24 #ifndef TVM_TOPI_REDUCTION_H_ 25 #define TVM_TOPI_REDUCTION_H_ 46 using FReduce = std::function<PrimExpr(PrimExpr source, const Array<IterVar>& axis,
66 std::vector<int> real_axis;
68 for (
int i = 0; i < ndim; ++i) {
69 real_axis.push_back(i);
73 for (
auto elem : axis) {
74 int64_t val = elem->value;
78 ICHECK_LE(val, ndim) <<
" exceeds the maximum dimension " << ndim;
80 real_axis.push_back(static_cast<int>(val));
82 std::sort(real_axis.begin(), real_axis.end());
83 real_axis.resize(std::unique(real_axis.begin(), real_axis.end()) - real_axis.begin());
91 for (
auto i : real_axis) {
92 std::string name =
"k" + std::to_string(i);
100 bool keepdims,
bool atleast1d) {
101 auto ndim = data->shape.size();
104 for (
size_t i = 0; i < ndim; ++i) {
105 if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
113 for (
size_t i = 0; i < ndim; ++i) {
114 if (std::find(real_axis.begin(), real_axis.end(), i) == real_axis.end()) {
120 if (target_shape.
size() == 0 && atleast1d) {
140 const std::vector<int>& reduce_axes,
141 const std::vector<int>& squeeze_axes,
Span span =
Span()) {
149 for (
size_t i = 0; i < data->shape.size(); ++i) {
150 bool squeeze_i = std::find(squeeze_axes.begin(), squeeze_axes.end(), i) != squeeze_axes.
end();
151 if (std::find(reduce_axes.begin(), reduce_axes.end(), i) != reduce_axes.end()) {
153 eval_range.
push_back(r_axes[red_counter]);
156 arg_counter += !squeeze_i;
159 eval_range.
push_back(indices[arg_counter]);
163 return func(data(eval_range), r_axes, {}, span);
183 bool keepdims,
bool atleast1d) {
184 auto ndim = data->shape.size();
185 ICHECK_NE(ndim, 0) <<
"Cannot reduce a 0 dim Tensor";
186 auto real_axis =
GetRealAxis(static_cast<int>(ndim), axis);
188 return DoCommReduce(data, func, target_shape, real_axis,
189 keepdims ? std::vector<int>() : real_axis);
206 bool keepdims,
bool atleast1d) {
207 auto ndim = data->shape.size();
208 ICHECK_NE(ndim, 0) <<
"Cannot reduce a 0 dim Tensor";
209 auto real_axis =
GetRealAxis(static_cast<int>(ndim), axis);
213 auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func,
220 for (
size_t i = 0; i < ndim; ++i) {
221 if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
223 eval_range.
push_back(reduce_axes[red_counter]);
228 eval_range.
push_back(indices[arg_counter]);
237 for (
auto i : real_axis) {
240 auto idx = detail::RavelIndex(eval_indices, ravel_shape);
241 return func({idx, data(eval_range)}, reduce_axes,
nullptr);
246 auto temp_idx = temp_idx_val[0];
247 auto temp_val = temp_idx_val[1];
249 target_shape, [&temp_idx](
const Array<Var>& indices) {
return temp_idx(indices); },
257 using FIdentity = std::function<Array<PrimExpr>(std::vector<DataType> types)>;
269 std::string name =
"reduce") {
273 std::vector<DataType> dtypes;
275 for (
size_t i = 0; i < exprs.size(); ++i) {
276 auto dtype = exprs[i].dtype();
278 lhs.
push_back(
var(name +
"_lhs_" + std::to_string(i), dtype));
279 rhs.
push_back(
var(name +
"_rhs_" + std::to_string(i), dtype));
282 auto result = fcombine(lhs, rhs);
283 auto id_elem = fidentity(dtypes);
288 for (
size_t i = 0; i < exprs.size(); ++i) {
298 return tvm::min(source, axis, init, span);
304 return tvm::max(source, axis, init, span);
310 return tvm::prod(source, axis, init, span);
327 bool atleast1d =
false) {
328 if (data->dtype.is_bool()) {
336 ICHECK_GE(data->shape.size(), target_shape.
size());
337 auto ishape = detail::GetConstIntValues(data->shape,
"ishape");
338 auto oshape = detail::GetConstIntValues(target_shape,
"oshape");
340 std::vector<int> reduce_axes;
341 std::vector<int> squeeze_axes;
342 for (
int i_ax = ishape.size() - 1, o_ax = oshape.size() - 1; i_ax >= 0; --i_ax) {
343 if (o_ax >= 0 && ishape[i_ax] == oshape[o_ax]) {
347 reduce_axes.push_back(i_ax);
349 squeeze_axes.push_back(i_ax);
350 }
else if (oshape[o_ax] == 1) {
357 std::reverse(reduce_axes.begin(), reduce_axes.end());
358 std::reverse(squeeze_axes.begin(), squeeze_axes.end());
377 bool atleast1d =
false) {
396 bool atleast1d =
false) {
415 bool atleast1d =
false) {
434 bool atleast1d =
false) {
450 auto is_smaller = lhs_val < rhs_val;
451 auto is_same = lhs_val == rhs_val;
457 if (select_last_index) {
458 proper_index = lhs_idx > rhs_idx;
460 proper_index = lhs_idx < rhs_idx;
463 PrimExpr update_index = is_smaller || (is_same && proper_index);
468 auto fidentity = [&](std::vector<DataType> types) {
494 bool atleast1d =
false,
bool select_last_index =
false) {
496 return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
511 auto is_bigger = lhs_val > rhs_val;
512 auto is_same = lhs_val == rhs_val;
518 if (select_last_index) {
519 proper_index = lhs_idx > rhs_idx;
521 proper_index = lhs_idx < rhs_idx;
524 PrimExpr update_index = is_bigger || (is_same && proper_index);
529 auto fidentity = [&](std::vector<DataType> types) {
554 bool atleast1d =
false,
bool select_last_index =
false) {
556 return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
573 bool atleast1d =
false) {
583 ICHECK_EQ(lhs.
size(), rhs.size());
585 for (
size_t i = 0; i < lhs.
size(); ++i) {
590 auto fidentity = [](std::vector<DataType> types) {
592 for (
size_t i = 0; i < types.size(); ++i) {
602 #endif // TVM_TOPI_REDUCTION_H_ Tensor max(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the maximum of elements over a given axis.
Definition: reduction.h:433
tvm::Span Span
Definition: base.h:65
void reserve(int64_t n)
Make sure the list has the capacity of at least n.
Definition: array.h:569
Tensor CommReduceIdx(const Tensor &data, const Array< Integer > &axis, FCommReduce func, bool keepdims, bool atleast1d)
Create an index reduction operation.
Definition: reduction.h:205
Managed reference to CommReducerNode.
Definition: expr.h:1025
FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, std::string name="reduce")
Create a commutative reducer for a reduction.
Definition: reduction.h:268
Tensor collapse_sum(const Tensor &data, Array< PrimExpr > target_shape)
Definition: reduction.h:335
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
constexpr auto kCommReduceIdx
Definition: tags.h:35
std::function< Array< PrimExpr >(Array< Var > lhs, Array< Var > rhs)> FCombine
A combiner function for a reduction.
Definition: reduction.h:254
Managed reference to ReduceNode.
Definition: expr.h:1089
FCommReduce MakeTupleSumReducer()
Create communitive reducer summing over tuples.
Definition: reduction.h:580
PrimExpr min_value(const DataType &dtype, Span span=Span())
FCommReduce MakeArgmaxReducer(bool select_last_index=false)
Definition: reduction.h:499
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:954
Tensor min(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the minimum of elements over a given axis.
Definition: reduction.h:414
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Tensor expression language DSL.
Definition: extracted_task.h:33
Array< IterVar > MakeReduceAxes(const std::vector< int > &real_axis, const Tensor &data)
Enumerate the axes for a reduce op.
Definition: reduction.h:89
Array< PrimExpr > MakeReduceTargetShape(const std::vector< int > &real_axis, const Tensor &data, bool keepdims, bool atleast1d)
Calculate the target shape for a reduce op.
Definition: reduction.h:99
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
FCommReduce MakeArgminReducer(bool select_last_index=false)
Definition: reduction.h:438
Tensor sum(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:326
Utility functions for handling constants in TVM expressions.
Range constainer.
Definition: expr.h:715
PrimExpr const_true(int lanes=1, Span span=Span())
Make a constant true expression.
Definition: op.h:785
Definition: source_map.h:120
size_t size() const
Definition: array.h:420
PrimExpr MaxOp(PrimExpr source, Array< IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
Wrap tvm::max to ensure we get the correct overload.
Definition: reduction.h:302
bool defined() const
Definition: object.h:544
PrimExpr MinOp(PrimExpr source, Array< IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
Wrap tvm::min to ensure we get the correct overload.
Definition: reduction.h:296
Tensor identity(const Tensor &x, std::string name="T_identity", std::string tag=kElementWise)
Creates an operation that returns identity of a given tensor.
Definition: elemwise.h:152
Elementwise op constructions.
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
std::function< Array< PrimExpr >(std::vector< DataType > types)> FIdentity
An initializer function for a reduction.
Definition: reduction.h:257
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
std::function< Array< PrimExpr >(Array< PrimExpr > exprs, const Array< IterVar > &axis, PrimExpr *condition)> FCommReduce
The operation to use for CommReduceIdx.
Definition: reduction.h:51
PrimExpr any(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
logical Or of source expression over axis
PrimExpr ProdOp(PrimExpr source, Array< IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
Wrap tvm::prod to ensure we get the correct overload.
Definition: reduction.h:308
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
constexpr auto kCommReduce
Definition: tags.h:34
Tensor argmax(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false, bool select_last_index=false)
Creates an operation that finds the indices of the maximum values over a given axis.
Definition: reduction.h:553
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
iterator end() const
Definition: array.h:390
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Operation node can generate one or multiple Tensors.
PrimExpr all(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
logical And of source expression over axis
Managed reference to SelectNode.
Definition: expr.h:609
Tensor prod(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates product operation over given axis.
Definition: reduction.h:572
PrimExpr max_value(const DataType &dtype, Span span=Span())
Tensor any(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that computes the logical OR of elements over a given axis.
Definition: reduction.h:395
std::function< PrimExpr(PrimExpr source, const Array< IterVar > &axis, Array< PrimExpr > init, Span span)> FReduce
The operation to use for CommReduce.
Definition: reduction.h:47
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
std::vector< int > GetRealAxis(int ndim, const Array< Integer > &axis)
Convert a reduction axis which could be empty or have negative elements into a real axis with valid d...
Definition: reduction.h:65
Tensor all(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that computes the logical AND of elements over a given axis. ...
Definition: reduction.h:376
Broadcast op constructions.
Reference to PrimExprNode.
Definition: expr.h:114
Index ravel and unraval operations.
Tensor DoCommReduce(const Tensor &data, FReduce func, const Array< PrimExpr > &target_shape, const std::vector< int > &reduce_axes, const std::vector< int > &squeeze_axes, Span span=Span())
Create a reduction operation.
Definition: reduction.h:139
PrimExpr prod(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
product of source expression over axis
Tensor argmin(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false, bool select_last_index=false)
Creates an operation that finds the indices of the minimum values over a given axis.
Definition: reduction.h:493
Tensor CommReduce(const Tensor &data, const Array< Integer > &axis, FReduce func, bool keepdims, bool atleast1d)
Create a reduction operation.
Definition: reduction.h:182