19 #ifndef TVM_TIR_UTILS_H_
20 #define TVM_TIR_UTILS_H_
25 #include <unordered_map>
37 #define TVM_SREF_AS_OR_ERR(Result, SRef, Type) \
38 SRef->StmtAs<Type>(); \
49 #define TVM_SREF_TO_BLOCK(SRef) \
51 auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::BlockNode) \
52 << "TypeError: Expects StmtSRef `" << #SRef << "` points to `Block`, but gets: " \
53 << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \
65 #define TVM_SREF_TO_FOR(SRef) \
67 auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::ForNode) \
68 << "TypeError: Expects StmtSRef `" << #SRef << "` points to `Loop`, but gets: " \
69 << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \
80 #define TVM_TYPE_AS_OR_ERR(Result, From, Type) \
90 #define TVM_TYPE_AS(From, Type) \
92 auto result = TVM_TYPE_AS_OR_ERR(result, (From), Type) \
93 << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \
94 << "`, but gets: " << ((From).defined() ? (From)->GetTypeKey() : "None"); \
106 inline void SetSeqIndex(std::unordered_map<const StmtNode*, StmtSRef>& stmt2ref,
107 const Stmt& stmt,
int seq_index,
bool include_loops =
true) {
109 const BlockNode* block = realize->block.get();
110 ICHECK(stmt2ref.count(block));
111 stmt2ref.at(block)->seq_index = seq_index;
112 }
else if (
const auto* block = stmt.
as<
BlockNode>()) {
113 ICHECK(stmt2ref.count(block));
114 stmt2ref.at(block)->seq_index = seq_index;
115 }
else if (
const auto* loop = stmt.
as<
ForNode>()) {
116 if (!include_loops)
return;
117 ICHECK(stmt2ref.count(loop));
118 stmt2ref.at(loop)->seq_index = seq_index;
129 std::unordered_map<const StmtNode*, StmtSRef>& stmt2ref,
130 const SeqStmtNode* seq_stmt,
bool include_loops =
true) {
132 for (
const Stmt& stmt : seq_stmt->
seq) {
Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
A block is a basic schedule unit in TIR.
Definition: stmt.h:1258
A block realization node represents execution of the block at the binding values.
Definition: stmt.h:1342
A for loop, with possible type annotations.
Definition: stmt.h:967
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:670
Array< Stmt > seq
internal sequence content.
Definition: stmt.h:673
Container of all statements.
Definition: stmt.h:59
void SetSeqIndexInChildren(std::unordered_map< const StmtNode *, StmtSRef > &stmt2ref, const SeqStmtNode *seq_stmt, bool include_loops=true)
Update seq_index of the children of a SeqStmt.
Definition: utils.h:128
void SetSeqIndex(std::unordered_map< const StmtNode *, StmtSRef > &stmt2ref, const Stmt &stmt, int seq_index, bool include_loops=true)
Set the StmtSRefNode::seq_index field for stmt.
Definition: utils.h:106
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36