tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
reduction.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_REDUCTION_H_
25 #define TVM_TOPI_REDUCTION_H_
26 
27 #include <tvm/te/operation.h>
28 #include <tvm/topi/broadcast.h>
31 #include <tvm/topi/elemwise.h>
32 #include <tvm/topi/tags.h>
33 #include <tvm/topi/transform.h>
34 
35 #include <algorithm>
36 #include <iterator>
37 #include <string>
38 #include <vector>
39 
40 namespace tvm {
41 namespace topi {
42 
43 using namespace tvm::te;
44 
46 using FReduce = std::function<PrimExpr(PrimExpr source, const Array<IterVar>& axis,
47  Array<PrimExpr> init, Span span)>;
48 
50 using FCommReduce = std::function<Array<PrimExpr>(Array<PrimExpr> exprs, const Array<IterVar>& axis,
51  PrimExpr* condition)>;
52 
65 inline std::vector<int> GetRealAxis(int ndim, const Array<Integer>& axis) {
66  std::vector<int> real_axis;
67  if (!axis.defined() || axis.size() == 0) {
68  for (int i = 0; i < ndim; ++i) {
69  real_axis.push_back(i);
70  }
71  } else {
72  // Use a set so duplicates are removed and the dims are sorted
73  for (auto elem : axis) {
74  int64_t val = elem->value;
75  if (val < 0) {
76  val += ndim;
77  }
78  ICHECK_LE(val, ndim) << " exceeds the maximum dimension " << ndim;
79  ICHECK_GE(val, 0);
80  real_axis.push_back(static_cast<int>(val));
81  }
82  std::sort(real_axis.begin(), real_axis.end());
83  real_axis.resize(std::unique(real_axis.begin(), real_axis.end()) - real_axis.begin());
84  }
85  return real_axis;
86 }
87 
89 inline Array<IterVar> MakeReduceAxes(const std::vector<int>& real_axis, const Tensor& data) {
90  Array<IterVar> reduce_axes;
91  for (auto i : real_axis) {
92  std::string name = "k" + std::to_string(i);
93  reduce_axes.push_back(tvm::te::reduce_axis(Range(0, data->shape[i]), name));
94  }
95  return reduce_axes;
96 }
97 
99 inline Array<PrimExpr> MakeReduceTargetShape(const std::vector<int>& real_axis, const Tensor& data,
100  bool keepdims, bool atleast1d) {
101  auto ndim = data->shape.size();
102  Array<PrimExpr> target_shape;
103  if (keepdims) {
104  for (size_t i = 0; i < ndim; ++i) {
105  if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
106  // real_axis contains i
107  target_shape.push_back(1);
108  } else {
109  target_shape.push_back(data->shape[i]);
110  }
111  }
112  } else {
113  for (size_t i = 0; i < ndim; ++i) {
114  if (std::find(real_axis.begin(), real_axis.end(), i) == real_axis.end()) {
115  // real_axis does not contain i
116  target_shape.push_back(data->shape[i]);
117  }
118  }
119  }
120  if (target_shape.size() == 0 && atleast1d) {
121  target_shape.push_back(1);
122  }
123  return target_shape;
124 }
125 
139 inline Tensor DoCommReduce(const Tensor& data, FReduce func, const Array<PrimExpr>& target_shape,
140  const std::vector<int>& reduce_axes,
141  const std::vector<int>& squeeze_axes, Span span = Span()) {
142  auto r_axes = MakeReduceAxes(reduce_axes, data);
143  auto compute = [&](const Array<Var>& indices) {
144  Array<PrimExpr> eval_range;
145  Array<Var> eval_indices;
146  int arg_counter = 0;
147  int red_counter = 0;
148 
149  for (size_t i = 0; i < data->shape.size(); ++i) {
150  bool squeeze_i = std::find(squeeze_axes.begin(), squeeze_axes.end(), i) != squeeze_axes.end();
151  if (std::find(reduce_axes.begin(), reduce_axes.end(), i) != reduce_axes.end()) {
152  // real_axis contains i
153  eval_range.push_back(r_axes[red_counter]);
154  eval_indices.push_back(r_axes[red_counter]->var);
155  red_counter++;
156  arg_counter += !squeeze_i;
157  continue;
158  }
159  eval_range.push_back(indices[arg_counter]);
160  arg_counter++;
161  }
162 
163  return func(data(eval_range), r_axes, {}, span);
164  };
165 
166  return tvm::te::compute(target_shape, compute, data->op->name + "_red", kCommReduce);
167 }
168 
182 inline Tensor CommReduce(const Tensor& data, const Array<Integer>& axis, FReduce func,
183  bool keepdims, bool atleast1d) {
184  auto ndim = data->shape.size();
185  ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
186  auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
187  auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d);
188  return DoCommReduce(data, func, target_shape, real_axis,
189  keepdims ? std::vector<int>() : real_axis);
190 }
191 
205 inline Tensor CommReduceIdx(const Tensor& data, const Array<Integer>& axis, FCommReduce func,
206  bool keepdims, bool atleast1d) {
207  auto ndim = data->shape.size();
208  ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
209  auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
210  auto reduce_axes = MakeReduceAxes(real_axis, data);
211  auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d);
212 
213  auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func,
214  &data](const Array<Var>& indices) {
215  Array<PrimExpr> eval_range;
216  Array<PrimExpr> eval_indices;
217  int arg_counter = 0;
218  int red_counter = 0;
219 
220  for (size_t i = 0; i < ndim; ++i) {
221  if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
222  // real_axis contains i
223  eval_range.push_back(reduce_axes[red_counter]);
224  eval_indices.push_back(reduce_axes[red_counter]->var);
225  red_counter++;
226  } else {
227  if (!keepdims) {
228  eval_range.push_back(indices[arg_counter]);
229  arg_counter++;
230  } else {
231  eval_range.push_back(indices[i]);
232  }
233  }
234  }
235 
236  Array<PrimExpr> ravel_shape;
237  for (auto i : real_axis) {
238  ravel_shape.push_back(data->shape[i]);
239  }
240  auto idx = detail::RavelIndex(eval_indices, ravel_shape);
241  return func({idx, data(eval_range)}, reduce_axes, nullptr);
242  };
243 
244  auto temp_idx_val =
245  tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduceIdx);
246  auto temp_idx = temp_idx_val[0];
247  auto temp_val = temp_idx_val[1];
248  return tvm::te::compute(
249  target_shape, [&temp_idx](const Array<Var>& indices) { return temp_idx(indices); },
250  data->op->name + "_red", kCommReduceIdx);
251 }
252 
254 using FCombine = std::function<Array<PrimExpr>(Array<Var> lhs, Array<Var> rhs)>;
255 
257 using FIdentity = std::function<Array<PrimExpr>(std::vector<DataType> types)>;
258 
268 inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity,
269  std::string name = "reduce") {
270  return [fcombine, fidentity, name](Array<PrimExpr> exprs, const Array<IterVar>& axis,
271  PrimExpr* condition) {
272  Array<Var> lhs, rhs;
273  std::vector<DataType> dtypes;
274 
275  for (size_t i = 0; i < exprs.size(); ++i) {
276  auto dtype = exprs[i].dtype();
277  dtypes.push_back(dtype);
278  lhs.push_back(var(name + "_lhs_" + std::to_string(i), dtype));
279  rhs.push_back(var(name + "_rhs_" + std::to_string(i), dtype));
280  }
281 
282  auto result = fcombine(lhs, rhs);
283  auto id_elem = fidentity(dtypes);
284  auto cond = condition != nullptr ? *condition : tir::const_true();
285 
286  auto combiner = tvm::tir::CommReducer(lhs, rhs, result, id_elem);
287  Array<PrimExpr> outputs;
288  for (size_t i = 0; i < exprs.size(); ++i) {
289  outputs.push_back(tvm::tir::Reduce(combiner, exprs, axis, cond, static_cast<int>(i), {}));
290  }
291  return outputs;
292  };
293 }
294 
296 inline PrimExpr MinOp(PrimExpr source, Array<IterVar> axis, Array<PrimExpr> init = {},
297  Span span = Span()) {
298  return tvm::min(source, axis, init, span);
299 }
300 
302 inline PrimExpr MaxOp(PrimExpr source, Array<IterVar> axis, Array<PrimExpr> init = {},
303  Span span = Span()) {
304  return tvm::max(source, axis, init, span); // NOLINT(*)
305 }
306 
308 inline PrimExpr ProdOp(PrimExpr source, Array<IterVar> axis, Array<PrimExpr> init = {},
309  Span span = Span()) {
310  return tvm::prod(source, axis, init, span); // NOLINT(*)
311 }
312 
326 inline Tensor sum(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
327  bool atleast1d = false) {
328  if (data->dtype.is_bool()) {
329  return CommReduce(data, axis, tvm::any, keepdims, atleast1d);
330  } else {
331  return CommReduce(data, axis, tvm::sum, keepdims, atleast1d);
332  }
333 }
334 
335 inline Tensor collapse_sum(const Tensor& data, Array<PrimExpr> target_shape) {
336  ICHECK_GE(data->shape.size(), target_shape.size());
337  auto ishape = detail::GetConstIntValues(data->shape, "ishape");
338  auto oshape = detail::GetConstIntValues(target_shape, "oshape");
339 
340  std::vector<int> reduce_axes;
341  std::vector<int> squeeze_axes;
342  for (int i_ax = ishape.size() - 1, o_ax = oshape.size() - 1; i_ax >= 0; --i_ax) {
343  if (o_ax >= 0 && ishape[i_ax] == oshape[o_ax]) {
344  --o_ax;
345  continue;
346  }
347  reduce_axes.push_back(i_ax);
348  if (o_ax < 0) { // squeeze o_ax if was added during expansion
349  squeeze_axes.push_back(i_ax);
350  } else if (oshape[o_ax] == 1) {
351  --o_ax;
352  }
353  }
354 
355  if (reduce_axes.size() == 0) return topi::identity(data, "tensor", kCommReduce);
356 
357  std::reverse(reduce_axes.begin(), reduce_axes.end());
358  std::reverse(squeeze_axes.begin(), squeeze_axes.end());
359  return DoCommReduce(data, tvm::sum, target_shape, reduce_axes, squeeze_axes);
360 }
361 
376 inline Tensor all(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
377  bool atleast1d = false) {
378  return CommReduce(data, axis, tvm::all, keepdims, atleast1d);
379 }
380 
395 inline Tensor any(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
396  bool atleast1d = false) {
397  return CommReduce(data, axis, tvm::any, keepdims, atleast1d);
398 }
399 
414 inline Tensor min(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
415  bool atleast1d = false) {
416  return CommReduce(data, axis, MinOp, keepdims, atleast1d);
417 }
418 
433 inline Tensor max(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
434  bool atleast1d = false) {
435  return CommReduce(data, axis, MaxOp, keepdims, atleast1d);
436 }
437 
438 inline FCommReduce MakeArgminReducer(bool select_last_index = false) {
439  // Create a Commutative Reducer with a comparison operation, and method to get the initial value.
440  auto fcombine = [=](Array<Var> lhs, Array<Var> rhs) {
441  Array<PrimExpr> result;
442 
443  // Casting to avoid operator ambiguity
444  PrimExpr lhs_idx = static_cast<PrimExpr>(lhs[0]);
445  PrimExpr rhs_idx = static_cast<PrimExpr>(rhs[0]);
446  PrimExpr lhs_val = static_cast<PrimExpr>(lhs[1]);
447  PrimExpr rhs_val = static_cast<PrimExpr>(rhs[1]);
448 
449  // These variables compare the actual values of the array
450  auto is_smaller = lhs_val < rhs_val;
451  auto is_same = lhs_val == rhs_val;
452 
453  // This checks if the indices are correct for the reduction. E.g. for select_last_index
454  // it gives precedence for later indices of the same element and precedence for sooner
455  // indices if not select_last_index;
456  PrimExpr proper_index;
457  if (select_last_index) {
458  proper_index = lhs_idx > rhs_idx;
459  } else {
460  proper_index = lhs_idx < rhs_idx;
461  }
462 
463  PrimExpr update_index = is_smaller || (is_same && proper_index);
464  result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx
465  result.push_back(tvm::tir::Select(is_smaller, lhs[1], rhs[1])); // val
466  return result;
467  };
468  auto fidentity = [&](std::vector<DataType> types) {
469  Array<PrimExpr> result;
470  result.push_back(tvm::tir::make_const(types[0], -1)); // idx
471  result.push_back(tvm::max_value(types[1])); // val
472  return result;
473  };
474  return MakeCommReducer(fcombine, fidentity, "argmin");
475 }
476 
493 inline Tensor argmin(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
494  bool atleast1d = false, bool select_last_index = false) {
495  auto reducer = MakeArgminReducer(select_last_index);
496  return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
497 }
498 
499 inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) {
500  // Create a Commutative Reducer with a comparison operation, and method to get the initial value.
501  auto fcombine = [=](Array<Var> lhs, Array<Var> rhs) {
502  Array<PrimExpr> result;
503 
504  // Casting to avoid operator ambiguity
505  PrimExpr lhs_idx = static_cast<PrimExpr>(lhs[0]);
506  PrimExpr rhs_idx = static_cast<PrimExpr>(rhs[0]);
507  PrimExpr lhs_val = static_cast<PrimExpr>(lhs[1]);
508  PrimExpr rhs_val = static_cast<PrimExpr>(rhs[1]);
509 
510  // These variables compare the actual values of the array
511  auto is_bigger = lhs_val > rhs_val;
512  auto is_same = lhs_val == rhs_val;
513 
514  // This checks if the indices are correct for the reduction. E.g. for select_last_index
515  // it gives precedence for later indices of the same element and precedence for sooner
516  // indices if not select_last_index;
517  PrimExpr proper_index;
518  if (select_last_index) {
519  proper_index = lhs_idx > rhs_idx;
520  } else {
521  proper_index = lhs_idx < rhs_idx;
522  }
523 
524  PrimExpr update_index = is_bigger || (is_same && proper_index);
525  result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx
526  result.push_back(tvm::tir::Select(is_bigger, lhs[1], rhs[1])); // val
527  return result;
528  };
529  auto fidentity = [&](std::vector<DataType> types) {
530  Array<PrimExpr> result;
531  result.push_back(tvm::tir::make_const(types[0], -1)); // idx
532  result.push_back(tvm::min_value(types[1])); // val
533  return result;
534  };
535  return MakeCommReducer(fcombine, fidentity, "argmax");
536 }
537 
553 inline Tensor argmax(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
554  bool atleast1d = false, bool select_last_index = false) {
555  auto reducer = MakeArgmaxReducer(select_last_index);
556  return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
557 }
558 
572 inline Tensor prod(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
573  bool atleast1d = false) {
574  return CommReduce(data, axis, ProdOp, keepdims, atleast1d);
575 }
576 
581  auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
582  Array<PrimExpr> result;
583  ICHECK_EQ(lhs.size(), rhs.size());
584  result.reserve(lhs.size());
585  for (size_t i = 0; i < lhs.size(); ++i) {
586  result.push_back(lhs[i] + rhs[i]);
587  }
588  return result;
589  };
590  auto fidentity = [](std::vector<DataType> types) {
591  Array<PrimExpr> result;
592  for (size_t i = 0; i < types.size(); ++i) {
593  result.push_back(tvm::tir::make_const(types[i], 0));
594  }
595  return result;
596  };
597  return MakeCommReducer(fcombine, fidentity, "tuple_sum");
598 }
599 
600 } // namespace topi
601 } // namespace tvm
602 #endif // TVM_TOPI_REDUCTION_H_
Tensor max(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the maximum of elements over a given axis.
Definition: reduction.h:433
tvm::Span Span
Definition: base.h:65
void reserve(int64_t n)
Make sure the list has the capacity of at least n.
Definition: array.h:569
Tensor CommReduceIdx(const Tensor &data, const Array< Integer > &axis, FCommReduce func, bool keepdims, bool atleast1d)
Create an index reduction operation.
Definition: reduction.h:205
Managed reference to CommReducerNode.
Definition: expr.h:1025
FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, std::string name="reduce")
Create a commutative reducer for a reduction.
Definition: reduction.h:268
Tensor collapse_sum(const Tensor &data, Array< PrimExpr > target_shape)
Definition: reduction.h:335
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
constexpr auto kCommReduceIdx
Definition: tags.h:35
std::function< Array< PrimExpr >(Array< Var > lhs, Array< Var > rhs)> FCombine
A combiner function for a reduction.
Definition: reduction.h:254
Managed reference to ReduceNode.
Definition: expr.h:1089
FCommReduce MakeTupleSumReducer()
Create communitive reducer summing over tuples.
Definition: reduction.h:580
PrimExpr min_value(const DataType &dtype, Span span=Span())
FCommReduce MakeArgmaxReducer(bool select_last_index=false)
Definition: reduction.h:499
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:954
Tensor min(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the minimum of elements over a given axis.
Definition: reduction.h:414
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Tensor expression language DSL.
Definition: extracted_task.h:33
Array< IterVar > MakeReduceAxes(const std::vector< int > &real_axis, const Tensor &data)
Enumerate the axes for a reduce op.
Definition: reduction.h:89
Array< PrimExpr > MakeReduceTargetShape(const std::vector< int > &real_axis, const Tensor &data, bool keepdims, bool atleast1d)
Calculate the target shape for a reduce op.
Definition: reduction.h:99
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
FCommReduce MakeArgminReducer(bool select_last_index=false)
Definition: reduction.h:438
Tensor sum(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:326
Utility functions for handling constants in TVM expressions.
Range constainer.
Definition: expr.h:715
PrimExpr const_true(int lanes=1, Span span=Span())
Make a constant true expression.
Definition: op.h:785
Definition: source_map.h:120
size_t size() const
Definition: array.h:420
PrimExpr MaxOp(PrimExpr source, Array< IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
Wrap tvm::max to ensure we get the correct overload.
Definition: reduction.h:302
bool defined() const
Definition: object.h:544
PrimExpr MinOp(PrimExpr source, Array< IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
Wrap tvm::min to ensure we get the correct overload.
Definition: reduction.h:296
Tensor identity(const Tensor &x, std::string name="T_identity", std::string tag=kElementWise)
Creates an operation that returns identity of a given tensor.
Definition: elemwise.h:152
Elementwise op constructions.
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
std::function< Array< PrimExpr >(std::vector< DataType > types)> FIdentity
An initializer function for a reduction.
Definition: reduction.h:257
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
std::function< Array< PrimExpr >(Array< PrimExpr > exprs, const Array< IterVar > &axis, PrimExpr *condition)> FCommReduce
The operation to use for CommReduceIdx.
Definition: reduction.h:51
PrimExpr any(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
logical Or of source expression over axis
PrimExpr ProdOp(PrimExpr source, Array< IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
Wrap tvm::prod to ensure we get the correct overload.
Definition: reduction.h:308
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
constexpr auto kCommReduce
Definition: tags.h:34
Tensor argmax(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false, bool select_last_index=false)
Creates an operation that finds the indices of the maximum values over a given axis.
Definition: reduction.h:553
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
iterator end() const
Definition: array.h:390
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Operation node can generate one or multiple Tensors.
PrimExpr all(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
logical And of source expression over axis
Managed reference to SelectNode.
Definition: expr.h:609
Tensor prod(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates product operation over given axis.
Definition: reduction.h:572
Transform op constructors.
PrimExpr max_value(const DataType &dtype, Span span=Span())
Tensor any(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that computes the logical OR of elements over a given axis.
Definition: reduction.h:395
std::function< PrimExpr(PrimExpr source, const Array< IterVar > &axis, Array< PrimExpr > init, Span span)> FReduce
The operation to use for CommReduce.
Definition: reduction.h:47
External function interface to rocBLAS libraries.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
std::vector< int > GetRealAxis(int ndim, const Array< Integer > &axis)
Convert a reduction axis which could be empty or have negative elements into a real axis with valid d...
Definition: reduction.h:65
Tensor all(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that computes the logical AND of elements over a given axis. ...
Definition: reduction.h:376
Broadcast op constructions.
Reference to PrimExprNode.
Definition: expr.h:114
Index ravel and unraval operations.
Tensor DoCommReduce(const Tensor &data, FReduce func, const Array< PrimExpr > &target_shape, const std::vector< int > &reduce_axes, const std::vector< int > &squeeze_axes, Span span=Span())
Create a reduction operation.
Definition: reduction.h:139
PrimExpr prod(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
product of source expression over axis
Tensor argmin(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false, bool select_last_index=false)
Creates an operation that finds the indices of the minimum values over a given axis.
Definition: reduction.h:493
Tensor CommReduce(const Tensor &data, const Array< Integer > &axis, FReduce func, bool keepdims, bool atleast1d)
Create a reduction operation.
Definition: reduction.h:182