Correct way to evaluate linear regression predictive variance with learned alpha


#1

In the standard bayesian linear regression tutorial, the weight variances are fixed.
I have a model in which these are additionally learned: weight_variance = 1/alpha where alpha is the inverse variance, modelled in Edward as a gamma prior and lognormal variational distribution qalpha

W = Normal(loc=tf.zeros([D, 1]), scale=tf.ones(shape=[D, 1], 1.0 / alpha)

Using the latent dictionary pair

{alpha:qalpha}

When evaluating the model, one typically wants the test-point dependent predictive variance, which from Murphy is

pred_var(x) = 1/alpha + x^T V x

where V is the linear regression weight variance (W)
My model is getting a nice predictive mean by running

sess.run(Y_post, feed_dict=X_test)

and taking the mean over multiple samples or simply evaluating the variational mean()s

I’m having great difficulty however getting something meaningful from the variance when plotting it over a dense grid of test points it’s incredibly noisy and does not appear to be taking into account where lots of training points were available and where there were no training points. I was wondering if someone could check if what I’m doing is correct/wrong and any suggestions how to improve it

Note:

  • The tf.diag is just for completeness if one were to use a multivariate qW since qW ~ Normal() so we only really have the diagonal components.
  • The reduce sum is to apply this across multiple test points and replaces a per-test-point double dot product
  • X is an (N,D) feature matrix
  • Because qbeta is lognormal TransformedDistribution, in the actual code it’s being evaluated like

qbetamean = qbeta.bijector._forward_fn(qbeta.distribution.mean())

qbetamean = self.qbeta.mean()
V = tf.diag(tf.reshape(qW.variance(), [-1]))
dot1 = tf.matmul(X, V)
mult1 = tf.multiply(dot1, X)
phi_var_phi = tf.reduce_sum(mult1, axis=1, keep_dims=True)
predictive_variance = 1.0 / qbetamean + phi_var_phi


#2

perhaps you could link to your code, as its a bit unclear exactly what you are doing. For example, what is Y_post?

 sess.run(Y_post, feed_dict=X_test)

Assuming this is the same more or less as the tutorial, you seem to be just doing a plug-in approximation.
To get the full predictive you need to take the likelihood on the new points, multiplied by the posterior of w and integrate out w.