tvm
rms_norm.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_NN_RMS_NORM_H_
25 #define TVM_TOPI_NN_RMS_NORM_H_
26 
27 #include <tvm/te/operation.h>
28 #include <tvm/topi/reduction.h>
29 #include <tvm/topi/tags.h>
30 
31 #include <string>
32 
33 namespace tvm {
34 namespace topi {
35 namespace nn {
36 
37 using namespace tvm::te;
38 
50 inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array<Integer>& axis,
51  double epsilon, std::string name = "T_rms_norm",
52  std::string tag = kInjective) {
53  const auto& data_type = data->dtype;
54  const auto& weight_type = weight.defined() ? weight->dtype : data_type;
55  ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the same type";
56 
57  const auto& data_fp32 = cast(data, DataType::Float(32));
58  const auto& weight_fp32 = cast(weight, DataType::Float(32));
59 
60  auto square = multiply(data_fp32, data_fp32);
61  auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true);
62 
63  auto ndim = data_fp32->shape.size();
64  ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
65  auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
66  auto reduce_extent = make_const(data_fp32->dtype, 1);
67  for (int i : real_axis) {
68  reduce_extent *= data_fp32->shape[i];
69  }
70  auto rsqrt_func = [&](const Array<Var>& indices) {
71  Array<Var> non_reduce_indices;
72  for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
73  if (std::find(real_axis.begin(), real_axis.end(), i) == real_axis.end()) {
74  non_reduce_indices.push_back(indices[i]);
75  }
76  }
77  auto output =
78  tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + make_const(data_type, epsilon));
79  return output;
80  };
81  auto rsqrt_shape = Array<PrimExpr>();
82  for (int i = 0, n = static_cast<int>(data_fp32->shape.size()); i < n; ++i) {
83  if (std::find(real_axis.begin(), real_axis.end(), i) == real_axis.end()) {
84  rsqrt_shape.push_back(data_fp32->shape[i]);
85  }
86  }
87  auto rsqrt = tvm::te::compute(rsqrt_shape, rsqrt_func, "rsqrt", tag);
88 
89  auto rms_norm_func = [&](const Array<Var>& indices) {
90  Array<Var> reduce_indices, non_reduce_indices;
91  for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
92  if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
93  reduce_indices.push_back(indices[i]);
94  } else {
95  non_reduce_indices.push_back(indices[i]);
96  }
97  }
98  auto output = rsqrt(non_reduce_indices) * data_fp32(indices) * weight_fp32(reduce_indices);
99  return output;
100  };
101  auto rms_norm = tvm::te::compute(data_fp32->shape, rms_norm_func, name, tag);
102 
103  return cast(rms_norm, data_type);
104 }
105 
106 } // namespace nn
107 } // namespace topi
108 } // namespace tvm
109 
110 #endif // TVM_TOPI_NN_RMS_NORM_H_
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:236
bool defined() const
Definition: object.h:552
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Tensor expression language DSL.
Definition: extracted_task.h:33
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:962
Tensor rms_norm(const Tensor &data, const Tensor &weight, const Array< Integer > &axis, double epsilon, std::string name="T_rms_norm", std::string tag=kInjective)
Root mean square normalization.
Definition: rms_norm.h:50
constexpr auto kInjective
Definition: tags.h:33
tvm::PrimExpr multiply(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:225
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type,...
Definition: elemwise.h:281
Tensor rsqrt(const Tensor &x, std::string name="tensor", std::string tag=kElementWise)
Creates an operation that returns rsqrt of a given tensor.
Definition: elemwise.h:235
std::vector< int > GetRealAxis(int ndim, const Array< Integer > &axis)
Convert a reduction axis which could be empty or have negative elements into a real axis with valid d...
Definition: reduction.h:65
Tensor sum(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that sums array elements over a given axis.
Definition: reduction.h:326
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr rsqrt(PrimExpr x, Span span=Span())
Definition: op.h:713
Operation node can generate one or multiple Tensors.
Reduction op constructors.
External function interface to rocBLAS libraries.