tvm
var.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_VAR_H_
25 #define TVM_TIR_VAR_H_
26 
27 #include <tvm/ir/expr.h>
28 #include <tvm/node/node.h>
29 #include <tvm/runtime/data_type.h>
30 
31 #include <string>
32 
33 namespace tvm {
34 namespace tir {
35 
47 class VarNode : public PrimExprNode {
48  public:
62 
64  v->Visit("dtype", &dtype);
65  v->Visit("name", &name_hint);
66  v->Visit("type_annotation", &type_annotation);
67  v->Visit("span", &span);
68  }
69 
70  bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
71  if (!equal(dtype, other->dtype)) return false;
72  if (!equal(type_annotation, other->type_annotation)) return false;
73  return equal.FreeVarEqualImpl(this, other);
74  }
75 
76  void SHashReduce(SHashReducer hash_reduce) const {
77  hash_reduce(dtype);
78  hash_reduce(type_annotation);
79  hash_reduce.FreeVarHashImpl(this);
80  }
81 
82  static constexpr const char* _type_key = "tir.Var";
83  static constexpr const uint32_t _type_child_slots = 1;
85 };
86 
88 class Var : public PrimExpr {
89  public:
90  explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {}
97  TVM_DLL explicit Var(String name_hint = "v", DataType dtype = DataType::Int(32),
98  Span span = Span());
105  TVM_DLL explicit Var(String name_hint, Type type_annotation, Span span = Span());
111  TVM_DLL Var copy_with_suffix(const String& suffix) const;
117  TVM_DLL Var copy_with_dtype(DataType dtype) const;
118 
123  const VarNode* operator->() const { return get(); }
128  const VarNode* get() const { return static_cast<const VarNode*>(data_.get()); }
131 };
132 
137 class SizeVarNode : public VarNode {
138  public:
139  static constexpr const char* _type_key = "tir.SizeVar";
141 };
142 
144 class SizeVar : public Var {
145  public:
146  explicit SizeVar(ObjectPtr<Object> n) : Var(n) {}
153  TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32),
154  Span span = Span());
159  const SizeVarNode* operator->() const { return get(); }
164  const SizeVarNode* get() const { return static_cast<const SizeVarNode*>(data_.get()); }
167 };
168 
170 
178 enum IterVarType : int {
187  kDataPar = 0,
210  kOrdered = 3,
220  kOpaque = 4,
221  // The following are possible additional
222  // types that are provided during schedule
239 };
240 
247 class IterVarNode : public Object {
248  public:
267  mutable Span span;
268 
270  v->Visit("dom", &dom);
271  v->Visit("var", &var);
272  v->Visit("iter_type", &iter_type);
273  v->Visit("thread_tag", &thread_tag);
274  v->Visit("span", &span);
275  }
276 
277  bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const {
278  return equal(dom, other->dom) && equal.DefEqual(var, other->var) &&
279  equal(iter_type, other->iter_type) && equal(thread_tag, other->thread_tag);
280  }
281 
282  void SHashReduce(SHashReducer hash_reduce) const {
283  hash_reduce(dom);
284  hash_reduce.DefHash(var);
285  hash_reduce(iter_type);
286  hash_reduce(thread_tag);
287  }
288 
289  static constexpr const char* _type_key = "tir.IterVar";
290  static constexpr const bool _type_has_method_sequal_reduce = true;
291  static constexpr const bool _type_has_method_shash_reduce = true;
293 };
294 
301 class IterVar : public ObjectRef {
302  public:
303  TVM_DLL IterVar(Range dom, Var var, IterVarType iter_type, String thread_tag = "",
304  Span span = Span());
308  inline operator PrimExpr() const;
309 
312 };
313 
314 // inline implementations
315 inline IterVar::operator PrimExpr() const { return (*this)->var; }
316 
317 inline const char* IterVarType2String(IterVarType t) {
318  switch (t) {
319  case kDataPar:
320  return "DataPar";
321  case kThreadIndex:
322  return "ThreadIndex";
323  case kCommReduce:
324  return "CommReduce";
325  case kOrdered:
326  return "Ordered";
327  case kOpaque:
328  return "Opaque";
329  case kUnrolled:
330  return "Unrolled";
331  case kVectorized:
332  return "Vectorized";
333  case kParallelized:
334  return "Parallelized";
335  case kTensorized:
336  return "Tensorized";
337  }
338  return "Unknown";
339 }
340 } // namespace tir
341 } // namespace tvm
342 #endif // TVM_TIR_VAR_H_
tvm::Span Span
Definition: base.h:65
void FreeVarHashImpl(const runtime::Object *var) const
Implementation for hash for a free var.
Definition: structural_hash.h:185
bool DefEqual(const ObjectRef &lhs, const ObjectRef &rhs)
Reduce condition to comparison of two definitions, where free vars can be mapped. ...
A custom smart pointer for Object.
Definition: object.h:358
Definitions and helper macros for IR/AST nodes.
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode)
static constexpr const uint32_t _type_child_slots
Definition: var.h:83
bool FreeVarEqualImpl(const runtime::Object *lhs, const runtime::Object *rhs) const
Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
Definition: structural_equal.h:277
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
Base expr nodes in TVM.
static constexpr const bool _type_has_method_shash_reduce
Definition: expr.h:59
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:102
void SHashReduce(SHashReducer hash_reduce) const
Definition: var.h:282
Var(ObjectPtr< Object > n)
Definition: var.h:90
String thread_tag
additional tag on the iteration variable, set this if this is binded already to a known thread tag...
Definition: var.h:262
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
a named variable in TIR
Definition: var.h:88
String name_hint
The hint to the variable name.
Definition: var.h:53
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:301
A variable node in the IR.
Definition: var.h:47
Type type_annotation
type annotation of the variable.
Definition: var.h:61
base class of all object containers.
Definition: object.h:167
A variable node represent a tensor index size, whose value must be non-negative.
Definition: var.h:137
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
bool SEqualReduce(const VarNode *other, SEqualReducer equal) const
Definition: var.h:70
An iteration variable representing an iteration over a one dimensional interval.
Definition: var.h:247
IterVarType
Type of iteration variable. Each IterVar have a specific type.
Definition: var.h:178
SizeVar(ObjectPtr< Object > n)
Definition: var.h:146
Range constainer.
Definition: expr.h:713
Definition: span.h:115
void VisitAttrs(AttrVisitor *v)
Definition: var.h:269
Span span
Span that points to the original source code. Reserved debug information.
Definition: var.h:267
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:55
Runtime primitive data type.
Definition: data_type.h:41
void VisitAttrs(AttrVisitor *v)
Definition: var.h:63
const VarNode * operator->() const
Get pointer to the internal value.
Definition: var.h:123
const char * IterVarType2String(IterVarType t)
Definition: var.h:317
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
static constexpr const bool _type_has_method_sequal_reduce
Definition: expr.h:58
Communicative reduction. Cannot be directly parallelized.
Definition: var.h:202
Var var
The looping variable.
Definition: var.h:255
IterVarType iter_type
The type of the IterVar.
Definition: var.h:257
Reference to string objects.
Definition: string.h:97
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
The loop is vectorized.
Definition: var.h:230
The execution is unrolled.
Definition: var.h:226
bool SEqualReduce(const IterVarNode *other, SEqualReducer equal) const
Definition: var.h:277
Data parallel iteration. This normally corresponds to axis of Tensor. Allow all IterVar manipulations...
Definition: var.h:187
void SHashReduce(SHashReducer hash_reduce) const
Definition: var.h:76
Serial loops with loop carry dependency, the iteration must execute in order. Cannot be re-ordered...
Definition: var.h:210
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Marks boundary of tensorization intrinsic.
Definition: var.h:238
Base class of all object reference.
Definition: object.h:511
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:785
a named variable represents a tensor index size
Definition: var.h:144
static constexpr const char * _type_key
Definition: var.h:82
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
The IterVar itself is a thread-index of a fixed thread launching group. Note that this is already ass...
Definition: var.h:195
const SizeVarNode * operator->() const
Get pointer to the internal value.
Definition: var.h:159
DataType dtype
The runtime data type of the primitive expression.
Definition: expr.h:101
Managed reference to TypeNode.
Definition: type.h:93
Reference to PrimExprNode.
Definition: expr.h:112
Range dom
the domain of iteration, if known, can be None For the intermediate schedule node, before schedule.
Definition: var.h:253
IterVar is opaque,.
Definition: var.h:220
The loop is parallelized.
Definition: var.h:234
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:164
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:179
Base node of all primitive expressions.
Definition: expr.h:85