tvm
base.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef TVM_SCRIPT_IR_BUILDER_BASE_H_
20 #define TVM_SCRIPT_IR_BUILDER_BASE_H_
21 
22 #include <tvm/ffi/reflection/registry.h>
23 #include <tvm/ir/cast.h>
24 #include <tvm/ir/expr.h>
25 #include <tvm/ir/function.h>
26 
27 #include <vector>
28 
29 namespace tvm {
30 namespace script {
31 namespace ir_builder {
32 
34 
65 class IRBuilderFrameNode : public ffi::Object {
66  public:
68  std::vector<ffi::TypedFunction<void()>> callbacks;
69 
70  static void RegisterReflection() {
71  namespace refl = tvm::ffi::reflection;
72  refl::ObjectDef<IRBuilderFrameNode>();
73  // `callbacks` is not registered as it's not visited.
74  }
75 
76  static constexpr const bool _type_mutable = true;
77  TVM_FFI_DECLARE_OBJECT_INFO("script.ir_builder.IRBuilderFrame", IRBuilderFrameNode, ffi::Object);
78 
79  public:
81  virtual ~IRBuilderFrameNode() = default;
86  virtual void EnterWithScope();
91  virtual void ExitWithScope();
96  void AddCallback(ffi::TypedFunction<void()> callback);
97 };
98 
103 class IRBuilderFrame : public ffi::ObjectRef {
104  public:
106 
107  protected:
109  IRBuilderFrame() = default;
110  explicit IRBuilderFrame(ffi::ObjectPtr<IRBuilderFrameNode> data) : ffi::ObjectRef(data) {}
111 
112  public:
117  inline void EnterWithScope() {
118  TVM_FFI_ICHECK(data_ != nullptr);
119  static_cast<IRBuilderFrameNode*>(data_.get())->EnterWithScope();
120  }
125  inline void ExitWithScope() {
126  TVM_FFI_ICHECK(data_ != nullptr);
127  static_cast<IRBuilderFrameNode*>(data_.get())->ExitWithScope();
128  data_.reset();
129  }
130 };
131 
133 
158 class IRBuilderNode : public ffi::Object {
159  public:
161  ffi::Array<IRBuilderFrame> frames;
163  ffi::Optional<ffi::ObjectRef> result;
164 
165  static void RegisterReflection() {
166  namespace refl = tvm::ffi::reflection;
167  refl::ObjectDef<IRBuilderNode>()
168  .def_ro("frames", &IRBuilderNode::frames)
169  .def_ro("result", &IRBuilderNode::result);
170  }
171 
172  static constexpr const bool _type_mutable = true;
173  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.IRBuilder", IRBuilderNode, ffi::Object);
174 
175  public:
181  template <typename TFrame>
182  inline ffi::Optional<TFrame> FindFrame() const;
189  template <typename TFrame>
190  inline ffi::Optional<TFrame> GetLastFrame() const;
196  template <typename TObjectRef>
197  inline TObjectRef Get() const;
198 };
199 
204 class IRBuilder : public ffi::ObjectRef {
205  public:
209 
210  public:
243  static IRBuilder Current();
245  static bool IsInScope();
252  template <class TObjectRef>
253  inline static TObjectRef Name(ffi::String name, TObjectRef obj);
254 };
255 
257 
258 namespace details {
259 
260 class Namer {
261  public:
262  using FType = NodeFunctor<void(const ffi::ObjectRef&, ffi::String)>;
263  static FType& vtable();
264  static void Name(ffi::ObjectRef node, ffi::String name);
265 };
266 
267 } // namespace details
268 
269 template <class TObjectRef>
270 inline TObjectRef IRBuilder::Name(ffi::String name, TObjectRef obj) {
271  details::Namer::Name(obj, name);
272  return Downcast<TObjectRef>(obj);
273 }
274 
275 template <typename TFrame>
276 inline ffi::Optional<TFrame> IRBuilderNode::FindFrame() const {
277  using TFrameNode = typename TFrame::ContainerType;
278  for (auto it = frames.rbegin(); it != frames.rend(); ++it) {
279  if (const TFrameNode* p = (*it).template as<TFrameNode>()) {
280  return ffi::GetRef<TFrame>(p);
281  }
282  }
283  return std::nullopt;
284 }
285 
286 template <typename TFrame>
287 inline ffi::Optional<TFrame> IRBuilderNode::GetLastFrame() const {
288  using TFrameNode = typename TFrame::ContainerType;
289  if (!frames.empty() && frames.back()->IsInstance<TFrameNode>()) {
290  return Downcast<TFrame>(frames.back());
291  }
292  return std::nullopt;
293 }
294 
295 template <typename TObjectRef>
296 inline TObjectRef IRBuilderNode::Get() const {
297  using TObject = typename TObjectRef::ContainerType;
298  TVM_FFI_CHECK(result.defined(), IndexError) << "No result exists in IRBuilder yet";
299  const auto* n = result.as<TObject>();
300  TVM_FFI_CHECK(n != nullptr, TypeError)
301  << "IRBuilder result is not of type: " << TObject::_type_key;
302  return ffi::GetRef<TObjectRef>(n);
303 }
304 
305 } // namespace ir_builder
306 } // namespace script
307 } // namespace tvm
308 
309 #endif // TVM_SCRIPT_IR_BUILDER_BASE_H_
Value casting helpers.
A dynamically dispatched functor on the type of the first argument.
Definition: node_functor.h:62
TVM_FFI_DECLARE_OBJECT_INFO("script.ir_builder.IRBuilderFrame", IRBuilderFrameNode, ffi::Object)
static constexpr const bool _type_mutable
Definition: base.h:76
std::vector< ffi::TypedFunction< void()> > callbacks
A list of callbacks used when exiting the frame.
Definition: base.h:68
virtual ~IRBuilderFrameNode()=default
Default destructor.
virtual void ExitWithScope()
The method called when exiting RAII scope.
virtual void EnterWithScope()
The method called when entering RAII scope.
void AddCallback(ffi::TypedFunction< void()> callback)
Add a callback method invoked when exiting the RAII scope.
static void RegisterReflection()
Definition: base.h:70
Managed reference to an IRBuilderFrameNode.
Definition: base.h:103
void EnterWithScope()
Redirected to IRBuilderFrameNode::EnterWithScope.
Definition: base.h:117
IRBuilderFrame(ffi::ObjectPtr< IRBuilderFrameNode > data)
Definition: base.h:110
IRBuilderFrame()=default
Disallow direct construction of this object.
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(IRBuilderFrame, ffi::ObjectRef, IRBuilderFrameNode)
void ExitWithScope()
Redirected to IRBuilderFrameNode::ExitWithScope.
Definition: base.h:125
A dialect-agnostic IRBuilder that constructs any IR of TVM. An idiomatic use of this class is to put ...
Definition: base.h:158
ffi::Optional< TFrame > FindFrame() const
Find a frame of the given type in the stack this->frames from top to bottom.
Definition: base.h:276
ffi::Optional< TFrame > GetLastFrame() const
Get the frame on top of the stack this->frames if its type is TFrame.
Definition: base.h:287
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.IRBuilder", IRBuilderNode, ffi::Object)
TObjectRef Get() const
Get the IR being constructed.
Definition: base.h:296
ffi::Optional< ffi::ObjectRef > result
The outcome of IR construction.
Definition: base.h:163
static constexpr const bool _type_mutable
Definition: base.h:172
ffi::Array< IRBuilderFrame > frames
A stack of context frames in the IRBuilder.
Definition: base.h:161
static void RegisterReflection()
Definition: base.h:165
Managed reference to an IRBuilderNode.
Definition: base.h:204
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(IRBuilder, ffi::ObjectRef, IRBuilderNode)
static TObjectRef Name(ffi::String name, TObjectRef obj)
Give a string name to the obj
Definition: base.h:270
void ExitWithScope()
Exit the RAII scope.
static IRBuilder Current()
Get the current IRBuilder in the current thread-local scope.
void EnterWithScope()
Puts the current IRBuilder into a thread-local scope, which can be retrieved using IRBuilder::Current...
static bool IsInScope()
See if the current thread-local scope has an IRBuilder.
IRBuilder()
Creates an IRBuilder.
static void Name(ffi::ObjectRef node, ffi::String name)
Base expr nodes in TVM.
Function nodes.
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37