tvm
sorting.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_ATTRS_SORTING_H_
25 #define TVM_RELAX_ATTRS_SORTING_H_
26 
27 #include <tvm/relax/expr.h>
28 #include <tvm/tir/index_map.h>
29 
30 namespace tvm {
31 namespace relax {
32 
34 struct SortAttrs : public tvm::AttrsNode<SortAttrs> {
35  int axis;
36  bool descending;
37 
38  TVM_DECLARE_ATTRS(SortAttrs, "relax.attrs.SortAttrs") {
39  TVM_ATTR_FIELD(axis).set_default(-1).describe(
40  "Axis along which the sort is computed."
41  "The default the last axis is used.");
43  .set_default(false)
44  .describe(
45  "Whether to sort in descending order."
46  "If it is not specified, it defaults to the ascending order.");
47  }
48 }; // struct SortAttrs
49 
51 struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> {
52  int axis;
53  bool descending;
55 
56  TVM_DECLARE_ATTRS(ArgsortAttrs, "relax.attrs.ArgsortAttrs") {
57  TVM_ATTR_FIELD(axis).set_default(-1).describe(
58  "Axis along which the argsort is computed."
59  "The default the last axis is used.");
61  .set_default(false)
62  .describe(
63  "Whether to argsort in descending order."
64  "If it is not specified, it defaults to the ascending order.");
66  .set_default(NullValue<DataType>())
67  .describe("DType of the output indices.");
68  }
69 }; // struct ArgsortAttrs
70 
72 struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> {
73  int k;
74  int axis;
75  bool largest;
78 
79  TVM_DECLARE_ATTRS(TopKAttrs, "relax.attrs.TopKAttrs") {
80  TVM_ATTR_FIELD(k).describe("Number of top elements to select");
81  TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis along which to sort the input tensor.");
82  TVM_ATTR_FIELD(ret_type).set_default("both").describe(
83  "The return type [both, values, indices]."
84  "both - return both top k data and indices."
85  "values - return top k data only."
86  "indices - return top k indices only.");
87  TVM_ATTR_FIELD(largest).set_default(true).describe(
88  "Whether to return largest or smallest elements."
89  "By default, return the largest k elements.");
91  .set_default(NullValue<DataType>())
92  .describe("Data type of the output indices.");
93  }
94 }; // struct TopKAttrs
95 
96 } // namespace relax
97 } // namespace tvm
98 
99 #endif // TVM_RELAX_ATTRS_SORTING_H_
The base class of the all the Use "curiously recurring template pattern".
Definition: attrs.h:870
Runtime primitive data type.
Definition: data_type.h:43
Reference to string objects.
Definition: string.h:98
Defines a remapping of buffer indices.
#define TVM_ATTR_FIELD(FieldName)
Declare an attribute field.
Definition: attrs.h:76
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
DataType NullValue< DataType >()
Definition: attrs.h:90
Attributes used in argsort operator.
Definition: sorting.h:51
DataType dtype
Definition: sorting.h:54
int axis
Definition: sorting.h:52
TVM_DECLARE_ATTRS(ArgsortAttrs, "relax.attrs.ArgsortAttrs")
Definition: sorting.h:56
bool descending
Definition: sorting.h:53
Attributes used in sort operator.
Definition: sorting.h:34
bool descending
Definition: sorting.h:36
int axis
Definition: sorting.h:35
TVM_DECLARE_ATTRS(SortAttrs, "relax.attrs.SortAttrs")
Definition: sorting.h:38
Attributes used in topk operator.
Definition: sorting.h:72
TVM_DECLARE_ATTRS(TopKAttrs, "relax.attrs.TopKAttrs")
Definition: sorting.h:79
int k
Definition: sorting.h:73
String ret_type
Definition: sorting.h:76
DataType dtype
Definition: sorting.h:77
int axis
Definition: sorting.h:74
bool largest
Definition: sorting.h:75