Hello All,
Short story:
I am having trouble importing a pre-trained model into Edward for inference. My issue is related to the behavior of the variables
argument to inference.run
, specifically, when calling inference.run
, how do I initialize only the tf.Variables
defined internally by the inference
object? Essentially, in my application I need to call inference.run( variables=[] )
to prevent the pre-trained Variables
from losing their values during inference, however, when I do that the internal variables are not initialized and so the inference falls over.
Longer example:
A prototype of my problem is below. I am loading a pre-trained model which consists of a graph with a bunch of tf.Variables
which have their values set by the training:
import edward as ed
import tensorflow as tf
# load a model I trained elsewhere.
saver = tf.train.import_meta_graph("test.meta")
sess = ed.get_session()
saver.restore(sess, "test")
# the model is a quadratic with some fitted parameters
print "parameters after load:",sess.run( tf.get_collection("model_parameters") )
Then I want to define an inference problem and do some MCMC (or whatever)
g = sess.graph
x = g.get_tensor_by_name("input:0")
y = g.get_tensor_by_name("output:0")
# set up the inference problem
Xin = ed.models.Uniform(low=[0.], high=[1.])
y = ed.copy(y, dict_swap={x:Xin})
# target output
x_dat = [0.5]
y_dat = sess.run(y, feed_dict={x:x_dat})
print "parameters after init:",sess.run( tf.get_collection("model_parameters") )
Xq = ed.models.Empirical(tf.Variable(tf.zeros((100,1))))
Xjump = ed.models.Normal( loc=Xin, scale=[0.1] )
yout = ed.models.Normal( loc=y, scale=[0.01])
inference = ed.MetropolisHastings( {Xin:Xq}, {Xin:Xjump}, {yout:y_dat} )
inference.run()
print "parameters afer run:",sess.run( tf.get_collection("model_parameters") )
The output shows that after the call to inference.run
the model parameters have been reset, and looking a little deeper it looks like the inference was done with the untrained model parameters too
INFO:tensorflow:Restoring parameters from test
parameters after load: [0.16402531, 0.94702035, 0.0043922863]
parameters after init: [0.16402531, 0.94702035, 0.0043922863]
100/100 [100%] ██████████████████████████████ Elapsed: 0s | Acceptance Rate: 0.970
parameters afer run: [0.0, 0.0, 0.0]
I have tracked the problem down to the first few lines of the run
method of the inference
class (here) where a call to tf.global_variables_initializer()
is made. You can switch this off by passing the variables
parameter and empty list, however in that case the inference falls over because none of the internal Variables
are initialized. I also can’t get a list of those variables with a solution like this one
inference = ed.MetropolisHastings( {Xin:Xq}, {Xin:Xjump}, {yout:y_dat} )
internal_vars = [v for v in tf.global_variables() if v.name.split(':')[0] in set(sess.run(tf.report_uninitialized_variables()))]
inference.run( variables=internal_vars )
because I don’t think they have been added to the graph until inference.run
has been called (i.e., internal_vars
is empty until after the call to run
.
Thanks again,
Jim