25 #ifndef TVM_TIR_EXPR_H_
26 #define TVM_TIR_EXPR_H_
28 #include <tvm/ffi/container/array.h>
29 #include <tvm/ffi/container/map.h>
30 #include <tvm/ffi/string.h>
43 #include <unordered_map>
63 static constexpr
const char*
_type_key =
"tir.StringImm";
92 static constexpr
const char*
_type_key =
"tir.Cast";
111 template <
typename T>
121 refl::ObjectDef<T>().def_ro(
"a", &T::a).def_ro(
"b", &T::b);
222 static constexpr
const char*
_type_key =
"tir.FloorDiv";
239 static constexpr
const char*
_type_key =
"tir.FloorMod";
291 template <
typename T>
301 refl::ObjectDef<T>().def_ro(
"a", &T::a).def_ro(
"b", &T::b);
473 refl::ObjectDef<NotNode>().def_ro(
"a", &
NotNode::a);
509 refl::ObjectDef<SelectNode>()
552 refl::ObjectDef<BufferLoadNode>()
558 static constexpr
const char*
_type_key =
"tir.BufferLoad";
571 void LegalizeDType();
585 Optional<PrimExpr> predicate = std::nullopt,
Span span =
Span());
608 refl::ObjectDef<ProducerLoadNode>()
613 static constexpr
const char*
_type_key =
"tir.ProducerLoad";
649 refl::ObjectDef<RampNode>()
680 refl::ObjectDef<BroadcastNode>()
685 static constexpr
const char*
_type_key =
"tir.Broadcast";
714 refl::ObjectDef<LetNode>()
715 .def_ro(
"var", &
LetNode::var, refl::AttachFieldFlag::SEqHashDef())
785 refl::ObjectDef<ShuffleNode>()
800 TVM_DLL
Shuffle(Array<PrimExpr> vectors, Array<PrimExpr> indices,
Span span =
Span());
828 Array<PrimExpr>
operator()(Array<PrimExpr> a, Array<PrimExpr> b)
const;
837 refl::ObjectDef<CommReducerNode>()
845 static constexpr
const char*
_type_key =
"tir.CommReducer";
856 TVM_DLL
CommReducer(Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result,
857 Array<PrimExpr> identity_element,
Span span =
Span());
883 refl::ObjectDef<ReduceNode>()
903 int value_index, Array<PrimExpr> init,
Span span =
Span());
917 template <
typename K,
typename V>
919 std::unordered_map<K, V>
ret;
920 for (
auto kv : dmap) {
921 ret[kv.first] = kv.second;
930 inline constexpr
bool use_default_type_traits_v<tvm::tir::StringImm> =
false;
934 :
public ObjectRefWithFallbackTraitsBase<tvm::tir::StringImm, String> {
Symbolic n-dimensional array, to represent a memory buffer.
Constant floating point literals in the program.
Definition: expr.h:538
Constant integer literals in the program.
Definition: expr.h:501
Base node of all primitive expressions.
Definition: expr.h:95
Reference to PrimExprNode.
Definition: expr.h:129
DataType dtype() const
Definition: expr.h:143
Managed reference to RelaxExprNode.
Definition: expr.h:446
Definition: source_map.h:113
Runtime primitive data type.
Definition: data_type.h:47
a + b
Definition: expr.h:128
static constexpr const char * _type_key
Definition: expr.h:130
Managed reference to AddNode.
Definition: expr.h:137
TVM_DEFINE_OBJECT_REF_COW_METHOD(AddNode)
TVM_DEFINE_OBJECT_REF_METHODS(Add, PrimExpr, AddNode)
Add(PrimExpr a, PrimExpr b, Span span=Span())
a && b
Definition: expr.h:410
PrimExpr a
The left operand.
Definition: expr.h:413
static constexpr const char * _type_key
Definition: expr.h:422
PrimExpr b
The right operand.
Definition: expr.h:415
TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, PrimExprNode)
static void RegisterReflection()
Definition: expr.h:417
Managed reference to AndNode.
Definition: expr.h:430
TVM_DEFINE_OBJECT_REF_COW_METHOD(AndNode)
And(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(And, PrimExpr, AndNode)
Base template to implement binary ops.
Definition: expr.h:112
PrimExpr b
The right operand.
Definition: expr.h:117
TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode)
static void RegisterReflection()
Definition: expr.h:119
PrimExpr a
The left operand.
Definition: expr.h:115
Create a vector where all the elements are value.
Definition: expr.h:671
static constexpr const char * _type_key
Definition: expr.h:685
TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, PrimExprNode)
PrimExpr value
The base value.
Definition: expr.h:674
static void RegisterReflection()
Definition: expr.h:678
PrimExpr lanes
The number of lanes.
Definition: expr.h:676
Managed reference to BroadcastNode.
Definition: expr.h:693
Broadcast(PrimExpr value, PrimExpr lanes, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode)
Load value from the high dimension buffer.
Definition: expr.h:541
friend class VectorTypeRewriter
Definition: expr.h:574
friend class CustomDatatypesLowerer
Definition: expr.h:573
TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode)
Array< PrimExpr > indices
The indices location to be loaded.
Definition: expr.h:546
friend class Vectorizer
Definition: expr.h:575
Buffer buffer
The buffer variable.
Definition: expr.h:544
static constexpr const char * _type_key
Definition: expr.h:558
static void RegisterReflection()
Definition: expr.h:550
Optional< PrimExpr > predicate
The predicate mask for loading values.
Definition: expr.h:548
Managed reference to BufferLoadNode.
Definition: expr.h:582
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode)
BufferLoad(Buffer buffer, Array< PrimExpr > indices, Optional< PrimExpr > predicate=std::nullopt, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode)
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:157
Call node.
Definition: expr.h:738
RelaxExpr op
The operator(function) being invoked.
Definition: expr.h:746
static constexpr const char * _type_key
Definition: expr.h:756
Array< PrimExpr > args
The arguments.
Definition: expr.h:749
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode)
static void RegisterReflection()
Definition: expr.h:751
Managed reference to CallNode.
Definition: expr.h:764
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode)
TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode)
Call(DataType dtype, RelaxExpr op, Array< PrimExpr > args, Span span=Span())
Cast value from one data type to another.
Definition: expr.h:82
PrimExpr value
Original data type.
Definition: expr.h:85
static void RegisterReflection()
Definition: expr.h:87
TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, PrimExprNode)
static constexpr const char * _type_key
Definition: expr.h:92
Managed reference to CastNode.
Definition: expr.h:100
TVM_DEFINE_OBJECT_REF_COW_METHOD(CastNode)
TVM_DEFINE_OBJECT_REF_METHODS(Cast, PrimExpr, CastNode)
Cast(DataType dtype, PrimExpr value, Span span=Span())
Base template to implement comparison ops.
Definition: expr.h:292
static void RegisterReflection()
Definition: expr.h:299
TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode)
PrimExpr a
The left operand.
Definition: expr.h:295
PrimExpr b
The right operand.
Definition: expr.h:297
A commutative reducer node to represent a commutative binary operator with identity element.
Definition: expr.h:813
Array< Var > rhs
The right argument of reducer.
Definition: expr.h:818
Array< Var > lhs
The left argument of reducer.
Definition: expr.h:816
static constexpr const char * _type_key
Definition: expr.h:845
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: expr.h:846
Array< PrimExpr > operator()(Array< PrimExpr > a, Array< PrimExpr > b) const
Function call operator to combine a and b.
Array< PrimExpr > result
The result of reducer.
Definition: expr.h:820
TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object)
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:833
static void RegisterReflection()
Definition: expr.h:835
Array< PrimExpr > identity_element
The identity element of reducer, which leaves other elements unchanged when combined with it,...
Definition: expr.h:826
Managed reference to CommReducerNode.
Definition: expr.h:854
CommReducer(Array< Var > lhs, Array< Var > rhs, Array< PrimExpr > result, Array< PrimExpr > identity_element, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(CommReducer, ObjectRef, CommReducerNode)
Managed reference to DataProducerNode.
Definition: buffer.h:288
a / b in the C semnatics.
Definition: expr.h:183
static constexpr const char * _type_key
Definition: expr.h:185
Managed reference to DivNode.
Definition: expr.h:192
TVM_DEFINE_OBJECT_REF_COW_METHOD(DivNode)
Div(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(Div, PrimExpr, DivNode)
a == b
Definition: expr.h:308
static constexpr const char * _type_key
Definition: expr.h:310
Managed reference to EQNode.
Definition: expr.h:317
TVM_DEFINE_OBJECT_REF_METHODS(EQ, PrimExpr, EQNode)
EQ(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(EQNode)
Floor division, floor(a/b)
Definition: expr.h:220
static constexpr const char * _type_key
Definition: expr.h:222
Managed reference to FloorDivNode.
Definition: expr.h:229
TVM_DEFINE_OBJECT_REF_METHODS(FloorDiv, PrimExpr, FloorDivNode)
FloorDiv(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorDivNode)
The remainder of the floordiv.
Definition: expr.h:237
static constexpr const char * _type_key
Definition: expr.h:239
Managed reference to FloorModNode.
Definition: expr.h:246
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorModNode)
TVM_DEFINE_OBJECT_REF_METHODS(FloorMod, PrimExpr, FloorModNode)
FloorMod(PrimExpr a, PrimExpr b, Span span=Span())
a >= b
Definition: expr.h:393
static constexpr const char * _type_key
Definition: expr.h:395
Managed reference to GENode.
Definition: expr.h:402
TVM_DEFINE_OBJECT_REF_METHODS(GE, PrimExpr, GENode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(GENode)
GE(PrimExpr a, PrimExpr b, Span span=Span())
a > b
Definition: expr.h:376
static constexpr const char * _type_key
Definition: expr.h:378
Managed reference to GTNode.
Definition: expr.h:385
TVM_DEFINE_OBJECT_REF_METHODS(GT, PrimExpr, GTNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(GTNode)
GT(PrimExpr a, PrimExpr b, Span span=Span())
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:298
Managed reference to LENode.
Definition: expr.h:368
LE(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(LENode)
TVM_DEFINE_OBJECT_REF_METHODS(LE, PrimExpr, LENode)
a < b
Definition: expr.h:342
static constexpr const char * _type_key
Definition: expr.h:344
Managed reference to LTNode.
Definition: expr.h:351
TVM_DEFINE_OBJECT_REF_COW_METHOD(LTNode)
LT(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(LT, PrimExpr, LTNode)
Let binding. Bind var to value then evaluate body.
Definition: expr.h:703
static void RegisterReflection()
Definition: expr.h:712
Var var
The variable.
Definition: expr.h:706
static constexpr const char * _type_key
Definition: expr.h:720
PrimExpr value
The value to be binded.
Definition: expr.h:708
TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, PrimExprNode)
PrimExpr body
The result expression.
Definition: expr.h:710
Managed reference to LetNode.
Definition: expr.h:728
TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode)
Let(Var var, PrimExpr value, PrimExpr body, Span span=Span())
max(a, b)
Definition: expr.h:271
static constexpr const char * _type_key
Definition: expr.h:273
Managed reference to MaxNode.
Definition: expr.h:280
TVM_DEFINE_OBJECT_REF_METHODS(Max, PrimExpr, MaxNode)
Max(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(MaxNode)
min(a, b)
Definition: expr.h:254
static constexpr const char * _type_key
Definition: expr.h:256
Managed reference to MinNode.
Definition: expr.h:263
Min(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(Min, PrimExpr, MinNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MinNode)
a % b in the C semnatics.
Definition: expr.h:203
static constexpr const char * _type_key
Definition: expr.h:205
Managed reference to ModNode.
Definition: expr.h:212
TVM_DEFINE_OBJECT_REF_METHODS(Mod, PrimExpr, ModNode)
Mod(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ModNode)
a * b
Definition: expr.h:163
static constexpr const char * _type_key
Definition: expr.h:165
Managed reference to MulNode.
Definition: expr.h:172
Mul(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(Mul, PrimExpr, MulNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MulNode)
a != b
Definition: expr.h:325
static constexpr const char * _type_key
Definition: expr.h:327
Managed reference to NENode.
Definition: expr.h:334
NE(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(NENode)
TVM_DEFINE_OBJECT_REF_METHODS(NE, PrimExpr, NENode)
TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, PrimExprNode)
PrimExpr a
The input operand.
Definition: expr.h:469
static constexpr const char * _type_key
Definition: expr.h:476
static void RegisterReflection()
Definition: expr.h:471
Managed reference to NotNode.
Definition: expr.h:484
TVM_DEFINE_OBJECT_REF_COW_METHOD(NotNode)
Not(PrimExpr a, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(Not, PrimExpr, NotNode)
a || b
Definition: expr.h:438
PrimExpr b
The right operand.
Definition: expr.h:443
PrimExpr a
The left operand.
Definition: expr.h:441
TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, PrimExprNode)
static constexpr const char * _type_key
Definition: expr.h:450
static void RegisterReflection()
Definition: expr.h:445
Managed reference to OrNode.
Definition: expr.h:458
TVM_DEFINE_OBJECT_REF_METHODS(Or, PrimExpr, OrNode)
Or(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(OrNode)
Load value from the result produced by the producer.
Definition: expr.h:599
static constexpr const char * _type_key
Definition: expr.h:613
Array< PrimExpr > indices
The location arguments.
Definition: expr.h:604
static void RegisterReflection()
Definition: expr.h:606
DataProducer producer
The buffer producer.
Definition: expr.h:602
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerLoadNode, PrimExprNode)
Managed reference to ProducerLoadNode.
Definition: expr.h:621
TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode)
ProducerLoad(DataProducer producer, Array< PrimExpr > indices, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerLoadNode)
Construct a vector with lanes elements where its i-th element equals base + i * stride....
Definition: expr.h:638
static constexpr const char * _type_key
Definition: expr.h:655
PrimExpr stride
The stride of each step.
Definition: expr.h:643
TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, PrimExprNode)
PrimExpr lanes
Total number of lanes.
Definition: expr.h:645
static void RegisterReflection()
Definition: expr.h:647
PrimExpr base
The base value.
Definition: expr.h:641
Managed reference to RampNode.
Definition: expr.h:663
TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode)
TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode)
Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span=Span())
Reduction operator.
Definition: expr.h:863
Array< PrimExpr > init
The init operand.
Definition: expr.h:870
int value_index
the index of this reduce node
Definition: expr.h:879
static constexpr const char * _type_key
Definition: expr.h:892
TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode)
Array< IterVar > axis
The reduction axis.
Definition: expr.h:872
CommReducer combiner
The commutative combiner.
Definition: expr.h:866
static void RegisterReflection()
Definition: expr.h:881
Array< PrimExpr > source
The source operand.
Definition: expr.h:868
PrimExpr condition
Predicate on the reduction Only add the body to reduction if condition is true.
Definition: expr.h:877
Managed reference to ReduceNode.
Definition: expr.h:900
TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode)
TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode)
Reduce(CommReducer combiner, Array< PrimExpr > src, Array< IterVar > rdom, PrimExpr condition, int value_index, Array< PrimExpr > init, Span span=Span())
return true_value if condition is true, otherwise return false_value.
Definition: expr.h:498
PrimExpr condition
The condition.
Definition: expr.h:501
PrimExpr true_value
value to be returned when condition is true.
Definition: expr.h:503
static constexpr const char * _type_key
Definition: expr.h:515
TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, PrimExprNode)
PrimExpr false_value
value to be returned when condition is false.
Definition: expr.h:505
static void RegisterReflection()
Definition: expr.h:507
Managed reference to SelectNode.
Definition: expr.h:523
TVM_DEFINE_OBJECT_REF_COW_METHOD(SelectNode)
TVM_DEFINE_OBJECT_REF_METHODS(Select, PrimExpr, SelectNode)
Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Shuffle instruction. vec = concat(vectors) result = (vec[indices[0]], vec[indices[1]] ....
Definition: expr.h:776
Array< PrimExpr > indices
The indices of each element.
Definition: expr.h:781
static constexpr const char * _type_key
Definition: expr.h:790
static void RegisterReflection()
Definition: expr.h:783
TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, PrimExprNode)
Array< PrimExpr > vectors
the input vectors.
Definition: expr.h:779
Managed reference to ShuffleNode.
Definition: expr.h:798
static PrimExpr Concat(Array< PrimExpr > vectors, Span span=Span())
Shuffle(Array< PrimExpr > vectors, Array< PrimExpr > indices, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ShuffleNode)
TVM_DEFINE_OBJECT_REF_METHODS(Shuffle, PrimExpr, ShuffleNode)
static PrimExpr ExtractElement(PrimExpr vector, int index, Span span=Span())
String constants, only used in asserts.
Definition: expr.h:53
String value
The constant value content.
Definition: expr.h:56
static constexpr const char * _type_key
Definition: expr.h:63
TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode)
static void RegisterReflection()
Definition: expr.h:58
Managed reference to StringImmNode.
Definition: expr.h:71
StringImm(String value, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode)
TVM_DEFINE_OBJECT_REF_METHODS(StringImm, PrimExpr, StringImmNode)
a - b
Definition: expr.h:145
static constexpr const char * _type_key
Definition: expr.h:147
Managed reference to SubNode.
Definition: expr.h:154
Sub(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_METHODS(Sub, PrimExpr, SubNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(SubNode)
a named variable in TIR
Definition: var.h:78
Defines the Functor data structures.
Definition: repr_printer.h:91
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
std::unordered_map< K, V > as_unordered_map(const Map< K, V > &dmap)
Definition: expr.h:918
tvm::FloatImmNode FloatImmNode
Definition: expr.h:50
tvm::IntImmNode IntImmNode
Definition: expr.h:49
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Definitions and helper macros for IR/AST nodes.
static TVM_FFI_INLINE tvm::tir::StringImm ConvertFallbackValue(String value)
Definition: expr.h:935
a <= b
Definition: expr.h:359
static constexpr const char * _type_key
Definition: expr.h:361