Multidimensional MDN

Hi,

I’m trying to extend the MDN example to case where both input and output are multidimensional.
Currently I’m getting exception with shape mismatch.
Can you provide simple example that works?

Thanks a lot

Hi,

In order to promote this question, till now I’ve found that when I’m performing next command:
o = Mixture(cat=cat, components=components, value=tf.zeros_like(o_ph))

, the edward infrastructure expects batch shape of cat and components[0] be the same.

While cat has scalar dimension:
TensorShape([Dimension(None)])
(none is number of sample points)

, the components[0] in my case is Normal distribution and its batch_shape is:
TensorShape([Dimension(None), Dimension(2)])
(2 is current dimension of my output)

When creating Mixture, I’m getting error:
ValueError: Shapes (?,) and (?, 2) are not compatible

So, why does Mixture expect cat and component to share dimensions?
And how can I fix it?

Thanks

Hi, I was having this same problem and your forum post was the only other mention of the problem I could find.

Anyway, I’ve since figured it out by going through the tf.contrib.distributions.Mixture code. Apparently distributions infer something about what the different axes you’re working with mean. So if you use a Normal distribution (even if you feed it a loc/scale with two axes), it assumes you want a UNIvariate normal distribution and that you simply have 2 batch dimensions. This becomes a problem later on when you do a mixture, because the mixture sees the categorical distribution has a batch shape of [None] (and event shape of []) and the Gaussians have a batch shape of [None, 2] (and event shape of []). What you want it to think is that the categorical distribution has a batch shape of [None]/event shape of [] and the Gaussians have a batch shape of [None] and an event shape of [2]. To do this, you have to use a different class, namely the ed.models.MultivariateNormalDiag class. This will provide the same distribution mathematically as you had before, but it will realize that the last axis is an event (or depth, whatever you want to call it) axis and not a batch axis.

Hope that helps others with the same problem.

1 Like

Thanks @kkleidal. I tried the MultivariateNormalDiag in place of Normal and the shape error disappears. However, I’m now getting a confusing error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-3-ea93ed6724dc> in <module>()
     15 components = [MultivariateNormalDiag(loc=loc, scale=scale) for loc, scale
     16               in zip(tf.unstack(tf.transpose(locs)),
---> 17                      tf.unstack(tf.transpose(scales)))]
     18 y = Mixture(cat=cat, components=components, value=tf.zeros_like(y_ph))
     19 # Note: A bug exists in Mixture which prevents samples from it to have

<ipython-input-3-ea93ed6724dc> in <listcomp>(.0)
     13 
     14 cat = Categorical(logits=logits)
---> 15 components = [MultivariateNormalDiag(loc=loc, scale=scale) for loc, scale
     16               in zip(tf.unstack(tf.transpose(locs)),
     17                      tf.unstack(tf.transpose(scales)))]

~/anaconda3/lib/python3.6/site-packages/edward/models/random_variables.py in __init__(self, *args, **kwargs)
     19     # to use _candidate's docstring, must write a new __init__ method
     20     def __init__(self, *args, **kwargs):
---> 21       _RandomVariable.__init__(self, *args, **kwargs)
     22     __init__.__doc__ = _candidate.__init__.__doc__
     23     _params = {'__doc__': _candidate.__doc__,

~/anaconda3/lib/python3.6/site-packages/edward/models/random_variable.py in __init__(self, *args, **kwargs)
    110       self._kwargs['collections'] = collections
    111 
--> 112     super(RandomVariable, self).__init__(*args, **kwargs)
    113 
    114     self._sample_shape = tf.TensorShape(sample_shape)

TypeError: __init__() got an unexpected keyword argument 'scale'

Any inputs will be greatly appreciated - this is the only relevant thread I could find anywhere.