From 98aac116c898af62c58bfbf7dc6a8f8a734cfdd1 Mon Sep 17 00:00:00 2001 From: "Ryan M. Richard" Date: Mon, 9 Mar 2026 22:13:32 -0500 Subject: [PATCH] Adds BufferBaseCommon --- .gitignore | 1 + include/tensorwrapper/buffer/buffer_base.hpp | 116 ++---------- .../buffer/buffer_base_common.hpp | 150 +++++++++++++++ include/tensorwrapper/buffer/buffer_fwd.hpp | 6 + .../tensorwrapper/buffer/buffer_view_base.hpp | 173 ++++++++++++++++++ include/tensorwrapper/types/buffer_traits.hpp | 16 +- .../tensorwrapper/buffer/buffer_base.cpp | 83 --------- .../buffer/buffer_base_common.cpp | 114 ++++++++++++ .../tensorwrapper/buffer/buffer_view_base.cpp | 136 ++++++++++++++ 9 files changed, 615 insertions(+), 180 deletions(-) create mode 100644 include/tensorwrapper/buffer/buffer_base_common.hpp create mode 100644 include/tensorwrapper/buffer/buffer_view_base.hpp delete mode 100644 tests/cxx/unit_tests/tensorwrapper/buffer/buffer_base.cpp create mode 100644 tests/cxx/unit_tests/tensorwrapper/buffer/buffer_base_common.cpp create mode 100644 tests/cxx/unit_tests/tensorwrapper/buffer/buffer_view_base.cpp diff --git a/.gitignore b/.gitignore index 2df4161c..2715503b 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ .idea/ .vscode/ .cache/ +.cursor/ # These are common Python virtual enviornment directory names venv/ diff --git a/include/tensorwrapper/buffer/buffer_base.hpp b/include/tensorwrapper/buffer/buffer_base.hpp index 9926c1e6..d24201b9 100644 --- a/include/tensorwrapper/buffer/buffer_base.hpp +++ b/include/tensorwrapper/buffer/buffer_base.hpp @@ -15,6 +15,7 @@ */ #pragma once +#include #include #include #include @@ -25,14 +26,19 @@ namespace tensorwrapper::buffer { /** @brief Common base class for all buffer objects. * - * All classes which wrap existing tensor libraries derive from this class. + * All classes which own their state and wrap existing tensor libraries derive + * from this class. */ -class BufferBase : public tensorwrapper::detail_::PolymorphicBase, +class BufferBase : public BufferBaseCommon, + public tensorwrapper::detail_::PolymorphicBase, public tensorwrapper::detail_::DSLBase { private: /// Type of *this using my_type = BufferBase; + /// Type of the common base class + using common_base = BufferBaseCommon; + /// Traits of my_type using my_traits = types::ClassTraits; @@ -58,96 +64,9 @@ class BufferBase : public tensorwrapper::detail_::PolymorphicBase, using const_buffer_base_pointer = typename my_traits::const_buffer_base_pointer; - /// Type of the class describing the physical layout of the buffer - using layout_type = layout::Physical; - - /// Type of a read-only reference to a layout - using const_layout_reference = const layout_type&; - /// Type of a pointer to the layout using layout_pointer = std::unique_ptr; - /// Type used to represent the tensor's rank - using rank_type = typename layout_type::size_type; - - // ------------------------------------------------------------------------- - // -- Accessors - // ------------------------------------------------------------------------- - - /** @brief Does *this have a layout? - * - * Default constructed or moved from BufferBase objects do not have - * layouts. This method is used to determine if *this has a layout or not. - * - * @return True if *this has a layout and false otherwise. - * - * @throw None No throw guarantee. - */ - bool has_layout() const noexcept { return static_cast(m_layout_); } - - /** @brief Retrieves the layout of *this. - * - * This method can be used to retrieve the layout associated with *this, - * assuming there is one. See has_layout for determining if *this has a - * layout or not. - * - * @return A read-only reference to the layout. - * - * @throw std::runtime_error if *this does not have a layout. Strong throw - * guarantee. - */ - const_layout_reference layout() const { - assert_layout_(); - return *m_layout_; - } - - rank_type rank() const noexcept { - return has_layout() ? layout().rank() : 0; - } - - // ------------------------------------------------------------------------- - // -- Utility methods - // ------------------------------------------------------------------------- - - /** @brief Is *this value equal to @p rhs? - * - * Two BufferBase objects are value equal if the layouts they contain are - * polymorphically value equal or if both BufferBase objects do not contain - * a layout. - * - * @param[in] rhs The object to compare to. - * - * @return True if *this is value equal to @p rhs and false otherwise. - * - * @throw None No throw guarantee. - */ - bool operator==(const BufferBase& rhs) const noexcept { - if(has_layout() != rhs.has_layout()) return false; - if(has_layout() && m_layout_->are_different(*rhs.m_layout_)) - return false; - return true; - } - - /** @brief Is *this different from @p rhs? - * - * This method defines "different from" as being "not value equal." See - * the description of operator== for the definition of value equal. - * - * @param[in] rhs The object to compare to. - * - * @return False if *this is value equal to @p rhs and true otherwise. - * - * @throw None No throw guarantee. - */ - - bool operator!=(const BufferBase& rhs) const noexcept { - return !(*this == rhs); - } - - bool approximately_equal(const BufferBase& rhs, double tol) const { - return approximately_equal_(rhs, tol); - } - protected: // ------------------------------------------------------------------------- // -- Ctors, assignment @@ -215,6 +134,14 @@ class BufferBase : public tensorwrapper::detail_::PolymorphicBase, return *this; } + // ------------------------------------------------------------------------- + // -- BufferBaseCommon hooks + // ------------------------------------------------------------------------- + friend common_base; + bool has_layout_() const noexcept { return static_cast(m_layout_); } + + const_layout_reference layout_() const { return *m_layout_; } + dsl_reference addition_assignment_(label_type this_labels, const_labeled_reference lhs, const_labeled_reference rhs) override; @@ -233,19 +160,16 @@ class BufferBase : public tensorwrapper::detail_::PolymorphicBase, virtual bool approximately_equal_(const BufferBase& rhs, double tol) const = 0; + template + bool approximately_equal_(const BufferViewBase& rhs, + double tol) const; + private: template dsl_reference binary_op_common_(FxnType&& fxn, label_type this_labels, const_labeled_reference lhs, const_labeled_reference rhs); - /// Throws std::runtime_error when there is no layout - void assert_layout_() const { - if(has_layout()) return; - throw std::runtime_error( - "Buffer has no layout. Was it default initialized?"); - } - /// The layout of *this layout_pointer m_layout_; }; diff --git a/include/tensorwrapper/buffer/buffer_base_common.hpp b/include/tensorwrapper/buffer/buffer_base_common.hpp new file mode 100644 index 00000000..10d253fc --- /dev/null +++ b/include/tensorwrapper/buffer/buffer_base_common.hpp @@ -0,0 +1,150 @@ +/* + * Copyright 2024 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +namespace tensorwrapper::buffer { + +/** @brief CRTP base factoring the common layout/equality API of BufferBase and + * BufferViewBase. + * + * Derived must implement the protected hooks: has_layout_(), layout_(), + * and approximately_equal_(const BufferBase&, double). + * + * @tparam Derived The CRTP derived type (BufferBase or BufferViewBase). + */ +template +class BufferBaseCommon { +private: + /// Type of *this + using my_type = BufferBaseCommon; + + /// Traits for my_type + using traits_type = types::ClassTraits; + + /// Traits for Derived + using derived_traits = types::ClassTraits; + +public: + ///@{ + using layout_type = typename traits_type::layout_type; + using const_layout_reference = typename traits_type::const_layout_reference; + using rank_type = typename traits_type::rank_type; + ///@} + + // ------------------------------------------------------------------------- + // -- Accessors + // ------------------------------------------------------------------------- + + /** @brief Does *this have a layout? + * + * @return True if *this has a layout and false otherwise. + * + * @throw None No throw guarantee. + */ + bool has_layout() const noexcept { return derived_().has_layout_(); } + + /** @brief Retrieves the layout of *this. + * + * @return A read-only reference to the layout. + * + * @throw std::runtime_error if *this does not have a layout. Strong throw + * guarantee. + */ + const_layout_reference layout() const { + assert_layout_(); + return derived_().layout_(); + } + + /** @brief Returns the rank of the layout. + * + * @return The rank, or 0 if *this has no layout. + * + * @throw None No throw guarantee. + */ + rank_type rank() const noexcept { + return has_layout() ? layout().rank() : 0; + } + + // ------------------------------------------------------------------------- + // -- Utility methods + // ------------------------------------------------------------------------- + + /** @brief Is *this value equal to @p rhs? + * + * @param[in] rhs The object to compare to. + * + * @return True if *this is value equal to @p rhs and false otherwise. + * + * @throw None No throw guarantee. + */ + template + bool operator==(const BufferBaseCommon& rhs) const noexcept { + if(has_layout() != rhs.has_layout()) return false; + if(has_layout() && layout().are_different(rhs.layout())) return false; + return true; + } + + /** @brief Is *this different from @p rhs? + * + * @param[in] rhs The object to compare to. + * + * @return False if *this is value equal to @p rhs and true otherwise. + * + * @throw None No throw guarantee. + */ + template + bool operator!=(const BufferBaseCommon& rhs) const noexcept { + return !(*this == rhs); + } + + /** @brief Are *this and @p rhs approximately equal within @p tol? + * + * @param[in] rhs The object to compare to. + * @param[in] tol The tolerance for the comparison. + * + * @return True if approximately equal, false otherwise. + */ + template + bool approximately_equal(const BufferBaseCommon& rhs, + double tol) const { + return derived_().approximately_equal_(rhs.derived_(), tol); + } + +protected: + void assert_layout_() const { + if(!has_layout()) { + throw std::runtime_error( + "Buffer has no layout. Was it default initialized?"); + } + } + +private: + template + friend class BufferBaseCommon; + + Derived& derived_() noexcept { return static_cast(*this); } + + /// Access derived for CRTP + const Derived& derived_() const noexcept { + return *static_cast(this); + } +}; + +} // namespace tensorwrapper::buffer diff --git a/include/tensorwrapper/buffer/buffer_fwd.hpp b/include/tensorwrapper/buffer/buffer_fwd.hpp index 98f41eea..74b63fb4 100644 --- a/include/tensorwrapper/buffer/buffer_fwd.hpp +++ b/include/tensorwrapper/buffer/buffer_fwd.hpp @@ -18,8 +18,14 @@ namespace tensorwrapper::buffer { +template +class BufferBaseCommon; + class BufferBase; +template +class BufferViewBase; + class Contiguous; class Local; diff --git a/include/tensorwrapper/buffer/buffer_view_base.hpp b/include/tensorwrapper/buffer/buffer_view_base.hpp new file mode 100644 index 00000000..21039688 --- /dev/null +++ b/include/tensorwrapper/buffer/buffer_view_base.hpp @@ -0,0 +1,173 @@ +/* + * Copyright 2024 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +namespace tensorwrapper::buffer { + +/** @brief View of a BufferBase that aliases existing state instead of owning + * it. + * + * BufferViewBase has the same layout/equality API as BufferBase (has_layout(), + * layout(), rank(), operator==, operator!=, approximately_equal) but holds a + * non-owning pointer to a BufferBase and delegates all operations to it. + * + * BufferViewBase is templated on the type of the aliased buffer, which must + * be either BufferBase or const BufferBase. This controls whether the view is + * a mutable or const view of the underlying BufferBase. + * + * The aliased buffer must outlive this view. Default-constructed or + * moved-from views have no aliased buffer (has_layout() is false, layout() + * throws). + * + * @tparam BufferBaseType Either BufferBase or const BufferBase. + */ +template +class BufferViewBase : public BufferBaseCommon> { +private: + static_assert(std::is_same_v || + std::is_same_v, + "BufferViewBase BufferBaseType must be BufferBase or " + "const BufferBase"); + + /// Type *this derives from + using my_base_type = BufferBaseCommon>; + using typename my_base_type::const_layout_reference; + + using aliased_type = BufferBaseType; + using aliased_pointer = aliased_type*; + +public: + // ------------------------------------------------------------------------- + // -- Ctors and assignment + // ------------------------------------------------------------------------- + + /** @brief Creates a view that aliases no buffer. + * + * @throw None No throw guarantee. + */ + BufferViewBase() noexcept : m_aliased_(nullptr) {} + + /** @brief Creates a view that aliases @p buffer. + * + * @param[in] buffer The buffer to alias. Must outlive *this. + * + * @throw None No throw guarantee. + */ + explicit BufferViewBase(aliased_type& buffer) noexcept : + m_aliased_(&buffer) {} + + /** @brief Creates a view that aliases the same buffer as @p other. + * + * @param[in] other The view to copy. + * + * @throw None No throw guarantee. + */ + BufferViewBase(const BufferViewBase& other) noexcept = default; + + /** @brief Creates a view by taking the alias from @p other. + * + * After construction *this aliases the buffer @p other did, and @p other + * aliases no buffer. + * + * @param[in,out] other The view to move from. + * + * @throw None No throw guarantee. + */ + BufferViewBase(BufferViewBase&& other) noexcept = default; + + /** @brief Makes *this alias the same buffer as @p rhs. + * + * @param[in] rhs The view to copy. + * + * @return *this. + * + * @throw None No throw guarantee. + */ + BufferViewBase& operator=(const BufferViewBase& rhs) noexcept = default; + + /** @brief Replaces the alias in *this with that of @p rhs. + * + * @param[in,out] rhs The view to move from. + * + * @return *this. + * + * @throw None No throw guarantee. + */ + BufferViewBase& operator=(BufferViewBase&& rhs) noexcept = default; + + /** @brief Is *this different from @p rhs? + * + * @param[in] rhs The view to compare to. + * + * @return False if *this is value equal to @p rhs and true otherwise. + * + * @throw None No throw guarantee. + */ + bool operator!=(const BufferViewBase& rhs) const noexcept { + return !(*this == rhs); + } + +protected: + friend my_base_type; + friend class BufferBase; + + // ------------------------------------------------------------------------- + // -- BufferBaseCommon hooks + // ------------------------------------------------------------------------- + + bool has_layout_() const noexcept { + return m_aliased_ != nullptr && m_aliased_->has_layout(); + } + + const_layout_reference layout_() const { + if(m_aliased_ == nullptr) { + throw std::runtime_error( + "Buffer has no layout. Was it default initialized?"); + } + return m_aliased_->layout(); + } + + template + bool approximately_equal_(const BufferViewBase& rhs, + double tol) const { + if(m_aliased_ == nullptr) return !rhs.has_layout(); + return m_aliased_->approximately_equal(*rhs.m_aliased_, tol); + } + + bool approximately_equal_(const BufferBase& rhs, double tol) const { + if(m_aliased_ == nullptr) return !rhs.has_layout(); + return m_aliased_->approximately_equal(rhs, tol); + } + +private: + /// The buffer *this aliases (non-owning) + aliased_pointer m_aliased_; +}; + +// Out-of-line definition so both BufferBase and BufferViewBase are complete +template +bool BufferBase::approximately_equal_(const BufferViewBase& rhs, + double tol) const { + if(!rhs.has_layout()) return !has_layout(); + return approximately_equal_( + *static_cast(rhs.m_aliased_), tol); +} + +} // namespace tensorwrapper::buffer diff --git a/include/tensorwrapper/types/buffer_traits.hpp b/include/tensorwrapper/types/buffer_traits.hpp index be3e7a4e..e42a5b7e 100644 --- a/include/tensorwrapper/types/buffer_traits.hpp +++ b/include/tensorwrapper/types/buffer_traits.hpp @@ -17,12 +17,26 @@ #pragma once #include #include +#include #include namespace tensorwrapper::types { +template +struct ClassTraits> { + /// Type of the class describing the physical layout of the buffer + using layout_type = layout::Physical; + + /// Type of a read-only reference to a layout + using const_layout_reference = const layout_type&; + + /// Type used to represent the tensor's rank + using rank_type = typename layout_type::size_type; +}; + template<> -struct ClassTraits { +struct ClassTraits + : public ClassTraits> { /// Type all buffers inherit from using buffer_base_type = buffer::BufferBase; diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/buffer_base.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/buffer_base.cpp deleted file mode 100644 index 6de12348..00000000 --- a/tests/cxx/unit_tests/tensorwrapper/buffer/buffer_base.cpp +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright 2024 NWChemEx-Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../testing/testing.hpp" -#include -#include -#include - -using namespace tensorwrapper; -using namespace buffer; - -/* Testing strategy: - * - * - BufferBase is an abstract class. To test it we must create an instance of - * a derived class. We then will upcast to BufferBase and perform checks - * through the BufferBase interface. - * - `xxx_assignment` methods are tested in the derived classes; however, the - * corresponding `xxx` method is defined in BufferBase and thus is tested - * here (`xxx` being `addition`, `subtraction`, etc.). - * - */ - -TEST_CASE("BufferBase") { - auto pscalar = testing::eigen_scalar(); - auto& scalar = *pscalar; - scalar.set_elem({}, 1.0); - - auto pvector = testing::eigen_vector(2); - auto& vector = *pvector; - - vector.set_elem({0}, 1.0); - vector.set_elem({1}, 2.0); - - auto scalar_layout = testing::scalar_physical(); - auto vector_layout = testing::vector_physical(2); - - buffer::Contiguous defaulted; - BufferBase& defaulted_base = defaulted; - BufferBase& scalar_base = scalar; - BufferBase& vector_base = vector; - - SECTION("has_layout") { - REQUIRE_FALSE(defaulted_base.has_layout()); - REQUIRE(scalar_base.has_layout()); - REQUIRE(vector_base.has_layout()); - } - - SECTION("layout") { - REQUIRE_THROWS_AS(defaulted_base.layout(), std::runtime_error); - REQUIRE(scalar_base.layout().are_equal(scalar_layout)); - REQUIRE(vector_base.layout().are_equal(vector_layout)); - } - - SECTION("operator==") { - // Defaulted layout == defaulted layout - REQUIRE(defaulted_base == buffer::Contiguous{}); - - // Defaulted layout != non-defaulted layout - REQUIRE_FALSE(defaulted_base == scalar_base); - - // Non-defaulted layout different value - REQUIRE_FALSE(scalar_base == vector_base); - } - - SECTION("operator!=") { - // Just spot check because it negates operator==, which was tested - REQUIRE(defaulted_base != scalar_base); - REQUIRE_FALSE(defaulted_base != buffer::Contiguous()); - } -} diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/buffer_base_common.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/buffer_base_common.cpp new file mode 100644 index 00000000..193508ca --- /dev/null +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/buffer_base_common.cpp @@ -0,0 +1,114 @@ +/* + * Copyright 2024 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../testing/testing.hpp" +#include +#include +#include +#include +#include + +using namespace tensorwrapper; +using namespace buffer; + +TEST_CASE("BufferBaseCommon") { + using MutableView = BufferViewBase; + using ConstView = BufferViewBase; + + auto pscalar = testing::eigen_scalar(); + auto& scalar = *pscalar; + scalar.set_elem({}, 1.0); + + auto pvector = testing::eigen_vector(2); + auto& vector = *pvector; + vector.set_elem({0}, 1.0); + vector.set_elem({1}, 2.0); + + auto scalar_layout = testing::scalar_physical(); + auto vector_layout = testing::vector_physical(2); + + buffer::Contiguous defaulted; + MutableView defaulted_view(defaulted); + MutableView scalar_view(scalar); + MutableView vector_view(vector); + ConstView defaulted_const_view(defaulted); + ConstView scalar_const_view(scalar); + ConstView vector_const_view(vector); + + SECTION("operator== (BufferBase with BufferBaseView)") { + REQUIRE(defaulted_view == defaulted); + REQUIRE(defaulted_const_view == defaulted); + REQUIRE(defaulted == defaulted_view); + REQUIRE(defaulted == defaulted_const_view); + + REQUIRE_FALSE(defaulted_view == scalar); + REQUIRE_FALSE(defaulted_const_view == scalar); + REQUIRE_FALSE(scalar == defaulted_view); + REQUIRE_FALSE(scalar == defaulted_const_view); + } + + SECTION("operator!= (BufferBasewith BufferBase") { + REQUIRE(scalar_view != vector); + REQUIRE(scalar_const_view != vector); + REQUIRE(vector != scalar_view); + REQUIRE(vector != scalar_const_view); + + REQUIRE(scalar_view != defaulted); + REQUIRE(scalar_const_view != defaulted); + REQUIRE(defaulted != scalar_view); + REQUIRE(defaulted != scalar_const_view); + + REQUIRE_FALSE(scalar_view != scalar); + REQUIRE_FALSE(scalar_const_view != scalar); + REQUIRE_FALSE(scalar != scalar_view); + REQUIRE_FALSE(scalar != scalar_const_view); + } + + SECTION("BufferBase operator== with BufferBase") { + REQUIRE(scalar == scalar); + REQUIRE_FALSE(scalar == vector); + REQUIRE_FALSE(defaulted == scalar); + } + + SECTION("BufferViewBase operator== with BufferViewBase") { + REQUIRE(scalar_view == scalar_view); + REQUIRE(scalar_const_view == scalar_const_view); + REQUIRE(scalar_view == scalar_const_view); + REQUIRE(scalar_const_view == scalar_view); + REQUIRE_FALSE(scalar_view == vector_view); + REQUIRE_FALSE(defaulted_view == scalar_view); + REQUIRE_FALSE(scalar_const_view == vector_view); + REQUIRE_FALSE(vector_view == scalar_const_view); + REQUIRE_FALSE(defaulted_view == vector_view); + REQUIRE_FALSE(vector_view == defaulted_view); + } + + SECTION("approximately_equal") { + REQUIRE(scalar.approximately_equal(scalar, 1e-10)); + REQUIRE(scalar_view.approximately_equal(scalar_view, 1e-10)); + REQUIRE(scalar_view.approximately_equal(scalar, 1e-10)); + REQUIRE(scalar.approximately_equal(scalar_view, 1e-10)); + + REQUIRE_FALSE(scalar_view.approximately_equal(vector, 1e-10)); + REQUIRE_FALSE(vector.approximately_equal(scalar_view, 1e-10)); + } + + SECTION("Null view equals buffer with no layout") { + ConstView null_view; + REQUIRE(null_view == defaulted); + REQUIRE_FALSE(null_view == scalar); + } +} diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/buffer_view_base.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/buffer_view_base.cpp new file mode 100644 index 00000000..61a9d9cf --- /dev/null +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/buffer_view_base.cpp @@ -0,0 +1,136 @@ +/* + * Copyright 2024 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../testing/testing.hpp" +#include +#include +#include +#include + +using namespace tensorwrapper; +using namespace buffer; + +TEST_CASE("BufferViewBase") { + using MutableView = BufferViewBase; + using ConstView = BufferViewBase; + + auto pscalar = testing::eigen_scalar(); + auto& scalar = *pscalar; + scalar.set_elem({}, 1.0); + + auto pvector = testing::eigen_vector(2); + auto& vector = *pvector; + vector.set_elem({0}, 1.0); + vector.set_elem({1}, 2.0); + + auto scalar_layout = testing::scalar_physical(); + auto vector_layout = testing::vector_physical(2); + + buffer::Contiguous defaulted; + + SECTION("Default construction") { + ConstView defaulted_const_view; + MutableView defaulted_view; + REQUIRE_FALSE(defaulted_const_view.has_layout()); + REQUIRE_FALSE(defaulted_view.has_layout()); + REQUIRE_THROWS_AS(defaulted_const_view.layout(), std::runtime_error); + REQUIRE_THROWS_AS(defaulted_view.layout(), std::runtime_error); + REQUIRE(defaulted_const_view.rank() == 0); + REQUIRE(defaulted_view.rank() == 0); + } + + SECTION("Construct from buffer") { + ConstView scalar_const_view(scalar); + MutableView scalar_view(scalar); + REQUIRE(scalar_const_view.has_layout()); + REQUIRE(scalar_view.has_layout()); + REQUIRE(scalar_const_view.layout().are_equal(scalar_layout)); + REQUIRE(scalar_const_view.rank() == 0); + REQUIRE(scalar_view.layout().are_equal(scalar_layout)); + REQUIRE(scalar_view.rank() == 0); + + ConstView vector_const_view(vector); + MutableView vector_view(vector); + REQUIRE(vector_const_view.has_layout()); + REQUIRE(vector_view.has_layout()); + REQUIRE(vector_const_view.layout().are_equal(vector_layout)); + REQUIRE(vector_const_view.rank() == 1); + REQUIRE(vector_view.layout().are_equal(vector_layout)); + REQUIRE(vector_view.rank() == 1); + } + + SECTION("Copy construction") { + ConstView const_view(scalar); + ConstView copy_const(const_view); + REQUIRE(copy_const.has_layout()); + REQUIRE(copy_const.layout().are_equal(scalar_layout)); + REQUIRE(copy_const.rank() == 0); + + MutableView mutable_view(scalar); + MutableView copy_mutable(mutable_view); + REQUIRE(copy_mutable.has_layout()); + REQUIRE(copy_mutable.layout().are_equal(scalar_layout)); + REQUIRE(copy_mutable.rank() == 0); + } + + SECTION("Move construction") { + ConstView const_view(scalar); + ConstView moved_const(std::move(const_view)); + REQUIRE(moved_const.has_layout()); + REQUIRE(moved_const.layout().are_equal(scalar_layout)); + REQUIRE(moved_const.rank() == 0); + + MutableView mutable_view(scalar); + MutableView moved(std::move(mutable_view)); + REQUIRE(moved.has_layout()); + REQUIRE(moved.layout().are_equal(scalar_layout)); + REQUIRE(moved.rank() == 0); + } + + SECTION("Copy assignment") { + ConstView const_view(scalar); + ConstView other_const; + auto pother_const = &(other_const = const_view); + REQUIRE(pother_const == &other_const); + REQUIRE(other_const.has_layout()); + REQUIRE(other_const.layout().are_equal(scalar_layout)); + REQUIRE(other_const.rank() == 0); + + MutableView mutable_view(scalar); + MutableView other; + other = mutable_view; + REQUIRE(other.has_layout()); + REQUIRE(other.layout().are_equal(scalar_layout)); + REQUIRE(other.rank() == 0); + } + + SECTION("Move assignment") { + ConstView const_view(scalar); + ConstView other_const; + auto pother_const = &(other_const = std::move(const_view)); + REQUIRE(pother_const == &other_const); + REQUIRE(other_const.has_layout()); + REQUIRE(other_const.layout().are_equal(scalar_layout)); + REQUIRE(other_const.rank() == 0); + + MutableView mutable_view(scalar); + MutableView other_mutable; + other_mutable = std::move(mutable_view); + REQUIRE(other_mutable.has_layout()); + REQUIRE(other_mutable.layout().are_equal(scalar_layout)); + REQUIRE(other_mutable.rank() == 0); + } +}