Gaussian mixture model (Dirichlet process)


#1

I am trying to find the number of clusters in the Iris dataset, without defining K!

To do this I follow this tutorial: http://edwardlib.org/tutorials/unsupervised

This works fine, however I want to replace the Dirichlet Distribution with a Dirichlet Process.

When I try to do this, I get the following error:

NotImplementedError: conjugate_log_prob not implemented for <class 'abc.Deterministic'>

Here is my code:

df = pd.DataFrame(load_iris()['data'])
def stick_breaking(beta):
    portion_remaining = tf.concat([[1], tf.cumprod(1 - beta)[:-1]], axis=0)
    return beta * portion_remaining

k = 30

d = df.shape[1]
n = df.shape[0]

alpha = ed.models.HalfNormal(tf.ones(k))
beta = ed.models.Beta(alpha, tf.ones(k))
pi = ed.models.Deterministic(stick_breaking(beta))

mu = Normal(tf.zeros(d), tf.ones(d), sample_shape=k)

sigmasq = InverseGamma(tf.ones(d), tf.ones(d), sample_shape=k)
x = ParamMixture(pi, {'loc': mu, 'scale_diag': tf.sqrt(sigmasq)},
                 MultivariateNormalDiag,
                 sample_shape=n)
z = x.cat

t = 500

qalpha = Empirical(tf.get_variable('qalpha', shape=[t, k], initializer=tf.ones_initializer()))
qbeta = Empirical(tf.get_variable('qbeta', shape=[t, k], initializer=tf.constant_initializer(0.5)))
qpi = Empirical(tf.get_variable('qpi', shape=[t, k], initializer=tf.constant_initializer(1 / k)))
qmu = Empirical(tf.get_variable('qmu', shape=[t, k, d], initializer=tf.zeros_initializer()))
qsigma = Empirical(tf.get_variable('qsigma', shape=[t, k, d], initializer=tf.ones_initializer()))
qz = Empirical(tf.get_variable('qz', shape=[t, n], initializer=tf.zeros_initializer(), dtype=tf.int32)

inference = ed.Gibbs({pi: qpi, mu: qmu, sigmasq: qsigma, z: qz, alpha: qalpha, beta: qbeta},
data={x: df.values}))
inference.initialize()

sess = ed.get_session()
tf.global_variables_initializer().run()

t_ph = tf.placeholder(tf.int32, [])
running_cluster_means = tf.reduce_mean(mu_param.params[:t_ph], axis=0)

for _ in range(inference.n_iter):
    info_dict = inference.update()
    inference.print_progress(info_dict)
    t = info_dict['t']
    if t % inference.n_print == 0:
        print('\nInferred cluster means:')
        print(sess.run(running_cluster_means, {t_ph: t - 1}))