24 #ifndef TVM_RUNTIME_PACKED_FUNC_H_ 25 #define TVM_RUNTIME_PACKED_FUNC_H_ 31 #include <tvm/runtime/logging.h> 41 #include <type_traits> 46 #ifndef TVM_RUNTIME_HEADER_ONLY 47 #define TVM_RUNTIME_HEADER_ONLY 0 56 class TVMMovableArgValueWithContext_;
59 template <
typename FType>
61 template <
typename TSignature>
78 static constexpr
const char*
_type_key =
"runtime.PackedFunc";
85 template <
class TPackedFuncSubObj>
113 template <
class TCallable>
115 using TStorage =
typename std::remove_cv<typename std::remove_reference<TCallable>::type>::type;
147 template <
typename TCallable,
148 typename = std::enable_if_t<
149 std::is_convertible<TCallable, std::function<void(TVMArgs, TVMRetValue*)>>::value &&
150 !std::is_base_of<TCallable, PackedFunc>::value>>
153 data_ = make_object<ObjType>(std::forward<TCallable>(data));
169 template <
typename... Args>
170 inline TVMRetValue operator()(Args&&... args)
const;
178 bool operator==(std::nullptr_t null)
const {
return data_ ==
nullptr; }
180 bool operator!=(std::nullptr_t null)
const {
return data_ !=
nullptr; }
191 template <
typename FType>
226 template <
typename R,
typename... Args>
284 template <
typename FLambda,
typename =
typename std::enable_if<std::is_convertible<
285 FLambda, std::function<R(Args...)>>::value>::type>
287 this->AssignTypedLambda(typed_lambda, name);
307 template <
typename FLambda,
typename =
typename std::enable_if<std::is_convertible<
308 FLambda, std::function<R(Args...)>>::value>::type>
310 this->AssignTypedLambda(typed_lambda);
328 template <
typename FLambda,
typename =
typename std::enable_if<
329 std::is_convertible<FLambda,
330 std::function<R(Args...)>>::value>::type>
332 this->AssignTypedLambda(typed_lambda);
349 TVM_ALWAYS_INLINE R operator()(Args... args)
const;
360 bool operator==(std::nullptr_t null)
const {
return packed_ ==
nullptr; }
362 bool operator!=(std::nullptr_t null)
const {
return packed_ !=
nullptr; }
376 template <
typename FLambda>
377 inline void AssignTypedLambda(FLambda flambda, std::string name);
386 template <
typename FLambda>
387 inline void AssignTypedLambda(FLambda flambda);
403 : values(values), type_codes(type_codes), num_args(num_args) {}
405 inline int size()
const;
422 #define TVM_CHECK_TYPE_CODE(CODE, T) \ 423 ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) 429 template <
typename T>
439 using ContainerType =
typename T::ContainerType;
440 if (ptr ==
nullptr) {
441 if (T::_type_is_nullable) {
459 using ContainerType =
typename T::ContainerType;
460 if (ptr ==
nullptr)
return T::_type_is_nullable;
464 using ContainerType =
typename T::ContainerType;
465 return ContainerType::_type_key;
470 template <
typename T>
473 if (ptr ==
nullptr) {
480 for (
size_t i = 0; i < n->size(); i++) {
484 return String(
"Array[index " + std::to_string(i) +
": " + check_subtype.
value() +
"]");
490 if (ptr ==
nullptr)
return true;
502 template <
typename K,
typename V>
505 if (ptr ==
nullptr)
return NullOpt;
508 for (
const auto& kv : *n) {
511 if (key_type.
defined() || value_type.defined()) {
512 std::string key_name =
514 std::string value_name = value_type.defined() ? std::string(value_type.value())
516 return String(
"Map[" + key_name +
", " + value_name +
"]");
522 if (ptr ==
nullptr)
return true;
525 for (
const auto& kv : *n) {
543 operator double()
const {
547 if (type_code_ == kDLInt) {
548 return static_cast<double>(value_.v_int64);
551 return value_.v_float64;
553 operator int64_t()
const {
555 return value_.v_int64;
557 operator uint64_t()
const {
559 return value_.v_int64;
561 operator int()
const {
565 return static_cast<int>(value_.v_int64);
567 operator bool()
const {
569 return value_.v_int64 != 0;
571 operator void*()
const {
575 return value_.v_handle;
577 operator DLTensor*()
const {
579 return static_cast<DLTensor*
>(value_.v_handle);
582 LOG(FATAL) <<
"Expected " 608 return value_.v_device;
616 template <
typename T>
618 return static_cast<T*
>(value_.v_handle);
621 template <
typename TObjectRef,
622 typename =
typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
623 inline bool IsObjectRef()
const;
624 template <
typename TObjectRef>
625 inline TObjectRef AsObjectRef()
const;
657 using TVMPODValue_::operator double;
658 using TVMPODValue_::operator int64_t;
659 using TVMPODValue_::operator uint64_t;
660 using TVMPODValue_::operator int;
661 using TVMPODValue_::operator bool;
662 using TVMPODValue_::operator
void*;
663 using TVMPODValue_::operator DLTensor*;
664 using TVMPODValue_::operator
NDArray;
665 using TVMPODValue_::operator
Device;
666 using TVMPODValue_::operator
Module;
672 operator std::string()
const {
677 return std::string(arr->
data, arr->
size);
678 }
else if (type_code_ ==
kTVMStr) {
679 return std::string(value_.v_str);
681 ICHECK(IsObjectRef<tvm::runtime::String>())
684 return AsObjectRef<tvm::runtime::String>().
operator std::string();
687 template <
typename FType>
693 template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
694 inline operator T()
const;
695 inline operator DLDataType()
const;
713 using TVMPODValue_::operator double;
714 using TVMPODValue_::operator int64_t;
715 using TVMPODValue_::operator uint64_t;
716 using TVMPODValue_::operator int;
717 using TVMPODValue_::operator bool;
718 using TVMPODValue_::operator
void*;
719 using TVMPODValue_::operator DLTensor*;
720 using TVMPODValue_::operator
NDArray;
721 using TVMPODValue_::operator
Device;
722 using TVMPODValue_::operator
Module;
725 operator std::string()
const {
return AsArgValue().operator std::string(); }
726 template <
typename FType>
730 operator DLDataType()
const {
return AsArgValue().operator DLDataType(); }
738 template <
typename T,
739 typename =
typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
740 inline operator T()
const;
766 const std::string* optional_name,
FSig* f_sig)
767 : value_(value, type_code),
768 arg_index_(arg_index),
769 optional_name_(optional_name),
772 template <
typename T>
776 }
catch (dmlc::Error& e) {
777 LOG(FATAL) <<
"In function " << (optional_name_ ==
nullptr ?
"<anonymous>" : *optional_name_)
778 << (f_sig_ ==
nullptr ?
"" : (*f_sig_)()) <<
": error while converting argument " 779 << arg_index_ <<
": " << e.what();
787 const std::string* optional_name_;
814 using TVMPODValue_::operator double;
815 using TVMPODValue_::operator int64_t;
816 using TVMPODValue_::operator uint64_t;
817 using TVMPODValue_::operator int;
818 using TVMPODValue_::operator bool;
819 using TVMPODValue_::operator
void*;
820 using TVMPODValue_::operator DLTensor*;
821 using TVMPODValue_::operator
Device;
822 using TVMPODValue_::operator
NDArray;
823 using TVMPODValue_::operator
Module;
830 operator std::string()
const {
834 return *ptr<std::string>();
837 return *ptr<std::string>();
839 operator DLDataType()
const {
844 return value_.v_type;
847 template <
typename FType>
854 value_ = other.value_;
855 type_code_ = other.type_code_;
860 this->SwitchToPOD(kDLFloat);
861 value_.v_float64 = value;
866 value_.v_handle = value;
871 value_.v_handle = value;
875 this->SwitchToPOD(kDLInt);
876 value_.v_int64 = value;
880 this->SwitchToPOD(kDLInt);
881 value_.v_int64 = value;
886 value_.v_device = value;
896 this->SwitchToPOD(kDLInt);
897 value_.v_int64 = value;
901 this->SwitchToClass(
kTVMStr, value);
909 if (other.
data_ !=
nullptr) {
916 value_.v_handle =
nullptr;
928 template <
typename FType>
957 *ret_type_code = type_code_;
979 <<
"TVMRetValue.value can only be used for POD data";
983 template <
typename TObjectRef,
984 typename =
typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
986 template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
987 inline operator T()
const;
990 template <
typename T>
991 void Assign(
const T& other) {
992 switch (other.type_code()) {
994 SwitchToClass<std::string>(
kTVMStr, other);
998 SwitchToClass<std::string>(
kTVMBytes, other);
1002 *
this = other.operator PackedFunc();
1006 *
this = other.operator Module();
1010 *
this = other.operator NDArray();
1016 GetObjectPtr<Object>(static_cast<Object*>(other.value_.v_handle)));
1024 SwitchToPOD(other.type_code());
1025 value_ = other.value_;
1031 void SwitchToPOD(
int type_code) {
1032 if (type_code_ != type_code) {
1034 type_code_ = type_code;
1037 template <
typename T>
1038 void SwitchToClass(
int type_code, T v) {
1039 if (type_code_ != type_code) {
1041 type_code_ = type_code;
1042 value_.v_handle =
new T(v);
1044 *
static_cast<T*
>(value_.v_handle) = v;
1048 if (other.data_ !=
nullptr) {
1050 type_code_ = type_code;
1052 value_.v_handle = other.data_;
1053 other.data_ =
nullptr;
1056 value_.v_handle =
nullptr;
1061 switch (type_code_) {
1064 delete ptr<std::string>();
1095 template <
typename TObjectRef>
1130 #define TVM_DLL_EXPORT_PACKED_FUNC(ExportName, Function) \ 1132 TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ 1133 int* out_type_code, void* resource_handle); \ 1134 int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ 1135 int* out_type_code, void* resource_handle) { \ 1137 ::tvm::runtime::TVMRetValue rv; \ 1138 Function(::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ 1139 rv.MoveToCHost(out_value, out_type_code); \ 1141 } catch (const ::std::exception& _except_) { \ 1142 TVMAPISetLastError(_except_.what()); \ 1183 #define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ 1185 TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ 1186 int* out_type_code, void* resource_handle) { \ 1188 auto f = Function; \ 1189 using FType = ::tvm::runtime::detail::function_signature<decltype(f)>::FType; \ 1190 ::tvm::runtime::TVMRetValue rv; \ 1191 ::tvm::runtime::detail::unpack_call_by_signature<FType>::run( \ 1192 f, ::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ 1193 rv.MoveToCHost(out_value, out_type_code); \ 1195 } catch (const ::std::exception& _except_) { \ 1196 TVMAPISetLastError(_except_.what()); \ 1203 ICHECK_LT(i, num_args) <<
"not enough argument passed, " << num_args <<
" passed" 1204 <<
" but request arg[" << i <<
"].";
1210 template <
class TPackedFuncSubObj>
1213 (
static_cast<const TPackedFuncSubObj*
>(obj))->callable_(args, rv);
1217 (*f_call_packed_)(
this, args, rv);
1221 (
static_cast<PackedFuncObj*
>(data_.get()))->CallPacked(args, rv);
1226 switch (type_code) {
1242 return "ArrayHandle";
1244 return "DLDataType";
1248 return "FunctionHandle";
1250 return "ModuleHandle";
1252 return "NDArrayContainer";
1256 return "ObjectRValueRefArg";
1258 LOG(FATAL) <<
"unknown type_code=" <<
static_cast<int>(type_code);
1265 template <
bool stop, std::
size_t I,
typename F>
1266 struct for_each_dispatcher {
1267 template <
typename T,
typename... Args>
1268 static void run(
const F& f, T&& value, Args&&... args) {
1269 f(I, std::forward<T>(value));
1270 for_each_dispatcher<
sizeof...(Args) == 0, (I + 1), F>::run(f, std::forward<Args>(args)...);
1274 template <std::
size_t I,
typename F>
1275 struct for_each_dispatcher<true, I, F> {
1276 static void run(
const F& f) {}
1279 template <
typename F,
typename... Args>
1280 inline void for_each(
const F& f, Args&&... args) {
1281 for_each_dispatcher<
sizeof...(Args) == 0, 0, F>::run(f, std::forward<Args>(args)...);
1284 namespace parameter_pack {
1286 template <
typename... EnumArgs>
1287 struct EnumeratedParamPack {
1289 template <
template <
size_t i,
typename TArgument>
class Functor,
typename... ExtraParams>
1290 static void F(ExtraParams&&... extra_params) {
1291 using TExpander =
int[];
1294 (Functor<EnumArgs::i, typename EnumArgs::T>::F(extra_params...), 0)...,
1300 template <
typename... Args>
1301 struct EnumerateImpl {
1303 template <
size_t _i,
typename _T>
1305 static const constexpr
size_t i = _i;
1309 template <
typename...>
1312 template <std::size_t...
id>
1313 struct Zipper<std::integer_sequence<std::size_t, id...>> {
1314 using T = EnumeratedParamPack<Item<id, Args>...>;
1318 using T =
typename Zipper<std::index_sequence_for<Args...>>::T;
1321 template <
typename... Args>
1322 using Enumerate =
typename EnumerateImpl<Args...>::T;
1324 template <
typename... Args>
1326 template <
template <
size_t i,
typename TArgument>
class Functor,
typename... ExtraParams>
1327 static void InvokeWithoutArg(ExtraParams&&... extra_params) {
1328 Enumerate<Args...>::Invoke::template F<Functor, ExtraParams...>(
1329 std::forward<ExtraParams>(extra_params)...);
1339 template <
typename T>
1340 struct func_signature_helper {
1344 template <
typename T,
typename R,
typename... Args>
1345 struct func_signature_helper<R (T::*)(Args...)> {
1346 using FType = R(Args...);
1347 using ParamType = parameter_pack::ParamPack<Args...>;
1349 static_assert(!std::is_reference<R>::value,
"TypedPackedFunc return reference");
1352 template <
typename T,
typename R,
typename... Args>
1353 struct func_signature_helper<R (T::*)(Args...) const> {
1354 using FType = R(Args...);
1355 using ParamType = parameter_pack::ParamPack<Args...>;
1357 static_assert(!std::is_reference<R>::value,
"TypedPackedFunc return reference");
1364 template <
typename T>
1365 struct function_signature {
1366 using FType =
typename func_signature_helper<decltype(&T::operator())>::FType;
1367 using ParamType =
typename func_signature_helper<decltype(&T::operator())>::ParamType;
1368 using RetType =
typename func_signature_helper<decltype(&T::operator())>::RetType;
1372 template <
typename R,
typename... Args>
1373 struct function_signature<R(Args...)> {
1374 using FType = R(Args...);
1375 using ParamType = parameter_pack::ParamPack<Args...>;
1377 static_assert(!std::is_reference<R>::value,
"TypedPackedFunc return reference");
1381 template <
typename R,
typename... Args>
1382 struct function_signature<R (*)(Args...)> {
1383 using FType = R(Args...);
1384 using ParamType = detail::parameter_pack::ParamPack<Args...>;
1386 static_assert(!std::is_reference<R>::value,
"TypedPackedFunc return reference");
1389 template <
typename TSignature>
1392 namespace type2str {
1394 template <
typename T>
1395 struct TypeSimplifier;
1397 template <
typename T>
1399 template <typename = std::enable_if_t<std::is_base_of<ObjectRef, T>::value>>
1400 static std::string v() {
1401 return T::ContainerType::_type_key;
1405 struct Type2Str<int> {
1406 static std::string v() {
return "int"; }
1409 struct Type2Str<double> {
1410 static std::string v() {
return "double"; }
1413 struct Type2Str<int64_t> {
1414 static std::string v() {
return "int64_t"; }
1417 struct Type2Str<uint64_t> {
1418 static std::string v() {
return "uint64_t"; }
1421 struct Type2Str<bool> {
1422 static std::string v() {
return "bool"; }
1425 struct Type2Str<void> {
1426 static std::string v() {
return "void"; }
1429 struct Type2Str<std::basic_string<char>> {
1430 static std::string v() {
return "basic_string<char>"; }
1432 template <
typename K,
typename V>
1433 struct Type2Str<Map<K, V>> {
1434 static std::string v() {
1435 return "Map<" + TypeSimplifier<K>::v() +
", " + TypeSimplifier<V>::v() +
">";
1439 struct Type2Str<DLDevice> {
1440 static std::string v() {
return "DLDevice"; }
1443 struct Type2Str<DLTensor> {
1444 static std::string v() {
return "DLTensor"; }
1447 struct Type2Str<DataType> {
1448 static std::string v() {
return "DataType"; }
1451 struct Type2Str<DLDataType> {
1452 static std::string v() {
return "DLDataType"; }
1455 struct Type2Str<TVMRetValue> {
1456 static std::string v() {
return "TVMRetValue"; }
1459 struct Type2Str<TVMArgValue> {
1460 static std::string v() {
return "TVMArgValue"; }
1462 template <
typename FType>
1463 struct Type2Str<TypedPackedFunc<FType>> {
1466 template <
typename T>
1467 struct Type2Str<Array<T>> {
1468 static std::string v() {
return "Array<" + TypeSimplifier<T>::v() +
">"; }
1475 template <
typename T>
1476 struct TypeSimplifier {
1477 static std::string v() {
1478 using U =
typename std::remove_cv<
1479 typename std::remove_reference<typename std::remove_pointer<T>::type>::type>::type;
1480 return (std::is_const<T>::value ?
"const " :
"") + Type2Str<U>::v() +
1481 (std::is_pointer<T>::value ?
"*" :
"") + (std::is_reference<T>::value ?
"&" :
"");
1491 template <
typename TSignature>
1493 using ParamType =
typename TSignature::ParamType;
1494 using RetType =
typename TSignature::RetType;
1496 template <
size_t i,
typename TArgument>
1497 struct PrintParamType {
1498 static void F(std::ostream& os) {
1499 os << (i == 0 ?
"" :
", ") << i <<
": " << type2str::TypeSimplifier<TArgument>::v();
1503 static std::string F() {
1504 std::ostringstream oss;
1506 ParamType::template InvokeWithoutArg<PrintParamType>(oss);
1507 oss <<
") -> " << type2str::TypeSimplifier<RetType>::v();
1518 template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type>
1520 values_[i].v_int64 =
static_cast<int64_t
>(value);
1521 type_codes_[i] = kDLInt;
1523 TVM_ALWAYS_INLINE
void operator()(
size_t i, uint64_t value)
const {
1524 values_[i].v_int64 =
static_cast<int64_t
>(value);
1526 type_codes_[i] = kDLInt;
1528 TVM_ALWAYS_INLINE
void operator()(
size_t i,
double value)
const {
1529 values_[i].v_float64 = value;
1530 type_codes_[i] = kDLFloat;
1532 TVM_ALWAYS_INLINE
void operator()(
size_t i, std::nullptr_t value)
const {
1533 values_[i].v_handle = value;
1537 values_[i] = value.
value_;
1541 values_[i].v_handle = value;
1544 TVM_ALWAYS_INLINE
void operator()(
size_t i, DLTensor* value)
const {
1545 values_[i].v_handle = value;
1549 values_[i].v_device = value;
1552 TVM_ALWAYS_INLINE
void operator()(
size_t i, DLDataType value)
const {
1553 values_[i].v_type = value;
1557 operator()(i, dtype.operator DLDataType());
1559 TVM_ALWAYS_INLINE
void operator()(
size_t i,
const char* value)
const {
1560 values_[i].v_str = value;
1564 TVM_ALWAYS_INLINE
void operator()(
size_t i,
const std::string& value)
const {
1565 values_[i].v_str = value.c_str();
1569 values_[i].v_handle =
const_cast<TVMByteArray*
>(&value);
1572 template <
typename FType>
1574 operator()(i, value.packed());
1578 values_[i].v_str = value.
ptr<std::string>()->c_str();
1582 values_[i] = value.
value_;
1587 template <
typename TObjectRef,
1588 typename =
typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
1589 TVM_ALWAYS_INLINE
void operator()(
size_t i,
const TObjectRef& value)
const {
1590 this->SetObject(i, value);
1593 template <
typename TObjectRef,
1594 typename =
typename std::enable_if<std::is_base_of<
1595 ObjectRef,
typename std::remove_reference<TObjectRef>::type>::value>::type>
1596 TVM_ALWAYS_INLINE
void operator()(
size_t i, TObjectRef&& value)
const {
1597 this->SetObject(i, std::forward<TObjectRef>(value));
1601 template <
typename TObjectRef>
1602 inline void SetObject(
size_t i, TObjectRef&& value)
const;
1609 template <
typename... Args>
1611 const int kNumArgs =
sizeof...(Args);
1612 const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
1614 int type_codes[kArraySize];
1615 detail::for_each(
TVMArgsSetter(values, type_codes), std::forward<Args>(args)...);
1618 ->CallPacked(
TVMArgs(values, type_codes, kNumArgs), &rv);
1623 template <
typename R,
int nleft,
int index,
typename F>
1624 struct unpack_call_dispatcher {
1625 template <
typename... Args>
1626 TVM_ALWAYS_INLINE
static void run(
const std::string* optional_name,
FSig* f_sig,
const F& f,
1628 Args&&... unpacked_args) {
1631 unpack_call_dispatcher<R, nleft - 1, index + 1, F>::run(
1632 optional_name, f_sig, f, args_pack, rv, std::forward<Args>(unpacked_args)...,
1634 optional_name, f_sig));
1638 template <
typename R,
int index,
typename F>
1639 struct unpack_call_dispatcher<R, 0, index, F> {
1640 template <
typename... Args>
1641 TVM_ALWAYS_INLINE
static void run(
const std::string* optional_name,
FSig* f_sig,
const F& f,
1643 Args&&... unpacked_args) {
1644 using RetType = decltype(f(std::forward<Args>(unpacked_args)...));
1645 if (std::is_same<RetType, R>::value) {
1646 *rv = f(std::forward<Args>(unpacked_args)...);
1648 *rv = R(f(std::forward<Args>(unpacked_args)...));
1653 template <
int index,
typename F>
1654 struct unpack_call_dispatcher<void, 0, index, F> {
1655 template <
typename... Args>
1656 TVM_ALWAYS_INLINE
static void run(
const std::string* optional_name,
FSig* f_sig,
const F& f,
1658 Args&&... unpacked_args) {
1659 f(std::forward<Args>(unpacked_args)...);
1663 template <
typename R,
int nargs,
typename F>
1664 TVM_ALWAYS_INLINE
void unpack_call(
const std::string* optional_name,
const F& f,
1666 FSig* f_sig = detail::SignaturePrinter<detail::function_signature<F>>::F;
1667 CHECK_EQ(nargs, args.
size()) <<
"Function " 1668 << (optional_name ==
nullptr ?
"<anonymous>" : *optional_name)
1669 << (f_sig ==
nullptr ?
"" : (*f_sig)()) <<
" expects " << nargs
1670 <<
" arguments but " << args.
size() <<
" were provided";
1671 unpack_call_dispatcher<R, nargs, 0, F>::run(optional_name, f_sig, f, args, rv);
1674 template <
typename FType>
1675 struct unpack_call_by_signature {};
1677 template <
typename R,
typename... Args>
1678 struct unpack_call_by_signature<R(Args...)> {
1679 template <
typename F>
1680 TVM_ALWAYS_INLINE
static void run(
const F& f,
const TVMArgs& args,
TVMRetValue* rv) {
1681 unpack_call<R,
sizeof...(Args)>(
nullptr, f, args, rv);
1685 template <
typename R,
typename... Args>
1686 TVM_ALWAYS_INLINE R call_packed(
const PackedFunc& pf, Args&&... args) {
1687 return R(pf(std::forward<Args>(args)...));
1690 template <
typename R>
1691 struct typed_packed_call_dispatcher {
1692 template <
typename... Args>
1693 TVM_ALWAYS_INLINE
static R run(
const PackedFunc& pf, Args&&... args) {
1694 return pf(std::forward<Args>(args)...);
1699 struct typed_packed_call_dispatcher<void> {
1700 template <
typename... Args>
1701 TVM_ALWAYS_INLINE
static void run(
const PackedFunc& pf, Args&&... args) {
1702 pf(std::forward<Args>(args)...);
1707 template <
typename R,
typename... Args>
1710 template <
typename R,
typename... Args>
1714 template <
typename R,
typename... Args>
1718 template <
typename R,
typename... Args>
1722 template <
typename R,
typename... Args>
1723 template <
typename FType>
1724 inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda, std::string name) {
1725 FSig* f_sig = detail::SignaturePrinter<detail::function_signature<FType>>::F;
1727 if (args.
size() !=
sizeof...(Args)) {
1728 LOG(FATAL) <<
"Function " << name << (f_sig ==
nullptr ?
"" : (*f_sig)()) <<
" expects " 1729 <<
sizeof...(Args) <<
" arguments, but " << args.
size() <<
" were provided.";
1731 detail::unpack_call<R,
sizeof...(Args)>(&name, flambda, args, rv);
1735 template <
typename R,
typename... Args>
1736 template <
typename FType>
1738 FSig* f_sig = detail::SignaturePrinter<detail::function_signature<FType>>::F;
1740 if (args.
size() !=
sizeof...(Args)) {
1741 LOG(FATAL) <<
"Function <anonymous> " << (*f_sig)() <<
" expects " <<
sizeof...(Args)
1742 <<
" arguments, but " << args.
size() <<
" were provided.";
1744 detail::unpack_call<R,
sizeof...(Args)>(
nullptr, flambda, args, rv);
1748 template <
typename R,
typename... Args>
1750 return detail::typed_packed_call_dispatcher<R>::run(packed_, std::forward<Args>(args)...);
1758 template <
typename T>
1759 inline void TVMArgsSetter::SetObject(
size_t i, T&& value)
const {
1760 using ContainerType =
typename std::remove_reference<T>::type::ContainerType;
1761 if (value.defined()) {
1762 Object* ptr = value.data_.data_;
1763 if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
1764 (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1768 }
else if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
1769 (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1771 values_[i].v_handle = ptr;
1773 }
else if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value ||
1774 (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
1776 values_[i].v_handle = ptr;
1778 }
else if (std::is_rvalue_reference<decltype(value)>::value) {
1779 values_[i].v_handle =
const_cast<Object**
>(&(value.data_.data_));
1782 values_[i].v_handle = value.data_.data_;
1787 values_[i].v_handle =
nullptr;
1791 template <
typename TObjectRef,
typename>
1793 using ContainerType =
typename TObjectRef::ContainerType;
1795 if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
1800 if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
1802 static_cast<Object*
>(value_.v_handle)->IsInstance<ContainerType>();
1804 if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value) {
1806 static_cast<Object*
>(value_.v_handle)->IsInstance<ContainerType>();
1812 return (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1814 (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1816 (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
1822 template <
typename TObjectRef>
1824 static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
1825 "Conversion only works for ObjectRef");
1826 using ContainerType =
typename TObjectRef::ContainerType;
1829 CHECK(TObjectRef::_type_is_nullable)
1830 <<
"Expect a not null value of " << ContainerType::_type_key;
1834 if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
1839 CHECK(data->IsInstance<ContainerType>())
1840 <<
"Expected " << ContainerType::_type_key <<
" but got " << data->GetTypeKey();
1841 return TObjectRef(data);
1843 if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
1847 CHECK(data->IsInstance<ContainerType>())
1848 <<
"Expected " << ContainerType::_type_key <<
" but got " << data->GetTypeKey();
1849 return TObjectRef(data);
1851 if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value) {
1855 CHECK(data->IsInstance<ContainerType>())
1856 <<
"Expected " << ContainerType::_type_key <<
" but got " << data->GetTypeKey();
1857 return TObjectRef(data);
1864 <<
", but got " << checked_type.
value();
1865 return TObjectRef(GetObjectPtr<Object>(ptr));
1870 <<
", but got " << checked_type.
value();
1871 return TObjectRef(GetObjectPtr<Object>(ptr));
1872 }
else if (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1877 return TObjectRef(data);
1878 }
else if (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1881 return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
1882 }
else if (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
1885 return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
1892 template <
typename TObjectRef,
typename>
1894 using ContainerType =
typename TObjectRef::ContainerType;
1895 const Object* ptr = other.get();
1896 if (ptr !=
nullptr) {
1897 if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
1898 (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1900 return operator=(
NDArray(std::move(other.data_)));
1902 if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
1903 (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1905 return operator=(
Module(std::move(other.data_)));
1910 value_.v_handle =
nullptr;
1915 template <
typename T,
typename>
1916 inline TVMArgValue::operator T()
const {
1920 template <
typename T,
typename>
1921 inline TVMMovableArgValue_::operator T()
const {
1923 auto** ref =
static_cast<Object**
>(value_.v_handle);
1932 template <
typename T,
typename>
1933 inline TVMRetValue::operator T()
const {
1938 return (*this)->GetFunction(name, query_imports);
1961 template <
typename T>
1977 inline TVMArgValue::operator DLDataType()
const {
1990 return value_.v_type;
1997 #endif // TVM_RUNTIME_PACKED_FUNC_H_ static TObjectRef From(const TVMArgValue &val)
Convert a TObjectRef from an argument value.
Definition: packed_func.h:1102
const char * ArgTypeCode2Str(int type_code)
Convert argument type code to string.
Definition: packed_func.h:1225
int num_args
Definition: packed_func.h:395
TVMArgValue()
default constructor
Definition: packed_func.h:649
static bool Check(const Object *ptr)
Definition: packed_func.h:489
TVMRetValue & operator=(std::string value)
Definition: packed_func.h:900
array node content in array
Definition: array.h:38
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:799
PackedFunc GetFunction(const std::string &name, bool query_imports=false)
Get packed function from current module by name.
Definition: packed_func.h:1937
static void FFIDecRef(TVMArrayHandle handle)
DecRef resource managed by an FFI array handle.
Definition: ndarray.h:429
TVMRetValue & operator=(TVMMovableArgValue_ &&other)
Definition: packed_func.h:940
TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef &&value) const
Definition: packed_func.h:1596
void MoveToCHost(TVMValue *ret_value, int *ret_type_code)
Move the value back to front-end via C API. This marks the current container as null. The managed resources are moved to the front-end. The front end should take charge in managing them.
Definition: packed_func.h:953
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
A custom smart pointer for Object.
Definition: object.h:358
TVMArgs(const TVMValue *values, const int *type_codes, int num_args)
constructor
Definition: packed_func.h:402
A PackedFunc wrapper to provide typed function signature. It is backed by a PackedFunc internally...
Definition: packed_func.h:227
TVM_ALWAYS_INLINE void operator()(size_t i, const char *value) const
Definition: packed_func.h:1559
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:180
TVMRetValue & operator=(bool value)
Definition: packed_func.h:895
int type_code_
the type code
Definition: packed_func.h:637
Internal base class to handle conversion to POD values.
Definition: packed_func.h:541
static constexpr const uint32_t _type_index
Definition: packed_func.h:77
std::string DLDataType2String(DLDataType t)
convert a TVM type to string.
Definition: data_type.h:332
void * v_handle
Definition: c_runtime_api.h:147
const PackedFunc & packed() const
Definition: packed_func.h:358
Definition: c_runtime_api.h:124
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Definition: packed_func.h:504
TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue *rv) const
Call the function in packed format.
Definition: packed_func.h:1220
Definition: c_runtime_api.h:119
TVMMovableArgValue_(TVMValue value, int type_code)
Definition: packed_func.h:711
TVM_ALWAYS_INLINE void operator()(size_t i, DLTensor *value) const
Definition: packed_func.h:1544
~TVMRetValue()
destructor
Definition: packed_func.h:812
Definition: c_runtime_api.h:115
const TVMValue * values
Definition: packed_func.h:393
Definition: c_runtime_api.h:120
TVMArgValue operator[](int i) const
Get i-th argument.
Definition: packed_func.h:1202
size_t size
Definition: c_runtime_api.h:159
static std::string TypeName()
Definition: packed_func.h:463
static bool Check(const Object *ptr)
Check if an object matches the template type.
Definition: packed_func.h:458
Union type of values being passed through API and function calls.
Definition: c_runtime_api.h:144
base class of all object containers.
Definition: object.h:167
Definition: c_runtime_api.h:118
const char * data
Definition: c_runtime_api.h:158
Object * TVMArrayHandleToObjectHandle(TVMArrayHandle handle)
Definition: ndarray.h:433
TVMValue value_
The value.
Definition: packed_func.h:635
TVMRetValue()
default constructor
Definition: packed_func.h:802
Managed NDArray. The array is backed by reference counted blocks.
Definition: ndarray.h:59
TVMRetValue & operator=(int value)
Definition: packed_func.h:879
static void FFIClearAfterMove(ObjectRef *ref)
Clear the object ref data field without DecRef after we successfully moved the field.
Definition: object.h:592
TVMRetValue & operator=(TVMRetValue &&other)
Definition: packed_func.h:852
Byte array type used to pass in byte array When kTVMBytes is used as data type.
Definition: c_runtime_api.h:157
TVMRetValue & operator=(TVMByteArray value)
Definition: packed_func.h:904
TypedPackedFunc(const FLambda &typed_lambda)
construct from a lambda function with the same signature.
Definition: packed_func.h:309
void DecRef()
developer function, decrease reference counter.
Definition: object.h:801
bool IsInstance() const
Definition: object.h:829
TVMRetValue & operator=(const TVMArgValue &other)
Definition: packed_func.h:936
TVMRetValue & operator=(double value)
Definition: packed_func.h:859
TVM_ALWAYS_INLINE void operator()(size_t i, const TVMByteArray &value) const
Definition: packed_func.h:1568
Runtime Array container types.
Definition: packed_func.h:38
Internal auxiliary struct for TypedPackedFunc to indicate a movable argument.
Definition: packed_func.h:709
static bool Check(const Object *ptr)
Definition: packed_func.h:521
TVM_ALWAYS_INLINE void operator()(size_t i, T value) const
Definition: packed_func.h:1519
static Optional< T > From(const TVMRetValue &val)
Definition: packed_func.h:1967
A device-independent managed NDArray abstraction.
DLDataType String2DLDataType(std::string s)
convert a string to TVM type.
Definition: data_type.h:339
static String From(const TVMRetValue &val)
Definition: packed_func.h:1952
static ObjectPtr< Object > FFIDataFromHandle(TVMArrayHandle handle)
Construct NDArray's Data field from array handle in FFI.
Definition: ndarray.h:416
Type traits for runtime type check during FFI conversion.
Definition: packed_func.h:430
bool IsObjectRef() const
Definition: packed_func.h:1792
static Optional< T > From(const TVMArgValue &val)
Definition: packed_func.h:1963
Runtime primitive data type.
Definition: data_type.h:41
bool defined() const
Definition: object.h:544
TVM_ALWAYS_INLINE void operator()(size_t i, Device value) const
Definition: packed_func.h:1548
Arguments into TVM functions.
Definition: packed_func.h:391
TVMRetValue(TVMRetValue &&other)
move constructor from another return value.
Definition: packed_func.h:807
TSelf & operator=(FLambda typed_lambda)
copy assignment operator from typed lambda
Definition: packed_func.h:331
PackedFunc(TCallable data)
Constructing a packed function from a callable type whose signature is consistent with PackedFunc ...
Definition: packed_func.h:151
T value() const
Definition: optional.h:92
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:270
Definition: c_runtime_api.h:123
TypedPackedFunc(const FLambda &typed_lambda, std::string name)
construct from a lambda function with the same signature.
Definition: packed_func.h:286
TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const
Definition: packed_func.h:1523
static std::string TypeName()
Definition: packed_func.h:500
TSelf & operator=(PackedFunc packed)
copy assignment operator from PackedFunc.
Definition: packed_func.h:340
runtime::PackedFunc.
Definition: object.h:74
TVMRetValue & operator=(Module m)
Definition: packed_func.h:920
Object container class that backs NDArray.
Definition: ndarray.h:294
TVMRetValue & operator=(PackedFunc f)
Definition: packed_func.h:924
PackedFuncObj(FCallPacked *f_call_pack)
Constructing a packed function object from a function pointer.
Definition: packed_func.h:103
TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef &value) const
Definition: packed_func.h:1589
const int * type_codes
Definition: packed_func.h:394
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:574
TVMRetValue & operator=(DLDataType t)
Definition: packed_func.h:889
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:178
TObjectRef AsObjectRef() const
Definition: packed_func.h:1823
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
TVM_ALWAYS_INLINE void operator()(size_t i, const std::string &value) const
Definition: packed_func.h:1564
Reference to string objects.
Definition: string.h:124
void(const PackedFuncObj *, TVMArgs, TVMRetValue *) FCallPacked
The internal callable function type.
Definition: packed_func.h:97
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:60
Object container class that backs PackedFunc.
Definition: packed_func.h:68
TVM_ALWAYS_INLINE void operator()(size_t i, DataType dtype) const
Definition: packed_func.h:1556
TVMPODValue_(TVMValue value, int type_code)
Definition: packed_func.h:632
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
TVM_ALWAYS_INLINE void operator()(size_t i, double value) const
Definition: packed_func.h:1528
TVM_ALWAYS_INLINE void operator()(size_t i, std::nullptr_t value) const
Definition: packed_func.h:1532
T * ptr() const
return handle as specific pointer type.
Definition: packed_func.h:617
DLDevice Device
Definition: ndarray.h:43
Definition: c_runtime_api.h:113
const Object * get() const
Definition: object.h:546
TVMRetValue & operator=(const TypedPackedFunc< FType > &f)
Definition: packed_func.h:929
static TVMRetValue MoveFromCHost(TVMValue value, int type_code)
Construct a new TVMRetValue by moving from return value stored via C API.
Definition: packed_func.h:967
Object & operator=(const Object &other)
Definition: object.h:251
TVMRetValue & operator=(DLDevice value)
Definition: packed_func.h:884
TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncObj, Object)
TStorage callable_
Type-erased filed for storing callable object.
Definition: packed_func.h:127
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Check if an object matches the template type and return the mismatched type if it exists...
Definition: packed_func.h:438
static bool CanConvertFrom(const TVMArgValue &val)
Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String.
Definition: packed_func.h:1973
Base class of all object reference.
Definition: object.h:511
Base container of module.
Definition: module.h:113
TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue *rv) const
Call the function in packed format.
Definition: packed_func.h:1216
int size() const
Definition: packed_func.h:1208
std::string GetTypeKey() const
Definition: object.h:180
TVMRetValue & operator=(const DataType &other)
Definition: packed_func.h:894
Shared content of all specializations of hash map.
Definition: map.h:174
TVMRetValue & operator=(int64_t value)
Definition: packed_func.h:874
A managed object in the TVM runtime.
TVM_ALWAYS_INLINE void operator()(size_t i, const TVMArgValue &value) const
Definition: packed_func.h:1536
static TVMArrayHandle FFIGetHandle(const ObjectRef &nd)
Get FFI Array handle from ndarray.
Definition: ndarray.h:421
void operator()(size_t i, const TVMRetValue &value) const
Definition: packed_func.h:1576
Runtime container of the functions generated by TVM, This is used to support dynamically link...
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:360
Module container of TVM.
Definition: module.h:50
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:362
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1268
static std::string TypeIndex2Key(uint32_t tindex)
Get the type key of the corresponding index from runtime.
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Runtime Map container types.
static String From(const TVMArgValue &val)
Definition: packed_func.h:1944
static std::string TypeName()
Definition: packed_func.h:531
Definition: c_runtime_api.h:114
int type_code() const
Definition: packed_func.h:610
TVMMovableArgValueWithContext_(TVMValue value, int type_code, int arg_index, const std::string *optional_name, FSig *f_sig)
move constructor from another return value.
Definition: packed_func.h:765
A single argument value to PackedFunc. Containing both type_code and TVMValue.
Definition: packed_func.h:646
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:138
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
const TVMValue & value() const
Definition: packed_func.h:691
FCallPacked * f_call_packed_
Internal callable function pointer used to call the packed function.
Definition: packed_func.h:109
TVMRetValue operator()(Args &&... args) const
Call packed function by directly passing in unpacked format.
Definition: packed_func.h:1610
PackedFuncObj()=delete
Delete the default constructor explicitly.
Definition: packed_func.h:1514
constexpr runtime::NullOptType NullOpt
Definition: optional.h:160
TVMArgsSetter(TVMValue *values, int *type_codes)
Definition: packed_func.h:1516
static constexpr const char * _type_key
Definition: packed_func.h:78
Derived object class for constructing PackedFuncObj.
Definition: packed_func.h:114
TypedPackedFunc(std::nullptr_t null)
constructor from null
Definition: packed_func.h:234
TVMRetValue(const TVMRetValue &other)
Definition: packed_func.h:828
PackedFunc(std::nullptr_t null)
Constructor from null.
Definition: packed_func.h:141
const TVMValue & value() const
Definition: packed_func.h:976
Definition: c_runtime_api.h:116
TVMArgValue(TVMValue value, int type_code)
constructor
Definition: packed_func.h:655
TVMRetValue & operator=(const TVMRetValue &other)
Definition: packed_func.h:932
#define TVM_CHECK_TYPE_CODE(CODE, T)
Definition: packed_func.h:422
TVM_ALWAYS_INLINE void operator()(size_t i, void *value) const
Definition: packed_func.h:1540
std::string() FSig
Using static function to output TypedPackedFunc signature.
Definition: packed_func.h:186
Definition: c_runtime_api.h:122
PackedFuncSubObj(TCallable callable)
Derived object class for constructing PackedFuncObj.
Definition: packed_func.h:124
TVMRetValue & operator=(std::nullptr_t value)
Definition: packed_func.h:864
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Definition: packed_func.h:472
runtime::DataType DataType
Definition: data_type.h:389
Definition: c_runtime_api.h:121
TVMRetValue & operator=(NDArray other)
Definition: packed_func.h:908
Type trait to specify special value conversion rules from TVMArgValue and TVMRetValue.
Definition: packed_func.h:1096
static TObjectRef From(const TVMRetValue &val)
Convert a TObjectRef from a return value.
Definition: packed_func.h:1108
Definition: c_runtime_api.h:117
Definition: packed_func.h:62
TypedPackedFunc()
default constructor
Definition: packed_func.h:232
Internal auxiliary struct for TypedPackedFunc to indicate a movable argument with additional context ...
Definition: packed_func.h:754
TVMRetValue & operator=(void *value)
Definition: packed_func.h:869
TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc< FType > &value) const
Definition: packed_func.h:1573
TVM_ALWAYS_INLINE void operator()(size_t i, DLDataType value) const
Definition: packed_func.h:1552
TVMPODValue_()
Definition: packed_func.h:631