Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/implementations/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ function _left_polarnewton!(A::AbstractMatrix, W, P = similar(A, (0, 0)); tol =
Rc .= R
Rᴴinv = ldiv!(UpperTriangular(Rc)', one!(Rᴴinv))
else # m == n
Q = nothing
R = A
Rc = view(W, 1:n, 1:n)
Rc .= R
Expand Down Expand Up @@ -163,6 +164,7 @@ function _right_polarnewton!(A::AbstractMatrix, Wᴴ, P = similar(A, (0, 0)); to
copy!(Lc, L)
Lᴴinv = ldiv!(LowerTriangular(Lc)', one!(Lᴴinv))
else # m == n
Q = nothing
L = A
Lc = view(Wᴴ, 1:m, 1:m)
Lc .= L
Expand Down
2 changes: 2 additions & 0 deletions src/implementations/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ function householder_qr!(
(inplaceQ && (computeR || positive || blocksize > 1 || m < n)) &&
throw(ArgumentError("inplace Q only supported if matrix is tall (`m >= n`), R is not required, and using the unblocked algorithm (`blocksize = 1`) with `positive = false`"))

jpvt = Vector{Int}(undef, 0)
τ = Vector{eltype(A)}(undef, 0)
Comment on lines +144 to +145
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the least impactful: making these nothing and having a small type union, or the current approach, which I think does require a small allocation?

# Compute QR in packed form
if blocksize > 1
nb = min(minmn, blocksize)
Expand Down
7 changes: 2 additions & 5 deletions src/pullbacks/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,12 @@ function eig_trunc_pullback!(
(n, n) == size(ΔA) || throw(DimensionMismatch())
G = V' * V

VᴴΔV = !iszerotangent(ΔV) ? V' * ΔV : zero(G)
ΔVperp = ΔV - V * inv(G) * VᴴΔV
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this even work if iszerotangent(ΔV), e.g. ΔV could be nothing on some of the AD engines like Zygote, no?

if !iszerotangent(ΔV)
(n, p) == size(ΔV) || throw(DimensionMismatch())
VᴴΔV = V' * ΔV
check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol)

ΔVperp = ΔV - V * inv(G) * VᴴΔV
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
else
VᴴΔV = zero(G)
end

if !iszerotangent(ΔDmat)
Expand Down
2 changes: 2 additions & 0 deletions src/pullbacks/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ function svd_pullback!(
Ur = view(U, :, 1:r)
Vᴴr = view(Vᴴ, 1:r, :)
Sr = view(S, 1:r)
indU = axes(U, 2)
indV = axes(Vᴴ, 1)
Comment on lines +50 to +51
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These can probably immediately get their final value, since ind is always defined, even if we end up not needing them.


# Extract and check the cotangents
ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ
Expand Down
6 changes: 4 additions & 2 deletions src/yalapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2033,14 +2033,15 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
for i in 1:2 # first call returns lwork as work[1]
#! format: off
if eltype(A) <: Complex
rwork_ = isnothing(rwork) ? Vector{$relty}(undef, 0) : rwork
ccall((@blasfunc($gesvd), libblastrampoline), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
Ptr{$relty}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
Ptr{$elty}, Ref{BlasInt}, Ptr{$relty},
Ptr{BlasInt}, Clong, Clong),
jobu, jobvt, m, n, A, lda,
S, U, ldu, Vᴴ, ldv,
work, lwork, rwork,
work, lwork, rwork_,
info, 1, 1)
else
ccall((@blasfunc($gesvd), libblastrampoline), Cvoid,
Expand Down Expand Up @@ -2135,14 +2136,15 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
for i in 1:2 # first call returns lwork as work[1]
#! format: off
if eltype(A) <: Complex
rwork_ = isnothing(rwork) ? Vector{$relty}(undef, 0) : rwork
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how this one helps? if eltype(A) <: Complex, then also rwork is not nothing. Is it just a matter of trying to show this to the compiler? How about

rwork_::Vector{$relty} = rwork

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if eltype(A) <: Complex, then also rwork is not nothing

JET seems unable to deduce this fact :(

Let's try the suggestion and see how it reacts!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, no dice, because from JET's PoV this could still involve a conversion from Nothing to a Vector{$relty}...

ccall((@blasfunc($gesdd), libblastrampoline), Cvoid,
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
Ptr{$relty}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
Ptr{$elty}, Ref{BlasInt}, Ptr{$relty}, Ptr{BlasInt},
Ptr{BlasInt}, Clong),
job, m, n, A, lda,
S, U, ldu, Vᴴ, ldv,
work, lwork, rwork, iwork,
work, lwork, rwork_, iwork,
info, 1)
else
ccall((@blasfunc($gesdd), libblastrampoline), Cvoid,
Expand Down
Loading