44 #ifndef TVM_IR_ATTRS_H_
45 #define TVM_IR_ATTRS_H_
47 #include <dmlc/common.h>
55 #include <type_traits>
56 #include <unordered_map>
66 #define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
67 static constexpr const char* _type_key = TypeKey; \
68 TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
69 template <typename FVisit> \
70 void _tvm_VisitAttrs(FVisit& _tvm_fvisit)
76 #define TVM_ATTR_FIELD(FieldName) _tvm_fvisit(#FieldName, &FieldName)
83 template <
typename TObjectRef>
85 static_assert(TObjectRef::_type_is_nullable,
"Can only get NullValue for nullable types");
100 explicit AttrError(std::string msg) : Error(
"AttributeError:" + msg) {}
116 v->Visit(
"name", &
name);
121 static constexpr
const char*
_type_key =
"AttrFieldInfo";
152 template <
typename... Args>
257 template <
typename TObjectRef>
259 const std::string& attr_key,
261 static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
262 "Can only call GetAttr with ObjectRef types.");
263 if (!defined())
return default_value;
266 auto it = node->
dict.find(attr_key);
267 if (it != node->
dict.end()) {
268 return Downcast<Optional<TObjectRef>>((*it).second);
270 return default_value;
274 template <
typename TObjectRef>
298 return GetAttr<Integer>(attr_key, 0).value_or(0).IntValue() != 0;
310 template <
typename TAttrs>
312 static_assert(std::is_base_of<Attrs, TAttrs>::value,
"Can only take attr nodes");
313 auto n = make_object<typename TAttrs::ContainerType>();
345 template <
typename TFunc>
347 using TNode =
typename TFunc::ContainerType;
348 static_assert(TNode::_type_final,
"Can only operate on the leaf nodes");
349 TNode* node = input.CopyOnWrite();
350 if (node->attrs.defined()) {
351 node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
369 template <
typename TFunc>
371 using TNode =
typename TFunc::ContainerType;
372 static_assert(TNode::_type_final,
"Can only operate on the leaf nodes");
373 TNode* node = input.CopyOnWrite();
374 if (node->attrs.defined()) {
375 for (
const auto& pair : attrs) {
376 node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second);
379 node->attrs =
DictAttrs(std::move(attrs));
410 template <
typename TFunc>
411 inline TFunc
WithoutAttr(TFunc input,
const std::string& attr_key) {
412 using TNode =
typename TFunc::ContainerType;
413 static_assert(TNode::_type_final,
"Can only operate on the leaf nodes");
415 if (input->attrs.defined()) {
416 TNode* node = input.CopyOnWrite();
417 node->attrs.CopyOnWrite()->dict.erase(attr_key);
418 if (node->attrs->dict.size() == 0) {
419 node->attrs = NullValue<DictAttrs>();
434 template <
typename T>
438 template <
typename T>
442 template <
typename T>
452 template <
typename T>
454 visitor_->Visit(key, value);
467 : lhs_(lhs), rhs_(rhs), equal_(
equal) {}
468 template <
typename T>
471 const T* rhs_value =
reinterpret_cast<const T*
>(
472 reinterpret_cast<const char*
>(rhs_) +
473 (
reinterpret_cast<const char*
>(lhs_value) -
reinterpret_cast<const char*
>(lhs_)));
474 if (!equal_(*lhs_value, *rhs_value)) {
490 template <
typename T>
492 hash_reducer_(*value);
501 template <
typename T>
515 bool value_missing_{
false};
520 type_key_ = other.type_key_;
522 value_ = other.value_;
523 value_missing_ = other.value_missing_;
525 other.value_missing_ =
false;
530 if (value_missing_) {
531 std::ostringstream os;
532 os << type_key_ <<
": Cannot find required field \'" << key_ <<
"\' during initialization. "
533 <<
"If the key is defined check that its type matches the declared type.";
540 if (this->value_missing_)
return *
this;
541 const T& val = *value_;
543 std::ostringstream os;
544 os << type_key_ <<
"." << key_ <<
": "
545 <<
"value " << val <<
" is smaller than the lower bound " << begin;
552 if (this->value_missing_)
return *
this;
553 const T& val = *value_;
555 std::ostringstream os;
556 os << type_key_ <<
"." << key_ <<
": "
557 <<
"value " << val <<
" is bigger than the upper bound " << end;
564 if (!value_missing_)
return *
this;
566 value_missing_ =
false;
574 template <
typename T>
576 *ptr = val.operator T();
579 template <
typename T>
585 *ptr =
static_cast<T
>(expr->value);
596 inline void SetValue<std::string>(std::string* ptr,
const TVMArgValue& val) {
598 *ptr = val.operator std::string();
600 LOG(FATAL) <<
"Expect str";
607 *ptr = val.operator double();
612 *ptr =
static_cast<double>(op->value);
614 *ptr =
static_cast<double>(op->value);
616 LOG(FATAL) <<
"Expect float value, but get " << expr->
GetTypeKey();
638 template <
typename FFind>
643 size_t hit_count_{0};
645 AttrInitVisitor(
const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) {}
647 template <
typename T>
654 if (ffind_(key, &val)) {
661 #if defined(__GNUC__)
662 #pragma GCC diagnostic ignored "-Wpragmas"
663 #pragma GCC diagnostic ignored "-Wpessimizing-move"
665 return std::move(opt);
670 const char* type_key_;
674 template <
typename FFind>
683 template <
typename T>
685 static constexpr
const char* value = T::ContainerType::_type_key;
690 static constexpr
const char* value =
"int";
695 static constexpr
const char* value =
"int64";
700 static constexpr
const char* value =
"uint64_t";
705 static constexpr
const char* value =
"DataType";
710 static constexpr
const char* value =
"str";
715 static constexpr
const char* value =
"bool";
720 static constexpr
const char* value =
"handle";
725 static constexpr
const char* value =
"double";
734 info_->description = str;
737 template <
typename T>
739 std::ostringstream os;
740 os << info_->type_info <<
", default=" << value;
741 info_->type_info = os.str();
744 template <
typename T>
748 template <
typename T>
759 template <
typename T>
776 template <
typename T>
779 if (key == key_) exist_ =
true;
784 template <
typename T>
789 : visitor_(visitor), key_(key), data_(data) {}
793 visitor_->Visit(key_, data_);
816 template <
typename T>
832 template <
typename DerivedType>
837 self()->_tvm_VisitAttrs(vis);
842 self()->_tvm_VisitAttrs(vis);
846 ICHECK_EQ(args.size() % 2, 0);
847 const int kLinearSearchBound = 16;
850 if (args.size() < kLinearSearchBound) {
853 for (
int i = 0; i < args.size(); i += 2) {
854 ICHECK_EQ(args.type_codes[i],
kTVMStr);
855 if (!std::strcmp(key, args.values[i].v_str)) {
863 self()->_tvm_VisitAttrs(vis);
864 hit_count = vis.hit_count_;
867 std::unordered_map<std::string, runtime::TVMArgValue> kwargs;
868 for (
int i = 0; i < args.size(); i += 2) {
869 ICHECK_EQ(args.type_codes[i],
kTVMStr);
870 kwargs[args[i].operator std::string()] = args[i + 1];
873 auto it = kwargs.find(key);
874 if (it != kwargs.end()) {
881 self()->_tvm_VisitAttrs(vis);
882 hit_count = vis.hit_count_;
885 if (hit_count * 2 != args.size() && !allow_unknown) {
886 for (
int i = 0; i < args.size(); i += 2) {
888 visitor.
key_ = args[i].operator std::string();
889 self()->_tvm_VisitAttrs(visitor);
891 std::ostringstream os;
892 os << DerivedType::_type_key <<
": does not have field \'" << visitor.
key_
893 <<
"\', Possible fields:\n";
894 os <<
"----------------\n";
903 DerivedType* pself =
self();
905 self()->_tvm_VisitAttrs(visitor);
911 self()->_tvm_VisitAttrs(visitor);
916 self()->_tvm_VisitAttrs(visitor);
921 DerivedType*
self()
const {
922 return const_cast<DerivedType*
>(
static_cast<const DerivedType*
>(
this));
926 template <
typename... Args>
930 pf(std::forward<Args>(args)...);
936 os << info->name <<
" : " << info->type_info <<
'\n';
937 if (info->description.length() != 0) {
938 os <<
" " << info->description <<
'\n';
@ kTVMStr
Definition: c_runtime_api.h:185
Information about attribute fields in string representations.
Definition: attrs.h:106
static constexpr bool _type_has_method_shash_reduce
Definition: attrs.h:123
String description
detailed description of the type
Definition: attrs.h:113
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object)
String type_info
type docstring information in str.
Definition: attrs.h:111
void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:115
String name
name of the field
Definition: attrs.h:109
static constexpr bool _type_has_method_sequal_reduce
Definition: attrs.h:122
static constexpr const char * _type_key
Definition: attrs.h:121
AttrFieldInfo.
Definition: attrs.h:128
TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode)
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
The base class of the all the Use "curiously recurring template pattern".
Definition: attrs.h:833
bool SEqualReduce(const DerivedType *other, SEqualReducer equal) const
Definition: attrs.h:902
void SHashReduce(SHashReducer hash_reducer) const
Definition: attrs.h:909
void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:835
void VisitNonDefaultAttrs(AttrVisitor *v)
Visit attributes that do not equal the default value.
Definition: attrs.h:840
Array< AttrFieldInfo > ListFieldInfo() const final
Get the field information.
Definition: attrs.h:914
void InitByPackedArgs(const runtime::TVMArgs &args, bool allow_unknown) final
Initialize the attributes by arguments.
Definition: attrs.h:845
Managed reference to BaseAttrsNode.
Definition: attrs.h:190
TVM_DEFINE_OBJECT_REF_METHODS(Attrs, ObjectRef, BaseAttrsNode)
Base class of all attribute class.
Definition: attrs.h:139
virtual ~BaseAttrsNode()
virtual destructor
Definition: attrs.h:144
virtual void InitByPackedArgs(const TVMArgs &kwargs, bool allow_unknown=false)=0
Initialize the attributes by arguments.
virtual Array< AttrFieldInfo > ListFieldInfo() const =0
Get the field information.
void PrintDocString(std::ostream &os) const
Print readible docstring to ostream, add newline.
Definition: attrs.h:933
static constexpr const char * _type_key
Definition: attrs.h:182
TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object)
void InitBySeq(Args &&... args)
Initialize the attributes by sequence of arguments.
Definition: attrs.h:927
virtual void VisitNonDefaultAttrs(AttrVisitor *v)=0
Visit attributes that do not equal the default value.
static constexpr const bool _type_has_method_sequal_reduce
Definition: attrs.h:180
static constexpr const bool _type_has_method_shash_reduce
Definition: attrs.h:181
virtual void VisitAttrs(AttrVisitor *v)
Definition: attrs.h:146
Specialized attribute type that is backed by a map. The DictAttrsNode implements the Attrs behavior,...
Definition: attrs.h:201
TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode)
void InitByPackedArgs(const runtime::TVMArgs &args, bool allow_unknown) final
Initialize the attributes by arguments.
void VisitAttrs(AttrVisitor *v) final
static constexpr const char * _type_key
Definition: attrs.h:219
Array< AttrFieldInfo > ListFieldInfo() const final
Get the field information.
void VisitNonDefaultAttrs(AttrVisitor *v) final
Visit attributes that do not equal the default value.
Map< String, ObjectRef > dict
internal attrs map
Definition: attrs.h:204
void SHashReduce(SHashReducer hash_reduce) const
Definition: attrs.h:210
bool SEqualReduce(const DictAttrsNode *other, SEqualReducer equal) const
Definition: attrs.h:206
Managed reference to DictAttrsNode.
Definition: attrs.h:227
DictAttrs(Map< String, ObjectRef > dict)
Consruct a Attrs backed by DictAttrsNode.
bool HasNonzeroAttr(const std::string &attr_key) const
Check whether the function has an non-zero integer attr.
Definition: attrs.h:297
Optional< TObjectRef > GetAttr(const std::string &attr_key, Optional< TObjectRef > default_value=Optional< TObjectRef >(nullptr)) const
Get a function attribute.
Definition: attrs.h:258
TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode)
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode)
Optional< TObjectRef > GetAttr(const std::string &attr_key, TObjectRef default_value) const
Definition: attrs.h:275
Constant floating point literals in the program.
Definition: expr.h:538
Constant integer literals in the program.
Definition: expr.h:491
Managed reference class to IntImmNode.
Definition: expr.h:520
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:124
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:110
Content-aware structural equality comparator for objects.
Definition: structural_equal.h:103
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin)
Definition: attrs.h:745
TSelf & set_default(const T &value)
Definition: attrs.h:738
AttrDocEntry(ObjectPtr< AttrFieldInfoNode > info)
Definition: attrs.h:732
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end)
Definition: attrs.h:749
TSelf & describe(const char *str)
Definition: attrs.h:733
AttrDocEntry operator()(const char *key, T *v)
Definition: attrs.h:760
Array< AttrFieldInfo > fields_
Definition: attrs.h:768
AttrNopEntry operator()(const char *key, T *v)
Definition: attrs.h:777
std::string key_
Definition: attrs.h:773
bool exist_
Definition: attrs.h:774
AttrInitVisitor(const char *type_key, FFind ffind)
Definition: attrs.h:645
AttrInitEntry< T > operator()(const char *key, T *value)
Definition: attrs.h:648
AttrNonDefaultVisitor(AttrVisitor *visitor)
Definition: attrs.h:815
AttrTriggerNonDefaultEntry< T > operator()(const char *key, T *value)
Definition: attrs.h:817
AttrNopEntry operator()(const char *key, T *value)
Definition: attrs.h:453
AttrNormalVisitor(AttrVisitor *visitor)
Definition: attrs.h:451
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Runtime primitive data type.
Definition: data_type.h:42
@ kHandle
Definition: data_type.h:56
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
A custom smart pointer for Object.
Definition: object.h:360
Base class of all object reference.
Definition: object.h:517
bool defined() const
Definition: object.h:550
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:894
base class of all object containers.
Definition: object.h:169
std::string GetTypeKey() const
Definition: object.h:182
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:139
Reference to string objects.
Definition: string.h:98
static bool CanConvertFrom(const TVMArgValue &val)
Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String.
Definition: packed_func.h:2194
A single argument value to PackedFunc. Containing both type_code and TVMValue.
Definition: packed_func.h:649
const TVMValue & value() const
Definition: packed_func.h:691
Arguments into TVM functions.
Definition: packed_func.h:392
int type_code() const
Definition: packed_func.h:613
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:799
void SetValue< int >(int *ptr, const TVMArgValue &val)
Definition: attrs.h:621
void SetValue< double >(double *ptr, const TVMArgValue &val)
Definition: attrs.h:605
void SetValue< DataType >(DataType *ptr, const TVMArgValue &val)
Definition: attrs.h:591
AttrInitVisitor< FFind > CreateInitVisitor(const char *type_key, FFind ffind)
Definition: attrs.h:675
void SetValue< uint64_t >(uint64_t *ptr, const TVMArgValue &val)
Definition: attrs.h:629
void SetValue< int64_t >(int64_t *ptr, const TVMArgValue &val)
Definition: attrs.h:625
void SetValue< bool >(bool *ptr, const TVMArgValue &val)
Definition: attrs.h:633
void SetValue(T *ptr, const TVMArgValue &val)
Definition: attrs.h:575
void SetIntValue(T *ptr, const TVMArgValue &val)
Definition: attrs.h:580
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
DataType NullValue< DataType >()
Definition: attrs.h:90
TFunc WithAttr(TFunc input, const std::string &attr_key, ObjectRef attr_value)
Copy the function or module, but overrides the attribute value key with the value.
Definition: attrs.h:346
TAttrs AttrsWithDefaultValues()
Create an Attr object with all default values.
Definition: attrs.h:311
runtime::DataType DataType
Definition: data_type.h:433
TFunc WithoutAttr(TFunc input, const std::string &attr_key)
Copy the function or module, but removes the specified attribute.
Definition: attrs.h:411
TFunc WithAttrs(TFunc input, Map< String, ObjectRef > attrs)
Copy the function or module, but overrides the attributes with the entries from attrs.
Definition: attrs.h:370
TObjectRef NullValue()
Create a NodeRef type that represents null.
Definition: attrs.h:84
Type-erased function used across TVM API.
Error thrown during attribute checking.
Definition: attrs.h:95
AttrError(std::string msg)
constructor
Definition: attrs.h:100
TSelf & set_lower_bound(const T &begin)
Definition: attrs.h:539
const char * type_key_
Definition: attrs.h:506
const char * key_
Definition: attrs.h:508
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:569
TSelf & set_upper_bound(const T &end)
Definition: attrs.h:551
~AttrInitEntry() DMLC_THROW_EXCEPTION
Definition: attrs.h:529
bool value_missing_
Definition: attrs.h:515
TSelf & set_default(const T &value)
Definition: attrs.h:563
T * value_
Definition: attrs.h:510
AttrInitEntry(AttrInitEntry &&other)
Definition: attrs.h:519
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T &begin)
Definition: attrs.h:439
TSelf & set_default(DMLC_ATTRIBUTE_UNUSED const T &value)
Definition: attrs.h:435
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:433
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T &end)
Definition: attrs.h:443
AttrTriggerNonDefaultEntry(AttrVisitor *visitor, const char *key, T *data)
Definition: attrs.h:788
TSelf & set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T &begin)
Definition: attrs.h:803
TSelf & set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T &end)
Definition: attrs.h:804
TSelf & describe(DMLC_ATTRIBUTE_UNUSED const char *str)
Definition: attrs.h:796
~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION
Definition: attrs.h:791
TSelf & set_default(const T &value)
Definition: attrs.h:797
Helper struct to get the type name known to tvm.
Definition: attrs.h:684
Structural equality comparison.
int64_t v_int64
Definition: c_runtime_api.h:209