tvm
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
packed_func.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RUNTIME_PACKED_FUNC_H_
25 #define TVM_RUNTIME_PACKED_FUNC_H_
26 
31 #include <tvm/runtime/data_type.h>
32 #include <tvm/runtime/logging.h>
33 #include <tvm/runtime/module.h>
34 #include <tvm/runtime/ndarray.h>
35 #include <tvm/runtime/object.h>
36 
37 #include <functional>
38 #include <limits>
39 #include <memory>
40 #include <string>
41 #include <tuple>
42 #include <type_traits>
43 #include <utility>
44 #include <vector>
45 
46 // Whether use TVM runtime in header only mode.
47 #ifndef TVM_RUNTIME_HEADER_ONLY
48 #define TVM_RUNTIME_HEADER_ONLY 0
49 #endif
50 
51 namespace tvm {
52 namespace runtime {
53 
54 // forward declarations
55 class TVMArgs;
56 class TVMArgValue;
57 class TVMMovableArgValueWithContext_;
58 class TVMRetValue;
59 class TVMArgsSetter;
60 template <typename FType>
62 template <typename TSignature>
64 
69 class PackedFuncObj : public Object {
70  public:
76  TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const;
77 
78  static constexpr const uint32_t _type_index = TypeIndex::kRuntimePackedFunc;
79  static constexpr const char* _type_key = "runtime.PackedFunc";
81 
82  protected:
86  template <class TPackedFuncSubObj>
87  struct Extractor {
94  static void Call(const PackedFuncObj* obj, TVMArgs args, TVMRetValue* rv);
95  };
96 
98  using FCallPacked = void(const PackedFuncObj*, TVMArgs, TVMRetValue*);
99 
104  explicit PackedFuncObj(FCallPacked* f_call_pack) : f_call_packed_(f_call_pack) {}
105 
107  PackedFuncObj() = delete;
108 
111 };
112 
114 template <class TCallable>
116  using TStorage = typename std::remove_cv<typename std::remove_reference<TCallable>::type>::type;
117 
118  public:
125  explicit PackedFuncSubObj(TCallable callable)
126  : PackedFuncObj(Extractor<TSelf>::Call), callable_(callable) {}
128  mutable TStorage callable_;
129 };
130 
139 class PackedFunc : public ObjectRef {
140  public:
142  PackedFunc(std::nullptr_t null) : ObjectRef(nullptr) {} // NOLINT(*)
148  template <typename TCallable,
149  typename = std::enable_if_t<
150  std::is_convertible<TCallable, std::function<void(TVMArgs, TVMRetValue*)>>::value &&
151  !std::is_base_of<TCallable, PackedFunc>::value>>
152  explicit PackedFunc(TCallable data) {
153  using ObjType = PackedFuncSubObj<TCallable>;
154  data_ = make_object<ObjType>(std::forward<TCallable>(data));
155  }
170  template <typename... Args>
171  inline TVMRetValue operator()(Args&&... args) const;
177  TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const;
179  bool operator==(std::nullptr_t null) const { return data_ == nullptr; }
181  bool operator!=(std::nullptr_t null) const { return data_ != nullptr; }
182 
184 };
185 
187 using FSig = std::string();
188 
192 template <typename FType>
193 class TypedPackedFunc;
194 
227 template <typename R, typename... Args>
228 class TypedPackedFunc<R(Args...)> {
229  public:
231  using TSelf = TypedPackedFunc<R(Args...)>;
235  TypedPackedFunc(std::nullptr_t null) {} // NOLINT(*)
253  inline TypedPackedFunc(PackedFunc packed); // NOLINT(*)
258  inline TypedPackedFunc(const TVMRetValue& value); // NOLINT(*)
263  inline TypedPackedFunc(const TVMArgValue& value); // NOLINT(*)
268  inline TypedPackedFunc(TVMMovableArgValueWithContext_&& value); // NOLINT(*)
285  template <typename FLambda, typename = typename std::enable_if<std::is_convertible<
286  FLambda, std::function<R(Args...)>>::value>::type>
287  TypedPackedFunc(const FLambda& typed_lambda, std::string name) { // NOLINT(*)
288  this->AssignTypedLambda(typed_lambda, name);
289  }
308  template <typename FLambda, typename = typename std::enable_if<std::is_convertible<
309  FLambda, std::function<R(Args...)>>::value>::type>
310  TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*)
311  this->AssignTypedLambda(typed_lambda);
312  }
329  template <typename FLambda, typename = typename std::enable_if<
330  std::is_convertible<FLambda,
331  std::function<R(Args...)>>::value>::type>
332  TSelf& operator=(FLambda typed_lambda) { // NOLINT(*)
333  this->AssignTypedLambda(typed_lambda);
334  return *this;
335  }
342  packed_ = packed;
343  return *this;
344  }
350  TVM_ALWAYS_INLINE R operator()(Args... args) const;
355  operator PackedFunc() const { return packed(); }
359  const PackedFunc& packed() const { return packed_; }
361  bool operator==(std::nullptr_t null) const { return packed_ == nullptr; }
363  bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; }
364 
365  private:
366  friend class TVMRetValue;
368  PackedFunc packed_;
377  template <typename FLambda>
378  inline void AssignTypedLambda(FLambda flambda, std::string name);
387  template <typename FLambda>
388  inline void AssignTypedLambda(FLambda flambda);
389 };
390 
392 class TVMArgs {
393  public:
394  const TVMValue* values;
395  const int* type_codes;
396  int num_args;
403  TVMArgs(const TVMValue* values, const int* type_codes, int num_args)
406  inline int size() const;
412  inline TVMArgValue operator[](int i) const;
413 };
414 
420 inline const char* ArgTypeCode2Str(int type_code);
421 
422 inline std::ostream& operator<<(std::ostream& os, DLDevice dev); // NOLINT(*)
423 
424 // macro to check type code.
425 #define TVM_CHECK_TYPE_CODE(CODE, T) \
426  ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE)
427 
432 template <typename T>
442  using ContainerType = typename T::ContainerType;
443  if (ptr == nullptr) {
444  if (T::_type_is_nullable) {
445  return NullOpt;
446  } else {
447  return String("nullptr");
448  }
449  }
450  if (ptr->IsInstance<ContainerType>()) {
451  return NullOpt;
452  } else {
453  return String(ptr->GetTypeKey());
454  }
455  }
461  static bool Check(const Object* ptr) {
462  using ContainerType = typename T::ContainerType;
463  if (ptr == nullptr) return T::_type_is_nullable;
464  return ptr->IsInstance<ContainerType>();
465  }
466  static std::string TypeName() {
467  using ContainerType = typename T::ContainerType;
468  return ContainerType::_type_key;
469  }
470 };
471 
472 // Additional overloads for PackedFunc checking.
473 template <typename T>
476  if (ptr == nullptr) {
477  return NullOpt;
478  }
479  if (!ptr->IsInstance<ArrayNode>()) {
480  return String(ptr->GetTypeKey());
481  }
482  const ArrayNode* n = static_cast<const ArrayNode*>(ptr);
483  for (size_t i = 0; i < n->size(); i++) {
484  const ObjectRef& p = (*n)[i];
486  if (check_subtype.defined()) {
487  return String("Array[index " + std::to_string(i) + ": " + check_subtype.value() + "]");
488  }
489  }
490  return NullOpt;
491  }
492  static bool Check(const Object* ptr) {
493  if (ptr == nullptr) return true;
494  if (!ptr->IsInstance<ArrayNode>()) return false;
495  const ArrayNode* n = static_cast<const ArrayNode*>(ptr);
496  for (const ObjectRef& p : *n) {
497  if (!ObjectTypeChecker<T>::Check(p.get())) {
498  return false;
499  }
500  }
501  return true;
502  }
503  static std::string TypeName() { return "Array[" + ObjectTypeChecker<T>::TypeName() + "]"; }
504 };
505 template <typename K, typename V>
506 struct ObjectTypeChecker<Map<K, V>> {
508  if (ptr == nullptr) return NullOpt;
509  if (!ptr->IsInstance<MapNode>()) return String(ptr->GetTypeKey());
510  const MapNode* n = static_cast<const MapNode*>(ptr);
511  for (const auto& kv : *n) {
513  Optional<String> value_type = ObjectTypeChecker<K>::CheckAndGetMismatch(kv.first.get());
514  if (key_type.defined() || value_type.defined()) {
515  std::string key_name =
516  key_type.defined() ? std::string(key_type.value()) : ObjectTypeChecker<K>::TypeName();
517  std::string value_name = value_type.defined() ? std::string(value_type.value())
519  return String("Map[" + key_name + ", " + value_name + "]");
520  }
521  }
522  return NullOpt;
523  }
524  static bool Check(const Object* ptr) {
525  if (ptr == nullptr) return true;
526  if (!ptr->IsInstance<MapNode>()) return false;
527  const MapNode* n = static_cast<const MapNode*>(ptr);
528  for (const auto& kv : *n) {
529  if (!ObjectTypeChecker<K>::Check(kv.first.get())) return false;
530  if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
531  }
532  return true;
533  }
534  static std::string TypeName() {
536  ']';
537  }
538 };
539 
545  public:
546  operator double() const {
547  // Allow automatic conversion from int to float
548  // This avoids errors when user pass in int from
549  // the frontend while the API expects a float.
550  if (type_code_ == kDLInt) {
551  return static_cast<double>(value_.v_int64);
552  }
553  TVM_CHECK_TYPE_CODE(type_code_, kDLFloat);
554  return value_.v_float64;
555  }
556  operator int64_t() const {
558  return value_.v_int64;
559  }
560  operator uint64_t() const {
562  return value_.v_int64;
563  }
564  operator int() const {
568  return static_cast<int>(value_.v_int64);
569  }
570  operator bool() const {
572  return value_.v_int64 != 0;
573  }
574  operator void*() const {
575  if (type_code_ == kTVMNullptr) return nullptr;
578  return value_.v_handle;
579  }
580  operator DLTensor*() const {
582  return static_cast<DLTensor*>(value_.v_handle);
583  } else {
584  if (type_code_ == kTVMNullptr) return nullptr;
585  LOG(FATAL) << "Expected "
586  << "DLTensor* or NDArray but got " << ArgTypeCode2Str(type_code_);
587  return nullptr;
588  }
589  }
590  operator NDArray() const {
591  if (type_code_ == kTVMNullptr) return NDArray(ObjectPtr<Object>(nullptr));
594  }
595  operator Module() const {
596  if (type_code_ == kTVMNullptr) {
597  return Module(ObjectPtr<Object>(nullptr));
598  }
600  return Module(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
601  }
602  operator PackedFunc() const {
603  if (type_code_ == kTVMNullptr) {
604  return PackedFunc(ObjectPtr<Object>(nullptr));
605  }
607  return PackedFunc(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
608  }
609  operator Device() const {
611  return value_.v_device;
612  }
613  int type_code() const { return type_code_; }
619  template <typename T>
620  T* ptr() const {
621  return static_cast<T*>(value_.v_handle);
622  }
623  // ObjectRef handling
624  template <typename TObjectRef,
625  typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
626  inline bool IsObjectRef() const;
627  template <typename TObjectRef>
628  inline TObjectRef AsObjectRef() const;
629 
630  protected:
631  friend class TVMArgsSetter;
632  friend class TVMRetValue;
633  friend class TVMMovableArgValue_;
636 
641 };
642 
649 class TVMArgValue : public TVMPODValue_ {
650  public:
659  // reuse converter from parent
660  using TVMPODValue_::operator double;
661  using TVMPODValue_::operator int64_t;
662  using TVMPODValue_::operator uint64_t;
663  using TVMPODValue_::operator int;
664  using TVMPODValue_::operator bool;
665  using TVMPODValue_::operator void*;
666  using TVMPODValue_::operator DLTensor*;
667  using TVMPODValue_::operator NDArray;
668  using TVMPODValue_::operator Device;
669  using TVMPODValue_::operator Module;
670  using TVMPODValue_::operator PackedFunc;
673 
674  // conversion operator.
675  operator std::string() const {
676  if (type_code_ == kTVMDataType) {
677  return DLDataType2String(operator DLDataType());
678  } else if (type_code_ == kTVMBytes) {
679  TVMByteArray* arr = static_cast<TVMByteArray*>(value_.v_handle);
680  return std::string(arr->data, arr->size);
681  } else if (type_code_ == kTVMStr) {
682  return std::string(value_.v_str);
683  } else {
684  return AsObjectRef<tvm::runtime::String>().operator std::string();
685  }
686  }
687  template <typename FType>
688  operator TypedPackedFunc<FType>() const {
689  return TypedPackedFunc<FType>(operator PackedFunc());
690  }
691  const TVMValue& value() const { return value_; }
692 
693  template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
694  inline operator T() const;
695  inline operator DLDataType() const;
696  inline operator DataType() const;
697 };
698 
710  public:
712  // reuse converter from parent
713  using TVMPODValue_::operator double;
714  using TVMPODValue_::operator int64_t;
715  using TVMPODValue_::operator uint64_t;
716  using TVMPODValue_::operator int;
717  using TVMPODValue_::operator bool;
718  using TVMPODValue_::operator void*;
719  using TVMPODValue_::operator DLTensor*;
720  using TVMPODValue_::operator NDArray;
721  using TVMPODValue_::operator Device;
722  using TVMPODValue_::operator Module;
723  using TVMPODValue_::operator PackedFunc;
724  // reuse conversion rule from ArgValue.
725  operator std::string() const { return AsArgValue().operator std::string(); }
726  template <typename FType>
727  operator TypedPackedFunc<FType>() const {
728  return TypedPackedFunc<FType>(operator PackedFunc());
729  }
730  operator DLDataType() const { return AsArgValue().operator DLDataType(); }
731  operator DataType() const { return AsArgValue().operator DataType(); }
732  operator TVMArgValue() const { return AsArgValue(); }
738  template <typename T,
739  typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
740  inline operator T() const;
741 
742  private:
744  TVMArgValue AsArgValue() const { return TVMArgValue(value_, type_code_); }
745 };
746 
755  public:
765  TVMMovableArgValueWithContext_(TVMValue value, int type_code, int arg_index,
766  const std::string* optional_name, FSig* f_sig)
767  : value_(value, type_code),
768  arg_index_(arg_index),
769  optional_name_(optional_name),
770  f_sig_(f_sig) {}
771 
772  template <typename T>
773  operator T() const {
774  try {
775  return value_; // implicit conversion happens here
776  } catch (dmlc::Error& e) {
777  LOG(FATAL) << "In function " << (optional_name_ == nullptr ? "<anonymous>" : *optional_name_)
778  << (f_sig_ == nullptr ? "" : (*f_sig_)()) << ": error while converting argument "
779  << arg_index_ << ": " << e.what();
780  throw; // never reached, LOG(FATAL) throws, but this silences a warning.
781  }
782  }
783 
784  private:
785  TVMMovableArgValue_ value_;
786  int arg_index_;
787  const std::string* optional_name_;
788  FSig* f_sig_;
789 };
790 
799 class TVMRetValue : public TVMPODValue_ {
800  public:
808  other.value_.v_handle = nullptr;
809  other.type_code_ = kTVMNullptr;
810  }
812  ~TVMRetValue() { this->Clear(); }
813  // reuse converter from parent
814  using TVMPODValue_::operator double;
815  using TVMPODValue_::operator int64_t;
816  using TVMPODValue_::operator uint64_t;
817  using TVMPODValue_::operator int;
818  using TVMPODValue_::operator bool;
819  using TVMPODValue_::operator void*;
820  using TVMPODValue_::operator DLTensor*;
821  using TVMPODValue_::operator Device;
822  using TVMPODValue_::operator NDArray;
823  using TVMPODValue_::operator Module;
824  using TVMPODValue_::operator PackedFunc;
827 
828  TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); }
829  // conversion operators
830  operator std::string() const {
831  if (type_code_ == kTVMDataType) {
832  return DLDataType2String(operator DLDataType());
833  } else if (type_code_ == kTVMBytes) {
834  return *ptr<std::string>();
835  }
837  return *ptr<std::string>();
838  }
839  operator DLDataType() const {
840  if (type_code_ == kTVMStr) {
841  return String2DLDataType(operator std::string());
842  }
844  return value_.v_type;
845  }
846  operator DataType() const { return DataType(operator DLDataType()); }
847  template <typename FType>
848  operator TypedPackedFunc<FType>() const {
849  return TypedPackedFunc<FType>(operator PackedFunc());
850  }
851  // Assign operators
853  this->Clear();
854  value_ = other.value_;
855  type_code_ = other.type_code_;
856  other.type_code_ = kTVMNullptr;
857  return *this;
858  }
860  this->SwitchToPOD(kDLFloat);
862  return *this;
863  }
864  TVMRetValue& operator=(std::nullptr_t value) {
865  this->SwitchToPOD(kTVMNullptr);
867  return *this;
868  }
870  this->SwitchToPOD(kTVMOpaqueHandle);
872  return *this;
873  }
875  this->SwitchToPOD(kDLInt);
876  value_.v_int64 = value;
877  return *this;
878  }
880  this->SwitchToPOD(kDLInt);
881  value_.v_int64 = value;
882  return *this;
883  }
885  this->SwitchToPOD(kDLDevice);
887  return *this;
888  }
889  TVMRetValue& operator=(DLDataType t) {
890  this->SwitchToPOD(kTVMDataType);
891  value_.v_type = t;
892  return *this;
893  }
894  TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); }
896  this->SwitchToPOD(kDLInt);
897  value_.v_int64 = value;
898  return *this;
899  }
900  TVMRetValue& operator=(std::string value) {
901  this->SwitchToClass(kTVMStr, value);
902  return *this;
903  }
905  this->SwitchToClass(kTVMBytes, std::string(value.data, value.size));
906  return *this;
907  }
909  if (other.data_ != nullptr) {
910  this->Clear();
914  } else {
915  SwitchToPOD(kTVMNullptr);
916  value_.v_handle = nullptr;
917  }
918  return *this;
919  }
921  SwitchToObject(kTVMModuleHandle, std::move(m.data_));
922  return *this;
923  }
925  this->SwitchToObject(kTVMPackedFuncHandle, std::move(f.data_));
926  return *this;
927  }
928  template <typename FType>
930  return operator=(f.packed());
931  }
932  TVMRetValue& operator=(const TVMRetValue& other) { // NOLINT(*0
933  this->Assign(other);
934  return *this;
935  }
937  this->Assign(other);
938  return *this;
939  }
941  this->Assign(other);
942  return *this;
943  }
953  void MoveToCHost(TVMValue* ret_value, int* ret_type_code) {
954  // cannot move str; need specially handle.
955  ICHECK(type_code_ != kTVMStr && type_code_ != kTVMBytes);
956  *ret_value = value_;
957  *ret_type_code = type_code_;
959  }
968  // Can move POD and everything under the object system.
971  ret.value_ = value;
972  ret.type_code_ = type_code;
973  return ret;
974  }
976  const TVMValue& value() const {
979  << "TVMRetValue.value can only be used for POD data";
980  return value_;
981  }
982  // ObjectRef handling
983  template <typename TObjectRef,
984  typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
985  inline TVMRetValue& operator=(TObjectRef other);
986  template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
987  inline operator T() const;
988 
989  private:
990  template <typename T>
991  void Assign(const T& other) {
992  switch (other.type_code()) {
993  case kTVMStr: {
994  SwitchToClass<std::string>(kTVMStr, other);
995  break;
996  }
997  case kTVMBytes: {
998  SwitchToClass<std::string>(kTVMBytes, other);
999  break;
1000  }
1001  case kTVMPackedFuncHandle: {
1002  *this = other.operator PackedFunc();
1003  break;
1004  }
1005  case kTVMModuleHandle: {
1006  *this = other.operator Module();
1007  break;
1008  }
1009  case kTVMNDArrayHandle: {
1010  *this = other.operator NDArray();
1011  break;
1012  }
1013  case kTVMObjectHandle: {
1014  // Avoid operator ObjectRef as we already know it is not NDArray/Module
1015  SwitchToObject(kTVMObjectHandle,
1016  GetObjectPtr<Object>(static_cast<Object*>(other.value_.v_handle)));
1017  break;
1018  }
1019  case kTVMObjectRValueRefArg: {
1020  operator=(other.operator ObjectRef());
1021  break;
1022  }
1023  default: {
1024  SwitchToPOD(other.type_code());
1025  value_ = other.value_;
1026  break;
1027  }
1028  }
1029  }
1030  // get the internal container.
1031  void SwitchToPOD(int type_code) {
1032  if (type_code_ != type_code) {
1033  this->Clear();
1035  }
1036  }
1037  template <typename T>
1038  void SwitchToClass(int type_code, T v) {
1039  if (type_code_ != type_code) {
1040  this->Clear();
1042  value_.v_handle = new T(v);
1043  } else {
1044  *static_cast<T*>(value_.v_handle) = v;
1045  }
1046  }
1047  void SwitchToObject(int type_code, ObjectPtr<Object> other) {
1048  if (other.data_ != nullptr) {
1049  this->Clear();
1051  // move the handle out
1052  value_.v_handle = other.data_;
1053  other.data_ = nullptr;
1054  } else {
1055  SwitchToPOD(kTVMNullptr);
1056  value_.v_handle = nullptr;
1057  }
1058  }
1059  void Clear() {
1060  if (type_code_ == kTVMNullptr) return;
1061  switch (type_code_) {
1062  case kTVMStr:
1063  case kTVMBytes:
1064  delete ptr<std::string>();
1065  break;
1066  case kTVMPackedFuncHandle:
1067  static_cast<Object*>(value_.v_handle)->DecRef();
1068  break;
1069  case kTVMNDArrayHandle: {
1071  break;
1072  }
1073  case kTVMModuleHandle: {
1074  static_cast<Object*>(value_.v_handle)->DecRef();
1075  break;
1076  }
1077  case kTVMObjectHandle: {
1078  static_cast<Object*>(value_.v_handle)->DecRef();
1079  break;
1080  }
1081  }
1083  }
1084 };
1085 
1095 template <typename TObjectRef>
1102  static TObjectRef From(const TVMArgValue& val) { return val.AsObjectRef<TObjectRef>(); }
1108  static TObjectRef From(const TVMRetValue& val) { return val.AsObjectRef<TObjectRef>(); }
1109 };
1110 
1130 #define TVM_DLL_EXPORT_PACKED_FUNC(ExportName, Function) \
1131  extern "C" { \
1132  TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \
1133  int* out_type_code, void* resource_handle); \
1134  int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \
1135  int* out_type_code, void* resource_handle) { \
1136  try { \
1137  ::tvm::runtime::TVMRetValue rv; \
1138  Function(::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \
1139  rv.MoveToCHost(out_value, out_type_code); \
1140  return 0; \
1141  } catch (const ::std::exception& _except_) { \
1142  TVMAPISetLastError(_except_.what()); \
1143  return -1; \
1144  } \
1145  } \
1146  }
1147 
1148 #define TVM_MODULE_VTABLE_BEGIN(TypeKey) \
1149  const char* type_key() const final { return TypeKey; } \
1150  PackedFunc GetFunction(const String& _name, const ObjectPtr<Object>& _self) final { \
1151  using SelfPtr = std::remove_cv_t<decltype(this)>;
1152 #define TVM_MODULE_VTABLE_END() \
1153  return PackedFunc(nullptr); \
1154  }
1155 #define TVM_MODULE_VTABLE_ENTRY(Name, MemFunc) \
1156  if (_name == Name) { \
1157  return PackedFunc([_self](TVMArgs args, TVMRetValue* rv) -> void { \
1158  using Helper = ::tvm::runtime::detail::ModuleVTableEntryHelper<decltype(MemFunc)>; \
1159  SelfPtr self = static_cast<SelfPtr>(_self.get()); \
1160  CHECK_EQ(args.size(), Helper::LenArgs) \
1161  << "Function `" << self->type_key() << "::" << Name << "` requires " << Helper::LenArgs \
1162  << " arguments, but got " << args.size(); \
1163  Helper::Call(rv, self, MemFunc, args, Helper::IndexSeq{}); \
1164  }); \
1165  }
1166 #define TVM_MODULE_VTABLE_ENTRY_PACKED(Name, Func) \
1167  if (_name == Name) { \
1168  auto f = (Func); \
1169  using FType = ::tvm::runtime::detail::function_signature<decltype(f)>::FType; \
1170  return TypedPackedFunc<FType>(std::move(f)).packed(); \
1171  }
1172 
1208 #define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \
1209  extern "C" { \
1210  TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \
1211  int* out_type_code, void* resource_handle) { \
1212  try { \
1213  auto f = Function; \
1214  using FType = ::tvm::runtime::detail::function_signature<decltype(f)>::FType; \
1215  ::tvm::runtime::TVMRetValue rv; \
1216  ::tvm::runtime::detail::unpack_call_by_signature<FType>::run( \
1217  f, ::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \
1218  rv.MoveToCHost(out_value, out_type_code); \
1219  return 0; \
1220  } catch (const ::std::exception& _except_) { \
1221  TVMAPISetLastError(_except_.what()); \
1222  return -1; \
1223  } \
1224  } \
1225  }
1226 
1227 inline TVMArgValue TVMArgs::operator[](int i) const {
1228  ICHECK_LT(i, num_args) << "not enough argument passed, " << num_args << " passed"
1229  << " but request arg[" << i << "].";
1230  return TVMArgValue(values[i], type_codes[i]);
1231 }
1232 
1233 inline int TVMArgs::size() const { return num_args; }
1234 
1235 template <class TPackedFuncSubObj>
1237  TVMRetValue* rv) {
1238  (static_cast<const TPackedFuncSubObj*>(obj))->callable_(args, rv);
1239 }
1240 
1241 TVM_ALWAYS_INLINE void PackedFuncObj::CallPacked(TVMArgs args, TVMRetValue* rv) const {
1242  (*f_call_packed_)(this, args, rv);
1243 }
1244 
1245 TVM_ALWAYS_INLINE void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const {
1246  (static_cast<PackedFuncObj*>(data_.get()))->CallPacked(args, rv);
1247 }
1248 
1249 // internal namespace
1250 inline const char* ArgTypeCode2Str(int type_code) {
1251  switch (type_code) {
1252  case kDLInt:
1253  return "int";
1254  case kDLUInt:
1255  return "uint";
1256  case kDLFloat:
1257  return "float";
1258  case kTVMStr:
1259  return "str";
1260  case kTVMBytes:
1261  return "bytes";
1262  case kTVMOpaqueHandle:
1263  return "handle";
1264  case kTVMNullptr:
1265  return "NULL";
1266  case kTVMDLTensorHandle:
1267  return "ArrayHandle";
1268  case kTVMDataType:
1269  return "DLDataType";
1270  case kDLDevice:
1271  return "DLDevice";
1272  case kTVMPackedFuncHandle:
1273  return "FunctionHandle";
1274  case kTVMModuleHandle:
1275  return "ModuleHandle";
1276  case kTVMNDArrayHandle:
1277  return "NDArrayContainer";
1278  case kTVMObjectHandle:
1279  return "Object";
1281  return "ObjectRValueRefArg";
1282  default:
1283  LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
1284  }
1285  throw;
1286 }
1287 
1293 inline const char* DLDeviceType2Str(int type) {
1294  switch (type) {
1295  case kDLCPU:
1296  return "cpu";
1297  case kDLCUDA:
1298  return "cuda";
1299  case kDLCUDAHost:
1300  return "cuda_host";
1301  case kDLCUDAManaged:
1302  return "cuda_managed";
1303  case kDLOpenCL:
1304  return "opencl";
1305  case kDLSDAccel:
1306  return "sdaccel";
1307  case kDLAOCL:
1308  return "aocl";
1309  case kDLVulkan:
1310  return "vulkan";
1311  case kDLMetal:
1312  return "metal";
1313  case kDLVPI:
1314  return "vpi";
1315  case kDLROCM:
1316  return "rocm";
1317  case kDLROCMHost:
1318  return "rocm_host";
1319  case kDLExtDev:
1320  return "ext_dev";
1321  case kDLOneAPI:
1322  return "oneapi";
1323  case kDLWebGPU:
1324  return "webgpu";
1325  case kDLHexagon:
1326  return "hexagon";
1327  case kOpenGL:
1328  return "opengl";
1329  case kDLMicroDev:
1330  return "microdev";
1331  default:
1332  LOG(FATAL) << "unknown type = " << type;
1333  }
1334  throw;
1335 }
1336 
1337 namespace detail {
1338 
1339 template <bool stop, std::size_t I, typename F>
1340 struct for_each_dispatcher {
1341  template <typename T, typename... Args>
1342  static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*)
1343  f(I, std::forward<T>(value));
1344  for_each_dispatcher<sizeof...(Args) == 0, (I + 1), F>::run(f, std::forward<Args>(args)...);
1345  }
1346 };
1347 
1348 template <std::size_t I, typename F>
1349 struct for_each_dispatcher<true, I, F> {
1350  static void run(const F& f) {} // NOLINT(*)
1351 };
1352 
1353 template <typename F, typename... Args>
1354 inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
1355  for_each_dispatcher<sizeof...(Args) == 0, 0, F>::run(f, std::forward<Args>(args)...);
1356 }
1357 
1358 template <typename T>
1359 struct ModuleVTableEntryHelper {};
1360 
1361 template <typename T, typename R, typename... Args>
1362 struct ModuleVTableEntryHelper<R (T::*)(Args...) const> {
1363  using MemFnType = R (T::*)(Args...) const;
1364  using IndexSeq = std::index_sequence_for<Args...>;
1365  static constexpr const std::size_t LenArgs = sizeof...(Args);
1366 
1367  template <std::size_t... Is>
1368  static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
1369  std::index_sequence<Is...>) {
1370  *rv = (self->*f)(args[Is]...);
1371  }
1372 };
1373 
1374 template <typename T, typename R, typename... Args>
1375 struct ModuleVTableEntryHelper<R (T::*)(Args...)> {
1376  using MemFnType = R (T::*)(Args...);
1377  using IndexSeq = std::index_sequence_for<Args...>;
1378  static constexpr const std::size_t LenArgs = sizeof...(Args);
1379 
1380  template <std::size_t... Is>
1381  static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
1382  std::index_sequence<Is...>) {
1383  *rv = (self->*f)(args[Is]...);
1384  }
1385 };
1386 
1387 template <typename T, typename... Args>
1388 struct ModuleVTableEntryHelper<void (T::*)(Args...) const> {
1389  using MemFnType = void (T::*)(Args...) const;
1390  using IndexSeq = std::index_sequence_for<Args...>;
1391  static constexpr const std::size_t LenArgs = sizeof...(Args);
1392 
1393  template <std::size_t... Is>
1394  static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
1395  std::index_sequence<Is...>) {
1396  (self->*f)(args[Is]...);
1397  }
1398 };
1399 
1400 template <typename T, typename... Args>
1401 struct ModuleVTableEntryHelper<void (T::*)(Args...)> {
1402  using MemFnType = void (T::*)(Args...);
1403  using IndexSeq = std::index_sequence_for<Args...>;
1404  static constexpr const std::size_t LenArgs = sizeof...(Args);
1405 
1406  template <std::size_t... Is>
1407  static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
1408  std::index_sequence<Is...>) {
1409  (self->*f)(args[Is]...);
1410  }
1411 };
1412 
1413 namespace parameter_pack {
1414 
1415 template <typename... EnumArgs>
1416 struct EnumeratedParamPack {
1417  struct InvokeWithoutArg {
1418  template <template <size_t i, typename TArgument> class Functor, typename ExtraParams>
1419  static void F(ExtraParams&& extra_params) {
1420  using TExpander = int[];
1421  (void)TExpander{
1422  0,
1423  (Functor<EnumArgs::i, typename EnumArgs::T>::F(std::forward<ExtraParams>(extra_params)),
1424  0)...,
1425  };
1426  }
1427  };
1428  struct InvokeWithArg {
1429  template <template <size_t i, typename TArgument> class Functor, typename ExtraParams,
1430  typename... Params>
1431  static void F(ExtraParams&& extra_params, Params&&... params) {
1432  using TExpander = int[];
1433  (void)TExpander{
1434  0,
1435  (Functor<EnumArgs::i, typename EnumArgs::T>::F(std::forward<ExtraParams>(extra_params),
1436  std::forward<Params>(params)),
1437  0)...,
1438  };
1439  }
1440  };
1441 };
1442 
1443 template <typename... Args>
1444 struct EnumerateImpl {
1445  private:
1446  template <size_t _i, typename _T>
1447  struct Item {
1448  static const constexpr size_t i = _i;
1449  using T = _T;
1450  };
1451 
1452  template <typename...>
1453  struct Zipper;
1454 
1455  template <std::size_t... id>
1456  struct Zipper<std::integer_sequence<std::size_t, id...>> {
1457  using WithoutArg = typename EnumeratedParamPack<Item<id, Args>...>::InvokeWithoutArg;
1458  using WithArg = typename EnumeratedParamPack<Item<id, Args>...>::InvokeWithArg;
1459  };
1460 
1461  public:
1462  using WithoutArg = typename Zipper<std::index_sequence_for<Args...>>::WithoutArg;
1463  using WithArg = typename Zipper<std::index_sequence_for<Args...>>::WithArg;
1464 };
1465 
1466 template <typename... Args>
1467 using EnumerateWithoutArg = typename EnumerateImpl<Args...>::WithoutArg;
1468 
1469 template <typename... Args>
1470 using EnumerateWithArg = typename EnumerateImpl<Args...>::WithArg;
1471 
1472 template <typename... Args>
1473 struct ParamPack {
1474  template <template <size_t i, typename TArgument> class Functor, typename ExtraParams>
1475  static void InvokeWithoutArg(ExtraParams&& extra_params) {
1476  EnumerateWithoutArg<Args...>::template F<Functor, ExtraParams>(
1477  std::forward<ExtraParams>(extra_params));
1478  }
1479 };
1480 
1481 } // namespace parameter_pack
1482 
1487 template <typename T>
1488 struct func_signature_helper {
1489  using FType = void;
1490 };
1491 
1492 template <typename T, typename R, typename... Args>
1493 struct func_signature_helper<R (T::*)(Args...)> {
1494  using FType = R(Args...);
1495  using ParamType = parameter_pack::ParamPack<Args...>;
1496  using RetType = R;
1497  static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
1498 };
1499 
1500 template <typename T, typename R, typename... Args>
1501 struct func_signature_helper<R (T::*)(Args...) const> {
1502  using FType = R(Args...);
1503  using ParamType = parameter_pack::ParamPack<Args...>;
1504  using RetType = R;
1505  static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
1506 };
1507 
1512 template <typename T>
1513 struct function_signature {
1514  using FType = typename func_signature_helper<decltype(&T::operator())>::FType;
1515  using ParamType = typename func_signature_helper<decltype(&T::operator())>::ParamType;
1516  using RetType = typename func_signature_helper<decltype(&T::operator())>::RetType;
1517 };
1518 
1519 // handle case of function.
1520 template <typename R, typename... Args>
1521 struct function_signature<R(Args...)> {
1522  using FType = R(Args...);
1523  using ParamType = parameter_pack::ParamPack<Args...>;
1524  using RetType = R;
1525  static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
1526 };
1527 
1528 // handle case of function ptr.
1529 template <typename R, typename... Args>
1530 struct function_signature<R (*)(Args...)> {
1531  using FType = R(Args...);
1532  using ParamType = detail::parameter_pack::ParamPack<Args...>;
1533  using RetType = R;
1534  static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
1535 };
1536 
1537 template <typename TSignature>
1538 struct SignaturePrinter;
1539 
1540 namespace type2str {
1541 
1542 template <typename T>
1543 struct TypeSimplifier;
1544 
1545 template <typename T>
1546 struct Type2Str {
1547  template <typename = std::enable_if_t<std::is_base_of<ObjectRef, T>::value>>
1548  static std::string v() {
1549  return T::ContainerType::_type_key;
1550  }
1551 };
1552 template <>
1553 struct Type2Str<int> {
1554  static std::string v() { return "int"; }
1555 };
1556 template <>
1557 struct Type2Str<double> {
1558  static std::string v() { return "double"; }
1559 };
1560 template <>
1561 struct Type2Str<int64_t> {
1562  static std::string v() { return "int64_t"; }
1563 };
1564 template <>
1565 struct Type2Str<uint64_t> {
1566  static std::string v() { return "uint64_t"; }
1567 };
1568 template <>
1569 struct Type2Str<bool> {
1570  static std::string v() { return "bool"; }
1571 };
1572 template <>
1573 struct Type2Str<void> {
1574  static std::string v() { return "void"; }
1575 };
1576 template <>
1577 struct Type2Str<std::basic_string<char>> {
1578  static std::string v() { return "basic_string<char>"; }
1579 };
1580 template <typename K, typename V>
1581 struct Type2Str<Map<K, V>> {
1582  static std::string v() {
1583  return "Map<" + TypeSimplifier<K>::v() + ", " + TypeSimplifier<V>::v() + ">";
1584  }
1585 };
1586 template <>
1587 struct Type2Str<DLDevice> {
1588  static std::string v() { return "DLDevice"; }
1589 };
1590 template <>
1591 struct Type2Str<DLTensor> {
1592  static std::string v() { return "DLTensor"; }
1593 };
1594 template <>
1595 struct Type2Str<DataType> {
1596  static std::string v() { return "DataType"; }
1597 };
1598 template <>
1599 struct Type2Str<DLDataType> {
1600  static std::string v() { return "DLDataType"; }
1601 };
1602 template <>
1603 struct Type2Str<TVMRetValue> {
1604  static std::string v() { return "TVMRetValue"; }
1605 };
1606 template <>
1607 struct Type2Str<TVMArgValue> {
1608  static std::string v() { return "TVMArgValue"; }
1609 };
1610 template <>
1611 struct Type2Str<TVMByteArray> {
1612  static std::string v() { return "TVMByteArray"; }
1613 };
1614 template <typename FType>
1615 struct Type2Str<TypedPackedFunc<FType>> {
1616  static std::string v() { return SignaturePrinter<function_signature<FType>>::F(); }
1617 };
1618 template <typename T>
1619 struct Type2Str<Array<T>> {
1620  static std::string v() { return "Array<" + TypeSimplifier<T>::v() + ">"; }
1621 };
1622 
1627 template <typename T>
1628 struct TypeSimplifier {
1629  static std::string v() {
1630  using U = typename std::remove_cv<
1631  typename std::remove_reference<typename std::remove_pointer<T>::type>::type>::type;
1632  return (std::is_const<T>::value ? "const " : "") + Type2Str<U>::v() +
1633  (std::is_pointer<T>::value ? "*" : "") + (std::is_reference<T>::value ? "&" : "");
1634  }
1635 };
1636 
1637 } // namespace type2str
1638 
1643 template <typename TSignature>
1644 struct SignaturePrinter {
1645  using ParamType = typename TSignature::ParamType;
1646  using RetType = typename TSignature::RetType;
1647 
1648  template <size_t i, typename TArgument>
1649  struct PrintParamType {
1650  static void F(std::ostream& os) {
1651  os << (i == 0 ? "" : ", ") << i << ": " << type2str::TypeSimplifier<TArgument>::v();
1652  }
1653  };
1654 
1655  static std::string F() {
1656  std::ostringstream oss;
1657  oss << "(";
1658  ParamType::template InvokeWithoutArg<PrintParamType>(oss);
1659  oss << ") -> " << type2str::TypeSimplifier<RetType>::v();
1660  return oss.str();
1661  }
1662 };
1663 } // namespace detail
1664 
1665 /* \brief argument settter to PackedFunc */
1667  public:
1668  TVMArgsSetter(TVMValue* values, int* type_codes) : values_(values), type_codes_(type_codes) {}
1669  // setters for POD types
1670  template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type>
1671  TVM_ALWAYS_INLINE void operator()(size_t i, T value) const {
1672  values_[i].v_int64 = static_cast<int64_t>(value);
1673  type_codes_[i] = kDLInt;
1674  }
1675  TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const {
1676  values_[i].v_int64 = static_cast<int64_t>(value);
1677  ICHECK_LE(value, static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
1678  type_codes_[i] = kDLInt;
1679  }
1680  TVM_ALWAYS_INLINE void operator()(size_t i, double value) const {
1681  values_[i].v_float64 = value;
1682  type_codes_[i] = kDLFloat;
1683  }
1684  TVM_ALWAYS_INLINE void operator()(size_t i, std::nullptr_t value) const {
1685  values_[i].v_handle = value;
1686  type_codes_[i] = kTVMNullptr;
1687  }
1688  TVM_ALWAYS_INLINE void operator()(size_t i, const TVMArgValue& value) const {
1689  values_[i] = value.value_;
1690  type_codes_[i] = value.type_code_;
1691  }
1692  TVM_ALWAYS_INLINE void operator()(size_t i, void* value) const {
1693  values_[i].v_handle = value;
1694  type_codes_[i] = kTVMOpaqueHandle;
1695  }
1696  TVM_ALWAYS_INLINE void operator()(size_t i, DLTensor* value) const {
1697  values_[i].v_handle = value;
1698  type_codes_[i] = kTVMDLTensorHandle;
1699  }
1700  TVM_ALWAYS_INLINE void operator()(size_t i, Device value) const {
1701  values_[i].v_device = value;
1702  type_codes_[i] = kDLDevice;
1703  }
1704  TVM_ALWAYS_INLINE void operator()(size_t i, DLDataType value) const {
1705  values_[i].v_type = value;
1706  type_codes_[i] = kTVMDataType;
1707  }
1708  TVM_ALWAYS_INLINE void operator()(size_t i, DataType dtype) const {
1709  operator()(i, dtype.operator DLDataType());
1710  }
1711  TVM_ALWAYS_INLINE void operator()(size_t i, const char* value) const {
1712  values_[i].v_str = value;
1713  type_codes_[i] = kTVMStr;
1714  }
1715  // setters for container types
1716  TVM_ALWAYS_INLINE void operator()(size_t i, const std::string& value) const {
1717  values_[i].v_str = value.c_str();
1718  type_codes_[i] = kTVMStr;
1719  }
1720  TVM_ALWAYS_INLINE void operator()(size_t i, const TVMByteArray& value) const {
1721  values_[i].v_handle = const_cast<TVMByteArray*>(&value);
1722  type_codes_[i] = kTVMBytes;
1723  }
1724  template <typename FType>
1725  TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc<FType>& value) const {
1726  operator()(i, value.packed());
1727  }
1728  void operator()(size_t i, const TVMRetValue& value) const {
1729  if (value.type_code() == kTVMStr) {
1730  values_[i].v_str = value.ptr<std::string>()->c_str();
1731  type_codes_[i] = kTVMStr;
1732  } else {
1733  ICHECK_NE(value.type_code(), kTVMBytes) << "not handled.";
1734  values_[i] = value.value_;
1735  type_codes_[i] = value.type_code();
1736  }
1737  }
1738  // ObjectRef handling
1739  template <typename TObjectRef,
1740  typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
1741  TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef& value) const {
1742  this->SetObject(i, value);
1743  }
1744 
1745  template <typename TObjectRef,
1746  typename = typename std::enable_if<std::is_base_of<
1747  ObjectRef, typename std::remove_reference<TObjectRef>::type>::value>::type>
1748  TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef&& value) const {
1749  this->SetObject(i, std::forward<TObjectRef>(value));
1750  }
1751 
1752  private:
1753  template <typename TObjectRef>
1754  inline void SetObject(size_t i, TObjectRef&& value) const;
1756  TVMValue* values_;
1758  int* type_codes_;
1759 };
1760 
1761 template <typename... Args>
1762 inline TVMRetValue PackedFunc::operator()(Args&&... args) const {
1763  const int kNumArgs = sizeof...(Args);
1764  const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
1765  TVMValue values[kArraySize];
1766  int type_codes[kArraySize];
1767  detail::for_each(TVMArgsSetter(values, type_codes), std::forward<Args>(args)...);
1768  TVMRetValue rv;
1769  (static_cast<PackedFuncObj*>(data_.get()))
1770  ->CallPacked(TVMArgs(values, type_codes, kNumArgs), &rv);
1771  return rv;
1772 }
1773 
1774 template <size_t i, typename T>
1776  static TVM_ALWAYS_INLINE void F(TVMArgsSetter* setter, T&& value) {
1777  (*setter)(i, std::forward<T>(value));
1778  }
1779 };
1780 
1781 template <typename... Args>
1782 void TVM_ALWAYS_INLINE PackArgs(TVMValue* values, int* type_codes, Args&&... args) {
1783  TVMArgsSetter setter(values, type_codes);
1784  detail::parameter_pack::EnumerateWithArg<Args...>::template F<TVMArgsSetterApply>(
1785  &setter, std::forward<Args>(args)...);
1786 }
1787 
1788 namespace detail {
1789 template <typename R, int nleft, int index, typename F>
1790 struct unpack_call_dispatcher {
1791  template <typename... Args>
1792  TVM_ALWAYS_INLINE static void run(const std::string* optional_name, FSig* f_sig, const F& f,
1793  const TVMArgs& args_pack, TVMRetValue* rv,
1794  Args&&... unpacked_args) {
1795  // construct a movable argument value
1796  // which allows potential move of argument to the input of F.
1797  unpack_call_dispatcher<R, nleft - 1, index + 1, F>::run(
1798  optional_name, f_sig, f, args_pack, rv, std::forward<Args>(unpacked_args)...,
1799  TVMMovableArgValueWithContext_(args_pack.values[index], args_pack.type_codes[index], index,
1800  optional_name, f_sig));
1801  }
1802 };
1803 
1804 template <typename R, int index, typename F>
1805 struct unpack_call_dispatcher<R, 0, index, F> {
1806  template <typename... Args>
1807  TVM_ALWAYS_INLINE static void run(const std::string* optional_name, FSig* f_sig, const F& f,
1808  const TVMArgs& args_pack, TVMRetValue* rv,
1809  Args&&... unpacked_args) {
1810  using RetType = decltype(f(std::forward<Args>(unpacked_args)...));
1811  if (std::is_same<RetType, R>::value) {
1812  *rv = f(std::forward<Args>(unpacked_args)...);
1813  } else {
1814  *rv = R(f(std::forward<Args>(unpacked_args)...));
1815  }
1816  }
1817 };
1818 
1819 template <int index, typename F>
1820 struct unpack_call_dispatcher<void, 0, index, F> {
1821  template <typename... Args>
1822  TVM_ALWAYS_INLINE static void run(const std::string* optional_name, FSig* f_sig, const F& f,
1823  const TVMArgs& args_pack, TVMRetValue* rv,
1824  Args&&... unpacked_args) {
1825  f(std::forward<Args>(unpacked_args)...);
1826  }
1827 };
1828 
1829 template <typename R, int nargs, typename F>
1830 TVM_ALWAYS_INLINE void unpack_call(const std::string* optional_name, const F& f,
1831  const TVMArgs& args, TVMRetValue* rv) {
1832  FSig* f_sig = detail::SignaturePrinter<detail::function_signature<F>>::F;
1833  CHECK_EQ(nargs, args.size()) << "Function "
1834  << (optional_name == nullptr ? "<anonymous>" : *optional_name)
1835  << (f_sig == nullptr ? "" : (*f_sig)()) << " expects " << nargs
1836  << " arguments but " << args.size() << " were provided";
1837  unpack_call_dispatcher<R, nargs, 0, F>::run(optional_name, f_sig, f, args, rv);
1838 }
1839 
1840 template <typename FType>
1841 struct unpack_call_by_signature {};
1842 
1843 template <typename R, typename... Args>
1844 struct unpack_call_by_signature<R(Args...)> {
1845  template <typename F>
1846  TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args, TVMRetValue* rv) {
1847  unpack_call<R, sizeof...(Args)>(nullptr, f, args, rv);
1848  }
1849 };
1850 
1851 template <typename R, typename... Args>
1852 TVM_ALWAYS_INLINE R call_packed(const PackedFunc& pf, Args&&... args) {
1853  return R(pf(std::forward<Args>(args)...));
1854 }
1855 
1856 template <typename R>
1857 struct typed_packed_call_dispatcher {
1858  template <typename... Args>
1859  TVM_ALWAYS_INLINE static R run(const PackedFunc& pf, Args&&... args) {
1860  return pf(std::forward<Args>(args)...);
1861  }
1862 };
1863 
1864 template <>
1865 struct typed_packed_call_dispatcher<void> {
1866  template <typename... Args>
1867  TVM_ALWAYS_INLINE static void run(const PackedFunc& pf, Args&&... args) {
1868  pf(std::forward<Args>(args)...);
1869  }
1870 };
1871 } // namespace detail
1872 
1873 template <typename R, typename... Args>
1874 TypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed) : packed_(packed) {}
1875 
1876 template <typename R, typename... Args>
1878  : packed_(value.operator PackedFunc()) {}
1879 
1880 template <typename R, typename... Args>
1882  : packed_(value.operator PackedFunc()) {}
1883 
1884 template <typename R, typename... Args>
1886  : packed_(value.operator PackedFunc()) {}
1887 
1888 template <typename R, typename... Args>
1889 template <typename FType>
1890 inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda, std::string name) {
1891  FSig* f_sig = detail::SignaturePrinter<detail::function_signature<FType>>::F;
1892  packed_ = PackedFunc([flambda, name, f_sig](const TVMArgs& args, TVMRetValue* rv) {
1893  if (args.size() != sizeof...(Args)) {
1894  LOG(FATAL) << "Function " << name << (f_sig == nullptr ? "" : (*f_sig)()) << " expects "
1895  << sizeof...(Args) << " arguments, but " << args.size() << " were provided.";
1896  }
1897  detail::unpack_call<R, sizeof...(Args)>(&name, flambda, args, rv);
1898  });
1899 }
1900 
1901 template <typename R, typename... Args>
1902 template <typename FType>
1903 inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
1904  FSig* f_sig = detail::SignaturePrinter<detail::function_signature<FType>>::F;
1905  packed_ = PackedFunc([flambda, f_sig](const TVMArgs& args, TVMRetValue* rv) {
1906  if (args.size() != sizeof...(Args)) {
1907  LOG(FATAL) << "Function <anonymous> " << (*f_sig)() << " expects " << sizeof...(Args)
1908  << " arguments, but " << args.size() << " were provided.";
1909  }
1910  detail::unpack_call<R, sizeof...(Args)>(nullptr, flambda, args, rv);
1911  });
1912 }
1913 
1914 template <typename R, typename... Args>
1915 TVM_ALWAYS_INLINE R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
1916  return detail::typed_packed_call_dispatcher<R>::run(packed_, std::forward<Args>(args)...);
1917 }
1918 
1919 // ObjectRef related conversion handling
1920 // Object can have three possible type codes:
1921 // kTVMNDArrayHandle, kTVMModuleHandle, kTVMObjectHandle
1922 //
1923 // We use type traits to eliminate un-necessary checks.
1924 template <typename T>
1925 inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
1926  using ContainerType = typename std::remove_reference<T>::type::ContainerType;
1927  if (value.defined()) {
1928  Object* ptr = value.data_.data_;
1929  if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
1930  (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1932  values_[i].v_handle = NDArray::FFIGetHandle(value);
1933  type_codes_[i] = kTVMNDArrayHandle;
1934  } else if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
1935  (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1936  ptr->IsInstance<Module::ContainerType>())) {
1937  values_[i].v_handle = ptr;
1938  type_codes_[i] = kTVMModuleHandle;
1939  } else if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value ||
1940  (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
1942  values_[i].v_handle = ptr;
1943  type_codes_[i] = kTVMPackedFuncHandle;
1944  } else if (std::is_rvalue_reference<decltype(value)>::value) {
1945  values_[i].v_handle = const_cast<Object**>(&(value.data_.data_));
1946  type_codes_[i] = kTVMObjectRValueRefArg;
1947  } else {
1948  values_[i].v_handle = value.data_.data_;
1949  type_codes_[i] = kTVMObjectHandle;
1950  }
1951  } else {
1952  type_codes_[i] = kTVMNullptr;
1953  values_[i].v_handle = nullptr;
1954  }
1955 }
1956 
1957 template <typename TObjectRef, typename>
1958 inline bool TVMPODValue_::IsObjectRef() const {
1959  using ContainerType = typename TObjectRef::ContainerType;
1960  // NOTE: the following code can be optimized by constant folding.
1961  if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
1962  return type_code_ == kTVMNDArrayHandle &&
1964  ->IsInstance<ContainerType>();
1965  }
1966  if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
1967  return type_code_ == kTVMModuleHandle &&
1968  static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
1969  }
1970  if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value) {
1971  return type_code_ == kTVMPackedFuncHandle &&
1972  static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
1973  }
1974  // NOTE: we don't pass NDArray and runtime::Module as RValue ref.
1977  }
1978  return (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1980  (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1982  (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
1986 }
1987 
1988 template <typename TObjectRef>
1989 inline TObjectRef TVMPODValue_::AsObjectRef() const {
1990  static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
1991  "Conversion only works for ObjectRef");
1992  using ContainerType = typename TObjectRef::ContainerType;
1993 
1994  if (type_code_ == kTVMNullptr) {
1995  CHECK(TObjectRef::_type_is_nullable)
1996  << "Expect a not null value of " << ContainerType::_type_key;
1997  return TObjectRef(ObjectPtr<Object>(nullptr));
1998  }
1999  // NOTE: the following code can be optimized by constant folding.
2000  if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
2001  // Casting to a sub-class of NDArray
2003  ObjectPtr<Object> data =
2005  CHECK(data->IsInstance<ContainerType>())
2006  << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey();
2007  return TObjectRef(data);
2008  }
2009  if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
2010  // Casting to a sub-class of Module
2012  ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle));
2013  CHECK(data->IsInstance<ContainerType>())
2014  << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey();
2015  return TObjectRef(data);
2016  }
2017  if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value) {
2018  // Casting to a sub-class of PackedFunc
2020  ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle));
2021  CHECK(data->IsInstance<ContainerType>())
2022  << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey();
2023  return TObjectRef(data);
2024  }
2025  if (type_code_ == kTVMObjectHandle) {
2026  // normal object type check.
2027  Object* ptr = static_cast<Object*>(value_.v_handle);
2029  ICHECK(!checked_type.defined()) << "Expected " << ObjectTypeChecker<TObjectRef>::TypeName()
2030  << ", but got " << checked_type.value();
2031  return TObjectRef(GetObjectPtr<Object>(ptr));
2032  } else if (type_code_ == kTVMObjectRValueRefArg) {
2033  Object* ptr = *static_cast<Object**>(value_.v_handle);
2035  ICHECK(!checked_type.defined()) << "Expected " << ObjectTypeChecker<TObjectRef>::TypeName()
2036  << ", but got " << checked_type.value();
2037  return TObjectRef(GetObjectPtr<Object>(ptr));
2038  } else if (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
2040  // Casting to a base class that NDArray can sub-class
2041  ObjectPtr<Object> data =
2043  return TObjectRef(data);
2044  } else if (std::is_base_of<ContainerType, Module::ContainerType>::value &&
2046  // Casting to a base class that Module can sub-class
2047  return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
2048  } else if (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
2050  // Casting to a base class that PackedFunc can sub-class
2051  return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
2052  } else {
2054  return TObjectRef(ObjectPtr<Object>(nullptr));
2055  }
2056 }
2057 
2058 template <typename TObjectRef, typename>
2059 inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) {
2060  using ContainerType = typename TObjectRef::ContainerType;
2061  const Object* ptr = other.get();
2062  if (ptr != nullptr) {
2063  if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
2064  (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
2065  ptr->IsInstance<NDArray::ContainerType>())) {
2066  return operator=(NDArray(std::move(other.data_)));
2067  }
2068  if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
2069  (std::is_base_of<ContainerType, Module::ContainerType>::value &&
2070  ptr->IsInstance<Module::ContainerType>())) {
2071  return operator=(Module(std::move(other.data_)));
2072  }
2073  if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value ||
2074  (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
2075  ptr->IsInstance<PackedFunc::ContainerType>())) {
2076  return operator=(PackedFunc(std::move(other.data_)));
2077  }
2078  SwitchToObject(kTVMObjectHandle, std::move(other.data_));
2079  } else {
2080  SwitchToPOD(kTVMNullptr);
2081  value_.v_handle = nullptr;
2082  }
2083  return *this;
2084 }
2085 
2086 template <typename T, typename>
2087 inline TVMArgValue::operator T() const {
2088  return PackedFuncValueConverter<T>::From(*this);
2089 }
2090 
2091 template <typename T, typename>
2092 inline TVMMovableArgValue_::operator T() const {
2093  if (type_code_ == kTVMObjectRValueRefArg) {
2094  auto** ref = static_cast<Object**>(value_.v_handle);
2095  if (ObjectTypeChecker<T>::Check(*ref)) {
2097  }
2098  }
2099  // fallback
2100  return PackedFuncValueConverter<T>::From(AsArgValue());
2101 }
2102 
2103 template <typename T, typename>
2104 inline TVMRetValue::operator T() const {
2105  return PackedFuncValueConverter<T>::From(*this);
2106 }
2107 
2108 inline PackedFunc Module::GetFunction(const String& name, bool query_imports) {
2109  return (*this)->GetFunction(name, query_imports);
2110 }
2111 
2112 // specializations of PackedFuncValueConverter
2113 template <>
2115  static String From(const TVMArgValue& val) {
2116  if (val.IsObjectRef<tvm::runtime::String>()) {
2117  return val.AsObjectRef<tvm::runtime::String>();
2118  } else {
2119  return tvm::runtime::String(val.operator std::string());
2120  }
2121  }
2122 
2123  static String From(const TVMRetValue& val) {
2124  if (val.IsObjectRef<tvm::runtime::String>()) {
2125  return val.AsObjectRef<tvm::runtime::String>();
2126  } else {
2127  return tvm::runtime::String(val.operator std::string());
2128  }
2129  }
2130 };
2131 
2132 template <typename T>
2134  static Optional<T> From(const TVMArgValue& val) {
2135  if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
2137  }
2138  static Optional<T> From(const TVMRetValue& val) {
2139  if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
2141  }
2142 };
2143 
2144 template <typename... VariantTypes>
2145 struct PackedFuncValueConverter<Variant<VariantTypes...>> {
2146  using VType = Variant<VariantTypes...>;
2147 
2148  // Can't just take `const TVMPODValue&` as an argument, because
2149  // `TVMArgValue` and `TVMRetValue` have different implementations
2150  // for `operator std::string()`.
2151  template <typename PODSubclass>
2152  static VType From(const PODSubclass& val) {
2153  if (auto opt = TryAsObjectRef<VariantTypes...>(val)) {
2154  return opt.value();
2155  }
2156 
2157  if (auto opt = TryValueConverter<PODSubclass, VariantTypes...>(val)) {
2158  return opt.value();
2159  }
2160 
2161  LOG(FATAL) << "Expected one of "
2162  << static_cast<const std::stringstream&>(
2163  (std::stringstream() << ... << VariantTypes::ContainerType::_type_key))
2164  .str()
2165  << " but got " << ArgTypeCode2Str(val.type_code());
2166  }
2167 
2168  template <typename VarFirst, typename... VarRest>
2170  if (val.IsObjectRef<VarFirst>()) {
2171  return VType(val.AsObjectRef<VarFirst>());
2172  } else if constexpr (sizeof...(VarRest)) {
2173  return TryAsObjectRef<VarRest...>(val);
2174  } else {
2175  return NullOpt;
2176  }
2177  }
2178 
2179  template <typename PODSubclass, typename VarFirst, typename... VarRest>
2180  static Optional<VType> TryValueConverter(const PODSubclass& val) {
2181  try {
2183  } catch (const InternalError&) {
2184  }
2185 
2186  if constexpr (sizeof...(VarRest)) {
2187  return TryValueConverter<PODSubclass, VarRest...>(val);
2188  } else {
2189  return NullOpt;
2190  }
2191  }
2192 };
2193 
2194 inline bool String::CanConvertFrom(const TVMArgValue& val) {
2195  return val.type_code() == kTVMStr || val.IsObjectRef<tvm::runtime::String>();
2196 }
2197 
2198 inline TVMArgValue::operator DLDataType() const {
2199  if (String::CanConvertFrom(*this)) {
2200  return String2DLDataType(PackedFuncValueConverter<String>::From(*this).operator std::string());
2201  }
2202  // None type
2203  if (type_code_ == kTVMNullptr) {
2204  DLDataType t;
2205  t.code = kTVMOpaqueHandle;
2206  t.bits = 0;
2207  t.lanes = 0;
2208  return t;
2209  }
2210  TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType);
2211  return value_.v_type;
2212 }
2213 
2214 inline TVMArgValue::operator DataType() const { return DataType(operator DLDataType()); }
2215 
2216 } // namespace runtime
2217 } // namespace tvm
2218 #endif // TVM_RUNTIME_PACKED_FUNC_H_
Runtime Array container types.
@ kTVMPackedFuncHandle
Definition: c_runtime_api.h:184
@ kTVMNDArrayHandle
Definition: c_runtime_api.h:187
@ kTVMModuleHandle
Definition: c_runtime_api.h:183
@ kTVMBytes
Definition: c_runtime_api.h:186
@ kTVMDataType
Definition: c_runtime_api.h:179
@ kTVMDLTensorHandle
Definition: c_runtime_api.h:181
@ kDLDevice
Definition: c_runtime_api.h:180
@ kTVMOpaqueHandle
Definition: c_runtime_api.h:177
@ kTVMObjectHandle
Definition: c_runtime_api.h:182
@ kTVMObjectRValueRefArg
Definition: c_runtime_api.h:188
@ kTVMNullptr
Definition: c_runtime_api.h:178
@ kTVMStr
Definition: c_runtime_api.h:185
@ kDLMicroDev
Definition: c_runtime_api.h:124
@ kOpenGL
Definition: c_runtime_api.h:123
@ kDLSDAccel
Definition: c_runtime_api.h:122
@ kDLAOCL
Definition: c_runtime_api.h:121
DLTensor * TVMArrayHandle
the array handle
Definition: c_runtime_api.h:202
array node content in array
Definition: array.h:40
size_t size() const
Definition: array.h:43
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Runtime primitive data type.
Definition: data_type.h:42
Shared content of all specializations of hash map.
Definition: map.h:174
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
Base container of module.
Definition: module.h:142
Module container of TVM.
Definition: module.h:79
PackedFunc GetFunction(const String &name, bool query_imports=false)
Get packed function from current module by name.
Definition: packed_func.h:2108
Object container class that backs NDArray.
Definition: ndarray.h:287
Managed NDArray. The array is backed by reference counted blocks.
Definition: ndarray.h:51
static TVMArrayHandle FFIGetHandle(const ObjectRef &nd)
Get FFI Array handle from ndarray.
Definition: ndarray.h:424
static ObjectPtr< Object > FFIDataFromHandle(TVMArrayHandle handle)
Construct NDArray's Data field from array handle in FFI.
Definition: ndarray.h:419
static void FFIDecRef(TVMArrayHandle handle)
DecRef resource managed by an FFI array handle.
Definition: ndarray.h:432
A custom smart pointer for Object.
Definition: object.h:360
Base class of all object reference.
Definition: object.h:517
bool defined() const
Definition: object.h:550
static void FFIClearAfterMove(ObjectRef *ref)
Clear the object ref data field without DecRef after we successfully moved the field.
Definition: object.h:621
friend class TVMArgsSetter
Definition: object.h:635
const Object * get() const
Definition: object.h:552
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:603
base class of all object containers.
Definition: object.h:169
std::string GetTypeKey() const
Definition: object.h:182
bool IsInstance() const
Definition: object.h:858
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
T value() const
Definition: optional.h:92
Object container class that backs PackedFunc.
Definition: packed_func.h:69
PackedFuncObj(FCallPacked *f_call_pack)
Constructing a packed function object from a function pointer.
Definition: packed_func.h:104
static constexpr const uint32_t _type_index
Definition: packed_func.h:78
TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncObj, Object)
FCallPacked * f_call_packed_
Internal callable function pointer used to call the packed function.
Definition: packed_func.h:110
static constexpr const char * _type_key
Definition: packed_func.h:79
PackedFuncObj()=delete
Delete the default constructor explicitly.
void(const PackedFuncObj *, TVMArgs, TVMRetValue *) FCallPacked
The internal callable function type.
Definition: packed_func.h:98
TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue *rv) const
Call the function in packed format.
Definition: packed_func.h:1241
Derived object class for constructing PackedFuncObj.
Definition: packed_func.h:115
PackedFuncSubObj(TCallable callable)
Derived object class for constructing PackedFuncObj.
Definition: packed_func.h:125
TStorage callable_
Type-erased filed for storing callable object.
Definition: packed_func.h:128
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:139
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:181
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:179
TVMRetValue operator()(Args &&... args) const
Call packed function by directly passing in unpacked format.
Definition: packed_func.h:1762
TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue *rv) const
Call the function in packed format.
Definition: packed_func.h:1245
PackedFunc(TCallable data)
Constructing a packed function from a callable type whose signature is consistent with PackedFunc
Definition: packed_func.h:152
TVM_DEFINE_OBJECT_REF_METHODS(PackedFunc, ObjectRef, PackedFuncObj)
PackedFunc(std::nullptr_t null)
Constructor from null.
Definition: packed_func.h:142
Reference to string objects.
Definition: string.h:98
static bool CanConvertFrom(const TVMArgValue &val)
Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String.
Definition: packed_func.h:2194
A single argument value to PackedFunc. Containing both type_code and TVMValue.
Definition: packed_func.h:649
TObjectRef AsObjectRef() const
Definition: packed_func.h:1989
const TVMValue & value() const
Definition: packed_func.h:691
TVMArgValue(TVMValue value, int type_code)
constructor
Definition: packed_func.h:658
bool IsObjectRef() const
Definition: packed_func.h:1958
TVMArgValue()
default constructor
Definition: packed_func.h:652
Definition: packed_func.h:1666
TVM_ALWAYS_INLINE void operator()(size_t i, DLTensor *value) const
Definition: packed_func.h:1696
TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const
Definition: packed_func.h:1675
TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc< FType > &value) const
Definition: packed_func.h:1725
TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef &value) const
Definition: packed_func.h:1741
void operator()(size_t i, const TVMRetValue &value) const
Definition: packed_func.h:1728
TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef &&value) const
Definition: packed_func.h:1748
TVM_ALWAYS_INLINE void operator()(size_t i, double value) const
Definition: packed_func.h:1680
TVM_ALWAYS_INLINE void operator()(size_t i, T value) const
Definition: packed_func.h:1671
TVMArgsSetter(TVMValue *values, int *type_codes)
Definition: packed_func.h:1668
TVM_ALWAYS_INLINE void operator()(size_t i, const char *value) const
Definition: packed_func.h:1711
TVM_ALWAYS_INLINE void operator()(size_t i, Device value) const
Definition: packed_func.h:1700
TVM_ALWAYS_INLINE void operator()(size_t i, void *value) const
Definition: packed_func.h:1692
TVM_ALWAYS_INLINE void operator()(size_t i, const std::string &value) const
Definition: packed_func.h:1716
TVM_ALWAYS_INLINE void operator()(size_t i, std::nullptr_t value) const
Definition: packed_func.h:1684
TVM_ALWAYS_INLINE void operator()(size_t i, const TVMByteArray &value) const
Definition: packed_func.h:1720
TVM_ALWAYS_INLINE void operator()(size_t i, DataType dtype) const
Definition: packed_func.h:1708
TVM_ALWAYS_INLINE void operator()(size_t i, DLDataType value) const
Definition: packed_func.h:1704
TVM_ALWAYS_INLINE void operator()(size_t i, const TVMArgValue &value) const
Definition: packed_func.h:1688
Arguments into TVM functions.
Definition: packed_func.h:392
const TVMValue * values
Definition: packed_func.h:394
TVMArgs(const TVMValue *values, const int *type_codes, int num_args)
constructor
Definition: packed_func.h:403
TVMArgValue operator[](int i) const
Get i-th argument.
Definition: packed_func.h:1227
const int * type_codes
Definition: packed_func.h:395
int size() const
Definition: packed_func.h:1233
int num_args
Definition: packed_func.h:396
Internal auxiliary struct for TypedPackedFunc to indicate a movable argument with additional context ...
Definition: packed_func.h:754
TVMMovableArgValueWithContext_(TVMValue value, int type_code, int arg_index, const std::string *optional_name, FSig *f_sig)
move constructor from another return value.
Definition: packed_func.h:765
Internal auxiliary struct for TypedPackedFunc to indicate a movable argument.
Definition: packed_func.h:709
TVMMovableArgValue_(TVMValue value, int type_code)
Definition: packed_func.h:711
Internal base class to handle conversion to POD values.
Definition: packed_func.h:544
TObjectRef AsObjectRef() const
Definition: packed_func.h:1989
TVMPODValue_()
Definition: packed_func.h:634
bool IsObjectRef() const
Definition: packed_func.h:1958
TVMValue value_
The value.
Definition: packed_func.h:638
T * ptr() const
return handle as specific pointer type.
Definition: packed_func.h:620
int type_code_
the type code
Definition: packed_func.h:640
int type_code() const
Definition: packed_func.h:613
TVMPODValue_(TVMValue value, int type_code)
Definition: packed_func.h:635
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:799
static TVMRetValue MoveFromCHost(TVMValue value, int type_code)
Construct a new TVMRetValue by moving from return value stored via C API.
Definition: packed_func.h:967
const TVMValue & value() const
Definition: packed_func.h:976
~TVMRetValue()
destructor
Definition: packed_func.h:812
TObjectRef AsObjectRef() const
Definition: packed_func.h:1989
TVMRetValue & operator=(DLDataType t)
Definition: packed_func.h:889
TVMRetValue & operator=(const TypedPackedFunc< FType > &f)
Definition: packed_func.h:929
TVMRetValue & operator=(std::nullptr_t value)
Definition: packed_func.h:864
TVMRetValue & operator=(TVMRetValue &&other)
Definition: packed_func.h:852
TVMRetValue & operator=(PackedFunc f)
Definition: packed_func.h:924
TVMRetValue & operator=(bool value)
Definition: packed_func.h:895
TVMRetValue & operator=(int value)
Definition: packed_func.h:879
bool IsObjectRef() const
Definition: packed_func.h:1958
void MoveToCHost(TVMValue *ret_value, int *ret_type_code)
Move the value back to front-end via C API. This marks the current container as null....
Definition: packed_func.h:953
TVMRetValue()
default constructor
Definition: packed_func.h:802
TVMRetValue & operator=(void *value)
Definition: packed_func.h:869
TVMRetValue & operator=(std::string value)
Definition: packed_func.h:900
TVMRetValue & operator=(const TVMArgValue &other)
Definition: packed_func.h:936
TVMRetValue(const TVMRetValue &other)
Definition: packed_func.h:828
TVMRetValue & operator=(double value)
Definition: packed_func.h:859
TVMRetValue(TVMRetValue &&other)
move constructor from another return value.
Definition: packed_func.h:807
TVMRetValue & operator=(const TVMRetValue &other)
Definition: packed_func.h:932
TVMRetValue & operator=(int64_t value)
Definition: packed_func.h:874
TVMRetValue & operator=(Module m)
Definition: packed_func.h:920
TVMRetValue & operator=(const DataType &other)
Definition: packed_func.h:894
TVMRetValue & operator=(DLDevice value)
Definition: packed_func.h:884
TVMRetValue & operator=(TVMMovableArgValue_ &&other)
Definition: packed_func.h:940
TVMRetValue & operator=(TVMByteArray value)
Definition: packed_func.h:904
TVMRetValue & operator=(NDArray other)
Definition: packed_func.h:908
A PackedFunc wrapper to provide typed function signature. It is backed by a PackedFunc internally.
Definition: packed_func.h:228
const PackedFunc & packed() const
Definition: packed_func.h:359
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:361
TypedPackedFunc(const FLambda &typed_lambda)
construct from a lambda function with the same signature.
Definition: packed_func.h:310
TypedPackedFunc()
default constructor
Definition: packed_func.h:233
TypedPackedFunc(const FLambda &typed_lambda, std::string name)
construct from a lambda function with the same signature.
Definition: packed_func.h:287
TSelf & operator=(FLambda typed_lambda)
copy assignment operator from typed lambda
Definition: packed_func.h:332
TSelf & operator=(PackedFunc packed)
copy assignment operator from PackedFunc.
Definition: packed_func.h:341
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:363
TypedPackedFunc(std::nullptr_t null)
constructor from null
Definition: packed_func.h:235
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:61
Definition: variant.h:69
struct TVMArgs TVMArgs
Runtime Map container types.
const char * ArgTypeCode2Str(int type_code)
Convert argument type code to string.
Definition: packed_func.h:1250
const char * DLDeviceType2Str(int type)
The name of DLDeviceType.
Definition: packed_func.h:1293
DLDataType String2DLDataType(std::string s)
convert a string to TVM type.
Definition: data_type.h:377
std::string() FSig
Using static function to output TypedPackedFunc signature.
Definition: packed_func.h:187
void TVM_ALWAYS_INLINE PackArgs(TVMValue *values, int *type_codes, Args &&... args)
Definition: packed_func.h:1782
std::string DLDataType2String(DLDataType t)
convert a TVM type to string.
Definition: data_type.h:370
std::ostream & operator<<(std::ostream &os, const ObjectRef &n)
Definition: repr_printer.h:97
Object * TVMArrayHandleToObjectHandle(TVMArrayHandle handle)
Definition: ndarray.h:436
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
runtime::DataType DataType
Definition: data_type.h:433
DLDevice Device
Definition: ndarray.h:43
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
A device-independent managed NDArray abstraction.
A managed object in the TVM runtime.
#define TVM_CHECK_TYPE_CODE(CODE, T)
Definition: packed_func.h:425
Runtime container of the functions generated by TVM, This is used to support dynamically link,...
Definition: packed_func.h:38
Byte array type used to pass in byte array When kTVMBytes is used as data type.
Definition: c_runtime_api.h:221
size_t size
Definition: c_runtime_api.h:223
const char * data
Definition: c_runtime_api.h:222
static std::string TypeName()
Definition: packed_func.h:503
static bool Check(const Object *ptr)
Definition: packed_func.h:492
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Definition: packed_func.h:475
static bool Check(const Object *ptr)
Definition: packed_func.h:524
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Definition: packed_func.h:507
static std::string TypeName()
Definition: packed_func.h:534
Type traits for runtime type check during FFI conversion.
Definition: packed_func.h:433
static std::string TypeName()
Definition: packed_func.h:466
static bool Check(const Object *ptr)
Check if an object matches the template type.
Definition: packed_func.h:461
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Check if an object matches the template type and return the mismatched type if it exists.
Definition: packed_func.h:441
Internal struct for extracting the callable method from callable type.
Definition: packed_func.h:87
static void Call(const PackedFuncObj *obj, TVMArgs args, TVMRetValue *rv)
Extracting the callable method from callable type.
Definition: packed_func.h:1236
static Optional< T > From(const TVMRetValue &val)
Definition: packed_func.h:2138
static Optional< T > From(const TVMArgValue &val)
Definition: packed_func.h:2134
static VType From(const PODSubclass &val)
Definition: packed_func.h:2152
static Optional< VType > TryValueConverter(const PODSubclass &val)
Definition: packed_func.h:2180
static Optional< VType > TryAsObjectRef(const TVMPODValue_ &val)
Definition: packed_func.h:2169
static String From(const TVMArgValue &val)
Definition: packed_func.h:2115
static String From(const TVMRetValue &val)
Definition: packed_func.h:2123
Type trait to specify special value conversion rules from TVMArgValue and TVMRetValue.
Definition: packed_func.h:1096
static TObjectRef From(const TVMRetValue &val)
Convert a TObjectRef from a return value.
Definition: packed_func.h:1108
static TObjectRef From(const TVMArgValue &val)
Convert a TObjectRef from an argument value.
Definition: packed_func.h:1102
Definition: packed_func.h:63
Definition: packed_func.h:1775
static TVM_ALWAYS_INLINE void F(TVMArgsSetter *setter, T &&value)
Definition: packed_func.h:1776
@ kRuntimePackedFunc
runtime::PackedFunc.
Definition: object.h:74
Union type of values being passed through API and function calls.
Definition: c_runtime_api.h:208
DLDevice v_device
Definition: c_runtime_api.h:214
void * v_handle
Definition: c_runtime_api.h:211
DLDataType v_type
Definition: c_runtime_api.h:213
int64_t v_int64
Definition: c_runtime_api.h:209
const char * v_str
Definition: c_runtime_api.h:212
double v_float64
Definition: c_runtime_api.h:210
Runtime Variant container types.