tvm
pooling.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_NN_POOLING_H_
25 #define TVM_TOPI_NN_POOLING_H_
26 
27 #include <tvm/arith/analyzer.h>
29 #include <tvm/topi/nn.h>
30 #include <tvm/topi/reduction.h>
31 #include <tvm/topi/tags.h>
32 
33 #include <algorithm>
34 #include <string>
35 #include <vector>
36 
37 namespace tvm {
38 namespace topi {
39 namespace nn {
40 
41 using namespace tvm::te;
42 
44 enum PoolType : int {
47 };
48 
49 inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
50  const Array<PrimExpr>& kernel_size, const Array<PrimExpr>& stride_size,
51  const Array<PrimExpr>& padding_size, PoolType pool_type,
52  bool ceil_mode, const size_t height_axis, const size_t width_axis,
53  bool count_include_pad) {
54  ICHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)";
55  ICHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)";
56  ICHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements";
57  ICHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
58  ICHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";
59 
60  auto kernel_height = cast(DataType::DataType::Int(32), kernel_size[0]);
61  auto kernel_width = cast(DataType::DataType::Int(32), kernel_size[1]);
62  auto stride_height = cast(DataType::DataType::Int(32), stride_size[0]);
63  auto stride_width = cast(DataType::DataType::Int(32), stride_size[1]);
64 
65  auto height = cast(DataType::DataType::Int(32), x->shape[height_axis]);
66  auto width = cast(DataType::DataType::Int(32), x->shape[width_axis]);
67 
68  auto pad_top = cast(DataType::DataType::Int(32), padding_size[0]);
69  auto pad_left = cast(DataType::DataType::Int(32), padding_size[1]);
70  auto pad_bottom = cast(DataType::DataType::Int(32), padding_size[2]);
71  auto pad_right = cast(DataType::DataType::Int(32), padding_size[3]);
72 
73  if (ceil_mode) {
74  // Additional padding to ensure we do ceil instead of floor when
75  // dividing by stride.
76  pad_bottom += stride_height - 1;
77  pad_right += stride_width - 1;
78  }
79 
80  Array<PrimExpr> pad_before(std::vector<PrimExpr>(x->shape.size(), 0));
81  pad_before.Set(height_axis, pad_top);
82  pad_before.Set(width_axis, pad_left);
83 
84  Array<PrimExpr> pad_after(std::vector<PrimExpr>(x->shape.size(), 0));
85  pad_after.Set(height_axis, pad_bottom);
86  pad_after.Set(width_axis, pad_right);
87  arith::Analyzer analyzer;
88  auto out_height =
89  analyzer.Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1);
90  auto out_width =
91  analyzer.Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1);
92 
93  auto dheight = tvm::te::reduce_axis(Range(0, kernel_height), "dh");
94  auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width), "dw");
95 
96  Array<PrimExpr> data_shape = x->shape;
97  for (size_t i = 0; i < data_shape.size(); ++i) {
98  data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
99  }
100 
101  Array<PrimExpr> out_shape = data_shape;
102  out_shape.Set(height_axis, out_height);
103  out_shape.Set(width_axis, out_width);
104 
105  const int64_t* padding_h0 = as_const_int(pad_top);
106  const int64_t* padding_w0 = as_const_int(pad_left);
107  const int64_t* padding_h1 = as_const_int(pad_bottom);
108  const int64_t* padding_w1 = as_const_int(pad_right);
109  const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
110  ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
111 
112  if (pool_type == kMaxPool) {
113  Array<PrimExpr> ravel_shape{data_shape.begin(), data_shape.end()};
114  ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom);
115  ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);
116 
117  auto windowh =
118  tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh");
119  auto windoww =
120  tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww");
121 
122  auto argmax = MakeArgmaxReducer();
123  auto pad_x = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
124 
125  auto mp_argmax = tvm::te::compute(
126  out_shape,
127  [&](const Array<Var>& inds) {
128  Array<PrimExpr> window_inds{inds.begin(), inds.end()};
129  window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight);
130  window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth);
131  auto idx = detail::RavelIndex(window_inds, ravel_shape);
132  return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr);
133  },
134  "maxpool_grad_argmax", kCommReduceIdx);
135 
136  auto mp_inds = mp_argmax[0];
137 
138  return tvm::te::compute(
139  data_shape,
140  [&](const Array<Var>& inds) {
141  Array<PrimExpr> pad_inds{inds.begin(), inds.end()};
142  pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top);
143  pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left);
144  auto idx = detail::RavelIndex(pad_inds, ravel_shape);
145 
146  Array<PrimExpr> out_idx{inds.begin(), inds.end()};
147  out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh);
148  out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww);
149 
150  PrimExpr out_idx_lower_h = tir::Select(
151  pad_inds[height_axis] < kernel_height, make_const(DataType::DataType::Int(32), 0),
152  (pad_inds[height_axis] - kernel_height) / stride_height + 1);
153  PrimExpr out_idx_lower_w = tir::Select(
154  pad_inds[width_axis] < kernel_width, make_const(DataType::DataType::Int(32), 0),
155  (pad_inds[width_axis] - kernel_width) / stride_width + 1);
156 
157  return tvm::sum(
158  tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h,
159  out_idx[width_axis] >= out_idx_lower_w),
160  mp_inds(out_idx) == idx),
161  out_grad(out_idx), make_const(x->dtype, 0)),
162  {windowh, windoww});
163  },
164  "T_pool_grad", "pool_grad_max");
165  } else if (pool_type == kAvgPool) {
166  auto windowh =
167  tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh");
168  auto windoww =
169  tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww");
170  return tvm::te::compute(
171  data_shape,
172  [&](const Array<Var>& inds) {
173  PrimExpr pad_h_idx = inds[height_axis] + pad_top;
174  PrimExpr pad_w_idx = inds[width_axis] + pad_left;
175 
176  // output indices whose pooling windows cover current input element (can be out-of-bound)
177  Array<PrimExpr> out_idx{inds.begin(), inds.end()};
178  out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh));
179  out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww));
180 
181  PrimExpr out_idx_lower_h =
182  tir::Select(pad_h_idx < kernel_height, make_const(DataType::Int(32), 0),
183  (pad_h_idx - kernel_height) / stride_height + 1);
184  PrimExpr out_idx_lower_w =
185  tir::Select(pad_w_idx < kernel_width, make_const(DataType::Int(32), 0),
186  (pad_w_idx - kernel_width) / stride_width + 1);
187 
188  PrimExpr divide_factor; // number of pooled elements
189  if (count_include_pad) {
190  divide_factor = kernel_height * kernel_width;
191  } else {
192  PrimExpr h_start = out_idx[height_axis] * stride_height - pad_top;
193  PrimExpr w_start = out_idx[width_axis] * stride_width - pad_left;
194 
195  PrimExpr h_end = min(h_start + kernel_height, height);
196  PrimExpr w_end = min(w_start + kernel_width, width);
197  h_start = max(h_start, make_const(DataType::Int(32), 0));
198  w_start = max(w_start, make_const(DataType::Int(32), 0));
199  divide_factor =
200  max((h_end - h_start) * (w_end - w_start), make_const(DataType::Int(32), 1));
201  }
202  return tvm::sum(
203  tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h,
204  out_idx[height_axis] < out_height),
205  tir::And(out_idx[width_axis] >= out_idx_lower_w,
206  out_idx[width_axis] < out_width)),
207  out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)),
208  {windowh, windoww});
209  },
210  "T_pool_grad", "pool_grad_avg");
211  } else {
212  LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
213  return Tensor();
214  }
215 }
216 
217 inline bool find_depth_height_width(const std::string& layout, int* depth_axis, int* height_axis,
218  int* width_axis) {
219  *depth_axis = -1;
220  *height_axis = -1;
221  *width_axis = -1;
222  int curr_idx = 0;
223  for (size_t i = 0; i < layout.size(); ++i) {
224  if ((layout[i] >= 'A' && layout[i] <= 'Z') || (layout[i] >= 'a' && layout[i] <= 'z')) {
225  if (layout[i] == 'D') {
226  if (*depth_axis != -1) return false;
227  *depth_axis = curr_idx;
228  } else if (layout[i] == 'H') {
229  if (*height_axis != -1) return false;
230  *height_axis = curr_idx;
231  } else if (layout[i] == 'W') {
232  if (*width_axis != -1) return false;
233  *width_axis = curr_idx;
234  } else if (layout[i] == 'd' || layout[i] == 'h' || layout[i] == 'w') {
235  // do not support split on height or width, e.g., NCHW16w
236  return false;
237  }
238  ++curr_idx;
239  }
240  }
241  if (*depth_axis == -1 || *height_axis == -1 || *width_axis == -1) return false;
242  return true;
243 }
244 
245 inline bool find_height_width(const std::string& layout, int* height_axis, int* width_axis) {
246  int dummy;
247  ICHECK_EQ(find_depth_height_width(layout, &dummy, height_axis, width_axis), false);
248  if (*height_axis != -1 && *width_axis != -1) {
249  return true;
250  }
251  return false;
252 }
253 
254 inline bool find_width(const std::string& layout, int* width_axis) {
255  int dummy;
256  ICHECK_EQ(find_depth_height_width(layout, &dummy, &dummy, width_axis), false);
257  if (*width_axis != -1) {
258  return true;
259  }
260  return false;
261 }
262 
293 inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array<PrimExpr>& kernel_size,
294  const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
295  PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW",
296  bool count_include_pad = true) {
297  int height_axis = -1, width_axis = -1;
298  ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
299  return pool_grad_impl(out_grad, x, kernel_size, stride_size, padding_size, pool_type, ceil_mode,
300  height_axis, width_axis, count_include_pad);
301 }
302 
303 inline PrimExpr start_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) {
304  return indexdiv(out_index * idim, odim);
305 }
306 
307 inline PrimExpr end_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) {
308  PrimExpr tmp = indexdiv((out_index + 1) * idim, odim);
309  return tvm::tir::Select(indexmod((out_index + 1) * idim, odim) == 0, tmp, tmp + 1);
310 }
311 
322 inline Tensor adaptive_pool_impl(const Tensor& x, const Array<PrimExpr>& output_size,
323  PoolType pool_type, const std::vector<int>& axes) {
324  const auto n_dim = output_size.size();
325  ICHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension";
326 
327  Array<PrimExpr> data_shape = x->shape;
328  for (size_t i = 0; i < data_shape.size(); ++i) {
329  data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
330  }
331  Array<PrimExpr> out_shape = data_shape;
332  Array<PrimExpr> in_size, out_size;
333  for (size_t i = 0; i < n_dim; ++i) {
334  in_size.push_back(data_shape[axes[i]]);
335  out_size.push_back(cast(DataType::Int(32), output_size[i]));
336  out_shape.Set(axes[i], out_size[i]);
337  }
338 
339  auto get_iter_vars = [=](const Array<Var>& output, bool reduce_indices) {
340  Array<PrimExpr> indices;
341  for (size_t i = 0; i < output.size(); ++i) indices.push_back(output[i]);
342  Array<tir::IterVar> reduce_axes;
343  for (size_t i = 0; i < n_dim; ++i) {
344  auto i_start = start_index(output[axes[i]], out_size[i], in_size[i]);
345  auto i_end = end_index(output[axes[i]], out_size[i], in_size[i]);
346  auto rv_name = "rv" + std::to_string(i);
347  auto rv_axis = tvm::te::reduce_axis(Range(0, i_end - i_start), rv_name);
348  reduce_axes.push_back(rv_axis);
349  if (reduce_indices) {
350  indices.Set(axes[i], i_start + rv_axis);
351  }
352  }
353  return std::make_tuple(indices, reduce_axes);
354  };
355 
356  if (pool_type == kMaxPool) {
357  return tvm::te::compute(
358  out_shape,
359  [&](const Array<Var>& output) {
360  Array<PrimExpr> indices;
361  Array<tir::IterVar> reduce_axes;
362  std::tie(indices, reduce_axes) = get_iter_vars(output, true);
363  return tvm::max(x(indices), reduce_axes); // NOLINT(*)
364  },
365  "tensor", "adaptive_pool_max");
366  } else if (pool_type == kAvgPool) {
367  auto pool_sum = tvm::te::compute(
368  out_shape,
369  [&](const Array<Var>& output) {
370  Array<PrimExpr> indices;
371  Array<tir::IterVar> reduce_axes;
372  std::tie(indices, reduce_axes) = get_iter_vars(output, true);
373  return tvm::sum(x(indices), reduce_axes);
374  },
375  "tensor", "adaptive_pool_sum");
376 
377  return tvm::te::compute(
378  out_shape,
379  [&](const Array<Var>& output) {
380  Array<PrimExpr> indices;
381  Array<tir::IterVar> reduce_axes;
382  std::tie(indices, reduce_axes) = get_iter_vars(output, false);
383 
384  PrimExpr divide_factor = tvm::cast(x->dtype, 1);
385  for (size_t i = 0; i < n_dim; ++i) {
386  divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent);
387  }
388 
389  return div(pool_sum(indices), divide_factor);
390  },
391  "tensor", kElementWise);
392  } else {
393  LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
394  return x;
395  }
396 }
397 
424 inline Tensor adaptive_pool(const Tensor& x, const Array<PrimExpr>& output_size, PoolType pool_type,
425  const std::string& layout = "NCHW") {
426  int height_axis = -1, width_axis = -1;
427  ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
428  return adaptive_pool_impl(x, output_size, pool_type, {height_axis, width_axis});
429 }
430 
439 inline Tensor adaptive_pool3d(const Tensor& x, const Array<PrimExpr>& output_size,
440  PoolType pool_type, const std::string& layout = "NCDHW") {
441  int depth_axis = -1, height_axis = -1, width_axis = -1;
442  ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis))
443  << "Unsupported layout " << layout;
444  return adaptive_pool_impl(x, output_size, pool_type, {depth_axis, height_axis, width_axis});
445 }
446 
455 inline Tensor adaptive_pool1d(const Tensor& x, const Array<PrimExpr>& output_size,
456  PoolType pool_type, const std::string& layout = "NCW") {
457  int width_axis = -1;
458  ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout;
459  return adaptive_pool_impl(x, output_size, pool_type, {width_axis});
460 }
461 
487 inline Tensor global_pool(const Tensor& x, PoolType pool_type, const std::string& layout = "NCHW") {
488  return adaptive_pool(x, Array<PrimExpr>{1, 1}, pool_type, layout);
489 }
490 
507 inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
508  const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size,
509  const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode,
510  const std::vector<int>& axis, bool count_include_pad) {
511  int k_size = kernel_size.size();
512  int x_size = x->shape.size();
513  ICHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel";
514  ICHECK_EQ(padding_size.size(), k_size * 2) << "Pooling padding_size must has double elements of"
515  " kernel";
516  ICHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel";
517 
518  Array<IterVar> daxis;
519  std::vector<PrimExpr> kernel(k_size);
520  std::vector<PrimExpr> stride(k_size);
521  std::vector<PrimExpr> dilation(k_size);
522  std::vector<PrimExpr> pad_head(k_size);
523  std::vector<PrimExpr> pad_tail(k_size);
524  std::vector<PrimExpr> offset(k_size, 0);
525  Array<PrimExpr> pad_before(std::vector<PrimExpr>(x_size, 0));
526  Array<PrimExpr> pad_after(std::vector<PrimExpr>(x_size, 0));
527  Array<PrimExpr> data_shape = x->shape;
528  for (size_t i = 0; i < data_shape.size(); ++i) {
529  data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
530  }
531  Array<PrimExpr> out_shape = data_shape;
532 
533  bool do_pad = false;
534  for (int i = 0; i < k_size; i++) {
535  int ii = axis[i];
536  kernel[i] = cast(DataType::Int(32), kernel_size[i]);
537  stride[i] = cast(DataType::Int(32), stride_size[i]);
538  dilation[i] = cast(DataType::Int(32), dilation_size[i]);
539  pad_head[i] = cast(DataType::Int(32), padding_size[i]);
540  pad_tail[i] = cast(DataType::Int(32), padding_size[i + k_size]);
541 
542  if (ceil_mode) {
543  // The offset[i] is an additional padding to ensure we do ceil instead of floor when
544  // dividing by stride.
545  // In the case of ceil_mode=True and count_include_pad=True,
546  // in order to obtain the correct boundary,
547  // we also need to use the offset[i] to eliminate this extra padding.
548  offset[i] = stride[i] - 1;
549  pad_tail[i] += offset[i];
550  }
551 
552  const int64_t* padding0 = as_const_int(pad_head[i]);
553  const int64_t* padding1 = as_const_int(pad_tail[i]);
554  do_pad = do_pad || (padding0 && *padding0) || (padding1 && *padding1);
555 
556  daxis.push_back(tvm::te::reduce_axis(Range(0, kernel[i]), "rv" + std::to_string(i)));
557 
558  pad_before.Set(ii, pad_head[i]);
559  pad_after.Set(ii, pad_tail[i]);
560 
561  arith::Analyzer analyzer;
562 
563  PrimExpr numerator =
564  data_shape[ii] - (kernel[i] - 1) * dilation[i] - 1 + pad_head[i] + pad_tail[i];
565  auto out_dim = analyzer.Simplify(indexdiv(numerator, stride[i]) + 1);
566  out_shape.Set(ii, out_dim);
567  }
568 
569  if (pool_type == kMaxPool) {
570  auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
571  return tvm::te::compute(
572  out_shape,
573  [&](const Array<Var>& output) {
574  Array<PrimExpr> indices;
575  for (const Var& var : output) indices.push_back(var);
576 
577  for (int i = 0; i < k_size; i++) {
578  int ii = axis[i];
579  indices.Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]);
580  }
581  return tvm::max(temp(indices), daxis);
582  },
583  "tensor", "pool_max");
584  } else if (pool_type == kAvgPool) {
585  // Pad the inputs
586  auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
587 
588  // TVM compute for summing the pooling window.
589  auto pool_sum = tvm::te::compute(
590  out_shape,
591  [&](const Array<Var>& output) {
592  Array<PrimExpr> indices;
593  for (const Var& var : output) indices.push_back(var);
594 
595  for (int i = 0; i < k_size; i++) {
596  int ii = axis[i];
597  indices.Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]);
598  }
599  return tvm::sum(temp(indices), daxis);
600  },
601  "tensor", "pool_sum");
602 
603  // TVM compute for dividing the reduced window sum by kernel size.
604  return tvm::te::compute(
605  out_shape,
606  [&](const Array<Var>& output) {
607  Array<PrimExpr> indices;
608  for (const Var& var : output) indices.push_back(var);
609  if (count_include_pad) {
610  std::vector<PrimExpr> start(k_size);
611  std::vector<PrimExpr> end(k_size);
612  auto num_el = make_const(DataType::Int(32), 1);
613  for (int i = 0; i < k_size; i++) {
614  int ii = axis[i];
615  start[i] = output[ii] * stride[i] - pad_head[i];
616  // When computing the output shape in ceil_mode,
617  // we have added the extra padding of offset[i],
618  // so now in order to calculate the correct boundary ,
619  // we need to substract the offset[i].
620  end[i] = start[i] + (kernel[i] - 1) * dilation[i];
621  end[i] = min(end[i], data_shape[ii] + pad_tail[i] - 1 - offset[i]);
622  num_el *= (end[i] - start[i]) / dilation[i] + 1;
623  }
624  return div(pool_sum(indices), num_el);
625  } else {
626  std::vector<PrimExpr> start(k_size);
627  std::vector<PrimExpr> end(k_size);
628  auto num_el = make_const(DataType::Int(32), 1);
629  for (int i = 0; i < k_size; i++) {
630  int ii = axis[i];
631 
632  // Let start and end contain the first and last index of our Tensor
633  // along the relevant dimension we use in our calculation.
634  // Assume indices -1, -2 represent the padding before (tail) and
635  // len(arr), len(arr) + 1 represent the padding after (head).
636  start[i] = output[ii] * stride[i] - pad_head[i];
637  end[i] = start[i] + (kernel[i] - 1) * dilation[i];
638 
639  // if start[i] < 0, e.g. we start on a tail padded number this will be a positive
640  // number that represents the number of steps along the dilated kernel to reach a
641  // non-padded value. Otherwise this should be 0.
642  PrimExpr jumps_to_non_pad = (dilation[i] - 1 - start[i]) / dilation[i];
643  jumps_to_non_pad = max(jumps_to_non_pad, make_const(DataType::Int(32), 0));
644 
645  end[i] = min(end[i], data_shape[ii] - 1);
646  num_el *= (end[i] - (start[i] + dilation[i] * jumps_to_non_pad)) / dilation[i] + 1;
647  }
648 
649  PrimExpr divide_factor = max(num_el, make_const(DataType::Int(32), 1));
650  return div(pool_sum(indices), divide_factor);
651  }
652  },
653  "tensor", kElementWise);
654  } else {
655  LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
656  return x;
657  }
658 }
659 
690 inline Tensor pool1d(const Tensor& x, const Array<PrimExpr>& kernel_size,
691  const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size,
692  const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode,
693  const std::string& layout = "NCW", bool count_include_pad = true) {
694  int width_axis = -1;
695  ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout;
696  std::vector<int> axis = {width_axis};
697  return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
698  ceil_mode, axis, count_include_pad);
699 }
700 
731 inline Tensor pool2d(const Tensor& x, const Array<PrimExpr>& kernel_size,
732  const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size,
733  const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode,
734  const std::string& layout = "NCHW", bool count_include_pad = true) {
735  int height_axis = -1, width_axis = -1;
736  ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
737  std::vector<int> axis = {height_axis, width_axis};
738  return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
739  ceil_mode, axis, count_include_pad);
740 }
741 
773 inline Tensor pool3d(const Tensor& x, const Array<PrimExpr>& kernel_size,
774  const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size,
775  const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode,
776  const std::string& layout = "NCDHW", bool count_include_pad = true) {
777  int depth_axis = -1, height_axis = -1, width_axis = -1;
778  ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis))
779  << "Unsupported layout " << layout;
780  std::vector<int> axis = {depth_axis, height_axis, width_axis};
781  return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
782  ceil_mode, axis, count_include_pad);
783 }
784 
785 } // namespace nn
786 } // namespace topi
787 } // namespace tvm
788 #endif // TVM_TOPI_NN_POOLING_H_
Tensor max(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the maximum of elements over a given axis.
Definition: reduction.h:429
constexpr auto kCommReduceIdx
Definition: tags.h:35
Tensor pool1d(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCW", bool count_include_pad=true)
Perform pooling on the width dimension of data. Width axis is determined by the layout string in whic...
Definition: pooling.h:690
PrimExpr min_value(const DataType &dtype, Span span=Span())
FCommReduce MakeArgmaxReducer(bool select_last_index=false)
Definition: reduction.h:495
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:943
Tensor min(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the minimum of elements over a given axis.
Definition: reduction.h:410
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Tensor adaptive_pool_impl(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::vector< int > &axes)
Perform adaptive pooling on N dimensional data.
Definition: pooling.h:322
Tensor expression language DSL.
Definition: extracted_task.h:33
a named variable in TIR
Definition: var.h:88
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
Algebra expression simplifications.
Tensor pool3d(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCDHW", bool count_include_pad=true)
Perform pooling on depth, height and width dimension of data. It decides the depth, height and width dimension according to the layout string, in which &#39;D&#39;, &#39;W&#39; and &#39;H&#39; means depth, width and height respectively. Depth, Width and height dimension cannot be split. For example, NCDHW, NCDHW16c, etc. are valid for pool, while NCDHW16d, NCDHW16w or NCDHW16h are not. See layout for more information of the layout string convention.
Definition: pooling.h:773
Reduction op constructors.
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
PoolType
Pooling type.
Definition: pooling.h:44
const int64_t * as_const_int(const PrimExpr &x)
Get x as constant int expression.
Definition: op.h:791
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
void Set(int64_t i, T value)
set i-th element of the array.
Definition: array.h:586
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:455
Range constainer.
Definition: expr.h:711
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
size_t size() const
Definition: array.h:418
Definition: pooling.h:45
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Tensor pool_impl_nd(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::vector< int > &axis, bool count_include_pad)
Perform pooling on N-dimension of data.
Definition: pooling.h:507
Padding helpers.
constexpr auto kElementWise
Definition: tags.h:32
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
Tensor global_pool(const Tensor &x, PoolType pool_type, const std::string &layout="NCHW")
Perform global pooling on height and width dimension of data. It decides the height and width dimensi...
Definition: pooling.h:487
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
tvm::te::Tensor pad(const tvm::te::Tensor &t, const tvm::Array< tvm::PrimExpr > &pad_before, tvm::Array< tvm::PrimExpr > pad_after=tvm::Array< tvm::PrimExpr >(), PrimExpr pad_value=PrimExpr(), std::string name="T_pad", std::string tag=kElementWise, std::string pad_mode="constant", const Array< PrimExpr > *dyn_output_shape=nullptr)
Creates an operation that performs padding.
Definition: nn.h:155
Tensor argmax(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false, bool select_last_index=false)
Creates an operation that finds the indices of the maximum values over a given axis.
Definition: reduction.h:549
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
iterator end() const
Definition: array.h:388
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
iterator begin() const
Definition: array.h:385
PrimExpr end_index(const Var &out_index, const PrimExpr &odim, const PrimExpr &idim)
Definition: pooling.h:307
Tensor pool2d(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCHW", bool count_include_pad=true)
Perform pooling on height and width dimension of data. It decides the height and width dimension acco...
Definition: pooling.h:731
PrimExpr start_index(const Var &out_index, const PrimExpr &odim, const PrimExpr &idim)
Definition: pooling.h:303
Tensor adaptive_pool3d(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCDHW")
Adaptively perform pooling on three dimensional data. See the two dimensional version above for detai...
Definition: pooling.h:439
Tensor adaptive_pool1d(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCW")
Adaptively perform pooling on one dimensional data. See the two dimensional version above for details...
Definition: pooling.h:455
Tensor adaptive_pool(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCHW")
Adaptively perform pooling on height and width dimension of data. The pooling kernel and stride sizes...
Definition: pooling.h:424
Managed reference to SelectNode.
Definition: expr.h:589
Tensor pool_grad(const Tensor &out_grad, const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCHW", bool count_include_pad=true)
Calculate gradient of pooling on height and width dimension of data. It decides the height and width ...
Definition: pooling.h:293
Managed reference to AndNode.
Definition: expr.h:465
Tensor pool_grad_impl(const Tensor &out_grad, const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const size_t height_axis, const size_t width_axis, bool count_include_pad)
Definition: pooling.h:49
External function interface to rocBLAS libraries.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type...
Definition: elemwise.h:280
Reference to PrimExprNode.
Definition: expr.h:112
Definition: pooling.h:46
bool find_depth_height_width(const std::string &layout, int *depth_axis, int *height_axis, int *width_axis)
Definition: pooling.h:217
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:423
bool find_height_width(const std::string &layout, int *height_axis, int *width_axis)
Definition: pooling.h:245
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:154
NN op constructions.
bool find_width(const std::string &layout, int *width_axis)
Definition: pooling.h:254