Note
Click here to download the full example code
Building a Graph Convolutional Network¶
Author: Yulun Yao, Chien-Yu Lin
This article is an introductory tutorial to build a Graph Convolutional Network (GCN) with Relay. In this tutorial, we will run our GCN on Cora dataset to demonstrate. Cora dataset is a common benchmark for Graph Neural Networks (GNN) and frameworks that support GNN training and inference. We directly load the dataset from DGL library to do the apples to apples comparison against DGL.
Please refer to DGL doc for DGL installation at https://docs.dgl.ai/install/index.html.
Please refer to PyTorch guide for PyTorch installation at https://pytorch.org/get-started/locally/.
Define GCN in DGL with PyTorch backend¶
DGL example: https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcn This part reuses the code from the above example.
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import networkx as nx
from dgl.nn.pytorch import GraphConv
class GCN(nn.Module):
def __init__(self, g, n_infeat, n_hidden, n_classes, n_layers, activation):
super(GCN, self).__init__()
self.g = g
self.layers = nn.ModuleList()
self.layers.append(GraphConv(n_infeat, n_hidden, activation=activation))
for i in range(n_layers - 1):
self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation))
self.layers.append(GraphConv(n_hidden, n_classes))
def forward(self, features):
h = features
for i, layer in enumerate(self.layers):
# handle api changes for differnt DGL version
if dgl.__version__ > "0.3":
h = layer(self.g, h)
else:
h = layer(h, self.g)
return h
Using backend: pytorch
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/dgl/backend/pytorch/tensor.py:17: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
if LooseVersion(th.__version__) < LooseVersion("1.5.0"):
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/dgl/backend/pytorch/tensor.py:333: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
if LooseVersion(th.__version__) >= LooseVersion("1.10.0"):
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/dgl/backend/pytorch/sparse.py:8: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
if LooseVersion(th.__version__) >= LooseVersion("1.6.0"):
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/dgl/dataloading/pytorch/dataloader.py:22: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
PYTORCH_VER = LooseVersion(th.__version__)
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/dgl/dataloading/pytorch/dataloader.py:23: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
PYTORCH_16 = PYTORCH_VER >= LooseVersion("1.6.0")
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/dgl/dataloading/pytorch/dataloader.py:24: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
PYTORCH_17 = PYTORCH_VER >= LooseVersion("1.7.0")
Define the functions to load dataset and evaluate accuracy¶
You may substitute this part with your own dataset, here we load data from DGL
from dgl.data import load_data
from collections import namedtuple
def load_dataset(dataset="cora"):
args = namedtuple("args", ["dataset"])
data = load_data(args(dataset))
# Remove self-loops to avoid duplicate passing of a node's feature to itself
g = data.graph
g.remove_edges_from(nx.selfloop_edges(g))
g.add_edges_from(zip(g.nodes, g.nodes))
return g, data
def evaluate(data, logits):
test_mask = data.test_mask # the test set which isn't included in the training phase
pred = logits.argmax(axis=1)
acc = ((pred == data.labels) * test_mask).sum() / test_mask.sum()
return acc
Load the data and set up model parameters¶
"""
Parameters
----------
dataset: str
Name of dataset. You can choose from ['cora', 'citeseer', 'pubmed'].
num_layer: int
number of hidden layers
num_hidden: int
number of the hidden units in the hidden layer
infeat_dim: int
dimension of the input features
num_classes: int
dimension of model output (Number of classes)
"""
dataset = "cora"
g, data = load_dataset(dataset)
num_layers = 1
num_hidden = 16
infeat_dim = data.features.shape[1]
num_classes = data.num_labels
Downloading /workspace/.dgl/cora_v2.zip from https://data.dgl.ai/dataset/cora_v2.zip...
Extracting file to /workspace/.dgl/cora_v2
Finished data loading and preprocessing.
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done saving data into cached files.
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/dgl/data/utils.py:286: UserWarning: Property dataset.graph will be deprecated, please use dataset[0] instead.
warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new))
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/dgl/data/utils.py:286: UserWarning: Property dataset.feat will be deprecated, please use g.ndata['feat'] instead.
warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new))
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/dgl/data/utils.py:286: UserWarning: Property dataset.num_labels will be deprecated, please use dataset.num_classes instead.
warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new))
Set up the DGL-PyTorch model and get the golden results¶
The weights are trained with https://github.com/dmlc/dgl/blob/master/examples/pytorch/gcn/train.py
from tvm.contrib.download import download_testdata
from dgl import DGLGraph
features = torch.FloatTensor(data.features)
dgl_g = DGLGraph(g)
torch_model = GCN(dgl_g, infeat_dim, num_hidden, num_classes, num_layers, F.relu)
# Download the pretrained weights
model_url = "https://homes.cs.washington.edu/~cyulin/media/gnn_model/gcn_%s.torch" % (dataset)
model_path = download_testdata(model_url, "gcn_%s.pickle" % (dataset), module="gcn_model")
# Load the weights into the model
torch_model.load_state_dict(torch.load(model_path))
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/dgl/data/utils.py:286: UserWarning: Property dataset.feat will be deprecated, please use g.ndata['feat'] instead.
warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new))
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/dgl/base.py:45: DGLWarning: Recommend creating graphs by `dgl.graph(data)` instead of `dgl.DGLGraph(data)`.
return warnings.warn(message, category=category, stacklevel=1)
<All keys matched successfully>
Run the DGL model and test for accuracy¶
torch_model.eval()
with torch.no_grad():
logits_torch = torch_model(features)
print("Print the first five outputs from DGL-PyTorch execution\n", logits_torch[:5])
acc = evaluate(data, logits_torch.numpy())
print("Test accuracy of DGL results: {:.2%}".format(acc))
Print the first five outputs from DGL-PyTorch execution
tensor([[-0.2198, -0.7980, 0.0784, 0.9232, -0.9319, -0.7733, 0.9410],
[-0.4646, -0.6606, -0.1732, 1.1829, -0.3705, -0.5535, 0.0858],
[-0.0031, -0.4156, 0.0175, 0.4765, -0.5887, -0.3609, 0.2278],
[-0.8559, -0.8860, 1.4782, 0.9262, -1.3100, -1.0960, -0.0908],
[-0.0702, -1.1651, 1.1453, -0.3586, -0.4938, -0.2288, 0.1827]])
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/dgl/data/utils.py:286: UserWarning: Property dataset.test_mask will be deprecated, please use g.ndata['test_mask'] instead.
warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new))
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/dgl/data/utils.py:286: UserWarning: Property dataset.label will be deprecated, please use g.ndata['label'] instead.
warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new))
Test accuracy of DGL results: 10.00%
Define Graph Convolution Layer in Relay¶
To run GCN on TVM, we first need to implement Graph Convolution Layer. You may refer to https://github.com/dmlc/dgl/blob/master/python/dgl/nn/mxnet/conv/graphconv.py for a GraphConv Layer implemented in DGL with MXNet Backend
The layer is defined with below operations, note that we apply two transposes to keep adjacency matrix on right hand side of sparse_dense operator, this method is temporary and will be updated in next few weeks when we have sparse matrix transpose and support for left sparse operator.
\[\mbox{GraphConv}(A, H, W) = A * H * W = ((H * W)^t * A^t)^t = ((W^t * H^t) * A^t)^t\]
from tvm import relay
from tvm.contrib import graph_executor
import tvm
from tvm import te
def GraphConv(layer_name, input_dim, output_dim, adj, input, norm=None, bias=True, activation=None):
"""
Parameters
----------
layer_name: str
Name of layer
input_dim: int
Input dimension per node feature
output_dim: int,
Output dimension per node feature
adj: namedtuple,
Graph representation (Adjacency Matrix) in Sparse Format (`data`, `indices`, `indptr`),
where `data` has shape [num_nonzeros], indices` has shape [num_nonzeros], `indptr` has shape [num_nodes + 1]
input: relay.Expr,
Input feature to current layer with shape [num_nodes, input_dim]
norm: relay.Expr,
Norm passed to this layer to normalize features before and after Convolution.
bias: bool
Set bias to True to add bias when doing GCN layer
activation: <function relay.op.nn>,
Activation function applies to the output. e.g. relay.nn.{relu, sigmoid, log_softmax, softmax, leaky_relu}
Returns
----------
output: tvm.relay.Expr
The Output Tensor for this layer [num_nodes, output_dim]
"""
if norm is not None:
input = relay.multiply(input, norm)
weight = relay.var(layer_name + ".weight", shape=(input_dim, output_dim))
weight_t = relay.transpose(weight)
dense = relay.nn.dense(weight_t, input)
output = relay.nn.sparse_dense(dense, adj)
output_t = relay.transpose(output)
if norm is not None:
output_t = relay.multiply(output_t, norm)
if bias is True:
_bias = relay.var(layer_name + ".bias", shape=(output_dim, 1))
output_t = relay.nn.bias_add(output_t, _bias, axis=-1)
if activation is not None:
output_t = activation(output_t)
return output_t
Prepare the parameters needed in the GraphConv layers¶
import numpy as np
import networkx as nx
def prepare_params(g, data):
params = {}
params["infeats"] = data.features.numpy().astype(
"float32"
) # Only support float32 as feature for now
# Generate adjacency matrix
adjacency = nx.to_scipy_sparse_matrix(g)
params["g_data"] = adjacency.data.astype("float32")
params["indices"] = adjacency.indices.astype("int32")
params["indptr"] = adjacency.indptr.astype("int32")
# Normalization w.r.t. node degrees
degs = [g.in_degree[i] for i in range(g.number_of_nodes())]
params["norm"] = np.power(degs, -0.5).astype("float32")
params["norm"] = params["norm"].reshape((params["norm"].shape[0], 1))
return params
params = prepare_params(g, data)
# Check shape of features and the validity of adjacency matrix
assert len(params["infeats"].shape) == 2
assert (
params["g_data"] is not None and params["indices"] is not None and params["indptr"] is not None
)
assert params["infeats"].shape[0] == params["indptr"].shape[0] - 1
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/dgl/data/utils.py:286: UserWarning: Property dataset.feat will be deprecated, please use g.ndata['feat'] instead.
warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new))
Put layers together¶
# Define input features, norms, adjacency matrix in Relay
infeats = relay.var("infeats", shape=data.features.shape)
norm = relay.Constant(tvm.nd.array(params["norm"]))
g_data = relay.Constant(tvm.nd.array(params["g_data"]))
indices = relay.Constant(tvm.nd.array(params["indices"]))
indptr = relay.Constant(tvm.nd.array(params["indptr"]))
Adjacency = namedtuple("Adjacency", ["data", "indices", "indptr"])
adj = Adjacency(g_data, indices, indptr)
# Construct the 2-layer GCN
layers = []
layers.append(
GraphConv(
layer_name="layers.0",
input_dim=infeat_dim,
output_dim=num_hidden,
adj=adj,
input=infeats,
norm=norm,
activation=relay.nn.relu,
)
)
layers.append(
GraphConv(
layer_name="layers.1",
input_dim=num_hidden,
output_dim=num_classes,
adj=adj,
input=layers[-1],
norm=norm,
activation=None,
)
)
# Analyze free variables and generate Relay function
output = layers[-1]
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/dgl/data/utils.py:286: UserWarning: Property dataset.feat will be deprecated, please use g.ndata['feat'] instead.
warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new))
Compile and run with TVM¶
Export the weights from PyTorch model to Python Dict
model_params = {}
for param_tensor in torch_model.state_dict():
model_params[param_tensor] = torch_model.state_dict()[param_tensor].numpy()
for i in range(num_layers + 1):
params["layers.%d.weight" % (i)] = model_params["layers.%d.weight" % (i)]
params["layers.%d.bias" % (i)] = model_params["layers.%d.bias" % (i)]
# Set the TVM build target
target = "llvm" # Currently only support `llvm` as target
func = relay.Function(relay.analysis.free_vars(output), output)
func = relay.build_module.bind_params_by_name(func, params)
mod = tvm.IRModule()
mod["main"] = func
# Build with Relay
with tvm.transform.PassContext(opt_level=0): # Currently only support opt_level=0
lib = relay.build(mod, target, params=params)
# Generate graph executor
dev = tvm.device(target, 0)
m = graph_executor.GraphModule(lib["default"](dev))
Run the TVM model, test for accuracy and verify with DGL¶
m.run()
logits_tvm = m.get_output(0).numpy()
print("Print the first five outputs from TVM execution\n", logits_tvm[:5])
labels = data.labels
test_mask = data.test_mask
acc = evaluate(data, logits_tvm)
print("Test accuracy of TVM results: {:.2%}".format(acc))
import tvm.testing
# Verify the results with the DGL model
tvm.testing.assert_allclose(logits_torch, logits_tvm, atol=1e-3)
Print the first five outputs from TVM execution
[[-0.21976954 -0.7979525 0.07836491 0.9232204 -0.93188703 -0.7732947
0.9410008 ]
[-0.4645713 -0.66060466 -0.17316166 1.1828876 -0.37051404 -0.5534965
0.08579484]
[-0.00308266 -0.41562504 0.0175378 0.47649348 -0.5886737 -0.3609016
0.22782072]
[-0.8559376 -0.8860172 1.4782399 0.9262254 -1.3099641 -1.0960144
-0.09084877]
[-0.07015878 -1.1651071 1.1452857 -0.35857323 -0.49377596 -0.22878847
0.18269953]]
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/dgl/data/utils.py:286: UserWarning: Property dataset.label will be deprecated, please use g.ndata['label'] instead.
warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new))
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/dgl/data/utils.py:286: UserWarning: Property dataset.test_mask will be deprecated, please use g.ndata['test_mask'] instead.
warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new))
Test accuracy of TVM results: 10.00%