Simple Hierarchical model fails

Hi,

I’m trying to implement a very basic hierarchical model with real valued data from several groups. I’d like to infer both the group means and the overall mean. Here’s my code.

import edward as ed
from edward.models import Normal
import numpy as np
import tensorflow as tf

# TOY DATA
N = 3  # number of groups
M = 1000  # samples per group

# mean for each group is different
# want to infer the group means plus the overall mean
actual_group_means = [0.1, 0.2, 0.3]
sigma = 0.1

observed_groups = np.repeat([0, 1, 2], M)
samples = [np.random.normal(actual_group_means[g], sigma, M) for g in range(N)]
observed_data = np.concatenate(samples)

# MODEL
groups = tf.placeholder(tf.int32, [M * N])

overall_mean = Normal(
    loc=tf.zeros(1), 
    scale=tf.ones(1) * 0.05
)
q_overall_mean = Normal(
    loc=tf.Variable(tf.zeros(1)),
    scale=tf.nn.softplus(tf.Variable(tf.zeros(1)))
)

group_means = Normal(
    loc=tf.ones(N) * overall_mean,
    scale=tf.ones(N) * 0.05
)
q_group_means = Normal(
    loc=tf.Variable(tf.zeros(N)),
    scale=tf.nn.softplus(tf.Variable(tf.zeros(N)))
)

data = Normal(
    loc=tf.gather(group_means, groups),
    scale=tf.ones(shape=[N * M]) * sigma
)

for inference_alg in (ed.KLpq, ed.KLqp):
    inference = inference_alg(
        {
            overall_mean: q_overall_mean,
            group_means: q_group_means
        },
        data={
            groups: observed_groups,
            data: observed_data
        }
    )
    inference.run(n_samples=5, n_iter=1000)
    sess = ed.get_session()
    print('Using {}:'.format(inference_alg))
    print(q_overall_mean.mean().eval())
    print(q_group_means.mean().eval())

Output from this is:

Using <class 'edward.inferences.klpq.KLpq'>:
[ 0.4643645]
[ 0.08044512  0.20934561  0.26856244]
Using <class 'edward.inferences.klqp.KLqp'>:
[ 0.]
[ 0.09141546  0.1917349   0.30631313]

In both cases the group means are inferred but for KLpq the overall mean is too high and for KLqp the mean isn’t inferred at all. I get the same results when increasing the number of iterations and in the real use case that this simplified version is based on.

Have I misspecified something in the model?

Thank you

UPDATE: With MAP inference the results look much better. I guess the question then becomes is there some theoretical reason why variational inference would struggle in this case?

How do your result change with the increase in n_samples to 50 or 100?

With n_samples=100 the inferred overall mean with KLpq is much better, but for KLqp I get the same results:

1000/1000 [100%] ██████████████████████████████ Elapsed: 51s | Loss: 26.828
Using <class 'edward.inferences.klpq.KLpq'>:
[ 0.13855419]
[ 0.09777226  0.19409294  0.29982021]
1000/1000 [100%] ██████████████████████████████ Elapsed: 33s | Loss: -2617.888
Using <class 'edward.inferences.klqp.KLqp'>:
[ 0.]
[ 0.09611346  0.20162852  0.29795927]

Unfortunately I don’t have good intuition why ed.KLqp is giving bad results for the overall mean. The model seems well-behaving enough that I can’t suspect the global optima of the objective is that bad (but I wouldn’t bet on this argument). If you’d like to investigate this more, you can try fixing ed.KLqp's results on the group means and run inference over only the overall mean.

Thanks for your thoughts. I tried your suggestion of fixing the group means and only running inference on the overall mean but the results are pretty much in line with the above: ed.KLqp overall posterior mean equals the prior mean whereas ed.KLpq overall mean improves with increased n_samples.

If replace the overall mean prior with overall_mean = Normal(loc=tf.ones(1) * c, ...) for some constant c then there is no difference either. Is it significant that not only is the overall mean not being inferred very well but that the posterior mean equals the prior mean? Does this suggest that no inference updates are occurring/this is a ‘bad’ initialisation for the problem?

The problem is ed.KLqp infers it can use the analytical KL divergence between group_means and q_group_means (both are Normal). However, in reality it cannot because it needs to take an expectation over q_overall_mean.

You get the correct answer by explicitly calling ed.ReparameterizationKLqp, which computes all the KL divergences by sampling.

https://aksarkar.github.io/nwas/klqp.html

1 Like