Program Listing for File access_path.h#
↰ Return to documentation for file (tvm/ffi/reflection/access_path.h
)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_FFI_REFLECTION_ACCESS_PATH_H_
#define TVM_FFI_REFLECTION_ACCESS_PATH_H_
#include <tvm/ffi/any.h>
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/cast.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/tuple.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/reflection/registry.h>
#include <vector>
namespace tvm {
namespace ffi {
namespace reflection {
enum class AccessKind : int32_t {
kAttr = 0,
kArrayItem = 1,
kMapItem = 2,
// the following two are used for error reporting when
// the supposed access field is not available
kAttrMissing = 3,
kArrayItemMissing = 4,
kMapItemMissing = 5,
};
class AccessStep;
class AccessStepObj : public Object {
public:
AccessKind kind;
Any key;
// default constructor to enable auto-serialization
AccessStepObj() = default;
AccessStepObj(AccessKind kind, Any key) : kind(kind), key(key) {}
inline bool StepEqual(const AccessStep& other) const;
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ffi.reflection.AccessStep", AccessStepObj, Object);
};
class AccessStep : public ObjectRef {
public:
AccessStep(AccessKind kind, Any key) : ObjectRef(make_object<AccessStepObj>(kind, key)) {}
static AccessStep Attr(String field_name) { return AccessStep(AccessKind::kAttr, field_name); }
static AccessStep AttrMissing(String field_name) {
return AccessStep(AccessKind::kAttrMissing, field_name);
}
static AccessStep ArrayItem(int64_t index) { return AccessStep(AccessKind::kArrayItem, index); }
static AccessStep ArrayItemMissing(int64_t index) {
return AccessStep(AccessKind::kArrayItemMissing, index);
}
static AccessStep MapItem(Any key) { return AccessStep(AccessKind::kMapItem, key); }
static AccessStep MapItemMissing(Any key = nullptr) {
return AccessStep(AccessKind::kMapItemMissing, key);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AccessStep, ObjectRef, AccessStepObj);
};
inline bool AccessStepObj::StepEqual(const AccessStep& other) const {
return this->kind == other->kind && AnyEqual()(this->key, other->key);
}
// forward declaration
class AccessPath;
class AccessPathObj : public Object {
public:
Optional<ObjectRef> parent;
Optional<AccessStep> step;
int32_t depth;
// default constructor to enable auto-serialization
AccessPathObj() = default;
AccessPathObj(Optional<ObjectRef> parent, Optional<AccessStep> step, int32_t depth)
: parent(parent), step(step), depth(depth) {}
inline Optional<AccessPath> GetParent() const;
inline AccessPath Extend(AccessStep step) const;
inline AccessPath Attr(String field_name) const;
inline AccessPath AttrMissing(String field_name) const;
inline AccessPath ArrayItem(int64_t index) const;
inline AccessPath ArrayItemMissing(int64_t index) const;
inline AccessPath MapItem(Any key) const;
inline AccessPath MapItemMissing(Any key) const;
inline Array<AccessStep> ToSteps() const;
inline bool PathEqual(const AccessPath& other) const;
inline bool IsPrefixOf(const AccessPath& other) const;
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ffi.reflection.AccessPath", AccessPathObj, Object);
private:
static bool PathEqual(const AccessPathObj* lhs, const AccessPathObj* rhs) {
// fast path for same pointer
if (lhs == rhs) return true;
if (lhs->depth != rhs->depth) return false;
// do deep equality checks
while (lhs->parent.has_value()) {
TVM_FFI_ICHECK(rhs->parent.has_value());
TVM_FFI_ICHECK(lhs->step.has_value());
TVM_FFI_ICHECK(rhs->step.has_value());
if (!(*lhs->step)->StepEqual(*(rhs->step))) {
return false;
}
lhs = static_cast<const AccessPathObj*>(lhs->parent.get());
rhs = static_cast<const AccessPathObj*>(rhs->parent.get());
// fast path for same pointer
if (lhs == rhs) return true;
TVM_FFI_ICHECK(lhs != nullptr);
TVM_FFI_ICHECK(rhs != nullptr);
}
return true;
}
};
class AccessPath : public ObjectRef {
public:
template <typename Iter>
static AccessPath FromSteps(Iter begin, Iter end) {
AccessPath path = AccessPath::Root();
for (Iter it = begin; it != end; ++it) {
path = path->Extend(*it);
}
return path;
}
static AccessPath FromSteps(Array<AccessStep> steps) {
AccessPath path = AccessPath::Root();
for (AccessStep step : steps) {
path = path->Extend(step);
}
return path;
}
static AccessPath Root() {
return AccessPath(make_object<AccessPathObj>(std::nullopt, std::nullopt, 0));
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AccessPath, ObjectRef, AccessPathObj);
private:
friend class AccessPathObj;
explicit AccessPath(ObjectPtr<AccessPathObj> ptr) : ObjectRef(ptr) {}
};
using AccessPathPair = Tuple<AccessPath, AccessPath>;
inline Optional<AccessPath> AccessPathObj::GetParent() const {
if (auto opt_parent = this->parent.as<AccessPath>()) {
return opt_parent;
}
return std::nullopt;
}
inline AccessPath AccessPathObj::Extend(AccessStep step) const {
return AccessPath(make_object<AccessPathObj>(GetRef<AccessPath>(this), step, this->depth + 1));
}
inline AccessPath AccessPathObj::Attr(String field_name) const {
return this->Extend(AccessStep::Attr(field_name));
}
inline AccessPath AccessPathObj::AttrMissing(String field_name) const {
return this->Extend(AccessStep::AttrMissing(field_name));
}
inline AccessPath AccessPathObj::ArrayItem(int64_t index) const {
return this->Extend(AccessStep::ArrayItem(index));
}
inline AccessPath AccessPathObj::ArrayItemMissing(int64_t index) const {
return this->Extend(AccessStep::ArrayItemMissing(index));
}
inline AccessPath AccessPathObj::MapItem(Any key) const {
return this->Extend(AccessStep::MapItem(key));
}
inline AccessPath AccessPathObj::MapItemMissing(Any key) const {
return this->Extend(AccessStep::MapItemMissing(key));
}
inline Array<AccessStep> AccessPathObj::ToSteps() const {
std::vector<AccessStep> reverse_steps;
reverse_steps.reserve(this->depth);
const AccessPathObj* current = this;
while (current->parent.has_value()) {
TVM_FFI_ICHECK(current->step.has_value());
reverse_steps.push_back(*(current->step));
current = static_cast<const AccessPathObj*>(current->parent.get());
TVM_FFI_ICHECK(current != nullptr);
}
return Array<AccessStep>(reverse_steps.rbegin(), reverse_steps.rend());
}
inline bool AccessPathObj::PathEqual(const AccessPath& other) const {
return PathEqual(this, other.get());
}
inline bool AccessPathObj::IsPrefixOf(const AccessPath& other) const {
if (this->depth > other->depth) {
return false;
}
const AccessPathObj* rhs_path = other.get();
while (rhs_path->depth > this->depth) {
TVM_FFI_ICHECK(rhs_path->parent.has_value());
rhs_path = static_cast<const AccessPathObj*>(rhs_path->parent.get());
}
return PathEqual(this, rhs_path);
}
} // namespace reflection
} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_REFLECTION_ACCESS_PATH_H_