44 #ifndef TVM_IR_ATTRS_H_
45 #define TVM_IR_ATTRS_H_
47 #include <dmlc/common.h>
55 #include <type_traits>
56 #include <unordered_map>
66 #define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
67 static constexpr const char* _type_key = TypeKey; \
68 TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
69 template <typename FVisit> \
70 void _tvm_VisitAttrs(FVisit& _tvm_fvisit)
76 #define TVM_ATTR_FIELD(FieldName) _tvm_fvisit(#FieldName, &FieldName)
83 template <
typename TObjectRef>
85 static_assert(TObjectRef::_type_is_nullable,
"Can only get NullValue for nullable types");
100 explicit AttrError(std::string msg) : Error(
"AttributeError:" + msg) {}
116 v->Visit(
"name", &
name);
121 static constexpr
const char*
_type_key =
"AttrFieldInfo";
152 template <
typename... Args>
257 template <
typename TObjectRef>
259 const std::string& attr_key,
261 static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
262 "Can only call GetAttr with ObjectRef types.");
263 if (!defined())
return default_value;
266 auto it = node->
dict.find(attr_key);
267 if (it != node->
dict.end()) {
279 return default_value;
283 template <
typename TObjectRef>
307 return GetAttr<Integer>(attr_key, 0).value_or(0).IntValue() != 0;
319 template <
typename TAttrs>
321 static_assert(std::is_base_of<Attrs, TAttrs>::value,
"Can only take attr nodes");
322 auto n = make_object<typename TAttrs::ContainerType>();
394 template <
typename TFunc>
396 using TNode =
typename TFunc::ContainerType;
397 static_assert(TNode::_type_final,
"Can only operate on the leaf nodes");
398 TNode* node = input.CopyOnWrite();
399 node->attrs =
WithAttr(std::move(node->attrs), attr_key, attr_value);
414 template <
typename TFunc>
416 using TNode =
typename TFunc::ContainerType;
417 static_assert(TNode::_type_final,
"Can only operate on the leaf nodes");
418 TNode* node = input.CopyOnWrite();
420 node->attrs =
WithAttrs(std::move(node->attrs), attrs);
451 template <
typename TFunc>
452 inline TFunc
WithoutAttr(TFunc input,
const std::string& attr_key) {
453 using TNode =
typename TFunc::ContainerType;
454 static_assert(TNode::_type_final,
"Can only operate on the leaf nodes");
456 TNode* node = input.CopyOnWrite();
457 node->attrs =
WithoutAttr(std::move(node->attrs), attr_key);
471 template <
typename T>
475 template <
typename T>
479 template <
typename T>
489 template <
typename T>
491 visitor_->Visit(key, value);
504 : lhs_(lhs), rhs_(rhs), equal_(
equal) {}
505 template <
typename T>
508 const T* rhs_value =
reinterpret_cast<const T*
>(
509 reinterpret_cast<const char*
>(rhs_) +
510 (
reinterpret_cast<const char*
>(lhs_value) -
reinterpret_cast<const char*
>(lhs_)));
511 if (!equal_(*lhs_value, *rhs_value)) {
527 template <
typename T>
529 hash_reducer_(*value);
538 template <
typename T>
552 bool value_missing_{
false};
557 type_key_ = other.type_key_;
559 value_ = other.value_;
560 value_missing_ = other.value_missing_;
562 other.value_missing_ =
false;
567 if (value_missing_) {
568 std::ostringstream os;
569 os << type_key_ <<
": Cannot find required field \'" << key_ <<
"\' during initialization. "
570 <<
"If the key is defined check that its type matches the declared type.";
577 if (this->value_missing_)
return *
this;
578 const T& val = *value_;
580 std::ostringstream os;
581 os << type_key_ <<
"." << key_ <<
": "
582 <<
"value " << val <<
" is smaller than the lower bound " << begin;
589 if (this->value_missing_)
return *
this;
590 const T& val = *value_;
592 std::ostringstream os;
593 os << type_key_ <<
"." << key_ <<
": "
594 <<
"value " << val <<
" is bigger than the upper bound " << end;
601 if (!value_missing_)
return *
this;
603 value_missing_ =
false;
611 template <
typename T>
613 *ptr = val.operator T();
616 template <
typename T>
622 *ptr =
static_cast<T
>(expr->value);
633 inline void SetValue<std::string>(std::string* ptr,
const TVMArgValue& val) {
635 *ptr = val.operator std::string();
637 LOG(FATAL) <<
"Expect str";
644 *ptr = val.operator double();
649 *ptr =
static_cast<double>(op->value);
651 *ptr =
static_cast<double>(op->value);
653 LOG(FATAL) <<
"Expect float value, but get " << expr->
GetTypeKey();
675 template <
typename FFind>
680 size_t hit_count_{0};
682 AttrInitVisitor(
const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) {}
684 template <
typename T>
691 if (ffind_(key, &val)) {
698 #if defined(__GNUC__)
699 #pragma GCC diagnostic ignored "-Wpragmas"
700 #pragma GCC diagnostic ignored "-Wpessimizing-move"
702 return std::move(opt);
707 const char* type_key_;
711 template <
typename FFind>
720 template <
typename T>
722 static constexpr
const char* value = T::ContainerType::_type_key;
727 static constexpr
const char* value =
"int";
732 static constexpr
const char* value =
"int64";
737 static constexpr
const char* value =
"uint64_t";
742 static constexpr
const char* value =
"DataType";
747 static constexpr
const char* value =
"str";
752 static constexpr
const char* value =
"bool";
757 static constexpr
const char* value =
"handle";
762 static constexpr
const char* value =
"double";
771 info_->description = str;
774 template <
typename T>
776 std::ostringstream os;
777 os << info_->type_info <<
", default=" << value;
778 info_->type_info = os.str();
781 template <
typename T>
785 template <
typename T>
796 template <
typename T>
813 template <
typename T>
816 if (key == key_) exist_ =
true;
821 template <
typename T>
826 : visitor_(visitor), key_(key), data_(data) {}
830 visitor_->Visit(key_, data_);
853 template <
typename T>
869 template <
typename DerivedType>
874 self()->_tvm_VisitAttrs(vis);
879 self()->_tvm_VisitAttrs(vis);
883 ICHECK_EQ(args.size() % 2, 0);
884 const int kLinearSearchBound = 16;
887 if (args.size() < kLinearSearchBound) {
890 for (
int i = 0; i < args.size(); i += 2) {
891 ICHECK_EQ(args.type_codes[i],
kTVMStr);
892 if (!std::strcmp(key, args.values[i].v_str)) {
900 self()->_tvm_VisitAttrs(vis);
901 hit_count = vis.hit_count_;
904 std::unordered_map<std::string, runtime::TVMArgValue> kwargs;
905 for (
int i = 0; i < args.size(); i += 2) {
906 ICHECK_EQ(args.type_codes[i],
kTVMStr);
907 kwargs[args[i].operator std::string()] = args[i + 1];
910 auto it = kwargs.find(key);
911 if (it != kwargs.end()) {
918 self()->_tvm_VisitAttrs(vis);
919 hit_count = vis.hit_count_;
922 if (hit_count * 2 != args.size() && !allow_unknown) {
923 for (
int i = 0; i < args.size(); i += 2) {
925 visitor.
key_ = args[i].operator std::string();
926 self()->_tvm_VisitAttrs(visitor);
928 std::ostringstream os;
929 os << DerivedType::_type_key <<
": does not have field \'" << visitor.
key_
930 <<
"\', Possible fields:\n";
931 os <<
"----------------\n";
940 DerivedType* pself =
self();
942 self()->_tvm_VisitAttrs(visitor);
948 self()->_tvm_VisitAttrs(visitor);
953 self()->_tvm_VisitAttrs(visitor);
958 DerivedType*
self()
const {
959 return const_cast<DerivedType*
>(
static_cast<const DerivedType*
>(
this));
963 template <
typename... Args>
967 pf(std::forward<Args>(args)...);
973 os << info->name <<
" : " << info->type_info <<
'\n';
974 if (info->description.length() != 0) {
975 os <<
" " << info->description <<
'\n';
@ kTVMStr
Definition: c_runtime_api.h:186
Information about attribute fields in string representations.
Definition: attrs.h:106
static constexpr bool _type_has_method_shash_reduce
Definition: attrs.h:123
String description
detailed description of the type
Definition: attrs.h:113
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object)
String type_info
type docstring information in str.
Definition: attrs.h:111
void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:115
String name
name of the field
Definition: attrs.h:109
static constexpr bool _type_has_method_sequal_reduce
Definition: attrs.h:122
static constexpr const char * _type_key
Definition: attrs.h:121
AttrFieldInfo.
Definition: attrs.h:128
TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode)
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
The base class of the all the Use "curiously recurring template pattern".
Definition: attrs.h:870
bool SEqualReduce(const DerivedType *other, SEqualReducer equal) const
Definition: attrs.h:939
void SHashReduce(SHashReducer hash_reducer) const
Definition: attrs.h:946
void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:872
void VisitNonDefaultAttrs(AttrVisitor *v)
Visit attributes that do not equal the default value.
Definition: attrs.h:877
Array< AttrFieldInfo > ListFieldInfo() const final
Get the field information.
Definition: attrs.h:951
void InitByPackedArgs(const runtime::TVMArgs &args, bool allow_unknown) final
Initialize the attributes by arguments.
Definition: attrs.h:882
Managed reference to BaseAttrsNode.
Definition: attrs.h:190
TVM_DEFINE_OBJECT_REF_METHODS(Attrs, ObjectRef, BaseAttrsNode)
Base class of all attribute class.
Definition: attrs.h:139
virtual ~BaseAttrsNode()
virtual destructor
Definition: attrs.h:144
virtual void InitByPackedArgs(const TVMArgs &kwargs, bool allow_unknown=false)=0
Initialize the attributes by arguments.
virtual Array< AttrFieldInfo > ListFieldInfo() const =0
Get the field information.
void PrintDocString(std::ostream &os) const
Print readible docstring to ostream, add newline.
Definition: attrs.h:970
static constexpr const char * _type_key
Definition: attrs.h:182
TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object)
void InitBySeq(Args &&... args)
Initialize the attributes by sequence of arguments.
Definition: attrs.h:964
virtual void VisitNonDefaultAttrs(AttrVisitor *v)=0
Visit attributes that do not equal the default value.
static constexpr const bool _type_has_method_sequal_reduce
Definition: attrs.h:180
static constexpr const bool _type_has_method_shash_reduce
Definition: attrs.h:181
virtual void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:146
Specialized attribute type that is backed by a map. The DictAttrsNode implements the Attrs behavior,...
Definition: attrs.h:201
TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode)
void InitByPackedArgs(const runtime::TVMArgs &args, bool allow_unknown) final
Initialize the attributes by arguments.
void VisitAttrs(AttrVisitor *v) final
static constexpr const char * _type_key
Definition: attrs.h:219
Array< AttrFieldInfo > ListFieldInfo() const final
Get the field information.
void VisitNonDefaultAttrs(AttrVisitor *v) final
Visit attributes that do not equal the default value.
Map< String, ObjectRef > dict
internal attrs map
Definition: attrs.h:204
void SHashReduce(SHashReducer hash_reduce) const
Definition: attrs.h:210
bool SEqualReduce(const DictAttrsNode *other, SEqualReducer equal) const
Definition: attrs.h:206
Managed reference to DictAttrsNode.
Definition: attrs.h:227
bool HasNonzeroAttr(const std::string &attr_key) const
Check whether the function has an non-zero integer attr.
Definition: attrs.h:306
Optional< TObjectRef > GetAttr(const std::string &attr_key, Optional< TObjectRef > default_value=Optional< TObjectRef >(nullptr)) const
Get a function attribute.
Definition: attrs.h:258
DictAttrs(Map< String, ObjectRef > dict={})
Consruct a Attrs backed by DictAttrsNode.
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode)
Optional< TObjectRef > GetAttr(const std::string &attr_key, TObjectRef default_value) const
Definition: attrs.h:284
TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(DictAttrs, Attrs, DictAttrsNode)
Constant floating point literals in the program.
Definition: expr.h:548
Constant integer literals in the program.
Definition: expr.h:501
Managed reference class to IntImmNode.
Definition: expr.h:530
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:114
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin)
Definition: attrs.h:782
TSelf & set_default(const T &value)
Definition: attrs.h:775
AttrDocEntry(ObjectPtr< AttrFieldInfoNode > info)
Definition: attrs.h:769
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end)
Definition: attrs.h:786
TSelf & describe(const char *str)
Definition: attrs.h:770
AttrDocEntry operator()(const char *key, T *v)
Definition: attrs.h:797
Array< AttrFieldInfo > fields_
Definition: attrs.h:805
AttrNopEntry operator()(const char *key, T *v)
Definition: attrs.h:814
std::string key_
Definition: attrs.h:810
bool exist_
Definition: attrs.h:811
AttrInitVisitor(const char *type_key, FFind ffind)
Definition: attrs.h:682
AttrInitEntry< T > operator()(const char *key, T *value)
Definition: attrs.h:685
AttrNonDefaultVisitor(AttrVisitor *visitor)
Definition: attrs.h:852
AttrTriggerNonDefaultEntry< T > operator()(const char *key, T *value)
Definition: attrs.h:854
AttrNopEntry operator()(const char *key, T *value)
Definition: attrs.h:490
AttrNormalVisitor(AttrVisitor *visitor)
Definition: attrs.h:488
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Runtime primitive data type.
Definition: data_type.h:43
@ kHandle
Definition: data_type.h:57
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
bool defined() const
Definition: object.h:552
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
base class of all object containers.
Definition: object.h:171
std::string GetTypeKey() const
Definition: object.h:184
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:141
Reference to string objects.
Definition: string.h:98
static bool CanConvertFrom(const TVMArgValue &val)
Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String.
Definition: packed_func.h:2683
A single argument value to PackedFunc. Containing both type_code and TVMValue.
Definition: packed_func.h:796
const TVMValue & value() const
Definition: packed_func.h:838
Arguments into TVM functions.
Definition: packed_func.h:394
int type_code() const
Definition: packed_func.h:656
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:946
void SetValue< int >(int *ptr, const TVMArgValue &val)
Definition: attrs.h:658
void SetValue< double >(double *ptr, const TVMArgValue &val)
Definition: attrs.h:642
void SetValue< DataType >(DataType *ptr, const TVMArgValue &val)
Definition: attrs.h:628
AttrInitVisitor< FFind > CreateInitVisitor(const char *type_key, FFind ffind)
Definition: attrs.h:712
void SetValue< uint64_t >(uint64_t *ptr, const TVMArgValue &val)
Definition: attrs.h:666
void SetValue< int64_t >(int64_t *ptr, const TVMArgValue &val)
Definition: attrs.h:662
void SetValue< bool >(bool *ptr, const TVMArgValue &val)
Definition: attrs.h:670
void SetValue(T *ptr, const TVMArgValue &val)
Definition: attrs.h:612
void SetIntValue(T *ptr, const TVMArgValue &val)
Definition: attrs.h:617
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
DictAttrs WithAttr(DictAttrs attrs, String key, ObjectRef value)
Copy the DictAttrs, but overrides a single attribute.
DictAttrs WithoutAttr(DictAttrs attrs, const std::string &key)
Copy the DictAttrs, but without a specific attribute.
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
DataType NullValue< DataType >()
Definition: attrs.h:90
TAttrs AttrsWithDefaultValues()
Create an Attr object with all default values.
Definition: attrs.h:320
runtime::DataType DataType
Definition: data_type.h:493
DictAttrs WithAttrs(DictAttrs attrs, Map< String, ObjectRef > new_attrs)
Copy the DictAttrs, but overrides attributes with the entries from attrs.
TObjectRef NullValue()
Create a NodeRef type that represents null.
Definition: attrs.h:84
Type-erased function used across TVM API.
Error thrown during attribute checking.
Definition: attrs.h:95
AttrError(std::string msg)
constructor
Definition: attrs.h:100
TSelf & set_lower_bound(const T &begin)
Definition: attrs.h:576
const char * type_key_
Definition: attrs.h:543
const char * key_
Definition: attrs.h:545
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:606
TSelf & set_upper_bound(const T &end)
Definition: attrs.h:588
~AttrInitEntry() DMLC_THROW_EXCEPTION
Definition: attrs.h:566
bool value_missing_
Definition: attrs.h:552
TSelf & set_default(const T &value)
Definition: attrs.h:600
T * value_
Definition: attrs.h:547
AttrInitEntry(AttrInitEntry &&other)
Definition: attrs.h:556
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T &begin)
Definition: attrs.h:476
TSelf & set_default(DMLC_ATTRIBUTE_UNUSED const T &value)
Definition: attrs.h:472
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:470
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T &end)
Definition: attrs.h:480
AttrTriggerNonDefaultEntry(AttrVisitor *visitor, const char *key, T *data)
Definition: attrs.h:825
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T &begin)
Definition: attrs.h:840
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T &end)
Definition: attrs.h:841
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:833
~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION
Definition: attrs.h:828
TSelf & set_default(const T &value)
Definition: attrs.h:834
Helper struct to get the type name known to tvm.
Definition: attrs.h:721
Structural equality comparison.
int64_t v_int64
Definition: c_runtime_api.h:211