.. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_how_to_work_with_relay_build_gcn.py: Building a Graph Convolutional Network ====================================== **Author**: `Yulun Yao `_, `Chien-Yu Lin `_ This article is an introductory tutorial to build a Graph Convolutional Network (GCN) with Relay. In this tutorial, we will run our GCN on Cora dataset to demonstrate. Cora dataset is a common benchmark for Graph Neural Networks (GNN) and frameworks that support GNN training and inference. We directly load the dataset from DGL library to do the apples to apples comparison against DGL. Please refer to DGL doc for DGL installation at https://docs.dgl.ai/install/index.html. Please refer to PyTorch guide for PyTorch installation at https://pytorch.org/get-started/locally/. Define GCN in DGL with PyTorch backend -------------------------------------- DGL example: https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcn This part reuses the code from the above example. .. code-block:: default import torch import torch.nn as nn import torch.nn.functional as F import dgl import networkx as nx from dgl.nn.pytorch import GraphConv class GCN(nn.Module): def __init__(self, g, n_infeat, n_hidden, n_classes, n_layers, activation): super(GCN, self).__init__() self.g = g self.layers = nn.ModuleList() self.layers.append(GraphConv(n_infeat, n_hidden, activation=activation)) for i in range(n_layers - 1): self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation)) self.layers.append(GraphConv(n_hidden, n_classes)) def forward(self, features): h = features for i, layer in enumerate(self.layers): # handle api changes for differnt DGL version if dgl.__version__ > "0.3": h = layer(self.g, h) else: h = layer(h, self.g) return h .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Using backend: pytorch Define the functions to load dataset and evaluate accuracy ---------------------------------------------------------- You may substitute this part with your own dataset, here we load data from DGL .. code-block:: default from dgl.data import load_data from collections import namedtuple def load_dataset(dataset="cora"): args = namedtuple("args", ["dataset"]) data = load_data(args(dataset)) # Remove self-loops to avoid duplicate passing of a node's feature to itself g = data.graph g.remove_edges_from(nx.selfloop_edges(g)) g.add_edges_from(zip(g.nodes, g.nodes)) return g, data def evaluate(data, logits): test_mask = data.test_mask # the test set which isn't included in the training phase pred = logits.argmax(axis=1) acc = ((pred == data.labels) * test_mask).sum() / test_mask.sum() return acc Load the data and set up model parameters ----------------------------------------- .. code-block:: default """ Parameters ---------- dataset: str Name of dataset. You can choose from ['cora', 'citeseer', 'pubmed']. num_layer: int number of hidden layers num_hidden: int number of the hidden units in the hidden layer infeat_dim: int dimension of the input features num_classes: int dimension of model output (Number of classes) """ dataset = "cora" g, data = load_dataset(dataset) num_layers = 1 num_hidden = 16 infeat_dim = data.features.shape[1] num_classes = data.num_labels .. rst-class:: sphx-glr-script-out Out: .. code-block:: none NumNodes: 2708 NumEdges: 10556 NumFeats: 1433 NumClasses: 7 NumTrainingSamples: 140 NumValidationSamples: 500 NumTestSamples: 1000 Done loading data from cached files. /usr/local/lib/python3.6/dist-packages/dgl/data/utils.py:285: UserWarning: Property dataset.graph will be deprecated, please use dataset[0] instead. warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new)) /usr/local/lib/python3.6/dist-packages/dgl/data/utils.py:285: UserWarning: Property dataset.feat will be deprecated, please use g.ndata['feat'] instead. warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new)) /usr/local/lib/python3.6/dist-packages/dgl/data/utils.py:285: UserWarning: Property dataset.num_labels will be deprecated, please use dataset.num_classes instead. warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new)) Set up the DGL-PyTorch model and get the golden results ------------------------------------------------------- The weights are trained with https://github.com/dmlc/dgl/blob/master/examples/pytorch/gcn/train.py .. code-block:: default from tvm.contrib.download import download_testdata from dgl import DGLGraph features = torch.FloatTensor(data.features) dgl_g = DGLGraph(g) torch_model = GCN(dgl_g, infeat_dim, num_hidden, num_classes, num_layers, F.relu) # Download the pretrained weights model_url = "https://homes.cs.washington.edu/~cyulin/media/gnn_model/gcn_%s.torch" % (dataset) model_path = download_testdata(model_url, "gcn_%s.pickle" % (dataset), module="gcn_model") # Load the weights into the model torch_model.load_state_dict(torch.load(model_path)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none /usr/local/lib/python3.6/dist-packages/dgl/base.py:45: DGLWarning: Recommend creating graphs by `dgl.graph(data)` instead of `dgl.DGLGraph(data)`. return warnings.warn(message, category=category, stacklevel=1) Run the DGL model and test for accuracy --------------------------------------- .. code-block:: default torch_model.eval() with torch.no_grad(): logits_torch = torch_model(features) print("Print the first five outputs from DGL-PyTorch execution\n", logits_torch[:5]) acc = evaluate(data, logits_torch.numpy()) print("Test accuracy of DGL results: {:.2%}".format(acc)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Print the first five outputs from DGL-PyTorch execution tensor([[-2.2395, -0.9681, 3.4042, -0.1481, -0.0272, -1.2441, -1.8549], [-1.6017, -1.3846, 0.7642, 2.5430, -1.7420, -1.3704, 0.4249], [-2.0039, -1.2357, 2.4931, 1.0323, -1.3252, -1.3401, -0.5114], [ 0.1647, -2.0421, -0.2668, 0.1527, -0.6965, 1.1109, 1.1034], [-0.8606, -0.6954, 0.1959, 0.6853, 0.0284, -0.6652, 0.2225]]) /usr/local/lib/python3.6/dist-packages/dgl/data/utils.py:285: UserWarning: Property dataset.test_mask will be deprecated, please use g.ndata['test_mask'] instead. warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new)) /usr/local/lib/python3.6/dist-packages/dgl/data/utils.py:285: UserWarning: Property dataset.label will be deprecated, please use g.ndata['label'] instead. warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new)) Test accuracy of DGL results: 5.30% Define Graph Convolution Layer in Relay --------------------------------------- To run GCN on TVM, we first need to implement Graph Convolution Layer. You may refer to https://github.com/dmlc/dgl/blob/master/python/dgl/nn/mxnet/conv/graphconv.py for a GraphConv Layer implemented in DGL with MXNet Backend The layer is defined with below operations, note that we apply two transposes to keep adjacency matrix on right hand side of sparse_dense operator, this method is temporary and will be updated in next few weeks when we have sparse matrix transpose and support for left sparse operator. .. math:: \mbox{GraphConv}(A, H, W) = A * H * W = ((H * W)^t * A^t)^t = ((W^t * H^t) * A^t)^t .. code-block:: default from tvm import relay from tvm.contrib import graph_executor import tvm from tvm import te def GraphConv(layer_name, input_dim, output_dim, adj, input, norm=None, bias=True, activation=None): """ Parameters ---------- layer_name: str Name of layer input_dim: int Input dimension per node feature output_dim: int, Output dimension per node feature adj: namedtuple, Graph representation (Adjacency Matrix) in Sparse Format (`data`, `indices`, `indptr`), where `data` has shape [num_nonzeros], indices` has shape [num_nonzeros], `indptr` has shape [num_nodes + 1] input: relay.Expr, Input feature to current layer with shape [num_nodes, input_dim] norm: relay.Expr, Norm passed to this layer to normalize features before and after Convolution. bias: bool Set bias to True to add bias when doing GCN layer activation: , Activation function applies to the output. e.g. relay.nn.{relu, sigmoid, log_softmax, softmax, leaky_relu} Returns ---------- output: tvm.relay.Expr The Output Tensor for this layer [num_nodes, output_dim] """ if norm is not None: input = relay.multiply(input, norm) weight = relay.var(layer_name + ".weight", shape=(input_dim, output_dim)) weight_t = relay.transpose(weight) dense = relay.nn.dense(weight_t, input) output = relay.nn.sparse_dense(dense, adj) output_t = relay.transpose(output) if norm is not None: output_t = relay.multiply(output_t, norm) if bias is True: _bias = relay.var(layer_name + ".bias", shape=(output_dim, 1)) output_t = relay.nn.bias_add(output_t, _bias, axis=-1) if activation is not None: output_t = activation(output_t) return output_t Prepare the parameters needed in the GraphConv layers ----------------------------------------------------- .. code-block:: default import numpy as np import networkx as nx def prepare_params(g, data): params = {} params["infeats"] = data.features.numpy().astype( "float32" ) # Only support float32 as feature for now # Generate adjacency matrix adjacency = nx.to_scipy_sparse_matrix(g) params["g_data"] = adjacency.data.astype("float32") params["indices"] = adjacency.indices.astype("int32") params["indptr"] = adjacency.indptr.astype("int32") # Normalization w.r.t. node degrees degs = [g.in_degree[i] for i in range(g.number_of_nodes())] params["norm"] = np.power(degs, -0.5).astype("float32") params["norm"] = params["norm"].reshape((params["norm"].shape[0], 1)) return params params = prepare_params(g, data) # Check shape of features and the validity of adjacency matrix assert len(params["infeats"].shape) == 2 assert ( params["g_data"] is not None and params["indices"] is not None and params["indptr"] is not None ) assert params["infeats"].shape[0] == params["indptr"].shape[0] - 1 Put layers together ------------------- .. code-block:: default # Define input features, norms, adjacency matrix in Relay infeats = relay.var("infeats", shape=data.features.shape) norm = relay.Constant(tvm.nd.array(params["norm"])) g_data = relay.Constant(tvm.nd.array(params["g_data"])) indices = relay.Constant(tvm.nd.array(params["indices"])) indptr = relay.Constant(tvm.nd.array(params["indptr"])) Adjacency = namedtuple("Adjacency", ["data", "indices", "indptr"]) adj = Adjacency(g_data, indices, indptr) # Construct the 2-layer GCN layers = [] layers.append( GraphConv( layer_name="layers.0", input_dim=infeat_dim, output_dim=num_hidden, adj=adj, input=infeats, norm=norm, activation=relay.nn.relu, ) ) layers.append( GraphConv( layer_name="layers.1", input_dim=num_hidden, output_dim=num_classes, adj=adj, input=layers[-1], norm=norm, activation=None, ) ) # Analyze free variables and generate Relay function output = layers[-1] Compile and run with TVM ------------------------ Export the weigths from PyTorch model to Python Dict .. code-block:: default model_params = {} for param_tensor in torch_model.state_dict(): model_params[param_tensor] = torch_model.state_dict()[param_tensor].numpy() for i in range(num_layers + 1): params["layers.%d.weight" % (i)] = model_params["layers.%d.weight" % (i)] params["layers.%d.bias" % (i)] = model_params["layers.%d.bias" % (i)] # Set the TVM build target target = "llvm" # Currently only support `llvm` as target func = relay.Function(relay.analysis.free_vars(output), output) func = relay.build_module.bind_params_by_name(func, params) mod = tvm.IRModule() mod["main"] = func # Build with Relay with tvm.transform.PassContext(opt_level=0): # Currently only support opt_level=0 lib = relay.build(mod, target, params=params) # Generate graph executor dev = tvm.device(target, 0) m = graph_executor.GraphModule(lib["default"](dev)) Run the TVM model, test for accuracy and verify with DGL -------------------------------------------------------- .. code-block:: default m.run() logits_tvm = m.get_output(0).numpy() print("Print the first five outputs from TVM execution\n", logits_tvm[:5]) labels = data.labels test_mask = data.test_mask acc = evaluate(data, logits_tvm) print("Test accuracy of TVM results: {:.2%}".format(acc)) import tvm.testing # Verify the results with the DGL model tvm.testing.assert_allclose(logits_torch, logits_tvm, atol=1e-3) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Print the first five outputs from TVM execution [[-2.2394986 -0.9680933 3.4041846 -0.14806426 -0.02724874 -1.2441163 -1.8548993 ] [-1.6016592 -1.3846085 0.7641872 2.5430043 -1.7419695 -1.3703678 0.42491326] [-2.0038617 -1.2356598 2.4931228 1.0322791 -1.325198 -1.3400824 -0.51143134] [ 0.16473567 -2.0420618 -0.26682284 0.15265226 -0.6964847 1.1109071 1.103439 ] [-0.8606019 -0.69538236 0.1958623 0.6853092 0.02840531 -0.6652414 0.22247872]] Test accuracy of TVM results: 5.30% .. _sphx_glr_download_how_to_work_with_relay_build_gcn.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download :download:`Download Python source code: build_gcn.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: build_gcn.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_