Multinomial-Logistic Normal Model, Probably My Fault

So the issue seemed to be how I was dealing with batch sampling from the multinomial.

The following code snippet seemed to fix the problem.

def _sample_n(self, n=1, seed=None):
    # define Python function which returns samples as a Numpy array
    def np_sample(p, n):
        return multinomial.rvs(p=p, n=n, random_state=seed).astype(np.float32)

    # wrap python function as tensorflow op
    val = tf.py_func(np_sample, [self.probs, n], [tf.float32])[0]
    # set shape from unknown shape
    batch_event_shape = self.batch_shape.concatenate(self.event_shape)
    shape = tf.concat(
        [tf.expand_dims(n, 0), tf.convert_to_tensor(batch_event_shape)], 0)
    val = tf.reshape(val, shape)
    return val
Multinomial._sample_n = _sample_n
1 Like