tvm
utils.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_TIR_UTILS_H_
20 #define TVM_TIR_UTILS_H_
21 
22 #include <tvm/tir/block_scope.h>
23 #include <tvm/tir/stmt.h>
24 
25 #include <unordered_map>
26 
27 namespace tvm {
28 namespace tir {
29 
37 #define TVM_SREF_AS_OR_ERR(Result, SRef, Type) \
38  SRef->StmtAs<Type>(); \
39  ICHECK(Result)
40 
49 #define TVM_SREF_TO_BLOCK(SRef) \
50  [&]() { \
51  auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::BlockNode) \
52  << "TypeError: Expects StmtSRef `" << #SRef << "` points to `Block`, but gets: " \
53  << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \
54  return result; \
55  }()
56 
65 #define TVM_SREF_TO_FOR(SRef) \
66  [&]() { \
67  auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::ForNode) \
68  << "TypeError: Expects StmtSRef `" << #SRef << "` points to `Loop`, but gets: " \
69  << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \
70  return result; \
71  }()
72 
80 #define TVM_TYPE_AS_OR_ERR(Result, From, Type) \
81  From.as<Type>(); \
82  ICHECK(Result)
83 
90 #define TVM_TYPE_AS(From, Type) \
91  [&]() { \
92  auto result = TVM_TYPE_AS_OR_ERR(result, (From), Type) \
93  << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \
94  << "`, but gets: " << ((From).defined() ? (From)->GetTypeKey() : "None"); \
95  return result; \
96  }()
97 
106 inline void SetSeqIndex(std::unordered_map<const StmtNode*, StmtSRef>& stmt2ref, // NOLINT(*)
107  const Stmt& stmt, int seq_index, bool include_loops = true) {
108  if (const auto* realize = stmt.as<BlockRealizeNode>()) {
109  const BlockNode* block = realize->block.get();
110  ICHECK(stmt2ref.count(block));
111  stmt2ref.at(block)->seq_index = seq_index;
112  } else if (const auto* block = stmt.as<BlockNode>()) {
113  ICHECK(stmt2ref.count(block));
114  stmt2ref.at(block)->seq_index = seq_index;
115  } else if (const auto* loop = stmt.as<ForNode>()) {
116  if (!include_loops) return;
117  ICHECK(stmt2ref.count(loop));
118  stmt2ref.at(loop)->seq_index = seq_index;
119  }
120 }
121 
129  std::unordered_map<const StmtNode*, StmtSRef>& stmt2ref, // NOLINT(*)
130  const SeqStmtNode* seq_stmt, bool include_loops = true) {
131  int i = 0;
132  for (const Stmt& stmt : seq_stmt->seq) {
133  SetSeqIndex(stmt2ref, stmt, i, include_loops);
134  ++i;
135  }
136 }
137 
138 } // namespace tir
139 } // namespace tvm
140 
141 #endif // TVM_TIR_UTILS_H_
Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
A block is a basic schedule unit in TIR.
Definition: stmt.h:1258
A block realization node represents execution of the block at the binding values.
Definition: stmt.h:1342
A for loop, with possible type annotations.
Definition: stmt.h:967
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:670
Array< Stmt > seq
internal sequence content.
Definition: stmt.h:673
Container of all statements.
Definition: stmt.h:59
void SetSeqIndexInChildren(std::unordered_map< const StmtNode *, StmtSRef > &stmt2ref, const SeqStmtNode *seq_stmt, bool include_loops=true)
Update seq_index of the children of a SeqStmt.
Definition: utils.h:128
void SetSeqIndex(std::unordered_map< const StmtNode *, StmtSRef > &stmt2ref, const Stmt &stmt, int seq_index, bool include_loops=true)
Set the StmtSRefNode::seq_index field for stmt.
Definition: utils.h:106
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
TIR statements.