24 #ifndef TVM_TOPI_CUDA_REDUCTION_H_ 25 #define TVM_TOPI_CUDA_REDUCTION_H_ 51 bool is_idx_reduce =
false) {
62 auto out_stage = sch[data_out];
64 <<
"reduce_axis must be greater than zero";
68 IterVar block_x, thread_x, thread_y;
73 if (target->kind->name ==
"opencl" || target->kind->name ==
"metal") {
83 num_thread = target->GetAttr<
Integer>(
"max_num_threads").value().
IntValue();
90 out_stage.split(fused_reduce, num_thread, &ko, &ki);
91 auto data_out_rf = sch.
rfactor(data_out, ki)[0];
93 out_stage.bind(tx, thread_x);
94 sch[data_out_rf].compute_at(out_stage, tx);
97 Tensor temp_idx_input, temp_val_input;
99 real_output = op.
output(0);
100 temp_idx_input = data_out->op.output(0);
101 temp_val_input = data_out->op.output(1);
103 real_output = data_out;
106 auto stage_real = sch[real_output];
111 stage_real.split(fused_outer, num_thread, &bx, &outer_in);
114 stage_real.bind(outer_in, thread_y);
115 stage_real.bind(bx, block_x);
117 sch[temp_idx_input].compute_at(stage_real, outer_in);
118 sch[temp_val_input].compute_at(stage_real, outer_in);
122 sch[temp_idx_input].compute_at(stage_real, stage_real->op.as<
ComputeOpNode>()->
axis[0]);
123 sch[temp_val_input].compute_at(stage_real, stage_real->op.as<
ComputeOpNode>()->
axis[0]);
127 stage_real.set_store_predicate(static_cast<PrimExpr>(thread_x) == 0);
142 s[op].compute_inline();
147 LOG(ERROR) <<
"Unsupported operator " << op->
tag;
161 LOG(ERROR) <<
"Elementwise op after reduce is not yet supported";
169 for (
auto tensor : op->
InputTensors()[0]->op->InputTensors()) {
173 LOG(ERROR) <<
"Unsupported operator " << op->
tag;
186 ICHECK_EQ(outs.
size(), 1) <<
"outs must have size 1";
188 for (
auto t : outs) {
199 #endif // TVM_TOPI_CUDA_REDUCTION_H_ IterVar thread_axis(Range dom, std::string tag)
Create a new IterVar that represents an axis in thread.
void TraverseBeforeReduce(Schedule s, Operation op)
Recursively traverse operator inputs, setting injective inputs to be computed inline.
Definition: reduction.h:138
Tensor output(size_t i) const
get the i-th output of the operation.
Global schedule container For operations and all the operations they depend on. The schedule per Oper...
Definition: schedule.h:318
constexpr auto kCommReduceIdx
Definition: tags.h:35
Schedule ScheduleReduce(const Target &target, Operation op, Schedule sch, bool is_idx_reduce=false)
Schedule a given reduce operation.
Definition: reduction.h:50
Schedule create_schedule(Array< Operation > ops)
Create a schedule for array of ops(and their dependencies).
Definition: schedule.h:695
std::string tag
optional tag of the operation
Definition: operation.h:61
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Tensor expression language DSL.
Definition: extracted_task.h:33
Operation that produces tensors.
Definition: tensor.h:47
bool is_injective(std::string tag)
Definition: tags.h:51
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:308
A placeholder op represents an input placeholder.
Definition: operation.h:152
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
bool IsInstance() const
Definition: object.h:829
Collection of Schedule pass functions.
Range constainer.
Definition: expr.h:715
virtual Array< Tensor > InputTensors() const =0
List all the input Tensors.
size_t size() const
Definition: array.h:420
int64_t IntValue() const
convert to int64_t
Definition: expr.h:659
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
A Compute op that compute a tensor on certain domain.
Definition: operation.h:226
Array< IterVar > axis
IterVar on each axis.
Definition: operation.h:207
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
constexpr auto kCommReduce
Definition: tags.h:34
Managed reference class to TargetNode.
Definition: target.h:183
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Operation node can generate one or multiple Tensors.
Managed reference to FuseNode.
Definition: schedule.h:815
Array< Tensor > rfactor(const Tensor &tensor, const IterVar &axis, int factor_axis=0)
Factor a reduction axis in tensor's schedule to be an explicit axis. This will create a new stage tha...
bool is_broadcast(std::string tag)
Definition: tags.h:47
Schedule schedule_reduce(const Target &target, Array< Tensor > outs)
Create a rocm schedule for a reduce operation.
Definition: reduction.h:47
void TraverseAfterReduce(const Target &target, Schedule s, Operation op)
Schedule a reduce op, then invoke TraverseBeforeReduce on each of the op's inputs.
Definition: reduction.h:159
Array< IterVar > reduce_axis
IterVar on each reduction axis, if the body is a Reduce.
Definition: operation.h:209
Generic function that can be specialzied on a per target basis.
Container of constant int that adds more constructors.
Definition: expr.h:622