tvm
layer_norm.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_NN_LAYER_NORM_H_
25 #define TVM_TOPI_NN_LAYER_NORM_H_
26 
27 #include <tvm/te/operation.h>
28 #include <tvm/topi/tags.h>
29 
30 #include <string>
31 
32 namespace tvm {
33 namespace topi {
34 namespace nn {
35 
36 using namespace tvm::te;
37 
51 inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta,
52  const Array<Integer>& axis, double epsilon,
53  std::string name = "T_layer_norm", std::string tag = kInjective) {
54  const auto& data_type = data->dtype;
55  const auto& gamma_type = gamma.defined() ? gamma->dtype : data_type;
56  const auto& beta_type = beta.defined() ? beta->dtype : data_type;
57  ICHECK(data_type == gamma_type && data_type == beta_type)
58  << "layer_norm: data, gamma and beta must have the same type";
59  ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16))
60  << "layer_norm: only support float32 and float16 for now";
61  bool is_float16 = data_type == DataType::Float(16);
62  // sum x and x^2
63  auto ndim = data->shape.size();
64  ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
65  auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
66  auto reduce_axes = MakeReduceAxes(real_axis, data);
67  auto target_shape =
68  MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/true);
69  auto func = MakeTupleSumReducer();
70 
71  auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func,
72  &data](const Array<Var>& indices) {
73  Array<PrimExpr> eval_range;
74  int arg_counter = 0;
75  int red_counter = 0;
76 
77  for (size_t i = 0; i < ndim; ++i) {
78  if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
79  // real_axis contains i
80  eval_range.push_back(reduce_axes[red_counter]);
81  red_counter++;
82  } else {
83  eval_range.push_back(indices[arg_counter]);
84  arg_counter++;
85  }
86  }
87  auto square = [is_float16](const PrimExpr& x) {
88  if (is_float16) {
89  return Cast(DataType::Float(32), x) * Cast(DataType::Float(32), x);
90  }
91  return x * x;
92  };
93  if (is_float16) {
94  return func({Cast(DataType::Float(32), data(eval_range)), square(data(eval_range))},
95  reduce_axes, nullptr);
96  } else {
97  return func({data(eval_range), square(data(eval_range))}, reduce_axes, nullptr);
98  }
99  };
100 
101  auto temp_x_x2 =
102  tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduce);
103 
104  auto temp_x = temp_x_x2[0];
105  auto temp_x2 = temp_x_x2[1];
106 
107  auto reduce_extent = make_const(data->dtype, 1);
108  for (int i : real_axis) {
109  reduce_extent *= data->shape[i];
110  }
111  auto layer_norm_func = [&](const Array<Var>& indices) {
112  Array<Var> reduce_indices, non_reduce_indices;
113  for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
114  if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
115  reduce_indices.push_back(indices[i]);
116  } else {
117  non_reduce_indices.push_back(indices[i]);
118  }
119  }
120  auto mean = temp_x(non_reduce_indices) / reduce_extent;
121  auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean;
122  auto layer_norm = (data(indices) - mean) * tvm::rsqrt(var + make_const(var->dtype, epsilon));
123  if (is_float16) {
125  }
126  layer_norm = topi::multiply(layer_norm, gamma(reduce_indices));
127  if (beta.defined()) {
128  layer_norm = topi::add(layer_norm, beta(reduce_indices));
129  }
130  return layer_norm;
131  };
132  return tvm::te::compute(data->shape, layer_norm_func, name, tag);
133 }
134 
135 } // namespace nn
136 } // namespace topi
137 } // namespace tvm
138 
139 #endif // TVM_TOPI_NN_LAYER_NORM_H_
DataType dtype
The runtime data type of the primitive expression.
Definition: expr.h:102
Reference to PrimExprNode.
Definition: expr.h:115
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:236
bool defined() const
Definition: object.h:552
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Managed reference to CastNode.
Definition: expr.h:117
Tensor expression language DSL.
Definition: extracted_task.h:33
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:962
Tensor layer_norm(const Tensor &data, const Tensor &gamma, const Tensor &beta, const Array< Integer > &axis, double epsilon, std::string name="T_layer_norm", std::string tag=kInjective)
Layer normalization.
Definition: layer_norm.h:51
FCommReduce MakeTupleSumReducer()
Create communitive reducer summing over tuples.
Definition: reduction.h:587
constexpr auto kInjective
Definition: tags.h:33
tvm::PrimExpr multiply(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:225
constexpr auto kCommReduce
Definition: tags.h:34
Array< IterVar > MakeReduceAxes(const std::vector< int > &real_axis, const Tensor &data)
Enumerate the axes for a reduce op.
Definition: reduction.h:89
std::vector< int > GetRealAxis(int ndim, const Array< Integer > &axis)
Convert a reduction axis which could be empty or have negative elements into a real axis with valid d...
Definition: reduction.h:65
tvm::PrimExpr add(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:197
Array< PrimExpr > MakeReduceTargetShape(const std::vector< int > &real_axis, const Tensor &data, bool keepdims, bool atleast1d)
Calculate the target shape for a reduce op.
Definition: reduction.h:99
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr rsqrt(PrimExpr x, Span span=Span())
Definition: op.h:713
Operation node can generate one or multiple Tensors.
External function interface to rocBLAS libraries.