44 #ifndef TVM_IR_ATTRS_H_ 45 #define TVM_IR_ATTRS_H_ 47 #include <dmlc/common.h> 55 #include <type_traits> 56 #include <unordered_map> 66 #define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ 67 static constexpr const char* _type_key = TypeKey; \ 68 TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \ 69 template <typename FVisit> \ 70 void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*) 76 #define TVM_ATTR_FIELD(FieldName) __fvisit__(#FieldName, &FieldName) 83 template <
typename TObjectRef>
85 static_assert(TObjectRef::_type_is_nullable,
"Can only get NullValue for nullable types");
100 explicit AttrError(std::string msg) : Error(
"AttributeError:" + msg) {}
116 v->Visit(
"name", &name);
117 v->Visit(
"type_info", &type_info);
118 v->Visit(
"description", &description);
121 static constexpr
const char* _type_key =
"AttrFieldInfo";
122 static constexpr
bool _type_has_method_sequal_reduce =
false;
123 static constexpr
bool _type_has_method_shash_reduce =
false;
152 template <
typename... Args>
153 inline void InitBySeq(Args&&... args);
158 inline void PrintDocString(std::ostream& os)
const;
165 TVM_DLL
virtual void VisitNonDefaultAttrs(
AttrVisitor* v) = 0;
178 TVM_DLL
virtual void InitByPackedArgs(
const TVMArgs& kwargs,
bool allow_unknown =
false) = 0;
180 static constexpr
const bool _type_has_method_sequal_reduce =
true;
181 static constexpr
const bool _type_has_method_shash_reduce =
true;
182 static constexpr
const char* _type_key =
"Attrs";
215 void InitByPackedArgs(
const runtime::TVMArgs& args,
bool allow_unknown)
final;
219 static constexpr
const char* _type_key =
"DictAttrs";
258 template <
typename TObjectRef>
260 const std::string& attr_key,
262 static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
263 "Can only call GetAttr with ObjectRef types.");
264 if (!defined())
return default_value;
267 auto it = node->
dict.find(attr_key);
268 if (it != node->
dict.end()) {
269 return Downcast<Optional<TObjectRef>>((*it).second);
271 return default_value;
275 template <
typename TObjectRef>
299 return GetAttr<Integer>(attr_key, 0) != 0;
311 template <
typename TAttrs>
313 static_assert(std::is_base_of<Attrs, TAttrs>::value,
"Can only take attr nodes");
314 auto n = make_object<typename TAttrs::ContainerType>();
346 template <
typename TFunc>
348 using TNode =
typename TFunc::ContainerType;
349 static_assert(TNode::_type_final,
"Can only operate on the leaf nodes");
350 TNode* node = input.CopyOnWrite();
351 if (node->attrs.defined()) {
352 node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
355 node->attrs = DictAttrs(dict);
370 template <
typename TFunc>
372 using TNode =
typename TFunc::ContainerType;
373 static_assert(TNode::_type_final,
"Can only operate on the leaf nodes");
374 TNode* node = input.CopyOnWrite();
375 if (node->attrs.defined()) {
376 for (
const auto& pair : attrs) {
377 node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second);
380 node->attrs = DictAttrs(std::move(attrs));
394 template <
typename T>
398 template <
typename T>
402 template <
typename T>
412 template <
typename T>
414 visitor_->Visit(key, value);
427 : lhs_(lhs), rhs_(rhs), equal_(equal) {}
428 template <
typename T>
431 const T* rhs_value =
reinterpret_cast<const T*
>(
432 reinterpret_cast<const char*
>(rhs_) +
433 (reinterpret_cast<const char*>(lhs_value) -
reinterpret_cast<const char*
>(lhs_)));
434 if (!equal_(*lhs_value, *rhs_value)) {
450 template <
typename T>
452 hash_reducer_(*value);
461 template <
typename T>
475 bool value_missing_{
false};
480 type_key_ = other.type_key_;
482 value_ = other.value_;
483 value_missing_ = other.value_missing_;
485 other.value_missing_ =
false;
490 if (value_missing_) {
491 std::ostringstream os;
492 os << type_key_ <<
": Cannot find required field \'" << key_ <<
"\' during initialization. " 493 <<
"If the key is defined check that its type matches the declared type.";
500 if (this->value_missing_)
return *
this;
501 const T& val = *value_;
503 std::ostringstream os;
504 os << type_key_ <<
"." << key_ <<
": " 505 <<
"value " << val <<
" is smaller than the lower bound " << begin;
512 if (this->value_missing_)
return *
this;
513 const T& val = *value_;
515 std::ostringstream os;
516 os << type_key_ <<
"." << key_ <<
": " 517 <<
"value " << val <<
" is bigger than the upper bound " << end;
524 if (!value_missing_)
return *
this;
526 value_missing_ =
false;
534 template <
typename T>
536 *ptr = val.operator T();
539 template <
typename T>
545 *ptr =
static_cast<T
>(expr->value);
556 inline void SetValue<std::string>(std::string* ptr,
const TVMArgValue& val) {
558 *ptr = val.operator std::string();
560 LOG(FATAL) <<
"Expect str";
566 if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
567 *ptr = val.operator double();
572 *ptr =
static_cast<double>(op->value);
574 *ptr =
static_cast<double>(op->value);
576 LOG(FATAL) <<
"Expect float value, but get " << expr->
GetTypeKey();
598 template <
typename FFind>
603 size_t hit_count_{0};
605 AttrInitVisitor(
const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) {}
607 template <
typename T>
614 if (ffind_(key, &val)) {
621 #if defined(__GNUC__) 622 #pragma GCC diagnostic ignored "-Wpragmas" 623 #pragma GCC diagnostic ignored "-Wpessimizing-move" 625 return std::move(opt);
630 const char* type_key_;
634 template <
typename FFind>
643 template <
typename T>
645 static constexpr
const char* value = T::ContainerType::_type_key;
650 static constexpr
const char* value =
"int";
655 static constexpr
const char* value =
"int64";
660 static constexpr
const char* value =
"uint64_t";
665 static constexpr
const char* value =
"DataType";
670 static constexpr
const char* value =
"str";
675 static constexpr
const char* value =
"bool";
680 static constexpr
const char* value =
"handle";
685 static constexpr
const char* value =
"double";
694 info_->description = str;
697 template <
typename T>
699 std::ostringstream os;
700 os << info_->type_info <<
", default=" << value;
701 info_->type_info = os.str();
704 template <
typename T>
708 template <
typename T>
719 template <
typename T>
736 template <
typename T>
739 if (key == key_) exist_ =
true;
744 template <
typename T>
749 : visitor_(visitor), key_(key), data_(data) {}
753 visitor_->Visit(key_, data_);
776 template <
typename T>
792 template <
typename DerivedType>
797 self()->__VisitAttrs__(vis);
802 self()->__VisitAttrs__(vis);
806 ICHECK_EQ(args.size() % 2, 0);
807 const int kLinearSearchBound = 16;
810 if (args.size() < kLinearSearchBound) {
813 for (
int i = 0; i < args.size(); i += 2) {
814 ICHECK_EQ(args.type_codes[i],
kTVMStr);
815 if (!std::strcmp(key, args.values[i].v_str)) {
823 self()->__VisitAttrs__(vis);
824 hit_count = vis.hit_count_;
827 std::unordered_map<std::string, runtime::TVMArgValue> kwargs;
828 for (
int i = 0; i < args.size(); i += 2) {
829 ICHECK_EQ(args.type_codes[i],
kTVMStr);
830 kwargs[args[i].operator std::string()] = args[i + 1];
833 auto it = kwargs.find(key);
834 if (it != kwargs.end()) {
841 self()->__VisitAttrs__(vis);
842 hit_count = vis.hit_count_;
845 if (hit_count * 2 != args.size() && !allow_unknown) {
846 for (
int i = 0; i < args.size(); i += 2) {
848 visitor.
key_ = args[i].operator std::string();
849 self()->__VisitAttrs__(visitor);
851 std::ostringstream os;
852 os << DerivedType::_type_key <<
": does not have field \'" << visitor.
key_ 853 <<
"\', Possible fields:\n";
854 os <<
"----------------\n";
855 this->PrintDocString(os);
863 DerivedType* pself =
self();
865 self()->__VisitAttrs__(visitor);
871 self()->__VisitAttrs__(visitor);
876 self()->__VisitAttrs__(visitor);
881 DerivedType*
self()
const {
882 return const_cast<DerivedType*
>(
static_cast<const DerivedType*
>(
this));
886 template <
typename... Args>
890 pf(std::forward<Args>(args)...);
896 os << info->name <<
" : " << info->type_info <<
'\n';
897 if (info->description.length() != 0) {
898 os <<
" " << info->description <<
'\n';
904 #endif // TVM_IR_ATTRS_H_ int64_t v_int64
Definition: c_runtime_api.h:145
Map< String, ObjectRef > dict
internal attrs map
Definition: attrs.h:204
AttrError(std::string msg)
constructor
Definition: attrs.h:100
TSelf & set_lower_bound(const T &begin)
Definition: attrs.h:499
bool value_missing_
Definition: attrs.h:475
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:734
void SetValue< uint64_t >(uint64_t *ptr, const TVMArgValue &val)
Definition: attrs.h:589
void SetValue< bool >(bool *ptr, const TVMArgValue &val)
Definition: attrs.h:593
void SetValue< DataType >(DataType *ptr, const TVMArgValue &val)
Definition: attrs.h:551
A custom smart pointer for Object.
Definition: object.h:356
AttrNormalVisitor(AttrVisitor *visitor)
Definition: attrs.h:411
TSelf & describe(const char *str)
Definition: attrs.h:693
void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:115
void PrintDocString(std::ostream &os) const
Print readible docstring to ostream, add newline.
Definition: attrs.h:893
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:102
void SetIntValue(T *ptr, const TVMArgValue &val)
Definition: attrs.h:540
void SHashReduce(SHashReducer hash_reduce) const
Definition: attrs.h:210
Managed reference to BaseAttrsNode.
Definition: attrs.h:190
AttrFieldInfo.
Definition: attrs.h:128
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:101
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
Definition: data_type.h:55
TFunc WithAttr(TFunc input, const std::string &attr_key, ObjectRef attr_value)
Copy the function or module, but overrides the attribute value key with the value.
Definition: attrs.h:347
Structural equality comparison.
AttrInitEntry(AttrInitEntry &&other)
Definition: attrs.h:479
Constant floating point literals in the program.
Definition: expr.h:279
void SetValue< double >(double *ptr, const TVMArgValue &val)
Definition: attrs.h:565
bool exist_
Definition: attrs.h:734
AttrTriggerNonDefaultEntry< T > operator()(const char *key, T *value)
Definition: attrs.h:777
Definition: loop_state.h:456
void SetValue< int >(int *ptr, const TVMArgValue &val)
Definition: attrs.h:581
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin)
Definition: attrs.h:705
AttrTriggerNonDefaultEntry(AttrVisitor *visitor, const char *key, T *data)
Definition: attrs.h:748
bool SEqualReduce(const DerivedType *other, SEqualReducer equal) const
Definition: attrs.h:862
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T &begin)
Definition: attrs.h:399
const char * key_
Definition: attrs.h:468
Array< AttrFieldInfo > ListFieldInfo() const final
Get the field information.
Definition: attrs.h:874
Information about attribute fields in string representations.
Definition: attrs.h:106
AttrNopEntry operator()(const char *key, T *v)
Definition: attrs.h:737
Managed reference to DictAttrsNode.
Definition: attrs.h:227
base class of all object containers.
Definition: object.h:165
Specialized attribute type that is backed by a map. The DictAttrsNode implements the Attrs behavior...
Definition: attrs.h:201
AttrDocEntry operator()(const char *key, T *v)
Definition: attrs.h:720
Constant integer literals in the program.
Definition: expr.h:233
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:81
Optional< TObjectRef > GetAttr(const std::string &attr_key, TObjectRef default_value) const
Definition: attrs.h:276
Helper struct to get the type name known to tvm.
Definition: attrs.h:644
TSelf & set_default(const T &value)
Definition: attrs.h:523
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:393
String name
name of the field
Definition: attrs.h:109
void InitByPackedArgs(const runtime::TVMArgs &args, bool allow_unknown) final
Initialize the attributes by arguments.
Definition: attrs.h:805
bool defined() const
Definition: object.h:537
Runtime primitive data type.
Definition: data_type.h:41
AttrInitVisitor< FFind > CreateInitVisitor(const char *type_key, FFind ffind)
Definition: attrs.h:635
Arguments into TVM functions.
Definition: packed_func.h:335
Optional< TObjectRef > GetAttr(const std::string &attr_key, Optional< TObjectRef > default_value=Optional< TObjectRef >(nullptr)) const
Get a function attribute.
Definition: attrs.h:259
String description
detailed description of the type
Definition: attrs.h:113
Error thrown during attribute checking.
Definition: attrs.h:95
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end)
Definition: attrs.h:709
virtual ~BaseAttrsNode()
virtual destructor
Definition: attrs.h:144
Managed reference class to IntImmNode.
Definition: expr.h:262
TSelf & set_default(DMLC_ATTRIBUTE_UNUSED const T &value)
Definition: attrs.h:395
Reference to string objects.
Definition: string.h:129
AttrInitVisitor(const char *type_key, FFind ffind)
Definition: attrs.h:605
void InitBySeq(Args &&... args)
Initialize the attributes by sequence of arguments.
Definition: attrs.h:887
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T &end)
Definition: attrs.h:764
AttrDocEntry(ObjectPtr< AttrFieldInfoNode > info)
Definition: attrs.h:692
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:706
const char * type_key_
Definition: attrs.h:466
~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION
Definition: attrs.h:751
std::string key_
Definition: attrs.h:733
static bool CanConvertFrom(const TVMArgValue &val)
Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String.
Definition: packed_func.h:1696
Base class of all object reference.
Definition: object.h:504
DataType NullValue< DataType >()
Definition: attrs.h:90
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:778
void SHashReduce(SHashReducer hash_reducer) const
Definition: attrs.h:869
std::string GetTypeKey() const
Definition: object.h:178
String type_info
type docstring information in str.
Definition: attrs.h:111
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:664
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:529
virtual void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:146
bool SEqualReduce(const DictAttrsNode *other, SEqualReducer equal) const
Definition: attrs.h:206
The base class of the all the Use "curiously recurring template pattern".
Definition: attrs.h:793
void VisitNonDefaultAttrs(AttrVisitor *v)
Visit attributes that do not equal the default value.
Definition: attrs.h:800
TSelf & set_default(const T &value)
Definition: attrs.h:757
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1235
int type_code() const
Definition: packed_func.h:547
A single argument value to PackedFunc. Containing both type_code and TVMValue.
Definition: packed_func.h:583
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:68
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
const TVMValue & value() const
Definition: packed_func.h:632
Array< AttrFieldInfo > fields_
Definition: attrs.h:728
void SetValue(T *ptr, const TVMArgValue &val)
Definition: attrs.h:535
AttrNopEntry operator()(const char *key, T *value)
Definition: attrs.h:413
T * value_
Definition: attrs.h:470
TSelf & set_upper_bound(const T &end)
Definition: attrs.h:511
Base class of all attribute class.
Definition: attrs.h:139
TObjectRef NullValue()
Create a NodeRef type that represents null.
Definition: attrs.h:84
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:756
TFunc WithAttrs(TFunc input, Map< String, ObjectRef > attrs)
Copy the function or module, but overrides the attributes with the entries from attrs.
Definition: attrs.h:371
void SetValue< int64_t >(int64_t *ptr, const TVMArgValue &val)
Definition: attrs.h:585
AttrNonDefaultVisitor(AttrVisitor *visitor)
Definition: attrs.h:775
~AttrInitEntry() DMLC_THROW_EXCEPTION
Definition: attrs.h:489
TSelf & set_default(const T &value)
Definition: attrs.h:698
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:858
void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:795
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T &end)
Definition: attrs.h:403
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T &begin)
Definition: attrs.h:763
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
helper macro to declare a base object type that can be inherited.
Definition: object.h:641
runtime::DataType DataType
Definition: data_type.h:389
Definition: c_runtime_api.h:121
AttrInitEntry< T > operator()(const char *key, T *value)
Definition: attrs.h:608
TAttrs AttrsWithDefaultValues()
Create an Attr object with all default values.
Definition: attrs.h:312
Type-erased function used across TVM API.
bool HasNonzeroAttr(const std::string &attr_key) const
Check whether the function has an non-zero integer attr.
Definition: attrs.h:298