tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_TRANSFORM_H_
25 #define TVM_TIR_TRANSFORM_H_
26 
27 #include <tvm/ir/transform.h>
28 #include <tvm/target/target.h>
29 #include <tvm/tir/expr.h>
30 #include <tvm/tir/function.h>
31 
32 #include <string>
33 #include <vector>
34 
35 namespace tvm {
36 namespace tir {
37 namespace transform {
38 
46 
47 /*
48  * \brief Create a function pass that optimizes PrimFuncs.
49  *
50  * \param pass_func The packed function that contains the optimization.
51  * \param opt_level The optimization level of the function pass.
52  * \param name The name of the function pass.
53  * \param required The list of the passes that the function pass is dependent on.
54  *
55  * \return The created function pass.
56  */
59  int opt_level, String name, tvm::Array<String> required);
60 
66 TVM_DLL Pass InjectPrefetch();
67 
68 // TODO(tvm-team): consolidate configs to the PassContext
78 TVM_DLL Pass StorageFlatten(int cache_line_size, bool create_bound_attribute = false);
79 
93 TVM_DLL Pass InjectCopyIntrin(String pragma_key, runtime::PackedFunc fintrin);
94 
100 TVM_DLL Pass CoProcSync();
101 
108 TVM_DLL Pass LiftAttrScope(String attr_key);
109 
115 TVM_DLL Pass LoopPartition();
116 
124 TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true);
125 
132 
139 
149 
156 TVM_DLL Pass UnrollLoop();
157 
163 TVM_DLL Pass RemoveNoOp();
164 
171 
177 TVM_DLL Pass Simplify();
178 
190 TVM_DLL Pass ConvertSSA();
191 
198 
225 TVM_DLL Pass MakePackedAPI();
226 
237 
249 
258 
265 
277 
293 
310 
316 TVM_DLL Pass SkipAssert();
317 
324 TVM_DLL Pass ThreadSync(String storage_scope);
325 
332 
338 TVM_DLL Pass InferFragment();
339 
343 static constexpr const char* kDisableLowerTVMBuiltin = "disable_lower_builtin";
344 
350 
356 TVM_DLL Pass LowerIntrin();
357 
363 
372 
379 
388 TVM_DLL Pass NarrowDataType(int target_bits);
389 
396 
403 TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16");
404 
410 
416 
426 
434 
447 
454 
460 
468 
476 
482 
521 TVM_DLL Pass CompactBufferAllocation(bool is_strict = true);
522 
527 
533 
539 
545 
551 
557 TVM_DLL Pass FlattenBuffer();
558 
559 /*
560  * \brief Flatten the multi-dimensional read/write
561  * to two dimensional texture Load/Store and realize
562  * texture buffer allocations.
563  *
564  * \return The Pass
565  */
567 
568 /*
569  * \brief Lower VTCM allocations
570  *
571  * \return The Pass
572  */
574 
578 TVM_DLL Pass LowerAsyncDMA();
579 
587 TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false);
588 
595 
607 
612 
621 
629 
730 
731 TVM_DLL Pass BindParams(const Array<runtime::NDArray>& constants);
732 
739 
744 TVM_DLL Pass LowerAutoCopy();
745 
751 
756 TVM_DLL Pass BindTarget(Target target);
757 
763 
769 
775 
780 TVM_DLL Pass InjectPTXLDG32(bool enable_ptx_ldg32 = true);
781 
794 TVM_DLL Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite = false);
795 
801 
807 
808 } // namespace transform
809 } // namespace tir
810 } // namespace tvm
811 
812 #endif // TVM_TIR_TRANSFORM_H_
Managed reference class to IRModuleNode.
Definition: module.h:348
Managed reference class to TargetNode.
Definition: target.h:192
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:139
Reference to string objects.
Definition: string.h:98
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:61
Managed reference to PrimFuncNode.
Definition: function.h:145
PassContext that is used to configure the pass behavior.
Definition: transform.h:153
Definition: transform.h:362
tvm::transform::Sequential Sequential
Definition: transform.h:49
tvm::transform::PassContextNode PassContextNode
Definition: transform.h:48
tvm::transform::PassContext PassContext
Definition: transform.h:47
tvm::transform::PassInfo PassInfo
Definition: transform.h:45
tvm::transform::PassNode PassNode
Definition: transform.h:44
tvm::transform::PassInfoNode PassInfoNode
Definition: transform.h:46
PrimFuncFrame PrimFunc(bool is_private)
The primitive function statement.
Pass ThreadSync(String storage_scope)
Insert sync between parallel read/write of shared buffers.
Pass MergeDynamicSharedMemoryAllocations()
Pass LowerIntrin()
Lower the target specific function intrinsics in each of the function.
Pass LowerCustomDatatypes()
Lower custom datatypes.
Pass ConvertForLoopsToSerial()
This pass is post-scheduling pass to convert all Parallel For loops to Serial ones....
Pass LowerThreadAllreduce()
Lower cross thread alleduce.
Pass LoopPartition()
partition loops in the stmt.
Pass MakeUnpackedAPI()
Transform the high-level PrimFunc to a C signature that can be used to call the operator directly.
Pass TransformMmaBufferLayout()
Transform Mma scope (m16n8k8.matrixA/B/C) to local scope with layout transformation.
Pass LiftAttrScope(String attr_key)
Lift common attrs with attr_key to outer scope.
Pass BF16ComputeLegalize()
Legalize bf16 compute Ops. Add a cast to fp32 before Ops, then add a cast back to bf16.
Pass InjectSoftwarePipeline()
This pass transforms annotated loops into pipelined ones where producers and consumers are overlapped...
Pass CompactBufferAllocation(bool is_strict=true)
Compact the buffer access region by removing the buffer regions that are not accessed,...
Pass InjectPrefetch()
Inject prefetch instructions into stmt.
Pass RemapThreadAxis(Map< String, IterVar > axis_map)
Remap the thread axis.
Pass InstrumentProfileIntrinsics()
Insert intrinsic calls to instrument function and loop level profiling.
Pass PointerValueTypeRewrite()
Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use the m...
Pass InjectVirtualThread()
Inject virtual thread loops.
Pass MakePackedAPI()
Transform the high-level PrimFunc to a low-level version that can be used as an API function.
Pass FlattenBuffer()
Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional BufferLoad/BufferStore...
Pass LowerAutoCopy()
Automatically do memory optimizations for auto copy blocks.
Pass FP8ComputeLegalize(String promote_dtype_str="float16")
Legalize fp8 compute Ops. Add a cast to fp16/fp32 before Ops, then add a cast back to fp8.
Pass SplitHostDevice()
Split the function into a host function and device functions.
Pass RewriteUnsafeSelect()
Detect and rewrite unsafe select that contains memory access.
Pass LowerDeviceStorageAccessInfo()
Lower attached storage access information on device.
Pass InjectDoubleBuffer()
Inject double buffer statements.
Pass LowerOpaqueBlock()
Remove the block to ensure that the TIR can not be scheduled again.
Pass FP8StorageLegalize()
Legalize fp8 storage types to u8.
Pass RenormalizeSplitPattern()
Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())
Pass PlanAndUpdateBufferAllocationLocation()
Locate the buffer allocation to the exact position (usually is the lca of buffer access)....
Pass LowerCrossThreadReduction()
Lower cross-thread reduction from thread bindings to intrinsic function calls.
Pass DecorateDeviceScope()
Decorate all the function's body as device function.
Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< PrimFunc(PrimFunc, IRModule, PassContext)> &pass_func, int opt_level, String name, tvm::Array< String > required)
Pass SkipAssert()
skip assert stmt.
Pass Filter(runtime::TypedPackedFunc< bool(PrimFunc)> fcond)
Filter PrimFuncs with a given condition.
Pass AnnotateEntryFunc()
Set a PrimFunc as the entry point if it is only function in IRModule.
Pass StorageFlatten(int cache_line_size, bool create_bound_attribute=false)
Flatten the multi-dimensional read/write to single dimensional Load/Store.
Pass ConvertSSA()
Convert an IRModule to be SSA form.
Pass Simplify()
Run arithmetic simplifications on the statements and expressions.
Pass CommonSubexprElimTIR(bool enable_cse_tir=true, bool identify_equiv_terms=false)
Implements a Common Subexpression Elimination (CSE) for TIR which introduces let-in bindings for dupl...
Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite=false)
Remove the weight layout rewrite block.
Pass RemoveNoOp()
Remove No Op from the Stmt.
Pass AnnotateDeviceRegions()
Annotate locations that should be run on the device.
Pass UnifyThreadBinding()
Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and "vthread....
Pass LowerTVMBuiltin()
Lower builtin intrinsics.
Pass InstrumentBoundCheckers()
Instruments bound checkers.
Pass CoProcSync()
Detect and insert sync points to co-processor.
Pass LowerDeviceKernelLaunch()
Lower cross-device function calls.
Pass NarrowDataType(int target_bits)
Narrow down PrimExpr datatype in stmt to target_bits.
Pass HoistExpression()
Hoist loop-invariant expressions nodes to outside the elligible loops.
Pass LowerMatchBuffer()
Remove match buffers inside the block. Also, it will validate the binding.
Pass UnrollLoop()
unroll the constant loop marked by unroll. This pass also automatically attach pragma unroll tag to l...
Pass StorageRewrite()
Rewrite storage allocation pattern. Moves the allocation to outer most possible scope....
Pass BindParams(const Array< runtime::NDArray > &constants)
Pass ExtractPrimFuncConstants()
Pass to collect tir non-scalar constants into module's 'Constants' attribute.
Pass InjectPTXLDG32(bool enable_ptx_ldg32=true)
Pass to rewrite global to local memory copy on CUDA with ldg32 instruction.
Pass InjectPTXAsyncCopy()
Pass to rewrite global to shared memory copy on CUDA with asyncronous copy.
Pass UnifiedStaticMemoryPlanner()
This is the unified static memory planner pass that will plan for memory intra- and inter- PrimFuncs ...
Pass HoistIfThenElse()
Hoist loop-invariant IfThenElse nodes to outside the elligible loops.
Pass LowerWarpMemory()
Lower warp memory access to low-level device related function calls.
Pass InstallDebugSpans()
Add TIR-printer output as debug information to all ops in the module.
Pass InferFragment()
Infer the TensorCore fragment infomation using tensor intrinsics.
Pass ConvertBlocksToOpaque()
Substitute all the block vars with the PrimExprs they are bound to, indicated by the corresponding it...
Pass BF16StorageLegalize()
Legalize bf16 storage types to u16.
Pass InjectCopyIntrin(String pragma_key, runtime::PackedFunc fintrin)
Inject copy intrinsics with optional pad.
Pass BindTarget(Target target)
Annotate a PrimFunc with a given target.
Pass LowerInitBlock()
Lower block init stmt into IfThenElse stmts.
Pass LiftThreadBinding()
Lift the same thread bindings to their LCA loops.
Pass CombineContextCall()
Combine context calls in the host function.
Pass ManifestSharedMemoryLocalStage()
Add the explicit local stage for the shared memory access on GPU.
Pass VectorizeLoop(bool enable_vectorize=true)
Lower vectorization loops.
Pass LowerAsyncDMA()
Lower Async TIR primitives to DMA copy and wait builtins.
Pass InjectPermutedLayout()
Inject permuted layout for shared memory.
tvm::transform::Pass Pass
Definition: transform.h:35
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Compilation target object.
TIR expressions.
TIR Function.