tvm
analysis.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_ANALYSIS_H_
25 #define TVM_TIR_ANALYSIS_H_
26 
27 #include <tvm/ir/module.h>
28 #include <tvm/ir/transform.h>
29 #include <tvm/tir/expr.h>
30 #include <tvm/tir/function.h>
31 #include <tvm/tir/op_attr_types.h>
32 #include <tvm/tir/stmt.h>
33 
34 #include <string>
35 
36 namespace tvm {
37 namespace tir {
38 
54 struct ExprDeepEqual {
55  public:
56  TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
57 };
58 
65 template <class FLambda>
66 inline void VisitPrimFuncs(const IRModule& mod, FLambda fvisit) {
67  for (const auto& kv : mod->functions) {
68  const BaseFunc& base_func = kv.second;
69  if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
70  fvisit(prim_func);
71  }
72  }
73 }
74 
81 TVM_DLL Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
82 
88 TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
89 
96 TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);
97 
104 TVM_DLL bool UsesVar(const Stmt& stmt, std::function<bool(const VarNode*)> vset_contains);
105 
112 TVM_DLL bool UsesVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);
113 
123 TVM_DLL bool VerifySSA(const PrimFunc& func);
124 
135 TVM_DLL bool VerifyMemory(const PrimFunc& func);
136 
156 TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);
157 
171  const Map<Var, Buffer>& buffer_var_map);
172 
182  const Map<Var, Buffer>& buffer_var_map);
183 
188 TVM_DLL size_t CalculateExprComplexity(const PrimExpr& expr);
189 
196 TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func,
197  const Integer& workspace_byte_alignment);
198 
208 
209 // Pass variants of verification analysis
210 // directly throws RuntimeError when verification fails.
211 namespace transform {
212 
215 
222 TVM_DLL Pass VerifySSA();
223 
230 TVM_DLL Pass VerifyMemory();
231 
240 TVM_DLL Pass VerifyGPUCode(Map<String, PrimExpr> constraints);
241 
242 } // namespace transform
243 } // namespace tir
244 } // namespace tvm
245 #endif // TVM_TIR_ANALYSIS_H_
Managed reference to BlockNode.
Definition: stmt.h:1164
Map< Buffer, Optional< Stmt > > DetectBufferAccessLCA(const PrimFunc &func)
Detect the lowest common ancestor(LCA) of buffer access, including both high-level access(BufferLoad...
Array< Array< BufferRegion > > GetBlockReadWriteRegion(const Block &block, const Map< Var, Buffer > &buffer_var_map)
Auto detect the block read/write region according to its body stmt. An opaque access will be counted ...
IRModule that holds the functions and type definitions.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
Attribute types in the Op registry for TIR ops.
A variable node in the IR.
Definition: var.h:47
Primitive functions that contains TIR statements.
Definition: function.h:46
TIR Function.
tvm::transform::Pass Pass
Definition: transform.h:41
bool VerifyGPUCode(const PrimFunc &func, Map< String, PrimExpr > constraints)
Verify the correctness of a GPU code It will check the whether the amount of memory usage or the numb...
CallEffectKind
The effect type of the call.
Definition: op_attr_types.h:60
TIR statements.
Compare two expressions recursively and check if they are equal to each other without var remapping...
Definition: analysis.h:54
Map< GlobalVar, BaseFunc > functions
A map from ids to all global functions.
Definition: module.h:57
TIR expressions.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
bool VerifySSA(const PrimFunc &func)
Verifies whether the IR stmt or Expr is in SSA form. That is: each Var is defined and assigned once(i...
Managed reference to PrimFuncNode.
Definition: function.h:135
bool VerifyMemory(const PrimFunc &func)
Verify if memory accesses are legal for a specific target device type.
size_t CalculateWorkspaceBytes(const PrimFunc &func, const Integer &workspace_byte_alignment)
Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc.
Container of all statements.
Definition: stmt.h:57
Definition: transform.h:363
Array< Var > UndefinedVars(const Stmt &stmt, const Array< Var > &defs)
Find undefined vars in the statement.
void VisitPrimFuncs(const IRModule &mod, FLambda fvisit)
Visit the PrimFuncs in the IRModule.
Definition: analysis.h:66
size_t CalculateExprComplexity(const PrimExpr &expr)
Calculate the expresion complexity based on number of symbols it contains.
Array< Array< BufferRegion > > GetBlockAccessRegion(const Block &block, const Map< Var, Buffer > &buffer_var_map)
Auto detect the block access region according to its body stmt It will detect the access region as an...
Managed reference class to IRModuleNode.
Definition: module.h:352
bool UsesVar(const Stmt &stmt, std::function< bool(const VarNode *)> vset_contains)
Whether the given Stmt uses any var in the given variable set.
tvm::transform::PassContext PassContext
Definition: transform.h:45
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1235
CallEffectKind SideEffect(const PrimExpr &expr)
Analyze the side effect.
Managed reference to BaseFuncNode.
Definition: function.h:143
bool operator()(const PrimExpr &lhs, const PrimExpr &rhs) const
Reference to PrimExprNode.
Definition: expr.h:109
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:271
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:858
Container of constant int that adds more constructors.
Definition: expr.h:356