Nan in summary histogram for: gradient

Dear Edward’s users.

I obtain (randomly) the following message using inference with KLqp

InvalidArgumentError (see above for traceback): Nan in summary histogram for: gradient/posterior/qmu_loc/0
         [[Node: gradient/posterior/qmu_loc/0 = HistogramSummary[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"](gradient/posterior/qmu_loc/0/tag, gradients/AddN_27)]]

I’ve read something about decreasing the learning rate, but I don’t know how to put this option inside of inference:

inference.run(n_samples=5, n_iter=250, logdir='log2DMH')

You can pass in your own optimizer and lower its learning rate.

learning_rate = 1e-3
optimizer = tf.train.AdamOptimizer(learning_rate)
inference.initialize(..., optimizer=optimizer)

Hi Dustin! Many thanks! I had changes in the results.

Now I want understand why using a learning_rate = 1e-3 we obtain Loss: nan, but using learning_rate = 1e-2, we obtain Loss 18073.822. I’ll check the bibliography at Classes of Inference:

KLqp supports

1. score function gradients (Paisley et al., 2012)
2. reparameterization gradients (Kingma and Welling, 2014)

of the loss function.

learning_rate = 1e-3
optimizer = tf.train.AdamOptimizer(learning_rate)
inference.initialize(n_samples=5, n_iter=250, logdir='log2DMH', optimizer=optimizer)
inference.run()
1000/1000 [100%] ██████████████████████████████ Elapsed: 1s | Loss: nan    
learning_rate = 1e-2
optimizer = tf.train.AdamOptimizer(learning_rate)
inference.initialize(n_samples=5, n_iter=250, logdir='log2DMH', optimizer=optimizer)
inference.run()
1000/1000 [100%] ██████████████████████████████ Elapsed: 1s | Loss: 18073.822

---- EDIT ----
It was my fault! I forgot to do fine control of training procedure! Things are much better now.

learning_rate = 1e-3
optimizer = tf.train.AdamOptimizer(learning_rate)
inference.initialize(n_samples=30, n_iter=5000, logdir='log2DMH', optimizer=optimizer)
tf.global_variables_initializer().run()

for _ in range(inference.n_iter):
  info_dict = inference.update()
  inference.print_progress(info_dict)

inference.finalize()