Not only observations, but observations and their uncertainties

Dear users,

I try to deal with a simple model which has not only observations, but observations and its uncertainties. I’ve modified the normal_normal example at Github in order to include uncertainties. I’ve also included KLqp instead of HMC. Is this the correct way to do this? The problem I see in my code is the following:

x_obs = Normal(loc=x, scale = x_s)

because the shape grows as N does, which affects to KLqp. For instance if N=100000 then

`>>> x_obs.shape

TensorShape([Dimension(100000)])

Maybe it’s better to split the dataset as @dustin does at Data Subsampling section? What do you think?

The modified normal_normal.py example:

"""Normal-normal model using Hamiltonian Monte Carlo."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import edward as ed
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from edward.models import Empirical, Normal


def main(_):
  ed.set_seed(42)

  # DATA
  N=100000
  x_data = np.array([0.0] * N)
  x_uncert = np.random.normal(size=N)

  # MODEL: Normal-Normal with known variance
  mu = Normal(loc=0.0, scale=1.0)
  x = Normal(loc=mu, scale=1.0)

  x_s = tf.placeholder(tf.float32, [N])
  x_obs = Normal(loc=x, scale = x_s)
  # INFERENCE
  qmu = Normal(loc=tf.Variable(tf.random_normal([])), scale=tf.nn.softplus(tf.Variable(tf.random_normal([]))))


  # analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140)
  inference = ed.KLqp({mu: qmu}, data={x_obs: x_data, x_s: x_uncert})
  inference.run()

  # CRITICISM
  sess = ed.get_session()
  mean, stddev = sess.run([qmu.mean(), qmu.stddev()])
  print("Inferred posterior mean:")
  print(mean)
  print("Inferred posterior stddev:")
  print(stddev)

  # Check convergence with visual diagnostics.
  samples = sess.run(qmu.sample(1000))

  # Plot histogram.
  plt.hist(samples, bins='auto')
  plt.show()

  # Trace plot.
  plt.plot(samples)
  plt.show()

if __name__ == "__main__":
  tf.app.run()