tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_TRANSFORM_H_
25 #define TVM_TIR_TRANSFORM_H_
26 
27 #include <tvm/ir/transform.h>
28 #include <tvm/target/target.h>
29 #include <tvm/tirx/expr.h>
30 #include <tvm/tirx/function.h>
31 
32 #include <string>
33 #include <vector>
34 
35 namespace tvm {
36 namespace tirx {
37 namespace transform {
38 
47 
48 /*
49  * \brief Create a function pass that optimizes PrimFuncs.
50  *
51  * \param pass_func The packed function that contains the optimization.
52  * \param opt_level The optimization level of the function pass.
53  * \param name The name of the function pass.
54  * \param required The list of the passes that the function pass is dependent on.
55  *
56  * \return The created function pass.
57  */
58 TVM_DLL Pass CreatePrimFuncPass(std::function<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
59  int opt_level, ffi::String name,
60  tvm::ffi::Array<ffi::String> required, bool traceable = false);
61 
69 TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true);
70 
79 TVM_DLL Pass StorageRewrite();
80 
87 TVM_DLL Pass UnrollLoop();
88 
94 TVM_DLL Pass RemoveNoOp();
95 
101 TVM_DLL Pass Simplify();
102 
114 TVM_DLL Pass ConvertSSA();
115 
142 TVM_DLL Pass MakePackedAPI();
143 
154 TVM_DLL Pass RemapThreadAxis(ffi::Map<ffi::String, IterVar> axis_map);
155 
164 
176 
192 
209 
215 TVM_DLL Pass SkipAssert();
216 
220 static constexpr const char* kDisableLowerTVMBuiltin = "disable_lower_builtin";
221 
227 
233 TVM_DLL Pass LowerIntrin();
234 
240 
249 TVM_DLL Pass NarrowDataType(int target_bits);
250 
258 
265 
273 TVM_DLL Pass FP8ComputeLegalize(ffi::String promote_dtype = "float16");
274 
280 
287 
294 
304 
310 TVM_DLL Pass FlattenBuffer();
311 
318 
326 
331 TVM_DLL Pass BindTarget(Target target);
332 
338 
343 TVM_DLL Pass Filter(ffi::TypedFunction<bool(PrimFunc)> fcond);
344 
357 } // namespace transform
358 } // namespace tirx
359 } // namespace tvm
360 
361 #endif // TVM_TIR_TRANSFORM_H_
Managed reference class to IRModuleNode.
Definition: module.h:257
Managed reference class to TargetNode.
Definition: target.h:135
Managed reference to PrimFuncNode.
Definition: function.h:130
PassContextNode contains the information that a pass can rely on, such as analysis results.
Definition: transform.h:79
PassContext that is used to configure the pass behavior.
Definition: transform.h:153
Meta data that will be used to help optimization and analysis.
Definition: transform.h:319
PassNode is the base type of differnt types of optimization passes. It is designed as a pure class an...
Definition: transform.h:370
Definition: transform.h:400
Definition: transform.h:491
tvm::transform::PassContext PassContext
Definition: transform.h:40
tvm::transform::PassInfo PassInfo
Definition: transform.h:39
tvm::transform::Pass Pass
Definition: transform.h:38
PrimFuncFrame PrimFunc(bool is_private)
The primitive function statement.
Pass SplitHostDevice()
Split the function into a host function and device functions.
Pass PointerValueTypeRewrite()
Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use the m...
Pass SkipAssert()
skip assert stmt.
Pass RemapThreadAxis(ffi::Map< ffi::String, IterVar > axis_map)
Remap the thread axis.
Pass Simplify()
Run arithmetic simplifications on the statements and expressions.
Pass BindTarget(Target target)
Annotate a PrimFunc with a given target.
Pass AnnotateEntryFunc()
Set a PrimFunc as the entry point if it is only function in IRModule.
Pass CommonSubexprElim()
Implements Common Subexpression Elimination (CSE) for TIR which introduces Bind statements for duplic...
Pass MakePackedAPI()
Transform the high-level PrimFunc to a low-level version that can be used as an API function.
Pass FlattenBuffer()
Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional BufferLoad/BufferStore...
Pass StorageRewrite()
Rewrite storage allocation pattern. Moves the allocation to outer most possible scope....
Pass ConvertSSA()
Convert an IRModule to be SSA form.
Pass VectorizeLoop(bool enable_vectorize=true)
Lower vectorization loops.
Pass UnifiedStaticMemoryPlanner()
This is the unified static memory planner pass that will plan for memory intra- and inter- PrimFuncs ...
Pass LowerWarpMemory()
Lower warp memory access to low-level device related function calls.
Pass AnnotateDeviceRegions()
Annotate locations that should be run on the device.
Pass BF16ComputeLegalize()
Legalize bf16 compute Ops. Add a cast to fp32 before Ops, then add a cast back to bf16.
Pass CreatePrimFuncPass(std::function< PrimFunc(PrimFunc, IRModule, PassContext)> pass_func, int opt_level, ffi::String name, tvm::ffi::Array< ffi::String > required, bool traceable=false)
Pass NarrowDataType(int target_bits)
Narrow down PrimExpr datatype in stmt to target_bits.
Pass LowerDeviceKernelLaunch()
Lower cross-device function calls.
Pass RemoveNoOp()
Remove No Op from the Stmt.
Pass InlinePrivateFunctions()
Inline calls to private functions.
Pass LowerIntrin()
Lower the target specific function intrinsics in each of the function.
Pass Filter(ffi::TypedFunction< bool(PrimFunc)> fcond)
Filter PrimFuncs with a given condition.
Pass UnrollLoop()
unroll the constant loop marked by unroll. This pass also automatically attach pragma unroll tag to l...
Pass FP8ComputeLegalize(ffi::String promote_dtype="float16")
Legalize fp8 compute Ops. Add a cast to fp16/fp32 before Ops, then add a cast back to fp8.
Pass LowerCustomDatatypes()
Lower custom datatypes.
Pass LowerTVMBuiltin()
Lower builtin intrinsics.
Pass BF16StorageLegalize()
Legalize bf16 storage types to u16.
Pass FP8StorageLegalize()
Legalize fp8 storage types to u8.
Pass ForceNarrowIndexToInt32()
Force to narrow down indexing expressions and integer buffers to int32 dtype.
Pass CreateModulePass(std::function< IRModule(IRModule, PassContext)> pass_func, int opt_level, ffi::String name, ffi::Array< ffi::String > required, bool traceable=false)
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
Compilation target object.
TIR expressions.
TIR Function.