Sampling Large Graphs in PyTorch Geometric

Original article was published by Mike Chaykowsky on Deep Learning on Medium


Sampling Large Graphs in PyTorch Geometric

Sometimes we encounter large graphs that force us beyond the available memory of our GPU or CPU. In these cases we can utilize graph sampling techniques. PyTorch Geometric is a graph deep learning library that allows us to easily implement many graph neural network architectures with ease. The library contains many standard graph deep learning datasets like Cora, Citeseer, and Pubmed. But recently there’s been a push in graph open datasets to use large scale networks like the Open Graph Benchmark (OGB) [3]. In OGB, the various datasets range from ‘small’ networks like ogbn-arxiv (169,343 nodes) all the way up to ‘large’ datasets like ogbn-papers100M (111,059,956 nodes). Maybe ogbn-arxiv can fit in memory if you are simply doing a node classification with a small GCN or something, but try anything beyond this or use a medium to large dataset in OGB and you might have to resort to sampling the graph.

There are various ways to sample a large graph and I will attempt to cover a two of the prominent methods.

  1. NeighborSampler
Sketch of bipartite graphs from 3-layer Neighborhood Sampler

2. GraphSAINTSampler

Sketch of subgraph sampler from a GraphSAINTSampler mini-batch

The NeighborSampler class is from the GraphSAGE paper, Inductive Representation Learning on Large Graphs [2]. If you haven’t used SAGEConv before, you can think of it as learning a function that outputs node embeddings based on the neighborhood of a node rather than learning all of the node embeddings directly. This makes it particularly useful for inductive tasks as you can pass the GNN nodes that it has never seen before.

I sketch an abstract view of a 3-layer GNN’s NeighborSampler with equal sizes below where each layer is broken into a bipartite graph neighborhood sample.

Sketch of 3-layer GNN with Neighborhood Sampling

I attempt to sketch the source nodes in blue and the target nodes in red. As you can see, the target nodes also have self-loops so they are represented on the left and right sides of the bipartite graph. Here is what the code looks like:

train_loader = NeighborSampler(data.edge_index, node_idx=train_idx, 
sizes=[15, 10, 5], batch_size=1024,
shuffle=True, num_workers=12)

sizes specifies the number of neighbors to sample for each source node. Imagine starting with the set of nodes we want to compute embeddings for (these are the final red nodes in layer 3 of the image above). Then we compute all of the samples as specified by the hyperparameters and ultimately it is returned to us — in reverse. This means that as we go through all of the layers, we end up with only the node embeddings we are interested in at the end.

Note: Node IDs in each mini-batch are the original node IDs from the larger graph. This sampler does not sample subgraphs per se, but neighborhood samples to learn an aggregator function.

From the GraphSAGE example in PyTorch Geometric on the ogbn-products dataset, we can see that the train_loader consists of batch_size, n_id, andadjs .

for batch_size, n_id, adjs in train_loader:
...
out = model(x[n_id], adjs)
...

n_id is all of the node IDs of every node used in the sampling procedure, including the sampled neighbors and source nodes. By passing our model x[n_id] we are isolating the node feature vectors of only those nodes used in this batch’s computation. There are 3 adj inadjs which consist of an edge_index, e_id, and size . So in our forward pass of the SAGE model in PyTorch Geometric we have:

def forward(self, x, adjs): 
for i, (edge_index, _, size) in enumerate(adjs):
x_target = x[:size[1]]
x = self.convs[i]((x, x_target), edge_index)
if i != self.num_layers - 1:
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
return x.log_softmax(dim=-1)

Now you can see that the three bipartite graphs in adjs are each passed to the three convolutional layers respectively.

One possible downside of this approach is that we are not actually sampling subgraphs for each batch from the training data. The sampler tries to imitate a GNN convolving across the training dataset network instead of taking actual samples at each iteration. This can be beneficial so as not to bias your training loops, but if you are doing something beyond simple classification or link prediction I have encountered some issues with indexing. However, in another light this approach could be a benefit compared to subgraph sampling since we reduce the bias of the training data.

GraphSAINTSampler allows you to work with actual subgraphs of the original training dataset and re-writes the node IDs from 0 to n where n is the number of nodes in the subgraph.

The GraphSAINTSampler parent class has three child classes: GraphSAINTNodeSampler, GraphSAINTEdgeSampler, and GraphSAINTRandomWalkSampler. Each of these classes use their respective sampling schemes to compute the importance of nodes, which translates to a probability distribution for sampling [5]. This initial sampler is like a pre-processing step that estimates the probability of a node v in V and an edge e in E being sampled. That probability is used later as a normalization factor on the subgraph [4].

Here is an example of the GraphSAINTRandomWalkSampler in the graph_saint example in PyTorch Geometric.

loader = GraphSAINTRandomWalkSampler(data, batch_size=6000, 
walk_length=2, num_steps=5,
sample_coverage=100,
save_dir=dataset.processed_dir,
num_workers=4)

Keep in mind that the loader may not load the entire dataset depending on how you set the hyperparameters. The batch_size hyperparameter is the number of walks to sample per batch. For example, with the Citeseer dataset and batch_size = 1 , walk_length = 1 , and num_steps = 1 we get 1 data sample with 2 nodes. With batch_size = 10 we get 1 data sample with 20 nodes. With batch_size = 100 we get around 200 nodes — which may change at each iteration i.e.189, 191, etc. The num_steps hyperparameter is the number of iterations per epoch. So if we increase num_steps to 2 the number of nodes grows to around 380, with a batch_size = 100 and walk_length = 1. The walk_length hyperparameter refers to the length of each random walk for the sampler and the results of the number of returned nodes will vary widely depending on the assortativity of your network. This actual compiles a C++ implementation (and subsequently cuda implementation) of a random walk on the sparse tensor representation from PyTorch Geometric’s torch_sparse library. See the PyTorch tutorial on extending TorchScript with custom C++ operators for more on this.

Number of nodes in GraphSAINT data loader for various hyperparameters on Citeseer dataset

One of the most important aspects of the GraphSAINT paper is the computation of the normalization statistics to reduce bias in each subgraph sample. The GraphSAINTSampler computes these statistics which is partly controlled by the sample_coverage hyperparameter. The resulting statistics are returned as an edge_norm attribute on the data object. We can modify the edge_weight attribute before the forward pass of our graph neural network with the edge_norm attribute.

edge_weight = data.edge_norm * data.edge_weight            
out = model(data.x, data.edge_index, edge_weight)