24 #ifndef TVM_TOPI_DETAIL_BROADCAST_H_
25 #define TVM_TOPI_DETAIL_BROADCAST_H_
38 struct BroadcastHelper {
39 std::deque<tvm::PrimExpr> common_shape;
40 std::deque<tvm::tirx::Var> all_vars;
41 std::deque<tvm::tirx::Var> vars1;
42 std::deque<tvm::tirx::Var> vars2;
46 TVM_FFI_ICHECK(type1.is_scalar() && type2.is_scalar());
47 TVM_FFI_ICHECK(type1.code() == type2.code());
51 inline BroadcastHelper BroadcastShape(
const tvm::ffi::Array<tvm::PrimExpr>& shape1,
52 const tvm::ffi::Array<tvm::PrimExpr>& shape2) {
54 int s1_size = shape1.size();
55 int s2_size = shape2.size();
59 auto cast_if_needed = [](
DataType to_type, PrimExpr expr) {
60 return to_type != expr.dtype() ?
cast(to_type, expr) : expr;
63 for (i = 1; i <=
std::min(s1_size, s2_size); ++i) {
67 DataType common_type = CommonType(shape1[s1_size - i].dtype(), shape2[s2_size - i].dtype());
70 if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) {
71 bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i]));
72 bh.vars1.push_front(bh.all_vars[0]);
73 bh.vars2.push_front(bh.all_vars[0]);
74 }
else if (topi::detail::EqualCheck(one, shape1[s1_size - i])) {
75 TVM_FFI_ICHECK(!topi::detail::EqualCheck(one, shape2[s2_size - i]));
76 bh.common_shape.push_front(cast_if_needed(common_type, shape2[s2_size - i]));
77 bh.vars2.push_front(bh.all_vars[0]);
78 }
else if (topi::detail::EqualCheck(one, shape2[s2_size - i])) {
79 bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i]));
80 bh.vars1.push_front(bh.all_vars[0]);
81 }
else if (!static_size1 && !static_size2) {
82 bh.common_shape.push_front(
83 cast_if_needed(common_type,
max(shape1[s1_size - i], shape2[s2_size - i])));
84 bh.vars1.push_front(bh.all_vars[0]);
85 bh.vars2.push_front(bh.all_vars[0]);
86 }
else if (!static_size1) {
87 bh.common_shape.push_front(cast_if_needed(common_type, shape2[s2_size - i]));
88 bh.vars2.push_front(bh.all_vars[0]);
89 bh.vars1.push_front(bh.all_vars[0]);
90 }
else if (!static_size2) {
91 bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i]));
92 bh.vars1.push_front(bh.all_vars[0]);
93 bh.vars2.push_front(bh.all_vars[0]);
95 TVM_FFI_ICHECK(
false) <<
"Incompatible broadcast dims: " << shape1[s1_size - i] <<
" and "
96 << shape2[s2_size - i] <<
" in: "
97 << tvm::ffi::Array<tvm::PrimExpr>(shape1.begin(), shape1.end())
99 << tvm::ffi::Array<tvm::PrimExpr>(shape2.begin(), shape2.end());
103 auto max_size =
std::max(s1_size, s2_size);
104 auto&
shape = (s1_size > s2_size) ? shape1 : shape2;
105 auto& vars = (s1_size > s2_size) ? bh.vars1 : bh.vars2;
106 for (; i <= max_size; ++i) {
108 bh.common_shape.push_front(
shape[max_size - i]);
109 vars.push_front(bh.all_vars[0]);
114 inline tvm::ffi::Array<tvm::PrimExpr> InputIndexFromBroadcast(
115 const tvm::ffi::Array<tvm::tirx::Var>& ovars,
const tvm::te::Tensor& T,
116 const std::deque<tvm::tirx::Var>& my_vars,
const std::deque<tvm::tirx::Var>& all_vars) {
117 tvm::ffi::Array<tvm::PrimExpr> ivars;
118 TVM_FFI_ICHECK_EQ(ovars.size(), all_vars.size());
120 size_t expected_dims = T->shape.size();
121 for (
size_t i = 0; i < ovars.size(); ++i) {
123 for (
size_t j = 0; j < my_vars.size(); ++j) {
124 if (all_vars[i].same_as(my_vars[j])) {
125 ivars.push_back(ovars[i]);
132 if (!found && (ovars.size() - i) <= expected_dims) {
136 TVM_FFI_ICHECK(expected_dims == ivars.size());
140 template <
typename FBinaryExpr>
143 const std::string& tag =
"") {
144 auto bh = BroadcastShape(A->shape, B->shape);
145 auto l = [&](tvm::ffi::Array<tvm::tirx::Var> ovars) {
146 return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)),
147 B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars)));
150 tvm::ffi::Array<tvm::PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()), l, name, tag);
Reference to PrimExprNode.
Definition: expr.h:126
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:100
a named variable in TIR
Definition: var.h:76
Utility functions for handling constants in TVM expressions.
Tensor compute(ffi::Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", ffi::Map< ffi::String, ffi::Any > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
PrimExpr make_zero(DataType t, Span span=Span())
Make a const zero expr.
Definition: op.h:1021
tvm::IntImmNode IntImmNode
Definition: expr.h:48
Tensor max(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the maximum of elements over a given axis.
Definition: reduction.h:442
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type,...
Definition: elemwise.h:277
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1981
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
runtime::DataType DataType
Definition: data_type.h:462
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
Operation node can generate one or multiple Tensors.