tvm
softmax.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_CUDA_SOFTMAX_H_
25 #define TVM_TOPI_CUDA_SOFTMAX_H_
26 
28 #include <tvm/te/operation.h>
29 #include <tvm/te/schedule_pass.h>
30 #include <tvm/topi/detail/fuse.h>
31 #include <tvm/topi/tags.h>
32 
33 namespace tvm {
34 namespace topi {
35 
36 using namespace tvm::te;
37 
38 namespace cuda {
39 
48 inline Schedule schedule_softmax(const Target& target, const Array<Tensor>& outs) {
49  Array<Operation> out_ops;
50  for (auto t : outs) {
51  out_ops.push_back(t->op);
52  }
53  auto s = create_schedule(out_ops);
54 
55  auto softmax = outs[0];
56  tvm::te::Tensor max_elem;
57  tvm::te::Tensor expsum;
59  bool has_exp = false;
60 
61  auto tag = softmax->op.as<ComputeOpNode>()->tag;
62  if (tag == "softmax_output") {
63  expsum = softmax->op->InputTensors()[1];
64  exp = softmax->op->InputTensors()[0];
65  max_elem = s[exp]->op->InputTensors()[1];
66  has_exp = true;
67  } else if (tag == "log_softmax_output") {
68  max_elem = softmax->op->InputTensors()[1];
69  expsum = softmax->op->InputTensors()[2];
70  } else {
71  LOG(ERROR) << "Tag is expected to be softmax_output or log_softmax_output. Got " << tag;
72  }
73 
74  int num_thread = 64;
75  auto block_x = tvm::te::thread_axis(Range(), "blockIdx.x");
76  auto thread_x = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.x");
77 
78  if (has_exp) {
79  s[exp].bind(exp->op.as<ComputeOpNode>()->axis[0], block_x);
80  }
81 
82  s[max_elem].bind(max_elem->op.as<ComputeOpNode>()->axis[0], block_x);
83 
84  auto k = expsum->op.as<ComputeOpNode>()->reduce_axis[0];
85  IterVar ko, ki;
86  s[expsum].split(k, num_thread, &ko, &ki);
87  auto EF = s.rfactor(expsum, ki)[0];
88  s[expsum].bind(s[expsum]->op.as<ComputeOpNode>()->axis[0], block_x);
89  s[expsum].bind(s[expsum]->op.as<ComputeOpNode>()->reduce_axis[0], thread_x);
90  s[EF].compute_at(s[expsum], s[expsum]->op.as<ComputeOpNode>()->reduce_axis[0]);
91  s[expsum].set_store_predicate(thread_x->var == 0);
92 
93  IterVar tx, xi;
94  s[softmax].split_by_nparts(softmax->op.as<ComputeOpNode>()->axis[1], num_thread, &tx, &xi);
95  s[softmax].bind(tx, thread_x);
96 
97  return s;
98 }
99 
100 } // namespace cuda
101 } // namespace topi
102 } // namespace tvm
103 #endif // TVM_TOPI_CUDA_SOFTMAX_H_
Range container
Definition: expr.h:725
Managed reference class to TargetNode.
Definition: target.h:200
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
Array< IterVar > axis
IterVar on each axis.
Definition: operation.h:207
Array< IterVar > reduce_axis
IterVar on each reduction axis, if the body is a Reduce.
Definition: operation.h:209
A Compute op that compute a tensor on certain domain.
Definition: operation.h:226
Global schedule container For operations and all the operations they depend on. The schedule per Oper...
Definition: schedule.h:326
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:315
Fuse operation.
Generic function that can be specialzied on a per target basis.
Tensor expression language DSL.
Definition: extracted_task.h:33
Schedule create_schedule(Array< Operation > ops)
Create a schedule for array of ops(and their dependencies).
Definition: schedule.h:702
IterVar thread_axis(Range dom, std::string tag)
Create a new IterVar that represents an axis in thread.
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor softmax(const Tensor &x, int axis=-1, std::string name="tensor", std::string tag="softmax_output")
Softmax activation.
Definition: softmax.h:50
Schedule schedule_softmax(const Target &target, const Array< Tensor > &outs)
Create a rocm schedule for the given softmax output tensors.
Definition: softmax.h:48
Tensor exp(const Tensor &x, std::string name="T_" "exp", std::string tag=kElementWise)
Definition: elemwise.h:50
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Operation node can generate one or multiple Tensors.
Collection of Schedule pass functions.
External function interface to rocBLAS libraries.