tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
expr.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_IR_EXPR_H_
25 #define TVM_IR_EXPR_H_
26 
27 #include <tvm/ir/source_map.h>
28 #include <tvm/ir/type.h>
29 #include <tvm/node/node.h>
31 #include <tvm/runtime/object.h>
32 
33 #include <algorithm>
34 #include <limits>
35 #include <string>
36 #include <type_traits>
37 
38 namespace tvm {
39 
41 
42 // Forward-declare VirtualDevice to avoid circular imports.
43 class VirtualDevice;
44 
49 class BaseExprNode : public Object {
50  public:
55  mutable Span span;
56 
57  static constexpr const char* _type_key = "BaseExpr";
58  static constexpr const bool _type_has_method_sequal_reduce = true;
59  static constexpr const bool _type_has_method_shash_reduce = true;
60  static constexpr const uint32_t _type_child_slots = 62;
62 };
63 
68 class BaseExpr : public ObjectRef {
69  public:
71 };
72 
85 class PrimExprNode : public BaseExprNode {
86  public:
102 
104 
105  static constexpr const char* _type_key = "PrimExpr";
106  static constexpr const uint32_t _type_child_slots = 38;
108 };
109 
114 class PrimExpr : public BaseExpr {
115  public:
120  TVM_DLL PrimExpr(int32_t value); // NOLINT(*)
125  TVM_DLL PrimExpr(float value); // NOLINT(*)
126 
128  DataType dtype() const { return static_cast<const PrimExprNode*>(get())->dtype; }
129 
131 
132  private:
133  // Internal function for conversion.
134  friend struct runtime::PackedFuncValueConverter<PrimExpr>;
135  TVM_DLL static PrimExpr FromObject_(ObjectRef ref);
136 };
137 
147 TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b);
148 
158 TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b);
159 
168 TVM_DLL PrimExpr operator-(PrimExpr a);
169 
179 TVM_DLL PrimExpr operator*(PrimExpr a, PrimExpr b);
180 
190 TVM_DLL PrimExpr operator/(PrimExpr a, PrimExpr b);
191 
201 TVM_DLL PrimExpr operator<<(PrimExpr a, PrimExpr b);
202 
212 TVM_DLL PrimExpr operator>>(PrimExpr a, PrimExpr b);
213 
223 TVM_DLL PrimExpr operator>(PrimExpr a, PrimExpr b);
224 
234 TVM_DLL PrimExpr operator>=(PrimExpr a, PrimExpr b);
235 
245 TVM_DLL PrimExpr operator<(PrimExpr a, PrimExpr b);
246 
256 TVM_DLL PrimExpr operator<=(PrimExpr a, PrimExpr b);
257 
267 TVM_DLL PrimExpr operator==(PrimExpr a, PrimExpr b);
268 
278 TVM_DLL PrimExpr operator!=(PrimExpr a, PrimExpr b);
279 
288 TVM_DLL PrimExpr operator&&(PrimExpr a, PrimExpr b);
289 
298 TVM_DLL PrimExpr operator||(PrimExpr a, PrimExpr b);
299 
307 TVM_DLL PrimExpr operator!(PrimExpr a);
308 
318 TVM_DLL PrimExpr operator&(PrimExpr a, PrimExpr b);
319 
329 TVM_DLL PrimExpr operator|(PrimExpr a, PrimExpr b);
330 
340 TVM_DLL PrimExpr operator^(PrimExpr a, PrimExpr b);
341 
350 TVM_DLL PrimExpr operator~(PrimExpr a);
351 
361 class RelayExprNode : public BaseExprNode {
362  public:
369  mutable Type checked_type_ = Type(nullptr);
373  inline const Type& checked_type() const;
384  template <typename TTypeNode>
385  inline const TTypeNode* type_as() const;
386 
410 
422  VirtualDevice virtual_device() const;
423 
424  static constexpr const char* _type_key = "RelayExpr";
425  static constexpr const uint32_t _type_child_slots = 22;
427 };
428 
433 class RelayExpr : public BaseExpr {
434  public:
436 };
437 
438 class GlobalVar;
447 class GlobalVarNode : public RelayExprNode {
448  public:
451 
453  v->Visit("name_hint", &name_hint);
454  v->Visit("virtual_device_", &virtual_device_);
455  v->Visit("span", &span);
456  v->Visit("_checked_type_", &checked_type_);
457  }
458 
459  bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const {
460  // name matters for global var.
461  return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other);
462  }
463 
464  void SHashReduce(SHashReducer hash_reduce) const {
465  hash_reduce(name_hint);
466  hash_reduce.FreeVarHashImpl(this);
467  }
468 
469  static constexpr const char* _type_key = "GlobalVar";
471 };
472 
477 class GlobalVar : public RelayExpr {
478  public:
479  TVM_DLL explicit GlobalVar(String name_hint, Type type = {}, Span span = {});
480 
483 };
484 
485 // PrimExprs that are useful as runtime containers.
486 //
491 class IntImmNode : public PrimExprNode {
492  public:
494  int64_t value;
495 
497  v->Visit("dtype", &dtype);
498  v->Visit("value", &value);
499  v->Visit("span", &span);
500  }
501 
502  bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const {
503  return equal(dtype, other->dtype) && equal(value, other->value);
504  }
505 
506  void SHashReduce(SHashReducer hash_reduce) const {
507  hash_reduce(dtype);
508  hash_reduce(value);
509  }
510 
511  static constexpr const char* _type_key = "IntImm";
513 };
514 
520 class IntImm : public PrimExpr {
521  public:
528  TVM_DLL IntImm(DataType dtype, int64_t value, Span span = Span());
529 
532 };
533 
538 class FloatImmNode : public PrimExprNode {
539  public:
541  double value;
542 
544  v->Visit("dtype", &dtype);
545  v->Visit("value", &value);
546  v->Visit("span", &span);
547  }
548 
549  bool SEqualReduce(const FloatImmNode* other, SEqualReducer equal) const {
550  return equal(dtype, other->dtype) && equal(value, other->value);
551  }
552 
553  void SHashReduce(SHashReducer hash_reduce) const {
554  hash_reduce(dtype);
555  hash_reduce(value);
556  }
557 
558  static constexpr const char* _type_key = "FloatImm";
560 };
561 
567 class FloatImm : public PrimExpr {
568  public:
575  TVM_DLL FloatImm(DataType dtype, double value, Span span = Span());
576 
579 };
580 
587 class Bool : public IntImm {
588  public:
589  explicit Bool(bool value, Span span = Span()) : IntImm(DataType::Bool(), value, span) {}
590  Bool operator!() const { return Bool((*this)->value == 0); }
591  operator bool() const { return (*this)->value != 0; }
592 
594 };
595 
596 // Overload operators to make sure we have the most fine grained types.
597 inline Bool operator||(const Bool& a, bool b) { return Bool(a.operator bool() || b); }
598 inline Bool operator||(bool a, const Bool& b) { return Bool(a || b.operator bool()); }
599 inline Bool operator||(const Bool& a, const Bool& b) {
600  return Bool(a.operator bool() || b.operator bool());
601 }
602 inline Bool operator&&(const Bool& a, bool b) { return Bool(a.operator bool() && b); }
603 inline Bool operator&&(bool a, const Bool& b) { return Bool(a && b.operator bool()); }
604 inline Bool operator&&(const Bool& a, const Bool& b) {
605  return Bool(a.operator bool() && b.operator bool());
606 }
607 
608 inline bool operator==(const Bool& a, bool b) { return a.operator bool() == b; }
609 inline bool operator==(bool a, const Bool& b) { return a == b.operator bool(); }
610 inline bool operator==(const Bool& a, const Bool& b) {
611  return a.operator bool() == b.operator bool();
612 }
613 
622 class Integer : public IntImm {
623  public:
624  Integer() {}
628  explicit Integer(ObjectPtr<Object> node) : IntImm(node) {}
632  Integer(int value, Span span = Span()) : IntImm(DataType::Int(32), value, span) {} // NOLINT(*)
637  Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*)
643  template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
644  explicit Integer(Enum value) : Integer(static_cast<int>(value)) {
645  static_assert(std::is_same<int, typename std::underlying_type<Enum>::type>::value,
646  "declare enum to be enum int to use visitor");
647  }
652  Integer& operator=(const IntImm& other) {
653  data_ = ObjectRef::GetDataPtr<Object>(other);
654  return *this;
655  }
659  int64_t IntValue() const {
660  ICHECK(data_ != nullptr) << " Trying to reference a null Integer";
661  return (*this)->value;
662  }
663  // comparators
664  Bool operator==(int other) const {
665  if (data_ == nullptr) return Bool(false);
666  return Bool((*this)->value == other);
667  }
668  Bool operator!=(int other) const { return !(*this == other); }
669  template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
670  Bool operator==(Enum other) const {
671  return *this == static_cast<int>(other);
672  }
673  template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
674  Bool operator!=(Enum other) const {
675  return *this != static_cast<int>(other);
676  }
677 };
678 
680 class RangeNode : public Object {
681  public:
687  mutable Span span;
690  RangeNode(PrimExpr min, PrimExpr extent, Span span = Span())
691  : min(min), extent(extent), span(span) {}
692 
694  v->Visit("min", &min);
695  v->Visit("extent", &extent);
696  v->Visit("span", &span);
697  }
698 
699  bool SEqualReduce(const RangeNode* other, SEqualReducer equal) const {
700  return equal(min, other->min) && equal(extent, other->extent);
701  }
702 
703  void SHashReduce(SHashReducer hash_reduce) const {
704  hash_reduce(min);
705  hash_reduce(extent);
706  }
707 
708  static constexpr const char* _type_key = "Range";
709  static constexpr const bool _type_has_method_sequal_reduce = true;
710  static constexpr const bool _type_has_method_shash_reduce = true;
712 };
713 
715 class Range : public ObjectRef {
716  public:
723  TVM_DLL Range(PrimExpr begin, PrimExpr end, Span span = Span());
734  static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span = Span());
735  // declare range.
737 };
738 
739 // implementataions
740 inline const Type& RelayExprNode::checked_type() const {
741  ICHECK(checked_type_.defined()) << "internal error: the type checker has "
742  << "not populated the checked_type "
743  << "field for " << GetRef<RelayExpr>(this);
744  return this->checked_type_;
745 }
746 
747 template <typename TTypeNode>
748 inline const TTypeNode* RelayExprNode::type_as() const {
749  static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
750  "TType must be a special case of type");
751  ICHECK(checked_type_.defined())
752  << "Type inference for this Expr has not completed. Try to call infer_type pass.";
753  const TTypeNode* node = checked_type_.as<TTypeNode>();
754  ICHECK(node != nullptr) << "Expected type to be " << TTypeNode::_type_key << ", but get "
755  << checked_type_->GetTypeKey();
756  return node;
757 }
758 
759 } // namespace tvm
760 
761 namespace tvm {
762 namespace runtime {
763 // common rule for RetValue and ArgValue
764 template <>
766  static PrimExpr From(const TVMPODValue_& val) {
767  if (val.type_code() == kTVMNullptr) {
768  return PrimExpr(ObjectPtr<Object>(nullptr));
769  }
770  if (val.type_code() == kDLInt) {
771  int64_t value = val.operator int64_t();
773  return IntImm(runtime::DataType::Int(64), value);
774  }
775  return IntImm(runtime::DataType::Int(32), val.operator int());
776  }
777  if (val.type_code() == kDLFloat) {
778  return FloatImm(runtime::DataType::Float(32), val.operator double());
779  }
780 
781  return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
782  }
783 };
784 
785 template <>
787  static tvm::Integer From(const TVMPODValue_& val) {
788  if (val.type_code() == kTVMNullptr) {
789  return Integer(ObjectPtr<Object>(nullptr));
790  }
791  if (val.type_code() == kTVMArgInt) {
792  return Integer(val.operator int());
793  }
794  return val.AsObjectRef<tvm::Integer>();
795  }
796 };
797 
798 template <>
800  static tvm::Bool From(const TVMPODValue_& val) {
801  if (val.type_code() == kTVMNullptr) {
802  return Bool(ObjectPtr<Object>(nullptr));
803  }
804  if (val.type_code() == kTVMArgInt) {
805  int v = val.operator int();
806  ICHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v;
807  return Bool(static_cast<bool>(v));
808  }
809  return val.AsObjectRef<tvm::Bool>();
810  }
811 };
812 
813 } // namespace runtime
814 } // namespace tvm
815 #endif // TVM_IR_EXPR_H_
Integer(Enum value)
Constructor from enum.
Definition: expr.h:644
PrimExpr operator!=(PrimExpr a, PrimExpr b)
not_equal
tvm::Span Span
Definition: base.h:65
static constexpr const char * _type_key
Definition: expr.h:57
void FreeVarHashImpl(const runtime::Object *var) const
Implementation for hash for a free var.
Definition: structural_hash.h:193
PrimExpr operator<(PrimExpr a, PrimExpr b)
less
double value
The constant value content.
Definition: expr.h:541
PrimExpr min
beginning of the node
Definition: expr.h:683
const Type & checked_type() const
Definition: expr.h:740
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
A custom smart pointer for Object.
Definition: object.h:358
PrimExpr operator||(PrimExpr a, PrimExpr b)
or
Boolean constant.
Definition: expr.h:587
Definitions and helper macros for IR/AST nodes.
Runtime String container types.
Internal base class to handle conversion to POD values.
Definition: packed_func.h:541
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:703
bool FreeVarEqualImpl(const runtime::Object *lhs, const runtime::Object *rhs) const
Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
Definition: structural_equal.h:313
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
static constexpr const bool _type_has_method_shash_reduce
Definition: expr.h:59
Integer()
Definition: expr.h:624
String name_hint
The name of the variable, this only acts as a hint.
Definition: expr.h:450
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Span span
the location of this range in the source
Definition: expr.h:687
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:110
Definition: c_runtime_api.h:175
Bool operator==(int other) const
Definition: expr.h:664
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
Constant floating point literals in the program.
Definition: expr.h:538
Definition: loop_state.h:456
bool SEqualReduce(const IntImmNode *other, SEqualReducer equal) const
Definition: expr.h:502
A map from source names to source code.
DataType dtype() const
Definition: expr.h:128
Integer(int value, Span span=Span())
Construct integer from int value.
Definition: expr.h:632
base class of all object containers.
Definition: object.h:167
PrimExpr operator-(PrimExpr a, PrimExpr b)
subtraction operator
Integer & operator=(const IntImm &other)
Assign an expression to integer.
Definition: expr.h:652
Integer(IntImm other)
Construct integer from int imm.
Definition: expr.h:637
Managed reference to BaseExprNode.
Definition: expr.h:68
Constant integer literals in the program.
Definition: expr.h:491
PrimExpr extent
the extend of range
Definition: expr.h:685
PrimExpr operator &&(PrimExpr a, PrimExpr b)
and
Integer(ObjectPtr< Object > node)
constructor from node.
Definition: expr.h:628
Managed reference class to FloatImmNode.
Definition: expr.h:567
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
PrimExpr operator &(PrimExpr a, PrimExpr b)
take bitwise and of two values
static tvm::Bool From(const TVMPODValue_ &val)
Definition: expr.h:800
bool SEqualReduce(const FloatImmNode *other, SEqualReducer equal) const
Definition: expr.h:549
Managed reference class to VirtualDeviceNode.
Definition: virtual_device.h:271
Range constainer.
Definition: expr.h:715
Definition: source_map.h:120
PrimExpr operator!(PrimExpr a)
not
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:55
int64_t IntValue() const
convert to int64_t
Definition: expr.h:659
static constexpr const uint32_t _type_child_slots
Definition: expr.h:60
IR/AST nodes for the unified type system in TVM.
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:464
bool SEqualReduce(const RangeNode *other, SEqualReducer equal) const
Definition: expr.h:699
Runtime primitive data type.
Definition: data_type.h:41
Base type of all the expressions.
Definition: expr.h:49
tvm::GlobalVar GlobalVar
Definition: expr.h:58
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:178
static PrimExpr From(const TVMPODValue_ &val)
Definition: expr.h:766
static constexpr const bool _type_has_method_sequal_reduce
Definition: expr.h:58
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:543
Bool operator!() const
Definition: expr.h:590
Managed reference class to IntImmNode.
Definition: expr.h:520
PrimExpr operator<<(PrimExpr a, PrimExpr b)
left shift operator
Managed reference to GlobalVarNode.
Definition: expr.h:477
PrimExpr operator^(PrimExpr a, PrimExpr b)
take bitwise xor of two values
RangeNode(PrimExpr min, PrimExpr extent, Span span=Span())
Definition: expr.h:690
ObjectRef virtual_device_
The virtual device (VirtualDevice) for this node (the result of device planning). For first-order exp...
Definition: expr.h:409
TObjectRef AsObjectRef() const
Definition: packed_func.h:1826
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
int64_t value
the Internal value.
Definition: expr.h:494
Reference to string objects.
Definition: string.h:98
Managed reference to RelayExprNode.
Definition: expr.h:433
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:496
PrimExpr operator>>(PrimExpr a, PrimExpr b)
right shift operator
bool SEqualReduce(const GlobalVarNode *other, SEqualReducer equal) const
Definition: expr.h:459
PrimExpr operator==(PrimExpr a, PrimExpr b)
equal
Bool operator!=(Enum other) const
Definition: expr.h:674
tvm::Type Type
Definition: type.h:47
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:506
#define TVM_OBJECT_ENABLE_SCRIPT_PRINTER()
Definition: script_printer.h:115
PrimExpr operator>=(PrimExpr a, PrimExpr b)
greater_equal
Bool operator!=(int other) const
Definition: expr.h:668
Base class of all object reference.
Definition: object.h:511
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:785
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:553
PrimExpr operator*(PrimExpr a, PrimExpr b)
multiplication operator
A managed object in the TVM runtime.
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
RangeNode()
constructor
Definition: expr.h:689
static tvm::Integer From(const TVMPODValue_ &val)
Definition: expr.h:787
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:693
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:452
const TTypeNode * type_as() const
Check if the inferred(checked) type of the Expr is backed by a TTypeNode and return it...
Definition: expr.h:748
DataType dtype
The runtime data type of the primitive expression.
Definition: expr.h:101
Definition: c_runtime_api.h:178
PrimExpr operator/(PrimExpr a, PrimExpr b)
division operator
int type_code() const
Definition: packed_func.h:610
Bool(bool value, Span span=Span())
Definition: expr.h:589
PrimExpr operator<=(PrimExpr a, PrimExpr b)
less_equal
Managed reference to TypeNode.
Definition: type.h:93
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object)
PrimExpr operator~(PrimExpr a)
take bitwise negation of two values
Reference to PrimExprNode.
Definition: expr.h:114
Global variable that lives in the top-level module.
Definition: expr.h:447
PrimExpr operator|(PrimExpr a, PrimExpr b)
take bitwise or of two values
Base node of all non-primitive expressions.
Definition: expr.h:361
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:728
PrimExpr operator+(PrimExpr a, PrimExpr b)
add operator
Bool operator==(Enum other) const
Definition: expr.h:670
Type trait to specify special value conversion rules from TVMArgValue and TVMRetValue.
Definition: packed_func.h:1096
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:164
Base node of all primitive expressions.
Definition: expr.h:85
Container of constant int that adds more constructors.
Definition: expr.h:622
PrimExpr operator>(PrimExpr a, PrimExpr b)
greater
range over one dimension
Definition: expr.h:680