I’m doing:
model:
sess=tf.Session()
session = ed.get_session()
mot1 = Dirichlet(tf.constant([1.0,1.0,1.0]), name='mot1')
qmot1 = Dirichlet(tf.nn.softplus(tf.Variable(tf.constant([1.0,1.0,1.0])), name="qmot1"))
probs = tf.reshape((tf.tile(tf.constant([1.0,5.0,15.0]),tf.constant([100]))),[100,3])
data = tf.constant(Categorical(probs=probs).eval())
probs1 = tf.reshape((tf.tile(mot1.eval(),tf.constant([100]))),[100,3])
w = Categorical(probs=probs1)
inference:
inference = ed.KLqp({mot1: qmot1}, data={data: w})
inference.initialize()
inference.n_iter = 1000
tf.global_variables_initializer().run()
for _ in range(inference.n_iter):
info_dict = inference.update()
inference.print_progress(info_dict)
print(session.run(qmot1))
print("______________")
inference.finalize()
I expect here, qmot1 to converge towards [1.0,5.0,15.0], but absolutely not, instead that prints senseless values. I’m surely missing some basics but dunno what.