tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
analysis.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_ANALYSIS_H_
25 #define TVM_RELAX_ANALYSIS_H_
26 
27 #include <tvm/arith/analyzer.h>
28 #include <tvm/ir/diagnostic.h>
29 #include <tvm/ir/module.h>
30 #include <tvm/relax/expr.h>
32 #include <tvm/relax/struct_info.h>
33 #include <tvm/tir/function.h>
34 #include <tvm/tir/index_map.h>
35 
36 #include <functional>
37 #include <utility>
38 
39 namespace tvm {
40 namespace relax {
41 //-----------------------------------
42 // Shape expression analysis
43 //----------------------------------
56 TVM_DLL bool CanProveShapeEqual(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs,
57  arith::Analyzer* ana);
58 
70 TVM_DLL bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs, arith::Analyzer* ana);
71 
72 //-----------------------------------
73 // Foundational StructInfo analysis
74 //-----------------------------------
80 TVM_DLL Type GetStaticType(const StructInfo& info);
81 
87 TVM_DLL StructInfo StructInfoFromType(const Type& type);
88 
99 TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call,
100  const BlockBuilder& ctx, arith::Analyzer* ana = nullptr);
101 
158 TVM_DLL StructInfo
160  std::function<Optional<PrimExpr>(const tir::Var& var)> f_shape_var_map = nullptr,
161  std::function<Optional<Expr>(const Var& var)> f_var_map = nullptr,
162  arith::Analyzer* ana = nullptr);
163 
178  Map<Var, Expr> var_map, arith::Analyzer* ana = nullptr);
179 
193 enum class BaseCheckResult {
197  kFailL0 = 0,
203  kFailL1 = 1,
222  kFailL2 = 2,
224  kPass = 3
225 };
226 
239 TVM_DLL BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived,
240  arith::Analyzer* ana = nullptr);
241 
250 TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived,
251  arith::Analyzer* ana = nullptr);
252 
273 TVM_DLL PrimExpr StructInfoBaseCheckPrecondition(const StructInfo& base, const StructInfo& derived);
274 
283 TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs,
284  arith::Analyzer* ana = nullptr);
285 
293 
307 
320 
328 
335 TVM_DLL Array<tir::Var> FreeSymbolicVars(const Expr& expr);
336 //-----------------------------------
337 // General IR analysis
338 //-----------------------------------
349 TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);
350 
361 TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);
362 
370 TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
371 
383 
408 
416 
423 TVM_DLL Map<Var, Expr> AnalyzeVar2Value(const Expr& expr);
424 
432 
440 
448 
460 std::pair<Map<Var, Array<Var>>, Array<Var>> FunctionUseDef(const Expr& expr);
461 
464 struct VarUsageInfo {
465  /* \brief A map from variables to the bound expression.
466  *
467  * This is equivalent to the output of AnalyzeVar2Value
468  */
470 
471  /* \brief The map from variables to downstream usages of the variable
472  *
473  * This is equivalent to the first output of FunctionUseDef.
474  */
476 
477  /* \brief A list of variables produced as output
478  *
479  * This is equivalent to the second output of FunctionUseDef
480  */
482 };
483 
495 
504 TVM_DLL Expr RemoveAllUnused(Expr expr);
505 
516 
530 TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func);
531 
544 TVM_DLL Optional<Expr> FindImpureCall(const Expr& expr,
545  const Optional<Expr>& own_name = Optional<Expr>(nullptr));
546 
558 TVM_DLL bool ContainsImpureCall(const Expr& expr,
559  const Optional<Expr>& own_name = Optional<Expr>(nullptr));
560 
572 TVM_DLL bool WellFormed(Variant<IRModule, Function> obj, bool check_struct_info = true);
573 
585  const Function& fn, Array<tir::IndexMap> write_buffer_transformations);
586 
587 /* \brief Collect variables whose value can be computed at compile-time
588  *
589  * If a function has the `kNumInput` attribute, then the first
590  * `kNumInput` parameters are provided at run-time, while all
591  * remaining parameters may be known at compile-time. This utility
592  * collects all variable bindings that only depend, directly or
593  * indirectly, on the parameters known at compile-time.
594  *
595  * \param func The relax::Function to analyze
596  *
597  * \return The set of variables that can be computed at compile-time,
598  * in order of their occurrence within the function.
599  */
601 
602 } // namespace relax
603 } // namespace tvm
604 
605 #endif // TVM_RELAX_ANALYSIS_H_
Algebra expression simplifications.
Managed reference class to IRModuleNode.
Definition: module.h:249
Reference to PrimExprNode.
Definition: expr.h:115
Managed reference to RelaxExprNode.
Definition: expr.h:405
Managed reference to TypeNode.
Definition: type.h:93
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:629
Definition: block_builder.h:265
Definition: expr.h:191
Definition: expr.h:807
Managed reference to FuncStructInfoNode.
Definition: struct_info.h:362
Definition: expr.h:996
Managed reference to StructInfoNode.
Definition: expr.h:130
Definition: expr.h:423
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Definition: variant.h:69
Managed reference to PrimFuncNode.
Definition: function.h:144
a named variable in TIR
Definition: var.h:89
A new diagnostic interface for TVM error reporting.
Defines a remapping of buffer indices.
IRModule that holds the functions and type definitions.
BaseCheckResult
Fine grained result of base check.
Definition: analysis.h:193
@ kPass
LSet is superset of RSet.
@ kFailL2
WLSet is not superset of RSet because of mismatch in value information.
@ kFailL1
LSet is not superset of RSet by only looking at static information.
@ kFailL0
The two value sets have no intersection at all: Interset(LSet, RSet) = empty.
tvm::Array< Var > AllVars(const Expr &expr)
Get all variables from expression expr.
bool WellFormed(Variant< IRModule, Function > obj, bool check_struct_info=true)
Check if the IRModule is well formed.
BaseCheckResult StructInfoBaseCheck(const StructInfo &base, const StructInfo &derived, arith::Analyzer *ana=nullptr)
Run a base check to see if base subsumes derived.
StructInfo StructInfoLCA(const StructInfo &lhs, const StructInfo &rhs, arith::Analyzer *ana=nullptr)
Unify the two struct info to their least common ancestor.
Array< tir::Var > FreeSymbolicVars(const Expr &expr)
Get the TIR variables that are used but not defined in the input function. The returned list is dedup...
Array< tir::Var > TIRVarsInStructInfo(const StructInfo &sinfo)
Get the TIR variables that appear in the input struct info. The returned list is deduplicated - each ...
Expr RemoveAllUnused(Expr expr)
Remove unused statements inside DataflowBlocks.
Map< String, Array< Binding > > NameToBinding(const Function &fn)
Return a mapping from variable name to its Bindings.
tvm::Array< GlobalVar > AllGlobalVars(const Expr &expr)
Get all global variables from expression expr.
tvm::Array< tvm::Array< GlobalVar > > DetectRecursion(const IRModule &m)
Find all sets of recursive or mutually recursive functions in the module.
Map< tir::Block, Map< ObjectRef, tir::IndexMap > > SuggestLayoutTransforms(const Function &fn, Array< tir::IndexMap > write_buffer_transformations)
Using the layout transforms on the outputs, suggest layout transformation on the blocks and buffers f...
bool ContainsImpureCall(const Expr &expr, const Optional< Expr > &own_name=Optional< Expr >(nullptr))
Check if the given expression (likely a function body) contains any impure calls.
std::pair< Map< Var, Array< Var > >, Array< Var > > FunctionUseDef(const Expr &expr)
Get the use-def chain of variables inside a function.
Type GetStaticType(const StructInfo &info)
Get the corresponding static type from a given struct info.
StructInfo StructInfoFromType(const Type &type)
Get the corresponding struct info from static type.
bool CanProveShapeEqual(const Array< PrimExpr > &lhs, const Array< PrimExpr > &rhs, arith::Analyzer *ana)
Can prove the two symbolic shape arrays equals to each other.
tvm::Array< Var > BoundVars(const Expr &expr)
Get all bound variables from expression expr.
bool HasReshapePattern(const tir::PrimFunc &func)
Check if the given PrimFunc is essentially doing a reshape operation. The reshape operation also incl...
Array< tir::Var > DefinedSymbolicVars(const Expr &expr)
Get the TIR variables that defined in the input function. The returned list is deduplicated - each TI...
tvm::Array< Var > FreeVars(const Expr &expr)
Get free type parameters from expression expr.
OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc &func)
Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps.
Array< PrimExpr > CollectNonNegativeExpressions(const StructInfo &sinfo)
Collect expressions whose usage requires them to be non-negative.
Array< tir::Var > DefinableTIRVarsInStructInfo(const StructInfo &sinfo)
Get the TIR variables that appear in the input struct info.
StructInfo DeriveCallRetStructInfo(const FuncStructInfo &finfo, const Call &call, const BlockBuilder &ctx, arith::Analyzer *ana=nullptr)
OpPatternKind
Definition: op_attr_types.h:34
Optional< Expr > FindImpureCall(const Expr &expr, const Optional< Expr > &own_name=Optional< Expr >(nullptr))
Check if the given expression (likely a function body) contains any impure calls.
bool IsBaseOf(const StructInfo &base, const StructInfo &derived, arith::Analyzer *ana=nullptr)
Check the relation of two struct info to see if one subsumes another one.
Array< Var > ComputableAtCompileTime(const Function &func)
StructInfo EraseToWellDefined(const StructInfo &info, std::function< Optional< PrimExpr >(const tir::Var &var)> f_shape_var_map=nullptr, std::function< Optional< Expr >(const Var &var)> f_var_map=nullptr, arith::Analyzer *ana=nullptr)
Erase the info to a corresponding more coarse grained struct info that is still well-defined(with all...
Map< Var, Array< Var > > DataflowBlockUseDef(const DataflowBlock &dfb)
Get the use-def chain of variables inside a dataflow block.
PrimExpr StructInfoBaseCheckPrecondition(const StructInfo &base, const StructInfo &derived)
Return the condition for which base is a superset of derived.
Map< Var, Expr > AnalyzeVar2Value(const IRModule &m)
Analyze var -> value mapping from VarBindings.
VarUsageInfo CollectVarUsage(const Expr &expr)
Collect variable bindings and usage.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
Data structures that can appear in operator attributes.
A utility struct returned by CollectVarUsage.
Definition: analysis.h:464
Map< Var, Expr > bound_values
Definition: analysis.h:469
Array< Var > outputs
Definition: analysis.h:481
Map< Var, Array< Var > > downstream_usage
Definition: analysis.h:475
TIR Function.