Model where last op is not RandomVariable


#1

Sometimes it is convenient to define as a deterministic function of stochastic and deterministic components. i.e. defining:

# Normal Way
y = Normal(location=foo, scale=bar)

# This Way
y = foo + Normal(location=0, scale=bar)

The resulting node in the graph is an add-op, but logically it is a random variable. However when defining an inference over this model (i.e. assigning data to y in the second example. The inference algorithm seems to fail to calculate the log probabilities, and doesn’t move the variational model from its starting position.

Is there any way to this (I thought maybe wrapping y in a PointMass)? I’m aware it’s bad practice for the toy example, but it would allow me to, for example, define modular additive models with various latent components, which is very useful in practice.

Full failing example based closely on linear reg tutorial:

import edward as ed
import numpy as np
import tensorflow as tf

from edward.models import Normal

def build_toy_dataset(N, w, noise_std=0.1):
    D = len(w)
    x = np.random.randn(N, D)
    y = np.dot(x, w) + np.random.normal(0, noise_std, size=N)
    return x, y

N = 40  # number of data points
D = 10  # number of features

w_true = np.random.randn(D)
X_train, y_train = build_toy_dataset(N, w_true)
X_test, y_test = build_toy_dataset(N, w_true)

X = tf.placeholder(tf.float32, [N, D])
w = Normal(loc=tf.zeros(D), scale=tf.ones(D))
b = Normal(loc=tf.zeros(1), scale=tf.ones(1))
## BEGIN CHANGE
# original: y = Normal(loc=ed.dot(X, w) + b, scale=tf.ones(N))
err = Normal(loc=0.0, scale=tf.ones(N))
y = ed.dot(X, w) + b + err
## END CHANGE

qw = Normal(loc=tf.Variable(tf.random_normal([D])),
            scale=tf.nn.softplus(tf.Variable(tf.random_normal([D]))))
qb = Normal(loc=tf.Variable(tf.random_normal([1])),
            scale=tf.nn.softplus(tf.Variable(tf.random_normal([1]))))

inference = ed.KLqp({w: qw, b: qb}, data={X: X_train, y: y_train})
inference.run(n_samples=5, n_iter=250)

#2

These related threads answer your question: