tvm
strided_slice.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_DETAIL_STRIDED_SLICE_H_
25 #define TVM_TOPI_DETAIL_STRIDED_SLICE_H_
26 
27 #include <tvm/tir/expr.h>
28 
29 #include <algorithm>
30 #include <limits>
31 #include <string>
32 #include <tuple>
33 #include <vector>
34 
35 #include "constant_utils.h"
36 
37 namespace tvm {
38 namespace topi {
39 namespace detail {
40 
41 using namespace tvm::te;
42 
43 inline int64_t CanonicalizeIndex(int64_t index, int64_t extent, int64_t stride) {
44  int64_t begin_range = stride < 0 ? -1 : 0;
45  int64_t end_range = stride < 0 ? extent - 1 : extent;
46  if (index < 0) {
47  index += extent;
48  }
49  return std::min(std::max(index, begin_range), end_range);
50 }
51 
52 inline std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_t>> ConvertToVec(
53  const ffi::Array<Integer>& begin, const ffi::Array<Integer>& end,
54  const ffi::Array<Integer>& strides, std::string slice_mode) {
55  std::vector<int64_t> stride_vec(strides.size(), 1);
56  if (slice_mode == "end") {
57  for (size_t i = 0; i < strides.size(); ++i) {
58  ICHECK(strides[i].defined());
59  stride_vec[i] = GetConstInt(strides[i]);
60  }
61  }
62  const int64_t max_range = std::numeric_limits<int64_t>::max();
63  std::vector<int64_t> begin_vec;
64  for (size_t i = 0; i < begin.size(); ++i) {
65  if (!begin[i].defined()) {
66  // value=None
67  begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
68  } else {
69  begin_vec.push_back(GetConstInt(begin[i]));
70  }
71  }
72  std::vector<int64_t> end_vec;
73  for (size_t i = 0; i < end.size(); ++i) {
74  // allow end to be None
75  if (!end[i].defined()) {
76  end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
77  } else if (slice_mode == "size") {
78  int64_t end_val = GetConstInt(end[i]);
79  if (end_val < 0) {
80  end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
81  } else {
82  end_vec.push_back(begin_vec[i] + end_val);
83  }
84  } else {
85  end_vec.push_back(GetConstInt(end[i]));
86  }
87  }
88  return std::make_tuple(begin_vec, end_vec, stride_vec);
89 }
90 
91 inline ffi::Array<PrimExpr> StridedSliceCanonicalizeBegin(const ffi::Array<PrimExpr>& ishape,
92  const std::vector<int64_t>& begin,
93  const std::vector<int64_t>& strides,
94  const ffi::Array<Integer>& axes,
95  DataType dtype,
96  std::string slice_mode = "end") {
97  ffi::Array<PrimExpr> begin_expr;
98  for (size_t i = 0; i < axes.size(); ++i) {
99  if (ishape[axes[i].IntValue()]->IsInstance<tvm::IntImmNode>()) {
100  int64_t dim_i = GetConstInt(ishape[axes[i].IntValue()]);
101  int64_t begin_i = CanonicalizeIndex(begin[i], dim_i, strides[i]);
102  begin_expr.push_back(make_const(dtype, begin_i));
103  } else {
104  auto idim = ishape[axes[i].IntValue()];
105  auto b_expr = make_const(dtype, begin[i]);
106  PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr;
107  auto s = strides[i];
108  if (s < 0) {
109  b = tvm::min(b, idim - 1);
110  } else {
111  b = tvm::if_then_else(b < 0, 0, b);
112  }
113  begin_expr.push_back(b);
114  }
115  }
116  return begin_expr;
117 }
118 
119 inline ffi::Array<PrimExpr> StridedSliceOutputShape(
120  const ffi::Array<PrimExpr>& ishape, const std::vector<int64_t>& begin,
121  const std::vector<int64_t>& end, const std::vector<int64_t>& strides,
122  const ffi::Array<Integer>& axes, std::string slice_mode,
123  const ffi::Array<PrimExpr>& begin_canonicalized, bool use_any = false) {
124  ICHECK(!use_any) << "StridedSliceOutputShape does not legacy use_any";
125  const size_t src_tensor_dim = ishape.size();
126  ffi::Array<PrimExpr> out_shape;
127  for (size_t i = 0; i < src_tensor_dim; ++i) {
128  out_shape.push_back(ishape[i]);
129  }
130 
131  for (size_t i = 0; i < axes.size(); ++i) {
132  if (ishape[axes[i].IntValue()]->IsInstance<tvm::IntImmNode>()) {
133  const int64_t dim_i = GetConstInt(ishape[axes[i].IntValue()]);
134  ICHECK(begin_canonicalized[i]->IsInstance<tvm::IntImmNode>());
135  int64_t begin_i = GetConstInt(begin_canonicalized[i]);
136  int64_t end_i = CanonicalizeIndex(end[i], dim_i, strides[i]);
137  int interval = std::abs(end_i - begin_i);
138  int slice_size =
139  static_cast<int>((interval + std::abs(strides[i]) - 1) / std::abs(strides[i]));
140  ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
141  << ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is invalid for axis=" << i;
142  out_shape.Set(axes[i].IntValue(), cast(out_shape[i].dtype(), PrimExpr(slice_size)));
143  } else {
144  out_shape.Set(axes[i].IntValue(), tvm::tir::Var("dim", out_shape[i]->dtype));
145  }
146  }
147 
148  return out_shape;
149 }
150 
151 } // namespace detail
152 } // namespace topi
153 } // namespace tvm
154 #endif // TVM_TOPI_DETAIL_STRIDED_SLICE_H_
a named variable in TIR
Definition: var.h:77
Utility functions for handling constants in TVM expressions.
Tensor expression language DSL.
Definition: extracted_task.h:33
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:994
ffi::Array< PrimExpr > StridedSliceOutputShape(const ffi::Array< PrimExpr > &ishape, const ffi::Array< Integer > &begin, const ffi::Array< Integer > &end, const ffi::Array< Integer > &strides, const ffi::Array< Integer > &axes, const std::string &slice_mode)
Calculate the output shape of strided_slice, the entry point for Relax type relation.
Definition: transform.h:864
PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stride)
Definition: transform.h:675
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type,...
Definition: elemwise.h:281
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
runtime::DataType DataType
Definition: data_type.h:458
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
PrimExpr abs(PrimExpr x, Span span=Span())
Calculate absolute value of x.
TIR expressions.