tvm
elemwise.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_ELEMWISE_H_
25 #define TVM_TOPI_ELEMWISE_H_
26 
27 #include <tvm/tirx/builtin.h>
28 #include <tvm/tirx/expr.h>
29 #include <tvm/tirx/op.h>
30 #include <tvm/topi/tags.h>
31 
32 #include <algorithm>
33 #include <string>
34 
35 #include "broadcast.h"
36 
37 namespace tvm {
38 namespace topi {
39 
40 using namespace tvm::te;
41 
42 // Unary intrinsic operators
43 #define TOPI_DECLARE_UNARY_OP(OpName) \
44  inline Tensor OpName(const Tensor& x, std::string name = "T_" #OpName, \
45  std::string tag = kElementWise) { \
46  return compute( \
47  x->shape, [&](const ffi::Array<Var>& i) { return ::tvm::OpName(x(i)); }, name, tag); \
48  }
49 
77 
82 inline Tensor fast_tanh_float(const Tensor& in, std::string name, std::string tag) {
83  // Clamp the inputs to the range [-9, 9] since anything outside
84  // this range is +/-1.0f in single-precision.
85  auto x = maximum(make_const(in->dtype, -9.0), minimum(make_const(in->dtype, 9.0), in));
86 
87  // The monomial coefficients of the numerator polynomial (odd).
88  auto alpha_1 = make_const(in->dtype, 4.89352455891786e-03);
89  auto alpha_3 = make_const(in->dtype, 6.37261928875436e-04);
90  auto alpha_5 = make_const(in->dtype, 1.48572235717979e-05);
91  auto alpha_7 = make_const(in->dtype, 5.12229709037114e-08);
92  auto alpha_9 = make_const(in->dtype, -8.60467152213735e-11);
93  auto alpha_11 = make_const(in->dtype, 2.00018790482477e-13);
94  auto alpha_13 = make_const(in->dtype, -2.76076847742355e-16);
95 
96  // The monomial coefficients of the denominator polynomial (even).
97  auto beta_0 = make_const(in->dtype, 4.89352518554385e-03);
98  auto beta_2 = make_const(in->dtype, 2.26843463243900e-03);
99  auto beta_4 = make_const(in->dtype, 1.18534705686654e-04);
100  auto beta_6 = make_const(in->dtype, 1.19825839466702e-06);
101 
102  return compute(
103  x->shape,
104  [&](const ffi::Array<Var>& i) {
105  auto x2 = x(i) * x(i);
106  auto p = x2 * alpha_13 + alpha_11;
107  p = x2 * p + alpha_9;
108  p = x2 * p + alpha_7;
109  p = x2 * p + alpha_5;
110  p = x2 * p + alpha_3;
111  p = x2 * p + alpha_1;
112  p = x(i) * p;
113 
114  auto q = x2 * beta_6 + beta_4;
115  q = x2 * q + beta_2;
116  q = x2 * q + beta_0;
117  return p / q;
118  },
119  name, tag);
120 }
121 
131 inline Tensor fast_tanh(const Tensor& x, std::string name = "T_fast_tanh",
132  std::string tag = kElementWise) {
133  if (x->dtype == DataType::Float(32)) {
134  // invoke fast_tanh_float implementation
135  return fast_tanh_float(x, name, tag);
136  } else {
137  // fallback to default implementation
138  return compute(
139  x->shape, [&](const ffi::Array<Var>& i) { return ::tvm::tanh(x(i)); }, name, tag);
140  }
141 }
142 
152 inline Tensor identity(const Tensor& x, std::string name = "T_identity",
153  std::string tag = kElementWise) {
154  return compute(x->shape, [&](const ffi::Array<Var>& i) { return x(i); }, name, tag);
155 }
156 
166 inline Tensor negative(const Tensor& x, std::string name = "T_negative",
167  std::string tag = kElementWise) {
168  return compute(x->shape, [&](const ffi::Array<Var>& i) { return -x(i); }, name, tag);
169 }
170 
180 inline Tensor logical_not(const Tensor& x, std::string name = "T_logical_not",
181  std::string tag = kElementWise) {
182  return compute(x->shape, [&](const ffi::Array<Var>& i) { return !x(i); }, name, tag);
183 }
184 
194 inline Tensor bitwise_not(const Tensor& x, std::string name = "T_bitwise_not",
195  std::string tag = kElementWise) {
196  return compute(x->shape, [&](const ffi::Array<Var>& i) { return ~x(i); }, name, tag);
197 }
198 
208 inline Tensor sign(const Tensor& x, std::string name = "T_sign", std::string tag = kElementWise) {
209  return compute(
210  x->shape,
211  [&](const ffi::Array<Var>& i) {
212  PrimExpr zero = make_zero(x->dtype);
213  PrimExpr one = make_const(x->dtype, 1);
214  PrimExpr minus_one = make_const(x->dtype, -1);
215  auto s1 = tvm::tirx::Select((x(i) < zero), minus_one, zero);
216  auto s2 = tvm::tirx::Select((x(i) > zero), one, s1);
217  return s2;
218  },
219  name, tag);
220 }
221 
231 inline Tensor rsqrt(const Tensor& x, std::string name = "tensor", std::string tag = kElementWise) {
232  return compute(
233  x->shape,
234  [&](const ffi::Array<Var>& i) {
235  PrimExpr one = make_const(x->dtype, 1);
236  return one / tvm::sqrt(x(i));
237  },
238  name, tag);
239 }
240 
253 inline Tensor clip(const Tensor& x, const PrimExpr& a_min, const PrimExpr& a_max,
254  std::string name = "T_clip", std::string tag = kElementWise) {
255  return compute(
256  x->shape,
257  [&](const ffi::Array<Var>& i) {
258  auto min_val = tvm::cast(x->dtype, a_min);
259  auto max_val = tvm::cast(x->dtype, a_max);
260  return tvm::max(tvm::min(x(i), max_val), min_val); // NOLINT(*)
261  },
262  name, tag);
263 }
264 
277 inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast",
278  std::string tag = kElementWise) {
279  return compute(
280  x->shape,
281  [&](const ffi::Array<Var>& i) -> PrimExpr {
282  auto expr = x(i);
283  if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) {
284  if (expr.dtype().lanes() == type.lanes()) {
285  return expr;
286  } else if (expr.dtype().lanes() == 1 && type.is_vector()) {
287  return tvm::tirx::Broadcast(expr, type.lanes());
288  }
289  }
290 
291  return tvm::cast(type, x(i));
292  },
293  name, tag);
294 }
295 
306 inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "tensor",
307  std::string tag = kElementWise) {
308  return compute(
309  x->shape, [&](const ffi::Array<Var>& i) { return reinterpret(type, x(i)); }, name, tag);
310 }
311 
321 inline Tensor elemwise_sum(const ffi::Array<Tensor>& xs, std::string name = "T_elemwise_sum",
322  std::string tag = kElementWise) {
323  TVM_FFI_ICHECK_GT(xs.size(), 0) << "elemwise sum must have at least one input tensor.";
324  return compute(
325  xs[0]->shape,
326  [&](const ffi::Array<Var>& i) {
327  auto sum_expr = xs[0](i);
328  for (size_t j = 1; j < xs.size(); j++) {
329  sum_expr = sum_expr + xs[j](i);
330  }
331  return sum_expr;
332  },
333  name, tag);
334 }
335 
347 inline Tensor full(const ffi::Array<PrimExpr>& shape, DataType dtype, const PrimExpr fill_value,
348  std::string name = "T_full", std::string tag = kElementWise) {
349  PrimExpr ev = cast(dtype, fill_value);
350  if (!ev.defined()) {
351  LOG(ERROR) << "Can't cast fill_value to " << dtype;
352  }
353  return compute(shape, [&](const ffi::Array<Var>& i) { return ev; }, name, tag);
354 }
355 
367 inline Tensor full_like(const Tensor& x, const PrimExpr fill_value,
368  std::string name = "T_full_like", std::string tag = kElementWise) {
369  PrimExpr ev = cast(x->dtype, fill_value);
370  return compute(x->shape, [&](const ffi::Array<Var>& i) { return ev; }, name, tag);
371 }
372 
394 inline Tensor fast_exp_float32(const Tensor& _x, std::string name, std::string tag) {
395  auto x_hi = make_const(DataType::Float(32), 88.3762626647950f);
396  auto x_lo = make_const(DataType::Float(32), -88.3762626647949f);
397  auto log2e = make_const(DataType::Float(32), 1.44269504088896341f);
398  auto ln2 = make_const(DataType::Float(32), 0.6931471805599453f);
399  PrimExpr p[6] = {make_const(DataType::Float(32), 1.9875691500E-4f),
400  make_const(DataType::Float(32), 1.3981999507E-3f),
401  make_const(DataType::Float(32), 8.3334519073E-3f),
402  make_const(DataType::Float(32), 4.1665795894E-2f),
403  make_const(DataType::Float(32), 1.6666665459E-1f),
404  make_const(DataType::Float(32), 5.0000001201E-1f)};
405  auto one = make_const(DataType::Float(32), 1.0f);
406  auto one_half = make_const(DataType::Float(32), 0.5f);
407  auto b = make_const(DataType::Float(32), 127.0f);
408 
409  return compute(
410  _x->shape,
411  [&](const ffi::Array<Var>& i) {
412  // clamp x
413  auto x = ::tvm::max(::tvm::min(_x(i), x_hi), x_lo);
414  // integer part
415  auto n = ::tvm::floor(x * log2e + one_half);
416  // fractional part
417  auto f = x - n * ln2;
418  auto y =
419  (((((p[0] * f + p[1]) * f + p[2]) * f + p[3]) * f + p[4]) * f + p[5]) * f * f + f + one;
420  // Return 2^m * exp(r).
421  auto ef =
422  tvm::reinterpret(DataType::Float(32), ::tvm::cast(DataType::Int(32), n + b) << 23);
423  return ::tvm::max(ef * y, _x(i)); // NOLINT(*)
424  },
425  name, tag);
426 }
427 
438 inline Tensor fast_exp(const Tensor& x, std::string name = "T_fast_exp",
439  std::string tag = kElementWise) {
440  if (x->dtype == DataType::Float(32)) {
441  auto ret = fast_exp_float32(x, name, tag);
442  return ret;
443  } else {
444  return compute(x->shape, [&](const ffi::Array<Var>& i) { return ::tvm::exp(x(i)); }, name, tag);
445  }
446 }
447 
451 inline Tensor fast_erf_float32(const Tensor& data, std::string name, std::string tag) {
452  return compute(
453  data->shape, [&](const ffi::Array<Var>& i) { return fast_erf_float_expr(data(i), 32); }, name,
454  tag);
455 }
456 
460 inline Tensor fast_erf_float16(const Tensor& data, std::string name, std::string tag) {
461  return compute(
462  data->shape, [&](const ffi::Array<Var>& i) { return fast_erf_float_expr(data(i), 16); }, name,
463  tag);
464 }
465 
475 inline Tensor fast_erf(const Tensor& x, std::string name = "T_fast_erf",
476  std::string tag = kElementWise) {
477  if (x->dtype == DataType::Float(32)) {
478  auto ret = fast_erf_float32(x, name, tag);
479  return ret;
480  } else if (x->dtype == DataType::Float(16)) {
481  auto ret = fast_erf_float16(x, name, tag);
482  return ret;
483  } else {
484  return topi::erf(x);
485  }
486 }
487 
488 } // namespace topi
489 } // namespace tvm
490 #endif // TVM_TOPI_ELEMWISE_H_
Reference to PrimExprNode.
Definition: expr.h:126
Runtime primitive data type.
Definition: data_type.h:47
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:295
Managed Tensor. The array is backed by reference counted blocks.
Definition: tensor.h:54
Detail broadcast.
#define TOPI_DECLARE_UNARY_OP(OpName)
Definition: elemwise.h:43
Tensor expression language DSL.
Definition: extracted_task.h:33
Tensor compute(ffi::Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", ffi::Map< ffi::String, ffi::Any > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
PrimExpr make_const(DataType t, ValueType value, Span span=Span())
Make a const value with certain data type.
Definition: op.h:1007
constexpr auto kElementWise
Definition: tags.h:32
Tensor clip(const Tensor &x, const PrimExpr &a_min, const PrimExpr &a_max, std::string name="T_clip", std::string tag=kElementWise)
Creates an operation that clips each element of a tensor to the interval [a_min, a_max].
Definition: elemwise.h:253
Tensor ceil(const Tensor &x, std::string name="T_" "ceil", std::string tag=kElementWise)
Definition: elemwise.h:58
Tensor isinf(const Tensor &x, std::string name="T_" "isinf", std::string tag=kElementWise)
Definition: elemwise.h:76
Tensor round(const Tensor &x, std::string name="T_" "round", std::string tag=kElementWise)
Definition: elemwise.h:59
Tensor tan(const Tensor &x, std::string name="T_" "tan", std::string tag=kElementWise)
Definition: elemwise.h:64
Tensor logical_not(const Tensor &x, std::string name="T_logical_not", std::string tag=kElementWise)
Creates an operation that returns the logical NOT of a given tensor.
Definition: elemwise.h:180
Tensor fast_erf_float32(const Tensor &data, std::string name, std::string tag)
Fast_erf_float expression from Eigen.
Definition: elemwise.h:451
Tensor reinterpret(const Tensor &x, DataType type, std::string name="tensor", std::string tag=kElementWise)
Reinterpret each element of x to the given type.
Definition: elemwise.h:306
Tensor floor(const Tensor &x, std::string name="T_" "floor", std::string tag=kElementWise)
Definition: elemwise.h:57
Tensor full(const ffi::Array< PrimExpr > &shape, DataType dtype, const PrimExpr fill_value, std::string name="T_full", std::string tag=kElementWise)
Creates an operation that fill a tensor with fill_value.
Definition: elemwise.h:347
Tensor trunc(const Tensor &x, std::string name="T_" "trunc", std::string tag=kElementWise)
Definition: elemwise.h:60
Tensor isnan(const Tensor &x, std::string name="T_" "isnan", std::string tag=kElementWise)
Definition: elemwise.h:73
Tensor fast_erf_float16(const Tensor &data, std::string name, std::string tag)
Fast_erf_float expression from Eigen for float16.
Definition: elemwise.h:460
Tensor exp(const Tensor &x, std::string name="T_" "exp", std::string tag=kElementWise)
Definition: elemwise.h:50
Tensor acos(const Tensor &x, std::string name="T_" "acos", std::string tag=kElementWise)
Definition: elemwise.h:67
Tensor asinh(const Tensor &x, std::string name="T_" "asinh", std::string tag=kElementWise)
Definition: elemwise.h:70
Tensor log(const Tensor &x, std::string name="T_" "log", std::string tag=kElementWise)
Definition: elemwise.h:54
Tensor fast_tanh_float(const Tensor &in, std::string name, std::string tag)
Fast_tanh_float implementation from Eigen https://github.com/eigenteam/eigen-git-mirror/blob/master/E...
Definition: elemwise.h:82
Tensor fast_tanh(const Tensor &x, std::string name="T_fast_tanh", std::string tag=kElementWise)
Creates an operation that returns hyperbolic tanh of a given tensor.
Definition: elemwise.h:131
Tensor fast_erf(const Tensor &x, std::string name="T_fast_erf", std::string tag=kElementWise)
Fast erf implementation.
Definition: elemwise.h:475
Tensor cos(const Tensor &x, std::string name="T_" "cos", std::string tag=kElementWise)
Definition: elemwise.h:62
tvm::PrimExpr minimum(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:372
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type,...
Definition: elemwise.h:277
Tensor sqrt(const Tensor &x, std::string name="T_" "sqrt", std::string tag=kElementWise)
Definition: elemwise.h:53
Tensor rsqrt(const Tensor &x, std::string name="tensor", std::string tag=kElementWise)
Creates an operation that returns rsqrt of a given tensor.
Definition: elemwise.h:231
Tensor abs(const Tensor &x, std::string name="T_" "abs", std::string tag=kElementWise)
Definition: elemwise.h:61
Tensor asin(const Tensor &x, std::string name="T_" "asin", std::string tag=kElementWise)
Definition: elemwise.h:69
Tensor bitwise_not(const Tensor &x, std::string name="T_bitwise_not", std::string tag=kElementWise)
Creates an operation that returns the bitwise NOT of a given tensor.
Definition: elemwise.h:194
Tensor atan(const Tensor &x, std::string name="T_" "atan", std::string tag=kElementWise)
Definition: elemwise.h:71
Tensor atanh(const Tensor &x, std::string name="T_" "atanh", std::string tag=kElementWise)
Definition: elemwise.h:72
Tensor log10(const Tensor &x, std::string name="T_" "log10", std::string tag=kElementWise)
Definition: elemwise.h:56
Tensor sigmoid(const Tensor &x, std::string name="T_" "sigmoid", std::string tag=kElementWise)
Definition: elemwise.h:52
Tensor identity(const Tensor &x, std::string name="T_identity", std::string tag=kElementWise)
Creates an operation that returns identity of a given tensor.
Definition: elemwise.h:152
Tensor isfinite(const Tensor &x, std::string name="T_" "isfinite", std::string tag=kElementWise)
Definition: elemwise.h:75
Tensor log2(const Tensor &x, std::string name="T_" "log2", std::string tag=kElementWise)
Definition: elemwise.h:55
Tensor elemwise_sum(const ffi::Array< Tensor > &xs, std::string name="T_elemwise_sum", std::string tag=kElementWise)
Creates an operation that sum each element of a tensor.
Definition: elemwise.h:321
Tensor cosh(const Tensor &x, std::string name="T_" "cosh", std::string tag=kElementWise)
Definition: elemwise.h:63
Tensor acosh(const Tensor &x, std::string name="T_" "acosh", std::string tag=kElementWise)
Definition: elemwise.h:68
Tensor sin(const Tensor &x, std::string name="T_" "sin", std::string tag=kElementWise)
Definition: elemwise.h:65
Tensor tanh(const Tensor &x, std::string name="T_" "tanh", std::string tag=kElementWise)
Definition: elemwise.h:74
Tensor full_like(const Tensor &x, const PrimExpr fill_value, std::string name="T_full_like", std::string tag=kElementWise)
Creates an operation that construct a tensor with same shape as input tensor, then fill a tensor with...
Definition: elemwise.h:367
Tensor erf(const Tensor &x, std::string name="T_" "erf", std::string tag=kElementWise)
Definition: elemwise.h:51
Tensor fast_exp(const Tensor &x, std::string name="T_fast_exp", std::string tag=kElementWise)
Fast exponential function implementation.
Definition: elemwise.h:438
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1981
Tensor sign(const Tensor &x, std::string name="T_sign", std::string tag=kElementWise)
Returns the sign of the tensor.
Definition: elemwise.h:208
Tensor negative(const Tensor &x, std::string name="T_negative", std::string tag=kElementWise)
Creates an operation that returns the negation of a given tensor.
Definition: elemwise.h:166
Tensor sinh(const Tensor &x, std::string name="T_" "sinh", std::string tag=kElementWise)
Definition: elemwise.h:66
tvm::PrimExpr maximum(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:359
Tensor fast_exp_float32(const Tensor &_x, std::string name, std::string tag)
Fast exponential function implementation.
Definition: elemwise.h:394
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
PrimExpr cast(const DataType &t, PrimExpr value, Span span=Span())
cast value to type.
Tag definitions.
TIR builtin intrinsics.
TIR expressions.
Common operators defined for Expr.