24 #ifndef TVM_RUNTIME_PACKED_FUNC_H_ 25 #define TVM_RUNTIME_PACKED_FUNC_H_ 31 #include <tvm/runtime/logging.h> 41 #include <type_traits> 46 #ifndef TVM_RUNTIME_HEADER_ONLY 47 #define TVM_RUNTIME_HEADER_ONLY 0 56 class TVMMovableArgValueWithContext_;
59 template <
typename FType>
61 template <
typename TSignature>
78 static constexpr
const char*
_type_key =
"runtime.PackedFunc";
85 template <
class TPackedFuncSubObj>
113 template <
class TCallable>
115 using TStorage =
typename std::remove_cv<typename std::remove_reference<TCallable>::type>::type;
147 template <
typename TCallable,
148 typename = std::enable_if_t<
149 std::is_convertible<TCallable, std::function<void(TVMArgs, TVMRetValue*)>>::value &&
150 !std::is_base_of<TCallable, PackedFunc>::value>>
153 data_ = make_object<ObjType>(std::forward<TCallable>(data));
169 template <
typename... Args>
170 inline TVMRetValue operator()(Args&&... args)
const;
178 bool operator==(std::nullptr_t null)
const {
return data_ ==
nullptr; }
180 bool operator!=(std::nullptr_t null)
const {
return data_ !=
nullptr; }
191 template <
typename FType>
226 template <
typename R,
typename... Args>
284 template <
typename FLambda,
typename =
typename std::enable_if<std::is_convertible<
285 FLambda, std::function<R(Args...)>>::value>::type>
287 this->AssignTypedLambda(typed_lambda, name);
307 template <
typename FLambda,
typename =
typename std::enable_if<std::is_convertible<
308 FLambda, std::function<R(Args...)>>::value>::type>
310 this->AssignTypedLambda(typed_lambda);
328 template <
typename FLambda,
typename =
typename std::enable_if<
329 std::is_convertible<FLambda,
330 std::function<R(Args...)>>::value>::type>
332 this->AssignTypedLambda(typed_lambda);
349 TVM_ALWAYS_INLINE R operator()(Args... args)
const;
360 bool operator==(std::nullptr_t null)
const {
return packed_ ==
nullptr; }
362 bool operator!=(std::nullptr_t null)
const {
return packed_ !=
nullptr; }
376 template <
typename FLambda>
377 inline void AssignTypedLambda(FLambda flambda, std::string name);
386 template <
typename FLambda>
387 inline void AssignTypedLambda(FLambda flambda);
403 : values(values), type_codes(type_codes), num_args(num_args) {}
405 inline int size()
const;
422 #define TVM_CHECK_TYPE_CODE(CODE, T) \ 423 ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) 429 template <
typename T>
439 using ContainerType =
typename T::ContainerType;
440 if (ptr ==
nullptr) {
441 if (T::_type_is_nullable) {
459 using ContainerType =
typename T::ContainerType;
460 if (ptr ==
nullptr)
return T::_type_is_nullable;
464 using ContainerType =
typename T::ContainerType;
465 return ContainerType::_type_key;
470 template <
typename T>
473 if (ptr ==
nullptr) {
480 for (
size_t i = 0; i < n->size(); i++) {
484 return String(
"Array[index " + std::to_string(i) +
": " + check_subtype.
value() +
"]");
490 if (ptr ==
nullptr)
return true;
502 template <
typename K,
typename V>
505 if (ptr ==
nullptr)
return NullOpt;
508 for (
const auto& kv : *n) {
511 if (key_type.
defined() || value_type.defined()) {
512 std::string key_name =
514 std::string value_name = value_type.defined() ? std::string(value_type.value())
516 return String(
"Map[" + key_name +
", " + value_name +
"]");
522 if (ptr ==
nullptr)
return true;
525 for (
const auto& kv : *n) {
543 operator double()
const {
547 if (type_code_ == kDLInt) {
548 return static_cast<double>(value_.v_int64);
551 return value_.v_float64;
553 operator int64_t()
const {
555 return value_.v_int64;
557 operator uint64_t()
const {
559 return value_.v_int64;
561 operator int()
const {
565 return static_cast<int>(value_.v_int64);
567 operator bool()
const {
569 return value_.v_int64 != 0;
571 operator void*()
const {
575 return value_.v_handle;
577 operator DLTensor*()
const {
579 return static_cast<DLTensor*
>(value_.v_handle);
582 LOG(FATAL) <<
"Expected " 608 return value_.v_device;
616 template <
typename T>
618 return static_cast<T*
>(value_.v_handle);
621 template <
typename TObjectRef,
622 typename =
typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
623 inline bool IsObjectRef()
const;
624 template <
typename TObjectRef>
625 inline TObjectRef AsObjectRef()
const;
657 using TVMPODValue_::operator double;
658 using TVMPODValue_::operator int64_t;
659 using TVMPODValue_::operator uint64_t;
660 using TVMPODValue_::operator int;
661 using TVMPODValue_::operator bool;
662 using TVMPODValue_::operator
void*;
663 using TVMPODValue_::operator DLTensor*;
664 using TVMPODValue_::operator
NDArray;
665 using TVMPODValue_::operator
Device;
666 using TVMPODValue_::operator
Module;
672 operator std::string()
const {
677 return std::string(arr->
data, arr->
size);
678 }
else if (type_code_ ==
kTVMStr) {
679 return std::string(value_.v_str);
681 ICHECK(IsObjectRef<tvm::runtime::String>())
684 return AsObjectRef<tvm::runtime::String>().
operator std::string();
687 template <
typename FType>
693 template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
694 inline operator T()
const;
695 inline operator DLDataType()
const;
713 using TVMPODValue_::operator double;
714 using TVMPODValue_::operator int64_t;
715 using TVMPODValue_::operator uint64_t;
716 using TVMPODValue_::operator int;
717 using TVMPODValue_::operator bool;
718 using TVMPODValue_::operator
void*;
719 using TVMPODValue_::operator DLTensor*;
720 using TVMPODValue_::operator
NDArray;
721 using TVMPODValue_::operator
Device;
722 using TVMPODValue_::operator
Module;
725 operator std::string()
const {
return AsArgValue().operator std::string(); }
726 template <
typename FType>
730 operator DLDataType()
const {
return AsArgValue().operator DLDataType(); }
738 template <
typename T,
739 typename =
typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
740 inline operator T()
const;
766 const std::string* optional_name,
FSig* f_sig)
767 : value_(value, type_code),
768 arg_index_(arg_index),
769 optional_name_(optional_name),
772 template <
typename T>
776 }
catch (dmlc::Error& e) {
777 LOG(FATAL) <<
"In function " << (optional_name_ ==
nullptr ?
"<anonymous>" : *optional_name_)
778 << (f_sig_ ==
nullptr ?
"" : (*f_sig_)()) <<
": error while converting argument " 779 << arg_index_ <<
": " << e.what();
787 const std::string* optional_name_;
814 using TVMPODValue_::operator double;
815 using TVMPODValue_::operator int64_t;
816 using TVMPODValue_::operator uint64_t;
817 using TVMPODValue_::operator int;
818 using TVMPODValue_::operator bool;
819 using TVMPODValue_::operator
void*;
820 using TVMPODValue_::operator DLTensor*;
821 using TVMPODValue_::operator
Device;
822 using TVMPODValue_::operator
NDArray;
823 using TVMPODValue_::operator
Module;
830 operator std::string()
const {
834 return *ptr<std::string>();
837 return *ptr<std::string>();
839 operator DLDataType()
const {
844 return value_.v_type;
847 template <
typename FType>
854 value_ = other.value_;
855 type_code_ = other.type_code_;
860 this->SwitchToPOD(kDLFloat);
861 value_.v_float64 = value;
866 value_.v_handle = value;
871 value_.v_handle = value;
875 this->SwitchToPOD(kDLInt);
876 value_.v_int64 = value;
880 this->SwitchToPOD(kDLInt);
881 value_.v_int64 = value;
886 value_.v_device = value;
896 this->SwitchToPOD(kDLInt);
897 value_.v_int64 = value;
901 this->SwitchToClass(
kTVMStr, value);
909 if (other.
data_ !=
nullptr) {
916 value_.v_handle =
nullptr;
928 template <
typename FType>
957 *ret_type_code = type_code_;
979 <<
"TVMRetValue.value can only be used for POD data";
983 template <
typename TObjectRef,
984 typename =
typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
986 template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
987 inline operator T()
const;
990 template <
typename T>
991 void Assign(
const T& other) {
992 switch (other.type_code()) {
994 SwitchToClass<std::string>(
kTVMStr, other);
998 SwitchToClass<std::string>(
kTVMBytes, other);
1002 *
this = other.operator PackedFunc();
1006 *
this = other.operator Module();
1010 *
this = other.operator NDArray();
1016 GetObjectPtr<Object>(static_cast<Object*>(other.value_.v_handle)));
1024 SwitchToPOD(other.type_code());
1025 value_ = other.value_;
1031 void SwitchToPOD(
int type_code) {
1032 if (type_code_ != type_code) {
1034 type_code_ = type_code;
1037 template <
typename T>
1038 void SwitchToClass(
int type_code, T v) {
1039 if (type_code_ != type_code) {
1041 type_code_ = type_code;
1042 value_.v_handle =
new T(v);
1044 *
static_cast<T*
>(value_.v_handle) = v;
1048 if (other.data_ !=
nullptr) {
1050 type_code_ = type_code;
1052 value_.v_handle = other.data_;
1053 other.data_ =
nullptr;
1056 value_.v_handle =
nullptr;
1061 switch (type_code_) {
1064 delete ptr<std::string>();
1095 template <
typename TObjectRef>
1130 #define TVM_DLL_EXPORT_PACKED_FUNC(ExportName, Function) \ 1132 TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ 1133 int* out_type_code, void* resource_handle); \ 1134 int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ 1135 int* out_type_code, void* resource_handle) { \ 1137 ::tvm::runtime::TVMRetValue rv; \ 1138 Function(::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ 1139 rv.MoveToCHost(out_value, out_type_code); \ 1141 } catch (const ::std::exception& _except_) { \ 1142 TVMAPISetLastError(_except_.what()); \ 1183 #define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ 1185 TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ 1186 int* out_type_code, void* resource_handle) { \ 1188 auto f = Function; \ 1189 using FType = ::tvm::runtime::detail::function_signature<decltype(f)>::FType; \ 1190 ::tvm::runtime::TVMRetValue rv; \ 1191 ::tvm::runtime::detail::unpack_call_by_signature<FType>::run( \ 1192 f, ::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ 1193 rv.MoveToCHost(out_value, out_type_code); \ 1195 } catch (const ::std::exception& _except_) { \ 1196 TVMAPISetLastError(_except_.what()); \ 1203 ICHECK_LT(i, num_args) <<
"not enough argument passed, " << num_args <<
" passed" 1204 <<
" but request arg[" << i <<
"].";
1210 template <
class TPackedFuncSubObj>
1213 (
static_cast<const TPackedFuncSubObj*
>(obj))->callable_(args, rv);
1217 (*f_call_packed_)(
this, args, rv);
1221 (
static_cast<PackedFuncObj*
>(data_.get()))->CallPacked(args, rv);
1226 switch (type_code) {
1242 return "ArrayHandle";
1244 return "DLDataType";
1248 return "FunctionHandle";
1250 return "ModuleHandle";
1252 return "NDArrayContainer";
1256 return "ObjectRValueRefArg";
1258 LOG(FATAL) <<
"unknown type_code=" <<
static_cast<int>(type_code);
1264 template <
bool stop, std::
size_t I,
typename F>
1265 struct for_each_dispatcher {
1266 template <
typename T,
typename... Args>
1267 static void run(
const F& f, T&& value, Args&&... args) {
1268 f(I, std::forward<T>(value));
1269 for_each_dispatcher<
sizeof...(Args) == 0, (I + 1), F>::run(f, std::forward<Args>(args)...);
1273 template <std::
size_t I,
typename F>
1274 struct for_each_dispatcher<true, I, F> {
1275 static void run(
const F& f) {}
1278 template <
typename F,
typename... Args>
1279 inline void for_each(
const F& f, Args&&... args) {
1280 for_each_dispatcher<
sizeof...(Args) == 0, 0, F>::run(f, std::forward<Args>(args)...);
1283 namespace parameter_pack {
1285 template <
typename... EnumArgs>
1286 struct EnumeratedParamPack {
1288 template <
template <
size_t i,
typename TArgument>
class Functor,
typename... ExtraParams>
1289 static void F(ExtraParams&&... extra_params) {
1290 using TExpander =
int[];
1293 (Functor<EnumArgs::i, typename EnumArgs::T>::F(extra_params...), 0)...,
1299 template <
typename... Args>
1300 struct EnumerateImpl {
1302 template <
size_t _i,
typename _T>
1304 static const constexpr
size_t i = _i;
1308 template <
typename...>
1311 template <std::size_t...
id>
1312 struct Zipper<std::integer_sequence<std::size_t, id...>> {
1313 using T = EnumeratedParamPack<Item<id, Args>...>;
1317 using T =
typename Zipper<std::index_sequence_for<Args...>>::T;
1320 template <
typename... Args>
1321 using Enumerate =
typename EnumerateImpl<Args...>::T;
1323 template <
typename... Args>
1325 template <
template <
size_t i,
typename TArgument>
class Functor,
typename... ExtraParams>
1326 static void InvokeWithoutArg(ExtraParams&&... extra_params) {
1327 Enumerate<Args...>::Invoke::template F<Functor, ExtraParams...>(
1328 std::forward<ExtraParams>(extra_params)...);
1338 template <
typename T>
1339 struct func_signature_helper {
1343 template <
typename T,
typename R,
typename... Args>
1344 struct func_signature_helper<R (T::*)(Args...)> {
1345 using FType = R(Args...);
1346 using ParamType = parameter_pack::ParamPack<Args...>;
1348 static_assert(!std::is_reference<R>::value,
"TypedPackedFunc return reference");
1351 template <
typename T,
typename R,
typename... Args>
1352 struct func_signature_helper<R (T::*)(Args...) const> {
1353 using FType = R(Args...);
1354 using ParamType = parameter_pack::ParamPack<Args...>;
1356 static_assert(!std::is_reference<R>::value,
"TypedPackedFunc return reference");
1363 template <
typename T>
1364 struct function_signature {
1365 using FType =
typename func_signature_helper<decltype(&T::operator())>::FType;
1366 using ParamType =
typename func_signature_helper<decltype(&T::operator())>::ParamType;
1367 using RetType =
typename func_signature_helper<decltype(&T::operator())>::RetType;
1371 template <
typename R,
typename... Args>
1372 struct function_signature<R(Args...)> {
1373 using FType = R(Args...);
1374 using ParamType = parameter_pack::ParamPack<Args...>;
1376 static_assert(!std::is_reference<R>::value,
"TypedPackedFunc return reference");
1380 template <
typename R,
typename... Args>
1381 struct function_signature<R (*)(Args...)> {
1382 using FType = R(Args...);
1383 using ParamType = detail::parameter_pack::ParamPack<Args...>;
1385 static_assert(!std::is_reference<R>::value,
"TypedPackedFunc return reference");
1388 template <
typename TSignature>
1391 namespace type2str {
1393 template <
typename T>
1394 struct TypeSimplifier;
1396 template <
typename T>
1398 template <typename = std::enable_if_t<std::is_base_of<ObjectRef, T>::value>>
1399 static std::string v() {
1400 return T::ContainerType::_type_key;
1404 struct Type2Str<int> {
1405 static std::string v() {
return "int"; }
1408 struct Type2Str<double> {
1409 static std::string v() {
return "double"; }
1412 struct Type2Str<int64_t> {
1413 static std::string v() {
return "int64_t"; }
1416 struct Type2Str<uint64_t> {
1417 static std::string v() {
return "uint64_t"; }
1420 struct Type2Str<bool> {
1421 static std::string v() {
return "bool"; }
1424 struct Type2Str<void> {
1425 static std::string v() {
return "void"; }
1428 struct Type2Str<std::basic_string<char>> {
1429 static std::string v() {
return "basic_string<char>"; }
1431 template <
typename K,
typename V>
1432 struct Type2Str<Map<K, V>> {
1433 static std::string v() {
1434 return "Map<" + TypeSimplifier<K>::v() +
", " + TypeSimplifier<V>::v() +
">";
1438 struct Type2Str<DLDevice> {
1439 static std::string v() {
return "DLDevice"; }
1442 struct Type2Str<DLTensor> {
1443 static std::string v() {
return "DLTensor"; }
1446 struct Type2Str<DataType> {
1447 static std::string v() {
return "DataType"; }
1450 struct Type2Str<DLDataType> {
1451 static std::string v() {
return "DLDataType"; }
1454 struct Type2Str<TVMRetValue> {
1455 static std::string v() {
return "TVMRetValue"; }
1458 struct Type2Str<TVMArgValue> {
1459 static std::string v() {
return "TVMArgValue"; }
1462 struct Type2Str<TVMByteArray> {
1463 static std::string v() {
return "TVMByteArray"; }
1465 template <
typename FType>
1466 struct Type2Str<TypedPackedFunc<FType>> {
1469 template <
typename T>
1470 struct Type2Str<Array<T>> {
1471 static std::string v() {
return "Array<" + TypeSimplifier<T>::v() +
">"; }
1478 template <
typename T>
1479 struct TypeSimplifier {
1480 static std::string v() {
1481 using U =
typename std::remove_cv<
1482 typename std::remove_reference<typename std::remove_pointer<T>::type>::type>::type;
1483 return (std::is_const<T>::value ?
"const " :
"") + Type2Str<U>::v() +
1484 (std::is_pointer<T>::value ?
"*" :
"") + (std::is_reference<T>::value ?
"&" :
"");
1494 template <
typename TSignature>
1496 using ParamType =
typename TSignature::ParamType;
1497 using RetType =
typename TSignature::RetType;
1499 template <
size_t i,
typename TArgument>
1500 struct PrintParamType {
1501 static void F(std::ostream& os) {
1502 os << (i == 0 ?
"" :
", ") << i <<
": " << type2str::TypeSimplifier<TArgument>::v();
1506 static std::string F() {
1507 std::ostringstream oss;
1509 ParamType::template InvokeWithoutArg<PrintParamType>(oss);
1510 oss <<
") -> " << type2str::TypeSimplifier<RetType>::v();
1521 template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type>
1523 values_[i].v_int64 =
static_cast<int64_t
>(value);
1524 type_codes_[i] = kDLInt;
1526 TVM_ALWAYS_INLINE
void operator()(
size_t i, uint64_t value)
const {
1527 values_[i].v_int64 =
static_cast<int64_t
>(value);
1529 type_codes_[i] = kDLInt;
1531 TVM_ALWAYS_INLINE
void operator()(
size_t i,
double value)
const {
1532 values_[i].v_float64 = value;
1533 type_codes_[i] = kDLFloat;
1535 TVM_ALWAYS_INLINE
void operator()(
size_t i, std::nullptr_t value)
const {
1536 values_[i].v_handle = value;
1540 values_[i] = value.
value_;
1544 values_[i].v_handle = value;
1547 TVM_ALWAYS_INLINE
void operator()(
size_t i, DLTensor* value)
const {
1548 values_[i].v_handle = value;
1552 values_[i].v_device = value;
1555 TVM_ALWAYS_INLINE
void operator()(
size_t i, DLDataType value)
const {
1556 values_[i].v_type = value;
1560 operator()(i, dtype.operator DLDataType());
1562 TVM_ALWAYS_INLINE
void operator()(
size_t i,
const char* value)
const {
1563 values_[i].v_str = value;
1567 TVM_ALWAYS_INLINE
void operator()(
size_t i,
const std::string& value)
const {
1568 values_[i].v_str = value.c_str();
1572 values_[i].v_handle =
const_cast<TVMByteArray*
>(&value);
1575 template <
typename FType>
1577 operator()(i, value.packed());
1581 values_[i].v_str = value.
ptr<std::string>()->c_str();
1585 values_[i] = value.
value_;
1590 template <
typename TObjectRef,
1591 typename =
typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
1592 TVM_ALWAYS_INLINE
void operator()(
size_t i,
const TObjectRef& value)
const {
1593 this->SetObject(i, value);
1596 template <
typename TObjectRef,
1597 typename =
typename std::enable_if<std::is_base_of<
1598 ObjectRef,
typename std::remove_reference<TObjectRef>::type>::value>::type>
1599 TVM_ALWAYS_INLINE
void operator()(
size_t i, TObjectRef&& value)
const {
1600 this->SetObject(i, std::forward<TObjectRef>(value));
1604 template <
typename TObjectRef>
1605 inline void SetObject(
size_t i, TObjectRef&& value)
const;
1612 template <
typename... Args>
1614 const int kNumArgs =
sizeof...(Args);
1615 const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
1617 int type_codes[kArraySize];
1618 detail::for_each(
TVMArgsSetter(values, type_codes), std::forward<Args>(args)...);
1621 ->CallPacked(
TVMArgs(values, type_codes, kNumArgs), &rv);
1626 template <
typename R,
int nleft,
int index,
typename F>
1627 struct unpack_call_dispatcher {
1628 template <
typename... Args>
1629 TVM_ALWAYS_INLINE
static void run(
const std::string* optional_name,
FSig* f_sig,
const F& f,
1631 Args&&... unpacked_args) {
1634 unpack_call_dispatcher<R, nleft - 1, index + 1, F>::run(
1635 optional_name, f_sig, f, args_pack, rv, std::forward<Args>(unpacked_args)...,
1637 optional_name, f_sig));
1641 template <
typename R,
int index,
typename F>
1642 struct unpack_call_dispatcher<R, 0, index, F> {
1643 template <
typename... Args>
1644 TVM_ALWAYS_INLINE
static void run(
const std::string* optional_name,
FSig* f_sig,
const F& f,
1646 Args&&... unpacked_args) {
1647 using RetType = decltype(f(std::forward<Args>(unpacked_args)...));
1648 if (std::is_same<RetType, R>::value) {
1649 *rv = f(std::forward<Args>(unpacked_args)...);
1651 *rv = R(f(std::forward<Args>(unpacked_args)...));
1656 template <
int index,
typename F>
1657 struct unpack_call_dispatcher<void, 0, index, F> {
1658 template <
typename... Args>
1659 TVM_ALWAYS_INLINE
static void run(
const std::string* optional_name,
FSig* f_sig,
const F& f,
1661 Args&&... unpacked_args) {
1662 f(std::forward<Args>(unpacked_args)...);
1666 template <
typename R,
int nargs,
typename F>
1667 TVM_ALWAYS_INLINE
void unpack_call(
const std::string* optional_name,
const F& f,
1669 FSig* f_sig = detail::SignaturePrinter<detail::function_signature<F>>::F;
1670 CHECK_EQ(nargs, args.
size()) <<
"Function " 1671 << (optional_name ==
nullptr ?
"<anonymous>" : *optional_name)
1672 << (f_sig ==
nullptr ?
"" : (*f_sig)()) <<
" expects " << nargs
1673 <<
" arguments but " << args.
size() <<
" were provided";
1674 unpack_call_dispatcher<R, nargs, 0, F>::run(optional_name, f_sig, f, args, rv);
1677 template <
typename FType>
1678 struct unpack_call_by_signature {};
1680 template <
typename R,
typename... Args>
1681 struct unpack_call_by_signature<R(Args...)> {
1682 template <
typename F>
1683 TVM_ALWAYS_INLINE
static void run(
const F& f,
const TVMArgs& args,
TVMRetValue* rv) {
1684 unpack_call<R,
sizeof...(Args)>(
nullptr, f, args, rv);
1688 template <
typename R,
typename... Args>
1689 TVM_ALWAYS_INLINE R call_packed(
const PackedFunc& pf, Args&&... args) {
1690 return R(pf(std::forward<Args>(args)...));
1693 template <
typename R>
1694 struct typed_packed_call_dispatcher {
1695 template <
typename... Args>
1696 TVM_ALWAYS_INLINE
static R run(
const PackedFunc& pf, Args&&... args) {
1697 return pf(std::forward<Args>(args)...);
1702 struct typed_packed_call_dispatcher<void> {
1703 template <
typename... Args>
1704 TVM_ALWAYS_INLINE
static void run(
const PackedFunc& pf, Args&&... args) {
1705 pf(std::forward<Args>(args)...);
1710 template <
typename R,
typename... Args>
1713 template <
typename R,
typename... Args>
1717 template <
typename R,
typename... Args>
1721 template <
typename R,
typename... Args>
1725 template <
typename R,
typename... Args>
1726 template <
typename FType>
1727 inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda, std::string name) {
1728 FSig* f_sig = detail::SignaturePrinter<detail::function_signature<FType>>::F;
1730 if (args.
size() !=
sizeof...(Args)) {
1731 LOG(FATAL) <<
"Function " << name << (f_sig ==
nullptr ?
"" : (*f_sig)()) <<
" expects " 1732 <<
sizeof...(Args) <<
" arguments, but " << args.
size() <<
" were provided.";
1734 detail::unpack_call<R,
sizeof...(Args)>(&name, flambda, args, rv);
1738 template <
typename R,
typename... Args>
1739 template <
typename FType>
1741 FSig* f_sig = detail::SignaturePrinter<detail::function_signature<FType>>::F;
1743 if (args.
size() !=
sizeof...(Args)) {
1744 LOG(FATAL) <<
"Function <anonymous> " << (*f_sig)() <<
" expects " <<
sizeof...(Args)
1745 <<
" arguments, but " << args.
size() <<
" were provided.";
1747 detail::unpack_call<R,
sizeof...(Args)>(
nullptr, flambda, args, rv);
1751 template <
typename R,
typename... Args>
1753 return detail::typed_packed_call_dispatcher<R>::run(packed_, std::forward<Args>(args)...);
1761 template <
typename T>
1762 inline void TVMArgsSetter::SetObject(
size_t i, T&& value)
const {
1763 using ContainerType =
typename std::remove_reference<T>::type::ContainerType;
1764 if (value.defined()) {
1765 Object* ptr = value.data_.data_;
1766 if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
1767 (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1771 }
else if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
1772 (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1774 values_[i].v_handle = ptr;
1776 }
else if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value ||
1777 (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
1779 values_[i].v_handle = ptr;
1781 }
else if (std::is_rvalue_reference<decltype(value)>::value) {
1782 values_[i].v_handle =
const_cast<Object**
>(&(value.data_.data_));
1785 values_[i].v_handle = value.data_.data_;
1790 values_[i].v_handle =
nullptr;
1794 template <
typename TObjectRef,
typename>
1796 using ContainerType =
typename TObjectRef::ContainerType;
1798 if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
1803 if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
1805 static_cast<Object*
>(value_.v_handle)->IsInstance<ContainerType>();
1807 if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value) {
1809 static_cast<Object*
>(value_.v_handle)->IsInstance<ContainerType>();
1815 return (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1817 (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1819 (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
1825 template <
typename TObjectRef>
1827 static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
1828 "Conversion only works for ObjectRef");
1829 using ContainerType =
typename TObjectRef::ContainerType;
1832 CHECK(TObjectRef::_type_is_nullable)
1833 <<
"Expect a not null value of " << ContainerType::_type_key;
1837 if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
1842 CHECK(data->IsInstance<ContainerType>())
1843 <<
"Expected " << ContainerType::_type_key <<
" but got " << data->GetTypeKey();
1844 return TObjectRef(data);
1846 if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
1850 CHECK(data->IsInstance<ContainerType>())
1851 <<
"Expected " << ContainerType::_type_key <<
" but got " << data->GetTypeKey();
1852 return TObjectRef(data);
1854 if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value) {
1858 CHECK(data->IsInstance<ContainerType>())
1859 <<
"Expected " << ContainerType::_type_key <<
" but got " << data->GetTypeKey();
1860 return TObjectRef(data);
1867 <<
", but got " << checked_type.
value();
1868 return TObjectRef(GetObjectPtr<Object>(ptr));
1873 <<
", but got " << checked_type.
value();
1874 return TObjectRef(GetObjectPtr<Object>(ptr));
1875 }
else if (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1880 return TObjectRef(data);
1881 }
else if (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1884 return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
1885 }
else if (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
1888 return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
1895 template <
typename TObjectRef,
typename>
1897 using ContainerType =
typename TObjectRef::ContainerType;
1898 const Object* ptr = other.get();
1899 if (ptr !=
nullptr) {
1900 if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
1901 (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1903 return operator=(
NDArray(std::move(other.data_)));
1905 if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
1906 (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1908 return operator=(
Module(std::move(other.data_)));
1910 if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value ||
1911 (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
1913 return operator=(
PackedFunc(std::move(other.data_)));
1918 value_.v_handle =
nullptr;
1923 template <
typename T,
typename>
1924 inline TVMArgValue::operator T()
const {
1928 template <
typename T,
typename>
1929 inline TVMMovableArgValue_::operator T()
const {
1931 auto** ref =
static_cast<Object**
>(value_.v_handle);
1940 template <
typename T,
typename>
1941 inline TVMRetValue::operator T()
const {
1946 return (*this)->GetFunction(name, query_imports);
1969 template <
typename T>
1985 inline TVMArgValue::operator DLDataType()
const {
1998 return value_.v_type;
2005 #endif // TVM_RUNTIME_PACKED_FUNC_H_ static TObjectRef From(const TVMArgValue &val)
Convert a TObjectRef from an argument value.
Definition: packed_func.h:1102
const char * ArgTypeCode2Str(int type_code)
Convert argument type code to string.
Definition: packed_func.h:1225
int num_args
Definition: packed_func.h:395
TVMArgValue()
default constructor
Definition: packed_func.h:649
static bool Check(const Object *ptr)
Definition: packed_func.h:489
TVMRetValue & operator=(std::string value)
Definition: packed_func.h:900
array node content in array
Definition: array.h:40
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:799
PackedFunc GetFunction(const std::string &name, bool query_imports=false)
Get packed function from current module by name.
Definition: packed_func.h:1945
static void FFIDecRef(TVMArrayHandle handle)
DecRef resource managed by an FFI array handle.
Definition: ndarray.h:430
TVMRetValue & operator=(TVMMovableArgValue_ &&other)
Definition: packed_func.h:940
TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef &&value) const
Definition: packed_func.h:1599
void MoveToCHost(TVMValue *ret_value, int *ret_type_code)
Move the value back to front-end via C API. This marks the current container as null. The managed resources are moved to the front-end. The front end should take charge in managing them.
Definition: packed_func.h:953
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
A custom smart pointer for Object.
Definition: object.h:358
TVMArgs(const TVMValue *values, const int *type_codes, int num_args)
constructor
Definition: packed_func.h:402
A PackedFunc wrapper to provide typed function signature. It is backed by a PackedFunc internally...
Definition: packed_func.h:227
TVM_ALWAYS_INLINE void operator()(size_t i, const char *value) const
Definition: packed_func.h:1562
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:180
TVMRetValue & operator=(bool value)
Definition: packed_func.h:895
int type_code_
the type code
Definition: packed_func.h:637
Internal base class to handle conversion to POD values.
Definition: packed_func.h:541
static constexpr const uint32_t _type_index
Definition: packed_func.h:77
std::string DLDataType2String(DLDataType t)
convert a TVM type to string.
Definition: data_type.h:341
void * v_handle
Definition: c_runtime_api.h:211
const PackedFunc & packed() const
Definition: packed_func.h:358
Definition: c_runtime_api.h:188
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Definition: packed_func.h:504
TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue *rv) const
Call the function in packed format.
Definition: packed_func.h:1220
Definition: c_runtime_api.h:183
TVMMovableArgValue_(TVMValue value, int type_code)
Definition: packed_func.h:711
TVM_ALWAYS_INLINE void operator()(size_t i, DLTensor *value) const
Definition: packed_func.h:1547
~TVMRetValue()
destructor
Definition: packed_func.h:812
Definition: c_runtime_api.h:179
const TVMValue * values
Definition: packed_func.h:393
Definition: c_runtime_api.h:184
TVMArgValue operator[](int i) const
Get i-th argument.
Definition: packed_func.h:1202
size_t size
Definition: c_runtime_api.h:223
static std::string TypeName()
Definition: packed_func.h:463
static bool Check(const Object *ptr)
Check if an object matches the template type.
Definition: packed_func.h:458
Union type of values being passed through API and function calls.
Definition: c_runtime_api.h:208
base class of all object containers.
Definition: object.h:167
Definition: c_runtime_api.h:182
const char * data
Definition: c_runtime_api.h:222
Object * TVMArrayHandleToObjectHandle(TVMArrayHandle handle)
Definition: ndarray.h:434
TVMValue value_
The value.
Definition: packed_func.h:635
TVMRetValue()
default constructor
Definition: packed_func.h:802
Managed NDArray. The array is backed by reference counted blocks.
Definition: ndarray.h:51
TVMRetValue & operator=(int value)
Definition: packed_func.h:879
static void FFIClearAfterMove(ObjectRef *ref)
Clear the object ref data field without DecRef after we successfully moved the field.
Definition: object.h:592
TVMRetValue & operator=(TVMRetValue &&other)
Definition: packed_func.h:852
Byte array type used to pass in byte array When kTVMBytes is used as data type.
Definition: c_runtime_api.h:221
TVMRetValue & operator=(TVMByteArray value)
Definition: packed_func.h:904
TypedPackedFunc(const FLambda &typed_lambda)
construct from a lambda function with the same signature.
Definition: packed_func.h:309
void DecRef()
developer function, decrease reference counter.
Definition: object.h:801
bool IsInstance() const
Definition: object.h:829
TVMRetValue & operator=(const TVMArgValue &other)
Definition: packed_func.h:936
TVMRetValue & operator=(double value)
Definition: packed_func.h:859
TVM_ALWAYS_INLINE void operator()(size_t i, const TVMByteArray &value) const
Definition: packed_func.h:1571
Runtime Array container types.
Definition: packed_func.h:38
Internal auxiliary struct for TypedPackedFunc to indicate a movable argument.
Definition: packed_func.h:709
static bool Check(const Object *ptr)
Definition: packed_func.h:521
TVM_ALWAYS_INLINE void operator()(size_t i, T value) const
Definition: packed_func.h:1522
static Optional< T > From(const TVMRetValue &val)
Definition: packed_func.h:1975
A device-independent managed NDArray abstraction.
DLDataType String2DLDataType(std::string s)
convert a string to TVM type.
Definition: data_type.h:348
static String From(const TVMRetValue &val)
Definition: packed_func.h:1960
static ObjectPtr< Object > FFIDataFromHandle(TVMArrayHandle handle)
Construct NDArray's Data field from array handle in FFI.
Definition: ndarray.h:417
Type traits for runtime type check during FFI conversion.
Definition: packed_func.h:430
bool IsObjectRef() const
Definition: packed_func.h:1795
static Optional< T > From(const TVMArgValue &val)
Definition: packed_func.h:1971
Runtime primitive data type.
Definition: data_type.h:41
bool defined() const
Definition: object.h:544
TVM_ALWAYS_INLINE void operator()(size_t i, Device value) const
Definition: packed_func.h:1551
Arguments into TVM functions.
Definition: packed_func.h:391
TVMRetValue(TVMRetValue &&other)
move constructor from another return value.
Definition: packed_func.h:807
TSelf & operator=(FLambda typed_lambda)
copy assignment operator from typed lambda
Definition: packed_func.h:331
PackedFunc(TCallable data)
Constructing a packed function from a callable type whose signature is consistent with PackedFunc ...
Definition: packed_func.h:151
T value() const
Definition: optional.h:92
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Definition: c_runtime_api.h:187
TypedPackedFunc(const FLambda &typed_lambda, std::string name)
construct from a lambda function with the same signature.
Definition: packed_func.h:286
TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const
Definition: packed_func.h:1526
static std::string TypeName()
Definition: packed_func.h:500
TSelf & operator=(PackedFunc packed)
copy assignment operator from PackedFunc.
Definition: packed_func.h:340
runtime::PackedFunc.
Definition: object.h:74
TVMRetValue & operator=(Module m)
Definition: packed_func.h:920
Object container class that backs NDArray.
Definition: ndarray.h:286
TVMRetValue & operator=(PackedFunc f)
Definition: packed_func.h:924
PackedFuncObj(FCallPacked *f_call_pack)
Constructing a packed function object from a function pointer.
Definition: packed_func.h:103
TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef &value) const
Definition: packed_func.h:1592
const int * type_codes
Definition: packed_func.h:394
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:574
TVMRetValue & operator=(DLDataType t)
Definition: packed_func.h:889
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:178
TObjectRef AsObjectRef() const
Definition: packed_func.h:1826
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
TVM_ALWAYS_INLINE void operator()(size_t i, const std::string &value) const
Definition: packed_func.h:1567
Reference to string objects.
Definition: string.h:98
void(const PackedFuncObj *, TVMArgs, TVMRetValue *) FCallPacked
The internal callable function type.
Definition: packed_func.h:97
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:60
Object container class that backs PackedFunc.
Definition: packed_func.h:68
TVM_ALWAYS_INLINE void operator()(size_t i, DataType dtype) const
Definition: packed_func.h:1559
TVMPODValue_(TVMValue value, int type_code)
Definition: packed_func.h:632
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
TVM_ALWAYS_INLINE void operator()(size_t i, double value) const
Definition: packed_func.h:1531
TVM_ALWAYS_INLINE void operator()(size_t i, std::nullptr_t value) const
Definition: packed_func.h:1535
T * ptr() const
return handle as specific pointer type.
Definition: packed_func.h:617
DLDevice Device
Definition: ndarray.h:43
Definition: c_runtime_api.h:177
const Object * get() const
Definition: object.h:546
TVMRetValue & operator=(const TypedPackedFunc< FType > &f)
Definition: packed_func.h:929
static TVMRetValue MoveFromCHost(TVMValue value, int type_code)
Construct a new TVMRetValue by moving from return value stored via C API.
Definition: packed_func.h:967
Object & operator=(const Object &other)
Definition: object.h:251
TVMRetValue & operator=(DLDevice value)
Definition: packed_func.h:884
TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncObj, Object)
TStorage callable_
Type-erased filed for storing callable object.
Definition: packed_func.h:127
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Check if an object matches the template type and return the mismatched type if it exists...
Definition: packed_func.h:438
static bool CanConvertFrom(const TVMArgValue &val)
Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String.
Definition: packed_func.h:1981
Base class of all object reference.
Definition: object.h:511
Base container of module.
Definition: module.h:142
TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue *rv) const
Call the function in packed format.
Definition: packed_func.h:1216
int size() const
Definition: packed_func.h:1208
std::string GetTypeKey() const
Definition: object.h:180
TVMRetValue & operator=(const DataType &other)
Definition: packed_func.h:894
Shared content of all specializations of hash map.
Definition: map.h:174
TVMRetValue & operator=(int64_t value)
Definition: packed_func.h:874
A managed object in the TVM runtime.
TVM_ALWAYS_INLINE void operator()(size_t i, const TVMArgValue &value) const
Definition: packed_func.h:1539
static TVMArrayHandle FFIGetHandle(const ObjectRef &nd)
Get FFI Array handle from ndarray.
Definition: ndarray.h:422
void operator()(size_t i, const TVMRetValue &value) const
Definition: packed_func.h:1579
Runtime container of the functions generated by TVM, This is used to support dynamically link...
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:360
Module container of TVM.
Definition: module.h:79
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:362
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
static std::string TypeIndex2Key(uint32_t tindex)
Get the type key of the corresponding index from runtime.
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Runtime Map container types.
static String From(const TVMArgValue &val)
Definition: packed_func.h:1952
static std::string TypeName()
Definition: packed_func.h:531
Definition: c_runtime_api.h:178
int type_code() const
Definition: packed_func.h:610
TVMMovableArgValueWithContext_(TVMValue value, int type_code, int arg_index, const std::string *optional_name, FSig *f_sig)
move constructor from another return value.
Definition: packed_func.h:765
A single argument value to PackedFunc. Containing both type_code and TVMValue.
Definition: packed_func.h:646
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:138
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
const TVMValue & value() const
Definition: packed_func.h:691
FCallPacked * f_call_packed_
Internal callable function pointer used to call the packed function.
Definition: packed_func.h:109
TVMRetValue operator()(Args &&... args) const
Call packed function by directly passing in unpacked format.
Definition: packed_func.h:1613
PackedFuncObj()=delete
Delete the default constructor explicitly.
Definition: packed_func.h:1517
constexpr runtime::NullOptType NullOpt
Definition: optional.h:160
TVMArgsSetter(TVMValue *values, int *type_codes)
Definition: packed_func.h:1519
static constexpr const char * _type_key
Definition: packed_func.h:78
Derived object class for constructing PackedFuncObj.
Definition: packed_func.h:114
TypedPackedFunc(std::nullptr_t null)
constructor from null
Definition: packed_func.h:234
TVMRetValue(const TVMRetValue &other)
Definition: packed_func.h:828
PackedFunc(std::nullptr_t null)
Constructor from null.
Definition: packed_func.h:141
const TVMValue & value() const
Definition: packed_func.h:976
Definition: c_runtime_api.h:180
TVMArgValue(TVMValue value, int type_code)
constructor
Definition: packed_func.h:655
TVMRetValue & operator=(const TVMRetValue &other)
Definition: packed_func.h:932
#define TVM_CHECK_TYPE_CODE(CODE, T)
Definition: packed_func.h:422
TVM_ALWAYS_INLINE void operator()(size_t i, void *value) const
Definition: packed_func.h:1543
std::string() FSig
Using static function to output TypedPackedFunc signature.
Definition: packed_func.h:186
Definition: c_runtime_api.h:186
PackedFuncSubObj(TCallable callable)
Derived object class for constructing PackedFuncObj.
Definition: packed_func.h:124
TVMRetValue & operator=(std::nullptr_t value)
Definition: packed_func.h:864
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Definition: packed_func.h:472
runtime::DataType DataType
Definition: data_type.h:398
Definition: c_runtime_api.h:185
TVMRetValue & operator=(NDArray other)
Definition: packed_func.h:908
Type trait to specify special value conversion rules from TVMArgValue and TVMRetValue.
Definition: packed_func.h:1096
static TObjectRef From(const TVMRetValue &val)
Convert a TObjectRef from a return value.
Definition: packed_func.h:1108
Definition: c_runtime_api.h:181
Definition: packed_func.h:62
TypedPackedFunc()
default constructor
Definition: packed_func.h:232
Internal auxiliary struct for TypedPackedFunc to indicate a movable argument with additional context ...
Definition: packed_func.h:754
TVMRetValue & operator=(void *value)
Definition: packed_func.h:869
TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc< FType > &value) const
Definition: packed_func.h:1576
TVM_ALWAYS_INLINE void operator()(size_t i, DLDataType value) const
Definition: packed_func.h:1555
TVMPODValue_()
Definition: packed_func.h:631