tvm
Enumerations | Functions
tvm::topi::nn Namespace Reference

Enumerations

enum  PoolType : int { kAvgPool , kMaxPool }
 Pooling type. More...
 

Functions

tvm::te::Tensor bias_add (const tvm::te::Tensor &data, const tvm::te::Tensor &bias, int axis)
 Creates an operation that calculates data + bias. More...
 
tvm::te::Tensor binarize_pack (const tvm::te::Tensor &data, int axis, std::string name="PackedInput", std::string tag="binarize_pack")
 Binarization and bit-packing along a certain axis. More...
 
tvm::te::Tensor binary_dense (const tvm::te::Tensor &data, const tvm::te::Tensor &weight)
 Binary matrix multiplication using xor and bit-count. More...
 
tvm::te::Tensor dense (const tvm::te::Tensor &data, const tvm::te::Tensor &weight, const tvm::te::Tensor &bias, const DataType &out_dtype)
 Creates an operation that calculates data * weight^T + bias. More...
 
PrimExpr all (Array< PrimExpr > args)
 Create a new expression of the logical and of all conditions in the arguments. More...
 
Tensor dilate (const Tensor &x, Array< PrimExpr > strides, double dilation_value, std::string name="tensor", std::string tag=kInjective)
 Dilate data with given dilation value (0 by default). More...
 
Tensor flatten (const Tensor &x, std::string name="tensor", std::string tag=kInjective)
 Flattens the input tensor into a 2-D tensor by collapsing higher dimensions. This requires the input tensor to have constant sized dimensions. More...
 
Tensor group_norm (const Tensor &data, const Tensor &gamma, const Tensor &beta, int num_groups, int channel_axis, const Array< Integer > &axes, double epsilon, std::string name="T_group_norm", std::string tag=kInjective)
 
Tensor instance_norm (const Tensor &data, const Tensor &gamma, const Tensor &beta, const Array< Integer > &axis, double epsilon, std::string name="T_instance_norm", std::string tag=kInjective)
 Instance normalization. More...
 
Tensor layer_norm (const Tensor &data, const Tensor &gamma, const Tensor &beta, const Array< Integer > &axis, double epsilon, std::string name="T_layer_norm", std::string tag=kInjective)
 Layer normalization. More...
 
Tensor lrn (const Tensor &data, int size, int axis=1, float alpha=0.0001, float beta=0.75, float bias=2, std::string name="tensor", std::string tag=kBroadcast)
 Local response normalization inference operator. More...
 
Tensor scale_shift_nchw (const Tensor &x, const Tensor &scale, const Tensor &shift, std::string name="ScaleShift", std::string tag=kBroadcast)
 Scale and shift with NCHW order. More...
 
Tensor scale_shift_nhwc (const Tensor &x, const Tensor &scale, const Tensor &shift, std::string name="ScaleShift", std::string tag=kBroadcast)
 Scale and shift with NHWC order. More...
 
Tensor pool_grad_impl (const Tensor &out_grad, const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const size_t height_axis, const size_t width_axis, bool count_include_pad)
 
bool find_depth_height_width (const std::string &layout, int *depth_axis, int *height_axis, int *width_axis)
 Find index of Depth, Height or Width dimension in a layout string. More...
 
bool find_height_width (const std::string &layout, int *height_axis, int *width_axis)
 
bool find_width (const std::string &layout, int *width_axis)
 
