tvm
reduction.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_REDUCTION_H_
25 #define TVM_TOPI_REDUCTION_H_
26 
27 #include <tvm/te/operation.h>
28 #include <tvm/topi/broadcast.h>
31 #include <tvm/topi/elemwise.h>
32 #include <tvm/topi/tags.h>
33 #include <tvm/topi/transform.h>
34 
35 #include <algorithm>
36 #include <iterator>
37 #include <string>
38 #include <vector>
39 
40 namespace tvm {
41 namespace topi {
42 
43 using namespace tvm::te;
44 
46 using FReduce = std::function<PrimExpr(PrimExpr source, const ffi::Array<IterVar>& axis,
47  ffi::Array<PrimExpr> init, Span span)>;
48 
50 using FCommReduce = std::function<ffi::Array<PrimExpr>(
51  ffi::Array<PrimExpr> exprs, const ffi::Array<IterVar>& axis, PrimExpr* condition)>;
52 
65 inline std::vector<int> GetRealAxis(int ndim, const ffi::Optional<ffi::Array<Integer>>& axis) {
66  std::vector<int> real_axis;
67  if (!axis.has_value()) {
68  for (int i = 0; i < ndim; ++i) {
69  real_axis.push_back(i);
70  }
71  } else {
72  // Use a set so duplicates are removed and the dims are sorted
73  for (auto elem : axis.value()) {
74  int64_t val = elem->value;
75  if (val < 0) {
76  val += ndim;
77  }
78  ICHECK_LT(val, ndim) << " exceeds the maximum dimension " << ndim;
79  ICHECK_GE(val, 0);
80  real_axis.push_back(static_cast<int>(val));
81  }
82  std::sort(real_axis.begin(), real_axis.end());
83  real_axis.resize(std::unique(real_axis.begin(), real_axis.end()) - real_axis.begin());
84  }
85  return real_axis;
86 }
87 
89 inline ffi::Array<IterVar> MakeReduceAxes(const std::vector<int>& real_axis, const Tensor& data) {
90  ffi::Array<IterVar> reduce_axes;
91  for (auto i : real_axis) {
92  std::string name = "k" + std::to_string(i);
93  reduce_axes.push_back(tvm::te::reduce_axis(Range(0, data->shape[i]), name));
94  }
95  return reduce_axes;
96 }
97 
99 inline ffi::Array<PrimExpr> MakeReduceTargetShape(const std::vector<int>& real_axis,
100  const Tensor& data, bool keepdims,
101  bool atleast1d) {
102  auto ndim = data->shape.size();
103  ffi::Array<PrimExpr> target_shape;
104  if (keepdims) {
105  for (size_t i = 0; i < ndim; ++i) {
106  if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
107  // real_axis contains i
108  target_shape.push_back(1);
109  } else {
110  target_shape.push_back(data->shape[i]);
111  }
112  }
113  } else {
114  for (size_t i = 0; i < ndim; ++i) {
115  if (std::find(real_axis.begin(), real_axis.end(), i) == real_axis.end()) {
116  // real_axis does not contain i
117  target_shape.push_back(data->shape[i]);
118  }
119  }
120  }
121  if (target_shape.size() == 0 && atleast1d) {
122  target_shape.push_back(1);
123  }
124  return target_shape;
125 }
126 
140 inline Tensor DoCommReduce(const Tensor& data, FReduce func,
141  const ffi::Array<PrimExpr>& target_shape,
142  const std::vector<int>& reduce_axes,
143  const std::vector<int>& squeeze_axes, Span span = Span()) {
144  auto r_axes = MakeReduceAxes(reduce_axes, data);
145  auto compute = [&](const ffi::Array<Var>& indices) {
146  ffi::Array<PrimExpr> eval_range;
147  ffi::Array<Var> eval_indices;
148  int arg_counter = 0;
149  int red_counter = 0;
150 
151  for (size_t i = 0; i < data->shape.size(); ++i) {
152  bool squeeze_i = std::find(squeeze_axes.begin(), squeeze_axes.end(), i) != squeeze_axes.end();
153  if (std::find(reduce_axes.begin(), reduce_axes.end(), i) != reduce_axes.end()) {
154  // real_axis contains i
155  eval_range.push_back(r_axes[red_counter]);
156  eval_indices.push_back(r_axes[red_counter]->var);
157  red_counter++;
158  arg_counter += !squeeze_i;
159  continue;
160  }
161  eval_range.push_back(indices[arg_counter]);
162  arg_counter++;
163  }
164 
165  return func(data(eval_range), r_axes, {}, span);
166  };
167 
168  return tvm::te::compute(target_shape, compute, data->op->name + "_red", kCommReduce);
169 }
170 
184 inline Tensor CommReduce(const Tensor& data, const ffi::Optional<ffi::Array<Integer>>& axis,
185  FReduce func, bool keepdims, bool atleast1d) {
186  auto ndim = data->shape.size();
187  ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
188  auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
189  auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d);
190  return DoCommReduce(data, func, target_shape, real_axis,
191  keepdims ? std::vector<int>() : real_axis);
192 }
193 
207 inline Tensor CommReduceIdx(const Tensor& data, const ffi::Optional<ffi::Array<Integer>>& axis,
208  FCommReduce func, bool keepdims, bool atleast1d) {
209  auto ndim = data->shape.size();
210  ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
211  auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
212  auto reduce_axes = MakeReduceAxes(real_axis, data);
213  auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d);
214 
215  auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func,
216  &data](const ffi::Array<Var>& indices) {
217  ffi::Array<PrimExpr> eval_range;
218  ffi::Array<PrimExpr> eval_indices;
219  int arg_counter = 0;
220  int red_counter = 0;
221 
222  for (size_t i = 0; i < ndim; ++i) {
223  if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
224  // real_axis contains i
225  eval_range.push_back(reduce_axes[red_counter]);
226  eval_indices.push_back(reduce_axes[red_counter]->var);
227  red_counter++;
228  } else {
229  if (!keepdims) {
230  eval_range.push_back(indices[arg_counter]);
231  arg_counter++;
232  } else {
233  eval_range.push_back(indices[i]);
234  }
235  }
236  }
237 
238  ffi::Array<PrimExpr> ravel_shape;
239  for (auto i : real_axis) {
240  ravel_shape.push_back(data->shape[i]);
241  }
242  auto idx = detail::RavelIndex(eval_indices, ravel_shape);
243  return func({idx, data(eval_range)}, reduce_axes, nullptr);
244  };
245 
246  auto temp_idx_val =
247  tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduceIdx);
248  auto temp_idx = temp_idx_val[0];
249  auto temp_val = temp_idx_val[1];
250  return tvm::te::compute(
251  target_shape, [&temp_idx](const ffi::Array<Var>& indices) { return temp_idx(indices); },
252  data->op->name + "_red", kCommReduceIdx);
253 }
254 
256 using FCombine = std::function<ffi::Array<PrimExpr>(ffi::Array<Var> lhs, ffi::Array<Var> rhs)>;
257 
259 using FIdentity = std::function<ffi::Array<PrimExpr>(std::vector<DataType> types)>;
260 
270 inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity,
271  std::string name = "reduce") {
272  return [fcombine, fidentity, name](ffi::Array<PrimExpr> exprs, const ffi::Array<IterVar>& axis,
273  PrimExpr* condition) {
274  ffi::Array<Var> lhs, rhs;
275  std::vector<DataType> dtypes;
276 
277  for (size_t i = 0; i < exprs.size(); ++i) {
278  auto dtype = exprs[i].dtype();
279  dtypes.push_back(dtype);
280  lhs.push_back(var(name + "_lhs_" + std::to_string(i), dtype));
281  rhs.push_back(var(name + "_rhs_" + std::to_string(i), dtype));
282  }
283 
284  auto result = fcombine(lhs, rhs);
285  auto id_elem = fidentity(dtypes);
286  auto cond = condition != nullptr ? *condition : tir::const_true();
287 
288  auto combiner = tvm::tir::CommReducer(lhs, rhs, result, id_elem);
289  ffi::Array<PrimExpr> outputs;
290  for (size_t i = 0; i < exprs.size(); ++i) {
291  outputs.push_back(tvm::tir::Reduce(combiner, exprs, axis, cond, static_cast<int>(i), {}));
292  }
293  return outputs;
294  };
295 }
296 
298 inline PrimExpr MinOp(PrimExpr source, ffi::Array<IterVar> axis, ffi::Array<PrimExpr> init = {},
299  Span span = Span()) {
300  return tvm::min(source, axis, init, span);
301 }
302 
304 inline PrimExpr MaxOp(PrimExpr source, ffi::Array<IterVar> axis, ffi::Array<PrimExpr> init = {},
305  Span span = Span()) {
306  return tvm::max(source, axis, init, span); // NOLINT(*)
307 }
308 
310 inline PrimExpr ProdOp(PrimExpr source, ffi::Array<IterVar> axis, ffi::Array<PrimExpr> init = {},
311  Span span = Span()) {
312  return tvm::prod(source, axis, init, span); // NOLINT(*)
313 }
314 
328 inline Tensor sum(const Tensor& data, const ffi::Optional<ffi::Array<Integer>>& axis,
329  bool keepdims = false, bool atleast1d = false) {
330  if (data->dtype.is_bool()) {
331  return CommReduce(data, axis, tvm::any, keepdims, atleast1d);
332  } else {
333  return CommReduce(data, axis, tvm::sum, keepdims, atleast1d);
334  }
335 }
336 
337 inline Tensor collapse_sum(const Tensor& data, ffi::Array<PrimExpr> target_shape) {
338  const auto& ishape = data->shape;
339  const auto& oshape = target_shape;
340  int isize = data->shape.size();
341  int osize = target_shape.size();
342 
343  ICHECK_GE(isize, osize)
344  << "Invalid collapse: input dimensionality smaller than output dimensionality.\ninput shape: "
345  << data->shape << "\nvs\noutput shape: " << target_shape;
346 
347  std::vector<int> reduce_axes;
348  std::vector<int> squeeze_axes;
349  tvm::PrimExpr one(1);
350 
351  for (int i_ax = isize - 1, o_ax = osize - 1; i_ax >= 0; --i_ax) {
352  if (o_ax >= 0 && topi::detail::EqualCheck(ishape[i_ax], oshape[o_ax])) {
353  --o_ax;
354  continue;
355  }
356  reduce_axes.push_back(i_ax);
357  if (o_ax < 0) { // squeeze o_ax if was added during expansion
358  squeeze_axes.push_back(i_ax);
359  } else if (topi::detail::EqualCheck(one, oshape[o_ax])) {
360  --o_ax;
361  }
362  }
363 
364  if (reduce_axes.size() == 0) return topi::identity(data, "tensor", kCommReduce);
365 
366  std::reverse(reduce_axes.begin(), reduce_axes.end());
367  std::reverse(squeeze_axes.begin(), squeeze_axes.end());
368  return DoCommReduce(data, tvm::sum, target_shape, reduce_axes, squeeze_axes);
369 }
370 
385 inline Tensor all(const Tensor& data, const ffi::Optional<ffi::Array<Integer>>& axis,
386  bool keepdims = false, bool atleast1d = false) {
387  return CommReduce(data, axis, tvm::all, keepdims, atleast1d);
388 }
389 
404 inline Tensor any(const Tensor& data, const ffi::Optional<ffi::Array<Integer>>& axis,
405  bool keepdims = false, bool atleast1d = false) {
406  return CommReduce(data, axis, tvm::any, keepdims, atleast1d);
407 }
408 
423 inline Tensor min(const Tensor& data, const ffi::Optional<ffi::Array<Integer>>& axis,
424  bool keepdims = false, bool atleast1d = false) {
425  return CommReduce(data, axis, MinOp, keepdims, atleast1d);
426 }
427 
442 inline Tensor max(const Tensor& data, const ffi::Optional<ffi::Array<Integer>>& axis,
443  bool keepdims = false, bool atleast1d = false) {
444  return CommReduce(data, axis, MaxOp, keepdims, atleast1d);
445 }
446 
447 inline FCommReduce MakeArgminReducer(bool select_last_index = false) {
448  // Create a Commutative Reducer with a comparison operation, and method to get the initial value.
449  auto fcombine = [=](ffi::Array<Var> lhs, ffi::Array<Var> rhs) {
450  ffi::Array<PrimExpr> result;
451 
452  // Casting to avoid operator ambiguity
453  PrimExpr lhs_idx = static_cast<PrimExpr>(lhs[0]);
454  PrimExpr rhs_idx = static_cast<PrimExpr>(rhs[0]);
455  PrimExpr lhs_val = static_cast<PrimExpr>(lhs[1]);
456  PrimExpr rhs_val = static_cast<PrimExpr>(rhs[1]);
457 
458  // These variables compare the actual values of the array
459  auto is_smaller = lhs_val < rhs_val;
460  auto is_same = lhs_val == rhs_val;
461 
462  // This checks if the indices are correct for the reduction. E.g. for select_last_index
463  // it gives precedence for later indices of the same element and precedence for sooner
464  // indices if not select_last_index;
465  PrimExpr proper_index;
466  if (select_last_index) {
467  proper_index = lhs_idx > rhs_idx;
468  } else {
469  proper_index = lhs_idx < rhs_idx;
470  }
471 
472  PrimExpr update_index = is_smaller || (is_same && proper_index);
473  result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx
474  result.push_back(tvm::tir::Select(is_smaller, lhs[1], rhs[1])); // val
475  return result;
476  };
477  auto fidentity = [&](std::vector<DataType> types) {
478  ffi::Array<PrimExpr> result;
479  result.push_back(tvm::tir::make_const(types[0], -1)); // idx
480  result.push_back(tvm::max_value(types[1])); // val
481  return result;
482  };
483  return MakeCommReducer(fcombine, fidentity, "argmin");
484 }
485 
502 inline Tensor argmin(const Tensor& data, const ffi::Optional<ffi::Array<Integer>>& axis,
503  bool keepdims = false, bool atleast1d = false,
504  bool select_last_index = false) {
505  auto reducer = MakeArgminReducer(select_last_index);
506  return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
507 }
508 
509 inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) {
510  // Create a Commutative Reducer with a comparison operation, and method to get the initial value.
511  auto fcombine = [=](ffi::Array<Var> lhs, ffi::Array<Var> rhs) {
512  ffi::Array<PrimExpr> result;
513 
514  // Casting to avoid operator ambiguity
515  PrimExpr lhs_idx = static_cast<PrimExpr>(lhs[0]);
516  PrimExpr rhs_idx = static_cast<PrimExpr>(rhs[0]);
517  PrimExpr lhs_val = static_cast<PrimExpr>(lhs[1]);
518  PrimExpr rhs_val = static_cast<PrimExpr>(rhs[1]);
519 
520  // These variables compare the actual values of the array
521  auto is_bigger = lhs_val > rhs_val;
522  auto is_same = lhs_val == rhs_val;
523 
524  // This checks if the indices are correct for the reduction. E.g. for select_last_index
525  // it gives precedence for later indices of the same element and precedence for sooner
526  // indices if not select_last_index;
527  PrimExpr proper_index;
528  if (select_last_index) {
529  proper_index = lhs_idx > rhs_idx;
530  } else {
531  proper_index = lhs_idx < rhs_idx;
532  }
533 
534  PrimExpr update_index = is_bigger || (is_same && proper_index);
535  result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx
536  result.push_back(tvm::tir::Select(is_bigger, lhs[1], rhs[1])); // val
537  return result;
538  };
539  auto fidentity = [&](std::vector<DataType> types) {
540  ffi::Array<PrimExpr> result;
541  result.push_back(tvm::tir::make_const(types[0], -1)); // idx
542  result.push_back(tvm::min_value(types[1])); // val
543  return result;
544  };
545  return MakeCommReducer(fcombine, fidentity, "argmax");
546 }
547 
563 inline Tensor argmax(const Tensor& data, const ffi::Optional<ffi::Array<Integer>>& axis,
564  bool keepdims = false, bool atleast1d = false,
565  bool select_last_index = false) {
566  auto reducer = MakeArgmaxReducer(select_last_index);
567  return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
568 }
569 
583 inline Tensor prod(const Tensor& data, const ffi::Optional<ffi::Array<Integer>>& axis,
584  bool keepdims = false, bool atleast1d = false) {
585  return CommReduce(data, axis, ProdOp, keepdims, atleast1d);
586 }
587 
592  auto fcombine = [](ffi::Array<Var> lhs, ffi::Array<Var> rhs) {
593  ffi::Array<PrimExpr> result;
594  ICHECK_EQ(lhs.size(), rhs.size());
595  result.reserve(lhs.size());
596  for (size_t i = 0; i < lhs.size(); ++i) {
597  result.push_back(lhs[i] + rhs[i]);
598  }
599  return result;
600  };
601  auto fidentity = [](std::vector<DataType> types) {
602  ffi::Array<PrimExpr> result;
603  for (size_t i = 0; i < types.size(); ++i) {
604  result.push_back(tvm::tir::make_const(types[i], 0));
605  }
606  return result;
607  };
608  return MakeCommReducer(fcombine, fidentity, "tuple_sum");
609 }
610 
611 } // namespace topi
612 } // namespace tvm
613 #endif // TVM_TOPI_REDUCTION_H_
Broadcast op constructions.
Reference to PrimExprNode.
Definition: expr.h:124
Range container
Definition: expr.h:689
Definition: source_map.h:111
Managed Tensor. The array is backed by reference counted blocks.
Definition: tensor.h:53
Managed reference to CommReducerNode.
Definition: expr.h:832
Managed reference to ReduceNode.
Definition: expr.h:876
Managed reference to SelectNode.
Definition: expr.h:515
Utility functions for handling constants in TVM expressions.
Elementwise op constructions.
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(ffi::Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", ffi::Map< ffi::String, ffi::Any > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:994
PrimExpr const_true(int lanes=1, Span span=Span())
Make a constant true expression.
Definition: op.h:818
Tensor DoCommReduce(const Tensor &data, FReduce func, const ffi::Array< PrimExpr > &target_shape, const std::vector< int > &reduce_axes, const std::vector< int > &squeeze_axes, Span span=Span())
Create a reduction operation.
Definition: reduction.h:140
Tensor collapse_sum(const Tensor &data, ffi::Array< PrimExpr > target_shape)
Definition: reduction.h:337
std::vector< int > GetRealAxis(int ndim, const ffi::Optional< ffi::Array< Integer >> &axis)
Convert a reduction axis which could be empty or have negative elements into a real axis with valid d...
Definition: reduction.h:65
FCommReduce MakeTupleSumReducer()
Create communitive reducer summing over tuples.
Definition: reduction.h:591
FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, std::string name="reduce")
Create a commutative reducer for a reduction.
Definition: reduction.h:270
ffi::Array< PrimExpr > MakeReduceTargetShape(const std::vector< int > &real_axis, const Tensor &data, bool keepdims, bool atleast1d)
Calculate the target shape for a reduce op.
Definition: reduction.h:99
std::function< ffi::Array< PrimExpr >(ffi::Array< PrimExpr > exprs, const ffi::Array< IterVar > &axis, PrimExpr *condition)> FCommReduce
The operation to use for CommReduceIdx.
Definition: reduction.h:51
Tensor max(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the maximum of elements over a given axis.
Definition: reduction.h:442
ffi::Array< IterVar > MakeReduceAxes(const std::vector< int > &real_axis, const Tensor &data)
Enumerate the axes for a reduce op.
Definition: reduction.h:89
FCommReduce MakeArgmaxReducer(bool select_last_index=false)
Definition: reduction.h:509
PrimExpr MaxOp(PrimExpr source, ffi::Array< IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
Wrap tvm::max to ensure we get the correct overload.
Definition: reduction.h:304
std::function< PrimExpr(PrimExpr source, const ffi::Array< IterVar > &axis, ffi::Array< PrimExpr > init, Span span)> FReduce
The operation to use for CommReduce.
Definition: reduction.h:47
constexpr auto kCommReduce
Definition: tags.h:34
Tensor argmin(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false, bool select_last_index=false)
Creates an operation that finds the indices of the minimum values over a given axis.
Definition: reduction.h:502
Tensor CommReduceIdx(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, FCommReduce func, bool keepdims, bool atleast1d)
Create an index reduction operation.
Definition: reduction.h:207
constexpr auto kCommReduceIdx
Definition: tags.h:35
PrimExpr MinOp(PrimExpr source, ffi::Array< IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
Wrap tvm::min to ensure we get the correct overload.
Definition: reduction.h:298
FCommReduce MakeArgminReducer(bool select_last_index=false)
Definition: reduction.h:447
std::function< ffi::Array< PrimExpr >(ffi::Array< Var > lhs, ffi::Array< Var > rhs)> FCombine
A combiner function for a reduction.
Definition: reduction.h:256
Tensor identity(const Tensor &x, std::string name="T_identity", std::string tag=kElementWise)
Creates an operation that returns identity of a given tensor.
Definition: elemwise.h:152
PrimExpr ProdOp(PrimExpr source, ffi::Array< IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
Wrap tvm::prod to ensure we get the correct overload.
Definition: reduction.h:310
Tensor any(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that computes the logical OR of elements over a given axis.
Definition: reduction.h:404
Tensor prod(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates product operation over given axis.
Definition: reduction.h:583
Tensor min(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the minimum of elements over a given axis.
Definition: reduction.h:423
Tensor sum(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:328
Tensor CommReduce(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, FReduce func, bool keepdims, bool atleast1d)
Create a reduction operation.
Definition: reduction.h:184
Tensor argmax(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false, bool select_last_index=false)
Creates an operation that finds the indices of the maximum values over a given axis.
Definition: reduction.h:563
std::function< ffi::Array< PrimExpr >(std::vector< DataType > types)> FIdentity
An initializer function for a reduction.
Definition: reduction.h:259
Tensor all(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that computes the logical AND of elements over a given axis.
Definition: reduction.h:385
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr any(PrimExpr source, ffi::Array< tir::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
logical Or of source expression over axis
PrimExpr min_value(const DataType &dtype, Span span=Span())
PrimExpr max_value(const DataType &dtype, Span span=Span())
PrimExpr sum(PrimExpr source, ffi::Array< tir::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
PrimExpr prod(PrimExpr source, ffi::Array< tir::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
product of source expression over axis
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
PrimExpr all(PrimExpr source, ffi::Array< tir::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
logical And of source expression over axis
Operation node can generate one or multiple Tensors.
Index ravel and unraval operations.
External function interface to rocBLAS libraries.
Transform op constructors.