tvm
struct_info_functor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RELAX_STRUCT_INFO_FUNCTOR_H_
25 #define TVM_RELAX_STRUCT_INFO_FUNCTOR_H_
26 
27 #include <tvm/node/functor.h>
29 #include <tvm/relax/struct_info.h>
30 
31 #include <utility>
32 
33 namespace tvm {
34 namespace relax {
35 
36 template <typename FStructInfo>
38 
39 // functions to be overriden.
40 #define STRUCT_INFO_FUNCTOR_DEFAULT \
41  { return VisitStructInfoDefault_(op, std::forward<Args>(args)...); }
42 
43 #define TVM_STRUCT_INFO_FUNCTOR_DISPATCH(OP) \
44  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
45  return self->VisitStructInfo_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
46  });
47 
48 template <typename R, typename... Args>
49 class StructInfoFunctor<R(const StructInfo& n, Args...)> {
50  private:
51  using TSelf = StructInfoFunctor<R(const StructInfo& n, Args...)>;
52  using FStructInfo = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
53 
54  public:
56  using result_type = R;
58  virtual ~StructInfoFunctor() {}
65  R operator()(const StructInfo& n, Args... args) {
66  return VisitStructInfo(n, std::forward<Args>(args)...);
67  }
74  virtual R VisitStructInfo(const StructInfo& n, Args... args) {
75  ICHECK(n.defined());
76  static FStructInfo vtable = InitVTable();
77  return vtable(n, this, std::forward<Args>(args)...);
78  }
79  // Functions that can be overriden by subclass
81  Args... args) STRUCT_INFO_FUNCTOR_DEFAULT;
82  virtual R VisitStructInfo_(const PrimStructInfoNode* op,
83  Args... args) STRUCT_INFO_FUNCTOR_DEFAULT;
85  Args... args) STRUCT_INFO_FUNCTOR_DEFAULT;
87  Args... args) STRUCT_INFO_FUNCTOR_DEFAULT;
89  Args... args) STRUCT_INFO_FUNCTOR_DEFAULT;
91  Args... args) STRUCT_INFO_FUNCTOR_DEFAULT;
92  virtual R VisitStructInfo_(const FuncStructInfoNode* op,
93  Args... args) STRUCT_INFO_FUNCTOR_DEFAULT;
94  virtual R VisitStructInfoDefault_(const Object* op, Args...) {
95  LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
96  throw; // unreachable, written to stop compiler warning
97  }
98 
99  private:
100  // initialize the vtable.
101  static FStructInfo InitVTable() {
102  FStructInfo vtable;
103  // Set dispatch
111  return vtable;
112  }
113 };
114 
115 #undef TVM_STRUCT_INFO_FUNCTOR_DISPATCH
116 
120 class TVM_DLL StructInfoVisitor : public StructInfoFunctor<void(const StructInfo& n)> {
121  public:
122  void VisitStructInfo_(const ObjectStructInfoNode* op) override;
123  void VisitStructInfo_(const PrimStructInfoNode* op) override;
124  void VisitStructInfo_(const ShapeStructInfoNode* op) override;
125  void VisitStructInfo_(const TensorStructInfoNode* op) override;
127  void VisitStructInfo_(const TupleStructInfoNode* op) override;
128  void VisitStructInfo_(const FuncStructInfoNode* op) override;
129 
130  protected:
131  // two functions to override when visit expr fields in struct info.
132  virtual void VisitStructInfoExprField(const Expr& expr) {}
133  virtual void VisitStructInfoExprField(const PrimExpr& expr) {}
134 };
135 
139 class TVM_DLL StructInfoMutator : public StructInfoFunctor<StructInfo(const StructInfo& n)> {
140  public:
148 
149  protected:
150  // two functions to override when visit expr fields in struct info.
151  virtual Expr VisitStructInfoExprField(const Expr& expr) { return expr; }
152  virtual PrimExpr VisitStructInfoExprField(const PrimExpr& expr) { return expr; }
153 };
154 
155 } // namespace relax
156 } // namespace tvm
157 #endif // TVM_RELAX_STRUCT_INFO_FUNCTOR_H_
A dynamically dispatched functor on the type of the first argument.
Definition: functor.h:64
Reference to PrimExprNode.
Definition: expr.h:115
Managed reference to RelayExprNode.
Definition: expr.h:442
Structure information about function.
Definition: struct_info.h:303
Opaque object.
Definition: struct_info.h:35
Primitive value.
Definition: struct_info.h:61
StructInfo of shape value.
Definition: struct_info.h:106
virtual R VisitStructInfo_(const FuncStructInfoNode *op, Args... args)
Definition: struct_info_functor.h:92
virtual R VisitStructInfo_(const TensorStructInfoNode *op, Args... args)
Definition: struct_info_functor.h:86
virtual R VisitStructInfo_(const PrimStructInfoNode *op, Args... args)
Definition: struct_info_functor.h:82
virtual R VisitStructInfoDefault_(const Object *op, Args...)
Definition: struct_info_functor.h:94
virtual R VisitStructInfo_(const ShapeStructInfoNode *op, Args... args)
Definition: struct_info_functor.h:84
virtual R VisitStructInfo(const StructInfo &n, Args... args)
The functor call.
Definition: struct_info_functor.h:74
virtual R VisitStructInfo_(const TupleStructInfoNode *op, Args... args)
Definition: struct_info_functor.h:90
virtual R VisitStructInfo_(const ObjectStructInfoNode *op, Args... args)
Definition: struct_info_functor.h:80
virtual ~StructInfoFunctor()
virtual destructor
Definition: struct_info_functor.h:58
virtual R VisitStructInfo_(const distributed::DTensorStructInfoNode *op, Args... args)
Definition: struct_info_functor.h:88
R result_type
the result type of this functor
Definition: struct_info_functor.h:56
R operator()(const StructInfo &n, Args... args)
Same as call.
Definition: struct_info_functor.h:65
Definition: struct_info_functor.h:37
StructInfoMutator that mutates struct info.
Definition: struct_info_functor.h:139
StructInfo VisitStructInfo_(const TupleStructInfoNode *op) override
StructInfo VisitStructInfo_(const FuncStructInfoNode *op) override
StructInfo VisitStructInfo_(const ObjectStructInfoNode *op) override
StructInfo VisitStructInfo_(const ShapeStructInfoNode *op) override
StructInfo VisitStructInfo_(const PrimStructInfoNode *op) override
virtual Expr VisitStructInfoExprField(const Expr &expr)
Definition: struct_info_functor.h:151
StructInfo VisitStructInfo_(const distributed::DTensorStructInfoNode *op) override
virtual PrimExpr VisitStructInfoExprField(const PrimExpr &expr)
Definition: struct_info_functor.h:152
StructInfo VisitStructInfo_(const TensorStructInfoNode *op) override
A struct info visitor.
Definition: struct_info_functor.h:120
void VisitStructInfo_(const FuncStructInfoNode *op) override
virtual void VisitStructInfoExprField(const PrimExpr &expr)
Definition: struct_info_functor.h:133
void VisitStructInfo_(const TupleStructInfoNode *op) override
void VisitStructInfo_(const ObjectStructInfoNode *op) override
virtual void VisitStructInfoExprField(const Expr &expr)
Definition: struct_info_functor.h:132
void VisitStructInfo_(const PrimStructInfoNode *op) override
void VisitStructInfo_(const distributed::DTensorStructInfoNode *op) override
void VisitStructInfo_(const ShapeStructInfoNode *op) override
void VisitStructInfo_(const TensorStructInfoNode *op) override
Managed reference to StructInfoNode.
Definition: expr.h:129
StructInfo of Tensor.
Definition: struct_info.h:163
StructInfo of Tuple.
Definition: struct_info.h:253
StructInfo of DTensor (Distributed Tensor).
Definition: struct_info.h:132
Base class of all object reference.
Definition: object.h:519
bool defined() const
Definition: object.h:552
base class of all object containers.
Definition: object.h:171
std::string GetTypeKey() const
Definition: object.h:184
Struct info for DTensor (Distributed Tensor)
Defines the Functor data structures.
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
#define STRUCT_INFO_FUNCTOR_DEFAULT
Definition: struct_info_functor.h:40
#define TVM_STRUCT_INFO_FUNCTOR_DISPATCH(OP)
Definition: struct_info_functor.h:43