tvm
attrs.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
44 #ifndef TVM_IR_ATTRS_H_
45 #define TVM_IR_ATTRS_H_
46 
47 #include <dmlc/common.h>
48 #include <tvm/ir/expr.h>
52 
53 #include <functional>
54 #include <string>
55 #include <type_traits>
56 #include <unordered_map>
57 #include <utility>
58 #include <vector>
59 
60 namespace tvm {
66 #define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
67  static constexpr const char* _type_key = TypeKey; \
68  TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
69  template <typename FVisit> \
70  void _tvm_VisitAttrs(FVisit& _tvm_fvisit) // NOLINT(*)
71 
76 #define TVM_ATTR_FIELD(FieldName) _tvm_fvisit(#FieldName, &FieldName)
77 
83 template <typename TObjectRef>
84 inline TObjectRef NullValue() {
85  static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types");
86  return TObjectRef(ObjectPtr<Object>(nullptr));
87 }
88 
89 template <>
91  return DataType(DataType::kHandle, 0, 0);
92 }
93 
95 struct AttrError : public Error {
100  explicit AttrError(std::string msg) : Error("AttributeError:" + msg) {}
101 };
102 
106 class AttrFieldInfoNode : public Object {
107  public:
114 
116  v->Visit("name", &name);
117  v->Visit("type_info", &type_info);
118  v->Visit("description", &description);
119  }
120 
121  static constexpr const char* _type_key = "AttrFieldInfo";
122  static constexpr bool _type_has_method_sequal_reduce = false;
123  static constexpr bool _type_has_method_shash_reduce = false;
125 };
126 
128 class AttrFieldInfo : public ObjectRef {
129  public:
131 };
132 
139 class BaseAttrsNode : public Object {
140  public:
144  virtual ~BaseAttrsNode() {}
145  // visit function
146  virtual void VisitAttrs(AttrVisitor* v) {}
152  template <typename... Args>
153  inline void InitBySeq(Args&&... args);
158  inline void PrintDocString(std::ostream& os) const; // NOLINT(*)
165  TVM_DLL virtual void VisitNonDefaultAttrs(AttrVisitor* v) = 0;
170  TVM_DLL virtual Array<AttrFieldInfo> ListFieldInfo() const = 0;
178  TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0;
179 
180  static constexpr const bool _type_has_method_sequal_reduce = true;
181  static constexpr const bool _type_has_method_shash_reduce = true;
182  static constexpr const char* _type_key = "Attrs";
184 };
185 
190 class Attrs : public ObjectRef {
191  public:
193 };
194 
201 class DictAttrsNode : public BaseAttrsNode {
202  public:
205 
206  bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const {
207  return equal(dict, other->dict);
208  }
209 
210  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dict); }
211 
212  // implementations
213  void VisitAttrs(AttrVisitor* v) final;
215  void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
217 
218  // type info
219  static constexpr const char* _type_key = "DictAttrs";
221 };
222 
227 class DictAttrs : public Attrs {
228  public:
233  TVM_DLL explicit DictAttrs(Map<String, ObjectRef> dict = {});
234 
235  // Utils for accessing attributes
236  // This needs to be on DictAttrs, not DictAttrsNode because we return the default
237  // value if DictAttrsNode is not defined.
257  template <typename TObjectRef>
259  const std::string& attr_key,
260  Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
261  static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
262  "Can only call GetAttr with ObjectRef types.");
263  if (!defined()) return default_value;
264  const DictAttrsNode* node = this->as<DictAttrsNode>();
265 
266  auto it = node->dict.find(attr_key);
267  if (it != node->dict.end()) {
268  // For backwards compatibility, return through TVMRetValue.
269  // This triggers any automatic conversions registered with
270  // PackedFuncValueConverter. Importantly, this allows use of
271  // `GetAttr<Integer>` and `GetAttr<Bool>` for properties that
272  // are stored internally as `runtime::Box<int64_t>` and
273  // `runtime::Box<bool>`.
275  ret = (*it).second;
277  return obj;
278  } else {
279  return default_value;
280  }
281  }
282  // variant that uses TObjectRef to enable implicit conversion to default value.
283  template <typename TObjectRef>
284  Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
285  return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
286  }
306  bool HasNonzeroAttr(const std::string& attr_key) const {
307  return GetAttr<Integer>(attr_key, 0).value_or(0).IntValue() != 0;
308  }
309 
312 };
313 
319 template <typename TAttrs>
320 inline TAttrs AttrsWithDefaultValues() {
321  static_assert(std::is_base_of<Attrs, TAttrs>::value, "Can only take attr nodes");
322  auto n = make_object<typename TAttrs::ContainerType>();
323  n->InitByPackedArgs(runtime::TVMArgs(nullptr, nullptr, 0), false);
324  return TAttrs(n);
325 }
326 
338 
351 
352 inline DictAttrs WithAttr(DictAttrs attrs, const std::string& key, ObjectRef value) {
353  return WithAttr(std::move(attrs), String(key), std::move(value));
354 }
355 
365 DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key);
366 
394 template <typename TFunc>
395 inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_value) {
396  using TNode = typename TFunc::ContainerType;
397  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
398  TNode* node = input.CopyOnWrite();
399  node->attrs = WithAttr(std::move(node->attrs), attr_key, attr_value);
400 
401  return input;
402 }
403 
414 template <typename TFunc>
415 inline TFunc WithAttrs(TFunc input, Map<String, ObjectRef> attrs) {
416  using TNode = typename TFunc::ContainerType;
417  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
418  TNode* node = input.CopyOnWrite();
419 
420  node->attrs = WithAttrs(std::move(node->attrs), attrs);
421 
422  return input;
423 }
424 
451 template <typename TFunc>
452 inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) {
453  using TNode = typename TFunc::ContainerType;
454  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
455 
456  TNode* node = input.CopyOnWrite();
457  node->attrs = WithoutAttr(std::move(node->attrs), attr_key);
458 
459  return input;
460 }
461 
462 // Namespace containing detail implementations
463 namespace detail {
465 
466 // helper entry that does nothing in set_default/bound/describe calls.
467 struct AttrNopEntry {
469 
470  TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
471  template <typename T>
472  TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {
473  return *this;
474  }
475  template <typename T>
476  TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
477  return *this;
478  }
479  template <typename T>
480  TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
481  return *this;
482  }
483 };
484 
485 // Wrapper for normal visitor.
487  public:
488  explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
489  template <typename T>
490  AttrNopEntry operator()(const char* key, T* value) {
491  visitor_->Visit(key, value);
492  return AttrNopEntry();
493  }
494 
495  private:
496  AttrVisitor* visitor_;
497 };
498 
500  public:
501  bool result_{true};
502  // constructor
503  AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal)
504  : lhs_(lhs), rhs_(rhs), equal_(equal) {}
505  template <typename T>
506  AttrNopEntry operator()(const char* key, T* lhs_value) {
507  if (!result_) return AttrNopEntry();
508  const T* rhs_value = reinterpret_cast<const T*>(
509  reinterpret_cast<const char*>(rhs_) +
510  (reinterpret_cast<const char*>(lhs_value) - reinterpret_cast<const char*>(lhs_)));
511  if (!equal_(*lhs_value, *rhs_value)) {
512  result_ = false;
513  }
514  return AttrNopEntry();
515  }
516 
517  private:
518  const Object* lhs_;
519  const Object* rhs_;
520  const SEqualReducer& equal_;
521 };
522 
524  public:
525  explicit AttrsSHashVisitor(const SHashReducer& hash_reducer) : hash_reducer_(hash_reducer) {}
526 
527  template <typename T>
528  AttrNopEntry operator()(const char* key, T* value) {
529  hash_reducer_(*value);
530  return AttrNopEntry();
531  }
532 
533  private:
534  const SHashReducer& hash_reducer_;
535 };
536 
537 // helper entry that does initialization, set default.
538 template <typename T>
540  // The attributes
542  // The type key
543  const char* type_key_;
544  // field name
545  const char* key_;
546  // internal value.
547  T* value_;
548  // whether the value is missing.
549  // NOTE: initialize to false so that the destructor does not throw unless
550  // AttrInitVisitor::operator() is committed to returning an instance of this class.
551  // It is expected not to set this to true until that is true.
552  bool value_missing_{false};
553 
554  AttrInitEntry() = default;
555 
557  type_key_ = other.type_key_;
558  key_ = other.key_;
559  value_ = other.value_;
560  value_missing_ = other.value_missing_;
561  // avoid unexpected throw
562  other.value_missing_ = false;
563  }
564 
565  // If the value is still missing in destruction time throw an error.
566  ~AttrInitEntry() DMLC_THROW_EXCEPTION {
567  if (value_missing_) {
568  std::ostringstream os;
569  os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization. "
570  << "If the key is defined check that its type matches the declared type.";
571  throw AttrError(os.str());
572  }
573  }
574  // override fields.
575  // This function sets the lower bound of the attribute
576  TSelf& set_lower_bound(const T& begin) {
577  if (this->value_missing_) return *this;
578  const T& val = *value_;
579  if (begin > val) {
580  std::ostringstream os;
581  os << type_key_ << "." << key_ << ": "
582  << "value " << val << " is smaller than the lower bound " << begin;
583  throw AttrError(os.str());
584  }
585  return *this;
586  }
587  // This function sets the upper bound of the attribute
588  TSelf& set_upper_bound(const T& end) {
589  if (this->value_missing_) return *this;
590  const T& val = *value_;
591  if (val > end) {
592  std::ostringstream os;
593  os << type_key_ << "." << key_ << ": "
594  << "value " << val << " is bigger than the upper bound " << end;
595  throw AttrError(os.str());
596  }
597  return *this;
598  }
599  // set default when
600  TSelf& set_default(const T& value) {
601  if (!value_missing_) return *this;
602  *value_ = value;
603  value_missing_ = false;
604  return *this;
605  }
606  TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
607 };
608 
609 // Template function to allow smart conversion
610 // from Expr types into the constants.
611 template <typename T>
612 inline void SetValue(T* ptr, const TVMArgValue& val) {
613  *ptr = val.operator T();
614 }
615 
616 template <typename T>
617 inline void SetIntValue(T* ptr, const TVMArgValue& val) {
618  if (val.type_code() == kDLInt) {
619  *ptr = static_cast<T>(val.value().v_int64);
620  } else {
621  IntImm expr = val;
622  *ptr = static_cast<T>(expr->value);
623  }
624 }
625 
626 // Workaround for GCC8.1 / GCC8.2
627 template <>
628 inline void SetValue<DataType>(DataType* ptr, const TVMArgValue& val) {
629  *ptr = val.operator DataType();
630 }
631 
632 template <>
633 inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
634  if (String::CanConvertFrom(val)) {
635  *ptr = val.operator std::string();
636  } else {
637  LOG(FATAL) << "Expect str";
638  }
639 }
640 
641 template <>
642 inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
643  if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
644  *ptr = val.operator double();
645  } else {
646  ObjectRef expr = val;
647  ICHECK(expr.defined());
648  if (const IntImmNode* op = expr.as<IntImmNode>()) {
649  *ptr = static_cast<double>(op->value);
650  } else if (const FloatImmNode* op = expr.as<FloatImmNode>()) {
651  *ptr = static_cast<double>(op->value);
652  } else {
653  LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
654  }
655  }
656 }
657 template <>
658 inline void SetValue<int>(int* ptr, const TVMArgValue& val) {
659  SetIntValue(ptr, val);
660 }
661 template <>
662 inline void SetValue<int64_t>(int64_t* ptr, const TVMArgValue& val) {
663  SetIntValue(ptr, val);
664 }
665 template <>
666 inline void SetValue<uint64_t>(uint64_t* ptr, const TVMArgValue& val) {
667  SetIntValue(ptr, val);
668 }
669 template <>
670 inline void SetValue<bool>(bool* ptr, const TVMArgValue& val) {
671  SetIntValue(ptr, val);
672 }
673 
674 // Visitor for value initialization
675 template <typename FFind>
677  public:
678  // Counter of number of matched attributes during visit.
679  // This is used to decide if there is additional unmatched attributes.
680  size_t hit_count_{0};
681  // constructor
682  AttrInitVisitor(const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) {}
683 
684  template <typename T>
685  AttrInitEntry<T> operator()(const char* key, T* value) {
686  TVMArgValue val;
687  AttrInitEntry<T> opt;
688  opt.type_key_ = type_key_;
689  opt.key_ = key;
690  opt.value_ = value;
691  if (ffind_(key, &val)) {
692  SetValue(value, val);
693  opt.value_missing_ = false;
694  ++hit_count_;
695  } else {
696  opt.value_missing_ = true;
697  }
698 #if defined(__GNUC__)
699 #pragma GCC diagnostic ignored "-Wpragmas"
700 #pragma GCC diagnostic ignored "-Wpessimizing-move"
701 #endif
702  return std::move(opt);
703  }
704 
705  private:
706  // the type key
707  const char* type_key_;
708  FFind ffind_;
709 };
710 
711 template <typename FFind>
712 inline AttrInitVisitor<FFind> CreateInitVisitor(const char* type_key, FFind ffind) {
713  return AttrInitVisitor<FFind>(type_key, ffind);
714 }
715 
720 template <typename T>
721 struct TypeName {
722  static constexpr const char* value = T::ContainerType::_type_key;
723 };
724 
725 template <>
726 struct TypeName<int> {
727  static constexpr const char* value = "int";
728 };
729 
730 template <>
731 struct TypeName<int64_t> {
732  static constexpr const char* value = "int64";
733 };
734 
735 template <>
736 struct TypeName<uint64_t> {
737  static constexpr const char* value = "uint64_t";
738 };
739 
740 template <>
742  static constexpr const char* value = "DataType";
743 };
744 
745 template <>
746 struct TypeName<std::string> {
747  static constexpr const char* value = "str";
748 };
749 
750 template <>
751 struct TypeName<bool> {
752  static constexpr const char* value = "bool";
753 };
754 
755 template <>
756 struct TypeName<void*> {
757  static constexpr const char* value = "handle";
758 };
759 
760 template <>
761 struct TypeName<double> {
762  static constexpr const char* value = "double";
763 };
764 
766  public:
768 
769  explicit AttrDocEntry(ObjectPtr<AttrFieldInfoNode> info) : info_(info) {}
770  TSelf& describe(const char* str) {
771  info_->description = str;
772  return *this;
773  }
774  template <typename T>
775  TSelf& set_default(const T& value) {
776  std::ostringstream os;
777  os << info_->type_info << ", default=" << value;
778  info_->type_info = os.str();
779  return *this;
780  }
781  template <typename T>
782  TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) {
783  return *this;
784  }
785  template <typename T>
786  TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) {
787  return *this;
788  }
789 
790  private:
792 };
793 
795  public:
796  template <typename T>
797  AttrDocEntry operator()(const char* key, T* v) {
798  ObjectPtr<AttrFieldInfoNode> info = make_object<AttrFieldInfoNode>();
799  info->name = key;
800  info->type_info = TypeName<T>::value;
801  fields_.push_back(AttrFieldInfo(info));
802  return AttrDocEntry(info);
803  }
804 
806 };
807 
809  public:
810  std::string key_;
811  bool exist_{false};
812 
813  template <typename T>
814  AttrNopEntry operator()(const char* key, T* v) {
815  if (exist_) return AttrNopEntry();
816  if (key == key_) exist_ = true;
817  return AttrNopEntry();
818  }
819 };
820 
821 template <typename T>
824  // constructor
825  AttrTriggerNonDefaultEntry(AttrVisitor* visitor, const char* key, T* data)
826  : visitor_(visitor), key_(key), data_(data) {}
827 
828  ~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION {
829  if (trigger_) {
830  visitor_->Visit(key_, data_);
831  }
832  }
833  TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
834  TSelf& set_default(const T& value) {
835  if (tvm::StructuralEqual()(value, *data_)) {
836  trigger_ = false;
837  }
838  return *this;
839  }
840  TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { return *this; }
841  TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { return *this; }
842 
843  private:
844  AttrVisitor* visitor_;
845  const char* key_;
846  T* data_;
847  bool trigger_{true};
848 };
849 
851  public:
852  explicit AttrNonDefaultVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
853  template <typename T>
854  AttrTriggerNonDefaultEntry<T> operator()(const char* key, T* value) {
855  return AttrTriggerNonDefaultEntry<T>(visitor_, key, value);
856  }
857 
858  private:
859  AttrVisitor* visitor_;
860 };
861 } // namespace detail
862 
869 template <typename DerivedType>
870 class AttrsNode : public BaseAttrsNode {
871  public:
874  self()->_tvm_VisitAttrs(vis);
875  }
876 
879  self()->_tvm_VisitAttrs(vis);
880  }
881 
882  void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final {
883  ICHECK_EQ(args.size() % 2, 0);
884  const int kLinearSearchBound = 16;
885  int hit_count = 0;
886  // applies two strategies to lookup
887  if (args.size() < kLinearSearchBound) {
888  // linear search.
889  auto ffind = [&args](const char* key, runtime::TVMArgValue* val) {
890  for (int i = 0; i < args.size(); i += 2) {
891  ICHECK_EQ(args.type_codes[i], kTVMStr);
892  if (!std::strcmp(key, args.values[i].v_str)) {
893  *val = args[i + 1];
894  return true;
895  }
896  }
897  return false;
898  };
899  auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
900  self()->_tvm_VisitAttrs(vis);
901  hit_count = vis.hit_count_;
902  } else {
903  // construct a map then do lookup.
904  std::unordered_map<std::string, runtime::TVMArgValue> kwargs;
905  for (int i = 0; i < args.size(); i += 2) {
906  ICHECK_EQ(args.type_codes[i], kTVMStr);
907  kwargs[args[i].operator std::string()] = args[i + 1];
908  }
909  auto ffind = [&kwargs](const char* key, runtime::TVMArgValue* val) {
910  auto it = kwargs.find(key);
911  if (it != kwargs.end()) {
912  *val = it->second;
913  return true;
914  }
915  return false;
916  };
917  auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
918  self()->_tvm_VisitAttrs(vis);
919  hit_count = vis.hit_count_;
920  }
921  // error handling, slow path
922  if (hit_count * 2 != args.size() && !allow_unknown) {
923  for (int i = 0; i < args.size(); i += 2) {
925  visitor.key_ = args[i].operator std::string();
926  self()->_tvm_VisitAttrs(visitor);
927  if (!visitor.exist_) {
928  std::ostringstream os;
929  os << DerivedType::_type_key << ": does not have field \'" << visitor.key_
930  << "\', Possible fields:\n";
931  os << "----------------\n";
932  this->PrintDocString(os);
933  throw AttrError(os.str());
934  }
935  }
936  }
937  }
938 
939  bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {
940  DerivedType* pself = self();
941  ::tvm::detail::AttrsSEqualVisitor visitor(pself, other, equal);
942  self()->_tvm_VisitAttrs(visitor);
943  return visitor.result_;
944  }
945 
946  void SHashReduce(SHashReducer hash_reducer) const {
947  ::tvm::detail::AttrsSHashVisitor visitor(hash_reducer);
948  self()->_tvm_VisitAttrs(visitor);
949  }
950 
953  self()->_tvm_VisitAttrs(visitor);
954  return visitor.fields_;
955  }
956 
957  private:
958  DerivedType* self() const {
959  return const_cast<DerivedType*>(static_cast<const DerivedType*>(this));
960  }
961 };
962 
963 template <typename... Args>
964 inline void BaseAttrsNode::InitBySeq(Args&&... args) {
966  [this](const TVMArgs& args, TVMRetValue* rv) { this->InitByPackedArgs(args); });
967  pf(std::forward<Args>(args)...);
968 }
969 
970 inline void BaseAttrsNode::PrintDocString(std::ostream& os) const { // NOLINT(*)
971  Array<AttrFieldInfo> entry = this->ListFieldInfo();
972  for (AttrFieldInfo info : entry) {
973  os << info->name << " : " << info->type_info << '\n';
974  if (info->description.length() != 0) {
975  os << " " << info->description << '\n';
976  }
977  }
978 }
979 
980 } // namespace tvm
981 #endif // TVM_IR_ATTRS_H_
@ kTVMStr
Definition: c_runtime_api.h:186
Information about attribute fields in string representations.
Definition: attrs.h:106
static constexpr bool _type_has_method_shash_reduce
Definition: attrs.h:123
String description
detailed description of the type
Definition: attrs.h:113
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object)
String type_info
type docstring information in str.
Definition: attrs.h:111
void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:115
String name
name of the field
Definition: attrs.h:109
static constexpr bool _type_has_method_sequal_reduce
Definition: attrs.h:122
static constexpr const char * _type_key
Definition: attrs.h:121
AttrFieldInfo.
Definition: attrs.h:128
TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode)
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
The base class of the all the Use "curiously recurring template pattern".
Definition: attrs.h:870
bool SEqualReduce(const DerivedType *other, SEqualReducer equal) const
Definition: attrs.h:939
void SHashReduce(SHashReducer hash_reducer) const
Definition: attrs.h:946
void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:872
void VisitNonDefaultAttrs(AttrVisitor *v)
Visit attributes that do not equal the default value.
Definition: attrs.h:877
Array< AttrFieldInfo > ListFieldInfo() const final
Get the field information.
Definition: attrs.h:951
void InitByPackedArgs(const runtime::TVMArgs &args, bool allow_unknown) final
Initialize the attributes by arguments.
Definition: attrs.h:882
Managed reference to BaseAttrsNode.
Definition: attrs.h:190
TVM_DEFINE_OBJECT_REF_METHODS(Attrs, ObjectRef, BaseAttrsNode)
Base class of all attribute class.
Definition: attrs.h:139
virtual ~BaseAttrsNode()
virtual destructor
Definition: attrs.h:144
virtual void InitByPackedArgs(const TVMArgs &kwargs, bool allow_unknown=false)=0
Initialize the attributes by arguments.
virtual Array< AttrFieldInfo > ListFieldInfo() const =0
Get the field information.
void PrintDocString(std::ostream &os) const
Print readible docstring to ostream, add newline.
Definition: attrs.h:970
static constexpr const char * _type_key
Definition: attrs.h:182
TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object)
void InitBySeq(Args &&... args)
Initialize the attributes by sequence of arguments.
Definition: attrs.h:964
virtual void VisitNonDefaultAttrs(AttrVisitor *v)=0
Visit attributes that do not equal the default value.
static constexpr const bool _type_has_method_sequal_reduce
Definition: attrs.h:180
static constexpr const bool _type_has_method_shash_reduce
Definition: attrs.h:181
virtual void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:146
Specialized attribute type that is backed by a map. The DictAttrsNode implements the Attrs behavior,...
Definition: attrs.h:201
TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode)
void InitByPackedArgs(const runtime::TVMArgs &args, bool allow_unknown) final
Initialize the attributes by arguments.
void VisitAttrs(AttrVisitor *v) final
static constexpr const char * _type_key
Definition: attrs.h:219
Array< AttrFieldInfo > ListFieldInfo() const final
Get the field information.
void VisitNonDefaultAttrs(AttrVisitor *v) final
Visit attributes that do not equal the default value.
Map< String, ObjectRef > dict
internal attrs map
Definition: attrs.h:204
void SHashReduce(SHashReducer hash_reduce) const
Definition: attrs.h:210
bool SEqualReduce(const DictAttrsNode *other, SEqualReducer equal) const
Definition: attrs.h:206
Managed reference to DictAttrsNode.
Definition: attrs.h:227
bool HasNonzeroAttr(const std::string &attr_key) const
Check whether the function has an non-zero integer attr.
Definition: attrs.h:306
Optional< TObjectRef > GetAttr(const std::string &attr_key, Optional< TObjectRef > default_value=Optional< TObjectRef >(nullptr)) const
Get a function attribute.
Definition: attrs.h:258
DictAttrs(Map< String, ObjectRef > dict={})
Consruct a Attrs backed by DictAttrsNode.
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode)
Optional< TObjectRef > GetAttr(const std::string &attr_key, TObjectRef default_value) const
Definition: attrs.h:284
TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(DictAttrs, Attrs, DictAttrsNode)
Constant floating point literals in the program.
Definition: expr.h:548
Constant integer literals in the program.
Definition: expr.h:501
Managed reference class to IntImmNode.
Definition: expr.h:530
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:114
Definition: attrs.h:765
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin)
Definition: attrs.h:782
TSelf & set_default(const T &value)
Definition: attrs.h:775
AttrDocEntry(ObjectPtr< AttrFieldInfoNode > info)
Definition: attrs.h:769
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end)
Definition: attrs.h:786
TSelf & describe(const char *str)
Definition: attrs.h:770
Definition: attrs.h:794
AttrDocEntry operator()(const char *key, T *v)
Definition: attrs.h:797
Array< AttrFieldInfo > fields_
Definition: attrs.h:805
Definition: attrs.h:808
AttrNopEntry operator()(const char *key, T *v)
Definition: attrs.h:814
std::string key_
Definition: attrs.h:810
bool exist_
Definition: attrs.h:811
Definition: attrs.h:676
AttrInitVisitor(const char *type_key, FFind ffind)
Definition: attrs.h:682
AttrInitEntry< T > operator()(const char *key, T *value)
Definition: attrs.h:685
Definition: attrs.h:850
AttrNonDefaultVisitor(AttrVisitor *visitor)
Definition: attrs.h:852
AttrTriggerNonDefaultEntry< T > operator()(const char *key, T *value)
Definition: attrs.h:854
Definition: attrs.h:486
AttrNopEntry operator()(const char *key, T *value)
Definition: attrs.h:490
AttrNormalVisitor(AttrVisitor *visitor)
Definition: attrs.h:488
Definition: attrs.h:499
AttrNopEntry operator()(const char *key, T *lhs_value)
Definition: attrs.h:506
AttrsSEqualVisitor(const Object *lhs, const Object *rhs, const SEqualReducer &equal)
Definition: attrs.h:503
bool result_
Definition: attrs.h:501
Definition: attrs.h:523
AttrNopEntry operator()(const char *key, T *value)
Definition: attrs.h:528
AttrsSHashVisitor(const SHashReducer &hash_reducer)
Definition: attrs.h:525
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Runtime primitive data type.
Definition: data_type.h:43
@ kHandle
Definition: data_type.h:57
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
bool defined() const
Definition: object.h:552
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
base class of all object containers.
Definition: object.h:171
std::string GetTypeKey() const
Definition: object.h:184
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:141
Reference to string objects.
Definition: string.h:98
static bool CanConvertFrom(const TVMArgValue &val)
Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String.
Definition: packed_func.h:2683
A single argument value to PackedFunc. Containing both type_code and TVMValue.
Definition: packed_func.h:796
const TVMValue & value() const
Definition: packed_func.h:838
Arguments into TVM functions.
Definition: packed_func.h:394
int type_code() const
Definition: packed_func.h:656
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:946
Base expr nodes in TVM.
void SetValue< int >(int *ptr, const TVMArgValue &val)
Definition: attrs.h:658
void SetValue< double >(double *ptr, const TVMArgValue &val)
Definition: attrs.h:642
void SetValue< DataType >(DataType *ptr, const TVMArgValue &val)
Definition: attrs.h:628
AttrInitVisitor< FFind > CreateInitVisitor(const char *type_key, FFind ffind)
Definition: attrs.h:712
void SetValue< uint64_t >(uint64_t *ptr, const TVMArgValue &val)
Definition: attrs.h:666
void SetValue< int64_t >(int64_t *ptr, const TVMArgValue &val)
Definition: attrs.h:662
void SetValue< bool >(bool *ptr, const TVMArgValue &val)
Definition: attrs.h:670
void SetValue(T *ptr, const TVMArgValue &val)
Definition: attrs.h:612
void SetIntValue(T *ptr, const TVMArgValue &val)
Definition: attrs.h:617
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
DictAttrs WithAttr(DictAttrs attrs, String key, ObjectRef value)
Copy the DictAttrs, but overrides a single attribute.
DictAttrs WithoutAttr(DictAttrs attrs, const std::string &key)
Copy the DictAttrs, but without a specific attribute.
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
DataType NullValue< DataType >()
Definition: attrs.h:90
TAttrs AttrsWithDefaultValues()
Create an Attr object with all default values.
Definition: attrs.h:320
runtime::DataType DataType
Definition: data_type.h:493
DictAttrs WithAttrs(DictAttrs attrs, Map< String, ObjectRef > new_attrs)
Copy the DictAttrs, but overrides attributes with the entries from attrs.
TObjectRef NullValue()
Create a NodeRef type that represents null.
Definition: attrs.h:84
Type-erased function used across TVM API.
Error thrown during attribute checking.
Definition: attrs.h:95
AttrError(std::string msg)
constructor
Definition: attrs.h:100
Definition: attrs.h:539
TSelf & set_lower_bound(const T &begin)
Definition: attrs.h:576
const char * type_key_
Definition: attrs.h:543
const char * key_
Definition: attrs.h:545
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:606
TSelf & set_upper_bound(const T &end)
Definition: attrs.h:588
~AttrInitEntry() DMLC_THROW_EXCEPTION
Definition: attrs.h:566
bool value_missing_
Definition: attrs.h:552
TSelf & set_default(const T &value)
Definition: attrs.h:600
T * value_
Definition: attrs.h:547
AttrInitEntry(AttrInitEntry &&other)
Definition: attrs.h:556
Definition: attrs.h:467
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T &begin)
Definition: attrs.h:476
TSelf & set_default(DMLC_ATTRIBUTE_UNUSED const T &value)
Definition: attrs.h:472
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:470
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T &end)
Definition: attrs.h:480
AttrTriggerNonDefaultEntry(AttrVisitor *visitor, const char *key, T *data)
Definition: attrs.h:825
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T &begin)
Definition: attrs.h:840
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T &end)
Definition: attrs.h:841
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:833
~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION
Definition: attrs.h:828
TSelf & set_default(const T &value)
Definition: attrs.h:834
Helper struct to get the type name known to tvm.
Definition: attrs.h:721
Structural equality comparison.
int64_t v_int64
Definition: c_runtime_api.h:211