tvm
bitserial.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_RELAY_ATTRS_BITSERIAL_H_
26 #define TVM_RELAY_ATTRS_BITSERIAL_H_
27 
28 #include <tvm/ir/attrs.h>
29 #include <tvm/relay/base.h>
30 
31 #include <string>
32 
33 namespace tvm {
34 namespace relay {
35 
37 struct BitPackAttrs : public tvm::AttrsNode<BitPackAttrs> {
38  int bits;
39  int pack_axis;
40  int bit_axis;
42  std::string name;
43 
44  TVM_DECLARE_ATTRS(BitPackAttrs, "relay.attrs.BitPackAttrs") {
45  TVM_ATTR_FIELD(bits).set_default(1).describe("Number of bits to quantize with.");
46  TVM_ATTR_FIELD(pack_axis).set_default(1).describe(
47  "Axis that should be compressed, typically channels.");
48  TVM_ATTR_FIELD(bit_axis).set_default(-1).describe("New axis for packed bits.");
50  .set_default(NullValue<DataType>())
51  .describe("Type of int to pack bits into.");
52  TVM_ATTR_FIELD(name).set_default("BitPack").describe("Name of operation.");
53  }
54 };
55 
57 struct BinaryConv2DAttrs : public tvm::AttrsNode<BinaryConv2DAttrs> {
64  std::string data_layout;
65  std::string kernel_layout;
68  bool unipolar;
69 
70  TVM_DECLARE_ATTRS(BinaryConv2DAttrs, "relay.attrs.BinaryConv2DAttrs") {
72  .set_default(Array<IndexExpr>({1, 1}))
73  .describe("Specifies the strides of the convolution.");
75  .set_default(Array<IndexExpr>({0, 0}))
76  .describe(
77  "If padding is non-zero the input is implicitly zero-padded"
78  "on both sides for padding number of points.");
80  .set_default(Array<IndexExpr>({3, 3}))
81  .describe("Specifies the dimensions of the convolution window.");
83  .set_default(NullValue<IndexExpr>())
84  .describe("Number of output channels, needed for shape inference.");
86  .set_default(1)
87  .describe("Number of bits activation should be packed with.");
89  .set_default(1)
90  .describe("Number of bits kernel should be packed with.");
92  .set_default("NCHW")
93  .describe("Dimension ordering of input data, can be 'NCHW' or NHWC'.");
95  .set_default("OIHW")
96  .describe("Dimension ordering of kernel data, can be 'OIHW' or HWIO'.");
98  .set_default(NullValue<DataType>())
99  .describe("Datatype to pack bits into.");
100  TVM_ATTR_FIELD(out_dtype).set_default(NullValue<DataType>()).describe("Output datatype.");
101  TVM_ATTR_FIELD(unipolar).set_default(true).describe(
102  "Whether to use unipolar or bipolar quantization.");
103  }
104 };
105 
106 /*~ \brief Attributes for bitserial dense operator */
107 struct BinaryDenseAttrs : public tvm::AttrsNode<BinaryDenseAttrs> {
113  bool unipolar;
114 
115  TVM_DECLARE_ATTRS(BinaryDenseAttrs, "relay.attrs.BinaryDenseAttrs") {
116  TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation.");
117  TVM_ATTR_FIELD(data_bits).set_default(1).describe(
118  "Number of bits to pack for incoming tensor.");
120  .set_default(1)
121  .describe("Number of bits to pack for weight tensor.");
123  .set_default(NullValue<DataType>())
124  .describe("Datatype to pack bits into before computation.");
125  TVM_ATTR_FIELD(out_dtype).set_default(NullValue<DataType>()).describe("Output data type.");
126  TVM_ATTR_FIELD(unipolar).set_default(true).describe(
127  "Whether to use unipolar or bipolar quantization for inputs.");
128  }
129 };
130 
131 } // namespace relay
132 } // namespace tvm
133 #endif // TVM_RELAY_ATTRS_BITSERIAL_H_
The base class of the all the Use "curiously recurring template pattern".
Definition: attrs.h:870
Reference to PrimExprNode.
Definition: expr.h:115
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Runtime primitive data type.
Definition: data_type.h:43
Helpers for attribute objects.
#define TVM_ATTR_FIELD(FieldName)
Declare an attribute field.
Definition: attrs.h:76
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
DataType NullValue< DataType >()
Definition: attrs.h:90
Base classes for the Relay IR.
Attribues used in bitserial convolution operators.
Definition: bitserial.h:57
Array< IndexExpr > kernel_size
Definition: bitserial.h:61
int activation_bits
Definition: bitserial.h:62
Array< IndexExpr > strides
Definition: bitserial.h:58
DataType out_dtype
Definition: bitserial.h:67
DataType pack_dtype
Definition: bitserial.h:66
std::string data_layout
Definition: bitserial.h:64
int weight_bits
Definition: bitserial.h:63
Array< IndexExpr > padding
Definition: bitserial.h:59
bool unipolar
Definition: bitserial.h:68
std::string kernel_layout
Definition: bitserial.h:65
TVM_DECLARE_ATTRS(BinaryConv2DAttrs, "relay.attrs.BinaryConv2DAttrs")
Definition: bitserial.h:70
IndexExpr channels
Definition: bitserial.h:60
Definition: bitserial.h:107
int data_bits
Definition: bitserial.h:109
IndexExpr units
Definition: bitserial.h:108
int weight_bits
Definition: bitserial.h:110
DataType pack_dtype
Definition: bitserial.h:111
DataType out_dtype
Definition: bitserial.h:112
TVM_DECLARE_ATTRS(BinaryDenseAttrs, "relay.attrs.BinaryDenseAttrs")
Definition: bitserial.h:115
bool unipolar
Definition: bitserial.h:113
Attributes used in bitpack operators.
Definition: bitserial.h:37
std::string name
Definition: bitserial.h:42
int bits
Definition: bitserial.h:38
int bit_axis
Definition: bitserial.h:40
int pack_axis
Definition: bitserial.h:39
TVM_DECLARE_ATTRS(BitPackAttrs, "relay.attrs.BitPackAttrs")
Definition: bitserial.h:44
DataType pack_type
Definition: bitserial.h:41