24 #ifndef TVM_TOPI_NN_POOLING_H_
25 #define TVM_TOPI_NN_POOLING_H_
50 const ffi::Array<PrimExpr>& kernel_size,
51 const ffi::Array<PrimExpr>& stride_size,
52 const ffi::Array<PrimExpr>& padding_size,
PoolType pool_type,
53 bool ceil_mode,
const size_t height_axis,
const size_t width_axis,
54 bool count_include_pad) {
55 TVM_FFI_ICHECK(out_grad->shape.size() >= 2) <<
"Pooling grad output must >= 2-D (H, W)";
56 TVM_FFI_ICHECK(x->shape.size() >= 2) <<
"Pooling input must >= 2-D (H, W)";
57 TVM_FFI_ICHECK_EQ(kernel_size.size(), 2) <<
"Pooling kernel_size must have 2 elements";
58 TVM_FFI_ICHECK_EQ(stride_size.size(), 2) <<
"Pooling stride_size must have 2 elements";
59 TVM_FFI_ICHECK_EQ(padding_size.size(), 4) <<
"Pooling padding_size must have 4 elements";
61 auto kernel_height = kernel_size[0];
62 auto kernel_width = kernel_size[1];
63 auto stride_height = stride_size[0];
64 auto stride_width = stride_size[1];
66 auto height = x->shape[height_axis];
67 auto width = x->shape[width_axis];
69 auto pad_top = padding_size[0];
70 auto pad_left = padding_size[1];
71 auto pad_bottom = padding_size[2];
72 auto pad_right = padding_size[3];
77 pad_bottom += stride_height - 1;
78 pad_right += stride_width - 1;
81 ffi::Array<PrimExpr> pad_before(std::vector<PrimExpr>(x->shape.size(), 0));
82 pad_before.Set(height_axis, pad_top);
83 pad_before.Set(width_axis, pad_left);
85 ffi::Array<PrimExpr> pad_after(std::vector<PrimExpr>(x->shape.size(), 0));
86 pad_after.Set(height_axis, pad_bottom);
87 pad_after.Set(width_axis, pad_right);
90 analyzer.
Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1);
92 analyzer.
Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1);
97 ffi::Array<PrimExpr> data_shape = x->shape;
98 ffi::Array<PrimExpr> out_shape = data_shape;
99 out_shape.Set(height_axis, out_height);
100 out_shape.Set(width_axis, out_width);
106 const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
107 ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
110 ffi::Array<PrimExpr> ravel_shape{data_shape.begin(), data_shape.end()};
111 ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom);
112 ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);
120 auto pad_x = do_pad ?
pad(x, pad_before, pad_after,
tvm::min_value(x->dtype),
"pad_temp") : x;
124 [&](
const ffi::Array<Var>& inds) {
125 ffi::Array<PrimExpr> window_inds{inds.begin(), inds.end()};
126 window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight);
127 window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth);
128 auto idx = detail::RavelIndex(window_inds, ravel_shape);
129 return argmax({idx, pad_x(window_inds)}, {dheight, dwidth},
nullptr);
133 auto mp_inds = mp_argmax[0];
137 [&](
const ffi::Array<Var>& inds) {
138 ffi::Array<PrimExpr> pad_inds{inds.begin(), inds.end()};
139 pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top);
140 pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left);
141 auto idx = detail::RavelIndex(pad_inds, ravel_shape);
143 ffi::Array<PrimExpr> out_idx{inds.begin(), inds.end()};
144 out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh);
145 out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww);
148 pad_inds[height_axis] < kernel_height,
make_const(pad_inds[height_axis].dtype(), 0),
149 (pad_inds[height_axis] - kernel_height) / stride_height + 1);
151 pad_inds[width_axis] < kernel_width,
make_const(pad_inds[width_axis].dtype(), 0),
152 (pad_inds[width_axis] - kernel_width) / stride_width + 1);
156 out_idx[width_axis] >= out_idx_lower_w),
157 mp_inds(out_idx) == idx),
161 "T_pool_grad",
"pool_grad_max");
169 [&](
const ffi::Array<Var>& inds) {
170 PrimExpr pad_h_idx = inds[height_axis] + pad_top;
171 PrimExpr pad_w_idx = inds[width_axis] + pad_left;
174 ffi::Array<PrimExpr> out_idx{inds.begin(), inds.end()};
175 out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh));
176 out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww));
180 (pad_h_idx - kernel_height) / stride_height + 1);
183 (pad_w_idx - kernel_width) / stride_width + 1);
186 if (count_include_pad) {
187 divide_factor = kernel_height * kernel_width;
189 PrimExpr h_start = out_idx[height_axis] * stride_height - pad_top;
190 PrimExpr w_start = out_idx[width_axis] * stride_width - pad_left;
192 PrimExpr h_end =
min(h_start + kernel_height, height);
193 PrimExpr w_end =
min(w_start + kernel_width, width);
201 out_idx[height_axis] < out_height),
202 tirx::And(out_idx[width_axis] >= out_idx_lower_w,
203 out_idx[width_axis] < out_width)),
204 out_grad(out_idx) / divide_factor,
make_const(out_grad->dtype, 0)),
207 "T_pool_grad",
"pool_grad_avg");
209 LOG(ERROR) <<
"Unrecognized pool_type: " << pool_type;
227 if (depth_axis) *depth_axis = -1;
228 if (height_axis) *height_axis = -1;
229 if (width_axis) *width_axis = -1;
231 for (
size_t i = 0; i < layout.size(); ++i) {
232 if ((layout[i] >=
'A' && layout[i] <=
'Z') || (layout[i] >=
'a' && layout[i] <=
'z')) {
233 if (layout[i] ==
'D' && depth_axis) {
234 if (*depth_axis != -1)
return false;
235 *depth_axis = curr_idx;
236 }
else if (layout[i] ==
'H' && height_axis) {
237 if (*height_axis != -1)
return false;
238 *height_axis = curr_idx;
239 }
else if (layout[i] ==
'W' && width_axis) {
240 if (*width_axis != -1)
return false;
241 *width_axis = curr_idx;
242 }
else if (layout[i] ==
'd' || layout[i] ==
'h' || layout[i] ==
'w') {
249 if ((depth_axis && *depth_axis == -1) || (height_axis && *height_axis == -1) ||
250 (width_axis && *width_axis == -1))
259 inline bool find_width(
const std::string& layout,
int* width_axis) {
295 const ffi::Array<PrimExpr>& kernel_size,
296 const ffi::Array<PrimExpr>& stride_size,
297 const ffi::Array<PrimExpr>& padding_size,
PoolType pool_type,
298 bool ceil_mode,
const std::string& layout =
"NCHW",
299 bool count_include_pad =
true) {
300 int height_axis = -1, width_axis = -1;
302 <<
"Unsupported layout " << layout;
303 return pool_grad_impl(out_grad, x, kernel_size, stride_size, padding_size, pool_type, ceil_mode,
304 height_axis, width_axis, count_include_pad);
308 return indexdiv(out_index * idim, odim);
327 PoolType pool_type,
const std::vector<int>& axes) {
328 const auto n_dim = output_size.size();
329 TVM_FFI_ICHECK_EQ(axes.size(), n_dim) <<
"The number of axes not equal to the in/out dimension";
331 ffi::Array<PrimExpr> data_shape = x->shape;
332 ffi::Array<PrimExpr> out_shape = data_shape;
333 ffi::Array<PrimExpr> in_size, out_size;
334 for (
size_t i = 0; i < n_dim; ++i) {
335 in_size.push_back(data_shape[axes[i]]);
336 out_size.push_back(output_size[i]);
337 out_shape.Set(axes[i], out_size[i]);
340 auto get_iter_vars = [=](
const ffi::Array<Var>& output,
bool reduce_indices) {
341 ffi::Array<PrimExpr> indices;
342 for (
size_t i = 0; i < output.size(); ++i) indices.push_back(output[i]);
343 ffi::Array<tirx::IterVar> reduce_axes;
344 for (
size_t i = 0; i < n_dim; ++i) {
345 auto i_start =
start_index(output[axes[i]], out_size[i], in_size[i]);
346 auto i_end =
end_index(output[axes[i]], out_size[i], in_size[i]);
347 auto rv_name =
"rv" + std::to_string(i);
349 reduce_axes.push_back(rv_axis);
350 if (reduce_indices) {
351 indices.Set(axes[i], i_start + rv_axis);
354 return std::make_tuple(indices, reduce_axes);
357 ffi::Map<ffi::String, ffi::Any> attrs;
359 attrs.Set(
"schedule_rule", tvm::ffi::String(
"meta_schedule.adaptive_pool_max"));
362 [&](
const ffi::Array<Var>& output) {
363 ffi::Array<PrimExpr> indices;
364 ffi::Array<tirx::IterVar> reduce_axes;
365 std::tie(indices, reduce_axes) = get_iter_vars(output,
true);
366 return tvm::max(x(indices), reduce_axes);
368 "adaptive_pool_max",
"adaptive_pool_max", attrs);
370 attrs.Set(
"schedule_rule", tvm::ffi::String(
"meta_schedule.adaptive_pool_avg"));
373 [&](
const ffi::Array<Var>& output) {
374 ffi::Array<PrimExpr> indices;
375 ffi::Array<tirx::IterVar> reduce_axes;
376 std::tie(indices, reduce_axes) = get_iter_vars(output,
true);
377 return tvm::sum(x(indices), reduce_axes);
379 "adaptive_pool_sum",
"adaptive_pool_sum");
383 [&](
const ffi::Array<Var>& output) {
384 ffi::Array<PrimExpr> indices;
385 ffi::Array<tirx::IterVar> reduce_axes;
386 std::tie(indices, reduce_axes) = get_iter_vars(output,
false);
389 for (
size_t i = 0; i < n_dim; ++i) {
393 return div(pool_sum(indices), divide_factor);
397 LOG(ERROR) <<
"Unrecognized pool_type: " << pool_type;
429 PoolType pool_type,
const std::string& layout =
"NCHW") {
430 int height_axis = -1, width_axis = -1;
432 <<
"Unsupported layout " << layout;
445 PoolType pool_type,
const std::string& layout =
"NCDHW") {
446 int depth_axis = -1, height_axis = -1, width_axis = -1;
448 <<
"Unsupported layout " << layout;
449 return adaptive_pool_impl(x, output_size, pool_type, {depth_axis, height_axis, width_axis});
461 PoolType pool_type,
const std::string& layout =
"NCW") {
463 TVM_FFI_ICHECK(
find_width(layout, &width_axis)) <<
"Unsupported layout " << layout;
493 return adaptive_pool(x, ffi::Array<PrimExpr>{1, 1}, pool_type, layout);
513 const ffi::Array<PrimExpr>& stride_size,
514 const ffi::Array<PrimExpr>& dilation_size,
515 const ffi::Array<PrimExpr>& padding_size,
PoolType pool_type,
516 bool ceil_mode,
const std::vector<int>& axis,
bool count_include_pad) {
517 int k_size = kernel_size.size();
518 int x_size = x->shape.size();
519 TVM_FFI_ICHECK_EQ(stride_size.size(), k_size)
520 <<
"Pooling stride_size must have same elements as kernel";
521 TVM_FFI_ICHECK_EQ(padding_size.size(), k_size * 2)
522 <<
"Pooling padding_size must has double elements of"
524 TVM_FFI_ICHECK_EQ(axis.size(), k_size) <<
"axis must have same elements as kernel";
526 ffi::Array<IterVar> daxis;
527 std::vector<PrimExpr> kernel(k_size);
528 std::vector<PrimExpr> stride(k_size);
529 std::vector<PrimExpr> dilation(k_size);
530 std::vector<PrimExpr> pad_head(k_size);
531 std::vector<PrimExpr> pad_tail(k_size);
532 std::vector<PrimExpr> offset(k_size, 0);
533 ffi::Array<PrimExpr> pad_before(std::vector<PrimExpr>(x_size, 0));
534 ffi::Array<PrimExpr> pad_after(std::vector<PrimExpr>(x_size, 0));
535 ffi::Array<PrimExpr> data_shape = x->shape;
536 ffi::Array<PrimExpr> out_shape = data_shape;
539 for (
int i = 0; i < k_size; i++) {
541 kernel[i] = kernel_size[i];
542 stride[i] = stride_size[i];
543 dilation[i] = dilation_size[i];
544 pad_head[i] = padding_size[i];
545 pad_tail[i] = padding_size[i + k_size];
553 offset[i] = stride[i] - 1;
554 pad_tail[i] += offset[i];
559 do_pad = do_pad || (padding0 && *padding0) || (padding1 && *padding1);
563 pad_before.Set(ii, pad_head[i]);
564 pad_after.Set(ii, pad_tail[i]);
569 data_shape[ii] - (kernel[i] - 1) * dilation[i] - 1 + pad_head[i] + pad_tail[i];
570 auto raw_out =
indexdiv(numerator, stride[i]) + 1;
575 auto invalid_last = (raw_out - 1) * stride[i] >= data_shape[ii] + pad_head[i];
577 out_shape.Set(ii, out_dim);
579 auto out_dim = analyzer.
Simplify(raw_out);
580 out_shape.Set(ii, out_dim);
584 ffi::Map<ffi::String, ffi::Any> attrs;
586 auto temp = do_pad ?
pad(x, pad_before, pad_after,
tvm::min_value(x->dtype),
"pad_temp") : x;
587 attrs.Set(
"schedule_rule", tvm::ffi::String(
"meta_schedule.pool_max"));
590 [&](
const ffi::Array<Var>& output) {
591 ffi::Array<PrimExpr> indices;
592 for (
const Var&
var : output) indices.push_back(
var);
594 for (
int i = 0; i < k_size; i++) {
596 indices.Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]);
598 return tvm::max(temp(indices), daxis);
600 "pool_max",
"pool_max", attrs);
602 attrs.Set(
"schedule_rule", tvm::ffi::String(
"meta_schedule.pool_avg"));
604 auto temp = do_pad ?
pad(x, pad_before, pad_after, 0,
"pad_temp") : x;
609 [&](
const ffi::Array<Var>& output) {
610 ffi::Array<PrimExpr> indices;
611 for (
const Var&
var : output) indices.push_back(
var);
613 for (
int i = 0; i < k_size; i++) {
615 indices.Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]);
617 return tvm::sum(temp(indices), daxis);
619 "pool_sum",
"pool_sum");
624 [&](
const ffi::Array<Var>& output) {
625 ffi::Array<PrimExpr> indices;
626 for (
const Var&
var : output) indices.push_back(
var);
627 if (count_include_pad) {
628 std::vector<PrimExpr> start(k_size);
629 std::vector<PrimExpr> end(k_size);
631 for (
int i = 0; i < k_size; i++) {
633 start[i] = output[ii] * stride[i] - pad_head[i];
638 end[i] = start[i] + (kernel[i] - 1) * dilation[i];
639 end[i] =
min(end[i], data_shape[ii] + pad_tail[i] - 1 - offset[i]);
640 num_el *= (end[i] - start[i]) / dilation[i] + 1;
642 return div(pool_sum(indices), num_el);
644 std::vector<PrimExpr> start(k_size);
645 std::vector<PrimExpr> end(k_size);
647 for (
int i = 0; i < k_size; i++) {
654 start[i] = output[ii] * stride[i] - pad_head[i];
655 end[i] = start[i] + (kernel[i] - 1) * dilation[i];
660 PrimExpr jumps_to_non_pad = (dilation[i] - 1 - start[i]) / dilation[i];
661 jumps_to_non_pad =
max(jumps_to_non_pad,
make_const(jumps_to_non_pad.dtype(), 0));
663 end[i] =
min(end[i], data_shape[ii] - 1);
664 num_el *= (end[i] - (start[i] + dilation[i] * jumps_to_non_pad)) / dilation[i] + 1;
668 return div(pool_sum(indices), divide_factor);
673 LOG(ERROR) <<
"Unrecognized pool_type: " << pool_type;
709 const ffi::Array<PrimExpr>& stride_size,
710 const ffi::Array<PrimExpr>& dilation_size,
711 const ffi::Array<PrimExpr>& padding_size,
PoolType pool_type,
bool ceil_mode,
712 const std::string& layout =
"NCW",
bool count_include_pad =
true) {
714 TVM_FFI_ICHECK(
find_width(layout, &width_axis)) <<
"Unsupported layout " << layout;
715 std::vector<int> axis = {width_axis};
716 return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
717 ceil_mode, axis, count_include_pad);
751 const ffi::Array<PrimExpr>& stride_size,
752 const ffi::Array<PrimExpr>& dilation_size,
753 const ffi::Array<PrimExpr>& padding_size,
PoolType pool_type,
bool ceil_mode,
754 const std::string& layout =
"NCHW",
bool count_include_pad =
true) {
755 int height_axis = -1, width_axis = -1;
757 <<
"Unsupported layout " << layout;
758 std::vector<int> axis = {height_axis, width_axis};
759 return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
760 ceil_mode, axis, count_include_pad);
795 const ffi::Array<PrimExpr>& stride_size,
796 const ffi::Array<PrimExpr>& dilation_size,
797 const ffi::Array<PrimExpr>& padding_size,
PoolType pool_type,
bool ceil_mode,
798 const std::string& layout =
"NCDHW",
bool count_include_pad =
true) {
799 int depth_axis = -1, height_axis = -1, width_axis = -1;
801 <<
"Unsupported layout " << layout;
802 std::vector<int> axis = {depth_axis, height_axis, width_axis};
803 return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
804 ceil_mode, axis, count_include_pad);
Algebra expression simplifications.
Reference to PrimExprNode.
Definition: expr.h:126
DataType dtype() const
Definition: expr.h:140
Range container
Definition: expr.h:690
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:634
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:278
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:100
Managed reference to AndNode.
Definition: expr.h:427
Managed reference to SelectNode.
Definition: expr.h:514
a named variable in TIR
Definition: var.h:76
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(ffi::Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", ffi::Map< ffi::String, ffi::Any > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:1007
const int64_t * as_const_int(const PrimExpr &x)
Get x as constant int expression.
Definition: op.h:848
Tensor adaptive_pool3d(const Tensor &x, const ffi::Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCDHW")
Adaptively perform pooling on three dimensional data. See the two dimensional version above for detai...
Definition: pooling.h:444
Tensor adaptive_pool(const Tensor &x, const ffi::Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCHW")
Adaptively perform pooling on height and width dimension of data. The pooling kernel and stride sizes...
Definition: pooling.h:428
PoolType
Pooling type.
Definition: pooling.h:44
@ kAvgPool
Definition: pooling.h:45
@ kMaxPool
Definition: pooling.h:46
Tensor adaptive_pool1d(const Tensor &x, const ffi::Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCW")
Adaptively perform pooling on one dimensional data. See the two dimensional version above for details...
Definition: pooling.h:460
Tensor pool3d(const Tensor &x, const ffi::Array< PrimExpr > &kernel_size, const ffi::Array< PrimExpr > &stride_size, const ffi::Array< PrimExpr > &dilation_size, const ffi::Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCDHW", bool count_include_pad=true)
Perform pooling on depth, height and width dimension of data. It decides the depth,...
Definition: pooling.h:794
Tensor pool2d(const Tensor &x, const ffi::Array< PrimExpr > &kernel_size, const ffi::Array< PrimExpr > &stride_size, const ffi::Array< PrimExpr > &dilation_size, const ffi::Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCHW", bool count_include_pad=true)
Perform pooling on height and width dimension of data. It decides the height and width dimension acco...
Definition: pooling.h:750
Tensor pool_grad_impl(const Tensor &out_grad, const Tensor &x, const ffi::Array< PrimExpr > &kernel_size, const ffi::Array< PrimExpr > &stride_size, const ffi::Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const size_t height_axis, const size_t width_axis, bool count_include_pad)
Definition: pooling.h:49
PrimExpr start_index(const Var &out_index, const PrimExpr &odim, const PrimExpr &idim)
Definition: pooling.h:307
PrimExpr end_index(const Var &out_index, const PrimExpr &odim, const PrimExpr &idim)
Definition: pooling.h:311
Tensor pool_grad(const Tensor &out_grad, const Tensor &x, const ffi::Array< PrimExpr > &kernel_size, const ffi::Array< PrimExpr > &stride_size, const ffi::Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCHW", bool count_include_pad=true)
Calculate gradient of pooling on height and width dimension of data. It decides the height and width ...
Definition: pooling.h:294
bool find_depth_height_width(const std::string &layout, int *depth_axis, int *height_axis, int *width_axis)
Find index of Depth, Height or Width dimension in a layout string.
Definition: pooling.h:225
bool find_width(const std::string &layout, int *width_axis)
Definition: pooling.h:259
Tensor pool_impl_nd(const Tensor &x, const ffi::Array< PrimExpr > &kernel_size, const ffi::Array< PrimExpr > &stride_size, const ffi::Array< PrimExpr > &dilation_size, const ffi::Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::vector< int > &axis, bool count_include_pad)
Perform pooling on N-dimension of data.
Definition: pooling.h:512
bool find_height_width(const std::string &layout, int *height_axis, int *width_axis)
Definition: pooling.h:255
Tensor global_pool(const Tensor &x, PoolType pool_type, const std::string &layout="NCHW")
Perform global pooling on height and width dimension of data. It decides the height and width dimensi...
Definition: pooling.h:492
Tensor adaptive_pool_impl(const Tensor &x, const ffi::Array< PrimExpr > &output_size, PoolType pool_type, const std::vector< int > &axes)
Perform adaptive pooling on N dimensional data.
Definition: pooling.h:326
Tensor pool1d(const Tensor &x, const ffi::Array< PrimExpr > &kernel_size, const ffi::Array< PrimExpr > &stride_size, const ffi::Array< PrimExpr > &dilation_size, const ffi::Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCW", bool count_include_pad=true)
Perform pooling on the width dimension of data. Width axis is determined by the layout string in whic...
Definition: pooling.h:708
constexpr auto kElementWise
Definition: tags.h:32
Tensor max(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the maximum of elements over a given axis.
Definition: reduction.h:442
FCommReduce MakeArgmaxReducer(bool select_last_index=false)
Definition: reduction.h:509
tvm::te::Tensor pad(const tvm::te::Tensor &t, const tvm::ffi::Array< tvm::PrimExpr > &pad_before, tvm::ffi::Array< tvm::PrimExpr > pad_after=tvm::ffi::Array< tvm::PrimExpr >(), PrimExpr pad_value=PrimExpr(), std::string name="T_pad", std::string tag=kElementWise, std::string pad_mode="constant", const ffi::Array< PrimExpr > *dyn_output_shape=nullptr)
Creates an operation that performs padding.
Definition: nn.h:156
constexpr auto kCommReduceIdx
Definition: tags.h:35
Tensor min(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the minimum of elements over a given axis.
Definition: reduction.h:423
Tensor argmax(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false, bool select_last_index=false)
Creates an operation that finds the indices of the maximum values over a given axis.
Definition: reduction.h:563
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
PrimExpr min_value(const DataType &dtype, Span span=Span())
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr sum(PrimExpr source, ffi::Array< tirx::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
Reduction op constructors.