40 #include <type_traits>
44 #define TVM_TIR_REGISTER_OP(OpName) \
45 TVM_REGISTER_OP("tirx." OpName).set_attr<TScriptPrinterName>("TScriptPrinterName", OpName)
584 ffi::Array<PrimExpr> init = {}, Span span = Span());
594 ffi::Array<PrimExpr> init = {}, Span span = Span());
605 ffi::Array<PrimExpr> init = {}, Span span = Span());
616 ffi::Array<PrimExpr> init = {}, Span span = Span());
627 ffi::Array<PrimExpr> init = {}, Span span = Span());
638 ffi::Array<PrimExpr> init = {}, Span span = Span());
731 <<
"tirx." << op_name <<
" only supports floating-point inputs, but got " << dtype;
735 #define TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckInputDType) \
736 inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
737 static const Op& op = Op::Get("tirx." #OpName); \
738 CheckInputDType(#OpName, x.dtype()); \
739 if (x.dtype().is_bfloat16()) { \
740 DataType bf16_dtype = x.dtype(); \
741 DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \
742 PrimExpr x_fp32 = tirx::Cast(fp32_dtype, {x}, span); \
743 PrimExpr result_fp32 = tirx::Call(fp32_dtype, op, {x_fp32}, span); \
744 return tirx::Cast(bf16_dtype, {result_fp32}, span); \
746 return tirx::Call(x.dtype(), op, {x}, span); \
750 #define TVM_DECLARE_INTRIN_UNARY(OpName) \
751 TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, [](const char*, DataType) {})
753 #define TVM_DECLARE_FLOAT_INTRIN_UNARY(OpName) \
754 TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckMathUnaryOpInputDType)
782 #define TVM_DECLARE_INTRIN_BINARY(OpName) \
783 inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \
784 static const Op& op = Op::Get("tirx." #OpName); \
785 return tirx::Call(x.dtype(), op, {x, y}, span); \
803 if (!type.defined())
return false;
805 if (
const auto* prim_type = ptr_type->element_type.as<
PrimTypeNode>()) {
806 return prim_type->dtype == element_type;
820 template <
typename ValueType,
821 typename =
typename std::enable_if<std::is_pod<ValueType>::value>::type>
855 if (!x.defined())
return nullptr;
917 template <
typename FReduce>
921 init_value = freduce(init_value, val, span);
953 return as_int && (*as_int > 0);
958 return as_int && (*as_int < 0);
963 return as_int && (*as_int == value);
967 if (!stmt.defined())
return true;
972 return op->seq.size() == 0;
977 template <
typename ValueType>
982 uint64_t uval =
static_cast<uint64_t
>(value);
983 if (value <
static_cast<ValueType
>(0)) {
984 TVM_FFI_THROW(InternalError) <<
"cannot make uint from negative value " << value;
986 return IntImm(t,
static_cast<int64_t
>(value), span);
988 uint64_t mask = (
static_cast<uint64_t
>(1) << 32U) - 1U;
989 uint64_t low = uval & mask;
990 uint64_t high = uval >> 32U;
991 return LargeUIntImm(t,
static_cast<int64_t
>(low),
static_cast<int64_t
>(high), span);
995 return FloatImm(t,
static_cast<double>(value), span);
1001 return FloatImm(t,
static_cast<double>(value), span);
1003 TVM_FFI_THROW(InternalError) <<
"cannot make const for type " << t;
1012 template <
typename ValueType,
typename>
1037 #define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
1038 inline PrimExpr Name(PrimExpr& a, PrimExpr b) { \
1043 #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \
1044 inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \
1045 inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \
1046 inline PrimExpr Name(int a, const PrimExpr& b) { \
1047 return Name(tirx::make_const(b.dtype(), a), b); \
1049 inline PrimExpr Name(const PrimExpr& a, int b) { \
1050 return Name(a, tirx::make_const(a.dtype(), b)); \
1052 inline PrimExpr Name(const PrimExpr& a, double b) { \
1053 return Name(a, tirx::make_const(DataType::Float(64), b)); \
1056 #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(Name) \
1057 inline PrimExpr Name(const PrimExpr& a, float b, Span span = Span()) { \
1058 return Name(a, PrimExpr(b), span); \
1060 inline PrimExpr Name(float a, const PrimExpr& b, Span span = Span()) { \
1061 return Name(PrimExpr(a), b, span); \
1063 inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \
1064 return Name(tirx::make_const(b.dtype(), a), b, span); \
1066 inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \
1067 return Name(a, tirx::make_const(a.dtype(), b), span); \
1069 inline PrimExpr Name(const PrimExpr& a, double b, Span span = Span()) { \
1070 return Name(a, tirx::make_const(DataType::Float(64), b), span); \
1073 #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
1074 inline PrimExpr Name(const PrimExpr& a, bool b) { return Name(a, PrimExpr(b)); } \
1075 inline PrimExpr Name(bool a, const PrimExpr& b) { return Name(PrimExpr(a), b); }
1077 #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \
1078 inline PrimExpr Name(const PrimExpr& a, bool b, Span span = Span()) { \
1079 return Name(a, PrimExpr(b), span); \
1081 inline PrimExpr Name(bool a, const PrimExpr& b, Span span = Span()) { \
1082 return Name(PrimExpr(a), b, span); \
1085 #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
1086 inline PrimExpr Name(const PrimExpr& a, int b) { \
1087 return Name(a, tirx::make_const(a.dtype(), b)); \
1089 inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tirx::make_const(b.dtype(), a), b); }
1091 #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \
1092 inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \
1093 return Name(a, tirx::make_const(a.dtype(), b), span); \
1095 inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \
1096 return Name(tirx::make_const(b.dtype(), a), b, span); \
1148 template <
typename TA>
1150 constexpr
bool div_ambiguity = !std::is_class<TA>::value;
1151 static_assert(div_ambiguity,
1152 "TVM supports multiple types of integer divisions, "
1153 "please call div, indexdiv/indexmod, "
1154 "floordiv/floormod or truncdiv/truncmod directly "
1155 "to avoid ambiguity in the code. "
1156 "Checkout these functions in tirx/op.h.");
1164 template <
typename TB>
1170 template <
typename TB>
1176 template <
typename TB>
Constant floating point literals in the program.
Definition: expr.h:529
Managed reference class to FloatImmNode.
Definition: expr.h:546
Constant integer literals in the program.
Definition: expr.h:494
int64_t value
the Internal value.
Definition: expr.h:497
Managed reference class to IntImmNode.
Definition: expr.h:511
Low-level raw pointer type.
Definition: type.h:152
Reference to PrimExprNode.
Definition: expr.h:126
Primitive data types used in the low-level IR.
Definition: type.h:112
Definition: source_map.h:111
Managed reference to TypeNode.
Definition: type.h:99
Runtime primitive data type.
Definition: data_type.h:47
bool is_handle() const
Definition: data_type.h:198
bool is_uint() const
Definition: data_type.h:196
bool is_float6() const
Definition: data_type.h:159
DataType element_of() const
Get the scalar version of the type.
Definition: data_type.h:240
@ kCustomBegin
Definition: data_type.h:75
bool is_bool() const
Definition: data_type.h:143
bool is_int() const
Definition: data_type.h:194
int code() const
Definition: data_type.h:114
int lanes() const
Definition: data_type.h:120
static DataType Bool(int lanes=1, bool is_scalable=false)
Construct a bool type.
Definition: data_type.h:387
int vscale_factor() const
Definition: data_type.h:129
bool is_fixed_length_vector() const
Definition: data_type.h:205
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:278
bool is_scalar() const
Definition: data_type.h:141
bool is_float8() const
Definition: data_type.h:151
bool is_bfloat16() const
Definition: data_type.h:192
bool is_float4() const
Definition: data_type.h:164
static DataType UInt(int bits, int lanes=1, bool is_scalable=false)
Construct an uint type.
Definition: data_type.h:286
bool is_float() const
Definition: data_type.h:147
Create a vector where all the elements are value.
Definition: expr.h:657
Managed reference to BroadcastNode.
Definition: expr.h:677
Managed reference to CallNode.
Definition: expr.h:744
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:336
Managed reference to MulNode.
Definition: expr.h:169
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:311
Container of all statements.
Definition: stmt.h:65
Primitive operators(builtin intrinsics) and registry for them.
IR/AST nodes for the unified type system in TVM.
const Op & vscale()
Get the target's vscale value. It will be lowered to llvm.vscale intrinsic (https://llvm....
bool is_const_number(const PrimExpr &x)
Check whether x is an integer/float constant.
Definition: op.h:939
bool is_zero(const PrimExpr &x)
Check whether x is a constant integer 0.
Definition: op.h:892
bool is_const_power_of_two_integer(const PrimExpr &x, int *shift)
Check whether x is a constant power of two If x is power of two, write the power to the shift.
bool IsPointerType(const Type &type, const DataType &element_type)
Check if type is a pointer to a runtime element type.
Definition: op.h:802
bool is_positive_const(const PrimExpr &a)
Definition: op.h:951
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:1013
PrimExpr MakeConstScalar(DataType t, ValueType value, Span span=Span())
Definition: op.h:978
PrimExpr const_false(int lanes=1, Span span=Span())
Make a constant false expression.
Definition: op.h:845
PrimExpr make_zero(DataType t, Span span=Span())
Make a const zero expr.
Definition: op.h:1027
PrimExpr const_true(int lanes=1, Span span=Span())
Make a constant true expression.
Definition: op.h:836
bool is_negative_const(const PrimExpr &a)
Definition: op.h:956
bool is_const_int(const PrimExpr &x, int64_t value)
Check whether x is a constant integer expression.
Definition: op.h:961
bool is_one(const PrimExpr &x)
Check whether x is a constant integer 1.
Definition: op.h:884
bool is_no_op(const tirx::Stmt &stmt)
Check whether stmt is nop.
Definition: op.h:966
PrimExpr foldl(FReduce freduce, PrimExpr init_value, const ffi::Array< PrimExpr > &values, Span span=Span())
Left fold.
Definition: op.h:918
const int64_t * as_const_int(const PrimExpr &x)
Get x as constant int expression.
Definition: op.h:854
std::function< PrimExpr(PrimExpr source, const ffi::Array< IterVar > &axis, ffi::Array< PrimExpr > init, Span span)> FReduce
The operation to use for CommReduce.
Definition: reduction.h:47
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
runtime::DataType GetRuntimeDataType(const Type &type)
Get the implied DataType for storing values with type during runtime.
PrimExpr isfinite(PrimExpr x, Span span=Span())
Check if x is finite.
PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span=Span())
compute ceil(a / b)
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr tanh(PrimExpr x, Span span=Span())
Definition: op.h:760
PrimExpr erf(PrimExpr x, Span span=Span())
Definition: op.h:759
PrimExpr shapediv(PrimExpr a, PrimExpr b, Span span=Span())
compute ceil(a / b) where a and b are non-negative.
PrimExpr log10(PrimExpr x, Span span=Span())
Definition: op.h:766
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
PrimExpr operator/(PrimExpr a, PrimExpr b)
division operator
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of truncdiv
PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span=Span())
and
PrimExpr hypot(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:791
PrimExpr log1p(PrimExpr x, Span span=Span())
Definition: op.h:767
void DivAmbiguityError(const TA &a)
Helper function to raise a compiler error about division ambiguity.
Definition: op.h:1149
PrimExpr likely(PrimExpr cond, Span span=Span())
Mark condition as likely.
PrimExpr reinterpret(const DataType &t, PrimExpr value, Span span=Span())
perform reinterpret cast value to type.
PrimExpr atan2(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:788
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
PrimExpr min_value(const DataType &dtype, Span span=Span())
PrimExpr bitwise_neg(PrimExpr a, Span span=Span())
take bitwise negation of two values
PrimExpr cosh(PrimExpr x, Span span=Span())
Definition: op.h:771
PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span=Span())
or
PrimExpr thread_return(Span span=Span())
Return from a thread.
PrimExpr atan(PrimExpr x, Span span=Span())
Definition: op.h:776
Type GetType(const PrimExpr &expr)
Get the type of the expression under the unified type system.
PrimExpr isnan(PrimExpr x, Span span=Span())
Check if x is NaN.
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr max_value(const DataType &dtype, Span span=Span())
PrimExpr exp2(PrimExpr x, Span span=Span())
Definition: op.h:757
PrimExpr rsqrt(PrimExpr x, Span span=Span())
Definition: op.h:763
PrimExpr operator/=(const PrimExpr &a, const TB &b)
Definition: op.h:1171
PrimExpr asinh(PrimExpr x, Span span=Span())
Definition: op.h:778
PrimExpr less(PrimExpr a, PrimExpr b, Span span=Span())
less
PrimExpr sin(PrimExpr x, Span span=Span())
Definition: op.h:772
PrimExpr trunc(PrimExpr x, Span span=Span())
Calculate trunc(x)
PrimExpr round(PrimExpr x, Span span=Span())
Round x to the nearest integer, ties to even.
Type GetTypeFromRuntimeDataType(const DataType &dtype)
Get the type corresponding to DataType.
PrimExpr neg(PrimExpr a, Span span=Span())
negation.
PrimExpr ceil(PrimExpr x, Span span=Span())
Calculate ceil(x)
PrimExpr pow(PrimExpr x, PrimExpr y, Span span=Span())
Calculate power(x, y)
PrimExpr logical_not(PrimExpr a, Span span=Span())
not
PrimExpr exp10(PrimExpr x, Span span=Span())
Definition: op.h:758
PrimExpr copysign(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:790
PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span=Span())
take bitwise xor of two values
PrimExpr less_equal(PrimExpr a, PrimExpr b, Span span=Span())
less_equal
PrimExpr any(PrimExpr source, ffi::Array< tirx::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
logical Or of source expression over axis
PrimExpr greater(PrimExpr a, PrimExpr b, Span span=Span())
greater
PrimExpr exp(PrimExpr x, Span span=Span())
Definition: op.h:756
PrimExpr logaddexp(PrimExpr a, PrimExpr b, Span span=Span())
Compute log(exp(a) + exp(b)).
PrimExpr floormod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of floordiv
PrimExpr infinity(const DataType &dtype, Span span=Span())
PrimExpr sub(PrimExpr a, PrimExpr b, Span span=Span())
subtraction operator
PrimExpr all(PrimExpr source, ffi::Array< tirx::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
logical And of source expression over axis
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr nextafter(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:789
PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span span=Span())
Construct a large uint constant by its low 32 bits and high 32bits.
PrimExpr asin(PrimExpr x, Span span=Span())
Definition: op.h:774
PrimExpr prod(PrimExpr source, ffi::Array< tirx::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
product of source expression over axis
PrimExpr sigmoid(PrimExpr x, Span span=Span())
Definition: op.h:761
PrimExpr max(const PrimExpr &a, double b, Span span=Span())
Definition: op.h:1109
PrimExpr acos(PrimExpr x, Span span=Span())
Definition: op.h:775
PrimExpr mul(PrimExpr a, PrimExpr b, Span span=Span())
multiplication operator
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
void CheckMathUnaryOpInputDType(const char *op_name, DataType dtype)
Definition: op.h:729
PrimExpr sum(PrimExpr source, ffi::Array< tirx::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
PrimExpr floor(PrimExpr x, Span span=Span())
Calculate floor(x)
PrimExpr greater_equal(PrimExpr a, PrimExpr b, Span span=Span())
greater_equal
PrimExpr operator%(const PrimExpr &a, const TB &b)
Definition: op.h:1177
PrimExpr abs(PrimExpr x, Span span=Span())
Calculate absolute value of x.
PrimExpr atanh(PrimExpr x, Span span=Span())
Definition: op.h:779
PrimExpr sqrt(PrimExpr x, Span span=Span())
Definition: op.h:762
PrimExpr isinf(PrimExpr x, Span span=Span())
Check if x is infinite.
PrimExpr continue_loop(Span span=Span())
Continue current loop.
PrimExpr log2(PrimExpr x, Span span=Span())
Definition: op.h:765
PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span=Span())
not_equal
PrimExpr ldexp(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:792
PrimExpr truncdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute trunc(a / b)
PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span span=Span())
Execute a multiplication between two Q-numbers x and y followed by a right shift s....
PrimExpr popcount(PrimExpr x, Span span=Span())
Definition: op.h:768
PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span=Span())
take bitwise and of two values
PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span=Span())
left shift operator
PrimExpr sinh(PrimExpr x, Span span=Span())
Definition: op.h:773
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
PrimExpr break_loop(Span span=Span())
Break current loop.
PrimExpr add(PrimExpr a, PrimExpr b, Span span=Span())
add operator
PrimExpr log(PrimExpr x, Span span=Span())
Definition: op.h:764
PrimExpr nearbyint(PrimExpr x, Span span=Span())
Round x to the nearest integer, ties to even.
PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span=Span())
right shift operator
PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span=Span())
take bitwise or of two values
PrimExpr clz(PrimExpr x, Span span=Span())
Definition: op.h:780
PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b)
PrimExpr acosh(PrimExpr x, Span span=Span())
Definition: op.h:777
PrimExpr tan(PrimExpr x, Span span=Span())
Definition: op.h:769
PrimExpr cos(PrimExpr x, Span span=Span())
Definition: op.h:770
PrimExpr fast_erf_float_expr(PrimExpr arg, int bits)
Fast_erf_float expression from Eigen.
#define TVM_DECLARE_INTRIN_UNARY(OpName)
Definition: op.h:750
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name)
Definition: op.h:1085
#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD_SPANNED(Name)
Definition: op.h:1077
#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc)
Definition: op.h:1037
#define TVM_DECLARE_FLOAT_INTRIN_UNARY(OpName)
Definition: op.h:753
#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(Name)
Definition: op.h:1056
#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name)
Definition: op.h:1043
#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name)
Definition: op.h:1073
#define TVM_DECLARE_INTRIN_BINARY(OpName)
Definition: op.h:782
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(Name)
Definition: op.h:1091