24 #ifndef TVM_TE_OPERATION_H_
25 #define TVM_TE_OPERATION_H_
28 #include <tvm/ffi/reflection/registry.h>
35 #include <unordered_map>
50 std::vector<std::vector<IntSet>>
data;
88 refl::ObjectDef<OperationNode>()
94 static constexpr
const char* _type_key =
"te.Operation";
116 refl::ObjectDef<PlaceholderOpNode>()
121 static constexpr
const char*
_type_key =
"te.PlaceholderOp";
152 refl::ObjectDef<BaseComputeOpNode>()
157 static constexpr
const char* _type_key =
"te.BaseComputeOp";
174 Array<
Tensor> InputTensors() const final;
176 static
void RegisterReflection() {
181 static constexpr
const char* _type_key =
"te.ComputeOp";
192 TVM_DLL
ComputeOp(std::string name, std::string tag, Map<String, ffi::Any> attrs,
193 Array<IterVar> axis, Array<PrimExpr> body);
237 refl::ObjectDef<ScanOpNode>()
257 TVM_DLL
ScanOp(std::string name, std::string tag, Optional<Map<String, ffi::Any>> attrs,
258 IterVar axis, Array<Tensor> init, Array<Tensor> update,
259 Array<Tensor> state_placeholder, Array<Tensor> input);
288 refl::ObjectDef<ExternOpNode>()
306 TVM_DLL
ExternOp(std::string name, std::string tag, Map<String, ffi::Any> attrs,
307 Array<Tensor> inputs, Array<Buffer> input_placeholders,
308 Array<Buffer> output_placeholders,
Stmt body);
349 std::string name =
"placeholder");
361 std::string tag =
"", Map<String, ffi::Any> attrs = {});
373 std::string name =
"tensor", std::string tag =
"",
374 Map<String, ffi::Any> attrs = {});
388 TVM_DLL Array<Tensor>
scan(Array<Tensor> init, Array<Tensor> update,
389 Array<Tensor> state_placeholder, Array<Tensor> inputs = Array<Tensor>(),
390 std::string name =
"scan", std::string tag =
"",
391 Map<String, ffi::Any> attrs = {});
395 std::string name =
"tensor", std::string tag =
"",
396 Map<String, ffi::Any> attrs = {}) {
397 FCompute fc = [f](
const Array<Var>& i) {
return f(i[0]); };
401 std::string name =
"tensor", std::string tag =
"",
402 Map<String, ffi::Any> attrs = {}) {
403 FCompute fc = [f](
const Array<Var>& i) {
return f(i[0], i[1]); };
407 std::string name =
"tensor", std::string tag =
"",
408 Map<String, ffi::Any> attrs = {}) {
409 FCompute fc = [f](
const Array<Var>& i) {
return f(i[0], i[1], i[2]); };
413 std::string name =
"tensor", std::string tag =
"",
414 Map<String, ffi::Any> attrs = {}) {
415 FCompute fc = [f](
const Array<Var>& i) {
return f(i[0], i[1], i[2], i[3]); };
Algebra expression simplifications.
Symbolic n-dimensional array, to represent a memory buffer.
Reference to PrimExprNode.
Definition: expr.h:129
Range container
Definition: expr.h:698
Runtime primitive data type.
Definition: data_type.h:47
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:291
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:274
A Compute op that compute a tensor on certain domain. This is the base class for ComputeOp (operating...
Definition: operation.h:141
Array< IterVar > axis
IterVar on each axis.
Definition: operation.h:144
static void RegisterReflection()
Definition: operation.h:150
Array< PrimExpr > output_shape(size_t idx) const final
Get shape of i-th output tensor.
Array< IterVar > reduce_axis
IterVar on each reduction axis, if the body is a Reduce.
Definition: operation.h:146
TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode)
A Compute op that compute a tensor on certain domain.
Definition: operation.h:165
Array< PrimExpr > body
the compute expression
Definition: operation.h:168
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode)
ComputeOpNode()
constructor
Definition: operation.h:170
int num_outputs() const final
Managed reference to ComputeOpNode.
Definition: operation.h:190
TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode)
ComputeOp(std::string name, std::string tag, Map< String, ffi::Any > attrs, Array< IterVar > axis, Array< PrimExpr > body)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeOpNode)
External computation that cannot be splitted.
Definition: operation.h:267
Array< Buffer > output_placeholders
Symbolic placeholder representation of outputs.
Definition: operation.h:274
Array< Tensor > inputs
The input tensors.
Definition: operation.h:270
int num_outputs() const final
TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode)
static void RegisterReflection()
Definition: operation.h:286
Stmt body
the statement that generates the computation.
Definition: operation.h:276
Array< PrimExpr > output_shape(size_t i) const final
Get shape of i-th output tensor.
ExternOpNode()
constructor
Definition: operation.h:279
static constexpr const char * _type_key
Definition: operation.h:295
Array< Buffer > input_placeholders
Symbolic placeholder representation of inputs.
Definition: operation.h:272
Array< Tensor > InputTensors() const final
List all the input Tensors.
DataType output_dtype(size_t i) const final
Get data type. i-th output tensor.
Managed reference to ExternOpNode.
Definition: operation.h:304
ExternOp(std::string name, std::string tag, Map< String, ffi::Any > attrs, Array< Tensor > inputs, Array< Buffer > input_placeholders, Array< Buffer > output_placeholders, Stmt body)
TVM_DEFINE_OBJECT_REF_METHODS(ExternOp, Operation, ExternOpNode)
Base class of all operation nodes.
Definition: operation.h:56
virtual ~OperationNode()
Definition: operation.h:65
TVM_DECLARE_BASE_OBJECT_INFO(OperationNode, Object)
static void RegisterReflection()
Definition: operation.h:86
virtual DataType output_dtype(size_t i) const =0
Get data type. i-th output tensor.
Map< String, ffi::Any > attrs
additional attributes of the operation
Definition: operation.h:63
virtual Array< PrimExpr > output_shape(size_t i) const =0
Get shape of i-th output tensor.
virtual int num_outputs() const =0
virtual Array< Tensor > InputTensors() const =0
List all the input Tensors.
std::string name
optional name of the operation
Definition: operation.h:59
std::string tag
optional tag of the operation
Definition: operation.h:61
Operation that produces tensors.
Definition: tensor.h:48
const OperationNode * operator->() const
access the internal node container
Definition: operation.h:420
A placeholder op represents an input placeholder.
Definition: operation.h:102
Array< PrimExpr > output_shape(size_t i) const final
Get shape of i-th output tensor.
Array< PrimExpr > shape
The shape of the input.
Definition: operation.h:105
TVM_DECLARE_BASE_OBJECT_INFO(PlaceholderOpNode, OperationNode)
DataType dtype
The data type of the input.
Definition: operation.h:107
static void RegisterReflection()
Definition: operation.h:114
static constexpr const char * _type_key
Definition: operation.h:121
Array< Tensor > InputTensors() const final
List all the input Tensors.
int num_outputs() const final
DataType output_dtype(size_t i) const final
Get data type. i-th output tensor.
Managed reference to PlaceholderOpNode.
Definition: operation.h:130
TVM_DEFINE_OBJECT_REF_METHODS(PlaceholderOp, Operation, PlaceholderOpNode)
PlaceholderOp(std::string name, Array< PrimExpr > shape, DataType dtype)
Symbolic scan.
Definition: operation.h:202
ScanOpNode()
constructor
Definition: operation.h:228
int num_outputs() const final
TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode)
DataType output_dtype(size_t i) const final
Get data type. i-th output tensor.
Array< Tensor > state_placeholder
The placeholder to refer as states in update.
Definition: operation.h:211
Array< Tensor > init
the initialization tensors
Definition: operation.h:207
static void RegisterReflection()
Definition: operation.h:235
IterVar scan_axis
IterVar to scan over.
Definition: operation.h:205
Array< IterVar > spatial_axis_
Spatial axis to indicate spatial dimension of each output. They corresponds to flattened spatial axis...
Definition: operation.h:226
Array< Tensor > inputs
the inputs to the scan, these are optionally provided But they can be helpful to provide hints to spe...
Definition: operation.h:216
Array< Tensor > InputTensors() const final
List all the input Tensors.
Array< Tensor > update
the update function represented by tensor
Definition: operation.h:209
static constexpr const char * _type_key
Definition: operation.h:246
Array< PrimExpr > output_shape(size_t i) const final
Get shape of i-th output tensor.
Managed reference to ScanOpNode.
Definition: operation.h:255
ScanOp(std::string name, std::string tag, Optional< Map< String, ffi::Any >> attrs, IterVar axis, Array< Tensor > init, Array< Tensor > update, Array< Tensor > state_placeholder, Array< Tensor > input)
TVM_DEFINE_OBJECT_REF_METHODS(ScanOp, Operation, ScanOpNode)
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:100
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:298
Container of all statements.
Definition: stmt.h:64
a named variable in TIR
Definition: var.h:78
Definition: repr_printer.h:91
Tensor placeholder(Array< PrimExpr > shape, DataType dtype=DataType::Float(32), std::string name="placeholder")
create a place holder tensor.
Array< Tensor > scan(Array< Tensor > init, Array< Tensor > update, Array< Tensor > state_placeholder, Array< Tensor > inputs=Array< Tensor >(), std::string name="scan", std::string tag="", Map< String, ffi::Any > attrs={})
Construct new tensors by scan.
std::function< PrimExpr(const Array< Var > &i)> FCompute
The compute function to specify the input source of a Tensor.
Definition: operation.h:337
IterVar thread_axis(Range dom, std::string tag)
Create a new IterVar that represents an axis in thread.
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
std::function< Array< PrimExpr >(const Array< Var > &i)> FBatchCompute
The compute function to specify the inputs source of Tensors.
Definition: operation.h:340
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ffi::Any > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1945
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
Temporary data structure to store union of bounds of each axis of Tensor.
Definition: operation.h:46
TensorDom(int ndim)
Definition: operation.h:48
std::vector< std::vector< IntSet > > data
The domain data.
Definition: operation.h:50
Common operators defined for Expr.