tvm
ravel_unravel.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_DETAIL_RAVEL_UNRAVEL_H_
25 #define TVM_TOPI_DETAIL_RAVEL_UNRAVEL_H_
26 
27 #include <tvm/te/operation.h>
28 
29 #include <vector>
30 
31 namespace tvm {
32 namespace topi {
33 namespace detail {
34 
35 using namespace tvm::te;
36 
45 inline PrimExpr RavelIndex(Array<PrimExpr> indices, Array<PrimExpr> shape) {
46  ICHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size";
47  if (indices.size() == 0U) {
48  return 0;
49  }
50  PrimExpr idx;
51  for (size_t i = 0; i < indices.size(); ++i) {
52  if (i == 0) {
53  idx = indices[i];
54  } else {
55  idx = idx * shape[i] + indices[i];
56  }
57  }
58  return idx;
59 }
60 
69 inline Array<PrimExpr> UnravelIndex(PrimExpr idx, Array<PrimExpr> shape) {
70  std::vector<PrimExpr> indices;
71 
72  for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
73  indices.push_back(indexmod(idx, shape[i]));
74  idx = indexdiv(idx, shape[i]);
75  }
76  std::reverse(indices.begin(), indices.end());
77  return indices;
78 }
79 
80 } // namespace detail
81 } // namespace topi
82 } // namespace tvm
83 #endif // TVM_TOPI_DETAIL_RAVEL_UNRAVEL_H_
Tensor expression language DSL.
Definition: extracted_task.h:33
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1913
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
Operation node can generate one or multiple Tensors.