tvm
softmax.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_NN_SOFTMAX_H_
25 #define TVM_TOPI_NN_SOFTMAX_H_
26 
27 #include <tvm/te/operation.h>
28 #include <tvm/topi/reduction.h>
29 #include <tvm/topi/tags.h>
30 
31 #include <algorithm>
32 #include <string>
33 
34 namespace tvm {
35 namespace topi {
36 namespace nn {
37 
38 using namespace tvm::te;
39 
50 inline Tensor softmax(const Tensor& x, int axis = -1, std::string name = "tensor",
51  std::string tag = "softmax_output") {
52  auto input_shape = x->shape;
53  auto ndim = input_shape.size();
54  if (axis < 0) {
55  axis = ndim + axis;
56  }
57  ICHECK_LT(axis, ndim) << "axis parameter should be less than input dim";
58 
59  auto k1 = tvm::te::reduce_axis(Range(0, input_shape[axis]), "k1");
60  auto k2 = tvm::te::reduce_axis(Range(0, input_shape[axis]), "k2");
61  auto reduced_shape = MakeReduceTargetShape({axis}, x, false, false);
62 
64  attrs.Set("axis", Integer(axis));
65 
66  auto insert_reduce_index = [axis, ndim](const Array<Var>& indices, const IterVar& reduce_index) {
67  Array<PrimExpr> eval_range;
68  int arg_counter = 0;
69  for (size_t i = 0; i < ndim; ++i) {
70  if (static_cast<int>(i) == axis) {
71  eval_range.push_back(reduce_index);
72  } else {
73  eval_range.push_back(indices[arg_counter++]);
74  }
75  }
76  return eval_range;
77  };
78 
79  auto get_non_reduce_indices = [axis, ndim](const Array<Var>& indices) {
80  Array<PrimExpr> non_reduce_indices;
81  for (size_t i = 0; i < ndim; ++i) {
82  if (static_cast<int>(i) != axis) non_reduce_indices.push_back(indices[i]);
83  }
84  return non_reduce_indices;
85  };
86 
87  auto _compute_max = [&](const Array<Var>& indices) {
88  auto eval_range = insert_reduce_index(indices, k1);
89  return topi::MaxOp(x(eval_range), {k1});
90  };
91 
92  auto _compute_exp = [&](const Tensor& max_elem, const Array<Var>& indices) {
93  auto non_reduce_indices = get_non_reduce_indices(indices);
94  return tvm::exp(x(indices) - max_elem(non_reduce_indices));
95  };
96 
97  auto _compute_expsum = [&](const Tensor& exp, const Array<Var>& indices) {
98  auto eval_range = insert_reduce_index(indices, k2);
99  return tvm::sum(exp(eval_range), {k2});
100  };
101 
102  auto _normalize = [&](const Tensor& exp, const Tensor& expsum, const Array<Var>& indices) {
103  auto non_reduce_indices = get_non_reduce_indices(indices);
104  return exp(indices) / expsum(non_reduce_indices);
105  };
106 
107  auto max_elem = tvm::te::compute(reduced_shape, _compute_max);
108  auto exp = tvm::te::compute(
109  input_shape, [&](const Array<Var>& indices) { return _compute_exp(max_elem, indices); });
110  auto expsum = tvm::te::compute(
111  reduced_shape, [&](const Array<Var>& indices) { return _compute_expsum(exp, indices); });
112  return tvm::te::compute(
113  input_shape, [&](const Array<Var>& indices) { return _normalize(exp, expsum, indices); },
114  name, tag, attrs);
115 }
116 
126 inline Tensor log_softmax(const Tensor& x, std::string name = "tensor",
127  std::string tag = "log_softmax_output") {
128  ICHECK_EQ(x->shape.size(), 2) << "Log softmax requires 2-D input";
129 
130  PrimExpr m = x->shape[0];
131  PrimExpr n = x->shape[1];
132 
133  auto k = tvm::te::reduce_axis(Range(0, n), "k");
134  auto max_elem =
135  tvm::te::compute({m}, [&](Var i) { return tvm::max(x(i, k), Array<IterVar>{k}); });
136  k = tvm::te::reduce_axis(Range(0, n), "k");
137 
138  auto expsum =
139  tvm::te::compute({m}, [&](Var i) { return tvm::sum(tvm::exp(x(i, k) - max_elem(i)), {k}); });
140 
141  return tvm::te::compute(
142  x->shape, [&](Var i, Var j) { return x(i, j) - max_elem(i) - tvm::log(expsum(i)); }, name,
143  tag);
144 }
145 
146 } // namespace nn
147 } // namespace topi
148 } // namespace tvm
149 #endif // TVM_TOPI_NN_SOFTMAX_H_
Container of constant int that adds more constructors.
Definition: expr.h:632
Reference to PrimExprNode.
Definition: expr.h:115
Range container
Definition: expr.h:725
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
void Set(const K &key, const V &value)
set the Map.
Definition: map.h:1374
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:315
a named variable in TIR
Definition: var.h:89
Tensor expression language DSL.
Definition: extracted_task.h:33
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Tensor softmax(const Tensor &x, int axis=-1, std::string name="tensor", std::string tag="softmax_output")
Softmax activation.
Definition: softmax.h:50
Tensor log_softmax(const Tensor &x, std::string name="tensor", std::string tag="log_softmax_output")
Log softmax activation.
Definition: softmax.h:126
PrimExpr MaxOp(PrimExpr source, Array< IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
Wrap tvm::max to ensure we get the correct overload.
Definition: reduction.h:302
Tensor exp(const Tensor &x, std::string name="T_" "exp", std::string tag=kElementWise)
Definition: elemwise.h:50
Array< PrimExpr > MakeReduceTargetShape(const std::vector< int > &real_axis, const Tensor &data, bool keepdims, bool atleast1d)
Calculate the target shape for a reduce op.
Definition: reduction.h:99
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr exp(PrimExpr x, Span span=Span())
Definition: op.h:706
PrimExpr sum(PrimExpr source, Array< tir::IterVar > axis, Array< PrimExpr > init={}, Span span=Span())
sum of source expression over axis
Operation node can generate one or multiple Tensors.
Reduction op constructors.
External function interface to rocBLAS libraries.