tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAY_TRANSFORM_H_
25 #define TVM_RELAY_TRANSFORM_H_
26 
27 #include <tvm/ir/transform.h>
29 #include <tvm/relay/expr.h>
30 #include <tvm/relay/function.h>
31 #include <tvm/relay/op.h>
34 #include <tvm/target/target.h>
36 
37 #include <string>
38 
39 namespace tvm {
40 namespace relay {
41 namespace transform {
42 
50 
51 /*
52  * \brief Create a function pass.
53  *
54  * \param pass_func The packed function that contains the optimization.
55  * \param opt_level The optimization level of the function pass.
56  * \param name The name of the function pass.
57  * \param required The list of the passes that the function pass is dependent on.
58  *
59  * \return The created function pass.
60  */
63  int opt_level, String name, tvm::Array<String> required);
64 
89 TVM_DLL Pass DeadCodeElimination(bool inline_once = false, bool ignore_purity = false);
90 
104 
118 TVM_DLL Pass FoldConstant(bool fold_qnn = false);
119 
125 TVM_DLL Pass SplitArgs(int max_function_args);
126 
134 TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
135 
142 TVM_DLL Pass DefuseOps();
143 
152 TVM_DLL Pass RewriteAnnotatedOps(int fallback_device);
153 
168 
183 TVM_DLL Pass ToANormalForm();
184 
192 TVM_DLL Expr ToANormalForm(const Expr& expr);
193 
208 TVM_DLL Pass ToCPS();
209 
219 
229 TVM_DLL Pass PartialEval();
230 
239 
245 TVM_DLL Pass FastMath();
246 
257 
267 TVM_DLL Pass InferType();
268 
282 TVM_DLL Type InferTypeLocal(const Expr& expr);
283 
294 
304 TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);
305 
317 TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3, bool to_batch_matmul = true);
318 
328 TVM_DLL Pass CombineParallelBatchMatmul(uint64_t min_num_branches = 3);
329 
336 
343 
350 TVM_DLL Pass FoldScaleAxis();
351 
359 
366 TVM_DLL Pass AlterOpLayout();
367 
373 
379 
400 TVM_DLL Pass ConvertLayout(const Map<String, Array<String>>& desired_layouts);
401 
412 TVM_DLL Pass Legalize(const String& legalize_map_attr_name = "FTVMLegalize");
413 
420 
435 TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);
436 
444 
451 TVM_DLL Pass Inline();
452 
462 
468 TVM_DLL Pass SimplifyExpr();
469 
476 
522 
532 TVM_DLL Pass ManifestAlloc(VirtualDevice cpu_virtual_device);
533 
541 
556 
565 
578 
589 
595 
605 
606 } // namespace transform
607 
622 TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
623 
635 TVM_DLL Function SubstituteBoundVars(const Function& func, const tvm::Map<Var, Expr>& binds);
636 
649 TVM_DLL Expr ForwardRewrite(const Expr& expr, const String& rewrite_map_attr_name,
650  std::function<ObjectRef(const Call&)> fcontext = nullptr,
651  std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
652 
665 TVM_DLL Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func,
666  std::function<ObjectRef(const Call&)> fcontext = nullptr,
667  std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
668 
678 TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
679 
697 TVM_DLL Function ToCPS(const Function& f, const IRModule& mod);
698 
709 TVM_DLL Function UnCPS(const Function& f);
710 
718 TVM_DLL Expr DeDup(const Expr& e);
719 
720 namespace legalize {
721 TVM_DLL Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name);
722 } // namespace legalize
723 
724 } // namespace relay
725 } // namespace tvm
726 
727 #endif // TVM_RELAY_TRANSFORM_H_
Managed reference class to CompilationConfig.
Definition: compilation_config.h:191
Managed reference class to IRModuleNode.
Definition: module.h:348
Managed reference to RelayExprNode.
Definition: expr.h:433
Managed reference to TypeNode.
Definition: type.h:93
Managed reference class to VirtualDeviceNode.
Definition: virtual_device.h:271
Definition: expr.h:357
Managed reference to FunctionNode.
Definition: function.h:105
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Base class of all object reference.
Definition: object.h:515
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:138
Reference to string objects.
Definition: string.h:98
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:60
PassContextNode contains the information that a pass can rely on, such as analysis results.
Definition: transform.h:77
PassContext that is used to configure the pass behavior.
Definition: transform.h:153
Meta data that will be used to help optimization and analysis.
Definition: transform.h:282
Managed reference class for PassInfoNode.
Definition: transform.h:310
PassNode is the base type of differnt types of optimization passes. It is designed as a pure class an...
Definition: transform.h:328
Definition: transform.h:362
Definition: transform.h:453
A helper class to collect all the targets in canonical form necessary for compilation.
Expr Legalize(const Expr &expr, const std::string &legalize_map_attr_name)
Type InferTypeLocal(const Expr &expr)
Infer the type of an expression, reusing existing type information.
Pass RewriteAnnotatedOps(int fallback_device)
Rewrite the annotated program.
Pass ManifestLifetimes()
A pass for manifesting variable lifetimes by inserting kill operations when variables become dead....
Pass FoldConstant(bool fold_qnn=false)
Fold constant expressions.
Pass CapturePostDfsIndexInSpans()
Captures the post-dfs index and dominator post-dfs index of (most) expression nodes in their span,...
Pass RelayToTIRTargetHook(CompilationConfig config)
Run any custom passes registered under "RelayToTIR" attributes on TargetKinds.
Pass PlanDevices(CompilationConfig config)
Uses existing "on_device" and "device_copy" CallNodes to infer the VirtualDevice on which every Relay...
Pass InferType()
Infer the type of an expression.
Pass CreateFunctionPass(const runtime::TypedPackedFunc< Function(Function, IRModule, PassContext)> &pass_func, int opt_level, String name, tvm::Array< String > required)
Pass FastMath()
Replaces non linear activation functions with their fast but approximate counterparts.
Pass SplitArgs(int max_function_args)
Split function with huge number of arguments to smaller pieces.
Pass FuseOps(int fuse_opt_level=-1)
Fuse operations into expr into separate functions.
Pass ToGraphNormalForm()
Remove let binding and directly share via pointer instead.
tvm::transform::Sequential Sequential
Definition: transform.h:49
Pass CombineParallelConv2D(uint64_t min_num_branches=3)
Combine parallel 2d convolutions into a single convolution if the number of branches of this conv2d o...
Pass BackwardFoldScaleAxis()
Backward fold axis scaling into weights of conv/dense operators.
Pass ConvertLayout(const Map< String, Array< String >> &desired_layouts)
Given a dest layout, this pass transforms the expr such that most of the ops input data layout is cha...
Pass SimplifyExpr()
Simplify the Relay expression.
Pass DeadCodeElimination(bool inline_once=false, bool ignore_purity=false)
Remove let-bound expressions which do not effect the program result.
Pass ManifestAlloc(VirtualDevice cpu_virtual_device)
A pass for manifesting explicit memory allocations and rewriting specific dialects.
Pass AutoSchedulerLayoutRewrite()
Do layout rewrite according to the tile structure created by auto-scheduler.
Pass SimplifyExprPostAlterOp()
Stripped down version of SimplifyExpr which is run after AlterOpLayout.
Pass CombineParallelDense(uint64_t min_num_branches=3, bool to_batch_matmul=true)
Combine parallel dense ops into a single batch_matmul if the number of branches of this dense operato...
Pass ToBasicBlockNormalForm()
Turn an expression to Basic Block Normal Form.
Pass AlterOpLayout()
Alternate the layouts of operators or replace primitive operators with other expressions.
Pass AnnotateUsedMemory()
Annotates the minimum required memory of each primitive function callsite by analyzing the liveness o...
Pass EtaExpand(bool expand_constructor, bool expand_global_var)
Add abstraction over a constructor or global variable bound to a function.
tvm::transform::PassContextNode PassContextNode
Definition: transform.h:48
Pass FoldScaleAxis()
A sequential pass that executes ForwardFoldScaleAxis and BackwardFoldScaleAxis passes.
tvm::transform::PassContext PassContext
Definition: transform.h:47
Pass ToANormalForm()
turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
Pass DefuseOps()
The inverse operation of FuseOps. It transforms a fused program returned by FuseOps into the program ...
Pass CanonicalizeCast()
Canonicalize cast expressions to make operator fusion more efficient.
Pass AnnotateMemoryScope()
Calls device dependent memory scope analysis pass, collects mapping of desirable expr->memory_scope a...
Pass PartialEval()
Aggressive constant propagation/constant folding/inlining.
tvm::transform::PassInfo PassInfo
Definition: transform.h:45
tvm::transform::PassNode PassNode
Definition: transform.h:44
Pass CombineParallelBatchMatmul(uint64_t min_num_branches=3)
Combine parallel batch_matmul ops into a single batch_matmul if the number of branches of this dense ...
Pass ForwardFoldScaleAxis()
Forward fold axis scaling into weights of conv/dense operators.
Pass PartitionGraph()
Partition a Relay program into regions that can be executed on different backends.
Pass Legalize(const String &legalize_map_attr_name="FTVMLegalize")
Legalizes an expr with another expression.
tvm::transform::PassInfoNode PassInfoNode
Definition: transform.h:46
Pass DynamicToStatic()
Find Dynamic ops and make them static.
Pass MetaScheduleLayoutRewrite()
Do layout rewrite according to the tile structure created by meta-schedule.
Pass LazyGradientInit()
Convert all expressions of TensorType into GradCell, an algebraic data type defined in gradient....
Pass Inline()
Inline the global functions marked as inline in a given Relay IRModule.
Pass RemoveStandaloneReshapes()
Removes non-fused reshapes after lowering the graph. InferType() cannot be invoked after calling this...
Pass ToCPS()
Turn an expression into continuation passing style(CPS).
Pass CanonicalizeOps()
Canonicalize some operators to the simplified operators. For example, bias_add can be canonicalized t...
Pass SimplifyInference()
Simplify certain operators during inference. For example, the result of a batch norm which is indexed...
Pass EliminateCommonSubexpr(runtime::PackedFunc fskip=nullptr)
Search and eliminate common subexpression. For example, if there are two expressions evaluated to an ...
Pass FlattenAtrousConv()
This transform flattens atrous convolution, which corresponds to the sequence of operations: "space_t...
Pass RemoveUnusedFunctions(Array< runtime::String > entry_functions)
Remove the unused functions in the Relay IRModule.
Expr ForwardRewrite(const Expr &expr, const String &rewrite_map_attr_name, std::function< ObjectRef(const Call &)> fcontext=nullptr, std::function< Expr(const Expr &)> fmulti_ref_trigger=nullptr)
Apply rewrite rules to rewrite the expr in post DFS order. This function is used as a helper function...
Function UnCPS(const Function &f)
Remove the continuation argument of a CPS function.
Expr DeDup(const Expr &e)
Deduplicate the bound variables and type variables in the expression.
tvm::RelayExpr Expr
Definition: expr.h:54
Expr RewriteAnnotatedOps(const Expr &expr, int fallback_device)
Rewrite the annotated program.
Function SubstituteBoundVars(const Function &func, const tvm::Map< Var, Expr > &binds)
Substitute variables with new variables (including function parameters) in a function....
Function ToCPS(const Function &f, const IRModule &mod)
Turn an expression into continuation passing style(CPS).
Expr Bind(const Expr &expr, const tvm::Map< Var, Expr > &binds)
Bind the free variables to a Relay expression. This is a helper function usually called by other pass...
tvm::transform::Pass Pass
Definition: transform.h:35
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Transform operators.
Relay expression language.
Relay Function.
Primitive operators(builtin intrinsics).
The Expr and related elements in DataFlow construction.
Compilation target object.
A compile time representation for where data is to be stored at runtime, and how to compile code to c...