Hi guys. I’m trying to recreate the Bernoulli Mixture Model over binarized MNIST digits from Section 9.3.3 of Bishop’s textbook in Edward, but am struggling with the ParamMixture() class.

So far, I have:

N = 10000
K = 10
D = 28*28
pi = Dirichlet(tf.ones(K), sample_shape=D)
mu = Beta(tf.ones(D),tf.ones(D),sample_shape=K)
x = ParamMixture(pi, {'probs': mu}, Bernoulli, sample_shape=N)
z = x.cat

This allows me to define a ParamMixture with the right number of dimensions. However, I get an error during training: Incompatible shapes: [10000,10,784] vs. [10000,784,10]

If I try to change any of the shapes in the model parameters, the ParamMixture complains; if I have a working ParamMixture, I get the error during inference.

In short: does anyone have an example of how to create a multi-dimensional Bernoulli Mixture Model? Any help would be greatly appreciated!