24 #ifndef TVM_TIR_ANALYSIS_H_ 25 #define TVM_TIR_ANALYSIS_H_ 71 template <
class FLambda>
74 const BaseFunc& base_func = kv.second;
132 TVM_DLL
bool UsesVar(
const Stmt& stmt, std::function<
bool(
const VarNode*)> vset_contains);
263 const Integer& workspace_byte_alignment);
317 namespace transform {
372 #endif // TVM_TIR_ANALYSIS_H_ Managed reference to BlockNode.
Definition: stmt.h:1258
Map< Buffer, Optional< Stmt > > DetectBufferAccessLCA(const PrimFunc &func)
Detect the lowest common ancestor(LCA) of buffer access, including both high-level access(BufferLoad...
Array< Array< BufferRegion > > GetBlockReadWriteRegion(const Block &block, const Map< Var, Buffer > &buffer_var_map)
Auto detect the block read/write region according to its body stmt. An opaque access will be counted ...
Array< Var > UndefinedVars(const PrimExpr &expr, const Array< Var > &defs)
Find undefined vars in the expression.
IRModule that holds the functions and type definitions.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
bool VerifyWellFormed(const PrimFunc &func, bool assert_mode=true)
Verify if the given TIR is well-formed. The verification includes:
Attribute types in the Op registry for TIR ops.
bool UsesVar(const PrimExpr &expr, std::function< bool(const VarNode *)> vset_contains)
Whether the given PrimExpr uses any var in the given variable set.
A variable node in the IR.
Definition: var.h:47
tvm::Map< String, Integer > CalculateAllocatedBytes(const PrimFunc &func)
Calculate the allocated memory per scope in bytes needed inside the TIR PrimFunc. ...
Primitive functions that contains TIR statements.
Definition: function.h:46
Managed reference to ForNode.
Definition: stmt.h:962
double EstimateTIRFlops(const IRModule &mod)
Estimate the FLOPs of TIRs in an IRModule.
Managed reference to BufferRegionNode.
Definition: stmt.h:1099
const tir::BlockNode * FindAnchorBlock(const IRModule &mod)
Find the "anchor block" of the given module. We define the anchor block to be the block with (1) an i...
const PrimFuncNode * FindEntryFunc(const IRModule &mod, GlobalVar *result_g_var)
Find the entry function of the given IRModule, i.e, functions marked by tir::attr::kIsEntryFunc, whose name is main or being the only PrimeFunc.
size_t CalculateConstantBytes(const PrimFunc &func, const Integer &constant_byte_alignment)
Calculate the constants size in bytes needed by the TIR allocates inside the TIR PrimFunc.
Helper struct for return value of IdentifyMemCpy.
Definition: analysis.h:227
CallEffectKind
The effect type of the call.
Definition: op_attr_types.h:88
Compare two expressions recursively and check if they are equal to each other without var remapping...
Definition: analysis.h:60
BufferRegion source
Definition: analysis.h:228
Map< GlobalVar, BaseFunc > functions
A map from ids to all global functions.
Definition: module.h:59
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Managed reference to PrimFuncNode.
Definition: function.h:145
Managed reference to GlobalVarNode.
Definition: expr.h:477
size_t CalculateWorkspaceBytes(const PrimFunc &func, const Integer &workspace_byte_alignment)
Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc.
Container of all statements.
Definition: stmt.h:59
A block is a basic schedule unit in TIR.
Definition: stmt.h:1191
void VisitPrimFuncs(const IRModule &mod, FLambda fvisit)
Visit the PrimFuncs in the IRModule.
Definition: analysis.h:72
size_t CalculateExprComplexity(const PrimExpr &expr)
Calculate the expresion complexity based on number of symbols it contains.
Array< Array< BufferRegion > > GetBlockAccessRegion(const Block &block, const Map< Var, Buffer > &buffer_var_map)
Auto detect the block access region according to its body stmt It will detect the access region as an...
std::optional< MemCpyDetails > IdentifyMemCpy(const For &loop, arith::Analyzer *analyzer)
Identify whether a For loop is semantically equivalent to MemCpy.
Managed reference class to IRModuleNode.
Definition: module.h:348
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
CallEffectKind SideEffect(const PrimExpr &expr)
Analyze the side effect.
Managed reference to BaseFuncNode.
Definition: function.h:143
BufferRegion dest
Definition: analysis.h:229
Reference to PrimExprNode.
Definition: expr.h:114
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:865
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:579
Container of constant int that adds more constructors.
Definition: expr.h:622