24 #ifndef TVM_TOPI_DETAIL_STRIDED_SLICE_H_
25 #define TVM_TOPI_DETAIL_STRIDED_SLICE_H_
44 int64_t begin_range = stride < 0 ? -1 : 0;
45 int64_t end_range = stride < 0 ? extent - 1 : extent;
52 inline std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_t>> ConvertToVec(
53 const Array<Integer>& begin,
const Array<Integer>& end,
const Array<Integer>& strides,
54 std::string slice_mode) {
55 std::vector<int64_t> stride_vec(strides.size(), 1);
56 if (slice_mode ==
"end") {
57 for (
size_t i = 0; i < strides.size(); ++i) {
58 ICHECK(strides[i].defined());
59 stride_vec[i] = GetConstInt(strides[i]);
63 std::vector<int64_t> begin_vec;
64 for (
size_t i = 0; i < begin.size(); ++i) {
65 if (!begin[i].defined()) {
67 begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
69 begin_vec.push_back(GetConstInt(begin[i]));
72 std::vector<int64_t> end_vec;
73 for (
size_t i = 0; i < end.size(); ++i) {
75 if (!end[i].defined()) {
76 end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
77 }
else if (slice_mode ==
"size") {
78 int64_t end_val = GetConstInt(end[i]);
80 end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
82 end_vec.push_back(begin_vec[i] + end_val);
85 end_vec.push_back(GetConstInt(end[i]));
88 return std::make_tuple(begin_vec, end_vec, stride_vec);
91 inline Array<PrimExpr> StridedSliceCanonicalizeBegin(
const Array<PrimExpr>& ishape,
92 const std::vector<int64_t>& begin,
93 const std::vector<int64_t>& strides,
94 const Array<Integer>& axes,
DataType dtype,
95 std::string slice_mode =
"end") {
96 Array<PrimExpr> begin_expr;
97 for (
size_t i = 0; i < axes.size(); ++i) {
98 if (ishape[axes[i].IntValue()]->IsInstance<tvm::IntImmNode>()) {
99 int64_t dim_i = GetConstInt(ishape[axes[i].IntValue()]);
101 begin_expr.push_back(
make_const(dtype, begin_i));
103 auto idim = ishape[axes[i].IntValue()];
105 PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr;
112 begin_expr.push_back(b);
119 const std::vector<int64_t>& begin,
120 const std::vector<int64_t>& end,
121 const std::vector<int64_t>& strides,
122 const Array<Integer>& axes, std::string slice_mode,
123 const Array<PrimExpr>& begin_canonicalized,
124 bool use_any =
false) {
125 const size_t src_tensor_dim = ishape.size();
126 Array<PrimExpr> out_shape;
127 for (
size_t i = 0; i < src_tensor_dim; ++i) {
128 out_shape.push_back(ishape[i]);
131 for (
size_t i = 0; i < axes.size(); ++i) {
132 if (ishape[axes[i].IntValue()]->IsInstance<tvm::IntImmNode>()) {
133 const int64_t dim_i = GetConstInt(ishape[axes[i].IntValue()]);
134 ICHECK(begin_canonicalized[i]->IsInstance<tvm::IntImmNode>());
135 int64_t begin_i = GetConstInt(begin_canonicalized[i]);
137 int interval =
std::abs(end_i - begin_i);
139 static_cast<int>((interval +
std::abs(strides[i]) - 1) /
std::abs(strides[i]));
140 ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
141 <<
": Input [Begin=" << begin[i] <<
", End=" << end[i] <<
"] is invalid for axis=" << i;
142 out_shape.Set(axes[i].IntValue(),
cast(out_shape[i].dtype(), PrimExpr(slice_size)));
143 }
else if (use_any) {
146 out_shape.Set(axes[i].IntValue(),
tvm::tir::Var(
"dim", out_shape[i]->dtype));
Managed reference to AnyNode.
Definition: expr.h:1131
a named variable in TIR
Definition: var.h:89
Utility functions for handling constants in TVM expressions.
Tensor expression language DSL.
Definition: extracted_task.h:33
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:962
PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stride)
Definition: transform.h:669
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type,...
Definition: elemwise.h:281
Array< PrimExpr > StridedSliceOutputShape(const Array< PrimExpr > &ishape, const Array< Integer > &begin, const Array< Integer > &end, const Array< Integer > &strides, const Array< Integer > &axes, const std::string &slice_mode)
Calculate the output shape of strided_slice, the entry point for Relay type relation.
Definition: transform.h:856
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
runtime::DataType DataType
Definition: data_type.h:493
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
PrimExpr abs(PrimExpr x, Span span=Span())
Calculate absolute value of x.