Problems when building a variational model with MIXED GAUSSIAN prior rather than NORMAL prior

Hello!

I have learnt the model of BNN and found that the prior of the model is standard normal. But I want to try to explore the multimode of the parameters in BNN.

Unfortunately, when changing the prior to mixed guassian distribution, I encounter a problem below.

ValueError                                Traceback (most recent call last)
<ipython-input-5-9bb61b223807> in <module>()
     11     W_0 = edm.ParamMixture(probs, {'loc': mu, 'scale_diag': tf.sqrt(sigmasq)},
     12                  edm.MultivariateNormalDiag,
---> 13                  sample_shape=N)
...
...
ValueError: Dimensions must be equal, but are 50 and 2 for 'W_0_1/ParamMixture/sample/mul' (op: 'Mul') with input shapes: [50,2,1,2], [2,50,1,1].

And my code goes like:

K = 2
with tf.variable_scope("W_0"):
    probs = edm.Dirichlet(tf.ones(K))
    mu = edm.Normal(tf.zeros([D, 2]), tf.ones([D, 2]), sample_shape=K)
    sigmasq = edm.InverseGamma(tf.ones([D, 2]), tf.ones([D, 2]), sample_shape=K)
    W_0 = edm.ParamMixture(probs, {'loc': mu, 'scale_diag': tf.sqrt(sigmasq)},
                 edm.MultivariateNormalDiag,
                 sample_shape=N)

which is derived from tutorial.

Thanks.