Pass Infrastructure
Both Relay and TVM IR contain a series of optimization passes which improve performance metrics of models such as mean inference, memory footprint, or power consumption for specific devices. There is a suite of standard optimizations as well as machine learning-specific optimizations including constant folding, dead code elimination, operator layout alteration, operator fusion, buffer handling, and loop transformation, etc. Each of these passes is structured as a ir-to-ir transformation using the analysis result collected during and/or before traversal.
However, as TVM evolves quickly, the need for a more systematic and efficient way to manage these passes is becoming apparent. In addition, a generic framework that manages the passes across different layers of the TVM stack (e.g. Relay and tir) paves the way for developers to quickly prototype and plug the implemented passes into the system.
This doc describes the design of such an infra that takes the advantage of the way production compilers are used to manage the optimization passes and the style modern deep learning frameworks adopted to build up layers.
For example, many existing production compilers, such as GCC and LLVM, employ pass managers to effectively manage the execution of passes. Initially managing passes is straightforward as the number of passes is small, but mature compilers will contain hundreds of individual passes. Often external users will want to have custom passes correctly scheduled without having to modify a single handcrafted pass order.
Similarly, modern deep learning frameworks, such as Pytorch and MXNet Gluon, also have the tendency to enable pass-style layer construction scheme through Sequential and Block, respectively. With such constructs, these modern frameworks are able to conveniently add modules/layers to their containers and build up neural networks easily.
The design of the Relay pass infra is largely inspired by the hierarchical pass manager used in LLVM and the block-style containers used in the popular deep learning frameworks. The major goals of the pass infra include:
enabling better programmatic orchestration of optimizations. This allows users to flexibly customize and build their own optimization pipelines.
providing a user-friendly way to debug optimization passes.
alleviating developers from manually and respectively resolving the dependencies between passes.
simplifying the implementation of new passes for developers. For example, we allow users to implement a pass in Python and let the pass infra manipulate its execution.
The Design
We focus on ease of extension for users, making it possible for users to quickly add new passes without loss of backward compatibility. The design contains both the backend and the frontend. The former implements the main logic of the pass infra. The latter provides simple APIs for users to interact with, i.e., allowing users to quickly create their own optimization pipelines.
C++ Backend
We provide a PassInfo
object to contain the basic information needed by
a pass. name
is the pass name, opt_level
indicates at which optimization
level the pass will be enabled, and required
represents the passes that are
required to execute a certain pass (see include/tvm/ir/transform.h for
more details). For example, during registration of a pass (will be covered in
later), the pass developers can specify the name of the pass, the optimization
level it will be performed at, and/or the passes that are required.
opt_level
could be used to help the pass infra identify if a certain pass
needs to be executed when running under a user-provided optimization level. The
required
field can be used by the pass infra to resolve pass dependencies.
class PassInfoNode : public Object {
String name;
int opt_level;
Array<String> required;
};
PassContext
PassContext
carries useful information for an optimization pass. For
example, it contains the error reporting system so optimization authors can
provide diagnostics about why an optimization fails. PassContext
is also
designed to replace the old BuildConfig
which was used to help users
configure the compilation options, including optimization level and
required/disabled passes, etc. For instance, we may have a configuration which
performs all passes at opt_level=3
with some disabled passes using
disabled_pass=xx
provided by PassContext
. Now we could glob all passes
at opt_level=3
and exclude those in the disabled pass list. PassContext
also provides a way to instrument all passes. See section Pass Instrument.
This class is designed for users to conveniently write the Python with
syntax to perform optimizations under a certain configuration. In addition, the
users can obtain the context that is available within a certain program scope in
a thread-safe way through PassContext::Current()
, since a thread-local store
PassContextThreadLocalStore
is used to hold the created pass context
objects. Examples will be provided later to show how we can use both the C++ and
Python APIs to create a compilation pipeline using pass context.
class PassContextNode : public Object {
public:
int opt_level{2};
tvm::Array<tvm::Expr> required_pass;
tvm::Array<tvm::Expr> disabled_pass;
mutable Optional<DiagnosticContext> diag_ctx;
Map<String, ObjectRef> config;
Array<instrument::PassInstrument> instruments;
};
class PassContext : public NodeRef {
public:
TVM_DLL static PassContext Create();
TVM_DLL static PassContext Current();
TVM_DLL void InstrumentEnterPassContext();
TVM_DLL void InstrumentExitPassContext();
TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
/* Other fields are omitted. */
private:
// The entry of a pass context scope.
TVM_DLL void EnterWithScope();
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();
// Classes to get the Python `with` like syntax.
friend class tvm::With<PassContext>;
};
struct PassContextThreadLocalEntry {
/*! \brief The default pass context. */
PassContext default_context;
/*! \brief The current pass context. */
std::stack<PassContext> context_stack;
PassContextThreadLocalEntry() {
default_context = PassContext(make_node<PassContextNode>());
}
};
/*! \brief The thread-local store to hold the pass context. */
typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry>
PassContextThreadLocalStore;
Pass Constructs
The pass infra is designed in a hierarchical manner, and it could work at
different granularities of Relay/tir programs. A pure virtual class PassNode
is
introduced to serve as the base of the different optimization passes. This class
contains several virtual methods that must be implemented by the
subclasses at the level of modules, functions, or sequences of passes.
class PassNode : Object {
virtual PassInfo Info() const = 0;
virtual Module operator()(const IRModule& mod
const PassContext& pass_ctx) const = 0;
};
The functor shows how a pass must be realized, i.e. it always works on a
IRModule
under a certain context. All passes are designed in a Module
to Module
manner. Therefore, optimizations governed by the pass infra will
always update the whole module.
Several subclasses have been created to implement different types of optimization passes, e.g., function-level passes, module-level passes, and sequential passes. Each subclass itself could act as a pass manager. For instance, they could collect the required passes and execute them or build a dependency graph based on the given metadata. The full definition of them can be found in src/relay/ir/transform.cc and src/ir/transform.cc.
Module-Level Passes
Module level passes are geared mainly for global and inter-procedural optimizations (IPO), which are similar to the module pass used in LLVM. Some typical passes in Relay that need the global picture of a module, such as A-normal form conversion and lambda lifting, etc., fall into this set. At this level, users can even add and/or delete functions in a module. Note that all passes
class ModulePassNode : PassNode {
PassInfo pass_info;
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
// Other members/methods are omitted
};
pass_info
maintains the information needed by a module-level pass.
pass_func
sketches the real optimization. For example, we may need to
perform dead code elimination on the module. We could implement the algorithm in
the pass_func
and let it run on a module. It will then remove the dead code
including the unused functions in the module. Note that this field is designed
as a packed function, which enables the implementation of the optimization in
both C++ and Python.
Function-Level Passes
Function-level passes are used to implement various intra-function level
optimizations for a given Relay/tir module. It fetches one function at a time from
the function list of a module for optimization and yields a rewritten Relay
Function
or tir PrimFunc
. Most of passes can be classified into this category, such as
common subexpression elimination and inference simplification in Relay as well as vectorization
and flattening storage in tir, etc.
Note that the scope of passes at this level is either a Relay function or a tir primitive function. Therefore, we cannot add or delete a function through these passes as they are not aware of the global information.
class FunctionPassNode : PassNode {
PassInfo pass_info;
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
bool SkipFunction(const Function& func) const;
// Other members/methods are omitted...
};
pass_info
is identical to what we just described in the module pass.
pass_func
takes a function for optimization, it also needs a module as we
may use it for reporting errors. A function could be annotated with
“SkipOptimization” so that it will be ignored during optimization.
Sequential Passes
SequentialPass
is similar to Pytorch nn.Sequential
that contains a host
of passes for execution.
class SequentialPassNode : PassNode {
PassInfo pass_info;
// Passes need to be executed.
Array<Pass> passes;
bool PassEnabled(const PassInfo& info) const;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
};
Only a few passes currently in Relay are put in this group. For example,
FoldScaleAxis
requires to dispatch ForwardFoldScaleAxis
and
BackwardFoldScaleAxis
internally. In addition, BackwardFoldScaleAxis
is
recommended to be fulfilled first. This pass, hence, is an ideal candidate for
SequentialPass
.
The following code shows how individual passes in a sequential pass are invoked. Essentially, we sequentially execute each pass in a sequential pass using the order that they were appended to the pass list.
Module SequentialNode::operator()(const Module& module,
const PassContext& pass_ctx) const {
Module mod = module;
for (const Pass& pass : passes) {
ICHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info();
if (!PassEnabled(pass_info)) continue;
for (const auto& it : pass_info->required) {
const auto* name = it.as<tvm::ir::StringImm>();
ICHECK(name);
mod = GetPass(name->value)(mod, pass_ctx);
}
mod = pass(mod, pass_ctx);
}
return mod;
}
Upon the invocation of a pass, we first check if this pass is enabled. This is
done by first checking if the pass is explicitly disabled by a user, followed by
inspecting if it is specified as a required pass by the user. If it is still
undetermined whether this pass is enabled, its opt_level
will be checked.
This pass will be enabled and therefore executed only when its optimization
level is not less than the configured optimization level in the pass context.
To execute the pass, we need first to retrieve the registered pass in the TVM packed function registry using the pass name. This is possible because every pass is registered with an API endpoint as we will show later.
Pass GetPass(const std::string& pass_name) {
using tvm::runtime::Registry;
std::string fpass_name = "relay._transform." + pass_name;
const auto* f = Registry::Get(fpass_name);
ICHECK(f != nullptr) << "Cannot find " << fpass_name
<< "to create the pass " << pass_name;
return (*f)();
}
Some helper functions are provided to create each type of these aforementioned passes. These helpers are also exposed to the Python frontend for users to favorably use Python APIs to create a specific pass object.
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level,
String name,
Array<String> required);
Pass CreatePrimFuncPass(
const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
String name,
Array<String> required);
Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
String name,
Array<String> required);
Pass Sequential(tvm::Array<Pass> passes, PassInfo pass_info);
Pass Registration
We’ve covered the concept of different level of passes and the context used for compilation. It would be interesting to see how easily users can register a pass. Let’s take const folding as an example. This pass has already been implemented to fold constants in a Relay function (found in src/relay/transforms/fold_constant.cc).
An API was provided to perform the Expr
to Expr
transformation.
Expr FoldConstant(const Expr& expr);
In order to register this pass to the pass infra, we first need to decide at
which level this pass will be performed. As const folding happens on individual
functions, we should intuitively create a FunctionPass
for it through
CreateFunctionPass
. The pass_func
is returned as a packed function that
invokes the Expr
to Expr
API on each function in a IRModule. {}
indicates that no prerequisite is required for this pass. Otherwise, the pass
developer has to identify and list them.
Meanwhile, a pass API endpoint is registered with the name
relay._transform.FoldConstant
. This pass, therefore, becomes an entry in the
registry that can be accessed by both C++ (e.g. the GetPass
above) and
Python when needed.
namespace transform {
Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FoldConstant(f));
};
return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}
TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);
} // namespace transform
To allow other C++ modules to apply this pass, we declare a free function in include/tvm/relay/transform.h as the following:
TVM_DLL Pass FoldConstant();
Pass Instrument
Pass Instrument is a mechanism to analyze the pass itself. For example, we can use the infrastructure to know how much time and memory a pass requires or how a pass can transform the IR module.
We introduce four instrument points in the life-cycle of PassContext
.
TVM_DLL void InstrumentEnterPassContext();
TVM_DLL void InstrumentExitPassContext();
TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
InstrumentEnterPassContext
is called immediately when entering the scope
of the PassContext
instance.
InstrumentExitPassContext
is called when leaving the scope of PassContext
,
or exceptions occur during the execution of passes.
This method is also called when instruments is being overriden by override_instruments
in tvm.transform.PassContext
.
See Override Instruments in Current PassContext.
InstrumentBeforePass
is called before execution.
InstrumentAfterPass
is called after execution if the pass should be run. The behavior is like:
if (pass_ctx.InstrumentBeforePass(ir_module, pass_info)) {
new_ir_module = run_pass(ir_module, pass_ctx);
pass_ctx.InstrumentAfterPass(new_ir_module, pass_info);
return new_ir_module;
}
The PassInstrument
interface allow you to run arbitrary code inside above four methods.
Multiple PassInstrument
instances can be registed into a single
PassContext
. PassInstrument
instances are called sequentially in the order of
instruments
argument passed to PassContext
.
PassInstrument
provides following interfaces:
namespace instrument {
class PassInstrumentNode : public Object {
public:
String name;
virtual void EnterPassContext() const = 0;
virtual void ExitPassContext() const = 0;
virtual bool ShouldRun(const IRModule& mod, const transform::PassInfo& info) const = 0;
virtual void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const = 0;
virtual void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const = 0;
/* Other fields are omitted. */
};
class PassInstrument : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode);
};
} // namespace instrument
Python frontend are provided to implement PassInstrument
quickly. See Pass Instrument.
Within a PassContext
, the call sequence of a PassInstrument
instance is like:
with PassContext(instruments=[pi]) # pi = a PassInstrument implementation.
pi.EnterPassContext()
if pi.ShouldRun(Pass1):
pi.RunBeforePass()
Pass1()
pi.RunAfterPass()
if pi.ShouldRun(Pass2):
pi.RunBeforePass()
Pass2()
pi.RunAfterPass()
pi.ExitPassContext()
Here is a brief introduction of relations between PassInstrument
interfaces
and PassContext
methods. See (src/ir/transform.cc) for more details.
InstrumentEnterPassContext
EnterPassContext()
is executed in the order ofinstruments
passed to thePassContext
.When an exception raises,
PassContext
disable the pass instrumentation by clearing all registeredPassInstrument
instances.Then
PassContext
executeExitPassContext()
method of eachPassInstrument
instances which successfully finishedEnterPassContext()
For example, if
PassInstrument
A, B, and C are registered to aPassContext
and A finishedEnterPassContext()
while B throws an exception, then C is never executed;ExitPassContext()
of A is executed.
InstrumentExitPassContext
ExitPassContext()
of eachPassInstrument
instances are executed in the order ofinstruments
passed to thePassContext
.While an exception occurs,
instruments
is cleared.PassInstrument
Instances registered after the one throwing exceptions do not executeExitPassContext
.
InstrumentBeforePass
ShouldRun
is executed if the pass is not listed as a required pass.RunBeforePass
is executed in the order ofinstruments
if the pass is not blocked byShouldRun
.Note that
InstrumentBeforePass
returns a boolean indicating whether or not the pass should be run.When an exception occur, it is thrown immediately. We rely on Python Context Manager to exit
PassContext
safely (meaningExitPassContext
of each instruments will be run. For C++, please refer to include/tvm/support/with.h.)
InstrumentAfterPass
RunAfterPass
is executed in the order ofinstruments
passed to thePassContext
.When an exception occur, it is thrown immediately. We rely on Python Context Manager or
With
class(include/tvm/support/with.h) to exitPassContext
safely
Built-in Instrument
There are several built-in instruments. Those marked with TODO are not implemented yet.
PassTimingInstrument (see src/ir/instrument.cc)
Profile the execution time of passes.
PrintIRBefore(TODO)
Print the IR module before the pass transforms it.
tvm.transform.PrintIR()
can also serve this purpose if we insert it around passes. However, with thePassInstrument
, we don’t need to modify the sequence of passes.
PrintAfter(TODO)
Print the IR module after the pass transforms it.
Python Frontend
Only some simple APIs are needed for the frontend side. For example, we can provide users the following APIs to create and execute a pass (full implementation is provided in python/tvm/relay/transform/transform.py and python/tvm/ir/transform.py). The backend receives the information and decides which function it should use to create a Pass object.
PassContext
Python frontend provides a wrapper for the PassContext
to enable the
with
syntax by overriding __enter__
and __exit__
. A current
static method is offered for users to get the context that is in use under
a certain scope.
@tvm._ffi.register_object("transform.PassContext")
class PassContext(tvm.runtime.Object):
def __enter__(self):
_transform.EnterPassContext(self)
return self
def __exit__(self, ptype, value, trace, config):
_transform.ExitPassContext(self)
@staticmethod
def current():
"""Return the current pass context."""
return _transform.GetCurrentPassContext()
A PassContext
is used to configure the compilation options, including the
optimization level and required/disabled passes. It can also take a dictionary
of configs so that different passes can conveniently fetch the passed data, such
as fallback device info and step/depth for loop unrolling, etc. In order to
enable fetching the required config, the key must be registered through
TVM_REGISTER_PASS_CONFIG_OPTION
. For example, the following is used by the
loop unrolling pass
TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);
Please refer to src/tir/transforms/unroll_loop.cc for more details.
Pass Objects
Pass
is the base class of all pass objects. All methods here are just simple
wrappers that were implemented in the backend. They are defined for users to
conveniently interact with the base class in Python. Only a __call__
is
defined in the pass base class to make the subclasses as callable objects so
that they can be invoked easily (e.g., pass_xx(arg)
) for execution.
@register_relay_node
class Pass(RelayNode):
def __call__(self, mod):
return _transform.RunPass(self, mod)
Some auxiliary APIs are provided to enable easy creation of passes from
the Python frontend and to let the pass infra control the execution. For
example, module_pass
, function_pass
, and sequential
are provided to
users so that they can customize their own pass or pass pipeline.
For all the passes that are implemented in the C++ backend, we provide corresponding Python APIs in python/tvm/ir/transform.py and python/tvm/relay/transform/transform.py, respectively. For instance, const folding has a Python API like the following:
def FoldConstant():
return _transform.FoldConstant()
Users can build a pass through decoration like the following:
@relay.transform.module_pass(opt_level=2)
def transform(mod, ctx):
tp = relay.TensorType((10,), "float32")
x = relay.var("x", tp)
gv = relay.GlobalVar("abs")
func = relay.Function([x], relay.abs(x))
new_mod = tvm.IRModule({gv: func})
new_mod.update(mod)
return new_mod
module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2
The transform
function here adds an abs
function to the input module,
but it could be any customized optimizations at the module level. After
creating this module_pass
, users can apply it on any Relay module. For
example, we can build an empty module and apply this pass to add an abs
function.
mod = tvm.IRModule()
mod = module_pass(mod)
Correspondingly, we also offer such functionality for function_pass
. For
instance, an example function-level pass could be written as the following:
@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
def __init__(self, new_func):
self.new_func = new_func
def transform_function(self, func, mod, ctx):
# Just for demo purposes
# Transform func to new_func
return self.new_func
x = relay.var("x", shape=(10, 20))
f1 = relay.Function([x], x)
f2 = relay.Function([x], relay.log(x))
# fpass is now a special pass that replaces every
# function to f1
fpass = TestReplaceFunc(f1)
# Now every function in input_mod is replaced by f1
res_mod = fpass(input_mod)
Alternatively, users can also directly register a pass without using the decorators and then invoke it. For more examples about how to customize your own optimization pipeline and debug Relay and tir passes, please refer to the use pass infra tutorial.
Pass Instrument
One can implement a PassInstrument
by using the pass_instrument
decorator(python/tvm/ir/instrument.py) on a class implementing following methods.
Note that it is recommended to use the pass_instrument
decorator to implement
PassInstrument
, instead of overriding or subclassing.
enter_pass_ctx
This method is run when entering
PassContext
.
exit_pass_ctx
This method is run when exiting
PassContext
.
should_run
This method is run before a pass is executed, returning a boolean indicating whether or not the pass should be run.
run_before_pass
If a pass should be run, this method is run just before pass execution.
run_after_pass
This method is run right after a pass has been executed.
PassInstrument
instances can be registered through instruments
argument in
tvm.transform.PassContext
.
use pass instrument tutorial provides examples for how to implement PassInstrument
with Python APIs.
Override Instruments in Current PassContext
override_instruments
method is provided to override the instruments
of current PassContext
.
For example, if passes are run without explicitly creating a new PassContext
,
one can still register PassInstrument
into the global PassContext
by:
cur_pass_ctx = tvm.transform.PassContext.current()
# override PassInstrument instances
cur_pass_ctx.override_instruments([pass_inst])
mod = pass_seq(mod)
result = pass_inst.get_result()
Note that when override_instruments
is called, the exit_pass_ctx
method of
old PassInstrument
instances are called. Then the enter_pass_ctx
method of
new PassInstrument
are called.