I define the prior as w = Normal(loc=tf.zeros([self.n_in, self.n_out]), scale=2.0 * tf.ones([self.n_in, self.n_out])), so w.shape is TensorShape([Dimension(3), Dimension(2)])
I also define the posterior as qw=Empirical(params=tf.Variable(np.zeros([self.n_in, self.n_out]))), then qw.shape becomes TensorShape([Dimension(2)])
When I try inference = ed.SGHMC({self.w: self.qw}, data={self.x: X_train}), there is a TypeError: Key-value pair in latent_vars does not have same shape: (3, 2), (2,)
Thank you, dustin. So in the case of Probabilistic PCA, I have to modify the code as below to do SGHMC inference. The syntaxe" qz=Empirical(params=tf.Variable(tf.zeros([N, N, K])))" have to include two N, both mean the number of samples? It looks somewhat weired and the results of SGHMC looks not correct
N = 5000 # number of data points
D = 2 # data dimensionality
K = 1 # latent dimensionality
x_train = build_toy_dataset(N, D, K)
w = Normal(loc=tf.zeros([D, K]), scale=2.0 * tf.ones([D, K]))
z = Normal(loc=tf.zeros([N, K]), scale=tf.ones([N, K]))
x = Normal(loc=tf.matmul(w, z, transpose_b=True), scale=tf.ones([D, N]))
#For SGHMC inference
qw=Empirical(params=tf.Variable(tf.zeros([N, D, K])))
qz=Empirical(params=tf.Variable(tf.zeros([N, N, K])))
inference = ed.SGHMC({w: qw, z:qz}, data={x: x_train})
inference.run(step_size=1e-3)
#For KLqp inference
#qw = Normal(loc=tf.Variable(tf.random_normal([D, K])),
# scale=tf.nn.softplus(tf.Variable(tf.random_normal([D, K]))))
#qz = Normal(loc=tf.Variable(tf.random_normal([N, K])),
# scale=tf.nn.softplus(tf.Variable(tf.random_normal([N, K]))))
#inference = ed.KLqp({w: qw, z: qz}, data={x: x_train})
#inference.run(n_iter=500, n_print=100, n_samples=10)
w_post_mean=self.qw.mean().eval()
The outer dimension refers to the number of desired posterior samples. It has no relation to the number of observations.
re:SGHMC. SGMCMC algorithms can be very finicky. Have you tried tuning them and using longer chains (i.e., more posterior samples)? If you’re not doing data subsampling, ed.HMC is likely a better choice.