tvm
analysis.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_ANALYSIS_H_
25 #define TVM_TIR_ANALYSIS_H_
26 
27 #include <tvm/ir/module.h>
28 #include <tvm/ir/transform.h>
29 #include <tvm/tir/expr.h>
30 #include <tvm/tir/function.h>
31 #include <tvm/tir/op_attr_types.h>
32 #include <tvm/tir/stmt.h>
33 
34 #include <string>
35 
36 namespace tvm {
37 namespace tir {
38 
54 struct ExprDeepEqual {
55  public:
56  TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
57 };
58 
65 template <class FLambda>
66 inline void VisitPrimFuncs(const IRModule& mod, FLambda fvisit) {
67  for (const auto& kv : mod->functions) {
68  const BaseFunc& base_func = kv.second;
69  if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
70  fvisit(prim_func);
71  }
72  }
73 }
74 
80 TVM_DLL double EstimateTIRFlops(const Stmt& stmt);
81 
87 TVM_DLL double EstimateTIRFlops(const IRModule& mod);
88 
95 TVM_DLL Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
96 
102 TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
103 
110 TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);
111 
118 TVM_DLL bool UsesVar(const Stmt& stmt, std::function<bool(const VarNode*)> vset_contains);
119 
126 TVM_DLL bool UsesVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);
127 
137 TVM_DLL bool VerifySSA(const PrimFunc& func);
138 
149 TVM_DLL bool VerifyMemory(const PrimFunc& func);
150 
170 TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);
171 
185  const Map<Var, Buffer>& buffer_var_map);
186 
196  const Map<Var, Buffer>& buffer_var_map);
197 
202 TVM_DLL size_t CalculateExprComplexity(const PrimExpr& expr);
203 
210 TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func,
211  const Integer& workspace_byte_alignment);
212 
222 
230 TVM_DLL bool VerifyWellFormed(const PrimFunc& func, bool assert_mode = true);
231 
232 // Pass variants of verification analysis
233 // directly throws RuntimeError when verification fails.
234 namespace transform {
235 
238 
245 TVM_DLL Pass VerifySSA();
246 
253 TVM_DLL Pass VerifyMemory();
254 
263 TVM_DLL Pass VerifyGPUCode(Map<String, PrimExpr> constraints);
264 
274 TVM_DLL Pass OOBChecker();
275 
276 } // namespace transform
277 } // namespace tir
278 } // namespace tvm
279 #endif // TVM_TIR_ANALYSIS_H_
Managed reference to BlockNode.
Definition: stmt.h:1295
Map< Buffer, Optional< Stmt > > DetectBufferAccessLCA(const PrimFunc &func)
Detect the lowest common ancestor(LCA) of buffer access, including both high-level access(BufferLoad...
Array< Array< BufferRegion > > GetBlockReadWriteRegion(const Block &block, const Map< Var, Buffer > &buffer_var_map)
Auto detect the block read/write region according to its body stmt. An opaque access will be counted ...
Pass OOBChecker()
Statically check TIR code for out of bounds array access.
IRModule that holds the functions and type definitions.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
bool VerifyWellFormed(const PrimFunc &func, bool assert_mode=true)
Verify if the given TIR is well-formed. The verification includes:
Attribute types in the Op registry for TIR ops.
A variable node in the IR.
Definition: var.h:47
Primitive functions that contains TIR statements.
Definition: function.h:46
TIR Function.
tvm::transform::Pass Pass
Definition: transform.h:43
bool VerifyGPUCode(const PrimFunc &func, Map< String, PrimExpr > constraints)
Verify the correctness of a GPU code It will check the whether the amount of memory usage or the numb...
CallEffectKind
The effect type of the call.
Definition: op_attr_types.h:60
TIR statements.
Compare two expressions recursively and check if they are equal to each other without var remapping...
Definition: analysis.h:54
Map< GlobalVar, BaseFunc > functions
A map from ids to all global functions.
Definition: module.h:59
TIR expressions.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
bool VerifySSA(const PrimFunc &func)
Verifies whether the IR stmt or Expr is in SSA form. That is: each Var is defined and assigned once(i...
Managed reference to PrimFuncNode.
Definition: function.h:156
bool VerifyMemory(const PrimFunc &func)
Verify if memory accesses are legal for a specific target device type.
size_t CalculateWorkspaceBytes(const PrimFunc &func, const Integer &workspace_byte_alignment)
Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc.
Container of all statements.
Definition: stmt.h:57
Definition: transform.h:363
Array< Var > UndefinedVars(const Stmt &stmt, const Array< Var > &defs)
Find undefined vars in the statement.
void VisitPrimFuncs(const IRModule &mod, FLambda fvisit)
Visit the PrimFuncs in the IRModule.
Definition: analysis.h:66
size_t CalculateExprComplexity(const PrimExpr &expr)
Calculate the expresion complexity based on number of symbols it contains.
Array< Array< BufferRegion > > GetBlockAccessRegion(const Block &block, const Map< Var, Buffer > &buffer_var_map)
Auto detect the block access region according to its body stmt It will detect the access region as an...
Managed reference class to IRModuleNode.
Definition: module.h:352
bool UsesVar(const Stmt &stmt, std::function< bool(const VarNode *)> vset_contains)
Whether the given Stmt uses any var in the given variable set.
tvm::transform::PassContext PassContext
Definition: transform.h:47
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
CallEffectKind SideEffect(const PrimExpr &expr)
Analyze the side effect.
Managed reference to BaseFuncNode.
Definition: function.h:143
bool operator()(const PrimExpr &lhs, const PrimExpr &rhs) const
Reference to PrimExprNode.
Definition: expr.h:112
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:865
double EstimateTIRFlops(const Stmt &stmt)
Estimate the FLOPs of a TIR fragment.
Container of constant int that adds more constructors.
Definition: expr.h:618