tvm
target_kind.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_TARGET_TARGET_KIND_H_
25 #define TVM_TARGET_TARGET_KIND_H_
26 
27 #include <tvm/ffi/function.h>
28 #include <tvm/ffi/reflection/registry.h>
30 #include <tvm/node/node.h>
31 
32 #include <memory>
33 #include <unordered_map>
34 #include <utility>
35 #include <vector>
36 
37 namespace tvm {
38 
39 class Target;
40 
44 using TargetFeatures = Map<String, ffi::Any>;
45 
53 using TargetJSON = Map<String, ffi::Any>;
54 using FTVMTargetParser = ffi::TypedFunction<TargetJSON(TargetJSON)>;
55 
56 namespace detail {
57 template <typename, typename, typename>
58 struct ValueTypeInfoMaker;
59 }
60 
61 class TargetInternal;
62 
63 template <typename>
64 class TargetKindAttrMap;
65 
67 class TargetKindNode : public Object {
68  public:
70  String name;
74  Array<String> default_keys;
79 
80  static void RegisterReflection() {
81  namespace refl = tvm::ffi::reflection;
82  refl::ObjectDef<TargetKindNode>()
83  .def_ro("name", &TargetKindNode::name)
84  .def_ro("default_device_type", &TargetKindNode::default_device_type,
85  refl::AttachFieldFlag::SEqHashIgnore())
86  .def_ro("default_keys", &TargetKindNode::default_keys,
87  refl::AttachFieldFlag::SEqHashIgnore());
88  }
89 
90  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindUniqueInstance;
91  static constexpr const char* _type_key = "target.TargetKind";
93 
94  private:
96  uint32_t AttrRegistryIndex() const { return index_; }
98  String AttrRegistryName() const { return name; }
100  struct ValueTypeInfo {
101  String type_key;
102  int32_t type_index;
103  std::unique_ptr<ValueTypeInfo> key;
104  std::unique_ptr<ValueTypeInfo> val;
105  };
107  std::unordered_map<String, ValueTypeInfo> key2vtype_;
109  std::unordered_map<String, ffi::Any> key2default_;
111  uint32_t index_;
112 
113  template <typename, typename, typename>
115  template <typename, typename>
116  friend class AttrRegistry;
117  template <typename>
119  friend class TargetKindRegEntry;
120  friend class TargetInternal;
121 };
122 
127 class TargetKind : public ObjectRef {
128  public:
129  TargetKind() = default;
131  template <typename ValueType>
132  static inline TargetKindAttrMap<ValueType> GetAttrMap(const String& attr_name);
138  TVM_DLL static Optional<TargetKind> Get(const String& target_kind_name);
140  TargetKindNode* operator->() { return static_cast<TargetKindNode*>(data_.get()); }
141 
143 
144  private:
145  TVM_DLL static const AttrRegistryMapContainerMap<TargetKind>& GetAttrMapContainer(
146  const String& attr_name);
147  friend class TargetKindRegEntry;
148  friend class TargetInternal;
149 };
150 
155 template <typename ValueType>
156 class TargetKindAttrMap : public AttrRegistryMap<TargetKind, ValueType> {
157  public:
159  using TParent::count;
160  using TParent::get;
161  using TParent::operator[];
163 };
164 
166 static constexpr const char* kTvmRuntimeCpp = "c++";
167 
169 static constexpr const char* kTvmRuntimeCrt = "c";
170 
176  public:
190  template <typename ValueType>
191  inline TargetKindRegEntry& set_attr(const String& attr_name, const ValueType& value,
192  int plevel = 10);
202  inline TargetKindRegEntry& set_default_keys(std::vector<String> keys);
208  template <typename FLambda>
209  inline TargetKindRegEntry& set_attrs_preprocessor(FLambda f);
220  template <typename ValueType>
221  inline TargetKindRegEntry& add_attr_option(const String& key);
228  template <typename ValueType>
229  inline TargetKindRegEntry& add_attr_option(const String& key, ffi::Any default_value);
231  inline TargetKindRegEntry& set_name();
236  TVM_DLL static Array<String> ListTargetKinds();
241  TVM_DLL static Map<String, String> ListTargetKindOptions(const TargetKind& kind);
242 
248  TVM_DLL static TargetKindRegEntry& RegisterOrGet(const String& target_kind_name);
249 
250  private:
251  TargetKind kind_;
252  String name;
253 
255  explicit TargetKindRegEntry(uint32_t reg_index) : kind_(make_object<TargetKindNode>()) {
256  kind_->index_ = reg_index;
257  }
264  TVM_DLL void UpdateAttr(const String& key, ffi::Any value, int plevel);
265  template <typename, typename>
266  friend class AttrRegistry;
267  friend class TargetKind;
268 };
269 
270 namespace detail {
271 template <typename Type, template <typename...> class Container>
272 struct is_specialized : std::false_type {
273  using type = std::false_type;
274 };
275 
276 template <template <typename...> class Container, typename... Args>
277 struct is_specialized<Container<Args...>, Container> : std::true_type {
278  using type = std::true_type;
279 };
280 
281 template <typename ValueType, typename IsArray = typename is_specialized<ValueType, Array>::type,
282  typename IsMap = typename is_specialized<ValueType, Map>::type>
284 
285 template <typename ValueType>
286 struct ValueTypeInfoMaker<ValueType, std::false_type, std::false_type> {
287  using ValueTypeInfo = TargetKindNode::ValueTypeInfo;
288 
289  ValueTypeInfo operator()() const {
290  ValueTypeInfo info;
291  info.key = nullptr;
292  info.val = nullptr;
293  if constexpr (std::is_base_of_v<ObjectRef, ValueType>) {
294  int32_t tindex = ffi::TypeToRuntimeTypeIndex<ValueType>::v();
295  info.type_index = tindex;
296  info.type_key = runtime::Object::TypeIndex2Key(tindex);
297  return info;
298  } else if constexpr (std::is_same_v<ValueType, String>) {
299  // special handle string since it can be backed by multiple types.
300  info.type_index = ffi::TypeIndex::kTVMFFIStr;
301  info.type_key = ffi::TypeTraits<ValueType>::TypeStr();
302  return info;
303  } else {
304  // TODO(tqchen) consider upgrade to leverage any system to support union type
305  constexpr int32_t tindex = ffi::TypeToFieldStaticTypeIndex<ValueType>::value;
306  static_assert(tindex != ffi::TypeIndex::kTVMFFIAny, "Do not support union type for now");
307  info.type_index = tindex;
308  info.type_key = runtime::Object::TypeIndex2Key(tindex);
309  return info;
310  }
311  }
312 };
313 
314 template <typename ValueType>
315 struct ValueTypeInfoMaker<ValueType, std::true_type, std::false_type> {
316  using ValueTypeInfo = TargetKindNode::ValueTypeInfo;
317 
318  ValueTypeInfo operator()() const {
319  using key_type = ValueTypeInfoMaker<typename ValueType::value_type>;
320  uint32_t tindex = ValueType::ContainerType::_GetOrAllocRuntimeTypeIndex();
321  ValueTypeInfo info;
322  info.type_index = tindex;
323  info.type_key = runtime::Object::TypeIndex2Key(tindex);
324  info.key = std::make_unique<ValueTypeInfo>(key_type()());
325  info.val = nullptr;
326  return info;
327  }
328 };
329 
330 template <typename ValueType>
331 struct ValueTypeInfoMaker<ValueType, std::false_type, std::true_type> {
332  using ValueTypeInfo = TargetKindNode::ValueTypeInfo;
333  ValueTypeInfo operator()() const {
334  using key_type = ValueTypeInfoMaker<typename ValueType::key_type>;
335  using val_type = ValueTypeInfoMaker<typename ValueType::mapped_type>;
336  uint32_t tindex = ValueType::ContainerType::_GetOrAllocRuntimeTypeIndex();
337  ValueTypeInfo info;
338  info.type_index = tindex;
339  info.type_key = runtime::Object::TypeIndex2Key(tindex);
340  info.key = std::make_unique<ValueTypeInfo>(key_type()());
341  info.val = std::make_unique<ValueTypeInfo>(val_type()());
342  return info;
343  }
344 };
345 
346 } // namespace detail
347 
348 template <typename ValueType>
349 inline TargetKindAttrMap<ValueType> TargetKind::GetAttrMap(const String& attr_name) {
350  return TargetKindAttrMap<ValueType>(GetAttrMapContainer(attr_name));
351 }
352 
353 template <typename ValueType>
354 inline TargetKindRegEntry& TargetKindRegEntry::set_attr(const String& attr_name,
355  const ValueType& value, int plevel) {
356  ICHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
357  ffi::Any rv;
358  rv = value;
359  UpdateAttr(attr_name, rv, plevel);
360  return *this;
361 }
362 
365  return *this;
366 }
367 
368 inline TargetKindRegEntry& TargetKindRegEntry::set_default_keys(std::vector<String> keys) {
369  kind_->default_keys = keys;
370  return *this;
371 }
372 
373 template <typename FLambda>
375  LOG(WARNING) << "set_attrs_preprocessor is deprecated please use set_target_parser instead";
376  kind_->preprocessor = ffi::Function::FromTyped(std::move(f));
377  return *this;
378 }
379 
381  kind_->target_parser = parser;
382  return *this;
383 }
384 
385 template <typename ValueType>
387  ICHECK(!kind_->key2vtype_.count(key))
388  << "AttributeError: add_attr_option failed because '" << key << "' has been set once";
389  kind_->key2vtype_[key] = detail::ValueTypeInfoMaker<ValueType>()();
390  return *this;
391 }
392 
393 template <typename ValueType>
395  Any default_value) {
396  add_attr_option<ValueType>(key);
397  kind_->key2default_[key] = default_value;
398  return *this;
399 }
400 
402  if (kind_->name.empty()) {
403  kind_->name = name;
404  }
405  return *this;
406 }
407 
408 #define TVM_TARGET_KIND_REGISTER_VAR_DEF \
409  static DMLC_ATTRIBUTE_UNUSED ::tvm::TargetKindRegEntry& __make_##TargetKind
410 
428 #define TVM_REGISTER_TARGET_KIND(TargetKindName, DeviceType) \
429  TVM_STR_CONCAT(TVM_TARGET_KIND_REGISTER_VAR_DEF, __COUNTER__) = \
430  ::tvm::TargetKindRegEntry::RegisterOrGet(TargetKindName) \
431  .set_name() \
432  .set_default_device_type(DeviceType) \
433  .add_attr_option<Array<String>>("keys") \
434  .add_attr_option<String>("tag") \
435  .add_attr_option<String>("device") \
436  .add_attr_option<String>("model") \
437  .add_attr_option<Array<String>>("libs") \
438  .add_attr_option<Target>("host") \
439  .add_attr_option<int64_t>("from_device") \
440  .add_attr_option<int64_t>("target_device_type")
441 
442 } // namespace tvm
443 
444 #endif // TVM_TARGET_TARGET_KIND_H_
Attribute map used in registry.
Generic attribute map.
Definition: attr_registry_map.h:38
Map<Key, ValueType> used to store meta-data.
Definition: attr_registry_map.h:105
ValueType get(const TargetKind &key, ValueType def_value) const
get the corresponding value element at key with default value.
Definition: attr_registry_map.h:136
int count(const TargetKind &key) const
Check if the map has op as key.
Definition: attr_registry_map.h:117
Definition: instruction.h:30
Map<TargetKind, ValueType> used to store meta-information about TargetKind.
Definition: target_kind.h:156
TargetKindAttrMap(const AttrRegistryMapContainerMap< TargetKind > &map)
Definition: target_kind.h:162
Target kind, specifies the kind of the target.
Definition: target_kind.h:67
int default_device_type
Device type of target kind.
Definition: target_kind.h:72
String name
Name of the target kind.
Definition: target_kind.h:70
TVM_DECLARE_FINAL_OBJECT_INFO(TargetKindNode, Object)
FTVMTargetParser target_parser
Function used to parse a JSON target during creation.
Definition: target_kind.h:78
friend class TargetInternal
Definition: target_kind.h:120
Array< String > default_keys
Default keys of the target.
Definition: target_kind.h:74
static void RegisterReflection()
Definition: target_kind.h:80
static constexpr const char * _type_key
Definition: target_kind.h:91
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: target_kind.h:90
ffi::Function preprocessor
Function used to preprocess on target creation.
Definition: target_kind.h:76
Helper structure to register TargetKind.
Definition: target_kind.h:175
TargetKindRegEntry & set_attrs_preprocessor(FLambda f)
Set the pre-processing function applied upon target creation.
Definition: target_kind.h:374
TargetKindRegEntry & set_target_parser(FTVMTargetParser parser)
Set the parsing function applied upon target creation.
Definition: target_kind.h:380
TargetKindRegEntry & add_attr_option(const String &key, ffi::Any default_value)
Register a valid configuration option and its ValueType for validation.
TargetKindRegEntry & set_default_keys(std::vector< String > keys)
Set DLPack's device_type the target.
Definition: target_kind.h:368
TargetKindRegEntry & set_name()
Set name of the TargetKind to be the same as registry if it is empty.
Definition: target_kind.h:401
static TargetKindRegEntry & RegisterOrGet(const String &target_kind_name)
Register or get a new entry.
TargetKindRegEntry & set_attr(const String &attr_name, const ValueType &value, int plevel=10)
Register additional attributes to target_kind.
Definition: target_kind.h:354
static Array< String > ListTargetKinds()
List all the entry names in the registry.
TargetKindRegEntry & set_default_device_type(int device_type)
Set DLPack's device_type the target.
Definition: target_kind.h:363
TargetKindRegEntry & add_attr_option(const String &key)
Register a valid configuration option and its ValueType for validation.
Definition: target_kind.h:386
static Map< String, String > ListTargetKindOptions(const TargetKind &kind)
Get all supported option names and types for a given Target kind.
Managed reference class to TargetKindNode.
Definition: target_kind.h:127
static Optional< TargetKind > Get(const String &target_kind_name)
Retrieve the TargetKind given its name.
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TargetKind, ObjectRef, TargetKindNode)
friend class TargetInternal
Definition: target_kind.h:148
static TargetKindAttrMap< ValueType > GetAttrMap(const String &attr_name)
Get the attribute map given the attribute name.
Definition: target_kind.h:349
TargetKindNode * operator->()
Mutable access to the container class
Definition: target_kind.h:140
TargetKind()=default
Managed reference to TypeNode.
Definition: type.h:101
Definition: repr_printer.h:91
tvm::relax::Function Function
Definition: transform.h:42
constexpr const char * device_type
The device type.
Definition: stmt.h:1092
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37
ffi::TypedFunction< TargetJSON(TargetJSON)> FTVMTargetParser
Definition: target_kind.h:54
Map< String, ffi::Any > TargetFeatures
Map containing parsed features of a specific Target.
Definition: target_kind.h:44
Map< String, ffi::Any > TargetJSON
TargetParser to apply on instantiation of a given TargetKind.
Definition: target_kind.h:53
Definitions and helper macros for IR/AST nodes.
Definition: target_kind.h:283
std::true_type type
Definition: target_kind.h:278
Definition: target_kind.h:272
std::false_type type
Definition: target_kind.h:273