tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_TRANSFORM_H_
25 #define TVM_TIR_TRANSFORM_H_
26 
27 #include <tvm/ir/transform.h>
28 #include <tvm/target/target.h>
29 #include <tvm/tir/expr.h>
30 #include <tvm/tir/function.h>
31 
32 #include <string>
33 #include <vector>
34 
35 namespace tvm {
36 namespace tir {
37 namespace transform {
38 
46 
47 /*
48  * \brief Create a function pass that optimizes PrimFuncs.
49  *
50  * \param pass_func The packed function that contains the optimization.
51  * \param opt_level The optimization level of the function pass.
52  * \param name The name of the function pass.
53  * \param required The list of the passes that the function pass is dependent on.
54  *
55  * \return The created function pass.
56  */
57 TVM_DLL Pass CreatePrimFuncPass(
58  const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
59  int opt_level, String name, tvm::Array<String> required);
60 
66 TVM_DLL Pass InjectPrefetch();
67 
68 // TODO(tvm-team): consolidate configs to the PassContext
78 TVM_DLL Pass StorageFlatten(int cache_line_size, bool create_bound_attribute = false);
79 
93 TVM_DLL Pass InjectCopyIntrin(String pragma_key, runtime::PackedFunc fintrin);
94 
100 TVM_DLL Pass CoProcSync();
101 
108 TVM_DLL Pass LiftAttrScope(String attr_key);
109 
115 TVM_DLL Pass LoopPartition();
116 
124 TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true);
125 
131 TVM_DLL Pass InjectVirtualThread();
132 
138 TVM_DLL Pass InjectDoubleBuffer();
139 
148 TVM_DLL Pass StorageRewrite();
149 
156 TVM_DLL Pass UnrollLoop();
157 
163 TVM_DLL Pass RemoveNoOp();
164 
170 TVM_DLL Pass RewriteUnsafeSelect();
171 
177 TVM_DLL Pass Simplify();
178 
190 TVM_DLL Pass ConvertSSA();
191 
197 TVM_DLL Pass InstrumentBoundCheckers();
198 
225 TVM_DLL Pass MakePackedAPI();
226 
236 TVM_DLL Pass MakeUnpackedAPI();
237 
248 TVM_DLL Pass RemapThreadAxis(Map<String, IterVar> axis_map);
249 
257 TVM_DLL Pass LowerCustomDatatypes();
258 
264 TVM_DLL Pass DecorateDeviceScope();
265 
271 TVM_DLL Pass SplitHostDevice();
272 
278 TVM_DLL Pass SkipAssert();
279 
286 TVM_DLL Pass ThreadSync(String storage_scope);
287 
293 TVM_DLL Pass LowerThreadAllreduce();
294 
300 TVM_DLL Pass InferFragment();
301 
305 static constexpr const char* kDisableLowerTVMBuiltin = "disable_lower_builtin";
306 
311 TVM_DLL Pass LowerTVMBuiltin();
312 
318 TVM_DLL Pass LowerIntrin();
319 
324 TVM_DLL Pass LowerWarpMemory();
325 
334 
340 TVM_DLL Pass CombineContextCall();
341 
350 TVM_DLL Pass NarrowDataType(int target_bits);
351 
357 TVM_DLL Pass BF16ComputeLegalize();
358 
363 TVM_DLL Pass BF16StorageLegalize();
364 
373 TVM_DLL Pass PointerValueTypeRewrite();
374 
381 TVM_DLL Pass HoistIfThenElse();
382 
394 TVM_DLL Pass HoistExpression();
395 
402 
407 TVM_DLL Pass LowerInitBlock();
408 
416 
423 TVM_DLL Pass ConvertBlocksToOpaque();
424 
462 TVM_DLL Pass CompactBufferAllocation();
463 
467 TVM_DLL Pass LegalizePackedCalls();
468 
473 TVM_DLL Pass LowerMatchBuffer();
474 
479 TVM_DLL Pass LowerOpaqueBlock();
480 
486 TVM_DLL Pass FlattenBuffer();
487 
488 /*
489  * \brief Flatten the multi-dimensional read/write
490  * to two dimensional texture Load/Store and realize
491  * texture buffer allocations.
492  *
493  * \return The Pass
494  */
495 TVM_DLL Pass TextureFlatten();
496 
497 /*
498  * \brief Lower VTCM allocations
499  *
500  * \return The Pass
501  */
502 TVM_DLL Pass LowerVtcmAlloc();
503 
507 TVM_DLL Pass LowerAsyncDMA();
508 
516 TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false);
517 
523 TVM_DLL Pass InstallDebugSpans();
524 
535 TVM_DLL Pass UnifyThreadBinding();
536 
541 
549 TVM_DLL Pass ConvertForLoopsToSerial();
550 
558 
658 TVM_DLL Pass InjectSoftwarePipeline();
659 
660 TVM_DLL Pass BindParams(const Array<runtime::NDArray>& constants);
661 
668 
673 TVM_DLL Pass LowerAutoCopy();
674 
679 TVM_DLL Pass RenormalizeSplitPattern();
680 
685 TVM_DLL Pass BindTarget(Target target);
686 
691 TVM_DLL Pass AnnotateEntryFunc();
692 
697 TVM_DLL Pass Filter(runtime::TypedPackedFunc<bool(PrimFunc)> fcond);
698 
703 TVM_DLL Pass InjectPTXAsyncCopy();
704 
709 TVM_DLL Pass InjectPTXLDG32(bool enable_ptx_ldg32 = true);
710 
723 TVM_DLL Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite = false);
724 
730 
736 
737 } // namespace transform
738 } // namespace tir
739 } // namespace tvm
740 
741 #endif // TVM_TIR_TRANSFORM_H_
Pass LowerDeviceStorageAccessInfo()
Lower attached storage access information on device.
Pass LowerIntrin()
Lower the target specific function intrinsics in each of the function.
Pass InferFragment()
Infer the TensorCore fragment infomation using tensor intrinsics.
Pass StorageFlatten(int cache_line_size, bool create_bound_attribute=false)
Flatten the multi-dimensional read/write to single dimensional Load/Store.
Pass LowerMatchBuffer()
Remove match buffers inside the block. Also, it will validate the binding.
Pass SplitHostDevice()
Split the function into a host function and device functions.
Pass ConvertForLoopsToSerial()
This pass is post-scheduling pass to convert all Parallel For loops to Serial ones. This is run to attain lesser memory and/or executor/backend does not support parallel launch of For loops.
Pass MakeUnpackedAPI()
Transform the high-level PrimFunc to a C signature that can be used to call the operator directly...
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Pass LiftAttrScope(String attr_key)
Lift common attrs with attr_key to outer scope.
Pass InjectVirtualThread()
Inject virtual thread loops.
Pass CombineContextCall()
Combine context calls in the host function.
Pass InstrumentBoundCheckers()
Instruments bound checkers.
Pass LowerInitBlock()
Lower block init stmt into IfThenElse stmts.
Pass DecorateDeviceScope()
Decorate all the function&#39;s body as device function.
Pass ManifestSharedMemoryLocalStage()
Add the explicit local stage for the shared memory access on GPU.
tvm::transform::Sequential Sequential
Definition: transform.h:49
Pass HoistIfThenElse()
Hoist loop-invariant IfThenElse nodes to outside the elligible loops.
Pass NarrowDataType(int target_bits)
Narrow down PrimExpr datatype in stmt to target_bits.
Pass LowerCrossThreadReduction()
Lower cross-thread reduction from thread bindings to intrinsic function calls.
PrimFuncFrame PrimFunc()
The primitive function statement.
Pass InjectCopyIntrin(String pragma_key, runtime::PackedFunc fintrin)
Inject copy intrinsics with optional pad.
TIR Function.
Pass MakePackedAPI()
Transform the high-level PrimFunc to a low-level version that can be used as an API function...
Pass LoopPartition()
partition loops in the stmt.
Pass BindTarget(Target target)
Annotate a PrimFunc with a given target.
Pass ConvertBlocksToOpaque()
Substitute all the block vars with the PrimExprs they are bound to, indicated by the corresponding it...
tvm::transform::PassNode PassNode
Definition: transform.h:44
Pass UnrollLoop()
unroll the constant loop marked by unroll. This pass also automatically attach pragma unroll tag to l...
Pass CoProcSync()
Detect and insert sync points to co-processor.
Pass LowerOpaqueBlock()
Remove the block to ensure that the TIR can not be scheduled again.
Pass AnnotateEntryFunc()
Set a PrimFunc as the entry point if it is only function in IRModule.
Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite=false)
Remove the weight layout rewrite block.
Pass InjectSoftwarePipeline()
This pass transforms annotated loops into pipelined ones where producers and consumers are overlapped...
Pass InstrumentProfileIntrinsics()
Insert intrinsic calls to instrument function and loop level profiling.
Pass HoistExpression()
Hoist loop-invariant expressions nodes to outside the elligible loops.
Pass RewriteUnsafeSelect()
Detect and rewrite unsafe select that contains memory access.
Pass FlattenBuffer()
Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional BufferLoad/BufferStore...
Pass LowerThreadAllreduce()
Lower cross thread alleduce.
Pass ExtractPrimFuncConstants()
Pass to collect tir non-scalar constants into module&#39;s &#39;Constants&#39; attribute.
Pass LowerTVMBuiltin()
Lower builtin intrinsics.
Pass LowerAsyncDMA()
Lower Async TIR primitives to DMA copy and wait builtins.
Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< PrimFunc(PrimFunc, IRModule, PassContext)> &pass_func, int opt_level, String name, tvm::Array< String > required)
TIR expressions.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Pass LowerAutoCopy()
Automatically do memory optimizations for auto copy blocks.
tvm::transform::PassInfoNode PassInfoNode
Definition: transform.h:46
Pass MergeDynamicSharedMemoryAllocations()
Pass InjectPTXLDG32(bool enable_ptx_ldg32=true)
Pass to rewrite global to local memory copy on CUDA with ldg32 instruction.
Pass InstallDebugSpans()
Add TIR-printer output as debug information to all ops in the module.
Pass VectorizeLoop(bool enable_vectorize=true)
Lower vectorization loops.
Pass Filter(runtime::TypedPackedFunc< bool(PrimFunc)> fcond)
Filter PrimFuncs with a given condition.
Pass StorageRewrite()
Rewrite storage allocation pattern. Moves the allocation to outer most possible scope. Trying to share space between allocations to make a static allocation plan when possible.
tvm::transform::Pass Pass
Definition: transform.h:35
Pass InjectPrefetch()
Inject prefetch instructions into stmt.
Pass Simplify()
Run arithmetic simplifications on the statements and expressions.
Pass RemapThreadAxis(Map< String, IterVar > axis_map)
Remap the thread axis.
Pass UnifiedStaticMemoryPlanner()
This is the unified static memory planner pass that will plan for memory intra- and inter- PrimFuncs ...
tvm::transform::PassContext PassContext
Definition: transform.h:47
Compilation target object.
Pass ConvertSSA()
Convert an IRModule to be SSA form.
Pass LowerWarpMemory()
Lower warp memory access to low-level device related function calls.
Pass CompactBufferAllocation()
Compact the buffer access region by removing the buffer regions that are not accessed, i.e. narrowing the buffer shape and adjust the access region if necessary.
Pass PlanAndUpdateBufferAllocationLocation()
Locate the buffer allocation to the exact position (usually is the lca of buffer access). This pass will inject opaque block with alloc_buffers at the allocation site.
Pass PointerValueTypeRewrite()
Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use the m...
Pass RenormalizeSplitPattern()
Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())
Pass RemoveNoOp()
Remove No Op from the Stmt.
Pass BindParams(const Array< runtime::NDArray > &constants)
Pass LowerCustomDatatypes()
Lower custom datatypes.
Pass InjectDoubleBuffer()
Inject double buffer statements.
Pass BF16StorageLegalize()
Legalize bf16 storage types to u16.
Pass ThreadSync(String storage_scope)
Insert sync between parallel read/write of shared buffers.
Pass InjectPTXAsyncCopy()
Pass to rewrite global to shared memory copy on CUDA with asyncronous copy.
Pass CommonSubexprElimTIR(bool enable_cse_tir=true, bool identify_equiv_terms=false)
Implements a Common Subexpression Elimination (CSE) for TIR which introduces let-in bindings for dupl...
Pass BF16ComputeLegalize()
Legalize bf16 compute Ops. Add a cast to fp32 before Ops, then add a cast back to bf16...
tvm::transform::PassContextNode PassContextNode
Definition: transform.h:48
IRModuleFrame IRModule()
The IRModule declaration statement.
Pass SkipAssert()
skip assert stmt.
tvm::transform::PassInfo PassInfo
Definition: transform.h:45
Pass UnifyThreadBinding()
Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g., "threadIdx.x") use different IterVars and variables in their AttrStmts. After the unification, we use a consolidated IterVar and a variable for them.