tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
var.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_VAR_H_
25 #define TVM_TIR_VAR_H_
26 
27 #include <tvm/ir/expr.h>
28 #include <tvm/node/node.h>
29 #include <tvm/runtime/data_type.h>
30 
31 #include <functional>
32 #include <string>
33 
34 namespace tvm {
35 namespace tir {
36 
48 class VarNode : public PrimExprNode {
49  public:
63 
65  v->Visit("dtype", &dtype);
66  v->Visit("name", &name_hint);
67  v->Visit("type_annotation", &type_annotation);
68  v->Visit("span", &span);
69  }
70 
71  bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
72  if (!equal(dtype, other->dtype)) return false;
73  if (!equal(type_annotation, other->type_annotation)) return false;
74  return equal.FreeVarEqualImpl(this, other);
75  }
76 
77  void SHashReduce(SHashReducer hash_reduce) const {
78  hash_reduce(dtype);
79  hash_reduce(type_annotation);
80  hash_reduce.FreeVarHashImpl(this);
81  }
82 
83  static constexpr const char* _type_key = "tir.Var";
84  static constexpr const uint32_t _type_child_slots = 1;
86 };
87 
89 class Var : public PrimExpr {
90  public:
91  explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {}
98  TVM_DLL explicit Var(String name_hint = "v", DataType dtype = DataType::Int(32),
99  Span span = Span());
106  TVM_DLL explicit Var(String name_hint, Type type_annotation, Span span = Span());
112  TVM_DLL Var copy_with_name(const String& name) const;
118  TVM_DLL Var copy_with_suffix(const String& suffix) const;
125 
130  const VarNode* operator->() const { return get(); }
135  const VarNode* get() const { return static_cast<const VarNode*>(data_.get()); }
138 };
139 
144 class SizeVarNode : public VarNode {
145  public:
146  static constexpr const char* _type_key = "tir.SizeVar";
148 };
149 
151 class SizeVar : public Var {
152  public:
153  explicit SizeVar(ObjectPtr<Object> n) : Var(n) {}
160  TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32),
161  Span span = Span());
168  TVM_DLL explicit SizeVar(String name_hint, Type type_annotation, Span span = Span());
173  const SizeVarNode* operator->() const { return get(); }
178  const SizeVarNode* get() const { return static_cast<const SizeVarNode*>(data_.get()); }
181 };
182 
184 
192 enum IterVarType : int {
201  kDataPar = 0,
224  kOrdered = 3,
234  kOpaque = 4,
235  // The following are possible additional
236  // types that are provided during schedule
252  kTensorized = 8
253 };
254 
261 class IterVarNode : public Object {
262  public:
281  mutable Span span;
282 
284  v->Visit("dom", &dom);
285  v->Visit("var", &var);
286  v->Visit("iter_type", &iter_type);
287  v->Visit("thread_tag", &thread_tag);
288  v->Visit("span", &span);
289  }
290 
291  bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const {
292  return equal(dom, other->dom) && equal.DefEqual(var, other->var) &&
293  equal(iter_type, other->iter_type) && equal(thread_tag, other->thread_tag);
294  }
295 
296  void SHashReduce(SHashReducer hash_reduce) const {
297  hash_reduce(dom);
298  hash_reduce.DefHash(var);
299  hash_reduce(iter_type);
300  hash_reduce(thread_tag);
301  }
302 
303  static constexpr const char* _type_key = "tir.IterVar";
304  static constexpr const bool _type_has_method_sequal_reduce = true;
305  static constexpr const bool _type_has_method_shash_reduce = true;
307 };
308 
315 class IterVar : public ObjectRef {
316  public:
317  TVM_DLL IterVar(Range dom, Var var, IterVarType iter_type, String thread_tag = "",
318  Span span = Span());
322  inline operator PrimExpr() const;
323 
326 };
327 
328 // inline implementations
329 inline IterVar::operator PrimExpr() const { return (*this)->var; }
330 
331 inline const char* IterVarType2String(IterVarType t) {
332  switch (t) {
333  case kDataPar:
334  return "DataPar";
335  case kThreadIndex:
336  return "ThreadIndex";
337  case kCommReduce:
338  return "CommReduce";
339  case kOrdered:
340  return "Ordered";
341  case kOpaque:
342  return "Opaque";
343  case kUnrolled:
344  return "Unrolled";
345  case kVectorized:
346  return "Vectorized";
347  case kParallelized:
348  return "Parallelized";
349  case kTensorized:
350  return "Tensorized";
351  }
352  return "Unknown";
353 }
354 } // namespace tir
355 } // namespace tvm
356 
357 /* \brief Allow tir.Var as key in STL tables
358  *
359  * For most TIR expressions, it would be ambiguous whether the
360  * expression should follow reference equality or structural equality.
361  * This is not the case for variables, which do not contain nested
362  * internal structure, and are frequently used as keys in lookup
363  * tables.
364  *
365  * Providing `std::hash` and `std::equal_to` specializations for
366  * `tir::Var` allows it to be used as a key in STL tables. For
367  * `PrimExpr`, the user must specify the type of equality used
368  * (e.g. `std::unordered_set<T, StructuralHash, StructuralEqual>` or
369  * `std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>`).
370  */
371 template <>
372 struct std::hash<tvm::tir::Var> {
373  std::size_t operator()(const tvm::tir::Var& var) const {
375  }
376 };
377 
378 template <>
379 struct std::equal_to<tvm::tir::Var> {
380  bool operator()(const tvm::tir::Var& var_a, const tvm::tir::Var& var_b) const {
381  return tvm::runtime::ObjectPtrEqual()(var_a, var_b);
382  }
383 };
384 
385 #endif // TVM_TIR_VAR_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:56
Base node of all primitive expressions.
Definition: expr.h:86
DataType dtype
The runtime data type of the primitive expression.
Definition: expr.h:102
Reference to PrimExprNode.
Definition: expr.h:115
DataType dtype() const
Definition: expr.h:129
Range container
Definition: expr.h:687
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:135
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
void FreeVarHashImpl(const runtime::Object *var) const
Implementation for hash for a free var.
Definition: structural_hash.h:203
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:198
Definition: source_map.h:120
Managed reference to TypeNode.
Definition: type.h:93
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Runtime primitive data type.
Definition: data_type.h:43
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:227
A custom smart pointer for Object.
Definition: object.h:363
Base class of all object reference.
Definition: object.h:520
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:606
base class of all object containers.
Definition: object.h:172
Reference to string objects.
Definition: string.h:97
An iteration variable representing an iteration over a one dimensional interval.
Definition: var.h:261
Var var
The looping variable.
Definition: var.h:269
String thread_tag
additional tag on the iteration variable, set this if this is bound already to a known thread tag.
Definition: var.h:276
void SHashReduce(SHashReducer hash_reduce) const
Definition: var.h:296
TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object)
static constexpr const bool _type_has_method_sequal_reduce
Definition: var.h:304
bool SEqualReduce(const IterVarNode *other, SEqualReducer equal) const
Definition: var.h:291
static constexpr const char * _type_key
Definition: var.h:303
Span span
Span that points to the original source code. Reserved debug information.
Definition: var.h:281
Range dom
the domain of iteration, if known, can be None For the intermediate schedule node,...
Definition: var.h:267
static constexpr const bool _type_has_method_shash_reduce
Definition: var.h:305
void VisitAttrs(AttrVisitor *v)
Definition: var.h:283
IterVarType iter_type
The type of the IterVar.
Definition: var.h:271
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:315
IterVar(Range dom, Var var, IterVarType iter_type, String thread_tag="", Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterVarNode)
TVM_DEFINE_OBJECT_REF_METHODS(IterVar, ObjectRef, IterVarNode)
A variable node represent a tensor index size, whose value must be non-negative.
Definition: var.h:144
TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode)
static constexpr const char * _type_key
Definition: var.h:146
a named variable represents a tensor index size
Definition: var.h:151
SizeVar(String name_hint="s", DataType t=DataType::Int(32), Span span=Span())
constructor
const SizeVarNode * get() const
Get pointer to the internal value.
Definition: var.h:178
SizeVar(String name_hint, Type type_annotation, Span span=Span())
Constructor which provides a more detailed type annotation.
SizeVar(ObjectPtr< Object > n)
Definition: var.h:153
const SizeVarNode * operator->() const
Get pointer to the internal value.
Definition: var.h:173
A variable node in the IR.
Definition: var.h:48
static constexpr const char * _type_key
Definition: var.h:83
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode)
void SHashReduce(SHashReducer hash_reduce) const
Definition: var.h:77
void VisitAttrs(AttrVisitor *v)
Definition: var.h:64
Type type_annotation
type annotation of the variable.
Definition: var.h:62
bool SEqualReduce(const VarNode *other, SEqualReducer equal) const
Definition: var.h:71
static constexpr const uint32_t _type_child_slots
Definition: var.h:84
String name_hint
The hint to the variable name.
Definition: var.h:54
a named variable in TIR
Definition: var.h:89
const VarNode * get() const
Get pointer to the internal value.
Definition: var.h:135
Var(ObjectPtr< Object > n)
Definition: var.h:91
Var(String name_hint, Type type_annotation, Span span=Span())
Constructor which provides a more detailed type annotation.
const VarNode * operator->() const
Get pointer to the internal value.
Definition: var.h:130
Var copy_with_name(const String &name) const
Make a new copy of var with same type, but a different nam.
Var(String name_hint="v", DataType dtype=DataType::Int(32), Span span=Span())
Constructor.
Var copy_with_suffix(const String &suffix) const
Make a new copy of var with same type, append suffix.
Var copy_with_dtype(DataType dtype) const
Make a new copy of the variable with specified dtype.
Base expr nodes in TVM.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
IterVarType
Type of iteration variable. Each IterVar have a specific type.
Definition: var.h:192
@ kVectorized
The loop is vectorized.
Definition: var.h:244
@ kThreadIndex
The IterVar itself is a thread-index of a fixed thread launching group. Note that this is already ass...
Definition: var.h:209
@ kUnrolled
The execution is unrolled.
Definition: var.h:240
@ kTensorized
Marks boundary of tensorization intrinsic.
Definition: var.h:252
@ kDataPar
Data parallel iteration. This normally corresponds to axis of Tensor. Allow all IterVar manipulations...
Definition: var.h:201
@ kOrdered
Serial loops with loop carry dependency, the iteration must execute in order. Cannot be re-ordered.
Definition: var.h:224
@ kCommReduce
Communicative reduction. Cannot be directly parallelized.
Definition: var.h:216
@ kParallelized
The loop is parallelized.
Definition: var.h:248
@ kOpaque
IterVar is opaque,.
Definition: var.h:234
const char * IterVarType2String(IterVarType t)
Definition: var.h:331
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
Definitions and helper macros for IR/AST nodes.
ObjectRef equal functor.
Definition: object.h:666
ObjectRef hash functor.
Definition: object.h:656