24 #ifndef TVM_RUNTIME_PACKED_FUNC_H_ 25 #define TVM_RUNTIME_PACKED_FUNC_H_ 31 #include <tvm/runtime/logging.h> 41 #include <type_traits> 46 #ifndef TVM_RUNTIME_HEADER_ONLY 47 #define TVM_RUNTIME_HEADER_ONLY 0 56 class TVMMovableArgValueWithContext_;
59 template <
typename FType>
61 template <
typename TSignature>
78 static constexpr
const char*
_type_key =
"runtime.PackedFunc";
85 template <
class TPackedFuncSubObj>
113 template <
class TCallable>
115 using TStorage =
typename std::remove_cv<typename std::remove_reference<TCallable>::type>::type;
147 template <
typename TCallable,
148 typename = std::enable_if_t<
149 std::is_convertible<TCallable, std::function<void(TVMArgs, TVMRetValue*)>>::value &&
150 !std::is_base_of<TCallable, PackedFunc>::value>>
153 data_ = make_object<ObjType>(std::forward<TCallable>(data));
169 template <
typename... Args>
170 inline TVMRetValue operator()(Args&&... args)
const;
178 bool operator==(std::nullptr_t null)
const {
return data_ ==
nullptr; }
180 bool operator!=(std::nullptr_t null)
const {
return data_ !=
nullptr; }
191 template <
typename FType>
226 template <
typename R,
typename... Args>
284 template <
typename FLambda,
typename =
typename std::enable_if<std::is_convertible<
285 FLambda, std::function<R(Args...)>>::value>::type>
287 this->AssignTypedLambda(typed_lambda, name);
307 template <
typename FLambda,
typename =
typename std::enable_if<std::is_convertible<
308 FLambda, std::function<R(Args...)>>::value>::type>
310 this->AssignTypedLambda(typed_lambda);
328 template <
typename FLambda,
typename =
typename std::enable_if<
329 std::is_convertible<FLambda,
330 std::function<R(Args...)>>::value>::type>
332 this->AssignTypedLambda(typed_lambda);
349 TVM_ALWAYS_INLINE R operator()(Args... args)
const;
360 bool operator==(std::nullptr_t null)
const {
return packed_ ==
nullptr; }
362 bool operator!=(std::nullptr_t null)
const {
return packed_ !=
nullptr; }
376 template <
typename FLambda>
377 inline void AssignTypedLambda(FLambda flambda, std::string name);
386 template <
typename FLambda>
387 inline void AssignTypedLambda(FLambda flambda);
403 : values(values), type_codes(type_codes), num_args(num_args) {}
405 inline int size()
const;
422 #define TVM_CHECK_TYPE_CODE(CODE, T) \ 423 ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) 429 template <
typename T>
439 using ContainerType =
typename T::ContainerType;
440 if (ptr ==
nullptr) {
441 if (T::_type_is_nullable) {
459 using ContainerType =
typename T::ContainerType;
460 if (ptr ==
nullptr)
return T::_type_is_nullable;
464 using ContainerType =
typename T::ContainerType;
465 return ContainerType::_type_key;
470 template <
typename T>
473 if (ptr ==
nullptr) {
480 for (
size_t i = 0; i < n->size(); i++) {
484 return String(
"Array[index " + std::to_string(i) +
": " + check_subtype.
value() +
"]");
490 if (ptr ==
nullptr)
return true;
502 template <
typename K,
typename V>
505 if (ptr ==
nullptr)
return NullOpt;
508 for (
const auto& kv : *n) {
511 if (key_type.
defined() || value_type.defined()) {
512 std::string key_name =
514 std::string value_name = value_type.defined() ? std::string(value_type.value())
516 return String(
"Map[" + key_name +
", " + value_name +
"]");
522 if (ptr ==
nullptr)
return true;
525 for (
const auto& kv : *n) {
543 operator double()
const {
547 if (type_code_ == kDLInt) {
548 return static_cast<double>(value_.v_int64);
551 return value_.v_float64;
553 operator int64_t()
const {
555 return value_.v_int64;
557 operator uint64_t()
const {
559 return value_.v_int64;
561 operator int()
const {
565 return static_cast<int>(value_.v_int64);
567 operator bool()
const {
569 return value_.v_int64 != 0;
571 operator void*()
const {
575 return value_.v_handle;
577 operator DLTensor*()
const {
579 return static_cast<DLTensor*
>(value_.v_handle);
582 LOG(FATAL) <<
"Expected " 608 return value_.v_device;
616 template <
typename T>
618 return static_cast<T*
>(value_.v_handle);
621 template <
typename TObjectRef,
622 typename =
typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
623 inline bool IsObjectRef()
const;
624 template <
typename TObjectRef>
625 inline TObjectRef AsObjectRef()
const;
657 using TVMPODValue_::operator double;
658 using TVMPODValue_::operator int64_t;
659 using TVMPODValue_::operator uint64_t;
660 using TVMPODValue_::operator int;
661 using TVMPODValue_::operator bool;
662 using TVMPODValue_::operator
void*;
663 using TVMPODValue_::operator DLTensor*;
664 using TVMPODValue_::operator
NDArray;
665 using TVMPODValue_::operator
Device;
666 using TVMPODValue_::operator
Module;
672 operator std::string()
const {
677 return std::string(arr->
data, arr->
size);
678 }
else if (type_code_ ==
kTVMStr) {
679 return std::string(value_.v_str);
681 ICHECK(IsObjectRef<tvm::runtime::String>())
684 return AsObjectRef<tvm::runtime::String>().
operator std::string();
687 template <
typename FType>
693 template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
694 inline operator T()
const;
695 inline operator DLDataType()
const;
713 using TVMPODValue_::operator double;
714 using TVMPODValue_::operator int64_t;
715 using TVMPODValue_::operator uint64_t;
716 using TVMPODValue_::operator int;
717 using TVMPODValue_::operator bool;
718 using TVMPODValue_::operator
void*;
719 using TVMPODValue_::operator DLTensor*;
720 using TVMPODValue_::operator
NDArray;
721 using TVMPODValue_::operator
Device;
722 using TVMPODValue_::operator
Module;
725 operator std::string()
const {
return AsArgValue().operator std::string(); }
726 template <
typename FType>
730 operator DLDataType()
const {
return AsArgValue().operator DLDataType(); }
738 template <
typename T,
739 typename =
typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
740 inline operator T()
const;
766 const std::string* optional_name,
FSig* f_sig)
767 : value_(value, type_code),
768 arg_index_(arg_index),
769 optional_name_(optional_name),
772 template <
typename T>
776 }
catch (dmlc::Error& e) {
777 LOG(FATAL) <<
"In function " << (optional_name_ ==
nullptr ?
"<anonymous>" : *optional_name_)
778 << (f_sig_ ==
nullptr ?
"" : (*f_sig_)()) <<
": error while converting argument " 779 << arg_index_ <<
": " << e.what();
787 const std::string* optional_name_;
814 using TVMPODValue_::operator double;
815 using TVMPODValue_::operator int64_t;
816 using TVMPODValue_::operator uint64_t;
817 using TVMPODValue_::operator int;
818 using TVMPODValue_::operator bool;
819 using TVMPODValue_::operator
void*;
820 using TVMPODValue_::operator DLTensor*;
821 using TVMPODValue_::operator
Device;
822 using TVMPODValue_::operator
NDArray;
823 using TVMPODValue_::operator
Module;
830 operator std::string()
const {
834 return *ptr<std::string>();
837 return *ptr<std::string>();
839 operator DLDataType()
const {
844 return value_.v_type;
847 template <
typename FType>
854 value_ = other.value_;
855 type_code_ = other.type_code_;
860 this->SwitchToPOD(kDLFloat);
861 value_.v_float64 = value;
866 value_.v_handle = value;
871 value_.v_handle = value;
875 this->SwitchToPOD(kDLInt);
876 value_.v_int64 = value;
880 this->SwitchToPOD(kDLInt);
881 value_.v_int64 = value;
886 value_.v_device = value;
896 this->SwitchToPOD(kDLInt);
897 value_.v_int64 = value;
901 this->SwitchToClass(
kTVMStr, value);
909 if (other.
data_ !=
nullptr) {
916 value_.v_handle =
nullptr;
928 template <
typename FType>
957 *ret_type_code = type_code_;
979 <<
"TVMRetValue.value can only be used for POD data";
983 template <
typename TObjectRef,
984 typename =
typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
986 template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type>
987 inline operator T()
const;
990 template <
typename T>
991 void Assign(
const T& other) {
992 switch (other.type_code()) {
994 SwitchToClass<std::string>(
kTVMStr, other);
998 SwitchToClass<std::string>(
kTVMBytes, other);
1002 *
this = other.operator PackedFunc();
1006 *
this = other.operator Module();
1010 *
this = other.operator NDArray();
1016 GetObjectPtr<Object>(static_cast<Object*>(other.value_.v_handle)));
1024 SwitchToPOD(other.type_code());
1025 value_ = other.value_;
1031 void SwitchToPOD(
int type_code) {
1032 if (type_code_ != type_code) {
1034 type_code_ = type_code;
1037 template <
typename T>
1038 void SwitchToClass(
int type_code, T v) {
1039 if (type_code_ != type_code) {
1041 type_code_ = type_code;
1042 value_.v_handle =
new T(v);
1044 *
static_cast<T*
>(value_.v_handle) = v;
1048 if (other.data_ !=
nullptr) {
1050 type_code_ = type_code;
1052 value_.v_handle = other.data_;
1053 other.data_ =
nullptr;
1056 value_.v_handle =
nullptr;
1061 switch (type_code_) {
1064 delete ptr<std::string>();
1095 template <
typename TObjectRef>
1130 #define TVM_DLL_EXPORT_PACKED_FUNC(ExportName, Function) \ 1132 TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ 1133 int* out_type_code, void* resource_handle); \ 1134 int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ 1135 int* out_type_code, void* resource_handle) { \ 1137 ::tvm::runtime::TVMRetValue rv; \ 1138 Function(::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ 1139 rv.MoveToCHost(out_value, out_type_code); \ 1141 } catch (const ::std::exception& _except_) { \ 1142 TVMAPISetLastError(_except_.what()); \ 1183 #define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ 1185 TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ 1186 int* out_type_code, void* resource_handle) { \ 1188 auto f = Function; \ 1189 using FType = ::tvm::runtime::detail::function_signature<decltype(f)>::FType; \ 1190 ::tvm::runtime::TVMRetValue rv; \ 1191 ::tvm::runtime::detail::unpack_call_by_signature<FType>::run( \ 1192 f, ::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ 1193 rv.MoveToCHost(out_value, out_type_code); \ 1195 } catch (const ::std::exception& _except_) { \ 1196 TVMAPISetLastError(_except_.what()); \ 1203 ICHECK_LT(i, num_args) <<
"not enough argument passed, " << num_args <<
" passed" 1204 <<
" but request arg[" << i <<
"].";
1210 template <
class TPackedFuncSubObj>
1213 (
static_cast<const TPackedFuncSubObj*
>(obj))->callable_(args, rv);
1217 (*f_call_packed_)(
this, args, rv);
1221 (
static_cast<PackedFuncObj*
>(data_.get()))->CallPacked(args, rv);
1226 switch (type_code) {
1242 return "ArrayHandle";
1244 return "DLDataType";
1248 return "FunctionHandle";
1250 return "ModuleHandle";
1252 return "NDArrayContainer";
1256 return "ObjectRValueRefArg";
1258 LOG(FATAL) <<
"unknown type_code=" <<
static_cast<int>(type_code);
1264 template <
bool stop, std::
size_t I,
typename F>
1265 struct for_each_dispatcher {
1266 template <
typename T,
typename... Args>
1267 static void run(
const F& f, T&& value, Args&&... args) {
1268 f(I, std::forward<T>(value));
1269 for_each_dispatcher<
sizeof...(Args) == 0, (I + 1), F>::run(f, std::forward<Args>(args)...);
1273 template <std::
size_t I,
typename F>
1274 struct for_each_dispatcher<true, I, F> {
1275 static void run(
const F& f) {}
1278 template <
typename F,
typename... Args>
1279 inline void for_each(
const F& f, Args&&... args) {
1280 for_each_dispatcher<
sizeof...(Args) == 0, 0, F>::run(f, std::forward<Args>(args)...);
1283 namespace parameter_pack {
1285 template <
typename... EnumArgs>
1286 struct EnumeratedParamPack {
1288 template <
template <
size_t i,
typename TArgument>
class Functor,
typename... ExtraParams>
1289 static void F(ExtraParams&&... extra_params) {
1290 using TExpander =
int[];
1293 (Functor<EnumArgs::i, typename EnumArgs::T>::F(extra_params...), 0)...,
1299 template <
typename... Args>
1300 struct EnumerateImpl {
1302 template <
size_t _i,
typename _T>
1304 static const constexpr
size_t i = _i;
1308 template <
typename...>
1311 template <std::size_t...
id>
1312 struct Zipper<std::integer_sequence<std::size_t, id...>> {
1313 using T = EnumeratedParamPack<Item<id, Args>...>;
1317 using T =
typename Zipper<std::index_sequence_for<Args...>>::T;
1320 template <
typename... Args>
1321 using Enumerate =
typename EnumerateImpl<Args...>::T;
1323 template <
typename... Args>
1325 template <
template <
size_t i,
typename TArgument>
class Functor,
typename... ExtraParams>
1326 static void InvokeWithoutArg(ExtraParams&&... extra_params) {
1327 Enumerate<Args...>::Invoke::template F<Functor, ExtraParams...>(
1328 std::forward<ExtraParams>(extra_params)...);
1338 template <
typename T>
1339 struct func_signature_helper {
1343 template <
typename T,
typename R,
typename... Args>
1344 struct func_signature_helper<R (T::*)(Args...)> {
1345 using FType = R(Args...);
1346 using ParamType = parameter_pack::ParamPack<Args...>;
1348 static_assert(!std::is_reference<R>::value,
"TypedPackedFunc return reference");
1351 template <
typename T,
typename R,
typename... Args>
1352 struct func_signature_helper<R (T::*)(Args...) const> {
1353 using FType = R(Args...);
1354 using ParamType = parameter_pack::ParamPack<Args...>;
1356 static_assert(!std::is_reference<R>::value,
"TypedPackedFunc return reference");
1363 template <
typename T>
1364 struct function_signature {
1365 using FType =
typename func_signature_helper<decltype(&T::operator())>::FType;
1366 using ParamType =
typename func_signature_helper<decltype(&T::operator())>::ParamType;
1367 using RetType =
typename func_signature_helper<decltype(&T::operator())>::RetType;
1371 template <
typename R,
typename... Args>
1372 struct function_signature<R(Args...)> {
1373 using FType = R(Args...);
1374 using ParamType = parameter_pack::ParamPack<Args...>;
1376 static_assert(!std::is_reference<R>::value,
"TypedPackedFunc return reference");
1380 template <
typename R,
typename... Args>
1381 struct function_signature<R (*)(Args...)> {
1382 using FType = R(Args...);
1383 using ParamType = detail::parameter_pack::ParamPack<Args...>;
1385 static_assert(!std::is_reference<R>::value,
"TypedPackedFunc return reference");
1388 template <
typename TSignature>
1391 namespace type2str {
1393 template <
typename T>
1394 struct TypeSimplifier;
1396 template <
typename T>
1398 template <typename = std::enable_if_t<std::is_base_of<ObjectRef, T>::value>>
1399 static std::string v() {
1400 return T::ContainerType::_type_key;
1404 struct Type2Str<int> {
1405 static std::string v() {
return "int"; }
1408 struct Type2Str<double> {
1409 static std::string v() {
return "double"; }
1412 struct Type2Str<int64_t> {
1413 static std::string v() {
return "int64_t"; }
1416 struct Type2Str<uint64_t> {
1417 static std::string v() {
return "uint64_t"; }
1420 struct Type2Str<bool> {
1421 static std::string v() {
return "bool"; }
1424 struct Type2Str<void> {
1425 static std::string v() {
return "void"; }
1428 struct Type2Str<std::basic_string<char>> {
1429 static std::string v() {
return "basic_string<char>"; }
1431 template <
typename K,
typename V>
1432 struct Type2Str<Map<K, V>> {
1433 static std::string v() {
1434 return "Map<" + TypeSimplifier<K>::v() +
", " + TypeSimplifier<V>::v() +
">";
1438 struct Type2Str<DLDevice> {
1439 static std::string v() {
return "DLDevice"; }
1442 struct Type2Str<DLTensor> {
1443 static std::string v() {
return "DLTensor"; }
1446 struct Type2Str<DataType> {
1447 static std::string v() {
return "DataType"; }
1450 struct Type2Str<DLDataType> {
1451 static std::string v() {
return "DLDataType"; }
1454 struct Type2Str<TVMRetValue> {
1455 static std::string v() {
return "TVMRetValue"; }
1458 struct Type2Str<TVMArgValue> {
1459 static std::string v() {
return "TVMArgValue"; }
1461 template <
typename FType>
1462 struct Type2Str<TypedPackedFunc<FType>> {
1465 template <
typename T>
1466 struct Type2Str<Array<T>> {
1467 static std::string v() {
return "Array<" + TypeSimplifier<T>::v() +
">"; }
1474 template <
typename T>
1475 struct TypeSimplifier {
1476 static std::string v() {
1477 using U =
typename std::remove_cv<
1478 typename std::remove_reference<typename std::remove_pointer<T>::type>::type>::type;
1479 return (std::is_const<T>::value ?
"const " :
"") + Type2Str<U>::v() +
1480 (std::is_pointer<T>::value ?
"*" :
"") + (std::is_reference<T>::value ?
"&" :
"");
1490 template <
typename TSignature>
1492 using ParamType =
typename TSignature::ParamType;
1493 using RetType =
typename TSignature::RetType;
1495 template <
size_t i,
typename TArgument>
1496 struct PrintParamType {
1497 static void F(std::ostream& os) {
1498 os << (i == 0 ?
"" :
", ") << i <<
": " << type2str::TypeSimplifier<TArgument>::v();
1502 static std::string F() {
1503 std::ostringstream oss;
1505 ParamType::template InvokeWithoutArg<PrintParamType>(oss);
1506 oss <<
") -> " << type2str::TypeSimplifier<RetType>::v();
1517 template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type>
1519 values_[i].v_int64 =
static_cast<int64_t
>(value);
1520 type_codes_[i] = kDLInt;
1522 TVM_ALWAYS_INLINE
void operator()(
size_t i, uint64_t value)
const {
1523 values_[i].v_int64 =
static_cast<int64_t
>(value);
1525 type_codes_[i] = kDLInt;
1527 TVM_ALWAYS_INLINE
void operator()(
size_t i,
double value)
const {
1528 values_[i].v_float64 = value;
1529 type_codes_[i] = kDLFloat;
1531 TVM_ALWAYS_INLINE
void operator()(
size_t i, std::nullptr_t value)
const {
1532 values_[i].v_handle = value;
1536 values_[i] = value.
value_;
1540 values_[i].v_handle = value;
1543 TVM_ALWAYS_INLINE
void operator()(
size_t i, DLTensor* value)
const {
1544 values_[i].v_handle = value;
1548 values_[i].v_device = value;
1551 TVM_ALWAYS_INLINE
void operator()(
size_t i, DLDataType value)
const {
1552 values_[i].v_type = value;
1556 operator()(i, dtype.operator DLDataType());
1558 TVM_ALWAYS_INLINE
void operator()(
size_t i,
const char* value)
const {
1559 values_[i].v_str = value;
1563 TVM_ALWAYS_INLINE
void operator()(
size_t i,
const std::string& value)
const {
1564 values_[i].v_str = value.c_str();
1568 values_[i].v_handle =
const_cast<TVMByteArray*
>(&value);
1571 template <
typename FType>
1573 operator()(i, value.packed());
1577 values_[i].v_str = value.
ptr<std::string>()->c_str();
1581 values_[i] = value.
value_;
1586 template <
typename TObjectRef,
1587 typename =
typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
1588 TVM_ALWAYS_INLINE
void operator()(
size_t i,
const TObjectRef& value)
const {
1589 this->SetObject(i, value);
1592 template <
typename TObjectRef,
1593 typename =
typename std::enable_if<std::is_base_of<
1594 ObjectRef,
typename std::remove_reference<TObjectRef>::type>::value>::type>
1595 TVM_ALWAYS_INLINE
void operator()(
size_t i, TObjectRef&& value)
const {
1596 this->SetObject(i, std::forward<TObjectRef>(value));
1600 template <
typename TObjectRef>
1601 inline void SetObject(
size_t i, TObjectRef&& value)
const;
1608 template <
typename... Args>
1610 const int kNumArgs =
sizeof...(Args);
1611 const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
1613 int type_codes[kArraySize];
1614 detail::for_each(
TVMArgsSetter(values, type_codes), std::forward<Args>(args)...);
1617 ->CallPacked(
TVMArgs(values, type_codes, kNumArgs), &rv);
1622 template <
typename R,
int nleft,
int index,
typename F>
1623 struct unpack_call_dispatcher {
1624 template <
typename... Args>
1625 TVM_ALWAYS_INLINE
static void run(
const std::string* optional_name,
FSig* f_sig,
const F& f,
1627 Args&&... unpacked_args) {
1630 unpack_call_dispatcher<R, nleft - 1, index + 1, F>::run(
1631 optional_name, f_sig, f, args_pack, rv, std::forward<Args>(unpacked_args)...,
1633 optional_name, f_sig));
1637 template <
typename R,
int index,
typename F>
1638 struct unpack_call_dispatcher<R, 0, index, F> {
1639 template <
typename... Args>
1640 TVM_ALWAYS_INLINE
static void run(
const std::string* optional_name,
FSig* f_sig,
const F& f,
1642 Args&&... unpacked_args) {
1643 using RetType = decltype(f(std::forward<Args>(unpacked_args)...));
1644 if (std::is_same<RetType, R>::value) {
1645 *rv = f(std::forward<Args>(unpacked_args)...);
1647 *rv = R(f(std::forward<Args>(unpacked_args)...));
1652 template <
int index,
typename F>
1653 struct unpack_call_dispatcher<void, 0, index, F> {
1654 template <
typename... Args>
1655 TVM_ALWAYS_INLINE
static void run(
const std::string* optional_name,
FSig* f_sig,
const F& f,
1657 Args&&... unpacked_args) {
1658 f(std::forward<Args>(unpacked_args)...);
1662 template <
typename R,
int nargs,
typename F>
1663 TVM_ALWAYS_INLINE
void unpack_call(
const std::string* optional_name,
const F& f,
1665 FSig* f_sig = detail::SignaturePrinter<detail::function_signature<F>>::F;
1666 CHECK_EQ(nargs, args.
size()) <<
"Function " 1667 << (optional_name ==
nullptr ?
"<anonymous>" : *optional_name)
1668 << (f_sig ==
nullptr ?
"" : (*f_sig)()) <<
" expects " << nargs
1669 <<
" arguments but " << args.
size() <<
" were provided";
1670 unpack_call_dispatcher<R, nargs, 0, F>::run(optional_name, f_sig, f, args, rv);
1673 template <
typename FType>
1674 struct unpack_call_by_signature {};
1676 template <
typename R,
typename... Args>
1677 struct unpack_call_by_signature<R(Args...)> {
1678 template <
typename F>
1679 TVM_ALWAYS_INLINE
static void run(
const F& f,
const TVMArgs& args,
TVMRetValue* rv) {
1680 unpack_call<R,
sizeof...(Args)>(
nullptr, f, args, rv);
1684 template <
typename R,
typename... Args>
1685 TVM_ALWAYS_INLINE R call_packed(
const PackedFunc& pf, Args&&... args) {
1686 return R(pf(std::forward<Args>(args)...));
1689 template <
typename R>
1690 struct typed_packed_call_dispatcher {
1691 template <
typename... Args>
1692 TVM_ALWAYS_INLINE
static R run(
const PackedFunc& pf, Args&&... args) {
1693 return pf(std::forward<Args>(args)...);
1698 struct typed_packed_call_dispatcher<void> {
1699 template <
typename... Args>
1700 TVM_ALWAYS_INLINE
static void run(
const PackedFunc& pf, Args&&... args) {
1701 pf(std::forward<Args>(args)...);
1706 template <
typename R,
typename... Args>
1709 template <
typename R,
typename... Args>
1713 template <
typename R,
typename... Args>
1717 template <
typename R,
typename... Args>
1721 template <
typename R,
typename... Args>
1722 template <
typename FType>
1723 inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda, std::string name) {
1724 FSig* f_sig = detail::SignaturePrinter<detail::function_signature<FType>>::F;
1726 if (args.
size() !=
sizeof...(Args)) {
1727 LOG(FATAL) <<
"Function " << name << (f_sig ==
nullptr ?
"" : (*f_sig)()) <<
" expects " 1728 <<
sizeof...(Args) <<
" arguments, but " << args.
size() <<
" were provided.";
1730 detail::unpack_call<R,
sizeof...(Args)>(&name, flambda, args, rv);
1734 template <
typename R,
typename... Args>
1735 template <
typename FType>
1737 FSig* f_sig = detail::SignaturePrinter<detail::function_signature<FType>>::F;
1739 if (args.
size() !=
sizeof...(Args)) {
1740 LOG(FATAL) <<
"Function <anonymous> " << (*f_sig)() <<
" expects " <<
sizeof...(Args)
1741 <<
" arguments, but " << args.
size() <<
" were provided.";
1743 detail::unpack_call<R,
sizeof...(Args)>(
nullptr, flambda, args, rv);
1747 template <
typename R,
typename... Args>
1749 return detail::typed_packed_call_dispatcher<R>::run(packed_, std::forward<Args>(args)...);
1757 template <
typename T>
1758 inline void TVMArgsSetter::SetObject(
size_t i, T&& value)
const {
1759 using ContainerType =
typename std::remove_reference<T>::type::ContainerType;
1760 if (value.defined()) {
1761 Object* ptr = value.data_.data_;
1762 if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
1763 (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1767 }
else if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
1768 (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1770 values_[i].v_handle = ptr;
1772 }
else if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value ||
1773 (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
1775 values_[i].v_handle = ptr;
1777 }
else if (std::is_rvalue_reference<decltype(value)>::value) {
1778 values_[i].v_handle =
const_cast<Object**
>(&(value.data_.data_));
1781 values_[i].v_handle = value.data_.data_;
1786 values_[i].v_handle =
nullptr;
1790 template <
typename TObjectRef,
typename>
1792 using ContainerType =
typename TObjectRef::ContainerType;
1794 if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
1799 if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
1801 static_cast<Object*
>(value_.v_handle)->IsInstance<ContainerType>();
1803 if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value) {
1805 static_cast<Object*
>(value_.v_handle)->IsInstance<ContainerType>();
1811 return (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1813 (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1815 (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
1821 template <
typename TObjectRef>
1823 static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
1824 "Conversion only works for ObjectRef");
1825 using ContainerType =
typename TObjectRef::ContainerType;
1828 CHECK(TObjectRef::_type_is_nullable)
1829 <<
"Expect a not null value of " << ContainerType::_type_key;
1833 if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
1838 CHECK(data->IsInstance<ContainerType>())
1839 <<
"Expected " << ContainerType::_type_key <<
" but got " << data->GetTypeKey();
1840 return TObjectRef(data);
1842 if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
1846 CHECK(data->IsInstance<ContainerType>())
1847 <<
"Expected " << ContainerType::_type_key <<
" but got " << data->GetTypeKey();
1848 return TObjectRef(data);
1850 if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value) {
1854 CHECK(data->IsInstance<ContainerType>())
1855 <<
"Expected " << ContainerType::_type_key <<
" but got " << data->GetTypeKey();
1856 return TObjectRef(data);
1863 <<
", but got " << checked_type.
value();
1864 return TObjectRef(GetObjectPtr<Object>(ptr));
1869 <<
", but got " << checked_type.
value();
1870 return TObjectRef(GetObjectPtr<Object>(ptr));
1871 }
else if (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1876 return TObjectRef(data);
1877 }
else if (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1880 return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
1881 }
else if (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
1884 return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
1891 template <
typename TObjectRef,
typename>
1893 using ContainerType =
typename TObjectRef::ContainerType;
1894 const Object* ptr = other.get();
1895 if (ptr !=
nullptr) {
1896 if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
1897 (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
1899 return operator=(
NDArray(std::move(other.data_)));
1901 if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
1902 (std::is_base_of<ContainerType, Module::ContainerType>::value &&
1904 return operator=(
Module(std::move(other.data_)));
1909 value_.v_handle =
nullptr;
1914 template <
typename T,
typename>
1915 inline TVMArgValue::operator T()
const {
1919 template <
typename T,
typename>
1920 inline TVMMovableArgValue_::operator T()
const {
1922 auto** ref =
static_cast<Object**
>(value_.v_handle);
1931 template <
typename T,
typename>
1932 inline TVMRetValue::operator T()
const {
1937 return (*this)->GetFunction(name, query_imports);
1960 template <
typename T>
1976 inline TVMArgValue::operator DLDataType()
const {
1989 return value_.v_type;
1996 #endif // TVM_RUNTIME_PACKED_FUNC_H_ static TObjectRef From(const TVMArgValue &val)
Convert a TObjectRef from an argument value.
Definition: packed_func.h:1102
const char * ArgTypeCode2Str(int type_code)
Convert argument type code to string.
Definition: packed_func.h:1225
int num_args
Definition: packed_func.h:395
TVMArgValue()
default constructor
Definition: packed_func.h:649
static bool Check(const Object *ptr)
Definition: packed_func.h:489
TVMRetValue & operator=(std::string value)
Definition: packed_func.h:900
array node content in array
Definition: array.h:40
Return Value container, Unlike TVMArgValue, which only holds reference and do not delete the underlyi...
Definition: packed_func.h:799
PackedFunc GetFunction(const std::string &name, bool query_imports=false)
Get packed function from current module by name.
Definition: packed_func.h:1936
static void FFIDecRef(TVMArrayHandle handle)
DecRef resource managed by an FFI array handle.
Definition: ndarray.h:438
TVMRetValue & operator=(TVMMovableArgValue_ &&other)
Definition: packed_func.h:940
TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef &&value) const
Definition: packed_func.h:1595
void MoveToCHost(TVMValue *ret_value, int *ret_type_code)
Move the value back to front-end via C API. This marks the current container as null. The managed resources are moved to the front-end. The front end should take charge in managing them.
Definition: packed_func.h:953
PrimExpr min(PrimExpr a, PrimExpr b, Span span=Span())
take minimum of two values
A custom smart pointer for Object.
Definition: object.h:358
TVMArgs(const TVMValue *values, const int *type_codes, int num_args)
constructor
Definition: packed_func.h:402
A PackedFunc wrapper to provide typed function signature. It is backed by a PackedFunc internally...
Definition: packed_func.h:227
TVM_ALWAYS_INLINE void operator()(size_t i, const char *value) const
Definition: packed_func.h:1558
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:180
TVMRetValue & operator=(bool value)
Definition: packed_func.h:895
int type_code_
the type code
Definition: packed_func.h:637
Internal base class to handle conversion to POD values.
Definition: packed_func.h:541
static constexpr const uint32_t _type_index
Definition: packed_func.h:77
std::string DLDataType2String(DLDataType t)
convert a TVM type to string.
Definition: data_type.h:341
void * v_handle
Definition: c_runtime_api.h:204
const PackedFunc & packed() const
Definition: packed_func.h:358
Definition: c_runtime_api.h:181
runtime implementation for LibTorch/TorchScript.
Definition: analyzer.h:36
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Definition: packed_func.h:504
TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue *rv) const
Call the function in packed format.
Definition: packed_func.h:1220
Definition: c_runtime_api.h:176
TVMMovableArgValue_(TVMValue value, int type_code)
Definition: packed_func.h:711
TVM_ALWAYS_INLINE void operator()(size_t i, DLTensor *value) const
Definition: packed_func.h:1543
~TVMRetValue()
destructor
Definition: packed_func.h:812
Definition: c_runtime_api.h:172
const TVMValue * values
Definition: packed_func.h:393
Definition: c_runtime_api.h:177
TVMArgValue operator[](int i) const
Get i-th argument.
Definition: packed_func.h:1202
size_t size
Definition: c_runtime_api.h:216
static std::string TypeName()
Definition: packed_func.h:463
static bool Check(const Object *ptr)
Check if an object matches the template type.
Definition: packed_func.h:458
Union type of values being passed through API and function calls.
Definition: c_runtime_api.h:201
base class of all object containers.
Definition: object.h:167
Definition: c_runtime_api.h:175
const char * data
Definition: c_runtime_api.h:215
Object * TVMArrayHandleToObjectHandle(TVMArrayHandle handle)
Definition: ndarray.h:442
TVMValue value_
The value.
Definition: packed_func.h:635
TVMRetValue()
default constructor
Definition: packed_func.h:802
Managed NDArray. The array is backed by reference counted blocks.
Definition: ndarray.h:59
TVMRetValue & operator=(int value)
Definition: packed_func.h:879
static void FFIClearAfterMove(ObjectRef *ref)
Clear the object ref data field without DecRef after we successfully moved the field.
Definition: object.h:592
TVMRetValue & operator=(TVMRetValue &&other)
Definition: packed_func.h:852
Byte array type used to pass in byte array When kTVMBytes is used as data type.
Definition: c_runtime_api.h:214
TVMRetValue & operator=(TVMByteArray value)
Definition: packed_func.h:904
TypedPackedFunc(const FLambda &typed_lambda)
construct from a lambda function with the same signature.
Definition: packed_func.h:309
void DecRef()
developer function, decrease reference counter.
Definition: object.h:801
bool IsInstance() const
Definition: object.h:829
TVMRetValue & operator=(const TVMArgValue &other)
Definition: packed_func.h:936
TVMRetValue & operator=(double value)
Definition: packed_func.h:859
TVM_ALWAYS_INLINE void operator()(size_t i, const TVMByteArray &value) const
Definition: packed_func.h:1567
Runtime Array container types.
Definition: packed_func.h:38
Internal auxiliary struct for TypedPackedFunc to indicate a movable argument.
Definition: packed_func.h:709
static bool Check(const Object *ptr)
Definition: packed_func.h:521
TVM_ALWAYS_INLINE void operator()(size_t i, T value) const
Definition: packed_func.h:1518
static Optional< T > From(const TVMRetValue &val)
Definition: packed_func.h:1966
A device-independent managed NDArray abstraction.
DLDataType String2DLDataType(std::string s)
convert a string to TVM type.
Definition: data_type.h:348
static String From(const TVMRetValue &val)
Definition: packed_func.h:1951
static ObjectPtr< Object > FFIDataFromHandle(TVMArrayHandle handle)
Construct NDArray's Data field from array handle in FFI.
Definition: ndarray.h:425
Type traits for runtime type check during FFI conversion.
Definition: packed_func.h:430
bool IsObjectRef() const
Definition: packed_func.h:1791
static Optional< T > From(const TVMArgValue &val)
Definition: packed_func.h:1962
Runtime primitive data type.
Definition: data_type.h:41
bool defined() const
Definition: object.h:544
TVM_ALWAYS_INLINE void operator()(size_t i, Device value) const
Definition: packed_func.h:1547
Arguments into TVM functions.
Definition: packed_func.h:391
TVMRetValue(TVMRetValue &&other)
move constructor from another return value.
Definition: packed_func.h:807
TSelf & operator=(FLambda typed_lambda)
copy assignment operator from typed lambda
Definition: packed_func.h:331
PackedFunc(TCallable data)
Constructing a packed function from a callable type whose signature is consistent with PackedFunc ...
Definition: packed_func.h:151
T value() const
Definition: optional.h:92
Array, container representing a contiguous sequence of ObjectRefs.
Definition: array.h:289
Definition: c_runtime_api.h:180
TypedPackedFunc(const FLambda &typed_lambda, std::string name)
construct from a lambda function with the same signature.
Definition: packed_func.h:286
TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const
Definition: packed_func.h:1522
static std::string TypeName()
Definition: packed_func.h:500
TSelf & operator=(PackedFunc packed)
copy assignment operator from PackedFunc.
Definition: packed_func.h:340
runtime::PackedFunc.
Definition: object.h:74
TVMRetValue & operator=(Module m)
Definition: packed_func.h:920
Object container class that backs NDArray.
Definition: ndarray.h:294
TVMRetValue & operator=(PackedFunc f)
Definition: packed_func.h:924
PackedFuncObj(FCallPacked *f_call_pack)
Constructing a packed function object from a function pointer.
Definition: packed_func.h:103
TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef &value) const
Definition: packed_func.h:1588
const int * type_codes
Definition: packed_func.h:394
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:574
TVMRetValue & operator=(DLDataType t)
Definition: packed_func.h:889
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:178
TObjectRef AsObjectRef() const
Definition: packed_func.h:1822
PrimExpr max(PrimExpr a, PrimExpr b, Span span=Span())
take maximum of two values
TVM_ALWAYS_INLINE void operator()(size_t i, const std::string &value) const
Definition: packed_func.h:1563
Reference to string objects.
Definition: string.h:97
void(const PackedFuncObj *, TVMArgs, TVMRetValue *) FCallPacked
The internal callable function type.
Definition: packed_func.h:97
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:60
Object container class that backs PackedFunc.
Definition: packed_func.h:68
TVM_ALWAYS_INLINE void operator()(size_t i, DataType dtype) const
Definition: packed_func.h:1555
TVMPODValue_(TVMValue value, int type_code)
Definition: packed_func.h:632
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
Definition: object.h:713
TVM_ALWAYS_INLINE void operator()(size_t i, double value) const
Definition: packed_func.h:1527
TVM_ALWAYS_INLINE void operator()(size_t i, std::nullptr_t value) const
Definition: packed_func.h:1531
T * ptr() const
return handle as specific pointer type.
Definition: packed_func.h:617
DLDevice Device
Definition: ndarray.h:43
Definition: c_runtime_api.h:170
const Object * get() const
Definition: object.h:546
TVMRetValue & operator=(const TypedPackedFunc< FType > &f)
Definition: packed_func.h:929
static TVMRetValue MoveFromCHost(TVMValue value, int type_code)
Construct a new TVMRetValue by moving from return value stored via C API.
Definition: packed_func.h:967
Object & operator=(const Object &other)
Definition: object.h:251
TVMRetValue & operator=(DLDevice value)
Definition: packed_func.h:884
TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncObj, Object)
TStorage callable_
Type-erased filed for storing callable object.
Definition: packed_func.h:127
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Check if an object matches the template type and return the mismatched type if it exists...
Definition: packed_func.h:438
static bool CanConvertFrom(const TVMArgValue &val)
Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String.
Definition: packed_func.h:1972
Base class of all object reference.
Definition: object.h:511
Base container of module.
Definition: module.h:113
TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue *rv) const
Call the function in packed format.
Definition: packed_func.h:1216
int size() const
Definition: packed_func.h:1208
std::string GetTypeKey() const
Definition: object.h:180
TVMRetValue & operator=(const DataType &other)
Definition: packed_func.h:894
Shared content of all specializations of hash map.
Definition: map.h:174
TVMRetValue & operator=(int64_t value)
Definition: packed_func.h:874
A managed object in the TVM runtime.
TVM_ALWAYS_INLINE void operator()(size_t i, const TVMArgValue &value) const
Definition: packed_func.h:1535
static TVMArrayHandle FFIGetHandle(const ObjectRef &nd)
Get FFI Array handle from ndarray.
Definition: ndarray.h:430
void operator()(size_t i, const TVMRetValue &value) const
Definition: packed_func.h:1575
Runtime container of the functions generated by TVM, This is used to support dynamically link...
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:360
Module container of TVM.
Definition: module.h:50
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:362
Map container of NodeRef->NodeRef in DSL graph. Map implements copy on write semantics, which means map is mutable but copy will happen when array is referenced in more than two places.
Definition: map.h:1271
static std::string TypeIndex2Key(uint32_t tindex)
Get the type key of the corresponding index from runtime.
PrimExpr ret(PrimExpr value, Span span=Span())
Return the value.
Runtime Map container types.
static String From(const TVMArgValue &val)
Definition: packed_func.h:1943
static std::string TypeName()
Definition: packed_func.h:531
Definition: c_runtime_api.h:171
int type_code() const
Definition: packed_func.h:610
TVMMovableArgValueWithContext_(TVMValue value, int type_code, int arg_index, const std::string *optional_name, FSig *f_sig)
move constructor from another return value.
Definition: packed_func.h:765
A single argument value to PackedFunc. Containing both type_code and TVMValue.
Definition: packed_func.h:646
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:138
Optional container that to represent to a Nullable variant of T.
Definition: optional.h:51
const TVMValue & value() const
Definition: packed_func.h:691
FCallPacked * f_call_packed_
Internal callable function pointer used to call the packed function.
Definition: packed_func.h:109
TVMRetValue operator()(Args &&... args) const
Call packed function by directly passing in unpacked format.
Definition: packed_func.h:1609
PackedFuncObj()=delete
Delete the default constructor explicitly.
Definition: packed_func.h:1513
constexpr runtime::NullOptType NullOpt
Definition: optional.h:160
TVMArgsSetter(TVMValue *values, int *type_codes)
Definition: packed_func.h:1515
static constexpr const char * _type_key
Definition: packed_func.h:78
Derived object class for constructing PackedFuncObj.
Definition: packed_func.h:114
TypedPackedFunc(std::nullptr_t null)
constructor from null
Definition: packed_func.h:234
TVMRetValue(const TVMRetValue &other)
Definition: packed_func.h:828
PackedFunc(std::nullptr_t null)
Constructor from null.
Definition: packed_func.h:141
const TVMValue & value() const
Definition: packed_func.h:976
Definition: c_runtime_api.h:173
TVMArgValue(TVMValue value, int type_code)
constructor
Definition: packed_func.h:655
TVMRetValue & operator=(const TVMRetValue &other)
Definition: packed_func.h:932
#define TVM_CHECK_TYPE_CODE(CODE, T)
Definition: packed_func.h:422
TVM_ALWAYS_INLINE void operator()(size_t i, void *value) const
Definition: packed_func.h:1539
std::string() FSig
Using static function to output TypedPackedFunc signature.
Definition: packed_func.h:186
Definition: c_runtime_api.h:179
PackedFuncSubObj(TCallable callable)
Derived object class for constructing PackedFuncObj.
Definition: packed_func.h:124
TVMRetValue & operator=(std::nullptr_t value)
Definition: packed_func.h:864
static Optional< String > CheckAndGetMismatch(const Object *ptr)
Definition: packed_func.h:472
runtime::DataType DataType
Definition: data_type.h:398
Definition: c_runtime_api.h:178
TVMRetValue & operator=(NDArray other)
Definition: packed_func.h:908
Type trait to specify special value conversion rules from TVMArgValue and TVMRetValue.
Definition: packed_func.h:1096
static TObjectRef From(const TVMRetValue &val)
Convert a TObjectRef from a return value.
Definition: packed_func.h:1108
Definition: c_runtime_api.h:174
Definition: packed_func.h:62
TypedPackedFunc()
default constructor
Definition: packed_func.h:232
Internal auxiliary struct for TypedPackedFunc to indicate a movable argument with additional context ...
Definition: packed_func.h:754
TVMRetValue & operator=(void *value)
Definition: packed_func.h:869
TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc< FType > &value) const
Definition: packed_func.h:1572
TVM_ALWAYS_INLINE void operator()(size_t i, DLDataType value) const
Definition: packed_func.h:1551
TVMPODValue_()
Definition: packed_func.h:631