24 #ifndef TVM_TE_TENSOR_H_
25 #define TVM_TE_TENSOR_H_
28 #include <tvm/ffi/reflection/registry.h>
33 #include <type_traits>
52 explicit Operation(ObjectPtr<Object> n) : ObjectRef(n) {}
53 explicit Operation(ffi::UnsafeInit tag) : ObjectRef(tag) {}
89 TVM_DLL ffi::String GetNameHint() const final;
91 static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode;
108 inline PrimExpr IndexTensor(ffi::Array<PrimExpr> indices,
bool support_negative_indices)
const;
125 inline size_t ndim()
const;
131 template <
typename... Args>
133 ffi::Array<PrimExpr> indices{std::forward<Args>(args)...};
134 return operator()(indices);
153 template <
typename... Args>
155 ffi::Array<PrimExpr> indices{std::forward<Args>(args)...};
156 return IndexWithNegativeIndices(indices);
179 : tensor_(tensor), indices_(indices) {}
186 std::vector<PrimExpr> other = indices_;
187 other.emplace_back(i);
188 return Slice(tensor_, other);
195 inline operator PrimExpr()
const {
return tensor_(indices_); }
199 std::vector<PrimExpr> indices_;
215 if (get() == other.get())
return true;
216 if (get() ==
nullptr || other.get() ==
nullptr)
return false;
217 if ((*this)->op.defined() || other->op.defined()) {
218 return (*this)->op == other->op && (*this)->value_index == other->value_index;
227 #define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \
228 inline PrimExpr operator Op(const Tensor::Slice& a) { return Op a.operator PrimExpr(); }
230 #define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \
231 template <typename T> \
232 inline PrimExpr operator Op(const Tensor::Slice& a, const T& b) { \
233 return a.operator PrimExpr() Op b; \
235 template <typename T> \
236 inline PrimExpr operator Op(const T& a, const Tensor::Slice& b) { \
237 return a Op b.operator PrimExpr(); \
239 inline PrimExpr operator Op(const Tensor::Slice& a, const Tensor::Slice& b) { \
240 return a.operator PrimExpr() Op b.operator PrimExpr(); \
268 std::size_t operator()(const ::tvm::te::Tensor& k)
const {
269 ::tvm::ObjectPtrHash hasher;
270 if (k.defined() && k->op.defined()) {
271 return hasher(k->op);
Reference to PrimExprNode.
Definition: expr.h:124
Runtime primitive data type.
Definition: data_type.h:47
Base class of all operation nodes.
Definition: operation.h:56
Operation that produces tensors.
Definition: tensor.h:48
Tensor output(size_t i) const
get the i-th output of the operation.
Operation(ObjectPtr< Object > n)
Definition: tensor.h:52
Operation()
default constructor
Definition: tensor.h:51
Operation(ffi::UnsafeInit tag)
Definition: tensor.h:53
Node to represent a tensor.
Definition: tensor.h:70
PrimExpr ToPrimExpr() const final
DataType GetDataType() const final
Get the data type of the result.
Definition: tensor.h:85
DataType dtype
data type in the content of the tensor
Definition: tensor.h:75
ffi::Array< PrimExpr > shape
The shape of the tensor.
Definition: tensor.h:73
static void RegisterReflection()
Operation op
the source operation, can be None
Definition: tensor.h:77
ffi::Array< PrimExpr > GetShape() const final
Get the shape of the result.
Definition: tensor.h:83
data structure to represent a slice that fixes first k coordinates. This is used to enable syntax sug...
Definition: tensor.h:175
Slice operator[](PrimExpr i)
get i-th slice from the current slice.
Definition: tensor.h:185
Slice(const Tensor &tensor, std::vector< PrimExpr > indices)
Definition: tensor.h:178
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:100
PrimExpr IndexWithNegativeIndices(ffi::Array< PrimExpr > indices) const
Take elements from the tensor with support for negative indices.
bool operator==(const Tensor &other) const
check if two tensors equals each other.
Definition: tensor.h:214
Tensor(ffi::Array< PrimExpr > shape, DataType dtype, Operation op, int value_index)
PrimExpr IndexWithNegativeIndices(Args &&... args) const
Take elements from the tensor with support for negative indices.
Definition: tensor.h:154
Slice operator[](PrimExpr i) const
get i-th slice from the current Tensor.
Definition: tensor.h:206
PrimExpr IndexWithNegativeIndices(ffi::Array< Var > indices) const
Take elements from the tensor with support for negative indices.
size_t ndim() const
Definition: tensor.h:212
PrimExpr operator()(ffi::Array< PrimExpr > indices) const
Take elements from the tensor.
bool operator!=(const Tensor &other) const
check if two tensors are different.
Definition: tensor.h:224
PrimExpr operator()(Args &&... args) const
Take elements from the tensor.
Definition: tensor.h:132
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Tensor, DataProducer, TensorNode)
PrimExpr operator()(ffi::Array< Var > indices) const
Take elements from the tensor.
Base node for data producers.
Definition: buffer.h:260
Managed reference to DataProducerNode.
Definition: buffer.h:286
PrimExpr operator==(const Tensor::Slice &a, const T &b)
Definition: tensor.h:248
PrimExpr operator!=(const Tensor::Slice &a, const T &b)
Definition: tensor.h:251
Definition: extracted_task.h:30
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1960
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op)
Definition: tensor.h:227
#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op)
Definition: tensor.h:230
Common operators defined for Expr.