tvm
operation.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TE_OPERATION_H_
25 #define TVM_TE_OPERATION_H_
26 
27 #include <tvm/arith/analyzer.h>
28 #include <tvm/ffi/reflection/registry.h>
29 #include <tvm/ir/cow.h>
30 #include <tvm/te/tensor.h>
31 #include <tvm/tirx/buffer.h>
32 #include <tvm/tirx/expr.h>
33 #include <tvm/tirx/op.h>
34 
35 #include <string>
36 #include <unordered_map>
37 #include <vector>
38 
39 namespace tvm {
41 namespace te {
42 
47 struct TensorDom {
48  // constructor
49  explicit TensorDom(int ndim) : data(ndim) {}
51  std::vector<std::vector<IntSet>> data;
52 };
53 
57 class TVM_DLL OperationNode : public ffi::Object {
58  public:
60  std::string name;
62  std::string tag;
64  ffi::Map<ffi::String, ffi::Any> attrs;
65  // virtual destructor.
66  virtual ~OperationNode() {}
68  virtual int num_outputs() const = 0;
74  virtual DataType output_dtype(size_t i) const = 0;
80  virtual ffi::Array<PrimExpr> output_shape(size_t i) const = 0;
85  virtual ffi::Array<Tensor> InputTensors() const = 0;
86 
87  static void RegisterReflection() {
88  namespace refl = tvm::ffi::reflection;
89  refl::ObjectDef<OperationNode>()
90  .def_ro("name", &OperationNode::name)
91  .def_ro("tag", &OperationNode::tag)
92  .def_ro("attrs", &OperationNode::attrs);
93  }
94  TVM_FFI_DECLARE_OBJECT_INFO("te.Operation", OperationNode, ffi::Object);
95 };
96 
101  public:
103  ffi::Array<PrimExpr> shape;
106  // override behavior.
107  int num_outputs() const final;
108  DataType output_dtype(size_t i) const final;
109  ffi::Array<PrimExpr> output_shape(size_t i) const final;
110  ffi::Array<Tensor> InputTensors() const final;
111 
112  static void RegisterReflection() {
113  namespace refl = tvm::ffi::reflection;
114  refl::ObjectDef<PlaceholderOpNode>()
115  .def_ro("shape", &PlaceholderOpNode::shape)
116  .def_ro("dtype", &PlaceholderOpNode::dtype);
117  }
119 };
120 
125 class PlaceholderOp : public Operation {
126  public:
127  TVM_DLL PlaceholderOp(std::string name, ffi::Array<PrimExpr> shape, DataType dtype);
128 
130 };
131 
136 class TVM_DLL BaseComputeOpNode : public OperationNode {
137  public:
139  ffi::Array<IterVar> axis;
141  ffi::Array<IterVar> reduce_axis;
142  // override functions
143  ffi::Array<PrimExpr> output_shape(size_t idx) const final;
144 
145  static void RegisterReflection() {
146  namespace refl = tvm::ffi::reflection;
147  refl::ObjectDef<BaseComputeOpNode>()
148  .def_ro("axis", &BaseComputeOpNode::axis)
149  .def_ro("reduce_axis", &BaseComputeOpNode::reduce_axis);
150  }
152 };
153 
157 class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
158  public:
160  ffi::Array<PrimExpr> body;
163  // override functions
164  int num_outputs() const final;
165  DataType output_dtype(size_t i) const final;
166  ffi::Array<Tensor> InputTensors() const final;
167 
168  static void RegisterReflection() {
169  namespace refl = tvm::ffi::reflection;
170  refl::ObjectDef<ComputeOpNode>().def_ro("body", &ComputeOpNode::body);
171  }
173 };
174 
179 class ComputeOp : public Operation {
180  public:
181  TVM_DLL ComputeOp(std::string name, std::string tag, ffi::Map<ffi::String, ffi::Any> attrs,
182  ffi::Array<IterVar> axis, ffi::Array<PrimExpr> body);
183 
186 };
187 
191 class ScanOpNode : public OperationNode {
192  public:
196  ffi::Array<Tensor> init;
198  ffi::Array<Tensor> update;
200  ffi::Array<Tensor> state_placeholder;
205  ffi::Array<Tensor> inputs;
215  ffi::Array<IterVar> spatial_axis_;
218  // override behavior.
219  int num_outputs() const final;
220  DataType output_dtype(size_t i) const final;
221  ffi::Array<PrimExpr> output_shape(size_t i) const final;
222  ffi::Array<Tensor> InputTensors() const final;
223 
224  static void RegisterReflection() {
225  namespace refl = tvm::ffi::reflection;
226  refl::ObjectDef<ScanOpNode>()
227  .def_ro("scan_axis", &ScanOpNode::scan_axis)
228  .def_ro("init", &ScanOpNode::init)
229  .def_ro("update", &ScanOpNode::update)
230  .def_ro("state_placeholder", &ScanOpNode::state_placeholder)
231  .def_ro("inputs", &ScanOpNode::inputs)
232  .def_ro("spatial_axis_", &ScanOpNode::spatial_axis_);
233  }
235 };
236 
241 class ScanOp : public Operation {
242  public:
243  TVM_DLL ScanOp(std::string name, std::string tag,
244  ffi::Optional<ffi::Map<ffi::String, ffi::Any>> attrs, IterVar axis,
245  ffi::Array<Tensor> init, ffi::Array<Tensor> update,
246  ffi::Array<Tensor> state_placeholder, ffi::Array<Tensor> input);
247 
249 };
250 
254 class ExternOpNode : public OperationNode {
255  public:
257  ffi::Array<Tensor> inputs;
259  ffi::Array<Buffer> input_placeholders;
261  ffi::Array<Buffer> output_placeholders;
264 
267  // override functions
268  int num_outputs() const final;
269  DataType output_dtype(size_t i) const final;
270  ffi::Array<PrimExpr> output_shape(size_t i) const final;
271  ffi::Array<Tensor> InputTensors() const final;
272 
273  static void RegisterReflection() {
274  namespace refl = tvm::ffi::reflection;
275  refl::ObjectDef<ExternOpNode>()
276  .def_ro("inputs", &ExternOpNode::inputs)
277  .def_ro("input_placeholders", &ExternOpNode::input_placeholders)
278  .def_ro("output_placeholders", &ExternOpNode::output_placeholders)
279  .def_ro("body", &ExternOpNode::body);
280  }
282 };
283 
288 class ExternOp : public Operation {
289  public:
290  TVM_DLL ExternOp(std::string name, std::string tag, ffi::Map<ffi::String, ffi::Any> attrs,
291  ffi::Array<Tensor> inputs, ffi::Array<Buffer> input_placeholders,
292  ffi::Array<Buffer> output_placeholders, Stmt body);
293 
295 };
296 
302 TVM_DLL Var var(std::string name_hint, DataType t = DataType::Int(32));
303 
310 TVM_DLL IterVar thread_axis(Range dom, std::string tag);
311 
318 TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv");
319 
321 using FCompute = std::function<PrimExpr(const ffi::Array<Var>& i)>;
322 
324 using FBatchCompute = std::function<ffi::Array<PrimExpr>(const ffi::Array<Var>& i)>;
325 
332 TVM_DLL Tensor placeholder(ffi::Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
333  std::string name = "placeholder");
334 
344 TVM_DLL Tensor compute(ffi::Array<PrimExpr> shape, FCompute fcompute, std::string name = "tensor",
345  std::string tag = "", ffi::Map<ffi::String, ffi::Any> attrs = {});
346 
356 TVM_DLL ffi::Array<Tensor> compute(ffi::Array<PrimExpr> shape, FBatchCompute fcompute,
357  std::string name = "tensor", std::string tag = "",
358  ffi::Map<ffi::String, ffi::Any> attrs = {});
359 
372 TVM_DLL ffi::Array<Tensor> scan(ffi::Array<Tensor> init, ffi::Array<Tensor> update,
373  ffi::Array<Tensor> state_placeholder,
374  ffi::Array<Tensor> inputs = ffi::Array<Tensor>(),
375  std::string name = "scan", std::string tag = "",
376  ffi::Map<ffi::String, ffi::Any> attrs = {});
377 
378 // same as compute, specialized for different fcompute function
379 inline Tensor compute(ffi::Array<PrimExpr> shape, std::function<PrimExpr(Var)> f,
380  std::string name = "tensor", std::string tag = "",
381  ffi::Map<ffi::String, ffi::Any> attrs = {}) {
382  FCompute fc = [f](const ffi::Array<Var>& i) { return f(i[0]); };
383  return compute(shape, fc, name, tag, attrs);
384 }
385 inline Tensor compute(ffi::Array<PrimExpr> shape, std::function<PrimExpr(Var, Var)> f,
386  std::string name = "tensor", std::string tag = "",
387  ffi::Map<ffi::String, ffi::Any> attrs = {}) {
388  FCompute fc = [f](const ffi::Array<Var>& i) { return f(i[0], i[1]); };
389  return compute(shape, fc, name, tag, attrs);
390 }
391 inline Tensor compute(ffi::Array<PrimExpr> shape, std::function<PrimExpr(Var, Var, Var)> f,
392  std::string name = "tensor", std::string tag = "",
393  ffi::Map<ffi::String, ffi::Any> attrs = {}) {
394  FCompute fc = [f](const ffi::Array<Var>& i) { return f(i[0], i[1], i[2]); };
395  return compute(shape, fc, name, tag, attrs);
396 }
397 inline Tensor compute(ffi::Array<PrimExpr> shape, std::function<PrimExpr(Var, Var, Var, Var)> f,
398  std::string name = "tensor", std::string tag = "",
399  ffi::Map<ffi::String, ffi::Any> attrs = {}) {
400  FCompute fc = [f](const ffi::Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
401  return compute(shape, fc, name, tag, attrs);
402 }
403 
404 // inline function.
405 inline const OperationNode* Operation::operator->() const {
406  return static_cast<const OperationNode*>(get());
407 }
408 } // namespace te
409 } // namespace tvm
410 #endif // TVM_TE_OPERATION_H_
Algebra expression simplifications.
Symbolic n-dimensional array, to represent a memory buffer.
Reference to PrimExprNode.
Definition: expr.h:126
Range container
Definition: expr.h:690
Runtime primitive data type.
Definition: data_type.h:45
static DataType Float(int bits, int lanes=1)
Construct an float type.
Definition: data_type.h:293
static DataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:276
A Compute op that compute a tensor on certain domain. This is the base class for ComputeOp (operating...
Definition: operation.h:136
ffi::Array< IterVar > reduce_axis
IterVar on each reduction axis, if the body is a Reduce.
Definition: operation.h:141
static void RegisterReflection()
Definition: operation.h:145
ffi::Array< PrimExpr > output_shape(size_t idx) const final
Get shape of i-th output tensor.
TVM_FFI_DECLARE_OBJECT_INFO("te.BaseComputeOp", BaseComputeOpNode, OperationNode)
ffi::Array< IterVar > axis
IterVar on each axis.
Definition: operation.h:139
A Compute op that compute a tensor on certain domain.
Definition: operation.h:157
ComputeOpNode()
constructor
Definition: operation.h:162
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("te.ComputeOp", ComputeOpNode, BaseComputeOpNode)
int num_outputs() const final
ffi::Array< PrimExpr > body
the compute expression
Definition: operation.h:160
Managed reference to ComputeOpNode.
Definition: operation.h:179
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ComputeOp, Operation, ComputeOpNode)
ComputeOp(std::string name, std::string tag, ffi::Map< ffi::String, ffi::Any > attrs, ffi::Array< IterVar > axis, ffi::Array< PrimExpr > body)
TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeOpNode)
External computation that cannot be splitted.
Definition: operation.h:254
int num_outputs() const final
ffi::Array< Tensor > inputs
The input tensors.
Definition: operation.h:257
static void RegisterReflection()
Definition: operation.h:273
ffi::Array< PrimExpr > output_shape(size_t i) const final
Get shape of i-th output tensor.
ffi::Array< Buffer > input_placeholders
Symbolic placeholder representation of inputs.
Definition: operation.h:259
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("te.ExternOp", ExternOpNode, OperationNode)
ffi::Array< Tensor > InputTensors() const final
List all the input Tensors.
Stmt body
the statement that generates the computation.
Definition: operation.h:263
ExternOpNode()
constructor
Definition: operation.h:266
ffi::Array< Buffer > output_placeholders
Symbolic placeholder representation of outputs.
Definition: operation.h:261
DataType output_dtype(size_t i) const final
Get data type. i-th output tensor.
Managed reference to ExternOpNode.
Definition: operation.h:288
ExternOp(std::string name, std::string tag, ffi::Map< ffi::String, ffi::Any > attrs, ffi::Array< Tensor > inputs, ffi::Array< Buffer > input_placeholders, ffi::Array< Buffer > output_placeholders, Stmt body)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ExternOp, Operation, ExternOpNode)
Base class of all operation nodes.
Definition: operation.h:57
TVM_FFI_DECLARE_OBJECT_INFO("te.Operation", OperationNode, ffi::Object)
virtual ~OperationNode()
Definition: operation.h:66
virtual ffi::Array< PrimExpr > output_shape(size_t i) const =0
Get shape of i-th output tensor.
static void RegisterReflection()
Definition: operation.h:87
virtual ffi::Array< Tensor > InputTensors() const =0
List all the input Tensors.
virtual DataType output_dtype(size_t i) const =0
Get data type. i-th output tensor.
virtual int num_outputs() const =0
ffi::Map< ffi::String, ffi::Any > attrs
additional attributes of the operation
Definition: operation.h:64
std::string name
optional name of the operation
Definition: operation.h:60
std::string tag
optional tag of the operation
Definition: operation.h:62
Operation that produces tensors.
Definition: tensor.h:48
const OperationNode * operator->() const
access the internal node container
Definition: operation.h:405
A placeholder op represents an input placeholder.
Definition: operation.h:100
TVM_FFI_DECLARE_OBJECT_INFO("te.PlaceholderOp", PlaceholderOpNode, OperationNode)
DataType dtype
The data type of the input.
Definition: operation.h:105
static void RegisterReflection()
Definition: operation.h:112
ffi::Array< PrimExpr > shape
The shape of the input.
Definition: operation.h:103
int num_outputs() const final
DataType output_dtype(size_t i) const final
Get data type. i-th output tensor.
ffi::Array< Tensor > InputTensors() const final
List all the input Tensors.
ffi::Array< PrimExpr > output_shape(size_t i) const final
Get shape of i-th output tensor.
Managed reference to PlaceholderOpNode.
Definition: operation.h:125
PlaceholderOp(std::string name, ffi::Array< PrimExpr > shape, DataType dtype)
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PlaceholderOp, Operation, PlaceholderOpNode)
Symbolic scan.
Definition: operation.h:191
ffi::Array< Tensor > state_placeholder
The placeholder to refer as states in update.
Definition: operation.h:200
ScanOpNode()
constructor
Definition: operation.h:217
int num_outputs() const final
ffi::Array< Tensor > init
the initialization tensors
Definition: operation.h:196
DataType output_dtype(size_t i) const final
Get data type. i-th output tensor.
ffi::Array< Tensor > update
the update function represented by tensor
Definition: operation.h:198
ffi::Array< PrimExpr > output_shape(size_t i) const final
Get shape of i-th output tensor.
ffi::Array< Tensor > inputs
the inputs to the scan, these are optionally provided But they can be helpful to provide hints to spe...
Definition: operation.h:205
ffi::Array< Tensor > InputTensors() const final
List all the input Tensors.
static void RegisterReflection()
Definition: operation.h:224
IterVar scan_axis
IterVar to scan over.
Definition: operation.h:194
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("te.ScanOp", ScanOpNode, OperationNode)
ffi::Array< IterVar > spatial_axis_
Spatial axis to indicate spatial dimension of each output. They corresponds to flattened spatial axis...
Definition: operation.h:215
Managed reference to ScanOpNode.
Definition: operation.h:241
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ScanOp, Operation, ScanOpNode)
ScanOp(std::string name, std::string tag, ffi::Optional< ffi::Map< ffi::String, ffi::Any >> attrs, IterVar axis, ffi::Array< Tensor > init, ffi::Array< Tensor > update, ffi::Array< Tensor > state_placeholder, ffi::Array< Tensor > input)
Tensor structure representing a possible input, or intermediate computation result.
Definition: tensor.h:100
Iteration Variable, represents an iteration over an integer interval.
Definition: var.h:297
Container of all statements.
Definition: stmt.h:67
a named variable in TIR
Definition: var.h:77
Copy-on-write helper macro for IR ffi::ObjectRef types.
std::function< PrimExpr(const ffi::Array< Var > &i)> FCompute
The compute function to specify the input source of a Tensor.
Definition: operation.h:321
ffi::Array< Tensor > scan(ffi::Array< Tensor > init, ffi::Array< Tensor > update, ffi::Array< Tensor > state_placeholder, ffi::Array< Tensor > inputs=ffi::Array< Tensor >(), std::string name="scan", std::string tag="", ffi::Map< ffi::String, ffi::Any > attrs={})
Construct new tensors by scan.
std::function< ffi::Array< PrimExpr >(const ffi::Array< Var > &i)> FBatchCompute
The compute function to specify the inputs source of Tensors.
Definition: operation.h:324
IterVar thread_axis(Range dom, std::string tag)
Create a new IterVar that represents an axis in thread.
IterVar reduce_axis(Range dom, std::string name="rv")
Create a new IterVar for reduction operations.
Tensor compute(ffi::Array< PrimExpr > shape, FCompute fcompute, std::string name="tensor", std::string tag="", ffi::Map< ffi::String, ffi::Any > attrs={})
Construct a new tensor by computing over shape, using the computation rule: result_tensor[axis] = fco...
Var var(std::string name_hint, DataType t=DataType::Int(32))
Construct a new Var expression.
Tensor placeholder(ffi::Array< PrimExpr > shape, DataType dtype=DataType::Float(32), std::string name="placeholder")
create a place holder tensor.
Tensor shape(const Tensor &src, DataType dtype, const std::string name="T_shape", const std::string tag=kInjective)
Get the shape of input tensor.
Definition: transform.h:1981
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
Temporary data structure to store union of bounds of each axis of Tensor.
Definition: operation.h:47
TensorDom(int ndim)
Definition: operation.h:49
std::vector< std::vector< IntSet > > data
The domain data.
Definition: operation.h:51
Dataflow tensor object.
TIR expressions.
Common operators defined for Expr.