Source: Deep Learning on Medium

tl;drI made a PyTorch layer that’s sometimes faster than a list of nn.Linear layers: https://gist.githubusercontent.com/Multihuntr/085a3da993f27d787863d22d888ae87b/raw/84b52877793dce9ab9d1c8b00db02f585d0166d1/partitioned_linear.py

**Disclaimer**: This method uses a lot more memory, and is only faster for certain sizes (which I discuss) and only on the GPU. But in the better case, it’s around 10x faster (in the worse case, it’s a bit slower).

I recently implemented a model where I had some input like

`inp.shape # [batch, n, in_features]`

And I wanted to apply `n`

different `nn.Linear`

modules on each different input. So I made a list of `nn.Linear`

layers and just

`out = []`

for i, lin in enumerate(linear_layers):

out.append(lin(inp[:, i]))

out = torch.stack(out, dim=1)

out.shape # [batch, n, out_features]

Which works, but it’s pretty slow on a GPU when your `n`

gets more than 8 or so. So I tried to find a layer that would go straight from `inp`

to `out`

without going through lists and `torch.stack`

-ing. This is my account of a short foray into madness resulting a new (questionably useful) PyTorch layer to do just that.

First, why is the loop slow? There are two reasons I think it’s slow:

- There’s overhead for any CUDA operation,
- The weights for the linear layers aren’t contiguous, so the GPU has to operate on different parts of memory.

It’s actually so slow on the GPU that a linear layer

`lin_big = nn.Linear(n*in_features, n*out_features)`

can run much faster than the above loop.

Let’s stop and look at the difference in the number of calculations being done and guess which will be faster. In the above script `n = in_features = out_features = 32 `

so we have `n*in_features = n*out_features = 1024`

. Then `lin_big`

has a weight matrix shaped `[1024, 1024]`

, and there’s `1024*1024 ~= 1 million`

floating point operations. Compared to `32`

layers with weight matrices shaped `[32, 32]`

, which end up being only `32 * 32 * 32 ~= 32k`

floating point operations. The above script, on my machine gives:

`Single big linear - 0.09648s`

Multiple small linears - 1.20879s

Woah! It’s not just a little bit faster. This discrepancy aligns with the reasons I gave above for why multiple small linear layers is so slow. So, we just need to do a single matrix multiplication, without any loops and ensure that the weights of the Linear layers are in the same location, while getting the same result.

If we think a little bit more about the big linear layer, we realise that the set of calculations we want to do is actually a subset of the calculations it’s already doing. Remember that every output is a linear combination of every input. So all we have to do is zero out the right parts of the weight matrix. Let’s reduce the sizes and look at a concrete example.

Let’s say we have `in_features=2`

, `n=2`

, `out_features=1`

, `batch=1`

. Then we have an input like `[1, 4]`

and we need an output like `[1, 2]`

. Then

`big_lin = nn.Linear(4, 2, bias=False)`

big_lin.weight.data.shape # [2, 4]

That is; `big_lin.weight.data[0]`

is for calculating the first output, and `big_lin.weight.data[1]`

is for calculating the second output. And `big_lin.weight.data[:, :2]`

is for calculating from the first 2 inputs, and `big_lin.weight.data[:, 2:]`

is for calculating from the second 2 inputs.

So if the weight matrix looks like this

` x_0 x_1 x_2 x_3`

x_4 x_5 x_6 x_7

then setting it to

` x_0 x_1 0 0`

0 0 x_6 x_7

means that `x_0`

and `x_1`

will use the first two inputs for the first output, and `x_6`

and `x_7`

will used the second two inputs for the second output. And it doesn’t use anything else for calculating anything else.

As you might imagine, as these sizes get bigger, we have a mostly sparse matrix. This is why the memory cost will go up dramatically. Sparse matrix multiplication algorithms are a hairy topic that we’re not going to look at or use. We’re just interested in getting at least the performance of the `big_lin`

.

**We’re not done, though…**

We can easily create a linear layer that has weights with 0s in the correct locations, but the optimiser won’t keep them as 0. So, instead of having the weights be 0, we have to make sure we mask out the part that we don’t like before using the weights (only matters during training).

**Ok, now we’re done.**