Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion include/tensorwrapper/layout/layout_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <tensorwrapper/shape/shape_base.hpp>
#include <tensorwrapper/sparsity/pattern.hpp>
#include <tensorwrapper/symmetry/group.hpp>
#include <tensorwrapper/types/layout_traits.hpp>

namespace tensorwrapper::layout {

Expand All @@ -29,6 +30,10 @@ namespace tensorwrapper::layout {
*/
class LayoutBase : public tensorwrapper::detail_::PolymorphicBase<LayoutBase>,
public tensorwrapper::detail_::DSLBase<LayoutBase> {
private:
/// Type defining types for *this
using traits_type = types::ClassTraits<LayoutBase>;

public:
/// Type all layouts derive from
using layout_base = LayoutBase;
Expand Down Expand Up @@ -70,7 +75,7 @@ class LayoutBase : public tensorwrapper::detail_::PolymorphicBase<LayoutBase>,
using sparsity_pointer = std::unique_ptr<sparsity_type>;

/// Type used for indexing and offsets
using size_type = std::size_t;
using size_type = typename traits_type::size_type;

// -------------------------------------------------------------------------
// -- Ctors and dtor
Expand Down Expand Up @@ -186,6 +191,9 @@ class LayoutBase : public tensorwrapper::detail_::PolymorphicBase<LayoutBase>,
return *m_sparsity_;
}

/** @brief True if *this is a NULL layout and false otherwise. */
bool is_null() const noexcept { return !static_cast<bool>(m_shape_); }

/** @brief The rank of the tensor this layout describes.
*
* This method is convenience function for calling the rank methods on one
Expand Down Expand Up @@ -214,6 +222,10 @@ class LayoutBase : public tensorwrapper::detail_::PolymorphicBase<LayoutBase>,
* @throw None No throw guarantee.
*/
bool operator==(const layout_base& rhs) const noexcept {
if(is_null() && rhs.is_null())
return true;
else if(is_null() || rhs.is_null())
return false;
if(m_shape_->are_different(*rhs.m_shape_)) return false;
if(m_symmetry_->are_different(*rhs.m_symmetry_)) return false;
if(m_sparsity_->are_different(*rhs.m_sparsity_)) return false;
Expand Down
173 changes: 173 additions & 0 deletions include/tensorwrapper/layout/layout_common.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
* Copyright 2026 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <tensorwrapper/layout/layout_base.hpp>
#include <tensorwrapper/types/layout_traits.hpp>

namespace tensorwrapper::layout {

template<typename Derived>
class LayoutCommon : public LayoutBase {
private:
/// Type of *this
using my_type = LayoutCommon<Derived>;

/// Type defining the types for *this
using traits_type = types::ClassTraits<my_type>;

public:
///@{
using slice_type = typename traits_type::slice_type;
using offset_il_type = typename traits_type::offset_il_type;
///@}

/// Pull in base class's ctors
using LayoutBase::LayoutBase;

/** @brief Slices a layout given two initializer lists.
*
* C++ doesn't allow templates to work with initializer lists, therefore
* we must provide a special overload for when the input containers are
* initializer lists. This method simply dispatches to the range-based
* method by calling begin()/end() on each initializer list. See the
* description of that method for more details.
*
* @param[in] first_elem An initializer list containing the offsets of
* the first element IN the slice such that
* `first_elem[i]` is the offset along mode i.
* @param[in] last_elem An initializer list containing the offsets of
* the first element NOT IN the slice such that
* `last_elem[i]` is the offset along mode i.
*
* @return The requested slice.
*
* @throws ??? If the range-based method throws. Same throw guarantee.
*/
slice_type slice(offset_il_type first_elem,
offset_il_type last_elem) const {
return slice(first_elem.begin(), first_elem.end(), last_elem.begin(),
last_elem.end());
}

/** @brief Slices a layout given two containers.
*
* @tparam ContainerType0 The type of first_elem. Assumed to have
* begin()/end() methods.
* @tparam ContainerType1 The type of last_elem. Assumed to have
* begin()/end() methods.
*
* Element indices are usually stored in containers. This overload is a
* convenience method for calling begin()/end() on the containers before
* dispatching to the range-based overload. See the documentation for the
* range-based overload for more details.
*
* @param[in] first_elem A container containing the offsets of
* the first element IN the slice such that
* `first_elem[i]` is the offset along mode i.
* @param[in] last_elem A container containing the offsets of
* the first element NOT IN the slice such that
* `last_elem[i]` is the offset along mode i.
*
* @return The requested slice.
*
* @throws ??? If the range-based method throws. Same throw guarantee.
*/
template<typename ContainerType0, typename ContainerType1>
slice_type slice(ContainerType0&& first_elem, ContainerType1&& last_elem) {
return slice(first_elem.begin(), first_elem.end(), last_elem.begin(),
last_elem.end());
}

/** @brief Implements slicing given two ranges.
*
* @tparam BeginItr The type of the iterators pointing to offsets in the
* container holding the first element of the slice.
* @tparam EndItr The type of the iterators pointing to the offsets in
* the container holding the first element NOT in the
* slice.
*
* All other slice functions dispatch to this method.
*
* Slices are assumed to be contiguous, meaning we can uniquely specify
* the slice by providing the first element IN the slice and the first
* element NOT IN the slice.
*
* Specifying an element of a rank @f$r@f$ tensor requires providing
* @f$r@f$ offsets (one for each mode). Generally speaking, this requires
* the offsets to be in a container. This method takes iterators to those
* containers such that the @f$r@f$ elements in the range
* [first_elem_begin, first_elem_end) are the offsets of first element IN
* the slice and [last_elem_begin, last_elem_end) are the offsets of the
* first element NOT IN the slice.
*
* @note Both [first_elem_begin, first_elem_end) and
* [last_elem_begin, last_elem_end) being empty is allowed as long
* as *this is null or for a scalar. In these cases you will get back
* the only slice possible, which is the entire shape, i.e. a copy of
* *this.
*
* @param[in] first_elem_begin An iterator to the offset along mode 0 for
* the first element in the slice.
* @param[in] first_elem_end An iterator pointing to just past the offset
* along mode "r-1" (r being the rank of *this) for the first
* element in the slice.
* @param[in] last_elem_begin An iterator to the offset along mode 0 for
* the first element NOT in the slice.
* @param[in] last_elem_end An iterator pointing to just past the offset
* along mode "r-1" (r being the rank of *this) for the first
* element NOT in the slice.
*
* @return The requested slice.
*
* @throw std::runtime_error if the range
* [first_elem_begin, first_elem_end) does not contain the same
* number of elements as [last_elem_begin, last_elem_end).
* Strong throw guarantee.
* @throw std::runtime_error if the offsets in the range
* [first_elem_begin, first_elem_end) do not come before the
* offsets in [last_elem_begin, last_elem_end). Strong throw
* guarantee.
* @throw std::runtime_error if [first_elem_begin, first_elem_end) and
* [last_elem_begin, last_elem_end) contain the
* same number of offsets, but that number is NOT
* equal to the rank of *this. Strong throw
* guarantee.
*
*/
template<typename BeginItr, typename EndItr>
slice_type slice(BeginItr first_elem_begin, BeginItr first_elem_end,
EndItr last_elem_begin, EndItr last_elem_end) const;
};

template<typename Derived>
template<typename BeginItr, typename EndItr>
inline auto LayoutCommon<Derived>::slice(BeginItr first_elem_begin,
BeginItr first_elem_end,
EndItr last_elem_begin,
EndItr last_elem_end) const
-> slice_type {
if(this->is_null()) return Derived{};
auto new_shape = shape().as_smooth().slice(first_elem_begin, first_elem_end,
last_elem_begin, last_elem_end);
auto new_symmetry = symmetry().slice(first_elem_begin, first_elem_end,
last_elem_begin, last_elem_end);
auto new_sparsity = sparsity().slice(first_elem_begin, first_elem_end,
last_elem_begin, last_elem_end);
return slice_type{new_shape, new_symmetry, new_sparsity};
}

} // namespace tensorwrapper::layout
29 changes: 29 additions & 0 deletions include/tensorwrapper/layout/layout_fwd.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright 2026 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

namespace tensorwrapper::layout {

class LayoutBase;

template<typename Derived>
class LayoutCommon;

class Logical;
class Physical;

} // namespace tensorwrapper::layout
6 changes: 3 additions & 3 deletions include/tensorwrapper/layout/physical.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
*/

#pragma once
#include <tensorwrapper/layout/layout_base.hpp>

#include <tensorwrapper/layout/layout_common.hpp>
namespace tensorwrapper::layout {

/** @brief Specializes a LayoutBase for a layout describing how a tensor is
Expand All @@ -26,10 +26,10 @@ namespace tensorwrapper::layout {
* to hold details such as row major vs column major that matter for the
* physical layout, but not the logical layout.
*/
class Physical : public LayoutBase {
class Physical : public LayoutCommon<Physical> {
private:
/// Type *this derives from
using my_base_type = LayoutBase;
using my_base_type = LayoutCommon<Physical>;

public:
/// Pull in base class's types
Expand Down
5 changes: 2 additions & 3 deletions include/tensorwrapper/shape/shape_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@

#pragma once
#include <cstddef>
#include <memory>
#include <tensorwrapper/detail_/dsl_base.hpp>
#include <tensorwrapper/detail_/polymorphic_base.hpp>
#include <tensorwrapper/shape/shape_traits.hpp>
#include <tensorwrapper/shape/smooth_view.hpp>
#include <tensorwrapper/types/shape_traits.hpp>

namespace tensorwrapper::shape {

Expand All @@ -42,7 +41,7 @@ class ShapeBase : public tensorwrapper::detail_::PolymorphicBase<ShapeBase>,
public tensorwrapper::detail_::DSLBase<ShapeBase> {
private:
/// Type implementing the traits of this
using traits_type = ShapeTraits<ShapeBase>;
using traits_type = types::ClassTraits<ShapeBase>;

protected:
/// Typedef of the PolymorphicBase class of *this
Expand Down
82 changes: 0 additions & 82 deletions include/tensorwrapper/shape/shape_traits.hpp

This file was deleted.

2 changes: 1 addition & 1 deletion include/tensorwrapper/shape/smooth.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
#pragma once
#include <functional>
#include <numeric>
#include <shape/shape_traits.hpp>
#include <shape/smooth_common.hpp>
#include <tensorwrapper/shape/shape_base.hpp>
#include <tensorwrapper/types/shape_traits.hpp>
#include <vector>

namespace tensorwrapper::shape {
Expand Down
Loading