diff --git a/docs/src/user_interface/decompositions.md b/docs/src/user_interface/decompositions.md index b48bacbe..38e6ceb7 100644 --- a/docs/src/user_interface/decompositions.md +++ b/docs/src/user_interface/decompositions.md @@ -40,11 +40,10 @@ lq_full lq_compact ``` -Alongside these functions, we provide a LAPACK-based implementation for dense arrays, as provided by the following algorithm: +The following algorithm is available for QR and LQ decompositions: ```@docs; canonical=false -LAPACK_HouseholderQR -LAPACK_HouseholderLQ +Householder ``` ## Eigenvalue Decomposition @@ -63,9 +62,9 @@ These functions return the diagonal elements of `D` in a vector. Finally, it is also possible to compute a partial or truncated eigenvalue decomposition, using the [`eig_trunc`](@ref) and [`eigh_trunc`](@ref) functions. To control the behavior of the truncation, we refer to [Truncations](@ref) for more information. -### Symmetric Eigenvalue Decomposition +### Hermitian or Real Symmetric Eigenvalue Decomposition -For symmetric matrices, we provide the following functions: +For hermitian matrices, thus including real symmetric matrices, we provide the following functions: ```@docs; canonical=false eigh_full @@ -78,7 +77,7 @@ eigh_vals By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results. See [Gauge choices](@ref sec_gaugefix) for more details. -Alongside these functions, we provide a LAPACK-based implementation for dense arrays, as provided by the following algorithms: +The following algorithms are available for the hermitian eigenvalue decomposition: ```@autodocs; canonical=false Modules = [MatrixAlgebraKit] @@ -100,7 +99,7 @@ eig_vals By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results. See [Gauge choices](@ref sec_gaugefix) for more details. -Alongside these functions, we provide a LAPACK-based implementation for dense arrays, as provided by the following algorithms: +The following algorithms are available for the standard eigenvalue decomposition: ```@autodocs; canonical=false Modules = [MatrixAlgebraKit] @@ -120,7 +119,7 @@ schur_full schur_vals ``` -The LAPACK-based implementation for dense arrays is provided by the following algorithms: +The following algorithms are available for the Schur decomposition: ```@autodocs; canonical=false Modules = [MatrixAlgebraKit] @@ -153,11 +152,11 @@ svd_trunc By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results. See [Gauge choices](@ref sec_gaugefix) for more details. -MatrixAlgebraKit again ships with LAPACK-based implementations for dense arrays: +The following algorithms are available for the singular value decomposition: ```@autodocs; canonical=false Modules = [MatrixAlgebraKit] -Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_SVDAlgorithm +Filter = t -> t isa Type && t <: MatrixAlgebraKit.SVDAlgorithms ``` ## Polar Decomposition @@ -388,6 +387,54 @@ norm(A * N1') < 1e-14 && norm(A * N2') < 1e-14 && true ``` +## [Driver Selection](@id sec_driverselection) + +!!! note "Expert use case" + Selecting a specific driver is an advanced feature intended for users who need to target a specific computational backend, such as a GPU. For most use cases, the default driver selection is sufficient. + +Each algorithm in MatrixAlgebraKit can optionally accept a `driver` keyword argument to explicitly select the computational backend. +By default, the driver is set to `DefaultDriver()`, which automatically selects the most appropriate backend based on the input matrix type. +The available drivers are: + +```@docs; canonical=false +MatrixAlgebraKit.DefaultDriver +MatrixAlgebraKit.LAPACK +MatrixAlgebraKit.CUSOLVER +MatrixAlgebraKit.ROCSOLVER +MatrixAlgebraKit.GLA +MatrixAlgebraKit.Native +``` + +For example, to force LAPACK for a generic matrix type, or to use a GPU backend: + +```julia +using MatrixAlgebraKit +using MatrixAlgebraKit: LAPACK, CUSOLVER # driver types are not exported by default + +# Default: driver is selected automatically based on the input type +U, S, Vᴴ = svd_compact(A) +U, S, Vᴴ = svd_compact(A; alg = SafeDivideAndConquer()) + +# Expert: explicitly select LAPACK +U, S, Vᴴ = svd_compact(A; alg = SafeDivideAndConquer(; driver = LAPACK())) + +# Expert: use a GPU backend (requires loading the appropriate extension) +U, S, Vᴴ = svd_compact(A; alg = QRIteration(; driver = CUSOLVER())) +``` + +Similarly, for QR decompositions: + +```julia +using MatrixAlgebraKit: LAPACK # driver types are not exported by default + +# Default: driver is selected automatically +Q, R = qr_compact(A) +Q, R = qr_compact(A; alg = Householder()) + +# Expert: explicitly select a driver +Q, R = qr_compact(A; alg = Householder(; driver = LAPACK())) +``` + ## [Gauge choices](@id sec_gaugefix) Both eigenvalue and singular value decompositions have residual gauge degrees of freedom even when the eigenvalues or singular values are unique. diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index ebec311e..24435d71 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -6,7 +6,7 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe using MatrixAlgebraKit: ROCSOLVER, LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm -import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj! +import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdj! import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank using AMDGPU using LinearAlgebra @@ -14,11 +14,12 @@ using LinearAlgebra: BlasFloat include("yarocsolver.jl") -MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedROCMatrix{<:BlasFloat}} = ROCSOLVER() -function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix} - return ROCSOLVER_QRIteration(; kwargs...) +MatrixAlgebraKit.default_driver(::Type{TA}) where {TA <: StridedROCVecOrMat{<:BlasFloat}} = ROCSOLVER() + +function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat{<:BlasFloat}} + return QRIteration(; kwargs...) end -function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix} +function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat{<:BlasFloat}} return ROCSOLVER_DivideAndConquer(; kwargs...) end @@ -26,13 +27,21 @@ for f in (:geqrf!, :ungqr!, :unmqr!) @eval $f(::ROCSOLVER, args...) = YArocSOLVER.$f(args...) end -_gpu_gesvd!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix) = - YArocSOLVER.gesvd!(A, S, U, Vᴴ) -# not yet supported -# _gpu_Xgesvdp!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = -# YArocSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...) -_gpu_gesvdj!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = - YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) +MatrixAlgebraKit.supports_svd(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi) +MatrixAlgebraKit.supports_svd_full(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi) + +function gesvd!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) + m, n = size(A) + m >= n && return YArocSOLVER.gesvd!(A, S, U, Vᴴ) + return MatrixAlgebraKit.svd_via_adjoint!(gesvd!, ROCSOLVER(), A, S, U, Vᴴ; kwargs...) +end + +function gesvdj!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) + m, n = size(A) + m >= n && return YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) + return MatrixAlgebraKit.svd_via_adjoint!(gesvdj!, ROCSOLVER(), A, S, U, Vᴴ; kwargs...) +end + _gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevj!(A, Dd, V; kwargs...) _gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 937027a4..166cd666 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -6,8 +6,8 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm -import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev! -import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester, svd_rank +import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!, _gpu_geev! +import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_Xgesvdr!, _sylvester, svd_rank using CUDA, CUDA.CUBLAS using CUDA: i32 using LinearAlgebra @@ -15,31 +15,46 @@ using LinearAlgebra: BlasFloat include("yacusolver.jl") -MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER() -function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}} - return CUSOLVER_QRIteration(; kwargs...) +MatrixAlgebraKit.default_driver(::Type{TA}) where {TA <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER() + +function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}} + return QRIteration(; kwargs...) end -function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}} +function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}} return CUSOLVER_Simple(; kwargs...) end -function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}} +function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}} return CUSOLVER_DivideAndConquer(; kwargs...) end + for f in (:geqrf!, :ungqr!, :unmqr!) @eval $f(::CUSOLVER, args...) = YACUSOLVER.$f(args...) end +MatrixAlgebraKit.supports_svd(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar) +MatrixAlgebraKit.supports_svd_full(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar) + +function gesvd!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) + m, n = size(A) + m >= n && return YACUSOLVER.gesvd!(A, S, U, Vᴴ) + return MatrixAlgebraKit.svd_via_adjoint!(gesvd!, CUSOLVER(), A, S, U, Vᴴ; kwargs...) +end + +function gesvdj!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) + m, n = size(A) + m >= n && return YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) + return MatrixAlgebraKit.svd_via_adjoint!(gesvdj!, CUSOLVER(), A, S, U, Vᴴ; kwargs...) +end + +gesvdp!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = + YACUSOLVER.gesvdp!(A, S, U, Vᴴ; kwargs...) + +_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = + YACUSOLVER.gesvdr!(A, S, U, Vᴴ; kwargs...) + _gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) = YACUSOLVER.Xgeev!(A, D, V) -_gpu_gesvd!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix) = - YACUSOLVER.gesvd!(A, S, U, Vᴴ) -_gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = - YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...) -_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = - YACUSOLVER.Xgesvdr!(A, S, U, Vᴴ; kwargs...) -_gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = - YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) _gpu_heevj!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevj!(A, Dd, V; kwargs...) diff --git a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl index aba01fe6..be63eb85 100644 --- a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl +++ b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl @@ -98,7 +98,7 @@ for (bname, fname, elty, relty) in end end -function Xgesvdp!( +function gesvdp!( A::StridedCuMatrix{T}, S::StridedCuVector = similar(A, real(T), min(size(A)...)), U::StridedCuMatrix{T} = similar(A, T, size(A, 1), min(size(A)...)), @@ -164,9 +164,7 @@ function Xgesvdp!( ) end err = h_err_sigma[] - if err > tol - warn("Xgesvdp! did not attained requested tolerance: error = $err > tolerance = $tol") - end + err > tol && @warn "gesvdp! did not attain the requested tolerance: error = $err > tolerance = $tol" flag = @allowscalar dh.info[1] CUSOLVER.chklapackerror(BlasInt(flag)) @@ -269,7 +267,7 @@ for (bname, fname, elty, relty) in end # Wrapper for randomized SVD -function Xgesvdr!( +function gesvdr!( A::StridedCuMatrix{T}, S::StridedCuVector = similar(A, real(T), min(size(A)...)), U::StridedCuMatrix{T} = similar(A, T, size(A, 1), min(size(A)...)), diff --git a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl index 629972cf..bb349b7e 100644 --- a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl +++ b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl @@ -2,45 +2,37 @@ module MatrixAlgebraKitGenericLinearAlgebraExt using MatrixAlgebraKit using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge +using MatrixAlgebraKit: GLA +import MatrixAlgebraKit: gesvd! using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr! using LinearAlgebra: I, Diagonal, lmul! -function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} - return GLA_QRIteration() -end - -for f! in (:svd_compact!, :svd_full!) - @eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing, nothing) -end -MatrixAlgebraKit.initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing +const GlaFloat = Union{BigFloat, Complex{BigFloat}} +const GlaStridedVecOrMatrix{T <: GlaFloat} = Union{StridedVector{T}, StridedMatrix{T}} +MatrixAlgebraKit.default_driver(::Type{<:QRIteration}, ::Type{TA}) where {TA <: GlaStridedVecOrMatrix} = GLA() -function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration) - F = svd!(A) - U, S, Vᴴ = F.U, Diagonal(F.S), F.Vt - - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ) - - return U, S, Vᴴ +function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix} + return QRIteration(; kwargs...) end -function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration) - F = svd!(A; full = true) - U, Vᴴ = F.U, F.Vt - S = MatrixAlgebraKit.zero!(similar(F.S, (size(U, 2), size(Vᴴ, 1)))) - diagview(S) .= F.S - - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ) - - return U, S, Vᴴ -end - -function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, ::GLA_QRIteration) - return svdvals!(A) +function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) + m, n = size(A) + if length(U) == 0 && length(Vᴴ) == 0 + Sv = svdvals!(A) + copyto!(S, Sv) + else + minmn = min(m, n) + # full SVD if U has m columns or Vᴴ has n rows (beyond the compact min(m,n)) + full = (length(U) > 0 && size(U, 2) > minmn) || (length(Vᴴ) > 0 && size(Vᴴ, 1) > minmn) + F = svd!(A; full = full) + length(S) > 0 && copyto!(S, F.S) + length(U) > 0 && copyto!(U, F.U) + length(Vᴴ) > 0 && copyto!(Vᴴ, F.Vt) + end + return S, U, Vᴴ end -function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} +function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix} return GLA_QRIteration(; kwargs...) end diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 2d5e251e..9135b6f9 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -32,6 +32,7 @@ export left_orth, right_orth, left_null, right_null export left_orth!, right_orth!, left_null!, right_null! export Householder, Native_HouseholderQR, Native_HouseholderLQ +export DivideAndConquer, SafeDivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations, LAPACK_DivideAndConquer, LAPACK_Jacobi, LAPACK_SafeDivideAndConquer diff --git a/src/algorithms.jl b/src/algorithms.jl index 07685f6a..cf1d2ff4 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -212,6 +212,23 @@ Driver to select a native implementation in MatrixAlgebraKit as the implementati """ struct Native <: Driver end +# In order to avoid amibiguities, this method is implemented in a tiered way +# default_driver(alg, A) -> default_driver(typeof(alg), typeof(A)) +# default_driver(Talg, TA) -> default_driver(TA) +# This is to try and minimize ambiguity while allowing overloading at multiple levels +@inline default_driver(alg::AbstractAlgorithm, A) = default_driver(typeof(alg), A isa Type ? A : typeof(A)) +@inline default_driver(::Type{Alg}, A) where {Alg <: AbstractAlgorithm} = default_driver(Alg, typeof(A)) +@inline default_driver(::Type{Alg}, ::Type{TA}) where {Alg <: AbstractAlgorithm, TA} = default_driver(TA) + +# defaults +default_driver(::Type{TA}) where {TA <: AbstractArray} = Native() # default fallback +default_driver(::Type{TA}) where {TA <: YALAPACK.MaybeBlasVecOrMat} = LAPACK() + +# wrapper types +@inline default_driver(::Type{Alg}, ::Type{<:SubArray{T, N, A}}) where {Alg <: AbstractAlgorithm, T, N, A} = default_driver(Alg, A) +@inline default_driver(::Type{Alg}, ::Type{<:Base.ReshapedArray{T, N, A}}) where {Alg <: AbstractAlgorithm, T, N, A} = default_driver(Alg, A) +@inline default_driver(::Type{<:SubArray{T, N, A}}) where {T, N, A} = default_driver(A) +@inline default_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = default_driver(A) # Truncation strategy # ------------------- diff --git a/src/common/defaults.jl b/src/common/defaults.jl index bc4160a1..c64c24ad 100644 --- a/src/common/defaults.jl +++ b/src/common/defaults.jl @@ -59,3 +59,5 @@ function default_fixgauge(new_value::Bool) DEFAULT_FIXGAUGE[] = new_value return previous_value end + +const _fixgauge_docs = "The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the output, see also [`default_fixgauge`](@ref) for a global toggle and [`gaugefix!`](@ref) for implementation details." diff --git a/src/common/gauge.jl b/src/common/gauge.jl index e855548f..f045d4e9 100644 --- a/src/common/gauge.jl +++ b/src/common/gauge.jl @@ -8,6 +8,9 @@ This is achieved by ensuring that the entry with the largest magnitude in `V` or is real and positive. """ gaugefix! +# Helper functions +_argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x))) +_largest(x, y) = abs(x) < abs(y) ? y : x function gaugefix!(::Union{typeof(eig_full!), typeof(eigh_full!), typeof(gen_eig_full!)}, V::AbstractMatrix) for j in axes(V, 2) diff --git a/src/implementations/lq.jl b/src/implementations/lq.jl index cf64ccd6..d248ac40 100644 --- a/src/implementations/lq.jl +++ b/src/implementations/lq.jl @@ -115,7 +115,7 @@ end @inline householder_lq!(A, L, Q; driver::Driver = DefaultDriver(), kwargs...) = householder_lq!(driver, A, L, Q; kwargs...) householder_lq!(::DefaultDriver, A, L, Q; kwargs...) = - householder_lq!(default_householder_driver(A), A, L, Q; kwargs...) + householder_lq!(default_driver(Householder, A), A, L, Q; kwargs...) householder_lq!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, L, Q; kwargs...) = lq_via_qr!(A, L, Q, Householder(; driver, kwargs...)) function householder_lq!( @@ -221,7 +221,7 @@ end @inline householder_lq_null!(A, Nᴴ; driver::Driver = DefaultDriver(), kwargs...) = householder_lq_null!(driver, A, Nᴴ; kwargs...) householder_lq_null!(::DefaultDriver, A, Nᴴ; kwargs...) = - householder_lq_null!(default_householder_driver(A), A, Nᴴ; kwargs...) + householder_lq_null!(default_driver(Householder, A), A, Nᴴ; kwargs...) householder_lq_null!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, Nᴴ; kwargs...) = lq_null_via_qr!(A, Nᴴ, Householder(; driver, kwargs...)) function householder_lq_null!( diff --git a/src/implementations/qr.jl b/src/implementations/qr.jl index f78d8c44..3d340a28 100644 --- a/src/implementations/qr.jl +++ b/src/implementations/qr.jl @@ -117,7 +117,7 @@ end @inline householder_qr!(A, Q, R; driver::Driver = DefaultDriver(), kwargs...) = householder_qr!(driver, A, Q, R; kwargs...) householder_qr!(::DefaultDriver, A, Q, R; kwargs...) = - householder_qr!(default_householder_driver(A), A, Q, R; kwargs...) + householder_qr!(default_driver(Householder, A), A, Q, R; kwargs...) function householder_qr!( driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; positive::Bool = true, pivoted::Bool = false, @@ -248,7 +248,7 @@ end @inline householder_qr_null!(A, N; driver::Driver = DefaultDriver(), kwargs...) = householder_qr_null!(driver, A, N; kwargs...) householder_qr_null!(::DefaultDriver, A, N; kwargs...) = - householder_qr_null!(default_householder_driver(A), A, N; kwargs...) + householder_qr_null!(default_driver(Householder, A), A, N; kwargs...) function householder_qr_null!( driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, N::AbstractMatrix; positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0 diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index dfe1fa16..fffd49c4 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -105,133 +105,121 @@ function initialize_output(::typeof(svd_vals!), A::Diagonal, ::DiagonalAlgorithm return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1)) end -# Implementation -# -------------- -function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) - check_input(svd_full!, A, USVᴴ, alg) - U, S, Vᴴ = USVᴴ - fill!(S, zero(eltype(S))) - m, n = size(A) - minmn = min(m, n) - if minmn == 0 - one!(U) - zero!(S) - one!(Vᴴ) - return USVᴴ - end +# ========================== +# IMPLEMENTATIONS +# ========================== - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa LAPACK_QRIteration - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_QRIteration")) - YALAPACK.gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ) - elseif alg isa LAPACK_DivideAndConquer - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer")) - YALAPACK.gesdd!(A, view(S, 1:minmn, 1), U, Vᴴ) - elseif alg isa LAPACK_SafeDivideAndConquer - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) - YALAPACK.gesdvd!(A, view(S, 1:minmn, 1), U, Vᴴ) - elseif alg isa LAPACK_Bisection - throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) - elseif alg isa LAPACK_Jacobi - throw(ArgumentError("LAPACK_Jacobi is not supported for full SVD")) - else - throw(ArgumentError("Unsupported SVD algorithm")) - end +for f! in (:gesdd!, :gesvd!, :gesvdj!, :gesvdp!, :gesvdx!, :gesvdr!, :gesdvd!) + @eval $f!(driver::Driver, args...) = throw(ArgumentError("$driver does not provide $f!")) +end - for i in 2:minmn - S[i, i] = S[i, 1] - S[i, 1] = zero(eltype(S)) - end +""" + svd_via_adjoint!(f!, driver, A, S, U, Vᴴ; kwargs...) - do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ) +Compute the SVD of `A` (m × n, m < n) by computing the SVD of `adjoint(A)` using +the provided function `f!(driver, A, S, U, Vᴴ; kwargs...)`. Use this as a building +block for drivers whose SVD routines require m ≥ n. +""" +function svd_via_adjoint!(f!::F, driver::Driver, A, S, U, Vᴴ; kwargs...) where {F} + Aᴴ = adjoint!(similar(A'), A) + Uᴴ = similar(U') + V = similar(Vᴴ') + f!(driver, Aᴴ, S, V, Uᴴ; kwargs...) + length(U) > 0 && adjoint!(U, Uᴴ) + length(Vᴴ) > 0 && adjoint!(Vᴴ, V) + return S, U, Vᴴ +end - return USVᴴ +# LAPACK +for f! in (:gesdd!, :gesvd!, :gesvdx!, :gesdvd!) + @eval $f!(::LAPACK, args...; kwargs...) = YALAPACK.$f!(args...; kwargs...) end -function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) - check_input(svd_compact!, A, USVᴴ, alg) - U, S, Vᴴ = USVᴴ +function gesvdj!(::LAPACK, A, S, U, Vᴴ; kwargs...) m, n = size(A) - minmn = min(m, n) - if minmn == 0 - one!(U) - zero!(S) - one!(Vᴴ) - return USVᴴ - end + m >= n && return YALAPACK.gesvdj!(A, S, U, Vᴴ) + return svd_via_adjoint!(gesvdj!, LAPACK(), A, S, U, Vᴴ; kwargs...) +end - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa LAPACK_QRIteration - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_QRIteration")) - YALAPACK.gesvd!(A, diagview(S), U, Vᴴ) - elseif alg isa LAPACK_DivideAndConquer - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer")) - YALAPACK.gesdd!(A, diagview(S), U, Vᴴ) - elseif alg isa LAPACK_SafeDivideAndConquer - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) - YALAPACK.gesdvd!(A, diagview(S), U, Vᴴ) - elseif alg isa LAPACK_Bisection - YALAPACK.gesvdx!(A, diagview(S), U, Vᴴ; alg_kwargs...) - elseif alg isa LAPACK_Jacobi - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_Jacobi")) - YALAPACK.gesvj!(A, diagview(S), U, Vᴴ) - else - throw(ArgumentError("Unsupported SVD algorithm")) +for (f, f_lapack!, Alg) in ( + (:safe_divide_and_conquer, :gesdvd!, :SafeDivideAndConquer), + (:divide_and_conquer, :gesdd!, :DivideAndConquer), + (:qr_iteration, :gesvd!, :QRIteration), + (:bisection, :gesvdx!, :Bisection), + (:jacobi, :gesvdj!, :Jacobi), + (:svd_polar, :gesvdp!, :SVDViaPolar), + ) + f_svd! = Symbol(f, :_svd!) + f_svd_full! = Symbol(f, :_svd_full!) + f_svd_vals! = Symbol(f, :_svd_vals!) + + # MatrixAlgebraKit wrappers + @eval begin + function svd_compact!(A, USVᴴ, alg::$Alg) + check_input(svd_compact!, A, USVᴴ, alg) + return $f_svd!(A, USVᴴ...; alg.kwargs...) + end + function svd_full!(A, USVᴴ, alg::$Alg) + check_input(svd_full!, A, USVᴴ, alg) + return $f_svd_full!(A, USVᴴ...; alg.kwargs...) + end + function svd_vals!(A, S, alg::$Alg) + check_input(svd_vals!, A, S, alg) + return $f_svd_vals!(A, S; alg.kwargs...) + end end - do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ) + # driver + @eval begin + @inline $f_svd!(A, U, S, Vᴴ; driver::Driver = DefaultDriver(), kwargs...) = $f_svd!(driver, A, U, S, Vᴴ; kwargs...) + @inline $f_svd_full!(A, U, S, Vᴴ; driver::Driver = DefaultDriver(), kwargs...) = $f_svd_full!(driver, A, U, S, Vᴴ; kwargs...) + @inline $f_svd_vals!(A, S; driver::Driver = DefaultDriver(), kwargs...) = $f_svd_vals!(driver, A, S; kwargs...) - return USVᴴ -end - -function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) - check_input(svd_vals!, A, S, alg) - m, n = size(A) - minmn = min(m, n) - if minmn == 0 - zero!(S) - return S - end - U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) - - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa LAPACK_QRIteration - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_QRIteration")) - YALAPACK.gesvd!(A, S, U, Vᴴ) - elseif alg isa LAPACK_DivideAndConquer - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer")) - YALAPACK.gesdd!(A, S, U, Vᴴ) - elseif alg isa LAPACK_SafeDivideAndConquer - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) - YALAPACK.gesdvd!(A, S, U, Vᴴ) - elseif alg isa LAPACK_Bisection - YALAPACK.gesvdx!(A, S, U, Vᴴ; alg_kwargs...) - elseif alg isa LAPACK_Jacobi - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_Jacobi")) - YALAPACK.gesvj!(A, S, U, Vᴴ) - else - throw(ArgumentError("Unsupported SVD algorithm")) + @inline $f_svd!(::DefaultDriver, A, U, S, Vᴴ; kwargs...) = $f_svd!(default_driver($Alg, A), A, U, S, Vᴴ; kwargs...) + @inline $f_svd_full!(::DefaultDriver, A, U, S, Vᴴ; kwargs...) = $f_svd_full!(default_driver($Alg, A), A, U, S, Vᴴ; kwargs...) + @inline $f_svd_vals!(::DefaultDriver, A, S; kwargs...) = $f_svd_vals!(default_driver($Alg, A), A, S; kwargs...) end - return S + # Implementation + @eval begin + function $f_svd!(driver::Driver, A, U, S, Vᴴ; fixgauge::Bool = true, kwargs...) + supports_svd(driver, $(QuoteNode(f))) || + throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) + isempty(A) && return one!(U), zero!(S), one!(Vᴴ) + $f_lapack!(driver, A, diagview(S), U, Vᴴ; kwargs...) + fixgauge && gaugefix!(svd_compact!, U, Vᴴ) + return U, S, Vᴴ + end + function $f_svd_full!(driver::Driver, A, U, S, Vᴴ; fixgauge::Bool = true, kwargs...) + supports_svd_full(driver, $(QuoteNode(f))) || + throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) + isempty(A) && return one!(U), zero!(S), one!(Vᴴ) + zero!(S) + minmn = min(size(A)...) + $f_lapack!(driver, A, view(S, 1:minmn, 1), U, Vᴴ; kwargs...) + diagview(S) .= view(S, 1:minmn, 1) + zero!(view(S, 2:minmn, 1)) + fixgauge && gaugefix!(svd_full!, U, Vᴴ) + return U, S, Vᴴ + end + function $f_svd_vals!(driver::Driver, A, S; fixgauge::Bool = true, kwargs...) + supports_svd(driver, $(QuoteNode(f))) || + throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) + isempty(A) && return zero!(S) + U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) + $f_lapack!(driver, A, S, U, Vᴴ; kwargs...) + return S + end + end end +supports_svd(::Driver, ::Symbol) = false +supports_svd(::LAPACK, f::Symbol) = f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration, :bisection, :jacobi) +supports_svd(::GLA, f::Symbol) = f === :qr_iteration +supports_svd_full(::Driver, ::Symbol) = false +supports_svd_full(::LAPACK, f::Symbol) = f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration) +supports_svd_full(::GLA, f::Symbol) = f === :qr_iteration + function svd_trunc_no_error!(A, USVᴴ, alg::TruncatedAlgorithm) U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) @@ -294,11 +282,8 @@ function svd_vals!(A::AbstractMatrix, S, alg::DiagonalAlgorithm) return S end -# GPU logic -# --------- -# placed here to avoid code duplication since much of the logic is replicable across -# CUDA and AMDGPU -### +# GPU logic (randomized SVD - CUSOLVER_Randomized has no CPU analog, kept as-is) +# --------------------------------------------------------------------------------- function check_input( ::Union{typeof(svd_trunc!), typeof(svd_trunc_no_error!)}, A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized @@ -327,79 +312,11 @@ function initialize_output( return (U, S, Vᴴ) end -function _gpu_gesvd!( - A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix - ) - throw(MethodError(_gpu_gesvd!, (A, S, U, Vᴴ))) -end -function _gpu_Xgesvdp!( - A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs... - ) - throw(MethodError(_gpu_Xgesvdp!, (A, S, U, Vᴴ))) -end function _gpu_Xgesvdr!( A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs... ) throw(MethodError(_gpu_Xgesvdr!, (A, S, U, Vᴴ))) end -function _gpu_gesvdj!( - A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs... - ) - throw(MethodError(_gpu_gesvdj!, (A, S, U, Vᴴ))) -end -function _gpu_gesvd_maybe_transpose!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix) - m, n = size(A) - m ≥ n && return _gpu_gesvd!(A, S, U, Vᴴ) - # both CUSOLVER and ROCSOLVER require m ≥ n for gesvd (QR_Iteration) - # if this condition is not met, do the SVD via adjoint - minmn = min(m, n) - Aᴴ = min(m, n) > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A') - Uᴴ = similar(U') - V = similar(Vᴴ') - if size(U) == (m, m) - _gpu_gesvd!(Aᴴ, view(S, 1:minmn, 1), V, Uᴴ) - else - _gpu_gesvd!(Aᴴ, S, V, Uᴴ) - end - length(U) > 0 && adjoint!(U, Uᴴ) - length(Vᴴ) > 0 && adjoint!(Vᴴ, V) - return U, S, Vᴴ -end - -# GPU SVD implementation -function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) - check_input(svd_full!, A, USVᴴ, alg) - U, S, Vᴴ = USVᴴ - fill!(S, zero(eltype(S))) - m, n = size(A) - minmn = min(m, n) - if minmn == 0 - one!(U) - zero!(S) - one!(Vᴴ) - return USVᴴ - end - - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa GPU_QRIteration - isempty(alg_kwargs) || @warn "invalid keyword arguments for GPU_QRIteration" - _gpu_gesvd_maybe_transpose!(A, view(S, 1:minmn, 1), U, Vᴴ) - elseif alg isa GPU_SVDPolar - _gpu_Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg_kwargs...) - elseif alg isa GPU_Jacobi - _gpu_gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; alg_kwargs...) - else - throw(ArgumentError("Unsupported SVD algorithm")) - end - diagview(S) .= view(S, 1:minmn, 1) - view(S, 2:minmn, 1) .= zero(eltype(S)) - - do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ) - - return USVᴴ -end function svd_trunc_no_error!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) U, S, Vᴴ = USVᴴ @@ -427,61 +344,59 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran return Utr, Str, Vᴴtr, ϵ end -function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) - check_input(svd_compact!, A, USVᴴ, alg) - U, S, Vᴴ = USVᴴ - m, n = size(A) - minmn = min(m, n) - if minmn == 0 - one!(U) - zero!(S) - one!(Vᴴ) - return USVᴴ - end - - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa GPU_QRIteration - isempty(alg_kwargs) || @warn "invalid keyword arguments for GPU_QRIteration" - _gpu_gesvd_maybe_transpose!(A, diagview(S), U, Vᴴ) - elseif alg isa GPU_SVDPolar - _gpu_Xgesvdp!(A, diagview(S), U, Vᴴ; alg_kwargs...) - elseif alg isa GPU_Jacobi - _gpu_gesvdj!(A, diagview(S), U, Vᴴ; alg_kwargs...) - else - throw(ArgumentError("Unsupported SVD algorithm")) +# Deprecations +# ------------ +for algtype in (:SafeDivideAndConquer, :DivideAndConquer, :QRIteration, :Jacobi, :Bisection) + lapack_algtype = Symbol(:LAPACK_, algtype) + @eval begin + Base.@deprecate( + svd_compact!(A, USVᴴ, alg::$lapack_algtype), + svd_compact!(A, USVᴴ, $algtype(; driver = LAPACK(), alg.kwargs...)) + ) + Base.@deprecate( + svd_full!(A, USVᴴ, alg::$lapack_algtype), + svd_full!(A, USVᴴ, $algtype(; driver = LAPACK(), alg.kwargs...)) + ) + Base.@deprecate( + svd_vals!(A, S, alg::$lapack_algtype), + svd_vals!(A, S, $algtype(; driver = LAPACK(), alg.kwargs...)) + ) end - - do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ) - - return USVᴴ end -_argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x))) -_largest(x, y) = abs(x) < abs(y) ? y : x -function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm) - check_input(svd_vals!, A, S, alg) - m, n = size(A) - minmn = min(m, n) - if minmn == 0 - zero!(S) - return S - end - U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) - - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa GPU_QRIteration - isempty(alg_kwargs) || @warn "invalid keyword arguments for GPU_QRIteration" - _gpu_gesvd_maybe_transpose!(A, S, U, Vᴴ) - elseif alg isa GPU_SVDPolar - _gpu_Xgesvdp!(A, S, U, Vᴴ; alg_kwargs...) - elseif alg isa GPU_Jacobi - _gpu_gesvdj!(A, S, U, Vᴴ; alg_kwargs...) - else - throw(ArgumentError("Unsupported SVD algorithm")) +for (algtype, newtype, drivertype) in ( + (:CUSOLVER_QRIteration, :QRIteration, :CUSOLVER), + (:CUSOLVER_Jacobi, :Jacobi, :CUSOLVER), + (:CUSOLVER_SVDPolar, :SVDViaPolar, :CUSOLVER), + (:ROCSOLVER_QRIteration, :QRIteration, :ROCSOLVER), + (:ROCSOLVER_Jacobi, :Jacobi, :ROCSOLVER), + ) + @eval begin + Base.@deprecate( + svd_compact!(A, USVᴴ, alg::$algtype), + svd_compact!(A, USVᴴ, $newtype(; driver = $drivertype(), alg.kwargs...)) + ) + Base.@deprecate( + svd_full!(A, USVᴴ, alg::$algtype), + svd_full!(A, USVᴴ, $newtype(; driver = $drivertype(), alg.kwargs...)) + ) + Base.@deprecate( + svd_vals!(A, S, alg::$algtype), + svd_vals!(A, S, $newtype(; driver = $drivertype(), alg.kwargs...)) + ) end - - return S end + +# GLA_QRIteration SVD deprecations (eigh methods remain in the GLA extension) +Base.@deprecate( + svd_compact!(A, USVᴴ, alg::GLA_QRIteration), + svd_compact!(A, USVᴴ, QRIteration(; driver = GLA(), alg.kwargs...)) +) +Base.@deprecate( + svd_full!(A, USVᴴ, alg::GLA_QRIteration), + svd_full!(A, USVᴴ, QRIteration(; driver = GLA(), alg.kwargs...)) +) +Base.@deprecate( + svd_vals!(A, S, alg::GLA_QRIteration), + svd_vals!(A, S, QRIteration(; driver = GLA(), alg.kwargs...)) +) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index b9cf7595..59cbe76f 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -60,7 +60,6 @@ of `R` are non-negative. """ @algdef GLA_HouseholderQR -# TODO: @algdef LAPACK_HouseholderQL @algdef LAPACK_HouseholderRQ @@ -86,38 +85,118 @@ function Householder(; return Householder((; blocksize, driver, pivoted, positive)) end -default_householder_driver(A) = default_householder_driver(typeof(A)) -default_householder_driver(::Type) = Native() +""" + DivideAndConquer(; [driver], fixgauge = default_fixgauge()) -default_householder_driver(::Type{A}) where {A <: YALAPACK.MaybeBlasMat} = LAPACK() +Algorithm type to denote the algorithm for computing the eigenvalue decomposition of a Hermitian matrix, +or the singular value decomposition of a general matrix using the divide-and-conquer algorithm. -# note: StridedVector fallback is needed for handling reshaped parent types -default_householder_driver(::Type{A}) where {A <: StridedVector{<:BlasFloat}} = LAPACK() -default_householder_driver(::Type{<:SubArray{T, N, A}}) where {T, N, A} = - default_householder_driver(A) -default_householder_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = - default_householder_driver(A) +The optional `driver` symbol can be used to choose between different implementations of this algorithm. +$_fixgauge_docs +""" +@algdef DivideAndConquer +""" + SafeDivideAndConquer(; [driver], fixgauge = default_fixgauge()) + +Algorithm type to for computing the eigenvalue decomposition of a Hermitian matrix, +or the singular value decomposition of a general matrix using the divide-and-conquer algorithm, +with an additional fallback to the standard QR iteration algorithm in case the former fails to converge. + +The optional `driver` symbol can be used to choose between different implementations of this algorithm. +$_fixgauge_docs + +!!! warning + This approach requires a copy of the input matrix, and is thus the most memory intensive SVD strategy. + However, as it combines the speed of the Divide and Conquer algorithm with the robustness of the + QR Iteration algorithm, it is the default SVD strategy for LAPACK-based implementations in MatrixAlgebraKit. + +See also [`DivideAndConquer`](@ref) and [`QRIteration`](@ref). +""" +@algdef SafeDivideAndConquer + +""" + QRIteration(; [driver], fixgauge = default_fixgauge()) + +Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix, +or the singular value decomposition of a general matrix via QR iteration. + +The optional `driver` symbol can be used to choose between different implementations of this algorithm. +$_fixgauge_docs +""" +@algdef QRIteration + +""" + Bisection(; [driver], fixgauge = default_fixgauge()) + +Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix, +or the singular value decomposition of a general matrix via the bisection algorithm. + +The optional `driver` symbol can be used to choose between different implementations of this algorithm. +$_fixgauge_docs +""" +@algdef Bisection + +""" + Jacobi(; [driver], fixgauge = default_fixgauge()) + +Algorithm type for computing the singular value decomposition of a general matrix using the Jacobi algorithm. + +The optional `driver` symbol can be used to choose between different implementations of this algorithm. +$_fixgauge_docs +""" +@algdef Jacobi + +""" + SVDViaPolar(; [driver], fixgauge = default_fixgauge(), [tol]) + +Algorithm type to denote the algorithm for computing the singular value decomposition of a general +matrix via Halley's iterative algorithm for the polar decomposition followed by the Hermitian +eigenvalue decomposition of the positive definite factor. + +The optional `driver` symbol can be used to choose between different implementations of this algorithm. +$_fixgauge_docs +The tolerance `tol` can optionally be used to emit a warning if the decomposition failed to converge beyond that given value. +""" +@algdef SVDViaPolar + +for f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration, :bisection, :jacobi, :svd_polar) + default_f_driver = Symbol(:default_, f, :_driver) + @eval begin + $default_f_driver(A) = $default_f_driver(typeof(A)) + $default_f_driver(::Type) = Native() + + $default_f_driver(::Type{<:SubArray{T, N, A}}) where {T, N, A} = $default_f_driver(A) + $default_f_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = $default_f_driver(A) + end + + if f !== :svd_polar + @eval begin + $default_f_driver(::Type{A}) where {A <: YALAPACK.MaybeBlasMat} = LAPACK() + # note: StridedVector fallback is needed for handling reshaped parent types + $default_f_driver(::Type{A}) where {A <: StridedVector{<:BlasFloat}} = LAPACK() + end + end + +end # General Eigenvalue Decomposition # ------------------------------- """ - LAPACK_Simple(; fixgauge::Bool = true) + LAPACK_Simple(; fixgauge = default_fixgauge()) Algorithm type to denote the simple LAPACK driver for computing the Schur or non-Hermitian eigenvalue decomposition of a matrix. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, -see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef LAPACK_Simple """ - LAPACK_Expert(; fixgauge::Bool = true) + LAPACK_Expert(; fixgauge = default_fixgauge()) Algorithm type to denote the expert LAPACK driver for computing the Schur or non-Hermitian eigenvalue decomposition of a matrix. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, -see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef LAPACK_Expert @@ -134,45 +213,38 @@ eigenvalue decomposition of a non-Hermitian matrix. # Hermitian Eigenvalue Decomposition # ---------------------------------- """ - LAPACK_QRIteration(; fixgauge::Bool = true) + LAPACK_QRIteration(; fixgauge = default_fixgauge()) -Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a -Hermitian matrix, or the singular value decomposition of a general matrix using the -QR Iteration algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). +Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a Hermitian matrix, +or the singular value decomposition of a general matrix using the QR Iteration algorithm. +$_fixgauge_docs """ @algdef LAPACK_QRIteration """ - LAPACK_Bisection(; fixgauge::Bool = true) + LAPACK_Bisection(; fixgauge = default_fixgauge()) -Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a -Hermitian matrix, or the singular value decomposition of a general matrix using the -Bisection algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). +Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a Hermitian matrix, +or the singular value decomposition of a general matrix using the Bisection algorithm. +$_fixgauge_docs """ @algdef LAPACK_Bisection """ - LAPACK_DivideAndConquer(; fixgauge::Bool = true) + LAPACK_DivideAndConquer(; fixgauge = default_fixgauge()) -Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a -Hermitian matrix, or the singular value decomposition of a general matrix using the -Divide and Conquer algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). +Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a Hermitian matrix, +or the singular value decomposition of a general matrix using the Divide and Conquer algorithm. +$_fixgauge_docs """ @algdef LAPACK_DivideAndConquer """ - LAPACK_MultipleRelativelyRobustRepresentations(; fixgauge::Bool = true) + LAPACK_MultipleRelativelyRobustRepresentations(; fixgauge = default_fixgauge()) -Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a -Hermitian matrix using the Multiple Relatively Robust Representations algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, -see also [`gaugefix!`](@ref). +Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a Hermitian matrix +using the Multiple Relatively Robust Representations algorithm. +$_fixgauge_docs """ @algdef LAPACK_MultipleRelativelyRobustRepresentations @@ -184,26 +256,24 @@ const LAPACK_EighAlgorithm = Union{ } """ - GLA_QRIteration(; fixgauge::Bool = true) + GLA_QRIteration(; fixgauge = default_fixgauge()) Algorithm type to denote the GenericLinearAlgebra.jl implementation for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef GLA_QRIteration # Singular Value Decomposition # ---------------------------- """ - LAPACK_SafeDivideAndConquer(; fixgauge::Bool = true) + LAPACK_SafeDivideAndConquer(; fixgauge = default_fixgauge()) Algorithm type to denote the LAPACK driver for computing the singular value decomposition of a general matrix using the Divide and Conquer algorithm, with an additional fallback to the standard QR Iteration algorithm in case the former fails to converge. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular vectors, -see also [`gaugefix!`](@ref). +$_fixgauge_docs !!! warning This approach requires a copy of the input matrix, and is thus the most memory intensive SVD strategy. @@ -213,12 +283,11 @@ see also [`gaugefix!`](@ref). @algdef LAPACK_SafeDivideAndConquer """ - LAPACK_Jacobi(; fixgauge::Bool = true) + LAPACK_Jacobi(; fixgauge = default_fixgauge()) Algorithm type to denote the LAPACK driver for computing the singular value decomposition of a general matrix using the Jacobi algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular vectors, -see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef LAPACK_Jacobi @@ -291,34 +360,31 @@ the diagonal elements of `R` are non-negative. @algdef CUSOLVER_HouseholderQR """ - CUSOLVER_QRIteration(; fixgauge::Bool = true) + CUSOLVER_QRIteration(; fixgauge = default_fixgauge()) Algorithm type to denote the CUSOLVER driver for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the QR Iteration algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef CUSOLVER_QRIteration """ - CUSOLVER_SVDPolar(; fixgauge::Bool = true) + CUSOLVER_SVDPolar(; fixgauge = default_fixgauge()) Algorithm type to denote the CUSOLVER driver for computing the singular value decomposition of a general matrix by using Halley's iterative algorithm to compute the polar decompositon, followed by the hermitian eigenvalue decomposition of the positive definite factor. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular -vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef CUSOLVER_SVDPolar """ - CUSOLVER_Jacobi(; fixgauge::Bool = true) + CUSOLVER_Jacobi(; fixgauge = default_fixgauge()) Algorithm type to denote the CUSOLVER driver for computing the singular value decomposition of a general matrix using the Jacobi algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular -vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef CUSOLVER_Jacobi @@ -340,25 +406,23 @@ for more information. does_truncate(::TruncatedAlgorithm{<:CUSOLVER_Randomized}) = true """ - CUSOLVER_Simple(; fixgauge::Bool = true) + CUSOLVER_Simple(; fixgauge = default_fixgauge()) Algorithm type to denote the simple CUSOLVER driver for computing the non-Hermitian eigenvalue decomposition of a matrix. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, -see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef CUSOLVER_Simple const CUSOLVER_EigAlgorithm = Union{CUSOLVER_Simple} """ - CUSOLVER_DivideAndConquer(; fixgauge::Bool = true) + CUSOLVER_DivideAndConquer(; fixgauge = default_fixgauge()) Algorithm type to denote the CUSOLVER driver for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the Divide and Conquer algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef CUSOLVER_DivideAndConquer @@ -379,45 +443,41 @@ the diagonal elements of `R` are non-negative. @algdef ROCSOLVER_HouseholderQR """ - ROCSOLVER_QRIteration(; fixgauge::Bool = true) + ROCSOLVER_QRIteration(; fixgauge = default_fixgauge()) Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the QR Iteration algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef ROCSOLVER_QRIteration """ - ROCSOLVER_Jacobi(; fixgauge::Bool = true) + ROCSOLVER_Jacobi(; fixgauge = default_fixgauge()) Algorithm type to denote the ROCSOLVER driver for computing the singular value decomposition of a general matrix using the Jacobi algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular -vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef ROCSOLVER_Jacobi """ - ROCSOLVER_Bisection(; fixgauge::Bool = true) + ROCSOLVER_Bisection(; fixgauge = default_fixgauge()) Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the Bisection algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef ROCSOLVER_Bisection """ - ROCSOLVER_DivideAndConquer(; fixgauge::Bool = true) + ROCSOLVER_DivideAndConquer(; fixgauge = default_fixgauge()) Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the Divide and Conquer algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef ROCSOLVER_DivideAndConquer @@ -442,7 +502,7 @@ const GPU_Randomized = Union{CUSOLVER_Randomized} const QRAlgorithms = Union{Householder, LAPACK_HouseholderQR, Native_HouseholderQR, CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR} const LQAlgorithms = Union{Householder, LAPACK_HouseholderLQ, Native_HouseholderLQ, LQViaTransposedQR} -const SVDAlgorithms = Union{LAPACK_SVDAlgorithm, GPU_SVDAlgorithm} +const SVDAlgorithms = Union{SafeDivideAndConquer, DivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar} const PolarAlgorithms = Union{PolarViaSVD, PolarNewton} # ================================ diff --git a/src/interface/svd.jl b/src/interface/svd.jl index b973f6c4..3c99eea3 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -158,11 +158,9 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_compact(!)`](@ref svd_compact) an # Algorithm selection # ------------------- default_svd_algorithm(A; kwargs...) = default_svd_algorithm(typeof(A); kwargs...) -function default_svd_algorithm(T::Type; kwargs...) - throw(MethodError(default_svd_algorithm, (T,))) -end +default_svd_algorithm(T::Type; kwargs...) = throw(MethodError(default_svd_algorithm, (T,))) function default_svd_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat} - return LAPACK_SafeDivideAndConquer(; kwargs...) + return SafeDivideAndConquer(; kwargs...) end function default_svd_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} return DiagonalAlgorithm(; kwargs...) diff --git a/src/yalapack.jl b/src/yalapack.jl index 576fe3c5..c5131175 100644 --- a/src/yalapack.jl +++ b/src/yalapack.jl @@ -2303,7 +2303,7 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in end return (S, U, Vᴴ) end - function gesvj!( + function gesvdj!( A::AbstractMatrix{$elty}, S::AbstractVector{$relty} = similar(A, $relty, min(size(A)...)), U::AbstractMatrix{$elty} = similar( diff --git a/test/algorithms.jl b/test/algorithms.jl index 3acd8b6f..83566803 100644 --- a/test/algorithms.jl +++ b/test/algorithms.jl @@ -7,7 +7,7 @@ using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm, @testset "default_algorithm" begin A = randn(3, 3) for f in (svd_compact!, svd_compact, svd_full!, svd_full) - @test @constinferred(default_algorithm(f, A)) === LAPACK_SafeDivideAndConquer() + @test @constinferred(default_algorithm(f, A)) == SafeDivideAndConquer() end for f in (eig_full!, eig_full, eig_vals!, eig_vals) @test @constinferred(default_algorithm(f, A)) === LAPACK_Expert() @@ -21,7 +21,7 @@ using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm, end for f in (left_polar!, left_polar, right_polar!, right_polar) @test @constinferred(default_algorithm(f, A)) == - PolarViaSVD(LAPACK_SafeDivideAndConquer()) + PolarViaSVD(SafeDivideAndConquer()) end for f in (qr_full!, qr_full, qr_compact!, qr_compact, qr_null!, qr_null) @test @constinferred(default_algorithm(f, A)) == Householder() @@ -37,8 +37,8 @@ end @testset "select_algorithm" begin A = randn(3, 3) for f in (svd_trunc!, svd_trunc) - @test @constinferred(select_algorithm(f, A)) === - TruncatedAlgorithm(LAPACK_SafeDivideAndConquer(), notrunc()) + @test @constinferred(select_algorithm(f, A)) == + TruncatedAlgorithm(SafeDivideAndConquer(), notrunc()) end for f in (eig_trunc!, eig_trunc) @test @constinferred(select_algorithm(f, A)) === @@ -55,8 +55,8 @@ end @test_throws ArgumentError select_algorithm(eig_trunc!, A, alg; trunc = (; maxrank = 2)) end - @test @constinferred(select_algorithm(svd_compact!, A)) === LAPACK_SafeDivideAndConquer() - @test @constinferred(select_algorithm(svd_compact!, A, nothing)) === LAPACK_SafeDivideAndConquer() + @test @constinferred(select_algorithm(svd_compact!, A)) == SafeDivideAndConquer() + @test @constinferred(select_algorithm(svd_compact!, A, nothing)) == SafeDivideAndConquer() for alg in (:LAPACK_QRIteration, LAPACK_QRIteration, LAPACK_QRIteration()) @test @constinferred(select_algorithm(svd_compact!, A, $alg)) === LAPACK_QRIteration() end diff --git a/test/linearmap.jl b/test/linearmap.jl index a7adaae7..07063ac9 100644 --- a/test/linearmap.jl +++ b/test/linearmap.jl @@ -3,7 +3,7 @@ module LinearMaps export LinearMap using MatrixAlgebraKit - using MatrixAlgebraKit: AbstractAlgorithm, DiagonalAlgorithm, GLA_QRIteration + using MatrixAlgebraKit: AbstractAlgorithm using GenericLinearAlgebra import MatrixAlgebraKit as MAK @@ -31,18 +31,17 @@ module LinearMaps MAK.check_input($f!, parent(A), parent.(F), alg) @eval MAK.initialize_output(::typeof($f!), A::LinearMap, alg::AbstractAlgorithm) = LinearMap.(MAK.initialize_output($f!, parent(A), alg)) - @eval MAK.initialize_output(::typeof($f!), A::LinearMap, alg::GLA_QRIteration) = - (nothing, nothing, nothing) - @eval MAK.$f!(A::LinearMap, F, alg::AbstractAlgorithm) = - LinearMap.(MAK.$f!(parent(A), parent.(F), alg)) - @eval MAK.$f!(A::LinearMap, F, alg::GLA_QRIteration) = - LinearMap.(MAK.$f!(parent(A), F, alg)) - @eval MAK.check_input(::typeof($f!), A::LinearMap, F, alg::DiagonalAlgorithm) = - MAK.check_input($f!, parent(A), parent.(F), alg) - @eval MAK.initialize_output(::typeof($f!), A::LinearMap, alg::DiagonalAlgorithm) = - LinearMap.(MAK.initialize_output($f!, parent(A), alg)) - @eval MAK.$f!(A::LinearMap, F, alg::DiagonalAlgorithm) = - LinearMap.(MAK.$f!(parent(A), parent.(F), alg)) + end + + # Define svd_compact! and svd_full! for LinearMap with concrete algorithm types to avoid + # ambiguity with methods like `svd_compact!(A, USVᴴ, alg::SafeDivideAndConquer)`. + # Using AbstractAlgorithm here would be ambiguous since neither A-type nor alg-type would + # be strictly more specific. + for f! in (:svd_compact!, :svd_full!) + for Alg in (:SafeDivideAndConquer, :DivideAndConquer, :QRIteration, :Bisection, :Jacobi, :SVDViaPolar) + @eval MAK.$f!(A::LinearMap, USVᴴ, alg::MAK.$Alg) = + LinearMap.(MAK.$f!(parent(A), parent.(USVᴴ), alg)) + end end for f in (:qr, :lq, :svd) diff --git a/test/orthnull.jl b/test/orthnull.jl index dec12946..995e2607 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -13,8 +13,8 @@ using .TestSuite is_buildkite = get(ENV, "BUILDKITE", "false") == "true" -m = 54 -for T in (BLASFloats..., GenericFloats...), n in (37, m, 63) +m = 23 +for T in (BLASFloats..., GenericFloats...), n in (17, m, 27) TestSuite.seed_rng!(123) if T ∈ BLASFloats if CUDA.functional() diff --git a/test/svd.jl b/test/svd.jl index 800f191b..239c7544 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -14,55 +14,78 @@ using .TestSuite is_buildkite = get(ENV, "BUILDKITE", "false") == "true" -for T in (BLASFloats..., GenericFloats...), m in (0, 54), n in (0, 37, m, 63) - TestSuite.seed_rng!(123) - if T ∈ BLASFloats - if CUDA.functional() - TestSuite.test_svd(CuMatrix{T}, (m, n)) - CUDA_SVD_ALGS = ( - CUSOLVER_QRIteration(), - CUSOLVER_SVDPolar(), - CUSOLVER_Jacobi(), - ) - TestSuite.test_svd_algs(CuMatrix{T}, (m, n), CUDA_SVD_ALGS) - k = 5 - p = min(m, n) - k - 2 - min(m, n) > k + 2 && TestSuite.test_randomized_svd(CuMatrix{T}, (m, n), (MatrixAlgebraKit.TruncatedAlgorithm(CUSOLVER_Randomized(; k, p, niters = 20), truncrank(k)),)) - if n == m - TestSuite.test_svd(Diagonal{T, CuVector{T}}, m) - TestSuite.test_svd_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),)) - end - end - if AMDGPU.functional() - TestSuite.test_svd(ROCMatrix{T}, (m, n)) - AMD_SVD_ALGS = ( - ROCSOLVER_QRIteration(), - ROCSOLVER_Jacobi(), - ) - TestSuite.test_svd_algs(ROCMatrix{T}, (m, n), AMD_SVD_ALGS) - if n == m - TestSuite.test_svd(Diagonal{T, ROCVector{T}}, m) - TestSuite.test_svd_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),)) - end - end +# CPU tests +# --------- +if !is_buildkite + # LAPACK algorithms: + for T in BLASFloats, m in (0, 54), n in (0, 37, m, 63) + TestSuite.seed_rng!(123) + LAPACK_SVD_ALGS = (QRIteration(), DivideAndConquer(), SafeDivideAndConquer(; fixgauge = true)) + TestSuite.test_svd(T, (m, n)) + TestSuite.test_svd_algs(T, (m, n), LAPACK_SVD_ALGS) end - if !is_buildkite - if T ∈ BLASFloats - LAPACK_SVD_ALGS = ( - LAPACK_QRIteration(), - LAPACK_DivideAndConquer(), - LAPACK_SafeDivideAndConquer(; fixgauge = true), - ) - TestSuite.test_svd(T, (m, n)) - TestSuite.test_svd_algs(T, (m, n), LAPACK_SVD_ALGS) - elseif T ∈ GenericFloats - TestSuite.test_svd(T, (m, n)) - TestSuite.test_svd_algs(T, (m, n), (GLA_QRIteration(),)) - end - if m == n - AT = Diagonal{T, Vector{T}} - TestSuite.test_svd(AT, m) - TestSuite.test_svd_algs(AT, m, (DiagonalAlgorithm(),)) - end + + # Generic floats: + for T in GenericFloats, m in (0, 54), n in (0, 37, m, 63) + TestSuite.seed_rng!(123) + TestSuite.test_svd(T, (m, n)) + TestSuite.test_svd_algs(T, (m, n), (GLA_QRIteration(),)) + end + + # Diagonal: + for T in (BLASFloats..., GenericFloats...), m in (0, 54) + TestSuite.seed_rng!(123) + AT = Diagonal{T, Vector{T}} + TestSuite.test_svd(AT, m) + TestSuite.test_svd_algs(AT, m, (DiagonalAlgorithm(),)) + end +end + +# CUDA tests +# ------------ +if CUDA.functional() + # LAPACK algorithms: + for T in BLASFloats, m in (0, 23), n in (0, 17, m, 27) + TestSuite.seed_rng!(123) + TestSuite.test_svd(CuMatrix{T}, (m, n)) + CUDA_SVD_ALGS = (QRIteration(), SVDViaPolar(), Jacobi()) + TestSuite.test_svd_algs(CuMatrix{T}, (m, n), CUDA_SVD_ALGS) + end + + # Randomized SVD: + for T in BLASFloats, m in (0, 23), n in (0, 17, m, 27) + TestSuite.seed_rng!(123) + k = 5 + p = min(m, n) - k - 2 + p > 0 || continue + TestSuite.test_randomized_svd(CuMatrix{T}, (m, n), (MatrixAlgebraKit.TruncatedAlgorithm(CUSOLVER_Randomized(; k, p, niters = 20), truncrank(k)),)) + end + + # Diagonal: + for T in BLASFloats, m in (0, 23) + TestSuite.seed_rng!(123) + AT = Diagonal{T, CuVector{T}} + TestSuite.test_svd(AT, m) + TestSuite.test_svd_algs(AT, m, (DiagonalAlgorithm(),)) + end +end + +# AMDGPU tests +# ------------ +if AMDGPU.functional() + # LAPACK algorithms: + for T in BLASFloats, m in (0, 23), n in (0, 17, m, 27) + TestSuite.seed_rng!(123) + TestSuite.test_svd(ROCMatrix{T}, (m, n)) + AMD_SVD_ALGS = (QRIteration(), Jacobi()) + TestSuite.test_svd_algs(ROCMatrix{T}, (m, n), AMD_SVD_ALGS) + end + + # Diagonal: + for T in BLASFloats, m in (0, 23) + TestSuite.seed_rng!(123) + AT = Diagonal{T, ROCVector{T}} + TestSuite.test_svd(AT, m) + TestSuite.test_svd_algs(AT, m, (DiagonalAlgorithm(),)) end end