tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_TRANSFORM_H_
25 #define TVM_TIR_TRANSFORM_H_
26 
27 #include <tvm/ir/transform.h>
28 #include <tvm/tir/expr.h>
29 #include <tvm/tir/function.h>
30 
31 #include <string>
32 
33 namespace tvm {
34 namespace tir {
35 namespace transform {
36 
44 
45 /*
46  * \brief Create a function pass that optimizes PrimFuncs.
47  *
48  * \param pass_func The packed function that contains the optimization.
49  * \param opt_level The optimization level of the function pass.
50  * \param name The name of the function pass.
51  * \param required The list of the passes that the function pass is dependent on.
52  *
53  * \return The created function pass.
54  */
55 TVM_DLL Pass CreatePrimFuncPass(
56  const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
57  int opt_level, String name, tvm::Array<String> required);
58 
64 TVM_DLL Pass InjectPrefetch();
65 
66 // TODO(tvm-team): consolidate configs to the PassContext
76 TVM_DLL Pass StorageFlatten(int cache_line_size, bool create_bound_attribute = false);
77 
91 TVM_DLL Pass InjectCopyIntrin(String pragma_key, runtime::PackedFunc fintrin);
92 
98 TVM_DLL Pass CoProcSync();
99 
106 TVM_DLL Pass LiftAttrScope(String attr_key);
107 
113 TVM_DLL Pass LoopPartition();
114 
122 TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true);
123 
129 TVM_DLL Pass InjectVirtualThread();
130 
136 TVM_DLL Pass InjectDoubleBuffer();
137 
146 TVM_DLL Pass StorageRewrite();
147 
154 TVM_DLL Pass UnrollLoop();
155 
161 TVM_DLL Pass RemoveNoOp();
162 
168 TVM_DLL Pass RewriteUnsafeSelect();
169 
175 TVM_DLL Pass Simplify();
176 
182 TVM_DLL Pass InstrumentBoundCheckers();
183 
213 TVM_DLL Pass MakePackedAPI(int num_unpacked_args);
214 
224 TVM_DLL Pass MakeUnpackedAPI();
225 
236 TVM_DLL Pass RemapThreadAxis(Map<String, IterVar> axis_map);
237 
245 TVM_DLL Pass LowerCustomDatatypes();
246 
252 TVM_DLL Pass DecorateDeviceScope();
253 
259 TVM_DLL Pass SplitHostDevice();
260 
266 TVM_DLL Pass SkipAssert();
267 
274 TVM_DLL Pass ThreadSync(String storage_scope);
275 
281 TVM_DLL Pass LowerThreadAllreduce();
282 
288 TVM_DLL Pass InferFragment();
289 
294 TVM_DLL Pass LowerTVMBuiltin();
295 
301 TVM_DLL Pass LowerIntrin();
302 
307 TVM_DLL Pass LowerWarpMemory();
308 
317 
323 TVM_DLL Pass CombineContextCall();
324 
333 TVM_DLL Pass NarrowDataType(int target_bits);
334 
340 TVM_DLL Pass BF16Legalize();
341 
350 TVM_DLL Pass PointerValueTypeRewrite();
351 
358 TVM_DLL Pass HoistIfThenElse();
359 
364 TVM_DLL Pass LowerInitBlock();
365 
373 
380 TVM_DLL Pass ConvertBlocksToOpaque();
381 
419 TVM_DLL Pass CompactBufferAllocation();
420 
424 TVM_DLL Pass LegalizePackedCalls();
425 
430 TVM_DLL Pass LowerMatchBuffer();
431 
438 TVM_DLL Pass FlattenBuffer();
439 
440 /*
441  * \brief Flatten the multi-dimensional read/write
442  * to two dimensional texture Load/Store and realize
443  * texture buffer allocations.
444  *
445  * \return The Pass
446  */
447 TVM_DLL Pass TextureFlatten();
448 
459 TVM_DLL Pass UnifyThreadBinding();
460 
465 
473 TVM_DLL Pass ConvertForLoopsToSerial();
474 
475 } // namespace transform
476 } // namespace tir
477 } // namespace tvm
478 
479 #endif // TVM_TIR_TRANSFORM_H_
Pass LowerDeviceStorageAccessInfo()
Lower attached storage access information on device.
Pass LowerIntrin()
Lower the target specific function intrinsics in each of the function.
Pass InferFragment()
Infer the TensorCore fragment infomation using tensor intrinsics.
Pass StorageFlatten(int cache_line_size, bool create_bound_attribute=false)
Flatten the multi-dimensional read/write to single dimensional Load/Store.
Pass LowerMatchBuffer()
Remove match buffers inside the block. Also, it will validate the binding.
Pass SplitHostDevice()
Split the function into a host function and device functions.
Pass ConvertForLoopsToSerial()
This pass is post-scheduling pass to convert all Parallel For loops to Serial ones. This is run to attain lesser memory and/or executor/backend does not support parallel launch of For loops.
Pass MakeUnpackedAPI()
Transform the high-level PrimFunc to a C signature that can be used to call the operator directly...
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
Pass LiftAttrScope(String attr_key)
Lift common attrs with attr_key to outer scope.
Pass InjectVirtualThread()
Inject virtual thread loops.
Pass CombineContextCall()
Combine context calls in the host function.
Pass InstrumentBoundCheckers()
Instruments bound checkers.
Pass LowerInitBlock()
Lower block init stmt into IfThenElse stmts.
Pass DecorateDeviceScope()
Decorate all the function&#39;s body as device function.
tvm::transform::Sequential Sequential
Definition: transform.h:47
Pass HoistIfThenElse()
Hoist loop-invariant IfThenElse nodes to outside the elligible loops.
Pass MakePackedAPI(int num_unpacked_args)
Transform the high-level PrimFunc to a low-level version that can be used as an API function...
Pass NarrowDataType(int target_bits)
Narrow down PrimExpr datatype in stmt to target_bits.
Pass InjectCopyIntrin(String pragma_key, runtime::PackedFunc fintrin)
Inject copy intrinsics with optional pad.
TIR Function.
Pass LoopPartition()
partition loops in the stmt.
Pass ConvertBlocksToOpaque()
Substitute all the block vars with the PrimExprs they are bound to, indicated by the corresponding it...
tvm::transform::PassNode PassNode
Definition: transform.h:42
tvm::transform::Pass Pass
Definition: transform.h:41
Pass UnrollLoop()
unroll the constant loop marked by unroll. This pass also automatically attach pragma unroll tag to l...
Pass CoProcSync()
Detect and insert sync points to co-processor.
Pass RewriteUnsafeSelect()
Detect and rewrite unsafe select that contains memory access.
Pass FlattenBuffer()
Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional Load/Store. Also remove Block to ensure that the flattened TIR can not be scheduled again.
Pass LowerThreadAllreduce()
Lower cross thread alleduce.
Pass LowerTVMBuiltin()
Lower builtin intrinsics.
Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< PrimFunc(PrimFunc, IRModule, PassContext)> &pass_func, int opt_level, String name, tvm::Array< String > required)
TIR expressions.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
tvm::transform::PassInfoNode PassInfoNode
Definition: transform.h:44
Pass MergeDynamicSharedMemoryAllocations()
Pass VectorizeLoop(bool enable_vectorize=true)
Lower vectorization loops.
Pass StorageRewrite()
Rewrite storage allocation pattern. Moves the allocation to outer most possible scope. Trying to share space between allocations to make a static allocation plan when possible.
Pass BF16Legalize()
Legalize bf16 typed Ops. Add a cast to fp32 before Ops, then add a cast back to bf16.
Pass InjectPrefetch()
Inject prefetch instructions into stmt.
Pass Simplify()
Run arithmetic simplifications on the statements and expressions.
Pass RemapThreadAxis(Map< String, IterVar > axis_map)
Remap the thread axis.
tvm::transform::PassContext PassContext
Definition: transform.h:45
Pass LowerWarpMemory()
Lower warp memory access to low-level device related function calls.
Pass CompactBufferAllocation()
Compact the buffer access region by removing the buffer regions that are not accessed, i.e. narrowing the buffer shape and adjust the access region if necessary.
Pass PlanAndUpdateBufferAllocationLocation()
Locate the buffer allocation to the exact position (usually is the lca of buffer access). This pass will inject opaque block with alloc_buffers at the allocation site.
Pass PointerValueTypeRewrite()
Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use the m...
Pass RemoveNoOp()
Remove No Op from the Stmt.
Pass LowerCustomDatatypes()
Lower custom datatypes.
Pass InjectDoubleBuffer()
Inject double buffer statements.
Pass ThreadSync(String storage_scope)
Insert sync between parallel read/write of shared buffers.
tvm::transform::PassContextNode PassContextNode
Definition: transform.h:46
Pass SkipAssert()
skip assert stmt.
tvm::transform::PassInfo PassInfo
Definition: transform.h:43
Pass UnifyThreadBinding()
Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g., "threadIdx.x") use different IterVars and variables in their AttrStmts. After the unification, we use a consolidated IterVar and a variable for them.