24 #ifndef TVM_TOPI_EINSUM_H_ 25 #define TVM_TOPI_EINSUM_H_ 27 #define LABELRANGE 128 28 #define NPY_MAXDIMS 16 29 #define NPY_MAXARGS 16 43 #include <unordered_set> 50 using namespace topi::detail;
60 size_t ndim = shape.
size();
63 for (
int i = ndim - 1; i >= 0; i--) {
65 prod = prod * GetConstInt(shape[i]);
79 int ndim = shape.
size();
82 for (
int idim = 0; idim < ndim; ++idim) {
83 ret.
Set(idim, shape[idim]);
106 char* op_labels,
char* label_counts,
int* min_label,
113 for (i = 0; i < length; ++i) {
114 int label = subscripts[i];
117 if (label > 0 && isalpha(label)) {
119 CHECK(idim < ndim) <<
"einstein sum subscripts string contains " 120 <<
"too many subscripts for operand " << iop;
122 op_labels[idim++] = label;
123 if (label < *min_label) {
126 if (label > *max_label) {
129 label_counts[label]++;
130 }
else if (label ==
'.') {
134 !(ellipsis != -1 || i + 2 >= length || subscripts[++i] !=
'.' || subscripts[++i] !=
'.'))
135 <<
"einstein sum subscripts string contains a " 136 <<
"'.' that is not part of an ellipsis ('...') " 137 <<
"in operand " << iop;
141 CHECK(label ==
' ') <<
"invalid subscript '" <<
static_cast<char>(label)
142 <<
"' in einstein sum " 143 <<
"subscripts string, subscripts must " 149 if (ellipsis == -1) {
150 CHECK(idim == ndim) <<
"operand has more dimensions than subscripts " 151 <<
"given in einstein sum, but no '...' ellipsis " 152 <<
"provided to broadcast the extra dimensions.";
153 }
else if (idim < ndim) {
156 for (i = 0; i < idim - ellipsis; ++i) {
157 op_labels[ndim - i - 1] = op_labels[idim - i - 1];
160 for (i = 0; i < ndim - idim; ++i) {
161 op_labels[ellipsis + i] = 0;
173 for (idim = 0; idim < ndim - 1; ++idim) {
174 int label = op_labels[idim];
178 char* next =
reinterpret_cast<char*
>(memchr(op_labels + idim + 1, label, ndim - idim - 1));
180 while (next !=
nullptr) {
182 *next =
static_cast<char>((op_labels + idim) - next);
184 next =
reinterpret_cast<char*
>(memchr(next + 1, label, op_labels + ndim - 1 - next));
205 const char* label_counts,
char* out_labels) {
211 for (i = 0; i < length; ++i) {
212 int label = subscripts[i];
215 if (label > 0 && isalpha(label)) {
217 CHECK(memchr(subscripts + i + 1, label, length - i - 1) ==
nullptr)
218 <<
"einstein sum subscripts string includes " 219 <<
"output subscript '" <<
static_cast<char>(label) <<
"' multiple times";
222 CHECK(label_counts[label] != 0)
223 <<
"einstein sum subscripts string included " 224 <<
"output subscript '" <<
static_cast<char>(label) <<
"' which never appeared " 228 CHECK(ndim <
NPY_MAXDIMS) <<
"einstein sum subscripts string contains " 229 <<
"too many subscripts in the output";
231 out_labels[ndim++] = label;
232 }
else if (label ==
'.') {
235 CHECK(!(ellipsis || i + 2 >= length || subscripts[++i] !=
'.' || subscripts[++i] !=
'.'))
236 <<
"einstein sum subscripts string " 237 <<
"contains a '.' that is not part of " 238 <<
"an ellipsis ('...') in the output";
241 CHECK(ndim + ndim_broadcast <=
NPY_MAXDIMS) <<
"einstein sum subscripts string contains " 242 <<
"too many subscripts in the output";
245 for (bdim = 0; bdim < ndim_broadcast; ++bdim) {
246 out_labels[ndim++] = 0;
249 CHECK(label ==
' ') <<
"invalid subscript '" <<
static_cast<char>(label)
250 <<
"' in einstein sum " 251 <<
"subscripts string, subscripts must " 257 CHECK(!(!ellipsis && ndim_broadcast > 0)) <<
"output has more dimensions than subscripts " 258 <<
"given in einstein sum, but no '...' ellipsis " 259 <<
"provided to broadcast the extra dimensions.";
280 int idim, ndim, icombine, combineoffset;
287 newdim = newshape->
size();
290 for (idim = 0; idim < newdim; ++idim) {
291 newshape->
Set(idim, 0);
292 newstride->
Set(idim, 0);
297 for (idim = 0; idim < ndim; ++idim) {
302 int label = (
signed char)labels[idim];
305 combineoffset = label;
306 label = labels[idim + label];
309 if (icombine != idim) {
310 labels[icombine] = labels[idim];
312 icombinemap[idim] = icombine;
316 newshape->
Set(icombine, shape[idim]);
317 newstride->
Set(icombine, stride[idim]);
320 int i = icombinemap[idim + combineoffset];
321 CHECK(!((combineoffset < 0) &&
322 GetConstInt((*newshape)[i] != 0 && (*newshape)[i] != shape[idim])))
323 <<
"dimensions in operand " << iop <<
" for collapsing index '" << label
324 <<
"' don't match (" << GetConstInt((*newshape)[i]) <<
" != " << shape[idim] <<
")";
325 newshape->
Set(i, shape[idim]);
326 newstride->
Set(i, (*newstride)[i] + stride[idim]);
330 if (combineoffset == 0) {
346 inline static int PrepareOpAxes(
int ndim,
int iop,
char* labels,
int* axes,
int ndim_iter,
348 int i, label, ibroadcast;
350 ibroadcast = ndim - 1;
351 for (i = ndim_iter - 1; i >= 0; --i) {
352 label = iter_labels[i];
358 while (ibroadcast >= 0 && labels[ibroadcast] != 0) {
365 if (ibroadcast < 0) {
369 axes[i] = ibroadcast;
374 char* match =
reinterpret_cast<char*
>(memchr(labels, label, ndim));
376 if (match ==
nullptr) {
380 axes[i] = match - labels;
396 std::string::size_type pos = 0;
397 while ((pos = str.find(sub, pos)) != std::string::npos) {
410 inline std::bitset<LABELRANGE>
Str2Set(
const std::string& str) {
411 std::bitset<LABELRANGE>
ret;
412 for (
const char& c : str) {
413 ret.set(static_cast<int>(c));
425 inline std::vector<std::string>
Split(
const std::string& str,
const std::string&
sub) {
426 std::string::size_type pos = 0;
427 std::string::size_type start = 0;
428 std::vector<std::string>
ret;
429 while ((pos = str.find(sub, start)) != std::string::npos) {
430 ret.push_back(str.substr(start, pos - start));
431 start = pos + sub.length();
433 ret.push_back(str.substr(start));
447 std::string subscripts,
const std::vector<
Array<PrimExpr>>& operands) {
448 const std::string einsum_symbols =
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
449 std::bitset<LABELRANGE> einsum_symbols_set;
450 for (
const char& c : einsum_symbols) {
451 einsum_symbols_set.set(c);
454 CHECK_NE(operands.size(), 0U) <<
"No input operands";
456 auto end_pos = std::remove(subscripts.begin(), subscripts.end(),
' ');
457 subscripts.erase(end_pos, subscripts.end());
460 for (
const char& c : subscripts) {
461 if (c ==
'.' || c ==
',' || c ==
'-' || c ==
'>') {
464 CHECK(einsum_symbols_set.test(c)) <<
"Character " << c <<
" is not a valid symbol.";
468 if (subscripts.find(
'-') != std::string::npos || subscripts.find(
'>') != std::string::npos) {
469 bool invalid = (std::count(subscripts.begin(), subscripts.end(),
'-') > 1 ||
470 std::count(subscripts.begin(), subscripts.end(),
'>') > 1);
472 <<
"Subscripts can only contain one '->'.";
476 if (subscripts.find(
'.') != std::string::npos) {
477 std::string used = subscripts;
479 std::remove_if(used.begin(), used.end(),
480 [](
const char& c) {
return c ==
'.' || c ==
',' || c ==
'-' || c ==
'>'; }),
483 std::bitset<LABELRANGE> used_set =
Str2Set(used);
484 std::string ellipse_inds =
"";
485 for (
const char& c : einsum_symbols) {
486 if (!used_set.test(static_cast<int>(c))) {
487 ellipse_inds.append(1, c);
491 std::string input_tmp, output_sub;
492 std::vector<std::string> split_subscripts;
495 if (subscripts.find(
"->") != std::string::npos) {
496 std::vector<std::string> tmp =
Split(subscripts,
"->");
499 split_subscripts =
Split(input_tmp,
",");
502 split_subscripts =
Split(subscripts,
",");
506 size_t size_split_subscripts = split_subscripts.size();
508 for (
size_t i = 0; i < size_split_subscripts; ++i) {
509 const std::string&
sub = split_subscripts[i];
510 if (sub.find(
'.') != std::string::npos) {
511 CHECK_EQ(std::count(sub.begin(), sub.end(),
'.'), 3) <<
"Invalid Ellipses";
515 int ellipse_count = 0;
516 if (operands[i].size() == 0) {
519 ellipse_count =
std::max(operands[i].size(), static_cast<size_t>(1));
520 ellipse_count -= sub.length() - 3;
523 if (ellipse_count > longest) {
524 longest = ellipse_count;
527 CHECK_GE(ellipse_count, 0) <<
"Ellipses lengths do not match.";
528 if (ellipse_count == 0) {
529 split_subscripts[i].erase(sub.find(
"..."), 3);
531 std::string rep_inds = ellipse_inds.substr(ellipse_inds.length() - ellipse_count);
532 split_subscripts[i].replace(sub.find(
"..."), 3, rep_inds);
535 subscripts += split_subscripts[i];
536 if (i + 1 < size_split_subscripts) {
540 std::string out_ellipse;
544 out_ellipse = ellipse_inds.substr(ellipse_inds.length() - longest);
548 output_sub.replace(output_sub.find(
"..."), 3, out_ellipse);
549 subscripts +=
"->" + output_sub;
552 std::bitset<LABELRANGE> out_ellipse_set =
Str2Set(out_ellipse);
553 std::string tmp_subscripts = subscripts, output_subscript =
"";
554 size_t len_tmp_subscripts = tmp_subscripts.length();
555 std::sort(tmp_subscripts.begin(), tmp_subscripts.end());
556 for (
size_t i = 0; i < len_tmp_subscripts; ++i) {
557 const char& c = tmp_subscripts[i];
561 CHECK(einsum_symbols_set.test(c)) <<
"Character " << c <<
" is not a valid symbol.";
562 if ((i == 0 || tmp_subscripts[i - 1] != c) &&
563 (i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c) &&
564 !out_ellipse_set.test(c)) {
565 output_subscript.append(1, c);
568 subscripts +=
"->" + out_ellipse + output_subscript;
573 std::tuple<std::string, std::string>
ret;
574 if (subscripts.find(
"->") != std::string::npos) {
575 std::vector<std::string> tmp(2);
576 tmp =
Split(subscripts,
"->");
577 ret = std::make_tuple(tmp[0], tmp[1]);
579 std::string first = subscripts;
580 std::string second =
"";
582 std::string tmp_subscripts = subscripts;
583 size_t len_tmp_subscripts = tmp_subscripts.length();
584 std::sort(tmp_subscripts.begin(), tmp_subscripts.end());
585 for (
size_t i = 0; i < len_tmp_subscripts; ++i) {
586 const char& c = tmp_subscripts[i];
590 CHECK(einsum_symbols_set.test(c)) <<
"Character " << c <<
" is not a valid symbol.";
591 if ((i == 0 || tmp_subscripts[i - 1] != c) &&
592 (i == len_tmp_subscripts - 1 || tmp_subscripts[i + 1] != c)) {
596 ret = std::make_tuple(first, second);
600 std::bitset<LABELRANGE> input_subscripts_set =
Str2Set(std::get<0>(ret));
601 for (
const char& c : std::get<1>(ret)) {
602 CHECK(input_subscripts_set.test(c))
603 <<
"Output character " << c <<
" did not appear in the input";
607 CHECK_EQ(std::count(std::get<0>(ret).begin(), std::get<0>(ret).end(),
',') + 1, operands.size())
608 <<
"Number of einsum subscripts must be equal to the " 609 <<
"number of operands.";
624 std::tuple<std::string, std::string> parsed_subscripts =
ParseEinsumInput(subscripts, operands);
627 std::vector<std::string> input_list =
Split(std::get<0>(parsed_subscripts),
",");
628 size_t isize = input_list.size();
632 memset(dimension_dict, -1,
sizeof(dimension_dict));
633 for (
size_t i = 0; i < isize; ++i) {
634 const std::string& term = input_list[i];
636 CHECK_EQ(sh.
size(), term.length())
637 <<
"Einstein sum subscript " << input_list[i] <<
" does not contain the " 638 <<
"correct number of indices for operand " << i <<
".";
639 size_t len_term = term.length();
640 for (
size_t j = 0; j < len_term; ++j) {
641 int64_t dim = GetConstInt(sh[j]);
642 const char& c = term[j];
644 if (dimension_dict[static_cast<int>(c)] != -1) {
646 if (dimension_dict[static_cast<int>(c)] == 1) {
647 dimension_dict[
static_cast<int>(c)] = dim;
649 CHECK(dim == 1 || dim == dimension_dict[static_cast<int>(c)])
650 <<
"Size of label '" << c <<
"' for operand " << i <<
" (" 651 << dimension_dict[
static_cast<int>(c)] <<
") does not match previous terms (" << dim
654 dimension_dict[
static_cast<int>(c)] = dim;
660 const std::string& output_str = std::get<1>(parsed_subscripts);
661 size_t odim = output_str.size();
663 for (
size_t i = 0; i < odim; ++i) {
664 oshape.
Set(i, dimension_dict[static_cast<int>(output_str[i])]);
682 std::string name =
"T_einsum", std::string tag =
kEinsum) {
684 const char* subscripts = subscripts_str.data();
685 const char* head = subscripts;
686 const int nop = inputs.
size();
689 int iop, idim, min_label =
LABELRANGE - 1, max_label = 0;
691 memset(label_counts, 0,
sizeof(label_counts));
692 for (iop = 0; iop < nop; ++iop) {
693 int length =
static_cast<int>(strcspn(subscripts,
",-"));
695 CHECK(!(iop == nop - 1 && subscripts[length] ==
','))
696 <<
"more operands provided to einstein sum function " 697 <<
"than specified in the subscripts string";
698 CHECK(!(iop < nop - 1 && subscripts[length] !=
','))
699 <<
"fewer operands provided to einstein sum function " 700 <<
"than specified in the subscripts string";
702 op_labels[iop], label_counts, &min_label, &max_label),
706 subscripts += length;
709 CHECK_LT(subscripts - head, subscripts_str.length()) <<
"subscripts out of range";
717 int ndim_broadcast = 0;
718 for (iop = 0; iop < nop; ++iop) {
721 char* labels = op_labels[iop];
723 ndim = inputs[iop + back].ndim();
724 for (idim = 0; idim < ndim; ++idim) {
725 if (labels[idim] == 0) {
730 if (count_zeros > ndim_broadcast) {
731 ndim_broadcast = count_zeros;
739 int label, ndim_output;
741 if (subscripts[0] ==
'\0') {
743 for (ndim_output = 0; ndim_output < ndim_broadcast; ++ndim_output) {
744 output_labels[ndim_output] = 0;
746 for (label = min_label; label <= max_label; ++label) {
747 if (label_counts[label] == 1) {
748 CHECK(ndim_output <
NPY_MAXDIMS) <<
"einstein sum subscript string has too many " 749 <<
"distinct labels";
750 output_labels[ndim_output++] = label;
754 CHECK(subscripts[0] ==
'-' && subscripts[1] ==
'>') <<
"einstein sum subscript string does not " 755 <<
"contain proper '->' output specified";
760 label_counts, output_labels);
761 CHECK_GE(ndim_output, 0);
769 std::vector<Array<PrimExpr>> opshape(nop), opstride_true(nop);
770 for (iop = 0; iop < nop; ++iop) {
771 char* labels = op_labels[iop];
774 ndim = inputs[iop + back].ndim();
783 for (idim = 0; idim < ndim; ++idim) {
784 if ((
signed char)labels[idim] < 0) {
793 opshape[iop] = tshape;
794 opstride_true[iop] = tstride;
797 opshape[iop] = inputs[iop + back]->shape;
798 opstride_true[iop] =
GetStride(opshape[iop]);
807 char* iter_labels = output_labels;
808 int ndim_iter = ndim_output;
809 for (label = min_label; label <= max_label; ++label) {
810 if (label_counts[label] > 0 && memchr(output_labels, label, ndim_output) ==
nullptr) {
811 CHECK(ndim_iter <
NPY_MAXDIMS) <<
"too many subscripts in einsum";
812 iter_labels[ndim_iter++] = label;
817 std::vector<Array<PrimExpr>> iterstride(nop + 1,
821 std::vector<Array<PrimExpr>> operands;
822 for (
size_t i = 0; i < inputs.
size(); i++) {
823 operands.push_back(inputs[i]->
shape);
828 std::vector<Array<PrimExpr>> remainshape(nop);
831 for (iop = 0; iop < nop; ++iop) {
832 op_axes[iop] = op_axes_arrays[iop];
833 CHECK_GE(PrepareOpAxes(opshape[iop].size(), iop, op_labels[iop], op_axes[iop], ndim_iter,
836 for (idim = 0; idim < ndim_iter; idim++) {
837 if (op_axes[iop][idim] != -1) {
838 iterstride[iop].Set(idim, opstride_true[iop][op_axes[iop][idim]]);
839 if (GetConstInt(itershape[idim]) != -1) {
840 if (GetConstInt(itershape[idim]) == 1) {
841 itershape.
Set(idim, opshape[iop][op_axes[iop][idim]]);
844 itershape.
Set(idim, opshape[iop][op_axes[iop][idim]]);
849 for (idim = 0; idim < ndim_output; ++idim) {
850 iterstride[nop].Set(idim, ostride_true[idim]);
852 reduceshape =
Array<PrimExpr>(
static_cast<size_t>(ndim_iter - ndim_output), 0);
853 for (idim = ndim_output; idim < ndim_iter; ++idim) {
854 reduceshape.
Set(idim - ndim_output, itershape[idim]);
856 for (iop = 0; iop < nop; iop++) {
858 for (idim = 0; idim < ndim_iter; idim++) {
859 if (op_axes_arrays[iop][idim] == -1) {
860 rsh.
push_back(GetConstInt(itershape[idim]));
862 if (GetConstInt(itershape[idim] != opshape[iop][op_axes_arrays[iop][idim]])) {
863 rsh.
push_back(GetConstInt(itershape[idim]));
870 if (ndim_iter == 0) {
873 itershape =
Pad(itershape, ndim_iter);
874 for (iop = 0; iop <= nop; ++iop) {
875 iterstride[iop] =
Pad(iterstride[iop], ndim_iter);
878 reduceshape =
Pad(reduceshape, ndim_iter);
879 for (iop = 0; iop < nop; ++iop) {
880 opshape[iop] =
Pad(opshape[iop], ndim_iter);
881 remainshape[iop] =
Pad(remainshape[iop], ndim_iter);
887 for (iop = 0; iop < nop; ++iop) {
890 for (idim = 0; idim < ndim_iter; ++idim) {
891 otmp.
Set(idim, idim < ndim_output ? iterstride[iop][idim] : 1);
892 rtmp.
Set(idim, idim < ndim_iter - ndim_output ? iterstride[iop][idim + ndim_output] : 1);
899 auto func = [inputs, oshape, ostride, reduceshape, ndim_iter, rstride,
901 for (
int rdim = 0; rdim < ndim_iter; ++rdim) {
902 if (GetConstInt(reduceshape[rdim]) == 0) {
909 bool rec_flag =
false;
912 for (
int iop = 0; iop < nop; ++iop) {
916 for (
size_t i = 0; i < input_indices.size(); ++i) {
917 k += input_indices[i] * ostride[iop][i];
919 for (
size_t i = 0; i < ridx.
size(); ++i) {
920 k += ridx[i] * rstride[iop][i];
923 tmp = tmp * inputs[iop](temp_indices);
927 ridx.
Set(ridx.
size() - 1, ridx[ridx.
size() - 1] + 1);
928 for (
int i = static_cast<int>(ridx.
size() - 1);
929 (i > 0) && GetConstInt(ridx[i] >= reduceshape[i]); --i) {
930 ridx.
Set(i, ridx[i] - reduceshape[i]);
931 ridx.
Set(i - 1, ridx[i - 1] + 1);
933 rec_flag = GetConstInt(ridx[0] < reduceshape[0]);
938 return compute(oshape, func, name, tag);
943 #endif // TVM_TOPI_EINSUM_H_ #define NPY_MAXDIMS
Definition: einsum.h:28
constexpr auto kEinsum
Definition: tags.h:44
int ParseOutputSubscripts(const char *subscripts, int length, int ndim_broadcast, const char *label_counts, char *out_labels)
Parse the subscripts for the output into an output that includes 'ndim_broadcast' unlabeled dimension...
Definition: einsum.h:204
Array< PrimExpr > Pad(const Array< PrimExpr > shape, int odim)
Pad the shape with 1.
Definition: einsum.h:78
Tensor einsum(const std::string &subscripts_str, const Array< Tensor > inputs, std::string name="T_einsum", std::string tag=kEinsum)
Evaluates the Einstein summation convention on the operands.
Definition: einsum.h:681
std::bitset< 128 > Str2Set(const std::string &str)
Transfer string to.
Definition: einsum.h:410
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr sub(PrimExpr a, PrimExpr b, Span span=Span())
subtraction operator
Tensor expression language DSL.
Definition: extracted_task.h:33
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
Array< PrimExpr > NumpyEinsumShape(const std::string subscripts, const std::vector< Array< PrimExpr >> &operands)
Compute the shape of the output.
Definition: einsum.h:621
size_t ndim() const
Definition: tensor.h:214
void Set(int64_t i, T value)
set i-th element of the array.
Definition: array.h:586
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:455
Tensor sum(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:326
Utility functions for handling constants in TVM expressions.
size_t size() const
Definition: array.h:418
Utility functions for handling tensor.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void GetCombinedDimsView(const Tensor &op, int iop, char *labels, Array< PrimExpr > *newshape, Array< PrimExpr > *newstride)
If any dimensions are combined, create a view that combines them. Shows in newshape and newstride...
Definition: einsum.h:278
int ParseOperandSubscripts(const char *subscripts, int length, int ndim, int iop, char *op_labels, char *label_counts, int *min_label, int *max_label)
Parse the subscripts for one operand into an output of 'ndim' labels.
Definition: einsum.h:105
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1758
iterator end() const
Definition: array.h:388
std::tuple< std::string, std::string > ParseEinsumInput(std::string subscripts, const std::vector< Array< PrimExpr >> &operands)
Parse the input subscripts into a vector of strings.
Definition: einsum.h:446
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
iterator begin() const
Definition: array.h:385
Operation node can generate one or multiple Tensors.
#define NPY_MAXARGS
Definition: einsum.h:29
int CountSubstring(const std::string &str, const std::string &sub)
Count SubString.
Definition: einsum.h:394
Tensor prod(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates product operation over given axis.
Definition: reduction.h:568
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Array< PrimExpr > GetStride(const Array< PrimExpr > shape)
Compute the stride of the given shape.
Definition: einsum.h:59
Reference to PrimExprNode.
Definition: expr.h:112
Layout expression to describe the data organization of a tensor. And BijectiveLayout to mapping two d...
std::vector< std::string > Split(const std::string &str, const std::string &sub)
Split str according to substring.
Definition: einsum.h:425
Index ravel and unraval operations.
#define LABELRANGE
Definition: einsum.h:27