tvm
flatten.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_NN_FLATTEN_H_
25 #define TVM_TOPI_NN_FLATTEN_H_
26 
27 #include <tvm/te/operation.h>
29 #include <tvm/topi/tags.h>
30 
31 #include <string>
32 #include <vector>
33 
34 namespace tvm {
35 namespace topi {
36 namespace nn {
37 
38 using namespace tvm::te;
39 
50 inline Tensor flatten(const Tensor& x, std::string name = "tensor", std::string tag = kInjective) {
51  auto ishape = x->shape;
52  PrimExpr dim = 1;
53  for (size_t i = 1; i < ishape.size(); ++i) {
54  dim = dim * ishape[i];
55  }
56 
57  Array<PrimExpr> oshape({ishape[0], dim});
58 
59  std::vector<PrimExpr> extra_shape;
60  for (size_t i = 1; i < ishape.size(); ++i) {
61  extra_shape.push_back(ishape[i]);
62  }
63  std::reverse(extra_shape.begin(), extra_shape.end());
64 
65  return tvm::te::compute(
66  oshape,
67  [&](Var i, Var j) {
68  PrimExpr idx = j;
69  std::vector<PrimExpr> index;
70  for (auto s : extra_shape) {
71  index.push_back(indexmod(idx, s));
72  idx = indexdiv(idx, s);
73  }
74  index.push_back(i);
75  std::reverse(index.begin(), index.end());
76  return x(index);
77  },
78  name, tag);
79 }
80 
81 } // namespace nn
82 } // namespace topi
83 } // namespace tvm
84 #endif // TVM_TOPI_NN_FLATTEN_H_
Reference to PrimExprNode.
Definition: expr.h:115
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
iterator end() const
Definition: array.h:390
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
iterator begin() const
Definition: array.h:387
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
a named variable in TIR
Definition: var.h:89
Utility functions for handling constants in TVM expressions.
Tensor expression language DSL.
Definition: extracted_task.h:33
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Tensor flatten(const Tensor &x, std::string name="tensor", std::string tag=kInjective)
Flattens the input tensor into a 2-D tensor by collapsing higher dimensions. This requires the input ...
Definition: flatten.h:50
constexpr auto kInjective
Definition: tags.h:33
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
Operation node can generate one or multiple Tensors.
External function interface to rocBLAS libraries.