美文网首页
Tutorial: Batched graph classifi

Tutorial: Batched graph classifi

作者: 魏鹏飞 | 来源:发表于2020-03-07 17:41 被阅读0次

In this tutorial, you learn how to use DGL to batch multiple graphs of variable size and shape. The tutorial also demonstrates training a graph neural network for a simple graph classification task.

Graph classification is an important problem with applications across many fields, such as bioinformatics, chemoinformatics, social network analysis, urban computing, and cybersecurity. Applying graph neural networks to this problem has been a popular approach recently. This can be seen in the following reserach references: Ying et al., 2018, Cangea et al., 2018, Knyazev et al., 2018, Bianchi et al., 2019, Liao et al., 2019, Gao et al., 2019).

Simple graph classification task

In this tutorial, you learn how to perform batched graph classification with DGL. The example task objective is to classify eight types of topologies shown here.

Implement a synthetic dataset data.MiniGCDataset in DGL. The dataset has eight different types of graphs and each class has the same number of graph samples.

from dgl.data import MiniGCDataset
import matplotlib.pyplot as plt
import networkx as nx
# A dataset with 80 samples, each graph is
# of size [10, 20]
'''
Init signature: MiniGCDataset(num_graphs, min_num_v, max_num_v)
Parameters
----------
num_graphs: int
    Number of graphs in this dataset.
min_num_v: int
    Minimum number of nodes for graphs
max_num_v: int
    Maximum number of nodes for graphs
'''
dataset = MiniGCDataset(80, 10, 20)
graph, label = dataset[0]
fig, ax = plt.subplots()
nx.draw(graph.to_networkx(), ax=ax)
ax.set_title('Class: {:d}'.format(label))
plt.show()

Form a graph mini-batch

To train neural networks efficiently, a common practice is to batch multiple samples together to form a mini-batch. Batching fixed-shaped tensor inputs is common. For example, batching two images of size 28 x 28 gives a tensor of shape 2 x 28 x 28. By contrast, batching graph inputs has two challenges:

  • Graphs are sparse.
  • Graphs can have various length. For example, number of nodes and edges.

To address this, DGL provides a dgl.batch() API. It leverages the idea that a batch of graphs can be viewed as a large graph that has many disjointed connected components. Below is a visualization that gives the general idea.

Define the following collate function to form a mini-batch from a given list of graph and label pairs.

The return type of dgl.batch() is still a graph. In the same way, a batch of tensors is still a tensor. This means that any code that works for one graph immediately works for a batch of graphs. More importantly, because DGL processes messages on all nodes and edges in parallel, this greatly improves efficiency.

Graph classifier

Graph classification proceeds as follows.

graph_classifier.png

From a batch of graphs, perform message passing and graph convolution for nodes to communicate with others. After message passing, compute a tensor for graph representation from node (and edge) attributes. This step might be called readout or aggregation. Finally, the graph representations are fed into a classifier g to predict the graph labels.

Graph convolution

The graph convolution operation is basically the same as that for graph convolutional network (GCN). To learn more, see the GCN tutorial). The only difference is that we replace h_v^{l+1}=ReLU(b^{(l)}+\sum_{u\in N(v)}h_u^{(l)}W^{(l)}) by h_v^{l+1}=ReLU(b^{(l)}+\frac{1}{|N(v)|}\sum_{u\in N(v)}h_u^{(l)}W^{(l)})

The replacement of summation by average is to balance nodes with different degrees. This gives a better performance for this experiment.

The self edges added in the dataset initialization allows you to include the original node feature h^{(l)}_v when taking the average.

import dgl.function as fn
import torch
import torch.nn as nn


# Sends a message of node feature h.
msg = fn.copy_src(src='h', out='m')

def reduce(nodes):
    """Take an average over all neighbor node features hu and use it to
    overwrite the original node feature."""
    accum = torch.mean(nodes.mailbox['m'], 1)
    return {'h': accum}

class NodeApplyModule(nn.Module):
    """Update the node feature hv with ReLU(Whv+b)."""
    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data['h'])
        h = self.activation(h)
        return {'h' : h}

class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCN, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

    def forward(self, g, feature):
        # Initialize the node features with h.
        g.ndata['h'] = feature
        g.update_all(msg, reduce)
        g.apply_nodes(func=self.apply_mod)
        return g.ndata.pop('h')

