Transfer Learning Abstractions with Category Theory
Deep functors might help us create radical new techniques for transfer learning, and they might help us unlock better abstractions within deep learning systems. All thanks to category theory and constraint-based programming.
Category theory can be pretty specific. I don’t intend to be 100% complete when I explain concepts within it, but I bet I can explain enough without misrepresentation.
A Driving Problem
Let’s start with a simple problem: create a neural network that takes in an image of N dots and outputs an image of N+1 dots. Then, let’s make “N+1” swappable for another computation, like “N×2”, and make the input/output formats swappable for other formats like binary arrays and images of digits.
To make stuff easier, let’s guarantee that N is a digit, and all operations always end with “modulo 10,” pinning their outputs between 0 and 9.
Here’s the tough part: after initial training, these components must be swappable without further training.
Consider the network:
This network can work incredibly well on our first problem: turn an image of N dots into an image of N+1 dots. However, there’s no way to swap the computation, N+1, while keeping the encoding and decoding logic unaffected. Somehow, we need to separate the network into a few tasks, maybe then we could swap stuff out…
It’s (naively) impossible to learn a network separated in modules like this. Neural networks are tightly coupled from input to output. There’s no way to tell when encoding, decoding, and “Add 1” begin and end; they’re each distributed throughout a typical network. But if we had a decoupled network, we could swap each of the modules to read different kinds of inputs, perform different math operations, and output different formats.
One solution to creating a deep neural network with swappable modules is to create many interconnected networks, trained together:
In this graph, we have many input formats, many output formats, and many math operations we want to perform. It’s possible we can learn it by using multiple inputs and outputs or by training one path at a time until convergence.
Let’s add the ID function to the digit operation module so that the inputs and outputs of that module are constrained to be similar; this will satisfy eagle-eye readers later.
Look at the inputs and outputs of the digit operation modules. Let’s declare that everything going in and out of that module is in the category of Fuzzy Digits, because we’re pretty much guaranteed that each of those tensors contains information about a digit, but it’s unclear how they’re represented. It’s clear from the diagram that all Fuzzy Digits are compatible with multiple interfaces, so it makes sense to put them in a category together.
Similarly, we can assign each of the I/O formats category names. Let’s declare every input to the dot counter and output of the dot image output to be in Dot Images. We also have Digit Images and Binary Numbers.
Morphisms are like functions in programming — they take us from one object to another object. In category theory, morphisms are represented as arrows in between objects within the bounds of a category.
We have several morphisms available to us in Fuzzy Digits. We’ve already learned N+1 and N×2, for example. When we have a morphism in Fuzzy Digits, we can apply it to an object in Fuzzy Digits to get another object in Fuzzy Digits.
It’s worth noting that morphisms compose, just like functions; they’re associative.
Functors “map” the morphisms in one category into another category.
In programming, you can map functions onto lists. A list can hold any one type of value in it; when you “map” a list, you apply a function to each of the list’s members, which are in the category of types. So we have two categories: the category of types and the category of lists. “Map” maps a given function from the category of types into the category of lists.
Here’s our deep functor:
You can see that the morphisms (arrows) in Fuzzy Digits are mapped one-to-one in Digit Images. When in Digit Images, we prefix the names of morphisms with “F” to denote that they are mapped through a functor named “F”. You can see that some underlying structure is preserved in the target category; in this case, the underlying digit remains the same even when we move from the category of Fuzzy Digits to Digit Images.
We can add Binary Numbers to the diagram:
In our toy system, we have a few functors going from the category of Fuzzy Digits: the Dot Images functor, the Binary Numbers functor, and the Digit Images functor. Each functor maps the morphisms in Fuzzy Digits into their own respective category.
What’s totally cool about this is that we can learn new functors from Fuzzy Digits and new morphisms in Fuzzy Digits, hopefully without needing to retrain everything else. For example, we could learn an integer division morphism or a functor into the category of roman numerals.
Taking it Further
Deep functors are just a fancy way to apply transfer learning.
Personally, I’d like to see this applied to other categories, like the category of headshot photos. Imagine creating a functor to the category of headshots from the category of headwear. Perhaps it would help people create headwear-specific models quickly with constrained training data and integrate them within many larger models.
Potentially more interestingly, how can we apply deep functors internally in models, enabling internally learned abstraction and reuse? Perhaps we can learn to map concepts into different domains, like musical scales into different keys.
Learning Category Theory
Please check out Bartosz Milewski’s free book, Category Theory for Programmers, especially if you’re familiar with programming. It’s just right for myself. I borrowed some imagery for my own diagrams.
I’m Job Searching! February 2018
I’m on the job hunt! I have broad experience in machine learning, functional programming, music, and entrepreneurship. I’m looking for a place to apply my skills preferably in Los Angeles or other non-bay-area locations.
Source: Deep Learning on Medium