diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index fb67149e..7a52028b 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -15,38 +15,20 @@ using LinearAlgebra: BlasFloat include("yacusolver.jl") -function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}} +function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}} return CUSOLVER_HouseholderQR(; kwargs...) end -function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}} +function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}} qr_alg = CUSOLVER_HouseholderQR(; kwargs...) return LQViaTransposedQR(qr_alg) end -function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}} +function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}} return CUSOLVER_QRIteration(; kwargs...) end -function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}} +function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}} return CUSOLVER_Simple(; kwargs...) end -function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}} - return CUSOLVER_DivideAndConquer(; kwargs...) -end - -# include for block sector support -function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}} - return CUSOLVER_HouseholderQR(; kwargs...) -end -function MatrixAlgebraKit.default_lq_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}} - qr_alg = CUSOLVER_HouseholderQR(; kwargs...) - return LQViaTransposedQR(qr_alg) -end -function MatrixAlgebraKit.default_svd_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}} - return CUSOLVER_Jacobi(; kwargs...) -end -function MatrixAlgebraKit.default_eig_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}} - return CUSOLVER_Simple(; kwargs...) -end -function MatrixAlgebraKit.default_eigh_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}} +function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}} return CUSOLVER_DivideAndConquer(; kwargs...) end diff --git a/src/interface/eig.jl b/src/interface/eig.jl index bb111c01..c315d737 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -161,12 +161,18 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc). # ------------------- default_eig_algorithm(A; kwargs...) = default_eig_algorithm(typeof(A); kwargs...) default_eig_algorithm(T::Type; kwargs...) = throw(MethodError(default_eig_algorithm, (T,))) -function default_eig_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat} +function default_eig_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat} return LAPACK_Expert(; kwargs...) end function default_eig_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} return DiagonalAlgorithm(; kwargs...) end +function default_eig_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} + return default_eig_algorithm(A) +end +function default_eig_algorithm(::Type{SubArray{T, N, A}}) where {T, N, A} + return default_eig_algorithm(A) +end for f in (:eig_full!, :eig_vals!) @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index 42c7f9f3..470395fe 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -167,12 +167,18 @@ default_eigh_algorithm(A; kwargs...) = default_eigh_algorithm(typeof(A); kwargs. function default_eigh_algorithm(T::Type; kwargs...) throw(MethodError(default_eigh_algorithm, (T,))) end -function default_eigh_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat} +function default_eigh_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat} return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...) end function default_eigh_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} return DiagonalAlgorithm(; kwargs...) end +function default_eigh_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} + return default_eigh_algorithm(A) +end +function default_eigh_algorithm(::Type{SubArray{T, N, A}}) where {T, N, A} + return default_eigh_algorithm(A) +end for f in (:eigh_full!, :eigh_vals!) @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} diff --git a/src/interface/lq.jl b/src/interface/lq.jl index 8254c826..36327d03 100644 --- a/src/interface/lq.jl +++ b/src/interface/lq.jl @@ -75,12 +75,18 @@ end function default_lq_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix} return Native_HouseholderLQ(; kwargs...) end -function default_lq_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat} +function default_lq_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat} return LAPACK_HouseholderLQ(; kwargs...) end function default_lq_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} return DiagonalAlgorithm(; kwargs...) end +function default_lq_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} + return default_lq_algorithm(A) +end +function default_lq_algorithm(::Type{SubArray{T, N, A}}) where {T, N, A} + return default_lq_algorithm(A) +end for f in (:lq_full!, :lq_compact!, :lq_null!) @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} diff --git a/src/interface/qr.jl b/src/interface/qr.jl index c881b5a6..08dbcc19 100644 --- a/src/interface/qr.jl +++ b/src/interface/qr.jl @@ -75,12 +75,18 @@ end function default_qr_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix} return Native_HouseholderQR(; kwargs...) end -function default_qr_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat} +function default_qr_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat} return LAPACK_HouseholderQR(; kwargs...) end function default_qr_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} return DiagonalAlgorithm(; kwargs...) end +function default_qr_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} + return default_qr_algorithm(A) +end +function default_qr_algorithm(::Type{SubArray{T, N, A}}) where {T, N, A} + return default_qr_algorithm(A) +end for f in (:qr_full!, :qr_compact!, :qr_null!) @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} diff --git a/src/interface/svd.jl b/src/interface/svd.jl index 6ca88b0c..24611e65 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -161,12 +161,18 @@ default_svd_algorithm(A; kwargs...) = default_svd_algorithm(typeof(A); kwargs... function default_svd_algorithm(T::Type; kwargs...) throw(MethodError(default_svd_algorithm, (T,))) end -function default_svd_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat} +function default_svd_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat} return LAPACK_DivideAndConquer(; kwargs...) end function default_svd_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} return DiagonalAlgorithm(; kwargs...) end +function default_svd_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} + return default_svd_algorithm(A) +end +function default_svd_algorithm(::Type{SubArray{T, N, A}}) where {T, N, A} + return default_svd_algorithm(A) +end for f in (:svd_full!, :svd_compact!, :svd_vals!) @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} diff --git a/src/yalapack.jl b/src/yalapack.jl index 3d3613cc..fded1ed3 100644 --- a/src/yalapack.jl +++ b/src/yalapack.jl @@ -15,10 +15,13 @@ using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, Char, LAPACK, using LinearAlgebra.BLAS: @blasfunc, libblastrampoline using LinearAlgebra.LAPACK: chkfinite, chktrans, chkside, chkuplofinite, chklapackerror -# type alias for matrices that are definitely supported by YALAPACK +# type alias for vectors/matrices that are definitely supported by YALAPACK +const BlasVec{T <: BlasFloat} = StridedVector{T} const BlasMat{T <: BlasFloat} = StridedMatrix{T} -# type alias for matrices that are possibly supported by YALAPACK, after conversion +# type alias for vectors/matrices that are possibly supported by YALAPACK, after conversion +const MaybeBlasVec = Union{BlasVec, AbstractVector{<:Integer}} const MaybeBlasMat = Union{BlasMat, AbstractMatrix{<:Integer}} +const MaybeBlasVecOrMat = Union{MaybeBlasVec, MaybeBlasMat} # LU factorisation (currently unused in MatrixAlgebraKit) # for (getrf, getrs, elty) in (