24 #ifndef TVM_TOPI_NN_INSTANCE_NORM_H_
25 #define TVM_TOPI_NN_INSTANCE_NORM_H_
54 int channel_axis,
const Array<Integer>& axis,
double epsilon,
55 std::string name =
"T_instance_norm", std::string tag =
kInjective) {
56 const auto& data_type = data->dtype;
57 const auto& gamma_type = gamma.defined() ? gamma->dtype : data_type;
58 const auto& beta_type = beta.defined() ? beta->dtype : data_type;
59 ICHECK(data_type == gamma_type && data_type == beta_type)
60 <<
"instance_norm: data, gamma and beta must have the same type";
62 <<
"instance_norm: only support float32 and float16 for now";
65 auto ndim = data->shape.size();
66 ICHECK_NE(ndim, 0) <<
"Cannot reduce a 0 dim Tensor";
67 auto real_axis =
GetRealAxis(
static_cast<int>(ndim), axis);
73 auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func,
74 &data](
const Array<Var>& indices) {
75 Array<PrimExpr> eval_range;
79 for (
size_t i = 0; i < ndim; ++i) {
80 if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
82 eval_range.push_back(reduce_axes[red_counter]);
85 eval_range.push_back(indices[arg_counter]);
89 auto square = [is_float16](
const PrimExpr& x) {
97 reduce_axes,
nullptr);
99 return func({data(eval_range), square(data(eval_range))}, reduce_axes,
nullptr);
106 auto temp_x = temp_x_x2[0];
107 auto temp_x2 = temp_x_x2[1];
109 auto reduce_extent =
make_const(data->dtype, 1);
110 for (
int i : real_axis) {
111 reduce_extent *= data->shape[i];
113 auto instance_norm_func = [&](
const Array<Var>& indices) {
114 Array<Var> reduce_indices, non_reduce_indices;
116 for (
int i = 0, n =
static_cast<int>(indices.size()); i < n; ++i) {
117 if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
118 reduce_indices.push_back(indices[i]);
120 non_reduce_indices.push_back(indices[i]);
124 channel = indices[channel_axis];
125 auto mean = temp_x(non_reduce_indices) / reduce_extent;
126 auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean;
132 if (beta.defined()) {
DataType dtype
The runtime data type of the primitive expression.
Definition: expr.h:111
Reference to PrimExprNode.
Definition: expr.h:129
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:291
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:100
Managed reference to CastNode.
Definition: expr.h:100
a named variable in TIR
Definition: var.h:78
Tensor expression language DSL.
Definition: extracted_task.h:33
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ffi::Any > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:980
Tensor instance_norm(const Tensor &data, const Tensor &gamma, const Tensor &beta, int channel_axis, const Array< Integer > &axis, double epsilon, std::string name="T_instance_norm", std::string tag=kInjective)
Instance normalization.
Definition: instance_norm.h:53
FCommReduce MakeTupleSumReducer()
Create communitive reducer summing over tuples.
Definition: reduction.h:589
constexpr auto kInjective
Definition: tags.h:33
tvm::PrimExpr multiply(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:225
constexpr auto kCommReduce
Definition: tags.h:34
Array< IterVar > MakeReduceAxes(const std::vector< int > &real_axis, const Tensor &data)
Enumerate the axes for a reduce op.
Definition: reduction.h:89
tvm::PrimExpr add(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:197
Array< PrimExpr > MakeReduceTargetShape(const std::vector< int > &real_axis, const Tensor &data, bool keepdims, bool atleast1d)
Calculate the target shape for a reduce op.
Definition: reduction.h:99
std::vector< int > GetRealAxis(int ndim, const Optional< Array< Integer >> &axis)
Convert a reduction axis which could be empty or have negative elements into a real axis with valid d...
Definition: reduction.h:65
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr rsqrt(PrimExpr x, Span span=Span())
Definition: op.h:731
Operation node can generate one or multiple Tensors.