Saving and Loading Aboleth Models

In this tutorial we will cover the basics of how to save and load models constructed with Aboleth. We don’t provide any inherent saving and loading code in this library, and rely directly on TensorFlow functionality.

Naming the Graph

Even though the whole graph you create is saved and automatically named, it helps when loading to know the exact name of the part of the graph you want to evaluate. So to begin, we will create a very simple Bayesian linear regressor with place holders for data. Let’s start with the place holders,

with tf.name_scope("Placeholders"):
    n_samples_ = tf.placeholder_with_default(NSAMPLES, shape=[],
                                             name="samples")
    X_ = tf.placeholder_with_default(X_train, shape=(None, D),
                                     name="X")
    Y_ = tf.placeholder_with_default(Y_train, shape=(None, 1),
                                     name="Y")

We have used a name_scope here for easy reference later. Also, we’ll assume variables in all-caps have been defined elsewhere. Now let’s make our simple network (just a linear layer),

net = ab.stack(
    ab.InputLayer(name='X', n_samples=n_samples_),
    ab.DenseVariational(output_dim=1, full=True)
)

And now lets build and name our graph and associate names with the parts of it we will to evaluate later,

with tf.name_scope("Model"):
    f, kl = net(X=X_)
    likelihood = tf.distributions.Normal(loc=f, scale=ab.pos(NOISE))
    loss = ab.elbo(likelihood, Y_, N, kl)

with tf.name_scope("Predict"):
    tf.identity(f, name="f")
    ab.sample_mean(f, name="Ey")

Now note how we have used tf.identity here to name the latent function, f, again this is so we can easily load it later for drawing samples from our network. We also don’t need any variables to assign these operations to (unless we want to use them before saving), we just need to build them into the graph.

Saving the Graph

At this point we recommend reading the Tensorflow tutorial on saving and restoring. We typically use a tf.MonitoredTrainingSession as it handles all of the model saving and check-pointing etc. You can see how we do this in the SARCOS demo, but we have also copied the code below for convenience,

# Training graph with step counter
with tf.name_scope("Train"):
    optimizer = tf.train.AdamOptimizer()
    global_step = tf.train.create_global_step()
    train = optimizer.minimize(loss, global_step=global_step)

# Logging
log = tf.train.LoggingTensorHook(
    {'step': global_step, 'loss': loss},
    every_n_iter=1000
)

# Training loop
with tf.train.MonitoredTrainingSession(
        config=CONFIG,
        checkpoint_dir="./",
        save_summaries_steps=None,
        save_checkpoint_secs=20,
        save_summaries_secs=20,
        hooks=[log]
) as sess:
    for i in range(NEPOCHS):

        # your training code here
        ...

This code will also make it easy to use TensorBoard to monitor your training, simply point it at the checkpoint_dir and run it like,

$ tensorboard --logdir=<checkpoint_dir>

Once you are satisfied that your model has converged, you can just kill the python process. If you think it could do with a bit more “baking”, then just simply re-run the training script and the MonitoredTrainingSession will ensure you resume learning where you left off!

Loading Specific Parts of the Graph for Prediction

Typically we only want to evaluate particular parts of the graph (that is, the ones we named previously). In this section we’ll go through how to load the last checkpoint saved by the MonitoredTrainingSession, and to get hold of the tensors that we named. We then use these tensors to predict on new query data!

# Get latest checkpoint
model = tf.train.latest_checkpoint(CHECKPOINT_DIR)

# Make a graph and a session we will populate with our saved graph
graph = tf.Graph()
with graph.as_default():

    sess = tf.Session()
    with sess.as_default():

        # Restore graph
        saver = tf.train.import_meta_graph("{}.meta".format(model))
        saver.restore(sess, model_file)

        # Restore place holders
        X_ = graph.get_operation_by_name("Placeholders/X").outputs[0]
        Y_ = graph.get_operation_by_name("Placeholders/Y").outputs[0]
        n_samples_ = graph.\
            get_operation_by_name("Placeholders/samples").outputs[0]

        feed_dict = {X_: X_test, n_samples_: PREDICTSAMPLES}

        f = graph.get_operation_by_name("Predict/f").outputs[0]
        Ey = graph.get_operation_by_name("Predict/Ey").outputs[0]

        f_samples, y_pred = sess.run([f, Ey], feed_dict=feed_dict)

The most complicated part of the above code is remembering all of the boiler-plate to insert the saved graph into a new session, and then do get our place holders and prediction tensors. Once we have done this though, evaluating the operations we need for prediction is handled in the usual way. We have also assumed in this demo that you want to use more samples for prediction (PREDICTSAMPLES) than for training (NSAMPLES), so we have made this also a place holder.

That’s it!