tvm.transform

Common pass infrastructure across IR variants.

tvm.transform.ApplyPassToFunction(transform: tvm.ir.transform.Pass, func_name_regex: str, error_if_no_function_matches_regex: bool = False) tvm.ir.transform.Pass

Utility to apply a pass to specific functions in an IRModule

TVM uses IRModule to IRModule transformations at all stages of lowering. These transformations may be useful when hand-writing an optimized model, or to perform optimizations on specific kernels within an IRModule. This utility allows a pass to be applied to a specified function, without altering other functions in the module.

Parameters
  • transform (Pass) – The IRModule to IRModule pass to be applied.

  • func_name_regex (str) – A regex used to select the functions to be updated. The pass will be applied to all functions whose name matches the regex.

  • error_if_no_function_matches_regex (bool) – Specifies the behavior if an IRModule does not contain any function matching the provided regex. If true, an error will be raised. If false (default), the IRModule will be returned unmodified.

Returns

new_transform – The modified IRModule to IRModule pass.

Return type

Pass

class tvm.transform.ModulePass

A pass that works on tvm.IRModule. Users don’t need to interact with this class directly. Instead, a module pass should be created through module_pass, because the design of the module_pass API is flexible enough to handle the creation of a module pass in different manners. In addition, all members of a module pass can be accessed from the base class. The same rule applies to FunctionPass as well.

class tvm.transform.Pass

The base class of all passes. All methods here are just simple wrappers that are implemented in the backend. They are defined for users to conveniently interact with the base class.

property info

Get the pass meta.

class tvm.transform.PassContext(opt_level=2, required_pass=None, disabled_pass=None, instruments=None, config=None, trace=None, trace_stack=None, make_traceable=None, num_evals=0, tuning_api_database=None)

The basis where a Relay optimization/analysis runs on. Each pass context contains a number of auxiliary information that is used to help an optimization pass. Such information includes the error reporter to record the errors of during the optimization, etc.

opt_levelOptional[int]

The optimization level of this pass.

required_passOptional[Union[List[str], Set[str], Tuple[str]]]

The list of passes that are required by a certain pass.

disabled_passOptional[Union[List[str], Set[str], Tuple[str]]]

The list of passes that are disabled.

instrumentsOptional[Sequence[PassInstrument]]

The list of pass instrument implementations.

configOptional[Dict[str, Object]]

Additional configurations for specific passes.

trace: Optional[relax.tuning.Trace]

Initial trace for trace mode.

trace_stack: Optional[List[relax.tuning_api.Trace]]

Initial trace stack for trace mode.

make_traceable: Optional[List[str]]

List of passes to make traceable.

num_evals: int

initial number of evaluations conducted in the pipeline.

tuning_api_database: Optional[relax.tuning_api.JSONDatabase]

override_instruments(instruments)

Override instruments within this PassContext.

If there are existing instruments, their exit_pass_ctx callbacks are called. Then switching to new instruments and calling new enter_pass_ctx callbacks.

instrumentsSequence[PassInstrument]

The list of pass instrument implementations.

static current()

Return the current pass context.

static list_configs()

List all registered PassContext configuration names and metadata.

Returns

configs

Return type

Dict[str, Dict[str, str]]

push_trace(trace)

Push a trace into the stack.

pop_trace(return_current=True)

Pop a topmost trace from the stack. :returns: Trace :rtype: Optional[relax.tuning.Trace]

get_trace_stack()

Get the current trace stack.

get_trace_stack_size()

Get the size of current stack.

get_current_trace()

Get the trace on the top of the stack.

set_num_evals(num: int)

Set the number of evaluations conducted in the pipeline.

inc_num_evals(num: int)

Increment the number of evaluations conducted in the pipeline.

get_tuning_api_database()

Get tuning api database.

class tvm.transform.PassInfo(opt_level, name, required=None, traceable=False)

The class contains the meta data required by a pass. It is the container of information needed by running an optimization or analysis. This class can be extended by adding new members when more meta data is needed.

Parameters
  • opt_level (int) – The optimization level of this pass.

  • name (str) – The pass name.

  • required (List[str]) – The list of passes that are required by a certain pass.

tvm.transform.PrintIR(header='', show_meta_data=False)

A special trace pass that prints the header and IR.

Parameters
  • header (str) – The header to be displayed along with the dump.

  • show_meta_data (bool) – A boolean flag to indicate if meta data should be printed.

Returns

Return type

The pass

class tvm.transform.Sequential(passes=None, opt_level=0, name='sequential', required=None, traceable=False)

A pass that works on a sequence of pass objects. Multiple passes can be executed sequentially using this class.

Note that users can also provide a series of passes that they don’t want to apply when running a sequential pass. Pass dependency will be resolved in the backend as well.

Parameters
  • passes (Optional[List[Pass]]) – A sequence of passes candidate for optimization.

  • opt_level (Optional[int]) – The optimization level of this sequential pass. The opt_level of a default sequential pass is set to 0. Note that some of the passes within the Sequantial may still not be executed if their opt_level is higher than the provided opt_level.

  • name (Optional[str]) – The name of the sequential pass.

  • required (Optional[List[str]]) – The list of passes that the sequential pass is dependent on.

tvm.transform.module_pass(pass_func=None, opt_level=None, name=None, required=None, traceable=False)

Decorate a module pass.

This function returns a callback when pass_func is provided. Otherwise, it serves a decorator function.

pass_func can also be a class type with a method transform_module. This function will create a decorated ModulePass using transform_module as the pass function.

Parameters
  • pass_func (Optional[Callable[(Module, PassContext) ->Module]]) – The transformation function or class.

  • opt_level (int) – The optimization level of this module pass.

  • name (Optional[str]) – The name of the module pass. The name could be empty. In this case, the name of the optimization function will be used as the pass name.

  • required (Optional[List[str]]) – The list of passes that the module pass is dependent on.

  • traceable (Boolean) – Boolean variable whether the module pass is traceable

Returns

create_module_pass – A decorator will be returned if pass_func is not provided, otherwise return the decorated result. The returned decorator has two behaviors depending on the input: A new ModulePass will be returned when we decorate a pass function. A new ModulePass class will be returned when we decorate a class type.

Return type

Union[Callable, ModulePass]

Examples

The following code block decorates a module pass class.

@relay.transform.module_pass
class CustomPipeline:
    def __init__(self, enable_fold):
        self.enable_fold = enable_fold
        self.cse = relay.transform.EliminateCommonSubexpr()
        self.const_fold = relay.transform.FoldConstant()

    def transform_module(self, mod, ctx):
        mod = self.cse(mod, ctx)
        if self.enable_fold:
            mod = self.const_fold(mod, ctx)
        return mod

# create an instance of customized pipeline
pipeline = CustomPipeline(enable_fold=False)
assert isinstance(pipeline, transform.ModulePass)
# run the pipeline.
output_module = pipeline(input_module)

The following code creates a module pass by decorating a user defined transform function.

@relay.transform.module_pass(opt_level=2)
def transform(mod, ctx):
    tp = relay.TensorType((10,), "float32")
    x = relay.var("x", tp)
    gv = relay.GlobalVar("var")
    func = relay.Function([x], relay.abs(x))
    new_mod = tvm.IRModule({gv: func})
    new_mod.update(mod)
    return new_mod

module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2

# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = module_pass(m)
# Now a function abs should be added to the module m.