tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
nn.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_NN_H_
25 #define TVM_TOPI_NN_H_
26 
27 #include <tvm/arith/analyzer.h>
28 #include <tvm/te/operation.h>
29 #include <tvm/tir/expr.h>
30 #include <tvm/tir/op.h>
32 #include <tvm/topi/reduction.h>
33 #include <tvm/topi/tags.h>
34 #include <tvm/topi/transform.h>
35 
36 #include <algorithm>
37 #include <string>
38 
39 namespace tvm {
40 namespace topi {
41 
42 using namespace tvm::te;
43 
54 template <typename T>
55 inline tvm::te::Tensor relu(const tvm::te::Tensor& t, T threshold = static_cast<T>(0),
56  std::string name = "T_relu", std::string tag = kElementWise) {
57  return tvm::te::compute(
58  t->shape,
59  [&](const tvm::Array<tvm::tir::Var>& i) {
60  auto threshold_const = tvm::tir::make_const(t->dtype, threshold);
61  return tvm::max(t(i), threshold_const);
62  },
63  name, tag);
64 }
65 
76 inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, double alpha = 0.1,
77  std::string name = "T_leaky_relu",
78  std::string tag = kElementWise) {
79  return tvm::te::compute(
80  t->shape,
81  [&](const tvm::Array<tvm::tir::Var>& i) {
82  auto value = t(i);
83  auto calpha = tvm::tir::make_const(value.dtype(), alpha);
84  return tvm::tir::Select(value > 0, value, value * calpha);
85  },
86  name, tag);
87 }
88 
100 inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& slope,
101  const int axis = 1, std::string name = "T_prelu",
102  std::string tag = kBroadcast) {
103  ICHECK((size_t)axis < x->shape.size()) << "Wrong axis (" << axis << ")value. ";
104  ICHECK(topi::detail::GetConstInt(slope->shape[0]) == topi::detail::GetConstInt(x->shape[axis]))
105  << "Wrong slope shape received.";
106 
107  return tvm::te::compute(
108  x->shape,
109  [&](const tvm::Array<tvm::tir::Var>& indices) {
110  auto xval = x(indices);
111  return tvm::tir::Select(xval > 0, xval, xval * slope(indices[axis]));
112  },
113  name, tag);
114 }
115 
157  PrimExpr pad_value = PrimExpr(), std::string name = "T_pad",
158  std::string tag = kElementWise, std::string pad_mode = "constant",
159  const Array<PrimExpr>* dyn_output_shape = nullptr) {
160  if (pad_after.size() < pad_before.size()) {
161  for (size_t i = pad_after.size(); i < pad_before.size(); ++i) {
162  pad_after.push_back(pad_before[i]);
163  }
164  }
165 
166  arith::Analyzer analyzer;
167  ICHECK_GE(pad_before.size(), 1);
168  ICHECK_EQ(pad_before.size(), pad_after.size());
169  tvm::Array<tvm::PrimExpr> pad_before_int32;
170  tvm::Array<tvm::PrimExpr> pad_after_int32;
171 
172  for (const auto& ele : pad_before) {
173  pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
174  }
175  for (const auto& ele : pad_after) {
176  pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
177  }
178 
179  tvm::Array<tvm::PrimExpr> output_shape;
180  if (dyn_output_shape == nullptr) {
181  for (size_t i = 0; i < t->shape.size(); ++i) {
182  if (i >= pad_before.size()) {
183  output_shape.push_back(t->shape[i]);
184  } else {
185  output_shape.push_back(
186  analyzer.Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i]));
187  }
188  }
189  } else {
190  for (size_t i = 0; i < dyn_output_shape->size(); i++) {
191  output_shape.push_back((*dyn_output_shape)[i]);
192  }
193  }
194 
195  if (!pad_value.defined()) {
196  pad_value = tvm::tir::make_const(t->dtype, 0);
197  }
198 
199  auto l = [&](tvm::Array<tvm::tir::Var> ovars) {
203  for (size_t i = 0; i < t->shape.size(); ++i) {
204  if (i >= pad_before_int32.size()) {
205  indices.push_back(ovars[i]);
206  continue;
207  }
208  if (!topi::detail::EqualCheck(pad_before_int32[i], 0)) {
209  sel.push_back(ovars[i] >= pad_before_int32[i]);
210  indices.push_back(ovars[i] - pad_before_int32[i]);
211  } else {
212  indices.push_back(ovars[i]);
213  }
214  if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) {
215  sel.push_back(analyzer.Simplify(ovars[i] < pad_before_int32[i] + t->shape[i]));
216  }
217  if (pad_mode == "edge") {
218  pad_idx.push_back(
219  tvm::if_then_else(ovars[i] < pad_before[i], 0,
220  tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i],
221  t->shape[i] - 1, ovars[i] - pad_before[i])));
222  } else if (pad_mode == "reflect") {
223  pad_idx.push_back(
224  tvm::if_then_else(ovars[i] < pad_before[i], pad_before[i] - ovars[i],
225  tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i],
226  t->shape[i] * 2 - ovars[i] + pad_before[i] - 2,
227  ovars[i] - pad_before[i])));
228  }
229  }
230  if (sel.size() != 0) {
231  if (pad_mode == "constant") {
232  return tvm::if_then_else(
233  foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); },
234  const_true(1), sel),
235  t(indices), pad_value);
236  } else if (pad_mode == "edge" || pad_mode == "reflect") {
237  return tvm::if_then_else(
238  foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); },
239  const_true(1), sel),
240  t(indices), t(pad_idx));
241  }
242  }
243  return t(indices);
244  };
245  return tvm::te::compute(output_shape, l, name, tag);
246 }
247 
269  int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1,
270  std::string name = "T_conv2d_nchw",
271  std::string tag = kConv2dNCHW) {
272  ICHECK_EQ(4, I->shape.size());
273  ICHECK_EQ(4, W->shape.size());
274  auto pH = I->shape[2];
275  auto pW = I->shape[3];
276  tvm::Array<tvm::PrimExpr> output_shape{
277  I->shape[0], // B
278  W->shape[0], // O
279  indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
280  indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W
281  };
282  auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[1]}, "i");
283  auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[2]}, "kh");
284  auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw");
285  auto T =
286  (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
287  auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) {
288  return tvm::sum(T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw), {i, kh, kw});
289  };
290  return tvm::te::compute(output_shape, l, name, tag);
291 }
292 
313  int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1,
314  std::string name = "T_conv2d_hwcn",
315  std::string tag = kConv2dHWCN) {
316  ICHECK_EQ(4, I->shape.size());
317  ICHECK_EQ(4, W->shape.size());
318  auto pH = I->shape[2];
319  auto pW = I->shape[3];
320  tvm::Array<tvm::PrimExpr> output_shape{
321  indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
322  indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1, // W
323  I->shape[2], // B
324  W->shape[3] // O
325  };
326  auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[3]}, "i");
327  auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[0]}, "kh");
328  auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw");
329  auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {pad_h, pad_w});
330  auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) {
331  return tvm::sum(T(stride_h * h + kh, stride_w * w + kw, i, b) * W(kh, kw, i, o), {i, kh, kw});
332  };
333  return tvm::te::compute(output_shape, l, name, tag);
334 }
335 
357  int pad_h = 0, int pad_w = 0, int stride_h = 1,
358  int stride_w = 1,
359  std::string name = "T_depthwise_conv2d_nchw",
360  std::string tag = kDepthwiseConv2dNCHW) {
361  ICHECK_EQ(4, I->shape.size());
362  ICHECK_EQ(4, W->shape.size());
363  auto pH = I->shape[2];
364  auto pW = I->shape[3];
365  auto pCM = W->shape[1]; // channel_multiplier
366  tvm::Array<tvm::PrimExpr> output_shape{
367  I->shape[0], // B
368  W->shape[1], // O
369  indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
370  indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W
371  };
372  auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[1]}, "i");
373  auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[2]}, "kh");
374  auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw");
375  auto T =
376  (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
377  auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) {
378  return tvm::sum(T(b, indexdiv(i, pCM), stride_h * h + kh, stride_w * w + kw) *
379  W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw),
380  {i, kh, kw});
381  };
382  return tvm::te::compute(output_shape, l, name, tag);
383 }
384 
386  int pad_h = 0, int pad_w = 0, int stride_h = 1,
387  int stride_w = 1,
388  std::string name = "T_depthwise_conv2d_nhwc",
389  std::string tag = kDepthwiseConv2dNHWC) {
390  ICHECK_EQ(4, I->shape.size());
391  ICHECK_EQ(4, W->shape.size());
392  auto pH = I->shape[1];
393  auto pW = I->shape[2];
394  auto pCM = W->shape[1]; // channel_multiplier
395  tvm::Array<tvm::PrimExpr> output_shape{
396  I->shape[0], // B
397  indexdiv(I->shape[1] - W->shape[1] + 2 * pad_h, stride_h) + 1, // H
398  indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W
399  W->shape[3], // O
400  };
401  auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[3]}, "i");
402  auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[0]}, "kh");
403  auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw");
404  auto T =
405  (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), pad_h, pad_w, tvm::PrimExpr(0)});
406  auto l = [&](tvm::tir::Var b, tvm::tir::Var h, tvm::tir::Var w, tvm::tir::Var o) {
407  return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, indexdiv(i, pCM)) *
408  W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)),
409  {kh, kw, i});
410  };
411  return tvm::te::compute(output_shape, l, name, tag);
412 }
413 
435  int pad_h = 0, int pad_w = 0, int stride_h = 1,
436  int stride_w = 1,
437  std::string name = "T_group_conv2d_ngchw",
438  std::string tag = kGroupConv2d) {
439  ICHECK_EQ(5, I->shape.size());
440  ICHECK_EQ(5, W->shape.size());
441  auto pH = I->shape[2];
442  auto pW = I->shape[3];
443  tvm::Array<tvm::PrimExpr> output_shape{
444  I->shape[0], // B
445  I->shape[1], // G
446  W->shape[2], // O
447  indexdiv(I->shape[3] - W->shape[3] + 2 * pad_h, stride_h) + 1, // H
448  indexdiv(I->shape[4] - W->shape[4] + 2 * pad_w, stride_w) + 1 // W
449  };
450  auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[2]}, "i");
451  auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kh");
452  auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[4]}, "kw");
453 
454  auto T = (pad_h == 0 && pad_w == 0)
455  ? I
456  : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
457  auto l = [&](tvm::Array<tvm::tir::Var> args) {
458  tvm::tir::Var b = args[0];
459  tvm::tir::Var g = args[1];
460  tvm::tir::Var o = args[2];
461  tvm::tir::Var h = args[3];
462  tvm::tir::Var w = args[4];
463  return tvm::sum(I(b, g, i, stride_h * h + kh, stride_w * w + kw) * W(g, i, o, kh, kw),
464  {i, kh, kw});
465  };
466  return tvm::te::compute(output_shape, l, name, tag);
467 }
468 
483  const tvm::Array<Integer>& block_shape,
484  const tvm::Array<tvm::PrimExpr>& pad_before,
485  const tvm::Array<tvm::PrimExpr>& pad_after,
486  PrimExpr pad_value = PrimExpr(),
487  std::string name = "space_to_batch_nd",
488  std::string tag = kInjective) {
489  tvm::te::Tensor padded_t;
490  CHECK_EQ(pad_before.size(), pad_after.size());
491  CHECK_EQ(block_shape.size(), pad_before.size())
492  << "Paddings must be provided for each spatial dimension";
493  tvm::Array<tvm::PrimExpr> pad_before_int32;
494  tvm::Array<tvm::PrimExpr> pad_after_int32;
495 
496  // pad size for batch dimension is 0
497  pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0));
498  pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0));
499  // insert pad sizes given for spatial dimensions
500  for (const auto& ele : pad_before) {
501  pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
502  }
503  for (const auto& ele : pad_after) {
504  pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
505  }
506 
507  // pad the input with paddings provided
508  if (!pad_value.defined()) {
509  pad_value = tvm::tir::make_const(data->dtype, 0);
510  }
511  padded_t = pad(data, pad_before_int32, pad_after_int32, pad_value);
512 
513  auto input_shape = data->shape;
514  auto padded_shape = padded_t->shape;
515 
516  // infer shapes
517  tvm::Array<PrimExpr> r_shape;
518  tvm::Array<Integer> axis;
519  tvm::Array<PrimExpr> o_shape;
520 
521  size_t num_block_dims = block_shape.size();
522  int batch = static_cast<int>(GetConstInt(input_shape[0]));
523  tvm::PrimExpr block_shape_prod(1);
524  r_shape.push_back(batch);
525 
526  for (size_t i = 1; i <= num_block_dims; i++) {
527  int padded_input = static_cast<int>(GetConstInt(padded_shape[i]));
528  int block_size = static_cast<int>(GetConstInt(block_shape[i - 1]));
529  CHECK_EQ((padded_input % block_size), 0)
530  << "(" << i
531  << ")th "
532  "Input dimension after padding ("
533  << padded_input << ")"
534  << " must be divisible by its block size (" << block_size << ")";
535 
536  r_shape.push_back(div(padded_shape[i], block_shape[i - 1]));
537  r_shape.push_back(block_shape[i - 1]);
538  block_shape_prod *= block_shape[i - 1];
539  axis.push_back(Integer(r_shape.size() - 1)); // index of block_shape[i - 1]
540  }
541 
542  size_t n = axis.size();
543  axis.push_back(0); // batch is at index 0
544  // index of (padded_shape[i] / block_shape[i - 1]) in r_shape
545  for (size_t i = 0; i < n; i++) {
546  axis.push_back(static_cast<int>(GetConstInt(axis[i] - 1)));
547  }
548  o_shape.push_back(tvm::PrimExpr(batch) * block_shape_prod);
549  for (size_t i = 1; i <= num_block_dims; i++) {
550  o_shape.push_back(div(padded_shape[i], block_shape[i - 1]));
551  }
552  // append remaining shape
553  for (size_t i = num_block_dims + 1; i < input_shape.size(); i++) {
554  r_shape.push_back(input_shape[i]);
555  axis.push_back(Integer(r_shape.size() - 1)); // index of remaining shape in r_shape
556  o_shape.push_back(input_shape[i]);
557  }
558 
559  tvm::te::Tensor output = reshape(padded_t, r_shape);
560  output = transpose(output, axis);
561  output = reshape(output, o_shape);
562 
563  return output;
564 }
565 
579  const tvm::Array<Integer>& block_shape,
580  const tvm::Array<tvm::PrimExpr>& crop_begin_list,
581  const tvm::Array<tvm::PrimExpr>& crop_end_list,
582  std::string name = "batch_to_space_nd",
583  std::string tag = kInjective) {
584  // Construct shapes for reshape and transpose operation
585  Array<PrimExpr> in_shape = data->shape;
586  Array<PrimExpr> r_shape;
587  Array<Integer> axis;
588  size_t num_block_dims = block_shape.size();
589  size_t num_input_dims = in_shape.size();
590  tvm::PrimExpr block_shape_prod(1);
591  int batch = static_cast<int>(GetConstInt(in_shape[0]));
592 
593  for (size_t i = 0; i < num_block_dims; i++) {
594  r_shape.push_back(block_shape[i]);
595  block_shape_prod *= block_shape[i];
596  }
597  axis.push_back(Integer(r_shape.size())); // axis of (batch / block_shape_prod)
598  r_shape.push_back(batch / block_shape_prod);
599 
600  for (size_t i = 1; i < num_input_dims; i++) {
601  axis.push_back(Integer(r_shape.size())); // axis of in_shape[i]
602  if (axis.size() < (num_block_dims + num_input_dims)) {
603  axis.push_back(Integer(r_shape.size() - (num_block_dims + 1))); // axis of block_shape[i]
604  }
605  r_shape.push_back(in_shape[i]);
606  }
607 
608  Array<PrimExpr> r_p_shape;
609  r_p_shape.push_back(batch / block_shape_prod);
610  for (size_t i = 1; i <= num_block_dims; i++) {
611  r_p_shape.push_back(in_shape[i] * block_shape[i - 1]);
612  }
613  for (size_t i = num_block_dims + 1; i < num_input_dims; i++) {
614  r_p_shape.push_back(in_shape[i]);
615  }
616 
617  tvm::te::Tensor out;
618  out = reshape(data, r_shape);
619  out = transpose(out, axis);
620  out = reshape(out, r_p_shape);
621 
622  // Crop the start and end of dimensions of out
623  Array<Integer> begin_idx, end_idx, strides;
624  for (size_t i = 0; i < r_p_shape.size(); ++i) {
625  strides.push_back(Integer(1));
626  if (i > 0 && i <= num_block_dims) {
627  // prepare begin and end index for spatial dimensions
628  int begin_i = static_cast<int>(GetConstInt(crop_begin_list[i - 1]));
629  int end_i = static_cast<int>(GetConstInt(crop_end_list[i - 1]));
630  int out_i = static_cast<int>(GetConstInt(r_p_shape[i]));
631  CHECK_GT(out_i, (begin_i + end_i))
632  << "Incorrect crop sizes for (" << i << ")th dim, can not crop more than"
633  << " output size" << out_i << " vs " << (begin_i + end_i);
634  begin_idx.push_back(begin_i);
635  end_idx.push_back(out_i - end_i);
636  } else {
637  // ignore the batch and remaining dimension
638  begin_idx.push_back(Integer(0));
639  end_idx.push_back(static_cast<int>(GetConstInt(r_p_shape[i])));
640  }
641  }
642 
643  out = strided_slice(out, begin_idx, end_idx, strides);
644  return out;
645 }
646 
660 inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const Tensor& weights,
661  std::string reduction = "mean", int ignore_index = -100,
662  const std::string name = "nll_loss", const std::string tag = kBroadcast) {
663  if (predictions.ndim() == 1) {
664  // corner case: no batch in shape
665  // prediction->shape = (C,), targets->shape = (), weights->shape = (C,)
666  auto T = tvm::te::compute(
667  {},
668  [&](const tvm::Array<tvm::tir::Var>& target_indices) {
669  auto c = targets();
670  return tvm::tir::Select(c != ignore_index, -predictions(c) * weights(c),
671  tvm::tir::make_const(predictions->dtype, 0));
672  },
673  name, tag);
674  if (reduction == "mean") {
675  auto W = tvm::te::compute(
676  {},
677  [&](const tvm::Array<tvm::tir::Var>& target_indices) {
678  auto c = targets();
679  return tvm::tir::Select(c != ignore_index, weights(c),
680  tvm::tir::make_const(predictions->dtype, 0));
681  },
682  name, tag);
683  return topi::divide(T, W);
684  } else {
685  return T;
686  }
687  }
688 
689  auto T = tvm::te::compute(
690  targets->shape,
691  [&](const tvm::Array<tvm::tir::Var>& target_indices) {
692  auto c = targets(target_indices);
693  tvm::Array<tvm::PrimExpr> pred_indices;
694  pred_indices.push_back(target_indices[0]); // batch index
695  pred_indices.push_back(c); // class index
696  for (size_t i = 1; i < target_indices.size(); i++) {
697  pred_indices.push_back(target_indices[i]); // indices for multidimensional loss
698  }
699  return tvm::tir::Select(c != ignore_index, -predictions(pred_indices) * weights(c),
700  tvm::tir::make_const(predictions->dtype, 0));
701  },
702  name, tag);
703  ICHECK(T->shape.size() != 0);
704  if (reduction == "mean") {
705  auto W = tvm::te::compute(
706  targets->shape,
707  [&](const tvm::Array<tvm::tir::Var>& target_indices) {
708  auto c = targets(target_indices);
709  return tvm::tir::Select(c != ignore_index, weights(c),
710  tvm::tir::make_const(predictions->dtype, 0));
711  },
712  name, tag);
713  return topi::divide(topi::sum(T, {}), topi::sum(W, {}));
714  } else if (reduction == "sum") {
715  return topi::sum(T, {});
716  } else { // reduction == "none"
717  return T;
718  }
719 }
720 
721 } // namespace topi
722 } // namespace tvm
723 #endif // TVM_TOPI_NN_H_
tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_conv2d_hwcn", std::string tag=kConv2dHWCN)
Creates an operation for 2-D convolution layer with an HWCN-layout.
Definition: nn.h:312
tvm::PrimExpr divide(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:239
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:954
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Tensor expression language DSL.
Definition: extracted_task.h:33
a named variable in TIR
Definition: var.h:88
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
Algebra expression simplifications.
constexpr auto kGroupConv2d
Definition: tags.h:45
constexpr auto kInjective
Definition: tags.h:33
Reduction op constructors.
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array< PrimExpr > &values, Span span=Span())
Left fold.
Definition: op.h:970
size_t ndim() const
Definition: tensor.h:214
Common operators defined for Expr.
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
Tensor nll_loss(const Tensor &predictions, const Tensor &targets, const Tensor &weights, std::string reduction="mean", int ignore_index=-100, const std::string name="nll_loss", const std::string tag=kBroadcast)
Negative log likelihood loss.
Definition: nn.h:660
tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor &data, const tvm::Array< Integer > &block_shape, const tvm::Array< tvm::PrimExpr > &crop_begin_list, const tvm::Array< tvm::PrimExpr > &crop_end_list, std::string name="batch_to_space_nd", std::string tag=kInjective)
Reshape the batch dimension into spatial dimensions.
Definition: nn.h:578
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
Tensor sum(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:326
tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor &data, const tvm::Array< Integer > &block_shape, const tvm::Array< tvm::PrimExpr > &pad_before, const tvm::Array< tvm::PrimExpr > &pad_after, PrimExpr pad_value=PrimExpr(), std::string name="space_to_batch_nd", std::string tag=kInjective)
Divide spatial dimensions of the input into a grid of blocks.
Definition: nn.h:482
tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_conv2d_nchw", std::string tag=kConv2dNCHW)
Creates an operation that performs a 2-D convolution with an NCHW-layout.
Definition: nn.h:268
Utility functions for handling constants in TVM expressions.
constexpr auto kDepthwiseConv2dNHWC
Definition: tags.h:41
constexpr auto kBroadcast
Definition: tags.h:36
Range constainer.
Definition: expr.h:715
PrimExpr const_true(int lanes=1, Span span=Span())
Make a constant true expression.
Definition: op.h:785
Definition: source_map.h:120
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
size_t size() const
Definition: array.h:420
TIR expressions.
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
constexpr auto kElementWise
Definition: tags.h:32
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
tvm::te::Tensor depthwise_conv2d_nchw(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_depthwise_conv2d_nchw", std::string tag=kDepthwiseConv2dNCHW)
Creates an operation that performs a 2-D depthwise convolution with an NCHW-layout.
Definition: nn.h:356
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1768
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
tvm::te::Tensor pad(const tvm::te::Tensor &t, const tvm::Array< tvm::PrimExpr > &pad_before, tvm::Array< tvm::PrimExpr > pad_after=tvm::Array< tvm::PrimExpr >(), PrimExpr pad_value=PrimExpr(), std::string name="T_pad", std::string tag=kElementWise, std::string pad_mode="constant", const Array< PrimExpr > *dyn_output_shape=nullptr)
Creates an operation that performs padding.
Definition: nn.h:155
PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span=Span())
and
tvm::te::Tensor leaky_relu(const tvm::te::Tensor &t, double alpha=0.1, std::string name="T_leaky_relu", std::string tag=kElementWise)
Creates an operation that performs a leaky rectified linear unit.
Definition: nn.h:76
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
constexpr auto kConv2dNCHW
Definition: tags.h:38
Operation node can generate one or multiple Tensors.
constexpr auto kDepthwiseConv2dNCHW
Definition: tags.h:40
Managed reference to SelectNode.
Definition: expr.h:609
Tensor transpose(const Tensor &x, Array< Integer > axes, std::string name="T_transpose", std::string tag=kInjective)
Permute the dimensions of an array.
Definition: transform.h:196
Transform op constructors.
tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_depthwise_conv2d_nhwc", std::string tag=kDepthwiseConv2dNHWC)
Definition: nn.h:385
constexpr auto kConv2dHWCN
Definition: tags.h:39
External function interface to rocBLAS libraries.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Tensor reshape(const Tensor &x, Array< PrimExpr > newshape, std::string name="T_reshape", std::string tag=kInjective)
Reshape a tensor.
Definition: transform.h:320
tvm::te::Tensor relu(const tvm::te::Tensor &t, T threshold=static_cast< T >(0), std::string name="T_relu", std::string tag=kElementWise)
Creates an operation that performs a rectified linear unit.
Definition: nn.h:55
Tensor strided_slice(const Tensor &x, const Array< Integer > &begin, const Array< Integer > &end, const Array< Integer > &strides, std::string slice_mode="end", std::string name="T_strided_slice", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:816
Reference to PrimExprNode.
Definition: expr.h:114
tvm::te::Tensor prelu(const tvm::te::Tensor &x, const tvm::te::Tensor &slope, const int axis=1, std::string name="T_prelu", std::string tag=kBroadcast)
Creates an operation that performs a parametric rectified linear unit.
Definition: nn.h:100
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:579
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:164
Container of constant int that adds more constructors.
Definition: expr.h:622
tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_group_conv2d_ngchw", std::string tag=kGroupConv2d)
Creates an operation that performs a 2-D group convolution with an NGCHW-layout.
Definition: nn.h:434