Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
.idea/
.vscode/
.cache/
.cursor/

# These are common Python virtual enviornment directory names
venv/
Expand Down
116 changes: 20 additions & 96 deletions include/tensorwrapper/buffer/buffer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#pragma once
#include <tensorwrapper/buffer/buffer_base_common.hpp>
#include <tensorwrapper/detail_/dsl_base.hpp>
#include <tensorwrapper/detail_/polymorphic_base.hpp>
#include <tensorwrapper/dsl/labeled.hpp>
Expand All @@ -25,14 +26,19 @@ namespace tensorwrapper::buffer {

/** @brief Common base class for all buffer objects.
*
* All classes which wrap existing tensor libraries derive from this class.
* All classes which own their state and wrap existing tensor libraries derive
* from this class.
*/
class BufferBase : public tensorwrapper::detail_::PolymorphicBase<BufferBase>,
class BufferBase : public BufferBaseCommon<BufferBase>,
public tensorwrapper::detail_::PolymorphicBase<BufferBase>,
public tensorwrapper::detail_::DSLBase<BufferBase> {
private:
/// Type of *this
using my_type = BufferBase;

/// Type of the common base class
using common_base = BufferBaseCommon<my_type>;

/// Traits of my_type
using my_traits = types::ClassTraits<my_type>;

Expand All @@ -58,96 +64,9 @@ class BufferBase : public tensorwrapper::detail_::PolymorphicBase<BufferBase>,
using const_buffer_base_pointer =
typename my_traits::const_buffer_base_pointer;

/// Type of the class describing the physical layout of the buffer
using layout_type = layout::Physical;

/// Type of a read-only reference to a layout
using const_layout_reference = const layout_type&;

/// Type of a pointer to the layout
using layout_pointer = std::unique_ptr<layout_type>;

/// Type used to represent the tensor's rank
using rank_type = typename layout_type::size_type;

// -------------------------------------------------------------------------
// -- Accessors
// -------------------------------------------------------------------------

/** @brief Does *this have a layout?
*
* Default constructed or moved from BufferBase objects do not have
* layouts. This method is used to determine if *this has a layout or not.
*
* @return True if *this has a layout and false otherwise.
*
* @throw None No throw guarantee.
*/
bool has_layout() const noexcept { return static_cast<bool>(m_layout_); }

/** @brief Retrieves the layout of *this.
*
* This method can be used to retrieve the layout associated with *this,
* assuming there is one. See has_layout for determining if *this has a
* layout or not.
*
* @return A read-only reference to the layout.
*
* @throw std::runtime_error if *this does not have a layout. Strong throw
* guarantee.
*/
const_layout_reference layout() const {
assert_layout_();
return *m_layout_;
}

rank_type rank() const noexcept {
return has_layout() ? layout().rank() : 0;
}

// -------------------------------------------------------------------------
// -- Utility methods
// -------------------------------------------------------------------------

/** @brief Is *this value equal to @p rhs?
*
* Two BufferBase objects are value equal if the layouts they contain are
* polymorphically value equal or if both BufferBase objects do not contain
* a layout.
*
* @param[in] rhs The object to compare to.
*
* @return True if *this is value equal to @p rhs and false otherwise.
*
* @throw None No throw guarantee.
*/
bool operator==(const BufferBase& rhs) const noexcept {
if(has_layout() != rhs.has_layout()) return false;
if(has_layout() && m_layout_->are_different(*rhs.m_layout_))
return false;
return true;
}

/** @brief Is *this different from @p rhs?
*
* This method defines "different from" as being "not value equal." See
* the description of operator== for the definition of value equal.
*
* @param[in] rhs The object to compare to.
*
* @return False if *this is value equal to @p rhs and true otherwise.
*
* @throw None No throw guarantee.
*/

bool operator!=(const BufferBase& rhs) const noexcept {
return !(*this == rhs);
}

bool approximately_equal(const BufferBase& rhs, double tol) const {
return approximately_equal_(rhs, tol);
}

protected:
// -------------------------------------------------------------------------
// -- Ctors, assignment
Expand Down Expand Up @@ -215,6 +134,14 @@ class BufferBase : public tensorwrapper::detail_::PolymorphicBase<BufferBase>,
return *this;
}

// -------------------------------------------------------------------------
// -- BufferBaseCommon hooks
// -------------------------------------------------------------------------
friend common_base;
bool has_layout_() const noexcept { return static_cast<bool>(m_layout_); }

const_layout_reference layout_() const { return *m_layout_; }

dsl_reference addition_assignment_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs) override;
Expand All @@ -233,19 +160,16 @@ class BufferBase : public tensorwrapper::detail_::PolymorphicBase<BufferBase>,
virtual bool approximately_equal_(const BufferBase& rhs,
double tol) const = 0;

