tvm
op.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef TVM_IR_OP_H_
26 #define TVM_IR_OP_H_
27 
28 #include <dmlc/registry.h>
29 #include <tvm/ir/attrs.h>
30 #include <tvm/ir/expr.h>
31 #include <tvm/ir/type.h>
32 #include <tvm/ir/type_relation.h>
34 #include <tvm/runtime/registry.h>
35 
36 #include <string>
37 #include <utility>
38 #include <vector>
39 
40 namespace tvm {
41 
42 // forward declare name.
43 template <typename>
44 class OpAttrMap;
45 
46 // TODO(tvm-team): migrate low-level intrinsics to use Op
58 class OpNode : public RelayExprNode {
59  public:
63  mutable FuncType op_type;
69  /* \brief Information of input arguments to the operator */
80  uint32_t attrs_type_index{0};
85  int32_t num_inputs = -1;
91  int32_t support_level = 10;
92 
94  v->Visit("name", &name);
95  v->Visit("op_type", &op_type);
96  v->Visit("description", &description);
97  v->Visit("arguments", &arguments);
98  v->Visit("attrs_type_key", &attrs_type_key);
99  v->Visit("num_inputs", &num_inputs);
100  v->Visit("support_level", &support_level);
101  }
102 
103  bool SEqualReduce(const OpNode* other, SEqualReducer equal) const {
104  // pointer equality is fine as there is only one op with the same name.
105  return this == other;
106  }
107 
108  void SHashReduce(SHashReducer hash_reduce) const {
109  // Name uniquely identifies an Op.
110  hash_reduce(name);
111  }
112 
118  bool IsPrimitiveOp() const {
119  if (is_primitive_ != -1) return is_primitive_ != 0;
120  is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0;
121  return is_primitive_ != 0;
122  }
123 
124  static constexpr const char* _type_key = "Op";
126 
127  private:
129  uint32_t AttrRegistryIndex() const { return index_; }
131  std::string AttrRegistryName() const { return name; }
132 
133  // friend class
134  template <typename>
136  template <typename, typename>
137  friend class AttrRegistry;
138  friend class OpRegEntry;
139 
140  friend bool IsPrimitiveOp(const RelayExpr&);
141  // Program internal unique index of operator.
142  // Used to help index the program.
143  uint32_t index_{0};
144  // whether this is a primitive op. -1 means unknown.
145  mutable int is_primitive_{-1};
146  // Internal function to compute if it is primitive op
147  bool IsPrimitiveOp_() const {
148  const auto& fn_ty = this->op_type;
149  ICHECK(fn_ty.get() != nullptr) << "op_type of " << this->name << " is not registered";
150  if (fn_ty->type_constraints.size() != 1) return false;
151  const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
152  if (rel == nullptr) return false;
153  // validate if the type parameter matches up
154  for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
155  if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
156  }
157  return true;
158  }
159 };
160 
165 class Op : public RelayExpr {
166  public:
174  template <typename ValueType>
175  inline static OpAttrMap<ValueType> GetAttrMap(const String& attr_name);
181  TVM_DLL static bool HasAttrMap(const String& attr_name);
188  TVM_DLL static const Op& Get(const String& op_name);
189 
191 
192  private:
198  TVM_DLL static const AttrRegistryMapContainerMap<Op>& GetAttrMapContainer(const String& key);
199 };
200 
205 class OpRegEntry {
206  public:
208  const Op& op() const { return op_; }
215  inline OpRegEntry& describe(const std::string& descr); // NOLINT(*)
223  inline OpRegEntry& add_argument(const std::string& name, const std::string& type,
224  const std::string& description);
232  inline OpRegEntry& add_type_rel(
233  const std::string& rel_name,
234  runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>
235  type_rel_func);
241  template <typename AttrsType>
242  inline OpRegEntry& set_attrs_type();
248  inline OpRegEntry& set_attrs_type_key(const String& key);
254  inline OpRegEntry& set_num_inputs(int32_t n); // NOLINT(*)
260  inline OpRegEntry& set_support_level(int32_t level); // NOLINT(*)
274  template <typename ValueType>
275  inline OpRegEntry& set_attr(const std::string& attr_name, // NOLINT(*)
276  const ValueType& value, int plevel = 10);
277 
282  inline void reset_attr(const std::string& attr_name);
283 
284  // set the name of the op to be the same as registry
285  inline OpRegEntry& set_name() { // NOLINT(*)
286  if (get()->name.length() == 0) {
287  get()->name = name;
288  }
289  return *this;
290  }
296  TVM_DLL static OpRegEntry& RegisterOrGet(const String& name);
297 
298  private:
299  template <typename, typename>
300  friend class AttrRegistry;
301  // the name
302  std::string name;
304  Op op_;
305  // private constructor
306  TVM_DLL OpRegEntry(uint32_t reg_index);
307  // return internal pointer to op.
308  inline OpNode* get();
309  // update the attribute OpAttrMap
310  TVM_DLL void UpdateAttr(const String& key, runtime::TVMRetValue value, int plevel);
311 };
312 
317 template <typename ValueType>
318 class OpAttrMap : public AttrRegistryMap<Op, ValueType> {
319  public:
327  inline ValueType get(const RelayExpr& expr, ValueType def_value) const;
328 
330  using TParent::count;
331  using TParent::get;
332  using TParent::operator[];
333 
334  private:
335  friend class Op;
336  // constructor
337  explicit OpAttrMap(const AttrRegistryMapContainerMap<Op>& map) : TParent(map) {}
338 };
339 
340 // internal macros to make
341 #define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_##Op
342 
358 #define TVM_REGISTER_OP(OpName) \
359  TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \
360  ::tvm::OpRegEntry::RegisterOrGet(OpName).set_name()
361 
362 // implementations
363 
364 template <typename ValueType>
366  return OpAttrMap<ValueType>(Op::GetAttrMapContainer(key));
367 }
368 
369 inline OpNode* OpRegEntry::get() { return const_cast<OpNode*>(op_.operator->()); }
370 
371 inline OpRegEntry& OpRegEntry::describe(const std::string& descr) { // NOLINT(*)
372  get()->description = descr;
373  return *this;
374 }
375 
376 inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type,
377  const std::string& description) {
378  auto n = make_object<AttrFieldInfoNode>();
379  n->name = name;
380  n->type_info = type;
381  n->description = description;
382  get()->arguments.push_back(AttrFieldInfo(n));
383  return *this;
384 }
385 
387  const std::string& rel_name,
388  runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>
389  type_rel_func) {
390  auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
391  TypeRelationFn env_type_rel_func;
392 
393  if (runtime::Registry::Get(func_name)) {
394  auto env_func = EnvFunc::Get(func_name);
395  env_type_rel_func = env_func;
396  } else {
397  runtime::Registry::Register(func_name).set_body(type_rel_func.packed());
398  auto env_func = EnvFunc::Get(func_name);
399  env_type_rel_func = env_func;
400  }
401 
402  Array<TypeVar> type_params;
403  Array<Type> arg_types;
404 
405  // Add inputs.
406  std::string input_name_prefix = "in";
407  for (int i = 0; i < get()->num_inputs; i++) {
408  auto name = input_name_prefix + std::to_string(i);
409  auto param = TypeVar(name, TypeKind::kType);
410  type_params.push_back(param);
411  arg_types.push_back(param);
412  }
413 
414  Array<Type> ty_call_args = arg_types;
415 
416  // Add output type.
417  auto out_param = TypeVar("out", TypeKind::kType);
418  type_params.push_back(out_param);
419  // this will trigger copy on write.
420  ty_call_args.push_back(out_param);
421 
422  // The attributes of primitive op is nullptr
423  //
424  // The attributes of primitive operator can vary at the call site.
425  // The type of sum is also dependent on Attrs being passed.
426  // So puting nullptr in the Attrs means that the operator is polymorphic on Attrs.
427  //
428  // A common example is sum(x, axis), where the choice of axis
429  // can affect the type of the function.
430  TypeConstraint type_rel =
431  TypeRelation(env_type_rel_func, ty_call_args, arg_types.size(), Attrs());
432 
433  auto func_type = FuncType(arg_types, out_param, type_params, {type_rel});
434 
435  get()->op_type = func_type;
436 
437  return *this;
438 }
439 
440 inline OpRegEntry& OpRegEntry::set_num_inputs(int32_t n) { // NOLINT(*)
441  get()->num_inputs = n;
442  return *this;
443 }
444 
445 template <typename AttrsType>
446 inline OpRegEntry& OpRegEntry::set_attrs_type() { // NOLINT(*)
447  get()->attrs_type_key = AttrsType::_type_key;
448  get()->attrs_type_index = AttrsType::RuntimeTypeIndex();
449  return *this;
450 }
451 
452 inline OpRegEntry& OpRegEntry::set_attrs_type_key(const String& key) { // NOLINT(*)
453  get()->attrs_type_key = key;
455  return *this;
456 }
457 
458 inline OpRegEntry& OpRegEntry::set_support_level(int32_t n) { // NOLINT(*)
459  get()->support_level = n;
460  return *this;
461 }
462 
463 template <typename ValueType>
464 inline OpRegEntry& OpRegEntry::set_attr( // NOLINT(*)
465  const std::string& attr_name, const ValueType& value, int plevel) {
466  ICHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
468  rv = value;
469  UpdateAttr(attr_name, rv, plevel);
470  return *this;
471 }
472 
473 // member functions of OpAttrMap
474 
475 template <typename ValueType>
476 inline ValueType OpAttrMap<ValueType>::get(const RelayExpr& expr, ValueType def_value) const {
477  ICHECK(expr.defined());
478  if (const OpNode* op = expr.as<OpNode>()) {
479  return this->map_.get(GetRef<Op>(op), def_value);
480  } else {
481  return def_value;
482  }
483 }
484 
498 inline bool IsPrimitiveOp(const RelayExpr& expr) {
499  const auto* op = expr.as<OpNode>();
500  return op != nullptr && op->IsPrimitiveOp();
501 }
502 
503 } // namespace tvm
504 #endif // TVM_IR_OP_H_
Attribute map used in registry.
AttrFieldInfo.
Definition: attrs.h:128
Generic attribute map.
Definition: attr_registry_map.h:38
Map<Key, ValueType> used to store meta-data.
Definition: attr_registry_map.h:101
ValueType get(const Op &key, ValueType def_value) const
get the corresponding value element at key with default value.
Definition: attr_registry_map.h:126
int count(const Op &key) const
Check if the map has op as key.
Definition: attr_registry_map.h:113
Definition: executor.h:43
Visitor class to get the attributes of an AST/IR node. The content is going to be called for each fie...
Definition: reflection.h:52
Managed reference to BaseAttrsNode.
Definition: attrs.h:190
static EnvFunc Get(const String &name)
Get a global function based on the name.
Managed reference to FuncTypeNode.
Definition: type.h:481
Map<Op,ValueType> used to store meta-information about Op.
Definition: op.h:318
ValueType get(const RelayExpr &expr, ValueType def_value) const
get the corresponding value element at op with default value.
Definition: op.h:476
Primitive Op(builtin intrinsics)
Definition: op.h:58
void SHashReduce(SHashReducer hash_reduce) const
Definition: op.h:108
static constexpr const char * _type_key
Definition: op.h:124
bool IsPrimitiveOp() const
Check that if current op is a "primtive operator". That is the arguments are all type variables,...
Definition: op.h:118
String name
name of the operator
Definition: op.h:61
uint32_t attrs_type_index
attribute type index, this field varies in each run and is not exposed to frontend.
Definition: op.h:80
String attrs_type_key
The type key of the attribute field This can be empty, in which case it defaults to anything.
Definition: op.h:75
int32_t support_level
support level of the operator, The lower the more priority it contains. This is in analogies to BLAS ...
Definition: op.h:91
void VisitAttrs(AttrVisitor *v)
Definition: op.h:93
int32_t num_inputs
number of input arguments to the operator, -1 means it is variable length
Definition: op.h:85
TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelayExprNode)
FuncType op_type
the type of the operator
Definition: op.h:63
String description
detailed description of the operator This can be used to generate docstring automatically for the ope...
Definition: op.h:68
Array< AttrFieldInfo > arguments
Definition: op.h:70
bool SEqualReduce(const OpNode *other, SEqualReducer equal) const
Definition: op.h:103
Helper structure to register operators.
Definition: op.h:205
OpRegEntry & set_attrs_type_key(const String &key)
Set the attrs type key and index to be AttrsType.
Definition: op.h:452
OpRegEntry & describe(const std::string &descr)
setter function during registration Set the description of operator
Definition: op.h:371
static OpRegEntry & RegisterOrGet(const String &name)
Register or get a new entry.
OpRegEntry & add_type_rel(const std::string &rel_name, runtime::TypedPackedFunc< bool(const Array< Type > &, int, const Attrs &, const TypeReporter &)> type_rel_func)
Attach the type function corresponding to the return type.
Definition: op.h:386
OpRegEntry & set_name()
Definition: op.h:285
void reset_attr(const std::string &attr_name)
Resets an attr of the registry.
OpRegEntry & add_argument(const std::string &name, const std::string &type, const std::string &description)
Add argument information to the function.
Definition: op.h:376
OpRegEntry & set_attrs_type()
Set the attrs type key and index to be AttrsType.
Definition: op.h:446
OpRegEntry & set_support_level(int32_t level)
Set the support level of op.
Definition: op.h:458
const Op & op() const
Definition: op.h:208
OpRegEntry & set_attr(const std::string &attr_name, const ValueType &value, int plevel=10)
Register additional attributes to operator.
Definition: op.h:464
OpRegEntry & set_num_inputs(int32_t n)
Set the num_inputs.
Definition: op.h:440
Managed reference class to OpNode.
Definition: op.h:165
static OpAttrMap< ValueType > GetAttrMap(const String &attr_name)
Get additional registered attribute about operators. If nothing has been registered,...
Definition: op.h:365
static bool HasAttrMap(const String &attr_name)
Checks if an attr map is present in the registry.
static const Op & Get(const String &op_name)
Get an Op for a given operator name. Will raise an error if the op has not been registered.
Base node of all non-primitive expressions.
Definition: expr.h:362
Managed reference to RelayExprNode.
Definition: expr.h:442
TVM_DEFINE_OBJECT_REF_METHODS(RelayExpr, BaseExpr, RelayExprNode)
A Reducer class to reduce the structural equality result of two objects.
Definition: structural_equal.h:137
A Reducer class to reduce the structural hash value.
Definition: structural_hash.h:121
Managed reference to TypeConstraintNode.
Definition: type.h:423
Container class of TypeReporter.
Definition: type_relation.h:145
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
void push_back(const T &item)
push a new item to the back of the list
Definition: array.h:457
size_t size() const
Definition: array.h:420
bool defined() const
Definition: object.h:552
const ObjectType * as() const
Try to downcast the internal Object to a raw pointer of a corresponding type.
Definition: object.h:910
static uint32_t TypeKey2Index(const std::string &key)
Get the type index of the corresponding key from runtime.
static Registry & Register(const String &name, bool override=false)
Register a function with given name.
static const PackedFunc * Get(const String &name)
Get the global function by name.
Registry & set_body(PackedFunc f)
set the body of the function to be f
Reference to string objects.
Definition: string.h:98
size_t length() const
Return the length of the string.
Definition: string.h:201
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:946
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:63
Helpers for attribute objects.
Base expr nodes in TVM.
IR/AST nodes for the unified type system in TVM.
tvm::TypeVar TypeVar
Definition: type.h:49
tvm::TypeRelationNode TypeRelationNode
Definition: type.h:68
tvm::FuncType FuncType
Definition: type.h:57
tvm::TypeRelation TypeRelation
Definition: type.h:67
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr equal(PrimExpr a, PrimExpr b, Span span=Span())
equal
bool IsPrimitiveOp(const RelayExpr &expr)
Check that an expression is a "primitive operator".
Definition: op.h:498
@ kType
Definition: type.h:202
This file defines the TVM global function registry.
Type relation and function for type inference(checking).