tvm
attrs.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
44 #ifndef TVM_IR_ATTRS_H_
45 #define TVM_IR_ATTRS_H_
46 
47 #include <dmlc/common.h>
48 #include <tvm/ir/expr.h>
52 
53 #include <functional>
54 #include <string>
55 #include <type_traits>
56 #include <unordered_map>
57 #include <utility>
58 #include <vector>
59 
60 namespace tvm {
66 #define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
67  static constexpr const char* _type_key = TypeKey; \
68  TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
69  template <typename FVisit> \
70  void _tvm_VisitAttrs(FVisit& _tvm_fvisit) // NOLINT(*)
71 
76 #define TVM_ATTR_FIELD(FieldName) _tvm_fvisit(#FieldName, &FieldName)
77 
83 template <typename TObjectRef>
84 inline TObjectRef NullValue() {
85  static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types");
86  return TObjectRef(ObjectPtr<Object>(nullptr));
87 }
88 
89 template <>
91  return DataType(DataType::kHandle, 0, 0);
92 }
93 
95 struct AttrError : public Error {
100  explicit AttrError(std::string msg) : Error("AttributeError:" + msg) {}
101 };
102 
106 class AttrFieldInfoNode : public Object {
107  public:
114 
116  v->Visit("name", &name);
117  v->Visit("type_info", &type_info);
118  v->Visit("description", &description);
119  }
120 
121  static constexpr const char* _type_key = "AttrFieldInfo";
122  static constexpr bool _type_has_method_sequal_reduce = false;
123  static constexpr bool _type_has_method_shash_reduce = false;
125 };
126 
128 class AttrFieldInfo : public ObjectRef {
129  public:
131 };
132 
139 class BaseAttrsNode : public Object {
140  public:
144  virtual ~BaseAttrsNode() {}
145  // visit function
146  virtual void VisitAttrs(AttrVisitor* v) {}
152  template <typename... Args>
153  inline void InitBySeq(Args&&... args);
158  inline void PrintDocString(std::ostream& os) const; // NOLINT(*)
165  TVM_DLL virtual void VisitNonDefaultAttrs(AttrVisitor* v) = 0;
170  TVM_DLL virtual Array<AttrFieldInfo> ListFieldInfo() const = 0;
178  TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0;
179 
180  static constexpr const bool _type_has_method_sequal_reduce = true;
181  static constexpr const bool _type_has_method_shash_reduce = true;
182  static constexpr const char* _type_key = "Attrs";
184 };
185 
190 class Attrs : public ObjectRef {
191  public:
193 };
194 
201 class DictAttrsNode : public BaseAttrsNode {
202  public:
205 
206  bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const {
207  return equal(dict, other->dict);
208  }
209 
210  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dict); }
211 
212  // implementations
213  void VisitAttrs(AttrVisitor* v) final;
215  void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
217 
218  // type info
219  static constexpr const char* _type_key = "DictAttrs";
221 };
222 
227 class DictAttrs : public Attrs {
228  public:
233  TVM_DLL explicit DictAttrs(Map<String, ObjectRef> dict = {});
234 
235  // Utils for accessing attributes
236  // This needs to be on DictAttrs, not DictAttrsNode because we return the default
237  // value if DictAttrsNode is not defined.
257  template <typename TObjectRef>
259  const std::string& attr_key,
260  Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
261  static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
262  "Can only call GetAttr with ObjectRef types.");
263  if (!defined()) return default_value;
264  const DictAttrsNode* node = this->as<DictAttrsNode>();
265 
266  auto it = node->dict.find(attr_key);
267  if (it != node->dict.end()) {
268  return Downcast<Optional<TObjectRef>>((*it).second);
269  } else {
270  return default_value;
271  }
272  }
273  // variant that uses TObjectRef to enable implicit conversion to default value.
274  template <typename TObjectRef>
275  Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
276  return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
277  }
297  bool HasNonzeroAttr(const std::string& attr_key) const {
298  return GetAttr<Integer>(attr_key, 0).value_or(0).IntValue() != 0;
299  }
300 
303 };
304 
310 template <typename TAttrs>
311 inline TAttrs AttrsWithDefaultValues() {
312  static_assert(std::is_base_of<Attrs, TAttrs>::value, "Can only take attr nodes");
313  auto n = make_object<typename TAttrs::ContainerType>();
314  n->InitByPackedArgs(runtime::TVMArgs(nullptr, nullptr, 0), false);
315  return TAttrs(n);
316 }
317 
345 template <typename TFunc>
346 inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_value) {
347  using TNode = typename TFunc::ContainerType;
348  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
349  TNode* node = input.CopyOnWrite();
350  if (node->attrs.defined()) {
351  node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
352  } else {
353  Map<String, ObjectRef> dict = {{attr_key, attr_value}};
354  node->attrs = DictAttrs(dict);
355  }
356  return input;
357 }
358 
369 template <typename TFunc>
370 inline TFunc WithAttrs(TFunc input, Map<String, ObjectRef> attrs) {
371  using TNode = typename TFunc::ContainerType;
372  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
373  TNode* node = input.CopyOnWrite();
374  if (node->attrs.defined()) {
375  for (const auto& pair : attrs) {
376  node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second);
377  }
378  } else {
379  node->attrs = DictAttrs(std::move(attrs));
380  }
381  return input;
382 }
383 
410 template <typename TFunc>
411 inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) {
412  using TNode = typename TFunc::ContainerType;
413  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
414 
415  if (input->attrs.defined()) {
416  TNode* node = input.CopyOnWrite();
417  node->attrs.CopyOnWrite()->dict.erase(attr_key);
418  }
419  return input;
420 }
421 
422 // Namespace containing detail implementations
423 namespace detail {
425 
426 // helper entry that does nothing in set_default/bound/describe calls.
427 struct AttrNopEntry {
429 
430  TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
431  template <typename T>
432  TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {
433  return *this;
434  }
435  template <typename T>
436  TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
437  return *this;
438  }
439  template <typename T>
440  TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
441  return *this;
442  }
443 };
444 
445 // Wrapper for normal visitor.
447  public:
448  explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
449  template <typename T>
450  AttrNopEntry operator()(const char* key, T* value) {
451  visitor_->Visit(key, value);
452  return AttrNopEntry();
453  }
454 
455  private:
456  AttrVisitor* visitor_;
457 };
458 
460  public:
461  bool result_{true};
462  // constructor
463  AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal)
464  : lhs_(lhs), rhs_(rhs), equal_(equal) {}
465  template <typename T>
466  AttrNopEntry operator()(const char* key, T* lhs_value) {
467  if (!result_) return AttrNopEntry();
468  const T* rhs_value = reinterpret_cast<const T*>(
469  reinterpret_cast<const char*>(rhs_) +
470  (reinterpret_cast<const char*>(lhs_value) - reinterpret_cast<const char*>(lhs_)));
471  if (!equal_(*lhs_value, *rhs_value)) {
472  result_ = false;
473  }
474  return AttrNopEntry();
475  }
476 
477  private:
478  const Object* lhs_;
479  const Object* rhs_;
480  const SEqualReducer& equal_;
481 };
482 
484  public:
485  explicit AttrsSHashVisitor(const SHashReducer& hash_reducer) : hash_reducer_(hash_reducer) {}
486 
487  template <typename T>
488  AttrNopEntry operator()(const char* key, T* value) {
489  hash_reducer_(*value);
490  return AttrNopEntry();
491  }
492 
493  private:
494  const SHashReducer& hash_reducer_;
495 };
496 
497 // helper entry that does initialization, set default.
498 template <typename T>
500  // The attributes
502  // The type key
503  const char* type_key_;
504  // field name
505  const char* key_;
506  // internal value.
507  T* value_;
508  // whether the value is missing.
509  // NOTE: initialize to false so that the destructor does not throw unless
510  // AttrInitVisitor::operator() is committed to returning an instance of this class.
511  // It is expected not to set this to true until that is true.
512  bool value_missing_{false};
513 
514  AttrInitEntry() = default;
515 
517  type_key_ = other.type_key_;
518  key_ = other.key_;
519  value_ = other.value_;
520  value_missing_ = other.value_missing_;
521  // avoid unexpected throw
522  other.value_missing_ = false;
523  }
524 
525  // If the value is still missing in destruction time throw an error.
526  ~AttrInitEntry() DMLC_THROW_EXCEPTION {
527  if (value_missing_) {
528  std::ostringstream os;
529  os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization. "
530  << "If the key is defined check that its type matches the declared type.";
531  throw AttrError(os.str());
532  }
533  }
534  // override fields.
535  // This function sets the lower bound of the attribute
536  TSelf& set_lower_bound(const T& begin) {
537  if (this->value_missing_) return *this;
538  const T& val = *value_;
539  if (begin > val) {
540  std::ostringstream os;
541  os << type_key_ << "." << key_ << ": "
542  << "value " << val << " is smaller than the lower bound " << begin;
543  throw AttrError(os.str());
544  }
545  return *this;
546  }
547  // This function sets the upper bound of the attribute
548  TSelf& set_upper_bound(const T& end) {
549  if (this->value_missing_) return *this;
550  const T& val = *value_;
551  if (val > end) {
552  std::ostringstream os;
553  os << type_key_ << "." << key_ << ": "
554  << "value " << val << " is bigger than the upper bound " << end;
555  throw AttrError(os.str());
556  }
557  return *this;
558  }
559  // set default when
560  TSelf& set_default(const T& value) {
561  if (!value_missing_) return *this;
562  *value_ = value;
563  value_missing_ = false;
564  return *this;
565  }
566  TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
567 };
568 
569 // Template function to allow smart conversion
570 // from Expr types into the constants.
571 template <typename T>
572 inline void SetValue(T* ptr, const TVMArgValue& val) {
573  *ptr = val.operator T();
574 }
575 
576 template <typename T>
577 inline void SetIntValue(T* ptr, const TVMArgValue& val) {
578  if (val.type_code() == kDLInt) {
579  *ptr = static_cast<T>(val.value().v_int64);
580  } else {
581  IntImm expr = val;
582  *ptr = static_cast<T>(expr->value);
583  }
584 }
585 
586 // Workaround for GCC8.1 / GCC8.2
587 template <>
588 inline void SetValue<DataType>(DataType* ptr, const TVMArgValue& val) {
589  *ptr = val.operator DataType();
590 }
591 
592 template <>
593 inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
594  if (String::CanConvertFrom(val)) {
595  *ptr = val.operator std::string();
596  } else {
597  LOG(FATAL) << "Expect str";
598  }
599 }
600 
601 template <>
602 inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
603  if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
604  *ptr = val.operator double();
605  } else {
606  ObjectRef expr = val;
607  ICHECK(expr.defined());
608  if (const IntImmNode* op = expr.as<IntImmNode>()) {
609  *ptr = static_cast<double>(op->value);
610  } else if (const FloatImmNode* op = expr.as<FloatImmNode>()) {
611  *ptr = static_cast<double>(op->value);
612  } else {
613  LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
614  }
615  }
616 }
617 template <>
618 inline void SetValue<int>(int* ptr, const TVMArgValue& val) {
619  SetIntValue(ptr, val);
620 }
621 template <>
622 inline void SetValue<int64_t>(int64_t* ptr, const TVMArgValue& val) {
623  SetIntValue(ptr, val);
624 }
625 template <>
626 inline void SetValue<uint64_t>(uint64_t* ptr, const TVMArgValue& val) {
627  SetIntValue(ptr, val);
628 }
629 template <>
630 inline void SetValue<bool>(bool* ptr, const TVMArgValue& val) {
631  SetIntValue(ptr, val);
632 }
633 
634 // Visitor for value initialization
635 template <typename FFind>
637  public:
638  // Counter of number of matched attributes during visit.
639  // This is used to decide if there is additional unmatched attributes.
640  size_t hit_count_{0};
641  // constructor
642  AttrInitVisitor(const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) {}
643 
644  template <typename T>
645  AttrInitEntry<T> operator()(const char* key, T* value) {
646  TVMArgValue val;
647  AttrInitEntry<T> opt;
648  opt.type_key_ = type_key_;
649  opt.key_ = key;
650  opt.value_ = value;
651  if (ffind_(key, &val)) {
652  SetValue(value, val);
653  opt.value_missing_ = false;
654  ++hit_count_;
655  } else {
656  opt.value_missing_ = true;
657  }
658 #if defined(__GNUC__)
659 #pragma GCC diagnostic ignored "-Wpragmas"
660 #pragma GCC diagnostic ignored "-Wpessimizing-move"
661 #endif
662  return std::move(opt);
663  }
664 
665  private:
666  // the type key
667  const char* type_key_;
668  FFind ffind_;
669 };
670 
671 template <typename FFind>
672 inline AttrInitVisitor<FFind> CreateInitVisitor(const char* type_key, FFind ffind) {
673  return AttrInitVisitor<FFind>(type_key, ffind);
674 }
675 
680 template <typename T>
681 struct TypeName {
682  static constexpr const char* value = T::ContainerType::_type_key;
683 };
684 
685 template <>
686 struct TypeName<int> {
687  static constexpr const char* value = "int";
688 };
689 
690 template <>
691 struct TypeName<int64_t> {
692  static constexpr const char* value = "int64";
693 };
694 
695 template <>
696 struct TypeName<uint64_t> {
697  static constexpr const char* value = "uint64_t";
698 };
699 
700 template <>
702  static constexpr const char* value = "DataType";
703 };
704 
705 template <>
706 struct TypeName<std::string> {
707  static constexpr const char* value = "str";
708 };
709 
710 template <>
711 struct TypeName<bool> {
712  static constexpr const char* value = "bool";
713 };
714 
715 template <>
716 struct TypeName<void*> {
717  static constexpr const char* value = "handle";
718 };
719 
720 template <>
721 struct TypeName<double> {
722  static constexpr const char* value = "double";
723 };
724 
726  public:
728 
729  explicit AttrDocEntry(ObjectPtr<AttrFieldInfoNode> info) : info_(info) {}
730  TSelf& describe(const char* str) {
731  info_->description = str;
732  return *this;
733  }
734  template <typename T>
735  TSelf& set_default(const T& value) {
736  std::ostringstream os;
737  os << info_->type_info << ", default=" << value;
738  info_->type_info = os.str();
739  return *this;
740  }
741  template <typename T>
742  TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) {
743  return *this;
744  }
745  template <typename T>
746  TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) {
747  return *this;
748  }
749 
750  private:
752 };
753 
755  public:
756  template <typename T>
757  AttrDocEntry operator()(const char* key, T* v) {
758  ObjectPtr<AttrFieldInfoNode> info = make_object<AttrFieldInfoNode>();
759  info->name = key;
760  info->type_info = TypeName<T>::value;
761  fields_.push_back(AttrFieldInfo(info));
762  return AttrDocEntry(info);
763  }
764 
766 };
767 
769  public:
770  std::string key_;
771  bool exist_{false};
772 
773  template <typename T>
774  AttrNopEntry operator()(const char* key, T* v) {
775  if (exist_) return AttrNopEntry();
776  if (key == key_) exist_ = true;
777  return AttrNopEntry();
778  }
779 };
780 
781 template <typename T>
784  // constructor
785  AttrTriggerNonDefaultEntry(AttrVisitor* visitor, const char* key, T* data)
786  : visitor_(visitor), key_(key), data_(data) {}
787 
788  ~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION {
789  if (trigger_) {
790  visitor_->Visit(key_, data_);
791  }
792  }
793  TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
794  TSelf& set_default(const T& value) {
795  if (tvm::StructuralEqual()(value, *data_)) {
796  trigger_ = false;
797  }
798  return *this;
799  }
800  TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { return *this; }
801  TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { return *this; }
802 
803  private:
804  AttrVisitor* visitor_;
805  const char* key_;
806  T* data_;
807  bool trigger_{true};
808 };
809 
811  public:
812  explicit AttrNonDefaultVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
813  template <typename T>
814  AttrTriggerNonDefaultEntry<T> operator()(const char* key, T* value) {
815  return AttrTriggerNonDefaultEntry<T>(visitor_, key, value);
816  }
817 
818  private:
819  AttrVisitor* visitor_;
820 };
821 } // namespace detail
822 
829 template <typename DerivedType>
830 class AttrsNode : public BaseAttrsNode {
831  public:
834  self()->_tvm_VisitAttrs(vis);
835  }
836 
839  self()->_tvm_VisitAttrs(vis);
840  }
841 
842  void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final {
843  ICHECK_EQ(args.size() % 2, 0);
844  const int kLinearSearchBound = 16;
845  int hit_count = 0;
846  // applies two strategies to lookup
847  if (args.size() < kLinearSearchBound) {
848  // linear search.
849  auto ffind = [&args](const char* key, runtime::TVMArgValue* val) {
850  for (int i = 0; i < args.size(); i += 2) {
851  ICHECK_EQ(args.type_codes[i], kTVMStr);
852  if (!std::strcmp(key, args.values[i].v_str)) {
853  *val = args[i + 1];
854  return true;
855  }
856  }
857  return false;
858  };
859  auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
860  self()->_tvm_VisitAttrs(vis);
861  hit_count = vis.hit_count_;
862  } else {
863  // construct a map then do lookup.
864  std::unordered_map<std::string, runtime::TVMArgValue> kwargs;
865  for (int i = 0; i < args.size(); i += 2) {
866  ICHECK_EQ(args.type_codes[i], kTVMStr);
867  kwargs[args[i].operator std::string()] = args[i + 1];
868  }
869  auto ffind = [&kwargs](const char* key, runtime::TVMArgValue* val) {
870  auto it = kwargs.find(key);
871  if (it != kwargs.end()) {
872  *val = it->second;
873  return true;
874  }
875  return false;
876  };
877  auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
878  self()->_tvm_VisitAttrs(vis);
879  hit_count = vis.hit_count_;
880  }
881  // error handling, slow path
882  if (hit_count * 2 != args.size() && !allow_unknown) {
883  for (int i = 0; i < args.size(); i += 2) {
885  visitor.key_ = args[i].operator std::string();
886  self()->_tvm_VisitAttrs(visitor);
887  if (!visitor.exist_) {
888  std::ostringstream os;
889  os << DerivedType::_type_key << ": does not have field \'" << visitor.key_
890  << "\', Possible fields:\n";
891  os << "----------------\n";
892  this->PrintDocString(os);
893  throw AttrError(os.str());
894  }
895  }
896  }
897  }
898 
899  bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {
900  DerivedType* pself = self();
901  ::tvm::detail::AttrsSEqualVisitor visitor(pself, other, equal);
902  self()->_tvm_VisitAttrs(visitor);
903  return visitor.result_;
904  }
905 
906  void SHashReduce(SHashReducer hash_reducer) const {
907  ::tvm::detail::AttrsSHashVisitor visitor(hash_reducer);
908  self()->_tvm_VisitAttrs(visitor);
909  }
910 
913  self()->_tvm_VisitAttrs(visitor);
914  return visitor.fields_;
915  }
916 
917  private:
918  DerivedType* self() const {
919  return const_cast<DerivedType*>(static_cast<const DerivedType*>(this));
920  }
921 };
922 
923 template <typename... Args>
924 inline void BaseAttrsNode::InitBySeq(Args&&... args) {
926  [this](const TVMArgs& args, TVMRetValue* rv) { this->InitByPackedArgs(args); });
927  pf(std::forward<Args>(args)...);
928 }
929 
930 inline void BaseAttrsNode::PrintDocString(std::ostream& os) const { // NOLINT(*)
931  Array<AttrFieldInfo> entry = this->ListFieldInfo();
932  for (AttrFieldInfo info : entry) {
933  os << info->name << " : " << info->type_info << '\n';
934  if (info->description.length() != 0) {
935  os << " " << info->description << '\n';
936  }
937  }
938 }
939 
940 } // namespace tvm
941 #endif // TVM_IR_ATTRS_H_
@ kTVMStr
Definition: c_runtime_api.h:185
Information about attribute fields in string representations.
Definition: attrs.h:106
static constexpr bool _type_has_method_shash_reduce
Definition: attrs.h:123
String description
detailed description of the type
Definition: attrs.h:113
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object)
String type_info
type docstring information in str.
Definition: attrs.h:111
void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:115
String name
name of the field
Definition: attrs.h:109
static constexpr bool _type_has_method_sequal_reduce
Definition: attrs.h:122
static constexpr const char * _type_key
Definition: attrs.h:121
AttrFieldInfo.
Definition: attrs.h:128
TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode)
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
The base class of the all the Use "curiously recurring template pattern".
Definition: attrs.h:830
bool SEqualReduce(const DerivedType *other, SEqualReducer equal) const
Definition: attrs.h:899
void SHashReduce(SHashReducer hash_reducer) const
Definition: attrs.h:906
void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:832
void VisitNonDefaultAttrs(AttrVisitor *v)
Visit attributes that do not equal the default value.
Definition: attrs.h:837
Array< AttrFieldInfo > ListFieldInfo() const final
Get the field information.
Definition: attrs.h:911
void InitByPackedArgs(const runtime::TVMArgs &args, bool allow_unknown) final
Initialize the attributes by arguments.
Definition: attrs.h:842
Managed reference to BaseAttrsNode.
Definition: attrs.h:190
TVM_DEFINE_OBJECT_REF_METHODS(Attrs, ObjectRef, BaseAttrsNode)
Base class of all attribute class.
Definition: attrs.h:139
virtual ~BaseAttrsNode()
virtual destructor
Definition: attrs.h:144
virtual void InitByPackedArgs(const TVMArgs &kwargs, bool allow_unknown=false)=0
Initialize the attributes by arguments.
virtual Array< AttrFieldInfo > ListFieldInfo() const =0
Get the field information.
void PrintDocString(std::ostream &os) const
Print readible docstring to ostream, add newline.
Definition: attrs.h:930
static constexpr const char * _type_key
Definition: attrs.h:182
TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object)
void InitBySeq(Args &&... args)
Initialize the attributes by sequence of arguments.
Definition: attrs.h:924
virtual void VisitNonDefaultAttrs(AttrVisitor *v)=0
Visit attributes that do not equal the default value.
static constexpr const bool _type_has_method_sequal_reduce
Definition: attrs.h:180
static constexpr const bool _type_has_method_shash_reduce
Definition: attrs.h:181
virtual void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:146
Specialized attribute type that is backed by a map. The DictAttrsNode implements the Attrs behavior,...
Definition: attrs.h:201
TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode)
void InitByPackedArgs(const runtime::TVMArgs &args, bool allow_unknown) final
Initialize the attributes by arguments.
void VisitAttrs(AttrVisitor *v) final
static constexpr const char * _type_key
Definition: attrs.h:219
Array< AttrFieldInfo > ListFieldInfo() const final
Get the field information.
void VisitNonDefaultAttrs(AttrVisitor *v) final
Visit attributes that do not equal the default value.
Map< String, ObjectRef > dict
internal attrs map
Definition: attrs.h:204
void SHashReduce(SHashReducer hash_reduce) const
Definition: attrs.h:210
bool SEqualReduce(const DictAttrsNode *other, SEqualReducer equal) const
Definition: attrs.h:206
Managed reference to DictAttrsNode.
Definition: attrs.h:227
bool HasNonzeroAttr(const std::string &attr_key) const
Check whether the function has an non-zero integer attr.
Definition: attrs.h:297
Optional< TObjectRef > GetAttr(const std::string &attr_key, Optional< TObjectRef > default_value=Optional< TObjectRef >(nullptr)) const
Get a function attribute.
Definition: attrs.h:258
DictAttrs(Map< String, ObjectRef > dict={})
Consruct a Attrs backed by DictAttrsNode.
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode)
Optional< TObjectRef > GetAttr(const std::string &attr_key, TObjectRef default_value) const
Definition: attrs.h:275
TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(DictAttrs, Attrs, DictAttrsNode)
Constant floating point literals in the program.
Definition: expr.h:548
Constant integer literals in the program.
Definition: expr.h:501
Managed reference class to IntImmNode.
Definition: expr.h:530
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:126
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:110
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:103
Definition: attrs.h:725
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin)
Definition: attrs.h:742
TSelf & set_default(const T &value)
Definition: attrs.h:735
AttrDocEntry(ObjectPtr< AttrFieldInfoNode > info)
Definition: attrs.h:729
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end)
Definition: attrs.h:746
TSelf & describe(const char *str)
Definition: attrs.h:730
Definition: attrs.h:754
AttrDocEntry operator()(const char *key, T *v)
Definition: attrs.h:757
Array< AttrFieldInfo > fields_
Definition: attrs.h:765
Definition: attrs.h:768
AttrNopEntry operator()(const char *key, T *v)
Definition: attrs.h:774
std::string key_
Definition: attrs.h:770
bool exist_
Definition: attrs.h:771
Definition: attrs.h:636
AttrInitVisitor(const char *type_key, FFind ffind)
Definition: attrs.h:642
AttrInitEntry< T > operator()(const char *key, T *value)
Definition: attrs.h:645
Definition: attrs.h:810
AttrNonDefaultVisitor(AttrVisitor *visitor)
Definition: attrs.h:812
AttrTriggerNonDefaultEntry< T > operator()(const char *key, T *value)
Definition: attrs.h:814
Definition: attrs.h:446
AttrNopEntry operator()(const char *key, T *value)
Definition: attrs.h:450
AttrNormalVisitor(AttrVisitor *visitor)
Definition: attrs.h:448
Definition: attrs.h:459
AttrNopEntry operator()(const char *key, T *lhs_value)
Definition: attrs.h:466
AttrsSEqualVisitor(const Object *lhs, const Object *rhs, const SEqualReducer &equal)
Definition: attrs.h:463
bool result_
Definition: attrs.h:461
Definition: attrs.h:483
AttrNopEntry operator()(const char *key, T *value)
Definition: attrs.h:488
AttrsSHashVisitor(const SHashReducer &hash_reducer)
Definition: attrs.h:485
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Runtime primitive data type.
Definition: data_type.h:43
@ kHandle
Definition: data_type.h:57
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
bool defined() const
Definition: object.h:552
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:906
base class of all object containers.
Definition: object.h:171
std::string GetTypeKey() const
Definition: object.h:184
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:139
Reference to string objects.
Definition: string.h:98
static bool CanConvertFrom(const TVMArgValue &val)
Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String.
Definition: packed_func.h:2221
A single argument value to PackedFunc. Containing both type_code and TVMValue.
Definition: packed_func.h:657
const TVMValue & value() const
Definition: packed_func.h:699
Arguments into TVM functions.
Definition: packed_func.h:392
int type_code() const
Definition: packed_func.h:621
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:807
struct TVMArgs TVMArgs
Base expr nodes in TVM.
void SetValue< int >(int *ptr, const TVMArgValue &val)
Definition: attrs.h:618
void SetValue< double >(double *ptr, const TVMArgValue &val)
Definition: attrs.h:602
void SetValue< DataType >(DataType *ptr, const TVMArgValue &val)
Definition: attrs.h:588
AttrInitVisitor< FFind > CreateInitVisitor(const char *type_key, FFind ffind)
Definition: attrs.h:672
void SetValue< uint64_t >(uint64_t *ptr, const TVMArgValue &val)
Definition: attrs.h:626
void SetValue< int64_t >(int64_t *ptr, const TVMArgValue &val)
Definition: attrs.h:622
void SetValue< bool >(bool *ptr, const TVMArgValue &val)
Definition: attrs.h:630
void SetValue(T *ptr, const TVMArgValue &val)
Definition: attrs.h:572
void SetIntValue(T *ptr, const TVMArgValue &val)
Definition: attrs.h:577
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
DataType NullValue< DataType >()
Definition: attrs.h:90
TFunc WithAttr(TFunc input, const std::string &attr_key, ObjectRef attr_value)
Copy the function or module, but overrides the attribute value key with the value.
Definition: attrs.h:346
TAttrs AttrsWithDefaultValues()
Create an Attr object with all default values.
Definition: attrs.h:311
runtime::DataType DataType
Definition: data_type.h:491
TFunc WithoutAttr(TFunc input, const std::string &attr_key)
Copy the function or module, but removes the specified attribute.
Definition: attrs.h:411
TFunc WithAttrs(TFunc input, Map< String, ObjectRef > attrs)
Copy the function or module, but overrides the attributes with the entries from attrs.
Definition: attrs.h:370
TObjectRef NullValue()
Create a NodeRef type that represents null.
Definition: attrs.h:84
Type-erased function used across TVM API.
Error thrown during attribute checking.
Definition: attrs.h:95
AttrError(std::string msg)
constructor
Definition: attrs.h:100
Definition: attrs.h:499
TSelf & set_lower_bound(const T &begin)
Definition: attrs.h:536
const char * type_key_
Definition: attrs.h:503
const char * key_
Definition: attrs.h:505
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:566
TSelf & set_upper_bound(const T &end)
Definition: attrs.h:548
~AttrInitEntry() DMLC_THROW_EXCEPTION
Definition: attrs.h:526
bool value_missing_
Definition: attrs.h:512
TSelf & set_default(const T &value)
Definition: attrs.h:560
T * value_
Definition: attrs.h:507
AttrInitEntry(AttrInitEntry &&other)
Definition: attrs.h:516
Definition: attrs.h:427
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T &begin)
Definition: attrs.h:436
TSelf & set_default(DMLC_ATTRIBUTE_UNUSED const T &value)
Definition: attrs.h:432
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:430
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T &end)
Definition: attrs.h:440
AttrTriggerNonDefaultEntry(AttrVisitor *visitor, const char *key, T *data)
Definition: attrs.h:785
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T &begin)
Definition: attrs.h:800
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T &end)
Definition: attrs.h:801
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:793
~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION
Definition: attrs.h:788
TSelf & set_default(const T &value)
Definition: attrs.h:794
Helper struct to get the type name known to tvm.
Definition: attrs.h:681
Structural equality comparison.
int64_t v_int64
Definition: c_runtime_api.h:209