Understanding Neural ALU


Hi, you must have heard by now the new paper published by Google DeepMind — Neural Arithmetic Logic Unit. Last week I skimmed through it and ever since I was curious to understand the nuances of its mathematics. So here I try to write my understanding on the paper after reading it. I don’t think there’s any better way to learn than to tell the world what you know, so here it is.

Today, neural networks as we know it, are able to recognise objects, generate data, translate pretty much anything and can even learn to drive cars provided you have wide data to train on. But one thing neural nets lack is the ability to count or learn numerical functions. Well, they do work pretty well inside training data but what happens when we extrapolate our data and feed it to the same nets. Of course they fail. So this is the core theme of this paper. Neural Arithmetic Logic Unit (NALU) as they say, are able to track time, perform arithmetic over images of numbers, translate numerical language into real-valued scalars, execute computer code, and count objects in images.

To illustrate numerical extrapolation failures in Neural Networks, they trained normal neural networks to learn the identity function. The training data consisted of numbers between -5 and 5 and they tested the learned function on numbers between -20 and 20. The model was an autoencoder which accepted a scalar value as an input(eg. input-2), encoded it using fully connected layers, then reconstructed the input value as a linear combination of the last hidden layer(eg. output-2). Conclusion — All nonlinear functions fail to learn to represent numbers outside of the range seen during training. The severity of this failure directly corresponds to the degree of non-linearity within the chosen activation function.

This is what a regular autoencoder look like —

Input(2) → Encoder (Input layer + Fully Connected layer) → Fully Connected Layer → Decoder(Fully Connected layer + Reconstruction layer) → Output(2)

As you can see below, all non linearities gave high mean absolute error outside the training data range of -5 and 5

Source: Neural Arithmetic Logic Units arXiv:1808.00508 [cs.NE]

Neural Accumulator and Neural ALU

To overcome the above problem they intoduced NAC and NALU.

First model is the Neural Accumulator(NAC) which is just like the normal affine layer f(Wx + b), the difference being in transformation or the weight matrix consisting just -1’s, 0’s, and 1’s. i.e. rather than arbitrary scaling of rows in input vector, the outputs are additions and subtractions of the same. This prevents the layer from changing the scale of the representations of the numbers when mapping the input to the output, meaning that they are consistent throughout the model, no matter how many operations are chained together.

Since every element in W is either -1, 0 or 1, there is a constraint in differentiation. So, a continuous and differentiable (to be able to use gradient descent) parametrization is proposed by writing W in terms of unconstrained value, W = tanh(Wˆ ) .σ(Mˆ ) (element wise product of tanh and sigmoid). W hat and M hat could be anything and the result W is guaranteed to be in between -1 and 1 (included) and to biased to be close to -1, 0 and 1.

Neural Accumulator(NAC) Source: arXiv:1808.00508 [cs.NE]

To learn more complex mathematical function like multiplication, Neural ALU is introduced. It uses tow NACs, one performing addition/subtraction as described above and the other capable of multiplication, division power functions like √ x. As with the NAC, there is the same bias against learning to rescale during the mapping from input to output.

If you look at the image below, there are two NAC cells in purple color. The smaller one with matmul outputs the result of the NALU’s addition and subtraction operations using the above equations: a = Wx where W = tanh(Wˆ ) .σ(Mˆ ). The second bigger NAC cell in purple color operates in log space and is therefore capable of learning to multiply and divide, storing its results in: m = expW(log(|x| + epsilon))

NALU Source: arXiv:1808.00508 [cs.NE]

Finally these two NAC cells are interpolated by a learned sigmoidal gate g (orange cell). If add/sub cell’s output value is applied with a weight of 1 then the other NAC cell is turned off and vice versa.

y = g .a + (1 − g). m where g = σ(Gx)

Altogether, this cell can learn arithmetic functions consisting of multiplication, addition, subtraction, division, and power functions in a way that extrapolates to numbers outside of the range observed during training.


EXPERIMENTS

They performed various experiments to prove their claims.
First was the simple function learning task. NAC and NALU were able to learn functions like addition, sub, multi, div, square and square root both statically and recurrent. The result was that while several standard architectures succeed at these tasks in the interpolation case, none of them succeed at extrapolation. However, in both interpolation and extrapolation, the NAC succeeds at modeling addition and subtraction, whereas the more flexible NALU succeeds at multiplicative operations as well.
Second, they experimented on MNIST Digit Counting task and MNIST Digit Addition task. standard architectures succeed on held-out sequences in the interpolation length, but they completely fail at extrapolation. Notably, the RNN-tanh and RNN-ReLU models also fail to learn to interpolate to shorter sequences than seen during training. However, the NAC and NALU both extrapolate and interpolate well.

They also experimented the language to number translation to test whether representations of number words are learned in a systematic way. Expectedly, NAC and NALU not only learned but also extrapolated smartly to translate language to numbers.
They also trained a Reinforcement Learning program where the objective was that the agent has to reach the target at the exact same time instance T as given in a 5×5 grid

Learning to track time in a grid world environment Source: arXiv:1808.00508 [cs.NE]

There were two models. Simply stating, one with LSTM memory and the other with LSTM memory and passed through an NAC and back into the LSTM. Both agents were trained on episodes where T ∼ U{5, 12}. Both agents quickly learned to master the training episodes. However, the agent with the NAC performed well on the task for T ≤ 19, whereas performance of the standard LSTM agent deteriorated for T > 13. Both of the agent failed. The authors hypothesize that the more limited extrapolation (in terms of orders of magnitude) of the NAC here relative with other uses of the NAC was caused by the model still using the LSTM to encode numeracy to some degree.

Source: Deep Learning on Medium