tvm
pooling.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_NN_POOLING_H_
25 #define TVM_TOPI_NN_POOLING_H_
26 
27 #include <tvm/arith/analyzer.h>
29 #include <tvm/topi/nn.h>
30 #include <tvm/topi/reduction.h>
31 #include <tvm/topi/tags.h>
32 
33 #include <algorithm>
34 #include <string>
35 #include <vector>
36 
37 namespace tvm {
38 namespace topi {
39 namespace nn {
40 
41 using namespace tvm::te;
42 
44 enum PoolType : int {
47 };
48 
49 inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
50  const ffi::Array<PrimExpr>& kernel_size,
51  const ffi::Array<PrimExpr>& stride_size,
52  const ffi::Array<PrimExpr>& padding_size, PoolType pool_type,
53  bool ceil_mode, const size_t height_axis, const size_t width_axis,
54  bool count_include_pad) {
55  ICHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)";
56  ICHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)";
57  ICHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements";
58  ICHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
59  ICHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";
60 
61  auto kernel_height = kernel_size[0];
62  auto kernel_width = kernel_size[1];
63  auto stride_height = stride_size[0];
64  auto stride_width = stride_size[1];
65 
66  auto height = x->shape[height_axis];
67  auto width = x->shape[width_axis];
68 
69  auto pad_top = padding_size[0];
70  auto pad_left = padding_size[1];
71  auto pad_bottom = padding_size[2];
72  auto pad_right = padding_size[3];
73 
74  if (ceil_mode) {
75  // Additional padding to ensure we do ceil instead of floor when
76  // dividing by stride.
77  pad_bottom += stride_height - 1;
78  pad_right += stride_width - 1;
79  }
80 
81  ffi::Array<PrimExpr> pad_before(std::vector<PrimExpr>(x->shape.size(), 0));
82  pad_before.Set(height_axis, pad_top);
83  pad_before.Set(width_axis, pad_left);
84 
85  ffi::Array<PrimExpr> pad_after(std::vector<PrimExpr>(x->shape.size(), 0));
86  pad_after.Set(height_axis, pad_bottom);
87  pad_after.Set(width_axis, pad_right);
88  arith::Analyzer analyzer;
89  auto out_height =
90  analyzer.Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1);
91  auto out_width =
92  analyzer.Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1);
93 
94  auto dheight = tvm::te::reduce_axis(Range(0, kernel_height), "dh");
95  auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width), "dw");
96 
97  ffi::Array<PrimExpr> data_shape = x->shape;
98  ffi::Array<PrimExpr> out_shape = data_shape;
99  out_shape.Set(height_axis, out_height);
100  out_shape.Set(width_axis, out_width);
101 
102  const int64_t* padding_h0 = as_const_int(pad_top);
103  const int64_t* padding_w0 = as_const_int(pad_left);
104  const int64_t* padding_h1 = as_const_int(pad_bottom);
105  const int64_t* padding_w1 = as_const_int(pad_right);
106  const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
107  ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
108 
109  if (pool_type == kMaxPool) {
110  ffi::Array<PrimExpr> ravel_shape{data_shape.begin(), data_shape.end()};
111  ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom);
112  ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);
113 
114  auto windowh =
115  tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh");
116  auto windoww =
117  tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww");
118 
119  auto argmax = MakeArgmaxReducer();
120  auto pad_x = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
121 
122  auto mp_argmax = tvm::te::compute(
123  out_shape,
124  [&](const ffi::Array<Var>& inds) {
125  ffi::Array<PrimExpr> window_inds{inds.begin(), inds.end()};
126  window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight);
127  window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth);
128  auto idx = detail::RavelIndex(window_inds, ravel_shape);
129  return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr);
130  },
131  "maxpool_grad_argmax", kCommReduceIdx);
132 
133  auto mp_inds = mp_argmax[0];
134 
135  return tvm::te::compute(
136  data_shape,
137  [&](const ffi::Array<Var>& inds) {
138  ffi::Array<PrimExpr> pad_inds{inds.begin(), inds.end()};
139  pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top);
140  pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left);
141  auto idx = detail::RavelIndex(pad_inds, ravel_shape);
142 
143  ffi::Array<PrimExpr> out_idx{inds.begin(), inds.end()};
144  out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh);
145  out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww);
146 
147  PrimExpr out_idx_lower_h = tir::Select(
148  pad_inds[height_axis] < kernel_height, make_const(pad_inds[height_axis].dtype(), 0),
149  (pad_inds[height_axis] - kernel_height) / stride_height + 1);
150  PrimExpr out_idx_lower_w = tir::Select(
151  pad_inds[width_axis] < kernel_width, make_const(pad_inds[width_axis].dtype(), 0),
152  (pad_inds[width_axis] - kernel_width) / stride_width + 1);
153 
154  return tvm::sum(
155  tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h,
156  out_idx[width_axis] >= out_idx_lower_w),
157  mp_inds(out_idx) == idx),
158  out_grad(out_idx), make_const(x->dtype, 0)),
159  {windowh, windoww});
160  },
161  "T_pool_grad", "pool_grad_max");
162  } else if (pool_type == kAvgPool) {
163  auto windowh =
164  tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh");
165  auto windoww =
166  tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww");
167  return tvm::te::compute(
168  data_shape,
169  [&](const ffi::Array<Var>& inds) {
170  PrimExpr pad_h_idx = inds[height_axis] + pad_top;
171  PrimExpr pad_w_idx = inds[width_axis] + pad_left;
172 
173  // output indices whose pooling windows cover current input element (can be out-of-bound)
174  ffi::Array<PrimExpr> out_idx{inds.begin(), inds.end()};
175  out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh));
176  out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww));
177 
178  PrimExpr out_idx_lower_h =
179  tir::Select(pad_h_idx < kernel_height, make_const(pad_h_idx.dtype(), 0),
180  (pad_h_idx - kernel_height) / stride_height + 1);
181  PrimExpr out_idx_lower_w =
182  tir::Select(pad_w_idx < kernel_width, make_const(pad_w_idx.dtype(), 0),
183  (pad_w_idx - kernel_width) / stride_width + 1);
184 
185  PrimExpr divide_factor; // number of pooled elements
186  if (count_include_pad) {
187  divide_factor = kernel_height * kernel_width;
188  } else {
189  PrimExpr h_start = out_idx[height_axis] * stride_height - pad_top;
190  PrimExpr w_start = out_idx[width_axis] * stride_width - pad_left;
191 
192  PrimExpr h_end = min(h_start + kernel_height, height);
193  PrimExpr w_end = min(w_start + kernel_width, width);
194  h_start = max(h_start, make_const(h_start.dtype(), 0));
195  w_start = max(w_start, make_const(w_start.dtype(), 0));
196  divide_factor =
197  max((h_end - h_start) * (w_end - w_start), make_const(h_end.dtype(), 1));
198  }
199  return tvm::sum(
200  tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h,
201  out_idx[height_axis] < out_height),
202  tir::And(out_idx[width_axis] >= out_idx_lower_w,
203  out_idx[width_axis] < out_width)),
204  out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)),
205  {windowh, windoww});
206  },
207  "T_pool_grad", "pool_grad_avg");
208  } else {
209  LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
210  return Tensor();
211  }
212 }
213 
225 inline bool find_depth_height_width(const std::string& layout, int* depth_axis, int* height_axis,
226  int* width_axis) {
227  if (depth_axis) *depth_axis = -1;
228  if (height_axis) *height_axis = -1;
229  if (width_axis) *width_axis = -1;
230  int curr_idx = 0;
231  for (size_t i = 0; i < layout.size(); ++i) {
232  if ((layout[i] >= 'A' && layout[i] <= 'Z') || (layout[i] >= 'a' && layout[i] <= 'z')) {
233  if (layout[i] == 'D' && depth_axis) {
234  if (*depth_axis != -1) return false;
235  *depth_axis = curr_idx;
236  } else if (layout[i] == 'H' && height_axis) {
237  if (*height_axis != -1) return false;
238  *height_axis = curr_idx;
239  } else if (layout[i] == 'W' && width_axis) {
240  if (*width_axis != -1) return false;
241  *width_axis = curr_idx;
242  } else if (layout[i] == 'd' || layout[i] == 'h' || layout[i] == 'w') {
243  // do not support split on height, width or depth, e.g., NCHW16w
244  return false;
245  }
246  ++curr_idx;
247  }
248  }
249  if ((depth_axis && *depth_axis == -1) || (height_axis && *height_axis == -1) ||
250  (width_axis && *width_axis == -1))
251  return false;
252  return true;
253 }
254 
255 inline bool find_height_width(const std::string& layout, int* height_axis, int* width_axis) {
256  return find_depth_height_width(layout, /*depth_axis=*/nullptr, height_axis, width_axis);
257 }
258 
259 inline bool find_width(const std::string& layout, int* width_axis) {
260  return find_depth_height_width(layout, /*depth_axis=*/nullptr, /*height_axis=*/nullptr,
261  width_axis);
262 }
263 
294 inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x,
295  const ffi::Array<PrimExpr>& kernel_size,
296  const ffi::Array<PrimExpr>& stride_size,
297  const ffi::Array<PrimExpr>& padding_size, PoolType pool_type,
298  bool ceil_mode, const std::string& layout = "NCHW",
299  bool count_include_pad = true) {
300  int height_axis = -1, width_axis = -1;
301  ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
302  return pool_grad_impl(out_grad, x, kernel_size, stride_size, padding_size, pool_type, ceil_mode,
303  height_axis, width_axis, count_include_pad);
304 }
305 
306 inline PrimExpr start_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) {
307  return indexdiv(out_index * idim, odim);
308 }
309 
310 inline PrimExpr end_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) {
311  PrimExpr tmp = indexdiv((out_index + 1) * idim, odim);
312  return tvm::tir::Select(indexmod((out_index + 1) * idim, odim) == 0, tmp, tmp + 1);
313 }
314 
325 inline Tensor adaptive_pool_impl(const Tensor& x, const ffi::Array<PrimExpr>& output_size,
326  PoolType pool_type, const std::vector<int>& axes) {
327  const auto n_dim = output_size.size();
328  ICHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension";
329 
330  ffi::Array<PrimExpr> data_shape = x->shape;
331  ffi::Array<PrimExpr> out_shape = data_shape;
332  ffi::Array<PrimExpr> in_size, out_size;
333  for (size_t i = 0; i < n_dim; ++i) {
334  in_size.push_back(data_shape[axes[i]]);
335  out_size.push_back(output_size[i]);
336  out_shape.Set(axes[i], out_size[i]);
337  }
338 
339  auto get_iter_vars = [=](const ffi::Array<Var>& output, bool reduce_indices) {
340  ffi::Array<PrimExpr> indices;
341  for (size_t i = 0; i < output.size(); ++i) indices.push_back(output[i]);
342  ffi::Array<tir::IterVar> reduce_axes;
343  for (size_t i = 0; i < n_dim; ++i) {
344  auto i_start = start_index(output[axes[i]], out_size[i], in_size[i]);
345  auto i_end = end_index(output[axes[i]], out_size[i], in_size[i]);
346  auto rv_name = "rv" + std::to_string(i);
347  auto rv_axis = tvm::te::reduce_axis(Range(0, i_end - i_start), rv_name);
348  reduce_axes.push_back(rv_axis);
349  if (reduce_indices) {
350  indices.Set(axes[i], i_start + rv_axis);
351  }
352  }
353  return std::make_tuple(indices, reduce_axes);
354  };
355 
356  ffi::Map<ffi::String, ffi::Any> attrs;
357  if (pool_type == kMaxPool) {
358  attrs.Set("schedule_rule", tvm::ffi::String("meta_schedule.adaptive_pool_max"));
359  return tvm::te::compute(
360  out_shape,
361  [&](const ffi::Array<Var>& output) {
362  ffi::Array<PrimExpr> indices;
363  ffi::Array<tir::IterVar> reduce_axes;
364  std::tie(indices, reduce_axes) = get_iter_vars(output, true);
365  return tvm::max(x(indices), reduce_axes); // NOLINT(*)
366  },
367  "adaptive_pool_max", "adaptive_pool_max", attrs);
368  } else if (pool_type == kAvgPool) {
369  attrs.Set("schedule_rule", tvm::ffi::String("meta_schedule.adaptive_pool_avg"));
370  auto pool_sum = tvm::te::compute(
371  out_shape,
372  [&](const ffi::Array<Var>& output) {
373  ffi::Array<PrimExpr> indices;
374  ffi::Array<tir::IterVar> reduce_axes;
375  std::tie(indices, reduce_axes) = get_iter_vars(output, true);
376  return tvm::sum(x(indices), reduce_axes);
377  },
378  "adaptive_pool_sum", "adaptive_pool_sum");
379 
380  return tvm::te::compute(
381  out_shape,
382  [&](const ffi::Array<Var>& output) {
383  ffi::Array<PrimExpr> indices;
384  ffi::Array<tir::IterVar> reduce_axes;
385  std::tie(indices, reduce_axes) = get_iter_vars(output, false);
386 
387  PrimExpr divide_factor = tvm::cast(x->dtype, 1);
388  for (size_t i = 0; i < n_dim; ++i) {
389  divide_factor *= tvm::cast(DataType::Int(32), reduce_axes[i]->dom->extent);
390  }
391 
392  return div(pool_sum(indices), divide_factor);
393  },
394  "adaptive_pool_avg", kElementWise, attrs);
395  } else {
396  LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
397  return x;
398  }
399 }
400 
427 inline Tensor adaptive_pool(const Tensor& x, const ffi::Array<PrimExpr>& output_size,
428  PoolType pool_type, const std::string& layout = "NCHW") {
429  int height_axis = -1, width_axis = -1;
430  ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
431  return adaptive_pool_impl(x, output_size, pool_type, {height_axis, width_axis});
432 }
433 
442 inline Tensor adaptive_pool3d(const Tensor& x, const ffi::Array<PrimExpr>& output_size,
443  PoolType pool_type, const std::string& layout = "NCDHW") {
444  int depth_axis = -1, height_axis = -1, width_axis = -1;
445  ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis))
446  << "Unsupported layout " << layout;
447  return adaptive_pool_impl(x, output_size, pool_type, {depth_axis, height_axis, width_axis});
448 }
449 
458 inline Tensor adaptive_pool1d(const Tensor& x, const ffi::Array<PrimExpr>& output_size,
459  PoolType pool_type, const std::string& layout = "NCW") {
460  int width_axis = -1;
461  ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout;
462  return adaptive_pool_impl(x, output_size, pool_type, {width_axis});
463 }
464 
490 inline Tensor global_pool(const Tensor& x, PoolType pool_type, const std::string& layout = "NCHW") {
491  return adaptive_pool(x, ffi::Array<PrimExpr>{1, 1}, pool_type, layout);
492 }
493 
510 inline Tensor pool_impl_nd(const Tensor& x, const ffi::Array<PrimExpr>& kernel_size,
511  const ffi::Array<PrimExpr>& stride_size,
512  const ffi::Array<PrimExpr>& dilation_size,
513  const ffi::Array<PrimExpr>& padding_size, PoolType pool_type,
514  bool ceil_mode, const std::vector<int>& axis, bool count_include_pad) {
515  int k_size = kernel_size.size();
516  int x_size = x->shape.size();
517  ICHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel";
518  ICHECK_EQ(padding_size.size(), k_size * 2) << "Pooling padding_size must has double elements of"
519  " kernel";
520  ICHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel";
521 
522  ffi::Array<IterVar> daxis;
523  std::vector<PrimExpr> kernel(k_size);
524  std::vector<PrimExpr> stride(k_size);
525  std::vector<PrimExpr> dilation(k_size);
526  std::vector<PrimExpr> pad_head(k_size);
527  std::vector<PrimExpr> pad_tail(k_size);
528  std::vector<PrimExpr> offset(k_size, 0);
529  ffi::Array<PrimExpr> pad_before(std::vector<PrimExpr>(x_size, 0));
530  ffi::Array<PrimExpr> pad_after(std::vector<PrimExpr>(x_size, 0));
531  ffi::Array<PrimExpr> data_shape = x->shape;
532  ffi::Array<PrimExpr> out_shape = data_shape;
533 
534  bool do_pad = false;
535  for (int i = 0; i < k_size; i++) {
536  int ii = axis[i];
537  kernel[i] = kernel_size[i];
538  stride[i] = stride_size[i];
539  dilation[i] = dilation_size[i];
540  pad_head[i] = padding_size[i];
541  pad_tail[i] = padding_size[i + k_size];
542 
543  if (ceil_mode) {
544  // The offset[i] is an additional padding to ensure we do ceil instead of floor when
545  // dividing by stride.
546  // In the case of ceil_mode=True and count_include_pad=True,
547  // in order to obtain the correct boundary,
548  // we also need to use the offset[i] to eliminate this extra padding.
549  offset[i] = stride[i] - 1;
550  pad_tail[i] += offset[i];
551  }
552 
553  const int64_t* padding0 = as_const_int(pad_head[i]);
554  const int64_t* padding1 = as_const_int(pad_tail[i]);
555  do_pad = do_pad || (padding0 && *padding0) || (padding1 && *padding1);
556 
557  daxis.push_back(tvm::te::reduce_axis(Range(0, kernel[i]), "rv" + std::to_string(i)));
558 
559  pad_before.Set(ii, pad_head[i]);
560  pad_after.Set(ii, pad_tail[i]);
561 
562  arith::Analyzer analyzer;
563 
564  PrimExpr numerator =
565  data_shape[ii] - (kernel[i] - 1) * dilation[i] - 1 + pad_head[i] + pad_tail[i];
566  auto out_dim = analyzer.Simplify(indexdiv(numerator, stride[i]) + 1);
567  out_shape.Set(ii, out_dim);
568  }
569 
570  ffi::Map<ffi::String, ffi::Any> attrs;
571  if (pool_type == kMaxPool) {
572  auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
573  attrs.Set("schedule_rule", tvm::ffi::String("meta_schedule.pool_max"));
574  return tvm::te::compute(
575  out_shape,
576  [&](const ffi::Array<Var>& output) {
577  ffi::Array<PrimExpr> indices;
578  for (const Var& var : output) indices.push_back(var);
579 
580  for (int i = 0; i < k_size; i++) {
581  int ii = axis[i];
582  indices.Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]);
583  }
584  return tvm::max(temp(indices), daxis);
585  },
586  "pool_max", "pool_max", attrs);
587  } else if (pool_type == kAvgPool) {
588  attrs.Set("schedule_rule", tvm::ffi::String("meta_schedule.pool_avg"));
589  // Pad the inputs
590  auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
591 
592  // TVM compute for summing the pooling window.
593  auto pool_sum = tvm::te::compute(
594  out_shape,
595  [&](const ffi::Array<Var>& output) {
596  ffi::Array<PrimExpr> indices;
597  for (const Var& var : output) indices.push_back(var);
598 
599  for (int i = 0; i < k_size; i++) {
600  int ii = axis[i];
601  indices.Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]);
602  }
603  return tvm::sum(temp(indices), daxis);
604  },
605  "pool_sum", "pool_sum");
606 
607  // TVM compute for dividing the reduced window sum by kernel size.
608  return tvm::te::compute(
609  out_shape,
610  [&](const ffi::Array<Var>& output) {
611  ffi::Array<PrimExpr> indices;
612  for (const Var& var : output) indices.push_back(var);
613  if (count_include_pad) {
614  std::vector<PrimExpr> start(k_size);
615  std::vector<PrimExpr> end(k_size);
616  auto num_el = make_const(DataType::Int(32), 1);
617  for (int i = 0; i < k_size; i++) {
618  int ii = axis[i];
619  start[i] = output[ii] * stride[i] - pad_head[i];
620  // When computing the output shape in ceil_mode,
621  // we have added the extra padding of offset[i],
622  // so now in order to calculate the correct boundary ,
623  // we need to substract the offset[i].
624  end[i] = start[i] + (kernel[i] - 1) * dilation[i];
625  end[i] = min(end[i], data_shape[ii] + pad_tail[i] - 1 - offset[i]);
626  num_el *= (end[i] - start[i]) / dilation[i] + 1;
627  }
628  return div(pool_sum(indices), num_el);
629  } else {
630  std::vector<PrimExpr> start(k_size);
631  std::vector<PrimExpr> end(k_size);
632  auto num_el = make_const(DataType::Int(32), 1);
633  for (int i = 0; i < k_size; i++) {
634  int ii = axis[i];
635 
636  // Let start and end contain the first and last index of our Tensor
637  // along the relevant dimension we use in our calculation.
638  // Assume indices -1, -2 represent the padding before (tail) and
639  // len(arr), len(arr) + 1 represent the padding after (head).
640  start[i] = output[ii] * stride[i] - pad_head[i];
641  end[i] = start[i] + (kernel[i] - 1) * dilation[i];
642 
643  // if start[i] < 0, e.g. we start on a tail padded number this will be a positive
644  // number that represents the number of steps along the dilated kernel to reach a
645  // non-padded value. Otherwise this should be 0.
646  PrimExpr jumps_to_non_pad = (dilation[i] - 1 - start[i]) / dilation[i];
647  jumps_to_non_pad = max(jumps_to_non_pad, make_const(jumps_to_non_pad.dtype(), 0));
648 
649  end[i] = min(end[i], data_shape[ii] - 1);
650  num_el *= (end[i] - (start[i] + dilation[i] * jumps_to_non_pad)) / dilation[i] + 1;
651  }
652 
653  PrimExpr divide_factor = max(num_el, make_const(DataType::Int(32), 1));
654  return div(pool_sum(indices), divide_factor);
655  }
656  },
657  "pool_avg", kElementWise, attrs);
658  } else {
659  LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
660  return x;
661  }
662 }
663 
694 inline Tensor pool1d(const Tensor& x, const ffi::Array<PrimExpr>& kernel_size,
695  const ffi::Array<PrimExpr>& stride_size,
696  const ffi::Array<PrimExpr>& dilation_size,
697  const ffi::Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode,
698  const std::string& layout = "NCW", bool count_include_pad = true) {
699  int width_axis = -1;
700  ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout;
701  std::vector<int> axis = {width_axis};
702  return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
703  ceil_mode, axis, count_include_pad);
704 }
705 
736 inline Tensor pool2d(const Tensor& x, const ffi::Array<PrimExpr>& kernel_size,
737  const ffi::Array<PrimExpr>& stride_size,
738  const ffi::Array<PrimExpr>& dilation_size,
739  const ffi::Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode,
740  const std::string& layout = "NCHW", bool count_include_pad = true) {
741  int height_axis = -1, width_axis = -1;
742  ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
743  std::vector<int> axis = {height_axis, width_axis};
744  return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
745  ceil_mode, axis, count_include_pad);
746 }
747 
779 inline Tensor pool3d(const Tensor& x, const ffi::Array<PrimExpr>& kernel_size,
780  const ffi::Array<PrimExpr>& stride_size,
781  const ffi::Array<PrimExpr>& dilation_size,
782  const ffi::Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode,
783  const std::string& layout = "NCDHW", bool count_include_pad = true) {
784  int depth_axis = -1, height_axis = -1, width_axis = -1;
785  ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis))
786  << "Unsupported layout " << layout;
787  std::vector<int> axis = {depth_axis, height_axis, width_axis};
788  return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
789  ceil_mode, axis, count_include_pad);
790 }
791 
792 } // namespace nn
793 } // namespace topi
794 } // namespace tvm
795 #endif // TVM_TOPI_NN_POOLING_H_
Algebra expression simplifications.
Reference to PrimExprNode.
Definition: expr.h:124
DataType dtype() const
Definition: expr.h:138
Range container
Definition: expr.h:689
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:634
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:274
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:100
Managed reference to AndNode.
Definition: expr.h:428
Managed reference to SelectNode.
Definition: expr.h:515
a named variable in TIR
Definition: var.h:77
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(ffi::Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", ffi::Map< ffi::String, ffi::Any > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:994
const int64_t * as_const_int(const PrimExpr &x)
Get x as constant int expression.
Definition: op.h:836
Tensor adaptive_pool3d(const Tensor &x, const ffi::Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCDHW")
Adaptively perform pooling on three dimensional data. See the two dimensional version above for detai...
Definition: pooling.h:442
Tensor adaptive_pool(const Tensor &x, const ffi::Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCHW")
Adaptively perform pooling on height and width dimension of data. The pooling kernel and stride sizes...
Definition: pooling.h:427
PoolType
Pooling type.
Definition: pooling.h:44
@ kAvgPool
Definition: pooling.h:45
@ kMaxPool
Definition: pooling.h:46
Tensor adaptive_pool1d(const Tensor &x, const ffi::Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCW")
Adaptively perform pooling on one dimensional data. See the two dimensional version above for details...
Definition: pooling.h:458
Tensor pool3d(const Tensor &x, const ffi::Array< PrimExpr > &kernel_size, const ffi::Array< PrimExpr > &stride_size, const ffi::Array< PrimExpr > &dilation_size, const ffi::Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCDHW", bool count_include_pad=true)
Perform pooling on depth, height and width dimension of data. It decides the depth,...
Definition: pooling.h:779
Tensor pool2d(const Tensor &x, const ffi::Array< PrimExpr > &kernel_size, const ffi::Array< PrimExpr > &stride_size, const ffi::Array< PrimExpr > &dilation_size, const ffi::Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCHW", bool count_include_pad=true)
Perform pooling on height and width dimension of data. It decides the height and width dimension acco...
Definition: pooling.h:736
Tensor pool_grad_impl(const Tensor &out_grad, const Tensor &x, const ffi::Array< PrimExpr > &kernel_size, const ffi::Array< PrimExpr > &stride_size, const ffi::Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const size_t height_axis, const size_t width_axis, bool count_include_pad)
Definition: pooling.h:49
PrimExpr start_index(const Var &out_index, const PrimExpr &odim, const PrimExpr &idim)
Definition: pooling.h:306
PrimExpr end_index(const Var &out_index, const PrimExpr &odim, const PrimExpr &idim)
Definition: pooling.h:310
Tensor pool_grad(const Tensor &out_grad, const Tensor &x, const ffi::Array< PrimExpr > &kernel_size, const ffi::Array< PrimExpr > &stride_size, const ffi::Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCHW", bool count_include_pad=true)
Calculate gradient of pooling on height and width dimension of data. It decides the height and width ...
Definition: pooling.h:294
bool find_depth_height_width(const std::string &layout, int *depth_axis, int *height_axis, int *width_axis)
Find index of Depth, Height or Width dimension in a layout string.
Definition: pooling.h:225
bool find_width(const std::string &layout, int *width_axis)
Definition: pooling.h:259
Tensor pool_impl_nd(const Tensor &x, const ffi::Array< PrimExpr > &kernel_size, const ffi::Array< PrimExpr > &stride_size, const ffi::Array< PrimExpr > &dilation_size, const ffi::Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::vector< int > &axis, bool count_include_pad)
Perform pooling on N-dimension of data.
Definition: pooling.h:510
bool find_height_width(const std::string &layout, int *height_axis, int *width_axis)
Definition: pooling.h:255
Tensor global_pool(const Tensor &x, PoolType pool_type, const std::string &layout="NCHW")
Perform global pooling on height and width dimension of data. It decides the height and width dimensi...
Definition: pooling.h:490
Tensor adaptive_pool_impl(const Tensor &x, const ffi::Array< PrimExpr > &output_size, PoolType pool_type, const std::vector< int > &axes)
Perform adaptive pooling on N dimensional data.
Definition: pooling.h:325
Tensor pool1d(const Tensor &x, const ffi::Array< PrimExpr > &kernel_size, const ffi::Array< PrimExpr > &stride_size, const ffi::Array< PrimExpr > &dilation_size, const ffi::Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCW", bool count_include_pad=true)
Perform pooling on the width dimension of data. Width axis is determined by the layout string in whic...
Definition: pooling.h:694
constexpr auto kElementWise
Definition: tags.h:32
Tensor max(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the maximum of elements over a given axis.
Definition: reduction.h:442
FCommReduce MakeArgmaxReducer(bool select_last_index=false)
Definition: reduction.h:509
tvm::te::Tensor pad(const tvm::te::Tensor &t, const tvm::ffi::Array< tvm::PrimExpr > &pad_before, tvm::ffi::Array< tvm::PrimExpr > pad_after=tvm::ffi::Array< tvm::PrimExpr >(), PrimExpr pad_value=PrimExpr(), std::string name="T_pad", std::string tag=kElementWise, std::string pad_mode="constant", const ffi::Array< PrimExpr > *dyn_output_shape=nullptr)
Creates an operation that performs padding.
Definition: nn.h:155
constexpr auto kCommReduceIdx
Definition: tags.h:35
Tensor min(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the minimum of elements over a given axis.
Definition: reduction.h:423
Tensor argmax(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false, bool select_last_index=false)
Creates an operation that finds the indices of the maximum values over a given axis.
Definition: reduction.h:563
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
PrimExpr min_value(const DataType &dtype, Span span=Span())
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr sum(PrimExpr source, ffi::Array< tir::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
Padding helpers.
Reduction op constructors.
External function interface to rocBLAS libraries.
NN op constructions.