Program Listing for File structural_visit.h

Program Listing for File structural_visit.h#

Return to documentation for file (tvm/ffi/extra/structural_visit.h)

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
#ifndef TVM_FFI_EXTRA_STRUCTURAL_VISIT_H_
#define TVM_FFI_EXTRA_STRUCTURAL_VISIT_H_

#include <tvm/ffi/any.h>
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/cast.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/tuple.h>
#include <tvm/ffi/container/variant.h>
#include <tvm/ffi/expected.h>
#include <tvm/ffi/extra/visit_error_context.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/function_details.h>
#include <tvm/ffi/optional.h>
#include <tvm/ffi/reflection/accessor.h>

#include <cstddef>
#include <exception>
#include <optional>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>

namespace tvm {
namespace ffi {

class VisitInterruptObj : public Object {
 public:
  Any value;

  VisitInterruptObj() = default;
  explicit VisitInterruptObj(Any value) : value(std::move(value)) {}

  static constexpr const int32_t _type_index = TypeIndex::kTVMFFIVisitInterrupt;
  static const constexpr bool _type_final = true;
  TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIVisitInterrupt, VisitInterruptObj,
                                     Object);
};

class VisitInterrupt : public ObjectRef {
 public:
  VisitInterrupt() : VisitInterrupt(Any(nullptr)) {}
  explicit VisitInterrupt(Any value)
      : ObjectRef(make_object<VisitInterruptObj>(std::move(value))) {}

  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(VisitInterrupt, ObjectRef, VisitInterruptObj);
};

class StructuralVisitorObj;

using FStructuralVisit = TVMFFIAny (*)(StructuralVisitorObj* visitor, AnyView value) noexcept;

namespace details {

// Visit reflected structural fields of an object-backed value.
TVM_FFI_INLINE static Expected<Optional<VisitInterrupt>> VisitReflectedFieldsExpected(
    StructuralVisitorObj* visitor, const Object* obj) noexcept;

}  // namespace details

struct StructuralVisitorVTable {
  FStructuralVisit visit = nullptr;
};

class StructuralVisitorObj : public Object {
 public:
  StructuralVisitorObj() : StructuralVisitorObj(VTable()) {}

  TVM_FFI_INLINE Optional<VisitInterrupt> Visit(AnyView value) {
    return VisitExpected(value).value();
  }

  TVM_FFI_INLINE Expected<Optional<VisitInterrupt>> VisitExpected(AnyView value) noexcept {
    return details::ExpectedUnsafe::MoveFromTVMFFIAny<Optional<VisitInterrupt>>(
        (*vtable_->visit)(this, value));
  }

  TVM_FFI_INLINE Optional<VisitInterrupt> DefaultVisit(AnyView value) {
    return DefaultVisitExpected(value).value();
  }

  TVM_FFI_INLINE Expected<Optional<VisitInterrupt>> DefaultVisitExpected(AnyView value) noexcept {
    int32_t type_index = value.type_index();
    static reflection::TypeAttrColumn column(reflection::type_attr::kStructuralVisit);
    AnyView attr = column[type_index];

    // case 1: Type-specific override registered as an opaque ABI visit function pointer.
    if (attr.type_index() == TypeIndex::kTVMFFIOpaquePtr) {
      auto* visit_fn = reinterpret_cast<FStructuralVisit>(attr.cast<void*>());
      return details::ExpectedUnsafe::MoveFromTVMFFIAny<Optional<VisitInterrupt>>(
          (*visit_fn)(this, value));
    }

    // case 2: Type-specific override registered as an ffi::Function.
    if (attr.type_index() == TypeIndex::kTVMFFIFunction) {
      return attr.cast<Function>().CallExpected<Optional<VisitInterrupt>>(this, value);
    }

    if (TVM_FFI_PREDICT_FALSE(attr.type_index() != TypeIndex::kTVMFFINone)) {
      return Unexpected(Error("TypeError",
                              std::string(reflection::type_attr::kStructuralVisit) +
                                  " must be an opaque function pointer or ffi.Function",
                              ""));
    }

    if (type_index < TypeIndex::kTVMFFIStaticObjectBegin) {
      return Optional<VisitInterrupt>(std::nullopt);
    }

    return details::VisitReflectedFieldsExpected(this, value.cast<const Object*>());
  }

  TVM_FFI_INLINE TVMFFIDefRegionKind def_region_kind() const { return def_region_mode_; }

  template <typename Callback>
  TVM_FFI_INLINE auto WithDefRegionKind(TVMFFIDefRegionKind kind, Callback&& callback) {
    class Scope {
     public:
      Scope(StructuralVisitorObj* visitor, TVMFFIDefRegionKind kind)
          : visitor_(visitor), old_kind_(visitor->def_region_mode_) {
        visitor_->def_region_mode_ = kind;
      }
      ~Scope() { visitor_->def_region_mode_ = old_kind_; }
      Scope(const Scope&) = delete;
      Scope& operator=(const Scope&) = delete;

     private:
      StructuralVisitorObj* visitor_;
      TVMFFIDefRegionKind old_kind_;
    };
    Scope scope(this, kind);
    return std::forward<Callback>(callback)();
  }

