25 #ifndef TVM_TIR_EXPR_H_ 26 #define TVM_TIR_EXPR_H_ 43 #include <unordered_map> 59 v->Visit(
"dtype", &
dtype);
60 v->Visit(
"value", &value);
61 v->Visit(
"span", &
span);
70 static constexpr
const char*
_type_key =
"tir.StringImm";
94 v->Visit(
"dtype", &
dtype);
95 v->Visit(
"value", &value);
96 v->Visit(
"span", &
span);
126 template <
typename T>
135 v->Visit(
"dtype", &(this->
dtype));
138 v->Visit(
"span", &
span);
243 static constexpr
const char*
_type_key =
"tir.FloorDiv";
259 static constexpr
const char*
_type_key =
"tir.FloorMod";
308 template <
typename T>
317 v->Visit(
"dtype", &(this->
dtype));
320 v->Visit(
"span", &
span);
441 v->Visit(
"dtype", &(this->
dtype));
444 v->Visit(
"span", &
span);
480 v->Visit(
"dtype", &
dtype);
483 v->Visit(
"span", &
span);
517 v->Visit(
"dtype", &
dtype);
519 v->Visit(
"span", &
span);
562 v->Visit(
"dtype", &
dtype);
563 v->Visit(
"condition", &condition);
564 v->Visit(
"true_value", &true_value);
565 v->Visit(
"false_value", &false_value);
566 v->Visit(
"span", &
span);
576 hash_reduce(condition);
577 hash_reduce(true_value);
578 hash_reduce(false_value);
614 v->Visit(
"dtype", &(this->
dtype));
615 v->Visit(
"buffer", &buffer);
616 v->Visit(
"indices", &indices);
617 v->Visit(
"span", &
span);
628 hash_reduce(indices);
631 static constexpr
const char*
_type_key =
"tir.BufferLoad";
644 void LegalizeDType();
646 friend class CustomDatatypesLowerer;
647 friend class VectorTypeRewriter;
648 friend class Vectorizer;
679 v->Visit(
"dtype", &(this->
dtype));
680 v->Visit(
"producer", &producer);
681 v->Visit(
"indices", &indices);
682 v->Visit(
"span", &
span);
692 hash_reduce(producer);
693 hash_reduce(indices);
696 static constexpr
const char*
_type_key =
"tir.ProducerLoad";
736 v->Visit(
"dtype", &
dtype);
737 v->Visit(
"buffer_var", &buffer_var);
738 v->Visit(
"index", &index);
739 v->Visit(
"predicate", &predicate);
740 v->Visit(
"span", &
span);
750 hash_reduce(buffer_var);
752 hash_reduce(predicate);
789 v->Visit(
"dtype", &
dtype);
790 v->Visit(
"base", &base);
791 v->Visit(
"stride", &stride);
792 v->Visit(
"lanes", &lanes);
793 v->Visit(
"span", &
span);
831 v->Visit(
"dtype", &
dtype);
832 v->Visit(
"value", &value);
833 v->Visit(
"lanes", &lanes);
834 v->Visit(
"span", &
span);
847 static constexpr
const char*
_type_key =
"tir.Broadcast";
874 v->Visit(
"dtype", &
dtype);
875 v->Visit(
"var", &var);
876 v->Visit(
"value", &value);
877 v->Visit(
"body", &body);
878 v->Visit(
"span", &
span);
923 v->Visit(
"dtype", &
dtype);
925 v->Visit(
"args", &args);
926 v->Visit(
"span", &
span);
966 v->Visit(
"dtype", &
dtype);
967 v->Visit(
"vectors", &vectors);
968 v->Visit(
"indices", &indices);
969 v->Visit(
"span", &
span);
979 hash_reduce(vectors);
980 hash_reduce(indices);
1028 v->Visit(
"lhs", &lhs);
1029 v->Visit(
"rhs", &rhs);
1030 v->Visit(
"result", &result);
1031 v->Visit(
"identity_element", &identity_element);
1032 v->Visit(
"span", &span);
1043 hash_reduce(result);
1044 hash_reduce(identity_element);
1085 v->Visit(
"dtype", &
dtype);
1086 v->Visit(
"combiner", &combiner);
1087 v->Visit(
"source", &source);
1088 v->Visit(
"init", &init);
1089 v->Visit(
"axis", &axis);
1090 v->Visit(
"condition", &condition);
1091 v->Visit(
"value_index", &value_index);
1092 v->Visit(
"span", &
span);
1106 hash_reduce(combiner);
1107 hash_reduce(source);
1109 hash_reduce(condition);
1110 hash_reduce(value_index);
1133 v->Visit(
"dtype", &
dtype);
1134 v->Visit(
"span", &
span);
1172 template <
typename K,
typename V>
1174 std::unordered_map<K, V>
ret;
1175 for (
auto kv : dmap) {
1176 ret[kv.first] = kv.second;
1187 #endif // TVM_TIR_EXPR_H_ Managed reference to MulNode.
Definition: expr.h:196
tvm::Span Span
Definition: base.h:65
Let binding. Bind var to value then evaluate body.
Definition: expr.h:864
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:440
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:1132
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:1141
bool DefEqual(const ObjectRef &lhs, const ObjectRef &rhs)
Reduce condition to comparison of two definitions, where free vars can be mapped. ...
Definition: structural_equal.h:165
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:327
Managed reference to CommReducerNode.
Definition: expr.h:1057
PrimExpr body
The result expression.
Definition: expr.h:871
Managed reference to CastNode.
Definition: expr.h:116
Var var
The variable.
Definition: expr.h:867
PrimExpr b
The right operand.
Definition: expr.h:132
PrimExpr predicate
The predicate to mask which lanes would be loaded.
Definition: expr.h:733
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:735
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:788
Definitions and helper macros for IR/AST nodes.
PrimExpr b
The right operand.
Definition: expr.h:477
Managed reference to ReduceNode.
Definition: expr.h:1121
Runtime String container types.
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:873
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:933
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:574
Array< Var > rhs
The right argument of reducer.
Definition: expr.h:1010
Var buffer_var
The buffer variable.
Definition: expr.h:729
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:102
PrimExpr a
The left operand.
Definition: expr.h:312
static constexpr const char * _type_key
Definition: expr.h:70
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:690
static constexpr const bool _type_has_method_shash_reduce
Definition: expr.h:59
Managed reference to LTNode.
Definition: expr.h:378
PrimExpr a
The left operand.
Definition: expr.h:130
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:678
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:102
PrimExpr value
The value to be binded.
Definition: expr.h:869
Array< PrimExpr > init
The init operand.
Definition: expr.h:1073
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:801
String constants, only used in asserts.
Definition: expr.h:53
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
PrimExpr condition
Predicate on the reduction Only add the body to reduction if condition is true.
Definition: expr.h:1080
SizeVar ToSizeVar() const
Convert to SizeVar.
Definition: expr.h:1147
a named variable in TIR
Definition: var.h:88
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:58
Constant floating point literals in the program.
Definition: expr.h:321
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:301
PrimExpr true_value
value to be returned when condition is true.
Definition: expr.h:557
Managed reference to MinNode.
Definition: expr.h:282
PrimExpr b
The right operand.
Definition: expr.h:438
Definition: loop_state.h:456
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:625
a * b
Definition: expr.h:187
Managed reference to GTNode.
Definition: expr.h:410
bool SEqualReduce(const CallNode *other, SEqualReducer equal) const
Definition: expr.h:929
Managed reference to CallNode.
Definition: expr.h:947
Managed reference to GENode.
Definition: expr.h:426
Array< PrimExpr > identity_element
The identity element of reducer, which leaves other elements unchanged when combined with it...
Definition: expr.h:1018
int lanes
The number of lanes.
Definition: expr.h:828
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:1027
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:316
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:965
bool SEqualReduce(const LoadNode *other, SEqualReducer equal) const
Definition: expr.h:743
bool SEqualReduce(const T *other, SEqualReducer equal) const
Definition: expr.h:141
base class of all object containers.
Definition: object.h:167
Any shape.
Definition: expr.h:1130
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:1025
bool SEqualReduce(const LetNode *other, SEqualReducer equal) const
Definition: expr.h:881
Shuffle instruction. vec = concat(vectors) result = (vec[indices[0]], vec[indices[1]] ...
Definition: expr.h:958
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:103
std::unordered_map< K, V > as_unordered_map(const Map< K, V > &dmap)
Definition: expr.h:1173
Managed reference to FloorModNode.
Definition: expr.h:266
a + b
Definition: expr.h:155
Managed reference to LENode.
Definition: expr.h:394
bool SEqualReduce(const OrNode *other, SEqualReducer equal) const
Definition: expr.h:486
Constant integer literals in the program.
Definition: expr.h:275
PrimExpr base
The base value.
Definition: expr.h:782
PrimExpr index
The index locations to be loaded.
Definition: expr.h:731
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:1084
Var ToVar() const
Convert to var.
Definition: expr.h:1144
PrimExpr a
The input operand.
Definition: expr.h:514
a || b
Definition: expr.h:472
Runtime Array container types.
tvm::tir::Any Any
Definition: type.h:45
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:479
bool SEqualReduce(const SelectNode *other, SEqualReducer equal) const
Definition: expr.h:569
Array< PrimExpr > vectors
the input vectors.
Definition: expr.h:961
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:748
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:68
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:922
Managed reference to AnyNode.
Definition: expr.h:1157
Managed reference to DivNode.
Definition: expr.h:215
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:145
Managed reference to NENode.
Definition: expr.h:362
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:55
PrimExpr stride
The stride of each step.
Definition: expr.h:784
Managed reference to BroadcastNode.
Definition: expr.h:855
Runtime primitive data type.
Definition: data_type.h:41
bool SEqualReduce(const CommReducerNode *other, SEqualReducer equal) const
Definition: expr.h:1035
Managed reference to ModNode.
Definition: expr.h:234
Base template to implement comparison ops.
Definition: expr.h:309
CommReducer combiner
The commutative combiner.
Definition: expr.h:1069
Array< PrimExpr > source
The source operand.
Definition: expr.h:1071
bool SEqualReduce(const NotNode *other, SEqualReducer equal) const
Definition: expr.h:522
PrimExpr a
The left operand.
Definition: expr.h:436
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:516
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
ObjectRef hash functor.
Definition: object.h:624
PrimExpr b
The right operand.
Definition: expr.h:314
static constexpr const bool _type_has_method_sequal_reduce
Definition: expr.h:58
bool SEqualReduce(const T *other, SEqualReducer equal) const
Definition: expr.h:323
Managed reference to AddNode.
Definition: expr.h:164
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:526
bool SEqualReduce(const ProducerLoadNode *other, SEqualReducer equal) const
Definition: expr.h:685
Managed reference to LetNode.
Definition: expr.h:901
Create a vector where all the elements are value.
Definition: expr.h:823
PrimExpr value
Original data type.
Definition: expr.h:91
A commutative reducer node to represent a commutative binary operator with identity element...
Definition: expr.h:1005
Array< PrimExpr > indices
The indices of each element.
Definition: expr.h:963
Managed reference to OrNode.
Definition: expr.h:504
a > b
Definition: expr.h:401
tvm::IntImmNode IntImmNode
Definition: expr.h:49
int value_index
the index of this reduce node
Definition: expr.h:1082
Cast value from one data type to another.
Definition: expr.h:88
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:1040
Reference to string objects.
Definition: string.h:124
Managed reference to RelayExprNode.
Definition: expr.h:217
bool SEqualReduce(const ShuffleNode *other, SEqualReducer equal) const
Definition: expr.h:972
TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode)
Managed reference to MaxNode.
Definition: expr.h:298
bool SEqualReduce(const CastNode *other, SEqualReducer equal) const
Definition: expr.h:99
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
bool SEqualReduce(const BroadcastNode *other, SEqualReducer equal) const
Definition: expr.h:837
Managed reference to LoadNode.
Definition: expr.h:763
Array< Var > lhs
The left argument of reducer.
Definition: expr.h:1008
Managed reference to FloorDivNode.
Definition: expr.h:250
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:93
tvm::FloatImmNode FloatImmNode
Definition: expr.h:50
Array< PrimExpr > indices
The location arguments.
Definition: expr.h:676
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:613
RelayExpr op
The operator(function) being invoked.
Definition: expr.h:918
bool SEqualReduce(const RampNode *other, SEqualReducer equal) const
Definition: expr.h:796
Load the value from buffer_var.
Definition: expr.h:726
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Defines the Functor data structures.
Managed reference to SubNode.
Definition: expr.h:180
Base class of all object reference.
Definition: object.h:511
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:785
!a
Definition: expr.h:511
a named variable represents a tensor index size
Definition: var.h:144
max(a, b)
Definition: expr.h:289
Managed reference to DataProducerNode.
Definition: buffer.h:293
bool SEqualReduce(const AndNode *other, SEqualReducer equal) const
Definition: expr.h:447
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:886
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types...
Definition: buffer.h:160
bool SEqualReduce(const AnyNode *other, SEqualReducer equal) const
Definition: expr.h:1137
Array< IterVar > axis
The reduction axis.
Definition: expr.h:1075
Symbolic n-dimensional array, to represent a memory buffer.
String value
The constant value content.
Definition: expr.h:56
PrimExpr false_value
value to be returned when condition is false.
Definition: expr.h:559
Managed reference to SelectNode.
Definition: expr.h:589
Managed reference to RampNode.
Definition: expr.h:816
Array< PrimExpr > indices
The indices location to be loaded.
Definition: expr.h:611
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:561
Managed reference to ProducerLoadNode.
Definition: expr.h:704
Construct a vector with lanes elements where its i-th element equals base + i * stride. This is useful to construct a index for a continuous vector load.
Definition: expr.h:779
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:134
Array< PrimExpr > result
The result of reducer.
Definition: expr.h:1012
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:490
DataType dtype
The runtime data type of the primitive expression.
Definition: expr.h:101
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:1103
min(a, b)
Definition: expr.h:273
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1268
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
The remainder of the floordiv.
Definition: expr.h:257
Managed reference to AndNode.
Definition: expr.h:465
Runtime Map container types.
a == b
Definition: expr.h:337
a && b
Definition: expr.h:433
DataProducer producer
The buffer producer.
Definition: expr.h:674
Base template to implement binary ops.
Definition: expr.h:127
Managed reference to ShuffleNode.
Definition: expr.h:991
Managed reference to NotNode.
Definition: expr.h:539
a < b
Definition: expr.h:369
Load value from the high dimension buffer.
Definition: expr.h:606
bool SEqualReduce(const ReduceNode *other, SEqualReducer equal) const
Definition: expr.h:1095
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:451
bool SEqualReduce(const BufferLoadNode *other, SEqualReducer equal) const
Definition: expr.h:620
PrimExpr value
The base value.
Definition: expr.h:826
a % b in the C semnatics.
Definition: expr.h:225
Reference to PrimExprNode.
Definition: expr.h:112
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:977
PrimExpr a
The left operand.
Definition: expr.h:475
Floor division, floor(a/b)
Definition: expr.h:241
a - b
Definition: expr.h:171
Array< PrimExpr > args
The arguments.
Definition: expr.h:921
Call node.
Definition: expr.h:910
Managed reference to StringImmNode.
Definition: expr.h:78
Buffer buffer
The buffer variable.
Definition: expr.h:609
return true_value if condition is true, otherwise return false_value.
Definition: expr.h:552
Reduction operator operator.
Definition: expr.h:1066
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:728
a / b in the C semnatics.
Definition: expr.h:206
a <= b
Definition: expr.h:385
Managed reference to EQNode.
Definition: expr.h:346
Load value from the result produced by the producer.
Definition: expr.h:671
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:841
bool SEqualReduce(const StringImmNode *other, SEqualReducer equal) const
Definition: expr.h:64
Array< T > Concat(Array< T > lhs, const Array< T > &rhs)
Concat two Arrays.
Definition: array.h:719
a != b
Definition: expr.h:353
Managed reference to BufferLoadNode.
Definition: expr.h:655
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:154
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:179
Base node of all primitive expressions.
Definition: expr.h:85
int lanes
Total number of lanes.
Definition: expr.h:786
a >= b
Definition: expr.h:417
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:830
PrimExpr condition
The condition.
Definition: expr.h:555