tvm
op.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_IR_OP_H_
26 #define TVM_IR_OP_H_
27 
28 #include <tvm/ffi/function.h>
29 #include <tvm/ffi/reflection/registry.h>
30 #include <tvm/ir/attrs.h>
31 #include <tvm/ir/env_func.h>
32 #include <tvm/ir/expr.h>
33 #include <tvm/ir/type.h>
35 #include <tvm/runtime/logging.h>
36 
37 #include <string>
38 #include <utility>
39 #include <vector>
40 
41 namespace tvm {
42 
43 // forward declare name.
44 template <typename>
45 class OpAttrMap;
46 
47 // TODO(tvm-team): migrate low-level intrinsics to use Op
59 class OpNode : public RelaxExprNode {
60  public:
62  String name;
64  mutable FuncType op_type;
69  String description;
70  /* \brief Information of input arguments to the operator */
71  Array<AttrFieldInfo> arguments;
81  uint32_t attrs_type_index{0};
86  int32_t num_inputs = -1;
92  int32_t support_level = 10;
93 
94  static void RegisterReflection() {
95  namespace refl = tvm::ffi::reflection;
96  refl::ObjectDef<OpNode>()
97  .def_ro("name", &OpNode::name)
98  .def_ro("op_type", &OpNode::op_type, refl::AttachFieldFlag::SEqHashIgnore())
99  .def_ro("description", &OpNode::description, refl::AttachFieldFlag::SEqHashIgnore())
100  .def_ro("arguments", &OpNode::arguments, refl::AttachFieldFlag::SEqHashIgnore())
101  .def_ro("attrs_type_key", &OpNode::attrs_type_key, refl::AttachFieldFlag::SEqHashIgnore())
102  .def_ro("num_inputs", &OpNode::num_inputs, refl::AttachFieldFlag::SEqHashIgnore())
103  .def_ro("support_level", &OpNode::support_level, refl::AttachFieldFlag::SEqHashIgnore());
104  }
105 
106  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindUniqueInstance;
107  static constexpr const char* _type_key = "ir.Op";
109 
110  private:
112  uint32_t AttrRegistryIndex() const { return index_; }
114  std::string AttrRegistryName() const { return name; }
115 
116  // friend class
117  template <typename>
119  template <typename, typename>
120  friend class AttrRegistry;
121  friend class OpRegEntry;
122 
123  // Program internal unique index of operator.
124  // Used to help index the program.
125  uint32_t index_{0};
126 };
127 
132 class Op : public RelaxExpr {
133  public:
141  template <typename ValueType>
142  inline static OpAttrMap<ValueType> GetAttrMap(const String& attr_name);
148  TVM_DLL static bool HasAttrMap(const String& attr_name);
155  TVM_DLL static const Op& Get(const String& op_name);
156 
158 
159  private:
165  TVM_DLL static const AttrRegistryMapContainerMap<Op>& GetAttrMapContainer(const String& key);
166 };
167 
172 class OpRegEntry {
173  public:
175  const Op& op() const { return op_; }
182  inline OpRegEntry& describe(const std::string& descr); // NOLINT(*)
190  inline OpRegEntry& add_argument(const std::string& name, const std::string& type,
191  const std::string& description);
197  template <typename AttrsType>
198  inline OpRegEntry& set_attrs_type();
204  inline OpRegEntry& set_attrs_type_key(const String& key);
210  inline OpRegEntry& set_num_inputs(int32_t n); // NOLINT(*)
216  inline OpRegEntry& set_support_level(int32_t level); // NOLINT(*)
230  template <typename ValueType>
231  inline OpRegEntry& set_attr(const std::string& attr_name, // NOLINT(*)
232  const ValueType& value, int plevel = 10);
233 
238  inline void reset_attr(const std::string& attr_name);
239 
240  // set the name of the op to be the same as registry
241  inline OpRegEntry& set_name() { // NOLINT(*)
242  if (get()->name.length() == 0) {
243  get()->name = name;
244  }
245  return *this;
246  }
252  TVM_DLL static OpRegEntry& RegisterOrGet(const String& name);
253 
254  private:
255  template <typename, typename>
256  friend class AttrRegistry;
257  // the name
258  std::string name;
260  Op op_;
261  // private constructor
262  TVM_DLL OpRegEntry(uint32_t reg_index);
263  // return internal pointer to op.
264  inline OpNode* get();
265  // update the attribute OpAttrMap
266  TVM_DLL void UpdateAttr(const String& key, ffi::Any value, int plevel);
267 };
268 
273 template <typename ValueType>
274 class OpAttrMap : public AttrRegistryMap<Op, ValueType> {
275  public:
283  inline ValueType get(const RelaxExpr& expr, ValueType def_value) const;
284 
286  using TParent::count;
287  using TParent::get;
288  using TParent::operator[];
289 
290  private:
291  friend class Op;
292  // constructor
293  explicit OpAttrMap(const AttrRegistryMapContainerMap<Op>& map) : TParent(map) {}
294 };
295 
296 // internal macros to make
297 #define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_##Op
298 
314 #define TVM_REGISTER_OP(OpName) \
315  TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \
316  ::tvm::OpRegEntry::RegisterOrGet(OpName).set_name()
317 
318 // implementations
319 
320 template <typename ValueType>
321 inline OpAttrMap<ValueType> Op::GetAttrMap(const String& key) {
322  return OpAttrMap<ValueType>(Op::GetAttrMapContainer(key));
323 }
324 
325 inline OpNode* OpRegEntry::get() { return const_cast<OpNode*>(op_.operator->()); }
326 
327 inline OpRegEntry& OpRegEntry::describe(const std::string& descr) { // NOLINT(*)
328  get()->description = descr;
329  return *this;
330 }
331 
332 inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type,
333  const std::string& description) {
334  auto n = make_object<AttrFieldInfoNode>();
335  n->name = name;
336  n->type_info = type;
337  n->description = description;
338  get()->arguments.push_back(AttrFieldInfo(n));
339  return *this;
340 }
341 
342 inline OpRegEntry& OpRegEntry::set_num_inputs(int32_t n) { // NOLINT(*)
343  get()->num_inputs = n;
344  return *this;
345 }
346 
347 template <typename AttrsType>
348 inline OpRegEntry& OpRegEntry::set_attrs_type() { // NOLINT(*)
349  get()->attrs_type_key = AttrsType::_type_key;
350  get()->attrs_type_index = AttrsType::RuntimeTypeIndex();
351  return *this;
352 }
353 
354 inline OpRegEntry& OpRegEntry::set_attrs_type_key(const String& key) { // NOLINT(*)
355  get()->attrs_type_key = key;
356  get()->attrs_type_index = tvm::ffi::TypeKeyToIndex(key.c_str());
357  return *this;
358 }
359 
360 inline OpRegEntry& OpRegEntry::set_support_level(int32_t n) { // NOLINT(*)
361  get()->support_level = n;
362  return *this;
363 }
364 
365 template <typename ValueType>
366 inline OpRegEntry& OpRegEntry::set_attr( // NOLINT(*)
367  const std::string& attr_name, const ValueType& value, int plevel) {
368  ICHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
369  UpdateAttr(attr_name, Any(value), plevel);
370  return *this;
371 }
372 
373 // member functions of OpAttrMap
374 
375 template <typename ValueType>
376 inline ValueType OpAttrMap<ValueType>::get(const RelaxExpr& expr, ValueType def_value) const {
377  ICHECK(expr.defined());
378  if (const OpNode* op = expr.as<OpNode>()) {
379  return this->map_.get(GetRef<Op>(op), def_value);
380  } else {
381  return def_value;
382  }
383 }
384 
385 } // namespace tvm
386 #endif // TVM_IR_OP_H_
Attribute map used in registry.
Helpers for attribute objects.
AttrFieldInfo.
Definition: attrs.h:92
Generic attribute map.
Definition: attr_registry_map.h:38
Map<Key, ValueType> used to store meta-data.
Definition: attr_registry_map.h:105
ValueType get(const Op &key, ValueType def_value) const
get the corresponding value element at key with default value.
Definition: attr_registry_map.h:136
int count(const Op &key) const
Check if the map has op as key.
Definition: attr_registry_map.h:117
Definition: instruction.h:30
Managed reference to FuncTypeNode.
Definition: type.h:283
Map<Op,ValueType> used to store meta-information about Op.
Definition: op.h:274
ValueType get(const RelaxExpr &expr, ValueType def_value) const
get the corresponding value element at op with default value.
Definition: op.h:376
Primitive Op(builtin intrinsics)
Definition: op.h:59
static void RegisterReflection()
Definition: op.h:94
static constexpr const char * _type_key
Definition: op.h:107
String name
name of the operator
Definition: op.h:62
uint32_t attrs_type_index
attribute type index, this field varies in each run and is not exposed to frontend.
Definition: op.h:81
String attrs_type_key
The type key of the attribute field This can be empty, in which case it defaults to anything.
Definition: op.h:76
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind
Definition: op.h:106
int32_t support_level
support level of the operator, The lower the more priority it contains. This is in analogies to BLAS ...
Definition: op.h:92
TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelaxExprNode)
int32_t num_inputs
number of input arguments to the operator, -1 means it is variable length
Definition: op.h:86
FuncType op_type
the type of the operator
Definition: op.h:64
String description
detailed description of the operator This can be used to generate docstring automatically for the ope...
Definition: op.h:69
Array< AttrFieldInfo > arguments
Definition: op.h:71
Helper structure to register operators.
Definition: op.h:172
OpRegEntry & set_attrs_type_key(const String &key)
Set the attrs type key and index to be AttrsType.
Definition: op.h:354
OpRegEntry & describe(const std::string &descr)
setter function during registration Set the description of operator
Definition: op.h:327
static OpRegEntry & RegisterOrGet(const String &name)
Register or get a new entry.
OpRegEntry & set_name()
Definition: op.h:241
void reset_attr(const std::string &attr_name)
Resets an attr of the registry.
OpRegEntry & add_argument(const std::string &name, const std::string &type, const std::string &description)
Add argument information to the function.
Definition: op.h:332
OpRegEntry & set_attrs_type()
Set the attrs type key and index to be AttrsType.
Definition: op.h:348
OpRegEntry & set_support_level(int32_t level)
Set the support level of op.
Definition: op.h:360
const Op & op() const
Definition: op.h:175
OpRegEntry & set_attr(const std::string &attr_name, const ValueType &value, int plevel=10)
Register additional attributes to operator.
Definition: op.h:366
OpRegEntry & set_num_inputs(int32_t n)
Set the num_inputs.
Definition: op.h:342
Managed reference class to OpNode.
Definition: op.h:132
static OpAttrMap< ValueType > GetAttrMap(const String &attr_name)
Get additional registered attribute about operators. If nothing has been registered,...
Definition: op.h:321
TVM_DEFINE_OBJECT_REF_METHODS(Op, RelaxExpr, OpNode)
static bool HasAttrMap(const String &attr_name)
Checks if an attr map is present in the registry.
static const Op & Get(const String &op_name)
Get an Op for a given operator name. Will raise an error if the op has not been registered.
Base node of all non-primitive expressions.
Definition: expr.h:422
Managed reference to RelaxExprNode.
Definition: expr.h:446
Serializable global function used in IR.
Base expr nodes in TVM.
IR/AST nodes for the unified type system in TVM.
Definition: repr_printer.h:91
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:37