44 #ifndef TVM_IR_ATTRS_H_ 45 #define TVM_IR_ATTRS_H_ 47 #include <dmlc/common.h> 55 #include <type_traits> 56 #include <unordered_map> 66 #define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ 67 static constexpr const char* _type_key = TypeKey; \ 68 TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \ 69 template <typename FVisit> \ 70 void _tvm_VisitAttrs(FVisit& _tvm_fvisit) // NOLINT(*) 76 #define TVM_ATTR_FIELD(FieldName) _tvm_fvisit(#FieldName, &FieldName) 83 template <
typename TObjectRef>
85 static_assert(TObjectRef::_type_is_nullable,
"Can only get NullValue for nullable types");
100 explicit AttrError(std::string msg) : Error(
"AttributeError:" + msg) {}
116 v->Visit(
"name", &name);
117 v->Visit(
"type_info", &type_info);
118 v->Visit(
"description", &description);
121 static constexpr
const char* _type_key =
"AttrFieldInfo";
122 static constexpr
bool _type_has_method_sequal_reduce =
false;
123 static constexpr
bool _type_has_method_shash_reduce =
false;
152 template <
typename... Args>
153 inline void InitBySeq(Args&&... args);
158 inline void PrintDocString(std::ostream& os)
const;
165 TVM_DLL
virtual void VisitNonDefaultAttrs(
AttrVisitor* v) = 0;
178 TVM_DLL
virtual void InitByPackedArgs(
const TVMArgs& kwargs,
bool allow_unknown =
false) = 0;
180 static constexpr
const bool _type_has_method_sequal_reduce =
true;
181 static constexpr
const bool _type_has_method_shash_reduce =
true;
182 static constexpr
const char* _type_key =
"Attrs";
215 void InitByPackedArgs(
const runtime::TVMArgs& args,
bool allow_unknown)
final;
219 static constexpr
const char* _type_key =
"DictAttrs";
258 template <
typename TObjectRef>
260 const std::string& attr_key,
262 static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
263 "Can only call GetAttr with ObjectRef types.");
264 if (!defined())
return default_value;
267 auto it = node->
dict.find(attr_key);
268 if (it != node->
dict.end()) {
269 return Downcast<Optional<TObjectRef>>((*it).second);
271 return default_value;
275 template <
typename TObjectRef>
299 return GetAttr<Integer>(attr_key, 0).value_or(0).IntValue() != 0;
311 template <
typename TAttrs>
313 static_assert(std::is_base_of<Attrs, TAttrs>::value,
"Can only take attr nodes");
314 auto n = make_object<typename TAttrs::ContainerType>();
346 template <
typename TFunc>
348 using TNode =
typename TFunc::ContainerType;
349 static_assert(TNode::_type_final,
"Can only operate on the leaf nodes");
350 TNode* node = input.CopyOnWrite();
351 if (node->attrs.defined()) {
352 node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
355 node->attrs = DictAttrs(dict);
370 template <
typename TFunc>
372 using TNode =
typename TFunc::ContainerType;
373 static_assert(TNode::_type_final,
"Can only operate on the leaf nodes");
374 TNode* node = input.CopyOnWrite();
375 if (node->attrs.defined()) {
376 for (
const auto& pair : attrs) {
377 node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second);
380 node->attrs = DictAttrs(std::move(attrs));
411 template <
typename TFunc>
412 inline TFunc
WithoutAttr(TFunc input,
const std::string& attr_key) {
413 using TNode =
typename TFunc::ContainerType;
414 static_assert(TNode::_type_final,
"Can only operate on the leaf nodes");
416 if (input->attrs.defined()) {
417 TNode* node = input.CopyOnWrite();
418 node->attrs.CopyOnWrite()->dict.erase(attr_key);
419 if (node->attrs->dict.size() == 0) {
420 node->attrs = NullValue<DictAttrs>();
435 template <
typename T>
439 template <
typename T>
443 template <
typename T>
453 template <
typename T>
455 visitor_->Visit(key, value);
468 : lhs_(lhs), rhs_(rhs), equal_(equal) {}
469 template <
typename T>
472 const T* rhs_value =
reinterpret_cast<const T*
>(
473 reinterpret_cast<const char*
>(rhs_) +
474 (reinterpret_cast<const char*>(lhs_value) -
reinterpret_cast<const char*
>(lhs_)));
475 if (!equal_(*lhs_value, *rhs_value)) {
491 template <
typename T>
493 hash_reducer_(*value);
502 template <
typename T>
516 bool value_missing_{
false};
521 type_key_ = other.type_key_;
523 value_ = other.value_;
524 value_missing_ = other.value_missing_;
526 other.value_missing_ =
false;
531 if (value_missing_) {
532 std::ostringstream os;
533 os << type_key_ <<
": Cannot find required field \'" << key_ <<
"\' during initialization. " 534 <<
"If the key is defined check that its type matches the declared type.";
541 if (this->value_missing_)
return *
this;
542 const T& val = *value_;
544 std::ostringstream os;
545 os << type_key_ <<
"." << key_ <<
": " 546 <<
"value " << val <<
" is smaller than the lower bound " << begin;
553 if (this->value_missing_)
return *
this;
554 const T& val = *value_;
556 std::ostringstream os;
557 os << type_key_ <<
"." << key_ <<
": " 558 <<
"value " << val <<
" is bigger than the upper bound " << end;
565 if (!value_missing_)
return *
this;
567 value_missing_ =
false;
575 template <
typename T>
577 *ptr = val.operator T();
580 template <
typename T>
586 *ptr =
static_cast<T
>(expr->value);
597 inline void SetValue<std::string>(std::string* ptr,
const TVMArgValue& val) {
599 *ptr = val.operator std::string();
601 LOG(FATAL) <<
"Expect str";
607 if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
608 *ptr = val.operator double();
613 *ptr =
static_cast<double>(op->value);
615 *ptr =
static_cast<double>(op->value);
617 LOG(FATAL) <<
"Expect float value, but get " << expr->
GetTypeKey();
639 template <
typename FFind>
644 size_t hit_count_{0};
646 AttrInitVisitor(
const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) {}
648 template <
typename T>
655 if (ffind_(key, &val)) {
662 #if defined(__GNUC__) 663 #pragma GCC diagnostic ignored "-Wpragmas" 664 #pragma GCC diagnostic ignored "-Wpessimizing-move" 666 return std::move(opt);
671 const char* type_key_;
675 template <
typename FFind>
684 template <
typename T>
686 static constexpr
const char* value = T::ContainerType::_type_key;
691 static constexpr
const char* value =
"int";
696 static constexpr
const char* value =
"int64";
701 static constexpr
const char* value =
"uint64_t";
706 static constexpr
const char* value =
"DataType";
711 static constexpr
const char* value =
"str";
716 static constexpr
const char* value =
"bool";
721 static constexpr
const char* value =
"handle";
726 static constexpr
const char* value =
"double";
735 info_->description = str;
738 template <
typename T>
740 std::ostringstream os;
741 os << info_->type_info <<
", default=" << value;
742 info_->type_info = os.str();
745 template <
typename T>
749 template <
typename T>
760 template <
typename T>
777 template <
typename T>
780 if (key == key_) exist_ =
true;
785 template <
typename T>
790 : visitor_(visitor), key_(key), data_(data) {}
794 visitor_->Visit(key_, data_);
817 template <
typename T>
833 template <
typename DerivedType>
838 self()->_tvm_VisitAttrs(vis);
843 self()->_tvm_VisitAttrs(vis);
847 ICHECK_EQ(args.size() % 2, 0);
848 const int kLinearSearchBound = 16;
851 if (args.size() < kLinearSearchBound) {
854 for (
int i = 0; i < args.size(); i += 2) {
855 ICHECK_EQ(args.type_codes[i],
kTVMStr);
856 if (!std::strcmp(key, args.values[i].v_str)) {
864 self()->_tvm_VisitAttrs(vis);
865 hit_count = vis.hit_count_;
868 std::unordered_map<std::string, runtime::TVMArgValue> kwargs;
869 for (
int i = 0; i < args.size(); i += 2) {
870 ICHECK_EQ(args.type_codes[i],
kTVMStr);
871 kwargs[args[i].operator std::string()] = args[i + 1];
874 auto it = kwargs.find(key);
875 if (it != kwargs.end()) {
882 self()->_tvm_VisitAttrs(vis);
883 hit_count = vis.hit_count_;
886 if (hit_count * 2 != args.size() && !allow_unknown) {
887 for (
int i = 0; i < args.size(); i += 2) {
889 visitor.
key_ = args[i].operator std::string();
890 self()->_tvm_VisitAttrs(visitor);
892 std::ostringstream os;
893 os << DerivedType::_type_key <<
": does not have field \'" << visitor.
key_ 894 <<
"\', Possible fields:\n";
895 os <<
"----------------\n";
896 this->PrintDocString(os);
904 DerivedType* pself =
self();
906 self()->_tvm_VisitAttrs(visitor);
912 self()->_tvm_VisitAttrs(visitor);
917 self()->_tvm_VisitAttrs(visitor);
922 DerivedType*
self()
const {
923 return const_cast<DerivedType*
>(
static_cast<const DerivedType*
>(
this));
927 template <
typename... Args>
931 pf(std::forward<Args>(args)...);
937 os << info->name <<
" : " << info->type_info <<
'\n';
938 if (info->description.length() != 0) {
939 os <<
" " << info->description <<
'\n';
945 #endif // TVM_IR_ATTRS_H_ int64_t v_int64
Definition: c_runtime_api.h:209
Map< String, ObjectRef > dict
internal attrs map
Definition: attrs.h:204
AttrError(std::string msg)
constructor
Definition: attrs.h:100
TSelf & set_lower_bound(const T &begin)
Definition: attrs.h:540
bool value_missing_
Definition: attrs.h:516
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:799
void SetValue< uint64_t >(uint64_t *ptr, const TVMArgValue &val)
Definition: attrs.h:630
void SetValue< bool >(bool *ptr, const TVMArgValue &val)
Definition: attrs.h:634
void SetValue< DataType >(DataType *ptr, const TVMArgValue &val)
Definition: attrs.h:592
A custom smart pointer for Object.
Definition: object.h:358
AttrNormalVisitor(AttrVisitor *visitor)
Definition: attrs.h:452
TSelf & describe(const char *str)
Definition: attrs.h:734
void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:115
void PrintDocString(std::ostream &os) const
Print readible docstring to ostream, add newline.
Definition: attrs.h:934
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
void SetIntValue(T *ptr, const TVMArgValue &val)
Definition: attrs.h:581
void SHashReduce(SHashReducer hash_reduce) const
Definition: attrs.h:210
Managed reference to BaseAttrsNode.
Definition: attrs.h:190
AttrFieldInfo.
Definition: attrs.h:128
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:110
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
Definition: data_type.h:55
TFunc WithAttr(TFunc input, const std::string &attr_key, ObjectRef attr_value)
Copy the function or module, but overrides the attribute value key with the value.
Definition: attrs.h:347
Structural equality comparison.
AttrInitEntry(AttrInitEntry &&other)
Definition: attrs.h:520
Constant floating point literals in the program.
Definition: expr.h:538
void SetValue< double >(double *ptr, const TVMArgValue &val)
Definition: attrs.h:606
bool exist_
Definition: attrs.h:775
AttrTriggerNonDefaultEntry< T > operator()(const char *key, T *value)
Definition: attrs.h:818
Definition: loop_state.h:456
void SetValue< int >(int *ptr, const TVMArgValue &val)
Definition: attrs.h:622
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin)
Definition: attrs.h:746
AttrTriggerNonDefaultEntry(AttrVisitor *visitor, const char *key, T *data)
Definition: attrs.h:789
bool SEqualReduce(const DerivedType *other, SEqualReducer equal) const
Definition: attrs.h:903
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T &begin)
Definition: attrs.h:440
const char * key_
Definition: attrs.h:509
Array< AttrFieldInfo > ListFieldInfo() const final
Get the field information.
Definition: attrs.h:915
Information about attribute fields in string representations.
Definition: attrs.h:106
AttrNopEntry operator()(const char *key, T *v)
Definition: attrs.h:778
Managed reference to DictAttrsNode.
Definition: attrs.h:227
base class of all object containers.
Definition: object.h:167
Specialized attribute type that is backed by a map. The DictAttrsNode implements the Attrs behavior...
Definition: attrs.h:201
AttrDocEntry operator()(const char *key, T *v)
Definition: attrs.h:761
Constant integer literals in the program.
Definition: expr.h:491
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:103
Optional< TObjectRef > GetAttr(const std::string &attr_key, TObjectRef default_value) const
Definition: attrs.h:276
Helper struct to get the type name known to tvm.
Definition: attrs.h:685
TSelf & set_default(const T &value)
Definition: attrs.h:564
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:434
String name
name of the field
Definition: attrs.h:109
void InitByPackedArgs(const runtime::TVMArgs &args, bool allow_unknown) final
Initialize the attributes by arguments.
Definition: attrs.h:846
bool defined() const
Definition: object.h:544
Runtime primitive data type.
Definition: data_type.h:41
AttrInitVisitor< FFind > CreateInitVisitor(const char *type_key, FFind ffind)
Definition: attrs.h:676
Arguments into TVM functions.
Definition: packed_func.h:391
Optional< TObjectRef > GetAttr(const std::string &attr_key, Optional< TObjectRef > default_value=Optional< TObjectRef >(nullptr)) const
Get a function attribute.
Definition: attrs.h:259
String description
detailed description of the type
Definition: attrs.h:113
Error thrown during attribute checking.
Definition: attrs.h:95
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end)
Definition: attrs.h:750
virtual ~BaseAttrsNode()
virtual destructor
Definition: attrs.h:144
TFunc WithoutAttr(TFunc input, const std::string &attr_key)
Copy the function or module, but removes the specified attribute.
Definition: attrs.h:412
Managed reference class to IntImmNode.
Definition: expr.h:520
TSelf & set_default(DMLC_ATTRIBUTE_UNUSED const T &value)
Definition: attrs.h:436
Reference to string objects.
Definition: string.h:98
AttrInitVisitor(const char *type_key, FFind ffind)
Definition: attrs.h:646
void InitBySeq(Args &&... args)
Initialize the attributes by sequence of arguments.
Definition: attrs.h:928
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T &end)
Definition: attrs.h:805
AttrDocEntry(ObjectPtr< AttrFieldInfoNode > info)
Definition: attrs.h:733
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
const char * type_key_
Definition: attrs.h:507
~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION
Definition: attrs.h:792
std::string key_
Definition: attrs.h:774
static bool CanConvertFrom(const TVMArgValue &val)
Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String.
Definition: packed_func.h:1981
Base class of all object reference.
Definition: object.h:511
DataType NullValue< DataType >()
Definition: attrs.h:90
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)
Define CopyOnWrite function in an ObjectRef.
Definition: object.h:785
void SHashReduce(SHashReducer hash_reducer) const
Definition: attrs.h:910
std::string GetTypeKey() const
Definition: object.h:180
String type_info
type docstring information in str.
Definition: attrs.h:111
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:671
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:570
virtual void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:146
bool SEqualReduce(const DictAttrsNode *other, SEqualReducer equal) const
Definition: attrs.h:206
The base class of the all the Use "curiously recurring template pattern".
Definition: attrs.h:834
void VisitNonDefaultAttrs(AttrVisitor *v)
Visit attributes that do not equal the default value.
Definition: attrs.h:841
TSelf & set_default(const T &value)
Definition: attrs.h:798
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
int type_code() const
Definition: packed_func.h:610
A single argument value to PackedFunc. Containing both type_code and TVMValue.
Definition: packed_func.h:646
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:138
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
const TVMValue & value() const
Definition: packed_func.h:691
Array< AttrFieldInfo > fields_
Definition: attrs.h:769
void SetValue(T *ptr, const TVMArgValue &val)
Definition: attrs.h:576
AttrNopEntry operator()(const char *key, T *value)
Definition: attrs.h:454
T * value_
Definition: attrs.h:511
TSelf & set_upper_bound(const T &end)
Definition: attrs.h:552
Base class of all attribute class.
Definition: attrs.h:139
TObjectRef NullValue()
Create a NodeRef type that represents null.
Definition: attrs.h:84
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:797
TFunc WithAttrs(TFunc input, Map< String, ObjectRef > attrs)
Copy the function or module, but overrides the attributes with the entries from attrs.
Definition: attrs.h:371
void SetValue< int64_t >(int64_t *ptr, const TVMArgValue &val)
Definition: attrs.h:626
AttrNonDefaultVisitor(AttrVisitor *visitor)
Definition: attrs.h:816
~AttrInitEntry() DMLC_THROW_EXCEPTION
Definition: attrs.h:530
TSelf & set_default(const T &value)
Definition: attrs.h:739
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:865
void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:836
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T &end)
Definition: attrs.h:444
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T &begin)
Definition: attrs.h:804
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
helper macro to declare a base object type that can be inherited.
Definition: object.h:648
runtime::DataType DataType
Definition: data_type.h:398
Definition: c_runtime_api.h:185
AttrInitEntry< T > operator()(const char *key, T *value)
Definition: attrs.h:649
TAttrs AttrsWithDefaultValues()
Create an Attr object with all default values.
Definition: attrs.h:312
Type-erased function used across TVM API.
bool HasNonzeroAttr(const std::string &attr_key) const
Check whether the function has an non-zero integer attr.
Definition: attrs.h:298