tvm
einsum.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_EINSUM_H_
25 #define TVM_TOPI_EINSUM_H_
26 
27 #define LABELRANGE 128
28 #define NPY_MAXDIMS 16
29 #define NPY_MAXARGS 16
30 
31 #include <tvm/te/operation.h>
35 #include <tvm/topi/tags.h>
36 
37 #include <algorithm>
38 #include <bitset>
39 #include <iterator>
40 #include <string>
41 #include <tuple>
42 #include <unordered_set>
43 #include <vector>
44 
45 namespace tvm {
46 namespace topi {
47 
48 using namespace tvm::te;
49 using namespace topi::detail;
50 
58 ffi::Array<PrimExpr> InferEinsumShape(const std::string& subscripts,
59  const std::vector<ffi::Array<PrimExpr>>& operands);
60 
72 Tensor einsum(const std::string& subscripts_str, const ffi::Array<Tensor> inputs,
73  std::string name = "T_einsum", std::string tag = kEinsum);
74 
81  static EinsumEquation FromString(const std::string& equation);
82  using Label = char;
83  using Subscript = std::vector<Label>;
84  // Special label value for ellipsis. The value is chosen to be less than any other letters so make
85  // sorting easier.
86  static constexpr Label kEllipsis = '\0';
87  // The input subscripts for each operand of the Einsum operator.
88  std::vector<Subscript> inputs;
89  // The output subscript of the Einsum equation.
91 };
92 
93 } // namespace topi
94 } // namespace tvm
95 #endif // TVM_TOPI_EINSUM_H_
Managed Tensor. The array is backed by reference counted blocks.
Definition: tensor.h:54
Utility functions for handling constants in TVM expressions.
Tensor expression language DSL.
Definition: extracted_task.h:33
Tensor einsum(const std::string &subscripts_str, const ffi::Array< Tensor > inputs, std::string name="T_einsum", std::string tag=kEinsum)
Evaluates the Einstein summation convention on the operands.
ffi::Array< PrimExpr > InferEinsumShape(const std::string &subscripts, const std::vector< ffi::Array< PrimExpr >> &operands)
Compute the shape of the output.
constexpr auto kEinsum
Definition: tags.h:44
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
Operation node can generate one or multiple Tensors.
Index ravel and unraval operations.
Definition: einsum.h:75
static EinsumEquation FromString(const std::string &equation)
Create EinsumEquation from a string. The result will be converted to the explicit mode of Einsum if i...
char Label
Definition: einsum.h:82
std::vector< Label > Subscript
Definition: einsum.h:83
std::vector< Subscript > inputs
Definition: einsum.h:88
Subscript output
Definition: einsum.h:90
Tag definitions.
Utility functions for handling tensor.