tvm
transform.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_TRANSFORM_H_
25 #define TVM_TOPI_TRANSFORM_H_
26 
27 #include <tvm/te/operation.h>
28 #include <tvm/tir/data_layout.h>
29 #include <tvm/topi/broadcast.h>
35 #include <tvm/topi/tags.h>
36 
37 #include <algorithm>
38 #include <iterator>
39 #include <limits>
40 #include <string>
41 #include <unordered_set>
42 #include <vector>
43 
44 namespace tvm {
45 namespace topi {
46 
47 using namespace tvm::te;
48 using namespace topi::detail;
49 
62 inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1,
63  std::string name = "T_expand_dims", std::string tag = kBroadcast) {
64  int ndim = static_cast<int>(x->shape.size());
65  ICHECK(-ndim - 1 <= axis && axis <= ndim)
66  << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]"
67  << ", but got axis = " << axis << ", and data.ndim = " << ndim;
68  ICHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`"
69  << ", but got num_newaxis = " << num_newaxis;
70  if (axis < 0) {
71  // Calculate offset from last dimension
72  axis = ndim + axis + 1;
73  }
74  Array<PrimExpr> new_shape;
75  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
76  new_shape.push_back(x->shape[i]);
77  }
78  for (size_t i = 0; i < static_cast<size_t>(num_newaxis); ++i) {
79  new_shape.push_back(1);
80  }
81  for (size_t i = axis; i < x->shape.size(); ++i) {
82  new_shape.push_back(x->shape[i]);
83  }
84 
85  return compute(
86  new_shape,
87  [&](const Array<Var>& indices) {
88  Array<PrimExpr> idx;
89  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
90  idx.push_back(indices[i]);
91  }
92  for (size_t i = axis + num_newaxis; i < indices.size(); ++i) {
93  idx.push_back(indices[i]);
94  }
95  return x(idx);
96  },
97  name, tag);
98 }
99 
111 inline Tensor transpose(const Tensor& x, Array<Integer> axes, std::string name = "T_transpose",
112  std::string tag = kInjective) {
113  if (!axes.defined() || axes.size() == 0) {
114  axes = Array<Integer>();
115  for (int i = static_cast<int>(x->shape.size()) - 1; i >= 0; --i) {
116  axes.push_back(i);
117  }
118  }
119 
120  Array<PrimExpr> new_shape;
121  for (size_t i = 0; i < axes.size(); ++i) {
122  int axis = static_cast<int>(axes[i]->value);
123  int new_axis = axis;
124  if (axis < 0) {
125  new_axis = static_cast<int>(x->shape.size()) + axis;
126  axes.Set(i, new_axis);
127  }
128  ICHECK((new_axis >= 0) && (new_axis < static_cast<int>(x->shape.size())))
129  << "axis=" << axis << " is invalid for the " << static_cast<int>(x->shape.size())
130  << "-dimensional input tensor";
131 
132  for (size_t j = 0; j < axes.size(); ++j) {
133  if (i != j) {
134  ICHECK(new_axis != static_cast<int>(axes[j]->value)) << "repeated axis in transpose";
135  }
136  }
137  new_shape.push_back(x->shape[new_axis]);
138  }
139 
140  return compute(
141  new_shape,
142  [&](const Array<Var>& indices) {
143  std::vector<PrimExpr> idx;
144  for (size_t i = 0; i < axes.size(); ++i) {
145  idx.push_back(1);
146  }
147  for (size_t i = 0; i < axes.size(); ++i) {
148  int axis = static_cast<int>(axes[i]->value);
149  idx[axis] = indices[i];
150  }
151  return x(idx);
152  },
153  name, tag);
154 }
155 
170 inline Tensor reverse_sequence(const Tensor& x, const Tensor& seq_lengths, int seq_axis = 1,
171  int batch_axis = 0, std::string name = "T_reverse_sequence",
172  std::string tag = kInjective) {
173  size_t src_tensor_dim = x->shape.size();
174  int seq_axis_inp = seq_axis;
175 
176  if (seq_lengths.defined()) {
177  size_t seq_lengths_dim = seq_lengths->shape.size();
178  int batch_axis_inp = batch_axis;
179  if (batch_axis < 0) {
180  batch_axis = static_cast<int>(x->shape.size()) + batch_axis;
181  }
182 
183  ICHECK(seq_lengths_dim == 1) << "seq_lengths should be 1D vector";
184 
185  ICHECK(GetConstInt(seq_lengths->shape[0]) == GetConstInt(x->shape[batch_axis]))
186  << "For reverse_sequnece seq_lengths size should match with dimension of batch axis"
187  << ", but got dimension of batch_axis = " << GetConstInt(x->shape[batch_axis])
188  << ", and seq_length size = " << GetConstInt(seq_lengths->shape[0]);
189 
190  ICHECK((0 <= batch_axis) && (batch_axis < static_cast<int>(x->shape.size())))
191  << "batch_axis=" << batch_axis_inp << " is invalid for the "
192  << static_cast<int>(x->shape.size()) << "-dimensional input tensor";
193  }
194 
195  if (seq_axis < 0) {
196  seq_axis = static_cast<int>(x->shape.size()) + seq_axis;
197  }
198  ICHECK((0 <= seq_axis) && (seq_axis < static_cast<int>(x->shape.size())))
199  << "seq_axis=" << seq_axis_inp << " is invalid for the " << static_cast<int>(x->shape.size())
200  << "-dimensional input tensor";
201 
202  auto func = [&](const Array<Var>& indices) {
203  Array<PrimExpr> real_indices;
204  for (size_t i = 0; i < src_tensor_dim; ++i) {
205  if (i == static_cast<size_t>(seq_axis)) {
206  if (seq_lengths.defined()) {
207  auto len = seq_lengths(indices[batch_axis]);
208  auto idx = if_then_else(
209  len <= 1 || len <= indices[i], indices[i],
210  if_then_else(len > x->shape[i], x->shape[i] - 1 - indices[i], len - 1 - indices[i]));
211  real_indices.push_back(idx);
212  } else {
213  real_indices.push_back(x->shape[i] - 1 - indices[i]);
214  }
215  } else {
216  real_indices.push_back(indices[i]);
217  }
218  }
219  return x(real_indices);
220  };
221 
222  return compute(x->shape, func, name, tag);
223 }
224 
235 inline Tensor reshape(const Tensor& x, Array<PrimExpr> newshape, std::string name = "T_reshape",
236  std::string tag = kInjective) {
237  auto x_shape = x->shape;
238  Array<PrimExpr> target_shape;
239 
240  for (const auto& ele : newshape) {
241  if (ele.as<IntImmNode>()) {
242  target_shape.push_back(cast(DataType::Int(32), ele));
243  } else {
244  target_shape.push_back(ele);
245  }
246  }
247 
248  if (is_empty_shape(target_shape)) {
249  return compute(
250  target_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
251  } else {
252  return compute(
253  target_shape,
254  [&](const Array<Var>& indices) {
255  return x(UnravelIndex(
256  RavelIndex(Array<PrimExpr>{indices.begin(), indices.end()}, target_shape), x_shape));
257  },
258  name, tag);
259  }
260 }
261 
273 inline Tensor unravel_index(const Tensor& x, const Tensor& shape, std::string name = "T_unravel",
274  std::string tag = kInjective) {
275  auto x_shape = x->shape;
276  auto shape_shape = shape->shape;
277 
278  Array<PrimExpr> oshape;
279  oshape.push_back(shape_shape[0]);
280  if (x_shape.size() != 0) {
281  oshape.push_back(x_shape[0]);
282  }
283 
284  auto func = [&](const Array<Var>& indices) {
285  auto i = indices[0];
286  std::vector<PrimExpr> indices_divs;
287  PrimExpr ret = 0;
288  PrimExpr cur_val = 0;
289  PrimExpr index_val = 0;
290 
291  if (x_shape.size() != 0) {
292  index_val = x[indices[1]];
293  } else {
294  index_val = x();
295  }
296  indices_divs.push_back(index_val);
297  for (int v = GetConstInt(shape_shape[0]) - 1; v >= 0; --v) {
298  ret = tvm::if_then_else(i == v, indexmod(indices_divs.back(), shape[v]), ret);
299  cur_val = indexdiv(indices_divs.back(), shape[v]);
300  indices_divs.push_back(cur_val);
301  }
302  return ret;
303  };
304 
305  return compute(oshape, func, name, tag);
306 }
307 
321 inline Tensor squeeze(const Tensor& x, Array<Integer> axis, bool atleast1d = false,
322  std::string name = "T_squeeze", std::string tag = kInjective) {
323  auto ndim = x->shape.size();
324  std::vector<int> axis_val;
325  if (!axis.defined() || axis.size() == 0) {
326  for (size_t i = 0; i < ndim; ++i) {
327  if (IsConstInt(x->shape[i]) && GetConstInt(x->shape[i]) == 1) {
328  axis_val.push_back(static_cast<int>(i));
329  }
330  }
331  } else {
332  for (size_t i = 0; i < axis.size(); ++i) {
333  int64_t val = axis[i]->value;
334  if (val < 0) {
335  val += static_cast<int>(x->shape.size());
336  }
337  if (IsConstInt(x->shape[val])) {
338  ICHECK_EQ(GetConstInt(x->shape[val]), 1) << "Dimension " << val << " must have size 1";
339  }
340  axis_val.push_back(val);
341  }
342  }
343 
344  std::unordered_set<int> axis_set(axis_val.begin(), axis_val.end());
345 
346  Array<PrimExpr> out_shape;
347  for (size_t i = 0; i < ndim; ++i) {
348  if (axis_set.count(static_cast<int>(i)) == 0) {
349  out_shape.push_back(x->shape[i]);
350  }
351  }
352  if (out_shape.size() == 0 && atleast1d) {
353  out_shape.push_back(1);
354  }
355 
356  return compute(
357  out_shape,
358  [&](const Array<Var>& indices) {
359  Array<PrimExpr> real_indices;
360  int flag = 0;
361  for (size_t i = 0; i < ndim; ++i) {
362  if (axis_set.count(static_cast<int>(i)) == 0) {
363  real_indices.push_back(indices[i - flag]);
364  } else {
365  real_indices.push_back(0);
366  flag += 1;
367  }
368  }
369  return x(real_indices);
370  },
371  name, tag);
372 }
373 
384 inline Tensor concatenate(const Array<Tensor>& inputs, int axis = 0, std::string name = "T_concat",
385  std::string tag = kInjective) {
386  int ndim = static_cast<int>(inputs[0]->shape.size());
387  ICHECK(-ndim <= axis && axis < ndim) << "concatenate only accepts `axis` in [-ndim, ndim)"
388  << ", but got axis = " << axis << ", and ndim = " << ndim;
389  if (axis < 0) {
390  axis += ndim;
391  }
392  ICHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds";
393 
394  Array<PrimExpr> axis_sizes;
395  for (auto t : inputs) {
396  axis_sizes.push_back(t->shape[axis]);
397  }
398  arith::Analyzer analyzer;
399  PrimExpr join_size = axis_sizes[0];
400  for (size_t i = 1; i < axis_sizes.size(); ++i) {
401  join_size += axis_sizes[i];
402  }
403  join_size = analyzer.Simplify(join_size);
404  Array<PrimExpr> out_shape;
405  for (size_t i = 0; i < inputs[0]->shape.size(); ++i) {
406  out_shape.push_back(i == static_cast<size_t>(axis) ? join_size : inputs[0]->shape[i]);
407  }
408 
409  return compute(
410  out_shape,
411  [&](const Array<Var>& indices) {
412  auto ret = inputs[0](indices);
413  auto ind = indices[axis];
414  for (size_t i = 0; i < inputs.size() - 1; ++i) {
415  ind -= axis_sizes[i];
416 
417  Array<PrimExpr> idx;
418  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
419  idx.push_back(indices[i]);
420  }
421  idx.push_back(ind);
422  for (size_t i = axis + 1; i < indices.size(); ++i) {
423  idx.push_back(indices[i]);
424  }
425 
426  ret = tvm::if_then_else(ind >= 0, inputs[i + 1](idx), ret);
427  }
428  return ret;
429  },
430  name, tag);
431 }
432 
443 inline Tensor stack(const Array<Tensor>& inputs, int axis = 0, std::string name = "T_stack",
444  std::string tag = kInjective) {
445  int ndim = static_cast<int>(inputs[0]->shape.size());
446  ICHECK(-ndim - 1 <= axis && axis <= ndim)
447  << "stack only accepts `axis` in [-ndim, ndim)"
448  << ", but got axis = " << axis << ", and ndim = " << ndim;
449  if (axis < 0) {
450  axis += ndim + 1;
451  }
452  ICHECK_LT(axis, inputs[0]->shape.size() + 1) << "axis out of bounds";
453 
454  const int stack_size = static_cast<int>(inputs.size());
455  Array<PrimExpr> out_shape;
456  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) out_shape.push_back(inputs[0]->shape[i]);
457  out_shape.push_back(stack_size);
458  for (size_t i = static_cast<size_t>(axis); i < static_cast<size_t>(ndim); ++i)
459  out_shape.push_back(inputs[0]->shape[i]);
460 
461  return compute(
462  out_shape,
463  [&](const Array<Var>& indices) {
464  Array<PrimExpr> idx;
465  for (size_t i = 0; i < indices.size(); ++i)
466  if (i != static_cast<size_t>(axis)) idx.push_back(indices[i]);
467  auto ind = indices[axis];
468  auto ret = inputs[0](idx);
469  for (int i = 0; i < static_cast<int>(inputs.size() - 1); ++i) {
470  ret = tvm::if_then_else(ind == i + 1, inputs[i + 1](idx), ret);
471  }
472  return ret;
473  },
474  name, tag);
475 }
476 
489 inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int axis,
490  std::string name = "T_split", std::string tag = kInjective) {
491  if (axis < 0) {
492  axis += static_cast<int>(x->shape.size());
493  }
494  ICHECK_LT(axis, x->shape.size()) << "axis out of bounds";
495 
496  auto src_axis_size = x->shape[axis];
497  std::vector<PrimExpr> begin_ids;
498  begin_ids.push_back(0);
499 
500  for (auto idx : split_indices) {
501  auto idx_node = idx.as<IntImmNode>();
502  auto back_node = begin_ids.back().as<IntImmNode>();
503  if (idx_node && back_node) {
504  ICHECK_GT(idx_node->value, back_node->value) << "split_indices must be sorted";
505  }
506  begin_ids.push_back(idx);
507  }
508 
509  Array<Array<PrimExpr> > out_shapes;
510  for (size_t i = 0; i < begin_ids.size(); ++i) {
511  PrimExpr out_axis_size;
512  if (i == begin_ids.size() - 1) {
513  out_axis_size = src_axis_size - begin_ids[i];
514  } else {
515  out_axis_size = begin_ids[i + 1] - begin_ids[i];
516  }
517 
519  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
520  shape.push_back(x->shape[i]);
521  }
522  shape.push_back(out_axis_size);
523  for (size_t i = axis + 1; i < x->shape.size(); ++i) {
524  shape.push_back(x->shape[i]);
525  }
526 
527  out_shapes.push_back(shape);
528  }
529 
530  Array<Tensor> result;
531  for (size_t i = 0; i < begin_ids.size(); ++i) {
532  result.push_back(compute(
533  out_shapes[i],
534  [&](const Array<Var>& indices) {
535  auto begin = begin_ids[i];
536  Array<PrimExpr> real_indices;
537  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
538  real_indices.push_back(indices[j]);
539  }
540  real_indices.push_back(indices[axis] + begin);
541  for (size_t j = axis + 1; j < indices.size(); ++j) {
542  real_indices.push_back(indices[j]);
543  }
544 
545  return x(real_indices);
546  },
547  name, tag));
548  }
549 
550  return result;
551 }
552 
566 inline Tensor dynamic_strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
567  const Array<PrimExpr>& end, const Array<PrimExpr>& strides,
568  std::string name = "T_dynamic_strided_slice",
569  std::string tag = kInjective) {
570  const size_t src_tensor_dim = x->shape.size();
571  ICHECK_LE(begin.size(), src_tensor_dim);
572  ICHECK_LE(end.size(), src_tensor_dim);
573  ICHECK_LE(strides.size(), src_tensor_dim);
574  ICHECK_EQ(begin.size(), end.size());
575  ICHECK_EQ(begin.size(), strides.size());
576 
577  const size_t num_slice_axes = begin.size();
578  Array<PrimExpr> out_shape;
579 
580  for (size_t i = 0; i < num_slice_axes; ++i) {
581  auto d = indexdiv(end[i] - begin[i], strides[i]);
582  if (d->IsInstance<tvm::IntImmNode>()) {
583  // Preserve static dimension if possible
584  out_shape.push_back(d);
585  } else {
586  out_shape.push_back(tvm::tir::Var("dim"));
587  }
588  }
589 
590  for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) {
591  out_shape.push_back(x->shape[i]);
592  }
593 
594  return te::compute(
595  out_shape,
596  [&](const Array<tvm::tir::Var>& indices) {
597  Array<PrimExpr> real_indices;
598  for (size_t i = 0; i < num_slice_axes; ++i) {
599  real_indices.push_back(indices[i] * strides[i] + tvm::min(begin[i], x->shape[i] - 1));
600  }
601  // keep input dim
602  for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) {
603  real_indices.push_back(indices[i]);
604  }
605  return x(real_indices);
606  },
607  name, tag);
608 }
609 
624  const te::Tensor& end, const te::Tensor& strides,
625  std::string name = "T_strided_slice_dynamic",
626  std::string tag = topi::kInjective) {
627  const int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value;
628  ICHECK_EQ(end->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
629  ICHECK_EQ(strides->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
630 
631  Array<PrimExpr> begin_expr, end_expr, strides_expr;
632  for (int64_t i = 0; i < num_dynamic_axes; ++i) {
633  auto i64_ind = IntImm(DataType::Int(64), i);
634  begin_expr.push_back(begin(i64_ind));
635  end_expr.push_back(end(i64_ind));
636  strides_expr.push_back(strides(i64_ind));
637  }
638  return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, tag);
639 }
640 
656  const Array<PrimExpr>& ishape, const Array<Integer>& begin, const Array<Integer>& end,
657  const Array<Integer>& strides, const Array<Integer>& axes, const std::string& slice_mode) {
658  ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size());
659  std::vector<int64_t> begin_vec, end_vec, strides_vec;
660  std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode);
661  auto begin_canonicalized = StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes,
662  begin[0]->dtype, slice_mode);
663  return StridedSliceOutputShape(ishape, begin_vec, end_vec, strides_vec, axes, slice_mode,
664  begin_canonicalized, true);
665 }
666 
683 inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& begin,
684  const Array<Integer>& end, const Array<Integer>& strides,
685  const Array<Integer>& axes, std::string slice_mode = "end",
686  std::string name = "T_strided_slice_with_axes",
687  std::string tag = kInjective) {
688  const size_t src_tensor_dim = x->shape.size();
689  ICHECK(axes.size() <= src_tensor_dim);
690  ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size());
691 
692  std::vector<int64_t> begin_vec, end_vec, strides_vec;
693  std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode);
694 
695  auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, axes,
696  begin[0]->dtype, slice_mode);
697  auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec, strides_vec, axes,
698  slice_mode, begin_expr);
699 
700  return te::compute(
701  out_shape,
702  [&](const Array<tir::Var>& indices) {
703  Array<PrimExpr> real_indices;
704  for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]);
705  for (size_t i = 0; i < axes.size(); ++i) {
706  auto stride = make_const(strides[i].dtype(), strides_vec[i]);
707  PrimExpr ind = indices[axes[i]] * stride + begin_expr[i];
708  real_indices.Set(axes[i], ind);
709  }
710  return x(real_indices);
711  },
712  name, tag);
713 }
714 
729 inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, const Array<Integer>& end,
730  const Array<Integer>& strides, std::string slice_mode = "end",
731  std::string name = "T_strided_slice", std::string tag = kInjective) {
732  size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
733  Array<Integer> axes;
734  for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i);
735  Array<Integer> begin_full(begin);
736  Array<Integer> end_full(end);
737  Array<Integer> strides_full(strides);
738 
739  const IntImm one = IntImm(DataType::Int(64), 1);
740  const IntImm zero = IntImm(DataType::Int(64), 0);
742 
743  for (size_t i = strides.size(); i < src_tensor_dim; ++i) {
744  strides_full.push_back(one);
745  }
746  for (size_t i = begin.size(); i < src_tensor_dim; ++i) {
747  begin_full.push_back(GetConstInt(strides_full[i]) > 0 ? zero : max_range);
748  }
749  for (size_t i = end.size(); i < src_tensor_dim; ++i) {
750  end_full.push_back(GetConstInt(strides_full[i]) < 0 ? zero : max_range);
751  }
752 
753  return strided_slice_with_axes(x, begin_full, end_full, strides_full, axes, slice_mode, name,
754  tag);
755 }
756 
769 inline Array<Tensor> split_sections(const Tensor& x, int num_sections, int axis,
770  std::string name = "T_split_sections",
771  std::string tag = kInjective) {
772  if (axis < 0) {
773  axis += static_cast<int>(x->shape.size());
774  }
775  ICHECK_LT(axis, x->shape.size()) << "axis out of bounds";
776 
777  auto src_axis_size = x->shape[axis];
778 
779  ICHECK_GT(num_sections, 0) << "Slice count must be > 0";
780 
781  if (auto node = src_axis_size.as<IntImmNode>()) {
782  ICHECK_EQ(node->value % num_sections, 0)
783  << "num_sections must be an integer factor of the size of axis " << axis << " ("
784  << node->value << ")";
785  }
786 
787  Array<PrimExpr> split_indices;
788  auto seg_size = indexdiv(src_axis_size, num_sections);
789  for (int i = 0; i < num_sections; ++i) {
790  // region at index 0 is added by split()
791  if (i != 0) {
792  split_indices.push_back(seg_size * i);
793  }
794  }
795 
796  return split(x, split_indices, axis, name, tag);
797 }
798 
812 inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims,
813  std::string mode = "clip", std::string name = "T_take",
814  std::string tag = kInjective) {
815  Array<PrimExpr> a_shape = a->shape;
816  Array<PrimExpr> out_shape = indices->shape;
817  PrimExpr a_size = 1;
818  for (size_t i = 0; i < a_shape.size(); ++i) {
819  a_size = a_size * a_shape[i];
820  }
821 
822  if (mode == "clip") {
823  return compute(
824  out_shape,
825  [&](const Array<Var>& out_index) {
826  auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
827  return a(UnravelIndex(idx, a_shape));
828  },
829  name, tag);
830  } else if (mode == "fast") {
831  LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
832  "Make sure input indices are in bound";
833  return compute(
834  out_shape,
835  [&](const Array<Var>& out_index) { return a(UnravelIndex(indices(out_index), a_shape)); },
836  name, tag);
837  } else { // mode == "wrap"
838  return compute(
839  out_shape,
840  [&](const Array<Var>& out_index) {
841  auto idx = truncmod(truncmod(indices(out_index), a_size) + a_size, a_size);
842  return a(UnravelIndex(idx, a_shape));
843  },
844  name, tag);
845  }
846 }
847 
860 inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, double mask_value,
861  int axis, std::string name = "T_sequence_mask",
862  std::string tag = kInjective) {
863  ICHECK(axis == 0 || axis == 1) << "axis must be either 0 or 1";
864  ICHECK_EQ(valid_length->shape.size(), 1) << "valid_length must have ndim=1, i.e., (batch_size,).";
865  auto length_dim = data->shape[axis];
866  auto batch_dim = data->shape[1 - axis];
867  Array<PrimExpr> out_shape = data->shape;
868  Tensor out = compute(
869  out_shape,
870  [&](const Array<Var>& out_index) {
871  Array<PrimExpr> len_index;
872  auto tid = out_index[axis];
873  auto bid = out_index[1 - axis];
874  len_index.push_back(bid);
875  PrimExpr ret =
876  tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index),
877  tvm::tir::make_const(data->dtype, mask_value), data(out_index));
878  return ret;
879  },
880  name, tag);
881  return out;
882 }
883 
898 inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int axis,
899  std::string mode = "clip", std::string name = "T_take",
900  std::string tag = kInjective) {
901  if (axis < 0) {
902  axis += static_cast<int>(a->shape.size());
903  }
904  ICHECK_GE(axis, 0) << "axis out of bounds";
905  ICHECK_LT(axis, a->shape.size()) << "axis out of bounds";
906  auto axis_dim = a->shape[axis];
907  int indices_len = static_cast<int>(indices->shape.size());
908 
909  int batch_dims_ = batch_dims;
910  if (batch_dims_ != 0) {
911  ICHECK_GE(batch_dims_, -static_cast<int>(indices->shape.size())) << "batch_dims out of bounds";
912  ICHECK_LE(batch_dims_, indices->shape.size()) << "batch_dims out of bounds";
913 
914  if (batch_dims_ < 0) {
915  batch_dims_ = indices->shape.size() + batch_dims_;
916  }
917 
918  ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds";
919  ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to axis";
920  for (int i = 0; i < batch_dims_; ++i) {
921  auto addr1 = a->shape[i];
922  auto addr2 = indices->shape[i];
923  auto v1 = static_cast<IntImm*>(&addr1)->get()->value;
924  auto v2 = static_cast<IntImm*>(&addr2)->get()->value;
925  ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to indices.shape[" << i << "]";
926  }
927  }
928 
929  // The result shape is a.shape[:axis] + indices.shape[batch_dims:] +
930  // a.shape[axis + 1:].
931 
932  Array<PrimExpr> out_shape;
933  for (int i = 0; i < batch_dims_; ++i) {
934  out_shape.push_back(a->shape[i]);
935  }
936  for (int i = batch_dims_; i < axis; ++i) {
937  out_shape.push_back(a->shape[i]);
938  }
939  for (size_t i = static_cast<size_t>(batch_dims_); i < indices->shape.size(); ++i) {
940  out_shape.push_back(indices->shape[i]);
941  }
942  for (size_t i = axis + 1; i < a->shape.size(); ++i) {
943  out_shape.push_back(a->shape[i]);
944  }
945 
946  if (mode == "clip") {
947  if (batch_dims_ == 0) {
948  return compute(
949  out_shape,
950  [&](const Array<Var>& out_index) {
951  Array<PrimExpr> indices_position;
952  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
953  indices_position.push_back(out_index[j]);
954  }
955  Array<PrimExpr> real_indices;
956  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
957  real_indices.push_back(out_index[j]);
958  }
959  auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
960  real_indices.push_back(idx);
961  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
962  real_indices.push_back(out_index[j]);
963  }
964  return a(real_indices);
965  },
966  name, tag);
967  } else {
968  return compute(
969  out_shape,
970  [&](const Array<Var>& out_index) {
971  Array<PrimExpr> indices_position;
972  for (size_t j = 0; j < static_cast<size_t>(batch_dims_); ++j) {
973  indices_position.push_back(out_index[j]);
974  }
975  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len - batch_dims_); ++j) {
976  indices_position.push_back(out_index[j]);
977  }
978  Array<PrimExpr> real_indices;
979  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
980  real_indices.push_back(out_index[j]);
981  }
982  auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
983  real_indices.push_back(idx);
984  for (size_t j = axis + indices_len - batch_dims_; j < out_index.size(); ++j) {
985  real_indices.push_back(out_index[j]);
986  }
987  return a(real_indices);
988  },
989  name, tag);
990  }
991  } else if (mode == "fast") {
992  LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
993  "Make sure input indices are in bound";
994  return compute(
995  out_shape,
996  [&](const Array<Var>& out_index) {
997  Array<PrimExpr> indices_position;
998  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
999  indices_position.push_back(out_index[j]);
1000  }
1001  Array<PrimExpr> real_indices;
1002  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1003  real_indices.push_back(out_index[j]);
1004  }
1005  real_indices.push_back(indices(indices_position));
1006  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1007  real_indices.push_back(out_index[j]);
1008  }
1009  return a(real_indices);
1010  },
1011  name, tag);
1012  } else { // mode == "wrap"
1013  return compute(
1014  out_shape,
1015  [&](const Array<Var>& out_index) {
1016  Array<PrimExpr> indices_position;
1017  for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); ++j) {
1018  indices_position.push_back(out_index[j]);
1019  }
1020  Array<PrimExpr> real_indices;
1021  for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
1022  real_indices.push_back(out_index[j]);
1023  }
1024  auto idx = truncmod(truncmod(indices(indices_position), axis_dim) + axis_dim, axis_dim);
1025  real_indices.push_back(idx);
1026  for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
1027  real_indices.push_back(out_index[j]);
1028  }
1029  return a(real_indices);
1030  },
1031  name, tag);
1032  }
1033 }
1034 
1046 inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y,
1047  std::string name = "T_where", std::string tag = kBroadcast) {
1048  ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs "
1049  << y->dtype;
1050  auto get_out_shape = [&]() {
1051  auto bh1 = detail::BroadcastShape(x->shape, y->shape);
1052  Array<PrimExpr> common_shape1(bh1.common_shape.begin(), bh1.common_shape.end());
1053  auto bh2 = detail::BroadcastShape(condition->shape, common_shape1);
1054  Array<PrimExpr> common_shape2(bh2.common_shape.begin(), bh2.common_shape.end());
1055  return common_shape2;
1056  };
1057 
1058  auto oshape = get_out_shape();
1059 
1060  auto c_bh = detail::BroadcastShape(condition->shape, oshape);
1061  auto x_bh = detail::BroadcastShape(x->shape, oshape);
1062  auto y_bh = detail::BroadcastShape(y->shape, oshape);
1063 
1064  auto select = [&](tvm::Array<tvm::tir::Var> ovars) {
1065  auto c = condition(InputIndexFromBroadcast(ovars, condition, c_bh.vars1, c_bh.all_vars));
1066  auto true_val = x(InputIndexFromBroadcast(ovars, x, x_bh.vars1, x_bh.all_vars));
1067  auto false_val = y(InputIndexFromBroadcast(ovars, y, y_bh.vars1, y_bh.all_vars));
1068  return tvm::tir::Select(c != 0, true_val, false_val);
1069  };
1070 
1071  return compute(oshape, select, name, tag);
1072 }
1073 
1086 inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = "T_repeat",
1087  std::string tag = kBroadcast) {
1088  int ndim = static_cast<int>(x->shape.size());
1089  ICHECK(-ndim - 1 <= axis && axis <= ndim)
1090  << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
1091  << ", but got axis = " << axis << ", and data.ndim = " << ndim;
1092  ICHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`"
1093  << ", but got repeats = " << repeats;
1094  if (axis < 0) {
1095  // Calculate offset from last dimension
1096  axis += ndim;
1097  }
1098  Array<PrimExpr> new_shape;
1099  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
1100  new_shape.push_back(x->shape[i]);
1101  }
1102  new_shape.push_back(repeats * x->shape[axis]);
1103  for (size_t i = axis + 1; i < x->shape.size(); ++i) {
1104  new_shape.push_back(x->shape[i]);
1105  }
1106 
1107  return compute(
1108  new_shape,
1109  [&](const Array<Var>& indices) {
1110  Array<PrimExpr> idx;
1111  for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
1112  idx.push_back(indices[i]);
1113  }
1114  idx.push_back(indexdiv(indices[axis], repeats));
1115  for (size_t i = axis + 1; i < indices.size(); ++i) {
1116  idx.push_back(indices[i]);
1117  }
1118  return x(idx);
1119  },
1120  name, tag);
1121 }
1122 
1133 inline Tensor tile(const Tensor& x, Array<Integer> reps, std::string name = "T_tile",
1134  std::string tag = kBroadcast) {
1135  size_t ndim = x->shape.size();
1136  size_t rdim = reps.size();
1137  size_t tdim = (ndim > rdim) ? ndim : rdim;
1138  Array<PrimExpr> data_shape;
1139  Array<PrimExpr> reps_shape;
1140  Array<PrimExpr> new_shape;
1141  if (ndim == rdim) {
1142  for (size_t i = 0; i < ndim; ++i) {
1143  data_shape.push_back(x->shape[i]);
1144  reps_shape.push_back(reps[i]);
1145  }
1146  } else if (ndim > rdim) {
1147  for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]);
1148  for (size_t i = 0; i < (ndim - rdim); ++i) reps_shape.push_back(1);
1149  for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]);
1150  } else {
1151  for (size_t i = 0; i < (rdim - ndim); ++i) data_shape.push_back(1);
1152  for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]);
1153  for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]);
1154  }
1155  for (size_t i = 0; i < tdim; ++i) new_shape.push_back(data_shape[i] * reps_shape[i]);
1156 
1157  if (is_empty_shape(new_shape)) {
1158  return compute(
1159  new_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
1160  } else {
1161  return compute(
1162  new_shape,
1163  [&](const Array<Var>& indices) {
1164  Array<PrimExpr> idx;
1165  if (ndim >= rdim) {
1166  for (size_t i = 0; i < ndim; ++i) idx.push_back(indexmod(indices[i], x->shape[i]));
1167  } else {
1168  for (size_t i = 0; i < ndim; ++i)
1169  idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
1170  }
1171  return x(idx);
1172  },
1173  name, tag);
1174  }
1175 }
1176 
1188 inline Tensor dyn_tile(const Tensor& x, Array<PrimExpr> new_shape, size_t rdim,
1189  std::string name = "T_tile", std::string tag = kBroadcast) {
1190  size_t ndim = x->shape.size();
1191  if (is_empty_shape(new_shape)) {
1192  return compute(
1193  new_shape, [&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0); }, name, tag);
1194  } else {
1195  return compute(
1196  new_shape,
1197  [&](const Array<Var>& indices) {
1198  Array<PrimExpr> idx;
1199  if (ndim >= rdim) {
1200  for (size_t i = 0; i < ndim; ++i) {
1201  idx.push_back(indexmod(indices[i], x->shape[i]));
1202  }
1203  } else {
1204  for (size_t i = 0; i < ndim; ++i) {
1205  idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
1206  }
1207  }
1208  return x(idx);
1209  },
1210  name, tag);
1211  }
1212 }
1213 
1225 inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
1226  std::string name = "T_gather", std::string tag = kInjective) {
1227  size_t ndim_d = data->shape.size();
1228  size_t ndim_i = indices->shape.size();
1229  ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar.";
1230  ICHECK_EQ(ndim_d, ndim_i);
1231  if (axis < 0) {
1232  axis += ndim_d;
1233  }
1234  ICHECK_GE(axis, 0);
1235  ICHECK_LT(axis, ndim_d);
1236  if (indices->shape[axis].as<IntImmNode>()) {
1237  size_t indices_dim_i = static_cast<size_t>(GetConstInt(indices->shape[axis]));
1238  ICHECK_GE(indices_dim_i, 1);
1239  }
1240  ICHECK(indices->dtype.is_int());
1241 
1242  Array<PrimExpr> out_shape;
1243  for (size_t i = 0; i < ndim_i; ++i) {
1244  out_shape.push_back(indices->shape[i]);
1245  }
1246 
1247  return compute(
1248  out_shape,
1249  [&](const Array<Var>& out_index) {
1250  Array<PrimExpr> indices_position;
1251  for (size_t i = 0; i < ndim_i; ++i) {
1252  indices_position.push_back(out_index[i]);
1253  }
1254  Array<PrimExpr> real_indices;
1255  for (size_t i = 0; i < ndim_i; ++i) {
1256  if (i == static_cast<size_t>(axis)) {
1257  real_indices.push_back(indices(indices_position));
1258  } else {
1259  real_indices.push_back(indices_position[i]);
1260  }
1261  }
1262  return data(real_indices);
1263  },
1264  name, tag);
1265 }
1266 
1278 inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dims = 0,
1279  std::string name = "T_gather_nd", std::string tag = kInjective) {
1280  size_t ndim_d = data->shape.size();
1281  size_t ndim_i = indices->shape.size();
1282  ICHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions";
1283  size_t indices_dim0 = static_cast<size_t>(GetConstInt(indices->shape[0]));
1284  ICHECK_LE(indices_dim0, ndim_d) << "dim 0 of indices tensor must be no more "
1285  << "than dimensions of data tensor";
1286  Array<PrimExpr> out_shape;
1287  for (size_t i = 1; i < ndim_i; ++i) {
1288  out_shape.push_back(indices->shape[i]);
1289  }
1290  for (size_t i = indices_dim0 + batch_dims; i < ndim_d; ++i) {
1291  out_shape.push_back(data->shape[i]);
1292  }
1293  return compute(
1294  out_shape,
1295  [&](const Array<Var>& out_index) {
1296  Array<PrimExpr> indices_position;
1297  indices_position.push_back(0);
1298  for (size_t i = 0; i < ndim_i - 1; ++i) {
1299  indices_position.push_back(out_index[i]);
1300  }
1301  Array<PrimExpr> real_indices;
1302  for (size_t i = 0; i < static_cast<size_t>(batch_dims); ++i) {
1303  real_indices.push_back(out_index[i]);
1304  }
1305  for (size_t i = 0; i < indices_dim0; ++i) {
1306  indices_position.Set(0, make_const(DataType::Int(32), i));
1307  if (indices->dtype.is_int()) {
1308  real_indices.push_back(indices(indices_position));
1309  } else {
1310  real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position)));
1311  }
1312  }
1313  if (real_indices.size() == ndim_d) {
1314  return data(real_indices);
1315  }
1316  for (size_t i = ndim_i - 1; i < out_index.size(); ++i) {
1317  real_indices.push_back(out_index[i]);
1318  }
1319  return data(real_indices);
1320  },
1321  name, tag);
1322 }
1323 
1340  bool trans_a = false, bool trans_b = false,
1341  std::string name = "T_matmul", std::string tag = kMatMul) {
1342  tvm::Array<tvm::PrimExpr> output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]};
1343  auto k = tvm::te::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k");
1344  auto l = [&](tvm::tir::Var i, tvm::tir::Var j) {
1345  return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), {k});
1346  };
1347  return tvm::te::compute(output_shape, l, name, tag);
1348 }
1349 
1361 inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, int axes = 2,
1362  std::string name = "T_tensordot", std::string tag = kMatMul) {
1363  ICHECK_GE(A->shape.size(), axes);
1364  ICHECK_GE(B->shape.size(), axes);
1365 
1366  Array<PrimExpr> output_shape(A->shape.begin(), A->shape.end() + (-axes));
1367  for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) output_shape.push_back(*it);
1368 
1369  Array<IterVar> iter_vars;
1370  for (int i = 0; i < axes; ++i)
1371  iter_vars.push_back(reduce_axis(Range(0, B->shape[i]), "k" + std::to_string(i)));
1372 
1373  auto func = [&A, &B, &iter_vars, axes](const Array<Var>& input_indices) {
1374  Array<PrimExpr> A_indices(input_indices.begin(),
1375  input_indices.begin() + (A->shape.size() - axes));
1376  for (auto& v : iter_vars) A_indices.push_back(v);
1377 
1378  Array<PrimExpr> B_indices;
1379  for (auto& v : iter_vars) B_indices.push_back(v);
1380 
1381  auto it = input_indices.begin() + (A->shape.size() - axes);
1382  for (; it != input_indices.end(); ++it) B_indices.push_back(*it);
1383 
1384  // Some passes don't like reductions with empty axis, so avoid it here
1385  if (iter_vars.empty())
1386  return A(A_indices) * B(B_indices);
1387  else
1388  return sum(A(A_indices) * B(B_indices), iter_vars);
1389  };
1390 
1391  return compute(output_shape, func, name, tag);
1392 }
1393 
1406 inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, Array<PrimExpr> A_axes,
1407  Array<PrimExpr> B_axes, std::string name = "T_tensordot",
1408  std::string tag = kMatMul) {
1409  ICHECK_EQ(A_axes.size(), B_axes.size());
1410 
1411  auto A_axes_val = GetConstIntValues(A_axes, "A_axes");
1412  auto B_axes_val = GetConstIntValues(B_axes, "B_axes");
1413 
1414  Array<PrimExpr> output_shape;
1415  for (unsigned i = 0; i < A->shape.size(); ++i)
1416  if (std::find(A_axes_val.begin(), A_axes_val.end(), i) == A_axes_val.end())
1417  output_shape.push_back(A->shape[i]);
1418  for (unsigned i = 0; i < B->shape.size(); ++i)
1419  if (std::find(B_axes_val.begin(), B_axes_val.end(), i) == B_axes_val.end())
1420  output_shape.push_back(B->shape[i]);
1421 
1422  Array<IterVar> iter_vars;
1423  for (unsigned i = 0; i < B_axes_val.size(); ++i)
1424  iter_vars.push_back(reduce_axis(Range(0, B->shape[B_axes_val[i]]), "k" + std::to_string(i)));
1425 
1426  auto func = [&A, &B, &iter_vars, A_axes_val, B_axes_val](const Array<Var>& input_indices) {
1427  int idx_input = 0;
1428  Array<PrimExpr> A_indices;
1429  for (unsigned i = 0; i < A->shape.size(); ++i) {
1430  auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i);
1431  if (axes_pos == A_axes_val.end())
1432  A_indices.push_back(input_indices[idx_input++]);
1433  else
1434  A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]);
1435  }
1436 
1437  Array<PrimExpr> B_indices;
1438  for (unsigned i = 0; i < B->shape.size(); ++i) {
1439  auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i);
1440  if (axes_pos == B_axes_val.end())
1441  B_indices.push_back(input_indices[idx_input++]);
1442  else
1443  B_indices.push_back(iter_vars[axes_pos - B_axes_val.begin()]);
1444  }
1445  return sum(A(A_indices) * B(B_indices), iter_vars);
1446  };
1447  return compute(output_shape, func, name, tag);
1448 }
1449 
1450 inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr& step,
1451  DataType dtype, std::string name = "T_arange", std::string tag = kInjective) {
1452  PrimExpr num_elem = tvm::cast(
1453  tvm::DataType::Int(32), tvm::ceil(tvm::cast(tvm::DataType::Float(32), stop - start) / step));
1455  return compute(
1456  {num_elem},
1457  [&](const Array<Var>& indices) { return tvm::cast(dtype, start + step * indices[0]); }, name,
1458  tag);
1459 }
1460 
1471 inline Array<Tensor> meshgrid(const Array<Tensor>& inputs, const std::string& indexing,
1472  std::string name = "T_meshgrid", std::string tag = kInjective) {
1473  const bool cartesian_indexing = indexing == "xy" && inputs.size() >= 2;
1474  Array<PrimExpr> out_shape;
1475  for (size_t i = 0; i < inputs.size(); ++i) {
1476  const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i;
1477  out_shape.push_back(inputs[src_index]->shape.size() == 0 ? 1 : inputs[src_index]->shape[0]);
1478  }
1479  Array<Tensor> result;
1480  for (size_t i = 0; i < inputs.size(); ++i) {
1481  result.push_back(compute(
1482  out_shape,
1483  [&](const Array<Var>& indices) {
1484  const int src_index = (cartesian_indexing && i < 2) ? 1 - i : i;
1485  Array<PrimExpr> real_indices = {indices[src_index]};
1486  return inputs[i](real_indices);
1487  },
1488  name, tag));
1489  }
1490  return result;
1491 }
1492 
1502 inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
1503  const std::string& dst_layout,
1504  const std::string name = "T_layout_trans",
1505  const std::string tag = kInjective) {
1506  Layout src_layout_struct(src_layout);
1507  Layout dst_layout_struct(dst_layout);
1508 
1509  if (src_layout_struct.Equals(dst_layout_struct)) {
1510  return src;
1511  }
1512 
1513  ICHECK(src_layout_struct.defined() && dst_layout_struct.defined())
1514  << "cannot convert from/to undefined layout";
1515 
1516  auto layout_converter = tir::BijectiveLayout(src_layout_struct, dst_layout_struct);
1517  ICHECK(layout_converter.defined())
1518  << "cannot convert from " << src_layout << " to " << dst_layout;
1519 
1520  Array<PrimExpr> dst_shape = layout_converter.ForwardShape(src->shape);
1521 
1522  return compute(
1523  dst_shape,
1524  [&](const Array<Var>& dst_indices) {
1525  Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
1526  Array<PrimExpr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
1527  return src(src_indices);
1528  },
1529  name, tag);
1530 }
1531 
1534  std::vector<std::string>* axes) {
1535  int32_t factor = 0;
1536  std::string axis = "";
1537  for (char c : std::string(layout)) {
1538  if (c >= 'A' && c <= 'z') {
1539  axis += c;
1540  if (factor != 0) {
1541  shape->push_back(factor);
1542  factor = 0;
1543  }
1544  } else if (c >= '0' && c <= '9') {
1545  factor = factor * 10 + c - '0';
1546  if (!axis.empty()) {
1547  axes->push_back(axis);
1548  axis = "";
1549  }
1550  } else {
1551  LOG(FATAL) << "Invalid layout " << layout;
1552  }
1553  }
1554  if (!axis.empty()) {
1555  axes->push_back(axis);
1556  }
1557 }
1558 
1569 inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String& src_layout,
1570  const String& dst_layout,
1571  const String name = "T_auto_scheduler_layout_trans",
1572  const String tag = kInjective) {
1573  Array<PrimExpr> src_shape;
1574  std::vector<std::string> src_axes;
1575  Array<PrimExpr> dst_shape;
1576  std::vector<std::string> dst_axes;
1577 
1578  parse_auto_scheduler_layout(src_layout, &src_shape, &src_axes);
1579  parse_auto_scheduler_layout(dst_layout, &dst_shape, &dst_axes);
1580  return compute(
1581  dst_shape,
1582  [&](const Array<Var>& dst_indices) {
1583  Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
1584  Array<PrimExpr> src_indices;
1585  for (const std::string& src_axis : src_axes) {
1586  PrimExpr src_index = 0;
1587  CHECK_EQ(dst_indices_expr.size(), dst_axes.size());
1588  for (size_t i = 0; i < dst_axes.size(); ++i) {
1589  if (dst_axes[i] == src_axis) {
1590  src_index = src_index * dst_shape[i] + dst_indices_expr[i];
1591  }
1592  }
1593  src_indices.push_back(src_index);
1594  }
1595  return src(src_indices);
1596  },
1597  name, tag);
1598 }
1599 
1608 inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = "T_shape",
1609  const std::string tag = kInjective) {
1610  int ndim = static_cast<int>(src->shape.size());
1611  Array<PrimExpr> out_shape{ndim};
1612  return compute(
1613  out_shape,
1614  [&](const Array<Var>& indices) {
1615  auto idx = indices[0];
1616  PrimExpr ret = 0;
1617  for (int i = 0; i < ndim; ++i) {
1618  ret = tvm::if_then_else(idx == i, src->shape[i], ret);
1619  }
1620  return tvm::cast(dtype, ret);
1621  },
1622  name, tag);
1623 }
1624 
1633 inline Tensor ndarray_size(const Tensor& src, const DataType& dtype,
1634  const std::string& name = "ndarray_size",
1635  const std::string& tag = kInjective) {
1636  int ndim = static_cast<int>(src->shape.size());
1637  Array<PrimExpr> out_ndarray_size = {};
1638  return compute(
1639  out_ndarray_size,
1640  [&](const Array<Var>& indices) {
1641  PrimExpr ret = 1;
1642  for (int i = 0; i < ndim; ++i) {
1643  ret *= src->shape[i];
1644  }
1645  return tvm::cast(dtype, ret);
1646  },
1647  name, tag);
1648 }
1649 
1664 inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value,
1665  int depth, int axis, const DataType& dtype,
1666  Array<PrimExpr> oshape = Array<PrimExpr>(),
1667  const std::string name = "T_one_hot", const std::string tag = kInjective) {
1668  int true_axis = (axis == -1) ? indices->shape.size() : axis;
1669  if (oshape.size() == 0) {
1670  int ndim = indices->shape.size() + 1;
1671  int indices_index = 0;
1672  for (int i = 0; i < ndim; i++) {
1673  if (i == true_axis) {
1674  oshape.push_back(Integer(depth));
1675  } else {
1676  oshape.push_back(indices->shape[indices_index++]);
1677  }
1678  }
1679  }
1680 
1681  PrimExpr on_value_cast = cast(dtype, on_value);
1682  PrimExpr off_value_cast = cast(dtype, off_value);
1683  return compute(
1684  oshape,
1685  [&](const Array<Var>& iter_vars) {
1686  Array<Var> indices_indices;
1687  for (size_t i = 0; i < iter_vars.size(); i++) {
1688  if (static_cast<int>(i) == true_axis) {
1689  continue;
1690  }
1691 
1692  indices_indices.push_back(iter_vars[i]);
1693  }
1694 
1695  auto idx = iter_vars[true_axis];
1696  return tir::Select(indices(indices_indices) == idx, on_value_cast, off_value_cast);
1697  },
1698  name, tag);
1699 }
1700 
1711 inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array<PrimExpr>& output_shape,
1712  const Tensor& sparse_values, const PrimExpr& default_value,
1713  const std::string name = "T_sparse_to_dense",
1714  const std::string tag = kInjective) {
1715  ICHECK(sparse_indices->dtype.is_int()) << "sparse_indices only accepts integer values";
1716  ICHECK_LE(sparse_indices->shape.size(), 3)
1717  << "sparse_indices tensor should be 0D, 1D, or 2D only";
1718  ICHECK_LE(sparse_values->shape.size(), 2) << "sparse_values tensor should be 0D or 1D only";
1719 
1720  const auto rank_sparse_indices = static_cast<int>(sparse_indices->shape.size());
1721  Array<PrimExpr> oshape;
1722  for (auto l : output_shape) {
1723  oshape.push_back(l);
1724  }
1725  return compute(
1726  oshape,
1727  [&](const Array<Var>& indices) {
1728  PrimExpr ret = default_value;
1729  if (0 == rank_sparse_indices) {
1730  ret = if_then_else(indices[0] == sparse_indices[0], sparse_values[0], ret);
1731  } else if (1 == rank_sparse_indices) {
1732  for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
1733  ret = if_then_else(indices[0] == sparse_indices[j], sparse_values[j], ret);
1734  }
1735  } else {
1736  for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
1737  PrimExpr aggregate_condition;
1738  for (int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) {
1739  PrimExpr comparision = indices[k] == sparse_indices[j][k];
1740  aggregate_condition = 0 == k ? comparision : aggregate_condition && comparision;
1741  }
1742  ret = if_then_else(aggregate_condition, sparse_values[j], ret);
1743  }
1744  }
1745  return ret;
1746  },
1747  name, tag);
1748 }
1749 
1762 inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k1, int k2,
1763  bool super_diag_right_align, bool sub_diag_right_align,
1764  const std::string name = "T_matrix_set_diag",
1765  const std::string tag = kInjective) {
1766  size_t ndim = input->shape.size() - 1;
1767 
1768  bool only_one_diagonal = k1 == k2;
1769 
1770  return compute(
1771  input->shape,
1772  [&](const Array<Var>& iter_vars) {
1773  auto get_diag = [&]() {
1774  Array<PrimExpr> diagonal_indices;
1775  PrimExpr k, offset = 0;
1776  for (size_t i = 0; i < ndim - 1; i++) {
1777  diagonal_indices.push_back(iter_vars[i]);
1778  }
1779  if (only_one_diagonal) {
1780  k = k1;
1781  } else {
1782  // Determining which diagonal/sub-diagonal/super-diagonal it is
1783  k = iter_vars[ndim] - iter_vars[ndim - 1];
1784  diagonal_indices.push_back(k2 - k);
1785 
1786  // Calculating the offset in diagonal tensor for this diagonal
1787  auto get_offset = [&](PrimExpr M, PrimExpr N) {
1788  // offset = max_diagonal_length - diagonal_length
1789  return diagonal->shape[diagonal->shape.size() - 1] - if_then_else(M < N, M, N);
1790  };
1791  offset = if_then_else(
1792  k >= 0,
1793  super_diag_right_align ? get_offset(input->shape[ndim] - k, input->shape[ndim - 1])
1794  : 0,
1795  sub_diag_right_align ? get_offset(input->shape[ndim], input->shape[ndim - 1] + k)
1796  : 0);
1797  }
1798  diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) +
1799  offset);
1800  return diagonal(diagonal_indices);
1801  };
1802  return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1,
1803  if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2,
1804  get_diag(), input(iter_vars)),
1805  input(iter_vars));
1806  },
1807  name, tag);
1808 }
1809 
1818 inline Tensor adv_index(const Tensor& data, const Array<Tensor>& indices,
1819  const std::string name = "advanced_index",
1820  const std::string tag = kInjective) {
1821  Array<PrimExpr> oshape;
1822  Array<PrimExpr> broadcast_shape;
1823  Array<Tensor> bindices;
1824  std::vector<int64_t> flatten_shape_lens;
1825  int64_t num_picked_elems = 1;
1826  bool has_dyn_shape = false;
1827 
1828  if (indices.size() == 1) {
1829  broadcast_shape = indices[0]->shape;
1830  bindices = indices;
1831  } else {
1832  for (const auto& index : indices) {
1833  int64_t flatten_len = 1;
1834  for (const auto& dim : index->shape) {
1835  const IntImmNode* axis_len = dim.as<IntImmNode>();
1836  if (!axis_len) {
1837  broadcast_shape = index->shape;
1838  has_dyn_shape = true;
1839  break;
1840  }
1841  flatten_len *= axis_len->value;
1842  }
1843  if (has_dyn_shape) break;
1844  flatten_shape_lens.push_back(flatten_len);
1845  if (flatten_len > num_picked_elems) {
1846  num_picked_elems = flatten_len;
1847  broadcast_shape = index->shape;
1848  }
1849  }
1850 
1851  // Do broadcast for indices
1852  for (size_t i = 0; i < indices.size(); ++i) {
1853  if (!has_dyn_shape && flatten_shape_lens[i] < num_picked_elems) {
1854  bindices.push_back(broadcast_to(indices[i], broadcast_shape));
1855  } else {
1856  bindices.push_back(indices[i]);
1857  }
1858  }
1859  }
1860 
1861  for (const auto& dim : broadcast_shape) {
1862  oshape.push_back(dim);
1863  }
1864  for (size_t i = indices.size(); i < data->shape.size(); ++i) {
1865  oshape.push_back(data->shape[i]);
1866  }
1867 
1868  return compute(
1869  oshape,
1870  [&](const Array<Var>& iter_var) {
1871  Array<PrimExpr> tensor_indices;
1872  for (size_t i = 0; i < broadcast_shape.size(); ++i) {
1873  tensor_indices.push_back(iter_var[i]);
1874  }
1875 
1876  Array<PrimExpr> real_indices;
1877  for (size_t i = 0; i < bindices.size(); ++i) {
1878  real_indices.push_back(bindices[i](tensor_indices));
1879  }
1880  for (size_t i = broadcast_shape.size(); i < iter_var.size(); ++i) {
1881  real_indices.push_back(iter_var[i]);
1882  }
1883 
1884  return data(real_indices);
1885  },
1886  name, tag);
1887 }
1888 
1889 } // namespace topi
1890 } // namespace tvm
1891 #endif // TVM_TOPI_TRANSFORM_H_
Managed reference to LayoutNode.
Definition: data_layout.h:123
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
bool Equals(const Layout &rhs) const
Whether the two layouts are equal.
Definition: data_layout.h:276
Tensor strided_slice_with_axes(const Tensor &x, const Array< Integer > &begin, const Array< Integer > &end, const Array< Integer > &strides, const Array< Integer > &axes, std::string slice_mode="end", std::string name="T_strided_slice_with_axes", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:683
Tensor sparse_to_dense(const Tensor &sparse_indices, const Array< PrimExpr > &output_shape, const Tensor &sparse_values, const PrimExpr &default_value, const std::string name="T_sparse_to_dense", const std::string tag=kInjective)
Get a dense tensor.
Definition: transform.h:1711
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
Array< PrimExpr > StridedSliceOutputShape(const Array< PrimExpr > &ishape, const Array< Integer > &begin, const Array< Integer > &end, const Array< Integer > &strides, const Array< Integer > &axes, const std::string &slice_mode)
Calcluate the output shape of strided_slice, the entry point for Relay type relation.
Definition: transform.h:655
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:1109
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
Tensor expression language DSL.
Definition: autodiff.h:35
Tensor tensordot(const Tensor &A, const tvm::te::Tensor &B, int axes=2, std::string name="T_tensordot", std::string tag=kMatMul)
A generalization of matrix multiplication to tensors.
Definition: transform.h:1361
Tensor dynamic_strided_slice(const Tensor &x, const Array< PrimExpr > &begin, const Array< PrimExpr > &end, const Array< PrimExpr > &strides, std::string name="T_dynamic_strided_slice", std::string tag=kInjective)
strided_slice of a tensor where begin/end/stride can be mixed static and dynamic
Definition: transform.h:566
PrimExpr ceil(PrimExpr x, Span span=Span())
Calculate ceil(x)
a named variable in TIR
Definition: var.h:88
Tensor where(const Tensor &condition, const Tensor &x, const Tensor &y, std::string name="T_where", std::string tag=kBroadcast)
Return the elements, either from x or y, depending on the condition.
Definition: transform.h:1046
Tensor one_hot(const Tensor &indices, const PrimExpr on_value, const PrimExpr off_value, int depth, int axis, const DataType &dtype, Array< PrimExpr > oshape=Array< PrimExpr >(), const std::string name="T_one_hot", const std::string tag=kInjective)
Returns a one-hot tensor where the locations repsented by indices take value on_value, other locations take value off_value.
Definition: transform.h:1664
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
Tensor matrix_set_diag(const Tensor &input, const Tensor &diagonal, int k1, int k2, bool super_diag_right_align, bool sub_diag_right_align, const std::string name="T_matrix_set_diag", const std::string tag=kInjective)
Returns a tensor with the diagonal of input tensor replaced with the provided diagonals.
Definition: transform.h:1762
constexpr auto kMatMul
Definition: tags.h:37
constexpr auto kInjective
Definition: tags.h:33
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
Utility functions for strided_slice op.
Tensor unravel_index(const Tensor &x, const Tensor &shape, std::string name="T_unravel", std::string tag=kInjective)
Converts a flat index or array of flat indices into a tuple of coordinate arrays. ...
Definition: transform.h:273
Array< Tensor > split(const Tensor &x, Array< PrimExpr > split_indices, int axis, std::string name="T_split", std::string tag=kInjective)
Split a tensor into multiple sub-tensors.
Definition: transform.h:489
void parse_auto_scheduler_layout(const String &layout, Array< PrimExpr > *shape, std::vector< std::string > *axes)
Utility function for auto_scheduler_layout_transform.
Definition: transform.h:1533
Tensor gather(const Tensor &data, int axis, const Tensor &indices, std::string name="T_gather", std::string tag=kInjective)
Gather values along given axis from given indices.
Definition: transform.h:1225
Tensor dyn_tile(const Tensor &x, Array< PrimExpr > new_shape, size_t rdim, std::string name="T_tile", std::string tag=kBroadcast)
Creates an operation to tile elements of an array.
Definition: transform.h:1188
Tensor layout_transform(const Tensor &src, const std::string &src_layout, const std::string &dst_layout, const std::string name="T_layout_trans", const std::string tag=kInjective)
Transform the layout according to src_layout and dst_layout.
Definition: transform.h:1502
Tensor auto_scheduler_layout_transform(const Tensor &src, const String &src_layout, const String &dst_layout, const String name="T_auto_scheduler_layout_trans", const String tag=kInjective)
Transform the auto-scheduler generated layout according to src_layout and dst_layout.
Definition: transform.h:1569
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
Tensor squeeze(const Tensor &x, Array< Integer > axis, bool atleast1d=false, std::string name="T_squeeze", std::string tag=kInjective)
Remove size 1 dimensions from the shape of a tensor. The removed dimensions must have a constant size...
Definition: transform.h:321
void Set(int64_t i, T value)
set i-th element of the array.
Definition: array.h:567
Constant integer literals in the program.
Definition: expr.h:233
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:436
Tensor expand_dims(const Tensor &x, int axis, int num_newaxis=1, std::string name="T_expand_dims", std::string tag=kBroadcast)
Creates an operation to insert new dimensions of length 1.
Definition: transform.h:62
Tensor tile(const Tensor &x, Array< Integer > reps, std::string name="T_tile", std::string tag=kBroadcast)
Creates an operation to tile elements of an array.
Definition: transform.h:1133
Tensor sum(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:326
Utility functions for handling constants in TVM expressions.
constexpr auto kBroadcast
Definition: tags.h:36
Range constainer.
Definition: expr.h:449
Tensor arange(const PrimExpr &start, const PrimExpr &stop, const PrimExpr &step, DataType dtype, std::string name="T_arange", std::string tag=kInjective)
Definition: transform.h:1450
size_t size() const
Definition: array.h:399
Runtime primitive data type.
Definition: data_type.h:41
bool defined() const
Definition: object.h:537
Tensor stack(const Array< Tensor > &inputs, int axis=0, std::string name="T_stack", std::string tag=kInjective)
Join a sequence of tensors along a new axis.
Definition: transform.h:443
Utility functions for handling tensor.
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:168
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of of source expression over axis
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
Tensor concatenate(const Array< Tensor > &inputs, int axis=0, std::string name="T_concat", std::string tag=kInjective)
Join a sequence of tensors along an existing axis.
Definition: transform.h:384
Tensor take(const Tensor &a, const Tensor &indices, int batch_dims, std::string mode="clip", std::string name="T_take", std::string tag=kInjective)
Take elements from an flattened input array when axis is None.
Definition: transform.h:812
Managed reference class to IntImmNode.
Definition: expr.h:262
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
int64_t value
the Internal value.
Definition: expr.h:236
Reference to string objects.
Definition: string.h:129
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1608
Array< Tensor > meshgrid(const Array< Tensor > &inputs, const std::string &indexing, std::string name="T_meshgrid", std::string tag=kInjective)
Produce grids by expanding input over dimensions defined by other inputs.
Definition: transform.h:1471
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of truncdiv
iterator end() const
Definition: array.h:369
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
iterator begin() const
Definition: array.h:366
Operation node can generate one or multiple Tensors.
Managed reference to SelectNode.
Definition: expr.h:589
Bijective function mapping for data layout transformation. Given two Layout, BijectiveLayout build an...
Definition: data_layout.h:324
Tensor transpose(const Tensor &x, Array< Integer > axes, std::string name="T_transpose", std::string tag=kInjective)
Permute the dimensions of an array.
Definition: transform.h:111
Tensor ndarray_size(const Tensor &src, const DataType &dtype, const std::string &name="ndarray_size", const std::string &tag=kInjective)
Get the size of input tensor.
Definition: transform.h:1633
Tensor sequence_mask(const Tensor &data, const Tensor &valid_length, double mask_value, int axis, std::string name="T_sequence_mask", std::string tag=kInjective)
Mask the out-of-boundary elements of each sequence.
Definition: transform.h:860
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Tensor adv_index(const Tensor &data, const Array< Tensor > &indices, const std::string name="advanced_index", const std::string tag=kInjective)
Numpy style advanced indexing with tensor.
Definition: transform.h:1818
External function interface to rocBLAS libraries.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Tensor reverse_sequence(const Tensor &x, const Tensor &seq_lengths, int seq_axis=1, int batch_axis=0, std::string name="T_reverse_sequence", std::string tag=kInjective)
Reverse the tensor for variable length slices. Input is first sliced along batch axis and then elemen...
Definition: transform.h:170
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type...
Definition: elemwise.h:280
Tensor reshape(const Tensor &x, Array< PrimExpr > newshape, std::string name="T_reshape", std::string tag=kInjective)
Reshape a tensor.
Definition: transform.h:235
tvm::te::Tensor broadcast_to(const tvm::te::Tensor &t, const tvm::Array< tvm::PrimExpr > &output_shape, std::string name="T_broadcast_to", std::string tag=kBroadcast)
Creates an operation that broadcasts a tensor into a compatible shape according to numpy&#39;s rules...
Definition: broadcast.h:48
Tensor repeat(const Tensor &x, int repeats, int axis, std::string name="T_repeat", std::string tag=kBroadcast)
Creates an operation to repeat elements of an array.
Definition: transform.h:1086
Tensor strided_slice(const Tensor &x, const Array< Integer > &begin, const Array< Integer > &end, const Array< Integer > &strides, std::string slice_mode="end", std::string name="T_strided_slice", std::string tag=kInjective)
strided_slice of a tensor
Definition: transform.h:729
Broadcast op constructions.
Reference to PrimExprNode.
Definition: expr.h:109
Layout expression to describe the data organization of a tensor. And BijectiveLayout to mapping two d...
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:858
Tensor gather_nd(const Tensor &data, const Tensor &indices, int batch_dims=0, std::string name="T_gather_nd", std::string tag=kInjective)
Gather elements from a n-dimension array.
Definition: transform.h:1278
Array< Tensor > split_sections(const Tensor &x, int num_sections, int axis, std::string name="T_split_sections", std::string tag=kInjective)
Split a tensor into a number of sub-tensors.
Definition: transform.h:769
Detail broadcast.
Index ravel and unraval operations.
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:387
tvm::te::Tensor matmul(const tvm::te::Tensor &A, const tvm::te::Tensor &B, bool trans_a=false, bool trans_b=false, std::string name="T_matmul", std::string tag=kMatMul)
Creates an operation that calculates a matrix multiplication (row-major notation): A(i...
Definition: transform.h:1339
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:154
Container of constant int that adds more constructors.
Definition: expr.h:356