Skip to content
Open
67 changes: 57 additions & 10 deletions docs/src/user_interface/decompositions.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,10 @@ lq_full
lq_compact
```

Alongside these functions, we provide a LAPACK-based implementation for dense arrays, as provided by the following algorithm:
The following algorithm is available for QR and LQ decompositions:

```@docs; canonical=false
LAPACK_HouseholderQR
LAPACK_HouseholderLQ
Householder
```

## Eigenvalue Decomposition
Expand All @@ -63,9 +62,9 @@ These functions return the diagonal elements of `D` in a vector.
Finally, it is also possible to compute a partial or truncated eigenvalue decomposition, using the [`eig_trunc`](@ref) and [`eigh_trunc`](@ref) functions.
To control the behavior of the truncation, we refer to [Truncations](@ref) for more information.

### Symmetric Eigenvalue Decomposition
### Hermitian or Real Symmetric Eigenvalue Decomposition

For symmetric matrices, we provide the following functions:
For hermitian matrices, thus including real symmetric matrices, we provide the following functions:

```@docs; canonical=false
eigh_full
Expand All @@ -78,7 +77,7 @@ eigh_vals
By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results.
See [Gauge choices](@ref sec_gaugefix) for more details.

Alongside these functions, we provide a LAPACK-based implementation for dense arrays, as provided by the following algorithms:
The following algorithms are available for the hermitian eigenvalue decomposition:

```@autodocs; canonical=false
Modules = [MatrixAlgebraKit]
Expand All @@ -100,7 +99,7 @@ eig_vals
By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results.
See [Gauge choices](@ref sec_gaugefix) for more details.

Alongside these functions, we provide a LAPACK-based implementation for dense arrays, as provided by the following algorithms:
The following algorithms are available for the standard eigenvalue decomposition:

```@autodocs; canonical=false
Modules = [MatrixAlgebraKit]
Expand All @@ -120,7 +119,7 @@ schur_full
schur_vals
```

The LAPACK-based implementation for dense arrays is provided by the following algorithms:
The following algorithms are available for the Schur decomposition:

```@autodocs; canonical=false
Modules = [MatrixAlgebraKit]
Expand Down Expand Up @@ -153,11 +152,11 @@ svd_trunc
By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results.
See [Gauge choices](@ref sec_gaugefix) for more details.

MatrixAlgebraKit again ships with LAPACK-based implementations for dense arrays:
The following algorithms are available for the singular value decomposition:

```@autodocs; canonical=false
Modules = [MatrixAlgebraKit]
Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_SVDAlgorithm
Filter = t -> t isa Type && t <: MatrixAlgebraKit.SVDAlgorithms
```

## Polar Decomposition
Expand Down Expand Up @@ -388,6 +387,54 @@ norm(A * N1') < 1e-14 && norm(A * N2') < 1e-14 &&
true
```

## [Driver Selection](@id sec_driverselection)

!!! note "Expert use case"
Selecting a specific driver is an advanced feature intended for users who need to target a specific computational backend, such as a GPU. For most use cases, the default driver selection is sufficient.

Each algorithm in MatrixAlgebraKit can optionally accept a `driver` keyword argument to explicitly select the computational backend.
By default, the driver is set to `DefaultDriver()`, which automatically selects the most appropriate backend based on the input matrix type.
The available drivers are:

```@docs; canonical=false
MatrixAlgebraKit.DefaultDriver
MatrixAlgebraKit.LAPACK
MatrixAlgebraKit.CUSOLVER
MatrixAlgebraKit.ROCSOLVER
MatrixAlgebraKit.GLA
MatrixAlgebraKit.Native
```

For example, to force LAPACK for a generic matrix type, or to use a GPU backend:

```julia
using MatrixAlgebraKit
using MatrixAlgebraKit: LAPACK, CUSOLVER # driver types are not exported by default

# Default: driver is selected automatically based on the input type
U, S, Vᴴ = svd_compact(A)
U, S, Vᴴ = svd_compact(A; alg = SafeDivideAndConquer())

# Expert: explicitly select LAPACK
U, S, Vᴴ = svd_compact(A; alg = SafeDivideAndConquer(; driver = LAPACK()))

# Expert: use a GPU backend (requires loading the appropriate extension)
U, S, Vᴴ = svd_compact(A; alg = QRIteration(; driver = CUSOLVER()))
```

Similarly, for QR decompositions:

```julia
using MatrixAlgebraKit: LAPACK # driver types are not exported by default

# Default: driver is selected automatically
Q, R = qr_compact(A)
Q, R = qr_compact(A; alg = Householder())

