tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
group_norm.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_NN_GROUP_NORM_H_
25 #define TVM_TOPI_NN_GROUP_NORM_H_
26 
27 #include <tvm/te/operation.h>
28 #include <tvm/topi/tags.h>
29 
30 #include <algorithm>
31 #include <string>
32 #include <vector>
33 
34 namespace tvm {
35 namespace topi {
36 namespace nn {
37 
38 using namespace tvm::te;
39 
40 inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta,
41  int num_groups, int channel_axis, const Array<Integer>& axes,
42  double epsilon, std::string name = "T_group_norm",
43  std::string tag = kInjective) {
44  // reshape data C -> G, C/G
45  int ndim = data->shape.size();
46  channel_axis = GetRealAxis(ndim, {channel_axis})[0];
47 
48  auto shape = data->shape;
49  auto group_size = floordiv(shape[channel_axis], num_groups);
50  auto new_shape = Array<PrimExpr>();
51  for (int i = 0; i < ndim; ++i) {
52  if (i == channel_axis) {
53  new_shape.push_back(num_groups);
54  new_shape.push_back(group_size);
55  } else {
56  new_shape.push_back(shape[i]);
57  }
58  }
59  auto data_reshaped = reshape(data, new_shape);
60  // reshape gamma and beta, C -> G, C/G
61  Tensor gamma_reshaped;
62  if (gamma.defined()) {
63  gamma_reshaped = reshape(gamma, {num_groups, group_size});
64  }
65  Tensor beta_reshaped;
66  if (beta.defined()) {
67  beta_reshaped = reshape(beta, {num_groups, group_size});
68  }
69 
70  // get the new axes to normalize after reshape
71  std::vector<int> new_axes{channel_axis + 1};
72  for (auto axis : axes) {
73  int new_axis = GetRealAxis(ndim, {axis})[0];
74  if (new_axis < channel_axis) {
75  new_axes.push_back(new_axis);
76  } else if (new_axis > channel_axis) {
77  new_axes.push_back(new_axis + 1);
78  } else {
79  ICHECK(false) << "axes can not contain channel axis";
80  }
81  }
82  std::sort(new_axes.begin(), new_axes.end());
83 
84  // sum x and x^2
85  ndim = data_reshaped->shape.size();
86  auto reduce_axes = MakeReduceAxes(new_axes, data_reshaped);
87  auto target_shape =
88  MakeReduceTargetShape(new_axes, data_reshaped, /*keepdims=*/false, /*atleast1d=*/true);
89  auto func = MakeTupleSumReducer();
90 
91  auto compute = [ndim, &new_axes, &reduce_axes, &func, &data_reshaped](const Array<Var>& indices) {
92  Array<PrimExpr> eval_range;
93  int arg_counter = 0;
94  int red_counter = 0;
95 
96  for (int i = 0; i < ndim; ++i) {
97  if (std::find(new_axes.begin(), new_axes.end(), i) != new_axes.end()) {
98  // new_axes contains i
99  eval_range.push_back(reduce_axes[red_counter]);
100  red_counter++;
101  } else {
102  eval_range.push_back(indices[arg_counter]);
103  arg_counter++;
104  }
105  }
106  auto square = [](const PrimExpr& x) { return x * x; };
107  return func({data_reshaped(eval_range), square(data_reshaped(eval_range))}, reduce_axes,
108  nullptr);
109  };
110 
111  auto temp_x_x2 =
112  tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduce);
113 
114  auto temp_x = temp_x_x2[0];
115  auto temp_x2 = temp_x_x2[1];
116  auto reduce_extent = make_const(data->dtype, 1);
117  for (auto axis : new_axes) {
118  reduce_extent *= data_reshaped->shape[axis];
119  }
120  auto group_norm_func = [&](const Array<Var>& indices) {
121  Array<Var> reduce_indices, non_reduce_indices, gamma_indices;
122  for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
123  if (std::find(new_axes.begin(), new_axes.end(), i) != new_axes.end()) {
124  reduce_indices.push_back(indices[i]);
125  } else {
126  non_reduce_indices.push_back(indices[i]);
127  }
128  }
129  gamma_indices = {indices[channel_axis], indices[channel_axis + 1]};
130  auto mean = temp_x(non_reduce_indices) / reduce_extent;
131  auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean;
132  auto group_norm =
133  (data_reshaped(indices) - mean) * tvm::rsqrt(var + make_const(data->dtype, epsilon));
134  if (gamma.defined()) {
135  group_norm = topi::multiply(group_norm, gamma_reshaped(gamma_indices));
136  }
137  if (beta.defined()) {
138  group_norm = topi::add(group_norm, beta_reshaped(gamma_indices));
139  }
140  return group_norm;
141  };
142  auto group_norm_out = tvm::te::compute(data_reshaped->shape, group_norm_func, name, tag);
143  auto group_norm_out_reshaped = reshape(group_norm_out, shape);
144  return group_norm_out_reshaped;
145 }
146 
147 } // namespace nn
148 } // namespace topi
149 } // namespace tvm
150 
151 #endif // TVM_TOPI_NN_GROUP_NORM_H_
Reference to PrimExprNode.
Definition: expr.h:114
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
bool defined() const
Definition: object.h:550
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Tensor expression language DSL.
Definition: extracted_task.h:33
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:961
Tensor group_norm(const Tensor &data, const Tensor &gamma, const Tensor &beta, int num_groups, int channel_axis, const Array< Integer > &axes, double epsilon, std::string name="T_group_norm", std::string tag=kInjective)
Definition: group_norm.h:40
FCommReduce MakeTupleSumReducer()
Create communitive reducer summing over tuples.
Definition: reduction.h:587
constexpr auto kInjective
Definition: tags.h:33
Tensor reshape(const Tensor &x, Array< PrimExpr > newshape, std::string name="T_reshape", std::string tag=kInjective)
Reshape a tensor.
Definition: transform.h:321
tvm::PrimExpr multiply(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:225
constexpr auto kCommReduce
Definition: tags.h:34
Array< IterVar > MakeReduceAxes(const std::vector< int > &real_axis, const Tensor &data)
Enumerate the axes for a reduce op.
Definition: reduction.h:89
std::vector< int > GetRealAxis(int ndim, const Array< Integer > &axis)
Convert a reduction axis which could be empty or have negative elements into a real axis with valid d...
Definition: reduction.h:65
tvm::PrimExpr add(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:197
Array< PrimExpr > MakeReduceTargetShape(const std::vector< int > &real_axis, const Tensor &data, bool keepdims, bool atleast1d)
Calculate the target shape for a reduce op.
Definition: reduction.h:99
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1766
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr rsqrt(PrimExpr x, Span span=Span())
Definition: op.h:712
PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b)
Operation node can generate one or multiple Tensors.
External function interface to rocBLAS libraries.