tvm
analysis.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_S_TIR_ANALYSIS_H_
25 #define TVM_S_TIR_ANALYSIS_H_
26 
27 #include <tvm/ir/module.h>
28 #include <tvm/ir/transform.h>
29 #include <tvm/target/target.h>
30 #include <tvm/tirx/function.h>
31 #include <tvm/tirx/stmt.h>
32 
33 #include <optional>
34 
35 namespace tvm {
36 namespace tirx {
37 
50 TVM_DLL ffi::Array<ffi::Array<BufferRegion>> GetSBlockAccessRegion(
51  const SBlock& block, const ffi::Map<Var, Buffer>& buffer_var_map);
52 
61 TVM_DLL ffi::Array<ffi::Array<BufferRegion>> GetSBlockReadWriteRegion(
62  const SBlock& block, const ffi::Map<Var, Buffer>& buffer_var_map);
63 
72 TVM_DLL ffi::Map<Buffer, ffi::Optional<Stmt>> DetectBufferAccessLCA(const PrimFunc& func);
73 
89 
90 } // namespace tirx
91 
92 namespace arith {
93 class Analyzer;
94 }
95 
96 namespace s_tir {
97 
98 using namespace tvm::tirx;
99 
105 TVM_DLL double EstimateTIRFlops(const Stmt& stmt);
106 
112 TVM_DLL double EstimateTIRFlops(const IRModule& mod);
113 
120 TVM_DLL bool IsPureFunction(const PrimFunc& func, bool assert_on_error = false);
121 
128 TVM_DLL bool VerifyGPUCode(const PrimFunc& func, ffi::Map<ffi::String, PrimExpr> constraints);
129 
134 };
135 
141 TVM_DLL std::optional<MemCpyDetails> IdentifyMemCpy(const For& loop, arith::Analyzer* analyzer);
142 
148 TVM_DLL ffi::Map<ffi::String, ffi::Map<ffi::String, Integer>> CalculateAllocatedBytes(
149  const PrimFunc& func);
150 
156 TVM_DLL ffi::Map<ffi::String, ffi::Map<ffi::String, Integer>> CalculateAllocatedBytes(
157  const IRModule& mod);
158 
163 TVM_DLL ffi::Array<tvm::transform::Pass> GetVTCMCompactionPasses();
164 
171 TVM_DLL bool VerifyVTCMLimit(const IRModule& mod, Integer limit);
172 
179 TVM_DLL bool VerifyVTCMLimit(const PrimFunc& func, Integer limit);
180 
181 namespace transform {
182 
185 
191 TVM_DLL Pass VerifyGPUCode(ffi::Map<ffi::String, PrimExpr> constraints);
192 
198 TVM_DLL Pass VerifyVTCMLimit(ffi::Optional<Target> default_target = std::nullopt);
199 
204 TVM_DLL Pass OOBChecker();
205 
206 } // namespace transform
207 } // namespace s_tir
208 } // namespace tvm
209 #endif // TVM_S_TIR_ANALYSIS_H_
Managed reference class to IRModuleNode.
Definition: module.h:257
Container of constant int that adds more constructors.
Definition: expr.h:601
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:634
Managed reference to BufferRegionNode.
Definition: stmt.h:716
Managed reference to ForNode.
Definition: stmt.h:640
Managed reference to PrimFuncNode.
Definition: function.h:130
A block is a basic schedule unit in TIR.
Definition: stmt.h:799
Managed reference to SBlockNode.
Definition: stmt.h:846
Container of all statements.
Definition: stmt.h:65
Definition: transform.h:400
IRModule that holds the functions and type definitions.
tvm::transform::Pass Pass
Definition: transform.h:35
tvm::transform::PassContext PassContext
Definition: transform.h:37
Pass OOBChecker()
Statically check TIR code for out of bounds array access.
std::optional< MemCpyDetails > IdentifyMemCpy(const For &loop, arith::Analyzer *analyzer)
Identify whether a For loop is semantically equivalent to MemCpy.
ffi::Map< ffi::String, ffi::Map< ffi::String, Integer > > CalculateAllocatedBytes(const PrimFunc &func)
Calculate the allocated memory per scope in bytes needed inside the TIR PrimFunc.
bool VerifyGPUCode(const PrimFunc &func, ffi::Map< ffi::String, PrimExpr > constraints)
Verify the correctness of a GPU code.
bool VerifyVTCMLimit(const IRModule &mod, Integer limit)
Verifies that the VTCM usage for all prim_funcs in the given IRModule.
bool IsPureFunction(const PrimFunc &func, bool assert_on_error=false)
Analyze the side effect of a function.
double EstimateTIRFlops(const Stmt &stmt)
Estimate the FLOPs of a TIR fragment.
ffi::Array< tvm::transform::Pass > GetVTCMCompactionPasses()
Get the list of lowering passes to calculate the compacted VTCM allocation size.
Definition: axis_group_graph.h:39
ffi::Map< Buffer, ffi::Optional< Stmt > > DetectBufferAccessLCA(const PrimFunc &func)
Detect the lowest common ancestor(LCA) of buffer access, including both high-level access(BufferLoad,...
const tirx::SBlockNode * FindAnchorBlock(const IRModule &mod)
Find the "anchor block" of the given module. We define the anchor block to be the block with (1) an i...
ffi::Array< ffi::Array< BufferRegion > > GetSBlockReadWriteRegion(const SBlock &block, const ffi::Map< Var, Buffer > &buffer_var_map)
Auto detect the block read/write region according to its body stmt. An opaque access will be counted ...
ffi::Array< ffi::Array< BufferRegion > > GetSBlockAccessRegion(const SBlock &block, const ffi::Map< Var, Buffer > &buffer_var_map)
Auto detect the block access region according to its body stmt It will detect the access region as an...
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:308
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
Helper struct for return value of IdentifyMemCpy.
Definition: analysis.h:131
BufferRegion source
Definition: analysis.h:132
BufferRegion dest
Definition: analysis.h:133
Compilation target object.
TIR Function.
TIR statements.