25 #ifndef TVM_TIR_EXPR_H_
26 #define TVM_TIR_EXPR_H_
28 #include <tvm/ffi/container/array.h>
29 #include <tvm/ffi/container/map.h>
30 #include <tvm/ffi/string.h>
42 #include <unordered_map>
106 template <
typename T>
116 refl::ObjectDef<T>().def_ro(
"a", &T::a).def_ro(
"b", &T::b);
219 static constexpr
const char*
_type_key =
"tirx.FloorDiv";
236 static constexpr
const char*
_type_key =
"tirx.FloorMod";
288 template <
typename T>
298 refl::ObjectDef<T>().def_ro(
"a", &T::a).def_ro(
"b", &T::b);
468 refl::ObjectDef<NotNode>().def_ro(
"a", &
NotNode::a);
502 refl::ObjectDef<SelectNode>()
543 refl::ObjectDef<BufferLoadNode>()
560 void LegalizeDType();
574 ffi::Optional<PrimExpr> predicate = std::nullopt,
Span span =
Span());
597 refl::ObjectDef<ProducerLoadNode>()
637 refl::ObjectDef<RampNode>()
666 refl::ObjectDef<BroadcastNode>()
698 refl::ObjectDef<LetNode>()
699 .def_ro(
"var", &
LetNode::var, refl::AttachFieldFlag::SEqHashDef())
765 refl::ObjectDef<ShuffleNode>()
778 TVM_DLL
Shuffle(ffi::Array<PrimExpr> vectors, ffi::Array<PrimExpr> indices,
Span span =
Span());
806 ffi::Array<PrimExpr>
operator()(ffi::Array<PrimExpr> a, ffi::Array<PrimExpr> b)
const;
815 refl::ObjectDef<CommReducerNode>()
833 TVM_DLL
CommReducer(ffi::Array<Var> lhs, ffi::Array<Var> rhs, ffi::Array<PrimExpr> result,
834 ffi::Array<PrimExpr> identity_element,
Span span =
Span());
860 refl::ObjectDef<ReduceNode>()
878 PrimExpr condition,
int value_index, ffi::Array<PrimExpr> init,
893 template <
typename K,
typename V>
895 std::unordered_map<K, V>
ret;
896 for (
auto kv : dmap) {
897 ret[kv.first] = kv.second;
906 inline constexpr
bool use_default_type_traits_v<tvm::tirx::StringImm> =
false;
910 :
public ObjectRefWithFallbackTraitsBase<tvm::tirx::StringImm, ffi::String> {
Symbolic n-dimensional array, to represent a memory buffer.
Constant floating point literals in the program.
Definition: expr.h:529
Constant integer literals in the program.
Definition: expr.h:494
Base node of all primitive expressions.
Definition: expr.h:93
Reference to PrimExprNode.
Definition: expr.h:126
DataType dtype() const
Definition: expr.h:140
Managed reference to RelaxExprNode.
Definition: expr.h:441
Definition: source_map.h:111
Runtime primitive data type.
Definition: data_type.h:47
a + b
Definition: expr.h:125
static constexpr const char * _type_key
Definition: expr.h:127
Managed reference to AddNode.
Definition: expr.h:134
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Add, PrimExpr, AddNode)
Add(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AddNode)
a && b
Definition: expr.h:409
PrimExpr a
The left operand.
Definition: expr.h:412
PrimExpr b
The right operand.
Definition: expr.h:414
static void RegisterReflection()
Definition: expr.h:416
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.And", AndNode, PrimExprNode)
Managed reference to AndNode.
Definition: expr.h:427
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(And, PrimExpr, AndNode)
And(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(AndNode)
Base template to implement binary ops.
Definition: expr.h:107
static constexpr const int _type_child_slots
Definition: expr.h:119
PrimExpr b
The right operand.
Definition: expr.h:112
TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(T, PrimExprNode)
static constexpr const bool _type_final
Definition: expr.h:120
static void RegisterReflection()
Definition: expr.h:114
PrimExpr a
The left operand.
Definition: expr.h:110
Create a vector where all the elements are value.
Definition: expr.h:657
PrimExpr value
The base value.
Definition: expr.h:660
static void RegisterReflection()
Definition: expr.h:664
PrimExpr lanes
The number of lanes.
Definition: expr.h:662
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Broadcast", BroadcastNode, PrimExprNode)
Managed reference to BroadcastNode.
Definition: expr.h:677
TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode)
Broadcast(PrimExpr value, PrimExpr lanes, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Broadcast, PrimExpr, BroadcastNode)
Load value from the high dimension buffer.
Definition: expr.h:532
Buffer buffer
The buffer variable.
Definition: expr.h:535
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.BufferLoad", BufferLoadNode, PrimExprNode)
friend class VectorTypeRewriter
Definition: expr.h:563
ffi::Optional< PrimExpr > predicate
The predicate mask for loading values.
Definition: expr.h:539
friend class CustomDatatypesLowerer
Definition: expr.h:562
friend class Vectorizer
Definition: expr.h:564
static void RegisterReflection()
Definition: expr.h:541
ffi::Array< PrimExpr > indices
The indices location to be loaded.
Definition: expr.h:537
Managed reference to BufferLoadNode.
Definition: expr.h:571
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BufferLoad, PrimExpr, BufferLoadNode)
BufferLoad(Buffer buffer, ffi::Array< PrimExpr > indices, ffi::Optional< PrimExpr > predicate=std::nullopt, Span span=Span())
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:156
Call node.
Definition: expr.h:720
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Call", CallNode, PrimExprNode)
ffi::Array< PrimExpr > args
The arguments.
Definition: expr.h:731
RelaxExpr op
The operator(function) being invoked.
Definition: expr.h:728
static void RegisterReflection()
Definition: expr.h:733
Managed reference to CallNode.
Definition: expr.h:744
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode)
Call(DataType dtype, RelaxExpr op, ffi::Array< PrimExpr > args, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, PrimExpr, CallNode)
Cast value from one data type to another.
Definition: expr.h:79
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Cast", CastNode, PrimExprNode)
PrimExpr value
Original data type.
Definition: expr.h:82
static void RegisterReflection()
Definition: expr.h:84
Managed reference to CastNode.
Definition: expr.h:95
Cast(DataType dtype, PrimExpr value, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(CastNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Cast, PrimExpr, CastNode)
Base template to implement comparison ops.
Definition: expr.h:289
static constexpr const bool _type_final
Definition: expr.h:302
PrimExpr a
The left operand.
Definition: expr.h:292
static constexpr const int _type_child_slots
Definition: expr.h:301
static void RegisterReflection()
Definition: expr.h:296
PrimExpr b
The right operand.
Definition: expr.h:294
TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(T, PrimExprNode)
A commutative reducer node to represent a commutative binary operator with identity element.
Definition: expr.h:791
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.CommReducer", CommReducerNode, Object)
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: expr.h:823
ffi::Array< Var > lhs
The left argument of reducer.
Definition: expr.h:794
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:811
ffi::Array< PrimExpr > identity_element
The identity element of reducer, which leaves other elements unchanged when combined with it,...
Definition: expr.h:804
static void RegisterReflection()
Definition: expr.h:813
ffi::Array< PrimExpr > operator()(ffi::Array< PrimExpr > a, ffi::Array< PrimExpr > b) const
Function call operator to combine a and b.
ffi::Array< Var > rhs
The right argument of reducer.
Definition: expr.h:796
ffi::Array< PrimExpr > result
The result of reducer.
Definition: expr.h:798
Managed reference to CommReducerNode.
Definition: expr.h:831
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CommReducer, ObjectRef, CommReducerNode)
CommReducer(ffi::Array< Var > lhs, ffi::Array< Var > rhs, ffi::Array< PrimExpr > result, ffi::Array< PrimExpr > identity_element, Span span=Span())
Managed reference to DataProducerNode.
Definition: buffer.h:286
a / b in the C semnatics.
Definition: expr.h:180
static constexpr const char * _type_key
Definition: expr.h:182
Managed reference to DivNode.
Definition: expr.h:189
TVM_DEFINE_OBJECT_REF_COW_METHOD(DivNode)
Div(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Div, PrimExpr, DivNode)
a == b
Definition: expr.h:307
static constexpr const char * _type_key
Definition: expr.h:309
Managed reference to EQNode.
Definition: expr.h:316
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(EQ, PrimExpr, EQNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(EQNode)
EQ(PrimExpr a, PrimExpr b, Span span=Span())
Floor division, floor(a/b)
Definition: expr.h:217
static constexpr const char * _type_key
Definition: expr.h:219
Managed reference to FloorDivNode.
Definition: expr.h:226
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FloorDiv, PrimExpr, FloorDivNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorDivNode)
FloorDiv(PrimExpr a, PrimExpr b, Span span=Span())
The remainder of the floordiv.
Definition: expr.h:234
static constexpr const char * _type_key
Definition: expr.h:236
Managed reference to FloorModNode.
Definition: expr.h:243
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FloorMod, PrimExpr, FloorModNode)
FloorMod(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorModNode)
a >= b
Definition: expr.h:392
static constexpr const char * _type_key
Definition: expr.h:394
Managed reference to GENode.
Definition: expr.h:401
GE(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GE, PrimExpr, GENode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(GENode)
a > b
Definition: expr.h:375
static constexpr const char * _type_key
Definition: expr.h:377
Managed reference to GTNode.
Definition: expr.h:384
TVM_DEFINE_OBJECT_REF_COW_METHOD(GTNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GT, PrimExpr, GTNode)
GT(PrimExpr a, PrimExpr b, Span span=Span())
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:296
Managed reference to LENode.
Definition: expr.h:367
TVM_DEFINE_OBJECT_REF_COW_METHOD(LENode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LE, PrimExpr, LENode)
LE(PrimExpr a, PrimExpr b, Span span=Span())
a < b
Definition: expr.h:341
static constexpr const char * _type_key
Definition: expr.h:343
Managed reference to LTNode.
Definition: expr.h:350
TVM_DEFINE_OBJECT_REF_COW_METHOD(LTNode)
LT(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LT, PrimExpr, LTNode)
Let binding. Bind var to value then evaluate body.
Definition: expr.h:687
PrimExpr body
The result expression.
Definition: expr.h:694
PrimExpr value
The value to be binded.
Definition: expr.h:692
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Let", LetNode, PrimExprNode)
static void RegisterReflection()
Definition: expr.h:696
Var var
The variable.
Definition: expr.h:690
Managed reference to LetNode.
Definition: expr.h:710
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode)
Let(Var var, PrimExpr value, PrimExpr body, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Let, PrimExpr, LetNode)
max(a, b)
Definition: expr.h:268
static constexpr const char * _type_key
Definition: expr.h:270
Managed reference to MaxNode.
Definition: expr.h:277
TVM_DEFINE_OBJECT_REF_COW_METHOD(MaxNode)
Max(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Max, PrimExpr, MaxNode)
min(a, b)
Definition: expr.h:251
static constexpr const char * _type_key
Definition: expr.h:253
Managed reference to MinNode.
Definition: expr.h:260
TVM_DEFINE_OBJECT_REF_COW_METHOD(MinNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Min, PrimExpr, MinNode)
Min(PrimExpr a, PrimExpr b, Span span=Span())
a % b in the C semnatics.
Definition: expr.h:200
static constexpr const char * _type_key
Definition: expr.h:202
Managed reference to ModNode.
Definition: expr.h:209
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Mod, PrimExpr, ModNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ModNode)
Mod(PrimExpr a, PrimExpr b, Span span=Span())
a * b
Definition: expr.h:160
static constexpr const char * _type_key
Definition: expr.h:162
Managed reference to MulNode.
Definition: expr.h:169
Mul(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Mul, PrimExpr, MulNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MulNode)
a != b
Definition: expr.h:324
static constexpr const char * _type_key
Definition: expr.h:326
Managed reference to NENode.
Definition: expr.h:333
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(NE, PrimExpr, NENode)
NE(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(NENode)
static void RegisterReflection()
Definition: expr.h:466
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Not", NotNode, PrimExprNode)
PrimExpr a
The input operand.
Definition: expr.h:464
Managed reference to NotNode.
Definition: expr.h:477
Not(PrimExpr a, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(NotNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Not, PrimExpr, NotNode)
a || b
Definition: expr.h:435
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Or", OrNode, PrimExprNode)
PrimExpr b
The right operand.
Definition: expr.h:440
static void RegisterReflection()
Definition: expr.h:442
PrimExpr a
The left operand.
Definition: expr.h:438
Managed reference to OrNode.
Definition: expr.h:453
Or(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(OrNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Or, PrimExpr, OrNode)
Load value from the result produced by the producer.
Definition: expr.h:588
ffi::Array< PrimExpr > indices
The location arguments.
Definition: expr.h:593
DataProducer producer
The buffer producer.
Definition: expr.h:591
static void RegisterReflection()
Definition: expr.h:595
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.ProducerLoad", ProducerLoadNode, PrimExprNode)
Managed reference to ProducerLoadNode.
Definition: expr.h:608
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ProducerLoad, PrimExpr, ProducerLoadNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerLoadNode)
ProducerLoad(DataProducer producer, ffi::Array< PrimExpr > indices, Span span=Span())
Construct a vector with lanes elements where its i-th element equals base + i * stride....
Definition: expr.h:626
PrimExpr lanes
Total number of lanes.
Definition: expr.h:633
PrimExpr base
The base value.
Definition: expr.h:629
static void RegisterReflection()
Definition: expr.h:635
PrimExpr stride
The stride of each step.
Definition: expr.h:631
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Ramp", RampNode, PrimExprNode)
Managed reference to RampNode.
Definition: expr.h:649
TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode)
Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Ramp, PrimExpr, RampNode)
Reduction operator.
Definition: expr.h:840
ffi::Array< PrimExpr > source
The source operand.
Definition: expr.h:845
ffi::Array< IterVar > axis
The reduction axis.
Definition: expr.h:849
CommReducer combiner
The commutative combiner.
Definition: expr.h:843
static void RegisterReflection()
Definition: expr.h:858
PrimExpr condition
Predicate on the reduction Only add the body to reduction if condition is true.
Definition: expr.h:854
int value_index
the index of this reduce node
Definition: expr.h:856
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Reduce", ReduceNode, PrimExprNode)
ffi::Array< PrimExpr > init
The init operand.
Definition: expr.h:847
Managed reference to ReduceNode.
Definition: expr.h:875
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Reduce, PrimExpr, ReduceNode)
Reduce(CommReducer combiner, ffi::Array< PrimExpr > src, ffi::Array< IterVar > rdom, PrimExpr condition, int value_index, ffi::Array< PrimExpr > init, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode)
return true_value if condition is true, otherwise return false_value.
Definition: expr.h:491
static void RegisterReflection()
Definition: expr.h:500
PrimExpr false_value
value to be returned when condition is false.
Definition: expr.h:498
PrimExpr true_value
value to be returned when condition is true.
Definition: expr.h:496
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Select", SelectNode, PrimExprNode)
PrimExpr condition
The condition.
Definition: expr.h:494
Managed reference to SelectNode.
Definition: expr.h:514
Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(SelectNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Select, PrimExpr, SelectNode)
Shuffle instruction. vec = concat(vectors) result = (vec[indices[0]], vec[indices[1]] ....
Definition: expr.h:756
ffi::Array< PrimExpr > indices
The indices of each element.
Definition: expr.h:761
ffi::Array< PrimExpr > vectors
the input vectors.
Definition: expr.h:759
static void RegisterReflection()
Definition: expr.h:763
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Shuffle", ShuffleNode, PrimExprNode)
Managed reference to ShuffleNode.
Definition: expr.h:776
static PrimExpr ExtractElement(PrimExpr vector, int index, Span span=Span())
static PrimExpr Concat(ffi::Array< PrimExpr > vectors, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ShuffleNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Shuffle, PrimExpr, ShuffleNode)
Shuffle(ffi::Array< PrimExpr > vectors, ffi::Array< PrimExpr > indices, Span span=Span())
ffi::String constants, only used in asserts.
Definition: expr.h:52
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.StringImm", StringImmNode, PrimExprNode)
ffi::String value
The constant value content.
Definition: expr.h:55
static void RegisterReflection()
Definition: expr.h:57
Managed reference to StringImmNode.
Definition: expr.h:68
TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode)
StringImm(ffi::String value, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StringImm, PrimExpr, StringImmNode)
a - b
Definition: expr.h:142
static constexpr const char * _type_key
Definition: expr.h:144
Managed reference to SubNode.
Definition: expr.h:151
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Sub, PrimExpr, SubNode)
Sub(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(SubNode)
a named variable in TIR
Definition: var.h:76
Defines the Functor data structures.
Definition: repr_printer.h:91
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
tvm::FloatImmNode FloatImmNode
Definition: expr.h:49
tvm::IntImmNode IntImmNode
Definition: expr.h:48
std::unordered_map< K, V > as_unordered_map(const ffi::Map< K, V > &dmap)
Definition: expr.h:894
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
static TVM_FFI_INLINE tvm::tirx::StringImm ConvertFallbackValue(ffi::String value)
Definition: expr.h:911
a <= b
Definition: expr.h:358
static constexpr const char * _type_key
Definition: expr.h:360