tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_TRANSFORM_H_
25 #define TVM_TOPI_TRANSFORM_H_
26 
27 #include <tvm/arith/analyzer.h>
28 #include <tvm/te/operation.h>
29 #include <tvm/tir/data_layout.h>
30 #include <tvm/tir/index_map.h>
31 #include <tvm/topi/broadcast.h>
37 #include <tvm/topi/tags.h>
38 
39 #include <algorithm>
40 #include <iterator>
41 #include <limits>
42 #include <string>
43 #include <unordered_set>
44 #include <utility>
45 #include <vector>
46 
47 #include "tvm/ir/expr.h"
48 #include "tvm/runtime/data_type.h"
49 #include "tvm/tir/expr.h"
50 #include "tvm/tir/op.h"
51 #include "tvm/tir/var.h"
52 
53 namespace tvm {
54 namespace topi {
55 
56 using namespace tvm::te;
57 using namespace topi::detail;
58 
76 inline Tensor sliding_window(const Tensor& x, int axis, ffi::Array<Integer> window_shape,
77  ffi::Array<Integer> strides, std::string name = "T_sliding_window",
78  std::string tag = "") {
79  CHECK_GE(axis, 0);
80  auto _axis = size_t(axis);
81  CHECK_LT(_axis, x->shape.size()) << "axis must be a valid dimension index of x.";
82  CHECK_EQ(x->shape.size() - _axis, window_shape.size())
83  << "There must be a window shape for every dimension of x "
84  << "over which we are sliding the window.";
85  CHECK_EQ(strides.size(), window_shape.size()) << "Windows and strides should be the same length.";
86 
87  // Compute the new shape.
88  ffi::Array<PrimExpr> new_shape;
89  // Dimensions up until `axis` remain the same.
90  for (size_t i = 0; i < _axis; ++i) {
91  new_shape.push_back(x->shape[i]);
92  }
93 
94  // New dimensions which result from sliding the window in each dimension. One new dimension per
95  // window dimension.
96  for (size_t i = 0; i < window_shape.size(); ++i) {
97  // Length of the shape along this dimension.
98  auto dim_len = x->shape[_axis + i];
99  // Length of the window along this dimension.
100  auto window_len = window_shape[i];
101  // Strides along this dimension.
102  auto stride = strides[i];
103 
104  new_shape.push_back(floordiv(dim_len - (window_len - 1) + stride - 1, stride));
105  }
106 
107  // Dimensions comprising the window.
108  for (size_t i = 0; i < window_shape.size(); ++i) {
109  new_shape.push_back(window_shape[i]);
110  }
111 
112  ICHECK(new_shape.size() == _axis + 2 * window_shape.size());
113 
114  return compute(
115  new_shape,
116  [&](const ffi::Array<Var>& indices) {
117  // The index at which to index the old tensor x.
118  ffi::Array<PrimExpr> idx;
119 
120  // Dimensions up until `axis` remain the same.
121  for (size_t i = 0; i < _axis; ++i) {
122  idx.push_back(indices[i]);
123  }
124 
125  for (size_t i = 0; i < window_shape.size(); ++i) {
126  // Which window in this dimension we are indexing.
127  auto window_idx = indices[_axis + i];
128  // Which index within the window we are indexing.
129  auto idx_within_window = indices[_axis + window_shape.size() + i];
130  // Stride value for this dimension.
131  auto stride = strides[i];
132 
133  idx.push_back(window_idx * stride + idx_within_window);
134  }
135 
136  ICHECK(idx.size() == x->shape.size());
137 
138  return x(idx);
139  },
140  name, tag);
141 }
142 
155 inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1,
156  std::string name = "T_expand_dims", std::string tag = kBroadcast) {
157  int ndim = static_cast<int>(x->shape.size());
158  ICHECK(-ndim - 1 <= axis && axis <= ndim)
159  << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]"
160  << ", but got axis = " << axis << ", and data.ndim = " << ndim;
161  ICHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`"
162  << ", but got num_newaxis = " << num_newaxis;
163  if (axis < 0) {
164  // Calculate offset from last dimension
165  axis = ndim + axis + 1;
166  }
167  ffi::Array<PrimExpr> new_shape;
168  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
169  new_shape.push_back(x->shape[i]);
170  }
171  for (size_t i = 0; i < static_cast<size_t>(num_newaxis); ++i) {
172  new_shape.push_back(1);
173  }
174  for (size_t i = axis; i < x->shape.size(); ++i) {
175  new_shape.push_back(x->shape[i]);
176  }
177 
178  return compute(
179  new_shape,
180  [&](const ffi::Array<Var>& indices) {
181  ffi::Array<PrimExpr> idx;
182  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
183  idx.push_back(indices[i]);
184  }
185  for (size_t i = axis + num_newaxis; i < indices.size(); ++i) {
186  idx.push_back(indices[i]);
187  }
188  return x(idx);
189  },
190  name, tag);
191 }
192 
204 inline Tensor transpose(const Tensor& x, ffi::Optional<ffi::Array<Integer>> opt_axes,
205  std::string name = "T_transpose", std::string tag = kInjective) {
206  ffi::Array<Integer> axes = opt_axes.value_or({});
207  if (axes.size() == 0) {
208  for (int i = static_cast<int>(x->shape.size()) - 1; i >= 0; --i) {
209  axes.push_back(i);
210  }
211  }
212 
213  ffi::Array<PrimExpr> new_shape;
214  for (size_t i = 0; i < axes.size(); ++i) {
215  int axis = static_cast<int>(axes[i]->value);
216  int new_axis = axis;
217  if (axis < 0) {
218  new_axis = static_cast<int>(x->shape.size()) + axis;
219  axes.Set(i, new_axis);
220  }
221  ICHECK((new_axis >= 0) && (new_axis < static_cast<int>(x->shape.size())))
222  << "axis=" << axis << " is invalid for the " << static_cast<int>(x->shape.size())
223  << "-dimensional input tensor";
224 
225  for (size_t j = 0; j < axes.size(); ++j) {
226  if (i != j) {
227  ICHECK(new_axis != static_cast<int>(axes[j]->value)) << "repeated axis in transpose";
228  }
229  }
230  new_shape.push_back(x->shape[new_axis]);
231  }
232 
233  return compute(
234  new_shape,
235  [&](const ffi::Array<Var>& indices) {
236  std::vector<PrimExpr> idx;
237  for (size_t i = 0; i < axes.size(); ++i) {
238  idx.push_back(1);
239  }
240  for (size_t i = 0; i < axes.size(); ++i) {
241  int axis = static_cast<int>(axes[i]->value);
242  idx[axis] = indices[i];
243  }
244  return x(idx);
245  },
246  name, tag);
247 }
248 
263 inline Tensor reverse_sequence(const Tensor& x, const Tensor& seq_lengths, int seq_axis = 1,
264  int batch_axis = 0, std::string name = "T_reverse_sequence",
265  std::string tag = kInjective) {
266  size_t src_tensor_dim = x->shape.size();
267  int seq_axis_inp = seq_axis;
268 
269  if (seq_lengths.defined()) {
270  size_t seq_lengths_dim = seq_lengths->shape.size();
271  int batch_axis_inp = batch_axis;
272  if (batch_axis < 0) {
273  batch_axis = static_cast<int>(x->shape.size()) + batch_axis;
274  }
275 
276  ICHECK(seq_lengths_dim == 1) << "seq_lengths should be 1D vector";
277 
278  ICHECK(GetConstInt(seq_lengths->shape[0]) == GetConstInt(x->shape[batch_axis]))
279  << "For reverse_sequnece seq_lengths size should match with dimension of batch axis"
280  << ", but got dimension of batch_axis = " << GetConstInt(x->shape[batch_axis])
281  << ", and seq_length size = " << GetConstInt(seq_lengths->shape[0]);
282 
283  ICHECK((0 <= batch_axis) && (batch_axis < static_cast<int>(x->shape.size())))
284  << "batch_axis=" << batch_axis_inp << " is invalid for the "
285  << static_cast<int>(x->shape.size()) << "-dimensional input tensor";
286  }
287 
288  if (seq_axis < 0) {
289  seq_axis = static_cast<int>(x->shape.size()) + seq_axis;
290  }
291  ICHECK((0 <= seq_axis) && (seq_axis < static_cast<int>(x->shape.size())))
292  << "seq_axis=" << seq_axis_inp << " is invalid for the " << static_cast<int>(x->shape.size())
293  << "-dimensional input tensor";
294 
295  auto func = [&](const ffi::Array<Var>& indices) {
296  ffi::Array<PrimExpr> real_indices;
297  for (size_t i = 0; i < src_tensor_dim; ++i) {
298  if (i == static_cast<size_t>(seq_axis)) {
299  if (seq_lengths.defined()) {
300  auto len = seq_lengths(indices[batch_axis]);
301  auto idx = if_then_else(
302  len <= 1 || len <= indices[i], indices[i],
303  if_then_else(len > x->shape[i], x->shape[i] - 1 - indices[i], len - 1 - indices[i]));
304  real_indices.push_back(idx);
305  } else {
306  real_indices.push_back(x->shape[i] - 1 - indices[i]);
307  }
308  } else {
309  real_indices.push_back(indices[i]);
310  }
311  }
312  return x(real_indices);
313  };
314 
315  return compute(x->shape, func, name, tag);
316 }
317 
328 inline Tensor reshape(const Tensor& x, ffi::Array<PrimExpr> newshape,
329  std::string name = "T_reshape", std::string tag = kInjective) {
330  auto x_shape = x->shape;
331  ffi::Array<PrimExpr> target_shape;
332 
333  for (const auto& ele : newshape) {
334  target_shape.push_back(ele);
335  }
336 
337  // If either the input shape or the target shape contains a zero, return an empty tensor.
338  if (is_empty_shape(target_shape) || is_empty_shape(x->shape)) {
339  return compute(
340  target_shape, [&](const ffi::Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name,
341  tag);
342  } else {
343  return compute(
344  target_shape,
345  [&](const ffi::Array<Var>& indices) {
346  return x(UnravelIndex(
347  RavelIndex(ffi::Array<PrimExpr>{indices.begin(), indices.end()}, target_shape),
348  x_shape));
349  },
350  name, tag);
351  }
352 }
353 
365 inline Tensor unravel_index(const Tensor& x, const Tensor& shape, std::string name = "T_unravel",
366  std::string tag = kInjective) {
367  auto x_shape = x->shape;
368  auto shape_shape = shape->shape;
369 
370  ffi::Array<PrimExpr> oshape;
371  oshape.push_back(shape_shape[0]);
372  if (x_shape.size() != 0) {
373  oshape.push_back(x_shape[0]);
374  }
375 
376  auto func = [&](const ffi::Array<Var>& indices) {
377  auto i = indices[0];
378  std::vector<PrimExpr> indices_divs;
379  PrimExpr ret = 0;
380  PrimExpr cur_val = 0;
381  PrimExpr index_val = 0;
382 
383  if (x_shape.size() != 0) {
384  index_val = x[indices[1]];
385  } else {
386  index_val = x();
387  }
388  indices_divs.push_back(index_val);
389  for (int v = GetConstInt(shape_shape[0]) - 1; v >= 0; --v) {
390  ret = tvm::if_then_else(i == v, indexmod(indices_divs.back(), shape[v]), ret);
391  cur_val = indexdiv(indices_divs.back(), shape[v]);
392  indices_divs.push_back(cur_val);
393  }
394  return ret;
395  };
396 
397  return compute(oshape, func, name, tag);
398 }
399 
413 inline Tensor squeeze(const Tensor& x, ffi::Optional<ffi::Array<Integer>> opt_axes,
414  bool atleast1d = false, std::string name = "T_squeeze",
415  std::string tag = kInjective) {
416  auto ndim = x->shape.size();
417  std::vector<int> axis_val;
418  if (!opt_axes.has_value()) {
419  for (size_t i = 0; i < ndim; ++i) {
420  if (IsConstInt(x->shape[i]) && GetConstInt(x->shape[i]) == 1) {
421  axis_val.push_back(static_cast<int>(i));
422  }
423  }
424  } else {
425  ffi::Array<Integer> axis = *std::move(opt_axes);
426  for (size_t i = 0; i < axis.size(); ++i) {
427  int64_t val = axis[i]->value;
428  if (val < 0) {
429  val += static_cast<int>(x->shape.size());
430  }
431  // If a dimension is not 1, silently skip it (no-op).
432  bool is_const = IsConstInt(x->shape[val]);
433  if ((is_const && GetConstInt(x->shape[val]) == 1) || !is_const) {
434  axis_val.push_back(val);
435  }
436  }
437  }
438 
439  std::unordered_set<int> axis_set(axis_val.begin(), axis_val.end());
440 
441  ffi::Array<PrimExpr> out_shape;
442  for (size_t i = 0; i < ndim; ++i) {
443  if (axis_set.count(static_cast<int>(i)) == 0) {
444  out_shape.push_back(x->shape[i]);
445  }
446  }
447  if (out_shape.size() == 0 && atleast1d) {
448  out_shape.push_back(1);
449  }
450 
451  return compute(
452  out_shape,
453  [&](const ffi::Array<Var>& indices) {
454  ffi::Array<PrimExpr> real_indices;
455  int flag = 0;
456  for (size_t i = 0; i < ndim; ++i) {
457  if (axis_set.count(static_cast<int>(i)) == 0) {
458  real_indices.push_back(indices[i - flag]);
459  } else {
460  real_indices.push_back(0);
461  flag += 1;
462  }
463  }
464  return x(real_indices);
465  },
466  name, tag);
467 }
468 
479 inline Tensor concatenate(const ffi::Array<Tensor>& inputs, int axis = 0,
480  std::string name = "T_concat", std::string tag = kInjective) {
481  int ndim = static_cast<int>(inputs[0]->shape.size());
482  ICHECK(-ndim <= axis && axis < ndim) << "concatenate only accepts `axis` in [-ndim, ndim)"
483  << ", but got axis = " << axis << ", and ndim = " << ndim;
484  if (axis < 0) {
485  axis += ndim;
486  }
487  ICHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds";
488 
489  ffi::Array<PrimExpr> axis_sizes;
490  for (auto t : inputs) {
491  axis_sizes.push_back(t->shape[axis]);
492  }
493  arith::Analyzer analyzer;
494  PrimExpr join_size = axis_sizes[0];
495  for (size_t i = 1; i < axis_sizes.size(); ++i) {
496  join_size += axis_sizes[i];
497  }
498  join_size = analyzer.Simplify(join_size);
499  ffi::Array<PrimExpr> out_shape;
500  for (size_t i = 0; i < inputs[0]->shape.size(); ++i) {
501  out_shape.push_back(i == static_cast<size_t>(axis) ? join_size : inputs[0]->shape[i]);
502  }
503 
504  return compute(
505  out_shape,
506  [&](const ffi::Array<Var>& indices) {
507  auto ret = inputs[0](indices);
508  auto ind = indices[axis];
509  for (size_t i = 0; i < inputs.size() - 1; ++i) {
510  ind -= axis_sizes[i];
511 
512  ffi::Array<PrimExpr> idx;
513  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
514  idx.push_back(indices[i]);
515  }
516  idx.push_back(ind);
517  for (size_t i = axis + 1; i < indices.size(); ++i) {
518  idx.push_back(indices[i]);
519  }
520 
521  ret = tvm::if_then_else(ind >= 0, inputs[i + 1](idx), ret);
522  }
523  return ret;
524  },
525  name, tag);
526 }
527 
538 inline Tensor stack(const ffi::Array<Tensor>& inputs, int axis = 0, std::string name = "T_stack",
539  std::string tag = kInjective) {
540  int ndim = static_cast<int>(inputs[0]->shape.size());
541  ICHECK(-ndim - 1 <= axis && axis <= ndim)
542  << "stack only accepts `axis` in [-ndim, ndim)"
543  << ", but got axis = " << axis << ", and ndim = " << ndim;
544  if (axis < 0) {
545  axis += ndim + 1;
546  }
547  ICHECK_LT(axis, inputs[0]->shape.size() + 1) << "axis out of bounds";
548 
549  const int stack_size = static_cast<int>(inputs.size());
550  ffi::Array<PrimExpr> out_shape;
551  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) out_shape.push_back(inputs[0]->shape[i]);
552  out_shape.push_back(stack_size);
553  for (size_t i = static_cast<size_t>(axis); i < static_cast<size_t>(ndim); ++i)
554  out_shape.push_back(inputs[0]->shape[i]);
555 
556  return compute(
557  out_shape,
558  [&](const ffi::Array<Var>& indices) {
559  ffi::Array<PrimExpr> idx;
560  for (size_t i = 0; i < indices.size(); ++i)
561  if (i != static_cast<size_t>(axis)) idx.push_back(indices[i]);
562  auto ind = indices[axis];
563  auto ret = inputs[0](idx);
564  for (int i = 0; i < static_cast<int>(inputs.size() - 1); ++i) {
565  ret = tvm::if_then_else(ind == i + 1, inputs[i + 1](idx), ret);
566  }
567  return ret;
568  },
569  name, tag);
570 }
571 
584 inline ffi::Array<Tensor> split_indices_array(const Tensor& x, ffi::Array<PrimExpr> split_indices,
585  int axis, std::string name = "T_split",
586  std::string tag = kInjective) {
587  if (axis < 0) {
588  axis += static_cast<int>(x->shape.size());
589  }
590  ICHECK_LT(axis, x->shape.size()) << "axis out of bounds";
591 
592  auto src_axis_size = x->shape[axis];
593  std::vector<PrimExpr> begin_ids;
594  begin_ids.push_back(0);
595 
596  for (auto idx : split_indices) {
597  auto idx_node = idx.as<IntImmNode>();
598  auto back_node = begin_ids.back().as<IntImmNode>();
599  if (idx_node && back_node) {
600  ICHECK_GT(idx_node->value, back_node->value) << "split_indices must be sorted";
601  }
602  begin_ids.push_back(idx);
603  }
604 
605  ffi::Array<ffi::Array<PrimExpr>> out_shapes;
606  for (size_t i = 0; i < begin_ids.size(); ++i) {
607  PrimExpr out_axis_size;
608  if (i == begin_ids.size() - 1) {
609  out_axis_size = src_axis_size - begin_ids[i];
610  } else {
611  out_axis_size = begin_ids[i + 1] - begin_ids[i];
612  }
613 
614  ffi::Array<PrimExpr> shape;
615  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
616  shape.push_back(x->shape[i]);
617  }
618  shape.push_back(out_axis_size);
619  for (size_t i = axis + 1; i < x->shape.size(); ++i) {
620  shape.push_back(x->shape[i]);
621  }
622 
623  out_shapes.push_back(shape);
624  }
625 
626  ffi::Array<Tensor> result;
627  for (size_t i = 0; i < begin_ids.size(); ++i) {
628  result.push_back(compute(
629  out_shapes[i],
630  [&](const ffi::Array<Var>& indices) {
631  auto begin = begin_ids[i];
632  ffi::Array<PrimExpr> real_indices;
633  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
634  real_indices.push_back(indices[j]);
635  }
636  real_indices.push_back(indices[axis] + begin);
637  for (size_t j = axis + 1; j < indices.size(); ++j) {
638  real_indices.push_back(indices[j]);
639  }
640 
641  return x(real_indices);
642  },
643  name, tag));
644  }
645 
646  return result;
647 }
648 
650  auto idx_var = index.as<tvm::tir::VarNode>();
651  auto extent_var = extent.as<tvm::tir::VarNode>();
652 
653  if (idx_var && extent_var && idx_var->name_hint == extent_var->name_hint) {
654  return index;
655  }
656 
657  PrimExpr begin_range = tvm::if_then_else(stride < 0, -1, 0);
658  PrimExpr end_range = tvm::if_then_else(stride < 0, extent - 1, extent);
659 
660  if (!(index->IsInstance<tvm::IntImmNode>() && GetConstInt(index) >= 0)) {
661  index = tvm::if_then_else(index < 0, index + extent, index);
662  }
663 
664  return tvm::min(tvm::max(index, begin_range), end_range);
665 }
666 
667 inline int64_t StaticCanonicalizeIndex(int64_t index, int64_t extent, int64_t stride) {
668  int64_t begin_range = stride < 0 ? -1 : 0;
669  int64_t end_range = stride < 0 ? extent - 1 : extent;
670  if (index < 0) {
671  index += extent;
672  }
673  return std::min(std::max(index, begin_range), end_range);
674 }
675 
676 inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stride) {
677  if (index->IsInstance<tvm::IntImmNode>() && extent->IsInstance<tvm::IntImmNode>() &&
678  stride->IsInstance<tvm::IntImmNode>()) {
679  return tvm::IntImm(
680  tvm::DataType::Int(64),
681  StaticCanonicalizeIndex(GetConstInt(index), GetConstInt(extent), GetConstInt(stride)));
682  }
683  return DynamicCanonicalizeIndex(index, extent, stride);
684 }
685 
686 inline PrimExpr GetLength(PrimExpr begin, PrimExpr end, PrimExpr stride, PrimExpr extent,
687  bool assume_inbound = true) {
688  if (assume_inbound) {
689  return ceildiv(end - begin, stride);
690  } else {
691  begin = CanonicalizeIndex(begin, extent, stride);
692  end = CanonicalizeIndex(end, extent, stride);
693  return tvm::if_then_else(stride < 0, ceildiv(begin - end, -stride),
694  ceildiv(end - begin, stride));
695  }
696 }
697 
714  const te::Tensor& x, const ffi::Array<PrimExpr>& begin, const ffi::Array<PrimExpr>& end,
715  const ffi::Array<PrimExpr>& strides, const ffi::Array<Integer>& axes,
716  bool assume_inbound = true, std::string name = "T_dynamic_strided_slice_with_axes",
717  std::string tag = kInjective) {
718  const size_t src_tensor_dim = x->shape.size();
719  ICHECK_EQ(begin.size(), end.size());
720  ICHECK_EQ(begin.size(), strides.size());
721  ICHECK_EQ(begin.size(), axes.size());
722  ICHECK_LE(begin.size(), src_tensor_dim);
723 
724  for (const auto& axis_imm : axes) {
725  int axis = axis_imm->value;
726  ICHECK_LT(axis, src_tensor_dim);
727  }
728 
729  arith::Analyzer analyzer;
730 
731  ffi::Array<PrimExpr> out_shape = x->shape;
732  for (size_t i = 0; i < begin.size(); i++) {
733  int axis = axes[i]->value;
734  PrimExpr new_shape =
735  analyzer.Simplify(GetLength(begin[i], end[i], strides[i], out_shape[axis], assume_inbound));
736  out_shape.Set(axis, new_shape);
737  }
738 
739  return te::compute(
740  out_shape,
741  [&](const ffi::Array<tvm::tir::Var>& indices) {
742  ffi::Array<PrimExpr> real_indices =
743  indices.Map([](const auto& var) -> PrimExpr { return var; });
744 
745  for (size_t i = 0; i < begin.size(); i++) {
746  int axis = axes[i]->value;
747  PrimExpr new_index = indices[axis] * strides[i] + begin[i];
748  real_indices.Set(axis, new_index);
749  }
750 
751  return x(real_indices);
752  },
753  name, tag);
754 }
755 
770 inline Tensor dynamic_strided_slice(const Tensor& x, const ffi::Array<PrimExpr>& begin,
771  const ffi::Array<PrimExpr>& end,
772  const ffi::Array<PrimExpr>& strides, bool assume_inbound = true,
773  std::string name = "T_dynamic_strided_slice",
774  std::string tag = kInjective) {
775  const size_t src_tensor_dim = x->shape.size();
776  ICHECK_LE(begin.size(), src_tensor_dim);
777  ICHECK_LE(end.size(), src_tensor_dim);
778  ICHECK_LE(strides.size(), src_tensor_dim);
779  ICHECK_EQ(begin.size(), end.size());
780  ICHECK_EQ(begin.size(), strides.size());
781 
782  const size_t num_slice_axes = begin.size();
783  ffi::Array<PrimExpr> out_shape;
784 
785  arith::Analyzer analyzer;
786  for (size_t i = 0; i < num_slice_axes; ++i) {
787  // Check ProducerLoad to keep backward compatibility for Relax.
788  if (!begin[i]->IsInstance<ProducerLoadNode>() && !end[i]->IsInstance<ProducerLoadNode>() &&
789  !strides[i]->IsInstance<ProducerLoadNode>()) {
790  out_shape.push_back(
791  analyzer.Simplify(GetLength(begin[i], end[i], strides[i], x->shape[i], assume_inbound)));
792  } else {
793  out_shape.push_back(tvm::tir::Var("dim"));
794  }
795  }
796 
797  for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) {
798  out_shape.push_back(x->shape[i]);
799  }
800 
801  return te::compute(
802  out_shape,
803  [&](const ffi::Array<tvm::tir::Var>& indices) {
804  ffi::Array<PrimExpr> real_indices;
805  for (size_t i = 0; i < num_slice_axes; ++i) {
806  real_indices.push_back(indices[i] * strides[i] + tvm::min(begin[i], x->shape[i] - 1));
807  }
808  // keep input dim
809  for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) {
810  real_indices.push_back(indices[i]);
811  }
812  return x(real_indices);
813  },
814  name, tag);
815 }
816 
832  const te::Tensor& end, const te::Tensor& strides,
833  bool assume_inbound = true,
834  std::string name = "T_strided_slice_dynamic",
835  std::string tag = topi::kInjective) {
836  DataType index_dtype = begin->shape[0]->dtype;
837  const int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value;
838  ICHECK_EQ(end->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
839  ICHECK_EQ(strides->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
840 
841  ffi::Array<PrimExpr> begin_expr, end_expr, strides_expr;
842  for (int64_t i = 0; i < num_dynamic_axes; ++i) {
843  auto ind = make_const(index_dtype, i);
844  begin_expr.push_back(begin(ind));
845  end_expr.push_back(end(ind));
846  strides_expr.push_back(strides(ind));
847  }
848  return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, assume_inbound, name, tag);
849 }
850 
865 inline ffi::Array<PrimExpr> StridedSliceOutputShape(const ffi::Array<PrimExpr>& ishape,
866  const ffi::Array<Integer>& begin,
867  const ffi::Array<Integer>& end,
868  const ffi::Array<Integer>& strides,
869  const ffi::Array<Integer>& axes,
870  const std::string& slice_mode) {
871  ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size());
872  std::vector<int64_t> begin_vec, end_vec, strides_vec;
873  std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode);
874  auto begin_canonicalized = StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes,
875  begin[0]->dtype, slice_mode);
876  return StridedSliceOutputShape(ishape, begin_vec, end_vec, strides_vec, axes, slice_mode,
877  begin_canonicalized, true);
878 }
879 
896 inline Tensor strided_slice_with_axes(const Tensor& x, const ffi::Array<Integer>& begin,
897  const ffi::Array<Integer>& end,
898  const ffi::Array<Integer>& strides,
899  const ffi::Array<Integer>& axes,
900  std::string slice_mode = "end",
901  std::string name = "T_strided_slice_with_axes",
902  std::string tag = kInjective) {
903  const size_t src_tensor_dim = x->shape.size();
904  ICHECK(axes.size() <= src_tensor_dim);
905  ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size());
906 
907  std::vector<int64_t> begin_vec, end_vec, strides_vec;
908  std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode);
909 
910  auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, axes,
911  begin[0]->dtype, slice_mode);
912  auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec, strides_vec, axes,
913  slice_mode, begin_expr);
914 
915  return te::compute(
916  out_shape,
917  [&](const ffi::Array<tir::Var>& indices) {
918  ffi::Array<PrimExpr> real_indices;
919  for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]);
920  for (size_t i = 0; i < axes.size(); ++i) {
921  auto stride = make_const(strides[i].dtype(), strides_vec[i]);
922  PrimExpr ind = indices[axes[i].IntValue()] * stride + begin_expr[i];
923  real_indices.Set(axes[i].IntValue(), ind);
924  }
925  return x(real_indices);
926  },
927  name, tag);
928 }
929 
944 inline Tensor strided_slice(const Tensor& x, const ffi::Array<Integer>& begin,
945  const ffi::Array<Integer>& end, const ffi::Array<Integer>& strides,
946  std::string slice_mode = "end", std::string name = "T_strided_slice",
947  std::string tag = kInjective) {
948  size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
949  ffi::Array<Integer> axes;
950  for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i);
951  ffi::Array<Integer> begin_full(begin);
952  ffi::Array<Integer> end_full(end);
953  ffi::Array<Integer> strides_full(strides);
954 
955  DataType index_dtype = begin.size() > 0 ? begin[0]->dtype : DataType::Int(64);
956  const IntImm one = IntImm(index_dtype, 1);
957  const IntImm zero = IntImm(index_dtype, 0);
958  const IntImm max_range = Downcast<IntImm>(max_value(index_dtype));
959 
960  for (size_t i = strides.size(); i < src_tensor_dim; ++i) {
961  strides_full.push_back(one);
962  }
963  for (size_t i = begin.size(); i < src_tensor_dim; ++i) {
964  begin_full.push_back(GetConstInt(strides_full[i]) > 0 ? zero : max_range);
965  }
966  for (size_t i = end.size(); i < src_tensor_dim; ++i) {
967  end_full.push_back(GetConstInt(strides_full[i]) < 0 ? zero : max_range);
968  }
969 
970  return strided_slice_with_axes(x, begin_full, end_full, strides_full, axes, slice_mode, name,
971  tag);
972 }
973 
986 inline ffi::Array<Tensor> split_n_sections(const Tensor& x, int num_sections, int axis,
987  std::string name = "T_split_sections",
988  std::string tag = kInjective) {
989  if (axis < 0) {
990  axis += static_cast<int>(x->shape.size());
991  }
992  ICHECK_LT(axis, x->shape.size()) << "axis out of bounds";
993 
994  auto src_axis_size = x->shape[axis];
995 
996  ICHECK_GT(num_sections, 0) << "Slice count must be > 0";
997 
998  ffi::Array<PrimExpr> split_indices;
999  auto seg_size = indexdiv(src_axis_size + num_sections - 1, num_sections);
1000  for (int i = 0; i < num_sections; ++i) {
1001  // region at index 0 is added by split()
1002  if (i != 0) {
1003  split_indices.push_back(seg_size * i);
1004  }
1005  }
1006 
1007  return split_indices_array(x, split_indices, axis, name, tag);
1008 }
1009 
1022 inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims,
1023  std::string mode = "fast", std::string name = "T_take",
1024  std::string tag = kInjective) {
1025  ffi::Array<PrimExpr> a_shape = a->shape;
1026  ffi::Array<PrimExpr> out_shape = indices->shape;
1027  PrimExpr a_size = 1;
1028  for (size_t i = 0; i < a_shape.size(); ++i) {
1029  a_size = a_size * a_shape[i];
1030  }
1031 
1032  if (mode == "clip") {
1033  return compute(
1034  out_shape,
1035  [&](const ffi::Array<Var>& out_index) {
1036  auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
1037  return a(UnravelIndex(idx, a_shape));
1038  },
1039  name, tag);
1040  } else if (mode == "fast") {
1041  LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
1042  "Make sure input indices are in bound";
1043  return compute(
1044  out_shape,
1045  [&](const ffi::Array<Var>& out_index) {
1046  return a(UnravelIndex(indices(out_index), a_shape));
1047  },
1048  name, tag);
1049  } else if (mode == "nan") {
1050  return compute(
1051  out_shape,
1052  [&](const ffi::Array<Var>& out_index) {
1053  auto idx = tvm::if_then_else(
1054  indices(out_index) < 0 || indices(out_index) >= a_size,
1055  tvm::FloatImm(a->dtype, std::numeric_limits<float>::quiet_NaN()), indices(out_index));
1056  return a(UnravelIndex(idx, a_shape));
1057  },
1058  name, tag);
1059  } else { // mode == "wrap"
1060  return compute(
1061  out_shape,
1062  [&](const ffi::Array<Var>& out_index) {
1063  auto idx = truncmod(truncmod(indices(out_index), a_size) + a_size, a_size);
1064  return a(UnravelIndex(idx, a_shape));
1065  },
1066  name, tag);
1067  }
1068 }
1069 
1082 inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, double mask_value,
1083  int axis, std::string name = "T_sequence_mask",
1084  std::string tag = kInjective) {
1085  ICHECK(axis == 0 || axis == 1) << "axis must be either 0 or 1";
1086  ICHECK_EQ(valid_length->shape.size(), 1) << "valid_length must have ndim=1, i.e., (batch_size,).";
1087  auto length_dim = data->shape[axis];
1088  auto batch_dim = data->shape[1 - axis];
1089  ffi::Array<PrimExpr> out_shape = data->shape;
1090  Tensor out = compute(
1091  out_shape,
1092  [&](const ffi::Array<Var>& out_index) {
1093  ffi::Array<PrimExpr> len_index;
1094  auto tid = out_index[axis];
1095  auto bid = out_index[1 - axis];
1096  len_index.push_back(bid);
1097  PrimExpr ret =
1098  tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index),
1099  tvm::tir::make_const(data->dtype, mask_value), data(out_index));
1100  return ret;
1101  },
1102  name, tag);
1103  return out;
1104 }
1105 
1120 inline Tensor take(const Tensor& a, ffi::Variant<Tensor, PrimExpr> indices, int batch_dims,
1121  int axis, std::string mode = "fast", std::string name = "T_take",
1122  std::string tag = kInjective) {
1123  if (axis < 0) {
1124  axis += static_cast<int>(a->shape.size());
1125  }
1126  ICHECK_GE(axis, 0) << "axis out of bounds";
1127  ICHECK_LT(axis, a->shape.size()) << "axis out of bounds";
1128  auto axis_dim = a->shape[axis];
1129  auto indices_shape = [&]() -> ffi::Array<PrimExpr> {
1130  if (auto tensor = indices.as<TensorNode>()) {
1131  return tensor->shape;
1132  } else {
1133  return {};
1134  }
1135  }();
1136 
1137  int indices_len = static_cast<int>(indices_shape.size());
1138 
1139  int batch_dims_ = batch_dims;
1140  if (batch_dims_ != 0) {
1141  ICHECK_GE(batch_dims_, -indices_len) << "batch_dims out of bounds";
1142  ICHECK_LE(batch_dims_, indices_len) << "batch_dims out of bounds";
1143 
1144  if (batch_dims_ < 0) {
1145  batch_dims_ = indices_len + batch_dims_;
1146  }
1147 
1148  ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds";
1149  ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to axis";
1150  for (int i = 0; i < batch_dims_; ++i) {
1151  auto addr1 = a->shape[i];
1152  auto addr2 = indices_shape[i];
1153  auto v1 = static_cast<IntImm*>(&addr1)->get()->value;
1154  auto v2 = static_cast<IntImm*>(&addr2)->get()->value;
1155  ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to indices.shape[" << i << "]";
1156  }
1157  }
1158 
1159  // The result shape is a.shape[:axis] + indices.shape[batch_dims:] +
1160  // a.shape[axis + 1:].
1161 
1162  ffi::Array<PrimExpr> out_shape;
1163  for (int i = 0; i < batch_dims_; ++i) {
1164  out_shape.push_back(a->shape[i]);
1165  }
1166  for (int i = batch_dims_; i < axis; ++i) {
1167  out_shape.push_back(a->shape[i]);
1168  }
1169  for (int i = batch_dims_; i < indices_len; ++i) {
1170  out_shape.push_back(indices_shape[i]);
1171  }
1172  for (size_t i = axis + 1; i < a->shape.size(); ++i) {
1173  out_shape.push_back(a->shape[i]);
1174  }
1175 
1176  auto get_index = [&](const ffi::Array<PrimExpr>& indices_position) -> PrimExpr {
1177  if (auto tensor = indices.as<Tensor>()) {
1178  return tensor.value()(indices_position);
1179  } else if (auto prim = indices.as<PrimExpr>()) {
1180  ICHECK_EQ(indices_position.size(), 0);
1181  return prim.value();
1182  } else {
1183  LOG(FATAL) << "Variant did not contain either allowed type";
1184  }
1185  };
1186 
1187  if (mode == "clip") {
1188  if (batch_dims_ == 0) {
1189  return compute(
1190  out_shape,
1191  [&](const ffi::Array<Var>& out_index) {
1192  ffi::Array<PrimExpr> indices_position;
1193  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1194  indices_position.push_back(out_index[j]);
1195  }
1196  ffi::Array<PrimExpr> real_indices;
1197  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1198  real_indices.push_back(out_index[j]);
1199  }
1200  auto idx = tvm::min(tvm::max(0, get_index(indices_position)), axis_dim - 1);
1201  real_indices.push_back(idx);
1202  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1203  real_indices.push_back(out_index[j]);
1204  }
1205  return a(real_indices);
1206  },
1207  name, tag);
1208  } else {
1209  return compute(
1210  out_shape,
1211  [&](const ffi::Array<Var>& out_index) {
1212  ffi::Array<PrimExpr> indices_position;
1213  for (size_t j = 0; j < static_cast<size_t>(batch_dims_); ++j) {
1214  indices_position.push_back(out_index[j]);
1215  }
1216  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len - batch_dims_); ++j) {
1217  indices_position.push_back(out_index[j]);
1218  }
1219  ffi::Array<PrimExpr> real_indices;
1220  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1221  real_indices.push_back(out_index[j]);
1222  }
1223  auto idx = tvm::min(tvm::max(0, get_index(indices_position)), axis_dim - 1);
1224  real_indices.push_back(idx);
1225  for (size_t j = axis + indices_len - batch_dims_; j < out_index.size(); ++j) {
1226  real_indices.push_back(out_index[j]);
1227  }
1228  return a(real_indices);
1229  },
1230  name, tag);
1231  }
1232  } else if (mode == "fast") {
1233  LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
1234  "Make sure input indices are in bound";
1235  return compute(
1236  out_shape,
1237  [&](const ffi::Array<Var>& out_index) {
1238  ffi::Array<PrimExpr> indices_position;
1239  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1240  indices_position.push_back(out_index[j]);
1241  }
1242  ffi::Array<PrimExpr> real_indices;
1243  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1244  real_indices.push_back(out_index[j]);
1245  }
1246  real_indices.push_back(get_index(indices_position));
1247  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1248  real_indices.push_back(out_index[j]);
1249  }
1250  return a(real_indices);
1251  },
1252  name, tag);
1253  } else if (mode == "nan") {
1254  return compute(
1255  out_shape,
1256  [&](const ffi::Array<Var>& out_index) {
1257  ffi::Array<PrimExpr> indices_position;
1258  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1259  indices_position.push_back(out_index[j]);
1260  }
1261  ffi::Array<PrimExpr> real_indices;
1262  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1263  real_indices.push_back(out_index[j]);
1264  }
1265  PrimExpr idx = get_index(indices_position);
1266  real_indices.push_back(idx);
1267  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1268  real_indices.push_back(out_index[j]);
1269  }
1270  PrimExpr in_bounds = idx >= 0 && idx < axis_dim;
1271  return tvm::if_then_else(
1272  in_bounds, a(real_indices),
1273  tvm::tir::make_const(a->dtype, std::numeric_limits<float>::quiet_NaN()));
1274  },
1275  name, tag);
1276  } else { // mode == "wrap"
1277  return compute(
1278  out_shape,
1279  [&](const ffi::Array<Var>& out_index) {
1280  ffi::Array<PrimExpr> indices_position;
1281  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1282  indices_position.push_back(out_index[j]);
1283  }
1284  ffi::Array<PrimExpr> real_indices;
1285  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1286  real_indices.push_back(out_index[j]);
1287  }
1288  auto idx = truncmod(truncmod(get_index(indices_position), axis_dim) + axis_dim, axis_dim);
1289  real_indices.push_back(idx);
1290  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1291  real_indices.push_back(out_index[j]);
1292  }
1293  return a(real_indices);
1294  },
1295  name, tag);
1296  }
1297 }
1298 
1310 inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y,
1311  std::string name = "T_where", std::string tag = kBroadcast) {
1312  ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs "
1313  << y->dtype;
1314  auto get_out_shape = [&]() {
1315  auto bh1 = detail::BroadcastShape(x->shape, y->shape);
1316  ffi::Array<PrimExpr> common_shape1(bh1.common_shape.begin(), bh1.common_shape.end());
1317  auto bh2 = detail::BroadcastShape(condition->shape, common_shape1);
1318  ffi::Array<PrimExpr> common_shape2(bh2.common_shape.begin(), bh2.common_shape.end());
1319  return common_shape2;
1320  };
1321 
1322  auto oshape = get_out_shape();
1323 
1324  auto c_bh = detail::BroadcastShape(condition->shape, oshape);
1325  auto x_bh = detail::BroadcastShape(x->shape, oshape);
1326  auto y_bh = detail::BroadcastShape(y->shape, oshape);
1327 
1328  auto select = [&](tvm::ffi::Array<tvm::tir::Var> ovars) {
1329  auto c = condition(InputIndexFromBroadcast(ovars, condition, c_bh.vars1, c_bh.all_vars));
1330  auto true_val = x(InputIndexFromBroadcast(ovars, x, x_bh.vars1, x_bh.all_vars));
1331  auto false_val = y(InputIndexFromBroadcast(ovars, y, y_bh.vars1, y_bh.all_vars));
1332  return tvm::tir::Select(c != 0, true_val, false_val);
1333  };
1334 
1335  return compute(oshape, select, name, tag);
1336 }
1337 
1350 inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = "T_repeat",
1351  std::string tag = kBroadcast) {
1352  int ndim = static_cast<int>(x->shape.size());
1353  ICHECK(-ndim - 1 <= axis && axis <= ndim)
1354  << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
1355  << ", but got axis = " << axis << ", and data.ndim = " << ndim;
1356  ICHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`"
1357  << ", but got repeats = " << repeats;
1358  if (axis < 0) {
1359  // Calculate offset from last dimension
1360  axis += ndim;
1361  }
1362  ffi::Array<PrimExpr> new_shape;
1363  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
1364  new_shape.push_back(x->shape[i]);
1365  }
1366  new_shape.push_back(repeats * x->shape[axis]);
1367  for (size_t i = axis + 1; i < x->shape.size(); ++i) {
1368  new_shape.push_back(x->shape[i]);
1369  }
1370 
1371  return compute(
1372  new_shape,
1373  [&](const ffi::Array<Var>& indices) {
1374  ffi::Array<PrimExpr> idx;
1375  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
1376  idx.push_back(indices[i]);
1377  }
1378  idx.push_back(indexdiv(indices[axis], repeats));
1379  for (size_t i = axis + 1; i < indices.size(); ++i) {
1380  idx.push_back(indices[i]);
1381  }
1382  return x(idx);
1383  },
1384  name, tag);
1385 }
1386 
1397 inline Tensor tile(const Tensor& x, ffi::Array<Integer> reps, std::string name = "T_tile",
1398  std::string tag = kBroadcast) {
1399  size_t ndim = x->shape.size();
1400  size_t rdim = reps.size();
1401  size_t tdim = (ndim > rdim) ? ndim : rdim;
1402  ffi::Array<PrimExpr> data_shape;
1403  ffi::Array<PrimExpr> reps_shape;
1404  ffi::Array<PrimExpr> new_shape;
1405  if (ndim == rdim) {
1406  for (size_t i = 0; i < ndim; ++i) {
1407  data_shape.push_back(x->shape[i]);
1408  reps_shape.push_back(reps[i]);
1409  }
1410  } else if (ndim > rdim) {
1411  for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]);
1412  for (size_t i = 0; i < (ndim - rdim); ++i) reps_shape.push_back(1);
1413  for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]);
1414  } else {
1415  for (size_t i = 0; i < (rdim - ndim); ++i) data_shape.push_back(1);
1416  for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]);
1417  for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]);
1418  }
1419  for (size_t i = 0; i < tdim; ++i) new_shape.push_back(data_shape[i] * reps_shape[i]);
1420 
1421  if (is_empty_shape(new_shape)) {
1422  return compute(
1423  new_shape, [&](const ffi::Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name,
1424  tag);
1425  } else {
1426  return compute(
1427  new_shape,
1428  [&](const ffi::Array<Var>& indices) {
1429  ffi::Array<PrimExpr> idx;
1430  if (ndim >= rdim) {
1431  for (size_t i = 0; i < ndim; ++i) idx.push_back(indexmod(indices[i], x->shape[i]));
1432  } else {
1433  for (size_t i = 0; i < ndim; ++i)
1434  idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
1435  }
1436  return x(idx);
1437  },
1438  name, tag);
1439  }
1440 }
1441 
1453 inline Tensor dyn_tile(const Tensor& x, ffi::Array<PrimExpr> new_shape, size_t rdim,
1454  std::string name = "T_tile", std::string tag = kBroadcast) {
1455  size_t ndim = x->shape.size();
1456  if (is_empty_shape(new_shape)) {
1457  return compute(
1458  new_shape, [&](const ffi::Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name,
1459  tag);
1460  } else {
1461  return compute(
1462  new_shape,
1463  [&](const ffi::Array<Var>& indices) {
1464  ffi::Array<PrimExpr> idx;
1465  if (ndim >= rdim) {
1466  for (size_t i = 0; i < ndim; ++i) {
1467  idx.push_back(indexmod(indices[i], x->shape[i]));
1468  }
1469  } else {
1470  for (size_t i = 0; i < ndim; ++i) {
1471  idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
1472  }
1473  }
1474  return x(idx);
1475  },
1476  name, tag);
1477  }
1478 }
1479 
1491 inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
1492  std::string name = "T_gather", std::string tag = kInjective) {
1493  size_t ndim_d = data->shape.size();
1494  size_t ndim_i = indices->shape.size();
1495  ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar.";
1496  ICHECK_EQ(ndim_d, ndim_i);
1497  if (axis < 0) {
1498  axis += ndim_d;
1499  }
1500  ICHECK_GE(axis, 0);
1501  ICHECK_LT(axis, ndim_d);
1502  if (indices->shape[axis].as<IntImmNode>()) {
1503  size_t indices_dim_i = static_cast<size_t>(GetConstInt(indices->shape[axis]));
1504  ICHECK_GE(indices_dim_i, 1);
1505  }
1506  ICHECK(indices->dtype.is_int() || indices->dtype.is_uint());
1507 
1508  ffi::Array<PrimExpr> out_shape;
1509  for (size_t i = 0; i < ndim_i; ++i) {
1510  out_shape.push_back(indices->shape[i]);
1511  }
1512 
1513  return compute(
1514  out_shape,
1515  [&](const ffi::Array<Var>& out_index) {
1516  ffi::Array<PrimExpr> indices_position;
1517  for (size_t i = 0; i < ndim_i; ++i) {
1518  indices_position.push_back(out_index[i]);
1519  }
1520  ffi::Array<PrimExpr> real_indices;
1521  for (size_t i = 0; i < ndim_i; ++i) {
1522  if (i == static_cast<size_t>(axis)) {
1523  real_indices.push_back(indices(indices_position));
1524  } else {
1525  real_indices.push_back(indices_position[i]);
1526  }
1527  }
1528  return data(real_indices);
1529  },
1530  name, tag);
1531 }
1532 
1544 inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dims = 0,
1545  std::string name = "T_gather_nd", std::string tag = kInjective) {
1546  size_t ndim_d = data->shape.size();
1547  size_t ndim_i = indices->shape.size();
1548  ICHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions";
1549  size_t indices_dim0 = static_cast<size_t>(GetConstInt(indices->shape[0]));
1550  ICHECK_LE(indices_dim0, ndim_d) << "dim 0 of indices tensor must be no more "
1551  << "than dimensions of data tensor";
1552  ffi::Array<PrimExpr> out_shape;
1553  for (size_t i = 1; i < ndim_i; ++i) {
1554  out_shape.push_back(indices->shape[i]);
1555  }
1556  for (size_t i = indices_dim0 + batch_dims; i < ndim_d; ++i) {
1557  out_shape.push_back(data->shape[i]);
1558  }
1559  return compute(
1560  out_shape,
1561  [&](const ffi::Array<Var>& out_index) {
1562  ffi::Array<PrimExpr> indices_position;
1563  indices_position.push_back(0);
1564  for (size_t i = 0; i < ndim_i - 1; ++i) {
1565  indices_position.push_back(out_index[i]);
1566  }
1567  ffi::Array<PrimExpr> real_indices;
1568  for (size_t i = 0; i < static_cast<size_t>(batch_dims); ++i) {
1569  real_indices.push_back(out_index[i]);
1570  }
1571  for (size_t i = 0; i < indices_dim0; ++i) {
1572  indices_position.Set(0, make_const(DataType::Int(32), i));
1573  if (indices->dtype.is_int() || indices->dtype.is_uint()) {
1574  real_indices.push_back(indices(indices_position));
1575  } else {
1576  real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position)));
1577  }
1578  }
1579  if (real_indices.size() == ndim_d) {
1580  return data(real_indices);
1581  }
1582  for (size_t i = ndim_i - 1; i < out_index.size(); ++i) {
1583  real_indices.push_back(out_index[i]);
1584  }
1585  return data(real_indices);
1586  },
1587  name, tag);
1588 }
1589 
1606  bool trans_a = false, bool trans_b = false,
1607  std::string name = "T_matmul", std::string tag = kMatMul) {
1608  tvm::ffi::Array<tvm::PrimExpr> output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]};
1609  auto k = tvm::te::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k");
1610  auto l = [&](tvm::tir::Var i, tvm::tir::Var j) {
1611  return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), {k});
1612  };
1613  return tvm::te::compute(output_shape, l, name, tag);
1614 }
1615 
1627 inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, int axes = 2,
1628  std::string name = "T_tensordot", std::string tag = kMatMul) {
1629  ICHECK_GE(A->shape.size(), axes);
1630  ICHECK_GE(B->shape.size(), axes);
1631 
1632  ffi::Array<PrimExpr> output_shape(A->shape.begin(), A->shape.end() + (-axes));
1633  for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) output_shape.push_back(*it);
1634 
1635  ffi::Array<IterVar> iter_vars;
1636  for (int i = 0; i < axes; ++i)
1637  iter_vars.push_back(reduce_axis(Range(0, B->shape[i]), "k" + std::to_string(i)));
1638 
1639  auto func = [&A, &B, &iter_vars, axes](const ffi::Array<Var>& input_indices) {
1640  ffi::Array<PrimExpr> A_indices(input_indices.begin(),
1641  input_indices.begin() + (A->shape.size() - axes));
1642  for (auto& v : iter_vars) A_indices.push_back(v);
1643 
1644  ffi::Array<PrimExpr> B_indices;
1645  for (auto& v : iter_vars) B_indices.push_back(v);
1646 
1647  auto it = input_indices.begin() + (A->shape.size() - axes);
1648  for (; it != input_indices.end(); ++it) B_indices.push_back(*it);
1649 
1650  // Some passes don't like reductions with empty axis, so avoid it here
1651  if (iter_vars.empty()) {
1652  return A(A_indices) * B(B_indices);
1653  } else {
1654  return sum(A(A_indices) * B(B_indices), iter_vars);
1655  }
1656  };
1657 
1658  return compute(output_shape, func, name, tag);
1659 }
1660 
1673 inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, ffi::Array<PrimExpr> A_axes,
1674  ffi::Array<PrimExpr> B_axes, std::string name = "T_tensordot",
1675  std::string tag = kMatMul) {
1676  ICHECK_EQ(A_axes.size(), B_axes.size());
1677 
1678  auto A_axes_val = GetConstIntValues(A_axes, "A_axes");
1679  auto B_axes_val = GetConstIntValues(B_axes, "B_axes");
1680 
1681  ffi::Array<PrimExpr> output_shape;
1682  for (unsigned i = 0; i < A->shape.size(); ++i)
1683  if (std::find(A_axes_val.begin(), A_axes_val.end(), i) == A_axes_val.end())
1684  output_shape.push_back(A->shape[i]);
1685  for (unsigned i = 0; i < B->shape.size(); ++i)
1686  if (std::find(B_axes_val.begin(), B_axes_val.end(), i) == B_axes_val.end())
1687  output_shape.push_back(B->shape[i]);
1688 
1689  ffi::Array<IterVar> iter_vars;
1690  for (unsigned i = 0; i < B_axes_val.size(); ++i)
1691  iter_vars.push_back(reduce_axis(Range(0, B->shape[B_axes_val[i]]), "k" + std::to_string(i)));
1692 
1693  auto func = [&A, &B, &iter_vars, A_axes_val, B_axes_val](const ffi::Array<Var>& input_indices) {
1694  int idx_input = 0;
1695  ffi::Array<PrimExpr> A_indices;
1696  for (unsigned i = 0; i < A->shape.size(); ++i) {
1697  auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i);
1698  if (axes_pos == A_axes_val.end()) {
1699  A_indices.push_back(input_indices[idx_input++]);
1700  } else {
1701  A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]);
1702  }
1703  }
1704 
1705  ffi::Array<PrimExpr> B_indices;
1706  for (unsigned i = 0; i < B->shape.size(); ++i) {
1707  auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i);
1708  if (axes_pos == B_axes_val.end()) {
1709  B_indices.push_back(input_indices[idx_input++]);
1710  } else {
1711  B_indices.push_back(iter_vars[axes_pos - B_axes_val.begin()]);
1712  }
1713  }
1714  return sum(A(A_indices) * B(B_indices), iter_vars);
1715  };
1716  return compute(output_shape, func, name, tag);
1717 }
1718 
1719 inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr& step,
1720  DataType dtype, std::string name = "T_arange", std::string tag = kInjective) {
1721  arith::Analyzer analyzer;
1722  PrimExpr num_elem;
1723  bool is_all_int = start.dtype().is_int() && stop.dtype().is_int() && step.dtype().is_int();
1724  if (is_all_int && analyzer.CanProveGreaterEqual(step, 1)) {
1725  // fast path for integer arange when step is positive
1726  num_elem = tvm::floordiv((stop - start + step - 1), step);
1727  } else if (is_all_int && analyzer.CanProveLess(step, 0)) {
1728  // fast path for integer arange when step is negative
1729  num_elem = tvm::floordiv((start - stop - step - 1), -step);
1730  } else {
1731  // fallback path for non-integer or step of unknown sign
1732  num_elem = tvm::cast(DefaultIndexType(),
1733  tvm::ceil(tvm::cast(tvm::DataType::Float(32), stop - start) / step));
1734  }
1735  num_elem = analyzer.Simplify(num_elem);
1736 
1737  return compute(
1738  {num_elem},
1739  [&](const ffi::Array<Var>& indices) { return tvm::cast(dtype, start + step * indices[0]); },
1740  name, tag);
1741 }
1742 
1753 inline ffi::Array<Tensor> meshgrid(const ffi::Array<Tensor>& inputs, const std::string& indexing,
1754  std::string name = "T_meshgrid", std::string tag = kInjective) {
1755  const bool cartesian_indexing = indexing == "xy" && inputs.size() >= 2;
1756  ffi::Array<PrimExpr> out_shape;
1757  for (size_t i = 0; i < inputs.size(); ++i) {
1758  const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i;
1759  out_shape.push_back(inputs[src_index]->shape.size() == 0 ? 1 : inputs[src_index]->shape[0]);
1760  }
1761  ffi::Array<Tensor> result;
1762  for (size_t i = 0; i < inputs.size(); ++i) {
1763  result.push_back(compute(
1764  out_shape,
1765  [&](const ffi::Array<Var>& indices) {
1766  const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i;
1767  auto ndim = inputs[i]->GetShape().size();
1768  ffi::Array<PrimExpr> real_indices = {};
1769  if (ndim > 0) {
1770  real_indices = {indices[src_index]};
1771  }
1772  return inputs[i](real_indices);
1773  },
1774  name, tag));
1775  }
1776  return result;
1777 }
1778 
1789 inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
1790  const std::string& dst_layout,
1791  const std::string schedule_rule = "None",
1792  const std::string name = "T_layout_trans",
1793  const std::string tag = kInjective) {
1794  Layout src_layout_struct(src_layout);
1795  Layout dst_layout_struct(dst_layout);
1796 
1797  if (src_layout_struct.Equals(dst_layout_struct)) {
1798  return src;
1799  }
1800 
1801  ICHECK(src_layout_struct.defined() && dst_layout_struct.defined())
1802  << "cannot convert from/to undefined layout";
1803 
1804  auto layout_converter = tir::BijectiveLayout(src_layout_struct, dst_layout_struct);
1805  ICHECK(layout_converter.defined())
1806  << "cannot convert from " << src_layout << " to " << dst_layout;
1807 
1808  ffi::Array<PrimExpr> dst_shape = layout_converter.ForwardShape(src->shape);
1809 
1810  ffi::Map<ffi::String, ffi::Any> attrs = {{"schedule_rule", ffi::String(schedule_rule)},
1811  // Information about layouts needed for the schedule rule
1812  {"src_layout", ffi::String(src_layout)},
1813  {"dst_layout", ffi::String(dst_layout)},
1814  {"input_shape", src->shape}};
1815 
1816  return compute(
1817  dst_shape,
1818  [&](const ffi::Array<Var>& dst_indices) {
1819  ffi::Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
1820  ffi::Array<PrimExpr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
1821  PrimExpr in_range = PrimExpr(1) > PrimExpr(0); // init with dtype=bool and value=true
1822  for (size_t i = 0; i < src.ndim(); ++i) {
1823  in_range = in_range && (src_indices[i] < src->shape[i]);
1824  }
1825  return if_then_else(in_range, src(src_indices), tvm::cast(src->dtype, PrimExpr(0)));
1826  },
1827  name, tag, attrs);
1828 }
1829 
1831 inline void parse_auto_scheduler_layout(const ffi::String& layout, ffi::Array<PrimExpr>* shape,
1832  std::vector<std::string>* axes) {
1833  int32_t factor = 0;
1834  std::string axis = "";
1835  for (char c : std::string(layout)) {
1836  if (c >= 'A' && c <= 'z') {
1837  axis += c;
1838  if (factor != 0) {
1839  shape->push_back(factor);
1840  factor = 0;
1841  }
1842  } else if (c >= '0' && c <= '9') {
1843  factor = factor * 10 + c - '0';
1844  if (!axis.empty()) {
1845  axes->push_back(axis);
1846  axis = "";
1847  }
1848  } else {
1849  LOG(FATAL) << "Invalid layout " << layout;
1850  }
1851  }
1852  if (!axis.empty()) {
1853  axes->push_back(axis);
1854  }
1855 }
1856 
1868  const Tensor& src, const ffi::String& src_layout, const ffi::String& dst_layout,
1869  const ffi::String name = "T_auto_scheduler_layout_trans", const ffi::String tag = kInjective) {
1870  ffi::Array<PrimExpr> src_shape;
1871  std::vector<std::string> src_axes;
1872  ffi::Array<PrimExpr> dst_shape;
1873  std::vector<std::string> dst_axes;
1874 
1875  parse_auto_scheduler_layout(src_layout, &src_shape, &src_axes);
1876  parse_auto_scheduler_layout(dst_layout, &dst_shape, &dst_axes);
1877  return compute(
1878  dst_shape,
1879  [&](const ffi::Array<Var>& dst_indices) {
1880  ffi::Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
1881  ffi::Array<PrimExpr> src_indices;
1882  for (const std::string& src_axis : src_axes) {
1883  PrimExpr src_index = 0;
1884  CHECK_EQ(dst_indices_expr.size(), dst_axes.size());
1885  for (size_t i = 0; i < dst_axes.size(); ++i) {
1886  if (dst_axes[i] == src_axis) {
1887  src_index = src_index * dst_shape[i] + dst_indices_expr[i];
1888  }
1889  }
1890  src_indices.push_back(src_index);
1891  }
1892  return src(src_indices);
1893  },
1894  name, tag);
1895 }
1896 
1934  const Tensor& src, const tir::IndexMap& index_map,
1935  const ffi::String name = "T_meta_schedule_layout_trans", const ffi::String tag = kInjective) {
1936  arith::Analyzer analyzer;
1937  ffi::Array<Range> iter_domain;
1938  iter_domain.reserve(src->shape.size());
1939  for (const PrimExpr& e : src->shape) {
1940  iter_domain.push_back(Range::FromMinExtent(make_zero(e->dtype), e));
1941  }
1942  ffi::Array<PrimExpr> post_transform_shape = index_map->MapShape(src->shape, &analyzer);
1943  return compute(
1944  post_transform_shape,
1945  [src, inv = index_map.Inverse(iter_domain, &analyzer),
1946  &analyzer](const ffi::Array<Var>& indices) -> PrimExpr {
1947  return src(
1948  inv->MapIndices(ffi::Array<PrimExpr>{indices.begin(), indices.end()}, &analyzer));
1949  },
1950  name, tag);
1951 }
1952 
1961 inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = "T_shape",
1962  const std::string tag = kInjective) {
1963  int ndim = static_cast<int>(src->shape.size());
1964  ffi::Array<PrimExpr> out_shape{ndim};
1965  return compute(
1966  out_shape,
1967  [&](const ffi::Array<Var>& indices) {
1968  auto idx = indices[0];
1969  PrimExpr ret = 0;
1970  for (int i = 0; i < ndim; ++i) {
1971  ret = tvm::if_then_else(idx == i, src->shape[i], ret);
1972  }
1973  return tvm::cast(dtype, ret);
1974  },
1975  name, tag);
1976 }
1977 
1986 inline te::Tensor tensor_size(const te::Tensor& src, const DataType& dtype,
1987  const std::string& name = "tensor_size",
1988  const std::string& tag = kInjective) {
1989  int ndim = static_cast<int>(src->shape.size());
1990  ffi::Array<PrimExpr> out_tensor_size = {};
1991  return compute(
1992  out_tensor_size,
1993  [&](const ffi::Array<Var>& indices) {
1994  PrimExpr ret = 1;
1995  for (int i = 0; i < ndim; ++i) {
1996  ret *= src->shape[i];
1997  }
1998  return tvm::cast(dtype, ret);
1999  },
2000  name, tag);
2001 }
2002 
2017 inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value,
2018  int depth, int axis, const DataType& dtype,
2019  ffi::Array<PrimExpr> oshape = ffi::Array<PrimExpr>(),
2020  const std::string name = "T_one_hot", const std::string tag = kInjective) {
2021  int true_axis = (axis == -1) ? indices->shape.size() : axis;
2022  if (oshape.size() == 0) {
2023  int ndim = indices->shape.size() + 1;
2024  int indices_index = 0;
2025  for (int i = 0; i < ndim; i++) {
2026  if (i == true_axis) {
2027  oshape.push_back(Integer(depth));
2028  } else {
2029  oshape.push_back(indices->shape[indices_index++]);
2030  }
2031  }
2032  }
2033 
2034  PrimExpr on_value_cast = cast(dtype, on_value);
2035  PrimExpr off_value_cast = cast(dtype, off_value);
2036  return compute(
2037  oshape,
2038  [&](const ffi::Array<Var>& iter_vars) {
2039  ffi::Array<Var> indices_indices;
2040  for (size_t i = 0; i < iter_vars.size(); i++) {
2041  if (static_cast<int>(i) == true_axis) {
2042  continue;
2043  }
2044 
2045  indices_indices.push_back(iter_vars[i]);
2046  }
2047 
2048  auto idx = iter_vars[true_axis];
2049  return tir::Select(indices(indices_indices) == idx, on_value_cast, off_value_cast);
2050  },
2051  name, tag);
2052 }
2053 
2064 inline Tensor sparse_to_dense(const Tensor& sparse_indices,
2065  const ffi::Array<PrimExpr>& output_shape, const Tensor& sparse_values,
2066  const PrimExpr& default_value,
2067  const std::string name = "T_sparse_to_dense",
2068  const std::string tag = kInjective) {
2069  ICHECK(sparse_indices->dtype.is_int()) << "sparse_indices only accepts integer values";
2070  ICHECK_LE(sparse_indices->shape.size(), 3)
2071  << "sparse_indices tensor should be 0D, 1D, or 2D only";
2072  ICHECK_LE(sparse_values->shape.size(), 2) << "sparse_values tensor should be 0D or 1D only";
2073 
2074  const auto rank_sparse_indices = static_cast<int>(sparse_indices->shape.size());
2075  ffi::Array<PrimExpr> oshape;
2076  for (auto l : output_shape) {
2077  oshape.push_back(l);
2078  }
2079  return compute(
2080  oshape,
2081  [&](const ffi::Array<Var>& indices) {
2082  PrimExpr ret = default_value;
2083  if (0 == rank_sparse_indices) {
2084  ret = if_then_else(indices[0] == sparse_indices(), sparse_values(), ret);
2085  } else if (1 == rank_sparse_indices) {
2086  for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
2087  ret = if_then_else(indices[0] == sparse_indices[j], sparse_values[j], ret);
2088  }
2089  } else {
2090  for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
2091  PrimExpr aggregate_condition;
2092  for (int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) {
2093  PrimExpr comparision = indices[k] == sparse_indices[j][k];
2094  aggregate_condition = 0 == k ? comparision : aggregate_condition && comparision;
2095  }
2096  ret = if_then_else(aggregate_condition, sparse_values[j], ret);
2097  }
2098  }
2099  return ret;
2100  },
2101  name, tag);
2102 }
2103 
2116 inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k1, int k2,
2117  bool super_diag_right_align, bool sub_diag_right_align,
2118  const std::string name = "T_matrix_set_diag",
2119  const std::string tag = kInjective) {
2120  size_t ndim = input->shape.size() - 1;
2121 
2122  bool only_one_diagonal = k1 == k2;
2123 
2124  return compute(
2125  input->shape,
2126  [&](const ffi::Array<Var>& iter_vars) {
2127  auto get_diag = [&]() {
2128  ffi::Array<PrimExpr> diagonal_indices;
2129  PrimExpr k, offset = 0;
2130  for (size_t i = 0; i < ndim - 1; i++) {
2131  diagonal_indices.push_back(iter_vars[i]);
2132  }
2133  if (only_one_diagonal) {
2134  k = k1;
2135  } else {
2136  // Determining which diagonal/sub-diagonal/super-diagonal it is
2137  k = iter_vars[ndim] - iter_vars[ndim - 1];
2138  diagonal_indices.push_back(k2 - k);
2139 
2140  // Calculating the offset in diagonal tensor for this diagonal
2141  auto get_offset = [&](PrimExpr M, PrimExpr N) {
2142  // offset = max_diagonal_length - diagonal_length
2143  return diagonal->shape[diagonal->shape.size() - 1] - if_then_else(M < N, M, N);
2144  };
2145  offset = if_then_else(
2146  k >= 0,
2147  super_diag_right_align ? get_offset(input->shape[ndim] - k, input->shape[ndim - 1])
2148  : 0,
2149  sub_diag_right_align ? get_offset(input->shape[ndim], input->shape[ndim - 1] + k)
2150  : 0);
2151  }
2152  diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) +
2153  offset);
2154  return diagonal(diagonal_indices);
2155  };
2156  return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1,
2157  if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2,
2158  get_diag(), input(iter_vars)),
2159  input(iter_vars));
2160  },
2161  name, tag);
2162 }
2163 
2172 inline Tensor adv_index(const Tensor& data, const ffi::Array<Tensor>& indices,
2173  const std::string name = "advanced_index",
2174  const std::string tag = kInjective) {
2175  ICHECK_LE(indices.size(), data->shape.size()) << "too many indices for data!";
2176  ffi::Array<PrimExpr> oshape;
2177  ffi::Array<PrimExpr> broadcast_shape;
2178  ffi::Array<Tensor> bindices;
2179 
2180  broadcast_shape = indices[0]->shape;
2181  for (size_t i = 1; i < indices.size(); ++i) {
2182  auto bh = detail::BroadcastShape(broadcast_shape, indices[i]->shape);
2183  broadcast_shape = ffi::Array<PrimExpr>(bh.common_shape.begin(), bh.common_shape.end());
2184  }
2185  if (indices.size() == 1) {
2186  // quick path
2187  bindices = indices;
2188  } else {
2189  // Do broadcast for indices
2190  for (size_t i = 0; i < indices.size(); ++i) {
2191  bindices.push_back(broadcast_to(indices[i], broadcast_shape));
2192  }
2193  }
2194 
2195  for (const auto& dim : broadcast_shape) {
2196  oshape.push_back(dim);
2197  }
2198  for (size_t i = indices.size(); i < data->shape.size(); ++i) {
2199  oshape.push_back(data->shape[i]);
2200  }
2201 
2202  return compute(
2203  oshape,
2204  [&](const ffi::Array<Var>& iter_var) {
2205  ffi::Array<PrimExpr> tensor_indices;
2206  for (size_t i = 0; i < broadcast_shape.size(); ++i) {
2207  tensor_indices.push_back(iter_var[i]);
2208  }
2209  ffi::Array<PrimExpr> real_indices;
2210  for (size_t i = 0; i < bindices.size(); ++i) {
2211  real_indices.push_back(bindices[i](tensor_indices));
2212  }
2213  for (size_t i = broadcast_shape.size(); i < iter_var.size(); ++i) {
2214  real_indices.push_back(iter_var[i]);
2215  }
2216 
2217  return data(real_indices);
2218  },
2219  name, tag);
2220 }
2221 
2222 namespace relax {
2223 // relax dynamic slice
2225  const te::Tensor& end, const te::Tensor& strides,
2226  ffi::Array<PrimExpr> output_shape,
2227  std::string name = "T_strided_slice_dynamic",
2228  std::string tag = kInjective) {
2229  const size_t num_dynamic_axes = x.ndim();
2230  ICHECK_EQ(begin.ndim(), 1);
2231  ICHECK_EQ(end.ndim(), 1);
2232  ICHECK_EQ(strides.ndim(), 1);
2233  const auto* len_begin = begin->shape[0].as<IntImmNode>();
2234  const auto* len_end = end->shape[0].as<IntImmNode>();
2235  const auto* len_strides = strides->shape[0].as<IntImmNode>();
2236  ICHECK(len_begin);
2237  ICHECK(len_end);
2238  ICHECK(len_strides);
2239  ICHECK_EQ(len_begin->value, num_dynamic_axes);
2240  ICHECK_EQ(len_end->value, num_dynamic_axes);
2241  ICHECK_EQ(len_strides->value, num_dynamic_axes);
2242 
2243  return te::compute(
2244  output_shape,
2245  [&](const ffi::Array<tvm::tir::Var>& indices) {
2246  ffi::Array<PrimExpr> real_indices;
2247  for (size_t i = 0; i < num_dynamic_axes; ++i) {
2248  auto ind = make_const(DataType::Int(64), i);
2249  real_indices.push_back(indices[i] * strides(ind) + tvm::min(begin(ind), x->shape[i] - 1));
2250  }
2251  return x(real_indices);
2252  },
2253  name, tag);
2254 }
2255 
2256 } // namespace relax
2257 
2258 } // namespace topi
2259 } // namespace tvm
2260 #endif // TVM_TOPI_TRANSFORM_H_
Algebra expression simplifications.
Broadcast op constructions.
Managed reference class to FloatImmNode.
Definition: expr.h:545
Constant integer literals in the program.
Definition: expr.h:493
int64_t value
the Internal value.
Definition: expr.h:496
Managed reference class to IntImmNode.
Definition: expr.h:510
Container of constant int that adds more constructors.
Definition: expr.h:600
Reference to PrimExprNode.
Definition: expr.h:124
DataType dtype() const
Definition: expr.h:138
Range container
Definition: expr.h:689
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span=Span())
construct a new range with min and extent The corresponding constructor is removed,...
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:634
bool CanProveGreaterEqual(const PrimExpr &expr, int64_t lower_bound)
Whether can we prove expr >= val.
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
bool CanProveLess(const PrimExpr &expr, int64_t upper_bound)
Whether can we prove expr < val.
Runtime primitive data type.
Definition: data_type.h:47
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:294
bool is_int() const
Definition: data_type.h:193
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:277
Managed Tensor. The array is backed by reference counted blocks.
Definition: tensor.h:53
Node to represent a tensor.
Definition: tensor.h:70
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:100
size_t ndim() const
Definition: tensor.h:212
Bijective function mapping for data layout transformation. Given two Layout, BijectiveLayout build an...
Definition: data_layout.h:333
Definition: index_map.h:169
IndexMap Inverse(ffi::Array< Range > initial_ranges, arith::Analyzer *analyzer) const
Generate the inverse mapping.
Managed reference to LayoutNode.
Definition: data_layout.h:124
bool Equals(const Layout &rhs) const
Whether the two layouts are equal.
Definition: data_layout.h:279
Managed reference to SelectNode.
Definition: expr.h:515
A variable node in the IR.
Definition: var.h:48
ffi::String name_hint
The hint to the variable name.
Definition: var.h:54
a named variable in TIR
Definition: var.h:77
Utility functions for handling constants in TVM expressions.
Layout expression to describe the data organization of a tensor. And BijectiveLayout to mapping two d...
Detail broadcast.
Defines a remapping of buffer indices.
Base expr nodes in TVM.
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(ffi::Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", ffi::Map< ffi::String, ffi::Any > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:994
DataType DefaultIndexType()
if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return int32
Definition: buffer.h:43
PrimExpr make_zero(DataType t, Span span=Span())
Make a const zero expr.
Definition: op.h:1008
te::Tensor dynamic_strided_slice(const te::Tensor &x, const te::Tensor &begin, const te::Tensor &end, const te::Tensor &strides, ffi::Array< PrimExpr > output_shape, std::string name="T_strided_slice_dynamic", std::string tag=kInjective)
Definition: transform.h:2224
PrimExpr GetLength(PrimExpr begin, PrimExpr end, PrimExpr stride, PrimExpr extent, bool assume_inbound=true)
Definition: transform.h:686
Tensor sequence_mask(const Tensor &data, const Tensor &valid_length, double mask_value, int axis, std::string name="T_sequence_mask", std::string tag=kInjective)
Mask the out-of-boundary elements of each sequence.
Definition: transform.h:1082
Tensor gather_nd(const Tensor &data, const Tensor &indices, int batch_dims=0, std::string name="T_gather_nd", std::string tag=kInjective)
Gather elements from a n-dimension array.
Definition: transform.h:1544
int64_t StaticCanonicalizeIndex(int64_t index, int64_t extent, int64_t stride)
Definition: transform.h:667
Tensor reshape(const Tensor &x, ffi::Array< PrimExpr > newshape, std::string name="T_reshape", std::string tag=kInjective)
Reshape a tensor.
Definition: transform.h:328
Tensor one_hot(const Tensor &indices, const PrimExpr on_value, const PrimExpr off_value, int depth, int axis, const DataType &dtype, ffi::Array< PrimExpr > oshape=ffi::Array< PrimExpr >(), const std::string name="T_one_hot", const std::string tag=kInjective)
Returns a one-hot tensor where the locations repsented by indices take value on_value,...
Definition: transform.h:2017
tvm::te::Tensor broadcast_to(const tvm::te::Tensor &t, const tvm::ffi::Array< tvm::PrimExpr > &output_shape, std::string name="T_broadcast_to", std::string tag=kBroadcast)
Creates an operation that broadcasts a tensor into a compatible shape according to numpy's rules.
Definition: broadcast.h:48
constexpr auto kBroadcast
Definition: tags.h:36
Tensor arange(const PrimExpr &start, const PrimExpr &stop, const PrimExpr &step, DataType dtype, std::string name="T_arange", std::string tag=kInjective)
Definition: transform.h:1719
constexpr auto kInjective
Definition: tags.h:33
Tensor stack(const ffi::Array< Tensor > &inputs, int axis=0, std::string name="T_stack", std::string tag=kInjective)
Join a sequence of tensors along a new axis.
Definition: transform.h:538
Tensor auto_scheduler_layout_transform(const Tensor &src, const ffi::String &src_layout, const ffi::String &dst_layout, const ffi::String name="T_auto_scheduler_layout_trans", const ffi::String tag=kInjective)
Transform the auto-scheduler generated layout according to src_layout and dst_layout.
Definition: transform.h:1867
ffi::Array< PrimExpr > StridedSliceOutputShape(const ffi::Array< PrimExpr > &ishape, const ffi::Array< Integer > &begin, const ffi::Array< Integer > &end, const ffi::Array< Integer > &strides, const ffi::Array< Integer > &axes, const std::string &slice_mode)
Calculate the output shape of strided_slice, the entry point for Relax type relation.
Definition: transform.h:865
PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stride)
Definition: transform.h:676
te::Tensor dynamic_strided_slice_with_axes(const te::Tensor &x, const ffi::Array< PrimExpr > &begin, const ffi::Array< PrimExpr > &end, const ffi::Array< PrimExpr > &strides, const ffi::Array< Integer > &axes, bool assume_inbound=true, std::string name="T_dynamic_strided_slice_with_axes", std::string tag=kInjective)
strided_slice of a tensor where begin/end/stride can be mixed static and dynamic
Definition: transform.h:713
Tensor transpose(const Tensor &x, ffi::Optional< ffi::Array< Integer >> opt_axes, std::string name="T_transpose", std::string tag=kInjective)
Permute the dimensions of an array.
Definition: transform.h:204
void parse_auto_scheduler_layout(const ffi::String &layout, ffi::Array< PrimExpr > *shape, std::vector< std::string > *axes)
Utility function for auto_scheduler_layout_transform.
Definition: transform.h:1831
Tensor squeeze(const Tensor &x, ffi::Optional< ffi::Array< Integer >> opt_axes, bool atleast1d=false, std::string name="T_squeeze", std::string tag=kInjective)
Remove size 1 dimensions from the shape of a tensor. The removed dimensions must have a constant size...
Definition: transform.h:413
ffi::Array< Tensor > split_n_sections(const Tensor &x, int num_sections, int axis, std::string name="T_split_sections", std::string tag=kInjective)
Split a tensor into a number of sub-tensors.
Definition: transform.h:986
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type,...
Definition: elemwise.h:281
Tensor expand_dims(const Tensor &x, int axis, int num_newaxis=1, std::string name="T_expand_dims", std::string tag=kBroadcast)
Creates an operation to insert new dimensions of length 1.
Definition: transform.h:155
Tensor sparse_to_dense(const Tensor &sparse_indices, const ffi::Array< PrimExpr > &output_shape, const Tensor &sparse_values, const PrimExpr &default_value, const std::string name="T_sparse_to_dense", const std::string tag=kInjective)
Get a dense tensor.
Definition: transform.h:2064
Tensor unravel_index(const Tensor &x, const Tensor &shape, std::string name="T_unravel", std::string tag=kInjective)
Converts a flat index or array of flat indices into a tuple of coordinate arrays.
Definition: transform.h:365
Tensor layout_transform(const Tensor &src, const std::string &src_layout, const std::string &dst_layout, const std::string schedule_rule="None", const std::string name="T_layout_trans", const std::string tag=kInjective)
Transform the layout according to src_layout and dst_layout.
Definition: transform.h:1789
Tensor adv_index(const Tensor &data, const ffi::Array< Tensor > &indices, const std::string name="advanced_index", const std::string tag=kInjective)
Numpy style advanced indexing with tensor.
Definition: transform.h:2172
Tensor strided_slice(const Tensor &x, const ffi::Array< Integer > &begin, const ffi::Array< Integer > &end, const ffi::Array< Integer > &strides, std::string slice_mode="end", std::string name="T_strided_slice", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:944
Tensor concatenate(const ffi::Array< Tensor > &inputs, int axis=0, std::string name="T_concat", std::string tag=kInjective)
Join a sequence of tensors along an existing axis.
Definition: transform.h:479
ffi::Array< Tensor > meshgrid(const ffi::Array< Tensor > &inputs, const std::string &indexing, std::string name="T_meshgrid", std::string tag=kInjective)
Produce grids by expanding input over dimensions defined by other inputs.
Definition: transform.h:1753
constexpr auto kMatMul
Definition: tags.h:37
Tensor strided_slice_with_axes(const Tensor &x, const ffi::Array< Integer > &begin, const ffi::Array< Integer > &end, const ffi::Array< Integer > &strides, const ffi::Array< Integer > &axes, std::string slice_mode="end", std::string name="T_strided_slice_with_axes", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:896
Tensor dyn_tile(const Tensor &x, ffi::Array< PrimExpr > new_shape, size_t rdim, std::string name="T_tile", std::string tag=kBroadcast)
Creates an operation to tile elements of an array.
Definition: transform.h:1453
Tensor reverse_sequence(const Tensor &x, const Tensor &seq_lengths, int seq_axis=1, int batch_axis=0, std::string name="T_reverse_sequence", std::string tag=kInjective)
Reverse the tensor for variable length slices. Input is first sliced along batch axis and then elemen...
Definition: transform.h:263
ffi::Array< Tensor > split_indices_array(const Tensor &x, ffi::Array< PrimExpr > split_indices, int axis, std::string name="T_split", std::string tag=kInjective)
Split a tensor into multiple sub-tensors.
Definition: transform.h:584
Tensor tensordot(const Tensor &A, const tvm::te::Tensor &B, int axes=2, std::string name="T_tensordot", std::string tag=kMatMul)
A generalization of matrix multiplication to tensors.
Definition: transform.h:1627
Tensor sum(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:328
Tensor meta_schedule_layout_transform(const Tensor &src, const tir::IndexMap &index_map, const ffi::String name="T_meta_schedule_layout_trans", const ffi::String tag=kInjective)
Transform the meta-schedule generated layout according to TIR's IndexMap.
Definition: transform.h:1933
Tensor tile(const Tensor &x, ffi::Array< Integer > reps, std::string name="T_tile", std::string tag=kBroadcast)
Creates an operation to tile elements of an array.
Definition: transform.h:1397
Tensor take(const Tensor &a, const Tensor &indices, int batch_dims, std::string mode="fast", std::string name="T_take", std::string tag=kInjective)
Take elements from an flattened input array when axis is None.
Definition: transform.h:1022
PrimExpr DynamicCanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stride)
Definition: transform.h:649
tvm::te::Tensor matmul(const tvm::te::Tensor &A, const tvm::te::Tensor &B, bool trans_a=false, bool trans_b=false, std::string name="T_matmul", std::string tag=kMatMul)
Creates an operation that calculates a matrix multiplication (row-major notation): A(i,...
Definition: transform.h:1605
Tensor dynamic_strided_slice(const Tensor &x, const ffi::Array< PrimExpr > &begin, const ffi::Array< PrimExpr > &end, const ffi::Array< PrimExpr > &strides, bool assume_inbound=true, std::string name="T_dynamic_strided_slice", std::string tag=kInjective)
strided_slice of a tensor where begin/end/stride can be mixed static and dynamic
Definition: transform.h:770
Tensor matrix_set_diag(const Tensor &input, const Tensor &diagonal, int k1, int k2, bool super_diag_right_align, bool sub_diag_right_align, const std::string name="T_matrix_set_diag", const std::string tag=kInjective)
Returns a tensor with the diagonal of input tensor replaced with the provided diagonals.
Definition: transform.h:2116
Tensor where(const Tensor &condition, const Tensor &x, const Tensor &y, std::string name="T_where", std::string tag=kBroadcast)
Return the elements, either from x or y, depending on the condition.
Definition: transform.h:1310
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1961
Tensor gather(const Tensor &data, int axis, const Tensor &indices, std::string name="T_gather", std::string tag=kInjective)
Gather values along given axis from given indices.
Definition: transform.h:1491
Tensor sliding_window(const Tensor &x, int axis, ffi::Array< Integer > window_shape, ffi::Array< Integer > strides, std::string name="T_sliding_window", std::string tag="")
Creates an operation to slide a window over the input x.
Definition: transform.h:76
te::Tensor tensor_size(const te::Tensor &src, const DataType &dtype, const std::string &name="tensor_size", const std::string &tag=kInjective)
Get the size of input tensor.
Definition: transform.h:1986
Tensor repeat(const Tensor &x, int repeats, int axis, std::string name="T_repeat", std::string tag=kBroadcast)
Creates an operation to repeat elements of an array.
Definition: transform.h:1350
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span=Span())
compute ceil(a / b)
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of truncdiv
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr max_value(const DataType &dtype, Span span=Span())
PrimExpr ceil(PrimExpr x, Span span=Span())
Calculate ceil(x)
PrimExpr sum(PrimExpr source, ffi::Array< tir::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b)
Operation node can generate one or multiple Tensors.
Index ravel and unraval operations.
Utility functions for strided_slice op.
External function interface to rocBLAS libraries.
Utility functions for handling tensor.
TIR expressions.
Common operators defined for Expr.
Variables in the TIR.