tvm
manipulate.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_ATTRS_MANIPULATE_H_
25 #define TVM_RELAX_ATTRS_MANIPULATE_H_
26 
27 #include <tvm/relax/expr.h>
28 #include <tvm/tir/index_map.h>
29 
30 namespace tvm {
31 namespace relax {
32 
34 struct ConcatAttrs : public AttrsNodeReflAdapter<ConcatAttrs> {
35  Optional<int64_t> axis;
36 
37  static void RegisterReflection() {
38  namespace refl = tvm::ffi::reflection;
39  refl::ObjectDef<ConcatAttrs>().def_ro("axis", &ConcatAttrs::axis,
40  "The axis at which the input arrays are concatenated."
41  "Should lie in range `[-ndim, ndim)`.");
42  }
43 
44  static constexpr const char* _type_key = "relax.attrs.ConcatAttrs";
46 }; // struct ConcatAttrs
47 
49 struct ExpandDimsAttrs : public AttrsNodeReflAdapter<ExpandDimsAttrs> {
50  Array<Integer> axis;
51 
52  static void RegisterReflection() {
53  namespace refl = tvm::ffi::reflection;
54  refl::ObjectDef<ExpandDimsAttrs>().def_ro(
55  "axis", &ExpandDimsAttrs::axis,
56  "The axes at which the input array are expanded. "
57  "All values are required to lie in range `[-data.ndim - 1, data.ndim]`, "
58  "with the convention of negative indexing.");
59  }
60 
61  static constexpr const char* _type_key = "relax.attrs.ExpandDimsAttrs";
63 }; // struct ExpandDimsAttrs
64 
66 struct LayoutTransformAttrs : public AttrsNodeReflAdapter<LayoutTransformAttrs> {
68  // pad_value is chosen to be of PrimValue type, as it represents constant TIR POD expression. This
69  // needs to be revisited in case PrimValue is evolved to represent symbolic expression in future.
70  Optional<PrimValue> pad_value;
77  Optional<Array<IntImm>> axis_separators;
83  Optional<Array<IntImm>> input_axis_separators;
84 
85  static void RegisterReflection() {
86  namespace refl = tvm::ffi::reflection;
87  refl::ObjectDef<LayoutTransformAttrs>()
88  .def_ro("index_map", &LayoutTransformAttrs::index_map,
89  "The layout transformation to apply.")
90  .def_ro(
91  "pad_value", &LayoutTransformAttrs::pad_value,
92  "The specific value to be used to pad if the layout transform would result in implicit "
93  "padding. If not specified, the compiler is free to choose any value.")
94  .def_ro("axis_separators", &LayoutTransformAttrs::axis_separators,
95  "The separators between input axes when generating flat output axes")
96  .def_ro("input_axis_separators", &LayoutTransformAttrs::input_axis_separators,
97  "The separators between axes to regenerate output");
98  }
99 
100  static constexpr const char* _type_key = "relax.attrs.LayoutTransformAttrs";
102 }; // struct LayoutTransformAttrs
103 
105 struct PermuteDimsAttrs : public AttrsNodeReflAdapter<PermuteDimsAttrs> {
106  Optional<Array<Integer>> axes;
107 
108  static void RegisterReflection() {
109  namespace refl = tvm::ffi::reflection;
110  refl::ObjectDef<PermuteDimsAttrs>().def_ro(
111  "axes", &PermuteDimsAttrs::axes, "The target axes order, reverse order if not specified.");
112  }
113 
114  static constexpr const char* _type_key = "relax.attrs.PermuteDimsAttrs";
116 }; // struct PermuteDimsAttrs
117 
119 struct SplitAttrs : public AttrsNodeReflAdapter<SplitAttrs> {
121  int axis;
122 
123  static void RegisterReflection() {
124  namespace refl = tvm::ffi::reflection;
125  refl::ObjectDef<SplitAttrs>()
126  .def_ro("indices_or_sections", &SplitAttrs::indices_or_sections,
127  "The input array of indices or the number of split sections.")
128  .def_ro("axis", &SplitAttrs::axis, "The axis to be splitted");
129  }
130 
131  static constexpr const char* _type_key = "relax.attrs.SplitAttrs";
133 }; // struct SplitAttrs
134 
136 struct SqueezeAttrs : public AttrsNodeReflAdapter<SqueezeAttrs> {
137  Optional<Array<Integer>> axis;
138 
139  static void RegisterReflection() {
140  namespace refl = tvm::ffi::reflection;
141  refl::ObjectDef<SqueezeAttrs>().def_ro("axis", &SqueezeAttrs::axis,
142  "The axis to squeeze in the input tensor."
143  "If `axis = None`, all axis of dimension 1 get squeezed;"
144  "Else, the dimension in axes get squeezed."
145  "It is an error if an axis does not has dimension 1.");
146  }
147 
148  static constexpr const char* _type_key = "relax.attrs.SqueezeAttrs";
150 }; // struct SqueezeAttrs
151 
153 struct StackAttrs : public AttrsNodeReflAdapter<StackAttrs> {
154  Optional<Integer> axis;
155 
156  static void RegisterReflection() {
157  namespace refl = tvm::ffi::reflection;
158  refl::ObjectDef<StackAttrs>().def_ro(
159  "axis", &StackAttrs::axis,
160  "The axis along which to stack the input tensors. "
161  "The axis will be inserted at this position in the output, "
162  "so it must be in range [-ndim-1, ndim] where ndim is the "
163  "number of dimensions of the input tensors.");
164  }
165 
166  static constexpr const char* _type_key = "relax.attrs.StackAttrs";
168 }; // struct StackAttrs
169 
171 struct RepeatAttrs : public AttrsNodeReflAdapter<RepeatAttrs> {
172  int repeats;
173  Optional<int64_t> axis;
174 
175  static void RegisterReflection() {
176  namespace refl = tvm::ffi::reflection;
177  refl::ObjectDef<RepeatAttrs>()
178  .def_ro("repeats", &RepeatAttrs::repeats, "The number of repetitions.")
179  .def_ro("axis", &RepeatAttrs::axis,
180  "The axis along which to repeat values. The negative numbers are interpreted "
181  "counting from the backward. By default, use the flattened input array, and "
182  "return a flat output array.");
183  }
184 
185  static constexpr const char* _type_key = "relax.attrs.RepeatAttrs";
187 }; // struct RepeatAttrs
188 
190 struct TileAttrs : public AttrsNodeReflAdapter<TileAttrs> {
191  Array<Integer> repeats;
192 
193  static void RegisterReflection() {
194  namespace refl = tvm::ffi::reflection;
195  refl::ObjectDef<TileAttrs>().def_ro("repeats", &TileAttrs::repeats,
196  "The number of repetitions of data along each axis.");
197  }
198 
199  static constexpr const char* _type_key = "relax.attrs.TileAttrs";
201 }; // struct TileAttrs
202 
204 struct FlipAttrs : public AttrsNodeReflAdapter<FlipAttrs> {
206 
207  static void RegisterReflection() {
208  namespace refl = tvm::ffi::reflection;
209  refl::ObjectDef<FlipAttrs>().def_ro("axis", &FlipAttrs::axis,
210  "The axis along which to flip over.",
211  refl::DefaultValue(NullValue<Integer>()));
212  }
213 
214  static constexpr const char* _type_key = "relax.attrs.FlipAttrs";
216 }; // struct FlipAttrs
217 
219 struct GatherElementsAttrs : public AttrsNodeReflAdapter<GatherElementsAttrs> {
221 
222  static void RegisterReflection() {
223  namespace refl = tvm::ffi::reflection;
224  refl::ObjectDef<GatherElementsAttrs>().def_ro("axis", &GatherElementsAttrs::axis,
225  "The axis along which to index.",
226  refl::DefaultValue(0));
227  }
228 
229  static constexpr const char* _type_key = "relax.attrs.GatherElementsAttrs";
231 }; // struct GatherElementsAttrs
232 
234 struct GatherNDAttrs : public AttrsNodeReflAdapter<GatherNDAttrs> {
236 
237  static void RegisterReflection() {
238  namespace refl = tvm::ffi::reflection;
239  refl::ObjectDef<GatherNDAttrs>().def_ro("batch_dims", &GatherNDAttrs::batch_dims,
240  "The number of batch dims.", refl::DefaultValue(0));
241  }
242 
243  static constexpr const char* _type_key = "relax.attrs.GatherNDAttrs";
245 }; // struct GatherNDAttrs
246 
248 struct IndexPutAttrs : public AttrsNodeReflAdapter<IndexPutAttrs> {
250 
251  static void RegisterReflection() {
252  namespace refl = tvm::ffi::reflection;
253  refl::ObjectDef<IndexPutAttrs>().def_ro(
254  "accumulate", &IndexPutAttrs::accumulate,
255  "Whether to accumulate (add) values rather than replace. "
256  "If true, performs tensor[indices] += values, "
257  "otherwise performs tensor[indices] = values.",
258  refl::DefaultValue(false));
259  }
260 
261  static constexpr const char* _type_key = "relax.attrs.IndexPutAttrs";
263 }; // struct IndexPutAttrs
264 
266 struct MeshgridAttrs : public AttrsNodeReflAdapter<MeshgridAttrs> {
267  Optional<String> indexing;
268 
269  static void RegisterReflection() {
270  namespace refl = tvm::ffi::reflection;
271  refl::ObjectDef<MeshgridAttrs>().def_ro("indexing", &MeshgridAttrs::indexing,
272  "Specifies how the grid dimensions are ordered.");
273  }
274 
275  static constexpr const char* _type_key = "relax.attrs.MeshgridAttrs";
277 };
278 
280 struct ScatterElementsAttrs : public AttrsNodeReflAdapter<ScatterElementsAttrs> {
282  String reduction;
283 
284  static void RegisterReflection() {
285  namespace refl = tvm::ffi::reflection;
286  refl::ObjectDef<ScatterElementsAttrs>()
287  .def_ro("axis", &ScatterElementsAttrs::axis, "The axis over which to select values.",
288  refl::DefaultValue(0))
289  .def_ro("reduction", &ScatterElementsAttrs::reduction,
290  "Reduction mode of the scatter elements, "
291  "either \"update\", \"add\", \"mul\", \"mean\", \"min\" or \"max\".",
292  refl::DefaultValue("update"));
293  }
294 
295  static constexpr const char* _type_key = "relax.attrs.ScatterElementsAttrs";
297 }; // struct ScatterElementsAttrs
298 
300 struct ScatterNDAttrs : public AttrsNodeReflAdapter<ScatterNDAttrs> {
301  String reduction;
302 
303  static void RegisterReflection() {
304  namespace refl = tvm::ffi::reflection;
305  refl::ObjectDef<ScatterNDAttrs>().def_ro(
306  "reduction", &ScatterNDAttrs::reduction,
307  "Accumulation mode of the ScatterND, "
308  "either \"update\", \"add\", \"mul\", \"min\" or \"max\".",
309  refl::DefaultValue("update"));
310  }
311 
312  static constexpr const char* _type_key = "relax.attrs.ScatterNDAttrs";
314 }; // struct ScatterNDAttrs
315 
317 struct SliceScatterAttrs : public AttrsNodeReflAdapter<SliceScatterAttrs> {
318  int axis;
319 
320  static void RegisterReflection() {
321  namespace refl = tvm::ffi::reflection;
322  refl::ObjectDef<SliceScatterAttrs>().def_ro("axis", &SliceScatterAttrs::axis,
323  "the dimension to insert the slice into ",
324  refl::DefaultValue(0));
325  }
326 
327  static constexpr const char* _type_key = "relax.attrs.SliceScatterAttrs";
329 }; // struct SliceScatterAttrs
330 
332 struct OneHotAttrs : public AttrsNodeReflAdapter<OneHotAttrs> {
333  int depth;
334  int axis;
335 
336  static void RegisterReflection() {
337  namespace refl = tvm::ffi::reflection;
338  refl::ObjectDef<OneHotAttrs>()
339  .def_ro("depth", &OneHotAttrs::depth, "Depth of the one hot dimension.")
340  .def_ro("axis", &OneHotAttrs::axis, "Axis to fill.", refl::DefaultValue(-1));
341  }
342 
343  static constexpr const char* _type_key = "relax.attrs.OneHotAttrs";
345 }; // struct OneHotAttrs
346 
347 } // namespace relax
348 } // namespace tvm
349 
350 #endif // TVM_RELAX_ATTRS_MANIPULATE_H_
Adapter for AttrsNode with the new reflection API.
Definition: attrs.h:384
Base class of all attribute class.
Definition: attrs.h:103
Container of constant int that adds more constructors.
Definition: expr.h:612
Definition: index_map.h:169
Defines a remapping of buffer indices.
Definition: repr_printer.h:91
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
Attributes used in concat operators.
Definition: manipulate.h:34
static void RegisterReflection()
Definition: manipulate.h:37
Optional< int64_t > axis
Definition: manipulate.h:35
static constexpr const char * _type_key
Definition: manipulate.h:44
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ConcatAttrs, BaseAttrsNode)
Attributes used in expand_dims operators.
Definition: manipulate.h:49
static void RegisterReflection()
Definition: manipulate.h:52
static constexpr const char * _type_key
Definition: manipulate.h:61
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ExpandDimsAttrs, BaseAttrsNode)
Array< Integer > axis
Definition: manipulate.h:50
Attributes used in flip operators.
Definition: manipulate.h:204
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(FlipAttrs, BaseAttrsNode)
static void RegisterReflection()
Definition: manipulate.h:207
static constexpr const char * _type_key
Definition: manipulate.h:214
Integer axis
Definition: manipulate.h:205
Attributes used in gather_elements operators.
Definition: manipulate.h:219
static void RegisterReflection()
Definition: manipulate.h:222
static constexpr const char * _type_key
Definition: manipulate.h:229
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(GatherElementsAttrs, BaseAttrsNode)
Integer axis
Definition: manipulate.h:220
Attributes used in gather_nd operators.
Definition: manipulate.h:234
Integer batch_dims
Definition: manipulate.h:235
static constexpr const char * _type_key
Definition: manipulate.h:243
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(GatherNDAttrs, BaseAttrsNode)
static void RegisterReflection()
Definition: manipulate.h:237
Attributes used in index_put operator.
Definition: manipulate.h:248
bool accumulate
Definition: manipulate.h:249
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(IndexPutAttrs, BaseAttrsNode)
static constexpr const char * _type_key
Definition: manipulate.h:261
static void RegisterReflection()
Definition: manipulate.h:251
Attributes used in layout_transform operator.
Definition: manipulate.h:66
static void RegisterReflection()
Definition: manipulate.h:85
Optional< PrimValue > pad_value
Definition: manipulate.h:70
static constexpr const char * _type_key
Definition: manipulate.h:100
Optional< Array< IntImm > > input_axis_separators
Definition: manipulate.h:83
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(LayoutTransformAttrs, BaseAttrsNode)
tir::IndexMap index_map
Definition: manipulate.h:67
Optional< Array< IntImm > > axis_separators
Definition: manipulate.h:77
Attribute used in meshgrid operator.
Definition: manipulate.h:266
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(MeshgridAttrs, BaseAttrsNode)
Optional< String > indexing
Definition: manipulate.h:267
static void RegisterReflection()
Definition: manipulate.h:269
static constexpr const char * _type_key
Definition: manipulate.h:275
Attributes used in one_hot operator.
Definition: manipulate.h:332
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(OneHotAttrs, BaseAttrsNode)
static void RegisterReflection()
Definition: manipulate.h:336
int axis
Definition: manipulate.h:334
static constexpr const char * _type_key
Definition: manipulate.h:343
int depth
Definition: manipulate.h:333
Attributes used in permute_dims operator.
Definition: manipulate.h:105
static void RegisterReflection()
Definition: manipulate.h:108
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(PermuteDimsAttrs, BaseAttrsNode)
static constexpr const char * _type_key
Definition: manipulate.h:114
Optional< Array< Integer > > axes
Definition: manipulate.h:106
Attributes used in repeat operators.
Definition: manipulate.h:171
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(RepeatAttrs, BaseAttrsNode)
static constexpr const char * _type_key
Definition: manipulate.h:185
Optional< int64_t > axis
Definition: manipulate.h:173
int repeats
Definition: manipulate.h:172
static void RegisterReflection()
Definition: manipulate.h:175
Attributes used in scatter_elements operators.
Definition: manipulate.h:280
String reduction
Definition: manipulate.h:282
Integer axis
Definition: manipulate.h:281
static constexpr const char * _type_key
Definition: manipulate.h:295
static void RegisterReflection()
Definition: manipulate.h:284
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ScatterElementsAttrs, BaseAttrsNode)
Attributes used in scatter_nd operators.
Definition: manipulate.h:300
String reduction
Definition: manipulate.h:301
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ScatterNDAttrs, BaseAttrsNode)
static constexpr const char * _type_key
Definition: manipulate.h:312
static void RegisterReflection()
Definition: manipulate.h:303
Attributes used in slice_scatter operator.
Definition: manipulate.h:317
static void RegisterReflection()
Definition: manipulate.h:320
static constexpr const char * _type_key
Definition: manipulate.h:327
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SliceScatterAttrs, BaseAttrsNode)
int axis
Definition: manipulate.h:318
Attributes used in split operator.
Definition: manipulate.h:119
static void RegisterReflection()
Definition: manipulate.h:123
ObjectRef indices_or_sections
Definition: manipulate.h:120
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SplitAttrs, BaseAttrsNode)
static constexpr const char * _type_key
Definition: manipulate.h:131
int axis
Definition: manipulate.h:121
Attributes used in squeeze operators.
Definition: manipulate.h:136
Optional< Array< Integer > > axis
Definition: manipulate.h:137
static constexpr const char * _type_key
Definition: manipulate.h:148
static void RegisterReflection()
Definition: manipulate.h:139
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SqueezeAttrs, BaseAttrsNode)
Attributes used in stack operators.
Definition: manipulate.h:153
Optional< Integer > axis
Definition: manipulate.h:154
static void RegisterReflection()
Definition: manipulate.h:156
static constexpr const char * _type_key
Definition: manipulate.h:166
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(StackAttrs, BaseAttrsNode)
Attributes used in tile operators.
Definition: manipulate.h:190
static void RegisterReflection()
Definition: manipulate.h:193
static constexpr const char * _type_key
Definition: manipulate.h:199
Array< Integer > repeats
Definition: manipulate.h:191
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TileAttrs, BaseAttrsNode)