tvm
var.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_VAR_H_
25 #define TVM_TIR_VAR_H_
26 
27 #include <tvm/ir/expr.h>
28 #include <tvm/node/node.h>
29 #include <tvm/runtime/data_type.h>
30 
31 #include <functional>
32 #include <string>
33 
34 namespace tvm {
35 namespace tir {
36 
48 class VarNode : public PrimExprNode {
49  public:
54  ffi::String name_hint;
63 
64  static void RegisterReflection() {
65  namespace refl = tvm::ffi::reflection;
66  refl::ObjectDef<VarNode>()
67  .def_ro("name", &VarNode::name_hint, refl::AttachFieldFlag::SEqHashIgnore())
68  .def_ro("type_annotation", &VarNode::type_annotation);
69  }
70 
71  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar;
72  static constexpr const uint32_t _type_child_slots = 1;
74 };
75 
77 class Var : public PrimExpr {
78  public:
79  explicit Var(ffi::UnsafeInit tag) : PrimExpr(tag) {}
80  explicit Var(ObjectPtr<VarNode> n) : PrimExpr(n) {}
87  TVM_DLL explicit Var(ffi::String name_hint = "v", DataType dtype = DataType::Int(32),
88  Span span = Span());
95  TVM_DLL explicit Var(ffi::String name_hint, Type type_annotation, Span span = Span());
101  TVM_DLL Var copy_with_name(const ffi::String& name) const;
107  TVM_DLL Var copy_with_suffix(const ffi::String& suffix) const;
114 
119  const VarNode* operator->() const { return get(); }
124  const VarNode* get() const { return static_cast<const VarNode*>(data_.get()); }
127 };
128 
133 class SizeVarNode : public VarNode {
134  public:
135  static void RegisterReflection() {
136  namespace refl = tvm::ffi::reflection;
137  refl::ObjectDef<SizeVarNode>();
138  }
140 };
141 
143 class SizeVar : public Var {
144  public:
145  explicit SizeVar(ObjectPtr<SizeVarNode> n) : Var(n) {}
146  explicit SizeVar(ffi::UnsafeInit tag) : Var(tag) {}
153  TVM_DLL explicit SizeVar(ffi::String name_hint = "s", DataType t = DataType::Int(32),
154  Span span = Span());
161  TVM_DLL explicit SizeVar(ffi::String name_hint, Type type_annotation, Span span = Span());
166  const SizeVarNode* operator->() const { return get(); }
171  const SizeVarNode* get() const { return static_cast<const SizeVarNode*>(data_.get()); }
174 };
175 
176 using Region = ffi::Array<Range>;
177 
185 enum IterVarType : int {
194  kDataPar = 0,
217  kOrdered = 3,
227  kOpaque = 4,
228  // The following are possible additional
229  // types that are provided during schedule
245  kTensorized = 8
246 };
247 
255  public:
269  ffi::String thread_tag;
274  mutable Span span;
275 
276  PrimExpr ToPrimExpr() const final { return var; }
277 
278  static void RegisterReflection() {
279  namespace refl = tvm::ffi::reflection;
280  refl::ObjectDef<IterVarNode>()
281  .def_ro("dom", &IterVarNode::dom)
282  .def_ro("var", &IterVarNode::var, refl::AttachFieldFlag::SEqHashDef())
283  .def_ro("iter_type", &IterVarNode::iter_type)
284  .def_ro("thread_tag", &IterVarNode::thread_tag);
285  }
286 
287  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
289 };
290 
297 class IterVar : public PrimExprConvertible {
298  public:
299  TVM_DLL IterVar(Range dom, Var var, IterVarType iter_type, ffi::String thread_tag = "",
300  Span span = Span());
304  inline operator PrimExpr() const;
305 
308 };
309 
310 // inline implementations
311 inline IterVar::operator PrimExpr() const { return (*this)->var; }
312 
313 inline const char* IterVarType2String(IterVarType t) {
314  switch (t) {
315  case kDataPar:
316  return "DataPar";
317  case kThreadIndex:
318  return "ThreadIndex";
319  case kCommReduce:
320  return "CommReduce";
321  case kOrdered:
322  return "Ordered";
323  case kOpaque:
324  return "Opaque";
325  case kUnrolled:
326  return "Unrolled";
327  case kVectorized:
328  return "Vectorized";
329  case kParallelized:
330  return "Parallelized";
331  case kTensorized:
332  return "Tensorized";
333  }
334  return "Unknown";
335 }
336 } // namespace tir
337 } // namespace tvm
338 
339 /* \brief Allow tir.Var as key in STL tables
340  *
341  * For most TIR expressions, it would be ambiguous whether the
342  * expression should follow reference equality or structural equality.
343  * This is not the case for variables, which do not contain nested
344  * internal structure, and are frequently used as keys in lookup
345  * tables.
346  *
347  * Providing `std::hash` and `std::equal_to` specializations for
348  * `tir::Var` allows it to be used as a key in STL tables. For
349  * `PrimExpr`, the user must specify the type of equality used
350  * (e.g. `std::unordered_set<T, StructuralHash, StructuralEqual>` or
351  * `std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>`).
352  */
353 template <>
354 struct std::hash<tvm::tir::Var> {
355  std::size_t operator()(const tvm::tir::Var& var) const {
356  return tvm::runtime::ObjectPtrHash()(var);
357  }
358 };
359 
360 template <>
361 struct std::equal_to<tvm::tir::Var> {
362  bool operator()(const tvm::tir::Var& var_a, const tvm::tir::Var& var_b) const {
363  return tvm::runtime::ObjectPtrEqual()(var_a, var_b);
364  }
365 };
366 
367 #endif // TVM_TIR_VAR_H_
Base class for other IR constructs that can be converted to PrimExpr. This is useful for the FFI to c...
Definition: expr.h:154
Managed reference to PrimExprConvertibleNode.
Definition: expr.h:165
Base node of all primitive expressions.
Definition: expr.h:91
Reference to PrimExprNode.
Definition: expr.h:124
DataType dtype() const
Definition: expr.h:138
Range container
Definition: expr.h:689
Definition: source_map.h:111
Managed reference to TypeNode.
Definition: type.h:100
Runtime primitive data type.
Definition: data_type.h:47
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:274
An iteration variable representing an iteration over a one dimensional interval.
Definition: var.h:254
ffi::String thread_tag
additional tag on the iteration variable, set this if this is bound already to a known thread tag.
Definition: var.h:269
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: var.h:287
Var var
The looping variable.
Definition: var.h:262
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.IterVar", IterVarNode, PrimExprConvertibleNode)
Span span
Span that points to the original source code. Reserved debug information.
Definition: var.h:274
static void RegisterReflection()
Definition: var.h:278
PrimExpr ToPrimExpr() const final
Definition: var.h:276
Range dom
the domain of iteration, if known, can be None For the intermediate schedule node,...
Definition: var.h:260
IterVarType iter_type
The type of the IterVar.
Definition: var.h:264
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:297
IterVar(Range dom, Var var, IterVarType iter_type, ffi::String thread_tag="", Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterVarNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IterVar, PrimExprConvertible, IterVarNode)
A variable node represent a tensor index size, whose value must be non-negative.
Definition: var.h:133
static void RegisterReflection()
Definition: var.h:135
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SizeVar", SizeVarNode, VarNode)
a named variable represents a tensor index size
Definition: var.h:143
SizeVar(ffi::String name_hint, Type type_annotation, Span span=Span())
Constructor which provides a more detailed type annotation.
const SizeVarNode * get() const
Get pointer to the internal value.
Definition: var.h:171
SizeVar(ffi::UnsafeInit tag)
Definition: var.h:146
SizeVar(ffi::String name_hint="s", DataType t=DataType::Int(32), Span span=Span())
constructor
SizeVar(ObjectPtr< SizeVarNode > n)
Definition: var.h:145
const SizeVarNode * operator->() const
Get pointer to the internal value.
Definition: var.h:166
A variable node in the IR.
Definition: var.h:48
ffi::String name_hint
The hint to the variable name.
Definition: var.h:54
static void RegisterReflection()
Definition: var.h:64
Type type_annotation
type annotation of the variable.
Definition: var.h:62
TVM_FFI_DECLARE_OBJECT_INFO("tir.Var", VarNode, PrimExprNode)
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: var.h:71
static constexpr const uint32_t _type_child_slots
Definition: var.h:72
a named variable in TIR
Definition: var.h:77
const VarNode * get() const
Get pointer to the internal value.
Definition: var.h:124
Var(ffi::String name_hint="v", DataType dtype=DataType::Int(32), Span span=Span())
Constructor.
Var(ffi::UnsafeInit tag)
Definition: var.h:79
Var(ffi::String name_hint, Type type_annotation, Span span=Span())
Constructor which provides a more detailed type annotation.
Var copy_with_suffix(const ffi::String &suffix) const
Make a new copy of var with same type, append suffix.
Var(ObjectPtr< VarNode > n)
Definition: var.h:80
const VarNode * operator->() const
Get pointer to the internal value.
Definition: var.h:119
Var copy_with_dtype(DataType dtype) const
Make a new copy of the variable with specified dtype.
Var copy_with_name(const ffi::String &name) const
Make a new copy of var with same type, but a different nam.
Base expr nodes in TVM.
Definition: repr_printer.h:91
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
ffi::Array< Range > Region
Definition: var.h:176
IterVarType
Type of iteration variable. Each IterVar have a specific type.
Definition: var.h:185
@ kVectorized
The loop is vectorized.
Definition: var.h:237
@ kThreadIndex
The IterVar itself is a thread-index of a fixed thread launching group. Note that this is already ass...
Definition: var.h:202
@ kUnrolled
The execution is unrolled.
Definition: var.h:233
@ kTensorized
Marks boundary of tensorization intrinsic.
Definition: var.h:245
@ kDataPar
Data parallel iteration. This normally corresponds to axis of Tensor. Allow all IterVar manipulations...
Definition: var.h:194
@ kOrdered
Serial loops with loop carry dependency, the iteration must execute in order. Cannot be re-ordered.
Definition: var.h:217
@ kCommReduce
Communicative reduction. Cannot be directly parallelized.
Definition: var.h:209
@ kParallelized
The loop is parallelized.
Definition: var.h:241
@ kOpaque
IterVar is opaque,.
Definition: var.h:227
const char * IterVarType2String(IterVarType t)
Definition: var.h:313
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
Definitions and helper macros for IR/AST nodes.