tvm
nn.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_ATTRS_NN_H_
25 #define TVM_RELAX_ATTRS_NN_H_
26 
27 #include <tvm/relax/expr.h>
28 
29 namespace tvm {
30 namespace relax {
31 
33 struct Conv1DAttrs : public AttrsNodeReflAdapter<Conv1DAttrs> {
34  Array<IntImm> strides;
35  Array<IntImm> padding;
36  Array<IntImm> dilation;
37  int groups;
38  String data_layout;
39  String kernel_layout;
40  String out_layout;
42 
43  static void RegisterReflection() {
44  namespace refl = tvm::ffi::reflection;
45  refl::ObjectDef<Conv1DAttrs>()
46  .def_ro("strides", &Conv1DAttrs::strides, "Specifies the strides of the convolution.")
47  .def_ro("padding", &Conv1DAttrs::padding,
48  "If padding is non-zero, then the input is implicitly zero-padded"
49  "Padding support both symmetric and asymmetric as"
50  "one int : same padding used on both sides"
51  "two int : padding width in the order of (left, right)")
52  .def_ro("dilation", &Conv1DAttrs::dilation,
53  "Specifies the dilation rate to use for dilated convolution.")
54  .def_ro("groups", &Conv1DAttrs::groups,
55  "Number of groups to split the input into for grouped convolution. The number of "
56  "input and "
57  "output channels should be divisible by the number of groups.")
58  .def_ro("data_layout", &Conv1DAttrs::data_layout,
59  "Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
60  "'N', 'C', 'W' stands for batch, channel, width"
61  "dimensions respectively. Convolution is applied on the 'W' dimensions.")
62  .def_ro("kernel_layout", &Conv1DAttrs::kernel_layout,
63  "Dimension ordering of weight. Can be 'OIW', 'IOW', etc."
64  "'O', 'I', 'W' stands for num_filter, input_channel, and width"
65  "dimensions respectively.")
66  .def_ro("out_layout", &Conv1DAttrs::out_layout,
67  "Dimension ordering of output. Can be 'NCW', 'NWC', etc."
68  "'N', 'C', 'W' stands for batch, channel, and width"
69  "dimensions respectively. Default to be same as input layout.")
70  .def_ro("out_dtype", &Conv1DAttrs::out_dtype,
71  "Output data type, set to explicit type under mixed precision setting");
72  }
73 
74  static constexpr const char* _type_key = "relax.attrs.Conv1DAttrs";
76 }; // struct Conv1dAttrs
77 
79 struct Conv2DAttrs : public AttrsNodeReflAdapter<Conv2DAttrs> {
80  Array<IntImm> strides;
81  Array<IntImm> padding;
82  Array<IntImm> dilation;
83  int groups;
84  String data_layout;
85  String kernel_layout;
86  String out_layout;
88 
89  static void RegisterReflection() {
90  namespace refl = tvm::ffi::reflection;
91  refl::ObjectDef<Conv2DAttrs>()
92  .def_ro("strides", &Conv2DAttrs::strides, "Specifies the strides of the convolution.")
93  .def_ro("padding", &Conv2DAttrs::padding,
94  "If padding is non-zero, then the input is implicitly zero-padded"
95  "Padding support both symmetric and asymmetric as"
96  "one int : same padding used on all sides"
97  "two int : bottom, right will use same padding as top, left"
98  "four int : padding width in the order of (top, left, bottom, right)")
99  .def_ro("dilation", &Conv2DAttrs::dilation,
100  "Specifies the dilation rate to use for dilated convolution.")
101  .def_ro("groups", &Conv2DAttrs::groups,
102  "Number of groups to split the input into for grouped convolution. The number of "
103  "input and "
104  "output channels should be divisible by the number of groups.")
105  .def_ro("data_layout", &Conv2DAttrs::data_layout,
106  "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
107  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
108  "dimensions respectively. Convolution is applied on the 'H' and"
109  "'W' dimensions.")
110  .def_ro("kernel_layout", &Conv2DAttrs::kernel_layout,
111  "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
112  "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
113  "dimensions respectively.")
114  .def_ro("out_layout", &Conv2DAttrs::out_layout,
115  "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
116  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
117  "dimensions respectively. Default to be same as input layout.")
118  .def_ro("out_dtype", &Conv2DAttrs::out_dtype,
119  "Output data type, set to explicit type under mixed precision setting");
120  }
121 
122  static constexpr const char* _type_key = "relax.attrs.Conv2DAttrs";
124 }; // struct Conv2dAttrs
125 
127 struct Conv3DAttrs : public AttrsNodeReflAdapter<Conv3DAttrs> {
128  Array<IntImm> strides;
129  Array<IntImm> padding;
130  Array<IntImm> dilation;
131  int groups;
132  String data_layout;
134  String out_layout;
136 
137  static void RegisterReflection() {
138  namespace refl = tvm::ffi::reflection;
139  refl::ObjectDef<Conv3DAttrs>()
140  .def_ro("strides", &Conv3DAttrs::strides, "Specifies the strides of the convolution.")
141  .def_ro(
142  "padding", &Conv3DAttrs::padding,
143  "If padding is non-zero, then the input is implicitly zero-padded"
144  "Padding support both symmetric and asymmetric as"
145  "one int : same padding used on all sides"
146  "two int : bottom, right will use same padding as top, left"
147  "four int : padding width in the order of (forward, back, top, left, bottom, right)")
148  .def_ro("dilation", &Conv3DAttrs::dilation,
149  "Specifies the dilation rate to use for dilated convolution.")
150  .def_ro("groups", &Conv3DAttrs::groups,
151  "Number of groups to split the input into for grouped convolution. The number of "
152  "input and "
153  "output channels should be divisible by the number of groups.")
154  .def_ro("data_layout", &Conv3DAttrs::data_layout,
155  "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
156  "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
157  "dimensions respectively. Convolution is applied on the 'D', 'H', and"
158  "'W' dimensions.")
159  .def_ro(
160  "kernel_layout", &Conv3DAttrs::kernel_layout,
161  "Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc."
162  "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height, and width"
163  "dimensions respectively.")
164  .def_ro("out_layout", &Conv3DAttrs::out_layout,
165  "Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc."
166  "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
167  "dimensions respectively. Default to be same as input layout.")
168  .def_ro("out_dtype", &Conv3DAttrs::out_dtype,
169  "Output data type, set to explicit type under mixed precision setting");
170  }
171 
172  static constexpr const char* _type_key = "relax.attrs.Conv3DAttrs";
174 }; // struct Conv3dAttrs
175 
177 struct Conv1DTransposeAttrs : public AttrsNodeReflAdapter<Conv1DTransposeAttrs> {
178  Array<IntImm> strides;
179  Array<IntImm> padding;
180  Array<IntImm> output_padding;
181  Array<IntImm> dilation;
182  int groups;
183  String data_layout;
185  String out_layout;
187 
188  static void RegisterReflection() {
189  namespace refl = tvm::ffi::reflection;
190  refl::ObjectDef<Conv1DTransposeAttrs>()
191  .def_ro("strides", &Conv1DTransposeAttrs::strides,
192  "Specifies the strides of the convolution.")
193  .def_ro("padding", &Conv1DTransposeAttrs::padding,
194  "If padding is non-zero, then the input is implicitly zero-padded"
195  "Padding support both symmetric and asymmetric as"
196  "one int : same padding used on both sides"
197  "two int : padding width in the order of (left, right)")
198  .def_ro("output_padding", &Conv1DTransposeAttrs::output_padding,
199  "Used to disambiguate the output shape.")
200  .def_ro("dilation", &Conv1DTransposeAttrs::dilation,
201  "Specifies the dilation rate to use for dilated convolution.")
202  .def_ro("groups", &Conv1DTransposeAttrs::groups,
203  "Number of groups to split the input into for grouped convolution. The number of "
204  "input and "
205  "output channels should be divisible by the number of groups.")
206  .def_ro("data_layout", &Conv1DTransposeAttrs::data_layout,
207  "Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
208  "'N', 'C', 'W' stands for batch, channel, width"
209  "dimensions respectively. Convolution is applied on the 'W' dimensions.")
210  .def_ro("kernel_layout", &Conv1DTransposeAttrs::kernel_layout,
211  "Dimension ordering of weight. Can be 'OIW', 'IOW', etc."
212  "'O', 'I', 'W' stands for num_filter, input_channel, and width"
213  "dimensions respectively.")
214  .def_ro("out_layout", &Conv1DTransposeAttrs::out_layout,
215  "Dimension ordering of output. Can be 'NCW', 'NWC', etc."
216  "'N', 'C', 'W' stands for batch, channel, and width"
217  "dimensions respectively. Default to be same as input layout.")
218  .def_ro("out_dtype", &Conv1DTransposeAttrs::out_dtype,
219  "Output data type, set to explicit type under mixed precision setting");
220  }
221 
222  static constexpr const char* _type_key = "relax.attrs.Conv1DTransposeAttrs";
224 }; // struct Conv1DTransposeAttrs
225 
227 struct Conv2DTransposeAttrs : public AttrsNodeReflAdapter<Conv2DTransposeAttrs> {
228  Array<IntImm> strides;
229  Array<IntImm> padding;
230  Array<IntImm> output_padding;
231  Array<IntImm> dilation;
232  int groups;
233  String data_layout;
235  String out_layout;
237 
238  static void RegisterReflection() {
239  namespace refl = tvm::ffi::reflection;
240  refl::ObjectDef<Conv2DTransposeAttrs>()
241  .def_ro("strides", &Conv2DTransposeAttrs::strides,
242  "Specifies the strides of the convolution.")
243  .def_ro("padding", &Conv2DTransposeAttrs::padding,
244  "If padding is non-zero, then the input is implicitly zero-padded"
245  "Padding support both symmetric and asymmetric as"
246  "one int : same padding used on all sides"
247  "two int : bottom, right will use same padding as top, left"
248  "four int : padding width in the order of (top, left, bottom, right)")
249  .def_ro("output_padding", &Conv2DTransposeAttrs::output_padding,
250  "Used to disambiguate the output shape.")
251  .def_ro("dilation", &Conv2DTransposeAttrs::dilation,
252  "Specifies the dilation rate to use for dilated convolution.")
253  .def_ro("groups", &Conv2DTransposeAttrs::groups,
254  "Number of groups to split the input into for grouped convolution. The number of "
255  "input and "
256  "output channels should be divisible by the number of groups.")
257  .def_ro("data_layout", &Conv2DTransposeAttrs::data_layout,
258  "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
259  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
260  "dimensions respectively. Convolution is applied on the 'H' and"
261  "'W' dimensions.")
262  .def_ro("kernel_layout", &Conv2DTransposeAttrs::kernel_layout,
263  "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
264  "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
265  "dimensions respectively.")
266  .def_ro("out_layout", &Conv2DTransposeAttrs::out_layout,
267  "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
268  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
269  "dimensions respectively. Default to be same as input layout.")
270  .def_ro("out_dtype", &Conv2DTransposeAttrs::out_dtype,
271  "Output data type, set to explicit type under mixed precision setting");
272  }
273 
274  static constexpr const char* _type_key = "relax.attrs.Conv2DTransposeAttrs";
276 }; // struct Conv2DTransposeAttrs
277 
279 struct Pool1DAttrs : public AttrsNodeReflAdapter<Pool1DAttrs> {
280  Array<IntImm> pool_size;
281  Array<IntImm> strides;
282  Array<IntImm> padding;
283  Array<IntImm> dilation;
284  bool ceil_mode;
286  String layout;
287  String out_layout;
288 
289  static void RegisterReflection() {
290  namespace refl = tvm::ffi::reflection;
291  refl::ObjectDef<Pool1DAttrs>()
292  .def_ro("pool_size", &Pool1DAttrs::pool_size, "Size of the pooling windows.")
293  .def_ro("strides", &Pool1DAttrs::strides, "Specifies the strides of the convolution.")
294  .def_ro("dilation", &Pool1DAttrs::dilation, "Specifies the dilation of the convolution.")
295  .def_ro("padding", &Pool1DAttrs::padding,
296  "If padding is non-zero, then the input is implicitly zero-padded"
297  "Padding support both symmetric and asymmetric as"
298  "one int : same padding used on all sides"
299  "two int : padding width in the order of (left, right)")
300  .def_ro(
301  "ceil_mode", &Pool1DAttrs::ceil_mode,
302  "A boolean indicating if use ceil or floor to compute the output shape. By using ceil, "
303  "every element in the input tensor will be covered by a sliding window.")
304  .def_ro("count_include_pad", &Pool1DAttrs::count_include_pad,
305  "When true, will include padding to compute the average")
306  .def_ro("layout", &Pool1DAttrs::layout,
307  "Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
308  "'N', 'C', 'W' stands for batch, channel, and width"
309  "dimensions respectively. Pooling is applied on the 'W' dimensions.",
310  refl::DefaultValue("NCW"))
311  .def_ro("out_layout", &Pool1DAttrs::out_layout,
312  "Dimension ordering of output data. Can be 'NCW', 'NWC', etc."
313  "'N', 'C', 'W' stands for batch, channel, and width"
314  "dimensions respectively. Pooling is applied on the 'W' dimensions.");
315  }
316 
317  static constexpr const char* _type_key = "relax.attrs.Pool1DAttrs";
319 }; // struct Pool1dAttrs
320 
322 struct Pool2DAttrs : public AttrsNodeReflAdapter<Pool2DAttrs> {
323  Array<IntImm> pool_size;
324  Array<IntImm> strides;
325  Array<IntImm> padding;
326  Array<IntImm> dilation;
327  bool ceil_mode;
329  String layout;
330  String out_layout;
331 
332  static void RegisterReflection() {
333  namespace refl = tvm::ffi::reflection;
334  refl::ObjectDef<Pool2DAttrs>()
335  .def_ro("pool_size", &Pool2DAttrs::pool_size, "Size of the pooling windows.")
336  .def_ro("strides", &Pool2DAttrs::strides, "Specifies the strides of the convolution.")
337  .def_ro("dilation", &Pool2DAttrs::dilation, "Specifies the dilation of the convolution.")
338  .def_ro("padding", &Pool2DAttrs::padding,
339  "If padding is non-zero, then the input is implicitly zero-padded"
340  "Padding support both symmetric and asymmetric as"
341  "one int : same padding used on all sides"
342  "two int : bottom, right will use same padding as top, left"
343  "four int : padding width in the order of (top, left, bottom, right)")
344  .def_ro(
345  "ceil_mode", &Pool2DAttrs::ceil_mode,
346  "A boolean indicating if use ceil or floor to compute the output shape. By using ceil, "
347  "every element in the input tensor will be covered by a sliding window.")
348  .def_ro("count_include_pad", &Pool2DAttrs::count_include_pad,
349  "When true, will include padding to compute the average")
350  .def_ro("layout", &Pool2DAttrs::layout,
351  "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
352  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
353  "dimensions respectively. Pooling is applied on the 'H' and"
354  "'W' dimensions.")
355  .def_ro("out_layout", &Pool2DAttrs::out_layout,
356  "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
357  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
358  "dimensions respectively. Pooling is applied on the 'H' and"
359  "'W' dimensions.");
360  }
361 
362  static constexpr const char* _type_key = "relax.attrs.Pool2DAttrs";
364 }; // struct Pool2dAttrs
365 
367 struct Pool3DAttrs : public AttrsNodeReflAdapter<Pool3DAttrs> {
368  Array<IntImm> pool_size;
369  Array<IntImm> strides;
370  Array<IntImm> padding;
371  Array<IntImm> dilation;
372  bool ceil_mode;
374  String layout;
375  String out_layout;
376 
377  static void RegisterReflection() {
378  namespace refl = tvm::ffi::reflection;
379  refl::ObjectDef<Pool3DAttrs>()
380  .def_ro("pool_size", &Pool3DAttrs::pool_size, "Size of the pooling windows.")
381  .def_ro("strides", &Pool3DAttrs::strides, "Specifies the strides of the convolution.")
382  .def_ro("dilation", &Pool3DAttrs::dilation, "Specifies the dilation of the convolution.")
383  .def_ro("padding", &Pool3DAttrs::padding,
384  "If padding is non-zero, then the input is implicitly zero-padded"
385  "Padding support both symmetric and asymmetric as"
386  "one int : same padding used on all sides"
387  "three int : back, bottom, right will use same padding as front, top, left"
388  "four int : padding width in the order of (front, top, left, back, bottom, right)")
389  .def_ro(
390  "ceil_mode", &Pool3DAttrs::ceil_mode,
391  "A boolean indicating if use ceil or floor to compute the output shape. By using ceil, "
392  "every element in the input tensor will be covered by a sliding window.")
393  .def_ro("count_include_pad", &Pool3DAttrs::count_include_pad,
394  "When true, will include padding to compute the average")
395  .def_ro("layout", &Pool3DAttrs::layout,
396  "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
397  "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
398  "dimensions respectively. Pooling is applied on the 'D', 'H' and"
399  "'W' dimensions.")
400  .def_ro("out_layout", &Pool3DAttrs::out_layout,
401  "Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc."
402  "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
403  "dimensions respectively. Pooling is applied on the 'D', 'H' and"
404  "'W' dimensions.");
405  }
406 
407  static constexpr const char* _type_key = "relax.attrs.Pool3DAttrs";
409 }; // struct Pool3dAttrs
410 
412 struct AdaptivePool1DAttrs : public AttrsNodeReflAdapter<AdaptivePool1DAttrs> {
413  Optional<Array<IntImm>> output_size;
414  String layout;
415  String out_layout;
416 
417  static void RegisterReflection() {
418  namespace refl = tvm::ffi::reflection;
419  refl::ObjectDef<AdaptivePool1DAttrs>()
420  .def_ro("output_size", &AdaptivePool1DAttrs::output_size, "Output width.")
421  .def_ro("layout", &AdaptivePool1DAttrs::layout,
422  "Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
423  "'N', 'C', 'W' stands for batch, channel and width"
424  "dimensions respectively. Pooling is applied on the"
425  "'W' dimensions.")
426  .def_ro("out_layout", &AdaptivePool1DAttrs::out_layout,
427  "Dimension ordering of output data. Can be 'NCW', 'NWC', etc."
428  "'N', 'C', 'W' stands for batch, channel and width"
429  "dimensions respectively. Pooling is applied on the"
430  "'W' dimensions.");
431  }
432 
433  static constexpr const char* _type_key = "relax.attrs.AdaptivePool1DAttrs";
435 }; // struct AdaptivePool1DAttrs
436 
438 struct AdaptivePool2DAttrs : public AttrsNodeReflAdapter<AdaptivePool2DAttrs> {
439  Optional<Array<IntImm>> output_size;
440  String layout;
441  String out_layout;
442 
443  static void RegisterReflection() {
444  namespace refl = tvm::ffi::reflection;
445  refl::ObjectDef<AdaptivePool2DAttrs>()
446  .def_ro("output_size", &AdaptivePool2DAttrs::output_size, "Output height and width.")
447  .def_ro("layout", &AdaptivePool2DAttrs::layout,
448  "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
449  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
450  "dimensions respectively. Pooling is applied on the 'H' and"
451  "'W' dimensions.")
452  .def_ro("out_layout", &AdaptivePool2DAttrs::out_layout,
453  "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
454  "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
455  "dimensions respectively. Pooling is applied on the 'H' and"
456  "'W' dimensions.");
457  }
458 
459  static constexpr const char* _type_key = "relax.attrs.AdaptivePool2DAttrs";
461 }; // struct AdaptivePool2DAttrs
462 
464 struct AdaptivePool3DAttrs : public AttrsNodeReflAdapter<AdaptivePool3DAttrs> {
465  Optional<Array<IntImm>> output_size;
466  String layout;
467  String out_layout;
468 
469  static void RegisterReflection() {
470  namespace refl = tvm::ffi::reflection;
471  refl::ObjectDef<AdaptivePool3DAttrs>()
472  .def_ro("output_size", &AdaptivePool3DAttrs::output_size, "Output depth, height and width.")
473  .def_ro("layout", &AdaptivePool3DAttrs::layout,
474  "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
475  "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
476  "dimensions respectively. Pooling is applied on 'D', 'H' and"
477  "'W' dimensions.")
478  .def_ro("out_layout", &AdaptivePool3DAttrs::out_layout,
479  "Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc."
480  "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
481  "dimensions respectively. Pooling is applied on 'D', 'H' and"
482  "'W' dimensions.");
483  }
484 
485  static constexpr const char* _type_key = "relax.attrs.AdaptivePool3DAttrs";
487 }; // struct AdaptivePool3DAttrs
488 
490 struct SoftmaxAttrs : public AttrsNodeReflAdapter<SoftmaxAttrs> {
491  int axis;
492 
493  static void RegisterReflection() {
494  namespace refl = tvm::ffi::reflection;
495  refl::ObjectDef<SoftmaxAttrs>().def_ro("axis", &SoftmaxAttrs::axis,
496  "The axis to sum over when computing softmax.");
497  }
498 
499  static constexpr const char* _type_key = "relax.attrs.SoftmaxAttrs";
501 };
502 
504 struct LeakyReluAttrs : public AttrsNodeReflAdapter<LeakyReluAttrs> {
505  double alpha;
506 
507  static void RegisterReflection() {
508  namespace refl = tvm::ffi::reflection;
509  refl::ObjectDef<LeakyReluAttrs>().def_ro("alpha", &LeakyReluAttrs::alpha,
510  "The slope of the negative part.");
511  }
512 
513  static constexpr const char* _type_key = "relax.attrs.LeakyReluAttrs";
515 };
516 
518 struct SoftplusAttrs : public AttrsNodeReflAdapter<SoftplusAttrs> {
519  double beta;
520  double threshold;
521 
522  static void RegisterReflection() {
523  namespace refl = tvm::ffi::reflection;
524  refl::ObjectDef<SoftplusAttrs>()
525  .def_ro("beta", &SoftplusAttrs::beta,
526  "Scaling factor controlling the sharpness of the Softplus transition.")
527  .def_ro("threshold", &SoftplusAttrs::threshold,
528  "Value determining when to use linear approximation for numerical stability.");
529  }
530 
531  static constexpr const char* _type_key = "relax.attrs.SoftplusAttrs";
533 };
534 
536 struct PReluAttrs : public AttrsNodeReflAdapter<PReluAttrs> {
537  int axis;
538 
539  static void RegisterReflection() {
540  namespace refl = tvm::ffi::reflection;
541  refl::ObjectDef<PReluAttrs>().def_ro("axis", &PReluAttrs::axis,
542  "The axis along which the alpha values are applied.");
543  }
544 
545  static constexpr const char* _type_key = "relax.attrs.PReluAttrs";
547 };
548 
550 struct BatchNormAttrs : public AttrsNodeReflAdapter<BatchNormAttrs> {
551  int axis;
552  double epsilon;
553  bool center;
554  bool scale;
555  double momentum;
556  bool training;
557 
558  static void RegisterReflection() {
559  namespace refl = tvm::ffi::reflection;
560  refl::ObjectDef<BatchNormAttrs>()
561  .def_ro("axis", &BatchNormAttrs::axis, "The axis along which the normalization is applied.")
562  .def_ro("epsilon", &BatchNormAttrs::epsilon,
563  "Small float added to variance to avoid dividing by zero")
564  .def_ro("center", &BatchNormAttrs::center,
565  "Indicating if the beta offset will be added to the normalized tensor.")
566  .def_ro("scale", &BatchNormAttrs::scale,
567  "Indicating if the gamma scale will be multiplied.")
568  .def_ro("momentum", &BatchNormAttrs::momentum,
569  "The value used for the moving_mean and moving_var update.")
570  .def_ro("training", &BatchNormAttrs::training,
571  "Whether we are training (i.e., not in eval mode).");
572  }
573 
574  static constexpr const char* _type_key = "relax.attrs.BatchNormAttrs";
576 }; // struct BatchNormAttrs
577 
579 struct LayerNormAttrs : public AttrsNodeReflAdapter<LayerNormAttrs> {
580  Array<Integer> axes;
581  double epsilon;
582  bool center;
583  bool scale;
584 
585  static void RegisterReflection() {
586  namespace refl = tvm::ffi::reflection;
587  refl::ObjectDef<LayerNormAttrs>()
588  .def_ro("axes", &LayerNormAttrs::axes,
589  "The axes that along which the normalization is applied.")
590  .def_ro("epsilon", &LayerNormAttrs::epsilon,
591  "Small float added to variance to avoid dividing by zero")
592  .def_ro("center", &LayerNormAttrs::center,
593  "Indicating if the beta offset will be added to the normalized tensor.")
594  .def_ro("scale", &LayerNormAttrs::scale,
595  "Indicating if the gamma scale will be multiplied.");
596  }
597 
598  static constexpr const char* _type_key = "relax.attrs.LayerNormAttrs";
600 }; // struct LayerNormAttrs
601 
603 struct GroupNormAttrs : public AttrsNodeReflAdapter<GroupNormAttrs> {
606  Array<Integer> axes;
607  double epsilon;
608  bool center;
609  bool scale;
610 
611  static void RegisterReflection() {
612  namespace refl = tvm::ffi::reflection;
613  refl::ObjectDef<GroupNormAttrs>()
614  .def_ro("num_groups", &GroupNormAttrs::num_groups,
615  "The number of groups to separate the channels into.")
616  .def_ro("channel_axis", &GroupNormAttrs::channel_axis,
617  "The axis that represents the channel.")
618  .def_ro(
619  "axes", &GroupNormAttrs::axes,
620  "The axes that along which the normalization is applied (excluding the channel axis).")
621  .def_ro("epsilon", &GroupNormAttrs::epsilon,
622  "Small float added to variance to avoid dividing by zero")
623  .def_ro("center", &GroupNormAttrs::center,
624  "Indicating if the beta offset will be added to the normalized tensor.")
625  .def_ro("scale", &GroupNormAttrs::scale,
626  "Indicating if the gamma scale will be multiplied.");
627  }
628 
629  static constexpr const char* _type_key = "relax.attrs.GroupNormAttrs";
631 }; // struct GroupNormAttrs
632 
634 struct InstanceNormAttrs : public AttrsNodeReflAdapter<InstanceNormAttrs> {
636  Array<Integer> axes;
637  double epsilon;
638  bool center;
639  bool scale;
640 
641  static void RegisterReflection() {
642  namespace refl = tvm::ffi::reflection;
643  refl::ObjectDef<InstanceNormAttrs>()
644  .def_ro("channel_axis", &InstanceNormAttrs::channel_axis,
645  "The axis that represents the channel.")
646  .def_ro("axes", &InstanceNormAttrs::axes,
647  "The axes that along which the normalization is applied.")
648  .def_ro("epsilon", &InstanceNormAttrs::epsilon,
649  "Small float added to variance to avoid dividing by zero")
650  .def_ro("center", &InstanceNormAttrs::center,
651  "Indicating if the beta offset will be added to the normalized tensor.")
652  .def_ro("scale", &InstanceNormAttrs::scale,
653  "Indicating if the gamma scale will be multiplied.");
654  }
655 
656  static constexpr const char* _type_key = "relax.attrs.InstanceNormAttrs";
658 }; // struct InstanceNormAttrs
659 
661 struct RMSNormAttrs : public AttrsNodeReflAdapter<RMSNormAttrs> {
662  Array<Integer> axes;
663  double epsilon;
664 
665  static void RegisterReflection() {
666  namespace refl = tvm::ffi::reflection;
667  refl::ObjectDef<RMSNormAttrs>()
668  .def_ro("axes", &RMSNormAttrs::axes,
669  "The axes that along which the normalization is applied.")
670  .def_ro("epsilon", &RMSNormAttrs::epsilon,
671  "Small float added to variance to avoid dividing by zero");
672  }
673 
674  static constexpr const char* _type_key = "relax.attrs.RMSNormAttrs";
676 }; // struct RMSNormAttrs
677 
679 struct NLLLossAttrs : public AttrsNodeReflAdapter<NLLLossAttrs> {
680  String reduction;
682 
683  static void RegisterReflection() {
684  namespace refl = tvm::ffi::reflection;
685  refl::ObjectDef<NLLLossAttrs>()
686  .def_ro("reduction", &NLLLossAttrs::reduction,
687  "The reduction method to apply to the output. Can be"
688  "'none', 'mean' or 'sum'.",
689  refl::DefaultValue("mean"))
690  .def_ro("ignore_index", &NLLLossAttrs::ignore_index, "The target value to ignore.");
691  }
692 
693  static constexpr const char* _type_key = "relax.attrs.NLLLossAttrs";
695 }; // struct NLLLossAttrs
696 
698 struct DropoutAttrs : public AttrsNodeReflAdapter<DropoutAttrs> {
699  double rate;
700 
701  static void RegisterReflection() {
702  namespace refl = tvm::ffi::reflection;
703  refl::ObjectDef<DropoutAttrs>().def_ro(
704  "rate", &DropoutAttrs::rate,
705  "Fraction of the input that gets dropped out during training time");
706  }
707 
708  static constexpr const char* _type_key = "relax.attrs.DropoutAttrs";
710 }; // struct DropoutAttrs
711 
713 struct AttentionAttrs : public AttrsNodeReflAdapter<AttentionAttrs> {
714  Optional<FloatImm> scale;
715  Optional<String> causal_mask;
716  Optional<IntImm> window_size;
717 
718  static void RegisterReflection() {
719  namespace refl = tvm::ffi::reflection;
720  refl::ObjectDef<AttentionAttrs>()
721  .def_ro(
722  "scale", &AttentionAttrs::scale,
723  "The custom scale applied before the softmax. The default value is 1 / sqrt(head_dim).")
724  .def_ro("causal_mask", &AttentionAttrs::causal_mask,
725  "The type of the causal mask, i.e. 'TopLeft' and 'BottomRight'.")
726  .def_ro("window_size", &AttentionAttrs::window_size,
727  "The size of the window for sliding-window attention.");
728  }
729 
730  static constexpr const char* _type_key = "relax.attrs.AttentionAttrs";
732 }; // struct AttentionAttrs
733 
735 struct PadAttrs : public AttrsNodeReflAdapter<PadAttrs> {
736  Array<Integer> pad_width;
737  double pad_value = 0.0;
738  tvm::String pad_mode;
739 
740  static void RegisterReflection() {
741  namespace refl = tvm::ffi::reflection;
742  refl::ObjectDef<PadAttrs>()
743  .def_ro("pad_width", &PadAttrs::pad_width,
744  "Number of values padded to the edges of each axis, "
745  "in the format of (before_1, after_1, ..., before_N, after_N)")
746  .def_ro("pad_value", &PadAttrs::pad_value, "The value to fill in padded area with",
747  refl::DefaultValue(0.0))
748  .def_ro("pad_mode", &PadAttrs::pad_mode,
749  "Padding type to use. \"constant\" pads with constant_value, "
750  "\"edge\" pads using the edge values of the input array, "
751  "\"reflect\" pads by reflecting values with respect to the edges.",
752  refl::DefaultValue("constant"));
753  }
754 
755  static constexpr const char* _type_key = "relax.attrs.PadAttrs";
757 };
758 
760 struct PixelShuffleAttrs : public AttrsNodeReflAdapter<PixelShuffleAttrs> {
762 
763  static void RegisterReflection() {
764  namespace refl = tvm::ffi::reflection;
765  refl::ObjectDef<PixelShuffleAttrs>().def_ro("upscale_factor",
767  "Scale factor for spatial upsampling.");
768  }
769 
770  static constexpr const char* _type_key = "relax.attrs.PixelShuffleAttrs";
772 };
773 
774 } // namespace relax
775 } // namespace tvm
776 
777 #endif // TVM_RELAX_ATTRS_NN_H_
Adapter for AttrsNode with the new reflection API.
Definition: attrs.h:384
Base class of all attribute class.
Definition: attrs.h:103
Runtime primitive data type.
Definition: data_type.h:47
Definition: repr_printer.h:91
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
Attributes for 1d adaptive pool operator.
Definition: nn.h:412
Optional< Array< IntImm > > output_size
Definition: nn.h:413
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AdaptivePool1DAttrs, BaseAttrsNode)
static constexpr const char * _type_key
Definition: nn.h:433
String layout
Definition: nn.h:414
static void RegisterReflection()
Definition: nn.h:417
String out_layout
Definition: nn.h:415
Attributes for 2d adaptive pool operator.
Definition: nn.h:438
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AdaptivePool2DAttrs, BaseAttrsNode)
static constexpr const char * _type_key
Definition: nn.h:459
static void RegisterReflection()
Definition: nn.h:443
Optional< Array< IntImm > > output_size
Definition: nn.h:439
String layout
Definition: nn.h:440
String out_layout
Definition: nn.h:441
Attributes for 3d adaptive pool operator.
Definition: nn.h:464
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AdaptivePool3DAttrs, BaseAttrsNode)
String layout
Definition: nn.h:466
static void RegisterReflection()
Definition: nn.h:469
static constexpr const char * _type_key
Definition: nn.h:485
String out_layout
Definition: nn.h:467
Optional< Array< IntImm > > output_size
Definition: nn.h:465
Attributes used in Attention operator.
Definition: nn.h:713
Optional< String > causal_mask
Definition: nn.h:715
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AttentionAttrs, BaseAttrsNode)
static constexpr const char * _type_key
Definition: nn.h:730
Optional< FloatImm > scale
Definition: nn.h:714
Optional< IntImm > window_size
Definition: nn.h:716
static void RegisterReflection()
Definition: nn.h:718
Attributes used in batch_norm operator.
Definition: nn.h:550
bool training
Definition: nn.h:556
bool scale
Definition: nn.h:554
static constexpr const char * _type_key
Definition: nn.h:574
static void RegisterReflection()
Definition: nn.h:558
double epsilon
Definition: nn.h:552
int axis
Definition: nn.h:551
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(BatchNormAttrs, BaseAttrsNode)
double momentum
Definition: nn.h:555
bool center
Definition: nn.h:553
Attributes used in Conv1d operator.
Definition: nn.h:33
Array< IntImm > dilation
Definition: nn.h:36
String out_layout
Definition: nn.h:40
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(Conv1DAttrs, BaseAttrsNode)
int groups
Definition: nn.h:37
Array< IntImm > padding
Definition: nn.h:35
Array< IntImm > strides
Definition: nn.h:34
static void RegisterReflection()
Definition: nn.h:43
String data_layout
Definition: nn.h:38
DataType out_dtype
Definition: nn.h:41
static constexpr const char * _type_key
Definition: nn.h:74
String kernel_layout
Definition: nn.h:39
Attributes used in Conv1DTranspose operator.
Definition: nn.h:177
static void RegisterReflection()
Definition: nn.h:188
Array< IntImm > output_padding
Definition: nn.h:180
String data_layout
Definition: nn.h:183
Array< IntImm > dilation
Definition: nn.h:181
Array< IntImm > strides
Definition: nn.h:178
DataType out_dtype
Definition: nn.h:186
String out_layout
Definition: nn.h:185
static constexpr const char * _type_key
Definition: nn.h:222
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(Conv1DTransposeAttrs, BaseAttrsNode)
Array< IntImm > padding
Definition: nn.h:179
String kernel_layout
Definition: nn.h:184
int groups
Definition: nn.h:182
Attributes used in Conv2d operator.
Definition: nn.h:79
String kernel_layout
Definition: nn.h:85
DataType out_dtype
Definition: nn.h:87
static constexpr const char * _type_key
Definition: nn.h:122
static void RegisterReflection()
Definition: nn.h:89
String data_layout
Definition: nn.h:84
int groups
Definition: nn.h:83
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(Conv2DAttrs, BaseAttrsNode)
Array< IntImm > strides
Definition: nn.h:80
Array< IntImm > dilation
Definition: nn.h:82
Array< IntImm > padding
Definition: nn.h:81
String out_layout
Definition: nn.h:86
Attributes used in Conv2d operator.
Definition: nn.h:227
Array< IntImm > dilation
Definition: nn.h:231
Array< IntImm > output_padding
Definition: nn.h:230
Array< IntImm > padding
Definition: nn.h:229
Array< IntImm > strides
Definition: nn.h:228
String kernel_layout
Definition: nn.h:234
static constexpr const char * _type_key
Definition: nn.h:274
int groups
Definition: nn.h:232
String data_layout
Definition: nn.h:233
String out_layout
Definition: nn.h:235
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(Conv2DTransposeAttrs, BaseAttrsNode)
static void RegisterReflection()
Definition: nn.h:238
DataType out_dtype
Definition: nn.h:236
Attributes used in Conv3d operator.
Definition: nn.h:127
String out_layout
Definition: nn.h:134
static constexpr const char * _type_key
Definition: nn.h:172
Array< IntImm > dilation
Definition: nn.h:130
String data_layout
Definition: nn.h:132
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(Conv3DAttrs, BaseAttrsNode)
static void RegisterReflection()
Definition: nn.h:137
Array< IntImm > strides
Definition: nn.h:128
DataType out_dtype
Definition: nn.h:135
Array< IntImm > padding
Definition: nn.h:129
String kernel_layout
Definition: nn.h:133
int groups
Definition: nn.h:131
Attributes used in dropout operator.
Definition: nn.h:698
double rate
Definition: nn.h:699
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(DropoutAttrs, BaseAttrsNode)
static void RegisterReflection()
Definition: nn.h:701
static constexpr const char * _type_key
Definition: nn.h:708
Attributes used in group_norm operator.
Definition: nn.h:603
int num_groups
Definition: nn.h:604
int channel_axis
Definition: nn.h:605
double epsilon
Definition: nn.h:607
static constexpr const char * _type_key
Definition: nn.h:629
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(GroupNormAttrs, BaseAttrsNode)
Array< Integer > axes
Definition: nn.h:606
bool center
Definition: nn.h:608
static void RegisterReflection()
Definition: nn.h:611
bool scale
Definition: nn.h:609
Attributes used in instance_norm operator.
Definition: nn.h:634
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(InstanceNormAttrs, BaseAttrsNode)
bool center
Definition: nn.h:638
double epsilon
Definition: nn.h:637
bool scale
Definition: nn.h:639
static void RegisterReflection()
Definition: nn.h:641
Array< Integer > axes
Definition: nn.h:636
int channel_axis
Definition: nn.h:635
static constexpr const char * _type_key
Definition: nn.h:656
Attributes used in layer_norm operator.
Definition: nn.h:579
bool scale
Definition: nn.h:583
static void RegisterReflection()
Definition: nn.h:585
bool center
Definition: nn.h:582
Array< Integer > axes
Definition: nn.h:580
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(LayerNormAttrs, BaseAttrsNode)
static constexpr const char * _type_key
Definition: nn.h:598
double epsilon
Definition: nn.h:581
Attributes used in softmax operators.
Definition: nn.h:504
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(LeakyReluAttrs, BaseAttrsNode)
static constexpr const char * _type_key
Definition: nn.h:513
static void RegisterReflection()
Definition: nn.h:507
double alpha
Definition: nn.h:505
Attributes used in nll_loss operator.
Definition: nn.h:679
static void RegisterReflection()
Definition: nn.h:683
static constexpr const char * _type_key
Definition: nn.h:693
int ignore_index
Definition: nn.h:681
String reduction
Definition: nn.h:680
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(NLLLossAttrs, BaseAttrsNode)
Attributes used in PReLU operator.
Definition: nn.h:536
static constexpr const char * _type_key
Definition: nn.h:545
static void RegisterReflection()
Definition: nn.h:539
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(PReluAttrs, BaseAttrsNode)
int axis
Definition: nn.h:537
Attributes used for the padding operator.
Definition: nn.h:735
static constexpr const char * _type_key
Definition: nn.h:755
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(PadAttrs, BaseAttrsNode)
double pad_value
Definition: nn.h:737
tvm::String pad_mode
Definition: nn.h:738
Array< Integer > pad_width
Definition: nn.h:736
static void RegisterReflection()
Definition: nn.h:740
Attributes used for the pixel shuffle operator.
Definition: nn.h:760
int upscale_factor
Definition: nn.h:761
static void RegisterReflection()
Definition: nn.h:763
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(PixelShuffleAttrs, BaseAttrsNode)
static constexpr const char * _type_key
Definition: nn.h:770
Attributes used in max_pool1d and avg_pool1d operator.
Definition: nn.h:279
static void RegisterReflection()
Definition: nn.h:289
Array< IntImm > padding
Definition: nn.h:282
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(Pool1DAttrs, BaseAttrsNode)
bool count_include_pad
Definition: nn.h:285
String layout
Definition: nn.h:286
static constexpr const char * _type_key
Definition: nn.h:317
Array< IntImm > strides
Definition: nn.h:281
Array< IntImm > dilation
Definition: nn.h:283
bool ceil_mode
Definition: nn.h:284
Array< IntImm > pool_size
Definition: nn.h:280
String out_layout
Definition: nn.h:287
Attributes used in max_pool2d and avg_pool2d operator.
Definition: nn.h:322
Array< IntImm > padding
Definition: nn.h:325
static void RegisterReflection()
Definition: nn.h:332
String layout
Definition: nn.h:329
bool count_include_pad
Definition: nn.h:328
static constexpr const char * _type_key
Definition: nn.h:362
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(Pool2DAttrs, BaseAttrsNode)
String out_layout
Definition: nn.h:330
Array< IntImm > dilation
Definition: nn.h:326
bool ceil_mode
Definition: nn.h:327
Array< IntImm > strides
Definition: nn.h:324
Array< IntImm > pool_size
Definition: nn.h:323
Attributes used in max_pool3d and avg_pool3d operator.
Definition: nn.h:367
bool ceil_mode
Definition: nn.h:372
Array< IntImm > dilation
Definition: nn.h:371
String layout
Definition: nn.h:374
static constexpr const char * _type_key
Definition: nn.h:407
static void RegisterReflection()
Definition: nn.h:377
bool count_include_pad
Definition: nn.h:373
String out_layout
Definition: nn.h:375
Array< IntImm > pool_size
Definition: nn.h:368
Array< IntImm > strides
Definition: nn.h:369
Array< IntImm > padding
Definition: nn.h:370
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(Pool3DAttrs, BaseAttrsNode)
Attributes used in rms_norm operator.
Definition: nn.h:661
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(RMSNormAttrs, BaseAttrsNode)
double epsilon
Definition: nn.h:663
static constexpr const char * _type_key
Definition: nn.h:674
static void RegisterReflection()
Definition: nn.h:665
Array< Integer > axes
Definition: nn.h:662
Attributes used in softmax operators.
Definition: nn.h:490
static constexpr const char * _type_key
Definition: nn.h:499
int axis
Definition: nn.h:491
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SoftmaxAttrs, BaseAttrsNode)
static void RegisterReflection()
Definition: nn.h:493
Attributes used in softplus operators.
Definition: nn.h:518
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SoftplusAttrs, BaseAttrsNode)
double threshold
Definition: nn.h:520
double beta
Definition: nn.h:519
static void RegisterReflection()
Definition: nn.h:522
static constexpr const char * _type_key
Definition: nn.h:531