tvm.relay.transform¶
The Relay IR namespace containing transformations.
Functions:
|
Convert the types of operations in a graph to a new value. |
Alternate the layouts of operators or replace primitive operators with other expressions. |
|
Annotate a program with span information by first generating its textual representation and then parsing it back into a Relay AST annotated with span information. |
|
|
Annotate ops in an experession with a provied compiler/target and then use it for codegen. |
Backward fold axis scaling into weights of conv2d/dense. |
|
Batching parallel operators into one for Conv2D, Dense and BatchMatmul. |
|
Canonicalize cast expressions to make operator fusion more efficient. |
|
Canonicalize special operators to basic operators. |
|
Captures the post-dfs index and dominator post-dfs index of (most) expression nodes in their span, in the form "index:<post-dfs index>:<dominator post-dfs index>". |
|
|
Partition the bodies of all functions according to the available targets so as to minimize model latency. |
|
Combine multiple batch matmul operators into one. |
|
Combine multiple conv2d operators into one. |
|
Combine multiple dense operators into one. |
|
Rewrite qualified |
|
Rewrite freezed |
|
Given a dest layout, this pass transforms the expr such that most of the ops input data layout is changed to the dest layout. |
|
Remove expressions that do not have any users (dead code). |
|
Performs defunctionalization on func, transforming func from a higher-order program to a first-order program. |
The inverse operation of FuseOps. |
|
|
Rewrite qualified |
|
Transform division by a constant to multiplication by the inverse of the constant |
If possible, convert tvm.relay.dynamic* ops to static versions |
|
|
Eliminate common subexpressions. |
|
Add abstraction over a constructor or global variable bound to a function |
|
Find regions of the graph of the form |
|
Converts the expensive non linear functions to their fast but approximate counterparts. |
Transforms all global functions in the module to return the original result, paired with the gradients of the inputs. |
|
The purpose of this pass is to find a sequence of space_to_batch_nd-conv2d-batch_to_space_nd operations: |
|
|
Fold the constant expressions in a Relay program. |
|
Fold the constant expressions in a Relay program. |
FoldExplicitPadding finds explict padding before an op that can support implicit padding and fuses them. |
|
Fold the scaling of axis into weights of conv2d/dense. |
|
Fold the scaling of axis into weights of conv2d/dense. |
|
|
Fuse operators in an expr to a larger operator according to some rules. |
Infer the type of an expr. |
|
|
Infer the type of a single expr, reusing type information to do so. |
|
Perform inlining on the given Relay IR module. |
|
Inlines all global functions bound to a global var in global_vars. |
Lift the closure to global function. |
|
Reduces memory usage of gradient tensors |
|
|
Legalizes an expression with another expression. |
Manifest the lifetimes of variables after allocations have been manifested, by inserting kill operations once variables become dead. |
|
|
Marks all global functions which have a "Compiler" attribute matching compiler_filter as 'extern'. |
Merge together compiler regions. |
|
|
Merge multiple operators into a single composite relay function. |
Outlines all literal functions in direct call positions which have a "Compiler" attribute. |
|
Evaluate the static fragment of the code. |
|
|
Partition a Relay program into regions that can be executed on different backends. |
|
Uses existing "on_device" and "device_copy" calls to infer the virtual device on which every Relay sub-expression should run and the result stored. |
|
Remove unused global relay functions in a relay module. |
Simplify the Relay expression, including merging consecutive reshapes. |
|
|
Rewrite |
Simplify the data-flow graph for inference phase. |
|
|
Split function with huge number of arguments to smaller pieces. |
Turn Graph Normal Form expression into A Normal Form Expression. |
|
ToANormalForm, but on expression level. |
|
Turn an expression to Basic Block Normal Form. |
|
|
Turn expression into continuation passing style(CPS). |
Turn a Relay program in A Normal Form into Graph Normal Form |
|
|
Automatic mixed precision rewriter. |
|
Configure the build behavior by setting config variables. |
|
Decorate a function pass. |
|
Transform the input function, returning a function that calculate the original result, paired with gradient of the input. |
|
Turn expression into CPS expression. |
|
Turn an cps function into a Function without the continuation argument. |
Classes:
|
Enable inference of multiple shaped inputs in one module. |
|
Change the batch size. |
A pass that works on each tvm.relay.Function in a module. |
|
|
A structure for customizing the ConvertLayout pass. |
- tvm.relay.transform.recast(expr, dtype, out_dtype, ops=None, skip_layers=None)¶
Convert the types of operations in a graph to a new value. Note that this is primarily useful for testing performance of individual operations at the new datatype. In a real setting, this pass will almost certainly do a poor job converting from one datatype to another as it just applies hard casting. For example, when recasting from float to integer, many small values will simply be set to 0. Although this will allow autotuning and benchmarking to produce proper timings at the new data type, the output of the model will of course be heavily impacted.
- Parameters
expr (tvm.relay.Expr, tvm.relay.Function, or tvm.ir.IRModule) – The original function that will have its type changed.
dtype (str) – The target type to cast to.
out_dtype (str) – The output type to cast to.
ops (List[str]) – A list of operations that should have their type changed, others will be left as is.
skip_layers (List[int]) – A list of integers indicating operations that should not have their type changed, counted starting with the first valid operation encountered. Negative indices are allowed and indicate starting at the last layer.
- Returns
output_expr – The graph after recasting to the specified datatype.
- Return type
tvm.relay.Expr, tvm.relay.Function, or tvm.ir.IRModule
- class tvm.relay.transform.FlexibleShapeDispatch(buckets, axis=0, auto_pad=False, pad_value=0, input_indices=None, affects_output=True)¶
Enable inference of multiple shaped inputs in one module.
This transformation adds a handler around a module that checks input shapes and dispatches to a subgraph specialized to handle the specific shapes of that input. If no exactly matching subgraph is available, the input will be run using full dynamism. For best performance, specify all the sizes the module will be likely to see using the buckets argument.
By default, this pass will dispatch shapes that exactly match one of the buckets to a corresponding subgraph. All non-matching shapes use the same fully dynamic fallback. This can be detrimental to performance for those non-matching shapes. Setting auto_pad to True causes this pass to round-up the shape of non-matching inputs to the closest bucket. This allows them to use the tuned kernels of bucket shapes which can improve performance.
Models that have multiple inputs sharing a dynamic axis, which is common for batch size or sequence length dynamism, are supported through the input_indices argument.
Many types of dynamism such as batching affect both the input and output shape, however this is not always the case. If the output shape is independent of the input, the affects_output argument of this pass must be set to False.
- Parameters
buckets (list[int]) – The sizes of the input dimension that should be explicitly handled. Each value in buckets will have a corresponding subgraph constructed to handle it.
axis (int) – The dimension of the input that should be made flexible. This will most often be used for the batch dimension.
auto_pad (Optional[bool]) – If True, then padding will be inserted to values that don’t match one of the provided buckets.
pad_value (Optional[float]) – When auto_pad is true, padding will be done with this value.
input_indices (Optional[List[int]]) – Which inputs should be dispatched dynamically, provided by index. All inputs must share the same dynamic axis.
affects_output (Optional[bool]) – Whether the change in input shape has a corresponding effect on the output shape. Batching for example effects both the input and output whereas changing sequence length in an NLP model typically does not.
- Returns
ret – A pass that can be applied to a module to add flexible shape handling.
- Return type
- tvm.relay.transform.AlterOpLayout()¶
Alternate the layouts of operators or replace primitive operators with other expressions. This pass can be used for computing convolution in custom layouts or other general weight pre-transformation.
- Returns
ret – The registered pass that alters the layout of operators.
- Return type
- tvm.relay.transform.AnnotateSpans()¶
Annotate a program with span information by first generating its textual representation and then parsing it back into a Relay AST annotated with span information.
- Returns
ret – The registered AnnotateSpans pass.
- Return type
- tvm.relay.transform.AnnotateTarget(targets, include_non_call_ops=True)¶
Annotate ops in an experession with a provied compiler/target and then use it for codegen.
- Parameters
- Returns
ret – The annotated pass that wrapps ops with subgraph_start and subgraph_end.
- Return type
- tvm.relay.transform.BackwardFoldScaleAxis()¶
Backward fold axis scaling into weights of conv2d/dense.
- Returns
ret – The registered pass to backward fold expressions.
- Return type
Note
It is recommended to call backward_fold_scale_axis before using forward_fold_scale_axis as backward folding targets the common conv->bn pattern.
- tvm.relay.transform.BatchingOps()¶
Batching parallel operators into one for Conv2D, Dense and BatchMatmul.
- Returns
ret – The sequential pass which apply batching for different operator types.
- Return type
- tvm.relay.transform.CanonicalizeCast()¶
Canonicalize cast expressions to make operator fusion more efficient.
- Returns
ret – The registered pass that canonicalizes cast expression.
- Return type
- tvm.relay.transform.CanonicalizeOps()¶
Canonicalize special operators to basic operators. This can simplify followed analysis, e.g. expanding bias_add to expand_dims and broadcast_add.
- Returns
ret – The registered pass performing the canonicalization.
- Return type
- tvm.relay.transform.CapturePostDfsIndexInSpans()¶
Captures the post-dfs index and dominator post-dfs index of (most) expression nodes in their span, in the form “index:<post-dfs index>:<dominator post-dfs index>”.
This is useful for debugging since a) it helps identify pretty-printed sub-expressions within the overall model and b) the indexes are heavily used by Collage for its compact representation of sub-graphs.
Note that Op and Constructor nodes are not changed even though they are assigned an post-dfs index.
- Returns
ret – The pass.
- Return type
- class tvm.relay.transform.ChangeBatch(data, batch_size=16)¶
Change the batch size.
- Parameters
- Returns
pass – The pass.
- Return type
- tvm.relay.transform.CollagePartition(config, cost_estimator=None)¶
Partition the bodies of all functions according to the available targets so as to minimize model latency. See https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md.
- Parameters
config (CompilationConfig) – The available targets.
cost_estimator (CostEstimator, optional) – The custom cost estimator to use for costing each candidate partition.
- Returns
ret – The pass.
- Return type
- tvm.relay.transform.CombineParallelBatchMatmul(min_num_branches=3)¶
Combine multiple batch matmul operators into one. For example:
Would become:
- Parameters
min_num_branches (int) – The minimum number of required parallel branches for performing this optimization.
- Returns
ret – The registered pass that combines parallel dense operators.
- Return type
- tvm.relay.transform.CombineParallelConv2D(min_num_branches=3)¶
Combine multiple conv2d operators into one.
- Parameters
min_num_branches (int) – The minimum number of required parallel branches for performing this optimization.
- Returns
ret – The registered pass that combines parallel conv2d operators.
- Return type
- tvm.relay.transform.CombineParallelDense(min_num_branches=3, to_batch=True)¶
Combine multiple dense operators into one. For example:
Would become:
or (if to_batch=False)
- Parameters
- Returns
ret – The registered pass that combines parallel dense operators.
- Return type
- tvm.relay.transform.Conv2dToSparse(weight_name, weight_shape, layout, kernel_size)¶
Rewrite qualified
`nn.conv2d operation`
to`nn.sparse_conv2d`
- tvm.relay.transform.Conv2dToSparse2(layout, kernel_size, blocksize, sparsity_threshold)¶
Rewrite freezed
`nn.conv2d`
operation to`nn.sparse_conv2d`
- Parameters
- Returns
ret – The registered DenseToSparse pass.
- Return type
- tvm.relay.transform.ConvertLayout(desired_layouts)¶
Given a dest layout, this pass transforms the expr such that most of the ops input data layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one at the start and one at the end.
This pass is not a part of relay.build and is expected to be called between framework-relay parser and relay.build call. This is very helpful for hardware backends that support/prefer only type of data layout.
RFC - https://discuss.tvm.apache.org/t/layout-conversion-pass/4009
This pass uses most of the AlterOpLayout and InferCorrectLayout infrastructure. We can define new layouts for conv2d ops for now. Most of the other operators try to adapt to their input layout using the InferCorrectLayout infrastructure.
- Parameters
desired_layouts (map of op_name to list of layouts) – Specify a mapping of operator names to a list of layouts to convert to, in the order defined by the operator. An example for nn.conv2d could be: {“nn.conv2d”, [“NHWC”, “OHWI]}, where the first item in the list specifies the data layout and the second specifies the kernel layout.
- Returns
pass – The pass.
- Return type
- tvm.relay.transform.DeadCodeElimination(inline_once=False, ignore_impurity=False)¶
Remove expressions that do not have any users (dead code).
- Parameters
inline_once (Optional[Bool]) – Whether to inline a binding that is referenced exactly once.
ignore_impurity (Optional[Bool]) – Whether to ignore possible side-effects in let-bound expressions.
- Returns
ret – The registered pass that eliminates the dead code in a Relay program.
- Return type
- tvm.relay.transform.Defunctionalization(func, mod)¶
Performs defunctionalization on func, transforming func from a higher-order program to a first-order program.
At each call site, the function is cloned and type parameters are substituted in. Function arguments are encoded as datatypes and additional apply functions are used for application.
- Parameters
func (tvm.relay.Function) – The input function, which should not be polymorphic or be higher-order. This is because all types must be known and we can’t encode function arguments to the program itself.
mod (tvm.IRModule) – The IRModule containing function and type definitions, which is also mutated during this pass.
- Returns
expr – The output function.
- Return type
tvm.relay.Function
- tvm.relay.transform.DefuseOps()¶
The inverse operation of FuseOps. It transforms a fused program returned by FuseOps into the program before FuseOps. (i.e., x == DefuseOps(FuseOps(x)))
- Returns
ret – The registered pass for operator defusion.
- Return type
- tvm.relay.transform.DenseToSparse(weight_name, weight_shape)¶
Rewrite qualified
`nn.dense operation`
to`nn.sparse_dense`
This pass is used in`data_dep_optimization.bsr_dense`
Parameters of this pass is generated by`analysis.sparse_dense.process_params`
- tvm.relay.transform.DivToMul()¶
Transform division by a constant to multiplication by the inverse of the constant
- tvm.relay.transform.DynamicToStatic()¶
If possible, convert tvm.relay.dynamic* ops to static versions
- Returns
ret – The registered pass for dynamic->static conversion.
- Return type
- tvm.relay.transform.EliminateCommonSubexpr(fskip=None)¶
Eliminate common subexpressions.
- Parameters
fskip (Callable) – The callback function that decides whether an expression should be skipped.
- Returns
ret – The registered pass that eliminates common subexpressions.
- Return type
- tvm.relay.transform.EtaExpand(expand_constructor=False, expand_global_var=False)¶
Add abstraction over a constructor or global variable bound to a function
- Parameters
- Returns
ret – The registered pass that eta expands an expression.
- Return type
- tvm.relay.transform.FakeQuantizationToInteger(hard_fail=False, use_qat=False, optional_qnn_ops=None)¶
Find regions of the graph of the form
x w | | dq dq \ / op1 | op2 | q
where
q == qnn.quantize
anddq = qnn.dequantize
and rewrite them into integer versions ofop1
andop2
Rules for rewriting indivdual ops are in fake_quantization_to_integer.py
- Parameters
hard_fail (boolean) – How do deal with errors during graph rewriting. If true, raise an error. If false, skip rewriting the subgraph.
use_qat (boolean) –
To perform an additional QAT pass - convert enabled operations with dequantized inputs. Example: in the graph above op2 is not registered with the FakeQuantizationToInteger attribute, op1 operation can still be converted. Converted pattern below:
x w | | \ / op1 | dq | op2 | q
optional_qnn_ops (List[str]) – Specify a list of operator names to explicitly enable conversion for specific ops disabled by default. Example: [‘nn.softmax’]
- Returns
ret – The registered FakeQuantizationToInteger pass.
- Return type
- tvm.relay.transform.FastMath()¶
Converts the expensive non linear functions to their fast but approximate counterparts.
- Returns
ret – The registered pass to perform fast math operations.
- Return type
- tvm.relay.transform.FirstOrderGradient()¶
Transforms all global functions in the module to return the original result, paired with the gradients of the inputs. This pass transforms each global function independently and does not support interprocedural AD. Additionally, this pass does not support any control-flow or references, and should only be used on pure data-flow graphs.
- Returns
ret – The registered FirstOrderGradient pass.
- Return type
- tvm.relay.transform.FlattenAtrousConv()¶
The purpose of this pass is to find a sequence of space_to_batch_nd-conv2d-batch_to_space_nd operations:
x w | | s2b | \ / conv2d | b2s
and convert them into subgraphs with a convolution with the modified “dilation” and recalculated “padding” parameters.
- Returns
ret – The registered FlattenAtrousConv pass.
- Return type
- tvm.relay.transform.FoldConstant(fold_qnn=False)¶
Fold the constant expressions in a Relay program.
Because of backward compatibility reason it skips QNN primitives from folding by default. There are some transformation passes like FakeQuantizationToInteger, which requires to keep QNN primitives for constant subgraphs. Uncontrolled constant folding of QNN primitives may break applicability of FakeQuantizationToInteger. We suggest to use FoldConstant pass with none default fold_qnn=True value only when all other QNN sensitive passes were already applied.
- Parameters
fold_qnn (bool) – Whether to fold constants for QNN operations.
- Returns
ret – The registered pass for constant folding.
- Return type
- tvm.relay.transform.FoldConstantExpr(expr, mod, fold_qnn=False)¶
Fold the constant expressions in a Relay program. :param expr: The expression to fold :type expr: Expr :param mod: The module the expr lives in (for global calls) :type mod: IRModule :param fold_qnn: Whether to fold constants for QNN operations. :type fold_qnn: bool
- Returns
new_expr – The expr after Constant Folding
- Return type
Expr
- tvm.relay.transform.FoldExplicitPadding()¶
FoldExplicitPadding finds explict padding before an op that can support implicit padding and fuses them.
- Returns
ret – The registered ImplicitPadding pass.
- Return type
- tvm.relay.transform.FoldScaleAxis()¶
Fold the scaling of axis into weights of conv2d/dense. This pass will invoke both forward and backward scale folding.
- Returns
ret – The registered pass to fold expressions.
- Return type
Note
Internally, we will call backward_fold_scale_axis before using forward_fold_scale_axis as backward folding targets the common conv->bn pattern.
- tvm.relay.transform.ForwardFoldScaleAxis()¶
Fold the scaling of axis into weights of conv2d/dense.
- Returns
ret – The registered pass to forward fold expressions.
- Return type
Note
It is recommended to call backward_fold_scale_axis before using forward_fold_scale_axis, as backward folding targets the common conv->bn pattern.
- class tvm.relay.transform.FunctionPass¶
A pass that works on each tvm.relay.Function in a module. A function pass class should be created through function_pass.
- tvm.relay.transform.FuseOps(fuse_opt_level=- 1)¶
Fuse operators in an expr to a larger operator according to some rules.
- Parameters
fuse_opt_level (int) – The level of fuse optimization. -1 indicates that the level will be inferred from pass context.
- Returns
ret – The registered pass for operator fusion.
- Return type
- tvm.relay.transform.InferType()¶
Infer the type of an expr.
- Returns
ret – The registered type inference pass.
- Return type
- tvm.relay.transform.InferTypeLocal(expr)¶
Infer the type of a single expr, reusing type information to do so.
This populates the checked_type field in expr. We assume existing type information in the graph is correct!
- Parameters
expr (relay.Expr) – The expression we want to know the type of
- Returns
type – The type of the expression
- Return type
relay.Type
- tvm.relay.transform.Inline()¶
Perform inlining on the given Relay IR module. The global functions that are marked as inline should be always inlined. A cost model will be needed in the future to decide if it is profitable to inline the function.
- Returns
ret – The registered pass that performs inlining for a Relay IR module.
- Return type
- tvm.relay.transform.InlineCompilerFunctionsBoundTo(global_vars)¶
Inlines all global functions bound to a global var in global_vars.
Both the global “Compiler” attributed function, and any calls to “Composite” functions it its body are inlined.
This pass may be useful for external codegen which needs to undo partitioning based on properties of the entire partition.
- Parameters
global_vars (Array[tvm.relay.GlobalVar]) – The global vars of all ‘Compiler’ functions to inline.
- Returns
ret – The pass.
- Return type
- tvm.relay.transform.LambdaLift()¶
Lift the closure to global function.
- Returns
ret – The registered pass that lifts the lambda function.
- Return type
- class tvm.relay.transform.LayoutConfig(skip_layers=None)¶
A structure for customizing the ConvertLayout pass.
- tvm.relay.transform.LazyGradientInit()¶
Reduces memory usage of gradient tensors
- Returns
ret – A pass which delays and/or reduces memory allocation, by lazily allocating 0 or one filled tensors.
- Return type
- tvm.relay.transform.Legalize(legalize_map_attr_name='FTVMLegalize')¶
Legalizes an expression with another expression. This pass can be used to replace an expr with another expr for target dependent optimizations. For example, one expr, though semnatically equivalent to the other, can have better performance on a target. This pass can be used to legalize the expr in a target-dependent manner.
- Parameters
legalize_map_attr_name (str) – The Op’s attr name which corresponds to the legalize rule function.
- Returns
ret – The registered pass that rewrites an expr.
- Return type
- tvm.relay.transform.ManifestLifetimes()¶
Manifest the lifetimes of variables after allocations have been manifested, by inserting kill operations once variables become dead.
- tvm.relay.transform.MarkCompilerFunctionsAsExtern(compiler_filter='')¶
Marks all global functions which have a “Compiler” attribute matching compiler_filter as ‘extern’.
The function’s attributes are replaced with a single “Extern” attribute, and all calls to the function are switched to use the ‘call_lowered’ calling convention.
If compiler_filter is non-empty only functions with that as their attribute value are outlined.
This pass may be useful for external codegen using the “RelayToTIR” custom pass mechanism to cleanup the IRModule after custom lowering.
- Parameters
compiler_filter (String) – If non-empty, the “Compiler” attribute to filter on.
- Returns
ret – The pass.
- Return type
- tvm.relay.transform.MergeCompilerRegions()¶
Merge together compiler regions.
- Returns
ret – The registered pass that merges compiler regions.
- Return type
- tvm.relay.transform.MergeComposite(pattern_table)¶
Merge multiple operators into a single composite relay function.
- Parameters
pattern_table (List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Function]]) – A list of (pattern_name, pattern, check) tuples. The order of the patterns in the list will determine the order of priority in which they are matched. ‘check’ is a function to check whether an extracted pattern matches. It can be implemented by pattern writer but if not specified it will always return True.
- Returns
ret – The registered pass that merges operators into a single composite relay function.
- Return type
- tvm.relay.transform.OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter='')¶
Outlines all literal functions in direct call positions which have a “Compiler” attribute.
The outlined functions are bound to unique global vars according to their existing “global_symbol” attribute. At most one function with the same global symbol is outlined.
If compiler_filter is non-empty only functions with that as their attribute value are outlined.
This pass may be useful for external codegen using the “RelayToTIR” custom pass mechanism to prepare the IRModule before custom lowering.
- Parameters
compiler_filter (String) – If non-empty, the “Compiler” attribute to filter on.
- Returns
ret – The pass.
- Return type
- tvm.relay.transform.PartialEvaluate()¶
Evaluate the static fragment of the code.
Note
This transformation could be either Module -> Module or Expr -> Expr. It will directly transform the input expression to a new one if the target expression is provided. Otherwise, it will rely on the pass manager to carry out transformation.
- Returns
ret – The registered pass that performs partial evaluation on an expression.
- Return type
- tvm.relay.transform.PartitionGraph(mod_name='default', bind_constants=True)¶
Partition a Relay program into regions that can be executed on different backends.
- Parameters
mod_name (string) – Controls the prefix of the name of each partitioned subraph. If mod_name is None, then tvmgen_ prefix is used. Otherwise, tvmgen_mod_name_ prefix is used.
bind_constants (bool) – Whether or not to bind constants in partitioned subgraphs. Note that the codegen needs to maintain the bound constants; Otherwise the constants will be maintained by the metadata module. So it is recommended for C-source based codegens to set bind_constants=False to avoid embedding large constants in a C source file.
- Returns
ret – The registered pass that partitions the Relay program.
- Return type
- tvm.relay.transform.PlanDevices(config)¶
Uses existing “on_device” and “device_copy” calls to infer the virtual device on which every Relay sub-expression should run and the result stored. Captures the result of that analysis using new “on_device” and “device_copy” calls. Sub-expressions which are not otherwise constrained are assigned to the default primitive virtual device describe by config. However data and computations which must be hosted on a CPU (such as shapes and shape functions) use the host virtual device of the config.
- Parameters
config (tvm.CompilationConfig) – The compilation configuration, specifying available targets and default devices.
- Returns
ret – The pass.
- Return type
tvm.transforms.Pass
- tvm.relay.transform.RemoveUnusedFunctions(entry_functions=None)¶
Remove unused global relay functions in a relay module.
- Parameters
entry_functions (list[string]) – The set of entry functions to start from.
- Returns
ret – The registered pass to remove unused functions.
- Return type
- tvm.relay.transform.SimplifyExpr()¶
Simplify the Relay expression, including merging consecutive reshapes.
- Returns
ret – The registered SimplifyExpr pass.
- Return type
- tvm.relay.transform.SimplifyFCTranspose(target_weight_name)¶
Rewrite
`y = nn.dense(x, transpose(w, [1, 0]))`
to`y = nn.dense(x, wt)`
This pass is used in`data_dep_optimization.simplify_fc_transpose`
- tvm.relay.transform.SimplifyInference()¶
Simplify the data-flow graph for inference phase. An simplified expression which is semantically equal to the input expression will be returned.
Note that batch norms will only be simplified if their result is indexed at tuple index 0.
- Returns
ret – The registered pass to perform operator simplification.
- Return type
- tvm.relay.transform.SplitArgs(max_function_args)¶
Split function with huge number of arguments to smaller pieces.
- Parameters
max_function_args (int) – Maximum number of function arguments. If it equals 0 then SplitArgs shouldn’t split the function.
- Returns
ret – The registered pass.
- Return type
- tvm.relay.transform.ToANormalForm()¶
Turn Graph Normal Form expression into A Normal Form Expression. The scope of the root expression is the global scope. The scope of any non root expression is the least common ancestor of all it’s scope. Values are ordered by post-DFS order in each scope.
- Returns
ret – The registered pass that transforms an expression into A Normal Form.
- Return type
Union[tvm.transform.Pass, tvm.relay.Expr]
- tvm.relay.transform.ToANormalFormExpr(e)¶
ToANormalForm, but on expression level.
- Parameters
e (Expr) – The graph expression.
- Returns
ret – The transformed expresion.
- Return type
Expr
- tvm.relay.transform.ToBasicBlockNormalForm()¶
Turn an expression to Basic Block Normal Form. We define a block as a group of expressions implied by the scope structure. Each graph node can only belong to a single block. For any value that is being used in multiple blocks, it has to be referred by a Var which is defined in a block, whose scope is the least common ancestor of blocks this value is used.
- Returns
ret – The registered pass that transforms an expression into Basic Block Normal Form.
- Return type
- tvm.relay.transform.ToCPS(expr, mod=None)¶
Turn expression into continuation passing style(CPS).
Every intermediate compute will be passed to a continuation.
- Returns
result – The registered pass that transforms an expression into CPS.
- Return type
- tvm.relay.transform.ToGraphNormalForm()¶
Turn a Relay program in A Normal Form into Graph Normal Form
- Returns
ret – The registered pass that transforms an expression into Graph Normal Form.
- Return type
- tvm.relay.transform.ToMixedPrecision(mixed_precision_type='float16', missing_op_mode=1)¶
Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version where as many operations as possible are in the target mixed_precision_type.
- Parameters
mixed_precision_type (str) – The target datatype to transform operations in the graph to use.
missing_op_mode (int) –
- Determines how to handle ops not registered with FTVMMixedPrecisionConversionType
0: Does not allow any missing ops. Will throw errors when encountering any. 1: Allow missing ops but emit warnings. 2: Allow missing ops and silently ignore them.
relay.ToMixedPrecision.keep_orig_output_dtype (boolean) – Defines if outputs should be retained in original data type or convert to mixed_precision_type. By default this parameter is False and transformation modifies the data types of outputs to mixed_precision_type. This parameter is not part of explicit arguments of the transformation, but should be passed through tvm.transform.PassContext.
- Returns
ret – The registered pass.
- Return type
- tvm.relay.transform.build_config(opt_level=2, required_pass=None, disabled_pass=None, trace=None)¶
Configure the build behavior by setting config variables. This function will be deprecated in TVM v0.7. Instead, we should directly use tvm.transform.PassContext.
- Parameters
opt_level (int, optional) –
Optimization level. The optimization pass name and level are as the following:
OPT_PASS_LEVEL = { "SimplifyInference": 0, "OpFusion": 1, "FoldConstant": 2, "FoldScaleAxis": 3, "AlterOpLayout": 3, "CanonicalizeOps": 3, "CanonicalizeCast": 3, "EliminateCommonSubexpr": 3, "CombineParallelConv2D": 4, "CombineParallelDense": 4, "CombineParallelBatchMatmul": 4, "FastMath": 4 }
required_pass (set of str, optional) – Optimization passes that are required regardless of optimization level.
disabled_pass (set of str, optional) – Optimization passes to be disabled during optimization.
trace (Callable[[IRModule, PassInfo, bool], None]) – A tracing function for debugging or introspection.
- Returns
pass_context – The pass context for optimizations.
- Return type
- tvm.relay.transform.function_pass(pass_func=None, opt_level=None, name=None, required=None)¶
Decorate a function pass.
This function returns a callback when pass_func is provided. Otherwise, it returns the created function pass using the given optimization function.
- Parameters
pass_func (Optional[Callable[(Function, Module, PassContext) -> Function]]) – The transformation function or class.
opt_level (int) – The optimization level of this module pass.
name (Optional[str]) – The name of the function pass. The name could be empty. In this case, the name of the optimization function will be used as the pass name.
required (Optional[List[str]]) – The list of passes that the module pass is dependent on.
- Returns
create_function_pass – A decorator will be returned if pass_func is not provided, otherwise return the decorated result. The returned decorator has two behaviors depending on the input: A new FunctionPass will be returned when we decorate a pass function. A new FunctionPass class will be returned when we decorate a class type.
- Return type
Union[Callable, FunctionPass]
Examples
The following code block decorates a function pass class.
@relay.transform.function_pass(opt_level=1) class TestReplaceFunc: def __init__(self, new_func): self.new_func = new_func def transform_function(self, func, mod, ctx): # just for demo purposes # transform func to new_func return self.new_func x = relay.var("x", shape=(10, 20)) f1 = relay.Function([x], x) f2 = relay.Function([x], relay.log(x)) # fpass is now a special pass that replaces every # function to f1 fpass = TestReplaceFunc(f1) # now every function in input_mod is replaced by f1 res_mod = fpass(input_mod)
The following code creates a function pass by decorating a user defined transform function.
@relay.transform.function_pass(opt_level=2) def transform(func, mod, ctx): # my transformations here. return func function_pass = transform assert isinstance(function_pass, transform.FunctionPass) assert function_pass.info.opt_level == 2 # Given a module m, the optimization could be invoked as the follwoing: updated_mod = function_pass(m) # Now constant folding should have been applied to every function in # the provided module m. And the updated module will be returned.
- tvm.relay.transform.gradient(expr, mod=None, mode='higher_order')¶
Transform the input function, returning a function that calculate the original result, paired with gradient of the input.
- Parameters
expr (tvm.relay.Expr) – The input expression, which is a Function or a GlobalVar.
mod (Optional[tvm.IRModule]) –
mode (Optional[String]) – The mode of the automatic differentiation algorithm. ‘first_order’ only works on first order code, but will not produce reference nor closure. ‘higher_order’ works on all code using reference and closure.
- Returns
expr – The transformed expression.
- Return type
tvm.relay.Expr
- tvm.relay.transform.to_cps(func, mod=None)¶
Turn expression into CPS expression.
Every intermediate compute will be passed to a continuation.
- Parameters
func (tvm.relay.Function) – The input function.
mod (Optional[tvm.IRModule]) – The global module.
- Returns
result – The output function.
- Return type
tvm.relay.Function
- tvm.relay.transform.un_cps(func)¶
Turn an cps function into a Function without the continuation argument.
- Note that this will not give the exact same interface as before cps:
If the input/output is higher order, they will still be in cps form.
- Parameters
func (tvm.relay.Function) – The input function
- Returns
result – The output function
- Return type
tvm.relay.Function