Using RelaxedBernoulli for a Zero-Inflated Poisson model

Hi Dustin and members of the Edward forum,

I have been able to implement a zero-inflated Poisson model using variational inference. I replaced the Bernoulli distribution with the relaxed Bernoulli distribution, which is a continuous probability density function. I’ve inspected the parameter updates, and sometimes the best estimates occur during the run rather at the end of the period. Whenever I rerun the variational inference, I’ll occasionally get “nans” for the loss. Does anyone have any suggestion on how I can optimize the inference?
I’ve posted a copy of my code below:

import edward as ed
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from edward.models import Beta, Gamma, Poisson, RelaxedBernoulli

# Generate data

n = 100
pi_true = 0.6
mu_true = 10
np.random.seed(1335)
z_true = np.random.binomial(1, pi_true, n)
x_obs = np.random.poisson(lam=np.multiply(1 - z_true, mu_true))

# Model
​
pi = Beta(30.0, 20.0)
mu = Gamma(20.0, 2.0)
z = RelaxedBernoulli(0.01, probs=pi, sample_shape=n)
x = Poisson(rate=tf.multiply(mu, 1.0-z))

# Variational Inference

qpi = Beta(tf.nn.softplus(tf.Variable(tf.random_normal([]))), 
          tf.nn.softplus(tf.Variable(tf.random_normal([]))))

qmu = Gamma(tf.nn.softplus(tf.Variable(tf.random_normal([]))), 
          tf.nn.softplus(tf.Variable(tf.random_normal([]))))

qz = RelaxedBernoulli(tf.nn.softplus(tf.Variable(tf.random_normal([]))), 
          tf.nn.softplus(tf.Variable(tf.random_normal([]))), sample_shape=n)

inference = ed.KLqp({pi: qpi, mu: qmu, z:qz}, data={x: x_obs})
inference.initialize(n_samples=3, n_print=100, n_iter=1000)
tf.global_variables_initializer().run()

sess = ed.get_session()
for _ in range(inference.n_iter):
    info_dict = inference.update()
    inference.print_progress(info_dict)
    t = info_dict['t']
    if t == 1 or t % inference.n_print == 0:
        qpi_mean, qmu_mean = sess.run([qpi.mean(), qmu.mean()])
        print("")
        print("Inferred probability of a zero-count:")
        print(qpi_mean)
        print("")
        print("Inferred Poisson mean:")
        print(qmu_mean)

Replace Gamma with Log-Normal Distribution and let’s check the new the results.

Hi Ermia,

The Relaxed Bernoulli is unstable even with the lognormal distribution. The Bernoulli using MAP performs better.

Best,
Mark

1 Like