tvm
dense.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_NN_DENSE_H_
25 #define TVM_TOPI_NN_DENSE_H_
26 
27 #include <tvm/te/operation.h>
28 #include <tvm/topi/tags.h>
29 
30 #include <string>
31 
32 namespace tvm {
33 namespace topi {
34 namespace nn {
35 
36 using namespace tvm::te;
37 
48 inline tvm::te::Tensor dense(const tvm::te::Tensor& data, const tvm::te::Tensor& weight,
49  const tvm::te::Tensor& bias, const DataType& out_dtype) {
50  ICHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
51  ICHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
52  if (bias.defined()) {
53  ICHECK_EQ(bias->shape.size(), 1) << "dense requires 1-D bias";
54  }
55 
56  auto batch = data->shape[0];
57  auto in_dim = data->shape[1];
58  auto out_dim = weight->shape[0];
59 
60  auto k = tvm::te::reduce_axis(Range(0, in_dim), "k");
61  auto matmul = tvm::te::compute(
62  {batch, out_dim},
63  [&](Var i, Var j) {
64  return tvm::sum(tvm::cast(out_dtype, data(i, k)) * tvm::cast(out_dtype, weight(j, k)), {k});
65  },
66  "tensor", "dense");
67 
68  if (bias.defined()) {
70  {batch, out_dim},
71  [&](Var i, Var j) { return matmul(i, j) + tvm::cast(out_dtype, bias(j)); }, "tensor",
72  kBroadcast);
73  }
74 
75  return matmul;
76 }
77 
78 } // namespace nn
79 } // namespace topi
80 } // namespace tvm
81 #endif // TVM_TOPI_NN_DENSE_H_
Range container
Definition: expr.h:725
Runtime primitive data type.
Definition: data_type.h:43
bool defined() const
Definition: object.h:552
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
a named variable in TIR
Definition: var.h:89
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
tvm::te::Tensor dense(const tvm::te::Tensor &data, const tvm::te::Tensor &weight, const tvm::te::Tensor &bias, const DataType &out_dtype)
Creates an operation that calculates data * weight^T + bias.
Definition: dense.h:48
constexpr auto kBroadcast
Definition: tags.h:36
tvm::te::Tensor matmul(const tvm::te::Tensor &A, const tvm::te::Tensor &B, bool trans_a=false, bool trans_b=false, std::string name="T_matmul", std::string tag=kMatMul)
Creates an operation that calculates a matrix multiplication (row-major notation): A(i,...
Definition: transform.h:1557
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Operation node can generate one or multiple Tensors.
External function interface to rocBLAS libraries.