tvm
tensor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TE_TENSOR_H_
25 #define TVM_TE_TENSOR_H_
26 
27 #include <tvm/arith/bound.h>
28 #include <tvm/ffi/reflection/registry.h>
29 #include <tvm/tir/expr.h>
30 #include <tvm/tir/op.h>
31 
32 #include <string>
33 #include <type_traits>
34 #include <utility>
35 #include <vector>
36 
37 namespace tvm {
38 namespace te {
39 
40 using arith::IntSet;
41 using namespace tvm::tir;
42 
43 // internal node container for Operation
44 class OperationNode;
45 class Tensor;
46 
48 class Operation : public ObjectRef {
49  public:
51  Operation() {}
52  explicit Operation(ObjectPtr<Object> n) : ObjectRef(n) {}
53  explicit Operation(ffi::UnsafeInit tag) : ObjectRef(tag) {}
58  inline const OperationNode* operator->() const;
64  TVM_DLL Tensor output(size_t i) const;
67 };
68 
70 class TensorNode : public DataProducerNode {
71  public:
73  ffi::Array<PrimExpr> shape;
79  int value_index{0};
80 
81  static void RegisterReflection();
82 
83  ffi::Array<PrimExpr> GetShape() const final { return shape; }
84 
85  DataType GetDataType() const final { return dtype; }
86 
87  TVM_DLL PrimExpr ToPrimExpr() const final;
88 
89  TVM_DLL ffi::String GetNameHint() const final;
90 
91  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode;
92 
93  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("te.Tensor", TensorNode, DataProducerNode);
94 };
95 
100 class Tensor : public DataProducer {
101  private:
108  inline PrimExpr IndexTensor(ffi::Array<PrimExpr> indices, bool support_negative_indices) const;
109 
110  public:
111  TVM_DLL Tensor(ffi::Array<PrimExpr> shape, DataType dtype, Operation op, int value_index);
117  inline bool operator==(const Tensor& other) const;
123  inline bool operator!=(const Tensor& other) const;
125  inline size_t ndim() const;
131  template <typename... Args>
132  inline PrimExpr operator()(Args&&... args) const {
133  ffi::Array<PrimExpr> indices{std::forward<Args>(args)...};
134  return operator()(indices);
135  }
141  TVM_DLL PrimExpr operator()(ffi::Array<PrimExpr> indices) const;
147  TVM_DLL PrimExpr operator()(ffi::Array<Var> indices) const;
153  template <typename... Args>
154  TVM_DLL PrimExpr IndexWithNegativeIndices(Args&&... args) const {
155  ffi::Array<PrimExpr> indices{std::forward<Args>(args)...};
156  return IndexWithNegativeIndices(indices);
157  }
163  TVM_DLL PrimExpr IndexWithNegativeIndices(ffi::Array<PrimExpr> indices) const;
169  TVM_DLL PrimExpr IndexWithNegativeIndices(ffi::Array<Var> indices) const;
170 
175  class Slice {
176  public:
177  // construct via tensor and indices
178  Slice(const Tensor& tensor, std::vector<PrimExpr> indices)
179  : tensor_(tensor), indices_(indices) {}
186  std::vector<PrimExpr> other = indices_;
187  other.emplace_back(i);
188  return Slice(tensor_, other);
189  }
195  inline operator PrimExpr() const { return tensor_(indices_); }
196 
197  private:
198  const Tensor& tensor_;
199  std::vector<PrimExpr> indices_;
200  };
206  inline Slice operator[](PrimExpr i) const { return Slice(*this, {i}); }
207 
209 };
210 
211 // Implementations of inline functions
212 inline size_t Tensor::ndim() const { return (*this)->shape.size(); }
213 
214 inline bool Tensor::operator==(const Tensor& other) const {
215  if (get() == other.get()) return true;
216  if (get() == nullptr || other.get() == nullptr) return false;
217  if ((*this)->op.defined() || other->op.defined()) {
218  return (*this)->op == other->op && (*this)->value_index == other->value_index;
219  } else {
220  return false;
221  }
222 }
223 
224 inline bool Tensor::operator!=(const Tensor& other) const { return !(*this == other); }
225 
226 // macro to turn every operation of slice to expression
227 #define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \
228  inline PrimExpr operator Op(const Tensor::Slice& a) { return Op a.operator PrimExpr(); }
229 
230 #define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \
231  template <typename T> \
232  inline PrimExpr operator Op(const Tensor::Slice& a, const T& b) { \
233  return a.operator PrimExpr() Op b; \
234  } \
235  template <typename T> \
236  inline PrimExpr operator Op(const T& a, const Tensor::Slice& b) { \
237  return a Op b.operator PrimExpr(); \
238  } \
239  inline PrimExpr operator Op(const Tensor::Slice& a, const Tensor::Slice& b) { \
240  return a.operator PrimExpr() Op b.operator PrimExpr(); \
241  }
242 
258 
259 } // namespace te
260 } // namespace tvm
261 
262 namespace std {
263 template <>
264 struct hash<::tvm::te::Operation> : public ::tvm::ObjectPtrHash {};
265 
266 template <>
267 struct hash<::tvm::te::Tensor> {
268  std::size_t operator()(const ::tvm::te::Tensor& k) const {
269  ::tvm::ObjectPtrHash hasher;
270  if (k.defined() && k->op.defined()) {
271  return hasher(k->op);
272  } else {
273  return hasher(k);
274  }
275  }
276 };
277 } // namespace std
278 #endif // TVM_TE_TENSOR_H_
Bound deducers.
Reference to PrimExprNode.
Definition: expr.h:124
Runtime primitive data type.
Definition: data_type.h:47
Base class of all operation nodes.
Definition: operation.h:56
Operation that produces tensors.
Definition: tensor.h:48
Tensor output(size_t i) const
get the i-th output of the operation.
Operation(ObjectPtr< Object > n)
Definition: tensor.h:52
Operation()
default constructor
Definition: tensor.h:51
Operation(ffi::UnsafeInit tag)
Definition: tensor.h:53
Node to represent a tensor.
Definition: tensor.h:70
PrimExpr ToPrimExpr() const final
DataType GetDataType() const final
Get the data type of the result.
Definition: tensor.h:85
DataType dtype
data type in the content of the tensor
Definition: tensor.h:75
ffi::Array< PrimExpr > shape
The shape of the tensor.
Definition: tensor.h:73
static void RegisterReflection()
Operation op
the source operation, can be None
Definition: tensor.h:77
ffi::Array< PrimExpr > GetShape() const final
Get the shape of the result.
Definition: tensor.h:83
data structure to represent a slice that fixes first k coordinates. This is used to enable syntax sug...
Definition: tensor.h:175
Slice operator[](PrimExpr i)
get i-th slice from the current slice.
Definition: tensor.h:185
Slice(const Tensor &tensor, std::vector< PrimExpr > indices)
Definition: tensor.h:178
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:100
PrimExpr IndexWithNegativeIndices(ffi::Array< PrimExpr > indices) const
Take elements from the tensor with support for negative indices.
bool operator==(const Tensor &other) const
check if two tensors equals each other.
Definition: tensor.h:214
Tensor(ffi::Array< PrimExpr > shape, DataType dtype, Operation op, int value_index)
PrimExpr IndexWithNegativeIndices(Args &&... args) const
Take elements from the tensor with support for negative indices.
Definition: tensor.h:154
Slice operator[](PrimExpr i) const
get i-th slice from the current Tensor.
Definition: tensor.h:206
PrimExpr IndexWithNegativeIndices(ffi::Array< Var > indices) const
Take elements from the tensor with support for negative indices.
size_t ndim() const
Definition: tensor.h:212
PrimExpr operator()(ffi::Array< PrimExpr > indices) const
Take elements from the tensor.
bool operator!=(const Tensor &other) const
check if two tensors are different.
Definition: tensor.h:224
PrimExpr operator()(Args &&... args) const
Take elements from the tensor.
Definition: tensor.h:132
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Tensor, DataProducer, TensorNode)
PrimExpr operator()(ffi::Array< Var > indices) const
Take elements from the tensor.
Base node for data producers.
Definition: buffer.h:260
Managed reference to DataProducerNode.
Definition: buffer.h:286
PrimExpr operator==(const Tensor::Slice &a, const T &b)
Definition: tensor.h:248
PrimExpr operator!=(const Tensor::Slice &a, const T &b)
Definition: tensor.h:251
Definition: extracted_task.h:30
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1960
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op)
Definition: tensor.h:227
#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op)
Definition: tensor.h:230
TIR expressions.
Common operators defined for Expr.