tvm
algorithm.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAY_ATTRS_ALGORITHM_H_
25 #define TVM_RELAY_ATTRS_ALGORITHM_H_
26 
27 #include <tvm/ir/attrs.h>
28 #include <tvm/relay/base.h>
29 #include <tvm/relay/expr.h>
30 
31 #include <string>
32 
33 namespace tvm {
34 namespace relay {
35 
37 struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> {
38  int axis;
39  bool is_ascend;
41 
42  TVM_DECLARE_ATTRS(ArgsortAttrs, "relay.attrs.ArgsortAttrs") {
43  TVM_ATTR_FIELD(axis).set_default(-1).describe(
44  "Axis along which to sort the input tensor."
45  "If not given, the flattened array is used.");
46  TVM_ATTR_FIELD(is_ascend).set_default(true).describe(
47  "Whether to sort in ascending or descending order."
48  "By default, sort in ascending order");
50  .set_default(NullValue<DataType>())
51  .describe("DType of the output indices.");
52  }
53 };
54 
55 struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> {
57  int axis;
58  bool is_ascend;
59  std::string ret_type;
61 
62  TVM_DECLARE_ATTRS(TopKAttrs, "relay.attrs.TopkAttrs") {
63  TVM_ATTR_FIELD(k).describe("Number of top elements to select");
64  TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis along which to sort the input tensor.");
65  TVM_ATTR_FIELD(ret_type).set_default("both").describe(
66  "The return type [both, values, indices]."
67  "both - return both top k data and indices."
68  "values - return top k data only."
69  "indices - return top k indices only.");
70  TVM_ATTR_FIELD(is_ascend).set_default(false).describe(
71  "Whether to sort in ascending or descending order."
72  "By default, sort in descending order");
74  .set_default(NullValue<DataType>())
75  .describe("Data type of the output indices.");
76  }
77 };
78 
79 struct SearchSortedAttrs : public tvm::AttrsNode<SearchSortedAttrs> {
80  bool right;
82 
83  TVM_DECLARE_ATTRS(SearchSortedAttrs, "relay.attrs.SearchSortedAttrs") {
84  TVM_ATTR_FIELD(right).set_default(false).describe(
85  "Controls which index is returned if a value lands exactly on one of sorted values. If "
86  " false, the index of the first suitable location found is given. If true, return the "
87  "last such index. If there is no suitable index, return either 0 or N (where N is the "
88  "size of the innermost dimension).");
90  .set_default(DataType::Int(32))
91  .describe("Data type of the output indices.");
92  }
93 };
94 
95 } // namespace relay
96 } // namespace tvm
97 #endif // TVM_RELAY_ATTRS_ALGORITHM_H_
The base class of the all the Use "curiously recurring template pattern".
Definition: attrs.h:870
Runtime primitive data type.
Definition: data_type.h:43
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:219
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Helpers for attribute objects.
#define TVM_ATTR_FIELD(FieldName)
Declare an attribute field.
Definition: attrs.h:76
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
DataType NullValue< DataType >()
Definition: attrs.h:90
Base classes for the Relay IR.
Relay expression language.
Attributes used in argsort operators.
Definition: algorithm.h:37
int axis
Definition: algorithm.h:38
TVM_DECLARE_ATTRS(ArgsortAttrs, "relay.attrs.ArgsortAttrs")
Definition: algorithm.h:42
DataType dtype
Definition: algorithm.h:40
bool is_ascend
Definition: algorithm.h:39
Definition: algorithm.h:79
DataType dtype
Definition: algorithm.h:81
bool right
Definition: algorithm.h:80
TVM_DECLARE_ATTRS(SearchSortedAttrs, "relay.attrs.SearchSortedAttrs")
Definition: algorithm.h:83
Definition: algorithm.h:55
std::string ret_type
Definition: algorithm.h:59
Optional< Integer > k
Definition: algorithm.h:56
int axis
Definition: algorithm.h:57
DataType dtype
Definition: algorithm.h:60
bool is_ascend
Definition: algorithm.h:58
TVM_DECLARE_ATTRS(TopKAttrs, "relay.attrs.TopkAttrs")
Definition: algorithm.h:62