tvm
analysis.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_ANALYSIS_H_
25 #define TVM_RELAX_ANALYSIS_H_
26 
27 #include <tvm/arith/analyzer.h>
28 #include <tvm/ir/diagnostic.h>
29 #include <tvm/ir/module.h>
30 #include <tvm/relax/expr.h>
32 #include <tvm/relax/struct_info.h>
33 #include <tvm/tir/function.h>
34 #include <tvm/tir/index_map.h>
35 
36 #include <functional>
37 #include <utility>
38 
39 namespace tvm {
40 namespace relax {
41 //-----------------------------------
42 // Shape expression analysis
43 //----------------------------------
56 TVM_DLL bool CanProveShapeEqual(const ffi::Array<PrimExpr>& lhs, const ffi::Array<PrimExpr>& rhs,
57  arith::Analyzer* ana);
58 
70 TVM_DLL bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs, arith::Analyzer* ana);
71 
72 //-----------------------------------
73 // Foundational StructInfo analysis
74 //-----------------------------------
80 TVM_DLL Type GetStaticType(const StructInfo& info);
81 
87 TVM_DLL StructInfo StructInfoFromType(const Type& type);
88 
99 TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call,
100  const BlockBuilder& ctx, arith::Analyzer* ana = nullptr);
101 
159  const StructInfo& info,
160  std::function<ffi::Optional<PrimExpr>(const tir::Var& var)> f_shape_var_map = nullptr,
161  std::function<ffi::Optional<Expr>(const Var& var)> f_var_map = nullptr,
162  arith::Analyzer* ana = nullptr);
163 
178  ffi::Map<tir::Var, PrimExpr> shape_var_map,
179  ffi::Map<Var, Expr> var_map, arith::Analyzer* ana = nullptr);
180 
194 enum class BaseCheckResult {
198  kFailL0 = 0,
204  kFailL1 = 1,
223  kFailL2 = 2,
225  kPass = 3
226 };
227 
240 TVM_DLL BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived,
241  arith::Analyzer* ana = nullptr);
242 
251 TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived,
252  arith::Analyzer* ana = nullptr);
253 
274 TVM_DLL PrimExpr StructInfoBaseCheckPrecondition(const StructInfo& base, const StructInfo& derived);
275 
284 TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs,
285  arith::Analyzer* ana = nullptr);
286 
293 TVM_DLL ffi::Array<tir::Var> TIRVarsInStructInfo(const StructInfo& sinfo);
294 
307 TVM_DLL ffi::Array<tir::Var> DefinableTIRVarsInStructInfo(const StructInfo& sinfo);
308 
320 TVM_DLL ffi::Array<PrimExpr> CollectNonNegativeExpressions(const StructInfo& sinfo);
321 
328 TVM_DLL ffi::Array<tir::Var> DefinedSymbolicVars(const Expr& expr);
329 
336 TVM_DLL ffi::Array<tir::Var> FreeSymbolicVars(const Expr& expr);
337 //-----------------------------------
338 // General IR analysis
339 //-----------------------------------
350 TVM_DLL tvm::ffi::Array<Var> BoundVars(const Expr& expr);
351 
362 TVM_DLL tvm::ffi::Array<Var> FreeVars(const Expr& expr);
363 
371 TVM_DLL tvm::ffi::Array<Var> AllVars(const Expr& expr);
372 
383 TVM_DLL tvm::ffi::Array<GlobalVar> AllGlobalVars(const Expr& expr);
384 
408 TVM_DLL tvm::ffi::Array<tvm::ffi::Array<GlobalVar>> DetectRecursion(const IRModule& m);
409 
416 TVM_DLL ffi::Map<Var, Expr> AnalyzeVar2Value(const IRModule& m);
417 
424 TVM_DLL ffi::Map<Var, Expr> AnalyzeVar2Value(const Expr& expr);
425 
432 TVM_DLL ffi::Map<Var, Expr> AnalyzeVar2Value(const DataflowBlock& dfb);
433 
440 TVM_DLL ffi::Map<ffi::String, ffi::Array<Binding>> NameToBinding(const Function& fn);
441 
448 TVM_DLL ffi::Map<Var, ffi::Array<Var>> DataflowBlockUseDef(const DataflowBlock& dfb);
449 
461 std::pair<ffi::Map<Var, ffi::Array<Var>>, ffi::Array<Var>> FunctionUseDef(const Expr& expr);
462 
465 struct VarUsageInfo {
466  /* \brief A map from variables to the bound expression.
467  *
468  * This is equivalent to the output of AnalyzeVar2Value
469  */
470  ffi::Map<Var, Expr> bound_values;
471 
472  /* \brief The map from variables to downstream usages of the variable
473  *
474  * This is equivalent to the first output of FunctionUseDef.
475  */
476  ffi::Map<Var, ffi::Array<Var>> downstream_usage;
477 
478  /* \brief A list of variables produced as output
479  *
480  * This is equivalent to the second output of FunctionUseDef
481  */
482  ffi::Array<Var> outputs;
483 };
484 
496 
505 TVM_DLL Expr RemoveAllUnused(Expr expr);
506 
517 
531 TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func);
532 
545 TVM_DLL ffi::Optional<Expr> FindImpureCall(
546  const Expr& expr, const ffi::Optional<Expr>& own_name = ffi::Optional<Expr>(std::nullopt));
547 
559 TVM_DLL bool ContainsImpureCall(
560  const Expr& expr, const ffi::Optional<Expr>& own_name = ffi::Optional<Expr>(std::nullopt));
561 
573 TVM_DLL bool WellFormed(ffi::Variant<IRModule, Function> obj, bool check_struct_info = true);
574 
585 TVM_DLL ffi::Map<tir::Block, ffi::Map<ObjectRef, tir::IndexMap>> SuggestLayoutTransforms(
586  const Function& fn, ffi::Array<tir::IndexMap> write_buffer_transformations);
587 
588 /* \brief Collect variables whose value can be computed at compile-time
589  *
590  * If a function has the `kNumInput` attribute, then the first
591  * `kNumInput` parameters are provided at run-time, while all
592  * remaining parameters may be known at compile-time. This utility
593  * collects all variable bindings that only depend, directly or
594  * indirectly, on the parameters known at compile-time.
595  *
596  * \param func The relax::Function to analyze
597  *
598  * \return The set of variables that can be computed at compile-time,
599  * in order of their occurrence within the function.
600  */
601 TVM_DLL ffi::Array<Var> ComputableAtCompileTime(const Function& func);
602 
603 } // namespace relax
604 } // namespace tvm
605 
606 #endif // TVM_RELAX_ANALYSIS_H_
Algebra expression simplifications.
Managed reference class to IRModuleNode.
Definition: module.h:256
Reference to PrimExprNode.
Definition: expr.h:124
Managed reference to RelaxExprNode.
Definition: expr.h:439
Managed reference to TypeNode.
Definition: type.h:100
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:634
Definition: block_builder.h:264
Definition: expr.h:176
Definition: expr.h:693
Managed reference to FuncStructInfoNode.
Definition: struct_info.h:308
Definition: expr.h:830
Managed reference to StructInfoNode.
Definition: expr.h:132
Definition: expr.h:377
Managed reference to PrimFuncNode.
Definition: function.h:129
a named variable in TIR
Definition: var.h:77
A new diagnostic interface for TVM error reporting.
Defines a remapping of buffer indices.
IRModule that holds the functions and type definitions.
BaseCheckResult
Fine grained result of base check.
Definition: analysis.h:194
@ kPass
LSet is superset of RSet.
@ kFailL2
WLSet is not superset of RSet because of mismatch in value information.
@ kFailL1
LSet is not superset of RSet by only looking at static information.
@ kFailL0
The two value sets have no intersection at all: Interset(LSet, RSet) = empty.
ffi::Optional< Expr > FindImpureCall(const Expr &expr, const ffi::Optional< Expr > &own_name=ffi::Optional< Expr >(std::nullopt))
Check if the given expression (likely a function body) contains any impure calls.
ffi::Array< tir::Var > FreeSymbolicVars(const Expr &expr)
Get the TIR variables that are used but not defined in the input function. The returned list is dedup...
BaseCheckResult StructInfoBaseCheck(const StructInfo &base, const StructInfo &derived, arith::Analyzer *ana=nullptr)
Run a base check to see if base subsumes derived.
StructInfo StructInfoLCA(const StructInfo &lhs, const StructInfo &rhs, arith::Analyzer *ana=nullptr)
Unify the two struct info to their least common ancestor.
bool ContainsImpureCall(const Expr &expr, const ffi::Optional< Expr > &own_name=ffi::Optional< Expr >(std::nullopt))
Check if the given expression (likely a function body) contains any impure calls.
Expr RemoveAllUnused(Expr expr)
Remove unused statements inside DataflowBlocks.
ffi::Array< tir::Var > DefinableTIRVarsInStructInfo(const StructInfo &sinfo)
Get the TIR variables that appear in the input struct info.
ffi::Map< tir::Block, ffi::Map< ObjectRef, tir::IndexMap > > SuggestLayoutTransforms(const Function &fn, ffi::Array< tir::IndexMap > write_buffer_transformations)
Using the layout transforms on the outputs, suggest layout transformation on the blocks and buffers f...
bool CanProveShapeEqual(const ffi::Array< PrimExpr > &lhs, const ffi::Array< PrimExpr > &rhs, arith::Analyzer *ana)
Can prove the two symbolic shape arrays equals to each other.
ffi::Map< Var, Expr > AnalyzeVar2Value(const IRModule &m)
Analyze var -> value mapping from VarBindings.
bool WellFormed(ffi::Variant< IRModule, Function > obj, bool check_struct_info=true)
Check if the IRModule is well formed.
Type GetStaticType(const StructInfo &info)
Get the corresponding static type from a given struct info.
StructInfo EraseToWellDefined(const StructInfo &info, std::function< ffi::Optional< PrimExpr >(const tir::Var &var)> f_shape_var_map=nullptr, std::function< ffi::Optional< Expr >(const Var &var)> f_var_map=nullptr, arith::Analyzer *ana=nullptr)
Erase the info to a corresponding more coarse grained struct info that is still well-defined(with all...
StructInfo StructInfoFromType(const Type &type)
Get the corresponding struct info from static type.
ffi::Array< Var > ComputableAtCompileTime(const Function &func)
ffi::Array< PrimExpr > CollectNonNegativeExpressions(const StructInfo &sinfo)
Collect expressions whose usage requires them to be non-negative.
ffi::Array< tir::Var > TIRVarsInStructInfo(const StructInfo &sinfo)
Get the TIR variables that appear in the input struct info. The returned list is deduplicated - each ...
bool HasReshapePattern(const tir::PrimFunc &func)
Check if the given PrimFunc is essentially doing a reshape operation. The reshape operation also incl...
tvm::ffi::Array< Var > FreeVars(const Expr &expr)
Get free type parameters from expression expr.
ffi::Map< Var, ffi::Array< Var > > DataflowBlockUseDef(const DataflowBlock &dfb)
Get the use-def chain of variables inside a dataflow block.
OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc &func)
Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps.
ffi::Map< ffi::String, ffi::Array< Binding > > NameToBinding(const Function &fn)
Return a mapping from variable name to its Bindings.
tvm::ffi::Array< GlobalVar > AllGlobalVars(const Expr &expr)
Get all global variables from expression expr.
ffi::Array< tir::Var > DefinedSymbolicVars(const Expr &expr)
Get the TIR variables that defined in the input function. The returned list is deduplicated - each TI...
StructInfo DeriveCallRetStructInfo(const FuncStructInfo &finfo, const Call &call, const BlockBuilder &ctx, arith::Analyzer *ana=nullptr)
OpPatternKind
Definition: op_attr_types.h:34
tvm::ffi::Array< Var > AllVars(const Expr &expr)
Get all variables from expression expr.
bool IsBaseOf(const StructInfo &base, const StructInfo &derived, arith::Analyzer *ana=nullptr)
Check the relation of two struct info to see if one subsumes another one.
tvm::ffi::Array< Var > BoundVars(const Expr &expr)
Get all bound variables from expression expr.
PrimExpr StructInfoBaseCheckPrecondition(const StructInfo &base, const StructInfo &derived)
Return the condition for which base is a superset of derived.
tvm::ffi::Array< tvm::ffi::Array< GlobalVar > > DetectRecursion(const IRModule &m)
Find all sets of recursive or mutually recursive functions in the module.
std::pair< ffi::Map< Var, ffi::Array< Var > >, ffi::Array< Var > > FunctionUseDef(const Expr &expr)
Get the use-def chain of variables inside a function.
VarUsageInfo CollectVarUsage(const Expr &expr)
Collect variable bindings and usage.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
Data structures that can appear in operator attributes.
A utility struct returned by CollectVarUsage.
Definition: analysis.h:465
ffi::Array< Var > outputs
Definition: analysis.h:482
ffi::Map< Var, ffi::Array< Var > > downstream_usage
Definition: analysis.h:476
ffi::Map< Var, Expr > bound_values
Definition: analysis.h:470
TIR Function.