Temporal Graph Networks

Original article was published by Michael Bronstein on Artificial Intelligence on Medium

Graph neural networks (GNNs) research has surged to become one of the hottest topics in machine learning this year. GNNs have seen a series of recent successes in problems from the fields of biology, chemistry, social science, physics, and many others. So far, GNN models have been primarily developed for static graphs that do not change over time. However, many interesting real-world graphs are dynamic and evolving in time, with prominent examples including social networks, financial transactions, and recommender systems. In many cases, it is the dynamic behaviour of such systems that conveys important insights, otherwise lost if one considers only a static graph.

A dynamic network of Twitter users interacting with tweets and following each other. All the edges have a timestamp. Given such a dynamic graph, we want to predict future interactions, e.g., which tweet a user will like or whom they will follow.

A dynamic graph can be represented as an ordered list or asynchronous “stream” of timed events, such as additions or deletions of nodes and edges [1]. A social network like Twitter is a good illustration to keep in mind: when a person joins the platform, a new node is created. When they follow another user, a follow edge is created. When they change their profile, the node is updated.

This stream of events is ingested by an encoder neural network that produces a time-dependent embedding for each node of the graph. The embedding can then be fed into a decoder that is designed for a specific task. One example task is predicting future interactions by trying to answer the question: what is the probability of having an edge between nodes i and j at time t? The ability to answer this question is crucial to recommendation systems that e.g. suggest social network users whom to follow or decide which content to display. The following figure illustrates this scenario:

Example of a TGN encoder ingesting a dynamic graph with seven visible edges (with timestamps t₁ to t₇), with the goal of predicting the future interaction between nodes 2 and 4 at time t₈ (grey edge). To do so, TGN computes embeddings for nodes 2 and 4 at time t₈. These embeddings are then concatenated and fed into a decoder (e.g. an MLP), which outputs the probability of the interaction happening.

The key piece in the above setup is the encoder that can be trained with any decoder. On the aforementioned task of future interaction prediction, the training can be done in a self-supervised manner: during each epoch, the encoder processes the events in chronological order and predicts the next interaction based on the previous ones [2].

Temporal Graph Network (TGN) is a general encoder architecture we developed at Twitter with colleagues Fabrizio Frasca, Davide Eynard, Ben Chamberlain, and Federico Monti [3]. This model can be applied to various problems of learning on dynamic graphs represented as a stream of events. In a nutshell, a TGN encoder acts by creating compressed representations of the nodes based on their interactions and updates them upon each event. To accomplish this, a TGN has the following main components:

Memory. The memory stores the states of all the nodes, acting as a compressed representation of the node’s past interactions. It is analogous to the hidden state of an RNN; however, here we have a separate state vector s(t) for each node i. When a new node appears, we add a corresponding state initialised as a vector of zeros. Moreover, since the memory for each node is just a state vector (and not a parameter), it can also be updated at test time when the model ingests new interactions.

Message Function is the main mechanism of updating the memory. Given an interaction between nodes i and j at time t, the message function computes two messages (one for i and one for j) , which are used to update the memory. This is analogous to the messages computed in message-passing graph neural networks [4]. The message is a function of the memory of nodes i and j at the instance of time t⁻ preceding the interaction, the interaction time t, and the edge features [5]:

Memory Updater is used to update the memory with the new messages. This module is usually implemented as an RNN.

Given that the memory of a node is a vector updated over time, the most straightforward approach is to use it directly as the node embedding. In practice, however, this is a bad idea due to the staleness problem: given that the memory is updated only when a node is involved in an interaction, a long period of inactivity of a node causes its memory to go out of date. As an illustration, think of a user staying off Twitter for a few months. When the user returns, they might have already developed new interests in the meantime, so the memory of their past activity is no more relevant. We therefore need a better way to compute the embeddings.

