24 #ifndef TVM_TOPI_CUDA_REDUCTION_H_    25 #define TVM_TOPI_CUDA_REDUCTION_H_    51                         bool is_idx_reduce = 
false) {
    62   auto out_stage = sch[data_out];
    64       << 
"reduce_axis must be greater than zero";
    68   IterVar block_x, thread_x, thread_y;
    73     if (target->kind->name == 
"opencl" || target->kind->name == 
"metal") {
    83     num_thread = target->GetAttr<
Integer>(
"max_num_threads").value().
IntValue();
    90   out_stage.split(fused_reduce, num_thread, &ko, &ki);
    91   auto data_out_rf = sch.
rfactor(data_out, ki)[0];
    93   out_stage.bind(tx, thread_x);
    94   sch[data_out_rf].compute_at(out_stage, tx);
    97   Tensor temp_idx_input, temp_val_input;
    99     real_output = op.
output(0);
   100     temp_idx_input = data_out->op.output(0);
   101     temp_val_input = data_out->op.output(1);
   103     real_output = data_out;
   106   auto stage_real = sch[real_output];
   111     stage_real.split(fused_outer, num_thread, &bx, &outer_in);
   114     stage_real.bind(outer_in, thread_y);
   115     stage_real.bind(bx, block_x);
   117       sch[temp_idx_input].compute_at(stage_real, outer_in);
   118       sch[temp_val_input].compute_at(stage_real, outer_in);
   122       sch[temp_idx_input].compute_at(stage_real, stage_real->op.as<
ComputeOpNode>()->
axis[0]);
   123       sch[temp_val_input].compute_at(stage_real, stage_real->op.as<
ComputeOpNode>()->
axis[0]);
   127   stage_real.set_store_predicate(static_cast<PrimExpr>(thread_x) == 0);
   142     s[op].compute_inline();
   147     LOG(ERROR) << 
"Unsupported operator " << op->
tag;
   161     LOG(ERROR) << 
"Elementwise op after reduce is not yet supported";
   169     for (
auto tensor : op->
InputTensors()[0]->op->InputTensors()) {
   173     LOG(ERROR) << 
"Unsupported operator " << op->
tag;
   186   ICHECK_EQ(outs.
size(), 1) << 
"outs must have size 1";
   188   for (
auto t : outs) {
   199 #endif  // TVM_TOPI_CUDA_REDUCTION_H_ IterVar thread_axis(Range dom, std::string tag)
Create a new IterVar that represents an axis in thread. 
void TraverseBeforeReduce(Schedule s, Operation op)
Recursively traverse operator inputs, setting injective inputs to be computed inline. 
Definition: reduction.h:138
Tensor output(size_t i) const
get the i-th output of the operation. 
Global schedule container For operations and all the operations they depend on. The schedule per Oper...
Definition: schedule.h:317
constexpr auto kCommReduceIdx
Definition: tags.h:35
Schedule ScheduleReduce(const Target &target, Operation op, Schedule sch, bool is_idx_reduce=false)
Schedule a given reduce operation. 
Definition: reduction.h:50
Schedule create_schedule(Array< Operation > ops)
Create a schedule for array of ops(and their dependencies). 
Definition: schedule.h:654
std::string tag
optional tag of the operation 
Definition: operation.h:61
runtime implementation for LibTorch/TorchScript. 
Definition: analyzer.h:36
Tensor expression language DSL. 
Definition: extracted_task.h:33
Operation that produces tensors. 
Definition: tensor.h:47
bool is_injective(std::string tag)
Definition: tags.h:51
Iteration Variable, represents an iteration over an integer interval. 
Definition: var.h:301
A placeholder op represents an input placeholder. 
Definition: operation.h:152
void push_back(const T &item)
push a new item to the back of the list 
Definition: array.h:457
bool IsInstance() const
Definition: object.h:829
Collection of Schedule pass functions. 
Range constainer. 
Definition: expr.h:713
virtual Array< Tensor > InputTensors() const =0
List all the input Tensors. 
size_t size() const
Definition: array.h:420
int64_t IntValue() const
convert to int64_t 
Definition: expr.h:657
Array, container representing a contiguous sequence of ObjectRefs. 
Definition: array.h:289
A Compute op that compute a tensor on certain domain. 
Definition: operation.h:226
Array< IterVar > axis
IterVar on each axis. 
Definition: operation.h:207
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations. 
constexpr auto kCommReduce
Definition: tags.h:34
Managed reference class to TargetNode. 
Definition: target.h:183
Tensor structure representing a possible input, or intermediate computation result. 
Definition: tensor.h:102
Operation node can generate one or multiple Tensors. 
Managed reference to FuseNode. 
Definition: schedule.h:774
Array< Tensor > rfactor(const Tensor &tensor, const IterVar &axis, int factor_axis=0)
Factor a reduction axis in tensor's schedule to be an explicit axis. This will create a new stage tha...
bool is_broadcast(std::string tag)
Definition: tags.h:47
Schedule schedule_reduce(const Target &target, Array< Tensor > outs)
Create a rocm schedule for a reduce operation. 
Definition: reduction.h:47
void TraverseAfterReduce(const Target &target, Schedule s, Operation op)
Schedule a reduce op, then invoke TraverseBeforeReduce on each of the op's inputs. 
Definition: reduction.h:159
Array< IterVar > reduce_axis
IterVar on each reduction axis, if the body is a Reduce. 
Definition: operation.h:209
Generic function that can be specialzied on a per target basis. 
Container of constant int that adds more constructors. 
Definition: expr.h:620