“GANs” vs “ODEs”: the end of mathematical modeling?

Source: Deep Learning on Medium


Disentangling neural networks representations [source]

Hi everyone! In this article, I would like to make a connection between classical mathematical modeling, that we study in school, college, and machine learning, that also models objects and processes around us in a totally different manner. While mathematicians create models based on their expertise and understanding of the world, machine learning algorithms describe the world in some hidden way, not fully understandable, but in most of the cases even more accurate than mathematical models developed by human experts. However, in a lot of applications (as healthcare, finance, military), we need clean and interpretable decisions, that machine learning algorithms, in particular, deep learning models are not designed to provide.

This blog will review the main characteristics we expect from any model, pros, and cons of “classical” mathematical modeling and machine learning modeling and will show a candidate that combines both of the worlds — disentangled representation learning.

Also, if you want to try to apply disentangled representations on your own data, check my implementation on GitHub and this library from Google Research.

What’s wrong with deep learning?

Since deep learning revolution, we try to apply neural networks everywhere. In a lot of important domains, it really makes sense and helps to achieve state of the art results: in computer vision, natural language processing, speech analysis, and signal processing. Eventually, all this deep learning hype is about automatic feature extraction from complex data with a combination of linear and non-linear transformation in the neural nets, ending with some “vector”, that we also call “embedding”, that represents all needed information about the input object and allows to do classification or regression on it:

And these “embeddings” are indeed very good in terms of feature extraction and accuracy, but they fail in numerous things:

  • Interpretation: the vector of size N tells me nothing about why some particular decision was taken, only reverse engineering approaches can highlight “objects of interest” in input data.
  • Needs a lot of data: deep learning doesn’t really work on 10–100 samples.
  • Unsupervised learning: most of the applications now require labeling of the training data
  • Zero-shot reuse: this is a really important issue today: a neural network trained on one dataset very rarely can be straightforwardly applied on another similar one without retraining.
  • Objects generation: can I generate a real object from this embedding? Probably with GANs — yes.
  • Objects manipulation: can I manipulate particular properties of the input object with this embedding? Not really.
  • Theoretical foundation: well, we got universal approximation theory. Not much.

It seems like these problems are really very difficult to solve within a modern machine learning framework. But we were dealing with them all somehow recently!

What’s good with mathematical modeling?

Concerning all these problems mentioned above, most of mathematicians 20, 50, even 100 years ago haven’t encountered at all. Why? Because they were busy with mathematical modeling, i.e. describing objects and processes from the real world using mathematical abstractions, for example, distributions, formulas, or differential equations (that’s why we have “ODE”, ordinary differential equations, in the title). Let’s check the “problems checklist” again, but thinking about mathematical models created by scientists from scratch. I’ll still use the term “embedding” here, but it will represent the parameters of the mathematical model, i.e. set of degrees of freedom in the differential equation.

  • Interpretation: every mathematical model is created based on how a scientist describes the object — with clear motivation and understanding. For example for describing physical motion, our embedding will consist of object mass, a speed of motion and coordinates space — no abstract vectors!
  • Needs a lot of data: most of today’s scientific breakthroughs wasn’t done on “ImageNet-sized” datasets.
  • Unsupervised learning: well, not really the case of mathematical modeling either :)
  • Zero-shot reuse: the same stochastic differential equation of, let’s say, geometric Brownian motion can be applied in finance, biology or physics — just rename the parameters names.
  • Objects generation: out-of-the-box, just sampling parameters.
  • Objects manipulation: out-of-the-box, just manipulating parameters.
  • Theoretical foundation: hundreds of years of science.

So why we don’t use differential equations for everything? Well, it turned out, that for complex data on large-scale they perform much worse. That’s why we are riding deep learning waves today. But still, we’d like to have nice properties from human-developed models as well.

Combining machine learning and human-based modeling

What if we could still use neural networks, that are so accurate while analyzing complex data, but also have the properties we described above? Interpretability, ability to generate and manipulate objects, unsupervised feature learning and zero-shot reuse in different environments, where are you? For example, as such a feature extractor for facial images I would like to see something like this:

Almost unsupervised disentanglement

