Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/src/user_interface/decompositions.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ The following algorithms are available for the hermitian eigenvalue decompositio

```@autodocs; canonical=false
Modules = [MatrixAlgebraKit]
Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_EighAlgorithm
Filter = t -> t isa Type && t <: MatrixAlgebraKit.EighAlgorithms
```

### Eigenvalue Decomposition
Expand All @@ -103,7 +103,7 @@ The following algorithms are available for the standard eigenvalue decomposition

```@autodocs; canonical=false
Modules = [MatrixAlgebraKit]
Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_EigAlgorithm
Filter = t -> t isa Type && t <: MatrixAlgebraKit.EigAlgorithms
```

## Schur Decomposition
Expand All @@ -123,7 +123,7 @@ The following algorithms are available for the Schur decomposition:

```@autodocs; canonical=false
Modules = [MatrixAlgebraKit]
Filter = t -> t isa Type && t <: MatrixAlgebraKit.LAPACK_EigAlgorithm
Filter = t -> t isa Type && t <: MatrixAlgebraKit.SchurAlgorithms
```

## Singular Value Decomposition
Expand Down
14 changes: 7 additions & 7 deletions ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: ROCSOLVER, LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdj!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank
import MatrixAlgebraKit: heevj!, heevd!, heev!, heevx!
import MatrixAlgebraKit: _sylvester, svd_rank
using AMDGPU
using LinearAlgebra
using LinearAlgebra: BlasFloat
Expand All @@ -20,14 +21,13 @@ function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <
return QRIteration(; kwargs...)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCVecOrMat{<:BlasFloat}}
return ROCSOLVER_DivideAndConquer(; kwargs...)
return DivideAndConquer(; kwargs...)
end

for f in (:geqrf!, :ungqr!, :unmqr!)
@eval $f(::ROCSOLVER, args...) = YArocSOLVER.$f(args...)
end

MatrixAlgebraKit.supports_svd(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi)
MatrixAlgebraKit.supports_svd_full(::ROCSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi)

function gesvd!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...)
Expand All @@ -42,13 +42,13 @@ function gesvdj!(::ROCSOLVER, A::StridedROCMatrix, S::StridedROCVector, U::Strid
return MatrixAlgebraKit.svd_via_adjoint!(gesvdj!, ROCSOLVER(), A, S, U, Vᴴ; kwargs...)
end

_gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
heevj!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
YArocSOLVER.heevj!(A, Dd, V; kwargs...)
_gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
heevd!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
YArocSOLVER.heevd!(A, Dd, V; kwargs...)
_gpu_heev!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
heev!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
YArocSOLVER.heev!(A, Dd, V; kwargs...)
_gpu_heevx!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
heevx!(::ROCSOLVER, A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
YArocSOLVER.heevx!(A, Dd, V; kwargs...)

function MatrixAlgebraKit.findtruncated_svd(values::StridedROCVector, strategy::TruncationByValue)
Expand Down
18 changes: 9 additions & 9 deletions ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!, _gpu_geev!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_Xgesvdr!, _sylvester, svd_rank
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!
import MatrixAlgebraKit: heevj!, heevd!, geev!
import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank
using CUDA, CUDA.CUBLAS
using CUDA: i32
using LinearAlgebra
Expand All @@ -21,18 +22,17 @@ function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <
return QRIteration(; kwargs...)
end
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
return CUSOLVER_Simple(; kwargs...)
return QRIteration(; kwargs...)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedCuVecOrMat{<:BlasFloat}}
return CUSOLVER_DivideAndConquer(; kwargs...)
return DivideAndConquer(; kwargs...)
end


for f in (:geqrf!, :ungqr!, :unmqr!)
@eval $f(::CUSOLVER, args...) = YACUSOLVER.$f(args...)
end

MatrixAlgebraKit.supports_svd(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar)
MatrixAlgebraKit.supports_svd_full(::CUSOLVER, f::Symbol) = f in (:qr_iteration, :jacobi, :svd_polar)

function gesvd!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...)
Expand All @@ -53,12 +53,12 @@ gesvdp!(::CUSOLVER, A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix,
_gpu_Xgesvdr!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
YACUSOLVER.gesvdr!(A, S, U, Vᴴ; kwargs...)

_gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) =
YACUSOLVER.Xgeev!(A, D, V)
geev!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix) =
YACUSOLVER.Xgeev!(A, Dd, V)

