Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tensorwrapper/buffer/buffer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ class BufferBase : public BufferBaseCommon<BufferBase>,

const_layout_reference layout_() const { return *m_layout_; }

layout_reference layout_() { return *m_layout_; }

dsl_reference addition_assignment_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs) override;
Expand Down
21 changes: 16 additions & 5 deletions include/tensorwrapper/buffer/buffer_base_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,14 @@ class BufferBaseCommon {
/// Type of *this
using my_type = BufferBaseCommon<Derived>;

/// Traits for my_type
using traits_type = types::ClassTraits<my_type>;

/// Traits for Derived
using derived_traits = types::ClassTraits<Derived>;
using traits_type = types::ClassTraits<Derived>;

public:
///@{
using layout_type = typename traits_type::layout_type;
using layout_reference = typename traits_type::layout_reference;
using layout_pointer = typename traits_type::layout_pointer;
using const_layout_reference = typename traits_type::const_layout_reference;
using rank_type = typename traits_type::rank_type;
///@}
Expand All @@ -60,6 +59,18 @@ class BufferBaseCommon {
*/
bool has_layout() const noexcept { return derived_().has_layout_(); }

/** @brief Retrieves the layout of *this.
*
* @return A reference to the layout.
*
* @throw std::runtime_error if *this does not have a layout. Strong throw
* guarantee.
*/
layout_reference layout() {
assert_layout_();
return derived_().layout_();
}

/** @brief Retrieves the layout of *this.
*
* @return A read-only reference to the layout.
Expand Down Expand Up @@ -143,7 +154,7 @@ class BufferBaseCommon {

/// Access derived for CRTP
const Derived& derived_() const noexcept {
return *static_cast<const Derived*>(this);
return static_cast<const Derived&>(*this);
}
};

Expand Down
13 changes: 11 additions & 2 deletions include/tensorwrapper/buffer/buffer_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,19 @@ class BufferBase;
template<typename BufferBaseType>
class BufferViewBase;

class Contiguous;

class Local;

template<typename LocalType>
class LocalView;

template<typename Derived>
class ReplicatedCommon;

class Replicated;

template<typename ReplicatedType>
class ReplicatedView;

class Contiguous;

} // namespace tensorwrapper::buffer
132 changes: 86 additions & 46 deletions include/tensorwrapper/buffer/buffer_view_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,28 @@
*/

#pragma once
#include <memory>
#include <tensorwrapper/buffer/buffer_base.hpp>
#include <tensorwrapper/buffer/buffer_base_common.hpp>
#include <tensorwrapper/buffer/detail_/buffer_view_base_pimpl.hpp>
#include <type_traits>

namespace tensorwrapper::buffer {

/** @brief View of a BufferBase that aliases existing state instead of owning
* it.
* it.
*
* BufferViewBase has the same layout/equality API as BufferBase (has_layout(),
* layout(), rank(), operator==, operator!=, approximately_equal) but holds a
* non-owning pointer to a BufferBase and delegates all operations to it.
* layout(), rank(), operator==, operator!=, approximately_equal) but uses a
* PIMPL. The view delegates layout operations to the PIMPL.
*
* BufferViewBase is templated on the type of the aliased buffer, which must
* be either BufferBase or const BufferBase. This controls whether the view is
* a mutable or const view of the underlying BufferBase.
* BufferViewBase is templated on the type of the aliased buffer (BufferBase or
* const BufferBase) for API compatibility; construction from a buffer copies
* a non-owning pointer to that buffer's layout into the PIMPL.
*
* The aliased buffer must outlive this view. Default-constructed or
* moved-from views have no aliased buffer (has_layout() is false, layout()
* throws).
* The referenced layout (and its owner) must outlive this view. Default-
* constructed or moved-from views have no layout (has_layout() is false,
* layout() throws).
*
* @tparam BufferBaseType Either BufferBase or const BufferBase.
*/
Expand All @@ -48,61 +50,87 @@ class BufferViewBase : public BufferBaseCommon<BufferViewBase<BufferBaseType>> {

/// Type *this derives from
using my_base_type = BufferBaseCommon<BufferViewBase<BufferBaseType>>;
using typename my_base_type::const_layout_reference;

using aliased_type = BufferBaseType;
using aliased_pointer = aliased_type*;
/// Type of the PIMPL
using pimpl_type = detail_::BufferViewBasePIMPL<BufferBaseType>;
using pimpl_reference = pimpl_type&;
using const_pimpl_reference = const pimpl_type&;

public:
using typename my_base_type::const_layout_reference;
using typename my_base_type::layout_pointer;
using typename my_base_type::layout_reference;
using typename my_base_type::layout_type;
// -------------------------------------------------------------------------
// -- Ctors and assignment
// -------------------------------------------------------------------------

/** @brief Creates a view that aliases no buffer.
/** @brief Creates a view with no layout.
*
* @throw None No throw guarantee.
*/
BufferViewBase() noexcept : m_aliased_(nullptr) {}
BufferViewBase() noexcept : m_pimpl_(nullptr) {}

/** @brief Creates a view that aliases @p buffer.
/** @brief Creates a view that aliases the layout of @p buffer.
*
* @param[in] buffer The buffer to alias. Must outlive *this.
* @param[in] buffer The buffer whose layout to alias. The layout must
* outlive *this.
*
* @throw None No throw guarantee.
*/
explicit BufferViewBase(aliased_type& buffer) noexcept :
m_aliased_(&buffer) {}
explicit BufferViewBase(BufferBaseType& buffer) noexcept :
m_pimpl_(buffer.has_layout() ?
std::make_unique<pimpl_type>(&buffer.layout()) :
nullptr) {}

/** Creates a read-only view from a mutable buffer. */
template<typename OtherBufferBaseType>
requires(!std::is_const_v<OtherBufferBaseType> &&
std::is_const_v<BufferBaseType>)
explicit BufferViewBase(OtherBufferBaseType& other) noexcept :
m_pimpl_(other.has_layout() ?
std::make_unique<pimpl_type>(&other.layout()) :
nullptr) {}

explicit BufferViewBase(layout_pointer layout) noexcept :
m_pimpl_(std::make_unique<pimpl_type>(layout)) {}

/** @brief Creates a view that aliases the same buffer as @p other.
/** @brief Creates a view that aliases the same layout as @p other.
*
* @param[in] other The view to copy.
*
* @throw None No throw guarantee.
*/
BufferViewBase(const BufferViewBase& other) noexcept = default;
BufferViewBase(const BufferViewBase& other) noexcept :
m_pimpl_(other.m_pimpl_ ? other.m_pimpl_->clone() : nullptr) {}

/** @brief Creates a view by taking the alias from @p other.
/** @brief Creates a view by taking the PIMPL from @p other.
*
* After construction *this aliases the buffer @p other did, and @p other
* aliases no buffer.
* After construction *this aliases the layout @p other did, and @p other
* has no layout.
*
* @param[in,out] other The view to move from.
*
* @throw None No throw guarantee.
*/
BufferViewBase(BufferViewBase&& other) noexcept = default;

/** @brief Makes *this alias the same buffer as @p rhs.
/** @brief Makes *this alias the same layout as @p rhs.
*
* @param[in] rhs The view to copy.
*
* @return *this.
*
* @throw None No throw guarantee.
*/
BufferViewBase& operator=(const BufferViewBase& rhs) noexcept = default;
BufferViewBase& operator=(const BufferViewBase& rhs) noexcept {
if(this != &rhs) {
m_pimpl_ = rhs.m_pimpl_ ? rhs.m_pimpl_->clone() : nullptr;
}
return *this;
}

/** @brief Replaces the alias in *this with that of @p rhs.
/** @brief Replaces the PIMPL in *this with that of @p rhs.
*
* @param[in,out] rhs The view to move from.
*
Expand Down Expand Up @@ -133,41 +161,53 @@ class BufferViewBase : public BufferBaseCommon<BufferViewBase<BufferBaseType>> {
// -------------------------------------------------------------------------

bool has_layout_() const noexcept {
return m_aliased_ != nullptr && m_aliased_->has_layout();
return m_pimpl_ != nullptr && m_pimpl_->has_layout();
}

layout_reference layout_() { return pimpl_().layout(); }

const_layout_reference layout_() const { return pimpl_().layout(); }

// Will be polymorphic eventually
template<typename OtherBufferBaseType>
bool approximately_equal_(const BufferViewBase<OtherBufferBaseType>& rhs,
double) const {
return *this == rhs;
}

// Will be polymorphic eventually
bool approximately_equal_(const BufferBase& rhs, double) const {
return *this == rhs;
}

const_layout_reference layout_() const {
if(m_aliased_ == nullptr) {
private:
void assert_pimpl_() const {
if(!m_pimpl_) {
throw std::runtime_error(
"Buffer has no layout. Was it default initialized?");
"BufferViewBase has no PIMPL. Was it default initialized?");
}
return m_aliased_->layout();
}

template<typename OtherBufferBase>
bool approximately_equal_(const BufferViewBase<OtherBufferBase>& rhs,
double tol) const {
if(m_aliased_ == nullptr) return !rhs.has_layout();
return m_aliased_->approximately_equal(*rhs.m_aliased_, tol);
pimpl_reference pimpl_() {
assert_pimpl_();
return *m_pimpl_;
}

bool approximately_equal_(const BufferBase& rhs, double tol) const {
if(m_aliased_ == nullptr) return !rhs.has_layout();
return m_aliased_->approximately_equal(rhs, tol);
const_pimpl_reference pimpl_() const {
assert_pimpl_();
return *m_pimpl_;
}

private:
/// The buffer *this aliases (non-owning)
aliased_pointer m_aliased_;
/// PIMPL holding non-owning pointer to the aliased layout
std::unique_ptr<pimpl_type> m_pimpl_;
};

// Out-of-line definition so both BufferBase and BufferViewBase are complete
// Out-of-line definition so both BufferBase and BufferViewBase are complete.

template<typename BufferBaseType>
bool BufferBase::approximately_equal_(const BufferViewBase<BufferBaseType>& rhs,
double tol) const {
if(!rhs.has_layout()) return !has_layout();
return approximately_equal_(
*static_cast<const BufferBaseType*>(rhs.m_aliased_), tol);
return !this->layout().are_different(rhs.layout());
}

} // namespace tensorwrapper::buffer
Loading