Neural Ordinary Differential Equations: Major Breakthrough in Neural Network Research

Source: Deep Learning on Medium


A new class of Deep Neural Networks based on ODE solvers

Neural Ordinary Differential Equations is the official name of the paper that won the best paper award at NeurIPS ( Neural Information Processing System — a machine learning and computational neuroscience conference held every December; not to mention it is the biggest AI conference of the year). The authors introduced a new network without any layers. This network is modeled as a differential equation and implemented the wisdom gained over a hundred years of research on differential equation solvers, to approximate the basic function of time series data. A glimpse of the paper through my lens is as follows:

Deep Neural Networks are stacks of discrete layers, where each layer performs matrix operations and the network is optimized using gradients of cross entropy function. The matrix operations in each layer introduce a small error hence compounded over the network. To diminish this error Microsoft introduced, ResNet( Residual Network), which established the concept of skip connections. The authors of this paper suggest that the mathematical form for residual networks resembles an ordinary differential equation and ODE solvers like Euler’s method or adjoint method can be deployed to solve an ODE network. They also found a better prediction than a recurrent network for some projects.

An ODE Network approximated the spiral function below better than a Recurrent Network

Let’s recap the ODE and Euler Method to solve it

By definition, an ordinary differential equation (ODE) is a differential equation having one or more functions of one independent variable and the derivatives of those functions. The snippet below will take you back to your undergrad math classroom.

ODE nets are Initial value problems. Here the idea is that if given a state of a system at time =0, state of the system can be computed at time =1, by integrating the derivatives of the state through time and approximating the function using a numerical solver. The oldest and simplest method to solve ODE is Euler’s Method. This method starts at some initial state and moves in the direction of the gradient field at that point for finite step size and continues to approximate the solution. Euler method is considered as a poor solver as it accumulates error at every step.

Euler’s method and ResNets

Neural Network (NN) is insanely popular as machine learning models. So these consist of a series of layers which are matrix operations. In a NN, input data is fed into a layer, multiplied by a weight matrix, a bias is added, then applying an activation function to the result of the matrix operations and the output of this layer is fed into the next layer. Each layer in NN adds a small error that compounds over the network. A simple solution to reduce this error is to add more layers. But there is always a threshold to the number of layers in a network. Which is why Microsoft introduced Residual Networks (ResNets) to reduce this error.

Residual networks work on the concept of feeding the input and output of the current layer into the next layer and repeat this across the whole network thus creating skip connections.

ResNets Architecture

Below equation for residual neural networks can be seen as an initial equation where Euler’s method can be used to solve this ODE.

Thus neural networks can be represented as differential equations.

Let’s look at the code for a Resnet below.

Define a ResNet block and the weights to the hidden units at every layer. Now, for creating a ResNet, compute the update at every layer and add it to the current hidden layer. This looks just like Euler integration

Resnet code block

Simply feed the current depth to the ResNet block and use the fixed set of parameters for the entire depth. This change means that the dynamics are now defined even between layers and can change continuously with depth.

The Adjoint Method-A promising direction

So when neural networks are modeled as ODE, how to solve them. The answer is an age-old Adjoint Method!!! This method involves computing three integrals. One integral computes the adjoint which captures how a loss function L (shown in the below picture) changes with respect to the hidden state. Second computes the hidden state backward in time together and the third integral tells us how the weights are behaving laws. Combining these three integrals into a single vector, through an ODE solver will then solve the equation.

Reverse-mode differentiation of an ODE solution

Approximate the derivative, don’t differentiate the approximation!

There are no fixed layers in ODE net, the best analogy for depth is the number of evaluations of the dynamic network that the ODE solvers create. So no need to specify the depth of the model. Below figure shows the trajectory corresponding to different inputs in an ODE net. The dots in marker show evaluations by the solver.

Evidently, dynamics in the center require fewer evaluations than the more complicated dynamics on the sides. Also, the average number of dynamics evaluations required by the ODE solver increases during training So the downside of these models is that the direct control of the time cost of the model cannot be done during training, also the final model often end up being more expensive than the corresponding resnets by about a factor of two to four times.

ODE nets require fewer parameters than resnets because the network dynamics changed smoothly with depth. The table below shows test accuracies for a classification task, ODE net achieving the same test accuracy with about a third of the parameters as resnets.

Performance on MNIST

Key takeaways for ODE nets

  1. Constant memory cost
  2. Lesser parameters than Resnets
  3. Applied to time series data

Real brains are also continuous-time systems

References