  static constexpr const bool _type_mutable = true;
  TVM_FFI_DECLARE_OBJECT_INFO("ffi.StructuralVisitor", StructuralVisitorObj, Object);

 protected:
  explicit StructuralVisitorObj(const StructuralVisitorVTable* vtable) : vtable_(vtable) {}

  const StructuralVisitorVTable* vtable_ = nullptr;

  TVMFFIDefRegionKind def_region_mode_ = kTVMFFIDefRegionKindNone;

 private:
  static const StructuralVisitorVTable* VTable() {
    static const StructuralVisitorVTable vtable{&StructuralVisitorObj::DispatchVisit};
    return &vtable;
  }

  static TVMFFIAny DispatchVisit(StructuralVisitorObj* visitor, AnyView value) noexcept {
    auto interrupt = visitor->DefaultVisitExpected(value);
    if (TVM_FFI_PREDICT_FALSE(interrupt.type_index() == TypeIndex::kTVMFFIError)) {
      if (value.type_index() >= TypeIndex::kTVMFFIStaticObjectBegin) {
        Error err = interrupt.error();
        details::UpdateVisitErrorContext(err, value.cast<ObjectRef>());
      }
    }
    return details::ExpectedUnsafe::MoveToTVMFFIAny(std::move(interrupt));
  }
};

class StructuralVisitor : public ObjectRef {
 public:
  StructuralVisitor() : ObjectRef(make_object<StructuralVisitorObj>()) {}
  explicit StructuralVisitor(ObjectPtr<StructuralVisitorObj> n) : ObjectRef(std::move(n)) {}

  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(StructuralVisitor, ObjectRef, StructuralVisitorObj);
};

namespace details {

template <typename T>
TVM_FFI_INLINE bool StructuralVisitNeedEarlyReturn(const Expected<T>& result) noexcept {
  int32_t type_index = result.type_index();
  return type_index == TypeIndex::kTVMFFIError || type_index == TypeIndex::kTVMFFIVisitInterrupt;
}

TVM_FFI_INLINE static Expected<Optional<VisitInterrupt>> VisitReflectedFieldsExpected(
    StructuralVisitorObj* visitor, const Object* obj) noexcept {
  int32_t type_index = obj->type_index();
  const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);

  Expected<Optional<VisitInterrupt>> result = Optional<VisitInterrupt>(std::nullopt);
  reflection::ForEachFieldInfoWithEarlyStop(
      type_info, [&](const TVMFFIFieldInfo* field_info) -> bool {
        if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore) {
          return false;
        }

        Any field_value;
        const void* field_addr = reinterpret_cast<const char*>(obj) + field_info->offset;
        int ret_code = field_info->getter(const_cast<void*>(field_addr),
                                          reinterpret_cast<TVMFFIAny*>(&field_value));
        if (TVM_FFI_PREDICT_FALSE(ret_code != 0)) {
          result = Unexpected(details::MoveFromSafeCallRaised());
          return true;
        }

        TVMFFIDefRegionKind kind = kTVMFFIDefRegionKindNone;
        if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDefNonRecursive) {
          kind = kTVMFFIDefRegionKindNonRecursive;
        } else if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDefRecursive) {
          kind = kTVMFFIDefRegionKindRecursive;
        }

        if (kind != kTVMFFIDefRegionKindNone) {
          result = visitor->WithDefRegionKind(
              kind, [&]() { return visitor->VisitExpected(field_value); });
        } else {
          result = visitor->VisitExpected(field_value);
        }
        return StructuralVisitNeedEarlyReturn(result);
      });
  return result;
}

}  // namespace details

// ---------------------------------------------------------------------------
// Structural Walk API.
// ---------------------------------------------------------------------------

class WalkResult : public Variant<VisitInterrupt, int32_t> {
 public:
  static constexpr int32_t kAdvanceTag = 0;
  static constexpr int32_t kSkipTag = 1;

  using Storage = Variant<VisitInterrupt, int32_t>;

  static WalkResult Advance() { return WalkResult(kAdvanceTag); }

  static WalkResult Skip() { return WalkResult(kSkipTag); }

  static WalkResult Interrupt(VisitInterrupt signal = VisitInterrupt()) {
    return WalkResult(Storage(std::move(signal)));
  }

 private:
  // Keep raw storage construction behind the named factories.
  explicit WalkResult(int32_t tag) : Storage(tag) {}
  explicit WalkResult(Storage storage) : Storage(std::move(storage)) {}

