I am trying to find the number of clusters in the Iris dataset, without defining K!
To do this I follow this tutorial: http://edwardlib.org/tutorials/unsupervised
This works fine, however I want to replace the Dirichlet Distribution with a Dirichlet Process.
When I try to do this, I get the following error:
NotImplementedError: conjugate_log_prob not implemented for <class 'abc.Deterministic'>
Here is my code:
df = pd.DataFrame(load_iris()['data'])
def stick_breaking(beta):
portion_remaining = tf.concat([[1], tf.cumprod(1 - beta)[:-1]], axis=0)
return beta * portion_remaining
k = 30
d = df.shape[1]
n = df.shape[0]
alpha = ed.models.HalfNormal(tf.ones(k))
beta = ed.models.Beta(alpha, tf.ones(k))
pi = ed.models.Deterministic(stick_breaking(beta))
mu = Normal(tf.zeros(d), tf.ones(d), sample_shape=k)
sigmasq = InverseGamma(tf.ones(d), tf.ones(d), sample_shape=k)
x = ParamMixture(pi, {'loc': mu, 'scale_diag': tf.sqrt(sigmasq)},
MultivariateNormalDiag,
sample_shape=n)
z = x.cat
t = 500
qalpha = Empirical(tf.get_variable('qalpha', shape=[t, k], initializer=tf.ones_initializer()))
qbeta = Empirical(tf.get_variable('qbeta', shape=[t, k], initializer=tf.constant_initializer(0.5)))
qpi = Empirical(tf.get_variable('qpi', shape=[t, k], initializer=tf.constant_initializer(1 / k)))
qmu = Empirical(tf.get_variable('qmu', shape=[t, k, d], initializer=tf.zeros_initializer()))
qsigma = Empirical(tf.get_variable('qsigma', shape=[t, k, d], initializer=tf.ones_initializer()))
qz = Empirical(tf.get_variable('qz', shape=[t, n], initializer=tf.zeros_initializer(), dtype=tf.int32)
inference = ed.Gibbs({pi: qpi, mu: qmu, sigmasq: qsigma, z: qz, alpha: qalpha, beta: qbeta},
data={x: df.values}))
inference.initialize()
sess = ed.get_session()
tf.global_variables_initializer().run()
t_ph = tf.placeholder(tf.int32, [])
running_cluster_means = tf.reduce_mean(mu_param.params[:t_ph], axis=0)
for _ in range(inference.n_iter):
info_dict = inference.update()
inference.print_progress(info_dict)
t = info_dict['t']
if t % inference.n_print == 0:
print('\nInferred cluster means:')
print(sess.run(running_cluster_means, {t_ph: t - 1}))