tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_TRANSFORM_H_
25 #define TVM_RELAX_TRANSFORM_H_
26 
27 #include <tvm/ffi/reflection/registry.h>
28 #include <tvm/ir/transform.h>
30 #include <tvm/relax/expr.h>
31 #include <tvm/tirx/function.h>
32 #include <tvm/tirx/index_map.h>
33 
34 namespace tvm {
35 namespace relax {
36 namespace transform {
37 
44 using LayoutCb = ffi::TypedFunction<ffi::Map<ffi::String, ffi::Array<ffi::String>>(Call)>;
45 
57 TVM_DLL Pass CreateFunctionPass(std::function<Function(Function, IRModule, PassContext)> pass_func,
58  int opt_level, ffi::String name,
59  tvm::ffi::Array<ffi::String> required, bool traceable = false);
60 
73  std::function<DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func, int opt_level,
74  ffi::String name, tvm::ffi::Array<ffi::String> required, bool traceable = false);
75 
81 TVM_DLL Pass LambdaLift();
82 
88 TVM_DLL Pass ToNonDataflow();
89 
102 
109 
123 
146 
153 
160 TVM_DLL Pass Normalize();
161 
170 
183 
191 TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false);
192 
201 TVM_DLL Pass BindParams(ffi::String func_name, ffi::Map<Any, ffi::ObjectRef> params);
202 
218 TVM_DLL Pass BindSymbolicVars(ffi::Map<ffi::Variant<tirx::Var, ffi::String>, PrimExpr> binding_map,
219  ffi::Optional<ffi::String> func_name = std::nullopt);
220 
228 TVM_DLL Pass FoldConstant();
229 
254 TVM_DLL Pass LegalizeOps(ffi::Optional<ffi::Map<ffi::String, ffi::Function>> cmap,
255  ffi::Optional<ffi::Array<ffi::String>> skip_ops,
256  bool enable_warning = false);
257 
263 
275 
284 
311 TVM_DLL Pass
312 LiftTransformParams(ffi::Variant<Bool, ffi::Array<ffi::String>> shared_transform = Bool(false));
313 
320 TVM_DLL Pass UpdateVDevice(VDevice new_vdevice, int64_t index);
321 
327 
333 
339 
348 
360 TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
361 
367 class FusionPatternNode : public ffi::Object {
368  public:
373  ffi::String name;
374 
380 
385  ffi::Map<ffi::String, DFPattern> annotation_patterns;
386 
394  ffi::Optional<ffi::Function> check;
395 
402  ffi::Optional<ffi::Function> attrs_getter;
403 
404  static void RegisterReflection() {
405  namespace refl = tvm::ffi::reflection;
406  refl::ObjectDef<FusionPatternNode>()
407  .def_ro("name", &FusionPatternNode::name)
408  .def_ro("pattern", &FusionPatternNode::pattern)
409  .def_ro("annotation_patterns", &FusionPatternNode::annotation_patterns)
410  .def_ro("check", &FusionPatternNode::check)
411  .def_ro("attrs_getter", &FusionPatternNode::attrs_getter);
412  }
413  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.transform.FusionPattern", FusionPatternNode,
414  ffi::Object);
415 };
416 
417 class FusionPattern : public ffi::ObjectRef {
418  public:
419  FusionPattern(ffi::String name, DFPattern pattern,
420  ffi::Map<ffi::String, DFPattern> annotation_patterns,
421  ffi::Optional<ffi::Function> check, ffi::Optional<ffi::Function> attrs_getter);
422 
423  FusionPattern(ffi::String name, DFPattern pattern)
424  : FusionPattern(name, pattern, {}, std::nullopt, std::nullopt) {}
425 
427 };
428 
432 class PatternCheckContextNode : public ffi::Object {
433  public:
438 
443  ffi::Map<ffi::String, Expr> annotated_expr;
444 
449  ffi::Map<Var, Expr> matched_bindings;
450 
455  ffi::Map<Var, ffi::Array<Var>> var_usages;
456 
461  ffi::Map<Expr, Var> value_to_bound_var;
462 
463  static void RegisterReflection() {
464  namespace refl = tvm::ffi::reflection;
465  refl::ObjectDef<PatternCheckContextNode>()
466  .def_ro("matched_expr", &PatternCheckContextNode::matched_expr)
467  .def_ro("annotated_expr", &PatternCheckContextNode::annotated_expr)
468  .def_ro("matched_bindings", &PatternCheckContextNode::matched_bindings)
469  .def_ro("var_usages", &PatternCheckContextNode::var_usages)
470  .def_ro("value_to_bound_var", &PatternCheckContextNode::value_to_bound_var);
471  }
472  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.transform.PatternCheckContext", PatternCheckContextNode,
473  ffi::Object);
474 };
475 
476 class PatternCheckContext : public ffi::ObjectRef {
477  public:
478  PatternCheckContext(Expr matched_expr, ffi::Map<ffi::String, Expr> annotated_expr,
479  ffi::Map<Var, Expr> matched_bindings,
480  ffi::Map<Var, ffi::Array<Var>> var_usages,
481  ffi::Map<Expr, Var> value_to_bound_var);
482 
485 };
486 
512 TVM_DLL Pass Gradient(ffi::String func_name,
513  ffi::Optional<ffi::Array<Var>> require_grads = std::nullopt,
514  int target_index = 0);
515 
536 TVM_DLL Pass FuseOpsByPattern(const tvm::ffi::Array<FusionPattern>& patterns,
537  bool bind_constants = true, bool annotate_codegen = false,
538  const tvm::ffi::Array<ffi::String>& entry_function_names = {});
539 
548 
555 TVM_DLL Pass FuseTIR();
556 
563 TVM_DLL Pass
564 RunCodegen(ffi::Optional<ffi::Map<ffi::String, ffi::Map<ffi::String, ffi::Any>>> target_options,
565  ffi::Array<ffi::String> entry_functions);
566 
575 TVM_DLL Pass DecomposeOpsForInference(ffi::Optional<ffi::String> func_name);
576 
585 TVM_DLL Pass DecomposeOpsForTraining(ffi::Optional<ffi::String> func_name);
586 
602  const ffi::Map<ffi::String, tirx::PrimFunc>& op_impl_map,
603  const ffi::Map<ffi::String, ffi::Array<tirx::IndexMap>>& op_buffer_transforms,
604  const ffi::Map<ffi::String, ffi::Optional<ffi::Array<ffi::Array<IntImm>>>>& axis_separators,
605  const ffi::Map<ffi::String, ffi::Optional<ffi::Array<ffi::Array<IntImm>>>>&
606  input_axis_separators);
607 
615 TVM_DLL Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts,
616  LayoutCb layout_cb);
617 
625 TVM_DLL Pass ConvertToDataflow(int min_size = 2);
626 
643 TVM_DLL Pass DeadCodeElimination(ffi::Array<ffi::String> entry_functions = {});
644 
655 
666 TVM_DLL Pass
667 ToMixedPrecision(const DataType& out_dtype,
668  ffi::Optional<ffi::Array<ffi::String>> fp16_input_names = std::nullopt);
669 
676 
683 
684 } // namespace transform
685 } // namespace relax
686 } // namespace tvm
687 
688 #endif // TVM_RELAX_TRANSFORM_H_
Boolean constant.
Definition: expr.h:566
Managed reference class to IRModuleNode.
Definition: module.h:258
Reference to PrimExprNode.
Definition: expr.h:126
Managed reference to RelaxExprNode.
Definition: expr.h:441
Managed reference to VDeviceNode.
Definition: global_info.h:87
Definition: expr.h:180
Managed reference to dataflow patterns.
Definition: dataflow_pattern.h:101
Definition: expr.h:698
Definition: expr.h:835
Definition: expr.h:380
The pattern object used as the input of FuseOpsByPattern. For bindings to be fused,...
Definition: transform.h:367
ffi::Optional< ffi::Function > attrs_getter
The function to get attributes for fused function.
Definition: transform.h:402
ffi::Optional< ffi::Function > check
The function to determine whether the match result is accepted. This can be std::nullopt if check fun...
Definition: transform.h:394
ffi::String name
The name of pattern. It becomes the value of the kComposite attribute of a fused function after succe...
Definition: transform.h:373
ffi::Map< ffi::String, DFPattern > annotation_patterns
The map which is used to extract important expressions from the pattern match result....
Definition: transform.h:385
DFPattern pattern
The dataflow pattern that will be used to match expression in the DataflowBlock. All the call nodes c...
Definition: transform.h:379
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.transform.FusionPattern", FusionPatternNode, ffi::Object)
static void RegisterReflection()
Definition: transform.h:404
Definition: transform.h:417
FusionPattern(ffi::String name, DFPattern pattern)
Definition: transform.h:423
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(FusionPattern, ffi::ObjectRef, FusionPatternNode)
FusionPattern(ffi::String name, DFPattern pattern, ffi::Map< ffi::String, DFPattern > annotation_patterns, ffi::Optional< ffi::Function > check, ffi::Optional< ffi::Function > attrs_getter)
The input of FusionPattern::check.
Definition: transform.h:432
ffi::Map< Expr, Var > value_to_bound_var
Map from value to its bound variable. It doesn't have variables after the matched expression.
Definition: transform.h:461
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.transform.PatternCheckContext", PatternCheckContextNode, ffi::Object)
ffi::Map< Var, Expr > matched_bindings
Map from variable to its value. It contains variables from bindings that is being fused by FuseOpsByP...
Definition: transform.h:449
static void RegisterReflection()
Definition: transform.h:463
ffi::Map< ffi::String, Expr > annotated_expr
A map which contains all expressions matched by the sub patterns in FusionPattern::annotation_pattern...
Definition: transform.h:443
Expr matched_expr
The expression that's matched with the FusionPattern::pattern.
Definition: transform.h:437
ffi::Map< Var, ffi::Array< Var > > var_usages
A map mapping variable definitions to a set of uses. It has all variables used in the function.
Definition: transform.h:455
PatternCheckContext(Expr matched_expr, ffi::Map< ffi::String, Expr > annotated_expr, ffi::Map< Var, Expr > matched_bindings, ffi::Map< Var, ffi::Array< Var >> var_usages, ffi::Map< Expr, Var > value_to_bound_var)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PatternCheckContext, ffi::ObjectRef, PatternCheckContextNode)
Runtime primitive data type.
Definition: data_type.h:45
PassContext that is used to configure the pass behavior.
Definition: transform.h:153
Managed reference class for PassInfoNode.
Definition: transform.h:350
Definition: transform.h:400
A pattern language for matching dataflow properties.
Defines a remapping of buffer indices.
Pass RealizeVDevice()
Propagate virtual device information.
Pass Gradient(ffi::String func_name, ffi::Optional< ffi::Array< Var >> require_grads=std::nullopt, int target_index=0)
Reverse-mode automatic differentiation.
Pass CreateFunctionPass(std::function< Function(Function, IRModule, PassContext)> pass_func, int opt_level, ffi::String name, tvm::ffi::Array< ffi::String > required, bool traceable=false)
Create a function pass.
Pass MergeCompositeFunctions()
Group one or multiple composite functions created by FuseOpsByPattern into a new function....
tvm::relax::DataflowBlock DataflowBlock
Definition: transform.h:42
Pass DataflowUseInplaceCalls()
Pass that changes calls to operators that can be done in-place (generally, these are elementwise oper...
Pass FuseOpsByPattern(const tvm::ffi::Array< FusionPattern > &patterns, bool bind_constants=true, bool annotate_codegen=false, const tvm::ffi::Array< ffi::String > &entry_function_names={})
Apply pattern matching to each function in the given module, and group matched expressions into a new...
Pass ToMixedPrecision(const DataType &out_dtype, ffi::Optional< ffi::Array< ffi::String >> fp16_input_names=std::nullopt)
Automatic mixed precision pass. Currently the pass assumes the input module to be fp32 only,...
Pass BindSymbolicVars(ffi::Map< ffi::Variant< tirx::Var, ffi::String >, PrimExpr > binding_map, ffi::Optional< ffi::String > func_name=std::nullopt)
Bind symbolic vars to constant shape values.
Pass CallTIRRewrite()
Perform explicit tensor allocation for call_tir and call_dps_packed.
Pass Normalize()
Transform Relax IR to normal form: transform AST to A-normal form, and fill the struct_info_ of expre...
tvm::relax::Function Function
Definition: transform.h:41
Pass LiftTransformParams(ffi::Variant< Bool, ffi::Array< ffi::String >> shared_transform=Bool(false))
Lift transformation of the parameters of a function.
Pass RemoveUnusedParameters()
Remove unused parameters to internal functions.
Pass SpecializePrimFuncBasedOnCallSite()
This pass updates the var_buffer mapping of PrimFunctions from the call_tir info. Primarily used to u...
Pass FuseTIR()
Fuse relax sub-function into a larger TIR function if possible. this pass works together with FuseOps...
Pass LegalizeOps(ffi::Optional< ffi::Map< ffi::String, ffi::Function >> cmap, ffi::Optional< ffi::Array< ffi::String >> skip_ops, bool enable_warning=false)
Legalize high-level operator calls in Relax functions to call_tir with corresponding low-level TIR Pr...
Pass AlterOpImpl(const ffi::Map< ffi::String, tirx::PrimFunc > &op_impl_map, const ffi::Map< ffi::String, ffi::Array< tirx::IndexMap >> &op_buffer_transforms, const ffi::Map< ffi::String, ffi::Optional< ffi::Array< ffi::Array< IntImm >>>> &axis_separators, const ffi::Map< ffi::String, ffi::Optional< ffi::Array< ffi::Array< IntImm >>>> &input_axis_separators)
Returns a pass which replaces PrimFuncs which have matching kOperatorName attribute in op_impl_map,...
Pass SplitLayoutRewritePreproc()
Split the layout rewrite preproc block to a separate tirx::PrimFunc.
Pass AttachGlobalSymbol()
Attach global_symbol to Relax functions and TIR Primfuncs for codegen.
Pass FuseOps(int fuse_opt_level=-1)
This pass groups bindings in a dataflow block of Relax functions and generates a new grouped Relax fu...
Pass AnnotateTIROpPattern()
Annotate Op Pattern Kind for TIR functions, which is used in FuseOps.
Pass RemovePurityChecking()
Activate force_pure on all pure functions in the module and unwrap all pure override ops into the nor...
Pass AttachAttrLayoutFreeBuffers()
Attach layout free buffers to the tirx::PrimFunc.
tvm::transform::PassContext PassContext
Definition: transform.h:40
Pass NormalizeGlobalVar()
Possibly rename the GlobalVar in an IRModule to ensure these properties:
tvm::transform::PassInfo PassInfo
Definition: transform.h:39
Pass RunCodegen(ffi::Optional< ffi::Map< ffi::String, ffi::Map< ffi::String, ffi::Any >>> target_options, ffi::Array< ffi::String > entry_functions)
Run codegen.
Pass BindParams(ffi::String func_name, ffi::Map< Any, ffi::ObjectRef > params)
Bind params of function of the module to constant tensors.
Pass RemoveUnusedOutputs()
Remove unused outputs from internal functions.
Pass ExpandTupleArguments()
Expand tuple arguments to internal functions.
Pass FoldConstant()
Fold constant expressions within dataflow blocks.
tvm::transform::Pass Pass
Definition: transform.h:38
Pass CanonicalizeBindings()
Simplify a Relax module by folding var bindings and match shape nodes, as well as tuple indices....
Pass RewriteDataflowReshape()
Convert all reshape-like call_tir whose corresponding binding vars are DataflowVars to relax....
Pass DecomposeOpsForTraining(ffi::Optional< ffi::String > func_name)
Decompose composite operators during training. For example, The result of batch norm (a triple) will ...
ffi::TypedFunction< ffi::Map< ffi::String, ffi::Array< ffi::String > >(Call)> LayoutCb
Definition: transform.h:44
Pass LambdaLift()
Perform lambda lifting to lift functions from nested into global.
Pass CreateDataflowBlockPass(std::function< DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func, int opt_level, ffi::String name, tvm::ffi::Array< ffi::String > required, bool traceable=false)
Create a dataflowblock pass.
Pass RewriteCUDAGraph()
Rewrite a Relax module for executing with CUDA graph. This pass identifies the regions that can be ex...
Pass DecomposeOpsForInference(ffi::Optional< ffi::String > func_name)
Decompose composite operators during inference. For example, The result of batch norm (a triple) will...
Pass ConvertToDataflow(int min_size=2)
A pass that converts consecutive dataflow operations inside binding blocks into dataflow blocks.
Pass EliminateCommonSubexpr(bool call_only=false)
Pass UpdateVDevice(VDevice new_vdevice, int64_t index)
Update virtual device.
Pass StaticPlanBlockMemory()
The static memory planning pass on BindingBlock level. The pass will reuse allocated memory to its be...
Pass ToNonDataflow()
Transform all dataflow structure to non-dataflow version.
Pass ConvertLayout(ffi::Map< ffi::String, ffi::Array< ffi::String >> desired_layouts, LayoutCb layout_cb)
Layout conversion pass.
Pass DeadCodeElimination(ffi::Array< ffi::String > entry_functions={})
Dead code elimination.
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:233
Pass CreateModulePass(std::function< IRModule(IRModule, PassContext)> pass_func, int opt_level, ffi::String name, ffi::Array< ffi::String > required, bool traceable=false)
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
TIR Function.