24 #ifndef TVM_TOPI_TRANSFORM_H_ 25 #define TVM_TOPI_TRANSFORM_H_ 42 #include <unordered_set> 49 using namespace topi::detail;
70 std::string tag =
"") {
72 auto _axis = size_t(axis);
73 CHECK_LT(_axis, x->shape.size()) <<
"axis must be a valid dimension index of x.";
74 CHECK_EQ(x->shape.size() - _axis, window_shape.
size())
75 <<
"There must be a window shape for every dimension of x " 76 <<
"over which we are sliding the window.";
77 CHECK_EQ(strides.
size(), window_shape.
size()) <<
"Windows and strides should be the same length.";
82 for (
size_t i = 0; i < _axis; ++i) {
88 for (
size_t i = 0; i < window_shape.
size(); ++i) {
90 auto dim_len = x->shape[_axis + i];
92 auto window_len = window_shape[i];
94 auto stride = strides[i];
100 for (
size_t i = 0; i < window_shape.
size(); ++i) {
101 new_shape.push_back(window_shape[i]);
104 ICHECK(new_shape.size() == _axis + 2 * window_shape.
size());
113 for (
size_t i = 0; i < _axis; ++i) {
117 for (
size_t i = 0; i < window_shape.
size(); ++i) {
119 auto window_idx = indices[_axis + i];
121 auto idx_within_window = indices[_axis + window_shape.
size() + i];
123 auto stride = strides[i];
125 idx.
push_back(window_idx * stride + idx_within_window);
128 ICHECK(idx.
size() == x->shape.size());
148 std::string name =
"T_expand_dims", std::string tag =
kBroadcast) {
149 int ndim =
static_cast<int>(x->shape.size());
150 ICHECK(-ndim - 1 <= axis && axis <= ndim)
151 <<
"expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]" 152 <<
", but got axis = " << axis <<
", and data.ndim = " << ndim;
153 ICHECK(num_newaxis >= 0) <<
"expand_dims only accepts `num_newaxis >= 0`" 154 <<
", but got num_newaxis = " << num_newaxis;
157 axis = ndim + axis + 1;
160 for (
size_t i = 0; i < static_cast<size_t>(axis); ++i) {
163 for (
size_t i = 0; i < static_cast<size_t>(num_newaxis); ++i) {
166 for (
size_t i = axis; i < x->shape.size(); ++i) {
174 for (
size_t i = 0; i < static_cast<size_t>(axis); ++i) {
177 for (
size_t i = axis + num_newaxis; i < indices.
size(); ++i) {
200 for (
int i = static_cast<int>(x->shape.size()) - 1; i >= 0; --i) {
206 for (
size_t i = 0; i < axes.
size(); ++i) {
207 int axis =
static_cast<int>(axes[i]->value);
210 new_axis =
static_cast<int>(x->shape.size()) + axis;
211 axes.
Set(i, new_axis);
213 ICHECK((new_axis >= 0) && (new_axis < static_cast<int>(x->shape.size())))
214 <<
"axis=" << axis <<
" is invalid for the " << static_cast<int>(x->shape.size())
215 <<
"-dimensional input tensor";
217 for (
size_t j = 0; j < axes.
size(); ++j) {
219 ICHECK(new_axis != static_cast<int>(axes[j]->value)) <<
"repeated axis in transpose";
228 std::vector<PrimExpr> idx;
229 for (
size_t i = 0; i < axes.
size(); ++i) {
232 for (
size_t i = 0; i < axes.
size(); ++i) {
233 int axis =
static_cast<int>(axes[i]->value);
234 idx[axis] = indices[i];
256 int batch_axis = 0, std::string name =
"T_reverse_sequence",
258 size_t src_tensor_dim = x->shape.size();
259 int seq_axis_inp = seq_axis;
262 size_t seq_lengths_dim = seq_lengths->shape.size();
263 int batch_axis_inp = batch_axis;
264 if (batch_axis < 0) {
265 batch_axis =
static_cast<int>(x->shape.size()) + batch_axis;
268 ICHECK(seq_lengths_dim == 1) <<
"seq_lengths should be 1D vector";
270 ICHECK(GetConstInt(seq_lengths->shape[0]) == GetConstInt(x->shape[batch_axis]))
271 <<
"For reverse_sequnece seq_lengths size should match with dimension of batch axis" 272 <<
", but got dimension of batch_axis = " << GetConstInt(x->shape[batch_axis])
273 <<
", and seq_length size = " << GetConstInt(seq_lengths->shape[0]);
275 ICHECK((0 <= batch_axis) && (batch_axis < static_cast<int>(x->shape.size())))
276 <<
"batch_axis=" << batch_axis_inp <<
" is invalid for the " 277 << static_cast<int>(x->shape.size()) <<
"-dimensional input tensor";
281 seq_axis =
static_cast<int>(x->shape.size()) + seq_axis;
283 ICHECK((0 <= seq_axis) && (seq_axis < static_cast<int>(x->shape.size())))
284 <<
"seq_axis=" << seq_axis_inp <<
" is invalid for the " << static_cast<int>(x->shape.size())
285 <<
"-dimensional input tensor";
289 for (
size_t i = 0; i < src_tensor_dim; ++i) {
290 if (i == static_cast<size_t>(seq_axis)) {
292 auto len = seq_lengths(indices[batch_axis]);
294 len <= 1 || len <= indices[i], indices[i],
295 if_then_else(len > x->shape[i], x->shape[i] - 1 - indices[i], len - 1 - indices[i]));
298 real_indices.
push_back(x->shape[i] - 1 - indices[i]);
304 return x(real_indices);
307 return compute(x->shape, func, name, tag);
322 auto x_shape = x->shape;
325 for (
const auto& ele : newshape) {
334 if (is_empty_shape(target_shape) || is_empty_shape(x->shape)) {
336 target_shape, [&](
const Array<Var>& indices) {
return tvm::cast(x->dtype, 0); }, name, tag);
341 return x(UnravelIndex(
361 auto x_shape = x->shape;
362 auto shape_shape = shape->shape;
366 if (x_shape.size() != 0) {
372 std::vector<PrimExpr> indices_divs;
377 if (x_shape.size() != 0) {
378 index_val = x[indices[1]];
382 indices_divs.push_back(index_val);
383 for (
int v = GetConstInt(shape_shape[0]) - 1; v >= 0; --v) {
385 cur_val =
indexdiv(indices_divs.back(), shape[v]);
386 indices_divs.push_back(cur_val);
391 return compute(oshape, func, name, tag);
408 std::string name =
"T_squeeze", std::string tag =
kInjective) {
409 auto ndim = x->shape.size();
410 std::vector<int> axis_val;
412 for (
size_t i = 0; i < ndim; ++i) {
413 if (IsConstInt(x->shape[i]) && GetConstInt(x->shape[i]) == 1) {
414 axis_val.push_back(static_cast<int>(i));
418 for (
size_t i = 0; i < axis.
size(); ++i) {
419 int64_t val = axis[i]->value;
421 val +=
static_cast<int>(x->shape.size());
423 if (IsConstInt(x->shape[val])) {
424 ICHECK_EQ(GetConstInt(x->shape[val]), 1) <<
"Dimension " << val <<
" must have size 1";
430 std::unordered_set<int> axis_set(axis_val.begin(), axis_val.end());
433 for (
size_t i = 0; i < ndim; ++i) {
434 if (axis_set.count(static_cast<int>(i)) == 0) {
438 if (out_shape.size() == 0 && atleast1d) {
439 out_shape.push_back(1);
447 for (
size_t i = 0; i < ndim; ++i) {
448 if (axis_set.count(static_cast<int>(i)) == 0) {
449 real_indices.
push_back(indices[i - flag]);
455 return x(real_indices);
472 int ndim =
static_cast<int>(inputs[0]->shape.
size());
473 ICHECK(-ndim <= axis && axis < ndim) <<
"concatenate only accepts `axis` in [-ndim, ndim)" 474 <<
", but got axis = " << axis <<
", and ndim = " << ndim;
478 ICHECK_LT(axis, inputs[0]->
shape.size()) <<
"axis out of bounds";
481 for (
auto t : inputs) {
486 for (
size_t i = 1; i < axis_sizes.size(); ++i) {
487 join_size += axis_sizes[i];
489 join_size = analyzer.
Simplify(join_size);
491 for (
size_t i = 0; i < inputs[0]->shape.size(); ++i) {
492 out_shape.
push_back(i == static_cast<size_t>(axis) ? join_size : inputs[0]->
shape[i]);
498 auto ret = inputs[0](indices);
499 auto ind = indices[axis];
500 for (
size_t i = 0; i < inputs.size() - 1; ++i) {
501 ind -= axis_sizes[i];
504 for (
size_t i = 0; i < static_cast<size_t>(axis); ++i) {
508 for (
size_t i = axis + 1; i < indices.
size(); ++i) {
531 int ndim =
static_cast<int>(inputs[0]->shape.
size());
532 ICHECK(-ndim - 1 <= axis && axis <= ndim)
533 <<
"stack only accepts `axis` in [-ndim, ndim)" 534 <<
", but got axis = " << axis <<
", and ndim = " << ndim;
538 ICHECK_LT(axis, inputs[0]->
shape.size() + 1) <<
"axis out of bounds";
540 const int stack_size =
static_cast<int>(inputs.
size());
542 for (
size_t i = 0; i < static_cast<size_t>(axis); ++i) out_shape.
push_back(inputs[0]->shape[i]);
544 for (
size_t i = static_cast<size_t>(axis); i < static_cast<size_t>(ndim); ++i)
545 out_shape.
push_back(inputs[0]->shape[i]);
551 for (
size_t i = 0; i < indices.
size(); ++i)
552 if (i != static_cast<size_t>(axis)) idx.
push_back(indices[i]);
553 auto ind = indices[axis];
554 auto ret = inputs[0](idx);
555 for (
int i = 0; i < static_cast<int>(inputs.
size() - 1); ++i) {
576 std::string name =
"T_split", std::string tag =
kInjective) {
578 axis +=
static_cast<int>(x->shape.size());
580 ICHECK_LT(axis, x->shape.size()) <<
"axis out of bounds";
582 auto src_axis_size = x->shape[axis];
583 std::vector<PrimExpr> begin_ids;
584 begin_ids.push_back(0);
586 for (
auto idx : split_indices) {
588 auto back_node = begin_ids.back().as<
IntImmNode>();
589 if (idx_node && back_node) {
590 ICHECK_GT(idx_node->value, back_node->
value) <<
"split_indices must be sorted";
592 begin_ids.push_back(idx);
596 for (
size_t i = 0; i < begin_ids.size(); ++i) {
598 if (i == begin_ids.size() - 1) {
599 out_axis_size = src_axis_size - begin_ids[i];
601 out_axis_size = begin_ids[i + 1] - begin_ids[i];
605 for (
size_t i = 0; i < static_cast<size_t>(axis); ++i) {
609 for (
size_t i = axis + 1; i < x->shape.size(); ++i) {
617 for (
size_t i = 0; i < begin_ids.size(); ++i) {
621 auto begin = begin_ids[i];
623 for (
size_t j = 0; j < static_cast<size_t>(axis); ++j) {
626 real_indices.
push_back(indices[axis] + begin);
627 for (
size_t j = axis + 1; j < indices.
size(); ++j) {
631 return x(real_indices);
654 std::string name =
"T_dynamic_strided_slice",
656 const size_t src_tensor_dim = x->shape.size();
657 ICHECK_LE(begin.
size(), src_tensor_dim);
658 ICHECK_LE(end.
size(), src_tensor_dim);
659 ICHECK_LE(strides.
size(), src_tensor_dim);
660 ICHECK_EQ(begin.
size(), end.
size());
661 ICHECK_EQ(begin.
size(), strides.
size());
663 const size_t num_slice_axes = begin.
size();
666 for (
size_t i = 0; i < num_slice_axes; ++i) {
667 auto d =
indexdiv(end[i] - begin[i], strides[i]);
676 for (
size_t i = num_slice_axes; i < src_tensor_dim; ++i) {
684 for (
size_t i = 0; i < num_slice_axes; ++i) {
685 real_indices.
push_back(indices[i] * strides[i] +
tvm::min(begin[i], x->shape[i] - 1));
688 for (
size_t i = num_slice_axes; i < src_tensor_dim; ++i) {
691 return x(real_indices);
711 std::string name =
"T_strided_slice_dynamic",
713 DataType index_dtype = begin->shape[0]->dtype;
714 const int64_t num_dynamic_axes = begin->shape[0].
as<
IntImmNode>()->value;
719 for (int64_t i = 0; i < num_dynamic_axes; ++i) {
746 std::vector<int64_t> begin_vec, end_vec, strides_vec;
747 std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode);
748 auto begin_canonicalized = StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes,
749 begin[0]->dtype, slice_mode);
751 begin_canonicalized,
true);
773 std::string name =
"T_strided_slice_with_axes",
775 const size_t src_tensor_dim = x->shape.size();
776 ICHECK(axes.
size() <= src_tensor_dim);
779 std::vector<int64_t> begin_vec, end_vec, strides_vec;
780 std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode);
782 auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, axes,
783 begin[0]->dtype, slice_mode);
785 slice_mode, begin_expr);
791 for (
size_t i = 0; i < out_shape.size(); ++i) real_indices.
push_back(indices[i]);
792 for (
size_t i = 0; i < axes.
size(); ++i) {
793 auto stride =
make_const(strides[i].dtype(), strides_vec[i]);
794 PrimExpr ind = indices[axes[i].IntValue()] * stride + begin_expr[i];
795 real_indices.
Set(axes[i].IntValue(), ind);
797 return x(real_indices);
818 std::string name =
"T_strided_slice", std::string tag =
kInjective) {
819 size_t src_tensor_dim =
static_cast<size_t>(x->shape.size());
821 for (
size_t i = 0; i < src_tensor_dim; ++i) axes.
push_back(i);
831 for (
size_t i = strides.
size(); i < src_tensor_dim; ++i) {
834 for (
size_t i = begin.
size(); i < src_tensor_dim; ++i) {
835 begin_full.
push_back(GetConstInt(strides_full[i]) > 0 ? zero : max_range);
837 for (
size_t i = end.
size(); i < src_tensor_dim; ++i) {
838 end_full.
push_back(GetConstInt(strides_full[i]) < 0 ? zero : max_range);
858 std::string name =
"T_split_sections",
861 axis +=
static_cast<int>(x->shape.size());
863 ICHECK_LT(axis, x->shape.size()) <<
"axis out of bounds";
865 auto src_axis_size = x->shape[axis];
867 ICHECK_GT(num_sections, 0) <<
"Slice count must be > 0";
869 if (
auto node = src_axis_size.as<
IntImmNode>()) {
870 ICHECK_EQ(node->value % num_sections, 0)
871 <<
"num_sections must be an integer factor of the size of axis " << axis <<
" (" 872 << node->value <<
")";
876 auto seg_size =
indexdiv(src_axis_size, num_sections);
877 for (
int i = 0; i < num_sections; ++i) {
884 return split(x, split_indices, axis, name, tag);
901 std::string mode =
"clip", std::string name =
"T_take",
906 for (
size_t i = 0; i < a_shape.
size(); ++i) {
907 a_size = a_size * a_shape[i];
910 if (mode ==
"clip") {
915 return a(UnravelIndex(idx, a_shape));
918 }
else if (mode ==
"fast") {
919 LOG(WARNING) <<
"Fast mode segfaults when there are out-of-bounds indices. " 920 "Make sure input indices are in bound";
923 [&](
const Array<Var>& out_index) {
return a(UnravelIndex(indices(out_index), a_shape)); },
930 return a(UnravelIndex(idx, a_shape));
949 int axis, std::string name =
"T_sequence_mask",
951 ICHECK(axis == 0 || axis == 1) <<
"axis must be either 0 or 1";
952 ICHECK_EQ(valid_length->shape.size(), 1) <<
"valid_length must have ndim=1, i.e., (batch_size,).";
953 auto length_dim = data->shape[axis];
954 auto batch_dim = data->shape[1 - axis];
960 auto tid = out_index[axis];
961 auto bid = out_index[1 - axis];
987 std::string mode =
"clip", std::string name =
"T_take",
990 axis +=
static_cast<int>(a->shape.size());
992 ICHECK_GE(axis, 0) <<
"axis out of bounds";
993 ICHECK_LT(axis, a->shape.size()) <<
"axis out of bounds";
994 auto axis_dim = a->shape[axis];
995 int indices_len =
static_cast<int>(indices->shape.size());
997 int batch_dims_ = batch_dims;
998 if (batch_dims_ != 0) {
999 ICHECK_GE(batch_dims_, -static_cast<int>(indices->shape.size())) <<
"batch_dims out of bounds";
1000 ICHECK_LE(batch_dims_, indices->shape.size()) <<
"batch_dims out of bounds";
1002 if (batch_dims_ < 0) {
1003 batch_dims_ = indices->shape.size() + batch_dims_;
1006 ICHECK_LT(batch_dims_, a->shape.size()) <<
"batch_dims out of bounds";
1007 ICHECK_LE(batch_dims_, axis) <<
"batch_dims must be less than or equal to axis";
1008 for (
int i = 0; i < batch_dims_; ++i) {
1009 auto addr1 = a->shape[i];
1010 auto addr2 = indices->shape[i];
1011 auto v1 =
static_cast<IntImm*
>(&addr1)->
get()->value;
1012 auto v2 =
static_cast<IntImm*
>(&addr2)->
get()->value;
1013 ICHECK_EQ(v1, v2) <<
"a.shape[" << i <<
"] should be equal to indices.shape[" << i <<
"]";
1021 for (
int i = 0; i < batch_dims_; ++i) {
1024 for (
int i = batch_dims_; i < axis; ++i) {
1027 for (
size_t i = static_cast<size_t>(batch_dims_); i < indices->shape.size(); ++i) {
1030 for (
size_t i = axis + 1; i < a->shape.size(); ++i) {
1034 if (mode ==
"clip") {
1035 if (batch_dims_ == 0) {
1040 for (
size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1041 indices_position.
push_back(out_index[j]);
1044 for (
size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1049 for (
size_t j = axis + indices_len; j < out_index.
size(); ++j) {
1052 return a(real_indices);
1060 for (
size_t j = 0; j < static_cast<size_t>(batch_dims_); ++j) {
1061 indices_position.
push_back(out_index[j]);
1063 for (
size_t j = axis; j < static_cast<size_t>(axis + indices_len - batch_dims_); ++j) {
1064 indices_position.
push_back(out_index[j]);
1067 for (
size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1072 for (
size_t j = axis + indices_len - batch_dims_; j < out_index.
size(); ++j) {
1075 return a(real_indices);
1079 }
else if (mode ==
"fast") {
1080 LOG(WARNING) <<
"Fast mode segfaults when there are out-of-bounds indices. " 1081 "Make sure input indices are in bound";
1086 for (
size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1087 indices_position.
push_back(out_index[j]);
1090 for (
size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1093 real_indices.
push_back(indices(indices_position));
1094 for (
size_t j = axis + indices_len; j < out_index.
size(); ++j) {
1097 return a(real_indices);
1105 for (
size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1106 indices_position.
push_back(out_index[j]);
1109 for (
size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1112 auto idx =
truncmod(
truncmod(indices(indices_position), axis_dim) + axis_dim, axis_dim);
1114 for (
size_t j = axis + indices_len; j < out_index.
size(); ++j) {
1117 return a(real_indices);
1135 std::string name =
"T_where", std::string tag =
kBroadcast) {
1136 ICHECK_EQ(x->dtype, y->dtype) <<
"x and y must have the same dtype: " << x->dtype <<
" vs " 1138 auto get_out_shape = [&]() {
1139 auto bh1 = detail::BroadcastShape(x->shape, y->shape);
1140 Array<PrimExpr> common_shape1(bh1.common_shape.begin(), bh1.common_shape.end());
1141 auto bh2 = detail::BroadcastShape(condition->shape, common_shape1);
1143 return common_shape2;
1146 auto oshape = get_out_shape();
1148 auto c_bh = detail::BroadcastShape(condition->shape, oshape);
1149 auto x_bh = detail::BroadcastShape(x->shape, oshape);
1150 auto y_bh = detail::BroadcastShape(y->shape, oshape);
1153 auto c = condition(InputIndexFromBroadcast(ovars, condition, c_bh.vars1, c_bh.all_vars));
1154 auto true_val = x(InputIndexFromBroadcast(ovars, x, x_bh.vars1, x_bh.all_vars));
1155 auto false_val = y(InputIndexFromBroadcast(ovars, y, y_bh.vars1, y_bh.all_vars));
1159 return compute(oshape, select, name, tag);
1176 int ndim =
static_cast<int>(x->shape.size());
1177 ICHECK(-ndim - 1 <= axis && axis <= ndim)
1178 <<
"repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" 1179 <<
", but got axis = " << axis <<
", and data.ndim = " << ndim;
1180 ICHECK(repeats >= 1) <<
"repeat only accepts `repeats >= 1`" 1181 <<
", but got repeats = " << repeats;
1187 for (
size_t i = 0; i < static_cast<size_t>(axis); ++i) {
1190 new_shape.
push_back(repeats * x->shape[axis]);
1191 for (
size_t i = axis + 1; i < x->shape.size(); ++i) {
1199 for (
size_t i = 0; i < static_cast<size_t>(axis); ++i) {
1203 for (
size_t i = axis + 1; i < indices.
size(); ++i) {
1223 size_t ndim = x->shape.size();
1224 size_t rdim = reps.
size();
1225 size_t tdim = (ndim > rdim) ? ndim : rdim;
1230 for (
size_t i = 0; i < ndim; ++i) {
1234 }
else if (ndim > rdim) {
1235 for (
size_t i = 0; i < ndim; ++i) data_shape.
push_back(x->shape[i]);
1236 for (
size_t i = 0; i < (ndim - rdim); ++i) reps_shape.
push_back(1);
1237 for (
size_t i = 0; i < rdim; ++i) reps_shape.
push_back(reps[i]);
1239 for (
size_t i = 0; i < (rdim - ndim); ++i) data_shape.
push_back(1);
1240 for (
size_t i = 0; i < ndim; ++i) data_shape.
push_back(x->shape[i]);
1241 for (
size_t i = 0; i < rdim; ++i) reps_shape.
push_back(reps[i]);
1243 for (
size_t i = 0; i < tdim; ++i) new_shape.
push_back(data_shape[i] * reps_shape[i]);
1245 if (is_empty_shape(new_shape)) {
1254 for (
size_t i = 0; i < ndim; ++i) idx.
push_back(
indexmod(indices[i], x->shape[i]));
1256 for (
size_t i = 0; i < ndim; ++i)
1277 std::string name =
"T_tile", std::string tag =
kBroadcast) {
1278 size_t ndim = x->shape.size();
1279 if (is_empty_shape(new_shape)) {
1288 for (
size_t i = 0; i < ndim; ++i) {
1292 for (
size_t i = 0; i < ndim; ++i) {
1314 std::string name =
"T_gather", std::string tag =
kInjective) {
1315 size_t ndim_d = data->shape.size();
1316 size_t ndim_i = indices->shape.size();
1317 ICHECK_GE(ndim_d, 1) <<
"Cannot gather from a scalar.";
1318 ICHECK_EQ(ndim_d, ndim_i);
1323 ICHECK_LT(axis, ndim_d);
1325 size_t indices_dim_i =
static_cast<size_t>(GetConstInt(indices->shape[axis]));
1326 ICHECK_GE(indices_dim_i, 1);
1328 ICHECK(indices->dtype.is_int() || indices->dtype.is_uint());
1331 for (
size_t i = 0; i < ndim_i; ++i) {
1339 for (
size_t i = 0; i < ndim_i; ++i) {
1340 indices_position.
push_back(out_index[i]);
1343 for (
size_t i = 0; i < ndim_i; ++i) {
1344 if (i == static_cast<size_t>(axis)) {
1345 real_indices.
push_back(indices(indices_position));
1347 real_indices.
push_back(indices_position[i]);
1350 return data(real_indices);
1367 std::string name =
"T_gather_nd", std::string tag =
kInjective) {
1368 size_t ndim_d = data->shape.size();
1369 size_t ndim_i = indices->shape.size();
1370 ICHECK_GE(ndim_i, 1) <<
"indices tensor must have at least 1 dimensions";
1371 size_t indices_dim0 =
static_cast<size_t>(GetConstInt(indices->shape[0]));
1372 ICHECK_LE(indices_dim0, ndim_d) <<
"dim 0 of indices tensor must be no more " 1373 <<
"than dimensions of data tensor";
1375 for (
size_t i = 1; i < ndim_i; ++i) {
1378 for (
size_t i = indices_dim0 + batch_dims; i < ndim_d; ++i) {
1386 for (
size_t i = 0; i < ndim_i - 1; ++i) {
1387 indices_position.
push_back(out_index[i]);
1390 for (
size_t i = 0; i < static_cast<size_t>(batch_dims); ++i) {
1393 for (
size_t i = 0; i < indices_dim0; ++i) {
1395 if (indices->dtype.is_int() || indices->dtype.is_uint()) {
1396 real_indices.
push_back(indices(indices_position));
1401 if (real_indices.
size() == ndim_d) {
1402 return data(real_indices);
1404 for (
size_t i = ndim_i - 1; i < out_index.
size(); ++i) {
1407 return data(real_indices);
1428 bool trans_a =
false,
bool trans_b =
false,
1429 std::string name =
"T_matmul", std::string tag =
kMatMul) {
1433 return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), {k});
1450 std::string name =
"T_tensordot", std::string tag =
kMatMul) {
1451 ICHECK_GE(A->shape.size(), axes);
1452 ICHECK_GE(B->shape.size(), axes);
1454 Array<PrimExpr> output_shape(A->shape.begin(), A->shape.end() + (-axes));
1455 for (
auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) output_shape.push_back(*it);
1458 for (
int i = 0; i < axes; ++i)
1461 auto func = [&A, &B, &iter_vars, axes](
const Array<Var>& input_indices) {
1463 input_indices.begin() + (A->shape.size() - axes));
1464 for (
auto& v : iter_vars) A_indices.
push_back(v);
1467 for (
auto& v : iter_vars) B_indices.
push_back(v);
1469 auto it = input_indices.begin() + (A->shape.size() - axes);
1470 for (; it != input_indices.end(); ++it) B_indices.
push_back(*it);
1473 if (iter_vars.empty()) {
1474 return A(A_indices) * B(B_indices);
1476 return sum(A(A_indices) * B(B_indices), iter_vars);
1480 return compute(output_shape, func, name, tag);
1498 ICHECK_EQ(A_axes.
size(), B_axes.
size());
1500 auto A_axes_val = GetConstIntValues(A_axes,
"A_axes");
1501 auto B_axes_val = GetConstIntValues(B_axes,
"B_axes");
1504 for (
unsigned i = 0; i < A->shape.size(); ++i)
1505 if (std::find(A_axes_val.begin(), A_axes_val.end(), i) == A_axes_val.end())
1507 for (
unsigned i = 0; i < B->shape.size(); ++i)
1508 if (std::find(B_axes_val.begin(), B_axes_val.end(), i) == B_axes_val.end())
1512 for (
unsigned i = 0; i < B_axes_val.size(); ++i)
1515 auto func = [&A, &B, &iter_vars, A_axes_val, B_axes_val](
const Array<Var>& input_indices) {
1518 for (
unsigned i = 0; i < A->shape.size(); ++i) {
1519 auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i);
1520 if (axes_pos == A_axes_val.end()) {
1521 A_indices.
push_back(input_indices[idx_input++]);
1523 A_indices.
push_back(iter_vars[axes_pos - A_axes_val.begin()]);
1528 for (
unsigned i = 0; i < B->shape.size(); ++i) {
1529 auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i);
1530 if (axes_pos == B_axes_val.end()) {
1531 B_indices.
push_back(input_indices[idx_input++]);
1533 B_indices.
push_back(iter_vars[axes_pos - B_axes_val.begin()]);
1536 return sum(A(A_indices) * B(B_indices), iter_vars);
1538 return compute(output_shape, func, name, tag);
1548 [&](
const Array<Var>& indices) {
return tvm::cast(dtype, start + step * indices[0]); }, name,
1563 std::string name =
"T_meshgrid", std::string tag =
kInjective) {
1564 const bool cartesian_indexing = indexing ==
"xy" && inputs.
size() >= 2;
1566 for (
size_t i = 0; i < inputs.
size(); ++i) {
1567 const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i;
1568 out_shape.
push_back(inputs[src_index]->
shape.size() == 0 ? 1 : inputs[src_index]->shape[0]);
1571 for (
size_t i = 0; i < inputs.
size(); ++i) {
1575 const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i;
1576 auto ndim = inputs[i]->GetShape().
size();
1579 real_indices = {indices[src_index]};
1581 return inputs[i](real_indices);
1599 const std::string& dst_layout,
1600 const std::string schedule_rule =
"None",
1601 const std::string name =
"T_layout_trans",
1603 Layout src_layout_struct(src_layout);
1604 Layout dst_layout_struct(dst_layout);
1606 if (src_layout_struct.
Equals(dst_layout_struct)) {
1610 ICHECK(src_layout_struct.
defined() && dst_layout_struct.
defined())
1611 <<
"cannot convert from/to undefined layout";
1614 ICHECK(layout_converter.defined())
1615 <<
"cannot convert from " << src_layout <<
" to " << dst_layout;
1617 Array<PrimExpr> dst_shape = layout_converter.ForwardShape(src->shape);
1621 {
"src_layout",
String(src_layout)},
1622 {
"dst_layout",
String(dst_layout)},
1623 {
"input_shape", src->shape}};
1629 Array<PrimExpr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
1631 for (
size_t i = 0; i < src.
ndim(); ++i) {
1632 in_range = in_range && (src_indices[i] < src->shape[i]);
1641 std::vector<std::string>* axes) {
1643 std::string axis =
"";
1644 for (
char c : std::string(layout)) {
1645 if (c >=
'A' && c <=
'z') {
1651 }
else if (c >=
'0' && c <=
'9') {
1652 factor = factor * 10 + c -
'0';
1653 if (!axis.empty()) {
1654 axes->push_back(axis);
1658 LOG(FATAL) <<
"Invalid layout " << layout;
1661 if (!axis.empty()) {
1662 axes->push_back(axis);
1677 const String& dst_layout,
1678 const String name =
"T_auto_scheduler_layout_trans",
1681 std::vector<std::string> src_axes;
1683 std::vector<std::string> dst_axes;
1692 for (
const std::string& src_axis : src_axes) {
1694 CHECK_EQ(dst_indices_expr.size(), dst_axes.size());
1695 for (
size_t i = 0; i < dst_axes.size(); ++i) {
1696 if (dst_axes[i] == src_axis) {
1697 src_index = src_index * dst_shape[i] + dst_indices_expr[i];
1702 return src(src_indices);
1744 const String name =
"T_meta_schedule_layout_trans",
1747 iter_domain.
reserve(src->shape.size());
1748 for (
const PrimExpr& e : src->shape) {
1751 Array<PrimExpr> post_transform_shape = index_map->MapShape(src->shape);
1753 post_transform_shape,
1755 return src(inv->MapIndices(
Array<PrimExpr>{indices.begin(), indices.end()}));
1770 int ndim =
static_cast<int>(src->shape.size());
1775 auto idx = indices[0];
1777 for (
int i = 0; i < ndim; ++i) {
1794 const std::string& name =
"ndarray_size",
1796 int ndim =
static_cast<int>(src->shape.size());
1802 for (
int i = 0; i < ndim; ++i) {
1803 ret *= src->shape[i];
1825 int depth,
int axis,
const DataType& dtype,
1827 const std::string name =
"T_one_hot",
const std::string tag =
kInjective) {
1828 int true_axis = (axis == -1) ? indices->shape.size() : axis;
1829 if (oshape.size() == 0) {
1830 int ndim = indices->shape.size() + 1;
1831 int indices_index = 0;
1832 for (
int i = 0; i < ndim; i++) {
1833 if (i == true_axis) {
1834 oshape.push_back(
Integer(depth));
1836 oshape.push_back(indices->shape[indices_index++]);
1847 for (
size_t i = 0; i < iter_vars.
size(); i++) {
1848 if (static_cast<int>(i) == true_axis) {
1852 indices_indices.
push_back(iter_vars[i]);
1855 auto idx = iter_vars[true_axis];
1856 return tir::Select(indices(indices_indices) == idx, on_value_cast, off_value_cast);
1873 const std::string name =
"T_sparse_to_dense",
1875 ICHECK(sparse_indices->dtype.is_int()) <<
"sparse_indices only accepts integer values";
1876 ICHECK_LE(sparse_indices->shape.size(), 3)
1877 <<
"sparse_indices tensor should be 0D, 1D, or 2D only";
1878 ICHECK_LE(sparse_values->shape.size(), 2) <<
"sparse_values tensor should be 0D or 1D only";
1880 const auto rank_sparse_indices =
static_cast<int>(sparse_indices->shape.size());
1882 for (
auto l : output_shape) {
1889 if (0 == rank_sparse_indices) {
1890 ret =
if_then_else(indices[0] == sparse_indices(), sparse_values(), ret);
1891 }
else if (1 == rank_sparse_indices) {
1892 for (
int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
1893 ret =
if_then_else(indices[0] == sparse_indices[j], sparse_values[j], ret);
1896 for (
int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
1898 for (
int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) {
1899 PrimExpr comparision = indices[k] == sparse_indices[j][k];
1900 aggregate_condition = 0 == k ? comparision : aggregate_condition && comparision;
1902 ret =
if_then_else(aggregate_condition, sparse_values[j], ret);
1923 bool super_diag_right_align,
bool sub_diag_right_align,
1924 const std::string name =
"T_matrix_set_diag",
1926 size_t ndim = input->shape.size() - 1;
1928 bool only_one_diagonal = k1 == k2;
1933 auto get_diag = [&]() {
1934 Array<PrimExpr> diagonal_indices;
1935 PrimExpr k, offset = 0;
1936 for (size_t i = 0; i < ndim - 1; i++) {
1937 diagonal_indices.push_back(iter_vars[i]);
1939 if (only_one_diagonal) {
1943 k = iter_vars[ndim] - iter_vars[ndim - 1];
1944 diagonal_indices.push_back(k2 - k);
1947 auto get_offset = [&](PrimExpr M, PrimExpr N) {
1949 return diagonal->shape[diagonal->shape.size() - 1] - if_then_else(M < N, M, N);
1951 offset = if_then_else(
1953 super_diag_right_align ? get_offset(input->shape[ndim] - k, input->shape[ndim - 1])
1955 sub_diag_right_align ? get_offset(input->shape[ndim], input->shape[ndim - 1] + k)
1958 diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) +
1960 return diagonal(diagonal_indices);
1964 get_diag(), input(iter_vars)),
1979 const std::string name =
"advanced_index",
1981 ICHECK_LE(indices.
size(), data->shape.size()) <<
"too many indices for data!";
1986 broadcast_shape = indices[0]->shape;
1987 for (
size_t i = 1; i < indices.
size(); ++i) {
1988 auto bh = detail::BroadcastShape(broadcast_shape, indices[i]->
shape);
1989 broadcast_shape =
Array<PrimExpr>(bh.common_shape.begin(), bh.common_shape.end());
1991 if (indices.
size() == 1) {
1996 for (
size_t i = 0; i < indices.
size(); ++i) {
2001 for (
const auto& dim : broadcast_shape) {
2002 oshape.push_back(dim);
2004 for (
size_t i = indices.
size(); i < data->shape.size(); ++i) {
2005 oshape.push_back(data->shape[i]);
2012 for (
size_t i = 0; i < broadcast_shape.size(); ++i) {
2017 for (
size_t i = 0; i < bindices.
size(); ++i) {
2018 real_indices.
push_back(bindices[i](tensor_indices));
2020 for (
size_t i = broadcast_shape.size(); i < iter_var.
size(); ++i) {
2024 return data(real_indices);
2031 #endif // TVM_TOPI_TRANSFORM_H_ void reserve(int64_t n)
Make sure the list has the capacity of at least n.
Definition: array.h:569
Managed reference to LayoutNode.
Definition: data_layout.h:123
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
bool Equals(const Layout &rhs) const
Whether the two layouts are equal.
Definition: data_layout.h:278
Tensor strided_slice_with_axes(const Tensor &x, const Array< Integer > &begin, const Array< Integer > &end, const Array< Integer > &strides, const Array< Integer > &axes, std::string slice_mode="end", std::string name="T_strided_slice_with_axes", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:770
Tensor sparse_to_dense(const Tensor &sparse_indices, const Array< PrimExpr > &output_shape, const Tensor &sparse_values, const PrimExpr &default_value, const std::string name="T_sparse_to_dense", const std::string tag=kInjective)
Get a dense tensor.
Definition: transform.h:1871
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
Array< PrimExpr > StridedSliceOutputShape(const Array< PrimExpr > &ishape, const Array< Integer > &begin, const Array< Integer > &end, const Array< Integer > &strides, const Array< Integer > &axes, const std::string &slice_mode)
Calcluate the output shape of strided_slice, the entry point for Relay type relation.
Definition: transform.h:742
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:954
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span=Span())
construct a new range with min and extent The corresponding constructor is removed, because that is counter convention of tradition meaning of range(begin, end)
Tensor sliding_window(const Tensor &x, int axis, Array< Integer > window_shape, Array< Integer > strides, std::string name="T_sliding_window", std::string tag="")
Creates an operation to slide a window over the input x.
Definition: transform.h:68
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Tensor layout_transform(const Tensor &src, const std::string &src_layout, const std::string &dst_layout, const std::string schedule_rule="None", const std::string name="T_layout_trans", const std::string tag=kInjective)
Transform the layout according to src_layout and dst_layout.
Definition: transform.h:1598
Tensor expression language DSL.
Definition: extracted_task.h:33
Tensor tensordot(const Tensor &A, const tvm::te::Tensor &B, int axes=2, std::string name="T_tensordot", std::string tag=kMatMul)
A generalization of matrix multiplication to tensors.
Definition: transform.h:1449
Tensor dynamic_strided_slice(const Tensor &x, const Array< PrimExpr > &begin, const Array< PrimExpr > &end, const Array< PrimExpr > &strides, std::string name="T_dynamic_strided_slice", std::string tag=kInjective)
strided_slice of a tensor where begin/end/stride can be mixed static and dynamic
Definition: transform.h:652
PrimExpr ceil(PrimExpr x, Span span=Span())
Calculate ceil(x)
a named variable in TIR
Definition: var.h:88
Tensor where(const Tensor &condition, const Tensor &x, const Tensor &y, std::string name="T_where", std::string tag=kBroadcast)
Return the elements, either from x or y, depending on the condition.
Definition: transform.h:1134
Tensor one_hot(const Tensor &indices, const PrimExpr on_value, const PrimExpr off_value, int depth, int axis, const DataType &dtype, Array< PrimExpr > oshape=Array< PrimExpr >(), const std::string name="T_one_hot", const std::string tag=kInjective)
Returns a one-hot tensor where the locations repsented by indices take value on_value, other locations take value off_value.
Definition: transform.h:1824
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
Tensor matrix_set_diag(const Tensor &input, const Tensor &diagonal, int k1, int k2, bool super_diag_right_align, bool sub_diag_right_align, const std::string name="T_matrix_set_diag", const std::string tag=kInjective)
Returns a tensor with the diagonal of input tensor replaced with the provided diagonals.
Definition: transform.h:1922
constexpr auto kMatMul
Definition: tags.h:37
constexpr auto kInjective
Definition: tags.h:33
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
Utility functions for strided_slice op.
Tensor unravel_index(const Tensor &x, const Tensor &shape, std::string name="T_unravel", std::string tag=kInjective)
Converts a flat index or array of flat indices into a tuple of coordinate arrays. ...
Definition: transform.h:359
Array< Tensor > split(const Tensor &x, Array< PrimExpr > split_indices, int axis, std::string name="T_split", std::string tag=kInjective)
Split a tensor into multiple sub-tensors.
Definition: transform.h:575
DataType dtype() const
Definition: expr.h:128
void parse_auto_scheduler_layout(const String &layout, Array< PrimExpr > *shape, std::vector< std::string > *axes)
Utility function for auto_scheduler_layout_transform.
Definition: transform.h:1640
Tensor gather(const Tensor &data, int axis, const Tensor &indices, std::string name="T_gather", std::string tag=kInjective)
Gather values along given axis from given indices.
Definition: transform.h:1313
size_t ndim() const
Definition: tensor.h:214
Tensor dyn_tile(const Tensor &x, Array< PrimExpr > new_shape, size_t rdim, std::string name="T_tile", std::string tag=kBroadcast)
Creates an operation to tile elements of an array.
Definition: transform.h:1276
Tensor auto_scheduler_layout_transform(const Tensor &src, const String &src_layout, const String &dst_layout, const String name="T_auto_scheduler_layout_trans", const String tag=kInjective)
Transform the auto-scheduler generated layout according to src_layout and dst_layout.
Definition: transform.h:1676
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
Tensor squeeze(const Tensor &x, Array< Integer > axis, bool atleast1d=false, std::string name="T_squeeze", std::string tag=kInjective)
Remove size 1 dimensions from the shape of a tensor. The removed dimensions must have a constant size...
Definition: transform.h:407
void Set(int64_t i, T value)
set i-th element of the array.
Definition: array.h:621
Constant integer literals in the program.
Definition: expr.h:491
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
Defines a remapping of buffer indices.
Tensor expand_dims(const Tensor &x, int axis, int num_newaxis=1, std::string name="T_expand_dims", std::string tag=kBroadcast)
Creates an operation to insert new dimensions of length 1.
Definition: transform.h:147
Tensor tile(const Tensor &x, Array< Integer > reps, std::string name="T_tile", std::string tag=kBroadcast)
Creates an operation to tile elements of an array.
Definition: transform.h:1221
Tensor sum(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:326
Utility functions for handling constants in TVM expressions.
constexpr auto kBroadcast
Definition: tags.h:36
Range constainer.
Definition: expr.h:715
Tensor arange(const PrimExpr &start, const PrimExpr &stop, const PrimExpr &step, DataType dtype, std::string name="T_arange", std::string tag=kInjective)
Definition: transform.h:1541
size_t size() const
Definition: array.h:420
Runtime primitive data type.
Definition: data_type.h:41
bool defined() const
Definition: object.h:544
Tensor stack(const Array< Tensor > &inputs, int axis=0, std::string name="T_stack", std::string tag=kInjective)
Join a sequence of tensors along a new axis.
Definition: transform.h:529
Utility functions for handling tensor.
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:178
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Definition: index_map.h:177
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
Tensor concatenate(const Array< Tensor > &inputs, int axis=0, std::string name="T_concat", std::string tag=kInjective)
Join a sequence of tensors along an existing axis.
Definition: transform.h:470
Tensor take(const Tensor &a, const Tensor &indices, int batch_dims, std::string mode="clip", std::string name="T_take", std::string tag=kInjective)
Take elements from an flattened input array when axis is None.
Definition: transform.h:900
Managed reference class to IntImmNode.
Definition: expr.h:520
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
int64_t value
the Internal value.
Definition: expr.h:494
Reference to string objects.
Definition: string.h:98
PrimExpr make_zero(DataType t, Span span=Span())
Make a const zero expr.
Definition: op.h:962
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1768
Array< Tensor > meshgrid(const Array< Tensor > &inputs, const std::string &indexing, std::string name="T_meshgrid", std::string tag=kInjective)
Produce grids by expanding input over dimensions defined by other inputs.
Definition: transform.h:1562
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of truncdiv
iterator end() const
Definition: array.h:390
PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b)
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
iterator begin() const
Definition: array.h:387
Operation node can generate one or multiple Tensors.
Tensor meta_schedule_layout_transform(const Tensor &src, const tir::IndexMap &index_map, const String name="T_meta_schedule_layout_trans", const String tag=kInjective)
Transform the meta-schedule generated layout according to TIR's IndexMap.
Definition: transform.h:1743
Managed reference to SelectNode.
Definition: expr.h:609
Bijective function mapping for data layout transformation. Given two Layout, BijectiveLayout build an...
Definition: data_layout.h:332
Tensor transpose(const Tensor &x, Array< Integer > axes, std::string name="T_transpose", std::string tag=kInjective)
Permute the dimensions of an array.
Definition: transform.h:196
PrimExpr max_value(const DataType &dtype, Span span=Span())
Tensor ndarray_size(const Tensor &src, const DataType &dtype, const std::string &name="ndarray_size", const std::string &tag=kInjective)
Get the size of input tensor.
Definition: transform.h:1793
Tensor sequence_mask(const Tensor &data, const Tensor &valid_length, double mask_value, int axis, std::string name="T_sequence_mask", std::string tag=kInjective)
Mask the out-of-boundary elements of each sequence.
Definition: transform.h:948
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Tensor adv_index(const Tensor &data, const Array< Tensor > &indices, const std::string name="advanced_index", const std::string tag=kInjective)
Numpy style advanced indexing with tensor.
Definition: transform.h:1978
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Tensor reverse_sequence(const Tensor &x, const Tensor &seq_lengths, int seq_axis=1, int batch_axis=0, std::string name="T_reverse_sequence", std::string tag=kInjective)
Reverse the tensor for variable length slices. Input is first sliced along batch axis and then elemen...
Definition: transform.h:255
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type...
Definition: elemwise.h:281
Tensor reshape(const Tensor &x, Array< PrimExpr > newshape, std::string name="T_reshape", std::string tag=kInjective)
Reshape a tensor.
Definition: transform.h:320
tvm::te::Tensor broadcast_to(const tvm::te::Tensor &t, const tvm::Array< tvm::PrimExpr > &output_shape, std::string name="T_broadcast_to", std::string tag=kBroadcast)
Creates an operation that broadcasts a tensor into a compatible shape according to numpy's rules...
Definition: broadcast.h:48
Tensor repeat(const Tensor &x, int repeats, int axis, std::string name="T_repeat", std::string tag=kBroadcast)
Creates an operation to repeat elements of an array.
Definition: transform.h:1174
Tensor strided_slice(const Tensor &x, const Array< Integer > &begin, const Array< Integer > &end, const Array< Integer > &strides, std::string slice_mode="end", std::string name="T_strided_slice", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:816
Broadcast op constructions.
Reference to PrimExprNode.
Definition: expr.h:114
Layout expression to describe the data organization of a tensor. And BijectiveLayout to mapping two d...
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:865
Tensor gather_nd(const Tensor &data, const Tensor &indices, int batch_dims=0, std::string name="T_gather_nd", std::string tag=kInjective)
Gather elements from a n-dimension array.
Definition: transform.h:1366
Array< Tensor > split_sections(const Tensor &x, int num_sections, int axis, std::string name="T_split_sections", std::string tag=kInjective)
Split a tensor into a number of sub-tensors.
Definition: transform.h:857
Index ravel and unraval operations.
IndexMap Inverse(Array< Range > initial_ranges) const
Generate the inverse mapping.
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:579
tvm::te::Tensor matmul(const tvm::te::Tensor &A, const tvm::te::Tensor &B, bool trans_a=false, bool trans_b=false, std::string name="T_matmul", std::string tag=kMatMul)
Creates an operation that calculates a matrix multiplication (row-major notation): A(i...
Definition: transform.h:1427
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:164
Container of constant int that adds more constructors.
Definition: expr.h:622