tvm
softmax.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_NN_SOFTMAX_H_
25 #define TVM_TOPI_NN_SOFTMAX_H_
26 
27 #include <tvm/te/operation.h>
28 #include <tvm/topi/reduction.h>
29 #include <tvm/topi/tags.h>
30 
31 #include <algorithm>
32 #include <string>
33 
34 namespace tvm {
35 namespace topi {
36 namespace nn {
37 
38 using namespace tvm::te;
39 
50 inline Tensor softmax(const Tensor& x, int axis = -1, std::string name = "tensor",
51  std::string tag = "softmax_output") {
52  auto input_shape = x->shape;
53  auto ndim = input_shape.size();
54  if (axis < 0) {
55  axis = ndim + axis;
56  }
57  ICHECK_LT(axis, ndim) << "axis parameter should be less than input dim";
58 
59  auto k1 = tvm::te::reduce_axis(Range(0, input_shape[axis]), "k1");
60  auto k2 = tvm::te::reduce_axis(Range(0, input_shape[axis]), "k2");
61  auto reduced_shape = MakeReduceTargetShape({axis}, x, false, false);
62 
64  attrs.Set("axis", Integer(axis));
65 
66  auto insert_reduce_index = [axis, ndim](const Array<Var>& indices, const IterVar& reduce_index) {
67  Array<PrimExpr> eval_range;
68  int arg_counter = 0;
69  for (size_t i = 0; i < ndim; ++i) {
70  if (static_cast<int>(i) == axis)
71  eval_range.push_back(reduce_index);
72  else
73  eval_range.push_back(indices[arg_counter++]);
74  }
75  return eval_range;
76  };
77 
78  auto get_non_reduce_indices = [axis, ndim](const Array<Var>& indices) {
79  Array<PrimExpr> non_reduce_indices;
80  for (size_t i = 0; i < ndim; ++i) {
81  if (static_cast<int>(i) != axis) non_reduce_indices.push_back(indices[i]);
82  }
83  return non_reduce_indices;
84  };
85 
86  auto _compute_max = [&](const Array<Var>& indices) {
87  auto eval_range = insert_reduce_index(indices, k1);
88  return topi::MaxOp(x(eval_range), {k1});
89  };
90 
91  auto _compute_exp = [&](const Tensor& max_elem, const Array<Var>& indices) {
92  auto non_reduce_indices = get_non_reduce_indices(indices);
93  return tvm::exp(x(indices) - max_elem(non_reduce_indices));
94  };
95 
96  auto _compute_expsum = [&](const Tensor& exp, const Array<Var>& indices) {
97  auto eval_range = insert_reduce_index(indices, k2);
98  return tvm::sum(exp(eval_range), {k2});
99  };
100 
101  auto _normalize = [&](const Tensor& exp, const Tensor& expsum, const Array<Var>& indices) {
102  auto non_reduce_indices = get_non_reduce_indices(indices);
103  return exp(indices) / expsum(non_reduce_indices);
104  };
105 
106  auto max_elem = tvm::te::compute(reduced_shape, _compute_max);
107  auto exp = tvm::te::compute(
108  input_shape, [&](const Array<Var>& indices) { return _compute_exp(max_elem, indices); });
109  auto expsum = tvm::te::compute(
110  reduced_shape, [&](const Array<Var>& indices) { return _compute_expsum(exp, indices); });
111  return tvm::te::compute(
112  input_shape, [&](const Array<Var>& indices) { return _normalize(exp, expsum, indices); },
113  name, tag, attrs);
114 }
115 
125 inline Tensor log_softmax(const Tensor& x, std::string name = "tensor",
126  std::string tag = "log_softmax_output") {
127  ICHECK_EQ(x->shape.size(), 2) << "Log softmax requires 2-D input";
128 
129  PrimExpr m = x->shape[0];
130  PrimExpr n = x->shape[1];
131 
132  auto k = tvm::te::reduce_axis(Range(0, n), "k");
133  auto max_elem =
134  tvm::te::compute({m}, [&](Var i) { return tvm::max(x(i, k), Array<IterVar>{k}); });
135  k = tvm::te::reduce_axis(Range(0, n), "k");
136 
137  auto expsum =
138  tvm::te::compute({m}, [&](Var i) { return tvm::sum(tvm::exp(x(i, k) - max_elem(i)), {k}); });
139 
140  return tvm::te::compute(
141  x->shape, [&](Var i, Var j) { return x(i, j) - max_elem(i) - tvm::log(expsum(i)); }, name,
142  tag);
143 }
144 
145 } // namespace nn
146 } // namespace topi
147 } // namespace tvm
148 #endif // TVM_TOPI_NN_SOFTMAX_H_
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
Tensor expression language DSL.
Definition: autodiff.h:35
a named variable in TIR
Definition: var.h:88
Reduction op constructors.
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:297
Tensor softmax(const Tensor &x, int axis=-1, std::string name="tensor", std::string tag="softmax_output")
Softmax activation.
Definition: softmax.h:50
Array< PrimExpr > MakeReduceTargetShape(const std::vector< int > &real_axis, const Tensor &data, bool keepdims, bool atleast1d)
Calculate the target shape for a reduce op.
Definition: reduction.h:99
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:436
Tensor log_softmax(const Tensor &x, std::string name="tensor", std::string tag="log_softmax_output")
Log softmax activation.
Definition: softmax.h:125
Range constainer.
Definition: expr.h:449
PrimExpr MaxOp(PrimExpr source, Array< IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
Wrap tvm::max to ensure we get the correct overload.
Definition: reduction.h:302
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of of source expression over axis
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Tensor exp(const Tensor &x, std::string name="T_" "exp", std::string tag=kElementWise)
Definition: elemwise.h:49
Operation node can generate one or multiple Tensors.
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1235
External function interface to rocBLAS libraries.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
PrimExpr exp(PrimExpr x, Span span=Span())
Definition: op.h:844
Reference to PrimExprNode.
Definition: expr.h:109
void Set(const K &key, const V &value)
set the Map.
Definition: map.h:1338
Container of constant int that adds more constructors.
Definition: expr.h:356