tvm
elemwise.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_ELEMWISE_H_
25 #define TVM_TOPI_ELEMWISE_H_
26 
27 #include <tvm/tir/builtin.h>
28 #include <tvm/tir/expr.h>
29 #include <tvm/topi/tags.h>
30 
31 #include <algorithm>
32 #include <string>
33 
34 #include "broadcast.h"
35 
36 namespace tvm {
37 namespace topi {
38 
39 using namespace tvm::te;
40 
41 // Unary intrinsic operators
42 #define TOPI_DECLARE_UNARY_OP(OpName) \
43  inline Tensor OpName(const Tensor& x, std::string name = "T_" #OpName, \
44  std::string tag = kElementWise) { \
45  return compute( \
46  x->shape, [&](const Array<Var>& i) { return ::tvm::OpName(x(i)); }, name, tag); \
47  }
48 
76 
81 inline Tensor fast_tanh_float(const Tensor& in, std::string name, std::string tag) {
82  // Clamp the inputs to the range [-9, 9] since anything outside
83  // this range is +/-1.0f in single-precision.
84  auto x = maximum(minimum(in, make_const(in->dtype, 9.0)), make_const(in->dtype, -9.0));
85 
86  // The monomial coefficients of the numerator polynomial (odd).
87  auto alpha_1 = make_const(in->dtype, 4.89352455891786e-03);
88  auto alpha_3 = make_const(in->dtype, 6.37261928875436e-04);
89  auto alpha_5 = make_const(in->dtype, 1.48572235717979e-05);
90  auto alpha_7 = make_const(in->dtype, 5.12229709037114e-08);
91  auto alpha_9 = make_const(in->dtype, -8.60467152213735e-11);
92  auto alpha_11 = make_const(in->dtype, 2.00018790482477e-13);
93  auto alpha_13 = make_const(in->dtype, -2.76076847742355e-16);
94 
95  // The monomial coefficients of the denominator polynomial (even).
96  auto beta_0 = make_const(in->dtype, 4.89352518554385e-03);
97  auto beta_2 = make_const(in->dtype, 2.26843463243900e-03);
98  auto beta_4 = make_const(in->dtype, 1.18534705686654e-04);
99  auto beta_6 = make_const(in->dtype, 1.19825839466702e-06);
100 
101  return compute(
102  x->shape,
103  [&](const Array<Var>& i) {
104  auto x2 = x(i) * x(i);
105  auto p = x2 * alpha_13 + alpha_11;
106  p = x2 * p + alpha_9;
107  p = x2 * p + alpha_7;
108  p = x2 * p + alpha_5;
109  p = x2 * p + alpha_3;
110  p = x2 * p + alpha_1;
111  p = x(i) * p;
112 
113  auto q = x2 * beta_6 + beta_4;
114  q = x2 * q + beta_2;
115  q = x2 * q + beta_0;
116  return p / q;
117  },
118  name, tag);
119 }
120 
130 inline Tensor fast_tanh(const Tensor& x, std::string name = "T_fast_tanh",
131  std::string tag = kElementWise) {
132  if (x->dtype == DataType::Float(32)) {
133  // invoke fast_tanh_float implementation
134  return fast_tanh_float(x, name, tag);
135  } else {
136  // fallback to default implementation
137  return compute(
138  x->shape, [&](const Array<Var>& i) { return ::tvm::tanh(x(i)); }, name, tag);
139  }
140 }
141 
151 inline Tensor identity(const Tensor& x, std::string name = "T_identity",
152  std::string tag = kElementWise) {
153  return compute(
154  x->shape, [&](const Array<Var>& i) { return x(i); }, name, tag);
155 }
156 
166 inline Tensor negative(const Tensor& x, std::string name = "T_negative",
167  std::string tag = kElementWise) {
168  return compute(
169  x->shape, [&](const Array<Var>& i) { return -x(i); }, name, tag);
170 }
171 
181 inline Tensor logical_not(const Tensor& x, std::string name = "T_logical_not",
182  std::string tag = kElementWise) {
183  return compute(
184  x->shape, [&](const Array<Var>& i) { return !x(i); }, name, tag);
185 }
186 
196 inline Tensor bitwise_not(const Tensor& x, std::string name = "T_bitwise_not",
197  std::string tag = kElementWise) {
198  return compute(
199  x->shape, [&](const Array<Var>& i) { return ~x(i); }, name, tag);
200 }
201 
211 inline Tensor sign(const Tensor& x, std::string name = "T_sign", std::string tag = kElementWise) {
212  return compute(
213  x->shape,
214  [&](const Array<Var>& i) {
215  PrimExpr zero = make_zero(x->dtype);
216  PrimExpr one = make_const(x->dtype, 1);
217  PrimExpr minus_one = make_const(x->dtype, -1);
218  auto s1 = tvm::tir::Select((x(i) < zero), minus_one, zero);
219  auto s2 = tvm::tir::Select((x(i) > zero), one, s1);
220  return s2;
221  },
222  name, tag);
223 }
224 
234 inline Tensor rsqrt(const Tensor& x, std::string name = "tensor", std::string tag = kElementWise) {
235  return compute(
236  x->shape,
237  [&](const Array<Var>& i) {
238  PrimExpr one = make_const(x->dtype, 1);
239  return one / tvm::sqrt(x(i));
240  },
241  name, tag);
242 }
243 
256 inline Tensor clip(const Tensor& x, const PrimExpr& a_min, const PrimExpr& a_max,
257  std::string name = "T_clip", std::string tag = kElementWise) {
258  return compute(
259  x->shape,
260  [&](const Array<Var>& i) {
261  auto min_val = tvm::cast(x->dtype, a_min);
262  auto max_val = tvm::cast(x->dtype, a_max);
263  return tvm::max(tvm::min(x(i), max_val), min_val); // NOLINT(*)
264  },
265  name, tag);
266 }
267 
280 inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast",
281  std::string tag = kElementWise) {
282  return compute(
283  x->shape,
284  [&](const Array<Var>& i) -> PrimExpr {
285  auto expr = x(i);
286  if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) {
287  if (expr.dtype().lanes() == type.lanes()) {
288  return expr;
289  } else if (expr.dtype().lanes() == 1 && type.lanes() > 1) {
290  return tvm::tir::Broadcast(expr, type.lanes());
291  }
292  }
293 
294  return tvm::cast(type, x(i));
295  },
296  name, tag);
297 }
298 
309 inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "tensor",
310  std::string tag = kElementWise) {
311  return compute(
312  x->shape,
313  [&](const Array<Var>& i) {
314  return tvm::tir::Call(type, tvm::tir::builtin::reinterpret(), {x(i)});
315  },
316  name, tag);
317 }
318 
328 inline Tensor elemwise_sum(const Array<Tensor>& xs, std::string name = "T_elemwise_sum",
329  std::string tag = kElementWise) {
330  ICHECK_GT(xs.size(), 0) << "elemwise sum must have at least one input tensor.";
331  return compute(
332  xs[0]->shape,
333  [&](const Array<Var>& i) {
334  auto sum_expr = xs[0](i);
335  for (size_t j = 1; j < xs.size(); j++) {
336  sum_expr = sum_expr + xs[j](i);
337  }
338  return sum_expr;
339  },
340  name, tag);
341 }
342 
354 inline Tensor full(const Array<PrimExpr>& shape, DataType dtype, const PrimExpr fill_value,
355  std::string name = "T_full", std::string tag = kElementWise) {
356  PrimExpr ev = cast(dtype, fill_value);
357  if (!ev.defined()) {
358  LOG(ERROR) << "Can't cast fill_value to " << dtype;
359  }
360  return compute(
361  shape, [&](const Array<Var>& i) { return ev; }, name, tag);
362 }
363 
375 inline Tensor full_like(const Tensor& x, const PrimExpr fill_value,
376  std::string name = "T_full_like", std::string tag = kElementWise) {
377  PrimExpr ev = cast(x->dtype, fill_value);
378  return compute(
379  x->shape, [&](const Array<Var>& i) { return ev; }, name, tag);
380 }
381 
403 inline Tensor fast_exp_float32(const Tensor& _x, std::string name, std::string tag) {
404  auto x_hi = make_const(DataType::Float(32), 88.3762626647950f);
405  auto x_lo = make_const(DataType::Float(32), -88.3762626647949f);
406  auto log2e = make_const(DataType::Float(32), 1.44269504088896341f);
407  auto ln2 = make_const(DataType::Float(32), 0.6931471805599453f);
408  PrimExpr p[6] = {make_const(DataType::Float(32), 1.9875691500E-4f),
409  make_const(DataType::Float(32), 1.3981999507E-3f),
410  make_const(DataType::Float(32), 8.3334519073E-3f),
411  make_const(DataType::Float(32), 4.1665795894E-2f),
412  make_const(DataType::Float(32), 1.6666665459E-1f),
413  make_const(DataType::Float(32), 5.0000001201E-1f)};
414  auto one = make_const(DataType::Float(32), 1.0f);
415  auto one_half = make_const(DataType::Float(32), 0.5f);
416  auto b = make_const(DataType::Float(32), 127.0f);
417 
418  return compute(
419  _x->shape,
420  [&](const Array<Var>& i) {
421  // clamp x
422  auto x = ::tvm::max(::tvm::min(_x(i), x_hi), x_lo);
423  // integer part
424  auto n = ::tvm::floor(x * log2e + one_half);
425  // fractional part
426  auto f = x - n * ln2;
427  auto y =
428  (((((p[0] * f + p[1]) * f + p[2]) * f + p[3]) * f + p[4]) * f + p[5]) * f * f + f + one;
429  // Return 2^m * exp(r).
430  auto ef =
431  tvm::reinterpret(DataType::Float(32), ::tvm::cast(DataType::Int(32), n + b) << 23);
432  return ::tvm::max(ef * y, _x(i)); // NOLINT(*)
433  },
434  name, tag);
435 }
436 
447 inline Tensor fast_exp(const Tensor& x, std::string name = "T_fast_exp",
448  std::string tag = kElementWise) {
449  if (x->dtype == DataType::Float(32)) {
450  auto ret = fast_exp_float32(x, name, tag);
451  return ret;
452  } else {
453  return compute(
454  x->shape, [&](const Array<Var>& i) { return ::tvm::exp(x(i)); }, name, tag);
455  }
456 }
457 
464 inline PrimExpr fast_erf_float_expr(PrimExpr arg, int bits) {
465  auto plus_4 = make_const(DataType::Float(bits), 4.f);
466  auto minus_4 = make_const(DataType::Float(bits), -4.f);
467 
468  // The monomial coefficients of the numerator polynomial (odd).
469  auto alpha_1 = make_const(DataType::Float(bits), -1.60960333262415e-02f);
470  auto alpha_3 = make_const(DataType::Float(bits), -2.95459980854025e-03f);
471  auto alpha_5 = make_const(DataType::Float(bits), -7.34990630326855e-04f);
472  auto alpha_7 = make_const(DataType::Float(bits), -5.69250639462346e-05f);
473  auto alpha_9 = make_const(DataType::Float(bits), -2.10102402082508e-06f);
474  auto alpha_11 = make_const(DataType::Float(bits), 2.77068142495902e-08f);
475  auto alpha_13 = make_const(DataType::Float(bits), -2.72614225801306e-10f);
476 
477  // The monomial coefficients of the denominator polynomial (even).
478  auto beta_0 = make_const(DataType::Float(bits), -1.42647390514189e-02f);
479  auto beta_2 = make_const(DataType::Float(bits), -7.37332916720468e-03f);
480  auto beta_4 = make_const(DataType::Float(bits), -1.68282697438203e-03f);
481  auto beta_6 = make_const(DataType::Float(bits), -2.13374055278905e-04f);
482  auto beta_8 = make_const(DataType::Float(bits), -1.45660718464996e-05f);
483 
484  // clamp x
485  auto x = tvm::max(tvm::min(arg, plus_4), minus_4);
486  auto x2 = x * x;
487 
488  // Evaluate the numerator polynomial p.
489  auto p = x2 * alpha_13 + alpha_11;
490  p = x2 * p + alpha_9;
491  p = x2 * p + alpha_7;
492  p = x2 * p + alpha_5;
493  p = x2 * p + alpha_3;
494  p = x2 * p + alpha_1;
495  p = x * p;
496 
497  // Evaluate the denominator polynomial p.
498  auto q = x2 * beta_8 + beta_6;
499  q = x2 * q + beta_4;
500  q = x2 * q + beta_2;
501  q = x2 * q + beta_0;
502 
503  return p / q;
504 }
505 
509 inline Tensor fast_erf_float32(const Tensor& data, std::string name, std::string tag) {
510  return compute(
511  data->shape, [&](const Array<Var>& i) { return fast_erf_float_expr(data(i), 32); }, name,
512  tag);
513 }
514 
524 inline Tensor fast_erf(const Tensor& x, std::string name = "T_fast_erf",
525  std::string tag = kElementWise) {
526  if (x->dtype == DataType::Float(32)) {
527  auto ret = fast_erf_float32(x, name, tag);
528  return ret;
529  } else {
530  return topi::erf(x);
531  }
532 }
533 
534 } // namespace topi
535 } // namespace tvm
536 #endif // TVM_TOPI_ELEMWISE_H_
Tensor full_like(const Tensor &x, const PrimExpr fill_value, std::string name="T_full_like", std::string tag=kElementWise)
Creates an operation that construct a tensor with same shape as input tensor, then fill a tensor with...
Definition: elemwise.h:375
Tensor negative(const Tensor &x, std::string name="T_negative", std::string tag=kElementWise)
Creates an operation that returns the negation of a given tensor.
Definition: elemwise.h:166
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
tvm::PrimExpr minimum(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:316
Tensor abs(const Tensor &x, std::string name="T_" "abs", std::string tag=kElementWise)
Definition: elemwise.h:60
Tensor isinf(const Tensor &x, std::string name="T_" "isinf", std::string tag=kElementWise)
Definition: elemwise.h:75
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:1109
Tensor full(const Array< PrimExpr > &shape, DataType dtype, const PrimExpr fill_value, std::string name="T_full", std::string tag=kElementWise)
Creates an operation that fill a tensor with fill_value.
Definition: elemwise.h:354
Tensor sqrt(const Tensor &x, std::string name="T_" "sqrt", std::string tag=kElementWise)
Definition: elemwise.h:52
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
Tensor expression language DSL.
Definition: autodiff.h:35
Tensor elemwise_sum(const Array< Tensor > &xs, std::string name="T_elemwise_sum", std::string tag=kElementWise)
Creates an operation that sum each element of a tensor.
Definition: elemwise.h:328
Tensor tan(const Tensor &x, std::string name="T_" "tan", std::string tag=kElementWise)
Definition: elemwise.h:63
Tensor fast_tanh_float(const Tensor &in, std::string name, std::string tag)
Fast_tanh_float implementation from Eigen https://github.com/eigenteam/eigen-git-mirror/blob/master/E...
Definition: elemwise.h:81
Tensor floor(const Tensor &x, std::string name="T_" "floor", std::string tag=kElementWise)
Definition: elemwise.h:56
Tensor erf(const Tensor &x, std::string name="T_" "erf", std::string tag=kElementWise)
Definition: elemwise.h:50
Tensor round(const Tensor &x, std::string name="T_" "round", std::string tag=kElementWise)
Definition: elemwise.h:58
Tensor atan(const Tensor &x, std::string name="T_" "atan", std::string tag=kElementWise)
Definition: elemwise.h:70
Tensor asinh(const Tensor &x, std::string name="T_" "asinh", std::string tag=kElementWise)
Definition: elemwise.h:69
Tensor clip(const Tensor &x, const PrimExpr &a_min, const PrimExpr &a_max, std::string name="T_clip", std::string tag=kElementWise)
Creates an operation that clips each element of a tensor to the interval [a_min, a_max].
Definition: elemwise.h:256
Tensor acosh(const Tensor &x, std::string name="T_" "acosh", std::string tag=kElementWise)
Definition: elemwise.h:67
Tensor fast_exp_float32(const Tensor &_x, std::string name, std::string tag)
Fast exponential function implementation.
Definition: elemwise.h:403
Tensor sin(const Tensor &x, std::string name="T_" "sin", std::string tag=kElementWise)
Definition: elemwise.h:64
Tensor log10(const Tensor &x, std::string name="T_" "log10", std::string tag=kElementWise)
Definition: elemwise.h:55
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
Tensor cos(const Tensor &x, std::string name="T_" "cos", std::string tag=kElementWise)
Definition: elemwise.h:61
Tensor sign(const Tensor &x, std::string name="T_sign", std::string tag=kElementWise)
Returns the sign of the tensor.
Definition: elemwise.h:211
Tensor isnan(const Tensor &x, std::string name="T_" "isnan", std::string tag=kElementWise)
Definition: elemwise.h:72
Tensor ceil(const Tensor &x, std::string name="T_" "ceil", std::string tag=kElementWise)
Definition: elemwise.h:57
size_t size() const
Definition: array.h:399
bool defined() const
Definition: object.h:537
Runtime primitive data type.
Definition: data_type.h:41
Tensor rsqrt(const Tensor &x, std::string name="tensor", std::string tag=kElementWise)
Creates an operation that returns rsqrt of a given tensor.
Definition: elemwise.h:234
Tensor identity(const Tensor &x, std::string name="T_identity", std::string tag=kElementWise)
Creates an operation that returns identity of a given tensor.
Definition: elemwise.h:151
TIR expressions.
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:168
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
constexpr auto kElementWise
Definition: tags.h:32
Tensor log(const Tensor &x, std::string name="T_" "log", std::string tag=kElementWise)
Definition: elemwise.h:53
Tensor fast_exp(const Tensor &x, std::string name="T_fast_exp", std::string tag=kElementWise)
Fast exponential function implementation.
Definition: elemwise.h:447
Tensor sinh(const Tensor &x, std::string name="T_" "sinh", std::string tag=kElementWise)
Definition: elemwise.h:65
Tensor bitwise_not(const Tensor &x, std::string name="T_bitwise_not", std::string tag=kElementWise)
Creates an operation that returns the bitwise NOT of a given tensor.
Definition: elemwise.h:196
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
Tensor tanh(const Tensor &x, std::string name="T_" "tanh", std::string tag=kElementWise)
Definition: elemwise.h:73
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1608
Tensor fast_erf_float32(const Tensor &data, std::string name, std::string tag)
Fast_erf_float expression from Eigen.
Definition: elemwise.h:509
Tensor acos(const Tensor &x, std::string name="T_" "acos", std::string tag=kElementWise)
Definition: elemwise.h:66
Tensor reinterpret(const Tensor &x, DataType type, std::string name="tensor", std::string tag=kElementWise)
Reinterpret each element of x to the given type.
Definition: elemwise.h:309
Tensor log2(const Tensor &x, std::string name="T_" "log2", std::string tag=kElementWise)
Definition: elemwise.h:54
Tensor sigmoid(const Tensor &x, std::string name="T_" "sigmoid", std::string tag=kElementWise)
Definition: elemwise.h:51
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Tensor trunc(const Tensor &x, std::string name="T_" "trunc", std::string tag=kElementWise)
Definition: elemwise.h:59
Tensor exp(const Tensor &x, std::string name="T_" "exp", std::string tag=kElementWise)
Definition: elemwise.h:49
Tensor fast_erf(const Tensor &x, std::string name="T_fast_erf", std::string tag=kElementWise)
Fast erf implementation.
Definition: elemwise.h:524
Tensor logical_not(const Tensor &x, std::string name="T_logical_not", std::string tag=kElementWise)
Creates an operation that returns the logical NOT of a given tensor.
Definition: elemwise.h:181
Tensor cosh(const Tensor &x, std::string name="T_" "cosh", std::string tag=kElementWise)
Definition: elemwise.h:62
Tensor asin(const Tensor &x, std::string name="T_" "asin", std::string tag=kElementWise)
Definition: elemwise.h:68
#define TOPI_DECLARE_UNARY_OP(OpName)
Definition: elemwise.h:42
Tensor isfinite(const Tensor &x, std::string name="T_" "isfinite", std::string tag=kElementWise)
Definition: elemwise.h:74
tvm::PrimExpr maximum(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:303
Tensor fast_tanh(const Tensor &x, std::string name="T_fast_tanh", std::string tag=kElementWise)
Creates an operation that returns hyperbolic tanh of a given tensor.
Definition: elemwise.h:130
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
TIR builtin intrinsics.
External function interface to rocBLAS libraries.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type...
Definition: elemwise.h:280
Broadcast op constructions.
Reference to PrimExprNode.
Definition: expr.h:109
Tensor atanh(const Tensor &x, std::string name="T_" "atanh", std::string tag=kElementWise)
Definition: elemwise.h:71
PrimExpr fast_erf_float_expr(PrimExpr arg, int bits)
Fast_erf_float expression from Eigen https://github.com/eigenteam/eigen-git-mirror/blob/master/unsupp...
Definition: elemwise.h:464