tvm
broadcast.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_BROADCAST_H_
25 #define TVM_TOPI_BROADCAST_H_
26 
29 #include <tvm/topi/tags.h>
30 
31 #include <algorithm>
32 #include <string>
33 
34 namespace tvm {
35 namespace topi {
36 
49  const tvm::Array<tvm::PrimExpr>& output_shape,
50  std::string name = "T_broadcast_to",
51  std::string tag = kBroadcast) {
52  ICHECK_GE(output_shape.size(), t->shape.size())
53  << "Not a broadcast, output dimensionality smaller than input.\noutput: " << output_shape
54  << "\nvs\ninput: " << t;
55  auto bh = detail::BroadcastShape(output_shape, t->shape);
56  ICHECK_EQ(output_shape.size(), bh.common_shape.size());
57  Array<PrimExpr> oshape;
58  for (size_t i = 0; i < output_shape.size(); ++i) {
59  if (output_shape[i].as<tir::IntImmNode>() == nullptr) {
60  oshape.push_back(output_shape[i]);
61  } else {
62  ICHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i]));
63  oshape.push_back(bh.common_shape[i]);
64  }
65  }
66  auto l = [&](tvm::Array<tvm::tir::Var> ovars) {
67  return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars));
68  };
69  return tvm::te::compute(oshape, l, name, tag);
70 }
71 
72 #define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \
73  inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \
74  inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \
75  std::string name = "T_" #Name, std::string tag = kBroadcast) { \
76  auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
77  return detail::WithBroadcast(l, A, B, name, tag); \
78  } \
79  inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B, \
80  std::string name = "T_" #Name, std::string tag = kElementWise) { \
81  auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
82  return tvm::te::compute( \
83  A->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A(i), B); }, name, tag); \
84  } \
85  inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B, \
86  std::string name = "T_" #Name, std::string tag = kElementWise) { \
87  auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
88  return tvm::te::compute( \
89  B->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A, B(i)); }, name, tag); \
90  }
91 
92 #define TOPI_DEFINE_OP_OVERLOAD(Name, OpName) \
93  inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B) { \
94  return topi::OpName(A, B); \
95  } \
96  inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B) { \
97  return topi::OpName(A, B); \
98  } \
99  inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B) { \
100  return topi::OpName(A, B); \
101  }
102 
114 TOPI_DEFINE_BCAST_OP(logical_and, { return a && b; });
116 
128 TOPI_DEFINE_BCAST_OP(logical_or, { return a || b; });
130 
142 TOPI_DEFINE_BCAST_OP(logical_xor, { return a ^ b; });
143 
155 TOPI_DEFINE_BCAST_OP(bitwise_and, { return a & b; });
157 
169 TOPI_DEFINE_BCAST_OP(bitwise_or, { return a | b; });
171 
183 TOPI_DEFINE_BCAST_OP(bitwise_xor, { return a ^ b; });
185 
197 TOPI_DEFINE_BCAST_OP(add, { return a + b; });
199 
211 TOPI_DEFINE_BCAST_OP(subtract, { return a - b; });
213 
225 TOPI_DEFINE_BCAST_OP(multiply, { return a * b; });
227 
239 TOPI_DEFINE_BCAST_OP(divide, { return div(a, b); });
240 
253  if (a.dtype().is_int() || a.dtype().is_uint()) {
254  return floordiv(a, b);
255  } else {
256  return floor(div(a, b));
257  }
258 });
259 
272  if (a.dtype().is_int() || a.dtype().is_uint()) {
273  return truncdiv(a, b);
274  } else {
275  return trunc(div(a, b));
276  }
277 });
278 
290 TOPI_DEFINE_BCAST_OP(mod, { return truncmod(a, b); });
291 
304  if (a.dtype().is_int() || a.dtype().is_uint()) {
305  return floormod(a, b);
306  } else {
307  return a - floor_divide(a, b) * b;
308  }
309 });
310 
323  if (a.dtype().is_int() || a.dtype().is_uint()) {
324  return truncmod(a, b);
325  } else {
326  return a - trunc_divide(a, b) * b;
327  }
328 });
329 
342 
355 
367 TOPI_DEFINE_BCAST_OP(power, { return tvm::pow(a, b); });
368 
380 TOPI_DEFINE_BCAST_OP(left_shift, { return a << b; });
382 
394 TOPI_DEFINE_BCAST_OP(right_shift, { return a >> b; });
396 
408 TOPI_DEFINE_BCAST_OP(greater, { return (a > b); });
409 
421 TOPI_DEFINE_BCAST_OP(less, { return (a < b); });
422 
434 TOPI_DEFINE_BCAST_OP(equal, { return (a == b); });
435 
447 TOPI_DEFINE_BCAST_OP(not_equal, { return (a != b); });
448 
460 TOPI_DEFINE_BCAST_OP(greater_equal, { return (a >= b); });
461 
473 TOPI_DEFINE_BCAST_OP(less_equal, { return (a <= b); });
474 
475 } // namespace topi
476 } // namespace tvm
477 
478 #endif // TVM_TOPI_BROADCAST_H_
#define TOPI_DEFINE_OP_OVERLOAD(Name, OpName)
Definition: broadcast.h:92
#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule)
Definition: broadcast.h:72
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
size_t size() const
Definition: array.h:420
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Utility functions for handling constants in TVM expressions.
Detail broadcast.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
tvm::PrimExpr floor_mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:309
constexpr auto kBroadcast
Definition: tags.h:36
tvm::PrimExpr bitwise_or(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:169
tvm::PrimExpr not_equal(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:447
tvm::PrimExpr less(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:421
tvm::PrimExpr subtract(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:211
tvm::PrimExpr logical_or(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:128
tvm::PrimExpr left_shift(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:380
tvm::PrimExpr trunc_mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:328
tvm::PrimExpr greater_equal(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:460
tvm::te::Tensor broadcast_to(const tvm::te::Tensor &t, const tvm::Array< tvm::PrimExpr > &output_shape, std::string name="T_broadcast_to", std::string tag=kBroadcast)
Creates an operation that broadcasts a tensor into a compatible shape according to numpy's rules.
Definition: broadcast.h:48
tvm::PrimExpr bitwise_and(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:155
tvm::PrimExpr divide(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:239
tvm::PrimExpr multiply(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:225
tvm::PrimExpr floor_divide(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:258
tvm::PrimExpr minimum(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:354
tvm::PrimExpr logical_and(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:114
tvm::PrimExpr trunc_divide(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:277
tvm::PrimExpr right_shift(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:394
tvm::PrimExpr add(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:197
tvm::PrimExpr mod(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:290
tvm::PrimExpr less_equal(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:473
tvm::PrimExpr equal(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:434
tvm::PrimExpr bitwise_xor(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:183
tvm::PrimExpr greater(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:408
tvm::PrimExpr power(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:367
tvm::PrimExpr logical_xor(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:142
tvm::PrimExpr maximum(const tvm::PrimExpr &a, const tvm::PrimExpr &b)
Definition: broadcast.h:341
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
PrimExpr div(PrimExpr a, PrimExpr b, Span span=Span())
compute division in C semantics.
PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder of truncdiv
PrimExpr pow(PrimExpr x, PrimExpr y, Span span=Span())
Calculate power(x, y)
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
External function interface to rocBLAS libraries.