tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
op.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 // Acknowledgement: Most operator APIs originate from Halide.
28 #ifndef TVM_TIR_OP_H_
29 #define TVM_TIR_OP_H_
30 
31 #include <tvm/ir/expr.h>
32 #include <tvm/ir/op.h>
33 #include <tvm/ir/type.h>
34 #include <tvm/tir/expr.h>
35 #include <tvm/tir/stmt.h>
36 
37 #include <algorithm>
38 #include <limits>
39 #include <type_traits>
40 
41 namespace tvm {
42 
43 #define TVM_TIR_REGISTER_OP(OpName) \
44  TVM_REGISTER_OP("tir." OpName).set_attr<TScriptPrinterName>("TScriptPrinterName", OpName)
45 
46 // Most common operators can be overloaded by argument type(PrimExpr).
47 // So we put them under the root namespace.
48 //
49 // We put more developer oriented APIs -- make_const and is_const under tir
50 // as they are more specific to the tir namespace.
51 
63 TVM_DLL Type GetType(const PrimExpr& expr);
64 
73 
83 
91 TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span());
92 
99 TVM_DLL PrimExpr max_value(const DataType& dtype, Span span = Span());
100 
107 TVM_DLL PrimExpr min_value(const DataType& dtype, Span span = Span());
108 
115 TVM_DLL PrimExpr infinity(const DataType& dtype, Span span = Span());
116 
126 TVM_DLL PrimExpr cast(const DataType& t, PrimExpr value, Span span = Span());
136 TVM_DLL PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span = Span());
147 TVM_DLL PrimExpr add(PrimExpr a, PrimExpr b, Span span = Span());
158 TVM_DLL PrimExpr sub(PrimExpr a, PrimExpr b, Span span = Span());
168 TVM_DLL PrimExpr neg(PrimExpr a, Span span = Span());
179 TVM_DLL PrimExpr mul(PrimExpr a, PrimExpr b, Span span = Span());
190 TVM_DLL PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span = Span());
201 TVM_DLL PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span = Span());
212 TVM_DLL PrimExpr greater(PrimExpr a, PrimExpr b, Span span = Span());
234 TVM_DLL PrimExpr less(PrimExpr a, PrimExpr b, Span span = Span());
245 TVM_DLL PrimExpr less_equal(PrimExpr a, PrimExpr b, Span span = Span());
256 TVM_DLL PrimExpr equal(PrimExpr a, PrimExpr b, Span span = Span());
267 TVM_DLL PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span = Span());
277 TVM_DLL PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span = Span());
287 TVM_DLL PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span = Span());
296 TVM_DLL PrimExpr logical_not(PrimExpr a, Span span = Span());
311 TVM_DLL PrimExpr div(PrimExpr a, PrimExpr b, Span span = Span());
324 TVM_DLL PrimExpr truncdiv(PrimExpr a, PrimExpr b, Span span = Span());
337 TVM_DLL PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span = Span());
353 TVM_DLL PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span = Span());
369 TVM_DLL PrimExpr shapediv(PrimExpr a, PrimExpr b, Span span = Span());
384 TVM_DLL PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span = Span());
395 TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span = Span());
406 TVM_DLL PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span = Span());
417 TVM_DLL PrimExpr floormod(PrimExpr a, PrimExpr b, Span span = Span());
428 TVM_DLL PrimExpr max(PrimExpr a, PrimExpr b, Span span = Span());
439 TVM_DLL PrimExpr min(PrimExpr a, PrimExpr b, Span span = Span());
450 TVM_DLL PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span = Span());
461 TVM_DLL PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span = Span());
472 TVM_DLL PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span = Span());
482 TVM_DLL PrimExpr bitwise_neg(PrimExpr a, Span span = Span());
494 TVM_DLL PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value,
495  Span span = Span());
502 TVM_DLL PrimExpr likely(PrimExpr cond, Span span = Span());
509 TVM_DLL PrimExpr pow(PrimExpr x, PrimExpr y, Span span = Span());
517 TVM_DLL PrimExpr abs(PrimExpr x, Span span = Span());
524 TVM_DLL PrimExpr isnan(PrimExpr x, Span span = Span());
525 
532 TVM_DLL PrimExpr isfinite(PrimExpr x, Span span = Span());
533 
540 TVM_DLL PrimExpr isinf(PrimExpr x, Span span = Span());
541 
550 TVM_DLL PrimExpr sum(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
551  Span span = Span());
552 
560 TVM_DLL PrimExpr all(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
561  Span span = Span());
562 
571 TVM_DLL PrimExpr any(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
572  Span span = Span());
573 
582 TVM_DLL PrimExpr max(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
583  Span span = Span());
584 
593 TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
594  Span span = Span());
595 
604 TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
605  Span span = Span());
606 
613 TVM_DLL PrimExpr floor(PrimExpr x, Span span = Span());
614 
621 TVM_DLL PrimExpr ceil(PrimExpr x, Span span = Span());
622 
629 TVM_DLL PrimExpr round(PrimExpr x, Span span = Span());
630 
638 TVM_DLL PrimExpr nearbyint(PrimExpr x, Span span = Span());
639 
646 TVM_DLL PrimExpr trunc(PrimExpr x, Span span = Span());
647 
656 TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span span = Span());
657 
679  Span span = Span());
680 
688 TVM_DLL PrimExpr fast_erf_float_expr(PrimExpr arg, int bits);
689 
690 // Intrinsic operators
691 #define TVM_DECLARE_INTRIN_UNARY(OpName) \
692  inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
693  static const Op& op = Op::Get("tir." #OpName); \
694  if (x.dtype().is_bfloat16()) { \
695  DataType bf16_dtype = x.dtype(); \
696  DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \
697  PrimExpr x_fp32 = tir::Cast(fp32_dtype, {x}, span); \
698  PrimExpr result_fp32 = tir::Call(fp32_dtype, op, {x_fp32}, span); \
699  return tir::Cast(bf16_dtype, {result_fp32}, span); \
700  } else { \
701  return tir::Call(x.dtype(), op, {x}, span); \
702  } \
703  }
704 
730 
731 #define TVM_DECLARE_INTRIN_BINARY(OpName) \
732  inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \
733  static const Op& op = Op::Get("tir." #OpName); \
734  return tir::Call(x.dtype(), op, {x, y}, span); \
735  }
736 
742 
743 namespace tir {
744 
751 inline bool IsPointerType(const Type& type, const DataType& element_type) {
752  if (!type.defined()) return false;
753  if (const auto* ptr_type = type.as<PointerTypeNode>()) {
754  if (const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>()) {
755  return prim_type->dtype == element_type;
756  }
757  }
758  return false;
759 }
760 
769 template <typename ValueType,
770  typename = typename std::enable_if<std::is_pod<ValueType>::value>::type>
771 inline PrimExpr make_const(DataType t, ValueType value, Span span = Span());
778 inline PrimExpr make_zero(DataType t, Span span = Span());
785 inline PrimExpr const_true(int lanes = 1, Span span = Span()) {
786  return make_const(DataType::UInt(1, lanes), 1);
787 }
794 inline PrimExpr const_false(int lanes = 1, Span span = Span()) {
795  return make_const(DataType::UInt(1, lanes), 0);
796 }
803 inline const int64_t* as_const_int(const PrimExpr& x) {
804  if (!x.defined()) return nullptr;
805  if (const tir::IntImmNode* op = x.as<tir::IntImmNode>()) {
806  return &(op->value);
807  }
808 
809  return nullptr;
810 }
811 
818 inline bool is_const_int(const PrimExpr& x, int64_t value);
819 
825 inline bool is_no_op(const tir::Stmt& stmt);
826 
833 inline bool is_one(const PrimExpr& x) { return is_const_int(x, 1); }
834 
841 inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); }
842 
848 inline bool is_const_int(const PrimExpr& x);
849 
855 inline bool is_const_number(const PrimExpr& x);
856 
866 template <typename FReduce>
867 inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr>& values,
868  Span span = Span()) {
869  for (PrimExpr val : values) {
870  init_value = freduce(init_value, val, span);
871  }
872  return init_value;
873 }
874 
883 TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift);
884 
885 // Implementation details after this
886 inline bool is_const_int(const PrimExpr& x) { return as_const_int(x); }
887 
888 inline bool is_const_number(const PrimExpr& x) {
889  if (x.as<tir::IntImmNode>()) {
890  return true;
891  } else if (x.as<tir::FloatImmNode>()) {
892  return true;
893  } else if (const auto* op = x.as<tir::BroadcastNode>()) {
894  return (op->value->IsInstance<tir::IntImmNode>() || op->value->IsInstance<tir::FloatImmNode>());
895  }
896  return false;
897 }
898 
899 inline bool is_positive_const(const PrimExpr& a) {
900  const int64_t* as_int = as_const_int(a);
901  return as_int && (*as_int > 0);
902 }
903 
904 inline bool is_negative_const(const PrimExpr& a) {
905  const int64_t* as_int = as_const_int(a);
906  return as_int && (*as_int < 0);
907 }
908 
909 inline bool is_const_int(const PrimExpr& x, int64_t value) {
910  const int64_t* as_int = as_const_int(x);
911  return as_int && (*as_int == value);
912 }
913 
914 inline bool is_no_op(const tir::Stmt& stmt) {
915  if (!stmt.defined()) return true;
916  if (const auto* op = stmt.as<tir::EvaluateNode>()) {
917  return is_const_int(op->value);
918  }
919  if (const auto* op = stmt.as<tir::SeqStmtNode>()) {
920  return op->seq.size() == 0;
921  }
922  return false;
923 }
924 
925 template <typename ValueType>
926 inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) {
927  if (t.is_int()) return IntImm(t, static_cast<int64_t>(value), span);
928  if (t.is_uint()) {
929  // Use IntImm if it is a small integer
930  uint64_t uval = static_cast<uint64_t>(value);
931  if (value < static_cast<ValueType>(0)) {
932  LOG(FATAL) << "cannot make uint from negative value " << value;
933  } else if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
934  return IntImm(t, static_cast<int64_t>(value), span);
935  } else {
936  uint64_t mask = (static_cast<uint64_t>(1) << 32U) - 1U;
937  uint64_t low = uval & mask;
938  uint64_t high = uval >> 32U;
939  return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high), span);
940  }
941  }
942  if (t.is_float() || t.is_bfloat16() || t.is_float8())
943  return FloatImm(t, static_cast<double>(value), span);
944  // For now, we store const scalar values of custom datatypes within doubles; later, during the
945  // datatypes lowering pass, we will lower the value to its true representation in the format
946  // specified by the datatype.
947  // TODO(gus) when do we need to start worrying about doubles not being precise enough?
948  if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(DataType::kCustomBegin)) {
949  return FloatImm(t, static_cast<double>(value), span);
950  }
951  LOG(FATAL) << "cannot make const for type " << t;
952  throw;
953 }
954 
955 template <>
956 inline PrimExpr MakeConstScalar(DataType t, bool value, Span span) {
957  return MakeConstScalar(t, static_cast<int>(value), span);
958 }
959 
960 template <typename ValueType, typename>
961 inline PrimExpr make_const(DataType t, ValueType value, Span span) {
962  if (t.lanes() == 1) {
963  return MakeConstScalar(t, value, span);
964  } else {
965  return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span);
966  }
967 }
968 
969 inline PrimExpr make_zero(DataType t, Span span) {
970  if (t.is_handle()) {
971  return reinterpret(t, make_const(DataType::UInt(64), 0, span));
972  }
973  return make_const(t, 0, span);
974 }
975 
976 } // namespace tir
977 
978 // additional const expression overloading
979 #define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
980  inline PrimExpr Name(PrimExpr& a, PrimExpr b) { \
981  a = OpFunc(a, b); \
982  return a; \
983  }
984 
985 #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \
986  inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \
987  inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \
988  inline PrimExpr Name(int a, const PrimExpr& b) { \
989  return Name(tir::make_const(b.dtype(), a), b); \
990  } \
991  inline PrimExpr Name(const PrimExpr& a, int b) { \
992  return Name(a, tir::make_const(a.dtype(), b)); \
993  } \
994  inline PrimExpr Name(const PrimExpr& a, double b) { \
995  return Name(a, tir::make_const(DataType::Float(64), b)); \
996  }
997 
998 #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(Name) \
999  inline PrimExpr Name(const PrimExpr& a, float b, Span span = Span()) { \
1000  return Name(a, PrimExpr(b), span); \
1001  } \
1002  inline PrimExpr Name(float a, const PrimExpr& b, Span span = Span()) { \
1003  return Name(PrimExpr(a), b, span); \
1004  } \
1005  inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \
1006  return Name(tir::make_const(b.dtype(), a), b, span); \
1007  } \
1008  inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \
1009  return Name(a, tir::make_const(a.dtype(), b), span); \
1010  } \
1011  inline PrimExpr Name(const PrimExpr& a, double b, Span span = Span()) { \
1012  return Name(a, tir::make_const(DataType::Float(64), b), span); \
1013  }
1014 
1015 #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
1016  inline PrimExpr Name(const PrimExpr& a, bool b) { return Name(a, PrimExpr(b)); } \
1017  inline PrimExpr Name(bool a, const PrimExpr& b) { return Name(PrimExpr(a), b); }
1018 
1019 #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \
1020  inline PrimExpr Name(const PrimExpr& a, bool b, Span span = Span()) { \
1021  return Name(a, PrimExpr(b), span); \
1022  } \
1023  inline PrimExpr Name(bool a, const PrimExpr& b, Span span = Span()) { \
1024  return Name(PrimExpr(a), b, span); \
1025  }
1026 
1027 #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
1028  inline PrimExpr Name(const PrimExpr& a, int b) { \
1029  return Name(a, tir::make_const(a.dtype(), b)); \
1030  } \
1031  inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tir::make_const(b.dtype(), a), b); }
1032 
1033 #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \
1034  inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \
1035  return Name(a, tir::make_const(a.dtype(), b), span); \
1036  } \
1037  inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \
1038  return Name(tir::make_const(b.dtype(), a), b, span); \
1039  }
1040 
1041 TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+);
1042 TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-);
1043 TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*);
1047 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>); // NOLINT(*)
1049 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*)
1061 // integer related ops
1073 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*)
1074 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*)
1078 // logical ops
1083 
1089 template <typename TA>
1090 inline void DivAmbiguityError(const TA& a) {
1091  constexpr bool div_ambiguity = !std::is_class<TA>::value;
1092  static_assert(div_ambiguity,
1093  "TVM supports multiple types of integer divisions, "
1094  "please call div, indexdiv/indexmod, "
1095  "floordiv/floormod or truncdiv/truncmod directly "
1096  "to avoid ambiguity in the code. "
1097  "Checkout these functions in tir/op.h.");
1098 }
1099 
1100 // The following code are not intended to be used in the codebase.
1101 // Instead, they generate clear compiler errors that ask developers
1102 // to use the specific division function.
1103 // The second template argument is necessary to make sure the
1104 // code compiles lazily by the compiler during invocation.
1105 template <typename TB>
1106 inline PrimExpr operator/(const PrimExpr& a, const TB& b) {
1107  DivAmbiguityError(a);
1108  return a;
1109 }
1110 
1111 template <typename TB>
1112 inline PrimExpr operator/=(const PrimExpr& a, const TB& b) {
1113  DivAmbiguityError(a);
1114  return a;
1115 }
1116 
1117 template <typename TB>
1118 inline PrimExpr operator%(const PrimExpr& a, const TB& b) {
1119  DivAmbiguityError(a);
1120  return a;
1121 }
1122 } // namespace tvm
1123 #endif // TVM_TIR_OP_H_
Constant floating point literals in the program.
Definition: expr.h:538
Managed reference class to FloatImmNode.
Definition: expr.h:567
Constant integer literals in the program.
Definition: expr.h:491
int64_t value
the Internal value.
Definition: expr.h:494
Managed reference class to IntImmNode.
Definition: expr.h:520
Low-level raw pointer type.
Definition: type.h:150
Reference to PrimExprNode.
Definition: expr.h:114
Primitive data types used in the low-level IR.
Definition: type.h:106
Definition: source_map.h:120
Managed reference to TypeNode.
Definition: type.h:93
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Runtime primitive data type.
Definition: data_type.h:42
bool is_handle() const
Definition: data_type.h:115
bool is_uint() const
Definition: data_type.h:113
DataType element_of() const
Get the scalar version of the type.
Definition: data_type.h:138
@ kCustomBegin
Definition: data_type.h:60
bool is_int() const
Definition: data_type.h:111
int code() const
Definition: data_type.h:87
int lanes() const
Definition: data_type.h:93
bool is_float8() const
Definition: data_type.h:101
bool is_bfloat16() const
Definition: data_type.h:109
static DataType UInt(int bits, int lanes=1)
Construct an uint type.
Definition: data_type.h:183
bool is_float() const
Definition: data_type.h:99
bool defined() const
Definition: object.h:550
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:894
Create a vector where all the elements are value.
Definition: expr.h:787
Managed reference to BroadcastNode.
Definition: expr.h:819
Evaluates an expression. This is mostly used for putting a Call node into Stmt.
Definition: stmt.h:699
The container of seq statement. Represent a sequence of statements.
Definition: stmt.h:666
Container of all statements.
Definition: stmt.h:59
Base expr nodes in TVM.
Primitive operators(builtin intrinsics) and registry for them.
IR/AST nodes for the unified type system in TVM.
tvm::Span Span
Definition: base.h:65
PrimExpr MakeConstScalar(DataType t, ValueType value, Span span=Span())
Definition: op.h:926
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:961
bool is_const_power_of_two_integer(const PrimExpr &x, int *shift)
Check whether x is a constant power of two If x is power of two, write the power to the shift.
bool is_zero(const PrimExpr &x)
Check whether x is a constant integer 0.
Definition: op.h:841
bool IsPointerType(const Type &type, const DataType &element_type)
Check if type is a pointer to a runtime element type.
Definition: op.h:751
bool is_negative_const(const PrimExpr &a)
Definition: op.h:904
bool is_const_number(const PrimExpr &x)
Check whether x is an integer/float constant.
Definition: op.h:888
bool is_const_int(const PrimExpr &x, int64_t value)
Check whether x is a constant integer expression.
Definition: op.h:909
PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array< PrimExpr > &values, Span span=Span())
Left fold.
Definition: op.h:867
bool is_positive_const(const PrimExpr &a)
Definition: op.h:899
PrimExpr const_false(int lanes=1, Span span=Span())
Make a constant false expression.
Definition: op.h:794
PrimExpr const_true(int lanes=1, Span span=Span())
Make a constant true expression.
Definition: op.h:785
bool is_no_op(const tir::Stmt &stmt)
Check whether stmt is nop.
Definition: op.h:914
bool is_one(const PrimExpr &x)
Check whether x is a constant integer 1.
Definition: op.h:833
const int64_t * as_const_int(const PrimExpr &x)
Get x as constant int expression.
Definition: op.h:803
PrimExpr make_zero(DataType t, Span span=Span())
Make a const zero expr.
Definition: op.h:969
std::function< PrimExpr(PrimExpr source, const Array< IterVar > &axis, Array< PrimExpr > init, Span span)> FReduce
The operation to use for CommReduce.
Definition: reduction.h:47
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
runtime::DataType GetRuntimeDataType(const Type &type)
Get the implied DataType for storing values with type during runtime.
PrimExpr isfinite(PrimExpr x, Span span=Span())
Check if x is finite.
PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span=Span())
compute ceil(a / b)
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr tanh(PrimExpr x, Span span=Span())
Definition: op.h:709
PrimExpr erf(PrimExpr x, Span span=Span())
Definition: op.h:708
PrimExpr shapediv(PrimExpr a, PrimExpr b, Span span=Span())
compute ceil(a / b) where a and b are non-negative.
PrimExpr log10(PrimExpr x, Span span=Span())
Definition: op.h:715
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
PrimExpr operator/(PrimExpr a, PrimExpr b)
division operator
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of truncdiv
PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span=Span())
and
PrimExpr hypot(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:740
PrimExpr log1p(PrimExpr x, Span span=Span())
Definition: op.h:716
void DivAmbiguityError(const TA &a)
Helper function to raise a compiler error about division ambiguity.
Definition: op.h:1090
PrimExpr prod(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
product of source expression over axis
PrimExpr likely(PrimExpr cond, Span span=Span())
Mark condition as likely.
PrimExpr reinterpret(const DataType &t, PrimExpr value, Span span=Span())
perform reinterpret cast value to type.
PrimExpr atan2(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:737
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
PrimExpr min_value(const DataType &dtype, Span span=Span())
PrimExpr bitwise_neg(PrimExpr a, Span span=Span())
take bitwise negation of two values
PrimExpr cosh(PrimExpr x, Span span=Span())
Definition: op.h:720
PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span=Span())
or
PrimExpr atan(PrimExpr x, Span span=Span())
Definition: op.h:725
Type GetType(const PrimExpr &expr)
Get the type of the expression under the unified type system.
PrimExpr isnan(PrimExpr x, Span span=Span())
Check if x is NaN.
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr max_value(const DataType &dtype, Span span=Span())
PrimExpr exp2(PrimExpr x, Span span=Span())
Definition: op.h:706
PrimExpr rsqrt(PrimExpr x, Span span=Span())
Definition: op.h:712
PrimExpr operator/=(const PrimExpr &a, const TB &b)
Definition: op.h:1112
PrimExpr asinh(PrimExpr x, Span span=Span())
Definition: op.h:727
PrimExpr less(PrimExpr a, PrimExpr b, Span span=Span())
less
PrimExpr sin(PrimExpr x, Span span=Span())
Definition: op.h:721
PrimExpr trunc(PrimExpr x, Span span=Span())
Calculate trunc(x)
PrimExpr round(PrimExpr x, Span span=Span())
Calculate round(x)
Type GetTypeFromRuntimeDataType(const DataType &dtype)
Get the type corresponding to DataType.
PrimExpr neg(PrimExpr a, Span span=Span())
negation.
PrimExpr ceil(PrimExpr x, Span span=Span())
Calculate ceil(x)
PrimExpr any(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
logical Or of source expression over axis
PrimExpr pow(PrimExpr x, PrimExpr y, Span span=Span())
Calculate power(x, y)
PrimExpr logical_not(PrimExpr a, Span span=Span())
not
PrimExpr exp10(PrimExpr x, Span span=Span())
Definition: op.h:707
PrimExpr copysign(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:739
PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span=Span())
take bitwise xor of two values
PrimExpr less_equal(PrimExpr a, PrimExpr b, Span span=Span())
less_equal
PrimExpr greater(PrimExpr a, PrimExpr b, Span span=Span())
greater
PrimExpr exp(PrimExpr x, Span span=Span())
Definition: op.h:705
PrimExpr floormod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of floordiv
PrimExpr infinity(const DataType &dtype, Span span=Span())
PrimExpr sub(PrimExpr a, PrimExpr b, Span span=Span())
subtraction operator
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr nextafter(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:738
PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span span=Span())
Construct a large uint constant by its low 32 bits and high 32bits.
PrimExpr asin(PrimExpr x, Span span=Span())
Definition: op.h:723
PrimExpr sigmoid(PrimExpr x, Span span=Span())
Definition: op.h:710
PrimExpr max(const PrimExpr &a, double b, Span span=Span())
Definition: op.h:1051
PrimExpr acos(PrimExpr x, Span span=Span())
Definition: op.h:724
PrimExpr mul(PrimExpr a, PrimExpr b, Span span=Span())
multiplication operator
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
PrimExpr floor(PrimExpr x, Span span=Span())
Calculate floor(x)
PrimExpr greater_equal(PrimExpr a, PrimExpr b, Span span=Span())
greater_equal
PrimExpr operator%(const PrimExpr &a, const TB &b)
Definition: op.h:1118
PrimExpr abs(PrimExpr x, Span span=Span())
Calculate absolute value of x.
PrimExpr atanh(PrimExpr x, Span span=Span())
Definition: op.h:728
PrimExpr sqrt(PrimExpr x, Span span=Span())
Definition: op.h:711
PrimExpr isinf(PrimExpr x, Span span=Span())
Check if x is infinite.
PrimExpr log2(PrimExpr x, Span span=Span())
Definition: op.h:714
PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span=Span())
not_equal
PrimExpr ldexp(PrimExpr x, PrimExpr y, Span span=Span())
Definition: op.h:741
PrimExpr truncdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute trunc(a / b)
PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span span=Span())
Execute a multiplication between two Q-numbers x and y followed by a right shift s....
PrimExpr popcount(PrimExpr x, Span span=Span())
Definition: op.h:717
PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span=Span())
take bitwise and of two values
PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span=Span())
left shift operator
PrimExpr sinh(PrimExpr x, Span span=Span())
Definition: op.h:722
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
PrimExpr all(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
logical And of source expression over axis
PrimExpr add(PrimExpr a, PrimExpr b, Span span=Span())
add operator
PrimExpr log(PrimExpr x, Span span=Span())
Definition: op.h:713
PrimExpr nearbyint(PrimExpr x, Span span=Span())
Calculates std::nearbyint(x)
PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span=Span())
right shift operator
PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span=Span())
take bitwise or of two values
PrimExpr clz(PrimExpr x, Span span=Span())
Definition: op.h:729
PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b)
PrimExpr acosh(PrimExpr x, Span span=Span())
Definition: op.h:726
PrimExpr tan(PrimExpr x, Span span=Span())
Definition: op.h:718
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
PrimExpr cos(PrimExpr x, Span span=Span())
Definition: op.h:719
PrimExpr fast_erf_float_expr(PrimExpr arg, int bits)
Fast_erf_float expression from Eigen.
TIR statements.
TIR expressions.
#define TVM_DECLARE_INTRIN_UNARY(OpName)
Definition: op.h:691
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name)
Definition: op.h:1027
#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD_SPANNED(Name)
Definition: op.h:1019
#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc)
Definition: op.h:979
#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(Name)
Definition: op.h:998
#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name)
Definition: op.h:985
#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name)
Definition: op.h:1015
#define TVM_DECLARE_INTRIN_BINARY(OpName)
Definition: op.h:731
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(Name)
Definition: op.h:1033