Readout and classification

For this demonstration, consider initial node features to be their degrees. After two rounds of graph convolution, perform a graph readout by averaging over all node features for each graph in the batch.

h_g=\frac{1}{|V|}\sum_{v\in V}h_v\tag{1}

In DGL, dgl.mean_nodes() handles this task for a batch of graphs with variable size. You then feed the graph representations into a classifier with one linear layer to obtain pre-softmax logits.

import torch.nn.functional as F


class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()

        self.layers = nn.ModuleList([
            GCN(in_dim, hidden_dim, F.relu),
            GCN(hidden_dim, hidden_dim, F.relu)])
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        # For undirected graphs, in_degree is the same as
        # out_degree.
        h = g.in_degrees().view(-1, 1).float()
        for conv in self.layers:
            h = conv(g, h)
        g.ndata['h'] = h
        hg = dgl.mean_nodes(g, 'h')
        return self.classify(hg)

Setup and training

Create a synthetic dataset of 400 graphs with 10 ~ 20 nodes. 320 graphs constitute a training set and 80 graphs constitute a test set.

import torch.optim as optim
from torch.utils.data import DataLoader

# Create training and test sets.
trainset = MiniGCDataset(320, 10, 20)
testset = MiniGCDataset(80, 10, 20)
# Use PyTorch's DataLoader and the collate function
# defined before.
data_loader = DataLoader(trainset, batch_size=32, shuffle=True,
                         collate_fn=collate)

# Create model
model = Classifier(1, 256, trainset.num_classes)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()

epoch_losses = []
for epoch in range(80):
    epoch_loss = 0
    for iter, (bg, label) in enumerate(data_loader):
        prediction = model(bg)
        loss = loss_func(prediction, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= (iter + 1)
    print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)

# Results
Epoch 0, loss 2.1296
Epoch 1, loss 1.9604
Epoch 2, loss 1.8552
Epoch 3, loss 1.7858
Epoch 4, loss 1.6951
Epoch 5, loss 1.6152
Epoch 6, loss 1.5279
Epoch 7, loss 1.4640
Epoch 8, loss 1.3852
Epoch 9, loss 1.3196
Epoch 10, loss 1.2540
Epoch 11, loss 1.2263
Epoch 12, loss 1.1799
Epoch 13, loss 1.1836
......
......
......
Epoch 69, loss 0.4761
Epoch 70, loss 0.4719
Epoch 71, loss 0.4638
Epoch 72, loss 0.4573
Epoch 73, loss 0.4677
Epoch 74, loss 0.4489
Epoch 75, loss 0.4524
Epoch 76, loss 0.4447
Epoch 77, loss 0.4342
Epoch 78, loss 0.4452
Epoch 79, loss 0.4344

The learning curve of a run is presented below.

plt.title('cross entropy averaged over minibatches')
plt.plot(epoch_losses)
plt.show()
image.png

The trained model is evaluated on the test set created. To deploy the tutorial, restrict the running time to get a higher accuracy (80 % ~ 90 %)than the ones printed below.

model.eval()
# Convert a list of tuples to two lists
test_X, test_Y = map(list, zip(*testset))
test_bg = dgl.batch(test_X)
test_Y = torch.tensor(test_Y).float().view(-1, 1)
probs_Y = torch.softmax(model(test_bg), 1)
sampled_Y = torch.multinomial(probs_Y, 1)
argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
print('Accuracy of sampled predictions on the test set: {:.4f}%'.format(
    (test_Y == sampled_Y.float()).sum().item() / len(test_Y) * 100))
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
    (test_Y == argmax_Y.float()).sum().item() / len(test_Y) * 100))

# Results
Accuracy of sampled predictions on the test set: 68.7500%
Accuracy of argmax predictions on the test set: 81.250000%

The animation here plots the probability that a trained model predicts the correct graph type.

image.png

To understand the node and graph representations that a trained model learned, we use t-SNE, for dimensionality reduction and visualization.

image.png

While the visualization does suggest some clustering effects of the node features, you would not expect a perfect result. Node degrees are deterministic for these node features. The graph features are improved when separated.

原文地址:
https://docs.dgl.ai/tutorials/basics/4_batch.html

相关文章

网友评论

      本文标题:Tutorial: Batched graph classifi

      本文链接:https://www.haomeiwen.com/subject/gmcqdhtx.html