24 #ifndef TVM_TOPI_DETAIL_BROADCAST_H_
25 #define TVM_TOPI_DETAIL_BROADCAST_H_
38 struct BroadcastHelper {
39 std::deque<tvm::PrimExpr> common_shape;
40 std::deque<tvm::tir::Var> all_vars;
41 std::deque<tvm::tir::Var> vars1;
42 std::deque<tvm::tir::Var> vars2;
46 ICHECK(type1.is_scalar() && type2.is_scalar());
47 ICHECK(type1.code() == type2.code());
54 int s1_size = shape1.
size();
55 int s2_size = shape2.
size();
59 auto cast_if_needed = [](
DataType to_type, PrimExpr expr) {
60 return to_type != expr.dtype() ?
cast(to_type, expr) : expr;
63 for (i = 1; i <=
std::min(s1_size, s2_size); ++i) {
67 DataType common_type = CommonType(shape1[s1_size - i].dtype(), shape2[s2_size - i].dtype());
70 if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) {
71 bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i]));
72 bh.vars1.push_front(bh.all_vars[0]);
73 bh.vars2.push_front(bh.all_vars[0]);
74 }
else if (topi::detail::EqualCheck(one, shape1[s1_size - i])) {
75 ICHECK(!topi::detail::EqualCheck(one, shape2[s2_size - i]));
76 bh.common_shape.push_front(cast_if_needed(common_type, shape2[s2_size - i]));
77 bh.vars2.push_front(bh.all_vars[0]);
78 }
else if (topi::detail::EqualCheck(one, shape2[s2_size - i])) {
79 bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i]));
80 bh.vars1.push_front(bh.all_vars[0]);
81 }
else if (!static_size1 && !static_size2) {
82 bh.common_shape.push_front(
83 cast_if_needed(common_type,
max(shape1[s1_size - i], shape2[s2_size - i])));
84 bh.vars1.push_front(bh.all_vars[0]);
85 bh.vars2.push_front(bh.all_vars[0]);
86 }
else if (!static_size1) {
87 bh.common_shape.push_front(cast_if_needed(common_type, shape2[s2_size - i]));
88 bh.vars2.push_front(bh.all_vars[0]);
89 bh.vars1.push_front(bh.all_vars[0]);
90 }
else if (!static_size2) {
91 bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i]));
92 bh.vars1.push_front(bh.all_vars[0]);
93 bh.vars2.push_front(bh.all_vars[0]);
95 ICHECK(
false) <<
"Incompatible broadcast dims: " << shape1[s1_size - i] <<
" and "
96 << shape2[s2_size - i]
102 auto max_size =
std::max(s1_size, s2_size);
103 auto&
shape = (s1_size > s2_size) ? shape1 : shape2;
104 auto& vars = (s1_size > s2_size) ? bh.vars1 : bh.vars2;
105 for (; i <= max_size; ++i) {
107 bh.common_shape.push_front(
shape[max_size - i]);
108 vars.push_front(bh.all_vars[0]);
115 const std::deque<tvm::tir::Var>& my_vars,
const std::deque<tvm::tir::Var>& all_vars) {
117 ICHECK_EQ(ovars.
size(), all_vars.size());
119 size_t expected_dims = T->shape.size();
120 for (
size_t i = 0; i < ovars.
size(); ++i) {
122 for (
size_t j = 0; j < my_vars.size(); ++j) {
123 if (all_vars[i].same_as(my_vars[j])) {
131 if (!found && (ovars.
size() - i) <= expected_dims) {
135 ICHECK(expected_dims == ivars.
size());
139 template <
typename FBinaryExpr>
142 const std::string& tag =
"") {
143 auto bh = BroadcastShape(A->shape, B->shape);
145 return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)),
146 B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars)));
Reference to PrimExprNode.
Definition: expr.h:115
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
iterator end() const
Definition: array.h:390
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
iterator begin() const
Definition: array.h:387
size_t size() const
Definition: array.h:420
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
a named variable in TIR
Definition: var.h:89
Utility functions for handling constants in TVM expressions.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
tvm::IntImmNode IntImmNode
Definition: expr.h:49
PrimExpr make_zero(DataType t, Span span=Span())
Make a const zero expr.
Definition: op.h:976
Tensor max(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the maximum of elements over a given axis.
Definition: reduction.h:440
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type,...
Definition: elemwise.h:281
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1913
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
runtime::DataType DataType
Definition: data_type.h:493
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
Operation node can generate one or multiple Tensors.