tvm
tensor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TE_TENSOR_H_
25 #define TVM_TE_TENSOR_H_
26 
27 #include <tvm/arith/bound.h>
28 #include <tvm/tir/expr.h>
29 #include <tvm/tir/op.h>
30 
31 #include <string>
32 #include <type_traits>
33 #include <utility>
34 #include <vector>
35 
36 namespace tvm {
37 namespace te {
38 
39 using arith::IntSet;
40 using namespace tvm::tir;
41 
42 // internal node container for Operation
43 class OperationNode;
44 class Tensor;
45 
47 class Operation : public ObjectRef {
48  public:
50  Operation() {}
56  inline const OperationNode* operator->() const;
62  TVM_DLL Tensor output(size_t i) const;
65 };
66 
68 class TensorNode : public DataProducerNode {
69  public:
77  int value_index{0};
80 
82  v->Visit("shape", &shape);
83  v->Visit("dtype", &dtype);
84  v->Visit("op", &op);
85  v->Visit("value_index", &value_index);
86  }
87 
88  Array<PrimExpr> GetShape() const final { return shape; }
89 
90  DataType GetDataType() const final { return dtype; }
91 
92  TVM_DLL String GetNameHint() const final;
93 
94  static constexpr const char* _type_key = "Tensor";
96 };
97 
102 class Tensor : public DataProducer {
103  private:
110  inline PrimExpr IndexTensor(Array<PrimExpr> indices, bool support_negative_indices) const;
111 
112  public:
113  TVM_DLL Tensor(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index);
119  inline bool operator==(const Tensor& other) const;
125  inline bool operator!=(const Tensor& other) const;
127  inline size_t ndim() const;
133  template <typename... Args>
134  inline PrimExpr operator()(Args&&... args) const {
135  Array<PrimExpr> indices{std::forward<Args>(args)...};
136  return operator()(indices);
137  }
143  TVM_DLL PrimExpr operator()(Array<PrimExpr> indices) const;
149  TVM_DLL PrimExpr operator()(Array<Var> indices) const;
155  template <typename... Args>
156  TVM_DLL PrimExpr IndexWithNegativeIndices(Args&&... args) const {
157  Array<PrimExpr> indices{std::forward<Args>(args)...};
158  return IndexWithNegativeIndices(indices);
159  }
172 
177  class Slice {
178  public:
179  // construct via tensor and indices
180  Slice(const Tensor& tensor, std::vector<PrimExpr> indices)
181  : tensor_(tensor), indices_(indices) {}
188  std::vector<PrimExpr> other = indices_;
189  other.emplace_back(i);
190  return Slice(tensor_, other);
191  }
197  inline operator PrimExpr() const { return tensor_(indices_); }
198 
199  private:
200  const Tensor& tensor_;
201  std::vector<PrimExpr> indices_;
202  };
208  inline Slice operator[](PrimExpr i) const { return Slice(*this, {i}); }
209 
211 };
212 
213 // Implementations of inline functions
214 inline size_t Tensor::ndim() const { return (*this)->shape.size(); }
215 
216 inline bool Tensor::operator==(const Tensor& other) const {
217  if (get() == other.get()) return true;
218  if (get() == nullptr || other.get() == nullptr) return false;
219  if ((*this)->op.defined() || other->op.defined()) {
220  return (*this)->op == other->op && (*this)->value_index == other->value_index;
221  } else {
222  return false;
223  }
224 }
225 
226 inline bool Tensor::operator!=(const Tensor& other) const { return !(*this == other); }
227 
228 // macro to turn every operation of slice to expression
229 #define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \
230  inline PrimExpr operator Op(const Tensor::Slice& a) { return Op a.operator PrimExpr(); }
231 
232 #define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \
233  template <typename T> \
234  inline PrimExpr operator Op(const Tensor::Slice& a, const T& b) { \
235  return a.operator PrimExpr() Op b; \
236  } \
237  template <typename T> \
238  inline PrimExpr operator Op(const T& a, const Tensor::Slice& b) { \
239  return a Op b.operator PrimExpr(); \
240  } \
241  inline PrimExpr operator Op(const Tensor::Slice& a, const Tensor::Slice& b) { \
242  return a.operator PrimExpr() Op b.operator PrimExpr(); \
243  }
244 
260 
261 } // namespace te
262 } // namespace tvm
263 
264 namespace std {
265 template <>
266 struct hash<::tvm::te::Operation> : public ::tvm::ObjectPtrHash {};
267 
268 template <>
269 struct hash<::tvm::te::Tensor> {
270  std::size_t operator()(const ::tvm::te::Tensor& k) const {
271  ::tvm::ObjectPtrHash hasher;
272  if (k.defined() && k->op.defined()) {
273  return hasher(k->op);
274  } else {
275  return hasher(k);
276  }
277  }
278 };
279 } // namespace std
280 #endif // TVM_TE_TENSOR_H_
Bound deducers.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Reference to PrimExprNode.
Definition: expr.h:115
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Runtime primitive data type.
Definition: data_type.h:43
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
bool defined() const
Definition: object.h:552
const Object * get() const
Definition: object.h:554
base class of all object containers.
Definition: object.h:171
Reference to string objects.
Definition: string.h:98
Base class of all operation nodes.
Definition: operation.h:56
Operation that produces tensors.
Definition: tensor.h:47
Tensor output(size_t i) const
get the i-th output of the operation.
Operation(ObjectPtr< Object > n)
Definition: tensor.h:51
Operation()
default constructor
Definition: tensor.h:50
Node to represent a tensor.
Definition: tensor.h:68
Array< PrimExpr > shape
The shape of the tensor.
Definition: tensor.h:71
TensorNode()
constructor
Definition: tensor.h:79
void VisitAttrs(AttrVisitor *v)
Definition: tensor.h:81
DataType GetDataType() const final
Get the data type of the result.
Definition: tensor.h:90
String GetNameHint() const final
Get the name hint of the data producer.
Array< PrimExpr > GetShape() const final
Get the shape of the result.
Definition: tensor.h:88
DataType dtype
data type in the content of the tensor
Definition: tensor.h:73
Operation op
the source operation, can be None
Definition: tensor.h:75
data structure to represent a slice that fixes first k coordinates. This is used to enable syntax sug...
Definition: tensor.h:177
Slice operator[](PrimExpr i)
get i-th slice from the current slice.
Definition: tensor.h:187
Slice(const Tensor &tensor, std::vector< PrimExpr > indices)
Definition: tensor.h:180
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
bool operator==(const Tensor &other) const
check if two tensors equals each other.
Definition: tensor.h:216
PrimExpr operator()(Array< Var > indices) const
Take elements from the tensor.
PrimExpr IndexWithNegativeIndices(Args &&... args) const
Take elements from the tensor with support for negative indices.
Definition: tensor.h:156
Slice operator[](PrimExpr i) const
get i-th slice from the current Tensor.
Definition: tensor.h:208
PrimExpr IndexWithNegativeIndices(Array< PrimExpr > indices) const
Take elements from the tensor with support for negative indices.
size_t ndim() const
Definition: tensor.h:214
PrimExpr IndexWithNegativeIndices(Array< Var > indices) const
Take elements from the tensor with support for negative indices.
bool operator!=(const Tensor &other) const
check if two tensors are different.
Definition: tensor.h:226
PrimExpr operator()(Array< PrimExpr > indices) const
Take elements from the tensor.
PrimExpr operator()(Args &&... args) const
Take elements from the tensor.
Definition: tensor.h:134
TVM_DEFINE_OBJECT_REF_METHODS(Tensor, DataProducer, TensorNode)
Tensor(Array< PrimExpr > shape, DataType dtype, Operation op, int value_index)
Base node for data producers.
Definition: buffer.h:276
Managed reference to DataProducerNode.
Definition: buffer.h:313
PrimExpr operator==(const Tensor::Slice &a, const T &b)
Definition: tensor.h:250
PrimExpr operator!=(const Tensor::Slice &a, const T &b)
Definition: tensor.h:253
Definition: extracted_task.h:30
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1913
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:702
ObjectRef hash functor.
Definition: object.h:655
#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op)
Definition: tensor.h:229
#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op)
Definition: tensor.h:232
TIR expressions.
Common operators defined for Expr.