tvm
local_response_norm.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_NN_LOCAL_RESPONSE_NORM_H_
25 #define TVM_TOPI_NN_LOCAL_RESPONSE_NORM_H_
26 
27 #include <tvm/te/operation.h>
28 #include <tvm/topi/tags.h>
29 
30 #include <string>
31 
32 namespace tvm {
33 namespace topi {
34 namespace nn {
35 
36 using namespace tvm::te;
37 
52 inline Tensor lrn(const Tensor& data, int size, int axis = 1, float alpha = 0.0001,
53  float beta = 0.75, float bias = 2, std::string name = "tensor",
54  std::string tag = kBroadcast) {
55  ICHECK_EQ(data->shape.size(), 4) << "LRN requires 4-D input";
56  ICHECK_EQ(size % 2, 1) << "size should be odd number";
57  ICHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC";
58  ICHECK(data->dtype.is_float()) << "datatype should be float";
59  auto input_shape = data->shape;
60  Array<PrimExpr> pad_before{0, 0, 0, 0};
61  Array<PrimExpr> pad_after{0, 0, 0, 0};
62  pad_before.Set(axis, static_cast<PrimExpr>(size / 2));
63  pad_after.Set(axis, static_cast<PrimExpr>(size / 2));
64  auto pad_data = pad(data, pad_before, pad_after, 0, "pad_data");
65  auto rxs = tvm::te::reduce_axis(Range(0, size), "rxs");
66  Tensor sqr_sum;
67  if (axis == 1) {
68  sqr_sum = tvm::te::compute(
69  input_shape,
70  [&](Var i, Var l, Var j, Var k) {
71  return tvm::sum(pad_data(i, l + rxs, j, k) * pad_data(i, l + rxs, j, k), {rxs});
72  },
73  "tensor", "sqr_sum");
74  } else if (axis == 3) {
75  sqr_sum = tvm::te::compute(
76  input_shape,
77  [&](Var i, Var l, Var j, Var k) {
78  return tvm::sum(pad_data(i, l, j, k + rxs) * pad_data(i, l, j, k + rxs), {rxs});
79  },
80  "tensor", "sqr_sum");
81  }
82  PrimExpr alpha_imm = tvm::te::make_const(data->dtype, alpha);
83  PrimExpr beta_imm = tvm::te::make_const(data->dtype, beta);
84  PrimExpr bias_imm = tvm::te::make_const(data->dtype, bias);
85  auto sqrt_sum_up = tvm::te::compute(
86  input_shape,
87  [&](Var i, Var j, Var k, Var l) {
88  return tvm::pow(bias_imm + (div(alpha_imm * sqr_sum(i, j, k, l), size)), beta_imm);
89  },
90  "tensor", kElementWise);
91  return topi::divide(data, sqrt_sum_up);
92 }
93 } // namespace nn
94 } // namespace topi
95 } // namespace tvm
96 #endif // TVM_TOPI_NN_LOCAL_RESPONSE_NORM_H_
Reference to PrimExprNode.
Definition: expr.h:115
Range container
Definition: expr.h:725
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void Set(int64_t i, T value)
set i-th element of the array.
Definition: array.h:621
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
a named variable in TIR
Definition: var.h:89
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:962
Tensor lrn(const Tensor &data, int size, int axis=1, float alpha=0.0001, float beta=0.75, float bias=2, std::string name="tensor", std::string tag=kBroadcast)
Local response normalization inference operator.
Definition: local_response_norm.h:52
constexpr auto kElementWise
Definition: tags.h:32
constexpr auto kBroadcast
Definition: tags.h:36
tvm::te::Tensor pad(const tvm::te::Tensor &t, const tvm::Array< tvm::PrimExpr > &pad_before, tvm::Array< tvm::PrimExpr > pad_after=tvm::Array< tvm::PrimExpr >(), PrimExpr pad_value=PrimExpr(), std::string name="T_pad", std::string tag=kElementWise, std::string pad_mode="constant", const Array< PrimExpr > *dyn_output_shape=nullptr)
Creates an operation that performs padding.
Definition: nn.h:155
tvm::PrimExpr divide(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:239
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
PrimExpr pow(PrimExpr x, PrimExpr y, Span span=Span())
Calculate power(x, y)
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Operation node can generate one or multiple Tensors.
External function interface to rocBLAS libraries.