Note
This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally.
How to Use TVM Pass Instrument¶
Author: Chi-Wei Wang
As more and more passes are implemented, it becomes useful to instrument pass execution, analyze per-pass effects, and observe various events.
We can instrument passes by providing a list of tvm.ir.instrument.PassInstrument
instances to tvm.transform.PassContext
. We provide a pass instrument
for collecting timing information (tvm.ir.instrument.PassTimingInstrument
),
but an extension mechanism is available via the tvm.instrument.pass_instrument()
decorator.
This tutorial demonstrates how developers can use PassContext
to instrument
passes. Please also refer to the Pass Infrastructure.
import tvm
import tvm.relay as relay
from tvm.relay.testing import resnet
from tvm.contrib.download import download_testdata
from tvm.relay.build_module import bind_params_by_name
from tvm.ir.instrument import (
PassTimingInstrument,
pass_instrument,
)
Create An Example Relay Program¶
We use pre-defined resnet-18 network in Relay.
batch_size = 1
num_of_image_class = 1000
image_shape = (3, 224, 224)
output_shape = (batch_size, num_of_image_class)
relay_mod, relay_params = resnet.get_workload(num_layers=18, batch_size=1, image_shape=image_shape)
print("Printing the IR module...")
print(relay_mod.astext(show_meta_data=False))
Create PassContext With Instruments¶
To run all passes with an instrument, pass it via the instruments
argument to
the PassContext
constructor. A built-in PassTimingInstrument
is used to
profile the execution time of each passes.
timing_inst = PassTimingInstrument()
with tvm.transform.PassContext(instruments=[timing_inst]):
relay_mod = relay.transform.InferType()(relay_mod)
relay_mod = relay.transform.FoldScaleAxis()(relay_mod)
# before exiting the context, get profile results.
profiles = timing_inst.render()
print("Printing results of timing profile...")
print(profiles)
Use Current PassContext With Instruments¶
One can also use the current PassContext
and register
PassInstrument
instances by override_instruments
method.
Note that override_instruments
executes exit_pass_ctx
method
if any instrument already exists. Then it switches to new instruments
and calls enter_pass_ctx
method of new instruments.
Refer to following sections and tvm.instrument.pass_instrument()
for these methods.
cur_pass_ctx = tvm.transform.PassContext.current()
cur_pass_ctx.override_instruments([timing_inst])
relay_mod = relay.transform.InferType()(relay_mod)
relay_mod = relay.transform.FoldScaleAxis()(relay_mod)
profiles = timing_inst.render()
print("Printing results of timing profile...")
print(profiles)
Register empty list to clear existing instruments.
Note that exit_pass_ctx
of PassTimingInstrument
is called.
Profiles are cleared so nothing is printed.
cur_pass_ctx.override_instruments([])
# Uncomment the call to .render() to see a warning like:
# Warning: no passes have been profiled, did you enable pass profiling?
# profiles = timing_inst.render()
Create Customized Instrument Class¶
A customized instrument class can be created using the
tvm.instrument.pass_instrument()
decorator.
Let’s create an instrument class which calculates the change in number of
occurrences of each operator caused by each pass. We can look at op.name
to
find the name of each operator. And we do this before and after passes to calculate the difference.
@pass_instrument
class RelayCallNodeDiffer:
def __init__(self):
self._op_diff = []
# Passes can be nested.
# Use stack to make sure we get correct before/after pairs.
self._op_cnt_before_stack = []
def enter_pass_ctx(self):
self._op_diff = []
self._op_cnt_before_stack = []
def exit_pass_ctx(self):
assert len(self._op_cnt_before_stack) == 0, "The stack is not empty. Something wrong."
def run_before_pass(self, mod, info):
self._op_cnt_before_stack.append((info.name, self._count_nodes(mod)))
def run_after_pass(self, mod, info):
# Pop out the latest recorded pass.
name_before, op_to_cnt_before = self._op_cnt_before_stack.pop()
assert name_before == info.name, "name_before: {}, info.name: {} doesn't match".format(
name_before, info.name
)
cur_depth = len(self._op_cnt_before_stack)
op_to_cnt_after = self._count_nodes(mod)
op_diff = self._diff(op_to_cnt_after, op_to_cnt_before)
# only record passes causing differences.
if op_diff:
self._op_diff.append((cur_depth, info.name, op_diff))
def get_pass_to_op_diff(self):
"""
return [
(depth, pass_name, {op_name: diff_num, ...}), ...
]
"""
return self._op_diff
@staticmethod
def _count_nodes(mod):
"""Count the number of occurrences of each operator in the module"""
ret = {}
def visit(node):
if isinstance(node, relay.expr.Call):
if hasattr(node.op, "name"):
op_name = node.op.name
else:
# Some CallNode may not have 'name' such as relay.Function
return
ret[op_name] = ret.get(op_name, 0) + 1
relay.analysis.post_order_visit(mod["main"], visit)
return ret
@staticmethod
def _diff(d_after, d_before):
"""Calculate the difference of two dictionary along their keys.
The result is values in d_after minus values in d_before.
"""
ret = {}
key_after, key_before = set(d_after), set(d_before)
for k in key_before & key_after:
tmp = d_after[k] - d_before[k]
if tmp:
ret[k] = d_after[k] - d_before[k]
for k in key_after - key_before:
ret[k] = d_after[k]
for k in key_before - key_after:
ret[k] = -d_before[k]
return ret
Apply Passes and Multiple Instrument Classes¶
We can use multiple instrument classes in a PassContext
.
However, it should be noted that instrument methods are executed sequentially,
obeying the order of instruments
argument.
So for instrument classes like PassTimingInstrument
, it is inevitable to
count-up the execution time of other instrument classes to the final
profile result.
call_node_inst = RelayCallNodeDiffer()
desired_layouts = {
"nn.conv2d": ["NHWC", "HWIO"],
}
pass_seq = tvm.transform.Sequential(
[
relay.transform.FoldConstant(),
relay.transform.ConvertLayout(desired_layouts),
relay.transform.FoldConstant(),
]
)
relay_mod["main"] = bind_params_by_name(relay_mod["main"], relay_params)
# timing_inst is put after call_node_inst.
# So the execution time of ``call_node.inst.run_after_pass()`` is also counted.
with tvm.transform.PassContext(opt_level=3, instruments=[call_node_inst, timing_inst]):
relay_mod = pass_seq(relay_mod)
profiles = timing_inst.render()
# Uncomment the next line to see timing-profile results.
# print(profiles)
We can see how many CallNode increase/decrease per op type.
Exception Handling¶
Let’s see what happens if an exception occurs in a method of a PassInstrument
.
Define PassInstrument
classes which raise exceptions in enter/exit PassContext
:
class PassExampleBase:
def __init__(self, name):
self._name = name
def enter_pass_ctx(self):
print(self._name, "enter_pass_ctx")
def exit_pass_ctx(self):
print(self._name, "exit_pass_ctx")
def should_run(self, mod, info):
print(self._name, "should_run")
return True
def run_before_pass(self, mod, pass_info):
print(self._name, "run_before_pass")
def run_after_pass(self, mod, pass_info):
print(self._name, "run_after_pass")
@pass_instrument
class PassFine(PassExampleBase):
pass
@pass_instrument
class PassBadEnterCtx(PassExampleBase):
def enter_pass_ctx(self):
print(self._name, "bad enter_pass_ctx!!!")
raise ValueError("{} bad enter_pass_ctx".format(self._name))
@pass_instrument
class PassBadExitCtx(PassExampleBase):
def exit_pass_ctx(self):
print(self._name, "bad exit_pass_ctx!!!")
raise ValueError("{} bad exit_pass_ctx".format(self._name))
If an exception occurs in enter_pass_ctx
, PassContext
will disable the pass
instrumentation. And it will run the exit_pass_ctx
of each PassInstrument
which successfully finished enter_pass_ctx
.
In following example, we can see exit_pass_ctx
of PassFine_0 is executed after exception.
demo_ctx = tvm.transform.PassContext(
instruments=[
PassFine("PassFine_0"),
PassBadEnterCtx("PassBadEnterCtx"),
PassFine("PassFine_1"),
]
)
try:
with demo_ctx:
relay_mod = relay.transform.InferType()(relay_mod)
except ValueError as ex:
print("Catching", str(ex).split("\n")[-1])
Exceptions in PassInstrument
instances cause all instruments of the current PassContext
to be cleared, so nothing is printed when override_instruments
is called.
demo_ctx.override_instruments([]) # no PassFine_0 exit_pass_ctx printed....etc
If an exception occurs in exit_pass_ctx
, then the pass instrument is disabled.
Then exception is propagated. That means PassInstrument
instances registered
after the one throwing the exception do not execute exit_pass_ctx
.
demo_ctx = tvm.transform.PassContext(
instruments=[
PassFine("PassFine_0"),
PassBadExitCtx("PassBadExitCtx"),
PassFine("PassFine_1"),
]
)
try:
# PassFine_1 execute enter_pass_ctx, but not exit_pass_ctx.
with demo_ctx:
relay_mod = relay.transform.InferType()(relay_mod)
except ValueError as ex:
print("Catching", str(ex).split("\n")[-1])
Exceptions occurred in should_run
, run_before_pass
, run_after_pass
are not handled explicitly – we rely on the context manager (the with
syntax)
to exit PassContext
safely.
We use run_before_pass
as an example:
@pass_instrument
class PassBadRunBefore(PassExampleBase):
def run_before_pass(self, mod, pass_info):
print(self._name, "bad run_before_pass!!!")
raise ValueError("{} bad run_before_pass".format(self._name))
demo_ctx = tvm.transform.PassContext(
instruments=[
PassFine("PassFine_0"),
PassBadRunBefore("PassBadRunBefore"),
PassFine("PassFine_1"),
]
)
try:
# All exit_pass_ctx are called.
with demo_ctx:
relay_mod = relay.transform.InferType()(relay_mod)
except ValueError as ex:
print("Catching", str(ex).split("\n")[-1])
Also note that pass instrumentation is not disable. So if we call
override_instruments
, the exit_pass_ctx
of old registered PassInstrument
is called.
demo_ctx.override_instruments([])
If we don’t wrap pass execution with with
syntax, exit_pass_ctx
is not
called. Let try this with current PassContext
:
cur_pass_ctx = tvm.transform.PassContext.current()
cur_pass_ctx.override_instruments(
[
PassFine("PassFine_0"),
PassBadRunBefore("PassBadRunBefore"),
PassFine("PassFine_1"),
]
)
Then call passes. exit_pass_ctx
is not executed after the exception,
as expectation.
try:
# No ``exit_pass_ctx`` got executed.
relay_mod = relay.transform.InferType()(relay_mod)
except ValueError as ex:
print("Catching", str(ex).split("\n")[-1])
Clear instruments.
cur_pass_ctx.override_instruments([])