tvm
nn.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_ATTRS_NN_H_
25 #define TVM_RELAX_ATTRS_NN_H_
26 
27 #include <tvm/relax/expr.h>
28 
29 namespace tvm {
30 namespace relax {
31 
33 struct Conv1DAttrs : public AttrsNodeReflAdapter<Conv1DAttrs> {
34  ffi::Array<IntImm> strides;
35  ffi::Array<IntImm> padding;
36  ffi::Array<IntImm> dilation;
37  int groups;
38  ffi::String data_layout;
39  ffi::String kernel_layout;
40  ffi::String out_layout;
42 
43  static void RegisterReflection() {
44  namespace refl = tvm::ffi::reflection;
45  refl::ObjectDef<Conv1DAttrs>()
46  .def_ro("strides", &Conv1DAttrs::strides, "Specifies the strides of the convolution.")
47  .def_ro("padding", &Conv1DAttrs::padding,
48  "If padding is non-zero, then the input is implicitly zero-padded"
49  "Padding support both symmetric and asymmetric as"
50  "one int : same padding used on both sides"
51  "two int : padding width in the order of (left, right)")
52  .def_ro("dilation", &Conv1DAttrs::dilation,
53  "Specifies the dilation rate to use for dilated convolution.")
54  .def_ro("groups", &Conv1DAttrs::groups,
55  "Number of groups to split the input into for grouped convolution. The number of "
56  "input and "
57  "output channels should be divisible by the number of groups.")
58  .def_ro("data_layout", &Conv1DAttrs::data_layout,
59  "Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
60  "'N', 'C', 'W' stands for batch, channel, width"
61  "dimensions respectively. Convolution is applied on the 'W' dimensions.")
62  .def_ro("kernel_layout", &Conv1DAttrs::kernel_layout,
63  "Dimension ordering of weight. Can be 'OIW', 'IOW', etc."
64  "'O', 'I', 'W' stands for num_filter, input_channel, and width"
65  "dimensions respectively.")
66  .def_ro("out_layout", &Conv1DAttrs::out_layout,
67  "Dimension ordering of output. Can be 'NCW', 'NWC', etc."
68  "'N', 'C', 'W' stands for batch, channel, and width"
69  "dimensions respectively. Default to be same as input layout.")
70  .def_ro("out_dtype", &Conv1DAttrs::out_dtype,
71  "Output data type, set to explicit type under mixed precision setting");
72  }
74 }; // struct Conv1dAttrs
75 
77 struct Conv2DAttrs : public AttrsNodeReflAdapter<Conv2DAttrs> {
78  ffi::Array<IntImm> strides;
79  ffi::Array<IntImm> padding;
80  ffi::Array<IntImm> dilation;
81  int groups;
82  ffi::String data_layout;
83  ffi::String kernel_layout;
84  ffi::String out_layout;
86 
87  static void RegisterReflection() {
88  namespace refl = tvm::ffi::reflection;
89  refl::ObjectDef<Conv2DAttrs>()
90  .def_ro("strides", &Conv2DAttrs::strides, "Specifies the strides of the convolution.")
91  .def_ro("padding", &Conv2DAttrs::padding,
92  "If padding is non-zero, then the input is implicitly zero-padded"
93  "Padding support both symmetric and asymmetric as"
94  "one int : same padding used on all sides"
95  "two int : bottom, right will use same padding as top, left"
96  "four int : padding width in the order of (top, left, bottom, right)")
97  .def_ro("dilation", &Conv2DAttrs::dilation,
98  "Specifies the dilation rate to use for dilated convolution.")
99  .def_ro("groups", &Conv2DAttrs::groups,
100  "Number of groups to split the input into for grouped convolution. The number of "
101  "input and "
102  "output channels should be divisible by the number of groups.")
103  .def_ro("data_layout", &Conv2DAttrs::data_layout,
104  "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
105  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
106  "dimensions respectively. Convolution is applied on the 'H' and"
107  "'W' dimensions.")
108  .def_ro("kernel_layout", &Conv2DAttrs::kernel_layout,
109  "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
110  "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
111  "dimensions respectively.")
112  .def_ro("out_layout", &Conv2DAttrs::out_layout,
113  "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
114  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
115  "dimensions respectively. Default to be same as input layout.")
116  .def_ro("out_dtype", &Conv2DAttrs::out_dtype,
117  "Output data type, set to explicit type under mixed precision setting");
118  }
120 }; // struct Conv2dAttrs
121 
123 struct Conv3DAttrs : public AttrsNodeReflAdapter<Conv3DAttrs> {
124  ffi::Array<IntImm> strides;
125  ffi::Array<IntImm> padding;
126  ffi::Array<IntImm> dilation;
127  int groups;
128  ffi::String data_layout;
129  ffi::String kernel_layout;
130  ffi::String out_layout;
132 
133  static void RegisterReflection() {
134  namespace refl = tvm::ffi::reflection;
135  refl::ObjectDef<Conv3DAttrs>()
136  .def_ro("strides", &Conv3DAttrs::strides, "Specifies the strides of the convolution.")
137  .def_ro(
138  "padding", &Conv3DAttrs::padding,
139  "If padding is non-zero, then the input is implicitly zero-padded"
140  "Padding support both symmetric and asymmetric as"
141  "one int : same padding used on all sides"
142  "two int : bottom, right will use same padding as top, left"
143  "four int : padding width in the order of (forward, back, top, left, bottom, right)")
144  .def_ro("dilation", &Conv3DAttrs::dilation,
145  "Specifies the dilation rate to use for dilated convolution.")
146  .def_ro("groups", &Conv3DAttrs::groups,
147  "Number of groups to split the input into for grouped convolution. The number of "
148  "input and "
149  "output channels should be divisible by the number of groups.")
150  .def_ro("data_layout", &Conv3DAttrs::data_layout,
151  "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
152  "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
153  "dimensions respectively. Convolution is applied on the 'D', 'H', and"
154  "'W' dimensions.")
155  .def_ro(
156  "kernel_layout", &Conv3DAttrs::kernel_layout,
157  "Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc."
158  "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height, and width"
159  "dimensions respectively.")
160  .def_ro("out_layout", &Conv3DAttrs::out_layout,
161  "Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc."
162  "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
163  "dimensions respectively. Default to be same as input layout.")
164  .def_ro("out_dtype", &Conv3DAttrs::out_dtype,
165  "Output data type, set to explicit type under mixed precision setting");
166  }
168 }; // struct Conv3dAttrs
169 
171 struct Conv1DTransposeAttrs : public AttrsNodeReflAdapter<Conv1DTransposeAttrs> {
172  ffi::Array<IntImm> strides;
173  ffi::Array<IntImm> padding;
174  ffi::Array<IntImm> output_padding;
175  ffi::Array<IntImm> dilation;
176  int groups;
177  ffi::String data_layout;
178  ffi::String kernel_layout;
179  ffi::String out_layout;
181 
182  static void RegisterReflection() {
183  namespace refl = tvm::ffi::reflection;
184  refl::ObjectDef<Conv1DTransposeAttrs>()
185  .def_ro("strides", &Conv1DTransposeAttrs::strides,
186  "Specifies the strides of the convolution.")
187  .def_ro("padding", &Conv1DTransposeAttrs::padding,
188  "If padding is non-zero, then the input is implicitly zero-padded"
189  "Padding support both symmetric and asymmetric as"
190  "one int : same padding used on both sides"
191  "two int : padding width in the order of (left, right)")
192  .def_ro("output_padding", &Conv1DTransposeAttrs::output_padding,
193  "Used to disambiguate the output shape.")
194  .def_ro("dilation", &Conv1DTransposeAttrs::dilation,
195  "Specifies the dilation rate to use for dilated convolution.")
196  .def_ro("groups", &Conv1DTransposeAttrs::groups,
197  "Number of groups to split the input into for grouped convolution. The number of "
198  "input and "
199  "output channels should be divisible by the number of groups.")
200  .def_ro("data_layout", &Conv1DTransposeAttrs::data_layout,
201  "Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
202  "'N', 'C', 'W' stands for batch, channel, width"
203  "dimensions respectively. Convolution is applied on the 'W' dimensions.")
204  .def_ro("kernel_layout", &Conv1DTransposeAttrs::kernel_layout,
205  "Dimension ordering of weight. Can be 'OIW', 'IOW', etc."
206  "'O', 'I', 'W' stands for num_filter, input_channel, and width"
207  "dimensions respectively.")
208  .def_ro("out_layout", &Conv1DTransposeAttrs::out_layout,
209  "Dimension ordering of output. Can be 'NCW', 'NWC', etc."
210  "'N', 'C', 'W' stands for batch, channel, and width"
211  "dimensions respectively. Default to be same as input layout.")
212  .def_ro("out_dtype", &Conv1DTransposeAttrs::out_dtype,
213  "Output data type, set to explicit type under mixed precision setting");
214  }
215  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv1DTransposeAttrs", Conv1DTransposeAttrs,
216  BaseAttrsNode);
217 }; // struct Conv1DTransposeAttrs
218 
220 struct Conv2DTransposeAttrs : public AttrsNodeReflAdapter<Conv2DTransposeAttrs> {
221  ffi::Array<IntImm> strides;
222  ffi::Array<IntImm> padding;
223  ffi::Array<IntImm> output_padding;
224  ffi::Array<IntImm> dilation;
225  int groups;
226  ffi::String data_layout;
227  ffi::String kernel_layout;
228  ffi::String out_layout;
230 
231  static void RegisterReflection() {
232  namespace refl = tvm::ffi::reflection;
233  refl::ObjectDef<Conv2DTransposeAttrs>()
234  .def_ro("strides", &Conv2DTransposeAttrs::strides,
235  "Specifies the strides of the convolution.")
236  .def_ro("padding", &Conv2DTransposeAttrs::padding,
237  "If padding is non-zero, then the input is implicitly zero-padded"
238  "Padding support both symmetric and asymmetric as"
239  "one int : same padding used on all sides"
240  "two int : bottom, right will use same padding as top, left"
241  "four int : padding width in the order of (top, left, bottom, right)")
242  .def_ro("output_padding", &Conv2DTransposeAttrs::output_padding,
243  "Used to disambiguate the output shape.")
244  .def_ro("dilation", &Conv2DTransposeAttrs::dilation,
245  "Specifies the dilation rate to use for dilated convolution.")
246  .def_ro("groups", &Conv2DTransposeAttrs::groups,
247  "Number of groups to split the input into for grouped convolution. The number of "
248  "input and "
249  "output channels should be divisible by the number of groups.")
250  .def_ro("data_layout", &Conv2DTransposeAttrs::data_layout,
251  "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
252  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
253  "dimensions respectively. Convolution is applied on the 'H' and"
254  "'W' dimensions.")
255  .def_ro("kernel_layout", &Conv2DTransposeAttrs::kernel_layout,
256  "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
257  "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
258  "dimensions respectively.")
259  .def_ro("out_layout", &Conv2DTransposeAttrs::out_layout,
260  "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
261  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
262  "dimensions respectively. Default to be same as input layout.")
263  .def_ro("out_dtype", &Conv2DTransposeAttrs::out_dtype,
264  "Output data type, set to explicit type under mixed precision setting");
265  }
266  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv2DTransposeAttrs", Conv2DTransposeAttrs,
267  BaseAttrsNode);
268 }; // struct Conv2DTransposeAttrs
269 
271 struct Pool1DAttrs : public AttrsNodeReflAdapter<Pool1DAttrs> {
272  ffi::Array<IntImm> pool_size;
273  ffi::Array<IntImm> strides;
274  ffi::Array<IntImm> padding;
275  ffi::Array<IntImm> dilation;
276  bool ceil_mode;
278  ffi::String layout;
279  ffi::String out_layout;
280 
281  static void RegisterReflection() {
282  namespace refl = tvm::ffi::reflection;
283  refl::ObjectDef<Pool1DAttrs>()
284  .def_ro("pool_size", &Pool1DAttrs::pool_size, "Size of the pooling windows.")
285  .def_ro("strides", &Pool1DAttrs::strides, "Specifies the strides of the convolution.")
286  .def_ro("dilation", &Pool1DAttrs::dilation, "Specifies the dilation of the convolution.")
287  .def_ro("padding", &Pool1DAttrs::padding,
288  "If padding is non-zero, then the input is implicitly zero-padded"
289  "Padding support both symmetric and asymmetric as"
290  "one int : same padding used on all sides"
291  "two int : padding width in the order of (left, right)")
292  .def_ro(
293  "ceil_mode", &Pool1DAttrs::ceil_mode,
294  "A boolean indicating if use ceil or floor to compute the output shape. By using ceil, "
295  "every element in the input tensor will be covered by a sliding window.")
296  .def_ro("count_include_pad", &Pool1DAttrs::count_include_pad,
297  "When true, will include padding to compute the average")
298  .def_ro("layout", &Pool1DAttrs::layout,
299  "Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
300  "'N', 'C', 'W' stands for batch, channel, and width"
301  "dimensions respectively. Pooling is applied on the 'W' dimensions.",
302  refl::DefaultValue("NCW"))
303  .def_ro("out_layout", &Pool1DAttrs::out_layout,
304  "Dimension ordering of output data. Can be 'NCW', 'NWC', etc."
305  "'N', 'C', 'W' stands for batch, channel, and width"
306  "dimensions respectively. Pooling is applied on the 'W' dimensions.");
307  }
309 }; // struct Pool1dAttrs
310 
312 struct Pool2DAttrs : public AttrsNodeReflAdapter<Pool2DAttrs> {
313  ffi::Array<IntImm> pool_size;
314  ffi::Array<IntImm> strides;
315  ffi::Array<IntImm> padding;
316  ffi::Array<IntImm> dilation;
317  bool ceil_mode;
319  ffi::String layout;
320  ffi::String out_layout;
321 
322  static void RegisterReflection() {
323  namespace refl = tvm::ffi::reflection;
324  refl::ObjectDef<Pool2DAttrs>()
325  .def_ro("pool_size", &Pool2DAttrs::pool_size, "Size of the pooling windows.")
326  .def_ro("strides", &Pool2DAttrs::strides, "Specifies the strides of the convolution.")
327  .def_ro("dilation", &Pool2DAttrs::dilation, "Specifies the dilation of the convolution.")
328  .def_ro("padding", &Pool2DAttrs::padding,
329  "If padding is non-zero, then the input is implicitly zero-padded"
330  "Padding support both symmetric and asymmetric as"
331  "one int : same padding used on all sides"
332  "two int : bottom, right will use same padding as top, left"
333  "four int : padding width in the order of (top, left, bottom, right)")
334  .def_ro(
335  "ceil_mode", &Pool2DAttrs::ceil_mode,
336  "A boolean indicating if use ceil or floor to compute the output shape. By using ceil, "
337  "every element in the input tensor will be covered by a sliding window.")
338  .def_ro("count_include_pad", &Pool2DAttrs::count_include_pad,
339  "When true, will include padding to compute the average")
340  .def_ro("layout", &Pool2DAttrs::layout,
341  "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
342  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
343  "dimensions respectively. Pooling is applied on the 'H' and"
344  "'W' dimensions.")
345  .def_ro("out_layout", &Pool2DAttrs::out_layout,
346  "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
347  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
348  "dimensions respectively. Pooling is applied on the 'H' and"
349  "'W' dimensions.");
350  }
352 }; // struct Pool2dAttrs
353 
355 struct Pool3DAttrs : public AttrsNodeReflAdapter<Pool3DAttrs> {
356  ffi::Array<IntImm> pool_size;
357  ffi::Array<IntImm> strides;
358  ffi::Array<IntImm> padding;
359  ffi::Array<IntImm> dilation;
360  bool ceil_mode;
362  ffi::String layout;
363  ffi::String out_layout;
364 
365  static void RegisterReflection() {
366  namespace refl = tvm::ffi::reflection;
367  refl::ObjectDef<Pool3DAttrs>()
368  .def_ro("pool_size", &Pool3DAttrs::pool_size, "Size of the pooling windows.")
369  .def_ro("strides", &Pool3DAttrs::strides, "Specifies the strides of the convolution.")
370  .def_ro("dilation", &Pool3DAttrs::dilation, "Specifies the dilation of the convolution.")
371  .def_ro("padding", &Pool3DAttrs::padding,
372  "If padding is non-zero, then the input is implicitly zero-padded"
373  "Padding support both symmetric and asymmetric as"
374  "one int : same padding used on all sides"
375  "three int : back, bottom, right will use same padding as front, top, left"
376  "four int : padding width in the order of (front, top, left, back, bottom, right)")
377  .def_ro(
378  "ceil_mode", &Pool3DAttrs::ceil_mode,
379  "A boolean indicating if use ceil or floor to compute the output shape. By using ceil, "
380  "every element in the input tensor will be covered by a sliding window.")
381  .def_ro("count_include_pad", &Pool3DAttrs::count_include_pad,
382  "When true, will include padding to compute the average")
383  .def_ro("layout", &Pool3DAttrs::layout,
384  "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
385  "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
386  "dimensions respectively. Pooling is applied on the 'D', 'H' and"
387  "'W' dimensions.")
388  .def_ro("out_layout", &Pool3DAttrs::out_layout,
389  "Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc."
390  "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
391  "dimensions respectively. Pooling is applied on the 'D', 'H' and"
392  "'W' dimensions.");
393  }
395 }; // struct Pool3dAttrs
396 
398 struct AdaptivePool1DAttrs : public AttrsNodeReflAdapter<AdaptivePool1DAttrs> {
399  ffi::Optional<ffi::Array<IntImm>> output_size;
400  ffi::String layout;
401  ffi::String out_layout;
402 
403  static void RegisterReflection() {
404  namespace refl = tvm::ffi::reflection;
405  refl::ObjectDef<AdaptivePool1DAttrs>()
406  .def_ro("output_size", &AdaptivePool1DAttrs::output_size, "Output width.")
407  .def_ro("layout", &AdaptivePool1DAttrs::layout,
408  "Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
409  "'N', 'C', 'W' stands for batch, channel and width"
410  "dimensions respectively. Pooling is applied on the"
411  "'W' dimensions.")
412  .def_ro("out_layout", &AdaptivePool1DAttrs::out_layout,
413  "Dimension ordering of output data. Can be 'NCW', 'NWC', etc."
414  "'N', 'C', 'W' stands for batch, channel and width"
415  "dimensions respectively. Pooling is applied on the"
416  "'W' dimensions.");
417  }
418  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AdaptivePool1DAttrs", AdaptivePool1DAttrs,
419  BaseAttrsNode);
420 }; // struct AdaptivePool1DAttrs
421 
423 struct AdaptivePool2DAttrs : public AttrsNodeReflAdapter<AdaptivePool2DAttrs> {
424  ffi::Optional<ffi::Array<IntImm>> output_size;
425  ffi::String layout;
426  ffi::String out_layout;
427 
428  static void RegisterReflection() {
429  namespace refl = tvm::ffi::reflection;
430  refl::ObjectDef<AdaptivePool2DAttrs>()
431  .def_ro("output_size", &AdaptivePool2DAttrs::output_size, "Output height and width.")
432  .def_ro("layout", &AdaptivePool2DAttrs::layout,
433  "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
434  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
435  "dimensions respectively. Pooling is applied on the 'H' and"
436  "'W' dimensions.")
437  .def_ro("out_layout", &AdaptivePool2DAttrs::out_layout,
438  "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
439  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
440  "dimensions respectively. Pooling is applied on the 'H' and"
441  "'W' dimensions.");
442  }
443  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AdaptivePool2DAttrs", AdaptivePool2DAttrs,
444  BaseAttrsNode);
445 }; // struct AdaptivePool2DAttrs
446 
448 struct AdaptivePool3DAttrs : public AttrsNodeReflAdapter<AdaptivePool3DAttrs> {
449  ffi::Optional<ffi::Array<IntImm>> output_size;
450  ffi::String layout;
451  ffi::String out_layout;
452 
453  static void RegisterReflection() {
454  namespace refl = tvm::ffi::reflection;
455  refl::ObjectDef<AdaptivePool3DAttrs>()
456  .def_ro("output_size", &AdaptivePool3DAttrs::output_size, "Output depth, height and width.")
457  .def_ro("layout", &AdaptivePool3DAttrs::layout,
458  "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
459  "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
460  "dimensions respectively. Pooling is applied on 'D', 'H' and"
461  "'W' dimensions.")
462  .def_ro("out_layout", &AdaptivePool3DAttrs::out_layout,
463  "Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc."
464  "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
465  "dimensions respectively. Pooling is applied on 'D', 'H' and"
466  "'W' dimensions.");
467  }
468  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AdaptivePool3DAttrs", AdaptivePool3DAttrs,
469  BaseAttrsNode);
470 }; // struct AdaptivePool3DAttrs
471 
473 struct SoftmaxAttrs : public AttrsNodeReflAdapter<SoftmaxAttrs> {
474  int axis;
475 
476  static void RegisterReflection() {
477  namespace refl = tvm::ffi::reflection;
478  refl::ObjectDef<SoftmaxAttrs>().def_ro("axis", &SoftmaxAttrs::axis,
479  "The axis to sum over when computing softmax.");
480  }
482 };
483 
485 struct LeakyReluAttrs : public AttrsNodeReflAdapter<LeakyReluAttrs> {
486  double alpha;
487 
488  static void RegisterReflection() {
489  namespace refl = tvm::ffi::reflection;
490  refl::ObjectDef<LeakyReluAttrs>().def_ro("alpha", &LeakyReluAttrs::alpha,
491  "The slope of the negative part.");
492  }
494 };
495 
497 struct SoftplusAttrs : public AttrsNodeReflAdapter<SoftplusAttrs> {
498  double beta;
499  double threshold;
500 
501  static void RegisterReflection() {
502  namespace refl = tvm::ffi::reflection;
503  refl::ObjectDef<SoftplusAttrs>()
504  .def_ro("beta", &SoftplusAttrs::beta,
505  "Scaling factor controlling the sharpness of the Softplus transition.")
506  .def_ro("threshold", &SoftplusAttrs::threshold,
507  "Value determining when to use linear approximation for numerical stability.");
508  }
510 };
511 
513 struct PReluAttrs : public AttrsNodeReflAdapter<PReluAttrs> {
514  int axis;
515 
516  static void RegisterReflection() {
517  namespace refl = tvm::ffi::reflection;
518  refl::ObjectDef<PReluAttrs>().def_ro("axis", &PReluAttrs::axis,
519  "The axis along which the alpha values are applied.");
520  }
522 };
523 
525 struct BatchNormAttrs : public AttrsNodeReflAdapter<BatchNormAttrs> {
526  int axis;
527  double epsilon;
528  bool center;
529  bool scale;
530  double momentum;
531  bool training;
532 
533  static void RegisterReflection() {
534  namespace refl = tvm::ffi::reflection;
535  refl::ObjectDef<BatchNormAttrs>()
536  .def_ro("axis", &BatchNormAttrs::axis, "The axis along which the normalization is applied.")
537  .def_ro("epsilon", &BatchNormAttrs::epsilon,
538  "Small float added to variance to avoid dividing by zero")
539  .def_ro("center", &BatchNormAttrs::center,
540  "Indicating if the beta offset will be added to the normalized tensor.")
541  .def_ro("scale", &BatchNormAttrs::scale,
542  "Indicating if the gamma scale will be multiplied.")
543  .def_ro("momentum", &BatchNormAttrs::momentum,
544  "The value used for the moving_mean and moving_var update.")
545  .def_ro("training", &BatchNormAttrs::training,
546  "Whether we are training (i.e., not in eval mode).");
547  }
549 }; // struct BatchNormAttrs
550 
552 struct LayerNormAttrs : public AttrsNodeReflAdapter<LayerNormAttrs> {
553  ffi::Array<Integer> axes;
554  double epsilon;
555  bool center;
556  bool scale;
557 
558  static void RegisterReflection() {
559  namespace refl = tvm::ffi::reflection;
560  refl::ObjectDef<LayerNormAttrs>()
561  .def_ro("axes", &LayerNormAttrs::axes,
562  "The axes that along which the normalization is applied.")
563  .def_ro("epsilon", &LayerNormAttrs::epsilon,
564  "Small float added to variance to avoid dividing by zero")
565  .def_ro("center", &LayerNormAttrs::center,
566  "Indicating if the beta offset will be added to the normalized tensor.")
567  .def_ro("scale", &LayerNormAttrs::scale,
568  "Indicating if the gamma scale will be multiplied.");
569  }
571 }; // struct LayerNormAttrs
572 
574 struct GroupNormAttrs : public AttrsNodeReflAdapter<GroupNormAttrs> {
577  ffi::Array<Integer> axes;
578  double epsilon;
579  bool center;
580  bool scale;
581 
582  static void RegisterReflection() {
583  namespace refl = tvm::ffi::reflection;
584  refl::ObjectDef<GroupNormAttrs>()
585  .def_ro("num_groups", &GroupNormAttrs::num_groups,
586  "The number of groups to separate the channels into.")
587  .def_ro("channel_axis", &GroupNormAttrs::channel_axis,
588  "The axis that represents the channel.")
589  .def_ro(
590  "axes", &GroupNormAttrs::axes,
591  "The axes that along which the normalization is applied (excluding the channel axis).")
592  .def_ro("epsilon", &GroupNormAttrs::epsilon,
593  "Small float added to variance to avoid dividing by zero")
594  .def_ro("center", &GroupNormAttrs::center,
595  "Indicating if the beta offset will be added to the normalized tensor.")
596  .def_ro("scale", &GroupNormAttrs::scale,
597  "Indicating if the gamma scale will be multiplied.");
598  }
600 }; // struct GroupNormAttrs
601 
603 struct InstanceNormAttrs : public AttrsNodeReflAdapter<InstanceNormAttrs> {
605  ffi::Array<Integer> axes;
606  double epsilon;
607  bool center;
608  bool scale;
609 
610  static void RegisterReflection() {
611  namespace refl = tvm::ffi::reflection;
612  refl::ObjectDef<InstanceNormAttrs>()
613  .def_ro("channel_axis", &InstanceNormAttrs::channel_axis,
614  "The axis that represents the channel.")
615  .def_ro("axes", &InstanceNormAttrs::axes,
616  "The axes that along which the normalization is applied.")
617  .def_ro("epsilon", &InstanceNormAttrs::epsilon,
618  "Small float added to variance to avoid dividing by zero")
619  .def_ro("center", &InstanceNormAttrs::center,
620  "Indicating if the beta offset will be added to the normalized tensor.")
621  .def_ro("scale", &InstanceNormAttrs::scale,
622  "Indicating if the gamma scale will be multiplied.");
623  }
624  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.InstanceNormAttrs", InstanceNormAttrs,
625  BaseAttrsNode);
626 }; // struct InstanceNormAttrs
627 
629 struct RMSNormAttrs : public AttrsNodeReflAdapter<RMSNormAttrs> {
630  ffi::Array<Integer> axes;
631  double epsilon;
632 
633  static void RegisterReflection() {
634  namespace refl = tvm::ffi::reflection;
635  refl::ObjectDef<RMSNormAttrs>()
636  .def_ro("axes", &RMSNormAttrs::axes,
637  "The axes that along which the normalization is applied.")
638  .def_ro("epsilon", &RMSNormAttrs::epsilon,
639  "Small float added to variance to avoid dividing by zero");
640  }
642 }; // struct RMSNormAttrs
643 
645 struct NLLLossAttrs : public AttrsNodeReflAdapter<NLLLossAttrs> {
646  ffi::String reduction;
648 
649  static void RegisterReflection() {
650  namespace refl = tvm::ffi::reflection;
651  refl::ObjectDef<NLLLossAttrs>()
652  .def_ro("reduction", &NLLLossAttrs::reduction,
653  "The reduction method to apply to the output. Can be"
654  "'none', 'mean' or 'sum'.",
655  refl::DefaultValue("mean"))
656  .def_ro("ignore_index", &NLLLossAttrs::ignore_index, "The target value to ignore.");
657  }
659 }; // struct NLLLossAttrs
660 
662 struct DropoutAttrs : public AttrsNodeReflAdapter<DropoutAttrs> {
663  double rate;
664 
665  static void RegisterReflection() {
666  namespace refl = tvm::ffi::reflection;
667  refl::ObjectDef<DropoutAttrs>().def_ro(
668  "rate", &DropoutAttrs::rate,
669  "Fraction of the input that gets dropped out during training time");
670  }
672 }; // struct DropoutAttrs
673 
675 struct AttentionAttrs : public AttrsNodeReflAdapter<AttentionAttrs> {
676  ffi::Optional<FloatImm> scale;
677  ffi::Optional<ffi::String> causal_mask;
678  ffi::Optional<IntImm> window_size;
679 
680  static void RegisterReflection() {
681  namespace refl = tvm::ffi::reflection;
682  refl::ObjectDef<AttentionAttrs>()
683  .def_ro(
684  "scale", &AttentionAttrs::scale,
685  "The custom scale applied before the softmax. The default value is 1 / sqrt(head_dim).")
686  .def_ro("causal_mask", &AttentionAttrs::causal_mask,
687  "The type of the causal mask, i.e. 'TopLeft' and 'BottomRight'.")
688  .def_ro("window_size", &AttentionAttrs::window_size,
689  "The size of the window for sliding-window attention.");
690  }
692 }; // struct AttentionAttrs
693 
695 struct PadAttrs : public AttrsNodeReflAdapter<PadAttrs> {
696  ffi::Array<Integer> pad_width;
697  double pad_value = 0.0;
698  tvm::ffi::String pad_mode;
699 
700  static void RegisterReflection() {
701  namespace refl = tvm::ffi::reflection;
702  refl::ObjectDef<PadAttrs>()
703  .def_ro("pad_width", &PadAttrs::pad_width,
704  "Number of values padded to the edges of each axis, "
705  "in the format of (before_1, after_1, ..., before_N, after_N)")
706  .def_ro("pad_value", &PadAttrs::pad_value, "The value to fill in padded area with",
707  refl::DefaultValue(0.0))
708  .def_ro("pad_mode", &PadAttrs::pad_mode,
709  "Padding type to use. \"constant\" pads with constant_value, "
710  "\"edge\" pads using the edge values of the input array, "
711  "\"reflect\" pads by reflecting values with respect to the edges.",
712  refl::DefaultValue("constant"));
713  }
715 };
716 
718 struct PixelShuffleAttrs : public AttrsNodeReflAdapter<PixelShuffleAttrs> {
720 
721  static void RegisterReflection() {
722  namespace refl = tvm::ffi::reflection;
723  refl::ObjectDef<PixelShuffleAttrs>().def_ro("upscale_factor",
725  "Scale factor for spatial upsampling.");
726  }
727  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PixelShuffleAttrs", PixelShuffleAttrs,
728  BaseAttrsNode);
729 };
730 
731 } // namespace relax
732 } // namespace tvm
733 
734 #endif // TVM_RELAX_ATTRS_NN_H_
Adapter for AttrsNode with the new reflection API.
Definition: attrs.h:385
Base class of all attribute class.
Definition: attrs.h:102
Runtime primitive data type.
Definition: data_type.h:47
Definition: repr_printer.h:91
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
Attributes for 1d adaptive pool operator.
Definition: nn.h:398
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AdaptivePool1DAttrs", AdaptivePool1DAttrs, BaseAttrsNode)
ffi::String out_layout
Definition: nn.h:401
ffi::String layout
Definition: nn.h:400
ffi::Optional< ffi::Array< IntImm > > output_size
Definition: nn.h:399
static void RegisterReflection()
Definition: nn.h:403
Attributes for 2d adaptive pool operator.
Definition: nn.h:423
static void RegisterReflection()
Definition: nn.h:428
ffi::Optional< ffi::Array< IntImm > > output_size
Definition: nn.h:424
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AdaptivePool2DAttrs", AdaptivePool2DAttrs, BaseAttrsNode)
ffi::String out_layout
Definition: nn.h:426
ffi::String layout
Definition: nn.h:425
Attributes for 3d adaptive pool operator.
Definition: nn.h:448
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AdaptivePool3DAttrs", AdaptivePool3DAttrs, BaseAttrsNode)
ffi::Optional< ffi::Array< IntImm > > output_size
Definition: nn.h:449
static void RegisterReflection()
Definition: nn.h:453
ffi::String layout
Definition: nn.h:450
ffi::String out_layout
Definition: nn.h:451
Attributes used in Attention operator.
Definition: nn.h:675
ffi::Optional< IntImm > window_size
Definition: nn.h:678
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AttentionAttrs", AttentionAttrs, BaseAttrsNode)
ffi::Optional< FloatImm > scale
Definition: nn.h:676
ffi::Optional< ffi::String > causal_mask
Definition: nn.h:677
static void RegisterReflection()
Definition: nn.h:680
Attributes used in batch_norm operator.
Definition: nn.h:525
bool training
Definition: nn.h:531
bool scale
Definition: nn.h:529
static void RegisterReflection()
Definition: nn.h:533
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.BatchNormAttrs", BatchNormAttrs, BaseAttrsNode)
double epsilon
Definition: nn.h:527
int axis
Definition: nn.h:526
double momentum
Definition: nn.h:530
bool center
Definition: nn.h:528
Attributes used in Conv1d operator.
Definition: nn.h:33
ffi::String out_layout
Definition: nn.h:40
int groups
Definition: nn.h:37
ffi::Array< IntImm > padding
Definition: nn.h:35
ffi::Array< IntImm > dilation
Definition: nn.h:36
ffi::String data_layout
Definition: nn.h:38
static void RegisterReflection()
Definition: nn.h:43
DataType out_dtype
Definition: nn.h:41
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv1DAttrs", Conv1DAttrs, BaseAttrsNode)
ffi::Array< IntImm > strides
Definition: nn.h:34
ffi::String kernel_layout
Definition: nn.h:39
Attributes used in Conv1DTranspose operator.
Definition: nn.h:171
ffi::String data_layout
Definition: nn.h:177
static void RegisterReflection()
Definition: nn.h:182
ffi::Array< IntImm > strides
Definition: nn.h:172
DataType out_dtype
Definition: nn.h:180
ffi::Array< IntImm > padding
Definition: nn.h:173
ffi::Array< IntImm > dilation
Definition: nn.h:175
ffi::Array< IntImm > output_padding
Definition: nn.h:174
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv1DTransposeAttrs", Conv1DTransposeAttrs, BaseAttrsNode)
ffi::String out_layout
Definition: nn.h:179
int groups
Definition: nn.h:176
ffi::String kernel_layout
Definition: nn.h:178
Attributes used in Conv2d operator.
Definition: nn.h:77
ffi::String out_layout
Definition: nn.h:84
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv2DAttrs", Conv2DAttrs, BaseAttrsNode)
DataType out_dtype
Definition: nn.h:85
ffi::Array< IntImm > padding
Definition: nn.h:79
ffi::String kernel_layout
Definition: nn.h:83
ffi::Array< IntImm > strides
Definition: nn.h:78
static void RegisterReflection()
Definition: nn.h:87
int groups
Definition: nn.h:81
ffi::String data_layout
Definition: nn.h:82
ffi::Array< IntImm > dilation
Definition: nn.h:80
Attributes used in Conv2d operator.
Definition: nn.h:220
ffi::Array< IntImm > strides
Definition: nn.h:221
ffi::String kernel_layout
Definition: nn.h:227
ffi::String data_layout
Definition: nn.h:226
ffi::Array< IntImm > padding
Definition: nn.h:222
ffi::Array< IntImm > output_padding
Definition: nn.h:223
int groups
Definition: nn.h:225
ffi::Array< IntImm > dilation
Definition: nn.h:224
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv2DTransposeAttrs", Conv2DTransposeAttrs, BaseAttrsNode)
ffi::String out_layout
Definition: nn.h:228
static void RegisterReflection()
Definition: nn.h:231
DataType out_dtype
Definition: nn.h:229
Attributes used in Conv3d operator.
Definition: nn.h:123
ffi::String kernel_layout
Definition: nn.h:129
ffi::Array< IntImm > dilation
Definition: nn.h:126
ffi::String out_layout
Definition: nn.h:130
ffi::Array< IntImm > strides
Definition: nn.h:124
static void RegisterReflection()
Definition: nn.h:133
DataType out_dtype
Definition: nn.h:131
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv3DAttrs", Conv3DAttrs, BaseAttrsNode)
ffi::String data_layout
Definition: nn.h:128
ffi::Array< IntImm > padding
Definition: nn.h:125
int groups
Definition: nn.h:127
Attributes used in dropout operator.
Definition: nn.h:662
double rate
Definition: nn.h:663
static void RegisterReflection()
Definition: nn.h:665
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.DropoutAttrs", DropoutAttrs, BaseAttrsNode)
Attributes used in group_norm operator.
Definition: nn.h:574
ffi::Array< Integer > axes
Definition: nn.h:577
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GroupNormAttrs", GroupNormAttrs, BaseAttrsNode)
int num_groups
Definition: nn.h:575
int channel_axis
Definition: nn.h:576
double epsilon
Definition: nn.h:578
bool center
Definition: nn.h:579
static void RegisterReflection()
Definition: nn.h:582
bool scale
Definition: nn.h:580
Attributes used in instance_norm operator.
Definition: nn.h:603
bool center
Definition: nn.h:607
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.InstanceNormAttrs", InstanceNormAttrs, BaseAttrsNode)
double epsilon
Definition: nn.h:606
bool scale
Definition: nn.h:608
ffi::Array< Integer > axes
Definition: nn.h:605
static void RegisterReflection()
Definition: nn.h:610
int channel_axis
Definition: nn.h:604
Attributes used in layer_norm operator.
Definition: nn.h:552
ffi::Array< Integer > axes
Definition: nn.h:553
bool scale
Definition: nn.h:556
static void RegisterReflection()
Definition: nn.h:558
bool center
Definition: nn.h:555
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.LayerNormAttrs", LayerNormAttrs, BaseAttrsNode)
double epsilon
Definition: nn.h:554
Attributes used in softmax operators.
Definition: nn.h:485
static void RegisterReflection()
Definition: nn.h:488
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.LeakyReluAttrs", LeakyReluAttrs, BaseAttrsNode)
double alpha
Definition: nn.h:486
Attributes used in nll_loss operator.
Definition: nn.h:645
static void RegisterReflection()
Definition: nn.h:649
ffi::String reduction
Definition: nn.h:646
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.NLLLossAttrs", NLLLossAttrs, BaseAttrsNode)
int ignore_index
Definition: nn.h:647
Attributes used in PReLU operator.
Definition: nn.h:513
static void RegisterReflection()
Definition: nn.h:516
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PReluAttrs", PReluAttrs, BaseAttrsNode)
int axis
Definition: nn.h:514
Attributes used for the padding operator.
Definition: nn.h:695
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PadAttrs", PadAttrs, BaseAttrsNode)
ffi::Array< Integer > pad_width
Definition: nn.h:696
tvm::ffi::String pad_mode
Definition: nn.h:698
double pad_value
Definition: nn.h:697
static void RegisterReflection()
Definition: nn.h:700
Attributes used for the pixel shuffle operator.
Definition: nn.h:718
int upscale_factor
Definition: nn.h:719
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PixelShuffleAttrs", PixelShuffleAttrs, BaseAttrsNode)
static void RegisterReflection()
Definition: nn.h:721
Attributes used in max_pool1d and avg_pool1d operator.
Definition: nn.h:271
ffi::Array< IntImm > strides
Definition: nn.h:273
ffi::Array< IntImm > padding
Definition: nn.h:274
static void RegisterReflection()
Definition: nn.h:281
ffi::String layout
Definition: nn.h:278
ffi::Array< IntImm > dilation
Definition: nn.h:275
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Pool1DAttrs", Pool1DAttrs, BaseAttrsNode)
bool count_include_pad
Definition: nn.h:277
ffi::Array< IntImm > pool_size
Definition: nn.h:272
ffi::String out_layout
Definition: nn.h:279
bool ceil_mode
Definition: nn.h:276
Attributes used in max_pool2d and avg_pool2d operator.
Definition: nn.h:312
static void RegisterReflection()
Definition: nn.h:322
bool count_include_pad
Definition: nn.h:318
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Pool2DAttrs", Pool2DAttrs, BaseAttrsNode)
ffi::String out_layout
Definition: nn.h:320
ffi::Array< IntImm > padding
Definition: nn.h:315
ffi::Array< IntImm > pool_size
Definition: nn.h:313
bool ceil_mode
Definition: nn.h:317
ffi::String layout
Definition: nn.h:319
ffi::Array< IntImm > dilation
Definition: nn.h:316
ffi::Array< IntImm > strides
Definition: nn.h:314
Attributes used in max_pool3d and avg_pool3d operator.
Definition: nn.h:355
bool ceil_mode
Definition: nn.h:360
ffi::String out_layout
Definition: nn.h:363
ffi::Array< IntImm > padding
Definition: nn.h:358
ffi::String layout
Definition: nn.h:362
static void RegisterReflection()
Definition: nn.h:365
bool count_include_pad
Definition: nn.h:361
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Pool3DAttrs", Pool3DAttrs, BaseAttrsNode)
ffi::Array< IntImm > strides
Definition: nn.h:357
ffi::Array< IntImm > pool_size
Definition: nn.h:356
ffi::Array< IntImm > dilation
Definition: nn.h:359
Attributes used in rms_norm operator.
Definition: nn.h:629
ffi::Array< Integer > axes
Definition: nn.h:630
double epsilon
Definition: nn.h:631
static void RegisterReflection()
Definition: nn.h:633
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.RMSNormAttrs", RMSNormAttrs, BaseAttrsNode)
Attributes used in softmax operators.
Definition: nn.h:473
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SoftmaxAttrs", SoftmaxAttrs, BaseAttrsNode)
int axis
Definition: nn.h:474
static void RegisterReflection()
Definition: nn.h:476
Attributes used in softplus operators.
Definition: nn.h:497
double threshold
Definition: nn.h:499
double beta
Definition: nn.h:498
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SoftplusAttrs", SoftplusAttrs, BaseAttrsNode)
static void RegisterReflection()
Definition: nn.h:501