tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_TRANSFORM_H_
25 #define TVM_RELAX_TRANSFORM_H_
26 
27 #include <tvm/ffi/reflection/registry.h>
28 #include <tvm/ir/transform.h>
30 #include <tvm/relax/expr.h>
31 #include <tvm/tir/function.h>
32 #include <tvm/tir/index_map.h>
33 
34 namespace tvm {
35 namespace relax {
36 namespace transform {
37 
44 
56 TVM_DLL Pass CreateFunctionPass(std::function<Function(Function, IRModule, PassContext)> pass_func,
57  int opt_level, ffi::String name,
58  tvm::ffi::Array<ffi::String> required, bool traceable = false);
59 
72  std::function<DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func, int opt_level,
73  ffi::String name, tvm::ffi::Array<ffi::String> required, bool traceable = false);
74 
80 TVM_DLL Pass LambdaLift();
81 
87 TVM_DLL Pass ToNonDataflow();
88 
101 
108 
122 
145 
152 
159 TVM_DLL Pass Normalize();
160 
169 
182 
190 TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false);
191 
200 TVM_DLL Pass BindParams(ffi::String func_name, ffi::Map<Any, ObjectRef> params);
201 
217 TVM_DLL Pass BindSymbolicVars(ffi::Map<ffi::Variant<tir::Var, ffi::String>, PrimExpr> binding_map,
218  ffi::Optional<ffi::String> func_name = std::nullopt);
219 
227 TVM_DLL Pass FoldConstant();
228 
253 TVM_DLL Pass LegalizeOps(ffi::Optional<ffi::Map<ffi::String, ffi::Function>> cmap,
254  ffi::Optional<ffi::Array<ffi::String>> skip_ops,
255  bool enable_warning = false);
256 
262 
274 
283 
310 TVM_DLL Pass
311 LiftTransformParams(ffi::Variant<Bool, ffi::Array<ffi::String>> shared_transform = Bool(false));
312 
319 TVM_DLL Pass UpdateVDevice(VDevice new_vdevice, int64_t index);
320 
326 
332 
338 
347 
359 TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
360 
366 class FusionPatternNode : public Object {
367  public:
372  ffi::String name;
373 
379 
384  ffi::Map<ffi::String, DFPattern> annotation_patterns;
385 
393  ffi::Optional<ffi::Function> check;
394 
401  ffi::Optional<ffi::Function> attrs_getter;
402 
403  static void RegisterReflection() {
404  namespace refl = tvm::ffi::reflection;
405  refl::ObjectDef<FusionPatternNode>()
406  .def_ro("name", &FusionPatternNode::name)
407  .def_ro("pattern", &FusionPatternNode::pattern)
408  .def_ro("annotation_patterns", &FusionPatternNode::annotation_patterns)
409  .def_ro("check", &FusionPatternNode::check)
410  .def_ro("attrs_getter", &FusionPatternNode::attrs_getter);
411  }
412  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.transform.FusionPattern", FusionPatternNode, Object);
413 };
414 
415 class FusionPattern : public ObjectRef {
416  public:
417  FusionPattern(ffi::String name, DFPattern pattern,
418  ffi::Map<ffi::String, DFPattern> annotation_patterns,
419  ffi::Optional<ffi::Function> check, ffi::Optional<ffi::Function> attrs_getter);
420 
421  FusionPattern(ffi::String name, DFPattern pattern)
422  : FusionPattern(name, pattern, {}, std::nullopt, std::nullopt) {}
423 
425 };
426 
430 class PatternCheckContextNode : public Object {
431  public:
436 
441  ffi::Map<ffi::String, Expr> annotated_expr;
442 
447  ffi::Map<Var, Expr> matched_bindings;
448 
453  ffi::Map<Var, ffi::Array<Var>> var_usages;
454 
459  ffi::Map<Expr, Var> value_to_bound_var;
460 
461  static void RegisterReflection() {
462  namespace refl = tvm::ffi::reflection;
463  refl::ObjectDef<PatternCheckContextNode>()
464  .def_ro("matched_expr", &PatternCheckContextNode::matched_expr)
465  .def_ro("annotated_expr", &PatternCheckContextNode::annotated_expr)
466  .def_ro("matched_bindings", &PatternCheckContextNode::matched_bindings)
467  .def_ro("var_usages", &PatternCheckContextNode::var_usages)
468  .def_ro("value_to_bound_var", &PatternCheckContextNode::value_to_bound_var);
469  }
470  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.transform.PatternCheckContext", PatternCheckContextNode,
471  Object);
472 };
473 
474 class PatternCheckContext : public ObjectRef {
475  public:
476  PatternCheckContext(Expr matched_expr, ffi::Map<ffi::String, Expr> annotated_expr,
477  ffi::Map<Var, Expr> matched_bindings,
478  ffi::Map<Var, ffi::Array<Var>> var_usages,
479  ffi::Map<Expr, Var> value_to_bound_var);
480 
483 };
484 
510 TVM_DLL Pass Gradient(ffi::String func_name,
511  ffi::Optional<ffi::Array<Var>> require_grads = std::nullopt,
512  int target_index = 0);
513 
534 TVM_DLL Pass FuseOpsByPattern(const tvm::ffi::Array<FusionPattern>& patterns,
535  bool bind_constants = true, bool annotate_codegen = false,
536  const tvm::ffi::Array<ffi::String>& entry_function_names = {});
537 
546 
553 TVM_DLL Pass FuseTIR();
554 
561 TVM_DLL Pass
562 RunCodegen(ffi::Optional<ffi::Map<ffi::String, ffi::Map<ffi::String, ffi::Any>>> target_options,
563  ffi::Array<ffi::String> entry_functions);
564 
573 TVM_DLL Pass DecomposeOpsForInference(ffi::Optional<ffi::String> func_name);
574 
583 TVM_DLL Pass DecomposeOpsForTraining(ffi::Optional<ffi::String> func_name);
584 
600  const ffi::Map<ffi::String, tir::PrimFunc>& op_impl_map,
601  const ffi::Map<ffi::String, ffi::Array<tir::IndexMap>>& op_buffer_transforms,
602  const ffi::Map<ffi::String, ffi::Optional<ffi::Array<ffi::Array<IntImm>>>>& axis_separators,
603  const ffi::Map<ffi::String, ffi::Optional<ffi::Array<ffi::Array<IntImm>>>>&
604  input_axis_separators);
605 
612 TVM_DLL Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts);
613 
621 TVM_DLL Pass ConvertToDataflow(int min_size = 2);
622 
639 TVM_DLL Pass DeadCodeElimination(ffi::Array<ffi::String> entry_functions = {});
640 
651 
662 TVM_DLL Pass
663 ToMixedPrecision(const DataType& out_dtype,
664  ffi::Optional<ffi::Array<ffi::String>> fp16_input_names = std::nullopt);
665 
672 
683 TVM_DLL Pass FewShotTuning(int valid_count, bool benchmark);
684 
691 
692 } // namespace transform
693 } // namespace relax
694 } // namespace tvm
695 
696 #endif // TVM_RELAX_TRANSFORM_H_
Boolean constant.
Definition: expr.h:565
Managed reference class to IRModuleNode.
Definition: module.h:256
Reference to PrimExprNode.
Definition: expr.h:124
Managed reference to RelaxExprNode.
Definition: expr.h:439
Managed reference to VDeviceNode.
Definition: global_info.h:87
Managed reference to dataflow patterns.
Definition: dataflow_pattern.h:101
Definition: expr.h:697
Definition: expr.h:834
Definition: expr.h:381
The pattern object used as the input of FuseOpsByPattern. For bindings to be fused,...
Definition: transform.h:366
ffi::Optional< ffi::Function > attrs_getter
The function to get attributes for fused function.
Definition: transform.h:401
ffi::Optional< ffi::Function > check
The function to determine whether the match result is accepted. This can be std::nullopt if check fun...
Definition: transform.h:393
ffi::String name
The name of pattern. It becomes the value of the kComposite attribute of a fused function after succe...
Definition: transform.h:372
ffi::Map< ffi::String, DFPattern > annotation_patterns
The map which is used to extract important expressions from the pattern match result....
Definition: transform.h:384
DFPattern pattern
The dataflow pattern that will be used to match expression in the DataflowBlock. All the call nodes c...
Definition: transform.h:378
static void RegisterReflection()
Definition: transform.h:403
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.transform.FusionPattern", FusionPatternNode, Object)
Definition: transform.h:415
FusionPattern(ffi::String name, DFPattern pattern)
Definition: transform.h:421
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(FusionPattern, ObjectRef, FusionPatternNode)
FusionPattern(ffi::String name, DFPattern pattern, ffi::Map< ffi::String, DFPattern > annotation_patterns, ffi::Optional< ffi::Function > check, ffi::Optional< ffi::Function > attrs_getter)
The input of FusionPattern::check.
Definition: transform.h:430
ffi::Map< Expr, Var > value_to_bound_var
Map from value to its bound variable. It doesn't have variables after the matched expression.
Definition: transform.h:459
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.transform.PatternCheckContext", PatternCheckContextNode, Object)
ffi::Map< Var, Expr > matched_bindings
Map from variable to its value. It contains variables from bindings that is being fused by FuseOpsByP...
Definition: transform.h:447
static void RegisterReflection()
Definition: transform.h:461
ffi::Map< ffi::String, Expr > annotated_expr
A map which contains all expressions matched by the sub patterns in FusionPattern::annotation_pattern...
Definition: transform.h:441
Expr matched_expr
The expression that's matched with the FusionPattern::pattern.
Definition: transform.h:435
ffi::Map< Var, ffi::Array< Var > > var_usages
A map mapping variable definitions to a set of uses. It has all variables used in the function.
Definition: transform.h:453
PatternCheckContext(Expr matched_expr, ffi::Map< ffi::String, Expr > annotated_expr, ffi::Map< Var, Expr > matched_bindings, ffi::Map< Var, ffi::Array< Var >> var_usages, ffi::Map< Expr, Var > value_to_bound_var)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PatternCheckContext, ObjectRef, PatternCheckContextNode)
Runtime primitive data type.
Definition: data_type.h:47
PassContext that is used to configure the pass behavior.
Definition: transform.h:153
Managed reference class for PassInfoNode.
Definition: transform.h:350
Definition: transform.h:400
A pattern language for matching dataflow properties.
Defines a remapping of buffer indices.
Definition: repr_printer.h:91
Pass RealizeVDevice()
Propagate virtual device information.
Pass Gradient(ffi::String func_name, ffi::Optional< ffi::Array< Var >> require_grads=std::nullopt, int target_index=0)
Reverse-mode automatic differentiation.
Pass CreateFunctionPass(std::function< Function(Function, IRModule, PassContext)> pass_func, int opt_level, ffi::String name, tvm::ffi::Array< ffi::String > required, bool traceable=false)
Create a function pass.
Pass MergeCompositeFunctions()
Group one or multiple composite functions created by FuseOpsByPattern into a new function....
tvm::relax::DataflowBlock DataflowBlock
Definition: transform.h:42
Pass DataflowUseInplaceCalls()
Pass that changes calls to operators that can be done in-place (generally, these are elementwise oper...
Pass FuseOpsByPattern(const tvm::ffi::Array< FusionPattern > &patterns, bool bind_constants=true, bool annotate_codegen=false, const tvm::ffi::Array< ffi::String > &entry_function_names={})
Apply pattern matching to each function in the given module, and group matched expressions into a new...
Pass ToMixedPrecision(const DataType &out_dtype, ffi::Optional< ffi::Array< ffi::String >> fp16_input_names=std::nullopt)
Automatic mixed precision pass. Currently the pass assumes the input module to be fp32 only,...
Pass CallTIRRewrite()
Perform explicit tensor allocation for call_tir and call_dps_packed.
Pass BindParams(ffi::String func_name, ffi::Map< Any, ObjectRef > params)
Bind params of function of the module to constant tensors.
Pass Normalize()
Transform Relax IR to normal form: transform AST to A-normal form, and fill the struct_info_ of expre...
tvm::relax::Function Function
Definition: transform.h:41
Pass LiftTransformParams(ffi::Variant< Bool, ffi::Array< ffi::String >> shared_transform=Bool(false))
Lift transformation of the parameters of a function.
Pass RemoveUnusedParameters()
Remove unused parameters to internal functions.
Pass SpecializePrimFuncBasedOnCallSite()
This pass updates the var_buffer mapping of PrimFunctions from the call_tir info. Primarily used to u...
Pass FuseTIR()
Fuse relax sub-function into a larger TIR function if possible. this pass works together with FuseOps...
Pass FewShotTuning(int valid_count, bool benchmark)
The pass is designed for few shot tuning for static shape PrimFuncs. It examines all the blocks withi...
Pass LegalizeOps(ffi::Optional< ffi::Map< ffi::String, ffi::Function >> cmap, ffi::Optional< ffi::Array< ffi::String >> skip_ops, bool enable_warning=false)
Legalize high-level operator calls in Relax functions to call_tir with corresponding low-level TIR Pr...
Pass BindSymbolicVars(ffi::Map< ffi::Variant< tir::Var, ffi::String >, PrimExpr > binding_map, ffi::Optional< ffi::String > func_name=std::nullopt)
Bind symbolic vars to constant shape values.
Pass ConvertLayout(ffi::Map< ffi::String, ffi::Array< ffi::String >> desired_layouts)
Layout conversion pass.
Pass SplitLayoutRewritePreproc()
Split the layout rewrite preproc block to a separate tir::PrimFunc.
Pass AttachGlobalSymbol()
Attach global_symbol to Relax functions and TIR Primfuncs for codegen.
Pass FuseOps(int fuse_opt_level=-1)
This pass groups bindings in a dataflow block of Relax functions and generates a new grouped Relax fu...
Pass AnnotateTIROpPattern()
Annotate Op Pattern Kind for TIR functions, which is used in FuseOps.
Pass RemovePurityChecking()
Activate force_pure on all pure functions in the module and unwrap all pure override ops into the nor...
Pass AttachAttrLayoutFreeBuffers()
Attach layout free buffers to the tir::PrimFunc.
tvm::transform::PassContext PassContext
Definition: transform.h:40
Pass NormalizeGlobalVar()
Possibly rename the GlobalVar in an IRModule to ensure these properties:
tvm::transform::PassInfo PassInfo
Definition: transform.h:39
Pass RunCodegen(ffi::Optional< ffi::Map< ffi::String, ffi::Map< ffi::String, ffi::Any >>> target_options, ffi::Array< ffi::String > entry_functions)
Run codegen.
Pass RemoveUnusedOutputs()
Remove unused outputs from internal functions.
Pass ExpandTupleArguments()
Expand tuple arguments to internal functions.
Pass FoldConstant()
Fold constant expressions within dataflow blocks.
tvm::transform::Pass Pass
Definition: transform.h:38
Pass CanonicalizeBindings()
Simplify a Relax module by folding var bindings and match shape nodes, as well as tuple indices....
Pass RewriteDataflowReshape()
Convert all reshape-like call_tir whose corresponding binding vars are DataflowVars to relax....
Pass DecomposeOpsForTraining(ffi::Optional< ffi::String > func_name)
Decompose composite operators during training. For example, The result of batch norm (a triple) will ...
Pass LambdaLift()
Perform lambda lifting to lift functions from nested into global.
Pass CreateDataflowBlockPass(std::function< DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func, int opt_level, ffi::String name, tvm::ffi::Array< ffi::String > required, bool traceable=false)
Create a dataflowblock pass.
Pass RewriteCUDAGraph()
Rewrite a Relax module for executing with CUDA graph. This pass identifies the regions that can be ex...
Pass DecomposeOpsForInference(ffi::Optional< ffi::String > func_name)
Decompose composite operators during inference. For example, The result of batch norm (a triple) will...
Pass ConvertToDataflow(int min_size=2)
A pass that converts consecutive dataflow operations inside binding blocks into dataflow blocks.
Pass EliminateCommonSubexpr(bool call_only=false)
Pass UpdateVDevice(VDevice new_vdevice, int64_t index)
Update virtual device.
Pass AlterOpImpl(const ffi::Map< ffi::String, tir::PrimFunc > &op_impl_map, const ffi::Map< ffi::String, ffi::Array< tir::IndexMap >> &op_buffer_transforms, const ffi::Map< ffi::String, ffi::Optional< ffi::Array< ffi::Array< IntImm >>>> &axis_separators, const ffi::Map< ffi::String, ffi::Optional< ffi::Array< ffi::Array< IntImm >>>> &input_axis_separators)
Returns a pass which replaces PrimFuncs which have matching kOperatorName attribute in op_impl_map,...
Pass StaticPlanBlockMemory()
The static memory planning pass on BindingBlock level. The pass will reuse allocated memory to its be...
Pass ToNonDataflow()
Transform all dataflow structure to non-dataflow version.
Pass DeadCodeElimination(ffi::Array< ffi::String > entry_functions={})
Dead code elimination.
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1103
Pass CreateModulePass(std::function< IRModule(IRModule, PassContext)> pass_func, int opt_level, ffi::String name, ffi::Array< ffi::String > required, bool traceable=false)
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
TIR Function.