tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
var.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TIR_VAR_H_
25 #define TVM_TIR_VAR_H_
26 
27 #include <tvm/ir/expr.h>
28 #include <tvm/node/node.h>
29 #include <tvm/runtime/data_type.h>
30 
31 #include <string>
32 
33 namespace tvm {
34 namespace tir {
35 
47 class VarNode : public PrimExprNode {
48  public:
62 
64  v->Visit("dtype", &dtype);
65  v->Visit("name", &name_hint);
66  v->Visit("type_annotation", &type_annotation);
67  v->Visit("span", &span);
68  }
69 
70  bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
71  if (!equal(dtype, other->dtype)) return false;
72  if (!equal(type_annotation, other->type_annotation)) return false;
73  return equal.FreeVarEqualImpl(this, other);
74  }
75 
76  void SHashReduce(SHashReducer hash_reduce) const {
77  hash_reduce(dtype);
78  hash_reduce(type_annotation);
79  hash_reduce.FreeVarHashImpl(this);
80  }
81 
82  static constexpr const char* _type_key = "tir.Var";
83  static constexpr const uint32_t _type_child_slots = 1;
85 };
86 
88 class Var : public PrimExpr {
89  public:
90  explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {}
97  TVM_DLL explicit Var(String name_hint = "v", DataType dtype = DataType::Int(32),
98  Span span = Span());
105  TVM_DLL explicit Var(String name_hint, Type type_annotation, Span span = Span());
111  TVM_DLL Var copy_with_suffix(const String& suffix) const;
117  TVM_DLL Var copy_with_dtype(DataType dtype) const;
118 
123  const VarNode* operator->() const { return get(); }
128  const VarNode* get() const { return static_cast<const VarNode*>(data_.get()); }
131 };
132 
137 class SizeVarNode : public VarNode {
138  public:
139  static constexpr const char* _type_key = "tir.SizeVar";
141 };
142 
144 class SizeVar : public Var {
145  public:
146  explicit SizeVar(ObjectPtr<Object> n) : Var(n) {}
153  TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32),
154  Span span = Span());
161  TVM_DLL explicit SizeVar(String name_hint, Type type_annotation, Span span = Span());
166  const SizeVarNode* operator->() const { return get(); }
171  const SizeVarNode* get() const { return static_cast<const SizeVarNode*>(data_.get()); }
174 };
175 
177 
185 enum IterVarType : int {
194  kDataPar = 0,
217  kOrdered = 3,
227  kOpaque = 4,
228  // The following are possible additional
229  // types that are provided during schedule
246 };
247 
254 class IterVarNode : public Object {
255  public:
274  mutable Span span;
275 
277  v->Visit("dom", &dom);
278  v->Visit("var", &var);
279  v->Visit("iter_type", &iter_type);
280  v->Visit("thread_tag", &thread_tag);
281  v->Visit("span", &span);
282  }
283 
284  bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const {
285  return equal(dom, other->dom) && equal.DefEqual(var, other->var) &&
286  equal(iter_type, other->iter_type) && equal(thread_tag, other->thread_tag);
287  }
288 
289  void SHashReduce(SHashReducer hash_reduce) const {
290  hash_reduce(dom);
291  hash_reduce.DefHash(var);
292  hash_reduce(iter_type);
293  hash_reduce(thread_tag);
294  }
295 
296  static constexpr const char* _type_key = "tir.IterVar";
297  static constexpr const bool _type_has_method_sequal_reduce = true;
298  static constexpr const bool _type_has_method_shash_reduce = true;
300 };
301 
308 class IterVar : public ObjectRef {
309  public:
310  TVM_DLL IterVar(Range dom, Var var, IterVarType iter_type, String thread_tag = "",
311  Span span = Span());
315  inline operator PrimExpr() const;
316 
319 };
320 
321 // inline implementations
322 inline IterVar::operator PrimExpr() const { return (*this)->var; }
323 
324 inline const char* IterVarType2String(IterVarType t) {
325  switch (t) {
326  case kDataPar:
327  return "DataPar";
328  case kThreadIndex:
329  return "ThreadIndex";
330  case kCommReduce:
331  return "CommReduce";
332  case kOrdered:
333  return "Ordered";
334  case kOpaque:
335  return "Opaque";
336  case kUnrolled:
337  return "Unrolled";
338  case kVectorized:
339  return "Vectorized";
340  case kParallelized:
341  return "Parallelized";
342  case kTensorized:
343  return "Tensorized";
344  }
345  return "Unknown";
346 }
347 } // namespace tir
348 } // namespace tvm
349 #endif // TVM_TIR_VAR_H_
tvm::Span Span
Definition: base.h:65
void FreeVarHashImpl(const runtime::Object *var) const
Implementation for hash for a free var.
Definition: structural_hash.h:193
bool DefEqual(const ObjectRef &lhs, const ObjectRef &rhs)
Reduce condition to comparison of two definitions, where free vars can be mapped. ...
A custom smart pointer for Object.
Definition: object.h:358
Definitions and helper macros for IR/AST nodes.
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode)
static constexpr const uint32_t _type_child_slots
Definition: var.h:83
bool FreeVarEqualImpl(const runtime::Object *lhs, const runtime::Object *rhs) const
Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
Definition: structural_equal.h:313
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
Base expr nodes in TVM.
static constexpr const bool _type_has_method_shash_reduce
Definition: expr.h:59
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:110
void SHashReduce(SHashReducer hash_reduce) const
Definition: var.h:289
Var(ObjectPtr< Object > n)
Definition: var.h:90
String thread_tag
additional tag on the iteration variable, set this if this is binded already to a known thread tag...
Definition: var.h:269
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
a named variable in TIR
Definition: var.h:88
String name_hint
The hint to the variable name.
Definition: var.h:53
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:308
A variable node in the IR.
Definition: var.h:47
Type type_annotation
type annotation of the variable.
Definition: var.h:61
base class of all object containers.
Definition: object.h:167
A variable node represent a tensor index size, whose value must be non-negative.
Definition: var.h:137
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
bool SEqualReduce(const VarNode *other, SEqualReducer equal) const
Definition: var.h:70
An iteration variable representing an iteration over a one dimensional interval.
Definition: var.h:254
IterVarType
Type of iteration variable. Each IterVar have a specific type.
Definition: var.h:185
SizeVar(ObjectPtr< Object > n)
Definition: var.h:146
Range constainer.
Definition: expr.h:715
Definition: source_map.h:120
void VisitAttrs(AttrVisitor *v)
Definition: var.h:276
Span span
Span that points to the original source code. Reserved debug information.
Definition: var.h:274
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:55
Runtime primitive data type.
Definition: data_type.h:41
void VisitAttrs(AttrVisitor *v)
Definition: var.h:63
const VarNode * operator->() const
Get pointer to the internal value.
Definition: var.h:123
const char * IterVarType2String(IterVarType t)
Definition: var.h:324
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
static constexpr const bool _type_has_method_sequal_reduce
Definition: expr.h:58
Communicative reduction. Cannot be directly parallelized.
Definition: var.h:209
Var var
The looping variable.
Definition: var.h:262
IterVarType iter_type
The type of the IterVar.
Definition: var.h:264
Reference to string objects.
Definition: string.h:98
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
The loop is vectorized.
Definition: var.h:237
The execution is unrolled.
Definition: var.h:233
bool SEqualReduce(const IterVarNode *other, SEqualReducer equal) const
Definition: var.h:284
Data parallel iteration. This normally corresponds to axis of Tensor. Allow all IterVar manipulations...
Definition: var.h:194
void SHashReduce(SHashReducer hash_reduce) const
Definition: var.h:76
Serial loops with loop carry dependency, the iteration must execute in order. Cannot be re-ordered...
Definition: var.h:217
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Marks boundary of tensorization intrinsic.
Definition: var.h:245
Base class of all object reference.
Definition: object.h:511
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:785
a named variable represents a tensor index size
Definition: var.h:144
static constexpr const char * _type_key
Definition: var.h:82
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
The IterVar itself is a thread-index of a fixed thread launching group. Note that this is already ass...
Definition: var.h:202
const SizeVarNode * operator->() const
Get pointer to the internal value.
Definition: var.h:166
DataType dtype
The runtime data type of the primitive expression.
Definition: expr.h:101
Managed reference to TypeNode.
Definition: type.h:93
Reference to PrimExprNode.
Definition: expr.h:114
Range dom
the domain of iteration, if known, can be None For the intermediate schedule node, before schedule.
Definition: var.h:260
IterVar is opaque,.
Definition: var.h:227
The loop is parallelized.
Definition: var.h:241
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:164
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:187
Base node of all primitive expressions.
Definition: expr.h:85