tvm
analysis.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_ANALYSIS_H_
25 #define TVM_RELAX_ANALYSIS_H_
26 
27 #include <tvm/arith/analyzer.h>
28 #include <tvm/ir/diagnostic.h>
29 #include <tvm/ir/module.h>
30 #include <tvm/relax/expr.h>
32 #include <tvm/relax/struct_info.h>
33 #include <tvm/tir/function.h>
34 #include <tvm/tir/index_map.h>
35 
36 #include <functional>
37 #include <set>
38 #include <utility>
39 
40 namespace tvm {
41 namespace relax {
42 //-----------------------------------
43 // Shape expression analysis
44 //----------------------------------
57 TVM_DLL bool CanProveShapeEqual(const ffi::Array<PrimExpr>& lhs, const ffi::Array<PrimExpr>& rhs,
58  arith::Analyzer* ana);
59 
71 TVM_DLL bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs, arith::Analyzer* ana);
72 
73 //-----------------------------------
74 // Foundational StructInfo analysis
75 //-----------------------------------
81 TVM_DLL Type GetStaticType(const StructInfo& info);
82 
88 TVM_DLL StructInfo StructInfoFromType(const Type& type);
89 
100 TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call,
101  const BlockBuilder& ctx, arith::Analyzer* ana = nullptr);
102 
160  const StructInfo& info,
161  std::function<ffi::Optional<PrimExpr>(const tir::Var& var)> f_shape_var_map = nullptr,
162  std::function<ffi::Optional<Expr>(const Var& var)> f_var_map = nullptr,
163  arith::Analyzer* ana = nullptr);
164 
179  ffi::Map<tir::Var, PrimExpr> shape_var_map,
180  ffi::Map<Var, Expr> var_map, arith::Analyzer* ana = nullptr);
181 
195 enum class BaseCheckResult {
199  kFailL0 = 0,
205  kFailL1 = 1,
224  kFailL2 = 2,
226  kPass = 3
227 };
228 
241 TVM_DLL BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived,
242  arith::Analyzer* ana = nullptr);
243 
252 TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived,
253  arith::Analyzer* ana = nullptr);
254 
275 TVM_DLL PrimExpr StructInfoBaseCheckPrecondition(const StructInfo& base, const StructInfo& derived);
276 
285 TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs,
286  arith::Analyzer* ana = nullptr);
287 
294 TVM_DLL ffi::Array<tir::Var> TIRVarsInStructInfo(const StructInfo& sinfo);
295 
308 TVM_DLL ffi::Array<tir::Var> DefinableTIRVarsInStructInfo(const StructInfo& sinfo);
309 
321 TVM_DLL ffi::Array<PrimExpr> CollectNonNegativeExpressions(const StructInfo& sinfo);
322 
329 TVM_DLL ffi::Array<tir::Var> DefinedSymbolicVars(const Expr& expr);
330 
337 TVM_DLL ffi::Array<tir::Var> FreeSymbolicVars(const Expr& expr);
338 //-----------------------------------
339 // General IR analysis
340 //-----------------------------------
351 TVM_DLL tvm::ffi::Array<Var> BoundVars(const Expr& expr);
352 
363 TVM_DLL tvm::ffi::Array<Var> FreeVars(const Expr& expr);
364 
372 TVM_DLL tvm::ffi::Array<Var> AllVars(const Expr& expr);
373 
384 TVM_DLL tvm::ffi::Array<GlobalVar> AllGlobalVars(const Expr& expr);
385 
409 TVM_DLL tvm::ffi::Array<tvm::ffi::Array<GlobalVar>> DetectRecursion(const IRModule& m);
410 
417 TVM_DLL ffi::Map<Var, Expr> AnalyzeVar2Value(const IRModule& m);
418 
425 TVM_DLL ffi::Map<Var, Expr> AnalyzeVar2Value(const Expr& expr);
426 
433 TVM_DLL ffi::Map<Var, Expr> AnalyzeVar2Value(const DataflowBlock& dfb);
434 
441 TVM_DLL ffi::Map<ffi::String, ffi::Array<Binding>> NameToBinding(const Function& fn);
442 
449 TVM_DLL ffi::Map<Var, ffi::Array<Var>> DataflowBlockUseDef(const DataflowBlock& dfb);
450 
462 std::pair<ffi::Map<Var, ffi::Array<Var>>, ffi::Array<Var>> FunctionUseDef(const Expr& expr);
463 
466 struct VarUsageInfo {
467  /* \brief A map from variables to the bound expression.
468  *
469  * This is equivalent to the output of AnalyzeVar2Value
470  */
471  ffi::Map<Var, Expr> bound_values;
472 
473  /* \brief The map from variables to downstream usages of the variable
474  *
475  * This is equivalent to the first output of FunctionUseDef.
476  */
477  ffi::Map<Var, ffi::Array<Var>> downstream_usage;
478 
479  /* \brief A list of variables produced as output
480  *
481  * This is equivalent to the second output of FunctionUseDef
482  */
483  ffi::Array<Var> outputs;
484 };
485 
497 
507 TVM_DLL std::set<const VarNode*> GetUsedVars(const Expr& expr);
508 
517 TVM_DLL Expr RemoveAllUnused(Expr expr);
518 
529 
543 TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func);
544 
557 TVM_DLL ffi::Optional<Expr> FindImpureCall(
558  const Expr& expr, const ffi::Optional<Expr>& own_name = ffi::Optional<Expr>(std::nullopt));
559 
571 TVM_DLL bool ContainsImpureCall(
572  const Expr& expr, const ffi::Optional<Expr>& own_name = ffi::Optional<Expr>(std::nullopt));
573 
585 TVM_DLL bool WellFormed(ffi::Variant<IRModule, Function> obj, bool check_struct_info = true);
586 
597 TVM_DLL ffi::Map<tir::Block, ffi::Map<ObjectRef, tir::IndexMap>> SuggestLayoutTransforms(
598  const Function& fn, ffi::Array<tir::IndexMap> write_buffer_transformations);
599 
600 /* \brief Collect variables whose value can be computed at compile-time
601  *
602  * If a function has the `kNumInput` attribute, then the first
603  * `kNumInput` parameters are provided at run-time, while all
604  * remaining parameters may be known at compile-time. This utility
605  * collects all variable bindings that only depend, directly or
606  * indirectly, on the parameters known at compile-time.
607  *
608  * \param func The relax::Function to analyze
609  *
610  * \return The set of variables that can be computed at compile-time,
611  * in order of their occurrence within the function.
612  */
613 TVM_DLL ffi::Array<Var> ComputableAtCompileTime(const Function& func);
614 
615 } // namespace relax
616 } // namespace tvm
617 
618 #endif // TVM_RELAX_ANALYSIS_H_
Algebra expression simplifications.
Managed reference class to IRModuleNode.
Definition: module.h:256
Reference to PrimExprNode.
Definition: expr.h:124
Managed reference to RelaxExprNode.
Definition: expr.h:439
Managed reference to TypeNode.
Definition: type.h:100
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:634
Definition: block_builder.h:264
Definition: expr.h:180
Definition: expr.h:695
Managed reference to FuncStructInfoNode.
Definition: struct_info.h:308
Definition: expr.h:832
Managed reference to StructInfoNode.
Definition: expr.h:132
Definition: expr.h:380
Managed reference to PrimFuncNode.
Definition: function.h:129
a named variable in TIR
Definition: var.h:77
A new diagnostic interface for TVM error reporting.
Defines a remapping of buffer indices.
IRModule that holds the functions and type definitions.
BaseCheckResult
Fine grained result of base check.
Definition: analysis.h:195
@ kPass
LSet is superset of RSet.
@ kFailL2
WLSet is not superset of RSet because of mismatch in value information.
@ kFailL1
LSet is not superset of RSet by only looking at static information.
@ kFailL0
The two value sets have no intersection at all: Interset(LSet, RSet) = empty.
ffi::Optional< Expr > FindImpureCall(const Expr &expr, const ffi::Optional< Expr > &own_name=ffi::Optional< Expr >(std::nullopt))
Check if the given expression (likely a function body) contains any impure calls.
ffi::Array< tir::Var > FreeSymbolicVars(const Expr &expr)
Get the TIR variables that are used but not defined in the input function. The returned list is dedup...
BaseCheckResult StructInfoBaseCheck(const StructInfo &base, const StructInfo &derived, arith::Analyzer *ana=nullptr)
Run a base check to see if base subsumes derived.
StructInfo StructInfoLCA(const StructInfo &lhs, const StructInfo &rhs, arith::Analyzer *ana=nullptr)
Unify the two struct info to their least common ancestor.
bool ContainsImpureCall(const Expr &expr, const ffi::Optional< Expr > &own_name=ffi::Optional< Expr >(std::nullopt))
Check if the given expression (likely a function body) contains any impure calls.
Expr RemoveAllUnused(Expr expr)
Remove unused statements inside DataflowBlocks.
ffi::Array< tir::Var > DefinableTIRVarsInStructInfo(const StructInfo &sinfo)
Get the TIR variables that appear in the input struct info.
ffi::Map< tir::Block, ffi::Map< ObjectRef, tir::IndexMap > > SuggestLayoutTransforms(const Function &fn, ffi::Array< tir::IndexMap > write_buffer_transformations)
Using the layout transforms on the outputs, suggest layout transformation on the blocks and buffers f...
std::set< const VarNode * > GetUsedVars(const Expr &expr)
Get the used variables in an expression.
bool CanProveShapeEqual(const ffi::Array< PrimExpr > &lhs, const ffi::Array< PrimExpr > &rhs, arith::Analyzer *ana)
Can prove the two symbolic shape arrays equals to each other.
ffi::Map< Var, Expr > AnalyzeVar2Value(const IRModule &m)
Analyze var -> value mapping from VarBindings.
bool WellFormed(ffi::Variant< IRModule, Function > obj, bool check_struct_info=true)
Check if the IRModule is well formed.
Type GetStaticType(const StructInfo &info)
Get the corresponding static type from a given struct info.
StructInfo EraseToWellDefined(const StructInfo &info, std::function< ffi::Optional< PrimExpr >(const tir::Var &var)> f_shape_var_map=nullptr, std::function< ffi::Optional< Expr >(const Var &var)> f_var_map=nullptr, arith::Analyzer *ana=nullptr)
Erase the info to a corresponding more coarse grained struct info that is still well-defined(with all...
StructInfo StructInfoFromType(const Type &type)
Get the corresponding struct info from static type.
ffi::Array< Var > ComputableAtCompileTime(const Function &func)
ffi::Array< PrimExpr > CollectNonNegativeExpressions(const StructInfo &sinfo)
Collect expressions whose usage requires them to be non-negative.
ffi::Array< tir::Var > TIRVarsInStructInfo(const StructInfo &sinfo)
Get the TIR variables that appear in the input struct info. The returned list is deduplicated - each ...
bool HasReshapePattern(const tir::PrimFunc &func)
Check if the given PrimFunc is essentially doing a reshape operation. The reshape operation also incl...
tvm::ffi::Array< Var > FreeVars(const Expr &expr)
Get free type parameters from expression expr.
ffi::Map< Var, ffi::Array< Var > > DataflowBlockUseDef(const DataflowBlock &dfb)
Get the use-def chain of variables inside a dataflow block.
OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc &func)
Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps.
ffi::Map< ffi::String, ffi::Array< Binding > > NameToBinding(const Function &fn)
Return a mapping from variable name to its Bindings.
tvm::ffi::Array< GlobalVar > AllGlobalVars(const Expr &expr)
Get all global variables from expression expr.
ffi::Array< tir::Var > DefinedSymbolicVars(const Expr &expr)
Get the TIR variables that defined in the input function. The returned list is deduplicated - each TI...
StructInfo DeriveCallRetStructInfo(const FuncStructInfo &finfo, const Call &call, const BlockBuilder &ctx, arith::Analyzer *ana=nullptr)
OpPatternKind
Definition: op_attr_types.h:34
tvm::ffi::Array< Var > AllVars(const Expr &expr)
Get all variables from expression expr.
bool IsBaseOf(const StructInfo &base, const StructInfo &derived, arith::Analyzer *ana=nullptr)
Check the relation of two struct info to see if one subsumes another one.
tvm::ffi::Array< Var > BoundVars(const Expr &expr)
Get all bound variables from expression expr.
PrimExpr StructInfoBaseCheckPrecondition(const StructInfo &base, const StructInfo &derived)
Return the condition for which base is a superset of derived.
tvm::ffi::Array< tvm::ffi::Array< GlobalVar > > DetectRecursion(const IRModule &m)
Find all sets of recursive or mutually recursive functions in the module.
std::pair< ffi::Map< Var, ffi::Array< Var > >, ffi::Array< Var > > FunctionUseDef(const Expr &expr)
Get the use-def chain of variables inside a function.
VarUsageInfo CollectVarUsage(const Expr &expr)
Collect variable bindings and usage.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
Data structures that can appear in operator attributes.
A utility struct returned by CollectVarUsage.
Definition: analysis.h:466
ffi::Array< Var > outputs
Definition: analysis.h:483
ffi::Map< Var, ffi::Array< Var > > downstream_usage
Definition: analysis.h:477
ffi::Map< Var, Expr > bound_values
Definition: analysis.h:471
TIR Function.