24 #ifndef TVM_TE_TENSOR_H_ 25 #define TVM_TE_TENSOR_H_ 32 #include <type_traits> 62 TVM_DLL
Tensor output(
size_t i)
const;
82 v->Visit(
"shape", &shape);
83 v->Visit(
"dtype", &dtype);
85 v->Visit(
"value_index", &value_index);
92 TVM_DLL
String GetNameHint()
const final;
94 static constexpr
const char* _type_key =
"Tensor";
119 inline bool operator==(
const Tensor& other)
const;
125 inline bool operator!=(
const Tensor& other)
const;
127 inline size_t ndim()
const;
133 template <
typename... Args>
136 return operator()(indices);
155 template <
typename... Args>
158 return IndexWithNegativeIndices(indices);
180 Slice(
const Tensor& tensor, std::vector<PrimExpr> indices)
181 : tensor_(tensor), indices_(indices) {}
188 std::vector<PrimExpr> other = indices_;
189 other.emplace_back(i);
190 return Slice(tensor_, other);
197 inline operator PrimExpr()
const {
return tensor_(indices_); }
200 const Tensor& tensor_;
201 std::vector<PrimExpr> indices_;
217 if (
get() == other.
get())
return true;
218 if (
get() ==
nullptr || other.
get() ==
nullptr)
return false;
219 if ((*this)->op.defined() || other->op.
defined()) {
220 return (*this)->op == other->op && (*this)->value_index == other->value_index;
229 #define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \ 230 inline PrimExpr operator Op(const Tensor::Slice& a) { return Op a.operator PrimExpr(); } 232 #define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \ 233 template <typename T> \ 234 inline PrimExpr operator Op(const Tensor::Slice& a, const T& b) { \ 235 return a.operator PrimExpr() Op b; \ 237 template <typename T> \ 238 inline PrimExpr operator Op(const T& a, const Tensor::Slice& b) { \ 239 return a Op b.operator PrimExpr(); \ 241 inline PrimExpr operator Op(const Tensor::Slice& a, const Tensor::Slice& b) { \ 242 return a.operator PrimExpr() Op b.operator PrimExpr(); \ 270 std::size_t operator()(const ::tvm::te::Tensor& k)
const {
272 if (k.defined() && k->op.defined()) {
273 return hasher(k->op);
280 #endif // TVM_TE_TENSOR_H_ Slice operator[](PrimExpr i) const
get i-th slice from the current Tensor.
Definition: tensor.h:208
Node to represent a tensor.
Definition: tensor.h:68
Operation(ObjectPtr< Object > n)
Definition: tensor.h:51
A custom smart pointer for Object.
Definition: object.h:358
bool operator!=(const Tensor &other) const
check if two tensors are different.
Definition: tensor.h:226
Base class of all operation nodes.
Definition: operation.h:56
bool operator==(const Tensor &other) const
check if two tensors equals each other.
Definition: tensor.h:216
#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op)
Definition: tensor.h:232
PrimExpr operator()(Args &&... args) const
Take elements from the tensor.
Definition: tensor.h:134
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Operation op
the source operation, can be None
Definition: tensor.h:75
Operation that produces tensors.
Definition: tensor.h:47
Definition: loop_state.h:456
PrimExpr operator!=(const Tensor::Slice &a, const T &b)
Definition: tensor.h:253
Array< PrimExpr > shape
The shape of the tensor.
Definition: tensor.h:71
base class of all object containers.
Definition: object.h:167
size_t ndim() const
Definition: tensor.h:214
Common operators defined for Expr.
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
bool defined() const
Definition: object.h:544
Runtime primitive data type.
Definition: data_type.h:41
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
ObjectRef hash functor.
Definition: object.h:624
DataType GetDataType() const final
Get the data type of the result.
Definition: tensor.h:90
Operation()
default constructor
Definition: tensor.h:50
Reference to string objects.
Definition: string.h:98
TensorNode()
constructor
Definition: tensor.h:79
Slice operator[](PrimExpr i)
get i-th slice from the current slice.
Definition: tensor.h:187
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1768
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
const Object * get() const
Definition: object.h:546
Base class of all object reference.
Definition: object.h:511
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
Managed reference to DataProducerNode.
Definition: buffer.h:293
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
Base node for data producers.
Definition: buffer.h:256
#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op)
Definition: tensor.h:229
PrimExpr operator==(const Tensor::Slice &a, const T &b)
Definition: tensor.h:250
data structure to represent a slice that fixes first k coordinates. This is used to enable syntax sug...
Definition: tensor.h:177
void VisitAttrs(AttrVisitor *v)
Definition: tensor.h:81
Array< PrimExpr > GetShape() const final
Get the shape of the result.
Definition: tensor.h:88
Definition: extracted_task.h:30
DataType dtype
data type in the content of the tensor
Definition: tensor.h:73
PrimExpr IndexWithNegativeIndices(Args &&... args) const
Take elements from the tensor with support for negative indices.
Definition: tensor.h:156
Reference to PrimExprNode.
Definition: expr.h:114
Slice(const Tensor &tensor, std::vector< PrimExpr > indices)
Definition: tensor.h:180