Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 5 additions & 23 deletions ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,20 @@ using LinearAlgebra: BlasFloat

include("yacusolver.jl")

function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
return CUSOLVER_HouseholderQR(; kwargs...)
end
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
return LQViaTransposedQR(qr_alg)
end
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
return CUSOLVER_QRIteration(; kwargs...)
end
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
return CUSOLVER_Simple(; kwargs...)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
return CUSOLVER_DivideAndConquer(; kwargs...)
end

# include for block sector support
function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
return CUSOLVER_HouseholderQR(; kwargs...)
end
function MatrixAlgebraKit.default_lq_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
return LQViaTransposedQR(qr_alg)
end
function MatrixAlgebraKit.default_svd_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
return CUSOLVER_Jacobi(; kwargs...)
end
function MatrixAlgebraKit.default_eig_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
return CUSOLVER_Simple(; kwargs...)
end
function MatrixAlgebraKit.default_eigh_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
return CUSOLVER_DivideAndConquer(; kwargs...)
end

Expand Down
8 changes: 7 additions & 1 deletion src/interface/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,18 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc).
# -------------------
default_eig_algorithm(A; kwargs...) = default_eig_algorithm(typeof(A); kwargs...)
default_eig_algorithm(T::Type; kwargs...) = throw(MethodError(default_eig_algorithm, (T,)))
function default_eig_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat}
function default_eig_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat}
return LAPACK_Expert(; kwargs...)
end
function default_eig_algorithm(::Type{T}; kwargs...) where {T <: Diagonal}
return DiagonalAlgorithm(; kwargs...)
end
function default_eig_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A}
return default_eig_algorithm(A)
end
function default_eig_algorithm(::Type{SubArray{T, N, A}}) where {T, N, A}
return default_eig_algorithm(A)
end

for f in (:eig_full!, :eig_vals!)
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
Expand Down
8 changes: 7 additions & 1 deletion src/interface/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,18 @@ default_eigh_algorithm(A; kwargs...) = default_eigh_algorithm(typeof(A); kwargs.
function default_eigh_algorithm(T::Type; kwargs...)
throw(MethodError(default_eigh_algorithm, (T,)))
end
function default_eigh_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat}
function default_eigh_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat}
return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...)
end
function default_eigh_algorithm(::Type{T}; kwargs...) where {T <: Diagonal}
return DiagonalAlgorithm(; kwargs...)
end
function default_eigh_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A}
return default_eigh_algorithm(A)
end
function default_eigh_algorithm(::Type{SubArray{T, N, A}}) where {T, N, A}
return default_eigh_algorithm(A)
end

for f in (:eigh_full!, :eigh_vals!)
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
Expand Down
8 changes: 7 additions & 1 deletion src/interface/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,18 @@ end
function default_lq_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix}
return Native_HouseholderLQ(; kwargs...)
end
function default_lq_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat}
function default_lq_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat}
return LAPACK_HouseholderLQ(; kwargs...)
end
function default_lq_algorithm(::Type{T}; kwargs...) where {T <: Diagonal}
return DiagonalAlgorithm(; kwargs...)
end
function default_lq_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A}
return default_lq_algorithm(A)
end
function default_lq_algorithm(::Type{SubArray{T, N, A}}) where {T, N, A}
return default_lq_algorithm(A)
end

for f in (:lq_full!, :lq_compact!, :lq_null!)
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
Expand Down
8 changes: 7 additions & 1 deletion src/interface/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,18 @@ end
function default_qr_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix}
return Native_HouseholderQR(; kwargs...)
end
function default_qr_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat}
function default_qr_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat}
return LAPACK_HouseholderQR(; kwargs...)
end
function default_qr_algorithm(::Type{T}; kwargs...) where {T <: Diagonal}
return DiagonalAlgorithm(; kwargs...)
end
function default_qr_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A}
return default_qr_algorithm(A)
end
function default_qr_algorithm(::Type{SubArray{T, N, A}}) where {T, N, A}
return default_qr_algorithm(A)
end

for f in (:qr_full!, :qr_compact!, :qr_null!)
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
Expand Down
8 changes: 7 additions & 1 deletion src/interface/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,18 @@ default_svd_algorithm(A; kwargs...) = default_svd_algorithm(typeof(A); kwargs...
function default_svd_algorithm(T::Type; kwargs...)
throw(MethodError(default_svd_algorithm, (T,)))
end
function default_svd_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat}
function default_svd_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat}
return LAPACK_DivideAndConquer(; kwargs...)
end
function default_svd_algorithm(::Type{T}; kwargs...) where {T <: Diagonal}
return DiagonalAlgorithm(; kwargs...)
end
function default_svd_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A}
return default_svd_algorithm(A)
end
function default_svd_algorithm(::Type{SubArray{T, N, A}}) where {T, N, A}
return default_svd_algorithm(A)
end

for f in (:svd_full!, :svd_compact!, :svd_vals!)
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
Expand Down
7 changes: 5 additions & 2 deletions src/yalapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, Char, LAPACK,
using LinearAlgebra.BLAS: @blasfunc, libblastrampoline
using LinearAlgebra.LAPACK: chkfinite, chktrans, chkside, chkuplofinite, chklapackerror

# type alias for matrices that are definitely supported by YALAPACK
# type alias for vectors/matrices that are definitely supported by YALAPACK
const BlasVec{T <: BlasFloat} = StridedVector{T}
const BlasMat{T <: BlasFloat} = StridedMatrix{T}
# type alias for matrices that are possibly supported by YALAPACK, after conversion
# type alias for vectors/matrices that are possibly supported by YALAPACK, after conversion
const MaybeBlasVec = Union{BlasVec, AbstractVector{<:Integer}}
const MaybeBlasMat = Union{BlasMat, AbstractMatrix{<:Integer}}
const MaybeBlasVecOrMat = Union{MaybeBlasVec, MaybeBlasMat}

# LU factorisation (currently unused in MatrixAlgebraKit)
# for (getrf, getrs, elty) in (
Expand Down