tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
pooling.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_NN_POOLING_H_
25 #define TVM_TOPI_NN_POOLING_H_
26 
27 #include <tvm/arith/analyzer.h>
29 #include <tvm/topi/nn.h>
30 #include <tvm/topi/reduction.h>
31 #include <tvm/topi/tags.h>
32 
33 #include <algorithm>
34 #include <string>
35 #include <vector>
36 
37 namespace tvm {
38 namespace topi {
39 namespace nn {
40 
41 using namespace tvm::te;
42 
44 enum PoolType : int {
47 };
48 
49 inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
50  const Array<PrimExpr>& kernel_size, const Array<PrimExpr>& stride_size,
51  const Array<PrimExpr>& padding_size, PoolType pool_type,
52  bool ceil_mode, const size_t height_axis, const size_t width_axis,
53  bool count_include_pad) {
54  ICHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)";
55  ICHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)";
56  ICHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements";
57  ICHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
58  ICHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";
59 
60  auto kernel_height = cast(DataType::DataType::Int(32), kernel_size[0]);
61  auto kernel_width = cast(DataType::DataType::Int(32), kernel_size[1]);
62  auto stride_height = cast(DataType::DataType::Int(32), stride_size[0]);
63  auto stride_width = cast(DataType::DataType::Int(32), stride_size[1]);
64 
65  auto height = cast(DataType::DataType::Int(32), x->shape[height_axis]);
66  auto width = cast(DataType::DataType::Int(32), x->shape[width_axis]);
67 
68  auto pad_top = cast(DataType::DataType::Int(32), padding_size[0]);
69  auto pad_left = cast(DataType::DataType::Int(32), padding_size[1]);
70  auto pad_bottom = cast(DataType::DataType::Int(32), padding_size[2]);
71  auto pad_right = cast(DataType::DataType::Int(32), padding_size[3]);
72 
73  if (ceil_mode) {
74  // Additional padding to ensure we do ceil instead of floor when
75  // dividing by stride.
76  pad_bottom += stride_height - 1;
77  pad_right += stride_width - 1;
78  }
79 
80  Array<PrimExpr> pad_before(std::vector<PrimExpr>(x->shape.size(), 0));
81  pad_before.Set(height_axis, pad_top);
82  pad_before.Set(width_axis, pad_left);
83 
84  Array<PrimExpr> pad_after(std::vector<PrimExpr>(x->shape.size(), 0));
85  pad_after.Set(height_axis, pad_bottom);
86  pad_after.Set(width_axis, pad_right);
87  arith::Analyzer analyzer;
88  auto out_height =
89  analyzer.Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1);
90  auto out_width =
91  analyzer.Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1);
92 
93  auto dheight = tvm::te::reduce_axis(Range(0, kernel_height), "dh");
94  auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width), "dw");
95 
96  Array<PrimExpr> data_shape = x->shape;
97  for (size_t i = 0; i < data_shape.size(); ++i) {
98  data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
99  }
100 
101  Array<PrimExpr> out_shape = data_shape;
102  out_shape.Set(height_axis, out_height);
103  out_shape.Set(width_axis, out_width);
104 
105  const int64_t* padding_h0 = as_const_int(pad_top);
106  const int64_t* padding_w0 = as_const_int(pad_left);
107  const int64_t* padding_h1 = as_const_int(pad_bottom);
108  const int64_t* padding_w1 = as_const_int(pad_right);
109  const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
110  ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
111 
112  if (pool_type == kMaxPool) {
113  Array<PrimExpr> ravel_shape{data_shape.begin(), data_shape.end()};
114  ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom);
115  ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);
116 
117  auto windowh =
118  tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh");
119  auto windoww =
120  tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww");
121 
122  auto argmax = MakeArgmaxReducer();
123  auto pad_x = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
124 
125  auto mp_argmax = tvm::te::compute(
126  out_shape,
127  [&](const Array<Var>& inds) {
128  Array<PrimExpr> window_inds{inds.begin(), inds.end()};
129  window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight);
130  window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth);
131  auto idx = detail::RavelIndex(window_inds, ravel_shape);
132  return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr);
133  },
134  "maxpool_grad_argmax", kCommReduceIdx);
135 
136  auto mp_inds = mp_argmax[0];
137 
138  return tvm::te::compute(
139  data_shape,
140  [&](const Array<Var>& inds) {
141  Array<PrimExpr> pad_inds{inds.begin(), inds.end()};
142  pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top);
143  pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left);
144  auto idx = detail::RavelIndex(pad_inds, ravel_shape);
145 
146  Array<PrimExpr> out_idx{inds.begin(), inds.end()};
147  out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh);
148  out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww);
149 
150  PrimExpr out_idx_lower_h = tir::Select(
151  pad_inds[height_axis] < kernel_height, make_const(DataType::DataType::Int(32), 0),
152  (pad_inds[height_axis] - kernel_height) / stride_height + 1);
153  PrimExpr out_idx_lower_w = tir::Select(
154  pad_inds[width_axis] < kernel_width, make_const(DataType::DataType::Int(32), 0),
155  (pad_inds[width_axis] - kernel_width) / stride_width + 1);
156 
157  return tvm::sum(
158  tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h,
159  out_idx[width_axis] >= out_idx_lower_w),
160  mp_inds(out_idx) == idx),
161  out_grad(out_idx), make_const(x->dtype, 0)),
162  {windowh, windoww});
163  },
164  "T_pool_grad", "pool_grad_max");
165  } else if (pool_type == kAvgPool) {
166  auto windowh =
167  tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh");
168  auto windoww =
169  tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww");
170  return tvm::te::compute(
171  data_shape,
172  [&](const Array<Var>& inds) {
173  PrimExpr pad_h_idx = inds[height_axis] + pad_top;
174  PrimExpr pad_w_idx = inds[width_axis] + pad_left;
175 
176  // output indices whose pooling windows cover current input element (can be out-of-bound)
177  Array<PrimExpr> out_idx{inds.begin(), inds.end()};
178  out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh));
179  out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww));
180 
181  PrimExpr out_idx_lower_h =
182  tir::Select(pad_h_idx < kernel_height, make_const(DataType::Int(32), 0),
183  (pad_h_idx - kernel_height) / stride_height + 1);
184  PrimExpr out_idx_lower_w =
185  tir::Select(pad_w_idx < kernel_width, make_const(DataType::Int(32), 0),
186  (pad_w_idx - kernel_width) / stride_width + 1);
187 
188  PrimExpr divide_factor; // number of pooled elements
189  if (count_include_pad) {
190  divide_factor = kernel_height * kernel_width;
191  } else {
192  PrimExpr h_start = out_idx[height_axis] * stride_height - pad_top;
193  PrimExpr w_start = out_idx[width_axis] * stride_width - pad_left;
194 
195  PrimExpr h_end = min(h_start + kernel_height, height);
196  PrimExpr w_end = min(w_start + kernel_width, width);
197  h_start = max(h_start, make_const(DataType::Int(32), 0));
198  w_start = max(w_start, make_const(DataType::Int(32), 0));
199  divide_factor =
200  max((h_end - h_start) * (w_end - w_start), make_const(DataType::Int(32), 1));
201  }
202  return tvm::sum(
203  tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h,
204  out_idx[height_axis] < out_height),
205  tir::And(out_idx[width_axis] >= out_idx_lower_w,
206  out_idx[width_axis] < out_width)),
207  out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)),
208  {windowh, windoww});
209  },
210  "T_pool_grad", "pool_grad_avg");
211  } else {
212  LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
213  return Tensor();
214  }
215 }
216 
228 inline bool find_depth_height_width(const std::string& layout, int* depth_axis, int* height_axis,
229  int* width_axis) {
230  if (depth_axis) *depth_axis = -1;
231  if (height_axis) *height_axis = -1;
232  if (width_axis) *width_axis = -1;
233  int curr_idx = 0;
234  for (size_t i = 0; i < layout.size(); ++i) {
235  if ((layout[i] >= 'A' && layout[i] <= 'Z') || (layout[i] >= 'a' && layout[i] <= 'z')) {
236  if (layout[i] == 'D' && depth_axis) {
237  if (*depth_axis != -1) return false;
238  *depth_axis = curr_idx;
239  } else if (layout[i] == 'H' && height_axis) {
240  if (*height_axis != -1) return false;
241  *height_axis = curr_idx;
242  } else if (layout[i] == 'W' && width_axis) {
243  if (*width_axis != -1) return false;
244  *width_axis = curr_idx;
245  } else if (layout[i] == 'd' || layout[i] == 'h' || layout[i] == 'w') {
246  // do not support split on height, width or depth, e.g., NCHW16w
247  return false;
248  }
249  ++curr_idx;
250  }
251  }
252  if ((depth_axis && *depth_axis == -1) || (height_axis && *height_axis == -1) ||
253  (width_axis && *width_axis == -1))
254  return false;
255  return true;
256 }
257 
258 inline bool find_height_width(const std::string& layout, int* height_axis, int* width_axis) {
259  return find_depth_height_width(layout, /*depth_axis=*/nullptr, height_axis, width_axis);
260 }
261 
262 inline bool find_width(const std::string& layout, int* width_axis) {
263  return find_depth_height_width(layout, /*depth_axis=*/nullptr, /*height_axis=*/nullptr,
264  width_axis);
265 }
266 
297 inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array<PrimExpr>& kernel_size,
298  const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
299  PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW",
300  bool count_include_pad = true) {
301  int height_axis = -1, width_axis = -1;
302  ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
303  return pool_grad_impl(out_grad, x, kernel_size, stride_size, padding_size, pool_type, ceil_mode,
304  height_axis, width_axis, count_include_pad);
305 }
306 
307 inline PrimExpr start_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) {
308  return indexdiv(out_index * idim, odim);
309 }
310 
311 inline PrimExpr end_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) {
312  PrimExpr tmp = indexdiv((out_index + 1) * idim, odim);
313  return tvm::tir::Select(indexmod((out_index + 1) * idim, odim) == 0, tmp, tmp + 1);
314 }
315 
326 inline Tensor adaptive_pool_impl(const Tensor& x, const Array<PrimExpr>& output_size,
327  PoolType pool_type, const std::vector<int>& axes) {
328  const auto n_dim = output_size.size();
329  ICHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension";
330 
331  Array<PrimExpr> data_shape = x->shape;
332  for (size_t i = 0; i < data_shape.size(); ++i) {
333  data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
334  }
335  Array<PrimExpr> out_shape = data_shape;
336  Array<PrimExpr> in_size, out_size;
337  for (size_t i = 0; i < n_dim; ++i) {
338  in_size.push_back(data_shape[axes[i]]);
339  out_size.push_back(cast(DataType::Int(32), output_size[i]));
340  out_shape.Set(axes[i], out_size[i]);
341  }
342 
343  auto get_iter_vars = [=](const Array<Var>& output, bool reduce_indices) {
344  Array<PrimExpr> indices;
345  for (size_t i = 0; i < output.size(); ++i) indices.push_back(output[i]);
346  Array<tir::IterVar> reduce_axes;
347  for (size_t i = 0; i < n_dim; ++i) {
348  auto i_start = start_index(output[axes[i]], out_size[i], in_size[i]);
349  auto i_end = end_index(output[axes[i]], out_size[i], in_size[i]);
350  auto rv_name = "rv" + std::to_string(i);
351  auto rv_axis = tvm::te::reduce_axis(Range(0, i_end - i_start), rv_name);
352  reduce_axes.push_back(rv_axis);
353  if (reduce_indices) {
354  indices.Set(axes[i], i_start + rv_axis);
355  }
356  }
357  return std::make_tuple(indices, reduce_axes);
358  };
359 
361  if (pool_type == kMaxPool) {
362  attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.adaptive_pool_max"));
363  return tvm::te::compute(
364  out_shape,
365  [&](const Array<Var>& output) {
366  Array<PrimExpr> indices;
367  Array<tir::IterVar> reduce_axes;
368  std::tie(indices, reduce_axes) = get_iter_vars(output, true);
369  return tvm::max(x(indices), reduce_axes); // NOLINT(*)
370  },
371  "adaptive_pool_max", "adaptive_pool_max", attrs);
372  } else if (pool_type == kAvgPool) {
373  attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.adaptive_pool_avg"));
374  auto pool_sum = tvm::te::compute(
375  out_shape,
376  [&](const Array<Var>& output) {
377  Array<PrimExpr> indices;
378  Array<tir::IterVar> reduce_axes;
379  std::tie(indices, reduce_axes) = get_iter_vars(output, true);
380  return tvm::sum(x(indices), reduce_axes);
381  },
382  "adaptive_pool_sum", "adaptive_pool_sum");
383 
384  return tvm::te::compute(
385  out_shape,
386  [&](const Array<Var>& output) {
387  Array<PrimExpr> indices;
388  Array<tir::IterVar> reduce_axes;
389  std::tie(indices, reduce_axes) = get_iter_vars(output, false);
390 
391  PrimExpr divide_factor = tvm::cast(x->dtype, 1);
392  for (size_t i = 0; i < n_dim; ++i) {
393  divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent);
394  }
395 
396  return div(pool_sum(indices), divide_factor);
397  },
398  "adaptive_pool_avg", kElementWise, attrs);
399  } else {
400  LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
401  return x;
402  }
403 }
404 
431 inline Tensor adaptive_pool(const Tensor& x, const Array<PrimExpr>& output_size, PoolType pool_type,
432  const std::string& layout = "NCHW") {
433  int height_axis = -1, width_axis = -1;
434  ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
435  return adaptive_pool_impl(x, output_size, pool_type, {height_axis, width_axis});
436 }
437 
446 inline Tensor adaptive_pool3d(const Tensor& x, const Array<PrimExpr>& output_size,
447  PoolType pool_type, const std::string& layout = "NCDHW") {
448  int depth_axis = -1, height_axis = -1, width_axis = -1;
449  ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis))
450  << "Unsupported layout " << layout;
451  return adaptive_pool_impl(x, output_size, pool_type, {depth_axis, height_axis, width_axis});
452 }
453 
462 inline Tensor adaptive_pool1d(const Tensor& x, const Array<PrimExpr>& output_size,
463  PoolType pool_type, const std::string& layout = "NCW") {
464  int width_axis = -1;
465  ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout;
466  return adaptive_pool_impl(x, output_size, pool_type, {width_axis});
467 }
468 
494 inline Tensor global_pool(const Tensor& x, PoolType pool_type, const std::string& layout = "NCHW") {
495  return adaptive_pool(x, Array<PrimExpr>{1, 1}, pool_type, layout);
496 }
497 
514 inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
515  const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size,
516  const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode,
517  const std::vector<int>& axis, bool count_include_pad) {
518  int k_size = kernel_size.size();
519  int x_size = x->shape.size();
520  ICHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel";
521  ICHECK_EQ(padding_size.size(), k_size * 2) << "Pooling padding_size must has double elements of"
522  " kernel";
523  ICHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel";
524 
525  Array<IterVar> daxis;
526  std::vector<PrimExpr> kernel(k_size);
527  std::vector<PrimExpr> stride(k_size);
528  std::vector<PrimExpr> dilation(k_size);
529  std::vector<PrimExpr> pad_head(k_size);
530  std::vector<PrimExpr> pad_tail(k_size);
531  std::vector<PrimExpr> offset(k_size, 0);
532  Array<PrimExpr> pad_before(std::vector<PrimExpr>(x_size, 0));
533  Array<PrimExpr> pad_after(std::vector<PrimExpr>(x_size, 0));
534  Array<PrimExpr> data_shape = x->shape;
535  for (size_t i = 0; i < data_shape.size(); ++i) {
536  data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
537  }
538  Array<PrimExpr> out_shape = data_shape;
539 
540  bool do_pad = false;
541  for (int i = 0; i < k_size; i++) {
542  int ii = axis[i];
543  kernel[i] = cast(DataType::Int(32), kernel_size[i]);
544  stride[i] = cast(DataType::Int(32), stride_size[i]);
545  dilation[i] = cast(DataType::Int(32), dilation_size[i]);
546  pad_head[i] = cast(DataType::Int(32), padding_size[i]);
547  pad_tail[i] = cast(DataType::Int(32), padding_size[i + k_size]);
548 
549  if (ceil_mode) {
550  // The offset[i] is an additional padding to ensure we do ceil instead of floor when
551  // dividing by stride.
552  // In the case of ceil_mode=True and count_include_pad=True,
553  // in order to obtain the correct boundary,
554  // we also need to use the offset[i] to eliminate this extra padding.
555  offset[i] = stride[i] - 1;
556  pad_tail[i] += offset[i];
557  }
558 
559  const int64_t* padding0 = as_const_int(pad_head[i]);
560  const int64_t* padding1 = as_const_int(pad_tail[i]);
561  do_pad = do_pad || (padding0 && *padding0) || (padding1 && *padding1);
562 
563  daxis.push_back(tvm::te::reduce_axis(Range(0, kernel[i]), "rv" + std::to_string(i)));
564 
565  pad_before.Set(ii, pad_head[i]);
566  pad_after.Set(ii, pad_tail[i]);
567 
568  arith::Analyzer analyzer;
569 
570  PrimExpr numerator =
571  data_shape[ii] - (kernel[i] - 1) * dilation[i] - 1 + pad_head[i] + pad_tail[i];
572  auto out_dim = analyzer.Simplify(indexdiv(numerator, stride[i]) + 1);
573  out_shape.Set(ii, out_dim);
574  }
575 
577  if (pool_type == kMaxPool) {
578  auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
579  attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.pool_max"));
580  return tvm::te::compute(
581  out_shape,
582  [&](const Array<Var>& output) {
583  Array<PrimExpr> indices;
584  for (const Var& var : output) indices.push_back(var);
585 
586  for (int i = 0; i < k_size; i++) {
587  int ii = axis[i];
588  indices.Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]);
589  }
590  return tvm::max(temp(indices), daxis);
591  },
592  "pool_max", "pool_max", attrs);
593  } else if (pool_type == kAvgPool) {
594  attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.pool_avg"));
595  // Pad the inputs
596  auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
597 
598  // TVM compute for summing the pooling window.
599  auto pool_sum = tvm::te::compute(
600  out_shape,
601  [&](const Array<Var>& output) {
602  Array<PrimExpr> indices;
603  for (const Var& var : output) indices.push_back(var);
604 
605  for (int i = 0; i < k_size; i++) {
606  int ii = axis[i];
607  indices.Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]);
608  }
609  return tvm::sum(temp(indices), daxis);
610  },
611  "pool_sum", "pool_sum");
612 
613  // TVM compute for dividing the reduced window sum by kernel size.
614  return tvm::te::compute(
615  out_shape,
616  [&](const Array<Var>& output) {
617  Array<PrimExpr> indices;
618  for (const Var& var : output) indices.push_back(var);
619  if (count_include_pad) {
620  std::vector<PrimExpr> start(k_size);
621  std::vector<PrimExpr> end(k_size);
622  auto num_el = make_const(DataType::Int(32), 1);
623  for (int i = 0; i < k_size; i++) {
624  int ii = axis[i];
625  start[i] = output[ii] * stride[i] - pad_head[i];
626  // When computing the output shape in ceil_mode,
627  // we have added the extra padding of offset[i],
628  // so now in order to calculate the correct boundary ,
629  // we need to substract the offset[i].
630  end[i] = start[i] + (kernel[i] - 1) * dilation[i];
631  end[i] = min(end[i], data_shape[ii] + pad_tail[i] - 1 - offset[i]);
632  num_el *= (end[i] - start[i]) / dilation[i] + 1;
633  }
634  return div(pool_sum(indices), num_el);
635  } else {
636  std::vector<PrimExpr> start(k_size);
637  std::vector<PrimExpr> end(k_size);
638  auto num_el = make_const(DataType::Int(32), 1);
639  for (int i = 0; i < k_size; i++) {
640  int ii = axis[i];
641 
642  // Let start and end contain the first and last index of our Tensor
643  // along the relevant dimension we use in our calculation.
644  // Assume indices -1, -2 represent the padding before (tail) and
645  // len(arr), len(arr) + 1 represent the padding after (head).
646  start[i] = output[ii] * stride[i] - pad_head[i];
647  end[i] = start[i] + (kernel[i] - 1) * dilation[i];
648 
649  // if start[i] < 0, e.g. we start on a tail padded number this will be a positive
650  // number that represents the number of steps along the dilated kernel to reach a
651  // non-padded value. Otherwise this should be 0.
652  PrimExpr jumps_to_non_pad = (dilation[i] - 1 - start[i]) / dilation[i];
653  jumps_to_non_pad = max(jumps_to_non_pad, make_const(DataType::Int(32), 0));
654 
655  end[i] = min(end[i], data_shape[ii] - 1);
656  num_el *= (end[i] - (start[i] + dilation[i] * jumps_to_non_pad)) / dilation[i] + 1;
657  }
658 
659  PrimExpr divide_factor = max(num_el, make_const(DataType::Int(32), 1));
660  return div(pool_sum(indices), divide_factor);
661  }
662  },
663  "pool_avg", kElementWise, attrs);
664  } else {
665  LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
666  return x;
667  }
668 }
669 
700 inline Tensor pool1d(const Tensor& x, const Array<PrimExpr>& kernel_size,
701  const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size,
702  const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode,
703  const std::string& layout = "NCW", bool count_include_pad = true) {
704  int width_axis = -1;
705  ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout;
706  std::vector<int> axis = {width_axis};
707  return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
708  ceil_mode, axis, count_include_pad);
709 }
710 
741 inline Tensor pool2d(const Tensor& x, const Array<PrimExpr>& kernel_size,
742  const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size,
743  const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode,
744  const std::string& layout = "NCHW", bool count_include_pad = true) {
745  int height_axis = -1, width_axis = -1;
746  ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
747  std::vector<int> axis = {height_axis, width_axis};
748  return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
749  ceil_mode, axis, count_include_pad);
750 }
751 
783 inline Tensor pool3d(const Tensor& x, const Array<PrimExpr>& kernel_size,
784  const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size,
785  const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode,
786  const std::string& layout = "NCDHW", bool count_include_pad = true) {
787  int depth_axis = -1, height_axis = -1, width_axis = -1;
788  ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis))
789  << "Unsupported layout " << layout;
790  std::vector<int> axis = {depth_axis, height_axis, width_axis};
791  return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
792  ceil_mode, axis, count_include_pad);
793 }
794 
795 } // namespace nn
796 } // namespace topi
797 } // namespace tvm
798 #endif // TVM_TOPI_NN_POOLING_H_
Tensor max(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the maximum of elements over a given axis.
Definition: reduction.h:433
constexpr auto kCommReduceIdx
Definition: tags.h:35
Tensor pool1d(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCW", bool count_include_pad=true)
Perform pooling on the width dimension of data. Width axis is determined by the layout string in whic...
Definition: pooling.h:700
PrimExpr min_value(const DataType &dtype, Span span=Span())
FCommReduce MakeArgmaxReducer(bool select_last_index=false)
Definition: reduction.h:499
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:954
Tensor min(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the minimum of elements over a given axis.
Definition: reduction.h:414
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Tensor adaptive_pool_impl(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::vector< int > &axes)
Perform adaptive pooling on N dimensional data.
Definition: pooling.h:326
Tensor expression language DSL.
Definition: extracted_task.h:33
a named variable in TIR
Definition: var.h:88
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
Algebra expression simplifications.
Tensor pool3d(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCDHW", bool count_include_pad=true)
Perform pooling on depth, height and width dimension of data. It decides the depth, height and width dimension according to the layout string, in which &#39;D&#39;, &#39;W&#39; and &#39;H&#39; means depth, width and height respectively. Depth, Width and height dimension cannot be split. For example, NCDHW, NCDHW16c, etc. are valid for pool, while NCDHW16d, NCDHW16w or NCDHW16h are not. See layout for more information of the layout string convention.
Definition: pooling.h:783
Reduction op constructors.
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
PoolType
Pooling type.
Definition: pooling.h:44
const int64_t * as_const_int(const PrimExpr &x)
Get x as constant int expression.
Definition: op.h:803
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
void Set(int64_t i, T value)
set i-th element of the array.
Definition: array.h:621
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
Range constainer.
Definition: expr.h:715
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
size_t size() const
Definition: array.h:420
Definition: pooling.h:45
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Tensor pool_impl_nd(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::vector< int > &axis, bool count_include_pad)
Perform pooling on N-dimension of data.
Definition: pooling.h:514
Padding helpers.
constexpr auto kElementWise
Definition: tags.h:32
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
Reference to string objects.
Definition: string.h:98
Tensor global_pool(const Tensor &x, PoolType pool_type, const std::string &layout="NCHW")
Perform global pooling on height and width dimension of data. It decides the height and width dimensi...
Definition: pooling.h:494
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
tvm::te::Tensor pad(const tvm::te::Tensor &t, const tvm::Array< tvm::PrimExpr > &pad_before, tvm::Array< tvm::PrimExpr > pad_after=tvm::Array< tvm::PrimExpr >(), PrimExpr pad_value=PrimExpr(), std::string name="T_pad", std::string tag=kElementWise, std::string pad_mode="constant", const Array< PrimExpr > *dyn_output_shape=nullptr)
Creates an operation that performs padding.
Definition: nn.h:155
Tensor argmax(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false, bool select_last_index=false)
Creates an operation that finds the indices of the maximum values over a given axis.
Definition: reduction.h:553
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
iterator end() const
Definition: array.h:390
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
iterator begin() const
Definition: array.h:387
PrimExpr end_index(const Var &out_index, const PrimExpr &odim, const PrimExpr &idim)
Definition: pooling.h:311
Tensor pool2d(const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCHW", bool count_include_pad=true)
Perform pooling on height and width dimension of data. It decides the height and width dimension acco...
Definition: pooling.h:741
PrimExpr start_index(const Var &out_index, const PrimExpr &odim, const PrimExpr &idim)
Definition: pooling.h:307
Tensor adaptive_pool3d(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCDHW")
Adaptively perform pooling on three dimensional data. See the two dimensional version above for detai...
Definition: pooling.h:446
Tensor adaptive_pool1d(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCW")
Adaptively perform pooling on one dimensional data. See the two dimensional version above for details...
Definition: pooling.h:462
Tensor adaptive_pool(const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCHW")
Adaptively perform pooling on height and width dimension of data. The pooling kernel and stride sizes...
Definition: pooling.h:431
Managed reference to SelectNode.
Definition: expr.h:609
Tensor pool_grad(const Tensor &out_grad, const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCHW", bool count_include_pad=true)
Calculate gradient of pooling on height and width dimension of data. It decides the height and width ...
Definition: pooling.h:297
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
Managed reference to AndNode.
Definition: expr.h:482
Tensor pool_grad_impl(const Tensor &out_grad, const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const size_t height_axis, const size_t width_axis, bool count_include_pad)
Definition: pooling.h:49
External function interface to rocBLAS libraries.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type...
Definition: elemwise.h:281
Reference to PrimExprNode.
Definition: expr.h:114
void Set(const K &key, const V &value)
set the Map.
Definition: map.h:1374
Definition: pooling.h:46
bool find_depth_height_width(const std::string &layout, int *depth_axis, int *height_axis, int *width_axis)
Find index of Depth, Height or Width dimension in a layout string.
Definition: pooling.h:228
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:579
bool find_height_width(const std::string &layout, int *height_axis, int *width_axis)
Definition: pooling.h:258
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:164
NN op constructions.
bool find_width(const std::string &layout, int *width_axis)
Definition: pooling.h:262