tvm
manipulate.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_ATTRS_MANIPULATE_H_
25 #define TVM_RELAX_ATTRS_MANIPULATE_H_
26 
27 #include <tvm/relax/expr.h>
28 #include <tvm/tir/index_map.h>
29 
30 namespace tvm {
31 namespace relax {
32 
34 struct ConcatAttrs : public AttrsNodeReflAdapter<ConcatAttrs> {
35  ffi::Optional<int64_t> axis;
36 
37  static void RegisterReflection() {
38  namespace refl = tvm::ffi::reflection;
39  refl::ObjectDef<ConcatAttrs>().def_ro("axis", &ConcatAttrs::axis,
40  "The axis at which the input arrays are concatenated."
41  "Should lie in range `[-ndim, ndim)`.");
42  }
44 }; // struct ConcatAttrs
45 
47 struct ExpandDimsAttrs : public AttrsNodeReflAdapter<ExpandDimsAttrs> {
48  ffi::Array<Integer> axis;
49 
50  static void RegisterReflection() {
51  namespace refl = tvm::ffi::reflection;
52  refl::ObjectDef<ExpandDimsAttrs>().def_ro(
53  "axis", &ExpandDimsAttrs::axis,
54  "The axes at which the input array are expanded. "
55  "All values are required to lie in range `[-data.ndim - 1, data.ndim]`, "
56  "with the convention of negative indexing.");
57  }
59 }; // struct ExpandDimsAttrs
60 
62 struct LayoutTransformAttrs : public AttrsNodeReflAdapter<LayoutTransformAttrs> {
64  // pad_value is chosen to be of PrimValue type, as it represents constant TIR POD expression. This
65  // needs to be revisited in case PrimValue is evolved to represent symbolic expression in future.
66  ffi::Optional<PrimValue> pad_value;
73  ffi::Optional<ffi::Array<IntImm>> axis_separators;
79  ffi::Optional<ffi::Array<IntImm>> input_axis_separators;
80 
81  static void RegisterReflection() {
82  namespace refl = tvm::ffi::reflection;
83  refl::ObjectDef<LayoutTransformAttrs>()
84  .def_ro("index_map", &LayoutTransformAttrs::index_map,
85  "The layout transformation to apply.")
86  .def_ro(
87  "pad_value", &LayoutTransformAttrs::pad_value,
88  "The specific value to be used to pad if the layout transform would result in implicit "
89  "padding. If not specified, the compiler is free to choose any value.")
90  .def_ro("axis_separators", &LayoutTransformAttrs::axis_separators,
91  "The separators between input axes when generating flat output axes")
92  .def_ro("input_axis_separators", &LayoutTransformAttrs::input_axis_separators,
93  "The separators between axes to regenerate output");
94  }
95  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.LayoutTransformAttrs", LayoutTransformAttrs,
97 }; // struct LayoutTransformAttrs
98 
100 struct PermuteDimsAttrs : public AttrsNodeReflAdapter<PermuteDimsAttrs> {
101  ffi::Optional<ffi::Array<Integer>> axes;
102 
103  static void RegisterReflection() {
104  namespace refl = tvm::ffi::reflection;
105  refl::ObjectDef<PermuteDimsAttrs>().def_ro(
106  "axes", &PermuteDimsAttrs::axes, "The target axes order, reverse order if not specified.");
107  }
108  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PermuteDimsAttrs", PermuteDimsAttrs,
109  BaseAttrsNode);
110 }; // struct PermuteDimsAttrs
111 
113 struct SplitAttrs : public AttrsNodeReflAdapter<SplitAttrs> {
115  int axis;
116 
117  static void RegisterReflection() {
118  namespace refl = tvm::ffi::reflection;
119  refl::ObjectDef<SplitAttrs>()
120  .def_ro("indices_or_sections", &SplitAttrs::indices_or_sections,
121  "The input array of indices or the number of split sections.")
122  .def_ro("axis", &SplitAttrs::axis, "The axis to be splitted");
123  }
125 }; // struct SplitAttrs
126 
128 struct SqueezeAttrs : public AttrsNodeReflAdapter<SqueezeAttrs> {
129  ffi::Optional<ffi::Array<Integer>> axis;
130 
131  static void RegisterReflection() {
132  namespace refl = tvm::ffi::reflection;
133  refl::ObjectDef<SqueezeAttrs>().def_ro("axis", &SqueezeAttrs::axis,
134  "The axis to squeeze in the input tensor."
135  "If `axis = None`, all axis of dimension 1 get squeezed;"
136  "Else, the dimension in axes get squeezed."
137  "It is an error if an axis does not has dimension 1.");
138  }
140 }; // struct SqueezeAttrs
141 
143 struct StackAttrs : public AttrsNodeReflAdapter<StackAttrs> {
144  ffi::Optional<Integer> axis;
145 
146  static void RegisterReflection() {
147  namespace refl = tvm::ffi::reflection;
148  refl::ObjectDef<StackAttrs>().def_ro(
149  "axis", &StackAttrs::axis,
150  "The axis along which to stack the input tensors. "
151  "The axis will be inserted at this position in the output, "
152  "so it must be in range [-ndim-1, ndim] where ndim is the "
153  "number of dimensions of the input tensors.");
154  }
156 }; // struct StackAttrs
157 
159 struct RepeatAttrs : public AttrsNodeReflAdapter<RepeatAttrs> {
160  int repeats;
161  ffi::Optional<int64_t> axis;
162 
163  static void RegisterReflection() {
164  namespace refl = tvm::ffi::reflection;
165  refl::ObjectDef<RepeatAttrs>()
166  .def_ro("repeats", &RepeatAttrs::repeats, "The number of repetitions.")
167  .def_ro("axis", &RepeatAttrs::axis,
168  "The axis along which to repeat values. The negative numbers are interpreted "
169  "counting from the backward. By default, use the flattened input array, and "
170  "return a flat output array.");
171  }
173 }; // struct RepeatAttrs
174 
176 struct TileAttrs : public AttrsNodeReflAdapter<TileAttrs> {
177  ffi::Array<Integer> repeats;
178 
179  static void RegisterReflection() {
180  namespace refl = tvm::ffi::reflection;
181  refl::ObjectDef<TileAttrs>().def_ro("repeats", &TileAttrs::repeats,
182  "The number of repetitions of data along each axis.");
183  }
185 }; // struct TileAttrs
186 
188 struct FlipAttrs : public AttrsNodeReflAdapter<FlipAttrs> {
190 
191  static void RegisterReflection() {
192  namespace refl = tvm::ffi::reflection;
193  refl::ObjectDef<FlipAttrs>().def_ro("axis", &FlipAttrs::axis,
194  "The axis along which to flip over.",
195  refl::DefaultValue(NullValue<Integer>()));
196  }
198 }; // struct FlipAttrs
199 
201 struct GatherElementsAttrs : public AttrsNodeReflAdapter<GatherElementsAttrs> {
203 
204  static void RegisterReflection() {
205  namespace refl = tvm::ffi::reflection;
206  refl::ObjectDef<GatherElementsAttrs>().def_ro("axis", &GatherElementsAttrs::axis,
207  "The axis along which to index.",
208  refl::DefaultValue(0));
209  }
210  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GatherElementsAttrs", GatherElementsAttrs,
211  BaseAttrsNode);
212 }; // struct GatherElementsAttrs
213 
215 struct GatherNDAttrs : public AttrsNodeReflAdapter<GatherNDAttrs> {
217 
218  static void RegisterReflection() {
219  namespace refl = tvm::ffi::reflection;
220  refl::ObjectDef<GatherNDAttrs>().def_ro("batch_dims", &GatherNDAttrs::batch_dims,
221  "The number of batch dims.", refl::DefaultValue(0));
222  }
224 }; // struct GatherNDAttrs
225 
227 struct IndexPutAttrs : public AttrsNodeReflAdapter<IndexPutAttrs> {
229 
230  static void RegisterReflection() {
231  namespace refl = tvm::ffi::reflection;
232  refl::ObjectDef<IndexPutAttrs>().def_ro(
233  "accumulate", &IndexPutAttrs::accumulate,
234  "Whether to accumulate (add) values rather than replace. "
235  "If true, performs tensor[indices] += values, "
236  "otherwise performs tensor[indices] = values.",
237  refl::DefaultValue(false));
238  }
240 }; // struct IndexPutAttrs
241 
243 struct MeshgridAttrs : public AttrsNodeReflAdapter<MeshgridAttrs> {
244  ffi::Optional<ffi::String> indexing;
245 
246  static void RegisterReflection() {
247  namespace refl = tvm::ffi::reflection;
248  refl::ObjectDef<MeshgridAttrs>().def_ro("indexing", &MeshgridAttrs::indexing,
249  "Specifies how the grid dimensions are ordered.");
250  }
252 };
253 
255 struct ScatterElementsAttrs : public AttrsNodeReflAdapter<ScatterElementsAttrs> {
257  ffi::String reduction;
258 
259  static void RegisterReflection() {
260  namespace refl = tvm::ffi::reflection;
261  refl::ObjectDef<ScatterElementsAttrs>()
262  .def_ro("axis", &ScatterElementsAttrs::axis, "The axis over which to select values.",
263  refl::DefaultValue(0))
264  .def_ro("reduction", &ScatterElementsAttrs::reduction,
265  "Reduction mode of the scatter elements, "
266  "either \"update\", \"add\", \"mul\", \"mean\", \"min\" or \"max\".",
267  refl::DefaultValue("update"));
268  }
269  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ScatterElementsAttrs", ScatterElementsAttrs,
270  BaseAttrsNode);
271 }; // struct ScatterElementsAttrs
272 
274 struct ScatterNDAttrs : public AttrsNodeReflAdapter<ScatterNDAttrs> {
275  ffi::String reduction;
276 
277  static void RegisterReflection() {
278  namespace refl = tvm::ffi::reflection;
279  refl::ObjectDef<ScatterNDAttrs>().def_ro(
280  "reduction", &ScatterNDAttrs::reduction,
281  "Accumulation mode of the ScatterND, "
282  "either \"update\", \"add\", \"mul\", \"min\" or \"max\".",
283  refl::DefaultValue("update"));
284  }
286 }; // struct ScatterNDAttrs
287 
289 struct SliceScatterAttrs : public AttrsNodeReflAdapter<SliceScatterAttrs> {
290  int axis;
291 
292  static void RegisterReflection() {
293  namespace refl = tvm::ffi::reflection;
294  refl::ObjectDef<SliceScatterAttrs>().def_ro("axis", &SliceScatterAttrs::axis,
295  "the dimension to insert the slice into ",
296  refl::DefaultValue(0));
297  }
298  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SliceScatterAttrs", SliceScatterAttrs,
299  BaseAttrsNode);
300 }; // struct SliceScatterAttrs
301 
303 struct OneHotAttrs : public AttrsNodeReflAdapter<OneHotAttrs> {
304  int depth;
305  int axis;
306 
307  static void RegisterReflection() {
308  namespace refl = tvm::ffi::reflection;
309  refl::ObjectDef<OneHotAttrs>()
310  .def_ro("depth", &OneHotAttrs::depth, "Depth of the one hot dimension.")
311  .def_ro("axis", &OneHotAttrs::axis, "Axis to fill.", refl::DefaultValue(-1));
312  }
314 }; // struct OneHotAttrs
315 
316 } // namespace relax
317 } // namespace tvm
318 
319 #endif // TVM_RELAX_ATTRS_MANIPULATE_H_
Adapter for AttrsNode with the new reflection API.
Definition: attrs.h:385
Base class of all attribute class.
Definition: attrs.h:102
Container of constant int that adds more constructors.
Definition: expr.h:600
Definition: index_map.h:169
Defines a remapping of buffer indices.
Definition: repr_printer.h:91
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
Attributes used in concat operators.
Definition: manipulate.h:34
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ConcatAttrs", ConcatAttrs, BaseAttrsNode)
static void RegisterReflection()
Definition: manipulate.h:37
ffi::Optional< int64_t > axis
Definition: manipulate.h:35
Attributes used in expand_dims operators.
Definition: manipulate.h:47
static void RegisterReflection()
Definition: manipulate.h:50
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ExpandDimsAttrs", ExpandDimsAttrs, BaseAttrsNode)
ffi::Array< Integer > axis
Definition: manipulate.h:48
Attributes used in flip operators.
Definition: manipulate.h:188
static void RegisterReflection()
Definition: manipulate.h:191
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.FlipAttrs", FlipAttrs, BaseAttrsNode)
Integer axis
Definition: manipulate.h:189
Attributes used in gather_elements operators.
Definition: manipulate.h:201
static void RegisterReflection()
Definition: manipulate.h:204
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GatherElementsAttrs", GatherElementsAttrs, BaseAttrsNode)
Integer axis
Definition: manipulate.h:202
Attributes used in gather_nd operators.
Definition: manipulate.h:215
Integer batch_dims
Definition: manipulate.h:216
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GatherNDAttrs", GatherNDAttrs, BaseAttrsNode)
static void RegisterReflection()
Definition: manipulate.h:218
Attributes used in index_put operator.
Definition: manipulate.h:227
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.IndexPutAttrs", IndexPutAttrs, BaseAttrsNode)
bool accumulate
Definition: manipulate.h:228
static void RegisterReflection()
Definition: manipulate.h:230
Attributes used in layout_transform operator.
Definition: manipulate.h:62
ffi::Optional< ffi::Array< IntImm > > axis_separators
Definition: manipulate.h:73
ffi::Optional< PrimValue > pad_value
Definition: manipulate.h:66
static void RegisterReflection()
Definition: manipulate.h:81
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.LayoutTransformAttrs", LayoutTransformAttrs, BaseAttrsNode)
ffi::Optional< ffi::Array< IntImm > > input_axis_separators
Definition: manipulate.h:79
tir::IndexMap index_map
Definition: manipulate.h:63
Attribute used in meshgrid operator.
Definition: manipulate.h:243
ffi::Optional< ffi::String > indexing
Definition: manipulate.h:244
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MeshgridAttrs", MeshgridAttrs, BaseAttrsNode)
static void RegisterReflection()
Definition: manipulate.h:246
Attributes used in one_hot operator.
Definition: manipulate.h:303
static void RegisterReflection()
Definition: manipulate.h:307
int axis
Definition: manipulate.h:305
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.OneHotAttrs", OneHotAttrs, BaseAttrsNode)
int depth
Definition: manipulate.h:304
Attributes used in permute_dims operator.
Definition: manipulate.h:100
static void RegisterReflection()
Definition: manipulate.h:103
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.PermuteDimsAttrs", PermuteDimsAttrs, BaseAttrsNode)
ffi::Optional< ffi::Array< Integer > > axes
Definition: manipulate.h:101
Attributes used in repeat operators.
Definition: manipulate.h:159
ffi::Optional< int64_t > axis
Definition: manipulate.h:161
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.RepeatAttrs", RepeatAttrs, BaseAttrsNode)
int repeats
Definition: manipulate.h:160
static void RegisterReflection()
Definition: manipulate.h:163
Attributes used in scatter_elements operators.
Definition: manipulate.h:255
Integer axis
Definition: manipulate.h:256
ffi::String reduction
Definition: manipulate.h:257
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ScatterElementsAttrs", ScatterElementsAttrs, BaseAttrsNode)
static void RegisterReflection()
Definition: manipulate.h:259
Attributes used in scatter_nd operators.
Definition: manipulate.h:274
static void RegisterReflection()
Definition: manipulate.h:277
ffi::String reduction
Definition: manipulate.h:275
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ScatterNDAttrs", ScatterNDAttrs, BaseAttrsNode)
Attributes used in slice_scatter operator.
Definition: manipulate.h:289
static void RegisterReflection()
Definition: manipulate.h:292
int axis
Definition: manipulate.h:290
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SliceScatterAttrs", SliceScatterAttrs, BaseAttrsNode)
Attributes used in split operator.
Definition: manipulate.h:113
static void RegisterReflection()
Definition: manipulate.h:117
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SplitAttrs", SplitAttrs, BaseAttrsNode)
ObjectRef indices_or_sections
Definition: manipulate.h:114
int axis
Definition: manipulate.h:115
Attributes used in squeeze operators.
Definition: manipulate.h:128
static void RegisterReflection()
Definition: manipulate.h:131
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.SqueezeAttrs", SqueezeAttrs, BaseAttrsNode)
ffi::Optional< ffi::Array< Integer > > axis
Definition: manipulate.h:129
Attributes used in stack operators.
Definition: manipulate.h:143
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.StackAttrs", StackAttrs, BaseAttrsNode)
ffi::Optional< Integer > axis
Definition: manipulate.h:144
static void RegisterReflection()
Definition: manipulate.h:146
Attributes used in tile operators.
Definition: manipulate.h:176
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TileAttrs", TileAttrs, BaseAttrsNode)
static void RegisterReflection()
Definition: manipulate.h:179
ffi::Array< Integer > repeats
Definition: manipulate.h:177