Embedding. A solution is to look at the node neighbours. To solve the staleness problem, the embedding module computes the temporal embedding of a node by performing a graph aggregation over the spatiotemporal neighbours of that node. Even if a node has been inactive for a while, it is likely that some of its neighbours have been active, and by aggregating their memories, TGN can compute an up-to-date embedding for the node. In our example, even when a user stays off Twitter, their friends continue to be active, so when they return, the friends’ recent activity is likely to be more relevant than the user’s own history.

The graph embedding module computes the embedding of a target node by performing an aggregation over its temporal neighbourhood. In the above diagram, when computing the embedding for node 1 at some time t greater than t, t and t, but smaller than t, the temporal neighbourhood will include only edges occurred before time t. Therefore, the edge with node 5 is not involved in the computation, as it happens in the future. Instead, the embedding module aggregates from both the features (v) and memory (s) of neighbours 2, 3 and 4, as well as the features on the edges to compute a representation for node 1. The best performing graph embedding module in our experiments is graph attention, which is able to learn which neighbours are the most important based on their memory, features and time of interaction.

The overall computations performed by TGN on a batch of training data are summarised in the following figure:

Computations performed by TGN on a batch of training data. On the one side, embeddings are produced by the embedding module using the temporal graph and the node’s memory (1). The embeddings are then used to predict the batch interactions and compute the loss (2, 3). On the other side, these same interactions are used to update the memory (4, 5).

By looking at the above diagram, you may be wondering how the memory-related modules (Message function, Message aggregator, and Memory updater) are trained, given that they seem not to directly influence the loss and therefore do not receive a gradient. In order for these modules to influence the loss, we need to update the memory before predicting the batch interactions. However, that would cause leakage, since the memory would already contain information about what we are trying to predict. The strategy we propose to get around this problem is to update the memory with messages coming from previous batches, and then predict the interactions. The diagram below shows the flow of operations of TGN, which is necessary to train the memory-related modules:

Flow of operations of TGN necessary to train the memory-related modules. A new component is introduced, the raw message store, which stores the necessary information to compute messages, which we call raw messages, for interactions which have been processed by the model in the past. This allows the model to delay the memory update brought by an interaction to later batches. At first, the memory is updated using messages computed from raw messages stored in previous batches (1 and 2). The embeddings can then be computed using the just updated memory (grey link) (3). By doing this, the computation of the memory-related modules directly influences the loss (4, 5), and they receive a gradient. Finally, the raw messages for this batch interactions are stored in the raw message store (6) to be used in future batches.

In extensive experimental validation on various dynamic graphs, TGN significantly outperformed competing methods [6] on the tasks of future edge prediction and dynamic node classification both in terms of accuracy and speed. One such dynamic graph is Wikipedia, where users and pages are nodes, and an interaction represents a user editing a page. An encoding of the edit text is used as interaction features. The task in this case is to predict which page a user will edit at a given time. We compared different variants of TGN with baseline methods:

Comparison of various configurations of TGN and older methods (TGAT and Jodie) on future link prediction on the Wikipedia dataset in terms of prediction accuracy and time. We wish more papers reported both of these important criteria in a rigorous way.

This ablation study sheds light on the importance of different TGN modules and allowing us to make a few general conclusions. First, the memory is important: its absence leads to a large drop in performance [7]. Second, the use of the embedding module (as opposed to directly outputting the memory state) is important. Graph attention-based embedding appears to perform the best. Third, having the memory makes it sufficient to use only one graph attention layer (which drastically reduces the computation time), since the memory of 1-hop neighbours gives the model indirect access to 2-hop neighbours information.

As a concluding remark, we consider learning on dynamic graphs an almost virgin research area, with numerous important and exciting applications and a significant potential impact. We believe that our TGN model is an important step towards advancing the ability to learn on dynamic graphs consolidating and extending previous results. As this research field develops, better and bigger benchmarks will become essential. We are now working on creating new dynamic graph datasets and tasks as part of the Open Graph Benchmark.