In TensorFlow, it’s possible to train multiple intertwined deep neural networks together. This can help a ton with transfer learning, where you may have a great model for a domain, but only a little training data to learn a new task in that domain.
Here’s the task: train networks that can take in a binary or decimal number from 0 to 9, add 1 to it or multiply it by 2, and return a binary or decimal number from 0 to 9.
The naive approach would be to train each path separately, with no intertwining networks. It’s easiest and fastest way to complete this task, but that wouldn’t help us learn!
Let’s constrain ourselves to only training 2 encoder modules (binary and decimal), 2 math operation modules (plus 1 and times 2), and 2 decoder modules (binary and decimal). We can compose the pieces together at execution time.
Here’s what one path through our model might look like:
But training just that one path wouldn’t help us reuse the modules. We need to be able to reuse the encoders, math operations, and decoders. We can solve that by intertwining and training the networks like so:
Everything is pretty typical here; it’ll mostly be a lot of dense layers, except those selecting modules. In programming, you might call them case expressions. In electrical engineering, you might call them multiplexers. In TensorFlow, we call them tf.where , tf.cond, and tf.case operations.
We don’t need tf.cond or tf.case in this example because they don’t work well with batches.
tf.where lets us choose between two branches for each example in a batch. In plain Python, it might look like this:
def where(tests, as, bs):
result = 
for test, a, b in zip(tests, as, bs):
val = a if test else b
About performance and execution semantics (cool!)
Remember that TensorFlow uses a dataflow model, but it’s not necessarily lazy. During execution, both branches (read: tensors) in a tf.where need to be executed. With a tf.where, we throw out half the computed values. This will eat some processor cycles, but considering the nature of GPUs (SIMD good, branching bad), it’s likely the fastest way. If you need lazy instead of eager execution, try tf.cond or tf.case instead and give up the easy batch processing.
Here’s how we can select between between the plus 1 branch and the times 2 branch:
math_output = tf.where(use_plus1, plus1_output, times2_output)
# use_plus1 is a boolean of shape [batch_size]
# plus1_output has batch_size as it's first dim
# times2_output has the same shape as plus1_output
Controlling the branches
For each training example, we need to let our network know:
- which encoder to use
- which math operation to use
This information can be fed in through the feed dict as normal. We don’t need to choose between decoders.
Output and loss
There’s no reason not to train every decoder for every training example. We can combine their losses by simply adding them.
total_loss = binary_loss + digit_loss
Thankfully, the gradients for each decoder will stay the same. The partial derivative for the loss of both decoders is 1. Things might go haywire with certain optimizers when one loss is much higher than the other, but it works great for this example.
Putting it all together
Check out the code! It’s short, sweet, and copy-pastable.
Source: Deep Learning on Medium