24 #ifndef TVM_TOPI_BROADCAST_H_ 25 #define TVM_TOPI_BROADCAST_H_ 50 std::string name =
"T_broadcast_to",
52 ICHECK_GE(output_shape.
size(), t->shape.size())
53 <<
"Not a broadcast, output dimensionality smaller than input.\noutput: " << output_shape
54 <<
"\nvs\ninput: " << t;
55 auto bh = detail::BroadcastShape(output_shape, t->shape);
56 ICHECK_EQ(output_shape.
size(), bh.common_shape.size());
58 for (
size_t i = 0; i < output_shape.
size(); ++i) {
59 if (output_shape[i].as<tir::IntImmNode>() ==
nullptr) {
62 ICHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i]));
63 oshape.push_back(bh.common_shape[i]);
67 return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars));
72 #define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \ 73 inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \ 74 inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \ 75 std::string name = "T_" #Name, std::string tag = kBroadcast) { \ 76 auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ 77 return detail::WithBroadcast(l, A, B, name, tag); \ 79 inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B, \ 80 std::string name = "T_" #Name, std::string tag = kElementWise) { \ 81 auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ 82 return tvm::te::compute( \ 83 A->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A(i), B); }, name, tag); \ 85 inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B, \ 86 std::string name = "T_" #Name, std::string tag = kElementWise) { \ 87 auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ 88 return tvm::te::compute( \ 89 B->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A, B(i)); }, name, tag); \ 92 #define TOPI_DEFINE_OP_OVERLOAD(Name, OpName) \ 93 inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B) { \ 94 return topi::OpName(A, B); \ 96 inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B) { \ 97 return topi::OpName(A, B); \ 99 inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B) { \ 100 return topi::OpName(A, B); \ 253 if (a.dtype().is_int() || a.dtype().is_uint()) {
272 if (a.dtype().is_int() || a.dtype().is_uint()) {
304 if (a.dtype().is_int() || a.dtype().is_uint()) {
323 if (a.dtype().is_int() || a.dtype().is_uint()) {
478 #endif // TVM_TOPI_BROADCAST_H_ tvm::PrimExpr less_equal(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:473
tvm::PrimExpr bitwise_and(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:155
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
tvm::PrimExpr trunc_divide(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:277
tvm::PrimExpr minimum(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:354
tvm::PrimExpr greater_equal(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:460
tvm::PrimExpr divide(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:239
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Tensor floor(const Tensor &x, std::string name="T_" "floor", std::string tag=kElementWise)
Definition: elemwise.h:57
tvm::PrimExpr less(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:421
tvm::PrimExpr logical_xor(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:142
tvm::PrimExpr not_equal(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:447
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
tvm::PrimExpr floor_mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:309
Utility functions for handling constants in TVM expressions.
constexpr auto kBroadcast
Definition: tags.h:36
PrimExpr floormod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of floordiv
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
size_t size() const
Definition: array.h:420
tvm::PrimExpr trunc_mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:328
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
tvm::PrimExpr multiply(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:225
tvm::PrimExpr bitwise_or(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:169
tvm::PrimExpr add(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:197
#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule)
Definition: broadcast.h:72
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
#define TOPI_DEFINE_OP_OVERLOAD(Name, OpName)
Definition: broadcast.h:92
tvm::PrimExpr right_shift(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:394
tvm::PrimExpr subtract(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:211
PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of truncdiv
tvm::PrimExpr logical_and(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:114
PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b)
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Tensor trunc(const Tensor &x, std::string name="T_" "trunc", std::string tag=kElementWise)
Definition: elemwise.h:60
tvm::PrimExpr equal(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:434
tvm::PrimExpr floor_divide(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:258
tvm::PrimExpr bitwise_xor(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:183
tvm::PrimExpr maximum(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:341
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
tvm::te::Tensor broadcast_to(const tvm::te::Tensor &t, const tvm::Array< tvm::PrimExpr > &output_shape, std::string name="T_broadcast_to", std::string tag=kBroadcast)
Creates an operation that broadcasts a tensor into a compatible shape according to numpy's rules...
Definition: broadcast.h:48
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
tvm::PrimExpr greater(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:408
PrimExpr truncdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute trunc(a / b)
tvm::PrimExpr power(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:367
tvm::PrimExpr left_shift(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:380
PrimExpr pow(PrimExpr x, PrimExpr y, Span span=Span())
Calculate power(x, y)
tvm::PrimExpr logical_or(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:128