diff --git a/include/tensorwrapper/buffer/contiguous.hpp b/include/tensorwrapper/buffer/contiguous.hpp index c44f7a3d..f8dcf24f 100644 --- a/include/tensorwrapper/buffer/contiguous.hpp +++ b/include/tensorwrapper/buffer/contiguous.hpp @@ -18,7 +18,7 @@ #include #include #include -#include +#include #include namespace tensorwrapper::buffer { @@ -294,6 +294,11 @@ class Contiguous : public Replicated { */ void set_elem_(index_vector index, element_type new_value) override; + slice_type slice_(index_vector first_elem, index_vector last_elem) override; + + const_slice_type slice_(index_vector first_elem, + index_vector last_elem) const override; + private: /// Type for storing the hash of *this using hash_type = std::size_t; diff --git a/include/tensorwrapper/buffer/replicated.hpp b/include/tensorwrapper/buffer/replicated.hpp index 3d8479c3..a99f772a 100644 --- a/include/tensorwrapper/buffer/replicated.hpp +++ b/include/tensorwrapper/buffer/replicated.hpp @@ -17,6 +17,7 @@ #pragma once #include #include +#include namespace tensorwrapper::buffer { @@ -36,9 +37,14 @@ class Replicated : public ReplicatedCommon, public Local { protected: friend my_base_type; + friend my_base_type::sliceable_base; virtual const_element_reference get_elem_(index_vector index) const = 0; virtual void set_elem_(index_vector index, element_type value) = 0; + virtual slice_type slice_(index_vector first_elem, + index_vector last_elem) = 0; + virtual const_slice_type slice_(index_vector first_elem, + index_vector last_elem) const = 0; }; } // namespace tensorwrapper::buffer diff --git a/include/tensorwrapper/buffer/replicated_common.hpp b/include/tensorwrapper/buffer/replicated_common.hpp index e9fa21ca..d1f90b33 100644 --- a/include/tensorwrapper/buffer/replicated_common.hpp +++ b/include/tensorwrapper/buffer/replicated_common.hpp @@ -15,6 +15,8 @@ */ #pragma once +#include +#include #include namespace tensorwrapper::buffer { @@ -34,10 +36,13 @@ namespace tensorwrapper::buffer { * buffers. */ template -class ReplicatedCommon { +class ReplicatedCommon : public interfaces::Sliceable { private: using my_traits = types::ClassTraits; +protected: + using sliceable_base = interfaces::Sliceable; + public: /// Pull in types from traits_type ///@{ @@ -46,6 +51,9 @@ class ReplicatedCommon { using const_element_reference = typename my_traits::const_element_reference; using size_type = typename my_traits::size_type; using index_vector = typename my_traits::index_vector; + using slice_type = typename my_traits::slice_type; + using const_slice_type = typename my_traits::const_slice_type; + using slice_il_type = typename my_traits::slice_il_type; ///@} /** @brief Returns the element at the given index. diff --git a/include/tensorwrapper/buffer/replicated_view.hpp b/include/tensorwrapper/buffer/replicated_view.hpp index be0879ce..7eb932b1 100644 --- a/include/tensorwrapper/buffer/replicated_view.hpp +++ b/include/tensorwrapper/buffer/replicated_view.hpp @@ -49,10 +49,12 @@ class ReplicatedView /// Pull in base's types ///@{ using typename common_base_type::const_element_reference; + using typename common_base_type::const_slice_type; using typename common_base_type::element_reference; using typename common_base_type::element_type; using typename common_base_type::index_vector; using typename common_base_type::size_type; + using typename common_base_type::slice_type; ///@} /// Type of the PIMPL @@ -154,6 +156,7 @@ class ReplicatedView protected: friend common_base_type; + friend typename common_base_type::sliceable_base; /// Implements get_elem for the view. const_element_reference get_elem_(index_vector index) const; @@ -161,6 +164,13 @@ class ReplicatedView /// Implements set_elem for the view. void set_elem_(index_vector index, element_type value); + /// Implements slice for the view. + slice_type slice_(index_vector first_elem, index_vector last_elem); + + /// Implements slice for the view. + const_slice_type slice_(index_vector first_elem, + index_vector last_elem) const; + private: /// Does *this have a PIMPL? bool has_pimpl_() const noexcept; diff --git a/include/tensorwrapper/concepts/has_begin_end.hpp b/include/tensorwrapper/concepts/has_begin_end.hpp new file mode 100644 index 00000000..d0906c9b --- /dev/null +++ b/include/tensorwrapper/concepts/has_begin_end.hpp @@ -0,0 +1,28 @@ +/* + * Copyright 2026 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace tensorwrapper::concepts { + +template +concept HasBeginEnd = requires(T t) { + { t.begin() } -> std::same_as; + { t.end() } -> std::same_as; +}; + +} // namespace tensorwrapper::concepts diff --git a/include/tensorwrapper/interfaces/sliceable.hpp b/include/tensorwrapper/interfaces/sliceable.hpp new file mode 100644 index 00000000..12a9299c --- /dev/null +++ b/include/tensorwrapper/interfaces/sliceable.hpp @@ -0,0 +1,166 @@ +/* + * Copyright 2026 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +namespace tensorwrapper::interfaces { + +/** @brief Defines the interface for objects that can be sliced. + * + * @tparam Derived The type of the derived class. + * + * Assume we have a rank @f$r@f$ object. A slice of this object is a new object + * that is a rank @f$r@f$ object with a subset relationship with respect to the + * original object. Slices are contiguous in index space (for selecting + * arbitrary elements from the original object see masks), meaning the slice + * can be specified by providing two indices. The first index is the index of + * the first element IN the slice and the second index is the index of the + * first element NOT IN the slice. We term the first index in the slice is + * denoted @f$i_0@f$ and the first index not in the slice is denoted @f$i_N@f$ + * where @f$N@f$ is the size of the slice (i.e. the number of elements in the + * slice). + * + * TODO: Example of the range notation. + * + * The Sliceable interface defines a number of overloads for the `slice` method + * these overloads are: + * + * 1. @f$i_0@f$ and @f$i_N@f$ are provided as @f$r@f$ element initializer + * lists. The returned slice is mutable. + * 2. @f$i_0@f$ and @f$i_N@f$ are provided as @f$r@f$ element initalizer + * lists. The returned slice is read-only. + * 3. @f$i_0@f$ and @f$i_N@f$ are provided as @f$r@f$ element containers that + * support `begin` and `end` iterators. The returned slice is mutable. + * 4. @f$i_0@f$ and @f$i_N@f$ are provided as @f$r@f$ element containers that + * support `begin` and `end` iterators. The returned slice is read-only. + * 5. @f$i_0@f$ and @f$i_N@f$ are provided as @f$r@f$ ranges of iterators. The + * returned slice is mutable. + * 6. @f$i_0@f$ and @f$i_N@f$ are provided as @f$r@f$ ranges of iterators. The + * returned slice is read-only. + * + * To use this interface the @p Derived class must: + * + * 1. Specialize `ClassTraits` to provide a member type `slice_type` + * and `const_slice_type`. + * 2. Implement the const and non-const versions of the `slice_` method. + */ +template +class Sliceable { +private: + using my_traits = types::ClassTraits; + +public: + using index_vector = typename my_traits::index_vector; + using slice_type = typename my_traits::slice_type; + using const_slice_type = typename my_traits::const_slice_type; + using slice_il_type = typename my_traits::slice_il_type; + + /// Overload 1. + slice_type slice(slice_il_type first_elem, slice_il_type last_elem) { + return slice(first_elem.begin(), first_elem.end(), last_elem.begin(), + last_elem.end()); + } + + /// Overload 2. + const_slice_type slice(slice_il_type first_elem, + slice_il_type last_elem) const { + return slice_impl_(index_vector(first_elem.begin(), first_elem.end()), + index_vector(last_elem.begin(), last_elem.end())); + } + + /// Overload 3. + template + slice_type slice(ContainerType0&& first_elem, ContainerType1&& last_elem); + + /// Overload 4. + template + const_slice_type slice(ContainerType0&& first_elem, + ContainerType1&& last_elem) const; + + /// Overload 5. + template + slice_type slice(BeginItr0 first_elem_begin, EndItr0 first_elem_end, + BeginItr1 last_elem_begin, EndItr1 last_elem_end) { + return slice_impl_(index_vector(first_elem_begin, first_elem_end), + index_vector(last_elem_begin, last_elem_end)); + } + + /// Overload 6. + template + const_slice_type slice(BeginItr0 first_elem_begin, EndItr0 first_elem_end, + BeginItr1 last_elem_begin, + EndItr1 last_elem_end) const { + return slice_impl_(index_vector(first_elem_begin, first_elem_end), + index_vector(last_elem_begin, last_elem_end)); + } + +private: + slice_type slice_impl_(index_vector first_elem, index_vector last_elem) { + return derived().slice_(first_elem, last_elem); + } + + const_slice_type slice_impl_(index_vector first_elem, + index_vector last_elem) const { + return derived().slice_(first_elem, last_elem); + } + + Derived& derived() { return static_cast(*this); } + + const Derived& derived() const { + return static_cast(*this); + } +}; + +// ----------------------------------------------------------------------------- +// -- Out of line implementations +// ----------------------------------------------------------------------------- + +template +template +auto Sliceable::slice(ContainerType0&& first_elem, + ContainerType1&& last_elem) -> slice_type { + if constexpr(std::is_same_v, index_vector> && + std::is_same_v, index_vector>) { + return slice_impl_(first_elem, last_elem); + } else { + return slice_impl_(index_vector(first_elem.begin(), first_elem.end()), + index_vector(last_elem.begin(), last_elem.end())); + } +} + +template +template +auto Sliceable::slice(ContainerType0&& first_elem, + ContainerType1&& last_elem) const + -> const_slice_type { + if constexpr(std::is_same_v, index_vector> && + std::is_same_v, index_vector>) { + return slice_impl_(first_elem, last_elem); + } else { + return slice_impl_(index_vector(first_elem.begin(), first_elem.end()), + index_vector(last_elem.begin(), last_elem.end())); + } +} + +} // namespace tensorwrapper::interfaces diff --git a/include/tensorwrapper/types/buffer_traits.hpp b/include/tensorwrapper/types/buffer_traits.hpp index a6f34205..b044b1db 100644 --- a/include/tensorwrapper/types/buffer_traits.hpp +++ b/include/tensorwrapper/types/buffer_traits.hpp @@ -74,6 +74,8 @@ struct ReplicatedTraitsCommon { using buffer_type = wtf::buffer::FloatBuffer; using const_buffer_view = wtf::buffer::BufferView; using index_vector = std::vector; + using const_slice_type = buffer::ReplicatedView; + using slice_il_type = std::initializer_list; }; template<> @@ -81,6 +83,7 @@ struct ClassTraits : public ReplicatedTraitsCommon, public ClassTraits { using element_reference = wtf::fp::FloatView; using buffer_view = wtf::buffer::BufferView; + using slice_type = buffer::ReplicatedView; }; template<> @@ -88,10 +91,25 @@ struct ClassTraits : public ReplicatedTraitsCommon, public ClassTraits { using element_reference = wtf::fp::FloatView; using buffer_view = wtf::buffer::BufferView; + using slice_type = buffer::ReplicatedView; }; template struct ClassTraits> : public ClassTraits {}; +struct ContiguousTraitsCommon { + using shape_type = shape::Smooth; + using const_shape_view = shape::SmoothView; +}; + +template<> +struct ClassTraits + : public ClassTraits, public ContiguousTraitsCommon {}; + +template<> +struct ClassTraits + : public ClassTraits, + public ContiguousTraitsCommon {}; + } // namespace tensorwrapper::types diff --git a/include/tensorwrapper/types/contiguous_traits.hpp b/include/tensorwrapper/types/contiguous_traits.hpp deleted file mode 100644 index 41aaeec5..00000000 --- a/include/tensorwrapper/types/contiguous_traits.hpp +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright 2025 NWChemEx-Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include -#include -#include - -namespace tensorwrapper::types { - -struct ContiguousTraitsCommon { - using shape_type = shape::Smooth; - using const_shape_view = shape::SmoothView; -}; - -template<> -struct ClassTraits - : public ClassTraits, public ContiguousTraitsCommon {}; - -template<> -struct ClassTraits - : public ClassTraits, - public ContiguousTraitsCommon {}; - -} // namespace tensorwrapper::types diff --git a/src/tensorwrapper/buffer/contiguous.cpp b/src/tensorwrapper/buffer/contiguous.cpp index d91c6b77..329e43d7 100644 --- a/src/tensorwrapper/buffer/contiguous.cpp +++ b/src/tensorwrapper/buffer/contiguous.cpp @@ -262,6 +262,16 @@ void Contiguous::set_elem_(index_vector index, element_type new_value) { m_buffer_.at(ordinal_index) = new_value; } +auto Contiguous::slice_(index_vector first_elem, index_vector last_elem) + -> slice_type { + return slice_type(*this, first_elem, last_elem); +} + +auto Contiguous::slice_(index_vector first_elem, index_vector last_elem) const + -> const_slice_type { + return const_slice_type(*this, first_elem, last_elem); +} + // ----------------------------------------------------------------------------- // -- Private Methods // ----------------------------------------------------------------------------- diff --git a/src/tensorwrapper/buffer/detail_/replicated_view_pimpl.hpp b/src/tensorwrapper/buffer/detail_/replicated_view_pimpl.hpp index e8785307..228b24b4 100644 --- a/src/tensorwrapper/buffer/detail_/replicated_view_pimpl.hpp +++ b/src/tensorwrapper/buffer/detail_/replicated_view_pimpl.hpp @@ -41,7 +41,9 @@ class ReplicatedViewPIMPL { using element_type = typename traits_type::element_type; using const_element_reference = typename traits_type::const_element_reference; - using index_vector = typename traits_type::index_vector; + using index_vector = typename traits_type::index_vector; + using slice_type = typename traits_type::slice_type; + using const_slice_type = typename traits_type::const_slice_type; ///@} /// Type of a pointer to the PIMPL @@ -56,6 +58,16 @@ class ReplicatedViewPIMPL { const_layout_reference layout() const { return layout_(); } + slice_type slice(const index_vector& first_elem, + const index_vector& last_elem) { + return slice_(first_elem, last_elem); + } + + const_slice_type slice(const index_vector& first_elem, + const index_vector& last_elem) const { + return slice_(first_elem, last_elem); + } + const_element_reference get_elem(const index_vector& slice_index) const { return get_elem_(slice_index); } @@ -73,6 +85,12 @@ class ReplicatedViewPIMPL { virtual pimpl_pointer clone_() const = 0; + virtual slice_type slice_(const index_vector& first_elem, + const index_vector& last_elem) = 0; + + virtual const_slice_type slice_(const index_vector& first_elem, + const index_vector& last_elem) const = 0; + /// Derived class should implement to be consistent with get_elem virtual const_element_reference get_elem_( const index_vector& slice_index) const = 0; diff --git a/src/tensorwrapper/buffer/detail_/slice_pimpl.hpp b/src/tensorwrapper/buffer/detail_/slice_pimpl.hpp index 9c8b1cfe..52d265f5 100644 --- a/src/tensorwrapper/buffer/detail_/slice_pimpl.hpp +++ b/src/tensorwrapper/buffer/detail_/slice_pimpl.hpp @@ -36,12 +36,14 @@ class SlicePIMPL : public ReplicatedViewPIMPL { ///@{ using typename my_base::const_element_reference; using typename my_base::const_layout_reference; + using typename my_base::const_slice_type; using typename my_base::element_type; using typename my_base::index_vector; using typename my_base::layout_reference; using typename my_base::layout_type; using typename my_base::pimpl_pointer; using typename my_base::size_type; + using typename my_base::slice_type; ///@} /// Pull in base's ctors @@ -94,6 +96,21 @@ class SlicePIMPL : public ReplicatedViewPIMPL { } } + slice_type slice_(const index_vector& first_elem, + const index_vector& last_elem) override { + auto new_first_elem = translate_index_(first_elem); + auto new_last_elem = translate_index_(last_elem); + return slice_type(*m_replicated_ptr_, new_first_elem, new_last_elem); + } + + const_slice_type slice_(const index_vector& first_elem, + const index_vector& last_elem) const override { + auto new_first_elem = translate_index_(first_elem); + auto new_last_elem = translate_index_(last_elem); + return const_slice_type(*m_replicated_ptr_, new_first_elem, + new_last_elem); + } + private: ReplicatedType* m_replicated_ptr_; index_vector m_first_elem_; diff --git a/src/tensorwrapper/buffer/replicated_view.cpp b/src/tensorwrapper/buffer/replicated_view.cpp index bffe44b9..de9ad448 100644 --- a/src/tensorwrapper/buffer/replicated_view.cpp +++ b/src/tensorwrapper/buffer/replicated_view.cpp @@ -90,6 +90,20 @@ void REPLICATED_VIEW::set_elem_(index_vector index, element_type value) { } } +TPARAMS +auto REPLICATED_VIEW::slice_(index_vector first_elem, index_vector last_elem) + -> slice_type { + if(!has_pimpl_()) return slice_type{}; + return m_pimpl_->slice(first_elem, last_elem); +} + +TPARAMS +auto REPLICATED_VIEW::slice_(index_vector first_elem, + index_vector last_elem) const -> const_slice_type { + if(!has_pimpl_()) return const_slice_type{}; + return std::as_const(*m_pimpl_).slice(first_elem, last_elem); +} + // ----------------------------------------------------------------------------- // -- Private methods // ----------------------------------------------------------------------------- diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/contiguous.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/contiguous.cpp index ea10af7e..f4c954ef 100644 --- a/tests/cxx/unit_tests/tensorwrapper/buffer/contiguous.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/contiguous.cpp @@ -220,6 +220,39 @@ TEMPLATE_LIST_TEST_CASE("Contiguous", "", types::floating_point_types) { REQUIRE(matrix.get_elem({1, 0}) == one); } + SECTION("slice()") { + REQUIRE(scalar.slice({}, {}).get_elem({}) == one); + + auto vector_slice = vector.slice({0}, {2}); + REQUIRE(vector_slice.layout().shape().size() == 2); + REQUIRE(vector_slice.get_elem({0}) == one); + REQUIRE(vector_slice.get_elem({1}) == two); + TestType nine_nine(99.0); + vector_slice.set_elem({0}, nine_nine); + REQUIRE(vector.get_elem({0}) == nine_nine); + + auto matrix_slice = matrix.slice({1, 0}, {2, 2}); + REQUIRE(matrix_slice.layout().shape().size() == 2); + REQUIRE(matrix_slice.get_elem({0, 0}) == three); + REQUIRE(matrix_slice.get_elem({0, 1}) == four); + matrix_slice.set_elem({0, 0}, nine_nine); + REQUIRE(matrix.get_elem({1, 0}) == nine_nine); + } + + SECTION("slice() const") { + REQUIRE(std::as_const(scalar).slice({}, {}).get_elem({}) == one); + + auto vector_slice = std::as_const(vector).slice({0}, {2}); + REQUIRE(vector_slice.layout().shape().size() == 2); + REQUIRE(vector_slice.get_elem({0}) == one); + REQUIRE(vector_slice.get_elem({1}) == two); + + auto matrix_slice = std::as_const(matrix).slice({1, 0}, {2, 2}); + REQUIRE(matrix_slice.layout().shape().size() == 2); + REQUIRE(matrix_slice.get_elem({0, 0}) == three); + REQUIRE(matrix_slice.get_elem({0, 1}) == four); + } + SECTION("infinity_norm") { REQUIRE_THROWS_AS(defaulted.infinity_norm(), std::runtime_error); REQUIRE(scalar.infinity_norm() == one); diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/replicated_view.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/replicated_view.cpp index f0c3afef..c2d66233 100644 --- a/tests/cxx/unit_tests/tensorwrapper/buffer/replicated_view.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/replicated_view.cpp @@ -26,19 +26,22 @@ using namespace buffer; TEST_CASE("ReplicatedView") { using view_type = ReplicatedView; using const_view_type = ReplicatedView; - auto pvector = testing::eigen_vector(4); + using TestType = double; + auto pvector = testing::eigen_vector(4); auto& vector = *pvector; // vector has values [0, 1, 2, 3] at indices 0, 1, 2, 3 - auto pmatrix = testing::eigen_matrix(2, 2); + auto pmatrix = testing::eigen_matrix(2, 2); auto& matrix = *pmatrix; // matrix has values: (0,0)=1, (0,1)=2, (1,0)=3, (1,1)=4 view_type vector_view(vector, {1}, {3}); const_view_type const_vector_view(vector, {1}, {3}); + // vector_view has values [1, 2] at indices 0, 1 view_type matrix_view(matrix, {0, 1}, {2, 2}); const_view_type const_matrix_view(matrix, {0, 1}, {2, 2}); + // matrix_view has values [2, 4] at indices (0,0), (1,0) SECTION("Slice Constructor") { // Slice indices 1:3 of vector -> slice index 0 maps to 1, slice index 1 @@ -90,4 +93,50 @@ TEST_CASE("ReplicatedView") { view_type defaulted; REQUIRE_THROWS_AS(defaulted.set_elem({0}, 1.0), std::runtime_error); } + + SECTION("slice()") { + auto vector_slice = vector_view.slice({1}, {2}); + // vector_slice has values [2] at index 0 + REQUIRE(vector_slice.layout().shape().size() == 1); + REQUIRE(vector_slice.get_elem({0}) == 2.0); + + auto const_vector_slice = const_vector_view.slice({1}, {2}); + REQUIRE(const_vector_slice.layout().shape().size() == 1); + REQUIRE(const_vector_slice.get_elem({0}) == 2.0); + + auto matrix_slice = matrix_view.slice({0, 0}, {1, 1}); + // matrix_slice has values [2] at index (0,0) + REQUIRE(matrix_slice.layout().shape().size() == 1); + REQUIRE(matrix_slice.get_elem({0, 0}) == 2.0); + + auto const_matrix_slice = const_matrix_view.slice({0, 0}, {1, 1}); + REQUIRE(const_matrix_slice.layout().shape().size() == 1); + REQUIRE(const_matrix_slice.get_elem({0, 0}) == 2.0); + + TestType nine_nine(99.0); + vector_slice.set_elem({0}, nine_nine); + REQUIRE(vector.get_elem({2}) == nine_nine); + matrix_slice.set_elem({0, 0}, nine_nine); + REQUIRE(matrix.get_elem({0, 1}) == nine_nine); + } + + SECTION("slice() const") { + auto vector_slice = std::as_const(vector_view).slice({1}, {2}); + REQUIRE(vector_slice.layout().shape().size() == 1); + REQUIRE(vector_slice.get_elem({0}) == 2.0); + + auto const_vector_slice = + std::as_const(const_vector_view).slice({1}, {2}); + REQUIRE(const_vector_slice.layout().shape().size() == 1); + REQUIRE(const_vector_slice.get_elem({0}) == 2.0); + + auto matrix_slice = std::as_const(matrix_view).slice({0, 0}, {1, 1}); + REQUIRE(matrix_slice.layout().shape().size() == 1); + REQUIRE(matrix_slice.get_elem({0, 0}) == 2.0); + + auto const_matrix_slice = + std::as_const(const_matrix_view).slice({0, 0}, {1, 1}); + REQUIRE(const_matrix_slice.layout().shape().size() == 1); + REQUIRE(const_matrix_slice.get_elem({0, 0}) == 2.0); + } }