tvm
dilate.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TOPI_NN_DILATE_H_
25 #define TVM_TOPI_NN_DILATE_H_
26 
27 #include <tvm/arith/analyzer.h>
28 #include <tvm/te/operation.h>
29 #include <tvm/topi/tags.h>
30 
31 #include <string>
32 
33 namespace tvm {
34 namespace topi {
35 namespace nn {
36 
37 using namespace tvm::te;
38 
48  ICHECK_GT(args.size(), 0) << "all requires at least one argument";
49 
50  PrimExpr ret = args[0];
51  for (size_t i = 1; i < args.size(); ++i) {
52  ret = ret && args[i];
53  }
54  return ret;
55 }
56 
70 inline Tensor dilate(const Tensor& x, Array<PrimExpr> strides, double dilation_value,
71  std::string name = "tensor", std::string tag = kInjective) {
72  auto n = x->shape.size();
73  ICHECK_EQ(n, strides.size()) << "strides size (" << strides.size()
74  << ") must match dimension of x (" << n << ")";
75 
76  Array<PrimExpr> out_shape;
77  arith::Analyzer analyzer;
78  for (size_t i = 0; i < n; ++i) {
79  out_shape.push_back(analyzer.Simplify((x->shape[i] - 1) * (strides[i] + 1)));
80  }
81 
82  return tvm::te::compute(
83  out_shape,
84  [&](const Array<Var>& indices) {
85  Array<PrimExpr> not_zero;
86  Array<PrimExpr> index_tuple;
87  for (size_t i = 0; i < n; ++i) {
88  if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) {
89  index_tuple.push_back(indices[i]);
90  } else {
91  index_tuple.push_back(indexdiv(indices[i], strides[i]));
92  not_zero.push_back((indexmod(indices[i], strides[i])) == 0);
93  }
94  }
95  if (not_zero.size() > 0) {
96  auto all_not_zero = all(not_zero);
97  return tvm::if_then_else(all_not_zero, x(index_tuple),
98  make_const(x->dtype, dilation_value));
99  }
100  return x(index_tuple);
101  },
102  name, tag);
103 }
104 
105 } // namespace nn
106 } // namespace topi
107 } // namespace tvm
108 #endif // TVM_TOPI_NN_DILATE_H_
Algebra expression simplifications.
Reference to PrimExprNode.
Definition: expr.h:115
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:629
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
size_t size() const
Definition: array.h:420
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Tensor expression language DSL.
Definition: extracted_task.h:33
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
PrimExpr all(Array< PrimExpr > args)
Create a new expression of the logical and of all conditions in the arguments.
Definition: dilate.h:47
Tensor dilate(const Tensor &x, Array< PrimExpr > strides, double dilation_value, std::string name="tensor", std::string tag=kInjective)
Dilate data with given dilation value (0 by default).
Definition: dilate.h:70
constexpr auto kInjective
Definition: tags.h:33
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
Operation node can generate one or multiple Tensors.
External function interface to rocBLAS libraries.