# Expert: explicitly select a driver
Q, R = qr_compact(A; alg = Householder(; driver = LAPACK()))
```

## [Gauge choices](@id sec_gaugefix)

Both eigenvalue and singular value decompositions have residual gauge degrees of freedom even when the eigenvalues or singular values are unique.
Expand Down
29 changes: 17 additions & 12 deletions ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,38 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: ROCSOLVER, LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdj!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank
using AMDGPU
using LinearAlgebra
using LinearAlgebra: BlasFloat

include("yarocsolver.jl")

MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedROCMatrix{<:BlasFloat}} = ROCSOLVER()
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
return ROCSOLVER_QRIteration(; kwargs...)
MatrixAlgebraKit.default_driver(::Type{TA}) where {TA <: StridedROCVecOrMat{<:BlasFloat}} = ROCSOLVER()

function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat{<:BlasFloat}}
return QRIteration(; kwargs...)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat{<:BlasFloat}}
return ROCSOLVER_DivideAndConquer(; kwargs...)
end

for f in (:geqrf!, :ungqr!, :unmqr!)
@eval $f(::ROCSOLVER, args...) = YArocSOLVER.$f(args...)
end

_gpu_gesvd!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix) =
YArocSOLVER.gesvd!(A, S, U, Vᴴ)
# not yet supported
# _gpu_Xgesvdp!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) =
# YArocSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...)
_gpu_gesvdj!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) =
YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
function gesvd!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...)
m, n = size(A)
m >= n && return YArocSOLVER.gesvd!(A, S, U, Vᴴ)
return MatrixAlgebraKit.svd_via_adjoint!(gesvd!, ROCSOLVER(), A, S, U, Vᴴ; kwargs...)
end

function gesvdj!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...)
m, n = size(A)
m >= n && return YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
return MatrixAlgebraKit.svd_via_adjoint!(gesvdj!, ROCSOLVER(), A, S, U, Vᴴ; kwargs...)
end
_gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
YArocSOLVER.heevj!(A, Dd, V; kwargs...)
_gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
Expand Down
41 changes: 26 additions & 15 deletions ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,51 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester, svd_rank
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!, _gpu_geev!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_Xgesvdr!, _sylvester, svd_rank
using CUDA, CUDA.CUBLAS
using CUDA: i32
using LinearAlgebra
using LinearAlgebra: BlasFloat

include("yacusolver.jl")

MatrixAlgebraKit.default_householder_driver(::Type{A}) where {A <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER()
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
return CUSOLVER_QRIteration(; kwargs...)
MatrixAlgebraKit.default_driver(::Type{TA}) where {TA <: StridedCuVecOrMat{<:BlasFloat}} = CUSOLVER()

function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
return QRIteration(; kwargs...)
end
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
return CUSOLVER_Simple(; kwargs...)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
return CUSOLVER_DivideAndConquer(; kwargs...)
end

for f in (:geqrf!, :ungqr!, :unmqr!)
@eval $f(::CUSOLVER, args...) = YACUSOLVER.$f(args...)
end

function gesvd!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...)
m, n = size(A)
m >= n && return YACUSOLVER.gesvd!(A, S, U, Vᴴ)
return MatrixAlgebraKit.svd_via_adjoint!(gesvd!, CUSOLVER(), A, S, U, Vᴴ; kwargs...)
end

function gesvdj!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...)
m, n = size(A)
m >= n && return YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
return MatrixAlgebraKit.svd_via_adjoint!(gesvdj!, CUSOLVER(), A, S, U, Vᴴ; kwargs...)
end

gesvdp!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
YACUSOLVER.gesvdp!(A, S, U, Vᴴ; kwargs...)

_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
YACUSOLVER.gesvdr!(A, S, U, Vᴴ; kwargs...)

_gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) =
YACUSOLVER.Xgeev!(A, D, V)
_gpu_gesvd!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix) =
YACUSOLVER.gesvd!(A, S, U, Vᴴ)
_gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
YACUSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...)
_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
YACUSOLVER.Xgesvdr!(A, S, U, Vᴴ; kwargs...)
_gpu_gesvdj!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
YACUSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)

_gpu_heevj!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) =
YACUSOLVER.heevj!(A, Dd, V; kwargs...)
Expand Down
8 changes: 3 additions & 5 deletions ext/MatrixAlgebraKitCUDAExt/yacusolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ for (bname, fname, elty, relty) in
end
end

function Xgesvdp!(
function gesvdp!(
A::StridedCuMatrix{T},
S::StridedCuVector = similar(A, real(T), min(size(A)...)),
U::StridedCuMatrix{T} = similar(A, T, size(A, 1), min(size(A)...)),
Expand Down Expand Up @@ -164,9 +164,7 @@ function Xgesvdp!(
)
end
err = h_err_sigma[]
if err > tol
warn("Xgesvdp! did not attained requested tolerance: error = $err > tolerance = $tol")
end
err > tol && @warn "gesvdp! did not attain the requested tolerance: error = $err > tolerance = $tol"

flag = @allowscalar dh.info[1]
CUSOLVER.chklapackerror(BlasInt(flag))
Expand Down Expand Up @@ -269,7 +267,7 @@ for (bname, fname, elty, relty) in
end

# Wrapper for randomized SVD
function Xgesvdr!(
function gesvdr!(
A::StridedCuMatrix{T},
S::StridedCuVector = similar(A, real(T), min(size(A)...)),
U::StridedCuMatrix{T} = similar(A, T, size(A, 1), min(size(A)...)),
Expand Down
54 changes: 23 additions & 31 deletions ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,37 @@ module MatrixAlgebraKitGenericLinearAlgebraExt

using MatrixAlgebraKit
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge
using MatrixAlgebraKit: GLA
import MatrixAlgebraKit: gesvd!
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
using LinearAlgebra: I, Diagonal, lmul!

function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
return GLA_QRIteration()
end

for f! in (:svd_compact!, :svd_full!)
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing, nothing)
end
MatrixAlgebraKit.initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing
const GlaFloat = Union{BigFloat, Complex{BigFloat}}
const GlaStridedVecOrMatrix{T <: GlaFloat} = Union{StridedVector{T}, StridedMatrix{T}}
MatrixAlgebraKit.default_driver(::Type{<:QRIteration}, ::Type{TA}) where {TA <: GlaStridedVecOrMatrix} = GLA()

function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration)
F = svd!(A)
U, S, Vᴴ = F.U, Diagonal(F.S), F.Vt

do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ)

return U, S, Vᴴ
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix}
return QRIteration(; kwargs...)
end

function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration)
F = svd!(A; full = true)
U, Vᴴ = F.U, F.Vt
S = MatrixAlgebraKit.zero!(similar(F.S, (size(U, 2), size(Vᴴ, 1))))
diagview(S) .= F.S

do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ)

return U, S, Vᴴ
end

function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, ::GLA_QRIteration)
return svdvals!(A)
function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...)
m, n = size(A)
if length(U) == 0 && length(Vᴴ) == 0
Sv = svdvals!(A)
copyto!(S, Sv)
else
minmn = min(m, n)
# full SVD if U has m columns or Vᴴ has n rows (beyond the compact min(m,n))
full = (length(U) > 0 && size(U, 2) > minmn) || (length(Vᴴ) > 0 && size(Vᴴ, 1) > minmn)
F = svd!(A; full = full)
length(S) > 0 && copyto!(S, F.S)
length(U) > 0 && copyto!(U, F.U)
length(Vᴴ) > 0 && copyto!(Vᴴ, F.Vt)
end
return S, U, Vᴴ
end

function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix}
return GLA_QRIteration(; kwargs...)
end

Expand Down
1 change: 1 addition & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ export left_orth, right_orth, left_null, right_null
export left_orth!, right_orth!, left_null!, right_null!

export Householder, Native_HouseholderQR, Native_HouseholderLQ
export DivideAndConquer, SafeDivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
LAPACK_DivideAndConquer, LAPACK_Jacobi, LAPACK_SafeDivideAndConquer
Expand Down
17 changes: 17 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,23 @@ Driver to select a native implementation in MatrixAlgebraKit as the implementati
"""
struct Native <: Driver end

# In order to avoid amibiguities, this method is implemented in a tiered way
# default_driver(alg, A) -> default_driver(typeof(alg), typeof(A))
# default_driver(Talg, TA) -> default_driver(TA)
# This is to try and minimize ambiguity while allowing overloading at multiple levels
@inline default_driver(alg::AbstractAlgorithm, A) = default_driver(typeof(alg), A isa Type ? A : typeof(A))
@inline default_driver(::Type{Alg}, A) where {Alg <: AbstractAlgorithm} = default_driver(Alg, typeof(A))
@inline default_driver(::Type{Alg}, ::Type{TA}) where {Alg <: AbstractAlgorithm, TA} = default_driver(TA)

# defaults
default_driver(::Type{TA}) where {TA <: AbstractArray} = Native() # default fallback
default_driver(::Type{TA}) where {TA <: YALAPACK.MaybeBlasVecOrMat} = LAPACK()

# wrapper types
@inline default_driver(::Type{Alg}, ::Type{<:SubArray{T, N, A}}) where {Alg <: AbstractAlgorithm, T, N, A} = default_driver(Alg, A)
@inline default_driver(::Type{Alg}, ::Type{<:Base.ReshapedArray{T, N, A}}) where {Alg <: AbstractAlgorithm, T, N, A} = default_driver(Alg, A)
@inline default_driver(::Type{<:SubArray{T, N, A}}) where {T, N, A} = default_driver(A)
@inline default_driver(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = default_driver(A)

# Truncation strategy
# -------------------
Expand Down
2 changes: 2 additions & 0 deletions src/common/defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,5 @@ function default_fixgauge(new_value::Bool)
DEFAULT_FIXGAUGE[] = new_value
return previous_value
end

const _fixgauge_docs = "The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the output, see also [`default_fixgauge`](@ref) for a global toggle and [`gaugefix!`](@ref) for implementation details."
3 changes: 3 additions & 0 deletions src/common/gauge.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ This is achieved by ensuring that the entry with the largest magnitude in `V` or
is real and positive.
""" gaugefix!

# Helper functions
_argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x)))
_largest(x, y) = abs(x) < abs(y) ? y : x

function gaugefix!(::Union{typeof(eig_full!), typeof(eigh_full!), typeof(gen_eig_full!)}, V::AbstractMatrix)
for j in axes(V, 2)
Expand Down
Loading
Loading