tvm
traced_object.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef TVM_SCRIPT_PRINTER_TRACED_OBJECT_H_
27 #define TVM_SCRIPT_PRINTER_TRACED_OBJECT_H_
28 
29 #include <tvm/node/object_path.h>
30 #include <tvm/node/reflection.h>
31 #include <tvm/runtime/object.h>
32 
33 #include <string>
34 #include <utility>
35 
36 namespace tvm {
37 
38 template <typename RefT>
40 template <typename K, typename V>
41 class TracedMap;
42 template <typename T>
44 template <typename T>
46 template <typename T>
48 
49 namespace detail {
50 
51 template <typename T, bool IsObject = std::is_base_of<ObjectRef, T>::value>
53 
54 template <typename T>
55 struct TracedObjectWrapperSelector<T, false> {
57 };
58 
59 template <typename T>
60 struct TracedObjectWrapperSelector<T, true> {
62 };
63 
64 template <typename K, typename V>
65 struct TracedObjectWrapperSelector<Map<K, V>, true> {
67 };
68 
69 template <typename T>
72 };
73 
74 template <typename T>
77 };
78 
79 } // namespace detail
80 
84 template <typename RefT>
85 class TracedObject {
86  using ObjectType = typename RefT::ContainerType;
87 
88  public:
89  using ObjectRefType = RefT;
90 
91  // Don't use this direcly. For convenience, call MakeTraced() instead.
92  explicit TracedObject(const RefT& object_ref, ObjectPath path)
93  : ref_(object_ref), path_(std::move(path)) {}
94 
95  // Implicit conversion from a derived reference class
96  template <typename DerivedRef>
98  : ref_(derived.Get()), path_(derived.GetPath()) {}
99 
103  template <typename T, typename BaseType>
104  typename detail::TracedObjectWrapperSelector<T>::Type GetAttr(T BaseType::*member_ptr) const {
105  using WrapperType = typename detail::TracedObjectWrapperSelector<T>::Type;
106  const ObjectType* node = static_cast<const ObjectType*>(ref_.get());
107  const T& attr = node->*member_ptr;
108  Optional<String> attr_key = ICHECK_NOTNULL(GetAttrKeyByAddress(node, &attr));
109  return WrapperType(attr, path_->Attr(attr_key));
110  }
111 
115  const RefT& Get() const { return ref_; }
116 
120  template <typename RefU>
121  bool IsInstance() const {
122  return ref_->template IsInstance<typename RefU::ContainerType>();
123  }
124 
128  bool defined() const { return ref_.defined(); }
129 
135  template <typename RefU>
137  return TracedObject<RefU>(tvm::runtime::Downcast<RefU>(ref_), path_);
138  }
139 
145  template <typename RefU>
147  if (ref_->template IsInstance<typename RefU::ContainerType>()) {
148  return Downcast<RefU>();
149  } else {
150  return TracedOptional<RefU>(NullOpt, path_);
151  }
152  }
153 
157  const ObjectPath& GetPath() const { return path_; }
158 
159  private:
160  RefT ref_;
161  ObjectPath path_;
162 };
163 
167 template <typename K, typename V>
169  public:
171  using MapIter = typename Map<K, V>::iterator;
172 
173  using iterator_category = std::bidirectional_iterator_tag;
174  using difference_type = ptrdiff_t;
175  using value_type = const std::pair<K, WrappedV>;
176  using pointer = value_type*;
178 
179  explicit TracedMapIterator(MapIter iter, ObjectPath map_path)
180  : iter_(iter), map_path_(std::move(map_path)) {}
181 
182  bool operator==(const TracedMapIterator& other) const { return iter_ == other.iter_; }
183 
184  bool operator!=(const TracedMapIterator& other) const { return iter_ != other.iter_; }
185 
186  pointer operator->() const = delete;
187 
189  auto kv = *iter_;
190  return std::make_pair(kv.first, WrappedV(kv.second, map_path_->MapValue(kv.first)));
191  }
192 
194  ++iter_;
195  return *this;
196  }
197 
199  TracedMapIterator copy = *this;
200  ++(*this);
201  return copy;
202  }
203 
204  private:
205  MapIter iter_;
206  ObjectPath map_path_;
207 };
208 
212 template <typename K, typename V>
213 class TracedMap {
214  public:
216 
218 
219  // Don't use this direcly. For convenience, call MakeTraced() instead.
220  explicit TracedMap(Map<K, V> map, ObjectPath path)
221  : map_(std::move(map)), path_(std::move(path)) {}
222 
226  WrappedV at(const K& key) const {
227  auto it = map_.find(key);
228  ICHECK(it != map_.end()) << "No such key in Map";
229  auto kv = *it;
230  return WrappedV(kv.second, path_->MapValue(kv.first));
231  }
232 
236  const Map<K, V>& Get() const { return map_; }
237 
241  const ObjectPath& GetPath() const { return path_; }
242 
246  iterator begin() const { return iterator(map_.begin(), path_); }
247 
251  iterator end() const { return iterator(map_.end(), path_); }
252 
256  bool empty() const { return map_.empty(); }
257 
258  private:
259  Map<K, V> map_;
260  ObjectPath path_;
261 };
262 
266 template <typename T>
268  public:
270 
271  using difference_type = ptrdiff_t;
273  using pointer = WrappedT*;
274  using reference = WrappedT&;
275  using iterator_category = std::random_access_iterator_tag;
276 
277  explicit TracedArrayIterator(Array<T> array, size_t index, ObjectPath array_path)
278  : array_(array), index_(index), array_path_(array_path) {}
279 
281  ++index_;
282  return *this;
283  }
285  --index_;
286  return *this;
287  }
289  TracedArrayIterator copy = *this;
290  ++index_;
291  return copy;
292  }
294  TracedArrayIterator copy = *this;
295  --index_;
296  return copy;
297  }
298 
300  return TracedArrayIterator(array_, index_ + offset, array_path_);
301  }
302 
304  return TracedArrayIterator(array_, index_ - offset, array_path_);
305  }
306 
307  difference_type operator-(const TracedArrayIterator& rhs) const { return index_ - rhs.index_; }
308 
309  bool operator==(TracedArrayIterator other) const {
310  return array_.get() == other.array_.get() && index_ == other.index_;
311  }
312  bool operator!=(TracedArrayIterator other) const { return !(*this == other); }
313  value_type operator*() const { return WrappedT(array_[index_], array_path_->ArrayIndex(index_)); }
314 
315  private:
316  Array<T> array_;
317  size_t index_;
318  ObjectPath array_path_;
319 };
320 
324 template <typename T>
325 class TracedArray {
326  public:
328 
330 
331  // Don't use this direcly. For convenience, call MakeTraced() instead.
332  explicit TracedArray(Array<T> array, ObjectPath path)
333  : array_(std::move(array)), path_(std::move(path)) {}
334 
338  const Array<T>& Get() const { return array_; }
339 
343  const ObjectPath& GetPath() const { return path_; }
344 
348  WrappedT operator[](size_t index) const {
349  return WrappedT(array_[index], path_->ArrayIndex(index));
350  }
351 
357  iterator begin() const { return iterator(array_, 0, path_); }
358 
364  iterator end() const { return iterator(array_, array_.size(), path_); }
365 
369  bool empty() const { return array_.empty(); }
370 
374  size_t size() const { return array_.size(); }
375 
376  private:
377  Array<T> array_;
378  ObjectPath path_;
379 };
380 
384 template <typename T>
385 class TracedOptional {
386  public:
388 
392  TracedOptional(const WrappedT& value) // NOLINT(runtime/explicit)
393  : optional_(value.Get().defined() ? value.Get() : Optional<T>(NullOpt)),
394  path_(value.GetPath()) {}
395 
396  // Don't use this direcly. For convenience, call MakeTraced() instead.
397  explicit TracedOptional(Optional<T> optional, ObjectPath path)
398  : optional_(std::move(optional)), path_(std::move(path)) {}
399 
403  const Optional<T>& Get() const { return optional_; }
404 
408  const ObjectPath& GetPath() const { return path_; }
409 
413  bool defined() const { return optional_.defined(); }
414 
418  WrappedT value() const { return WrappedT(optional_.value(), path_); }
419 
423  explicit operator bool() const { return optional_.defined(); }
424 
425  private:
426  Optional<T> optional_;
427  ObjectPath path_;
428 };
429 
433 template <typename T>
434 class TracedBasicValue {
435  public:
436  explicit TracedBasicValue(const T& value, ObjectPath path)
437  : value_(value), path_(std::move(path)) {}
438 
442  const T& Get() const { return value_; }
443 
447  const ObjectPath& GetPath() const { return path_; }
448 
452  template <typename F>
454  ApplyFunc(F&& f) const {
455  return MakeTraced(f(value_), path_);
456  }
457 
458  private:
459  T value_;
460  ObjectPath path_;
461 };
462 
466 template <typename RefT>
468  using WrappedT = typename detail::TracedObjectWrapperSelector<RefT>::Type;
469  return WrappedT(object, ObjectPath::Root());
470 }
471 
475 template <typename RefT>
477  ObjectPath path) {
478  using WrappedT = typename detail::TracedObjectWrapperSelector<RefT>::Type;
479  return WrappedT(object, std::move(path));
480 }
481 
482 } // namespace tvm
483 
484 #endif // TVM_SCRIPT_PRINTER_TRACED_OBJECT_H_
TracedOptional< RefU > TryDowncast() const
Convert the wrapped reference type to a subtype.
Definition: traced_object.h:146
typename detail::TracedObjectWrapperSelector< T >::Type WrappedT
Definition: traced_object.h:387
Traced wrapper for regular (non-container) TVM objects.
Definition: traced_object.h:39
bool empty() const
Returns true iff the wrapped array is empty.
Definition: traced_object.h:369
bool IsInstance() const
Check if the reference to the wrapped object can be converted to RefU.
Definition: traced_object.h:121
TracedMapIterator & operator++()
Definition: traced_object.h:193
const T & Get() const
Access the wrapped value.
Definition: traced_object.h:442
bool defined() const
Returns true iff the object is present.
Definition: traced_object.h:413
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
static ObjectPath Root()
Create a path that represents the root object itself.
TracedArrayIterator operator+(difference_type offset) const
Definition: traced_object.h:299
bool defined() const
Same as Get().defined().
Definition: traced_object.h:128
bool operator!=(const TracedMapIterator &other) const
Definition: traced_object.h:184
value_type reference
Definition: traced_object.h:177
Definition: loop_state.h:456
TracedArrayIterator & operator--()
Definition: traced_object.h:284
WrappedV at(const K &key) const
Get a value by its key, wrapped in a traced wrapper.
Definition: traced_object.h:226
typename detail::TracedObjectWrapperSelector< T >::Type WrappedT
Definition: traced_object.h:327
const RefT & Get() const
Access the wrapped object.
Definition: traced_object.h:115
TracedArrayIterator(Array< T > array, size_t index, ObjectPath array_path)
Definition: traced_object.h:277
iterator end() const
Get an iterator to the end of the map.
Definition: traced_object.h:251
bool empty() const
Returns true iff the wrapped map is empty.
Definition: traced_object.h:256
WrappedT value_type
Definition: traced_object.h:272
ptrdiff_t difference_type
Definition: traced_object.h:271
const std::pair< K, WrappedV > value_type
Definition: traced_object.h:175
const ObjectPath & GetPath() const
Get the path of the wrapped object.
Definition: traced_object.h:241
Iterator class for TracedMap<K, V>
Definition: traced_object.h:168
typename detail::TracedObjectWrapperSelector< V >::Type WrappedV
Definition: traced_object.h:170
const ObjectPath & GetPath() const
Get the path of the wrapped array object.
Definition: traced_object.h:343
const Map< K, V > & Get() const
Access the wrapped map object.
Definition: traced_object.h:236
ptrdiff_t difference_type
Definition: traced_object.h:174
bool operator==(const TracedMapIterator &other) const
Definition: traced_object.h:182
WrappedT * pointer
Definition: traced_object.h:273
WrappedT value() const
Returns a non-optional traced wrapper, throws if defined() is false.
Definition: traced_object.h:418
Traced wrapper for Map objects.
Definition: traced_object.h:41
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
detail::TracedObjectWrapperSelector< T >::Type GetAttr(T BaseType::*member_ptr) const
Get a traced wrapper for an attribute of the wrapped object.
Definition: traced_object.h:104
TracedObject(const RefT &object_ref, ObjectPath path)
Definition: traced_object.h:92
iterator end() const
Get an iterator to the end of the array.
Definition: traced_object.h:364
std::bidirectional_iterator_tag iterator_category
Definition: traced_object.h:173
WrappedT & reference
Definition: traced_object.h:274
TracedArrayIterator & operator++()
Definition: traced_object.h:280
const Object * get() const
Definition: object.h:546
TracedObject< RefU > Downcast() const
Convert the wrapped reference type to a subtype.
Definition: traced_object.h:136
const ObjectPath & GetPath() const
Get the path of the wrapped object.
Definition: traced_object.h:157
TracedArrayIterator operator--(int)
Definition: traced_object.h:293
detail::TracedObjectWrapperSelector< typename std::invoke_result< F, const T & >::type >::Type ApplyFunc(F &&f) const
Transform the wrapped value without changing its path.
Definition: traced_object.h:454
difference_type operator-(const TracedArrayIterator &rhs) const
Definition: traced_object.h:307
detail::TracedObjectWrapperSelector< RefT >::Type MakeTraced(const RefT &object)
Wrap the given root object in an appropriate traced wrapper class.
Definition: traced_object.h:467
Traced wrapper for Array objects.
Definition: traced_object.h:43
Definition: traced_object.h:52
typename detail::TracedObjectWrapperSelector< T >::Type WrappedT
Definition: traced_object.h:269
typename Map< K, V >::iterator MapIter
Definition: traced_object.h:171
std::random_access_iterator_tag iterator_category
Definition: traced_object.h:275
TracedArrayIterator operator-(difference_type offset) const
Definition: traced_object.h:303
RefT ObjectRefType
Definition: traced_object.h:89
value_type operator*() const
Definition: traced_object.h:313
typename detail::TracedObjectWrapperSelector< V >::Type WrappedV
Definition: traced_object.h:215
A managed object in the TVM runtime.
WrappedT operator[](size_t index) const
Get an element by index, wrapped in a traced wrapper.
Definition: traced_object.h:348
Iterator class for TracedArray<T>
Definition: traced_object.h:267
TracedBasicValue(const T &value, ObjectPath path)
Definition: traced_object.h:436
Definition: object_path.h:122
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
TracedMapIterator(MapIter iter, ObjectPath map_path)
Definition: traced_object.h:179
iterator begin() const
Get an iterator to the first array element.
Definition: traced_object.h:357
const ObjectPath & GetPath() const
Get the path of the wrapped value.
Definition: traced_object.h:447
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Optional< String > GetAttrKeyByAddress(const Object *object, const void *attr_address)
Given an object and an address of its attribute, return the key of the attribute. ...
Managed reference to TypeNode.
Definition: type.h:93
const Optional< T > & Get() const
Access the wrapped optional object.
Definition: traced_object.h:403
TracedMapIterator operator++(int)
Definition: traced_object.h:198
Traced wrapper for basic values (i.e. non-TVM objects)
Definition: traced_object.h:47
TracedMap(Map< K, V > map, ObjectPath path)
Definition: traced_object.h:220
iterator begin() const
Get an iterator to the first item of the map.
Definition: traced_object.h:246
const ObjectPath & GetPath() const
Get the path of the wrapped optional object.
Definition: traced_object.h:408
TracedOptional(Optional< T > optional, ObjectPath path)
Definition: traced_object.h:397
constexpr runtime::NullOptType NullOpt
Definition: optional.h:160
Reflection and serialization of compiler IR/AST nodes.
TracedArrayIterator operator++(int)
Definition: traced_object.h:288
bool operator!=(TracedArrayIterator other) const
Definition: traced_object.h:312
Traced wrapper for Optional objects.
Definition: traced_object.h:45
TracedArray(Array< T > array, ObjectPath path)
Definition: traced_object.h:332
bool operator==(TracedArrayIterator other) const
Definition: traced_object.h:309
const Array< T > & Get() const
Access the wrapped array object.
Definition: traced_object.h:338
value_type * pointer
Definition: traced_object.h:176
size_t size() const
Get the size of the wrapped array.
Definition: traced_object.h:374
TracedObject(const TracedObject< DerivedRef > &derived)
Definition: traced_object.h:97
reference operator*() const
Definition: traced_object.h:188
TracedOptional(const WrappedT &value)
Implicit conversion from the corresponding non-optional traced wrapper.
Definition: traced_object.h:392