Source: Deep Learning on Medium
Keras Custom Training Loop
How to build a custom training loop at a lower level of abstraction, K.function, opt.get_updates usage and other stuff under the hood of Keras engine
Keras is a high level library, among all the other deep learning libraries, and we all love it for that. It abstracts most of the pain that, our not less beloved, Tensorflow brings with itself to crunch data very efficiently on GPU.
I use Keras at work and for my personal projects and I am deeply in love with its API and approach to model building. But, what happens when you want to do something out-of-the-box? I will tell you, you stumble upon the framework. Which framework? The framework of the Keras paradigm. In Keras things are easy and pragmatic, you follow the steps and the things work amazingly well. But, if for whatever reason, you need to skip or detour from the main route, things start to get messy.
You could argue: “but Keras is highly flexible, it has this amazing functional API for building daydream labyrinthic models, support for writing custom layers, the powerful Generators for handling Sequences, Images, multiprocessing, multi input-output, GPU parallelism and…”, I know and in fact, I know you know, or at least I expect it, otherwise, you would not be reading this post.
But, in spite of this flexibility I could still point out some fairly annoying experiences in Keras such as loss functions with multiple inputs/parameters, loading saved models with custom layers… But somehow you can get that solved with some workarounds or by digging a bit into the code.
However one of the things I struggled the most with is creating a custom training loop. But, why the heck do you want to build a custom training loop in the first place? Doesn’t the whole point of Keras is to abstract such nuances so you can focus on the model? Well, that’s quite true, but there are some corner cases in which you will want to get your hands dirty: when your model has multiple input and outputs of different shapes (not concatenable) and a unique loss function, when you need access to the gradients of the optimization at training time… or if you want specific applications: GANs and Reinforcement Learning mainly as far as I am concerned (let me know in the comments if you find others so I can learn too). The main reason to write this post is to clarify (or document if you prefer) the usage of certain tools in the Keras engine to build a custom training loop without being constrained to the strictly to the framework.
So, enough of the boring, show me the code!
Here it is, so you can take a look:
You could break this down into some pieces:
- Dataset creation: dummy dataset for our example.
- Keras default workflow: which includes model, loss function and optimizer definition.
- Graph creation: creation of the computational graph and linking all its parts. This section differs from the default Keras workflow because it is done under the hood by the Keras engine.
K.functionusage: this is the tricky part, the
K.functionit is not very well documented so I’ll try to show some light on the issue.
- Training loop: nothing special here, just the for loops and some printing to monitor the evolution of the training (and testing of course).
A dummy dataset for our case. Given 2 numbers in the range of [0, 9], the network must predict the sum of the two. So:
- Samples = [ 100 x 2 ], so 100 samples of 2 features (the 2 numbers to sum)
- Targets = [ 100 ], the result of the sum of those 100 samples, ideally, I’d like this to be [ 100 x 1 ] but we all are familiar with how Numpy works when reducing dimensions.
A small test dataset with 10 samples has been created as well.
# Training samples
samples = np.random.randint(0, 9, size=(100,2)) targets = np.sum(samples, axis=-1)# Samples for testing
samples_test = np.random.randint(0, 9, size=(10,2)) targets_test = np.sum(samples_test, axis=-1)
You already know this, so:
x = Input(shape=)
y = Dense(units=1)(x)
model = Model(x, y)
def loss_fn(y_true, y_pred):
# You can get all the crazy and twisted you
# want here no Keras restrictions this time :)
loss_value = K.sum(K.pow((y_true - y_pred), 2))
return loss_value # Optimizer to run the gradients
optimizer = Adam(lr=1e-4)
The only thing in here is that you can make the loss function as crazy and twisted as you want since the Keras engine won’t stop you with its
_standarize_user_data(link) and won’t complain. You can feature multiple inputs, configurable loss function by arguments… I have implemented a simple sum of squared errors (SSE) for this demo.
Graph creation and linking
This is the tricky part. In Keras the only graph you define is the computation flow of your model (and the loss function if you want, but under some restrictions). But you do not define the linking between the loss function, the model, and the gradients computation or the parameters update.
This is when we need to change our view from code workflow to graphs workflow, or tensors flow (liked that joke?). So apart to define the input to our model, the model itself and the loss function, which we have already done it, we need to:
- Create the input for our ground truth so we can compute the loss, the so called
- Get the output of the model, or the prediction, well-known as
- Link the model prediction and the ground truth with the loss function (already created as well)
y_true = Input(shape=)
y_pred = model(x)
loss = loss_fn(y_true, y_pred)
If you pay close attention to this, you’ll find out that the only inputs to this graph are
y_true (I know you guessed cause there are the only variables with the
Input call assigned, but just in case…).
So we already have the inputs, the model and the loss to be minimized all in one computational graph or in a graph where tensors flow (sorry I can’t stop it).
The only left thing to do define the graph for computing gradients of the loss with respect to the weights of the model and update those weights according to the learning rate. Easy right? Well, this is exactly what
optimizer.get_updates does. Given some parameters and a loss dependent upon those parameters, it returns the computational graph for computing the gradients (using the infamous K.gradients) and updating the weights.
optimizer.get_updatesreturns the computational graph for computing the gradients and updating the weights, given some parameters (weights) to optimize and a loss dependent upon those parameters
This is done by calling
optimizer.get_updates with the
loss as we defined it and the parameters, or trainable weights of the model, to optimize.
# Operation fort getting
# gradients and updating weights
updates_op = optimizer.get_updates(
Right now we have our graph done. Well, we actually have two graphs:
- Graph 1: Inputs = [ x, y_true ], Outputs = [ loss]
- Graph 2: Inputs = [ loss, weights ], Outputs =[ weights updated ]
The first graph corresponds to the forward pass of the network, and graph 2 corresponds to the backward pass or the optimization loop.
So now we have 2 graphs that are supposed to be run very efficiently on GPU, so ok, how do we do run them? That is where
K.function helps us.
keras.backend.function for completeness) usage is similar to the Tensorflow (this one was legitimate)
tf.Session for instantiating the graph and
session.run for running it. So the
K.function description would be:
K.functionreturns a function that calls a single iteration, or a forward pass, of a computational graph, described previously and referenced by its
outputsgiven as the parameters. If the
updateskeyword is set, it also runs the backward pass with the operations described in the graph passed as a parameter of the
With this in mind, we create 2 functions for executing the graphs defined previously:
The train function, which features a forward and a backward pass for each call to it. And it is configured as:
train = K.function(
And the test function, which only computes a forward pass given that it is intended for testing and not for updating the weights. Note that the
updates keyword is not set.
test = K.function(
Note that both of these functions will have a single argument as input, which it will be a list. This list must contain the inputs tensors specified in
inputs. On the other hand, the output will be a list with the output tensors specified in the
In our case a call will be done like:
loss = train([sample, target]), and then the loss will be a list so:
loss = loss. But you can check that on the code to see it in context.
Finally, we can set our custom training loop. There is nothing special here, just a default training loop, however, I will remark some specifics. Although in the code you can see the training loop and the test loop I will focus only on the former and you can extend it to the later.
Devil is in the details so:
tqdmit is just a library that implements a progress bar that will inform about the progression of the training during one epoch.
- Since we are taking one sample at a time,
batch_size = 1, Numpy squeezes the batch dimension, but the model expects an input with 2 dimensions, batches and features, so we need to add the batch dimension manually with
- The graph works with tensors so we need to transform the inputs to the graph into tensors with
- There has been implemented a training loss accumulator to compute the mean at each step in
loss_train_mean, this metric is printed in the progress bar at each time step to monitor the evolution of the training.
The usage of the
train function deserves special mention. The
train function will run a single execution of the graph, a forward and a backward pass, on the GPU with the given inputs each time we call it. This graph holds as inputs the sample and the target to perform a training step, and returns as output a list of tensors, in our case just one, which is the loss of that training step.
Check that the dimensions must match for every input tensor, this is done by the Keras engine by default, but in here we aim to take the control. So, when we call
train([sample, target]) the
sample must have same dimensions as
x = Input(shape=), so 2, batch and feature dimensions, and
target must have the same dimensions as
y_true = Input(shape=), which is a zero dimension tensor, or a single number (a scalar).
If you run the code you will see something like this:
Training Loss is the mean of the training losses over that epoch and the
Test Loss is the mean of the test loss over the samples of the test dataset.
After some iterations, you can appreciate how both the training loss and the test loss will decrease.
Although the purpose of this guide is to show how to work with Keras in a low level of abstraction, this is not a good practice or the preferred approach. Why? Because Keras’ high level of abstraction is the desired outcome of a carefully designed project. So the usage of these tools should be an exception and not the norm unless you really want to work at a low level of abstraction.
But, what if I want to work at a low level of abstraction? (maybe you are a researcher in optimization writing a new optimizer, or maybe you are working in algebra for scientific computing… who knows?) In that case, you could ask yourself: Is there any other alternative way to build training loops if I do really want to take control of everything?
Well, I am happy you asked that question because as an engineer I just can not show one solution without proposing any other alternatives to consider. There are also other 2 great approaches to train your model easily and still have things under your control:
- Tensor Layer: a widely overlooked project by Google that aims to provide the main building blocks of deep learning models without renouncing to the low level specifics.
- TensorFlow 2.0 Custom Training Loop: with the integration of Keras into the version 2.0 of Tensorflow you kind of have the best of both worlds, the high level building blocks of Keras with the low level flow control of Tensorflow.
So if you really need the low level go for Tensorflow, if you just need a modest level of abstraction go for TensorLayer but if you, like me, work mostly in the high level go for Keras, and use these tricks to overcome the unusual corner cases where you really need it.