tvm
attrs.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
44 #ifndef TVM_IR_ATTRS_H_
45 #define TVM_IR_ATTRS_H_
46 
47 #include <dmlc/common.h>
48 #include <tvm/ir/expr.h>
52 
53 #include <functional>
54 #include <string>
55 #include <type_traits>
56 #include <unordered_map>
57 #include <utility>
58 #include <vector>
59 
60 namespace tvm {
66 #define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
67  static constexpr const char* _type_key = TypeKey; \
68  TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
69  template <typename FVisit> \
70  void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*)
71 
76 #define TVM_ATTR_FIELD(FieldName) __fvisit__(#FieldName, &FieldName)
77 
83 template <typename TObjectRef>
84 inline TObjectRef NullValue() {
85  static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types");
86  return TObjectRef(ObjectPtr<Object>(nullptr));
87 }
88 
89 template <>
91  return DataType(DataType::kHandle, 0, 0);
92 }
93 
95 struct AttrError : public Error {
100  explicit AttrError(std::string msg) : Error("AttributeError:" + msg) {}
101 };
102 
106 class AttrFieldInfoNode : public Object {
107  public:
114 
116  v->Visit("name", &name);
117  v->Visit("type_info", &type_info);
118  v->Visit("description", &description);
119  }
120 
121  static constexpr const char* _type_key = "AttrFieldInfo";
122  static constexpr bool _type_has_method_sequal_reduce = false;
123  static constexpr bool _type_has_method_shash_reduce = false;
125 };
126 
128 class AttrFieldInfo : public ObjectRef {
129  public:
131 };
132 
139 class BaseAttrsNode : public Object {
140  public:
144  virtual ~BaseAttrsNode() {}
145  // visit function
146  virtual void VisitAttrs(AttrVisitor* v) {}
152  template <typename... Args>
153  inline void InitBySeq(Args&&... args);
158  inline void PrintDocString(std::ostream& os) const; // NOLINT(*)
165  TVM_DLL virtual void VisitNonDefaultAttrs(AttrVisitor* v) = 0;
170  TVM_DLL virtual Array<AttrFieldInfo> ListFieldInfo() const = 0;
178  TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0;
179 
180  static constexpr const bool _type_has_method_sequal_reduce = true;
181  static constexpr const bool _type_has_method_shash_reduce = true;
182  static constexpr const char* _type_key = "Attrs";
184 };
185 
190 class Attrs : public ObjectRef {
191  public:
193 };
194 
201 class DictAttrsNode : public BaseAttrsNode {
202  public:
205 
206  bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const {
207  return equal(dict, other->dict);
208  }
209 
210  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dict); }
211 
212  // implementations
213  void VisitAttrs(AttrVisitor* v) final;
214  void VisitNonDefaultAttrs(AttrVisitor* v) final;
215  void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
216  Array<AttrFieldInfo> ListFieldInfo() const final;
217 
218  // type info
219  static constexpr const char* _type_key = "DictAttrs";
221 };
222 
227 class DictAttrs : public Attrs {
228  public:
234  TVM_DLL explicit DictAttrs(Map<String, ObjectRef> dict);
235 
236  // Utils for accessing attributes
237  // This needs to be on DictAttrs, not DictAttrsNode because we return the default
238  // value if DictAttrsNode is not defined.
258  template <typename TObjectRef>
260  const std::string& attr_key,
261  Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
262  static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
263  "Can only call GetAttr with ObjectRef types.");
264  if (!defined()) return default_value;
265  const DictAttrsNode* node = this->as<DictAttrsNode>();
266 
267  auto it = node->dict.find(attr_key);
268  if (it != node->dict.end()) {
269  return Downcast<Optional<TObjectRef>>((*it).second);
270  } else {
271  return default_value;
272  }
273  }
274  // variant that uses TObjectRef to enable implicit conversion to default value.
275  template <typename TObjectRef>
276  Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
277  return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
278  }
298  bool HasNonzeroAttr(const std::string& attr_key) const {
299  return GetAttr<Integer>(attr_key, 0) != 0;
300  }
301 
304 };
305 
311 template <typename TAttrs>
312 inline TAttrs AttrsWithDefaultValues() {
313  static_assert(std::is_base_of<Attrs, TAttrs>::value, "Can only take attr nodes");
314  auto n = make_object<typename TAttrs::ContainerType>();
315  n->InitByPackedArgs(runtime::TVMArgs(nullptr, nullptr, 0), false);
316  return TAttrs(n);
317 }
318 
346 template <typename TFunc>
347 inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_value) {
348  using TNode = typename TFunc::ContainerType;
349  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
350  TNode* node = input.CopyOnWrite();
351  if (node->attrs.defined()) {
352  node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
353  } else {
354  Map<String, ObjectRef> dict = {{attr_key, attr_value}};
355  node->attrs = DictAttrs(dict);
356  }
357  return input;
358 }
359 
370 template <typename TFunc>
371 inline TFunc WithAttrs(TFunc input, Map<String, ObjectRef> attrs) {
372  using TNode = typename TFunc::ContainerType;
373  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
374  TNode* node = input.CopyOnWrite();
375  if (node->attrs.defined()) {
376  for (const auto& pair : attrs) {
377  node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second);
378  }
379  } else {
380  node->attrs = DictAttrs(std::move(attrs));
381  }
382  return input;
383 }
384 
385 // Namespace containing detail implementations
386 namespace detail {
388 
389 // helper entry that does nothing in set_default/bound/describe calls.
390 struct AttrNopEntry {
392 
393  TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
394  template <typename T>
395  TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {
396  return *this;
397  }
398  template <typename T>
399  TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
400  return *this;
401  }
402  template <typename T>
403  TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
404  return *this;
405  }
406 };
407 
408 // Wrapper for normal visitor.
410  public:
411  explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
412  template <typename T>
413  AttrNopEntry operator()(const char* key, T* value) {
414  visitor_->Visit(key, value);
415  return AttrNopEntry();
416  }
417 
418  private:
419  AttrVisitor* visitor_;
420 };
421 
423  public:
424  bool result_{true};
425  // constructor
426  AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal)
427  : lhs_(lhs), rhs_(rhs), equal_(equal) {}
428  template <typename T>
429  AttrNopEntry operator()(const char* key, T* lhs_value) {
430  if (!result_) return AttrNopEntry();
431  const T* rhs_value = reinterpret_cast<const T*>(
432  reinterpret_cast<const char*>(rhs_) +
433  (reinterpret_cast<const char*>(lhs_value) - reinterpret_cast<const char*>(lhs_)));
434  if (!equal_(*lhs_value, *rhs_value)) {
435  result_ = false;
436  }
437  return AttrNopEntry();
438  }
439 
440  private:
441  const Object* lhs_;
442  const Object* rhs_;
443  const SEqualReducer& equal_;
444 };
445 
447  public:
448  explicit AttrsSHashVisitor(const SHashReducer& hash_reducer) : hash_reducer_(hash_reducer) {}
449 
450  template <typename T>
451  AttrNopEntry operator()(const char* key, T* value) {
452  hash_reducer_(*value);
453  return AttrNopEntry();
454  }
455 
456  private:
457  const SHashReducer& hash_reducer_;
458 };
459 
460 // helper entry that does initialization, set default.
461 template <typename T>
463  // The attributes
465  // The type key
466  const char* type_key_;
467  // field name
468  const char* key_;
469  // internal value.
470  T* value_;
471  // whether the value is missing.
472  // NOTE: initialize to false so that the destructor does not throw unless
473  // AttrInitVisitor::operator() is committed to returning an instance of this class.
474  // It is expected not to set this to true until that is true.
475  bool value_missing_{false};
476 
477  AttrInitEntry() = default;
478 
480  type_key_ = other.type_key_;
481  key_ = other.key_;
482  value_ = other.value_;
483  value_missing_ = other.value_missing_;
484  // avoid unexpected throw
485  other.value_missing_ = false;
486  }
487 
488  // If the value is still missing in destruction time throw an error.
489  ~AttrInitEntry() DMLC_THROW_EXCEPTION {
490  if (value_missing_) {
491  std::ostringstream os;
492  os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization. "
493  << "If the key is defined check that its type matches the declared type.";
494  throw AttrError(os.str());
495  }
496  }
497  // override fields.
498  // This function sets the lower bound of the attribute
499  TSelf& set_lower_bound(const T& begin) {
500  if (this->value_missing_) return *this;
501  const T& val = *value_;
502  if (begin > val) {
503  std::ostringstream os;
504  os << type_key_ << "." << key_ << ": "
505  << "value " << val << " is smaller than the lower bound " << begin;
506  throw AttrError(os.str());
507  }
508  return *this;
509  }
510  // This function sets the upper bound of the attribute
511  TSelf& set_upper_bound(const T& end) {
512  if (this->value_missing_) return *this;
513  const T& val = *value_;
514  if (val > end) {
515  std::ostringstream os;
516  os << type_key_ << "." << key_ << ": "
517  << "value " << val << " is bigger than the upper bound " << end;
518  throw AttrError(os.str());
519  }
520  return *this;
521  }
522  // set default when
523  TSelf& set_default(const T& value) {
524  if (!value_missing_) return *this;
525  *value_ = value;
526  value_missing_ = false;
527  return *this;
528  }
529  TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
530 };
531 
532 // Template function to allow smart conversion
533 // from Expr types into the constants.
534 template <typename T>
535 inline void SetValue(T* ptr, const TVMArgValue& val) {
536  *ptr = val.operator T();
537 }
538 
539 template <typename T>
540 inline void SetIntValue(T* ptr, const TVMArgValue& val) {
541  if (val.type_code() == kDLInt) {
542  *ptr = static_cast<T>(val.value().v_int64);
543  } else {
544  IntImm expr = val;
545  *ptr = static_cast<T>(expr->value);
546  }
547 }
548 
549 // Workaround for GCC8.1 / GCC8.2
550 template <>
551 inline void SetValue<DataType>(DataType* ptr, const TVMArgValue& val) {
552  *ptr = val.operator DataType();
553 }
554 
555 template <>
556 inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
557  if (String::CanConvertFrom(val)) {
558  *ptr = val.operator std::string();
559  } else {
560  LOG(FATAL) << "Expect str";
561  }
562 }
563 
564 template <>
565 inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
566  if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
567  *ptr = val.operator double();
568  } else {
569  ObjectRef expr = val;
570  ICHECK(expr.defined());
571  if (const IntImmNode* op = expr.as<IntImmNode>()) {
572  *ptr = static_cast<double>(op->value);
573  } else if (const FloatImmNode* op = expr.as<FloatImmNode>()) {
574  *ptr = static_cast<double>(op->value);
575  } else {
576  LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
577  }
578  }
579 }
580 template <>
581 inline void SetValue<int>(int* ptr, const TVMArgValue& val) {
582  SetIntValue(ptr, val);
583 }
584 template <>
585 inline void SetValue<int64_t>(int64_t* ptr, const TVMArgValue& val) {
586  SetIntValue(ptr, val);
587 }
588 template <>
589 inline void SetValue<uint64_t>(uint64_t* ptr, const TVMArgValue& val) {
590  SetIntValue(ptr, val);
591 }
592 template <>
593 inline void SetValue<bool>(bool* ptr, const TVMArgValue& val) {
594  SetIntValue(ptr, val);
595 }
596 
597 // Visitor for value initialization
598 template <typename FFind>
600  public:
601  // Counter of number of matched attributes during visit.
602  // This is used to decide if there is additional unmatched attributes.
603  size_t hit_count_{0};
604  // constructor
605  AttrInitVisitor(const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) {}
606 
607  template <typename T>
608  AttrInitEntry<T> operator()(const char* key, T* value) {
609  TVMArgValue val;
610  AttrInitEntry<T> opt;
611  opt.type_key_ = type_key_;
612  opt.key_ = key;
613  opt.value_ = value;
614  if (ffind_(key, &val)) {
615  SetValue(value, val);
616  opt.value_missing_ = false;
617  ++hit_count_;
618  } else {
619  opt.value_missing_ = true;
620  }
621 #if defined(__GNUC__)
622 #pragma GCC diagnostic ignored "-Wpragmas"
623 #pragma GCC diagnostic ignored "-Wpessimizing-move"
624 #endif
625  return std::move(opt);
626  }
627 
628  private:
629  // the type key
630  const char* type_key_;
631  FFind ffind_;
632 };
633 
634 template <typename FFind>
635 inline AttrInitVisitor<FFind> CreateInitVisitor(const char* type_key, FFind ffind) {
636  return AttrInitVisitor<FFind>(type_key, ffind);
637 }
638 
643 template <typename T>
644 struct TypeName {
645  static constexpr const char* value = T::ContainerType::_type_key;
646 };
647 
648 template <>
649 struct TypeName<int> {
650  static constexpr const char* value = "int";
651 };
652 
653 template <>
654 struct TypeName<int64_t> {
655  static constexpr const char* value = "int64";
656 };
657 
658 template <>
659 struct TypeName<uint64_t> {
660  static constexpr const char* value = "uint64_t";
661 };
662 
663 template <>
665  static constexpr const char* value = "DataType";
666 };
667 
668 template <>
669 struct TypeName<std::string> {
670  static constexpr const char* value = "str";
671 };
672 
673 template <>
674 struct TypeName<bool> {
675  static constexpr const char* value = "bool";
676 };
677 
678 template <>
679 struct TypeName<void*> {
680  static constexpr const char* value = "handle";
681 };
682 
683 template <>
684 struct TypeName<double> {
685  static constexpr const char* value = "double";
686 };
687 
689  public:
691 
692  explicit AttrDocEntry(ObjectPtr<AttrFieldInfoNode> info) : info_(info) {}
693  TSelf& describe(const char* str) {
694  info_->description = str;
695  return *this;
696  }
697  template <typename T>
698  TSelf& set_default(const T& value) {
699  std::ostringstream os;
700  os << info_->type_info << ", default=" << value;
701  info_->type_info = os.str();
702  return *this;
703  }
704  template <typename T>
705  TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) {
706  return *this;
707  }
708  template <typename T>
709  TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) {
710  return *this;
711  }
712 
713  private:
715 };
716 
718  public:
719  template <typename T>
720  AttrDocEntry operator()(const char* key, T* v) {
721  ObjectPtr<AttrFieldInfoNode> info = make_object<AttrFieldInfoNode>();
722  info->name = key;
723  info->type_info = TypeName<T>::value;
724  fields_.push_back(AttrFieldInfo(info));
725  return AttrDocEntry(info);
726  }
727 
729 };
730 
732  public:
733  std::string key_;
734  bool exist_{false};
735 
736  template <typename T>
737  AttrNopEntry operator()(const char* key, T* v) {
738  if (exist_) return AttrNopEntry();
739  if (key == key_) exist_ = true;
740  return AttrNopEntry();
741  }
742 };
743 
744 template <typename T>
747  // constructor
748  AttrTriggerNonDefaultEntry(AttrVisitor* visitor, const char* key, T* data)
749  : visitor_(visitor), key_(key), data_(data) {}
750 
751  ~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION {
752  if (trigger_) {
753  visitor_->Visit(key_, data_);
754  }
755  }
756  TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
757  TSelf& set_default(const T& value) {
758  if (tvm::StructuralEqual()(value, *data_)) {
759  trigger_ = false;
760  }
761  return *this;
762  }
763  TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { return *this; }
764  TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { return *this; }
765 
766  private:
767  AttrVisitor* visitor_;
768  const char* key_;
769  T* data_;
770  bool trigger_{true};
771 };
772 
774  public:
775  explicit AttrNonDefaultVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
776  template <typename T>
777  AttrTriggerNonDefaultEntry<T> operator()(const char* key, T* value) {
778  return AttrTriggerNonDefaultEntry<T>(visitor_, key, value);
779  }
780 
781  private:
782  AttrVisitor* visitor_;
783 };
784 } // namespace detail
785 
792 template <typename DerivedType>
793 class AttrsNode : public BaseAttrsNode {
794  public:
797  self()->__VisitAttrs__(vis);
798  }
799 
802  self()->__VisitAttrs__(vis);
803  }
804 
805  void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final {
806  ICHECK_EQ(args.size() % 2, 0);
807  const int kLinearSearchBound = 16;
808  int hit_count = 0;
809  // applies two strategies to lookup
810  if (args.size() < kLinearSearchBound) {
811  // linear search.
812  auto ffind = [&args](const char* key, runtime::TVMArgValue* val) {
813  for (int i = 0; i < args.size(); i += 2) {
814  ICHECK_EQ(args.type_codes[i], kTVMStr);
815  if (!std::strcmp(key, args.values[i].v_str)) {
816  *val = args[i + 1];
817  return true;
818  }
819  }
820  return false;
821  };
822  auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
823  self()->__VisitAttrs__(vis);
824  hit_count = vis.hit_count_;
825  } else {
826  // construct a map then do lookup.
827  std::unordered_map<std::string, runtime::TVMArgValue> kwargs;
828  for (int i = 0; i < args.size(); i += 2) {
829  ICHECK_EQ(args.type_codes[i], kTVMStr);
830  kwargs[args[i].operator std::string()] = args[i + 1];
831  }
832  auto ffind = [&kwargs](const char* key, runtime::TVMArgValue* val) {
833  auto it = kwargs.find(key);
834  if (it != kwargs.end()) {
835  *val = it->second;
836  return true;
837  }
838  return false;
839  };
840  auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
841  self()->__VisitAttrs__(vis);
842  hit_count = vis.hit_count_;
843  }
844  // error handling, slow path
845  if (hit_count * 2 != args.size() && !allow_unknown) {
846  for (int i = 0; i < args.size(); i += 2) {
848  visitor.key_ = args[i].operator std::string();
849  self()->__VisitAttrs__(visitor);
850  if (!visitor.exist_) {
851  std::ostringstream os;
852  os << DerivedType::_type_key << ": does not have field \'" << visitor.key_
853  << "\', Possible fields:\n";
854  os << "----------------\n";
855  this->PrintDocString(os);
856  throw AttrError(os.str());
857  }
858  }
859  }
860  }
861 
862  bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {
863  DerivedType* pself = self();
864  ::tvm::detail::AttrsSEqualVisitor visitor(pself, other, equal);
865  self()->__VisitAttrs__(visitor);
866  return visitor.result_;
867  }
868 
869  void SHashReduce(SHashReducer hash_reducer) const {
870  ::tvm::detail::AttrsSHashVisitor visitor(hash_reducer);
871  self()->__VisitAttrs__(visitor);
872  }
873 
876  self()->__VisitAttrs__(visitor);
877  return visitor.fields_;
878  }
879 
880  private:
881  DerivedType* self() const {
882  return const_cast<DerivedType*>(static_cast<const DerivedType*>(this));
883  }
884 };
885 
886 template <typename... Args>
887 inline void BaseAttrsNode::InitBySeq(Args&&... args) {
889  [this](const TVMArgs& args, TVMRetValue* rv) { this->InitByPackedArgs(args); });
890  pf(std::forward<Args>(args)...);
891 }
892 
893 inline void BaseAttrsNode::PrintDocString(std::ostream& os) const { // NOLINT(*)
894  Array<AttrFieldInfo> entry = this->ListFieldInfo();
895  for (AttrFieldInfo info : entry) {
896  os << info->name << " : " << info->type_info << '\n';
897  if (info->description.length() != 0) {
898  os << " " << info->description << '\n';
899  }
900  }
901 }
902 
903 } // namespace tvm
904 #endif // TVM_IR_ATTRS_H_
int64_t v_int64
Definition: c_runtime_api.h:145
Map< String, ObjectRef > dict
internal attrs map
Definition: attrs.h:204
AttrError(std::string msg)
constructor
Definition: attrs.h:100
TSelf & set_lower_bound(const T &begin)
Definition: attrs.h:499
bool value_missing_
Definition: attrs.h:475
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:734
void SetValue< uint64_t >(uint64_t *ptr, const TVMArgValue &val)
Definition: attrs.h:589
void SetValue< bool >(bool *ptr, const TVMArgValue &val)
Definition: attrs.h:593
void SetValue< DataType >(DataType *ptr, const TVMArgValue &val)
Definition: attrs.h:551
A custom smart pointer for Object.
Definition: object.h:356
AttrNormalVisitor(AttrVisitor *visitor)
Definition: attrs.h:411
TSelf & describe(const char *str)
Definition: attrs.h:693
void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:115
void PrintDocString(std::ostream &os) const
Print readible docstring to ostream, add newline.
Definition: attrs.h:893
Definition: attrs.h:390
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:102
void SetIntValue(T *ptr, const TVMArgValue &val)
Definition: attrs.h:540
void SHashReduce(SHashReducer hash_reduce) const
Definition: attrs.h:210
Managed reference to BaseAttrsNode.
Definition: attrs.h:190
Base expr nodes in TVM.
Definition: attrs.h:688
AttrFieldInfo.
Definition: attrs.h:128
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
struct TVMArgs TVMArgs
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:101
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
Definition: data_type.h:55
TFunc WithAttr(TFunc input, const std::string &attr_key, ObjectRef attr_value)
Copy the function or module, but overrides the attribute value key with the value.
Definition: attrs.h:347
Structural equality comparison.
AttrInitEntry(AttrInitEntry &&other)
Definition: attrs.h:479
Constant floating point literals in the program.
Definition: expr.h:279
void SetValue< double >(double *ptr, const TVMArgValue &val)
Definition: attrs.h:565
bool exist_
Definition: attrs.h:734
Definition: attrs.h:717
AttrTriggerNonDefaultEntry< T > operator()(const char *key, T *value)
Definition: attrs.h:777
Definition: loop_state.h:456
void SetValue< int >(int *ptr, const TVMArgValue &val)
Definition: attrs.h:581
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin)
Definition: attrs.h:705
AttrTriggerNonDefaultEntry(AttrVisitor *visitor, const char *key, T *data)
Definition: attrs.h:748
bool SEqualReduce(const DerivedType *other, SEqualReducer equal) const
Definition: attrs.h:862
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T &begin)
Definition: attrs.h:399
const char * key_
Definition: attrs.h:468
Array< AttrFieldInfo > ListFieldInfo() const final
Get the field information.
Definition: attrs.h:874
Information about attribute fields in string representations.
Definition: attrs.h:106
AttrNopEntry operator()(const char *key, T *v)
Definition: attrs.h:737
Managed reference to DictAttrsNode.
Definition: attrs.h:227
base class of all object containers.
Definition: object.h:165
AttrsSEqualVisitor(const Object *lhs, const Object *rhs, const SEqualReducer &equal)
Definition: attrs.h:426
Specialized attribute type that is backed by a map. The DictAttrsNode implements the Attrs behavior...
Definition: attrs.h:201
AttrDocEntry operator()(const char *key, T *v)
Definition: attrs.h:720
Constant integer literals in the program.
Definition: expr.h:233
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:81
Optional< TObjectRef > GetAttr(const std::string &attr_key, TObjectRef default_value) const
Definition: attrs.h:276
Helper struct to get the type name known to tvm.
Definition: attrs.h:644
AttrNopEntry operator()(const char *key, T *lhs_value)
Definition: attrs.h:429
TSelf & set_default(const T &value)
Definition: attrs.h:523
Definition: attrs.h:731
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:393
String name
name of the field
Definition: attrs.h:109
Definition: attrs.h:599
void InitByPackedArgs(const runtime::TVMArgs &args, bool allow_unknown) final
Initialize the attributes by arguments.
Definition: attrs.h:805
bool defined() const
Definition: object.h:537
Runtime primitive data type.
Definition: data_type.h:41
AttrInitVisitor< FFind > CreateInitVisitor(const char *type_key, FFind ffind)
Definition: attrs.h:635
Arguments into TVM functions.
Definition: packed_func.h:335
Definition: attrs.h:462
Optional< TObjectRef > GetAttr(const std::string &attr_key, Optional< TObjectRef > default_value=Optional< TObjectRef >(nullptr)) const
Get a function attribute.
Definition: attrs.h:259
String description
detailed description of the type
Definition: attrs.h:113
Error thrown during attribute checking.
Definition: attrs.h:95
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end)
Definition: attrs.h:709
virtual ~BaseAttrsNode()
virtual destructor
Definition: attrs.h:144
Managed reference class to IntImmNode.
Definition: expr.h:262
TSelf & set_default(DMLC_ATTRIBUTE_UNUSED const T &value)
Definition: attrs.h:395
Reference to string objects.
Definition: string.h:129
AttrInitVisitor(const char *type_key, FFind ffind)
Definition: attrs.h:605
void InitBySeq(Args &&... args)
Initialize the attributes by sequence of arguments.
Definition: attrs.h:887
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T &end)
Definition: attrs.h:764
AttrDocEntry(ObjectPtr< AttrFieldInfoNode > info)
Definition: attrs.h:692
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:706
const char * type_key_
Definition: attrs.h:466
Definition: attrs.h:446
~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION
Definition: attrs.h:751
std::string key_
Definition: attrs.h:733
static bool CanConvertFrom(const TVMArgValue &val)
Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String.
Definition: packed_func.h:1696
Base class of all object reference.
Definition: object.h:504
DataType NullValue< DataType >()
Definition: attrs.h:90
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:778
void SHashReduce(SHashReducer hash_reducer) const
Definition: attrs.h:869
std::string GetTypeKey() const
Definition: object.h:178
String type_info
type docstring information in str.
Definition: attrs.h:111
Definition: attrs.h:409
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:664
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:529
virtual void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:146
bool SEqualReduce(const DictAttrsNode *other, SEqualReducer equal) const
Definition: attrs.h:206
The base class of the all the Use "curiously recurring template pattern".
Definition: attrs.h:793
void VisitNonDefaultAttrs(AttrVisitor *v)
Visit attributes that do not equal the default value.
Definition: attrs.h:800
TSelf & set_default(const T &value)
Definition: attrs.h:757
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1235
int type_code() const
Definition: packed_func.h:547
A single argument value to PackedFunc. Containing both type_code and TVMValue.
Definition: packed_func.h:583
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:68
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
const TVMValue & value() const
Definition: packed_func.h:632
Array< AttrFieldInfo > fields_
Definition: attrs.h:728
void SetValue(T *ptr, const TVMArgValue &val)
Definition: attrs.h:535
AttrNopEntry operator()(const char *key, T *value)
Definition: attrs.h:413
T * value_
Definition: attrs.h:470
TSelf & set_upper_bound(const T &end)
Definition: attrs.h:511
Base class of all attribute class.
Definition: attrs.h:139
TObjectRef NullValue()
Create a NodeRef type that represents null.
Definition: attrs.h:84
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:756
Definition: attrs.h:422
TFunc WithAttrs(TFunc input, Map< String, ObjectRef > attrs)
Copy the function or module, but overrides the attributes with the entries from attrs.
Definition: attrs.h:371
void SetValue< int64_t >(int64_t *ptr, const TVMArgValue &val)
Definition: attrs.h:585
AttrNonDefaultVisitor(AttrVisitor *visitor)
Definition: attrs.h:775
~AttrInitEntry() DMLC_THROW_EXCEPTION
Definition: attrs.h:489
TSelf & set_default(const T &value)
Definition: attrs.h:698
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:858
void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:795
bool result_
Definition: attrs.h:424
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T &end)
Definition: attrs.h:403
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T &begin)
Definition: attrs.h:763
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
helper macro to declare a base object type that can be inherited.
Definition: object.h:641
Definition: attrs.h:773
runtime::DataType DataType
Definition: data_type.h:389
Definition: c_runtime_api.h:121
AttrInitEntry< T > operator()(const char *key, T *value)
Definition: attrs.h:608
TAttrs AttrsWithDefaultValues()
Create an Attr object with all default values.
Definition: attrs.h:312
AttrNopEntry operator()(const char *key, T *value)
Definition: attrs.h:451
Type-erased function used across TVM API.
AttrsSHashVisitor(const SHashReducer &hash_reducer)
Definition: attrs.h:448
bool HasNonzeroAttr(const std::string &attr_key) const
Check whether the function has an non-zero integer attr.
Definition: attrs.h:298