24 #ifndef TVM_TOPI_REDUCTION_H_
25 #define TVM_TOPI_REDUCTION_H_
47 ffi::Array<PrimExpr> init,
Span span)>;
51 ffi::Array<PrimExpr> exprs,
const ffi::Array<IterVar>& axis,
PrimExpr* condition)>;
65 inline std::vector<int>
GetRealAxis(
int ndim,
const ffi::Optional<ffi::Array<Integer>>& axis) {
66 std::vector<int> real_axis;
67 if (!axis.has_value()) {
68 for (
int i = 0; i < ndim; ++i) {
69 real_axis.push_back(i);
73 for (
auto elem : axis.value()) {
74 int64_t val = elem->value;
78 ICHECK_LT(val, ndim) <<
" exceeds the maximum dimension " << ndim;
80 real_axis.push_back(
static_cast<int>(val));
82 std::sort(real_axis.begin(), real_axis.end());
83 real_axis.resize(std::unique(real_axis.begin(), real_axis.end()) - real_axis.begin());
90 ffi::Array<IterVar> reduce_axes;
91 for (
auto i : real_axis) {
92 std::string name =
"k" + std::to_string(i);
100 const Tensor& data,
bool keepdims,
102 auto ndim = data->shape.size();
103 ffi::Array<PrimExpr> target_shape;
105 for (
size_t i = 0; i < ndim; ++i) {
106 if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
108 target_shape.push_back(1);
110 target_shape.push_back(data->shape[i]);
114 for (
size_t i = 0; i < ndim; ++i) {
115 if (std::find(real_axis.begin(), real_axis.end(), i) == real_axis.end()) {
117 target_shape.push_back(data->shape[i]);
121 if (target_shape.size() == 0 && atleast1d) {
122 target_shape.push_back(1);
141 const ffi::Array<PrimExpr>& target_shape,
142 const std::vector<int>& reduce_axes,
143 const std::vector<int>& squeeze_axes,
Span span =
Span()) {
145 auto compute = [&](
const ffi::Array<Var>& indices) {
146 ffi::Array<PrimExpr> eval_range;
147 ffi::Array<Var> eval_indices;
151 for (
size_t i = 0; i < data->shape.size(); ++i) {
152 bool squeeze_i = std::find(squeeze_axes.begin(), squeeze_axes.end(), i) != squeeze_axes.end();
153 if (std::find(reduce_axes.begin(), reduce_axes.end(), i) != reduce_axes.end()) {
155 eval_range.push_back(r_axes[red_counter]);
156 eval_indices.push_back(r_axes[red_counter]->
var);
158 arg_counter += !squeeze_i;
161 eval_range.push_back(indices[arg_counter]);
165 return func(data(eval_range), r_axes, {}, span);
185 FReduce func,
bool keepdims,
bool atleast1d) {
186 auto ndim = data->shape.size();
187 ICHECK_NE(ndim, 0) <<
"Cannot reduce a 0 dim Tensor";
188 auto real_axis =
GetRealAxis(
static_cast<int>(ndim), axis);
190 return DoCommReduce(data, func, target_shape, real_axis,
191 keepdims ? std::vector<int>() : real_axis);
209 auto ndim = data->shape.size();
210 ICHECK_NE(ndim, 0) <<
"Cannot reduce a 0 dim Tensor";
211 auto real_axis =
GetRealAxis(
static_cast<int>(ndim), axis);
215 auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func,
216 &data](
const ffi::Array<Var>& indices) {
217 ffi::Array<PrimExpr> eval_range;
218 ffi::Array<PrimExpr> eval_indices;
222 for (
size_t i = 0; i < ndim; ++i) {
223 if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
225 eval_range.push_back(reduce_axes[red_counter]);
226 eval_indices.push_back(reduce_axes[red_counter]->
var);
230 eval_range.push_back(indices[arg_counter]);
233 eval_range.push_back(indices[i]);
238 ffi::Array<PrimExpr> ravel_shape;
239 for (
auto i : real_axis) {
240 ravel_shape.push_back(data->shape[i]);
242 auto idx = detail::RavelIndex(eval_indices, ravel_shape);
243 return func({idx, data(eval_range)}, reduce_axes,
nullptr);
248 auto temp_idx = temp_idx_val[0];
249 auto temp_val = temp_idx_val[1];
251 target_shape, [&temp_idx](
const ffi::Array<Var>& indices) {
return temp_idx(indices); },
256 using FCombine = std::function<ffi::Array<PrimExpr>(ffi::Array<Var> lhs, ffi::Array<Var> rhs)>;
259 using FIdentity = std::function<ffi::Array<PrimExpr>(std::vector<DataType> types)>;
271 std::string name =
"reduce") {
272 return [fcombine, fidentity, name](ffi::Array<PrimExpr> exprs,
const ffi::Array<IterVar>& axis,
274 ffi::Array<Var> lhs, rhs;
275 std::vector<DataType> dtypes;
277 for (
size_t i = 0; i < exprs.size(); ++i) {
278 auto dtype = exprs[i].dtype();
279 dtypes.push_back(dtype);
280 lhs.push_back(
var(name +
"_lhs_" + std::to_string(i), dtype));
281 rhs.push_back(
var(name +
"_rhs_" + std::to_string(i), dtype));
284 auto result = fcombine(lhs, rhs);
285 auto id_elem = fidentity(dtypes);
289 ffi::Array<PrimExpr> outputs;
290 for (
size_t i = 0; i < exprs.size(); ++i) {
291 outputs.push_back(
tvm::tir::Reduce(combiner, exprs, axis, cond,
static_cast<int>(i), {}));
300 return tvm::min(source, axis, init, span);
306 return tvm::max(source, axis, init, span);
312 return tvm::prod(source, axis, init, span);
329 bool keepdims =
false,
bool atleast1d =
false) {
330 if (data->dtype.is_bool()) {
338 const auto& ishape = data->shape;
339 const auto& oshape = target_shape;
340 int isize = data->shape.size();
341 int osize = target_shape.size();
343 ICHECK_GE(isize, osize)
344 <<
"Invalid collapse: input dimensionality smaller than output dimensionality.\ninput shape: "
345 << data->shape <<
"\nvs\noutput shape: " << target_shape;
347 std::vector<int> reduce_axes;
348 std::vector<int> squeeze_axes;
351 for (
int i_ax = isize - 1, o_ax = osize - 1; i_ax >= 0; --i_ax) {
352 if (o_ax >= 0 && topi::detail::EqualCheck(ishape[i_ax], oshape[o_ax])) {
356 reduce_axes.push_back(i_ax);
358 squeeze_axes.push_back(i_ax);
359 }
else if (topi::detail::EqualCheck(one, oshape[o_ax])) {
366 std::reverse(reduce_axes.begin(), reduce_axes.end());
367 std::reverse(squeeze_axes.begin(), squeeze_axes.end());
386 bool keepdims =
false,
bool atleast1d =
false) {
405 bool keepdims =
false,
bool atleast1d =
false) {
424 bool keepdims =
false,
bool atleast1d =
false) {
443 bool keepdims =
false,
bool atleast1d =
false) {
449 auto fcombine = [=](ffi::Array<Var> lhs, ffi::Array<Var> rhs) {
450 ffi::Array<PrimExpr> result;
459 auto is_smaller = lhs_val < rhs_val;
460 auto is_same = lhs_val == rhs_val;
466 if (select_last_index) {
467 proper_index = lhs_idx > rhs_idx;
469 proper_index = lhs_idx < rhs_idx;
472 PrimExpr update_index = is_smaller || (is_same && proper_index);
477 auto fidentity = [&](std::vector<DataType> types) {
478 ffi::Array<PrimExpr> result;
503 bool keepdims =
false,
bool atleast1d =
false,
504 bool select_last_index =
false) {
506 return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
511 auto fcombine = [=](ffi::Array<Var> lhs, ffi::Array<Var> rhs) {
512 ffi::Array<PrimExpr> result;
521 auto is_bigger = lhs_val > rhs_val;
522 auto is_same = lhs_val == rhs_val;
528 if (select_last_index) {
529 proper_index = lhs_idx > rhs_idx;
531 proper_index = lhs_idx < rhs_idx;
534 PrimExpr update_index = is_bigger || (is_same && proper_index);
539 auto fidentity = [&](std::vector<DataType> types) {
540 ffi::Array<PrimExpr> result;
564 bool keepdims =
false,
bool atleast1d =
false,
565 bool select_last_index =
false) {
567 return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
584 bool keepdims =
false,
bool atleast1d =
false) {
592 auto fcombine = [](ffi::Array<Var> lhs, ffi::Array<Var> rhs) {
593 ffi::Array<PrimExpr> result;
594 ICHECK_EQ(lhs.size(), rhs.size());
595 result.reserve(lhs.size());
596 for (
size_t i = 0; i < lhs.size(); ++i) {
597 result.push_back(lhs[i] + rhs[i]);
601 auto fidentity = [](std::vector<DataType> types) {
602 ffi::Array<PrimExpr> result;
603 for (
size_t i = 0; i < types.size(); ++i) {
Broadcast op constructions.
Reference to PrimExprNode.
Definition: expr.h:124
Range container
Definition: expr.h:689
Definition: source_map.h:111
Managed Tensor. The array is backed by reference counted blocks.
Definition: tensor.h:53
Managed reference to CommReducerNode.
Definition: expr.h:832
Managed reference to ReduceNode.
Definition: expr.h:876
Managed reference to SelectNode.
Definition: expr.h:515
Utility functions for handling constants in TVM expressions.
Elementwise op constructions.
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(ffi::Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", ffi::Map< ffi::String, ffi::Any > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:994
PrimExpr const_true(int lanes=1, Span span=Span())
Make a constant true expression.
Definition: op.h:818
Tensor DoCommReduce(const Tensor &data, FReduce func, const ffi::Array< PrimExpr > &target_shape, const std::vector< int > &reduce_axes, const std::vector< int > &squeeze_axes, Span span=Span())
Create a reduction operation.
Definition: reduction.h:140
Tensor collapse_sum(const Tensor &data, ffi::Array< PrimExpr > target_shape)
Definition: reduction.h:337
std::vector< int > GetRealAxis(int ndim, const ffi::Optional< ffi::Array< Integer >> &axis)
Convert a reduction axis which could be empty or have negative elements into a real axis with valid d...
Definition: reduction.h:65
FCommReduce MakeTupleSumReducer()
Create communitive reducer summing over tuples.
Definition: reduction.h:591
FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, std::string name="reduce")
Create a commutative reducer for a reduction.
Definition: reduction.h:270
ffi::Array< PrimExpr > MakeReduceTargetShape(const std::vector< int > &real_axis, const Tensor &data, bool keepdims, bool atleast1d)
Calculate the target shape for a reduce op.
Definition: reduction.h:99
std::function< ffi::Array< PrimExpr >(ffi::Array< PrimExpr > exprs, const ffi::Array< IterVar > &axis, PrimExpr *condition)> FCommReduce
The operation to use for CommReduceIdx.
Definition: reduction.h:51
Tensor max(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the maximum of elements over a given axis.
Definition: reduction.h:442
ffi::Array< IterVar > MakeReduceAxes(const std::vector< int > &real_axis, const Tensor &data)
Enumerate the axes for a reduce op.
Definition: reduction.h:89
FCommReduce MakeArgmaxReducer(bool select_last_index=false)
Definition: reduction.h:509
PrimExpr MaxOp(PrimExpr source, ffi::Array< IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
Wrap tvm::max to ensure we get the correct overload.
Definition: reduction.h:304
std::function< PrimExpr(PrimExpr source, const ffi::Array< IterVar > &axis, ffi::Array< PrimExpr > init, Span span)> FReduce
The operation to use for CommReduce.
Definition: reduction.h:47
constexpr auto kCommReduce
Definition: tags.h:34
Tensor argmin(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false, bool select_last_index=false)
Creates an operation that finds the indices of the minimum values over a given axis.
Definition: reduction.h:502
Tensor CommReduceIdx(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, FCommReduce func, bool keepdims, bool atleast1d)
Create an index reduction operation.
Definition: reduction.h:207
constexpr auto kCommReduceIdx
Definition: tags.h:35
PrimExpr MinOp(PrimExpr source, ffi::Array< IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
Wrap tvm::min to ensure we get the correct overload.
Definition: reduction.h:298
FCommReduce MakeArgminReducer(bool select_last_index=false)
Definition: reduction.h:447
std::function< ffi::Array< PrimExpr >(ffi::Array< Var > lhs, ffi::Array< Var > rhs)> FCombine
A combiner function for a reduction.
Definition: reduction.h:256
Tensor identity(const Tensor &x, std::string name="T_identity", std::string tag=kElementWise)
Creates an operation that returns identity of a given tensor.
Definition: elemwise.h:152
PrimExpr ProdOp(PrimExpr source, ffi::Array< IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
Wrap tvm::prod to ensure we get the correct overload.
Definition: reduction.h:310
Tensor any(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that computes the logical OR of elements over a given axis.
Definition: reduction.h:404
Tensor prod(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates product operation over given axis.
Definition: reduction.h:583
Tensor min(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the minimum of elements over a given axis.
Definition: reduction.h:423
Tensor sum(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:328
Tensor CommReduce(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, FReduce func, bool keepdims, bool atleast1d)
Create a reduction operation.
Definition: reduction.h:184
Tensor argmax(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false, bool select_last_index=false)
Creates an operation that finds the indices of the maximum values over a given axis.
Definition: reduction.h:563
std::function< ffi::Array< PrimExpr >(std::vector< DataType > types)> FIdentity
An initializer function for a reduction.
Definition: reduction.h:259
Tensor all(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that computes the logical AND of elements over a given axis.
Definition: reduction.h:385
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr any(PrimExpr source, ffi::Array< tir::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
logical Or of source expression over axis
PrimExpr min_value(const DataType &dtype, Span span=Span())
PrimExpr max_value(const DataType &dtype, Span span=Span())
PrimExpr sum(PrimExpr source, ffi::Array< tir::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
PrimExpr prod(PrimExpr source, ffi::Array< tir::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
product of source expression over axis
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
PrimExpr all(PrimExpr source, ffi::Array< tir::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
logical And of source expression over axis
Operation node can generate one or multiple Tensors.
Index ravel and unraval operations.