24 #ifndef TVM_TOPI_NN_GROUP_NORM_H_
25 #define TVM_TOPI_NN_GROUP_NORM_H_
41 double epsilon, std::string name =
"T_group_norm",
43 const auto& data_type = data->dtype;
44 const auto& gamma_type = gamma.
defined() ? gamma->dtype : data_type;
45 const auto& beta_type = beta.
defined() ? beta->dtype : data_type;
46 ICHECK(data_type == gamma_type && data_type == beta_type)
47 <<
"group_norm: data, gamma and beta must have the same type";
49 <<
"group_norm: only support float32 and float16 for now";
52 int ndim = data->shape.size();
53 channel_axis =
GetRealAxis(
static_cast<int>(ndim), {channel_axis})[0];
55 auto shape = data->shape;
58 for (
int i = 0; i < ndim; ++i) {
59 if (i == channel_axis) {
60 new_shape.push_back(num_groups);
61 new_shape.push_back(group_size);
63 new_shape.push_back(
shape[i]);
70 data_reshaped =
reshape(data, new_shape);
75 gamma_reshaped =
reshape(gamma, {num_groups, group_size});
79 beta_reshaped =
reshape(beta, {num_groups, group_size});
83 std::vector<int> new_axes{channel_axis + 1};
84 for (
auto axis : axes) {
85 int new_axis =
GetRealAxis(
static_cast<int>(ndim), {axis})[0];
86 if (new_axis < channel_axis) {
87 new_axes.push_back(new_axis);
88 }
else if (new_axis > channel_axis) {
89 new_axes.push_back(new_axis + 1);
91 ICHECK(
false) <<
"axes can not contain channel axis";
94 std::sort(new_axes.begin(), new_axes.end());
97 ndim = data_reshaped->shape.size();
103 auto compute = [ndim, &new_axes, &reduce_axes, &func, &data_reshaped](
const Array<Var>& indices) {
108 for (
int i = 0; i < ndim; ++i) {
109 if (std::find(new_axes.begin(), new_axes.end(), i) != new_axes.end()) {
111 eval_range.
push_back(reduce_axes[red_counter]);
114 eval_range.
push_back(indices[arg_counter]);
118 auto square = [](
const PrimExpr& x) {
return x * x; };
119 return func({data_reshaped(eval_range), square(data_reshaped(eval_range))}, reduce_axes,
126 auto temp_x = temp_x_x2[0];
127 auto temp_x2 = temp_x_x2[1];
129 for (
auto axis : new_axes) {
130 reduce_extent *= data_reshaped->shape[axis];
132 auto group_norm_func = [&](
const Array<Var>& indices) {
133 Array<Var> reduce_indices, non_reduce_indices, gamma_indices;
134 for (
int i = 0, n =
static_cast<int>(indices.size()); i < n; ++i) {
135 if (std::find(new_axes.begin(), new_axes.end(), i) != new_axes.end()) {
138 non_reduce_indices.
push_back(indices[i]);
141 gamma_indices = {indices[channel_axis], indices[channel_axis + 1]};
142 auto mean = temp_x(non_reduce_indices) / reduce_extent;
143 auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean;
157 auto group_norm_out =
tvm::te::compute(data_reshaped->shape, group_norm_func, name, tag);
158 auto group_norm_out_reshaped =
reshape(group_norm_out,
shape);
159 return group_norm_out_reshaped;
Reference to PrimExprNode.
Definition: expr.h:115
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:236
bool defined() const
Definition: object.h:552
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Managed reference to CastNode.
Definition: expr.h:117
Tensor expression language DSL.
Definition: extracted_task.h:33
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:962
Tensor group_norm(const Tensor &data, const Tensor &gamma, const Tensor &beta, int num_groups, int channel_axis, const Array< Integer > &axes, double epsilon, std::string name="T_group_norm", std::string tag=kInjective)
Definition: group_norm.h:39
FCommReduce MakeTupleSumReducer()
Create communitive reducer summing over tuples.
Definition: reduction.h:587
constexpr auto kInjective
Definition: tags.h:33
Tensor reshape(const Tensor &x, Array< PrimExpr > newshape, std::string name="T_reshape", std::string tag=kInjective)
Reshape a tensor.
Definition: transform.h:327
tvm::PrimExpr multiply(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:225
constexpr auto kCommReduce
Definition: tags.h:34
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type,...
Definition: elemwise.h:281
Array< IterVar > MakeReduceAxes(const std::vector< int > &real_axis, const Tensor &data)
Enumerate the axes for a reduce op.
Definition: reduction.h:89
std::vector< int > GetRealAxis(int ndim, const Array< Integer > &axis)
Convert a reduction axis which could be empty or have negative elements into a real axis with valid d...
Definition: reduction.h:65
tvm::PrimExpr add(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:197
Array< PrimExpr > MakeReduceTargetShape(const std::vector< int > &real_axis, const Tensor &data, bool keepdims, bool atleast1d)
Calculate the target shape for a reduce op.
Definition: reduction.h:99
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1913
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr rsqrt(PrimExpr x, Span span=Span())
Definition: op.h:713
PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b)
Operation node can generate one or multiple Tensors.