tvm
pooling.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_CUDA_POOLING_H_
25 #define TVM_TOPI_CUDA_POOLING_H_
26 
28 #include <tvm/te/operation.h>
29 #include <tvm/te/schedule_pass.h>
31 #include <tvm/topi/detail/fuse.h>
32 #include <tvm/topi/tags.h>
33 
34 namespace tvm {
35 namespace topi {
36 
37 using namespace tvm::te;
38 
39 namespace cuda {
40 
49 inline Schedule schedule_pool(const Target& target, const Array<Tensor>& outs) {
50  Array<Operation> out_ops;
51  for (auto t : outs) {
52  out_ops.push_back(t->op);
53  }
54  auto s = create_schedule(out_ops);
55 
56  auto _schedule = [&](const Tensor& padded_input, const Tensor& pool) {
57  if (padded_input->op->IsInstance<ComputeOpNode>()) {
58  s[padded_input].compute_inline();
59  }
60  int num_thread = target->GetAttr<Integer>("max_num_threads").value().IntValue();
61  Tensor out;
62  Tensor OL;
63  if (detail::contains(s->outputs, pool->op)) {
64  out = pool;
65  OL = s.cache_write(pool, "local");
66  } else {
67  out = outs[0]->op.output(0);
68  s[pool].set_scope("local");
69  }
70  auto fused = detail::Fuse(s[out], s[out]->op.as<ComputeOpNode>()->axis);
71  IterVar bx, tx;
72  s[out].split(fused, num_thread, &bx, &tx);
73  s[out].bind(bx, tvm::te::thread_axis(Range(), "blockIdx.x"));
74  s[out].bind(tx, tvm::te::thread_axis(Range(), "threadIdx.x"));
75  if (detail::contains(s->outputs, pool->op)) {
76  s[OL].compute_at(s[out], tx);
77  } else {
78  s[pool].compute_at(s[out], tx);
79  }
80  };
81 
82  std::function<void(Operation)> traverse;
83  traverse = [&](const Operation& op) {
84  // Inline all one-to-one-mapping operators except the last stage (output)
85  if (is_broadcast(op->tag)) {
86  if (!detail::contains(s->outputs, op)) {
87  s[op].compute_inline();
88  }
89  for (auto tensor : op->InputTensors()) {
90  if (tensor->op->InputTensors().size() > 0) {
91  traverse(tensor->op);
92  }
93  }
94  } else if (op->tag.rfind("pool", 0) == 0) {
95  // If tag starts with pool
96  auto padded_input = op->InputTensors()[0];
97  auto pool = op.output(0);
98  _schedule(padded_input, pool);
99  } else {
100  LOG(ERROR) << "Unsupported operator " << op->tag;
101  }
102  };
103 
104  traverse(outs[0]->op);
105  return s;
106 }
107 
116 inline Schedule schedule_global_pool(const Target& target, const Array<Tensor>& outs) {
117  Array<Operation> out_ops;
118  for (auto t : outs) {
119  out_ops.push_back(t->op);
120  }
121  auto s = create_schedule(out_ops);
122 
123  auto _schedule = [&](const Tensor& pool) {
124  auto num_thread = 8;
125  auto block_x = tvm::te::thread_axis(Range(), "blockIdx.x");
126  auto block_y = tvm::te::thread_axis(Range(), "blockIdx.y");
127  auto thread_x = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.x");
128  auto thread_y = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.y");
129  Tensor out;
130  Tensor OL;
131  if (detail::contains(s->outputs, pool->op)) {
132  out = pool;
133  OL = s.cache_write(pool, "local");
134  } else {
135  out = outs[0]->op.output(0);
136  s[pool].set_scope("local");
137  }
138 
139  auto i = s[out]->op.as<ComputeOpNode>()->axis[0];
140  auto c = s[out]->op.as<ComputeOpNode>()->axis[1];
141 
142  IterVar by, ty;
143  s[out].split(i, num_thread, &by, &ty);
144  IterVar bx, tx;
145  s[out].split(c, num_thread, &bx, &tx);
146  s[out].reorder({by, bx, ty, tx});
147  s[out].bind(ty, thread_y);
148  s[out].bind(tx, thread_x);
149  s[out].bind(by, block_y);
150  s[out].bind(bx, block_x);
151 
152  if (detail::contains(s->outputs, pool->op)) {
153  s[OL].compute_at(s[out], tx);
154  } else {
155  s[pool].compute_at(s[out], tx);
156  }
157  };
158 
159  std::function<void(Operation)> traverse;
160  traverse = [&](const Operation& op) {
161  // Inline all one-to-one-mapping operators except the last stage (output)
162  if (is_broadcast(op->tag)) {
163  if (!detail::contains(s->outputs, op)) {
164  s[op].compute_inline();
165  }
166  for (auto tensor : op->InputTensors()) {
167  if (tensor->op->InputTensors().size() > 0) {
168  traverse(tensor->op);
169  }
170  }
171  } else if (op->tag.rfind("global_pool", 0) == 0) {
172  // If tag starts with global_pool
173  auto pool = op.output(0);
174  _schedule(pool);
175  } else {
176  LOG(ERROR) << "Unsupported operator " << op->tag;
177  }
178  };
179 
180  traverse(outs[0]->op);
181  return s;
182 }
183 
184 } // namespace cuda
185 } // namespace topi
186 } // namespace tvm
187 #endif // TVM_TOPI_CUDA_POOLING_H_
Utility functions for handling arrays.
Container of constant int that adds more constructors.
Definition: expr.h:632
int64_t IntValue() const
convert to int64_t
Definition: expr.h:669
Range container
Definition: expr.h:725
Managed reference class to TargetNode.
Definition: target.h:200
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
Array< IterVar > axis
IterVar on each axis.
Definition: operation.h:207
A Compute op that compute a tensor on certain domain.
Definition: operation.h:226
Managed reference to FuseNode.
Definition: schedule.h:826
Operation that produces tensors.
Definition: tensor.h:47
Global schedule container For operations and all the operations they depend on. The schedule per Oper...
Definition: schedule.h:326
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:315
Fuse operation.
Generic function that can be specialzied on a per target basis.
Tensor expression language DSL.
Definition: extracted_task.h:33
Schedule create_schedule(Array< Operation > ops)
Create a schedule for array of ops(and their dependencies).
Definition: schedule.h:702
IterVar thread_axis(Range dom, std::string tag)
Create a new IterVar that represents an axis in thread.
Schedule schedule_global_pool(const Target &target, const Array< Tensor > &outs)
Create a rocm schedule for global_pool.
Definition: pooling.h:61
Schedule schedule_pool(const Target &target, const Array< Tensor > &outs)
Create a rocm schedule for pool.
Definition: pooling.h:49
bool is_broadcast(std::string tag)
Definition: tags.h:47
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Operation node can generate one or multiple Tensors.
Collection of Schedule pass functions.
External function interface to rocBLAS libraries.