tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_TRANSFORM_H_
25 #define TVM_TIR_TRANSFORM_H_
26 
27 #include <tvm/ir/transform.h>
28 #include <tvm/target/target.h>
29 #include <tvm/tir/expr.h>
30 #include <tvm/tir/function.h>
31 
32 #include <string>
33 #include <vector>
34 
35 namespace tvm {
36 namespace tir {
37 namespace transform {
38 
47 
48 /*
49  * \brief Create a function pass that optimizes PrimFuncs.
50  *
51  * \param pass_func The packed function that contains the optimization.
52  * \param opt_level The optimization level of the function pass.
53  * \param name The name of the function pass.
54  * \param required The list of the passes that the function pass is dependent on.
55  *
56  * \return The created function pass.
57  */
58 TVM_DLL Pass CreatePrimFuncPass(std::function<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
59  int opt_level, ffi::String name,
60  tvm::ffi::Array<ffi::String> required, bool traceable = false);
61 
69 TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true);
70 
79 TVM_DLL Pass StorageRewrite();
80 
87 TVM_DLL Pass UnrollLoop();
88 
94 TVM_DLL Pass RemoveNoOp();
95 
102 
108 TVM_DLL Pass Simplify();
109 
121 TVM_DLL Pass ConvertSSA();
122 
129 
156 TVM_DLL Pass MakePackedAPI();
157 
168 
179 TVM_DLL Pass RemapThreadAxis(ffi::Map<ffi::String, IterVar> axis_map);
180 
189 
196 
208 
224 
241 
247 TVM_DLL Pass SkipAssert();
248 
255 TVM_DLL Pass ThreadSync(ffi::String storage_scope);
256 
263 
269 TVM_DLL Pass InferFragment();
270 
274 static constexpr const char* kDisableLowerTVMBuiltin = "disable_lower_builtin";
275 
281 
287 TVM_DLL Pass LowerIntrin();
288 
294 
303 
310 
319 TVM_DLL Pass NarrowDataType(int target_bits);
320 
328 
335 
343 TVM_DLL Pass FP8ComputeLegalize(ffi::String promote_dtype = "float16");
344 
350 
357 
364 
374 
382 
395 
401 TVM_DLL Pass FlattenBuffer();
402 
403 /*
404  * \brief Lower VTCM allocations
405  *
406  * \return The Pass
407  */
409 
413 TVM_DLL Pass LowerAsyncDMA();
414 
422 TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false);
423 
428 
437 
445 
446 TVM_DLL Pass BindParams(const ffi::Array<runtime::Tensor>& constants);
447 
454 
460 
465 TVM_DLL Pass BindTarget(Target target);
466 
472 
477 TVM_DLL Pass Filter(ffi::TypedFunction<bool(PrimFunc)> fcond);
478 
484 
489 TVM_DLL Pass InjectPTXLDG32(bool enable_ptx_ldg32 = true);
490 
503 TVM_DLL Pass RemoveWeightLayoutRewriteBlock(bool skip_tensor_rewrite = false);
504 
510 
522 
530 
531 } // namespace transform
532 } // namespace tir
533 } // namespace tvm
534 
535 #endif // TVM_TIR_TRANSFORM_H_
Managed reference class to IRModuleNode.
Definition: module.h:256
Managed reference class to TargetNode.
Definition: target.h:192
Managed reference to PrimFuncNode.
Definition: function.h:129
PassContextNode contains the information that a pass can rely on, such as analysis results.
Definition: transform.h:79
PassContext that is used to configure the pass behavior.
Definition: transform.h:153
Meta data that will be used to help optimization and analysis.
Definition: transform.h:319
PassNode is the base type of differnt types of optimization passes. It is designed as a pure class an...
Definition: transform.h:370
Definition: transform.h:400
Definition: transform.h:491
tvm::transform::PassContext PassContext
Definition: transform.h:40
tvm::transform::PassInfo PassInfo
Definition: transform.h:39
tvm::transform::Pass Pass
Definition: transform.h:38
PrimFuncFrame PrimFunc(bool is_private)
The primitive function statement.
Pass CreatePrimFuncPass(std::function< PrimFunc(PrimFunc, IRModule, PassContext)> pass_func, int opt_level, ffi::String name, tvm::ffi::Array< ffi::String > required, bool traceable=false)
Pass LowerIntrin()
Lower the target specific function intrinsics in each of the function.
Pass LowerCustomDatatypes()
Lower custom datatypes.
Pass ConvertForLoopsToSerial()
This pass is post-scheduling pass to convert all Parallel For loops to Serial ones....
Pass LowerThreadAllreduce()
Lower cross thread alleduce.
Pass MakeUnpackedAPI()
Transform the high-level PrimFunc to a C signature that can be used to call the operator directly.
Pass BF16ComputeLegalize()
Legalize bf16 compute Ops. Add a cast to fp32 before Ops, then add a cast back to bf16.
Pass InstrumentProfileIntrinsics()
Insert intrinsic calls to instrument function and loop level profiling.
Pass PointerValueTypeRewrite()
Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use the m...
Pass MakePackedAPI()
Transform the high-level PrimFunc to a low-level version that can be used as an API function.
Pass FlattenBuffer()
Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional BufferLoad/BufferStore...
Pass BindParams(const ffi::Array< runtime::Tensor > &constants)
Pass MergeSharedMemoryAllocations()
Pass SplitHostDevice()
Split the function into a host function and device functions.
Pass RewriteUnsafeSelect()
Detect and rewrite unsafe select that contains memory access.
Pass LowerDeviceStorageAccessInfo()
Lower attached storage access information on device.
Pass RemoveWeightLayoutRewriteBlock(bool skip_tensor_rewrite=false)
Remove the weight layout rewrite block.
Pass FP8StorageLegalize()
Legalize fp8 storage types to u8.
Pass RenormalizeSplitPattern()
Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())
Pass InlinePrivateFunctions()
Inline calls to private functions.
Pass DecorateDeviceScope()
Decorate all the function's body as device function.
Pass SkipAssert()
skip assert stmt.
Pass ForceNarrowIndexToInt32()
Force to narrow down indexing expressions and integer buffers to int32 dtype.
Pass AnnotateEntryFunc()
Set a PrimFunc as the entry point if it is only function in IRModule.
Pass ConvertSSA()
Convert an IRModule to be SSA form.
Pass Simplify()
Run arithmetic simplifications on the statements and expressions.
Pass CommonSubexprElimTIR(bool enable_cse_tir=true, bool identify_equiv_terms=false)
Implements a Common Subexpression Elimination (CSE) for TIR which introduces let-in bindings for dupl...
Pass RemoveNoOp()
Remove No Op from the Stmt.
Pass AnnotateDeviceRegions()
Annotate locations that should be run on the device.
Pass LowerTVMBuiltin()
Lower builtin intrinsics.
Pass InstrumentBoundCheckers()
Instruments bound checkers.
Pass UseAssumeToReduceBranches()
This pass analyzes primfunc & eliminates branch introdued due to layout specific padding....
Pass FP8ComputeLegalize(ffi::String promote_dtype="float16")
Legalize fp8 compute Ops. Add a cast to fp16/fp32 before Ops, then add a cast back to fp8.
Pass LowerDeviceKernelLaunch()
Lower cross-device function calls.
Pass NarrowDataType(int target_bits)
Narrow down PrimExpr datatype in stmt to target_bits.
Pass HoistExpression()
Hoist loop-invariant expressions nodes to outside the elligible loops.
Pass DefaultGPUSchedule()
The pass sets default thread bindings for PrimFuncs, including symbolic shape functions,...
Pass ThreadSync(ffi::String storage_scope)
Insert sync between parallel read/write of shared buffers.
Pass UnrollLoop()
unroll the constant loop marked by unroll. This pass also automatically attach pragma unroll tag to l...
Pass StorageRewrite()
Rewrite storage allocation pattern. Moves the allocation to outer most possible scope....
Pass ExtractPrimFuncConstants()
Pass to collect tir non-scalar constants into module's 'Constants' attribute.
Pass InjectPTXLDG32(bool enable_ptx_ldg32=true)
Pass to rewrite global to local memory copy on CUDA with ldg32 instruction.
Pass InjectPTXAsyncCopy()
Pass to rewrite global to shared memory copy on CUDA with asyncronous copy.
Pass UnifiedStaticMemoryPlanner()
This is the unified static memory planner pass that will plan for memory intra- and inter- PrimFuncs ...
Pass HoistIfThenElse()
Hoist loop-invariant IfThenElse nodes to outside the elligible loops.
Pass RemapThreadAxis(ffi::Map< ffi::String, IterVar > axis_map)
Remap the thread axis.
Pass LowerWarpMemory()
Lower warp memory access to low-level device related function calls.
Pass InferFragment()
Infer the TensorCore fragment infomation using tensor intrinsics.
Pass BF16StorageLegalize()
Legalize bf16 storage types to u16.
Pass BindTarget(Target target)
Annotate a PrimFunc with a given target.
Pass Filter(ffi::TypedFunction< bool(PrimFunc)> fcond)
Filter PrimFuncs with a given condition.
Pass CombineContextCall()
Combine context calls in the host function.
Pass VectorizeLoop(bool enable_vectorize=true)
Lower vectorization loops.
Pass LowerAsyncDMA()
Lower Async TIR primitives to DMA copy and wait builtins.
Pass CreateModulePass(std::function< IRModule(IRModule, PassContext)> pass_func, int opt_level, ffi::String name, ffi::Array< ffi::String > required, bool traceable=false)
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
Compilation target object.
TIR expressions.
TIR Function.