Variational Distribution for hierarchical model

Hey everyone,

I am trying to implement a hierarchical logistic regression model.
The model definition is as follows:

# Hyperpriors
loc_w = Normal(loc=tf.zeros(D), scale=3.0 * tf.ones(D))
scale_w = Cauchy(loc=tf.zeros(D), scale=3.0 * tf.ones(D))

loc_b = Normal(loc=tf.zeros([]), scale=3.0 * tf.ones([]))
scale_b = Cauchy(loc=tf.zeros([]), scale=3.0 * tf.ones([]))
# Model variables
x = tf.placeholder(tf.float32, [None, D])
group = tf.placeholder(tf.int32, None) # A categorical vector with values 0-3 for 4 groups

w = Normal(loc=tf.ones([num_groups, D]) * loc_w, scale=tf.ones([num_groups, D]) * scale_w)
b = Normal(loc=tf.ones(num_groups) * loc_b, scale=tf.ones(num_groups) * scale_b)

y = Bernoulli(logits=ed.dot(x, tf.gather(w, group)) + tf.gather(b, group))

As you can see, the weights are a matrix with a row for each group. For example, 3 features and 4 groups would give me a 4x3 matrix. When I calculate the output as a Bernoulli RV I select the appropriate weights for the group of the current datapoint.
Now, my problem is, how do I specify this in the variational distribution?

The code currently looks like this:

   # Variational Inference
# Location of weights
qloc_w = Normal(loc=tf.get_variable("qw_loc_loc", D), 
                scale=tf.nn.softplus(tf.get_variable("qw_loc_scale", D))
               )
# Scale of weights
qscale_w = Normal(loc=tf.get_variable("qw_scale_loc", D), 
                  scale=tf.nn.softplus(tf.get_variable("qw_scale_scale", D))
                 )
# Location of bias
qloc_b = Normal(loc=tf.get_variable("qb_loc_loc", [1]), 
                scale=tf.nn.softplus(tf.get_variable("qb_loc_scale", [1]))
               )
# Scale of bias
qscale_b = Normal(loc=tf.get_variable("qb_scale_loc", [1]), 
                  scale=tf.nn.softplus(tf.get_variable("qb_scale_scale", [1]))
                 )
# Weights
qw = Normal(loc=tf.get_variable("qw_loc", [num_groups, D]), 
                                scale=tf.nn.softplus(tf.get_variable("qw_scale", [num_groups, D]))
                               )
# Bias
qb = Normal(loc=tf.get_variable("qb_loc", [num_groups, 1]),
            scale=tf.nn.softplus(tf.get_variable("qb_scale", [num_groups, 1]))
           )

inference = ed.ReparameterizationKLqp({w: qw, b: qb, loc_w:qloc_w, scale_w: qscale_w,
                    loc_b: qloc_b, scale_b:qscale_b}, data={x: X_train, y: y_train, group: g_train})
inference.initialize(n_print=10, n_iter=600)

This throws an error saying that the latent variables do not have the same form. The weights in the posterior distribution are (4, 3) and (3,) in the variational one.

Have you figured things out?