_gpu_heevj!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) =
heevj!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) =
YACUSOLVER.heevj!(A, Dd, V; kwargs...)
_gpu_heevd!(A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) =
heevd!(::CUSOLVER, A::StridedCuMatrix, Dd::StridedCuVector, V::StridedCuMatrix; kwargs...) =
YACUSOLVER.heevd!(A, Dd, V; kwargs...)

function MatrixAlgebraKit.findtruncated_svd(values::StridedCuVector, strategy::TruncationByValue)
Expand Down
37 changes: 21 additions & 16 deletions ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@ module MatrixAlgebraKitGenericLinearAlgebraExt

using MatrixAlgebraKit
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge
using MatrixAlgebraKit: GLA
import MatrixAlgebraKit: gesvd!
using MatrixAlgebraKit: GLA, Driver
import MatrixAlgebraKit: gesvd!, heev!
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
using LinearAlgebra: I, Diagonal, lmul!

const GlaFloat = Union{BigFloat, Complex{BigFloat}}
const GlaStridedVecOrMatrix{T <: GlaFloat} = Union{StridedVector{T}, StridedMatrix{T}}
MatrixAlgebraKit.default_driver(::Type{<:QRIteration}, ::Type{TA}) where {TA <: GlaStridedVecOrMatrix} = GLA()

function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix}
return QRIteration(; kwargs...)
MatrixAlgebraKit.supports_svd_full(::GLA, f::Symbol) = f === :qr_iteration

function MatrixAlgebraKit.default_svd_algorithm(
::Type{T};
driver::Driver = GLA(), kwargs...
) where {T <: GlaStridedVecOrMatrix}
return QRIteration(; driver, kwargs...)
end

function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix; kwargs...)
Expand All @@ -32,20 +37,20 @@ function gesvd!(::GLA, A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix,
return S, U, Vᴴ
end

function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: GlaStridedVecOrMatrix}
return GLA_QRIteration(; kwargs...)
end

MatrixAlgebraKit.initialize_output(::typeof(eigh_full!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing)
MatrixAlgebraKit.initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing

function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix, DV, ::GLA_QRIteration)
eigval, eigvec = eigen!(Hermitian(A); sortby = real)
return Diagonal(eigval::AbstractVector{real(eltype(A))}), eigvec::AbstractMatrix{eltype(A)}
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; driver::Driver = GLA(), kwargs...) where {T <: GlaStridedVecOrMatrix}
return QRIteration(; driver, kwargs...)
end

function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration)
return eigvals!(Hermitian(A); sortby = real)
function heev!(::GLA, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...)
if length(V) > 0
eigval, eigvec = eigen!(Hermitian(A); sortby = real)
copyto!(Dd, eigval)
copyto!(V, eigvec)
else
eigval = eigvals!(Hermitian(A); sortby = real)
copyto!(Dd, eigval)
end
return Dd, V
end

function MatrixAlgebraKit.householder_qr!(
Expand Down
59 changes: 34 additions & 25 deletions ext/MatrixAlgebraKitGenericSchurExt.jl
Original file line number Diff line number Diff line change
@@ -1,41 +1,50 @@
module MatrixAlgebraKitGenericSchurExt

using MatrixAlgebraKit
using MatrixAlgebraKit: check_input
using MatrixAlgebraKit: check_input, GS, Driver
import MatrixAlgebraKit: geev!, geevx!, gees!, eig_full!, eig_vals!, schur_full!, schur_vals!
using LinearAlgebra: Diagonal, sorteig!
using GenericSchur

function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}
return GS_QRIteration(; kwargs...)
end

