Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Printf = "1"
Random = "1"
SafeTestsets = "0.1"
ScopedValues = "1.3.0"
Strided = "2"
Strided = "2.3.4"
TensorKitSectors = "0.3.5"
TensorOperations = "5.1"
Test = "1"
Expand Down
2 changes: 1 addition & 1 deletion ext/TensorKitCUDAExt/TensorKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using TensorKit.Factorizations
using TensorKit.Strided
using TensorKit.Factorizations: AbstractAlgorithm
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check
import TensorKit: randisometry, rand, randn
import TensorKit: randisometry, rand, randn, _copyto!, _add_general_kernel_nonthreaded!, blocktype

using TensorKit: MatrixAlgebraKit

Expand Down
37 changes: 23 additions & 14 deletions ext/TensorKitCUDAExt/cutensormap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::Abstr
return TensorKit.TensorMapWithStorage{T, A}(A(h_t.data), V)
end

function TensorKit.blocktype(::Type{<:CuTensorMap{T, S}}) where {T, S}
return CuMatrix{T, CUDA.DeviceMemory}
end

for (fname, felt) in ((:zeros, :zero), (:ones, :one))
@eval begin
function CUDA.$fname(
Expand Down Expand Up @@ -101,18 +105,6 @@ function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S}
return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)]
end

function Base.convert(
TT::Type{CuTensorMap{T, S, N₁, N₂}},
t::AbstractTensorMap{<:Any, S, N₁, N₂}
) where {T, S, N₁, N₂}
if typeof(t) === TT
return t
else
tnew = TT(undef, space(t))
return copy!(tnew, t)
end
end

function LinearAlgebra.isposdef(t::CuTensorMap)
domain(t) == codomain(t) ||
throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same"))
Expand All @@ -138,10 +130,9 @@ function Base.promote_rule(
return CuTensorMap{T, S, N₁, N₂}
end

TensorKit.promote_storage_rule(::Type{CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} =
TensorKit.promote_storage_rule(::Type{<:CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} =
CuArray{T, N, CUDA.default_memory}


# CuTensorMap exponentation:
function TensorKit.exp!(t::CuTensorMap)
domain(t) == codomain(t) ||
Expand All @@ -168,3 +159,21 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
return tf
end
end

function TensorKit._add_general_kernel_nonthreaded!(
tdst::CuTensorMap, tsrc::CuTensorMap, p, transformer::TensorKit.GenericTreeTransformer, α, β, backend...
)
# preallocate buffers
buffers = TensorKit.allocate_buffers(tdst, tsrc, transformer)

for subtransformer in transformer.data
# Special case without intermediate buffers whenever there is only a single block
if length(subtransformer[1]) == 1
TensorKit._add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...)
else
cu_subtransformer = tuple(CUDA.adapt(CuArray, subtransformer[1]), subtransformer[2:end]...)
TensorKit._add_transform_multi!(tdst, tsrc, p, cu_subtransformer, buffers, α, β, backend...)
end
end
return nothing
end
8 changes: 5 additions & 3 deletions src/auxiliary/auxiliary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ function _interleave(a::NTuple{N}, b::NTuple{N}) where {N}
return (a[1], b[1], _interleave(tail(a), tail(b))...)
end

_copyto!(A, B) = copy!(A, B)

# Low-overhead implementation of `copyto!` for specific case of `stride(B, 1) < stride(B, 2)`
# used in indexmanipulations: avoids the overhead of Strided.jl
function _copyto!(A::StridedView{<:Any, 1}, B::StridedView{<:Any, 2})
length(A) == length(B) || throw(DimensionMismatch())
# for CPU-hosted Arrays # used in indexmanipulations: avoids the overhead of Strided.jl
function _copyto!(A::StridedView{TA, 1, AA}, B::StridedView{TB, 2, BB}) where {TA <: Number, TB <: Number, AA <: Memory{TA}, BB <: Memory{TB}}
length(A) == length(B) || throw(DimensionMismatch(lazy"length of A ($(length(A))) does not match length of B ($(length(B))"))

Adata = parent(A)
Astr = stride(A, 1)
Expand Down
13 changes: 8 additions & 5 deletions src/tensors/abstracttensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ storagetype(t) = storagetype(typeof(t))
function storagetype(::Type{T}) where {T <: AbstractTensorMap}
if T isa Union
# attempt to be slightly more specific by promoting unions
Ma = storagetype(T.a)
Mb = storagetype(T.b)
return promote_storagetype(Ma, Mb)
return promote_storagetype(T.a, T.b)
elseif eltype(T) isa Union
# attempt to be slightly more specific by promoting unions
TU = eltype(T)
return promote_storagetype(TU.a, TU.b)
else
# fallback definition by using scalartype
return similarstoragetype(scalartype(T))
Expand Down Expand Up @@ -103,8 +105,9 @@ similarstoragetype(X::Type, ::Type{T}) where {T <: Number} =

# implement on tensors
similarstoragetype(::Type{TT}) where {TT <: AbstractTensorMap} = similarstoragetype(storagetype(TT))
similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} =
similarstoragetype(storagetype(TT), T)
function similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number}
return similarstoragetype(storagetype(TT), T)
end

# implement on arrays
similarstoragetype(::Type{A}) where {A <: DenseVector{<:Number}} = A
Expand Down
9 changes: 6 additions & 3 deletions src/tensors/braidingtensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,15 @@ end
has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false
function add_transform!(
tdst::AbstractTensorMap,
tsrc::BraidingTensor, (p₁, p₂)::Index2Tuple,
tsrc::BraidingTensor{T, S},
(p₁, p₂)::Index2Tuple,
fusiontreetransform,
α::Number, β::Number, backend::AbstractBackend...
)
) where {T, S}
tsrc_map = similar(tdst, storagetype(tdst), space(tsrc))
copy!(tsrc_map, tsrc)
return add_transform!(
tdst, TensorMap(tsrc), (p₁, p₂), fusiontreetransform, α, β,
tdst, tsrc_map, (p₁, p₂), fusiontreetransform, α, β,
backend...
)
end
Expand Down
8 changes: 4 additions & 4 deletions test/cuda/tensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ for V in spacelist
next = @constinferred Nothing iterate(bs, state)
b2 = @constinferred block(t, first(blocksectors(t)))
@test b1 == b2
@test_broken eltype(bs) === Pair{typeof(c), typeof(b1)}
@test_broken typeof(b1) === TensorKit.blocktype(t)
@test eltype(bs) === Pair{typeof(c), typeof(b1)}
@test typeof(b1) === TensorKit.blocktype(t)
@test typeof(c) === sectortype(t)
end
end
Expand Down Expand Up @@ -162,8 +162,8 @@ for V in spacelist
next = @constinferred Nothing iterate(bs, state)
b2 = @constinferred block(t', first(blocksectors(t')))
@test b1 == b2
@test_broken eltype(bs) === Pair{typeof(c), typeof(b1)}
@test_broken typeof(b1) === TensorKit.blocktype(t')
@test eltype(bs) === Pair{typeof(c), typeof(b1)}
@test typeof(b1) === TensorKit.blocktype(t')
@test typeof(c) === sectortype(t)
# linear algebra
@test isa(@constinferred(norm(t)), real(T))
Expand Down
Loading