Multinomial on convolutionnal model returns nans (equivalent code with Pyro)


#1

it’s been a month since I try to implement a model inference with several frameworks to compare efficiency. I would like to make the inference work with Edward to see this.

Code in Edward: 

%matplotlib inline
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import tensorflow as tf
from edward.models import (Categorical, Dirichlet, Empirical, InverseGamma, MultivariateNormalDiag, Normal, ParamMixture, 
                           PointMass, Multinomial)
import edward as ed
import math
import matplotlib.cm as cm
from matplotlib import pyplot as plt
import numpy as np
# from PIL import Image
# from PIL import ImageDraw
# from PIL import ImageFont
from random import randint

def roll(tenseur):
    tenseur=tf.transpose(tenseur, perm=[0,3,2,1])
    tenseur=tf.squeeze(tenseur,[0])
    tenseur=tf.concat([tenseur,tf.zeros([tenseur.shape[0],tenseur.shape[1],tenseur.shape[0]-1])],axis=2)
    T=tf.zeros([0,tenseur.shape[1],tenseur.shape[2]])
    for i in range(tenseur.shape[0]):
        rectangle=tf.expand_dims(tf.manip.roll(tenseur[i,:,:], shift=i, axis=1),0)
        T = tf.concat([T,rectangle],axis=0)
    T=tf.reduce_sum(T,0)
    return T   

A=tf.constant([[1.,1.,0.,0.,0.],[0.,4.,30.,1.,0.]])
B=tf.constant([[2.,1.,0.,0.,3.],[4.,5.,0.,0.,0.]])
C=tf.constant([[1.,1.,20.,10.,0.],[0.,4.,1.,1.,0.]])
A=A/tf.reduce_sum(A)
B=B/tf.reduce_sum(B)
C=C/tf.reduce_sum(C)

true_motifs=tf.stack([A,B,C],0)
true_motifs=tf.transpose(true_motifs,[1,2,0])
true_motifs=tf.expand_dims(true_motifs,0)

z=np.zeros((1,1,3,10))
z[0,0,0,0]=1
z[0,0,1,4]=1
z[0,0,2,6]=1
z=z/np.sum(z)
z = tf.constant(z, dtype=tf.float32)

motifs = Dirichlet(tf.ones([1,2,5,3]), name='motifs')
qmotifs = Dirichlet(tf.nn.softplus(tf.Variable(tf.ones([1,2,5,3]))), name="qmotifs")

results = tf.nn.conv2d(true_motifs,z,padding="VALID",strides=[1,1,1,1])
results=roll(results)
results=tf.reshape(results,[55])

def p_w_ta_d(motifs):
    t = tf.nn.conv2d(motifs,z,padding="VALID",strides=[1,1,1,1])
    t = roll(t)
    t = tf.reshape(t,[55])
    return t

data = Multinomial(3000.,probs=results)
w = Multinomial(3000.,probs=p_w_ta_d(motifs))

inference = ed.KLqp({motifs: qmotifs}, data={w: data})
inference.initialize(optimizer=tf.train.AdamOptimizer(learning_rate=0.001,beta1=0.9,beta2=0.999,epsilon=1e-08,use_locking=False))
inference.n_iter = 800

session = ed.get_session()

for _ in range(inference.n_iter):
    info_dict = inference.update()
    inference.print_progress(info_dict)
    print("qmotifs mean:")
    print(session.run(qmotifs.mean()))
    print("______________")
    print("qmotifs:")
    print(session.run(qmotifs))
    
inference.finalize()
tf.global_variables_initializer().run()

That returns me nans on qmotifs.mean(). If I replace multinomial by Categorical(…, sample_shape=[1,3000]) which should gives me slighly same results as a Multinomial(3000., …), then qmotif.mean() stay on the same values.

Code in Pyro that works both with Multinomial/Categorical. They have an option “constraint=constraints.positive” that prevent nans, I guess, otherwise the code is strictly equivalent:

import time
import math
import matplotlib.cm as cm
from matplotlib import pyplot as plt
import numpy as np
# from PIL import Image
# from PIL import ImageDraw
# from PIL import ImageFont
from random import randint

from __future__ import print_function
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import pyro
from pyro.optim import Adam
from pyro.infer import SVI
import pyro.distributions as pdist
import torch.distributions as tdist
import torch.distributions.constraints as constraints
import pyro.infer
from pyro.infer import SVI, Trace_ELBO
import pyro.optim

%matplotlib inline
# import some dependencies
try:
    import seaborn as sns
    sns.set()
except ImportError:
    pass

torch.manual_seed(101)

from IPython.core.debugger import set_trace
softplus = torch.nn.Softplus()

def roll(doc):
    doc=doc.view(10,5,2)
    doc = torch.cat((doc, torch.zeros(doc.shape[0],doc.shape[1],doc.shape[0]-1)),dim=2)
    for i in range(doc.shape[0]):
        for j in range(doc.shape[1]):
            doc[i,j,:]=torch.cat((doc[i,j,-i:],doc[i,j,0:-i]),dim=0)
    doc = torch.sum(doc,0)
    return doc

def normalize_t(tensor):
    return tensor/tensor.norm(1)

A=torch.Tensor([[1,1,0,0,0],[0,4,30,1,0]])
B=torch.Tensor([[2,1,0,0,3],[4,5,0,0,0]])
C=torch.Tensor([[1,1,20,10,0],[0,4,1,1,0]])

motif1=normalize_t(A)
motif2=normalize_t(B)
motif3=normalize_t(C)

z=torch.zeros(10,3,1,1)
z[0,0,0,0]=1
z[4,1,0,0]=1
z[6,2,0,0]=1
z=normalize_t(z)
        
motifs = torch.zeros(1,3,2,5)
motifs[0,0,:,:]=motif1
motifs[0,1,]=motif2
motifs[0,2,:,:]=motif3

results = F.conv2d(motifs,z)
results=roll(results)
results=results.view(1,55)

def p_w_ta_d(motifs):
    t = F.conv2d(motifs,z)
    t = roll(t)
    t =  t.view(1,55)
    return t

pyro.clear_param_store()

def model(data):
    alpha = torch.ones(1,3,2,5)
    mots = pyro.sample("latent",pdist.Dirichlet(concentration=alpha))
    with pyro.iarange("data",len(data)): 
        pyro.sample("observe", pdist.Multinomial(3000,probs=p_w_ta_d(mots)),obs=data)

def guide(data):
    qalpha = pyro.param("qalpha", torch.ones(1,3,2,5),constraint=constraints.positive)
    pyro.sample("latent",pdist.Dirichlet(concentration=qalpha))

adam_params = {"lr": 0.0005, "betas": (0.9, 0.999)}
optimizer = pyro.optim.Adam(adam_params)

svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

n_steps = 20000
for step in range(n_steps):
    svi.step(data)
    if step % 100 == 0:
        print(step)
        print(pyro.param("qalpha"))
        print("___________________")

Any ideas ? ^^


#2

This is the same when I’m doing more simply:

data <- multinomial(tf.nn.conv2d(input, filter))
parameter <- Dirichlet(…)
qparameter <- Dirichlet(…)
latent <- tf.nn.conv2d(parameter, filter)
w <- multinomial(latent)
inference({parameter:qparameter}, {w:data})

without the “complex” roll function, it is then a basic convolution.

I’m missing something ?