tvm
exec_scope.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
24 #ifndef TVM_TIRX_EXEC_SCOPE_H_
25 #define TVM_TIRX_EXEC_SCOPE_H_
26 
27 #include <tvm/ffi/container/variant.h>
28 #include <tvm/ir/module.h>
29 #include <tvm/tirx/var.h>
30 
31 #include <string>
32 #include <utility>
33 
34 namespace tvm {
35 namespace tirx {
36 
45 enum class ScopeKind : int {
46  kWorld = 0,
47  kKernel = 1,
48  kCluster = 2,
49  kCta = 3,
50  kWarpgroup = 4,
51  kWarp = 5,
52  kThread = 6,
53 };
54 
56 TVM_DLL std::string ScopeKindToString(ScopeKind kind);
57 
59 TVM_DLL ScopeKind StringToScopeKind(const ffi::String& name);
60 
81 enum class ScopeBinding : int {
82  kKernelCluster = 0,
83  kKernelCta = 1,
84  kClusterCta = 2,
85  kCtaWarpgroup = 3,
86  kCtaWarp = 4,
87  kWarpgroupWarp = 5,
88  kWarpThread = 6,
89  kCtaThread = 7,
90  kWarpgroupThread = 8,
91  kClusterCtaPair = 9,
92 };
93 
95 TVM_DLL std::pair<ffi::String, ffi::String> ScopeBindingToStringPair(ScopeBinding binding);
96 
98 TVM_DLL ScopeBinding StringPairToScopeBinding(const ffi::String& parent, const ffi::String& cur);
99 
100 /******** Definition of ScopeId ********/
101 class ScopeIdDefNode : public ffi::Object {
102  public:
104  ffi::Array<Var> def_ids;
117  ffi::Optional<ffi::Array<PrimExpr>> extents;
124  ffi::Optional<ffi::Array<PrimExpr>> preferred_extents;
125 
126  static void RegisterReflection() {
127  namespace refl = tvm::ffi::reflection;
128  refl::ObjectDef<ScopeIdDefNode>()
129  .def_ro("def_ids", &ScopeIdDefNode::def_ids, refl::AttachFieldFlag::SEqHashDef())
130  .def_ro("extents", &ScopeIdDefNode::extents)
131  .def_ro("scope", &ScopeIdDefNode::scope)
132  .def_ro("preferred_extents", &ScopeIdDefNode::preferred_extents);
133  }
134 
135  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
136  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.ScopeIdDef", ScopeIdDefNode, ffi::Object);
137 };
138 
139 class ScopeIdDef : public ffi::ObjectRef {
140  public:
141  TVM_DLL explicit ScopeIdDef(ffi::Array<Var> def_ids, ffi::Optional<ffi::Array<PrimExpr>> extents,
142  ScopeBinding scope,
143  ffi::Optional<ffi::Array<PrimExpr>> preferred_extents =
144  ffi::Optional<ffi::Array<PrimExpr>>(std::nullopt));
145 
147  bool is_deferred() const { return !get()->extents.has_value(); }
148 
151 
154 };
155 
157  public:
158  using ScopeIdSet = std::unordered_map<ScopeBinding, ScopeIdDef>;
159 
169  enum class Mode { kRelaxed, kStrict };
170 
172  bool Verify(const ffi::Array<ScopeIdDef>& defs, Mode mode = Mode::kStrict);
173 
179 };
180 
186  public:
187  using LaunchParams = std::unordered_map<ffi::String, IterVar>;
188 
190  TVM_DLL static ffi::Array<PrimExpr> Resolve(ScopeBinding binding,
191  const ffi::Optional<ffi::Array<PrimExpr>>& extents,
192  int out_dim, const ffi::String& target_kind,
193  const LaunchParams& params);
194 
196  TVM_DLL static PrimExpr ComputeWarpIdInCta(const LaunchParams& params);
197 };
198 
206  return static_cast<int>(a) < static_cast<int>(b);
207 }
208 
210 TVM_DLL bool ScopeNameHigher(const ffi::String& a, const ffi::String& b);
211 
212 /******** Definition of Execution Scope ********/
213 class ExecScopeNode : public ffi::Object {
214  public:
215  ffi::Array<ScopeIdDef> scope_id_def;
216 
219 
221  ffi::String name() const { return ScopeKindToString(kind); }
222 
223  static void RegisterReflection() {
224  namespace refl = tvm::ffi::reflection;
225  refl::ObjectDef<ExecScopeNode>()
226  .def_ro("kind", &ExecScopeNode::kind)
227  .def_ro("scope_id_def", &ExecScopeNode::scope_id_def);
228  }
229 
230  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
231  TVM_FFI_DECLARE_OBJECT_INFO("tirx.ExecScope", ExecScopeNode, ffi::Object);
232 };
233 
234 class ExecScope : public ffi::ObjectRef {
235  public:
237  TVM_DLL explicit ExecScope(ScopeKind kind, ffi::Array<ScopeIdDef> scope_id_def = {});
239  TVM_DLL explicit ExecScope(const ffi::String& name, ffi::Array<ScopeIdDef> scope_id_def = {})
240  : ExecScope(StringToScopeKind(name), std::move(scope_id_def)) {}
241 
243 };
244 
245 } // namespace tirx
246 } // namespace tvm
247 
248 #endif // TVM_TIRX_EXEC_SCOPE_H_
Reference to PrimExprNode.
Definition: expr.h:126
Definition: exec_scope.h:213
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: exec_scope.h:230
ScopeKind kind
scope identity; one of the closed ScopeKind values.
Definition: exec_scope.h:218
static void RegisterReflection()
Definition: exec_scope.h:223
TVM_FFI_DECLARE_OBJECT_INFO("tirx.ExecScope", ExecScopeNode, ffi::Object)
ffi::Array< ScopeIdDef > scope_id_def
Definition: exec_scope.h:215
ffi::String name() const
Human-readable name derived from kind (for printing / errors).
Definition: exec_scope.h:221
Definition: exec_scope.h:234
ExecScope(const ffi::String &name, ffi::Array< ScopeIdDef > scope_id_def={})
Construct from a name string (FATALs on unknown name).
Definition: exec_scope.h:239
ExecScope(ScopeKind kind, ffi::Array< ScopeIdDef > scope_id_def={})
Construct from a ScopeKind (canonical).
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ExecScope, ffi::ObjectRef, ExecScopeNode)
Definition: exec_scope.h:101
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.ScopeIdDef", ScopeIdDefNode, ffi::Object)
ScopeBinding scope
The (parent, cur) binding of this scope id as a closed enum.
Definition: exec_scope.h:119
static void RegisterReflection()
Definition: exec_scope.h:126
ffi::Optional< ffi::Array< PrimExpr > > extents
The extents of the ScopeId.
Definition: exec_scope.h:117
ffi::Array< Var > def_ids
The ScopeId defined.
Definition: exec_scope.h:104
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: exec_scope.h:135
ffi::Optional< ffi::Array< PrimExpr > > preferred_extents
Optional preferred extents (cluster→cta only). Maps to cudaLaunchAttributePreferredClusterDimension (...
Definition: exec_scope.h:124
Definition: exec_scope.h:156
std::unordered_map< ScopeBinding, ScopeIdDef > ScopeIdSet
Definition: exec_scope.h:158
ScopeIdSet id_set
The resolved scope id set; id_set[binding] is the best-known def for that binding (extents filled in ...
Definition: exec_scope.h:178
Mode
Verification mode.
Definition: exec_scope.h:169
bool Verify(const ffi::Array< ScopeIdDef > &defs, Mode mode=Mode::kStrict)
Verify the scope id definitions are well formed.
Definition: exec_scope.h:139
TVM_DEFINE_OBJECT_REF_COW_METHOD(ScopeIdDefNode)
ScopeIdDef(ffi::Array< Var > def_ids, ffi::Optional< ffi::Array< PrimExpr >> extents, ScopeBinding scope, ffi::Optional< ffi::Array< PrimExpr >> preferred_extents=ffi::Optional< ffi::Array< PrimExpr >>(std::nullopt))
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ScopeIdDef, ffi::ObjectRef, ScopeIdDefNode)
bool is_deferred() const
Whether this def has a deferred (unknown) extent.
Definition: exec_scope.h:147
PrimExpr fused_extent() const
Product of all extent dimensions. PRECONDITION: !is_deferred().
Static resolver for ScopeIdDef values. Replaces the former ScopeIdResolveTable runtime registry with ...
Definition: exec_scope.h:185
static ffi::Array< PrimExpr > Resolve(ScopeBinding binding, const ffi::Optional< ffi::Array< PrimExpr >> &extents, int out_dim, const ffi::String &target_kind, const LaunchParams &params)
Resolve a ScopeIdDef for a given canonical binding + target.
static PrimExpr ComputeWarpIdInCta(const LaunchParams &params)
Compute the warp_id_in_cta shuffle expression from threadIdx in launch params.
std::unordered_map< ffi::String, IterVar > LaunchParams
Definition: exec_scope.h:187
IRModule that holds the functions and type definitions.
std::pair< ffi::String, ffi::String > ScopeBindingToStringPair(ScopeBinding binding)
Convert a ScopeBinding to its (parent, cur) string pair.
bool ScopeKindHigher(ScopeKind a, ScopeKind b)
Strict-weak "a is wider than b" on scope kinds: world > kernel > cluster > cta > warpgroup > warp > t...
Definition: exec_scope.h:205
ScopeKind StringToScopeKind(const ffi::String &name)
Parse a string name to a ScopeKind. FATAL if unknown.
ScopeKind
The target execution scope kind of an ExecScopeStmt.
Definition: exec_scope.h:45
bool ScopeNameHigher(const ffi::String &a, const ffi::String &b)
String-keyed convenience over ScopeKindHigher. FATALs on bad name.
ScopeBinding
The binding between a parent scope and a child scope as used by a ScopeIdDef. The closed enum of vali...
Definition: exec_scope.h:81
ScopeBinding StringPairToScopeBinding(const ffi::String &parent, const ffi::String &cur)
Parse a (parent, cur) string pair to a ScopeBinding. FATAL if unknown.
std::string ScopeKindToString(ScopeKind kind)
Convert a ScopeKind to its string name (e.g. kKernel -> "kernel").
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
Variables in the TIR.