Tensor pool_grad (const Tensor &out_grad, const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCHW", bool count_include_pad=true)
 Calculate gradient of pooling on height and width dimension of data. It decides the height and width dimension according to the layout string, in which 'W' and 'H' means width and height respectively. Width and height dimension cannot be split. For example, NCHW, NCHW16c, etc. are valid for pool, while NCHW16w, NCHW16h are not. See layout for more information of the layout string convention. More...
 
PrimExpr start_index (const Var &out_index, const PrimExpr &odim, const PrimExpr &idim)
 
PrimExpr end_index (const Var &out_index, const PrimExpr &odim, const PrimExpr &idim)
 
Tensor adaptive_pool_impl (const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::vector< int > &axes)
 Perform adaptive pooling on N dimensional data. More...
 
Tensor adaptive_pool (const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCHW")
 Adaptively perform pooling on height and width dimension of data. The pooling kernel and stride sizes are automatically chosen for desired output sizes. It decides the height and width dimension according to the layout string, in which 'W' and 'H' means width and height respectively. Width and height dimension cannot be split. For example, NCHW, NCHW16c, etc. are valid for pool, while NCHW16w, NCHW16h are not. See layout for more information of the layout string convention. More...
 
Tensor adaptive_pool3d (const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCDHW")
 Adaptively perform pooling on three dimensional data. See the two dimensional version above for details. More...
 
Tensor adaptive_pool1d (const Tensor &x, const Array< PrimExpr > &output_size, PoolType pool_type, const std::string &layout="NCW")
 Adaptively perform pooling on one dimensional data. See the two dimensional version above for details. More...
 
Tensor global_pool (const Tensor &x, PoolType pool_type, const std::string &layout="NCHW")
 Perform global pooling on height and width dimension of data. It decides the height and width dimension according to the layout string, in which 'W' and 'H' means width and height respectively. Width and height dimension cannot be split. For example, NCHW, NCHW16c, ... are valid for global_pool, while NCHW16w, NCHW16h are not. See layout for more information of the layout string convention. More...
 
Tensor pool_impl_nd (const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::vector< int > &axis, bool count_include_pad)
 Perform pooling on N-dimension of data. More...
 
Tensor pool1d (const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCW", bool count_include_pad=true)
 Perform pooling on the width dimension of data. Width axis is determined by the layout string in which 'W' means width. Width dimension cannot be split. For example, NCW, NCW16c, etc. are valid for pool, while NCW16w is not. See layout for more information of the layout string convention. More...
 
Tensor pool2d (const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCHW", bool count_include_pad=true)
 Perform pooling on height and width dimension of data. It decides the height and width dimension according to the layout string, in which 'W' and 'H' means width and height respectively. Width and height dimension cannot be split. For example, NCHW, NCHW16c, etc. are valid for pool, while NCHW16w, NCHW16h are not. See layout for more information of the layout string convention. More...
 
Tensor pool3d (const Tensor &x, const Array< PrimExpr > &kernel_size, const Array< PrimExpr > &stride_size, const Array< PrimExpr > &dilation_size, const Array< PrimExpr > &padding_size, PoolType pool_type, bool ceil_mode, const std::string &layout="NCDHW", bool count_include_pad=true)
 Perform pooling on depth, height and width dimension of data. It decides the depth, height and width dimension according to the layout string, in which 'D', 'W' and 'H' means depth, width and height respectively. Depth, Width and height dimension cannot be split. For example, NCDHW, NCDHW16c, etc. are valid for pool, while NCDHW16d, NCDHW16w or NCDHW16h are not. See layout for more information of the layout string convention. More...
 
Tensor rms_norm (const Tensor &data, const Tensor &weight, const Array< Integer > &axis, double epsilon, std::string name="T_rms_norm", std::string tag=kInjective)
 Root mean square normalization. More...
 
Tensor softmax (const Tensor &x, int axis=-1, std::string name="tensor", std::string tag="softmax_output")
 Softmax activation. More...
 
Tensor log_softmax (const Tensor &x, std::string name="tensor", std::string tag="log_softmax_output")
 Log softmax activation. More...
 

Enumeration Type Documentation

◆ PoolType

Pooling type.

Enumerator
kAvgPool 
kMaxPool 

Function Documentation

◆ adaptive_pool()

Tensor tvm::topi::nn::adaptive_pool ( const Tensor x,
const Array< PrimExpr > &  output_size,
PoolType  pool_type,
const std::string &  layout = "NCHW" 
)
inline

Adaptively perform pooling on height and width dimension of data. The pooling kernel and stride sizes are automatically chosen for desired output sizes. It decides the height and width dimension according to the layout string, in which 'W' and 'H' means width and height respectively. Width and height dimension cannot be split. For example, NCHW, NCHW16c, etc. are valid for pool, while NCHW16w, NCHW16h are not. See layout for more information of the layout string convention.

Parameters
xThe input tensor
output_sizeVector of two ints: {output_height, output_width}
pool_typeThe type of pooling operator
layoutThe input layout. Pooling supports any layout as long as 'H' and 'W' appear. The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, where upper case indicates a dimension and the corresponding lower case (with factor size) indicates the split dimension. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. (in which factor size 16 will not be used in pooling but for other operators, it can be used to decide the output shape). Since pooling does not care about the factor size of dimensions other than H and W, one can pass NCHWc as well.
Returns
The output tensor in same layout order

◆ adaptive_pool1d()

Tensor tvm::topi::nn::adaptive_pool1d ( const Tensor x,
const Array< PrimExpr > &  output_size,
PoolType  pool_type,
const std::string &  layout = "NCW" 
)
inline

Adaptively perform pooling on one dimensional data. See the two dimensional version above for details.

Parameters
xThe input tensor
output_sizeVector of one int: {output_width}
pool_typeThe type of pooling operator
layoutThe input layout. The default is "NCW".

◆ adaptive_pool3d()

Tensor tvm::topi::nn::adaptive_pool3d ( const Tensor x,
const Array< PrimExpr > &  output_size,
PoolType  pool_type,
const std::string &  layout = "NCDHW" 
)
inline

Adaptively perform pooling on three dimensional data. See the two dimensional version above for details.

Parameters
xThe input tensor
output_sizeVector of three ints: {output_depth, output_height, output_width}
pool_typeThe type of pooling operator
layoutThe input layout. The default is "NCDHW".

◆ adaptive_pool_impl()

Tensor tvm::topi::nn::adaptive_pool_impl ( const Tensor x,
const Array< PrimExpr > &  output_size,
PoolType  pool_type,
const std::vector< int > &  axes 
)
inline

Perform adaptive pooling on N dimensional data.

Parameters
xThe input tensor
output_sizeint vector of size in each dimension
pool_typeThe type of pooling operator
axesindices of each dimension
Returns
The output tensor in same layout order

◆ all()

PrimExpr tvm::topi::nn::all ( Array< PrimExpr args)

Create a new expression of the logical and of all conditions in the arguments.

Parameters
argsThe arguments to find the logical conjunction of
Returns
The logical conjunction expression

◆ bias_add()

tvm::te::Tensor tvm::topi::nn::bias_add ( const tvm::te::Tensor data,
const tvm::te::Tensor bias,
int  axis 
)
inline

Creates an operation that calculates data + bias.

Parameters
dataTensor with shape [batch, in_dim]
biasTensor with shape [batch].
axisThe axis to add the bias to.
Returns
Tensor with shape [batch, in_dim]

◆ binarize_pack()

tvm::te::Tensor tvm::topi::nn::binarize_pack ( const tvm::te::Tensor data,
int  axis,
std::string  name = "PackedInput",
std::string  tag = "binarize_pack" 
)
inline

Binarization and bit-packing along a certain axis.

Parameters
dataN-D tensor, can be any layout
axisThe axis along which to do binarization and bit-packing. This axis must have a size equal to an integer multiple of 32.
nameThe name of the operation
tagThe tag to mark the operation
Returns
Output tensor with dtype uint32

◆ binary_dense()

tvm::te::Tensor tvm::topi::nn::binary_dense ( const tvm::te::Tensor data,
const tvm::te::Tensor weight 
)
inline

Binary matrix multiplication using xor and bit-count.

Parameters
dataTensor with shape [batch, in_dim], dtype is uint32
weightTensor with shape [out_dim, in_dim], dtype is uint32
Returns
Tensor with shape [batch, out_dim], dtype is float32

◆ dense()

tvm::te::Tensor tvm::topi::nn::dense ( const tvm::te::Tensor data,
const tvm::te::Tensor weight,
const tvm::te::Tensor bias,
const DataType out_dtype 
)
inline

Creates an operation that calculates data * weight^T + bias.

Parameters
dataTensor with shape [batch, in_dim]
weightTensor with shape [out_dim, in_dim]
biasTensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
out_dtypeOutput data type. Used for mixed precision.
Returns
Tensor with shape [batch, out_dim]

◆ dilate()

Tensor tvm::topi::nn::dilate ( const Tensor x,
Array< PrimExpr strides,
double  dilation_value,
std::string  name = "tensor",
std::string  tag = kInjective 
)
inline

Dilate data with given dilation value (0 by default).

Parameters
xThe input tensor, this can have any number of dimensions and any layout.
stridesDilation stride for each dimension. Stride 1 means no dilation.
dilation_valueValue used to dilate the input.
nameThe name of the operation
tagThe tag to mark the operation
Returns
The output tensor.

◆ end_index()

PrimExpr tvm::topi::nn::end_index ( const Var out_index,
const PrimExpr odim,
const PrimExpr idim 
)
inline

◆ find_depth_height_width()

bool tvm::topi::nn::find_depth_height_width ( const std::string &  layout,
int *  depth_axis,
int *  height_axis,
int *  width_axis 
)
inline

Find index of Depth, Height or Width dimension in a layout string.

Parameters
layoutThe layout string
depth_axisset as the index of depth ('D') if not nullptr.
height_axisset as the index of height ('H') if not nullptr.
width_axisset as the index of width ('W') if not nullptr.
Returns
true if the layout is valid (i.e., no tiling on D, H or W dimensions, no duplicates and if the requested dimensions are found), otherwise false.

◆ find_height_width()

bool tvm::topi::nn::find_height_width ( const std::string &  layout,
int *  height_axis,
int *  width_axis 
)
inline

◆ find_width()

bool tvm::topi::nn::find_width ( const std::string &  layout,
int *  width_axis 
)
inline

◆ flatten()

Tensor tvm::topi::nn::flatten ( const Tensor x,
std::string  name = "tensor",
std::string  tag = kInjective 
)
inline

Flattens the input tensor into a 2-D tensor by collapsing higher dimensions. This requires the input tensor to have constant sized dimensions.

Parameters
xThe input tensor.
nameThe name of the operation
tagThe tag to mark the operation
Returns
A 2-D tensor.

◆ global_pool()

Tensor tvm::topi::nn::global_pool ( const Tensor x,
PoolType  pool_type,
const std::string &  layout = "NCHW" 
)
inline

Perform global pooling on height and width dimension of data. It decides the height and width dimension according to the layout string, in which 'W' and 'H' means width and height respectively. Width and height dimension cannot be split. For example, NCHW, NCHW16c, ... are valid for global_pool, while NCHW16w, NCHW16h are not. See layout for more information of the layout string convention.

Parameters
xThe input tensor represent as layout
pool_typeThe type of pooling operator
layoutThe input layout. global-pooling supports any layout as long as 'H' and 'W' appear. The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, where upper case indicates a dimension and the corresponding lower case (with factor size) indicates the sub-dimension. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. (in which factor size 16 will not be used in pooling but for other operators, it can be used to decide the output shape). Since pooling does not care about the factor size of dimensions other than H and W, one can pass NCHWc as well.
Returns
The output tensor in same layout with height and width dimension size of 1. e.g., for NCHW, the output shape will be [batch, channel, 1, 1]

◆ group_norm()

Tensor tvm::topi::nn::group_norm ( const Tensor data,
const Tensor gamma,
const Tensor beta,
int  num_groups,
int  channel_axis,
const Array< Integer > &  axes,
double  epsilon,
std::string  name = "T_group_norm",
std::string  tag = kInjective 
)
inline

◆ instance_norm()

Tensor tvm::topi::nn::instance_norm ( const Tensor data,
const Tensor gamma,
const Tensor beta,
const Array< Integer > &  axis,
double  epsilon,
std::string  name = "T_instance_norm",
std::string  tag = kInjective 
)
inline

Instance normalization.

Parameters
dataN-D tensor with shape [d_0, d_1, ..., d_{N-1}]
gammaK-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and d_{axis_k} == r_k
betaOptional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where d_{axis_k} == r_k
axisThe axis to normalize over (the axis along which mean and variance are computed).
epsilonThe epsilon value to avoid division by zero.
nameThe name of the operation.
tagThe tag to mark the operation.
Returns
The normalized tensor, with the same shape as data.

◆ layer_norm()

Tensor tvm::topi::nn::layer_norm ( const Tensor data,
const Tensor gamma,
const Tensor beta,
const Array< Integer > &  axis,
double  epsilon,
std::string  name = "T_layer_norm",
std::string  tag = kInjective 
)
inline

Layer normalization.

Parameters
dataN-D tensor with shape [d_0, d_1, ..., d_{N-1}]
gammaK-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and d_{axis_k} == r_k
betaOptional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where d_{axis_k} == r_k
axisThe axis to normalize over.
epsilonThe epsilon value to avoid division by zero.
nameThe name of the operation.
tagThe tag to mark the operation.
Returns
The normalized tensor, with the same shape as data.

◆ log_softmax()

Tensor tvm::topi::nn::log_softmax ( const Tensor x,
std::string  name = "tensor",
std::string  tag = "log_softmax_output" 
)
inline

Log softmax activation.

Parameters
xThe input tensor. 2-D where log softmax is performed along the second dimension
nameThe name of the operation
tagThe tag to mark the operation
Returns
A Tensor whose op member is the log softmax operation

◆ lrn()

Tensor tvm::topi::nn::lrn ( const Tensor data,
int  size,
int  axis = 1,
float  alpha = 0.0001,
float  beta = 0.75,
float  bias = 2,
std::string  name = "tensor",
std::string  tag = kBroadcast 
)
inline

Local response normalization inference operator.

Parameters
dataThe input tensor. 4-D shape NCHW or NHWC
sizeInteger to define normalisation window size
axisInput data layout channel axis
alphaFloat scaling factor
betaExponent value
biasOffset to avoid dividing by zero
nameThe name of the operation
tagThe tag to mark the operation
Returns
A Tensor whose op member is the Local response normalization operation

◆ pool1d()

Tensor tvm::topi::nn::pool1d ( const Tensor x,
const Array< PrimExpr > &  kernel_size,
const Array< PrimExpr > &  stride_size,
const Array< PrimExpr > &  dilation_size,
const Array< PrimExpr > &  padding_size,
PoolType  pool_type,
bool  ceil_mode,
const std::string &  layout = "NCW",
bool  count_include_pad = true 
)
inline

Perform pooling on the width dimension of data. Width axis is determined by the layout string in which 'W' means width. Width dimension cannot be split. For example, NCW, NCW16c, etc. are valid for pool, while NCW16w is not. See layout for more information of the layout string convention.

Parameters
xThe input tensor.
kernel_sizeVector of one int: {kernel_width}
stride_sizeVector of one int: {stride_width}
dilation_sizeVector of one int: {dilation_width}
padding_sizeVector of two ints: {head_pad_width, tail_pad_width}
pool_typeThe type of pooling operator
ceil_modeWhether to use ceil when calculating the output size
layoutThe input layout. Pooling supports any layout as long as 'W' appears. The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, where upper case indicates a dimension and the corresponding lower case (with factor size) indicates the split dimension. For example, NCW16c can describe a 4-D tensor of [batch_size, channel, width, channel_block]. (in which factor size 16 will not be used in pooling but for other operators, it can be used to decide the output shape). Since pooling does not care about the factor size of dimensions other than W, one can pass NCWc as well.
count_include_padWhether include padding in the calculation when pool_type is 'avg'
Returns
The output tensor in the same layout

◆ pool2d()

Tensor tvm::topi::nn::pool2d ( const Tensor x,
const Array< PrimExpr > &  kernel_size,
const Array< PrimExpr > &  stride_size,
const Array< PrimExpr > &  dilation_size,
const Array< PrimExpr > &  padding_size,
PoolType  pool_type,
bool  ceil_mode,
const std::string &  layout = "NCHW",
bool  count_include_pad = true 
)
inline

Perform pooling on height and width dimension of data. It decides the height and width dimension according to the layout string, in which 'W' and 'H' means width and height respectively. Width and height dimension cannot be split. For example, NCHW, NCHW16c, etc. are valid for pool, while NCHW16w, NCHW16h are not. See layout for more information of the layout string convention.

Parameters
xThe input tensor.
kernel_sizeVector of two ints: {kernel_height, kernel_width}
stride_sizeVector of two ints: {stride_height, stride_width}
dilation_sizeVector of two ints: {dilation_height, dilation_width}
padding_sizeVector of two ints: {padding_height, padding_width}
pool_typeThe type of pooling operator
ceil_modeWhether to use ceil when calculating the output size
layoutThe input layout. Pooling supports any layout as long as 'H' and 'W' appear. The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, where upper case indicates a dimension and the corresponding lower case (with factor size) indicates the split dimension. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. (in which factor size 16 will not be used in pooling but for other operators, it can be used to decide the output shape). Since pooling does not care about the factor size of dimensions other than H and W, one can pass NCHWc as well.
count_include_padWhether include padding in the calculation when pool_type is 'avg'
Returns
The output tensor in the same layout

◆ pool3d()

Tensor tvm::topi::nn::pool3d ( const Tensor x,
const Array< PrimExpr > &  kernel_size,
const Array< PrimExpr > &  stride_size,
const Array< PrimExpr > &  dilation_size,
const Array< PrimExpr > &  padding_size,
PoolType  pool_type,
bool  ceil_mode,
const std::string &  layout = "NCDHW",
bool  count_include_pad = true 
)
inline

Perform pooling on depth, height and width dimension of data. It decides the depth, height and width dimension according to the layout string, in which 'D', 'W' and 'H' means depth, width and height respectively. Depth, Width and height dimension cannot be split. For example, NCDHW, NCDHW16c, etc. are valid for pool, while NCDHW16d, NCDHW16w or NCDHW16h are not. See layout for more information of the layout string convention.

Parameters
xThe input tensor.
kernel_sizeVector of three ints: {kernel_depth, kernel_height, kernel_width}
stride_sizeVector of three ints: {stride_depth, stride_height, stride_width}
dilation_sizeVector of three ints: {dilation_depth, dilation_height, dilation_width}
padding_sizeVector of six ints: {head_pad_depth, head_pad_height, head_pad_width, tail_pad_depth, tail_pad_height, tail_pad_width}
pool_typeThe type of pooling operator
ceil_modeWhether to use ceil when calculating the output size
layoutThe input layout. Pooling supports any layout as long as 'D', 'H' and 'W' appear. The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, where upper case indicates a dimension and the corresponding lower case (with factor size) indicates the split dimension. For example, NCDHW16c can describe a 6-D tensor of [batch_size, channel, depth, height, width, channel_block]. (in which factor size 16 will not be used in pooling but for other operators, it can be used to decide the output shape). Since pooling does not care about the factor size of dimensions other than D, H and W, one can pass NCDHWc as well.
count_include_padWhether include padding in the calculation when pool_type is 'avg'
Returns
The output tensor in the same layout

◆ pool_grad()

Tensor tvm::topi::nn::pool_grad ( const Tensor out_grad,
const Tensor x,
const Array< PrimExpr > &  kernel_size,
const Array< PrimExpr > &  stride_size,
const Array< PrimExpr > &  padding_size,
PoolType  pool_type,
bool  ceil_mode,
const std::string &  layout = "NCHW",
bool  count_include_pad = true 
)
inline

Calculate gradient of pooling on height and width dimension of data. It decides the height and width dimension according to the layout string, in which 'W' and 'H' means width and height respectively. Width and height dimension cannot be split. For example, NCHW, NCHW16c, etc. are valid for pool, while NCHW16w, NCHW16h are not. See layout for more information of the layout string convention.

Parameters
out_gradThe output gradient tensor.
xThe input tensor.
kernel_sizeVector of two ints: {kernel_height, kernel_width}
stride_sizeVector of two ints: {stride_height, stride_width}
padding_sizeVector of two ints: {padding_height, padding_width}
pool_typeThe type of pooling operator
ceil_modeWhether to use ceil when calculating the output size
layoutThe input layout. Pooling supports any layout as long as 'H' and 'W' appear. The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, where upper case indicates a dimension and the corresponding lower case (with factor size) indicates the split dimension. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. (in which factor size 16 will not be used in pooling but for other operators, it can be used to decide the output shape). Since pooling does not care about the factor size of dimensions other than H and W, one can pass NCHWc as well.
count_include_padWhether include padding in the calculation when pool_type is 'avg'
Returns
The output tensor in the same layout

◆ pool_grad_impl()

Tensor tvm::topi::nn::pool_grad_impl ( const Tensor out_grad,
const Tensor x,
const Array< PrimExpr > &  kernel_size,
const Array< PrimExpr > &  stride_size,
const Array< PrimExpr > &  padding_size,
PoolType  pool_type,
bool  ceil_mode,
const size_t  height_axis,
const size_t  width_axis,
bool  count_include_pad 
)
inline

◆ pool_impl_nd()

Tensor tvm::topi::nn::pool_impl_nd ( const Tensor x,
const Array< PrimExpr > &  kernel_size,
const Array< PrimExpr > &  stride_size,
const Array< PrimExpr > &  dilation_size,
const Array< PrimExpr > &  padding_size,
PoolType  pool_type,
bool  ceil_mode,
const std::vector< int > &  axis,
bool  count_include_pad 
)
inline

Perform pooling on N-dimension of data.

Parameters
xThe input tensor
kernel_sizeVector of N ints
stride_sizeVector of N ints
dilation_sizeVector of N ints
padding_sizeVector of N*2 ints [head_pad_d1, head_pad_d2, ..., head_pad_dN, tail_pad_d1, tail_pad_d2, ..., tail_pad_dN]
pool_typeThe type of pooling operator
ceil_modeWhether to use ceil when calculating the output size
axisVector of indices for the N dimensions
count_include_padWhether include padding in the calculation
Returns
The output tensor in same layout order

◆ rms_norm()

Tensor tvm::topi::nn::rms_norm ( const Tensor data,
const Tensor weight,
const Array< Integer > &  axis,
double  epsilon,
std::string  name = "T_rms_norm",
std::string  tag = kInjective 
)
inline

Root mean square normalization.

Parameters
dataN-D tensor with shape [d_0, d_1, ..., d_{N-1}]
weightK-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and d_{axis_k} == r_k
axisThe axis to normalize over.
epsilonThe epsilon value to avoid division by zero.
nameThe name of the operation.
tagThe tag to mark the operation.
Returns
The normalized tensor, with the same shape as data.

◆ scale_shift_nchw()

Tensor tvm::topi::nn::scale_shift_nchw ( const Tensor x,
const Tensor scale,
const Tensor shift,
std::string  name = "ScaleShift",
std::string  tag = kBroadcast 
)
inline

Scale and shift with NCHW order.

Parameters
xThe input tensor.
scaleScale tensor, 1-D of size channel
shiftShift tensor, 1-D of size channel
nameThe name of the operation
tagThe tag to mark the operation
Returns
A Tensor whose op member is the scale shift operation

◆ scale_shift_nhwc()

Tensor tvm::topi::nn::scale_shift_nhwc ( const Tensor x,
const Tensor scale,
const Tensor shift,
std::string  name = "ScaleShift",
std::string  tag = kBroadcast 
)
inline

Scale and shift with NHWC order.

Parameters
xThe input tensor.
scaleScale tensor, 1-D of size channel
shiftShift tensor, 1-D of size channel
nameThe name of the operation
tagThe tag to mark the operation
Returns
A Tensor whose op member is the scale shift operation

◆ softmax()

Tensor tvm::topi::nn::softmax ( const Tensor x,
int  axis = -1,
std::string  name = "tensor",
std::string  tag = "softmax_output" 
)
inline

Softmax activation.

Parameters
xThe input tensor. Can be any dimension
axisThe channel axis along which softmax is performed
nameThe name of the operation
tagThe tag to mark the operation
Returns
A Tensor whose op member is the softmax operation

◆ start_index()

PrimExpr tvm::topi::nn::start_index ( const Var out_index,
const PrimExpr odim,
const PrimExpr idim 
)
inline