tvm
expr.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_IR_EXPR_H_
25 #define TVM_IR_EXPR_H_
26 
27 #include <tvm/ir/span.h>
28 #include <tvm/ir/type.h>
29 #include <tvm/node/node.h>
31 #include <tvm/runtime/object.h>
32 
33 #include <algorithm>
34 #include <limits>
35 #include <string>
36 #include <type_traits>
37 
38 namespace tvm {
39 
41 
46 class BaseExprNode : public Object {
47  public:
52  mutable Span span;
53 
54  static constexpr const char* _type_key = "BaseExpr";
55  static constexpr const bool _type_has_method_sequal_reduce = true;
56  static constexpr const bool _type_has_method_shash_reduce = true;
57  static constexpr const uint32_t _type_child_slots = 62;
59 };
60 
65 class BaseExpr : public ObjectRef {
66  public:
68 };
69 
82 class PrimExprNode : public BaseExprNode {
83  public:
99 
100  static constexpr const char* _type_key = "PrimExpr";
101  static constexpr const uint32_t _type_child_slots = 38;
103 };
104 
109 class PrimExpr : public BaseExpr {
110  public:
115  TVM_DLL PrimExpr(int32_t value); // NOLINT(*)
120  TVM_DLL PrimExpr(float value); // NOLINT(*)
121 
123  DataType dtype() const { return static_cast<const PrimExprNode*>(get())->dtype; }
124 
126 
127  private:
128  // Internal function for conversion.
129  friend struct runtime::PackedFuncValueConverter<PrimExpr>;
130  TVM_DLL static PrimExpr FromObject_(ObjectRef ref);
131 };
132 
142 class RelayExprNode : public BaseExprNode {
143  public:
150  mutable Type checked_type_ = Type(nullptr);
154  inline const Type& checked_type() const;
165  template <typename TTypeNode>
166  inline const TTypeNode* type_as() const;
167 
168  static constexpr const char* _type_key = "RelayExpr";
169  static constexpr const uint32_t _type_child_slots = 22;
171 };
172 
177 class RelayExpr : public BaseExpr {
178  public:
180 };
181 
182 class GlobalVar;
191 class GlobalVarNode : public RelayExprNode {
192  public:
195 
197  v->Visit("name_hint", &name_hint);
198  v->Visit("span", &span);
199  v->Visit("_checked_type_", &checked_type_);
200  }
201 
202  bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const {
203  // name matters for global var.
204  return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other);
205  }
206 
207  void SHashReduce(SHashReducer hash_reduce) const {
208  hash_reduce(name_hint);
209  hash_reduce.FreeVarHashImpl(this);
210  }
211 
212  static constexpr const char* _type_key = "GlobalVar";
214 };
215 
220 class GlobalVar : public RelayExpr {
221  public:
222  TVM_DLL explicit GlobalVar(String name_hint);
223 
225 };
226 
227 // PrimExprs that are useful as runtime containers.
228 //
233 class IntImmNode : public PrimExprNode {
234  public:
236  int64_t value;
237 
239  v->Visit("dtype", &dtype);
240  v->Visit("value", &value);
241  v->Visit("span", &span);
242  }
243 
244  bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const {
245  return equal(dtype, other->dtype) && equal(value, other->value);
246  }
247 
248  void SHashReduce(SHashReducer hash_reduce) const {
249  hash_reduce(dtype);
250  hash_reduce(value);
251  }
252 
253  static constexpr const char* _type_key = "IntImm";
255 };
256 
262 class IntImm : public PrimExpr {
263  public:
270  TVM_DLL IntImm(DataType dtype, int64_t value, Span span = Span());
271 
273 };
274 
279 class FloatImmNode : public PrimExprNode {
280  public:
282  double value;
283 
285  v->Visit("dtype", &dtype);
286  v->Visit("value", &value);
287  v->Visit("span", &span);
288  }
289 
290  bool SEqualReduce(const FloatImmNode* other, SEqualReducer equal) const {
291  return equal(dtype, other->dtype) && equal(value, other->value);
292  }
293 
294  void SHashReduce(SHashReducer hash_reduce) const {
295  hash_reduce(dtype);
296  hash_reduce(value);
297  }
298 
299  static constexpr const char* _type_key = "FloatImm";
301 };
302 
308 class FloatImm : public PrimExpr {
309  public:
316  TVM_DLL FloatImm(DataType dtype, double value, Span span = Span());
317 
319 };
320 
327 class Bool : public IntImm {
328  public:
329  explicit Bool(bool value, Span span = Span()) : IntImm(DataType::Bool(), value, span) {}
330  Bool operator!() const { return Bool((*this)->value == 0); }
331  operator bool() const { return (*this)->value != 0; }
332 
334 };
335 
336 // Overload operators to make sure we have the most fine grained types.
337 inline Bool operator||(const Bool& a, bool b) { return Bool(a.operator bool() || b); }
338 inline Bool operator||(bool a, const Bool& b) { return Bool(a || b.operator bool()); }
339 inline Bool operator||(const Bool& a, const Bool& b) {
340  return Bool(a.operator bool() || b.operator bool());
341 }
342 inline Bool operator&&(const Bool& a, bool b) { return Bool(a.operator bool() && b); }
343 inline Bool operator&&(bool a, const Bool& b) { return Bool(a && b.operator bool()); }
344 inline Bool operator&&(const Bool& a, const Bool& b) {
345  return Bool(a.operator bool() && b.operator bool());
346 }
347 
356 class Integer : public IntImm {
357  public:
358  Integer() {}
362  explicit Integer(ObjectPtr<Object> node) : IntImm(node) {}
366  Integer(int value, Span span = Span()) : IntImm(DataType::Int(32), value, span) {} // NOLINT(*)
371  Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*)
377  template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
378  explicit Integer(Enum value) : Integer(static_cast<int>(value)) {
379  static_assert(std::is_same<int, typename std::underlying_type<Enum>::type>::value,
380  "declare enum to be enum int to use visitor");
381  }
386  Integer& operator=(const IntImm& other) {
387  data_ = ObjectRef::GetDataPtr<Object>(other);
388  return *this;
389  }
393  operator int64_t() const {
394  ICHECK(data_ != nullptr) << " Trying to reference a null Integer";
395  return (*this)->value;
396  }
397  // comparators
398  Bool operator==(int other) const {
399  if (data_ == nullptr) return Bool(false);
400  return Bool((*this)->value == other);
401  }
402  Bool operator!=(int other) const { return !(*this == other); }
403  template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
404  Bool operator==(Enum other) const {
405  return *this == static_cast<int>(other);
406  }
407  template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
408  Bool operator!=(Enum other) const {
409  return *this != static_cast<int>(other);
410  }
411 };
412 
414 class RangeNode : public Object {
415  public:
421  mutable Span span;
424  RangeNode(PrimExpr min, PrimExpr extent, Span span = Span())
425  : min(min), extent(extent), span(span) {}
426 
428  v->Visit("min", &min);
429  v->Visit("extent", &extent);
430  v->Visit("span", &span);
431  }
432 
433  bool SEqualReduce(const RangeNode* other, SEqualReducer equal) const {
434  return equal(min, other->min) && equal(extent, other->extent);
435  }
436 
437  void SHashReduce(SHashReducer hash_reduce) const {
438  hash_reduce(min);
439  hash_reduce(extent);
440  }
441 
442  static constexpr const char* _type_key = "Range";
443  static constexpr const bool _type_has_method_sequal_reduce = true;
444  static constexpr const bool _type_has_method_shash_reduce = true;
446 };
447 
449 class Range : public ObjectRef {
450  public:
457  TVM_DLL Range(PrimExpr begin, PrimExpr end, Span span = Span());
468  static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span = Span());
469  // declare range.
471 };
472 
473 // implementataions
474 inline const Type& RelayExprNode::checked_type() const {
475  ICHECK(checked_type_.defined()) << "internal error: the type checker has "
476  << "not populated the checked_type "
477  << "field for " << GetRef<RelayExpr>(this);
478  return this->checked_type_;
479 }
480 
481 template <typename TTypeNode>
482 inline const TTypeNode* RelayExprNode::type_as() const {
483  static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
484  "TType must be a special case of type");
485  ICHECK(checked_type_.defined())
486  << "Type inference for this Expr has not completed. Try to call infer_type pass.";
487  const TTypeNode* node = checked_type_.as<TTypeNode>();
488  ICHECK(node != nullptr) << "Expected type to be " << TTypeNode::_type_key << ", but get "
489  << checked_type_->GetTypeKey();
490  return node;
491 }
492 
493 } // namespace tvm
494 
495 namespace tvm {
496 namespace runtime {
497 // common rule for RetValue and ArgValue
498 template <>
500  static PrimExpr From(const TVMPODValue_& val) {
501  if (val.type_code() == kTVMNullptr) {
502  return PrimExpr(ObjectPtr<Object>(nullptr));
503  }
504  if (val.type_code() == kDLInt) {
505  return PrimExpr(val.operator int());
506  }
507  if (val.type_code() == kDLFloat) {
508  return PrimExpr(static_cast<float>(val.operator double()));
509  }
510 
511  return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
512  }
513 };
514 
515 template <>
517  static tvm::Integer From(const TVMPODValue_& val) {
518  if (val.type_code() == kTVMNullptr) {
519  return Integer(ObjectPtr<Object>(nullptr));
520  }
521  if (val.type_code() == kTVMArgInt) {
522  return Integer(val.operator int());
523  }
524  return val.AsObjectRef<tvm::Integer>();
525  }
526 };
527 
528 template <>
530  static tvm::Bool From(const TVMPODValue_& val) {
531  if (val.type_code() == kTVMNullptr) {
532  return Bool(ObjectPtr<Object>(nullptr));
533  }
534  if (val.type_code() == kTVMArgInt) {
535  int v = val.operator int();
536  ICHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v;
537  return Bool(static_cast<bool>(v));
538  }
539  return val.AsObjectRef<tvm::Bool>();
540  }
541 };
542 
543 } // namespace runtime
544 } // namespace tvm
545 #endif // TVM_IR_EXPR_H_
Integer(Enum value)
Constructor from enum.
Definition: expr.h:378
tvm::Span Span
Definition: base.h:65
static constexpr const char * _type_key
Definition: expr.h:54
void FreeVarHashImpl(const runtime::Object *var) const
Implementation for hash for a free var.
Definition: structural_hash.h:184
double value
The constant value content.
Definition: expr.h:282
PrimExpr min
beginning of the node
Definition: expr.h:417
const Type & checked_type() const
Definition: expr.h:474
Bool operator &&(const Bool &a, bool b)
Definition: expr.h:342
A custom smart pointer for Object.
Definition: object.h:356
Boolean constant.
Definition: expr.h:327
Definitions and helper macros for IR/AST nodes.
Bool operator||(const Bool &a, bool b)
Definition: expr.h:337
Span information for debugging purposes.
Runtime String container types.
Internal base class to handle conversion to POD values.
Definition: packed_func.h:485
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:437
bool FreeVarEqualImpl(const runtime::Object *lhs, const runtime::Object *rhs) const
Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
Definition: structural_equal.h:190
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:102
static constexpr const bool _type_has_method_shash_reduce
Definition: expr.h:56
Integer()
Definition: expr.h:358
String name_hint
The name of the variable, this only acts as a hint.
Definition: expr.h:194
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
Span span
the location of this range in the source
Definition: expr.h:421
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:101
Definition: c_runtime_api.h:111
Bool operator==(int other) const
Definition: expr.h:398
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
Constant floating point literals in the program.
Definition: expr.h:279
Definition: loop_state.h:456
bool SEqualReduce(const IntImmNode *other, SEqualReducer equal) const
Definition: expr.h:244
DataType dtype() const
Definition: expr.h:123
Integer(int value, Span span=Span())
Construct integer from int value.
Definition: expr.h:366
base class of all object containers.
Definition: object.h:165
Integer & operator=(const IntImm &other)
Assign an expression to integer.
Definition: expr.h:386
Integer(IntImm other)
Construct integer from int imm.
Definition: expr.h:371
Managed reference to BaseExprNode.
Definition: expr.h:65
Constant integer literals in the program.
Definition: expr.h:233
PrimExpr extent
the extend of range
Definition: expr.h:419
Integer(ObjectPtr< Object > node)
constructor from node.
Definition: expr.h:362
Managed reference class to FloatImmNode.
Definition: expr.h:308
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
static tvm::Bool From(const TVMPODValue_ &val)
Definition: expr.h:530
bool SEqualReduce(const FloatImmNode *other, SEqualReducer equal) const
Definition: expr.h:290
Range constainer.
Definition: expr.h:449
Definition: span.h:115
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:52
static constexpr const uint32_t _type_child_slots
Definition: expr.h:57
IR/AST nodes for the unified type system in TVM.
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:207
bool SEqualReduce(const RangeNode *other, SEqualReducer equal) const
Definition: expr.h:433
Runtime primitive data type.
Definition: data_type.h:41
Base type of all the expressions.
Definition: expr.h:46
tvm::GlobalVar GlobalVar
Definition: expr.h:47
static PrimExpr From(const TVMPODValue_ &val)
Definition: expr.h:500
static constexpr const bool _type_has_method_sequal_reduce
Definition: expr.h:55
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:284
Bool operator!() const
Definition: expr.h:330
Managed reference class to IntImmNode.
Definition: expr.h:262
Managed reference to GlobalVarNode.
Definition: expr.h:220
RangeNode(PrimExpr min, PrimExpr extent, Span span=Span())
Definition: expr.h:424
TObjectRef AsObjectRef() const
Definition: packed_func.h:1559
int64_t value
the Internal value.
Definition: expr.h:236
Reference to string objects.
Definition: string.h:129
Managed reference to RelayExprNode.
Definition: expr.h:177
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:706
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:238
bool SEqualReduce(const GlobalVarNode *other, SEqualReducer equal) const
Definition: expr.h:202
Bool operator!=(Enum other) const
Definition: expr.h:408
tvm::Type Type
Definition: type.h:47
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:248
Bool operator!=(int other) const
Definition: expr.h:402
Base class of all object reference.
Definition: object.h:504
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:294
A managed object in the TVM runtime.
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:664
RangeNode()
constructor
Definition: expr.h:423
static tvm::Integer From(const TVMPODValue_ &val)
Definition: expr.h:517
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:427
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:196
const TTypeNode * type_as() const
Check if the inferred(checked) type of the Expr is backed by a TTypeNode and return it...
Definition: expr.h:482
DataType dtype
The runtime data type of the primitive expression.
Definition: expr.h:98
Definition: c_runtime_api.h:114
int type_code() const
Definition: packed_func.h:547
Bool(bool value, Span span=Span())
Definition: expr.h:329
Managed reference to TypeNode.
Definition: type.h:93
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object)
Reference to PrimExprNode.
Definition: expr.h:109
Global variable that lives in the top-level module.
Definition: expr.h:191
Base node of all non-primitive expressions.
Definition: expr.h:142
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:721
Bool operator==(Enum other) const
Definition: expr.h:404
Type trait to specify special value conversion rules from TVMArgValue and TVMRetValue.
Definition: packed_func.h:1037
Base node of all primitive expressions.
Definition: expr.h:82
Container of constant int that adds more constructors.
Definition: expr.h:356
range over one dimension
Definition: expr.h:414