24 #ifndef TVM_TE_OPERATION_H_ 25 #define TVM_TE_OPERATION_H_ 35 #include <unordered_map> 50 std::vector<std::vector<IntSet>>
data;
67 virtual int num_outputs()
const = 0;
78 virtual DataType output_dtype(
size_t i)
const = 0;
98 const std::unordered_map<Tensor, Tensor>& rmap)
const = 0;
109 const std::unordered_map<const VarNode*, IntSet>& dom_map,
110 std::unordered_map<Tensor, TensorDom>* out_dom_map)
const = 0;
119 virtual void GatherBound(
const Operation&
self,
120 const std::unordered_map<Tensor, TensorDom>& tensor_dom,
121 std::unordered_map<IterVar, Range>* out_dom_map)
const = 0;
131 virtual Stmt BuildRealize(
const Stage& stage,
132 const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body,
133 String storage_scope =
"")
const = 0;
141 virtual Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
142 bool debug_keep_trivial_loop)
const = 0;
144 static constexpr
const char* _type_key =
"Operation";
159 int num_outputs() const final;
161 DataType output_dtype(
size_t i) const final;
165 const
std::unordered_map<
Tensor, Tensor>& rmap) const final;
166 void PropBoundToInputs(const
Operation& self, arith::Analyzer* analyzer,
168 std::unordered_map<Tensor,
TensorDom>* out_dom_map) const final;
171 Stmt BuildRealize(const
Stage& stage, const
std::unordered_map<IterVar,
Range>& realize_map,
172 const
Stmt& body,
String storage_scope = "") const final;
173 Stmt BuildProvide(const
Stage& stage, const
std::unordered_map<IterVar,
Range>& dom_map,
174 bool debug_keep_trivial_loop) const final;
177 v->Visit(
"name", &name);
178 v->Visit(
"tag", &tag);
179 v->Visit(
"attrs", &attrs);
180 v->Visit(
"shape", &shape);
181 v->Visit(
"dtype", &dtype);
184 static constexpr
const char* _type_key =
"PlaceholderOp";
215 Stmt BuildRealize(const
Stage& stage, const
std::unordered_map<IterVar,
Range>& realize_map,
216 const
Stmt& body,
String storage_scope = "") const final;
217 virtual
size_t num_schedulable_dims() const = 0;
219 static constexpr const
char* _type_key = "BaseComputeOp";
233 int num_outputs()
const final;
234 DataType output_dtype(
size_t i)
const final;
237 const std::unordered_map<Tensor, Tensor>& rmap)
const final;
239 const std::unordered_map<const VarNode*, IntSet>& dom_map,
240 std::unordered_map<Tensor, TensorDom>* out_dom_map)
const final;
241 Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
242 bool debug_keep_trivial_loop)
const final;
243 size_t num_schedulable_dims()
const final;
246 v->Visit(
"name", &name);
247 v->Visit(
"tag", &tag);
248 v->Visit(
"attrs", &attrs);
249 v->Visit(
"axis", &axis);
250 v->Visit(
"reduce_axis", &reduce_axis);
251 v->Visit(
"body", &body);
254 static constexpr
const char* _type_key =
"ComputeOp";
289 int num_outputs()
const final;
290 DataType output_dtype(
size_t i)
const final;
293 const std::unordered_map<Tensor, Tensor>& rmap)
const final;
295 const std::unordered_map<const VarNode*, IntSet>& dom_map,
296 std::unordered_map<Tensor, TensorDom>* out_dom_map)
const final;
297 Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
298 bool debug_keep_trivial_loop)
const final;
299 size_t num_schedulable_dims()
const final;
302 v->Visit(
"name", &name);
303 v->Visit(
"tag", &tag);
304 v->Visit(
"axis", &axis);
305 v->Visit(
"reduce_axis", &reduce_axis);
306 v->Visit(
"schedulable_ndim", &schedulable_ndim);
307 v->Visit(
"intrin", &intrin);
308 v->Visit(
"inputs", &inputs);
309 v->Visit(
"input_regions", &input_regions);
310 v->Visit(
"scalar_inputs", &scalar_inputs);
313 static constexpr
const char* _type_key =
"TensorComputeOp";
362 int num_outputs()
const final;
364 DataType output_dtype(
size_t i)
const final;
368 const std::unordered_map<Tensor, Tensor>& rmap)
const final;
370 const std::unordered_map<const VarNode*, IntSet>& dom_map,
371 std::unordered_map<Tensor, TensorDom>* out_dom_map)
const final;
372 void GatherBound(
const Operation&
self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
373 std::unordered_map<IterVar, Range>* out_dom_map)
const final;
374 Stmt BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
375 const Stmt& body,
String storage_scope =
"")
const final;
376 Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
377 bool debug_keep_trivial_loop)
const final;
380 v->Visit(
"name", &name);
381 v->Visit(
"tag", &tag);
382 v->Visit(
"attrs", &attrs);
383 v->Visit(
"scan_axis", &scan_axis);
384 v->Visit(
"init", &init);
385 v->Visit(
"update", &update);
386 v->Visit(
"state_placeholder", &state_placeholder);
387 v->Visit(
"inputs", &inputs);
388 v->Visit(
"spatial_axis_", &spatial_axis_);
391 static constexpr
const char* _type_key =
"ScanOp";
425 int num_outputs()
const final;
427 DataType output_dtype(
size_t i)
const final;
431 const std::unordered_map<Tensor, Tensor>& rmap)
const final;
433 const std::unordered_map<const VarNode*, IntSet>& dom_map,
434 std::unordered_map<Tensor, TensorDom>* out_dom_map)
const final;
435 void GatherBound(
const Operation&
self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
436 std::unordered_map<IterVar, Range>* out_dom_map)
const final;
437 Stmt BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
438 const Stmt& body,
String storage_scope =
"")
const final;
439 Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
440 bool debug_keep_trivial_loop)
const final;
443 v->Visit(
"name", &name);
444 v->Visit(
"tag", &tag);
445 v->Visit(
"attrs", &attrs);
446 v->Visit(
"inputs", &inputs);
447 v->Visit(
"input_placeholders", &input_placeholders);
448 v->Visit(
"output_placeholders", &output_placeholders);
449 v->Visit(
"body", &body);
452 static constexpr
const char* _type_key =
"ExternOp";
490 int num_outputs()
const final;
492 DataType output_dtype(
size_t i)
const final;
496 const std::unordered_map<Tensor, Tensor>& rmap)
const final;
498 const std::unordered_map<const VarNode*, IntSet>& dom_map,
499 std::unordered_map<Tensor, TensorDom>* out_dom_map)
const final;
500 void GatherBound(
const Operation&
self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
501 std::unordered_map<IterVar, Range>* out_dom_map)
const final;
502 Stmt BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
503 const Stmt& body,
String storage_scope =
"")
const final;
504 Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
505 bool debug_keep_trivial_loop)
const final;
508 v->Visit(
"name", &name);
509 v->Visit(
"tag", &tag);
510 v->Visit(
"attrs", &attrs);
511 v->Visit(
"inputs", &inputs);
512 v->Visit(
"outputs", &outputs);
513 v->Visit(
"axis", &axis);
514 v->Visit(
"body", &body);
517 static constexpr
const char* _type_key =
"HybridOp";
557 using FCompute = std::function<PrimExpr(const Array<Var>& i)>;
569 std::string name =
"placeholder");
593 std::string name =
"tensor", std::string tag =
"",
610 std::string name =
"scan", std::string tag =
"",
615 std::string name =
"tensor", std::string tag =
"",
618 return compute(shape, fc, name, tag, attrs);
621 std::string name =
"tensor", std::string tag =
"",
624 return compute(shape, fc, name, tag, attrs);
627 std::string name =
"tensor", std::string tag =
"",
630 return compute(shape, fc, name, tag, attrs);
633 std::string name =
"tensor", std::string tag =
"",
636 return compute(shape, fc, name, tag, attrs);
645 #endif // TVM_TE_OPERATION_H_ IterVar thread_axis(Range dom, std::string tag)
Create a new IterVar that represents an axis in thread.
Array< Tensor > outputs
Symbolic placeholder representation of outputs.
Definition: operation.h:477
ComputeOpNode()
constructor
Definition: operation.h:231
Stmt body
the statement that generates the computation. This is slightly different from the body in ExternOpNod...
Definition: operation.h:485
std::string name
optional name of the operation
Definition: operation.h:59
Tensor placeholder(Array< PrimExpr > shape, DataType dtype=DataType::Float(32), std::string name="placeholder")
create a place holder tensor.
Managed reference to TensorComputeOpNode.
Definition: operation.h:321
Base class of all operation nodes.
Definition: operation.h:56
Managed reference to PlaceholderOpNode.
Definition: operation.h:192
int schedulable_ndim
number of axes that can be scheduled
Definition: operation.h:277
External computation that cannot be splitted.
Definition: operation.h:411
A computation operator that generated by hybrid script.
Definition: operation.h:472
std::string tag
optional tag of the operation
Definition: operation.h:61
IterVar scan_axis
IterVar to scan over.
Definition: operation.h:337
Stmt body
the statement that generates the computation.
Definition: operation.h:420
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
Array< IterVar > axis
The axis of iterations.
Definition: operation.h:479
A Compute op that compute a tensor on certain domain. This is the base class for ComputeOp (operating...
Definition: operation.h:204
Operation that produces tensors.
Definition: tensor.h:47
a named variable in TIR
Definition: var.h:88
Algebra expression simplifications.
HybridOpNode()
constructor
Definition: operation.h:488
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:308
A placeholder op represents an input placeholder.
Definition: operation.h:152
Definition: loop_state.h:456
A variable node in the IR.
Definition: var.h:47
Stage, contains scheduling for a stage of computation.
Definition: schedule.h:58
DataType dtype
The data type of the input.
Definition: operation.h:157
Symbolic scan.
Definition: operation.h:334
base class of all object containers.
Definition: object.h:167
Common operators defined for Expr.
ExternOpNode()
constructor
Definition: operation.h:423
Map< String, ObjectRef > attrs
additional attributes of the operation
Definition: operation.h:63
Array< Buffer > output_placeholders
Symbolic placeholder representation of outputs.
Definition: operation.h:418
std::function< PrimExpr(const Array< Var > &i)> FCompute
The compute function to specify the input source of a Tensor.
Definition: operation.h:557
Array< Tensor > inputs
The input tensors.
Definition: operation.h:475
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Range constainer.
Definition: expr.h:715
void VisitAttrs(AttrVisitor *v)
Definition: operation.h:507
Managed reference to TensorIntrinNode.
Definition: tensor_intrin.h:93
TensorComputeOpNode()
constructor
Definition: operation.h:287
Runtime primitive data type.
Definition: data_type.h:41
std::vector< std::vector< IntSet > > data
The domain data.
Definition: operation.h:50
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:178
Managed reference to IntSetNode.
Definition: int_set.h:68
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void VisitAttrs(AttrVisitor *v)
Definition: operation.h:379
std::function< Array< PrimExpr >(const Array< Var > &i)> FBatchCompute
The compute function to specify the inputs source of Tensors.
Definition: operation.h:560
Container of all statements.
Definition: stmt.h:59
A Compute op that compute a tensor on certain domain.
Definition: operation.h:226
Array< IterVar > axis
IterVar on each axis.
Definition: operation.h:207
Managed reference to ExternOpNode.
Definition: operation.h:460
Reference to string objects.
Definition: string.h:98
Array< Tensor > inputs
input tensors of intrin
Definition: operation.h:281
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1768
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
Managed reference to ScanOpNode.
Definition: operation.h:399
ScanOpNode()
constructor
Definition: operation.h:360
Array< Tensor > state_placeholder
The placeholder to refer as states in update.
Definition: operation.h:343
A TenorCompute op that compute a tensor with an tensor intrinsic.
Definition: operation.h:274
virtual ~OperationNode()
Definition: operation.h:65
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
void VisitAttrs(AttrVisitor *v)
Definition: operation.h:245
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:785
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:102
void VisitAttrs(AttrVisitor *v)
Definition: operation.h:301
Array< IterVar > spatial_axis_
Spatial axis to indicate spatial dimension of each output. They corresponds to flattened spatial axis...
Definition: operation.h:358
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
Symbolic n-dimensional array, to represent a memory buffer.
Array< Tensor > init
the initialization tensors
Definition: operation.h:339
Array< Region > input_regions
region of input tensors
Definition: operation.h:283
Array< Tensor > inputs
The input tensors.
Definition: operation.h:414
Array< Buffer > input_placeholders
Symbolic placeholder representation of inputs.
Definition: operation.h:416
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
Array< Tensor > update
the update function represented by tensor
Definition: operation.h:341
Array< Tensor > inputs
the inputs to the scan, these are optionally provided But they can be helpful to provide hints to spe...
Definition: operation.h:348
Array< PrimExpr > shape
The shape of the input.
Definition: operation.h:155
Managed reference to HybridOpNode.
Definition: operation.h:525
Array< Tensor > scan(Array< Tensor > init, Array< Tensor > update, Array< Tensor > state_placeholder, Array< Tensor > inputs=Array< Tensor >(), std::string name="scan", std::string tag="", Map< String, ObjectRef > attrs={})
Construct new tensors by scan.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ObjectRef > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Managed reference to ComputeOpNode.
Definition: operation.h:262
const OperationNode * operator->() const
access the internal node container
Definition: operation.h:640
TensorIntrin intrin
TensorIntrin used to compute.
Definition: operation.h:279
Reference to PrimExprNode.
Definition: expr.h:114
Array< PrimExpr > scalar_inputs
scalar expression inputs
Definition: operation.h:285
Temporary data structure to store union of bounds of each axis of Tensor.
Definition: operation.h:46
void VisitAttrs(AttrVisitor *v)
Definition: operation.h:442
TensorDom(int ndim)
Definition: operation.h:48
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
helper macro to declare a base object type that can be inherited.
Definition: object.h:648
Array< IterVar > reduce_axis
IterVar on each reduction axis, if the body is a Reduce.
Definition: operation.h:209
Array< PrimExpr > body
the compute expression
Definition: operation.h:229
Analyzer that contains bunch of sub-analyzers.
Definition: analyzer.h:579
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:164