tvm
analysis.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_ANALYSIS_H_
25 #define TVM_TIR_ANALYSIS_H_
26 
27 #include <tvm/ir/module.h>
28 #include <tvm/ir/transform.h>
29 #include <tvm/tir/expr.h>
30 #include <tvm/tir/function.h>
31 #include <tvm/tir/op_attr_types.h>
32 #include <tvm/tir/stmt.h>
33 
34 #include <string>
35 
36 namespace tvm {
37 namespace tir {
38 
54 struct ExprDeepEqual {
55  public:
56  TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
57 };
58 
65 template <class FLambda>
66 inline void VisitPrimFuncs(const IRModule& mod, FLambda fvisit) {
67  for (const auto& kv : mod->functions) {
68  const BaseFunc& base_func = kv.second;
69  if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
70  fvisit(prim_func);
71  }
72  }
73 }
74 
80 TVM_DLL double EstimateTIRFlops(const Stmt& stmt);
81 
87 TVM_DLL double EstimateTIRFlops(const IRModule& mod);
88 
95 TVM_DLL Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
96 
102 TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
103 
110 TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);
111 
118 TVM_DLL bool UsesVar(const Stmt& stmt, std::function<bool(const VarNode*)> vset_contains);
119 
126 TVM_DLL bool UsesVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);
127 
137 TVM_DLL bool VerifySSA(const PrimFunc& func);
138 
149 TVM_DLL bool VerifyMemory(const PrimFunc& func);
150 
170 TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);
171 
178 TVM_DLL bool VerifyVTCMLimit(const PrimFunc& func, Integer limit);
179 
193  const Map<Var, Buffer>& buffer_var_map);
194 
204  const Map<Var, Buffer>& buffer_var_map);
205 
210 TVM_DLL size_t CalculateExprComplexity(const PrimExpr& expr);
211 
217 TVM_DLL size_t CalculateConstantBytes(const PrimFunc& func, const Integer& constant_byte_alignment);
218 
225 TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func,
226  const Integer& workspace_byte_alignment);
227 
233 
243 
251 TVM_DLL bool VerifyWellFormed(const PrimFunc& func, bool assert_mode = true);
252 
260 const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var);
261 
276 const tir::BlockNode* FindAnchorBlock(const IRModule& mod);
277 
278 // Pass variants of verification analysis
279 // directly throws RuntimeError when verification fails.
280 namespace transform {
281 
284 
291 TVM_DLL Pass VerifySSA();
292 
299 TVM_DLL Pass VerifyMemory();
300 
309 TVM_DLL Pass VerifyGPUCode(Map<String, PrimExpr> constraints);
310 
319 TVM_DLL Pass VerifyVTCMLimit(const Integer& limit);
320 
330 TVM_DLL Pass OOBChecker();
331 
332 } // namespace transform
333 } // namespace tir
334 } // namespace tvm
335 #endif // TVM_TIR_ANALYSIS_H_
Managed reference to BlockNode.
Definition: stmt.h:1308
Map< Buffer, Optional< Stmt > > DetectBufferAccessLCA(const PrimFunc &func)
Detect the lowest common ancestor(LCA) of buffer access, including both high-level access(BufferLoad...
Array< Array< BufferRegion > > GetBlockReadWriteRegion(const Block &block, const Map< Var, Buffer > &buffer_var_map)
Auto detect the block read/write region according to its body stmt. An opaque access will be counted ...
Pass OOBChecker()
Statically check TIR code for out of bounds array access.
IRModule that holds the functions and type definitions.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
bool VerifyWellFormed(const PrimFunc &func, bool assert_mode=true)
Verify if the given TIR is well-formed. The verification includes:
Attribute types in the Op registry for TIR ops.
A variable node in the IR.
Definition: var.h:47
tvm::Map< String, Integer > CalculateAllocatedBytes(const PrimFunc &func)
Calculate the allocated memory per scope in bytes needed inside the TIR PrimFunc. ...
Primitive functions that contains TIR statements.
Definition: function.h:46
TIR Function.
const tir::BlockNode * FindAnchorBlock(const IRModule &mod)
Find the "anchor block" of the given module. We define the anchor block to be the block with (1) an i...
tvm::transform::Pass Pass
Definition: transform.h:43
const PrimFuncNode * FindEntryFunc(const IRModule &mod, GlobalVar *result_g_var)
Find the entry function of the given IRModule, i.e, functions marked by tir::attr::kIsEntryFunc, whose name is main or being the only PrimeFunc.
size_t CalculateConstantBytes(const PrimFunc &func, const Integer &constant_byte_alignment)
Calculate the constants size in bytes needed by the TIR allocates inside the TIR PrimFunc.
bool VerifyGPUCode(const PrimFunc &func, Map< String, PrimExpr > constraints)
Verify the correctness of a GPU code It will check the whether the amount of memory usage or the numb...
CallEffectKind
The effect type of the call.
Definition: op_attr_types.h:62
TIR statements.
Compare two expressions recursively and check if they are equal to each other without var remapping...
Definition: analysis.h:54
Map< GlobalVar, BaseFunc > functions
A map from ids to all global functions.
Definition: module.h:59
TIR expressions.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
bool VerifySSA(const PrimFunc &func)
Verifies whether the IR stmt or Expr is in SSA form. That is: each Var is defined and assigned once(i...
Managed reference to PrimFuncNode.
Definition: function.h:143
Managed reference to GlobalVarNode.
Definition: expr.h:475
bool VerifyMemory(const PrimFunc &func)
Verify if memory accesses are legal for a specific target device type.
size_t CalculateWorkspaceBytes(const PrimFunc &func, const Integer &workspace_byte_alignment)
Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc.
Container of all statements.
Definition: stmt.h:57
Definition: transform.h:363
Array< Var > UndefinedVars(const Stmt &stmt, const Array< Var > &defs)
Find undefined vars in the statement.
A block is a basic schedule unit in TIR.
Definition: stmt.h:1241
void VisitPrimFuncs(const IRModule &mod, FLambda fvisit)
Visit the PrimFuncs in the IRModule.
Definition: analysis.h:66
size_t CalculateExprComplexity(const PrimExpr &expr)
Calculate the expresion complexity based on number of symbols it contains.
Array< Array< BufferRegion > > GetBlockAccessRegion(const Block &block, const Map< Var, Buffer > &buffer_var_map)
Auto detect the block access region according to its body stmt It will detect the access region as an...
Managed reference class to IRModuleNode.
Definition: module.h:352
bool UsesVar(const Stmt &stmt, std::function< bool(const VarNode *)> vset_contains)
Whether the given Stmt uses any var in the given variable set.
tvm::transform::PassContext PassContext
Definition: transform.h:47
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
CallEffectKind SideEffect(const PrimExpr &expr)
Analyze the side effect.
Managed reference to BaseFuncNode.
Definition: function.h:143
bool operator()(const PrimExpr &lhs, const PrimExpr &rhs) const
bool VerifyVTCMLimit(const PrimFunc &func, Integer limit)
Verifies that the VTCM usage of the given prim_func is within the provided limit. ...
Reference to PrimExprNode.
Definition: expr.h:112
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:865
double EstimateTIRFlops(const Stmt &stmt)
Estimate the FLOPs of a TIR fragment.
Container of constant int that adds more constructors.
Definition: expr.h:620