From 4bb3a8f7dc725497f7898d3f87e247d97213f579 Mon Sep 17 00:00:00 2001 From: "Ryan M. Richard" Date: Tue, 10 Mar 2026 22:37:12 -0500 Subject: [PATCH] ReplicatedView supports slicing --- include/tensorwrapper/buffer/buffer_base.hpp | 2 + .../buffer/buffer_base_common.hpp | 21 ++- include/tensorwrapper/buffer/buffer_fwd.hpp | 13 +- .../tensorwrapper/buffer/buffer_view_base.hpp | 132 ++++++++----- include/tensorwrapper/buffer/contiguous.hpp | 76 ++++---- .../buffer/detail_/buffer_view_base_pimpl.hpp | 48 +++++ include/tensorwrapper/buffer/local.hpp | 21 ++- include/tensorwrapper/buffer/replicated.hpp | 13 +- .../buffer/replicated_common.hpp | 80 ++++++++ .../tensorwrapper/buffer/replicated_view.hpp | 178 ++++++++++++++++++ .../tensorwrapper/layout/layout_common.hpp | 1 + include/tensorwrapper/types/buffer_traits.hpp | 81 ++++++-- .../tensorwrapper/types/contiguous_traits.hpp | 25 +-- .../tensorwrapper/types/preserve_const.hpp | 29 +++ src/tensorwrapper/buffer/contiguous.cpp | 23 +-- .../buffer/detail_/replicated_view_pimpl.hpp | 85 +++++++++ .../buffer/detail_/slice_pimpl.hpp | 132 +++++++++++++ src/tensorwrapper/buffer/replicated_view.cpp | 115 +++++++++++ .../buffer/detail_/slice_pimpl.cpp | 85 +++++++++ .../tensorwrapper/buffer/replicated_view.cpp | 93 +++++++++ 20 files changed, 1107 insertions(+), 146 deletions(-) create mode 100644 include/tensorwrapper/buffer/detail_/buffer_view_base_pimpl.hpp create mode 100644 include/tensorwrapper/buffer/replicated_common.hpp create mode 100644 include/tensorwrapper/buffer/replicated_view.hpp create mode 100644 include/tensorwrapper/types/preserve_const.hpp create mode 100644 src/tensorwrapper/buffer/detail_/replicated_view_pimpl.hpp create mode 100644 src/tensorwrapper/buffer/detail_/slice_pimpl.hpp create mode 100644 src/tensorwrapper/buffer/replicated_view.cpp create mode 100644 tests/cxx/unit_tests/tensorwrapper/buffer/detail_/slice_pimpl.cpp create mode 100644 tests/cxx/unit_tests/tensorwrapper/buffer/replicated_view.cpp diff --git a/include/tensorwrapper/buffer/buffer_base.hpp b/include/tensorwrapper/buffer/buffer_base.hpp index d24201b9..1225cc23 100644 --- a/include/tensorwrapper/buffer/buffer_base.hpp +++ b/include/tensorwrapper/buffer/buffer_base.hpp @@ -142,6 +142,8 @@ class BufferBase : public BufferBaseCommon, const_layout_reference layout_() const { return *m_layout_; } + layout_reference layout_() { return *m_layout_; } + dsl_reference addition_assignment_(label_type this_labels, const_labeled_reference lhs, const_labeled_reference rhs) override; diff --git a/include/tensorwrapper/buffer/buffer_base_common.hpp b/include/tensorwrapper/buffer/buffer_base_common.hpp index 10d253fc..e0bf60ce 100644 --- a/include/tensorwrapper/buffer/buffer_base_common.hpp +++ b/include/tensorwrapper/buffer/buffer_base_common.hpp @@ -35,15 +35,14 @@ class BufferBaseCommon { /// Type of *this using my_type = BufferBaseCommon; - /// Traits for my_type - using traits_type = types::ClassTraits; - /// Traits for Derived - using derived_traits = types::ClassTraits; + using traits_type = types::ClassTraits; public: ///@{ using layout_type = typename traits_type::layout_type; + using layout_reference = typename traits_type::layout_reference; + using layout_pointer = typename traits_type::layout_pointer; using const_layout_reference = typename traits_type::const_layout_reference; using rank_type = typename traits_type::rank_type; ///@} @@ -60,6 +59,18 @@ class BufferBaseCommon { */ bool has_layout() const noexcept { return derived_().has_layout_(); } + /** @brief Retrieves the layout of *this. + * + * @return A reference to the layout. + * + * @throw std::runtime_error if *this does not have a layout. Strong throw + * guarantee. + */ + layout_reference layout() { + assert_layout_(); + return derived_().layout_(); + } + /** @brief Retrieves the layout of *this. * * @return A read-only reference to the layout. @@ -143,7 +154,7 @@ class BufferBaseCommon { /// Access derived for CRTP const Derived& derived_() const noexcept { - return *static_cast(this); + return static_cast(*this); } }; diff --git a/include/tensorwrapper/buffer/buffer_fwd.hpp b/include/tensorwrapper/buffer/buffer_fwd.hpp index 74b63fb4..25129c04 100644 --- a/include/tensorwrapper/buffer/buffer_fwd.hpp +++ b/include/tensorwrapper/buffer/buffer_fwd.hpp @@ -26,10 +26,19 @@ class BufferBase; template class BufferViewBase; -class Contiguous; - class Local; +template +class LocalView; + +template +class ReplicatedCommon; + class Replicated; +template +class ReplicatedView; + +class Contiguous; + } // namespace tensorwrapper::buffer diff --git a/include/tensorwrapper/buffer/buffer_view_base.hpp b/include/tensorwrapper/buffer/buffer_view_base.hpp index 21039688..583530d3 100644 --- a/include/tensorwrapper/buffer/buffer_view_base.hpp +++ b/include/tensorwrapper/buffer/buffer_view_base.hpp @@ -15,26 +15,28 @@ */ #pragma once +#include #include #include +#include #include namespace tensorwrapper::buffer { /** @brief View of a BufferBase that aliases existing state instead of owning - * it. + * it. * * BufferViewBase has the same layout/equality API as BufferBase (has_layout(), - * layout(), rank(), operator==, operator!=, approximately_equal) but holds a - * non-owning pointer to a BufferBase and delegates all operations to it. + * layout(), rank(), operator==, operator!=, approximately_equal) but uses a + * PIMPL. The view delegates layout operations to the PIMPL. * - * BufferViewBase is templated on the type of the aliased buffer, which must - * be either BufferBase or const BufferBase. This controls whether the view is - * a mutable or const view of the underlying BufferBase. + * BufferViewBase is templated on the type of the aliased buffer (BufferBase or + * const BufferBase) for API compatibility; construction from a buffer copies + * a non-owning pointer to that buffer's layout into the PIMPL. * - * The aliased buffer must outlive this view. Default-constructed or - * moved-from views have no aliased buffer (has_layout() is false, layout() - * throws). + * The referenced layout (and its owner) must outlive this view. Default- + * constructed or moved-from views have no layout (has_layout() is false, + * layout() throws). * * @tparam BufferBaseType Either BufferBase or const BufferBase. */ @@ -48,43 +50,64 @@ class BufferViewBase : public BufferBaseCommon> { /// Type *this derives from using my_base_type = BufferBaseCommon>; - using typename my_base_type::const_layout_reference; - using aliased_type = BufferBaseType; - using aliased_pointer = aliased_type*; + /// Type of the PIMPL + using pimpl_type = detail_::BufferViewBasePIMPL; + using pimpl_reference = pimpl_type&; + using const_pimpl_reference = const pimpl_type&; public: + using typename my_base_type::const_layout_reference; + using typename my_base_type::layout_pointer; + using typename my_base_type::layout_reference; + using typename my_base_type::layout_type; // ------------------------------------------------------------------------- // -- Ctors and assignment // ------------------------------------------------------------------------- - /** @brief Creates a view that aliases no buffer. + /** @brief Creates a view with no layout. * * @throw None No throw guarantee. */ - BufferViewBase() noexcept : m_aliased_(nullptr) {} + BufferViewBase() noexcept : m_pimpl_(nullptr) {} - /** @brief Creates a view that aliases @p buffer. + /** @brief Creates a view that aliases the layout of @p buffer. * - * @param[in] buffer The buffer to alias. Must outlive *this. + * @param[in] buffer The buffer whose layout to alias. The layout must + * outlive *this. * * @throw None No throw guarantee. */ - explicit BufferViewBase(aliased_type& buffer) noexcept : - m_aliased_(&buffer) {} + explicit BufferViewBase(BufferBaseType& buffer) noexcept : + m_pimpl_(buffer.has_layout() ? + std::make_unique(&buffer.layout()) : + nullptr) {} + + /** Creates a read-only view from a mutable buffer. */ + template + requires(!std::is_const_v && + std::is_const_v) + explicit BufferViewBase(OtherBufferBaseType& other) noexcept : + m_pimpl_(other.has_layout() ? + std::make_unique(&other.layout()) : + nullptr) {} + + explicit BufferViewBase(layout_pointer layout) noexcept : + m_pimpl_(std::make_unique(layout)) {} - /** @brief Creates a view that aliases the same buffer as @p other. + /** @brief Creates a view that aliases the same layout as @p other. * * @param[in] other The view to copy. * * @throw None No throw guarantee. */ - BufferViewBase(const BufferViewBase& other) noexcept = default; + BufferViewBase(const BufferViewBase& other) noexcept : + m_pimpl_(other.m_pimpl_ ? other.m_pimpl_->clone() : nullptr) {} - /** @brief Creates a view by taking the alias from @p other. + /** @brief Creates a view by taking the PIMPL from @p other. * - * After construction *this aliases the buffer @p other did, and @p other - * aliases no buffer. + * After construction *this aliases the layout @p other did, and @p other + * has no layout. * * @param[in,out] other The view to move from. * @@ -92,7 +115,7 @@ class BufferViewBase : public BufferBaseCommon> { */ BufferViewBase(BufferViewBase&& other) noexcept = default; - /** @brief Makes *this alias the same buffer as @p rhs. + /** @brief Makes *this alias the same layout as @p rhs. * * @param[in] rhs The view to copy. * @@ -100,9 +123,14 @@ class BufferViewBase : public BufferBaseCommon> { * * @throw None No throw guarantee. */ - BufferViewBase& operator=(const BufferViewBase& rhs) noexcept = default; + BufferViewBase& operator=(const BufferViewBase& rhs) noexcept { + if(this != &rhs) { + m_pimpl_ = rhs.m_pimpl_ ? rhs.m_pimpl_->clone() : nullptr; + } + return *this; + } - /** @brief Replaces the alias in *this with that of @p rhs. + /** @brief Replaces the PIMPL in *this with that of @p rhs. * * @param[in,out] rhs The view to move from. * @@ -133,41 +161,53 @@ class BufferViewBase : public BufferBaseCommon> { // ------------------------------------------------------------------------- bool has_layout_() const noexcept { - return m_aliased_ != nullptr && m_aliased_->has_layout(); + return m_pimpl_ != nullptr && m_pimpl_->has_layout(); + } + + layout_reference layout_() { return pimpl_().layout(); } + + const_layout_reference layout_() const { return pimpl_().layout(); } + + // Will be polymorphic eventually + template + bool approximately_equal_(const BufferViewBase& rhs, + double) const { + return *this == rhs; + } + + // Will be polymorphic eventually + bool approximately_equal_(const BufferBase& rhs, double) const { + return *this == rhs; } - const_layout_reference layout_() const { - if(m_aliased_ == nullptr) { +private: + void assert_pimpl_() const { + if(!m_pimpl_) { throw std::runtime_error( - "Buffer has no layout. Was it default initialized?"); + "BufferViewBase has no PIMPL. Was it default initialized?"); } - return m_aliased_->layout(); } - - template - bool approximately_equal_(const BufferViewBase& rhs, - double tol) const { - if(m_aliased_ == nullptr) return !rhs.has_layout(); - return m_aliased_->approximately_equal(*rhs.m_aliased_, tol); + pimpl_reference pimpl_() { + assert_pimpl_(); + return *m_pimpl_; } - bool approximately_equal_(const BufferBase& rhs, double tol) const { - if(m_aliased_ == nullptr) return !rhs.has_layout(); - return m_aliased_->approximately_equal(rhs, tol); + const_pimpl_reference pimpl_() const { + assert_pimpl_(); + return *m_pimpl_; } -private: - /// The buffer *this aliases (non-owning) - aliased_pointer m_aliased_; + /// PIMPL holding non-owning pointer to the aliased layout + std::unique_ptr m_pimpl_; }; -// Out-of-line definition so both BufferBase and BufferViewBase are complete +// Out-of-line definition so both BufferBase and BufferViewBase are complete. + template bool BufferBase::approximately_equal_(const BufferViewBase& rhs, double tol) const { if(!rhs.has_layout()) return !has_layout(); - return approximately_equal_( - *static_cast(rhs.m_aliased_), tol); + return !this->layout().are_different(rhs.layout()); } } // namespace tensorwrapper::buffer diff --git a/include/tensorwrapper/buffer/contiguous.hpp b/include/tensorwrapper/buffer/contiguous.hpp index 4960460a..c44f7a3d 100644 --- a/include/tensorwrapper/buffer/contiguous.hpp +++ b/include/tensorwrapper/buffer/contiguous.hpp @@ -42,9 +42,9 @@ class Contiguous : public Replicated { public: /// Add types from traits_type to public API ///@{ - using value_type = typename traits_type::value_type; - using reference = typename traits_type::reference; - using const_reference = typename traits_type::const_reference; + using value_type = typename traits_type::element_type; + using reference = typename traits_type::element_reference; + using const_reference = typename traits_type::const_element_reference; using buffer_type = typename traits_type::buffer_type; using buffer_view = typename traits_type::buffer_view; using const_buffer_view = typename traits_type::const_buffer_view; @@ -52,11 +52,9 @@ class Contiguous : public Replicated { using shape_type = typename traits_type::shape_type; using const_shape_view = typename traits_type::const_shape_view; using size_type = typename traits_type::size_type; + using index_vector = typename traits_type::index_vector; ///@} - /// Type of an offset vector - using index_vector = std::vector; - /// Type of the object used to annotate modes using typename my_base_type::label_type; using string_type = std::string; @@ -198,39 +196,6 @@ class Contiguous : public Replicated { */ size_type size() const noexcept; - /** @brief Returns the element with the offsets specified by @p index. - * - * This method will retrieve a const reference to the element at the - * offsets specified by @p index. The length of @p index must be equal - * to the rank of *this and each entry in @p index must be less than the - * extent of the corresponding mode of *this. - * - * This method can only be used to retrieve elements from *this. To modify - * elements use set_elem(). - * - * @param[in] index The offsets into each mode of *this for the desired - * element. - * - * @return A const reference to the element at the specified offsets. - */ - const_reference get_elem(index_vector index) const; - - /** @brief Sets the specified element to @p new_value. - * - * This method will set the element at the offsets specified by @p index. - * The length of @p index must be equal to the rank of *this and each - * entry in @p index must be less than the extent of the corresponding - * mode of *this. - * - * @param[in] index The offsets into each mode of *this for the desired - * element. - * @param[in] new_value The new value for the specified element. - * - * @throw std::out_of_range if any entry in @p index is invalid. Strong - * throw guarantee. - */ - void set_elem(index_vector index, value_type new_value); - /** @brief Returns a view of the data. * */ @@ -296,6 +261,39 @@ class Contiguous : public Replicated { /// Uses Eigen's printing capabilities to add to stream std::ostream& add_to_stream_(std::ostream& os) const override; + /** @brief Returns the element with the offsets specified by @p index. + * + * This method will retrieve a const reference to the element at the + * offsets specified by @p index. The length of @p index must be equal + * to the rank of *this and each entry in @p index must be less than the + * extent of the corresponding mode of *this. + * + * This method can only be used to retrieve elements from *this. To modify + * elements use set_elem(). + * + * @param[in] index The offsets into each mode of *this for the desired + * element. + * + * @return A const reference to the element at the specified offsets. + */ + const_element_reference get_elem_(index_vector index) const override; + + /** @brief Sets the specified element to @p new_value. + * + * This method will set the element at the offsets specified by @p index. + * The length of @p index must be equal to the rank of *this and each + * entry in @p index must be less than the extent of the corresponding + * mode of *this. + * + * @param[in] index The offsets into each mode of *this for the desired + * element. + * @param[in] new_value The new value for the specified element. + * + * @throw std::out_of_range if any entry in @p index is invalid. Strong + * throw guarantee. + */ + void set_elem_(index_vector index, element_type new_value) override; + private: /// Type for storing the hash of *this using hash_type = std::size_t; diff --git a/include/tensorwrapper/buffer/detail_/buffer_view_base_pimpl.hpp b/include/tensorwrapper/buffer/detail_/buffer_view_base_pimpl.hpp new file mode 100644 index 00000000..7d76c531 --- /dev/null +++ b/include/tensorwrapper/buffer/detail_/buffer_view_base_pimpl.hpp @@ -0,0 +1,48 @@ +/* + * Copyright 2026 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +namespace tensorwrapper::buffer::detail_ { + +/// PIMPL holding a non-owning pointer to a LayoutBase. +template +class BufferViewBasePIMPL { +private: + /// Type of the buffer base + using buffer_base_type = BufferBaseType; + using traits_type = types::ClassTraits; + +public: + using layout_type = typename traits_type::layout_type; + using layout_pointer = typename traits_type::layout_pointer; + using layout_reference = typename traits_type::layout_reference; + using const_layout_reference = typename traits_type::const_layout_reference; + + explicit BufferViewBasePIMPL(layout_pointer p) noexcept : + m_layout_ptr_(p) {} + + layout_reference layout() { return *m_layout_ptr_; } + const_layout_reference layout() const { return *m_layout_ptr_; } + + bool has_layout() const noexcept { return m_layout_ptr_ != nullptr; } + + auto clone() const { return std::make_unique(*this); } + +private: + layout_pointer m_layout_ptr_; +}; +} // namespace tensorwrapper::buffer::detail_ diff --git a/include/tensorwrapper/buffer/local.hpp b/include/tensorwrapper/buffer/local.hpp index e2cd1c5b..e54f57ef 100644 --- a/include/tensorwrapper/buffer/local.hpp +++ b/include/tensorwrapper/buffer/local.hpp @@ -16,7 +16,8 @@ #pragma once #include - +#include +#include namespace tensorwrapper::buffer { /** @brief Establishes that the state in the buffer is obtainable without @@ -36,4 +37,22 @@ class Local : public BufferBase { using my_base_type::my_base_type; }; +/** @brief A view of a Local buffer. + * + * This class is a view of a Local buffer. It is used to create a view of a + * Local buffer. It is not a strong type and does not impart any additional + * state to the BufferViewBase class. + */ +template +class LocalView + : public BufferViewBase> { +private: + using buffer_base_type = types::preserve_const_t; + using my_base_type = BufferViewBase; + +public: + /// Pull in base's ctors + using my_base_type::my_base_type; +}; + } // namespace tensorwrapper::buffer diff --git a/include/tensorwrapper/buffer/replicated.hpp b/include/tensorwrapper/buffer/replicated.hpp index 402b8fb3..3d8479c3 100644 --- a/include/tensorwrapper/buffer/replicated.hpp +++ b/include/tensorwrapper/buffer/replicated.hpp @@ -16,6 +16,7 @@ #pragma once #include +#include namespace tensorwrapper::buffer { @@ -24,14 +25,20 @@ namespace tensorwrapper::buffer { * At the moment this class is a strong type and has no additional state over * its base class. */ -class Replicated : public Local { +class Replicated : public ReplicatedCommon, public Local { private: /// Type *this derives from - using my_base_type = Local; + using my_base_type = ReplicatedCommon; public: // Pull in base's ctors - using my_base_type::my_base_type; + using Local::Local; + +protected: + friend my_base_type; + + virtual const_element_reference get_elem_(index_vector index) const = 0; + virtual void set_elem_(index_vector index, element_type value) = 0; }; } // namespace tensorwrapper::buffer diff --git a/include/tensorwrapper/buffer/replicated_common.hpp b/include/tensorwrapper/buffer/replicated_common.hpp new file mode 100644 index 00000000..e9fa21ca --- /dev/null +++ b/include/tensorwrapper/buffer/replicated_common.hpp @@ -0,0 +1,80 @@ +/* + * Copyright 2026 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace tensorwrapper::buffer { + +/** @brief Implements and defines common functionality for replicated buffers. + * + * @tparam Derived The class *this is implementing. Expected to be unqualified + * Replicated or ReplicatedView. + * + * To use this class the derived class must define: + * - `const_element_reference get_elem_(index_vector index) const` so that it + * returns the element at the given index. + * - `void set_elem_(index_vector index, element_type value)` so that it sets + * the element at the given index to the given value. + * + * This class is used to implement the common functionality for replicated + * buffers. + */ +template +class ReplicatedCommon { +private: + using my_traits = types::ClassTraits; + +public: + /// Pull in types from traits_type + ///@{ + using element_type = typename my_traits::element_type; + using element_reference = typename my_traits::element_reference; + using const_element_reference = typename my_traits::const_element_reference; + using size_type = typename my_traits::size_type; + using index_vector = typename my_traits::index_vector; + ///@} + + /** @brief Returns the element at the given index. + * + * @param[in] index The index of the element to return. + * + * @return The element at the given index. + */ + const_element_reference get_elem(index_vector index) const { + return derived().get_elem_(index); + } + + /** @brief Sets the element at the given index to the given value. + * + * @param[in] index The index of the element to set. + * @param[in] value The value to set the element to. + */ + void set_elem(index_vector index, element_type value) { + derived().set_elem_(index, value); + } + +private: + /// Access derived for CRTP + Derived& derived() { return static_cast(*this); } + + /// Access derived for CRTP read-only + const Derived& derived() const { + return static_cast(*this); + } +}; + +} // namespace tensorwrapper::buffer diff --git a/include/tensorwrapper/buffer/replicated_view.hpp b/include/tensorwrapper/buffer/replicated_view.hpp new file mode 100644 index 00000000..be0879ce --- /dev/null +++ b/include/tensorwrapper/buffer/replicated_view.hpp @@ -0,0 +1,178 @@ +/* + * Copyright 2026 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +namespace tensorwrapper::buffer { +namespace detail_ { + +template +class ReplicatedViewPIMPL; + +} // namespace detail_ + +/** @brief A view of a Replicated buffer. + * + * @tparam ReplicatedType The type of the Replicated buffer to view. Expected + * to be unqualified Replicated or const Replicated. + * + * This class is a view of a Replicated buffer. It is used to create a view of + * a Replicated buffer. + */ +template +class ReplicatedView + : public ReplicatedCommon>, + public LocalView> { +private: + using my_type = ReplicatedView; + using common_base_type = ReplicatedCommon; + using local_base_type = + LocalView>; + using my_base_type = LocalView; + +public: + /// Pull in base's types + ///@{ + using typename common_base_type::const_element_reference; + using typename common_base_type::element_reference; + using typename common_base_type::element_type; + using typename common_base_type::index_vector; + using typename common_base_type::size_type; + ///@} + + /// Type of the PIMPL + using pimpl_type = detail_::ReplicatedViewPIMPL; + + /// Type of a pointer to the PIMPL + using pimpl_pointer = std::unique_ptr; + + /** @brief Default constructor. + * + * This constructor will create a view with no layout and no elements. + * + * @throw None No throw guarantee. + */ + ReplicatedView(); + + /** @brief Slice construction. + * + * This ctor will create a view of the @p replicated buffer starting at + * the @p first_elem and ending at the @p last_elem. + * + * @param[in] replicated The replicated buffer to slice. The replicated + * buffer must outlive the view. + * @param[in] first_elem The first element of the slice. + * @param[in] last_elem The last element of the slice. + * + * @throw std::runtime_error if the slice is invalid. Strong throw + * guarantee. + * @throw std::bad_alloc if there is a problem allocating the PIMPL. Strong + * throw guarantee. + */ + ReplicatedView(ReplicatedType& replicated, index_vector first_elem, + index_vector last_elem); + + /** @brief Creates a new view implemented by @p pimpl. + * + * @param[in] pimpl A pointer to the PIMPL to use as the backing store. + * + * @throw None No throw guarantee. + */ + ReplicatedView(pimpl_pointer pimpl); + + /** @brief Creates a new view by copying the state of @p other. + * + * This ctor will create a new view by copying the state of @p other. + * After this operation *this will alias the same object @p other did. + * + * @param[in] other The view to copy. + * + * @throw std::bad_alloc if there is a problem allocating the copy. Strong + * throw guarantee. + */ + ReplicatedView(const ReplicatedView& other); + + /** @brief Overwrites the state of *this with the state of @p rhs. + * + * This operator will overwrite the state of *this by moving the pointers + * in @p rhs. After this operation *this will alias + * the same object @p rhs did. + * + * @param[in] rhs The view to move from. + * + * @throw None No throw guarantee. + */ + ReplicatedView(ReplicatedView&& other) noexcept; + + /** @brief Overwrites the state of *this with the state of @p rhs. + * + * This operator will overwrite the state of *this by copying the pointers + * in @p rhs. This is a shallow copy. After this operation *this will alias + * the same object @p rhs did. It is worth noting the aliased object is + * untouched after this operation. + * + * @param[in] rhs The view to copy. + * + * @return *this after making it alias the state in @p rhs. + * + * @throw std::bad_alloc if there is a problem allocating the copy. Strong + * throw guarantee. + */ + ReplicatedView& operator=(const ReplicatedView& rhs); + + /** @brief Overwrites the state of *this with the state of @p rhs. + * + * This operator will overwrite the state of *this with the state of + * @p rhs. After this operation *this will alias the same object @p rhs + * did. + * + * @param[in] rhs The view to copy. + * + * @return *this after making it alias the state in @p rhs. + * + * @throw None No throw guarantee. + */ + ReplicatedView& operator=(ReplicatedView&& rhs) noexcept; + + /// No-throw dtor. + ~ReplicatedView() noexcept; + +protected: + friend common_base_type; + + /// Implements get_elem for the view. + const_element_reference get_elem_(index_vector index) const; + + /// Implements set_elem for the view. + void set_elem_(index_vector index, element_type value); + +private: + /// Does *this have a PIMPL? + bool has_pimpl_() const noexcept; + + /// Throws if *this does not have a PIMPL. + void assert_pimpl_() const; + + /// The PIMPL holding the data for the view. + pimpl_pointer m_pimpl_; +}; + +extern template class ReplicatedView; +extern template class ReplicatedView; + +} // namespace tensorwrapper::buffer diff --git a/include/tensorwrapper/layout/layout_common.hpp b/include/tensorwrapper/layout/layout_common.hpp index 92a1bee8..3b359e78 100644 --- a/include/tensorwrapper/layout/layout_common.hpp +++ b/include/tensorwrapper/layout/layout_common.hpp @@ -16,6 +16,7 @@ #pragma once #include +#include #include namespace tensorwrapper::layout { diff --git a/include/tensorwrapper/types/buffer_traits.hpp b/include/tensorwrapper/types/buffer_traits.hpp index e42a5b7e..a6f34205 100644 --- a/include/tensorwrapper/types/buffer_traits.hpp +++ b/include/tensorwrapper/types/buffer_traits.hpp @@ -19,38 +19,79 @@ #include #include #include +#include +#include +#include +#include namespace tensorwrapper::types { -template -struct ClassTraits> { - /// Type of the class describing the physical layout of the buffer - using layout_type = layout::Physical; - - /// Type of a read-only reference to a layout +struct BufferBaseTraitsCommon : public CommonTypes { + using layout_type = layout::Physical; using const_layout_reference = const layout_type&; + using const_layout_pointer = const layout_type*; - /// Type used to represent the tensor's rank - using rank_type = typename layout_type::size_type; + using buffer_base_type = buffer::BufferBase; + using const_buffer_base_pointer = std::unique_ptr; + using const_buffer_base_reference = const buffer_base_type&; }; template<> -struct ClassTraits - : public ClassTraits> { - /// Type all buffers inherit from - using buffer_base_type = buffer::BufferBase; - - /// Type of a mutable reference to a buffer_base_type object +struct ClassTraits : public BufferBaseTraitsCommon { + using layout_reference = layout_type&; + using layout_pointer = layout_type*; using buffer_base_reference = buffer_base_type&; + using buffer_base_pointer = std::unique_ptr; +}; - /// Type of a read-only reference to a buffer_base_type object - using const_buffer_base_reference = const buffer_base_type&; +template<> +struct ClassTraits : public BufferBaseTraitsCommon { + using layout_reference = const layout_type&; + using layout_pointer = const layout_type*; + using buffer_base_reference = const buffer_base_type&; + using buffer_base_pointer = std::unique_ptr; + using const_buffer_base_pointer = std::unique_ptr; +}; - /// Type of a unique_ptr to a mutable buffer_base_type - using buffer_base_pointer = std::unique_ptr; +template +struct ClassTraits> + : public ClassTraits {}; - /// Type of a unique_ptr to a mutable buffer_base_type - using const_buffer_base_pointer = std::unique_ptr; +template<> +struct ClassTraits : public ClassTraits {}; + +template<> +struct ClassTraits + : public ClassTraits {}; + +template +struct ClassTraits> + : public ClassTraits {}; + +struct ReplicatedTraitsCommon { + using element_type = wtf::fp::Float; + using const_element_reference = wtf::fp::FloatView; + using buffer_type = wtf::buffer::FloatBuffer; + using const_buffer_view = wtf::buffer::BufferView; + using index_vector = std::vector; +}; + +template<> +struct ClassTraits : public ReplicatedTraitsCommon, + public ClassTraits { + using element_reference = wtf::fp::FloatView; + using buffer_view = wtf::buffer::BufferView; }; +template<> +struct ClassTraits + : public ReplicatedTraitsCommon, public ClassTraits { + using element_reference = wtf::fp::FloatView; + using buffer_view = wtf::buffer::BufferView; +}; + +template +struct ClassTraits> + : public ClassTraits {}; + } // namespace tensorwrapper::types diff --git a/include/tensorwrapper/types/contiguous_traits.hpp b/include/tensorwrapper/types/contiguous_traits.hpp index 9f4ce64e..41aaeec5 100644 --- a/include/tensorwrapper/types/contiguous_traits.hpp +++ b/include/tensorwrapper/types/contiguous_traits.hpp @@ -16,37 +16,24 @@ #pragma once #include +#include #include #include -#include namespace tensorwrapper::types { struct ContiguousTraitsCommon { - using value_type = wtf::fp::Float; - using const_reference = wtf::fp::FloatView; - using buffer_type = wtf::buffer::FloatBuffer; - using const_buffer_view = wtf::buffer::BufferView; - using shape_type = shape::Smooth; - using const_shape_view = shape::SmoothView; - using rank_type = typename ClassTraits::rank_type; - using size_type = typename ClassTraits::size_type; + using shape_type = shape::Smooth; + using const_shape_view = shape::SmoothView; }; template<> struct ClassTraits - : public ContiguousTraitsCommon { - using reference = wtf::fp::FloatView; - - using buffer_view = wtf::buffer::BufferView; - using const_buffer_view = wtf::buffer::BufferView; -}; + : public ClassTraits, public ContiguousTraitsCommon {}; template<> struct ClassTraits - : public ContiguousTraitsCommon { - using reference = wtf::fp::FloatView; - using buffer_view = wtf::buffer::BufferView; -}; + : public ClassTraits, + public ContiguousTraitsCommon {}; } // namespace tensorwrapper::types diff --git a/include/tensorwrapper/types/preserve_const.hpp b/include/tensorwrapper/types/preserve_const.hpp new file mode 100644 index 00000000..5e9e7120 --- /dev/null +++ b/include/tensorwrapper/types/preserve_const.hpp @@ -0,0 +1,29 @@ +/* + * Copyright 2026 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +namespace tensorwrapper::types { + +/**@ brief Uses the const-ness of @p T to ensure @p U has the same const-ness. + * + * @tparam T The type to preserve the const-ness of. + * @tparam U The type to preserve the const-ness of @p T on. + */ +template +using preserve_const_t = std::conditional_t, const U, U>; + +} // namespace tensorwrapper::types diff --git a/src/tensorwrapper/buffer/contiguous.cpp b/src/tensorwrapper/buffer/contiguous.cpp index 41c1cc4c..d91c6b77 100644 --- a/src/tensorwrapper/buffer/contiguous.cpp +++ b/src/tensorwrapper/buffer/contiguous.cpp @@ -59,17 +59,6 @@ auto Contiguous::shape() const -> const_shape_view { return m_shape_; } auto Contiguous::size() const noexcept -> size_type { return m_buffer_.size(); } -auto Contiguous::get_elem(index_vector index) const -> const_reference { - auto ordinal_index = coordinate_to_ordinal_(index); - return m_buffer_.at(ordinal_index); -} - -void Contiguous::set_elem(index_vector index, value_type new_value) { - auto ordinal_index = coordinate_to_ordinal_(index); - mark_for_rehash_(); - m_buffer_.at(ordinal_index) = new_value; -} - auto Contiguous::get_mutable_data() -> buffer_view { mark_for_rehash_(); return m_buffer_; @@ -261,6 +250,18 @@ std::ostream& Contiguous::add_to_stream_(std::ostream& os) const { return os; } +auto Contiguous::get_elem_(index_vector index) const + -> const_element_reference { + auto ordinal_index = coordinate_to_ordinal_(index); + return m_buffer_.at(ordinal_index); +} + +void Contiguous::set_elem_(index_vector index, element_type new_value) { + auto ordinal_index = coordinate_to_ordinal_(index); + mark_for_rehash_(); + m_buffer_.at(ordinal_index) = new_value; +} + // ----------------------------------------------------------------------------- // -- Private Methods // ----------------------------------------------------------------------------- diff --git a/src/tensorwrapper/buffer/detail_/replicated_view_pimpl.hpp b/src/tensorwrapper/buffer/detail_/replicated_view_pimpl.hpp new file mode 100644 index 00000000..e8785307 --- /dev/null +++ b/src/tensorwrapper/buffer/detail_/replicated_view_pimpl.hpp @@ -0,0 +1,85 @@ +/* + * Copyright 2024 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include + +namespace tensorwrapper::buffer::detail_ { + +/** @brief Implements the API for all ReplicatedView PIMPLs. + * + * @tparam ReplicatedType The type *this will be a view of. + */ +template +class ReplicatedViewPIMPL { +private: + using traits_type = types::ClassTraits; + +public: + /// Pull in types from traits_type + ///@{ + using layout_reference = typename traits_type::layout_reference; + using const_layout_reference = typename traits_type::const_layout_reference; + using layout_type = typename traits_type::layout_type; + using size_type = typename traits_type::size_type; + using element_type = typename traits_type::element_type; + using const_element_reference = + typename traits_type::const_element_reference; + using index_vector = typename traits_type::index_vector; + ///@} + + /// Type of a pointer to the PIMPL + using pimpl_pointer = std::unique_ptr; + + /// No-throw dtor. + virtual ~ReplicatedViewPIMPL() noexcept = default; + + pimpl_pointer clone() const { return clone_(); } + + layout_reference layout() { return layout_(); } + + const_layout_reference layout() const { return layout_(); } + + const_element_reference get_elem(const index_vector& slice_index) const { + return get_elem_(slice_index); + } + + template + std::enable_if_t> set_elem( + const index_vector& slice_index, element_type value) { + set_elem_(slice_index, std::move(value)); + } + +protected: + virtual layout_reference layout_() = 0; + + virtual const_layout_reference layout_() const = 0; + + virtual pimpl_pointer clone_() const = 0; + + /// Derived class should implement to be consistent with get_elem + virtual const_element_reference get_elem_( + const index_vector& slice_index) const = 0; + + /// Derived class should implement to be consistent with set_elem + virtual void set_elem_(const index_vector& slice_index, + element_type value) = 0; +}; + +} // namespace tensorwrapper::buffer::detail_ diff --git a/src/tensorwrapper/buffer/detail_/slice_pimpl.hpp b/src/tensorwrapper/buffer/detail_/slice_pimpl.hpp new file mode 100644 index 00000000..9c8b1cfe --- /dev/null +++ b/src/tensorwrapper/buffer/detail_/slice_pimpl.hpp @@ -0,0 +1,132 @@ +/* + * Copyright 2026 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "replicated_view_pimpl.hpp" +#include + +namespace tensorwrapper::buffer::detail_ { +/** @brief PIMPL holding a non-owning pointer to a Replicated object and slice + * bounds. + * + * Implements a view of a slice of a Replicated buffer. Slice indices are + * relative to the view; index translation to the underlying Replicated is + * performed in get_elem and set_elem. + */ +template +class SlicePIMPL : public ReplicatedViewPIMPL { +private: + using my_base = ReplicatedViewPIMPL; + +public: + /// Pull in types from base + ///@{ + using typename my_base::const_element_reference; + using typename my_base::const_layout_reference; + using typename my_base::element_type; + using typename my_base::index_vector; + using typename my_base::layout_reference; + using typename my_base::layout_type; + using typename my_base::pimpl_pointer; + using typename my_base::size_type; + ///@} + + /// Pull in base's ctors + using my_base::my_base; + + /** @brief Creates a PIMPL that views a slice of @p replicated_ptr. + * + * @param[in] replicated_ptr Non-owning pointer to the Replicated object + * (or nullptr for an empty view). + * @param[in] first_elem Indices of the first element in the slice + * (inclusive). + * @param[in] last_elem Indices of the first element not in the slice + * (exclusive). + */ + SlicePIMPL(ReplicatedType* replicated_ptr, index_vector first_elem, + index_vector last_elem) : + m_replicated_ptr_(replicated_ptr), + m_first_elem_(std::move(first_elem)), + m_last_elem_(std::move(last_elem)), + m_layout_ptr_(replicated_ptr ? + replicated_ptr->layout() + .slice(m_first_elem_.begin(), m_first_elem_.end(), + m_last_elem_.begin(), m_last_elem_.end()) + .template clone_as() : + nullptr) {} + +protected: + pimpl_pointer clone_() const override { + return std::make_unique(m_replicated_ptr_, m_first_elem_, + m_last_elem_); + } + + layout_reference layout_() override { return *m_layout_ptr_; } + + const_layout_reference layout_() const override { return *m_layout_ptr_; } + + const_element_reference get_elem_( + const index_vector& slice_index) const override { + return replicated().get_elem(translate_index_(slice_index)); + } + + void set_elem_(const index_vector& slice_index, + element_type value) override { + if constexpr(std::is_const_v) { + throw std::runtime_error( + "Cannot set elements of a const ReplicatedViewPIMPL."); + } else { + auto new_index = translate_index_(slice_index); + replicated().set_elem(new_index, std::move(value)); + } + } + +private: + ReplicatedType* m_replicated_ptr_; + index_vector m_first_elem_; + index_vector m_last_elem_; + + std::unique_ptr m_layout_ptr_; + + void assert_replicated_ptr_() const { + if(m_replicated_ptr_ == nullptr) { + throw std::runtime_error( + "SlicePIMPL has no Replicated object. Was it default " + "initialized?"); + } + } + + ReplicatedType& replicated() { + assert_replicated_ptr_(); + return *m_replicated_ptr_; + } + + const ReplicatedType& replicated() const { + assert_replicated_ptr_(); + return *m_replicated_ptr_; + } + + index_vector translate_index_(const index_vector& slice_index) const { + index_vector result; + result.reserve(slice_index.size()); + for(size_type i = 0; i < slice_index.size(); ++i) { + result.push_back(m_first_elem_[i] + slice_index[i]); + } + return result; + } +}; + +} // namespace tensorwrapper::buffer::detail_ diff --git a/src/tensorwrapper/buffer/replicated_view.cpp b/src/tensorwrapper/buffer/replicated_view.cpp new file mode 100644 index 00000000..bffe44b9 --- /dev/null +++ b/src/tensorwrapper/buffer/replicated_view.cpp @@ -0,0 +1,115 @@ +/* + * Copyright 2026 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "detail_/slice_pimpl.hpp" +#include +#include + +namespace tensorwrapper::buffer { + +#define TPARAMS template +#define REPLICATED_VIEW ReplicatedView + +// ----------------------------------------------------------------------------- +// -- Ctors, assignment, and dtor +// ----------------------------------------------------------------------------- + +TPARAMS +REPLICATED_VIEW::ReplicatedView() : m_pimpl_(nullptr) {} + +TPARAMS +REPLICATED_VIEW::ReplicatedView(ReplicatedType& replicated, + index_vector first_elem, + index_vector last_elem) : + ReplicatedView(std::make_unique>( + &replicated, first_elem, last_elem)) {} + +TPARAMS +REPLICATED_VIEW::ReplicatedView(pimpl_pointer pimpl) : + local_base_type(&pimpl->layout()), m_pimpl_(std::move(pimpl)) {} + +TPARAMS +REPLICATED_VIEW::ReplicatedView(const ReplicatedView& other) : + m_pimpl_(other.has_pimpl_() ? other.m_pimpl_->clone() : nullptr) {} + +TPARAMS +REPLICATED_VIEW::ReplicatedView(ReplicatedView&& other) noexcept = default; + +TPARAMS +auto REPLICATED_VIEW::operator=(const ReplicatedView& rhs) -> ReplicatedView& { + if(this != &rhs) { + m_pimpl_ = rhs.has_pimpl_() ? rhs.m_pimpl_->clone() : nullptr; + } + return *this; +} + +TPARAMS +auto REPLICATED_VIEW::operator=(ReplicatedView&& rhs) noexcept + -> ReplicatedView& { + if(this != &rhs) { + m_pimpl_ = rhs.has_pimpl_() ? rhs.m_pimpl_->clone() : nullptr; + } + return *this; +} + +TPARAMS +REPLICATED_VIEW::~ReplicatedView() noexcept = default; + +// ----------------------------------------------------------------------------- +// -- Protected methods +// ----------------------------------------------------------------------------- + +TPARAMS +auto REPLICATED_VIEW::get_elem_(index_vector index) const + -> const_element_reference { + assert_pimpl_(); + return m_pimpl_->get_elem(index); +} + +TPARAMS +void REPLICATED_VIEW::set_elem_(index_vector index, element_type value) { + assert_pimpl_(); + if constexpr(std::is_const_v) { + throw std::runtime_error( + "Cannot set element of a const ReplicatedView"); + } else { + m_pimpl_->set_elem(index, value); + } +} + +// ----------------------------------------------------------------------------- +// -- Private methods +// ----------------------------------------------------------------------------- + +TPARAMS +bool REPLICATED_VIEW::has_pimpl_() const noexcept { + return static_cast(m_pimpl_); +} + +TPARAMS +void REPLICATED_VIEW::assert_pimpl_() const { + if(has_pimpl_()) return; + throw std::runtime_error( + "ReplicatedView has no PIMPL. Was it default constructed?"); +} + +#undef REPLICATED_VIEW +#undef TPARAMS + +template class ReplicatedView; +template class ReplicatedView; + +} // namespace tensorwrapper::buffer diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/detail_/slice_pimpl.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/detail_/slice_pimpl.cpp new file mode 100644 index 00000000..7d175f2d --- /dev/null +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/detail_/slice_pimpl.cpp @@ -0,0 +1,85 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../../testing/testing.hpp" +#include +#include + +using namespace tensorwrapper; +using namespace buffer; +using namespace buffer::detail_; + +TEST_CASE("SlicePIMPL") { + using slice_type = SlicePIMPL; + using const_slice_type = SlicePIMPL; + auto pvector = testing::eigen_vector(4); + auto& vector = *pvector; + // vector has values [0, 1, 2, 3] at indices 0, 1, 2, 3 + + auto pmatrix = testing::eigen_matrix(2, 2); + auto& matrix = *pmatrix; + // matrix has values: (0,0)=1, (0,1)=2, (1,0)=3, (1,1)=4 + + slice_type vector_slice(&vector, {1}, {3}); + const_slice_type const_vector_slice(&vector, {1}, {3}); + + slice_type matrix_slice(&matrix, {0, 1}, {2, 2}); + const_slice_type const_matrix_slice(&matrix, {0, 1}, {2, 2}); + + SECTION("Constructor") { + // Slice indices 1:3 of vector -> slice index 0 maps to 1, slice index 1 + // maps to 2 + REQUIRE(vector_slice.get_elem({0}) == 1.0); + REQUIRE(vector_slice.get_elem({1}) == 2.0); + + REQUIRE(const_vector_slice.get_elem({0}) == 1.0); + REQUIRE(const_vector_slice.get_elem({1}) == 2.0); + + // Slice first_elem={0,1}, last_elem={2,2} -> 2 rows, 1 column + // slice {0,0} -> underlying {0,1} -> value 2 + // slice {1,0} -> underlying {1,1} -> value 4 + + REQUIRE(matrix_slice.get_elem({0, 0}) == 2.0); + REQUIRE(matrix_slice.get_elem({1, 0}) == 4.0); + + REQUIRE(const_matrix_slice.get_elem({0, 0}) == 2.0); + REQUIRE(const_matrix_slice.get_elem({1, 0}) == 4.0); + } + + SECTION("clone") { + vector_slice.set_elem({0}, 1.0); + auto cloned = vector_slice.clone(); + REQUIRE(cloned->get_elem({0}) == 1.0); + REQUIRE(cloned->get_elem({1}) == 2.0); + } + + SECTION("get_elem") { + REQUIRE(const_vector_slice.get_elem({0}) == 1.0); + REQUIRE(const_vector_slice.get_elem({1}) == 2.0); + + slice_type null_slice(nullptr, {0}, {1}); + REQUIRE_THROWS_AS(null_slice.get_elem({0}), std::runtime_error); + } + + SECTION("set_elem") { + vector_slice.set_elem({0}, 99.0); + REQUIRE(vector.get_elem({1}) == 99.0); + REQUIRE(vector_slice.get_elem({0}) == 99.0); + + slice_type null_slice(nullptr, {0}, {1}); + REQUIRE_THROWS_AS(null_slice.set_elem({0}, 1.0), std::runtime_error); + } +} diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/replicated_view.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/replicated_view.cpp new file mode 100644 index 00000000..f0c3afef --- /dev/null +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/replicated_view.cpp @@ -0,0 +1,93 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../testing/testing.hpp" +#include +#include +#include +#include + +using namespace tensorwrapper; +using namespace buffer; + +TEST_CASE("ReplicatedView") { + using view_type = ReplicatedView; + using const_view_type = ReplicatedView; + auto pvector = testing::eigen_vector(4); + auto& vector = *pvector; + // vector has values [0, 1, 2, 3] at indices 0, 1, 2, 3 + + auto pmatrix = testing::eigen_matrix(2, 2); + auto& matrix = *pmatrix; + // matrix has values: (0,0)=1, (0,1)=2, (1,0)=3, (1,1)=4 + + view_type vector_view(vector, {1}, {3}); + const_view_type const_vector_view(vector, {1}, {3}); + + view_type matrix_view(matrix, {0, 1}, {2, 2}); + const_view_type const_matrix_view(matrix, {0, 1}, {2, 2}); + + SECTION("Slice Constructor") { + // Slice indices 1:3 of vector -> slice index 0 maps to 1, slice index 1 + // maps to 2 + REQUIRE(vector_view.get_elem({0}) == 1.0); + REQUIRE(vector_view.get_elem({1}) == 2.0); + + auto corr_layout = vector.layout().slice({1}, {3}); + REQUIRE(vector_view.layout().are_equal(corr_layout)); + + REQUIRE(const_vector_view.get_elem({0}) == 1.0); + REQUIRE(const_vector_view.get_elem({1}) == 2.0); + REQUIRE(const_vector_view.layout().are_equal(corr_layout)); + + // Slice first_elem={0,1}, last_elem={2,2} -> 2 rows, 1 column + // slice {0,0} -> underlying {0,1} -> value 2 + // slice {1,0} -> underlying {1,1} -> value 4 + + REQUIRE(matrix_view.get_elem({0, 0}) == 2.0); + REQUIRE(matrix_view.get_elem({1, 0}) == 4.0); + + auto matrix_slice_layout = matrix.layout().slice({0, 1}, {2, 2}); + REQUIRE(matrix_view.layout().are_equal(matrix_slice_layout)); + REQUIRE(const_matrix_view.get_elem({0, 0}) == 2.0); + REQUIRE(const_matrix_view.get_elem({1, 0}) == 4.0); + REQUIRE(const_matrix_view.layout().are_equal(matrix_slice_layout)); + } + + SECTION("clone") { + vector_view.set_elem({0}, 1.0); + auto cloned = view_type(vector_view); + REQUIRE(cloned.get_elem({0}) == 1.0); + REQUIRE(cloned.get_elem({1}) == 2.0); + } + + SECTION("get_elem") { + REQUIRE(const_vector_view.get_elem({0}) == 1.0); + REQUIRE(const_vector_view.get_elem({1}) == 2.0); + + view_type defaulted; + REQUIRE_THROWS_AS(defaulted.get_elem({0}), std::runtime_error); + } + + SECTION("set_elem") { + vector_view.set_elem({0}, 99.0); + REQUIRE(vector.get_elem({1}) == 99.0); + REQUIRE(vector_view.get_elem({0}) == 99.0); + + view_type defaulted; + REQUIRE_THROWS_AS(defaulted.set_elem({0}, 1.0), std::runtime_error); + } +}