tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_TRANSFORM_H_
25 #define TVM_RELAX_TRANSFORM_H_
26 
27 #include <tvm/ffi/reflection/registry.h>
28 #include <tvm/ir/transform.h>
30 #include <tvm/relax/expr.h>
31 #include <tvm/tir/function.h>
32 #include <tvm/tir/index_map.h>
33 
34 namespace tvm {
35 namespace relax {
36 namespace transform {
37 
44 
56 TVM_DLL Pass CreateFunctionPass(std::function<Function(Function, IRModule, PassContext)> pass_func,
57  int opt_level, ffi::String name,
58  tvm::ffi::Array<ffi::String> required, bool traceable = false);
59 
72  std::function<DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func, int opt_level,
73  ffi::String name, tvm::ffi::Array<ffi::String> required, bool traceable = false);
74 
80 TVM_DLL Pass LambdaLift();
81 
87 TVM_DLL Pass ToNonDataflow();
88 
101 
108 
122 
144 
151 
158 TVM_DLL Pass Normalize();
159 
168 
181 
189 TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false);
190 
199 TVM_DLL Pass BindParams(ffi::String func_name, ffi::Map<Any, ObjectRef> params);
200 
216 TVM_DLL Pass BindSymbolicVars(ffi::Map<ffi::Variant<tir::Var, ffi::String>, PrimExpr> binding_map,
217  ffi::Optional<ffi::String> func_name = std::nullopt);
218 
226 TVM_DLL Pass FoldConstant();
227 
251 TVM_DLL Pass LegalizeOps(ffi::Optional<ffi::Map<ffi::String, ffi::Function>> cmap,
252  bool enable_warning = false);
253 
259 
271 
280 
307 TVM_DLL Pass
308 LiftTransformParams(ffi::Variant<Bool, ffi::Array<ffi::String>> shared_transform = Bool(false));
309 
316 TVM_DLL Pass UpdateVDevice(VDevice new_vdevice, int64_t index);
317 
323 
329 
335 
344 
356 TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
357 
363 class FusionPatternNode : public Object {
364  public:
369  ffi::String name;
370 
376 
381  ffi::Map<ffi::String, DFPattern> annotation_patterns;
382 
390  ffi::Optional<ffi::Function> check;
391 
398  ffi::Optional<ffi::Function> attrs_getter;
399 
400  static void RegisterReflection() {
401  namespace refl = tvm::ffi::reflection;
402  refl::ObjectDef<FusionPatternNode>()
403  .def_ro("name", &FusionPatternNode::name)
404  .def_ro("pattern", &FusionPatternNode::pattern)
405  .def_ro("annotation_patterns", &FusionPatternNode::annotation_patterns)
406  .def_ro("check", &FusionPatternNode::check)
407  .def_ro("attrs_getter", &FusionPatternNode::attrs_getter);
408  }
409  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.transform.FusionPattern", FusionPatternNode, Object);
410 };
411 
412 class FusionPattern : public ObjectRef {
413  public:
414  FusionPattern(ffi::String name, DFPattern pattern,
415  ffi::Map<ffi::String, DFPattern> annotation_patterns,
416  ffi::Optional<ffi::Function> check, ffi::Optional<ffi::Function> attrs_getter);
417 
418  FusionPattern(ffi::String name, DFPattern pattern)
419  : FusionPattern(name, pattern, {}, std::nullopt, std::nullopt) {}
420 
422 };
423 
427 class PatternCheckContextNode : public Object {
428  public:
433 
438  ffi::Map<ffi::String, Expr> annotated_expr;
439 
444  ffi::Map<Var, Expr> matched_bindings;
445 
450  ffi::Map<Var, ffi::Array<Var>> var_usages;
451 
456  ffi::Map<Expr, Var> value_to_bound_var;
457 
458  static void RegisterReflection() {
459  namespace refl = tvm::ffi::reflection;
460  refl::ObjectDef<PatternCheckContextNode>()
461  .def_ro("matched_expr", &PatternCheckContextNode::matched_expr)
462  .def_ro("annotated_expr", &PatternCheckContextNode::annotated_expr)
463  .def_ro("matched_bindings", &PatternCheckContextNode::matched_bindings)
464  .def_ro("var_usages", &PatternCheckContextNode::var_usages)
465  .def_ro("value_to_bound_var", &PatternCheckContextNode::value_to_bound_var);
466  }
467  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.transform.PatternCheckContext", PatternCheckContextNode,
468  Object);
469 };
470 
471 class PatternCheckContext : public ObjectRef {
472  public:
473  PatternCheckContext(Expr matched_expr, ffi::Map<ffi::String, Expr> annotated_expr,
474  ffi::Map<Var, Expr> matched_bindings,
475  ffi::Map<Var, ffi::Array<Var>> var_usages,
476  ffi::Map<Expr, Var> value_to_bound_var);
477 
480 };
481 
507 TVM_DLL Pass Gradient(ffi::String func_name,
508  ffi::Optional<ffi::Array<Var>> require_grads = std::nullopt,
509  int target_index = 0);
510 
531 TVM_DLL Pass FuseOpsByPattern(const tvm::ffi::Array<FusionPattern>& patterns,
532  bool bind_constants = true, bool annotate_codegen = false,
533  const tvm::ffi::Array<ffi::String>& entry_function_names = {});
534 
543 
550 TVM_DLL Pass FuseTIR();
551 
558 TVM_DLL Pass
559 RunCodegen(ffi::Optional<ffi::Map<ffi::String, ffi::Map<ffi::String, ffi::Any>>> target_options,
560  ffi::Array<ffi::String> entry_functions);
561 
570 TVM_DLL Pass DecomposeOpsForInference(ffi::Optional<ffi::String> func_name);
571 
580 TVM_DLL Pass DecomposeOpsForTraining(ffi::Optional<ffi::String> func_name);
581 
597  const ffi::Map<ffi::String, tir::PrimFunc>& op_impl_map,
598  const ffi::Map<ffi::String, ffi::Array<tir::IndexMap>>& op_buffer_transforms,
599  const ffi::Map<ffi::String, ffi::Optional<ffi::Array<ffi::Array<IntImm>>>>& axis_separators,
600  const ffi::Map<ffi::String, ffi::Optional<ffi::Array<ffi::Array<IntImm>>>>&
601  input_axis_separators);
602 
609 TVM_DLL Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts);
610 
618 TVM_DLL Pass ConvertToDataflow(int min_size = 2);
619 
636 TVM_DLL Pass DeadCodeElimination(ffi::Array<ffi::String> entry_functions = {});
637 
648 
659 TVM_DLL Pass
660 ToMixedPrecision(const DataType& out_dtype,
661  ffi::Optional<ffi::Array<ffi::String>> fp16_input_names = std::nullopt);
662 
669 
680 TVM_DLL Pass FewShotTuning(int valid_count, bool benchmark);
681 
682 } // namespace transform
683 } // namespace relax
684 } // namespace tvm
685 
686 #endif // TVM_RELAX_TRANSFORM_H_
Boolean constant.
Definition: expr.h:565
Managed reference class to IRModuleNode.
Definition: module.h:256
Reference to PrimExprNode.
Definition: expr.h:124
Managed reference to RelaxExprNode.
Definition: expr.h:439
Managed reference to VDeviceNode.
Definition: global_info.h:87
Managed reference to dataflow patterns.
Definition: dataflow_pattern.h:101
Definition: expr.h:693
Definition: expr.h:830
Definition: expr.h:377
The pattern object used as the input of FuseOpsByPattern. For bindings to be fused,...
Definition: transform.h:363
ffi::Optional< ffi::Function > attrs_getter
The function to get attributes for fused function.
Definition: transform.h:398
ffi::Optional< ffi::Function > check
The function to determine whether the match result is accepted. This can be std::nullopt if check fun...
Definition: transform.h:390
ffi::String name
The name of pattern. It becomes the value of the kComposite attribute of a fused function after succe...
Definition: transform.h:369
ffi::Map< ffi::String, DFPattern > annotation_patterns
The map which is used to extract important expressions from the pattern match result....
Definition: transform.h:381
DFPattern pattern
The dataflow pattern that will be used to match expression in the DataflowBlock. All the call nodes c...
Definition: transform.h:375
static void RegisterReflection()
Definition: transform.h:400
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.transform.FusionPattern", FusionPatternNode, Object)
Definition: transform.h:412
FusionPattern(ffi::String name, DFPattern pattern)
Definition: transform.h:418
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(FusionPattern, ObjectRef, FusionPatternNode)
FusionPattern(ffi::String name, DFPattern pattern, ffi::Map< ffi::String, DFPattern > annotation_patterns, ffi::Optional< ffi::Function > check, ffi::Optional< ffi::Function > attrs_getter)
The input of FusionPattern::check.
Definition: transform.h:427
ffi::Map< Expr, Var > value_to_bound_var
Map from value to its bound variable. It doesn't have variables after the matched expression.
Definition: transform.h:456
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.transform.PatternCheckContext", PatternCheckContextNode, Object)
ffi::Map< Var, Expr > matched_bindings
Map from variable to its value. It contains variables from bindings that is being fused by FuseOpsByP...
Definition: transform.h:444
static void RegisterReflection()
Definition: transform.h:458
ffi::Map< ffi::String, Expr > annotated_expr
A map which contains all expressions matched by the sub patterns in FusionPattern::annotation_pattern...
Definition: transform.h:438
Expr matched_expr
The expression that's matched with the FusionPattern::pattern.
Definition: transform.h:432
ffi::Map< Var, ffi::Array< Var > > var_usages
A map mapping variable definitions to a set of uses. It has all variables used in the function.
Definition: transform.h:450
PatternCheckContext(Expr matched_expr, ffi::Map< ffi::String, Expr > annotated_expr, ffi::Map< Var, Expr > matched_bindings, ffi::Map< Var, ffi::Array< Var >> var_usages, ffi::Map< Expr, Var > value_to_bound_var)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PatternCheckContext, ObjectRef, PatternCheckContextNode)
Runtime primitive data type.
Definition: data_type.h:47
PassContext that is used to configure the pass behavior.
Definition: transform.h:153
Managed reference class for PassInfoNode.
Definition: transform.h:350
Definition: transform.h:400
A pattern language for matching dataflow properties.
Defines a remapping of buffer indices.
Definition: repr_printer.h:91
Pass RealizeVDevice()
Propagate virtual device information.
Pass Gradient(ffi::String func_name, ffi::Optional< ffi::Array< Var >> require_grads=std::nullopt, int target_index=0)
Reverse-mode automatic differentiation.
Pass CreateFunctionPass(std::function< Function(Function, IRModule, PassContext)> pass_func, int opt_level, ffi::String name, tvm::ffi::Array< ffi::String > required, bool traceable=false)
Create a function pass.
Pass MergeCompositeFunctions()
Group one or multiple composite functions created by FuseOpsByPattern into a new function....
tvm::relax::DataflowBlock DataflowBlock
Definition: transform.h:42
Pass DataflowUseInplaceCalls()
Pass that changes calls to operators that can be done in-place (generally, these are elementwise oper...
Pass FuseOpsByPattern(const tvm::ffi::Array< FusionPattern > &patterns, bool bind_constants=true, bool annotate_codegen=false, const tvm::ffi::Array< ffi::String > &entry_function_names={})
Apply pattern matching to each function in the given module, and group matched expressions into a new...
Pass ToMixedPrecision(const DataType &out_dtype, ffi::Optional< ffi::Array< ffi::String >> fp16_input_names=std::nullopt)
Automatic mixed precision pass. Currently the pass assumes the input module to be fp32 only,...
Pass CallTIRRewrite()
Perform explicit tensor allocation for call_tir and call_dps_packed.
Pass BindParams(ffi::String func_name, ffi::Map< Any, ObjectRef > params)
Bind params of function of the module to constant tensors.
Pass Normalize()
Transform Relax IR to normal form: transform AST to A-normal form, and fill the struct_info_ of expre...
tvm::relax::Function Function
Definition: transform.h:41
Pass LiftTransformParams(ffi::Variant< Bool, ffi::Array< ffi::String >> shared_transform=Bool(false))
Lift transformation of the parameters of a function.
Pass RemoveUnusedParameters()
Remove unused parameters to internal functions.
Pass FuseTIR()
Fuse relax sub-function into a larger TIR function if possible. this pass works together with FuseOps...
Pass FewShotTuning(int valid_count, bool benchmark)
The pass is designed for few shot tuning for static shape PrimFuncs. It examines all the blocks withi...
Pass BindSymbolicVars(ffi::Map< ffi::Variant< tir::Var, ffi::String >, PrimExpr > binding_map, ffi::Optional< ffi::String > func_name=std::nullopt)
Bind symbolic vars to constant shape values.
Pass ConvertLayout(ffi::Map< ffi::String, ffi::Array< ffi::String >> desired_layouts)
Layout conversion pass.
Pass SplitLayoutRewritePreproc()
Split the layout rewrite preproc block to a separate tir::PrimFunc.
Pass AttachGlobalSymbol()
Attach global_symbol to Relax functions and TIR Primfuncs for codegen.
Pass FuseOps(int fuse_opt_level=-1)
This pass groups bindings in a dataflow block of Relax functions and generates a new grouped Relax fu...
Pass AnnotateTIROpPattern()
Annotate Op Pattern Kind for TIR functions, which is used in FuseOps.
Pass RemovePurityChecking()
Activate force_pure on all pure functions in the module and unwrap all pure override ops into the nor...
Pass AttachAttrLayoutFreeBuffers()
Attach layout free buffers to the tir::PrimFunc.
tvm::transform::PassContext PassContext
Definition: transform.h:40
Pass NormalizeGlobalVar()
Possibly rename the GlobalVar in an IRModule to ensure these properties:
tvm::transform::PassInfo PassInfo
Definition: transform.h:39
Pass RunCodegen(ffi::Optional< ffi::Map< ffi::String, ffi::Map< ffi::String, ffi::Any >>> target_options, ffi::Array< ffi::String > entry_functions)
Run codegen.
Pass LegalizeOps(ffi::Optional< ffi::Map< ffi::String, ffi::Function >> cmap, bool enable_warning=false)
Legalize high-level operator calls in Relax functions to call_tir with corresponding low-level TIR Pr...
Pass RemoveUnusedOutputs()
Remove unused outputs from internal functions.
Pass ExpandTupleArguments()
Expand tuple arguments to internal functions.
Pass FoldConstant()
Fold constant expressions within dataflow blocks.
tvm::transform::Pass Pass
Definition: transform.h:38
Pass CanonicalizeBindings()
Simplify a Relax module by folding var bindings and match shape nodes, as well as tuple indices....
Pass RewriteDataflowReshape()
Convert all reshape-like call_tir whose corresponding binding vars are DataflowVars to relax....
Pass DecomposeOpsForTraining(ffi::Optional< ffi::String > func_name)
Decompose composite operators during training. For example, The result of batch norm (a triple) will ...
Pass LambdaLift()
Perform lambda lifting to lift functions from nested into global.
Pass CreateDataflowBlockPass(std::function< DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func, int opt_level, ffi::String name, tvm::ffi::Array< ffi::String > required, bool traceable=false)
Create a dataflowblock pass.
Pass RewriteCUDAGraph()
Rewrite a Relax module for executing with CUDA graph. This pass identifies the regions that can be ex...
Pass DecomposeOpsForInference(ffi::Optional< ffi::String > func_name)
Decompose composite operators during inference. For example, The result of batch norm (a triple) will...
Pass ConvertToDataflow(int min_size=2)
A pass that converts consecutive dataflow operations inside binding blocks into dataflow blocks.
Pass EliminateCommonSubexpr(bool call_only=false)
Pass UpdateVDevice(VDevice new_vdevice, int64_t index)
Update virtual device.
Pass AlterOpImpl(const ffi::Map< ffi::String, tir::PrimFunc > &op_impl_map, const ffi::Map< ffi::String, ffi::Array< tir::IndexMap >> &op_buffer_transforms, const ffi::Map< ffi::String, ffi::Optional< ffi::Array< ffi::Array< IntImm >>>> &axis_separators, const ffi::Map< ffi::String, ffi::Optional< ffi::Array< ffi::Array< IntImm >>>> &input_axis_separators)
Returns a pass which replaces PrimFuncs which have matching kOperatorName attribute in op_impl_map,...
Pass StaticPlanBlockMemory()
The static memory planning pass on BindingBlock level. The pass will reuse allocated memory to its be...
Pass ToNonDataflow()
Transform all dataflow structure to non-dataflow version.
Pass DeadCodeElimination(ffi::Array< ffi::String > entry_functions={})
Dead code elimination.
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1094
Pass CreateModulePass(std::function< IRModule(IRModule, PassContext)> pass_func, int opt_level, ffi::String name, ffi::Array< ffi::String > required, bool traceable=false)
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
TIR Function.