tvm
analysis.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_ANALYSIS_H_
25 #define TVM_TIR_ANALYSIS_H_
26 
27 #include <tvm/ir/module.h>
28 #include <tvm/ir/transform.h>
29 #include <tvm/tir/expr.h>
30 #include <tvm/tir/function.h>
31 #include <tvm/tir/op_attr_types.h>
32 #include <tvm/tir/stmt.h>
33 
34 #include <string>
35 
36 namespace tvm {
37 namespace tir {
38 
54 struct ExprDeepEqual {
55  public:
56  TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
57 };
58 
65 template <class FLambda>
66 inline void VisitPrimFuncs(const IRModule& mod, FLambda fvisit) {
67  for (const auto& kv : mod->functions) {
68  const BaseFunc& base_func = kv.second;
69  if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
70  fvisit(prim_func);
71  }
72  }
73 }
74 
80 TVM_DLL double EstimateTIRFlops(const Stmt& stmt);
81 
87 TVM_DLL double EstimateTIRFlops(const IRModule& mod);
88 
95 TVM_DLL Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
96 
102 TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
103 
110 TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);
111 
118 TVM_DLL bool UsesVar(const Stmt& stmt, std::function<bool(const VarNode*)> vset_contains);
119 
126 TVM_DLL bool UsesVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);
127 
137 TVM_DLL bool VerifySSA(const PrimFunc& func);
138 
149 TVM_DLL bool VerifyMemory(const PrimFunc& func);
150 
170 TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);
171 
185  const Map<Var, Buffer>& buffer_var_map);
186 
196  const Map<Var, Buffer>& buffer_var_map);
197 
202 TVM_DLL size_t CalculateExprComplexity(const PrimExpr& expr);
203 
210 TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func,
211  const Integer& workspace_byte_alignment);
212 
222 
223 // Pass variants of verification analysis
224 // directly throws RuntimeError when verification fails.
225 namespace transform {
226 
229 
236 TVM_DLL Pass VerifySSA();
237 
244 TVM_DLL Pass VerifyMemory();
245 
254 TVM_DLL Pass VerifyGPUCode(Map<String, PrimExpr> constraints);
255 
256 } // namespace transform
257 } // namespace tir
258 } // namespace tvm
259 #endif // TVM_TIR_ANALYSIS_H_
Managed reference to BlockNode.
Definition: stmt.h:1260
Map< Buffer, Optional< Stmt > > DetectBufferAccessLCA(const PrimFunc &func)
Detect the lowest common ancestor(LCA) of buffer access, including both high-level access(BufferLoad...
Array< Array< BufferRegion > > GetBlockReadWriteRegion(const Block &block, const Map< Var, Buffer > &buffer_var_map)
Auto detect the block read/write region according to its body stmt. An opaque access will be counted ...
IRModule that holds the functions and type definitions.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Attribute types in the Op registry for TIR ops.
A variable node in the IR.
Definition: var.h:47
Primitive functions that contains TIR statements.
Definition: function.h:46
TIR Function.
tvm::transform::Pass Pass
Definition: transform.h:43
bool VerifyGPUCode(const PrimFunc &func, Map< String, PrimExpr > constraints)
Verify the correctness of a GPU code It will check the whether the amount of memory usage or the numb...
CallEffectKind
The effect type of the call.
Definition: op_attr_types.h:60
TIR statements.
Compare two expressions recursively and check if they are equal to each other without var remapping...
Definition: analysis.h:54
Map< GlobalVar, BaseFunc > functions
A map from ids to all global functions.
Definition: module.h:59
TIR expressions.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
bool VerifySSA(const PrimFunc &func)
Verifies whether the IR stmt or Expr is in SSA form. That is: each Var is defined and assigned once(i...
Managed reference to PrimFuncNode.
Definition: function.h:156
bool VerifyMemory(const PrimFunc &func)
Verify if memory accesses are legal for a specific target device type.
size_t CalculateWorkspaceBytes(const PrimFunc &func, const Integer &workspace_byte_alignment)
Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc.
Container of all statements.
Definition: stmt.h:57
Definition: transform.h:363
Array< Var > UndefinedVars(const Stmt &stmt, const Array< Var > &defs)
Find undefined vars in the statement.
void VisitPrimFuncs(const IRModule &mod, FLambda fvisit)
Visit the PrimFuncs in the IRModule.
Definition: analysis.h:66
size_t CalculateExprComplexity(const PrimExpr &expr)
Calculate the expresion complexity based on number of symbols it contains.
Array< Array< BufferRegion > > GetBlockAccessRegion(const Block &block, const Map< Var, Buffer > &buffer_var_map)
Auto detect the block access region according to its body stmt It will detect the access region as an...
Managed reference class to IRModuleNode.
Definition: module.h:360
bool UsesVar(const Stmt &stmt, std::function< bool(const VarNode *)> vset_contains)
Whether the given Stmt uses any var in the given variable set.
tvm::transform::PassContext PassContext
Definition: transform.h:47
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1268
CallEffectKind SideEffect(const PrimExpr &expr)
Analyze the side effect.
Managed reference to BaseFuncNode.
Definition: function.h:143
bool operator()(const PrimExpr &lhs, const PrimExpr &rhs) const
Reference to PrimExprNode.
Definition: expr.h:112
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:865
double EstimateTIRFlops(const Stmt &stmt)
Estimate the FLOPs of a TIR fragment.
Container of constant int that adds more constructors.
Definition: expr.h:404