24 #ifndef TVM_TOPI_NN_DILATE_H_
25 #define TVM_TOPI_NN_DILATE_H_
48 ICHECK_GT(args.size(), 0) <<
"all requires at least one argument";
51 for (
size_t i = 1; i < args.size(); ++i) {
71 std::string name =
"tensor", std::string tag =
kInjective) {
72 auto n = x->shape.size();
73 ICHECK_EQ(n, strides.size()) <<
"strides size (" << strides.size()
74 <<
") must match dimension of x (" << n <<
")";
76 Array<PrimExpr> out_shape;
78 for (
size_t i = 0; i < n; ++i) {
79 out_shape.push_back(analyzer.
Simplify((x->shape[i] - 1) * (strides[i] + 1)));
84 [&](
const Array<Var>& indices) {
85 Array<PrimExpr> not_zero;
86 Array<PrimExpr> index_tuple;
87 for (
size_t i = 0; i < n; ++i) {
88 if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) {
89 index_tuple.push_back(indices[i]);
91 index_tuple.push_back(
indexdiv(indices[i], strides[i]));
92 not_zero.push_back((
indexmod(indices[i], strides[i])) == 0);
95 if (not_zero.size() > 0) {
96 auto all_not_zero = all(not_zero);
97 return tvm::if_then_else(all_not_zero, x(index_tuple),
98 make_const(x->dtype, dilation_value));
100 return x(index_tuple);
Algebra expression simplifications.
Reference to PrimExprNode.
Definition: expr.h:129
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:636
PrimExpr Simplify(const PrimExpr &expr, int steps=2)
Simplify expr.
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:100
Tensor expression language DSL.
Definition: extracted_task.h:33
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ffi::Any > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
PrimExpr all(Array< PrimExpr > args)
Create a new expression of the logical and of all conditions in the arguments.
Definition: dilate.h:47
Tensor dilate(const Tensor &x, Array< PrimExpr > strides, double dilation_value, std::string name="tensor", std::string tag=kInjective)
Dilate data with given dilation value (0 by default).
Definition: dilate.h:70
constexpr auto kInjective
Definition: tags.h:33
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span=Span())
compute floor(a / b) where a and b are non-negative.
PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span=Span())
compute the remainder floor(a / b) where a and b are non-negative.
Operation node can generate one or multiple Tensors.