tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAY_TRANSFORM_H_
25 #define TVM_RELAY_TRANSFORM_H_
26 
27 #include <tvm/ir/transform.h>
29 #include <tvm/relay/expr.h>
30 #include <tvm/relay/function.h>
31 #include <tvm/relay/op.h>
34 #include <tvm/target/target.h>
36 
37 #include <string>
38 
39 namespace tvm {
40 namespace relay {
41 namespace transform {
42 
50 
51 /*
52  * \brief Create a function pass.
53  *
54  * \param pass_func The packed function that contains the optimization.
55  * \param opt_level The optimization level of the function pass.
56  * \param name The name of the function pass.
57  * \param required The list of the passes that the function pass is dependent on.
58  *
59  * \return The created function pass.
60  */
61 TVM_DLL Pass CreateFunctionPass(
63  int opt_level, String name, tvm::Array<String> required);
64 
89 TVM_DLL Pass DeadCodeElimination(bool inline_once = false, bool ignore_purity = false);
90 
103 TVM_DLL Pass LazyGradientInit();
104 
118 TVM_DLL Pass FoldConstant(bool fold_qnn = false);
119 
125 TVM_DLL Pass SplitArgs(int max_function_args);
126 
134 TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
135 
142 TVM_DLL Pass DefuseOps();
143 
152 TVM_DLL Pass RewriteAnnotatedOps(int fallback_device);
153 
167 TVM_DLL Pass ToBasicBlockNormalForm();
168 
183 TVM_DLL Pass ToANormalForm();
184 
192 TVM_DLL Expr ToANormalForm(const Expr& expr);
193 
208 TVM_DLL Pass ToCPS();
209 
218 TVM_DLL Pass ToGraphNormalForm();
219 
229 TVM_DLL Pass PartialEval();
230 
238 TVM_DLL Pass SimplifyInference();
239 
245 TVM_DLL Pass FastMath();
246 
256 TVM_DLL Pass DynamicToStatic();
257 
267 TVM_DLL Pass InferType();
268 
282 TVM_DLL Type InferTypeLocal(const Expr& expr);
283 
293 TVM_DLL Pass EliminateCommonSubexpr(runtime::PackedFunc fskip = nullptr);
294 
304 TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);
305 
317 TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3, bool to_batch_matmul = true);
318 
328 TVM_DLL Pass CombineParallelBatchMatmul(uint64_t min_num_branches = 3);
329 
335 TVM_DLL Pass BackwardFoldScaleAxis();
336 
342 TVM_DLL Pass ForwardFoldScaleAxis();
343 
350 TVM_DLL Pass FoldScaleAxis();
351 
358 TVM_DLL Pass CanonicalizeOps();
359 
366 TVM_DLL Pass AlterOpLayout();
367 
373 
379 
400 TVM_DLL Pass ConvertLayout(const Map<String, Array<String>>& desired_layouts);
401 
412 TVM_DLL Pass Legalize(const String& legalize_map_attr_name = "FTVMLegalize");
413 
419 TVM_DLL Pass CanonicalizeCast();
420 
435 TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);
436 
443 TVM_DLL Pass PartitionGraph();
444 
451 TVM_DLL Pass Inline();
452 
461 TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
462 
468 TVM_DLL Pass SimplifyExpr();
469 
475 TVM_DLL Pass SimplifyExprPostAlterOp();
476 
522 
532 TVM_DLL Pass ManifestAlloc(VirtualDevice cpu_virtual_device);
533 
540 TVM_DLL Pass ManifestLifetimes();
541 
555 TVM_DLL Pass PlanDevices(CompilationConfig config);
556 
564 TVM_DLL Pass FlattenAtrousConv();
565 
577 TVM_DLL Pass AnnotateUsedMemory();
578 
589 
594 TVM_DLL Pass AnnotateMemoryScope();
595 
605 
606 } // namespace transform
607 
622 TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
623 
635 TVM_DLL Function SubstituteBoundVars(const Function& func, const tvm::Map<Var, Expr>& binds);
636 
649 TVM_DLL Expr ForwardRewrite(const Expr& expr, const String& rewrite_map_attr_name,
650  std::function<ObjectRef(const Call&)> fcontext = nullptr,
651  std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
652 
665 TVM_DLL Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func,
666  std::function<ObjectRef(const Call&)> fcontext = nullptr,
667  std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
668 
678 TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
679 
697 TVM_DLL Function ToCPS(const Function& f, const IRModule& mod);
698 
709 TVM_DLL Function UnCPS(const Function& f);
710 
718 TVM_DLL Expr DeDup(const Expr& e);
719 
720 namespace legalize {
721 TVM_DLL Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name);
722 } // namespace legalize
723 
724 } // namespace relay
725 } // namespace tvm
726 
727 #endif // TVM_RELAY_TRANSFORM_H_
Pass AutoSchedulerLayoutRewrite()
Do layout rewrite according to the tile structure created by auto-scheduler.
Managed reference class to CompilationConfig.
Definition: compilation_config.h:191
Pass ToCPS()
Turn an expression into continuation passing style(CPS).
Pass CanonicalizeCast()
Canonicalize cast expressions to make operator fusion more efficient.
Pass PartitionGraph()
Partition a Relay program into regions that can be executed on different backends.
Pass SimplifyExprPostAlterOp()
Stripped down version of SimplifyExpr which is run after AlterOpLayout.
Pass ToGraphNormalForm()
Remove let binding and directly share via pointer instead.
Pass BackwardFoldScaleAxis()
Backward fold axis scaling into weights of conv/dense operators.
Expr ForwardRewrite(const Expr &expr, const String &rewrite_map_attr_name, std::function< ObjectRef(const Call &)> fcontext=nullptr, std::function< Expr(const Expr &)> fmulti_ref_trigger=nullptr)
Apply rewrite rules to rewrite the expr in post DFS order. This function is used as a helper function...
Function UnCPS(const Function &f)
Remove the continuation argument of a CPS function.
Managed reference to FunctionNode.
Definition: function.h:105
A compile time representation for where data is to be stored at runtime, and how to compile code to c...
Relay expression language.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Pass ToANormalForm()
turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
Pass AnnotateMemoryScope()
Calls device dependent memory scope analysis pass, collects mapping of desirable expr->memory_scope a...
Pass FoldConstant(bool fold_qnn=false)
Fold constant expressions.
Pass InferType()
Infer the type of an expression.
tvm::transform::Sequential Sequential
Definition: transform.h:49
Pass CombineParallelConv2D(uint64_t min_num_branches=3)
Combine parallel 2d convolutions into a single convolution if the number of branches of this conv2d o...
Relay Function.
Pass Inline()
Inline the global functions marked as inline in a given Relay IRModule.
Pass Legalize(const String &legalize_map_attr_name="FTVMLegalize")
Legalizes an expr with another expression.
The Expr and related elements in DataFlow construction.
Pass EliminateCommonSubexpr(runtime::PackedFunc fskip=nullptr)
Search and eliminate common subexpression. For example, if there are two expressions evaluated to an ...
Definition: expr.h:357
tvm::transform::PassNode PassNode
Definition: transform.h:44
Pass ManifestLifetimes()
A pass for manifesting variable lifetimes by inserting kill operations when variables become dead...
Managed reference class for PassInfoNode.
Definition: transform.h:310
Pass AlterOpLayout()
Alternate the layouts of operators or replace primitive operators with other expressions.
Pass DeadCodeElimination(bool inline_once=false, bool ignore_purity=false)
Remove let-bound expressions which do not effect the program result.
Managed reference class to VirtualDeviceNode.
Definition: virtual_device.h:271
Pass EtaExpand(bool expand_constructor, bool expand_global_var)
Add abstraction over a constructor or global variable bound to a function.
Expr Bind(const Expr &expr, const tvm::Map< Var, Expr > &binds)
Bind the free variables to a Relay expression. This is a helper function usually called by other pass...
Type InferTypeLocal(const Expr &expr)
Infer the type of an expression, reusing existing type information.
Pass ForwardFoldScaleAxis()
Forward fold axis scaling into weights of conv/dense operators.
Expr DeDup(const Expr &e)
Deduplicate the bound variables and type variables in the expression.
Pass DefuseOps()
The inverse operation of FuseOps. It transforms a fused program returned by FuseOps into the program ...
Pass SimplifyInference()
Simplify certain operators during inference. For example, the result of a batch norm which is indexed...
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Pass PlanDevices(CompilationConfig config)
Uses existing "on_device" and "device_copy" CallNodes to infer the VirtualDevice on which every Relay...
Pass FoldScaleAxis()
A sequential pass that executes ForwardFoldScaleAxis and BackwardFoldScaleAxis passes.
Definition: transform.h:362
tvm::transform::PassInfoNode PassInfoNode
Definition: transform.h:46
PassContext that is used to configure the pass behavior.
Definition: transform.h:153
A helper class to collect all the targets in canonical form necessary for compilation.
Reference to string objects.
Definition: string.h:98
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:60
Pass ManifestAlloc(VirtualDevice cpu_virtual_device)
A pass for manifesting explicit memory allocations and rewriting specific dialects.
Managed reference to RelayExprNode.
Definition: expr.h:433
Pass ConvertLayout(const Map< String, Array< String >> &desired_layouts)
Given a dest layout, this pass transforms the expr such that most of the ops input data layout is cha...
Pass SplitArgs(int max_function_args)
Split function with huge number of arguments to smaller pieces.
Pass RewriteAnnotatedOps(int fallback_device)
Rewrite the annotated program.
Pass ToBasicBlockNormalForm()
Turn an expression to Basic Block Normal Form.
Pass CreateFunctionPass(const runtime::TypedPackedFunc< Function(Function, IRModule, PassContext)> &pass_func, int opt_level, String name, tvm::Array< String > required)
Pass FuseOps(int fuse_opt_level=-1)
Fuse operations into expr into separate functions.
Base class of all object reference.
Definition: object.h:511
Function SubstituteBoundVars(const Function &func, const tvm::Map< Var, Expr > &binds)
Substitute variables with new variables (including function parameters) in a function. This is a helper function usually called by other pass functions to help optimizations. Expects all values in the bind map to be Vars.
Transform operators.
Pass MetaScheduleLayoutRewrite()
Do layout rewrite according to the tile structure created by meta-schedule.
Pass CombineParallelDense(uint64_t min_num_branches=3, bool to_batch_matmul=true)
Combine parallel dense ops into a single batch_matmul if the number of branches of this dense operato...
tvm::transform::Pass Pass
Definition: transform.h:35
tvm::RelayExpr Expr
Definition: expr.h:54
Pass LazyGradientInit()
Convert all expressions of TensorType into GradCell, an algebraic data type defined in gradient...
Meta data that will be used to help optimization and analysis.
Definition: transform.h:282
Managed reference class to IRModuleNode.
Definition: module.h:348
PassNode is the base type of differnt types of optimization passes. It is designed as a pure class an...
Definition: transform.h:328
PassContextNode contains the information that a pass can rely on, such as analysis results...
Definition: transform.h:77
Pass DynamicToStatic()
Find Dynamic ops and make them static.
tvm::transform::PassContext PassContext
Definition: transform.h:47
Compilation target object.
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:138
Managed reference to TypeNode.
Definition: type.h:93
Pass RemoveStandaloneReshapes()
Removes non-fused reshapes after lowering the graph. InferType() cannot be invoked after calling this...
Pass CombineParallelBatchMatmul(uint64_t min_num_branches=3)
Combine parallel batch_matmul ops into a single batch_matmul if the number of branches of this dense ...
Pass FlattenAtrousConv()
This transform flattens atrous convolution, which corresponds to the sequence of operations: "space_t...
Pass CapturePostDfsIndexInSpans()
Captures the post-dfs index and dominator post-dfs index of (most) expression nodes in their span...
Pass AnnotateUsedMemory()
Annotates the minimum required memory of each primitive function callsite by analyzing the liveness o...
Pass CanonicalizeOps()
Canonicalize some operators to the simplified operators. For example, bias_add can be canonicalized t...
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
Pass RemoveUnusedFunctions(Array< runtime::String > entry_functions)
Remove the unused functions in the Relay IRModule.
Pass RelayToTIRTargetHook(CompilationConfig config)
Run any custom passes registered under "RelayToTIR" attributes on TargetKinds.
Pass PartialEval()
Aggressive constant propagation/constant folding/inlining.
tvm::transform::PassContextNode PassContextNode
Definition: transform.h:48
Primitive operators(builtin intrinsics).
Pass FastMath()
Replaces non linear activation functions with their fast but approximate counterparts.
tvm::transform::PassInfo PassInfo
Definition: transform.h:45
Definition: transform.h:455
Pass SimplifyExpr()
Simplify the Relay expression.