tvm
packed_func.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef TVM_RUNTIME_PACKED_FUNC_H_
25 #define TVM_RUNTIME_PACKED_FUNC_H_
26 
30 #include <tvm/runtime/data_type.h>
31 #include <tvm/runtime/logging.h>
32 #include <tvm/runtime/module.h>
33 #include <tvm/runtime/ndarray.h>
34 #include <tvm/runtime/object.h>
35 
36 #include <functional>
37 #include <limits>
38 #include <memory>
39 #include <string>
40 #include <tuple>
41 #include <type_traits>
42 #include <utility>
43 #include <vector>
44 
45 // Whether use TVM runtime in header only mode.
46 #ifndef TVM_RUNTIME_HEADER_ONLY
47 #define TVM_RUNTIME_HEADER_ONLY 0
48 #endif
49 
50 namespace tvm {
51 namespace runtime {
52 
53 // forward declarations
54 class TVMArgs;
55 class TVMArgValue;
56 class TVMMovableArgValueWithContext_;
57 class TVMRetValue;
58 class TVMArgsSetter;
59 
68 class PackedFunc {
69  public:
88  using FType = std::function<void(TVMArgs args, TVMRetValue* rv)>;
92  PackedFunc(std::nullptr_t null) {} // NOLINT(*)
97  explicit PackedFunc(FType body) : body_(body) {}
112  template <typename... Args>
113  inline TVMRetValue operator()(Args&&... args) const;
119  inline void CallPacked(TVMArgs args, TVMRetValue* rv) const;
121  inline FType body() const;
123  bool operator==(std::nullptr_t null) const { return body_ == nullptr; }
125  bool operator!=(std::nullptr_t null) const { return body_ != nullptr; }
126 
127  private:
129  FType body_;
130 };
131 
135 template <typename FType>
137 
170 template <typename R, typename... Args>
171 class TypedPackedFunc<R(Args...)> {
172  public:
174  using TSelf = TypedPackedFunc<R(Args...)>;
178  TypedPackedFunc(std::nullptr_t null) {} // NOLINT(*)
196  inline TypedPackedFunc(PackedFunc packed); // NOLINT(*)
201  inline TypedPackedFunc(const TVMRetValue& value); // NOLINT(*)
206  inline TypedPackedFunc(const TVMArgValue& value); // NOLINT(*)
211  inline TypedPackedFunc(TVMMovableArgValueWithContext_&& value); // NOLINT(*)
228  template <typename FLambda, typename = typename std::enable_if<std::is_convertible<
229  FLambda, std::function<R(Args...)>>::value>::type>
230  TypedPackedFunc(const FLambda& typed_lambda, std::string name) { // NOLINT(*)
231  this->AssignTypedLambda(typed_lambda, name);
232  }
251  template <typename FLambda, typename = typename std::enable_if<std::is_convertible<
252  FLambda, std::function<R(Args...)>>::value>::type>
253  TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*)
254  this->AssignTypedLambda(typed_lambda);
255  }
272  template <typename FLambda, typename = typename std::enable_if<
273  std::is_convertible<FLambda,
274  std::function<R(Args...)>>::value>::type>
275  TSelf& operator=(FLambda typed_lambda) { // NOLINT(*)
276  this->AssignTypedLambda(typed_lambda);
277  return *this;
278  }
285  packed_ = packed;
286  return *this;
287  }
293  TVM_ALWAYS_INLINE R operator()(Args... args) const;
298  operator PackedFunc() const { return packed(); }
302  const PackedFunc& packed() const { return packed_; }
304  bool operator==(std::nullptr_t null) const { return packed_ == nullptr; }
306  bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; }
307 
308  private:
309  friend class TVMRetValue;
311  PackedFunc packed_;
320  template <typename FLambda>
321  inline void AssignTypedLambda(FLambda flambda, std::string name);
330  template <typename FLambda>
331  inline void AssignTypedLambda(FLambda flambda);
332 };
333 
335 class TVMArgs {
336  public:
337  const TVMValue* values;
338  const int* type_codes;
339  int num_args;
346  TVMArgs(const TVMValue* values, const int* type_codes, int num_args)
347  : values(values), type_codes(type_codes), num_args(num_args) {}
349  inline int size() const;
355  inline TVMArgValue operator[](int i) const;
356 };
357 
363 inline const char* ArgTypeCode2Str(int type_code);
364 
365 // macro to check type code.
366 #define TVM_CHECK_TYPE_CODE(CODE, T) \
367  ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE)
368 
373 template <typename T>
383  using ContainerType = typename T::ContainerType;
384  if (ptr == nullptr) {
385  if (T::_type_is_nullable) {
386  return NullOpt;
387  } else {
388  return String("nullptr");
389  }
390  }
391  if (ptr->IsInstance<ContainerType>()) {
392  return NullOpt;
393  } else {
394  return String(ptr->GetTypeKey());
395  }
396  }
402  static bool Check(const Object* ptr) {
403  using ContainerType = typename T::ContainerType;
404  if (ptr == nullptr) return T::_type_is_nullable;
405  return ptr->IsInstance<ContainerType>();
406  }
407  static std::string TypeName() {
408  using ContainerType = typename T::ContainerType;
409  return ContainerType::_type_key;
410  }
411 };
412 
413 // Additional overloads for PackedFunc checking.
414 template <typename T>
417  if (ptr == nullptr) {
418  return NullOpt;
419  }
420  if (!ptr->IsInstance<ArrayNode>()) {
421  return String(ptr->GetTypeKey());
422  }
423  const ArrayNode* n = static_cast<const ArrayNode*>(ptr);
424  for (size_t i = 0; i < n->size(); i++) {
425  const ObjectRef& p = (*n)[i];
427  if (check_subtype.defined()) {
428  return String("Array[index " + std::to_string(i) + ": " + check_subtype.value() + "]");
429  }
430  }
431  return NullOpt;
432  }
433  static bool Check(const Object* ptr) {
434  if (ptr == nullptr) return true;
435  if (!ptr->IsInstance<ArrayNode>()) return false;
436  const ArrayNode* n = static_cast<const ArrayNode*>(ptr);
437  for (const ObjectRef& p : *n) {
438  if (!ObjectTypeChecker<T>::Check(p.get())) {
439  return false;
440  }
441  }
442  return true;
443  }
444  static std::string TypeName() { return "Array[" + ObjectTypeChecker<T>::TypeName() + "]"; }
445 };
446 template <typename K, typename V>
447 struct ObjectTypeChecker<Map<K, V>> {
449  if (ptr == nullptr) return NullOpt;
450  if (!ptr->IsInstance<MapNode>()) return String(ptr->GetTypeKey());
451  const MapNode* n = static_cast<const MapNode*>(ptr);
452  for (const auto& kv : *n) {
454  Optional<String> value_type = ObjectTypeChecker<K>::CheckAndGetMismatch(kv.first.get());
455  if (key_type.defined() || value_type.defined()) {
456  std::string key_name =
457  key_type.defined() ? std::string(key_type.value()) : ObjectTypeChecker<K>::TypeName();
458  std::string value_name = value_type.defined() ? std::string(value_type.value())
460  return String("Map[" + key_name + ", " + value_name + "]");
461  }
462  }
463  return NullOpt;
464  }
465  static bool Check(const Object* ptr) {
466  if (ptr == nullptr) return true;
467  if (!ptr->IsInstance<MapNode>()) return false;
468  const MapNode* n = static_cast<const MapNode*>(ptr);
469  for (const auto& kv : *n) {
470  if (!ObjectTypeChecker<K>::Check(kv.first.get())) return false;
471  if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
472  }
473  return true;
474  }
475  static std::string TypeName() {
477  ']';
478  }
479 };
480 
486  public:
487  operator double() const {
488  // Allow automatic conversion from int to float
489  // This avoids errors when user pass in int from
490  // the frontend while the API expects a float.
491  if (type_code_ == kDLInt) {
492  return static_cast<double>(value_.v_int64);
493  }
494  TVM_CHECK_TYPE_CODE(type_code_, kDLFloat);
495  return value_.v_float64;
496  }
497  operator int64_t() const {
498  TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
499  return value_.v_int64;
500  }
501  operator uint64_t() const {
502  TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
503  return value_.v_int64;
504  }
505  operator int() const {
506  TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
507  ICHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
508  ICHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
509  return static_cast<int>(value_.v_int64);
510  }
511  operator bool() const {
512  TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
513  return value_.v_int64 != 0;
514  }
515  operator void*() const {
516  if (type_code_ == kTVMNullptr) return nullptr;
517  if (type_code_ == kTVMDLTensorHandle) return value_.v_handle;
519  return value_.v_handle;
520  }
521  operator DLTensor*() const {
522  if (type_code_ == kTVMDLTensorHandle || type_code_ == kTVMNDArrayHandle) {
523  return static_cast<DLTensor*>(value_.v_handle);
524  } else {
525  if (type_code_ == kTVMNullptr) return nullptr;
526  LOG(FATAL) << "Expected "
527  << "DLTensor* or NDArray but got " << ArgTypeCode2Str(type_code_);
528  return nullptr;
529  }
530  }
531  operator NDArray() const {
532  if (type_code_ == kTVMNullptr) return NDArray(ObjectPtr<Object>(nullptr));
534  return NDArray(NDArray::FFIDataFromHandle(static_cast<TVMArrayHandle>(value_.v_handle)));
535  }
536  operator Module() const {
537  if (type_code_ == kTVMNullptr) {
538  return Module(ObjectPtr<Object>(nullptr));
539  }
541  return Module(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
542  }
543  operator Device() const {
544  TVM_CHECK_TYPE_CODE(type_code_, kDLDevice);
545  return value_.v_device;
546  }
547  int type_code() const { return type_code_; }
553  template <typename T>
554  T* ptr() const {
555  return static_cast<T*>(value_.v_handle);
556  }
557  // ObjectRef handling
558  template <typename TObjectRef,
559  typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
560  inline bool IsObjectRef() const;
561  template <typename TObjectRef>
562  inline TObjectRef AsObjectRef() const;
563 
564  protected:
565  friend class TVMArgsSetter;
566  friend class TVMRetValue;
567  friend class TVMMovableArgValue_;
568  TVMPODValue_() : type_code_(kTVMNullptr) {}
569  TVMPODValue_(TVMValue value, int type_code) : value_(value), type_code_(type_code) {}
570 
575 };
576 
583 class TVMArgValue : public TVMPODValue_ {
584  public:
592  TVMArgValue(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {}
593  // reuse converter from parent
594  using TVMPODValue_::operator double;
595  using TVMPODValue_::operator int64_t;
596  using TVMPODValue_::operator uint64_t;
597  using TVMPODValue_::operator int;
598  using TVMPODValue_::operator bool;
599  using TVMPODValue_::operator void*;
600  using TVMPODValue_::operator DLTensor*;
601  using TVMPODValue_::operator NDArray;
602  using TVMPODValue_::operator Device;
603  using TVMPODValue_::operator Module;
606 
607  // conversion operator.
608  operator std::string() const {
609  if (type_code_ == kTVMDataType) {
610  return DLDataType2String(operator DLDataType());
611  } else if (type_code_ == kTVMBytes) {
612  TVMByteArray* arr = static_cast<TVMByteArray*>(value_.v_handle);
613  return std::string(arr->data, arr->size);
614  } else if (type_code_ == kTVMStr) {
615  return std::string(value_.v_str);
616  } else {
617  ICHECK(IsObjectRef<tvm::runtime::String>())
618  << "Could not convert TVM object of type " << runtime::Object::TypeIndex2Key(type_code_)
619  << " to a string.";
620  return AsObjectRef<tvm::runtime::String>().operator std::string();
621  }
622  }
623  operator PackedFunc() const {
624  if (type_code_ == kTVMNullptr) return PackedFunc();
626  return *ptr<PackedFunc>();
627  }
628  template <typename FType>
629  operator TypedPackedFunc<FType>() const {
630  return TypedPackedFunc<FType>(operator PackedFunc());
631  }
632  const TVMValue& value() const { return value_; }
633 
634  template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
635  inline operator T() const;
636  inline operator DLDataType() const;
637  inline operator DataType() const;
638 };
639 
651  public:
652  TVMMovableArgValue_(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {}
653  // reuse converter from parent
654  using TVMPODValue_::operator double;
655  using TVMPODValue_::operator int64_t;
656  using TVMPODValue_::operator uint64_t;
657  using TVMPODValue_::operator int;
658  using TVMPODValue_::operator bool;
659  using TVMPODValue_::operator void*;
660  using TVMPODValue_::operator DLTensor*;
661  using TVMPODValue_::operator NDArray;
662  using TVMPODValue_::operator Device;
663  using TVMPODValue_::operator Module;
664  // reuse conversion rule from ArgValue.
665  operator std::string() const { return AsArgValue().operator std::string(); }
666  operator PackedFunc() const { return AsArgValue().operator PackedFunc(); }
667  template <typename FType>
668  operator TypedPackedFunc<FType>() const {
669  return TypedPackedFunc<FType>(operator PackedFunc());
670  }
671  operator DLDataType() const { return AsArgValue().operator DLDataType(); }
672  operator DataType() const { return AsArgValue().operator DataType(); }
673  operator TVMArgValue() const { return AsArgValue(); }
679  template <typename T,
680  typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
681  inline operator T() const;
682 
683  private:
685  TVMArgValue AsArgValue() const { return TVMArgValue(value_, type_code_); }
686 };
687 
696  public:
705  TVMMovableArgValueWithContext_(TVMValue value, int type_code, int arg_index,
706  const std::string* optional_name)
707  : value_(value, type_code), arg_index_(arg_index), optional_name_(optional_name) {}
708 
709  template <typename T>
710  operator T() const {
711  try {
712  return value_; // implicit conversion happens here
713  } catch (dmlc::Error& e) {
714  LOG(FATAL) << "In function " << (optional_name_ == nullptr ? "<anonymous>" : *optional_name_)
715  << ": error while converting argument " << arg_index_ << ": " << e.what();
716  throw; // never reached, LOG(FATAL) throws, but this silences a warning.
717  }
718  }
719 
720  private:
721  TVMMovableArgValue_ value_;
722  int arg_index_;
723  const std::string* optional_name_;
724 };
725 
734 class TVMRetValue : public TVMPODValue_ {
735  public:
742  TVMRetValue(TVMRetValue&& other) : TVMPODValue_(other.value_, other.type_code_) {
743  other.value_.v_handle = nullptr;
744  other.type_code_ = kTVMNullptr;
745  }
747  ~TVMRetValue() { this->Clear(); }
748  // reuse converter from parent
749  using TVMPODValue_::operator double;
750  using TVMPODValue_::operator int64_t;
751  using TVMPODValue_::operator uint64_t;
752  using TVMPODValue_::operator int;
753  using TVMPODValue_::operator bool;
754  using TVMPODValue_::operator void*;
755  using TVMPODValue_::operator DLTensor*;
756  using TVMPODValue_::operator Device;
757  using TVMPODValue_::operator NDArray;
758  using TVMPODValue_::operator Module;
761 
762  TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); }
763  // conversion operators
764  operator std::string() const {
765  if (type_code_ == kTVMDataType) {
766  return DLDataType2String(operator DLDataType());
767  } else if (type_code_ == kTVMBytes) {
768  return *ptr<std::string>();
769  }
770  TVM_CHECK_TYPE_CODE(type_code_, kTVMStr);
771  return *ptr<std::string>();
772  }
773  operator DLDataType() const {
774  if (type_code_ == kTVMStr) {
775  return String2DLDataType(operator std::string());
776  }
777  TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType);
778  return value_.v_type;
779  }
780  operator DataType() const { return DataType(operator DLDataType()); }
781  operator PackedFunc() const {
782  if (type_code_ == kTVMNullptr) return PackedFunc();
784  return *ptr<PackedFunc>();
785  }
786  template <typename FType>
787  operator TypedPackedFunc<FType>() const {
788  return TypedPackedFunc<FType>(operator PackedFunc());
789  }
790  // Assign operators
792  this->Clear();
793  value_ = other.value_;
794  type_code_ = other.type_code_;
795  other.type_code_ = kTVMNullptr;
796  return *this;
797  }
798  TVMRetValue& operator=(double value) {
799  this->SwitchToPOD(kDLFloat);
800  value_.v_float64 = value;
801  return *this;
802  }
803  TVMRetValue& operator=(std::nullptr_t value) {
804  this->SwitchToPOD(kTVMNullptr);
805  value_.v_handle = value;
806  return *this;
807  }
808  TVMRetValue& operator=(void* value) {
809  this->SwitchToPOD(kTVMOpaqueHandle);
810  value_.v_handle = value;
811  return *this;
812  }
813  TVMRetValue& operator=(int64_t value) {
814  this->SwitchToPOD(kDLInt);
815  value_.v_int64 = value;
816  return *this;
817  }
818  TVMRetValue& operator=(int value) {
819  this->SwitchToPOD(kDLInt);
820  value_.v_int64 = value;
821  return *this;
822  }
823  TVMRetValue& operator=(DLDevice value) {
824  this->SwitchToPOD(kDLDevice);
825  value_.v_device = value;
826  return *this;
827  }
828  TVMRetValue& operator=(DLDataType t) {
829  this->SwitchToPOD(kTVMDataType);
830  value_.v_type = t;
831  return *this;
832  }
833  TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); }
834  TVMRetValue& operator=(bool value) {
835  this->SwitchToPOD(kDLInt);
836  value_.v_int64 = value;
837  return *this;
838  }
839  TVMRetValue& operator=(std::string value) {
840  this->SwitchToClass(kTVMStr, value);
841  return *this;
842  }
844  this->SwitchToClass(kTVMBytes, std::string(value.data, value.size));
845  return *this;
846  }
847  TVMRetValue& operator=(NDArray other) {
848  if (other.data_ != nullptr) {
849  this->Clear();
850  type_code_ = kTVMNDArrayHandle;
851  value_.v_handle = NDArray::FFIGetHandle(other);
853  } else {
854  SwitchToPOD(kTVMNullptr);
855  }
856  return *this;
857  }
858  TVMRetValue& operator=(Module m) {
859  SwitchToObject(kTVMModuleHandle, std::move(m.data_));
860  return *this;
861  }
863  if (f == nullptr) {
864  this->SwitchToPOD(kTVMNullptr);
865  } else {
866  this->SwitchToClass(kTVMPackedFuncHandle, f);
867  }
868  return *this;
869  }
870  template <typename FType>
872  return operator=(f.packed());
873  }
874  TVMRetValue& operator=(const TVMRetValue& other) { // NOLINT(*0
875  this->Assign(other);
876  return *this;
877  }
879  this->Assign(other);
880  return *this;
881  }
883  this->Assign(other);
884  return *this;
885  }
895  void MoveToCHost(TVMValue* ret_value, int* ret_type_code) {
896  // cannot move str; need specially handle.
897  ICHECK(type_code_ != kTVMStr && type_code_ != kTVMBytes);
898  *ret_value = value_;
899  *ret_type_code = type_code_;
900  type_code_ = kTVMNullptr;
901  }
909  static TVMRetValue MoveFromCHost(TVMValue value, int type_code) {
910  // Can move POD and everything under the object system.
911  ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle);
913  ret.value_ = value;
914  ret.type_code_ = type_code;
915  return ret;
916  }
918  const TVMValue& value() const {
919  ICHECK(type_code_ != kTVMObjectHandle && type_code_ != kTVMPackedFuncHandle &&
920  type_code_ != kTVMModuleHandle && type_code_ != kTVMStr)
921  << "TVMRetValue.value can only be used for POD data";
922  return value_;
923  }
924  // ObjectRef handling
925  template <typename TObjectRef,
926  typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
927  inline TVMRetValue& operator=(TObjectRef other);
928  template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
929  inline operator T() const;
930 
931  private:
932  template <typename T>
933  void Assign(const T& other) {
934  switch (other.type_code()) {
935  case kTVMStr: {
936  SwitchToClass<std::string>(kTVMStr, other);
937  break;
938  }
939  case kTVMBytes: {
940  SwitchToClass<std::string>(kTVMBytes, other);
941  break;
942  }
943  case kTVMPackedFuncHandle: {
944  SwitchToClass<PackedFunc>(kTVMPackedFuncHandle, other);
945  break;
946  }
947  case kTVMModuleHandle: {
948  *this = other.operator Module();
949  break;
950  }
951  case kTVMNDArrayHandle: {
952  *this = other.operator NDArray();
953  break;
954  }
955  case kTVMObjectHandle: {
956  // Avoid operator ObjectRef as we already know it is not NDArray/Module
957  SwitchToObject(kTVMObjectHandle,
958  GetObjectPtr<Object>(static_cast<Object*>(other.value_.v_handle)));
959  break;
960  }
961  case kTVMObjectRValueRefArg: {
962  operator=(other.operator ObjectRef());
963  break;
964  }
965  default: {
966  SwitchToPOD(other.type_code());
967  value_ = other.value_;
968  break;
969  }
970  }
971  }
972  // get the internal container.
973  void SwitchToPOD(int type_code) {
974  if (type_code_ != type_code) {
975  this->Clear();
976  type_code_ = type_code;
977  }
978  }
979  template <typename T>
980  void SwitchToClass(int type_code, T v) {
981  if (type_code_ != type_code) {
982  this->Clear();
983  type_code_ = type_code;
984  value_.v_handle = new T(v);
985  } else {
986  *static_cast<T*>(value_.v_handle) = v;
987  }
988  }
989  void SwitchToObject(int type_code, ObjectPtr<Object> other) {
990  if (other.data_ != nullptr) {
991  this->Clear();
992  type_code_ = type_code;
993  // move the handle out
994  value_.v_handle = other.data_;
995  other.data_ = nullptr;
996  } else {
997  SwitchToPOD(kTVMNullptr);
998  }
999  }
1000  void Clear() {
1001  if (type_code_ == kTVMNullptr) return;
1002  switch (type_code_) {
1003  case kTVMStr:
1004  case kTVMBytes:
1005  delete ptr<std::string>();
1006  break;
1007  case kTVMPackedFuncHandle:
1008  delete ptr<PackedFunc>();
1009  break;
1010  case kTVMNDArrayHandle: {
1011  NDArray::FFIDecRef(static_cast<TVMArrayHandle>(value_.v_handle));
1012  break;
1013  }
1014  case kTVMModuleHandle: {
1015  static_cast<Object*>(value_.v_handle)->DecRef();
1016  break;
1017  }
1018  case kTVMObjectHandle: {
1019  static_cast<Object*>(value_.v_handle)->DecRef();
1020  break;
1021  }
1022  }
1023  type_code_ = kTVMNullptr;
1024  }
1025 };
1026 
1036 template <typename TObjectRef>
1043  static TObjectRef From(const TVMArgValue& val) { return val.AsObjectRef<TObjectRef>(); }
1049  static TObjectRef From(const TVMRetValue& val) { return val.AsObjectRef<TObjectRef>(); }
1050 };
1051 
1071 #define TVM_DLL_EXPORT_PACKED_FUNC(ExportName, Function) \
1072  extern "C" { \
1073  TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \
1074  int* out_type_code, void* resource_handle); \
1075  int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \
1076  int* out_type_code, void* resource_handle) { \
1077  try { \
1078  ::tvm::runtime::TVMRetValue rv; \
1079  Function(::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \
1080  rv.MoveToCHost(out_value, out_type_code); \
1081  return 0; \
1082  } catch (const ::std::exception& _except_) { \
1083  TVMAPISetLastError(_except_.what()); \
1084  return -1; \
1085  } \
1086  } \
1087  }
1088 
1124 #define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \
1125  extern "C" { \
1126  TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \
1127  int* out_type_code, void* resource_handle) { \
1128  try { \
1129  auto f = Function; \
1130  using FType = ::tvm::runtime::detail::function_signature<decltype(f)>::FType; \
1131  ::tvm::runtime::TVMRetValue rv; \
1132  ::tvm::runtime::detail::unpack_call_by_signature<FType>::run( \
1133  f, ::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \
1134  rv.MoveToCHost(out_value, out_type_code); \
1135  return 0; \
1136  } catch (const ::std::exception& _except_) { \
1137  TVMAPISetLastError(_except_.what()); \
1138  return -1; \
1139  } \
1140  } \
1141  }
1142 
1143 inline TVMArgValue TVMArgs::operator[](int i) const {
1144  ICHECK_LT(i, num_args) << "not enough argument passed, " << num_args << " passed"
1145  << " but request arg[" << i << "].";
1146  return TVMArgValue(values[i], type_codes[i]);
1147 }
1148 
1149 inline int TVMArgs::size() const { return num_args; }
1150 
1151 inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { body_(args, rv); }
1152 
1153 inline PackedFunc::FType PackedFunc::body() const { return body_; }
1154 
1155 // internal namespace
1156 inline const char* ArgTypeCode2Str(int type_code) {
1157  switch (type_code) {
1158  case kDLInt:
1159  return "int";
1160  case kDLUInt:
1161  return "uint";
1162  case kDLFloat:
1163  return "float";
1164  case kTVMStr:
1165  return "str";
1166  case kTVMBytes:
1167  return "bytes";
1168  case kTVMOpaqueHandle:
1169  return "handle";
1170  case kTVMNullptr:
1171  return "NULL";
1172  case kTVMDLTensorHandle:
1173  return "ArrayHandle";
1174  case kTVMDataType:
1175  return "DLDataType";
1176  case kDLDevice:
1177  return "DLDevice";
1178  case kTVMPackedFuncHandle:
1179  return "FunctionHandle";
1180  case kTVMModuleHandle:
1181  return "ModuleHandle";
1182  case kTVMNDArrayHandle:
1183  return "NDArrayContainer";
1184  case kTVMObjectHandle:
1185  return "Object";
1187  return "ObjectRValueRefArg";
1188  default:
1189  LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
1190  return "";
1191  }
1192 }
1193 
1194 namespace detail {
1195 
1196 template <bool stop, std::size_t I, typename F>
1197 struct for_each_dispatcher {
1198  template <typename T, typename... Args>
1199  static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*)
1200  f(I, std::forward<T>(value));
1201  for_each_dispatcher<sizeof...(Args) == 0, (I + 1), F>::run(f, std::forward<Args>(args)...);
1202  }
1203 };
1204 
1205 template <std::size_t I, typename F>
1206 struct for_each_dispatcher<true, I, F> {
1207  static void run(const F& f) {} // NOLINT(*)
1208 };
1209 
1210 template <typename F, typename... Args>
1211 inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
1212  for_each_dispatcher<sizeof...(Args) == 0, 0, F>::run(f, std::forward<Args>(args)...);
1213 }
1214 
1215 template <typename T>
1216 struct func_signature_helper {
1217  using FType = void;
1218 };
1219 
1220 template <typename T, typename R, typename... Args>
1221 struct func_signature_helper<R (T::*)(Args...)> {
1222  using FType = R(Args...);
1223  static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
1224 };
1225 
1226 template <typename T, typename R, typename... Args>
1227 struct func_signature_helper<R (T::*)(Args...) const> {
1228  using FType = R(Args...);
1229  static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
1230 };
1231 
1236 template <typename T>
1237 struct function_signature {
1238  using FType = typename func_signature_helper<decltype(&T::operator())>::FType;
1239 };
1240 
1241 // handle case of function.
1242 template <typename R, typename... Args>
1243 struct function_signature<R(Args...)> {
1244  using FType = R(Args...);
1245  static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
1246 };
1247 
1248 // handle case of function ptr.
1249 template <typename R, typename... Args>
1250 struct function_signature<R (*)(Args...)> {
1251  using FType = R(Args...);
1252  static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference");
1253 };
1254 } // namespace detail
1255 
1256 /* \brief argument settter to PackedFunc */
1258  public:
1259  TVMArgsSetter(TVMValue* values, int* type_codes) : values_(values), type_codes_(type_codes) {}
1260  // setters for POD types
1261  template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type>
1262  TVM_ALWAYS_INLINE void operator()(size_t i, T value) const {
1263  values_[i].v_int64 = static_cast<int64_t>(value);
1264  type_codes_[i] = kDLInt;
1265  }
1266  TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const {
1267  values_[i].v_int64 = static_cast<int64_t>(value);
1268  ICHECK_LE(value, static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
1269  type_codes_[i] = kDLInt;
1270  }
1271  TVM_ALWAYS_INLINE void operator()(size_t i, double value) const {
1272  values_[i].v_float64 = value;
1273  type_codes_[i] = kDLFloat;
1274  }
1275  TVM_ALWAYS_INLINE void operator()(size_t i, std::nullptr_t value) const {
1276  values_[i].v_handle = value;
1277  type_codes_[i] = kTVMNullptr;
1278  }
1279  TVM_ALWAYS_INLINE void operator()(size_t i, const TVMArgValue& value) const {
1280  values_[i] = value.value_;
1281  type_codes_[i] = value.type_code_;
1282  }
1283  TVM_ALWAYS_INLINE void operator()(size_t i, void* value) const {
1284  values_[i].v_handle = value;
1285  type_codes_[i] = kTVMOpaqueHandle;
1286  }
1287  TVM_ALWAYS_INLINE void operator()(size_t i, DLTensor* value) const {
1288  values_[i].v_handle = value;
1289  type_codes_[i] = kTVMDLTensorHandle;
1290  }
1291  TVM_ALWAYS_INLINE void operator()(size_t i, Device value) const {
1292  values_[i].v_device = value;
1293  type_codes_[i] = kDLDevice;
1294  }
1295  TVM_ALWAYS_INLINE void operator()(size_t i, DLDataType value) const {
1296  values_[i].v_type = value;
1297  type_codes_[i] = kTVMDataType;
1298  }
1299  TVM_ALWAYS_INLINE void operator()(size_t i, DataType dtype) const {
1300  operator()(i, dtype.operator DLDataType());
1301  }
1302  TVM_ALWAYS_INLINE void operator()(size_t i, const char* value) const {
1303  values_[i].v_str = value;
1304  type_codes_[i] = kTVMStr;
1305  }
1306  // setters for container types
1307  TVM_ALWAYS_INLINE void operator()(size_t i, const std::string& value) const {
1308  values_[i].v_str = value.c_str();
1309  type_codes_[i] = kTVMStr;
1310  }
1311  TVM_ALWAYS_INLINE void operator()(size_t i, const TVMByteArray& value) const {
1312  values_[i].v_handle = const_cast<TVMByteArray*>(&value);
1313  type_codes_[i] = kTVMBytes;
1314  }
1315  TVM_ALWAYS_INLINE void operator()(size_t i, const PackedFunc& value) const {
1316  if (value != nullptr) {
1317  values_[i].v_handle = const_cast<PackedFunc*>(&value);
1318  type_codes_[i] = kTVMPackedFuncHandle;
1319  } else {
1320  values_[i].v_handle = nullptr;
1321  type_codes_[i] = kTVMNullptr;
1322  }
1323  }
1324  template <typename FType>
1325  TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc<FType>& value) const {
1326  operator()(i, value.packed());
1327  }
1328  void operator()(size_t i, const TVMRetValue& value) const {
1329  if (value.type_code() == kTVMStr) {
1330  values_[i].v_str = value.ptr<std::string>()->c_str();
1331  type_codes_[i] = kTVMStr;
1332  } else {
1333  ICHECK_NE(value.type_code(), kTVMBytes) << "not handled.";
1334  values_[i] = value.value_;
1335  type_codes_[i] = value.type_code();
1336  }
1337  }
1338  // ObjectRef handling
1339  template <typename TObjectRef,
1340  typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
1341  TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef& value) const {
1342  this->SetObject(i, value);
1343  }
1344 
1345  template <typename TObjectRef,
1346  typename = typename std::enable_if<std::is_base_of<
1347  ObjectRef, typename std::remove_reference<TObjectRef>::type>::value>::type>
1348  TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef&& value) const {
1349  this->SetObject(i, std::forward<TObjectRef>(value));
1350  }
1351 
1352  private:
1353  template <typename TObjectRef>
1354  inline void SetObject(size_t i, TObjectRef&& value) const;
1356  TVMValue* values_;
1358  int* type_codes_;
1359 };
1360 
1361 template <typename... Args>
1362 inline TVMRetValue PackedFunc::operator()(Args&&... args) const {
1363  const int kNumArgs = sizeof...(Args);
1364  const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
1365  TVMValue values[kArraySize];
1366  int type_codes[kArraySize];
1367  detail::for_each(TVMArgsSetter(values, type_codes), std::forward<Args>(args)...);
1368  TVMRetValue rv;
1369  body_(TVMArgs(values, type_codes, kNumArgs), &rv);
1370  return rv;
1371 }
1372 
1373 namespace detail {
1374 template <typename R, int nleft, int index, typename F>
1375 struct unpack_call_dispatcher {
1376  template <typename... Args>
1377  TVM_ALWAYS_INLINE static void run(const std::string* optional_name, const F& f,
1378  const TVMArgs& args_pack, TVMRetValue* rv,
1379  Args&&... unpacked_args) {
1380  // construct a movable argument value
1381  // which allows potential move of argument to the input of F.
1382  unpack_call_dispatcher<R, nleft - 1, index + 1, F>::run(
1383  optional_name, f, args_pack, rv, std::forward<Args>(unpacked_args)...,
1384  TVMMovableArgValueWithContext_(args_pack.values[index], args_pack.type_codes[index], index,
1385  optional_name));
1386  }
1387 };
1388 
1389 template <typename R, int index, typename F>
1390 struct unpack_call_dispatcher<R, 0, index, F> {
1391  template <typename... Args>
1392  TVM_ALWAYS_INLINE static void run(const std::string* optional_name, const F& f,
1393  const TVMArgs& args_pack, TVMRetValue* rv,
1394  Args&&... unpacked_args) {
1395  using RetType = decltype(f(std::forward<Args>(unpacked_args)...));
1396  if (std::is_same<RetType, R>::value) {
1397  *rv = f(std::forward<Args>(unpacked_args)...);
1398  } else {
1399  *rv = R(f(std::forward<Args>(unpacked_args)...));
1400  }
1401  }
1402 };
1403 
1404 template <int index, typename F>
1405 struct unpack_call_dispatcher<void, 0, index, F> {
1406  template <typename... Args>
1407  TVM_ALWAYS_INLINE static void run(const std::string* optional_name, const F& f,
1408  const TVMArgs& args_pack, TVMRetValue* rv,
1409  Args&&... unpacked_args) {
1410  f(std::forward<Args>(unpacked_args)...);
1411  }
1412 };
1413 
1414 template <typename R, int nargs, typename F>
1415 TVM_ALWAYS_INLINE void unpack_call(const std::string* optional_name, const F& f,
1416  const TVMArgs& args, TVMRetValue* rv) {
1417  CHECK_EQ(nargs, args.size()) << "Function "
1418  << (optional_name == nullptr ? "<anonymous>" : *optional_name)
1419  << " expects " << nargs << " arguments but " << args.size()
1420  << " were provided";
1421  unpack_call_dispatcher<R, nargs, 0, F>::run(optional_name, f, args, rv);
1422 }
1423 
1424 template <typename FType>
1425 struct unpack_call_by_signature {};
1426 
1427 template <typename R, typename... Args>
1428 struct unpack_call_by_signature<R(Args...)> {
1429  template <typename F>
1430  TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args, TVMRetValue* rv) {
1431  unpack_call<R, sizeof...(Args)>(nullptr, f, args, rv);
1432  }
1433 };
1434 
1435 template <typename R, typename... Args>
1436 TVM_ALWAYS_INLINE R call_packed(const PackedFunc& pf, Args&&... args) {
1437  return R(pf(std::forward<Args>(args)...));
1438 }
1439 
1440 template <typename R>
1441 struct typed_packed_call_dispatcher {
1442  template <typename... Args>
1443  TVM_ALWAYS_INLINE static R run(const PackedFunc& pf, Args&&... args) {
1444  return pf(std::forward<Args>(args)...);
1445  }
1446 };
1447 
1448 template <>
1449 struct typed_packed_call_dispatcher<void> {
1450  template <typename... Args>
1451  TVM_ALWAYS_INLINE static void run(const PackedFunc& pf, Args&&... args) {
1452  pf(std::forward<Args>(args)...);
1453  }
1454 };
1455 } // namespace detail
1456 
1457 template <typename R, typename... Args>
1458 TypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed) : packed_(packed) {}
1459 
1460 template <typename R, typename... Args>
1462  : packed_(value.operator PackedFunc()) {}
1463 
1464 template <typename R, typename... Args>
1466  : packed_(value.operator PackedFunc()) {}
1467 
1468 template <typename R, typename... Args>
1470  : packed_(value.operator PackedFunc()) {}
1471 
1472 template <typename R, typename... Args>
1473 template <typename FType>
1474 inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda, std::string name) {
1475  packed_ = PackedFunc([flambda, name](const TVMArgs& args, TVMRetValue* rv) {
1476  if (args.size() != sizeof...(Args)) {
1477  LOG(FATAL) << "Function " << name << " expects " << sizeof...(Args) << " arguments, but "
1478  << args.size() << " were provided.";
1479  }
1480  detail::unpack_call<R, sizeof...(Args)>(&name, flambda, args, rv);
1481  });
1482 }
1483 
1484 template <typename R, typename... Args>
1485 template <typename FType>
1486 inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
1487  packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) {
1488  if (args.size() != sizeof...(Args)) {
1489  LOG(FATAL) << "Function <anonymous> expects " << sizeof...(Args) << " arguments, but "
1490  << args.size() << " were provided.";
1491  }
1492  detail::unpack_call<R, sizeof...(Args)>(nullptr, flambda, args, rv);
1493  });
1494 }
1495 
1496 template <typename R, typename... Args>
1497 TVM_ALWAYS_INLINE R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
1498  return detail::typed_packed_call_dispatcher<R>::run(packed_, std::forward<Args>(args)...);
1499 }
1500 
1501 // ObjectRef related conversion handling
1502 // Object can have three possible type codes:
1503 // kTVMNDArrayHandle, kTVMModuleHandle, kTVMObjectHandle
1504 //
1505 // We use type traits to eliminate un-necessary checks.
1506 template <typename T>
1507 inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
1508  using ContainerType = typename std::remove_reference<T>::type::ContainerType;
1509  if (value.defined()) {
1510  Object* ptr = value.data_.data_;
1511  if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
1512  (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1514  values_[i].v_handle = NDArray::FFIGetHandle(value);
1515  type_codes_[i] = kTVMNDArrayHandle;
1516  } else if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
1517  (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1518  ptr->IsInstance<Module::ContainerType>())) {
1519  values_[i].v_handle = ptr;
1520  type_codes_[i] = kTVMModuleHandle;
1521  } else if (std::is_rvalue_reference<decltype(value)>::value) {
1522  values_[i].v_handle = const_cast<Object**>(&(value.data_.data_));
1523  type_codes_[i] = kTVMObjectRValueRefArg;
1524  } else {
1525  values_[i].v_handle = value.data_.data_;
1526  type_codes_[i] = kTVMObjectHandle;
1527  }
1528  } else {
1529  type_codes_[i] = kTVMNullptr;
1530  }
1531 }
1532 
1533 template <typename TObjectRef, typename>
1534 inline bool TVMPODValue_::IsObjectRef() const {
1535  using ContainerType = typename TObjectRef::ContainerType;
1536  // NOTE: the following code can be optimized by constant folding.
1537  if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
1538  return type_code_ == kTVMNDArrayHandle &&
1539  TVMArrayHandleToObjectHandle(static_cast<TVMArrayHandle>(value_.v_handle))
1540  ->IsInstance<ContainerType>();
1541  }
1542  if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
1543  return type_code_ == kTVMModuleHandle &&
1544  static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
1545  }
1546  // NOTE: we don't pass NDArray and runtime::Module as RValue ref.
1547  if (type_code_ == kTVMObjectRValueRefArg) {
1548  return ObjectTypeChecker<TObjectRef>::Check(*static_cast<Object**>(value_.v_handle));
1549  }
1550  return (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1551  type_code_ == kTVMNDArrayHandle) ||
1552  (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1553  type_code_ == kTVMModuleHandle) ||
1554  (type_code_ == kTVMObjectHandle &&
1555  ObjectTypeChecker<TObjectRef>::Check(static_cast<Object*>(value_.v_handle)));
1556 }
1557 
1558 template <typename TObjectRef>
1559 inline TObjectRef TVMPODValue_::AsObjectRef() const {
1560  static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
1561  "Conversion only works for ObjectRef");
1562  using ContainerType = typename TObjectRef::ContainerType;
1563 
1564  if (type_code_ == kTVMNullptr) {
1565  CHECK(TObjectRef::_type_is_nullable)
1566  << "Expect a not null value of " << ContainerType::_type_key;
1567  return TObjectRef(ObjectPtr<Object>(nullptr));
1568  }
1569  // NOTE: the following code can be optimized by constant folding.
1570  if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
1571  // Casting to a sub-class of NDArray
1573  ObjectPtr<Object> data =
1574  NDArray::FFIDataFromHandle(static_cast<TVMArrayHandle>(value_.v_handle));
1575  CHECK(data->IsInstance<ContainerType>())
1576  << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey();
1577  return TObjectRef(data);
1578  }
1579  if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
1580  // Casting to a sub-class of Module
1582  ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle));
1583  CHECK(data->IsInstance<ContainerType>())
1584  << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey();
1585  return TObjectRef(data);
1586  }
1587  if (type_code_ == kTVMObjectHandle) {
1588  // normal object type check.
1589  Object* ptr = static_cast<Object*>(value_.v_handle);
1591  ICHECK(!checked_type.defined()) << "Expected " << ObjectTypeChecker<TObjectRef>::TypeName()
1592  << ", but got " << checked_type.value();
1593  return TObjectRef(GetObjectPtr<Object>(ptr));
1594  } else if (type_code_ == kTVMObjectRValueRefArg) {
1595  Object* ptr = *static_cast<Object**>(value_.v_handle);
1597  ICHECK(!checked_type.defined()) << "Expected " << ObjectTypeChecker<TObjectRef>::TypeName()
1598  << ", but got " << checked_type.value();
1599  return TObjectRef(GetObjectPtr<Object>(ptr));
1600  } else if (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1601  type_code_ == kTVMNDArrayHandle) {
1602  // Casting to a base class that NDArray can sub-class
1603  ObjectPtr<Object> data =
1604  NDArray::FFIDataFromHandle(static_cast<TVMArrayHandle>(value_.v_handle));
1605  return TObjectRef(data);
1606  } else if (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1607  type_code_ == kTVMModuleHandle) {
1608  // Casting to a base class that Module can sub-class
1609  return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
1610  } else {
1612  return TObjectRef(ObjectPtr<Object>(nullptr));
1613  }
1614 }
1615 
1616 template <typename TObjectRef, typename>
1617 inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) {
1618  using ContainerType = typename TObjectRef::ContainerType;
1619  const Object* ptr = other.get();
1620  if (ptr != nullptr) {
1621  if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
1622  (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1624  return operator=(NDArray(std::move(other.data_)));
1625  }
1626  if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
1627  (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1628  ptr->IsInstance<Module::ContainerType>())) {
1629  return operator=(Module(std::move(other.data_)));
1630  }
1631  SwitchToObject(kTVMObjectHandle, std::move(other.data_));
1632  } else {
1633  SwitchToPOD(kTVMNullptr);
1634  }
1635  return *this;
1636 }
1637 
1638 template <typename T, typename>
1639 inline TVMArgValue::operator T() const {
1640  return PackedFuncValueConverter<T>::From(*this);
1641 }
1642 
1643 template <typename T, typename>
1644 inline TVMMovableArgValue_::operator T() const {
1645  if (type_code_ == kTVMObjectRValueRefArg) {
1646  auto** ref = static_cast<Object**>(value_.v_handle);
1647  if (ObjectTypeChecker<T>::Check(*ref)) {
1649  }
1650  }
1651  // fallback
1652  return PackedFuncValueConverter<T>::From(AsArgValue());
1653 }
1654 
1655 template <typename T, typename>
1656 inline TVMRetValue::operator T() const {
1657  return PackedFuncValueConverter<T>::From(*this);
1658 }
1659 
1660 inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) {
1661  return (*this)->GetFunction(name, query_imports);
1662 }
1663 
1664 // specializations of PackedFuncValueConverter
1665 template <>
1667  static String From(const TVMArgValue& val) {
1668  if (val.IsObjectRef<tvm::runtime::String>()) {
1669  return val.AsObjectRef<tvm::runtime::String>();
1670  } else {
1671  return tvm::runtime::String(val.operator std::string());
1672  }
1673  }
1674 
1675  static String From(const TVMRetValue& val) {
1676  if (val.IsObjectRef<tvm::runtime::String>()) {
1677  return val.AsObjectRef<tvm::runtime::String>();
1678  } else {
1679  return tvm::runtime::String(val.operator std::string());
1680  }
1681  }
1682 };
1683 
1684 template <typename T>
1686  static Optional<T> From(const TVMArgValue& val) {
1687  if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
1689  }
1690  static Optional<T> From(const TVMRetValue& val) {
1691  if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr);
1693  }
1694 };
1695 
1696 inline bool String::CanConvertFrom(const TVMArgValue& val) {
1697  return val.type_code() == kTVMStr || val.IsObjectRef<tvm::runtime::String>();
1698 }
1699 
1700 inline TVMArgValue::operator DLDataType() const {
1701  if (String::CanConvertFrom(*this)) {
1702  return String2DLDataType(PackedFuncValueConverter<String>::From(*this).operator std::string());
1703  }
1704  // None type
1705  if (type_code_ == kTVMNullptr) {
1706  DLDataType t;
1707  t.code = kTVMOpaqueHandle;
1708  t.bits = 0;
1709  t.lanes = 0;
1710  return t;
1711  }
1712  TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType);
1713  return value_.v_type;
1714 }
1715 
1716 inline TVMArgValue::operator DataType() const { return DataType(operator DLDataType()); }
1717 
1718 } // namespace runtime
1719 } // namespace tvm
1720 #endif // TVM_RUNTIME_PACKED_FUNC_H_
static TObjectRef From(const TVMArgValue &val)
Convert a TObjectRef from an argument value.
Definition: packed_func.h:1043
const char * ArgTypeCode2Str(int type_code)
Convert argument type code to string.
Definition: packed_func.h:1156
int num_args
Definition: packed_func.h:339
TVMArgValue()
default constructor
Definition: packed_func.h:586
static bool Check(const Object *ptr)
Definition: packed_func.h:433
TVMRetValue & operator=(std::string value)
Definition: packed_func.h:839
array node content in array
Definition: array.h:38
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:734
PackedFunc GetFunction(const std::string &name, bool query_imports=false)
Get packed function from current module by name.
Definition: packed_func.h:1660
static void FFIDecRef(TVMArrayHandle handle)
DecRef resource managed by an FFI array handle.
Definition: ndarray.h:396
TVMRetValue & operator=(TVMMovableArgValue_ &&other)
Definition: packed_func.h:882
TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef &&value) const
Definition: packed_func.h:1348
void MoveToCHost(TVMValue *ret_value, int *ret_type_code)
Move the value back to front-end via C API. This marks the current container as null. The managed resources are moved to the front-end. The front end should take charge in managing them.
Definition: packed_func.h:895
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
A custom smart pointer for Object.
Definition: object.h:356
TVMArgs(const TVMValue *values, const int *type_codes, int num_args)
constructor
Definition: packed_func.h:346
A PackedFunc wrapper to provide typed function signature. It is backed by a PackedFunc internally...
Definition: packed_func.h:171
TVM_ALWAYS_INLINE void operator()(size_t i, const char *value) const
Definition: packed_func.h:1302
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:125
TVMRetValue & operator=(bool value)
Definition: packed_func.h:834
int type_code_
the type code
Definition: packed_func.h:574
Internal base class to handle conversion to POD values.
Definition: packed_func.h:485
std::string DLDataType2String(DLDataType t)
convert a TVM type to string.
Definition: data_type.h:332
void * v_handle
Definition: c_runtime_api.h:147
const PackedFunc & packed() const
Definition: packed_func.h:302
Definition: c_runtime_api.h:124
Performance counters for profiling via the PAPI library.
Definition: analyzer.h:36
struct TVMArgs TVMArgs
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Definition: packed_func.h:448
FType body() const
Definition: packed_func.h:1153
Definition: c_runtime_api.h:119
TVMMovableArgValue_(TVMValue value, int type_code)
Definition: packed_func.h:652
TVM_ALWAYS_INLINE void operator()(size_t i, DLTensor *value) const
Definition: packed_func.h:1287
~TVMRetValue()
destructor
Definition: packed_func.h:747
Definition: c_runtime_api.h:115
const TVMValue * values
Definition: packed_func.h:337
Definition: c_runtime_api.h:120
TVMArgValue operator[](int i) const
Get i-th argument.
Definition: packed_func.h:1143
size_t size
Definition: c_runtime_api.h:159
TVMMovableArgValueWithContext_(TVMValue value, int type_code, int arg_index, const std::string *optional_name)
move constructor from another return value.
Definition: packed_func.h:705
static std::string TypeName()
Definition: packed_func.h:407
static bool Check(const Object *ptr)
Check if an object matches the template type.
Definition: packed_func.h:402
Union type of values being passed through API and function calls.
Definition: c_runtime_api.h:144
PackedFunc(FType body)
constructing a packed function from a std::function.
Definition: packed_func.h:97
base class of all object containers.
Definition: object.h:165
Definition: c_runtime_api.h:118
const char * data
Definition: c_runtime_api.h:158
Object * TVMArrayHandleToObjectHandle(TVMArrayHandle handle)
Definition: ndarray.h:400
TVMValue value_
The value.
Definition: packed_func.h:572
TVMRetValue()
default constructor
Definition: packed_func.h:737
Managed NDArray. The array is backed by reference counted blocks.
Definition: ndarray.h:59
TVMRetValue & operator=(int value)
Definition: packed_func.h:818
static void FFIClearAfterMove(ObjectRef *ref)
Clear the object ref data field without DecRef after we successfully moved the field.
Definition: object.h:585
TVMRetValue & operator=(TVMRetValue &&other)
Definition: packed_func.h:791
Byte array type used to pass in byte array When kTVMBytes is used as data type.
Definition: c_runtime_api.h:157
TVMRetValue & operator=(TVMByteArray value)
Definition: packed_func.h:843
TypedPackedFunc(const FLambda &typed_lambda)
construct from a lambda function with the same signature.
Definition: packed_func.h:253
bool IsInstance() const
Definition: object.h:822
TVMRetValue & operator=(const TVMArgValue &other)
Definition: packed_func.h:878
PackedFunc()
default constructor
Definition: packed_func.h:90
TVMRetValue & operator=(double value)
Definition: packed_func.h:798
TVM_ALWAYS_INLINE void operator()(size_t i, const TVMByteArray &value) const
Definition: packed_func.h:1311
Runtime Array container types.
Definition: packed_func.h:38
Internal auxiliary struct for TypedPackedFunc to indicate a movable argument.
Definition: packed_func.h:650
static bool Check(const Object *ptr)
Definition: packed_func.h:465
TVM_ALWAYS_INLINE void operator()(size_t i, T value) const
Definition: packed_func.h:1262
static Optional< T > From(const TVMRetValue &val)
Definition: packed_func.h:1690
A device-independent managed NDArray abstraction.
DLDataType String2DLDataType(std::string s)
convert a string to TVM type.
Definition: data_type.h:339
static String From(const TVMRetValue &val)
Definition: packed_func.h:1675
static ObjectPtr< Object > FFIDataFromHandle(TVMArrayHandle handle)
Construct NDArray&#39;s Data field from array handle in FFI.
Definition: ndarray.h:383
Type traits for runtime type check during FFI conversion.
Definition: packed_func.h:374
bool IsObjectRef() const
Definition: packed_func.h:1534
static Optional< T > From(const TVMArgValue &val)
Definition: packed_func.h:1686
Runtime primitive data type.
Definition: data_type.h:41
bool defined() const
Definition: object.h:537
TVM_ALWAYS_INLINE void operator()(size_t i, Device value) const
Definition: packed_func.h:1291
Arguments into TVM functions.
Definition: packed_func.h:335
TVMRetValue(TVMRetValue &&other)
move constructor from another return value.
Definition: packed_func.h:742
TSelf & operator=(FLambda typed_lambda)
copy assignment operator from typed lambda
Definition: packed_func.h:275
T value() const
Definition: optional.h:92
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
Definition: c_runtime_api.h:123
TypedPackedFunc(const FLambda &typed_lambda, std::string name)
construct from a lambda function with the same signature.
Definition: packed_func.h:230
TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const
Definition: packed_func.h:1266
static std::string TypeName()
Definition: packed_func.h:444
TSelf & operator=(PackedFunc packed)
copy assignment operator from PackedFunc.
Definition: packed_func.h:284
std::function< void(TVMArgs args, TVMRetValue *rv)> FType
The internal std::function.
Definition: packed_func.h:88
TVMRetValue & operator=(Module m)
Definition: packed_func.h:858
Object container class that backs NDArray.
Definition: ndarray.h:261
TVMRetValue & operator=(PackedFunc f)
Definition: packed_func.h:862
TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef &value) const
Definition: packed_func.h:1341
const int * type_codes
Definition: packed_func.h:338
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:567
TVMRetValue & operator=(DLDataType t)
Definition: packed_func.h:828
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:123
TObjectRef AsObjectRef() const
Definition: packed_func.h:1559
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
TVM_ALWAYS_INLINE void operator()(size_t i, const std::string &value) const
Definition: packed_func.h:1307
Reference to string objects.
Definition: string.h:129
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:136
TVM_ALWAYS_INLINE void operator()(size_t i, DataType dtype) const
Definition: packed_func.h:1299
TVMPODValue_(TVMValue value, int type_code)
Definition: packed_func.h:569
TVM_ALWAYS_INLINE void operator()(size_t i, double value) const
Definition: packed_func.h:1271
TVM_ALWAYS_INLINE void operator()(size_t i, std::nullptr_t value) const
Definition: packed_func.h:1275
T * ptr() const
return handle as specific pointer type.
Definition: packed_func.h:554
DLDevice Device
Definition: ndarray.h:43
Definition: c_runtime_api.h:113
const Object * get() const
Definition: object.h:539
TVMRetValue & operator=(const TypedPackedFunc< FType > &f)
Definition: packed_func.h:871
static TVMRetValue MoveFromCHost(TVMValue value, int type_code)
Construct a new TVMRetValue by moving from return value stored via C API.
Definition: packed_func.h:909
TVMRetValue & operator=(DLDevice value)
Definition: packed_func.h:823
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Check if an object matches the template type and return the mismatched type if it exists...
Definition: packed_func.h:382
static bool CanConvertFrom(const TVMArgValue &val)
Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String.
Definition: packed_func.h:1696
Base class of all object reference.
Definition: object.h:504
Base container of module.
Definition: module.h:111
int size() const
Definition: packed_func.h:1149
std::string GetTypeKey() const
Definition: object.h:178
TVMRetValue & operator=(const DataType &other)
Definition: packed_func.h:833
Shared content of all specializations of hash map.
Definition: map.h:167
TVMRetValue & operator=(int64_t value)
Definition: packed_func.h:813
A managed object in the TVM runtime.
TVM_ALWAYS_INLINE void operator()(size_t i, const TVMArgValue &value) const
Definition: packed_func.h:1279
static TVMArrayHandle FFIGetHandle(const ObjectRef &nd)
Get FFI Array handle from ndarray.
Definition: ndarray.h:388
void operator()(size_t i, const TVMRetValue &value) const
Definition: packed_func.h:1328
TVM_ALWAYS_INLINE void operator()(size_t i, const PackedFunc &value) const
Definition: packed_func.h:1315
Runtime container of the functions generated by TVM, This is used to support dynamically link...
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:304
Module container of TVM.
Definition: module.h:48
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:306
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1235
static std::string TypeIndex2Key(uint32_t tindex)
Get the type key of the corresponding index from runtime.
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Runtime Map container types.
static String From(const TVMArgValue &val)
Definition: packed_func.h:1667
static std::string TypeName()
Definition: packed_func.h:475
Definition: c_runtime_api.h:114
int type_code() const
Definition: packed_func.h:547
A single argument value to PackedFunc. Containing both type_code and TVMValue.
Definition: packed_func.h:583
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:68
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
const TVMValue & value() const
Definition: packed_func.h:632
TVMRetValue operator()(Args &&... args) const
Call packed function by directly passing in unpacked format.
Definition: packed_func.h:1362
Definition: packed_func.h:1257
constexpr runtime::NullOptType NullOpt
Definition: optional.h:155
TVMArgsSetter(TVMValue *values, int *type_codes)
Definition: packed_func.h:1259
TypedPackedFunc(std::nullptr_t null)
constructor from null
Definition: packed_func.h:178
TVMRetValue(const TVMRetValue &other)
Definition: packed_func.h:762
void CallPacked(TVMArgs args, TVMRetValue *rv) const
Call the function in packed format.
Definition: packed_func.h:1151
PackedFunc(std::nullptr_t null)
constructor from null
Definition: packed_func.h:92
const TVMValue & value() const
Definition: packed_func.h:918
Definition: c_runtime_api.h:116
TVMArgValue(TVMValue value, int type_code)
constructor
Definition: packed_func.h:592
TVMRetValue & operator=(const TVMRetValue &other)
Definition: packed_func.h:874
#define TVM_CHECK_TYPE_CODE(CODE, T)
Definition: packed_func.h:366
TVM_ALWAYS_INLINE void operator()(size_t i, void *value) const
Definition: packed_func.h:1283
Definition: c_runtime_api.h:122
TVMRetValue & operator=(std::nullptr_t value)
Definition: packed_func.h:803
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Definition: packed_func.h:416
runtime::DataType DataType
Definition: data_type.h:389
Definition: c_runtime_api.h:121
TVMRetValue & operator=(NDArray other)
Definition: packed_func.h:847
Type trait to specify special value conversion rules from TVMArgValue and TVMRetValue.
Definition: packed_func.h:1037
static TObjectRef From(const TVMRetValue &val)
Convert a TObjectRef from a return value.
Definition: packed_func.h:1049
Definition: c_runtime_api.h:117
TypedPackedFunc()
default constructor
Definition: packed_func.h:176
Internal auxiliary struct for TypedPackedFunc to indicate a movable argument with additional context ...
Definition: packed_func.h:695
TVMRetValue & operator=(void *value)
Definition: packed_func.h:808
TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc< FType > &value) const
Definition: packed_func.h:1325
TVM_ALWAYS_INLINE void operator()(size_t i, DLDataType value) const
Definition: packed_func.h:1295
TVMPODValue_()
Definition: packed_func.h:568