24 #ifndef TVM_TOPI_NN_SOFTMAX_H_
25 #define TVM_TOPI_NN_SOFTMAX_H_
51 std::string tag =
"softmax_output") {
52 auto input_shape = x->shape;
53 auto ndim = input_shape.size();
57 ICHECK_LT(axis, ndim) <<
"axis parameter should be less than input dim";
63 tvm::ffi::Map<ffi::String, ffi::Any> attrs;
64 attrs.Set(
"axis",
Integer(axis));
66 auto insert_reduce_index = [axis, ndim](
const ffi::Array<Var>& indices,
68 ffi::Array<PrimExpr> eval_range;
70 for (
size_t i = 0; i < ndim; ++i) {
71 if (
static_cast<int>(i) == axis) {
72 eval_range.push_back(reduce_index);
74 eval_range.push_back(indices[arg_counter++]);
80 auto get_non_reduce_indices = [axis, ndim](
const ffi::Array<Var>& indices) {
81 ffi::Array<PrimExpr> non_reduce_indices;
82 for (
size_t i = 0; i < ndim; ++i) {
83 if (
static_cast<int>(i) != axis) non_reduce_indices.push_back(indices[i]);
85 return non_reduce_indices;
88 auto _compute_max = [&](
const ffi::Array<Var>& indices) {
89 auto eval_range = insert_reduce_index(indices, k1);
93 auto _compute_exp = [&](
const Tensor& max_elem,
const ffi::Array<Var>& indices) {
94 auto non_reduce_indices = get_non_reduce_indices(indices);
95 return tvm::exp(x(indices) - max_elem(non_reduce_indices));
98 auto _compute_expsum = [&](
const Tensor&
exp,
const ffi::Array<Var>& indices) {
99 auto eval_range = insert_reduce_index(indices, k2);
103 auto _normalize = [&](
const Tensor&
exp,
const Tensor& expsum,
const ffi::Array<Var>& indices) {
104 auto non_reduce_indices = get_non_reduce_indices(indices);
105 return exp(indices) / expsum(non_reduce_indices);
110 input_shape, [&](
const ffi::Array<Var>& indices) {
return _compute_exp(max_elem, indices); });
112 reduced_shape, [&](
const ffi::Array<Var>& indices) {
return _compute_expsum(
exp, indices); });
114 input_shape, [&](
const ffi::Array<Var>& indices) {
return _normalize(
exp, expsum, indices); },
128 std::string tag =
"log_softmax_output") {
129 ICHECK_EQ(x->shape.size(), 2) <<
"Log softmax requires 2-D input";
143 x->shape, [&](
Var i,
Var j) { return x(i, j) - max_elem(i) - tvm::log(expsum(i)); }, name,
Container of constant int that adds more constructors.
Definition: expr.h:600
Reference to PrimExprNode.
Definition: expr.h:124
Range container
Definition: expr.h:689
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:100
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:297
a named variable in TIR
Definition: var.h:77
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(ffi::Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", ffi::Map< ffi::String, ffi::Any > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Tensor softmax(const Tensor &x, int axis=-1, std::string name="tensor", std::string tag="softmax_output")
Softmax activation.
Definition: softmax.h:50
Tensor log_softmax(const Tensor &x, std::string name="tensor", std::string tag="log_softmax_output")
Log softmax activation.
Definition: softmax.h:127
ffi::Array< PrimExpr > MakeReduceTargetShape(const std::vector< int > &real_axis, const Tensor &data, bool keepdims, bool atleast1d)
Calculate the target shape for a reduce op.
Definition: reduction.h:99
Tensor exp(const Tensor &x, std::string name="T_" "exp", std::string tag=kElementWise)
Definition: elemwise.h:50
PrimExpr MaxOp(PrimExpr source, ffi::Array< IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
Wrap tvm::max to ensure we get the correct overload.
Definition: reduction.h:304
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr exp(PrimExpr x, Span span=Span())
Definition: op.h:738
PrimExpr sum(PrimExpr source, ffi::Array< tir::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Operation node can generate one or multiple Tensors.
Reduction op constructors.