tvm
attrs.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
28 #ifndef TVM_IR_ATTRS_H_
29 #define TVM_IR_ATTRS_H_
30 
31 #include <tvm/ffi/container/map.h>
32 #include <tvm/ffi/extra/structural_equal.h>
33 #include <tvm/ffi/extra/structural_hash.h>
34 #include <tvm/ffi/function.h>
35 #include <tvm/ffi/reflection/accessor.h>
36 #include <tvm/ffi/reflection/registry.h>
37 #include <tvm/ir/cow.h>
38 #include <tvm/ir/expr.h>
39 
40 #include <functional>
41 #include <string>
42 #include <type_traits>
43 #include <unordered_map>
44 #include <utility>
45 #include <vector>
46 
47 namespace tvm {
48 
54 template <typename TObjectRef>
55 inline TObjectRef NullValue() {
56  static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types");
57  return TObjectRef(ffi::ObjectPtr<typename TObjectRef::ContainerType>(nullptr));
58 }
59 
60 template <>
62  return DataType(DataType::kHandle, 0, 0);
63 }
64 
68 class AttrFieldInfoNode : public ffi::Object {
69  public:
71  ffi::String name;
73  ffi::String type_info;
75  ffi::String description;
76 
77  static void RegisterReflection() {
78  namespace rfl = ffi::reflection;
79  rfl::ObjectDef<AttrFieldInfoNode>()
80  .def_ro("name", &AttrFieldInfoNode::name)
81  .def_ro("type_info", &AttrFieldInfoNode::type_info)
82  .def_ro("description", &AttrFieldInfoNode::description);
83  }
84 
85  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
86 
87  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.AttrFieldInfo", AttrFieldInfoNode, ffi::Object);
88 };
89 
91 class AttrFieldInfo : public ffi::ObjectRef {
92  public:
94 };
95 
102 class BaseAttrsNode : public ffi::Object {
103  public:
105  virtual ~BaseAttrsNode() {}
111  template <typename... Args>
112  inline void InitBySeq(Args&&... args);
120  TVM_DLL virtual void InitByPackedArgs(const ffi::PackedArgs& kwargs,
121  bool allow_unknown = false) = 0;
122 
123  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
124  TVM_FFI_DECLARE_OBJECT_INFO("ir.Attrs", BaseAttrsNode, ffi::Object);
125 };
126 
131 class Attrs : public ffi::ObjectRef {
132  public:
134 };
135 
142 class DictAttrsNode : public BaseAttrsNode {
143  public:
145  ffi::Map<ffi::String, ffi::Any> dict;
146 
147  static void RegisterReflection() {
148  namespace rfl = ffi::reflection;
149  rfl::ObjectDef<DictAttrsNode>().def_ro("__dict__", &DictAttrsNode::dict);
150  }
151 
152  void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final;
153 
154  // type info
156 };
157 
162 class DictAttrs : public Attrs {
163  public:
167  explicit DictAttrs(ffi::UnsafeInit tag) : Attrs(tag) {}
172  TVM_DLL explicit DictAttrs(ffi::Map<ffi::String, Any> dict = {});
173 
174  // Utils for accessing attributes
175  // This needs to be on DictAttrs, not DictAttrsNode because we return the default
176  // value if DictAttrsNode is not defined.
196  template <typename TObjectRef>
197  ffi::Optional<TObjectRef> GetAttr(
198  const std::string& attr_key,
199  ffi::Optional<TObjectRef> default_value = ffi::Optional<TObjectRef>(std::nullopt)) const {
200  if (!defined()) return default_value;
201  const DictAttrsNode* node = this->as<DictAttrsNode>();
202  auto it = node->dict.find(attr_key);
203  if (it != node->dict.end()) {
204  return (*it).second.cast<TObjectRef>();
205  } else {
206  return default_value;
207  }
208  }
209  // variant that uses TObjectRef to enable implicit conversion to default value.
210  template <typename TObjectRef>
211  ffi::Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
212  return GetAttr<TObjectRef>(attr_key, ffi::Optional<TObjectRef>(default_value));
213  }
233  bool HasNonzeroAttr(const std::string& attr_key) const {
234  return GetAttr<Integer>(attr_key, 0).value_or(0).IntValue() != 0;
235  }
236 
237  explicit DictAttrs(::tvm::ffi::ObjectPtr<DictAttrsNode> n) : Attrs(n) {}
238  DictAttrs(const DictAttrs&) = default;
239  DictAttrs(DictAttrs&&) = default;
240  DictAttrs& operator=(const DictAttrs&) = default;
242  const DictAttrsNode* operator->() const { return static_cast<const DictAttrsNode*>(data_.get()); }
243  const DictAttrsNode* get() const { return operator->(); }
246 };
247 
258 DictAttrs WithAttrs(DictAttrs attrs, ffi::Map<ffi::String, Any> new_attrs);
259 
271 DictAttrs WithAttr(DictAttrs attrs, ffi::String key, Any value);
272 
273 inline DictAttrs WithAttr(DictAttrs attrs, const std::string& key, Any value) {
274  return WithAttr(std::move(attrs), ffi::String(key), std::move(value));
275 }
276 
286 DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key);
287 
315 template <typename TFunc>
316 inline TFunc WithAttr(TFunc input, const std::string& attr_key, Any attr_value) {
317  using TNode = typename TFunc::ContainerType;
318  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
319  TNode* node = input.CopyOnWrite();
320  node->attrs = WithAttr(std::move(node->attrs), attr_key, attr_value);
321  return input;
322 }
323 
334 template <typename TFunc>
335 inline TFunc WithAttrs(TFunc input, ffi::Map<ffi::String, Any> attrs) {
336  using TNode = typename TFunc::ContainerType;
337  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
338  TNode* node = input.CopyOnWrite();
339 
340  node->attrs = WithAttrs(std::move(node->attrs), attrs);
341 
342  return input;
343 }
344 
371 template <typename TFunc>
372 inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) {
373  using TNode = typename TFunc::ContainerType;
374  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
375 
376  TNode* node = input.CopyOnWrite();
377  node->attrs = WithoutAttr(std::move(node->attrs), attr_key);
378 
379  return input;
380 }
381 
390 template <typename DerivedType>
392  public:
393  void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final {
394  TVM_FFI_THROW(InternalError) << "`" << DerivedType::_type_key
395  << "` uses new reflection mechanism for init";
396  }
397 
398  private:
399  DerivedType* self() const {
400  return const_cast<DerivedType*>(static_cast<const DerivedType*>(this));
401  }
402 };
403 
409 template <typename TAttrs>
410 inline TAttrs AttrsWithDefaultValues() {
411  static_assert(std::is_base_of_v<Attrs, TAttrs>, "Can only take attr nodes");
412  using ContainerType = typename TAttrs::ContainerType;
413  if constexpr (std::is_base_of_v<AttrsNodeReflAdapter<ContainerType>, ContainerType>) {
414  static auto finit_object = ffi::Function::GetGlobalRequired("ffi.MakeObjectFromPackedArgs");
415  AnyView packed_args[1];
416  packed_args[0] = ContainerType::RuntimeTypeIndex();
417  ffi::Any rv;
418  finit_object.CallPacked(ffi::PackedArgs(packed_args, 1), &rv);
419  return rv.cast<TAttrs>();
420  } else {
421  auto n = ffi::make_object<ContainerType>();
422  n->InitByPackedArgs(ffi::PackedArgs(nullptr, 0), false);
423  return TAttrs(n);
424  }
425 }
426 
427 } // namespace tvm
428 #endif // TVM_IR_ATTRS_H_
Information about attribute fields in string representations.
Definition: attrs.h:68
ffi::String name
name of the field
Definition: attrs.h:71
ffi::String type_info
type docstring information in str.
Definition: attrs.h:73
ffi::String description
detailed description of the type
Definition: attrs.h:75
static void RegisterReflection()
Definition: attrs.h:77
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.AttrFieldInfo", AttrFieldInfoNode, ffi::Object)
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: attrs.h:85
AttrFieldInfo.
Definition: attrs.h:91
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AttrFieldInfo, ffi::ObjectRef, AttrFieldInfoNode)
Adapter for AttrsNode with the new reflection API.
Definition: attrs.h:391
void InitByPackedArgs(const ffi::PackedArgs &args, bool allow_unknown) final
Initialize the attributes by arguments.
Definition: attrs.h:393
Managed reference to BaseAttrsNode.
Definition: attrs.h:131
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Attrs, ffi::ObjectRef, BaseAttrsNode)
Base class of all attribute class.
Definition: attrs.h:102
virtual ~BaseAttrsNode()
virtual destructor
Definition: attrs.h:105
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: attrs.h:123
void InitBySeq(Args &&... args)
Initialize the attributes by sequence of arguments.
virtual void InitByPackedArgs(const ffi::PackedArgs &kwargs, bool allow_unknown=false)=0
Initialize the attributes by arguments.
TVM_FFI_DECLARE_OBJECT_INFO("ir.Attrs", BaseAttrsNode, ffi::Object)
Specialized attribute type that is backed by a map. The DictAttrsNode implements the Attrs behavior,...
Definition: attrs.h:142
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.DictAttrs", DictAttrsNode, BaseAttrsNode)
static void RegisterReflection()
Definition: attrs.h:147
void InitByPackedArgs(const ffi::PackedArgs &args, bool allow_unknown) final
Initialize the attributes by arguments.
ffi::Map< ffi::String, ffi::Any > dict
internal attrs map
Definition: attrs.h:145
Managed reference to DictAttrsNode.
Definition: attrs.h:162
DictAttrs(DictAttrs &&)=default
DictAttrs & operator=(DictAttrs &&)=default
const DictAttrsNode * get() const
Definition: attrs.h:243
DictAttrs(::tvm::ffi::ObjectPtr< DictAttrsNode > n)
Definition: attrs.h:237
bool HasNonzeroAttr(const std::string &attr_key) const
Check whether the function has an non-zero integer attr.
Definition: attrs.h:233
DictAttrs & operator=(const DictAttrs &)=default
DictAttrs(ffi::Map< ffi::String, Any > dict={})
Consruct a Attrs backed by DictAttrsNode.
const DictAttrsNode * operator->() const
Definition: attrs.h:242
DictAttrs(const DictAttrs &)=default
ffi::Optional< TObjectRef > GetAttr(const std::string &attr_key, ffi::Optional< TObjectRef > default_value=ffi::Optional< TObjectRef >(std::nullopt)) const
Get a function attribute.
Definition: attrs.h:197
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode)
DictAttrs(ffi::UnsafeInit tag)
constructor with UnsafeInit
Definition: attrs.h:167
ffi::Optional< TObjectRef > GetAttr(const std::string &attr_key, TObjectRef default_value) const
Definition: attrs.h:211
Runtime primitive data type.
Definition: data_type.h:45
@ kHandle
Definition: data_type.h:59
Copy-on-write helper macro for IR ffi::ObjectRef types.
Base expr nodes in TVM.
An object that builds and maintains block scope and StmtSref mapping for Dependence analysis.
Definition: analyzer.h:37
DictAttrs WithoutAttr(DictAttrs attrs, const std::string &key)
Copy the DictAttrs, but without a specific attribute.
DataType NullValue< DataType >()
Definition: attrs.h:61
TAttrs AttrsWithDefaultValues()
Create an Attr object with all default values.
Definition: attrs.h:410
runtime::DataType DataType
Definition: data_type.h:457
DictAttrs WithAttrs(DictAttrs attrs, ffi::Map< ffi::String, Any > new_attrs)
Copy the DictAttrs, but overrides attributes with the entries from attrs.
TObjectRef NullValue()
Create a NodeRef type that represents null.
Definition: attrs.h:55
DictAttrs WithAttr(DictAttrs attrs, ffi::String key, Any value)
Copy the DictAttrs, but overrides a single attribute.