tvm
op.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 // Acknowledgement: Most operator APIs originate from Halide.
28 #ifndef TVM_TIR_OP_H_
29 #define TVM_TIR_OP_H_
30 
31 #include <tvm/ir/expr.h>
32 #include <tvm/ir/op.h>
33 #include <tvm/ir/type.h>
34 #include <tvm/tir/expr.h>
35 #include <tvm/tir/stmt.h>
36 
37 #include <algorithm>
38 #include <limits>
39 #include <type_traits>
40 
41 namespace tvm {
42 
43 // Most common operators can be overloaded by argument type(PrimExpr).
44 // So we put them under the root namespace.
45 // It is also necessary to overload operators for PrimExpr.
46 //
47 // We put more developer oriented APIs -- make_const and is_const under tir
48 // as they are more specific to the tir namespace.
49 
61 TVM_DLL Type GetType(const PrimExpr& expr);
62 
71 TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type);
72 
80 TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span());
81 
88 TVM_DLL PrimExpr max_value(const DataType& dtype, Span span = Span());
89 
96 TVM_DLL PrimExpr min_value(const DataType& dtype, Span span = Span());
97 
104 TVM_DLL PrimExpr infinity(const DataType& dtype, Span span = Span());
105 
115 TVM_DLL PrimExpr cast(const DataType& t, PrimExpr value, Span span = Span());
125 TVM_DLL PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span = Span());
136 TVM_DLL PrimExpr add(PrimExpr a, PrimExpr b, Span span = Span());
146 TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b);
157 TVM_DLL PrimExpr sub(PrimExpr a, PrimExpr b, Span span = Span());
167 TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b);
177 TVM_DLL PrimExpr neg(PrimExpr a, Span span = Span());
186 TVM_DLL PrimExpr operator-(PrimExpr a);
197 TVM_DLL PrimExpr mul(PrimExpr a, PrimExpr b, Span span = Span());
207 TVM_DLL PrimExpr operator*(PrimExpr a, PrimExpr b);
217 TVM_DLL PrimExpr operator/(PrimExpr a, PrimExpr b);
228 TVM_DLL PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span = Span());
238 TVM_DLL PrimExpr operator<<(PrimExpr a, PrimExpr b);
249 TVM_DLL PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span = Span());
259 TVM_DLL PrimExpr operator>>(PrimExpr a, PrimExpr b);
270 TVM_DLL PrimExpr greater(PrimExpr a, PrimExpr b, Span span = Span());
280 TVM_DLL PrimExpr operator>(PrimExpr a, PrimExpr b);
291 TVM_DLL PrimExpr greater_equal(PrimExpr a, PrimExpr b, Span span = Span());
301 TVM_DLL PrimExpr operator>=(PrimExpr a, PrimExpr b);
312 TVM_DLL PrimExpr less(PrimExpr a, PrimExpr b, Span span = Span());
322 TVM_DLL PrimExpr operator<(PrimExpr a, PrimExpr b);
333 TVM_DLL PrimExpr less_equal(PrimExpr a, PrimExpr b, Span span = Span());
343 TVM_DLL PrimExpr operator<=(PrimExpr a, PrimExpr b);
354 TVM_DLL PrimExpr equal(PrimExpr a, PrimExpr b, Span span = Span());
364 TVM_DLL PrimExpr operator==(PrimExpr a, PrimExpr b);
375 TVM_DLL PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span = Span());
385 TVM_DLL PrimExpr operator!=(PrimExpr a, PrimExpr b);
395 TVM_DLL PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span = Span());
404 TVM_DLL PrimExpr operator&&(PrimExpr a, PrimExpr b);
414 TVM_DLL PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span = Span());
423 TVM_DLL PrimExpr operator||(PrimExpr a, PrimExpr b);
432 TVM_DLL PrimExpr logical_not(PrimExpr a, Span span = Span());
440 TVM_DLL PrimExpr operator!(PrimExpr a);
455 TVM_DLL PrimExpr div(PrimExpr a, PrimExpr b, Span span = Span());
468 TVM_DLL PrimExpr truncdiv(PrimExpr a, PrimExpr b, Span span = Span());
481 TVM_DLL PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span = Span());
497 TVM_DLL PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span = Span());
512 TVM_DLL PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span = Span());
523 TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span = Span());
534 TVM_DLL PrimExpr floormod(PrimExpr a, PrimExpr b, Span span = Span());
545 TVM_DLL PrimExpr max(PrimExpr a, PrimExpr b, Span span = Span());
556 TVM_DLL PrimExpr min(PrimExpr a, PrimExpr b, Span span = Span());
567 TVM_DLL PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span = Span());
577 TVM_DLL PrimExpr operator&(PrimExpr a, PrimExpr b);
588 TVM_DLL PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span = Span());
598 TVM_DLL PrimExpr operator|(PrimExpr a, PrimExpr b);
609 TVM_DLL PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span = Span());
619 TVM_DLL PrimExpr operator^(PrimExpr a, PrimExpr b);
629 TVM_DLL PrimExpr bitwise_neg(PrimExpr a, Span span = Span());
638 TVM_DLL PrimExpr operator~(PrimExpr a);
650 TVM_DLL PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value,
651  Span span = Span());
658 TVM_DLL PrimExpr likely(PrimExpr cond, Span span = Span());
665 TVM_DLL PrimExpr pow(PrimExpr x, PrimExpr y, Span span = Span());
673 TVM_DLL PrimExpr abs(PrimExpr x, Span span = Span());
680 TVM_DLL PrimExpr isnan(PrimExpr x, Span span = Span());
681 
688 TVM_DLL PrimExpr isfinite(PrimExpr x, Span span = Span());
689 
696 TVM_DLL PrimExpr isinf(PrimExpr x, Span span = Span());
697 
706 TVM_DLL PrimExpr sum(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
707  Span span = Span());
708 
716 TVM_DLL PrimExpr all(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
717  Span span = Span());
718 
727 TVM_DLL PrimExpr any(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
728  Span span = Span());
729 
738 TVM_DLL PrimExpr max(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
739  Span span = Span());
740 
749 TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
750  Span span = Span());
751 
760 TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
761  Span span = Span());
762 
769 TVM_DLL PrimExpr floor(PrimExpr x, Span span = Span());
770 
777 TVM_DLL PrimExpr ceil(PrimExpr x, Span span = Span());
778 
785 TVM_DLL PrimExpr round(PrimExpr x, Span span = Span());
786 
794 TVM_DLL PrimExpr nearbyint(PrimExpr x, Span span = Span());
795 
802 TVM_DLL PrimExpr trunc(PrimExpr x, Span span = Span());
803 
812 TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span span = Span());
813 
834 TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s,
835  Span span = Span());
836 
837 // Intrinsic operators
838 #define TVM_DECLARE_INTRIN_UNARY(OpName) \
839  inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
840  static const Op& op = Op::Get("tir." #OpName); \
841  return tir::Call(x.dtype(), op, {x}, span); \
842  }
843 
868 
869 #define TVM_DECLARE_INTRIN_BINARY(OpName) \
870  inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \
871  static const Op& op = Op::Get("tir." #OpName); \
872  return tir::Call(x.dtype(), op, {x, y}, span); \
873  }
874 
880 
881 namespace tir {
882 
889 inline bool IsPointerType(const Type& type, const DataType& element_type) {
890  if (!type.defined()) return false;
891  if (const auto* ptr_type = type.as<PointerTypeNode>()) {
892  if (const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>()) {
893  return prim_type->dtype == element_type;
894  }
895  }
896  return false;
897 }
898 
907 template <typename ValueType,
908  typename = typename std::enable_if<std::is_pod<ValueType>::value>::type>
909 inline PrimExpr make_const(DataType t, ValueType value, Span span = Span());
916 inline PrimExpr make_zero(DataType t, Span span = Span());
923 inline PrimExpr const_true(int lanes = 1, Span span = Span()) {
924  return make_const(DataType::UInt(1, lanes), 1);
925 }
932 inline PrimExpr const_false(int lanes = 1, Span span = Span()) {
933  return make_const(DataType::UInt(1, lanes), 0);
934 }
941 inline const int64_t* as_const_int(const PrimExpr& x) {
942  if (!x.defined()) return nullptr;
943  if (const tir::IntImmNode* op = x.as<tir::IntImmNode>()) {
944  return &(op->value);
945  } else {
946  return nullptr;
947  }
948 }
949 
956 inline bool is_const_int(const PrimExpr& x, int64_t value);
957 
963 inline bool is_no_op(const tir::Stmt& stmt);
964 
971 inline bool is_one(const PrimExpr& x) { return is_const_int(x, 1); }
972 
979 inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); }
980 
986 inline bool is_const_int(const PrimExpr& x);
987 
993 inline bool is_const_number(const PrimExpr& x);
994 
1004 template <typename FReduce>
1005 inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr>& values,
1006  Span span = Span());
1007 
1016 TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift);
1017 
1018 // Implementation details after this
1019 inline bool is_const_int(const PrimExpr& x) {
1020  if (x.as<tir::IntImmNode>()) {
1021  return true;
1022  } else if (const auto* op = x.as<tir::BroadcastNode>()) {
1023  const PrimExpr& val = op->value;
1024  if (val.as<tir::IntImmNode>()) {
1025  return true;
1026  }
1027  }
1028  return false;
1029 }
1030 
1031 inline bool is_const_number(const PrimExpr& x) {
1032  if (x.as<tir::IntImmNode>()) {
1033  return true;
1034  } else if (x.as<tir::FloatImmNode>()) {
1035  return true;
1036  } else if (const auto* op = x.as<tir::BroadcastNode>()) {
1037  return (op->value->IsInstance<tir::IntImmNode>() || op->value->IsInstance<tir::FloatImmNode>());
1038  }
1039  return false;
1040 }
1041 
1042 inline bool is_positive_const(const PrimExpr& a) {
1043  if (const tir::IntImmNode* op = a.as<tir::IntImmNode>()) {
1044  return op->value > 0;
1045  } else {
1046  return false;
1047  }
1048 }
1049 
1050 inline bool is_negative_const(const PrimExpr& a) {
1051  if (const tir::IntImmNode* op = a.as<tir::IntImmNode>()) {
1052  return op->value < 0;
1053  } else {
1054  return false;
1055  }
1056 }
1057 
1058 inline bool is_const_int(const PrimExpr& x, int64_t value) {
1059  if (const auto* op = x.as<tir::IntImmNode>()) {
1060  return op->value == value;
1061  } else if (const auto* op = x.as<tir::BroadcastNode>()) {
1062  const PrimExpr& val = op->value;
1063  if (const auto* opv = val.as<tir::IntImmNode>()) {
1064  return opv->value == value;
1065  }
1066  }
1067  return false;
1068 }
1069 
1070 inline bool is_no_op(const tir::Stmt& stmt) {
1071  if (!stmt.defined()) return true;
1072  if (const auto* op = stmt.as<tir::EvaluateNode>()) {
1073  return is_const_int(op->value);
1074  }
1075  if (const auto* op = stmt.as<tir::SeqStmtNode>()) {
1076  return op->seq.size() == 0;
1077  }
1078  return false;
1079 }
1080 
1081 template <typename ValueType>
1082 inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) {
1083  if (t.is_int()) return IntImm(t, static_cast<int64_t>(value), span);
1084  if (t.is_uint()) {
1085  // Use IntImm if it is a small integer
1086  uint64_t uval = static_cast<uint64_t>(value);
1087  if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
1088  return IntImm(t, static_cast<int64_t>(value), span);
1089  } else {
1090  uint64_t mask = (static_cast<uint64_t>(1) << 32U) - 1U;
1091  uint64_t low = uval & mask;
1092  uint64_t high = uval >> 32U;
1093  return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high), span);
1094  }
1095  }
1096  if (t.is_float() || t.is_bfloat16()) return FloatImm(t, static_cast<double>(value), span);
1097  // For now, we store const scalar values of custom datatypes within doubles; later, during the
1098  // datatypes lowering pass, we will lower the value to its true representation in the format
1099  // specified by the datatype.
1100  // TODO(gus) when do we need to start worrying about doubles not being precise enough?
1101  if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(DataType::kCustomBegin)) {
1102  return FloatImm(t, static_cast<double>(value), span);
1103  }
1104  LOG(FATAL) << "cannot make const for type " << t;
1105  return PrimExpr();
1106 }
1107 
1108 template <typename ValueType, typename>
1109 inline PrimExpr make_const(DataType t, ValueType value, Span span) {
1110  if (t.lanes() == 1) {
1111  return MakeConstScalar(t, value, span);
1112  } else {
1113  return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span);
1114  }
1115 }
1116 
1117 inline PrimExpr make_zero(DataType t, Span span) {
1118  if (t.is_handle()) {
1119  return reinterpret(t, make_const(DataType::UInt(64), 0, span));
1120  }
1121  return make_const(t, 0, span);
1122 }
1123 
1124 template <typename FReduce>
1125 inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr>& values,
1126  Span span) {
1127  for (PrimExpr val : values) {
1128  init_value = freduce(init_value, val, span);
1129  }
1130  return init_value;
1131 }
1132 
1133 } // namespace tir
1134 
1135 // additional const expression overloading
1136 #define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
1137  inline PrimExpr Name(PrimExpr& a, PrimExpr b) { \
1138  a = OpFunc(a, b); \
1139  return a; \
1140  }
1141 
1142 #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \
1143  inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \
1144  inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \
1145  inline PrimExpr Name(int a, const PrimExpr& b) { \
1146  return Name(tir::make_const(b.dtype(), a), b); \
1147  } \
1148  inline PrimExpr Name(const PrimExpr& a, int b) { \
1149  return Name(a, tir::make_const(a.dtype(), b)); \
1150  } \
1151  inline PrimExpr Name(const PrimExpr& a, double b) { \
1152  return Name(a, tir::make_const(DataType::Float(64), b)); \
1153  }
1154 
1155 #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(Name) \
1156  inline PrimExpr Name(const PrimExpr& a, float b, Span span = Span()) { \
1157  return Name(a, PrimExpr(b), span); \
1158  } \
1159  inline PrimExpr Name(float a, const PrimExpr& b, Span span = Span()) { \
1160  return Name(PrimExpr(a), b, span); \
1161  } \
1162  inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \
1163  return Name(tir::make_const(b.dtype(), a), b, span); \
1164  } \
1165  inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \
1166  return Name(a, tir::make_const(a.dtype(), b), span); \
1167  } \
1168  inline PrimExpr Name(const PrimExpr& a, double b, Span span = Span()) { \
1169  return Name(a, tir::make_const(DataType::Float(64), b), span); \
1170  }
1171 
1172 #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
1173  inline PrimExpr Name(const PrimExpr& a, bool b) { return Name(a, PrimExpr(b)); } \
1174  inline PrimExpr Name(bool a, const PrimExpr& b) { return Name(PrimExpr(a), b); }
1175 
1176 #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \
1177  inline PrimExpr Name(const PrimExpr& a, bool b, Span span = Span()) { \
1178  return Name(a, PrimExpr(b), span); \
1179  } \
1180  inline PrimExpr Name(bool a, const PrimExpr& b, Span span = Span()) { \
1181  return Name(PrimExpr(a), b, span); \
1182  }
1183 
1184 #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
1185  inline PrimExpr Name(const PrimExpr& a, int b) { \
1186  return Name(a, tir::make_const(a.dtype(), b)); \
1187  } \
1188  inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tir::make_const(b.dtype(), a), b); }
1189 
1190 #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \
1191  inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \
1192  return Name(a, tir::make_const(a.dtype(), b), span); \
1193  } \
1194  inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \
1195  return Name(tir::make_const(b.dtype(), a), b, span); \
1196  }
1197 
1198 TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+);
1199 TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-);
1200 TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*);
1204 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>); // NOLINT(*)
1206 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*)
1218 // integer related ops
1230 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*)
1231 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*)
1235 // logical ops
1240 
1246 template <typename TA>
1247 inline void DivAmbiguityError(const TA& a) {
1248  constexpr bool div_ambiguity = !std::is_class<TA>::value;
1249  static_assert(div_ambiguity,
1250  "TVM supports multiple types of integer divisions, "
1251  "please call div, indexdiv/indexmod, "
1252  "floordiv/floormod or truncdiv/truncmod directly "
1253  "to avoid ambiguity in the code. "
1254  "Checkout these functions in tir/op.h.");
1255 }
1256 
1257 // The following code are not intended to be used in the codebase.
1258 // Instead, they generate clear compiler errors that ask developers
1259 // to use the specific division function.
1260 // The second template argument is necessary to make sure the
1261 // code compiles lazily by the compiler during invocation.
1262 template <typename TB>
1263 inline PrimExpr operator/(const PrimExpr& a, const TB& b) {
1264  DivAmbiguityError(a);
1265  return a;
1266 }
1267 
1268 template <typename TB>
1269 inline PrimExpr operator/=(const PrimExpr& a, const TB& b) {
1270  DivAmbiguityError(a);
1271  return a;
1272 }
1273 
1274 template <typename TB>
1275 inline PrimExpr operator%(const PrimExpr& a, const TB& b) {
1276  DivAmbiguityError(a);
1277  return a;
1278 }
1279 } // namespace tvm
1280 #endif // TVM_TIR_OP_H_
PrimExpr rsqrt(PrimExpr x, Span span=Span())
Definition: op.h:851
PrimExpr operator!=(PrimExpr a, PrimExpr b)
not_equal
tvm::Span Span
Definition: base.h:65
bool is_const_power_of_two_integer(const PrimExpr &x, int *shift)
Check whether x is a constant power of two If x is power of two, write the power to the shift...
PrimExpr operator<(PrimExpr a, PrimExpr b)
less
PrimExpr likely(PrimExpr cond, Span span=Span())
Mark condition as likely.
bool is_int() const
Definition: data_type.h:99
PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span=Span())
take bitwise xor of two values
PrimExpr log10(PrimExpr x, Span span=Span())
Definition: op.h:854
Bool operator &&(const Bool &a, bool b)
Definition: expr.h:342
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
PrimExpr neg(PrimExpr a, Span span=Span())
negation.
PrimExpr greater_equal(PrimExpr a, PrimExpr b, Span span=Span())
greater_equal
bool is_one(const PrimExpr &x)
Check whether x is a constant integer 1.
Definition: op.h:971
PrimExpr popcount(PrimExpr x, Span span=Span())
Definition: op.h:855
PrimExpr atan(PrimExpr x, Span span=Span())
Definition: op.h:863
Bool operator||(const Bool &a, bool b)
Definition: expr.h:337
PrimExpr abs(PrimExpr x, Span span=Span())
Calculate absolute value of x.
PrimExpr floor(PrimExpr x, Span span=Span())
Calculate floor(x)
PrimExpr exp10(PrimExpr x, Span span=Span())
Definition: op.h:846
Definition: data_type.h:57
PrimExpr min_value(const DataType &dtype, Span span=Span())
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:1109
Base expr nodes in TVM.
PrimExpr sinh(PrimExpr x, Span span=Span())
Definition: op.h:860
PrimExpr add(PrimExpr a, PrimExpr b, Span span=Span())
add operator
#define TVM_DECLARE_INTRIN_BINARY(OpName)
Definition: op.h:869
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc)
Definition: op.h:1136
PrimExpr tan(PrimExpr x, Span span=Span())
Definition: op.h:856
PrimExpr sub(PrimExpr a, PrimExpr b, Span span=Span())
subtraction operator
PrimExpr atanh(PrimExpr x, Span span=Span())
Definition: op.h:866
PrimExpr nearbyint(PrimExpr x, Span span=Span())
Calculates std::nearbyint(x)
bool is_float() const
Definition: data_type.h:93
PrimExpr asin(PrimExpr x, Span span=Span())
Definition: op.h:861
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:592
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
PrimExpr ceil(PrimExpr x, Span span=Span())
Calculate ceil(x)
PrimExpr ldexp(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:879
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
Constant floating point literals in the program.
Definition: expr.h:279
PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span=Span())
or
int code() const
Definition: data_type.h:81
PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span=Span())
take bitwise or of two values
PrimExpr max(const PrimExpr &a, double b, Span span=Span())
Definition: op.h:1208
PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array< PrimExpr > &values, Span span=Span())
Left fold.
Definition: op.h:1125
const int64_t * as_const_int(const PrimExpr &x)
Get x as constant int expression.
Definition: op.h:941
PrimExpr greater(PrimExpr a, PrimExpr b, Span span=Span())
greater
PrimExpr MakeConstScalar(DataType t, ValueType value, Span span=Span())
Definition: op.h:1082
PrimExpr atan2(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:875
PrimExpr operator-(PrimExpr a, PrimExpr b)
subtraction operator
#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name)
Definition: op.h:1172
runtime::DataType GetRuntimeDataType(const Type &type)
Get the implied DataType for storing values with type during runtime.
Low-level raw pointer type.
Definition: type.h:150
PrimExpr asinh(PrimExpr x, Span span=Span())
Definition: op.h:865
#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD_SPANNED(Name)
Definition: op.h:1176
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr log(PrimExpr x, Span span=Span())
Definition: op.h:852
PrimExpr round(PrimExpr x, Span span=Span())
Calculate round(x)
Constant integer literals in the program.
Definition: expr.h:233
Primitive operators(builtin intrinsics) and registry for them.
Managed reference class to FloatImmNode.
Definition: expr.h:308
PrimExpr exp2(PrimExpr x, Span span=Span())
Definition: op.h:845
PrimExpr operator &(PrimExpr a, PrimExpr b)
take bitwise and of two values
PrimExpr less_equal(PrimExpr a, PrimExpr b, Span span=Span())
less_equal
PrimExpr const_true(int lanes=1, Span span=Span())
Make a constant true expression.
Definition: op.h:923
Definition: span.h:115
TIR statements.
PrimExpr floormod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of floordiv
PrimExpr operator/=(const PrimExpr &a, const TB &b)
Definition: op.h:1269
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
PrimExpr hypot(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:878
PrimExpr operator!(PrimExpr a)
not
PrimExpr const_false(int lanes=1, Span span=Span())
Make a constant false expression.
Definition: op.h:932
IR/AST nodes for the unified type system in TVM.
Managed reference to BroadcastNode.
Definition: expr.h:839
bool defined() const
Definition: object.h:537
Runtime primitive data type.
Definition: data_type.h:41
TIR expressions.
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of of source expression over axis
PrimExpr reinterpret(const DataType &t, PrimExpr value, Span span=Span())
perform reinterpret cast value to type.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
Managed reference class to IntImmNode.
Definition: expr.h:262
PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span span=Span())
Execute a multiplication between two Q-numbers x and y followed by a right shift s. The mathematical expression is:
PrimExpr operator<<(PrimExpr a, PrimExpr b)
left shift operator
#define TVM_DECLARE_INTRIN_UNARY(OpName)
Definition: op.h:838
PrimExpr operator^(PrimExpr a, PrimExpr b)
take bitwise xor of two values
PrimExpr isfinite(PrimExpr x, Span span=Span())
Check if x is finite.
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name)
Definition: op.h:1184
Create a vector where all the elements are value.
Definition: expr.h:807
PrimExpr clz(PrimExpr x, Span span=Span())
Definition: op.h:867
Container of all statements.
Definition: stmt.h:57
PrimExpr cosh(PrimExpr x, Span span=Span())
Definition: op.h:858
bool is_uint() const
Definition: data_type.h:101
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name)
Definition: op.h:1142
int64_t value
the Internal value.
Definition: expr.h:236
PrimExpr make_zero(DataType t, Span span=Span())
Make a const zero expr.
Definition: op.h:1117
PrimExpr bitwise_neg(PrimExpr a, Span span=Span())
take bitwise negation of two values
PrimExpr any(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
logical Or of of source expression over axis
PrimExpr erf(PrimExpr x, Span span=Span())
Definition: op.h:847
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:738
PrimExpr operator>>(PrimExpr a, PrimExpr b)
right shift operator
PrimExpr nextafter(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:876
PrimExpr operator==(PrimExpr a, PrimExpr b)
equal
int lanes() const
Definition: data_type.h:87
PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span=Span())
and
PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of truncdiv
tvm::Type Type
Definition: type.h:47
bool is_bfloat16() const
Definition: data_type.h:97
PrimExpr operator>=(PrimExpr a, PrimExpr b)
greater_equal
PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b)
PrimExpr acos(PrimExpr x, Span span=Span())
Definition: op.h:862
Type GetType(const PrimExpr &expr)
Get the type of the expression under the unified type system.
PrimExpr operator*(PrimExpr a, PrimExpr b)
multiplication operator
PrimExpr log2(PrimExpr x, Span span=Span())
Definition: op.h:853
PrimExpr all(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
logical And of of source expression over axis
bool IsPointerType(const Type &type, const DataType &element_type)
Check if type is a pointer to a runtime element type.
Definition: op.h:889
bool is_const_int(const PrimExpr &x, int64_t value)
Check whether x is a constant integer expression.
Definition: op.h:1058
PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span=Span())
not_equal
PrimExpr cos(PrimExpr x, Span span=Span())
Definition: op.h:857
PrimExpr acosh(PrimExpr x, Span span=Span())
Definition: op.h:864
PrimExpr sqrt(PrimExpr x, Span span=Span())
Definition: op.h:850
PrimExpr tanh(PrimExpr x, Span span=Span())
Definition: op.h:848
PrimExpr infinity(const DataType &dtype, Span span=Span())
PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span span=Span())
Construct a large uint constant by its low 32 bits and high 32bits.
PrimExpr max_value(const DataType &dtype, Span span=Span())
bool is_const_number(const PrimExpr &x)
Check whether x is an integer/float constant.
Definition: op.h:1031
PrimExpr trunc(PrimExpr x, Span span=Span())
Calculate trunc(x)
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(Name)
Definition: op.h:1190
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
PrimExpr sin(PrimExpr x, Span span=Span())
Definition: op.h:859
std::function< PrimExpr(PrimExpr source, const Array< IterVar > &axis, Array< PrimExpr > init, Span span)> FReduce
The operation to use for CommReduce.
Definition: reduction.h:47
bool is_negative_const(const PrimExpr &a)
Definition: op.h:1050
void DivAmbiguityError(const TA &a)
Helper function to raise a compiler error about division ambiguity.
Definition: op.h:1247
PrimExpr operator/(PrimExpr a, PrimExpr b)
division operator
PrimExpr mul(PrimExpr a, PrimExpr b, Span span=Span())
multiplication operator
PrimExpr operator<=(PrimExpr a, PrimExpr b)
less_equal
PrimExpr copysign(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:877
Managed reference to TypeNode.
Definition: type.h:93
PrimExpr logical_not(PrimExpr a, Span span=Span())
not
bool is_positive_const(const PrimExpr &a)
Definition: op.h:1042
PrimExpr operator%(const PrimExpr &a, const TB &b)
Definition: op.h:1275
PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span=Span())
right shift operator
PrimExpr operator~(PrimExpr a)
take bitwise negation of two values
PrimExpr exp(PrimExpr x, Span span=Span())
Definition: op.h:844
bool is_handle() const
Definition: data_type.h:103
Reference to PrimExprNode.
Definition: expr.h:109
PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span=Span())
take bitwise and of two values
Primitive data types used in the low-level IR.
Definition: type.h:106
bool is_no_op(const tir::Stmt &stmt)
Check whether stmt is nop.
Definition: op.h:1070
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:858
PrimExpr truncdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute trunc(a / b)
PrimExpr sigmoid(PrimExpr x, Span span=Span())
Definition: op.h:849
#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(Name)
Definition: op.h:1155
PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span=Span())
left shift operator
PrimExpr operator|(PrimExpr a, PrimExpr b)
take bitwise or of two values
PrimExpr prod(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
product of of source expression over axis
PrimExpr operator+(PrimExpr a, PrimExpr b)
add operator
bool is_zero(const PrimExpr &x)
Check whether x is a constant integer 0.
Definition: op.h:979
runtime::DataType DataType
Definition: data_type.h:389
PrimExpr less(PrimExpr a, PrimExpr b, Span span=Span())
less
static DataType UInt(int bits, int lanes=1)
Construct an uint type.
Definition: data_type.h:161
PrimExpr isinf(PrimExpr x, Span span=Span())
Check if x is infinite.
PrimExpr pow(PrimExpr x, PrimExpr y, Span span=Span())
Calculate power(x, y)
DataType element_of() const
Get the scalar version of the type.
Definition: data_type.h:126
PrimExpr isnan(PrimExpr x, Span span=Span())
Check if x is NaN.
PrimExpr operator>(PrimExpr a, PrimExpr b)
greater