25 #ifndef TVM_TIR_EXPR_H_
26 #define TVM_TIR_EXPR_H_
28 #include <tvm/ffi/container/array.h>
29 #include <tvm/ffi/container/map.h>
30 #include <tvm/ffi/string.h>
43 #include <unordered_map>
107 template <
typename T>
117 refl::ObjectDef<T>().def_ro(
"a", &T::a).def_ro(
"b", &T::b);
220 static constexpr
const char*
_type_key =
"tir.FloorDiv";
237 static constexpr
const char*
_type_key =
"tir.FloorMod";
289 template <
typename T>
299 refl::ObjectDef<T>().def_ro(
"a", &T::a).def_ro(
"b", &T::b);
469 refl::ObjectDef<NotNode>().def_ro(
"a", &
NotNode::a);
503 refl::ObjectDef<SelectNode>()
544 refl::ObjectDef<BufferLoadNode>()
561 void LegalizeDType();
575 ffi::Optional<PrimExpr> predicate = std::nullopt,
Span span =
Span());
598 refl::ObjectDef<ProducerLoadNode>()
638 refl::ObjectDef<RampNode>()
667 refl::ObjectDef<BroadcastNode>()
699 refl::ObjectDef<LetNode>()
700 .def_ro(
"var", &
LetNode::var, refl::AttachFieldFlag::SEqHashDef())
766 refl::ObjectDef<ShuffleNode>()
779 TVM_DLL
Shuffle(ffi::Array<PrimExpr> vectors, ffi::Array<PrimExpr> indices,
Span span =
Span());
807 ffi::Array<PrimExpr>
operator()(ffi::Array<PrimExpr> a, ffi::Array<PrimExpr> b)
const;
816 refl::ObjectDef<CommReducerNode>()
834 TVM_DLL
CommReducer(ffi::Array<Var> lhs, ffi::Array<Var> rhs, ffi::Array<PrimExpr> result,
835 ffi::Array<PrimExpr> identity_element,
Span span =
Span());
861 refl::ObjectDef<ReduceNode>()
879 PrimExpr condition,
int value_index, ffi::Array<PrimExpr> init,
894 template <
typename K,
typename V>
896 std::unordered_map<K, V>
ret;
897 for (
auto kv : dmap) {
898 ret[kv.first] = kv.second;
907 inline constexpr
bool use_default_type_traits_v<tvm::tir::StringImm> =
false;
911 :
public ObjectRefWithFallbackTraitsBase<tvm::tir::StringImm, ffi::String> {
Symbolic n-dimensional array, to represent a memory buffer.
Constant floating point literals in the program.
Definition: expr.h:528
Constant integer literals in the program.
Definition: expr.h:493
Base node of all primitive expressions.
Definition: expr.h:91
Reference to PrimExprNode.
Definition: expr.h:124
DataType dtype() const
Definition: expr.h:138
Managed reference to RelaxExprNode.
Definition: expr.h:439
Definition: source_map.h:111
Runtime primitive data type.
Definition: data_type.h:47
a + b
Definition: expr.h:126
static constexpr const char * _type_key
Definition: expr.h:128
Managed reference to AddNode.
Definition: expr.h:135
TVM_DEFINE_OBJECT_REF_COW_METHOD(AddNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Add, PrimExpr, AddNode)
Add(PrimExpr a, PrimExpr b, Span span=Span())
a && b
Definition: expr.h:410
PrimExpr a
The left operand.
Definition: expr.h:413
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.And", AndNode, PrimExprNode)
PrimExpr b
The right operand.
Definition: expr.h:415
static void RegisterReflection()
Definition: expr.h:417
Managed reference to AndNode.
Definition: expr.h:428
TVM_DEFINE_OBJECT_REF_COW_METHOD(AndNode)
And(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(And, PrimExpr, AndNode)
Base template to implement binary ops.
Definition: expr.h:108
PrimExpr b
The right operand.
Definition: expr.h:113
TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(T, PrimExprNode)
static void RegisterReflection()
Definition: expr.h:115
PrimExpr a
The left operand.
Definition: expr.h:111
static constexpr const bool _type_final
Definition: expr.h:121
static constexpr const int _type_child_slots
Definition: expr.h:120
Create a vector where all the elements are value.
Definition: expr.h:658
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Broadcast", BroadcastNode, PrimExprNode)
PrimExpr value
The base value.
Definition: expr.h:661
static void RegisterReflection()
Definition: expr.h:665
PrimExpr lanes
The number of lanes.
Definition: expr.h:663
Managed reference to BroadcastNode.
Definition: expr.h:678
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Broadcast, PrimExpr, BroadcastNode)
Broadcast(PrimExpr value, PrimExpr lanes, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode)
Load value from the high dimension buffer.
Definition: expr.h:533
ffi::Array< PrimExpr > indices
The indices location to be loaded.
Definition: expr.h:538
friend class VectorTypeRewriter
Definition: expr.h:564
friend class CustomDatatypesLowerer
Definition: expr.h:563
ffi::Optional< PrimExpr > predicate
The predicate mask for loading values.
Definition: expr.h:540
friend class Vectorizer
Definition: expr.h:565
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BufferLoad", BufferLoadNode, PrimExprNode)
Buffer buffer
The buffer variable.
Definition: expr.h:536
static void RegisterReflection()
Definition: expr.h:542
Managed reference to BufferLoadNode.
Definition: expr.h:572
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BufferLoad, PrimExpr, BufferLoadNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode)
BufferLoad(Buffer buffer, ffi::Array< PrimExpr > indices, ffi::Optional< PrimExpr > predicate=std::nullopt, Span span=Span())
Buffer is a symbolic n-darray structure. It is a composition of primitive symbolic types,...
Definition: buffer.h:156
Call node.
Definition: expr.h:721
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Call", CallNode, PrimExprNode)
RelaxExpr op
The operator(function) being invoked.
Definition: expr.h:729
ffi::Array< PrimExpr > args
The arguments.
Definition: expr.h:732
static void RegisterReflection()
Definition: expr.h:734
Managed reference to CallNode.
Definition: expr.h:745
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode)
Call(DataType dtype, RelaxExpr op, ffi::Array< PrimExpr > args, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, PrimExpr, CallNode)
Cast value from one data type to another.
Definition: expr.h:80
PrimExpr value
Original data type.
Definition: expr.h:83
static void RegisterReflection()
Definition: expr.h:85
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Cast", CastNode, PrimExprNode)
Managed reference to CastNode.
Definition: expr.h:96
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Cast, PrimExpr, CastNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(CastNode)
Cast(DataType dtype, PrimExpr value, Span span=Span())
Base template to implement comparison ops.
Definition: expr.h:290
static constexpr const int _type_child_slots
Definition: expr.h:302
static void RegisterReflection()
Definition: expr.h:297
PrimExpr a
The left operand.
Definition: expr.h:293
PrimExpr b
The right operand.
Definition: expr.h:295
static constexpr const bool _type_final
Definition: expr.h:303
TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(T, PrimExprNode)
A commutative reducer node to represent a commutative binary operator with identity element.
Definition: expr.h:792
ffi::Array< PrimExpr > identity_element
The identity element of reducer, which leaves other elements unchanged when combined with it,...
Definition: expr.h:805
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: expr.h:824
ffi::Array< PrimExpr > result
The result of reducer.
Definition: expr.h:799
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.CommReducer", CommReducerNode, Object)
ffi::Array< Var > rhs
The right argument of reducer.
Definition: expr.h:797
ffi::Array< Var > lhs
The left argument of reducer.
Definition: expr.h:795
ffi::Array< PrimExpr > operator()(ffi::Array< PrimExpr > a, ffi::Array< PrimExpr > b) const
Function call operator to combine a and b.
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:812
static void RegisterReflection()
Definition: expr.h:814
Managed reference to CommReducerNode.
Definition: expr.h:832
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CommReducer, ObjectRef, CommReducerNode)
CommReducer(ffi::Array< Var > lhs, ffi::Array< Var > rhs, ffi::Array< PrimExpr > result, ffi::Array< PrimExpr > identity_element, Span span=Span())
Managed reference to DataProducerNode.
Definition: buffer.h:286
a / b in the C semnatics.
Definition: expr.h:181
static constexpr const char * _type_key
Definition: expr.h:183
Managed reference to DivNode.
Definition: expr.h:190
TVM_DEFINE_OBJECT_REF_COW_METHOD(DivNode)
Div(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Div, PrimExpr, DivNode)
a == b
Definition: expr.h:308
static constexpr const char * _type_key
Definition: expr.h:310
Managed reference to EQNode.
Definition: expr.h:317
EQ(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(EQ, PrimExpr, EQNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(EQNode)
Floor division, floor(a/b)
Definition: expr.h:218
static constexpr const char * _type_key
Definition: expr.h:220
Managed reference to FloorDivNode.
Definition: expr.h:227
FloorDiv(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FloorDiv, PrimExpr, FloorDivNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorDivNode)
The remainder of the floordiv.
Definition: expr.h:235
static constexpr const char * _type_key
Definition: expr.h:237
Managed reference to FloorModNode.
Definition: expr.h:244
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FloorMod, PrimExpr, FloorModNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorModNode)
FloorMod(PrimExpr a, PrimExpr b, Span span=Span())
a >= b
Definition: expr.h:393
static constexpr const char * _type_key
Definition: expr.h:395
Managed reference to GENode.
Definition: expr.h:402
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GE, PrimExpr, GENode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(GENode)
GE(PrimExpr a, PrimExpr b, Span span=Span())
a > b
Definition: expr.h:376
static constexpr const char * _type_key
Definition: expr.h:378
Managed reference to GTNode.
Definition: expr.h:385
TVM_DEFINE_OBJECT_REF_COW_METHOD(GTNode)
GT(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GT, PrimExpr, GTNode)
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:297
Managed reference to LENode.
Definition: expr.h:368
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LE, PrimExpr, LENode)
LE(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(LENode)
a < b
Definition: expr.h:342
static constexpr const char * _type_key
Definition: expr.h:344
Managed reference to LTNode.
Definition: expr.h:351
TVM_DEFINE_OBJECT_REF_COW_METHOD(LTNode)
LT(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LT, PrimExpr, LTNode)
Let binding. Bind var to value then evaluate body.
Definition: expr.h:688
static void RegisterReflection()
Definition: expr.h:697
Var var
The variable.
Definition: expr.h:691
PrimExpr value
The value to be binded.
Definition: expr.h:693
PrimExpr body
The result expression.
Definition: expr.h:695
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Let", LetNode, PrimExprNode)
Managed reference to LetNode.
Definition: expr.h:711
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Let, PrimExpr, LetNode)
Let(Var var, PrimExpr value, PrimExpr body, Span span=Span())
max(a, b)
Definition: expr.h:269
static constexpr const char * _type_key
Definition: expr.h:271
Managed reference to MaxNode.
Definition: expr.h:278
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Max, PrimExpr, MaxNode)
Max(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(MaxNode)
min(a, b)
Definition: expr.h:252
static constexpr const char * _type_key
Definition: expr.h:254
Managed reference to MinNode.
Definition: expr.h:261
Min(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Min, PrimExpr, MinNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(MinNode)
a % b in the C semnatics.
Definition: expr.h:201
static constexpr const char * _type_key
Definition: expr.h:203
Managed reference to ModNode.
Definition: expr.h:210
Mod(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Mod, PrimExpr, ModNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ModNode)
a * b
Definition: expr.h:161
static constexpr const char * _type_key
Definition: expr.h:163
Managed reference to MulNode.
Definition: expr.h:170
Mul(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(MulNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Mul, PrimExpr, MulNode)
a != b
Definition: expr.h:325
static constexpr const char * _type_key
Definition: expr.h:327
Managed reference to NENode.
Definition: expr.h:334
NE(PrimExpr a, PrimExpr b, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(NENode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(NE, PrimExpr, NENode)
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Not", NotNode, PrimExprNode)
PrimExpr a
The input operand.
Definition: expr.h:465
static void RegisterReflection()
Definition: expr.h:467
Managed reference to NotNode.
Definition: expr.h:478
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Not, PrimExpr, NotNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(NotNode)
Not(PrimExpr a, Span span=Span())
a || b
Definition: expr.h:436
PrimExpr b
The right operand.
Definition: expr.h:441
PrimExpr a
The left operand.
Definition: expr.h:439
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Or", OrNode, PrimExprNode)
static void RegisterReflection()
Definition: expr.h:443
Managed reference to OrNode.
Definition: expr.h:454
Or(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Or, PrimExpr, OrNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(OrNode)
Load value from the result produced by the producer.
Definition: expr.h:589
static void RegisterReflection()
Definition: expr.h:596
DataProducer producer
The buffer producer.
Definition: expr.h:592
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.ProducerLoad", ProducerLoadNode, PrimExprNode)
ffi::Array< PrimExpr > indices
The location arguments.
Definition: expr.h:594
Managed reference to ProducerLoadNode.
Definition: expr.h:609
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ProducerLoad, PrimExpr, ProducerLoadNode)
ProducerLoad(DataProducer producer, ffi::Array< PrimExpr > indices, Span span=Span())
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerLoadNode)
Construct a vector with lanes elements where its i-th element equals base + i * stride....
Definition: expr.h:627
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Ramp", RampNode, PrimExprNode)
PrimExpr stride
The stride of each step.
Definition: expr.h:632
PrimExpr lanes
Total number of lanes.
Definition: expr.h:634
static void RegisterReflection()
Definition: expr.h:636
PrimExpr base
The base value.
Definition: expr.h:630
Managed reference to RampNode.
Definition: expr.h:650
TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode)
Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Ramp, PrimExpr, RampNode)
Reduction operator.
Definition: expr.h:841
int value_index
the index of this reduce node
Definition: expr.h:857
CommReducer combiner
The commutative combiner.
Definition: expr.h:844
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Reduce", ReduceNode, PrimExprNode)
static void RegisterReflection()
Definition: expr.h:859
ffi::Array< PrimExpr > source
The source operand.
Definition: expr.h:846
PrimExpr condition
Predicate on the reduction Only add the body to reduction if condition is true.
Definition: expr.h:855
ffi::Array< IterVar > axis
The reduction axis.
Definition: expr.h:850
ffi::Array< PrimExpr > init
The init operand.
Definition: expr.h:848
Managed reference to ReduceNode.
Definition: expr.h:876
TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode)
Reduce(CommReducer combiner, ffi::Array< PrimExpr > src, ffi::Array< IterVar > rdom, PrimExpr condition, int value_index, ffi::Array< PrimExpr > init, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Reduce, PrimExpr, ReduceNode)
return true_value if condition is true, otherwise return false_value.
Definition: expr.h:492
PrimExpr condition
The condition.
Definition: expr.h:495
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Select", SelectNode, PrimExprNode)
PrimExpr true_value
value to be returned when condition is true.
Definition: expr.h:497
PrimExpr false_value
value to be returned when condition is false.
Definition: expr.h:499
static void RegisterReflection()
Definition: expr.h:501
Managed reference to SelectNode.
Definition: expr.h:515
TVM_DEFINE_OBJECT_REF_COW_METHOD(SelectNode)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Select, PrimExpr, SelectNode)
Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Shuffle instruction. vec = concat(vectors) result = (vec[indices[0]], vec[indices[1]] ....
Definition: expr.h:757
ffi::Array< PrimExpr > indices
The indices of each element.
Definition: expr.h:762
static void RegisterReflection()
Definition: expr.h:764
ffi::Array< PrimExpr > vectors
the input vectors.
Definition: expr.h:760
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Shuffle", ShuffleNode, PrimExprNode)
Managed reference to ShuffleNode.
Definition: expr.h:777
static PrimExpr Concat(ffi::Array< PrimExpr > vectors, Span span=Span())
Shuffle(ffi::Array< PrimExpr > vectors, ffi::Array< PrimExpr > indices, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Shuffle, PrimExpr, ShuffleNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ShuffleNode)
static PrimExpr ExtractElement(PrimExpr vector, int index, Span span=Span())
ffi::String constants, only used in asserts.
Definition: expr.h:53
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.StringImm", StringImmNode, PrimExprNode)
ffi::String value
The constant value content.
Definition: expr.h:56
static void RegisterReflection()
Definition: expr.h:58
Managed reference to StringImmNode.
Definition: expr.h:69
StringImm(ffi::String value, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StringImm, PrimExpr, StringImmNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode)
a - b
Definition: expr.h:143
static constexpr const char * _type_key
Definition: expr.h:145
Managed reference to SubNode.
Definition: expr.h:152
Sub(PrimExpr a, PrimExpr b, Span span=Span())
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Sub, PrimExpr, SubNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(SubNode)
a named variable in TIR
Definition: var.h:77
Defines the Functor data structures.
Definition: repr_printer.h:91
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
std::unordered_map< K, V > as_unordered_map(const ffi::Map< K, V > &dmap)
Definition: expr.h:895
tvm::FloatImmNode FloatImmNode
Definition: expr.h:50
tvm::IntImmNode IntImmNode
Definition: expr.h:49
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Definitions and helper macros for IR/AST nodes.
static TVM_FFI_INLINE tvm::tir::StringImm ConvertFallbackValue(ffi::String value)
Definition: expr.h:912
a <= b
Definition: expr.h:359
static constexpr const char * _type_key
Definition: expr.h:361