How to initialize `inference` object's internal Variables

Hello All,

Short story:

I am having trouble importing a pre-trained model into Edward for inference. My issue is related to the behavior of the variables argument to, specifically, when calling, how do I initialize only the tf.Variables defined internally by the inference object? Essentially, in my application I need to call variables=[] ) to prevent the pre-trained Variables from losing their values during inference, however, when I do that the internal variables are not initialized and so the inference falls over.

Longer example:

A prototype of my problem is below. I am loading a pre-trained model which consists of a graph with a bunch of tf.Variables which have their values set by the training:

import edward as ed
import tensorflow as tf

# load a model I trained elsewhere.

saver = tf.train.import_meta_graph("test.meta")

sess = ed.get_session()

saver.restore(sess, "test")

# the model is a quadratic with some fitted parameters

print "parameters after load:", tf.get_collection("model_parameters") )

Then I want to define an inference problem and do some MCMC (or whatever)

g = sess.graph

x = g.get_tensor_by_name("input:0")
y = g.get_tensor_by_name("output:0")

# set up the inference problem

Xin = ed.models.Uniform(low=[0.], high=[1.])
y = ed.copy(y, dict_swap={x:Xin})

# target output

x_dat = [0.5]
y_dat =, feed_dict={x:x_dat})

print "parameters after init:", tf.get_collection("model_parameters") )

Xq = ed.models.Empirical(tf.Variable(tf.zeros((100,1))))
Xjump = ed.models.Normal( loc=Xin, scale=[0.1] )
yout = ed.models.Normal( loc=y, scale=[0.01])

inference = ed.MetropolisHastings( {Xin:Xq}, {Xin:Xjump}, {yout:y_dat} )

print "parameters afer run:", tf.get_collection("model_parameters") )

The output shows that after the call to the model parameters have been reset, and looking a little deeper it looks like the inference was done with the untrained model parameters too

INFO:tensorflow:Restoring parameters from test
parameters after load: [0.16402531, 0.94702035, 0.0043922863]
parameters after init: [0.16402531, 0.94702035, 0.0043922863]
100/100 [100%] ██████████████████████████████ Elapsed: 0s | Acceptance Rate: 0.970
parameters afer run: [0.0, 0.0, 0.0]

I have tracked the problem down to the first few lines of the run method of the inference class (here) where a call to tf.global_variables_initializer() is made. You can switch this off by passing the variables parameter and empty list, however in that case the inference falls over because none of the internal Variables are initialized. I also can’t get a list of those variables with a solution like this one

inference = ed.MetropolisHastings( {Xin:Xq}, {Xin:Xjump}, {yout:y_dat} )
internal_vars = [v for v in tf.global_variables() if':')[0] in set(] variables=internal_vars )

because I don’t think they have been added to the graph until has been called (i.e., internal_vars is empty until after the call to run.

Thanks again,


Have you seen Saving Model Parameters? The short answer is that importing the metagraph only isn’t supported. This is because the serialized output doesn’t also write random variables to disk but only TensorFlow ops; it’s one of these low level TensorFlow issues that we’re trying to resolve as we merge Edward code natively in.

If you only checkpoint your model with tf.train.Saver, you can restore the tf.Variables from the graph; this does require re-building the graph though.

Thanks for the reply! I’m not sure how this solves my problem. The pre-trained model is trained using tensor flow only and I can successfully import the variable values from that stage, but they get lost when I initialize an Edward.Inference object. Maybe my current workaround will clarify:

Edit: (In reality, I have retained all of the code from the base method and spliced in the new definition of init)

<load a pre trained model and build a graph of RandomVariables around it>

infer = <an Edward Inference instance>

temp = set(tf.global_variables())
infer.initialize(*args, **kwargs)
internal_vars = set(tf.global_variables()) - temp

init = tf.variables_initializer(internal_vars)

for _ in range(infer.n_iter):
      info_dict = infer.update()