tvm
analysis.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_ANALYSIS_H_
25 #define TVM_RELAX_ANALYSIS_H_
26 
27 #include <tvm/arith/analyzer.h>
28 #include <tvm/ir/diagnostic.h>
29 #include <tvm/ir/module.h>
30 #include <tvm/relax/expr.h>
31 #include <tvm/relax/struct_info.h>
33 #include <tvm/tir/function.h>
34 
35 #include <functional>
36 #include <utility>
37 
38 namespace tvm {
39 namespace relax {
40 //-----------------------------------
41 // Shape expression analysis
42 //----------------------------------
55 TVM_DLL bool CanProveShapeEqual(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs,
56  arith::Analyzer* ana);
57 
69 TVM_DLL bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs, arith::Analyzer* ana);
70 
71 //-----------------------------------
72 // Foundational StructInfo analysis
73 //-----------------------------------
79 TVM_DLL Type GetStaticType(const StructInfo& info);
80 
86 TVM_DLL StructInfo StructInfoFromType(const Type& type);
87 
98 TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call,
99  const BlockBuilder& ctx, arith::Analyzer* ana = nullptr);
100 
157 TVM_DLL StructInfo
159  std::function<Optional<PrimExpr>(const tir::Var& var)> f_shape_var_map = nullptr,
160  std::function<Optional<Expr>(const Var& var)> f_var_map = nullptr,
161  arith::Analyzer* ana = nullptr);
162 
177  Map<Var, Expr> var_map, arith::Analyzer* ana = nullptr);
178 
192 enum class BaseCheckResult {
196  kFailL0 = 0,
202  kFailL1 = 1,
221  kFailL2 = 2,
223  kPass = 3
224 };
225 
238 TVM_DLL BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived,
239  arith::Analyzer* ana = nullptr);
240 
249 TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived,
250  arith::Analyzer* ana = nullptr);
251 
272 TVM_DLL PrimExpr StructInfoBaseCheckPrecondition(const StructInfo& base, const StructInfo& derived);
273 
282 TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs,
283  arith::Analyzer* ana = nullptr);
284 
292 
306 
319 
327 
334 TVM_DLL Array<tir::Var> FreeSymbolicVars(const Expr& expr);
335 //-----------------------------------
336 // General IR analysis
337 //-----------------------------------
348 TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);
349 
360 TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);
361 
369 TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
370 
382 
407 
415 
422 TVM_DLL Map<Var, Expr> AnalyzeVar2Value(const Expr& expr);
423 
431 
439 
447 
459 std::pair<Map<Var, Array<Var>>, Array<Var>> FunctionUseDef(const Expr& expr);
460 
463 struct VarUsageInfo {
464  /* \brief A map from variables to the bound expression.
465  *
466  * This is equivalent to the output of AnalyzeVar2Value
467  */
469 
470  /* \brief The map from variables to downstream usages of the variable
471  *
472  * This is equivalent to the first output of FunctionUseDef.
473  */
475 
476  /* \brief A list of variables produced as output
477  *
478  * This is equivalent to the second output of FunctionUseDef
479  */
481 };
482 
494 
503 TVM_DLL Expr RemoveAllUnused(Expr expr);
504 
515 
529 TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func);
530 
543 TVM_DLL Optional<Expr> FindImpureCall(const Expr& expr,
544  const Optional<Expr>& own_name = Optional<Expr>(nullptr));
545 
557 TVM_DLL bool ContainsImpureCall(const Expr& expr,
558  const Optional<Expr>& own_name = Optional<Expr>(nullptr));
559 
571 TVM_DLL bool WellFormed(Variant<IRModule, Function> obj, bool check_struct_info = true);
572 
584  const Function& fn, Array<tir::IndexMap> write_buffer_transformations);
585 
586 /* \brief Collect variables whose value can be computed at compile-time
587  *
588  * If a function has the `kNumInput` attribute, then the first
589  * `kNumInput` parameters are provided at run-time, while all
590  * remaining parameters may be known at compile-time. This utility
591  * collects all variable bindings that only depend, directly or
592  * indirectly, on the parameters known at compile-time.
593  *
594  * \param func The relax::Function to analyze
595  *
596  * \return The set of variables that can be computed at compile-time,
597  * in order of their occurrence within the function.
598  */
600 
601 } // namespace relax
602 } // namespace tvm
603 
604 #endif // TVM_RELAX_ANALYSIS_H_
Algebra expression simplifications.
Managed reference class to IRModuleNode.
Definition: module.h:366
Reference to PrimExprNode.
Definition: expr.h:115
Managed reference to RelayExprNode.
Definition: expr.h:442
Managed reference to TypeNode.
Definition: type.h:93
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:629
Definition: block_builder.h:264
Definition: expr.h:190
Definition: expr.h:806
Managed reference to FuncStructInfoNode.
Definition: struct_info.h:362
Definition: expr.h:995
Managed reference to StructInfoNode.
Definition: expr.h:129
Definition: expr.h:422
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Definition: variant.h:69
Managed reference to PrimFuncNode.
Definition: function.h:145
a named variable in TIR
Definition: var.h:89
A new diagnostic interface for TVM error reporting.
IRModule that holds the functions and type definitions.
BaseCheckResult
Fine grained result of base check.
Definition: analysis.h:192
@ kPass
LSet is superset of RSet.
@ kFailL2
WLSet is not superset of RSet because of mismatch in value information.
@ kFailL1
LSet is not superset of RSet by only looking at static information.
@ kFailL0
The two value sets have no intersection at all: Interset(LSet, RSet) = empty.
tvm::Array< Var > AllVars(const Expr &expr)
Get all variables from expression expr.
bool WellFormed(Variant< IRModule, Function > obj, bool check_struct_info=true)
Check if the IRModule is well formed.
BaseCheckResult StructInfoBaseCheck(const StructInfo &base, const StructInfo &derived, arith::Analyzer *ana=nullptr)
Run a base check to see if base subsumes derived.
StructInfo StructInfoLCA(const StructInfo &lhs, const StructInfo &rhs, arith::Analyzer *ana=nullptr)
Unify the two struct info to their least common ancestor.
Array< tir::Var > FreeSymbolicVars(const Expr &expr)
Get the TIR variables that are used but not defined in the input function. The returned list is dedup...
Array< tir::Var > TIRVarsInStructInfo(const StructInfo &sinfo)
Get the TIR variables that appear in the input struct info. The returned list is deduplicated - each ...
Expr RemoveAllUnused(Expr expr)
Remove unused statements inside DataflowBlocks.
Map< String, Array< Binding > > NameToBinding(const Function &fn)
Return a mapping from variable name to its Bindings.
tvm::Array< GlobalVar > AllGlobalVars(const Expr &expr)
Get all global variables from expression expr.
tvm::Array< tvm::Array< GlobalVar > > DetectRecursion(const IRModule &m)
Find all sets of recursive or mutually recursive functions in the module.
Map< tir::Block, Map< ObjectRef, tir::IndexMap > > SuggestLayoutTransforms(const Function &fn, Array< tir::IndexMap > write_buffer_transformations)
Using the layout transforms on the outputs, suggest layout transformation on the blocks and buffers f...
bool ContainsImpureCall(const Expr &expr, const Optional< Expr > &own_name=Optional< Expr >(nullptr))
Check if the given expression (likely a function body) contains any impure calls.
std::pair< Map< Var, Array< Var > >, Array< Var > > FunctionUseDef(const Expr &expr)
Get the use-def chain of variables inside a function.
Type GetStaticType(const StructInfo &info)
Get the corresponding static type from a given struct info.
StructInfo StructInfoFromType(const Type &type)
Get the corresponding struct info from static type.
bool CanProveShapeEqual(const Array< PrimExpr > &lhs, const Array< PrimExpr > &rhs, arith::Analyzer *ana)
Can prove the two symbolic shape arrays equals to each other.
tvm::Array< Var > BoundVars(const Expr &expr)
Get all bound variables from expression expr.
bool HasReshapePattern(const tir::PrimFunc &func)
Check if the given PrimFunc is essentially doing a reshape operation. The reshape operation also incl...
Array< tir::Var > DefinedSymbolicVars(const Expr &expr)
Get the TIR variables that defined in the input function. The returned list is deduplicated - each TI...
tvm::Array< Var > FreeVars(const Expr &expr)
Get free type parameters from expression expr.
Array< PrimExpr > CollectNonNegativeExpressions(const StructInfo &sinfo)
Collect expressions whose usage requires them to be non-negative.
Array< tir::Var > DefinableTIRVarsInStructInfo(const StructInfo &sinfo)
Get the TIR variables that appear in the input struct info.
StructInfo DeriveCallRetStructInfo(const FuncStructInfo &finfo, const Call &call, const BlockBuilder &ctx, arith::Analyzer *ana=nullptr)
Optional< Expr > FindImpureCall(const Expr &expr, const Optional< Expr > &own_name=Optional< Expr >(nullptr))
Check if the given expression (likely a function body) contains any impure calls.
bool IsBaseOf(const StructInfo &base, const StructInfo &derived, arith::Analyzer *ana=nullptr)
Check the relation of two struct info to see if one subsumes another one.
Array< Var > ComputableAtCompileTime(const Function &func)
StructInfo EraseToWellDefined(const StructInfo &info, std::function< Optional< PrimExpr >(const tir::Var &var)> f_shape_var_map=nullptr, std::function< Optional< Expr >(const Var &var)> f_var_map=nullptr, arith::Analyzer *ana=nullptr)
Erase the info to a corresponding more coarse grained struct info that is still well-defined(with all...
relay::OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc &func)
Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps.
Map< Var, Array< Var > > DataflowBlockUseDef(const DataflowBlock &dfb)
Get the use-def chain of variables inside a dataflow block.
PrimExpr StructInfoBaseCheckPrecondition(const StructInfo &base, const StructInfo &derived)
Return the condition for which base is a superset of derived.
Map< Var, Expr > AnalyzeVar2Value(const IRModule &m)
Analyze var -> value mapping from VarBindings.
VarUsageInfo CollectVarUsage(const Expr &expr)
Collect variable bindings and usage.
OpPatternKind
operator pattern used in graph fusion
Definition: op_attr_types.h:45
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
The Expr and related elements in DataFlow construction.
A utility struct returned by CollectVarUsage.
Definition: analysis.h:463
Map< Var, Expr > bound_values
Definition: analysis.h:468
Array< Var > outputs
Definition: analysis.h:480
Map< Var, Array< Var > > downstream_usage
Definition: analysis.h:474
TIR Function.