tvm.transform¶
Common pass infrastructure across IR variants.
- tvm.transform.ApplyPassToFunction(transform: tvm.ir.transform.Pass, func_name_regex: str, error_if_no_function_matches_regex: bool = False) tvm.ir.transform.Pass ¶
Utility to apply a pass to specific functions in an IRModule
TVM uses IRModule to IRModule transformations at all stages of lowering. These transformations may be useful when hand-writing an optimized model, or to perform optimizations on specific kernels within an IRModule. This utility allows a pass to be applied to a specified function, without altering other functions in the module.
- Parameters
transform (Pass) – The IRModule to IRModule pass to be applied.
func_name_regex (str) – A regex used to select the functions to be updated. The pass will be applied to all functions whose name matches the regex.
error_if_no_function_matches_regex (bool) – Specifies the behavior if an IRModule does not contain any function matching the provided regex. If true, an error will be raised. If false (default), the IRModule will be returned unmodified.
- Returns
new_transform – The modified IRModule to IRModule pass.
- Return type
- class tvm.transform.ModulePass¶
A pass that works on tvm.IRModule. Users don’t need to interact with this class directly. Instead, a module pass should be created through module_pass, because the design of the module_pass API is flexible enough to handle the creation of a module pass in different manners. In addition, all members of a module pass can be accessed from the base class. The same rule applies to FunctionPass as well.
- class tvm.transform.Pass¶
The base class of all passes. All methods here are just simple wrappers that are implemented in the backend. They are defined for users to conveniently interact with the base class.
- property info¶
Get the pass meta.
- class tvm.transform.PassContext(opt_level=2, required_pass=None, disabled_pass=None, instruments=None, config=None, trace=None, trace_stack=None, make_traceable=None, num_evals=0, tuning_api_database=None)¶
The basis where a Relay optimization/analysis runs on. Each pass context contains a number of auxiliary information that is used to help an optimization pass. Such information includes the error reporter to record the errors of during the optimization, etc.
- opt_levelOptional[int]
The optimization level of this pass.
- required_passOptional[Union[List[str], Set[str], Tuple[str]]]
The list of passes that are required by a certain pass.
- disabled_passOptional[Union[List[str], Set[str], Tuple[str]]]
The list of passes that are disabled.
- instrumentsOptional[Sequence[PassInstrument]]
The list of pass instrument implementations.
- configOptional[Dict[str, Object]]
Additional configurations for specific passes.
- trace: Optional[relax.tuning.Trace]
Initial trace for trace mode.
- trace_stack: Optional[List[relax.tuning_api.Trace]]
Initial trace stack for trace mode.
- make_traceable: Optional[List[str]]
List of passes to make traceable.
- num_evals: int
initial number of evaluations conducted in the pipeline.
tuning_api_database: Optional[relax.tuning_api.JSONDatabase]
- override_instruments(instruments)¶
Override instruments within this PassContext.
If there are existing instruments, their
exit_pass_ctx
callbacks are called. Then switching to new instruments and calling newenter_pass_ctx
callbacks.- instrumentsSequence[PassInstrument]
The list of pass instrument implementations.
- static current()¶
Return the current pass context.
- static list_configs()¶
List all registered PassContext configuration names and metadata.
- push_trace(trace)¶
Push a trace into the stack.
- pop_trace(return_current=True)¶
Pop a topmost trace from the stack. :returns: Trace :rtype: Optional[relax.tuning.Trace]
- get_trace_stack()¶
Get the current trace stack.
- get_trace_stack_size()¶
Get the size of current stack.
- get_current_trace()¶
Get the trace on the top of the stack.
- get_tuning_api_database()¶
Get tuning api database.
- class tvm.transform.PassInfo(opt_level, name, required=None, traceable=False)¶
The class contains the meta data required by a pass. It is the container of information needed by running an optimization or analysis. This class can be extended by adding new members when more meta data is needed.
- tvm.transform.PrintIR(header='', show_meta_data=False)¶
A special trace pass that prints the header and IR.
- class tvm.transform.Sequential(passes=None, opt_level=0, name='sequential', required=None, traceable=False)¶
A pass that works on a sequence of pass objects. Multiple passes can be executed sequentially using this class.
Note that users can also provide a series of passes that they don’t want to apply when running a sequential pass. Pass dependency will be resolved in the backend as well.
- Parameters
passes (Optional[List[Pass]]) – A sequence of passes candidate for optimization.
opt_level (Optional[int]) – The optimization level of this sequential pass. The opt_level of a default sequential pass is set to 0. Note that some of the passes within the Sequantial may still not be executed if their opt_level is higher than the provided opt_level.
name (Optional[str]) – The name of the sequential pass.
required (Optional[List[str]]) – The list of passes that the sequential pass is dependent on.
- tvm.transform.module_pass(pass_func=None, opt_level=None, name=None, required=None, traceable=False)¶
Decorate a module pass.
This function returns a callback when pass_func is provided. Otherwise, it serves a decorator function.
pass_func can also be a class type with a method transform_module. This function will create a decorated ModulePass using transform_module as the pass function.
- Parameters
pass_func (Optional[Callable[(Module, PassContext) ->Module]]) – The transformation function or class.
opt_level (int) – The optimization level of this module pass.
name (Optional[str]) – The name of the module pass. The name could be empty. In this case, the name of the optimization function will be used as the pass name.
required (Optional[List[str]]) – The list of passes that the module pass is dependent on.
traceable (Boolean) – Boolean variable whether the module pass is traceable
- Returns
create_module_pass – A decorator will be returned if pass_func is not provided, otherwise return the decorated result. The returned decorator has two behaviors depending on the input: A new ModulePass will be returned when we decorate a pass function. A new ModulePass class will be returned when we decorate a class type.
- Return type
Union[Callable, ModulePass]
Examples
The following code block decorates a module pass class.
@relay.transform.module_pass class CustomPipeline: def __init__(self, enable_fold): self.enable_fold = enable_fold self.cse = relay.transform.EliminateCommonSubexpr() self.const_fold = relay.transform.FoldConstant() def transform_module(self, mod, ctx): mod = self.cse(mod, ctx) if self.enable_fold: mod = self.const_fold(mod, ctx) return mod # create an instance of customized pipeline pipeline = CustomPipeline(enable_fold=False) assert isinstance(pipeline, transform.ModulePass) # run the pipeline. output_module = pipeline(input_module)
The following code creates a module pass by decorating a user defined transform function.
@relay.transform.module_pass(opt_level=2) def transform(mod, ctx): tp = relay.TensorType((10,), "float32") x = relay.var("x", tp) gv = relay.GlobalVar("var") func = relay.Function([x], relay.abs(x)) new_mod = tvm.IRModule({gv: func}) new_mod.update(mod) return new_mod module_pass = transform assert isinstance(module_pass, transform.ModulePass) assert module_pass.info.opt_level == 2 # Given a module m, the optimization could be invoked as the follwoing: updated_mod = module_pass(m) # Now a function abs should be added to the module m.