24 #ifndef TVM_TOPI_DETAIL_TENSOR_UTILS_H_
25 #define TVM_TOPI_DETAIL_TENSOR_UTILS_H_
43 inline bool is_empty_shape(
const Array<PrimExpr>& x) {
44 bool is_empty =
false;
45 for (
const auto& dim : x) {
47 if (int_dim->value == 0) {
66 inline PrimExpr bilinear_sample_nchw(
const Tensor& input,
const Array<PrimExpr>& indices,
67 const PrimExpr max_y,
const PrimExpr max_x) {
68 auto batch_id = indices[0];
69 auto channel_id = indices[1];
70 auto in_y = indices[2];
71 auto in_x = indices[3];
74 auto y_high = y_low + 1;
77 auto x_high = x_low + 1;
79 auto wy_h = in_y - y_low;
80 auto wx_h = in_x - x_low;
85 std::vector<std::vector<PrimExpr>> wx_xp{{wx_l, x_low}, {wx_h, x_high}};
86 std::vector<std::vector<PrimExpr>> wy_yp{{wy_l, y_low}, {wy_h, y_high}};
87 for (
auto wx_xp_ele : wx_xp) {
88 for (
auto wy_yp_ele : wy_yp) {
89 auto wx = wx_xp_ele[0];
90 auto xp = wx_xp_ele[1];
91 auto wy = wy_yp_ele[0];
92 auto yp = wy_yp_ele[1];
94 wx * wy * input(batch_id, channel_id, yp, xp), 0);
110 inline PrimExpr bilinear_sample_nhwc(
const Tensor& input,
const Array<PrimExpr>& indices,
111 const PrimExpr max_y,
const PrimExpr max_x) {
112 auto batch_id = indices[0];
113 auto channel_id = indices[3];
114 auto in_y = indices[1];
115 auto in_x = indices[2];
118 auto y_high = y_low + 1;
121 auto x_high = x_low + 1;
123 auto wy_h = in_y - y_low;
124 auto wx_h = in_x - x_low;
125 auto wy_l = 1 - wy_h;
126 auto wx_l = 1 - wx_h;
129 std::vector<std::vector<PrimExpr>> wx_xp{{wx_l, x_low}, {wx_h, x_high}};
130 std::vector<std::vector<PrimExpr>> wy_yp{{wy_l, y_low}, {wy_h, y_high}};
131 for (
auto wx_xp_ele : wx_xp) {
132 for (
auto wy_yp_ele : wy_yp) {
133 auto wx = wx_xp_ele[0];
134 auto xp = wx_xp_ele[1];
135 auto wy = wy_yp_ele[0];
136 auto yp = wy_yp_ele[1];
138 wx * wy * input(batch_id, yp, xp, channel_id), 0);
Constant integer literals in the program.
Definition: expr.h:501
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:219
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Tensor expression language DSL.
Definition: extracted_task.h:33
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span=Span())
Conditional expression.
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr floor(PrimExpr x, Span span=Span())
Calculate floor(x)
Operation node can generate one or multiple Tensors.