Adaptive Computation Time (ACT) in Neural Networks. (Part 1)

Source: Deep Learning on Medium

The interesting thing is how many updates to perform during each step?

We introduce a halting unit (essentially a sigmoidal unit) whose role is to decide whether we should stop computations or to continue.

The processing stops when the sum of halting unit outputs (h) becomes close to 1.0 (actually, 1.0 minus epsilon, where epsilon is a small value chosen to be 0.01 in the paper). Then the last h is replaced by a remainder to make sum(hᵢ) to be equal to 1.0. This procedure gives us the halting probabilities (pᵢ) of the intermediate steps. Then we determine mean-field updates for the states and outputs:

So, for example, let the halting unit outputs were 0.1, 0.3, 0.4 and 0.4 at the four consecutive steps. The sum of these h’s is larger than 1-epsilon, so we produce a list of weights: [0.1, 0.3, 0.4, 0.2] (the last value was truncated) and calculate the final state and output as a weighted sum with these weights:

s = 0.1*s¹ + 0.3*s² + 0.4*s³ + 0.2*s⁴

y = 0.1*y¹ + 0.3*y² + 0.4*y³ + 0.2*y⁴

We need to limit the computation time, otherwise, a network has an incentive to process data for as long as possible. So, we add a ponder cost to our loss function. The ponder cost resembles the total computation during the sequence. It is added to the total loss with a weight τ, a time penalty hyperparameter.

The ponder cost is discontinuous in some points, but in practice, we simply ignore this, so the gradient of the ponder cost with respect to the halting activations is straightforward and the network can be differentiated as usual and trained using backpropagation.

Addition training example.

The idea was tested on a set of tasks specially designed to be hard without such a mechanism: determining the parity of a sequence of binary numbers (a sequence is passed as a vector all elements at once, not as a sequence of elements which is typical for RNN training, that’s important), solving logic task tests, addition of a sequence elements producing the cumulative sums, and finally sorting sequences of 2 to 15 numbers.

The ACT nets produce significantly better results than baselines (the same RNNs or LSTMs but without ACT), yet require careful tuning of the time penalty hyperparameter (authors used grid search).