tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_S_TIR_TRANSFORM_H_
25 #define TVM_S_TIR_TRANSFORM_H_
26 
27 #include <tvm/ir/transform.h>
28 #include <tvm/target/target.h>
29 #include <tvm/tirx/transform.h>
30 
31 #include <string>
32 #include <vector>
33 
34 namespace tvm {
35 namespace s_tir {
36 
45 
46 namespace transform {
47 
51 
57 
64 
69 TVM_DLL Pass LowerInitBlock();
70 
78 
86 
92 
131 TVM_DLL Pass CompactBufferAllocation(bool is_strict = true);
132 
138 
144 
150 
156 
168 
206 
211 TVM_DLL Pass LowerAutoCopy();
212 
218 
221 
227 TVM_DLL Pass LoopPartition();
228 
235 
242 
251 TVM_DLL Pass HoistIfThenElse(tvm::ffi::String variant = "");
252 
264 
270 
276 
282 
288 TVM_DLL Pass InjectPTXLDG32(bool enable_inject = true);
289 
295 
301 
307 TVM_DLL Pass ThreadSync(tvm::ffi::String storage_scope);
308 
313 TVM_DLL Pass InferFragment();
314 
320 
325 TVM_DLL Pass LowerAsyncDMA();
326 
332 
338 
344 
350 TVM_DLL Pass RemoveWeightLayoutRewriteBlock(bool skip_tensor_rewrite = false);
351 
357 
363 
369 
370 } // namespace transform
371 } // namespace s_tir
372 } // namespace tvm
373 
374 #endif // TVM_S_TIR_TRANSFORM_H_
Managed reference to PrimFuncNode.
Definition: function.h:130
Definition: transform.h:400
tvm::transform::PassContext PassContext
Definition: transform.h:40
tvm::transform::Pass Pass
Definition: transform.h:38
Pass HoistIfThenElse(tvm::ffi::String variant="")
Hoist loop-invariant IfThenElse nodes to outside the eligible loops.
Pass DecorateDeviceScope()
Decorate all the function's body as device function.
Pass CanonicalizeLoop()
Canonicalize loop to start from zero .
Pass LowerAutoCopy()
Automatically do memory optimizations for auto copy blocks.
Pass CompactBufferAllocation(bool is_strict=true)
Compact the buffer access region by removing the buffer regions that are not accessed,...
Pass LowerVtcmAlloc()
Lower VTCM allocations.
Pass LowerInitBlock()
Lower block init stmt into IfThenElse stmts.
Pass UseAssumeToReduceBranches()
Eliminate branches by leveraging buffer assumptions (T.assume).
Pass InjectVirtualThread()
Inject virtual thread loops.
Pass HoistExpression()
Hoist loop-invariant expressions to outside the eligible loops.
Pass PlanAndUpdateBufferAllocationLocation()
Locate the buffer allocation to the exact position (usually is the lca of buffer access)....
Pass LowerCrossThreadReduction()
Lower cross-thread reduction from thread bindings to intrinsic function calls.
Pass InjectSoftwarePipeline()
This pass transforms annotated loops into pipelined ones where producers and consumers are overlapped...
Pass RemoveStoreUndef()
Remove stores of tirx::builtin::undef.
Pass InjectDoubleBuffer()
Inject double buffer statements.
Pass TransformMmaBufferLayout()
Transform Mma scope (m16n8k8.matrixA/B/C) to local scope with layout transformation.
Pass InjectPTXAsyncCopy()
Rewrite global to shared memory copy on CUDA with asynchronous copy.
Pass LowerMatchBuffer()
Remove match buffers inside the block. Also, it will validate the binding.
Pass LoopPartition()
partition loops in the stmt.
Pass AnnotateIrregularLoop()
Annotate irregular loop mark.
Pass InstrumentProfileIntrinsics()
Insert intrinsic calls to instrument function and loop level profiling.
Pass LowerOpaqueBlock()
Remove the block to ensure that the TIR can not be scheduled again.
Pass LowerAsyncDMA()
Lower Async TIR primitives to DMA copy and wait builtins.
Pass InjectPTXLDG32(bool enable_inject=true)
Rewrite global to local memory copy on CUDA with ldg32 instruction.
Pass InjectPermutedLayout()
Inject permuted layout for shared memory.
Pass ConvertBlocksToOpaque()
Substitute all the block vars with the PrimExprs they are bound to, indicated by the corresponding it...
Pass LowerThreadAllreduce()
Lower cross thread allreduce.
Pass RenormalizeSplitPattern()
Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()).
Pass ManifestSharedMemoryLocalStage()
Add the explicit local stage for the shared memory access on GPU.
Pass InstrumentBoundCheckers()
Instruments bound checkers.
Pass LiftThreadBinding()
Lift the same thread bindings to their LCA loops.
Pass DefaultGPUSchedule()
Set default thread bindings for GPU PrimFuncs.
Pass RemoveWeightLayoutRewriteBlock(bool skip_tensor_rewrite=false)
Remove weight layout rewrite block before benchmark.
Pass RewriteUnsafeSelect()
Detect and rewrite unsafe select that contains memory access.
Pass UnifyThreadBinding()
Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and "vthread....
Pass MergeSharedMemoryAllocations()
Merge multiple TIR-level shared memory allocations into one.
Pass ThreadSync(tvm::ffi::String storage_scope)
Insert sync between parallel read/write of shared buffers.
Pass InferFragment()
Infer the TensorCore fragment information using tensor intrinsics.
tirx::PrimFunc RenewDefs(const tirx::PrimFunc &func)
Renew the definition nodes for a TIR, including Var, Buffer and IterVar. This pass works as a simple ...
Pass CreatePrimFuncPass(std::function< PrimFunc(PrimFunc, IRModule, PassContext)> pass_func, int opt_level, ffi::String name, tvm::ffi::Array< ffi::String > required, bool traceable=false)
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
Compilation target object.
TIR specific transformation passes.