tvm
packed_func.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RUNTIME_PACKED_FUNC_H_
25 #define TVM_RUNTIME_PACKED_FUNC_H_
26 
32 #include <tvm/runtime/data_type.h>
33 #include <tvm/runtime/logging.h>
34 #include <tvm/runtime/module.h>
35 #include <tvm/runtime/ndarray.h>
36 #include <tvm/runtime/object.h>
37 
38 #include <functional>
39 #include <limits>
40 #include <memory>
41 #include <optional>
42 #include <string>
43 #include <tuple>
44 #include <type_traits>
45 #include <utility>
46 #include <vector>
47 
48 // Whether use TVM runtime in header only mode.
49 #ifndef TVM_RUNTIME_HEADER_ONLY
50 #define TVM_RUNTIME_HEADER_ONLY 0
51 #endif
52 
53 namespace tvm {
54 namespace runtime {
55 
56 // forward declarations
57 class TVMArgs;
58 class TVMArgValue;
59 class TVMMovableArgValueWithContext_;
60 class TVMRetValue;
61 class TVMArgsSetter;
62 template <typename FType>
64 template <typename TSignature>
66 
71 class PackedFuncObj : public Object {
72  public:
78  TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const;
79 
80  static constexpr const uint32_t _type_index = TypeIndex::kRuntimePackedFunc;
81  static constexpr const char* _type_key = "runtime.PackedFunc";
83 
84  protected:
88  template <class TPackedFuncSubObj>
89  struct Extractor {
96  static void Call(const PackedFuncObj* obj, TVMArgs args, TVMRetValue* rv);
97  };
98 
100  using FCallPacked = void(const PackedFuncObj*, TVMArgs, TVMRetValue*);
101 
106  explicit PackedFuncObj(FCallPacked* f_call_pack) : f_call_packed_(f_call_pack) {}
107 
109  PackedFuncObj() = delete;
110 
113 };
114 
116 template <class TCallable>
118  using TStorage = typename std::remove_cv<typename std::remove_reference<TCallable>::type>::type;
119 
120  public:
127  explicit PackedFuncSubObj(TCallable callable)
128  : PackedFuncObj(Extractor<TSelf>::Call), callable_(callable) {}
130  mutable TStorage callable_;
131 };
132 
141 class PackedFunc : public ObjectRef {
142  public:
144  PackedFunc(std::nullptr_t null) : ObjectRef(nullptr) {} // NOLINT(*)
150  template <typename TCallable,
151  typename = std::enable_if_t<
152  std::is_convertible<TCallable, std::function<void(TVMArgs, TVMRetValue*)>>::value &&
153  !std::is_base_of<TCallable, PackedFunc>::value>>
154  explicit PackedFunc(TCallable data) {
155  using ObjType = PackedFuncSubObj<TCallable>;
156  data_ = make_object<ObjType>(std::forward<TCallable>(data));
157  }
172  template <typename... Args>
173  inline TVMRetValue operator()(Args&&... args) const;
179  TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const;
181  bool operator==(std::nullptr_t null) const { return data_ == nullptr; }
183  bool operator!=(std::nullptr_t null) const { return data_ != nullptr; }
184 
186 };
187 
189 using FSig = std::string();
190 
194 template <typename FType>
195 class TypedPackedFunc;
196 
229 template <typename R, typename... Args>
230 class TypedPackedFunc<R(Args...)> {
231  public:
233  using TSelf = TypedPackedFunc<R(Args...)>;
237  TypedPackedFunc(std::nullptr_t null) {} // NOLINT(*)
255  inline TypedPackedFunc(PackedFunc packed); // NOLINT(*)
260  inline TypedPackedFunc(const TVMRetValue& value); // NOLINT(*)
265  inline TypedPackedFunc(const TVMArgValue& value); // NOLINT(*)
270  inline TypedPackedFunc(TVMMovableArgValueWithContext_&& value); // NOLINT(*)
287  template <typename FLambda, typename = typename std::enable_if<std::is_convertible<
288  FLambda, std::function<R(Args...)>>::value>::type>
289  TypedPackedFunc(const FLambda& typed_lambda, std::string name) { // NOLINT(*)
290  this->AssignTypedLambda(typed_lambda, name);
291  }
310  template <typename FLambda, typename = typename std::enable_if<std::is_convertible<
311  FLambda, std::function<R(Args...)>>::value>::type>
312  TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*)
313  this->AssignTypedLambda(typed_lambda);
314  }
331  template <typename FLambda, typename = typename std::enable_if<
332  std::is_convertible<FLambda,
333  std::function<R(Args...)>>::value>::type>
334  TSelf& operator=(FLambda typed_lambda) { // NOLINT(*)
335  this->AssignTypedLambda(typed_lambda);
336  return *this;
337  }
344  packed_ = packed;
345  return *this;
346  }
352  TVM_ALWAYS_INLINE R operator()(Args... args) const;
357  operator PackedFunc() const { return packed(); }
361  const PackedFunc& packed() const { return packed_; }
363  bool operator==(std::nullptr_t null) const { return packed_ == nullptr; }
365  bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; }
366 
367  private:
368  friend class TVMRetValue;
370  PackedFunc packed_;
379  template <typename FLambda>
380  inline void AssignTypedLambda(FLambda flambda, std::string name);
389  template <typename FLambda>
390  inline void AssignTypedLambda(FLambda flambda);
391 };
392 
394 class TVMArgs {
395  public:
396  const TVMValue* values;
397  const int* type_codes;
398  int num_args;
405  TVMArgs(const TVMValue* values, const int* type_codes, int num_args)
408  inline int size() const;
414  inline TVMArgValue operator[](int i) const;
421  template <typename T>
422  inline T At(int i) const;
423 };
424 
430 inline const char* ArgTypeCode2Str(int type_code);
431 
432 inline std::ostream& operator<<(std::ostream& os, DLDevice dev); // NOLINT(*)
433 
434 #define TVM_LOG_INCORRECT_TYPE_CODE(CODE, T) \
435  "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE)
436 
437 // macro to check type code.
438 #define TVM_CHECK_TYPE_CODE(CODE, T) ICHECK_EQ(CODE, T) << TVM_LOG_INCORRECT_TYPE_CODE(CODE, T)
439 
444 template <typename T>
454  using ContainerType = typename T::ContainerType;
455  if (ptr == nullptr) {
456  if (T::_type_is_nullable) {
457  return NullOpt;
458  } else {
459  return String("nullptr");
460  }
461  }
462  if (ptr->IsInstance<ContainerType>()) {
463  return NullOpt;
464  } else {
465  return String(ptr->GetTypeKey());
466  }
467  }
473  static bool Check(const Object* ptr) {
474  using ContainerType = typename T::ContainerType;
475  if (ptr == nullptr) return T::_type_is_nullable;
476  return ptr->IsInstance<ContainerType>();
477  }
478  static std::string TypeName() {
479  using ContainerType = typename T::ContainerType;
480  return ContainerType::_type_key;
481  }
482 };
483 
484 // Additional overloads for PackedFunc checking.
485 template <typename T>
488  if (ptr == nullptr) {
489  return NullOpt;
490  }
491  if (!ptr->IsInstance<ArrayNode>()) {
492  return String(ptr->GetTypeKey());
493  }
494 
495  if constexpr (std::is_same_v<T, ObjectRef>) {
496  return NullOpt;
497  }
498 
499  const ArrayNode* n = static_cast<const ArrayNode*>(ptr);
500  for (size_t i = 0; i < n->size(); i++) {
501  const ObjectRef& p = (*n)[i];
503  if (check_subtype.defined()) {
504  return String("Array[index " + std::to_string(i) + ": " + check_subtype.value() + "]");
505  }
506  }
507  return NullOpt;
508  }
509  static bool Check(const Object* ptr) {
510  if (ptr == nullptr) return true;
511  if (!ptr->IsInstance<ArrayNode>()) return false;
512  if constexpr (std::is_same_v<T, ObjectRef>) return true;
513 
514  const ArrayNode* n = static_cast<const ArrayNode*>(ptr);
515  for (const ObjectRef& p : *n) {
516  if (!ObjectTypeChecker<T>::Check(p.get())) {
517  return false;
518  }
519  }
520  return true;
521  }
522  static std::string TypeName() { return "Array[" + ObjectTypeChecker<T>::TypeName() + "]"; }
523 };
524 
525 template <typename K, typename V>
526 struct ObjectTypeChecker<Map<K, V>> {
528  if (ptr == nullptr) return NullOpt;
529  if (!ptr->IsInstance<MapNode>()) return String(ptr->GetTypeKey());
530 
531  if constexpr (std::is_same_v<K, ObjectRef> && std::is_same_v<V, ObjectRef>) {
532  return NullOpt;
533  }
534 
535  const MapNode* n = static_cast<const MapNode*>(ptr);
536  for (const auto& kv : *n) {
537  Optional<String> key_type = NullOpt;
538  if constexpr (!std::is_same_v<K, ObjectRef>) {
539  key_type = ObjectTypeChecker<K>::CheckAndGetMismatch(kv.first.get());
540  }
541  Optional<String> value_type = NullOpt;
542  if constexpr (!std::is_same_v<V, ObjectRef>) {
543  value_type = ObjectTypeChecker<K>::CheckAndGetMismatch(kv.first.get());
544  }
545  if (key_type.defined() || value_type.defined()) {
546  std::string key_name =
547  key_type.defined() ? std::string(key_type.value()) : ObjectTypeChecker<K>::TypeName();
548  std::string value_name = value_type.defined() ? std::string(value_type.value())
550  return String("Map[" + key_name + ", " + value_name + "]");
551  }
552  }
553  return NullOpt;
554  }
555  static bool Check(const Object* ptr) {
556  if (ptr == nullptr) return true;
557  if (!ptr->IsInstance<MapNode>()) return false;
558 
559  if constexpr (std::is_same_v<K, ObjectRef> && std::is_same_v<V, ObjectRef>) {
560  return true;
561  }
562 
563  const MapNode* n = static_cast<const MapNode*>(ptr);
564  for (const auto& kv : *n) {
565  if constexpr (!std::is_same_v<K, ObjectRef>) {
566  if (!ObjectTypeChecker<K>::Check(kv.first.get())) return false;
567  }
568  if constexpr (!std::is_same_v<V, ObjectRef>) {
569  if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
570  }
571  }
572  return true;
573  }
574  static std::string TypeName() {
576  ']';
577  }
578 };
579 
580 template <typename OnlyVariant>
581 struct ObjectTypeChecker<Variant<OnlyVariant>> {
584  }
585  static bool Check(const Object* ptr) { return ObjectTypeChecker<OnlyVariant>::Check(ptr); }
586  static std::string TypeName() { return "Variant[" + VariantNames() + "]"; }
587  static std::string VariantNames() { return ObjectTypeChecker<OnlyVariant>::TypeName(); }
588 };
589 
590 template <typename FirstVariant, typename... RemainingVariants>
591 struct ObjectTypeChecker<Variant<FirstVariant, RemainingVariants...>> {
594  if (!try_first.defined()) {
595  return try_first;
596  }
597 
598  return ObjectTypeChecker<Variant<RemainingVariants...>>::CheckAndGetMismatch(ptr);
599  }
600  static bool Check(const Object* ptr) {
602  ObjectTypeChecker<Variant<RemainingVariants...>>::Check(ptr);
603  }
604  static std::string TypeName() { return "Variant[" + VariantNames() + "]"; }
605  static std::string VariantNames() {
607  ObjectTypeChecker<Variant<RemainingVariants...>>::VariantNames();
608  }
609 };
610 
616  public:
617  operator void*() const {
618  if (type_code_ == kTVMNullptr) return nullptr;
621  return value_.v_handle;
622  }
623  operator DLTensor*() const {
625  return static_cast<DLTensor*>(value_.v_handle);
626  } else {
627  if (type_code_ == kTVMNullptr) return nullptr;
628  LOG(FATAL) << "Expected "
629  << "DLTensor* or NDArray but got " << ArgTypeCode2Str(type_code_);
630  return nullptr;
631  }
632  }
633  operator NDArray() const {
634  if (type_code_ == kTVMNullptr) return NDArray(ObjectPtr<Object>(nullptr));
637  }
638  operator Module() const {
639  if (type_code_ == kTVMNullptr) {
640  return Module(ObjectPtr<Object>(nullptr));
641  }
643  return Module(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
644  }
645  operator PackedFunc() const {
646  if (type_code_ == kTVMNullptr) {
647  return PackedFunc(ObjectPtr<Object>(nullptr));
648  }
650  return PackedFunc(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
651  }
652  operator Device() const {
654  return value_.v_device;
655  }
656  int type_code() const { return type_code_; }
662  template <typename T>
663  T* ptr() const {
664  return static_cast<T*>(value_.v_handle);
665  }
666 
667  std::optional<bool> TryAsBool() const {
668  // Helper function to reduce duplication in the variable integer
669  // conversions. This is publicly exposed, as it can be useful in
670  // specializations of PackedFuncValueConverter.
671  if (type_code_ == kTVMArgBool) {
672  return static_cast<bool>(value_.v_int64);
673  } else {
674  return std::nullopt;
675  }
676  }
677 
678  std::optional<int64_t> TryAsInt() const {
679  // Helper function to reduce duplication in the variable integer
680  // conversions. This is publicly exposed, as it can be useful in
681  // specializations of PackedFuncValueConverter.
682  if (type_code_ == kDLInt) {
683  return value_.v_int64;
684  } else {
685  return std::nullopt;
686  }
687  }
688 
689  std::optional<double> TryAsFloat() const {
690  // Helper function to reduce duplication in the variable integer
691  // conversions. This is publicly exposed, as it can be useful in
692  // specializations of PackedFuncValueConverter.
693  if (type_code_ == kDLFloat) {
694  return value_.v_float64;
695  } else {
696  return std::nullopt;
697  }
698  }
699 
700  protected:
701  friend class TVMArgsSetter;
702  friend class TVMRetValue;
703  friend class TVMMovableArgValue_;
706 
711 };
712 
737 template <typename Derived>
739  public:
741 
742  // ObjectRef handling
743  template <typename TObjectRef,
744  typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
745  inline bool IsObjectRef() const;
746  template <typename TObjectRef>
747  inline TObjectRef AsObjectRef() const;
748 
749  operator double() const {
750  // Allow automatic conversion from int to float
751  // This avoids errors when user pass in int from
752  // the frontend while the API expects a float.
753  if (auto opt = TryAsFloat()) {
754  return opt.value();
755  } else if (auto opt = TryAsInt()) {
756  return opt.value();
757  } else if (auto opt = TryAsBool()) {
758  return opt.value();
759  } else {
760  LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLFloat);
761  }
762  }
763  operator int64_t() const {
764  if (auto opt = TryAsInt()) {
765  return opt.value();
766  } else if (auto opt = TryAsBool()) {
767  return opt.value();
768  } else {
769  LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt);
770  }
771  }
772  operator uint64_t() const { return operator int64_t(); }
773  operator int() const {
774  int64_t value = operator int64_t();
775  ICHECK_LE(value, std::numeric_limits<int>::max());
776  ICHECK_GE(value, std::numeric_limits<int>::min());
777  return value;
778  }
779  operator bool() const {
780  if (auto opt = TryAsBool()) {
781  return opt.value();
782  } else if (auto opt = TryAsInt()) {
783  return opt.value();
784  } else {
785  LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt);
786  }
787  }
788 };
789 
796 class TVMArgValue : public TVMPODValue_CRTP_<TVMArgValue> {
797  public:
806  // reuse converter from parent
807  using TVMPODValue_CRTP_::operator double;
808  using TVMPODValue_CRTP_::operator int64_t;
809  using TVMPODValue_CRTP_::operator uint64_t;
810  using TVMPODValue_CRTP_::operator int;
811  using TVMPODValue_CRTP_::operator bool;
812  using TVMPODValue_::operator void*;
813  using TVMPODValue_::operator DLTensor*;
814  using TVMPODValue_::operator NDArray;
815  using TVMPODValue_::operator Device;
816  using TVMPODValue_::operator Module;
817  using TVMPODValue_::operator PackedFunc;
820 
821  // conversion operator.
822  operator std::string() const {
823  if (type_code_ == kTVMDataType) {
824  return DLDataType2String(operator DLDataType());
825  } else if (type_code_ == kTVMBytes) {
826  TVMByteArray* arr = static_cast<TVMByteArray*>(value_.v_handle);
827  return std::string(arr->data, arr->size);
828  } else if (type_code_ == kTVMStr) {
829  return std::string(value_.v_str);
830  } else {
831  return AsObjectRef<tvm::runtime::String>().operator std::string();
832  }
833  }
834  template <typename FType>
835  operator TypedPackedFunc<FType>() const {
836  return TypedPackedFunc<FType>(operator PackedFunc());
837  }
838  const TVMValue& value() const { return value_; }
839 
840  template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
841  inline operator T() const;
842  inline operator DLDataType() const;
843  inline operator DataType() const;
844 };
845 
856 class TVMMovableArgValue_ : public TVMPODValue_CRTP_<TVMMovableArgValue_> {
857  public:
859  // reuse converter from parent
860  using TVMPODValue_CRTP_::operator double;
861  using TVMPODValue_CRTP_::operator int64_t;
862  using TVMPODValue_CRTP_::operator uint64_t;
863  using TVMPODValue_CRTP_::operator int;
864  using TVMPODValue_CRTP_::operator bool;
865  using TVMPODValue_::operator void*;
866  using TVMPODValue_::operator DLTensor*;
867  using TVMPODValue_::operator NDArray;
868  using TVMPODValue_::operator Device;
869  using TVMPODValue_::operator Module;
870  using TVMPODValue_::operator PackedFunc;
871  // reuse conversion rule from ArgValue.
872  operator std::string() const { return AsArgValue().operator std::string(); }
873  template <typename FType>
874  operator TypedPackedFunc<FType>() const {
875  return TypedPackedFunc<FType>(operator PackedFunc());
876  }
877  operator DLDataType() const { return AsArgValue().operator DLDataType(); }
878  operator DataType() const { return AsArgValue().operator DataType(); }
879  operator TVMArgValue() const { return AsArgValue(); }
885  template <typename T,
886  typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
887  inline operator T() const;
888 
889  private:
891  TVMArgValue AsArgValue() const { return TVMArgValue(value_, type_code_); }
892 };
893 
902  public:
912  TVMMovableArgValueWithContext_(TVMValue value, int type_code, int arg_index,
913  const std::string* optional_name, FSig* f_sig)
914  : value_(value, type_code),
915  arg_index_(arg_index),
916  optional_name_(optional_name),
917  f_sig_(f_sig) {}
918 
919  template <typename T>
920  operator T() const {
921  try {
922  return value_; // implicit conversion happens here
923  } catch (dmlc::Error& e) {
924  LOG(FATAL) << "In function " << (optional_name_ == nullptr ? "<anonymous>" : *optional_name_)
925  << (f_sig_ == nullptr ? "" : (*f_sig_)()) << ": error while converting argument "
926  << arg_index_ << ": " << e.what();
927  throw; // never reached, LOG(FATAL) throws, but this silences a warning.
928  }
929  }
930 
931  private:
932  TVMMovableArgValue_ value_;
933  int arg_index_;
934  const std::string* optional_name_;
935  FSig* f_sig_;
936 };
937 
946 class TVMRetValue : public TVMPODValue_CRTP_<TVMRetValue> {
947  public:
955  other.value_.v_handle = nullptr;
956  other.type_code_ = kTVMNullptr;
957  }
959  ~TVMRetValue() { this->Clear(); }
960  // reuse converter from parent
961  using TVMPODValue_CRTP_::operator double;
962  using TVMPODValue_CRTP_::operator int64_t;
963  using TVMPODValue_CRTP_::operator uint64_t;
964  using TVMPODValue_CRTP_::operator int;
965  using TVMPODValue_CRTP_::operator bool;
966  using TVMPODValue_::operator void*;
967  using TVMPODValue_::operator DLTensor*;
968  using TVMPODValue_::operator Device;
969  using TVMPODValue_::operator NDArray;
970  using TVMPODValue_::operator Module;
971  using TVMPODValue_::operator PackedFunc;
974 
975  TVMRetValue(const TVMRetValue& other) : TVMPODValue_CRTP_() { this->Assign(other); }
976  // conversion operators
977  operator std::string() const {
978  if (type_code_ == kTVMDataType) {
979  return DLDataType2String(operator DLDataType());
980  } else if (type_code_ == kTVMBytes) {
981  return *ptr<std::string>();
982  }
984  return *ptr<std::string>();
985  }
986  operator DLDataType() const {
987  if (type_code_ == kTVMStr) {
988  return String2DLDataType(operator std::string());
989  }
991  return value_.v_type;
992  }
993  operator DataType() const { return DataType(operator DLDataType()); }
994  template <typename FType>
995  operator TypedPackedFunc<FType>() const {
996  return TypedPackedFunc<FType>(operator PackedFunc());
997  }
998  // Assign operators
1000  this->Clear();
1001  value_ = other.value_;
1002  type_code_ = other.type_code_;
1003  other.type_code_ = kTVMNullptr;
1004  return *this;
1005  }
1007  this->SwitchToPOD(kDLFloat);
1009  return *this;
1010  }
1011  TVMRetValue& operator=(std::nullptr_t value) {
1012  this->SwitchToPOD(kTVMNullptr);
1013  value_.v_handle = value;
1014  return *this;
1015  }
1017  this->SwitchToPOD(kTVMOpaqueHandle);
1018  value_.v_handle = value;
1019  return *this;
1020  }
1022  this->SwitchToPOD(kDLInt);
1023  value_.v_int64 = value;
1024  return *this;
1025  }
1027  this->SwitchToPOD(kDLInt);
1028  value_.v_int64 = value;
1029  return *this;
1030  }
1032  this->SwitchToPOD(kDLDevice);
1033  value_.v_device = value;
1034  return *this;
1035  }
1036  TVMRetValue& operator=(DLDataType t) {
1037  this->SwitchToPOD(kTVMDataType);
1038  value_.v_type = t;
1039  return *this;
1040  }
1041  TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); }
1043  this->SwitchToPOD(kTVMArgBool);
1044  value_.v_int64 = value;
1045  return *this;
1046  }
1047  TVMRetValue& operator=(std::string value) {
1048  this->SwitchToClass(kTVMStr, value);
1049  return *this;
1050  }
1052  this->SwitchToClass(kTVMBytes, std::string(value.data, value.size));
1053  return *this;
1054  }
1056  if (other.data_ != nullptr) {
1057  this->Clear();
1061  } else {
1062  SwitchToPOD(kTVMNullptr);
1063  value_.v_handle = nullptr;
1064  }
1065  return *this;
1066  }
1068  SwitchToObject(kTVMModuleHandle, std::move(m.data_));
1069  return *this;
1070  }
1072  this->SwitchToObject(kTVMPackedFuncHandle, std::move(f.data_));
1073  return *this;
1074  }
1075  template <typename FType>
1077  return operator=(f.packed());
1078  }
1079  TVMRetValue& operator=(const TVMRetValue& other) { // NOLINT(*0
1080  this->Assign(other);
1081  return *this;
1082  }
1084  this->Assign(other);
1085  return *this;
1086  }
1088  this->Assign(other);
1089  return *this;
1090  }
1100  void MoveToCHost(TVMValue* ret_value, int* ret_type_code) {
1101  // cannot move str; need specially handle.
1102  ICHECK(type_code_ != kTVMStr && type_code_ != kTVMBytes);
1103  *ret_value = value_;
1104  *ret_type_code = type_code_;
1106  }
1115  // Can move POD and everything under the object system.
1117  type_code == kTVMArgBool);
1118  TVMRetValue ret;
1119  ret.value_ = value;
1120  ret.type_code_ = type_code;
1121  return ret;
1122  }
1124  const TVMValue& value() const {
1127  << "TVMRetValue.value can only be used for POD data";
1128  return value_;
1129  }
1130  // ObjectRef handling
1131  template <typename TObjectRef,
1132  typename = typename std::enable_if_t<std::is_base_of_v<ObjectRef, TObjectRef>>>
1133  inline TVMRetValue& operator=(TObjectRef other);
1134  template <typename T, typename = typename std::enable_if_t<std::is_class_v<T>>>
1135  inline operator T() const;
1136 
1137  private:
1138  template <typename T>
1139  void Assign(const T& other) {
1140  switch (other.type_code()) {
1141  case kTVMStr: {
1142  SwitchToClass<std::string>(kTVMStr, other);
1143  break;
1144  }
1145  case kTVMBytes: {
1146  SwitchToClass<std::string>(kTVMBytes, other);
1147  break;
1148  }
1149  case kTVMPackedFuncHandle: {
1150  *this = other.operator PackedFunc();
1151  break;
1152  }
1153  case kTVMModuleHandle: {
1154  *this = other.operator Module();
1155  break;
1156  }
1157  case kTVMNDArrayHandle: {
1158  *this = other.operator NDArray();
1159  break;
1160  }
1161  case kTVMObjectHandle: {
1162  // We already known it is not NDArray/Module, but
1163  // operator=(ObjectRef) also handles conversions from wrappers
1164  // around primitive types. For NDArray/Module, the duplicate
1165  // checks are removed with if constexpr.
1166  operator=(other.operator ObjectRef());
1167  break;
1168  }
1169  case kTVMObjectRValueRefArg: {
1170  operator=(other.operator ObjectRef());
1171  break;
1172  }
1173  default: {
1174  SwitchToPOD(other.type_code());
1175  value_ = other.value_;
1176  break;
1177  }
1178  }
1179  }
1180  // get the internal container.
1181  void SwitchToPOD(int type_code) {
1182  if (type_code_ != type_code) {
1183  this->Clear();
1185  }
1186  }
1187  template <typename T>
1188  void SwitchToClass(int type_code, T v) {
1189  if (type_code_ != type_code) {
1190  this->Clear();
1192  value_.v_handle = new T(v);
1193  } else {
1194  *static_cast<T*>(value_.v_handle) = v;
1195  }
1196  }
1197  void SwitchToObject(int type_code, ObjectPtr<Object> other) {
1198  if (other.data_ != nullptr) {
1199  this->Clear();
1201  // move the handle out
1202  value_.v_handle = other.data_;
1203  other.data_ = nullptr;
1204  } else {
1205  SwitchToPOD(kTVMNullptr);
1206  value_.v_handle = nullptr;
1207  }
1208  }
1209  void Clear() {
1210  if (type_code_ == kTVMNullptr) return;
1211  switch (type_code_) {
1212  case kTVMStr:
1213  case kTVMBytes:
1214  delete ptr<std::string>();
1215  break;
1216  case kTVMPackedFuncHandle:
1217  static_cast<Object*>(value_.v_handle)->DecRef();
1218  break;
1219  case kTVMNDArrayHandle: {
1221  break;
1222  }
1223  case kTVMModuleHandle: {
1224  static_cast<Object*>(value_.v_handle)->DecRef();
1225  break;
1226  }
1227  case kTVMObjectHandle: {
1228  static_cast<Object*>(value_.v_handle)->DecRef();
1229  break;
1230  }
1231  }
1233  }
1234 };
1235 
1245 template <typename TObjectRef>
1252  static TObjectRef From(const TVMArgValue& val) { return val.AsObjectRef<TObjectRef>(); }
1258  static TObjectRef From(const TVMRetValue& val) { return val.AsObjectRef<TObjectRef>(); }
1259 };
1260 
1280 #define TVM_DLL_EXPORT_PACKED_FUNC(ExportName, Function) \
1281  extern "C" { \
1282  TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \
1283  int* out_type_code, void* resource_handle); \
1284  int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \
1285  int* out_type_code, void* resource_handle) { \
1286  try { \
1287  ::tvm::runtime::TVMRetValue rv; \
1288  Function(::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \
1289  rv.MoveToCHost(out_value, out_type_code); \
1290  return 0; \
1291  } catch (const ::std::exception& _except_) { \
1292  TVMAPISetLastError(_except_.what()); \
1293  return -1; \
1294  } \
1295  } \
1296  }
1297 
1298 #define TVM_MODULE_VTABLE_BEGIN(TypeKey) \
1299  const char* type_key() const final { return TypeKey; } \
1300  PackedFunc GetFunction(const String& _name, const ObjectPtr<Object>& _self) override { \
1301  using SelfPtr = std::remove_cv_t<decltype(this)>;
1302 #define TVM_MODULE_VTABLE_END() \
1303  return PackedFunc(nullptr); \
1304  }
1305 #define TVM_MODULE_VTABLE_END_WITH_DEFAULT(MemFunc) \
1306  { \
1307  auto f = (MemFunc); \
1308  return (this->*f)(_name); \
1309  } \
1310  } // NOLINT(*)
1311 #define TVM_MODULE_VTABLE_ENTRY(Name, MemFunc) \
1312  if (_name == Name) { \
1313  return PackedFunc([_self](TVMArgs args, TVMRetValue* rv) -> void { \
1314  using Helper = ::tvm::runtime::detail::ModuleVTableEntryHelper<decltype(MemFunc)>; \
1315  SelfPtr self = static_cast<SelfPtr>(_self.get()); \
1316  CHECK_EQ(args.size(), Helper::LenArgs) \
1317  << "Function `" << self->type_key() << "::" << Name << "` requires " << Helper::LenArgs \
1318  << " arguments, but got " << args.size(); \
1319  Helper::Call(rv, self, MemFunc, args, Helper::IndexSeq{}); \
1320  }); \
1321  }
1322 #define TVM_MODULE_VTABLE_ENTRY_PACKED(Name, MemFunc) \
1323  if (_name == Name) { \
1324  return PackedFunc([_self](TVMArgs args, TVMRetValue* rv) -> void { \
1325  (static_cast<SelfPtr>(_self.get())->*(MemFunc))(args, rv); \
1326  }); \
1327  }
1328 
1364 #define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \
1365  extern "C" { \
1366  TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \
1367  int* out_type_code, void* resource_handle) { \
1368  try { \
1369  auto f = Function; \
1370  using FType = ::tvm::runtime::detail::function_signature<decltype(f)>::FType; \
1371  ::tvm::runtime::TVMRetValue rv; \
1372  ::tvm::runtime::detail::unpack_call_by_signature<FType>::run( \
1373  f, ::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \
1374  rv.MoveToCHost(out_value, out_type_code); \
1375  return 0; \
1376  } catch (const ::std::exception& _except_) { \
1377  TVMAPISetLastError(_except_.what()); \
1378  return -1; \
1379  } \
1380  } \
1381  }
1382 
1383 inline TVMArgValue TVMArgs::operator[](int i) const {
1384  ICHECK_LT(i, num_args) << "not enough argument passed, " << num_args << " passed"
1385  << " but request arg[" << i << "].";
1386  return TVMArgValue(values[i], type_codes[i]);
1387 }
1388 
1389 inline int TVMArgs::size() const { return num_args; }
1390 
1391 template <class TPackedFuncSubObj>
1393  TVMRetValue* rv) {
1394  (static_cast<const TPackedFuncSubObj*>(obj))->callable_(args, rv);
1395 }
1396 
1397 TVM_ALWAYS_INLINE void PackedFuncObj::CallPacked(TVMArgs args, TVMRetValue* rv) const {
1398  (*f_call_packed_)(this, args, rv);
1399 }
1400 
1401 TVM_ALWAYS_INLINE void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const {
1402  (static_cast<PackedFuncObj*>(data_.get()))->CallPacked(args, rv);
1403 }
1404 
1405 // internal namespace
1406 inline const char* ArgTypeCode2Str(int type_code) {
1407  switch (type_code) {
1408  case kDLInt:
1409  return "int";
1410  case kTVMArgBool:
1411  return "bool";
1412  case kDLUInt:
1413  return "uint";
1414  case kDLFloat:
1415  return "float";
1416  case kTVMStr:
1417  return "str";
1418  case kTVMBytes:
1419  return "bytes";
1420  case kTVMOpaqueHandle:
1421  return "handle";
1422  case kTVMNullptr:
1423  return "NULL";
1424  case kTVMDLTensorHandle:
1425  return "ArrayHandle";
1426  case kTVMDataType:
1427  return "DLDataType";
1428  case kDLDevice:
1429  return "DLDevice";
1430  case kTVMPackedFuncHandle:
1431  return "FunctionHandle";
1432  case kTVMModuleHandle:
1433  return "ModuleHandle";
1434  case kTVMNDArrayHandle:
1435  return "NDArrayContainer";
1436  case kTVMObjectHandle:
1437  return "Object";
1439  return "ObjectRValueRefArg";
1440  default:
1441  LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
1442  }
1443  throw;
1444 }
1445 
1451 inline const char* DLDeviceType2Str(int type) {
1452  switch (type) {
1453  case kDLCPU:
1454  return "cpu";
1455  case kDLCUDA:
1456  return "cuda";
1457  case kDLCUDAHost:
1458  return "cuda_host";
1459  case kDLCUDAManaged:
1460  return "cuda_managed";
1461  case kDLOpenCL:
1462  return "opencl";
1463  case kDLSDAccel:
1464  return "sdaccel";
1465  case kDLAOCL:
1466  return "aocl";
1467  case kDLVulkan:
1468  return "vulkan";
1469  case kDLMetal:
1470  return "metal";
1471  case kDLVPI:
1472  return "vpi";
1473  case kDLROCM:
1474  return "rocm";
1475  case kDLROCMHost:
1476  return "rocm_host";
1477  case kDLExtDev:
1478  return "ext_dev";
1479  case kDLOneAPI:
1480  return "oneapi";
1481  case kDLWebGPU:
1482  return "webgpu";
1483  case kDLHexagon:
1484  return "hexagon";
1485  case kOpenGL:
1486  return "opengl";
1487  case kDLMicroDev:
1488  return "microdev";
1489  default:
1490  LOG(FATAL) << "unknown type = " << type;
1491  }
1492  throw;
1493 }
1494 
1495 namespace detail {
1496 
1497 template <bool stop, std::size_t I, typename F>
1498 struct for_each_dispatcher {
1499  template <typename T, typename... Args>
1500  static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*)
1501  f(I, std::forward<T>(value));
1502  for_each_dispatcher<sizeof...(Args) == 0, (I + 1), F>::run(f, std::forward<Args>(args)...);
1503  }
1504 };
1505 
1506 template <std::size_t I, typename F>
1507 struct for_each_dispatcher<true, I, F> {
1508  static void run(const F& f) {} // NOLINT(*)
1509 };
1510 
1511 template <typename F, typename... Args>
1512 inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
1513  for_each_dispatcher<sizeof...(Args) == 0, 0, F>::run(f, std::forward<Args>(args)...);
1514 }
1515 
1516 template <typename T>
1517 struct ModuleVTableEntryHelper {};
1518 
1519 template <typename T, typename R, typename... Args>
1520 struct ModuleVTableEntryHelper<R (T::*)(Args...) const> {
1521  using MemFnType = R (T::*)(Args...) const;
1522  using IndexSeq = std::index_sequence_for<Args...>;
1523  static constexpr const std::size_t LenArgs = sizeof...(Args);
1524 
1525  template <std::size_t... Is>
1526  static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
1527  std::index_sequence<Is...>) {
1528  *rv = (self->*f)(args[Is]...);
1529  }
1530 };
1531 
1532 template <typename T, typename R, typename... Args>
1533 struct ModuleVTableEntryHelper<R (T::*)(Args...)> {
1534  using MemFnType = R (T::*)(Args...);
1535  using IndexSeq = std::index_sequence_for<Args...>;
1536  static constexpr const std::size_t LenArgs = sizeof...(Args);
1537 
1538  template <std::size_t... Is>
1539  static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
1540  std::index_sequence<Is...>) {
1541  *rv = (self->*f)(args[Is]...);
1542  }
1543 };
1544 
1545 template <typename T, typename... Args>
1546 struct ModuleVTableEntryHelper<void (T::*)(Args...) const> {
1547  using MemFnType = void (T::*)(Args...) const;
1548  using IndexSeq = std::index_sequence_for<Args...>;
1549  static constexpr const std::size_t LenArgs = sizeof...(Args);
1550 
1551  template <std::size_t... Is>
1552  static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
1553  std::index_sequence<Is...>) {
1554  (self->*f)(args[Is]...);
1555  }
1556 };
1557 
1558 template <typename T, typename... Args>
1559 struct ModuleVTableEntryHelper<void (T::*)(Args...)> {
1560  using MemFnType = void (T::*)(Args...);
1561  using IndexSeq = std::index_sequence_for<Args...>;
1562  static constexpr const std::size_t LenArgs = sizeof...(Args);
1563 
1564  template <std::size_t... Is>
1565  static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
1566  std::index_sequence<Is...>) {
1567  (self->*f)(args[Is]...);
1568  }
1569 };
1570 
1571 namespace parameter_pack {
1572 
1573 template <typename... EnumArgs>
1574 struct EnumeratedParamPack {
1575  struct InvokeWithoutArg {
1576  template <template <size_t i, typename TArgument> class Functor, typename ExtraParams>
1577  static void F(ExtraParams&& extra_params) {
1578  using TExpander = int[];
1579  (void)TExpander{
1580  0,
1581  (Functor<EnumArgs::i, typename EnumArgs::T>::F(std::forward<ExtraParams>(extra_params)),
1582  0)...,
1583  };
1584  }
1585  };
1586  struct InvokeWithArg {
1587  template <template <size_t i, typename TArgument> class Functor, typename ExtraParams,
1588  typename... Params>
1589  static void F(ExtraParams&& extra_params, Params&&... params) {
1590  using TExpander = int[];
1591  (void)TExpander{
1592  0,
1593  (Functor<EnumArgs::i, typename EnumArgs::T>::F(std::forward<ExtraParams>(extra_params),
1594  std::forward<Params>(params)),
1595  0)...,
1596  };
1597  }
1598  };
1599 };
1600 
1601 template <typename... Args>
1602 struct EnumerateImpl {
1603  private:
1604  template <size_t _i, typename _T>
1605  struct Item {
1606  static const constexpr size_t i = _i;
1607  using T = _T;
1608  };
1609 
1610  template <typename...>
1611  struct Zipper;
1612 
1613  template <std::size_t... id>
1614  struct Zipper<std::integer_sequence<std::size_t, id...>> {
1615  using WithoutArg = typename EnumeratedParamPack<Item<id, Args>...>::InvokeWithoutArg;
1616  using WithArg = typename EnumeratedParamPack<Item<id, Args>...>::InvokeWithArg;
1617  };
1618 
1619  public:
1620  using WithoutArg = typename Zipper<std::index_sequence_for<Args...>>::WithoutArg;
1621  using WithArg = typename Zipper<std::index_sequence_for<Args...>>::WithArg;
1622 };
1623 
1624 template <typename... Args>
1625 using EnumerateWithoutArg = typename EnumerateImpl<Args...>::WithoutArg;
1626 
1627 template <typename... Args>
1628 using EnumerateWithArg = typename EnumerateImpl<Args...>::WithArg;
1629 
1630 template <typename... Args>
1631 struct ParamPack {
1632  template <template <size_t i, typename TArgument> class Functor, typename ExtraParams>
1633  static void InvokeWithoutArg(ExtraParams&& extra_params) {
1634  EnumerateWithoutArg<Args...>::template F<Functor, ExtraParams>(
1635  std::forward<ExtraParams>(extra_params));
1636  }
1637 };
1638 
1639 } // namespace parameter_pack
1640 
1645 template <typename T>
1646 struct func_signature_helper {
1647  using FType = void;
1648 };
1649 
1650 template <typename T, typename R, typename... Args>
1651 struct func_signature_helper<R (T::*)(Args...)> {
1652  using FType = R(Args...);
1653  using ParamType = parameter_pack::ParamPack<Args...>;
1654  using RetType = R;
1655  static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
1656 };
1657 
1658 template <typename T, typename R, typename... Args>
1659 struct func_signature_helper<R (T::*)(Args...) const> {
1660  using FType = R(Args...);
1661  using ParamType = parameter_pack::ParamPack<Args...>;
1662  using RetType = R;
1663  static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
1664 };
1665 
1670 template <typename T>
1671 struct function_signature {
1672  using FType = typename func_signature_helper<decltype(&T::operator())>::FType;
1673  using ParamType = typename func_signature_helper<decltype(&T::operator())>::ParamType;
1674  using RetType = typename func_signature_helper<decltype(&T::operator())>::RetType;
1675 };
1676 
1677 // handle case of function.
1678 template <typename R, typename... Args>
1679 struct function_signature<R(Args...)> {
1680  using FType = R(Args...);
1681  using ParamType = parameter_pack::ParamPack<Args...>;
1682  using RetType = R;
1683  static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
1684 };
1685 
1686 // handle case of function ptr.
1687 template <typename R, typename... Args>
1688 struct function_signature<R (*)(Args...)> {
1689  using FType = R(Args...);
1690  using ParamType = detail::parameter_pack::ParamPack<Args...>;
1691  using RetType = R;
1692  static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
1693 };
1694 
1695 template <typename TSignature>
1696 struct SignaturePrinter;
1697 
1698 namespace type2str {
1699 
1700 template <typename T>
1701 struct TypeSimplifier;
1702 
1703 template <typename T>
1704 struct Type2Str {
1705  template <typename = std::enable_if_t<std::is_base_of<ObjectRef, T>::value>>
1706  static std::string v() {
1707  return T::ContainerType::_type_key;
1708  }
1709 };
1710 template <>
1711 struct Type2Str<int> {
1712  static std::string v() { return "int"; }
1713 };
1714 template <>
1715 struct Type2Str<double> {
1716  static std::string v() { return "double"; }
1717 };
1718 template <>
1719 struct Type2Str<int64_t> {
1720  static std::string v() { return "int64_t"; }
1721 };
1722 template <>
1723 struct Type2Str<uint64_t> {
1724  static std::string v() { return "uint64_t"; }
1725 };
1726 template <>
1727 struct Type2Str<bool> {
1728  static std::string v() { return "bool"; }
1729 };
1730 template <>
1731 struct Type2Str<void> {
1732  static std::string v() { return "void"; }
1733 };
1734 template <>
1735 struct Type2Str<std::basic_string<char>> {
1736  static std::string v() { return "basic_string<char>"; }
1737 };
1738 template <typename K, typename V>
1739 struct Type2Str<Map<K, V>> {
1740  static std::string v() {
1741  return "Map<" + TypeSimplifier<K>::v() + ", " + TypeSimplifier<V>::v() + ">";
1742  }
1743 };
1744 template <>
1745 struct Type2Str<DLDevice> {
1746  static std::string v() { return "DLDevice"; }
1747 };
1748 template <>
1749 struct Type2Str<DLTensor> {
1750  static std::string v() { return "DLTensor"; }
1751 };
1752 template <>
1753 struct Type2Str<DataType> {
1754  static std::string v() { return "DataType"; }
1755 };
1756 template <>
1757 struct Type2Str<DLDataType> {
1758  static std::string v() { return "DLDataType"; }
1759 };
1760 template <>
1761 struct Type2Str<TVMRetValue> {
1762  static std::string v() { return "TVMRetValue"; }
1763 };
1764 template <>
1765 struct Type2Str<TVMArgValue> {
1766  static std::string v() { return "TVMArgValue"; }
1767 };
1768 template <>
1769 struct Type2Str<TVMByteArray> {
1770  static std::string v() { return "TVMByteArray"; }
1771 };
1772 template <typename FType>
1773 struct Type2Str<TypedPackedFunc<FType>> {
1774  static std::string v() { return SignaturePrinter<function_signature<FType>>::F(); }
1775 };
1776 template <typename T>
1777 struct Type2Str<Array<T>> {
1778  static std::string v() { return "Array<" + TypeSimplifier<T>::v() + ">"; }
1779 };
1780 
1785 template <typename T>
1786 struct TypeSimplifier {
1787  static std::string v() {
1788  using U = typename std::remove_cv<
1789  typename std::remove_reference<typename std::remove_pointer<T>::type>::type>::type;
1790  return (std::is_const<T>::value ? "const " : "") + Type2Str<U>::v() +
1791  (std::is_pointer<T>::value ? "*" : "") + (std::is_reference<T>::value ? "&" : "");
1792  }
1793 };
1794 
1795 } // namespace type2str
1796 
1801 template <typename TSignature>
1802 struct SignaturePrinter {
1803  using ParamType = typename TSignature::ParamType;
1804  using RetType = typename TSignature::RetType;
1805 
1806  template <size_t i, typename TArgument>
1807  struct PrintParamType {
1808  static void F(std::ostream& os) {
1809  os << (i == 0 ? "" : ", ") << i << ": " << type2str::TypeSimplifier<TArgument>::v();
1810  }
1811  };
1812 
1813  static std::string F() {
1814  std::ostringstream oss;
1815  oss << "(";
1816  ParamType::template InvokeWithoutArg<PrintParamType>(oss);
1817  oss << ") -> " << type2str::TypeSimplifier<RetType>::v();
1818  return oss.str();
1819  }
1820 };
1821 } // namespace detail
1822 
1823 /* \brief argument settter to PackedFunc */
1825  public:
1826  TVMArgsSetter(TVMValue* values, int* type_codes) : values_(values), type_codes_(type_codes) {}
1827  // setters for POD types
1828  template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type>
1829  TVM_ALWAYS_INLINE void operator()(size_t i, T value) const {
1830  values_[i].v_int64 = static_cast<int64_t>(value);
1831  type_codes_[i] = kDLInt;
1832  }
1833  TVM_ALWAYS_INLINE void operator()(size_t i, bool value) const {
1834  values_[i].v_int64 = value;
1835  type_codes_[i] = kTVMArgBool;
1836  }
1837  TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const {
1838  values_[i].v_int64 = static_cast<int64_t>(value);
1839  ICHECK_LE(value, static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
1840  type_codes_[i] = kDLInt;
1841  }
1842  TVM_ALWAYS_INLINE void operator()(size_t i, double value) const {
1843  values_[i].v_float64 = value;
1844  type_codes_[i] = kDLFloat;
1845  }
1846  TVM_ALWAYS_INLINE void operator()(size_t i, std::nullptr_t value) const {
1847  values_[i].v_handle = value;
1848  type_codes_[i] = kTVMNullptr;
1849  }
1850  TVM_ALWAYS_INLINE void operator()(size_t i, const TVMArgValue& value) const {
1851  values_[i] = value.value_;
1852  type_codes_[i] = value.type_code_;
1853  }
1854  TVM_ALWAYS_INLINE void operator()(size_t i, void* value) const {
1855  values_[i].v_handle = value;
1856  type_codes_[i] = kTVMOpaqueHandle;
1857  }
1858  TVM_ALWAYS_INLINE void operator()(size_t i, DLTensor* value) const {
1859  values_[i].v_handle = value;
1860  type_codes_[i] = kTVMDLTensorHandle;
1861  }
1862  TVM_ALWAYS_INLINE void operator()(size_t i, Device value) const {
1863  values_[i].v_device = value;
1864  type_codes_[i] = kDLDevice;
1865  }
1866  TVM_ALWAYS_INLINE void operator()(size_t i, DLDataType value) const {
1867  values_[i].v_type = value;
1868  type_codes_[i] = kTVMDataType;
1869  }
1870  TVM_ALWAYS_INLINE void operator()(size_t i, DataType dtype) const {
1871  operator()(i, dtype.operator DLDataType());
1872  }
1873  TVM_ALWAYS_INLINE void operator()(size_t i, const char* value) const {
1874  values_[i].v_str = value;
1875  type_codes_[i] = kTVMStr;
1876  }
1877  // setters for container types
1878  TVM_ALWAYS_INLINE void operator()(size_t i, const std::string& value) const {
1879  values_[i].v_str = value.c_str();
1880  type_codes_[i] = kTVMStr;
1881  }
1882  TVM_ALWAYS_INLINE void operator()(size_t i, const TVMByteArray& value) const {
1883  values_[i].v_handle = const_cast<TVMByteArray*>(&value);
1884  type_codes_[i] = kTVMBytes;
1885  }
1886  template <typename FType>
1887  TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc<FType>& value) const {
1888  operator()(i, value.packed());
1889  }
1890  void operator()(size_t i, const TVMRetValue& value) const {
1891  if (value.type_code() == kTVMStr) {
1892  values_[i].v_str = value.ptr<std::string>()->c_str();
1893  type_codes_[i] = kTVMStr;
1894  } else {
1895  ICHECK_NE(value.type_code(), kTVMBytes) << "not handled.";
1896  values_[i] = value.value_;
1897  type_codes_[i] = value.type_code();
1898  }
1899  }
1900  // ObjectRef handling
1901  template <typename TObjectRef,
1902  typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
1903  TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef& value) const {
1904  this->SetObject(i, value);
1905  }
1906 
1907  template <typename TObjectRef,
1908  typename = typename std::enable_if<std::is_base_of<
1909  ObjectRef, typename std::remove_reference<TObjectRef>::type>::value>::type>
1910  TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef&& value) const {
1911  this->SetObject(i, std::forward<TObjectRef>(value));
1912  }
1913 
1914  private:
1915  template <typename TObjectRef>
1916  inline void SetObject(size_t i, TObjectRef&& value) const;
1918  TVMValue* values_;
1920  int* type_codes_;
1921 };
1922 
1923 template <typename... Args>
1924 inline TVMRetValue PackedFunc::operator()(Args&&... args) const {
1925  const int kNumArgs = sizeof...(Args);
1926  const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
1927  TVMValue values[kArraySize];
1928  int type_codes[kArraySize];
1929  detail::for_each(TVMArgsSetter(values, type_codes), std::forward<Args>(args)...);
1930  TVMRetValue rv;
1931  (static_cast<PackedFuncObj*>(data_.get()))
1932  ->CallPacked(TVMArgs(values, type_codes, kNumArgs), &rv);
1933  return rv;
1934 }
1935 
1936 template <size_t i, typename T>
1938  static TVM_ALWAYS_INLINE void F(TVMArgsSetter* setter, T&& value) {
1939  (*setter)(i, std::forward<T>(value));
1940  }
1941 };
1942 
1943 template <typename... Args>
1944 void TVM_ALWAYS_INLINE PackArgs(TVMValue* values, int* type_codes, Args&&... args) {
1945  TVMArgsSetter setter(values, type_codes);
1946  detail::parameter_pack::EnumerateWithArg<Args...>::template F<TVMArgsSetterApply>(
1947  &setter, std::forward<Args>(args)...);
1948 }
1949 
1950 namespace detail {
1951 template <typename R, int nleft, int index, typename F>
1952 struct unpack_call_dispatcher {
1953  template <typename... Args>
1954  TVM_ALWAYS_INLINE static void run(const std::string* optional_name, FSig* f_sig, const F& f,
1955  const TVMArgs& args_pack, TVMRetValue* rv,
1956  Args&&... unpacked_args) {
1957  // construct a movable argument value
1958  // which allows potential move of argument to the input of F.
1959  unpack_call_dispatcher<R, nleft - 1, index + 1, F>::run(
1960  optional_name, f_sig, f, args_pack, rv, std::forward<Args>(unpacked_args)...,
1961  TVMMovableArgValueWithContext_(args_pack.values[index], args_pack.type_codes[index], index,
1962  optional_name, f_sig));
1963  }
1964 };
1965 
1966 template <typename R, int index, typename F>
1967 struct unpack_call_dispatcher<R, 0, index, F> {
1968  template <typename... Args>
1969  TVM_ALWAYS_INLINE static void run(const std::string* optional_name, FSig* f_sig, const F& f,
1970  const TVMArgs& args_pack, TVMRetValue* rv,
1971  Args&&... unpacked_args) {
1972  using RetType = decltype(f(std::forward<Args>(unpacked_args)...));
1973  if (std::is_same<RetType, R>::value) {
1974  *rv = f(std::forward<Args>(unpacked_args)...);
1975  } else {
1976  *rv = R(f(std::forward<Args>(unpacked_args)...));
1977  }
1978  }
1979 };
1980 
1981 template <int index, typename F>
1982 struct unpack_call_dispatcher<void, 0, index, F> {
1983  template <typename... Args>
1984  TVM_ALWAYS_INLINE static void run(const std::string* optional_name, FSig* f_sig, const F& f,
1985  const TVMArgs& args_pack, TVMRetValue* rv,
1986  Args&&... unpacked_args) {
1987  f(std::forward<Args>(unpacked_args)...);
1988  }
1989 };
1990 
1991 template <typename R, int nargs, typename F>
1992 TVM_ALWAYS_INLINE void unpack_call(const std::string* optional_name, const F& f,
1993  const TVMArgs& args, TVMRetValue* rv) {
1994  FSig* f_sig = detail::SignaturePrinter<detail::function_signature<F>>::F;
1995  CHECK_EQ(nargs, args.size()) << "Function "
1996  << (optional_name == nullptr ? "<anonymous>" : *optional_name)
1997  << (f_sig == nullptr ? "" : (*f_sig)()) << " expects " << nargs
1998  << " arguments but " << args.size() << " were provided";
1999  unpack_call_dispatcher<R, nargs, 0, F>::run(optional_name, f_sig, f, args, rv);
2000 }
2001 
2002 template <typename FType>
2003 struct unpack_call_by_signature {};
2004 
2005 template <typename R, typename... Args>
2006 struct unpack_call_by_signature<R(Args...)> {
2007  template <typename F>
2008  TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args, TVMRetValue* rv) {
2009  unpack_call<R, sizeof...(Args)>(nullptr, f, args, rv);
2010  }
2011 };
2012 
2013 template <typename R, typename... Args>
2014 TVM_ALWAYS_INLINE R call_packed(const PackedFunc& pf, Args&&... args) {
2015  return R(pf(std::forward<Args>(args)...));
2016 }
2017 
2018 template <typename R>
2019 struct typed_packed_call_dispatcher {
2020  template <typename... Args>
2021  TVM_ALWAYS_INLINE static R run(const PackedFunc& pf, Args&&... args) {
2022  return pf(std::forward<Args>(args)...);
2023  }
2024 };
2025 
2026 template <>
2027 struct typed_packed_call_dispatcher<void> {
2028  template <typename... Args>
2029  TVM_ALWAYS_INLINE static void run(const PackedFunc& pf, Args&&... args) {
2030  pf(std::forward<Args>(args)...);
2031  }
2032 };
2033 } // namespace detail
2034 
2035 template <typename R, typename... Args>
2036 TypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed) : packed_(packed) {}
2037 
2038 template <typename R, typename... Args>
2040  : packed_(value.operator PackedFunc()) {}
2041 
2042 template <typename R, typename... Args>
2044  : packed_(value.operator PackedFunc()) {}
2045 
2046 template <typename R, typename... Args>
2048  : packed_(value.operator PackedFunc()) {}
2049 
2050 template <typename R, typename... Args>
2051 template <typename FType>
2052 inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda, std::string name) {
2053  FSig* f_sig = detail::SignaturePrinter<detail::function_signature<FType>>::F;
2054  packed_ = PackedFunc([flambda, name, f_sig](const TVMArgs& args, TVMRetValue* rv) {
2055  if (args.size() != sizeof...(Args)) {
2056  LOG(FATAL) << "Function " << name << (f_sig == nullptr ? "" : (*f_sig)()) << " expects "
2057  << sizeof...(Args) << " arguments, but " << args.size() << " were provided.";
2058  }
2059  detail::unpack_call<R, sizeof...(Args)>(&name, flambda, args, rv);
2060  });
2061 }
2062 
2063 template <typename R, typename... Args>
2064 template <typename FType>
2065 inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
2066  FSig* f_sig = detail::SignaturePrinter<detail::function_signature<FType>>::F;
2067  packed_ = PackedFunc([flambda, f_sig](const TVMArgs& args, TVMRetValue* rv) {
2068  if (args.size() != sizeof...(Args)) {
2069  LOG(FATAL) << "Function <anonymous> " << (*f_sig)() << " expects " << sizeof...(Args)
2070  << " arguments, but " << args.size() << " were provided.";
2071  }
2072  detail::unpack_call<R, sizeof...(Args)>(nullptr, flambda, args, rv);
2073  });
2074 }
2075 
2076 template <typename R, typename... Args>
2077 TVM_ALWAYS_INLINE R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
2078  return detail::typed_packed_call_dispatcher<R>::run(packed_, std::forward<Args>(args)...);
2079 }
2080 
2081 template <typename T>
2082 inline T TVMArgs::At(int i) const {
2083  TVMArgValue arg = operator[](i);
2084  try {
2085  return arg.operator T();
2086  } catch (const dmlc::Error& e) {
2087  LOG(FATAL) << "Argument " << i << " cannot be converted to type \""
2088  << tvm::runtime::detail::type2str::Type2Str<T>::v() << "\". Its type is \""
2089  << tvm::runtime::ArgTypeCode2Str(arg.type_code()) << "\".";
2090  }
2091  throw;
2092 }
2093 
2094 // ObjectRef related conversion handling
2095 // Object can have three possible type codes:
2096 // kTVMNDArrayHandle, kTVMModuleHandle, kTVMObjectHandle
2097 //
2098 // We use type traits to eliminate un-necessary checks.
2099 template <typename T>
2100 inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
2101  using ContainerType = typename std::remove_reference<T>::type::ContainerType;
2102  if (!value.defined()) {
2103  type_codes_[i] = kTVMNullptr;
2104  values_[i].v_handle = nullptr;
2105  return;
2106  }
2107 
2108  Object* ptr = value.data_.data_;
2109  if constexpr (std::is_base_of_v<NDArray::ContainerType, ContainerType> ||
2110  std::is_base_of_v<ContainerType, NDArray::ContainerType>) {
2111  if (std::is_base_of_v<NDArray::ContainerType, ContainerType> ||
2113  values_[i].v_handle = NDArray::FFIGetHandle(value);
2114  type_codes_[i] = kTVMNDArrayHandle;
2115  return;
2116  }
2117  }
2118 
2119  if constexpr (std::is_base_of_v<Module::ContainerType, ContainerType> ||
2120  std::is_base_of_v<ContainerType, Module::ContainerType>) {
2121  if (std::is_base_of_v<Module::ContainerType, ContainerType> ||
2123  values_[i].v_handle = ptr;
2124  type_codes_[i] = kTVMModuleHandle;
2125  return;
2126  }
2127  }
2128 
2129  if constexpr (std::is_base_of_v<PackedFunc::ContainerType, ContainerType> ||
2130  std::is_base_of_v<ContainerType, PackedFunc::ContainerType>) {
2131  if (std::is_base_of_v<PackedFunc::ContainerType, ContainerType> ||
2133  values_[i].v_handle = ptr;
2134  type_codes_[i] = kTVMPackedFuncHandle;
2135  return;
2136  }
2137  }
2138 
2139  // Like with BoxInt, unwrap any BoxBool instances. See the BoxInt
2140  // explanation for more detail.
2141  if constexpr (std::is_base_of_v<Bool::ContainerType, ContainerType> ||
2142  std::is_base_of_v<ContainerType, Bool::ContainerType>) {
2143  if (std::is_base_of_v<Bool::ContainerType, ContainerType> ||
2144  ptr->IsInstance<Bool::ContainerType>()) {
2145  values_[i].v_int64 = static_cast<Bool::ContainerType*>(ptr)->value;
2146  type_codes_[i] = kTVMArgBool;
2147  return;
2148  }
2149  }
2150 
2151  // If a boxed integer is being returned, always unbox it to the
2152  // primitive type. This must be checked at the PackedFunc level to
2153  // ensure that a boxed primitive argument is round-tripped correctly
2154  // when the boxing is no longer required.
2155  //
2156  // For example, consider a PackedFunc with signature `ObjectRef
2157  // func(Array<ObjectRef>)`, and returns the first element of that
2158  // array. When passing a Python array `[5, 17.5, "hello"]`, the
2159  // items are converted to `[Box<i64>(5), Box<double>(17.5),
2160  // String("hello")]` in order to provide an `Array<ObjectRef>`.
2161  //
2162  // If we had no additional conversions, the caller would receive the
2163  // return value as a `Box<i64>(5)`, which would be unexpected and
2164  // require additional unwrapping. We could perform this check
2165  // inside the PackedFunc, but that would require a large amount of
2166  // duplicated checked, and would require explicit handling of
2167  // `TVMRetValue`. Instead, this conversion is checked in the FFI
2168  // return value, to ensure that boxing/unboxing is applied
2169  // consistently.
2170  if constexpr (std::is_base_of_v<Int::ContainerType, ContainerType> ||
2171  std::is_base_of_v<ContainerType, Int::ContainerType>) {
2172  if (std::is_base_of_v<Int::ContainerType, ContainerType> ||
2173  ptr->IsInstance<Int::ContainerType>()) {
2174  values_[i].v_int64 = static_cast<Int::ContainerType*>(ptr)->value;
2175  type_codes_[i] = kTVMArgInt;
2176  return;
2177  }
2178  }
2179 
2180  // Like with BoxInt, unwrap any BoxFloat instances. See the BoxInt
2181  // explanation for more detail.
2182  if constexpr (std::is_base_of_v<Float::ContainerType, ContainerType> ||
2183  std::is_base_of_v<ContainerType, Float::ContainerType>) {
2184  if (std::is_base_of_v<Float::ContainerType, ContainerType> ||
2185  ptr->IsInstance<Float::ContainerType>()) {
2186  values_[i].v_float64 = static_cast<Float::ContainerType*>(ptr)->value;
2187  type_codes_[i] = kTVMArgFloat;
2188  return;
2189  }
2190  }
2191 
2192  // Final fallback, if the ObjectRef has no special cases that must
2193  // be expressed within the TVMRetValue.
2194  if constexpr (std::is_rvalue_reference_v<decltype(value)>) {
2195  values_[i].v_handle = const_cast<Object**>(&(value.data_.data_));
2196  type_codes_[i] = kTVMObjectRValueRefArg;
2197  } else {
2198  values_[i].v_handle = value.data_.data_;
2199  type_codes_[i] = kTVMObjectHandle;
2200  }
2201 }
2202 
2203 template <typename Derived>
2204 template <typename TObjectRef, typename>
2206  using ContainerType = typename TObjectRef::ContainerType;
2207  // NOTE: the following code can be optimized by constant folding.
2208  if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
2209  return type_code_ == kTVMNDArrayHandle &&
2210  TVMArrayHandleToObjectHandle(static_cast<TVMArrayHandle>(value_.v_handle))
2211  ->IsInstance<ContainerType>();
2212  }
2213  if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
2214  return type_code_ == kTVMModuleHandle &&
2215  static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
2216  }
2217  if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value) {
2218  return type_code_ == kTVMPackedFuncHandle &&
2219  static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
2220  }
2221  // NOTE: we don't pass NDArray and runtime::Module as RValue ref.
2222  if (type_code_ == kTVMObjectRValueRefArg) {
2223  return ObjectTypeChecker<TObjectRef>::Check(*static_cast<Object**>(value_.v_handle));
2224  }
2225  return (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
2226  type_code_ == kTVMNDArrayHandle) ||
2227  (std::is_base_of<ContainerType, Module::ContainerType>::value &&
2228  type_code_ == kTVMModuleHandle) ||
2229  (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
2230  type_code_ == kTVMPackedFuncHandle) ||
2231  (type_code_ == kTVMObjectHandle &&
2232  ObjectTypeChecker<TObjectRef>::Check(static_cast<Object*>(value_.v_handle)));
2233 }
2234 
2235 template <typename Derived>
2236 template <typename TObjectRef>
2237 inline TObjectRef TVMPODValue_CRTP_<Derived>::AsObjectRef() const {
2238  static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
2239  "Conversion only works for ObjectRef");
2240  using ContainerType = typename TObjectRef::ContainerType;
2241 
2242  if (type_code_ == kTVMNullptr) {
2243  CHECK(TObjectRef::_type_is_nullable)
2244  << "Expect a not null value of " << ContainerType::_type_key;
2245  return TObjectRef(ObjectPtr<Object>(nullptr));
2246  }
2247 
2248  // NOTE: The following code uses "if constexpr" wherever possible to
2249  // minimize the number of runtime checks.
2250  if constexpr (std::is_base_of_v<NDArray::ContainerType, ContainerType>) {
2251  // Casting to a sub-class of NDArray
2253  ObjectPtr<Object> data =
2254  NDArray::FFIDataFromHandle(static_cast<TVMArrayHandle>(value_.v_handle));
2255  CHECK(data->IsInstance<ContainerType>())
2256  << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey();
2257  return TObjectRef(data);
2258  }
2259 
2260  if constexpr (std::is_base_of_v<Module::ContainerType, ContainerType>) {
2261  // Casting to a sub-class of Module
2263  ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle));
2264  CHECK(data->IsInstance<ContainerType>())
2265  << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey();
2266  return TObjectRef(data);
2267  }
2268 
2269  if constexpr (std::is_base_of_v<PackedFunc::ContainerType, ContainerType>) {
2270  // Casting to a sub-class of PackedFunc
2272  ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle));
2273  CHECK(data->IsInstance<ContainerType>())
2274  << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey();
2275  return TObjectRef(data);
2276  }
2277 
2278  if (type_code_ == kTVMObjectHandle) {
2279  // normal object type check.
2280  Object* ptr = static_cast<Object*>(value_.v_handle);
2282  ICHECK(!checked_type.defined()) << "Expected " << ObjectTypeChecker<TObjectRef>::TypeName()
2283  << ", but got " << checked_type.value();
2284  return TObjectRef(GetObjectPtr<Object>(ptr));
2285  } else if (type_code_ == kTVMObjectRValueRefArg) {
2286  Object* ptr = *static_cast<Object**>(value_.v_handle);
2288  ICHECK(!checked_type.defined()) << "Expected " << ObjectTypeChecker<TObjectRef>::TypeName()
2289  << ", but got " << checked_type.value();
2290  return TObjectRef(GetObjectPtr<Object>(ptr));
2291  }
2292 
2293  if constexpr (std::is_base_of_v<ContainerType, NDArray::ContainerType>) {
2294  if (type_code_ == kTVMNDArrayHandle) {
2295  // Casting to a base class that NDArray can sub-class
2296  ObjectPtr<Object> data =
2297  NDArray::FFIDataFromHandle(static_cast<TVMArrayHandle>(value_.v_handle));
2298  return TObjectRef(data);
2299  }
2300  }
2301 
2302  if constexpr (std::is_base_of_v<ContainerType, Module::ContainerType>) {
2303  if (type_code_ == kTVMModuleHandle) {
2304  // Casting to a base class that Module can sub-class
2305  return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
2306  }
2307  }
2308 
2309  if constexpr (std::is_base_of_v<ContainerType, PackedFunc::ContainerType>) {
2310  if (type_code_ == kTVMPackedFuncHandle) {
2311  // Casting to a base class that PackedFunc can sub-class
2312  return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
2313  }
2314  }
2315 
2316  if constexpr (std::is_base_of_v<TObjectRef, Int>) {
2317  if (type_code_ == kTVMArgInt) {
2318  return Int(value_.v_int64);
2319  }
2320  }
2321 
2322  if constexpr (std::is_base_of_v<TObjectRef, Float>) {
2323  if (type_code_ == kTVMArgFloat) {
2324  return Float(value_.v_float64);
2325  }
2326  }
2327 
2328  if constexpr (std::is_base_of_v<TObjectRef, Bool>) {
2329  if (type_code_ == kTVMArgBool) {
2330  return Bool(value_.v_int64);
2331  }
2332  }
2333 
2334  if constexpr (std::is_base_of_v<TObjectRef, String>) {
2335  if (type_code_ == kTVMStr || type_code_ == kTVMBytes) {
2336  // This step is the reason why `AsObjectRef` cannot be provided
2337  // in the base `TVMPODValue_` class. Because `TVMArgValue` and
2338  // `TVMRetValue` have different implementations of `operator
2339  // std::string`, with different interpretations of `kTVMStr` and
2340  // `kTVMBytes`, we must delegate to those implementations.
2341  //
2342  // This could be done with a pure virtual method in
2343  // `TVMPODValue_`, but that would require a vtable lookup during
2344  // FFI conversions, imposing a runtime overhead.
2345  return String(static_cast<const Derived*>(this)->operator std::string());
2346  }
2347  }
2348 
2350  return TObjectRef(ObjectPtr<Object>(nullptr));
2351 }
2352 
2353 template <typename TObjectRef, typename>
2354 inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) {
2355  using ContainerType = typename TObjectRef::ContainerType;
2356  const Object* ptr = other.get();
2357 
2358  if (ptr) {
2359  // Check for special cases of ObjectRef that have explicit
2360  // representation within the TVMRetValue structure.
2361  // (e.g. Unboxing of `runtime::Int` into a primitive integer
2362  // with type code kTVMArgInt.) The checks below are written to
2363  // handle three distinct cases.
2364  //
2365  // 1. If TObjectRef is a subclass of TSpecialCase, the special
2366  // case applies, and can be handled without a runtime check.
2367  // No runtime checks should be performed.
2368  //
2369  // 2. If TSpecialCase is a subclass of TObjectRef, the special
2370  // case might apply, and requires a runtime check.
2371  //
2372  // 3. If neither TObjectRef nor TSpecialCase is a subclass of
2373  // the other, then the special case does not apply. No
2374  // runtime checks should be performed.
2375  //
2376  // Use of `if constexpr` ensures that the C++ subclass checks
2377  // are applied when compiling TVM, and runtime overhead are only
2378  // present when they may be applicable.
2379 
2380  if constexpr (std::is_base_of_v<ContainerType, NDArray::ContainerType> ||
2381  std::is_base_of_v<NDArray::ContainerType, ContainerType>) {
2382  if (std::is_base_of_v<NDArray::ContainerType, ContainerType> ||
2383  ptr->IsInstance<NDArray::ContainerType>()) {
2384  return operator=(NDArray(std::move(other.data_)));
2385  }
2386  }
2387 
2388  if constexpr (std::is_base_of_v<ContainerType, Module::ContainerType> ||
2389  std::is_base_of_v<Module::ContainerType, ContainerType>) {
2390  if (std::is_base_of_v<Module::ContainerType, ContainerType> ||
2391  ptr->IsInstance<Module::ContainerType>()) {
2392  return operator=(Module(std::move(other.data_)));
2393  }
2394  }
2395 
2396  if constexpr (std::is_base_of_v<ContainerType, PackedFunc::ContainerType> ||
2397  std::is_base_of_v<PackedFunc::ContainerType, ContainerType>) {
2398  if (std::is_base_of_v<PackedFunc::ContainerType, ContainerType> ||
2399  ptr->IsInstance<PackedFunc::ContainerType>()) {
2400  return operator=(PackedFunc(std::move(other.data_)));
2401  }
2402  }
2403 
2404  if constexpr (std::is_base_of_v<Bool, TObjectRef> || std::is_base_of_v<TObjectRef, Bool>) {
2405  if (std::is_base_of_v<Bool, TObjectRef> || ptr->IsInstance<Bool::ContainerType>()) {
2406  bool value = static_cast<const Bool::ContainerType*>(ptr)->value;
2407  return operator=(value);
2408  }
2409  }
2410 
2411  if constexpr (std::is_base_of_v<Int, TObjectRef> || std::is_base_of_v<TObjectRef, Int>) {
2412  if (std::is_base_of_v<Int, TObjectRef> || ptr->IsInstance<Int::ContainerType>()) {
2413  int64_t value = static_cast<const Int::ContainerType*>(ptr)->value;
2414  return operator=(value);
2415  }
2416  }
2417 
2418  if constexpr (std::is_base_of_v<Float, TObjectRef> || std::is_base_of_v<TObjectRef, Float>) {
2419  if (std::is_base_of_v<Float, TObjectRef> || ptr->IsInstance<Float::ContainerType>()) {
2420  double value = static_cast<const Float::ContainerType*>(ptr)->value;
2421  return operator=(value);
2422  }
2423  }
2424 
2425  // If the object being stored is not one of the special cases,
2426  // it is stored as an ObjectRef.
2427  SwitchToObject(kTVMObjectHandle, std::move(other.data_));
2428 
2429  } else {
2430  // No object is present, set to an explicitly null handle. When
2431  // returning to a Python callee, this will be converted to
2432  // `None`.
2433  SwitchToPOD(kTVMNullptr);
2434  value_.v_handle = nullptr;
2435  }
2436 
2437  return *this;
2438 }
2439 
2440 template <typename T, typename>
2441 inline TVMArgValue::operator T() const {
2442  return PackedFuncValueConverter<T>::From(*this);
2443 }
2444 
2445 template <typename T, typename>
2446 inline TVMMovableArgValue_::operator T() const {
2447  if (type_code_ == kTVMObjectRValueRefArg) {
2448  auto** ref = static_cast<Object**>(value_.v_handle);
2449  if (ObjectTypeChecker<T>::Check(*ref)) {
2451  }
2452  }
2453  // fallback
2454  return PackedFuncValueConverter<T>::From(AsArgValue());
2455 }
2456 
2457 template <typename T, typename>
2458 inline TVMRetValue::operator T() const {
2459  return PackedFuncValueConverter<T>::From(*this);
2460 }
2461 
2462 inline PackedFunc Module::GetFunction(const String& name, bool query_imports) {
2463  return (*this)->GetFunction(name, query_imports);
2464 }
2465 
2466 // specializations of PackedFuncValueConverter
2467 template <>
2469  template <typename PODSubclass>
2470  static String From(const PODSubclass& val) {
2471  if (val.template IsObjectRef<tvm::runtime::String>()) {
2472  return val.template AsObjectRef<tvm::runtime::String>();
2473  } else {
2474  return tvm::runtime::String(val.operator std::string());
2475  }
2476  }
2477 };
2478 
2479 template <typename T>
2481  static Array<T> From(const TVMArgValue& val) {
2482  auto untyped_array = val.AsObjectRef<Array<ObjectRef>>();
2483 
2484  if constexpr (std::is_same_v<T, ObjectRef>) {
2485  return untyped_array;
2486  }
2487 
2488  // Attempt to convert each item of the array into the desired
2489  // type. If the items do not require a conversion, no copies are
2490  // made.
2491  return untyped_array.Map([](ObjectRef item) {
2492  // Recursively apply any conversions that have been registered
2493  // with TVM's FFI.
2494  //
2495  // For example, a function that accepts `Array<PrimExpr>` may
2496  // be called from python with argument `[1,2]`. By the time
2497  // `PackedFuncValueConverter::From` is called, the python list
2498  // has been converted to `Array<ObjectRef>`, with contents
2499  // converted into `runtime::Int`. Converting the `ObjectRef`
2500  // to `TVMArgValue` unboxes the `runtime::Int` back into a
2501  // primitive with type code `kTVMArgInt`. This primitive can
2502  // then be converted to a PrimExpr using
2503  // `PackedFuncValueConverter<PrimExpr>::From`.
2504  //
2505  // The use of two conversions, first from python `int` to
2506  // `runtime::Int` and then from `runtime::Int` to `PrimExpr`,
2507  // is a result of the split between `libtvm_runtime.so` and
2508  // `libtvm.so`. The FFI must function correctly in both
2509  // cases, and so conversions applied by default in the Python
2510  // FFI implementation may only produce types that are
2511  // available in both libraries. In the C++ FFI implementation
2512  // (i.e. this file), libtvm.so may apply additional
2513  // conversions that are not present in libtvm_runtime.so.
2514  TVMValue value;
2515  int type_code;
2516  TVMArgsSetter setter(&value, &type_code);
2517  setter(0, item);
2518  TVMArgValue arg(value, type_code);
2520  });
2521  }
2522  static Array<T> From(const TVMRetValue& val) {
2523  auto untyped_array = val.AsObjectRef<Array<ObjectRef>>();
2524 
2525  if constexpr (std::is_same_v<T, ObjectRef>) {
2526  return untyped_array;
2527  }
2528 
2529  return untyped_array.Map([](ObjectRef item) {
2530  TVMRetValue item_val;
2531  item_val = std::move(item);
2532  return PackedFuncValueConverter<T>::From(item_val);
2533  });
2534  }
2535 };
2536 
2537 template <typename T, typename U>
2539  static Map<T, U> From(const TVMArgValue& val) {
2540  auto untyped_map = val.AsObjectRef<Map<ObjectRef, ObjectRef>>();
2541 
2542  if constexpr (std::is_same_v<T, ObjectRef> && std::is_same_v<U, ObjectRef>) {
2543  return Downcast<Map<T, U>>(untyped_map);
2544  }
2545 
2546  if (ObjectTypeChecker<Map<T, U>>::Check(untyped_map.get())) {
2547  // Early bail-out for common case where no type conversions are
2548  // required.
2549  return Downcast<Map<T, U>>(untyped_map);
2550  }
2551 
2552  Map<T, U> output;
2553  for (const auto& kv : untyped_map) {
2554  T new_key = [&]() {
2555  if constexpr (std::is_same_v<T, ObjectRef>) {
2556  return kv.first;
2557  } else {
2558  TVMValue pod_value;
2559  int type_code;
2560  TVMArgsSetter setter(&pod_value, &type_code);
2561  setter(0, kv.first);
2562  TVMArgValue pod_arg(pod_value, type_code);
2563  return PackedFuncValueConverter<T>::From(pod_arg);
2564  }
2565  }();
2566  U new_value = [&]() {
2567  if constexpr (std::is_same_v<U, ObjectRef>) {
2568  return kv.second;
2569  } else {
2570  TVMValue pod_value;
2571  int type_code;
2572  TVMArgsSetter setter(&pod_value, &type_code);
2573  setter(0, kv.second);
2574  TVMArgValue key_arg(pod_value, type_code);
2575  return PackedFuncValueConverter<U>::From(key_arg);
2576  }
2577  }();
2578  output.Set(new_key, new_value);
2579  }
2580  return output;
2581  }
2582  static Map<T, U> From(const TVMRetValue& val) {
2583  auto untyped_map = val.AsObjectRef<Map<ObjectRef, ObjectRef>>();
2584 
2585  if constexpr (std::is_same_v<T, ObjectRef> && std::is_same_v<U, ObjectRef>) {
2586  return Downcast<Map<T, U>>(untyped_map);
2587  }
2588 
2589  if (ObjectTypeChecker<Map<T, U>>::Check(untyped_map.get())) {
2590  // Early bail-out for common case where no type conversions are
2591  // required.
2592  return Downcast<Map<T, U>>(untyped_map);
2593  }
2594 
2595  Map<T, U> output;
2596  for (const auto& kv : untyped_map) {
2597  T new_key = [&]() {
2598  if constexpr (std::is_same_v<T, ObjectRef>) {
2599  return kv.first;
2600  } else {
2601  TVMRetValue pod;
2602  pod = kv.first;
2604  }
2605  }();
2606  U new_value = [&]() {
2607  if constexpr (std::is_same_v<U, ObjectRef>) {
2608  return kv.second;
2609  } else {
2610  TVMRetValue pod;
2611  pod = kv.second;
2613  }
2614  }();
2615  output.Set(new_key, new_value);
2616  }
2617  return output;
2618  }
2619 };
2620 
2621 template <typename T>
2623  static Optional<T> From(const TVMArgValue& val) {
2624  if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
2626  }
2627  static Optional<T> From(const TVMRetValue& val) {
2628  if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
2630  }
2631 };
2632 
2633 template <typename... VariantTypes>
2634 struct PackedFuncValueConverter<Variant<VariantTypes...>> {
2635  using VType = Variant<VariantTypes...>;
2636 
2637  // Can't just take `const TVMPODValue&` as an argument, because
2638  // `TVMArgValue` and `TVMRetValue` have different implementations
2639  // for `operator std::string()`.
2640  template <typename PODSubclass>
2641  static VType From(const PODSubclass& val) {
2642  if (auto opt = TryAsObjectRef<VariantTypes...>(val)) {
2643  return opt.value();
2644  }
2645 
2646  if (auto opt = TryValueConverter<VariantTypes...>(val)) {
2647  return opt.value();
2648  }
2649 
2650  LOG(FATAL) << "Expected one of "
2651  << static_cast<const std::stringstream&>(
2652  (std::stringstream() << ... << VariantTypes::ContainerType::_type_key))
2653  .str()
2654  << " but got " << ArgTypeCode2Str(val.type_code());
2655  }
2656 
2657  template <typename VarFirst, typename... VarRest, typename PODSubclass>
2658  static Optional<VType> TryAsObjectRef(const PODSubclass& val) {
2659  if (val.template IsObjectRef<VarFirst>()) {
2660  return VType(val.template AsObjectRef<VarFirst>());
2661  } else if constexpr (sizeof...(VarRest)) {
2662  return TryAsObjectRef<VarRest...>(val);
2663  } else {
2664  return NullOpt;
2665  }
2666  }
2667 
2668  template <typename VarFirst, typename... VarRest, typename PODSubclass>
2669  static Optional<VType> TryValueConverter(const PODSubclass& val) {
2670  try {
2672  } catch (const Error&) {
2673  }
2674 
2675  if constexpr (sizeof...(VarRest)) {
2676  return TryValueConverter<VarRest...>(val);
2677  } else {
2678  return NullOpt;
2679  }
2680  }
2681 };
2682 
2683 inline bool String::CanConvertFrom(const TVMArgValue& val) {
2684  return val.type_code() == kTVMStr || val.IsObjectRef<tvm::runtime::String>();
2685 }
2686 
2687 inline TVMArgValue::operator DLDataType() const {
2688  if (String::CanConvertFrom(*this)) {
2689  return String2DLDataType(PackedFuncValueConverter<String>::From(*this).operator std::string());
2690  }
2691  // None type
2692  if (type_code_ == kTVMNullptr) {
2693  DLDataType t;
2694  t.code = kTVMOpaqueHandle;
2695  t.bits = 0;
2696  t.lanes = 0;
2697  return t;
2698  }
2699  TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType);
2700  return value_.v_type;
2701 }
2702 
2703 inline TVMArgValue::operator DataType() const { return DataType(operator DLDataType()); }
2704 
2705 } // namespace runtime // NOLINT(*)
2706 } // namespace tvm // NOLINT(*)
2707 #endif // TVM_RUNTIME_PACKED_FUNC_H_
Runtime Array container types.
Runtime container types for primitives stored as ObjectRef.
@ kTVMPackedFuncHandle
Definition: c_runtime_api.h:185
@ kTVMNDArrayHandle
Definition: c_runtime_api.h:188
@ kTVMModuleHandle
Definition: c_runtime_api.h:184
@ kTVMBytes
Definition: c_runtime_api.h:187
@ kTVMDataType
Definition: c_runtime_api.h:180
@ kTVMArgBool
Definition: c_runtime_api.h:190
@ kTVMArgInt
Definition: c_runtime_api.h:176
@ kTVMDLTensorHandle
Definition: c_runtime_api.h:182
@ kDLDevice
Definition: c_runtime_api.h:181
@ kTVMOpaqueHandle
Definition: c_runtime_api.h:178
@ kTVMObjectHandle
Definition: c_runtime_api.h:183
@ kTVMObjectRValueRefArg
Definition: c_runtime_api.h:189
@ kTVMNullptr
Definition: c_runtime_api.h:179
@ kTVMArgFloat
Definition: c_runtime_api.h:177
@ kTVMStr
Definition: c_runtime_api.h:186
@ kDLMicroDev
Definition: c_runtime_api.h:125
@ kOpenGL
Definition: c_runtime_api.h:124
@ kDLSDAccel
Definition: c_runtime_api.h:123
@ kDLAOCL
Definition: c_runtime_api.h:122
DLTensor * TVMArrayHandle
the array handle
Definition: c_runtime_api.h:204
array node content in array
Definition: array.h:40
size_t size() const
Definition: array.h:43
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Array< U > Map(F fmap) const
Helper function to apply a map function onto the array.
Definition: array.h:651
Runtime primitive data type.
Definition: data_type.h:43
Shared content of all specializations of hash map.
Definition: map.h:174
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics,...
Definition: map.h:1271
void Set(const K &key, const V &value)
set the Map.
Definition: map.h:1374
Base container of module.
Definition: module.h:142
Module container of TVM.
Definition: module.h:79
PackedFunc GetFunction(const String &name, bool query_imports=false)
Get packed function from current module by name.
Definition: packed_func.h:2462
Object container class that backs NDArray.
Definition: ndarray.h:306
Managed NDArray. The array is backed by reference counted blocks.
Definition: ndarray.h:51
static TVMArrayHandle FFIGetHandle(const ObjectRef &nd)
Get FFI Array handle from ndarray.
Definition: ndarray.h:434
static ObjectPtr< Object > FFIDataFromHandle(TVMArrayHandle handle)
Construct NDArray's Data field from array handle in FFI.
Definition: ndarray.h:429
static void FFIDecRef(TVMArrayHandle handle)
DecRef resource managed by an FFI array handle.
Definition: ndarray.h:442
A custom smart pointer for Object.
Definition: object.h:362
Base class of all object reference.
Definition: object.h:519
bool defined() const
Definition: object.h:552
static void FFIClearAfterMove(ObjectRef *ref)
Clear the object ref data field without DecRef after we successfully moved the field.
Definition: object.h:623
friend class TVMArgsSetter
Definition: object.h:637
const Object * get() const
Definition: object.h:554
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:605
base class of all object containers.
Definition: object.h:171
std::string GetTypeKey() const
Definition: object.h:184
bool IsInstance() const
Definition: object.h:874
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
T value() const
Definition: optional.h:92
Object container class that backs PackedFunc.
Definition: packed_func.h:71
PackedFuncObj(FCallPacked *f_call_pack)
Constructing a packed function object from a function pointer.
Definition: packed_func.h:106
static constexpr const uint32_t _type_index
Definition: packed_func.h:80
TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncObj, Object)
FCallPacked * f_call_packed_
Internal callable function pointer used to call the packed function.
Definition: packed_func.h:112
static constexpr const char * _type_key
Definition: packed_func.h:81
PackedFuncObj()=delete
Delete the default constructor explicitly.
void(const PackedFuncObj *, TVMArgs, TVMRetValue *) FCallPacked
The internal callable function type.
Definition: packed_func.h:100
TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue *rv) const
Call the function in packed format.
Definition: packed_func.h:1397
Derived object class for constructing PackedFuncObj.
Definition: packed_func.h:117
PackedFuncSubObj(TCallable callable)
Derived object class for constructing PackedFuncObj.
Definition: packed_func.h:127
TStorage callable_
Type-erased filed for storing callable object.
Definition: packed_func.h:130
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:141
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:183
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:181
TVMRetValue operator()(Args &&... args) const
Call packed function by directly passing in unpacked format.
Definition: packed_func.h:1924
TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue *rv) const
Call the function in packed format.
Definition: packed_func.h:1401
PackedFunc(TCallable data)
Constructing a packed function from a callable type whose signature is consistent with PackedFunc
Definition: packed_func.h:154
TVM_DEFINE_OBJECT_REF_METHODS(PackedFunc, ObjectRef, PackedFuncObj)
PackedFunc(std::nullptr_t null)
Constructor from null.
Definition: packed_func.h:144
Reference to string objects.
Definition: string.h:98
static bool CanConvertFrom(const TVMArgValue &val)
Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String.
Definition: packed_func.h:2683
A single argument value to PackedFunc. Containing both type_code and TVMValue.
Definition: packed_func.h:796
bool IsObjectRef() const
Definition: packed_func.h:2205
const TVMValue & value() const
Definition: packed_func.h:838
TVMArgValue(TVMValue value, int type_code)
constructor
Definition: packed_func.h:805
TObjectRef AsObjectRef() const
Definition: packed_func.h:2237
TVMArgValue()
default constructor
Definition: packed_func.h:799
Definition: packed_func.h:1824
TVM_ALWAYS_INLINE void operator()(size_t i, DLTensor *value) const
Definition: packed_func.h:1858
TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const
Definition: packed_func.h:1837
TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc< FType > &value) const
Definition: packed_func.h:1887
TVM_ALWAYS_INLINE void operator()(size_t i, bool value) const
Definition: packed_func.h:1833
TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef &value) const
Definition: packed_func.h:1903
void operator()(size_t i, const TVMRetValue &value) const
Definition: packed_func.h:1890
TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef &&value) const
Definition: packed_func.h:1910
TVM_ALWAYS_INLINE void operator()(size_t i, double value) const
Definition: packed_func.h:1842
TVM_ALWAYS_INLINE void operator()(size_t i, T value) const
Definition: packed_func.h:1829
TVMArgsSetter(TVMValue *values, int *type_codes)
Definition: packed_func.h:1826
TVM_ALWAYS_INLINE void operator()(size_t i, const char *value) const
Definition: packed_func.h:1873
TVM_ALWAYS_INLINE void operator()(size_t i, Device value) const
Definition: packed_func.h:1862
TVM_ALWAYS_INLINE void operator()(size_t i, void *value) const
Definition: packed_func.h:1854
TVM_ALWAYS_INLINE void operator()(size_t i, const std::string &value) const
Definition: packed_func.h:1878
TVM_ALWAYS_INLINE void operator()(size_t i, std::nullptr_t value) const
Definition: packed_func.h:1846
TVM_ALWAYS_INLINE void operator()(size_t i, const TVMByteArray &value) const
Definition: packed_func.h:1882
TVM_ALWAYS_INLINE void operator()(size_t i, DataType dtype) const
Definition: packed_func.h:1870
TVM_ALWAYS_INLINE void operator()(size_t i, DLDataType value) const
Definition: packed_func.h:1866
TVM_ALWAYS_INLINE void operator()(size_t i, const TVMArgValue &value) const
Definition: packed_func.h:1850
Arguments into TVM functions.
Definition: packed_func.h:394
T At(int i) const
Get the i-th argument and do proper type checking with detailed error messages.
Definition: packed_func.h:2082
const TVMValue * values
Definition: packed_func.h:396
TVMArgs(const TVMValue *values, const int *type_codes, int num_args)
constructor
Definition: packed_func.h:405
int size() const
Definition: packed_func.h:1389
TVMArgValue operator[](int i) const
Get i-th argument.
Definition: packed_func.h:1383
const int * type_codes
Definition: packed_func.h:397
int num_args
Definition: packed_func.h:398
Internal auxiliary struct for TypedPackedFunc to indicate a movable argument with additional context ...
Definition: packed_func.h:901
TVMMovableArgValueWithContext_(TVMValue value, int type_code, int arg_index, const std::string *optional_name, FSig *f_sig)
move constructor from another return value.
Definition: packed_func.h:912
Internal auxiliary struct for TypedPackedFunc to indicate a movable argument.
Definition: packed_func.h:856
TVMMovableArgValue_(TVMValue value, int type_code)
Definition: packed_func.h:858
A utility class that adds methods useful for each POD type.
Definition: packed_func.h:738
bool IsObjectRef() const
Definition: packed_func.h:2205
TObjectRef AsObjectRef() const
Definition: packed_func.h:2237
Internal base class to handle conversion to POD values.
Definition: packed_func.h:615
TVMPODValue_()
Definition: packed_func.h:704
std::optional< bool > TryAsBool() const
Definition: packed_func.h:667
TVMValue value_
The value.
Definition: packed_func.h:708
std::optional< int64_t > TryAsInt() const
Definition: packed_func.h:678
std::optional< double > TryAsFloat() const
Definition: packed_func.h:689
T * ptr() const
return handle as specific pointer type.
Definition: packed_func.h:663
int type_code_
the type code
Definition: packed_func.h:710
int type_code() const
Definition: packed_func.h:656
TVMPODValue_(TVMValue value, int type_code)
Definition: packed_func.h:705
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:946
static TVMRetValue MoveFromCHost(TVMValue value, int type_code)
Construct a new TVMRetValue by moving from return value stored via C API.
Definition: packed_func.h:1114
const TVMValue & value() const
Definition: packed_func.h:1124
~TVMRetValue()
destructor
Definition: packed_func.h:959
TVMRetValue & operator=(DLDataType t)
Definition: packed_func.h:1036
TVMRetValue & operator=(const TypedPackedFunc< FType > &f)
Definition: packed_func.h:1076
TVMRetValue & operator=(std::nullptr_t value)
Definition: packed_func.h:1011
TVMRetValue & operator=(TVMRetValue &&other)
Definition: packed_func.h:999
TVMRetValue & operator=(PackedFunc f)
Definition: packed_func.h:1071
TVMRetValue & operator=(bool value)
Definition: packed_func.h:1042
TVMRetValue & operator=(int value)
Definition: packed_func.h:1026
void MoveToCHost(TVMValue *ret_value, int *ret_type_code)
Move the value back to front-end via C API. This marks the current container as null....
Definition: packed_func.h:1100
TVMRetValue()
default constructor
Definition: packed_func.h:949
TVMRetValue & operator=(void *value)
Definition: packed_func.h:1016
TObjectRef AsObjectRef() const
Definition: packed_func.h:2237
TVMRetValue & operator=(std::string value)
Definition: packed_func.h:1047
TVMRetValue & operator=(const TVMArgValue &other)
Definition: packed_func.h:1083
TVMRetValue(const TVMRetValue &other)
Definition: packed_func.h:975
TVMRetValue & operator=(double value)
Definition: packed_func.h:1006
TVMRetValue(TVMRetValue &&other)
move constructor from another return value.
Definition: packed_func.h:954
TVMRetValue & operator=(const TVMRetValue &other)
Definition: packed_func.h:1079
TVMRetValue & operator=(int64_t value)
Definition: packed_func.h:1021
TVMRetValue & operator=(Module m)
Definition: packed_func.h:1067
TVMRetValue & operator=(const DataType &other)
Definition: packed_func.h:1041
TVMRetValue & operator=(DLDevice value)
Definition: packed_func.h:1031
TVMRetValue & operator=(TVMMovableArgValue_ &&other)
Definition: packed_func.h:1087
TVMRetValue & operator=(TVMByteArray value)
Definition: packed_func.h:1051
TVMRetValue & operator=(NDArray other)
Definition: packed_func.h:1055
A PackedFunc wrapper to provide typed function signature. It is backed by a PackedFunc internally.
Definition: packed_func.h:230
const PackedFunc & packed() const
Definition: packed_func.h:361
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:363
TypedPackedFunc(const FLambda &typed_lambda)
construct from a lambda function with the same signature.
Definition: packed_func.h:312
TypedPackedFunc()
default constructor
Definition: packed_func.h:235
TypedPackedFunc(const FLambda &typed_lambda, std::string name)
construct from a lambda function with the same signature.
Definition: packed_func.h:289
TSelf & operator=(FLambda typed_lambda)
copy assignment operator from typed lambda
Definition: packed_func.h:334
TSelf & operator=(PackedFunc packed)
copy assignment operator from PackedFunc.
Definition: packed_func.h:343
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:365
TypedPackedFunc(std::nullptr_t null)
constructor from null
Definition: packed_func.h:237
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:63
Definition: variant.h:69
Runtime Map container types.
Box< bool > Bool
Boxed version of C++ bool.
Definition: boxed_primitive.h:121
const char * ArgTypeCode2Str(int type_code)
Convert argument type code to string.
Definition: packed_func.h:1406
const char * DLDeviceType2Str(int type)
The name of DLDeviceType.
Definition: packed_func.h:1451
DLDataType String2DLDataType(std::string s)
convert a string to TVM type.
Definition: data_type.h:429
std::string() FSig
Using static function to output TypedPackedFunc signature.
Definition: packed_func.h:189
Box< int64_t > Int
Boxed version of C++ int64_t.
Definition: boxed_primitive.h:99
void TVM_ALWAYS_INLINE PackArgs(TVMValue *values, int *type_codes, Args &&... args)
Definition: packed_func.h:1944
Box< double > Float
Boxed version of C++ double.
Definition: boxed_primitive.h:107
std::string DLDataType2String(DLDataType t)
convert a TVM type to string.
Definition: data_type.h:422
std::ostream & operator<<(std::ostream &os, const ObjectRef &n)
Definition: repr_printer.h:97
Object * TVMArrayHandleToObjectHandle(TVMArrayHandle handle)
Definition: ndarray.h:446
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
runtime::DataType DataType
Definition: data_type.h:493
DLDevice Device
Definition: ndarray.h:43
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
constexpr runtime::NullOptType NullOpt
Definition: optional.h:169
A device-independent managed NDArray abstraction.
A managed object in the TVM runtime.
#define TVM_LOG_INCORRECT_TYPE_CODE(CODE, T)
Definition: packed_func.h:434
#define TVM_CHECK_TYPE_CODE(CODE, T)
Definition: packed_func.h:438
Runtime container of the functions generated by TVM, This is used to support dynamically link,...
Byte array type used to pass in byte array When kTVMBytes is used as data type.
Definition: c_runtime_api.h:223
size_t size
Definition: c_runtime_api.h:225
const char * data
Definition: c_runtime_api.h:224
static std::string TypeName()
Definition: packed_func.h:522
static bool Check(const Object *ptr)
Definition: packed_func.h:509
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Definition: packed_func.h:487
static bool Check(const Object *ptr)
Definition: packed_func.h:555
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Definition: packed_func.h:527
static std::string TypeName()
Definition: packed_func.h:574
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Definition: packed_func.h:592
static bool Check(const Object *ptr)
Definition: packed_func.h:600
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Definition: packed_func.h:582
static bool Check(const Object *ptr)
Definition: packed_func.h:585
static std::string TypeName()
Definition: packed_func.h:586
static std::string VariantNames()
Definition: packed_func.h:587
Type traits for runtime type check during FFI conversion.
Definition: packed_func.h:445
static std::string TypeName()
Definition: packed_func.h:478
static bool Check(const Object *ptr)
Check if an object matches the template type.
Definition: packed_func.h:473
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Check if an object matches the template type and return the mismatched type if it exists.
Definition: packed_func.h:453
Internal struct for extracting the callable method from callable type.
Definition: packed_func.h:89
static void Call(const PackedFuncObj *obj, TVMArgs args, TVMRetValue *rv)
Extracting the callable method from callable type.
Definition: packed_func.h:1392
static Array< T > From(const TVMArgValue &val)
Definition: packed_func.h:2481
static Array< T > From(const TVMRetValue &val)
Definition: packed_func.h:2522
static Map< T, U > From(const TVMArgValue &val)
Definition: packed_func.h:2539
static Map< T, U > From(const TVMRetValue &val)
Definition: packed_func.h:2582
static Optional< T > From(const TVMRetValue &val)
Definition: packed_func.h:2627
static Optional< T > From(const TVMArgValue &val)
Definition: packed_func.h:2623
static VType From(const PODSubclass &val)
Definition: packed_func.h:2641
static Optional< VType > TryValueConverter(const PODSubclass &val)
Definition: packed_func.h:2669
static Optional< VType > TryAsObjectRef(const PODSubclass &val)
Definition: packed_func.h:2658
static String From(const PODSubclass &val)
Definition: packed_func.h:2470
Type trait to specify special value conversion rules from TVMArgValue and TVMRetValue.
Definition: packed_func.h:1246
static TObjectRef From(const TVMRetValue &val)
Convert a TObjectRef from a return value.
Definition: packed_func.h:1258
static TObjectRef From(const TVMArgValue &val)
Convert a TObjectRef from an argument value.
Definition: packed_func.h:1252
Definition: packed_func.h:65
Definition: packed_func.h:1937
static TVM_ALWAYS_INLINE void F(TVMArgsSetter *setter, T &&value)
Definition: packed_func.h:1938
@ kRuntimePackedFunc
runtime::PackedFunc.
Definition: object.h:74
Union type of values being passed through API and function calls.
Definition: c_runtime_api.h:210
DLDevice v_device
Definition: c_runtime_api.h:216
void * v_handle
Definition: c_runtime_api.h:213
DLDataType v_type
Definition: c_runtime_api.h:215
int64_t v_int64
Definition: c_runtime_api.h:211
const char * v_str
Definition: c_runtime_api.h:214
double v_float64
Definition: c_runtime_api.h:212
Runtime Variant container types.