Multidimensional MDN



I’m trying to extend the MDN example to case where both input and output are multidimensional.
Currently I’m getting exception with shape mismatch.
Can you provide simple example that works?

Thanks a lot



In order to promote this question, till now I’ve found that when I’m performing next command:
o = Mixture(cat=cat, components=components, value=tf.zeros_like(o_ph))

, the edward infrastructure expects batch shape of cat and components[0] be the same.

While cat has scalar dimension:
(none is number of sample points)

, the components[0] in my case is Normal distribution and its batch_shape is:
TensorShape([Dimension(None), Dimension(2)])
(2 is current dimension of my output)

When creating Mixture, I’m getting error:
ValueError: Shapes (?,) and (?, 2) are not compatible

So, why does Mixture expect cat and component to share dimensions?
And how can I fix it?



Hi, I was having this same problem and your forum post was the only other mention of the problem I could find.

Anyway, I’ve since figured it out by going through the tf.contrib.distributions.Mixture code. Apparently distributions infer something about what the different axes you’re working with mean. So if you use a Normal distribution (even if you feed it a loc/scale with two axes), it assumes you want a UNIvariate normal distribution and that you simply have 2 batch dimensions. This becomes a problem later on when you do a mixture, because the mixture sees the categorical distribution has a batch shape of [None] (and event shape of []) and the Gaussians have a batch shape of [None, 2] (and event shape of []). What you want it to think is that the categorical distribution has a batch shape of [None]/event shape of [] and the Gaussians have a batch shape of [None] and an event shape of [2]. To do this, you have to use a different class, namely the ed.models.MultivariateNormalDiag class. This will provide the same distribution mathematically as you had before, but it will realize that the last axis is an event (or depth, whatever you want to call it) axis and not a batch axis.

Hope that helps others with the same problem.