tvm
nn.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_NN_H_
25 #define TVM_TOPI_NN_H_
26 
27 #include <tvm/arith/analyzer.h>
28 #include <tvm/te/operation.h>
29 #include <tvm/tir/expr.h>
30 #include <tvm/tir/op.h>
32 #include <tvm/topi/reduction.h>
33 #include <tvm/topi/tags.h>
34 #include <tvm/topi/transform.h>
35 
36 #include <algorithm>
37 #include <string>
38 
39 namespace tvm {
40 namespace topi {
41 
42 using namespace tvm::te;
43 
54 template <typename T>
55 inline tvm::te::Tensor relu(const tvm::te::Tensor& t, T threshold = static_cast<T>(0),
56  std::string name = "T_relu", std::string tag = kElementWise) {
57  return tvm::te::compute(
58  t->shape,
59  [&](const tvm::Array<tvm::tir::Var>& i) {
60  auto threshold_const = tvm::tir::make_const(t->dtype, threshold);
61  return tvm::max(t(i), threshold_const);
62  },
63  name, tag);
64 }
65 
76 inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, double alpha = 0.1,
77  std::string name = "T_leaky_relu",
78  std::string tag = kElementWise) {
79  return tvm::te::compute(
80  t->shape,
81  [&](const tvm::Array<tvm::tir::Var>& i) {
82  auto value = t(i);
83  auto calpha = tvm::tir::make_const(value.dtype(), alpha);
84  return tvm::tir::Select(value > 0, value, value * calpha);
85  },
86  name, tag);
87 }
88 
100 inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& slope,
101  const int axis = 1, std::string name = "T_prelu",
102  std::string tag = kBroadcast) {
103  ICHECK((size_t)axis < x->shape.size()) << "Wrong axis (" << axis << ")value. ";
104  ICHECK(topi::detail::GetConstInt(slope->shape[0]) == topi::detail::GetConstInt(x->shape[axis]))
105  << "Wrong slope shape received.";
106 
107  return tvm::te::compute(
108  x->shape,
109  [&](const tvm::Array<tvm::tir::Var>& indices) {
110  auto xval = x(indices);
111  return tvm::tir::Select(xval > 0, xval, xval * slope(indices[axis]));
112  },
113  name, tag);
114 }
115 
157  PrimExpr pad_value = PrimExpr(), std::string name = "T_pad",
158  std::string tag = kElementWise, std::string pad_mode = "constant",
159  const Array<PrimExpr>* dyn_output_shape = nullptr) {
160  if (pad_after.size() < pad_before.size()) {
161  for (size_t i = pad_after.size(); i < pad_before.size(); ++i) {
162  pad_after.push_back(pad_before[i]);
163  }
164  }
165 
166  arith::Analyzer analyzer;
167  ICHECK_GE(pad_before.size(), 1);
168  ICHECK_EQ(pad_before.size(), pad_after.size());
169  tvm::Array<tvm::PrimExpr> pad_before_int32;
170  tvm::Array<tvm::PrimExpr> pad_after_int32;
171 
172  for (const auto& ele : pad_before) {
173  pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
174  }
175  for (const auto& ele : pad_after) {
176  pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
177  }
178 
179  tvm::Array<tvm::PrimExpr> output_shape;
180  if (dyn_output_shape == nullptr) {
181  for (size_t i = 0; i < t->shape.size(); ++i) {
182  if (i >= pad_before.size()) {
183  output_shape.push_back(t->shape[i]);
184  } else {
185  output_shape.push_back(
186  analyzer.Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i]));
187  }
188  }
189  } else {
190  for (size_t i = 0; i < dyn_output_shape->size(); i++) {
191  output_shape.push_back((*dyn_output_shape)[i]);
192  }
193  }
194 
195  if (!pad_value.defined()) {
196  pad_value = tvm::tir::make_const(t->dtype, 0);
197  }
198 
199  auto l = [&](tvm::Array<tvm::tir::Var> ovars) {
203  for (size_t i = 0; i < t->shape.size(); ++i) {
204  if (i >= pad_before_int32.size()) {
205  indices.push_back(ovars[i]);
206  continue;
207  }
208  if (!topi::detail::EqualCheck(pad_before_int32[i], 0)) {
209  sel.push_back(ovars[i] >= pad_before_int32[i]);
210  indices.push_back(ovars[i] - pad_before_int32[i]);
211  } else {
212  indices.push_back(ovars[i]);
213  }
214  if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) {
215  sel.push_back(analyzer.Simplify(ovars[i] < pad_before_int32[i] + t->shape[i]));
216  }
217  if (pad_mode == "edge") {
218  pad_idx.push_back(
219  tvm::if_then_else(ovars[i] < pad_before[i], 0,
220  tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i],
221  t->shape[i] - 1, ovars[i] - pad_before[i])));
222  } else if (pad_mode == "reflect") {
223  pad_idx.push_back(
224  tvm::if_then_else(ovars[i] < pad_before[i], pad_before[i] - ovars[i],
225  tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i],
226  t->shape[i] * 2 - ovars[i] + pad_before[i] - 2,
227  ovars[i] - pad_before[i])));
228  }
229  }
230  if (sel.size() != 0) {
231  if (pad_mode == "constant") {
232  return tvm::if_then_else(
233  foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); },
234  const_true(1), sel),
235  t(indices), pad_value);
236  } else if (pad_mode == "edge" || pad_mode == "reflect") {
237  return tvm::if_then_else(
238  foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); },
239  const_true(1), sel),
240  t(indices), t(pad_idx));
241  }
242  }
243  return t(indices);
244  };
245  return tvm::te::compute(output_shape, l, name, tag);
246 }
247 
269  int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1,
270  std::string name = "T_conv2d_nchw",
271  std::string tag = kConv2dNCHW) {
272  ICHECK_EQ(4, I->shape.size());
273  ICHECK_EQ(4, W->shape.size());
274  auto pH = I->shape[2];
275  auto pW = I->shape[3];
276  tvm::Array<tvm::PrimExpr> output_shape{
277  I->shape[0], // B
278  W->shape[0], // O
279  indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
280  indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W
281  };
282  auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[1]}, "i");
283  auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[2]}, "kh");
284  auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw");
285  auto T =
286  (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
287  auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) {
288  return tvm::sum(T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw), {i, kh, kw});
289  };
290  return tvm::te::compute(output_shape, l, name, tag);
291 }
292 
313  int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1,
314  std::string name = "T_conv2d_hwcn",
315  std::string tag = kConv2dHWCN) {
316  ICHECK_EQ(4, I->shape.size());
317  ICHECK_EQ(4, W->shape.size());
318  auto pH = I->shape[2];
319  auto pW = I->shape[3];
320  tvm::Array<tvm::PrimExpr> output_shape{
321  indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
322  indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1, // W
323  I->shape[2], // B
324  W->shape[3] // O
325  };
326  auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[3]}, "i");
327  auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[0]}, "kh");
328  auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw");
329  auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {pad_h, pad_w});
330  auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) {
331  return tvm::sum(T(stride_h * h + kh, stride_w * w + kw, i, b) * W(kh, kw, i, o), {i, kh, kw});
332  };
333  return tvm::te::compute(output_shape, l, name, tag);
334 }
335 
357  int pad_h = 0, int pad_w = 0, int stride_h = 1,
358  int stride_w = 1,
359  std::string name = "T_depthwise_conv2d_nchw",
360  std::string tag = kDepthwiseConv2dNCHW) {
361  ICHECK_EQ(4, I->shape.size());
362  ICHECK_EQ(4, W->shape.size());
363  auto pH = I->shape[2];
364  auto pW = I->shape[3];
365  auto pCM = W->shape[1]; // channel_multiplier
366  tvm::Array<tvm::PrimExpr> output_shape{
367  I->shape[0], // B
368  W->shape[1], // O
369  indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
370  indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W
371  };
372  auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[1]}, "i");
373  auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[2]}, "kh");
374  auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw");
375  auto T =
376  (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
377  auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) {
378  return tvm::sum(T(b, indexdiv(i, pCM), stride_h * h + kh, stride_w * w + kw) *
379  W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw),
380  {i, kh, kw});
381  };
382  return tvm::te::compute(output_shape, l, name, tag);
383 }
384 
386  int pad_h = 0, int pad_w = 0, int stride_h = 1,
387  int stride_w = 1,
388  std::string name = "T_depthwise_conv2d_nhwc",
389  std::string tag = kDepthwiseConv2dNHWC) {
390  ICHECK_EQ(4, I->shape.size());
391  ICHECK_EQ(4, W->shape.size());
392  auto pH = I->shape[1];
393  auto pW = I->shape[2];
394  auto pCM = W->shape[1]; // channel_multiplier
395  tvm::Array<tvm::PrimExpr> output_shape{
396  I->shape[0], // B
397  indexdiv(I->shape[1] - W->shape[1] + 2 * pad_h, stride_h) + 1, // H
398  indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W
399  W->shape[3], // O
400  };
401  auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[3]}, "i");
402  auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[0]}, "kh");
403  auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw");
404  auto T =
405  (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), pad_h, pad_w, tvm::PrimExpr(0)});
406  auto l = [&](tvm::tir::Var b, tvm::tir::Var h, tvm::tir::Var w, tvm::tir::Var o) {
407  return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, indexdiv(i, pCM)) *
408  W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)),
409  {kh, kw, i});
410  };
411  return tvm::te::compute(output_shape, l, name, tag);
412 }
413 
435  int pad_h = 0, int pad_w = 0, int stride_h = 1,
436  int stride_w = 1,
437  std::string name = "T_group_conv2d_ngchw",
438  std::string tag = kGroupConv2d) {
439  ICHECK_EQ(5, I->shape.size());
440  ICHECK_EQ(5, W->shape.size());
441  auto pH = I->shape[2];
442  auto pW = I->shape[3];
443  tvm::Array<tvm::PrimExpr> output_shape{
444  I->shape[0], // B
445  I->shape[1], // G
446  W->shape[2], // O
447  indexdiv(I->shape[3] - W->shape[3] + 2 * pad_h, stride_h) + 1, // H
448  indexdiv(I->shape[4] - W->shape[4] + 2 * pad_w, stride_w) + 1 // W
449  };
450  auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[2]}, "i");
451  auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kh");
452  auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[4]}, "kw");
453 
454  auto T = (pad_h == 0 && pad_w == 0)
455  ? I
456  : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
457  auto l = [&](tvm::Array<tvm::tir::Var> args) {
458  tvm::tir::Var b = args[0];
459  tvm::tir::Var g = args[1];
460  tvm::tir::Var o = args[2];
461  tvm::tir::Var h = args[3];
462  tvm::tir::Var w = args[4];
463  return tvm::sum(I(b, g, i, stride_h * h + kh, stride_w * w + kw) * W(g, i, o, kh, kw),
464  {i, kh, kw});
465  };
466  return tvm::te::compute(output_shape, l, name, tag);
467 }
468 
483  const tvm::Array<Integer>& block_shape,
484  const tvm::Array<tvm::PrimExpr>& pad_before,
485  const tvm::Array<tvm::PrimExpr>& pad_after,
486  PrimExpr pad_value = PrimExpr(),
487  std::string name = "space_to_batch_nd",
488  std::string tag = kInjective) {
489  tvm::te::Tensor padded_t;
490  CHECK_EQ(pad_before.size(), pad_after.size());
491  CHECK_EQ(block_shape.size(), pad_before.size())
492  << "Paddings must be provided for each spatial dimension";
493  tvm::Array<tvm::PrimExpr> pad_before_int32;
494  tvm::Array<tvm::PrimExpr> pad_after_int32;
495 
496  // pad size for batch dimension is 0
497  pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0));
498  pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0));
499  // insert pad sizes given for spatial dimensions
500  for (const auto& ele : pad_before) {
501  pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
502  }
503  for (const auto& ele : pad_after) {
504  pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
505  }
506 
507  // pad the input with paddings provided
508  if (!pad_value.defined()) {
509  pad_value = tvm::tir::make_const(data->dtype, 0);
510  }
511  padded_t = pad(data, pad_before_int32, pad_after_int32, pad_value);
512 
513  auto input_shape = data->shape;
514  auto padded_shape = padded_t->shape;
515 
516  // infer shapes
517  tvm::Array<PrimExpr> r_shape;
518  tvm::Array<Integer> axis;
519  tvm::Array<PrimExpr> o_shape;
520 
521  size_t num_block_dims = block_shape.size();
522  int batch = static_cast<int>(GetConstInt(input_shape[0]));
523  tvm::PrimExpr block_shape_prod(1);
524  r_shape.push_back(batch);
525 
526  for (size_t i = 1; i <= num_block_dims; i++) {
527  int padded_input = static_cast<int>(GetConstInt(padded_shape[i]));
528  int block_size = static_cast<int>(GetConstInt(block_shape[i - 1]));
529  CHECK_EQ((padded_input % block_size), 0)
530  << "(" << i
531  << ")th "
532  "Input dimension after padding ("
533  << padded_input << ")"
534  << " must be divisible by its block size (" << block_size << ")";
535 
536  r_shape.push_back(div(padded_shape[i], block_shape[i - 1]));
537  r_shape.push_back(block_shape[i - 1]);
538  block_shape_prod *= block_shape[i - 1];
539  axis.push_back(Integer(r_shape.size() - 1)); // index of block_shape[i - 1]
540  }
541 
542  size_t n = axis.size();
543  axis.push_back(0); // batch is at index 0
544  // index of (padded_shape[i] / block_shape[i - 1]) in r_shape
545  for (size_t i = 0; i < n; i++) {
546  axis.push_back(static_cast<int>(GetConstInt(axis[i] - 1)));
547  }
548  o_shape.push_back(tvm::PrimExpr(batch) * block_shape_prod);
549  for (size_t i = 1; i <= num_block_dims; i++) {
550  o_shape.push_back(div(padded_shape[i], block_shape[i - 1]));
551  }
552  // append remaining shape
553  for (size_t i = num_block_dims + 1; i < input_shape.size(); i++) {
554  r_shape.push_back(input_shape[i]);
555  axis.push_back(Integer(r_shape.size() - 1)); // index of remaining shape in r_shape
556  o_shape.push_back(input_shape[i]);
557  }
558 
559  tvm::te::Tensor output = reshape(padded_t, r_shape);
560  output = transpose(output, axis);
561  output = reshape(output, o_shape);
562 
563  return output;
564 }
565 
579  const tvm::Array<Integer>& block_shape,
580  const tvm::Array<tvm::PrimExpr>& crop_begin_list,
581  const tvm::Array<tvm::PrimExpr>& crop_end_list,
582  std::string name = "batch_to_space_nd",
583  std::string tag = kInjective) {
584  // Construct shapes for reshape and transpose operation
585  Array<PrimExpr> in_shape = data->shape;
586  Array<PrimExpr> r_shape;
587  Array<Integer> axis;
588  size_t num_block_dims = block_shape.size();
589  size_t num_input_dims = in_shape.size();
590  tvm::PrimExpr block_shape_prod(1);
591  int batch = static_cast<int>(GetConstInt(in_shape[0]));
592 
593  for (size_t i = 0; i < num_block_dims; i++) {
594  r_shape.push_back(block_shape[i]);
595  block_shape_prod *= block_shape[i];
596  }
597  axis.push_back(Integer(r_shape.size())); // axis of (batch / block_shape_prod)
598  r_shape.push_back(batch / block_shape_prod);
599 
600  for (size_t i = 1; i < num_input_dims; i++) {
601  axis.push_back(Integer(r_shape.size())); // axis of in_shape[i]
602  if (axis.size() < (num_block_dims + num_input_dims)) {
603  axis.push_back(Integer(r_shape.size() - (num_block_dims + 1))); // axis of block_shape[i]
604  }
605  r_shape.push_back(in_shape[i]);
606  }
607 
608  Array<PrimExpr> r_p_shape;
609  r_p_shape.push_back(batch / block_shape_prod);
610  for (size_t i = 1; i <= num_block_dims; i++) {
611  r_p_shape.push_back(in_shape[i] * block_shape[i - 1]);
612  }
613  for (size_t i = num_block_dims + 1; i < num_input_dims; i++) {
614  r_p_shape.push_back(in_shape[i]);
615  }
616 
617  tvm::te::Tensor out;
618  out = reshape(data, r_shape);
619  out = transpose(out, axis);
620  out = reshape(out, r_p_shape);
621 
622  // Crop the start and end of dimensions of out
623  Array<Integer> begin_idx, end_idx, strides;
624  for (size_t i = 0; i < r_p_shape.size(); ++i) {
625  strides.push_back(Integer(1));
626  if (i > 0 && i <= num_block_dims) {
627  // prepare begin and end index for spatial dimensions
628  int begin_i = static_cast<int>(GetConstInt(crop_begin_list[i - 1]));
629  int end_i = static_cast<int>(GetConstInt(crop_end_list[i - 1]));
630  int out_i = static_cast<int>(GetConstInt(r_p_shape[i]));
631  CHECK_GT(out_i, (begin_i + end_i))
632  << "Incorrect crop sizes for (" << i << ")th dim, can not crop more than"
633  << " output size" << out_i << " vs " << (begin_i + end_i);
634  begin_idx.push_back(begin_i);
635  end_idx.push_back(out_i - end_i);
636  } else {
637  // ignore the batch and remaining dimension
638  begin_idx.push_back(Integer(0));
639  end_idx.push_back(static_cast<int>(GetConstInt(r_p_shape[i])));
640  }
641  }
642 
643  out = strided_slice(out, begin_idx, end_idx, strides);
644  return out;
645 }
646 
660 inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const Tensor& weights,
661  std::string reduction = "mean", int ignore_index = -100,
662  const std::string name = "nll_loss", const std::string tag = kBroadcast) {
663  if (predictions.ndim() == 1) {
664  // corner case: no batch in shape
665  // prediction->shape = (C,), targets->shape = (), weights->shape = (C,)
666  auto T = tvm::te::compute(
667  {},
668  [&](const tvm::Array<tvm::tir::Var>& target_indices) {
669  auto c = targets();
670  return tvm::tir::Select(c != ignore_index, -predictions(c) * weights(c),
671  tvm::tir::make_const(predictions->dtype, 0));
672  },
673  name, tag);
674  if (reduction == "mean") {
675  auto W = tvm::te::compute(
676  {},
677  [&](const tvm::Array<tvm::tir::Var>& target_indices) {
678  auto c = targets();
679  return tvm::tir::Select(c != ignore_index, weights(c),
680  tvm::tir::make_const(predictions->dtype, 0));
681  },
682  name, tag);
683  return topi::divide(T, W);
684  } else {
685  return T;
686  }
687  }
688  auto T = tvm::te::compute(
689  targets->shape,
690  [&](const tvm::Array<tvm::tir::Var>& target_indices) {
691  auto c = targets(target_indices);
692  tvm::Array<tvm::PrimExpr> pred_indices;
693  pred_indices.push_back(target_indices[0]); // batch index
694  pred_indices.push_back(c); // class index
695  for (size_t i = 1; i < target_indices.size(); i++) {
696  pred_indices.push_back(target_indices[i]); // indices for multidimensional loss
697  }
698  return tvm::tir::Select(c != ignore_index, -predictions(pred_indices) * weights(c),
699  tvm::tir::make_const(predictions->dtype, 0));
700  },
701  name, tag);
702  ICHECK(T->shape.size() != 0);
703  if (reduction == "mean") {
704  auto W = tvm::te::compute(
705  targets->shape,
706  [&](const tvm::Array<tvm::tir::Var>& target_indices) {
707  auto c = targets(target_indices);
708  return tvm::tir::Select(c != ignore_index, weights(c),
709  tvm::tir::make_const(predictions->dtype, 0));
710  },
711  name, tag);
712  return topi::divide(topi::sum(T, tvm::Array<Integer>(nullptr)),
713  topi::sum(W, tvm::Array<Integer>(nullptr)));
714  } else if (reduction == "sum") {
715  return topi::sum(T, tvm::Array<Integer>(nullptr));
716  } else { // reduction == "none"
717  return T;
718  }
719 }
720 
721 } // namespace topi
722 } // namespace tvm
723 #endif // TVM_TOPI_NN_H_
Algebra expression simplifications.
Container of constant int that adds more constructors.
Definition: expr.h:632
Reference to PrimExprNode.
Definition: expr.h:115
Range container
Definition: expr.h:725
Definition: source_map.h:120
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:629
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
size_t size() const
Definition: array.h:420
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:219
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
size_t ndim() const
Definition: tensor.h:214
Managed reference to SelectNode.
Definition: expr.h:609
a named variable in TIR
Definition: var.h:89
Utility functions for handling constants in TVM expressions.
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:962
PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array< PrimExpr > &values, Span span=Span())
Left fold.
Definition: op.h:868
PrimExpr const_true(int lanes=1, Span span=Span())
Make a constant true expression.
Definition: op.h:786
constexpr auto kElementWise
Definition: tags.h:32
constexpr auto kBroadcast
Definition: tags.h:36
Tensor transpose(const Tensor &x, Array< Integer > axes, std::string name="T_transpose", std::string tag=kInjective)
Permute the dimensions of an array.
Definition: transform.h:203
tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor &data, const tvm::Array< Integer > &block_shape, const tvm::Array< tvm::PrimExpr > &crop_begin_list, const tvm::Array< tvm::PrimExpr > &crop_end_list, std::string name="batch_to_space_nd", std::string tag=kInjective)
Reshape the batch dimension into spatial dimensions.
Definition: nn.h:578
Tensor strided_slice(const Tensor &x, const Array< Integer > &begin, const Array< Integer > &end, const Array< Integer > &strides, std::string slice_mode="end", std::string name="T_strided_slice", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:930
constexpr auto kInjective
Definition: tags.h:33
constexpr auto kConv2dNCHW
Definition: tags.h:38
tvm::te::Tensor prelu(const tvm::te::Tensor &x, const tvm::te::Tensor &slope, const int axis=1, std::string name="T_prelu", std::string tag=kBroadcast)
Creates an operation that performs a parametric rectified linear unit.
Definition: nn.h:100
tvm::te::Tensor pad(const tvm::te::Tensor &t, const tvm::Array< tvm::PrimExpr > &pad_before, tvm::Array< tvm::PrimExpr > pad_after=tvm::Array< tvm::PrimExpr >(), PrimExpr pad_value=PrimExpr(), std::string name="T_pad", std::string tag=kElementWise, std::string pad_mode="constant", const Array< PrimExpr > *dyn_output_shape=nullptr)
Creates an operation that performs padding.
Definition: nn.h:155
Tensor reshape(const Tensor &x, Array< PrimExpr > newshape, std::string name="T_reshape", std::string tag=kInjective)
Reshape a tensor.
Definition: transform.h:327
tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_group_conv2d_ngchw", std::string tag=kGroupConv2d)
Creates an operation that performs a 2-D group convolution with an NGCHW-layout.
Definition: nn.h:434
tvm::te::Tensor leaky_relu(const tvm::te::Tensor &t, double alpha=0.1, std::string name="T_leaky_relu", std::string tag=kElementWise)
Creates an operation that performs a leaky rectified linear unit.
Definition: nn.h:76
tvm::PrimExpr divide(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:239
constexpr auto kDepthwiseConv2dNCHW
Definition: tags.h:40
tvm::te::Tensor depthwise_conv2d_nchw(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_depthwise_conv2d_nchw", std::string tag=kDepthwiseConv2dNCHW)
Creates an operation that performs a 2-D depthwise convolution with an NCHW-layout.
Definition: nn.h:356
constexpr auto kGroupConv2d
Definition: tags.h:45
tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor &data, const tvm::Array< Integer > &block_shape, const tvm::Array< tvm::PrimExpr > &pad_before, const tvm::Array< tvm::PrimExpr > &pad_after, PrimExpr pad_value=PrimExpr(), std::string name="space_to_batch_nd", std::string tag=kInjective)
Divide spatial dimensions of the input into a grid of blocks.
Definition: nn.h:482
constexpr auto kConv2dHWCN
Definition: tags.h:39
constexpr auto kDepthwiseConv2dNHWC
Definition: tags.h:41
tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_conv2d_nchw", std::string tag=kConv2dNCHW)
Creates an operation that performs a 2-D convolution with an NCHW-layout.
Definition: nn.h:268
Tensor sum(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:326
tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_conv2d_hwcn", std::string tag=kConv2dHWCN)
Creates an operation for 2-D convolution layer with an HWCN-layout.
Definition: nn.h:312
tvm::te::Tensor relu(const tvm::te::Tensor &t, T threshold=static_cast< T >(0), std::string name="T_relu", std::string tag=kElementWise)
Creates an operation that performs a rectified linear unit.
Definition: nn.h:55
Tensor nll_loss(const Tensor &predictions, const Tensor &targets, const Tensor &weights, std::string reduction="mean", int ignore_index=-100, const std::string name="nll_loss", const std::string tag=kBroadcast)
Negative log likelihood loss.
Definition: nn.h:660
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1913
tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_depthwise_conv2d_nhwc", std::string tag=kDepthwiseConv2dNHWC)
Definition: nn.h:385
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span=Span())
and
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Operation node can generate one or multiple Tensors.
Reduction op constructors.
External function interface to rocBLAS libraries.
TIR expressions.
Common operators defined for Expr.
Transform op constructors.