tvm
nn.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_NN_H_
25 #define TVM_TOPI_NN_H_
26 
27 #include <tvm/arith/analyzer.h>
28 #include <tvm/te/operation.h>
29 #include <tvm/tirx/expr.h>
30 #include <tvm/tirx/op.h>
32 #include <tvm/topi/reduction.h>
33 #include <tvm/topi/tags.h>
34 #include <tvm/topi/transform.h>
35 
36 #include <algorithm>
37 #include <string>
38 
39 namespace tvm {
40 namespace topi {
41 
42 using namespace tvm::te;
43 
54 template <typename T>
55 inline tvm::te::Tensor relu(const tvm::te::Tensor& t, T threshold = static_cast<T>(0),
56  std::string name = "T_relu", std::string tag = kElementWise) {
57  return tvm::te::compute(
58  t->shape,
59  [&](const tvm::ffi::Array<tvm::tirx::Var>& i) {
60  auto threshold_const = tvm::tirx::make_const(t->dtype, threshold);
61  return tvm::max(t(i), threshold_const);
62  },
63  name, tag);
64 }
65 
76 inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, double alpha = 0.1,
77  std::string name = "T_leaky_relu",
78  std::string tag = kElementWise) {
79  return tvm::te::compute(
80  t->shape,
81  [&](const tvm::ffi::Array<tvm::tirx::Var>& i) {
82  auto value = t(i);
83  auto calpha = tvm::tirx::make_const(value.dtype(), alpha);
84  return tvm::tirx::Select(value > 0, value, value * calpha);
85  },
86  name, tag);
87 }
88 
100 inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& slope,
101  const int axis = 1, std::string name = "T_prelu",
102  std::string tag = kBroadcast) {
103  TVM_FFI_ICHECK((size_t)axis < x->shape.size()) << "Wrong axis (" << axis << ")value. ";
104  TVM_FFI_ICHECK(topi::detail::GetConstInt(slope->shape[0]) ==
105  topi::detail::GetConstInt(x->shape[axis]))
106  << "Wrong slope shape received.";
107 
108  return tvm::te::compute(
109  x->shape,
110  [&](const tvm::ffi::Array<tvm::tirx::Var>& indices) {
111  auto xval = x(indices);
112  return tvm::tirx::Select(xval > 0, xval, xval * slope(indices[axis]));
113  },
114  name, tag);
115 }
116 
157  const tvm::te::Tensor& t, const tvm::ffi::Array<tvm::PrimExpr>& pad_before,
158  tvm::ffi::Array<tvm::PrimExpr> pad_after = tvm::ffi::Array<tvm::PrimExpr>(),
159  PrimExpr pad_value = PrimExpr(), std::string name = "T_pad", std::string tag = kElementWise,
160  std::string pad_mode = "constant", const ffi::Array<PrimExpr>* dyn_output_shape = nullptr) {
161  if (pad_after.size() < pad_before.size()) {
162  for (size_t i = pad_after.size(); i < pad_before.size(); ++i) {
163  pad_after.push_back(pad_before[i]);
164  }
165  }
166 
167  arith::Analyzer analyzer;
168  TVM_FFI_ICHECK_GE(pad_before.size(), 1);
169  TVM_FFI_ICHECK_EQ(pad_before.size(), pad_after.size());
170  tvm::ffi::Array<tvm::PrimExpr> pad_before_int32;
171  tvm::ffi::Array<tvm::PrimExpr> pad_after_int32;
172 
173  for (const auto& ele : pad_before) {
174  pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
175  }
176  for (const auto& ele : pad_after) {
177  pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
178  }
179 
180  tvm::ffi::Array<tvm::PrimExpr> output_shape;
181  if (dyn_output_shape == nullptr) {
182  for (size_t i = 0; i < t->shape.size(); ++i) {
183  if (i >= pad_before.size()) {
184  output_shape.push_back(t->shape[i]);
185  } else {
186  output_shape.push_back(
187  analyzer.Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i]));
188  }
189  }
190  } else {
191  for (size_t i = 0; i < dyn_output_shape->size(); i++) {
192  output_shape.push_back((*dyn_output_shape)[i]);
193  }
194  }
195 
196  if (!pad_value.defined()) {
197  pad_value = tvm::tirx::make_const(t->dtype, 0);
198  }
199 
200  auto l = [&](tvm::ffi::Array<tvm::tirx::Var> ovars) {
201  tvm::ffi::Array<tvm::PrimExpr> indices;
202  tvm::ffi::Array<tvm::PrimExpr> sel;
203  tvm::ffi::Array<tvm::PrimExpr> pad_idx;
204  for (size_t i = 0; i < t->shape.size(); ++i) {
205  if (i >= pad_before_int32.size()) {
206  indices.push_back(ovars[i]);
207  continue;
208  }
209  if (!topi::detail::EqualCheck(pad_before_int32[i], 0)) {
210  sel.push_back(ovars[i] >= pad_before_int32[i]);
211  indices.push_back(ovars[i] - pad_before_int32[i]);
212  } else {
213  indices.push_back(ovars[i]);
214  }
215  if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) {
216  sel.push_back(analyzer.Simplify(ovars[i] < pad_before_int32[i] + t->shape[i]));
217  }
218  if (pad_mode == "edge") {
219  pad_idx.push_back(
220  tvm::if_then_else(ovars[i] < pad_before[i], 0,
221  tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i],
222  t->shape[i] - 1, ovars[i] - pad_before[i])));
223  } else if (pad_mode == "reflect") {
224  pad_idx.push_back(
225  tvm::if_then_else(ovars[i] < pad_before[i], pad_before[i] - ovars[i],
226  tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i],
227  t->shape[i] * 2 - ovars[i] + pad_before[i] - 2,
228  ovars[i] - pad_before[i])));
229  }
230  }
231  if (sel.size() != 0) {
232  if (pad_mode == "constant") {
233  return tvm::if_then_else(
234  foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); },
235  const_true(1), sel),
236  t(indices), pad_value);
237  } else if (pad_mode == "edge" || pad_mode == "reflect") {
238  return tvm::if_then_else(
239  foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); },
240  const_true(1), sel),
241  t(indices), t(pad_idx));
242  }
243  }
244  return t(indices);
245  };
246  return tvm::te::compute(output_shape, l, name, tag);
247 }
248 
270  int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1,
271  std::string name = "T_conv2d_nchw",
272  std::string tag = kConv2dNCHW) {
273  TVM_FFI_ICHECK_EQ(4, I->shape.size());
274  TVM_FFI_ICHECK_EQ(4, W->shape.size());
275  auto pH = I->shape[2];
276  auto pW = I->shape[3];
277  tvm::ffi::Array<tvm::PrimExpr> output_shape{
278  I->shape[0], // B
279  W->shape[0], // O
280  indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
281  indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W
282  };
283  auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[1]}, "i");
284  auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[2]}, "kh");
285  auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw");
286  auto T =
287  (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
289  return tvm::sum(T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw), {i, kh, kw});
290  };
291  return tvm::te::compute(output_shape, l, name, tag);
292 }
293 
314  int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1,
315  std::string name = "T_conv2d_hwcn",
316  std::string tag = kConv2dHWCN) {
317  TVM_FFI_ICHECK_EQ(4, I->shape.size());
318  TVM_FFI_ICHECK_EQ(4, W->shape.size());
319  auto pH = I->shape[2];
320  auto pW = I->shape[3];
321  tvm::ffi::Array<tvm::PrimExpr> output_shape{
322  indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
323  indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1, // W
324  I->shape[2], // B
325  W->shape[3] // O
326  };
327  auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[3]}, "i");
328  auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[0]}, "kh");
329  auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw");
330  auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {pad_h, pad_w});
332  return tvm::sum(T(stride_h * h + kh, stride_w * w + kw, i, b) * W(kh, kw, i, o), {i, kh, kw});
333  };
334  return tvm::te::compute(output_shape, l, name, tag);
335 }
336 
358  int pad_h = 0, int pad_w = 0, int stride_h = 1,
359  int stride_w = 1,
360  std::string name = "T_depthwise_conv2d_nchw",
361  std::string tag = kDepthwiseConv2dNCHW) {
362  TVM_FFI_ICHECK_EQ(4, I->shape.size());
363  TVM_FFI_ICHECK_EQ(4, W->shape.size());
364  auto pH = I->shape[2];
365  auto pW = I->shape[3];
366  auto pCM = W->shape[1]; // channel_multiplier
367  tvm::ffi::Array<tvm::PrimExpr> output_shape{
368  I->shape[0], // B
369  W->shape[1], // O
370  indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
371  indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W
372  };
373  auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[1]}, "i");
374  auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[2]}, "kh");
375  auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw");
376  auto T =
377  (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
379  return tvm::sum(T(b, indexdiv(i, pCM), stride_h * h + kh, stride_w * w + kw) *
380  W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw),
381  {i, kh, kw});
382  };
383  return tvm::te::compute(output_shape, l, name, tag);
384 }
385 
387  int pad_h = 0, int pad_w = 0, int stride_h = 1,
388  int stride_w = 1,
389  std::string name = "T_depthwise_conv2d_nhwc",
390  std::string tag = kDepthwiseConv2dNHWC) {
391  TVM_FFI_ICHECK_EQ(4, I->shape.size());
392  TVM_FFI_ICHECK_EQ(4, W->shape.size());
393  auto pH = I->shape[1];
394  auto pW = I->shape[2];
395  auto pCM = W->shape[1]; // channel_multiplier
396  tvm::ffi::Array<tvm::PrimExpr> output_shape{
397  I->shape[0], // B
398  indexdiv(I->shape[1] - W->shape[1] + 2 * pad_h, stride_h) + 1, // H
399  indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W
400  W->shape[3], // O
401  };
402  auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[3]}, "i");
403  auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[0]}, "kh");
404  auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw");
405  auto T =
406  (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), pad_h, pad_w, tvm::PrimExpr(0)});
408  return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, indexdiv(i, pCM)) *
409  W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)),
410  {kh, kw, i});
411  };
412  return tvm::te::compute(output_shape, l, name, tag);
413 }
414 
436  int pad_h = 0, int pad_w = 0, int stride_h = 1,
437  int stride_w = 1,
438  std::string name = "T_group_conv2d_ngchw",
439  std::string tag = kGroupConv2d) {
440  TVM_FFI_ICHECK_EQ(5, I->shape.size());
441  TVM_FFI_ICHECK_EQ(5, W->shape.size());
442  auto pH = I->shape[2];
443  auto pW = I->shape[3];
444  tvm::ffi::Array<tvm::PrimExpr> output_shape{
445  I->shape[0], // B
446  I->shape[1], // G
447  W->shape[2], // O
448  indexdiv(I->shape[3] - W->shape[3] + 2 * pad_h, stride_h) + 1, // H
449  indexdiv(I->shape[4] - W->shape[4] + 2 * pad_w, stride_w) + 1 // W
450  };
451  auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[2]}, "i");
452  auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kh");
453  auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[4]}, "kw");
454 
455  auto T = (pad_h == 0 && pad_w == 0)
456  ? I
457  : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
458  auto l = [&](tvm::ffi::Array<tvm::tirx::Var> args) {
459  tvm::tirx::Var b = args[0];
460  tvm::tirx::Var g = args[1];
461  tvm::tirx::Var o = args[2];
462  tvm::tirx::Var h = args[3];
463  tvm::tirx::Var w = args[4];
464  return tvm::sum(I(b, g, i, stride_h * h + kh, stride_w * w + kw) * W(g, i, o, kh, kw),
465  {i, kh, kw});
466  };
467  return tvm::te::compute(output_shape, l, name, tag);
468 }
469 
484  const tvm::ffi::Array<Integer>& block_shape,
485  const tvm::ffi::Array<tvm::PrimExpr>& pad_before,
486  const tvm::ffi::Array<tvm::PrimExpr>& pad_after,
487  PrimExpr pad_value = PrimExpr(),
488  std::string name = "space_to_batch_nd",
489  std::string tag = kInjective) {
490  tvm::te::Tensor padded_t;
491  TVM_FFI_ICHECK_EQ(pad_before.size(), pad_after.size());
492  TVM_FFI_ICHECK_EQ(block_shape.size(), pad_before.size())
493  << "Paddings must be provided for each spatial dimension";
494  tvm::ffi::Array<tvm::PrimExpr> pad_before_int32;
495  tvm::ffi::Array<tvm::PrimExpr> pad_after_int32;
496 
497  // pad size for batch dimension is 0
498  pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0));
499  pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0));
500  // insert pad sizes given for spatial dimensions
501  for (const auto& ele : pad_before) {
502  pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
503  }
504  for (const auto& ele : pad_after) {
505  pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
506  }
507 
508  // pad the input with paddings provided
509  if (!pad_value.defined()) {
510  pad_value = tvm::tirx::make_const(data->dtype, 0);
511  }
512  padded_t = pad(data, pad_before_int32, pad_after_int32, pad_value);
513 
514  auto input_shape = data->shape;
515  auto padded_shape = padded_t->shape;
516 
517  // infer shapes
518  tvm::ffi::Array<PrimExpr> r_shape;
519  tvm::ffi::Array<Integer> axis;
520  tvm::ffi::Array<PrimExpr> o_shape;
521 
522  size_t num_block_dims = block_shape.size();
523  int batch = static_cast<int>(GetConstInt(input_shape[0]));
524  tvm::PrimExpr block_shape_prod(1);
525  r_shape.push_back(batch);
526 
527  for (size_t i = 1; i <= num_block_dims; i++) {
528  int padded_input = static_cast<int>(GetConstInt(padded_shape[i]));
529  int block_size = static_cast<int>(GetConstInt(block_shape[i - 1]));
530  TVM_FFI_ICHECK_EQ((padded_input % block_size), 0)
531  << "(" << i
532  << ")th "
533  "Input dimension after padding ("
534  << padded_input << ")"
535  << " must be divisible by its block size (" << block_size << ")";
536 
537  r_shape.push_back(div(padded_shape[i], block_shape[i - 1]));
538  r_shape.push_back(block_shape[i - 1]);
539  block_shape_prod *= block_shape[i - 1];
540  axis.push_back(Integer(r_shape.size() - 1)); // index of block_shape[i - 1]
541  }
542 
543  size_t n = axis.size();
544  axis.push_back(0); // batch is at index 0
545  // index of (padded_shape[i] / block_shape[i - 1]) in r_shape
546  for (size_t i = 0; i < n; i++) {
547  axis.push_back(static_cast<int>(GetConstInt(axis[i] - 1)));
548  }
549  o_shape.push_back(tvm::PrimExpr(batch) * block_shape_prod);
550  for (size_t i = 1; i <= num_block_dims; i++) {
551  o_shape.push_back(div(padded_shape[i], block_shape[i - 1]));
552  }
553  // append remaining shape
554  for (size_t i = num_block_dims + 1; i < input_shape.size(); i++) {
555  r_shape.push_back(input_shape[i]);
556  axis.push_back(Integer(r_shape.size() - 1)); // index of remaining shape in r_shape
557  o_shape.push_back(input_shape[i]);
558  }
559 
560  tvm::te::Tensor output = reshape(padded_t, r_shape);
561  output = transpose(output, axis);
562  output = reshape(output, o_shape);
563 
564  return output;
565 }
566 
580  const tvm::ffi::Array<Integer>& block_shape,
581  const tvm::ffi::Array<tvm::PrimExpr>& crop_begin_list,
582  const tvm::ffi::Array<tvm::PrimExpr>& crop_end_list,
583  std::string name = "batch_to_space_nd",
584  std::string tag = kInjective) {
585  // Construct shapes for reshape and transpose operation
586  ffi::Array<PrimExpr> in_shape = data->shape;
587  ffi::Array<PrimExpr> r_shape;
588  ffi::Array<Integer> axis;
589  size_t num_block_dims = block_shape.size();
590  size_t num_input_dims = in_shape.size();
591  tvm::PrimExpr block_shape_prod(1);
592  int batch = static_cast<int>(GetConstInt(in_shape[0]));
593 
594  for (size_t i = 0; i < num_block_dims; i++) {
595  r_shape.push_back(block_shape[i]);
596  block_shape_prod *= block_shape[i];
597  }
598  axis.push_back(Integer(r_shape.size())); // axis of (batch / block_shape_prod)
599  r_shape.push_back(batch / block_shape_prod);
600 
601  for (size_t i = 1; i < num_input_dims; i++) {
602  axis.push_back(Integer(r_shape.size())); // axis of in_shape[i]
603  if (axis.size() < (num_block_dims + num_input_dims)) {
604  axis.push_back(Integer(r_shape.size() - (num_block_dims + 1))); // axis of block_shape[i]
605  }
606  r_shape.push_back(in_shape[i]);
607  }
608 
609  ffi::Array<PrimExpr> r_p_shape;
610  r_p_shape.push_back(batch / block_shape_prod);
611  for (size_t i = 1; i <= num_block_dims; i++) {
612  r_p_shape.push_back(in_shape[i] * block_shape[i - 1]);
613  }
614  for (size_t i = num_block_dims + 1; i < num_input_dims; i++) {
615  r_p_shape.push_back(in_shape[i]);
616  }
617 
618  tvm::te::Tensor out;
619  out = reshape(data, r_shape);
620  out = transpose(out, axis);
621  out = reshape(out, r_p_shape);
622 
623  // Crop the start and end of dimensions of out
624  ffi::Array<Integer> begin_idx, end_idx, strides;
625  for (size_t i = 0; i < r_p_shape.size(); ++i) {
626  strides.push_back(Integer(1));
627  if (i > 0 && i <= num_block_dims) {
628  // prepare begin and end index for spatial dimensions
629  int begin_i = static_cast<int>(GetConstInt(crop_begin_list[i - 1]));
630  int end_i = static_cast<int>(GetConstInt(crop_end_list[i - 1]));
631  int out_i = static_cast<int>(GetConstInt(r_p_shape[i]));
632  TVM_FFI_ICHECK_GT(out_i, (begin_i + end_i))
633  << "Incorrect crop sizes for (" << i << ")th dim, can not crop more than"
634  << " output size" << out_i << " vs " << (begin_i + end_i);
635  begin_idx.push_back(begin_i);
636  end_idx.push_back(out_i - end_i);
637  } else {
638  // ignore the batch and remaining dimension
639  begin_idx.push_back(Integer(0));
640  end_idx.push_back(static_cast<int>(GetConstInt(r_p_shape[i])));
641  }
642  }
643 
644  out = strided_slice(out, begin_idx, end_idx, strides);
645  return out;
646 }
647 
661 inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const Tensor& weights,
662  std::string reduction = "mean", int ignore_index = -100,
663  const std::string name = "nll_loss", const std::string tag = kBroadcast) {
664  if (predictions.ndim() == 1) {
665  // corner case: no batch in shape
666  // prediction->shape = (C,), targets->shape = (), weights->shape = (C,)
667  auto T = tvm::te::compute(
668  {},
669  [&](const tvm::ffi::Array<tvm::tirx::Var>& target_indices) {
670  auto c = targets();
671  return tvm::tirx::Select(c != ignore_index, -predictions(c) * weights(c),
672  tvm::tirx::make_const(predictions->dtype, 0));
673  },
674  name, tag);
675  if (reduction == "mean") {
676  auto W = tvm::te::compute(
677  {},
678  [&](const tvm::ffi::Array<tvm::tirx::Var>& target_indices) {
679  auto c = targets();
680  return tvm::tirx::Select(c != ignore_index, weights(c),
681  tvm::tirx::make_const(predictions->dtype, 0));
682  },
683  name, tag);
684  return topi::divide(T, W);
685  } else {
686  return T;
687  }
688  }
689  auto T = tvm::te::compute(
690  targets->shape,
691  [&](const tvm::ffi::Array<tvm::tirx::Var>& target_indices) {
692  auto c = targets(target_indices);
693  tvm::ffi::Array<tvm::PrimExpr> pred_indices;
694  pred_indices.push_back(target_indices[0]); // batch index
695  pred_indices.push_back(c); // class index
696  for (size_t i = 1; i < target_indices.size(); i++) {
697  pred_indices.push_back(target_indices[i]); // indices for multidimensional loss
698  }
699  return tvm::tirx::Select(c != ignore_index, -predictions(pred_indices) * weights(c),
700  tvm::tirx::make_const(predictions->dtype, 0));
701  },
702  name, tag);
703  TVM_FFI_ICHECK(T->shape.size() != 0);
704  if (reduction == "mean") {
705  auto W = tvm::te::compute(
706  targets->shape,
707  [&](const tvm::ffi::Array<tvm::tirx::Var>& target_indices) {
708  auto c = targets(target_indices);
709  return tvm::tirx::Select(c != ignore_index, weights(c),
710  tvm::tirx::make_const(predictions->dtype, 0));
711  },
712  name, tag);
713  return topi::divide(topi::sum(T, tvm::ffi::Array<Integer>(nullptr)),
714  topi::sum(W, tvm::ffi::Array<Integer>(nullptr)));
715  } else if (reduction == "sum") {
716  return topi::sum(T, tvm::ffi::Array<Integer>(nullptr));
717  } else { // reduction == "none"
718  return T;
719  }
720 }
721 
722 } // namespace topi
723 } // namespace tvm
724 #endif // TVM_TOPI_NN_H_
Algebra expression simplifications.
Container of constant int that adds more constructors.
Definition: expr.h:601
Reference to PrimExprNode.
Definition: expr.h:126
Range container
Definition: expr.h:690
Definition: source_map.h:111
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:634
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:278
Managed Tensor. The array is backed by reference counted blocks.
Definition: tensor.h:54
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:100
Managed reference to SelectNode.
Definition: expr.h:514
a named variable in TIR
Definition: var.h:76
Utility functions for handling constants in TVM expressions.
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(ffi::Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", ffi::Map< ffi::String, ffi::Any > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:1007
PrimExpr const_true(int lanes=1, Span span=Span())
Make a constant true expression.
Definition: op.h:830
PrimExpr foldl(FReduce freduce, PrimExpr init_value, const ffi::Array< PrimExpr > &values, Span span=Span())
Left fold.
Definition: op.h:912
constexpr auto kElementWise
Definition: tags.h:32
Tensor reshape(const Tensor &x, ffi::Array< PrimExpr > newshape, std::string name="T_reshape", std::string tag=kInjective)
Reshape a tensor.
Definition: transform.h:330
constexpr auto kBroadcast
Definition: tags.h:36
constexpr auto kInjective
Definition: tags.h:33
tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor &data, const tvm::ffi::Array< Integer > &block_shape, const tvm::ffi::Array< tvm::PrimExpr > &pad_before, const tvm::ffi::Array< tvm::PrimExpr > &pad_after, PrimExpr pad_value=PrimExpr(), std::string name="space_to_batch_nd", std::string tag=kInjective)
Divide spatial dimensions of the input into a grid of blocks.
Definition: nn.h:483
constexpr auto kConv2dNCHW
Definition: tags.h:38
tvm::te::Tensor prelu(const tvm::te::Tensor &x, const tvm::te::Tensor &slope, const int axis=1, std::string name="T_prelu", std::string tag=kBroadcast)
Creates an operation that performs a parametric rectified linear unit.
Definition: nn.h:100
tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor &data, const tvm::ffi::Array< Integer > &block_shape, const tvm::ffi::Array< tvm::PrimExpr > &crop_begin_list, const tvm::ffi::Array< tvm::PrimExpr > &crop_end_list, std::string name="batch_to_space_nd", std::string tag=kInjective)
Reshape the batch dimension into spatial dimensions.
Definition: nn.h:579
tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_group_conv2d_ngchw", std::string tag=kGroupConv2d)
Creates an operation that performs a 2-D group convolution with an NGCHW-layout.
Definition: nn.h:435
tvm::te::Tensor leaky_relu(const tvm::te::Tensor &t, double alpha=0.1, std::string name="T_leaky_relu", std::string tag=kElementWise)
Creates an operation that performs a leaky rectified linear unit.
Definition: nn.h:76
tvm::PrimExpr divide(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:241
Tensor transpose(const Tensor &x, ffi::Optional< ffi::Array< Integer >> opt_axes, std::string name="T_transpose", std::string tag=kInjective)
Permute the dimensions of an array.
Definition: transform.h:205
constexpr auto kDepthwiseConv2dNCHW
Definition: tags.h:40
tvm::te::Tensor depthwise_conv2d_nchw(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_depthwise_conv2d_nchw", std::string tag=kDepthwiseConv2dNCHW)
Creates an operation that performs a 2-D depthwise convolution with an NCHW-layout.
Definition: nn.h:357
constexpr auto kGroupConv2d
Definition: tags.h:45
tvm::te::Tensor pad(const tvm::te::Tensor &t, const tvm::ffi::Array< tvm::PrimExpr > &pad_before, tvm::ffi::Array< tvm::PrimExpr > pad_after=tvm::ffi::Array< tvm::PrimExpr >(), PrimExpr pad_value=PrimExpr(), std::string name="T_pad", std::string tag=kElementWise, std::string pad_mode="constant", const ffi::Array< PrimExpr > *dyn_output_shape=nullptr)
Creates an operation that performs padding.
Definition: nn.h:156
constexpr auto kConv2dHWCN
Definition: tags.h:39
constexpr auto kDepthwiseConv2dNHWC
Definition: tags.h:41
Tensor strided_slice(const Tensor &x, const ffi::Array< Integer > &begin, const ffi::Array< Integer > &end, const ffi::Array< Integer > &strides, std::string slice_mode="end", std::string name="T_strided_slice", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:962
tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_conv2d_nchw", std::string tag=kConv2dNCHW)
Creates an operation that performs a 2-D convolution with an NCHW-layout.
Definition: nn.h:269
Tensor sum(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:328
tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_conv2d_hwcn", std::string tag=kConv2dHWCN)
Creates an operation for 2-D convolution layer with an HWCN-layout.
Definition: nn.h:313
tvm::te::Tensor relu(const tvm::te::Tensor &t, T threshold=static_cast< T >(0), std::string name="T_relu", std::string tag=kElementWise)
Creates an operation that performs a rectified linear unit.
Definition: nn.h:55
Tensor nll_loss(const Tensor &predictions, const Tensor &targets, const Tensor &weights, std::string reduction="mean", int ignore_index=-100, const std::string name="nll_loss", const std::string tag=kBroadcast)
Negative log likelihood loss.
Definition: nn.h:661
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1981
tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_depthwise_conv2d_nhwc", std::string tag=kDepthwiseConv2dNHWC)
Definition: nn.h:386
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span=Span())
and
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr sum(PrimExpr source, ffi::Array< tirx::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
Operation node can generate one or multiple Tensors.
Reduction op constructors.
Tag definitions.
TIR expressions.
Common operators defined for Expr.
Transform op constructors.