tvm
pooling.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_NN_POOLING_H_
25 #define TVM_TOPI_NN_POOLING_H_
26 
27 #include <tvm/arith/analyzer.h>
29 #include <tvm/topi/nn.h>
30 #include <tvm/topi/reduction.h>
31 #include <tvm/topi/tags.h>
32 
33 #include <algorithm>
34 #include <string>
35 #include <vector>
36 
37 namespace tvm {
38 namespace topi {
39 namespace nn {
40 
41 using namespace tvm::te;
42 
44 enum PoolType : int {
47 };
48 
49 inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
50  const Array<PrimExpr>& kernel_size, const Array<PrimExpr>& stride_size,
51  const Array<PrimExpr>& padding_size, PoolType pool_type,
52  bool ceil_mode, const size_t height_axis, const size_t width_axis,
53  bool count_include_pad) {
54  ICHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)";
55  ICHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)";
56  ICHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements";
57  ICHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
58  ICHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";
59 
60  auto kernel_height = cast(DataType::DataType::Int(32), kernel_size[0]);
61  auto kernel_width = cast(DataType::DataType::Int(32), kernel_size[1]);
62  auto stride_height = cast(DataType::DataType::Int(32), stride_size[0]);
63  auto stride_width = cast(DataType::DataType::Int(32), stride_size[1]);
64 
65  auto height = cast(DataType::DataType::Int(32), x->shape[height_axis]);
66  auto width = cast(DataType::DataType::Int(32), x->shape[width_axis]);
67 
68  auto pad_top = cast(DataType::DataType::Int(32), padding_size[0]);
69  auto pad_left = cast(DataType::DataType::Int(32), padding_size[1]);
70  auto pad_bottom = cast(DataType::DataType::Int(32), padding_size[2]);
71  auto pad_right = cast(DataType::DataType::Int(32), padding_size[3]);
72 
73  if (ceil_mode) {
74  // Additional padding to ensure we do ceil instead of floor when
75  // dividing by stride.
76  pad_bottom += stride_height - 1;
77  pad_right += stride_width - 1;
78  }
79 
80  Array<PrimExpr> pad_before(std::vector<PrimExpr>(x->shape.size(), 0));
81  pad_before.Set(height_axis, pad_top);
82  pad_before.Set(width_axis, pad_left);
83 
84  Array<PrimExpr> pad_after(std::vector<PrimExpr>(x->shape.size(), 0));
85  pad_after.Set(height_axis, pad_bottom);
86  pad_after.Set(width_axis, pad_right);
87  arith::Analyzer analyzer;
88  auto out_height =
89  analyzer.Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1);
90  auto out_width =
91  analyzer.Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1);
92 
93  auto dheight = tvm::te::reduce_axis(Range(0, kernel_height), "dh");
94  auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width), "dw");
95 
96  Array<PrimExpr> data_shape = x->shape;
97  for (size_t i = 0; i < data_shape.size(); ++i) {
98  data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
99  }
100 
101  Array<PrimExpr> out_shape = data_shape;
102  out_shape.Set(height_axis, out_height);
103  out_shape.Set(width_axis, out_width);
104 
105  const int64_t* padding_h0 = as_const_int(pad_top);
106  const int64_t* padding_w0 = as_const_int(pad_left);
107  const int64_t* padding_h1 = as_const_int(pad_bottom);
108  const int64_t* padding_w1 = as_const_int(pad_right);
109  const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
110  ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
111 
112  if (pool_type == kMaxPool) {
113  Array<PrimExpr> ravel_shape{data_shape.begin(), data_shape.end()};
114  ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom);
115  ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);
116 
117  auto windowh =
118  tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh");
119  auto windoww =
120  tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww");
121 
122  auto argmax = MakeArgmaxReducer();
123  auto pad_x = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
124 
125  auto mp_argmax = tvm::te::compute(
126  out_shape,
127  [&](const Array<Var>& inds) {
128  Array<PrimExpr> window_inds{inds.begin(), inds.end()};
129  window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight);
130  window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth);
131  auto idx = detail::RavelIndex(window_inds, ravel_shape);
132  return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr);
133  },
134  "maxpool_grad_argmax", kCommReduceIdx);
135 
136  auto mp_inds = mp_argmax[0];
137 
138  return tvm::te::compute(
139  data_shape,
140  [&](const Array<Var>& inds) {
141  Array<PrimExpr> pad_inds{inds.begin(), inds.end()};
142  pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top);
143  pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left);
144  auto idx = detail::RavelIndex(pad_inds, ravel_shape);
145 
146  Array<PrimExpr> out_idx{inds.begin(), inds.end()};
147  out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh);
148  out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww);
149 
150  PrimExpr out_idx_lower_h = tir::Select(
151  pad_inds[height_axis] < kernel_height, make_const(DataType::DataType::Int(32), 0),
152  (pad_inds[height_axis] - kernel_height) / stride_height + 1);
153  PrimExpr out_idx_lower_w = tir::Select(
154  pad_inds[width_axis] < kernel_width, make_const(DataType::DataType::Int(32), 0),
155  (pad_inds[width_axis] - kernel_width) / stride_width + 1);
156 
157  return tvm::sum(
158  tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h,
159  out_idx[width_axis] >= out_idx_lower_w),
160  mp_inds(out_idx) == idx),
161  out_grad(out_idx), make_const(x->dtype, 0)),
162  {windowh, windoww});
163  },
164  "T_pool_grad", "pool_grad_max");
165  } else if (pool_type == kAvgPool) {
166  auto windowh =
167  tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh");
168  auto windoww =
169  tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww");
170  return tvm::te::compute(
171  data_shape,
172  [&](const Array<Var>& inds) {
173  PrimExpr pad_h_idx = inds[height_axis] + pad_top;
174  PrimExpr pad_w_idx = inds[width_axis] + pad_left;
175 
176  // output indices whose pooling windows cover current input element (can be out-of-bound)
177  Array<PrimExpr> out_idx{inds.begin(), inds.end()};
178  out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh));
179  out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww));
180 
181  PrimExpr out_idx_lower_h =
182  tir::Select(pad_h_idx < kernel_height, make_const(DataType::Int(32), 0),
183  (pad_h_idx - kernel_height) / stride_height + 1);
184  PrimExpr out_idx_lower_w =
185  tir::Select(pad_w_idx < kernel_width, make_const(DataType::Int(32), 0),
186  (pad_w_idx - kernel_width) / stride_width + 1);
187 
188  PrimExpr divide_factor; // number of pooled elements
189  if (count_include_pad) {
190  divide_factor = kernel_height * kernel_width;
191  } else {
192  PrimExpr h_start = out_idx[height_axis] * stride_height - pad_top;
193  PrimExpr w_start = out_idx[width_axis] * stride_width - pad_left;
194 
195  PrimExpr h_end = min(h_start + kernel_height, height);
196  PrimExpr w_end = min(w_start + kernel_width, width);
197  h_start = max(h_start, make_const(DataType::Int(32), 0));
198  w_start = max(w_start, make_const(DataType::Int(32), 0));
199  divide_factor =
200  max((h_end - h_start) * (w_end - w_start), make_const(DataType::Int(32), 1));
201  }
202  return tvm::sum(
203  tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h,
204  out_idx[height_axis] < out_height),
205  tir::And(out_idx[width_axis] >= out_idx_lower_w,
206  out_idx[width_axis] < out_width)),
207  out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)),
208  {windowh, windoww});
209  },
210  "T_pool_grad", "pool_grad_avg");
211  } else {
212  LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
213  return Tensor();
214  }
215 }
216 
217 inline bool find_depth_height_width(const std::string& layout, int* depth_axis, int* height_axis,
218  int* width_axis) {
219  *depth_axis = -1;
220  *height_axis = -1;
221  *width_axis = -1;
222  int curr_idx = 0;
223  for (size_t i = 0; i < layout.size(); ++i) {
224  if ((layout[i] >= 'A' && layout[i] <= 'Z') || (layout[i] >= 'a' && layout[i] <= 'z')) {
225  if (layout[i] == 'D') {
226  if (*depth_axis != -1) return false;
227  *depth_axis = curr_idx;
228  } else if (layout[i] == 'H') {
229  if (*height_axis != -1) return false;
230  *height_axis = curr_idx;
231  } else if (layout[i] == 'W') {
232  if (*width_axis != -1) return false;
233  *width_axis = curr_idx;
234  } else if (layout[i] == 'd' || layout[i] == 'h' || layout[i] == 'w') {
235  // do not support split on height or width, e.g., NCHW16w
236  return false;
237  }
238  ++curr_idx;
239  }
240  }
241  if (*depth_axis == -1 || *height_axis == -1 || *width_axis == -1) return false;
242  return true;
243 }
244 
245 inline bool find_height_width(const std::string& layout, int* height_axis, int* width_axis) {
246  int dummy;
247  ICHECK_EQ(find_depth_height_width(layout, &dummy, height_axis, width_axis), false);
248  if (*height_axis != -1 && *width_axis != -1) {
249  return true;
250  }
251  return false;
252 }
253 
254 inline bool find_width(const std::string& layout, int* width_axis) {
255  int dummy;
256  ICHECK_EQ(find_depth_height_width(layout, &dummy, &dummy, width_axis), false);
257  if (*width_axis != -1) {
258  return true;
259  }
260  return false;
261 }
262 
293 inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array<PrimExpr>& kernel_size,
294  const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
295  PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW",
296  bool count_include_pad = true) {
297  int height_axis = -1, width_axis = -1;
298  ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
299  return pool_grad_impl(out_grad, x, kernel_size, stride_size, padding_size, pool_type, ceil_mode,
300  height_axis, width_axis, count_include_pad);
301 }
302 
303 inline PrimExpr start_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) {
304  return indexdiv(out_index * idim, odim);
305 }
306 
307 inline PrimExpr end_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) {
308  PrimExpr tmp = indexdiv((out_index + 1) * idim, odim);
309  return tvm::tir::Select(indexmod((out_index + 1) * idim, odim) == 0, tmp, tmp + 1);
310 }
311 
322 inline Tensor adaptive_pool_impl(const Tensor& x, const Array<PrimExpr>& output_size,
323  PoolType pool_type, const std::vector<int>& axes) {
324  const auto n_dim = output_size.size();
325  ICHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension";
326 
327  Array<PrimExpr> data_shape = x->shape;
328  for (size_t i = 0; i < data_shape.size(); ++i) {
329  data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
330  }
331  Array<PrimExpr> out_shape = data_shape;
332  Array<PrimExpr> in_size, out_size;
333  for (size_t i = 0; i < n_dim; ++i) {
334  in_size.push_back(data_shape[axes[i]]);
335  out_size.push_back(cast(DataType::Int(32), output_size[i]));
336  out_shape.Set(axes[i], out_size[i]);
337  }
338 
339  auto get_iter_vars = [=](const Array<Var>& output, bool reduce_indices) {
340  Array<PrimExpr> indices;
341  for (size_t i = 0; i < output.size(); ++i) indices.push_back(output[i]);
342  Array<tir::IterVar> reduce_axes;
343  for (size_t i = 0; i < n_dim; ++i) {
344  auto i_start = start_index(output[axes[i]], out_size[i], in_size[i]);
345  auto i_end = end_index(output[axes[i]], out_size[i], in_size[i]);
346  auto rv_name = "rv" + std::to_string(i);
347  auto rv_axis = tvm::te::reduce_axis(Range(0, i_end - i_start), rv_name);
348  reduce_axes.push_back(rv_axis);
349  if (reduce_indices) {
350  indices.Set(axes[i], i_start + rv_axis);
351  }
352  }
353  return std::make_tuple(indices, reduce_axes);
354  };
355 
357  if (pool_type == kMaxPool) {
358  attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.adaptive_pool_max"));
359  return tvm::te::compute(
360  out_shape,
361  [&](const Array<Var>& output) {
362  Array<PrimExpr> indices;
363  Array<tir::IterVar> reduce_axes;
364  std::tie(indices, reduce_axes) = get_iter_vars(output, true);
365  return tvm::max(x(indices), reduce_axes); // NOLINT(*)
366  },
367  "adaptive_pool_max", "adaptive_pool_max", attrs);
368  } else if (pool_type == kAvgPool) {
369  attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.adaptive_pool_avg"));
370  auto pool_sum = tvm::te::compute(
371  out_shape,
372  [&](const Array<Var>& output) {
373  Array<PrimExpr> indices;
374  Array<tir::IterVar> reduce_axes;
375  std::tie(indices, reduce_axes) = get_iter_vars(output, true);
376  return tvm::sum(x(indices), reduce_axes);
377  },
378  "adaptive_pool_sum", "adaptive_pool_sum");
379 
380  return tvm::te::compute(
381  out_shape,
382  [&](const Array<Var>& output) {
383  Array<PrimExpr> indices;
384  Array<tir::IterVar> reduce_axes;
385  std::tie(indices, reduce_axes) = get_iter_vars(output, false);
386 
387  PrimExpr divide_factor = tvm::cast(x->dtype, 1);
388  for (size_t i = 0; i < n_dim; ++i) {
389  divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent);
390  }
391 
392  return div(pool_sum(indices), divide_factor);
393  },
394  "adaptive_pool_avg", kElementWise, attrs);
395  } else {
396  LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
397  return x;
398  }
399 }
400 
427 inline Tensor adaptive_pool(const Tensor& x, const Array<PrimExpr>& output_size, PoolType pool_type,
428  const std::string& layout = "NCHW") {
429  int height_axis = -1, width_axis = -1;
430  ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
431  return adaptive_pool_impl(x, output_size, pool_type, {height_axis, width_axis});
432 }
433 
442 inline Tensor adaptive_pool3d(const Tensor& x, const Array<PrimExpr>& output_size,
443  PoolType pool_type, const std::string& layout = "NCDHW") {
444  int depth_axis = -1, height_axis = -1, width_axis = -1;
445  ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis))
446  << "Unsupported layout " << layout;
447  return adaptive_pool_impl(x, output_size, pool_type, {depth_axis, height_axis, width_axis});
448 }
449 
458 inline Tensor adaptive_pool1d(const Tensor& x, const Array<PrimExpr>& output_size,
459  PoolType pool_type, const std::string& layout = "NCW") {
460  int width_axis = -1;
461  ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout;
462  return adaptive_pool_impl(x, output_size, pool_type, {width_axis});
463 }
464 
490 inline Tensor global_pool(const Tensor& x, PoolType pool_type, const std::string& layout = "NCHW") {
491  return adaptive_pool(x, Array<PrimExpr>{1, 1}, pool_type, layout);
492 }
493 
510 inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
511  const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size,
512  const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode,
513  const std::vector<int>& axis, bool count_include_pad) {
514  int k_size = kernel_size.size();
515  int x_size = x->shape.size();
516  ICHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel";
517  ICHECK_EQ(padding_size.size(), k_size * 2) << "Pooling padding_size must has double elements of"
518  " kernel";
519  ICHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel";
520 
521  Array<IterVar> daxis;
522  std::vector<PrimExpr> kernel(k_size);
523  std::vector<PrimExpr> stride(k_size);
524  std::vector<PrimExpr> dilation(k_size);
525  std::vector<PrimExpr> pad_head(k_size);
526  std::vector<PrimExpr> pad_tail(k_size);
527  std::vector<PrimExpr> offset(k_size, 0);
528  Array<PrimExpr> pad_before(std::vector<PrimExpr>(x_size, 0));
529  Array<PrimExpr> pad_after(std::vector<PrimExpr>(x_size, 0));
530  Array<PrimExpr> data_shape = x->shape;
531  for (size_t i = 0; i < data_shape.size(); ++i) {
532  data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
533  }
534  Array<PrimExpr> out_shape = data_shape;
535 
536  bool do_pad = false;
537  for (int i = 0; i < k_size; i++) {
538  int ii = axis[i];
539  kernel[i] = cast(DataType::Int(32), kernel_size[i]);
540  stride[i] = cast(DataType::Int(32), stride_size[i]);
541  dilation[i] = cast(DataType::Int(32), dilation_size[i]);
542  pad_head[i] = cast(DataType::Int(32), padding_size[i]);
543  pad_tail[i] = cast(DataType::Int(32), padding_size[i + k_size]);
544 
545  if (ceil_mode) {
546  // The offset[i] is an additional padding to ensure we do ceil instead of floor when
547  // dividing by stride.
548  // In the case of ceil_mode=True and count_include_pad=True,
549  // in order to obtain the correct boundary,
550  // we also need to use the offset[i] to eliminate this extra padding.
551  offset[i] = stride[i] - 1;
552  pad_tail[i] += offset[i];
553  }
554 
555  const int64_t* padding0 = as_const_int(pad_head[i]);
556  const int64_t* padding1 = as_const_int(pad_tail[i]);
557  do_pad = do_pad || (padding0 && *padding0) || (padding1 && *padding1);
558 
559  daxis.push_back(tvm::te::reduce_axis(Range(0, kernel[i]), "rv" + std::to_string(i)));
560 
561  pad_before.Set(ii, pad_head[i]);
562  pad_after.Set(ii, pad_tail[i]);
563 
564  arith::Analyzer analyzer;
565 
566  PrimExpr numerator =
567  data_shape[ii] - (kernel[i] - 1) * dilation[i] - 1 + pad_head[i] + pad_tail[i];
568  auto out_dim = analyzer.Simplify(indexdiv(numerator, stride[i]) + 1);
569  out_shape.Set(ii, out_dim);
570  }
571 
573  if (pool_type == kMaxPool) {
574  auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
575  attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.pool_max"));
576  return tvm::te::compute(
577  out_shape,
578  [&](const Array<Var>& output) {
579  Array<PrimExpr> indices;
580  for (const Var& var : output) indices.push_back(var);
581 
582  for (int i = 0; i < k_size; i++) {
583  int ii = axis[i];
584  indices.Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]);
585  }
586  return tvm::max(temp(indices), daxis);
587  },
588  "pool_max", "pool_max", attrs);
589  } else if (pool_type == kAvgPool) {
590  attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.pool_avg"));
591  // Pad the inputs
592  auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
593 
594  // TVM compute for summing the pooling window.
595  auto pool_sum = tvm::te::compute(
596  out_shape,
597  [&](const Array<Var>& output) {
598  Array<PrimExpr> indices;
599  for (const Var& var : output) indices.push_back(var);
600 
601  for (int i = 0; i < k_size; i++) {
602  int ii = axis[i];
603  indices.Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]);
604  }
605  return tvm::sum(temp(indices), daxis);
606  },
607  "pool_sum", "pool_sum");
608 
609  // TVM compute for dividing the reduced window sum by kernel size.
610  return tvm::te::compute(
611  out_shape,
612  [&](const Array<Var>& output) {
613  Array<PrimExpr> indices;
614  for (const Var& var : output) indices.push_back(var);
615  if (count_include_pad) {
616  std::vector<PrimExpr> start(k_size);
617  std::vector<PrimExpr> end(k_size);
618  auto num_el = make_const(DataType::Int(32), 1);
619  for (int i = 0; i < k_size; i++) {
620  int ii = axis[i];
621  start[i] = output[ii] * stride[i] - pad_head[i];
622  // When computing the output shape in ceil_mode,
623  // we have added the extra padding of offset[i],
624  // so now in order to calculate the correct boundary ,
625  // we need to substract the offset[i].
626  end[i] = start[i] + (kernel[i] - 1) * dilation[i];
627  end[i] = min(end[i], data_shape[ii] + pad_tail[i] - 1 - offset[i]);
628  num_el *= (end[i] - start[i]) / dilation[i] + 1;
629  }
630  return div(pool_sum(indices), num_el);
631  } else {
632  std::vector<PrimExpr> start(k_size);
633  std::vector<PrimExpr> end(k_size);
634  auto num_el = make_const(DataType::Int(32), 1);
635  for (int i = 0; i < k_size; i++) {
636  int ii = axis[i];
637 
638  // Let start and end contain the first and last index of our Tensor
639  // along the relevant dimension we use in our calculation.
640  // Assume indices -1, -2 represent the padding before (tail) and
641  // len(arr), len(arr) + 1 represent the padding after (head).
642  start[i] = output[ii] * stride[i] - pad_head[i];
643  end[i] = start[i] + (kernel[i] - 1) * dilation[i];
644 
645  // if start[i] < 0, e.g. we start on a tail padded number this will be a positive
646  // number that represents the number of steps along the dilated kernel to reach a
647  // non-padded value. Otherwise this should be 0.
648  PrimExpr jumps_to_non_pad = (dilation[i] - 1 - start[i]) / dilation[i];
649  jumps_to_non_pad = max(jumps_to_non_pad, make_const(DataType::Int(32), 0));
650 
651  end[i] = min(end[i], data_shape[ii] - 1);
652  num_el *= (end[i] - (start[i] + dilation[i] * jumps_to_non_pad)) / dilation[i] + 1;
653  }
654 
655  PrimExpr divide_factor = max(num_el, make_const(DataType::Int(32), 1));
656  return div(pool_sum(indices), divide_factor);
657  }
658  },
659  "pool_avg", kElementWise, attrs);
660  } else {
661  LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
662  return x;
663  }
664 }
665 
696 inline Tensor pool1d(const Tensor& x, const Array<PrimExpr>& kernel_size,
697  const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size,
698  const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode,
699  const std::string& layout = "NCW", bool count_include_pad = true) {
700  int width_axis = -1;
701  ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout;
702  std::vector<int> axis = {width_axis};
703  return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
704  ceil_mode, axis, count_include_pad);
705 }
706 
737 inline Tensor pool2d(const Tensor& x, const Array<PrimExpr>& kernel_size,
738  const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size,
739  const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode,
740  const std::string& layout = "NCHW", bool count_include_pad = true) {
741  int height_axis = -1, width_axis = -1;
742  ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
743  std::vector<int> axis = {height_axis, width_axis};
744  return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
745  ceil_mode, axis, count_include_pad);
746 }
747 
779 inline Tensor pool3d(const Tensor& x, const Array<PrimExpr>& kernel_size,
780  const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size,
781  const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode,
782  const std::string& layout = "NCDHW", bool count_include_pad = true) {
783  int depth_axis = -1, height_axis = -1, width_axis = -1;
784  ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis))
785  << "Unsupported layout " << layout;
786  std::vector<int> axis = {depth_axis, height_axis, width_axis};
787  return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
788  ceil_mode, axis, count_include_pad);
789 }
790 
791 } // namespace nn
792 } // namespace topi
793 } // namespace tvm
794 #endif // TVM_TOPI_NN_POOLING_H_
Tensor max(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the maximum of elements over a given axis.
Definition: reduction.h:429
constexpr auto kCommReduceIdx
Definition: tags.h:35
Tensor pool1d(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCW", bool count_include_pad=true)
Perform pooling on the width dimension of data. Width axis is determined by the layout string in whic...
Definition: pooling.h:696
PrimExpr min_value(const DataType &dtype, Span span=Span())
FCommReduce MakeArgmaxReducer(bool select_last_index=false)
Definition: reduction.h:495
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:942
Tensor min(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the minimum of elements over a given axis.
Definition: reduction.h:410
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Tensor adaptive_pool_impl(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::vector< int > &axes)
Perform adaptive pooling on N dimensional data.
Definition: pooling.h:322
Tensor expression language DSL.
Definition: extracted_task.h:33
a named variable in TIR
Definition: var.h:88
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
Algebra expression simplifications.
Tensor pool3d(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCDHW", bool count_include_pad=true)
Perform pooling on depth, height and width dimension of data. It decides the depth, height and width dimension according to the layout string, in which &#39;D&#39;, &#39;W&#39; and &#39;H&#39; means depth, width and height respectively. Depth, Width and height dimension cannot be split. For example, NCDHW, NCDHW16c, etc. are valid for pool, while NCDHW16d, NCDHW16w or NCDHW16h are not. See layout for more information of the layout string convention.
Definition: pooling.h:779
Reduction op constructors.
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
PoolType
Pooling type.
Definition: pooling.h:44
const int64_t * as_const_int(const PrimExpr &x)
Get x as constant int expression.
Definition: op.h:791
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
void Set(int64_t i, T value)
set i-th element of the array.
Definition: array.h:591
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
Range constainer.
Definition: expr.h:713
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
size_t size() const
Definition: array.h:420
Definition: pooling.h:45
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Tensor pool_impl_nd(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::vector< int > &axis, bool count_include_pad)
Perform pooling on N-dimension of data.
Definition: pooling.h:510
Padding helpers.
constexpr auto kElementWise
Definition: tags.h:32
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
Reference to string objects.
Definition: string.h:97
Tensor global_pool(const Tensor &x, PoolType pool_type, const std::string &layout="NCHW")
Perform global pooling on height and width dimension of data. It decides the height and width dimensi...
Definition: pooling.h:490
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
tvm::te::Tensor pad(const tvm::te::Tensor &t, const tvm::Array< tvm::PrimExpr > &pad_before, tvm::Array< tvm::PrimExpr > pad_after=tvm::Array< tvm::PrimExpr >(), PrimExpr pad_value=PrimExpr(), std::string name="T_pad", std::string tag=kElementWise, std::string pad_mode="constant", const Array< PrimExpr > *dyn_output_shape=nullptr)
Creates an operation that performs padding.
Definition: nn.h:155
Tensor argmax(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false, bool select_last_index=false)
Creates an operation that finds the indices of the maximum values over a given axis.
Definition: reduction.h:549
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
iterator end() const
Definition: array.h:390
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
iterator begin() const
Definition: array.h:387
PrimExpr end_index(const Var &out_index, const PrimExpr &odim, const PrimExpr &idim)
Definition: pooling.h:307
Tensor pool2d(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCHW", bool count_include_pad=true)
Perform pooling on height and width dimension of data. It decides the height and width dimension acco...
Definition: pooling.h:737
PrimExpr start_index(const Var &out_index, const PrimExpr &odim, const PrimExpr &idim)
Definition: pooling.h:303
Tensor adaptive_pool3d(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCDHW")
Adaptively perform pooling on three dimensional data. See the two dimensional version above for detai...
Definition: pooling.h:442
Tensor adaptive_pool1d(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCW")
Adaptively perform pooling on one dimensional data. See the two dimensional version above for details...
Definition: pooling.h:458
Tensor adaptive_pool(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCHW")
Adaptively perform pooling on height and width dimension of data. The pooling kernel and stride sizes...
Definition: pooling.h:427
Managed reference to SelectNode.
Definition: expr.h:609
Tensor pool_grad(const Tensor &out_grad, const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCHW", bool count_include_pad=true)
Calculate gradient of pooling on height and width dimension of data. It decides the height and width ...
Definition: pooling.h:293
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
Managed reference to AndNode.
Definition: expr.h:482
Tensor pool_grad_impl(const Tensor &out_grad, const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const size_t height_axis, const size_t width_axis, bool count_include_pad)
Definition: pooling.h:49
External function interface to rocBLAS libraries.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type...
Definition: elemwise.h:280
Reference to PrimExprNode.
Definition: expr.h:112
void Set(const K &key, const V &value)
set the Map.
Definition: map.h:1374
Definition: pooling.h:46
bool find_depth_height_width(const std::string &layout, int *depth_axis, int *height_axis, int *width_axis)
Definition: pooling.h:217
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:563
bool find_height_width(const std::string &layout, int *height_axis, int *width_axis)
Definition: pooling.h:245
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:164
NN op constructions.
bool find_width(const std::string &layout, int *width_axis)
Definition: pooling.h:254