tvm
op.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_IR_OP_H_
26 #define TVM_IR_OP_H_
27 
28 #include <dmlc/registry.h>
29 #include <tvm/ir/attrs.h>
30 #include <tvm/ir/expr.h>
31 #include <tvm/ir/type.h>
32 #include <tvm/ir/type_relation.h>
34 #include <tvm/runtime/registry.h>
35 
36 #include <string>
37 #include <utility>
38 #include <vector>
39 
40 namespace tvm {
41 
42 // forward declare name.
43 template <typename>
44 class OpAttrMap;
45 
46 // TODO(tvm-team): migrate low-level intrinsics to use Op
58 class OpNode : public RelayExprNode {
59  public:
63  mutable FuncType op_type;
69  /* \brief Information of input arguments to the operator */
80  uint32_t attrs_type_index{0};
85  int32_t num_inputs = -1;
91  int32_t support_level = 10;
92 
94  v->Visit("name", &name);
95  v->Visit("op_type", &op_type);
96  v->Visit("description", &description);
97  v->Visit("arguments", &arguments);
98  v->Visit("attrs_type_key", &attrs_type_key);
99  v->Visit("num_inputs", &num_inputs);
100  v->Visit("support_level", &support_level);
101  }
102 
103  bool SEqualReduce(const OpNode* other, SEqualReducer equal) const {
104  // pointer equality is fine as there is only one op with the same name.
105  return this == other;
106  }
107 
108  void SHashReduce(SHashReducer hash_reduce) const {
109  // Name uniquely identifies an Op.
110  hash_reduce(name);
111  }
112 
118  bool IsPrimitiveOp() const {
119  if (is_primitive_ != -1) return is_primitive_ != 0;
120  is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0;
121  return is_primitive_ != 0;
122  }
123 
124  static constexpr const char* _type_key = "Op";
126 
127  private:
129  uint32_t AttrRegistryIndex() const { return index_; }
131  std::string AttrRegistryName() const { return name; }
132 
133  // friend class
134  template <typename>
136  template <typename, typename>
137  friend class AttrRegistry;
138  friend class OpRegEntry;
139 
140  friend bool IsPrimitiveOp(const RelayExpr&);
141  // Program internal unique index of operator.
142  // Used to help index the program.
143  uint32_t index_{0};
144  // whether this is a primitive op. -1 means unknown.
145  mutable int is_primitive_{-1};
146  // Internal function to compute if it is primitive op
147  bool IsPrimitiveOp_() const {
148  const auto& fn_ty = this->op_type;
149  ICHECK(fn_ty.get() != nullptr) << "op_type of " << this->name << " is not registered";
150  if (fn_ty->type_constraints.size() != 1) return false;
151  const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
152  if (rel == nullptr) return false;
153  // validate if the type parameter matches up
154  for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
155  if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
156  }
157  return true;
158  }
159 };
160 
165 class Op : public RelayExpr {
166  public:
168  Op() {}
170  explicit Op(ObjectPtr<Object> n) : RelayExpr(n) {}
175  inline const OpNode* operator->() const;
183  template <typename ValueType>
184  inline static OpAttrMap<ValueType> GetAttrMap(const String& attr_name);
190  TVM_DLL static bool HasAttrMap(const String& attr_name);
197  TVM_DLL static const Op& Get(const String& op_name);
198 
201 
202  private:
208  TVM_DLL static const AttrRegistryMapContainerMap<Op>& GetAttrMapContainer(const String& key);
209 };
210 
215 class OpRegEntry {
216  public:
218  const Op& op() const { return op_; }
225  inline OpRegEntry& describe(const std::string& descr); // NOLINT(*)
233  inline OpRegEntry& add_argument(const std::string& name, const std::string& type,
234  const std::string& description);
242  inline OpRegEntry& add_type_rel(
243  const std::string& rel_name,
244  runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>
245  type_rel_func);
251  template <typename AttrsType>
252  inline OpRegEntry& set_attrs_type();
258  inline OpRegEntry& set_attrs_type_key(const String& key);
264  inline OpRegEntry& set_num_inputs(int32_t n); // NOLINT(*)
270  inline OpRegEntry& set_support_level(int32_t level); // NOLINT(*)
284  template <typename ValueType>
285  inline OpRegEntry& set_attr(const std::string& attr_name, // NOLINT(*)
286  const ValueType& value, int plevel = 10);
287 
292  inline void reset_attr(const std::string& attr_name);
293 
294  // set the name of the op to be the same as registry
295  inline OpRegEntry& set_name() { // NOLINT(*)
296  if (get()->name.length() == 0) {
297  get()->name = name;
298  }
299  return *this;
300  }
306  TVM_DLL static OpRegEntry& RegisterOrGet(const String& name);
307 
308  private:
309  template <typename, typename>
310  friend class AttrRegistry;
311  // the name
312  std::string name;
314  Op op_;
315  // private constructor
316  TVM_DLL OpRegEntry(uint32_t reg_index);
317  // return internal pointer to op.
318  inline OpNode* get();
319  // update the attribute OpAttrMap
320  TVM_DLL void UpdateAttr(const String& key, runtime::TVMRetValue value, int plevel);
321 };
322 
327 template <typename ValueType>
328 class OpAttrMap : public AttrRegistryMap<Op, ValueType> {
329  public:
337  inline ValueType get(const RelayExpr& expr, ValueType def_value) const;
338 
340  using TParent::count;
341  using TParent::get;
342  using TParent::operator[];
343 
344  private:
345  friend class Op;
346  // constructor
347  explicit OpAttrMap(const AttrRegistryMapContainerMap<Op>& map) : TParent(map) {}
348 };
349 
350 // internal macros to make
351 #define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_##Op
352 
368 #define TVM_REGISTER_OP(OpName) \
369  TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \
370  ::tvm::OpRegEntry::RegisterOrGet(OpName).set_name()
371 
372 // implementations
373 inline const OpNode* Op::operator->() const { return static_cast<const OpNode*>(get()); }
374 
375 template <typename ValueType>
377  return OpAttrMap<ValueType>(Op::GetAttrMapContainer(key));
378 }
379 
380 inline OpNode* OpRegEntry::get() { return const_cast<OpNode*>(op_.operator->()); }
381 
382 inline OpRegEntry& OpRegEntry::describe(const std::string& descr) { // NOLINT(*)
383  get()->description = descr;
384  return *this;
385 }
386 
387 inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type,
388  const std::string& description) {
389  auto n = make_object<AttrFieldInfoNode>();
390  n->name = name;
391  n->type_info = type;
392  n->description = description;
393  get()->arguments.push_back(AttrFieldInfo(n));
394  return *this;
395 }
396 
398  const std::string& rel_name,
399  runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>
400  type_rel_func) {
401  auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
402  TypeRelationFn env_type_rel_func;
403 
404  if (runtime::Registry::Get(func_name)) {
405  auto env_func = EnvFunc::Get(func_name);
406  env_type_rel_func = env_func;
407  } else {
408  runtime::Registry::Register(func_name).set_body(type_rel_func.packed());
409  auto env_func = EnvFunc::Get(func_name);
410  env_type_rel_func = env_func;
411  }
412 
413  Array<TypeVar> type_params;
414  Array<Type> arg_types;
415 
416  // Add inputs.
417  std::string input_name_prefix = "in";
418  for (int i = 0; i < get()->num_inputs; i++) {
419  auto name = input_name_prefix + std::to_string(i);
420  auto param = TypeVar(name, TypeKind::kType);
421  type_params.push_back(param);
422  arg_types.push_back(param);
423  }
424 
425  Array<Type> ty_call_args = arg_types;
426 
427  // Add output type.
428  auto out_param = TypeVar("out", TypeKind::kType);
429  type_params.push_back(out_param);
430  // this will trigger copy on write.
431  ty_call_args.push_back(out_param);
432 
433  // The attributes of primitive op is nullptr
434  //
435  // The attributes of primitive operator can vary at the call site.
436  // The type of sum is also dependent on Attrs being passed.
437  // So puting nullptr in the Attrs means that the operator is polymorphic on Attrs.
438  //
439  // A common example is sum(x, axis), where the choice of axis
440  // can affect the type of the function.
441  TypeConstraint type_rel =
442  TypeRelation(env_type_rel_func, ty_call_args, arg_types.size(), Attrs());
443 
444  auto func_type = FuncType(arg_types, out_param, type_params, {type_rel});
445 
446  get()->op_type = func_type;
447 
448  return *this;
449 }
450 
451 inline OpRegEntry& OpRegEntry::set_num_inputs(int32_t n) { // NOLINT(*)
452  get()->num_inputs = n;
453  return *this;
454 }
455 
456 template <typename AttrsType>
457 inline OpRegEntry& OpRegEntry::set_attrs_type() { // NOLINT(*)
458  get()->attrs_type_key = AttrsType::_type_key;
459  get()->attrs_type_index = AttrsType::RuntimeTypeIndex();
460  return *this;
461 }
462 
463 inline OpRegEntry& OpRegEntry::set_attrs_type_key(const String& key) { // NOLINT(*)
464  get()->attrs_type_key = key;
466  return *this;
467 }
468 
469 inline OpRegEntry& OpRegEntry::set_support_level(int32_t n) { // NOLINT(*)
470  get()->support_level = n;
471  return *this;
472 }
473 
474 template <typename ValueType>
475 inline OpRegEntry& OpRegEntry::set_attr( // NOLINT(*)
476  const std::string& attr_name, const ValueType& value, int plevel) {
477  ICHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
479  rv = value;
480  UpdateAttr(attr_name, rv, plevel);
481  return *this;
482 }
483 
484 // member functions of OpAttrMap
485 
486 template <typename ValueType>
487 inline ValueType OpAttrMap<ValueType>::get(const RelayExpr& expr, ValueType def_value) const {
488  ICHECK(expr.defined());
489  if (const OpNode* op = expr.as<OpNode>()) {
490  return this->map_.get(GetRef<Op>(op), def_value);
491  } else {
492  return def_value;
493  }
494 }
495 
509 inline bool IsPrimitiveOp(const RelayExpr& expr) {
510  const auto* op = expr.as<OpNode>();
511  return op != nullptr && op->IsPrimitiveOp();
512 }
513 
514 } // namespace tvm
515 #endif // TVM_IR_OP_H_
static OpAttrMap< ValueType > GetAttrMap(const String &attr_name)
Get additional registered attribute about operators. If nothing has been registered, an empty OpAttrMap will be returned.
Definition: op.h:376
Container class of TypeReporter.
Definition: type_relation.h:145
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:799
String name
name of the operator
Definition: op.h:61
User defined type relation, it is an input-output relation on types.
Definition: type_relation.h:185
A custom smart pointer for Object.
Definition: object.h:358
TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelayExprNode)
const Op & op() const
Definition: op.h:218
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:102
Managed reference to BaseAttrsNode.
Definition: attrs.h:190
Base expr nodes in TVM.
AttrFieldInfo.
Definition: attrs.h:128
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:102
Type relation and function for type inference(checking).
OpRegEntry & add_argument(const std::string &name, const std::string &type, const std::string &description)
Add argument information to the function.
Definition: op.h:387
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
Op()
default constructor
Definition: op.h:168
Managed reference to TypeConstraintNode.
Definition: type.h:403
static EnvFunc Get(const String &name)
Get a global function based on the name.
OpRegEntry & set_support_level(int32_t level)
Set the support level of op.
Definition: op.h:469
Primitive Op(builtin intrinsics)
Definition: op.h:58
Definition: executor.h:43
void SHashReduce(SHashReducer hash_reduce) const
Definition: op.h:108
Helpers for attribute objects.
OpRegEntry & set_num_inputs(int32_t n)
Set the num_inputs.
Definition: op.h:451
FuncType op_type
the type of the operator
Definition: op.h:63
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:436
static const PackedFunc * Get(const std::string &name)
Get the global function by name.
OpRegEntry & set_name()
Definition: op.h:295
Generic attribute map.
Definition: attr_registry_map.h:38
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
bool IsPrimitiveOp() const
Check that if current op is a "primtive operator". That is the arguments are all type variables...
Definition: op.h:118
tvm::OpNode OpNode
Definition: op.h:35
static constexpr const char * _type_key
Definition: op.h:124
size_t size() const
Definition: array.h:399
IR/AST nodes for the unified type system in TVM.
bool defined() const
Definition: object.h:544
Array< AttrFieldInfo > arguments
Definition: op.h:70
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
Array< Type > args
The type arguments to the type function.
Definition: type_relation.h:194
friend class OpRegEntry
Definition: op.h:138
tvm::FuncType FuncType
Definition: type.h:57
ValueType get(const RelayExpr &expr, ValueType def_value) const
get the corresponding value element at op with default value.
Definition: op.h:487
Reference to string objects.
Definition: string.h:124
tvm::TypeVar TypeVar
Definition: type.h:49
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:60
OpRegEntry & set_attrs_type_key(const String &key)
Set the attrs type key and index to be AttrsType.
Definition: op.h:463
Managed reference to RelayExprNode.
Definition: expr.h:217
Managed reference class to OpNode.
Definition: op.h:165
Map<Key, ValueType> used to store meta-data.
Definition: attr_registry_map.h:101
void VisitAttrs(AttrVisitor *v)
Definition: op.h:93
Managed reference to FuncTypeNode.
Definition: type.h:461
Helper structure to register operators.
Definition: op.h:215
Attribute map used in registry.
const OpNode * operator->() const
access the internal node container
Definition: op.h:373
String description
detailed description of the operator This can be used to generate docstring automatically for the ope...
Definition: op.h:68
OpRegEntry & describe(const std::string &descr)
setter function during registration Set the description of operator
Definition: op.h:382
uint32_t attrs_type_index
attribute type index, this field varies in each run and is not exposed to frontend.
Definition: op.h:80
Registry & set_body(PackedFunc f)
set the body of the function to be f
int32_t support_level
support level of the operator, The lower the more priority it contains. This is in analogies to BLAS ...
Definition: op.h:91
bool SEqualReduce(const OpNode *other, SEqualReducer equal) const
Definition: op.h:103
Map<Op,ValueType> used to store meta-information about Op.
Definition: op.h:44
OpRegEntry & set_attrs_type()
Set the attrs type key and index to be AttrsType.
Definition: op.h:457
OpRegEntry & set_attr(const std::string &attr_name, const ValueType &value, int plevel=10)
Register additional attributes to operator.
Definition: op.h:475
Definition: type.h:201
static Registry & Register(const std::string &name, bool override=false)
Register a function with given name.
String attrs_type_key
The type key of the attribute field This can be empty, in which case it defaults to anything...
Definition: op.h:75
int32_t num_inputs
number of input arguments to the operator, -1 means it is variable length
Definition: op.h:85
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:865
Base node of all non-primitive expressions.
Definition: expr.h:145
OpRegEntry & add_type_rel(const std::string &rel_name, runtime::TypedPackedFunc< bool(const Array< Type > &, int, const Attrs &, const TypeReporter &)> type_rel_func)
Attach the type function corresponding to the return type.
Definition: op.h:397
Op(ObjectPtr< Object > n)
constructor from node pointer
Definition: op.h:170
static uint32_t TypeKey2Index(const std::string &key)
Get the type index of the corresponding key from runtime.
tvm::TypeRelation TypeRelation
Definition: type.h:67
This file defines the TVM global function registry.