tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_TRANSFORM_H_
25 #define TVM_RELAX_TRANSFORM_H_
26 
27 #include <tvm/ir/transform.h>
29 #include <tvm/relax/expr.h>
30 #include <tvm/tir/function.h>
31 #include <tvm/tir/index_map.h>
32 namespace tvm {
33 namespace relax {
34 namespace transform {
35 
41 
55  int opt_level, String name, tvm::Array<String> required, bool traceable = false);
56 
70  int opt_level, String name, tvm::Array<String> required, bool traceable = false);
71 
77 TVM_DLL Pass LambdaLift();
78 
84 TVM_DLL Pass ToNonDataflow();
85 
98 
105 
119 
141 
148 
155 TVM_DLL Pass Normalize();
156 
165 
178 
186 TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false);
187 
196 TVM_DLL Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params);
197 
214  Optional<String> func_name = NullOpt);
215 
223 TVM_DLL Pass FoldConstant();
224 
248 TVM_DLL Pass LegalizeOps(Optional<Map<String, PackedFunc>> cmap, bool enable_warning = false);
249 
255 
267 
276 
303 TVM_DLL Pass LiftTransformParams(Variant<Bool, Array<String>> shared_transform = Bool(false));
304 
311 TVM_DLL Pass UpdateVDevice(VDevice new_vdevice, int64_t index);
312 
318 
324 
330 
339 
351 TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
352 
358 class FusionPatternNode : public Object {
359  public:
365 
371 
377 
386 
394 
396  v->Visit("name", &name);
397  v->Visit("pattern", &pattern);
398  v->Visit("annotation_patterns", &annotation_patterns);
399  v->Visit("check", &check);
400  v->Visit("attrs_getter", &attrs_getter);
401  }
402 
403  static constexpr const char* _type_key = "relax.transform.FusionPattern";
405 };
406 
407 class FusionPattern : public ObjectRef {
408  public:
409  FusionPattern(String name, DFPattern pattern, Map<String, DFPattern> annotation_patterns,
410  Optional<PackedFunc> check, Optional<PackedFunc> attrs_getter);
411 
413  : FusionPattern(name, pattern, {}, NullOpt, NullOpt) {}
414 
416 };
417 
422  public:
427 
433 
439 
445 
451 
453  v->Visit("matched_expr", &matched_expr);
454  v->Visit("annotated_expr", &annotated_expr);
455  v->Visit("matched_bindings", &matched_bindings);
456  v->Visit("var_usages", &var_usages);
457  v->Visit("value_to_bound_var", &value_to_bound_var);
458  }
459 
460  static constexpr const char* _type_key = "relax.transform.PatternCheckContext";
462 };
463 
465  public:
466  PatternCheckContext(Expr matched_expr, Map<String, Expr> annotated_expr,
467  Map<Var, Expr> matched_bindings, Map<Var, Array<Var>> var_usages,
468  Map<Expr, Var> value_to_bound_var);
469 
472 };
473 
499 TVM_DLL Pass Gradient(String func_name, Optional<Array<Var>> require_grads = NullOpt,
500  int target_index = 0);
501 
522 TVM_DLL Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool bind_constants = true,
523  bool annotate_codegen = false,
524  const tvm::Array<String>& entry_function_names = {});
525 
534 
541 TVM_DLL Pass FuseTIR();
542 
550  Array<runtime::String> entry_functions);
551 
561 
571 
586 TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
587  const Map<String, Array<tir::IndexMap>>& op_buffer_transforms,
589  const Map<String, Array<Array<IntImm>>>& input_axis_separators);
590 
597 TVM_DLL Pass ConvertLayout(Map<String, Array<String>> desired_layouts);
598 
606 TVM_DLL Pass ConvertToDataflow(int min_size = 2);
607 
624 TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions = {});
625 
636 
647 TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype,
648  Optional<Array<String>> fp16_input_names = NullOpt);
649 
656 
667 TVM_DLL Pass FewShotTuning(int valid_count, bool benchmark);
668 
669 } // namespace transform
670 } // namespace relax
671 } // namespace tvm
672 
673 #endif // TVM_RELAX_TRANSFORM_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Boolean constant.
Definition: expr.h:597
Managed reference class to IRModuleNode.
Definition: module.h:366
Managed reference to RelayExprNode.
Definition: expr.h:442
Managed reference to VDeviceNode.
Definition: global_info.h:95
Managed reference to dataflow patterns.
Definition: dataflow_pattern.h:101
Definition: expr.h:806
Definition: expr.h:995
Definition: expr.h:422
The pattern object used as the input of FuseOpsByPattern. For bindings to be fused,...
Definition: transform.h:358
Map< String, DFPattern > annotation_patterns
The map which is used to extract important expressions from the pattern match result....
Definition: transform.h:376
TVM_DECLARE_FINAL_OBJECT_INFO(FusionPatternNode, Object)
String name
The name of pattern. It becomes the value of the kComposite attribute of a fused function after succe...
Definition: transform.h:364
void VisitAttrs(tvm::AttrVisitor *v)
Definition: transform.h:395
Optional< PackedFunc > attrs_getter
The function to get attributes for fused function.
Definition: transform.h:393
Optional< PackedFunc > check
The function to determine whether the match result is accepted. This can be NullOpt if check function...
Definition: transform.h:385
static constexpr const char * _type_key
Definition: transform.h:403
DFPattern pattern
The dataflow pattern that will be used to match expression in the DataflowBlock. All the call nodes c...
Definition: transform.h:370
Definition: transform.h:407
FusionPattern(String name, DFPattern pattern, Map< String, DFPattern > annotation_patterns, Optional< PackedFunc > check, Optional< PackedFunc > attrs_getter)
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FusionPattern, ObjectRef, FusionPatternNode)
FusionPattern(String name, DFPattern pattern)
Definition: transform.h:412
The input of FusionPattern::check.
Definition: transform.h:421
void VisitAttrs(tvm::AttrVisitor *v)
Definition: transform.h:452
Map< Expr, Var > value_to_bound_var
Map from value to its bound variable. It doesn't have variables after the matched expression.
Definition: transform.h:450
Map< Var, Expr > matched_bindings
Map from variable to its value. It contains variables from bindings that is being fused by FuseOpsByP...
Definition: transform.h:438
Map< String, Expr > annotated_expr
A map which contains all expressions matched by the sub patterns in FusionPattern::annotation_pattern...
Definition: transform.h:432
TVM_DECLARE_FINAL_OBJECT_INFO(PatternCheckContextNode, Object)
static constexpr const char * _type_key
Definition: transform.h:460
Map< Var, Array< Var > > var_usages
A map mapping variable definitions to a set of uses. It has all variables used in the function.
Definition: transform.h:444
Expr matched_expr
The expression that's matched with the FusionPattern::pattern.
Definition: transform.h:426
PatternCheckContext(Expr matched_expr, Map< String, Expr > annotated_expr, Map< Var, Expr > matched_bindings, Map< Var, Array< Var >> var_usages, Map< Expr, Var > value_to_bound_var)
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternCheckContext, ObjectRef, PatternCheckContextNode)
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Runtime primitive data type.
Definition: data_type.h:43
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Base class of all object reference.
Definition: object.h:519
base class of all object containers.
Definition: object.h:171
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:63
Definition: variant.h:69
PassContext that is used to configure the pass behavior.
Definition: transform.h:182
Managed reference class for PassInfoNode.
Definition: transform.h:373
Definition: transform.h:426
Defines a remapping of buffer indices.
Pass RealizeVDevice()
Propagate virtual device information.
Pass MergeCompositeFunctions()
Group one or multiple composite functions created by FuseOpsByPattern into a new function....
tvm::relax::DataflowBlock DataflowBlock
Definition: transform.h:40
Pass DataflowUseInplaceCalls()
Pass that changes calls to operators that can be done in-place (generally, these are elementwise oper...
Pass CallTIRRewrite()
Perform explicit tensor allocation for call_tir and call_dps_packed.
Pass CreateFunctionPass(const runtime::TypedPackedFunc< Function(Function, IRModule, PassContext)> &pass_func, int opt_level, String name, tvm::Array< String > required, bool traceable=false)
Create a function pass.
Pass Gradient(String func_name, Optional< Array< Var >> require_grads=NullOpt, int target_index=0)
Reverse-mode automatic differentiation.
Pass Normalize()
Transform Relax IR to normal form: transform AST to A-normal form, and fill the checked_type_ and sha...
tvm::relax::Function Function
Definition: transform.h:39
Pass LegalizeOps(Optional< Map< String, PackedFunc >> cmap, bool enable_warning=false)
Legalize high-level operator calls in Relax functions to call_tir with corresponding low-level TIR Pr...
Pass DecomposeOpsForTraining(Optional< String > func_name)
Decompose composite operators during training. For example, The result of batch norm (a triple) will ...
Pass RemoveUnusedParameters()
Remove unused parameters to internal functions.
Pass ToMixedPrecision(const DataType &out_dtype, Optional< Array< String >> fp16_input_names=NullOpt)
Automatic mixed precision pass. Currently the pass assumes the input module to be fp32 only,...
Pass FuseTIR()
Fuse relax sub-function into a larger TIR function if possible. this pass works together with FuseOps...
Pass FewShotTuning(int valid_count, bool benchmark)
The pass is designed for few shot tuning for static shape PrimFuncs. It examines all the blocks withi...
Pass CreateDataflowBlockPass(const runtime::TypedPackedFunc< DataflowBlock(DataflowBlock, IRModule, PassContext)> &pass_func, int opt_level, String name, tvm::Array< String > required, bool traceable=false)
Create a dataflowblock pass.
Pass SplitLayoutRewritePreproc()
Split the layout rewrite preproc block to a separate tir::PrimFunc.
Pass AttachGlobalSymbol()
Attach global_symbol to Relax functions and TIR Primfuncs for codegen.
Pass DeadCodeElimination(Array< runtime::String > entry_functions={})
Dead code elimination.
Pass AlterOpImpl(const Map< String, tir::PrimFunc > &op_impl_map, const Map< String, Array< tir::IndexMap >> &op_buffer_transforms, const Map< String, Array< Array< IntImm >>> &axis_separators, const Map< String, Array< Array< IntImm >>> &input_axis_separators)
Returns a pass which replaces PrimFuncs which have matching kOperatorName attribute in op_impl_map,...
Pass FuseOps(int fuse_opt_level=-1)
This pass groups bindings in a dataflow block of Relax functions and generates a new grouped Relax fu...
Pass FuseOpsByPattern(const tvm::Array< FusionPattern > &patterns, bool bind_constants=true, bool annotate_codegen=false, const tvm::Array< String > &entry_function_names={})
Apply pattern matching to each function in the given module, and group matched expressions into a new...
Pass AnnotateTIROpPattern()
Annotate Op Pattern Kind for TIR functions, which is used in FuseOps.
Pass RemovePurityChecking()
Activate force_pure on all pure functions in the module and unwrap all pure override ops into the nor...
Pass AttachAttrLayoutFreeBuffers()
Attach layout free buffers to the tir::PrimFunc.
Pass NormalizeGlobalVar()
Possibly rename the GlobalVar in an IRModule to ensure these properties:
Pass ConvertLayout(Map< String, Array< String >> desired_layouts)
Layout conversion pass.
Pass RemoveUnusedOutputs()
Remove unused outputs from internal functions.
Pass ExpandTupleArguments()
Expand tuple arguments to internal functions.
Pass DecomposeOpsForInference(Optional< String > func_name)
Decompose composite operators during inference. For example, The result of batch norm (a triple) will...
Pass FoldConstant()
Fold constant expressions within dataflow blocks.
Pass CanonicalizeBindings()
Simplify a Relax module by folding var bindings and match shape nodes, as well as tuple indices....
Pass RewriteDataflowReshape()
Convert all reshape-like call_tir whose corresponding binding vars are DataflowVars to relax....
Pass RunCodegen(Optional< Map< String, Map< String, ObjectRef >>> target_options, Array< runtime::String > entry_functions)
Run codegen.
Pass LambdaLift()
Perform lambda lifting to lift functions from nested into global.
Pass BindSymbolicVars(Map< ObjectRef, PrimExpr > binding_map, Optional< String > func_name=NullOpt)
Bind symbolic vars to constant shape values.
Pass RewriteCUDAGraph()
Rewrite a Relax module for executing with CUDA graph. This pass identifies the regions that can be ex...
Pass BindParams(String func_name, Map< ObjectRef, ObjectRef > params)
Bind params of function of the module to constant tensors.
Pass ConvertToDataflow(int min_size=2)
A pass that converts consecutive dataflow operations inside binding blocks into dataflow blocks.
Pass EliminateCommonSubexpr(bool call_only=false)
Pass UpdateVDevice(VDevice new_vdevice, int64_t index)
Update virtual device.
Pass StaticPlanBlockMemory()
The static memory planning pass on BindingBlock level. The pass will reuse allocated memory to its be...
Pass ToNonDataflow()
Transform all dataflow structure to non-dataflow version.
Pass LiftTransformParams(Variant< Bool, Array< String >> shared_transform=Bool(false))
Lift transformation of the parameters of a function.
tvm::transform::PassContext PassContext
Definition: transform.h:47
tvm::transform::PassInfo PassInfo
Definition: transform.h:45
Box< bool > Bool
Boxed version of C++ bool.
Definition: boxed_primitive.h:121
constexpr const char * axis_separators
Marks the physical axis separators.
Definition: stmt.h:1458
tvm::transform::Pass Pass
Definition: transform.h:35
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
A pattern language for matching dataflow properties.
TIR Function.