tvm
dense.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_ROCM_DENSE_H_
25 #define TVM_TOPI_ROCM_DENSE_H_
26 
28 #include <tvm/te/operation.h>
30 #include <tvm/topi/cuda/dense.h>
33 #include <tvm/topi/nn/dense.h>
34 #include <tvm/topi/tags.h>
35 
36 namespace tvm {
37 namespace topi {
38 
39 using namespace tvm::te;
40 
41 namespace rocm {
53 inline tvm::te::Tensor dense_rocm(const Target& target, const tvm::te::Tensor& data,
54  const tvm::te::Tensor& weight, const tvm::te::Tensor& bias,
55  const DataType& out_dtype) {
56  ICHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
57  ICHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
58  if (bias.defined()) {
59  ICHECK_EQ(bias->shape.size(), 1) << "dense requires 1-D bias";
60  }
61 
62  auto batch = data->shape[0];
63  auto in_dim = data->shape[1];
64  auto out_dim = weight->shape[0];
65 
66  if (target->GetLibs().count("rocblas")) {
67  ICHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported.";
68  auto mm = topi::contrib::rocblas_matmul(data, weight, false, true);
69  if (bias.defined()) {
70  mm = tvm::te::compute(
71  {batch, out_dim}, [&](Var i, Var j) { return mm(i, j) + bias(j); }, "tensor", kBroadcast);
72  }
73 
74  return mm;
75  } else {
76  return topi::nn::dense(data, weight, bias, out_dtype);
77  }
78 }
79 
88 inline Schedule schedule_dense(const Target& target, const Array<Tensor>& outs) {
89  if (target->kind->name == "rocm" && target->GetLibs().count("rocblas")) {
90  return topi::generic::schedule_extern(target, outs);
91  }
92 
93  return topi::cuda::schedule_dense(target, outs);
94 }
95 
96 } // namespace rocm
97 } // namespace topi
98 } // namespace tvm
99 #endif // TVM_TOPI_ROCM_DENSE_H_
Utility functions for handling arrays.
Managed reference class to TargetNode.
Definition: target.h:200
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Runtime primitive data type.
Definition: data_type.h:43
bool defined() const
Definition: object.h:552
Global schedule container For operations and all the operations they depend on. The schedule per Oper...
Definition: schedule.h:326
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
a named variable in TIR
Definition: var.h:89
CUDA schedule for dense operation.
Schedule for extern followed by injective ops.
Generic function that can be specialzied on a per target basis.
Tensor expression language DSL.
Definition: extracted_task.h:33
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Tensor rocblas_matmul(const Tensor &lhs, const Tensor &rhs, bool transa, bool transb)
Create an op that multiplies lhs and rhs with rocBLAS.
Definition: rocblas.h:45
Schedule schedule_dense(const Target &target, const Array< Tensor > &outs)
Create a CUDA schedule for dense.
Definition: dense.h:88
Schedule schedule_extern(const Target &target, const Array< Tensor > &outs)
Schedule an extern op followed by injective operations.
Definition: extern.h:48
tvm::te::Tensor dense(const tvm::te::Tensor &data, const tvm::te::Tensor &weight, const tvm::te::Tensor &bias, const DataType &out_dtype)
Creates an operation that calculates data * weight^T + bias.
Definition: dense.h:48
tvm::te::Tensor dense_rocm(const Target &target, const tvm::te::Tensor &data, const tvm::te::Tensor &weight, const tvm::te::Tensor &bias, const DataType &out_dtype)
Implementation of dense for rocm backend.
Definition: dense.h:53
Schedule schedule_dense(const Target &target, const Array< Tensor > &outs)
Create a rocm schedule for dense.
Definition: dense.h:88
constexpr auto kBroadcast
Definition: tags.h:36
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Dense op constructions.
Operation node can generate one or multiple Tensors.
External function interface to rocBLAS libraries.