tvm
function.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAY_FUNCTION_H_
25 #define TVM_RELAY_FUNCTION_H_
26 
27 #include <tvm/ir/function.h>
28 #include <tvm/relay/expr.h>
29 
30 #include <string>
31 
32 namespace tvm {
33 namespace relay {
34 
39 class FunctionNode : public BaseFuncNode {
40  public:
60 
62  v->Visit("params", &params);
63  v->Visit("body", &body);
64  v->Visit("ret_type", &ret_type);
65  v->Visit("type_params", &type_params);
66  v->Visit("attrs", &attrs);
67  v->Visit("virtual_device_", &virtual_device_);
68  v->Visit("span", &span);
69  v->Visit("_checked_type_", &checked_type_);
70  }
71 
72  bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const {
73  // Important to make def equal first.
74  equal->MarkGraphNode();
75  return equal.DefEqual(params, other->params) &&
76  equal.DefEqual(type_params, other->type_params) && equal(ret_type, other->ret_type) &&
77  equal(attrs, other->attrs) && equal(body, other->body);
78  }
79 
80  void SHashReduce(SHashReducer hash_reduce) const {
81  hash_reduce->MarkGraphNode();
82  hash_reduce.DefHash(params);
83  hash_reduce.DefHash(type_params);
84  hash_reduce(ret_type);
85  hash_reduce(attrs);
86  hash_reduce(body);
87  }
88 
95  TVM_DLL FuncType func_type_annotation() const;
96 
97  static constexpr const char* _type_key = "relay.Function";
99 };
100 
105 class Function : public BaseFunc {
106  public:
116  TVM_DLL Function(tvm::Array<Var> params, Expr body, Type ret_type, tvm::Array<TypeVar> ty_params,
117  tvm::DictAttrs attrs = DictAttrs(), Span span = Span());
118 
121 };
122 
129  Optional<Expr> opt_body = Optional<Expr>(),
130  Optional<Type> opt_ret_type = Optional<Type>(),
131  Optional<Array<TypeVar>> opt_ty_params = Optional<Array<TypeVar>>(),
133  Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
134  Optional<Span> opt_span = Optional<Span>());
135 
136 /*
137  * \brief Returns the Relay FunctionNode represented by base_func if it should be optimized,
138  * otherwise returns nullptr.
139  *
140  * This means returns nullptr:
141  * - For PrimFuncs, since not Relay Functions.
142  * - For Functions marked for external compilation (with "Compiler").
143  * - For Functions marked as already having an external definition (with "ExternalSymbol").
144  * - For Functions marked as not to be optimized (with "SkipOptimization").
145  *
146  * TODO(mbs): Audit all enumerations of IRModule::functions to use this or some family of such.
147  */
149 
153 namespace attr {
154 
164 constexpr const char* kPrimitive = "Primitive";
165 
173 constexpr const char* kExtern = "Extern";
174 
182 constexpr const char* kCompiler = "Compiler";
183 
185 constexpr const char* kClosure = "Closure";
187 constexpr const char* kParams = "__params__";
189 constexpr const char* kSkipOptimization = "SkipOptimization";
191 constexpr const char* kComposite = "Composite";
193 constexpr const char* kInline = "Inline";
195 constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
197 constexpr const char* kReshapeOnly = "relay.reshape_only";
198 
199 } // namespace attr
200 
201 } // namespace relay
202 } // namespace tvm
203 #endif // TVM_RELAY_FUNCTION_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Span span
Span that points to the original source code. Reserved debug information.
Definition: expr.h:56
Base node of all functions.
Definition: function.h:139
DictAttrs attrs
Additional attributes storing the meta-data.
Definition: function.h:142
Managed reference to BaseFuncNode.
Definition: function.h:230
Managed reference to DictAttrsNode.
Definition: attrs.h:227
Managed reference to FuncTypeNode.
Definition: type.h:481
ObjectRef virtual_device_
The virtual device (VirtualDevice) for this node (the result of device planning). For first-order exp...
Definition: expr.h:418
Type checked_type_
Stores the result of type inference(type checking).
Definition: expr.h:370
Managed reference to RelayExprNode.
Definition: expr.h:442
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
virtual void MarkGraphNode()=0
Mark current comparison as graph node in hashing. Graph node hash will depends on the graph structure...
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
void DefHash(const ObjectRef &key) const
Push hash of key to the current sequence of hash values.
Definition: structural_hash.h:198
Definition: source_map.h:120
Managed reference to TypeNode.
Definition: type.h:93
Relay Function container.
Definition: function.h:39
Type ret_type
User annotated return type of the function.
Definition: function.h:51
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode)
Expr body
The expression which represents the computation of the function, the expression may reference the par...
Definition: function.h:49
tvm::Array< TypeVar > type_params
Type parameters of the function. Enables the function to vary its type based on these....
Definition: function.h:59
bool SEqualReduce(const FunctionNode *other, SEqualReducer equal) const
Definition: function.h:72
tvm::Array< Var > params
Function parameters.
Definition: function.h:42
static constexpr const char * _type_key
Definition: function.h:97
void VisitAttrs(tvm::AttrVisitor *v)
Definition: function.h:61
FuncType func_type_annotation() const
Return the derived function annotation of this expression.
void SHashReduce(SHashReducer hash_reduce) const
Definition: function.h:80
Managed reference to FunctionNode.
Definition: function.h:105
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode)
Function(tvm::Array< Var > params, Expr body, Type ret_type, tvm::Array< TypeVar > ty_params, tvm::DictAttrs attrs=DictAttrs(), Span span=Span())
Constructor.
TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode)
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Function nodes.
constexpr const char * kComposite
Treat the function as a composite operator.
Definition: function.h:191
constexpr const char * kCompiler
Indicates the name of the external codegen 'compiler' that should be used to lower or compile the fun...
Definition: function.h:182
constexpr const char * kParams
Store a Var to parameter/Constant mapping on a Function.
Definition: function.h:187
constexpr const char * kExtern
Mark the function as externally implemented, ie bound in a runtime::Module within the IRModule's "ext...
Definition: function.h:173
constexpr const char * kReshapeOnly
Mark the function as only composed of reshape operations.
Definition: function.h:197
constexpr const char * kPrimitive
Mark the function as representing a sub-graph which is to be lowered or compiled as a unit....
Definition: function.h:164
constexpr const char * kInline
Mark the function to be inlined.
Definition: function.h:193
constexpr const char * kSkipOptimization
Mark if the function should be avoided being optimized.
Definition: function.h:189
constexpr const char * kClosure
Indicate if the function is a closure.
Definition: function.h:185
constexpr const char * kPartitionedFromPattern
Indicate the function was created by the Pattern Partitioning Pass.
Definition: function.h:195
const FunctionNode * AsOptimizableFunctionNode(const BaseFunc &base_func)
Clause WithFields(Clause clause, Optional< Pattern > opt_lhs=Optional< Pattern >(), Optional< Expr > opt_rhs=Optional< Expr >())
Returns clause with the given properties. A null property denotes 'no change'. Returns clause if all ...
tvm::Span Span
Definition: base.h:65
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
Relay expression language.