MatrixAlgebraKit.initialize_output(::typeof(eig_full!), A::AbstractMatrix, ::GS_QRIteration) = (nothing, nothing)
MatrixAlgebraKit.initialize_output(::typeof(eig_vals!), A::AbstractMatrix, ::GS_QRIteration) = nothing

function MatrixAlgebraKit.eig_full!(A::AbstractMatrix, DV, ::GS_QRIteration)
D, V = GenericSchur.eigen!(A)
return Diagonal(D), V
end
const GSFloat = Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}

function MatrixAlgebraKit.eig_vals!(A::AbstractMatrix, D, ::GS_QRIteration)
return GenericSchur.eigvals!(A)
function MatrixAlgebraKit.default_eig_algorithm(
::Type{T}; driver::Driver = GS(), kwargs...
) where {T <: StridedMatrix{<:GSFloat}}
return QRIteration(; driver, kwargs...)
end

function MatrixAlgebraKit.schur_full!(A::AbstractMatrix, TZv, alg::GS_QRIteration)
check_input(schur_full!, A, TZv, alg)
T, Z, vals = TZv
S = GenericSchur.gschur(A)
copyto!(T, S.T)
copyto!(Z, S.Z)
copyto!(vals, S.values)
return T, Z, vals
function geev!(::GS, A::AbstractMatrix, Dd::AbstractVector, V::AbstractMatrix; kwargs...)
D, Vmat = GenericSchur.eigen!(A)
copyto!(Dd, D)
length(V) > 0 && copyto!(V, Vmat)
return Dd, V
end

function MatrixAlgebraKit.schur_vals!(A::AbstractMatrix, vals, alg::GS_QRIteration)
check_input(schur_vals!, A, vals, alg)
function gees!(driver::GS, A::AbstractMatrix, Z::AbstractMatrix, vals::AbstractVector)
S = GenericSchur.gschur(A)
copyto!(A, S.T)
length(Z) > 0 && copyto!(Z, S.Z)
copyto!(vals, sorteig!(S.values))
return vals
return A, Z, vals
end

Base.@deprecate(
eig_full!(A, DV, alg::GS_QRIteration),
eig_full!(A, DV, QRIteration(; driver = GS(), alg.kwargs...))
)
Base.@deprecate(
eig_vals!(A, D, alg::GS_QRIteration),
eig_vals!(A, D, QRIteration(; driver = GS(), alg.kwargs...))
)

Base.@deprecate(
schur_full!(A, TZv, alg::GS_QRIteration),
schur_full!(A, TZv, QRIteration(; driver = GS(), alg.kwargs...))
)
Base.@deprecate(
schur_vals!(A, vals, alg::GS_QRIteration),
schur_vals!(A, vals, QRIteration(; driver = GS(), alg.kwargs...))
)

end
1 change: 1 addition & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export left_orth!, right_orth!, left_null!, right_null!

export Householder, Native_HouseholderQR, Native_HouseholderLQ
export DivideAndConquer, SafeDivideAndConquer, QRIteration, Bisection, Jacobi, SVDViaPolar
export RobustRepresentations
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
LAPACK_DivideAndConquer, LAPACK_Jacobi, LAPACK_SafeDivideAndConquer
Expand Down
7 changes: 7 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,13 @@ Driver to select a native implementation in MatrixAlgebraKit as the implementati
"""
struct Native <: Driver end

"""
GS <: Driver
Driver to select GenericSchur.jl as the implementation strategy.
"""
struct GS <: Driver end

# In order to avoid amibiguities, this method is implemented in a tiered way
# default_driver(alg, A) -> default_driver(typeof(alg), typeof(A))
# default_driver(Talg, TA) -> default_driver(TA)
Expand Down
Loading
Loading