tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_TRANSFORM_H_
25 #define TVM_TIR_TRANSFORM_H_
26 
27 #include <tvm/ir/transform.h>
28 #include <tvm/target/target.h>
29 #include <tvm/tir/expr.h>
30 #include <tvm/tir/function.h>
31 
32 #include <string>
33 #include <vector>
34 
35 namespace tvm {
36 namespace tir {
37 namespace transform {
38 
47 
48 /*
49  * \brief Create a function pass that optimizes PrimFuncs.
50  *
51  * \param pass_func The packed function that contains the optimization.
52  * \param opt_level The optimization level of the function pass.
53  * \param name The name of the function pass.
54  * \param required The list of the passes that the function pass is dependent on.
55  *
56  * \return The created function pass.
57  */
58 TVM_DLL Pass CreatePrimFuncPass(std::function<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
59  int opt_level, String name, tvm::Array<String> required,
60  bool traceable = false);
61 
67 TVM_DLL Pass LoopPartition();
68 
76 TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true);
77 
84 
91 
101 
108 TVM_DLL Pass UnrollLoop();
109 
115 TVM_DLL Pass RemoveNoOp();
116 
123 
129 TVM_DLL Pass Simplify();
130 
142 TVM_DLL Pass ConvertSSA();
143 
150 
177 TVM_DLL Pass MakePackedAPI();
178 
189 
200 TVM_DLL Pass RemapThreadAxis(Map<String, IterVar> axis_map);
201 
210 
217 
229 
245 
262 
268 TVM_DLL Pass SkipAssert();
269 
276 TVM_DLL Pass ThreadSync(String storage_scope);
277 
284 
290 TVM_DLL Pass InferFragment();
291 
295 static constexpr const char* kDisableLowerTVMBuiltin = "disable_lower_builtin";
296 
302 
308 TVM_DLL Pass LowerIntrin();
309 
315 
324 
331 
340 TVM_DLL Pass NarrowDataType(int target_bits);
341 
349 
356 
364 TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16");
365 
371 
378 
385 
395 
403 
416 
423 
429 
437 
445 
451 
490 TVM_DLL Pass CompactBufferAllocation(bool is_strict = true);
491 
497 
503 
509 
515 
521 TVM_DLL Pass FlattenBuffer();
522 
523 /*
524  * \brief Lower VTCM allocations
525  *
526  * \return The Pass
527  */
529 
533 TVM_DLL Pass LowerAsyncDMA();
534 
542 TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false);
543 
555 
560 
569 
577 
678 
679 TVM_DLL Pass BindParams(const Array<runtime::NDArray>& constants);
680 
687 
692 TVM_DLL Pass LowerAutoCopy();
693 
699 
704 TVM_DLL Pass BindTarget(Target target);
705 
711 
716 TVM_DLL Pass Filter(ffi::TypedFunction<bool(PrimFunc)> fcond);
717 
723 
728 TVM_DLL Pass InjectPTXLDG32(bool enable_ptx_ldg32 = true);
729 
742 TVM_DLL Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite = false);
743 
749 
755 
767 
775 
776 } // namespace transform
777 } // namespace tir
778 } // namespace tvm
779 
780 #endif // TVM_TIR_TRANSFORM_H_
Managed reference class to IRModuleNode.
Definition: module.h:257
Managed reference class to TargetNode.
Definition: target.h:191
Managed reference to PrimFuncNode.
Definition: function.h:131
PassContextNode contains the information that a pass can rely on, such as analysis results.
Definition: transform.h:79
PassContext that is used to configure the pass behavior.
Definition: transform.h:156
Meta data that will be used to help optimization and analysis.
Definition: transform.h:315
PassNode is the base type of differnt types of optimization passes. It is designed as a pure class an...
Definition: transform.h:368
Definition: transform.h:400
Definition: transform.h:493
tvm::transform::PassContext PassContext
Definition: transform.h:40
tvm::transform::PassInfo PassInfo
Definition: transform.h:39
tvm::transform::Pass Pass
Definition: transform.h:38
PrimFuncFrame PrimFunc(bool is_private)
The primitive function statement.
Pass ThreadSync(String storage_scope)
Insert sync between parallel read/write of shared buffers.
Pass LowerIntrin()
Lower the target specific function intrinsics in each of the function.
Pass LowerCustomDatatypes()
Lower custom datatypes.
Pass ConvertForLoopsToSerial()
This pass is post-scheduling pass to convert all Parallel For loops to Serial ones....
Pass LowerThreadAllreduce()
Lower cross thread alleduce.
Pass LoopPartition()
partition loops in the stmt.
Pass MakeUnpackedAPI()
Transform the high-level PrimFunc to a C signature that can be used to call the operator directly.
Pass TransformMmaBufferLayout()
Transform Mma scope (m16n8k8.matrixA/B/C) to local scope with layout transformation.
Pass BF16ComputeLegalize()
Legalize bf16 compute Ops. Add a cast to fp32 before Ops, then add a cast back to bf16.
Pass InjectSoftwarePipeline()
This pass transforms annotated loops into pipelined ones where producers and consumers are overlapped...
Pass CompactBufferAllocation(bool is_strict=true)
Compact the buffer access region by removing the buffer regions that are not accessed,...
Pass RemapThreadAxis(Map< String, IterVar > axis_map)
Remap the thread axis.
Pass InstrumentProfileIntrinsics()
Insert intrinsic calls to instrument function and loop level profiling.
Pass PointerValueTypeRewrite()
Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use the m...
Pass InjectVirtualThread()
Inject virtual thread loops.
Pass MakePackedAPI()
Transform the high-level PrimFunc to a low-level version that can be used as an API function.
Pass FlattenBuffer()
Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional BufferLoad/BufferStore...
Pass LowerAutoCopy()
Automatically do memory optimizations for auto copy blocks.
Pass FP8ComputeLegalize(String promote_dtype_str="float16")
Legalize fp8 compute Ops. Add a cast to fp16/fp32 before Ops, then add a cast back to fp8.
Pass MergeSharedMemoryAllocations()
Pass SplitHostDevice()
Split the function into a host function and device functions.
Pass RewriteUnsafeSelect()
Detect and rewrite unsafe select that contains memory access.
Pass CreatePrimFuncPass(std::function< PrimFunc(PrimFunc, IRModule, PassContext)> pass_func, int opt_level, String name, tvm::Array< String > required, bool traceable=false)
Pass LowerDeviceStorageAccessInfo()
Lower attached storage access information on device.
Pass InjectDoubleBuffer()
Inject double buffer statements.
Pass LowerOpaqueBlock()
Remove the block to ensure that the TIR can not be scheduled again.
Pass FP8StorageLegalize()
Legalize fp8 storage types to u8.
Pass RenormalizeSplitPattern()
Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())
Pass InlinePrivateFunctions()
Inline calls to private functions.
Pass PlanAndUpdateBufferAllocationLocation()
Locate the buffer allocation to the exact position (usually is the lca of buffer access)....
Pass LowerCrossThreadReduction()
Lower cross-thread reduction from thread bindings to intrinsic function calls.
Pass DecorateDeviceScope()
Decorate all the function's body as device function.
Pass SkipAssert()
skip assert stmt.
Pass ForceNarrowIndexToInt32()
Force to narrow down indexing expressions and integer buffers to int32 dtype.
Pass AnnotateEntryFunc()
Set a PrimFunc as the entry point if it is only function in IRModule.
Pass ConvertSSA()
Convert an IRModule to be SSA form.
Pass Simplify()
Run arithmetic simplifications on the statements and expressions.
Pass CommonSubexprElimTIR(bool enable_cse_tir=true, bool identify_equiv_terms=false)
Implements a Common Subexpression Elimination (CSE) for TIR which introduces let-in bindings for dupl...
Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite=false)
Remove the weight layout rewrite block.
Pass RemoveNoOp()
Remove No Op from the Stmt.
Pass AnnotateDeviceRegions()
Annotate locations that should be run on the device.
Pass UnifyThreadBinding()
Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and "vthread....
Pass LowerTVMBuiltin()
Lower builtin intrinsics.
Pass InstrumentBoundCheckers()
Instruments bound checkers.
Pass UseAssumeToReduceBranches()
This pass analyzes primfunc & eliminates branch introdued due to layout specific padding....
Pass LowerDeviceKernelLaunch()
Lower cross-device function calls.
Pass NarrowDataType(int target_bits)
Narrow down PrimExpr datatype in stmt to target_bits.
Pass HoistExpression()
Hoist loop-invariant expressions nodes to outside the elligible loops.
Pass DefaultGPUSchedule()
The pass sets default thread bindings for PrimFuncs, including symbolic shape functions,...
Pass LowerMatchBuffer()
Remove match buffers inside the block. Also, it will validate the binding.
Pass UnrollLoop()
unroll the constant loop marked by unroll. This pass also automatically attach pragma unroll tag to l...
Pass StorageRewrite()
Rewrite storage allocation pattern. Moves the allocation to outer most possible scope....
Pass BindParams(const Array< runtime::NDArray > &constants)
Pass ExtractPrimFuncConstants()
Pass to collect tir non-scalar constants into module's 'Constants' attribute.
Pass InjectPTXLDG32(bool enable_ptx_ldg32=true)
Pass to rewrite global to local memory copy on CUDA with ldg32 instruction.
Pass InjectPTXAsyncCopy()
Pass to rewrite global to shared memory copy on CUDA with asyncronous copy.
Pass UnifiedStaticMemoryPlanner()
This is the unified static memory planner pass that will plan for memory intra- and inter- PrimFuncs ...
Pass HoistIfThenElse()
Hoist loop-invariant IfThenElse nodes to outside the elligible loops.
Pass LowerWarpMemory()
Lower warp memory access to low-level device related function calls.
Pass InferFragment()
Infer the TensorCore fragment infomation using tensor intrinsics.
Pass ConvertBlocksToOpaque()
Substitute all the block vars with the PrimExprs they are bound to, indicated by the corresponding it...
Pass BF16StorageLegalize()
Legalize bf16 storage types to u16.
Pass BindTarget(Target target)
Annotate a PrimFunc with a given target.
Pass LowerInitBlock()
Lower block init stmt into IfThenElse stmts.
Pass LiftThreadBinding()
Lift the same thread bindings to their LCA loops.
Pass Filter(ffi::TypedFunction< bool(PrimFunc)> fcond)
Filter PrimFuncs with a given condition.
Pass CombineContextCall()
Combine context calls in the host function.
Pass ManifestSharedMemoryLocalStage()
Add the explicit local stage for the shared memory access on GPU.
Pass VectorizeLoop(bool enable_vectorize=true)
Lower vectorization loops.
Pass LowerAsyncDMA()
Lower Async TIR primitives to DMA copy and wait builtins.
Pass InjectPermutedLayout()
Inject permuted layout for shared memory.
Pass CreateModulePass(std::function< IRModule(IRModule, PassContext)> pass_func, int opt_level, String name, Array< String > required, bool traceable=false)
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
Compilation target object.
TIR expressions.
TIR Function.