tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_TRANSFORM_H_
25 #define TVM_RELAX_TRANSFORM_H_
26 
27 #include <tvm/ffi/reflection/registry.h>
28 #include <tvm/ir/transform.h>
30 #include <tvm/relax/expr.h>
31 #include <tvm/tirx/function.h>
32 #include <tvm/tirx/index_map.h>
33 
34 namespace tvm {
35 namespace relax {
36 namespace transform {
37 
44 using LayoutCb = ffi::TypedFunction<ffi::Map<ffi::String, ffi::Array<ffi::String>>(Call)>;
45 
57 TVM_DLL Pass CreateFunctionPass(std::function<Function(Function, IRModule, PassContext)> pass_func,
58  int opt_level, ffi::String name,
59  tvm::ffi::Array<ffi::String> required, bool traceable = false);
60 
73  std::function<DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func, int opt_level,
74  ffi::String name, tvm::ffi::Array<ffi::String> required, bool traceable = false);
75 
81 TVM_DLL Pass LambdaLift();
82 
88 TVM_DLL Pass ToNonDataflow();
89 
102 
109 
123 
146 
153 
160 TVM_DLL Pass Normalize();
161 
170 
183 
191 TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false);
192 
201 TVM_DLL Pass BindParams(ffi::String func_name, ffi::Map<Any, ObjectRef> params);
202 
218 TVM_DLL Pass BindSymbolicVars(ffi::Map<ffi::Variant<tirx::Var, ffi::String>, PrimExpr> binding_map,
219  ffi::Optional<ffi::String> func_name = std::nullopt);
220 
228 TVM_DLL Pass FoldConstant();
229 
254 TVM_DLL Pass LegalizeOps(ffi::Optional<ffi::Map<ffi::String, ffi::Function>> cmap,
255  ffi::Optional<ffi::Array<ffi::String>> skip_ops,
256  bool enable_warning = false);
257 
263 
275 
284 
311 TVM_DLL Pass
312 LiftTransformParams(ffi::Variant<Bool, ffi::Array<ffi::String>> shared_transform = Bool(false));
313 
320 TVM_DLL Pass UpdateVDevice(VDevice new_vdevice, int64_t index);
321 
327 
333 
339 
348 
360 TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
361 
367 class FusionPatternNode : public Object {
368  public:
373  ffi::String name;
374 
380 
385  ffi::Map<ffi::String, DFPattern> annotation_patterns;
386 
394  ffi::Optional<ffi::Function> check;
395 
402  ffi::Optional<ffi::Function> attrs_getter;
403 
404  static void RegisterReflection() {
405  namespace refl = tvm::ffi::reflection;
406  refl::ObjectDef<FusionPatternNode>()
407  .def_ro("name", &FusionPatternNode::name)
408  .def_ro("pattern", &FusionPatternNode::pattern)
409  .def_ro("annotation_patterns", &FusionPatternNode::annotation_patterns)
410  .def_ro("check", &FusionPatternNode::check)
411  .def_ro("attrs_getter", &FusionPatternNode::attrs_getter);
412  }
413  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.transform.FusionPattern", FusionPatternNode, Object);
414 };
415 
416 class FusionPattern : public ObjectRef {
417  public:
418  FusionPattern(ffi::String name, DFPattern pattern,
419  ffi::Map<ffi::String, DFPattern> annotation_patterns,
420  ffi::Optional<ffi::Function> check, ffi::Optional<ffi::Function> attrs_getter);
421 
422  FusionPattern(ffi::String name, DFPattern pattern)
423  : FusionPattern(name, pattern, {}, std::nullopt, std::nullopt) {}
424 
426 };
427 
431 class PatternCheckContextNode : public Object {
432  public:
437 
442  ffi::Map<ffi::String, Expr> annotated_expr;
443 
448  ffi::Map<Var, Expr> matched_bindings;
449 
454  ffi::Map<Var, ffi::Array<Var>> var_usages;
455 
460  ffi::Map<Expr, Var> value_to_bound_var;
461 
462  static void RegisterReflection() {
463  namespace refl = tvm::ffi::reflection;
464  refl::ObjectDef<PatternCheckContextNode>()
465  .def_ro("matched_expr", &PatternCheckContextNode::matched_expr)
466  .def_ro("annotated_expr", &PatternCheckContextNode::annotated_expr)
467  .def_ro("matched_bindings", &PatternCheckContextNode::matched_bindings)
468  .def_ro("var_usages", &PatternCheckContextNode::var_usages)
469  .def_ro("value_to_bound_var", &PatternCheckContextNode::value_to_bound_var);
470  }
471  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.transform.PatternCheckContext", PatternCheckContextNode,
472  Object);
473 };
474 
475 class PatternCheckContext : public ObjectRef {
476  public:
477  PatternCheckContext(Expr matched_expr, ffi::Map<ffi::String, Expr> annotated_expr,
478  ffi::Map<Var, Expr> matched_bindings,
479  ffi::Map<Var, ffi::Array<Var>> var_usages,
480  ffi::Map<Expr, Var> value_to_bound_var);
481 
484 };
485 
511 TVM_DLL Pass Gradient(ffi::String func_name,
512  ffi::Optional<ffi::Array<Var>> require_grads = std::nullopt,
513  int target_index = 0);
514 
535 TVM_DLL Pass FuseOpsByPattern(const tvm::ffi::Array<FusionPattern>& patterns,
536  bool bind_constants = true, bool annotate_codegen = false,
537  const tvm::ffi::Array<ffi::String>& entry_function_names = {});
538 
547 
554 TVM_DLL Pass FuseTIR();
555 
562 TVM_DLL Pass
563 RunCodegen(ffi::Optional<ffi::Map<ffi::String, ffi::Map<ffi::String, ffi::Any>>> target_options,
564  ffi::Array<ffi::String> entry_functions);
565 
574 TVM_DLL Pass DecomposeOpsForInference(ffi::Optional<ffi::String> func_name);
575 
584 TVM_DLL Pass DecomposeOpsForTraining(ffi::Optional<ffi::String> func_name);
585 
601  const ffi::Map<ffi::String, tirx::PrimFunc>& op_impl_map,
602  const ffi::Map<ffi::String, ffi::Array<tirx::IndexMap>>& op_buffer_transforms,
603  const ffi::Map<ffi::String, ffi::Optional<ffi::Array<ffi::Array<IntImm>>>>& axis_separators,
604  const ffi::Map<ffi::String, ffi::Optional<ffi::Array<ffi::Array<IntImm>>>>&
605  input_axis_separators);
606 
614 TVM_DLL Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts,
615  LayoutCb layout_cb);
616 
624 TVM_DLL Pass ConvertToDataflow(int min_size = 2);
625 
642 TVM_DLL Pass DeadCodeElimination(ffi::Array<ffi::String> entry_functions = {});
643 
654 
665 TVM_DLL Pass
666 ToMixedPrecision(const DataType& out_dtype,
667  ffi::Optional<ffi::Array<ffi::String>> fp16_input_names = std::nullopt);
668 
675 
682 
683 } // namespace transform
684 } // namespace relax
685 } // namespace tvm
686 
687 #endif // TVM_RELAX_TRANSFORM_H_
Boolean constant.
Definition: expr.h:566
Managed reference class to IRModuleNode.
Definition: module.h:257
Reference to PrimExprNode.
Definition: expr.h:126
Managed reference to RelaxExprNode.
Definition: expr.h:441
Managed reference to VDeviceNode.
Definition: global_info.h:87
Definition: expr.h:180
Managed reference to dataflow patterns.
Definition: dataflow_pattern.h:101
Definition: expr.h:695
Definition: expr.h:832
Definition: expr.h:380
The pattern object used as the input of FuseOpsByPattern. For bindings to be fused,...
Definition: transform.h:367
ffi::Optional< ffi::Function > attrs_getter
The function to get attributes for fused function.
Definition: transform.h:402
ffi::Optional< ffi::Function > check
The function to determine whether the match result is accepted. This can be std::nullopt if check fun...
Definition: transform.h:394
ffi::String name
The name of pattern. It becomes the value of the kComposite attribute of a fused function after succe...
Definition: transform.h:373
ffi::Map< ffi::String, DFPattern > annotation_patterns
The map which is used to extract important expressions from the pattern match result....
Definition: transform.h:385
DFPattern pattern
The dataflow pattern that will be used to match expression in the DataflowBlock. All the call nodes c...
Definition: transform.h:379
static void RegisterReflection()
Definition: transform.h:404
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.transform.FusionPattern", FusionPatternNode, Object)
Definition: transform.h:416
FusionPattern(ffi::String name, DFPattern pattern)
Definition: transform.h:422
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(FusionPattern, ObjectRef, FusionPatternNode)
FusionPattern(ffi::String name, DFPattern pattern, ffi::Map< ffi::String, DFPattern > annotation_patterns, ffi::Optional< ffi::Function > check, ffi::Optional< ffi::Function > attrs_getter)
The input of FusionPattern::check.
Definition: transform.h:431
ffi::Map< Expr, Var > value_to_bound_var
Map from value to its bound variable. It doesn't have variables after the matched expression.
Definition: transform.h:460
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.transform.PatternCheckContext", PatternCheckContextNode, Object)
ffi::Map< Var, Expr > matched_bindings
Map from variable to its value. It contains variables from bindings that is being fused by FuseOpsByP...
Definition: transform.h:448
static void RegisterReflection()
Definition: transform.h:462
ffi::Map< ffi::String, Expr > annotated_expr
A map which contains all expressions matched by the sub patterns in FusionPattern::annotation_pattern...
Definition: transform.h:442
Expr matched_expr
The expression that's matched with the FusionPattern::pattern.
Definition: transform.h:436
ffi::Map< Var, ffi::Array< Var > > var_usages
A map mapping variable definitions to a set of uses. It has all variables used in the function.
Definition: transform.h:454
PatternCheckContext(Expr matched_expr, ffi::Map< ffi::String, Expr > annotated_expr, ffi::Map< Var, Expr > matched_bindings, ffi::Map< Var, ffi::Array< Var >> var_usages, ffi::Map< Expr, Var > value_to_bound_var)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PatternCheckContext, ObjectRef, PatternCheckContextNode)
Runtime primitive data type.
Definition: data_type.h:47
PassContext that is used to configure the pass behavior.
Definition: transform.h:153
Managed reference class for PassInfoNode.
Definition: transform.h:350
Definition: transform.h:400
A pattern language for matching dataflow properties.
Defines a remapping of buffer indices.
Definition: repr_printer.h:91
Pass RealizeVDevice()
Propagate virtual device information.
Pass Gradient(ffi::String func_name, ffi::Optional< ffi::Array< Var >> require_grads=std::nullopt, int target_index=0)
Reverse-mode automatic differentiation.
Pass CreateFunctionPass(std::function< Function(Function, IRModule, PassContext)> pass_func, int opt_level, ffi::String name, tvm::ffi::Array< ffi::String > required, bool traceable=false)
Create a function pass.
Pass MergeCompositeFunctions()
Group one or multiple composite functions created by FuseOpsByPattern into a new function....
tvm::relax::DataflowBlock DataflowBlock
Definition: transform.h:42
Pass DataflowUseInplaceCalls()
Pass that changes calls to operators that can be done in-place (generally, these are elementwise oper...
Pass FuseOpsByPattern(const tvm::ffi::Array< FusionPattern > &patterns, bool bind_constants=true, bool annotate_codegen=false, const tvm::ffi::Array< ffi::String > &entry_function_names={})
Apply pattern matching to each function in the given module, and group matched expressions into a new...
Pass ToMixedPrecision(const DataType &out_dtype, ffi::Optional< ffi::Array< ffi::String >> fp16_input_names=std::nullopt)
Automatic mixed precision pass. Currently the pass assumes the input module to be fp32 only,...
Pass BindSymbolicVars(ffi::Map< ffi::Variant< tirx::Var, ffi::String >, PrimExpr > binding_map, ffi::Optional< ffi::String > func_name=std::nullopt)
Bind symbolic vars to constant shape values.
Pass CallTIRRewrite()
Perform explicit tensor allocation for call_tir and call_dps_packed.
Pass BindParams(ffi::String func_name, ffi::Map< Any, ObjectRef > params)
Bind params of function of the module to constant tensors.
Pass Normalize()
Transform Relax IR to normal form: transform AST to A-normal form, and fill the struct_info_ of expre...
tvm::relax::Function Function
Definition: transform.h:41
Pass LiftTransformParams(ffi::Variant< Bool, ffi::Array< ffi::String >> shared_transform=Bool(false))
Lift transformation of the parameters of a function.
Pass RemoveUnusedParameters()
Remove unused parameters to internal functions.
Pass SpecializePrimFuncBasedOnCallSite()
This pass updates the var_buffer mapping of PrimFunctions from the call_tir info. Primarily used to u...
Pass FuseTIR()
Fuse relax sub-function into a larger TIR function if possible. this pass works together with FuseOps...
Pass LegalizeOps(ffi::Optional< ffi::Map< ffi::String, ffi::Function >> cmap, ffi::Optional< ffi::Array< ffi::String >> skip_ops, bool enable_warning=false)
Legalize high-level operator calls in Relax functions to call_tir with corresponding low-level TIR Pr...
Pass AlterOpImpl(const ffi::Map< ffi::String, tirx::PrimFunc > &op_impl_map, const ffi::Map< ffi::String, ffi::Array< tirx::IndexMap >> &op_buffer_transforms, const ffi::Map< ffi::String, ffi::Optional< ffi::Array< ffi::Array< IntImm >>>> &axis_separators, const ffi::Map< ffi::String, ffi::Optional< ffi::Array< ffi::Array< IntImm >>>> &input_axis_separators)
Returns a pass which replaces PrimFuncs which have matching kOperatorName attribute in op_impl_map,...
Pass SplitLayoutRewritePreproc()
Split the layout rewrite preproc block to a separate tirx::PrimFunc.
Pass AttachGlobalSymbol()
Attach global_symbol to Relax functions and TIR Primfuncs for codegen.
Pass FuseOps(int fuse_opt_level=-1)
This pass groups bindings in a dataflow block of Relax functions and generates a new grouped Relax fu...
Pass AnnotateTIROpPattern()
Annotate Op Pattern Kind for TIR functions, which is used in FuseOps.
Pass RemovePurityChecking()
Activate force_pure on all pure functions in the module and unwrap all pure override ops into the nor...
Pass AttachAttrLayoutFreeBuffers()
Attach layout free buffers to the tirx::PrimFunc.
tvm::transform::PassContext PassContext
Definition: transform.h:40
Pass NormalizeGlobalVar()
Possibly rename the GlobalVar in an IRModule to ensure these properties:
tvm::transform::PassInfo PassInfo
Definition: transform.h:39
Pass RunCodegen(ffi::Optional< ffi::Map< ffi::String, ffi::Map< ffi::String, ffi::Any >>> target_options, ffi::Array< ffi::String > entry_functions)
Run codegen.
Pass RemoveUnusedOutputs()
Remove unused outputs from internal functions.
Pass ExpandTupleArguments()
Expand tuple arguments to internal functions.
Pass FoldConstant()
Fold constant expressions within dataflow blocks.
tvm::transform::Pass Pass
Definition: transform.h:38
Pass CanonicalizeBindings()
Simplify a Relax module by folding var bindings and match shape nodes, as well as tuple indices....
Pass RewriteDataflowReshape()
Convert all reshape-like call_tir whose corresponding binding vars are DataflowVars to relax....
Pass DecomposeOpsForTraining(ffi::Optional< ffi::String > func_name)
Decompose composite operators during training. For example, The result of batch norm (a triple) will ...
ffi::TypedFunction< ffi::Map< ffi::String, ffi::Array< ffi::String > >(Call)> LayoutCb
Definition: transform.h:44
Pass LambdaLift()
Perform lambda lifting to lift functions from nested into global.
Pass CreateDataflowBlockPass(std::function< DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func, int opt_level, ffi::String name, tvm::ffi::Array< ffi::String > required, bool traceable=false)
Create a dataflowblock pass.
Pass RewriteCUDAGraph()
Rewrite a Relax module for executing with CUDA graph. This pass identifies the regions that can be ex...
Pass DecomposeOpsForInference(ffi::Optional< ffi::String > func_name)
Decompose composite operators during inference. For example, The result of batch norm (a triple) will...
Pass ConvertToDataflow(int min_size=2)
A pass that converts consecutive dataflow operations inside binding blocks into dataflow blocks.
Pass EliminateCommonSubexpr(bool call_only=false)
Pass UpdateVDevice(VDevice new_vdevice, int64_t index)
Update virtual device.
Pass StaticPlanBlockMemory()
The static memory planning pass on BindingBlock level. The pass will reuse allocated memory to its be...
Pass ToNonDataflow()
Transform all dataflow structure to non-dataflow version.
Pass ConvertLayout(ffi::Map< ffi::String, ffi::Array< ffi::String >> desired_layouts, LayoutCb layout_cb)
Layout conversion pass.
Pass DeadCodeElimination(ffi::Array< ffi::String > entry_functions={})
Dead code elimination.
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:233
Pass CreateModulePass(std::function< IRModule(IRModule, PassContext)> pass_func, int opt_level, ffi::String name, ffi::Array< ffi::String > required, bool traceable=false)
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
TIR Function.