tvm
type_relation.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_IR_TYPE_RELATION_H_
25 #define TVM_IR_TYPE_RELATION_H_
26 
27 #include <tvm/ir/attrs.h>
28 #include <tvm/ir/diagnostic.h>
29 #include <tvm/ir/env_func.h>
30 #include <tvm/ir/module.h>
31 #include <tvm/ir/type.h>
32 #include <tvm/runtime/logging.h>
33 
34 namespace tvm {
35 
40 class TypeCallNode : public TypeNode {
41  public:
48 
50  v->Visit("func", &func);
51  v->Visit("args", &args);
52  v->Visit("span", &span);
53  }
54 
55  bool SEqualReduce(const TypeCallNode* other, SEqualReducer equal) const {
56  return equal(func, other->func) && equal(args, other->args);
57  }
58 
59  void SHashReduce(SHashReducer hash_reduce) const {
60  hash_reduce(func);
61  hash_reduce(args);
62  }
63 
64  static constexpr const char* _type_key = "TypeCall";
66 };
67 
72 class TypeCall : public Type {
73  public:
79  TVM_DLL TypeCall(Type func, Array<Type> args);
80 
82 };
83 
88 class TypeReporterNode : public Object {
89  public:
91  virtual ~TypeReporterNode() {}
99  TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;
100 
108  TVM_DLL virtual bool Assert(const PrimExpr& cond) = 0;
116  TVM_DLL virtual bool AssertEQ(const PrimExpr& lhs, const PrimExpr& rhs) = 0;
117 
122  TVM_DLL virtual void SetSpan(const Span& span) = 0;
123 
124  TVM_DLL virtual Span GetSpan() = 0;
125 
126  TVM_DLL virtual DiagnosticContext GetDiagCtx() = 0;
127 
132  TVM_DLL virtual IRModule GetModule() = 0;
133 
134  // solver is not serializable.
136 
137  static constexpr const char* _type_key = "TypeReporter";
139 };
140 
145 class TypeReporter : public ObjectRef {
146  public:
150  return const_cast<TypeReporterNode*>(static_cast<const TypeReporterNode*>(get()));
151  }
153 };
154 
174 using TypeRelationFn = TypedEnvFunc<bool(const Array<Type>& args, int num_inputs,
175  const Attrs& attrs, const TypeReporter& reporter)>;
176 
186  public:
199 
201  v->Visit("func", &func);
202  v->Visit("args", &args);
203  v->Visit("num_inputs", &num_inputs);
204  v->Visit("attrs", &attrs);
205  v->Visit("span", &span);
206  }
207 
208  bool SEqualReduce(const TypeRelationNode* other, SEqualReducer equal) const {
209  return equal(func, other->func) && equal(args, other->args) &&
210  equal(num_inputs, other->num_inputs) && equal(attrs, other->attrs);
211  }
212 
213  void SHashReduce(SHashReducer hash_reduce) const {
214  hash_reduce(func);
215  hash_reduce(args);
216  hash_reduce(num_inputs);
217  hash_reduce(attrs);
218  }
219 
220  static constexpr const char* _type_key = "TypeRelation";
222 };
223 
228 class TypeRelation : public TypeConstraint {
229  public:
238  TVM_DLL TypeRelation(TypeRelationFn func, Array<Type> args, int num_inputs, Attrs attrs);
239 
241 };
242 } // namespace tvm
243 #endif // TVM_IR_TYPE_RELATION_H_
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Managed reference to BaseAttrsNode.
Definition: attrs.h:190
Definition: diagnostic.h:217
Managed reference class to IRModuleNode.
Definition: module.h:366
Reference to PrimExprNode.
Definition: expr.h:115
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
Definition: source_map.h:120
Type function application.
Definition: type_relation.h:40
Array< Type > args
The arguments.
Definition: type_relation.h:47
static constexpr const char * _type_key
Definition: type_relation.h:64
void VisitAttrs(AttrVisitor *v)
Definition: type_relation.h:49
bool SEqualReduce(const TypeCallNode *other, SEqualReducer equal) const
Definition: type_relation.h:55
Type func
The type-level function (ADT that takes type params).
Definition: type_relation.h:45
TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode)
void SHashReduce(SHashReducer hash_reduce) const
Definition: type_relation.h:59
Managed reference to TypeCallNode.
Definition: type_relation.h:72
TypeCall(Type func, Array< Type > args)
Constructor.
TVM_DEFINE_OBJECT_REF_METHODS(TypeCall, Type, TypeCallNode)
Potential Constraints in a function.
Definition: type.h:412
Managed reference to TypeConstraintNode.
Definition: type.h:423
Type is the base type of all types.
Definition: type.h:74
Span span
Span that points to the original source code. Reserved debug information.
Definition: type.h:80
User defined type relation, it is an input-output relation on types.
Definition: type_relation.h:185
TypeRelationFn func
The function on input and output variables which this is not directly serializable,...
Definition: type_relation.h:192
TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode)
Attrs attrs
Attributes to the relation function.
Definition: type_relation.h:198
Array< Type > args
The type arguments to the type function.
Definition: type_relation.h:194
int num_inputs
Number of inputs arguments.
Definition: type_relation.h:196
bool SEqualReduce(const TypeRelationNode *other, SEqualReducer equal) const
Definition: type_relation.h:208
static constexpr const char * _type_key
Definition: type_relation.h:220
void SHashReduce(SHashReducer hash_reduce) const
Definition: type_relation.h:213
void VisitAttrs(AttrVisitor *v)
Definition: type_relation.h:200
Managed reference to TypeRelationNode.
Definition: type_relation.h:228
TypeRelation(TypeRelationFn func, Array< Type > args, int num_inputs, Attrs attrs)
Constructor.
TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode)
reporter that reports back to the type resolution information.
Definition: type_relation.h:88
virtual Span GetSpan()=0
void VisitAttrs(AttrVisitor *v)
Definition: type_relation.h:135
virtual ~TypeReporterNode()
virtual destructor
Definition: type_relation.h:91
virtual DiagnosticContext GetDiagCtx()=0
virtual bool Assert(const PrimExpr &cond)=0
assert shape expression comparison.
static constexpr const char * _type_key
Definition: type_relation.h:137
virtual void Assign(const Type &dst, const Type &src)=0
Create a type equality constraint.
virtual void SetSpan(const Span &span)=0
Set the location at which to report unification errors.
virtual bool AssertEQ(const PrimExpr &lhs, const PrimExpr &rhs)=0
assert shape expression equals each other.
TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object)
virtual IRModule GetModule()=0
Retrieve the current global module.
Container class of TypeReporter.
Definition: type_relation.h:145
TypeReporterNode * operator->() const
Definition: type_relation.h:149
TypeReporter()
Definition: type_relation.h:147
TypeReporter(ObjectPtr< Object > n)
Definition: type_relation.h:148
Managed reference to TypeNode.
Definition: type.h:93
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
const Object * get() const
Definition: object.h:554
base class of all object containers.
Definition: object.h:171
A new diagnostic interface for TVM error reporting.
Serializable global function used in IR.
Helpers for attribute objects.
IRModule that holds the functions and type definitions.
IR/AST nodes for the unified type system in TVM.
tvm::TypeReporterNode TypeReporterNode
Definition: type.h:71
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal