tvm
tensor_utils.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_DETAIL_TENSOR_UTILS_H_
25 #define TVM_TOPI_DETAIL_TENSOR_UTILS_H_
26 
27 #include <tvm/te/operation.h>
28 
29 #include <vector>
30 namespace tvm {
31 namespace topi {
32 namespace detail {
33 
34 using namespace tvm::te;
35 
43 inline bool is_empty_shape(const Array<PrimExpr>& x) {
44  bool is_empty = false;
45  for (const auto& dim : x) {
46  if (auto int_dim = dim.as<IntImmNode>()) {
47  if (int_dim->value == 0) {
48  is_empty = true;
49  break;
50  }
51  }
52  }
53  return is_empty;
54 }
55 
66 inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array<PrimExpr>& indices,
67  const PrimExpr max_y, const PrimExpr max_x) {
68  auto batch_id = indices[0];
69  auto channel_id = indices[1];
70  auto in_y = indices[2];
71  auto in_x = indices[3];
72 
73  auto y_low = tvm::cast(DataType::Int(32), tvm::floor(in_y));
74  auto y_high = y_low + 1;
75 
76  auto x_low = tvm::cast(DataType::Int(32), tvm::floor(in_x));
77  auto x_high = x_low + 1;
78 
79  auto wy_h = in_y - y_low;
80  auto wx_h = in_x - x_low;
81  auto wy_l = 1 - wy_h;
82  auto wx_l = 1 - wx_h;
83 
84  PrimExpr val = 0;
85  std::vector<std::vector<PrimExpr>> wx_xp{{wx_l, x_low}, {wx_h, x_high}};
86  std::vector<std::vector<PrimExpr>> wy_yp{{wy_l, y_low}, {wy_h, y_high}};
87  for (auto wx_xp_ele : wx_xp) {
88  for (auto wy_yp_ele : wy_yp) {
89  auto wx = wx_xp_ele[0];
90  auto xp = wx_xp_ele[1];
91  auto wy = wy_yp_ele[0];
92  auto yp = wy_yp_ele[1];
93  val += tvm::if_then_else(0 <= yp && yp <= max_y && 0 <= xp && xp <= max_x,
94  wx * wy * input(batch_id, channel_id, yp, xp), 0);
95  }
96  }
97  return val;
98 }
99 
110 inline PrimExpr bilinear_sample_nhwc(const Tensor& input, const Array<PrimExpr>& indices,
111  const PrimExpr max_y, const PrimExpr max_x) {
112  auto batch_id = indices[0];
113  auto channel_id = indices[3];
114  auto in_y = indices[1];
115  auto in_x = indices[2];
116 
117  auto y_low = tvm::cast(DataType::Int(32), tvm::floor(in_y));
118  auto y_high = y_low + 1;
119 
120  auto x_low = tvm::cast(DataType::Int(32), tvm::floor(in_x));
121  auto x_high = x_low + 1;
122 
123  auto wy_h = in_y - y_low;
124  auto wx_h = in_x - x_low;
125  auto wy_l = 1 - wy_h;
126  auto wx_l = 1 - wx_h;
127 
128  PrimExpr val = 0;
129  std::vector<std::vector<PrimExpr>> wx_xp{{wx_l, x_low}, {wx_h, x_high}};
130  std::vector<std::vector<PrimExpr>> wy_yp{{wy_l, y_low}, {wy_h, y_high}};
131  for (auto wx_xp_ele : wx_xp) {
132  for (auto wy_yp_ele : wy_yp) {
133  auto wx = wx_xp_ele[0];
134  auto xp = wx_xp_ele[1];
135  auto wy = wy_yp_ele[0];
136  auto yp = wy_yp_ele[1];
137  val += tvm::if_then_else(0 <= yp && yp <= max_y && 0 <= xp && xp <= max_x,
138  wx * wy * input(batch_id, yp, xp, channel_id), 0);
139  }
140  }
141  return val;
142 }
143 
144 } // namespace detail
145 } // namespace topi
146 } // namespace tvm
147 #endif // TVM_TOPI_DETAIL_TENSOR_UTILS_H_
Constant integer literals in the program.
Definition: expr.h:501
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:219
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Tensor expression language DSL.
Definition: extracted_task.h:33
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr floor(PrimExpr x, Span span=Span())
Calculate floor(x)
Operation node can generate one or multiple Tensors.