Graph convolutional networks (GCNs) are deep learning networks that are used for numerous graph-based applications like node classification, link prediction, visualization, etc. Many of these applications are often meant for dynamic or evolving networks such as social networks, recommendation networks etc. The graph convolutional networks have difficulty with dynamic networks because the node-to-embedding mapping functions learned by such networks cannot be applied to the new nodes of a growing network. The GraphSAGE model overcomes this limitation by incorporating two changes, sampling and aggregating features, in the graph convolutional networks (GCNs) that are used for fixed structure networks. These changes, explained below, make GraphSAGE not only computationally efficient to scale to very large graphs but also permit embeddings to be generated for those nodes of a graph not seen before. This inductive capability to work with unseen data makes GraphSAGE a highly valuable node embedding tool suitable for a large number of applications. In this post, I will highlight the aggregation and sampling used in GraphSAGE and provide a link to an example of its usage. Before reading on, you may want to go over my post on graph convolutional networks.
How Does GraphSAGE Perform Aggregation?
We will assume each node has a feature vector associated with it. The associated feature vector with a node, for example, may be the profile of the person represented by the node. In absence of explicit features, structural features of nodes such as the node degree are used in GraphSAGE.
The key component of GraphSAGE algorithm is that embedding of a node is obtained by aggregating embeddings in a layer-wise manner by slowly expanding the local neighborhood of the node. At layer-0, the embedding of a node is simply its feature vector. Thus, the embedding of the target node A at layer-0 in the graph shown below is XA.
Figure 1. Input Graph for Embedding
The layer-1 embedding of a node is the aggregate embedding of all of its neighboring nodes away at one link. Similarly, the layer-2 embedding of a node is obtained by aggregating embeddings of its 2-neighborhood nodes. Figure 2 below shows this layer-wise embedding process for the target node A of the input graph of Figure 1.
Figure 2. Layer-wise Embedding of Target Node A
The actual aggregation is done by learning aggregation functions for different layers. The square boxes in Figure 2 denote these functions. The configuration of Figure 2 defines a graph convolutional neural network where the hidden layer outputs can be expressed as follows for layers 1-K.
For k = 2, we get the embedding vector ZA for the target node in the above graph. The matrices W and B are the trainable matrices. These trainable matrices are learned by defining a suitable loss function. Both supervised and unsupervised learning are possible.
While the above description for aggregation appears similar to that used in graph convolutional networks, the difference is that the aggregation function is learned during training in GraphSAGE and it is predefined and fixed in GCNs. It is this difference that makes GraphSAGE an inductive learner as opposed to GCNs being transductive learners.
Neighborhood Sampling in GraphSAGE
In graph convolution networks, every neighboring node of the target node at the specified neighborhood size participates in message passing and contributes towards the computation of the embedded vector of the target node. When the neighborhood size is enlarged, the number of nodes contributing to the embedded vector computation can grow exponentially for certain graphs. This problem gets exacerbated when the neighborhood of a target node includes a hub or celebrity node having millions of connections. To avoid exponential computation growth, GraphSAGE algorithm randomly selects only a sample of neighboring nodes. This allows GraphSAGE to be used for extremely large graphs with billions of nodes. You can see an illustration of neighborhood sampling in Figure 3 below where the No icon shows the nodes not being sampled.
Figure 3. Illustration of Neighborhood Sampling
Applying GraphSAGE to Perform Link Prediction
This link describes the use of GraphSAGE for link prediction. It uses Deep Graph Learning (DGL) for implementation. Check it out.
Thanks 👍