Using RelaxedBernoulli for a Zero-Inflated Poisson model

Hi Dustin and members of the Edward forum,

I have been able to implement a zero-inflated Poisson model using variational inference. I replaced the Bernoulli distribution with the relaxed Bernoulli distribution, which is a continuous probability density function. I’ve inspected the parameter updates, and sometimes the best estimates occur during the run rather at the end of the period. Whenever I rerun the variational inference, I’ll occasionally get “nans” for the loss. Does anyone have any suggestion on how I can optimize the inference?
I’ve posted a copy of my code below:

import edward as ed
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from edward.models import Beta, Gamma, Poisson, RelaxedBernoulli

# Generate data

n = 100
pi_true = 0.6
mu_true = 10
z_true = np.random.binomial(1, pi_true, n)
x_obs = np.random.poisson(lam=np.multiply(1 - z_true, mu_true))

# Model
pi = Beta(30.0, 20.0)
mu = Gamma(20.0, 2.0)
z = RelaxedBernoulli(0.01, probs=pi, sample_shape=n)
x = Poisson(rate=tf.multiply(mu, 1.0-z))

# Variational Inference

qpi = Beta(tf.nn.softplus(tf.Variable(tf.random_normal([]))), 

qmu = Gamma(tf.nn.softplus(tf.Variable(tf.random_normal([]))), 

qz = RelaxedBernoulli(tf.nn.softplus(tf.Variable(tf.random_normal([]))), 
          tf.nn.softplus(tf.Variable(tf.random_normal([]))), sample_shape=n)

inference = ed.KLqp({pi: qpi, mu: qmu, z:qz}, data={x: x_obs})
inference.initialize(n_samples=3, n_print=100, n_iter=1000)

sess = ed.get_session()
for _ in range(inference.n_iter):
    info_dict = inference.update()
    t = info_dict['t']
    if t == 1 or t % inference.n_print == 0:
        qpi_mean, qmu_mean =[qpi.mean(), qmu.mean()])
        print("Inferred probability of a zero-count:")
        print("Inferred Poisson mean:")

Replace Gamma with Log-Normal Distribution and let’s check the new the results.

Hi Ermia,

The Relaxed Bernoulli is unstable even with the lognormal distribution. The Bernoulli using MAP performs better.


1 Like