A (sometimes) faster alternative to a list of nn.Linear layers

Source: Deep Learning on Medium

tl;dr I made a PyTorch layer that’s sometimes faster than a list of nn.Linear layers: https://gist.githubusercontent.com/Multihuntr/085a3da993f27d787863d22d888ae87b/raw/84b52877793dce9ab9d1c8b00db02f585d0166d1/partitioned_linear.py

Disclaimer: This method uses a lot more memory, and is only faster for certain sizes (which I discuss) and only on the GPU. But in the better case, it’s around 10x faster (in the worse case, it’s a bit slower).

I recently implemented a model where I had some input like

inp.shape # [batch, n, in_features]

And I wanted to apply n different nn.Linear modules on each different input. So I made a list of nn.Linear layers and just

out = []
for i, lin in enumerate(linear_layers):
out.append(lin(inp[:, i]))
out = torch.stack(out, dim=1)
out.shape # [batch, n, out_features]

Which works, but it’s pretty slow on a GPU when your n gets more than 8 or so. So I tried to find a layer that would go straight from inp to out without going through lists and torch.stack -ing. This is my account of a short foray into madness resulting a new (questionably useful) PyTorch layer to do just that.

First, why is the loop slow? There are two reasons I think it’s slow:

  1. There’s overhead for any CUDA operation,
  2. The weights for the linear layers aren’t contiguous, so the GPU has to operate on different parts of memory.

It’s actually so slow on the GPU that a linear layer

lin_big = nn.Linear(n*in_features, n*out_features)

can run much faster than the above loop.


Let’s stop and look at the difference in the number of calculations being done and guess which will be faster. In the above script n = in_features = out_features = 32 so we have n*in_features = n*out_features = 1024 . Then lin_big has a weight matrix shaped [1024, 1024] , and there’s 1024*1024 ~= 1 million floating point operations. Compared to 32 layers with weight matrices shaped [32, 32] , which end up being only 32 * 32 * 32 ~= 32k floating point operations. The above script, on my machine gives:

Single big linear - 0.09648s
Multiple small linears - 1.20879s

Woah! It’s not just a little bit faster. This discrepancy aligns with the reasons I gave above for why multiple small linear layers is so slow. So, we just need to do a single matrix multiplication, without any loops and ensure that the weights of the Linear layers are in the same location, while getting the same result.

If we think a little bit more about the big linear layer, we realise that the set of calculations we want to do is actually a subset of the calculations it’s already doing. Remember that every output is a linear combination of every input. So all we have to do is zero out the right parts of the weight matrix. Let’s reduce the sizes and look at a concrete example.

Let’s say we have in_features=2, n=2, out_features=1, batch=1. Then we have an input like [1, 4] and we need an output like [1, 2]. Then

big_lin = nn.Linear(4, 2, bias=False)
big_lin.weight.data.shape # [2, 4]

That is; big_lin.weight.data[0] is for calculating the first output, and big_lin.weight.data[1]is for calculating the second output. And big_lin.weight.data[:, :2] is for calculating from the first 2 inputs, and big_lin.weight.data[:, 2:] is for calculating from the second 2 inputs.

So if the weight matrix looks like this

 x_0 x_1 x_2 x_3
x_4 x_5 x_6 x_7

then setting it to

 x_0 x_1 0 0
0 0 x_6 x_7

means that x_0 and x_1 will use the first two inputs for the first output, and x_6 and x_7 will used the second two inputs for the second output. And it doesn’t use anything else for calculating anything else.

As you might imagine, as these sizes get bigger, we have a mostly sparse matrix. This is why the memory cost will go up dramatically. Sparse matrix multiplication algorithms are a hairy topic that we’re not going to look at or use. We’re just interested in getting at least the performance of the big_lin.

We’re not done, though…

We can easily create a linear layer that has weights with 0s in the correct locations, but the optimiser won’t keep them as 0. So, instead of having the weights be 0, we have to make sure we mask out the part that we don’t like before using the weights (only matters during training).

Ok, now we’re done.

How much faster is it?

Let’s call the 1024->1024 big linear case A.
Let’s call the 32x (32->32) multiple small linear case B.
Let’s call this new layer case C.

I did some timing comparison to see how much of a difference it actually makes. Playing around with the numbers on my machine, if I choose
B <= 128,
in_size = 1024
n_l = 32
out_size = 1024

Then it’s more than 10x faster for case A or C than case B.

As B increases, the performance gap decreases.
If I choose B = 2048, then it’s only 1.5x faster to do case A than B.

As n_l decreases, the performance of of case B improves dramatically.
If I choose n_l <= 8, then it’s faster to do case B.

As in_size increases, the performance gap decreases
If I choose in_size = 8096, then it’s ~2.5x faster to do case A than B. Only ~1.5x faster to do case C than B.
If I choose in_size = 256, then it’s ~18x faster to do case A or C than B.

As out_size increases, the performance gap decreases
If I choose out_size = 8096, then it’s ~2.5x faster to do case A than B. Only ~1.5x faster to do case C than B.
If I choose out_size = 256, then it’s 18x faster to do case A or C than B.


It might be useful to someone, and was fun to work out.