tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
analysis.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_ANALYSIS_H_
25 #define TVM_TIR_ANALYSIS_H_
26 
27 #include <tvm/ir/module.h>
28 #include <tvm/ir/transform.h>
29 #include <tvm/tir/expr.h>
30 #include <tvm/tir/function.h>
31 #include <tvm/tir/op_attr_types.h>
32 #include <tvm/tir/stmt.h>
33 
34 #include <optional>
35 #include <string>
36 
37 namespace tvm {
38 
39 namespace arith {
40 class Analyzer;
41 }
42 
43 namespace tir {
44 
60 struct ExprDeepEqual {
61  public:
62  TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
63 };
64 
71 template <class FLambda>
72 inline void VisitPrimFuncs(const IRModule& mod, FLambda fvisit) {
73  for (const auto& kv : mod->functions) {
74  const BaseFunc& base_func = kv.second;
75  if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
76  fvisit(prim_func);
77  }
78  }
79 }
80 
86 TVM_DLL double EstimateTIRFlops(const Stmt& stmt);
87 
93 TVM_DLL double EstimateTIRFlops(const IRModule& mod);
94 
101 TVM_DLL Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
102 
108 TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
109 
116 TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr, const Array<Var>& defs);
117 
124 TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);
125 
132 TVM_DLL bool UsesVar(const Stmt& stmt, std::function<bool(const VarNode*)> vset_contains);
133 
140 TVM_DLL bool UsesVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);
141 
151 TVM_DLL bool VerifySSA(const PrimFunc& func);
152 
163 TVM_DLL bool VerifyMemory(const PrimFunc& func);
164 
184 TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);
185 
192 TVM_DLL bool VerifyVTCMLimit(const PrimFunc& func, Integer limit);
193 
207  const Map<Var, Buffer>& buffer_var_map);
208 
218  const Map<Var, Buffer>& buffer_var_map);
219 
230 };
231 
241 TVM_DLL std::optional<MemCpyDetails> IdentifyMemCpy(const For& loop, arith::Analyzer* analyzer);
242 
247 TVM_DLL size_t CalculateExprComplexity(const PrimExpr& expr);
248 
254 TVM_DLL size_t CalculateConstantBytes(const PrimFunc& func, const Integer& constant_byte_alignment);
255 
262 TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func,
263  const Integer& workspace_byte_alignment);
264 
270 
280 
288 TVM_DLL bool VerifyWellFormed(const PrimFunc& func, bool assert_mode = true);
289 
297 const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var);
298 
313 const tir::BlockNode* FindAnchorBlock(const IRModule& mod);
314 
315 // Pass variants of verification analysis
316 // directly throws RuntimeError when verification fails.
317 namespace transform {
318 
321 
328 TVM_DLL Pass VerifySSA();
329 
336 TVM_DLL Pass VerifyMemory();
337 
346 TVM_DLL Pass VerifyGPUCode(Map<String, PrimExpr> constraints);
347 
356 TVM_DLL Pass VerifyVTCMLimit(const Integer& limit);
357 
367 TVM_DLL Pass OOBChecker();
368 
369 } // namespace transform
370 } // namespace tir
371 } // namespace tvm
372 #endif // TVM_TIR_ANALYSIS_H_
Managed reference to BlockNode.
Definition: stmt.h:1258
Map< Buffer, Optional< Stmt > > DetectBufferAccessLCA(const PrimFunc &func)
Detect the lowest common ancestor(LCA) of buffer access, including both high-level access(BufferLoad...
Array< Array< BufferRegion > > GetBlockReadWriteRegion(const Block &block, const Map< Var, Buffer > &buffer_var_map)
Auto detect the block read/write region according to its body stmt. An opaque access will be counted ...
Pass OOBChecker()
Statically check TIR code for out of bounds array access.
Array< Var > UndefinedVars(const PrimExpr &expr, const Array< Var > &defs)
Find undefined vars in the expression.
IRModule that holds the functions and type definitions.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
bool VerifyWellFormed(const PrimFunc &func, bool assert_mode=true)
Verify if the given TIR is well-formed. The verification includes:
Attribute types in the Op registry for TIR ops.
bool UsesVar(const PrimExpr &expr, std::function< bool(const VarNode *)> vset_contains)
Whether the given PrimExpr uses any var in the given variable set.
A variable node in the IR.
Definition: var.h:47
tvm::Map< String, Integer > CalculateAllocatedBytes(const PrimFunc &func)
Calculate the allocated memory per scope in bytes needed inside the TIR PrimFunc. ...
Primitive functions that contains TIR statements.
Definition: function.h:46
TIR Function.
Managed reference to ForNode.
Definition: stmt.h:962
double EstimateTIRFlops(const IRModule &mod)
Estimate the FLOPs of TIRs in an IRModule.
Managed reference to BufferRegionNode.
Definition: stmt.h:1099
const tir::BlockNode * FindAnchorBlock(const IRModule &mod)
Find the "anchor block" of the given module. We define the anchor block to be the block with (1) an i...
tvm::transform::Pass Pass
Definition: transform.h:43
const PrimFuncNode * FindEntryFunc(const IRModule &mod, GlobalVar *result_g_var)
Find the entry function of the given IRModule, i.e, functions marked by tir::attr::kIsEntryFunc, whose name is main or being the only PrimeFunc.
size_t CalculateConstantBytes(const PrimFunc &func, const Integer &constant_byte_alignment)
Calculate the constants size in bytes needed by the TIR allocates inside the TIR PrimFunc.
Helper struct for return value of IdentifyMemCpy.
Definition: analysis.h:227
CallEffectKind
The effect type of the call.
Definition: op_attr_types.h:88
TIR statements.
Pass VerifyVTCMLimit(const Integer &limit)
Pass to checks if the size of the allocated vtcm memory satisfies the limit.
Compare two expressions recursively and check if they are equal to each other without var remapping...
Definition: analysis.h:60
BufferRegion source
Definition: analysis.h:228
Map< GlobalVar, BaseFunc > functions
A map from ids to all global functions.
Definition: module.h:59
TIR expressions.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Managed reference to PrimFuncNode.
Definition: function.h:145
Managed reference to GlobalVarNode.
Definition: expr.h:477
size_t CalculateWorkspaceBytes(const PrimFunc &func, const Integer &workspace_byte_alignment)
Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc.
Container of all statements.
Definition: stmt.h:59
Definition: transform.h:362
Pass VerifyGPUCode(Map< String, PrimExpr > constraints)
Pass variant of VerifyGPUCode.
A block is a basic schedule unit in TIR.
Definition: stmt.h:1191
void VisitPrimFuncs(const IRModule &mod, FLambda fvisit)
Visit the PrimFuncs in the IRModule.
Definition: analysis.h:72
size_t CalculateExprComplexity(const PrimExpr &expr)
Calculate the expresion complexity based on number of symbols it contains.
Array< Array< BufferRegion > > GetBlockAccessRegion(const Block &block, const Map< Var, Buffer > &buffer_var_map)
Auto detect the block access region according to its body stmt It will detect the access region as an...
std::optional< MemCpyDetails > IdentifyMemCpy(const For &loop, arith::Analyzer *analyzer)
Identify whether a For loop is semantically equivalent to MemCpy.
Pass VerifyMemory()
Pass variant of VerifyMemory.
Managed reference class to IRModuleNode.
Definition: module.h:348
tvm::transform::PassContext PassContext
Definition: transform.h:47
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
CallEffectKind SideEffect(const PrimExpr &expr)
Analyze the side effect.
Managed reference to BaseFuncNode.
Definition: function.h:143
BufferRegion dest
Definition: analysis.h:229
Pass VerifySSA()
Pass variant of VerifySSA.
Reference to PrimExprNode.
Definition: expr.h:114
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:865
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:579
Container of constant int that adds more constructors.
Definition: expr.h:622