  friend struct TypeTraits<WalkResult>;
};

template <>
inline constexpr bool use_default_type_traits_v<WalkResult> = false;

// Allow WalkResult to round-trip through Any / Expected while reusing Variant storage.
template <>
struct TypeTraits<WalkResult> : public TypeTraits<WalkResult::Storage> {
  using Base = TypeTraits<WalkResult::Storage>;

  TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
    return src->type_index == TypeIndex::kTVMFFINone || Base::CheckAnyStrict(src);
  }
  // Decode from borrowed Any storage after a strict type check.
  TVM_FFI_INLINE static WalkResult CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
    if (src->type_index == TypeIndex::kTVMFFINone) {
      return WalkResult::Advance();
    }
    return WalkResult(Base::CopyFromAnyViewAfterCheck(src));
  }
  // Decode by moving from owned Any storage after a strict type check.
  TVM_FFI_INLINE static WalkResult MoveFromAnyAfterCheck(TVMFFIAny* src) {
    if (src->type_index == TypeIndex::kTVMFFINone) {
      return WalkResult::Advance();
    }
    return WalkResult(Base::MoveFromAnyAfterCheck(src));
  }
  // Try all conversions supported by the underlying Variant storage.
  TVM_FFI_INLINE static std::optional<WalkResult> TryCastFromAnyView(const TVMFFIAny* src) {
    if (src->type_index == TypeIndex::kTVMFFINone) {
      return WalkResult::Advance();
    }
    if (auto opt = Base::TryCastFromAnyView(src)) {
      return WalkResult(*std::move(opt));
    }
    return std::nullopt;
  }
  TVM_FFI_INLINE static std::string TypeStr() { return "WalkResult"; }
};

enum class WalkOrder : int32_t {
  kPreOrder = 0,
  kPostOrder = 1,
};

namespace details {

// Return from the current ABI visit function if Result stops traversal.
// Result must evaluate to Expected whose raw storage can be moved to TVMFFIAny.
#define TVM_FFI_S_VISIT_MAYBE_EARLY_RETURN(Result)                                          \
  do {                                                                                      \
    auto&& tvm_ffi_res_ = (Result);                                                         \
    if (TVM_FFI_PREDICT_FALSE(                                                              \
            ::tvm::ffi::details::StructuralVisitNeedEarlyReturn(tvm_ffi_res_))) {           \
      return ::tvm::ffi::details::ExpectedUnsafe::MoveToTVMFFIAny(std::move(tvm_ffi_res_)); \
    }                                                                                       \
  } while (0)

// Return from the current ABI visit function if Result stops traversal.
// If Result is an Error, append Node to the visit error context before returning.
#define TVM_FFI_S_VISIT_MAYBE_EARLY_RETURN_WITH_ERROR_CONTEXT(Result, Node)                   \
  do {                                                                                        \
    auto&& tvm_ffi_res_ = (Result);                                                           \
    if (TVM_FFI_PREDICT_FALSE(                                                                \
            ::tvm::ffi::details::StructuralVisitNeedEarlyReturn(tvm_ffi_res_))) {             \
      if (TVM_FFI_PREDICT_FALSE(tvm_ffi_res_.type_index() ==                                  \
                                ::tvm::ffi::TypeIndex::kTVMFFIError)) {                       \
        if ((Node).type_index() >= ::tvm::ffi::TypeIndex::kTVMFFIStaticObjectBegin) {         \
          ::tvm::ffi::Error tvm_ffi_visit_err_ = tvm_ffi_res_.error();                        \
          ::tvm::ffi::details::UpdateVisitErrorContext(tvm_ffi_visit_err_,                    \
                                                       (Node).cast<::tvm::ffi::ObjectRef>()); \
        }                                                                                     \
      }                                                                                       \
      return ::tvm::ffi::details::ExpectedUnsafe::MoveToTVMFFIAny(std::move(tvm_ffi_res_));   \
    }                                                                                         \
  } while (0)

template <WalkOrder order, typename Dispatch>
class StructuralWalkCallbackVisitorObj : public StructuralVisitorObj {
 public:
  explicit StructuralWalkCallbackVisitorObj(Dispatch dispatch)
      : StructuralVisitorObj(VTable()), dispatch_(std::move(dispatch)) {}

 private:
  static const StructuralVisitorVTable* VTable() {
    static const StructuralVisitorVTable vtable{&StructuralWalkCallbackVisitorObj::DispatchVisit};
    return &vtable;
  }

  static TVMFFIAny DispatchVisit(StructuralVisitorObj* self, AnyView value) noexcept {
    return static_cast<StructuralWalkCallbackVisitorObj*>(self)->VisitImpl(value);
  }

