tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_TRANSFORM_H_
25 #define TVM_TOPI_TRANSFORM_H_
26 
27 #include <tvm/te/operation.h>
28 #include <tvm/tir/data_layout.h>
29 #include <tvm/tir/index_map.h>
30 #include <tvm/topi/broadcast.h>
36 #include <tvm/topi/tags.h>
37 
38 #include <algorithm>
39 #include <iterator>
40 #include <limits>
41 #include <string>
42 #include <unordered_set>
43 #include <vector>
44 
45 namespace tvm {
46 namespace topi {
47 
48 using namespace tvm::te;
49 using namespace topi::detail;
50 
68 inline Tensor sliding_window(const Tensor& x, int axis, Array<Integer> window_shape,
69  Array<Integer> strides, std::string name = "T_sliding_window",
70  std::string tag = "") {
71  CHECK_GE(axis, 0);
72  auto _axis = size_t(axis);
73  CHECK_LT(_axis, x->shape.size()) << "axis must be a valid dimension index of x.";
74  CHECK_EQ(x->shape.size() - _axis, window_shape.size())
75  << "There must be a window shape for every dimension of x "
76  << "over which we are sliding the window.";
77  CHECK_EQ(strides.size(), window_shape.size()) << "Windows and strides should be the same length.";
78 
79  // Compute the new shape.
80  Array<PrimExpr> new_shape;
81  // Dimensions up until `axis` remain the same.
82  for (size_t i = 0; i < _axis; ++i) {
83  new_shape.push_back(x->shape[i]);
84  }
85 
86  // New dimensions which result from sliding the window in each dimension. One new dimension per
87  // window dimension.
88  for (size_t i = 0; i < window_shape.size(); ++i) {
89  // Length of the shape along this dimension.
90  auto dim_len = x->shape[_axis + i];
91  // Length of the window along this dimension.
92  auto window_len = window_shape[i];
93  // Strides along this dimension.
94  auto stride = strides[i];
95 
96  new_shape.push_back(floordiv(dim_len - (window_len - 1) + stride - 1, stride));
97  }
98 
99  // Dimensions comprising the window.
100  for (size_t i = 0; i < window_shape.size(); ++i) {
101  new_shape.push_back(window_shape[i]);
102  }
103 
104  ICHECK(new_shape.size() == _axis + 2 * window_shape.size());
105 
106  return compute(
107  new_shape,
108  [&](const Array<Var>& indices) {
109  // The index at which to index the old tensor x.
110  Array<PrimExpr> idx;
111 
112  // Dimensions up until `axis` remain the same.
113  for (size_t i = 0; i < _axis; ++i) {
114  idx.push_back(indices[i]);
115  }
116 
117  for (size_t i = 0; i < window_shape.size(); ++i) {
118  // Which window in this dimension we are indexing.
119  auto window_idx = indices[_axis + i];
120  // Which index within the window we are indexing.
121  auto idx_within_window = indices[_axis + window_shape.size() + i];
122  // Stride value for this dimension.
123  auto stride = strides[i];
124 
125  idx.push_back(window_idx * stride + idx_within_window);
126  }
127 
128  ICHECK(idx.size() == x->shape.size());
129 
130  return x(idx);
131  },
132  name, tag);
133 }
134 
147 inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1,
148  std::string name = "T_expand_dims", std::string tag = kBroadcast) {
149  int ndim = static_cast<int>(x->shape.size());
150  ICHECK(-ndim - 1 <= axis && axis <= ndim)
151  << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]"
152  << ", but got axis = " << axis << ", and data.ndim = " << ndim;
153  ICHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`"
154  << ", but got num_newaxis = " << num_newaxis;
155  if (axis < 0) {
156  // Calculate offset from last dimension
157  axis = ndim + axis + 1;
158  }
159  Array<PrimExpr> new_shape;
160  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
161  new_shape.push_back(x->shape[i]);
162  }
163  for (size_t i = 0; i < static_cast<size_t>(num_newaxis); ++i) {
164  new_shape.push_back(1);
165  }
166  for (size_t i = axis; i < x->shape.size(); ++i) {
167  new_shape.push_back(x->shape[i]);
168  }
169 
170  return compute(
171  new_shape,
172  [&](const Array<Var>& indices) {
173  Array<PrimExpr> idx;
174  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
175  idx.push_back(indices[i]);
176  }
177  for (size_t i = axis + num_newaxis; i < indices.size(); ++i) {
178  idx.push_back(indices[i]);
179  }
180  return x(idx);
181  },
182  name, tag);
183 }
184 
196 inline Tensor transpose(const Tensor& x, Array<Integer> axes, std::string name = "T_transpose",
197  std::string tag = kInjective) {
198  if (!axes.defined() || axes.size() == 0) {
199  axes = Array<Integer>();
200  for (int i = static_cast<int>(x->shape.size()) - 1; i >= 0; --i) {
201  axes.push_back(i);
202  }
203  }
204 
205  Array<PrimExpr> new_shape;
206  for (size_t i = 0; i < axes.size(); ++i) {
207  int axis = static_cast<int>(axes[i]->value);
208  int new_axis = axis;
209  if (axis < 0) {
210  new_axis = static_cast<int>(x->shape.size()) + axis;
211  axes.Set(i, new_axis);
212  }
213  ICHECK((new_axis >= 0) && (new_axis < static_cast<int>(x->shape.size())))
214  << "axis=" << axis << " is invalid for the " << static_cast<int>(x->shape.size())
215  << "-dimensional input tensor";
216 
217  for (size_t j = 0; j < axes.size(); ++j) {
218  if (i != j) {
219  ICHECK(new_axis != static_cast<int>(axes[j]->value)) << "repeated axis in transpose";
220  }
221  }
222  new_shape.push_back(x->shape[new_axis]);
223  }
224 
225  return compute(
226  new_shape,
227  [&](const Array<Var>& indices) {
228  std::vector<PrimExpr> idx;
229  for (size_t i = 0; i < axes.size(); ++i) {
230  idx.push_back(1);
231  }
232  for (size_t i = 0; i < axes.size(); ++i) {
233  int axis = static_cast<int>(axes[i]->value);
234  idx[axis] = indices[i];
235  }
236  return x(idx);
237  },
238  name, tag);
239 }
240 
255 inline Tensor reverse_sequence(const Tensor& x, const Tensor& seq_lengths, int seq_axis = 1,
256  int batch_axis = 0, std::string name = "T_reverse_sequence",
257  std::string tag = kInjective) {
258  size_t src_tensor_dim = x->shape.size();
259  int seq_axis_inp = seq_axis;
260 
261  if (seq_lengths.defined()) {
262  size_t seq_lengths_dim = seq_lengths->shape.size();
263  int batch_axis_inp = batch_axis;
264  if (batch_axis < 0) {
265  batch_axis = static_cast<int>(x->shape.size()) + batch_axis;
266  }
267 
268  ICHECK(seq_lengths_dim == 1) << "seq_lengths should be 1D vector";
269 
270  ICHECK(GetConstInt(seq_lengths->shape[0]) == GetConstInt(x->shape[batch_axis]))
271  << "For reverse_sequnece seq_lengths size should match with dimension of batch axis"
272  << ", but got dimension of batch_axis = " << GetConstInt(x->shape[batch_axis])
273  << ", and seq_length size = " << GetConstInt(seq_lengths->shape[0]);
274 
275  ICHECK((0 <= batch_axis) && (batch_axis < static_cast<int>(x->shape.size())))
276  << "batch_axis=" << batch_axis_inp << " is invalid for the "
277  << static_cast<int>(x->shape.size()) << "-dimensional input tensor";
278  }
279 
280  if (seq_axis < 0) {
281  seq_axis = static_cast<int>(x->shape.size()) + seq_axis;
282  }
283  ICHECK((0 <= seq_axis) && (seq_axis < static_cast<int>(x->shape.size())))
284  << "seq_axis=" << seq_axis_inp << " is invalid for the " << static_cast<int>(x->shape.size())
285  << "-dimensional input tensor";
286 
287  auto func = [&](const Array<Var>& indices) {
288  Array<PrimExpr> real_indices;
289  for (size_t i = 0; i < src_tensor_dim; ++i) {
290  if (i == static_cast<size_t>(seq_axis)) {
291  if (seq_lengths.defined()) {
292  auto len = seq_lengths(indices[batch_axis]);
293  auto idx = if_then_else(
294  len <= 1 || len <= indices[i], indices[i],
295  if_then_else(len > x->shape[i], x->shape[i] - 1 - indices[i], len - 1 - indices[i]));
296  real_indices.push_back(idx);
297  } else {
298  real_indices.push_back(x->shape[i] - 1 - indices[i]);
299  }
300  } else {
301  real_indices.push_back(indices[i]);
302  }
303  }
304  return x(real_indices);
305  };
306 
307  return compute(x->shape, func, name, tag);
308 }
309 
320 inline Tensor reshape(const Tensor& x, Array<PrimExpr> newshape, std::string name = "T_reshape",
321  std::string tag = kInjective) {
322  auto x_shape = x->shape;
323  Array<PrimExpr> target_shape;
324 
325  for (const auto& ele : newshape) {
326  if (ele.as<IntImmNode>()) {
327  target_shape.push_back(cast(DataType::Int(32), ele));
328  } else {
329  target_shape.push_back(ele);
330  }
331  }
332 
333  // If either the input shape or the target shape contains a zero, return an empty tensor.
334  if (is_empty_shape(target_shape) || is_empty_shape(x->shape)) {
335  return compute(
336  target_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
337  } else {
338  return compute(
339  target_shape,
340  [&](const Array<Var>& indices) {
341  return x(UnravelIndex(
342  RavelIndex(Array<PrimExpr>{indices.begin(), indices.end()}, target_shape), x_shape));
343  },
344  name, tag);
345  }
346 }
347 
359 inline Tensor unravel_index(const Tensor& x, const Tensor& shape, std::string name = "T_unravel",
360  std::string tag = kInjective) {
361  auto x_shape = x->shape;
362  auto shape_shape = shape->shape;
363 
364  Array<PrimExpr> oshape;
365  oshape.push_back(shape_shape[0]);
366  if (x_shape.size() != 0) {
367  oshape.push_back(x_shape[0]);
368  }
369 
370  auto func = [&](const Array<Var>& indices) {
371  auto i = indices[0];
372  std::vector<PrimExpr> indices_divs;
373  PrimExpr ret = 0;
374  PrimExpr cur_val = 0;
375  PrimExpr index_val = 0;
376 
377  if (x_shape.size() != 0) {
378  index_val = x[indices[1]];
379  } else {
380  index_val = x();
381  }
382  indices_divs.push_back(index_val);
383  for (int v = GetConstInt(shape_shape[0]) - 1; v >= 0; --v) {
384  ret = tvm::if_then_else(i == v, indexmod(indices_divs.back(), shape[v]), ret);
385  cur_val = indexdiv(indices_divs.back(), shape[v]);
386  indices_divs.push_back(cur_val);
387  }
388  return ret;
389  };
390 
391  return compute(oshape, func, name, tag);
392 }
393 
407 inline Tensor squeeze(const Tensor& x, Array<Integer> axis, bool atleast1d = false,
408  std::string name = "T_squeeze", std::string tag = kInjective) {
409  auto ndim = x->shape.size();
410  std::vector<int> axis_val;
411  if (!axis.defined()) {
412  for (size_t i = 0; i < ndim; ++i) {
413  if (IsConstInt(x->shape[i]) && GetConstInt(x->shape[i]) == 1) {
414  axis_val.push_back(static_cast<int>(i));
415  }
416  }
417  } else {
418  for (size_t i = 0; i < axis.size(); ++i) {
419  int64_t val = axis[i]->value;
420  if (val < 0) {
421  val += static_cast<int>(x->shape.size());
422  }
423  if (IsConstInt(x->shape[val])) {
424  ICHECK_EQ(GetConstInt(x->shape[val]), 1) << "Dimension " << val << " must have size 1";
425  }
426  axis_val.push_back(val);
427  }
428  }
429 
430  std::unordered_set<int> axis_set(axis_val.begin(), axis_val.end());
431 
432  Array<PrimExpr> out_shape;
433  for (size_t i = 0; i < ndim; ++i) {
434  if (axis_set.count(static_cast<int>(i)) == 0) {
435  out_shape.push_back(x->shape[i]);
436  }
437  }
438  if (out_shape.size() == 0 && atleast1d) {
439  out_shape.push_back(1);
440  }
441 
442  return compute(
443  out_shape,
444  [&](const Array<Var>& indices) {
445  Array<PrimExpr> real_indices;
446  int flag = 0;
447  for (size_t i = 0; i < ndim; ++i) {
448  if (axis_set.count(static_cast<int>(i)) == 0) {
449  real_indices.push_back(indices[i - flag]);
450  } else {
451  real_indices.push_back(0);
452  flag += 1;
453  }
454  }
455  return x(real_indices);
456  },
457  name, tag);
458 }
459 
470 inline Tensor concatenate(const Array<Tensor>& inputs, int axis = 0, std::string name = "T_concat",
471  std::string tag = kInjective) {
472  int ndim = static_cast<int>(inputs[0]->shape.size());
473  ICHECK(-ndim <= axis && axis < ndim) << "concatenate only accepts `axis` in [-ndim, ndim)"
474  << ", but got axis = " << axis << ", and ndim = " << ndim;
475  if (axis < 0) {
476  axis += ndim;
477  }
478  ICHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds";
479 
480  Array<PrimExpr> axis_sizes;
481  for (auto t : inputs) {
482  axis_sizes.push_back(t->shape[axis]);
483  }
484  arith::Analyzer analyzer;
485  PrimExpr join_size = axis_sizes[0];
486  for (size_t i = 1; i < axis_sizes.size(); ++i) {
487  join_size += axis_sizes[i];
488  }
489  join_size = analyzer.Simplify(join_size);
490  Array<PrimExpr> out_shape;
491  for (size_t i = 0; i < inputs[0]->shape.size(); ++i) {
492  out_shape.push_back(i == static_cast<size_t>(axis) ? join_size : inputs[0]->shape[i]);
493  }
494 
495  return compute(
496  out_shape,
497  [&](const Array<Var>& indices) {
498  auto ret = inputs[0](indices);
499  auto ind = indices[axis];
500  for (size_t i = 0; i < inputs.size() - 1; ++i) {
501  ind -= axis_sizes[i];
502 
503  Array<PrimExpr> idx;
504  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
505  idx.push_back(indices[i]);
506  }
507  idx.push_back(ind);
508  for (size_t i = axis + 1; i < indices.size(); ++i) {
509  idx.push_back(indices[i]);
510  }
511 
512  ret = tvm::if_then_else(ind >= 0, inputs[i + 1](idx), ret);
513  }
514  return ret;
515  },
516  name, tag);
517 }
518 
529 inline Tensor stack(const Array<Tensor>& inputs, int axis = 0, std::string name = "T_stack",
530  std::string tag = kInjective) {
531  int ndim = static_cast<int>(inputs[0]->shape.size());
532  ICHECK(-ndim - 1 <= axis && axis <= ndim)
533  << "stack only accepts `axis` in [-ndim, ndim)"
534  << ", but got axis = " << axis << ", and ndim = " << ndim;
535  if (axis < 0) {
536  axis += ndim + 1;
537  }
538  ICHECK_LT(axis, inputs[0]->shape.size() + 1) << "axis out of bounds";
539 
540  const int stack_size = static_cast<int>(inputs.size());
541  Array<PrimExpr> out_shape;
542  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) out_shape.push_back(inputs[0]->shape[i]);
543  out_shape.push_back(stack_size);
544  for (size_t i = static_cast<size_t>(axis); i < static_cast<size_t>(ndim); ++i)
545  out_shape.push_back(inputs[0]->shape[i]);
546 
547  return compute(
548  out_shape,
549  [&](const Array<Var>& indices) {
550  Array<PrimExpr> idx;
551  for (size_t i = 0; i < indices.size(); ++i)
552  if (i != static_cast<size_t>(axis)) idx.push_back(indices[i]);
553  auto ind = indices[axis];
554  auto ret = inputs[0](idx);
555  for (int i = 0; i < static_cast<int>(inputs.size() - 1); ++i) {
556  ret = tvm::if_then_else(ind == i + 1, inputs[i + 1](idx), ret);
557  }
558  return ret;
559  },
560  name, tag);
561 }
562 
575 inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int axis,
576  std::string name = "T_split", std::string tag = kInjective) {
577  if (axis < 0) {
578  axis += static_cast<int>(x->shape.size());
579  }
580  ICHECK_LT(axis, x->shape.size()) << "axis out of bounds";
581 
582  auto src_axis_size = x->shape[axis];
583  std::vector<PrimExpr> begin_ids;
584  begin_ids.push_back(0);
585 
586  for (auto idx : split_indices) {
587  auto idx_node = idx.as<IntImmNode>();
588  auto back_node = begin_ids.back().as<IntImmNode>();
589  if (idx_node && back_node) {
590  ICHECK_GT(idx_node->value, back_node->value) << "split_indices must be sorted";
591  }
592  begin_ids.push_back(idx);
593  }
594 
595  Array<Array<PrimExpr>> out_shapes;
596  for (size_t i = 0; i < begin_ids.size(); ++i) {
597  PrimExpr out_axis_size;
598  if (i == begin_ids.size() - 1) {
599  out_axis_size = src_axis_size - begin_ids[i];
600  } else {
601  out_axis_size = begin_ids[i + 1] - begin_ids[i];
602  }
603 
605  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
606  shape.push_back(x->shape[i]);
607  }
608  shape.push_back(out_axis_size);
609  for (size_t i = axis + 1; i < x->shape.size(); ++i) {
610  shape.push_back(x->shape[i]);
611  }
612 
613  out_shapes.push_back(shape);
614  }
615 
616  Array<Tensor> result;
617  for (size_t i = 0; i < begin_ids.size(); ++i) {
618  result.push_back(compute(
619  out_shapes[i],
620  [&](const Array<Var>& indices) {
621  auto begin = begin_ids[i];
622  Array<PrimExpr> real_indices;
623  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
624  real_indices.push_back(indices[j]);
625  }
626  real_indices.push_back(indices[axis] + begin);
627  for (size_t j = axis + 1; j < indices.size(); ++j) {
628  real_indices.push_back(indices[j]);
629  }
630 
631  return x(real_indices);
632  },
633  name, tag));
634  }
635 
636  return result;
637 }
638 
652 inline Tensor dynamic_strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
653  const Array<PrimExpr>& end, const Array<PrimExpr>& strides,
654  std::string name = "T_dynamic_strided_slice",
655  std::string tag = kInjective) {
656  const size_t src_tensor_dim = x->shape.size();
657  ICHECK_LE(begin.size(), src_tensor_dim);
658  ICHECK_LE(end.size(), src_tensor_dim);
659  ICHECK_LE(strides.size(), src_tensor_dim);
660  ICHECK_EQ(begin.size(), end.size());
661  ICHECK_EQ(begin.size(), strides.size());
662 
663  const size_t num_slice_axes = begin.size();
664  Array<PrimExpr> out_shape;
665 
666  for (size_t i = 0; i < num_slice_axes; ++i) {
667  auto d = indexdiv(end[i] - begin[i], strides[i]);
668  if (d->IsInstance<tvm::IntImmNode>()) {
669  // Preserve static dimension if possible
670  out_shape.push_back(d);
671  } else {
672  out_shape.push_back(tvm::tir::Var("dim"));
673  }
674  }
675 
676  for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) {
677  out_shape.push_back(x->shape[i]);
678  }
679 
680  return te::compute(
681  out_shape,
682  [&](const Array<tvm::tir::Var>& indices) {
683  Array<PrimExpr> real_indices;
684  for (size_t i = 0; i < num_slice_axes; ++i) {
685  real_indices.push_back(indices[i] * strides[i] + tvm::min(begin[i], x->shape[i] - 1));
686  }
687  // keep input dim
688  for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) {
689  real_indices.push_back(indices[i]);
690  }
691  return x(real_indices);
692  },
693  name, tag);
694 }
695 
710  const te::Tensor& end, const te::Tensor& strides,
711  std::string name = "T_strided_slice_dynamic",
712  std::string tag = topi::kInjective) {
713  const int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value;
714  ICHECK_EQ(end->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
715  ICHECK_EQ(strides->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
716 
717  Array<PrimExpr> begin_expr, end_expr, strides_expr;
718  for (int64_t i = 0; i < num_dynamic_axes; ++i) {
719  auto i64_ind = IntImm(DataType::Int(64), i);
720  begin_expr.push_back(begin(i64_ind));
721  end_expr.push_back(end(i64_ind));
722  strides_expr.push_back(strides(i64_ind));
723  }
724  return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, tag);
725 }
726 
742  const Array<PrimExpr>& ishape, const Array<Integer>& begin, const Array<Integer>& end,
743  const Array<Integer>& strides, const Array<Integer>& axes, const std::string& slice_mode) {
744  ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size());
745  std::vector<int64_t> begin_vec, end_vec, strides_vec;
746  std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode);
747  auto begin_canonicalized = StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes,
748  begin[0]->dtype, slice_mode);
749  return StridedSliceOutputShape(ishape, begin_vec, end_vec, strides_vec, axes, slice_mode,
750  begin_canonicalized, true);
751 }
752 
769 inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& begin,
770  const Array<Integer>& end, const Array<Integer>& strides,
771  const Array<Integer>& axes, std::string slice_mode = "end",
772  std::string name = "T_strided_slice_with_axes",
773  std::string tag = kInjective) {
774  const size_t src_tensor_dim = x->shape.size();
775  ICHECK(axes.size() <= src_tensor_dim);
776  ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size());
777 
778  std::vector<int64_t> begin_vec, end_vec, strides_vec;
779  std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode);
780 
781  auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, axes,
782  begin[0]->dtype, slice_mode);
783  auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec, strides_vec, axes,
784  slice_mode, begin_expr);
785 
786  return te::compute(
787  out_shape,
788  [&](const Array<tir::Var>& indices) {
789  Array<PrimExpr> real_indices;
790  for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]);
791  for (size_t i = 0; i < axes.size(); ++i) {
792  auto stride = make_const(strides[i].dtype(), strides_vec[i]);
793  PrimExpr ind = indices[axes[i].IntValue()] * stride + begin_expr[i];
794  real_indices.Set(axes[i].IntValue(), ind);
795  }
796  return x(real_indices);
797  },
798  name, tag);
799 }
800 
815 inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, const Array<Integer>& end,
816  const Array<Integer>& strides, std::string slice_mode = "end",
817  std::string name = "T_strided_slice", std::string tag = kInjective) {
818  size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
819  Array<Integer> axes;
820  for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i);
821  Array<Integer> begin_full(begin);
822  Array<Integer> end_full(end);
823  Array<Integer> strides_full(strides);
824 
825  const IntImm one = IntImm(DataType::Int(64), 1);
826  const IntImm zero = IntImm(DataType::Int(64), 0);
828 
829  for (size_t i = strides.size(); i < src_tensor_dim; ++i) {
830  strides_full.push_back(one);
831  }
832  for (size_t i = begin.size(); i < src_tensor_dim; ++i) {
833  begin_full.push_back(GetConstInt(strides_full[i]) > 0 ? zero : max_range);
834  }
835  for (size_t i = end.size(); i < src_tensor_dim; ++i) {
836  end_full.push_back(GetConstInt(strides_full[i]) < 0 ? zero : max_range);
837  }
838 
839  return strided_slice_with_axes(x, begin_full, end_full, strides_full, axes, slice_mode, name,
840  tag);
841 }
842 
855 inline Array<Tensor> split_sections(const Tensor& x, int num_sections, int axis,
856  std::string name = "T_split_sections",
857  std::string tag = kInjective) {
858  if (axis < 0) {
859  axis += static_cast<int>(x->shape.size());
860  }
861  ICHECK_LT(axis, x->shape.size()) << "axis out of bounds";
862 
863  auto src_axis_size = x->shape[axis];
864 
865  ICHECK_GT(num_sections, 0) << "Slice count must be > 0";
866 
867  if (auto node = src_axis_size.as<IntImmNode>()) {
868  ICHECK_EQ(node->value % num_sections, 0)
869  << "num_sections must be an integer factor of the size of axis " << axis << " ("
870  << node->value << ")";
871  }
872 
873  Array<PrimExpr> split_indices;
874  auto seg_size = indexdiv(src_axis_size, num_sections);
875  for (int i = 0; i < num_sections; ++i) {
876  // region at index 0 is added by split()
877  if (i != 0) {
878  split_indices.push_back(seg_size * i);
879  }
880  }
881 
882  return split(x, split_indices, axis, name, tag);
883 }
884 
898 inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims,
899  std::string mode = "clip", std::string name = "T_take",
900  std::string tag = kInjective) {
901  Array<PrimExpr> a_shape = a->shape;
902  Array<PrimExpr> out_shape = indices->shape;
903  PrimExpr a_size = 1;
904  for (size_t i = 0; i < a_shape.size(); ++i) {
905  a_size = a_size * a_shape[i];
906  }
907 
908  if (mode == "clip") {
909  return compute(
910  out_shape,
911  [&](const Array<Var>& out_index) {
912  auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
913  return a(UnravelIndex(idx, a_shape));
914  },
915  name, tag);
916  } else if (mode == "fast") {
917  LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
918  "Make sure input indices are in bound";
919  return compute(
920  out_shape,
921  [&](const Array<Var>& out_index) { return a(UnravelIndex(indices(out_index), a_shape)); },
922  name, tag);
923  } else { // mode == "wrap"
924  return compute(
925  out_shape,
926  [&](const Array<Var>& out_index) {
927  auto idx = truncmod(truncmod(indices(out_index), a_size) + a_size, a_size);
928  return a(UnravelIndex(idx, a_shape));
929  },
930  name, tag);
931  }
932 }
933 
946 inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, double mask_value,
947  int axis, std::string name = "T_sequence_mask",
948  std::string tag = kInjective) {
949  ICHECK(axis == 0 || axis == 1) << "axis must be either 0 or 1";
950  ICHECK_EQ(valid_length->shape.size(), 1) << "valid_length must have ndim=1, i.e., (batch_size,).";
951  auto length_dim = data->shape[axis];
952  auto batch_dim = data->shape[1 - axis];
953  Array<PrimExpr> out_shape = data->shape;
954  Tensor out = compute(
955  out_shape,
956  [&](const Array<Var>& out_index) {
957  Array<PrimExpr> len_index;
958  auto tid = out_index[axis];
959  auto bid = out_index[1 - axis];
960  len_index.push_back(bid);
961  PrimExpr ret =
962  tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index),
963  tvm::tir::make_const(data->dtype, mask_value), data(out_index));
964  return ret;
965  },
966  name, tag);
967  return out;
968 }
969 
984 inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int axis,
985  std::string mode = "clip", std::string name = "T_take",
986  std::string tag = kInjective) {
987  if (axis < 0) {
988  axis += static_cast<int>(a->shape.size());
989  }
990  ICHECK_GE(axis, 0) << "axis out of bounds";
991  ICHECK_LT(axis, a->shape.size()) << "axis out of bounds";
992  auto axis_dim = a->shape[axis];
993  int indices_len = static_cast<int>(indices->shape.size());
994 
995  int batch_dims_ = batch_dims;
996  if (batch_dims_ != 0) {
997  ICHECK_GE(batch_dims_, -static_cast<int>(indices->shape.size())) << "batch_dims out of bounds";
998  ICHECK_LE(batch_dims_, indices->shape.size()) << "batch_dims out of bounds";
999 
1000  if (batch_dims_ < 0) {
1001  batch_dims_ = indices->shape.size() + batch_dims_;
1002  }
1003 
1004  ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds";
1005  ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to axis";
1006  for (int i = 0; i < batch_dims_; ++i) {
1007  auto addr1 = a->shape[i];
1008  auto addr2 = indices->shape[i];
1009  auto v1 = static_cast<IntImm*>(&addr1)->get()->value;
1010  auto v2 = static_cast<IntImm*>(&addr2)->get()->value;
1011  ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to indices.shape[" << i << "]";
1012  }
1013  }
1014 
1015  // The result shape is a.shape[:axis] + indices.shape[batch_dims:] +
1016  // a.shape[axis + 1:].
1017 
1018  Array<PrimExpr> out_shape;
1019  for (int i = 0; i < batch_dims_; ++i) {
1020  out_shape.push_back(a->shape[i]);
1021  }
1022  for (int i = batch_dims_; i < axis; ++i) {
1023  out_shape.push_back(a->shape[i]);
1024  }
1025  for (size_t i = static_cast<size_t>(batch_dims_); i < indices->shape.size(); ++i) {
1026  out_shape.push_back(indices->shape[i]);
1027  }
1028  for (size_t i = axis + 1; i < a->shape.size(); ++i) {
1029  out_shape.push_back(a->shape[i]);
1030  }
1031 
1032  if (mode == "clip") {
1033  if (batch_dims_ == 0) {
1034  return compute(
1035  out_shape,
1036  [&](const Array<Var>& out_index) {
1037  Array<PrimExpr> indices_position;
1038  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1039  indices_position.push_back(out_index[j]);
1040  }
1041  Array<PrimExpr> real_indices;
1042  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1043  real_indices.push_back(out_index[j]);
1044  }
1045  auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
1046  real_indices.push_back(idx);
1047  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1048  real_indices.push_back(out_index[j]);
1049  }
1050  return a(real_indices);
1051  },
1052  name, tag);
1053  } else {
1054  return compute(
1055  out_shape,
1056  [&](const Array<Var>& out_index) {
1057  Array<PrimExpr> indices_position;
1058  for (size_t j = 0; j < static_cast<size_t>(batch_dims_); ++j) {
1059  indices_position.push_back(out_index[j]);
1060  }
1061  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len - batch_dims_); ++j) {
1062  indices_position.push_back(out_index[j]);
1063  }
1064  Array<PrimExpr> real_indices;
1065  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1066  real_indices.push_back(out_index[j]);
1067  }
1068  auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
1069  real_indices.push_back(idx);
1070  for (size_t j = axis + indices_len - batch_dims_; j < out_index.size(); ++j) {
1071  real_indices.push_back(out_index[j]);
1072  }
1073  return a(real_indices);
1074  },
1075  name, tag);
1076  }
1077  } else if (mode == "fast") {
1078  LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
1079  "Make sure input indices are in bound";
1080  return compute(
1081  out_shape,
1082  [&](const Array<Var>& out_index) {
1083  Array<PrimExpr> indices_position;
1084  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1085  indices_position.push_back(out_index[j]);
1086  }
1087  Array<PrimExpr> real_indices;
1088  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1089  real_indices.push_back(out_index[j]);
1090  }
1091  real_indices.push_back(indices(indices_position));
1092  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1093  real_indices.push_back(out_index[j]);
1094  }
1095  return a(real_indices);
1096  },
1097  name, tag);
1098  } else { // mode == "wrap"
1099  return compute(
1100  out_shape,
1101  [&](const Array<Var>& out_index) {
1102  Array<PrimExpr> indices_position;
1103  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1104  indices_position.push_back(out_index[j]);
1105  }
1106  Array<PrimExpr> real_indices;
1107  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1108  real_indices.push_back(out_index[j]);
1109  }
1110  auto idx = truncmod(truncmod(indices(indices_position), axis_dim) + axis_dim, axis_dim);
1111  real_indices.push_back(idx);
1112  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1113  real_indices.push_back(out_index[j]);
1114  }
1115  return a(real_indices);
1116  },
1117  name, tag);
1118  }
1119 }
1120 
1132 inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y,
1133  std::string name = "T_where", std::string tag = kBroadcast) {
1134  ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs "
1135  << y->dtype;
1136  auto get_out_shape = [&]() {
1137  auto bh1 = detail::BroadcastShape(x->shape, y->shape);
1138  Array<PrimExpr> common_shape1(bh1.common_shape.begin(), bh1.common_shape.end());
1139  auto bh2 = detail::BroadcastShape(condition->shape, common_shape1);
1140  Array<PrimExpr> common_shape2(bh2.common_shape.begin(), bh2.common_shape.end());
1141  return common_shape2;
1142  };
1143 
1144  auto oshape = get_out_shape();
1145 
1146  auto c_bh = detail::BroadcastShape(condition->shape, oshape);
1147  auto x_bh = detail::BroadcastShape(x->shape, oshape);
1148  auto y_bh = detail::BroadcastShape(y->shape, oshape);
1149 
1150  auto select = [&](tvm::Array<tvm::tir::Var> ovars) {
1151  auto c = condition(InputIndexFromBroadcast(ovars, condition, c_bh.vars1, c_bh.all_vars));
1152  auto true_val = x(InputIndexFromBroadcast(ovars, x, x_bh.vars1, x_bh.all_vars));
1153  auto false_val = y(InputIndexFromBroadcast(ovars, y, y_bh.vars1, y_bh.all_vars));
1154  return tvm::tir::Select(c != 0, true_val, false_val);
1155  };
1156 
1157  return compute(oshape, select, name, tag);
1158 }
1159 
1172 inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = "T_repeat",
1173  std::string tag = kBroadcast) {
1174  int ndim = static_cast<int>(x->shape.size());
1175  ICHECK(-ndim - 1 <= axis && axis <= ndim)
1176  << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
1177  << ", but got axis = " << axis << ", and data.ndim = " << ndim;
1178  ICHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`"
1179  << ", but got repeats = " << repeats;
1180  if (axis < 0) {
1181  // Calculate offset from last dimension
1182  axis += ndim;
1183  }
1184  Array<PrimExpr> new_shape;
1185  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
1186  new_shape.push_back(x->shape[i]);
1187  }
1188  new_shape.push_back(repeats * x->shape[axis]);
1189  for (size_t i = axis + 1; i < x->shape.size(); ++i) {
1190  new_shape.push_back(x->shape[i]);
1191  }
1192 
1193  return compute(
1194  new_shape,
1195  [&](const Array<Var>& indices) {
1196  Array<PrimExpr> idx;
1197  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
1198  idx.push_back(indices[i]);
1199  }
1200  idx.push_back(indexdiv(indices[axis], repeats));
1201  for (size_t i = axis + 1; i < indices.size(); ++i) {
1202  idx.push_back(indices[i]);
1203  }
1204  return x(idx);
1205  },
1206  name, tag);
1207 }
1208 
1219 inline Tensor tile(const Tensor& x, Array<Integer> reps, std::string name = "T_tile",
1220  std::string tag = kBroadcast) {
1221  size_t ndim = x->shape.size();
1222  size_t rdim = reps.size();
1223  size_t tdim = (ndim > rdim) ? ndim : rdim;
1224  Array<PrimExpr> data_shape;
1225  Array<PrimExpr> reps_shape;
1226  Array<PrimExpr> new_shape;
1227  if (ndim == rdim) {
1228  for (size_t i = 0; i < ndim; ++i) {
1229  data_shape.push_back(x->shape[i]);
1230  reps_shape.push_back(reps[i]);
1231  }
1232  } else if (ndim > rdim) {
1233  for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]);
1234  for (size_t i = 0; i < (ndim - rdim); ++i) reps_shape.push_back(1);
1235  for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]);
1236  } else {
1237  for (size_t i = 0; i < (rdim - ndim); ++i) data_shape.push_back(1);
1238  for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]);
1239  for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]);
1240  }
1241  for (size_t i = 0; i < tdim; ++i) new_shape.push_back(data_shape[i] * reps_shape[i]);
1242 
1243  if (is_empty_shape(new_shape)) {
1244  return compute(
1245  new_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
1246  } else {
1247  return compute(
1248  new_shape,
1249  [&](const Array<Var>& indices) {
1250  Array<PrimExpr> idx;
1251  if (ndim >= rdim) {
1252  for (size_t i = 0; i < ndim; ++i) idx.push_back(indexmod(indices[i], x->shape[i]));
1253  } else {
1254  for (size_t i = 0; i < ndim; ++i)
1255  idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
1256  }
1257  return x(idx);
1258  },
1259  name, tag);
1260  }
1261 }
1262 
1274 inline Tensor dyn_tile(const Tensor& x, Array<PrimExpr> new_shape, size_t rdim,
1275  std::string name = "T_tile", std::string tag = kBroadcast) {
1276  size_t ndim = x->shape.size();
1277  if (is_empty_shape(new_shape)) {
1278  return compute(
1279  new_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
1280  } else {
1281  return compute(
1282  new_shape,
1283  [&](const Array<Var>& indices) {
1284  Array<PrimExpr> idx;
1285  if (ndim >= rdim) {
1286  for (size_t i = 0; i < ndim; ++i) {
1287  idx.push_back(indexmod(indices[i], x->shape[i]));
1288  }
1289  } else {
1290  for (size_t i = 0; i < ndim; ++i) {
1291  idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
1292  }
1293  }
1294  return x(idx);
1295  },
1296  name, tag);
1297  }
1298 }
1299 
1311 inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
1312  std::string name = "T_gather", std::string tag = kInjective) {
1313  size_t ndim_d = data->shape.size();
1314  size_t ndim_i = indices->shape.size();
1315  ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar.";
1316  ICHECK_EQ(ndim_d, ndim_i);
1317  if (axis < 0) {
1318  axis += ndim_d;
1319  }
1320  ICHECK_GE(axis, 0);
1321  ICHECK_LT(axis, ndim_d);
1322  if (indices->shape[axis].as<IntImmNode>()) {
1323  size_t indices_dim_i = static_cast<size_t>(GetConstInt(indices->shape[axis]));
1324  ICHECK_GE(indices_dim_i, 1);
1325  }
1326  ICHECK(indices->dtype.is_int() || indices->dtype.is_uint());
1327 
1328  Array<PrimExpr> out_shape;
1329  for (size_t i = 0; i < ndim_i; ++i) {
1330  out_shape.push_back(indices->shape[i]);
1331  }
1332 
1333  return compute(
1334  out_shape,
1335  [&](const Array<Var>& out_index) {
1336  Array<PrimExpr> indices_position;
1337  for (size_t i = 0; i < ndim_i; ++i) {
1338  indices_position.push_back(out_index[i]);
1339  }
1340  Array<PrimExpr> real_indices;
1341  for (size_t i = 0; i < ndim_i; ++i) {
1342  if (i == static_cast<size_t>(axis)) {
1343  real_indices.push_back(indices(indices_position));
1344  } else {
1345  real_indices.push_back(indices_position[i]);
1346  }
1347  }
1348  return data(real_indices);
1349  },
1350  name, tag);
1351 }
1352 
1364 inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dims = 0,
1365  std::string name = "T_gather_nd", std::string tag = kInjective) {
1366  size_t ndim_d = data->shape.size();
1367  size_t ndim_i = indices->shape.size();
1368  ICHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions";
1369  size_t indices_dim0 = static_cast<size_t>(GetConstInt(indices->shape[0]));
1370  ICHECK_LE(indices_dim0, ndim_d) << "dim 0 of indices tensor must be no more "
1371  << "than dimensions of data tensor";
1372  Array<PrimExpr> out_shape;
1373  for (size_t i = 1; i < ndim_i; ++i) {
1374  out_shape.push_back(indices->shape[i]);
1375  }
1376  for (size_t i = indices_dim0 + batch_dims; i < ndim_d; ++i) {
1377  out_shape.push_back(data->shape[i]);
1378  }
1379  return compute(
1380  out_shape,
1381  [&](const Array<Var>& out_index) {
1382  Array<PrimExpr> indices_position;
1383  indices_position.push_back(0);
1384  for (size_t i = 0; i < ndim_i - 1; ++i) {
1385  indices_position.push_back(out_index[i]);
1386  }
1387  Array<PrimExpr> real_indices;
1388  for (size_t i = 0; i < static_cast<size_t>(batch_dims); ++i) {
1389  real_indices.push_back(out_index[i]);
1390  }
1391  for (size_t i = 0; i < indices_dim0; ++i) {
1392  indices_position.Set(0, make_const(DataType::Int(32), i));
1393  if (indices->dtype.is_int() || indices->dtype.is_uint()) {
1394  real_indices.push_back(indices(indices_position));
1395  } else {
1396  real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position)));
1397  }
1398  }
1399  if (real_indices.size() == ndim_d) {
1400  return data(real_indices);
1401  }
1402  for (size_t i = ndim_i - 1; i < out_index.size(); ++i) {
1403  real_indices.push_back(out_index[i]);
1404  }
1405  return data(real_indices);
1406  },
1407  name, tag);
1408 }
1409 
1426  bool trans_a = false, bool trans_b = false,
1427  std::string name = "T_matmul", std::string tag = kMatMul) {
1428  tvm::Array<tvm::PrimExpr> output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]};
1429  auto k = tvm::te::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k");
1430  auto l = [&](tvm::tir::Var i, tvm::tir::Var j) {
1431  return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), {k});
1432  };
1433  return tvm::te::compute(output_shape, l, name, tag);
1434 }
1435 
1447 inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, int axes = 2,
1448  std::string name = "T_tensordot", std::string tag = kMatMul) {
1449  ICHECK_GE(A->shape.size(), axes);
1450  ICHECK_GE(B->shape.size(), axes);
1451 
1452  Array<PrimExpr> output_shape(A->shape.begin(), A->shape.end() + (-axes));
1453  for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) output_shape.push_back(*it);
1454 
1455  Array<IterVar> iter_vars;
1456  for (int i = 0; i < axes; ++i)
1457  iter_vars.push_back(reduce_axis(Range(0, B->shape[i]), "k" + std::to_string(i)));
1458 
1459  auto func = [&A, &B, &iter_vars, axes](const Array<Var>& input_indices) {
1460  Array<PrimExpr> A_indices(input_indices.begin(),
1461  input_indices.begin() + (A->shape.size() - axes));
1462  for (auto& v : iter_vars) A_indices.push_back(v);
1463 
1464  Array<PrimExpr> B_indices;
1465  for (auto& v : iter_vars) B_indices.push_back(v);
1466 
1467  auto it = input_indices.begin() + (A->shape.size() - axes);
1468  for (; it != input_indices.end(); ++it) B_indices.push_back(*it);
1469 
1470  // Some passes don't like reductions with empty axis, so avoid it here
1471  if (iter_vars.empty()) {
1472  return A(A_indices) * B(B_indices);
1473  } else {
1474  return sum(A(A_indices) * B(B_indices), iter_vars);
1475  }
1476  };
1477 
1478  return compute(output_shape, func, name, tag);
1479 }
1480 
1493 inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, Array<PrimExpr> A_axes,
1494  Array<PrimExpr> B_axes, std::string name = "T_tensordot",
1495  std::string tag = kMatMul) {
1496  ICHECK_EQ(A_axes.size(), B_axes.size());
1497 
1498  auto A_axes_val = GetConstIntValues(A_axes, "A_axes");
1499  auto B_axes_val = GetConstIntValues(B_axes, "B_axes");
1500 
1501  Array<PrimExpr> output_shape;
1502  for (unsigned i = 0; i < A->shape.size(); ++i)
1503  if (std::find(A_axes_val.begin(), A_axes_val.end(), i) == A_axes_val.end())
1504  output_shape.push_back(A->shape[i]);
1505  for (unsigned i = 0; i < B->shape.size(); ++i)
1506  if (std::find(B_axes_val.begin(), B_axes_val.end(), i) == B_axes_val.end())
1507  output_shape.push_back(B->shape[i]);
1508 
1509  Array<IterVar> iter_vars;
1510  for (unsigned i = 0; i < B_axes_val.size(); ++i)
1511  iter_vars.push_back(reduce_axis(Range(0, B->shape[B_axes_val[i]]), "k" + std::to_string(i)));
1512 
1513  auto func = [&A, &B, &iter_vars, A_axes_val, B_axes_val](const Array<Var>& input_indices) {
1514  int idx_input = 0;
1515  Array<PrimExpr> A_indices;
1516  for (unsigned i = 0; i < A->shape.size(); ++i) {
1517  auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i);
1518  if (axes_pos == A_axes_val.end()) {
1519  A_indices.push_back(input_indices[idx_input++]);
1520  } else {
1521  A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]);
1522  }
1523  }
1524 
1525  Array<PrimExpr> B_indices;
1526  for (unsigned i = 0; i < B->shape.size(); ++i) {
1527  auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i);
1528  if (axes_pos == B_axes_val.end()) {
1529  B_indices.push_back(input_indices[idx_input++]);
1530  } else {
1531  B_indices.push_back(iter_vars[axes_pos - B_axes_val.begin()]);
1532  }
1533  }
1534  return sum(A(A_indices) * B(B_indices), iter_vars);
1535  };
1536  return compute(output_shape, func, name, tag);
1537 }
1538 
1539 inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr& step,
1540  DataType dtype, std::string name = "T_arange", std::string tag = kInjective) {
1541  PrimExpr num_elem = tvm::cast(
1542  tvm::DataType::Int(32), tvm::ceil(tvm::cast(tvm::DataType::Float(32), stop - start) / step));
1544  return compute(
1545  {num_elem},
1546  [&](const Array<Var>& indices) { return tvm::cast(dtype, start + step * indices[0]); }, name,
1547  tag);
1548 }
1549 
1560 inline Array<Tensor> meshgrid(const Array<Tensor>& inputs, const std::string& indexing,
1561  std::string name = "T_meshgrid", std::string tag = kInjective) {
1562  const bool cartesian_indexing = indexing == "xy" && inputs.size() >= 2;
1563  Array<PrimExpr> out_shape;
1564  for (size_t i = 0; i < inputs.size(); ++i) {
1565  const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i;
1566  out_shape.push_back(inputs[src_index]->shape.size() == 0 ? 1 : inputs[src_index]->shape[0]);
1567  }
1568  Array<Tensor> result;
1569  for (size_t i = 0; i < inputs.size(); ++i) {
1570  result.push_back(compute(
1571  out_shape,
1572  [&](const Array<Var>& indices) {
1573  const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i;
1574  auto ndim = inputs[i]->GetShape().size();
1575  Array<PrimExpr> real_indices = {};
1576  if (ndim > 0) {
1577  real_indices = {indices[src_index]};
1578  }
1579  return inputs[i](real_indices);
1580  },
1581  name, tag));
1582  }
1583  return result;
1584 }
1585 
1595 inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
1596  const std::string& dst_layout,
1597  const std::string name = "T_layout_trans",
1598  const std::string tag = kInjective) {
1599  Layout src_layout_struct(src_layout);
1600  Layout dst_layout_struct(dst_layout);
1601 
1602  if (src_layout_struct.Equals(dst_layout_struct)) {
1603  return src;
1604  }
1605 
1606  ICHECK(src_layout_struct.defined() && dst_layout_struct.defined())
1607  << "cannot convert from/to undefined layout";
1608 
1609  auto layout_converter = tir::BijectiveLayout(src_layout_struct, dst_layout_struct);
1610  ICHECK(layout_converter.defined())
1611  << "cannot convert from " << src_layout << " to " << dst_layout;
1612 
1613  Array<PrimExpr> dst_shape = layout_converter.ForwardShape(src->shape);
1614 
1615  return compute(
1616  dst_shape,
1617  [&](const Array<Var>& dst_indices) {
1618  Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
1619  Array<PrimExpr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
1620  PrimExpr in_range = PrimExpr(1) > PrimExpr(0); // init with dtype=bool and value=true
1621  for (size_t i = 0; i < src.ndim(); ++i) {
1622  in_range = in_range && (src_indices[i] < src->shape[i]);
1623  }
1624  return if_then_else(in_range, src(src_indices), tvm::cast(src->dtype, PrimExpr(0)));
1625  },
1626  name, tag);
1627 }
1628 
1631  std::vector<std::string>* axes) {
1632  int32_t factor = 0;
1633  std::string axis = "";
1634  for (char c : std::string(layout)) {
1635  if (c >= 'A' && c <= 'z') {
1636  axis += c;
1637  if (factor != 0) {
1638  shape->push_back(factor);
1639  factor = 0;
1640  }
1641  } else if (c >= '0' && c <= '9') {
1642  factor = factor * 10 + c - '0';
1643  if (!axis.empty()) {
1644  axes->push_back(axis);
1645  axis = "";
1646  }
1647  } else {
1648  LOG(FATAL) << "Invalid layout " << layout;
1649  }
1650  }
1651  if (!axis.empty()) {
1652  axes->push_back(axis);
1653  }
1654 }
1655 
1666 inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String& src_layout,
1667  const String& dst_layout,
1668  const String name = "T_auto_scheduler_layout_trans",
1669  const String tag = kInjective) {
1670  Array<PrimExpr> src_shape;
1671  std::vector<std::string> src_axes;
1672  Array<PrimExpr> dst_shape;
1673  std::vector<std::string> dst_axes;
1674 
1675  parse_auto_scheduler_layout(src_layout, &src_shape, &src_axes);
1676  parse_auto_scheduler_layout(dst_layout, &dst_shape, &dst_axes);
1677  return compute(
1678  dst_shape,
1679  [&](const Array<Var>& dst_indices) {
1680  Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
1681  Array<PrimExpr> src_indices;
1682  for (const std::string& src_axis : src_axes) {
1683  PrimExpr src_index = 0;
1684  CHECK_EQ(dst_indices_expr.size(), dst_axes.size());
1685  for (size_t i = 0; i < dst_axes.size(); ++i) {
1686  if (dst_axes[i] == src_axis) {
1687  src_index = src_index * dst_shape[i] + dst_indices_expr[i];
1688  }
1689  }
1690  src_indices.push_back(src_index);
1691  }
1692  return src(src_indices);
1693  },
1694  name, tag);
1695 }
1696 
1733 inline Tensor meta_schedule_layout_transform(const Tensor& src, const tir::IndexMap& index_map,
1734  const String name = "T_meta_schedule_layout_trans",
1735  const String tag = kInjective) {
1736  Array<Range> iter_domain;
1737  iter_domain.reserve(src->shape.size());
1738  for (const PrimExpr& e : src->shape) {
1739  iter_domain.push_back(Range::FromMinExtent(make_zero(e->dtype), e));
1740  }
1741  Array<PrimExpr> post_transform_shape = index_map->MapShape(src->shape);
1742  return compute(
1743  post_transform_shape,
1744  [src, inv = index_map.Inverse(iter_domain)](const Array<Var>& indices) -> PrimExpr {
1745  return src(inv->MapIndices(Array<PrimExpr>{indices.begin(), indices.end()}));
1746  },
1747  name, tag);
1748 }
1749 
1758 inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = "T_shape",
1759  const std::string tag = kInjective) {
1760  int ndim = static_cast<int>(src->shape.size());
1761  Array<PrimExpr> out_shape{ndim};
1762  return compute(
1763  out_shape,
1764  [&](const Array<Var>& indices) {
1765  auto idx = indices[0];
1766  PrimExpr ret = 0;
1767  for (int i = 0; i < ndim; ++i) {
1768  ret = tvm::if_then_else(idx == i, src->shape[i], ret);
1769  }
1770  return tvm::cast(dtype, ret);
1771  },
1772  name, tag);
1773 }
1774 
1783 inline Tensor ndarray_size(const Tensor& src, const DataType& dtype,
1784  const std::string& name = "ndarray_size",
1785  const std::string& tag = kInjective) {
1786  int ndim = static_cast<int>(src->shape.size());
1787  Array<PrimExpr> out_ndarray_size = {};
1788  return compute(
1789  out_ndarray_size,
1790  [&](const Array<Var>& indices) {
1791  PrimExpr ret = 1;
1792  for (int i = 0; i < ndim; ++i) {
1793  ret *= src->shape[i];
1794  }
1795  return tvm::cast(dtype, ret);
1796  },
1797  name, tag);
1798 }
1799 
1814 inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value,
1815  int depth, int axis, const DataType& dtype,
1816  Array<PrimExpr> oshape = Array<PrimExpr>(),
1817  const std::string name = "T_one_hot", const std::string tag = kInjective) {
1818  int true_axis = (axis == -1) ? indices->shape.size() : axis;
1819  if (oshape.size() == 0) {
1820  int ndim = indices->shape.size() + 1;
1821  int indices_index = 0;
1822  for (int i = 0; i < ndim; i++) {
1823  if (i == true_axis) {
1824  oshape.push_back(Integer(depth));
1825  } else {
1826  oshape.push_back(indices->shape[indices_index++]);
1827  }
1828  }
1829  }
1830 
1831  PrimExpr on_value_cast = cast(dtype, on_value);
1832  PrimExpr off_value_cast = cast(dtype, off_value);
1833  return compute(
1834  oshape,
1835  [&](const Array<Var>& iter_vars) {
1836  Array<Var> indices_indices;
1837  for (size_t i = 0; i < iter_vars.size(); i++) {
1838  if (static_cast<int>(i) == true_axis) {
1839  continue;
1840  }
1841 
1842  indices_indices.push_back(iter_vars[i]);
1843  }
1844 
1845  auto idx = iter_vars[true_axis];
1846  return tir::Select(indices(indices_indices) == idx, on_value_cast, off_value_cast);
1847  },
1848  name, tag);
1849 }
1850 
1861 inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array<PrimExpr>& output_shape,
1862  const Tensor& sparse_values, const PrimExpr& default_value,
1863  const std::string name = "T_sparse_to_dense",
1864  const std::string tag = kInjective) {
1865  ICHECK(sparse_indices->dtype.is_int()) << "sparse_indices only accepts integer values";
1866  ICHECK_LE(sparse_indices->shape.size(), 3)
1867  << "sparse_indices tensor should be 0D, 1D, or 2D only";
1868  ICHECK_LE(sparse_values->shape.size(), 2) << "sparse_values tensor should be 0D or 1D only";
1869 
1870  const auto rank_sparse_indices = static_cast<int>(sparse_indices->shape.size());
1871  Array<PrimExpr> oshape;
1872  for (auto l : output_shape) {
1873  oshape.push_back(l);
1874  }
1875  return compute(
1876  oshape,
1877  [&](const Array<Var>& indices) {
1878  PrimExpr ret = default_value;
1879  if (0 == rank_sparse_indices) {
1880  ret = if_then_else(indices[0] == sparse_indices(), sparse_values(), ret);
1881  } else if (1 == rank_sparse_indices) {
1882  for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
1883  ret = if_then_else(indices[0] == sparse_indices[j], sparse_values[j], ret);
1884  }
1885  } else {
1886  for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
1887  PrimExpr aggregate_condition;
1888  for (int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) {
1889  PrimExpr comparision = indices[k] == sparse_indices[j][k];
1890  aggregate_condition = 0 == k ? comparision : aggregate_condition && comparision;
1891  }
1892  ret = if_then_else(aggregate_condition, sparse_values[j], ret);
1893  }
1894  }
1895  return ret;
1896  },
1897  name, tag);
1898 }
1899 
1912 inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k1, int k2,
1913  bool super_diag_right_align, bool sub_diag_right_align,
1914  const std::string name = "T_matrix_set_diag",
1915  const std::string tag = kInjective) {
1916  size_t ndim = input->shape.size() - 1;
1917 
1918  bool only_one_diagonal = k1 == k2;
1919 
1920  return compute(
1921  input->shape,
1922  [&](const Array<Var>& iter_vars) {
1923  auto get_diag = [&]() {
1924  Array<PrimExpr> diagonal_indices;
1925  PrimExpr k, offset = 0;
1926  for (size_t i = 0; i < ndim - 1; i++) {
1927  diagonal_indices.push_back(iter_vars[i]);
1928  }
1929  if (only_one_diagonal) {
1930  k = k1;
1931  } else {
1932  // Determining which diagonal/sub-diagonal/super-diagonal it is
1933  k = iter_vars[ndim] - iter_vars[ndim - 1];
1934  diagonal_indices.push_back(k2 - k);
1935 
1936  // Calculating the offset in diagonal tensor for this diagonal
1937  auto get_offset = [&](PrimExpr M, PrimExpr N) {
1938  // offset = max_diagonal_length - diagonal_length
1939  return diagonal->shape[diagonal->shape.size() - 1] - if_then_else(M < N, M, N);
1940  };
1941  offset = if_then_else(
1942  k >= 0,
1943  super_diag_right_align ? get_offset(input->shape[ndim] - k, input->shape[ndim - 1])
1944  : 0,
1945  sub_diag_right_align ? get_offset(input->shape[ndim], input->shape[ndim - 1] + k)
1946  : 0);
1947  }
1948  diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) +
1949  offset);
1950  return diagonal(diagonal_indices);
1951  };
1952  return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1,
1953  if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2,
1954  get_diag(), input(iter_vars)),
1955  input(iter_vars));
1956  },
1957  name, tag);
1958 }
1959 
1968 inline Tensor adv_index(const Tensor& data, const Array<Tensor>& indices,
1969  const std::string name = "advanced_index",
1970  const std::string tag = kInjective) {
1971  ICHECK_LE(indices.size(), data->shape.size()) << "too many indices for data!";
1972  Array<PrimExpr> oshape;
1973  Array<PrimExpr> broadcast_shape;
1974  Array<Tensor> bindices;
1975 
1976  broadcast_shape = indices[0]->shape;
1977  for (size_t i = 1; i < indices.size(); ++i) {
1978  auto bh = detail::BroadcastShape(broadcast_shape, indices[i]->shape);
1979  broadcast_shape = Array<PrimExpr>(bh.common_shape.begin(), bh.common_shape.end());
1980  }
1981  if (indices.size() == 1) {
1982  // quick path
1983  bindices = indices;
1984  } else {
1985  // Do broadcast for indices
1986  for (size_t i = 0; i < indices.size(); ++i) {
1987  bindices.push_back(broadcast_to(indices[i], broadcast_shape));
1988  }
1989  }
1990 
1991  for (const auto& dim : broadcast_shape) {
1992  oshape.push_back(dim);
1993  }
1994  for (size_t i = indices.size(); i < data->shape.size(); ++i) {
1995  oshape.push_back(data->shape[i]);
1996  }
1997 
1998  return compute(
1999  oshape,
2000  [&](const Array<Var>& iter_var) {
2001  Array<PrimExpr> tensor_indices;
2002  for (size_t i = 0; i < broadcast_shape.size(); ++i) {
2003  tensor_indices.push_back(iter_var[i]);
2004  }
2005 
2006  Array<PrimExpr> real_indices;
2007  for (size_t i = 0; i < bindices.size(); ++i) {
2008  real_indices.push_back(bindices[i](tensor_indices));
2009  }
2010  for (size_t i = broadcast_shape.size(); i < iter_var.size(); ++i) {
2011  real_indices.push_back(iter_var[i]);
2012  }
2013 
2014  return data(real_indices);
2015  },
2016  name, tag);
2017 }
2018 
2019 } // namespace topi
2020 } // namespace tvm
2021 #endif // TVM_TOPI_TRANSFORM_H_
void reserve(int64_t n)
Make sure the list has the capacity of at least n.
Definition: array.h:564
Managed reference to LayoutNode.
Definition: data_layout.h:123
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
bool Equals(const Layout &rhs) const
Whether the two layouts are equal.
Definition: data_layout.h:276
Tensor strided_slice_with_axes(const Tensor &x, const Array< Integer > &begin, const Array< Integer > &end, const Array< Integer > &strides, const Array< Integer > &axes, std::string slice_mode="end", std::string name="T_strided_slice_with_axes", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:769
Tensor sparse_to_dense(const Tensor &sparse_indices, const Array< PrimExpr > &output_shape, const Tensor &sparse_values, const PrimExpr &default_value, const std::string name="T_sparse_to_dense", const std::string tag=kInjective)
Get a dense tensor.
Definition: transform.h:1861
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
Array< PrimExpr > StridedSliceOutputShape(const Array< PrimExpr > &ishape, const Array< Integer > &begin, const Array< Integer > &end, const Array< Integer > &strides, const Array< Integer > &axes, const std::string &slice_mode)
Calcluate the output shape of strided_slice, the entry point for Relay type relation.
Definition: transform.h:741
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:943
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span=Span())
construct a new range with min and extent The corresponding constructor is removed, because that is counter convention of tradition meaning of range(begin, end)
Tensor sliding_window(const Tensor &x, int axis, Array< Integer > window_shape, Array< Integer > strides, std::string name="T_sliding_window", std::string tag="")
Creates an operation to slide a window over the input x.
Definition: transform.h:68
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Tensor expression language DSL.
Definition: extracted_task.h:33
Tensor tensordot(const Tensor &A, const tvm::te::Tensor &B, int axes=2, std::string name="T_tensordot", std::string tag=kMatMul)
A generalization of matrix multiplication to tensors.
Definition: transform.h:1447
Tensor dynamic_strided_slice(const Tensor &x, const Array< PrimExpr > &begin, const Array< PrimExpr > &end, const Array< PrimExpr > &strides, std::string name="T_dynamic_strided_slice", std::string tag=kInjective)
strided_slice of a tensor where begin/end/stride can be mixed static and dynamic
Definition: transform.h:652
PrimExpr ceil(PrimExpr x, Span span=Span())
Calculate ceil(x)
a named variable in TIR
Definition: var.h:88
Tensor where(const Tensor &condition, const Tensor &x, const Tensor &y, std::string name="T_where", std::string tag=kBroadcast)
Return the elements, either from x or y, depending on the condition.
Definition: transform.h:1132
Tensor one_hot(const Tensor &indices, const PrimExpr on_value, const PrimExpr off_value, int depth, int axis, const DataType &dtype, Array< PrimExpr > oshape=Array< PrimExpr >(), const std::string name="T_one_hot", const std::string tag=kInjective)
Returns a one-hot tensor where the locations repsented by indices take value on_value, other locations take value off_value.
Definition: transform.h:1814
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
Tensor matrix_set_diag(const Tensor &input, const Tensor &diagonal, int k1, int k2, bool super_diag_right_align, bool sub_diag_right_align, const std::string name="T_matrix_set_diag", const std::string tag=kInjective)
Returns a tensor with the diagonal of input tensor replaced with the provided diagonals.
Definition: transform.h:1912
constexpr auto kMatMul
Definition: tags.h:37
constexpr auto kInjective
Definition: tags.h:33
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
Utility functions for strided_slice op.
Tensor unravel_index(const Tensor &x, const Tensor &shape, std::string name="T_unravel", std::string tag=kInjective)
Converts a flat index or array of flat indices into a tuple of coordinate arrays. ...
Definition: transform.h:359
Array< Tensor > split(const Tensor &x, Array< PrimExpr > split_indices, int axis, std::string name="T_split", std::string tag=kInjective)
Split a tensor into multiple sub-tensors.
Definition: transform.h:575
DataType dtype() const
Definition: expr.h:126
void parse_auto_scheduler_layout(const String &layout, Array< PrimExpr > *shape, std::vector< std::string > *axes)
Utility function for auto_scheduler_layout_transform.
Definition: transform.h:1630
Tensor gather(const Tensor &data, int axis, const Tensor &indices, std::string name="T_gather", std::string tag=kInjective)
Gather values along given axis from given indices.
Definition: transform.h:1311
size_t ndim() const
Definition: tensor.h:214
Tensor dyn_tile(const Tensor &x, Array< PrimExpr > new_shape, size_t rdim, std::string name="T_tile", std::string tag=kBroadcast)
Creates an operation to tile elements of an array.
Definition: transform.h:1274
Tensor layout_transform(const Tensor &src, const std::string &src_layout, const std::string &dst_layout, const std::string name="T_layout_trans", const std::string tag=kInjective)
Transform the layout according to src_layout and dst_layout.
Definition: transform.h:1595
Tensor auto_scheduler_layout_transform(const Tensor &src, const String &src_layout, const String &dst_layout, const String name="T_auto_scheduler_layout_trans", const String tag=kInjective)
Transform the auto-scheduler generated layout according to src_layout and dst_layout.
Definition: transform.h:1666
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
Tensor squeeze(const Tensor &x, Array< Integer > axis, bool atleast1d=false, std::string name="T_squeeze", std::string tag=kInjective)
Remove size 1 dimensions from the shape of a tensor. The removed dimensions must have a constant size...
Definition: transform.h:407
void Set(int64_t i, T value)
set i-th element of the array.
Definition: array.h:586
Constant integer literals in the program.
Definition: expr.h:489
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:455
Defines a remapping of buffer indices.
Tensor expand_dims(const Tensor &x, int axis, int num_newaxis=1, std::string name="T_expand_dims", std::string tag=kBroadcast)
Creates an operation to insert new dimensions of length 1.
Definition: transform.h:147
Tensor tile(const Tensor &x, Array< Integer > reps, std::string name="T_tile", std::string tag=kBroadcast)
Creates an operation to tile elements of an array.
Definition: transform.h:1219
Tensor sum(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:326
Utility functions for handling constants in TVM expressions.
constexpr auto kBroadcast
Definition: tags.h:36
Range constainer.
Definition: expr.h:711
Tensor arange(const PrimExpr &start, const PrimExpr &stop, const PrimExpr &step, DataType dtype, std::string name="T_arange", std::string tag=kInjective)
Definition: transform.h:1539
size_t size() const
Definition: array.h:418
Runtime primitive data type.
Definition: data_type.h:41
bool defined() const
Definition: object.h:544
Tensor stack(const Array< Tensor > &inputs, int axis=0, std::string name="T_stack", std::string tag=kInjective)
Join a sequence of tensors along a new axis.
Definition: transform.h:529
Utility functions for handling tensor.
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:168
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Definition: index_map.h:167
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
Tensor concatenate(const Array< Tensor > &inputs, int axis=0, std::string name="T_concat", std::string tag=kInjective)
Join a sequence of tensors along an existing axis.
Definition: transform.h:470
Tensor take(const Tensor &a, const Tensor &indices, int batch_dims, std::string mode="clip", std::string name="T_take", std::string tag=kInjective)
Take elements from an flattened input array when axis is None.
Definition: transform.h:898
Managed reference class to IntImmNode.
Definition: expr.h:518
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
int64_t value
the Internal value.
Definition: expr.h:492
Reference to string objects.
Definition: string.h:97
PrimExpr make_zero(DataType t, Span span=Span())
Make a const zero expr.
Definition: op.h:951
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1758
Array< Tensor > meshgrid(const Array< Tensor > &inputs, const std::string &indexing, std::string name="T_meshgrid", std::string tag=kInjective)
Produce grids by expanding input over dimensions defined by other inputs.
Definition: transform.h:1560
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of truncdiv
iterator end() const
Definition: array.h:388
PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b)
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
iterator begin() const
Definition: array.h:385
Operation node can generate one or multiple Tensors.
Tensor meta_schedule_layout_transform(const Tensor &src, const tir::IndexMap &index_map, const String name="T_meta_schedule_layout_trans", const String tag=kInjective)
Transform the meta-schedule generated layout according to TIR&#39;s IndexMap.
Definition: transform.h:1733
Managed reference to SelectNode.
Definition: expr.h:589
Bijective function mapping for data layout transformation. Given two Layout, BijectiveLayout build an...
Definition: data_layout.h:330
Tensor transpose(const Tensor &x, Array< Integer > axes, std::string name="T_transpose", std::string tag=kInjective)
Permute the dimensions of an array.
Definition: transform.h:196
Tensor ndarray_size(const Tensor &src, const DataType &dtype, const std::string &name="ndarray_size", const std::string &tag=kInjective)
Get the size of input tensor.
Definition: transform.h:1783
Tensor sequence_mask(const Tensor &data, const Tensor &valid_length, double mask_value, int axis, std::string name="T_sequence_mask", std::string tag=kInjective)
Mask the out-of-boundary elements of each sequence.
Definition: transform.h:946
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Tensor adv_index(const Tensor &data, const Array< Tensor > &indices, const std::string name="advanced_index", const std::string tag=kInjective)
Numpy style advanced indexing with tensor.
Definition: transform.h:1968
External function interface to rocBLAS libraries.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Tensor reverse_sequence(const Tensor &x, const Tensor &seq_lengths, int seq_axis=1, int batch_axis=0, std::string name="T_reverse_sequence", std::string tag=kInjective)
Reverse the tensor for variable length slices. Input is first sliced along batch axis and then elemen...
Definition: transform.h:255
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type...
Definition: elemwise.h:280
Tensor reshape(const Tensor &x, Array< PrimExpr > newshape, std::string name="T_reshape", std::string tag=kInjective)
Reshape a tensor.
Definition: transform.h:320
tvm::te::Tensor broadcast_to(const tvm::te::Tensor &t, const tvm::Array< tvm::PrimExpr > &output_shape, std::string name="T_broadcast_to", std::string tag=kBroadcast)
Creates an operation that broadcasts a tensor into a compatible shape according to numpy&#39;s rules...
Definition: broadcast.h:48
Tensor repeat(const Tensor &x, int repeats, int axis, std::string name="T_repeat", std::string tag=kBroadcast)
Creates an operation to repeat elements of an array.
Definition: transform.h:1172
Tensor strided_slice(const Tensor &x, const Array< Integer > &begin, const Array< Integer > &end, const Array< Integer > &strides, std::string slice_mode="end", std::string name="T_strided_slice", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:815
Broadcast op constructions.
Reference to PrimExprNode.
Definition: expr.h:112
Layout expression to describe the data organization of a tensor. And BijectiveLayout to mapping two d...
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:865
Tensor gather_nd(const Tensor &data, const Tensor &indices, int batch_dims=0, std::string name="T_gather_nd", std::string tag=kInjective)
Gather elements from a n-dimension array.
Definition: transform.h:1364
Array< Tensor > split_sections(const Tensor &x, int num_sections, int axis, std::string name="T_split_sections", std::string tag=kInjective)
Split a tensor into a number of sub-tensors.
Definition: transform.h:855
Detail broadcast.
Index ravel and unraval operations.
IndexMap Inverse(Array< Range > initial_ranges) const
Generate the inverse mapping.
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:423
tvm::te::Tensor matmul(const tvm::te::Tensor &A, const tvm::te::Tensor &B, bool trans_a=false, bool trans_b=false, std::string name="T_matmul", std::string tag=kMatMul)
Creates an operation that calculates a matrix multiplication (row-major notation): A(i...
Definition: transform.h:1425
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:154
Container of constant int that adds more constructors.
Definition: expr.h:618