24 #ifndef TVM_TOPI_NN_H_
25 #define TVM_TOPI_NN_H_
56 std::string name =
"T_relu", std::string tag =
kElementWise) {
60 auto threshold_const = tvm::tir::make_const(t->dtype, threshold);
61 return tvm::max(t(i), threshold_const);
77 std::string name =
"T_leaky_relu",
83 auto calpha = tvm::tir::make_const(value.dtype(), alpha);
84 return tvm::tir::Select(value > 0, value, value * calpha);
101 const int axis = 1, std::string name =
"T_prelu",
103 ICHECK((
size_t)axis < x->
shape.size()) <<
"Wrong axis (" << axis <<
")value. ";
104 ICHECK(topi::detail::GetConstInt(slope->shape[0]) == topi::detail::GetConstInt(x->shape[axis]))
105 <<
"Wrong slope shape received.";
110 auto xval = x(indices);
111 return tvm::tir::Select(xval > 0, xval, xval * slope(indices[axis]));
158 std::string tag =
kElementWise, std::string pad_mode =
"constant",
160 if (pad_after.size() < pad_before.
size()) {
161 for (
size_t i = pad_after.size(); i < pad_before.
size(); ++i) {
162 pad_after.push_back(pad_before[i]);
167 ICHECK_GE(pad_before.
size(), 1);
168 ICHECK_EQ(pad_before.
size(), pad_after.size());
172 for (
const auto& ele : pad_before) {
175 for (
const auto& ele : pad_after) {
180 if (dyn_output_shape ==
nullptr) {
181 for (
size_t i = 0; i < t->shape.size(); ++i) {
182 if (i >= pad_before.size()) {
186 analyzer.
Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i]));
190 for (
size_t i = 0; i < dyn_output_shape->size(); i++) {
191 output_shape.
push_back((*dyn_output_shape)[i]);
195 if (!pad_value.defined()) {
203 for (
size_t i = 0; i < t->shape.size(); ++i) {
204 if (i >= pad_before_int32.
size()) {
208 if (!topi::detail::EqualCheck(pad_before_int32[i], 0)) {
209 sel.
push_back(ovars[i] >= pad_before_int32[i]);
210 indices.
push_back(ovars[i] - pad_before_int32[i]);
214 if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) {
217 if (pad_mode ==
"edge") {
221 t->shape[i] - 1, ovars[i] - pad_before[i])));
222 }
else if (pad_mode ==
"reflect") {
226 t->shape[i] * 2 - ovars[i] + pad_before[i] - 2,
227 ovars[i] - pad_before[i])));
230 if (sel.
size() != 0) {
231 if (pad_mode ==
"constant") {
235 t(indices), pad_value);
236 }
else if (pad_mode ==
"edge" || pad_mode ==
"reflect") {
240 t(indices), t(pad_idx));
269 int pad_h = 0,
int pad_w = 0,
int stride_h = 1,
int stride_w = 1,
270 std::string name =
"T_conv2d_nchw",
272 ICHECK_EQ(4, I->shape.size());
273 ICHECK_EQ(4, W->shape.size());
274 auto pH = I->shape[2];
275 auto pW = I->shape[3];
279 indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1,
280 indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1
288 return tvm::sum(T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw), {i, kh, kw});
313 int pad_h = 0,
int pad_w = 0,
int stride_h = 1,
int stride_w = 1,
314 std::string name =
"T_conv2d_hwcn",
316 ICHECK_EQ(4, I->shape.size());
317 ICHECK_EQ(4, W->shape.size());
318 auto pH = I->shape[2];
319 auto pW = I->shape[3];
321 indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1,
322 indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1,
329 auto T = (pad_h == 0 && pad_w == 0) ? I :
pad(I, {pad_h, pad_w});
331 return tvm::sum(T(stride_h * h + kh, stride_w * w + kw, i, b) * W(kh, kw, i, o), {i, kh, kw});
357 int pad_h = 0,
int pad_w = 0,
int stride_h = 1,
359 std::string name =
"T_depthwise_conv2d_nchw",
361 ICHECK_EQ(4, I->shape.size());
362 ICHECK_EQ(4, W->shape.size());
363 auto pH = I->shape[2];
364 auto pW = I->shape[3];
365 auto pCM = W->shape[1];
369 indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1,
370 indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1
378 return tvm::sum(T(b,
indexdiv(i, pCM), stride_h * h + kh, stride_w * w + kw) *
386 int pad_h = 0,
int pad_w = 0,
int stride_h = 1,
388 std::string name =
"T_depthwise_conv2d_nhwc",
390 ICHECK_EQ(4, I->shape.size());
391 ICHECK_EQ(4, W->shape.size());
392 auto pH = I->shape[1];
393 auto pW = I->shape[2];
394 auto pCM = W->shape[1];
397 indexdiv(I->shape[1] - W->shape[1] + 2 * pad_h, stride_h) + 1,
398 indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1,
407 return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw,
indexdiv(i, pCM)) *
435 int pad_h = 0,
int pad_w = 0,
int stride_h = 1,
437 std::string name =
"T_group_conv2d_ngchw",
439 ICHECK_EQ(5, I->shape.size());
440 ICHECK_EQ(5, W->shape.size());
441 auto pH = I->shape[2];
442 auto pW = I->shape[3];
447 indexdiv(I->shape[3] - W->shape[3] + 2 * pad_h, stride_h) + 1,
448 indexdiv(I->shape[4] - W->shape[4] + 2 * pad_w, stride_w) + 1
454 auto T = (pad_h == 0 && pad_w == 0)
463 return tvm::sum(I(b, g, i, stride_h * h + kh, stride_w * w + kw) * W(g, i, o, kh, kw),
487 std::string name =
"space_to_batch_nd",
490 CHECK_EQ(pad_before.
size(), pad_after.
size());
491 CHECK_EQ(block_shape.
size(), pad_before.
size())
492 <<
"Paddings must be provided for each spatial dimension";
500 for (
const auto& ele : pad_before) {
503 for (
const auto& ele : pad_after) {
508 if (!pad_value.defined()) {
511 padded_t =
pad(data, pad_before_int32, pad_after_int32, pad_value);
513 auto input_shape = data->shape;
514 auto padded_shape = padded_t->shape;
521 size_t num_block_dims = block_shape.
size();
522 int batch =
static_cast<int>(GetConstInt(input_shape[0]));
526 for (
size_t i = 1; i <= num_block_dims; i++) {
527 int padded_input =
static_cast<int>(GetConstInt(padded_shape[i]));
528 int block_size =
static_cast<int>(GetConstInt(block_shape[i - 1]));
529 CHECK_EQ((padded_input % block_size), 0)
532 "Input dimension after padding ("
533 << padded_input <<
")"
534 <<
" must be divisible by its block size (" << block_size <<
")";
536 r_shape.
push_back(
div(padded_shape[i], block_shape[i - 1]));
538 block_shape_prod *= block_shape[i - 1];
542 size_t n = axis.
size();
545 for (
size_t i = 0; i < n; i++) {
546 axis.
push_back(
static_cast<int>(GetConstInt(axis[i] - 1)));
549 for (
size_t i = 1; i <= num_block_dims; i++) {
550 o_shape.
push_back(
div(padded_shape[i], block_shape[i - 1]));
553 for (
size_t i = num_block_dims + 1; i < input_shape.size(); i++) {
561 output =
reshape(output, o_shape);
582 std::string name =
"batch_to_space_nd",
588 size_t num_block_dims = block_shape.
size();
589 size_t num_input_dims = in_shape.
size();
591 int batch =
static_cast<int>(GetConstInt(in_shape[0]));
593 for (
size_t i = 0; i < num_block_dims; i++) {
595 block_shape_prod *= block_shape[i];
598 r_shape.
push_back(batch / block_shape_prod);
600 for (
size_t i = 1; i < num_input_dims; i++) {
602 if (axis.
size() < (num_block_dims + num_input_dims)) {
609 r_p_shape.
push_back(batch / block_shape_prod);
610 for (
size_t i = 1; i <= num_block_dims; i++) {
611 r_p_shape.
push_back(in_shape[i] * block_shape[i - 1]);
613 for (
size_t i = num_block_dims + 1; i < num_input_dims; i++) {
624 for (
size_t i = 0; i < r_p_shape.
size(); ++i) {
626 if (i > 0 && i <= num_block_dims) {
628 int begin_i =
static_cast<int>(GetConstInt(crop_begin_list[i - 1]));
629 int end_i =
static_cast<int>(GetConstInt(crop_end_list[i - 1]));
630 int out_i =
static_cast<int>(GetConstInt(r_p_shape[i]));
631 CHECK_GT(out_i, (begin_i + end_i))
632 <<
"Incorrect crop sizes for (" << i <<
")th dim, can not crop more than"
633 <<
" output size" << out_i <<
" vs " << (begin_i + end_i);
639 end_idx.
push_back(
static_cast<int>(GetConstInt(r_p_shape[i])));
661 std::string reduction =
"mean",
int ignore_index = -100,
662 const std::string name =
"nll_loss",
const std::string tag =
kBroadcast) {
663 if (predictions.
ndim() == 1) {
674 if (reduction ==
"mean") {
691 auto c = targets(target_indices);
692 tvm::Array<tvm::PrimExpr> pred_indices;
693 pred_indices.push_back(target_indices[0]);
694 pred_indices.push_back(c);
695 for (size_t i = 1; i < target_indices.size(); i++) {
696 pred_indices.push_back(target_indices[i]);
698 return tvm::tir::Select(c != ignore_index, -predictions(pred_indices) * weights(c),
702 ICHECK(T->shape.size() != 0);
703 if (reduction ==
"mean") {
707 auto c = targets(target_indices);
708 return tvm::tir::Select(c != ignore_index, weights(c),
709 tvm::tir::make_const(predictions->dtype, 0));
714 }
else if (reduction ==
"sum") {
Algebra expression simplifications.
Container of constant int that adds more constructors.
Definition: expr.h:632
Reference to PrimExprNode.
Definition: expr.h:115
Range container
Definition: expr.h:725
Definition: source_map.h:120
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:629
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
size_t size() const
Definition: array.h:420
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:219
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
size_t ndim() const
Definition: tensor.h:214
Managed reference to SelectNode.
Definition: expr.h:609
a named variable in TIR
Definition: var.h:89
Utility functions for handling constants in TVM expressions.
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:962
PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array< PrimExpr > &values, Span span=Span())
Left fold.
Definition: op.h:868
PrimExpr const_true(int lanes=1, Span span=Span())
Make a constant true expression.
Definition: op.h:786
constexpr auto kElementWise
Definition: tags.h:32
constexpr auto kBroadcast
Definition: tags.h:36
Tensor transpose(const Tensor &x, Array< Integer > axes, std::string name="T_transpose", std::string tag=kInjective)
Permute the dimensions of an array.
Definition: transform.h:203
tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor &data, const tvm::Array< Integer > &block_shape, const tvm::Array< tvm::PrimExpr > &crop_begin_list, const tvm::Array< tvm::PrimExpr > &crop_end_list, std::string name="batch_to_space_nd", std::string tag=kInjective)
Reshape the batch dimension into spatial dimensions.
Definition: nn.h:578
Tensor strided_slice(const Tensor &x, const Array< Integer > &begin, const Array< Integer > &end, const Array< Integer > &strides, std::string slice_mode="end", std::string name="T_strided_slice", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:930
constexpr auto kInjective
Definition: tags.h:33
constexpr auto kConv2dNCHW
Definition: tags.h:38
tvm::te::Tensor prelu(const tvm::te::Tensor &x, const tvm::te::Tensor &slope, const int axis=1, std::string name="T_prelu", std::string tag=kBroadcast)
Creates an operation that performs a parametric rectified linear unit.
Definition: nn.h:100
tvm::te::Tensor pad(const tvm::te::Tensor &t, const tvm::Array< tvm::PrimExpr > &pad_before, tvm::Array< tvm::PrimExpr > pad_after=tvm::Array< tvm::PrimExpr >(), PrimExpr pad_value=PrimExpr(), std::string name="T_pad", std::string tag=kElementWise, std::string pad_mode="constant", const Array< PrimExpr > *dyn_output_shape=nullptr)
Creates an operation that performs padding.
Definition: nn.h:155
Tensor reshape(const Tensor &x, Array< PrimExpr > newshape, std::string name="T_reshape", std::string tag=kInjective)
Reshape a tensor.
Definition: transform.h:327
tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_group_conv2d_ngchw", std::string tag=kGroupConv2d)
Creates an operation that performs a 2-D group convolution with an NGCHW-layout.
Definition: nn.h:434
tvm::te::Tensor leaky_relu(const tvm::te::Tensor &t, double alpha=0.1, std::string name="T_leaky_relu", std::string tag=kElementWise)
Creates an operation that performs a leaky rectified linear unit.
Definition: nn.h:76
tvm::PrimExpr divide(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:239
constexpr auto kDepthwiseConv2dNCHW
Definition: tags.h:40
tvm::te::Tensor depthwise_conv2d_nchw(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_depthwise_conv2d_nchw", std::string tag=kDepthwiseConv2dNCHW)
Creates an operation that performs a 2-D depthwise convolution with an NCHW-layout.
Definition: nn.h:356
constexpr auto kGroupConv2d
Definition: tags.h:45
tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor &data, const tvm::Array< Integer > &block_shape, const tvm::Array< tvm::PrimExpr > &pad_before, const tvm::Array< tvm::PrimExpr > &pad_after, PrimExpr pad_value=PrimExpr(), std::string name="space_to_batch_nd", std::string tag=kInjective)
Divide spatial dimensions of the input into a grid of blocks.
Definition: nn.h:482
constexpr auto kConv2dHWCN
Definition: tags.h:39
constexpr auto kDepthwiseConv2dNHWC
Definition: tags.h:41
tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_conv2d_nchw", std::string tag=kConv2dNCHW)
Creates an operation that performs a 2-D convolution with an NCHW-layout.
Definition: nn.h:268
Tensor sum(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:326
tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_conv2d_hwcn", std::string tag=kConv2dHWCN)
Creates an operation for 2-D convolution layer with an HWCN-layout.
Definition: nn.h:312
tvm::te::Tensor relu(const tvm::te::Tensor &t, T threshold=static_cast< T >(0), std::string name="T_relu", std::string tag=kElementWise)
Creates an operation that performs a rectified linear unit.
Definition: nn.h:55
Tensor nll_loss(const Tensor &predictions, const Tensor &targets, const Tensor &weights, std::string reduction="mean", int ignore_index=-100, const std::string name="nll_loss", const std::string tag=kBroadcast)
Negative log likelihood loss.
Definition: nn.h:660
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1913
tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_depthwise_conv2d_nhwc", std::string tag=kDepthwiseConv2dNHWC)
Definition: nn.h:385
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span=Span())
and
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Operation node can generate one or multiple Tensors.
Reduction op constructors.
Common operators defined for Expr.