tvm
einsum.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_EINSUM_H_
25 #define TVM_TOPI_EINSUM_H_
26 
27 #define LABELRANGE 128
28 #define NPY_MAXDIMS 16
29 #define NPY_MAXARGS 16
30 
31 #include <tvm/te/operation.h>
32 #include <tvm/tir/data_layout.h>
36 #include <tvm/topi/tags.h>
37 
38 #include <algorithm>
39 #include <bitset>
40 #include <iterator>
41 #include <string>
42 #include <tuple>
43 #include <unordered_set>
44 #include <vector>
45 
46 namespace tvm {
47 namespace topi {
48 
49 using namespace tvm::te;
50 using namespace topi::detail;
51 
60  size_t ndim = shape.size();
61  int prod = 1;
62  Array<PrimExpr> stride = Array<PrimExpr>(ndim, -1);
63  for (int i = ndim - 1; i >= 0; i--) {
64  stride.Set(i, if_then_else(shape[i] > 1, prod, 0));
65  prod = prod * GetConstInt(shape[i]);
66  }
67  return stride;
68 }
69 
78 inline Array<PrimExpr> Pad(const Array<PrimExpr> shape, int odim) {
79  int ndim = shape.size();
80  CHECK_GE(odim, ndim);
81  Array<PrimExpr> ret(static_cast<size_t>(odim), 1);
82  for (int idim = 0; idim < ndim; ++idim) {
83  ret.Set(idim, shape[idim]);
84  }
85  return ret;
86 }
87 
105 inline int ParseOperandSubscripts(const char* subscripts, int length, int ndim, int iop,
106  char* op_labels, char* label_counts, int* min_label,
107  int* max_label) {
108  int i;
109  int idim = 0;
110  int ellipsis = -1;
111 
112  /* Process all labels for this operand */
113  for (i = 0; i < length; ++i) {
114  int label = subscripts[i];
115 
116  /* A proper label for an axis. */
117  if (label > 0 && isalpha(label)) {
118  /* Check we don't exceed the operator dimensions. */
119  CHECK(idim < ndim) << "einstein sum subscripts string contains "
120  << "too many subscripts for operand " << iop;
121 
122  op_labels[idim++] = label;
123  if (label < *min_label) {
124  *min_label = label;
125  }
126  if (label > *max_label) {
127  *max_label = label;
128  }
129  label_counts[label]++;
130  } else if (label == '.') {
131  /* The beginning of the ellipsis. */
132  /* Check it's a proper ellipsis. */
133  CHECK(
134  !(ellipsis != -1 || i + 2 >= length || subscripts[++i] != '.' || subscripts[++i] != '.'))
135  << "einstein sum subscripts string contains a "
136  << "'.' that is not part of an ellipsis ('...') "
137  << "in operand " << iop;
138 
139  ellipsis = idim;
140  } else {
141  CHECK(label == ' ') << "invalid subscript '" << static_cast<char>(label)
142  << "' in einstein sum "
143  << "subscripts string, subscripts must "
144  << "be letters";
145  }
146  }
147 
148  /* No ellipsis found, labels must match dimensions exactly. */
149  if (ellipsis == -1) {
150  CHECK(idim == ndim) << "operand has more dimensions than subscripts "
151  << "given in einstein sum, but no '...' ellipsis "
152  << "provided to broadcast the extra dimensions.";
153  } else if (idim < ndim) {
154  /* Ellipsis found, may have to add broadcast dimensions. */
155  /* Move labels after ellipsis to the end. */
156  for (i = 0; i < idim - ellipsis; ++i) {
157  op_labels[ndim - i - 1] = op_labels[idim - i - 1];
158  }
159  /* Set all broadcast dimensions to zero. */
160  for (i = 0; i < ndim - idim; ++i) {
161  op_labels[ellipsis + i] = 0;
162  }
163  }
164 
165  /*
166  * Find any labels duplicated for this operand, and turn them
167  * into negative offsets to the axis to merge with.
168  *
169  * In C, the char type may be signed or unsigned, but with
170  * twos complement arithmetic the char is ok either way here, and
171  * later where it matters the char is cast to a signed char.
172  */
173  for (idim = 0; idim < ndim - 1; ++idim) {
174  int label = op_labels[idim];
175  /* If it is a proper label, find any duplicates of it. */
176  if (label > 0) {
177  /* Search for the next matching label. */
178  char* next = reinterpret_cast<char*>(memchr(op_labels + idim + 1, label, ndim - idim - 1));
179 
180  while (next != nullptr) {
181  /* The offset from next to op_labels[idim] (negative). */
182  *next = static_cast<char>((op_labels + idim) - next);
183  /* Search for the next matching label. */
184  next = reinterpret_cast<char*>(memchr(next + 1, label, op_labels + ndim - 1 - next));
185  }
186  }
187  }
188  return 0;
189 }
190 
204 inline int ParseOutputSubscripts(const char* subscripts, int length, int ndim_broadcast,
205  const char* label_counts, char* out_labels) {
206  int i, bdim;
207  int ndim = 0;
208  int ellipsis = 0;
209 
210  /* Process all the output labels. */
211  for (i = 0; i < length; ++i) {
212  int label = subscripts[i];
213 
214  /* A proper label for an axis. */
215  if (label > 0 && isalpha(label)) {
216  /* Check that it doesn't occur again. */
217  CHECK(memchr(subscripts + i + 1, label, length - i - 1) == nullptr)
218  << "einstein sum subscripts string includes "
219  << "output subscript '" << static_cast<char>(label) << "' multiple times";
220 
221  /* Check that it was used in the inputs. */
222  CHECK(label_counts[label] != 0)
223  << "einstein sum subscripts string included "
224  << "output subscript '" << static_cast<char>(label) << "' which never appeared "
225  << "in an input";
226 
227  /* Check that there is room in out_labels for this label. */
228  CHECK(ndim < NPY_MAXDIMS) << "einstein sum subscripts string contains "
229  << "too many subscripts in the output";
230 
231  out_labels[ndim++] = label;
232  } else if (label == '.') {
233  /* The beginning of the ellipsis. */
234  /* Check it is a proper ellipsis. */
235  CHECK(!(ellipsis || i + 2 >= length || subscripts[++i] != '.' || subscripts[++i] != '.'))
236  << "einstein sum subscripts string "
237  << "contains a '.' that is not part of "
238  << "an ellipsis ('...') in the output";
239 
240  /* Check there is room in out_labels for broadcast dims. */
241  CHECK(ndim + ndim_broadcast <= NPY_MAXDIMS) << "einstein sum subscripts string contains "
242  << "too many subscripts in the output";
243 
244  ellipsis = 1;
245  for (bdim = 0; bdim < ndim_broadcast; ++bdim) {
246  out_labels[ndim++] = 0;
247  }
248  } else {
249  CHECK(label == ' ') << "invalid subscript '" << static_cast<char>(label)
250  << "' in einstein sum "
251  << "subscripts string, subscripts must "
252  << "be letters";
253  }
254  }
255 
256  /* If no ellipsis was found there should be no broadcast dimensions. */
257  CHECK(!(!ellipsis && ndim_broadcast > 0)) << "output has more dimensions than subscripts "
258  << "given in einstein sum, but no '...' ellipsis "
259  << "provided to broadcast the extra dimensions.";
260 
261  return ndim;
262 }
263 
278 inline void GetCombinedDimsView(const Tensor& op, int iop, char* labels, Array<PrimExpr>* newshape,
279  Array<PrimExpr>* newstride) {
280  int idim, ndim, icombine, combineoffset;
281  int icombinemap[NPY_MAXDIMS];
282  int newdim;
283 
284  Array<PrimExpr> shape = op->shape;
285  Array<PrimExpr> stride = GetStride(shape);
286  ndim = op.ndim();
287  newdim = newshape->size();
288 
289  /* Initialize the dimensions and strides to zero */
290  for (idim = 0; idim < newdim; ++idim) {
291  newshape->Set(idim, 0);
292  newstride->Set(idim, 0);
293  }
294 
295  /* Copy the dimensions and strides, except when collapsing */
296  icombine = 0;
297  for (idim = 0; idim < ndim; ++idim) {
298  /*
299  * The char type may be either signed or unsigned, we
300  * need it to be signed here.
301  */
302  int label = (signed char)labels[idim];
303  /* If this label says to merge axes, get the actual label */
304  if (label < 0) {
305  combineoffset = label;
306  label = labels[idim + label];
307  } else {
308  combineoffset = 0;
309  if (icombine != idim) {
310  labels[icombine] = labels[idim];
311  }
312  icombinemap[idim] = icombine;
313  }
314  /* If the label is 0, it's an unlabeled broadcast dimension */
315  if (label == 0) {
316  newshape->Set(icombine, shape[idim]);
317  newstride->Set(icombine, stride[idim]);
318  } else {
319  /* Update the combined axis dimensions and strides */
320  int i = icombinemap[idim + combineoffset];
321  CHECK(!((combineoffset < 0) &&
322  GetConstInt((*newshape)[i] != 0 && (*newshape)[i] != shape[idim])))
323  << "dimensions in operand " << iop << " for collapsing index '" << label
324  << "' don't match (" << GetConstInt((*newshape)[i]) << " != " << shape[idim] << ")";
325  newshape->Set(i, shape[idim]);
326  newstride->Set(i, (*newstride)[i] + stride[idim]);
327  }
328 
329  /* If the label didn't say to combine axes, increment dest i */
330  if (combineoffset == 0) {
331  icombine++;
332  }
333  }
334 }
335 
346 inline static int PrepareOpAxes(int ndim, int iop, char* labels, int* axes, int ndim_iter,
347  char* iter_labels) {
348  int i, label, ibroadcast;
349 
350  ibroadcast = ndim - 1;
351  for (i = ndim_iter - 1; i >= 0; --i) {
352  label = iter_labels[i];
353  /*
354  * If it's an unlabeled broadcast dimension, choose
355  * the next broadcast dimension from the operand.
356  */
357  if (label == 0) {
358  while (ibroadcast >= 0 && labels[ibroadcast] != 0) {
359  --ibroadcast;
360  }
361  /*
362  * If we used up all the operand broadcast dimensions,
363  * extend it with a "newaxis"
364  */
365  if (ibroadcast < 0) {
366  axes[i] = -1;
367  } else {
368  /* Otherwise map to the broadcast axis */
369  axes[i] = ibroadcast;
370  --ibroadcast;
371  }
372  } else {
373  /* It's a labeled dimension, find the matching one */
374  char* match = reinterpret_cast<char*>(memchr(labels, label, ndim));
375  /* If the op doesn't have the label, broadcast it */
376  if (match == nullptr) {
377  axes[i] = -1;
378  } else {
379  /* Otherwise use it */
380  axes[i] = match - labels;
381  }
382  }
383  }
384  return 0;
385 }
386 
394 inline int CountSubstring(const std::string& str, const std::string& sub) {
395  int count = 0;
396  std::string::size_type pos = 0;
397  while ((pos = str.find(sub, pos)) != std::string::npos) {
398  ++count;
399  pos += sub.length();
400  }
401  return count;
402 }
403 
410 inline std::bitset<LABELRANGE> Str2Set(const std::string& str) {
411  std::bitset<LABELRANGE> ret;
412  for (const char& c : str) {
413  ret.set(static_cast<int>(c));
414  }
415  return ret;
416 }
417 
425 inline std::vector<std::string> Split(const std::string& str, const std::string& sub) {
426  std::string::size_type pos = 0;
427  std::string::size_type start = 0;
428  std::vector<std::string> ret;
429  while ((pos = str.find(sub, start)) != std::string::npos) {
430  ret.push_back(str.substr(start, pos - start));
431  start = pos + sub.length();
432  }
433  ret.push_back(str.substr(start));
434  return ret;
435 }
436 
446 inline std::tuple<std::string, std::string> ParseEinsumInput(
447  std::string subscripts, const std::vector<Array<PrimExpr>>& operands) {
448  const std::string einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
449  std::bitset<LABELRANGE> einsum_symbols_set;
450  for (const char& c : einsum_symbols) {
451  einsum_symbols_set.set(c);
452  }
453 
454  CHECK_NE(operands.size(), 0U) << "No input operands";
455 
456  auto end_pos = std::remove(subscripts.begin(), subscripts.end(), ' ');
457  subscripts.erase(end_pos, subscripts.end());
458 
459  // Ensure all characters are valid
460  for (const char& c : subscripts) {
461  if (c == '.' || c == ',' || c == '-' || c == '>') {
462  continue;
463  }
464  CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
465  }
466 
467  // Check for proper "->"
468  if (subscripts.find('-') != std::string::npos || subscripts.find('>') != std::string::npos) {
469  bool invalid = (std::count(subscripts.begin(), subscripts.end(), '-') > 1 ||
470  std::count(subscripts.begin(), subscripts.end(), '>') > 1);
471  CHECK(!invalid && CountSubstring(subscripts, "->") == 1)
472  << "Subscripts can only contain one '->'.";
473  }
474 
475  // Parse ellipses
476  if (subscripts.find('.') != std::string::npos) {
477  std::string used = subscripts;
478  used.erase(
479  std::remove_if(used.begin(), used.end(),
480  [](const char& c) { return c == '.' || c == ',' || c == '-' || c == '>'; }),
481  used.end());
482 
483  std::bitset<LABELRANGE> used_set = Str2Set(used);
484  std::string ellipse_inds = "";
485  for (const char& c : einsum_symbols) {
486  if (!used_set.test(static_cast<int>(c))) {
487  ellipse_inds.append(1, c);
488  }
489  }
490  int longest = 0;
491  std::string input_tmp, output_sub;
492  std::vector<std::string> split_subscripts;
493  bool out_sub;
494 
495  if (subscripts.find("->") != std::string::npos) {
496  std::vector<std::string> tmp = Split(subscripts, "->");
497  input_tmp = tmp[0];
498  output_sub = tmp[1];
499  split_subscripts = Split(input_tmp, ",");
500  out_sub = true;
501  } else {
502  split_subscripts = Split(subscripts, ",");
503  out_sub = false;
504  }
505 
506  size_t size_split_subscripts = split_subscripts.size();
507  subscripts = "";
508  for (size_t i = 0; i < size_split_subscripts; ++i) {
509  const std::string& sub = split_subscripts[i];
510  if (sub.find('.') != std::string::npos) {
511  CHECK_EQ(std::count(sub.begin(), sub.end(), '.'), 3) << "Invalid Ellipses";
512  CHECK_EQ(CountSubstring(sub, "..."), 1) << "Invalid Ellipses";
513 
514  // Take into account numerical values
515  int ellipse_count = 0;
516  if (operands[i].size() == 0) {
517  ellipse_count = 0;
518  } else {
519  ellipse_count = std::max(operands[i].size(), static_cast<size_t>(1));
520  ellipse_count -= sub.length() - 3;
521  }
522 
523  if (ellipse_count > longest) {
524  longest = ellipse_count;
525  }
526 
527  CHECK_GE(ellipse_count, 0) << "Ellipses lengths do not match.";
528  if (ellipse_count == 0) {
529  split_subscripts[i].erase(sub.find("..."), 3);
530  } else {
531  std::string rep_inds = ellipse_inds.substr(ellipse_inds.length() - ellipse_count);
532  split_subscripts[i].replace(sub.find("..."), 3, rep_inds);
533  }
534  }
535  subscripts += split_subscripts[i];
536  if (i + 1 < size_split_subscripts) {
537  subscripts += ",";
538  }
539  }
540  std::string out_ellipse;
541  if (longest == 0) {
542  out_ellipse = "";
543  } else {
544  out_ellipse = ellipse_inds.substr(ellipse_inds.length() - longest);
545  }
546 
547  if (out_sub) {
548  output_sub.replace(output_sub.find("..."), 3, out_ellipse);
549  subscripts += "->" + output_sub;
550  } else {
551  // Special care for outputless ellipses
552  std::bitset<LABELRANGE> out_ellipse_set = Str2Set(out_ellipse);
553  std::string tmp_subscripts = subscripts, output_subscript = "";
554  size_t len_tmp_subscripts = tmp_subscripts.length();
555  std::sort(tmp_subscripts.begin(), tmp_subscripts.end());
556  for (size_t i = 0; i < len_tmp_subscripts; ++i) {
557  const char& c = tmp_subscripts[i];
558  if (c == ',') {
559  continue;
560  }
561  CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
562  if ((i == 0 || tmp_subscripts[i - 1] != c) &&
563  (i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c) &&
564  !out_ellipse_set.test(c)) {
565  output_subscript.append(1, c);
566  }
567  }
568  subscripts += "->" + out_ellipse + output_subscript;
569  }
570  }
571 
572  // Build output string if does not exist
573  std::tuple<std::string, std::string> ret;
574  if (subscripts.find("->") != std::string::npos) {
575  std::vector<std::string> tmp(2);
576  tmp = Split(subscripts, "->");
577  ret = std::make_tuple(tmp[0], tmp[1]);
578  } else {
579  std::string first = subscripts;
580  std::string second = "";
581  // Build output subscripts
582  std::string tmp_subscripts = subscripts;
583  size_t len_tmp_subscripts = tmp_subscripts.length();
584  std::sort(tmp_subscripts.begin(), tmp_subscripts.end());
585  for (size_t i = 0; i < len_tmp_subscripts; ++i) {
586  const char& c = tmp_subscripts[i];
587  if (c == ',') {
588  continue;
589  }
590  CHECK(einsum_symbols_set.test(c)) << "Character " << c << " is not a valid symbol.";
591  if ((i == 0 || tmp_subscripts[i - 1] != c) &&
592  (i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c)) {
593  second.append(1, c);
594  }
595  }
596  ret = std::make_tuple(first, second);
597  }
598 
599  // Make sure output subscripts are in the input
600  std::bitset<LABELRANGE> input_subscripts_set = Str2Set(std::get<0>(ret));
601  for (const char& c : std::get<1>(ret)) {
602  CHECK(input_subscripts_set.test(c))
603  << "Output character " << c << " did not appear in the input";
604  }
605 
606  // Make sure number operands is equivalent to the number of terms
607  CHECK_EQ(std::count(std::get<0>(ret).begin(), std::get<0>(ret).end(), ',') + 1, operands.size())
608  << "Number of einsum subscripts must be equal to the "
609  << "number of operands.";
610 
611  return ret;
612 }
613 
621 inline Array<PrimExpr> NumpyEinsumShape(const std::string subscripts,
622  const std::vector<Array<PrimExpr>>& operands) {
623  // Parsing
624  std::tuple<std::string, std::string> parsed_subscripts = ParseEinsumInput(subscripts, operands);
625 
626  // Build a few useful list and sets
627  std::vector<std::string> input_list = Split(std::get<0>(parsed_subscripts), ",");
628  size_t isize = input_list.size();
629 
630  // Get length of each unique dimension and ensure all dimensions are correct
631  int dimension_dict[LABELRANGE];
632  memset(dimension_dict, -1, sizeof(dimension_dict));
633  for (size_t i = 0; i < isize; ++i) {
634  const std::string& term = input_list[i];
635  const Array<PrimExpr>& sh = operands[i];
636  CHECK_EQ(sh.size(), term.length())
637  << "Einstein sum subscript " << input_list[i] << " does not contain the "
638  << "correct number of indices for operand " << i << ".";
639  size_t len_term = term.length();
640  for (size_t j = 0; j < len_term; ++j) {
641  int64_t dim = GetConstInt(sh[j]);
642  const char& c = term[j];
643 
644  if (dimension_dict[static_cast<int>(c)] != -1) {
645  // For broadcasting cases we always want the largest dim size
646  if (dimension_dict[static_cast<int>(c)] == 1) {
647  dimension_dict[static_cast<int>(c)] = dim;
648  }
649  CHECK(dim == 1 || dim == dimension_dict[static_cast<int>(c)])
650  << "Size of label '" << c << "' for operand " << i << " ("
651  << dimension_dict[static_cast<int>(c)] << ") does not match previous terms (" << dim
652  << ").";
653  } else {
654  dimension_dict[static_cast<int>(c)] = dim;
655  }
656  }
657  }
658 
659  // Get oshape
660  const std::string& output_str = std::get<1>(parsed_subscripts);
661  size_t odim = output_str.size();
662  Array<PrimExpr> oshape(odim, -1);
663  for (size_t i = 0; i < odim; ++i) {
664  oshape.Set(i, dimension_dict[static_cast<int>(output_str[i])]);
665  }
666  // Neglecting oshape assign check temporally
667  return oshape;
668 }
669 
681 inline Tensor einsum(const std::string& subscripts_str, const Array<Tensor> inputs,
682  std::string name = "T_einsum", std::string tag = kEinsum) {
683  bool back = false;
684  const char* subscripts = subscripts_str.data();
685  const char* head = subscripts;
686  const int nop = inputs.size();
687 
688  /* Step 1: Parse the subscripts string into label_counts and op_labels */
689  int iop, idim, min_label = LABELRANGE - 1, max_label = 0;
690  char label_counts[LABELRANGE], op_labels[NPY_MAXARGS][NPY_MAXDIMS];
691  memset(label_counts, 0, sizeof(label_counts));
692  for (iop = 0; iop < nop; ++iop) {
693  int length = static_cast<int>(strcspn(subscripts, ",-"));
694 
695  CHECK(!(iop == nop - 1 && subscripts[length] == ','))
696  << "more operands provided to einstein sum function "
697  << "than specified in the subscripts string";
698  CHECK(!(iop < nop - 1 && subscripts[length] != ','))
699  << "fewer operands provided to einstein sum function "
700  << "than specified in the subscripts string";
701  CHECK_EQ(ParseOperandSubscripts(subscripts, length, inputs[iop + back].ndim(), iop,
702  op_labels[iop], label_counts, &min_label, &max_label),
703  0);
704 
705  /* Move subscripts to the start of the labels for the next op */
706  subscripts += length;
707 
708  if (iop < nop - 1) {
709  CHECK_LT(subscripts - head, subscripts_str.length()) << "subscripts out of range";
710  subscripts++;
711  }
712  }
713  /*
714  * Find the number of broadcast dimensions, which is the maximum
715  * number of labels == 0 in an op_labels array.
716  */
717  int ndim_broadcast = 0;
718  for (iop = 0; iop < nop; ++iop) {
719  int count_zeros = 0;
720  int ndim;
721  char* labels = op_labels[iop];
722 
723  ndim = inputs[iop + back].ndim();
724  for (idim = 0; idim < ndim; ++idim) {
725  if (labels[idim] == 0) {
726  ++count_zeros;
727  }
728  }
729 
730  if (count_zeros > ndim_broadcast) {
731  ndim_broadcast = count_zeros;
732  }
733  }
734 
735  /*
736  * If there is no output signature, fill output_labels and ndim_output
737  * using each label that appeared once, in alphabetical order.
738  */
739  int label, ndim_output;
740  char output_labels[NPY_MAXDIMS];
741  if (subscripts[0] == '\0') {
742  /* If no output was specified, always broadcast left, as usual. */
743  for (ndim_output = 0; ndim_output < ndim_broadcast; ++ndim_output) {
744  output_labels[ndim_output] = 0;
745  }
746  for (label = min_label; label <= max_label; ++label) {
747  if (label_counts[label] == 1) {
748  CHECK(ndim_output < NPY_MAXDIMS) << "einstein sum subscript string has too many "
749  << "distinct labels";
750  output_labels[ndim_output++] = label;
751  }
752  }
753  } else {
754  CHECK(subscripts[0] == '-' && subscripts[1] == '>') << "einstein sum subscript string does not "
755  << "contain proper '->' output specified";
756  subscripts += 2;
757 
758  /* Parse the output subscript string. */
759  ndim_output = ParseOutputSubscripts(subscripts, strlen(subscripts), ndim_broadcast,
760  label_counts, output_labels);
761  CHECK_GE(ndim_output, 0);
762  }
763 
764  /*
765  * Step 2:
766  * Process all the input ops, combining dimensions into their
767  * diagonal where specified.
768  */
769  std::vector<Array<PrimExpr>> opshape(nop), opstride_true(nop);
770  for (iop = 0; iop < nop; ++iop) {
771  char* labels = op_labels[iop];
772  int combine, ndim;
773 
774  ndim = inputs[iop + back].ndim();
775 
776  /*
777  * Check whether any dimensions need to be combined
778  *
779  * The char type may be either signed or unsigned, we
780  * need it to be signed here.
781  */
782  combine = 0;
783  for (idim = 0; idim < ndim; ++idim) {
784  if ((signed char)labels[idim] < 0) {
785  combine++;
786  }
787  }
788  /* If any dimensions are combined, create a view which combines them */
789  if (combine) {
790  Array<PrimExpr> tshape(static_cast<size_t>(ndim - combine), -1);
791  Array<PrimExpr> tstride(static_cast<size_t>(ndim - combine), -1);
792  GetCombinedDimsView(inputs[iop + back], iop, labels, &tshape, &tstride);
793  opshape[iop] = tshape;
794  opstride_true[iop] = tstride;
795  } else {
796  /* No combining needed */
797  opshape[iop] = inputs[iop + back]->shape;
798  opstride_true[iop] = GetStride(opshape[iop]);
799  }
800  }
801  /*
802  * Step 3:
803  * Set up the labels for the iterator (output + combined labels).
804  * Can just share the output_labels memory, because iter_labels
805  * is output_labels with some more labels appended.
806  */
807  char* iter_labels = output_labels;
808  int ndim_iter = ndim_output;
809  for (label = min_label; label <= max_label; ++label) {
810  if (label_counts[label] > 0 && memchr(output_labels, label, ndim_output) == nullptr) {
811  CHECK(ndim_iter < NPY_MAXDIMS) << "too many subscripts in einsum";
812  iter_labels[ndim_iter++] = label;
813  }
814  }
815  /* Step 4: Set up the op_axes for the iterator */
816  Array<PrimExpr> itershape(static_cast<size_t>(ndim_iter), -1);
817  std::vector<Array<PrimExpr>> iterstride(nop + 1,
818  Array<PrimExpr>(static_cast<size_t>(ndim_iter), 0));
819 
820  // output_shape
821  std::vector<Array<PrimExpr>> operands;
822  for (size_t i = 0; i < inputs.size(); i++) {
823  operands.push_back(inputs[i]->shape);
824  }
825  Array<PrimExpr> oshape = NumpyEinsumShape(subscripts_str, operands);
826  Array<PrimExpr> ostride_true = GetStride(oshape);
827  Array<PrimExpr> reduceshape;
828  std::vector<Array<PrimExpr>> remainshape(nop);
829  int op_axes_arrays[NPY_MAXARGS][NPY_MAXDIMS];
830  int* op_axes[NPY_MAXARGS];
831  for (iop = 0; iop < nop; ++iop) {
832  op_axes[iop] = op_axes_arrays[iop];
833  CHECK_GE(PrepareOpAxes(opshape[iop].size(), iop, op_labels[iop], op_axes[iop], ndim_iter,
834  iter_labels),
835  0);
836  for (idim = 0; idim < ndim_iter; idim++) {
837  if (op_axes[iop][idim] != -1) {
838  iterstride[iop].Set(idim, opstride_true[iop][op_axes[iop][idim]]);
839  if (GetConstInt(itershape[idim]) != -1) {
840  if (GetConstInt(itershape[idim]) == 1) {
841  itershape.Set(idim, opshape[iop][op_axes[iop][idim]]);
842  }
843  } else {
844  itershape.Set(idim, opshape[iop][op_axes[iop][idim]]);
845  }
846  }
847  }
848  }
849  for (idim = 0; idim < ndim_output; ++idim) {
850  iterstride[nop].Set(idim, ostride_true[idim]);
851  }
852  reduceshape = Array<PrimExpr>(static_cast<size_t>(ndim_iter - ndim_output), 0);
853  for (idim = ndim_output; idim < ndim_iter; ++idim) {
854  reduceshape.Set(idim - ndim_output, itershape[idim]);
855  }
856  for (iop = 0; iop < nop; iop++) {
857  Array<Integer> rsh;
858  for (idim = 0; idim < ndim_iter; idim++) {
859  if (op_axes_arrays[iop][idim] == -1) {
860  rsh.push_back(GetConstInt(itershape[idim]));
861  } else {
862  if (GetConstInt(itershape[idim] != opshape[iop][op_axes_arrays[iop][idim]])) {
863  rsh.push_back(GetConstInt(itershape[idim]));
864  }
865  }
866  }
867  remainshape[iop] = Array<PrimExpr>(rsh.begin(), rsh.end());
868  }
869  // exclude the 0-dim case
870  if (ndim_iter == 0) {
871  ndim_iter = 1;
872  }
873  itershape = Pad(itershape, ndim_iter);
874  for (iop = 0; iop <= nop; ++iop) {
875  iterstride[iop] = Pad(iterstride[iop], ndim_iter);
876  }
877  // oshape = Pad(oshape, ndim_iter);
878  reduceshape = Pad(reduceshape, ndim_iter);
879  for (iop = 0; iop < nop; ++iop) {
880  opshape[iop] = Pad(opshape[iop], ndim_iter);
881  remainshape[iop] = Pad(remainshape[iop], ndim_iter);
882  }
883  // ostride and rstride
884  Array<Array<PrimExpr>> ostride;
885  Array<Array<PrimExpr>> rstride;
886 
887  for (iop = 0; iop < nop; ++iop) {
888  Array<PrimExpr> otmp(static_cast<size_t>(ndim_iter), 0);
889  Array<PrimExpr> rtmp(static_cast<size_t>(ndim_iter), 0);
890  for (idim = 0; idim < ndim_iter; ++idim) {
891  otmp.Set(idim, idim < ndim_output ? iterstride[iop][idim] : 1);
892  rtmp.Set(idim, idim < ndim_iter - ndim_output ? iterstride[iop][idim + ndim_output] : 1);
893  }
894  ostride.push_back(otmp);
895  rstride.push_back(rtmp);
896  }
897 
898  // func: input indices => return cooresponding value
899  auto func = [inputs, oshape, ostride, reduceshape, ndim_iter, rstride,
900  nop](const Array<Var>& input_indices) -> PrimExpr {
901  for (int rdim = 0; rdim < ndim_iter; ++rdim) {
902  if (GetConstInt(reduceshape[rdim]) == 0) {
903  return 0; //
904  }
905  }
906  Array<PrimExpr> ridx = UnravelIndex(0, reduceshape);
907 
908  PrimExpr sum = 0;
909  bool rec_flag = false;
910  do {
911  PrimExpr tmp = 1;
912  for (int iop = 0; iop < nop; ++iop) {
913  if (iop != -1) {
914  PrimExpr k = 0;
915 
916  for (size_t i = 0; i < input_indices.size(); ++i) {
917  k += input_indices[i] * ostride[iop][i];
918  }
919  for (size_t i = 0; i < ridx.size(); ++i) {
920  k += ridx[i] * rstride[iop][i];
921  }
922  Array<PrimExpr> temp_indices = UnravelIndex(k, inputs[iop]->shape);
923  tmp = tmp * inputs[iop](temp_indices);
924  }
925  }
926  sum += tmp;
927  ridx.Set(ridx.size() - 1, ridx[ridx.size() - 1] + 1);
928  for (int i = static_cast<int>(ridx.size() - 1);
929  (i > 0) && GetConstInt(ridx[i] >= reduceshape[i]); --i) {
930  ridx.Set(i, ridx[i] - reduceshape[i]);
931  ridx.Set(i - 1, ridx[i - 1] + 1);
932  }
933  rec_flag = GetConstInt(ridx[0] < reduceshape[0]);
934  } while (rec_flag);
935  return sum;
936  };
937 
938  return compute(oshape, func, name, tag);
939 }
940 
941 } // namespace topi
942 } // namespace tvm
943 #endif // TVM_TOPI_EINSUM_H_
#define NPY_MAXDIMS
Definition: einsum.h:28
constexpr auto kEinsum
Definition: tags.h:44
int ParseOutputSubscripts(const char *subscripts, int length, int ndim_broadcast, const char *label_counts, char *out_labels)
Parse the subscripts for the output into an output that includes &#39;ndim_broadcast&#39; unlabeled dimension...
Definition: einsum.h:204
Array< PrimExpr > Pad(const Array< PrimExpr > shape, int odim)
Pad the shape with 1.
Definition: einsum.h:78
Tensor einsum(const std::string &subscripts_str, const Array< Tensor > inputs, std::string name="T_einsum", std::string tag=kEinsum)
Evaluates the Einstein summation convention on the operands.
Definition: einsum.h:681
std::bitset< 128 > Str2Set(const std::string &str)
Transfer string to.
Definition: einsum.h:410
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr sub(PrimExpr a, PrimExpr b, Span span=Span())
subtraction operator
Tensor expression language DSL.
Definition: extracted_task.h:33
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
Array< PrimExpr > NumpyEinsumShape(const std::string subscripts, const std::vector< Array< PrimExpr >> &operands)
Compute the shape of the output.
Definition: einsum.h:621
size_t ndim() const
Definition: tensor.h:214
void Set(int64_t i, T value)
set i-th element of the array.
Definition: array.h:567
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:436
Tensor sum(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:326
Utility functions for handling constants in TVM expressions.
size_t size() const
Definition: array.h:399
Utility functions for handling tensor.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
void GetCombinedDimsView(const Tensor &op, int iop, char *labels, Array< PrimExpr > *newshape, Array< PrimExpr > *newstride)
If any dimensions are combined, create a view that combines them. Shows in newshape and newstride...
Definition: einsum.h:278
int ParseOperandSubscripts(const char *subscripts, int length, int ndim, int iop, char *op_labels, char *label_counts, int *min_label, int *max_label)
Parse the subscripts for one operand into an output of &#39;ndim&#39; labels.
Definition: einsum.h:105
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1758
iterator end() const
Definition: array.h:369
std::tuple< std::string, std::string > ParseEinsumInput(std::string subscripts, const std::vector< Array< PrimExpr >> &operands)
Parse the input subscripts into a vector of strings.
Definition: einsum.h:446
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
iterator begin() const
Definition: array.h:366
Operation node can generate one or multiple Tensors.
#define NPY_MAXARGS
Definition: einsum.h:29
int CountSubstring(const std::string &str, const std::string &sub)
Count SubString.
Definition: einsum.h:394
Tensor prod(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates product operation over given axis.
Definition: reduction.h:568
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
External function interface to rocBLAS libraries.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Array< PrimExpr > GetStride(const Array< PrimExpr > shape)
Compute the stride of the given shape.
Definition: einsum.h:59
Reference to PrimExprNode.
Definition: expr.h:112
Layout expression to describe the data organization of a tensor. And BijectiveLayout to mapping two d...
std::vector< std::string > Split(const std::string &str, const std::string &sub)
Split str according to substring.
Definition: einsum.h:425
Index ravel and unraval operations.
#define LABELRANGE
Definition: einsum.h:27