tvm
op.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 // Acknowledgement: Most operator APIs originate from Halide.
28 #ifndef TVM_TIRX_OP_H_
29 #define TVM_TIRX_OP_H_
30 
31 #include <tvm/ir/expr.h>
32 #include <tvm/ir/op.h>
33 #include <tvm/ir/type.h>
34 #include <tvm/tirx/builtin.h>
35 #include <tvm/tirx/expr.h>
36 #include <tvm/tirx/stmt.h>
39 
40 #include <algorithm>
41 #include <limits>
42 #include <type_traits>
43 
44 namespace tvm {
45 
46 #define TVM_TIR_REGISTER_OP(OpName) \
47  TVM_REGISTER_OP("tirx." OpName).set_attr<TScriptPrinterName>("TScriptPrinterName", OpName)
48 
49 #define TVM_TIRX_REGISTER_OP(OpName) TVM_TIR_REGISTER_OP(OpName)
50 
51 // Most common operators can be overloaded by argument type(PrimExpr).
52 // So we put them under the root namespace.
53 //
54 // We put more developer oriented APIs -- make_const and is_const under tirx
55 // as they are more specific to the tirx namespace.
56 
68 TVM_DLL Type GetType(const PrimExpr& expr);
69 
78 
88 
96 TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span());
97 
104 TVM_DLL PrimExpr thread_return(Span span = Span());
105 
111 TVM_DLL PrimExpr continue_loop(Span span = Span());
112 
118 TVM_DLL PrimExpr break_loop(Span span = Span());
119 
126 TVM_DLL PrimExpr max_value(const DataType& dtype, Span span = Span());
127 
134 TVM_DLL PrimExpr min_value(const DataType& dtype, Span span = Span());
135 
142 TVM_DLL PrimExpr infinity(const DataType& dtype, Span span = Span());
143 
153 TVM_DLL PrimExpr cast(const DataType& t, PrimExpr value, Span span = Span());
163 TVM_DLL PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span = Span());
174 TVM_DLL PrimExpr add(PrimExpr a, PrimExpr b, Span span = Span());
185 TVM_DLL PrimExpr sub(PrimExpr a, PrimExpr b, Span span = Span());
195 TVM_DLL PrimExpr neg(PrimExpr a, Span span = Span());
206 TVM_DLL PrimExpr mul(PrimExpr a, PrimExpr b, Span span = Span());
217 TVM_DLL PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span = Span());
228 TVM_DLL PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span = Span());
239 TVM_DLL PrimExpr greater(PrimExpr a, PrimExpr b, Span span = Span());
261 TVM_DLL PrimExpr less(PrimExpr a, PrimExpr b, Span span = Span());
272 TVM_DLL PrimExpr less_equal(PrimExpr a, PrimExpr b, Span span = Span());
283 TVM_DLL PrimExpr equal(PrimExpr a, PrimExpr b, Span span = Span());
294 TVM_DLL PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span = Span());
304 TVM_DLL PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span = Span());
314 TVM_DLL PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span = Span());
323 TVM_DLL PrimExpr logical_not(PrimExpr a, Span span = Span());
338 TVM_DLL PrimExpr div(PrimExpr a, PrimExpr b, Span span = Span());
351 TVM_DLL PrimExpr truncdiv(PrimExpr a, PrimExpr b, Span span = Span());
364 TVM_DLL PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span = Span());
380 TVM_DLL PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span = Span());
396 TVM_DLL PrimExpr shapediv(PrimExpr a, PrimExpr b, Span span = Span());
411 TVM_DLL PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span = Span());
422 TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span = Span());
431 TVM_DLL PrimExpr logaddexp(PrimExpr a, PrimExpr b, Span span = Span());
443 TVM_DLL PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span = Span());
454 TVM_DLL PrimExpr floormod(PrimExpr a, PrimExpr b, Span span = Span());
465 TVM_DLL PrimExpr max(PrimExpr a, PrimExpr b, Span span = Span());
476 TVM_DLL PrimExpr min(PrimExpr a, PrimExpr b, Span span = Span());
487 TVM_DLL PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span = Span());
498 TVM_DLL PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span = Span());
509 TVM_DLL PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span = Span());
519 TVM_DLL PrimExpr bitwise_neg(PrimExpr a, Span span = Span());
531 TVM_DLL PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value,
532  Span span = Span());
539 TVM_DLL PrimExpr likely(PrimExpr cond, Span span = Span());
546 TVM_DLL PrimExpr pow(PrimExpr x, PrimExpr y, Span span = Span());
554 TVM_DLL PrimExpr abs(PrimExpr x, Span span = Span());
561 TVM_DLL PrimExpr isnan(PrimExpr x, Span span = Span());
562 
569 TVM_DLL PrimExpr isfinite(PrimExpr x, Span span = Span());
570 
577 TVM_DLL PrimExpr isinf(PrimExpr x, Span span = Span());
578 
587 TVM_DLL PrimExpr sum(PrimExpr source, ffi::Array<tirx::IterVar> axis,
588  ffi::Array<PrimExpr> init = {}, Span span = Span());
589 
597 TVM_DLL PrimExpr all(PrimExpr source, ffi::Array<tirx::IterVar> axis,
598  ffi::Array<PrimExpr> init = {}, Span span = Span());
599 
608 TVM_DLL PrimExpr any(PrimExpr source, ffi::Array<tirx::IterVar> axis,
609  ffi::Array<PrimExpr> init = {}, Span span = Span());
610 
619 TVM_DLL PrimExpr max(PrimExpr source, ffi::Array<tirx::IterVar> axis,
620  ffi::Array<PrimExpr> init = {}, Span span = Span());
621 
630 TVM_DLL PrimExpr min(PrimExpr source, ffi::Array<tirx::IterVar> axis,
631  ffi::Array<PrimExpr> init = {}, Span span = Span());
632 
641 TVM_DLL PrimExpr prod(PrimExpr source, ffi::Array<tirx::IterVar> axis,
642  ffi::Array<PrimExpr> init = {}, Span span = Span());
643 
650 TVM_DLL PrimExpr floor(PrimExpr x, Span span = Span());
651 
658 TVM_DLL PrimExpr ceil(PrimExpr x, Span span = Span());
659 
670 TVM_DLL PrimExpr round(PrimExpr x, Span span = Span());
671 
681 TVM_DLL PrimExpr nearbyint(PrimExpr x, Span span = Span());
682 
689 TVM_DLL PrimExpr trunc(PrimExpr x, Span span = Span());
690 
699 TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span span = Span());
700 
722  Span span = Span());
723 
731 TVM_DLL PrimExpr fast_erf_float_expr(PrimExpr arg, int bits);
732 
733 inline void CheckMathUnaryOpInputDType(const char* op_name, DataType dtype) {
734  TVM_FFI_CHECK(dtype.is_float() || dtype.is_bfloat16(), TypeError)
735  << "tirx." << op_name << " only supports floating-point inputs, but got " << dtype;
736 }
737 
738 // Intrinsic operators
739 #define TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckInputDType) \
740  inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
741  static const Op& op = Op::Get("tirx." #OpName); \
742  CheckInputDType(#OpName, x.dtype()); \
743  if (x.dtype().is_bfloat16()) { \
744  DataType bf16_dtype = x.dtype(); \
745  DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \
746  PrimExpr x_fp32 = tirx::Cast(fp32_dtype, {x}, span); \
747  PrimExpr result_fp32 = tirx::Call(fp32_dtype, op, {x_fp32}, span); \
748  return tirx::Cast(bf16_dtype, {result_fp32}, span); \
749  } else { \
750  return tirx::Call(x.dtype(), op, {x}, span); \
751  } \
752  }
753 
754 #define TVM_DECLARE_INTRIN_UNARY(OpName) \
755  TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, [](const char*, DataType) {})
756 
757 #define TVM_DECLARE_FLOAT_INTRIN_UNARY(OpName) \
758  TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckMathUnaryOpInputDType)
759 
785 
786 #define TVM_DECLARE_INTRIN_BINARY(OpName) \
787  inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \
788  static const Op& op = Op::Get("tirx." #OpName); \
789  return tirx::Call(x.dtype(), op, {x, y}, span); \
790  }
791 
797 
798 namespace tirx {
799 
806 inline bool IsPointerType(const Type& type, const DataType& element_type) {
807  if (!type.defined()) return false;
808  if (const auto* ptr_type = type.as<PointerTypeNode>()) {
809  if (const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>()) {
810  return prim_type->dtype == element_type;
811  }
812  }
813  return false;
814 }
815 
824 template <typename ValueType,
825  typename = typename std::enable_if<std::is_pod<ValueType>::value>::type>
826 inline PrimExpr make_const(DataType t, ValueType value, Span span = Span());
833 inline PrimExpr make_zero(DataType t, Span span = Span());
840 inline PrimExpr const_true(int lanes = 1, Span span = Span()) {
841  return make_const(DataType::Bool(lanes), 1);
842 }
849 inline PrimExpr const_false(int lanes = 1, Span span = Span()) {
850  return make_const(DataType::Bool(lanes), 0);
851 }
858 inline const int64_t* as_const_int(const PrimExpr& x) {
859  if (!x.defined()) return nullptr;
860  if (const tirx::IntImmNode* op = x.as<tirx::IntImmNode>()) {
861  return &(op->value);
862  }
863 
864  return nullptr;
865 }
866 
873 inline bool is_const_int(const PrimExpr& x, int64_t value);
874 
880 inline bool is_no_op(const tirx::Stmt& stmt);
881 
888 inline bool is_one(const PrimExpr& x) { return is_const_int(x, 1); }
889 
896 inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); }
897 
903 inline bool is_const_int(const PrimExpr& x);
904 
910 inline bool is_const_number(const PrimExpr& x);
911 
921 template <typename FReduce>
922 inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const ffi::Array<PrimExpr>& values,
923  Span span = Span()) {
924  for (PrimExpr val : values) {
925  init_value = freduce(init_value, val, span);
926  }
927  return init_value;
928 }
929 
938 TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift);
939 
940 // Implementation details after this
941 inline bool is_const_int(const PrimExpr& x) { return as_const_int(x); }
942 
943 inline bool is_const_number(const PrimExpr& x) {
944  if (x.as<tirx::IntImmNode>()) {
945  return true;
946  } else if (x.as<tirx::FloatImmNode>()) {
947  return true;
948  } else if (const auto* op = x.as<tirx::BroadcastNode>()) {
949  return (op->value->IsInstance<tirx::IntImmNode>() ||
950  op->value->IsInstance<tirx::FloatImmNode>());
951  }
952  return false;
953 }
954 
955 inline bool is_positive_const(const PrimExpr& a) {
956  const int64_t* as_int = as_const_int(a);
957  return as_int && (*as_int > 0);
958 }
959 
960 inline bool is_negative_const(const PrimExpr& a) {
961  const int64_t* as_int = as_const_int(a);
962  return as_int && (*as_int < 0);
963 }
964 
965 inline bool is_const_int(const PrimExpr& x, int64_t value) {
966  const int64_t* as_int = as_const_int(x);
967  return as_int && (*as_int == value);
968 }
969 
970 inline bool is_no_op(const tirx::Stmt& stmt) {
971  if (!stmt.defined()) return true;
972  if (const auto* op = stmt.as<tirx::EvaluateNode>()) {
973  return is_const_int(op->value);
974  }
975  if (const auto* op = stmt.as<tirx::SeqStmtNode>()) {
976  return op->seq.size() == 0;
977  }
978  return false;
979 }
980 
981 template <typename ValueType>
982 inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) {
983  if (t.is_int() || t.is_bool()) return IntImm(t, static_cast<int64_t>(value), span);
984  if (t.is_uint()) {
985  // Use IntImm if it is a small integer
986  uint64_t uval = static_cast<uint64_t>(value);
987  if (value < static_cast<ValueType>(0)) {
988  TVM_FFI_THROW(InternalError) << "cannot make uint from negative value " << value;
989  } else if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
990  return IntImm(t, static_cast<int64_t>(value), span);
991  } else {
992  uint64_t mask = (static_cast<uint64_t>(1) << 32U) - 1U;
993  uint64_t low = uval & mask;
994  uint64_t high = uval >> 32U;
995  return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high), span);
996  }
997  }
998  if (t.is_float() || t.is_bfloat16() || t.is_float8() || t.is_float6() || t.is_float4())
999  return FloatImm(t, static_cast<double>(value), span);
1000  // For now, we store const scalar values of custom datatypes within doubles; later, during the
1001  // datatypes lowering pass, we will lower the value to its true representation in the format
1002  // specified by the datatype.
1003  // TODO(gus) when do we need to start worrying about doubles not being precise enough?
1004  if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(DataType::kCustomBegin)) {
1005  return FloatImm(t, static_cast<double>(value), span);
1006  }
1007  TVM_FFI_THROW(InternalError) << "cannot make const for type " << t;
1008  throw;
1009 }
1010 
1011 template <>
1012 inline PrimExpr MakeConstScalar(DataType t, bool value, Span span) {
1013  return MakeConstScalar(t, static_cast<int>(value), span);
1014 }
1015 
1016 template <typename ValueType, typename>
1017 inline PrimExpr make_const(DataType t, ValueType value, Span span) {
1018  if (t.is_scalar()) {
1019  return MakeConstScalar(t, value, span);
1020  } else {
1021  if (t.is_fixed_length_vector()) {
1022  return tirx::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span);
1023  } else {
1024  PrimExpr lanes =
1026  return tirx::Broadcast(MakeConstScalar(t.element_of(), value, span), lanes, span);
1027  }
1028  }
1029 }
1030 
1031 inline PrimExpr make_zero(DataType t, Span span) {
1032  if (t.is_handle()) {
1033  return reinterpret(t, make_const(DataType::UInt(64), 0, span));
1034  }
1035  return make_const(t, 0, span);
1036 }
1037 
1038 } // namespace tirx
1039 
1040 // additional const expression overloading
1041 #define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
1042  inline PrimExpr Name(PrimExpr& a, PrimExpr b) { \
1043  a = OpFunc(a, b); \
1044  return a; \
1045  }
1046 
1047 #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \
1048  inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \
1049  inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \
1050  inline PrimExpr Name(int a, const PrimExpr& b) { \
1051  return Name(tirx::make_const(b.dtype(), a), b); \
1052  } \
1053  inline PrimExpr Name(const PrimExpr& a, int b) { \
1054  return Name(a, tirx::make_const(a.dtype(), b)); \
1055  } \
1056  inline PrimExpr Name(const PrimExpr& a, double b) { \
1057  return Name(a, tirx::make_const(DataType::Float(64), b)); \
1058  }
1059 
1060 #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(Name) \
1061  inline PrimExpr Name(const PrimExpr& a, float b, Span span = Span()) { \
1062  return Name(a, PrimExpr(b), span); \
1063  } \
1064  inline PrimExpr Name(float a, const PrimExpr& b, Span span = Span()) { \
1065  return Name(PrimExpr(a), b, span); \
1066  } \
1067  inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \
1068  return Name(tirx::make_const(b.dtype(), a), b, span); \
1069  } \
1070  inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \
1071  return Name(a, tirx::make_const(a.dtype(), b), span); \
1072  } \
1073  inline PrimExpr Name(const PrimExpr& a, double b, Span span = Span()) { \
1074  return Name(a, tirx::make_const(DataType::Float(64), b), span); \
1075  }
1076 
1077 #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
1078  inline PrimExpr Name(const PrimExpr& a, bool b) { return Name(a, PrimExpr(b)); } \
1079  inline PrimExpr Name(bool a, const PrimExpr& b) { return Name(PrimExpr(a), b); }
1080 
1081 #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \
1082  inline PrimExpr Name(const PrimExpr& a, bool b, Span span = Span()) { \
1083  return Name(a, PrimExpr(b), span); \
1084  } \
1085  inline PrimExpr Name(bool a, const PrimExpr& b, Span span = Span()) { \
1086  return Name(PrimExpr(a), b, span); \
1087  }
1088 
1089 #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
1090  inline PrimExpr Name(const PrimExpr& a, int b) { \
1091  return Name(a, tirx::make_const(a.dtype(), b)); \
1092  } \
1093  inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tirx::make_const(b.dtype(), a), b); }
1094 
1095 #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \
1096  inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \
1097  return Name(a, tirx::make_const(a.dtype(), b), span); \
1098  } \
1099  inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \
1100  return Name(tirx::make_const(b.dtype(), a), b, span); \
1101  }
1102 
1103 TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+);
1104 TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-);
1105 TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*);
1109 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>); // NOLINT(*)
1111 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*)
1123 // integer related ops
1136 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*)
1137 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*)
1141 // logical ops
1146 
1152 template <typename TA>
1153 inline void DivAmbiguityError(const TA& a) {
1154  constexpr bool div_ambiguity = !std::is_class<TA>::value;
1155  static_assert(div_ambiguity,
1156  "TVM supports multiple types of integer divisions, "
1157  "please call div, indexdiv/indexmod, "
1158  "floordiv/floormod or truncdiv/truncmod directly "
1159  "to avoid ambiguity in the code. "
1160  "Checkout these functions in tirx/op.h.");
1161 }
1162 
1163 // The following code are not intended to be used in the codebase.
1164 // Instead, they generate clear compiler errors that ask developers
1165 // to use the specific division function.
1166 // The second template argument is necessary to make sure the
1167 // code compiles lazily by the compiler during invocation.
1168 template <typename TB>
1169 inline PrimExpr operator/(const PrimExpr& a, const TB& b) {
1170  DivAmbiguityError(a);
1171  return a;
1172 }
1173 
1174 template <typename TB>
1175 inline PrimExpr operator/=(const PrimExpr& a, const TB& b) {
1176  DivAmbiguityError(a);
1177  return a;
1178 }
1179 
1180 template <typename TB>
1181 inline PrimExpr operator%(const PrimExpr& a, const TB& b) {
1182  DivAmbiguityError(a);
1183  return a;
1184 }
1185 } // namespace tvm
1186 #endif // TVM_TIR_OP_H_
Constant floating point literals in the program.
Definition: expr.h:529
Managed reference class to FloatImmNode.
Definition: expr.h:546
Constant integer literals in the program.
Definition: expr.h:494
int64_t value
the Internal value.
Definition: expr.h:497
Managed reference class to IntImmNode.
Definition: expr.h:511
Low-level raw pointer type.
Definition: type.h:152
Reference to PrimExprNode.
Definition: expr.h:126
Primitive data types used in the low-level IR.
Definition: type.h:112
Definition: source_map.h:111
Managed reference to TypeNode.
Definition: type.h:99
Runtime primitive data type.
Definition: data_type.h:45
bool is_handle() const
Definition: data_type.h:196
bool is_uint() const
Definition: data_type.h:194
bool is_float6() const
Definition: data_type.h:157
DataType element_of() const
Get the scalar version of the type.
Definition: data_type.h:238
@ kCustomBegin
Definition: data_type.h:73
bool is_bool() const
Definition: data_type.h:141
bool is_int() const
Definition: data_type.h:192
int code() const
Definition: data_type.h:112
int lanes() const
Definition: data_type.h:118
static DataType Bool(int lanes=1, bool is_scalable=false)
Construct a bool type.
Definition: data_type.h:385
int vscale_factor() const
Definition: data_type.h:127
bool is_fixed_length_vector() const
Definition: data_type.h:203
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:276
bool is_scalar() const
Definition: data_type.h:139
bool is_float8() const
Definition: data_type.h:149
bool is_bfloat16() const
Definition: data_type.h:190
bool is_float4() const
Definition: data_type.h:162
static DataType UInt(int bits, int lanes=1, bool is_scalable=false)
Construct an uint type.
Definition: data_type.h:284
bool is_float() const
Definition: data_type.h:145
Create a vector where all the elements are value.
Definition: expr.h:658
Managed reference to BroadcastNode.
Definition: expr.h:678
Managed reference to CallNode.
Definition: expr.h:745
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:338
Managed reference to MulNode.
Definition: expr.h:170
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:313
Container of all statements.
Definition: stmt.h:67
Base expr nodes in TVM.
Primitive operators(builtin intrinsics) and registry for them.
IR/AST nodes for the unified type system in TVM.
const Op & vscale()
Get the target's vscale value. It will be lowered to llvm.vscale intrinsic (https://llvm....
bool is_const_number(const PrimExpr &x)
Check whether x is an integer/float constant.
Definition: op.h:943
bool is_zero(const PrimExpr &x)
Check whether x is a constant integer 0.
Definition: op.h:896
bool is_const_power_of_two_integer(const PrimExpr &x, int *shift)
Check whether x is a constant power of two If x is power of two, write the power to the shift.
bool IsPointerType(const Type &type, const DataType &element_type)
Check if type is a pointer to a runtime element type.
Definition: op.h:806
bool is_positive_const(const PrimExpr &a)
Definition: op.h:955
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:1017
PrimExpr MakeConstScalar(DataType t, ValueType value, Span span=Span())
Definition: op.h:982
PrimExpr const_false(int lanes=1, Span span=Span())
Make a constant false expression.
Definition: op.h:849
PrimExpr make_zero(DataType t, Span span=Span())
Make a const zero expr.
Definition: op.h:1031
PrimExpr const_true(int lanes=1, Span span=Span())
Make a constant true expression.
Definition: op.h:840
bool is_negative_const(const PrimExpr &a)
Definition: op.h:960
bool is_const_int(const PrimExpr &x, int64_t value)
Check whether x is a constant integer expression.
Definition: op.h:965
bool is_one(const PrimExpr &x)
Check whether x is a constant integer 1.
Definition: op.h:888
bool is_no_op(const tirx::Stmt &stmt)
Check whether stmt is nop.
Definition: op.h:970
PrimExpr foldl(FReduce freduce, PrimExpr init_value, const ffi::Array< PrimExpr > &values, Span span=Span())
Left fold.
Definition: op.h:922
const int64_t * as_const_int(const PrimExpr &x)
Get x as constant int expression.
Definition: op.h:858
std::function< PrimExpr(PrimExpr source, const ffi::Array< IterVar > &axis, ffi::Array< PrimExpr > init, Span span)> FReduce
The operation to use for CommReduce.
Definition: reduction.h:47
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
runtime::DataType GetRuntimeDataType(const Type &type)
Get the implied DataType for storing values with type during runtime.
PrimExpr isfinite(PrimExpr x, Span span=Span())
Check if x is finite.
PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span=Span())
compute ceil(a / b)
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr tanh(PrimExpr x, Span span=Span())
Definition: op.h:764
PrimExpr erf(PrimExpr x, Span span=Span())
Definition: op.h:763
PrimExpr shapediv(PrimExpr a, PrimExpr b, Span span=Span())
compute ceil(a / b) where a and b are non-negative.
PrimExpr log10(PrimExpr x, Span span=Span())
Definition: op.h:770
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
PrimExpr operator/(PrimExpr a, PrimExpr b)
division operator
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of truncdiv
PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span=Span())
and
PrimExpr hypot(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:795
PrimExpr log1p(PrimExpr x, Span span=Span())
Definition: op.h:771
void DivAmbiguityError(const TA &a)
Helper function to raise a compiler error about division ambiguity.
Definition: op.h:1153
PrimExpr likely(PrimExpr cond, Span span=Span())
Mark condition as likely.
PrimExpr reinterpret(const DataType &t, PrimExpr value, Span span=Span())
perform reinterpret cast value to type.
PrimExpr atan2(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:792
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
PrimExpr min_value(const DataType &dtype, Span span=Span())
PrimExpr bitwise_neg(PrimExpr a, Span span=Span())
take bitwise negation of two values
PrimExpr cosh(PrimExpr x, Span span=Span())
Definition: op.h:775
PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span=Span())
or
PrimExpr thread_return(Span span=Span())
Return from a thread.
PrimExpr atan(PrimExpr x, Span span=Span())
Definition: op.h:780
Type GetType(const PrimExpr &expr)
Get the type of the expression under the unified type system.
PrimExpr isnan(PrimExpr x, Span span=Span())
Check if x is NaN.
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr max_value(const DataType &dtype, Span span=Span())
PrimExpr exp2(PrimExpr x, Span span=Span())
Definition: op.h:761
PrimExpr rsqrt(PrimExpr x, Span span=Span())
Definition: op.h:767
PrimExpr operator/=(const PrimExpr &a, const TB &b)
Definition: op.h:1175
PrimExpr asinh(PrimExpr x, Span span=Span())
Definition: op.h:782
PrimExpr less(PrimExpr a, PrimExpr b, Span span=Span())
less
PrimExpr sin(PrimExpr x, Span span=Span())
Definition: op.h:776
PrimExpr trunc(PrimExpr x, Span span=Span())
Calculate trunc(x)
PrimExpr round(PrimExpr x, Span span=Span())
Round x to the nearest integer, ties to even.
Type GetTypeFromRuntimeDataType(const DataType &dtype)
Get the type corresponding to DataType.
PrimExpr neg(PrimExpr a, Span span=Span())
negation.
PrimExpr ceil(PrimExpr x, Span span=Span())
Calculate ceil(x)
PrimExpr pow(PrimExpr x, PrimExpr y, Span span=Span())
Calculate power(x, y)
PrimExpr logical_not(PrimExpr a, Span span=Span())
not
PrimExpr exp10(PrimExpr x, Span span=Span())
Definition: op.h:762
PrimExpr copysign(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:794
PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span=Span())
take bitwise xor of two values
PrimExpr less_equal(PrimExpr a, PrimExpr b, Span span=Span())
less_equal
PrimExpr any(PrimExpr source, ffi::Array< tirx::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
logical Or of source expression over axis
PrimExpr greater(PrimExpr a, PrimExpr b, Span span=Span())
greater
PrimExpr exp(PrimExpr x, Span span=Span())
Definition: op.h:760
PrimExpr logaddexp(PrimExpr a, PrimExpr b, Span span=Span())
Compute log(exp(a) + exp(b)).
PrimExpr floormod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of floordiv
PrimExpr infinity(const DataType &dtype, Span span=Span())
PrimExpr sub(PrimExpr a, PrimExpr b, Span span=Span())
subtraction operator
PrimExpr all(PrimExpr source, ffi::Array< tirx::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
logical And of source expression over axis
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr nextafter(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:793
PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span span=Span())
Construct a large uint constant by its low 32 bits and high 32bits.
PrimExpr asin(PrimExpr x, Span span=Span())
Definition: op.h:778
PrimExpr prod(PrimExpr source, ffi::Array< tirx::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
product of source expression over axis
PrimExpr sigmoid(PrimExpr x, Span span=Span())
Definition: op.h:765
PrimExpr max(const PrimExpr &a, double b, Span span=Span())
Definition: op.h:1113
PrimExpr acos(PrimExpr x, Span span=Span())
Definition: op.h:779
PrimExpr mul(PrimExpr a, PrimExpr b, Span span=Span())
multiplication operator
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
void CheckMathUnaryOpInputDType(const char *op_name, DataType dtype)
Definition: op.h:733
PrimExpr sum(PrimExpr source, ffi::Array< tirx::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
PrimExpr floor(PrimExpr x, Span span=Span())
Calculate floor(x)
PrimExpr greater_equal(PrimExpr a, PrimExpr b, Span span=Span())
greater_equal
PrimExpr operator%(const PrimExpr &a, const TB &b)
Definition: op.h:1181
PrimExpr abs(PrimExpr x, Span span=Span())
Calculate absolute value of x.
PrimExpr atanh(PrimExpr x, Span span=Span())
Definition: op.h:783
PrimExpr sqrt(PrimExpr x, Span span=Span())
Definition: op.h:766
PrimExpr isinf(PrimExpr x, Span span=Span())
Check if x is infinite.
PrimExpr continue_loop(Span span=Span())
Continue current loop.
PrimExpr log2(PrimExpr x, Span span=Span())
Definition: op.h:769
PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span=Span())
not_equal
PrimExpr ldexp(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:796
PrimExpr truncdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute trunc(a / b)
PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span span=Span())
Execute a multiplication between two Q-numbers x and y followed by a right shift s....
PrimExpr popcount(PrimExpr x, Span span=Span())
Definition: op.h:772
PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span=Span())
take bitwise and of two values
PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span=Span())
left shift operator
PrimExpr sinh(PrimExpr x, Span span=Span())
Definition: op.h:777
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
PrimExpr break_loop(Span span=Span())
Break current loop.
PrimExpr add(PrimExpr a, PrimExpr b, Span span=Span())
add operator
PrimExpr log(PrimExpr x, Span span=Span())
Definition: op.h:768
PrimExpr nearbyint(PrimExpr x, Span span=Span())
Round x to the nearest integer, ties to even.
PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span=Span())
right shift operator
PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span=Span())
take bitwise or of two values
PrimExpr clz(PrimExpr x, Span span=Span())
Definition: op.h:784
PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b)
PrimExpr acosh(PrimExpr x, Span span=Span())
Definition: op.h:781
PrimExpr tan(PrimExpr x, Span span=Span())
Definition: op.h:773
PrimExpr cos(PrimExpr x, Span span=Span())
Definition: op.h:774
PrimExpr fast_erf_float_expr(PrimExpr arg, int bits)
Fast_erf_float expression from Eigen.
TIR builtin intrinsics.
TIR expressions.
#define TVM_DECLARE_INTRIN_UNARY(OpName)
Definition: op.h:754
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name)
Definition: op.h:1089
#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD_SPANNED(Name)
Definition: op.h:1081
#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc)
Definition: op.h:1041
#define TVM_DECLARE_FLOAT_INTRIN_UNARY(OpName)
Definition: op.h:757
#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(Name)
Definition: op.h:1060
#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name)
Definition: op.h:1047
#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name)
Definition: op.h:1077
#define TVM_DECLARE_INTRIN_BINARY(OpName)
Definition: op.h:786
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(Name)
Definition: op.h:1095
TIR statements.