tvm
broadcast.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_BROADCAST_H_
25 #define TVM_TOPI_BROADCAST_H_
26 
29 #include <tvm/topi/tags.h>
30 
31 #include <algorithm>
32 #include <string>
33 
34 namespace tvm {
35 namespace topi {
36 
49  const tvm::ffi::Array<tvm::PrimExpr>& output_shape,
50  std::string name = "T_broadcast_to",
51  std::string tag = kBroadcast) {
52  ICHECK_GE(output_shape.size(), t->shape.size())
53  << "Not a broadcast, output dimensionality smaller than input.\noutput: " << output_shape
54  << "\nvs\ninput: " << t;
55  auto bh = detail::BroadcastShape(output_shape, t->shape);
56  ICHECK_EQ(output_shape.size(), bh.common_shape.size());
57  ffi::Array<PrimExpr> oshape;
58  for (size_t i = 0; i < output_shape.size(); ++i) {
59  if (output_shape[i].as<tir::IntImmNode>() == nullptr) {
60  oshape.push_back(output_shape[i]);
61  } else {
62  ICHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i]));
63  oshape.push_back(bh.common_shape[i]);
64  }
65  }
66  auto l = [&](tvm::ffi::Array<tvm::tir::Var> ovars) {
67  return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars));
68  };
69  return tvm::te::compute(oshape, l, name, tag);
70 }
71 
72 #define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \
73  inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \
74  inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \
75  std::string name = "T_" #Name, std::string tag = kBroadcast) { \
76  auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
77  return detail::WithBroadcast(l, A, B, name, tag); \
78  } \
79  inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B, \
80  std::string name = "T_" #Name, std::string tag = kElementWise) { \
81  auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
82  return tvm::te::compute( \
83  A->shape, [&](const ::tvm::ffi::Array<::tvm::tir::Var>& i) { return l(A(i), B); }, name, \
84  tag); \
85  } \
86  inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B, \
87  std::string name = "T_" #Name, std::string tag = kElementWise) { \
88  auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
89  return tvm::te::compute( \
90  B->shape, [&](const ::tvm::ffi::Array<::tvm::tir::Var>& i) { return l(A, B(i)); }, name, \
91  tag); \
92  }
93 
94 #define TOPI_DEFINE_OP_OVERLOAD(Name, OpName) \
95  inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B) { \
96  return topi::OpName(A, B); \
97  } \
98  inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B) { \
99  return topi::OpName(A, B); \
100  } \
101  inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B) { \
102  return topi::OpName(A, B); \
103  }
104 
116 TOPI_DEFINE_BCAST_OP(logical_and, { return a && b; });
118 
130 TOPI_DEFINE_BCAST_OP(logical_or, { return a || b; });
132 
144 TOPI_DEFINE_BCAST_OP(logical_xor, { return a ^ b; });
145 
157 TOPI_DEFINE_BCAST_OP(bitwise_and, { return a & b; });
159 
171 TOPI_DEFINE_BCAST_OP(bitwise_or, { return a | b; });
173 
185 TOPI_DEFINE_BCAST_OP(bitwise_xor, { return a ^ b; });
187 
199 TOPI_DEFINE_BCAST_OP(add, { return a + b; });
201 
213 TOPI_DEFINE_BCAST_OP(subtract, { return a - b; });
215 
227 TOPI_DEFINE_BCAST_OP(multiply, { return a * b; });
229 
241 TOPI_DEFINE_BCAST_OP(divide, { return div(a, b); });
242 
255  if (a.dtype().is_int() || a.dtype().is_uint()) {
256  return floordiv(a, b);
257  } else {
258  return floor(div(a, b));
259  }
260 });
261 
277 
290  if (a.dtype().is_int() || a.dtype().is_uint()) {
291  return truncdiv(a, b);
292  } else {
293  return trunc(div(a, b));
294  }
295 });
296 
308 TOPI_DEFINE_BCAST_OP(mod, { return truncmod(a, b); });
309 
322  if (a.dtype().is_int() || a.dtype().is_uint()) {
323  return floormod(a, b);
324  } else {
325  return a - floor_divide(a, b) * b;
326  }
327 });
328 
341  if (a.dtype().is_int() || a.dtype().is_uint()) {
342  return truncmod(a, b);
343  } else {
344  return a - trunc_divide(a, b) * b;
345  }
346 });
347 
360 
373 
385 TOPI_DEFINE_BCAST_OP(power, { return tvm::pow(a, b); });
386 
398 TOPI_DEFINE_BCAST_OP(left_shift, { return a << b; });
400 
412 TOPI_DEFINE_BCAST_OP(right_shift, { return a >> b; });
414 
426 TOPI_DEFINE_BCAST_OP(greater, { return (a > b); });
427 
439 TOPI_DEFINE_BCAST_OP(less, { return (a < b); });
440 
452 TOPI_DEFINE_BCAST_OP(equal, { return (a == b); });
453 
465 TOPI_DEFINE_BCAST_OP(not_equal, { return (a != b); });
466 
478 TOPI_DEFINE_BCAST_OP(greater_equal, { return (a >= b); });
479 
491 TOPI_DEFINE_BCAST_OP(less_equal, { return (a <= b); });
492 
493 } // namespace topi
494 } // namespace tvm
495 
496 #endif // TVM_TOPI_BROADCAST_H_
#define TOPI_DEFINE_OP_OVERLOAD(Name, OpName)
Definition: broadcast.h:94
#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule)
Definition: broadcast.h:72
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:100
Utility functions for handling constants in TVM expressions.
Detail broadcast.
Tensor compute(ffi::Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", ffi::Map< ffi::String, ffi::Any > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
tvm::PrimExpr floor_mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:327
tvm::te::Tensor broadcast_to(const tvm::te::Tensor &t, const tvm::ffi::Array< tvm::PrimExpr > &output_shape, std::string name="T_broadcast_to", std::string tag=kBroadcast)
Creates an operation that broadcasts a tensor into a compatible shape according to numpy's rules.
Definition: broadcast.h:48
constexpr auto kBroadcast
Definition: tags.h:36
tvm::PrimExpr bitwise_or(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:171
tvm::PrimExpr not_equal(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:465
tvm::PrimExpr less(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:439
tvm::PrimExpr subtract(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:213
tvm::PrimExpr logical_or(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:130
tvm::PrimExpr left_shift(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:398
tvm::PrimExpr trunc_mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:346
tvm::PrimExpr greater_equal(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:478
tvm::PrimExpr bitwise_and(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:157
tvm::PrimExpr divide(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:241
tvm::PrimExpr multiply(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:227
tvm::PrimExpr floor_divide(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:260
tvm::PrimExpr log_add_exp(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:276
tvm::PrimExpr minimum(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:372
tvm::PrimExpr logical_and(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:116
tvm::PrimExpr trunc_divide(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:295
tvm::PrimExpr right_shift(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:412
tvm::PrimExpr add(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:199
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:308
tvm::PrimExpr less_equal(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:491
tvm::PrimExpr equal(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:452
tvm::PrimExpr bitwise_xor(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:185
tvm::PrimExpr greater(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:426
tvm::PrimExpr power(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:385
tvm::PrimExpr logical_xor(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:144
tvm::PrimExpr maximum(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:359
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of truncdiv
PrimExpr pow(PrimExpr x, PrimExpr y, Span span=Span())
Calculate power(x, y)
PrimExpr logaddexp(PrimExpr a, PrimExpr b, Span span=Span())
Compute log(exp(a) + exp(b)).
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
External function interface to rocBLAS libraries.