Simple GP classification

Hi folk:

I’ve been trying to fit a GP classification model with a very simple dataset with two classes. Does anyone has a clue why the predictive mean is not good?

Increasing n_samples and n_iter didn’t help either. Although the length scale of the GP kernel is fixed here (no distribution over length scale), as I checked heuristically, I believe, the current value is good enough.

Also, as another question, I’m wondering why my query test set should be the same size as the training set. The shape of my X placeholder is [None, D].

Thanks in advance!


import numpy as np
import matplotlib.pyplot as pl
import tensorflow as tf
import edward as ed
from edward.models import Bernoulli, MultivariateNormalTriL
from edward.util import rbf
from edward.models import Bernoulli, Normal


X_train = np.linspace(-100, 100, 150)[:, np.newaxis]
y_train = np.array([0]*70 + [1]*80)

N = X_train.shape[0]  # number of data points
D = X_train.shape[1]  # number of features

print("NxD={}x{}".format(N, D))


X = tf.placeholder(tf.float32, [None, D])
f = MultivariateNormalTriL(loc=tf.zeros(N), scale_tril=tf.cholesky(rbf(X, lengthscale=0.5)))
y = Bernoulli(logits=f)

qf = Normal(loc=tf.Variable(tf.random_normal([N])),
                    scale=tf.nn.softplus(tf.Variable(tf.random_normal([N]))))

inference = ed.KLqp({f: qf}, data={X: X_train, y: y_train})
inference.run(n_samples=10, n_iter=500)

y_post = ed.copy(y, {f: qf})
sess = ed.get_session()
X_test = X_train #np.linspace(-2, 12, 100)[:, np.newaxis]

y_q = 0
T = 20
for i in range(T):
    y_q = y_q + sess.run(y_post.mean(), feed_dict={X: X_test})

y_q = y_q/T    

pl.scatter(X_train[:,0], y_train, c='b')
pl.plot(X_test[:,0], y_q, c='r')
pl.show()

A GP doesn’t generalize to test inputs as you wrote. f defines local variables, which are one per data point. This implies the variational approximation qf is specific to the training data and would require new inference on test data.

For review, I recommend reading, e.g., Section 2 of Snelson and Ghahramani (2007).