24 #ifndef TVM_IR_EXPR_H_
25 #define TVM_IR_EXPR_H_
37 #include <type_traits>
58 static constexpr
const char*
_type_key =
"BaseExpr";
393 template <
typename TTypeNode>
394 inline const TTypeNode*
type_as()
const;
426 v->Visit(
"span", &
span);
469 v->Visit(
"dtype", &
dtype);
470 v->Visit(
"value", &
value);
471 v->Visit(
"span", &
span);
516 v->Visit(
"dtype", &
dtype);
517 v->Visit(
"value", &
value);
518 v->Visit(
"span", &
span);
563 operator bool()
const {
return (*this)->value != 0; }
572 return Bool(a.operator
bool() || b.operator
bool());
577 return Bool(a.operator
bool() && b.operator
bool());
583 return a.operator bool() == b.operator bool();
615 template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
617 static_assert(std::is_same<
int,
typename std::underlying_type<Enum>::type>::value,
618 "declare enum to be enum int to use visitor");
625 data_ = ObjectRef::GetDataPtr<Object>(other);
632 ICHECK(
data_ !=
nullptr) <<
" Trying to reference a null Integer";
633 return (*this)->value;
637 if (
data_ ==
nullptr)
return Bool(
false);
638 return Bool((*this)->value == other);
641 template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
643 return *
this ==
static_cast<int>(other);
645 template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
647 return *
this !=
static_cast<int>(other);
666 v->Visit(
"min", &
min);
667 v->Visit(
"extent", &
extent);
668 v->Visit(
"span", &
span);
714 <<
"not populated the checked_type "
715 <<
"field for " << GetRef<RelaxExpr>(
this);
719 template <
typename TTypeNode>
721 static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
722 "TType must be a special case of type");
724 <<
"Type inference for this Expr has not completed. Try to call infer_type pass.";
726 ICHECK(node !=
nullptr) <<
"Expected type to be " << TTypeNode::_type_key <<
", but get "
747 template <
typename PODSub
class>
749 if (
auto opt = val.TryAsInt()) {
750 int64_t value = opt.
value();
755 return IntImm(dtype, value);
756 }
else if (
auto opt = val.TryAsBool()) {
763 template <
typename PODSub
class>
765 if (
auto opt = TryFrom(val)) {
768 return val.template AsObjectRef<tvm::IntImm>();
775 template <
typename PODSub
class>
780 return val.template AsObjectRef<tvm::Integer>();
787 template <
typename PODSub
class>
789 if (
auto opt = val.TryAsBool()) {
791 }
else if (
auto opt = val.TryAsInt()) {
792 int value = opt.value();
793 ICHECK(value == 0 || value == 1)
794 <<
"ValueError: boolean value can only be 0 or 1, but get " << value;
795 return tvm::Bool(
static_cast<bool>(value));
801 template <
typename PODSub
class>
803 if (
auto opt = TryFrom(val)) {
806 return val.template AsObjectRef<tvm::Bool>();
821 template <
typename PODSub
class>
823 if (
auto opt = TryFrom(val)) {
826 return val.template AsObjectRef<tvm::FloatImm>();
843 template <
typename PODSub
class>
845 if (val.template IsObjectRef<tvm::IntImm>()) {
846 return runtime::Int(val.template AsObjectRef<tvm::IntImm>()->value);
848 return val.template AsObjectRef<runtime::Int>();
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Base type of all the expressions.
Definition: expr.h:50
static constexpr const char * _type_key
Definition: expr.h:58
static constexpr const bool _type_has_method_shash_reduce
Definition: expr.h:60
static constexpr const uint32_t _type_child_slots
Definition: expr.h:61
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object)
static constexpr const bool _type_has_method_sequal_reduce
Definition: expr.h:59
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:56
Managed reference to BaseExprNode.
Definition: expr.h:69
TVM_DEFINE_OBJECT_REF_METHODS(BaseExpr, ObjectRef, BaseExprNode)
Boolean constant.
Definition: expr.h:559
Bool operator!() const
Definition: expr.h:562
Bool(bool value, Span span=Span())
Definition: expr.h:561
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bool, IntImm, IntImmNode)
Constant floating point literals in the program.
Definition: expr.h:510
bool SEqualReduce(const FloatImmNode *other, SEqualReducer equal) const
Definition: expr.h:521
TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode)
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:515
double value
The constant value content.
Definition: expr.h:513
static constexpr const char * _type_key
Definition: expr.h:530
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:525
Managed reference class to FloatImmNode.
Definition: expr.h:539
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode)
TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode)
FloatImm(DataType dtype, double value, Span span=Span())
Constructor.
Global variable that lives in the top-level module.
Definition: expr.h:419
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelaxExprNode)
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:436
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:424
static constexpr const char * _type_key
Definition: expr.h:441
String name_hint
The name of the variable, this only acts as a hint.
Definition: expr.h:422
bool SEqualReduce(const GlobalVarNode *other, SEqualReducer equal) const
Definition: expr.h:431
Managed reference to GlobalVarNode.
Definition: expr.h:449
GlobalVar(String name_hint, Type type={}, Span span={})
TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode)
TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelaxExpr, GlobalVarNode)
Constant integer literals in the program.
Definition: expr.h:463
TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode)
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:468
static constexpr const char * _type_key
Definition: expr.h:483
int64_t value
the Internal value.
Definition: expr.h:466
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:478
bool SEqualReduce(const IntImmNode *other, SEqualReducer equal) const
Definition: expr.h:474
Managed reference class to IntImmNode.
Definition: expr.h:492
TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode)
TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode)
IntImm(DataType dtype, int64_t value, Span span=Span())
Constructor.
Container of constant int that adds more constructors.
Definition: expr.h:594
Bool operator!=(Enum other) const
Definition: expr.h:646
Integer(ObjectPtr< Object > node)
constructor from node.
Definition: expr.h:600
Integer()
Definition: expr.h:596
Bool operator!=(int other) const
Definition: expr.h:640
Integer(Enum value)
Constructor from enum.
Definition: expr.h:616
Bool operator==(int other) const
Definition: expr.h:636
Integer(IntImm other)
Construct integer from int imm.
Definition: expr.h:609
int64_t IntValue() const
convert to int64_t
Definition: expr.h:631
Bool operator==(Enum other) const
Definition: expr.h:642
Integer & operator=(const IntImm &other)
Assign an expression to integer.
Definition: expr.h:624
Integer(int value, Span span=Span())
Construct integer from int value.
Definition: expr.h:604
Base node of all primitive expressions.
Definition: expr.h:86
static constexpr const uint32_t _type_child_slots
Definition: expr.h:107
TVM_OBJECT_ENABLE_SCRIPT_PRINTER()
DataType dtype
The runtime data type of the primitive expression.
Definition: expr.h:102
static constexpr const char * _type_key
Definition: expr.h:106
TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode)
Reference to PrimExprNode.
Definition: expr.h:115
DataType dtype() const
Definition: expr.h:129
TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode)
PrimExpr(float value)
construct from float.
PrimExpr(int32_t value)
construct from integer.
range over one dimension
Definition: expr.h:652
PrimExpr min
beginning of the node
Definition: expr.h:655
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:665
RangeNode(PrimExpr min, PrimExpr extent, Span span=Span())
Definition: expr.h:662
bool SEqualReduce(const RangeNode *other, SEqualReducer equal) const
Definition: expr.h:671
static constexpr const char * _type_key
Definition: expr.h:680
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:675
TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object)
RangeNode()
constructor
Definition: expr.h:661
static constexpr const bool _type_has_method_shash_reduce
Definition: expr.h:682
PrimExpr extent
the extend of range
Definition: expr.h:657
static constexpr const bool _type_has_method_sequal_reduce
Definition: expr.h:681
Span span
the location of this range in the source
Definition: expr.h:659
Range container
Definition: expr.h:687
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span=Span())
construct a new range with min and extent The corresponding constructor is removed,...
Range(PrimExpr begin, PrimExpr end, Span span=Span())
constructor by begin and end
TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode)
Base node of all non-primitive expressions.
Definition: expr.h:362
Type checked_type_
Stores the result of type inference(type checking).
Definition: expr.h:370
static constexpr const char * _type_key
Definition: expr.h:396
const Type & checked_type() const
Definition: expr.h:712
Optional< ObjectRef > struct_info_
Stores the result of structure information of the expression that encapsulate both static shape and r...
Definition: expr.h:377
TVM_DECLARE_BASE_OBJECT_INFO(RelaxExprNode, BaseExprNode)
const TTypeNode * type_as() const
Check if the inferred(checked) type of the Expr is backed by a TTypeNode and return it.
Definition: expr.h:720
static constexpr const uint32_t _type_child_slots
Definition: expr.h:397
Managed reference to RelaxExprNode.
Definition: expr.h:405
TVM_DEFINE_OBJECT_REF_METHODS(RelaxExpr, BaseExpr, RelaxExprNode)
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:135
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
void FreeVarHashImpl(const runtime::Object *var) const
Implementation for hash for a free var.
Definition: structural_hash.h:203
Definition: source_map.h:120
Managed reference to TypeNode.
Definition: type.h:93
Definition: boxed_primitive.h:81
Runtime primitive data type.
Definition: data_type.h:43
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:244
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:227
A custom smart pointer for Object.
Definition: object.h:363
Base class of all object reference.
Definition: object.h:520
bool defined() const
Definition: object.h:553
const Object * get() const
Definition: object.h:555
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:606
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:911
base class of all object containers.
Definition: object.h:172
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
T value() const
Definition: optional.h:92
Reference to string objects.
Definition: string.h:97
Internal base class to handle conversion to POD values.
Definition: packed_func.h:615
std::optional< double > TryAsFloat() const
Definition: packed_func.h:689
IR/AST nodes for the unified type system in TVM.
Box< bool > Bool
Boxed version of C++ bool.
Definition: boxed_primitive.h:121
Box< int64_t > Int
Boxed version of C++ int64_t.
Definition: boxed_primitive.h:99
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
PrimExpr operator!=(PrimExpr a, PrimExpr b)
not_equal
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr operator/(PrimExpr a, PrimExpr b)
division operator
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
PrimExpr operator>>(PrimExpr a, PrimExpr b)
right shift operator
PrimExpr operator<(PrimExpr a, PrimExpr b)
less
PrimExpr operator|(PrimExpr a, PrimExpr b)
take bitwise or of two values
PrimExpr operator==(PrimExpr a, PrimExpr b)
equal
PrimExpr operator~(PrimExpr a)
take bitwise negation of two values
PrimExpr operator>=(PrimExpr a, PrimExpr b)
greater_equal
PrimExpr operator<=(PrimExpr a, PrimExpr b)
less_equal
PrimExpr operator*(PrimExpr a, PrimExpr b)
multiplication operator
PrimExpr operator&(PrimExpr a, PrimExpr b)
take bitwise and of two values
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
PrimExpr operator!(PrimExpr a)
not
PrimExpr operator^(PrimExpr a, PrimExpr b)
take bitwise xor of two values
PrimExpr operator-(PrimExpr a, PrimExpr b)
subtraction operator
PrimExpr operator||(PrimExpr a, PrimExpr b)
or
PrimExpr operator>(PrimExpr a, PrimExpr b)
greater
PrimExpr operator+(PrimExpr a, PrimExpr b)
add operator
PrimExpr operator<<(PrimExpr a, PrimExpr b)
left shift operator
PrimExpr operator&&(PrimExpr a, PrimExpr b)
and
Definitions and helper macros for IR/AST nodes.
A managed object in the TVM runtime.
A map from source names to source code.
Runtime String container types.
ObjectRef equal functor.
Definition: object.h:666
ObjectRef hash functor.
Definition: object.h:656
static runtime::Int From(const PODSubclass &val)
Definition: expr.h:844
static tvm::Bool From(const PODSubclass &val)
Definition: expr.h:802
static Optional< tvm::Bool > TryFrom(const PODSubclass &val)
Definition: expr.h:788
static Optional< tvm::FloatImm > TryFrom(const TVMPODValue_ &val)
Definition: expr.h:813
static tvm::FloatImm From(const PODSubclass &val)
Definition: expr.h:822
static tvm::IntImm From(const PODSubclass &val)
Definition: expr.h:764
static Optional< tvm::IntImm > TryFrom(const PODSubclass &val)
Definition: expr.h:748
static tvm::Integer From(const PODSubclass &val)
Definition: expr.h:776
Type trait to specify special value conversion rules from TVMArgValue and TVMRetValue.
Definition: packed_func.h:1246