A little example on Dirichlet/Multinomial inference

%matplotlib inline
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import tensorflow as tf
from edward.models import (Multinomial, Dirichlet)
import edward as ed

true_pi1 = tf.constant([10.,5.,10.,20.])
true_pi2 = tf.constant([4.,8.,5.,15.])

pi1 = Dirichlet(tf.ones(4))
qpi1 = Dirichlet(tf.nn.softplus((tf.Variable(tf.ones(4)))), name="qpi1")

z_data=Multinomial(total_count=5000.,logits=tf.log(true_pi1+true_pi2))
z = Multinomial(total_count=5000.,logits=tf.log(pi1+true_pi2))

inference = ed.KLqp({pi1: qpi1}, data={z: z_data})
inference.initialize()
inference.n_iter = 1000

session = ed.get_session()
tf.global_variables_initializer().run()

for _ in range(inference.n_iter):
    info_dict = inference.update()
    inference.print_progress(info_dict)
    print("______________")
    print("MEAN QPI1:")
    print(session.run(qpi1.mean()))
    
inference.finalize()

Hello

Does someone know why I get at final (session.run(qpi1.mean()))=[[0.69216776 0.03820721 0.17060468 0.09902035]]

Instead of normalized true_pi1 (ie (true_pi1/tf.reduce_sum(true_pi1))) = [[0.22222222, 0.11111111, 0.22222222, 0.44444445]] ?

Solved. That was a normalization problem (true_pi2 and true_pi1 had to be normalized as instances of pi1 which is a Dirichlet sum up to one).