tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAY_TRANSFORM_H_
25 #define TVM_RELAY_TRANSFORM_H_
26 
27 #include <tvm/ir/transform.h>
29 #include <tvm/relay/expr.h>
30 #include <tvm/relay/function.h>
31 #include <tvm/relay/op.h>
34 #include <tvm/target/target.h>
36 
37 #include <string>
38 
39 namespace tvm {
40 namespace relay {
41 namespace transform {
42 
62 
74 
75 /*
76  * \brief Create a function pass.
77  *
78  * \param pass_func The packed function that contains the optimization.
79  * \param opt_level The optimization level of the function pass.
80  * \param name The name of the function pass.
81  * \param required The list of the passes that the function pass is dependent on.
82  *
83  * \return The created function pass.
84  */
87  int opt_level, String name, tvm::Array<String> required, bool traceable = false);
88 
113 TVM_DLL Pass DeadCodeElimination(bool inline_once = false, bool ignore_purity = false);
114 
128 
142 TVM_DLL Pass FoldConstant(bool fold_qnn = false);
143 
152 TVM_DLL Pass SplitArgs(uint64_t max_function_args);
153 
161 TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
162 
169 TVM_DLL Pass DefuseOps();
170 
179 TVM_DLL Pass RewriteAnnotatedOps(int fallback_device);
180 
195 
210 TVM_DLL Pass ToANormalForm();
211 
219 TVM_DLL Expr ToANormalForm(const Expr& expr);
220 
235 TVM_DLL Pass ToCPS();
236 
246 
256 TVM_DLL Pass PartialEval();
257 
266 
272 TVM_DLL Pass FastMath();
273 
284 
294 TVM_DLL Pass InferType();
295 
309 TVM_DLL Type InferTypeLocal(const Expr& expr);
310 
321 
331 TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);
332 
344 TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3, bool to_batch_matmul = true);
345 
355 TVM_DLL Pass CombineParallelBatchMatmul(uint64_t min_num_branches = 3);
356 
363 
370 
377 TVM_DLL Pass FoldScaleAxis();
378 
386 
393 TVM_DLL Pass AlterOpLayout();
394 
400 
406 
427 TVM_DLL Pass ConvertLayout(const Map<String, Array<String>>& desired_layouts);
428 
439 TVM_DLL Pass Legalize(const String& legalize_map_attr_name = "FTVMLegalize");
440 
447 
462 TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);
463 
471 
478 TVM_DLL Pass Inline();
479 
489 
495 TVM_DLL Pass SimplifyExpr();
496 
503 
549 
559 TVM_DLL Pass ManifestAlloc(VirtualDevice cpu_virtual_device);
560 
568 
583 
592 
605 
616 
622 
632 
633 } // namespace transform
634 
649 TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
650 
662 TVM_DLL Function SubstituteBoundVars(const Function& func, const tvm::Map<Var, Expr>& binds);
663 
676 TVM_DLL Expr ForwardRewrite(const Expr& expr, const String& rewrite_map_attr_name,
677  std::function<ObjectRef(const Call&)> fcontext = nullptr,
678  std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
679 
692 TVM_DLL Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func,
693  std::function<ObjectRef(const Call&)> fcontext = nullptr,
694  std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
695 
705 TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
706 
724 TVM_DLL Function ToCPS(const Function& f, const IRModule& mod);
725 
736 TVM_DLL Function UnCPS(const Function& f);
737 
745 TVM_DLL Expr DeDup(const Expr& e);
746 
747 namespace legalize {
748 TVM_DLL Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name);
749 } // namespace legalize
750 
751 } // namespace relay
752 } // namespace tvm
753 
754 #endif // TVM_RELAY_TRANSFORM_H_
Managed reference class to CompilationConfig.
Definition: compilation_config.h:191
Managed reference class to IRModuleNode.
Definition: module.h:366
Managed reference to RelayExprNode.
Definition: expr.h:442
Managed reference class to TargetNode.
Definition: target.h:200
Managed reference to TypeNode.
Definition: type.h:93
Managed reference class to VirtualDeviceNode.
Definition: virtual_device.h:271
Definition: expr.h:357
Managed reference to FunctionNode.
Definition: function.h:105
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Module container of TVM.
Definition: module.h:79
Base class of all object reference.
Definition: object.h:519
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:141
Reference to string objects.
Definition: string.h:98
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:63
PassContextNode contains the information that a pass can rely on, such as analysis results.
Definition: transform.h:77
PassContext that is used to configure the pass behavior.
Definition: transform.h:182
Meta data that will be used to help optimization and analysis.
Definition: transform.h:341
Managed reference class for PassInfoNode.
Definition: transform.h:373
PassNode is the base type of differnt types of optimization passes. It is designed as a pure class an...
Definition: transform.h:392
Definition: transform.h:426
Definition: transform.h:517
A helper class to collect all the targets in canonical form necessary for compilation.
tvm::relax::Function Function
Definition: transform.h:39
Expr Legalize(const Expr &expr, const std::string &legalize_map_attr_name)
Type InferTypeLocal(const Expr &expr)
Infer the type of an expression, reusing existing type information.
Pass RewriteAnnotatedOps(int fallback_device)
Rewrite the annotated program.
Pass ManifestLifetimes()
A pass for manifesting variable lifetimes by inserting kill operations when variables become dead....
Pass FoldConstant(bool fold_qnn=false)
Fold constant expressions.
Pass CapturePostDfsIndexInSpans()
Captures the post-dfs index and dominator post-dfs index of (most) expression nodes in their span,...
Pass RelayToTIRTargetHook(CompilationConfig config)
Run any custom passes registered under "RelayToTIR" attributes on TargetKinds.
Pass PlanDevices(CompilationConfig config)
Uses existing "on_device" and "device_copy" CallNodes to infer the VirtualDevice on which every Relay...
Pass InferType()
Infer the type of an expression.
Pass FastMath()
Replaces non linear activation functions with their fast but approximate counterparts.
Pass CreateFunctionPass(const runtime::TypedPackedFunc< Function(Function, IRModule, PassContext)> &pass_func, int opt_level, String name, tvm::Array< String > required, bool traceable=false)
Pass FuseOps(int fuse_opt_level=-1)
Fuse operations into expr into separate functions.
Pass ToGraphNormalForm()
Remove let binding and directly share via pointer instead.
tvm::transform::Sequential Sequential
Definition: transform.h:49
Pass CombineParallelConv2D(uint64_t min_num_branches=3)
Combine parallel 2d convolutions into a single convolution if the number of branches of this conv2d o...
Pass BackwardFoldScaleAxis()
Backward fold axis scaling into weights of conv/dense operators.
Pass ConvertLayout(const Map< String, Array< String >> &desired_layouts)
Given a dest layout, this pass transforms the expr such that most of the ops input data layout is cha...
Pass SimplifyExpr()
Simplify the Relay expression.
Pass DeadCodeElimination(bool inline_once=false, bool ignore_purity=false)
Remove let-bound expressions which do not effect the program result.
Pass ManifestAlloc(VirtualDevice cpu_virtual_device)
A pass for manifesting explicit memory allocations and rewriting specific dialects.
Pass AutoSchedulerLayoutRewrite()
Do layout rewrite according to the tile structure created by auto-scheduler.
Pass SimplifyExprPostAlterOp()
Stripped down version of SimplifyExpr which is run after AlterOpLayout.
Pass CombineParallelDense(uint64_t min_num_branches=3, bool to_batch_matmul=true)
Combine parallel dense ops into a single batch_matmul if the number of branches of this dense operato...
Pass ToBasicBlockNormalForm()
Turn an expression to Basic Block Normal Form.
Pass AlterOpLayout()
Alternate the layouts of operators or replace primitive operators with other expressions.
Pass AnnotateUsedMemory()
Annotates the minimum required memory of each primitive function callsite by analyzing the liveness o...
Pass EtaExpand(bool expand_constructor, bool expand_global_var)
Add abstraction over a constructor or global variable bound to a function.
tvm::transform::PassContextNode PassContextNode
Definition: transform.h:48
Pass FoldScaleAxis()
A sequential pass that executes ForwardFoldScaleAxis and BackwardFoldScaleAxis passes.
tvm::transform::PassContext PassContext
Definition: transform.h:47
Pass ToANormalForm()
turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
Pass DefuseOps()
The inverse operation of FuseOps. It transforms a fused program returned by FuseOps into the program ...
Pass CanonicalizeCast()
Canonicalize cast expressions to make operator fusion more efficient.
Pass AnnotateMemoryScope()
Calls device dependent memory scope analysis pass, collects mapping of desirable expr->memory_scope a...
Pass PartialEval()
Aggressive constant propagation/constant folding/inlining.
tvm::transform::PassInfo PassInfo
Definition: transform.h:45
tvm::transform::PassNode PassNode
Definition: transform.h:44
Pass CombineParallelBatchMatmul(uint64_t min_num_branches=3)
Combine parallel batch_matmul ops into a single batch_matmul if the number of branches of this dense ...
Pass ForwardFoldScaleAxis()
Forward fold axis scaling into weights of conv/dense operators.
Pass PartitionGraph()
Partition a Relay program into regions that can be executed on different backends.
Pass Legalize(const String &legalize_map_attr_name="FTVMLegalize")
Legalizes an expr with another expression.
tvm::transform::PassInfoNode PassInfoNode
Definition: transform.h:46
Pass DynamicToStatic()
Find Dynamic ops and make them static.
Pass MetaScheduleLayoutRewrite()
Do layout rewrite according to the tile structure created by meta-schedule.
Pass LazyGradientInit()
Convert all expressions of TensorType into GradCell, an algebraic data type defined in gradient....
Pass Inline()
Inline the global functions marked as inline in a given Relay IRModule.
Pass RemoveStandaloneReshapes()
Removes non-fused reshapes after lowering the graph. InferType() cannot be invoked after calling this...
Pass ToCPS()
Turn an expression into continuation passing style(CPS).
Pass CanonicalizeOps()
Canonicalize some operators to the simplified operators. For example, bias_add can be canonicalized t...
Pass SimplifyInference()
Simplify certain operators during inference. For example, the result of a batch norm which is indexed...
Pass EliminateCommonSubexpr(runtime::PackedFunc fskip=nullptr)
Search and eliminate common subexpression. For example, if there are two expressions evaluated to an ...
Pass FlattenAtrousConv()
This transform flattens atrous convolution, which corresponds to the sequence of operations: "space_t...
Pass SplitArgs(uint64_t max_function_args)
Split function with huge number of arguments to smaller pieces.
Pass RemoveUnusedFunctions(Array< runtime::String > entry_functions)
Remove the unused functions in the Relay IRModule.
Expr ForwardRewrite(const Expr &expr, const String &rewrite_map_attr_name, std::function< ObjectRef(const Call &)> fcontext=nullptr, std::function< Expr(const Expr &)> fmulti_ref_trigger=nullptr)
Apply rewrite rules to rewrite the expr in post DFS order. This function is used as a helper function...
Function UnCPS(const Function &f)
Remove the continuation argument of a CPS function.
Expr DeDup(const Expr &e)
Deduplicate the bound variables and type variables in the expression.
tvm::RelayExpr Expr
Definition: expr.h:54
Expr RewriteAnnotatedOps(const Expr &expr, int fallback_device)
Rewrite the annotated program.
Function SubstituteBoundVars(const Function &func, const tvm::Map< Var, Expr > &binds)
Substitute variables with new variables (including function parameters) in a function....
Function ToCPS(const Function &f, const IRModule &mod)
Turn an expression into continuation passing style(CPS).
Expr Bind(const Expr &expr, const tvm::Map< Var, Expr > &binds)
Bind the free variables to a Relay expression. This is a helper function usually called by other pass...
tvm::transform::Pass Pass
Definition: transform.h:35
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Transform operators.
Relay expression language.
Relay Function.
Primitive operators(builtin intrinsics).
The Expr and related elements in DataFlow construction.
Compilation target object.
A compile time representation for where data is to be stored at runtime, and how to compile code to c...