tvm
broadcast.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_DETAIL_BROADCAST_H_
25 #define TVM_TOPI_DETAIL_BROADCAST_H_
26 
27 #include <tvm/te/operation.h>
29 
30 #include <algorithm>
31 #include <deque>
32 #include <string>
33 
34 namespace tvm {
35 namespace topi {
36 namespace detail {
37 
38 struct BroadcastHelper {
39  std::deque<tvm::PrimExpr> common_shape;
40  std::deque<tvm::tir::Var> all_vars;
41  std::deque<tvm::tir::Var> vars1;
42  std::deque<tvm::tir::Var> vars2;
43 };
44 
45 inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::PrimExpr>& shape1,
46  const tvm::Array<tvm::PrimExpr>& shape2) {
47  BroadcastHelper bh;
48  int s1_size = shape1.size();
49  int s2_size = shape2.size();
50  tvm::PrimExpr one(1);
51  int i;
52  for (i = 1; i <= std::min(s1_size, s2_size); ++i) {
53  // TODO(@icemelon9): Need to revisit this part
54  const IntImmNode* static_size1 = shape1[s1_size - i].as<IntImmNode>();
55  const IntImmNode* static_size2 = shape2[s2_size - i].as<IntImmNode>();
56  bh.all_vars.push_front(tvm::tir::Var());
57  if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) {
58  bh.common_shape.push_front(shape1[s1_size - i]);
59  bh.vars1.push_front(bh.all_vars[0]);
60  bh.vars2.push_front(bh.all_vars[0]);
61  } else if (topi::detail::EqualCheck(one, shape1[s1_size - i])) {
62  ICHECK(!topi::detail::EqualCheck(one, shape2[s2_size - i]));
63  bh.common_shape.push_front(shape2[s2_size - i]);
64  bh.vars2.push_front(bh.all_vars[0]);
65  } else if (topi::detail::EqualCheck(one, shape2[s2_size - i])) {
66  bh.common_shape.push_front(shape1[s1_size - i]);
67  bh.vars1.push_front(bh.all_vars[0]);
68  } else if (!static_size1 && !static_size2) {
69  bh.common_shape.push_front(max(shape1[s1_size - i], shape2[s2_size - i]));
70  bh.vars1.push_front(bh.all_vars[0]);
71  bh.vars2.push_front(bh.all_vars[0]);
72  } else if (!static_size1) {
73  bh.common_shape.push_front(shape2[s2_size - i]);
74  bh.vars2.push_front(bh.all_vars[0]);
75  bh.vars1.push_front(bh.all_vars[0]);
76  } else if (!static_size2) {
77  bh.common_shape.push_front(shape1[s1_size - i]);
78  bh.vars1.push_front(bh.all_vars[0]);
79  bh.vars2.push_front(bh.all_vars[0]);
80  } else {
81  ICHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] << " and "
82  << shape2[s2_size - i]
83  << " in: " << tvm::Array<tvm::PrimExpr>(shape1.begin(), shape1.end()) << " and "
84  << tvm::Array<tvm::PrimExpr>(shape2.begin(), shape2.end());
85  }
86  }
87  // Remaining dimensions whether on shape1 or shape2 can always be completed
88  auto max_size = std::max(s1_size, s2_size);
89  auto& shape = (s1_size > s2_size) ? shape1 : shape2;
90  auto& vars = (s1_size > s2_size) ? bh.vars1 : bh.vars2;
91  for (; i <= max_size; ++i) {
92  bh.all_vars.push_front(tvm::tir::Var());
93  bh.common_shape.push_front(shape[max_size - i]);
94  vars.push_front(bh.all_vars[0]);
95  }
96  return bh;
97 }
98 
99 inline tvm::Array<tvm::PrimExpr> InputIndexFromBroadcast(
100  const tvm::Array<tvm::tir::Var>& ovars, const tvm::te::Tensor& T,
101  const std::deque<tvm::tir::Var>& my_vars, const std::deque<tvm::tir::Var>& all_vars) {
103  ICHECK_EQ(ovars.size(), all_vars.size());
104  // N^2, could use a map but NBD.
105  size_t expected_dims = T->shape.size();
106  for (size_t i = 0; i < ovars.size(); ++i) {
107  bool found = false;
108  for (size_t j = 0; j < my_vars.size(); ++j) {
109  if (all_vars[i].same_as(my_vars[j])) {
110  ivars.push_back(ovars[i]);
111  found = true;
112  break;
113  }
114  }
115  // Only inject 0 here if we have not yet reached the dimension of I
116  // (i.e. this must be a 1)
117  if (!found && (ovars.size() - i) <= expected_dims) {
118  ivars.push_back(tvm::tir::make_zero(ovars[i].dtype()));
119  }
120  }
121  ICHECK(expected_dims == ivars.size());
122  return ivars;
123 }
124 
125 template <typename FBinaryExpr>
126 inline tvm::te::Tensor WithBroadcast(FBinaryExpr op, const tvm::te::Tensor& A,
127  const tvm::te::Tensor& B, const std::string& name = "tensor",
128  const std::string& tag = "") {
129  auto bh = BroadcastShape(A->shape, B->shape);
130  auto l = [&](tvm::Array<tvm::tir::Var> ovars) {
131  return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)),
132  B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars)));
133  };
134  return tvm::te::compute(tvm::Array<tvm::PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()),
135  l, name, tag);
136 }
137 
138 } // namespace detail
139 } // namespace topi
140 } // namespace tvm
141 
142 #endif // TVM_TOPI_DETAIL_BROADCAST_H_
Tensor max(const Tensor &data, const Array< Integer > &axis, bool keepdims=false, bool atleast1d=false)
Creates an operation that finds the maximum of elements over a given axis.
Definition: reduction.h:429
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
a named variable in TIR
Definition: var.h:88
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:436
Utility functions for handling constants in TVM expressions.
size_t size() const
Definition: array.h:399
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
tvm::IntImmNode IntImmNode
Definition: expr.h:49
PrimExpr make_zero(DataType t, Span span=Span())
Make a const zero expr.
Definition: op.h:1138
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1758
iterator end() const
Definition: array.h:369
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
iterator begin() const
Definition: array.h:366
Operation node can generate one or multiple Tensors.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Reference to PrimExprNode.
Definition: expr.h:112
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:865