24 #ifndef TVM_TOPI_NN_POOLING_H_ 25 #define TVM_TOPI_NN_POOLING_H_ 52 bool ceil_mode,
const size_t height_axis,
const size_t width_axis,
53 bool count_include_pad) {
54 ICHECK(out_grad->shape.size() >= 2) <<
"Pooling grad output must >= 2-D (H, W)";
55 ICHECK(x->shape.size() >= 2) <<
"Pooling input must >= 2-D (H, W)";
56 ICHECK_EQ(kernel_size.
size(), 2) <<
"Pooling kernel_size must have 2 elements";
57 ICHECK_EQ(stride_size.
size(), 2) <<
"Pooling stride_size must have 2 elements";
58 ICHECK_EQ(padding_size.
size(), 4) <<
"Pooling padding_size must have 4 elements";
60 auto kernel_height =
cast(DataType::DataType::Int(32), kernel_size[0]);
61 auto kernel_width =
cast(DataType::DataType::Int(32), kernel_size[1]);
62 auto stride_height =
cast(DataType::DataType::Int(32), stride_size[0]);
63 auto stride_width =
cast(DataType::DataType::Int(32), stride_size[1]);
65 auto height =
cast(DataType::DataType::Int(32), x->shape[height_axis]);
66 auto width =
cast(DataType::DataType::Int(32), x->shape[width_axis]);
68 auto pad_top =
cast(DataType::DataType::Int(32), padding_size[0]);
69 auto pad_left =
cast(DataType::DataType::Int(32), padding_size[1]);
70 auto pad_bottom =
cast(DataType::DataType::Int(32), padding_size[2]);
71 auto pad_right =
cast(DataType::DataType::Int(32), padding_size[3]);
76 pad_bottom += stride_height - 1;
77 pad_right += stride_width - 1;
81 pad_before.
Set(height_axis, pad_top);
82 pad_before.
Set(width_axis, pad_left);
85 pad_after.
Set(height_axis, pad_bottom);
86 pad_after.
Set(width_axis, pad_right);
89 analyzer.
Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1);
91 analyzer.
Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1);
97 for (
size_t i = 0; i < data_shape.
size(); ++i) {
98 data_shape.
Set(i,
cast(DataType::DataType::Int(32), data_shape[i]));
102 out_shape.
Set(height_axis, out_height);
103 out_shape.
Set(width_axis, out_width);
109 const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
110 ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
114 ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom);
115 ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);
123 auto pad_x = do_pad ?
pad(x, pad_before, pad_after,
tvm::min_value(x->dtype),
"pad_temp") : x;
129 window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight);
130 window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth);
131 auto idx = detail::RavelIndex(window_inds, ravel_shape);
132 return argmax({idx, pad_x(window_inds)}, {dheight, dwidth},
nullptr);
136 auto mp_inds = mp_argmax[0];
142 pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top);
143 pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left);
144 auto idx = detail::RavelIndex(pad_inds, ravel_shape);
147 out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh);
148 out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww);
151 pad_inds[height_axis] < kernel_height,
make_const(DataType::DataType::Int(32), 0),
152 (pad_inds[height_axis] - kernel_height) / stride_height + 1);
154 pad_inds[width_axis] < kernel_width,
make_const(DataType::DataType::Int(32), 0),
155 (pad_inds[width_axis] - kernel_width) / stride_width + 1);
159 out_idx[width_axis] >= out_idx_lower_w),
160 mp_inds(out_idx) == idx),
164 "T_pool_grad",
"pool_grad_max");
173 PrimExpr pad_h_idx = inds[height_axis] + pad_top;
174 PrimExpr pad_w_idx = inds[width_axis] + pad_left;
178 out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh));
179 out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww));
183 (pad_h_idx - kernel_height) / stride_height + 1);
186 (pad_w_idx - kernel_width) / stride_width + 1);
189 if (count_include_pad) {
190 divide_factor = kernel_height * kernel_width;
192 PrimExpr h_start = out_idx[height_axis] * stride_height - pad_top;
193 PrimExpr w_start = out_idx[width_axis] * stride_width - pad_left;
195 PrimExpr h_end =
min(h_start + kernel_height, height);
196 PrimExpr w_end =
min(w_start + kernel_width, width);
204 out_idx[height_axis] < out_height),
205 tir::And(out_idx[width_axis] >= out_idx_lower_w,
206 out_idx[width_axis] < out_width)),
207 out_grad(out_idx) / divide_factor,
make_const(out_grad->dtype, 0)),
210 "T_pool_grad",
"pool_grad_avg");
212 LOG(ERROR) <<
"Unrecognized pool_type: " << pool_type;
230 if (depth_axis) *depth_axis = -1;
231 if (height_axis) *height_axis = -1;
232 if (width_axis) *width_axis = -1;
234 for (
size_t i = 0; i < layout.size(); ++i) {
235 if ((layout[i] >=
'A' && layout[i] <=
'Z') || (layout[i] >=
'a' && layout[i] <=
'z')) {
236 if (layout[i] ==
'D' && depth_axis) {
237 if (*depth_axis != -1)
return false;
238 *depth_axis = curr_idx;
239 }
else if (layout[i] ==
'H' && height_axis) {
240 if (*height_axis != -1)
return false;
241 *height_axis = curr_idx;
242 }
else if (layout[i] ==
'W' && width_axis) {
243 if (*width_axis != -1)
return false;
244 *width_axis = curr_idx;
245 }
else if (layout[i] ==
'd' || layout[i] ==
'h' || layout[i] ==
'w') {
252 if ((depth_axis && *depth_axis == -1) || (height_axis && *height_axis == -1) ||
253 (width_axis && *width_axis == -1))
262 inline bool find_width(
const std::string& layout,
int* width_axis) {
299 PoolType pool_type,
bool ceil_mode,
const std::string& layout =
"NCHW",
300 bool count_include_pad =
true) {
301 int height_axis = -1, width_axis = -1;
302 ICHECK(
find_height_width(layout, &height_axis, &width_axis)) <<
"Unsupported layout " << layout;
303 return pool_grad_impl(out_grad, x, kernel_size, stride_size, padding_size, pool_type, ceil_mode,
304 height_axis, width_axis, count_include_pad);
308 return indexdiv(out_index * idim, odim);
327 PoolType pool_type,
const std::vector<int>& axes) {
328 const auto n_dim = output_size.
size();
329 ICHECK_EQ(axes.size(), n_dim) <<
"The number of axes not equal to the in/out dimension";
332 for (
size_t i = 0; i < data_shape.size(); ++i) {
333 data_shape.Set(i,
cast(DataType::DataType::Int(32), data_shape[i]));
337 for (
size_t i = 0; i < n_dim; ++i) {
340 out_shape.
Set(axes[i], out_size[i]);
343 auto get_iter_vars = [=](
const Array<Var>& output,
bool reduce_indices) {
345 for (
size_t i = 0; i < output.
size(); ++i) indices.
push_back(output[i]);
347 for (
size_t i = 0; i < n_dim; ++i) {
348 auto i_start =
start_index(output[axes[i]], out_size[i], in_size[i]);
349 auto i_end =
end_index(output[axes[i]], out_size[i], in_size[i]);
350 auto rv_name =
"rv" + std::to_string(i);
353 if (reduce_indices) {
354 indices.
Set(axes[i], i_start + rv_axis);
357 return std::make_tuple(indices, reduce_axes);
368 std::tie(indices, reduce_axes) = get_iter_vars(output,
true);
369 return tvm::max(x(indices), reduce_axes);
371 "adaptive_pool_max",
"adaptive_pool_max", attrs);
379 std::tie(indices, reduce_axes) = get_iter_vars(output,
true);
380 return tvm::sum(x(indices), reduce_axes);
382 "adaptive_pool_sum",
"adaptive_pool_sum");
389 std::tie(indices, reduce_axes) = get_iter_vars(output,
false);
392 for (
size_t i = 0; i < n_dim; ++i) {
393 divide_factor *=
tvm::cast(x->dtype, reduce_axes[i]->dom->extent);
396 return div(pool_sum(indices), divide_factor);
400 LOG(ERROR) <<
"Unrecognized pool_type: " << pool_type;
432 const std::string& layout =
"NCHW") {
433 int height_axis = -1, width_axis = -1;
434 ICHECK(
find_height_width(layout, &height_axis, &width_axis)) <<
"Unsupported layout " << layout;
447 PoolType pool_type,
const std::string& layout =
"NCDHW") {
448 int depth_axis = -1, height_axis = -1, width_axis = -1;
450 <<
"Unsupported layout " << layout;
451 return adaptive_pool_impl(x, output_size, pool_type, {depth_axis, height_axis, width_axis});
463 PoolType pool_type,
const std::string& layout =
"NCW") {
465 ICHECK(
find_width(layout, &width_axis)) <<
"Unsupported layout " << layout;
517 const std::vector<int>& axis,
bool count_include_pad) {
518 int k_size = kernel_size.
size();
519 int x_size = x->shape.size();
520 ICHECK_EQ(stride_size.
size(), k_size) <<
"Pooling stride_size must have same elements as kernel";
521 ICHECK_EQ(padding_size.
size(), k_size * 2) <<
"Pooling padding_size must has double elements of" 523 ICHECK_EQ(axis.size(), k_size) <<
"axis must have same elements as kernel";
526 std::vector<PrimExpr> kernel(k_size);
527 std::vector<PrimExpr> stride(k_size);
528 std::vector<PrimExpr> dilation(k_size);
529 std::vector<PrimExpr> pad_head(k_size);
530 std::vector<PrimExpr> pad_tail(k_size);
531 std::vector<PrimExpr> offset(k_size, 0);
535 for (
size_t i = 0; i < data_shape.
size(); ++i) {
536 data_shape.
Set(i,
cast(DataType::DataType::Int(32), data_shape[i]));
541 for (
int i = 0; i < k_size; i++) {
555 offset[i] = stride[i] - 1;
556 pad_tail[i] += offset[i];
561 do_pad = do_pad || (padding0 && *padding0) || (padding1 && *padding1);
565 pad_before.
Set(ii, pad_head[i]);
566 pad_after.
Set(ii, pad_tail[i]);
571 data_shape[ii] - (kernel[i] - 1) * dilation[i] - 1 + pad_head[i] + pad_tail[i];
573 out_shape.
Set(ii, out_dim);
578 auto temp = do_pad ?
pad(x, pad_before, pad_after,
tvm::min_value(x->dtype),
"pad_temp") : x;
586 for (
int i = 0; i < k_size; i++) {
588 indices.
Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]);
590 return tvm::max(temp(indices), daxis);
592 "pool_max",
"pool_max", attrs);
596 auto temp = do_pad ?
pad(x, pad_before, pad_after, 0,
"pad_temp") : x;
605 for (
int i = 0; i < k_size; i++) {
607 indices.
Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]);
609 return tvm::sum(temp(indices), daxis);
611 "pool_sum",
"pool_sum");
619 if (count_include_pad) {
620 std::vector<PrimExpr> start(k_size);
621 std::vector<PrimExpr> end(k_size);
623 for (
int i = 0; i < k_size; i++) {
625 start[i] = output[ii] * stride[i] - pad_head[i];
630 end[i] = start[i] + (kernel[i] - 1) * dilation[i];
631 end[i] =
min(end[i], data_shape[ii] + pad_tail[i] - 1 - offset[i]);
632 num_el *= (end[i] - start[i]) / dilation[i] + 1;
634 return div(pool_sum(indices), num_el);
636 std::vector<PrimExpr> start(k_size);
637 std::vector<PrimExpr> end(k_size);
639 for (
int i = 0; i < k_size; i++) {
646 start[i] = output[ii] * stride[i] - pad_head[i];
647 end[i] = start[i] + (kernel[i] - 1) * dilation[i];
652 PrimExpr jumps_to_non_pad = (dilation[i] - 1 - start[i]) / dilation[i];
655 end[i] =
min(end[i], data_shape[ii] - 1);
656 num_el *= (end[i] - (start[i] + dilation[i] * jumps_to_non_pad)) / dilation[i] + 1;
660 return div(pool_sum(indices), divide_factor);
665 LOG(ERROR) <<
"Unrecognized pool_type: " << pool_type;
703 const std::string& layout =
"NCW",
bool count_include_pad =
true) {
705 ICHECK(
find_width(layout, &width_axis)) <<
"Unsupported layout " << layout;
706 std::vector<int> axis = {width_axis};
707 return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
708 ceil_mode, axis, count_include_pad);
744 const std::string& layout =
"NCHW",
bool count_include_pad =
true) {
745 int height_axis = -1, width_axis = -1;
746 ICHECK(
find_height_width(layout, &height_axis, &width_axis)) <<
"Unsupported layout " << layout;
747 std::vector<int> axis = {height_axis, width_axis};
748 return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
749 ceil_mode, axis, count_include_pad);
786 const std::string& layout =
"NCDHW",
bool count_include_pad =
true) {
787 int depth_axis = -1, height_axis = -1, width_axis = -1;
789 <<
"Unsupported layout " << layout;
790 std::vector<int> axis = {depth_axis, height_axis, width_axis};
791 return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
792 ceil_mode, axis, count_include_pad);
798 #endif // TVM_TOPI_NN_POOLING_H_ Tensor max(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the maximum of elements over a given axis.
Definition: reduction.h:433
constexpr auto kCommReduceIdx
Definition: tags.h:35
Tensor pool1d(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCW", bool count_include_pad=true)
Perform pooling on the width dimension of data. Width axis is determined by the layout string in whic...
Definition: pooling.h:700
PrimExpr min_value(const DataType &dtype, Span span=Span())
FCommReduce MakeArgmaxReducer(bool select_last_index=false)
Definition: reduction.h:499
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:954
Tensor min(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the minimum of elements over a given axis.
Definition: reduction.h:414
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Tensor adaptive_pool_impl(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::vector< int > &axes)
Perform adaptive pooling on N dimensional data.
Definition: pooling.h:326
Tensor expression language DSL.
Definition: extracted_task.h:33
a named variable in TIR
Definition: var.h:88
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
Algebra expression simplifications.
Tensor pool3d(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCDHW", bool count_include_pad=true)
Perform pooling on depth, height and width dimension of data. It decides the depth, height and width dimension according to the layout string, in which 'D', 'W' and 'H' means depth, width and height respectively. Depth, Width and height dimension cannot be split. For example, NCDHW, NCDHW16c, etc. are valid for pool, while NCDHW16d, NCDHW16w or NCDHW16h are not. See layout for more information of the layout string convention.
Definition: pooling.h:783
Reduction op constructors.
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
PoolType
Pooling type.
Definition: pooling.h:44
const int64_t * as_const_int(const PrimExpr &x)
Get x as constant int expression.
Definition: op.h:803
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
void Set(int64_t i, T value)
set i-th element of the array.
Definition: array.h:621
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
Range constainer.
Definition: expr.h:715
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
size_t size() const
Definition: array.h:420
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Tensor pool_impl_nd(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::vector< int > &axis, bool count_include_pad)
Perform pooling on N-dimension of data.
Definition: pooling.h:514
constexpr auto kElementWise
Definition: tags.h:32
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
Reference to string objects.
Definition: string.h:98
Tensor global_pool(const Tensor &x, PoolType pool_type, const std::string &layout="NCHW")
Perform global pooling on height and width dimension of data. It decides the height and width dimensi...
Definition: pooling.h:494
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
tvm::te::Tensor pad(const tvm::te::Tensor &t, const tvm::Array< tvm::PrimExpr > &pad_before, tvm::Array< tvm::PrimExpr > pad_after=tvm::Array< tvm::PrimExpr >(), PrimExpr pad_value=PrimExpr(), std::string name="T_pad", std::string tag=kElementWise, std::string pad_mode="constant", const Array< PrimExpr > *dyn_output_shape=nullptr)
Creates an operation that performs padding.
Definition: nn.h:155
Tensor argmax(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false, bool select_last_index=false)
Creates an operation that finds the indices of the maximum values over a given axis.
Definition: reduction.h:553
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
iterator end() const
Definition: array.h:390
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
iterator begin() const
Definition: array.h:387
PrimExpr end_index(const Var &out_index, const PrimExpr &odim, const PrimExpr &idim)
Definition: pooling.h:311
Tensor pool2d(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCHW", bool count_include_pad=true)
Perform pooling on height and width dimension of data. It decides the height and width dimension acco...
Definition: pooling.h:741
PrimExpr start_index(const Var &out_index, const PrimExpr &odim, const PrimExpr &idim)
Definition: pooling.h:307
Tensor adaptive_pool3d(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCDHW")
Adaptively perform pooling on three dimensional data. See the two dimensional version above for detai...
Definition: pooling.h:446
Tensor adaptive_pool1d(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCW")
Adaptively perform pooling on one dimensional data. See the two dimensional version above for details...
Definition: pooling.h:462
Tensor adaptive_pool(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCHW")
Adaptively perform pooling on height and width dimension of data. The pooling kernel and stride sizes...
Definition: pooling.h:431
Managed reference to SelectNode.
Definition: expr.h:609
Tensor pool_grad(const Tensor &out_grad, const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCHW", bool count_include_pad=true)
Calculate gradient of pooling on height and width dimension of data. It decides the height and width ...
Definition: pooling.h:297
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
Managed reference to AndNode.
Definition: expr.h:482
Tensor pool_grad_impl(const Tensor &out_grad, const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const size_t height_axis, const size_t width_axis, bool count_include_pad)
Definition: pooling.h:49
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type...
Definition: elemwise.h:281
Reference to PrimExprNode.
Definition: expr.h:114
void Set(const K &key, const V &value)
set the Map.
Definition: map.h:1374
bool find_depth_height_width(const std::string &layout, int *depth_axis, int *height_axis, int *width_axis)
Find index of Depth, Height or Width dimension in a layout string.
Definition: pooling.h:228
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:579
bool find_height_width(const std::string &layout, int *height_axis, int *width_axis)
Definition: pooling.h:258
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:164
bool find_width(const std::string &layout, int *width_axis)
Definition: pooling.h:262