tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
dense.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_CUDA_DENSE_H_
25 #define TVM_TOPI_CUDA_DENSE_H_
26 
28 #include <tvm/te/operation.h>
29 #include <tvm/te/schedule_pass.h>
33 #include <tvm/topi/nn/dense.h>
34 #include <tvm/topi/tags.h>
35 
36 namespace tvm {
37 namespace topi {
38 
39 using namespace tvm::te;
40 
41 namespace cuda {
53 inline tvm::te::Tensor dense_cuda(const Target& target, const tvm::te::Tensor& data,
54  const tvm::te::Tensor& weight, const tvm::te::Tensor& bias,
55  const DataType& out_dtype) {
56  ICHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
57  ICHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
58  if (bias.defined()) {
59  ICHECK_EQ(bias->shape.size(), 1) << "dense requires 1-D bias";
60  }
61 
62  auto batch = data->shape[0];
63  auto in_dim = data->shape[1];
64  auto out_dim = weight->shape[0];
65 
66  if (target->GetLibs().count("cublas")) {
67  ICHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported.";
68  auto mm = topi::contrib::cublas_matmul(data, weight, false, true);
69  if (bias.defined()) {
70  mm = tvm::te::compute(
71  {batch, out_dim}, [&](Var i, Var j) { return mm(i, j) + bias(j); }, "tensor", kBroadcast);
72  }
73 
74  return mm;
75  } else {
76  return topi::nn::dense(data, weight, bias, out_dtype);
77  }
78 }
79 
88 inline Schedule schedule_dense(const Target& target, const Array<Tensor>& outs) {
89  if (target->kind->name == "cuda" && target->GetLibs().count("cublas")) {
90  return topi::generic::schedule_extern(target, outs);
91  }
92 
93  Array<Operation> out_ops;
94  for (auto t : outs) {
95  out_ops.push_back(t->op);
96  }
97  auto s = create_schedule(out_ops);
98 
99  auto _schedule = [&](const Tensor& dense) {
100  auto num_thread = 64;
101  auto k = dense->op.as<ComputeOpNode>()->reduce_axis[0];
102  IterVar ko, kf;
103  s[dense].split(k, num_thread, &ko, &kf);
104  auto dense_f = s.rfactor(dense, kf)[0];
105 
106  Tensor out;
107  if (detail::contains(s->outputs, dense->op)) {
108  out = dense;
109  } else {
110  out = outs[0]->op.output(0);
111  s[dense].compute_at(s[out], s[out]->op.as<ComputeOpNode>()->axis[1]);
112  }
113  s[out].bind(s[out]->op.as<ComputeOpNode>()->axis[0],
114  tvm::te::thread_axis(Range(), "blockIdx.y"));
115  s[out].bind(s[out]->op.as<ComputeOpNode>()->axis[1],
116  tvm::te::thread_axis(Range(), "blockIdx.x"));
117 
118  auto tx = s[dense]->op.as<ComputeOpNode>()->reduce_axis[0];
119  auto thread_x = tvm::te::thread_axis(Range(), "threadIdx.x");
120  s[dense].bind(tx, thread_x);
121  s[dense_f].compute_at(s[dense], tx);
122  s[dense].set_store_predicate(static_cast<PrimExpr>(thread_x) == 0);
123  s[out].set_store_predicate(static_cast<PrimExpr>(thread_x) == 0);
124  };
125 
126  std::function<void(Operation)> traverse;
127  traverse = [&](const Operation& op) {
128  // Inline all one-to-one-mapping operators except the last stage (output)
129  if (is_broadcast(op->tag)) {
130  if (!detail::contains(s->outputs, op)) {
131  s[op].compute_inline();
132  }
133  for (auto tensor : op->InputTensors()) {
134  if (tensor->op->InputTensors().size() > 0) {
135  traverse(tensor->op);
136  }
137  }
138  } else if (op->tag == "dense") {
139  // If tag starts with global_pool
140  auto dense = op.output(0);
141  _schedule(dense);
142  } else {
143  LOG(ERROR) << "Unsupported operator " << op->tag;
144  }
145  };
146 
147  traverse(outs[0]->op);
148  return s;
149 }
150 
151 } // namespace cuda
152 } // namespace topi
153 } // namespace tvm
154 #endif // TVM_TOPI_CUDA_DENSE_H_
IterVar thread_axis(Range dom, std::string tag)
Create a new IterVar that represents an axis in thread.
Schedule for extern followed by injective ops.
Global schedule container For operations and all the operations they depend on. The schedule per Oper...
Definition: schedule.h:318
Schedule schedule_dense(const Target &target, const Array< Tensor > &outs)
Create a rocm schedule for dense.
Definition: dense.h:88
Schedule create_schedule(Array< Operation > ops)
Create a schedule for array of ops(and their dependencies).
Definition: schedule.h:695
Utility functions for handling arrays.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Tensor expression language DSL.
Definition: extracted_task.h:33
Operation that produces tensors.
Definition: tensor.h:47
a named variable in TIR
Definition: var.h:88
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:308
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
Collection of Schedule pass functions.
Tensor cublas_matmul(const Tensor &lhs, const Tensor &rhs, bool transa, bool transb)
Create an op that multiplies lhs and rhs with cuBLAS.
Definition: cublas.h:46
constexpr auto kBroadcast
Definition: tags.h:36
Range constainer.
Definition: expr.h:715
tvm::te::Tensor dense(const tvm::te::Tensor &data, const tvm::te::Tensor &weight, const tvm::te::Tensor &bias, const DataType &out_dtype)
Creates an operation that calculates data * weight^T + bias.
Definition: dense.h:48
bool defined() const
Definition: object.h:544
Runtime primitive data type.
Definition: data_type.h:41
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
External function interface to cuBLAS libraries.
A Compute op that compute a tensor on certain domain.
Definition: operation.h:226
Array< IterVar > axis
IterVar on each axis.
Definition: operation.h:207
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Managed reference class to TargetNode.
Definition: target.h:183
Dense op constructions.
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Operation node can generate one or multiple Tensors.
bool is_broadcast(std::string tag)
Definition: tags.h:47
External function interface to rocBLAS libraries.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:865
Schedule schedule_extern(const Target &target, const Array< Tensor > &outs)
Schedule an extern op followed by injective operations.
Definition: extern.h:48
Generic function that can be specialzied on a per target basis.
tvm::te::Tensor dense_cuda(const Target &target, const tvm::te::Tensor &data, const tvm::te::Tensor &weight, const tvm::te::Tensor &bias, const DataType &out_dtype)
Implementation of dense for CUDA backend.
Definition: dense.h:53