Hi,
I’m trying to implement a very basic hierarchical model with real valued data from several groups. I’d like to infer both the group means and the overall mean. Here’s my code.
import edward as ed
from edward.models import Normal
import numpy as np
import tensorflow as tf
# TOY DATA
N = 3 # number of groups
M = 1000 # samples per group
# mean for each group is different
# want to infer the group means plus the overall mean
actual_group_means = [0.1, 0.2, 0.3]
sigma = 0.1
observed_groups = np.repeat([0, 1, 2], M)
samples = [np.random.normal(actual_group_means[g], sigma, M) for g in range(N)]
observed_data = np.concatenate(samples)
# MODEL
groups = tf.placeholder(tf.int32, [M * N])
overall_mean = Normal(
loc=tf.zeros(1),
scale=tf.ones(1) * 0.05
)
q_overall_mean = Normal(
loc=tf.Variable(tf.zeros(1)),
scale=tf.nn.softplus(tf.Variable(tf.zeros(1)))
)
group_means = Normal(
loc=tf.ones(N) * overall_mean,
scale=tf.ones(N) * 0.05
)
q_group_means = Normal(
loc=tf.Variable(tf.zeros(N)),
scale=tf.nn.softplus(tf.Variable(tf.zeros(N)))
)
data = Normal(
loc=tf.gather(group_means, groups),
scale=tf.ones(shape=[N * M]) * sigma
)
for inference_alg in (ed.KLpq, ed.KLqp):
inference = inference_alg(
{
overall_mean: q_overall_mean,
group_means: q_group_means
},
data={
groups: observed_groups,
data: observed_data
}
)
inference.run(n_samples=5, n_iter=1000)
sess = ed.get_session()
print('Using {}:'.format(inference_alg))
print(q_overall_mean.mean().eval())
print(q_group_means.mean().eval())
Output from this is:
Using <class 'edward.inferences.klpq.KLpq'>:
[ 0.4643645]
[ 0.08044512 0.20934561 0.26856244]
Using <class 'edward.inferences.klqp.KLqp'>:
[ 0.]
[ 0.09141546 0.1917349 0.30631313]
In both cases the group means are inferred but for KLpq the overall mean is too high and for KLqp the mean isn’t inferred at all. I get the same results when increasing the number of iterations and in the real use case that this simplified version is based on.
Have I misspecified something in the model?
Thank you
UPDATE: With MAP inference the results look much better. I guess the question then becomes is there some theoretical reason why variational inference would struggle in this case?