It works with images, that are too complex for differential equations or other models, allows generation and manipulation of objects, are interpretable, and, most probably, can do all of this on another dataset as well. The only problem with this work is that is performed not fully unsupervised. Another important issue with manipulation is when I change, for example, “the beard” feature, it automatically makes a face more manly, which means, that learned features, although being interpretable, are correlated to each other, or, in other words, entangled.

β -VAE

However, there exists an approach that can help us to obtain disentangled representation, in other words, an embedding, where each element is responsible for a single factor and this embedding can be used for classification, generation or manipulation tasks on the new data (in zero-shot fashion). This algorithm was developed in DeepMind labs and is based on variational autoencoder, but with a bigger emphasis on the KL-divergence between latent distributions than on the restoration loss. Mor more details I would like you to consult the following video, that greatly explains the idea behind beta-VAE, it’s applications in supervised learning and reinforcement learning.

beta-VAE applications to machine learning and reinforcement learning

After watching this video, you can see, that beta-VAEs are really able to extract factors of variations of the input data: physical motion directions, objects sizes, colors, and orientations, they can able to separate objects of interest and background in reinforcement learning applications and zero-shot reuse of agents that were trained in simulations in real-world environments.

My own experiments

Since I am working a lot with medical and financial applications, where disengagement could really solve a lot of practical problems related to interpretability of the models, artificial data generation and zero-shot learning, I tried to use beta-VAEs for ECG data and BTC prices data. You can find the code for training the model in my GitHub. First, I applied beta-VAE (really very simple MLP network) to electrocardiograms from PTB diagnostic dataset, that literally has three factors of variations in it: different leads/forms of ECG, pulse, that changes from time to time for each person, and diagnosis, that is infarction or its absence. I trained VAE for 50 epochs with bottleneck size = 10, learning rate 5e-4 and capacity C = 25 (see this work for details). An input always was a single heartbeat. As I expected, my model learned real factors of variations in the dataset. On the picture below you can see, how I manipulate input (black line) heartbeat changing a single feature from a bottleneck from -3 to 3 while leaving all others fixed. You can see, that the 5th feature is responsible for changing the form of a heartbeat, the 8th one stands for the cardiac condition (blue ECGs have infarction symptoms, while red ones trying to be “the opposite”) and 10th feature changes the pulse slightly.

Disentangling ECG beats

Concerning financial data, everything is not that clear (not surprisingly). The training parameters were relatively similar, but the input was a 180-minute sample of BTC prices collected in 2017. I expected from the beta-VAE to learn some “standard” financial time series models like mean-reversion time series, but it’s relatively difficult to interpret obtained representation. Okay, I can tell that feature #5 changes the trend of the input time series, but #2, #4 and #6 are adding/removing curves in different parts of time series or make it more/less “volatile”.

Disentangling BTC close prices

Multiple objects disentanglement

What about a situation, when several objects are represented on the image and we want to find different factors for each of them? Again, DeepMind makes us happy with their results. I won’t dive deep into details, just check two GIFs below, get motivated and read corresponding papers in the tweets. Worth it :)

Summary

Let’s make conclusions describing the beta-VAE approach as we described “normal” deep learning and mathematical modeling.

  • Interpretation: totally interpretable features, we just need to validate each particular embedding element.
  • Needs a lot of data: well, still the case, since we’re operating on the deep learning territory.
  • Unsupervised learning: 100% unsupervised.
  • Zero-shot reuse: reinforcement learning examples from the video talking for themselves
  • Objects generation: easy sampling as in general VAE.
  • Objects manipulation: nice and easy with any factor of variation you want.
  • Theoretical foundation: work in progress :/

We have almost all the good properties from mathematical modeling satisfied alongside with deep learning ability to analyze complex data types with high accuracy. So, the very natural question arises: if I can learn so nice representations from complex data in a completely unsupervised way, does it mean the end of “classical” mathematical modeling? Do we really need to think of complex models if an ML model can do it, and we just have to analyze its features? It’s up to us to decide :)

P.S.
Follow me also on Facebook for AI articles that are too short for Medium, Instagram for personal stuff and Linkedin! Contact me if you want to collaborate on interpretable AI applications or other ML projects.