  TVMFFIAny VisitImpl(AnyView value) noexcept {
    if (TVM_FFI_PREDICT_FALSE(value.type_index() == TypeIndex::kTVMFFINone)) {
      return details::ExpectedUnsafe::MoveToTVMFFIAny(
          Expected<Optional<VisitInterrupt>>(std::nullopt));
    }
    if constexpr (order == WalkOrder::kPreOrder) {
      auto result = dispatch_(value, this->def_region_kind());
      TVM_FFI_S_VISIT_MAYBE_EARLY_RETURN_WITH_ERROR_CONTEXT(result, value);
      int32_t type_index = result.type_index();
      TVM_FFI_UNSAFE_ASSUME(type_index == TypeIndex::kTVMFFIInt);
      if (TVM_FFI_PREDICT_FALSE(details::ExpectedUnsafe::ValueAs<int32_t>(result) ==
                                WalkResult::kSkipTag)) {
        return details::ExpectedUnsafe::MoveToTVMFFIAny(
            Expected<Optional<VisitInterrupt>>(std::nullopt));
      }
    }

    TVM_FFI_S_VISIT_MAYBE_EARLY_RETURN_WITH_ERROR_CONTEXT(DefaultVisitExpected(value), value);

    if constexpr (order == WalkOrder::kPostOrder) {
      TVM_FFI_S_VISIT_MAYBE_EARLY_RETURN_WITH_ERROR_CONTEXT(
          dispatch_(value, this->def_region_kind()), value);
    }

    return details::ExpectedUnsafe::MoveToTVMFFIAny(
        Expected<Optional<VisitInterrupt>>(std::nullopt));
  }

  Dispatch dispatch_;
};

struct StructuralWalkCallbackChain {
  template <typename... Callbacks>
  static auto FromChain(Callbacks... callbacks) {
    return [=](AnyView x, TVMFFIDefRegionKind kind) mutable -> Expected<WalkResult> {
      try {
        Optional<Expected<WalkResult>> result;
        // Fold expression: each TryCallLink returns empty Optional on no-match
        // (falsy) or a result on match (truthy); || short-circuits on first match.
        (... || (result = TryCallLink(callbacks, x, kind)));
        if (result.has_value()) {
          return std::move(result).value();
        }
        return WalkResult::Advance();
      } catch (const Error& err) {
        return Unexpected(err);
      }
    };
  }

 private:
  template <typename Callback>
  static Optional<Expected<WalkResult>> TryCallLink(Callback& callback, AnyView x,
                                                    TVMFFIDefRegionKind kind) {
    using FuncInfo = FunctionInfo<std::decay_t<Callback>>;
    static_assert(FuncInfo::num_args == 1 || FuncInfo::num_args == 2,
                  "StructuralWalk callbacks must take one argument (value) or two arguments "
                  "(value, def-region kind)");
    using FirstArg = std::tuple_element_t<0, typename FuncInfo::ArgType>;
    using TSub = std::remove_cv_t<std::remove_reference_t<FirstArg>>;
    if constexpr (std::is_same_v<TSub, AnyView>) {
      // callback on AnyView
      return InvokeCallback(callback, x, kind);
    } else if constexpr (std::is_same_v<TSub, Any>) {
      // callback on Any
      return InvokeCallback(callback, Any(x), kind);
    } else {
      if (auto opt = x.template as<TSub>()) {
        return InvokeCallback(callback, *std::move(opt), kind);
      }
    }
    return std::nullopt;
  }

  template <typename Callback, typename Value>
  static Expected<WalkResult> InvokeCallback(Callback& callback, Value&& value,
                                             TVMFFIDefRegionKind kind) {
    using FuncInfo = FunctionInfo<std::decay_t<Callback>>;
    if constexpr (FuncInfo::num_args == 1) {
      return callback(std::forward<Value>(value));
    } else {
      return callback(std::forward<Value>(value), kind);
    }
  }
};

}  // namespace details

template <WalkOrder order, typename... Callbacks>
Expected<Optional<VisitInterrupt>> StructuralWalkExpected(AnyView root,
                                                          Callbacks&&... callbacks) noexcept {
  static_assert(sizeof...(Callbacks) != 0, "StructuralWalk requires at least one callback");
  auto dispatch =
      details::StructuralWalkCallbackChain::FromChain(std::forward<Callbacks>(callbacks)...);
  using Visitor = details::StructuralWalkCallbackVisitorObj<order, decltype(dispatch)>;
  StructuralVisitor visitor(make_object<Visitor>(std::move(dispatch)));
  return visitor->VisitExpected(root);
}

template <WalkOrder order, typename... Callbacks>
Optional<VisitInterrupt> StructuralWalk(AnyView root, Callbacks&&... callbacks) {
  return StructuralWalkExpected<order>(root, std::forward<Callbacks>(callbacks)...).value();
}

}  // namespace ffi
}  // namespace tvm
#endif  // TVM_FFI_EXTRA_STRUCTURAL_VISIT_H_