24 #ifndef TVM_TOPI_NN_POOLING_H_ 25 #define TVM_TOPI_NN_POOLING_H_ 52 bool ceil_mode,
const size_t height_axis,
const size_t width_axis,
53 bool count_include_pad) {
54 ICHECK(out_grad->shape.size() >= 2) <<
"Pooling grad output must >= 2-D (H, W)";
55 ICHECK(x->shape.size() >= 2) <<
"Pooling input must >= 2-D (H, W)";
56 ICHECK_EQ(kernel_size.
size(), 2) <<
"Pooling kernel_size must have 2 elements";
57 ICHECK_EQ(stride_size.
size(), 2) <<
"Pooling stride_size must have 2 elements";
58 ICHECK_EQ(padding_size.
size(), 4) <<
"Pooling padding_size must have 4 elements";
60 auto kernel_height =
cast(DataType::DataType::Int(32), kernel_size[0]);
61 auto kernel_width =
cast(DataType::DataType::Int(32), kernel_size[1]);
62 auto stride_height =
cast(DataType::DataType::Int(32), stride_size[0]);
63 auto stride_width =
cast(DataType::DataType::Int(32), stride_size[1]);
65 auto height =
cast(DataType::DataType::Int(32), x->shape[height_axis]);
66 auto width =
cast(DataType::DataType::Int(32), x->shape[width_axis]);
68 auto pad_top =
cast(DataType::DataType::Int(32), padding_size[0]);
69 auto pad_left =
cast(DataType::DataType::Int(32), padding_size[1]);
70 auto pad_bottom =
cast(DataType::DataType::Int(32), padding_size[2]);
71 auto pad_right =
cast(DataType::DataType::Int(32), padding_size[3]);
76 pad_bottom += stride_height - 1;
77 pad_right += stride_width - 1;
81 pad_before.
Set(height_axis, pad_top);
82 pad_before.
Set(width_axis, pad_left);
85 pad_after.
Set(height_axis, pad_bottom);
86 pad_after.
Set(width_axis, pad_right);
89 analyzer.
Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1);
91 analyzer.
Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1);
97 for (
size_t i = 0; i < data_shape.
size(); ++i) {
98 data_shape.
Set(i,
cast(DataType::DataType::Int(32), data_shape[i]));
102 out_shape.
Set(height_axis, out_height);
103 out_shape.
Set(width_axis, out_width);
109 const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
110 ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
114 ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom);
115 ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);
123 auto pad_x = do_pad ?
pad(x, pad_before, pad_after,
tvm::min_value(x->dtype),
"pad_temp") : x;
129 window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight);
130 window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth);
131 auto idx = detail::RavelIndex(window_inds, ravel_shape);
132 return argmax({idx, pad_x(window_inds)}, {dheight, dwidth},
nullptr);
136 auto mp_inds = mp_argmax[0];
142 pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top);
143 pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left);
144 auto idx = detail::RavelIndex(pad_inds, ravel_shape);
147 out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh);
148 out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww);
151 pad_inds[height_axis] < kernel_height,
make_const(DataType::DataType::Int(32), 0),
152 (pad_inds[height_axis] - kernel_height) / stride_height + 1);
154 pad_inds[width_axis] < kernel_width,
make_const(DataType::DataType::Int(32), 0),
155 (pad_inds[width_axis] - kernel_width) / stride_width + 1);
159 out_idx[width_axis] >= out_idx_lower_w),
160 mp_inds(out_idx) == idx),
164 "T_pool_grad",
"pool_grad_max");
173 PrimExpr pad_h_idx = inds[height_axis] + pad_top;
174 PrimExpr pad_w_idx = inds[width_axis] + pad_left;
178 out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh));
179 out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww));
183 (pad_h_idx - kernel_height) / stride_height + 1);
186 (pad_w_idx - kernel_width) / stride_width + 1);
189 if (count_include_pad) {
190 divide_factor = kernel_height * kernel_width;
192 PrimExpr h_start = out_idx[height_axis] * stride_height - pad_top;
193 PrimExpr w_start = out_idx[width_axis] * stride_width - pad_left;
195 PrimExpr h_end =
min(h_start + kernel_height, height);
196 PrimExpr w_end =
min(w_start + kernel_width, width);
204 out_idx[height_axis] < out_height),
205 tir::And(out_idx[width_axis] >= out_idx_lower_w,
206 out_idx[width_axis] < out_width)),
207 out_grad(out_idx) / divide_factor,
make_const(out_grad->dtype, 0)),
210 "T_pool_grad",
"pool_grad_avg");
212 LOG(ERROR) <<
"Unrecognized pool_type: " << pool_type;
223 for (
size_t i = 0; i < layout.size(); ++i) {
224 if ((layout[i] >=
'A' && layout[i] <=
'Z') || (layout[i] >=
'a' && layout[i] <=
'z')) {
225 if (layout[i] ==
'D') {
226 if (*depth_axis != -1)
return false;
227 *depth_axis = curr_idx;
228 }
else if (layout[i] ==
'H') {
229 if (*height_axis != -1)
return false;
230 *height_axis = curr_idx;
231 }
else if (layout[i] ==
'W') {
232 if (*width_axis != -1)
return false;
233 *width_axis = curr_idx;
234 }
else if (layout[i] ==
'd' || layout[i] ==
'h' || layout[i] ==
'w') {
241 if (*depth_axis == -1 || *height_axis == -1 || *width_axis == -1)
return false;
248 if (*height_axis != -1 && *width_axis != -1) {
254 inline bool find_width(
const std::string& layout,
int* width_axis) {
257 if (*width_axis != -1) {
295 PoolType pool_type,
bool ceil_mode,
const std::string& layout =
"NCHW",
296 bool count_include_pad =
true) {
297 int height_axis = -1, width_axis = -1;
298 ICHECK(
find_height_width(layout, &height_axis, &width_axis)) <<
"Unsupported layout " << layout;
299 return pool_grad_impl(out_grad, x, kernel_size, stride_size, padding_size, pool_type, ceil_mode,
300 height_axis, width_axis, count_include_pad);
304 return indexdiv(out_index * idim, odim);
323 PoolType pool_type,
const std::vector<int>& axes) {
324 const auto n_dim = output_size.
size();
325 ICHECK_EQ(axes.size(), n_dim) <<
"The number of axes not equal to the in/out dimension";
328 for (
size_t i = 0; i < data_shape.size(); ++i) {
329 data_shape.Set(i,
cast(DataType::DataType::Int(32), data_shape[i]));
333 for (
size_t i = 0; i < n_dim; ++i) {
336 out_shape.
Set(axes[i], out_size[i]);
339 auto get_iter_vars = [=](
const Array<Var>& output,
bool reduce_indices) {
341 for (
size_t i = 0; i < output.
size(); ++i) indices.
push_back(output[i]);
343 for (
size_t i = 0; i < n_dim; ++i) {
344 auto i_start =
start_index(output[axes[i]], out_size[i], in_size[i]);
345 auto i_end =
end_index(output[axes[i]], out_size[i], in_size[i]);
346 auto rv_name =
"rv" + std::to_string(i);
349 if (reduce_indices) {
350 indices.
Set(axes[i], i_start + rv_axis);
353 return std::make_tuple(indices, reduce_axes);
364 std::tie(indices, reduce_axes) = get_iter_vars(output,
true);
365 return tvm::max(x(indices), reduce_axes);
367 "adaptive_pool_max",
"adaptive_pool_max", attrs);
375 std::tie(indices, reduce_axes) = get_iter_vars(output,
true);
376 return tvm::sum(x(indices), reduce_axes);
378 "adaptive_pool_sum",
"adaptive_pool_sum");
385 std::tie(indices, reduce_axes) = get_iter_vars(output,
false);
388 for (
size_t i = 0; i < n_dim; ++i) {
389 divide_factor *=
tvm::cast(x->dtype, reduce_axes[i]->dom->extent);
392 return div(pool_sum(indices), divide_factor);
396 LOG(ERROR) <<
"Unrecognized pool_type: " << pool_type;
428 const std::string& layout =
"NCHW") {
429 int height_axis = -1, width_axis = -1;
430 ICHECK(
find_height_width(layout, &height_axis, &width_axis)) <<
"Unsupported layout " << layout;
443 PoolType pool_type,
const std::string& layout =
"NCDHW") {
444 int depth_axis = -1, height_axis = -1, width_axis = -1;
446 <<
"Unsupported layout " << layout;
447 return adaptive_pool_impl(x, output_size, pool_type, {depth_axis, height_axis, width_axis});
459 PoolType pool_type,
const std::string& layout =
"NCW") {
461 ICHECK(
find_width(layout, &width_axis)) <<
"Unsupported layout " << layout;
513 const std::vector<int>& axis,
bool count_include_pad) {
514 int k_size = kernel_size.
size();
515 int x_size = x->shape.size();
516 ICHECK_EQ(stride_size.
size(), k_size) <<
"Pooling stride_size must have same elements as kernel";
517 ICHECK_EQ(padding_size.
size(), k_size * 2) <<
"Pooling padding_size must has double elements of" 519 ICHECK_EQ(axis.size(), k_size) <<
"axis must have same elements as kernel";
522 std::vector<PrimExpr> kernel(k_size);
523 std::vector<PrimExpr> stride(k_size);
524 std::vector<PrimExpr> dilation(k_size);
525 std::vector<PrimExpr> pad_head(k_size);
526 std::vector<PrimExpr> pad_tail(k_size);
527 std::vector<PrimExpr> offset(k_size, 0);
531 for (
size_t i = 0; i < data_shape.
size(); ++i) {
532 data_shape.
Set(i,
cast(DataType::DataType::Int(32), data_shape[i]));
537 for (
int i = 0; i < k_size; i++) {
551 offset[i] = stride[i] - 1;
552 pad_tail[i] += offset[i];
557 do_pad = do_pad || (padding0 && *padding0) || (padding1 && *padding1);
561 pad_before.
Set(ii, pad_head[i]);
562 pad_after.
Set(ii, pad_tail[i]);
567 data_shape[ii] - (kernel[i] - 1) * dilation[i] - 1 + pad_head[i] + pad_tail[i];
569 out_shape.
Set(ii, out_dim);
574 auto temp = do_pad ?
pad(x, pad_before, pad_after,
tvm::min_value(x->dtype),
"pad_temp") : x;
582 for (
int i = 0; i < k_size; i++) {
584 indices.
Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]);
586 return tvm::max(temp(indices), daxis);
588 "pool_max",
"pool_max", attrs);
592 auto temp = do_pad ?
pad(x, pad_before, pad_after, 0,
"pad_temp") : x;
601 for (
int i = 0; i < k_size; i++) {
603 indices.
Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]);
605 return tvm::sum(temp(indices), daxis);
607 "pool_sum",
"pool_sum");
615 if (count_include_pad) {
616 std::vector<PrimExpr> start(k_size);
617 std::vector<PrimExpr> end(k_size);
619 for (
int i = 0; i < k_size; i++) {
621 start[i] = output[ii] * stride[i] - pad_head[i];
626 end[i] = start[i] + (kernel[i] - 1) * dilation[i];
627 end[i] =
min(end[i], data_shape[ii] + pad_tail[i] - 1 - offset[i]);
628 num_el *= (end[i] - start[i]) / dilation[i] + 1;
630 return div(pool_sum(indices), num_el);
632 std::vector<PrimExpr> start(k_size);
633 std::vector<PrimExpr> end(k_size);
635 for (
int i = 0; i < k_size; i++) {
642 start[i] = output[ii] * stride[i] - pad_head[i];
643 end[i] = start[i] + (kernel[i] - 1) * dilation[i];
648 PrimExpr jumps_to_non_pad = (dilation[i] - 1 - start[i]) / dilation[i];
651 end[i] =
min(end[i], data_shape[ii] - 1);
652 num_el *= (end[i] - (start[i] + dilation[i] * jumps_to_non_pad)) / dilation[i] + 1;
656 return div(pool_sum(indices), divide_factor);
661 LOG(ERROR) <<
"Unrecognized pool_type: " << pool_type;
699 const std::string& layout =
"NCW",
bool count_include_pad =
true) {
701 ICHECK(
find_width(layout, &width_axis)) <<
"Unsupported layout " << layout;
702 std::vector<int> axis = {width_axis};
703 return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
704 ceil_mode, axis, count_include_pad);
740 const std::string& layout =
"NCHW",
bool count_include_pad =
true) {
741 int height_axis = -1, width_axis = -1;
742 ICHECK(
find_height_width(layout, &height_axis, &width_axis)) <<
"Unsupported layout " << layout;
743 std::vector<int> axis = {height_axis, width_axis};
744 return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
745 ceil_mode, axis, count_include_pad);
782 const std::string& layout =
"NCDHW",
bool count_include_pad =
true) {
783 int depth_axis = -1, height_axis = -1, width_axis = -1;
785 <<
"Unsupported layout " << layout;
786 std::vector<int> axis = {depth_axis, height_axis, width_axis};
787 return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
788 ceil_mode, axis, count_include_pad);
794 #endif // TVM_TOPI_NN_POOLING_H_ Tensor max(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the maximum of elements over a given axis.
Definition: reduction.h:429
constexpr auto kCommReduceIdx
Definition: tags.h:35
Tensor pool1d(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCW", bool count_include_pad=true)
Perform pooling on the width dimension of data. Width axis is determined by the layout string in whic...
Definition: pooling.h:696
PrimExpr min_value(const DataType &dtype, Span span=Span())
FCommReduce MakeArgmaxReducer(bool select_last_index=false)
Definition: reduction.h:495
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:942
Tensor min(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the minimum of elements over a given axis.
Definition: reduction.h:410
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Tensor adaptive_pool_impl(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::vector< int > &axes)
Perform adaptive pooling on N dimensional data.
Definition: pooling.h:322
Tensor expression language DSL.
Definition: extracted_task.h:33
a named variable in TIR
Definition: var.h:88
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
Algebra expression simplifications.
Tensor pool3d(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCDHW", bool count_include_pad=true)
Perform pooling on depth, height and width dimension of data. It decides the depth, height and width dimension according to the layout string, in which 'D', 'W' and 'H' means depth, width and height respectively. Depth, Width and height dimension cannot be split. For example, NCDHW, NCDHW16c, etc. are valid for pool, while NCDHW16d, NCDHW16w or NCDHW16h are not. See layout for more information of the layout string convention.
Definition: pooling.h:779
Reduction op constructors.
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
PoolType
Pooling type.
Definition: pooling.h:44
const int64_t * as_const_int(const PrimExpr &x)
Get x as constant int expression.
Definition: op.h:791
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
void Set(int64_t i, T value)
set i-th element of the array.
Definition: array.h:591
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
Range constainer.
Definition: expr.h:713
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
size_t size() const
Definition: array.h:420
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Tensor pool_impl_nd(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::vector< int > &axis, bool count_include_pad)
Perform pooling on N-dimension of data.
Definition: pooling.h:510
constexpr auto kElementWise
Definition: tags.h:32
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
Reference to string objects.
Definition: string.h:97
Tensor global_pool(const Tensor &x, PoolType pool_type, const std::string &layout="NCHW")
Perform global pooling on height and width dimension of data. It decides the height and width dimensi...
Definition: pooling.h:490
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
tvm::te::Tensor pad(const tvm::te::Tensor &t, const tvm::Array< tvm::PrimExpr > &pad_before, tvm::Array< tvm::PrimExpr > pad_after=tvm::Array< tvm::PrimExpr >(), PrimExpr pad_value=PrimExpr(), std::string name="T_pad", std::string tag=kElementWise, std::string pad_mode="constant", const Array< PrimExpr > *dyn_output_shape=nullptr)
Creates an operation that performs padding.
Definition: nn.h:155
Tensor argmax(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false, bool select_last_index=false)
Creates an operation that finds the indices of the maximum values over a given axis.
Definition: reduction.h:549
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
iterator end() const
Definition: array.h:390
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
iterator begin() const
Definition: array.h:387
PrimExpr end_index(const Var &out_index, const PrimExpr &odim, const PrimExpr &idim)
Definition: pooling.h:307
Tensor pool2d(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCHW", bool count_include_pad=true)
Perform pooling on height and width dimension of data. It decides the height and width dimension acco...
Definition: pooling.h:737
PrimExpr start_index(const Var &out_index, const PrimExpr &odim, const PrimExpr &idim)
Definition: pooling.h:303
Tensor adaptive_pool3d(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCDHW")
Adaptively perform pooling on three dimensional data. See the two dimensional version above for detai...
Definition: pooling.h:442
Tensor adaptive_pool1d(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCW")
Adaptively perform pooling on one dimensional data. See the two dimensional version above for details...
Definition: pooling.h:458
Tensor adaptive_pool(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCHW")
Adaptively perform pooling on height and width dimension of data. The pooling kernel and stride sizes...
Definition: pooling.h:427
Managed reference to SelectNode.
Definition: expr.h:609
Tensor pool_grad(const Tensor &out_grad, const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCHW", bool count_include_pad=true)
Calculate gradient of pooling on height and width dimension of data. It decides the height and width ...
Definition: pooling.h:293
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
Managed reference to AndNode.
Definition: expr.h:482
Tensor pool_grad_impl(const Tensor &out_grad, const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const size_t height_axis, const size_t width_axis, bool count_include_pad)
Definition: pooling.h:49
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type...
Definition: elemwise.h:280
Reference to PrimExprNode.
Definition: expr.h:112
void Set(const K &key, const V &value)
set the Map.
Definition: map.h:1374
bool find_depth_height_width(const std::string &layout, int *depth_axis, int *height_axis, int *width_axis)
Definition: pooling.h:217
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:563
bool find_height_width(const std::string &layout, int *height_axis, int *width_axis)
Definition: pooling.h:245
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:164
bool find_width(const std::string &layout, int *width_axis)
Definition: pooling.h:254