24 #ifndef TVM_TOPI_NN_SOFTMAX_H_
25 #define TVM_TOPI_NN_SOFTMAX_H_
51 std::string tag =
"softmax_output") {
52 auto input_shape = x->shape;
53 auto ndim = input_shape.size();
57 ICHECK_LT(axis, ndim) <<
"axis parameter should be less than input dim";
66 auto insert_reduce_index = [axis, ndim](
const Array<Var>& indices,
const IterVar& reduce_index) {
69 for (
size_t i = 0; i < ndim; ++i) {
70 if (
static_cast<int>(i) == axis) {
73 eval_range.
push_back(indices[arg_counter++]);
79 auto get_non_reduce_indices = [axis, ndim](
const Array<Var>& indices) {
81 for (
size_t i = 0; i < ndim; ++i) {
82 if (
static_cast<int>(i) != axis) non_reduce_indices.
push_back(indices[i]);
84 return non_reduce_indices;
87 auto _compute_max = [&](
const Array<Var>& indices) {
88 auto eval_range = insert_reduce_index(indices, k1);
93 auto non_reduce_indices = get_non_reduce_indices(indices);
94 return tvm::exp(x(indices) - max_elem(non_reduce_indices));
98 auto eval_range = insert_reduce_index(indices, k2);
103 auto non_reduce_indices = get_non_reduce_indices(indices);
104 return exp(indices) / expsum(non_reduce_indices);
109 input_shape, [&](
const Array<Var>& indices) {
return _compute_exp(max_elem, indices); });
111 reduced_shape, [&](
const Array<Var>& indices) {
return _compute_expsum(
exp, indices); });
113 input_shape, [&](
const Array<Var>& indices) {
return _normalize(
exp, expsum, indices); },
127 std::string tag =
"log_softmax_output") {
128 ICHECK_EQ(x->shape.size(), 2) <<
"Log softmax requires 2-D input";
142 x->shape, [&](
Var i,
Var j) { return x(i, j) - max_elem(i) - tvm::log(expsum(i)); }, name,
Container of constant int that adds more constructors.
Definition: expr.h:632
Reference to PrimExprNode.
Definition: expr.h:115
Range container
Definition: expr.h:725
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
void Set(const K &key, const V &value)
set the Map.
Definition: map.h:1374
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:315
a named variable in TIR
Definition: var.h:89
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Tensor softmax(const Tensor &x, int axis=-1, std::string name="tensor", std::string tag="softmax_output")
Softmax activation.
Definition: softmax.h:50
Tensor log_softmax(const Tensor &x, std::string name="tensor", std::string tag="log_softmax_output")
Log softmax activation.
Definition: softmax.h:126
PrimExpr MaxOp(PrimExpr source, Array< IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
Wrap tvm::max to ensure we get the correct overload.
Definition: reduction.h:302
Tensor exp(const Tensor &x, std::string name="T_" "exp", std::string tag=kElementWise)
Definition: elemwise.h:50
Array< PrimExpr > MakeReduceTargetShape(const std::vector< int > &real_axis, const Tensor &data, bool keepdims, bool atleast1d)
Calculate the target shape for a reduce op.
Definition: reduction.h:99
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr exp(PrimExpr x, Span span=Span())
Definition: op.h:706
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Operation node can generate one or multiple Tensors.
Reduction op constructors.