Use Relay Visualizer to Visualize Relay

Author: Chi-Wei Wang

Relay IR module can contain lots of operations. Although an individual operation is usually easy to understand, putting them together can cause a complicated, hard-to-read graph. Things can get even worse with optimization-passes coming into play.

This utility visualizes an IR module as nodes and edges. It defines a set of interfaces including parser, plotter(renderer), graph, node, and edges. A default parser is provided. Users can implement their own renderers to render the graph.

Here we use a renderer rendering graph in the text-form. It is a lightweight, AST-like visualizer, inspired by clang ast-dump. We will introduce how to implement customized parsers and renderers through interface classes.

For more details, please refer to tvm.contrib.relay_viz.

from typing import (
    Dict,
    Union,
    Tuple,
    List,
)
import tvm
from tvm import relay
from tvm.contrib import relay_viz
from tvm.contrib.relay_viz.interface import (
    VizEdge,
    VizNode,
    VizParser,
)
from tvm.contrib.relay_viz.terminal import (
    TermGraph,
    TermPlotter,
    TermVizParser,
)

Define a Relay IR Module with multiple GlobalVar

Let’s build an example Relay IR Module containing multiple GlobalVar. We define an add function and call it in the main function.

data = relay.var("data")
bias = relay.var("bias")
add_op = relay.add(data, bias)
add_func = relay.Function([data, bias], add_op)
add_gvar = relay.GlobalVar("AddFunc")

input0 = relay.var("input0")
input1 = relay.var("input1")
input2 = relay.var("input2")
add_01 = relay.Call(add_gvar, [input0, input1])
add_012 = relay.Call(add_gvar, [input2, add_01])
main_func = relay.Function([input0, input1, input2], add_012)
main_gvar = relay.GlobalVar("main")

mod = tvm.IRModule({main_gvar: main_func, add_gvar: add_func})

Render the graph with Relay Visualizer on the terminal

The terminal can show a Relay IR module in text similar to clang AST-dump. We should see main and AddFunc function. AddFunc is called twice in the main function.

@main([Var(input0), Var(input1), Var(input2)])
`--Call
   |--GlobalVar AddFunc
   |--Var(Input) name_hint: input2
   `--Call
      |--GlobalVar AddFunc
      |--Var(Input) name_hint: input0
      `--Var(Input) name_hint: input1
@AddFunc([Var(data), Var(bias)])
`--Call
   |--add
   |--Var(Input) name_hint: data
   `--Var(Input) name_hint: bias

Customize Parser for Interested Relay Types

Sometimes we want to emphasize interested information, or parse things differently for a specific usage. It is possible to provide customized parsers as long as it obeys the interface. Here demonstrate how to customize parsers for relay.var. We need to implement abstract interface tvm.contrib.relay_viz.interface.VizParser.

class YourAwesomeParser(VizParser):
    def __init__(self):
        self._delegate = TermVizParser()

    def get_node_edges(
        self,
        node: relay.Expr,
        relay_param: Dict[str, tvm.runtime.NDArray],
        node_to_id: Dict[relay.Expr, str],
    ) -> Tuple[Union[VizNode, None], List[VizEdge]]:

        if isinstance(node, relay.Var):
            node = VizNode(node_to_id[node], "AwesomeVar", f"name_hint {node.name_hint}")
            # no edge is introduced. So return an empty list.
            return node, []

        # delegate other types to the other parser.
        return self._delegate.get_node_edges(node, relay_param, node_to_id)

Pass the parser and an interested renderer to visualizer. Here we just the terminal renderer.

@main([Var(input0), Var(input1), Var(input2)])
`--Call
   |--GlobalVar AddFunc
   |--AwesomeVar name_hint input2
   `--Call
      |--GlobalVar AddFunc
      |--AwesomeVar name_hint input0
      `--AwesomeVar name_hint input1
@AddFunc([Var(data), Var(bias)])
`--Call
   |--add
   |--AwesomeVar name_hint data
   `--AwesomeVar name_hint bias

Customization around Graph and Plotter

Besides parsers, we can also customize graph and renderers by implementing abstract class tvm.contrib.relay_viz.interface.VizGraph and tvm.contrib.relay_viz.interface.Plotter. Here we override the TermGraph defined in terminal.py for easier demo. We add a hook duplicating above AwesomeVar, and make TermPlotter use the new class.

class AwesomeGraph(TermGraph):
    def node(self, viz_node):
        # add the node first
        super().node(viz_node)
        # if it's AwesomeVar, duplicate it.
        if viz_node.type_name == "AwesomeVar":
            duplicated_id = f"duplicated_{viz_node.identity}"
            duplicated_type = "double AwesomeVar"
            super().node(VizNode(duplicated_id, duplicated_type, ""))
            # connect the duplicated var to the original one
            super().edge(VizEdge(duplicated_id, viz_node.identity))


# override TermPlotter to use `AwesomeGraph` instead
class AwesomePlotter(TermPlotter):
    def create_graph(self, name):
        self._name_to_graph[name] = AwesomeGraph(name)
        return self._name_to_graph[name]


viz = relay_viz.RelayVisualizer(mod, {}, AwesomePlotter(), YourAwesomeParser())
viz.render()
@main([Var(input0), Var(input1), Var(input2)])
`--Call
   |--GlobalVar AddFunc
   |--AwesomeVar name_hint input2
   |  `--double AwesomeVar
   `--Call
      |--GlobalVar AddFunc
      |--AwesomeVar name_hint input0
      |  `--double AwesomeVar
      `--AwesomeVar name_hint input1
         `--double AwesomeVar
@AddFunc([Var(data), Var(bias)])
`--Call
   |--add
   |--AwesomeVar name_hint data
   |  `--double AwesomeVar
   `--AwesomeVar name_hint bias
      `--double AwesomeVar

Summary

This tutorial demonstrates the usage of Relay Visualizer and customization. The class tvm.contrib.relay_viz.RelayVisualizer is composed of interfaces defined in interface.py.

It is aimed for quick look-then-fix iterations. The constructor arguments are intended to be simple, while the customization is still possible through a set of interface classes.

Gallery generated by Sphinx-Gallery