24 #ifndef TVM_TOPI_NN_H_
25 #define TVM_TOPI_NN_H_
56 std::string name =
"T_relu", std::string tag =
kElementWise) {
59 [&](
const tvm::ffi::Array<tvm::tirx::Var>& i) {
60 auto threshold_const = tvm::tirx::make_const(t->dtype, threshold);
61 return tvm::max(t(i), threshold_const);
77 std::string name =
"T_leaky_relu",
81 [&](
const tvm::ffi::Array<tvm::tirx::Var>& i) {
83 auto calpha = tvm::tirx::make_const(value.dtype(), alpha);
84 return tvm::tirx::Select(value > 0, value, value * calpha);
101 const int axis = 1, std::string name =
"T_prelu",
103 TVM_FFI_ICHECK((
size_t)axis < x->
shape.size()) <<
"Wrong axis (" << axis <<
")value. ";
104 TVM_FFI_ICHECK(topi::detail::GetConstInt(slope->shape[0]) ==
105 topi::detail::GetConstInt(x->shape[axis]))
106 <<
"Wrong slope shape received.";
110 [&](
const tvm::ffi::Array<tvm::tirx::Var>& indices) {
111 auto xval = x(indices);
112 return tvm::tirx::Select(xval > 0, xval, xval * slope(indices[axis]));
157 const tvm::te::Tensor& t,
const tvm::ffi::Array<tvm::PrimExpr>& pad_before,
158 tvm::ffi::Array<tvm::PrimExpr> pad_after = tvm::ffi::Array<tvm::PrimExpr>(),
160 std::string pad_mode =
"constant",
const ffi::Array<PrimExpr>* dyn_output_shape =
nullptr) {
161 if (pad_after.size() < pad_before.size()) {
162 for (
size_t i = pad_after.size(); i < pad_before.size(); ++i) {
163 pad_after.push_back(pad_before[i]);
168 TVM_FFI_ICHECK_GE(pad_before.size(), 1);
169 TVM_FFI_ICHECK_EQ(pad_before.size(), pad_after.size());
170 tvm::ffi::Array<tvm::PrimExpr> pad_before_int32;
171 tvm::ffi::Array<tvm::PrimExpr> pad_after_int32;
173 for (
const auto& ele : pad_before) {
176 for (
const auto& ele : pad_after) {
180 tvm::ffi::Array<tvm::PrimExpr> output_shape;
181 if (dyn_output_shape ==
nullptr) {
182 for (
size_t i = 0; i < t->shape.size(); ++i) {
183 if (i >= pad_before.size()) {
184 output_shape.push_back(t->shape[i]);
186 output_shape.push_back(
187 analyzer.
Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i]));
191 for (
size_t i = 0; i < dyn_output_shape->size(); i++) {
192 output_shape.push_back((*dyn_output_shape)[i]);
196 if (!pad_value.defined()) {
200 auto l = [&](tvm::ffi::Array<tvm::tirx::Var> ovars) {
201 tvm::ffi::Array<tvm::PrimExpr> indices;
202 tvm::ffi::Array<tvm::PrimExpr> sel;
203 tvm::ffi::Array<tvm::PrimExpr> pad_idx;
204 for (
size_t i = 0; i < t->shape.size(); ++i) {
205 if (i >= pad_before_int32.size()) {
206 indices.push_back(ovars[i]);
209 if (!topi::detail::EqualCheck(pad_before_int32[i], 0)) {
210 sel.push_back(ovars[i] >= pad_before_int32[i]);
211 indices.push_back(ovars[i] - pad_before_int32[i]);
213 indices.push_back(ovars[i]);
215 if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) {
216 sel.push_back(analyzer.
Simplify(ovars[i] < pad_before_int32[i] + t->shape[i]));
218 if (pad_mode ==
"edge") {
222 t->shape[i] - 1, ovars[i] - pad_before[i])));
223 }
else if (pad_mode ==
"reflect") {
227 t->shape[i] * 2 - ovars[i] + pad_before[i] - 2,
228 ovars[i] - pad_before[i])));
231 if (sel.size() != 0) {
232 if (pad_mode ==
"constant") {
236 t(indices), pad_value);
237 }
else if (pad_mode ==
"edge" || pad_mode ==
"reflect") {
241 t(indices), t(pad_idx));
270 int pad_h = 0,
int pad_w = 0,
int stride_h = 1,
int stride_w = 1,
271 std::string name =
"T_conv2d_nchw",
273 TVM_FFI_ICHECK_EQ(4, I->shape.size());
274 TVM_FFI_ICHECK_EQ(4, W->shape.size());
275 auto pH = I->shape[2];
276 auto pW = I->shape[3];
277 tvm::ffi::Array<tvm::PrimExpr> output_shape{
280 indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1,
281 indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1
289 return tvm::sum(T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw), {i, kh, kw});
314 int pad_h = 0,
int pad_w = 0,
int stride_h = 1,
int stride_w = 1,
315 std::string name =
"T_conv2d_hwcn",
317 TVM_FFI_ICHECK_EQ(4, I->shape.size());
318 TVM_FFI_ICHECK_EQ(4, W->shape.size());
319 auto pH = I->shape[2];
320 auto pW = I->shape[3];
321 tvm::ffi::Array<tvm::PrimExpr> output_shape{
322 indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1,
323 indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1,
330 auto T = (pad_h == 0 && pad_w == 0) ? I :
pad(I, {pad_h, pad_w});
332 return tvm::sum(T(stride_h * h + kh, stride_w * w + kw, i, b) * W(kh, kw, i, o), {i, kh, kw});
358 int pad_h = 0,
int pad_w = 0,
int stride_h = 1,
360 std::string name =
"T_depthwise_conv2d_nchw",
362 TVM_FFI_ICHECK_EQ(4, I->shape.size());
363 TVM_FFI_ICHECK_EQ(4, W->shape.size());
364 auto pH = I->shape[2];
365 auto pW = I->shape[3];
366 auto pCM = W->shape[1];
367 tvm::ffi::Array<tvm::PrimExpr> output_shape{
370 indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1,
371 indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1
379 return tvm::sum(T(b,
indexdiv(i, pCM), stride_h * h + kh, stride_w * w + kw) *
387 int pad_h = 0,
int pad_w = 0,
int stride_h = 1,
389 std::string name =
"T_depthwise_conv2d_nhwc",
391 TVM_FFI_ICHECK_EQ(4, I->shape.size());
392 TVM_FFI_ICHECK_EQ(4, W->shape.size());
393 auto pH = I->shape[1];
394 auto pW = I->shape[2];
395 auto pCM = W->shape[1];
396 tvm::ffi::Array<tvm::PrimExpr> output_shape{
398 indexdiv(I->shape[1] - W->shape[1] + 2 * pad_h, stride_h) + 1,
399 indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1,
408 return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw,
indexdiv(i, pCM)) *
436 int pad_h = 0,
int pad_w = 0,
int stride_h = 1,
438 std::string name =
"T_group_conv2d_ngchw",
440 TVM_FFI_ICHECK_EQ(5, I->shape.size());
441 TVM_FFI_ICHECK_EQ(5, W->shape.size());
442 auto pH = I->shape[2];
443 auto pW = I->shape[3];
444 tvm::ffi::Array<tvm::PrimExpr> output_shape{
448 indexdiv(I->shape[3] - W->shape[3] + 2 * pad_h, stride_h) + 1,
449 indexdiv(I->shape[4] - W->shape[4] + 2 * pad_w, stride_w) + 1
455 auto T = (pad_h == 0 && pad_w == 0)
458 auto l = [&](tvm::ffi::Array<tvm::tirx::Var> args) {
464 return tvm::sum(I(b, g, i, stride_h * h + kh, stride_w * w + kw) * W(g, i, o, kh, kw),
484 const tvm::ffi::Array<Integer>& block_shape,
485 const tvm::ffi::Array<tvm::PrimExpr>& pad_before,
486 const tvm::ffi::Array<tvm::PrimExpr>& pad_after,
488 std::string name =
"space_to_batch_nd",
491 TVM_FFI_ICHECK_EQ(pad_before.size(), pad_after.size());
492 TVM_FFI_ICHECK_EQ(block_shape.size(), pad_before.size())
493 <<
"Paddings must be provided for each spatial dimension";
494 tvm::ffi::Array<tvm::PrimExpr> pad_before_int32;
495 tvm::ffi::Array<tvm::PrimExpr> pad_after_int32;
501 for (
const auto& ele : pad_before) {
504 for (
const auto& ele : pad_after) {
509 if (!pad_value.defined()) {
512 padded_t =
pad(data, pad_before_int32, pad_after_int32, pad_value);
514 auto input_shape = data->shape;
515 auto padded_shape = padded_t->shape;
518 tvm::ffi::Array<PrimExpr> r_shape;
519 tvm::ffi::Array<Integer> axis;
520 tvm::ffi::Array<PrimExpr> o_shape;
522 size_t num_block_dims = block_shape.size();
523 int batch =
static_cast<int>(GetConstInt(input_shape[0]));
525 r_shape.push_back(batch);
527 for (
size_t i = 1; i <= num_block_dims; i++) {
528 int padded_input =
static_cast<int>(GetConstInt(padded_shape[i]));
529 int block_size =
static_cast<int>(GetConstInt(block_shape[i - 1]));
530 TVM_FFI_ICHECK_EQ((padded_input % block_size), 0)
533 "Input dimension after padding ("
534 << padded_input <<
")"
535 <<
" must be divisible by its block size (" << block_size <<
")";
537 r_shape.push_back(
div(padded_shape[i], block_shape[i - 1]));
538 r_shape.push_back(block_shape[i - 1]);
539 block_shape_prod *= block_shape[i - 1];
540 axis.push_back(
Integer(r_shape.size() - 1));
543 size_t n = axis.size();
546 for (
size_t i = 0; i < n; i++) {
547 axis.push_back(
static_cast<int>(GetConstInt(axis[i] - 1)));
550 for (
size_t i = 1; i <= num_block_dims; i++) {
551 o_shape.push_back(
div(padded_shape[i], block_shape[i - 1]));
554 for (
size_t i = num_block_dims + 1; i < input_shape.size(); i++) {
555 r_shape.push_back(input_shape[i]);
556 axis.push_back(
Integer(r_shape.size() - 1));
557 o_shape.push_back(input_shape[i]);
562 output =
reshape(output, o_shape);
580 const tvm::ffi::Array<Integer>& block_shape,
581 const tvm::ffi::Array<tvm::PrimExpr>& crop_begin_list,
582 const tvm::ffi::Array<tvm::PrimExpr>& crop_end_list,
583 std::string name =
"batch_to_space_nd",
586 ffi::Array<PrimExpr> in_shape = data->shape;
587 ffi::Array<PrimExpr> r_shape;
588 ffi::Array<Integer> axis;
589 size_t num_block_dims = block_shape.size();
590 size_t num_input_dims = in_shape.size();
592 int batch =
static_cast<int>(GetConstInt(in_shape[0]));
594 for (
size_t i = 0; i < num_block_dims; i++) {
595 r_shape.push_back(block_shape[i]);
596 block_shape_prod *= block_shape[i];
598 axis.push_back(
Integer(r_shape.size()));
599 r_shape.push_back(batch / block_shape_prod);
601 for (
size_t i = 1; i < num_input_dims; i++) {
602 axis.push_back(
Integer(r_shape.size()));
603 if (axis.size() < (num_block_dims + num_input_dims)) {
604 axis.push_back(
Integer(r_shape.size() - (num_block_dims + 1)));
606 r_shape.push_back(in_shape[i]);
609 ffi::Array<PrimExpr> r_p_shape;
610 r_p_shape.push_back(batch / block_shape_prod);
611 for (
size_t i = 1; i <= num_block_dims; i++) {
612 r_p_shape.push_back(in_shape[i] * block_shape[i - 1]);
614 for (
size_t i = num_block_dims + 1; i < num_input_dims; i++) {
615 r_p_shape.push_back(in_shape[i]);
624 ffi::Array<Integer> begin_idx, end_idx, strides;
625 for (
size_t i = 0; i < r_p_shape.size(); ++i) {
627 if (i > 0 && i <= num_block_dims) {
629 int begin_i =
static_cast<int>(GetConstInt(crop_begin_list[i - 1]));
630 int end_i =
static_cast<int>(GetConstInt(crop_end_list[i - 1]));
631 int out_i =
static_cast<int>(GetConstInt(r_p_shape[i]));
632 TVM_FFI_ICHECK_GT(out_i, (begin_i + end_i))
633 <<
"Incorrect crop sizes for (" << i <<
")th dim, can not crop more than"
634 <<
" output size" << out_i <<
" vs " << (begin_i + end_i);
635 begin_idx.push_back(begin_i);
636 end_idx.push_back(out_i - end_i);
639 begin_idx.push_back(
Integer(0));
640 end_idx.push_back(
static_cast<int>(GetConstInt(r_p_shape[i])));
662 std::string reduction =
"mean",
int ignore_index = -100,
663 const std::string name =
"nll_loss",
const std::string tag =
kBroadcast) {
664 if (predictions.ndim() == 1) {
669 [&](
const tvm::ffi::Array<tvm::tirx::Var>& target_indices) {
675 if (reduction ==
"mean") {
678 [&](
const tvm::ffi::Array<tvm::tirx::Var>& target_indices) {
691 [&](
const tvm::ffi::Array<tvm::tirx::Var>& target_indices) {
692 auto c = targets(target_indices);
693 tvm::ffi::Array<tvm::PrimExpr> pred_indices;
694 pred_indices.push_back(target_indices[0]);
695 pred_indices.push_back(c);
696 for (size_t i = 1; i < target_indices.size(); i++) {
697 pred_indices.push_back(target_indices[i]);
699 return tvm::tirx::Select(c != ignore_index, -predictions(pred_indices) * weights(c),
703 TVM_FFI_ICHECK(T->shape.size() != 0);
704 if (reduction ==
"mean") {
707 [&](
const tvm::ffi::Array<tvm::tirx::Var>& target_indices) {
708 auto c = targets(target_indices);
709 return tvm::tirx::Select(c != ignore_index, weights(c),
710 tvm::tirx::make_const(predictions->dtype, 0));
714 topi::sum(W, tvm::ffi::Array<Integer>(
nullptr)));
715 }
else if (reduction ==
"sum") {
716 return topi::sum(T, tvm::ffi::Array<Integer>(
nullptr));
Algebra expression simplifications.
Container of constant int that adds more constructors.
Definition: expr.h:601
Reference to PrimExprNode.
Definition: expr.h:126
Range container
Definition: expr.h:690
Definition: source_map.h:111
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:634
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:278
Managed Tensor. The array is backed by reference counted blocks.
Definition: tensor.h:54
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:100
Managed reference to SelectNode.
Definition: expr.h:514
a named variable in TIR
Definition: var.h:76
Utility functions for handling constants in TVM expressions.
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(ffi::Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", ffi::Map< ffi::String, ffi::Any > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:1007
PrimExpr const_true(int lanes=1, Span span=Span())
Make a constant true expression.
Definition: op.h:830
PrimExpr foldl(FReduce freduce, PrimExpr init_value, const ffi::Array< PrimExpr > &values, Span span=Span())
Left fold.
Definition: op.h:912
constexpr auto kElementWise
Definition: tags.h:32
Tensor reshape(const Tensor &x, ffi::Array< PrimExpr > newshape, std::string name="T_reshape", std::string tag=kInjective)
Reshape a tensor.
Definition: transform.h:330
constexpr auto kBroadcast
Definition: tags.h:36
constexpr auto kInjective
Definition: tags.h:33
tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor &data, const tvm::ffi::Array< Integer > &block_shape, const tvm::ffi::Array< tvm::PrimExpr > &pad_before, const tvm::ffi::Array< tvm::PrimExpr > &pad_after, PrimExpr pad_value=PrimExpr(), std::string name="space_to_batch_nd", std::string tag=kInjective)
Divide spatial dimensions of the input into a grid of blocks.
Definition: nn.h:483
constexpr auto kConv2dNCHW
Definition: tags.h:38
tvm::te::Tensor prelu(const tvm::te::Tensor &x, const tvm::te::Tensor &slope, const int axis=1, std::string name="T_prelu", std::string tag=kBroadcast)
Creates an operation that performs a parametric rectified linear unit.
Definition: nn.h:100
tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor &data, const tvm::ffi::Array< Integer > &block_shape, const tvm::ffi::Array< tvm::PrimExpr > &crop_begin_list, const tvm::ffi::Array< tvm::PrimExpr > &crop_end_list, std::string name="batch_to_space_nd", std::string tag=kInjective)
Reshape the batch dimension into spatial dimensions.
Definition: nn.h:579
tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_group_conv2d_ngchw", std::string tag=kGroupConv2d)
Creates an operation that performs a 2-D group convolution with an NGCHW-layout.
Definition: nn.h:435
tvm::te::Tensor leaky_relu(const tvm::te::Tensor &t, double alpha=0.1, std::string name="T_leaky_relu", std::string tag=kElementWise)
Creates an operation that performs a leaky rectified linear unit.
Definition: nn.h:76
tvm::PrimExpr divide(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:241
Tensor transpose(const Tensor &x, ffi::Optional< ffi::Array< Integer >> opt_axes, std::string name="T_transpose", std::string tag=kInjective)
Permute the dimensions of an array.
Definition: transform.h:205
constexpr auto kDepthwiseConv2dNCHW
Definition: tags.h:40
tvm::te::Tensor depthwise_conv2d_nchw(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_depthwise_conv2d_nchw", std::string tag=kDepthwiseConv2dNCHW)
Creates an operation that performs a 2-D depthwise convolution with an NCHW-layout.
Definition: nn.h:357
constexpr auto kGroupConv2d
Definition: tags.h:45
tvm::te::Tensor pad(const tvm::te::Tensor &t, const tvm::ffi::Array< tvm::PrimExpr > &pad_before, tvm::ffi::Array< tvm::PrimExpr > pad_after=tvm::ffi::Array< tvm::PrimExpr >(), PrimExpr pad_value=PrimExpr(), std::string name="T_pad", std::string tag=kElementWise, std::string pad_mode="constant", const ffi::Array< PrimExpr > *dyn_output_shape=nullptr)
Creates an operation that performs padding.
Definition: nn.h:156
constexpr auto kConv2dHWCN
Definition: tags.h:39
constexpr auto kDepthwiseConv2dNHWC
Definition: tags.h:41
Tensor strided_slice(const Tensor &x, const ffi::Array< Integer > &begin, const ffi::Array< Integer > &end, const ffi::Array< Integer > &strides, std::string slice_mode="end", std::string name="T_strided_slice", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:962
tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_conv2d_nchw", std::string tag=kConv2dNCHW)
Creates an operation that performs a 2-D convolution with an NCHW-layout.
Definition: nn.h:269
Tensor sum(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:328
tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_conv2d_hwcn", std::string tag=kConv2dHWCN)
Creates an operation for 2-D convolution layer with an HWCN-layout.
Definition: nn.h:313
tvm::te::Tensor relu(const tvm::te::Tensor &t, T threshold=static_cast< T >(0), std::string name="T_relu", std::string tag=kElementWise)
Creates an operation that performs a rectified linear unit.
Definition: nn.h:55
Tensor nll_loss(const Tensor &predictions, const Tensor &targets, const Tensor &weights, std::string reduction="mean", int ignore_index=-100, const std::string name="nll_loss", const std::string tag=kBroadcast)
Negative log likelihood loss.
Definition: nn.h:661
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1981
tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor &I, const tvm::te::Tensor &W, int pad_h=0, int pad_w=0, int stride_h=1, int stride_w=1, std::string name="T_depthwise_conv2d_nhwc", std::string tag=kDepthwiseConv2dNHWC)
Definition: nn.h:386
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span=Span())
and
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr sum(PrimExpr source, ffi::Array< tirx::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
Operation node can generate one or multiple Tensors.
Reduction op constructors.
Common operators defined for Expr.