25 #ifndef TVM_TIR_EXPR_H_ 26 #define TVM_TIR_EXPR_H_ 43 #include <unordered_map> 59 v->Visit(
"dtype", &
dtype);
60 v->Visit(
"value", &value);
61 v->Visit(
"span", &
span);
70 static constexpr
const char*
_type_key =
"tir.StringImm";
95 v->Visit(
"dtype", &
dtype);
96 v->Visit(
"value", &value);
97 v->Visit(
"span", &
span);
128 template <
typename T>
137 v->Visit(
"dtype", &(this->
dtype));
140 v->Visit(
"span", &
span);
250 static constexpr
const char*
_type_key =
"tir.FloorDiv";
267 static constexpr
const char*
_type_key =
"tir.FloorMod";
319 template <
typename T>
328 v->Visit(
"dtype", &(this->
dtype));
331 v->Visit(
"span", &
span);
458 v->Visit(
"dtype", &(this->
dtype));
461 v->Visit(
"span", &
span);
498 v->Visit(
"dtype", &
dtype);
501 v->Visit(
"span", &
span);
536 v->Visit(
"dtype", &
dtype);
538 v->Visit(
"span", &
span);
582 v->Visit(
"dtype", &
dtype);
583 v->Visit(
"condition", &condition);
584 v->Visit(
"true_value", &true_value);
585 v->Visit(
"false_value", &false_value);
586 v->Visit(
"span", &
span);
596 hash_reduce(condition);
597 hash_reduce(true_value);
598 hash_reduce(false_value);
635 v->Visit(
"dtype", &(this->
dtype));
636 v->Visit(
"buffer", &buffer);
637 v->Visit(
"indices", &indices);
638 v->Visit(
"span", &
span);
649 hash_reduce(indices);
652 static constexpr
const char*
_type_key =
"tir.BufferLoad";
665 void LegalizeDType();
667 friend class CustomDatatypesLowerer;
668 friend class VectorTypeRewriter;
669 friend class Vectorizer;
700 v->Visit(
"dtype", &(this->
dtype));
701 v->Visit(
"producer", &producer);
702 v->Visit(
"indices", &indices);
703 v->Visit(
"span", &
span);
713 hash_reduce(producer);
714 hash_reduce(indices);
717 static constexpr
const char*
_type_key =
"tir.ProducerLoad";
758 v->Visit(
"dtype", &
dtype);
759 v->Visit(
"buffer_var", &buffer_var);
760 v->Visit(
"index", &index);
761 v->Visit(
"predicate", &predicate);
762 v->Visit(
"span", &
span);
772 hash_reduce(buffer_var);
774 hash_reduce(predicate);
812 v->Visit(
"dtype", &
dtype);
813 v->Visit(
"base", &base);
814 v->Visit(
"stride", &stride);
815 v->Visit(
"lanes", &lanes);
816 v->Visit(
"span", &
span);
855 v->Visit(
"dtype", &
dtype);
856 v->Visit(
"value", &value);
857 v->Visit(
"lanes", &lanes);
858 v->Visit(
"span", &
span);
871 static constexpr
const char*
_type_key =
"tir.Broadcast";
899 v->Visit(
"dtype", &
dtype);
900 v->Visit(
"var", &var);
901 v->Visit(
"value", &value);
902 v->Visit(
"body", &body);
903 v->Visit(
"span", &
span);
949 v->Visit(
"dtype", &
dtype);
951 v->Visit(
"args", &args);
952 v->Visit(
"span", &
span);
993 v->Visit(
"dtype", &
dtype);
994 v->Visit(
"vectors", &vectors);
995 v->Visit(
"indices", &indices);
996 v->Visit(
"span", &
span);
1006 hash_reduce(vectors);
1007 hash_reduce(indices);
1056 v->Visit(
"lhs", &lhs);
1057 v->Visit(
"rhs", &rhs);
1058 v->Visit(
"result", &result);
1059 v->Visit(
"identity_element", &identity_element);
1060 v->Visit(
"span", &span);
1071 hash_reduce(result);
1072 hash_reduce(identity_element);
1113 v->Visit(
"dtype", &
dtype);
1114 v->Visit(
"combiner", &combiner);
1115 v->Visit(
"source", &source);
1116 v->Visit(
"init", &init);
1117 v->Visit(
"axis", &axis);
1118 v->Visit(
"condition", &condition);
1119 v->Visit(
"value_index", &value_index);
1120 v->Visit(
"span", &
span);
1134 hash_reduce(combiner);
1135 hash_reduce(source);
1137 hash_reduce(condition);
1138 hash_reduce(value_index);
1162 v->Visit(
"dtype", &
dtype);
1163 v->Visit(
"span", &
span);
1202 template <
typename K,
typename V>
1204 std::unordered_map<K, V>
ret;
1205 for (
auto kv : dmap) {
1206 ret[kv.first] = kv.second;
1217 #endif // TVM_TIR_EXPR_H_ Managed reference to MulNode.
Definition: expr.h:200
tvm::Span Span
Definition: base.h:65
Let binding. Bind var to value then evaluate body.
Definition: expr.h:889
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:457
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:1161
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:1170
bool DefEqual(const ObjectRef &lhs, const ObjectRef &rhs)
Reduce condition to comparison of two definitions, where free vars can be mapped. ...
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:338
Managed reference to CommReducerNode.
Definition: expr.h:1085
PrimExpr body
The result expression.
Definition: expr.h:896
Managed reference to CastNode.
Definition: expr.h:117
Var var
The variable.
Definition: expr.h:892
PrimExpr b
The right operand.
Definition: expr.h:134
PrimExpr predicate
The predicate to mask which lanes would be loaded.
Definition: expr.h:755
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:757
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:811
Definitions and helper macros for IR/AST nodes.
PrimExpr b
The right operand.
Definition: expr.h:495
Managed reference to ReduceNode.
Definition: expr.h:1149
Runtime String container types.
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:898
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:959
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:594
Array< Var > rhs
The right argument of reducer.
Definition: expr.h:1038
Var buffer_var
The buffer variable.
Definition: expr.h:751
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
PrimExpr a
The left operand.
Definition: expr.h:323
static constexpr const char * _type_key
Definition: expr.h:70
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:711
static constexpr const bool _type_has_method_shash_reduce
Definition: expr.h:59
Managed reference to LTNode.
Definition: expr.h:391
PrimExpr a
The left operand.
Definition: expr.h:132
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:699
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:102
PrimExpr value
The value to be binded.
Definition: expr.h:894
Array< PrimExpr > init
The init operand.
Definition: expr.h:1101
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:824
String constants, only used in asserts.
Definition: expr.h:53
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
PrimExpr condition
Predicate on the reduction Only add the body to reduction if condition is true.
Definition: expr.h:1108
SizeVar ToSizeVar() const
Convert to SizeVar.
Definition: expr.h:1176
a named variable in TIR
Definition: var.h:88
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:58
Constant floating point literals in the program.
Definition: expr.h:536
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:301
PrimExpr true_value
value to be returned when condition is true.
Definition: expr.h:577
Managed reference to MinNode.
Definition: expr.h:291
PrimExpr b
The right operand.
Definition: expr.h:455
Definition: loop_state.h:456
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:646
a * b
Definition: expr.h:191
Managed reference to GTNode.
Definition: expr.h:425
bool SEqualReduce(const CallNode *other, SEqualReducer equal) const
Definition: expr.h:955
Managed reference to CallNode.
Definition: expr.h:973
Managed reference to GENode.
Definition: expr.h:442
Array< PrimExpr > identity_element
The identity element of reducer, which leaves other elements unchanged when combined with it...
Definition: expr.h:1046
int lanes
The number of lanes.
Definition: expr.h:852
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:1055
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:327
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:992
bool SEqualReduce(const LoadNode *other, SEqualReducer equal) const
Definition: expr.h:765
bool SEqualReduce(const T *other, SEqualReducer equal) const
Definition: expr.h:143
base class of all object containers.
Definition: object.h:167
Any shape.
Definition: expr.h:1159
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:1053
bool SEqualReduce(const LetNode *other, SEqualReducer equal) const
Definition: expr.h:906
Shuffle instruction. vec = concat(vectors) result = (vec[indices[0]], vec[indices[1]] ...
Definition: expr.h:985
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:104
std::unordered_map< K, V > as_unordered_map(const Map< K, V > &dmap)
Definition: expr.h:1203
Managed reference to FloorModNode.
Definition: expr.h:274
a + b
Definition: expr.h:157
Managed reference to LENode.
Definition: expr.h:408
bool SEqualReduce(const OrNode *other, SEqualReducer equal) const
Definition: expr.h:504
Constant integer literals in the program.
Definition: expr.h:489
PrimExpr base
The base value.
Definition: expr.h:805
PrimExpr index
The index locations to be loaded.
Definition: expr.h:753
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:1112
Var ToVar() const
Convert to var.
Definition: expr.h:1173
PrimExpr a
The input operand.
Definition: expr.h:533
a || b
Definition: expr.h:490
Runtime Array container types.
tvm::tir::Any Any
Definition: type.h:45
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:497
bool SEqualReduce(const SelectNode *other, SEqualReducer equal) const
Definition: expr.h:589
Array< PrimExpr > vectors
the input vectors.
Definition: expr.h:988
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:770
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:68
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:948
Managed reference to AnyNode.
Definition: expr.h:1186
Managed reference to DivNode.
Definition: expr.h:220
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:147
Managed reference to NENode.
Definition: expr.h:374
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:55
PrimExpr stride
The stride of each step.
Definition: expr.h:807
Managed reference to BroadcastNode.
Definition: expr.h:879
Runtime primitive data type.
Definition: data_type.h:41
bool SEqualReduce(const CommReducerNode *other, SEqualReducer equal) const
Definition: expr.h:1063
Managed reference to ModNode.
Definition: expr.h:240
Base template to implement comparison ops.
Definition: expr.h:320
CommReducer combiner
The commutative combiner.
Definition: expr.h:1097
Array< PrimExpr > source
The source operand.
Definition: expr.h:1099
bool SEqualReduce(const NotNode *other, SEqualReducer equal) const
Definition: expr.h:541
PrimExpr a
The left operand.
Definition: expr.h:453
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:535
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
ObjectRef hash functor.
Definition: object.h:624
PrimExpr b
The right operand.
Definition: expr.h:325
static constexpr const bool _type_has_method_sequal_reduce
Definition: expr.h:58
bool SEqualReduce(const T *other, SEqualReducer equal) const
Definition: expr.h:334
Managed reference to AddNode.
Definition: expr.h:166
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:545
bool SEqualReduce(const ProducerLoadNode *other, SEqualReducer equal) const
Definition: expr.h:706
Managed reference to LetNode.
Definition: expr.h:926
Create a vector where all the elements are value.
Definition: expr.h:847
PrimExpr value
Original data type.
Definition: expr.h:92
A commutative reducer node to represent a commutative binary operator with identity element...
Definition: expr.h:1033
Array< PrimExpr > indices
The indices of each element.
Definition: expr.h:990
Managed reference to OrNode.
Definition: expr.h:522
a > b
Definition: expr.h:416
tvm::IntImmNode IntImmNode
Definition: expr.h:49
int value_index
the index of this reduce node
Definition: expr.h:1110
Cast value from one data type to another.
Definition: expr.h:89
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:1068
Reference to string objects.
Definition: string.h:97
LetFrame Let(Var var, PrimExpr value)
The let binding.
Managed reference to RelayExprNode.
Definition: expr.h:431
bool SEqualReduce(const ShuffleNode *other, SEqualReducer equal) const
Definition: expr.h:999
TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode)
Managed reference to MaxNode.
Definition: expr.h:308
bool SEqualReduce(const CastNode *other, SEqualReducer equal) const
Definition: expr.h:100
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
bool SEqualReduce(const BroadcastNode *other, SEqualReducer equal) const
Definition: expr.h:861
Managed reference to LoadNode.
Definition: expr.h:785
Array< Var > lhs
The left argument of reducer.
Definition: expr.h:1036
Managed reference to FloorDivNode.
Definition: expr.h:257
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:94
tvm::FloatImmNode FloatImmNode
Definition: expr.h:50
Array< PrimExpr > indices
The location arguments.
Definition: expr.h:697
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:634
RelayExpr op
The operator(function) being invoked.
Definition: expr.h:944
bool SEqualReduce(const RampNode *other, SEqualReducer equal) const
Definition: expr.h:819
Load the value from buffer_var.
Definition: expr.h:748
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Defines the Functor data structures.
Managed reference to SubNode.
Definition: expr.h:183
Base class of all object reference.
Definition: object.h:511
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:785
!a
Definition: expr.h:530
a named variable represents a tensor index size
Definition: var.h:144
max(a, b)
Definition: expr.h:299
Managed reference to DataProducerNode.
Definition: buffer.h:293
bool SEqualReduce(const AndNode *other, SEqualReducer equal) const
Definition: expr.h:464
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:911
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types...
Definition: buffer.h:160
bool SEqualReduce(const AnyNode *other, SEqualReducer equal) const
Definition: expr.h:1166
Array< IterVar > axis
The reduction axis.
Definition: expr.h:1103
Symbolic n-dimensional array, to represent a memory buffer.
String value
The constant value content.
Definition: expr.h:56
PrimExpr false_value
value to be returned when condition is false.
Definition: expr.h:579
Managed reference to SelectNode.
Definition: expr.h:609
Managed reference to RampNode.
Definition: expr.h:839
Array< PrimExpr > indices
The indices location to be loaded.
Definition: expr.h:632
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:581
Managed reference to ProducerLoadNode.
Definition: expr.h:725
Construct a vector with lanes elements where its i-th element equals base + i * stride. This is useful to construct a index for a continuous vector load.
Definition: expr.h:802
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:136
Array< PrimExpr > result
The result of reducer.
Definition: expr.h:1040
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:508
DataType dtype
The runtime data type of the primitive expression.
Definition: expr.h:101
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:1131
min(a, b)
Definition: expr.h:282
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
The remainder of the floordiv.
Definition: expr.h:265
Managed reference to AndNode.
Definition: expr.h:482
Runtime Map container types.
a == b
Definition: expr.h:348
a && b
Definition: expr.h:450
DataProducer producer
The buffer producer.
Definition: expr.h:695
Base template to implement binary ops.
Definition: expr.h:129
Managed reference to ShuffleNode.
Definition: expr.h:1018
Managed reference to NotNode.
Definition: expr.h:558
a < b
Definition: expr.h:382
Load value from the high dimension buffer.
Definition: expr.h:627
bool SEqualReduce(const ReduceNode *other, SEqualReducer equal) const
Definition: expr.h:1123
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:468
bool SEqualReduce(const BufferLoadNode *other, SEqualReducer equal) const
Definition: expr.h:641
PrimExpr value
The base value.
Definition: expr.h:850
a % b in the C semnatics.
Definition: expr.h:231
Reference to PrimExprNode.
Definition: expr.h:112
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:1004
PrimExpr a
The left operand.
Definition: expr.h:493
Floor division, floor(a/b)
Definition: expr.h:248
a - b
Definition: expr.h:174
Array< PrimExpr > args
The arguments.
Definition: expr.h:947
Call node.
Definition: expr.h:936
Managed reference to StringImmNode.
Definition: expr.h:78
Buffer buffer
The buffer variable.
Definition: expr.h:630
return true_value if condition is true, otherwise return false_value.
Definition: expr.h:572
Reduction operator operator.
Definition: expr.h:1094
Var Reduce(Range dom, PrimExpr binding, DataType dtype=DataType::Int(32))
The reduced block axis defining function.
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:728
a / b in the C semnatics.
Definition: expr.h:211
a <= b
Definition: expr.h:399
Managed reference to EQNode.
Definition: expr.h:357
Load value from the result produced by the producer.
Definition: expr.h:692
void SHashReduce(SHashReducer hash_reduce) const
Definition: expr.h:865
bool SEqualReduce(const StringImmNode *other, SEqualReducer equal) const
Definition: expr.h:64
Array< T > Concat(Array< T > lhs, const Array< T > &rhs)
Concat two Arrays.
Definition: array.h:840
a != b
Definition: expr.h:365
Managed reference to BufferLoadNode.
Definition: expr.h:676
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:164
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:179
Base node of all primitive expressions.
Definition: expr.h:85
int lanes
Total number of lanes.
Definition: expr.h:809
a >= b
Definition: expr.h:433
void VisitAttrs(AttrVisitor *v)
Definition: expr.h:854
PrimExpr condition
The condition.
Definition: expr.h:575