tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAY_TRANSFORM_H_
25 #define TVM_RELAY_TRANSFORM_H_
26 
27 #include <tvm/ir/transform.h>
29 #include <tvm/relay/expr.h>
30 #include <tvm/relay/function.h>
31 #include <tvm/relay/op.h>
33 #include <tvm/target/target.h>
34 
35 #include <string>
36 
37 namespace tvm {
38 namespace relay {
39 namespace transform {
40 
48 
49 /*
50  * \brief Create a function pass.
51  *
52  * \param pass_func The packed function that contains the optimization.
53  * \param opt_level The optimization level of the function pass.
54  * \param name The name of the function pass.
55  * \param required The list of the passes that the function pass is dependent on.
56  *
57  * \return The created function pass.
58  */
59 TVM_DLL Pass CreateFunctionPass(
61  int opt_level, String name, tvm::Array<String> required);
62 
77 TVM_DLL Pass DeadCodeElimination(bool inline_once = false);
78 
91 TVM_DLL Pass LazyGradientInit();
92 
98 TVM_DLL Pass FoldConstant();
99 
105 TVM_DLL Pass SplitArgs(int max_function_args);
106 
114 TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
115 
122 TVM_DLL Pass DefuseOps();
123 
132 TVM_DLL Pass RewriteAnnotatedOps(int fallback_device);
133 
147 TVM_DLL Pass ToBasicBlockNormalForm();
148 
163 TVM_DLL Pass ToANormalForm();
164 
172 TVM_DLL Expr ToANormalForm(const Expr& expr);
173 
188 TVM_DLL Pass ToCPS();
189 
198 TVM_DLL Pass ToGraphNormalForm();
199 
209 TVM_DLL Pass PartialEval();
210 
218 TVM_DLL Pass SimplifyInference();
219 
225 TVM_DLL Pass FastMath();
226 
236 TVM_DLL Pass DynamicToStatic();
237 
247 TVM_DLL Pass InferType();
248 
258 TVM_DLL Pass EliminateCommonSubexpr(runtime::PackedFunc fskip = nullptr);
259 
269 TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);
270 
282 TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3, bool to_batch_matmul = true);
283 
293 TVM_DLL Pass CombineParallelBatchMatmul(uint64_t min_num_branches = 3);
294 
300 TVM_DLL Pass BackwardFoldScaleAxis();
301 
307 TVM_DLL Pass ForwardFoldScaleAxis();
308 
315 TVM_DLL Pass FoldScaleAxis();
316 
323 TVM_DLL Pass CanonicalizeOps();
324 
331 TVM_DLL Pass AlterOpLayout();
332 
338 
359 TVM_DLL Pass ConvertLayout(const Map<String, Array<String>>& desired_layouts);
360 
371 TVM_DLL Pass Legalize(const String& legalize_map_attr_name = "FTVMLegalize");
372 
378 TVM_DLL Pass CanonicalizeCast();
379 
394 TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);
395 
402 TVM_DLL Pass PartitionGraph();
403 
410 TVM_DLL Pass Inline();
411 
420 TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
421 
427 TVM_DLL Pass SimplifyExpr();
428 
434 TVM_DLL Pass RelayToTIRTargetHook();
435 
445 TVM_DLL Pass ManifestAlloc(Target target_host, Map<tvm::Integer, tvm::Target> targets);
446 
456 TVM_DLL Pass PlanDevices(DLDeviceType default_device_type);
457 
458 } // namespace transform
459 
470 TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
471 
484 TVM_DLL Expr ForwardRewrite(const Expr& expr, const String& rewrite_map_attr_name,
485  std::function<ObjectRef(const Call&)> fcontext = nullptr,
486  std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
487 
500 TVM_DLL Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func,
501  std::function<ObjectRef(const Call&)> fcontext = nullptr,
502  std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
503 
513 TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
514 
532 TVM_DLL Function ToCPS(const Function& f, const IRModule& mod);
533 
544 TVM_DLL Function UnCPS(const Function& f);
545 
553 TVM_DLL Expr DeDup(const Expr& e);
554 
555 } // namespace relay
556 } // namespace tvm
557 
558 #endif // TVM_RELAY_TRANSFORM_H_
Pass AutoSchedulerLayoutRewrite()
Do layout rewrite according to the tile structure created by auto-scheduler.
Pass ToCPS()
Turn an expression into continuation passing style(CPS).
Pass CanonicalizeCast()
Canonicalize cast expressions to make operator fusion more efficient.
Pass PartitionGraph()
Partition a Relay program into regions that can be executed on different backends.
Pass PlanDevices(DLDeviceType default_device_type)
Uses existing "on_device" and "device_copy" CallNodes to infer the device on which every Relay sub-ex...
Pass ToGraphNormalForm()
Remove let binding and directly share via pointer instead.
Pass BackwardFoldScaleAxis()
Backward fold axis scaling into weights of conv/dense operators.
Expr ForwardRewrite(const Expr &expr, const String &rewrite_map_attr_name, std::function< ObjectRef(const Call &)> fcontext=nullptr, std::function< Expr(const Expr &)> fmulti_ref_trigger=nullptr)
Apply rewrite rules to rewrite the expr in post DFS order. This function is used as a helper function...
Pass DeadCodeElimination(bool inline_once=false)
Remove expressions which does not effect the program result.
Function UnCPS(const Function &f)
Remove the continuation argument of a CPS function.
Managed reference to FunctionNode.
Definition: function.h:104
Relay expression language.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
Pass ManifestAlloc(Target target_host, Map< tvm::Integer, tvm::Target > targets)
A pass for manifesting explicit memory allocations and rewriting specific dialects.
Pass ToANormalForm()
turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
Pass InferType()
Infer the type of an expression.
tvm::transform::Sequential Sequential
Definition: transform.h:47
Pass CombineParallelConv2D(uint64_t min_num_branches=3)
Combine parallel 2d convolutions into a single convolution if the number of branches of this conv2d o...
Relay Function.
Pass Inline()
Inline the global functions marked as inline in a given Relay IRModule.
Pass Legalize(const String &legalize_map_attr_name="FTVMLegalize")
Legalizes an expr with another expression.
The Expr and related elements in DataFlow construction.
Pass EliminateCommonSubexpr(runtime::PackedFunc fskip=nullptr)
Search and eliminate common subexpression. For example, if there are two expressions evaluated to an ...
Definition: expr.h:301
tvm::transform::PassNode PassNode
Definition: transform.h:42
tvm::transform::Pass Pass
Definition: transform.h:41
Managed reference class for PassInfoNode.
Definition: transform.h:311
Pass AlterOpLayout()
Alternate the layouts of operators or replace primitive operators with other expressions.
Pass EtaExpand(bool expand_constructor, bool expand_global_var)
Add abstraction over a constructor or global variable bound to a function.
Expr Bind(const Expr &expr, const tvm::Map< Var, Expr > &binds)
Bind the free variables to a Relay expression. This is a helper function usually called by other pass...
Pass ForwardFoldScaleAxis()
Forward fold axis scaling into weights of conv/dense operators.
Expr DeDup(const Expr &e)
Deduplicate the bound variables and type variables in the expression.
Pass DefuseOps()
The inverse operation of FuseOps. It transforms a fused program returned by FuseOps into the program ...
Pass SimplifyInference()
Simplify certain operators during inference. For example, the result of a batch norm which is indexed...
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
Pass FoldScaleAxis()
A sequential pass that executes ForwardFoldScaleAxis and BackwardFoldScaleAxis passes.
Definition: transform.h:363
tvm::transform::PassInfoNode PassInfoNode
Definition: transform.h:44
PassContext that is used to configure the pass behavior.
Definition: transform.h:154
Reference to string objects.
Definition: string.h:129
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:136
Managed reference to RelayExprNode.
Definition: expr.h:177
Pass ConvertLayout(const Map< String, Array< String >> &desired_layouts)
Given a dest layout, this pass transforms the expr such that most of the ops input data layout is cha...
Pass SplitArgs(int max_function_args)
Split function with huge number of arguments to smaller pieces.
Managed reference class to TargetNode.
Definition: target.h:132
Pass RewriteAnnotatedOps(int fallback_device)
Rewrite the annotated program.
Pass ToBasicBlockNormalForm()
Turn an expression to Basic Block Normal Form.
Pass CreateFunctionPass(const runtime::TypedPackedFunc< Function(Function, IRModule, PassContext)> &pass_func, int opt_level, String name, tvm::Array< String > required)
Pass FuseOps(int fuse_opt_level=-1)
Fuse operations into expr into seperate functions.
Base class of all object reference.
Definition: object.h:504
Transform operators.
Pass CombineParallelDense(uint64_t min_num_branches=3, bool to_batch_matmul=true)
Combine parallel dense ops into a single batch_matmul if the number of branches of this dense operato...
tvm::RelayExpr Expr
Definition: expr.h:43
Pass LazyGradientInit()
Convert all expressions of TensorType into GradCell, an algebraic data type defined in gradient...
Meta data that will be used to help optimization and analysis.
Definition: transform.h:283
Managed reference class to IRModuleNode.
Definition: module.h:352
PassNode is the base type of differnt types of optimization passes. It is designed as a pure class an...
Definition: transform.h:329
PassContextNode contains the information that a pass can rely on, such as analysis results...
Definition: transform.h:78
Pass DynamicToStatic()
Find Dynamic ops and make them static.
tvm::transform::PassContext PassContext
Definition: transform.h:45
Compilation target object.
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1235
Pass RelayToTIRTargetHook()
Run any registered RelayToTIR passes registered on the functions in a module.
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:68
Pass CombineParallelBatchMatmul(uint64_t min_num_branches=3)
Combine parallel batch_matmul ops into a single batch_matmul if the number of branches of this dense ...
Pass CanonicalizeOps()
Canonicalize some operators to the simplified operators. For example, bias_add can be canonicalized t...
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:271
Pass RemoveUnusedFunctions(Array< runtime::String > entry_functions)
Remove the unused functions in the Relay IRModule.
Pass PartialEval()
Aggressive constant propagation/constant folding/inlining.
tvm::transform::PassContextNode PassContextNode
Definition: transform.h:46
Primitive operators(builtin intrinsics).
Pass FastMath()
Replaces non linear activation functions with their fast but approximate counterparts.
Pass FoldConstant()
Fold constant expressions.
tvm::transform::PassInfo PassInfo
Definition: transform.h:43
Definition: transform.h:397
Pass SimplifyExpr()
Simplify the Relay expression.