24 #ifndef TVM_TOPI_CUDA_REDUCTION_H_
25 #define TVM_TOPI_CUDA_REDUCTION_H_
51 bool is_idx_reduce =
false) {
62 auto out_stage = sch[data_out];
64 <<
"reduce_axis must be greater than zero";
68 IterVar block_x, thread_x, thread_y;
73 if (target->kind->name ==
"opencl" || target->kind->name ==
"metal") {
83 num_thread = target->GetAttr<
Integer>(
"max_num_threads").value().
IntValue();
90 out_stage.split(fused_reduce, num_thread, &ko, &ki);
91 auto data_out_rf = sch.
rfactor(data_out, ki)[0];
93 out_stage.bind(tx, thread_x);
94 sch[data_out_rf].compute_at(out_stage, tx);
97 Tensor temp_idx_input, temp_val_input;
99 real_output = op.
output(0);
100 temp_idx_input = data_out->op.output(0);
101 temp_val_input = data_out->op.output(1);
103 real_output = data_out;
106 auto stage_real = sch[real_output];
111 stage_real.split(fused_outer, num_thread, &bx, &outer_in);
114 stage_real.bind(outer_in, thread_y);
115 stage_real.bind(bx, block_x);
117 sch[temp_idx_input].compute_at(stage_real, outer_in);
118 sch[temp_val_input].compute_at(stage_real, outer_in);
122 sch[temp_idx_input].compute_at(stage_real, stage_real->op.as<
ComputeOpNode>()->
axis[0]);
123 sch[temp_val_input].compute_at(stage_real, stage_real->op.as<
ComputeOpNode>()->
axis[0]);
127 stage_real.set_store_predicate(
static_cast<PrimExpr>(thread_x) == 0);
142 s[op].compute_inline();
147 LOG(ERROR) <<
"Unsupported operator " << op->
tag;
161 LOG(ERROR) <<
"Elementwise op after reduce is not yet supported";
169 for (
auto tensor : op->
InputTensors()[0]->op->InputTensors()) {
173 LOG(ERROR) <<
"Unsupported operator " << op->
tag;
186 ICHECK_EQ(outs.
size(), 1) <<
"outs must have size 1";
188 for (
auto t : outs) {
Container of constant int that adds more constructors.
Definition: expr.h:632
int64_t IntValue() const
convert to int64_t
Definition: expr.h:669
Reference to PrimExprNode.
Definition: expr.h:115
Range container
Definition: expr.h:725
Managed reference class to TargetNode.
Definition: target.h:200
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
size_t size() const
Definition: array.h:420
bool IsInstance() const
Definition: object.h:874
Array< IterVar > axis
IterVar on each axis.
Definition: operation.h:207
Array< IterVar > reduce_axis
IterVar on each reduction axis, if the body is a Reduce.
Definition: operation.h:209
A Compute op that compute a tensor on certain domain.
Definition: operation.h:226
Managed reference to FuseNode.
Definition: schedule.h:826
virtual Array< Tensor > InputTensors() const =0
List all the input Tensors.
std::string tag
optional tag of the operation
Definition: operation.h:61
Operation that produces tensors.
Definition: tensor.h:47
Tensor output(size_t i) const
get the i-th output of the operation.
A placeholder op represents an input placeholder.
Definition: operation.h:152
Global schedule container For operations and all the operations they depend on. The schedule per Oper...
Definition: schedule.h:326
Array< Tensor > rfactor(const Tensor &tensor, const IterVar &axis, int factor_axis=0)
Factor a reduction axis in tensor's schedule to be an explicit axis. This will create a new stage tha...
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:315
Generic function that can be specialzied on a per target basis.
Tensor expression language DSL.
Definition: extracted_task.h:33
Schedule create_schedule(Array< Operation > ops)
Create a schedule for array of ops(and their dependencies).
Definition: schedule.h:702
IterVar thread_axis(Range dom, std::string tag)
Create a new IterVar that represents an axis in thread.
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Schedule ScheduleReduce(const Target &target, Operation op, Schedule sch, bool is_idx_reduce=false)
Schedule a given reduce operation.
Definition: reduction.h:50
void TraverseAfterReduce(const Target &target, Schedule s, Operation op)
Schedule a reduce op, then invoke TraverseBeforeReduce on each of the op's inputs.
Definition: reduction.h:159
void TraverseBeforeReduce(Schedule s, Operation op)
Recursively traverse operator inputs, setting injective inputs to be computed inline.
Definition: reduction.h:138
Schedule schedule_reduce(const Target &target, Array< Tensor > outs)
Create a rocm schedule for a reduce operation.
Definition: reduction.h:47
bool is_injective(std::string tag)
Definition: tags.h:51
constexpr auto kCommReduce
Definition: tags.h:34
constexpr auto kCommReduceIdx
Definition: tags.h:35
bool is_broadcast(std::string tag)
Definition: tags.h:47
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Operation node can generate one or multiple Tensors.
Collection of Schedule pass functions.