24 #ifndef TVM_TOPI_BROADCAST_H_
25 #define TVM_TOPI_BROADCAST_H_
50 std::string name =
"T_broadcast_to",
52 ICHECK_GE(output_shape.
size(), t->shape.size())
53 <<
"Not a broadcast, output dimensionality smaller than input.\noutput: " << output_shape
54 <<
"\nvs\ninput: " << t;
55 auto bh = detail::BroadcastShape(output_shape, t->shape);
56 ICHECK_EQ(output_shape.
size(), bh.common_shape.size());
58 for (
size_t i = 0; i < output_shape.
size(); ++i) {
59 if (output_shape[i].as<tir::IntImmNode>() ==
nullptr) {
62 ICHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i]));
67 return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars));
72 #define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \
73 inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \
74 inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \
75 std::string name = "T_" #Name, std::string tag = kBroadcast) { \
76 auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
77 return detail::WithBroadcast(l, A, B, name, tag); \
79 inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B, \
80 std::string name = "T_" #Name, std::string tag = kElementWise) { \
81 auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
82 return tvm::te::compute( \
83 A->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A(i), B); }, name, tag); \
85 inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B, \
86 std::string name = "T_" #Name, std::string tag = kElementWise) { \
87 auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
88 return tvm::te::compute( \
89 B->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A, B(i)); }, name, tag); \
92 #define TOPI_DEFINE_OP_OVERLOAD(Name, OpName) \
93 inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B) { \
94 return topi::OpName(A, B); \
96 inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B) { \
97 return topi::OpName(A, B); \
99 inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B) { \
100 return topi::OpName(A, B); \
253 if (a.dtype().is_int() || a.dtype().is_uint()) {
254 return floordiv(a, b);
256 return floor(div(a, b));
272 if (a.dtype().is_int() || a.dtype().is_uint()) {
273 return truncdiv(a, b);
275 return trunc(div(a, b));
304 if (a.dtype().is_int() || a.dtype().is_uint()) {
305 return floormod(a, b);
307 return a - floor_divide(a, b) * b;
323 if (a.dtype().is_int() || a.dtype().is_uint()) {
324 return truncmod(a, b);
326 return a - trunc_divide(a, b) * b;
#define TOPI_DEFINE_OP_OVERLOAD(Name, OpName)
Definition: broadcast.h:92
#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule)
Definition: broadcast.h:72
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
size_t size() const
Definition: array.h:420
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Utility functions for handling constants in TVM expressions.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
tvm::PrimExpr floor_mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:309
constexpr auto kBroadcast
Definition: tags.h:36
tvm::PrimExpr bitwise_or(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:169
tvm::PrimExpr not_equal(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:447
tvm::PrimExpr less(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:421
tvm::PrimExpr subtract(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:211
tvm::PrimExpr logical_or(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:128
tvm::PrimExpr left_shift(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:380
tvm::PrimExpr trunc_mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:328
tvm::PrimExpr greater_equal(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:460
tvm::te::Tensor broadcast_to(const tvm::te::Tensor &t, const tvm::Array< tvm::PrimExpr > &output_shape, std::string name="T_broadcast_to", std::string tag=kBroadcast)
Creates an operation that broadcasts a tensor into a compatible shape according to numpy's rules.
Definition: broadcast.h:48
tvm::PrimExpr bitwise_and(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:155
tvm::PrimExpr divide(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:239
tvm::PrimExpr multiply(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:225
tvm::PrimExpr floor_divide(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:258
tvm::PrimExpr minimum(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:354
tvm::PrimExpr logical_and(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:114
tvm::PrimExpr trunc_divide(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:277
tvm::PrimExpr right_shift(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:394
tvm::PrimExpr add(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:197
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
tvm::PrimExpr less_equal(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:473
tvm::PrimExpr equal(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:434
tvm::PrimExpr bitwise_xor(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:183
tvm::PrimExpr greater(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:408
tvm::PrimExpr power(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:367
tvm::PrimExpr logical_xor(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:142
tvm::PrimExpr maximum(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:341
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of truncdiv
PrimExpr pow(PrimExpr x, PrimExpr y, Span span=Span())
Calculate power(x, y)
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values