24 #ifndef TVM_TOPI_DETAIL_BROADCAST_H_ 25 #define TVM_TOPI_DETAIL_BROADCAST_H_ 38 struct BroadcastHelper {
39 std::deque<tvm::PrimExpr> common_shape;
40 std::deque<tvm::tir::Var> all_vars;
41 std::deque<tvm::tir::Var> vars1;
42 std::deque<tvm::tir::Var> vars2;
48 int s1_size = shape1.
size();
49 int s2_size = shape2.
size();
52 for (i = 1; i <=
std::min(s1_size, s2_size); ++i) {
57 if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) {
58 bh.common_shape.push_front(shape1[s1_size - i]);
59 bh.vars1.push_front(bh.all_vars[0]);
60 bh.vars2.push_front(bh.all_vars[0]);
61 }
else if (topi::detail::EqualCheck(one, shape1[s1_size - i])) {
62 ICHECK(!topi::detail::EqualCheck(one, shape2[s2_size - i]));
63 bh.common_shape.push_front(shape2[s2_size - i]);
64 bh.vars2.push_front(bh.all_vars[0]);
65 }
else if (topi::detail::EqualCheck(one, shape2[s2_size - i])) {
66 bh.common_shape.push_front(shape1[s1_size - i]);
67 bh.vars1.push_front(bh.all_vars[0]);
68 }
else if (!static_size1 && !static_size2) {
69 bh.common_shape.push_front(
max(shape1[s1_size - i], shape2[s2_size - i]));
70 bh.vars1.push_front(bh.all_vars[0]);
71 bh.vars2.push_front(bh.all_vars[0]);
72 }
else if (!static_size1) {
73 bh.common_shape.push_front(shape2[s2_size - i]);
74 bh.vars2.push_front(bh.all_vars[0]);
75 bh.vars1.push_front(bh.all_vars[0]);
76 }
else if (!static_size2) {
77 bh.common_shape.push_front(shape1[s1_size - i]);
78 bh.vars1.push_front(bh.all_vars[0]);
79 bh.vars2.push_front(bh.all_vars[0]);
81 ICHECK(
false) <<
"Incompatible broadcast dims: " << shape1[s1_size - i] <<
" and " 82 << shape2[s2_size - i]
88 auto max_size =
std::max(s1_size, s2_size);
89 auto&
shape = (s1_size > s2_size) ? shape1 : shape2;
90 auto& vars = (s1_size > s2_size) ? bh.vars1 : bh.vars2;
91 for (; i <= max_size; ++i) {
93 bh.common_shape.push_front(
shape[max_size - i]);
94 vars.push_front(bh.all_vars[0]);
101 const std::deque<tvm::tir::Var>& my_vars,
const std::deque<tvm::tir::Var>& all_vars) {
103 ICHECK_EQ(ovars.
size(), all_vars.size());
105 size_t expected_dims = T->shape.size();
106 for (
size_t i = 0; i < ovars.
size(); ++i) {
108 for (
size_t j = 0; j < my_vars.size(); ++j) {
109 if (all_vars[i].same_as(my_vars[j])) {
117 if (!found && (ovars.
size() - i) <= expected_dims) {
121 ICHECK(expected_dims == ivars.
size());
125 template <
typename FBinaryExpr>
128 const std::string& tag =
"") {
129 auto bh = BroadcastShape(A->shape, B->shape);
131 return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)),
132 B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars)));
142 #endif // TVM_TOPI_DETAIL_BROADCAST_H_ Tensor max(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the maximum of elements over a given axis.
Definition: reduction.h:429
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
a named variable in TIR
Definition: var.h:88
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:436
Utility functions for handling constants in TVM expressions.
size_t size() const
Definition: array.h:399
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
tvm::IntImmNode IntImmNode
Definition: expr.h:49
PrimExpr make_zero(DataType t, Span span=Span())
Make a const zero expr.
Definition: op.h:1138
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1758
iterator end() const
Definition: array.h:369
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
iterator begin() const
Definition: array.h:366
Operation node can generate one or multiple Tensors.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Reference to PrimExprNode.
Definition: expr.h:112
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:865