tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
op.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_IR_OP_H_
26 #define TVM_IR_OP_H_
27 
28 #include <tvm/ir/attrs.h>
29 #include <tvm/ir/env_func.h>
30 #include <tvm/ir/expr.h>
31 #include <tvm/ir/type.h>
33 #include <tvm/runtime/logging.h>
34 #include <tvm/runtime/registry.h>
35 
36 #include <string>
37 #include <utility>
38 #include <vector>
39 
40 namespace tvm {
41 
42 // forward declare name.
43 template <typename>
44 class OpAttrMap;
45 
46 // TODO(tvm-team): migrate low-level intrinsics to use Op
58 class OpNode : public RelaxExprNode {
59  public:
63  mutable FuncType op_type;
69  /* \brief Information of input arguments to the operator */
80  uint32_t attrs_type_index{0};
85  int32_t num_inputs = -1;
91  int32_t support_level = 10;
92 
94  v->Visit("name", &name);
95  v->Visit("op_type", &op_type);
96  v->Visit("description", &description);
97  v->Visit("arguments", &arguments);
98  v->Visit("attrs_type_key", &attrs_type_key);
99  v->Visit("num_inputs", &num_inputs);
100  v->Visit("support_level", &support_level);
101  }
102 
103  bool SEqualReduce(const OpNode* other, SEqualReducer equal) const {
104  // pointer equality is fine as there is only one op with the same name.
105  return this == other;
106  }
107 
108  void SHashReduce(SHashReducer hash_reduce) const {
109  // Name uniquely identifies an Op.
110  hash_reduce(name);
111  }
112 
113  static constexpr const char* _type_key = "Op";
115 
116  private:
118  uint32_t AttrRegistryIndex() const { return index_; }
120  std::string AttrRegistryName() const { return name; }
121 
122  // friend class
123  template <typename>
125  template <typename, typename>
126  friend class AttrRegistry;
127  friend class OpRegEntry;
128 
129  // Program internal unique index of operator.
130  // Used to help index the program.
131  uint32_t index_{0};
132 };
133 
138 class Op : public RelaxExpr {
139  public:
147  template <typename ValueType>
148  inline static OpAttrMap<ValueType> GetAttrMap(const String& attr_name);
154  TVM_DLL static bool HasAttrMap(const String& attr_name);
161  TVM_DLL static const Op& Get(const String& op_name);
162 
164 
165  private:
171  TVM_DLL static const AttrRegistryMapContainerMap<Op>& GetAttrMapContainer(const String& key);
172 };
173 
178 class OpRegEntry {
179  public:
181  const Op& op() const { return op_; }
188  inline OpRegEntry& describe(const std::string& descr); // NOLINT(*)
196  inline OpRegEntry& add_argument(const std::string& name, const std::string& type,
197  const std::string& description);
203  template <typename AttrsType>
204  inline OpRegEntry& set_attrs_type();
210  inline OpRegEntry& set_attrs_type_key(const String& key);
216  inline OpRegEntry& set_num_inputs(int32_t n); // NOLINT(*)
222  inline OpRegEntry& set_support_level(int32_t level); // NOLINT(*)
236  template <typename ValueType>
237  inline OpRegEntry& set_attr(const std::string& attr_name, // NOLINT(*)
238  const ValueType& value, int plevel = 10);
239 
244  inline void reset_attr(const std::string& attr_name);
245 
246  // set the name of the op to be the same as registry
247  inline OpRegEntry& set_name() { // NOLINT(*)
248  if (get()->name.length() == 0) {
249  get()->name = name;
250  }
251  return *this;
252  }
258  TVM_DLL static OpRegEntry& RegisterOrGet(const String& name);
259 
260  private:
261  template <typename, typename>
262  friend class AttrRegistry;
263  // the name
264  std::string name;
266  Op op_;
267  // private constructor
268  TVM_DLL OpRegEntry(uint32_t reg_index);
269  // return internal pointer to op.
270  inline OpNode* get();
271  // update the attribute OpAttrMap
272  TVM_DLL void UpdateAttr(const String& key, runtime::TVMRetValue value, int plevel);
273 };
274 
279 template <typename ValueType>
280 class OpAttrMap : public AttrRegistryMap<Op, ValueType> {
281  public:
289  inline ValueType get(const RelaxExpr& expr, ValueType def_value) const;
290 
292  using TParent::count;
293  using TParent::get;
294  using TParent::operator[];
295 
296  private:
297  friend class Op;
298  // constructor
299  explicit OpAttrMap(const AttrRegistryMapContainerMap<Op>& map) : TParent(map) {}
300 };
301 
302 // internal macros to make
303 #define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_##Op
304 
320 #define TVM_REGISTER_OP(OpName) \
321  TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \
322  ::tvm::OpRegEntry::RegisterOrGet(OpName).set_name()
323 
324 // implementations
325 
326 template <typename ValueType>
328  return OpAttrMap<ValueType>(Op::GetAttrMapContainer(key));
329 }
330 
331 inline OpNode* OpRegEntry::get() { return const_cast<OpNode*>(op_.operator->()); }
332 
333 inline OpRegEntry& OpRegEntry::describe(const std::string& descr) { // NOLINT(*)
334  get()->description = descr;
335  return *this;
336 }
337 
338 inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type,
339  const std::string& description) {
340  auto n = make_object<AttrFieldInfoNode>();
341  n->name = name;
342  n->type_info = type;
343  n->description = description;
344  get()->arguments.push_back(AttrFieldInfo(n));
345  return *this;
346 }
347 
348 inline OpRegEntry& OpRegEntry::set_num_inputs(int32_t n) { // NOLINT(*)
349  get()->num_inputs = n;
350  return *this;
351 }
352 
353 template <typename AttrsType>
354 inline OpRegEntry& OpRegEntry::set_attrs_type() { // NOLINT(*)
355  get()->attrs_type_key = AttrsType::_type_key;
356  get()->attrs_type_index = AttrsType::RuntimeTypeIndex();
357  return *this;
358 }
359 
360 inline OpRegEntry& OpRegEntry::set_attrs_type_key(const String& key) { // NOLINT(*)
361  get()->attrs_type_key = key;
363  return *this;
364 }
365 
366 inline OpRegEntry& OpRegEntry::set_support_level(int32_t n) { // NOLINT(*)
367  get()->support_level = n;
368  return *this;
369 }
370 
371 template <typename ValueType>
372 inline OpRegEntry& OpRegEntry::set_attr( // NOLINT(*)
373  const std::string& attr_name, const ValueType& value, int plevel) {
374  ICHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
376  rv = value;
377  UpdateAttr(attr_name, rv, plevel);
378  return *this;
379 }
380 
381 // member functions of OpAttrMap
382 
383 template <typename ValueType>
384 inline ValueType OpAttrMap<ValueType>::get(const RelaxExpr& expr, ValueType def_value) const {
385  ICHECK(expr.defined());
386  if (const OpNode* op = expr.as<OpNode>()) {
387  return this->map_.get(GetRef<Op>(op), def_value);
388  } else {
389  return def_value;
390  }
391 }
392 
393 } // namespace tvm
394 #endif // TVM_IR_OP_H_
Attribute map used in registry.
Helpers for attribute objects.
AttrFieldInfo.
Definition: attrs.h:128
Generic attribute map.
Definition: attr_registry_map.h:38
Map<Key, ValueType> used to store meta-data.
Definition: attr_registry_map.h:101
ValueType get(const Op &key, ValueType def_value) const
get the corresponding value element at key with default value.
Definition: attr_registry_map.h:126
int count(const Op &key) const
Check if the map has op as key.
Definition: attr_registry_map.h:113
Definition: instruction.h:30
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Managed reference to FuncTypeNode.
Definition: type.h:301
Map<Op,ValueType> used to store meta-information about Op.
Definition: op.h:280
ValueType get(const RelaxExpr &expr, ValueType def_value) const
get the corresponding value element at op with default value.
Definition: op.h:384
Primitive Op(builtin intrinsics)
Definition: op.h:58
void SHashReduce(SHashReducer hash_reduce) const
Definition: op.h:108
static constexpr const char * _type_key
Definition: op.h:113
String name
name of the operator
Definition: op.h:61
uint32_t attrs_type_index
attribute type index, this field varies in each run and is not exposed to frontend.
Definition: op.h:80
String attrs_type_key
The type key of the attribute field This can be empty, in which case it defaults to anything.
Definition: op.h:75
int32_t support_level
support level of the operator, The lower the more priority it contains. This is in analogies to BLAS ...
Definition: op.h:91
TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelaxExprNode)
void VisitAttrs(AttrVisitor *v)
Definition: op.h:93
int32_t num_inputs
number of input arguments to the operator, -1 means it is variable length
Definition: op.h:85
FuncType op_type
the type of the operator
Definition: op.h:63
String description
detailed description of the operator This can be used to generate docstring automatically for the ope...
Definition: op.h:68
Array< AttrFieldInfo > arguments
Definition: op.h:70
bool SEqualReduce(const OpNode *other, SEqualReducer equal) const
Definition: op.h:103
Helper structure to register operators.
Definition: op.h:178
OpRegEntry & set_attrs_type_key(const String &key)
Set the attrs type key and index to be AttrsType.
Definition: op.h:360
OpRegEntry & describe(const std::string &descr)
setter function during registration Set the description of operator
Definition: op.h:333
static OpRegEntry & RegisterOrGet(const String &name)
Register or get a new entry.
OpRegEntry & set_name()
Definition: op.h:247
void reset_attr(const std::string &attr_name)
Resets an attr of the registry.
OpRegEntry & add_argument(const std::string &name, const std::string &type, const std::string &description)
Add argument information to the function.
Definition: op.h:338
OpRegEntry & set_attrs_type()
Set the attrs type key and index to be AttrsType.
Definition: op.h:354
OpRegEntry & set_support_level(int32_t level)
Set the support level of op.
Definition: op.h:366
const Op & op() const
Definition: op.h:181
OpRegEntry & set_attr(const std::string &attr_name, const ValueType &value, int plevel=10)
Register additional attributes to operator.
Definition: op.h:372
OpRegEntry & set_num_inputs(int32_t n)
Set the num_inputs.
Definition: op.h:348
Managed reference class to OpNode.
Definition: op.h:138
static OpAttrMap< ValueType > GetAttrMap(const String &attr_name)
Get additional registered attribute about operators. If nothing has been registered,...
Definition: op.h:327
static bool HasAttrMap(const String &attr_name)
Checks if an attr map is present in the registry.
static const Op & Get(const String &op_name)
Get an Op for a given operator name. Will raise an error if the op has not been registered.
Base node of all non-primitive expressions.
Definition: expr.h:362
Managed reference to RelaxExprNode.
Definition: expr.h:405
TVM_DEFINE_OBJECT_REF_METHODS(RelaxExpr, BaseExpr, RelaxExprNode)
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:135
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
bool defined() const
Definition: object.h:553
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:911
static uint32_t TypeKey2Index(const std::string &key)
Get the type index of the corresponding key from runtime.
Reference to string objects.
Definition: string.h:97
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:946
Serializable global function used in IR.
Base expr nodes in TVM.
IR/AST nodes for the unified type system in TVM.
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
This file defines the TVM global function registry.