tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_TRANSFORM_H_
25 #define TVM_TOPI_TRANSFORM_H_
26 
27 #include <tvm/arith/analyzer.h>
28 #include <tvm/te/operation.h>
29 #include <tvm/tir/data_layout.h>
30 #include <tvm/tir/index_map.h>
31 #include <tvm/topi/broadcast.h>
37 #include <tvm/topi/tags.h>
38 
39 #include <algorithm>
40 #include <iterator>
41 #include <limits>
42 #include <string>
43 #include <unordered_set>
44 #include <vector>
45 
46 namespace tvm {
47 namespace topi {
48 
49 using namespace tvm::te;
50 using namespace topi::detail;
51 
69 inline Tensor sliding_window(const Tensor& x, int axis, Array<Integer> window_shape,
70  Array<Integer> strides, std::string name = "T_sliding_window",
71  std::string tag = "") {
72  CHECK_GE(axis, 0);
73  auto _axis = size_t(axis);
74  CHECK_LT(_axis, x->shape.size()) << "axis must be a valid dimension index of x.";
75  CHECK_EQ(x->shape.size() - _axis, window_shape.size())
76  << "There must be a window shape for every dimension of x "
77  << "over which we are sliding the window.";
78  CHECK_EQ(strides.size(), window_shape.size()) << "Windows and strides should be the same length.";
79 
80  // Compute the new shape.
81  Array<PrimExpr> new_shape;
82  // Dimensions up until `axis` remain the same.
83  for (size_t i = 0; i < _axis; ++i) {
84  new_shape.push_back(x->shape[i]);
85  }
86 
87  // New dimensions which result from sliding the window in each dimension. One new dimension per
88  // window dimension.
89  for (size_t i = 0; i < window_shape.size(); ++i) {
90  // Length of the shape along this dimension.
91  auto dim_len = x->shape[_axis + i];
92  // Length of the window along this dimension.
93  auto window_len = window_shape[i];
94  // Strides along this dimension.
95  auto stride = strides[i];
96 
97  new_shape.push_back(floordiv(dim_len - (window_len - 1) + stride - 1, stride));
98  }
99 
100  // Dimensions comprising the window.
101  for (size_t i = 0; i < window_shape.size(); ++i) {
102  new_shape.push_back(window_shape[i]);
103  }
104 
105  ICHECK(new_shape.size() == _axis + 2 * window_shape.size());
106 
107  return compute(
108  new_shape,
109  [&](const Array<Var>& indices) {
110  // The index at which to index the old tensor x.
111  Array<PrimExpr> idx;
112 
113  // Dimensions up until `axis` remain the same.
114  for (size_t i = 0; i < _axis; ++i) {
115  idx.push_back(indices[i]);
116  }
117 
118  for (size_t i = 0; i < window_shape.size(); ++i) {
119  // Which window in this dimension we are indexing.
120  auto window_idx = indices[_axis + i];
121  // Which index within the window we are indexing.
122  auto idx_within_window = indices[_axis + window_shape.size() + i];
123  // Stride value for this dimension.
124  auto stride = strides[i];
125 
126  idx.push_back(window_idx * stride + idx_within_window);
127  }
128 
129  ICHECK(idx.size() == x->shape.size());
130 
131  return x(idx);
132  },
133  name, tag);
134 }
135 
148 inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1,
149  std::string name = "T_expand_dims", std::string tag = kBroadcast) {
150  int ndim = static_cast<int>(x->shape.size());
151  ICHECK(-ndim - 1 <= axis && axis <= ndim)
152  << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]"
153  << ", but got axis = " << axis << ", and data.ndim = " << ndim;
154  ICHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`"
155  << ", but got num_newaxis = " << num_newaxis;
156  if (axis < 0) {
157  // Calculate offset from last dimension
158  axis = ndim + axis + 1;
159  }
160  Array<PrimExpr> new_shape;
161  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
162  new_shape.push_back(x->shape[i]);
163  }
164  for (size_t i = 0; i < static_cast<size_t>(num_newaxis); ++i) {
165  new_shape.push_back(1);
166  }
167  for (size_t i = axis; i < x->shape.size(); ++i) {
168  new_shape.push_back(x->shape[i]);
169  }
170 
171  return compute(
172  new_shape,
173  [&](const Array<Var>& indices) {
174  Array<PrimExpr> idx;
175  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
176  idx.push_back(indices[i]);
177  }
178  for (size_t i = axis + num_newaxis; i < indices.size(); ++i) {
179  idx.push_back(indices[i]);
180  }
181  return x(idx);
182  },
183  name, tag);
184 }
185 
197 inline Tensor transpose(const Tensor& x, Array<Integer> axes, std::string name = "T_transpose",
198  std::string tag = kInjective) {
199  if (!axes.defined() || axes.size() == 0) {
200  axes = Array<Integer>();
201  for (int i = static_cast<int>(x->shape.size()) - 1; i >= 0; --i) {
202  axes.push_back(i);
203  }
204  }
205 
206  Array<PrimExpr> new_shape;
207  for (size_t i = 0; i < axes.size(); ++i) {
208  int axis = static_cast<int>(axes[i]->value);
209  int new_axis = axis;
210  if (axis < 0) {
211  new_axis = static_cast<int>(x->shape.size()) + axis;
212  axes.Set(i, new_axis);
213  }
214  ICHECK((new_axis >= 0) && (new_axis < static_cast<int>(x->shape.size())))
215  << "axis=" << axis << " is invalid for the " << static_cast<int>(x->shape.size())
216  << "-dimensional input tensor";
217 
218  for (size_t j = 0; j < axes.size(); ++j) {
219  if (i != j) {
220  ICHECK(new_axis != static_cast<int>(axes[j]->value)) << "repeated axis in transpose";
221  }
222  }
223  new_shape.push_back(x->shape[new_axis]);
224  }
225 
226  return compute(
227  new_shape,
228  [&](const Array<Var>& indices) {
229  std::vector<PrimExpr> idx;
230  for (size_t i = 0; i < axes.size(); ++i) {
231  idx.push_back(1);
232  }
233  for (size_t i = 0; i < axes.size(); ++i) {
234  int axis = static_cast<int>(axes[i]->value);
235  idx[axis] = indices[i];
236  }
237  return x(idx);
238  },
239  name, tag);
240 }
241 
256 inline Tensor reverse_sequence(const Tensor& x, const Tensor& seq_lengths, int seq_axis = 1,
257  int batch_axis = 0, std::string name = "T_reverse_sequence",
258  std::string tag = kInjective) {
259  size_t src_tensor_dim = x->shape.size();
260  int seq_axis_inp = seq_axis;
261 
262  if (seq_lengths.defined()) {
263  size_t seq_lengths_dim = seq_lengths->shape.size();
264  int batch_axis_inp = batch_axis;
265  if (batch_axis < 0) {
266  batch_axis = static_cast<int>(x->shape.size()) + batch_axis;
267  }
268 
269  ICHECK(seq_lengths_dim == 1) << "seq_lengths should be 1D vector";
270 
271  ICHECK(GetConstInt(seq_lengths->shape[0]) == GetConstInt(x->shape[batch_axis]))
272  << "For reverse_sequnece seq_lengths size should match with dimension of batch axis"
273  << ", but got dimension of batch_axis = " << GetConstInt(x->shape[batch_axis])
274  << ", and seq_length size = " << GetConstInt(seq_lengths->shape[0]);
275 
276  ICHECK((0 <= batch_axis) && (batch_axis < static_cast<int>(x->shape.size())))
277  << "batch_axis=" << batch_axis_inp << " is invalid for the "
278  << static_cast<int>(x->shape.size()) << "-dimensional input tensor";
279  }
280 
281  if (seq_axis < 0) {
282  seq_axis = static_cast<int>(x->shape.size()) + seq_axis;
283  }
284  ICHECK((0 <= seq_axis) && (seq_axis < static_cast<int>(x->shape.size())))
285  << "seq_axis=" << seq_axis_inp << " is invalid for the " << static_cast<int>(x->shape.size())
286  << "-dimensional input tensor";
287 
288  auto func = [&](const Array<Var>& indices) {
289  Array<PrimExpr> real_indices;
290  for (size_t i = 0; i < src_tensor_dim; ++i) {
291  if (i == static_cast<size_t>(seq_axis)) {
292  if (seq_lengths.defined()) {
293  auto len = seq_lengths(indices[batch_axis]);
294  auto idx = if_then_else(
295  len <= 1 || len <= indices[i], indices[i],
296  if_then_else(len > x->shape[i], x->shape[i] - 1 - indices[i], len - 1 - indices[i]));
297  real_indices.push_back(idx);
298  } else {
299  real_indices.push_back(x->shape[i] - 1 - indices[i]);
300  }
301  } else {
302  real_indices.push_back(indices[i]);
303  }
304  }
305  return x(real_indices);
306  };
307 
308  return compute(x->shape, func, name, tag);
309 }
310 
321 inline Tensor reshape(const Tensor& x, Array<PrimExpr> newshape, std::string name = "T_reshape",
322  std::string tag = kInjective) {
323  auto x_shape = x->shape;
324  Array<PrimExpr> target_shape;
325 
326  for (const auto& ele : newshape) {
327  target_shape.push_back(ele);
328  }
329 
330  // If either the input shape or the target shape contains a zero, return an empty tensor.
331  if (is_empty_shape(target_shape) || is_empty_shape(x->shape)) {
332  return compute(
333  target_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
334  } else {
335  return compute(
336  target_shape,
337  [&](const Array<Var>& indices) {
338  return x(UnravelIndex(
339  RavelIndex(Array<PrimExpr>{indices.begin(), indices.end()}, target_shape), x_shape));
340  },
341  name, tag);
342  }
343 }
344 
356 inline Tensor unravel_index(const Tensor& x, const Tensor& shape, std::string name = "T_unravel",
357  std::string tag = kInjective) {
358  auto x_shape = x->shape;
359  auto shape_shape = shape->shape;
360 
361  Array<PrimExpr> oshape;
362  oshape.push_back(shape_shape[0]);
363  if (x_shape.size() != 0) {
364  oshape.push_back(x_shape[0]);
365  }
366 
367  auto func = [&](const Array<Var>& indices) {
368  auto i = indices[0];
369  std::vector<PrimExpr> indices_divs;
370  PrimExpr ret = 0;
371  PrimExpr cur_val = 0;
372  PrimExpr index_val = 0;
373 
374  if (x_shape.size() != 0) {
375  index_val = x[indices[1]];
376  } else {
377  index_val = x();
378  }
379  indices_divs.push_back(index_val);
380  for (int v = GetConstInt(shape_shape[0]) - 1; v >= 0; --v) {
381  ret = tvm::if_then_else(i == v, indexmod(indices_divs.back(), shape[v]), ret);
382  cur_val = indexdiv(indices_divs.back(), shape[v]);
383  indices_divs.push_back(cur_val);
384  }
385  return ret;
386  };
387 
388  return compute(oshape, func, name, tag);
389 }
390 
404 inline Tensor squeeze(const Tensor& x, Array<Integer> axis, bool atleast1d = false,
405  std::string name = "T_squeeze", std::string tag = kInjective) {
406  auto ndim = x->shape.size();
407  std::vector<int> axis_val;
408  if (!axis.defined()) {
409  for (size_t i = 0; i < ndim; ++i) {
410  if (IsConstInt(x->shape[i]) && GetConstInt(x->shape[i]) == 1) {
411  axis_val.push_back(static_cast<int>(i));
412  }
413  }
414  } else {
415  for (size_t i = 0; i < axis.size(); ++i) {
416  int64_t val = axis[i]->value;
417  if (val < 0) {
418  val += static_cast<int>(x->shape.size());
419  }
420  if (IsConstInt(x->shape[val])) {
421  ICHECK_EQ(GetConstInt(x->shape[val]), 1) << "Dimension " << val << " must have size 1";
422  }
423  axis_val.push_back(val);
424  }
425  }
426 
427  std::unordered_set<int> axis_set(axis_val.begin(), axis_val.end());
428 
429  Array<PrimExpr> out_shape;
430  for (size_t i = 0; i < ndim; ++i) {
431  if (axis_set.count(static_cast<int>(i)) == 0) {
432  out_shape.push_back(x->shape[i]);
433  }
434  }
435  if (out_shape.size() == 0 && atleast1d) {
436  out_shape.push_back(1);
437  }
438 
439  return compute(
440  out_shape,
441  [&](const Array<Var>& indices) {
442  Array<PrimExpr> real_indices;
443  int flag = 0;
444  for (size_t i = 0; i < ndim; ++i) {
445  if (axis_set.count(static_cast<int>(i)) == 0) {
446  real_indices.push_back(indices[i - flag]);
447  } else {
448  real_indices.push_back(0);
449  flag += 1;
450  }
451  }
452  return x(real_indices);
453  },
454  name, tag);
455 }
456 
467 inline Tensor concatenate(const Array<Tensor>& inputs, int axis = 0, std::string name = "T_concat",
468  std::string tag = kInjective) {
469  int ndim = static_cast<int>(inputs[0]->shape.size());
470  ICHECK(-ndim <= axis && axis < ndim) << "concatenate only accepts `axis` in [-ndim, ndim)"
471  << ", but got axis = " << axis << ", and ndim = " << ndim;
472  if (axis < 0) {
473  axis += ndim;
474  }
475  ICHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds";
476 
477  Array<PrimExpr> axis_sizes;
478  for (auto t : inputs) {
479  axis_sizes.push_back(t->shape[axis]);
480  }
481  arith::Analyzer analyzer;
482  PrimExpr join_size = axis_sizes[0];
483  for (size_t i = 1; i < axis_sizes.size(); ++i) {
484  join_size += axis_sizes[i];
485  }
486  join_size = analyzer.Simplify(join_size);
487  Array<PrimExpr> out_shape;
488  for (size_t i = 0; i < inputs[0]->shape.size(); ++i) {
489  out_shape.push_back(i == static_cast<size_t>(axis) ? join_size : inputs[0]->shape[i]);
490  }
491 
492  return compute(
493  out_shape,
494  [&](const Array<Var>& indices) {
495  auto ret = inputs[0](indices);
496  auto ind = indices[axis];
497  for (size_t i = 0; i < inputs.size() - 1; ++i) {
498  ind -= axis_sizes[i];
499 
500  Array<PrimExpr> idx;
501  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
502  idx.push_back(indices[i]);
503  }
504  idx.push_back(ind);
505  for (size_t i = axis + 1; i < indices.size(); ++i) {
506  idx.push_back(indices[i]);
507  }
508 
509  ret = tvm::if_then_else(ind >= 0, inputs[i + 1](idx), ret);
510  }
511  return ret;
512  },
513  name, tag);
514 }
515 
526 inline Tensor stack(const Array<Tensor>& inputs, int axis = 0, std::string name = "T_stack",
527  std::string tag = kInjective) {
528  int ndim = static_cast<int>(inputs[0]->shape.size());
529  ICHECK(-ndim - 1 <= axis && axis <= ndim)
530  << "stack only accepts `axis` in [-ndim, ndim)"
531  << ", but got axis = " << axis << ", and ndim = " << ndim;
532  if (axis < 0) {
533  axis += ndim + 1;
534  }
535  ICHECK_LT(axis, inputs[0]->shape.size() + 1) << "axis out of bounds";
536 
537  const int stack_size = static_cast<int>(inputs.size());
538  Array<PrimExpr> out_shape;
539  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) out_shape.push_back(inputs[0]->shape[i]);
540  out_shape.push_back(stack_size);
541  for (size_t i = static_cast<size_t>(axis); i < static_cast<size_t>(ndim); ++i)
542  out_shape.push_back(inputs[0]->shape[i]);
543 
544  return compute(
545  out_shape,
546  [&](const Array<Var>& indices) {
547  Array<PrimExpr> idx;
548  for (size_t i = 0; i < indices.size(); ++i)
549  if (i != static_cast<size_t>(axis)) idx.push_back(indices[i]);
550  auto ind = indices[axis];
551  auto ret = inputs[0](idx);
552  for (int i = 0; i < static_cast<int>(inputs.size() - 1); ++i) {
553  ret = tvm::if_then_else(ind == i + 1, inputs[i + 1](idx), ret);
554  }
555  return ret;
556  },
557  name, tag);
558 }
559 
572 inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int axis,
573  std::string name = "T_split", std::string tag = kInjective) {
574  if (axis < 0) {
575  axis += static_cast<int>(x->shape.size());
576  }
577  ICHECK_LT(axis, x->shape.size()) << "axis out of bounds";
578 
579  auto src_axis_size = x->shape[axis];
580  std::vector<PrimExpr> begin_ids;
581  begin_ids.push_back(0);
582 
583  for (auto idx : split_indices) {
584  auto idx_node = idx.as<IntImmNode>();
585  auto back_node = begin_ids.back().as<IntImmNode>();
586  if (idx_node && back_node) {
587  ICHECK_GT(idx_node->value, back_node->value) << "split_indices must be sorted";
588  }
589  begin_ids.push_back(idx);
590  }
591 
592  Array<Array<PrimExpr>> out_shapes;
593  for (size_t i = 0; i < begin_ids.size(); ++i) {
594  PrimExpr out_axis_size;
595  if (i == begin_ids.size() - 1) {
596  out_axis_size = src_axis_size - begin_ids[i];
597  } else {
598  out_axis_size = begin_ids[i + 1] - begin_ids[i];
599  }
600 
602  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
603  shape.push_back(x->shape[i]);
604  }
605  shape.push_back(out_axis_size);
606  for (size_t i = axis + 1; i < x->shape.size(); ++i) {
607  shape.push_back(x->shape[i]);
608  }
609 
610  out_shapes.push_back(shape);
611  }
612 
613  Array<Tensor> result;
614  for (size_t i = 0; i < begin_ids.size(); ++i) {
615  result.push_back(compute(
616  out_shapes[i],
617  [&](const Array<Var>& indices) {
618  auto begin = begin_ids[i];
619  Array<PrimExpr> real_indices;
620  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
621  real_indices.push_back(indices[j]);
622  }
623  real_indices.push_back(indices[axis] + begin);
624  for (size_t j = axis + 1; j < indices.size(); ++j) {
625  real_indices.push_back(indices[j]);
626  }
627 
628  return x(real_indices);
629  },
630  name, tag));
631  }
632 
633  return result;
634 }
635 
649 inline Tensor dynamic_strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
650  const Array<PrimExpr>& end, const Array<PrimExpr>& strides,
651  std::string name = "T_dynamic_strided_slice",
652  std::string tag = kInjective) {
653  const size_t src_tensor_dim = x->shape.size();
654  ICHECK_LE(begin.size(), src_tensor_dim);
655  ICHECK_LE(end.size(), src_tensor_dim);
656  ICHECK_LE(strides.size(), src_tensor_dim);
657  ICHECK_EQ(begin.size(), end.size());
658  ICHECK_EQ(begin.size(), strides.size());
659 
660  const size_t num_slice_axes = begin.size();
661  Array<PrimExpr> out_shape;
662 
663  for (size_t i = 0; i < num_slice_axes; ++i) {
664  auto d = indexdiv(end[i] - begin[i], strides[i]);
665  if (d->IsInstance<tvm::IntImmNode>()) {
666  // Preserve static dimension if possible
667  out_shape.push_back(d);
668  } else {
669  out_shape.push_back(tvm::tir::Var("dim"));
670  }
671  }
672 
673  for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) {
674  out_shape.push_back(x->shape[i]);
675  }
676 
677  return te::compute(
678  out_shape,
679  [&](const Array<tvm::tir::Var>& indices) {
680  Array<PrimExpr> real_indices;
681  for (size_t i = 0; i < num_slice_axes; ++i) {
682  real_indices.push_back(indices[i] * strides[i] + tvm::min(begin[i], x->shape[i] - 1));
683  }
684  // keep input dim
685  for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) {
686  real_indices.push_back(indices[i]);
687  }
688  return x(real_indices);
689  },
690  name, tag);
691 }
692 
707  const te::Tensor& end, const te::Tensor& strides,
708  std::string name = "T_strided_slice_dynamic",
709  std::string tag = topi::kInjective) {
710  DataType index_dtype = begin->shape[0]->dtype;
711  const int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value;
712  ICHECK_EQ(end->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
713  ICHECK_EQ(strides->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
714 
715  Array<PrimExpr> begin_expr, end_expr, strides_expr;
716  for (int64_t i = 0; i < num_dynamic_axes; ++i) {
717  auto ind = make_const(index_dtype, i);
718  begin_expr.push_back(begin(ind));
719  end_expr.push_back(end(ind));
720  strides_expr.push_back(strides(ind));
721  }
722  return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, tag);
723 }
724 
740  const Array<PrimExpr>& ishape, const Array<Integer>& begin, const Array<Integer>& end,
741  const Array<Integer>& strides, const Array<Integer>& axes, const std::string& slice_mode) {
742  ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size());
743  std::vector<int64_t> begin_vec, end_vec, strides_vec;
744  std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode);
745  auto begin_canonicalized = StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes,
746  begin[0]->dtype, slice_mode);
747  return StridedSliceOutputShape(ishape, begin_vec, end_vec, strides_vec, axes, slice_mode,
748  begin_canonicalized, true);
749 }
750 
767 inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& begin,
768  const Array<Integer>& end, const Array<Integer>& strides,
769  const Array<Integer>& axes, std::string slice_mode = "end",
770  std::string name = "T_strided_slice_with_axes",
771  std::string tag = kInjective) {
772  const size_t src_tensor_dim = x->shape.size();
773  ICHECK(axes.size() <= src_tensor_dim);
774  ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size());
775 
776  std::vector<int64_t> begin_vec, end_vec, strides_vec;
777  std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode);
778 
779  auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, axes,
780  begin[0]->dtype, slice_mode);
781  auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec, strides_vec, axes,
782  slice_mode, begin_expr);
783 
784  return te::compute(
785  out_shape,
786  [&](const Array<tir::Var>& indices) {
787  Array<PrimExpr> real_indices;
788  for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]);
789  for (size_t i = 0; i < axes.size(); ++i) {
790  auto stride = make_const(strides[i].dtype(), strides_vec[i]);
791  PrimExpr ind = indices[axes[i].IntValue()] * stride + begin_expr[i];
792  real_indices.Set(axes[i].IntValue(), ind);
793  }
794  return x(real_indices);
795  },
796  name, tag);
797 }
798 
813 inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, const Array<Integer>& end,
814  const Array<Integer>& strides, std::string slice_mode = "end",
815  std::string name = "T_strided_slice", std::string tag = kInjective) {
816  size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
817  Array<Integer> axes;
818  for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i);
819  Array<Integer> begin_full(begin);
820  Array<Integer> end_full(end);
821  Array<Integer> strides_full(strides);
822 
823  DataType index_dtype = begin.size() > 0 ? begin[0]->dtype : DataType::Int(64);
824  const IntImm one = IntImm(index_dtype, 1);
825  const IntImm zero = IntImm(index_dtype, 0);
826  const IntImm max_range = Downcast<IntImm>(max_value(index_dtype));
827 
828  for (size_t i = strides.size(); i < src_tensor_dim; ++i) {
829  strides_full.push_back(one);
830  }
831  for (size_t i = begin.size(); i < src_tensor_dim; ++i) {
832  begin_full.push_back(GetConstInt(strides_full[i]) > 0 ? zero : max_range);
833  }
834  for (size_t i = end.size(); i < src_tensor_dim; ++i) {
835  end_full.push_back(GetConstInt(strides_full[i]) < 0 ? zero : max_range);
836  }
837 
838  return strided_slice_with_axes(x, begin_full, end_full, strides_full, axes, slice_mode, name,
839  tag);
840 }
841 
854 inline Array<Tensor> split_sections(const Tensor& x, int num_sections, int axis,
855  std::string name = "T_split_sections",
856  std::string tag = kInjective) {
857  if (axis < 0) {
858  axis += static_cast<int>(x->shape.size());
859  }
860  ICHECK_LT(axis, x->shape.size()) << "axis out of bounds";
861 
862  auto src_axis_size = x->shape[axis];
863 
864  ICHECK_GT(num_sections, 0) << "Slice count must be > 0";
865 
866  if (auto node = src_axis_size.as<IntImmNode>()) {
867  ICHECK_EQ(node->value % num_sections, 0)
868  << "num_sections must be an integer factor of the size of axis " << axis << " ("
869  << node->value << ")";
870  }
871 
872  Array<PrimExpr> split_indices;
873  auto seg_size = indexdiv(src_axis_size, num_sections);
874  for (int i = 0; i < num_sections; ++i) {
875  // region at index 0 is added by split()
876  if (i != 0) {
877  split_indices.push_back(seg_size * i);
878  }
879  }
880 
881  return split(x, split_indices, axis, name, tag);
882 }
883 
896 inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims,
897  std::string mode = "clip", std::string name = "T_take",
898  std::string tag = kInjective) {
899  Array<PrimExpr> a_shape = a->shape;
900  Array<PrimExpr> out_shape = indices->shape;
901  PrimExpr a_size = 1;
902  for (size_t i = 0; i < a_shape.size(); ++i) {
903  a_size = a_size * a_shape[i];
904  }
905 
906  if (mode == "clip") {
907  return compute(
908  out_shape,
909  [&](const Array<Var>& out_index) {
910  auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
911  return a(UnravelIndex(idx, a_shape));
912  },
913  name, tag);
914  } else if (mode == "fast") {
915  LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
916  "Make sure input indices are in bound";
917  return compute(
918  out_shape,
919  [&](const Array<Var>& out_index) { return a(UnravelIndex(indices(out_index), a_shape)); },
920  name, tag);
921  } else { // mode == "wrap"
922  return compute(
923  out_shape,
924  [&](const Array<Var>& out_index) {
925  auto idx = truncmod(truncmod(indices(out_index), a_size) + a_size, a_size);
926  return a(UnravelIndex(idx, a_shape));
927  },
928  name, tag);
929  }
930 }
931 
944 inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, double mask_value,
945  int axis, std::string name = "T_sequence_mask",
946  std::string tag = kInjective) {
947  ICHECK(axis == 0 || axis == 1) << "axis must be either 0 or 1";
948  ICHECK_EQ(valid_length->shape.size(), 1) << "valid_length must have ndim=1, i.e., (batch_size,).";
949  auto length_dim = data->shape[axis];
950  auto batch_dim = data->shape[1 - axis];
951  Array<PrimExpr> out_shape = data->shape;
952  Tensor out = compute(
953  out_shape,
954  [&](const Array<Var>& out_index) {
955  Array<PrimExpr> len_index;
956  auto tid = out_index[axis];
957  auto bid = out_index[1 - axis];
958  len_index.push_back(bid);
959  PrimExpr ret =
960  tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index),
961  tvm::tir::make_const(data->dtype, mask_value), data(out_index));
962  return ret;
963  },
964  name, tag);
965  return out;
966 }
967 
982 inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int axis,
983  std::string mode = "clip", std::string name = "T_take",
984  std::string tag = kInjective) {
985  if (axis < 0) {
986  axis += static_cast<int>(a->shape.size());
987  }
988  ICHECK_GE(axis, 0) << "axis out of bounds";
989  ICHECK_LT(axis, a->shape.size()) << "axis out of bounds";
990  auto axis_dim = a->shape[axis];
991  int indices_len = static_cast<int>(indices->shape.size());
992 
993  int batch_dims_ = batch_dims;
994  if (batch_dims_ != 0) {
995  ICHECK_GE(batch_dims_, -static_cast<int>(indices->shape.size())) << "batch_dims out of bounds";
996  ICHECK_LE(batch_dims_, indices->shape.size()) << "batch_dims out of bounds";
997 
998  if (batch_dims_ < 0) {
999  batch_dims_ = indices->shape.size() + batch_dims_;
1000  }
1001 
1002  ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds";
1003  ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to axis";
1004  for (int i = 0; i < batch_dims_; ++i) {
1005  auto addr1 = a->shape[i];
1006  auto addr2 = indices->shape[i];
1007  auto v1 = static_cast<IntImm*>(&addr1)->get()->value;
1008  auto v2 = static_cast<IntImm*>(&addr2)->get()->value;
1009  ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to indices.shape[" << i << "]";
1010  }
1011  }
1012 
1013  // The result shape is a.shape[:axis] + indices.shape[batch_dims:] +
1014  // a.shape[axis + 1:].
1015 
1016  Array<PrimExpr> out_shape;
1017  for (int i = 0; i < batch_dims_; ++i) {
1018  out_shape.push_back(a->shape[i]);
1019  }
1020  for (int i = batch_dims_; i < axis; ++i) {
1021  out_shape.push_back(a->shape[i]);
1022  }
1023  for (size_t i = static_cast<size_t>(batch_dims_); i < indices->shape.size(); ++i) {
1024  out_shape.push_back(indices->shape[i]);
1025  }
1026  for (size_t i = axis + 1; i < a->shape.size(); ++i) {
1027  out_shape.push_back(a->shape[i]);
1028  }
1029 
1030  if (mode == "clip") {
1031  if (batch_dims_ == 0) {
1032  return compute(
1033  out_shape,
1034  [&](const Array<Var>& out_index) {
1035  Array<PrimExpr> indices_position;
1036  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1037  indices_position.push_back(out_index[j]);
1038  }
1039  Array<PrimExpr> real_indices;
1040  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1041  real_indices.push_back(out_index[j]);
1042  }
1043  auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
1044  real_indices.push_back(idx);
1045  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1046  real_indices.push_back(out_index[j]);
1047  }
1048  return a(real_indices);
1049  },
1050  name, tag);
1051  } else {
1052  return compute(
1053  out_shape,
1054  [&](const Array<Var>& out_index) {
1055  Array<PrimExpr> indices_position;
1056  for (size_t j = 0; j < static_cast<size_t>(batch_dims_); ++j) {
1057  indices_position.push_back(out_index[j]);
1058  }
1059  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len - batch_dims_); ++j) {
1060  indices_position.push_back(out_index[j]);
1061  }
1062  Array<PrimExpr> real_indices;
1063  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1064  real_indices.push_back(out_index[j]);
1065  }
1066  auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
1067  real_indices.push_back(idx);
1068  for (size_t j = axis + indices_len - batch_dims_; j < out_index.size(); ++j) {
1069  real_indices.push_back(out_index[j]);
1070  }
1071  return a(real_indices);
1072  },
1073  name, tag);
1074  }
1075  } else if (mode == "fast") {
1076  LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
1077  "Make sure input indices are in bound";
1078  return compute(
1079  out_shape,
1080  [&](const Array<Var>& out_index) {
1081  Array<PrimExpr> indices_position;
1082  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1083  indices_position.push_back(out_index[j]);
1084  }
1085  Array<PrimExpr> real_indices;
1086  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1087  real_indices.push_back(out_index[j]);
1088  }
1089  real_indices.push_back(indices(indices_position));
1090  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1091  real_indices.push_back(out_index[j]);
1092  }
1093  return a(real_indices);
1094  },
1095  name, tag);
1096  } else { // mode == "wrap"
1097  return compute(
1098  out_shape,
1099  [&](const Array<Var>& out_index) {
1100  Array<PrimExpr> indices_position;
1101  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1102  indices_position.push_back(out_index[j]);
1103  }
1104  Array<PrimExpr> real_indices;
1105  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1106  real_indices.push_back(out_index[j]);
1107  }
1108  auto idx = truncmod(truncmod(indices(indices_position), axis_dim) + axis_dim, axis_dim);
1109  real_indices.push_back(idx);
1110  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1111  real_indices.push_back(out_index[j]);
1112  }
1113  return a(real_indices);
1114  },
1115  name, tag);
1116  }
1117 }
1118 
1130 inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y,
1131  std::string name = "T_where", std::string tag = kBroadcast) {
1132  ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs "
1133  << y->dtype;
1134  auto get_out_shape = [&]() {
1135  auto bh1 = detail::BroadcastShape(x->shape, y->shape);
1136  Array<PrimExpr> common_shape1(bh1.common_shape.begin(), bh1.common_shape.end());
1137  auto bh2 = detail::BroadcastShape(condition->shape, common_shape1);
1138  Array<PrimExpr> common_shape2(bh2.common_shape.begin(), bh2.common_shape.end());
1139  return common_shape2;
1140  };
1141 
1142  auto oshape = get_out_shape();
1143 
1144  auto c_bh = detail::BroadcastShape(condition->shape, oshape);
1145  auto x_bh = detail::BroadcastShape(x->shape, oshape);
1146  auto y_bh = detail::BroadcastShape(y->shape, oshape);
1147 
1148  auto select = [&](tvm::Array<tvm::tir::Var> ovars) {
1149  auto c = condition(InputIndexFromBroadcast(ovars, condition, c_bh.vars1, c_bh.all_vars));
1150  auto true_val = x(InputIndexFromBroadcast(ovars, x, x_bh.vars1, x_bh.all_vars));
1151  auto false_val = y(InputIndexFromBroadcast(ovars, y, y_bh.vars1, y_bh.all_vars));
1152  return tvm::tir::Select(c != 0, true_val, false_val);
1153  };
1154 
1155  return compute(oshape, select, name, tag);
1156 }
1157 
1170 inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = "T_repeat",
1171  std::string tag = kBroadcast) {
1172  int ndim = static_cast<int>(x->shape.size());
1173  ICHECK(-ndim - 1 <= axis && axis <= ndim)
1174  << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
1175  << ", but got axis = " << axis << ", and data.ndim = " << ndim;
1176  ICHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`"
1177  << ", but got repeats = " << repeats;
1178  if (axis < 0) {
1179  // Calculate offset from last dimension
1180  axis += ndim;
1181  }
1182  Array<PrimExpr> new_shape;
1183  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
1184  new_shape.push_back(x->shape[i]);
1185  }
1186  new_shape.push_back(repeats * x->shape[axis]);
1187  for (size_t i = axis + 1; i < x->shape.size(); ++i) {
1188  new_shape.push_back(x->shape[i]);
1189  }
1190 
1191  return compute(
1192  new_shape,
1193  [&](const Array<Var>& indices) {
1194  Array<PrimExpr> idx;
1195  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
1196  idx.push_back(indices[i]);
1197  }
1198  idx.push_back(indexdiv(indices[axis], repeats));
1199  for (size_t i = axis + 1; i < indices.size(); ++i) {
1200  idx.push_back(indices[i]);
1201  }
1202  return x(idx);
1203  },
1204  name, tag);
1205 }
1206 
1217 inline Tensor tile(const Tensor& x, Array<Integer> reps, std::string name = "T_tile",
1218  std::string tag = kBroadcast) {
1219  size_t ndim = x->shape.size();
1220  size_t rdim = reps.size();
1221  size_t tdim = (ndim > rdim) ? ndim : rdim;
1222  Array<PrimExpr> data_shape;
1223  Array<PrimExpr> reps_shape;
1224  Array<PrimExpr> new_shape;
1225  if (ndim == rdim) {
1226  for (size_t i = 0; i < ndim; ++i) {
1227  data_shape.push_back(x->shape[i]);
1228  reps_shape.push_back(reps[i]);
1229  }
1230  } else if (ndim > rdim) {
1231  for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]);
1232  for (size_t i = 0; i < (ndim - rdim); ++i) reps_shape.push_back(1);
1233  for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]);
1234  } else {
1235  for (size_t i = 0; i < (rdim - ndim); ++i) data_shape.push_back(1);
1236  for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]);
1237  for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]);
1238  }
1239  for (size_t i = 0; i < tdim; ++i) new_shape.push_back(data_shape[i] * reps_shape[i]);
1240 
1241  if (is_empty_shape(new_shape)) {
1242  return compute(
1243  new_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
1244  } else {
1245  return compute(
1246  new_shape,
1247  [&](const Array<Var>& indices) {
1248  Array<PrimExpr> idx;
1249  if (ndim >= rdim) {
1250  for (size_t i = 0; i < ndim; ++i) idx.push_back(indexmod(indices[i], x->shape[i]));
1251  } else {
1252  for (size_t i = 0; i < ndim; ++i)
1253  idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
1254  }
1255  return x(idx);
1256  },
1257  name, tag);
1258  }
1259 }
1260 
1272 inline Tensor dyn_tile(const Tensor& x, Array<PrimExpr> new_shape, size_t rdim,
1273  std::string name = "T_tile", std::string tag = kBroadcast) {
1274  size_t ndim = x->shape.size();
1275  if (is_empty_shape(new_shape)) {
1276  return compute(
1277  new_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
1278  } else {
1279  return compute(
1280  new_shape,
1281  [&](const Array<Var>& indices) {
1282  Array<PrimExpr> idx;
1283  if (ndim >= rdim) {
1284  for (size_t i = 0; i < ndim; ++i) {
1285  idx.push_back(indexmod(indices[i], x->shape[i]));
1286  }
1287  } else {
1288  for (size_t i = 0; i < ndim; ++i) {
1289  idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
1290  }
1291  }
1292  return x(idx);
1293  },
1294  name, tag);
1295  }
1296 }
1297 
1309 inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
1310  std::string name = "T_gather", std::string tag = kInjective) {
1311  size_t ndim_d = data->shape.size();
1312  size_t ndim_i = indices->shape.size();
1313  ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar.";
1314  ICHECK_EQ(ndim_d, ndim_i);
1315  if (axis < 0) {
1316  axis += ndim_d;
1317  }
1318  ICHECK_GE(axis, 0);
1319  ICHECK_LT(axis, ndim_d);
1320  if (indices->shape[axis].as<IntImmNode>()) {
1321  size_t indices_dim_i = static_cast<size_t>(GetConstInt(indices->shape[axis]));
1322  ICHECK_GE(indices_dim_i, 1);
1323  }
1324  ICHECK(indices->dtype.is_int() || indices->dtype.is_uint());
1325 
1326  Array<PrimExpr> out_shape;
1327  for (size_t i = 0; i < ndim_i; ++i) {
1328  out_shape.push_back(indices->shape[i]);
1329  }
1330 
1331  return compute(
1332  out_shape,
1333  [&](const Array<Var>& out_index) {
1334  Array<PrimExpr> indices_position;
1335  for (size_t i = 0; i < ndim_i; ++i) {
1336  indices_position.push_back(out_index[i]);
1337  }
1338  Array<PrimExpr> real_indices;
1339  for (size_t i = 0; i < ndim_i; ++i) {
1340  if (i == static_cast<size_t>(axis)) {
1341  real_indices.push_back(indices(indices_position));
1342  } else {
1343  real_indices.push_back(indices_position[i]);
1344  }
1345  }
1346  return data(real_indices);
1347  },
1348  name, tag);
1349 }
1350 
1362 inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dims = 0,
1363  std::string name = "T_gather_nd", std::string tag = kInjective) {
1364  size_t ndim_d = data->shape.size();
1365  size_t ndim_i = indices->shape.size();
1366  ICHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions";
1367  size_t indices_dim0 = static_cast<size_t>(GetConstInt(indices->shape[0]));
1368  ICHECK_LE(indices_dim0, ndim_d) << "dim 0 of indices tensor must be no more "
1369  << "than dimensions of data tensor";
1370  Array<PrimExpr> out_shape;
1371  for (size_t i = 1; i < ndim_i; ++i) {
1372  out_shape.push_back(indices->shape[i]);
1373  }
1374  for (size_t i = indices_dim0 + batch_dims; i < ndim_d; ++i) {
1375  out_shape.push_back(data->shape[i]);
1376  }
1377  return compute(
1378  out_shape,
1379  [&](const Array<Var>& out_index) {
1380  Array<PrimExpr> indices_position;
1381  indices_position.push_back(0);
1382  for (size_t i = 0; i < ndim_i - 1; ++i) {
1383  indices_position.push_back(out_index[i]);
1384  }
1385  Array<PrimExpr> real_indices;
1386  for (size_t i = 0; i < static_cast<size_t>(batch_dims); ++i) {
1387  real_indices.push_back(out_index[i]);
1388  }
1389  for (size_t i = 0; i < indices_dim0; ++i) {
1390  indices_position.Set(0, make_const(DataType::Int(32), i));
1391  if (indices->dtype.is_int() || indices->dtype.is_uint()) {
1392  real_indices.push_back(indices(indices_position));
1393  } else {
1394  real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position)));
1395  }
1396  }
1397  if (real_indices.size() == ndim_d) {
1398  return data(real_indices);
1399  }
1400  for (size_t i = ndim_i - 1; i < out_index.size(); ++i) {
1401  real_indices.push_back(out_index[i]);
1402  }
1403  return data(real_indices);
1404  },
1405  name, tag);
1406 }
1407 
1424  bool trans_a = false, bool trans_b = false,
1425  std::string name = "T_matmul", std::string tag = kMatMul) {
1426  tvm::Array<tvm::PrimExpr> output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]};
1427  auto k = tvm::te::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k");
1428  auto l = [&](tvm::tir::Var i, tvm::tir::Var j) {
1429  return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), {k});
1430  };
1431  return tvm::te::compute(output_shape, l, name, tag);
1432 }
1433 
1445 inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, int axes = 2,
1446  std::string name = "T_tensordot", std::string tag = kMatMul) {
1447  ICHECK_GE(A->shape.size(), axes);
1448  ICHECK_GE(B->shape.size(), axes);
1449 
1450  Array<PrimExpr> output_shape(A->shape.begin(), A->shape.end() + (-axes));
1451  for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) output_shape.push_back(*it);
1452 
1453  Array<IterVar> iter_vars;
1454  for (int i = 0; i < axes; ++i)
1455  iter_vars.push_back(reduce_axis(Range(0, B->shape[i]), "k" + std::to_string(i)));
1456 
1457  auto func = [&A, &B, &iter_vars, axes](const Array<Var>& input_indices) {
1458  Array<PrimExpr> A_indices(input_indices.begin(),
1459  input_indices.begin() + (A->shape.size() - axes));
1460  for (auto& v : iter_vars) A_indices.push_back(v);
1461 
1462  Array<PrimExpr> B_indices;
1463  for (auto& v : iter_vars) B_indices.push_back(v);
1464 
1465  auto it = input_indices.begin() + (A->shape.size() - axes);
1466  for (; it != input_indices.end(); ++it) B_indices.push_back(*it);
1467 
1468  // Some passes don't like reductions with empty axis, so avoid it here
1469  if (iter_vars.empty()) {
1470  return A(A_indices) * B(B_indices);
1471  } else {
1472  return sum(A(A_indices) * B(B_indices), iter_vars);
1473  }
1474  };
1475 
1476  return compute(output_shape, func, name, tag);
1477 }
1478 
1491 inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, Array<PrimExpr> A_axes,
1492  Array<PrimExpr> B_axes, std::string name = "T_tensordot",
1493  std::string tag = kMatMul) {
1494  ICHECK_EQ(A_axes.size(), B_axes.size());
1495 
1496  auto A_axes_val = GetConstIntValues(A_axes, "A_axes");
1497  auto B_axes_val = GetConstIntValues(B_axes, "B_axes");
1498 
1499  Array<PrimExpr> output_shape;
1500  for (unsigned i = 0; i < A->shape.size(); ++i)
1501  if (std::find(A_axes_val.begin(), A_axes_val.end(), i) == A_axes_val.end())
1502  output_shape.push_back(A->shape[i]);
1503  for (unsigned i = 0; i < B->shape.size(); ++i)
1504  if (std::find(B_axes_val.begin(), B_axes_val.end(), i) == B_axes_val.end())
1505  output_shape.push_back(B->shape[i]);
1506 
1507  Array<IterVar> iter_vars;
1508  for (unsigned i = 0; i < B_axes_val.size(); ++i)
1509  iter_vars.push_back(reduce_axis(Range(0, B->shape[B_axes_val[i]]), "k" + std::to_string(i)));
1510 
1511  auto func = [&A, &B, &iter_vars, A_axes_val, B_axes_val](const Array<Var>& input_indices) {
1512  int idx_input = 0;
1513  Array<PrimExpr> A_indices;
1514  for (unsigned i = 0; i < A->shape.size(); ++i) {
1515  auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i);
1516  if (axes_pos == A_axes_val.end()) {
1517  A_indices.push_back(input_indices[idx_input++]);
1518  } else {
1519  A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]);
1520  }
1521  }
1522 
1523  Array<PrimExpr> B_indices;
1524  for (unsigned i = 0; i < B->shape.size(); ++i) {
1525  auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i);
1526  if (axes_pos == B_axes_val.end()) {
1527  B_indices.push_back(input_indices[idx_input++]);
1528  } else {
1529  B_indices.push_back(iter_vars[axes_pos - B_axes_val.begin()]);
1530  }
1531  }
1532  return sum(A(A_indices) * B(B_indices), iter_vars);
1533  };
1534  return compute(output_shape, func, name, tag);
1535 }
1536 
1537 inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr& step,
1538  DataType dtype, std::string name = "T_arange", std::string tag = kInjective) {
1539  PrimExpr num_elem = tvm::cast(
1540  tvm::DataType::Int(32), tvm::ceil(tvm::cast(tvm::DataType::Float(32), stop - start) / step));
1542  return compute(
1543  {num_elem},
1544  [&](const Array<Var>& indices) { return tvm::cast(dtype, start + step * indices[0]); }, name,
1545  tag);
1546 }
1547 
1558 inline Array<Tensor> meshgrid(const Array<Tensor>& inputs, const std::string& indexing,
1559  std::string name = "T_meshgrid", std::string tag = kInjective) {
1560  const bool cartesian_indexing = indexing == "xy" && inputs.size() >= 2;
1561  Array<PrimExpr> out_shape;
1562  for (size_t i = 0; i < inputs.size(); ++i) {
1563  const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i;
1564  out_shape.push_back(inputs[src_index]->shape.size() == 0 ? 1 : inputs[src_index]->shape[0]);
1565  }
1566  Array<Tensor> result;
1567  for (size_t i = 0; i < inputs.size(); ++i) {
1568  result.push_back(compute(
1569  out_shape,
1570  [&](const Array<Var>& indices) {
1571  const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i;
1572  auto ndim = inputs[i]->GetShape().size();
1573  Array<PrimExpr> real_indices = {};
1574  if (ndim > 0) {
1575  real_indices = {indices[src_index]};
1576  }
1577  return inputs[i](real_indices);
1578  },
1579  name, tag));
1580  }
1581  return result;
1582 }
1583 
1594 inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
1595  const std::string& dst_layout,
1596  const std::string schedule_rule = "None",
1597  const std::string name = "T_layout_trans",
1598  const std::string tag = kInjective) {
1599  Layout src_layout_struct(src_layout);
1600  Layout dst_layout_struct(dst_layout);
1601 
1602  if (src_layout_struct.Equals(dst_layout_struct)) {
1603  return src;
1604  }
1605 
1606  ICHECK(src_layout_struct.defined() && dst_layout_struct.defined())
1607  << "cannot convert from/to undefined layout";
1608 
1609  auto layout_converter = tir::BijectiveLayout(src_layout_struct, dst_layout_struct);
1610  ICHECK(layout_converter.defined())
1611  << "cannot convert from " << src_layout << " to " << dst_layout;
1612 
1613  Array<PrimExpr> dst_shape = layout_converter.ForwardShape(src->shape);
1614 
1615  Map<String, ObjectRef> attrs = {{"schedule_rule", String(schedule_rule)},
1616  // Information about layouts needed for the schedule rule
1617  {"src_layout", String(src_layout)},
1618  {"dst_layout", String(dst_layout)},
1619  {"input_shape", src->shape}};
1620 
1621  return compute(
1622  dst_shape,
1623  [&](const Array<Var>& dst_indices) {
1624  Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
1625  Array<PrimExpr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
1626  PrimExpr in_range = PrimExpr(1) > PrimExpr(0); // init with dtype=bool and value=true
1627  for (size_t i = 0; i < src.ndim(); ++i) {
1628  in_range = in_range && (src_indices[i] < src->shape[i]);
1629  }
1630  return if_then_else(in_range, src(src_indices), tvm::cast(src->dtype, PrimExpr(0)));
1631  },
1632  name, tag, attrs);
1633 }
1634 
1637  std::vector<std::string>* axes) {
1638  int32_t factor = 0;
1639  std::string axis = "";
1640  for (char c : std::string(layout)) {
1641  if (c >= 'A' && c <= 'z') {
1642  axis += c;
1643  if (factor != 0) {
1644  shape->push_back(factor);
1645  factor = 0;
1646  }
1647  } else if (c >= '0' && c <= '9') {
1648  factor = factor * 10 + c - '0';
1649  if (!axis.empty()) {
1650  axes->push_back(axis);
1651  axis = "";
1652  }
1653  } else {
1654  LOG(FATAL) << "Invalid layout " << layout;
1655  }
1656  }
1657  if (!axis.empty()) {
1658  axes->push_back(axis);
1659  }
1660 }
1661 
1672 inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String& src_layout,
1673  const String& dst_layout,
1674  const String name = "T_auto_scheduler_layout_trans",
1675  const String tag = kInjective) {
1676  Array<PrimExpr> src_shape;
1677  std::vector<std::string> src_axes;
1678  Array<PrimExpr> dst_shape;
1679  std::vector<std::string> dst_axes;
1680 
1681  parse_auto_scheduler_layout(src_layout, &src_shape, &src_axes);
1682  parse_auto_scheduler_layout(dst_layout, &dst_shape, &dst_axes);
1683  return compute(
1684  dst_shape,
1685  [&](const Array<Var>& dst_indices) {
1686  Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
1687  Array<PrimExpr> src_indices;
1688  for (const std::string& src_axis : src_axes) {
1689  PrimExpr src_index = 0;
1690  CHECK_EQ(dst_indices_expr.size(), dst_axes.size());
1691  for (size_t i = 0; i < dst_axes.size(); ++i) {
1692  if (dst_axes[i] == src_axis) {
1693  src_index = src_index * dst_shape[i] + dst_indices_expr[i];
1694  }
1695  }
1696  src_indices.push_back(src_index);
1697  }
1698  return src(src_indices);
1699  },
1700  name, tag);
1701 }
1702 
1739 inline Tensor meta_schedule_layout_transform(const Tensor& src, const tir::IndexMap& index_map,
1740  const String name = "T_meta_schedule_layout_trans",
1741  const String tag = kInjective) {
1742  arith::Analyzer analyzer;
1743  Array<Range> iter_domain;
1744  iter_domain.reserve(src->shape.size());
1745  for (const PrimExpr& e : src->shape) {
1746  iter_domain.push_back(Range::FromMinExtent(make_zero(e->dtype), e));
1747  }
1748  Array<PrimExpr> post_transform_shape = index_map->MapShape(src->shape, &analyzer);
1749  return compute(
1750  post_transform_shape,
1751  [src, inv = index_map.Inverse(iter_domain, &analyzer),
1752  &analyzer](const Array<Var>& indices) -> PrimExpr {
1753  return src(inv->MapIndices(Array<PrimExpr>{indices.begin(), indices.end()}, &analyzer));
1754  },
1755  name, tag);
1756 }
1757 
1766 inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = "T_shape",
1767  const std::string tag = kInjective) {
1768  int ndim = static_cast<int>(src->shape.size());
1769  Array<PrimExpr> out_shape{ndim};
1770  return compute(
1771  out_shape,
1772  [&](const Array<Var>& indices) {
1773  auto idx = indices[0];
1774  PrimExpr ret = 0;
1775  for (int i = 0; i < ndim; ++i) {
1776  ret = tvm::if_then_else(idx == i, src->shape[i], ret);
1777  }
1778  return tvm::cast(dtype, ret);
1779  },
1780  name, tag);
1781 }
1782 
1791 inline Tensor ndarray_size(const Tensor& src, const DataType& dtype,
1792  const std::string& name = "ndarray_size",
1793  const std::string& tag = kInjective) {
1794  int ndim = static_cast<int>(src->shape.size());
1795  Array<PrimExpr> out_ndarray_size = {};
1796  return compute(
1797  out_ndarray_size,
1798  [&](const Array<Var>& indices) {
1799  PrimExpr ret = 1;
1800  for (int i = 0; i < ndim; ++i) {
1801  ret *= src->shape[i];
1802  }
1803  return tvm::cast(dtype, ret);
1804  },
1805  name, tag);
1806 }
1807 
1822 inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value,
1823  int depth, int axis, const DataType& dtype,
1824  Array<PrimExpr> oshape = Array<PrimExpr>(),
1825  const std::string name = "T_one_hot", const std::string tag = kInjective) {
1826  int true_axis = (axis == -1) ? indices->shape.size() : axis;
1827  if (oshape.size() == 0) {
1828  int ndim = indices->shape.size() + 1;
1829  int indices_index = 0;
1830  for (int i = 0; i < ndim; i++) {
1831  if (i == true_axis) {
1832  oshape.push_back(Integer(depth));
1833  } else {
1834  oshape.push_back(indices->shape[indices_index++]);
1835  }
1836  }
1837  }
1838 
1839  PrimExpr on_value_cast = cast(dtype, on_value);
1840  PrimExpr off_value_cast = cast(dtype, off_value);
1841  return compute(
1842  oshape,
1843  [&](const Array<Var>& iter_vars) {
1844  Array<Var> indices_indices;
1845  for (size_t i = 0; i < iter_vars.size(); i++) {
1846  if (static_cast<int>(i) == true_axis) {
1847  continue;
1848  }
1849 
1850  indices_indices.push_back(iter_vars[i]);
1851  }
1852 
1853  auto idx = iter_vars[true_axis];
1854  return tir::Select(indices(indices_indices) == idx, on_value_cast, off_value_cast);
1855  },
1856  name, tag);
1857 }
1858 
1869 inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array<PrimExpr>& output_shape,
1870  const Tensor& sparse_values, const PrimExpr& default_value,
1871  const std::string name = "T_sparse_to_dense",
1872  const std::string tag = kInjective) {
1873  ICHECK(sparse_indices->dtype.is_int()) << "sparse_indices only accepts integer values";
1874  ICHECK_LE(sparse_indices->shape.size(), 3)
1875  << "sparse_indices tensor should be 0D, 1D, or 2D only";
1876  ICHECK_LE(sparse_values->shape.size(), 2) << "sparse_values tensor should be 0D or 1D only";
1877 
1878  const auto rank_sparse_indices = static_cast<int>(sparse_indices->shape.size());
1879  Array<PrimExpr> oshape;
1880  for (auto l : output_shape) {
1881  oshape.push_back(l);
1882  }
1883  return compute(
1884  oshape,
1885  [&](const Array<Var>& indices) {
1886  PrimExpr ret = default_value;
1887  if (0 == rank_sparse_indices) {
1888  ret = if_then_else(indices[0] == sparse_indices(), sparse_values(), ret);
1889  } else if (1 == rank_sparse_indices) {
1890  for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
1891  ret = if_then_else(indices[0] == sparse_indices[j], sparse_values[j], ret);
1892  }
1893  } else {
1894  for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
1895  PrimExpr aggregate_condition;
1896  for (int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) {
1897  PrimExpr comparision = indices[k] == sparse_indices[j][k];
1898  aggregate_condition = 0 == k ? comparision : aggregate_condition && comparision;
1899  }
1900  ret = if_then_else(aggregate_condition, sparse_values[j], ret);
1901  }
1902  }
1903  return ret;
1904  },
1905  name, tag);
1906 }
1907 
1920 inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k1, int k2,
1921  bool super_diag_right_align, bool sub_diag_right_align,
1922  const std::string name = "T_matrix_set_diag",
1923  const std::string tag = kInjective) {
1924  size_t ndim = input->shape.size() - 1;
1925 
1926  bool only_one_diagonal = k1 == k2;
1927 
1928  return compute(
1929  input->shape,
1930  [&](const Array<Var>& iter_vars) {
1931  auto get_diag = [&]() {
1932  Array<PrimExpr> diagonal_indices;
1933  PrimExpr k, offset = 0;
1934  for (size_t i = 0; i < ndim - 1; i++) {
1935  diagonal_indices.push_back(iter_vars[i]);
1936  }
1937  if (only_one_diagonal) {
1938  k = k1;
1939  } else {
1940  // Determining which diagonal/sub-diagonal/super-diagonal it is
1941  k = iter_vars[ndim] - iter_vars[ndim - 1];
1942  diagonal_indices.push_back(k2 - k);
1943 
1944  // Calculating the offset in diagonal tensor for this diagonal
1945  auto get_offset = [&](PrimExpr M, PrimExpr N) {
1946  // offset = max_diagonal_length - diagonal_length
1947  return diagonal->shape[diagonal->shape.size() - 1] - if_then_else(M < N, M, N);
1948  };
1949  offset = if_then_else(
1950  k >= 0,
1951  super_diag_right_align ? get_offset(input->shape[ndim] - k, input->shape[ndim - 1])
1952  : 0,
1953  sub_diag_right_align ? get_offset(input->shape[ndim], input->shape[ndim - 1] + k)
1954  : 0);
1955  }
1956  diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) +
1957  offset);
1958  return diagonal(diagonal_indices);
1959  };
1960  return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1,
1961  if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2,
1962  get_diag(), input(iter_vars)),
1963  input(iter_vars));
1964  },
1965  name, tag);
1966 }
1967 
1976 inline Tensor adv_index(const Tensor& data, const Array<Tensor>& indices,
1977  const std::string name = "advanced_index",
1978  const std::string tag = kInjective) {
1979  ICHECK_LE(indices.size(), data->shape.size()) << "too many indices for data!";
1980  Array<PrimExpr> oshape;
1981  Array<PrimExpr> broadcast_shape;
1982  Array<Tensor> bindices;
1983 
1984  broadcast_shape = indices[0]->shape;
1985  for (size_t i = 1; i < indices.size(); ++i) {
1986  auto bh = detail::BroadcastShape(broadcast_shape, indices[i]->shape);
1987  broadcast_shape = Array<PrimExpr>(bh.common_shape.begin(), bh.common_shape.end());
1988  }
1989  if (indices.size() == 1) {
1990  // quick path
1991  bindices = indices;
1992  } else {
1993  // Do broadcast for indices
1994  for (size_t i = 0; i < indices.size(); ++i) {
1995  bindices.push_back(broadcast_to(indices[i], broadcast_shape));
1996  }
1997  }
1998 
1999  for (const auto& dim : broadcast_shape) {
2000  oshape.push_back(dim);
2001  }
2002  for (size_t i = indices.size(); i < data->shape.size(); ++i) {
2003  oshape.push_back(data->shape[i]);
2004  }
2005 
2006  return compute(
2007  oshape,
2008  [&](const Array<Var>& iter_var) {
2009  Array<PrimExpr> tensor_indices;
2010  for (size_t i = 0; i < broadcast_shape.size(); ++i) {
2011  tensor_indices.push_back(iter_var[i]);
2012  }
2013 
2014  Array<PrimExpr> real_indices;
2015  for (size_t i = 0; i < bindices.size(); ++i) {
2016  real_indices.push_back(bindices[i](tensor_indices));
2017  }
2018  for (size_t i = broadcast_shape.size(); i < iter_var.size(); ++i) {
2019  real_indices.push_back(iter_var[i]);
2020  }
2021 
2022  return data(real_indices);
2023  },
2024  name, tag);
2025 }
2026 
2027 } // namespace topi
2028 } // namespace tvm
2029 #endif // TVM_TOPI_TRANSFORM_H_
Algebra expression simplifications.
Broadcast op constructions.
Constant integer literals in the program.
Definition: expr.h:491
int64_t value
the Internal value.
Definition: expr.h:494
Managed reference class to IntImmNode.
Definition: expr.h:520
Container of constant int that adds more constructors.
Definition: expr.h:622
Reference to PrimExprNode.
Definition: expr.h:114
DataType dtype() const
Definition: expr.h:128
Range container
Definition: expr.h:715
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span=Span())
construct a new range with min and extent The corresponding constructor is removed,...
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:600
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void reserve(int64_t n)
Make sure the list has the capacity of at least n.
Definition: array.h:569
iterator end() const
Definition: array.h:390
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
void Set(int64_t i, T value)
set i-th element of the array.
Definition: array.h:621
iterator begin() const
Definition: array.h:387
size_t size() const
Definition: array.h:420
Runtime primitive data type.
Definition: data_type.h:42
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:190
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:176
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
bool defined() const
Definition: object.h:550
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:894
Reference to string objects.
Definition: string.h:98
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
size_t ndim() const
Definition: tensor.h:214
Bijective function mapping for data layout transformation. Given two Layout, BijectiveLayout build an...
Definition: data_layout.h:332
Definition: index_map.h:176
IndexMap Inverse(Array< Range > initial_ranges, arith::Analyzer *analyzer) const
Generate the inverse mapping.
Managed reference to LayoutNode.
Definition: data_layout.h:123
bool Equals(const Layout &rhs) const
Whether the two layouts are equal.
Definition: data_layout.h:278
Managed reference to SelectNode.
Definition: expr.h:609
a named variable in TIR
Definition: var.h:88
Utility functions for handling constants in TVM expressions.
Layout expression to describe the data organization of a tensor. And BijectiveLayout to mapping two d...
Detail broadcast.
Defines a remapping of buffer indices.
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:961
PrimExpr make_zero(DataType t, Span span=Span())
Make a const zero expr.
Definition: op.h:969
Tensor sequence_mask(const Tensor &data, const Tensor &valid_length, double mask_value, int axis, std::string name="T_sequence_mask", std::string tag=kInjective)
Mask the out-of-boundary elements of each sequence.
Definition: transform.h:944
Tensor gather_nd(const Tensor &data, const Tensor &indices, int batch_dims=0, std::string name="T_gather_nd", std::string tag=kInjective)
Gather elements from a n-dimension array.
Definition: transform.h:1362
constexpr auto kBroadcast
Definition: tags.h:36
Tensor transpose(const Tensor &x, Array< Integer > axes, std::string name="T_transpose", std::string tag=kInjective)
Permute the dimensions of an array.
Definition: transform.h:197
Tensor arange(const PrimExpr &start, const PrimExpr &stop, const PrimExpr &step, DataType dtype, std::string name="T_arange", std::string tag=kInjective)
Definition: transform.h:1537
Tensor strided_slice(const Tensor &x, const Array< Integer > &begin, const Array< Integer > &end, const Array< Integer > &strides, std::string slice_mode="end", std::string name="T_strided_slice", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:813
constexpr auto kInjective
Definition: tags.h:33
Tensor dynamic_strided_slice(const Tensor &x, const Array< PrimExpr > &begin, const Array< PrimExpr > &end, const Array< PrimExpr > &strides, std::string name="T_dynamic_strided_slice", std::string tag=kInjective)
strided_slice of a tensor where begin/end/stride can be mixed static and dynamic
Definition: transform.h:649
Tensor sliding_window(const Tensor &x, int axis, Array< Integer > window_shape, Array< Integer > strides, std::string name="T_sliding_window", std::string tag="")
Creates an operation to slide a window over the input x.
Definition: transform.h:69
Tensor reshape(const Tensor &x, Array< PrimExpr > newshape, std::string name="T_reshape", std::string tag=kInjective)
Reshape a tensor.
Definition: transform.h:321
Tensor one_hot(const Tensor &indices, const PrimExpr on_value, const PrimExpr off_value, int depth, int axis, const DataType &dtype, Array< PrimExpr > oshape=Array< PrimExpr >(), const std::string name="T_one_hot", const std::string tag=kInjective)
Returns a one-hot tensor where the locations repsented by indices take value on_value,...
Definition: transform.h:1822
Tensor meta_schedule_layout_transform(const Tensor &src, const tir::IndexMap &index_map, const String name="T_meta_schedule_layout_trans", const String tag=kInjective)
Transform the meta-schedule generated layout according to TIR's IndexMap.
Definition: transform.h:1739
Array< Tensor > meshgrid(const Array< Tensor > &inputs, const std::string &indexing, std::string name="T_meshgrid", std::string tag=kInjective)
Produce grids by expanding input over dimensions defined by other inputs.
Definition: transform.h:1558
Tensor tile(const Tensor &x, Array< Integer > reps, std::string name="T_tile", std::string tag=kBroadcast)
Creates an operation to tile elements of an array.
Definition: transform.h:1217
tvm::te::Tensor broadcast_to(const tvm::te::Tensor &t, const tvm::Array< tvm::PrimExpr > &output_shape, std::string name="T_broadcast_to", std::string tag=kBroadcast)
Creates an operation that broadcasts a tensor into a compatible shape according to numpy's rules.
Definition: broadcast.h:48
Tensor dyn_tile(const Tensor &x, Array< PrimExpr > new_shape, size_t rdim, std::string name="T_tile", std::string tag=kBroadcast)
Creates an operation to tile elements of an array.
Definition: transform.h:1272
Tensor adv_index(const Tensor &data, const Array< Tensor > &indices, const std::string name="advanced_index", const std::string tag=kInjective)
Numpy style advanced indexing with tensor.
Definition: transform.h:1976
Tensor concatenate(const Array< Tensor > &inputs, int axis=0, std::string name="T_concat", std::string tag=kInjective)
Join a sequence of tensors along an existing axis.
Definition: transform.h:467
void parse_auto_scheduler_layout(const String &layout, Array< PrimExpr > *shape, std::vector< std::string > *axes)
Utility function for auto_scheduler_layout_transform.
Definition: transform.h:1636
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type,...
Definition: elemwise.h:281
Tensor expand_dims(const Tensor &x, int axis, int num_newaxis=1, std::string name="T_expand_dims", std::string tag=kBroadcast)
Creates an operation to insert new dimensions of length 1.
Definition: transform.h:148
Tensor squeeze(const Tensor &x, Array< Integer > axis, bool atleast1d=false, std::string name="T_squeeze", std::string tag=kInjective)
Remove size 1 dimensions from the shape of a tensor. The removed dimensions must have a constant size...
Definition: transform.h:404
Tensor sparse_to_dense(const Tensor &sparse_indices, const Array< PrimExpr > &output_shape, const Tensor &sparse_values, const PrimExpr &default_value, const std::string name="T_sparse_to_dense", const std::string tag=kInjective)
Get a dense tensor.
Definition: transform.h:1869
Tensor unravel_index(const Tensor &x, const Tensor &shape, std::string name="T_unravel", std::string tag=kInjective)
Converts a flat index or array of flat indices into a tuple of coordinate arrays.
Definition: transform.h:356
Tensor auto_scheduler_layout_transform(const Tensor &src, const String &src_layout, const String &dst_layout, const String name="T_auto_scheduler_layout_trans", const String tag=kInjective)
Transform the auto-scheduler generated layout according to src_layout and dst_layout.
Definition: transform.h:1672
Tensor ndarray_size(const Tensor &src, const DataType &dtype, const std::string &name="ndarray_size", const std::string &tag=kInjective)
Get the size of input tensor.
Definition: transform.h:1791
Tensor layout_transform(const Tensor &src, const std::string &src_layout, const std::string &dst_layout, const std::string schedule_rule="None", const std::string name="T_layout_trans", const std::string tag=kInjective)
Transform the layout according to src_layout and dst_layout.
Definition: transform.h:1594
Tensor take(const Tensor &a, const Tensor &indices, int batch_dims, std::string mode="clip", std::string name="T_take", std::string tag=kInjective)
Take elements from an flattened input array when axis is None.
Definition: transform.h:896
constexpr auto kMatMul
Definition: tags.h:37
Tensor reverse_sequence(const Tensor &x, const Tensor &seq_lengths, int seq_axis=1, int batch_axis=0, std::string name="T_reverse_sequence", std::string tag=kInjective)
Reverse the tensor for variable length slices. Input is first sliced along batch axis and then elemen...
Definition: transform.h:256
Tensor sum(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:326
Tensor tensordot(const Tensor &A, const tvm::te::Tensor &B, int axes=2, std::string name="T_tensordot", std::string tag=kMatMul)
A generalization of matrix multiplication to tensors.
Definition: transform.h:1445
Tensor stack(const Array< Tensor > &inputs, int axis=0, std::string name="T_stack", std::string tag=kInjective)
Join a sequence of tensors along a new axis.
Definition: transform.h:526
Array< Tensor > split_sections(const Tensor &x, int num_sections, int axis, std::string name="T_split_sections", std::string tag=kInjective)
Split a tensor into a number of sub-tensors.
Definition: transform.h:854
Tensor strided_slice_with_axes(const Tensor &x, const Array< Integer > &begin, const Array< Integer > &end, const Array< Integer > &strides, const Array< Integer > &axes, std::string slice_mode="end", std::string name="T_strided_slice_with_axes", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:767
tvm::te::Tensor matmul(const tvm::te::Tensor &A, const tvm::te::Tensor &B, bool trans_a=false, bool trans_b=false, std::string name="T_matmul", std::string tag=kMatMul)
Creates an operation that calculates a matrix multiplication (row-major notation): A(i,...
Definition: transform.h:1423
Tensor matrix_set_diag(const Tensor &input, const Tensor &diagonal, int k1, int k2, bool super_diag_right_align, bool sub_diag_right_align, const std::string name="T_matrix_set_diag", const std::string tag=kInjective)
Returns a tensor with the diagonal of input tensor replaced with the provided diagonals.
Definition: transform.h:1920
Tensor where(const Tensor &condition, const Tensor &x, const Tensor &y, std::string name="T_where", std::string tag=kBroadcast)
Return the elements, either from x or y, depending on the condition.
Definition: transform.h:1130
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1766
Tensor gather(const Tensor &data, int axis, const Tensor &indices, std::string name="T_gather", std::string tag=kInjective)
Gather values along given axis from given indices.
Definition: transform.h:1309
Array< Tensor > split(const Tensor &x, Array< PrimExpr > split_indices, int axis, std::string name="T_split", std::string tag=kInjective)
Split a tensor into multiple sub-tensors.
Definition: transform.h:572
Tensor repeat(const Tensor &x, int repeats, int axis, std::string name="T_repeat", std::string tag=kBroadcast)
Creates an operation to repeat elements of an array.
Definition: transform.h:1170
Array< PrimExpr > StridedSliceOutputShape(const Array< PrimExpr > &ishape, const Array< Integer > &begin, const Array< Integer > &end, const Array< Integer > &strides, const Array< Integer > &axes, const std::string &slice_mode)
Calculate the output shape of strided_slice, the entry point for Relay type relation.
Definition: transform.h:739
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of truncdiv
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr max_value(const DataType &dtype, Span span=Span())
PrimExpr ceil(PrimExpr x, Span span=Span())
Calculate ceil(x)
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b)
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Operation node can generate one or multiple Tensors.
Index ravel and unraval operations.
Utility functions for strided_slice op.
External function interface to rocBLAS libraries.
Utility functions for handling tensor.