tvm
operation.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TE_OPERATION_H_
25 #define TVM_TE_OPERATION_H_
26 
27 #include <tvm/arith/analyzer.h>
28 #include <tvm/ffi/reflection/registry.h>
29 #include <tvm/te/tensor.h>
30 #include <tvm/tir/buffer.h>
31 #include <tvm/tir/expr.h>
32 #include <tvm/tir/op.h>
33 
34 #include <string>
35 #include <unordered_map>
36 #include <vector>
37 
38 namespace tvm {
40 namespace te {
41 
46 struct TensorDom {
47  // constructor
48  explicit TensorDom(int ndim) : data(ndim) {}
50  std::vector<std::vector<IntSet>> data;
51 };
52 
56 class TVM_DLL OperationNode : public Object {
57  public:
59  std::string name;
61  std::string tag;
63  Map<String, ffi::Any> attrs;
64  // virtual destructor.
65  virtual ~OperationNode() {}
67  virtual int num_outputs() const = 0;
73  virtual DataType output_dtype(size_t i) const = 0;
79  virtual Array<PrimExpr> output_shape(size_t i) const = 0;
84  virtual Array<Tensor> InputTensors() const = 0;
85 
86  static void RegisterReflection() {
87  namespace refl = tvm::ffi::reflection;
88  refl::ObjectDef<OperationNode>()
89  .def_ro("name", &OperationNode::name)
90  .def_ro("tag", &OperationNode::tag)
91  .def_ro("attrs", &OperationNode::attrs);
92  }
93 
94  static constexpr const char* _type_key = "te.Operation";
95 
97 };
98 
103  public:
105  Array<PrimExpr> shape;
108  // override behavior.
109  int num_outputs() const final;
110  DataType output_dtype(size_t i) const final;
111  Array<PrimExpr> output_shape(size_t i) const final;
112  Array<Tensor> InputTensors() const final;
113 
114  static void RegisterReflection() {
115  namespace refl = tvm::ffi::reflection;
116  refl::ObjectDef<PlaceholderOpNode>()
117  .def_ro("shape", &PlaceholderOpNode::shape)
118  .def_ro("dtype", &PlaceholderOpNode::dtype);
119  }
120 
121  static constexpr const char* _type_key = "te.PlaceholderOp";
122 
124 };
125 
130 class PlaceholderOp : public Operation {
131  public:
132  TVM_DLL PlaceholderOp(std::string name, Array<PrimExpr> shape, DataType dtype);
133 
135 };
136 
141 class TVM_DLL BaseComputeOpNode : public OperationNode {
142  public:
144  Array<IterVar> axis;
146  Array<IterVar> reduce_axis;
147  // override functions
148  Array<PrimExpr> output_shape(size_t idx) const final;
149 
150  static void RegisterReflection() {
151  namespace refl = tvm::ffi::reflection;
152  refl::ObjectDef<BaseComputeOpNode>()
153  .def_ro("axis", &BaseComputeOpNode::axis)
154  .def_ro("reduce_axis", &BaseComputeOpNode::reduce_axis);
155  }
156 
157  static constexpr const char* _type_key = "te.BaseComputeOp";
158 
160 };
161 
165 class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
166  public:
168  Array<PrimExpr> body;
171  // override functions
172  int num_outputs() const final;
173  DataType output_dtype(size_t i) const final;
174  Array<Tensor> InputTensors() const final;
175 
176  static void RegisterReflection() {
177  namespace refl = tvm::ffi::reflection;
178  refl::ObjectDef<ComputeOpNode>().def_ro("body", &ComputeOpNode::body);
179  }
180 
181  static constexpr const char* _type_key = "te.ComputeOp";
182 
184 };
185 
190 class ComputeOp : public Operation {
191  public:
192  TVM_DLL ComputeOp(std::string name, std::string tag, Map<String, ffi::Any> attrs,
193  Array<IterVar> axis, Array<PrimExpr> body);
194 
197 };
198 
202 class ScanOpNode : public OperationNode {
203  public:
207  Array<Tensor> init;
209  Array<Tensor> update;
211  Array<Tensor> state_placeholder;
216  Array<Tensor> inputs;
226  Array<IterVar> spatial_axis_;
229  // override behavior.
230  int num_outputs() const final;
231  DataType output_dtype(size_t i) const final;
232  Array<PrimExpr> output_shape(size_t i) const final;
233  Array<Tensor> InputTensors() const final;
234 
235  static void RegisterReflection() {
236  namespace refl = tvm::ffi::reflection;
237  refl::ObjectDef<ScanOpNode>()
238  .def_ro("scan_axis", &ScanOpNode::scan_axis)
239  .def_ro("init", &ScanOpNode::init)
240  .def_ro("update", &ScanOpNode::update)
241  .def_ro("state_placeholder", &ScanOpNode::state_placeholder)
242  .def_ro("inputs", &ScanOpNode::inputs)
243  .def_ro("spatial_axis_", &ScanOpNode::spatial_axis_);
244  }
245 
246  static constexpr const char* _type_key = "te.ScanOp";
247 
249 };
250 
255 class ScanOp : public Operation {
256  public:
257  TVM_DLL ScanOp(std::string name, std::string tag, Optional<Map<String, ffi::Any>> attrs,
258  IterVar axis, Array<Tensor> init, Array<Tensor> update,
259  Array<Tensor> state_placeholder, Array<Tensor> input);
260 
262 };
263 
267 class ExternOpNode : public OperationNode {
268  public:
270  Array<Tensor> inputs;
272  Array<Buffer> input_placeholders;
274  Array<Buffer> output_placeholders;
277 
280  // override functions
281  int num_outputs() const final;
282  DataType output_dtype(size_t i) const final;
283  Array<PrimExpr> output_shape(size_t i) const final;
284  Array<Tensor> InputTensors() const final;
285 
286  static void RegisterReflection() {
287  namespace refl = tvm::ffi::reflection;
288  refl::ObjectDef<ExternOpNode>()
289  .def_ro("inputs", &ExternOpNode::inputs)
290  .def_ro("input_placeholders", &ExternOpNode::input_placeholders)
291  .def_ro("output_placeholders", &ExternOpNode::output_placeholders)
292  .def_ro("body", &ExternOpNode::body);
293  }
294 
295  static constexpr const char* _type_key = "te.ExternOp";
296 
298 };
299 
304 class ExternOp : public Operation {
305  public:
306  TVM_DLL ExternOp(std::string name, std::string tag, Map<String, ffi::Any> attrs,
307  Array<Tensor> inputs, Array<Buffer> input_placeholders,
308  Array<Buffer> output_placeholders, Stmt body);
309 
311 };
312 
318 TVM_DLL Var var(std::string name_hint, DataType t = DataType::Int(32));
319 
326 TVM_DLL IterVar thread_axis(Range dom, std::string tag);
327 
334 TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv");
335 
337 using FCompute = std::function<PrimExpr(const Array<Var>& i)>;
338 
340 using FBatchCompute = std::function<Array<PrimExpr>(const Array<Var>& i)>;
341 
348 TVM_DLL Tensor placeholder(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
349  std::string name = "placeholder");
350 
360 TVM_DLL Tensor compute(Array<PrimExpr> shape, FCompute fcompute, std::string name = "tensor",
361  std::string tag = "", Map<String, ffi::Any> attrs = {});
362 
372 TVM_DLL Array<Tensor> compute(Array<PrimExpr> shape, FBatchCompute fcompute,
373  std::string name = "tensor", std::string tag = "",
374  Map<String, ffi::Any> attrs = {});
375 
388 TVM_DLL Array<Tensor> scan(Array<Tensor> init, Array<Tensor> update,
389  Array<Tensor> state_placeholder, Array<Tensor> inputs = Array<Tensor>(),
390  std::string name = "scan", std::string tag = "",
391  Map<String, ffi::Any> attrs = {});
392 
393 // same as compute, specialized for different fcompute function
394 inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var)> f,
395  std::string name = "tensor", std::string tag = "",
396  Map<String, ffi::Any> attrs = {}) {
397  FCompute fc = [f](const Array<Var>& i) { return f(i[0]); };
398  return compute(shape, fc, name, tag, attrs);
399 }
400 inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var)> f,
401  std::string name = "tensor", std::string tag = "",
402  Map<String, ffi::Any> attrs = {}) {
403  FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1]); };
404  return compute(shape, fc, name, tag, attrs);
405 }
406 inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var, Var)> f,
407  std::string name = "tensor", std::string tag = "",
408  Map<String, ffi::Any> attrs = {}) {
409  FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1], i[2]); };
410  return compute(shape, fc, name, tag, attrs);
411 }
412 inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var, Var, Var)> f,
413  std::string name = "tensor", std::string tag = "",
414  Map<String, ffi::Any> attrs = {}) {
415  FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
416  return compute(shape, fc, name, tag, attrs);
417 }
418 
419 // inline function.
420 inline const OperationNode* Operation::operator->() const {
421  return static_cast<const OperationNode*>(get());
422 }
423 } // namespace te
424 } // namespace tvm
425 #endif // TVM_TE_OPERATION_H_
Algebra expression simplifications.
Symbolic n-dimensional array, to represent a memory buffer.
Reference to PrimExprNode.
Definition: expr.h:129
Range container
Definition: expr.h:698
Runtime primitive data type.
Definition: data_type.h:47
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:291
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:274
A Compute op that compute a tensor on certain domain. This is the base class for ComputeOp (operating...
Definition: operation.h:141
Array< IterVar > axis
IterVar on each axis.
Definition: operation.h:144
static void RegisterReflection()
Definition: operation.h:150
Array< PrimExpr > output_shape(size_t idx) const final
Get shape of i-th output tensor.
Array< IterVar > reduce_axis
IterVar on each reduction axis, if the body is a Reduce.
Definition: operation.h:146
TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode)
A Compute op that compute a tensor on certain domain.
Definition: operation.h:165
Array< PrimExpr > body
the compute expression
Definition: operation.h:168
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode)
ComputeOpNode()
constructor
Definition: operation.h:170
int num_outputs() const final
Managed reference to ComputeOpNode.
Definition: operation.h:190
TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode)
ComputeOp(std::string name, std::string tag, Map< String, ffi::Any > attrs, Array< IterVar > axis, Array< PrimExpr > body)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeOpNode)
External computation that cannot be splitted.
Definition: operation.h:267
Array< Buffer > output_placeholders
Symbolic placeholder representation of outputs.
Definition: operation.h:274
Array< Tensor > inputs
The input tensors.
Definition: operation.h:270
int num_outputs() const final
TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode)
static void RegisterReflection()
Definition: operation.h:286
Stmt body
the statement that generates the computation.
Definition: operation.h:276
Array< PrimExpr > output_shape(size_t i) const final
Get shape of i-th output tensor.
ExternOpNode()
constructor
Definition: operation.h:279
static constexpr const char * _type_key
Definition: operation.h:295
Array< Buffer > input_placeholders
Symbolic placeholder representation of inputs.
Definition: operation.h:272
Array< Tensor > InputTensors() const final
List all the input Tensors.
DataType output_dtype(size_t i) const final
Get data type. i-th output tensor.
Managed reference to ExternOpNode.
Definition: operation.h:304
ExternOp(std::string name, std::string tag, Map< String, ffi::Any > attrs, Array< Tensor > inputs, Array< Buffer > input_placeholders, Array< Buffer > output_placeholders, Stmt body)
TVM_DEFINE_OBJECT_REF_METHODS(ExternOp, Operation, ExternOpNode)
Base class of all operation nodes.
Definition: operation.h:56
virtual ~OperationNode()
Definition: operation.h:65
TVM_DECLARE_BASE_OBJECT_INFO(OperationNode, Object)
static void RegisterReflection()
Definition: operation.h:86
virtual DataType output_dtype(size_t i) const =0
Get data type. i-th output tensor.
Map< String, ffi::Any > attrs
additional attributes of the operation
Definition: operation.h:63
virtual Array< PrimExpr > output_shape(size_t i) const =0
Get shape of i-th output tensor.
virtual int num_outputs() const =0
virtual Array< Tensor > InputTensors() const =0
List all the input Tensors.
std::string name
optional name of the operation
Definition: operation.h:59
std::string tag
optional tag of the operation
Definition: operation.h:61
Operation that produces tensors.
Definition: tensor.h:48
const OperationNode * operator->() const
access the internal node container
Definition: operation.h:420
A placeholder op represents an input placeholder.
Definition: operation.h:102
Array< PrimExpr > output_shape(size_t i) const final
Get shape of i-th output tensor.
Array< PrimExpr > shape
The shape of the input.
Definition: operation.h:105
TVM_DECLARE_BASE_OBJECT_INFO(PlaceholderOpNode, OperationNode)
DataType dtype
The data type of the input.
Definition: operation.h:107
static void RegisterReflection()
Definition: operation.h:114
static constexpr const char * _type_key
Definition: operation.h:121
Array< Tensor > InputTensors() const final
List all the input Tensors.
int num_outputs() const final
DataType output_dtype(size_t i) const final
Get data type. i-th output tensor.
Managed reference to PlaceholderOpNode.
Definition: operation.h:130
TVM_DEFINE_OBJECT_REF_METHODS(PlaceholderOp, Operation, PlaceholderOpNode)
PlaceholderOp(std::string name, Array< PrimExpr > shape, DataType dtype)
Symbolic scan.
Definition: operation.h:202
ScanOpNode()
constructor
Definition: operation.h:228
int num_outputs() const final
TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode)
DataType output_dtype(size_t i) const final
Get data type. i-th output tensor.
Array< Tensor > state_placeholder
The placeholder to refer as states in update.
Definition: operation.h:211
Array< Tensor > init
the initialization tensors
Definition: operation.h:207
static void RegisterReflection()
Definition: operation.h:235
IterVar scan_axis
IterVar to scan over.
Definition: operation.h:205
Array< IterVar > spatial_axis_
Spatial axis to indicate spatial dimension of each output. They corresponds to flattened spatial axis...
Definition: operation.h:226
Array< Tensor > inputs
the inputs to the scan, these are optionally provided But they can be helpful to provide hints to spe...
Definition: operation.h:216
Array< Tensor > InputTensors() const final
List all the input Tensors.
Array< Tensor > update
the update function represented by tensor
Definition: operation.h:209
static constexpr const char * _type_key
Definition: operation.h:246
Array< PrimExpr > output_shape(size_t i) const final
Get shape of i-th output tensor.
Managed reference to ScanOpNode.
Definition: operation.h:255
ScanOp(std::string name, std::string tag, Optional< Map< String, ffi::Any >> attrs, IterVar axis, Array< Tensor > init, Array< Tensor > update, Array< Tensor > state_placeholder, Array< Tensor > input)
TVM_DEFINE_OBJECT_REF_METHODS(ScanOp, Operation, ScanOpNode)
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:100
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:298
Container of all statements.
Definition: stmt.h:64
a named variable in TIR
Definition: var.h:78
Definition: repr_printer.h:91
Tensor placeholder(Array< PrimExpr > shape, DataType dtype=DataType::Float(32), std::string name="placeholder")
create a place holder tensor.
Array< Tensor > scan(Array< Tensor > init, Array< Tensor > update, Array< Tensor > state_placeholder, Array< Tensor > inputs=Array< Tensor >(), std::string name="scan", std::string tag="", Map< String, ffi::Any > attrs={})
Construct new tensors by scan.
std::function< PrimExpr(const Array< Var > &i)> FCompute
The compute function to specify the input source of a Tensor.
Definition: operation.h:337
IterVar thread_axis(Range dom, std::string tag)
Create a new IterVar that represents an axis in thread.
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
std::function< Array< PrimExpr >(const Array< Var > &i)> FBatchCompute
The compute function to specify the inputs source of Tensors.
Definition: operation.h:340
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Tensor compute(Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", Map< String, ffi::Any > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1945
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
Temporary data structure to store union of bounds of each axis of Tensor.
Definition: operation.h:46
TensorDom(int ndim)
Definition: operation.h:48
std::vector< std::vector< IntSet > > data
The domain data.
Definition: operation.h:50
Dataflow tensor object.
TIR expressions.
Common operators defined for Expr.