tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_TRANSFORM_H_
25 #define TVM_RELAX_TRANSFORM_H_
26 
27 #include <tvm/ffi/reflection/registry.h>
28 #include <tvm/ir/transform.h>
30 #include <tvm/relax/expr.h>
31 #include <tvm/tir/function.h>
32 #include <tvm/tir/index_map.h>
33 
34 namespace tvm {
35 namespace relax {
36 namespace transform {
37 
44 
56 TVM_DLL Pass CreateFunctionPass(std::function<Function(Function, IRModule, PassContext)> pass_func,
57  int opt_level, String name, tvm::Array<String> required,
58  bool traceable = false);
59 
72  std::function<DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func, int opt_level,
73  String name, tvm::Array<String> required, bool traceable = false);
74 
80 TVM_DLL Pass LambdaLift();
81 
87 TVM_DLL Pass ToNonDataflow();
88 
101 
108 
122 
144 
151 
158 TVM_DLL Pass Normalize();
159 
168 
181 
189 TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false);
190 
199 TVM_DLL Pass BindParams(String func_name, Map<Any, ObjectRef> params);
200 
216 TVM_DLL Pass BindSymbolicVars(Map<Variant<tir::Var, String>, PrimExpr> binding_map,
217  Optional<String> func_name = std::nullopt);
218 
226 TVM_DLL Pass FoldConstant();
227 
251 TVM_DLL Pass LegalizeOps(Optional<Map<String, ffi::Function>> cmap, bool enable_warning = false);
252 
258 
270 
279 
306 TVM_DLL Pass LiftTransformParams(Variant<Bool, Array<String>> shared_transform = Bool(false));
307 
314 TVM_DLL Pass UpdateVDevice(VDevice new_vdevice, int64_t index);
315 
321 
327 
333 
342 
354 TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
355 
361 class FusionPatternNode : public Object {
362  public:
367  String name;
368 
374 
379  Map<String, DFPattern> annotation_patterns;
380 
388  Optional<ffi::Function> check;
389 
396  Optional<ffi::Function> attrs_getter;
397 
398  static void RegisterReflection() {
399  namespace refl = tvm::ffi::reflection;
400  refl::ObjectDef<FusionPatternNode>()
401  .def_ro("name", &FusionPatternNode::name)
402  .def_ro("pattern", &FusionPatternNode::pattern)
403  .def_ro("annotation_patterns", &FusionPatternNode::annotation_patterns)
404  .def_ro("check", &FusionPatternNode::check)
405  .def_ro("attrs_getter", &FusionPatternNode::attrs_getter);
406  }
407 
408  static constexpr const char* _type_key = "relax.transform.FusionPattern";
410 };
411 
412 class FusionPattern : public ObjectRef {
413  public:
414  FusionPattern(String name, DFPattern pattern, Map<String, DFPattern> annotation_patterns,
415  Optional<ffi::Function> check, Optional<ffi::Function> attrs_getter);
416 
417  FusionPattern(String name, DFPattern pattern)
418  : FusionPattern(name, pattern, {}, std::nullopt, std::nullopt) {}
419 
421 };
422 
426 class PatternCheckContextNode : public Object {
427  public:
432 
437  Map<String, Expr> annotated_expr;
438 
443  Map<Var, Expr> matched_bindings;
444 
449  Map<Var, Array<Var>> var_usages;
450 
455  Map<Expr, Var> value_to_bound_var;
456 
457  static void RegisterReflection() {
458  namespace refl = tvm::ffi::reflection;
459  refl::ObjectDef<PatternCheckContextNode>()
460  .def_ro("matched_expr", &PatternCheckContextNode::matched_expr)
461  .def_ro("annotated_expr", &PatternCheckContextNode::annotated_expr)
462  .def_ro("matched_bindings", &PatternCheckContextNode::matched_bindings)
463  .def_ro("var_usages", &PatternCheckContextNode::var_usages)
464  .def_ro("value_to_bound_var", &PatternCheckContextNode::value_to_bound_var);
465  }
466 
467  static constexpr const char* _type_key = "relax.transform.PatternCheckContext";
469 };
470 
471 class PatternCheckContext : public ObjectRef {
472  public:
473  PatternCheckContext(Expr matched_expr, Map<String, Expr> annotated_expr,
474  Map<Var, Expr> matched_bindings, Map<Var, Array<Var>> var_usages,
475  Map<Expr, Var> value_to_bound_var);
476 
479 };
480 
506 TVM_DLL Pass Gradient(String func_name, Optional<Array<Var>> require_grads = std::nullopt,
507  int target_index = 0);
508 
529 TVM_DLL Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool bind_constants = true,
530  bool annotate_codegen = false,
531  const tvm::Array<String>& entry_function_names = {});
532 
541 
548 TVM_DLL Pass FuseTIR();
549 
556 TVM_DLL Pass RunCodegen(Optional<Map<String, Map<String, ffi::Any>>> target_options,
557  Array<String> entry_functions);
558 
567 TVM_DLL Pass DecomposeOpsForInference(Optional<String> func_name);
568 
577 TVM_DLL Pass DecomposeOpsForTraining(Optional<String> func_name);
578 
593 TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
594  const Map<String, Array<tir::IndexMap>>& op_buffer_transforms,
595  const Map<String, Optional<Array<Array<IntImm>>>>& axis_separators,
596  const Map<String, Optional<Array<Array<IntImm>>>>& input_axis_separators);
597 
604 TVM_DLL Pass ConvertLayout(Map<String, Array<String>> desired_layouts);
605 
613 TVM_DLL Pass ConvertToDataflow(int min_size = 2);
614 
631 TVM_DLL Pass DeadCodeElimination(Array<String> entry_functions = {});
632 
643 
654 TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype,
655  Optional<Array<String>> fp16_input_names = std::nullopt);
656 
663 
674 TVM_DLL Pass FewShotTuning(int valid_count, bool benchmark);
675 
676 } // namespace transform
677 } // namespace relax
678 } // namespace tvm
679 
680 #endif // TVM_RELAX_TRANSFORM_H_
Boolean constant.
Definition: expr.h:577
Managed reference class to IRModuleNode.
Definition: module.h:257
Reference to PrimExprNode.
Definition: expr.h:129
Managed reference to RelaxExprNode.
Definition: expr.h:446
Managed reference to VDeviceNode.
Definition: global_info.h:90
Managed reference to dataflow patterns.
Definition: dataflow_pattern.h:102
Definition: expr.h:720
Definition: expr.h:862
Definition: expr.h:387
The pattern object used as the input of FuseOpsByPattern. For bindings to be fused,...
Definition: transform.h:361
Map< String, DFPattern > annotation_patterns
The map which is used to extract important expressions from the pattern match result....
Definition: transform.h:379
TVM_DECLARE_FINAL_OBJECT_INFO(FusionPatternNode, Object)
String name
The name of pattern. It becomes the value of the kComposite attribute of a fused function after succe...
Definition: transform.h:367
static constexpr const char * _type_key
Definition: transform.h:408
DFPattern pattern
The dataflow pattern that will be used to match expression in the DataflowBlock. All the call nodes c...
Definition: transform.h:373
Optional< ffi::Function > check
The function to determine whether the match result is accepted. This can be std::nullopt if check fun...
Definition: transform.h:388
static void RegisterReflection()
Definition: transform.h:398
Optional< ffi::Function > attrs_getter
The function to get attributes for fused function.
Definition: transform.h:396
Definition: transform.h:412
FusionPattern(String name, DFPattern pattern, Map< String, DFPattern > annotation_patterns, Optional< ffi::Function > check, Optional< ffi::Function > attrs_getter)
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FusionPattern, ObjectRef, FusionPatternNode)
FusionPattern(String name, DFPattern pattern)
Definition: transform.h:417
The input of FusionPattern::check.
Definition: transform.h:426
Map< Expr, Var > value_to_bound_var
Map from value to its bound variable. It doesn't have variables after the matched expression.
Definition: transform.h:455
Map< Var, Expr > matched_bindings
Map from variable to its value. It contains variables from bindings that is being fused by FuseOpsByP...
Definition: transform.h:443
Map< String, Expr > annotated_expr
A map which contains all expressions matched by the sub patterns in FusionPattern::annotation_pattern...
Definition: transform.h:437
TVM_DECLARE_FINAL_OBJECT_INFO(PatternCheckContextNode, Object)
static constexpr const char * _type_key
Definition: transform.h:467
static void RegisterReflection()
Definition: transform.h:457
Map< Var, Array< Var > > var_usages
A map mapping variable definitions to a set of uses. It has all variables used in the function.
Definition: transform.h:449
Expr matched_expr
The expression that's matched with the FusionPattern::pattern.
Definition: transform.h:431
PatternCheckContext(Expr matched_expr, Map< String, Expr > annotated_expr, Map< Var, Expr > matched_bindings, Map< Var, Array< Var >> var_usages, Map< Expr, Var > value_to_bound_var)
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternCheckContext, ObjectRef, PatternCheckContextNode)
Runtime primitive data type.
Definition: data_type.h:47
PassContext that is used to configure the pass behavior.
Definition: transform.h:156
Managed reference class for PassInfoNode.
Definition: transform.h:349
Definition: transform.h:400
A pattern language for matching dataflow properties.
Defines a remapping of buffer indices.
Definition: repr_printer.h:91
Pass RealizeVDevice()
Propagate virtual device information.
Pass ToMixedPrecision(const DataType &out_dtype, Optional< Array< String >> fp16_input_names=std::nullopt)
Automatic mixed precision pass. Currently the pass assumes the input module to be fp32 only,...
Pass MergeCompositeFunctions()
Group one or multiple composite functions created by FuseOpsByPattern into a new function....
tvm::relax::DataflowBlock DataflowBlock
Definition: transform.h:42
Pass DataflowUseInplaceCalls()
Pass that changes calls to operators that can be done in-place (generally, these are elementwise oper...
Pass CallTIRRewrite()
Perform explicit tensor allocation for call_tir and call_dps_packed.
Pass DeadCodeElimination(Array< String > entry_functions={})
Dead code elimination.
Pass Normalize()
Transform Relax IR to normal form: transform AST to A-normal form, and fill the struct_info_ of expre...
tvm::relax::Function Function
Definition: transform.h:41
Pass DecomposeOpsForTraining(Optional< String > func_name)
Decompose composite operators during training. For example, The result of batch norm (a triple) will ...
Pass RemoveUnusedParameters()
Remove unused parameters to internal functions.
Pass AlterOpImpl(const Map< String, tir::PrimFunc > &op_impl_map, const Map< String, Array< tir::IndexMap >> &op_buffer_transforms, const Map< String, Optional< Array< Array< IntImm >>>> &axis_separators, const Map< String, Optional< Array< Array< IntImm >>>> &input_axis_separators)
Returns a pass which replaces PrimFuncs which have matching kOperatorName attribute in op_impl_map,...
Pass BindParams(String func_name, Map< Any, ObjectRef > params)
Bind params of function of the module to constant tensors.
Pass FuseTIR()
Fuse relax sub-function into a larger TIR function if possible. this pass works together with FuseOps...
Pass FewShotTuning(int valid_count, bool benchmark)
The pass is designed for few shot tuning for static shape PrimFuncs. It examines all the blocks withi...
Pass SplitLayoutRewritePreproc()
Split the layout rewrite preproc block to a separate tir::PrimFunc.
Pass AttachGlobalSymbol()
Attach global_symbol to Relax functions and TIR Primfuncs for codegen.
Pass FuseOps(int fuse_opt_level=-1)
This pass groups bindings in a dataflow block of Relax functions and generates a new grouped Relax fu...
Pass FuseOpsByPattern(const tvm::Array< FusionPattern > &patterns, bool bind_constants=true, bool annotate_codegen=false, const tvm::Array< String > &entry_function_names={})
Apply pattern matching to each function in the given module, and group matched expressions into a new...
Pass AnnotateTIROpPattern()
Annotate Op Pattern Kind for TIR functions, which is used in FuseOps.
Pass RemovePurityChecking()
Activate force_pure on all pure functions in the module and unwrap all pure override ops into the nor...
Pass CreateFunctionPass(std::function< Function(Function, IRModule, PassContext)> pass_func, int opt_level, String name, tvm::Array< String > required, bool traceable=false)
Create a function pass.
Pass AttachAttrLayoutFreeBuffers()
Attach layout free buffers to the tir::PrimFunc.
tvm::transform::PassContext PassContext
Definition: transform.h:40
Pass NormalizeGlobalVar()
Possibly rename the GlobalVar in an IRModule to ensure these properties:
tvm::transform::PassInfo PassInfo
Definition: transform.h:39
Pass ConvertLayout(Map< String, Array< String >> desired_layouts)
Layout conversion pass.
Pass RemoveUnusedOutputs()
Remove unused outputs from internal functions.
Pass ExpandTupleArguments()
Expand tuple arguments to internal functions.
Pass BindSymbolicVars(Map< Variant< tir::Var, String >, PrimExpr > binding_map, Optional< String > func_name=std::nullopt)
Bind symbolic vars to constant shape values.
Pass DecomposeOpsForInference(Optional< String > func_name)
Decompose composite operators during inference. For example, The result of batch norm (a triple) will...
Pass FoldConstant()
Fold constant expressions within dataflow blocks.
Pass Gradient(String func_name, Optional< Array< Var >> require_grads=std::nullopt, int target_index=0)
Reverse-mode automatic differentiation.
tvm::transform::Pass Pass
Definition: transform.h:38
Pass CanonicalizeBindings()
Simplify a Relax module by folding var bindings and match shape nodes, as well as tuple indices....
Pass RewriteDataflowReshape()
Convert all reshape-like call_tir whose corresponding binding vars are DataflowVars to relax....
Pass CreateDataflowBlockPass(std::function< DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func, int opt_level, String name, tvm::Array< String > required, bool traceable=false)
Create a dataflowblock pass.
Pass LambdaLift()
Perform lambda lifting to lift functions from nested into global.
Pass RewriteCUDAGraph()
Rewrite a Relax module for executing with CUDA graph. This pass identifies the regions that can be ex...
Pass ConvertToDataflow(int min_size=2)
A pass that converts consecutive dataflow operations inside binding blocks into dataflow blocks.
Pass EliminateCommonSubexpr(bool call_only=false)
Pass UpdateVDevice(VDevice new_vdevice, int64_t index)
Update virtual device.
Pass RunCodegen(Optional< Map< String, Map< String, ffi::Any >>> target_options, Array< String > entry_functions)
Run codegen.
Pass StaticPlanBlockMemory()
The static memory planning pass on BindingBlock level. The pass will reuse allocated memory to its be...
Pass ToNonDataflow()
Transform all dataflow structure to non-dataflow version.
Pass LiftTransformParams(Variant< Bool, Array< String >> shared_transform=Bool(false))
Lift transformation of the parameters of a function.
Pass LegalizeOps(Optional< Map< String, ffi::Function >> cmap, bool enable_warning=false)
Legalize high-level operator calls in Relax functions to call_tir with corresponding low-level TIR Pr...
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1123
Pass CreateModulePass(std::function< IRModule(IRModule, PassContext)> pass_func, int opt_level, String name, Array< String > required, bool traceable=false)
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
TIR Function.