tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAY_TRANSFORM_H_
25 #define TVM_RELAY_TRANSFORM_H_
26 
27 #include <tvm/ir/transform.h>
29 #include <tvm/relay/expr.h>
30 #include <tvm/relay/function.h>
31 #include <tvm/relay/op.h>
34 #include <tvm/target/target.h>
36 
37 #include <string>
38 
39 namespace tvm {
40 namespace relay {
41 namespace transform {
42 
50 
62 
63 /*
64  * \brief Create a function pass.
65  *
66  * \param pass_func The packed function that contains the optimization.
67  * \param opt_level The optimization level of the function pass.
68  * \param name The name of the function pass.
69  * \param required The list of the passes that the function pass is dependent on.
70  *
71  * \return The created function pass.
72  */
75  int opt_level, String name, tvm::Array<String> required);
76 
101 TVM_DLL Pass DeadCodeElimination(bool inline_once = false, bool ignore_purity = false);
102 
116 
130 TVM_DLL Pass FoldConstant(bool fold_qnn = false);
131 
140 TVM_DLL Pass SplitArgs(uint64_t max_function_args);
141 
149 TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
150 
157 TVM_DLL Pass DefuseOps();
158 
167 TVM_DLL Pass RewriteAnnotatedOps(int fallback_device);
168 
183 
198 TVM_DLL Pass ToANormalForm();
199 
207 TVM_DLL Expr ToANormalForm(const Expr& expr);
208 
223 TVM_DLL Pass ToCPS();
224 
234 
244 TVM_DLL Pass PartialEval();
245 
254 
260 TVM_DLL Pass FastMath();
261 
272 
282 TVM_DLL Pass InferType();
283 
297 TVM_DLL Type InferTypeLocal(const Expr& expr);
298 
309 
319 TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);
320 
332 TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3, bool to_batch_matmul = true);
333 
343 TVM_DLL Pass CombineParallelBatchMatmul(uint64_t min_num_branches = 3);
344 
351 
358 
365 TVM_DLL Pass FoldScaleAxis();
366 
374 
381 TVM_DLL Pass AlterOpLayout();
382 
388 
394 
415 TVM_DLL Pass ConvertLayout(const Map<String, Array<String>>& desired_layouts);
416 
427 TVM_DLL Pass Legalize(const String& legalize_map_attr_name = "FTVMLegalize");
428 
435 
450 TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);
451 
459 
466 TVM_DLL Pass Inline();
467 
477 
483 TVM_DLL Pass SimplifyExpr();
484 
491 
537 
547 TVM_DLL Pass ManifestAlloc(VirtualDevice cpu_virtual_device);
548 
556 
571 
580 
593 
604 
610 
620 
621 } // namespace transform
622 
637 TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
638 
650 TVM_DLL Function SubstituteBoundVars(const Function& func, const tvm::Map<Var, Expr>& binds);
651 
664 TVM_DLL Expr ForwardRewrite(const Expr& expr, const String& rewrite_map_attr_name,
665  std::function<ObjectRef(const Call&)> fcontext = nullptr,
666  std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
667 
680 TVM_DLL Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func,
681  std::function<ObjectRef(const Call&)> fcontext = nullptr,
682  std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
683 
693 TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
694 
712 TVM_DLL Function ToCPS(const Function& f, const IRModule& mod);
713 
724 TVM_DLL Function UnCPS(const Function& f);
725 
733 TVM_DLL Expr DeDup(const Expr& e);
734 
735 namespace legalize {
736 TVM_DLL Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name);
737 } // namespace legalize
738 
739 } // namespace relay
740 } // namespace tvm
741 
742 #endif // TVM_RELAY_TRANSFORM_H_
Managed reference class to CompilationConfig.
Definition: compilation_config.h:191
Managed reference class to IRModuleNode.
Definition: module.h:348
Managed reference to RelayExprNode.
Definition: expr.h:433
Managed reference to TypeNode.
Definition: type.h:93
Managed reference class to VirtualDeviceNode.
Definition: virtual_device.h:271
Definition: expr.h:357
Managed reference to FunctionNode.
Definition: function.h:105
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Base class of all object reference.
Definition: object.h:517
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:139
Reference to string objects.
Definition: string.h:98
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:61
PassContextNode contains the information that a pass can rely on, such as analysis results.
Definition: transform.h:77
PassContext that is used to configure the pass behavior.
Definition: transform.h:153
Meta data that will be used to help optimization and analysis.
Definition: transform.h:282
Managed reference class for PassInfoNode.
Definition: transform.h:310
PassNode is the base type of differnt types of optimization passes. It is designed as a pure class an...
Definition: transform.h:328
Definition: transform.h:362
Definition: transform.h:453
A helper class to collect all the targets in canonical form necessary for compilation.
Expr Legalize(const Expr &expr, const std::string &legalize_map_attr_name)
Type InferTypeLocal(const Expr &expr)
Infer the type of an expression, reusing existing type information.
Pass RewriteAnnotatedOps(int fallback_device)
Rewrite the annotated program.
Pass ManifestLifetimes()
A pass for manifesting variable lifetimes by inserting kill operations when variables become dead....
Pass FoldConstant(bool fold_qnn=false)
Fold constant expressions.
Pass CapturePostDfsIndexInSpans()
Captures the post-dfs index and dominator post-dfs index of (most) expression nodes in their span,...
Pass RelayToTIRTargetHook(CompilationConfig config)
Run any custom passes registered under "RelayToTIR" attributes on TargetKinds.
Pass PlanDevices(CompilationConfig config)
Uses existing "on_device" and "device_copy" CallNodes to infer the VirtualDevice on which every Relay...
Pass InferType()
Infer the type of an expression.
Pass CreateFunctionPass(const runtime::TypedPackedFunc< Function(Function, IRModule, PassContext)> &pass_func, int opt_level, String name, tvm::Array< String > required)
Pass FastMath()
Replaces non linear activation functions with their fast but approximate counterparts.
Pass FuseOps(int fuse_opt_level=-1)
Fuse operations into expr into separate functions.
Pass ToGraphNormalForm()
Remove let binding and directly share via pointer instead.
tvm::transform::Sequential Sequential
Definition: transform.h:49
Pass CombineParallelConv2D(uint64_t min_num_branches=3)
Combine parallel 2d convolutions into a single convolution if the number of branches of this conv2d o...
Pass BackwardFoldScaleAxis()
Backward fold axis scaling into weights of conv/dense operators.
Pass ConvertLayout(const Map< String, Array< String >> &desired_layouts)
Given a dest layout, this pass transforms the expr such that most of the ops input data layout is cha...
Pass SimplifyExpr()
Simplify the Relay expression.
Pass DeadCodeElimination(bool inline_once=false, bool ignore_purity=false)
Remove let-bound expressions which do not effect the program result.
Pass ManifestAlloc(VirtualDevice cpu_virtual_device)
A pass for manifesting explicit memory allocations and rewriting specific dialects.
Pass AutoSchedulerLayoutRewrite()
Do layout rewrite according to the tile structure created by auto-scheduler.
Pass SimplifyExprPostAlterOp()
Stripped down version of SimplifyExpr which is run after AlterOpLayout.
Pass CombineParallelDense(uint64_t min_num_branches=3, bool to_batch_matmul=true)
Combine parallel dense ops into a single batch_matmul if the number of branches of this dense operato...
Pass ToBasicBlockNormalForm()
Turn an expression to Basic Block Normal Form.
Pass AlterOpLayout()
Alternate the layouts of operators or replace primitive operators with other expressions.
Pass AnnotateUsedMemory()
Annotates the minimum required memory of each primitive function callsite by analyzing the liveness o...
Pass EtaExpand(bool expand_constructor, bool expand_global_var)
Add abstraction over a constructor or global variable bound to a function.
tvm::transform::PassContextNode PassContextNode
Definition: transform.h:48
Pass FoldScaleAxis()
A sequential pass that executes ForwardFoldScaleAxis and BackwardFoldScaleAxis passes.
tvm::transform::PassContext PassContext
Definition: transform.h:47
Pass ToANormalForm()
turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
Pass DefuseOps()
The inverse operation of FuseOps. It transforms a fused program returned by FuseOps into the program ...
Pass CanonicalizeCast()
Canonicalize cast expressions to make operator fusion more efficient.
Pass AnnotateMemoryScope()
Calls device dependent memory scope analysis pass, collects mapping of desirable expr->memory_scope a...
Pass PartialEval()
Aggressive constant propagation/constant folding/inlining.
tvm::transform::PassInfo PassInfo
Definition: transform.h:45
tvm::transform::PassNode PassNode
Definition: transform.h:44
Pass CombineParallelBatchMatmul(uint64_t min_num_branches=3)
Combine parallel batch_matmul ops into a single batch_matmul if the number of branches of this dense ...
Pass ForwardFoldScaleAxis()
Forward fold axis scaling into weights of conv/dense operators.
Pass PartitionGraph()
Partition a Relay program into regions that can be executed on different backends.
Pass Legalize(const String &legalize_map_attr_name="FTVMLegalize")
Legalizes an expr with another expression.
tvm::transform::PassInfoNode PassInfoNode
Definition: transform.h:46
Pass DynamicToStatic()
Find Dynamic ops and make them static.
Pass MetaScheduleLayoutRewrite()
Do layout rewrite according to the tile structure created by meta-schedule.
Pass LazyGradientInit()
Convert all expressions of TensorType into GradCell, an algebraic data type defined in gradient....
Pass Inline()
Inline the global functions marked as inline in a given Relay IRModule.
Pass RemoveStandaloneReshapes()
Removes non-fused reshapes after lowering the graph. InferType() cannot be invoked after calling this...
Pass ToCPS()
Turn an expression into continuation passing style(CPS).
Pass CanonicalizeOps()
Canonicalize some operators to the simplified operators. For example, bias_add can be canonicalized t...
Pass SimplifyInference()
Simplify certain operators during inference. For example, the result of a batch norm which is indexed...
Pass EliminateCommonSubexpr(runtime::PackedFunc fskip=nullptr)
Search and eliminate common subexpression. For example, if there are two expressions evaluated to an ...
Pass FlattenAtrousConv()
This transform flattens atrous convolution, which corresponds to the sequence of operations: "space_t...
Pass SplitArgs(uint64_t max_function_args)
Split function with huge number of arguments to smaller pieces.
Pass RemoveUnusedFunctions(Array< runtime::String > entry_functions)
Remove the unused functions in the Relay IRModule.
Expr ForwardRewrite(const Expr &expr, const String &rewrite_map_attr_name, std::function< ObjectRef(const Call &)> fcontext=nullptr, std::function< Expr(const Expr &)> fmulti_ref_trigger=nullptr)
Apply rewrite rules to rewrite the expr in post DFS order. This function is used as a helper function...
Function UnCPS(const Function &f)
Remove the continuation argument of a CPS function.
Expr DeDup(const Expr &e)
Deduplicate the bound variables and type variables in the expression.
tvm::RelayExpr Expr
Definition: expr.h:54
Expr RewriteAnnotatedOps(const Expr &expr, int fallback_device)
Rewrite the annotated program.
Function SubstituteBoundVars(const Function &func, const tvm::Map< Var, Expr > &binds)
Substitute variables with new variables (including function parameters) in a function....
Function ToCPS(const Function &f, const IRModule &mod)
Turn an expression into continuation passing style(CPS).
Expr Bind(const Expr &expr, const tvm::Map< Var, Expr > &binds)
Bind the free variables to a Relay expression. This is a helper function usually called by other pass...
tvm::transform::Pass Pass
Definition: transform.h:35
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Transform operators.
Relay expression language.
Relay Function.
Primitive operators(builtin intrinsics).
The Expr and related elements in DataFlow construction.
Compilation target object.
A compile time representation for where data is to be stored at runtime, and how to compile code to c...