From 16e21f9cfda67378610c31c708a5212ca500a509 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 6 Mar 2026 14:21:56 -0500 Subject: [PATCH 01/17] add SVD algorithms change default algorithms --- src/implementations/svd.jl | 225 ++++++++++++++++---------------- src/interface/decompositions.jl | 55 ++++++++ src/interface/svd.jl | 6 +- 3 files changed, 166 insertions(+), 120 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index dfe1fa16..3f539c20 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -105,133 +105,106 @@ function initialize_output(::typeof(svd_vals!), A::Diagonal, ::DiagonalAlgorithm return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1)) end -# Implementation -# -------------- -function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) - check_input(svd_full!, A, USVᴴ, alg) - U, S, Vᴴ = USVᴴ - fill!(S, zero(eltype(S))) - m, n = size(A) - minmn = min(m, n) - if minmn == 0 - one!(U) - zero!(S) - one!(Vᴴ) - return USVᴴ - end - - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa LAPACK_QRIteration - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_QRIteration")) - YALAPACK.gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ) - elseif alg isa LAPACK_DivideAndConquer - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer")) - YALAPACK.gesdd!(A, view(S, 1:minmn, 1), U, Vᴴ) - elseif alg isa LAPACK_SafeDivideAndConquer - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) - YALAPACK.gesdvd!(A, view(S, 1:minmn, 1), U, Vᴴ) - elseif alg isa LAPACK_Bisection - throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) - elseif alg isa LAPACK_Jacobi - throw(ArgumentError("LAPACK_Jacobi is not supported for full SVD")) - else - throw(ArgumentError("Unsupported SVD algorithm")) - end - - for i in 2:minmn - S[i, i] = S[i, 1] - S[i, 1] = zero(eltype(S)) - end - - do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ) +# ========================== +# IMPLEMENTATIONS +# ========================== - return USVᴴ +for f! in (:gesdd!, :gesvd!, :gesvdj!, :gesvdp!, :gesvdx!, :gesvdr!, :gesdvd!) + @eval $f!(driver::Driver, args...) = throw(ArgumentError("$driver does not provide $f!")) end -function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) - check_input(svd_compact!, A, USVᴴ, alg) - U, S, Vᴴ = USVᴴ - m, n = size(A) - minmn = min(m, n) - if minmn == 0 - one!(U) - zero!(S) - one!(Vᴴ) - return USVᴴ - end - - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa LAPACK_QRIteration - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_QRIteration")) - YALAPACK.gesvd!(A, diagview(S), U, Vᴴ) - elseif alg isa LAPACK_DivideAndConquer - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer")) - YALAPACK.gesdd!(A, diagview(S), U, Vᴴ) - elseif alg isa LAPACK_SafeDivideAndConquer - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) - YALAPACK.gesdvd!(A, diagview(S), U, Vᴴ) - elseif alg isa LAPACK_Bisection - YALAPACK.gesvdx!(A, diagview(S), U, Vᴴ; alg_kwargs...) - elseif alg isa LAPACK_Jacobi - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_Jacobi")) - YALAPACK.gesvj!(A, diagview(S), U, Vᴴ) - else - throw(ArgumentError("Unsupported SVD algorithm")) - end - - do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ) - - return USVᴴ +# LAPACK +for f! in (:gesdd!, :gesvd!, :gesvdj!, :gesvdx!, :gesdvd!) + @eval $f!(::LAPACK, args...; kwargs...) = YALAPACK.$f!(args...; kwargs...) end -function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) - check_input(svd_vals!, A, S, alg) - m, n = size(A) - minmn = min(m, n) - if minmn == 0 - zero!(S) - return S +for (f, f_lapack!, Alg) in ( + (:safe_divide_and_conquer, :gesdvd!, :SafeDivideAndConquer), + (:divide_and_conquer, :gesdd!, :DivideAndConquer), + (:qr_iteration, :gesvd!, :QRIteration), + (:bisection, :gesvdx!, :Bisection), + (:jacobi, :gesvdj!, :Jacobi), + ) + f_svd! = Symbol(f, :_svd!) + f_svd_full! = Symbol(f, :_svd_full!) + f_svd_vals! = Symbol(f, :_svd_vals!) + + # MatrixAlgebraKit wrappers + @eval begin + function svd_compact!(A, USVᴴ, alg::$Alg) + check_input(svd_compact!, A, USVᴴ, alg) + return $f_svd!(A, USVᴴ...; alg.kwargs...) + end + function svd_full!(A, USVᴴ, alg::$Alg) + check_input(svd_full!, A, USVᴴ, alg) + return $f_svd_full!(A, USVᴴ...; alg.kwargs...) + end + function svd_vals!(A, S, alg::$Alg) + check_input(svd_vals!, A, S, alg) + return $f_svd_vals!(A, S; alg.kwargs...) + end end - U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) - - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - if alg isa LAPACK_QRIteration - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_QRIteration")) - YALAPACK.gesvd!(A, S, U, Vᴴ) - elseif alg isa LAPACK_DivideAndConquer - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer")) - YALAPACK.gesdd!(A, S, U, Vᴴ) - elseif alg isa LAPACK_SafeDivideAndConquer - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) - YALAPACK.gesdvd!(A, S, U, Vᴴ) - elseif alg isa LAPACK_Bisection - YALAPACK.gesvdx!(A, S, U, Vᴴ; alg_kwargs...) - elseif alg isa LAPACK_Jacobi - isempty(alg_kwargs) || - throw(ArgumentError("invalid keyword arguments for LAPACK_Jacobi")) - YALAPACK.gesvj!(A, S, U, Vᴴ) - else - throw(ArgumentError("Unsupported SVD algorithm")) + # driver + @eval begin + @inline $f_svd!(A, U, S, Vᴴ; driver::Driver = DefaultDriver(), kwargs...) = + $f_svd!(driver, A, U, S, Vᴴ; kwargs...) + @inline $f_svd_full!(A, U, S, Vᴴ; driver::Driver = DefaultDriver(), kwargs...) = + $f_svd_full!(driver, A, U, S, Vᴴ; kwargs...) + @inline $f_svd_vals!(A, S; driver::Driver = DefaultDriver(), kwargs...) = + $f_svd_vals!(driver, A, S; kwargs...) + @inline $f_svd!(::DefaultDriver, A, U, S, Vᴴ; kwargs...) = + $f_svd!($(Symbol(:default_, f, :_driver))(A), A, U, S, Vᴴ; kwargs...) + @inline $f_svd_full!(::DefaultDriver, A, S; kwargs...) = + $f_svd_full!($(Symbol(:default_, f, :_driver)), A, S; kwargs...) + @inline $f_svd_vals!(::DefaultDriver, A, S; kwargs...) = + $f_svd_vals!($(Symbol(:default_, f, :_driver)), A, S; kwargs...) end - return S + # Implementation + @eval begin + function $f_svd!( + driver::Driver, A::AbstractMatrix, U::AbstractMatrix, S::AbstractMatrix, Vᴴ::AbstractMatrix; + fixgauge::Bool = true, kwargs... + ) + supports_svd(driver, $(QuoteNode(f))) || throw(ArgumentError(lazy"$driver does not provide $f")) + isempty(A) && return one!(U), zero!(S), one!(Vᴴ) + $f_lapack!(driver, A, view(S, 1:minmn, 1), U, Vᴴ; kwargs...) + fixgauge && gaugefix!(svd_compact!, U, Vᴴ) + return U, S, Vᴴ + end + function $f_svd_full!( + driver::Driver, A::AbstractMatrix, U::AbstractMatrix, S::AbstractMatrix, Vᴴ::AbstractMatrix; + fixgauge::Bool = true, kwargs... + ) + supports_svd_full(driver, $(QuoteNode(f))) || throw(ArgumentError(lazy"$driver does not provide $f")) + isempty(A) && return one!(U), zero!(S), one!(Vᴴ) + zero!(S) + minmn = min(size(A)...) + $f_lapack!(driver, A, view(S, 1:minmn, 1), U, Vᴴ; kwargs...) + diagview(S) .= view(S, 1:minmn, 1) + view(S, 2:minmn, 1) .= zero(eltype(S)) + fixgauge && gaugefix!(svd_full!, U, Vᴴ) + return U, S, Vᴴ + end + function $f_svd_vals!( + driver::Driver, A::AbstractMatrix, S::AbstractVector; + fixgauge::Bool = true, kwargs... + ) + supports_svd(driver, $(QuoteNode(f))) || throw(ArgumentError(lazy"$driver does not provide $f")) + isempty(A) && return zero!(S) + U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) + $f_lapack!(driver, A, view(S, 1:minmn, 1), U, Vᴴ; kwargs...) + return S + end + end end +supports_svd(::Driver, ::Symbol) = false +supports_svd(::LAPACK, f::Symbol) = f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration, :bisection, :jacobi) +supports_svd_full(::Driver, ::Symbol) = false +supports_svd_full(::LAPACK, f::Symbol) = f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration) + function svd_trunc_no_error!(A, USVᴴ, alg::TruncatedAlgorithm) U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) @@ -485,3 +458,23 @@ function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm) return S end + +# Deprecations +# ------------ +for algtype in (:DivideAndConquer, :QRIteration, :Jacobi, :Bisection) + algtype = Symbol(:LAPACK_, algtype) + @eval begin + Base.@deprecate( + svd_compact!(A, USVᴴ, alg::$algtype), + svd_compact!(A, USVᴴ, $algtype(; driver = LAPACK(), alg.kwargs...)) + ) + Base.@deprecate( + svd_full!(A, USVᴴ, alg::$algtype), + svd_full!(A, USVᴴ, $algtype(; driver = LAPACK(), alg.kwargs...)) + ) + Base.@deprecate( + svd_vals!(A, S, alg::$algtype), + svd_vals!(A, S, $algtype(; driver = LAPACK(), alg.kwargs...)) + ) + end +end diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index b9cf7595..2bd5c76f 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -98,6 +98,61 @@ default_householder_driver(::Type{<:SubArray{T, N, A}}) where {T, N, A} = default_householder_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = default_householder_driver(A) +""" + DivideAndConquer(; [driver], kwargs...) + +Algorithm type to denote the algorithm for computing the eigenvalue decomposition of a Hermitian matrix, +or the singular value decomposition of a general matrix using the divide-and-conquer algorithm. + +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`gaugefix!`](@ref). +""" +@algdef DivideAndConquer + +""" + SafeDivideAndConquer(; [driver], kwargs...) + +Algorithm type to denote the algorithm for computing the eigenvalue decomposition of a Hermitian matrix, +or the singular value decomposition of a general matrix using the divide-and-conquer algorithm, +with an additional fallback to the standard QR iteration algorithm in case the former fails to converge. + +The optional `driver` symbol can be used to choose between different implementations of this algorithm. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`gaugefix!`](@ref). + +!!! warning + This approach requires a copy of the input matrix, and is thus the most memory intensive SVD strategy. + However, as it combines the speed of the Divide and Conquer algorithm with the robustness of the + QR Iteration algorithm, it is the default SVD strategy for LAPACK-based implementations in MatrixAlgebraKit. + +See also [`DivideAndConquer`](@ref) and [`QRIteration`](@ref). +""" +@algdef SafeDivideAndConquer + +""" + QRIteration(; [driver], kwargs...) + +Algorithm type to denote the algorithm for computing the eigenvalue decomposition of a Hermitian matrix, +or the singular value decomposition of a general matrix via QR iteration. + +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`gaugefix!`](@ref). +""" +@algdef QRIteration +@algdef Bisection +@algdef Jacobi + +for f in (:divide_and_conquer, :qr_iteration, :bisection, :jacobi) + default_f_driver = Symbol(:default_, f, :_driver) + @eval begin + $default_f_driver(A) = $default_f_driver(typeof(A)) + $default_f_driver(::Type) = Native() + + $default_f_driver(::Type{A}) where {A <: YALAPACK.MaybeBlasMat} = LAPACK() + + # note: StridedVector fallback is needed for handling reshaped parent types + $default_f_driver(::Type{A}) where {A <: StridedVector{<:BlasFloat}} = LAPACK() + $default_f_driver(::Type{<:SubArray{T, N, A}}) where {T, N, A} = $default_f_driver(A) + $default_f_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = $default_f_driver(A) + end +end # General Eigenvalue Decomposition # ------------------------------- diff --git a/src/interface/svd.jl b/src/interface/svd.jl index b973f6c4..3c99eea3 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -158,11 +158,9 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_compact(!)`](@ref svd_compact) an # Algorithm selection # ------------------- default_svd_algorithm(A; kwargs...) = default_svd_algorithm(typeof(A); kwargs...) -function default_svd_algorithm(T::Type; kwargs...) - throw(MethodError(default_svd_algorithm, (T,))) -end +default_svd_algorithm(T::Type; kwargs...) = throw(MethodError(default_svd_algorithm, (T,))) function default_svd_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat} - return LAPACK_SafeDivideAndConquer(; kwargs...) + return SafeDivideAndConquer(; kwargs...) end function default_svd_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} return DiagonalAlgorithm(; kwargs...) From abdf12ebb44fb4937c147f424a2a3f9029f77efe Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 6 Mar 2026 15:49:02 -0500 Subject: [PATCH 02/17] uniformize names --- ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl | 10 +--------- ext/MatrixAlgebraKitCUDAExt/yacusolver.jl | 6 +++--- src/yalapack.jl | 2 +- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 937027a4..d47fabcd 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -6,7 +6,7 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm -import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev! +import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!, _gpu_geev! import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester, svd_rank using CUDA, CUDA.CUBLAS using CUDA: i32 @@ -32,14 +32,6 @@ end _gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) = YACUSOLVER.Xgeev!(A, D, V) -_gpu_gesvd!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix) = - YACUSOLVER.gesvd!(A, S, U, Vᴴ) -_gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = - YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...) -_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = - YACUSOLVER.Xgesvdr!(A, S, U, Vᴴ; kwargs...) -_gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = - YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) _gpu_heevj!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) = YACUSOLVER.heevj!(A, Dd, V; kwargs...) diff --git a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl index aba01fe6..37da64da 100644 --- a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl +++ b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl @@ -98,7 +98,7 @@ for (bname, fname, elty, relty) in end end -function Xgesvdp!( +function gesvdp!( A::StridedCuMatrix{T}, S::StridedCuVector = similar(A, real(T), min(size(A)...)), U::StridedCuMatrix{T} = similar(A, T, size(A, 1), min(size(A)...)), @@ -165,7 +165,7 @@ function Xgesvdp!( end err = h_err_sigma[] if err > tol - warn("Xgesvdp! did not attained requested tolerance: error = $err > tolerance = $tol") + warn("gesvdp! did not attained requested tolerance: error = $err > tolerance = $tol") end flag = @allowscalar dh.info[1] @@ -269,7 +269,7 @@ for (bname, fname, elty, relty) in end # Wrapper for randomized SVD -function Xgesvdr!( +function gesvdr!( A::StridedCuMatrix{T}, S::StridedCuVector = similar(A, real(T), min(size(A)...)), U::StridedCuMatrix{T} = similar(A, T, size(A, 1), min(size(A)...)), diff --git a/src/yalapack.jl b/src/yalapack.jl index 576fe3c5..c5131175 100644 --- a/src/yalapack.jl +++ b/src/yalapack.jl @@ -2303,7 +2303,7 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in end return (S, U, Vᴴ) end - function gesvj!( + function gesvdj!( A::AbstractMatrix{$elty}, S::AbstractVector{$relty} = similar(A, $relty, min(size(A)...)), U::AbstractMatrix{$elty} = similar( From d48e209547236f70493df4d6b7a93f651ed92065 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 16 Mar 2026 11:19:53 -0400 Subject: [PATCH 03/17] more cleanup and incorporate safe_svd --- src/implementations/svd.jl | 46 ++++++++++++++------------------- src/interface/decompositions.jl | 2 +- 2 files changed, 21 insertions(+), 27 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 3f539c20..f197c080 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -155,46 +155,40 @@ for (f, f_lapack!, Alg) in ( $f_svd_vals!(driver, A, S; kwargs...) @inline $f_svd!(::DefaultDriver, A, U, S, Vᴴ; kwargs...) = $f_svd!($(Symbol(:default_, f, :_driver))(A), A, U, S, Vᴴ; kwargs...) - @inline $f_svd_full!(::DefaultDriver, A, S; kwargs...) = - $f_svd_full!($(Symbol(:default_, f, :_driver)), A, S; kwargs...) + @inline $f_svd_full!(::DefaultDriver, A, U, S, Vᴴ; kwargs...) = + $f_svd_full!($(Symbol(:default_, f, :_driver))(A), A, U, S, Vᴴ; kwargs...) @inline $f_svd_vals!(::DefaultDriver, A, S; kwargs...) = - $f_svd_vals!($(Symbol(:default_, f, :_driver)), A, S; kwargs...) + $f_svd_vals!($(Symbol(:default_, f, :_driver))(A), A, S; kwargs...) end # Implementation @eval begin - function $f_svd!( - driver::Driver, A::AbstractMatrix, U::AbstractMatrix, S::AbstractMatrix, Vᴴ::AbstractMatrix; - fixgauge::Bool = true, kwargs... - ) - supports_svd(driver, $(QuoteNode(f))) || throw(ArgumentError(lazy"$driver does not provide $f")) + function $f_svd!(driver::Driver, A, U, S, Vᴴ; fixgauge::Bool = true, kwargs...) + supports_svd(driver, $(QuoteNode(f))) || + throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) isempty(A) && return one!(U), zero!(S), one!(Vᴴ) - $f_lapack!(driver, A, view(S, 1:minmn, 1), U, Vᴴ; kwargs...) + $f_lapack!(driver, A, diagview(S), U, Vᴴ; kwargs...) fixgauge && gaugefix!(svd_compact!, U, Vᴴ) return U, S, Vᴴ end - function $f_svd_full!( - driver::Driver, A::AbstractMatrix, U::AbstractMatrix, S::AbstractMatrix, Vᴴ::AbstractMatrix; - fixgauge::Bool = true, kwargs... - ) - supports_svd_full(driver, $(QuoteNode(f))) || throw(ArgumentError(lazy"$driver does not provide $f")) + function $f_svd_full!(driver::Driver, A, U, S, Vᴴ; fixgauge::Bool = true, kwargs...) + supports_svd_full(driver, $(QuoteNode(f))) || + throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) isempty(A) && return one!(U), zero!(S), one!(Vᴴ) zero!(S) minmn = min(size(A)...) $f_lapack!(driver, A, view(S, 1:minmn, 1), U, Vᴴ; kwargs...) diagview(S) .= view(S, 1:minmn, 1) - view(S, 2:minmn, 1) .= zero(eltype(S)) + zero!(view(S, 2:minmn, 1)) fixgauge && gaugefix!(svd_full!, U, Vᴴ) return U, S, Vᴴ end - function $f_svd_vals!( - driver::Driver, A::AbstractMatrix, S::AbstractVector; - fixgauge::Bool = true, kwargs... - ) - supports_svd(driver, $(QuoteNode(f))) || throw(ArgumentError(lazy"$driver does not provide $f")) + function $f_svd_vals!(driver::Driver, A, S; fixgauge::Bool = true, kwargs...) + supports_svd(driver, $(QuoteNode(f))) || + throw(ArgumentError(LazyString("driver ", driver, " does not provide `$($(QuoteNode(f_lapack!)))`"))) isempty(A) && return zero!(S) U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) - $f_lapack!(driver, A, view(S, 1:minmn, 1), U, Vᴴ; kwargs...) + $f_lapack!(driver, A, S, U, Vᴴ; kwargs...) return S end end @@ -461,19 +455,19 @@ end # Deprecations # ------------ -for algtype in (:DivideAndConquer, :QRIteration, :Jacobi, :Bisection) - algtype = Symbol(:LAPACK_, algtype) +for algtype in (:SafeDivideAndConquer, :DivideAndConquer, :QRIteration, :Jacobi, :Bisection) + lapack_algtype = Symbol(:LAPACK_, algtype) @eval begin Base.@deprecate( - svd_compact!(A, USVᴴ, alg::$algtype), + svd_compact!(A, USVᴴ, alg::$lapack_algtype), svd_compact!(A, USVᴴ, $algtype(; driver = LAPACK(), alg.kwargs...)) ) Base.@deprecate( - svd_full!(A, USVᴴ, alg::$algtype), + svd_full!(A, USVᴴ, alg::$lapack_algtype), svd_full!(A, USVᴴ, $algtype(; driver = LAPACK(), alg.kwargs...)) ) Base.@deprecate( - svd_vals!(A, S, alg::$algtype), + svd_vals!(A, S, alg::$lapack_algtype), svd_vals!(A, S, $algtype(; driver = LAPACK(), alg.kwargs...)) ) end diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 2bd5c76f..ed7d1ace 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -139,7 +139,7 @@ The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of @algdef Bisection @algdef Jacobi -for f in (:divide_and_conquer, :qr_iteration, :bisection, :jacobi) +for f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration, :bisection, :jacobi) default_f_driver = Symbol(:default_, f, :_driver) @eval begin $default_f_driver(A) = $default_f_driver(typeof(A)) From bb5c7ed484a62172dfc5487ca99f649334b4ea36 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 16 Mar 2026 13:13:14 -0400 Subject: [PATCH 04/17] incorporate changes for GPU and GLA --- .../MatrixAlgebraKitAMDGPUExt.jl | 37 ++-- .../MatrixAlgebraKitCUDAExt.jl | 34 +++- ...MatrixAlgebraKitGenericLinearAlgebraExt.jl | 50 ++--- src/MatrixAlgebraKit.jl | 1 + src/implementations/svd.jl | 182 +++++------------- src/interface/decompositions.jl | 17 ++ test/algorithms.jl | 12 +- 7 files changed, 152 insertions(+), 181 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index ebec311e..ed9ff61f 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -6,7 +6,7 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe using MatrixAlgebraKit: ROCSOLVER, LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm -import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj! +import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdj! import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank using AMDGPU using LinearAlgebra @@ -14,11 +14,13 @@ using LinearAlgebra: BlasFloat include("yarocsolver.jl") -MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedROCMatrix{<:BlasFloat}} = ROCSOLVER() -function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix} - return ROCSOLVER_QRIteration(; kwargs...) +MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedROCVecOrMat{<:BlasFloat}} = ROCSOLVER() +MatrixAlgebraKit.default_qr_iteration_driver(::Type{<:StridedROCVecOrMat}) = ROCSOLVER() +MatrixAlgebraKit.default_jacobi_driver(::Type{<:StridedROCVecOrMat}) = ROCSOLVER() +function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat} + return QRIteration(; kwargs...) end -function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix} +function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat} return ROCSOLVER_DivideAndConquer(; kwargs...) end @@ -26,12 +28,25 @@ for f in (:geqrf!, :ungqr!, :unmqr!) @eval $f(::ROCSOLVER, args...) = YArocSOLVER.$f(args...) end -_gpu_gesvd!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix) = - YArocSOLVER.gesvd!(A, S, U, Vᴴ) -# not yet supported -# _gpu_Xgesvdp!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = -# YArocSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...) -_gpu_gesvdj!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = +function gesvd!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) + m, n = size(A) + m >= n && return YArocSOLVER.gesvd!(A, S, U, Vᴴ) + # ROCSOLVER requires m ≥ n; compute SVD via adjoint when m < n + minmn = min(m, n) + Aᴴ = minmn > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A') + Uᴴ = similar(U') + V = similar(Vᴴ') + if size(U) == (m, m) + YArocSOLVER.gesvd!(Aᴴ, view(S, 1:minmn, 1), V, Uᴴ) + else + YArocSOLVER.gesvd!(Aᴴ, S, V, Uᴴ) + end + length(U) > 0 && adjoint!(U, Uᴴ) + length(Vᴴ) > 0 && adjoint!(Vᴴ, V) + return S, U, Vᴴ +end + +gesvdj!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) _gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevj!(A, Dd, V; kwargs...) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index d47fabcd..e848560e 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -7,7 +7,7 @@ using MatrixAlgebraKit: diagview, sign_safe using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!, _gpu_geev! -import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester, svd_rank +import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_Xgesvdr!, _sylvester, svd_rank using CUDA, CUDA.CUBLAS using CUDA: i32 using LinearAlgebra @@ -16,8 +16,11 @@ using LinearAlgebra: BlasFloat include("yacusolver.jl") MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER() +MatrixAlgebraKit.default_qr_iteration_driver(::Type{<:StridedCuVecOrMat{<:BlasFloat}}) = CUSOLVER() +MatrixAlgebraKit.default_jacobi_driver(::Type{<:StridedCuVecOrMat{<:BlasFloat}}) = CUSOLVER() +MatrixAlgebraKit.default_svd_polar_driver(::Type{<:StridedCuVecOrMat{<:BlasFloat}}) = CUSOLVER() function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}} - return CUSOLVER_QRIteration(; kwargs...) + return QRIteration(; kwargs...) end function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}} return CUSOLVER_Simple(; kwargs...) @@ -30,6 +33,33 @@ for f in (:geqrf!, :ungqr!, :unmqr!) @eval $f(::CUSOLVER, args...) = YACUSOLVER.$f(args...) end +function gesvd!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) + m, n = size(A) + m >= n && return YACUSOLVER.gesvd!(A, S, U, Vᴴ) + # CUSOLVER requires m ≥ n; compute SVD via adjoint when m < n + minmn = min(m, n) + Aᴴ = minmn > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A') + Uᴴ = similar(U') + V = similar(Vᴴ') + if size(U) == (m, m) + YACUSOLVER.gesvd!(Aᴴ, view(S, 1:minmn, 1), V, Uᴴ) + else + YACUSOLVER.gesvd!(Aᴴ, S, V, Uᴴ) + end + length(U) > 0 && adjoint!(U, Uᴴ) + length(Vᴴ) > 0 && adjoint!(Vᴴ, V) + return S, U, Vᴴ +end + +gesvdj!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = + YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) + +gesvdp!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = + YACUSOLVER.gesvdp!(A, S, U, Vᴴ; kwargs...) + +_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = + YACUSOLVER.gesvdr!(A, S, U, Vᴴ; kwargs...) + _gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) = YACUSOLVER.Xgeev!(A, D, V) diff --git a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl index 629972cf..62d767a4 100644 --- a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl +++ b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl @@ -2,42 +2,32 @@ module MatrixAlgebraKitGenericLinearAlgebraExt using MatrixAlgebraKit using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge +using MatrixAlgebraKit: GLA +import MatrixAlgebraKit: gesvd! using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr! using LinearAlgebra: I, Diagonal, lmul! -function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} - return GLA_QRIteration() -end - -for f! in (:svd_compact!, :svd_full!) - @eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing, nothing) -end -MatrixAlgebraKit.initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing +MatrixAlgebraKit.default_qr_iteration_driver(::Type{<:StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}) = GLA() -function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration) - F = svd!(A) - U, S, Vᴴ = F.U, Diagonal(F.S), F.Vt - - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ) - - return U, S, Vᴴ -end - -function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration) - F = svd!(A; full = true) - U, Vᴴ = F.U, F.Vt - S = MatrixAlgebraKit.zero!(similar(F.S, (size(U, 2), size(Vᴴ, 1)))) - diagview(S) .= F.S - - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ) - - return U, S, Vᴴ +function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} + return QRIteration(; kwargs...) end -function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, ::GLA_QRIteration) - return svdvals!(A) +function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...) + m, n = size(A) + if length(U) == 0 && length(Vᴴ) == 0 + Sv = svdvals!(A) + copyto!(S, Sv) + else + minmn = min(m, n) + # full SVD if U has m columns or Vᴴ has n rows (beyond the compact min(m,n)) + full = (length(U) > 0 && size(U, 2) > minmn) || (length(Vᴴ) > 0 && size(Vᴴ, 1) > minmn) + F = svd!(A; full = full) + length(S) > 0 && copyto!(S, F.S) + length(U) > 0 && copyto!(U, F.U) + length(Vᴴ) > 0 && copyto!(Vᴴ, F.Vt) + end + return S, U, Vᴴ end function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 2d5e251e..2380395c 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -32,6 +32,7 @@ export left_orth, right_orth, left_null, right_null export left_orth!, right_orth!, left_null!, right_null! export Householder, Native_HouseholderQR, Native_HouseholderLQ +export DivideAndConquer, SafeDivideAndConquer, QRIteration, Bisection, Jacobi, SVDPolar export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations, LAPACK_DivideAndConquer, LAPACK_Jacobi, LAPACK_SafeDivideAndConquer diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index f197c080..2a83f596 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -105,6 +105,10 @@ function initialize_output(::typeof(svd_vals!), A::Diagonal, ::DiagonalAlgorithm return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1)) end +# Helper functions used by gauge.jl +_argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x))) +_largest(x, y) = abs(x) < abs(y) ? y : x + # ========================== # IMPLEMENTATIONS # ========================== @@ -124,6 +128,7 @@ for (f, f_lapack!, Alg) in ( (:qr_iteration, :gesvd!, :QRIteration), (:bisection, :gesvdx!, :Bisection), (:jacobi, :gesvdj!, :Jacobi), + (:svd_polar, :gesvdp!, :SVDPolar), ) f_svd! = Symbol(f, :_svd!) f_svd_full! = Symbol(f, :_svd_full!) @@ -196,8 +201,14 @@ end supports_svd(::Driver, ::Symbol) = false supports_svd(::LAPACK, f::Symbol) = f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration, :bisection, :jacobi) +supports_svd(::GLA, f::Symbol) = f === :qr_iteration +supports_svd(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar) +supports_svd(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi) supports_svd_full(::Driver, ::Symbol) = false supports_svd_full(::LAPACK, f::Symbol) = f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration) +supports_svd_full(::GLA, f::Symbol) = f === :qr_iteration +supports_svd_full(::CUSOLVER, f::Symbol) = f === :qr_iteration +supports_svd_full(::ROCSOLVER, f::Symbol) = f === :qr_iteration function svd_trunc_no_error!(A, USVᴴ, alg::TruncatedAlgorithm) U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) @@ -261,11 +272,8 @@ function svd_vals!(A::AbstractMatrix, S, alg::DiagonalAlgorithm) return S end -# GPU logic -# --------- -# placed here to avoid code duplication since much of the logic is replicable across -# CUDA and AMDGPU -### +# GPU logic (randomized SVD - CUSOLVER_Randomized has no CPU analog, kept as-is) +# --------------------------------------------------------------------------------- function check_input( ::Union{typeof(svd_trunc!), typeof(svd_trunc_no_error!)}, A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized @@ -294,79 +302,11 @@ function initialize_output( return (U, S, Vᴴ) end -function _gpu_gesvd!( - A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix - ) - throw(MethodError(_gpu_gesvd!, (A, S, U, Vᴴ))) -end -function _gpu_Xgesvdp!( - A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs... - ) - throw(MethodError(_gpu_Xgesvdp!, (A, S, U, Vᴴ))) -end function _gpu_Xgesvdr!( A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs... ) throw(MethodError(_gpu_Xgesvdr!, (A, S, U, Vᴴ))) end -function _gpu_gesvdj!( - A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs... - ) - throw(MethodError(_gpu_gesvdj!, (A, S, U, Vᴴ))) -end -function _gpu_gesvd_maybe_transpose!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix) - m, n = size(A) - m ≥ n && return _gpu_gesvd!(A, S, U, Vᴴ) - # both CUSOLVER and ROCSOLVER require m ≥ n for gesvd (QR_Iteration) - # if this condition is not met, do the SVD via adjoint - minmn = min(m, n) - Aᴴ = min(m, n) > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A') - Uᴴ = similar(U') - V = similar(Vᴴ') - if size(U) == (m, m) - _gpu_gesvd!(Aᴴ, view(S, 1:minmn, 1), V, Uᴴ) - else - _gpu_gesvd!(Aᴴ, S, V, Uᴴ) - end - length(U) > 0 && adjoint!(U, Uᴴ) - length(Vᴴ) > 0 && adjoint!(Vᴴ, V) - return U, S, Vᴴ -end - -# GPU SVD implementation -function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) - check_input(svd_full!, A, USVᴴ, alg) - U, S, Vᴴ = USVᴴ - fill!(S, zero(eltype(S))) - m, n = size(A) - minmn = min(m, n) - if minmn == 0 - one!(U) - zero!(S) - one!(Vᴴ) - return USVᴴ - end - - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa GPU_QRIteration - isempty(alg_kwargs) || @warn "invalid keyword arguments for GPU_QRIteration" - _gpu_gesvd_maybe_transpose!(A, view(S, 1:minmn, 1), U, Vᴴ) - elseif alg isa GPU_SVDPolar - _gpu_Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg_kwargs...) - elseif alg isa GPU_Jacobi - _gpu_gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; alg_kwargs...) - else - throw(ArgumentError("Unsupported SVD algorithm")) - end - diagview(S) .= view(S, 1:minmn, 1) - view(S, 2:minmn, 1) .= zero(eltype(S)) - - do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ) - - return USVᴴ -end function svd_trunc_no_error!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) U, S, Vᴴ = USVᴴ @@ -394,65 +334,6 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran return Utr, Str, Vᴴtr, ϵ end -function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) - check_input(svd_compact!, A, USVᴴ, alg) - U, S, Vᴴ = USVᴴ - m, n = size(A) - minmn = min(m, n) - if minmn == 0 - one!(U) - zero!(S) - one!(Vᴴ) - return USVᴴ - end - - do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa GPU_QRIteration - isempty(alg_kwargs) || @warn "invalid keyword arguments for GPU_QRIteration" - _gpu_gesvd_maybe_transpose!(A, diagview(S), U, Vᴴ) - elseif alg isa GPU_SVDPolar - _gpu_Xgesvdp!(A, diagview(S), U, Vᴴ; alg_kwargs...) - elseif alg isa GPU_Jacobi - _gpu_gesvdj!(A, diagview(S), U, Vᴴ; alg_kwargs...) - else - throw(ArgumentError("Unsupported SVD algorithm")) - end - - do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ) - - return USVᴴ -end -_argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x))) -_largest(x, y) = abs(x) < abs(y) ? y : x - -function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm) - check_input(svd_vals!, A, S, alg) - m, n = size(A) - minmn = min(m, n) - if minmn == 0 - zero!(S) - return S - end - U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) - - alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) - - if alg isa GPU_QRIteration - isempty(alg_kwargs) || @warn "invalid keyword arguments for GPU_QRIteration" - _gpu_gesvd_maybe_transpose!(A, S, U, Vᴴ) - elseif alg isa GPU_SVDPolar - _gpu_Xgesvdp!(A, S, U, Vᴴ; alg_kwargs...) - elseif alg isa GPU_Jacobi - _gpu_gesvdj!(A, S, U, Vᴴ; alg_kwargs...) - else - throw(ArgumentError("Unsupported SVD algorithm")) - end - - return S -end - # Deprecations # ------------ for algtype in (:SafeDivideAndConquer, :DivideAndConquer, :QRIteration, :Jacobi, :Bisection) @@ -472,3 +353,40 @@ for algtype in (:SafeDivideAndConquer, :DivideAndConquer, :QRIteration, :Jacobi, ) end end + +for (algtype, newtype, drivertype) in ( + (:CUSOLVER_QRIteration, :QRIteration, :CUSOLVER), + (:CUSOLVER_Jacobi, :Jacobi, :CUSOLVER), + (:CUSOLVER_SVDPolar, :SVDPolar, :CUSOLVER), + (:ROCSOLVER_QRIteration, :QRIteration, :ROCSOLVER), + (:ROCSOLVER_Jacobi, :Jacobi, :ROCSOLVER), + ) + @eval begin + Base.@deprecate( + svd_compact!(A, USVᴴ, alg::$algtype), + svd_compact!(A, USVᴴ, $newtype(; driver = $drivertype(), alg.kwargs...)) + ) + Base.@deprecate( + svd_full!(A, USVᴴ, alg::$algtype), + svd_full!(A, USVᴴ, $newtype(; driver = $drivertype(), alg.kwargs...)) + ) + Base.@deprecate( + svd_vals!(A, S, alg::$algtype), + svd_vals!(A, S, $newtype(; driver = $drivertype(), alg.kwargs...)) + ) + end +end + +# GLA_QRIteration SVD deprecations (eigh methods remain in the GLA extension) +Base.@deprecate( + svd_compact!(A, USVᴴ, alg::GLA_QRIteration), + svd_compact!(A, USVᴴ, QRIteration(; driver = GLA(), alg.kwargs...)) +) +Base.@deprecate( + svd_full!(A, USVᴴ, alg::GLA_QRIteration), + svd_full!(A, USVᴴ, QRIteration(; driver = GLA(), alg.kwargs...)) +) +Base.@deprecate( + svd_vals!(A, S, alg::GLA_QRIteration), + svd_vals!(A, S, QRIteration(; driver = GLA(), alg.kwargs...)) +) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index ed7d1ace..061ab1dd 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -139,6 +139,18 @@ The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of @algdef Bisection @algdef Jacobi +""" + SVDPolar(; [driver], kwargs...) + +Algorithm type to denote the algorithm for computing the singular value decomposition of a general +matrix via Halley's iterative algorithm for the polar decomposition followed by the Hermitian +eigenvalue decomposition of the positive definite factor. + +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular vectors, +see also [`gaugefix!`](@ref). +""" +@algdef SVDPolar + for f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration, :bisection, :jacobi) default_f_driver = Symbol(:default_, f, :_driver) @eval begin @@ -154,6 +166,11 @@ for f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration, :bisecti end end +default_svd_polar_driver(A) = default_svd_polar_driver(typeof(A)) +default_svd_polar_driver(::Type) = Native() +default_svd_polar_driver(::Type{<:SubArray{T, N, A}}) where {T, N, A} = default_svd_polar_driver(A) +default_svd_polar_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = default_svd_polar_driver(A) + # General Eigenvalue Decomposition # ------------------------------- """ diff --git a/test/algorithms.jl b/test/algorithms.jl index 3acd8b6f..83566803 100644 --- a/test/algorithms.jl +++ b/test/algorithms.jl @@ -7,7 +7,7 @@ using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm, @testset "default_algorithm" begin A = randn(3, 3) for f in (svd_compact!, svd_compact, svd_full!, svd_full) - @test @constinferred(default_algorithm(f, A)) === LAPACK_SafeDivideAndConquer() + @test @constinferred(default_algorithm(f, A)) == SafeDivideAndConquer() end for f in (eig_full!, eig_full, eig_vals!, eig_vals) @test @constinferred(default_algorithm(f, A)) === LAPACK_Expert() @@ -21,7 +21,7 @@ using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm, end for f in (left_polar!, left_polar, right_polar!, right_polar) @test @constinferred(default_algorithm(f, A)) == - PolarViaSVD(LAPACK_SafeDivideAndConquer()) + PolarViaSVD(SafeDivideAndConquer()) end for f in (qr_full!, qr_full, qr_compact!, qr_compact, qr_null!, qr_null) @test @constinferred(default_algorithm(f, A)) == Householder() @@ -37,8 +37,8 @@ end @testset "select_algorithm" begin A = randn(3, 3) for f in (svd_trunc!, svd_trunc) - @test @constinferred(select_algorithm(f, A)) === - TruncatedAlgorithm(LAPACK_SafeDivideAndConquer(), notrunc()) + @test @constinferred(select_algorithm(f, A)) == + TruncatedAlgorithm(SafeDivideAndConquer(), notrunc()) end for f in (eig_trunc!, eig_trunc) @test @constinferred(select_algorithm(f, A)) === @@ -55,8 +55,8 @@ end @test_throws ArgumentError select_algorithm(eig_trunc!, A, alg; trunc = (; maxrank = 2)) end - @test @constinferred(select_algorithm(svd_compact!, A)) === LAPACK_SafeDivideAndConquer() - @test @constinferred(select_algorithm(svd_compact!, A, nothing)) === LAPACK_SafeDivideAndConquer() + @test @constinferred(select_algorithm(svd_compact!, A)) == SafeDivideAndConquer() + @test @constinferred(select_algorithm(svd_compact!, A, nothing)) == SafeDivideAndConquer() for alg in (:LAPACK_QRIteration, LAPACK_QRIteration, LAPACK_QRIteration()) @test @constinferred(select_algorithm(svd_compact!, A, $alg)) === LAPACK_QRIteration() end From 6282f32059e08f2fb2305d72b72b90fd5b6b162c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 16 Mar 2026 14:15:22 -0400 Subject: [PATCH 05/17] centralize SVD via adjoint implementation --- .../MatrixAlgebraKitAMDGPUExt.jl | 21 +++++----------- .../MatrixAlgebraKitCUDAExt.jl | 21 +++++----------- src/implementations/svd.jl | 25 ++++++++++++++++++- 3 files changed, 36 insertions(+), 31 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index ed9ff61f..6fa66049 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -31,23 +31,14 @@ end function gesvd!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) m, n = size(A) m >= n && return YArocSOLVER.gesvd!(A, S, U, Vᴴ) - # ROCSOLVER requires m ≥ n; compute SVD via adjoint when m < n - minmn = min(m, n) - Aᴴ = minmn > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A') - Uᴴ = similar(U') - V = similar(Vᴴ') - if size(U) == (m, m) - YArocSOLVER.gesvd!(Aᴴ, view(S, 1:minmn, 1), V, Uᴴ) - else - YArocSOLVER.gesvd!(Aᴴ, S, V, Uᴴ) - end - length(U) > 0 && adjoint!(U, Uᴴ) - length(Vᴴ) > 0 && adjoint!(Vᴴ, V) - return S, U, Vᴴ + return MatrixAlgebraKit.svd_via_adjoint!(gesvd!, ROCSOLVER(), A, S, U, Vᴴ; kwargs...) end -gesvdj!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = - YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) +function gesvdj!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) + m, n = size(A) + m >= n && return YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) + return MatrixAlgebraKit.svd_via_adjoint!(gesvdj!, ROCSOLVER(), A, S, U, Vᴴ; kwargs...) +end _gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevj!(A, Dd, V; kwargs...) _gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index e848560e..b39b6caa 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -36,23 +36,14 @@ end function gesvd!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) m, n = size(A) m >= n && return YACUSOLVER.gesvd!(A, S, U, Vᴴ) - # CUSOLVER requires m ≥ n; compute SVD via adjoint when m < n - minmn = min(m, n) - Aᴴ = minmn > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A') - Uᴴ = similar(U') - V = similar(Vᴴ') - if size(U) == (m, m) - YACUSOLVER.gesvd!(Aᴴ, view(S, 1:minmn, 1), V, Uᴴ) - else - YACUSOLVER.gesvd!(Aᴴ, S, V, Uᴴ) - end - length(U) > 0 && adjoint!(U, Uᴴ) - length(Vᴴ) > 0 && adjoint!(Vᴴ, V) - return S, U, Vᴴ + return MatrixAlgebraKit.svd_via_adjoint!(gesvd!, CUSOLVER(), A, S, U, Vᴴ; kwargs...) end -gesvdj!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = - YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) +function gesvdj!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) + m, n = size(A) + m >= n && return YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) + return MatrixAlgebraKit.svd_via_adjoint!(gesvdj!, CUSOLVER(), A, S, U, Vᴴ; kwargs...) +end gesvdp!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = YACUSOLVER.gesvdp!(A, S, U, Vᴴ; kwargs...) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 2a83f596..1f91ce44 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -117,11 +117,34 @@ for f! in (:gesdd!, :gesvd!, :gesvdj!, :gesvdp!, :gesvdx!, :gesvdr!, :gesdvd!) @eval $f!(driver::Driver, args...) = throw(ArgumentError("$driver does not provide $f!")) end +""" + svd_via_adjoint!(f!, driver, A, S, U, Vᴴ; kwargs...) + +Compute the SVD of `A` (m × n, m < n) by computing the SVD of `adjoint(A)` using +the provided function `f!(driver, A, S, U, Vᴴ; kwargs...)`. Use this as a building +block for drivers whose SVD routines require m ≥ n. +""" +function svd_via_adjoint!(f!::F, driver::Driver, A, S, U, Vᴴ; kwargs...) where {F} + Aᴴ = adjoint!(similar(A'), A) + Uᴴ = similar(U') + V = similar(Vᴴ') + f!(driver, Aᴴ, S, V, Uᴴ; kwargs...) + length(U) > 0 && adjoint!(U, Uᴴ) + length(Vᴴ) > 0 && adjoint!(Vᴴ, V) + return S, U, Vᴴ +end + # LAPACK -for f! in (:gesdd!, :gesvd!, :gesvdj!, :gesvdx!, :gesdvd!) +for f! in (:gesdd!, :gesvd!, :gesvdx!, :gesdvd!) @eval $f!(::LAPACK, args...; kwargs...) = YALAPACK.$f!(args...; kwargs...) end +function gesvdj!(::LAPACK, A, S, U, Vᴴ; kwargs...) + m, n = size(A) + m >= n && return YALAPACK.gesvdj!(A, S, U, Vᴴ) + return svd_via_adjoint!(gesvdj!, LAPACK(), A, S, U, Vᴴ; kwargs...) +end + for (f, f_lapack!, Alg) in ( (:safe_divide_and_conquer, :gesdvd!, :SafeDivideAndConquer), (:divide_and_conquer, :gesdd!, :DivideAndConquer), From 1722fe45208d2a4ed200290105d28876213697bc Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 16 Mar 2026 14:16:16 -0400 Subject: [PATCH 06/17] move helper functions --- src/common/gauge.jl | 3 +++ src/implementations/svd.jl | 4 ---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/common/gauge.jl b/src/common/gauge.jl index e855548f..f045d4e9 100644 --- a/src/common/gauge.jl +++ b/src/common/gauge.jl @@ -8,6 +8,9 @@ This is achieved by ensuring that the entry with the largest magnitude in `V` or is real and positive. """ gaugefix! +# Helper functions +_argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x))) +_largest(x, y) = abs(x) < abs(y) ? y : x function gaugefix!(::Union{typeof(eig_full!), typeof(eigh_full!), typeof(gen_eig_full!)}, V::AbstractMatrix) for j in axes(V, 2) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 1f91ce44..89228a2a 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -105,10 +105,6 @@ function initialize_output(::typeof(svd_vals!), A::Diagonal, ::DiagonalAlgorithm return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1)) end -# Helper functions used by gauge.jl -_argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x))) -_largest(x, y) = abs(x) < abs(y) ? y : x - # ========================== # IMPLEMENTATIONS # ========================== From ae6e9177785b91a23b8a88e45b98624e51ded68e Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 16 Mar 2026 14:52:35 -0400 Subject: [PATCH 07/17] update docstrings --- src/interface/decompositions.jl | 51 +++++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 15 deletions(-) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 061ab1dd..1e21fcf0 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -104,6 +104,7 @@ default_householder_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A Algorithm type to denote the algorithm for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the divide-and-conquer algorithm. +The optional `driver` symbol can be used to choose between different implementations of this algorithm. The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`gaugefix!`](@ref). """ @algdef DivideAndConquer @@ -111,7 +112,7 @@ The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of """ SafeDivideAndConquer(; [driver], kwargs...) -Algorithm type to denote the algorithm for computing the eigenvalue decomposition of a Hermitian matrix, +Algorithm type to for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the divide-and-conquer algorithm, with an additional fallback to the standard QR iteration algorithm in case the former fails to converge. @@ -128,15 +129,35 @@ See also [`DivideAndConquer`](@ref) and [`QRIteration`](@ref). @algdef SafeDivideAndConquer """ - QRIteration(; [driver], kwargs...) + QRIteration(; [driver], fixgauge = true) -Algorithm type to denote the algorithm for computing the eigenvalue decomposition of a Hermitian matrix, +Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix via QR iteration. +The optional `driver` symbol can be used to choose between different implementations of this algorithm. The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`gaugefix!`](@ref). """ @algdef QRIteration + +""" + Bisection(; [driver], fixgauge::Bool = true) + +Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix, +or the singular value decomposition of a general matrix via the bisection algorithm. + +The optional `driver` symbol can be used to choose between different implementations of this algorithm. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`gaugefix!`](@ref). +""" @algdef Bisection + +""" + Jacobi(; [driver], fixgauge = true) + +Algorithm type for computing the singular value decomposition of a general matrix using the Jacobi algorithm. + +The optional `driver` symbol can be used to choose between different implementations of this algorithm. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`gaugefix!`](@ref). +""" @algdef Jacobi """ @@ -146,30 +167,30 @@ Algorithm type to denote the algorithm for computing the singular value decompos matrix via Halley's iterative algorithm for the polar decomposition followed by the Hermitian eigenvalue decomposition of the positive definite factor. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular vectors, -see also [`gaugefix!`](@ref). +The optional `driver` symbol can be used to choose between different implementations of this algorithm. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`gaugefix!`](@ref). """ @algdef SVDPolar -for f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration, :bisection, :jacobi) +for f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration, :bisection, :jacobi, :svd_polar) default_f_driver = Symbol(:default_, f, :_driver) @eval begin $default_f_driver(A) = $default_f_driver(typeof(A)) $default_f_driver(::Type) = Native() - $default_f_driver(::Type{A}) where {A <: YALAPACK.MaybeBlasMat} = LAPACK() - - # note: StridedVector fallback is needed for handling reshaped parent types - $default_f_driver(::Type{A}) where {A <: StridedVector{<:BlasFloat}} = LAPACK() $default_f_driver(::Type{<:SubArray{T, N, A}}) where {T, N, A} = $default_f_driver(A) $default_f_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = $default_f_driver(A) end -end -default_svd_polar_driver(A) = default_svd_polar_driver(typeof(A)) -default_svd_polar_driver(::Type) = Native() -default_svd_polar_driver(::Type{<:SubArray{T, N, A}}) where {T, N, A} = default_svd_polar_driver(A) -default_svd_polar_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = default_svd_polar_driver(A) + if f !== :svd_polar + @eval begin + $default_f_driver(::Type{A}) where {A <: YALAPACK.MaybeBlasMat} = LAPACK() + # note: StridedVector fallback is needed for handling reshaped parent types + $default_f_driver(::Type{A}) where {A <: StridedVector{<:BlasFloat}} = LAPACK() + end + end + +end # General Eigenvalue Decomposition # ------------------------------- From 43ea342a21e69ad3f55b6dcba69106309f96a9ca Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 16 Mar 2026 21:11:05 -0400 Subject: [PATCH 08/17] fix ambiguity misery --- test/linearmap.jl | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/test/linearmap.jl b/test/linearmap.jl index a7adaae7..7e76a119 100644 --- a/test/linearmap.jl +++ b/test/linearmap.jl @@ -3,7 +3,7 @@ module LinearMaps export LinearMap using MatrixAlgebraKit - using MatrixAlgebraKit: AbstractAlgorithm, DiagonalAlgorithm, GLA_QRIteration + using MatrixAlgebraKit: AbstractAlgorithm using GenericLinearAlgebra import MatrixAlgebraKit as MAK @@ -31,18 +31,17 @@ module LinearMaps MAK.check_input($f!, parent(A), parent.(F), alg) @eval MAK.initialize_output(::typeof($f!), A::LinearMap, alg::AbstractAlgorithm) = LinearMap.(MAK.initialize_output($f!, parent(A), alg)) - @eval MAK.initialize_output(::typeof($f!), A::LinearMap, alg::GLA_QRIteration) = - (nothing, nothing, nothing) - @eval MAK.$f!(A::LinearMap, F, alg::AbstractAlgorithm) = - LinearMap.(MAK.$f!(parent(A), parent.(F), alg)) - @eval MAK.$f!(A::LinearMap, F, alg::GLA_QRIteration) = - LinearMap.(MAK.$f!(parent(A), F, alg)) - @eval MAK.check_input(::typeof($f!), A::LinearMap, F, alg::DiagonalAlgorithm) = - MAK.check_input($f!, parent(A), parent.(F), alg) - @eval MAK.initialize_output(::typeof($f!), A::LinearMap, alg::DiagonalAlgorithm) = - LinearMap.(MAK.initialize_output($f!, parent(A), alg)) - @eval MAK.$f!(A::LinearMap, F, alg::DiagonalAlgorithm) = - LinearMap.(MAK.$f!(parent(A), parent.(F), alg)) + end + + # Define svd_compact! and svd_full! for LinearMap with concrete algorithm types to avoid + # ambiguity with methods like `svd_compact!(A, USVᴴ, alg::SafeDivideAndConquer)`. + # Using AbstractAlgorithm here would be ambiguous since neither A-type nor alg-type would + # be strictly more specific. + for f! in (:svd_compact!, :svd_full!) + for Alg in (:SafeDivideAndConquer, :DivideAndConquer, :QRIteration, :Bisection, :Jacobi, :SVDPolar) + @eval MAK.$f!(A::LinearMap, USVᴴ, alg::MAK.$Alg) = + LinearMap.(MAK.$f!(parent(A), parent.(USVᴴ), alg)) + end end for f in (:qr, :lq, :svd) From 0fc8f182b575840b7cb376c172824559606ae5ac Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 16 Mar 2026 21:20:27 -0400 Subject: [PATCH 09/17] Apply suggestions from code review Co-authored-by: Jutho --- src/interface/decompositions.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 1e21fcf0..53cf90d5 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -99,13 +99,13 @@ default_householder_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A default_householder_driver(A) """ - DivideAndConquer(; [driver], kwargs...) + DivideAndConquer(; [driver], fixgauge=default_fixgauge()) Algorithm type to denote the algorithm for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the divide-and-conquer algorithm. The optional `driver` symbol can be used to choose between different implementations of this algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`gaugefix!`](@ref). +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`default_fixgauge`](@ref) and [`gaugefix!`](@ref). """ @algdef DivideAndConquer From 2465a42c337314bb42960e855777129dc0947470 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 17 Mar 2026 09:54:12 -0400 Subject: [PATCH 10/17] update docs --- docs/src/user_interface/decompositions.md | 63 ++++++++++++++++++++--- src/interface/decompositions.jl | 2 +- 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/docs/src/user_interface/decompositions.md b/docs/src/user_interface/decompositions.md index b48bacbe..c23061b2 100644 --- a/docs/src/user_interface/decompositions.md +++ b/docs/src/user_interface/decompositions.md @@ -40,11 +40,10 @@ lq_full lq_compact ``` -Alongside these functions, we provide a LAPACK-based implementation for dense arrays, as provided by the following algorithm: +The following algorithm is available for QR and LQ decompositions: ```@docs; canonical=false -LAPACK_HouseholderQR -LAPACK_HouseholderLQ +Householder ``` ## Eigenvalue Decomposition @@ -78,7 +77,7 @@ eigh_vals By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results. See [Gauge choices](@ref sec_gaugefix) for more details. -Alongside these functions, we provide a LAPACK-based implementation for dense arrays, as provided by the following algorithms: +The following algorithms are available for the symmetric eigenvalue decomposition: ```@autodocs; canonical=false Modules = [MatrixAlgebraKit] @@ -100,7 +99,7 @@ eig_vals By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results. See [Gauge choices](@ref sec_gaugefix) for more details. -Alongside these functions, we provide a LAPACK-based implementation for dense arrays, as provided by the following algorithms: +The following algorithms are available for the non-Hermitian eigenvalue decomposition: ```@autodocs; canonical=false Modules = [MatrixAlgebraKit] @@ -120,7 +119,7 @@ schur_full schur_vals ``` -The LAPACK-based implementation for dense arrays is provided by the following algorithms: +The following algorithms are available for the Schur decomposition: ```@autodocs; canonical=false Modules = [MatrixAlgebraKit] @@ -153,11 +152,11 @@ svd_trunc By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results. See [Gauge choices](@ref sec_gaugefix) for more details. -MatrixAlgebraKit again ships with LAPACK-based implementations for dense arrays: +The following algorithms are available for the singular value decomposition: ```@autodocs; canonical=false Modules = [MatrixAlgebraKit] -Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_SVDAlgorithm +Filter = t -> t isa Type && t <: MatrixAlgebraKit.SVDAlgorithms ``` ## Polar Decomposition @@ -388,6 +387,54 @@ norm(A * N1') < 1e-14 && norm(A * N2') < 1e-14 && true ``` +## [Driver Selection](@id sec_driverselection) + +!!! note "Expert use case" + Selecting a specific driver is an advanced feature intended for users who need to target a specific computational backend, such as a GPU. For most use cases, the default driver selection is sufficient. + +Each algorithm in MatrixAlgebraKit can optionally accept a `driver` keyword argument to explicitly select the computational backend. +By default, the driver is set to `DefaultDriver()`, which automatically selects the most appropriate backend based on the input matrix type. +The available drivers are: + +```@docs; canonical=false +MatrixAlgebraKit.DefaultDriver +MatrixAlgebraKit.LAPACK +MatrixAlgebraKit.CUSOLVER +MatrixAlgebraKit.ROCSOLVER +MatrixAlgebraKit.GLA +MatrixAlgebraKit.Native +``` + +For example, to force LAPACK for a generic matrix type, or to use a GPU backend: + +```julia +using MatrixAlgebraKit +using MatrixAlgebraKit: LAPACK, CUSOLVER # driver types are not exported by default + +# Default: driver is selected automatically based on the input type +U, S, Vᴴ = svd_compact(A) +U, S, Vᴴ = svd_compact(A; alg = SafeDivideAndConquer()) + +# Expert: explicitly select LAPACK +U, S, Vᴴ = svd_compact(A; alg = SafeDivideAndConquer(; driver = LAPACK())) + +# Expert: use a GPU backend (requires loading the appropriate extension) +U, S, Vᴴ = svd_compact(A; alg = QRIteration(; driver = CUSOLVER())) +``` + +Similarly, for QR decompositions: + +```julia +using MatrixAlgebraKit: LAPACK # driver types are not exported by default + +# Default: driver is selected automatically +Q, R = qr_compact(A) +Q, R = qr_compact(A; alg = Householder()) + +# Expert: explicitly select a driver +Q, R = qr_compact(A; alg = Householder(; driver = LAPACK())) +``` + ## [Gauge choices](@id sec_gaugefix) Both eigenvalue and singular value decompositions have residual gauge degrees of freedom even when the eigenvalues or singular values are unique. diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 53cf90d5..c94ad892 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -535,7 +535,7 @@ const GPU_Randomized = Union{CUSOLVER_Randomized} const QRAlgorithms = Union{Householder, LAPACK_HouseholderQR, Native_HouseholderQR, CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR} const LQAlgorithms = Union{Householder, LAPACK_HouseholderLQ, Native_HouseholderLQ, LQViaTransposedQR} -const SVDAlgorithms = Union{LAPACK_SVDAlgorithm, GPU_SVDAlgorithm} +const SVDAlgorithms = Union{SafeDivideAndConquer, DivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar} const PolarAlgorithms = Union{PolarViaSVD, PolarNewton} # ================================ From a0173efd5805114a73000dbdbb829cccdee810f2 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 17 Mar 2026 09:54:18 -0400 Subject: [PATCH 11/17] rename SVDViaPolar --- src/MatrixAlgebraKit.jl | 2 +- src/implementations/svd.jl | 4 ++-- src/interface/decompositions.jl | 4 ++-- test/linearmap.jl | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 2380395c..9135b6f9 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -32,7 +32,7 @@ export left_orth, right_orth, left_null, right_null export left_orth!, right_orth!, left_null!, right_null! export Householder, Native_HouseholderQR, Native_HouseholderLQ -export DivideAndConquer, SafeDivideAndConquer, QRIteration, Bisection, Jacobi, SVDPolar +export DivideAndConquer, SafeDivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations, LAPACK_DivideAndConquer, LAPACK_Jacobi, LAPACK_SafeDivideAndConquer diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 89228a2a..24a06a85 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -147,7 +147,7 @@ for (f, f_lapack!, Alg) in ( (:qr_iteration, :gesvd!, :QRIteration), (:bisection, :gesvdx!, :Bisection), (:jacobi, :gesvdj!, :Jacobi), - (:svd_polar, :gesvdp!, :SVDPolar), + (:svd_polar, :gesvdp!, :SVDViaPolar), ) f_svd! = Symbol(f, :_svd!) f_svd_full! = Symbol(f, :_svd_full!) @@ -376,7 +376,7 @@ end for (algtype, newtype, drivertype) in ( (:CUSOLVER_QRIteration, :QRIteration, :CUSOLVER), (:CUSOLVER_Jacobi, :Jacobi, :CUSOLVER), - (:CUSOLVER_SVDPolar, :SVDPolar, :CUSOLVER), + (:CUSOLVER_SVDPolar, :SVDViaPolar, :CUSOLVER), (:ROCSOLVER_QRIteration, :QRIteration, :ROCSOLVER), (:ROCSOLVER_Jacobi, :Jacobi, :ROCSOLVER), ) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index c94ad892..4f5eccf2 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -161,7 +161,7 @@ The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of @algdef Jacobi """ - SVDPolar(; [driver], kwargs...) + SVDViaPolar(; [driver], kwargs...) Algorithm type to denote the algorithm for computing the singular value decomposition of a general matrix via Halley's iterative algorithm for the polar decomposition followed by the Hermitian @@ -170,7 +170,7 @@ eigenvalue decomposition of the positive definite factor. The optional `driver` symbol can be used to choose between different implementations of this algorithm. The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`gaugefix!`](@ref). """ -@algdef SVDPolar +@algdef SVDViaPolar for f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration, :bisection, :jacobi, :svd_polar) default_f_driver = Symbol(:default_, f, :_driver) diff --git a/test/linearmap.jl b/test/linearmap.jl index 7e76a119..07063ac9 100644 --- a/test/linearmap.jl +++ b/test/linearmap.jl @@ -38,7 +38,7 @@ module LinearMaps # Using AbstractAlgorithm here would be ambiguous since neither A-type nor alg-type would # be strictly more specific. for f! in (:svd_compact!, :svd_full!) - for Alg in (:SafeDivideAndConquer, :DivideAndConquer, :QRIteration, :Bisection, :Jacobi, :SVDPolar) + for Alg in (:SafeDivideAndConquer, :DivideAndConquer, :QRIteration, :Bisection, :Jacobi, :SVDViaPolar) @eval MAK.$f!(A::LinearMap, USVᴴ, alg::MAK.$Alg) = LinearMap.(MAK.$f!(parent(A), parent.(USVᴴ), alg)) end From 556b17eb78b95090816838cf40292c615f14a60b Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 17 Mar 2026 10:36:03 -0400 Subject: [PATCH 12/17] Apply suggestions from code review Co-authored-by: Jutho --- docs/src/user_interface/decompositions.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/src/user_interface/decompositions.md b/docs/src/user_interface/decompositions.md index c23061b2..38e6ceb7 100644 --- a/docs/src/user_interface/decompositions.md +++ b/docs/src/user_interface/decompositions.md @@ -62,9 +62,9 @@ These functions return the diagonal elements of `D` in a vector. Finally, it is also possible to compute a partial or truncated eigenvalue decomposition, using the [`eig_trunc`](@ref) and [`eigh_trunc`](@ref) functions. To control the behavior of the truncation, we refer to [Truncations](@ref) for more information. -### Symmetric Eigenvalue Decomposition +### Hermitian or Real Symmetric Eigenvalue Decomposition -For symmetric matrices, we provide the following functions: +For hermitian matrices, thus including real symmetric matrices, we provide the following functions: ```@docs; canonical=false eigh_full @@ -77,7 +77,7 @@ eigh_vals By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results. See [Gauge choices](@ref sec_gaugefix) for more details. -The following algorithms are available for the symmetric eigenvalue decomposition: +The following algorithms are available for the hermitian eigenvalue decomposition: ```@autodocs; canonical=false Modules = [MatrixAlgebraKit] @@ -99,7 +99,7 @@ eig_vals By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results. See [Gauge choices](@ref sec_gaugefix) for more details. -The following algorithms are available for the non-Hermitian eigenvalue decomposition: +The following algorithms are available for the standard eigenvalue decomposition: ```@autodocs; canonical=false Modules = [MatrixAlgebraKit] From ec924ffcc6b82d15ede49498d71f7fc3abf9f8c0 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 17 Mar 2026 10:39:19 -0400 Subject: [PATCH 13/17] more consistent restriction to BlasFloat --- ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl | 4 ++-- ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 6fa66049..7ea744b1 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -17,10 +17,10 @@ include("yarocsolver.jl") MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedROCVecOrMat{<:BlasFloat}} = ROCSOLVER() MatrixAlgebraKit.default_qr_iteration_driver(::Type{<:StridedROCVecOrMat}) = ROCSOLVER() MatrixAlgebraKit.default_jacobi_driver(::Type{<:StridedROCVecOrMat}) = ROCSOLVER() -function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat} +function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat{<:BlasFloat}} return QRIteration(; kwargs...) end -function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat} +function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat{<:BlasFloat}} return ROCSOLVER_DivideAndConquer(; kwargs...) end diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index b39b6caa..665b2372 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -19,13 +19,13 @@ MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedCuVecO MatrixAlgebraKit.default_qr_iteration_driver(::Type{<:StridedCuVecOrMat{<:BlasFloat}}) = CUSOLVER() MatrixAlgebraKit.default_jacobi_driver(::Type{<:StridedCuVecOrMat{<:BlasFloat}}) = CUSOLVER() MatrixAlgebraKit.default_svd_polar_driver(::Type{<:StridedCuVecOrMat{<:BlasFloat}}) = CUSOLVER() -function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}} +function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}} return QRIteration(; kwargs...) end -function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}} +function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}} return CUSOLVER_Simple(; kwargs...) end -function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}} +function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}} return CUSOLVER_DivideAndConquer(; kwargs...) end From 089280fa26da828ba0fb593c9c56caf21b7ce5d9 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 17 Mar 2026 11:38:19 -0400 Subject: [PATCH 14/17] default_driver --- .../MatrixAlgebraKitAMDGPUExt.jl | 5 ++--- .../MatrixAlgebraKitCUDAExt.jl | 6 ++---- ...MatrixAlgebraKitGenericLinearAlgebraExt.jl | 8 +++++--- src/algorithms.jl | 17 +++++++++++++++++ src/implementations/lq.jl | 4 ++-- src/implementations/qr.jl | 4 ++-- src/implementations/svd.jl | 19 +++++++------------ 7 files changed, 37 insertions(+), 26 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 7ea744b1..081218ed 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -14,9 +14,8 @@ using LinearAlgebra: BlasFloat include("yarocsolver.jl") -MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedROCVecOrMat{<:BlasFloat}} = ROCSOLVER() -MatrixAlgebraKit.default_qr_iteration_driver(::Type{<:StridedROCVecOrMat}) = ROCSOLVER() -MatrixAlgebraKit.default_jacobi_driver(::Type{<:StridedROCVecOrMat}) = ROCSOLVER() +MatrixAlgebraKit.default_driver(::Type{TA}) where {TA <: StridedROCVecOrMat{<:BlasFloat}} = ROCSOLVER() + function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat{<:BlasFloat}} return QRIteration(; kwargs...) end diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 665b2372..837a5bd2 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -15,10 +15,8 @@ using LinearAlgebra: BlasFloat include("yacusolver.jl") -MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER() -MatrixAlgebraKit.default_qr_iteration_driver(::Type{<:StridedCuVecOrMat{<:BlasFloat}}) = CUSOLVER() -MatrixAlgebraKit.default_jacobi_driver(::Type{<:StridedCuVecOrMat{<:BlasFloat}}) = CUSOLVER() -MatrixAlgebraKit.default_svd_polar_driver(::Type{<:StridedCuVecOrMat{<:BlasFloat}}) = CUSOLVER() +MatrixAlgebraKit.default_driver(::Type{TA}) where {TA <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER() + function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}} return QRIteration(; kwargs...) end diff --git a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl index 62d767a4..bb349b7e 100644 --- a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl +++ b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl @@ -7,9 +7,11 @@ import MatrixAlgebraKit: gesvd! using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr! using LinearAlgebra: I, Diagonal, lmul! -MatrixAlgebraKit.default_qr_iteration_driver(::Type{<:StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}) = GLA() +const GlaFloat = Union{BigFloat, Complex{BigFloat}} +const GlaStridedVecOrMatrix{T <: GlaFloat} = Union{StridedVector{T}, StridedMatrix{T}} +MatrixAlgebraKit.default_driver(::Type{<:QRIteration}, ::Type{TA}) where {TA <: GlaStridedVecOrMatrix} = GLA() -function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} +function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix} return QRIteration(; kwargs...) end @@ -30,7 +32,7 @@ function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, return S, U, Vᴴ end -function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} +function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix} return GLA_QRIteration(; kwargs...) end diff --git a/src/algorithms.jl b/src/algorithms.jl index 07685f6a..cf1d2ff4 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -212,6 +212,23 @@ Driver to select a native implementation in MatrixAlgebraKit as the implementati """ struct Native <: Driver end +# In order to avoid amibiguities, this method is implemented in a tiered way +# default_driver(alg, A) -> default_driver(typeof(alg), typeof(A)) +# default_driver(Talg, TA) -> default_driver(TA) +# This is to try and minimize ambiguity while allowing overloading at multiple levels +@inline default_driver(alg::AbstractAlgorithm, A) = default_driver(typeof(alg), A isa Type ? A : typeof(A)) +@inline default_driver(::Type{Alg}, A) where {Alg <: AbstractAlgorithm} = default_driver(Alg, typeof(A)) +@inline default_driver(::Type{Alg}, ::Type{TA}) where {Alg <: AbstractAlgorithm, TA} = default_driver(TA) + +# defaults +default_driver(::Type{TA}) where {TA <: AbstractArray} = Native() # default fallback +default_driver(::Type{TA}) where {TA <: YALAPACK.MaybeBlasVecOrMat} = LAPACK() + +# wrapper types +@inline default_driver(::Type{Alg}, ::Type{<:SubArray{T, N, A}}) where {Alg <: AbstractAlgorithm, T, N, A} = default_driver(Alg, A) +@inline default_driver(::Type{Alg}, ::Type{<:Base.ReshapedArray{T, N, A}}) where {Alg <: AbstractAlgorithm, T, N, A} = default_driver(Alg, A) +@inline default_driver(::Type{<:SubArray{T, N, A}}) where {T, N, A} = default_driver(A) +@inline default_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = default_driver(A) # Truncation strategy # ------------------- diff --git a/src/implementations/lq.jl b/src/implementations/lq.jl index cf64ccd6..d248ac40 100644 --- a/src/implementations/lq.jl +++ b/src/implementations/lq.jl @@ -115,7 +115,7 @@ end @inline householder_lq!(A, L, Q; driver::Driver = DefaultDriver(), kwargs...) = householder_lq!(driver, A, L, Q; kwargs...) householder_lq!(::DefaultDriver, A, L, Q; kwargs...) = - householder_lq!(default_householder_driver(A), A, L, Q; kwargs...) + householder_lq!(default_driver(Householder, A), A, L, Q; kwargs...) householder_lq!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, L, Q; kwargs...) = lq_via_qr!(A, L, Q, Householder(; driver, kwargs...)) function householder_lq!( @@ -221,7 +221,7 @@ end @inline householder_lq_null!(A, Nᴴ; driver::Driver = DefaultDriver(), kwargs...) = householder_lq_null!(driver, A, Nᴴ; kwargs...) householder_lq_null!(::DefaultDriver, A, Nᴴ; kwargs...) = - householder_lq_null!(default_householder_driver(A), A, Nᴴ; kwargs...) + householder_lq_null!(default_driver(Householder, A), A, Nᴴ; kwargs...) householder_lq_null!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, Nᴴ; kwargs...) = lq_null_via_qr!(A, Nᴴ, Householder(; driver, kwargs...)) function householder_lq_null!( diff --git a/src/implementations/qr.jl b/src/implementations/qr.jl index f78d8c44..3d340a28 100644 --- a/src/implementations/qr.jl +++ b/src/implementations/qr.jl @@ -117,7 +117,7 @@ end @inline householder_qr!(A, Q, R; driver::Driver = DefaultDriver(), kwargs...) = householder_qr!(driver, A, Q, R; kwargs...) householder_qr!(::DefaultDriver, A, Q, R; kwargs...) = - householder_qr!(default_householder_driver(A), A, Q, R; kwargs...) + householder_qr!(default_driver(Householder, A), A, Q, R; kwargs...) function householder_qr!( driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; positive::Bool = true, pivoted::Bool = false, @@ -248,7 +248,7 @@ end @inline householder_qr_null!(A, N; driver::Driver = DefaultDriver(), kwargs...) = householder_qr_null!(driver, A, N; kwargs...) householder_qr_null!(::DefaultDriver, A, N; kwargs...) = - householder_qr_null!(default_householder_driver(A), A, N; kwargs...) + householder_qr_null!(default_driver(Householder, A), A, N; kwargs...) function householder_qr_null!( driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, N::AbstractMatrix; positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0 diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 24a06a85..2ad20d82 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -171,18 +171,13 @@ for (f, f_lapack!, Alg) in ( # driver @eval begin - @inline $f_svd!(A, U, S, Vᴴ; driver::Driver = DefaultDriver(), kwargs...) = - $f_svd!(driver, A, U, S, Vᴴ; kwargs...) - @inline $f_svd_full!(A, U, S, Vᴴ; driver::Driver = DefaultDriver(), kwargs...) = - $f_svd_full!(driver, A, U, S, Vᴴ; kwargs...) - @inline $f_svd_vals!(A, S; driver::Driver = DefaultDriver(), kwargs...) = - $f_svd_vals!(driver, A, S; kwargs...) - @inline $f_svd!(::DefaultDriver, A, U, S, Vᴴ; kwargs...) = - $f_svd!($(Symbol(:default_, f, :_driver))(A), A, U, S, Vᴴ; kwargs...) - @inline $f_svd_full!(::DefaultDriver, A, U, S, Vᴴ; kwargs...) = - $f_svd_full!($(Symbol(:default_, f, :_driver))(A), A, U, S, Vᴴ; kwargs...) - @inline $f_svd_vals!(::DefaultDriver, A, S; kwargs...) = - $f_svd_vals!($(Symbol(:default_, f, :_driver))(A), A, S; kwargs...) + @inline $f_svd!(A, U, S, Vᴴ; driver::Driver = DefaultDriver(), kwargs...) = $f_svd!(driver, A, U, S, Vᴴ; kwargs...) + @inline $f_svd_full!(A, U, S, Vᴴ; driver::Driver = DefaultDriver(), kwargs...) = $f_svd_full!(driver, A, U, S, Vᴴ; kwargs...) + @inline $f_svd_vals!(A, S; driver::Driver = DefaultDriver(), kwargs...) = $f_svd_vals!(driver, A, S; kwargs...) + + @inline $f_svd!(::DefaultDriver, A, U, S, Vᴴ; kwargs...) = $f_svd!(default_driver($Alg, A), A, U, S, Vᴴ; kwargs...) + @inline $f_svd_full!(::DefaultDriver, A, U, S, Vᴴ; kwargs...) = $f_svd_full!(default_driver($Alg, A), A, U, S, Vᴴ; kwargs...) + @inline $f_svd_vals!(::DefaultDriver, A, S; kwargs...) = $f_svd_vals!(default_driver($Alg, A), A, S; kwargs...) end # Implementation From c31d563d4ca78cccd238c5157d268d42b293cde9 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 17 Mar 2026 11:38:27 -0400 Subject: [PATCH 15/17] docs improvements --- ext/MatrixAlgebraKitCUDAExt/yacusolver.jl | 4 +- src/common/defaults.jl | 2 + src/interface/decompositions.jl | 147 +++++++++------------- 3 files changed, 60 insertions(+), 93 deletions(-) diff --git a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl index 37da64da..be63eb85 100644 --- a/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl +++ b/ext/MatrixAlgebraKitCUDAExt/yacusolver.jl @@ -164,9 +164,7 @@ function gesvdp!( ) end err = h_err_sigma[] - if err > tol - warn("gesvdp! did not attained requested tolerance: error = $err > tolerance = $tol") - end + err > tol && @warn "gesvdp! did not attain the requested tolerance: error = $err > tolerance = $tol" flag = @allowscalar dh.info[1] CUSOLVER.chklapackerror(BlasInt(flag)) diff --git a/src/common/defaults.jl b/src/common/defaults.jl index bc4160a1..c64c24ad 100644 --- a/src/common/defaults.jl +++ b/src/common/defaults.jl @@ -59,3 +59,5 @@ function default_fixgauge(new_value::Bool) DEFAULT_FIXGAUGE[] = new_value return previous_value end + +const _fixgauge_docs = "The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the output, see also [`default_fixgauge`](@ref) for a global toggle and [`gaugefix!`](@ref) for implementation details." diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 4f5eccf2..59cbe76f 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -60,7 +60,6 @@ of `R` are non-negative. """ @algdef GLA_HouseholderQR -# TODO: @algdef LAPACK_HouseholderQL @algdef LAPACK_HouseholderRQ @@ -86,38 +85,26 @@ function Householder(; return Householder((; blocksize, driver, pivoted, positive)) end -default_householder_driver(A) = default_householder_driver(typeof(A)) -default_householder_driver(::Type) = Native() - -default_householder_driver(::Type{A}) where {A <: YALAPACK.MaybeBlasMat} = LAPACK() - -# note: StridedVector fallback is needed for handling reshaped parent types -default_householder_driver(::Type{A}) where {A <: StridedVector{<:BlasFloat}} = LAPACK() -default_householder_driver(::Type{<:SubArray{T, N, A}}) where {T, N, A} = - default_householder_driver(A) -default_householder_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = - default_householder_driver(A) - """ - DivideAndConquer(; [driver], fixgauge=default_fixgauge()) + DivideAndConquer(; [driver], fixgauge = default_fixgauge()) Algorithm type to denote the algorithm for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the divide-and-conquer algorithm. The optional `driver` symbol can be used to choose between different implementations of this algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`default_fixgauge`](@ref) and [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef DivideAndConquer """ - SafeDivideAndConquer(; [driver], kwargs...) + SafeDivideAndConquer(; [driver], fixgauge = default_fixgauge()) Algorithm type to for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the divide-and-conquer algorithm, with an additional fallback to the standard QR iteration algorithm in case the former fails to converge. The optional `driver` symbol can be used to choose between different implementations of this algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs !!! warning This approach requires a copy of the input matrix, and is thus the most memory intensive SVD strategy. @@ -129,46 +116,47 @@ See also [`DivideAndConquer`](@ref) and [`QRIteration`](@ref). @algdef SafeDivideAndConquer """ - QRIteration(; [driver], fixgauge = true) + QRIteration(; [driver], fixgauge = default_fixgauge()) Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix via QR iteration. The optional `driver` symbol can be used to choose between different implementations of this algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef QRIteration """ - Bisection(; [driver], fixgauge::Bool = true) + Bisection(; [driver], fixgauge = default_fixgauge()) Algorithm type for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix via the bisection algorithm. The optional `driver` symbol can be used to choose between different implementations of this algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef Bisection """ - Jacobi(; [driver], fixgauge = true) + Jacobi(; [driver], fixgauge = default_fixgauge()) Algorithm type for computing the singular value decomposition of a general matrix using the Jacobi algorithm. The optional `driver` symbol can be used to choose between different implementations of this algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef Jacobi """ - SVDViaPolar(; [driver], kwargs...) + SVDViaPolar(; [driver], fixgauge = default_fixgauge(), [tol]) Algorithm type to denote the algorithm for computing the singular value decomposition of a general matrix via Halley's iterative algorithm for the polar decomposition followed by the Hermitian eigenvalue decomposition of the positive definite factor. The optional `driver` symbol can be used to choose between different implementations of this algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or singular vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs +The tolerance `tol` can optionally be used to emit a warning if the decomposition failed to converge beyond that given value. """ @algdef SVDViaPolar @@ -195,22 +183,20 @@ end # General Eigenvalue Decomposition # ------------------------------- """ - LAPACK_Simple(; fixgauge::Bool = true) + LAPACK_Simple(; fixgauge = default_fixgauge()) Algorithm type to denote the simple LAPACK driver for computing the Schur or non-Hermitian eigenvalue decomposition of a matrix. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, -see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef LAPACK_Simple """ - LAPACK_Expert(; fixgauge::Bool = true) + LAPACK_Expert(; fixgauge = default_fixgauge()) Algorithm type to denote the expert LAPACK driver for computing the Schur or non-Hermitian eigenvalue decomposition of a matrix. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, -see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef LAPACK_Expert @@ -227,45 +213,38 @@ eigenvalue decomposition of a non-Hermitian matrix. # Hermitian Eigenvalue Decomposition # ---------------------------------- """ - LAPACK_QRIteration(; fixgauge::Bool = true) + LAPACK_QRIteration(; fixgauge = default_fixgauge()) -Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a -Hermitian matrix, or the singular value decomposition of a general matrix using the -QR Iteration algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). +Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a Hermitian matrix, +or the singular value decomposition of a general matrix using the QR Iteration algorithm. +$_fixgauge_docs """ @algdef LAPACK_QRIteration """ - LAPACK_Bisection(; fixgauge::Bool = true) + LAPACK_Bisection(; fixgauge = default_fixgauge()) -Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a -Hermitian matrix, or the singular value decomposition of a general matrix using the -Bisection algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). +Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a Hermitian matrix, +or the singular value decomposition of a general matrix using the Bisection algorithm. +$_fixgauge_docs """ @algdef LAPACK_Bisection """ - LAPACK_DivideAndConquer(; fixgauge::Bool = true) + LAPACK_DivideAndConquer(; fixgauge = default_fixgauge()) -Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a -Hermitian matrix, or the singular value decomposition of a general matrix using the -Divide and Conquer algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). +Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a Hermitian matrix, +or the singular value decomposition of a general matrix using the Divide and Conquer algorithm. +$_fixgauge_docs """ @algdef LAPACK_DivideAndConquer """ - LAPACK_MultipleRelativelyRobustRepresentations(; fixgauge::Bool = true) + LAPACK_MultipleRelativelyRobustRepresentations(; fixgauge = default_fixgauge()) -Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a -Hermitian matrix using the Multiple Relatively Robust Representations algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, -see also [`gaugefix!`](@ref). +Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a Hermitian matrix +using the Multiple Relatively Robust Representations algorithm. +$_fixgauge_docs """ @algdef LAPACK_MultipleRelativelyRobustRepresentations @@ -277,26 +256,24 @@ const LAPACK_EighAlgorithm = Union{ } """ - GLA_QRIteration(; fixgauge::Bool = true) + GLA_QRIteration(; fixgauge = default_fixgauge()) Algorithm type to denote the GenericLinearAlgebra.jl implementation for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef GLA_QRIteration # Singular Value Decomposition # ---------------------------- """ - LAPACK_SafeDivideAndConquer(; fixgauge::Bool = true) + LAPACK_SafeDivideAndConquer(; fixgauge = default_fixgauge()) Algorithm type to denote the LAPACK driver for computing the singular value decomposition of a general matrix using the Divide and Conquer algorithm, with an additional fallback to the standard QR Iteration algorithm in case the former fails to converge. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular vectors, -see also [`gaugefix!`](@ref). +$_fixgauge_docs !!! warning This approach requires a copy of the input matrix, and is thus the most memory intensive SVD strategy. @@ -306,12 +283,11 @@ see also [`gaugefix!`](@ref). @algdef LAPACK_SafeDivideAndConquer """ - LAPACK_Jacobi(; fixgauge::Bool = true) + LAPACK_Jacobi(; fixgauge = default_fixgauge()) Algorithm type to denote the LAPACK driver for computing the singular value decomposition of a general matrix using the Jacobi algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular vectors, -see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef LAPACK_Jacobi @@ -384,34 +360,31 @@ the diagonal elements of `R` are non-negative. @algdef CUSOLVER_HouseholderQR """ - CUSOLVER_QRIteration(; fixgauge::Bool = true) + CUSOLVER_QRIteration(; fixgauge = default_fixgauge()) Algorithm type to denote the CUSOLVER driver for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the QR Iteration algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef CUSOLVER_QRIteration """ - CUSOLVER_SVDPolar(; fixgauge::Bool = true) + CUSOLVER_SVDPolar(; fixgauge = default_fixgauge()) Algorithm type to denote the CUSOLVER driver for computing the singular value decomposition of a general matrix by using Halley's iterative algorithm to compute the polar decompositon, followed by the hermitian eigenvalue decomposition of the positive definite factor. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular -vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef CUSOLVER_SVDPolar """ - CUSOLVER_Jacobi(; fixgauge::Bool = true) + CUSOLVER_Jacobi(; fixgauge = default_fixgauge()) Algorithm type to denote the CUSOLVER driver for computing the singular value decomposition of a general matrix using the Jacobi algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular -vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef CUSOLVER_Jacobi @@ -433,25 +406,23 @@ for more information. does_truncate(::TruncatedAlgorithm{<:CUSOLVER_Randomized}) = true """ - CUSOLVER_Simple(; fixgauge::Bool = true) + CUSOLVER_Simple(; fixgauge = default_fixgauge()) Algorithm type to denote the simple CUSOLVER driver for computing the non-Hermitian eigenvalue decomposition of a matrix. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, -see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef CUSOLVER_Simple const CUSOLVER_EigAlgorithm = Union{CUSOLVER_Simple} """ - CUSOLVER_DivideAndConquer(; fixgauge::Bool = true) + CUSOLVER_DivideAndConquer(; fixgauge = default_fixgauge()) Algorithm type to denote the CUSOLVER driver for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the Divide and Conquer algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef CUSOLVER_DivideAndConquer @@ -472,45 +443,41 @@ the diagonal elements of `R` are non-negative. @algdef ROCSOLVER_HouseholderQR """ - ROCSOLVER_QRIteration(; fixgauge::Bool = true) + ROCSOLVER_QRIteration(; fixgauge = default_fixgauge()) Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the QR Iteration algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef ROCSOLVER_QRIteration """ - ROCSOLVER_Jacobi(; fixgauge::Bool = true) + ROCSOLVER_Jacobi(; fixgauge = default_fixgauge()) Algorithm type to denote the ROCSOLVER driver for computing the singular value decomposition of a general matrix using the Jacobi algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular -vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef ROCSOLVER_Jacobi """ - ROCSOLVER_Bisection(; fixgauge::Bool = true) + ROCSOLVER_Bisection(; fixgauge = default_fixgauge()) Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the Bisection algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef ROCSOLVER_Bisection """ - ROCSOLVER_DivideAndConquer(; fixgauge::Bool = true) + ROCSOLVER_DivideAndConquer(; fixgauge = default_fixgauge()) Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the Divide and Conquer algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). +$_fixgauge_docs """ @algdef ROCSOLVER_DivideAndConquer From 578a1b737b292a990bb15b8ebd07b612c9f024c5 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 17 Mar 2026 14:44:50 -0400 Subject: [PATCH 16/17] more carefully designated supported drivers --- ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl | 4 ++++ ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl | 4 ++++ src/implementations/svd.jl | 4 ---- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 081218ed..24435d71 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -27,6 +27,9 @@ for f in (:geqrf!, :ungqr!, :unmqr!) @eval $f(::ROCSOLVER, args...) = YArocSOLVER.$f(args...) end +MatrixAlgebraKit.supports_svd(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi) +MatrixAlgebraKit.supports_svd_full(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi) + function gesvd!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) m, n = size(A) m >= n && return YArocSOLVER.gesvd!(A, S, U, Vᴴ) @@ -38,6 +41,7 @@ function gesvdj!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::Strid m >= n && return YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...) return MatrixAlgebraKit.svd_via_adjoint!(gesvdj!, ROCSOLVER(), A, S, U, Vᴴ; kwargs...) end + _gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevj!(A, Dd, V; kwargs...) _gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 837a5bd2..166cd666 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -27,10 +27,14 @@ function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T return CUSOLVER_DivideAndConquer(; kwargs...) end + for f in (:geqrf!, :ungqr!, :unmqr!) @eval $f(::CUSOLVER, args...) = YACUSOLVER.$f(args...) end +MatrixAlgebraKit.supports_svd(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar) +MatrixAlgebraKit.supports_svd_full(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar) + function gesvd!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) m, n = size(A) m >= n && return YACUSOLVER.gesvd!(A, S, U, Vᴴ) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 2ad20d82..fffd49c4 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -216,13 +216,9 @@ end supports_svd(::Driver, ::Symbol) = false supports_svd(::LAPACK, f::Symbol) = f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration, :bisection, :jacobi) supports_svd(::GLA, f::Symbol) = f === :qr_iteration -supports_svd(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar) -supports_svd(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi) supports_svd_full(::Driver, ::Symbol) = false supports_svd_full(::LAPACK, f::Symbol) = f in (:safe_divide_and_conquer, :divide_and_conquer, :qr_iteration) supports_svd_full(::GLA, f::Symbol) = f === :qr_iteration -supports_svd_full(::CUSOLVER, f::Symbol) = f === :qr_iteration -supports_svd_full(::ROCSOLVER, f::Symbol) = f === :qr_iteration function svd_trunc_no_error!(A, USVᴴ, alg::TruncatedAlgorithm) U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) From 9b21a2cf57bcc6205c364913779cb79520d3df78 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 17 Mar 2026 15:28:17 -0400 Subject: [PATCH 17/17] refactor SVD tests to reduce memory pressure on GPU --- test/svd.jl | 129 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 80 insertions(+), 49 deletions(-) diff --git a/test/svd.jl b/test/svd.jl index 800f191b..2125a7b4 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -14,55 +14,86 @@ using .TestSuite is_buildkite = get(ENV, "BUILDKITE", "false") == "true" -for T in (BLASFloats..., GenericFloats...), m in (0, 54), n in (0, 37, m, 63) - TestSuite.seed_rng!(123) - if T ∈ BLASFloats - if CUDA.functional() - TestSuite.test_svd(CuMatrix{T}, (m, n)) - CUDA_SVD_ALGS = ( - CUSOLVER_QRIteration(), - CUSOLVER_SVDPolar(), - CUSOLVER_Jacobi(), - ) - TestSuite.test_svd_algs(CuMatrix{T}, (m, n), CUDA_SVD_ALGS) - k = 5 - p = min(m, n) - k - 2 - min(m, n) > k + 2 && TestSuite.test_randomized_svd(CuMatrix{T}, (m, n), (MatrixAlgebraKit.TruncatedAlgorithm(CUSOLVER_Randomized(; k, p, niters = 20), truncrank(k)),)) - if n == m - TestSuite.test_svd(Diagonal{T, CuVector{T}}, m) - TestSuite.test_svd_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),)) - end - end - if AMDGPU.functional() - TestSuite.test_svd(ROCMatrix{T}, (m, n)) - AMD_SVD_ALGS = ( - ROCSOLVER_QRIteration(), - ROCSOLVER_Jacobi(), - ) - TestSuite.test_svd_algs(ROCMatrix{T}, (m, n), AMD_SVD_ALGS) - if n == m - TestSuite.test_svd(Diagonal{T, ROCVector{T}}, m) - TestSuite.test_svd_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),)) - end - end +# CPU tests +# --------- +if !is_buildkite + # LAPACK algorithms: + for T in BLASFloats, m in (0, 54), n in (0, 37, m, 63) + TestSuite.seed_rng!(123) + LAPACK_SVD_ALGS = ( + LAPACK_QRIteration(), + LAPACK_DivideAndConquer(), + LAPACK_SafeDivideAndConquer(; fixgauge = true), + ) + TestSuite.test_svd(T, (m, n)) + TestSuite.test_svd_algs(T, (m, n), LAPACK_SVD_ALGS) end - if !is_buildkite - if T ∈ BLASFloats - LAPACK_SVD_ALGS = ( - LAPACK_QRIteration(), - LAPACK_DivideAndConquer(), - LAPACK_SafeDivideAndConquer(; fixgauge = true), - ) - TestSuite.test_svd(T, (m, n)) - TestSuite.test_svd_algs(T, (m, n), LAPACK_SVD_ALGS) - elseif T ∈ GenericFloats - TestSuite.test_svd(T, (m, n)) - TestSuite.test_svd_algs(T, (m, n), (GLA_QRIteration(),)) - end - if m == n - AT = Diagonal{T, Vector{T}} - TestSuite.test_svd(AT, m) - TestSuite.test_svd_algs(AT, m, (DiagonalAlgorithm(),)) - end + + # Generic floats: + for T in GenericFloats, m in (0, 54), n in (0, 37, m, 63) + TestSuite.seed_rng!(123) + TestSuite.test_svd(T, (m, n)) + TestSuite.test_svd_algs(T, (m, n), (GLA_QRIteration(),)) + end + + # Diagonal: + for T in (BLASFloats..., GenericFloats...), m in (0, 54) + TestSuite.seed_rng!(123) + AT = Diagonal{T, Vector{T}} + TestSuite.test_svd(AT, m) + TestSuite.test_svd_algs(AT, m, (DiagonalAlgorithm(),)) + end +end + +# CUDA tests +# ------------ +if CUDA.functional() + # LAPACK algorithms: + for T in BLASFloats, m in (0, 23), n in (0, 17, m, 27) + TestSuite.seed_rng!(123) + TestSuite.test_svd(CuMatrix{T}, (m, n)) + CUDA_SVD_ALGS = ( + CUSOLVER_QRIteration(), + CUSOLVER_SVDPolar(), + CUSOLVER_Jacobi(), + ) + TestSuite.test_svd_algs(CuMatrix{T}, (m, n), CUDA_SVD_ALGS) + end + + # Randomized SVD: + for T in BLASFloats, m in (0, 23), n in (0, 17, m, 27) + TestSuite.seed_rng!(123) + k = 5 + p = min(m, n) - k - 2 + p > 0 || continue + TestSuite.test_randomized_svd(CuMatrix{T}, (m, n), (MatrixAlgebraKit.TruncatedAlgorithm(CUSOLVER_Randomized(; k, p, niters = 20), truncrank(k)),)) + end + + # Diagonal: + for T in BLASFloats, m in (0, 23) + TestSuite.seed_rng!(123) + AT = Diagonal{T, CuVector{T}} + TestSuite.test_svd(AT, m) + TestSuite.test_svd_algs(AT, m, (DiagonalAlgorithm(),)) + end +end + +# AMDGPU tests +# ------------ +if AMDGPU.functional() + # LAPACK algorithms: + for T in BLASFloats, m in (0, 23), n in (0, 17, m, 27) + TestSuite.seed_rng!(123) + TestSuite.test_svd(ROCMatrix{T}, (m, n)) + AMD_SVD_ALGS = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi()) + TestSuite.test_svd_algs(ROCMatrix{T}, (m, n), AMD_SVD_ALGS) + end + + # Diagonal: + for T in BLASFloats, m in (0, 23) + TestSuite.seed_rng!(123) + AT = Diagonal{T, ROCVector{T}} + TestSuite.test_svd(AT, m) + TestSuite.test_svd_algs(AT, m, (DiagonalAlgorithm(),)) end end