tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_TRANSFORM_H_
25 #define TVM_TOPI_TRANSFORM_H_
26 
27 #include <tvm/arith/analyzer.h>
28 #include <tvm/s_tir/data_layout.h>
29 #include <tvm/te/operation.h>
30 #include <tvm/tirx/index_map.h>
31 #include <tvm/topi/broadcast.h>
37 #include <tvm/topi/tags.h>
38 
39 #include <algorithm>
40 #include <iterator>
41 #include <limits>
42 #include <string>
43 #include <unordered_set>
44 #include <utility>
45 #include <vector>
46 
47 #include "tvm/ir/expr.h"
48 #include "tvm/runtime/data_type.h"
49 #include "tvm/tirx/expr.h"
50 #include "tvm/tirx/op.h"
51 #include "tvm/tirx/var.h"
52 
53 namespace tvm {
54 namespace topi {
55 
56 using namespace tvm::te;
57 using namespace topi::detail;
58 
76 inline Tensor sliding_window(const Tensor& x, int axis, ffi::Array<Integer> window_shape,
77  ffi::Array<Integer> strides, std::string name = "T_sliding_window",
78  std::string tag = "") {
79  TVM_FFI_ICHECK_GE(axis, 0);
80  auto _axis = size_t(axis);
81  TVM_FFI_ICHECK_LT(_axis, x->shape.size()) << "axis must be a valid dimension index of x.";
82  TVM_FFI_ICHECK_EQ(x->shape.size() - _axis, window_shape.size())
83  << "There must be a window shape for every dimension of x "
84  << "over which we are sliding the window.";
85  TVM_FFI_ICHECK_EQ(strides.size(), window_shape.size())
86  << "Windows and strides should be the same length.";
87 
88  // Compute the new shape.
89  ffi::Array<PrimExpr> new_shape;
90  // Dimensions up until `axis` remain the same.
91  for (size_t i = 0; i < _axis; ++i) {
92  new_shape.push_back(x->shape[i]);
93  }
94 
95  // New dimensions which result from sliding the window in each dimension. One new dimension per
96  // window dimension.
97  for (size_t i = 0; i < window_shape.size(); ++i) {
98  // Length of the shape along this dimension.
99  auto dim_len = x->shape[_axis + i];
100  // Length of the window along this dimension.
101  auto window_len = window_shape[i];
102  // Strides along this dimension.
103  auto stride = strides[i];
104 
105  new_shape.push_back(floordiv(dim_len - (window_len - 1) + stride - 1, stride));
106  }
107 
108  // Dimensions comprising the window.
109  for (size_t i = 0; i < window_shape.size(); ++i) {
110  new_shape.push_back(window_shape[i]);
111  }
112 
113  TVM_FFI_ICHECK(new_shape.size() == _axis + 2 * window_shape.size());
114 
115  return compute(
116  new_shape,
117  [&](const ffi::Array<Var>& indices) {
118  // The index at which to index the old tensor x.
119  ffi::Array<PrimExpr> idx;
120 
121  // Dimensions up until `axis` remain the same.
122  for (size_t i = 0; i < _axis; ++i) {
123  idx.push_back(indices[i]);
124  }
125 
126  for (size_t i = 0; i < window_shape.size(); ++i) {
127  // Which window in this dimension we are indexing.
128  auto window_idx = indices[_axis + i];
129  // Which index within the window we are indexing.
130  auto idx_within_window = indices[_axis + window_shape.size() + i];
131  // Stride value for this dimension.
132  auto stride = strides[i];
133 
134  idx.push_back(window_idx * stride + idx_within_window);
135  }
136 
137  TVM_FFI_ICHECK(idx.size() == x->shape.size());
138 
139  return x(idx);
140  },
141  name, tag);
142 }
143 
156 inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1,
157  std::string name = "T_expand_dims", std::string tag = kBroadcast) {
158  int ndim = static_cast<int>(x->shape.size());
159  TVM_FFI_ICHECK(-ndim - 1 <= axis && axis <= ndim)
160  << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]"
161  << ", but got axis = " << axis << ", and data.ndim = " << ndim;
162  TVM_FFI_ICHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`"
163  << ", but got num_newaxis = " << num_newaxis;
164  if (axis < 0) {
165  // Calculate offset from last dimension
166  axis = ndim + axis + 1;
167  }
168  ffi::Array<PrimExpr> new_shape;
169  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
170  new_shape.push_back(x->shape[i]);
171  }
172  for (size_t i = 0; i < static_cast<size_t>(num_newaxis); ++i) {
173  new_shape.push_back(1);
174  }
175  for (size_t i = axis; i < x->shape.size(); ++i) {
176  new_shape.push_back(x->shape[i]);
177  }
178 
179  return compute(
180  new_shape,
181  [&](const ffi::Array<Var>& indices) {
182  ffi::Array<PrimExpr> idx;
183  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
184  idx.push_back(indices[i]);
185  }
186  for (size_t i = axis + num_newaxis; i < indices.size(); ++i) {
187  idx.push_back(indices[i]);
188  }
189  return x(idx);
190  },
191  name, tag);
192 }
193 
205 inline Tensor transpose(const Tensor& x, ffi::Optional<ffi::Array<Integer>> opt_axes,
206  std::string name = "T_transpose", std::string tag = kInjective) {
207  ffi::Array<Integer> axes = opt_axes.value_or({});
208  if (axes.size() == 0) {
209  for (int i = static_cast<int>(x->shape.size()) - 1; i >= 0; --i) {
210  axes.push_back(i);
211  }
212  }
213 
214  ffi::Array<PrimExpr> new_shape;
215  for (size_t i = 0; i < axes.size(); ++i) {
216  int axis = static_cast<int>(axes[i]->value);
217  int new_axis = axis;
218  if (axis < 0) {
219  new_axis = static_cast<int>(x->shape.size()) + axis;
220  axes.Set(i, new_axis);
221  }
222  TVM_FFI_ICHECK((new_axis >= 0) && (new_axis < static_cast<int>(x->shape.size())))
223  << "axis=" << axis << " is invalid for the " << static_cast<int>(x->shape.size())
224  << "-dimensional input tensor";
225 
226  for (size_t j = 0; j < axes.size(); ++j) {
227  if (i != j) {
228  TVM_FFI_ICHECK(new_axis != static_cast<int>(axes[j]->value))
229  << "repeated axis in transpose";
230  }
231  }
232  new_shape.push_back(x->shape[new_axis]);
233  }
234 
235  return compute(
236  new_shape,
237  [&](const ffi::Array<Var>& indices) {
238  std::vector<PrimExpr> idx;
239  for (size_t i = 0; i < axes.size(); ++i) {
240  idx.push_back(1);
241  }
242  for (size_t i = 0; i < axes.size(); ++i) {
243  int axis = static_cast<int>(axes[i]->value);
244  idx[axis] = indices[i];
245  }
246  return x(idx);
247  },
248  name, tag);
249 }
250 
265 inline Tensor reverse_sequence(const Tensor& x, const Tensor& seq_lengths, int seq_axis = 1,
266  int batch_axis = 0, std::string name = "T_reverse_sequence",
267  std::string tag = kInjective) {
268  size_t src_tensor_dim = x->shape.size();
269  int seq_axis_inp = seq_axis;
270 
271  if (seq_lengths.defined()) {
272  size_t seq_lengths_dim = seq_lengths->shape.size();
273  int batch_axis_inp = batch_axis;
274  if (batch_axis < 0) {
275  batch_axis = static_cast<int>(x->shape.size()) + batch_axis;
276  }
277 
278  TVM_FFI_ICHECK(seq_lengths_dim == 1) << "seq_lengths should be 1D vector";
279 
280  TVM_FFI_ICHECK(GetConstInt(seq_lengths->shape[0]) == GetConstInt(x->shape[batch_axis]))
281  << "For reverse_sequnece seq_lengths size should match with dimension of batch axis"
282  << ", but got dimension of batch_axis = " << GetConstInt(x->shape[batch_axis])
283  << ", and seq_length size = " << GetConstInt(seq_lengths->shape[0]);
284 
285  TVM_FFI_ICHECK((0 <= batch_axis) && (batch_axis < static_cast<int>(x->shape.size())))
286  << "batch_axis=" << batch_axis_inp << " is invalid for the "
287  << static_cast<int>(x->shape.size()) << "-dimensional input tensor";
288  }
289 
290  if (seq_axis < 0) {
291  seq_axis = static_cast<int>(x->shape.size()) + seq_axis;
292  }
293  TVM_FFI_ICHECK((0 <= seq_axis) && (seq_axis < static_cast<int>(x->shape.size())))
294  << "seq_axis=" << seq_axis_inp << " is invalid for the " << static_cast<int>(x->shape.size())
295  << "-dimensional input tensor";
296 
297  auto func = [&](const ffi::Array<Var>& indices) {
298  ffi::Array<PrimExpr> real_indices;
299  for (size_t i = 0; i < src_tensor_dim; ++i) {
300  if (i == static_cast<size_t>(seq_axis)) {
301  if (seq_lengths.defined()) {
302  auto len = seq_lengths(indices[batch_axis]);
303  auto idx = if_then_else(
304  len <= 1 || len <= indices[i], indices[i],
305  if_then_else(len > x->shape[i], x->shape[i] - 1 - indices[i], len - 1 - indices[i]));
306  real_indices.push_back(idx);
307  } else {
308  real_indices.push_back(x->shape[i] - 1 - indices[i]);
309  }
310  } else {
311  real_indices.push_back(indices[i]);
312  }
313  }
314  return x(real_indices);
315  };
316 
317  return compute(x->shape, func, name, tag);
318 }
319 
330 inline Tensor reshape(const Tensor& x, ffi::Array<PrimExpr> newshape,
331  std::string name = "T_reshape", std::string tag = kInjective) {
332  auto x_shape = x->shape;
333  ffi::Array<PrimExpr> target_shape;
334 
335  for (const auto& ele : newshape) {
336  target_shape.push_back(ele);
337  }
338 
339  // If either the input shape or the target shape contains a zero, return an empty tensor.
340  if (is_empty_shape(target_shape) || is_empty_shape(x->shape)) {
341  return compute(
342  target_shape, [&](const ffi::Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name,
343  tag);
344  } else {
345  return compute(
346  target_shape,
347  [&](const ffi::Array<Var>& indices) {
348  return x(UnravelIndex(
349  RavelIndex(ffi::Array<PrimExpr>{indices.begin(), indices.end()}, target_shape),
350  x_shape));
351  },
352  name, tag);
353  }
354 }
355 
367 inline Tensor unravel_index(const Tensor& x, const Tensor& shape, std::string name = "T_unravel",
368  std::string tag = kInjective) {
369  auto x_shape = x->shape;
370  auto shape_shape = shape->shape;
371 
372  ffi::Array<PrimExpr> oshape;
373  oshape.push_back(shape_shape[0]);
374  if (x_shape.size() != 0) {
375  oshape.push_back(x_shape[0]);
376  }
377 
378  auto func = [&](const ffi::Array<Var>& indices) {
379  auto i = indices[0];
380  std::vector<PrimExpr> indices_divs;
381  PrimExpr ret = 0;
382  PrimExpr cur_val = 0;
383  PrimExpr index_val = 0;
384 
385  if (x_shape.size() != 0) {
386  index_val = x[indices[1]];
387  } else {
388  index_val = x();
389  }
390  indices_divs.push_back(index_val);
391  for (int v = GetConstInt(shape_shape[0]) - 1; v >= 0; --v) {
392  ret = tvm::if_then_else(i == v, indexmod(indices_divs.back(), shape[v]), ret);
393  cur_val = indexdiv(indices_divs.back(), shape[v]);
394  indices_divs.push_back(cur_val);
395  }
396  return ret;
397  };
398 
399  return compute(oshape, func, name, tag);
400 }
401 
415 inline Tensor squeeze(const Tensor& x, ffi::Optional<ffi::Array<Integer>> opt_axes,
416  bool atleast1d = false, std::string name = "T_squeeze",
417  std::string tag = kInjective) {
418  auto ndim = x->shape.size();
419  std::vector<int> axis_val;
420  if (!opt_axes.has_value()) {
421  for (size_t i = 0; i < ndim; ++i) {
422  if (IsConstInt(x->shape[i]) && GetConstInt(x->shape[i]) == 1) {
423  axis_val.push_back(static_cast<int>(i));
424  }
425  }
426  } else {
427  ffi::Array<Integer> axis = *std::move(opt_axes);
428  for (size_t i = 0; i < axis.size(); ++i) {
429  int64_t val = axis[i]->value;
430  if (val < 0) {
431  val += static_cast<int>(x->shape.size());
432  }
433  // If a dimension is not 1, silently skip it (no-op).
434  bool is_const = IsConstInt(x->shape[val]);
435  if ((is_const && GetConstInt(x->shape[val]) == 1) || !is_const) {
436  axis_val.push_back(val);
437  }
438  }
439  }
440 
441  std::unordered_set<int> axis_set(axis_val.begin(), axis_val.end());
442 
443  ffi::Array<PrimExpr> out_shape;
444  for (size_t i = 0; i < ndim; ++i) {
445  if (axis_set.count(static_cast<int>(i)) == 0) {
446  out_shape.push_back(x->shape[i]);
447  }
448  }
449  if (out_shape.size() == 0 && atleast1d) {
450  out_shape.push_back(1);
451  }
452 
453  return compute(
454  out_shape,
455  [&](const ffi::Array<Var>& indices) {
456  ffi::Array<PrimExpr> real_indices;
457  int flag = 0;
458  for (size_t i = 0; i < ndim; ++i) {
459  if (axis_set.count(static_cast<int>(i)) == 0) {
460  real_indices.push_back(indices[i - flag]);
461  } else {
462  real_indices.push_back(0);
463  flag += 1;
464  }
465  }
466  return x(real_indices);
467  },
468  name, tag);
469 }
470 
481 inline Tensor concatenate(const ffi::Array<Tensor>& inputs, int axis = 0,
482  std::string name = "T_concat", std::string tag = kInjective) {
483  int ndim = static_cast<int>(inputs[0]->shape.size());
484  TVM_FFI_ICHECK(-ndim <= axis && axis < ndim)
485  << "concatenate only accepts `axis` in [-ndim, ndim)"
486  << ", but got axis = " << axis << ", and ndim = " << ndim;
487  if (axis < 0) {
488  axis += ndim;
489  }
490  TVM_FFI_ICHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds";
491 
492  ffi::Array<PrimExpr> axis_sizes;
493  for (auto t : inputs) {
494  axis_sizes.push_back(t->shape[axis]);
495  }
496  arith::Analyzer analyzer;
497  PrimExpr join_size = axis_sizes[0];
498  for (size_t i = 1; i < axis_sizes.size(); ++i) {
499  join_size += axis_sizes[i];
500  }
501  join_size = analyzer.Simplify(join_size);
502  ffi::Array<PrimExpr> out_shape;
503  for (size_t i = 0; i < inputs[0]->shape.size(); ++i) {
504  out_shape.push_back(i == static_cast<size_t>(axis) ? join_size : inputs[0]->shape[i]);
505  }
506 
507  return compute(
508  out_shape,
509  [&](const ffi::Array<Var>& indices) {
510  auto ret = inputs[0](indices);
511  auto ind = indices[axis];
512  for (size_t i = 0; i < inputs.size() - 1; ++i) {
513  ind -= axis_sizes[i];
514 
515  ffi::Array<PrimExpr> idx;
516  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
517  idx.push_back(indices[i]);
518  }
519  idx.push_back(ind);
520  for (size_t i = axis + 1; i < indices.size(); ++i) {
521  idx.push_back(indices[i]);
522  }
523 
524  ret = tvm::if_then_else(ind >= 0, inputs[i + 1](idx), ret);
525  }
526  return ret;
527  },
528  name, tag);
529 }
530 
541 inline Tensor stack(const ffi::Array<Tensor>& inputs, int axis = 0, std::string name = "T_stack",
542  std::string tag = kInjective) {
543  int ndim = static_cast<int>(inputs[0]->shape.size());
544  TVM_FFI_ICHECK(-ndim - 1 <= axis && axis <= ndim)
545  << "stack only accepts `axis` in [-ndim, ndim)"
546  << ", but got axis = " << axis << ", and ndim = " << ndim;
547  if (axis < 0) {
548  axis += ndim + 1;
549  }
550  TVM_FFI_ICHECK_LT(axis, inputs[0]->shape.size() + 1) << "axis out of bounds";
551 
552  const int stack_size = static_cast<int>(inputs.size());
553  ffi::Array<PrimExpr> out_shape;
554  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) out_shape.push_back(inputs[0]->shape[i]);
555  out_shape.push_back(stack_size);
556  for (size_t i = static_cast<size_t>(axis); i < static_cast<size_t>(ndim); ++i)
557  out_shape.push_back(inputs[0]->shape[i]);
558 
559  return compute(
560  out_shape,
561  [&](const ffi::Array<Var>& indices) {
562  ffi::Array<PrimExpr> idx;
563  for (size_t i = 0; i < indices.size(); ++i)
564  if (i != static_cast<size_t>(axis)) idx.push_back(indices[i]);
565  auto ind = indices[axis];
566  auto ret = inputs[0](idx);
567  for (int i = 0; i < static_cast<int>(inputs.size() - 1); ++i) {
568  ret = tvm::if_then_else(ind == i + 1, inputs[i + 1](idx), ret);
569  }
570  return ret;
571  },
572  name, tag);
573 }
574 
587 inline ffi::Array<Tensor> split_indices_array(const Tensor& x, ffi::Array<PrimExpr> split_indices,
588  int axis, std::string name = "T_split",
589  std::string tag = kInjective) {
590  if (axis < 0) {
591  axis += static_cast<int>(x->shape.size());
592  }
593  TVM_FFI_ICHECK_LT(axis, x->shape.size()) << "axis out of bounds";
594 
595  auto src_axis_size = x->shape[axis];
596  std::vector<PrimExpr> begin_ids;
597  begin_ids.push_back(0);
598 
599  for (auto idx : split_indices) {
600  auto idx_node = idx.as<IntImmNode>();
601  auto back_node = begin_ids.back().as<IntImmNode>();
602  if (idx_node && back_node) {
603  TVM_FFI_ICHECK_GT(idx_node->value, back_node->value) << "split_indices must be sorted";
604  }
605  begin_ids.push_back(idx);
606  }
607 
608  ffi::Array<ffi::Array<PrimExpr>> out_shapes;
609  for (size_t i = 0; i < begin_ids.size(); ++i) {
610  PrimExpr out_axis_size;
611  if (i == begin_ids.size() - 1) {
612  out_axis_size = src_axis_size - begin_ids[i];
613  } else {
614  out_axis_size = begin_ids[i + 1] - begin_ids[i];
615  }
616 
617  ffi::Array<PrimExpr> shape;
618  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
619  shape.push_back(x->shape[i]);
620  }
621  shape.push_back(out_axis_size);
622  for (size_t i = axis + 1; i < x->shape.size(); ++i) {
623  shape.push_back(x->shape[i]);
624  }
625 
626  out_shapes.push_back(shape);
627  }
628 
629  ffi::Array<Tensor> result;
630  for (size_t i = 0; i < begin_ids.size(); ++i) {
631  result.push_back(compute(
632  out_shapes[i],
633  [&](const ffi::Array<Var>& indices) {
634  auto begin = begin_ids[i];
635  ffi::Array<PrimExpr> real_indices;
636  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
637  real_indices.push_back(indices[j]);
638  }
639  real_indices.push_back(indices[axis] + begin);
640  for (size_t j = axis + 1; j < indices.size(); ++j) {
641  real_indices.push_back(indices[j]);
642  }
643 
644  return x(real_indices);
645  },
646  name, tag));
647  }
648 
649  return result;
650 }
651 
653  auto idx_var = index.as<tvm::tirx::VarNode>();
654  auto extent_var = extent.as<tvm::tirx::VarNode>();
655 
656  if (idx_var && extent_var && idx_var->name_hint == extent_var->name_hint) {
657  return index;
658  }
659 
660  PrimExpr begin_range = tvm::if_then_else(stride < 0, -1, 0);
661  PrimExpr end_range = tvm::if_then_else(stride < 0, extent - 1, extent);
662 
663  if (!(index->IsInstance<tvm::IntImmNode>() && GetConstInt(index) >= 0)) {
664  index = tvm::if_then_else(index < 0, index + extent, index);
665  }
666 
667  return tvm::min(tvm::max(index, begin_range), end_range);
668 }
669 
670 inline int64_t StaticCanonicalizeIndex(int64_t index, int64_t extent, int64_t stride) {
671  int64_t begin_range = stride < 0 ? -1 : 0;
672  int64_t end_range = stride < 0 ? extent - 1 : extent;
673  if (index < 0) {
674  index += extent;
675  }
676  return std::min(std::max(index, begin_range), end_range);
677 }
678 
679 inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stride) {
680  if (index->IsInstance<tvm::IntImmNode>() && extent->IsInstance<tvm::IntImmNode>() &&
681  stride->IsInstance<tvm::IntImmNode>()) {
682  return tvm::IntImm(
683  tvm::DataType::Int(64),
684  StaticCanonicalizeIndex(GetConstInt(index), GetConstInt(extent), GetConstInt(stride)));
685  }
686  return DynamicCanonicalizeIndex(index, extent, stride);
687 }
688 
689 inline PrimExpr GetLength(PrimExpr begin, PrimExpr end, PrimExpr stride, PrimExpr extent,
690  bool assume_inbound = true) {
691  if (assume_inbound) {
692  return ceildiv(end - begin, stride);
693  } else {
694  begin = CanonicalizeIndex(begin, extent, stride);
695  end = CanonicalizeIndex(end, extent, stride);
696  return tvm::if_then_else(stride < 0, ceildiv(begin - end, -stride),
697  ceildiv(end - begin, stride));
698  }
699 }
700 
717  const te::Tensor& x, const ffi::Array<PrimExpr>& begin, const ffi::Array<PrimExpr>& end,
718  const ffi::Array<PrimExpr>& strides, const ffi::Array<Integer>& axes,
719  bool assume_inbound = true, std::string name = "T_dynamic_strided_slice_with_axes",
720  std::string tag = kInjective) {
721  const size_t src_tensor_dim = x->shape.size();
722  TVM_FFI_ICHECK_EQ(begin.size(), end.size());
723  TVM_FFI_ICHECK_EQ(begin.size(), strides.size());
724  TVM_FFI_ICHECK_EQ(begin.size(), axes.size());
725  TVM_FFI_ICHECK_LE(begin.size(), src_tensor_dim);
726 
727  for (const auto& axis_imm : axes) {
728  int axis = axis_imm->value;
729  TVM_FFI_ICHECK_LT(axis, src_tensor_dim);
730  }
731 
732  arith::Analyzer analyzer;
733 
734  ffi::Array<PrimExpr> out_shape = x->shape;
735  for (size_t i = 0; i < begin.size(); i++) {
736  int axis = axes[i]->value;
737  PrimExpr new_shape =
738  analyzer.Simplify(GetLength(begin[i], end[i], strides[i], out_shape[axis], assume_inbound));
739  out_shape.Set(axis, new_shape);
740  }
741 
742  return te::compute(
743  out_shape,
744  [&](const ffi::Array<tvm::tirx::Var>& indices) {
745  ffi::Array<PrimExpr> real_indices =
746  indices.Map([](const auto& var) -> PrimExpr { return var; });
747 
748  for (size_t i = 0; i < begin.size(); i++) {
749  int axis = axes[i]->value;
750  PrimExpr new_index = indices[axis] * strides[i] + begin[i];
751  real_indices.Set(axis, new_index);
752  }
753 
754  return x(real_indices);
755  },
756  name, tag);
757 }
758 
773 inline Tensor dynamic_strided_slice(const Tensor& x, const ffi::Array<PrimExpr>& begin,
774  const ffi::Array<PrimExpr>& end,
775  const ffi::Array<PrimExpr>& strides, bool assume_inbound = true,
776  std::string name = "T_dynamic_strided_slice",
777  std::string tag = kInjective) {
778  const size_t src_tensor_dim = x->shape.size();
779  TVM_FFI_ICHECK_LE(begin.size(), src_tensor_dim);
780  TVM_FFI_ICHECK_LE(end.size(), src_tensor_dim);
781  TVM_FFI_ICHECK_LE(strides.size(), src_tensor_dim);
782  TVM_FFI_ICHECK_EQ(begin.size(), end.size());
783  TVM_FFI_ICHECK_EQ(begin.size(), strides.size());
784 
785  const size_t num_slice_axes = begin.size();
786  ffi::Array<PrimExpr> out_shape;
787 
788  arith::Analyzer analyzer;
789  for (size_t i = 0; i < num_slice_axes; ++i) {
790  // Check ProducerLoad to keep backward compatibility for Relax.
791  if (!begin[i]->IsInstance<ProducerLoadNode>() && !end[i]->IsInstance<ProducerLoadNode>() &&
792  !strides[i]->IsInstance<ProducerLoadNode>()) {
793  out_shape.push_back(
794  analyzer.Simplify(GetLength(begin[i], end[i], strides[i], x->shape[i], assume_inbound)));
795  } else {
796  out_shape.push_back(tvm::tirx::Var("dim"));
797  }
798  }
799 
800  for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) {
801  out_shape.push_back(x->shape[i]);
802  }
803 
804  return te::compute(
805  out_shape,
806  [&](const ffi::Array<tvm::tirx::Var>& indices) {
807  ffi::Array<PrimExpr> real_indices;
808  for (size_t i = 0; i < num_slice_axes; ++i) {
809  real_indices.push_back(indices[i] * strides[i] + tvm::min(begin[i], x->shape[i] - 1));
810  }
811  // keep input dim
812  for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) {
813  real_indices.push_back(indices[i]);
814  }
815  return x(real_indices);
816  },
817  name, tag);
818 }
819 
835  const te::Tensor& end, const te::Tensor& strides,
836  bool assume_inbound = true,
837  std::string name = "T_strided_slice_dynamic",
838  std::string tag = topi::kInjective) {
839  DataType index_dtype = begin->shape[0]->dtype;
840  const int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value;
841  TVM_FFI_ICHECK_EQ(end->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
842  TVM_FFI_ICHECK_EQ(strides->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
843 
844  ffi::Array<PrimExpr> begin_expr, end_expr, strides_expr;
845  for (int64_t i = 0; i < num_dynamic_axes; ++i) {
846  auto ind = make_const(index_dtype, i);
847  begin_expr.push_back(begin(ind));
848  end_expr.push_back(end(ind));
849  strides_expr.push_back(strides(ind));
850  }
851  return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, assume_inbound, name, tag);
852 }
853 
868 inline ffi::Array<PrimExpr> StridedSliceOutputShape(const ffi::Array<PrimExpr>& ishape,
869  const ffi::Array<Integer>& begin,
870  const ffi::Array<Integer>& end,
871  const ffi::Array<Integer>& strides,
872  const ffi::Array<Integer>& axes,
873  const std::string& slice_mode) {
874  TVM_FFI_ICHECK(axes.size() == begin.size() && axes.size() == end.size() &&
875  axes.size() == strides.size());
876  std::vector<int64_t> begin_vec, end_vec, strides_vec;
877  std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode);
878  auto begin_canonicalized = StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes,
879  begin[0]->dtype, slice_mode);
880  return StridedSliceOutputShape(ishape, begin_vec, end_vec, strides_vec, axes, slice_mode,
881  begin_canonicalized, true);
882 }
883 
900 inline Tensor strided_slice_with_axes(const Tensor& x, const ffi::Array<Integer>& begin,
901  const ffi::Array<Integer>& end,
902  const ffi::Array<Integer>& strides,
903  const ffi::Array<Integer>& axes,
904  std::string slice_mode = "end",
905  std::string name = "T_strided_slice_with_axes",
906  std::string tag = kInjective) {
907  const int64_t src_tensor_dim = static_cast<int64_t>(x->shape.size());
908  TVM_FFI_ICHECK(static_cast<int64_t>(axes.size()) <= src_tensor_dim);
909  TVM_FFI_ICHECK(axes.size() == begin.size() && axes.size() == end.size() &&
910  axes.size() == strides.size());
911 
912  // Normalize negative axes
913  ffi::Array<Integer> normalized_axes;
914  for (size_t i = 0; i < axes.size(); ++i) {
915  int64_t axis = axes[i].IntValue();
916  if (axis < 0) {
917  axis += src_tensor_dim;
918  }
919  TVM_FFI_ICHECK(axis >= 0 && axis < src_tensor_dim)
920  << "Axis " << axes[i].IntValue() << " is out of bounds for tensor with " << src_tensor_dim
921  << " dimensions";
922  normalized_axes.push_back(Integer(axis));
923  }
924 
925  std::vector<int64_t> begin_vec, end_vec, strides_vec;
926  std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode);
927 
928  auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, normalized_axes,
929  begin[0]->dtype, slice_mode);
930  auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec, strides_vec,
931  normalized_axes, slice_mode, begin_expr);
932 
933  return te::compute(
934  out_shape,
935  [&](const ffi::Array<tirx::Var>& indices) {
936  ffi::Array<PrimExpr> real_indices;
937  for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]);
938  for (size_t i = 0; i < normalized_axes.size(); ++i) {
939  auto stride = make_const(strides[i].dtype(), strides_vec[i]);
940  PrimExpr ind = indices[normalized_axes[i].IntValue()] * stride + begin_expr[i];
941  real_indices.Set(normalized_axes[i].IntValue(), ind);
942  }
943  return x(real_indices);
944  },
945  name, tag);
946 }
947 
962 inline Tensor strided_slice(const Tensor& x, const ffi::Array<Integer>& begin,
963  const ffi::Array<Integer>& end, const ffi::Array<Integer>& strides,
964  std::string slice_mode = "end", std::string name = "T_strided_slice",
965  std::string tag = kInjective) {
966  size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
967  ffi::Array<Integer> axes;
968  for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i);
969  ffi::Array<Integer> begin_full(begin);
970  ffi::Array<Integer> end_full(end);
971  ffi::Array<Integer> strides_full(strides);
972 
973  DataType index_dtype = begin.size() > 0 ? begin[0]->dtype : DataType::Int(64);
974  const IntImm one = IntImm(index_dtype, 1);
975  const IntImm zero = IntImm(index_dtype, 0);
976  const IntImm max_range = Downcast<IntImm>(max_value(index_dtype));
977 
978  for (size_t i = strides.size(); i < src_tensor_dim; ++i) {
979  strides_full.push_back(one);
980  }
981  for (size_t i = begin.size(); i < src_tensor_dim; ++i) {
982  begin_full.push_back(GetConstInt(strides_full[i]) > 0 ? zero : max_range);
983  }
984  for (size_t i = end.size(); i < src_tensor_dim; ++i) {
985  end_full.push_back(GetConstInt(strides_full[i]) < 0 ? zero : max_range);
986  }
987 
988  return strided_slice_with_axes(x, begin_full, end_full, strides_full, axes, slice_mode, name,
989  tag);
990 }
991 
1004 inline ffi::Array<Tensor> split_n_sections(const Tensor& x, int num_sections, int axis,
1005  std::string name = "T_split_sections",
1006  std::string tag = kInjective) {
1007  if (axis < 0) {
1008  axis += static_cast<int>(x->shape.size());
1009  }
1010  TVM_FFI_ICHECK_LT(axis, x->shape.size()) << "axis out of bounds";
1011 
1012  auto src_axis_size = x->shape[axis];
1013 
1014  TVM_FFI_ICHECK_GT(num_sections, 0) << "Slice count must be > 0";
1015 
1016  ffi::Array<PrimExpr> split_indices;
1017  auto seg_size = indexdiv(src_axis_size + num_sections - 1, num_sections);
1018  for (int i = 0; i < num_sections; ++i) {
1019  // region at index 0 is added by split()
1020  if (i != 0) {
1021  split_indices.push_back(seg_size * i);
1022  }
1023  }
1024 
1025  return split_indices_array(x, split_indices, axis, name, tag);
1026 }
1027 
1040 inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims,
1041  std::string mode = "fast", std::string name = "T_take",
1042  std::string tag = kInjective) {
1043  ffi::Array<PrimExpr> a_shape = a->shape;
1044  ffi::Array<PrimExpr> out_shape = indices->shape;
1045  PrimExpr a_size = 1;
1046  for (size_t i = 0; i < a_shape.size(); ++i) {
1047  a_size = a_size * a_shape[i];
1048  }
1049 
1050  if (mode == "clip") {
1051  return compute(
1052  out_shape,
1053  [&](const ffi::Array<Var>& out_index) {
1054  auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
1055  return a(UnravelIndex(idx, a_shape));
1056  },
1057  name, tag);
1058  } else if (mode == "fast") {
1059  LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
1060  "Make sure input indices are in bound";
1061  return compute(
1062  out_shape,
1063  [&](const ffi::Array<Var>& out_index) {
1064  return a(UnravelIndex(indices(out_index), a_shape));
1065  },
1066  name, tag);
1067  } else if (mode == "nan") {
1068  return compute(
1069  out_shape,
1070  [&](const ffi::Array<Var>& out_index) {
1071  auto idx = tvm::if_then_else(
1072  indices(out_index) < 0 || indices(out_index) >= a_size,
1073  tvm::FloatImm(a->dtype, std::numeric_limits<float>::quiet_NaN()), indices(out_index));
1074  return a(UnravelIndex(idx, a_shape));
1075  },
1076  name, tag);
1077  } else { // mode == "wrap"
1078  return compute(
1079  out_shape,
1080  [&](const ffi::Array<Var>& out_index) {
1081  auto idx = truncmod(truncmod(indices(out_index), a_size) + a_size, a_size);
1082  return a(UnravelIndex(idx, a_shape));
1083  },
1084  name, tag);
1085  }
1086 }
1087 
1100 inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, double mask_value,
1101  int axis, std::string name = "T_sequence_mask",
1102  std::string tag = kInjective) {
1103  TVM_FFI_ICHECK(axis == 0 || axis == 1) << "axis must be either 0 or 1";
1104  TVM_FFI_ICHECK_EQ(valid_length->shape.size(), 1)
1105  << "valid_length must have ndim=1, i.e., (batch_size,).";
1106  auto length_dim = data->shape[axis];
1107  auto batch_dim = data->shape[1 - axis];
1108  ffi::Array<PrimExpr> out_shape = data->shape;
1109  Tensor out = compute(
1110  out_shape,
1111  [&](const ffi::Array<Var>& out_index) {
1112  ffi::Array<PrimExpr> len_index;
1113  auto tid = out_index[axis];
1114  auto bid = out_index[1 - axis];
1115  len_index.push_back(bid);
1116  PrimExpr ret =
1117  tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index),
1118  tvm::tirx::make_const(data->dtype, mask_value), data(out_index));
1119  return ret;
1120  },
1121  name, tag);
1122  return out;
1123 }
1124 
1139 inline Tensor take(const Tensor& a, ffi::Variant<Tensor, PrimExpr> indices, int batch_dims,
1140  int axis, std::string mode = "fast", std::string name = "T_take",
1141  std::string tag = kInjective) {
1142  if (axis < 0) {
1143  axis += static_cast<int>(a->shape.size());
1144  }
1145  TVM_FFI_ICHECK_GE(axis, 0) << "axis out of bounds";
1146  TVM_FFI_ICHECK_LT(axis, a->shape.size()) << "axis out of bounds";
1147  auto axis_dim = a->shape[axis];
1148  auto indices_shape = [&]() -> ffi::Array<PrimExpr> {
1149  if (auto tensor = indices.as<TensorNode>()) {
1150  return tensor->shape;
1151  } else {
1152  return {};
1153  }
1154  }();
1155 
1156  int indices_len = static_cast<int>(indices_shape.size());
1157 
1158  int batch_dims_ = batch_dims;
1159  if (batch_dims_ != 0) {
1160  TVM_FFI_ICHECK_GE(batch_dims_, -indices_len) << "batch_dims out of bounds";
1161  TVM_FFI_ICHECK_LE(batch_dims_, indices_len) << "batch_dims out of bounds";
1162 
1163  if (batch_dims_ < 0) {
1164  batch_dims_ = indices_len + batch_dims_;
1165  }
1166 
1167  TVM_FFI_ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds";
1168  TVM_FFI_ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to axis";
1169  for (int i = 0; i < batch_dims_; ++i) {
1170  auto addr1 = a->shape[i];
1171  auto addr2 = indices_shape[i];
1172  auto v1 = static_cast<IntImm*>(&addr1)->get()->value;
1173  auto v2 = static_cast<IntImm*>(&addr2)->get()->value;
1174  TVM_FFI_ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to indices.shape[" << i
1175  << "]";
1176  }
1177  }
1178 
1179  // The result shape is a.shape[:axis] + indices.shape[batch_dims:] +
1180  // a.shape[axis + 1:].
1181 
1182  ffi::Array<PrimExpr> out_shape;
1183  for (int i = 0; i < batch_dims_; ++i) {
1184  out_shape.push_back(a->shape[i]);
1185  }
1186  for (int i = batch_dims_; i < axis; ++i) {
1187  out_shape.push_back(a->shape[i]);
1188  }
1189  for (int i = batch_dims_; i < indices_len; ++i) {
1190  out_shape.push_back(indices_shape[i]);
1191  }
1192  for (size_t i = axis + 1; i < a->shape.size(); ++i) {
1193  out_shape.push_back(a->shape[i]);
1194  }
1195 
1196  auto get_index = [&](const ffi::Array<PrimExpr>& indices_position) -> PrimExpr {
1197  if (auto tensor = indices.as<Tensor>()) {
1198  return tensor.value()(indices_position);
1199  } else if (auto prim = indices.as<PrimExpr>()) {
1200  TVM_FFI_ICHECK_EQ(indices_position.size(), 0);
1201  return prim.value();
1202  } else {
1203  TVM_FFI_THROW(InternalError) << "Variant did not contain either allowed type";
1204  }
1205  };
1206 
1207  if (mode == "clip") {
1208  if (batch_dims_ == 0) {
1209  return compute(
1210  out_shape,
1211  [&](const ffi::Array<Var>& out_index) {
1212  ffi::Array<PrimExpr> indices_position;
1213  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1214  indices_position.push_back(out_index[j]);
1215  }
1216  ffi::Array<PrimExpr> real_indices;
1217  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1218  real_indices.push_back(out_index[j]);
1219  }
1220  auto idx = tvm::min(tvm::max(0, get_index(indices_position)), axis_dim - 1);
1221  real_indices.push_back(idx);
1222  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1223  real_indices.push_back(out_index[j]);
1224  }
1225  return a(real_indices);
1226  },
1227  name, tag);
1228  } else {
1229  return compute(
1230  out_shape,
1231  [&](const ffi::Array<Var>& out_index) {
1232  ffi::Array<PrimExpr> indices_position;
1233  for (size_t j = 0; j < static_cast<size_t>(batch_dims_); ++j) {
1234  indices_position.push_back(out_index[j]);
1235  }
1236  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len - batch_dims_); ++j) {
1237  indices_position.push_back(out_index[j]);
1238  }
1239  ffi::Array<PrimExpr> real_indices;
1240  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1241  real_indices.push_back(out_index[j]);
1242  }
1243  auto idx = tvm::min(tvm::max(0, get_index(indices_position)), axis_dim - 1);
1244  real_indices.push_back(idx);
1245  for (size_t j = axis + indices_len - batch_dims_; j < out_index.size(); ++j) {
1246  real_indices.push_back(out_index[j]);
1247  }
1248  return a(real_indices);
1249  },
1250  name, tag);
1251  }
1252  } else if (mode == "fast") {
1253  LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
1254  "Make sure input indices are in bound";
1255  return compute(
1256  out_shape,
1257  [&](const ffi::Array<Var>& out_index) {
1258  ffi::Array<PrimExpr> indices_position;
1259  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1260  indices_position.push_back(out_index[j]);
1261  }
1262  ffi::Array<PrimExpr> real_indices;
1263  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1264  real_indices.push_back(out_index[j]);
1265  }
1266  real_indices.push_back(get_index(indices_position));
1267  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1268  real_indices.push_back(out_index[j]);
1269  }
1270  return a(real_indices);
1271  },
1272  name, tag);
1273  } else if (mode == "nan") {
1274  return compute(
1275  out_shape,
1276  [&](const ffi::Array<Var>& out_index) {
1277  ffi::Array<PrimExpr> indices_position;
1278  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1279  indices_position.push_back(out_index[j]);
1280  }
1281  ffi::Array<PrimExpr> real_indices;
1282  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1283  real_indices.push_back(out_index[j]);
1284  }
1285  PrimExpr idx = get_index(indices_position);
1286  real_indices.push_back(idx);
1287  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1288  real_indices.push_back(out_index[j]);
1289  }
1290  PrimExpr in_bounds = idx >= 0 && idx < axis_dim;
1291  return tvm::if_then_else(
1292  in_bounds, a(real_indices),
1293  tvm::tirx::make_const(a->dtype, std::numeric_limits<float>::quiet_NaN()));
1294  },
1295  name, tag);
1296  } else { // mode == "wrap"
1297  return compute(
1298  out_shape,
1299  [&](const ffi::Array<Var>& out_index) {
1300  ffi::Array<PrimExpr> indices_position;
1301  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1302  indices_position.push_back(out_index[j]);
1303  }
1304  ffi::Array<PrimExpr> real_indices;
1305  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1306  real_indices.push_back(out_index[j]);
1307  }
1308  auto idx = truncmod(truncmod(get_index(indices_position), axis_dim) + axis_dim, axis_dim);
1309  real_indices.push_back(idx);
1310  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1311  real_indices.push_back(out_index[j]);
1312  }
1313  return a(real_indices);
1314  },
1315  name, tag);
1316  }
1317 }
1318 
1330 inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y,
1331  std::string name = "T_where", std::string tag = kBroadcast) {
1332  TVM_FFI_ICHECK_EQ(x->dtype, y->dtype)
1333  << "x and y must have the same dtype: " << x->dtype << " vs " << y->dtype;
1334  auto get_out_shape = [&]() {
1335  auto bh1 = detail::BroadcastShape(x->shape, y->shape);
1336  ffi::Array<PrimExpr> common_shape1(bh1.common_shape.begin(), bh1.common_shape.end());
1337  auto bh2 = detail::BroadcastShape(condition->shape, common_shape1);
1338  ffi::Array<PrimExpr> common_shape2(bh2.common_shape.begin(), bh2.common_shape.end());
1339  return common_shape2;
1340  };
1341 
1342  auto oshape = get_out_shape();
1343 
1344  auto c_bh = detail::BroadcastShape(condition->shape, oshape);
1345  auto x_bh = detail::BroadcastShape(x->shape, oshape);
1346  auto y_bh = detail::BroadcastShape(y->shape, oshape);
1347 
1348  auto select = [&](tvm::ffi::Array<tvm::tirx::Var> ovars) {
1349  auto c = condition(InputIndexFromBroadcast(ovars, condition, c_bh.vars1, c_bh.all_vars));
1350  auto true_val = x(InputIndexFromBroadcast(ovars, x, x_bh.vars1, x_bh.all_vars));
1351  auto false_val = y(InputIndexFromBroadcast(ovars, y, y_bh.vars1, y_bh.all_vars));
1352  return tvm::tirx::Select(c != 0, true_val, false_val);
1353  };
1354 
1355  return compute(oshape, select, name, tag);
1356 }
1357 
1370 inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = "T_repeat",
1371  std::string tag = kBroadcast) {
1372  int ndim = static_cast<int>(x->shape.size());
1373  TVM_FFI_ICHECK(-ndim - 1 <= axis && axis <= ndim)
1374  << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
1375  << ", but got axis = " << axis << ", and data.ndim = " << ndim;
1376  TVM_FFI_ICHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`"
1377  << ", but got repeats = " << repeats;
1378  if (axis < 0) {
1379  // Calculate offset from last dimension
1380  axis += ndim;
1381  }
1382  ffi::Array<PrimExpr> new_shape;
1383  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
1384  new_shape.push_back(x->shape[i]);
1385  }
1386  new_shape.push_back(repeats * x->shape[axis]);
1387  for (size_t i = axis + 1; i < x->shape.size(); ++i) {
1388  new_shape.push_back(x->shape[i]);
1389  }
1390 
1391  return compute(
1392  new_shape,
1393  [&](const ffi::Array<Var>& indices) {
1394  ffi::Array<PrimExpr> idx;
1395  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
1396  idx.push_back(indices[i]);
1397  }
1398  idx.push_back(indexdiv(indices[axis], repeats));
1399  for (size_t i = axis + 1; i < indices.size(); ++i) {
1400  idx.push_back(indices[i]);
1401  }
1402  return x(idx);
1403  },
1404  name, tag);
1405 }
1406 
1417 inline Tensor tile(const Tensor& x, ffi::Array<Integer> reps, std::string name = "T_tile",
1418  std::string tag = kBroadcast) {
1419  size_t ndim = x->shape.size();
1420  size_t rdim = reps.size();
1421  size_t tdim = (ndim > rdim) ? ndim : rdim;
1422  ffi::Array<PrimExpr> data_shape;
1423  ffi::Array<PrimExpr> reps_shape;
1424  ffi::Array<PrimExpr> new_shape;
1425  if (ndim == rdim) {
1426  for (size_t i = 0; i < ndim; ++i) {
1427  data_shape.push_back(x->shape[i]);
1428  reps_shape.push_back(reps[i]);
1429  }
1430  } else if (ndim > rdim) {
1431  for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]);
1432  for (size_t i = 0; i < (ndim - rdim); ++i) reps_shape.push_back(1);
1433  for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]);
1434  } else {
1435  for (size_t i = 0; i < (rdim - ndim); ++i) data_shape.push_back(1);
1436  for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]);
1437  for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]);
1438  }
1439  for (size_t i = 0; i < tdim; ++i) new_shape.push_back(data_shape[i] * reps_shape[i]);
1440 
1441  if (is_empty_shape(new_shape)) {
1442  return compute(
1443  new_shape, [&](const ffi::Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name,
1444  tag);
1445  } else {
1446  return compute(
1447  new_shape,
1448  [&](const ffi::Array<Var>& indices) {
1449  ffi::Array<PrimExpr> idx;
1450  if (ndim >= rdim) {
1451  for (size_t i = 0; i < ndim; ++i) idx.push_back(indexmod(indices[i], x->shape[i]));
1452  } else {
1453  for (size_t i = 0; i < ndim; ++i)
1454  idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
1455  }
1456  return x(idx);
1457  },
1458  name, tag);
1459  }
1460 }
1461 
1473 inline Tensor dyn_tile(const Tensor& x, ffi::Array<PrimExpr> new_shape, size_t rdim,
1474  std::string name = "T_tile", std::string tag = kBroadcast) {
1475  size_t ndim = x->shape.size();
1476  if (is_empty_shape(new_shape)) {
1477  return compute(
1478  new_shape, [&](const ffi::Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name,
1479  tag);
1480  } else {
1481  return compute(
1482  new_shape,
1483  [&](const ffi::Array<Var>& indices) {
1484  ffi::Array<PrimExpr> idx;
1485  if (ndim >= rdim) {
1486  for (size_t i = 0; i < ndim; ++i) {
1487  idx.push_back(indexmod(indices[i], x->shape[i]));
1488  }
1489  } else {
1490  for (size_t i = 0; i < ndim; ++i) {
1491  idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
1492  }
1493  }
1494  return x(idx);
1495  },
1496  name, tag);
1497  }
1498 }
1499 
1511 inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
1512  std::string name = "T_gather", std::string tag = kInjective) {
1513  size_t ndim_d = data->shape.size();
1514  size_t ndim_i = indices->shape.size();
1515  TVM_FFI_ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar.";
1516  TVM_FFI_ICHECK_EQ(ndim_d, ndim_i);
1517  if (axis < 0) {
1518  axis += ndim_d;
1519  }
1520  TVM_FFI_ICHECK_GE(axis, 0);
1521  TVM_FFI_ICHECK_LT(axis, ndim_d);
1522  if (indices->shape[axis].as<IntImmNode>()) {
1523  size_t indices_dim_i = static_cast<size_t>(GetConstInt(indices->shape[axis]));
1524  TVM_FFI_ICHECK_GE(indices_dim_i, 1);
1525  }
1526  TVM_FFI_ICHECK(indices->dtype.is_int() || indices->dtype.is_uint());
1527 
1528  ffi::Array<PrimExpr> out_shape;
1529  for (size_t i = 0; i < ndim_i; ++i) {
1530  out_shape.push_back(indices->shape[i]);
1531  }
1532 
1533  return compute(
1534  out_shape,
1535  [&](const ffi::Array<Var>& out_index) {
1536  ffi::Array<PrimExpr> indices_position;
1537  for (size_t i = 0; i < ndim_i; ++i) {
1538  indices_position.push_back(out_index[i]);
1539  }
1540  ffi::Array<PrimExpr> real_indices;
1541  for (size_t i = 0; i < ndim_i; ++i) {
1542  if (i == static_cast<size_t>(axis)) {
1543  real_indices.push_back(indices(indices_position));
1544  } else {
1545  real_indices.push_back(indices_position[i]);
1546  }
1547  }
1548  return data(real_indices);
1549  },
1550  name, tag);
1551 }
1552 
1564 inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dims = 0,
1565  std::string name = "T_gather_nd", std::string tag = kInjective) {
1566  size_t ndim_d = data->shape.size();
1567  size_t ndim_i = indices->shape.size();
1568  TVM_FFI_ICHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions";
1569  size_t indices_dim0 = static_cast<size_t>(GetConstInt(indices->shape[0]));
1570  TVM_FFI_ICHECK_LE(indices_dim0, ndim_d) << "dim 0 of indices tensor must be no more "
1571  << "than dimensions of data tensor";
1572  ffi::Array<PrimExpr> out_shape;
1573  for (size_t i = 1; i < ndim_i; ++i) {
1574  out_shape.push_back(indices->shape[i]);
1575  }
1576  for (size_t i = indices_dim0 + batch_dims; i < ndim_d; ++i) {
1577  out_shape.push_back(data->shape[i]);
1578  }
1579  return compute(
1580  out_shape,
1581  [&](const ffi::Array<Var>& out_index) {
1582  ffi::Array<PrimExpr> indices_position;
1583  indices_position.push_back(0);
1584  for (size_t i = 0; i < ndim_i - 1; ++i) {
1585  indices_position.push_back(out_index[i]);
1586  }
1587  ffi::Array<PrimExpr> real_indices;
1588  for (size_t i = 0; i < static_cast<size_t>(batch_dims); ++i) {
1589  real_indices.push_back(out_index[i]);
1590  }
1591  for (size_t i = 0; i < indices_dim0; ++i) {
1592  indices_position.Set(0, make_const(DataType::Int(32), i));
1593  if (indices->dtype.is_int() || indices->dtype.is_uint()) {
1594  real_indices.push_back(indices(indices_position));
1595  } else {
1596  real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position)));
1597  }
1598  }
1599  if (real_indices.size() == ndim_d) {
1600  return data(real_indices);
1601  }
1602  for (size_t i = ndim_i - 1; i < out_index.size(); ++i) {
1603  real_indices.push_back(out_index[i]);
1604  }
1605  return data(real_indices);
1606  },
1607  name, tag);
1608 }
1609 
1626  bool trans_a = false, bool trans_b = false,
1627  std::string name = "T_matmul", std::string tag = kMatMul) {
1628  tvm::ffi::Array<tvm::PrimExpr> output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]};
1629  auto k = tvm::te::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k");
1630  auto l = [&](tvm::tirx::Var i, tvm::tirx::Var j) {
1631  return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), {k});
1632  };
1633  return tvm::te::compute(output_shape, l, name, tag);
1634 }
1635 
1647 inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, int axes = 2,
1648  std::string name = "T_tensordot", std::string tag = kMatMul) {
1649  TVM_FFI_ICHECK_GE(A->shape.size(), axes);
1650  TVM_FFI_ICHECK_GE(B->shape.size(), axes);
1651 
1652  ffi::Array<PrimExpr> output_shape(A->shape.begin(), A->shape.end() + (-axes));
1653  for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) output_shape.push_back(*it);
1654 
1655  ffi::Array<IterVar> iter_vars;
1656  for (int i = 0; i < axes; ++i)
1657  iter_vars.push_back(reduce_axis(Range(0, B->shape[i]), "k" + std::to_string(i)));
1658 
1659  auto func = [&A, &B, &iter_vars, axes](const ffi::Array<Var>& input_indices) {
1660  ffi::Array<PrimExpr> A_indices(input_indices.begin(),
1661  input_indices.begin() + (A->shape.size() - axes));
1662  for (auto& v : iter_vars) A_indices.push_back(v);
1663 
1664  ffi::Array<PrimExpr> B_indices;
1665  for (auto& v : iter_vars) B_indices.push_back(v);
1666 
1667  auto it = input_indices.begin() + (A->shape.size() - axes);
1668  for (; it != input_indices.end(); ++it) B_indices.push_back(*it);
1669 
1670  // Some passes don't like reductions with empty axis, so avoid it here
1671  if (iter_vars.empty()) {
1672  return A(A_indices) * B(B_indices);
1673  } else {
1674  return sum(A(A_indices) * B(B_indices), iter_vars);
1675  }
1676  };
1677 
1678  return compute(output_shape, func, name, tag);
1679 }
1680 
1693 inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, ffi::Array<PrimExpr> A_axes,
1694  ffi::Array<PrimExpr> B_axes, std::string name = "T_tensordot",
1695  std::string tag = kMatMul) {
1696  TVM_FFI_ICHECK_EQ(A_axes.size(), B_axes.size());
1697 
1698  auto A_axes_val = GetConstIntValues(A_axes, "A_axes");
1699  auto B_axes_val = GetConstIntValues(B_axes, "B_axes");
1700 
1701  ffi::Array<PrimExpr> output_shape;
1702  for (unsigned i = 0; i < A->shape.size(); ++i)
1703  if (std::find(A_axes_val.begin(), A_axes_val.end(), i) == A_axes_val.end())
1704  output_shape.push_back(A->shape[i]);
1705  for (unsigned i = 0; i < B->shape.size(); ++i)
1706  if (std::find(B_axes_val.begin(), B_axes_val.end(), i) == B_axes_val.end())
1707  output_shape.push_back(B->shape[i]);
1708 
1709  ffi::Array<IterVar> iter_vars;
1710  for (unsigned i = 0; i < B_axes_val.size(); ++i)
1711  iter_vars.push_back(reduce_axis(Range(0, B->shape[B_axes_val[i]]), "k" + std::to_string(i)));
1712 
1713  auto func = [&A, &B, &iter_vars, A_axes_val, B_axes_val](const ffi::Array<Var>& input_indices) {
1714  int idx_input = 0;
1715  ffi::Array<PrimExpr> A_indices;
1716  for (unsigned i = 0; i < A->shape.size(); ++i) {
1717  auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i);
1718  if (axes_pos == A_axes_val.end()) {
1719  A_indices.push_back(input_indices[idx_input++]);
1720  } else {
1721  A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]);
1722  }
1723  }
1724 
1725  ffi::Array<PrimExpr> B_indices;
1726  for (unsigned i = 0; i < B->shape.size(); ++i) {
1727  auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i);
1728  if (axes_pos == B_axes_val.end()) {
1729  B_indices.push_back(input_indices[idx_input++]);
1730  } else {
1731  B_indices.push_back(iter_vars[axes_pos - B_axes_val.begin()]);
1732  }
1733  }
1734  return sum(A(A_indices) * B(B_indices), iter_vars);
1735  };
1736  return compute(output_shape, func, name, tag);
1737 }
1738 
1739 inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr& step,
1740  DataType dtype, std::string name = "T_arange", std::string tag = kInjective) {
1741  arith::Analyzer analyzer;
1742  PrimExpr num_elem;
1743  bool is_all_int = start.dtype().is_int() && stop.dtype().is_int() && step.dtype().is_int();
1744  if (is_all_int && analyzer.CanProveGreaterEqual(step, 1)) {
1745  // fast path for integer arange when step is positive
1746  num_elem = tvm::floordiv((stop - start + step - 1), step);
1747  } else if (is_all_int && analyzer.CanProveLess(step, 0)) {
1748  // fast path for integer arange when step is negative
1749  num_elem = tvm::floordiv((start - stop - step - 1), -step);
1750  } else {
1751  // fallback path for non-integer or step of unknown sign
1752  num_elem = tvm::cast(DefaultIndexType(),
1753  tvm::ceil(tvm::cast(tvm::DataType::Float(32), stop - start) / step));
1754  }
1755  num_elem = analyzer.Simplify(num_elem);
1756 
1757  return compute(
1758  {num_elem},
1759  [&](const ffi::Array<Var>& indices) { return tvm::cast(dtype, start + step * indices[0]); },
1760  name, tag);
1761 }
1762 
1773 inline ffi::Array<Tensor> meshgrid(const ffi::Array<Tensor>& inputs, const std::string& indexing,
1774  std::string name = "T_meshgrid", std::string tag = kInjective) {
1775  const bool cartesian_indexing = indexing == "xy" && inputs.size() >= 2;
1776  ffi::Array<PrimExpr> out_shape;
1777  for (size_t i = 0; i < inputs.size(); ++i) {
1778  const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i;
1779  out_shape.push_back(inputs[src_index]->shape.size() == 0 ? 1 : inputs[src_index]->shape[0]);
1780  }
1781  ffi::Array<Tensor> result;
1782  for (size_t i = 0; i < inputs.size(); ++i) {
1783  result.push_back(compute(
1784  out_shape,
1785  [&](const ffi::Array<Var>& indices) {
1786  const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i;
1787  auto ndim = inputs[i]->GetShape().size();
1788  ffi::Array<PrimExpr> real_indices = {};
1789  if (ndim > 0) {
1790  real_indices = {indices[src_index]};
1791  }
1792  return inputs[i](real_indices);
1793  },
1794  name, tag));
1795  }
1796  return result;
1797 }
1798 
1809 inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
1810  const std::string& dst_layout,
1811  const std::string schedule_rule = "None",
1812  const std::string name = "T_layout_trans",
1813  const std::string tag = kInjective) {
1814  Layout src_layout_struct(src_layout);
1815  Layout dst_layout_struct(dst_layout);
1816 
1817  if (src_layout_struct.Equals(dst_layout_struct)) {
1818  return src;
1819  }
1820 
1821  TVM_FFI_ICHECK(src_layout_struct.defined() && dst_layout_struct.defined())
1822  << "cannot convert from/to undefined layout";
1823 
1824  auto layout_converter = tirx::BijectiveLayout(src_layout_struct, dst_layout_struct);
1825  TVM_FFI_ICHECK(layout_converter.defined())
1826  << "cannot convert from " << src_layout << " to " << dst_layout;
1827 
1828  ffi::Array<PrimExpr> dst_shape = layout_converter.ForwardShape(src->shape);
1829 
1830  ffi::Map<ffi::String, ffi::Any> attrs = {{"schedule_rule", ffi::String(schedule_rule)},
1831  // Information about layouts needed for the schedule rule
1832  {"src_layout", ffi::String(src_layout)},
1833  {"dst_layout", ffi::String(dst_layout)},
1834  {"input_shape", src->shape}};
1835 
1836  return compute(
1837  dst_shape,
1838  [&](const ffi::Array<Var>& dst_indices) {
1839  ffi::Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
1840  ffi::Array<PrimExpr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
1841  PrimExpr in_range = PrimExpr(1) > PrimExpr(0); // init with dtype=bool and value=true
1842  for (size_t i = 0; i < src.ndim(); ++i) {
1843  in_range = in_range && (src_indices[i] < src->shape[i]);
1844  }
1845  return if_then_else(in_range, src(src_indices), tvm::cast(src->dtype, PrimExpr(0)));
1846  },
1847  name, tag, attrs);
1848 }
1849 
1851 inline void parse_auto_scheduler_layout(const ffi::String& layout, ffi::Array<PrimExpr>* shape,
1852  std::vector<std::string>* axes) {
1853  int32_t factor = 0;
1854  std::string axis = "";
1855  for (char c : std::string(layout)) {
1856  if (c >= 'A' && c <= 'z') {
1857  axis += c;
1858  if (factor != 0) {
1859  shape->push_back(factor);
1860  factor = 0;
1861  }
1862  } else if (c >= '0' && c <= '9') {
1863  factor = factor * 10 + c - '0';
1864  if (!axis.empty()) {
1865  axes->push_back(axis);
1866  axis = "";
1867  }
1868  } else {
1869  TVM_FFI_THROW(InternalError) << "Invalid layout " << layout;
1870  }
1871  }
1872  if (!axis.empty()) {
1873  axes->push_back(axis);
1874  }
1875 }
1876 
1888  const Tensor& src, const ffi::String& src_layout, const ffi::String& dst_layout,
1889  const ffi::String name = "T_auto_scheduler_layout_trans", const ffi::String tag = kInjective) {
1890  ffi::Array<PrimExpr> src_shape;
1891  std::vector<std::string> src_axes;
1892  ffi::Array<PrimExpr> dst_shape;
1893  std::vector<std::string> dst_axes;
1894 
1895  parse_auto_scheduler_layout(src_layout, &src_shape, &src_axes);
1896  parse_auto_scheduler_layout(dst_layout, &dst_shape, &dst_axes);
1897  return compute(
1898  dst_shape,
1899  [&](const ffi::Array<Var>& dst_indices) {
1900  ffi::Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
1901  ffi::Array<PrimExpr> src_indices;
1902  for (const std::string& src_axis : src_axes) {
1903  PrimExpr src_index = 0;
1904  TVM_FFI_ICHECK_EQ(dst_indices_expr.size(), dst_axes.size());
1905  for (size_t i = 0; i < dst_axes.size(); ++i) {
1906  if (dst_axes[i] == src_axis) {
1907  src_index = src_index * dst_shape[i] + dst_indices_expr[i];
1908  }
1909  }
1910  src_indices.push_back(src_index);
1911  }
1912  return src(src_indices);
1913  },
1914  name, tag);
1915 }
1916 
1954  const Tensor& src, const tirx::IndexMap& index_map,
1955  const ffi::String name = "T_meta_schedule_layout_trans", const ffi::String tag = kInjective) {
1956  arith::Analyzer analyzer;
1957  ffi::Array<Range> iter_domain;
1958  iter_domain.reserve(src->shape.size());
1959  for (const PrimExpr& e : src->shape) {
1960  iter_domain.push_back(Range::FromMinExtent(make_zero(e->dtype), e));
1961  }
1962  ffi::Array<PrimExpr> post_transform_shape = index_map->MapShape(src->shape, &analyzer);
1963  return compute(
1964  post_transform_shape,
1965  [src, inv = index_map.Inverse(iter_domain, &analyzer),
1966  &analyzer](const ffi::Array<Var>& indices) -> PrimExpr {
1967  return src(
1968  inv->MapIndices(ffi::Array<PrimExpr>{indices.begin(), indices.end()}, &analyzer));
1969  },
1970  name, tag);
1971 }
1972 
1981 inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = "T_shape",
1982  const std::string tag = kInjective) {
1983  int ndim = static_cast<int>(src->shape.size());
1984  ffi::Array<PrimExpr> out_shape{ndim};
1985  return compute(
1986  out_shape,
1987  [&](const ffi::Array<Var>& indices) {
1988  auto idx = indices[0];
1989  PrimExpr ret = 0;
1990  for (int i = 0; i < ndim; ++i) {
1991  ret = tvm::if_then_else(idx == i, src->shape[i], ret);
1992  }
1993  return tvm::cast(dtype, ret);
1994  },
1995  name, tag);
1996 }
1997 
2006 inline te::Tensor tensor_size(const te::Tensor& src, const DataType& dtype,
2007  const std::string& name = "tensor_size",
2008  const std::string& tag = kInjective) {
2009  int ndim = static_cast<int>(src->shape.size());
2010  ffi::Array<PrimExpr> out_tensor_size = {};
2011  return compute(
2012  out_tensor_size,
2013  [&](const ffi::Array<Var>& indices) {
2014  PrimExpr ret = 1;
2015  for (int i = 0; i < ndim; ++i) {
2016  ret *= src->shape[i];
2017  }
2018  return tvm::cast(dtype, ret);
2019  },
2020  name, tag);
2021 }
2022 
2037 inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value,
2038  int depth, int axis, const DataType& dtype,
2039  ffi::Array<PrimExpr> oshape = ffi::Array<PrimExpr>(),
2040  const std::string name = "T_one_hot", const std::string tag = kInjective) {
2041  int true_axis = (axis == -1) ? indices->shape.size() : axis;
2042  if (oshape.size() == 0) {
2043  int ndim = indices->shape.size() + 1;
2044  int indices_index = 0;
2045  for (int i = 0; i < ndim; i++) {
2046  if (i == true_axis) {
2047  oshape.push_back(Integer(depth));
2048  } else {
2049  oshape.push_back(indices->shape[indices_index++]);
2050  }
2051  }
2052  }
2053 
2054  PrimExpr on_value_cast = cast(dtype, on_value);
2055  PrimExpr off_value_cast = cast(dtype, off_value);
2056  return compute(
2057  oshape,
2058  [&](const ffi::Array<Var>& iter_vars) {
2059  ffi::Array<Var> indices_indices;
2060  for (size_t i = 0; i < iter_vars.size(); i++) {
2061  if (static_cast<int>(i) == true_axis) {
2062  continue;
2063  }
2064 
2065  indices_indices.push_back(iter_vars[i]);
2066  }
2067 
2068  auto idx = iter_vars[true_axis];
2069  return tirx::Select(indices(indices_indices) == idx, on_value_cast, off_value_cast);
2070  },
2071  name, tag);
2072 }
2073 
2084 inline Tensor sparse_to_dense(const Tensor& sparse_indices,
2085  const ffi::Array<PrimExpr>& output_shape, const Tensor& sparse_values,
2086  const PrimExpr& default_value,
2087  const std::string name = "T_sparse_to_dense",
2088  const std::string tag = kInjective) {
2089  TVM_FFI_ICHECK(sparse_indices->dtype.is_int()) << "sparse_indices only accepts integer values";
2090  TVM_FFI_ICHECK_LE(sparse_indices->shape.size(), 3)
2091  << "sparse_indices tensor should be 0D, 1D, or 2D only";
2092  TVM_FFI_ICHECK_LE(sparse_values->shape.size(), 2)
2093  << "sparse_values tensor should be 0D or 1D only";
2094 
2095  const auto rank_sparse_indices = static_cast<int>(sparse_indices->shape.size());
2096  ffi::Array<PrimExpr> oshape;
2097  for (auto l : output_shape) {
2098  oshape.push_back(l);
2099  }
2100  return compute(
2101  oshape,
2102  [&](const ffi::Array<Var>& indices) {
2103  PrimExpr ret = default_value;
2104  if (0 == rank_sparse_indices) {
2105  ret = if_then_else(indices[0] == sparse_indices(), sparse_values(), ret);
2106  } else if (1 == rank_sparse_indices) {
2107  for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
2108  ret = if_then_else(indices[0] == sparse_indices[j], sparse_values[j], ret);
2109  }
2110  } else {
2111  for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
2112  PrimExpr aggregate_condition;
2113  for (int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) {
2114  PrimExpr comparision = indices[k] == sparse_indices[j][k];
2115  aggregate_condition = 0 == k ? comparision : aggregate_condition && comparision;
2116  }
2117  ret = if_then_else(aggregate_condition, sparse_values[j], ret);
2118  }
2119  }
2120  return ret;
2121  },
2122  name, tag);
2123 }
2124 
2137 inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k1, int k2,
2138  bool super_diag_right_align, bool sub_diag_right_align,
2139  const std::string name = "T_matrix_set_diag",
2140  const std::string tag = kInjective) {
2141  size_t ndim = input->shape.size() - 1;
2142 
2143  bool only_one_diagonal = k1 == k2;
2144 
2145  return compute(
2146  input->shape,
2147  [&](const ffi::Array<Var>& iter_vars) {
2148  auto get_diag = [&]() {
2149  ffi::Array<PrimExpr> diagonal_indices;
2150  PrimExpr k, offset = 0;
2151  for (size_t i = 0; i < ndim - 1; i++) {
2152  diagonal_indices.push_back(iter_vars[i]);
2153  }
2154  if (only_one_diagonal) {
2155  k = k1;
2156  } else {
2157  // Determining which diagonal/sub-diagonal/super-diagonal it is
2158  k = iter_vars[ndim] - iter_vars[ndim - 1];
2159  diagonal_indices.push_back(k2 - k);
2160 
2161  // Calculating the offset in diagonal tensor for this diagonal
2162  auto get_offset = [&](PrimExpr M, PrimExpr N) {
2163  // offset = max_diagonal_length - diagonal_length
2164  return diagonal->shape[diagonal->shape.size() - 1] - if_then_else(M < N, M, N);
2165  };
2166  offset = if_then_else(
2167  k >= 0,
2168  super_diag_right_align ? get_offset(input->shape[ndim] - k, input->shape[ndim - 1])
2169  : 0,
2170  sub_diag_right_align ? get_offset(input->shape[ndim], input->shape[ndim - 1] + k)
2171  : 0);
2172  }
2173  diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) +
2174  offset);
2175  return diagonal(diagonal_indices);
2176  };
2177  return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1,
2178  if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2,
2179  get_diag(), input(iter_vars)),
2180  input(iter_vars));
2181  },
2182  name, tag);
2183 }
2184 
2193 inline Tensor adv_index(const Tensor& data, const ffi::Array<Tensor>& indices,
2194  const std::string name = "advanced_index",
2195  const std::string tag = kInjective) {
2196  TVM_FFI_ICHECK_LE(indices.size(), data->shape.size()) << "too many indices for data!";
2197  ffi::Array<PrimExpr> oshape;
2198  ffi::Array<PrimExpr> broadcast_shape;
2199  ffi::Array<Tensor> bindices;
2200 
2201  broadcast_shape = indices[0]->shape;
2202  for (size_t i = 1; i < indices.size(); ++i) {
2203  auto bh = detail::BroadcastShape(broadcast_shape, indices[i]->shape);
2204  broadcast_shape = ffi::Array<PrimExpr>(bh.common_shape.begin(), bh.common_shape.end());
2205  }
2206  if (indices.size() == 1) {
2207  // quick path
2208  bindices = indices;
2209  } else {
2210  // Do broadcast for indices
2211  for (size_t i = 0; i < indices.size(); ++i) {
2212  bindices.push_back(broadcast_to(indices[i], broadcast_shape));
2213  }
2214  }
2215 
2216  for (const auto& dim : broadcast_shape) {
2217  oshape.push_back(dim);
2218  }
2219  for (size_t i = indices.size(); i < data->shape.size(); ++i) {
2220  oshape.push_back(data->shape[i]);
2221  }
2222 
2223  return compute(
2224  oshape,
2225  [&](const ffi::Array<Var>& iter_var) {
2226  ffi::Array<PrimExpr> tensor_indices;
2227  for (size_t i = 0; i < broadcast_shape.size(); ++i) {
2228  tensor_indices.push_back(iter_var[i]);
2229  }
2230  ffi::Array<PrimExpr> real_indices;
2231  for (size_t i = 0; i < bindices.size(); ++i) {
2232  real_indices.push_back(bindices[i](tensor_indices));
2233  }
2234  for (size_t i = broadcast_shape.size(); i < iter_var.size(); ++i) {
2235  real_indices.push_back(iter_var[i]);
2236  }
2237 
2238  return data(real_indices);
2239  },
2240  name, tag);
2241 }
2242 
2243 namespace relax {
2244 // relax dynamic slice
2246  const te::Tensor& end, const te::Tensor& strides,
2247  ffi::Array<PrimExpr> output_shape,
2248  std::string name = "T_strided_slice_dynamic",
2249  std::string tag = kInjective) {
2250  const size_t num_dynamic_axes = x.ndim();
2251  TVM_FFI_ICHECK_EQ(begin.ndim(), 1);
2252  TVM_FFI_ICHECK_EQ(end.ndim(), 1);
2253  TVM_FFI_ICHECK_EQ(strides.ndim(), 1);
2254  const auto* len_begin = begin->shape[0].as<IntImmNode>();
2255  const auto* len_end = end->shape[0].as<IntImmNode>();
2256  const auto* len_strides = strides->shape[0].as<IntImmNode>();
2257  TVM_FFI_ICHECK(len_begin);
2258  TVM_FFI_ICHECK(len_end);
2259  TVM_FFI_ICHECK(len_strides);
2260  TVM_FFI_ICHECK_EQ(len_begin->value, num_dynamic_axes);
2261  TVM_FFI_ICHECK_EQ(len_end->value, num_dynamic_axes);
2262  TVM_FFI_ICHECK_EQ(len_strides->value, num_dynamic_axes);
2263 
2264  return te::compute(
2265  output_shape,
2266  [&](const ffi::Array<tvm::tirx::Var>& indices) {
2267  ffi::Array<PrimExpr> real_indices;
2268  for (size_t i = 0; i < num_dynamic_axes; ++i) {
2269  auto ind = make_const(DataType::Int(64), i);
2270  real_indices.push_back(indices[i] * strides(ind) + tvm::min(begin(ind), x->shape[i] - 1));
2271  }
2272  return x(real_indices);
2273  },
2274  name, tag);
2275 }
2276 
2277 } // namespace relax
2278 
2279 } // namespace topi
2280 } // namespace tvm
2281 #endif // TVM_TOPI_TRANSFORM_H_
Algebra expression simplifications.
Broadcast op constructions.
Managed reference class to FloatImmNode.
Definition: expr.h:546
Constant integer literals in the program.
Definition: expr.h:494
int64_t value
the Internal value.
Definition: expr.h:497
Managed reference class to IntImmNode.
Definition: expr.h:511
Container of constant int that adds more constructors.
Definition: expr.h:601
Reference to PrimExprNode.
Definition: expr.h:126
DataType dtype() const
Definition: expr.h:140
Range container
Definition: expr.h:690
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span=Span())
construct a new range with min and extent The corresponding constructor is removed,...
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:634
bool CanProveGreaterEqual(const PrimExpr &expr, int64_t lower_bound)
Whether can we prove expr >= val.
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
bool CanProveLess(const PrimExpr &expr, int64_t upper_bound)
Whether can we prove expr < val.
Runtime primitive data type.
Definition: data_type.h:47
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:295
bool is_int() const
Definition: data_type.h:194
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:278
Managed Tensor. The array is backed by reference counted blocks.
Definition: tensor.h:54
Node to represent a tensor.
Definition: tensor.h:70
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:100
size_t ndim() const
Definition: tensor.h:212
Bijective function mapping for data layout transformation. Given two Layout, BijectiveLayout build an...
Definition: data_layout.h:386
Definition: index_map.h:170
IndexMap Inverse(ffi::Array< Range > initial_ranges, arith::Analyzer *analyzer) const
Generate the inverse mapping.
Managed reference to LayoutNode.
Definition: data_layout.h:126
bool Equals(const Layout &rhs) const
Whether the two layouts are equal.
Definition: data_layout.h:332
Managed reference to SelectNode.
Definition: expr.h:514
A variable node in the IR.
Definition: var.h:47
ffi::String name_hint
The hint to the variable name.
Definition: var.h:53
a named variable in TIR
Definition: var.h:76
Utility functions for handling constants in TVM expressions.
Layout expression to describe the data organization of a tensor. And BijectiveLayout to mapping two d...
Detail broadcast.
Defines a remapping of buffer indices.
Base expr nodes in TVM.
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(ffi::Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", ffi::Map< ffi::String, ffi::Any > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:1007
DataType DefaultIndexType()
if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return int32
Definition: buffer.h:43
PrimExpr make_zero(DataType t, Span span=Span())
Make a const zero expr.
Definition: op.h:1021
te::Tensor dynamic_strided_slice(const te::Tensor &x, const te::Tensor &begin, const te::Tensor &end, const te::Tensor &strides, ffi::Array< PrimExpr > output_shape, std::string name="T_strided_slice_dynamic", std::string tag=kInjective)
Definition: transform.h:2245
PrimExpr GetLength(PrimExpr begin, PrimExpr end, PrimExpr stride, PrimExpr extent, bool assume_inbound=true)
Definition: transform.h:689
Tensor sequence_mask(const Tensor &data, const Tensor &valid_length, double mask_value, int axis, std::string name="T_sequence_mask", std::string tag=kInjective)
Mask the out-of-boundary elements of each sequence.
Definition: transform.h:1100
Tensor gather_nd(const Tensor &data, const Tensor &indices, int batch_dims=0, std::string name="T_gather_nd", std::string tag=kInjective)
Gather elements from a n-dimension array.
Definition: transform.h:1564
int64_t StaticCanonicalizeIndex(int64_t index, int64_t extent, int64_t stride)
Definition: transform.h:670
Tensor reshape(const Tensor &x, ffi::Array< PrimExpr > newshape, std::string name="T_reshape", std::string tag=kInjective)
Reshape a tensor.
Definition: transform.h:330
Tensor one_hot(const Tensor &indices, const PrimExpr on_value, const PrimExpr off_value, int depth, int axis, const DataType &dtype, ffi::Array< PrimExpr > oshape=ffi::Array< PrimExpr >(), const std::string name="T_one_hot", const std::string tag=kInjective)
Returns a one-hot tensor where the locations repsented by indices take value on_value,...
Definition: transform.h:2037
tvm::te::Tensor broadcast_to(const tvm::te::Tensor &t, const tvm::ffi::Array< tvm::PrimExpr > &output_shape, std::string name="T_broadcast_to", std::string tag=kBroadcast)
Creates an operation that broadcasts a tensor into a compatible shape according to numpy's rules.
Definition: broadcast.h:48
constexpr auto kBroadcast
Definition: tags.h:36
Tensor arange(const PrimExpr &start, const PrimExpr &stop, const PrimExpr &step, DataType dtype, std::string name="T_arange", std::string tag=kInjective)
Definition: transform.h:1739
constexpr auto kInjective
Definition: tags.h:33
Tensor stack(const ffi::Array< Tensor > &inputs, int axis=0, std::string name="T_stack", std::string tag=kInjective)
Join a sequence of tensors along a new axis.
Definition: transform.h:541
Tensor auto_scheduler_layout_transform(const Tensor &src, const ffi::String &src_layout, const ffi::String &dst_layout, const ffi::String name="T_auto_scheduler_layout_trans", const ffi::String tag=kInjective)
Transform the auto-scheduler generated layout according to src_layout and dst_layout.
Definition: transform.h:1887
ffi::Array< PrimExpr > StridedSliceOutputShape(const ffi::Array< PrimExpr > &ishape, const ffi::Array< Integer > &begin, const ffi::Array< Integer > &end, const ffi::Array< Integer > &strides, const ffi::Array< Integer > &axes, const std::string &slice_mode)
Calculate the output shape of strided_slice, the entry point for Relax type relation.
Definition: transform.h:868
PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stride)
Definition: transform.h:679
te::Tensor dynamic_strided_slice_with_axes(const te::Tensor &x, const ffi::Array< PrimExpr > &begin, const ffi::Array< PrimExpr > &end, const ffi::Array< PrimExpr > &strides, const ffi::Array< Integer > &axes, bool assume_inbound=true, std::string name="T_dynamic_strided_slice_with_axes", std::string tag=kInjective)
strided_slice of a tensor where begin/end/stride can be mixed static and dynamic
Definition: transform.h:716
Tensor transpose(const Tensor &x, ffi::Optional< ffi::Array< Integer >> opt_axes, std::string name="T_transpose", std::string tag=kInjective)
Permute the dimensions of an array.
Definition: transform.h:205
void parse_auto_scheduler_layout(const ffi::String &layout, ffi::Array< PrimExpr > *shape, std::vector< std::string > *axes)
Utility function for auto_scheduler_layout_transform.
Definition: transform.h:1851
Tensor squeeze(const Tensor &x, ffi::Optional< ffi::Array< Integer >> opt_axes, bool atleast1d=false, std::string name="T_squeeze", std::string tag=kInjective)
Remove size 1 dimensions from the shape of a tensor. The removed dimensions must have a constant size...
Definition: transform.h:415
ffi::Array< Tensor > split_n_sections(const Tensor &x, int num_sections, int axis, std::string name="T_split_sections", std::string tag=kInjective)
Split a tensor into a number of sub-tensors.
Definition: transform.h:1004
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type,...
Definition: elemwise.h:277
Tensor expand_dims(const Tensor &x, int axis, int num_newaxis=1, std::string name="T_expand_dims", std::string tag=kBroadcast)
Creates an operation to insert new dimensions of length 1.
Definition: transform.h:156
Tensor sparse_to_dense(const Tensor &sparse_indices, const ffi::Array< PrimExpr > &output_shape, const Tensor &sparse_values, const PrimExpr &default_value, const std::string name="T_sparse_to_dense", const std::string tag=kInjective)
Get a dense tensor.
Definition: transform.h:2084
Tensor unravel_index(const Tensor &x, const Tensor &shape, std::string name="T_unravel", std::string tag=kInjective)
Converts a flat index or array of flat indices into a tuple of coordinate arrays.
Definition: transform.h:367
Tensor layout_transform(const Tensor &src, const std::string &src_layout, const std::string &dst_layout, const std::string schedule_rule="None", const std::string name="T_layout_trans", const std::string tag=kInjective)
Transform the layout according to src_layout and dst_layout.
Definition: transform.h:1809
Tensor adv_index(const Tensor &data, const ffi::Array< Tensor > &indices, const std::string name="advanced_index", const std::string tag=kInjective)
Numpy style advanced indexing with tensor.
Definition: transform.h:2193
Tensor strided_slice(const Tensor &x, const ffi::Array< Integer > &begin, const ffi::Array< Integer > &end, const ffi::Array< Integer > &strides, std::string slice_mode="end", std::string name="T_strided_slice", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:962
Tensor concatenate(const ffi::Array< Tensor > &inputs, int axis=0, std::string name="T_concat", std::string tag=kInjective)
Join a sequence of tensors along an existing axis.
Definition: transform.h:481
ffi::Array< Tensor > meshgrid(const ffi::Array< Tensor > &inputs, const std::string &indexing, std::string name="T_meshgrid", std::string tag=kInjective)
Produce grids by expanding input over dimensions defined by other inputs.
Definition: transform.h:1773
constexpr auto kMatMul
Definition: tags.h:37
Tensor strided_slice_with_axes(const Tensor &x, const ffi::Array< Integer > &begin, const ffi::Array< Integer > &end, const ffi::Array< Integer > &strides, const ffi::Array< Integer > &axes, std::string slice_mode="end", std::string name="T_strided_slice_with_axes", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:900
Tensor dyn_tile(const Tensor &x, ffi::Array< PrimExpr > new_shape, size_t rdim, std::string name="T_tile", std::string tag=kBroadcast)
Creates an operation to tile elements of an array.
Definition: transform.h:1473
Tensor reverse_sequence(const Tensor &x, const Tensor &seq_lengths, int seq_axis=1, int batch_axis=0, std::string name="T_reverse_sequence", std::string tag=kInjective)
Reverse the tensor for variable length slices. Input is first sliced along batch axis and then elemen...
Definition: transform.h:265
ffi::Array< Tensor > split_indices_array(const Tensor &x, ffi::Array< PrimExpr > split_indices, int axis, std::string name="T_split", std::string tag=kInjective)
Split a tensor into multiple sub-tensors.
Definition: transform.h:587
Tensor tensordot(const Tensor &A, const tvm::te::Tensor &B, int axes=2, std::string name="T_tensordot", std::string tag=kMatMul)
A generalization of matrix multiplication to tensors.
Definition: transform.h:1647
Tensor sum(const Tensor &data, const ffi::Optional< ffi::Array< Integer >> &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:328
Tensor tile(const Tensor &x, ffi::Array< Integer > reps, std::string name="T_tile", std::string tag=kBroadcast)
Creates an operation to tile elements of an array.
Definition: transform.h:1417
Tensor meta_schedule_layout_transform(const Tensor &src, const tirx::IndexMap &index_map, const ffi::String name="T_meta_schedule_layout_trans", const ffi::String tag=kInjective)
Transform the meta-schedule generated layout according to TIR's IndexMap.
Definition: transform.h:1953
Tensor take(const Tensor &a, const Tensor &indices, int batch_dims, std::string mode="fast", std::string name="T_take", std::string tag=kInjective)
Take elements from an flattened input array when axis is None.
Definition: transform.h:1040
PrimExpr DynamicCanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stride)
Definition: transform.h:652
tvm::te::Tensor matmul(const tvm::te::Tensor &A, const tvm::te::Tensor &B, bool trans_a=false, bool trans_b=false, std::string name="T_matmul", std::string tag=kMatMul)
Creates an operation that calculates a matrix multiplication (row-major notation): A(i,...
Definition: transform.h:1625
Tensor dynamic_strided_slice(const Tensor &x, const ffi::Array< PrimExpr > &begin, const ffi::Array< PrimExpr > &end, const ffi::Array< PrimExpr > &strides, bool assume_inbound=true, std::string name="T_dynamic_strided_slice", std::string tag=kInjective)
strided_slice of a tensor where begin/end/stride can be mixed static and dynamic
Definition: transform.h:773
Tensor matrix_set_diag(const Tensor &input, const Tensor &diagonal, int k1, int k2, bool super_diag_right_align, bool sub_diag_right_align, const std::string name="T_matrix_set_diag", const std::string tag=kInjective)
Returns a tensor with the diagonal of input tensor replaced with the provided diagonals.
Definition: transform.h:2137
Tensor where(const Tensor &condition, const Tensor &x, const Tensor &y, std::string name="T_where", std::string tag=kBroadcast)
Return the elements, either from x or y, depending on the condition.
Definition: transform.h:1330
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1981
Tensor gather(const Tensor &data, int axis, const Tensor &indices, std::string name="T_gather", std::string tag=kInjective)
Gather values along given axis from given indices.
Definition: transform.h:1511
Tensor sliding_window(const Tensor &x, int axis, ffi::Array< Integer > window_shape, ffi::Array< Integer > strides, std::string name="T_sliding_window", std::string tag="")
Creates an operation to slide a window over the input x.
Definition: transform.h:76
te::Tensor tensor_size(const te::Tensor &src, const DataType &dtype, const std::string &name="tensor_size", const std::string &tag=kInjective)
Get the size of input tensor.
Definition: transform.h:2006
Tensor repeat(const Tensor &x, int repeats, int axis, std::string name="T_repeat", std::string tag=kBroadcast)
Creates an operation to repeat elements of an array.
Definition: transform.h:1370
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span=Span())
compute ceil(a / b)
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of truncdiv
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr max_value(const DataType &dtype, Span span=Span())
PrimExpr ceil(PrimExpr x, Span span=Span())
Calculate ceil(x)
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
PrimExpr sum(PrimExpr source, ffi::Array< tirx::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b)
Operation node can generate one or multiple Tensors.
Index ravel and unraval operations.
Utility functions for strided_slice op.
Tag definitions.
Utility functions for handling tensor.
TIR expressions.
Common operators defined for Expr.
Variables in the TIR.