24 #ifndef TVM_TOPI_TRANSFORM_H_
25 #define TVM_TOPI_TRANSFORM_H_
43 #include <unordered_set>
57 using namespace topi::detail;
77 ffi::Array<Integer> strides, std::string name =
"T_sliding_window",
78 std::string tag =
"") {
79 TVM_FFI_ICHECK_GE(axis, 0);
80 auto _axis = size_t(axis);
81 TVM_FFI_ICHECK_LT(_axis, x->shape.size()) <<
"axis must be a valid dimension index of x.";
82 TVM_FFI_ICHECK_EQ(x->shape.size() - _axis, window_shape.size())
83 <<
"There must be a window shape for every dimension of x "
84 <<
"over which we are sliding the window.";
85 TVM_FFI_ICHECK_EQ(strides.size(), window_shape.size())
86 <<
"Windows and strides should be the same length.";
89 ffi::Array<PrimExpr> new_shape;
91 for (
size_t i = 0; i < _axis; ++i) {
92 new_shape.push_back(x->shape[i]);
97 for (
size_t i = 0; i < window_shape.size(); ++i) {
99 auto dim_len = x->shape[_axis + i];
101 auto window_len = window_shape[i];
103 auto stride = strides[i];
105 new_shape.push_back(
floordiv(dim_len - (window_len - 1) + stride - 1, stride));
109 for (
size_t i = 0; i < window_shape.size(); ++i) {
110 new_shape.push_back(window_shape[i]);
113 TVM_FFI_ICHECK(new_shape.size() == _axis + 2 * window_shape.size());
117 [&](
const ffi::Array<Var>& indices) {
119 ffi::Array<PrimExpr> idx;
122 for (
size_t i = 0; i < _axis; ++i) {
123 idx.push_back(indices[i]);
126 for (
size_t i = 0; i < window_shape.size(); ++i) {
128 auto window_idx = indices[_axis + i];
130 auto idx_within_window = indices[_axis + window_shape.size() + i];
132 auto stride = strides[i];
134 idx.push_back(window_idx * stride + idx_within_window);
137 TVM_FFI_ICHECK(idx.size() == x->shape.size());
157 std::string name =
"T_expand_dims", std::string tag =
kBroadcast) {
158 int ndim =
static_cast<int>(x->shape.size());
159 TVM_FFI_ICHECK(-ndim - 1 <= axis && axis <= ndim)
160 <<
"expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]"
161 <<
", but got axis = " << axis <<
", and data.ndim = " << ndim;
162 TVM_FFI_ICHECK(num_newaxis >= 0) <<
"expand_dims only accepts `num_newaxis >= 0`"
163 <<
", but got num_newaxis = " << num_newaxis;
166 axis = ndim + axis + 1;
168 ffi::Array<PrimExpr> new_shape;
169 for (
size_t i = 0; i < static_cast<size_t>(axis); ++i) {
170 new_shape.push_back(x->shape[i]);
172 for (
size_t i = 0; i < static_cast<size_t>(num_newaxis); ++i) {
173 new_shape.push_back(1);
175 for (
size_t i = axis; i < x->shape.size(); ++i) {
176 new_shape.push_back(x->shape[i]);
181 [&](
const ffi::Array<Var>& indices) {
182 ffi::Array<PrimExpr> idx;
183 for (
size_t i = 0; i < static_cast<size_t>(axis); ++i) {
184 idx.push_back(indices[i]);
186 for (
size_t i = axis + num_newaxis; i < indices.size(); ++i) {
187 idx.push_back(indices[i]);
206 std::string name =
"T_transpose", std::string tag =
kInjective) {
207 ffi::Array<Integer> axes = opt_axes.value_or({});
208 if (axes.size() == 0) {
209 for (
int i =
static_cast<int>(x->shape.size()) - 1; i >= 0; --i) {
214 ffi::Array<PrimExpr> new_shape;
215 for (
size_t i = 0; i < axes.size(); ++i) {
216 int axis =
static_cast<int>(axes[i]->value);
219 new_axis =
static_cast<int>(x->shape.size()) + axis;
220 axes.Set(i, new_axis);
222 TVM_FFI_ICHECK((new_axis >= 0) && (new_axis <
static_cast<int>(x->shape.size())))
223 <<
"axis=" << axis <<
" is invalid for the " <<
static_cast<int>(x->shape.size())
224 <<
"-dimensional input tensor";
226 for (
size_t j = 0; j < axes.size(); ++j) {
228 TVM_FFI_ICHECK(new_axis !=
static_cast<int>(axes[j]->value))
229 <<
"repeated axis in transpose";
232 new_shape.push_back(x->shape[new_axis]);
237 [&](
const ffi::Array<Var>& indices) {
238 std::vector<PrimExpr> idx;
239 for (
size_t i = 0; i < axes.size(); ++i) {
242 for (
size_t i = 0; i < axes.size(); ++i) {
243 int axis =
static_cast<int>(axes[i]->value);
244 idx[axis] = indices[i];
266 int batch_axis = 0, std::string name =
"T_reverse_sequence",
268 size_t src_tensor_dim = x->shape.size();
269 int seq_axis_inp = seq_axis;
271 if (seq_lengths.defined()) {
272 size_t seq_lengths_dim = seq_lengths->shape.size();
273 int batch_axis_inp = batch_axis;
274 if (batch_axis < 0) {
275 batch_axis =
static_cast<int>(x->shape.size()) + batch_axis;
278 TVM_FFI_ICHECK(seq_lengths_dim == 1) <<
"seq_lengths should be 1D vector";
280 TVM_FFI_ICHECK(GetConstInt(seq_lengths->shape[0]) == GetConstInt(x->shape[batch_axis]))
281 <<
"For reverse_sequnece seq_lengths size should match with dimension of batch axis"
282 <<
", but got dimension of batch_axis = " << GetConstInt(x->shape[batch_axis])
283 <<
", and seq_length size = " << GetConstInt(seq_lengths->shape[0]);
285 TVM_FFI_ICHECK((0 <= batch_axis) && (batch_axis <
static_cast<int>(x->shape.size())))
286 <<
"batch_axis=" << batch_axis_inp <<
" is invalid for the "
287 <<
static_cast<int>(x->shape.size()) <<
"-dimensional input tensor";
291 seq_axis =
static_cast<int>(x->shape.size()) + seq_axis;
293 TVM_FFI_ICHECK((0 <= seq_axis) && (seq_axis <
static_cast<int>(x->shape.size())))
294 <<
"seq_axis=" << seq_axis_inp <<
" is invalid for the " <<
static_cast<int>(x->shape.size())
295 <<
"-dimensional input tensor";
297 auto func = [&](
const ffi::Array<Var>& indices) {
298 ffi::Array<PrimExpr> real_indices;
299 for (
size_t i = 0; i < src_tensor_dim; ++i) {
300 if (i ==
static_cast<size_t>(seq_axis)) {
301 if (seq_lengths.defined()) {
302 auto len = seq_lengths(indices[batch_axis]);
304 len <= 1 || len <= indices[i], indices[i],
305 if_then_else(len > x->shape[i], x->shape[i] - 1 - indices[i], len - 1 - indices[i]));
306 real_indices.push_back(idx);
308 real_indices.push_back(x->shape[i] - 1 - indices[i]);
311 real_indices.push_back(indices[i]);
314 return x(real_indices);
317 return compute(x->shape, func, name, tag);
331 std::string name =
"T_reshape", std::string tag =
kInjective) {
332 auto x_shape = x->shape;
333 ffi::Array<PrimExpr> target_shape;
335 for (
const auto& ele : newshape) {
336 target_shape.push_back(ele);
340 if (is_empty_shape(target_shape) || is_empty_shape(x->shape)) {
342 target_shape, [&](
const ffi::Array<Var>& indices) {
return tvm::cast(x->dtype, 0); }, name,
347 [&](
const ffi::Array<Var>& indices) {
348 return x(UnravelIndex(
349 RavelIndex(ffi::Array<PrimExpr>{indices.begin(), indices.end()}, target_shape),
369 auto x_shape = x->shape;
370 auto shape_shape =
shape->shape;
372 ffi::Array<PrimExpr> oshape;
373 oshape.push_back(shape_shape[0]);
374 if (x_shape.size() != 0) {
375 oshape.push_back(x_shape[0]);
378 auto func = [&](
const ffi::Array<Var>& indices) {
380 std::vector<PrimExpr> indices_divs;
385 if (x_shape.size() != 0) {
386 index_val = x[indices[1]];
390 indices_divs.push_back(index_val);
391 for (
int v = GetConstInt(shape_shape[0]) - 1; v >= 0; --v) {
394 indices_divs.push_back(cur_val);
399 return compute(oshape, func, name, tag);
416 bool atleast1d =
false, std::string name =
"T_squeeze",
418 auto ndim = x->shape.size();
419 std::vector<int> axis_val;
420 if (!opt_axes.has_value()) {
421 for (
size_t i = 0; i < ndim; ++i) {
422 if (IsConstInt(x->shape[i]) && GetConstInt(x->shape[i]) == 1) {
423 axis_val.push_back(
static_cast<int>(i));
427 ffi::Array<Integer> axis = *std::move(opt_axes);
428 for (
size_t i = 0; i < axis.size(); ++i) {
429 int64_t val = axis[i]->value;
431 val +=
static_cast<int>(x->shape.size());
434 bool is_const = IsConstInt(x->shape[val]);
435 if ((is_const && GetConstInt(x->shape[val]) == 1) || !is_const) {
436 axis_val.push_back(val);
441 std::unordered_set<int> axis_set(axis_val.begin(), axis_val.end());
443 ffi::Array<PrimExpr> out_shape;
444 for (
size_t i = 0; i < ndim; ++i) {
445 if (axis_set.count(
static_cast<int>(i)) == 0) {
446 out_shape.push_back(x->shape[i]);
449 if (out_shape.size() == 0 && atleast1d) {
450 out_shape.push_back(1);
455 [&](
const ffi::Array<Var>& indices) {
456 ffi::Array<PrimExpr> real_indices;
458 for (
size_t i = 0; i < ndim; ++i) {
459 if (axis_set.count(
static_cast<int>(i)) == 0) {
460 real_indices.push_back(indices[i - flag]);
462 real_indices.push_back(0);
466 return x(real_indices);
482 std::string name =
"T_concat", std::string tag =
kInjective) {
483 int ndim =
static_cast<int>(inputs[0]->shape.size());
484 TVM_FFI_ICHECK(-ndim <= axis && axis < ndim)
485 <<
"concatenate only accepts `axis` in [-ndim, ndim)"
486 <<
", but got axis = " << axis <<
", and ndim = " << ndim;
490 TVM_FFI_ICHECK_LT(axis, inputs[0]->
shape.size()) <<
"axis out of bounds";
492 ffi::Array<PrimExpr> axis_sizes;
493 for (
auto t : inputs) {
494 axis_sizes.push_back(t->shape[axis]);
498 for (
size_t i = 1; i < axis_sizes.size(); ++i) {
499 join_size += axis_sizes[i];
501 join_size = analyzer.
Simplify(join_size);
502 ffi::Array<PrimExpr> out_shape;
503 for (
size_t i = 0; i < inputs[0]->shape.size(); ++i) {
504 out_shape.push_back(i ==
static_cast<size_t>(axis) ? join_size : inputs[0]->
shape[i]);
509 [&](
const ffi::Array<Var>& indices) {
510 auto ret = inputs[0](indices);
511 auto ind = indices[axis];
512 for (
size_t i = 0; i < inputs.size() - 1; ++i) {
513 ind -= axis_sizes[i];
515 ffi::Array<PrimExpr> idx;
516 for (
size_t i = 0; i < static_cast<size_t>(axis); ++i) {
517 idx.push_back(indices[i]);
520 for (
size_t i = axis + 1; i < indices.size(); ++i) {
521 idx.push_back(indices[i]);
541 inline Tensor stack(
const ffi::Array<Tensor>& inputs,
int axis = 0, std::string name =
"T_stack",
543 int ndim =
static_cast<int>(inputs[0]->shape.size());
544 TVM_FFI_ICHECK(-ndim - 1 <= axis && axis <= ndim)
545 <<
"stack only accepts `axis` in [-ndim, ndim)"
546 <<
", but got axis = " << axis <<
", and ndim = " << ndim;
550 TVM_FFI_ICHECK_LT(axis, inputs[0]->
shape.size() + 1) <<
"axis out of bounds";
552 const int stack_size =
static_cast<int>(inputs.size());
553 ffi::Array<PrimExpr> out_shape;
554 for (
size_t i = 0; i < static_cast<size_t>(axis); ++i) out_shape.push_back(inputs[0]->shape[i]);
555 out_shape.push_back(stack_size);
556 for (
size_t i =
static_cast<size_t>(axis); i < static_cast<size_t>(ndim); ++i)
557 out_shape.push_back(inputs[0]->shape[i]);
561 [&](
const ffi::Array<Var>& indices) {
562 ffi::Array<PrimExpr> idx;
563 for (
size_t i = 0; i < indices.size(); ++i)
564 if (i !=
static_cast<size_t>(axis)) idx.push_back(indices[i]);
565 auto ind = indices[axis];
566 auto ret = inputs[0](idx);
567 for (
int i = 0; i < static_cast<int>(inputs.size() - 1); ++i) {
588 int axis, std::string name =
"T_split",
591 axis +=
static_cast<int>(x->shape.size());
593 TVM_FFI_ICHECK_LT(axis, x->shape.size()) <<
"axis out of bounds";
595 auto src_axis_size = x->shape[axis];
596 std::vector<PrimExpr> begin_ids;
597 begin_ids.push_back(0);
599 for (
auto idx : split_indices) {
601 auto back_node = begin_ids.back().as<
IntImmNode>();
602 if (idx_node && back_node) {
603 TVM_FFI_ICHECK_GT(idx_node->value, back_node->
value) <<
"split_indices must be sorted";
605 begin_ids.push_back(idx);
608 ffi::Array<ffi::Array<PrimExpr>> out_shapes;
609 for (
size_t i = 0; i < begin_ids.size(); ++i) {
611 if (i == begin_ids.size() - 1) {
612 out_axis_size = src_axis_size - begin_ids[i];
614 out_axis_size = begin_ids[i + 1] - begin_ids[i];
617 ffi::Array<PrimExpr>
shape;
618 for (
size_t i = 0; i < static_cast<size_t>(axis); ++i) {
619 shape.push_back(x->shape[i]);
621 shape.push_back(out_axis_size);
622 for (
size_t i = axis + 1; i < x->shape.size(); ++i) {
623 shape.push_back(x->shape[i]);
626 out_shapes.push_back(
shape);
629 ffi::Array<Tensor> result;
630 for (
size_t i = 0; i < begin_ids.size(); ++i) {
633 [&](
const ffi::Array<Var>& indices) {
634 auto begin = begin_ids[i];
635 ffi::Array<PrimExpr> real_indices;
636 for (
size_t j = 0; j < static_cast<size_t>(axis); ++j) {
637 real_indices.push_back(indices[j]);
639 real_indices.push_back(indices[axis] + begin);
640 for (
size_t j = axis + 1; j < indices.size(); ++j) {
641 real_indices.push_back(indices[j]);
644 return x(real_indices);
663 if (!(index->IsInstance<
tvm::IntImmNode>() && GetConstInt(index) >= 0)) {
671 int64_t begin_range = stride < 0 ? -1 : 0;
672 int64_t end_range = stride < 0 ? extent - 1 : extent;
690 bool assume_inbound =
true) {
691 if (assume_inbound) {
692 return ceildiv(end - begin, stride);
717 const te::Tensor& x,
const ffi::Array<PrimExpr>& begin,
const ffi::Array<PrimExpr>& end,
718 const ffi::Array<PrimExpr>& strides,
const ffi::Array<Integer>& axes,
719 bool assume_inbound =
true, std::string name =
"T_dynamic_strided_slice_with_axes",
721 const size_t src_tensor_dim = x->shape.size();
722 TVM_FFI_ICHECK_EQ(begin.size(), end.size());
723 TVM_FFI_ICHECK_EQ(begin.size(), strides.size());
724 TVM_FFI_ICHECK_EQ(begin.size(), axes.size());
725 TVM_FFI_ICHECK_LE(begin.size(), src_tensor_dim);
727 for (
const auto& axis_imm : axes) {
728 int axis = axis_imm->value;
729 TVM_FFI_ICHECK_LT(axis, src_tensor_dim);
734 ffi::Array<PrimExpr> out_shape = x->shape;
735 for (
size_t i = 0; i < begin.size(); i++) {
736 int axis = axes[i]->value;
738 analyzer.
Simplify(
GetLength(begin[i], end[i], strides[i], out_shape[axis], assume_inbound));
739 out_shape.Set(axis, new_shape);
744 [&](
const ffi::Array<tvm::tirx::Var>& indices) {
745 ffi::Array<PrimExpr> real_indices =
748 for (
size_t i = 0; i < begin.size(); i++) {
749 int axis = axes[i]->value;
750 PrimExpr new_index = indices[axis] * strides[i] + begin[i];
751 real_indices.Set(axis, new_index);
754 return x(real_indices);
774 const ffi::Array<PrimExpr>& end,
775 const ffi::Array<PrimExpr>& strides,
bool assume_inbound =
true,
776 std::string name =
"T_dynamic_strided_slice",
778 const size_t src_tensor_dim = x->shape.size();
779 TVM_FFI_ICHECK_LE(begin.size(), src_tensor_dim);
780 TVM_FFI_ICHECK_LE(end.size(), src_tensor_dim);
781 TVM_FFI_ICHECK_LE(strides.size(), src_tensor_dim);
782 TVM_FFI_ICHECK_EQ(begin.size(), end.size());
783 TVM_FFI_ICHECK_EQ(begin.size(), strides.size());
785 const size_t num_slice_axes = begin.size();
786 ffi::Array<PrimExpr> out_shape;
789 for (
size_t i = 0; i < num_slice_axes; ++i) {
791 if (!begin[i]->IsInstance<ProducerLoadNode>() && !end[i]->IsInstance<ProducerLoadNode>() &&
792 !strides[i]->IsInstance<ProducerLoadNode>()) {
794 analyzer.
Simplify(
GetLength(begin[i], end[i], strides[i], x->shape[i], assume_inbound)));
800 for (
size_t i = num_slice_axes; i < src_tensor_dim; ++i) {
801 out_shape.push_back(x->shape[i]);
806 [&](
const ffi::Array<tvm::tirx::Var>& indices) {
807 ffi::Array<PrimExpr> real_indices;
808 for (
size_t i = 0; i < num_slice_axes; ++i) {
809 real_indices.push_back(indices[i] * strides[i] +
tvm::min(begin[i], x->shape[i] - 1));
812 for (
size_t i = num_slice_axes; i < src_tensor_dim; ++i) {
813 real_indices.push_back(indices[i]);
815 return x(real_indices);
836 bool assume_inbound =
true,
837 std::string name =
"T_strided_slice_dynamic",
839 DataType index_dtype = begin->shape[0]->dtype;
840 const int64_t num_dynamic_axes = begin->shape[0].as<
IntImmNode>()->value;
841 TVM_FFI_ICHECK_EQ(end->shape[0].as<
IntImmNode>()->
value, num_dynamic_axes);
842 TVM_FFI_ICHECK_EQ(strides->shape[0].as<
IntImmNode>()->
value, num_dynamic_axes);
844 ffi::Array<PrimExpr> begin_expr, end_expr, strides_expr;
845 for (int64_t i = 0; i < num_dynamic_axes; ++i) {
847 begin_expr.push_back(begin(ind));
848 end_expr.push_back(end(ind));
849 strides_expr.push_back(strides(ind));
869 const ffi::Array<Integer>& begin,
870 const ffi::Array<Integer>& end,
871 const ffi::Array<Integer>& strides,
872 const ffi::Array<Integer>& axes,
873 const std::string& slice_mode) {
874 TVM_FFI_ICHECK(axes.size() == begin.size() && axes.size() == end.size() &&
875 axes.size() == strides.size());
876 std::vector<int64_t> begin_vec, end_vec, strides_vec;
877 std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode);
878 auto begin_canonicalized = StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes,
879 begin[0]->dtype, slice_mode);
881 begin_canonicalized,
true);
901 const ffi::Array<Integer>& end,
902 const ffi::Array<Integer>& strides,
903 const ffi::Array<Integer>& axes,
904 std::string slice_mode =
"end",
905 std::string name =
"T_strided_slice_with_axes",
907 const int64_t src_tensor_dim =
static_cast<int64_t
>(x->shape.size());
908 TVM_FFI_ICHECK(
static_cast<int64_t
>(axes.size()) <= src_tensor_dim);
909 TVM_FFI_ICHECK(axes.size() == begin.size() && axes.size() == end.size() &&
910 axes.size() == strides.size());
913 ffi::Array<Integer> normalized_axes;
914 for (
size_t i = 0; i < axes.size(); ++i) {
915 int64_t axis = axes[i].IntValue();
917 axis += src_tensor_dim;
919 TVM_FFI_ICHECK(axis >= 0 && axis < src_tensor_dim)
920 <<
"Axis " << axes[i].IntValue() <<
" is out of bounds for tensor with " << src_tensor_dim
922 normalized_axes.push_back(
Integer(axis));
925 std::vector<int64_t> begin_vec, end_vec, strides_vec;
926 std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode);
928 auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, normalized_axes,
929 begin[0]->dtype, slice_mode);
931 normalized_axes, slice_mode, begin_expr);
935 [&](
const ffi::Array<tirx::Var>& indices) {
936 ffi::Array<PrimExpr> real_indices;
937 for (
size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]);
938 for (
size_t i = 0; i < normalized_axes.size(); ++i) {
939 auto stride =
make_const(strides[i].dtype(), strides_vec[i]);
940 PrimExpr ind = indices[normalized_axes[i].IntValue()] * stride + begin_expr[i];
941 real_indices.Set(normalized_axes[i].IntValue(), ind);
943 return x(real_indices);
963 const ffi::Array<Integer>& end,
const ffi::Array<Integer>& strides,
964 std::string slice_mode =
"end", std::string name =
"T_strided_slice",
966 size_t src_tensor_dim =
static_cast<size_t>(x->shape.size());
967 ffi::Array<Integer> axes;
968 for (
size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i);
969 ffi::Array<Integer> begin_full(begin);
970 ffi::Array<Integer> end_full(end);
971 ffi::Array<Integer> strides_full(strides);
978 for (
size_t i = strides.size(); i < src_tensor_dim; ++i) {
979 strides_full.push_back(one);
981 for (
size_t i = begin.size(); i < src_tensor_dim; ++i) {
982 begin_full.push_back(GetConstInt(strides_full[i]) > 0 ? zero : max_range);
984 for (
size_t i = end.size(); i < src_tensor_dim; ++i) {
985 end_full.push_back(GetConstInt(strides_full[i]) < 0 ? zero : max_range);
1005 std::string name =
"T_split_sections",
1008 axis +=
static_cast<int>(x->shape.size());
1010 TVM_FFI_ICHECK_LT(axis, x->shape.size()) <<
"axis out of bounds";
1012 auto src_axis_size = x->shape[axis];
1014 TVM_FFI_ICHECK_GT(num_sections, 0) <<
"Slice count must be > 0";
1016 ffi::Array<PrimExpr> split_indices;
1017 auto seg_size =
indexdiv(src_axis_size + num_sections - 1, num_sections);
1018 for (
int i = 0; i < num_sections; ++i) {
1021 split_indices.push_back(seg_size * i);
1041 std::string mode =
"fast", std::string name =
"T_take",
1043 ffi::Array<PrimExpr> a_shape = a->shape;
1044 ffi::Array<PrimExpr> out_shape = indices->shape;
1046 for (
size_t i = 0; i < a_shape.size(); ++i) {
1047 a_size = a_size * a_shape[i];
1050 if (mode ==
"clip") {
1053 [&](
const ffi::Array<Var>& out_index) {
1055 return a(UnravelIndex(idx, a_shape));
1058 }
else if (mode ==
"fast") {
1059 LOG(WARNING) <<
"Fast mode segfaults when there are out-of-bounds indices. "
1060 "Make sure input indices are in bound";
1063 [&](
const ffi::Array<Var>& out_index) {
1064 return a(UnravelIndex(indices(out_index), a_shape));
1067 }
else if (mode ==
"nan") {
1070 [&](
const ffi::Array<Var>& out_index) {
1072 indices(out_index) < 0 || indices(out_index) >= a_size,
1073 tvm::FloatImm(a->dtype, std::numeric_limits<float>::quiet_NaN()), indices(out_index));
1074 return a(UnravelIndex(idx, a_shape));
1080 [&](
const ffi::Array<Var>& out_index) {
1081 auto idx =
truncmod(
truncmod(indices(out_index), a_size) + a_size, a_size);
1082 return a(UnravelIndex(idx, a_shape));
1101 int axis, std::string name =
"T_sequence_mask",
1103 TVM_FFI_ICHECK(axis == 0 || axis == 1) <<
"axis must be either 0 or 1";
1104 TVM_FFI_ICHECK_EQ(valid_length->shape.size(), 1)
1105 <<
"valid_length must have ndim=1, i.e., (batch_size,).";
1106 auto length_dim = data->shape[axis];
1107 auto batch_dim = data->shape[1 - axis];
1108 ffi::Array<PrimExpr> out_shape = data->shape;
1111 [&](
const ffi::Array<Var>& out_index) {
1112 ffi::Array<PrimExpr> len_index;
1113 auto tid = out_index[axis];
1114 auto bid = out_index[1 - axis];
1115 len_index.push_back(bid);
1140 int axis, std::string mode =
"fast", std::string name =
"T_take",
1143 axis +=
static_cast<int>(a->shape.size());
1145 TVM_FFI_ICHECK_GE(axis, 0) <<
"axis out of bounds";
1146 TVM_FFI_ICHECK_LT(axis, a->shape.size()) <<
"axis out of bounds";
1147 auto axis_dim = a->shape[axis];
1148 auto indices_shape = [&]() -> ffi::Array<PrimExpr> {
1149 if (
auto tensor = indices.as<
TensorNode>()) {
1150 return tensor->shape;
1156 int indices_len =
static_cast<int>(indices_shape.size());
1158 int batch_dims_ = batch_dims;
1159 if (batch_dims_ != 0) {
1160 TVM_FFI_ICHECK_GE(batch_dims_, -indices_len) <<
"batch_dims out of bounds";
1161 TVM_FFI_ICHECK_LE(batch_dims_, indices_len) <<
"batch_dims out of bounds";
1163 if (batch_dims_ < 0) {
1164 batch_dims_ = indices_len + batch_dims_;
1167 TVM_FFI_ICHECK_LT(batch_dims_, a->shape.size()) <<
"batch_dims out of bounds";
1168 TVM_FFI_ICHECK_LE(batch_dims_, axis) <<
"batch_dims must be less than or equal to axis";
1169 for (
int i = 0; i < batch_dims_; ++i) {
1170 auto addr1 = a->shape[i];
1171 auto addr2 = indices_shape[i];
1172 auto v1 =
static_cast<IntImm*
>(&addr1)->get()->value;
1173 auto v2 =
static_cast<IntImm*
>(&addr2)->get()->value;
1174 TVM_FFI_ICHECK_EQ(v1, v2) <<
"a.shape[" << i <<
"] should be equal to indices.shape[" << i
1182 ffi::Array<PrimExpr> out_shape;
1183 for (
int i = 0; i < batch_dims_; ++i) {
1184 out_shape.push_back(a->shape[i]);
1186 for (
int i = batch_dims_; i < axis; ++i) {
1187 out_shape.push_back(a->shape[i]);
1189 for (
int i = batch_dims_; i < indices_len; ++i) {
1190 out_shape.push_back(indices_shape[i]);
1192 for (
size_t i = axis + 1; i < a->shape.size(); ++i) {
1193 out_shape.push_back(a->shape[i]);
1196 auto get_index = [&](
const ffi::Array<PrimExpr>& indices_position) ->
PrimExpr {
1197 if (
auto tensor = indices.as<
Tensor>()) {
1198 return tensor.value()(indices_position);
1199 }
else if (
auto prim = indices.as<
PrimExpr>()) {
1200 TVM_FFI_ICHECK_EQ(indices_position.size(), 0);
1201 return prim.value();
1203 TVM_FFI_THROW(InternalError) <<
"Variant did not contain either allowed type";
1207 if (mode ==
"clip") {
1208 if (batch_dims_ == 0) {
1211 [&](
const ffi::Array<Var>& out_index) {
1212 ffi::Array<PrimExpr> indices_position;
1213 for (
size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1214 indices_position.push_back(out_index[j]);
1216 ffi::Array<PrimExpr> real_indices;
1217 for (
size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1218 real_indices.push_back(out_index[j]);
1220 auto idx =
tvm::min(
tvm::max(0, get_index(indices_position)), axis_dim - 1);
1221 real_indices.push_back(idx);
1222 for (
size_t j = axis + indices_len; j < out_index.size(); ++j) {
1223 real_indices.push_back(out_index[j]);
1225 return a(real_indices);
1231 [&](
const ffi::Array<Var>& out_index) {
1232 ffi::Array<PrimExpr> indices_position;
1233 for (
size_t j = 0; j < static_cast<size_t>(batch_dims_); ++j) {
1234 indices_position.push_back(out_index[j]);
1236 for (
size_t j = axis; j < static_cast<size_t>(axis + indices_len - batch_dims_); ++j) {
1237 indices_position.push_back(out_index[j]);
1239 ffi::Array<PrimExpr> real_indices;
1240 for (
size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1241 real_indices.push_back(out_index[j]);
1243 auto idx =
tvm::min(
tvm::max(0, get_index(indices_position)), axis_dim - 1);
1244 real_indices.push_back(idx);
1245 for (
size_t j = axis + indices_len - batch_dims_; j < out_index.size(); ++j) {
1246 real_indices.push_back(out_index[j]);
1248 return a(real_indices);
1252 }
else if (mode ==
"fast") {
1253 LOG(WARNING) <<
"Fast mode segfaults when there are out-of-bounds indices. "
1254 "Make sure input indices are in bound";
1257 [&](
const ffi::Array<Var>& out_index) {
1258 ffi::Array<PrimExpr> indices_position;
1259 for (
size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1260 indices_position.push_back(out_index[j]);
1262 ffi::Array<PrimExpr> real_indices;
1263 for (
size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1264 real_indices.push_back(out_index[j]);
1266 real_indices.push_back(get_index(indices_position));
1267 for (
size_t j = axis + indices_len; j < out_index.size(); ++j) {
1268 real_indices.push_back(out_index[j]);
1270 return a(real_indices);
1273 }
else if (mode ==
"nan") {
1276 [&](
const ffi::Array<Var>& out_index) {
1277 ffi::Array<PrimExpr> indices_position;
1278 for (
size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1279 indices_position.push_back(out_index[j]);
1281 ffi::Array<PrimExpr> real_indices;
1282 for (
size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1283 real_indices.push_back(out_index[j]);
1285 PrimExpr idx = get_index(indices_position);
1286 real_indices.push_back(idx);
1287 for (
size_t j = axis + indices_len; j < out_index.size(); ++j) {
1288 real_indices.push_back(out_index[j]);
1290 PrimExpr in_bounds = idx >= 0 && idx < axis_dim;
1292 in_bounds, a(real_indices),
1299 [&](
const ffi::Array<Var>& out_index) {
1300 ffi::Array<PrimExpr> indices_position;
1301 for (
size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1302 indices_position.push_back(out_index[j]);
1304 ffi::Array<PrimExpr> real_indices;
1305 for (
size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1306 real_indices.push_back(out_index[j]);
1308 auto idx =
truncmod(
truncmod(get_index(indices_position), axis_dim) + axis_dim, axis_dim);
1309 real_indices.push_back(idx);
1310 for (
size_t j = axis + indices_len; j < out_index.size(); ++j) {
1311 real_indices.push_back(out_index[j]);
1313 return a(real_indices);
1331 std::string name =
"T_where", std::string tag =
kBroadcast) {
1332 TVM_FFI_ICHECK_EQ(x->dtype, y->dtype)
1333 <<
"x and y must have the same dtype: " << x->dtype <<
" vs " << y->dtype;
1334 auto get_out_shape = [&]() {
1335 auto bh1 = detail::BroadcastShape(x->shape, y->shape);
1336 ffi::Array<PrimExpr> common_shape1(bh1.common_shape.begin(), bh1.common_shape.end());
1337 auto bh2 = detail::BroadcastShape(condition->shape, common_shape1);
1338 ffi::Array<PrimExpr> common_shape2(bh2.common_shape.begin(), bh2.common_shape.end());
1339 return common_shape2;
1342 auto oshape = get_out_shape();
1344 auto c_bh = detail::BroadcastShape(condition->shape, oshape);
1345 auto x_bh = detail::BroadcastShape(x->shape, oshape);
1346 auto y_bh = detail::BroadcastShape(y->shape, oshape);
1348 auto select = [&](tvm::ffi::Array<tvm::tirx::Var> ovars) {
1349 auto c = condition(InputIndexFromBroadcast(ovars, condition, c_bh.vars1, c_bh.all_vars));
1350 auto true_val = x(InputIndexFromBroadcast(ovars, x, x_bh.vars1, x_bh.all_vars));
1351 auto false_val = y(InputIndexFromBroadcast(ovars, y, y_bh.vars1, y_bh.all_vars));
1355 return compute(oshape, select, name, tag);
1372 int ndim =
static_cast<int>(x->shape.size());
1373 TVM_FFI_ICHECK(-ndim - 1 <= axis && axis <= ndim)
1374 <<
"repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
1375 <<
", but got axis = " << axis <<
", and data.ndim = " << ndim;
1376 TVM_FFI_ICHECK(repeats >= 1) <<
"repeat only accepts `repeats >= 1`"
1377 <<
", but got repeats = " << repeats;
1382 ffi::Array<PrimExpr> new_shape;
1383 for (
size_t i = 0; i < static_cast<size_t>(axis); ++i) {
1384 new_shape.push_back(x->shape[i]);
1386 new_shape.push_back(repeats * x->shape[axis]);
1387 for (
size_t i = axis + 1; i < x->shape.size(); ++i) {
1388 new_shape.push_back(x->shape[i]);
1393 [&](
const ffi::Array<Var>& indices) {
1394 ffi::Array<PrimExpr> idx;
1395 for (
size_t i = 0; i < static_cast<size_t>(axis); ++i) {
1396 idx.push_back(indices[i]);
1398 idx.push_back(
indexdiv(indices[axis], repeats));
1399 for (
size_t i = axis + 1; i < indices.size(); ++i) {
1400 idx.push_back(indices[i]);
1419 size_t ndim = x->shape.size();
1420 size_t rdim = reps.size();
1421 size_t tdim = (ndim > rdim) ? ndim : rdim;
1422 ffi::Array<PrimExpr> data_shape;
1423 ffi::Array<PrimExpr> reps_shape;
1424 ffi::Array<PrimExpr> new_shape;
1426 for (
size_t i = 0; i < ndim; ++i) {
1427 data_shape.push_back(x->shape[i]);
1428 reps_shape.push_back(reps[i]);
1430 }
else if (ndim > rdim) {
1431 for (
size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]);
1432 for (
size_t i = 0; i < (ndim - rdim); ++i) reps_shape.push_back(1);
1433 for (
size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]);
1435 for (
size_t i = 0; i < (rdim - ndim); ++i) data_shape.push_back(1);
1436 for (
size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]);
1437 for (
size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]);
1439 for (
size_t i = 0; i < tdim; ++i) new_shape.push_back(data_shape[i] * reps_shape[i]);
1441 if (is_empty_shape(new_shape)) {
1443 new_shape, [&](
const ffi::Array<Var>& indices) {
return tvm::cast(x->dtype, 0); }, name,
1448 [&](
const ffi::Array<Var>& indices) {
1449 ffi::Array<PrimExpr> idx;
1451 for (
size_t i = 0; i < ndim; ++i) idx.push_back(
indexmod(indices[i], x->shape[i]));
1453 for (
size_t i = 0; i < ndim; ++i)
1454 idx.push_back(
indexmod(indices[rdim - ndim + i], x->shape[i]));
1474 std::string name =
"T_tile", std::string tag =
kBroadcast) {
1475 size_t ndim = x->shape.size();
1476 if (is_empty_shape(new_shape)) {
1478 new_shape, [&](
const ffi::Array<Var>& indices) {
return tvm::cast(x->dtype, 0); }, name,
1483 [&](
const ffi::Array<Var>& indices) {
1484 ffi::Array<PrimExpr> idx;
1486 for (
size_t i = 0; i < ndim; ++i) {
1487 idx.push_back(
indexmod(indices[i], x->shape[i]));
1490 for (
size_t i = 0; i < ndim; ++i) {
1491 idx.push_back(
indexmod(indices[rdim - ndim + i], x->shape[i]));
1512 std::string name =
"T_gather", std::string tag =
kInjective) {
1513 size_t ndim_d = data->shape.size();
1514 size_t ndim_i = indices->shape.size();
1515 TVM_FFI_ICHECK_GE(ndim_d, 1) <<
"Cannot gather from a scalar.";
1516 TVM_FFI_ICHECK_EQ(ndim_d, ndim_i);
1520 TVM_FFI_ICHECK_GE(axis, 0);
1521 TVM_FFI_ICHECK_LT(axis, ndim_d);
1523 size_t indices_dim_i =
static_cast<size_t>(GetConstInt(indices->shape[axis]));
1524 TVM_FFI_ICHECK_GE(indices_dim_i, 1);
1526 TVM_FFI_ICHECK(indices->dtype.is_int() || indices->dtype.is_uint());
1528 ffi::Array<PrimExpr> out_shape;
1529 for (
size_t i = 0; i < ndim_i; ++i) {
1530 out_shape.push_back(indices->shape[i]);
1535 [&](
const ffi::Array<Var>& out_index) {
1536 ffi::Array<PrimExpr> indices_position;
1537 for (
size_t i = 0; i < ndim_i; ++i) {
1538 indices_position.push_back(out_index[i]);
1540 ffi::Array<PrimExpr> real_indices;
1541 for (
size_t i = 0; i < ndim_i; ++i) {
1542 if (i ==
static_cast<size_t>(axis)) {
1543 real_indices.push_back(indices(indices_position));
1545 real_indices.push_back(indices_position[i]);
1548 return data(real_indices);
1565 std::string name =
"T_gather_nd", std::string tag =
kInjective) {
1566 size_t ndim_d = data->shape.size();
1567 size_t ndim_i = indices->shape.size();
1568 TVM_FFI_ICHECK_GE(ndim_i, 1) <<
"indices tensor must have at least 1 dimensions";
1569 size_t indices_dim0 =
static_cast<size_t>(GetConstInt(indices->shape[0]));
1570 TVM_FFI_ICHECK_LE(indices_dim0, ndim_d) <<
"dim 0 of indices tensor must be no more "
1571 <<
"than dimensions of data tensor";
1572 ffi::Array<PrimExpr> out_shape;
1573 for (
size_t i = 1; i < ndim_i; ++i) {
1574 out_shape.push_back(indices->shape[i]);
1576 for (
size_t i = indices_dim0 + batch_dims; i < ndim_d; ++i) {
1577 out_shape.push_back(data->shape[i]);
1581 [&](
const ffi::Array<Var>& out_index) {
1582 ffi::Array<PrimExpr> indices_position;
1583 indices_position.push_back(0);
1584 for (
size_t i = 0; i < ndim_i - 1; ++i) {
1585 indices_position.push_back(out_index[i]);
1587 ffi::Array<PrimExpr> real_indices;
1588 for (
size_t i = 0; i < static_cast<size_t>(batch_dims); ++i) {
1589 real_indices.push_back(out_index[i]);
1591 for (
size_t i = 0; i < indices_dim0; ++i) {
1593 if (indices->dtype.is_int() || indices->dtype.is_uint()) {
1594 real_indices.push_back(indices(indices_position));
1599 if (real_indices.size() == ndim_d) {
1600 return data(real_indices);
1602 for (
size_t i = ndim_i - 1; i < out_index.size(); ++i) {
1603 real_indices.push_back(out_index[i]);
1605 return data(real_indices);
1626 bool trans_a =
false,
bool trans_b =
false,
1627 std::string name =
"T_matmul", std::string tag =
kMatMul) {
1628 tvm::ffi::Array<tvm::PrimExpr> output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]};
1631 return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), {k});
1648 std::string name =
"T_tensordot", std::string tag =
kMatMul) {
1649 TVM_FFI_ICHECK_GE(A->shape.size(), axes);
1650 TVM_FFI_ICHECK_GE(B->shape.size(), axes);
1652 ffi::Array<PrimExpr> output_shape(A->shape.begin(), A->shape.end() + (-axes));
1653 for (
auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) output_shape.push_back(*it);
1655 ffi::Array<IterVar> iter_vars;
1656 for (
int i = 0; i < axes; ++i)
1657 iter_vars.push_back(
reduce_axis(
Range(0, B->shape[i]),
"k" + std::to_string(i)));
1659 auto func = [&A, &B, &iter_vars, axes](
const ffi::Array<Var>& input_indices) {
1660 ffi::Array<PrimExpr> A_indices(input_indices.begin(),
1661 input_indices.begin() + (A->shape.size() - axes));
1662 for (
auto& v : iter_vars) A_indices.push_back(v);
1664 ffi::Array<PrimExpr> B_indices;
1665 for (
auto& v : iter_vars) B_indices.push_back(v);
1667 auto it = input_indices.begin() + (A->shape.size() - axes);
1668 for (; it != input_indices.end(); ++it) B_indices.push_back(*it);
1671 if (iter_vars.empty()) {
1672 return A(A_indices) * B(B_indices);
1674 return sum(A(A_indices) * B(B_indices), iter_vars);
1678 return compute(output_shape, func, name, tag);
1694 ffi::Array<PrimExpr> B_axes, std::string name =
"T_tensordot",
1696 TVM_FFI_ICHECK_EQ(A_axes.size(), B_axes.size());
1698 auto A_axes_val = GetConstIntValues(A_axes,
"A_axes");
1699 auto B_axes_val = GetConstIntValues(B_axes,
"B_axes");
1701 ffi::Array<PrimExpr> output_shape;
1702 for (
unsigned i = 0; i < A->shape.size(); ++i)
1703 if (std::find(A_axes_val.begin(), A_axes_val.end(), i) == A_axes_val.end())
1704 output_shape.push_back(A->shape[i]);
1705 for (
unsigned i = 0; i < B->shape.size(); ++i)
1706 if (std::find(B_axes_val.begin(), B_axes_val.end(), i) == B_axes_val.end())
1707 output_shape.push_back(B->shape[i]);
1709 ffi::Array<IterVar> iter_vars;
1710 for (
unsigned i = 0; i < B_axes_val.size(); ++i)
1711 iter_vars.push_back(
reduce_axis(
Range(0, B->shape[B_axes_val[i]]),
"k" + std::to_string(i)));
1713 auto func = [&A, &B, &iter_vars, A_axes_val, B_axes_val](
const ffi::Array<Var>& input_indices) {
1715 ffi::Array<PrimExpr> A_indices;
1716 for (
unsigned i = 0; i < A->shape.size(); ++i) {
1717 auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i);
1718 if (axes_pos == A_axes_val.end()) {
1719 A_indices.push_back(input_indices[idx_input++]);
1721 A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]);
1725 ffi::Array<PrimExpr> B_indices;
1726 for (
unsigned i = 0; i < B->shape.size(); ++i) {
1727 auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i);
1728 if (axes_pos == B_axes_val.end()) {
1729 B_indices.push_back(input_indices[idx_input++]);
1731 B_indices.push_back(iter_vars[axes_pos - B_axes_val.begin()]);
1734 return sum(A(A_indices) * B(B_indices), iter_vars);
1736 return compute(output_shape, func, name, tag);
1747 }
else if (is_all_int && analyzer.
CanProveLess(step, 0)) {
1755 num_elem = analyzer.
Simplify(num_elem);
1759 [&](
const ffi::Array<Var>& indices) {
return tvm::cast(dtype, start + step * indices[0]); },
1773 inline ffi::Array<Tensor>
meshgrid(
const ffi::Array<Tensor>& inputs,
const std::string& indexing,
1774 std::string name =
"T_meshgrid", std::string tag =
kInjective) {
1775 const bool cartesian_indexing = indexing ==
"xy" && inputs.size() >= 2;
1776 ffi::Array<PrimExpr> out_shape;
1777 for (
size_t i = 0; i < inputs.size(); ++i) {
1778 const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i;
1779 out_shape.push_back(inputs[src_index]->
shape.size() == 0 ? 1 : inputs[src_index]->shape[0]);
1781 ffi::Array<Tensor> result;
1782 for (
size_t i = 0; i < inputs.size(); ++i) {
1785 [&](
const ffi::Array<Var>& indices) {
1786 const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i;
1787 auto ndim = inputs[i]->GetShape().size();
1788 ffi::Array<PrimExpr> real_indices = {};
1790 real_indices = {indices[src_index]};
1792 return inputs[i](real_indices);
1810 const std::string& dst_layout,
1811 const std::string schedule_rule =
"None",
1812 const std::string name =
"T_layout_trans",
1814 Layout src_layout_struct(src_layout);
1815 Layout dst_layout_struct(dst_layout);
1817 if (src_layout_struct.
Equals(dst_layout_struct)) {
1821 TVM_FFI_ICHECK(src_layout_struct.defined() && dst_layout_struct.defined())
1822 <<
"cannot convert from/to undefined layout";
1825 TVM_FFI_ICHECK(layout_converter.defined())
1826 <<
"cannot convert from " << src_layout <<
" to " << dst_layout;
1828 ffi::Array<PrimExpr> dst_shape = layout_converter.ForwardShape(src->shape);
1830 ffi::Map<ffi::String, ffi::Any> attrs = {{
"schedule_rule", ffi::String(schedule_rule)},
1832 {
"src_layout", ffi::String(src_layout)},
1833 {
"dst_layout", ffi::String(dst_layout)},
1834 {
"input_shape", src->shape}};
1838 [&](
const ffi::Array<Var>& dst_indices) {
1839 ffi::Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
1840 ffi::Array<PrimExpr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
1842 for (
size_t i = 0; i < src.ndim(); ++i) {
1843 in_range = in_range && (src_indices[i] < src->shape[i]);
1852 std::vector<std::string>* axes) {
1854 std::string axis =
"";
1855 for (
char c : std::string(layout)) {
1856 if (c >=
'A' && c <=
'z') {
1859 shape->push_back(factor);
1862 }
else if (c >=
'0' && c <=
'9') {
1863 factor = factor * 10 + c -
'0';
1864 if (!axis.empty()) {
1865 axes->push_back(axis);
1869 TVM_FFI_THROW(InternalError) <<
"Invalid layout " << layout;
1872 if (!axis.empty()) {
1873 axes->push_back(axis);
1888 const Tensor& src,
const ffi::String& src_layout,
const ffi::String& dst_layout,
1889 const ffi::String name =
"T_auto_scheduler_layout_trans",
const ffi::String tag =
kInjective) {
1890 ffi::Array<PrimExpr> src_shape;
1891 std::vector<std::string> src_axes;
1892 ffi::Array<PrimExpr> dst_shape;
1893 std::vector<std::string> dst_axes;
1899 [&](
const ffi::Array<Var>& dst_indices) {
1900 ffi::Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
1901 ffi::Array<PrimExpr> src_indices;
1902 for (
const std::string& src_axis : src_axes) {
1904 TVM_FFI_ICHECK_EQ(dst_indices_expr.size(), dst_axes.size());
1905 for (
size_t i = 0; i < dst_axes.size(); ++i) {
1906 if (dst_axes[i] == src_axis) {
1907 src_index = src_index * dst_shape[i] + dst_indices_expr[i];
1910 src_indices.push_back(src_index);
1912 return src(src_indices);
1955 const ffi::String name =
"T_meta_schedule_layout_trans",
const ffi::String tag =
kInjective) {
1957 ffi::Array<Range> iter_domain;
1958 iter_domain.reserve(src->shape.size());
1959 for (
const PrimExpr& e : src->shape) {
1962 ffi::Array<PrimExpr> post_transform_shape = index_map->MapShape(src->shape, &analyzer);
1964 post_transform_shape,
1965 [src, inv = index_map.
Inverse(iter_domain, &analyzer),
1966 &analyzer](
const ffi::Array<Var>& indices) ->
PrimExpr {
1968 inv->MapIndices(ffi::Array<PrimExpr>{indices.begin(), indices.end()}, &analyzer));
1983 int ndim =
static_cast<int>(src->shape.size());
1984 ffi::Array<PrimExpr> out_shape{ndim};
1987 [&](
const ffi::Array<Var>& indices) {
1988 auto idx = indices[0];
1990 for (
int i = 0; i < ndim; ++i) {
2007 const std::string& name =
"tensor_size",
2009 int ndim =
static_cast<int>(src->shape.size());
2010 ffi::Array<PrimExpr> out_tensor_size = {};
2013 [&](
const ffi::Array<Var>& indices) {
2015 for (
int i = 0; i < ndim; ++i) {
2016 ret *= src->shape[i];
2038 int depth,
int axis,
const DataType& dtype,
2039 ffi::Array<PrimExpr> oshape = ffi::Array<PrimExpr>(),
2040 const std::string name =
"T_one_hot",
const std::string tag =
kInjective) {
2041 int true_axis = (axis == -1) ? indices->shape.size() : axis;
2042 if (oshape.size() == 0) {
2043 int ndim = indices->shape.size() + 1;
2044 int indices_index = 0;
2045 for (
int i = 0; i < ndim; i++) {
2046 if (i == true_axis) {
2047 oshape.push_back(
Integer(depth));
2049 oshape.push_back(indices->shape[indices_index++]);
2058 [&](
const ffi::Array<Var>& iter_vars) {
2059 ffi::Array<Var> indices_indices;
2060 for (
size_t i = 0; i < iter_vars.size(); i++) {
2061 if (
static_cast<int>(i) == true_axis) {
2065 indices_indices.push_back(iter_vars[i]);
2068 auto idx = iter_vars[true_axis];
2069 return tirx::Select(indices(indices_indices) == idx, on_value_cast, off_value_cast);
2085 const ffi::Array<PrimExpr>& output_shape,
const Tensor& sparse_values,
2087 const std::string name =
"T_sparse_to_dense",
2089 TVM_FFI_ICHECK(sparse_indices->dtype.is_int()) <<
"sparse_indices only accepts integer values";
2090 TVM_FFI_ICHECK_LE(sparse_indices->shape.size(), 3)
2091 <<
"sparse_indices tensor should be 0D, 1D, or 2D only";
2092 TVM_FFI_ICHECK_LE(sparse_values->shape.size(), 2)
2093 <<
"sparse_values tensor should be 0D or 1D only";
2095 const auto rank_sparse_indices =
static_cast<int>(sparse_indices->shape.size());
2096 ffi::Array<PrimExpr> oshape;
2097 for (
auto l : output_shape) {
2098 oshape.push_back(l);
2102 [&](
const ffi::Array<Var>& indices) {
2104 if (0 == rank_sparse_indices) {
2106 }
else if (1 == rank_sparse_indices) {
2107 for (
int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
2111 for (
int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
2113 for (
int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) {
2114 PrimExpr comparision = indices[k] == sparse_indices[j][k];
2115 aggregate_condition = 0 == k ? comparision : aggregate_condition && comparision;
2138 bool super_diag_right_align,
bool sub_diag_right_align,
2139 const std::string name =
"T_matrix_set_diag",
2141 size_t ndim = input->shape.size() - 1;
2143 bool only_one_diagonal = k1 == k2;
2147 [&](
const ffi::Array<Var>& iter_vars) {
2148 auto get_diag = [&]() {
2149 ffi::Array<PrimExpr> diagonal_indices;
2150 PrimExpr k, offset = 0;
2151 for (size_t i = 0; i < ndim - 1; i++) {
2152 diagonal_indices.push_back(iter_vars[i]);
2154 if (only_one_diagonal) {
2158 k = iter_vars[ndim] - iter_vars[ndim - 1];
2159 diagonal_indices.push_back(k2 - k);
2162 auto get_offset = [&](PrimExpr M, PrimExpr N) {
2164 return diagonal->shape[diagonal->shape.size() - 1] - if_then_else(M < N, M, N);
2166 offset = if_then_else(
2168 super_diag_right_align ? get_offset(input->shape[ndim] - k, input->shape[ndim - 1])
2170 sub_diag_right_align ? get_offset(input->shape[ndim], input->shape[ndim - 1] + k)
2173 diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) +
2175 return diagonal(diagonal_indices);
2179 get_diag(), input(iter_vars)),
2194 const std::string name =
"advanced_index",
2196 TVM_FFI_ICHECK_LE(indices.size(), data->shape.size()) <<
"too many indices for data!";
2197 ffi::Array<PrimExpr> oshape;
2198 ffi::Array<PrimExpr> broadcast_shape;
2199 ffi::Array<Tensor> bindices;
2201 broadcast_shape = indices[0]->shape;
2202 for (
size_t i = 1; i < indices.size(); ++i) {
2203 auto bh = detail::BroadcastShape(broadcast_shape, indices[i]->
shape);
2204 broadcast_shape = ffi::Array<PrimExpr>(bh.common_shape.begin(), bh.common_shape.end());
2206 if (indices.size() == 1) {
2211 for (
size_t i = 0; i < indices.size(); ++i) {
2212 bindices.push_back(
broadcast_to(indices[i], broadcast_shape));
2216 for (
const auto& dim : broadcast_shape) {
2217 oshape.push_back(dim);
2219 for (
size_t i = indices.size(); i < data->
shape.size(); ++i) {
2220 oshape.push_back(data->shape[i]);
2225 [&](
const ffi::Array<Var>& iter_var) {
2226 ffi::Array<PrimExpr> tensor_indices;
2227 for (
size_t i = 0; i < broadcast_shape.size(); ++i) {
2228 tensor_indices.push_back(iter_var[i]);
2230 ffi::Array<PrimExpr> real_indices;
2231 for (
size_t i = 0; i < bindices.size(); ++i) {
2232 real_indices.push_back(bindices[i](tensor_indices));
2234 for (
size_t i = broadcast_shape.size(); i < iter_var.size(); ++i) {
2235 real_indices.push_back(iter_var[i]);
2238 return data(real_indices);
2247 ffi::Array<PrimExpr> output_shape,
2248 std::string name =
"T_strided_slice_dynamic",
2250 const size_t num_dynamic_axes = x.
ndim();
2251 TVM_FFI_ICHECK_EQ(begin.
ndim(), 1);
2252 TVM_FFI_ICHECK_EQ(end.
ndim(), 1);
2253 TVM_FFI_ICHECK_EQ(strides.
ndim(), 1);
2254 const auto* len_begin = begin->shape[0].as<
IntImmNode>();
2255 const auto* len_end = end->shape[0].as<
IntImmNode>();
2256 const auto* len_strides = strides->shape[0].as<
IntImmNode>();
2257 TVM_FFI_ICHECK(len_begin);
2258 TVM_FFI_ICHECK(len_end);
2259 TVM_FFI_ICHECK(len_strides);
2260 TVM_FFI_ICHECK_EQ(len_begin->value, num_dynamic_axes);
2261 TVM_FFI_ICHECK_EQ(len_end->
value, num_dynamic_axes);
2262 TVM_FFI_ICHECK_EQ(len_strides->
value, num_dynamic_axes);
2266 [&](
const ffi::Array<tvm::tirx::Var>& indices) {
2267 ffi::Array<PrimExpr> real_indices;
2268 for (
size_t i = 0; i < num_dynamic_axes; ++i) {
2270 real_indices.push_back(indices[i] * strides(ind) +
tvm::min(begin(ind), x->shape[i] - 1));
2272 return x(real_indices);
Algebra expression simplifications.
Broadcast op constructions.
Managed reference class to FloatImmNode.
Definition: expr.h:546
Constant integer literals in the program.
Definition: expr.h:494
int64_t value
the Internal value.
Definition: expr.h:497
Managed reference class to IntImmNode.
Definition: expr.h:511
Container of constant int that adds more constructors.
Definition: expr.h:601
Reference to PrimExprNode.
Definition: expr.h:126
DataType dtype() const
Definition: expr.h:140
Range container
Definition: expr.h:690
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span=Span())
construct a new range with min and extent The corresponding constructor is removed,...
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:634
bool CanProveGreaterEqual(const PrimExpr &expr, int64_t lower_bound)
Whether can we prove expr >= val.
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
bool CanProveLess(const PrimExpr &expr, int64_t upper_bound)
Whether can we prove expr < val.
Runtime primitive data type.
Definition: data_type.h:47
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:295
bool is_int() const
Definition: data_type.h:194
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:278
Managed Tensor. The array is backed by reference counted blocks.
Definition: tensor.h:54
Node to represent a tensor.
Definition: tensor.h:70
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:100
size_t ndim() const
Definition: tensor.h:212
Bijective function mapping for data layout transformation. Given two Layout, BijectiveLayout build an...
Definition: data_layout.h:386
Definition: index_map.h:170
IndexMap Inverse(ffi::Array< Range > initial_ranges, arith::Analyzer *analyzer) const
Generate the inverse mapping.
Managed reference to LayoutNode.
Definition: data_layout.h:126
bool Equals(const Layout &rhs) const
Whether the two layouts are equal.
Definition: data_layout.h:332
Managed reference to SelectNode.
Definition: expr.h:514
A variable node in the IR.
Definition: var.h:47
ffi::String name_hint
The hint to the variable name.
Definition: var.h:53
a named variable in TIR
Definition: var.h:76
Utility functions for handling constants in TVM expressions.
Layout expression to describe the data organization of a tensor. And BijectiveLayout to mapping two d...
Defines a remapping of buffer indices.
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(ffi::Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", ffi::Map< ffi::String, ffi::Any > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:1007
DataType DefaultIndexType()
if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return int32
Definition: buffer.h:43
PrimExpr make_zero(DataType t, Span span=Span())
Make a const zero expr.
Definition: op.h:1021
te::Tensor dynamic_strided_slice(const te::Tensor &x, const te::Tensor &begin, const te::Tensor &end, const te::Tensor &strides, ffi::Array< PrimExpr > output_shape, std::string name="T_strided_slice_dynamic", std::string tag=kInjective)
Definition: transform.h:2245
PrimExpr GetLength(PrimExpr begin, PrimExpr end, PrimExpr stride, PrimExpr extent, bool assume_inbound=true)
Definition: transform.h:689
Tensor sequence_mask(const Tensor &data, const Tensor &valid_length, double mask_value, int axis, std::string name="T_sequence_mask", std::string tag=kInjective)
Mask the out-of-boundary elements of each sequence.
Definition: transform.h:1100
Tensor gather_nd(const Tensor &data, const Tensor &indices, int batch_dims=0, std::string name="T_gather_nd", std::string tag=kInjective)
Gather elements from a n-dimension array.
Definition: transform.h:1564
int64_t StaticCanonicalizeIndex(int64_t index, int64_t extent, int64_t stride)
Definition: transform.h:670
Tensor reshape(const Tensor &x, ffi::Array< PrimExpr > newshape, std::string name="T_reshape", std::string tag=kInjective)
Reshape a tensor.
Definition: transform.h:330
Tensor one_hot(const Tensor &indices, const PrimExpr on_value, const PrimExpr off_value, int depth, int axis, const DataType &dtype, ffi::Array< PrimExpr > oshape=ffi::Array< PrimExpr >(), const std::string name="T_one_hot", const std::string tag=kInjective)
Returns a one-hot tensor where the locations repsented by indices take value on_value,...
Definition: transform.h:2037
tvm::te::Tensor broadcast_to(const tvm::te::Tensor &t, const tvm::ffi::Array< tvm::PrimExpr > &output_shape, std::string name="T_broadcast_to", std::string tag=kBroadcast)
Creates an operation that broadcasts a tensor into a compatible shape according to numpy's rules.
Definition: broadcast.h:48
constexpr auto kBroadcast
Definition: tags.h:36
Tensor arange(const PrimExpr &start, const PrimExpr &stop, const PrimExpr &step, DataType dtype, std::string name="T_arange", std::string tag=kInjective)
Definition: transform.h:1739
constexpr auto kInjective
Definition: tags.h:33
Tensor stack(const ffi::Array< Tensor > &inputs, int axis=0, std::string name="T_stack", std::string tag=kInjective)
Join a sequence of tensors along a new axis.
Definition: transform.h:541
Tensor auto_scheduler_layout_transform(const Tensor &src, const ffi::String &src_layout, const ffi::String &dst_layout, const ffi::String name="T_auto_scheduler_layout_trans", const ffi::String tag=kInjective)
Transform the auto-scheduler generated layout according to src_layout and dst_layout.
Definition: transform.h:1887
ffi::Array< PrimExpr > StridedSliceOutputShape(const ffi::Array< PrimExpr > &ishape, const ffi::Array< Integer > &begin, const ffi::Array< Integer > &end, const ffi::Array< Integer > &strides, const ffi::Array< Integer > &axes, const std::string &slice_mode)
Calculate the output shape of strided_slice, the entry point for Relax type relation.
Definition: transform.h:868
PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stride)
Definition: transform.h:679
te::Tensor dynamic_strided_slice_with_axes(const te::Tensor &x, const ffi::Array< PrimExpr > &begin, const ffi::Array< PrimExpr > &end, const ffi::Array< PrimExpr > &strides, const ffi::Array< Integer > &axes, bool assume_inbound=true, std::string name="T_dynamic_strided_slice_with_axes", std::string tag=kInjective)
strided_slice of a tensor where begin/end/stride can be mixed static and dynamic
Definition: transform.h:716
Tensor transpose(const Tensor &x, ffi::Optional< ffi::Array< Integer >> opt_axes, std::string name="T_transpose", std::string tag=kInjective)
Permute the dimensions of an array.
Definition: transform.h:205
void parse_auto_scheduler_layout(const ffi::String &layout, ffi::Array< PrimExpr > *shape, std::vector< std::string > *axes)
Utility function for auto_scheduler_layout_transform.
Definition: transform.h:1851
Tensor squeeze(const Tensor &x, ffi::Optional< ffi::Array< Integer >> opt_axes, bool atleast1d=false, std::string name="T_squeeze", std::string tag=kInjective)
Remove size 1 dimensions from the shape of a tensor. The removed dimensions must have a constant size...
Definition: transform.h:415
ffi::Array< Tensor > split_n_sections(const Tensor &x, int num_sections, int axis, std::string name="T_split_sections", std::string tag=kInjective)
Split a tensor into a number of sub-tensors.
Definition: transform.h:1004
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type,...
Definition: elemwise.h:277
Tensor expand_dims(const Tensor &x, int axis, int num_newaxis=1, std::string name="T_expand_dims", std::string tag=kBroadcast)
Creates an operation to insert new dimensions of length 1.
Definition: transform.h:156
Tensor sparse_to_dense(const Tensor &sparse_indices, const ffi::Array< PrimExpr > &output_shape, const Tensor &sparse_values, const PrimExpr &default_value, const std::string name="T_sparse_to_dense", const std::string tag=kInjective)
Get a dense tensor.
Definition: transform.h:2084
Tensor unravel_index(const Tensor &x, const Tensor &shape, std::string name="T_unravel", std::string tag=kInjective)
Converts a flat index or array of flat indices into a tuple of coordinate arrays.
Definition: transform.h:367
Tensor layout_transform(const Tensor &src, const std::string &src_layout, const std::string &dst_layout, const std::string schedule_rule="None", const std::string name="T_layout_trans", const std::string tag=kInjective)
Transform the layout according to src_layout and dst_layout.
Definition: transform.h:1809
Tensor adv_index(const Tensor &data, const ffi::Array< Tensor > &indices, const std::string name="advanced_index", const std::string tag=kInjective)
Numpy style advanced indexing with tensor.
Definition: transform.h:2193
Tensor strided_slice(const Tensor &x, const ffi::Array< Integer > &begin, const ffi::Array< Integer > &end, const ffi::Array< Integer > &strides, std::string slice_mode="end", std::string name="T_strided_slice", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:962
Tensor concatenate(const ffi::Array< Tensor > &inputs, int axis=0, std::string name="T_concat", std::string tag=kInjective)
Join a sequence of tensors along an existing axis.
Definition: transform.h:481
ffi::Array< Tensor > meshgrid(const ffi::Array< Tensor > &inputs, const std::string &indexing, std::string name="T_meshgrid", std::string tag=kInjective)
Produce grids by expanding input over dimensions defined by other inputs.
Definition: transform.h:1773
constexpr auto kMatMul
Definition: tags.h:37
Tensor strided_slice_with_axes(const Tensor &x, const ffi::Array< Integer > &begin, const ffi::Array< Integer > &end, const ffi::Array< Integer > &strides, const ffi::Array< Integer > &axes, std::string slice_mode="end", std::string name="T_strided_slice_with_axes", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:900
Tensor dyn_tile(const Tensor &x, ffi::Array< PrimExpr > new_shape, size_t rdim, std::string name="T_tile", std::string tag=kBroadcast)
Creates an operation to tile elements of an array.
Definition: transform.h:1473
Tensor reverse_sequence(const Tensor &x, const Tensor &seq_lengths, int seq_axis=1, int batch_axis=0, std::string name="T_reverse_sequence", std::string tag=kInjective)
Reverse the tensor for variable length slices. Input is first sliced along batch axis and then elemen...
Definition: transform.h:265
ffi::Array< Tensor > split_indices_array(const Tensor &x, ffi::Array< PrimExpr > split_indices, int axis, std::string name="T_split", std::string tag=kInjective)
Split a tensor into multiple sub-tensors.
Definition: transform.h:587
Tensor tensordot(const Tensor &A, const tvm::te::Tensor &B, int axes=2, std::string name="T_tensordot", std::string tag=kMatMul)
A generalization of matrix multiplication to tensors.
Definition: transform.h:1647
Tensor sum(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:328
Tensor tile(const Tensor &x, ffi::Array< Integer > reps, std::string name="T_tile", std::string tag=kBroadcast)
Creates an operation to tile elements of an array.
Definition: transform.h:1417
Tensor meta_schedule_layout_transform(const Tensor &src, const tirx::IndexMap &index_map, const ffi::String name="T_meta_schedule_layout_trans", const ffi::String tag=kInjective)
Transform the meta-schedule generated layout according to TIR's IndexMap.
Definition: transform.h:1953
Tensor take(const Tensor &a, const Tensor &indices, int batch_dims, std::string mode="fast", std::string name="T_take", std::string tag=kInjective)
Take elements from an flattened input array when axis is None.
Definition: transform.h:1040
PrimExpr DynamicCanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stride)
Definition: transform.h:652
tvm::te::Tensor matmul(const tvm::te::Tensor &A, const tvm::te::Tensor &B, bool trans_a=false, bool trans_b=false, std::string name="T_matmul", std::string tag=kMatMul)
Creates an operation that calculates a matrix multiplication (row-major notation): A(i,...
Definition: transform.h:1625
Tensor dynamic_strided_slice(const Tensor &x, const ffi::Array< PrimExpr > &begin, const ffi::Array< PrimExpr > &end, const ffi::Array< PrimExpr > &strides, bool assume_inbound=true, std::string name="T_dynamic_strided_slice", std::string tag=kInjective)
strided_slice of a tensor where begin/end/stride can be mixed static and dynamic
Definition: transform.h:773
Tensor matrix_set_diag(const Tensor &input, const Tensor &diagonal, int k1, int k2, bool super_diag_right_align, bool sub_diag_right_align, const std::string name="T_matrix_set_diag", const std::string tag=kInjective)
Returns a tensor with the diagonal of input tensor replaced with the provided diagonals.
Definition: transform.h:2137
Tensor where(const Tensor &condition, const Tensor &x, const Tensor &y, std::string name="T_where", std::string tag=kBroadcast)
Return the elements, either from x or y, depending on the condition.
Definition: transform.h:1330
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1981
Tensor gather(const Tensor &data, int axis, const Tensor &indices, std::string name="T_gather", std::string tag=kInjective)
Gather values along given axis from given indices.
Definition: transform.h:1511
Tensor sliding_window(const Tensor &x, int axis, ffi::Array< Integer > window_shape, ffi::Array< Integer > strides, std::string name="T_sliding_window", std::string tag="")
Creates an operation to slide a window over the input x.
Definition: transform.h:76
te::Tensor tensor_size(const te::Tensor &src, const DataType &dtype, const std::string &name="tensor_size", const std::string &tag=kInjective)
Get the size of input tensor.
Definition: transform.h:2006
Tensor repeat(const Tensor &x, int repeats, int axis, std::string name="T_repeat", std::string tag=kBroadcast)
Creates an operation to repeat elements of an array.
Definition: transform.h:1370
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span=Span())
compute ceil(a / b)
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of truncdiv
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr max_value(const DataType &dtype, Span span=Span())
PrimExpr ceil(PrimExpr x, Span span=Span())
Calculate ceil(x)
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
PrimExpr sum(PrimExpr source, ffi::Array< tirx::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b)
Operation node can generate one or multiple Tensors.
Index ravel and unraval operations.
Utility functions for strided_slice op.
Utility functions for handling tensor.
Common operators defined for Expr.