Program Listing for File shape.h

Program Listing for File shape.h#

Return to documentation for file (tvm/ffi/container/shape.h)

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

#ifndef TVM_FFI_CONTAINER_SHAPE_H_
#define TVM_FFI_CONTAINER_SHAPE_H_

#include <tvm/ffi/container/array.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/type_traits.h>

#include <algorithm>
#include <ostream>
#include <utility>
#include <vector>

namespace tvm {
namespace ffi {

class ShapeView {
 public:
  ShapeView() : cell_{nullptr, 0} {}
  ShapeView(const ShapeView& other) = default;
  ShapeView& operator=(const ShapeView& other) = default;
  ShapeView(ShapeView&& other) = default;
  ShapeView& operator=(ShapeView&& other) = default;
  ShapeView(const int64_t* data, size_t size) : cell_{data, size} {}
  ShapeView(const std::initializer_list<int64_t>& other) : cell_{other.begin(), other.size()} {}
  const int64_t* data() const { return cell_.data; }
  size_t size() const { return cell_.size; }

  int64_t Product() const {
    int64_t product = 1;
    for (size_t i = 0; i < cell_.size; ++i) {
      product *= cell_.data[i];
    }
    return product;
  }

  int64_t operator[](size_t idx) const { return cell_.data[idx]; }

  const int64_t* begin() const { return cell_.data; }

  const int64_t* end() const { return cell_.data + cell_.size; }

  bool empty() const { return size() == 0; }

  int64_t front() const { return this->at(0); }

  int64_t back() const { return this->at(this->size() - 1); }

  int64_t at(size_t idx) const {
    if (idx >= this->size()) {
      TVM_FFI_THROW(IndexError) << "indexing " << idx << " on a Shape of size " << this->size();
    }
    return cell_.data[idx];
  }

 private:
  TVMFFIShapeCell cell_;
};

class ShapeObj : public Object, public TVMFFIShapeCell {
 public:
  using index_type = int64_t;

  int64_t Product() const {
    int64_t product = 1;
    for (size_t i = 0; i < this->size; ++i) {
      product *= this->data[i];
    }
    return product;
  }

  static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIShape;
  TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIShape, ShapeObj, Object);
};

namespace details {

class ShapeObjStdImpl : public ShapeObj {
 public:
  explicit ShapeObjStdImpl(std::vector<int64_t> other) : data_{other} {
    this->data = data_.data();
    this->size = static_cast<size_t>(data_.size());
  }

 private:
  std::vector<int64_t> data_;
};

TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeEmptyShape(size_t length, int64_t** mutable_data) {
  ObjectPtr<ShapeObj> p = make_inplace_array_object<ShapeObj, int64_t>(length);
  static_assert(alignof(ShapeObj) % alignof(int64_t) == 0);
  static_assert(sizeof(ShapeObj) % alignof(int64_t) == 0);
  int64_t* data = reinterpret_cast<int64_t*>(reinterpret_cast<char*>(p.get()) + sizeof(ShapeObj));
  if (mutable_data) {
    *mutable_data = data;
  }
  p->data = data;
  p->size = length;
  return p;
}

// inplace shape allocation
template <typename IterType>
TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeInplaceShape(IterType begin, IterType end) {
  size_t length = std::distance(begin, end);
  int64_t* mutable_data;
  ObjectPtr<ShapeObj> p = MakeEmptyShape(length, &mutable_data);
  std::copy(begin, end, mutable_data);
  return p;
}

TVM_FFI_INLINE void FillStridesFromShape(ShapeView shape, int64_t* out_strides) {
  int64_t stride = 1;
  for (int64_t i = static_cast<int64_t>(shape.size()) - 1; i >= 0; --i) {
    out_strides[i] = stride;
    stride *= shape[i];
  }
}

TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeStridesFromShape(ShapeView shape) {
  int64_t* strides_data;
  ObjectPtr<ShapeObj> strides = details::MakeEmptyShape(shape.size(), &strides_data);
  FillStridesFromShape(shape, strides_data);
  return strides;
}

}  // namespace details

class Shape : public ObjectRef {
 public:
  using index_type = ShapeObj::index_type;

  Shape() : ObjectRef(details::MakeEmptyShape(0, nullptr)) {}

  template <typename IterType>
  Shape(IterType begin, IterType end) : Shape(details::MakeInplaceShape(begin, end)) {}

  Shape(Array<int64_t> shape)  // NOLINT(*)
      : Shape(shape.begin(), shape.end()) {}

  Shape(std::initializer_list<int64_t> shape) : Shape(shape.begin(), shape.end()) {}

  Shape(std::vector<int64_t> other)  // NOLINT(*)
      : ObjectRef(make_object<details::ShapeObjStdImpl>(std::move(other))) {}

  Shape(ShapeView other) : Shape(other.begin(), other.end()) {}  // NOLINT(*)

  static Shape StridesFromShape(ShapeView shape) {
    return Shape(details::MakeStridesFromShape(shape));
  }

  operator ShapeView() const { return ShapeView(data(), size()); }  // NOLINT(*)

  const int64_t* data() const { return get()->data; }

  size_t size() const { return get()->size; }

  int64_t operator[](size_t idx) const { return this->data()[idx]; }

  int64_t at(size_t idx) const {
    if (idx >= this->size()) {
      TVM_FFI_THROW(IndexError) << "indexing " << idx << " on a Shape of size " << this->size();
    }
    return this->operator[](idx);
  }

  bool empty() const { return size() == 0; }

  int64_t front() const { return this->at(0); }

  int64_t back() const { return this->at(this->size() - 1); }

  const int64_t* begin() const { return get()->data; }

  const int64_t* end() const { return (get()->data + size()); }

  int64_t Product() const { return get()->Product(); }

  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Shape, ObjectRef, ShapeObj);

 private:
  explicit Shape(ObjectPtr<ShapeObj> ptr) : ObjectRef(ptr) {}
};

inline std::ostream& operator<<(std::ostream& os, const Shape& shape) {
  os << '[';
  for (size_t i = 0; i < shape.size(); ++i) {
    if (i != 0) {
      os << ", ";
    }
    os << shape[i];
  }
  os << ']';
  return os;
}

// Shape
template <>
inline constexpr bool use_default_type_traits_v<Shape> = false;

// Allow auto conversion from Array<int64_t> to Shape, but not from Shape to Array<int64_t>
template <>
struct TypeTraits<Shape> : public ObjectRefWithFallbackTraitsBase<Shape, Array<int64_t>> {
  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIShape;
  TVM_FFI_INLINE static Shape ConvertFallbackValue(Array<int64_t> src) { return Shape(src); }
};

}  // namespace ffi
}  // namespace tvm

#endif  // TVM_FFI_CONTAINER_SHAPE_H_