Error implementing a simple mixture of Gaussians


#1

I get the following error when running the code below to learn a simple mixture of Gaussians. Can anybody help me out, please?

Received a label value of 1 which is outside the valid range of [0, 2).  Label values: 2 2 2 2...

The piece of code is quite simple:

%matplotlib inline
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import edward as ed
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import six
import tensorflow as tf

from edward.models import (
    Categorical, Dirichlet, Empirical, InverseGamma,
    MultivariateNormalDiag, Normal, ParamMixture, Multinomial)

plt.style.use('ggplot')

def build_toy_dataset(N):
  pi = np.array([0.4, 0.6])
  mus = [1, -1]
  stds = [0.1, 0.1]
  x = np.zeros(N, dtype=np.float32)
  for n in range(N):
    k = np.argmax(np.random.multinomial(1, pi))
    x[n] = np.random.normal(mus[k], stds[k])

  return x


N = 500  # number of data points
K = 2  # number of components
D = 1  # dimensionality of data
#ed.set_seed(42)

x_train = build_toy_dataset(N)

pi = Dirichlet(np.zeros(K, dtype=np.float32))
mu = Normal(loc = 0.0, scale = 1.0, sample_shape = K)
sigma = InverseGamma(1.0, 2.0, sample_shape=K)

x = ParamMixture(pi, {'loc': mu, 'scale': tf.sqrt(sigma)},
                 Normal,
                 sample_shape=N)
z = x.cat

qpi = Dirichlet(tf.Variable(tf.ones(K) / K))
qmu = Normal(loc = tf.Variable(0.0), scale = tf.nn.softplus(tf.Variable(0.0)), sample_shape = K)
qsigma = InverseGamma(tf.Variable(2.0), tf.Variable(2.0), sample_shape = K)

qz = Categorical(logits = tf.Variable(tf.zeros([N,K])))

print(sigma)
print(qsigma)
print(z)
print(qz)


inference = ed.ScoreKLqp({pi: qpi, mu: qmu, sigma: qsigma, z: qz},
                     data={x: x_train})
inference.initialize()

sess = ed.get_session()
tf.global_variables_initializer().run()

for _ in range(inference.n_iter):
  info_dict = inference.update()
  inference.print_progress(info_dict)