tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
attrs.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
44 #ifndef TVM_IR_ATTRS_H_
45 #define TVM_IR_ATTRS_H_
46 
47 #include <dmlc/common.h>
48 #include <tvm/ir/expr.h>
52 
53 #include <functional>
54 #include <string>
55 #include <type_traits>
56 #include <unordered_map>
57 #include <utility>
58 #include <vector>
59 
60 namespace tvm {
66 #define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
67  static constexpr const char* _type_key = TypeKey; \
68  TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
69  template <typename FVisit> \
70  void _tvm_VisitAttrs(FVisit& _tvm_fvisit) // NOLINT(*)
71 
76 #define TVM_ATTR_FIELD(FieldName) _tvm_fvisit(#FieldName, &FieldName)
77 
83 template <typename TObjectRef>
84 inline TObjectRef NullValue() {
85  static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types");
86  return TObjectRef(ObjectPtr<Object>(nullptr));
87 }
88 
89 template <>
91  return DataType(DataType::kHandle, 0, 0);
92 }
93 
95 struct AttrError : public Error {
100  explicit AttrError(std::string msg) : Error("AttributeError:" + msg) {}
101 };
102 
106 class AttrFieldInfoNode : public Object {
107  public:
114 
116  v->Visit("name", &name);
117  v->Visit("type_info", &type_info);
118  v->Visit("description", &description);
119  }
120 
121  static constexpr const char* _type_key = "AttrFieldInfo";
122  static constexpr bool _type_has_method_sequal_reduce = false;
123  static constexpr bool _type_has_method_shash_reduce = false;
125 };
126 
128 class AttrFieldInfo : public ObjectRef {
129  public:
131 };
132 
139 class BaseAttrsNode : public Object {
140  public:
144  virtual ~BaseAttrsNode() {}
145  // visit function
146  virtual void VisitAttrs(AttrVisitor* v) {}
152  template <typename... Args>
153  inline void InitBySeq(Args&&... args);
158  inline void PrintDocString(std::ostream& os) const; // NOLINT(*)
165  TVM_DLL virtual void VisitNonDefaultAttrs(AttrVisitor* v) = 0;
170  TVM_DLL virtual Array<AttrFieldInfo> ListFieldInfo() const = 0;
178  TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0;
179 
180  static constexpr const bool _type_has_method_sequal_reduce = true;
181  static constexpr const bool _type_has_method_shash_reduce = true;
182  static constexpr const char* _type_key = "Attrs";
184 };
185 
190 class Attrs : public ObjectRef {
191  public:
193 };
194 
201 class DictAttrsNode : public BaseAttrsNode {
202  public:
205 
206  bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const {
207  return equal(dict, other->dict);
208  }
209 
210  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dict); }
211 
212  // implementations
213  void VisitAttrs(AttrVisitor* v) final;
215  void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
217 
218  // type info
219  static constexpr const char* _type_key = "DictAttrs";
221 };
222 
227 class DictAttrs : public Attrs {
228  public:
234 
235  // Utils for accessing attributes
236  // This needs to be on DictAttrs, not DictAttrsNode because we return the default
237  // value if DictAttrsNode is not defined.
257  template <typename TObjectRef>
259  const std::string& attr_key,
260  Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
261  static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
262  "Can only call GetAttr with ObjectRef types.");
263  if (!defined()) return default_value;
264  const DictAttrsNode* node = this->as<DictAttrsNode>();
265 
266  auto it = node->dict.find(attr_key);
267  if (it != node->dict.end()) {
268  return Downcast<Optional<TObjectRef>>((*it).second);
269  } else {
270  return default_value;
271  }
272  }
273  // variant that uses TObjectRef to enable implicit conversion to default value.
274  template <typename TObjectRef>
275  Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
276  return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
277  }
297  bool HasNonzeroAttr(const std::string& attr_key) const {
298  return GetAttr<Integer>(attr_key, 0).value_or(0).IntValue() != 0;
299  }
300 
303 };
304 
310 template <typename TAttrs>
311 inline TAttrs AttrsWithDefaultValues() {
312  static_assert(std::is_base_of<Attrs, TAttrs>::value, "Can only take attr nodes");
313  auto n = make_object<typename TAttrs::ContainerType>();
314  n->InitByPackedArgs(runtime::TVMArgs(nullptr, nullptr, 0), false);
315  return TAttrs(n);
316 }
317 
345 template <typename TFunc>
346 inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_value) {
347  using TNode = typename TFunc::ContainerType;
348  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
349  TNode* node = input.CopyOnWrite();
350  if (node->attrs.defined()) {
351  node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
352  } else {
353  Map<String, ObjectRef> dict = {{attr_key, attr_value}};
354  node->attrs = DictAttrs(dict);
355  }
356  return input;
357 }
358 
369 template <typename TFunc>
370 inline TFunc WithAttrs(TFunc input, Map<String, ObjectRef> attrs) {
371  using TNode = typename TFunc::ContainerType;
372  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
373  TNode* node = input.CopyOnWrite();
374  if (node->attrs.defined()) {
375  for (const auto& pair : attrs) {
376  node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second);
377  }
378  } else {
379  node->attrs = DictAttrs(std::move(attrs));
380  }
381  return input;
382 }
383 
410 template <typename TFunc>
411 inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) {
412  using TNode = typename TFunc::ContainerType;
413  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
414 
415  if (input->attrs.defined()) {
416  TNode* node = input.CopyOnWrite();
417  node->attrs.CopyOnWrite()->dict.erase(attr_key);
418  if (node->attrs->dict.size() == 0) {
419  node->attrs = NullValue<DictAttrs>();
420  }
421  }
422  return input;
423 }
424 
425 // Namespace containing detail implementations
426 namespace detail {
428 
429 // helper entry that does nothing in set_default/bound/describe calls.
430 struct AttrNopEntry {
432 
433  TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
434  template <typename T>
435  TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {
436  return *this;
437  }
438  template <typename T>
439  TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
440  return *this;
441  }
442  template <typename T>
443  TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
444  return *this;
445  }
446 };
447 
448 // Wrapper for normal visitor.
450  public:
451  explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
452  template <typename T>
453  AttrNopEntry operator()(const char* key, T* value) {
454  visitor_->Visit(key, value);
455  return AttrNopEntry();
456  }
457 
458  private:
459  AttrVisitor* visitor_;
460 };
461 
463  public:
464  bool result_{true};
465  // constructor
466  AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal)
467  : lhs_(lhs), rhs_(rhs), equal_(equal) {}
468  template <typename T>
469  AttrNopEntry operator()(const char* key, T* lhs_value) {
470  if (!result_) return AttrNopEntry();
471  const T* rhs_value = reinterpret_cast<const T*>(
472  reinterpret_cast<const char*>(rhs_) +
473  (reinterpret_cast<const char*>(lhs_value) - reinterpret_cast<const char*>(lhs_)));
474  if (!equal_(*lhs_value, *rhs_value)) {
475  result_ = false;
476  }
477  return AttrNopEntry();
478  }
479 
480  private:
481  const Object* lhs_;
482  const Object* rhs_;
483  const SEqualReducer& equal_;
484 };
485 
487  public:
488  explicit AttrsSHashVisitor(const SHashReducer& hash_reducer) : hash_reducer_(hash_reducer) {}
489 
490  template <typename T>
491  AttrNopEntry operator()(const char* key, T* value) {
492  hash_reducer_(*value);
493  return AttrNopEntry();
494  }
495 
496  private:
497  const SHashReducer& hash_reducer_;
498 };
499 
500 // helper entry that does initialization, set default.
501 template <typename T>
503  // The attributes
505  // The type key
506  const char* type_key_;
507  // field name
508  const char* key_;
509  // internal value.
510  T* value_;
511  // whether the value is missing.
512  // NOTE: initialize to false so that the destructor does not throw unless
513  // AttrInitVisitor::operator() is committed to returning an instance of this class.
514  // It is expected not to set this to true until that is true.
515  bool value_missing_{false};
516 
517  AttrInitEntry() = default;
518 
520  type_key_ = other.type_key_;
521  key_ = other.key_;
522  value_ = other.value_;
523  value_missing_ = other.value_missing_;
524  // avoid unexpected throw
525  other.value_missing_ = false;
526  }
527 
528  // If the value is still missing in destruction time throw an error.
529  ~AttrInitEntry() DMLC_THROW_EXCEPTION {
530  if (value_missing_) {
531  std::ostringstream os;
532  os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization. "
533  << "If the key is defined check that its type matches the declared type.";
534  throw AttrError(os.str());
535  }
536  }
537  // override fields.
538  // This function sets the lower bound of the attribute
539  TSelf& set_lower_bound(const T& begin) {
540  if (this->value_missing_) return *this;
541  const T& val = *value_;
542  if (begin > val) {
543  std::ostringstream os;
544  os << type_key_ << "." << key_ << ": "
545  << "value " << val << " is smaller than the lower bound " << begin;
546  throw AttrError(os.str());
547  }
548  return *this;
549  }
550  // This function sets the upper bound of the attribute
551  TSelf& set_upper_bound(const T& end) {
552  if (this->value_missing_) return *this;
553  const T& val = *value_;
554  if (val > end) {
555  std::ostringstream os;
556  os << type_key_ << "." << key_ << ": "
557  << "value " << val << " is bigger than the upper bound " << end;
558  throw AttrError(os.str());
559  }
560  return *this;
561  }
562  // set default when
563  TSelf& set_default(const T& value) {
564  if (!value_missing_) return *this;
565  *value_ = value;
566  value_missing_ = false;
567  return *this;
568  }
569  TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
570 };
571 
572 // Template function to allow smart conversion
573 // from Expr types into the constants.
574 template <typename T>
575 inline void SetValue(T* ptr, const TVMArgValue& val) {
576  *ptr = val.operator T();
577 }
578 
579 template <typename T>
580 inline void SetIntValue(T* ptr, const TVMArgValue& val) {
581  if (val.type_code() == kDLInt) {
582  *ptr = static_cast<T>(val.value().v_int64);
583  } else {
584  IntImm expr = val;
585  *ptr = static_cast<T>(expr->value);
586  }
587 }
588 
589 // Workaround for GCC8.1 / GCC8.2
590 template <>
591 inline void SetValue<DataType>(DataType* ptr, const TVMArgValue& val) {
592  *ptr = val.operator DataType();
593 }
594 
595 template <>
596 inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
597  if (String::CanConvertFrom(val)) {
598  *ptr = val.operator std::string();
599  } else {
600  LOG(FATAL) << "Expect str";
601  }
602 }
603 
604 template <>
605 inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
606  if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
607  *ptr = val.operator double();
608  } else {
609  ObjectRef expr = val;
610  ICHECK(expr.defined());
611  if (const IntImmNode* op = expr.as<IntImmNode>()) {
612  *ptr = static_cast<double>(op->value);
613  } else if (const FloatImmNode* op = expr.as<FloatImmNode>()) {
614  *ptr = static_cast<double>(op->value);
615  } else {
616  LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
617  }
618  }
619 }
620 template <>
621 inline void SetValue<int>(int* ptr, const TVMArgValue& val) {
622  SetIntValue(ptr, val);
623 }
624 template <>
625 inline void SetValue<int64_t>(int64_t* ptr, const TVMArgValue& val) {
626  SetIntValue(ptr, val);
627 }
628 template <>
629 inline void SetValue<uint64_t>(uint64_t* ptr, const TVMArgValue& val) {
630  SetIntValue(ptr, val);
631 }
632 template <>
633 inline void SetValue<bool>(bool* ptr, const TVMArgValue& val) {
634  SetIntValue(ptr, val);
635 }
636 
637 // Visitor for value initialization
638 template <typename FFind>
640  public:
641  // Counter of number of matched attributes during visit.
642  // This is used to decide if there is additional unmatched attributes.
643  size_t hit_count_{0};
644  // constructor
645  AttrInitVisitor(const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) {}
646 
647  template <typename T>
648  AttrInitEntry<T> operator()(const char* key, T* value) {
649  TVMArgValue val;
650  AttrInitEntry<T> opt;
651  opt.type_key_ = type_key_;
652  opt.key_ = key;
653  opt.value_ = value;
654  if (ffind_(key, &val)) {
655  SetValue(value, val);
656  opt.value_missing_ = false;
657  ++hit_count_;
658  } else {
659  opt.value_missing_ = true;
660  }
661 #if defined(__GNUC__)
662 #pragma GCC diagnostic ignored "-Wpragmas"
663 #pragma GCC diagnostic ignored "-Wpessimizing-move"
664 #endif
665  return std::move(opt);
666  }
667 
668  private:
669  // the type key
670  const char* type_key_;
671  FFind ffind_;
672 };
673 
674 template <typename FFind>
675 inline AttrInitVisitor<FFind> CreateInitVisitor(const char* type_key, FFind ffind) {
676  return AttrInitVisitor<FFind>(type_key, ffind);
677 }
678 
683 template <typename T>
684 struct TypeName {
685  static constexpr const char* value = T::ContainerType::_type_key;
686 };
687 
688 template <>
689 struct TypeName<int> {
690  static constexpr const char* value = "int";
691 };
692 
693 template <>
694 struct TypeName<int64_t> {
695  static constexpr const char* value = "int64";
696 };
697 
698 template <>
699 struct TypeName<uint64_t> {
700  static constexpr const char* value = "uint64_t";
701 };
702 
703 template <>
705  static constexpr const char* value = "DataType";
706 };
707 
708 template <>
709 struct TypeName<std::string> {
710  static constexpr const char* value = "str";
711 };
712 
713 template <>
714 struct TypeName<bool> {
715  static constexpr const char* value = "bool";
716 };
717 
718 template <>
719 struct TypeName<void*> {
720  static constexpr const char* value = "handle";
721 };
722 
723 template <>
724 struct TypeName<double> {
725  static constexpr const char* value = "double";
726 };
727 
729  public:
731 
732  explicit AttrDocEntry(ObjectPtr<AttrFieldInfoNode> info) : info_(info) {}
733  TSelf& describe(const char* str) {
734  info_->description = str;
735  return *this;
736  }
737  template <typename T>
738  TSelf& set_default(const T& value) {
739  std::ostringstream os;
740  os << info_->type_info << ", default=" << value;
741  info_->type_info = os.str();
742  return *this;
743  }
744  template <typename T>
745  TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) {
746  return *this;
747  }
748  template <typename T>
749  TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) {
750  return *this;
751  }
752 
753  private:
755 };
756 
758  public:
759  template <typename T>
760  AttrDocEntry operator()(const char* key, T* v) {
761  ObjectPtr<AttrFieldInfoNode> info = make_object<AttrFieldInfoNode>();
762  info->name = key;
763  info->type_info = TypeName<T>::value;
764  fields_.push_back(AttrFieldInfo(info));
765  return AttrDocEntry(info);
766  }
767 
769 };
770 
772  public:
773  std::string key_;
774  bool exist_{false};
775 
776  template <typename T>
777  AttrNopEntry operator()(const char* key, T* v) {
778  if (exist_) return AttrNopEntry();
779  if (key == key_) exist_ = true;
780  return AttrNopEntry();
781  }
782 };
783 
784 template <typename T>
787  // constructor
788  AttrTriggerNonDefaultEntry(AttrVisitor* visitor, const char* key, T* data)
789  : visitor_(visitor), key_(key), data_(data) {}
790 
791  ~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION {
792  if (trigger_) {
793  visitor_->Visit(key_, data_);
794  }
795  }
796  TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
797  TSelf& set_default(const T& value) {
798  if (tvm::StructuralEqual()(value, *data_)) {
799  trigger_ = false;
800  }
801  return *this;
802  }
803  TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { return *this; }
804  TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { return *this; }
805 
806  private:
807  AttrVisitor* visitor_;
808  const char* key_;
809  T* data_;
810  bool trigger_{true};
811 };
812 
814  public:
815  explicit AttrNonDefaultVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
816  template <typename T>
817  AttrTriggerNonDefaultEntry<T> operator()(const char* key, T* value) {
818  return AttrTriggerNonDefaultEntry<T>(visitor_, key, value);
819  }
820 
821  private:
822  AttrVisitor* visitor_;
823 };
824 } // namespace detail
825 
832 template <typename DerivedType>
833 class AttrsNode : public BaseAttrsNode {
834  public:
837  self()->_tvm_VisitAttrs(vis);
838  }
839 
842  self()->_tvm_VisitAttrs(vis);
843  }
844 
845  void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final {
846  ICHECK_EQ(args.size() % 2, 0);
847  const int kLinearSearchBound = 16;
848  int hit_count = 0;
849  // applies two strategies to lookup
850  if (args.size() < kLinearSearchBound) {
851  // linear search.
852  auto ffind = [&args](const char* key, runtime::TVMArgValue* val) {
853  for (int i = 0; i < args.size(); i += 2) {
854  ICHECK_EQ(args.type_codes[i], kTVMStr);
855  if (!std::strcmp(key, args.values[i].v_str)) {
856  *val = args[i + 1];
857  return true;
858  }
859  }
860  return false;
861  };
862  auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
863  self()->_tvm_VisitAttrs(vis);
864  hit_count = vis.hit_count_;
865  } else {
866  // construct a map then do lookup.
867  std::unordered_map<std::string, runtime::TVMArgValue> kwargs;
868  for (int i = 0; i < args.size(); i += 2) {
869  ICHECK_EQ(args.type_codes[i], kTVMStr);
870  kwargs[args[i].operator std::string()] = args[i + 1];
871  }
872  auto ffind = [&kwargs](const char* key, runtime::TVMArgValue* val) {
873  auto it = kwargs.find(key);
874  if (it != kwargs.end()) {
875  *val = it->second;
876  return true;
877  }
878  return false;
879  };
880  auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
881  self()->_tvm_VisitAttrs(vis);
882  hit_count = vis.hit_count_;
883  }
884  // error handling, slow path
885  if (hit_count * 2 != args.size() && !allow_unknown) {
886  for (int i = 0; i < args.size(); i += 2) {
888  visitor.key_ = args[i].operator std::string();
889  self()->_tvm_VisitAttrs(visitor);
890  if (!visitor.exist_) {
891  std::ostringstream os;
892  os << DerivedType::_type_key << ": does not have field \'" << visitor.key_
893  << "\', Possible fields:\n";
894  os << "----------------\n";
895  this->PrintDocString(os);
896  throw AttrError(os.str());
897  }
898  }
899  }
900  }
901 
902  bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {
903  DerivedType* pself = self();
904  ::tvm::detail::AttrsSEqualVisitor visitor(pself, other, equal);
905  self()->_tvm_VisitAttrs(visitor);
906  return visitor.result_;
907  }
908 
909  void SHashReduce(SHashReducer hash_reducer) const {
910  ::tvm::detail::AttrsSHashVisitor visitor(hash_reducer);
911  self()->_tvm_VisitAttrs(visitor);
912  }
913 
916  self()->_tvm_VisitAttrs(visitor);
917  return visitor.fields_;
918  }
919 
920  private:
921  DerivedType* self() const {
922  return const_cast<DerivedType*>(static_cast<const DerivedType*>(this));
923  }
924 };
925 
926 template <typename... Args>
927 inline void BaseAttrsNode::InitBySeq(Args&&... args) {
929  [this](const TVMArgs& args, TVMRetValue* rv) { this->InitByPackedArgs(args); });
930  pf(std::forward<Args>(args)...);
931 }
932 
933 inline void BaseAttrsNode::PrintDocString(std::ostream& os) const { // NOLINT(*)
934  Array<AttrFieldInfo> entry = this->ListFieldInfo();
935  for (AttrFieldInfo info : entry) {
936  os << info->name << " : " << info->type_info << '\n';
937  if (info->description.length() != 0) {
938  os << " " << info->description << '\n';
939  }
940  }
941 }
942 
943 } // namespace tvm
944 #endif // TVM_IR_ATTRS_H_
@ kTVMStr
Definition: c_runtime_api.h:185
Information about attribute fields in string representations.
Definition: attrs.h:106
static constexpr bool _type_has_method_shash_reduce
Definition: attrs.h:123
String description
detailed description of the type
Definition: attrs.h:113
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object)
String type_info
type docstring information in str.
Definition: attrs.h:111
void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:115
String name
name of the field
Definition: attrs.h:109
static constexpr bool _type_has_method_sequal_reduce
Definition: attrs.h:122
static constexpr const char * _type_key
Definition: attrs.h:121
AttrFieldInfo.
Definition: attrs.h:128
TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode)
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
The base class of the all the Use "curiously recurring template pattern".
Definition: attrs.h:833
bool SEqualReduce(const DerivedType *other, SEqualReducer equal) const
Definition: attrs.h:902
void SHashReduce(SHashReducer hash_reducer) const
Definition: attrs.h:909
void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:835
void VisitNonDefaultAttrs(AttrVisitor *v)
Visit attributes that do not equal the default value.
Definition: attrs.h:840
Array< AttrFieldInfo > ListFieldInfo() const final
Get the field information.
Definition: attrs.h:914
void InitByPackedArgs(const runtime::TVMArgs &args, bool allow_unknown) final
Initialize the attributes by arguments.
Definition: attrs.h:845
Managed reference to BaseAttrsNode.
Definition: attrs.h:190
TVM_DEFINE_OBJECT_REF_METHODS(Attrs, ObjectRef, BaseAttrsNode)
Base class of all attribute class.
Definition: attrs.h:139
virtual ~BaseAttrsNode()
virtual destructor
Definition: attrs.h:144
virtual void InitByPackedArgs(const TVMArgs &kwargs, bool allow_unknown=false)=0
Initialize the attributes by arguments.
virtual Array< AttrFieldInfo > ListFieldInfo() const =0
Get the field information.
void PrintDocString(std::ostream &os) const
Print readible docstring to ostream, add newline.
Definition: attrs.h:933
static constexpr const char * _type_key
Definition: attrs.h:182
TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object)
void InitBySeq(Args &&... args)
Initialize the attributes by sequence of arguments.
Definition: attrs.h:927
virtual void VisitNonDefaultAttrs(AttrVisitor *v)=0
Visit attributes that do not equal the default value.
static constexpr const bool _type_has_method_sequal_reduce
Definition: attrs.h:180
static constexpr const bool _type_has_method_shash_reduce
Definition: attrs.h:181
virtual void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:146
Specialized attribute type that is backed by a map. The DictAttrsNode implements the Attrs behavior,...
Definition: attrs.h:201
TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode)
void InitByPackedArgs(const runtime::TVMArgs &args, bool allow_unknown) final
Initialize the attributes by arguments.
void VisitAttrs(AttrVisitor *v) final
static constexpr const char * _type_key
Definition: attrs.h:219
Array< AttrFieldInfo > ListFieldInfo() const final
Get the field information.
void VisitNonDefaultAttrs(AttrVisitor *v) final
Visit attributes that do not equal the default value.
Map< String, ObjectRef > dict
internal attrs map
Definition: attrs.h:204
void SHashReduce(SHashReducer hash_reduce) const
Definition: attrs.h:210
bool SEqualReduce(const DictAttrsNode *other, SEqualReducer equal) const
Definition: attrs.h:206
Managed reference to DictAttrsNode.
Definition: attrs.h:227
DictAttrs(Map< String, ObjectRef > dict)
Consruct a Attrs backed by DictAttrsNode.
bool HasNonzeroAttr(const std::string &attr_key) const
Check whether the function has an non-zero integer attr.
Definition: attrs.h:297
Optional< TObjectRef > GetAttr(const std::string &attr_key, Optional< TObjectRef > default_value=Optional< TObjectRef >(nullptr)) const
Get a function attribute.
Definition: attrs.h:258
TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode)
Optional< TObjectRef > GetAttr(const std::string &attr_key, TObjectRef default_value) const
Definition: attrs.h:275
Constant floating point literals in the program.
Definition: expr.h:538
Constant integer literals in the program.
Definition: expr.h:491
Managed reference class to IntImmNode.
Definition: expr.h:520
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:110
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:103
Definition: attrs.h:728
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin)
Definition: attrs.h:745
TSelf & set_default(const T &value)
Definition: attrs.h:738
AttrDocEntry(ObjectPtr< AttrFieldInfoNode > info)
Definition: attrs.h:732
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end)
Definition: attrs.h:749
TSelf & describe(const char *str)
Definition: attrs.h:733
Definition: attrs.h:757
AttrDocEntry operator()(const char *key, T *v)
Definition: attrs.h:760
Array< AttrFieldInfo > fields_
Definition: attrs.h:768
Definition: attrs.h:771
AttrNopEntry operator()(const char *key, T *v)
Definition: attrs.h:777
std::string key_
Definition: attrs.h:773
bool exist_
Definition: attrs.h:774
Definition: attrs.h:639
AttrInitVisitor(const char *type_key, FFind ffind)
Definition: attrs.h:645
AttrInitEntry< T > operator()(const char *key, T *value)
Definition: attrs.h:648
Definition: attrs.h:813
AttrNonDefaultVisitor(AttrVisitor *visitor)
Definition: attrs.h:815
AttrTriggerNonDefaultEntry< T > operator()(const char *key, T *value)
Definition: attrs.h:817
Definition: attrs.h:449
AttrNopEntry operator()(const char *key, T *value)
Definition: attrs.h:453
AttrNormalVisitor(AttrVisitor *visitor)
Definition: attrs.h:451
Definition: attrs.h:462
AttrNopEntry operator()(const char *key, T *lhs_value)
Definition: attrs.h:469
AttrsSEqualVisitor(const Object *lhs, const Object *rhs, const SEqualReducer &equal)
Definition: attrs.h:466
bool result_
Definition: attrs.h:464
Definition: attrs.h:486
AttrNopEntry operator()(const char *key, T *value)
Definition: attrs.h:491
AttrsSHashVisitor(const SHashReducer &hash_reducer)
Definition: attrs.h:488
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Runtime primitive data type.
Definition: data_type.h:42
@ kHandle
Definition: data_type.h:56
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
A custom smart pointer for Object.
Definition: object.h:360
Base class of all object reference.
Definition: object.h:517
bool defined() const
Definition: object.h:550
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:894
base class of all object containers.
Definition: object.h:169
std::string GetTypeKey() const
Definition: object.h:182
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:139
Reference to string objects.
Definition: string.h:98
static bool CanConvertFrom(const TVMArgValue &val)
Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String.
Definition: packed_func.h:2194
A single argument value to PackedFunc. Containing both type_code and TVMValue.
Definition: packed_func.h:649
const TVMValue & value() const
Definition: packed_func.h:691
Arguments into TVM functions.
Definition: packed_func.h:392
int type_code() const
Definition: packed_func.h:613
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:799
struct TVMArgs TVMArgs
Base expr nodes in TVM.
void SetValue< int >(int *ptr, const TVMArgValue &val)
Definition: attrs.h:621
void SetValue< double >(double *ptr, const TVMArgValue &val)
Definition: attrs.h:605
void SetValue< DataType >(DataType *ptr, const TVMArgValue &val)
Definition: attrs.h:591
AttrInitVisitor< FFind > CreateInitVisitor(const char *type_key, FFind ffind)
Definition: attrs.h:675
void SetValue< uint64_t >(uint64_t *ptr, const TVMArgValue &val)
Definition: attrs.h:629
void SetValue< int64_t >(int64_t *ptr, const TVMArgValue &val)
Definition: attrs.h:625
void SetValue< bool >(bool *ptr, const TVMArgValue &val)
Definition: attrs.h:633
void SetValue(T *ptr, const TVMArgValue &val)
Definition: attrs.h:575
void SetIntValue(T *ptr, const TVMArgValue &val)
Definition: attrs.h:580
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
DataType NullValue< DataType >()
Definition: attrs.h:90
TFunc WithAttr(TFunc input, const std::string &attr_key, ObjectRef attr_value)
Copy the function or module, but overrides the attribute value key with the value.
Definition: attrs.h:346
TAttrs AttrsWithDefaultValues()
Create an Attr object with all default values.
Definition: attrs.h:311
runtime::DataType DataType
Definition: data_type.h:433
TFunc WithoutAttr(TFunc input, const std::string &attr_key)
Copy the function or module, but removes the specified attribute.
Definition: attrs.h:411
TFunc WithAttrs(TFunc input, Map< String, ObjectRef > attrs)
Copy the function or module, but overrides the attributes with the entries from attrs.
Definition: attrs.h:370
TObjectRef NullValue()
Create a NodeRef type that represents null.
Definition: attrs.h:84
Type-erased function used across TVM API.
Error thrown during attribute checking.
Definition: attrs.h:95
AttrError(std::string msg)
constructor
Definition: attrs.h:100
Definition: attrs.h:502
TSelf & set_lower_bound(const T &begin)
Definition: attrs.h:539
const char * type_key_
Definition: attrs.h:506
const char * key_
Definition: attrs.h:508
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:569
TSelf & set_upper_bound(const T &end)
Definition: attrs.h:551
~AttrInitEntry() DMLC_THROW_EXCEPTION
Definition: attrs.h:529
bool value_missing_
Definition: attrs.h:515
TSelf & set_default(const T &value)
Definition: attrs.h:563
T * value_
Definition: attrs.h:510
AttrInitEntry(AttrInitEntry &&other)
Definition: attrs.h:519
Definition: attrs.h:430
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T &begin)
Definition: attrs.h:439
TSelf & set_default(DMLC_ATTRIBUTE_UNUSED const T &value)
Definition: attrs.h:435
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:433
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T &end)
Definition: attrs.h:443
AttrTriggerNonDefaultEntry(AttrVisitor *visitor, const char *key, T *data)
Definition: attrs.h:788
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T &begin)
Definition: attrs.h:803
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T &end)
Definition: attrs.h:804
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:796
~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION
Definition: attrs.h:791
TSelf & set_default(const T &value)
Definition: attrs.h:797
Helper struct to get the type name known to tvm.
Definition: attrs.h:684
Structural equality comparison.
int64_t v_int64
Definition: c_runtime_api.h:209