template<typename BufferBaseType>
bool approximately_equal_(const BufferViewBase<BufferBaseType>& rhs,
double tol) const;

private:
template<typename FxnType>
dsl_reference binary_op_common_(FxnType&& fxn, label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs);

/// Throws std::runtime_error when there is no layout
void assert_layout_() const {
if(has_layout()) return;
throw std::runtime_error(
"Buffer has no layout. Was it default initialized?");
}

/// The layout of *this
layout_pointer m_layout_;
};
Expand Down
150 changes: 150 additions & 0 deletions include/tensorwrapper/buffer/buffer_base_common.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*
* Copyright 2024 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <stdexcept>
#include <tensorwrapper/buffer/buffer_fwd.hpp>
#include <tensorwrapper/types/buffer_traits.hpp>

namespace tensorwrapper::buffer {

/** @brief CRTP base factoring the common layout/equality API of BufferBase and
* BufferViewBase.
*
* Derived must implement the protected hooks: has_layout_(), layout_(),
* and approximately_equal_(const BufferBase&, double).
*
* @tparam Derived The CRTP derived type (BufferBase or BufferViewBase).
*/
template<typename Derived>
class BufferBaseCommon {
private:
/// Type of *this
using my_type = BufferBaseCommon<Derived>;

/// Traits for my_type
using traits_type = types::ClassTraits<my_type>;

/// Traits for Derived
using derived_traits = types::ClassTraits<Derived>;

public:
///@{
using layout_type = typename traits_type::layout_type;
using const_layout_reference = typename traits_type::const_layout_reference;
using rank_type = typename traits_type::rank_type;
///@}

// -------------------------------------------------------------------------
// -- Accessors
// -------------------------------------------------------------------------

/** @brief Does *this have a layout?
*
* @return True if *this has a layout and false otherwise.
*
* @throw None No throw guarantee.
*/
bool has_layout() const noexcept { return derived_().has_layout_(); }

/** @brief Retrieves the layout of *this.
*
* @return A read-only reference to the layout.
*
* @throw std::runtime_error if *this does not have a layout. Strong throw
* guarantee.
*/
const_layout_reference layout() const {
assert_layout_();
return derived_().layout_();
}

/** @brief Returns the rank of the layout.
*
* @return The rank, or 0 if *this has no layout.
*
* @throw None No throw guarantee.
*/
rank_type rank() const noexcept {
return has_layout() ? layout().rank() : 0;
}

// -------------------------------------------------------------------------
// -- Utility methods
// -------------------------------------------------------------------------

/** @brief Is *this value equal to @p rhs?
*
* @param[in] rhs The object to compare to.
*
* @return True if *this is value equal to @p rhs and false otherwise.
*
* @throw None No throw guarantee.
*/
template<typename OtherDerived>
bool operator==(const BufferBaseCommon<OtherDerived>& rhs) const noexcept {
if(has_layout() != rhs.has_layout()) return false;
if(has_layout() && layout().are_different(rhs.layout())) return false;
return true;
}

/** @brief Is *this different from @p rhs?
*
* @param[in] rhs The object to compare to.
*
* @return False if *this is value equal to @p rhs and true otherwise.
*
* @throw None No throw guarantee.
*/
template<typename OtherDerived>
bool operator!=(const BufferBaseCommon<OtherDerived>& rhs) const noexcept {
return !(*this == rhs);
}

/** @brief Are *this and @p rhs approximately equal within @p tol?
*
* @param[in] rhs The object to compare to.
* @param[in] tol The tolerance for the comparison.
*
* @return True if approximately equal, false otherwise.
*/
template<typename OtherDerived>
bool approximately_equal(const BufferBaseCommon<OtherDerived>& rhs,
double tol) const {
return derived_().approximately_equal_(rhs.derived_(), tol);
}

protected:
void assert_layout_() const {
if(!has_layout()) {
throw std::runtime_error(
"Buffer has no layout. Was it default initialized?");
}
}

private:
template<typename OtherDerived>
friend class BufferBaseCommon;

Derived& derived_() noexcept { return static_cast<Derived&>(*this); }

/// Access derived for CRTP
const Derived& derived_() const noexcept {
return *static_cast<const Derived*>(this);
}
};

} // namespace tensorwrapper::buffer
6 changes: 6 additions & 0 deletions include/tensorwrapper/buffer/buffer_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@

namespace tensorwrapper::buffer {

template<typename Derived>
class BufferBaseCommon;

class BufferBase;

template<typename BufferBaseType>
class BufferViewBase;

class Contiguous;

class Local;
Expand Down
Loading