tvm
manipulate.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_ATTRS_MANIPULATE_H_
25 #define TVM_RELAX_ATTRS_MANIPULATE_H_
26 
27 #include <tvm/relax/expr.h>
28 #include <tvm/tir/index_map.h>
29 
30 namespace tvm {
31 namespace relax {
32 
34 struct ConcatAttrs : public tvm::AttrsNode<ConcatAttrs> {
36 
37  TVM_DECLARE_ATTRS(ConcatAttrs, "relax.attrs.ConcatAttrs") {
38  TVM_ATTR_FIELD(axis).describe(
39  "The axis at which the input arrays are concatenated."
40  "Should lie in range `[-ndim, ndim)`.");
41  }
42 }; // struct ConcatAttrs
43 
45 struct ExpandDimsAttrs : public tvm::AttrsNode<ExpandDimsAttrs> {
47 
48  TVM_DECLARE_ATTRS(ExpandDimsAttrs, "relax.attrs.ExpandDimsAttrs") {
49  TVM_ATTR_FIELD(axis).describe(
50  "The axes at which the input array are expanded. "
51  "All values are required to lie in range `[-data.ndim - 1, data.ndim]`, "
52  "with the convention of negative indexing.");
53  }
54 }; // struct ExpandDimsAttrs
55 
57 struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
59  // pad_value is chosen to be of PrimValue type, as it represents constant TIR POD expression. This
60  // needs to be revisited in case PrimValue is evolved to represent symbolic expression in future.
75 
76  TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") {
77  TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply.");
78  TVM_ATTR_FIELD(pad_value).describe(
79  "The specific value to be used to pad if the layout transform would result in implicit "
80  "padding. If not specified, the compiler is free to choose any value.");
82  .describe("The separators between input axes when generating flat output axes");
84  .describe("The separators between axes to regenerate output");
85  }
86 }; // struct LayoutTransformAttrs
87 
89 struct PermuteDimsAttrs : public tvm::AttrsNode<PermuteDimsAttrs> {
91 
92  TVM_DECLARE_ATTRS(PermuteDimsAttrs, "relax.attrs.PermuteDimsAttrs") {
93  TVM_ATTR_FIELD(axes).describe("The target axes order, reverse order if not specified.");
94  }
95 }; // struct PermuteDimsAttrs
96 
98 struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
100  int axis;
101 
102  TVM_DECLARE_ATTRS(SplitAttrs, "relax.attrs.SplitAttrs") {
104  .describe("The input array of indices or the number of split sections.");
105  TVM_ATTR_FIELD(axis).describe("The axis to be splitted");
106  }
107 }; // struct SplitAttrs
108 
110 struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
112 
113  TVM_DECLARE_ATTRS(SqueezeAttrs, "relax.attrs.SqueezeAttrs") {
114  TVM_ATTR_FIELD(axis).describe(
115  "The axis to squeeze in the input tensor."
116  "If `axis = None`, all axis of dimension 1 get squeezed;"
117  "Else, the dimension in axes get squeezed."
118  "It is an error if an axis does not has dimension 1.");
119  }
120 }; // struct SqueezeAttrs
121 
123 struct RepeatAttrs : public tvm::AttrsNode<RepeatAttrs> {
124  int repeats;
126 
127  TVM_DECLARE_ATTRS(RepeatAttrs, "relax.attrs.RepeatAttrs") {
128  TVM_ATTR_FIELD(repeats).describe("The number of repetitions.");
129  TVM_ATTR_FIELD(axis).describe(
130  "The axis along which to repeat values. The negative numbers are interpreted "
131  "counting from the backward. By default, use the flattened input array, and "
132  "return a flat output array.");
133  }
134 }; // struct RepeatAttrs
135 
137 struct TileAttrs : public tvm::AttrsNode<TileAttrs> {
139 
140  TVM_DECLARE_ATTRS(TileAttrs, "relax.attrs.TileAttrs") {
141  TVM_ATTR_FIELD(repeats).describe("The number of repetitions of data along each axis.");
142  }
143 }; // struct TileAttrs
144 
146 struct FlipAttrs : public tvm::AttrsNode<FlipAttrs> {
148  TVM_DECLARE_ATTRS(FlipAttrs, "relax.attrs.FlipAttrs") {
150  .set_default(NullValue<Integer>())
151  .describe("The axis along which to flip over.");
152  }
153 }; // struct FlipAttrs
154 
156 struct GatherElementsAttrs : public tvm::AttrsNode<GatherElementsAttrs> {
158 
159  TVM_DECLARE_ATTRS(GatherElementsAttrs, "relax.attrs.GatherElementsAttrs") {
160  TVM_ATTR_FIELD(axis).set_default(0).describe("The axis along which to index.");
161  }
162 }; // struct GatherElementsAttrs
163 
165 struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
167  TVM_DECLARE_ATTRS(GatherNDAttrs, "relax.attrs.GatherNDAttrs") {
168  TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dims.");
169  }
170 }; // struct GatherNDAttrs
171 
173 struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
176 
177  TVM_DECLARE_ATTRS(ScatterElementsAttrs, "relax.attrs.ScatterElementsAttrs") {
178  TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values.");
179  TVM_ATTR_FIELD(reduction).set_default("update").describe(
180  "Reduction mode of the scatter elements, "
181  "either \"update\", \"add\", \"mul\", \"mean\", \"min\" or \"max\".");
182  }
183 }; // struct ScatterElementsAttrs
184 
186 struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
188 
189  TVM_DECLARE_ATTRS(ScatterNDAttrs, "relax.attrs.ScatterNDAttrs") {
190  TVM_ATTR_FIELD(reduction).set_default("update").describe(
191  "Accumulation mode of the ScatterND, "
192  "either \"update\", \"add\", \"mul\", \"min\" or \"max\".");
193  }
194 }; // struct ScatterNDAttrs
195 
197 struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
198  int depth;
199  int axis;
200 
201  TVM_DECLARE_ATTRS(OneHotAttrs, "relax.attrs.OneHotAttrs") {
202  TVM_ATTR_FIELD(depth).describe("Depth of the one hot dimension.");
203  TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis to fill.");
204  }
205 }; // struct OneHotAttrs
206 
207 } // namespace relax
208 } // namespace tvm
209 
210 #endif // TVM_RELAX_ATTRS_MANIPULATE_H_
The base class of the all the Use "curiously recurring template pattern".
Definition: attrs.h:870
Container of constant int that adds more constructors.
Definition: expr.h:632
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Base class of all object reference.
Definition: object.h:519
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Reference to string objects.
Definition: string.h:98
Definition: index_map.h:176
Defines a remapping of buffer indices.
#define TVM_ATTR_FIELD(FieldName)
Declare an attribute field.
Definition: attrs.h:76
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Attributes used in concat operators.
Definition: manipulate.h:34
Optional< Integer > axis
Definition: manipulate.h:35
TVM_DECLARE_ATTRS(ConcatAttrs, "relax.attrs.ConcatAttrs")
Definition: manipulate.h:37
Attributes used in expand_dims operators.
Definition: manipulate.h:45
TVM_DECLARE_ATTRS(ExpandDimsAttrs, "relax.attrs.ExpandDimsAttrs")
Definition: manipulate.h:48
Array< Integer > axis
Definition: manipulate.h:46
Attributes used in flip operators.
Definition: manipulate.h:146
TVM_DECLARE_ATTRS(FlipAttrs, "relax.attrs.FlipAttrs")
Definition: manipulate.h:148
Integer axis
Definition: manipulate.h:147
Attributes used in gather_elements operators.
Definition: manipulate.h:156
TVM_DECLARE_ATTRS(GatherElementsAttrs, "relax.attrs.GatherElementsAttrs")
Definition: manipulate.h:159
Integer axis
Definition: manipulate.h:157
Attributes used in gather_nd operators.
Definition: manipulate.h:165
Integer batch_dims
Definition: manipulate.h:166
TVM_DECLARE_ATTRS(GatherNDAttrs, "relax.attrs.GatherNDAttrs")
Definition: manipulate.h:167
Attributes used in layout_transform operator.
Definition: manipulate.h:57
Optional< PrimValue > pad_value
Definition: manipulate.h:61
TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs")
Definition: manipulate.h:76
Optional< Array< IntImm > > input_axis_separators
Definition: manipulate.h:74
tir::IndexMap index_map
Definition: manipulate.h:58
Optional< Array< IntImm > > axis_separators
Definition: manipulate.h:68
Attributes used in one_hot operator.
Definition: manipulate.h:197
int axis
Definition: manipulate.h:199
TVM_DECLARE_ATTRS(OneHotAttrs, "relax.attrs.OneHotAttrs")
Definition: manipulate.h:201
int depth
Definition: manipulate.h:198
Attributes used in permute_dims operator.
Definition: manipulate.h:89
TVM_DECLARE_ATTRS(PermuteDimsAttrs, "relax.attrs.PermuteDimsAttrs")
Definition: manipulate.h:92
Optional< Array< Integer > > axes
Definition: manipulate.h:90
Attributes used in repeat operators.
Definition: manipulate.h:123
Optional< Integer > axis
Definition: manipulate.h:125
TVM_DECLARE_ATTRS(RepeatAttrs, "relax.attrs.RepeatAttrs")
Definition: manipulate.h:127
int repeats
Definition: manipulate.h:124
Attributes used in scatter_elements operators.
Definition: manipulate.h:173
String reduction
Definition: manipulate.h:175
Integer axis
Definition: manipulate.h:174
TVM_DECLARE_ATTRS(ScatterElementsAttrs, "relax.attrs.ScatterElementsAttrs")
Definition: manipulate.h:177
Attributes used in scatter_nd operators.
Definition: manipulate.h:186
String reduction
Definition: manipulate.h:187
TVM_DECLARE_ATTRS(ScatterNDAttrs, "relax.attrs.ScatterNDAttrs")
Definition: manipulate.h:189
Attributes used in split operator.
Definition: manipulate.h:98
ObjectRef indices_or_sections
Definition: manipulate.h:99
TVM_DECLARE_ATTRS(SplitAttrs, "relax.attrs.SplitAttrs")
Definition: manipulate.h:102
int axis
Definition: manipulate.h:100
Attributes used in squeeze operators.
Definition: manipulate.h:110
Optional< Array< Integer > > axis
Definition: manipulate.h:111
TVM_DECLARE_ATTRS(SqueezeAttrs, "relax.attrs.SqueezeAttrs")
Definition: manipulate.h:113
Attributes used in tile operators.
Definition: manipulate.h:137
Array< Integer > repeats
Definition: manipulate.h:138
TVM_DECLARE_ATTRS(TileAttrs, "relax.attrs.TileAttrs")
Definition: manipulate.h:140