tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
var.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_VAR_H_
25 #define TVM_TIR_VAR_H_
26 
27 #include <tvm/ir/expr.h>
28 #include <tvm/node/node.h>
29 #include <tvm/runtime/data_type.h>
30 
31 #include <string>
32 
33 namespace tvm {
34 namespace tir {
35 
47 class VarNode : public PrimExprNode {
48  public:
62 
64  v->Visit("dtype", &dtype);
65  v->Visit("name", &name_hint);
66  v->Visit("type_annotation", &type_annotation);
67  v->Visit("span", &span);
68  }
69 
70  bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
71  if (!equal(dtype, other->dtype)) return false;
72  if (!equal(type_annotation, other->type_annotation)) return false;
73  return equal.FreeVarEqualImpl(this, other);
74  }
75 
76  void SHashReduce(SHashReducer hash_reduce) const {
77  hash_reduce(dtype);
78  hash_reduce(type_annotation);
79  hash_reduce.FreeVarHashImpl(this);
80  }
81 
82  static constexpr const char* _type_key = "tir.Var";
83  static constexpr const uint32_t _type_child_slots = 1;
85 };
86 
88 class Var : public PrimExpr {
89  public:
90  explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {}
97  TVM_DLL explicit Var(String name_hint = "v", DataType dtype = DataType::Int(32),
98  Span span = Span());
105  TVM_DLL explicit Var(String name_hint, Type type_annotation, Span span = Span());
111  TVM_DLL Var copy_with_name(const String& name) const;
117  TVM_DLL Var copy_with_suffix(const String& suffix) const;
124 
129  const VarNode* operator->() const { return get(); }
134  const VarNode* get() const { return static_cast<const VarNode*>(data_.get()); }
137 };
138 
143 class SizeVarNode : public VarNode {
144  public:
145  static constexpr const char* _type_key = "tir.SizeVar";
147 };
148 
150 class SizeVar : public Var {
151  public:
152  explicit SizeVar(ObjectPtr<Object> n) : Var(n) {}
159  TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32),
160  Span span = Span());
167  TVM_DLL explicit SizeVar(String name_hint, Type type_annotation, Span span = Span());
172  const SizeVarNode* operator->() const { return get(); }
177  const SizeVarNode* get() const { return static_cast<const SizeVarNode*>(data_.get()); }
180 };
181 
183 
191 enum IterVarType : int {
200  kDataPar = 0,
223  kOrdered = 3,
233  kOpaque = 4,
234  // The following are possible additional
235  // types that are provided during schedule
251  kTensorized = 8
252 };
253 
260 class IterVarNode : public Object {
261  public:
280  mutable Span span;
281 
283  v->Visit("dom", &dom);
284  v->Visit("var", &var);
285  v->Visit("iter_type", &iter_type);
286  v->Visit("thread_tag", &thread_tag);
287  v->Visit("span", &span);
288  }
289 
290  bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const {
291  return equal(dom, other->dom) && equal.DefEqual(var, other->var) &&
292  equal(iter_type, other->iter_type) && equal(thread_tag, other->thread_tag);
293  }
294 
295  void SHashReduce(SHashReducer hash_reduce) const {
296  hash_reduce(dom);
297  hash_reduce.DefHash(var);
298  hash_reduce(iter_type);
299  hash_reduce(thread_tag);
300  }
301 
302  static constexpr const char* _type_key = "tir.IterVar";
303  static constexpr const bool _type_has_method_sequal_reduce = true;
304  static constexpr const bool _type_has_method_shash_reduce = true;
306 };
307 
314 class IterVar : public ObjectRef {
315  public:
316  TVM_DLL IterVar(Range dom, Var var, IterVarType iter_type, String thread_tag = "",
317  Span span = Span());
321  inline operator PrimExpr() const;
322 
325 };
326 
327 // inline implementations
328 inline IterVar::operator PrimExpr() const { return (*this)->var; }
329 
330 inline const char* IterVarType2String(IterVarType t) {
331  switch (t) {
332  case kDataPar:
333  return "DataPar";
334  case kThreadIndex:
335  return "ThreadIndex";
336  case kCommReduce:
337  return "CommReduce";
338  case kOrdered:
339  return "Ordered";
340  case kOpaque:
341  return "Opaque";
342  case kUnrolled:
343  return "Unrolled";
344  case kVectorized:
345  return "Vectorized";
346  case kParallelized:
347  return "Parallelized";
348  case kTensorized:
349  return "Tensorized";
350  }
351  return "Unknown";
352 }
353 } // namespace tir
354 } // namespace tvm
355 #endif // TVM_TIR_VAR_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:55
Base node of all primitive expressions.
Definition: expr.h:85
DataType dtype
The runtime data type of the primitive expression.
Definition: expr.h:101
Reference to PrimExprNode.
Definition: expr.h:114
DataType dtype() const
Definition: expr.h:128
Range container
Definition: expr.h:715
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:110
void FreeVarHashImpl(const runtime::Object *var) const
Implementation for hash for a free var.
Definition: structural_hash.h:192
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:187
Definition: source_map.h:120
Managed reference to TypeNode.
Definition: type.h:93
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Runtime primitive data type.
Definition: data_type.h:42
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:176
A custom smart pointer for Object.
Definition: object.h:360
Base class of all object reference.
Definition: object.h:517
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:603
base class of all object containers.
Definition: object.h:169
Reference to string objects.
Definition: string.h:98
An iteration variable representing an iteration over a one dimensional interval.
Definition: var.h:260
Var var
The looping variable.
Definition: var.h:268
String thread_tag
additional tag on the iteration variable, set this if this is bound already to a known thread tag.
Definition: var.h:275
void SHashReduce(SHashReducer hash_reduce) const
Definition: var.h:295
TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object)
static constexpr const bool _type_has_method_sequal_reduce
Definition: var.h:303
bool SEqualReduce(const IterVarNode *other, SEqualReducer equal) const
Definition: var.h:290
static constexpr const char * _type_key
Definition: var.h:302
Span span
Span that points to the original source code. Reserved debug information.
Definition: var.h:280
Range dom
the domain of iteration, if known, can be None For the intermediate schedule node,...
Definition: var.h:266
static constexpr const bool _type_has_method_shash_reduce
Definition: var.h:304
void VisitAttrs(AttrVisitor *v)
Definition: var.h:282
IterVarType iter_type
The type of the IterVar.
Definition: var.h:270
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:314
IterVar(Range dom, Var var, IterVarType iter_type, String thread_tag="", Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterVarNode)
TVM_DEFINE_OBJECT_REF_METHODS(IterVar, ObjectRef, IterVarNode)
A variable node represent a tensor index size, whose value must be non-negative.
Definition: var.h:143
TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode)
static constexpr const char * _type_key
Definition: var.h:145
a named variable represents a tensor index size
Definition: var.h:150
SizeVar(String name_hint="s", DataType t=DataType::Int(32), Span span=Span())
constructor
const SizeVarNode * get() const
Get pointer to the internal value.
Definition: var.h:177
SizeVar(String name_hint, Type type_annotation, Span span=Span())
Constructor which provides a more detailed type annotation.
SizeVar(ObjectPtr< Object > n)
Definition: var.h:152
const SizeVarNode * operator->() const
Get pointer to the internal value.
Definition: var.h:172
A variable node in the IR.
Definition: var.h:47
static constexpr const char * _type_key
Definition: var.h:82
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode)
void SHashReduce(SHashReducer hash_reduce) const
Definition: var.h:76
void VisitAttrs(AttrVisitor *v)
Definition: var.h:63
Type type_annotation
type annotation of the variable.
Definition: var.h:61
bool SEqualReduce(const VarNode *other, SEqualReducer equal) const
Definition: var.h:70
static constexpr const uint32_t _type_child_slots
Definition: var.h:83
String name_hint
The hint to the variable name.
Definition: var.h:53
a named variable in TIR
Definition: var.h:88
const VarNode * get() const
Get pointer to the internal value.
Definition: var.h:134
Var(ObjectPtr< Object > n)
Definition: var.h:90
Var(String name_hint, Type type_annotation, Span span=Span())
Constructor which provides a more detailed type annotation.
const VarNode * operator->() const
Get pointer to the internal value.
Definition: var.h:129
Var copy_with_name(const String &name) const
Make a new copy of var with same type, but a different nam.
Var(String name_hint="v", DataType dtype=DataType::Int(32), Span span=Span())
Constructor.
Var copy_with_suffix(const String &suffix) const
Make a new copy of var with same type, append suffix.
Var copy_with_dtype(DataType dtype) const
Make a new copy of the variable with specified dtype.
Base expr nodes in TVM.
tvm::Span Span
Definition: base.h:65
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
IterVarType
Type of iteration variable. Each IterVar have a specific type.
Definition: var.h:191
@ kVectorized
The loop is vectorized.
Definition: var.h:243
@ kThreadIndex
The IterVar itself is a thread-index of a fixed thread launching group. Note that this is already ass...
Definition: var.h:208
@ kUnrolled
The execution is unrolled.
Definition: var.h:239
@ kTensorized
Marks boundary of tensorization intrinsic.
Definition: var.h:251
@ kDataPar
Data parallel iteration. This normally corresponds to axis of Tensor. Allow all IterVar manipulations...
Definition: var.h:200
@ kOrdered
Serial loops with loop carry dependency, the iteration must execute in order. Cannot be re-ordered.
Definition: var.h:223
@ kCommReduce
Communicative reduction. Cannot be directly parallelized.
Definition: var.h:215
@ kParallelized
The loop is parallelized.
Definition: var.h:247
@ kOpaque
IterVar is opaque,.
Definition: var.h:233
const char * IterVarType2String(IterVarType t)
Definition: var.h:330
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
Definitions and helper macros for IR/AST nodes.