I am trying to apply the Bayesian NN presented by Torsten Scholak at PyCon to some real world data I have, in order to familiarize myself with edward and tensorflow and I am getting very weird results.
The network fits the data well but only up to a certain point and then flatlines. I can’t figure out where in the code I should tweak it. Here is the code for the network
def neural_network_with_2_layers(x, W_0, W_1, b_0, b_1):
h = tf.nn.tanh(tf.matmul(x, W_0) + b_0)
h = tf.matmul(h, W_1) + b_1
return tf.reshape(h, [-1])
dim = 10 # layer dimensions
W_0 = Normal(loc=tf.zeros([D, dim]),
scale=tf.ones([D, dim]))
W_1 = Normal(loc=tf.zeros([dim, 1]),
scale=tf.ones([dim, 1]))
b_0 = Normal(loc=tf.zeros(dim),
scale=tf.ones(dim))
b_1 = Normal(loc=tf.zeros(1),
scale=tf.ones(1))
x = tf.placeholder(tf.float32, [N, D])
#Reshaping
a = neural_network_with_2_layers(x,W_0,W_1,b_0,b_1)
b = tf.reshape(a,[len(X_train),1])
y = Normal(loc=b,scale=(tf.ones([N,1])*0.1)) # constant noise
`#BACKWARD MODEL A`
q_W_0 = Normal(loc=tf.Variable(tf.random_normal([D, dim])),
scale=tf.nn.softplus(tf.Variable(tf.random_normal([D, dim]))))
q_W_1 = Normal(loc=tf.Variable(tf.random_normal([dim, 1])),
scale=tf.nn.softplus(tf.Variable(tf.random_normal([dim, 1]))))
q_b_0 = Normal(loc=tf.Variable(tf.random_normal([dim])),
scale=tf.nn.softplus(tf.Variable(tf.random_normal([dim]))))
q_b_1 = Normal(loc=tf.Variable(tf.random_normal([1])),
scale=tf.nn.softplus(tf.Variable(tf.random_normal([1]))))
inference = ed.KLqp(latent_vars={W_0: q_W_0, b_0: q_b_0,
W_1: q_W_1, b_1: q_b_1},
data={x: X_train, y: Y_train})
inference.run(n_samples=50, n_iter=20000)
Here are the results
and the code to plot them
# CRITICISM A
plt.scatter(X_train, Y_train, s=20.0); # blue
plt.scatter(X_test, Y_test, s=20.0, # red
color=sns.color_palette().as_hex()[2]);
xp = tf.placeholder(tf.float32, [1000, D])
[plt.plot(np.linspace(-1.0, 1.0, 1000),
sess.run(neural_network_with_2_layers(xp,
q_W_0, q_W_1,
q_b_0, q_b_1),
{xp: np.linspace(-1.0, 1.0, 1000)[:, np.newaxis]}),
color='black', alpha=0.1)
for _ in range(10)];
Cheers