tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_TRANSFORM_H_
25 #define TVM_TIR_TRANSFORM_H_
26 
27 #include <tvm/ir/transform.h>
28 #include <tvm/target/target.h>
29 #include <tvm/tir/expr.h>
30 #include <tvm/tir/function.h>
31 
32 #include <string>
33 #include <vector>
34 
35 namespace tvm {
36 namespace tir {
37 namespace transform {
38 
46 
47 /*
48  * \brief Create a function pass that optimizes PrimFuncs.
49  *
50  * \param pass_func The packed function that contains the optimization.
51  * \param opt_level The optimization level of the function pass.
52  * \param name The name of the function pass.
53  * \param required The list of the passes that the function pass is dependent on.
54  *
55  * \return The created function pass.
56  */
59  int opt_level, String name, tvm::Array<String> required, bool traceable = false);
60 
66 TVM_DLL Pass InjectPrefetch();
67 
68 // TODO(tvm-team): consolidate configs to the PassContext
78 TVM_DLL Pass StorageFlatten(int cache_line_size, bool create_bound_attribute = false);
79 
93 TVM_DLL Pass InjectCopyIntrin(String pragma_key, runtime::PackedFunc fintrin);
94 
100 TVM_DLL Pass CoProcSync();
101 
108 TVM_DLL Pass LiftAttrScope(String attr_key);
109 
115 TVM_DLL Pass LoopPartition();
116 
124 TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true);
125 
132 
139 
149 
156 TVM_DLL Pass UnrollLoop();
157 
163 TVM_DLL Pass RemoveNoOp();
164 
171 
177 TVM_DLL Pass Simplify();
178 
190 TVM_DLL Pass ConvertSSA();
191 
198 
225 TVM_DLL Pass MakePackedAPI();
226 
237 
249 
258 
265 
277 
293 
310 
316 TVM_DLL Pass SkipAssert();
317 
324 TVM_DLL Pass ThreadSync(String storage_scope);
325 
332 
338 TVM_DLL Pass InferFragment();
339 
343 static constexpr const char* kDisableLowerTVMBuiltin = "disable_lower_builtin";
344 
350 
356 TVM_DLL Pass LowerIntrin();
357 
363 
372 
379 
388 TVM_DLL Pass NarrowDataType(int target_bits);
389 
397 
404 
412 TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16");
413 
419 
426 
433 
443 
451 
464 
471 
477 
485 
493 
499 
538 TVM_DLL Pass CompactBufferAllocation(bool is_strict = true);
539 
544 
550 
556 
562 
568 
574 TVM_DLL Pass FlattenBuffer();
575 
576 /*
577  * \brief Flatten the multi-dimensional read/write
578  * to two dimensional texture Load/Store and realize
579  * texture buffer allocations.
580  *
581  * \return The Pass
582  */
584 
585 /*
586  * \brief Lower VTCM allocations
587  *
588  * \return The Pass
589  */
591 
595 TVM_DLL Pass LowerAsyncDMA();
596 
604 TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false);
605 
612 
624 
629 
638 
646 
747 
748 TVM_DLL Pass BindParams(const Array<runtime::NDArray>& constants);
749 
756 
761 TVM_DLL Pass LowerAutoCopy();
762 
768 
773 TVM_DLL Pass BindTarget(Target target);
774 
780 
786 
792 
797 TVM_DLL Pass InjectPTXLDG32(bool enable_ptx_ldg32 = true);
798 
811 TVM_DLL Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite = false);
812 
818 
824 
836 
844 
845 } // namespace transform
846 } // namespace tir
847 } // namespace tvm
848 
849 #endif // TVM_TIR_TRANSFORM_H_
Managed reference class to IRModuleNode.
Definition: module.h:366
Managed reference class to TargetNode.
Definition: target.h:200
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:141
Reference to string objects.
Definition: string.h:98
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:63
Managed reference to PrimFuncNode.
Definition: function.h:145
PassContext that is used to configure the pass behavior.
Definition: transform.h:182
Definition: transform.h:426
tvm::transform::Sequential Sequential
Definition: transform.h:49
tvm::transform::PassContextNode PassContextNode
Definition: transform.h:48
tvm::transform::PassContext PassContext
Definition: transform.h:47
tvm::transform::PassInfo PassInfo
Definition: transform.h:45
tvm::transform::PassNode PassNode
Definition: transform.h:44
tvm::transform::PassInfoNode PassInfoNode
Definition: transform.h:46
PrimFuncFrame PrimFunc(bool is_private)
The primitive function statement.
Pass ThreadSync(String storage_scope)
Insert sync between parallel read/write of shared buffers.
Pass LowerIntrin()
Lower the target specific function intrinsics in each of the function.
Pass LowerCustomDatatypes()
Lower custom datatypes.
Pass ConvertForLoopsToSerial()
This pass is post-scheduling pass to convert all Parallel For loops to Serial ones....
Pass LowerThreadAllreduce()
Lower cross thread alleduce.
Pass LoopPartition()
partition loops in the stmt.
Pass MakeUnpackedAPI()
Transform the high-level PrimFunc to a C signature that can be used to call the operator directly.
Pass TransformMmaBufferLayout()
Transform Mma scope (m16n8k8.matrixA/B/C) to local scope with layout transformation.
Pass LiftAttrScope(String attr_key)
Lift common attrs with attr_key to outer scope.
Pass BF16ComputeLegalize()
Legalize bf16 compute Ops. Add a cast to fp32 before Ops, then add a cast back to bf16.
Pass InjectSoftwarePipeline()
This pass transforms annotated loops into pipelined ones where producers and consumers are overlapped...
Pass CompactBufferAllocation(bool is_strict=true)
Compact the buffer access region by removing the buffer regions that are not accessed,...
Pass InjectPrefetch()
Inject prefetch instructions into stmt.
Pass RemapThreadAxis(Map< String, IterVar > axis_map)
Remap the thread axis.
Pass InstrumentProfileIntrinsics()
Insert intrinsic calls to instrument function and loop level profiling.
Pass PointerValueTypeRewrite()
Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use the m...
Pass InjectVirtualThread()
Inject virtual thread loops.
Pass MakePackedAPI()
Transform the high-level PrimFunc to a low-level version that can be used as an API function.
Pass FlattenBuffer()
Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional BufferLoad/BufferStore...
Pass LowerAutoCopy()
Automatically do memory optimizations for auto copy blocks.
Pass FP8ComputeLegalize(String promote_dtype_str="float16")
Legalize fp8 compute Ops. Add a cast to fp16/fp32 before Ops, then add a cast back to fp8.
Pass MergeSharedMemoryAllocations()
Pass SplitHostDevice()
Split the function into a host function and device functions.
Pass RewriteUnsafeSelect()
Detect and rewrite unsafe select that contains memory access.
Pass LowerDeviceStorageAccessInfo()
Lower attached storage access information on device.
Pass InjectDoubleBuffer()
Inject double buffer statements.
Pass LowerOpaqueBlock()
Remove the block to ensure that the TIR can not be scheduled again.
Pass FP8StorageLegalize()
Legalize fp8 storage types to u8.
Pass RenormalizeSplitPattern()
Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())
Pass InlinePrivateFunctions()
Inline calls to private functions.
Pass PlanAndUpdateBufferAllocationLocation()
Locate the buffer allocation to the exact position (usually is the lca of buffer access)....
Pass LowerCrossThreadReduction()
Lower cross-thread reduction from thread bindings to intrinsic function calls.
Pass DecorateDeviceScope()
Decorate all the function's body as device function.
Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< PrimFunc(PrimFunc, IRModule, PassContext)> &pass_func, int opt_level, String name, tvm::Array< String > required, bool traceable=false)
Pass SkipAssert()
skip assert stmt.
Pass Filter(runtime::TypedPackedFunc< bool(PrimFunc)> fcond)
Filter PrimFuncs with a given condition.
Pass ForceNarrowIndexToInt32()
Force to narrow down indexing expressions and integer buffers to int32 dtype.
Pass AnnotateEntryFunc()
Set a PrimFunc as the entry point if it is only function in IRModule.
Pass StorageFlatten(int cache_line_size, bool create_bound_attribute=false)
Flatten the multi-dimensional read/write to single dimensional Load/Store.
Pass ConvertSSA()
Convert an IRModule to be SSA form.
Pass Simplify()
Run arithmetic simplifications on the statements and expressions.
Pass CommonSubexprElimTIR(bool enable_cse_tir=true, bool identify_equiv_terms=false)
Implements a Common Subexpression Elimination (CSE) for TIR which introduces let-in bindings for dupl...
Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite=false)
Remove the weight layout rewrite block.
Pass RemoveNoOp()
Remove No Op from the Stmt.
Pass AnnotateDeviceRegions()
Annotate locations that should be run on the device.
Pass UnifyThreadBinding()
Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and "vthread....
Pass LowerTVMBuiltin()
Lower builtin intrinsics.
Pass InstrumentBoundCheckers()
Instruments bound checkers.
Pass UseAssumeToReduceBranches()
This pass analyzes primfunc & eliminates branch introdued due to layout specific padding....
Pass CoProcSync()
Detect and insert sync points to co-processor.
Pass LowerDeviceKernelLaunch()
Lower cross-device function calls.
Pass NarrowDataType(int target_bits)
Narrow down PrimExpr datatype in stmt to target_bits.
Pass HoistExpression()
Hoist loop-invariant expressions nodes to outside the elligible loops.
Pass DefaultGPUSchedule()
The pass sets default thread bindings for PrimFuncs, including symbolic shape functions,...
Pass LowerMatchBuffer()
Remove match buffers inside the block. Also, it will validate the binding.
Pass UnrollLoop()
unroll the constant loop marked by unroll. This pass also automatically attach pragma unroll tag to l...
Pass StorageRewrite()
Rewrite storage allocation pattern. Moves the allocation to outer most possible scope....
Pass BindParams(const Array< runtime::NDArray > &constants)
Pass ExtractPrimFuncConstants()
Pass to collect tir non-scalar constants into module's 'Constants' attribute.
Pass InjectPTXLDG32(bool enable_ptx_ldg32=true)
Pass to rewrite global to local memory copy on CUDA with ldg32 instruction.
Pass InjectPTXAsyncCopy()
Pass to rewrite global to shared memory copy on CUDA with asyncronous copy.
Pass UnifiedStaticMemoryPlanner()
This is the unified static memory planner pass that will plan for memory intra- and inter- PrimFuncs ...
Pass HoistIfThenElse()
Hoist loop-invariant IfThenElse nodes to outside the elligible loops.
Pass LowerWarpMemory()
Lower warp memory access to low-level device related function calls.
Pass InstallDebugSpans()
Add TIR-printer output as debug information to all ops in the module.
Pass InferFragment()
Infer the TensorCore fragment infomation using tensor intrinsics.
Pass ConvertBlocksToOpaque()
Substitute all the block vars with the PrimExprs they are bound to, indicated by the corresponding it...
Pass BF16StorageLegalize()
Legalize bf16 storage types to u16.
Pass InjectCopyIntrin(String pragma_key, runtime::PackedFunc fintrin)
Inject copy intrinsics with optional pad.
Pass BindTarget(Target target)
Annotate a PrimFunc with a given target.
Pass LowerInitBlock()
Lower block init stmt into IfThenElse stmts.
Pass LiftThreadBinding()
Lift the same thread bindings to their LCA loops.
Pass CombineContextCall()
Combine context calls in the host function.
Pass ManifestSharedMemoryLocalStage()
Add the explicit local stage for the shared memory access on GPU.
Pass VectorizeLoop(bool enable_vectorize=true)
Lower vectorization loops.
Pass LowerAsyncDMA()
Lower Async TIR primitives to DMA copy and wait builtins.
Pass InjectPermutedLayout()
Inject permuted layout for shared memory.
tvm::transform::Pass Pass
Definition: transform.h:35
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Compilation target object.
TIR expressions.
TIR Function.