tvm
broadcast.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_DETAIL_BROADCAST_H_
25 #define TVM_TOPI_DETAIL_BROADCAST_H_
26 
27 #include <tvm/te/operation.h>
29 
30 #include <algorithm>
31 #include <deque>
32 #include <string>
33 
34 namespace tvm {
35 namespace topi {
36 namespace detail {
37 
38 struct BroadcastHelper {
39  std::deque<tvm::PrimExpr> common_shape;
40  std::deque<tvm::tir::Var> all_vars;
41  std::deque<tvm::tir::Var> vars1;
42  std::deque<tvm::tir::Var> vars2;
43 };
44 
45 static inline DataType CommonType(DataType type1, DataType type2) {
46  ICHECK(type1.is_scalar() && type2.is_scalar());
47  ICHECK(type1.code() == type2.code());
48  return DataType(type1.code(), std::max(type1.bits(), type2.bits()), /*lanes=*/1);
49 }
50 
51 inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::PrimExpr>& shape1,
52  const tvm::Array<tvm::PrimExpr>& shape2) {
53  BroadcastHelper bh;
54  int s1_size = shape1.size();
55  int s2_size = shape2.size();
56  tvm::PrimExpr one(1);
57  int i;
58 
59  auto cast_if_needed = [](DataType to_type, PrimExpr expr) {
60  return to_type != expr.dtype() ? cast(to_type, expr) : expr;
61  };
62 
63  for (i = 1; i <= std::min(s1_size, s2_size); ++i) {
64  // TODO(@icemelon9): Need to revisit this part
65  const IntImmNode* static_size1 = shape1[s1_size - i].as<IntImmNode>();
66  const IntImmNode* static_size2 = shape2[s2_size - i].as<IntImmNode>();
67  DataType common_type = CommonType(shape1[s1_size - i].dtype(), shape2[s2_size - i].dtype());
68 
69  bh.all_vars.push_front(tvm::tir::Var("dim", common_type));
70  if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) {
71  bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i]));
72  bh.vars1.push_front(bh.all_vars[0]);
73  bh.vars2.push_front(bh.all_vars[0]);
74  } else if (topi::detail::EqualCheck(one, shape1[s1_size - i])) {
75  ICHECK(!topi::detail::EqualCheck(one, shape2[s2_size - i]));
76  bh.common_shape.push_front(cast_if_needed(common_type, shape2[s2_size - i]));
77  bh.vars2.push_front(bh.all_vars[0]);
78  } else if (topi::detail::EqualCheck(one, shape2[s2_size - i])) {
79  bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i]));
80  bh.vars1.push_front(bh.all_vars[0]);
81  } else if (!static_size1 && !static_size2) {
82  bh.common_shape.push_front(
83  cast_if_needed(common_type, max(shape1[s1_size - i], shape2[s2_size - i])));
84  bh.vars1.push_front(bh.all_vars[0]);
85  bh.vars2.push_front(bh.all_vars[0]);
86  } else if (!static_size1) {
87  bh.common_shape.push_front(cast_if_needed(common_type, shape2[s2_size - i]));
88  bh.vars2.push_front(bh.all_vars[0]);
89  bh.vars1.push_front(bh.all_vars[0]);
90  } else if (!static_size2) {
91  bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i]));
92  bh.vars1.push_front(bh.all_vars[0]);
93  bh.vars2.push_front(bh.all_vars[0]);
94  } else {
95  ICHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] << " and "
96  << shape2[s2_size - i]
97  << " in: " << tvm::Array<tvm::PrimExpr>(shape1.begin(), shape1.end()) << " and "
98  << tvm::Array<tvm::PrimExpr>(shape2.begin(), shape2.end());
99  }
100  }
101  // Remaining dimensions whether on shape1 or shape2 can always be completed
102  auto max_size = std::max(s1_size, s2_size);
103  auto& shape = (s1_size > s2_size) ? shape1 : shape2;
104  auto& vars = (s1_size > s2_size) ? bh.vars1 : bh.vars2;
105  for (; i <= max_size; ++i) {
106  bh.all_vars.push_front(tvm::tir::Var("v", shape[max_size - 1].dtype()));
107  bh.common_shape.push_front(shape[max_size - i]);
108  vars.push_front(bh.all_vars[0]);
109  }
110  return bh;
111 }
112 
113 inline tvm::Array<tvm::PrimExpr> InputIndexFromBroadcast(
114  const tvm::Array<tvm::tir::Var>& ovars, const tvm::te::Tensor& T,
115  const std::deque<tvm::tir::Var>& my_vars, const std::deque<tvm::tir::Var>& all_vars) {
117  ICHECK_EQ(ovars.size(), all_vars.size());
118  // N^2, could use a map but NBD.
119  size_t expected_dims = T->shape.size();
120  for (size_t i = 0; i < ovars.size(); ++i) {
121  bool found = false;
122  for (size_t j = 0; j < my_vars.size(); ++j) {
123  if (all_vars[i].same_as(my_vars[j])) {
124  ivars.push_back(ovars[i]);
125  found = true;
126  break;
127  }
128  }
129  // Only inject 0 here if we have not yet reached the dimension of I
130  // (i.e. this must be a 1)
131  if (!found && (ovars.size() - i) <= expected_dims) {
132  ivars.push_back(tvm::tir::make_zero(ovars[i].dtype()));
133  }
134  }
135  ICHECK(expected_dims == ivars.size());
136  return ivars;
137 }
138 
139 template <typename FBinaryExpr>
140 inline tvm::te::Tensor WithBroadcast(FBinaryExpr op, const tvm::te::Tensor& A,
141  const tvm::te::Tensor& B, const std::string& name = "tensor",
142  const std::string& tag = "") {
143  auto bh = BroadcastShape(A->shape, B->shape);
144  auto l = [&](tvm::Array<tvm::tir::Var> ovars) {
145  return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)),
146  B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars)));
147  };
148  return tvm::te::compute(tvm::Array<tvm::PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()),
149  l, name, tag);
150 }
151 
152 } // namespace detail
153 } // namespace topi
154 } // namespace tvm
155 
156 #endif // TVM_TOPI_DETAIL_BROADCAST_H_
Reference to PrimExprNode.
Definition: expr.h:115
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
iterator end() const
Definition: array.h:390
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
iterator begin() const
Definition: array.h:387
size_t size() const
Definition: array.h:420
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
a named variable in TIR
Definition: var.h:89
Utility functions for handling constants in TVM expressions.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
tvm::IntImmNode IntImmNode
Definition: expr.h:49
PrimExpr make_zero(DataType t, Span span=Span())
Make a const zero expr.
Definition: op.h:976
Tensor max(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the maximum of elements over a given axis.
Definition: reduction.h:440
Tensor cast(const Tensor &x, DataType type, std::string name="T_cast", std::string tag=kElementWise)
Cast each element of x to the given type. If expr is scalar and type is a corresponding vector type,...
Definition: elemwise.h:281
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1913
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
runtime::DataType DataType
Definition: data_type.h:493
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
Operation node can generate one or multiple Tensors.