Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion include/tensorwrapper/buffer/contiguous.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include <tensorwrapper/buffer/replicated.hpp>
#include <tensorwrapper/concepts/floating_point.hpp>
#include <tensorwrapper/shape/smooth.hpp>
#include <tensorwrapper/types/contiguous_traits.hpp>
#include <tensorwrapper/types/buffer_traits.hpp>
#include <tensorwrapper/types/floating_point.hpp>

namespace tensorwrapper::buffer {
Expand Down Expand Up @@ -294,6 +294,11 @@ class Contiguous : public Replicated {
*/
void set_elem_(index_vector index, element_type new_value) override;

slice_type slice_(index_vector first_elem, index_vector last_elem) override;

const_slice_type slice_(index_vector first_elem,
index_vector last_elem) const override;

private:
/// Type for storing the hash of *this
using hash_type = std::size_t;
Expand Down
6 changes: 6 additions & 0 deletions include/tensorwrapper/buffer/replicated.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once
#include <tensorwrapper/buffer/local.hpp>
#include <tensorwrapper/buffer/replicated_common.hpp>
#include <tensorwrapper/buffer/replicated_view.hpp>

namespace tensorwrapper::buffer {

Expand All @@ -36,9 +37,14 @@ class Replicated : public ReplicatedCommon<Replicated>, public Local {

protected:
friend my_base_type;
friend my_base_type::sliceable_base;

virtual const_element_reference get_elem_(index_vector index) const = 0;
virtual void set_elem_(index_vector index, element_type value) = 0;
virtual slice_type slice_(index_vector first_elem,
index_vector last_elem) = 0;
virtual const_slice_type slice_(index_vector first_elem,
index_vector last_elem) const = 0;
};

} // namespace tensorwrapper::buffer
10 changes: 9 additions & 1 deletion include/tensorwrapper/buffer/replicated_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/

#pragma once
#include <tensorwrapper/concepts/has_begin_end.hpp>
#include <tensorwrapper/interfaces/sliceable.hpp>
#include <tensorwrapper/types/buffer_traits.hpp>

namespace tensorwrapper::buffer {
Expand All @@ -34,10 +36,13 @@ namespace tensorwrapper::buffer {
* buffers.
*/
template<typename Derived>
class ReplicatedCommon {
class ReplicatedCommon : public interfaces::Sliceable<Derived> {
private:
using my_traits = types::ClassTraits<Derived>;

protected:
using sliceable_base = interfaces::Sliceable<Derived>;

public:
/// Pull in types from traits_type
///@{
Expand All @@ -46,6 +51,9 @@ class ReplicatedCommon {
using const_element_reference = typename my_traits::const_element_reference;
using size_type = typename my_traits::size_type;
using index_vector = typename my_traits::index_vector;
using slice_type = typename my_traits::slice_type;
using const_slice_type = typename my_traits::const_slice_type;
using slice_il_type = typename my_traits::slice_il_type;
///@}

/** @brief Returns the element at the given index.
Expand Down
10 changes: 10 additions & 0 deletions include/tensorwrapper/buffer/replicated_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ class ReplicatedView
/// Pull in base's types
///@{
using typename common_base_type::const_element_reference;
using typename common_base_type::const_slice_type;
using typename common_base_type::element_reference;
using typename common_base_type::element_type;
using typename common_base_type::index_vector;
using typename common_base_type::size_type;
using typename common_base_type::slice_type;
///@}

/// Type of the PIMPL
Expand Down Expand Up @@ -154,13 +156,21 @@ class ReplicatedView

protected:
friend common_base_type;
friend typename common_base_type::sliceable_base;

/// Implements get_elem for the view.
const_element_reference get_elem_(index_vector index) const;

/// Implements set_elem for the view.
void set_elem_(index_vector index, element_type value);

/// Implements slice for the view.
slice_type slice_(index_vector first_elem, index_vector last_elem);

/// Implements slice for the view.
const_slice_type slice_(index_vector first_elem,
index_vector last_elem) const;

private:
/// Does *this have a PIMPL?
bool has_pimpl_() const noexcept;
Expand Down
28 changes: 28 additions & 0 deletions include/tensorwrapper/concepts/has_begin_end.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright 2026 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <concepts>

namespace tensorwrapper::concepts {

template<typename T>
concept HasBeginEnd = requires(T t) {
{ t.begin() } -> std::same_as<typename T::iterator>;
{ t.end() } -> std::same_as<typename T::iterator>;
};

} // namespace tensorwrapper::concepts
166 changes: 166 additions & 0 deletions include/tensorwrapper/interfaces/sliceable.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
* Copyright 2026 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <tensorwrapper/concepts/has_begin_end.hpp>
#include <tensorwrapper/types/types.hpp>

namespace tensorwrapper::interfaces {

/** @brief Defines the interface for objects that can be sliced.
*
* @tparam Derived The type of the derived class.
*
* Assume we have a rank @f$r@f$ object. A slice of this object is a new object
* that is a rank @f$r@f$ object with a subset relationship with respect to the
* original object. Slices are contiguous in index space (for selecting
* arbitrary elements from the original object see masks), meaning the slice
* can be specified by providing two indices. The first index is the index of
* the first element IN the slice and the second index is the index of the
* first element NOT IN the slice. We term the first index in the slice is
* denoted @f$i_0@f$ and the first index not in the slice is denoted @f$i_N@f$
* where @f$N@f$ is the size of the slice (i.e. the number of elements in the
* slice).
*
* TODO: Example of the range notation.
*
* The Sliceable interface defines a number of overloads for the `slice` method
* these overloads are:
*
* 1. @f$i_0@f$ and @f$i_N@f$ are provided as @f$r@f$ element initializer
* lists. The returned slice is mutable.
* 2. @f$i_0@f$ and @f$i_N@f$ are provided as @f$r@f$ element initalizer
* lists. The returned slice is read-only.
* 3. @f$i_0@f$ and @f$i_N@f$ are provided as @f$r@f$ element containers that
* support `begin` and `end` iterators. The returned slice is mutable.
* 4. @f$i_0@f$ and @f$i_N@f$ are provided as @f$r@f$ element containers that
* support `begin` and `end` iterators. The returned slice is read-only.
* 5. @f$i_0@f$ and @f$i_N@f$ are provided as @f$r@f$ ranges of iterators. The
* returned slice is mutable.
* 6. @f$i_0@f$ and @f$i_N@f$ are provided as @f$r@f$ ranges of iterators. The
* returned slice is read-only.
*
* To use this interface the @p Derived class must:
*
* 1. Specialize `ClassTraits<Derived>` to provide a member type `slice_type`
* and `const_slice_type`.
* 2. Implement the const and non-const versions of the `slice_` method.
*/
template<typename Derived>
class Sliceable {
private:
using my_traits = types::ClassTraits<Derived>;

public:
using index_vector = typename my_traits::index_vector;
using slice_type = typename my_traits::slice_type;
using const_slice_type = typename my_traits::const_slice_type;
using slice_il_type = typename my_traits::slice_il_type;

/// Overload 1.
slice_type slice(slice_il_type first_elem, slice_il_type last_elem) {
return slice(first_elem.begin(), first_elem.end(), last_elem.begin(),
last_elem.end());
}

/// Overload 2.
const_slice_type slice(slice_il_type first_elem,
slice_il_type last_elem) const {
return slice_impl_(index_vector(first_elem.begin(), first_elem.end()),
index_vector(last_elem.begin(), last_elem.end()));
}

/// Overload 3.
template<concepts::HasBeginEnd ContainerType0,
concepts::HasBeginEnd ContainerType1>
slice_type slice(ContainerType0&& first_elem, ContainerType1&& last_elem);

/// Overload 4.
template<concepts::HasBeginEnd ContainerType0,
concepts::HasBeginEnd ContainerType1>
const_slice_type slice(ContainerType0&& first_elem,
ContainerType1&& last_elem) const;

/// Overload 5.
template<std::forward_iterator BeginItr0, std::forward_iterator EndItr0,
std::forward_iterator BeginItr1, std::forward_iterator EndItr1>
slice_type slice(BeginItr0 first_elem_begin, EndItr0 first_elem_end,
BeginItr1 last_elem_begin, EndItr1 last_elem_end) {
return slice_impl_(index_vector(first_elem_begin, first_elem_end),
index_vector(last_elem_begin, last_elem_end));
}

/// Overload 6.
template<std::forward_iterator BeginItr0, std::forward_iterator EndItr0,
std::forward_iterator BeginItr1, std::forward_iterator EndItr1>
const_slice_type slice(BeginItr0 first_elem_begin, EndItr0 first_elem_end,
BeginItr1 last_elem_begin,
EndItr1 last_elem_end) const {
return slice_impl_(index_vector(first_elem_begin, first_elem_end),
index_vector(last_elem_begin, last_elem_end));
}

private:
slice_type slice_impl_(index_vector first_elem, index_vector last_elem) {
return derived().slice_(first_elem, last_elem);
}

const_slice_type slice_impl_(index_vector first_elem,
index_vector last_elem) const {
return derived().slice_(first_elem, last_elem);
}

Derived& derived() { return static_cast<Derived&>(*this); }

const Derived& derived() const {
return static_cast<const Derived&>(*this);
}
};

// -----------------------------------------------------------------------------
// -- Out of line implementations
// -----------------------------------------------------------------------------

template<typename Derived>
template<concepts::HasBeginEnd ContainerType0,
concepts::HasBeginEnd ContainerType1>
auto Sliceable<Derived>::slice(ContainerType0&& first_elem,
ContainerType1&& last_elem) -> slice_type {
if constexpr(std::is_same_v<std::decay_t<ContainerType0>, index_vector> &&
std::is_same_v<std::decay_t<ContainerType1>, index_vector>) {
return slice_impl_(first_elem, last_elem);
} else {
return slice_impl_(index_vector(first_elem.begin(), first_elem.end()),
index_vector(last_elem.begin(), last_elem.end()));
}
}

template<typename Derived>
template<concepts::HasBeginEnd ContainerType0,
concepts::HasBeginEnd ContainerType1>
auto Sliceable<Derived>::slice(ContainerType0&& first_elem,
ContainerType1&& last_elem) const
-> const_slice_type {
if constexpr(std::is_same_v<std::decay_t<ContainerType0>, index_vector> &&
std::is_same_v<std::decay_t<ContainerType1>, index_vector>) {
return slice_impl_(first_elem, last_elem);
} else {
return slice_impl_(index_vector(first_elem.begin(), first_elem.end()),
index_vector(last_elem.begin(), last_elem.end()));
}
}

} // namespace tensorwrapper::interfaces
18 changes: 18 additions & 0 deletions include/tensorwrapper/types/buffer_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,42 @@ struct ReplicatedTraitsCommon {
using buffer_type = wtf::buffer::FloatBuffer;
using const_buffer_view = wtf::buffer::BufferView<const element_type>;
using index_vector = std::vector<types::CommonTypes::size_type>;
using const_slice_type = buffer::ReplicatedView<const buffer::Replicated>;
using slice_il_type = std::initializer_list<types::CommonTypes::size_type>;
};

template<>
struct ClassTraits<buffer::Replicated> : public ReplicatedTraitsCommon,
public ClassTraits<buffer::Local> {
using element_reference = wtf::fp::FloatView<element_type>;
using buffer_view = wtf::buffer::BufferView<element_type>;
using slice_type = buffer::ReplicatedView<buffer::Replicated>;
};

template<>
struct ClassTraits<const buffer::Replicated>
: public ReplicatedTraitsCommon, public ClassTraits<const buffer::Local> {
using element_reference = wtf::fp::FloatView<const element_type>;
using buffer_view = wtf::buffer::BufferView<const element_type>;
using slice_type = buffer::ReplicatedView<const buffer::Replicated>;
};

template<typename ReplicatedType>
struct ClassTraits<buffer::ReplicatedView<ReplicatedType>>
: public ClassTraits<ReplicatedType> {};

struct ContiguousTraitsCommon {
using shape_type = shape::Smooth;
using const_shape_view = shape::SmoothView<const shape_type>;
};

template<>
struct ClassTraits<tensorwrapper::buffer::Contiguous>
: public ClassTraits<buffer::Replicated>, public ContiguousTraitsCommon {};

template<>
struct ClassTraits<const tensorwrapper::buffer::Contiguous>
: public ClassTraits<const buffer::Replicated>,
public ContiguousTraitsCommon {};

} // namespace tensorwrapper::types
39 changes: 0 additions & 39 deletions include/tensorwrapper/types/contiguous_traits.hpp

This file was deleted.

Loading
Loading