tvm
analysis.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_ANALYSIS_H_
25 #define TVM_TIR_ANALYSIS_H_
26 
27 #include <tvm/ir/module.h>
28 #include <tvm/ir/transform.h>
29 #include <tvm/s_tir/analysis.h>
30 #include <tvm/target/target.h>
31 #include <tvm/tirx/expr.h>
32 #include <tvm/tirx/function.h>
33 #include <tvm/tirx/op_attr_types.h>
34 #include <tvm/tirx/stmt.h>
35 
36 #include <optional>
37 #include <string>
38 
39 namespace tvm {
40 
41 namespace arith {
42 class Analyzer;
43 }
44 
45 namespace tirx {
46 
62 struct ExprDeepEqual {
63  public:
64  TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
65 };
66 
73 template <class FLambda>
74 inline void VisitPrimFuncs(const IRModule& mod, FLambda fvisit) {
75  for (const auto& kv : mod->functions) {
76  const BaseFunc& base_func = kv.second;
77  if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
78  fvisit(prim_func);
79  }
80  }
81 }
82 
89 TVM_DLL ffi::Array<Var> UndefinedVars(const Stmt& stmt, const ffi::Array<Var>& defs);
90 
96 TVM_DLL ffi::Array<Var> UndefinedVars(const PrimExpr& expr);
97 
104 TVM_DLL ffi::Array<Var> UndefinedVars(const PrimExpr& expr, const ffi::Array<Var>& defs);
105 
112 TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);
113 
120 TVM_DLL bool UsesVar(const Stmt& stmt, std::function<bool(const VarNode*)> vset_contains);
121 
128 TVM_DLL bool UsesVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);
129 
139 TVM_DLL bool VerifySSA(const PrimFunc& func);
140 
151 TVM_DLL bool VerifyMemory(const PrimFunc& func);
152 
157 TVM_DLL size_t CalculateExprComplexity(const PrimExpr& expr);
158 
164 TVM_DLL size_t CalculateConstantBytes(const PrimFunc& func, const Integer& constant_byte_alignment);
165 
172 TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func,
173  const Integer& workspace_byte_alignment);
174 
195 TVM_DLL bool VerifyWellFormed(const PrimFunc& func, bool assert_mode = true);
196 
209 TVM_DLL bool VerifyWellFormed(const IRModule& mod, bool assert_mode = true);
210 
218 const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var);
219 
220 // Pass variants of verification analysis
221 // directly throws RuntimeError when verification fails.
222 namespace transform {
223 
226 
233 TVM_DLL Pass VerifySSA();
234 
241 TVM_DLL Pass VerifyMemory();
242 
243 } // namespace transform
244 } // namespace tirx
245 } // namespace tvm
246 #endif // TVM_TIR_ANALYSIS_H_
Managed reference to BaseFuncNode.
Definition: function.h:233
Managed reference to GlobalVarNode.
Definition: expr.h:482
Managed reference class to IRModuleNode.
Definition: module.h:257
Container of constant int that adds more constructors.
Definition: expr.h:601
Reference to PrimExprNode.
Definition: expr.h:126
Primitive functions that contains TIR statements.
Definition: function.h:49
Managed reference to PrimFuncNode.
Definition: function.h:130
Container of all statements.
Definition: stmt.h:65
A variable node in the IR.
Definition: var.h:47
Definition: transform.h:400
IRModule that holds the functions and type definitions.
tvm::transform::Pass Pass
Definition: transform.h:35
tvm::transform::PassContext PassContext
Definition: transform.h:37
Pass VerifySSA()
Pass variant of VerifySSA.
Pass VerifyMemory()
Pass variant of VerifyMemory.
size_t CalculateExprComplexity(const PrimExpr &expr)
Calculate the expression complexity based on number of symbols it contains.
bool VerifyMemory(const PrimFunc &func)
Verify if memory accesses are legal for a specific target device type.
bool VerifyWellFormed(const PrimFunc &func, bool assert_mode=true)
Verify if the given TIR is well-formed. The verification includes:
const PrimFuncNode * FindEntryFunc(const IRModule &mod, GlobalVar *result_g_var)
Find the entry function of the given IRModule, i.e, functions marked by tirx::attr::kIsEntryFunc,...
bool VerifySSA(const PrimFunc &func)
Verifies whether the IR stmt or Expr is in SSA form. That is: each Var is defined and assigned once(i...
CallEffectKind SideEffect(const PrimExpr &expr)
Analyze the side effect of an expression.
size_t CalculateWorkspaceBytes(const PrimFunc &func, const Integer &workspace_byte_alignment)
Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc.
bool UsesVar(const Stmt &stmt, std::function< bool(const VarNode *)> vset_contains)
Whether the given Stmt uses any var in the given variable set.
ffi::Array< Var > UndefinedVars(const Stmt &stmt, const ffi::Array< Var > &defs)
Find undefined vars in the statement.
size_t CalculateConstantBytes(const PrimFunc &func, const Integer &constant_byte_alignment)
Calculate the constants size in bytes needed by the TIR allocates inside the TIR PrimFunc.
CallEffectKind
The effect type of the call.
Definition: op_attr_types.h:88
void VisitPrimFuncs(const IRModule &mod, FLambda fvisit)
Visit the PrimFuncs in the IRModule.
Definition: analysis.h:74
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:308
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
Analysis utilities for Schedulable TensorIR (S-TIR).
Compare two expressions recursively and check if they are equal to each other without var remapping.
Definition: analysis.h:62
bool operator()(const PrimExpr &lhs, const PrimExpr &rhs) const
Compilation target object.
TIR expressions.
TIR Function.
Attribute types in the Op registry for TIR ops.
TIR statements.