So the issue seemed to be how I was dealing with batch sampling from the multinomial.
The following code snippet seemed to fix the problem.
def _sample_n(self, n=1, seed=None):
# define Python function which returns samples as a Numpy array
def np_sample(p, n):
return multinomial.rvs(p=p, n=n, random_state=seed).astype(np.float32)
# wrap python function as tensorflow op
val = tf.py_func(np_sample, [self.probs, n], [tf.float32])[0]
# set shape from unknown shape
batch_event_shape = self.batch_shape.concatenate(self.event_shape)
shape = tf.concat(
[tf.expand_dims(n, 0), tf.convert_to_tensor(batch_event_shape)], 0)
val = tf.reshape(val, shape)
return val
Multinomial._sample_n = _sample_n