Variational Distribution for hierarchical model


Hey everyone,

I am trying to implement a hierarchical logistic regression model.
The model definition is as follows:

# Hyperpriors
loc_w = Normal(loc=tf.zeros(D), scale=3.0 * tf.ones(D))
scale_w = Cauchy(loc=tf.zeros(D), scale=3.0 * tf.ones(D))

loc_b = Normal(loc=tf.zeros([]), scale=3.0 * tf.ones([]))
scale_b = Cauchy(loc=tf.zeros([]), scale=3.0 * tf.ones([]))
# Model variables
x = tf.placeholder(tf.float32, [None, D])
group = tf.placeholder(tf.int32, None) # A categorical vector with values 0-3 for 4 groups

w = Normal(loc=tf.ones([num_groups, D]) * loc_w, scale=tf.ones([num_groups, D]) * scale_w)
b = Normal(loc=tf.ones(num_groups) * loc_b, scale=tf.ones(num_groups) * scale_b)

y = Bernoulli(, tf.gather(w, group)) + tf.gather(b, group))

As you can see, the weights are a matrix with a row for each group. For example, 3 features and 4 groups would give me a 4x3 matrix. When I calculate the output as a Bernoulli RV I select the appropriate weights for the group of the current datapoint.
Now, my problem is, how do I specify this in the variational distribution?

The code currently looks like this:

   # Variational Inference
# Location of weights
qloc_w = Normal(loc=tf.get_variable("qw_loc_loc", D), 
                scale=tf.nn.softplus(tf.get_variable("qw_loc_scale", D))
# Scale of weights
qscale_w = Normal(loc=tf.get_variable("qw_scale_loc", D), 
                  scale=tf.nn.softplus(tf.get_variable("qw_scale_scale", D))
# Location of bias
qloc_b = Normal(loc=tf.get_variable("qb_loc_loc", [1]), 
                scale=tf.nn.softplus(tf.get_variable("qb_loc_scale", [1]))
# Scale of bias
qscale_b = Normal(loc=tf.get_variable("qb_scale_loc", [1]), 
                  scale=tf.nn.softplus(tf.get_variable("qb_scale_scale", [1]))
# Weights
qw = Normal(loc=tf.get_variable("qw_loc", [num_groups, D]), 
                                scale=tf.nn.softplus(tf.get_variable("qw_scale", [num_groups, D]))
# Bias
qb = Normal(loc=tf.get_variable("qb_loc", [num_groups, 1]),
            scale=tf.nn.softplus(tf.get_variable("qb_scale", [num_groups, 1]))

inference = ed.ReparameterizationKLqp({w: qw, b: qb, loc_w:qloc_w, scale_w: qscale_w,
                    loc_b: qloc_b, scale_b:qscale_b}, data={x: X_train, y: y_train, group: g_train})
inference.initialize(n_print=10, n_iter=600)

This throws an error saying that the latent variables do not have the same form. The weights in the posterior distribution are (4, 3) and (3,) in the variational one.