tvm
reduction.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_CUDA_REDUCTION_H_
25 #define TVM_TOPI_CUDA_REDUCTION_H_
26 
28 #include <tvm/te/operation.h>
29 #include <tvm/te/schedule_pass.h>
30 #include <tvm/topi/detail/fuse.h>
31 #include <tvm/topi/tags.h>
32 
33 namespace tvm {
34 namespace topi {
35 
36 using namespace tvm::te;
37 
38 namespace cuda {
51  bool is_idx_reduce = false) {
52  Tensor data_out;
53  Tensor data_in;
54 
55  if (!is_idx_reduce) {
56  data_in = op->InputTensors()[0];
57  data_out = op.output(0);
58  } else {
59  data_out = op->InputTensors()[0];
60  }
61 
62  auto out_stage = sch[data_out];
63  ICHECK_GT(out_stage->op.as<ComputeOpNode>()->reduce_axis.size(), 0)
64  << "reduce_axis must be greater than zero";
65 
66  bool all_reduce;
67  int num_thread;
68  IterVar block_x, thread_x, thread_y;
69 
70  if (out_stage->op.as<ComputeOpNode>()->axis.size() > 0) {
71  all_reduce = false;
72  num_thread = 32;
73  if (target->kind->name == "opencl" || target->kind->name == "metal") {
74  // Without this, CL_INVALID_WORK_GROUP_SIZE occurs with python tests.
75  // Don't know why.
76  num_thread = 16;
77  }
78  block_x = tvm::te::thread_axis(Range(), "blockIdx.x");
79  thread_x = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.x");
80  thread_y = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.y");
81  } else {
82  all_reduce = true;
83  num_thread = target->GetAttr<Integer>("max_num_threads").value().IntValue();
84  thread_x = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.x");
85  }
86 
87  auto fused_reduce = detail::Fuse(out_stage, out_stage->op.as<ComputeOpNode>()->reduce_axis);
88 
89  IterVar ko, ki;
90  out_stage.split(fused_reduce, num_thread, &ko, &ki);
91  auto data_out_rf = sch.rfactor(data_out, ki)[0];
92  auto tx = out_stage->op.as<ComputeOpNode>()->reduce_axis[0];
93  out_stage.bind(tx, thread_x);
94  sch[data_out_rf].compute_at(out_stage, tx);
95 
96  Tensor real_output;
97  Tensor temp_idx_input, temp_val_input;
98  if (is_idx_reduce) {
99  real_output = op.output(0);
100  temp_idx_input = data_out->op.output(0);
101  temp_val_input = data_out->op.output(1);
102  } else {
103  real_output = data_out;
104  }
105 
106  auto stage_real = sch[real_output];
107  if (!all_reduce) {
108  // Fuse and split the axis
109  auto fused_outer = detail::Fuse(stage_real, stage_real->op.as<ComputeOpNode>()->axis);
110  IterVar bx, outer_in;
111  stage_real.split(fused_outer, num_thread, &bx, &outer_in);
112 
113  // Bind the axes to threads and blocks
114  stage_real.bind(outer_in, thread_y);
115  stage_real.bind(bx, block_x);
116  if (is_idx_reduce) {
117  sch[temp_idx_input].compute_at(stage_real, outer_in);
118  sch[temp_val_input].compute_at(stage_real, outer_in);
119  }
120  } else {
121  if (is_idx_reduce) {
122  sch[temp_idx_input].compute_at(stage_real, stage_real->op.as<ComputeOpNode>()->axis[0]);
123  sch[temp_val_input].compute_at(stage_real, stage_real->op.as<ComputeOpNode>()->axis[0]);
124  }
125  }
126 
127  stage_real.set_store_predicate(static_cast<PrimExpr>(thread_x) == 0);
128  return sch;
129 }
130 
139  if (op->IsInstance<PlaceholderOpNode>()) {
140  return;
141  } else if (is_injective(op->tag)) {
142  s[op].compute_inline();
143  for (auto tensor : op->InputTensors()) {
144  TraverseBeforeReduce(s, tensor->op);
145  }
146  } else {
147  LOG(ERROR) << "Unsupported operator " << op->tag;
148  }
149 }
150 
159 void TraverseAfterReduce(const Target& target, Schedule s, Operation op) {
160  if (is_broadcast(op->tag)) {
161  LOG(ERROR) << "Elementwise op after reduce is not yet supported";
162  } else if (op->tag == kCommReduce) {
163  ScheduleReduce(target, op, s, false);
164  for (auto tensor : op->InputTensors()) {
165  TraverseBeforeReduce(s, tensor->op);
166  }
167  } else if (op->tag == kCommReduceIdx) {
168  ScheduleReduce(target, op, s, true);
169  for (auto tensor : op->InputTensors()[0]->op->InputTensors()) {
170  TraverseBeforeReduce(s, tensor->op);
171  }
172  } else {
173  LOG(ERROR) << "Unsupported operator " << op->tag;
174  }
175 }
176 
186  ICHECK_EQ(outs.size(), 1) << "outs must have size 1";
187  Array<Operation> out_ops;
188  for (auto t : outs) {
189  out_ops.push_back(t->op);
190  }
191  auto s = create_schedule(out_ops);
192  TraverseAfterReduce(target, s, outs[0]->op);
193  return s;
194 }
195 
196 } // namespace cuda
197 } // namespace topi
198 } // namespace tvm
199 #endif // TVM_TOPI_CUDA_REDUCTION_H_
Container of constant int that adds more constructors.
Definition: expr.h:632
int64_t IntValue() const
convert to int64_t
Definition: expr.h:669
Reference to PrimExprNode.
Definition: expr.h:115
Range container
Definition: expr.h:725
Managed reference class to TargetNode.
Definition: target.h:200
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
size_t size() const
Definition: array.h:420
bool IsInstance() const
Definition: object.h:874
Array< IterVar > axis
IterVar on each axis.
Definition: operation.h:207
Array< IterVar > reduce_axis
IterVar on each reduction axis, if the body is a Reduce.
Definition: operation.h:209
A Compute op that compute a tensor on certain domain.
Definition: operation.h:226
Managed reference to FuseNode.
Definition: schedule.h:826
virtual Array< Tensor > InputTensors() const =0
List all the input Tensors.
std::string tag
optional tag of the operation
Definition: operation.h:61
Operation that produces tensors.
Definition: tensor.h:47
Tensor output(size_t i) const
get the i-th output of the operation.
A placeholder op represents an input placeholder.
Definition: operation.h:152
Global schedule container For operations and all the operations they depend on. The schedule per Oper...
Definition: schedule.h:326
Array< Tensor > rfactor(const Tensor &tensor, const IterVar &axis, int factor_axis=0)
Factor a reduction axis in tensor's schedule to be an explicit axis. This will create a new stage tha...
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:315
Fuse operation.
Generic function that can be specialzied on a per target basis.
Tensor expression language DSL.
Definition: extracted_task.h:33
Schedule create_schedule(Array< Operation > ops)
Create a schedule for array of ops(and their dependencies).
Definition: schedule.h:702
IterVar thread_axis(Range dom, std::string tag)
Create a new IterVar that represents an axis in thread.
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Schedule ScheduleReduce(const Target &target, Operation op, Schedule sch, bool is_idx_reduce=false)
Schedule a given reduce operation.
Definition: reduction.h:50
void TraverseAfterReduce(const Target &target, Schedule s, Operation op)
Schedule a reduce op, then invoke TraverseBeforeReduce on each of the op's inputs.
Definition: reduction.h:159
void TraverseBeforeReduce(Schedule s, Operation op)
Recursively traverse operator inputs, setting injective inputs to be computed inline.
Definition: reduction.h:138
Schedule schedule_reduce(const Target &target, Array< Tensor > outs)
Create a rocm schedule for a reduce operation.
Definition: reduction.h:47
bool is_injective(std::string tag)
Definition: tags.h:51
constexpr auto kCommReduce
Definition: tags.h:34
constexpr auto kCommReduceIdx
Definition: tags.h:35
bool is_broadcast(std::string tag)
Definition: tags.h:47
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Operation node can generate one or multiple Tensors.
Collection of Schedule pass functions.
External function interface to rocBLAS libraries.