tvm
softmax.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_NN_SOFTMAX_H_
25 #define TVM_TOPI_NN_SOFTMAX_H_
26 
27 #include <tvm/te/operation.h>
28 #include <tvm/topi/reduction.h>
29 #include <tvm/topi/tags.h>
30 
31 #include <algorithm>
32 #include <string>
33 
34 namespace tvm {
35 namespace topi {
36 namespace nn {
37 
38 using namespace tvm::te;
39 
50 inline Tensor softmax(const Tensor& x, int axis = -1, std::string name = "tensor",
51  std::string tag = "softmax_output") {
52  auto input_shape = x->shape;
53  auto ndim = input_shape.size();
54  if (axis < 0) {
55  axis = ndim + axis;
56  }
57  ICHECK_LT(axis, ndim) << "axis parameter should be less than input dim";
58 
59  auto k1 = tvm::te::reduce_axis(Range(0, input_shape[axis]), "k1");
60  auto k2 = tvm::te::reduce_axis(Range(0, input_shape[axis]), "k2");
61  auto reduced_shape = MakeReduceTargetShape({axis}, x, false, false);
62 
63  tvm::ffi::Map<ffi::String, ffi::Any> attrs;
64  attrs.Set("axis", Integer(axis));
65 
66  auto insert_reduce_index = [axis, ndim](const ffi::Array<Var>& indices,
67  const IterVar& reduce_index) {
68  ffi::Array<PrimExpr> eval_range;
69  int arg_counter = 0;
70  for (size_t i = 0; i < ndim; ++i) {
71  if (static_cast<int>(i) == axis) {
72  eval_range.push_back(reduce_index);
73  } else {
74  eval_range.push_back(indices[arg_counter++]);
75  }
76  }
77  return eval_range;
78  };
79 
80  auto get_non_reduce_indices = [axis, ndim](const ffi::Array<Var>& indices) {
81  ffi::Array<PrimExpr> non_reduce_indices;
82  for (size_t i = 0; i < ndim; ++i) {
83  if (static_cast<int>(i) != axis) non_reduce_indices.push_back(indices[i]);
84  }
85  return non_reduce_indices;
86  };
87 
88  auto _compute_max = [&](const ffi::Array<Var>& indices) {
89  auto eval_range = insert_reduce_index(indices, k1);
90  return topi::MaxOp(x(eval_range), {k1});
91  };
92 
93  auto _compute_exp = [&](const Tensor& max_elem, const ffi::Array<Var>& indices) {
94  auto non_reduce_indices = get_non_reduce_indices(indices);
95  return tvm::exp(x(indices) - max_elem(non_reduce_indices));
96  };
97 
98  auto _compute_expsum = [&](const Tensor& exp, const ffi::Array<Var>& indices) {
99  auto eval_range = insert_reduce_index(indices, k2);
100  return tvm::sum(exp(eval_range), {k2});
101  };
102 
103  auto _normalize = [&](const Tensor& exp, const Tensor& expsum, const ffi::Array<Var>& indices) {
104  auto non_reduce_indices = get_non_reduce_indices(indices);
105  return exp(indices) / expsum(non_reduce_indices);
106  };
107 
108  auto max_elem = tvm::te::compute(reduced_shape, _compute_max);
109  auto exp = tvm::te::compute(
110  input_shape, [&](const ffi::Array<Var>& indices) { return _compute_exp(max_elem, indices); });
111  auto expsum = tvm::te::compute(
112  reduced_shape, [&](const ffi::Array<Var>& indices) { return _compute_expsum(exp, indices); });
113  return tvm::te::compute(
114  input_shape, [&](const ffi::Array<Var>& indices) { return _normalize(exp, expsum, indices); },
115  name, tag, attrs);
116 }
117 
127 inline Tensor log_softmax(const Tensor& x, std::string name = "tensor",
128  std::string tag = "log_softmax_output") {
129  ICHECK_EQ(x->shape.size(), 2) << "Log softmax requires 2-D input";
130 
131  PrimExpr m = x->shape[0];
132  PrimExpr n = x->shape[1];
133 
134  auto k = tvm::te::reduce_axis(Range(0, n), "k");
135  auto max_elem =
136  tvm::te::compute({m}, [&](Var i) { return tvm::max(x(i, k), ffi::Array<IterVar>{k}); });
137  k = tvm::te::reduce_axis(Range(0, n), "k");
138 
139  auto expsum =
140  tvm::te::compute({m}, [&](Var i) { return tvm::sum(tvm::exp(x(i, k) - max_elem(i)), {k}); });
141 
142  return tvm::te::compute(
143  x->shape, [&](Var i, Var j) { return x(i, j) - max_elem(i) - tvm::log(expsum(i)); }, name,
144  tag);
145 }
146 
147 } // namespace nn
148 } // namespace topi
149 } // namespace tvm
150 #endif // TVM_TOPI_NN_SOFTMAX_H_
Container of constant int that adds more constructors.
Definition: expr.h:600
Reference to PrimExprNode.
Definition: expr.h:124
Range container
Definition: expr.h:689
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:100
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:297
a named variable in TIR
Definition: var.h:77
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(ffi::Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", ffi::Map< ffi::String, ffi::Any > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Tensor softmax(const Tensor &x, int axis=-1, std::string name="tensor", std::string tag="softmax_output")
Softmax activation.
Definition: softmax.h:50
Tensor log_softmax(const Tensor &x, std::string name="tensor", std::string tag="log_softmax_output")
Log softmax activation.
Definition: softmax.h:127
ffi::Array< PrimExpr > MakeReduceTargetShape(const std::vector< int > &real_axis, const Tensor &data, bool keepdims, bool atleast1d)
Calculate the target shape for a reduce op.
Definition: reduction.h:99
Tensor exp(const Tensor &x, std::string name="T_" "exp", std::string tag=kElementWise)
Definition: elemwise.h:50
PrimExpr MaxOp(PrimExpr source, ffi::Array< IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
Wrap tvm::max to ensure we get the correct overload.
Definition: reduction.h:304
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr exp(PrimExpr x, Span span=Span())
Definition: op.h:738
PrimExpr sum(PrimExpr source, ffi::Array< tir::IterVar > axis, ffi::Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Operation node can generate one or multiple Tensors.
Reduction op constructors.
External function interface to rocBLAS libraries.