tvm
bnn.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_NN_BNN_H_
25 #define TVM_TOPI_NN_BNN_H_
26 
27 #include <tvm/arith/analyzer.h>
28 #include <tvm/te/operation.h>
30 #include <tvm/topi/tags.h>
31 
32 #include <string>
33 
34 namespace tvm {
35 namespace topi {
36 namespace nn {
37 
38 using namespace tvm::te;
39 
51 inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis,
52  std::string name = "PackedInput",
53  std::string tag = "binarize_pack") {
54  auto ishape = data->shape;
55  ICHECK_EQ(GetConstInt(ishape[axis]) % 32, 0)
56  << "binarize_pack: axis size must be a multiple of 32";
57 
58  arith::Analyzer analyzer;
59  auto n = ishape.size();
60  Array<PrimExpr> oshape;
61  for (size_t i = 0; i < n; ++i) {
62  oshape.push_back(i == static_cast<size_t>(axis) ? analyzer.Simplify(indexdiv(ishape[i], 32))
63  : ishape[i]);
64  }
65 
66  return tvm::te::compute(
67  oshape,
68  [&](const Array<Var>& indices) {
69  Array<PrimExpr> start_idx;
70  for (size_t i = 0; i < n; ++i) {
71  start_idx.push_back(i == static_cast<size_t>(axis) ? indices[i] * 32
72  : static_cast<PrimExpr>(indices[i]));
73  }
74  auto packed = make_const(DataType::UInt(32), 0);
75  for (size_t j = 0; j < 32; ++j) {
76  Array<PrimExpr> idx;
77  for (size_t i = 0; i < n; ++i) {
78  idx.push_back(i == static_cast<size_t>(axis) ? start_idx[i] + static_cast<int>(j)
79  : start_idx[i]);
80  }
81  auto sign = tvm::cast(DataType::UInt(32), data(idx) >= 0);
82  packed = (packed | sign);
83  if (j == 31) {
84  return packed;
85  }
86  packed = packed << 1;
87  }
88  return packed; // never reached, but suppress compiler warning
89  },
90  name, tag);
91 }
92 
101 inline tvm::te::Tensor binary_dense(const tvm::te::Tensor& data, const tvm::te::Tensor& weight) {
102  ICHECK_EQ(data->shape.size(), 2) << "binary_dense requires 2-D data";
103  ICHECK_EQ(weight->shape.size(), 2) << "binary_dense requires 2-D weight";
104  ICHECK_EQ(data->dtype, DataType::UInt(32)) << "binary_dense requires uint32 data";
105  ICHECK_EQ(weight->dtype, DataType::UInt(32)) << "binary_dense requires uint32 weight";
106 
107  auto batch = data->shape[0];
108  auto in_dim = data->shape[1];
109  auto out_dim = weight->shape[0];
110 
111  auto k = tvm::te::reduce_axis(Range(0, in_dim), "k");
112  auto matmul = tvm::te::compute(
113  {batch, out_dim},
114  [&](Var i, Var j) { return tvm::sum(popcount(data(i, k) ^ weight(j, k)), {k}); }, "tensor",
115  "binary_dense");
116 
117  return tvm::te::compute(
118  {batch, out_dim}, [&](Var i, Var j) { return 32 * in_dim - 2.0f * matmul(i, j); }, "tensor",
119  kElementWise);
120 }
121 
122 } // namespace nn
123 } // namespace topi
124 } // namespace tvm
125 #endif // TVM_TOPI_NN_BNN_H_
Algebra expression simplifications.
Reference to PrimExprNode.
Definition: expr.h:115
Range container
Definition: expr.h:725
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:629
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
static DataType UInt(int bits, int lanes=1, bool is_scalable=false)
Construct an uint type.
Definition: data_type.h:227
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
a named variable in TIR
Definition: var.h:89
Utility functions for handling constants in TVM expressions.
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:962
tvm::te::Tensor binary_dense(const tvm::te::Tensor &data, const tvm::te::Tensor &weight)
Binary matrix multiplication using xor and bit-count.
Definition: bnn.h:101
tvm::te::Tensor binarize_pack(const tvm::te::Tensor &data, int axis, std::string name="PackedInput", std::string tag="binarize_pack")
Binarization and bit-packing along a certain axis.
Definition: bnn.h:51
constexpr auto kElementWise
Definition: tags.h:32
tvm::te::Tensor matmul(const tvm::te::Tensor &A, const tvm::te::Tensor &B, bool trans_a=false, bool trans_b=false, std::string name="T_matmul", std::string tag=kMatMul)
Creates an operation that calculates a matrix multiplication (row-major notation): A(i,...
Definition: transform.h:1557
Tensor sign(const Tensor &x, std::string name="T_sign", std::string tag=kElementWise)
Returns the sign of the tensor.
Definition: elemwise.h:212
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr popcount(PrimExpr x, Span span=Span())
Definition: op.h:718
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Operation node can generate one or multiple Tensors.
External function interface to rocBLAS libraries.