Dirichlet Process Mixture Models (DPMM) with hierarchical structure

Hi @vsmolyakov

One of the problems in the example you mentioned is that the distributions are not well defined. Maybe an updated version of the code for a fixed K could help. It is based on the examples posted by @ecosang at Github. See also the discussion.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import edward as ed
import tensorflow as tf
from tensorflow.contrib.linalg import LinearOperatorTriL
ds = tf.contrib.distributions

import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
import matplotlib.cm as cm
import six

plt.style.use('ggplot')

from edward.models import Dirichlet, Categorical, \
MultivariateNormalTriL, WishartCholesky, Mixture


def plot_point_cov(points, nstd=2, ax=None, **kwargs):
    """
    Returns
    -------
        A matplotlib ellipse artist
    """
    pos = points.mean(axis=0)
    cov = np.cov(points, rowvar=False)
    return plot_cov_ellipse(cov, pos, nstd, ax, **kwargs)

def plot_cov_ellipse(cov, pos, nstd=2, ax=None, **kwargs):
    """
        A matplotlib ellipse artist
    """
    def eigsorted(cov):
        vals, vecs = np.linalg.eigh(cov)
        order = vals.argsort()[::-1]
        return vals[order], vecs[:,order]
    if ax is None:
        ax = plt.gca()
    vals, vecs = eigsorted(cov)
    theta = np.degrees(np.arctan2(*vecs[:,0][::-1]))
    # Width and height are "full" widths, not radius
    width, height = 2 * nstd * np.sqrt(vals)
    ellip = Ellipse(xy=pos, width=width, height=height, angle=theta, **kwargs)
    ax.add_artist(ellip)
    return ellip

# -- Dataset --

N = 5000  # number of data points
K = 3  # number of components
D = 2  # dimensionality of data

true_pi = np.array([0.3, 0.6, 0.1])
true_mus = [[10, 1], [-10, -10], [-1, 20]]
true_stds = [[[3, -0.1], [-0.1, 2]], 
[[4, 0.0], [0.0, 1]], [[3, 0.2], [0.2, 4]]]
true_cholesky = tf.cholesky(true_stds)

def build_toy_dataset(N,pi,mus,stds):
  x = np.zeros((N, 2), dtype=np.float32)
  label = np.zeros((N,),dtype=np.float32)
  for n in range(N):
    k = np.argmax(np.random.multinomial(1, pi))
    x[n, :] = np.random.multivariate_normal(mus[k], stds[k])
    label[n] = k
  return x, label

xn_data, xn_label = build_toy_dataset(N,true_pi, true_mus, true_stds)

pi = Dirichlet(tf.ones(K))  # Weights
z = Categorical(probs=pi,sample_shape=N) # Assignments

nu0 = tf.Variable(D, dtype=tf.float32, trainable=False)
psi0 = tf.Variable(np.eye(D), dtype=tf.float32, trainable=False)
mu0 = tf.Variable(np.zeros(D), dtype=tf.float32, trainable=False)
k0 = tf.Variable(1., dtype=tf.float32, trainable=False)

sigma = WishartCholesky(df=nu0, scale=psi0,cholesky_input_output_matrices=True, sample_shape = K)
mu = MultivariateNormalTriL(mu0, k0*sigma)
components = [MultivariateNormalTriL(mu[k],sigma[k]) for k in range(K)]
xn = Mixture(cat=z,components=components,sample_shape=N)

qpsi0 = tf.Variable(tf.random_normal([K, D, D], dtype=tf.float32))
Ltril = LinearOperatorTriL(ds.matrix_diag_transform(qpsi0, transform=tf.nn.softplus)).to_dense()
qsigma = WishartCholesky(df=[100]*K,scale=Ltril) #,cholesky_input_output_matrices=True)


qmu0 = tf.Variable(tf.random_normal([K, D], dtype=tf.float32))
qR = tf.Variable(tf.random_normal([K, D, D], dtype=tf.float32))
qmu = MultivariateNormalTriL(qmu0, qR)


inference = ed.KLqp({mu: qmu, sigma: qsigma}, data={xn: xn_data})
inference.initialize(n_iter=1000, n_print=100, n_samples=30)
sess = ed.get_session()
init = tf.global_variables_initializer()
init.run()

learning_curve = []
for _ in range(inference.n_iter):
    info_dict = inference.update()
    if _%30 == 0:
        print(info_dict)
    learning_curve.append(info_dict['loss'])

plt.semilogy(learning_curve)

post_mu_mean = qmu.sample(100).eval().mean(axis=0)
post_sigma_mean= qsigma.sample(100).eval().mean(axis=0)

plt.figure()
for i in range(K):
  plot_cov_ellipse(true_stds[i], true_mus[i], nstd=3, alpha=0.3, color = "green")
  plot_cov_ellipse(post_sigma_mean[i], post_mu_mean[i], nstd=3, alpha=0.3, color = "red")

plt.scatter(xn_data[:, 0], xn_data[:, 1])
plt.show()
1 Like