*Automatically bind your fields to and from a tensorflow graph*

Would it be cool to automatically bind class fields to tensorflow variables in a graph and restore them without manually get each variable back from it?

*The code for this article can be found **here,** a jupyter-notebook version can be found **here*

Image you have a `Model`

class

https://gist.github.com/764c20a0b7c871851f2b6d354fd17372

Usually, you first **build** your model and then you **train** it. After that, you want to **get** from the saved graph the old variables without rebuild the whole model from scratch.

https://gist.github.com/e46c216a4883a88930268d3d72860788

`<tf.Variable 'variable:0' shape=(1,) dtype=int32_ref>`

Now, imagine we have just trained our model and we want to store it. The usual pattern is

https://gist.github.com/b752a437082a584fbc4d0b55046b596f

Now you want to perform **inference**, aka get your stuff back, by loading the stored graph. In our case, we want the variable named `variable`

https://gist.github.com/d992c0a431745306635e62cf13d5ae98

`INFO:tensorflow:Restoring parameters from /tmp/model.ckpt`

Now we can get back our `variable`

from the graph

https://gist.github.com/4310a0bd3b1eccd7b63785fde7cebcfb

`name: "variable" op: "VariableV2" attr { key: "container" value { s: "" } } attr { key: "dtype" value { type: DT_INT32 } } attr { key: "shape" value { shape { dim { size: 1 } } } } attr { key: "shared_name" value { s: "" } }`

But, what if we want to use our `model`

class again? If we try now to call `model.variable`

we get None

https://gist.github.com/734dc7d2e3b4677321f427f77af72703

`None`

One solution is to **build again** the whole model and restore the graph after that

https://gist.github.com/23e4054280e934c9a4b4c3d9eb715def

`INFO:tensorflow:Restoring parameters from /tmp/model.ckpt <tf.Variable 'variable:0' shape=(1,) dtype=int32_ref>`

You can already see that is a big waste of time. We can bind `model.variable`

directly to the correct graph node by

https://gist.github.com/69d878978e8196c710256dfea1b24bf3

`name: "variable" op: "VariableV2" attr { key: "container" value { s: "" } } attr { key: "dtype" value { type: DT_INT32 } } attr { key: "shape" value { shape { dim { size: 1 } } } } attr { key: "shared_name" value { s: "" } }`

Now image we have a very big model with nested variables. In order to correct restore each variable pointer in the model you need to:

- name each variable
- get the variables back from the graph

Would it be cool if we can automatically retrieve all the variables setted as a field in the Model class?

### TFGraphConvertible

I have created a class, called `TFGraphConvertible`

. You can use the `TFGraphConvertible`

to automatically **serialize** and **deserialize**” a class.

Let’s recreate our model

https://gist.github.com/4bcb5cc60c98d447d65d8a79ee92d8ac

It exposes two methods: `to_graph`

and `from_graph`

### Serialize — to_graph

In order to **serialize a class** you can call the **to_graph** method that creates a dictionary of field names -> tensorflow variables name. You need to pass a `fields`

arguments, a dictionary of what field we want to serialize. In our case, we can just pass all of them.

https://gist.github.com/bd46adf6bce31bb83c1a6ac1baacb83e

`{'variable': 'variable_2:0'}`

It will create a dictionary with all the fields as keys and the corresponding tensorflow variables name as values

### Deserialize — from_graph

In order to **deserialize a class** you can call the **from_graph** method that takes the previous created dictionary and bind each class fields to the correct tensorflow variables

https://gist.github.com/e922b3949aebe63e12f8d2368a749225

`None <tf.Tensor 'variable_2:0' shape=(1,) dtype=int32_ref>`

And now you have your `model`

back!

### Full Example

Let’s see a more interesting example! We are going to train/restore a model for the MNIST dataset

https://gist.github.com/be6a79ff84d333a0896ff57fc6105bc0

Let’s get the dataset!

https://gist.github.com/c7d22cdc77e53fe74c491be60fec4890

`Using TensorFlow backend.`

Now it is time to train it

https://gist.github.com/7e22a21a3fedbeaf8a2cac3b1ea98e7e

`0.125 0.46875 0.8125 0.953125 0.828125 0.890625 0.796875 0.9375 0.953125 0.921875`

Perfect! Let’s store the serialized model in memory

https://gist.github.com/e5ecc4a5dcefd595f1b503b3c761a7c8

`{'x': 'ExpandDims:0', 'y': 'one_hot:0', 'forward_raw': 'dense_1/BiasAdd:0', 'accuracy': 'Mean:0', 'loss': 'Mean_1:0', 'train_step': 'Adam'}`

Then we reset the graph and recreat the model

https://gist.github.com/cf720a3eff740cda8bc7a4437a2a9af4

`INFO:tensorflow:Restoring parameters from /tmp/model.ckpt`

Of course, our variables in the `mnist_model`

do not exist

https://gist.github.com/1e4401a45bd0809af57486891ce99655

`--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) <ipython-input-21-9def5e0d8f6c> in <module>() ----> 1 mnist_model.accuracy AttributeError: 'MNISTModel' object has no attribute 'accuracy'`

Let’s recreate them by calling the `from_graph`

method.

https://gist.github.com/44a2aab032610f7b0d3da037c533d84b

`<tf.Tensor 'Mean:0' shape=() dtype=float32>`

Now `mnist_model`

is ready to go, let’s see the accuracy on a bacth of the test set

https://gist.github.com/70b7c85931867f9d3953cb5c1e391fcc

`INFO:tensorflow:Restoring parameters from /tmp/model.ckpt `

1.0

### Conclusion

With this tutorial we have seen how to serialize a class and bind each field back to the correct tensor in the tensorflow graph. Be aware that you can store the `serialized_model`

in `.json`

format and load it directly from anywhere. In this way, you can directly create your model by using Object Oriented Programming and retrieve all the variales inside them without having to rebuild them.

Thank you for reading

Francesco Saverio Zuppichini

*Originally published at **gist.github.com**.*

Source: Deep Learning on Medium