24 #ifndef TVM_TOPI_CUDA_POOLING_H_
25 #define TVM_TOPI_CUDA_POOLING_H_
56 auto _schedule = [&](
const Tensor& padded_input,
const Tensor& pool) {
58 s[padded_input].compute_inline();
60 int num_thread = target->GetAttr<
Integer>(
"max_num_threads").value().
IntValue();
63 if (detail::contains(s->outputs, pool->op)) {
65 OL = s.cache_write(pool,
"local");
67 out = outs[0]->op.output(0);
68 s[pool].set_scope(
"local");
72 s[out].split(fused, num_thread, &bx, &tx);
75 if (detail::contains(s->outputs, pool->op)) {
76 s[OL].compute_at(s[out], tx);
78 s[pool].compute_at(s[out], tx);
86 if (!detail::contains(s->outputs, op)) {
87 s[op].compute_inline();
89 for (
auto tensor : op->InputTensors()) {
90 if (tensor->op->InputTensors().size() > 0) {
94 }
else if (op->tag.rfind(
"pool", 0) == 0) {
96 auto padded_input = op->InputTensors()[0];
97 auto pool = op.output(0);
98 _schedule(padded_input, pool);
100 LOG(ERROR) <<
"Unsupported operator " << op->tag;
104 traverse(outs[0]->op);
118 for (
auto t : outs) {
123 auto _schedule = [&](
const Tensor& pool) {
131 if (detail::contains(s->outputs, pool->op)) {
133 OL = s.cache_write(pool,
"local");
135 out = outs[0]->op.output(0);
136 s[pool].set_scope(
"local");
143 s[out].split(i, num_thread, &by, &ty);
145 s[out].split(c, num_thread, &bx, &tx);
146 s[out].reorder({by, bx, ty, tx});
147 s[out].bind(ty, thread_y);
148 s[out].bind(tx, thread_x);
149 s[out].bind(by, block_y);
150 s[out].bind(bx, block_x);
152 if (detail::contains(s->outputs, pool->op)) {
153 s[OL].compute_at(s[out], tx);
155 s[pool].compute_at(s[out], tx);
163 if (!detail::contains(s->outputs, op)) {
164 s[op].compute_inline();
166 for (
auto tensor : op->InputTensors()) {
167 if (tensor->op->InputTensors().size() > 0) {
168 traverse(tensor->op);
171 }
else if (op->tag.rfind(
"global_pool", 0) == 0) {
173 auto pool = op.output(0);
176 LOG(ERROR) <<
"Unsupported operator " << op->tag;
180 traverse(outs[0]->op);
Utility functions for handling arrays.
Container of constant int that adds more constructors.
Definition: expr.h:632
int64_t IntValue() const
convert to int64_t
Definition: expr.h:669
Range container
Definition: expr.h:725
Managed reference class to TargetNode.
Definition: target.h:200
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
Array< IterVar > axis
IterVar on each axis.
Definition: operation.h:207
A Compute op that compute a tensor on certain domain.
Definition: operation.h:226
Managed reference to FuseNode.
Definition: schedule.h:826
Operation that produces tensors.
Definition: tensor.h:47
Global schedule container For operations and all the operations they depend on. The schedule per Oper...
Definition: schedule.h:326
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:315
Generic function that can be specialzied on a per target basis.
Tensor expression language DSL.
Definition: extracted_task.h:33
Schedule create_schedule(Array< Operation > ops)
Create a schedule for array of ops(and their dependencies).
Definition: schedule.h:702
IterVar thread_axis(Range dom, std::string tag)
Create a new IterVar that represents an axis in thread.
Schedule schedule_global_pool(const Target &target, const Array< Tensor > &outs)
Create a rocm schedule for global_pool.
Definition: pooling.h:61
Schedule schedule_pool(const Target &target, const Array< Tensor > &outs)
Create a rocm schedule for pool.
Definition: pooling.h:49
bool is_broadcast(std::string tag)
Definition: tags.h:47
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Operation node can generate one or multiple Tensors.
Collection of Schedule pass functions.