tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_TRANSFORM_H_
25 #define TVM_TOPI_TRANSFORM_H_
26 
27 #include <tvm/arith/analyzer.h>
28 #include <tvm/te/operation.h>
29 #include <tvm/tir/data_layout.h>
30 #include <tvm/tir/index_map.h>
31 #include <tvm/topi/broadcast.h>
37 #include <tvm/topi/tags.h>
38 
39 #include <algorithm>
40 #include <iterator>
41 #include <limits>
42 #include <string>
43 #include <unordered_set>
44 #include <vector>
45 
46 #include "tvm/ir/expr.h"
47 #include "tvm/runtime/data_type.h"
48 #include "tvm/tir/expr.h"
49 #include "tvm/tir/op.h"
50 #include "tvm/tir/var.h"
51 
52 namespace tvm {
53 namespace topi {
54 
55 using namespace tvm::te;
56 using namespace topi::detail;
57 
75 inline Tensor sliding_window(const Tensor& x, int axis, Array<Integer> window_shape,
76  Array<Integer> strides, std::string name = "T_sliding_window",
77  std::string tag = "") {
78  CHECK_GE(axis, 0);
79  auto _axis = size_t(axis);
80  CHECK_LT(_axis, x->shape.size()) << "axis must be a valid dimension index of x.";
81  CHECK_EQ(x->shape.size() - _axis, window_shape.size())
82  << "There must be a window shape for every dimension of x "
83  << "over which we are sliding the window.";
84  CHECK_EQ(strides.size(), window_shape.size()) << "Windows and strides should be the same length.";
85 
86  // Compute the new shape.
87  Array<PrimExpr> new_shape;
88  // Dimensions up until `axis` remain the same.
89  for (size_t i = 0; i < _axis; ++i) {
90  new_shape.push_back(x->shape[i]);
91  }
92 
93  // New dimensions which result from sliding the window in each dimension. One new dimension per
94  // window dimension.
95  for (size_t i = 0; i < window_shape.size(); ++i) {
96  // Length of the shape along this dimension.
97  auto dim_len = x->shape[_axis + i];
98  // Length of the window along this dimension.
99  auto window_len = window_shape[i];
100  // Strides along this dimension.
101  auto stride = strides[i];
102 
103  new_shape.push_back(floordiv(dim_len - (window_len - 1) + stride - 1, stride));
104  }
105 
106  // Dimensions comprising the window.
107  for (size_t i = 0; i < window_shape.size(); ++i) {
108  new_shape.push_back(window_shape[i]);
109  }
110 
111  ICHECK(new_shape.size() == _axis + 2 * window_shape.size());
112 
113  return compute(
114  new_shape,
115  [&](const Array<Var>& indices) {
116  // The index at which to index the old tensor x.
117  Array<PrimExpr> idx;
118 
119  // Dimensions up until `axis` remain the same.
120  for (size_t i = 0; i < _axis; ++i) {
121  idx.push_back(indices[i]);
122  }
123 
124  for (size_t i = 0; i < window_shape.size(); ++i) {
125  // Which window in this dimension we are indexing.
126  auto window_idx = indices[_axis + i];
127  // Which index within the window we are indexing.
128  auto idx_within_window = indices[_axis + window_shape.size() + i];
129  // Stride value for this dimension.
130  auto stride = strides[i];
131 
132  idx.push_back(window_idx * stride + idx_within_window);
133  }
134 
135  ICHECK(idx.size() == x->shape.size());
136 
137  return x(idx);
138  },
139  name, tag);
140 }
141 
154 inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1,
155  std::string name = "T_expand_dims", std::string tag = kBroadcast) {
156  int ndim = static_cast<int>(x->shape.size());
157  ICHECK(-ndim - 1 <= axis && axis <= ndim)
158  << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]"
159  << ", but got axis = " << axis << ", and data.ndim = " << ndim;
160  ICHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`"
161  << ", but got num_newaxis = " << num_newaxis;
162  if (axis < 0) {
163  // Calculate offset from last dimension
164  axis = ndim + axis + 1;
165  }
166  Array<PrimExpr> new_shape;
167  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
168  new_shape.push_back(x->shape[i]);
169  }
170  for (size_t i = 0; i < static_cast<size_t>(num_newaxis); ++i) {
171  new_shape.push_back(1);
172  }
173  for (size_t i = axis; i < x->shape.size(); ++i) {
174  new_shape.push_back(x->shape[i]);
175  }
176 
177  return compute(
178  new_shape,
179  [&](const Array<Var>& indices) {
180  Array<PrimExpr> idx;
181  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
182  idx.push_back(indices[i]);
183  }
184  for (size_t i = axis + num_newaxis; i < indices.size(); ++i) {
185  idx.push_back(indices[i]);
186  }
187  return x(idx);
188  },
189  name, tag);
190 }
191 
203 inline Tensor transpose(const Tensor& x, Array<Integer> axes, std::string name = "T_transpose",
204  std::string tag = kInjective) {
205  if (!axes.defined() || axes.size() == 0) {
206  axes = Array<Integer>();
207  for (int i = static_cast<int>(x->shape.size()) - 1; i >= 0; --i) {
208  axes.push_back(i);
209  }
210  }
211 
212  Array<PrimExpr> new_shape;
213  for (size_t i = 0; i < axes.size(); ++i) {
214  int axis = static_cast<int>(axes[i]->value);
215  int new_axis = axis;
216  if (axis < 0) {
217  new_axis = static_cast<int>(x->shape.size()) + axis;
218  axes.Set(i, new_axis);
219  }
220  ICHECK((new_axis >= 0) && (new_axis < static_cast<int>(x->shape.size())))
221  << "axis=" << axis << " is invalid for the " << static_cast<int>(x->shape.size())
222  << "-dimensional input tensor";
223 
224  for (size_t j = 0; j < axes.size(); ++j) {
225  if (i != j) {
226  ICHECK(new_axis != static_cast<int>(axes[j]->value)) << "repeated axis in transpose";
227  }
228  }
229  new_shape.push_back(x->shape[new_axis]);
230  }
231 
232  return compute(
233  new_shape,
234  [&](const Array<Var>& indices) {
235  std::vector<PrimExpr> idx;
236  for (size_t i = 0; i < axes.size(); ++i) {
237  idx.push_back(1);
238  }
239  for (size_t i = 0; i < axes.size(); ++i) {
240  int axis = static_cast<int>(axes[i]->value);
241  idx[axis] = indices[i];
242  }
243  return x(idx);
244  },
245  name, tag);
246 }
247 
262 inline Tensor reverse_sequence(const Tensor& x, const Tensor& seq_lengths, int seq_axis = 1,
263  int batch_axis = 0, std::string name = "T_reverse_sequence",
264  std::string tag = kInjective) {
265  size_t src_tensor_dim = x->shape.size();
266  int seq_axis_inp = seq_axis;
267 
268  if (seq_lengths.defined()) {
269  size_t seq_lengths_dim = seq_lengths->shape.size();
270  int batch_axis_inp = batch_axis;
271  if (batch_axis < 0) {
272  batch_axis = static_cast<int>(x->shape.size()) + batch_axis;
273  }
274 
275  ICHECK(seq_lengths_dim == 1) << "seq_lengths should be 1D vector";
276 
277  ICHECK(GetConstInt(seq_lengths->shape[0]) == GetConstInt(x->shape[batch_axis]))
278  << "For reverse_sequnece seq_lengths size should match with dimension of batch axis"
279  << ", but got dimension of batch_axis = " << GetConstInt(x->shape[batch_axis])
280  << ", and seq_length size = " << GetConstInt(seq_lengths->shape[0]);
281 
282  ICHECK((0 <= batch_axis) && (batch_axis < static_cast<int>(x->shape.size())))
283  << "batch_axis=" << batch_axis_inp << " is invalid for the "
284  << static_cast<int>(x->shape.size()) << "-dimensional input tensor";
285  }
286 
287  if (seq_axis < 0) {
288  seq_axis = static_cast<int>(x->shape.size()) + seq_axis;
289  }
290  ICHECK((0 <= seq_axis) && (seq_axis < static_cast<int>(x->shape.size())))
291  << "seq_axis=" << seq_axis_inp << " is invalid for the " << static_cast<int>(x->shape.size())
292  << "-dimensional input tensor";
293 
294  auto func = [&](const Array<Var>& indices) {
295  Array<PrimExpr> real_indices;
296  for (size_t i = 0; i < src_tensor_dim; ++i) {
297  if (i == static_cast<size_t>(seq_axis)) {
298  if (seq_lengths.defined()) {
299  auto len = seq_lengths(indices[batch_axis]);
300  auto idx = if_then_else(
301  len <= 1 || len <= indices[i], indices[i],
302  if_then_else(len > x->shape[i], x->shape[i] - 1 - indices[i], len - 1 - indices[i]));
303  real_indices.push_back(idx);
304  } else {
305  real_indices.push_back(x->shape[i] - 1 - indices[i]);
306  }
307  } else {
308  real_indices.push_back(indices[i]);
309  }
310  }
311  return x(real_indices);
312  };
313 
314  return compute(x->shape, func, name, tag);
315 }
316 
327 inline Tensor reshape(const Tensor& x, Array<PrimExpr> newshape, std::string name = "T_reshape",
328  std::string tag = kInjective) {
329  auto x_shape = x->shape;
330  Array<PrimExpr> target_shape;
331 
332  for (const auto& ele : newshape) {
333  target_shape.push_back(ele);
334  }
335 
336  // If either the input shape or the target shape contains a zero, return an empty tensor.
337  if (is_empty_shape(target_shape) || is_empty_shape(x->shape)) {
338  return compute(
339  target_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
340  } else {
341  return compute(
342  target_shape,
343  [&](const Array<Var>& indices) {
344  return x(UnravelIndex(
345  RavelIndex(Array<PrimExpr>{indices.begin(), indices.end()}, target_shape), x_shape));
346  },
347  name, tag);
348  }
349 }
350 
362 inline Tensor unravel_index(const Tensor& x, const Tensor& shape, std::string name = "T_unravel",
363  std::string tag = kInjective) {
364  auto x_shape = x->shape;
365  auto shape_shape = shape->shape;
366 
367  Array<PrimExpr> oshape;
368  oshape.push_back(shape_shape[0]);
369  if (x_shape.size() != 0) {
370  oshape.push_back(x_shape[0]);
371  }
372 
373  auto func = [&](const Array<Var>& indices) {
374  auto i = indices[0];
375  std::vector<PrimExpr> indices_divs;
376  PrimExpr ret = 0;
377  PrimExpr cur_val = 0;
378  PrimExpr index_val = 0;
379 
380  if (x_shape.size() != 0) {
381  index_val = x[indices[1]];
382  } else {
383  index_val = x();
384  }
385  indices_divs.push_back(index_val);
386  for (int v = GetConstInt(shape_shape[0]) - 1; v >= 0; --v) {
387  ret = tvm::if_then_else(i == v, indexmod(indices_divs.back(), shape[v]), ret);
388  cur_val = indexdiv(indices_divs.back(), shape[v]);
389  indices_divs.push_back(cur_val);
390  }
391  return ret;
392  };
393 
394  return compute(oshape, func, name, tag);
395 }
396 
410 inline Tensor squeeze(const Tensor& x, Array<Integer> axis, bool atleast1d = false,
411  std::string name = "T_squeeze", std::string tag = kInjective) {
412  auto ndim = x->shape.size();
413  std::vector<int> axis_val;
414  if (!axis.defined()) {
415  for (size_t i = 0; i < ndim; ++i) {
416  if (IsConstInt(x->shape[i]) && GetConstInt(x->shape[i]) == 1) {
417  axis_val.push_back(static_cast<int>(i));
418  }
419  }
420  } else {
421  for (size_t i = 0; i < axis.size(); ++i) {
422  int64_t val = axis[i]->value;
423  if (val < 0) {
424  val += static_cast<int>(x->shape.size());
425  }
426  if (IsConstInt(x->shape[val])) {
427  ICHECK_EQ(GetConstInt(x->shape[val]), 1) << "Dimension " << val << " must have size 1";
428  }
429  axis_val.push_back(val);
430  }
431  }
432 
433  std::unordered_set<int> axis_set(axis_val.begin(), axis_val.end());
434 
435  Array<PrimExpr> out_shape;
436  for (size_t i = 0; i < ndim; ++i) {
437  if (axis_set.count(static_cast<int>(i)) == 0) {
438  out_shape.push_back(x->shape[i]);
439  }
440  }
441  if (out_shape.size() == 0 && atleast1d) {
442  out_shape.push_back(1);
443  }
444 
445  return compute(
446  out_shape,
447  [&](const Array<Var>& indices) {
448  Array<PrimExpr> real_indices;
449  int flag = 0;
450  for (size_t i = 0; i < ndim; ++i) {
451  if (axis_set.count(static_cast<int>(i)) == 0) {
452  real_indices.push_back(indices[i - flag]);
453  } else {
454  real_indices.push_back(0);
455  flag += 1;
456  }
457  }
458  return x(real_indices);
459  },
460  name, tag);
461 }
462 
473 inline Tensor concatenate(const Array<Tensor>& inputs, int axis = 0, std::string name = "T_concat",
474  std::string tag = kInjective) {
475  int ndim = static_cast<int>(inputs[0]->shape.size());
476  ICHECK(-ndim <= axis && axis < ndim) << "concatenate only accepts `axis` in [-ndim, ndim)"
477  << ", but got axis = " << axis << ", and ndim = " << ndim;
478  if (axis < 0) {
479  axis += ndim;
480  }
481  ICHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds";
482 
483  Array<PrimExpr> axis_sizes;
484  for (auto t : inputs) {
485  axis_sizes.push_back(t->shape[axis]);
486  }
487  arith::Analyzer analyzer;
488  PrimExpr join_size = axis_sizes[0];
489  for (size_t i = 1; i < axis_sizes.size(); ++i) {
490  join_size += axis_sizes[i];
491  }
492  join_size = analyzer.Simplify(join_size);
493  Array<PrimExpr> out_shape;
494  for (size_t i = 0; i < inputs[0]->shape.size(); ++i) {
495  out_shape.push_back(i == static_cast<size_t>(axis) ? join_size : inputs[0]->shape[i]);
496  }
497 
498  return compute(
499  out_shape,
500  [&](const Array<Var>& indices) {
501  auto ret = inputs[0](indices);
502  auto ind = indices[axis];
503  for (size_t i = 0; i < inputs.size() - 1; ++i) {
504  ind -= axis_sizes[i];
505 
506  Array<PrimExpr> idx;
507  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
508  idx.push_back(indices[i]);
509  }
510  idx.push_back(ind);
511  for (size_t i = axis + 1; i < indices.size(); ++i) {
512  idx.push_back(indices[i]);
513  }
514 
515  ret = tvm::if_then_else(ind >= 0, inputs[i + 1](idx), ret);
516  }
517  return ret;
518  },
519  name, tag);
520 }
521 
532 inline Tensor stack(const Array<Tensor>& inputs, int axis = 0, std::string name = "T_stack",
533  std::string tag = kInjective) {
534  int ndim = static_cast<int>(inputs[0]->shape.size());
535  ICHECK(-ndim - 1 <= axis && axis <= ndim)
536  << "stack only accepts `axis` in [-ndim, ndim)"
537  << ", but got axis = " << axis << ", and ndim = " << ndim;
538  if (axis < 0) {
539  axis += ndim + 1;
540  }
541  ICHECK_LT(axis, inputs[0]->shape.size() + 1) << "axis out of bounds";
542 
543  const int stack_size = static_cast<int>(inputs.size());
544  Array<PrimExpr> out_shape;
545  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) out_shape.push_back(inputs[0]->shape[i]);
546  out_shape.push_back(stack_size);
547  for (size_t i = static_cast<size_t>(axis); i < static_cast<size_t>(ndim); ++i)
548  out_shape.push_back(inputs[0]->shape[i]);
549 
550  return compute(
551  out_shape,
552  [&](const Array<Var>& indices) {
553  Array<PrimExpr> idx;
554  for (size_t i = 0; i < indices.size(); ++i)
555  if (i != static_cast<size_t>(axis)) idx.push_back(indices[i]);
556  auto ind = indices[axis];
557  auto ret = inputs[0](idx);
558  for (int i = 0; i < static_cast<int>(inputs.size() - 1); ++i) {
559  ret = tvm::if_then_else(ind == i + 1, inputs[i + 1](idx), ret);
560  }
561  return ret;
562  },
563  name, tag);
564 }
565 
578 inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int axis,
579  std::string name = "T_split", std::string tag = kInjective) {
580  if (axis < 0) {
581  axis += static_cast<int>(x->shape.size());
582  }
583  ICHECK_LT(axis, x->shape.size()) << "axis out of bounds";
584 
585  auto src_axis_size = x->shape[axis];
586  std::vector<PrimExpr> begin_ids;
587  begin_ids.push_back(0);
588 
589  for (auto idx : split_indices) {
590  auto idx_node = idx.as<IntImmNode>();
591  auto back_node = begin_ids.back().as<IntImmNode>();
592  if (idx_node && back_node) {
593  ICHECK_GT(idx_node->value, back_node->value) << "split_indices must be sorted";
594  }
595  begin_ids.push_back(idx);
596  }
597 
598  Array<Array<PrimExpr>> out_shapes;
599  for (size_t i = 0; i < begin_ids.size(); ++i) {
600  PrimExpr out_axis_size;
601  if (i == begin_ids.size() - 1) {
602  out_axis_size = src_axis_size - begin_ids[i];
603  } else {
604  out_axis_size = begin_ids[i + 1] - begin_ids[i];
605  }
606 
608  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
609  shape.push_back(x->shape[i]);
610  }
611  shape.push_back(out_axis_size);
612  for (size_t i = axis + 1; i < x->shape.size(); ++i) {
613  shape.push_back(x->shape[i]);
614  }
615 
616  out_shapes.push_back(shape);
617  }
618 
619  Array<Tensor> result;
620  for (size_t i = 0; i < begin_ids.size(); ++i) {
621  result.push_back(compute(
622  out_shapes[i],
623  [&](const Array<Var>& indices) {
624  auto begin = begin_ids[i];
625  Array<PrimExpr> real_indices;
626  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
627  real_indices.push_back(indices[j]);
628  }
629  real_indices.push_back(indices[axis] + begin);
630  for (size_t j = axis + 1; j < indices.size(); ++j) {
631  real_indices.push_back(indices[j]);
632  }
633 
634  return x(real_indices);
635  },
636  name, tag));
637  }
638 
639  return result;
640 }
641 
643  auto idx_var = index.as<tvm::tir::VarNode>();
644  auto extent_var = extent.as<tvm::tir::VarNode>();
645 
646  if (idx_var && extent_var && idx_var->name_hint == extent_var->name_hint) {
647  return index;
648  }
649 
650  PrimExpr begin_range = tvm::if_then_else(stride < 0, -1, 0);
651  PrimExpr end_range = tvm::if_then_else(stride < 0, extent - 1, extent);
652 
653  if (!(index->IsInstance<tvm::IntImmNode>() && GetConstInt(index) >= 0)) {
654  index = tvm::if_then_else(index < 0, index + extent, index);
655  }
656 
657  return tvm::min(tvm::max(index, begin_range), end_range);
658 }
659 
660 inline int64_t StaticCanonicalizeIndex(int64_t index, int64_t extent, int64_t stride) {
661  int64_t begin_range = stride < 0 ? -1 : 0;
662  int64_t end_range = stride < 0 ? extent - 1 : extent;
663  if (index < 0) {
664  index += extent;
665  }
666  return std::min(std::max(index, begin_range), end_range);
667 }
668 
669 inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stride) {
670  if (index->IsInstance<tvm::IntImmNode>() && extent->IsInstance<tvm::IntImmNode>() &&
671  stride->IsInstance<tvm::IntImmNode>()) {
672  return tvm::IntImm(
673  tvm::DataType::Int(64),
674  StaticCanonicalizeIndex(GetConstInt(index), GetConstInt(extent), GetConstInt(stride)));
675  }
676  return DynamicCanonicalizeIndex(index, extent, stride);
677 }
678 
679 inline PrimExpr GetLength(PrimExpr begin, PrimExpr end, PrimExpr stride, PrimExpr extent,
680  bool assume_inbound = true) {
681  if (assume_inbound) {
682  return ceildiv(end - begin, stride);
683  } else {
684  begin = CanonicalizeIndex(begin, extent, stride);
685  end = CanonicalizeIndex(end, extent, stride);
686  return tvm::if_then_else(stride < 0, ceildiv(begin - end, -stride),
687  ceildiv(end - begin, stride));
688  }
689 }
690 
707  const Tensor& x, const Array<PrimExpr>& begin, const Array<PrimExpr>& end,
708  const Array<PrimExpr>& strides, const Array<Integer>& axes, bool assume_inbound = true,
709  std::string name = "T_dynamic_strided_slice_with_axes", std::string tag = kInjective) {
710  const size_t src_tensor_dim = x->shape.size();
711  ICHECK_EQ(begin.size(), end.size());
712  ICHECK_EQ(begin.size(), strides.size());
713  ICHECK_EQ(begin.size(), axes.size());
714  ICHECK_LE(begin.size(), src_tensor_dim);
715 
716  for (const auto& axis_imm : axes) {
717  int axis = axis_imm->value;
718  ICHECK_LT(axis, src_tensor_dim);
719  }
720 
721  arith::Analyzer analyzer;
722 
723  Array<PrimExpr> out_shape = x->shape;
724  for (size_t i = 0; i < begin.size(); i++) {
725  int axis = axes[i]->value;
726  PrimExpr new_shape =
727  analyzer.Simplify(GetLength(begin[i], end[i], strides[i], out_shape[axis], assume_inbound));
728  out_shape.Set(axis, new_shape);
729  }
730 
731  return te::compute(
732  out_shape,
733  [&](const Array<tvm::tir::Var>& indices) {
734  Array<PrimExpr> real_indices = indices.Map([](const auto& var) -> PrimExpr { return var; });
735 
736  for (size_t i = 0; i < begin.size(); i++) {
737  int axis = axes[i]->value;
738  PrimExpr new_index = indices[axis] * strides[i] + begin[i];
739  real_indices.Set(axis, new_index);
740  }
741 
742  return x(real_indices);
743  },
744  name, tag);
745 }
746 
761 inline Tensor dynamic_strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
762  const Array<PrimExpr>& end, const Array<PrimExpr>& strides,
763  bool assume_inbound = true,
764  std::string name = "T_dynamic_strided_slice",
765  std::string tag = kInjective) {
766  const size_t src_tensor_dim = x->shape.size();
767  ICHECK_LE(begin.size(), src_tensor_dim);
768  ICHECK_LE(end.size(), src_tensor_dim);
769  ICHECK_LE(strides.size(), src_tensor_dim);
770  ICHECK_EQ(begin.size(), end.size());
771  ICHECK_EQ(begin.size(), strides.size());
772 
773  const size_t num_slice_axes = begin.size();
774  Array<PrimExpr> out_shape;
775 
776  arith::Analyzer analyzer;
777  for (size_t i = 0; i < num_slice_axes; ++i) {
778  // Check ProducerLoad to keep backward compatibility for Relay.
779  if (!begin[i]->IsInstance<ProducerLoadNode>() && !end[i]->IsInstance<ProducerLoadNode>() &&
780  !strides[i]->IsInstance<ProducerLoadNode>()) {
781  out_shape.push_back(
782  analyzer.Simplify(GetLength(begin[i], end[i], strides[i], x->shape[i], assume_inbound)));
783  } else {
784  out_shape.push_back(tvm::tir::Var("dim"));
785  }
786  }
787 
788  for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) {
789  out_shape.push_back(x->shape[i]);
790  }
791 
792  return te::compute(
793  out_shape,
794  [&](const Array<tvm::tir::Var>& indices) {
795  Array<PrimExpr> real_indices;
796  for (size_t i = 0; i < num_slice_axes; ++i) {
797  real_indices.push_back(indices[i] * strides[i] + tvm::min(begin[i], x->shape[i] - 1));
798  }
799  // keep input dim
800  for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) {
801  real_indices.push_back(indices[i]);
802  }
803  return x(real_indices);
804  },
805  name, tag);
806 }
807 
823  const te::Tensor& end, const te::Tensor& strides,
824  bool assume_inbound = true,
825  std::string name = "T_strided_slice_dynamic",
826  std::string tag = topi::kInjective) {
827  DataType index_dtype = begin->shape[0]->dtype;
828  const int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value;
829  ICHECK_EQ(end->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
830  ICHECK_EQ(strides->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
831 
832  Array<PrimExpr> begin_expr, end_expr, strides_expr;
833  for (int64_t i = 0; i < num_dynamic_axes; ++i) {
834  auto ind = make_const(index_dtype, i);
835  begin_expr.push_back(begin(ind));
836  end_expr.push_back(end(ind));
837  strides_expr.push_back(strides(ind));
838  }
839  return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, assume_inbound, name, tag);
840 }
841 
857  const Array<PrimExpr>& ishape, const Array<Integer>& begin, const Array<Integer>& end,
858  const Array<Integer>& strides, const Array<Integer>& axes, const std::string& slice_mode) {
859  ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size());
860  std::vector<int64_t> begin_vec, end_vec, strides_vec;
861  std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode);
862  auto begin_canonicalized = StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes,
863  begin[0]->dtype, slice_mode);
864  return StridedSliceOutputShape(ishape, begin_vec, end_vec, strides_vec, axes, slice_mode,
865  begin_canonicalized, true);
866 }
867 
884 inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& begin,
885  const Array<Integer>& end, const Array<Integer>& strides,
886  const Array<Integer>& axes, std::string slice_mode = "end",
887  std::string name = "T_strided_slice_with_axes",
888  std::string tag = kInjective) {
889  const size_t src_tensor_dim = x->shape.size();
890  ICHECK(axes.size() <= src_tensor_dim);
891  ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size());
892 
893  std::vector<int64_t> begin_vec, end_vec, strides_vec;
894  std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode);
895 
896  auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, axes,
897  begin[0]->dtype, slice_mode);
898  auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec, strides_vec, axes,
899  slice_mode, begin_expr);
900 
901  return te::compute(
902  out_shape,
903  [&](const Array<tir::Var>& indices) {
904  Array<PrimExpr> real_indices;
905  for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]);
906  for (size_t i = 0; i < axes.size(); ++i) {
907  auto stride = make_const(strides[i].dtype(), strides_vec[i]);
908  PrimExpr ind = indices[axes[i].IntValue()] * stride + begin_expr[i];
909  real_indices.Set(axes[i].IntValue(), ind);
910  }
911  return x(real_indices);
912  },
913  name, tag);
914 }
915 
930 inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, const Array<Integer>& end,
931  const Array<Integer>& strides, std::string slice_mode = "end",
932  std::string name = "T_strided_slice", std::string tag = kInjective) {
933  size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
934  Array<Integer> axes;
935  for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i);
936  Array<Integer> begin_full(begin);
937  Array<Integer> end_full(end);
938  Array<Integer> strides_full(strides);
939 
940  DataType index_dtype = begin.size() > 0 ? begin[0]->dtype : DataType::Int(64);
941  const IntImm one = IntImm(index_dtype, 1);
942  const IntImm zero = IntImm(index_dtype, 0);
943  const IntImm max_range = Downcast<IntImm>(max_value(index_dtype));
944 
945  for (size_t i = strides.size(); i < src_tensor_dim; ++i) {
946  strides_full.push_back(one);
947  }
948  for (size_t i = begin.size(); i < src_tensor_dim; ++i) {
949  begin_full.push_back(GetConstInt(strides_full[i]) > 0 ? zero : max_range);
950  }
951  for (size_t i = end.size(); i < src_tensor_dim; ++i) {
952  end_full.push_back(GetConstInt(strides_full[i]) < 0 ? zero : max_range);
953  }
954 
955  return strided_slice_with_axes(x, begin_full, end_full, strides_full, axes, slice_mode, name,
956  tag);
957 }
958 
971 inline Array<Tensor> split_sections(const Tensor& x, int num_sections, int axis,
972  std::string name = "T_split_sections",
973  std::string tag = kInjective) {
974  if (axis < 0) {
975  axis += static_cast<int>(x->shape.size());
976  }
977  ICHECK_LT(axis, x->shape.size()) << "axis out of bounds";
978 
979  auto src_axis_size = x->shape[axis];
980 
981  ICHECK_GT(num_sections, 0) << "Slice count must be > 0";
982 
983  if (auto node = src_axis_size.as<IntImmNode>()) {
984  ICHECK_EQ(node->value % num_sections, 0)
985  << "num_sections must be an integer factor of the size of axis " << axis << " ("
986  << node->value << ")";
987  }
988 
989  Array<PrimExpr> split_indices;
990  auto seg_size = indexdiv(src_axis_size, num_sections);
991  for (int i = 0; i < num_sections; ++i) {
992  // region at index 0 is added by split()
993  if (i != 0) {
994  split_indices.push_back(seg_size * i);
995  }
996  }
997 
998  return split(x, split_indices, axis, name, tag);
999 }
1000 
1013 inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims,
1014  std::string mode = "clip", std::string name = "T_take",
1015  std::string tag = kInjective) {
1016  Array<PrimExpr> a_shape = a->shape;
1017  Array<PrimExpr> out_shape = indices->shape;
1018  PrimExpr a_size = 1;
1019  for (size_t i = 0; i < a_shape.size(); ++i) {
1020  a_size = a_size * a_shape[i];
1021  }
1022 
1023  if (mode == "clip") {
1024  return compute(
1025  out_shape,
1026  [&](const Array<Var>& out_index) {
1027  auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
1028  return a(UnravelIndex(idx, a_shape));
1029  },
1030  name, tag);
1031  } else if (mode == "fast") {
1032  LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
1033  "Make sure input indices are in bound";
1034  return compute(
1035  out_shape,
1036  [&](const Array<Var>& out_index) { return a(UnravelIndex(indices(out_index), a_shape)); },
1037  name, tag);
1038  } else { // mode == "wrap"
1039  return compute(
1040  out_shape,
1041  [&](const Array<Var>& out_index) {
1042  auto idx = truncmod(truncmod(indices(out_index), a_size) + a_size, a_size);
1043  return a(UnravelIndex(idx, a_shape));
1044  },
1045  name, tag);
1046  }
1047 }
1048 
1061 inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, double mask_value,
1062  int axis, std::string name = "T_sequence_mask",
1063  std::string tag = kInjective) {
1064  ICHECK(axis == 0 || axis == 1) << "axis must be either 0 or 1";
1065  ICHECK_EQ(valid_length->shape.size(), 1) << "valid_length must have ndim=1, i.e., (batch_size,).";
1066  auto length_dim = data->shape[axis];
1067  auto batch_dim = data->shape[1 - axis];
1068  Array<PrimExpr> out_shape = data->shape;
1069  Tensor out = compute(
1070  out_shape,
1071  [&](const Array<Var>& out_index) {
1072  Array<PrimExpr> len_index;
1073  auto tid = out_index[axis];
1074  auto bid = out_index[1 - axis];
1075  len_index.push_back(bid);
1076  PrimExpr ret =
1077  tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index),
1078  tvm::tir::make_const(data->dtype, mask_value), data(out_index));
1079  return ret;
1080  },
1081  name, tag);
1082  return out;
1083 }
1084 
1099 inline Tensor take(const Tensor& a, Variant<Tensor, PrimExpr> indices, int batch_dims, int axis,
1100  std::string mode = "clip", std::string name = "T_take",
1101  std::string tag = kInjective) {
1102  if (axis < 0) {
1103  axis += static_cast<int>(a->shape.size());
1104  }
1105  ICHECK_GE(axis, 0) << "axis out of bounds";
1106  ICHECK_LT(axis, a->shape.size()) << "axis out of bounds";
1107  auto axis_dim = a->shape[axis];
1108  auto indices_shape = [&]() -> Array<PrimExpr> {
1109  if (auto tensor = indices.as<TensorNode>()) {
1110  return tensor->shape;
1111  } else {
1112  return {};
1113  }
1114  }();
1115 
1116  int indices_len = static_cast<int>(indices_shape.size());
1117 
1118  int batch_dims_ = batch_dims;
1119  if (batch_dims_ != 0) {
1120  ICHECK_GE(batch_dims_, -indices_len) << "batch_dims out of bounds";
1121  ICHECK_LE(batch_dims_, indices_len) << "batch_dims out of bounds";
1122 
1123  if (batch_dims_ < 0) {
1124  batch_dims_ = indices_len + batch_dims_;
1125  }
1126 
1127  ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds";
1128  ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to axis";
1129  for (int i = 0; i < batch_dims_; ++i) {
1130  auto addr1 = a->shape[i];
1131  auto addr2 = indices_shape[i];
1132  auto v1 = static_cast<IntImm*>(&addr1)->get()->value;
1133  auto v2 = static_cast<IntImm*>(&addr2)->get()->value;
1134  ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to indices.shape[" << i << "]";
1135  }
1136  }
1137 
1138  // The result shape is a.shape[:axis] + indices.shape[batch_dims:] +
1139  // a.shape[axis + 1:].
1140 
1141  Array<PrimExpr> out_shape;
1142  for (int i = 0; i < batch_dims_; ++i) {
1143  out_shape.push_back(a->shape[i]);
1144  }
1145  for (int i = batch_dims_; i < axis; ++i) {
1146  out_shape.push_back(a->shape[i]);
1147  }
1148  for (int i = batch_dims_; i < indices_len; ++i) {
1149  out_shape.push_back(indices_shape[i]);
1150  }
1151  for (size_t i = axis + 1; i < a->shape.size(); ++i) {
1152  out_shape.push_back(a->shape[i]);
1153  }
1154 
1155  auto get_index = [&](const Array<PrimExpr>& indices_position) -> PrimExpr {
1156  if (auto tensor = indices.as<Tensor>()) {
1157  return tensor.value()(indices_position);
1158  } else if (auto prim = indices.as<PrimExpr>()) {
1159  ICHECK_EQ(indices_position.size(), 0);
1160  return prim.value();
1161  } else {
1162  LOG(FATAL) << "Variant did not contain either allowed type";
1163  }
1164  };
1165 
1166  if (mode == "clip") {
1167  if (batch_dims_ == 0) {
1168  return compute(
1169  out_shape,
1170  [&](const Array<Var>& out_index) {
1171  Array<PrimExpr> indices_position;
1172  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1173  indices_position.push_back(out_index[j]);
1174  }
1175  Array<PrimExpr> real_indices;
1176  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1177  real_indices.push_back(out_index[j]);
1178  }
1179  auto idx = tvm::min(tvm::max(0, get_index(indices_position)), axis_dim - 1);
1180  real_indices.push_back(idx);
1181  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1182  real_indices.push_back(out_index[j]);
1183  }
1184  return a(real_indices);
1185  },
1186  name, tag);
1187  } else {
1188  return compute(
1189  out_shape,
1190  [&](const Array<Var>& out_index) {
1191  Array<PrimExpr> indices_position;
1192  for (size_t j = 0; j < static_cast<size_t>(batch_dims_); ++j) {
1193  indices_position.push_back(out_index[j]);
1194  }
1195  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len - batch_dims_); ++j) {
1196  indices_position.push_back(out_index[j]);
1197  }
1198  Array<PrimExpr> real_indices;
1199  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1200  real_indices.push_back(out_index[j]);
1201  }
1202  auto idx = tvm::min(tvm::max(0, get_index(indices_position)), axis_dim - 1);
1203  real_indices.push_back(idx);
1204  for (size_t j = axis + indices_len - batch_dims_; j < out_index.size(); ++j) {
1205  real_indices.push_back(out_index[j]);
1206  }
1207  return a(real_indices);
1208  },
1209  name, tag);
1210  }
1211  } else if (mode == "fast") {
1212  return compute(
1213  out_shape,
1214  [&](const Array<Var>& out_index) {
1215  Array<PrimExpr> indices_position;
1216  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1217  indices_position.push_back(out_index[j]);
1218  }
1219  Array<PrimExpr> real_indices;
1220  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1221  real_indices.push_back(out_index[j]);
1222  }
1223  real_indices.push_back(get_index(indices_position));
1224  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1225  real_indices.push_back(out_index[j]);
1226  }
1227  return a(real_indices);
1228  },
1229  name, tag);
1230  } else { // mode == "wrap"
1231  return compute(
1232  out_shape,
1233  [&](const Array<Var>& out_index) {
1234  Array<PrimExpr> indices_position;
1235  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1236  indices_position.push_back(out_index[j]);
1237  }
1238  Array<PrimExpr> real_indices;
1239  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1240  real_indices.push_back(out_index[j]);
1241  }
1242  auto idx = truncmod(truncmod(get_index(indices_position), axis_dim) + axis_dim, axis_dim);
1243  real_indices.push_back(idx);
1244  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1245  real_indices.push_back(out_index[j]);
1246  }
1247  return a(real_indices);
1248  },
1249  name, tag);
1250  }
1251 }
1252 
1264 inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y,
1265  std::string name = "T_where", std::string tag = kBroadcast) {
1266  ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs "
1267  << y->dtype;
1268  auto get_out_shape = [&]() {
1269  auto bh1 = detail::BroadcastShape(x->shape, y->shape);
1270  Array<PrimExpr> common_shape1(bh1.common_shape.begin(), bh1.common_shape.end());
1271  auto bh2 = detail::BroadcastShape(condition->shape, common_shape1);
1272  Array<PrimExpr> common_shape2(bh2.common_shape.begin(), bh2.common_shape.end());
1273  return common_shape2;
1274  };
1275 
1276  auto oshape = get_out_shape();
1277 
1278  auto c_bh = detail::BroadcastShape(condition->shape, oshape);
1279  auto x_bh = detail::BroadcastShape(x->shape, oshape);
1280  auto y_bh = detail::BroadcastShape(y->shape, oshape);
1281 
1282  auto select = [&](tvm::Array<tvm::tir::Var> ovars) {
1283  auto c = condition(InputIndexFromBroadcast(ovars, condition, c_bh.vars1, c_bh.all_vars));
1284  auto true_val = x(InputIndexFromBroadcast(ovars, x, x_bh.vars1, x_bh.all_vars));
1285  auto false_val = y(InputIndexFromBroadcast(ovars, y, y_bh.vars1, y_bh.all_vars));
1286  return tvm::tir::Select(c != 0, true_val, false_val);
1287  };
1288 
1289  return compute(oshape, select, name, tag);
1290 }
1291 
1304 inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = "T_repeat",
1305  std::string tag = kBroadcast) {
1306  int ndim = static_cast<int>(x->shape.size());
1307  ICHECK(-ndim - 1 <= axis && axis <= ndim)
1308  << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
1309  << ", but got axis = " << axis << ", and data.ndim = " << ndim;
1310  ICHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`"
1311  << ", but got repeats = " << repeats;
1312  if (axis < 0) {
1313  // Calculate offset from last dimension
1314  axis += ndim;
1315  }
1316  Array<PrimExpr> new_shape;
1317  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
1318  new_shape.push_back(x->shape[i]);
1319  }
1320  new_shape.push_back(repeats * x->shape[axis]);
1321  for (size_t i = axis + 1; i < x->shape.size(); ++i) {
1322  new_shape.push_back(x->shape[i]);
1323  }
1324 
1325  return compute(
1326  new_shape,
1327  [&](const Array<Var>& indices) {
1328  Array<PrimExpr> idx;
1329  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
1330  idx.push_back(indices[i]);
1331  }
1332  idx.push_back(indexdiv(indices[axis], repeats));
1333  for (size_t i = axis + 1; i < indices.size(); ++i) {
1334  idx.push_back(indices[i]);
1335  }
1336  return x(idx);
1337  },
1338  name, tag);
1339 }
1340 
1351 inline Tensor tile(const Tensor& x, Array<Integer> reps, std::string name = "T_tile",
1352  std::string tag = kBroadcast) {
1353  size_t ndim = x->shape.size();
1354  size_t rdim = reps.size();
1355  size_t tdim = (ndim > rdim) ? ndim : rdim;
1356  Array<PrimExpr> data_shape;
1357  Array<PrimExpr> reps_shape;
1358  Array<PrimExpr> new_shape;
1359  if (ndim == rdim) {
1360  for (size_t i = 0; i < ndim; ++i) {
1361  data_shape.push_back(x->shape[i]);
1362  reps_shape.push_back(reps[i]);
1363  }
1364  } else if (ndim > rdim) {
1365  for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]);
1366  for (size_t i = 0; i < (ndim - rdim); ++i) reps_shape.push_back(1);
1367  for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]);
1368  } else {
1369  for (size_t i = 0; i < (rdim - ndim); ++i) data_shape.push_back(1);
1370  for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]);
1371  for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]);
1372  }
1373  for (size_t i = 0; i < tdim; ++i) new_shape.push_back(data_shape[i] * reps_shape[i]);
1374 
1375  if (is_empty_shape(new_shape)) {
1376  return compute(
1377  new_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
1378  } else {
1379  return compute(
1380  new_shape,
1381  [&](const Array<Var>& indices) {
1382  Array<PrimExpr> idx;
1383  if (ndim >= rdim) {
1384  for (size_t i = 0; i < ndim; ++i) idx.push_back(indexmod(indices[i], x->shape[i]));
1385  } else {
1386  for (size_t i = 0; i < ndim; ++i)
1387  idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
1388  }
1389  return x(idx);
1390  },
1391  name, tag);
1392  }
1393 }
1394 
1406 inline Tensor dyn_tile(const Tensor& x, Array<PrimExpr> new_shape, size_t rdim,
1407  std::string name = "T_tile", std::string tag = kBroadcast) {
1408  size_t ndim = x->shape.size();
1409  if (is_empty_shape(new_shape)) {
1410  return compute(
1411  new_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
1412  } else {
1413  return compute(
1414  new_shape,
1415  [&](const Array<Var>& indices) {
1416  Array<PrimExpr> idx;
1417  if (ndim >= rdim) {
1418  for (size_t i = 0; i < ndim; ++i) {
1419  idx.push_back(indexmod(indices[i], x->shape[i]));
1420  }
1421  } else {
1422  for (size_t i = 0; i < ndim; ++i) {
1423  idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
1424  }
1425  }
1426  return x(idx);
1427  },
1428  name, tag);
1429  }
1430 }
1431 
1443 inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
1444  std::string name = "T_gather", std::string tag = kInjective) {
1445  size_t ndim_d = data->shape.size();
1446  size_t ndim_i = indices->shape.size();
1447  ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar.";
1448  ICHECK_EQ(ndim_d, ndim_i);
1449  if (axis < 0) {
1450  axis += ndim_d;
1451  }
1452  ICHECK_GE(axis, 0);
1453  ICHECK_LT(axis, ndim_d);
1454  if (indices->shape[axis].as<IntImmNode>()) {
1455  size_t indices_dim_i = static_cast<size_t>(GetConstInt(indices->shape[axis]));
1456  ICHECK_GE(indices_dim_i, 1);
1457  }
1458  ICHECK(indices->dtype.is_int() || indices->dtype.is_uint());
1459 
1460  Array<PrimExpr> out_shape;
1461  for (size_t i = 0; i < ndim_i; ++i) {
1462  out_shape.push_back(indices->shape[i]);
1463  }
1464 
1465  return compute(
1466  out_shape,
1467  [&](const Array<Var>& out_index) {
1468  Array<PrimExpr> indices_position;
1469  for (size_t i = 0; i < ndim_i; ++i) {
1470  indices_position.push_back(out_index[i]);
1471  }
1472  Array<PrimExpr> real_indices;
1473  for (size_t i = 0; i < ndim_i; ++i) {
1474  if (i == static_cast<size_t>(axis)) {
1475  real_indices.push_back(indices(indices_position));
1476  } else {
1477  real_indices.push_back(indices_position[i]);
1478  }
1479  }
1480  return data(real_indices);
1481  },
1482  name, tag);
1483 }
1484 
1496 inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dims = 0,
1497  std::string name = "T_gather_nd", std::string tag = kInjective) {
1498  size_t ndim_d = data->shape.size();
1499  size_t ndim_i = indices->shape.size();
1500  ICHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions";
1501  size_t indices_dim0 = static_cast<size_t>(GetConstInt(indices->shape[0]));
1502  ICHECK_LE(indices_dim0, ndim_d) << "dim 0 of indices tensor must be no more "
1503  << "than dimensions of data tensor";
1504  Array<PrimExpr> out_shape;
1505  for (size_t i = 1; i < ndim_i; ++i) {
1506  out_shape.push_back(indices->shape[i]);
1507  }
1508  for (size_t i = indices_dim0 + batch_dims; i < ndim_d; ++i) {
1509  out_shape.push_back(data->shape[i]);
1510  }
1511  return compute(
1512  out_shape,
1513  [&](const Array<Var>& out_index) {
1514  Array<PrimExpr> indices_position;
1515  indices_position.push_back(0);
1516  for (size_t i = 0; i < ndim_i - 1; ++i) {
1517  indices_position.push_back(out_index[i]);
1518  }
1519  Array<PrimExpr> real_indices;
1520  for (size_t i = 0; i < static_cast<size_t>(batch_dims); ++i) {
1521  real_indices.push_back(out_index[i]);
1522  }
1523  for (size_t i = 0; i < indices_dim0; ++i) {
1524  indices_position.Set(0, make_const(DataType::Int(32), i));
1525  if (indices->dtype.is_int() || indices->dtype.is_uint()) {
1526  real_indices.push_back(indices(indices_position));
1527  } else {
1528  real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position)));
1529  }
1530  }
1531  if (real_indices.size() == ndim_d) {
1532  return data(real_indices);
1533  }
1534  for (size_t i = ndim_i - 1; i < out_index.size(); ++i) {
1535  real_indices.push_back(out_index[i]);
1536  }
1537  return data(real_indices);
1538  },
1539  name, tag);
1540 }
1541 
1558  bool trans_a = false, bool trans_b = false,
1559  std::string name = "T_matmul", std::string tag = kMatMul) {
1560  tvm::Array<tvm::PrimExpr> output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]};
1561  auto k = tvm::te::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k");
1562  auto l = [&](tvm::tir::Var i, tvm::tir::Var j) {
1563  return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), {k});
1564  };
1565  return tvm::te::compute(output_shape, l, name, tag);
1566 }
1567 
1579 inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, int axes = 2,
1580  std::string name = "T_tensordot", std::string tag = kMatMul) {
1581  ICHECK_GE(A->shape.size(), axes);
1582  ICHECK_GE(B->shape.size(), axes);
1583 
1584  Array<PrimExpr> output_shape(A->shape.begin(), A->shape.end() + (-axes));
1585  for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) output_shape.push_back(*it);
1586 
1587  Array<IterVar> iter_vars;
1588  for (int i = 0; i < axes; ++i)
1589  iter_vars.push_back(reduce_axis(Range(0, B->shape[i]), "k" + std::to_string(i)));
1590 
1591  auto func = [&A, &B, &iter_vars, axes](const Array<Var>& input_indices) {
1592  Array<PrimExpr> A_indices(input_indices.begin(),
1593  input_indices.begin() + (A->shape.size() - axes));
1594  for (auto& v : iter_vars) A_indices.push_back(v);
1595 
1596  Array<PrimExpr> B_indices;
1597  for (auto& v : iter_vars) B_indices.push_back(v);
1598 
1599  auto it = input_indices.begin() + (A->shape.size() - axes);
1600  for (; it != input_indices.end(); ++it) B_indices.push_back(*it);
1601 
1602  // Some passes don't like reductions with empty axis, so avoid it here
1603  if (iter_vars.empty()) {
1604  return A(A_indices) * B(B_indices);
1605  } else {
1606  return sum(A(A_indices) * B(B_indices), iter_vars);
1607  }
1608  };
1609 
1610  return compute(output_shape, func, name, tag);
1611 }
1612 
1625 inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, Array<PrimExpr> A_axes,
1626  Array<PrimExpr> B_axes, std::string name = "T_tensordot",
1627  std::string tag = kMatMul) {
1628  ICHECK_EQ(A_axes.size(), B_axes.size());
1629 
1630  auto A_axes_val = GetConstIntValues(A_axes, "A_axes");
1631  auto B_axes_val = GetConstIntValues(B_axes, "B_axes");
1632 
1633  Array<PrimExpr> output_shape;
1634  for (unsigned i = 0; i < A->shape.size(); ++i)
1635  if (std::find(A_axes_val.begin(), A_axes_val.end(), i) == A_axes_val.end())
1636  output_shape.push_back(A->shape[i]);
1637  for (unsigned i = 0; i < B->shape.size(); ++i)
1638  if (std::find(B_axes_val.begin(), B_axes_val.end(), i) == B_axes_val.end())
1639  output_shape.push_back(B->shape[i]);
1640 
1641  Array<IterVar> iter_vars;
1642  for (unsigned i = 0; i < B_axes_val.size(); ++i)
1643  iter_vars.push_back(reduce_axis(Range(0, B->shape[B_axes_val[i]]), "k" + std::to_string(i)));
1644 
1645  auto func = [&A, &B, &iter_vars, A_axes_val, B_axes_val](const Array<Var>& input_indices) {
1646  int idx_input = 0;
1647  Array<PrimExpr> A_indices;
1648  for (unsigned i = 0; i < A->shape.size(); ++i) {
1649  auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i);
1650  if (axes_pos == A_axes_val.end()) {
1651  A_indices.push_back(input_indices[idx_input++]);
1652  } else {
1653  A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]);
1654  }
1655  }
1656 
1657  Array<PrimExpr> B_indices;
1658  for (unsigned i = 0; i < B->shape.size(); ++i) {
1659  auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i);
1660  if (axes_pos == B_axes_val.end()) {
1661  B_indices.push_back(input_indices[idx_input++]);
1662  } else {
1663  B_indices.push_back(iter_vars[axes_pos - B_axes_val.begin()]);
1664  }
1665  }
1666  return sum(A(A_indices) * B(B_indices), iter_vars);
1667  };
1668  return compute(output_shape, func, name, tag);
1669 }
1670 
1671 inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr& step,
1672  DataType dtype, std::string name = "T_arange", std::string tag = kInjective) {
1673  arith::Analyzer analyzer;
1674  PrimExpr num_elem;
1675  bool is_all_int = start.dtype().is_int() && stop.dtype().is_int() && step.dtype().is_int();
1676  if (is_all_int && analyzer.CanProveGreaterEqual(step, 1)) {
1677  // fast path for integer arange when step is positive
1678  num_elem = tvm::floordiv((stop - start + step - 1), step);
1679  } else if (is_all_int && analyzer.CanProveLess(step, 0)) {
1680  // fast path for integer arange when step is negative
1681  num_elem = tvm::floordiv((start - stop - step - 1), -step);
1682  } else {
1683  // fallback path for non-integer or step of unknown sign
1684  num_elem = tvm::cast(DefaultIndexType(),
1685  tvm::ceil(tvm::cast(tvm::DataType::Float(32), stop - start) / step));
1686  }
1687  num_elem = analyzer.Simplify(num_elem);
1688 
1689  return compute(
1690  {num_elem},
1691  [&](const Array<Var>& indices) { return tvm::cast(dtype, start + step * indices[0]); }, name,
1692  tag);
1693 }
1694 
1705 inline Array<Tensor> meshgrid(const Array<Tensor>& inputs, const std::string& indexing,
1706  std::string name = "T_meshgrid", std::string tag = kInjective) {
1707  const bool cartesian_indexing = indexing == "xy" && inputs.size() >= 2;
1708  Array<PrimExpr> out_shape;
1709  for (size_t i = 0; i < inputs.size(); ++i) {
1710  const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i;
1711  out_shape.push_back(inputs[src_index]->shape.size() == 0 ? 1 : inputs[src_index]->shape[0]);
1712  }
1713  Array<Tensor> result;
1714  for (size_t i = 0; i < inputs.size(); ++i) {
1715  result.push_back(compute(
1716  out_shape,
1717  [&](const Array<Var>& indices) {
1718  const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i;
1719  auto ndim = inputs[i]->GetShape().size();
1720  Array<PrimExpr> real_indices = {};
1721  if (ndim > 0) {
1722  real_indices = {indices[src_index]};
1723  }
1724  return inputs[i](real_indices);
1725  },
1726  name, tag));
1727  }
1728  return result;
1729 }
1730 
1741 inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
1742  const std::string& dst_layout,
1743  const std::string schedule_rule = "None",
1744  const std::string name = "T_layout_trans",
1745  const std::string tag = kInjective) {
1746  Layout src_layout_struct(src_layout);
1747  Layout dst_layout_struct(dst_layout);
1748 
1749  if (src_layout_struct.Equals(dst_layout_struct)) {
1750  return src;
1751  }
1752 
1753  ICHECK(src_layout_struct.defined() && dst_layout_struct.defined())
1754  << "cannot convert from/to undefined layout";
1755 
1756  auto layout_converter = tir::BijectiveLayout(src_layout_struct, dst_layout_struct);
1757  ICHECK(layout_converter.defined())
1758  << "cannot convert from " << src_layout << " to " << dst_layout;
1759 
1760  Array<PrimExpr> dst_shape = layout_converter.ForwardShape(src->shape);
1761 
1762  Map<String, ObjectRef> attrs = {{"schedule_rule", String(schedule_rule)},
1763  // Information about layouts needed for the schedule rule
1764  {"src_layout", String(src_layout)},
1765  {"dst_layout", String(dst_layout)},
1766  {"input_shape", src->shape}};
1767 
1768  return compute(
1769  dst_shape,
1770  [&](const Array<Var>& dst_indices) {
1771  Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
1772  Array<PrimExpr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
1773  PrimExpr in_range = PrimExpr(1) > PrimExpr(0); // init with dtype=bool and value=true
1774  for (size_t i = 0; i < src.ndim(); ++i) {
1775  in_range = in_range && (src_indices[i] < src->shape[i]);
1776  }
1777  return if_then_else(in_range, src(src_indices), tvm::cast(src->dtype, PrimExpr(0)));
1778  },
1779  name, tag, attrs);
1780 }
1781 
1784  std::vector<std::string>* axes) {
1785  int32_t factor = 0;
1786  std::string axis = "";
1787  for (char c : std::string(layout)) {
1788  if (c >= 'A' && c <= 'z') {
1789  axis += c;
1790  if (factor != 0) {
1791  shape->push_back(factor);
1792  factor = 0;
1793  }
1794  } else if (c >= '0' && c <= '9') {
1795  factor = factor * 10 + c - '0';
1796  if (!axis.empty()) {
1797  axes->push_back(axis);
1798  axis = "";
1799  }
1800  } else {
1801  LOG(FATAL) << "Invalid layout " << layout;
1802  }
1803  }
1804  if (!axis.empty()) {
1805  axes->push_back(axis);
1806  }
1807 }
1808 
1819 inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String& src_layout,
1820  const String& dst_layout,
1821  const String name = "T_auto_scheduler_layout_trans",
1822  const String tag = kInjective) {
1823  Array<PrimExpr> src_shape;
1824  std::vector<std::string> src_axes;
1825  Array<PrimExpr> dst_shape;
1826  std::vector<std::string> dst_axes;
1827 
1828  parse_auto_scheduler_layout(src_layout, &src_shape, &src_axes);
1829  parse_auto_scheduler_layout(dst_layout, &dst_shape, &dst_axes);
1830  return compute(
1831  dst_shape,
1832  [&](const Array<Var>& dst_indices) {
1833  Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
1834  Array<PrimExpr> src_indices;
1835  for (const std::string& src_axis : src_axes) {
1836  PrimExpr src_index = 0;
1837  CHECK_EQ(dst_indices_expr.size(), dst_axes.size());
1838  for (size_t i = 0; i < dst_axes.size(); ++i) {
1839  if (dst_axes[i] == src_axis) {
1840  src_index = src_index * dst_shape[i] + dst_indices_expr[i];
1841  }
1842  }
1843  src_indices.push_back(src_index);
1844  }
1845  return src(src_indices);
1846  },
1847  name, tag);
1848 }
1849 
1886 inline Tensor meta_schedule_layout_transform(const Tensor& src, const tir::IndexMap& index_map,
1887  const String name = "T_meta_schedule_layout_trans",
1888  const String tag = kInjective) {
1889  arith::Analyzer analyzer;
1890  Array<Range> iter_domain;
1891  iter_domain.reserve(src->shape.size());
1892  for (const PrimExpr& e : src->shape) {
1893  iter_domain.push_back(Range::FromMinExtent(make_zero(e->dtype), e));
1894  }
1895  Array<PrimExpr> post_transform_shape = index_map->MapShape(src->shape, &analyzer);
1896  return compute(
1897  post_transform_shape,
1898  [src, inv = index_map.Inverse(iter_domain, &analyzer),
1899  &analyzer](const Array<Var>& indices) -> PrimExpr {
1900  return src(inv->MapIndices(Array<PrimExpr>{indices.begin(), indices.end()}, &analyzer));
1901  },
1902  name, tag);
1903 }
1904 
1913 inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = "T_shape",
1914  const std::string tag = kInjective) {
1915  int ndim = static_cast<int>(src->shape.size());
1916  Array<PrimExpr> out_shape{ndim};
1917  return compute(
1918  out_shape,
1919  [&](const Array<Var>& indices) {
1920  auto idx = indices[0];
1921  PrimExpr ret = 0;
1922  for (int i = 0; i < ndim; ++i) {
1923  ret = tvm::if_then_else(idx == i, src->shape[i], ret);
1924  }
1925  return tvm::cast(dtype, ret);
1926  },
1927  name, tag);
1928 }
1929 
1938 inline Tensor ndarray_size(const Tensor& src, const DataType& dtype,
1939  const std::string& name = "ndarray_size",
1940  const std::string& tag = kInjective) {
1941  int ndim = static_cast<int>(src->shape.size());
1942  Array<PrimExpr> out_ndarray_size = {};
1943  return compute(
1944  out_ndarray_size,
1945  [&](const Array<Var>& indices) {
1946  PrimExpr ret = 1;
1947  for (int i = 0; i < ndim; ++i) {
1948  ret *= src->shape[i];
1949  }
1950  return tvm::cast(dtype, ret);
1951  },
1952  name, tag);
1953 }
1954 
1969 inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value,
1970  int depth, int axis, const DataType& dtype,
1971  Array<PrimExpr> oshape = Array<PrimExpr>(),
1972  const std::string name = "T_one_hot", const std::string tag = kInjective) {
1973  int true_axis = (axis == -1) ? indices->shape.size() : axis;
1974  if (oshape.size() == 0) {
1975  int ndim = indices->shape.size() + 1;
1976  int indices_index = 0;
1977  for (int i = 0; i < ndim; i++) {
1978  if (i == true_axis) {
1979  oshape.push_back(Integer(depth));
1980  } else {
1981  oshape.push_back(indices->shape[indices_index++]);
1982  }
1983  }
1984  }
1985 
1986  PrimExpr on_value_cast = cast(dtype, on_value);
1987  PrimExpr off_value_cast = cast(dtype, off_value);
1988  return compute(
1989  oshape,
1990  [&](const Array<Var>& iter_vars) {
1991  Array<Var> indices_indices;
1992  for (size_t i = 0; i < iter_vars.size(); i++) {
1993  if (static_cast<int>(i) == true_axis) {
1994  continue;
1995  }
1996 
1997  indices_indices.push_back(iter_vars[i]);
1998  }
1999 
2000  auto idx = iter_vars[true_axis];
2001  return tir::Select(indices(indices_indices) == idx, on_value_cast, off_value_cast);
2002  },
2003  name, tag);
2004 }
2005 
2016 inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array<PrimExpr>& output_shape,
2017  const Tensor& sparse_values, const PrimExpr& default_value,
2018  const std::string name = "T_sparse_to_dense",
2019  const std::string tag = kInjective) {
2020  ICHECK(sparse_indices->dtype.is_int()) << "sparse_indices only accepts integer values";
2021  ICHECK_LE(sparse_indices->shape.size(), 3)
2022  << "sparse_indices tensor should be 0D, 1D, or 2D only";
2023  ICHECK_LE(sparse_values->shape.size(), 2) << "sparse_values tensor should be 0D or 1D only";
2024 
2025  const auto rank_sparse_indices = static_cast<int>(sparse_indices->shape.size());
2026  Array<PrimExpr> oshape;
2027  for (auto l : output_shape) {
2028  oshape.push_back(l);
2029  }
2030  return compute(
2031  oshape,
2032  [&](const Array<Var>& indices) {
2033  PrimExpr ret = default_value;
2034  if (0 == rank_sparse_indices) {
2035  ret = if_then_else(indices[0] == sparse_indices(), sparse_values(), ret);
2036  } else if (1 == rank_sparse_indices) {
2037  for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
2038  ret = if_then_else(indices[0] == sparse_indices[j], sparse_values[j], ret);
2039  }
2040  } else {
2041  for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
2042  PrimExpr aggregate_condition;
2043  for (int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) {
2044  PrimExpr comparision = indices[k] == sparse_indices[j][k];
2045  aggregate_condition = 0 == k ? comparision : aggregate_condition && comparision;
2046  }
2047  ret = if_then_else(aggregate_condition, sparse_values[j], ret);
2048  }
2049  }
2050  return ret;
2051  },
2052  name, tag);
2053 }
2054 
2067 inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k1, int k2,
2068  bool super_diag_right_align, bool sub_diag_right_align,
2069  const std::string name = "T_matrix_set_diag",
2070  const std::string tag = kInjective) {
2071  size_t ndim = input->shape.size() - 1;
2072 
2073  bool only_one_diagonal = k1 == k2;
2074 
2075  return compute(
2076  input->shape,
2077  [&](const Array<Var>& iter_vars) {
2078  auto get_diag = [&]() {
2079  Array<PrimExpr> diagonal_indices;
2080  PrimExpr k, offset = 0;
2081  for (size_t i = 0; i < ndim - 1; i++) {
2082  diagonal_indices.push_back(iter_vars[i]);
2083  }
2084  if (only_one_diagonal) {
2085  k = k1;
2086  } else {
2087  // Determining which diagonal/sub-diagonal/super-diagonal it is
2088  k = iter_vars[ndim] - iter_vars[ndim - 1];
2089  diagonal_indices.push_back(k2 - k);
2090 
2091  // Calculating the offset in diagonal tensor for this diagonal
2092  auto get_offset = [&](PrimExpr M, PrimExpr N) {
2093  // offset = max_diagonal_length - diagonal_length
2094  return diagonal->shape[diagonal->shape.size() - 1] - if_then_else(M < N, M, N);
2095  };
2096  offset = if_then_else(
2097  k >= 0,
2098  super_diag_right_align ? get_offset(input->shape[ndim] - k, input->shape[ndim - 1])
2099  : 0,
2100  sub_diag_right_align ? get_offset(input->shape[ndim], input->shape[ndim - 1] + k)
2101  : 0);
2102  }
2103  diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) +
2104  offset);
2105  return diagonal(diagonal_indices);
2106  };
2107  return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1,
2108  if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2,
2109  get_diag(), input(iter_vars)),
2110  input(iter_vars));
2111  },
2112  name, tag);
2113 }
2114 
2123 inline Tensor adv_index(const Tensor& data, const Array<Tensor>& indices,
2124  const std::string name = "advanced_index",
2125  const std::string tag = kInjective) {
2126  ICHECK_LE(indices.size(), data->shape.size()) << "too many indices for data!";
2127  Array<PrimExpr> oshape;
2128  Array<PrimExpr> broadcast_shape;
2129  Array<Tensor> bindices;
2130 
2131  broadcast_shape = indices[0]->shape;
2132  for (size_t i = 1; i < indices.size(); ++i) {
2133  auto bh = detail::BroadcastShape(broadcast_shape, indices[i]->shape);
2134  broadcast_shape = Array<PrimExpr>(bh.common_shape.begin(), bh.common_shape.end());
2135  }
2136  if (indices.size() == 1) {
2137  // quick path
2138  bindices = indices;
2139  } else {
2140  // Do broadcast for indices
2141  for (size_t i = 0; i < indices.size(); ++i) {
2142  bindices.push_back(broadcast_to(indices[i], broadcast_shape));
2143  }
2144  }
2145 
2146  for (const auto& dim : broadcast_shape) {
2147  oshape.push_back(dim);
2148  }
2149  for (size_t i = indices.size(); i < data->shape.size(); ++i) {
2150  oshape.push_back(data->shape[i]);
2151  }
2152 
2153  return compute(
2154  oshape,
2155  [&](const Array<Var>& iter_var) {
2156  Array<PrimExpr> tensor_indices;
2157  for (size_t i = 0; i < broadcast_shape.size(); ++i) {
2158  tensor_indices.push_back(iter_var[i]);
2159  }
2160  Array<PrimExpr> real_indices;
2161  for (size_t i = 0; i < bindices.size(); ++i) {
2162  real_indices.push_back(bindices[i](tensor_indices));
2163  }
2164  for (size_t i = broadcast_shape.size(); i < iter_var.size(); ++i) {
2165  real_indices.push_back(iter_var[i]);
2166  }
2167 
2168  return data(real_indices);
2169  },
2170  name, tag);
2171 }
2172 
2173 namespace relax {
2174 // relax dynamic slice
2176  const te::Tensor& end, const te::Tensor& strides,
2177  Array<PrimExpr> output_shape,
2178  std::string name = "T_strided_slice_dynamic",
2179  std::string tag = kInjective) {
2180  const size_t num_dynamic_axes = x.ndim();
2181  ICHECK_EQ(begin.ndim(), 1);
2182  ICHECK_EQ(end.ndim(), 1);
2183  ICHECK_EQ(strides.ndim(), 1);
2184  const auto* len_begin = begin->shape[0].as<IntImmNode>();
2185  const auto* len_end = end->shape[0].as<IntImmNode>();
2186  const auto* len_strides = strides->shape[0].as<IntImmNode>();
2187  ICHECK(len_begin);
2188  ICHECK(len_end);
2189  ICHECK(len_strides);
2190  ICHECK_EQ(len_begin->value, num_dynamic_axes);
2191  ICHECK_EQ(len_end->value, num_dynamic_axes);
2192  ICHECK_EQ(len_strides->value, num_dynamic_axes);
2193 
2194  return te::compute(
2195  output_shape,
2196  [&](const Array<tvm::tir::Var>& indices) {
2197  Array<PrimExpr> real_indices;
2198  for (size_t i = 0; i < num_dynamic_axes; ++i) {
2199  auto ind = make_const(DataType::Int(64), i);
2200  real_indices.push_back(indices[i] * strides(ind) + tvm::min(begin(ind), x->shape[i] - 1));
2201  }
2202  return x(real_indices);
2203  },
2204  name, tag);
2205 }
2206 
2207 } // namespace relax
2208 
2209 } // namespace topi
2210 } // namespace tvm
2211 #endif // TVM_TOPI_TRANSFORM_H_
Algebra expression simplifications.
Broadcast op constructions.
Constant integer literals in the program.
Definition: expr.h:501
int64_t value
the Internal value.
Definition: expr.h:504
Managed reference class to IntImmNode.
Definition: expr.h:530
Container of constant int that adds more constructors.
Definition: expr.h:632
Reference to PrimExprNode.
Definition: expr.h:115
DataType dtype() const
Definition: expr.h:129
Range container
Definition: expr.h:725
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span=Span())
construct a new range with min and extent The corresponding constructor is removed,...
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:629
bool CanProveGreaterEqual(const PrimExpr &expr, int64_t lower_bound)
Whether can we prove expr >= val.
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
bool CanProveLess(const PrimExpr &expr, int64_t upper_bound)
Whether can we prove expr < val.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void reserve(int64_t n)
Make sure the list has the capacity of at least n.
Definition: array.h:569
Array< U > Map(F fmap) const
Helper function to apply a map function onto the array.
Definition: array.h:651
iterator end() const
Definition: array.h:390
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
void Set(int64_t i, T value)
set i-th element of the array.
Definition: array.h:621
iterator begin() const
Definition: array.h:387
size_t size() const
Definition: array.h:420
Runtime primitive data type.
Definition: data_type.h:43
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:236
bool is_int() const
Definition: data_type.h:137
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:219
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
bool defined() const
Definition: object.h:552
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
Reference to string objects.
Definition: string.h:98
Definition: variant.h:69
Node to represent a tensor.
Definition: tensor.h:68
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
size_t ndim() const
Definition: tensor.h:214
Bijective function mapping for data layout transformation. Given two Layout, BijectiveLayout build an...
Definition: data_layout.h:332
Definition: index_map.h:176
IndexMap Inverse(Array< Range > initial_ranges, arith::Analyzer *analyzer) const
Generate the inverse mapping.
Managed reference to LayoutNode.
Definition: data_layout.h:123
bool Equals(const Layout &rhs) const
Whether the two layouts are equal.
Definition: data_layout.h:278
Managed reference to SelectNode.
Definition: expr.h:609
A variable node in the IR.
Definition: var.h:48
String name_hint
The hint to the variable name.
Definition: var.h:54
a named variable in TIR
Definition: var.h:89
Utility functions for handling constants in TVM expressions.
Layout expression to describe the data organization of a tensor. And BijectiveLayout to mapping two d...
Detail broadcast.
Defines a remapping of buffer indices.
Base expr nodes in TVM.
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:962
DataType DefaultIndexType()
if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return int32
Definition: buffer.h:42
PrimExpr make_zero(DataType t, Span span=Span())
Make a const zero expr.
Definition: op.h:976
te::Tensor dynamic_strided_slice(const te::Tensor &x, const te::Tensor &begin, const te::Tensor &end, const te::Tensor &strides, Array< PrimExpr > output_shape, std::string name="T_strided_slice_dynamic", std::string tag=kInjective)
Definition: transform.h:2175
PrimExpr GetLength(PrimExpr begin, PrimExpr end, PrimExpr stride, PrimExpr extent, bool assume_inbound=true)
Definition: transform.h:679
Tensor sequence_mask(const Tensor &data, const Tensor &valid_length, double mask_value, int axis, std::string name="T_sequence_mask", std::string tag=kInjective)
Mask the out-of-boundary elements of each sequence.
Definition: transform.h:1061
Tensor gather_nd(const Tensor &data, const Tensor &indices, int batch_dims=0, std::string name="T_gather_nd", std::string tag=kInjective)
Gather elements from a n-dimension array.
Definition: transform.h:1496
int64_t StaticCanonicalizeIndex(int64_t index, int64_t extent, int64_t stride)
Definition: transform.h:660
constexpr auto kBroadcast
Definition: tags.h:36
Tensor transpose(const Tensor &x, Array< Integer > axes, std::string name="T_transpose", std::string tag=kInjective)
Permute the dimensions of an array.
Definition: transform.h:203
Tensor arange(const PrimExpr &start, const PrimExpr &stop, const PrimExpr &step, DataType dtype, std::string name="T_arange", std::string tag=kInjective)
Definition: transform.h:1671
Tensor strided_slice(const Tensor &x, const Array< Integer > &begin, const Array< Integer > &end, const Array< Integer > &strides, std::string slice_mode="end", std::string name="T_strided_slice", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:930
constexpr auto kInjective
Definition: tags.h:33
Tensor sliding_window(const Tensor &x, int axis, Array< Integer > window_shape, Array< Integer > strides, std::string name="T_sliding_window", std::string tag="")
Creates an operation to slide a window over the input x.
Definition: transform.h:75
Tensor reshape(const Tensor &x, Array< PrimExpr > newshape, std::string name="T_reshape", std::string tag=kInjective)
Reshape a tensor.
Definition: transform.h:327
Tensor one_hot(const Tensor &indices, const PrimExpr on_value, const PrimExpr off_value, int depth, int axis, const DataType &dtype, Array< PrimExpr > oshape=Array< PrimExpr >(), const std::string name="T_one_hot", const std::string tag=kInjective)
Returns a one-hot tensor where the locations repsented by indices take value on_value,...
Definition: transform.h:1969
Tensor dynamic_strided_slice(const Tensor &x, const Array< PrimExpr > &begin, const Array< PrimExpr > &end, const Array< PrimExpr > &strides, bool assume_inbound=true, std::string name="T_dynamic_strided_slice", std::string tag=kInjective)
strided_slice of a tensor where begin/end/stride can be mixed static and dynamic
Definition: transform.h:761
Tensor meta_schedule_layout_transform(const Tensor &src, const tir::IndexMap &index_map, const String name="T_meta_schedule_layout_trans", const String tag=kInjective)
Transform the meta-schedule generated layout according to TIR's IndexMap.
Definition: transform.h:1886
Array< Tensor > meshgrid(const Array< Tensor > &inputs, const std::string &indexing, std::string name="T_meshgrid", std::string tag=kInjective)
Produce grids by expanding input over dimensions defined by other inputs.
Definition: transform.h:1705
Tensor tile(const Tensor &x, Array< Integer > reps, std::string name="T_tile", std::string tag=kBroadcast)
Creates an operation to tile elements of an array.
Definition: transform.h:1351
PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stride)
Definition: transform.h:669
tvm::te::Tensor broadcast_to(const tvm::te::Tensor &t, const tvm::Array< tvm::PrimExpr > &output_shape, std::string name="T_broadcast_to", std::string tag=kBroadcast)
Creates an operation that broadcasts a tensor into a compatible shape according to numpy's rules.
Definition: broadcast.h:48
Tensor dyn_tile(const Tensor &x, Array< PrimExpr > new_shape, size_t rdim, std::string name="T_tile", std::string tag=kBroadcast)
Creates an operation to tile elements of an array.
Definition: transform.h:1406
Tensor adv_index(const Tensor &data, const Array< Tensor > &indices, const std::string name="advanced_index", const std::string tag=kInjective)
Numpy style advanced indexing with tensor.
Definition: transform.h:2123
Tensor concatenate(const Array< Tensor > &inputs, int axis=0, std::string name="T_concat", std::string tag=kInjective)
Join a sequence of tensors along an existing axis.
Definition: transform.h:473
void parse_auto_scheduler_layout(const String &layout, Array< PrimExpr > *shape, std::vector< std::string > *axes)
Utility function for auto_scheduler_layout_transform.
Definition: transform.h:1783
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type,...
Definition: elemwise.h:281
Tensor expand_dims(const Tensor &x, int axis, int num_newaxis=1, std::string name="T_expand_dims", std::string tag=kBroadcast)
Creates an operation to insert new dimensions of length 1.
Definition: transform.h:154
Tensor squeeze(const Tensor &x, Array< Integer > axis, bool atleast1d=false, std::string name="T_squeeze", std::string tag=kInjective)
Remove size 1 dimensions from the shape of a tensor. The removed dimensions must have a constant size...
Definition: transform.h:410
Tensor sparse_to_dense(const Tensor &sparse_indices, const Array< PrimExpr > &output_shape, const Tensor &sparse_values, const PrimExpr &default_value, const std::string name="T_sparse_to_dense", const std::string tag=kInjective)
Get a dense tensor.
Definition: transform.h:2016
Tensor unravel_index(const Tensor &x, const Tensor &shape, std::string name="T_unravel", std::string tag=kInjective)
Converts a flat index or array of flat indices into a tuple of coordinate arrays.
Definition: transform.h:362
Tensor auto_scheduler_layout_transform(const Tensor &src, const String &src_layout, const String &dst_layout, const String name="T_auto_scheduler_layout_trans", const String tag=kInjective)
Transform the auto-scheduler generated layout according to src_layout and dst_layout.
Definition: transform.h:1819
Tensor ndarray_size(const Tensor &src, const DataType &dtype, const std::string &name="ndarray_size", const std::string &tag=kInjective)
Get the size of input tensor.
Definition: transform.h:1938
Tensor layout_transform(const Tensor &src, const std::string &src_layout, const std::string &dst_layout, const std::string schedule_rule="None", const std::string name="T_layout_trans", const std::string tag=kInjective)
Transform the layout according to src_layout and dst_layout.
Definition: transform.h:1741
Tensor take(const Tensor &a, const Tensor &indices, int batch_dims, std::string mode="clip", std::string name="T_take", std::string tag=kInjective)
Take elements from an flattened input array when axis is None.
Definition: transform.h:1013
constexpr auto kMatMul
Definition: tags.h:37
Tensor reverse_sequence(const Tensor &x, const Tensor &seq_lengths, int seq_axis=1, int batch_axis=0, std::string name="T_reverse_sequence", std::string tag=kInjective)
Reverse the tensor for variable length slices. Input is first sliced along batch axis and then elemen...
Definition: transform.h:262
Tensor sum(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:326
Tensor tensordot(const Tensor &A, const tvm::te::Tensor &B, int axes=2, std::string name="T_tensordot", std::string tag=kMatMul)
A generalization of matrix multiplication to tensors.
Definition: transform.h:1579
Tensor dynamic_strided_slice_with_axes(const Tensor &x, const Array< PrimExpr > &begin, const Array< PrimExpr > &end, const Array< PrimExpr > &strides, const Array< Integer > &axes, bool assume_inbound=true, std::string name="T_dynamic_strided_slice_with_axes", std::string tag=kInjective)
strided_slice of a tensor where begin/end/stride can be mixed static and dynamic
Definition: transform.h:706
Tensor stack(const Array< Tensor > &inputs, int axis=0, std::string name="T_stack", std::string tag=kInjective)
Join a sequence of tensors along a new axis.
Definition: transform.h:532
Array< Tensor > split_sections(const Tensor &x, int num_sections, int axis, std::string name="T_split_sections", std::string tag=kInjective)
Split a tensor into a number of sub-tensors.
Definition: transform.h:971
Tensor strided_slice_with_axes(const Tensor &x, const Array< Integer > &begin, const Array< Integer > &end, const Array< Integer > &strides, const Array< Integer > &axes, std::string slice_mode="end", std::string name="T_strided_slice_with_axes", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:884
PrimExpr DynamicCanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stride)
Definition: transform.h:642
tvm::te::Tensor matmul(const tvm::te::Tensor &A, const tvm::te::Tensor &B, bool trans_a=false, bool trans_b=false, std::string name="T_matmul", std::string tag=kMatMul)
Creates an operation that calculates a matrix multiplication (row-major notation): A(i,...
Definition: transform.h:1557
Tensor matrix_set_diag(const Tensor &input, const Tensor &diagonal, int k1, int k2, bool super_diag_right_align, bool sub_diag_right_align, const std::string name="T_matrix_set_diag", const std::string tag=kInjective)
Returns a tensor with the diagonal of input tensor replaced with the provided diagonals.
Definition: transform.h:2067
Tensor where(const Tensor &condition, const Tensor &x, const Tensor &y, std::string name="T_where", std::string tag=kBroadcast)
Return the elements, either from x or y, depending on the condition.
Definition: transform.h:1264
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1913
Tensor gather(const Tensor &data, int axis, const Tensor &indices, std::string name="T_gather", std::string tag=kInjective)
Gather values along given axis from given indices.
Definition: transform.h:1443
Array< Tensor > split(const Tensor &x, Array< PrimExpr > split_indices, int axis, std::string name="T_split", std::string tag=kInjective)
Split a tensor into multiple sub-tensors.
Definition: transform.h:578
Tensor repeat(const Tensor &x, int repeats, int axis, std::string name="T_repeat", std::string tag=kBroadcast)
Creates an operation to repeat elements of an array.
Definition: transform.h:1304
Array< PrimExpr > StridedSliceOutputShape(const Array< PrimExpr > &ishape, const Array< Integer > &begin, const Array< Integer > &end, const Array< Integer > &strides, const Array< Integer > &axes, const std::string &slice_mode)
Calculate the output shape of strided_slice, the entry point for Relay type relation.
Definition: transform.h:856
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span=Span())
compute ceil(a / b)
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of truncdiv
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr max_value(const DataType &dtype, Span span=Span())
PrimExpr ceil(PrimExpr x, Span span=Span())
Calculate ceil(x)
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b)
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Operation node can generate one or multiple Tensors.
Index ravel and unraval operations.
Utility functions for strided_slice op.
External function interface to rocBLAS libraries.
Utility functions for handling tensor.
TIR expressions.
Common operators defined for Expr.
Variables in the TIR.