From b4ae4d7a6fc775f2d73066727fd34833abc854d2 Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 2 Oct 2025 13:25:12 -0400 Subject: [PATCH 01/45] Working BP Commit --- src/ITensorNetworksNext.jl | 3 +++ src/abstracttensornetwork.jl | 2 +- test/test_beliefpropagation.jl | 25 +++++++++++++++++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 test/test_beliefpropagation.jl diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 19c4109..905d783 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -9,4 +9,7 @@ include("abstract_problem.jl") include("iterators.jl") include("adapters.jl") +include("beliefpropagation/abstractbeliefpropagationcache.jl") +include("beliefpropagation/beliefpropagationcache.jl") + end diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index e566752..1ecbffa 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -254,4 +254,4 @@ function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) return nothing end -Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) +Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) \ No newline at end of file diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl new file mode 100644 index 0000000..4b179fb --- /dev/null +++ b/test/test_beliefpropagation.jl @@ -0,0 +1,25 @@ +using Dictionaries: Dictionary +using ITensorBase: Index +using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, adapt_messages, default_message, default_messages, edge_scalars, messages, setmessages!, factors, freenergy, + partitionfunction +using Graphs: edges, vertices +using NamedGraphs.NamedGraphGenerators: named_grid +using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges +using Test: @test, @testset + +@testset "BeliefPropagation" begin + dims = (4, 1) + g = named_grid(dims) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + bpc = BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.update(bpc; maxiter = 10) + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test abs(z_bp - z_exact) <= 1e-14 +end \ No newline at end of file From d77d0632e6e88a13ab817d9d8a99a90442d37efe Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 23 Oct 2025 18:23:27 -0400 Subject: [PATCH 02/45] BP Code --- .../abstractbeliefpropagationcache.jl | 151 +++++++++++ .../beliefpropagationcache.jl | 237 ++++++++++++++++++ test/test_beliefpropagation.jl | 20 +- 3 files changed, 407 insertions(+), 1 deletion(-) create mode 100644 src/beliefpropagation/abstractbeliefpropagationcache.jl create mode 100644 src/beliefpropagation/beliefpropagationcache.jl diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl new file mode 100644 index 0000000..5eae283 --- /dev/null +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -0,0 +1,151 @@ +abstract type AbstractBeliefPropagationCache{V} <: AbstractGraph{V} end + +#Interface +factor(bp_cache::AbstractBeliefPropagationCache, vertex) = not_implemented() +setfactor!(bp_cache::AbstractBeliefPropagationCache, vertex, factor) = not_implemented() +messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) = not_implemented() +function default_message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) + return not_implemented() +end +default_messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +function setmessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge, message) + return not_implemented() +end +function deletemessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) + return not_implemented() +end +function rescale_messages( + bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}; kwargs... + ) + return not_implemented() +end +function rescale_vertices( + bp_cache::AbstractBeliefPropagationCache, vertices::Vector; kwargs... + ) + return not_implemented() +end + +function vertex_scalar(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) + return not_implemented() +end +function edge_scalar( + bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs... + ) + return not_implemented() +end + +#Graph functionality needed +Graphs.vertices(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +Graphs.edges(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +function NamedGraphs.GraphsExtensions.boundary_edges( + bp_cache::AbstractBeliefPropagationCache, vertices; kwargs... + ) + return not_implemented() +end + +#Functions derived from the interface +function setmessages!(bp_cache::AbstractBeliefPropagationCache, edges, messages) + for (e, m) in zip(edges) + setmessage!(bp_cache, e, m) + end + return +end + +function deletemessages!( + bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge} = edges(bp_cache) + ) + for e in edges + deletemessage!(bp_cache, e) + end + return bp_cache +end + +function vertex_scalars( + bp_cache::AbstractBeliefPropagationCache, vertices = Graphs.vertices(bp_cache); kwargs... + ) + return map(v -> region_scalar(bp_cache, v; kwargs...), vertices) +end + +function edge_scalars( + bp_cache::AbstractBeliefPropagationCache, edges = Graphs.edges(bp_cache); kwargs... + ) + return map(e -> region_scalar(bp_cache, e; kwargs...), edges) +end + +function scalar_factors_quotient(bp_cache::AbstractBeliefPropagationCache) + return vertex_scalars(bp_cache), edge_scalars(bp_cache) +end + +function incoming_messages( + bp_cache::AbstractBeliefPropagationCache, vertices::Vector{<:Any}; ignore_edges = [] + ) + b_edges = NamedGraphs.GraphsExtensions.boundary_edges(bp_cache, vertices; dir = :in) + b_edges = !isempty(ignore_edges) ? setdiff(b_edges, ignore_edges) : b_edges + return messages(bp_cache, b_edges) +end + +function incoming_messages(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) + return incoming_messages(bp_cache, [vertex]; kwargs...) +end + +#Adapt interface for changing device +function map_messages(f, bp_cache::AbstractBeliefPropagationCache, es = edges(bp_cache)) + bp_cache = copy(bp_cache) + for e in es + setmessage!(bp_cache, e, f(message(bp_cache, e))) + end + return bp_cache +end +function map_factors(f, bp_cache::AbstractBeliefPropagationCache, vs = vertices(bp_cache)) + bp_cache = copy(bp_cache) + for v in vs + setfactor!(bp_cache, v, f(factor(bp_cache, v))) + end + return bp_cache +end +function adapt_messages(to, bp_cache::AbstractBeliefPropagationCache, args...) + return map_messages(adapt(to), bp_cache, args...) +end +function adapt_factors(to, bp_cache::AbstractBeliefPropagationCache, args...) + return map_factors(adapt(to), bp_cache, args...) +end + +function freenergy(bp_cache::AbstractBeliefPropagationCache) + numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) + if any(t -> real(t) < 0, numerator_terms) + numerator_terms = complex.(numerator_terms) + end + if any(t -> real(t) < 0, denominator_terms) + denominator_terms = complex.(denominator_terms) + end + + any(iszero, denominator_terms) && return -Inf + return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) +end + +function partitionfunction(bp_cache::AbstractBeliefPropagationCache) + return exp(freenergy(bp_cache)) +end + +function rescale_messages(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) + return rescale_messages(bp_cache, [edge]) +end + +function rescale_messages(bp_cache::AbstractBeliefPropagationCache) + return rescale_messages(bp_cache, edges(bp_cache)) +end + +function rescale_vertices(bpc::AbstractBeliefPropagationCache; kwargs...) + return rescale_vertices(bpc, collect(vertices(bpc)); kwargs...) +end + +function rescale_vertex(bpc::AbstractBeliefPropagationCache, vertex; kwargs...) + return rescale_vertices(bpc, [vertex]; kwargs...) +end + +function rescale(bpc::AbstractBeliefPropagationCache, args...; kwargs...) + bpc = rescale_messages(bpc) + bpc = rescale_partitions(bpc, args...; kwargs...) + return bpc +end diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl new file mode 100644 index 0000000..295502a --- /dev/null +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -0,0 +1,237 @@ +using DiagonalArrays: delta +using Dictionaries: Dictionary, set!, delete! +using Graphs: AbstractGraph, is_tree, connected_components +using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges +using ITensorBase: ITensor, dim +using TypeParameterAccessors: unwrap_array_type, unwrap_array, parenttype + +struct BeliefPropagationCache{V, N <: AbstractDataGraph{V}} <: + AbstractBeliefPropagationCache{V} + network::N + messages::Dictionary +end + +messages(bp_cache::BeliefPropagationCache) = bp_cache.messages +network(bp_cache::BeliefPropagationCache) = bp_cache.network +default_messages() = Dictionary() + +BeliefPropagationCache(network) = BeliefPropagationCache(network, default_messages()) + +function Base.copy(bp_cache::BeliefPropagationCache) + return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) +end + +function deletemessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge) + ms = messages(bp_cache) + delete!(ms, e) + return bp_cache +end + +function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message) + ms = messages(bp_cache) + set!(ms, e, message) + return bp_cache +end + +function message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs...) + ms = messages(bp_cache) + return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge) +end + +function messages(bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}) + return [message(bp_cache, e) for e in edges] +end + +default_bp_maxiter(g::AbstractGraph) = is_tree(g) ? 1 : nothing +#Forward onto the network +for f in [ + :(Graphs.vertices), + :(Graphs.edges), + :(Graphs.is_tree), + :(NamedGraphs.GraphsExtensions.boundary_edges), + :(factors), + :(default_bp_maxiter), + :(ITensorNetworksNext.setfactor!), + :(ITensorNetworksNext.linkinds), + :(ITensorNetworksNext.underlying_graph), + ] + @eval begin + function $f(bp_cache::BeliefPropagationCache, args...; kwargs...) + return $f(network(bp_cache), args...; kwargs...) + end + end +end + +#TODO: Get subgraph working on an ITensorNetwork to overload this directly +function default_bp_edge_sequence(bp_cache::BeliefPropagationCache) + return forest_cover_edge_sequence(underlying_graph(bp_cache)) +end + +function factors(tn::AbstractTensorNetwork, vertex) + return [tn[vertex]] +end + +function region_scalar(bp_cache::BeliefPropagationCache, edge::AbstractEdge) + return (message(bp_cache, edge) * message(bp_cache, reverse(edge)))[] +end + +function region_scalar(bp_cache::BeliefPropagationCache, vertex) + incoming_ms = incoming_messages(bp_cache, vertex) + state = factors(bp_cache, vertex) + return (reduce(*, incoming_ms) * reduce(*, state))[] +end + +function default_message(bp_cache::BeliefPropagationCache, edge::AbstractEdge) + return default_message(network(bp_cache), edge::AbstractEdge) +end + +function default_message(tn::AbstractTensorNetwork, edge::AbstractEdge) + t = ITensor(ones(dim.(linkinds(tn, edge))...), linkinds(tn, edge)...) + #TODO: Get datatype working on tensornetworks so we can support GPU, etc... + return t +end + +#Algorithmic defaults +default_update_alg(bp_cache::BeliefPropagationCache) = "bp" +default_message_update_alg(bp_cache::BeliefPropagationCache) = "contract" +default_normalize(::Algorithm"contract") = true +default_sequence_alg(::Algorithm"contract") = "optimal" +function set_default_kwargs(alg::Algorithm"contract") + normalize = get(alg, :normalize, default_normalize(alg)) + sequence_alg = get(alg, :sequence_alg, default_sequence_alg(alg)) + return Algorithm("contract"; normalize, sequence_alg) +end +function set_default_kwargs(alg::Algorithm"adapt_update") + _alg = set_default_kwargs(get(alg, :alg, Algorithm("contract"))) + return Algorithm("adapt_update"; adapt = alg.adapt, alg = _alg) +end +default_verbose(::Algorithm"bp") = false +default_tol(::Algorithm"bp") = nothing +function set_default_kwargs(alg::Algorithm"bp", bp_cache::BeliefPropagationCache) + verbose = get(alg, :verbose, default_verbose(alg)) + maxiter = get(alg, :maxiter, default_bp_maxiter(bp_cache)) + edge_sequence = get(alg, :edge_sequence, default_bp_edge_sequence(bp_cache)) + tol = get(alg, :tol, default_tol(alg)) + message_update_alg = set_default_kwargs( + get(alg, :message_update_alg, Algorithm(default_message_update_alg(bp_cache))) + ) + return Algorithm("bp"; verbose, maxiter, edge_sequence, tol, message_update_alg) +end + +#TODO: Update message etc should go here... +function updated_message( + alg::Algorithm"contract", bp_cache::BeliefPropagationCache, edge::AbstractEdge + ) + vertex = src(edge) + incoming_ms = incoming_messages( + bp_cache, vertex; ignore_edges = typeof(edge)[reverse(edge)] + ) + state = factors(bp_cache, vertex) + #contract_list = ITensor[incoming_ms; state] + #sequence = contraction_sequence(contract_list; alg=alg.kwargs.sequence_alg) + #updated_messages = contract(contract_list; sequence) + updated_message = + !isempty(incoming_ms) ? reduce(*, state) * reduce(*, incoming_ms) : reduce(*, state) + if alg.normalize + message_norm = LinearAlgebra.norm(updated_message) + if !iszero(message_norm) + updated_message /= message_norm + end + end + return updated_message +end + +function updated_message( + bp_cache::BeliefPropagationCache, + edge::AbstractEdge; + alg = default_message_update_alg(bpc), + kwargs..., + ) + return updated_message(set_default_kwargs(Algorithm(alg; kwargs...)), bp_cache, edge) +end + +function update_message!( + message_update_alg::Algorithm, bp_cache::BeliefPropagationCache, edge::AbstractEdge + ) + return setmessage!(bp_cache, edge, updated_message(message_update_alg, bp_cache, edge)) +end + +""" +Do a sequential update of the message tensors on `edges` +""" +function update_iteration( + alg::Algorithm"bp", + bpc::AbstractBeliefPropagationCache, + edges::Vector; + (update_diff!) = nothing, + ) + bpc = copy(bpc) + for e in edges + prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing + update_message!(alg.message_update_alg, bpc, e) + if !isnothing(update_diff!) + update_diff![] += message_diff(message(bpc, e), prev_message) + end + end + return bpc +end + +""" +Do parallel updates between groups of edges of all message tensors +Currently we send the full message tensor data struct to update for each edge_group. But really we only need the +mts relevant to that group. +""" +function update_iteration( + alg::Algorithm"bp", + bpc::AbstractBeliefPropagationCache, + edge_groups::Vector{<:Vector{<:AbstractEdge}}; + (update_diff!) = nothing, + ) + new_mts = empty(messages(bpc)) + for edges in edge_groups + bpc_t = update_iteration(alg.kwargs.message_update_alg, bpc, edges; (update_diff!)) + for e in edges + set!(new_mts, e, message(bpc_t, e)) + end + end + return set_messages(bpc, new_mts) +end + +""" +More generic interface for update, with default params +""" +function update(alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache) + compute_error = !isnothing(alg.tol) + if isnothing(alg.maxiter) + error("You need to specify a number of iterations for BP!") + end + for i in 1:alg.maxiter + diff = compute_error ? Ref(0.0) : nothing + bpc = update_iteration(alg, bpc, alg.edge_sequence; (update_diff!) = diff) + if compute_error && (diff.x / length(alg.edge_sequence)) <= alg.tol + if alg.verbose + println("BP converged to desired precision after $i iterations.") + end + break + end + end + return bpc +end + +function update(bpc::AbstractBeliefPropagationCache; alg = default_update_alg(bpc), kwargs...) + return update(set_default_kwargs(Algorithm(alg; kwargs...), bpc), bpc) +end + +#Edge sequence stuff +function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root_vertex) + forests = forest_cover(g) + edges = edgetype(g)[] + for forest in forests + trees = [forest[vs] for vs in connected_components(forest)] + for tree in trees + tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) + push!(edges, vcat(tree_edges, reverse(reverse.(tree_edges)))...) + end + end + return edges +end \ No newline at end of file diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 4b179fb..81ee722 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -3,11 +3,13 @@ using ITensorBase: Index using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, adapt_messages, default_message, default_messages, edge_scalars, messages, setmessages!, factors, freenergy, partitionfunction using Graphs: edges, vertices -using NamedGraphs.NamedGraphGenerators: named_grid +using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges using Test: @test, @testset @testset "BeliefPropagation" begin + + #Chain of tensors dims = (4, 1) g = named_grid(dims) l = Dict(e => Index(2) for e in edges(g)) @@ -17,6 +19,22 @@ using Test: @test, @testset return randn(Tuple(is)) end + bpc = BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.update(bpc; maxiter = 1) + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test abs(z_bp - z_exact) <= 1e-14 + + #Tree of tensors + dims = (4, 3) + g = named_comb_tree(dims) + l = Dict(e => Index(3) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + bpc = BeliefPropagationCache(tn) bpc = ITensorNetworksNext.update(bpc; maxiter = 10) z_bp = partitionfunction(bpc) From b80e36eaf6aac3a3702bd0403d7858603366b1e7 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 28 Oct 2025 15:18:28 -0400 Subject: [PATCH 03/45] Express BP in terms of `SweepIterator` interface Introduce `BeliefPropagationProblem` wrapper to hold the cache and the error `diff` field. Also simplifies some kwargs wrangling. --- Project.toml | 2 + src/ITensorNetworksNext.jl | 1 + .../beliefpropagationcache.jl | 126 ++---------------- .../beliefpropagationproblem.jl | 85 ++++++++++++ 4 files changed, 101 insertions(+), 113 deletions(-) create mode 100644 src/beliefpropagation/beliefpropagationproblem.jl diff --git a/Project.toml b/Project.toml index 95b8be0..e0aea23 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" @@ -39,6 +40,7 @@ DerivableInterfaces = "0.5.5" DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" +ITensorBase = "0.2.14" LinearAlgebra = "1.10" MacroTools = "0.5.16" NamedDimsArrays = "0.8, 0.9" diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 905d783..cca4b6d 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -11,5 +11,6 @@ include("adapters.jl") include("beliefpropagation/abstractbeliefpropagationcache.jl") include("beliefpropagation/beliefpropagationcache.jl") +include("beliefpropagation/beliefpropagationproblem.jl") end diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 295502a..cdae651 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -1,9 +1,7 @@ -using DiagonalArrays: delta using Dictionaries: Dictionary, set!, delete! using Graphs: AbstractGraph, is_tree, connected_components using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges using ITensorBase: ITensor, dim -using TypeParameterAccessors: unwrap_array_type, unwrap_array, parenttype struct BeliefPropagationCache{V, N <: AbstractDataGraph{V}} <: AbstractBeliefPropagationCache{V} @@ -13,9 +11,8 @@ end messages(bp_cache::BeliefPropagationCache) = bp_cache.messages network(bp_cache::BeliefPropagationCache) = bp_cache.network -default_messages() = Dictionary() -BeliefPropagationCache(network) = BeliefPropagationCache(network, default_messages()) +BeliefPropagationCache(network) = BeliefPropagationCache(network, Dictionary()) function Base.copy(bp_cache::BeliefPropagationCache) return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) @@ -33,16 +30,15 @@ function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message) return bp_cache end -function message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs...) +function message(bp_cache::BeliefPropagationCache, edge::AbstractEdge; kwargs...) ms = messages(bp_cache) return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge) end -function messages(bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}) +function messages(bp_cache::BeliefPropagationCache, edges::Vector{<:AbstractEdge}) return [message(bp_cache, e) for e in edges] end -default_bp_maxiter(g::AbstractGraph) = is_tree(g) ? 1 : nothing #Forward onto the network for f in [ :(Graphs.vertices), @@ -62,11 +58,6 @@ for f in [ end end -#TODO: Get subgraph working on an ITensorNetwork to overload this directly -function default_bp_edge_sequence(bp_cache::BeliefPropagationCache) - return forest_cover_edge_sequence(underlying_graph(bp_cache)) -end - function factors(tn::AbstractTensorNetwork, vertex) return [tn[vertex]] end @@ -91,33 +82,6 @@ function default_message(tn::AbstractTensorNetwork, edge::AbstractEdge) return t end -#Algorithmic defaults -default_update_alg(bp_cache::BeliefPropagationCache) = "bp" -default_message_update_alg(bp_cache::BeliefPropagationCache) = "contract" -default_normalize(::Algorithm"contract") = true -default_sequence_alg(::Algorithm"contract") = "optimal" -function set_default_kwargs(alg::Algorithm"contract") - normalize = get(alg, :normalize, default_normalize(alg)) - sequence_alg = get(alg, :sequence_alg, default_sequence_alg(alg)) - return Algorithm("contract"; normalize, sequence_alg) -end -function set_default_kwargs(alg::Algorithm"adapt_update") - _alg = set_default_kwargs(get(alg, :alg, Algorithm("contract"))) - return Algorithm("adapt_update"; adapt = alg.adapt, alg = _alg) -end -default_verbose(::Algorithm"bp") = false -default_tol(::Algorithm"bp") = nothing -function set_default_kwargs(alg::Algorithm"bp", bp_cache::BeliefPropagationCache) - verbose = get(alg, :verbose, default_verbose(alg)) - maxiter = get(alg, :maxiter, default_bp_maxiter(bp_cache)) - edge_sequence = get(alg, :edge_sequence, default_bp_edge_sequence(bp_cache)) - tol = get(alg, :tol, default_tol(alg)) - message_update_alg = set_default_kwargs( - get(alg, :message_update_alg, Algorithm(default_message_update_alg(bp_cache))) - ) - return Algorithm("bp"; verbose, maxiter, edge_sequence, tol, message_update_alg) -end - #TODO: Update message etc should go here... function updated_message( alg::Algorithm"contract", bp_cache::BeliefPropagationCache, edge::AbstractEdge @@ -141,85 +105,21 @@ function updated_message( return updated_message end -function updated_message( - bp_cache::BeliefPropagationCache, - edge::AbstractEdge; - alg = default_message_update_alg(bpc), - kwargs..., +function default_algorithm( + ::Type{<:Algorithm"contract"}; normalize = true, sequence_alg = "optimal" ) - return updated_message(set_default_kwargs(Algorithm(alg; kwargs...)), bp_cache, edge) + return Algorithm("contract"; normalize, sequence_alg) end - -function update_message!( - message_update_alg::Algorithm, bp_cache::BeliefPropagationCache, edge::AbstractEdge +function default_algorithm( + ::Type{<:Algorithm"adapt_update"}; adapt, alg = default_algorithm(Algorithm"contract") ) - return setmessage!(bp_cache, edge, updated_message(message_update_alg, bp_cache, edge)) + return Algorithm("adapt_update"; adapt, alg) end -""" -Do a sequential update of the message tensors on `edges` -""" -function update_iteration( - alg::Algorithm"bp", - bpc::AbstractBeliefPropagationCache, - edges::Vector; - (update_diff!) = nothing, - ) - bpc = copy(bpc) - for e in edges - prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing - update_message!(alg.message_update_alg, bpc, e) - if !isnothing(update_diff!) - update_diff![] += message_diff(message(bpc, e), prev_message) - end - end - return bpc -end - -""" -Do parallel updates between groups of edges of all message tensors -Currently we send the full message tensor data struct to update for each edge_group. But really we only need the -mts relevant to that group. -""" -function update_iteration( - alg::Algorithm"bp", - bpc::AbstractBeliefPropagationCache, - edge_groups::Vector{<:Vector{<:AbstractEdge}}; - (update_diff!) = nothing, +function update_message!( + message_update_alg::Algorithm, bpc::BeliefPropagationCache, edge::AbstractEdge ) - new_mts = empty(messages(bpc)) - for edges in edge_groups - bpc_t = update_iteration(alg.kwargs.message_update_alg, bpc, edges; (update_diff!)) - for e in edges - set!(new_mts, e, message(bpc_t, e)) - end - end - return set_messages(bpc, new_mts) -end - -""" -More generic interface for update, with default params -""" -function update(alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache) - compute_error = !isnothing(alg.tol) - if isnothing(alg.maxiter) - error("You need to specify a number of iterations for BP!") - end - for i in 1:alg.maxiter - diff = compute_error ? Ref(0.0) : nothing - bpc = update_iteration(alg, bpc, alg.edge_sequence; (update_diff!) = diff) - if compute_error && (diff.x / length(alg.edge_sequence)) <= alg.tol - if alg.verbose - println("BP converged to desired precision after $i iterations.") - end - break - end - end - return bpc -end - -function update(bpc::AbstractBeliefPropagationCache; alg = default_update_alg(bpc), kwargs...) - return update(set_default_kwargs(Algorithm(alg; kwargs...), bpc), bpc) + return setmessage!(bpc, edge, updated_message(message_update_alg, bpc, edge)) end #Edge sequence stuff @@ -234,4 +134,4 @@ function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root end end return edges -end \ No newline at end of file +end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl new file mode 100644 index 0000000..a497363 --- /dev/null +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -0,0 +1,85 @@ +mutable struct BeliefPropagationProblem{V, Cache <: AbstractBeliefPropagationCache{V}} <: + AbstractProblem + const cache::Cache + diff::Union{Nothing, Float64} +end + +function default_algorithm( + ::Type{<:Algorithm"bp"}, + bpc::BeliefPropagationCache; + verbose = false, + tol = nothing, + edge_sequence = forest_cover_edge_sequence(underlying_graph(bpc)), + message_update_alg = default_algorithm(Algorithm"contract"), + maxiter = is_tree(bpc) ? 1 : nothing, + ) + return Algorithm("bp"; verbose, tol, edge_sequence, message_update_alg, maxiter) +end + +function compute!(iter::RegionIterator{<:BeliefPropagationProblem}) + prob = iter.problem + + edge_group, kwargs = current_region_plan(iter) + + new_message_tensors = map(edge_group) do edge + old_message = message(prob.cache, edge) + + new_message = updated_message(kwargs.message_update_alg, prob.cache, edge) + + if !isnothing(prob.diff) + # TODO: Define `message_diff` + prob.diff += message_diff(new_message, old_message) + end + + return new_message + end + + foreach(edge_group, new_message_tensors) do edge, new_message + setmessage!(prob.cache, edge, new_message) + end + + return iter +end + +function region_plan( + prob::BeliefPropagationProblem; root_vertex = default_root_vertex, sweep_kwargs... + ) + edges = forest_cover_edge_sequence(underlying_graph(prob.cache); root_vertex) + + plan = map(edges) do e + return [e] => (; sweep_kwargs...) + end + + return plan +end + +function update(bpc::AbstractBeliefPropagationCache; kwargs...) + return update(default_algorithm(Algorithm"bp", bpc; kwargs...), bpc) +end +function update(alg::Algorithm"bp", bpc) + compute_error = !isnothing(alg.tol) + + diff = compute_error ? 0.0 : nothing + + prob = BeliefPropagationProblem(bpc, diff) + + iter = SweepIterator(prob, alg.maxiter; compute_error, getfield(alg, :kwargs)...) + + for _ in iter + if compute_error && prob.diff <= alg.tol + break + end + end + + if alg.verbose && compute_error + if prob.diff <= alg.tol + println("BP converged to desired precision after $(iter.which_sweep) iterations.") + else + println( + "BP failed to converge to precision $(alg.tol), got $(prob.diff) after $(iter.which_sweep) iterations", + ) + end + end + + return bpc +end From fe44b804f7461106caa3a8dbc6f0dad38ff67ede Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 31 Oct 2025 12:46:03 -0400 Subject: [PATCH 04/45] Add method for `setmessages!` that allows messages from one cache to be set from another cache --- src/beliefpropagation/beliefpropagationcache.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index cdae651..b3a32b1 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -30,6 +30,14 @@ function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message) return bp_cache end +function setmessages!(bpc_dst::BeliefPropagationCache, bpc_src::BeliefPropagationCache, edges) + ms_dst = messages(bpc_dst) + for e in edges + set!(ms_dst, e, message(bpc_src, e)) + end + return bpc_dst +end + function message(bp_cache::BeliefPropagationCache, edge::AbstractEdge; kwargs...) ms = messages(bp_cache) return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge) From 3ce08983b2a9feae9057dc10ca55491bddf08079 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 10 Nov 2025 14:03:59 -0500 Subject: [PATCH 05/45] Network is now passed to `forest_cover_edge_sequence` directly. --- src/beliefpropagation/beliefpropagationproblem.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index a497363..967b454 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -9,7 +9,7 @@ function default_algorithm( bpc::BeliefPropagationCache; verbose = false, tol = nothing, - edge_sequence = forest_cover_edge_sequence(underlying_graph(bpc)), + edge_sequence = forest_cover_edge_sequence(network(bpc)), message_update_alg = default_algorithm(Algorithm"contract"), maxiter = is_tree(bpc) ? 1 : nothing, ) @@ -44,7 +44,8 @@ end function region_plan( prob::BeliefPropagationProblem; root_vertex = default_root_vertex, sweep_kwargs... ) - edges = forest_cover_edge_sequence(underlying_graph(prob.cache); root_vertex) + + edges = forest_cover_edge_sequence(network(prob.cache); root_vertex) plan = map(edges) do e return [e] => (; sweep_kwargs...) From f6e4fd0ea748f4a3da272dc1011a855fdaee7a9e Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 11:19:31 -0500 Subject: [PATCH 06/45] test file formatting --- test/test_beliefpropagation.jl | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 81ee722..fc657e7 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -1,7 +1,17 @@ using Dictionaries: Dictionary using ITensorBase: Index -using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, adapt_messages, default_message, default_messages, edge_scalars, messages, setmessages!, factors, freenergy, - partitionfunction +using ITensorNetworksNext: + BeliefPropagationCache, + ITensorNetworksNext, + TensorNetwork, + adapt_messages, + default_message, + default_messages, + edge_scalars, + factors, + messages, + partitionfunction, + setmessages! using Graphs: edges, vertices using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges @@ -15,15 +25,15 @@ using Test: @test, @testset l = Dict(e => Index(2) for e in edges(g)) l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) end bpc = BeliefPropagationCache(tn) bpc = ITensorNetworksNext.update(bpc; maxiter = 1) z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test abs(z_bp - z_exact) <= 1e-14 + @test abs(z_bp - z_exact) <= 1.0e-14 #Tree of tensors dims = (4, 3) @@ -31,13 +41,14 @@ using Test: @test, @testset l = Dict(e => Index(3) for e in edges(g)) l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) end bpc = BeliefPropagationCache(tn) bpc = ITensorNetworksNext.update(bpc; maxiter = 10) z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test abs(z_bp - z_exact) <= 1e-14 -end \ No newline at end of file + @test abs(z_bp - z_exact) <= 1.0e-14 +end + From 63840a90df869893d87c1ce6a6c58e06bb13973c Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 11:25:31 -0500 Subject: [PATCH 07/45] Add `DataGraphsPartitionedGraphsExt` glue for `TensorNetwork` type Also includes some fixes to the way `TensorNetwork` types are constructed based on index structure. --- src/tensornetwork.jl | 79 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 76 insertions(+), 3 deletions(-) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 582eec6..11c2e88 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -1,10 +1,21 @@ using Combinatorics: combinations using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph using Dictionaries: AbstractDictionary, Indices, dictionary -using Graphs: AbstractSimpleGraph +using Graphs: AbstractSimpleGraph, rem_vertex!, rem_edge! using NamedDimsArrays: AbstractNamedDimsArray, dimnames using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype -using NamedGraphs.GraphsExtensions: add_edges!, arrange_edge, arranged_edges, vertextype +using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, vertextype +using NamedGraphs.PartitionedGraphs: + AbstractPartitionedGraph, + PartitionedGraphs, + departition, + partitioned_vertices, + partitionedgraph, + quotient_graph, + quotient_graph_type +using .LazyNamedDimsArrays: lazy, Mul +using DataGraphs: vertex_data_eltype, vertex_data, edge_data +using DataGraphs.DataGraphsPartitionedGraphsExt function _TensorNetwork end @@ -24,8 +35,14 @@ function _TensorNetwork(graph::AbstractGraph, tensors) return _TensorNetwork(graph, Dictionary(keys(tensors), values(tensors))) end +function TensorNetwork{V, VD, UG, Tensors}(graph::UG) where {V, VD, UG <: AbstractGraph{V}, Tensors} + return _TensorNetwork(graph, Tensors()) +end + DataGraphs.underlying_graph(tn::TensorNetwork) = getfield(tn, :underlying_graph) DataGraphs.vertex_data(tn::TensorNetwork) = getfield(tn, :tensors) +DataGraphs.edge_data(tn::TensorNetwork) = Dictionary{edgetype(tn), Nothing}() +DataGraphs.vertex_data_eltype(T::Type{<:TensorNetwork}) = eltype(fieldtype(T, :tensors)) function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork}) return fieldtype(type, :underlying_graph) end @@ -70,7 +87,10 @@ function fix_links!(tn::AbstractTensorNetwork) for e in setdiff(arranged_edges(graph), tn_edges) insert_trivial_link!(tn, e) end - return tn + for edge in setdiff(arranged_edges(graph), arranged_edges(graph_structure)) + insert_trivial_link!(network, edge) + end + return network end # Determine the graph structure from the tensors. @@ -93,3 +113,56 @@ end NamedGraphs.convert_vertextype(::Type{V}, tn::TensorNetwork{V}) where {V} = tn NamedGraphs.convert_vertextype(V::Type, tn::TensorNetwork) = TensorNetwork{V}(tn) + +Graphs.connected_components(tn::TensorNetwork) = Graphs.connected_components(underlying_graph(tn)) + +function Graphs.rem_edge!(tn::TensorNetwork, e) + if !has_edge(underlying_graph(tn), e) + return false + end + if !isempty(linkinds(tn, e)) + throw(ArgumentError("cannot remove edge $e due to tensor indices existing on this edge.")) + end + rem_edge!(underlying_graph(tn), e) + return true +end + +function GraphsExtensions.graph_from_vertices(type::Type{<:TensorNetwork}, vertices) + DT = fieldtype(type, :tensors) + empty_dict = DT() + return TensorNetwork(similar_graph(underlying_graph_type(type), vertices), empty_dict) +end + +## PartitionedGraphs +function PartitionedGraphs.quotient_graph(tn::TensorNetwork) + ug = quotient_graph(underlying_graph(tn)) + return TensorNetwork(ug, vertex_data(QuotientView(tn))) +end +function PartitionedGraphs.quotient_graph_type(type::Type{<:TensorNetwork}) + UG = quotient_graph_type(underlying_graph_type(type)) + VD = Vector{vertex_data_eltype(type)} + V = vertextype(UG) + return TensorNetwork{V, VD, UG, Dictionary{V, VD}} +end + +function PartitionedGraphs.partitionedgraph(tn::TensorNetwork, parts) + pg = partitionedgraph(underlying_graph(tn), parts) + return TensorNetwork(pg, vertex_data(tn)) +end + +PartitionedGraphs.departition(tn::TensorNetwork) = tn +function PartitionedGraphs.departition( + tn::TensorNetwork{<:Any, <:Any, <:AbstractPartitionedGraph} + ) + return TensorNetwork(departition(underlying_graph(tn)), vertex_data(tn)) +end + +function DataGraphsPartitionedGraphsExt.to_quotient_vertex_data(::TensorNetwork, data) + return mapreduce(lazy, *, collect(last(data))) +end + +function PartitionedGraphs.quotientview(tn::TensorNetwork) + qview = QuotientView(underlying_graph(tn)) + tensors = vertex_data(QuotientView(tn)) + return TensorNetwork(qview, tensors) +end From ba22ab5b107d2b681a5bd1d29395c0f390f23d56 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 12:27:20 -0500 Subject: [PATCH 08/45] Make abstract tensor network interface more generic. --- src/abstracttensornetwork.jl | 106 ++++++++++++++++++----------------- 1 file changed, 54 insertions(+), 52 deletions(-) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index 1ecbffa..b02c789 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -9,19 +9,23 @@ using LinearAlgebra: LinearAlgebra, factorize using MacroTools: @capture using NamedDimsArrays: dimnames, inds using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree -using NamedGraphs.GraphsExtensions: ⊔, directed_graph, incident_edges, rem_edges!, - rename_vertices, vertextype +using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger +using NamedGraphs.GraphsExtensions: + ⊔, + directed_graph, + incident_edges, + rem_edges!, + rename_vertices, + vertextype using SplitApplyCombine: flatten +using NamedGraphs.SimilarType: similar_type abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end -function Graphs.rem_edge!(tn::AbstractTensorNetwork, e) - rem_edge!(underlying_graph(tn), e) - return tn -end +# Need to be careful about removing edges from tensor networks in case there is a bond +Graphs.rem_edge!(::AbstractTensorNetwork, edge) = not_implemented() -# TODO: Define a generic fallback for `AbstractDataGraph`? -DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = error("No edge data") +DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = not_implemented() # Graphs.jl overloads function Graphs.weights(graph::AbstractTensorNetwork) @@ -36,7 +40,7 @@ function Graphs.weights(graph::AbstractTensorNetwork) end # Copy -Base.copy(tn::AbstractTensorNetwork) = error("Not implemented") +Base.copy(::AbstractTensorNetwork) = not_implemented() # Iteration Base.iterate(tn::AbstractTensorNetwork, args...) = iterate(vertex_data(tn), args...) @@ -49,20 +53,11 @@ Base.eltype(tn::AbstractTensorNetwork) = eltype(vertex_data(tn)) # Overload if needed Graphs.is_directed(::Type{<:AbstractTensorNetwork}) = false -# Derived interface, may need to be overloaded -function DataGraphs.underlying_graph_type(G::Type{<:AbstractTensorNetwork}) - return underlying_graph_type(data_graph_type(G)) -end - # AbstractDataGraphs overloads -function DataGraphs.vertex_data(graph::AbstractTensorNetwork, args...) - return error("Not implemented") -end -function DataGraphs.edge_data(graph::AbstractTensorNetwork, args...) - return error("Not implemented") -end +DataGraphs.vertex_data(::AbstractTensorNetwork) = not_implemented() +DataGraphs.edge_data(::AbstractTensorNetwork) = not_implemented() -DataGraphs.underlying_graph(tn::AbstractTensorNetwork) = error("Not implemented") +DataGraphs.underlying_graph(::AbstractTensorNetwork) = not_implemented() function NamedGraphs.vertex_positions(tn::AbstractTensorNetwork) return NamedGraphs.vertex_positions(underlying_graph(tn)) end @@ -81,40 +76,37 @@ function Adapt.adapt_structure(to, tn::AbstractTensorNetwork) return map_vertex_data_preserve_graph(adapt(to), tn) end -function linkinds(tn::AbstractTensorNetwork, edge::Pair) - return linkinds(tn, edgetype(tn)(edge)) -end -function linkinds(tn::AbstractTensorNetwork, edge::AbstractEdge) - return inds(tn[src(edge)]) ∩ inds(tn[dst(edge)]) -end -function linkaxes(tn::AbstractTensorNetwork, edge::Pair) +linkinds(tn::AbstractGraph, edge::Pair) = linkinds(tn, edgetype(tn)(edge)) +linkinds(tn::AbstractGraph, edge::AbstractEdge) = inds(tn[src(edge)]) ∩ inds(tn[dst(edge)]) + +function linkaxes(tn::AbstractGraph, edge::Pair) return linkaxes(tn, edgetype(tn)(edge)) end -function linkaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) +function linkaxes(tn::AbstractGraph, edge::AbstractEdge) return axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) end -function linknames(tn::AbstractTensorNetwork, edge::Pair) +function linknames(tn::AbstractGraph, edge::Pair) return linknames(tn, edgetype(tn)(edge)) end -function linknames(tn::AbstractTensorNetwork, edge::AbstractEdge) +function linknames(tn::AbstractGraph, edge::AbstractEdge) return dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) end -function siteinds(tn::AbstractTensorNetwork, v) +function siteinds(tn::AbstractGraph, v) s = inds(tn[v]) for v′ in neighbors(tn, v) s = setdiff(s, inds(tn[v′])) end return s end -function siteaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) +function siteaxes(tn::AbstractGraph, edge::AbstractEdge) s = axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) for v′ in neighbors(tn, v) s = setdiff(s, axes(tn[v′])) end return s end -function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge) +function sitenames(tn::AbstractGraph, edge::AbstractEdge) s = dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) for v′ in neighbors(tn, v) s = setdiff(s, dimnames(tn[v′])) @@ -122,8 +114,8 @@ function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge) return s end -function setindex_preserve_graph!(tn::AbstractTensorNetwork, value, vertex) - vertex_data(tn)[vertex] = value +function setindex_preserve_graph!(tn::AbstractGraph, value, vertex) + set!(vertex_data(tn), vertex, value) return tn end @@ -153,7 +145,7 @@ end # Update the graph of the TensorNetwork `tn` to include # edges that should exist based on the tensor connectivity. -function add_missing_edges!(tn::AbstractTensorNetwork) +function add_missing_edges!(tn::AbstractGraph) foreach(v -> add_missing_edges!(tn, v), vertices(tn)) return tn end @@ -161,7 +153,7 @@ end # Update the graph of the TensorNetwork `tn` to include # edges that should be incident to the vertex `v` # based on the tensor connectivity. -function add_missing_edges!(tn::AbstractTensorNetwork, v) +function add_missing_edges!(tn::AbstractGraph, v) for v′ in vertices(tn) if v ≠ v′ e = v => v′ @@ -175,13 +167,13 @@ end # Fix the edges of the TensorNetwork `tn` to match # the tensor connectivity. -function fix_edges!(tn::AbstractTensorNetwork) +function fix_edges!(tn::AbstractGraph) foreach(v -> fix_edges!(tn, v), vertices(tn)) return tn end # Fix the edges of the TensorNetwork `tn` to match # the tensor connectivity at vertex `v`. -function fix_edges!(tn::AbstractTensorNetwork, v) +function fix_edges!(tn::AbstractGraph, v) rem_edges!(tn, incident_edges(tn, v)) add_missing_edges!(tn, v) return tn @@ -215,28 +207,20 @@ function Base.setindex!(tn::AbstractTensorNetwork, value, v) fix_edges!(tn, v) return tn end -using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger # Fix ambiguity error. function Base.setindex!(graph::AbstractTensorNetwork, value, vertex::OrdinalSuffixedInteger) graph[vertices(graph)[vertex]] = value return graph end -# Fix ambiguity error. -function Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) - return error("No edge data.") -end -# Fix ambiguity error. -function Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) - return error("No edge data.") -end -using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger +Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) = not_implemented() +Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) = not_implemented() # Fix ambiguity error. function Base.setindex!( tn::AbstractTensorNetwork, value, edge::Pair{<:OrdinalSuffixedInteger, <:OrdinalSuffixedInteger}, ) - return error("No edge data.") + return not_implemented() end function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) @@ -254,4 +238,22 @@ function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) return nothing end -Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) \ No newline at end of file +Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) + +function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices::AbstractVector{V}) where {V <: Int} + return tensornetwork_induced_subgraph(graph, subvertices) +end +function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices) + return tensornetwork_induced_subgraph(graph, subvertices) +end + +function tensornetwork_induced_subgraph(graph, subvertices) + underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) + subgraph = similar_type(graph)(underlying_subgraph) + for v in vertices(subgraph) + if isassigned(graph, v) + set!(vertex_data(subgraph), v, graph[v]) + end + end + return subgraph, vlist +end From 49b087015955f1865cc7b333e43f35b47e704751 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 12:27:50 -0500 Subject: [PATCH 09/45] BP Caching overhauls --- .../abstractbeliefpropagationcache.jl | 184 ++++++++---------- .../beliefpropagationcache.jl | 178 ++++++----------- .../beliefpropagationproblem.jl | 109 ++++++++--- 3 files changed, 226 insertions(+), 245 deletions(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 5eae283..8c6b3dd 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -1,117 +1,124 @@ -abstract type AbstractBeliefPropagationCache{V} <: AbstractGraph{V} end +using Graphs: AbstractGraph, AbstractEdge +using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_eltype +using NamedGraphs.GraphsExtensions: boundary_edges +using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, parent -#Interface -factor(bp_cache::AbstractBeliefPropagationCache, vertex) = not_implemented() -setfactor!(bp_cache::AbstractBeliefPropagationCache, vertex, factor) = not_implemented() -messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() -message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) = not_implemented() -function default_message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) - return not_implemented() -end -default_messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() -function setmessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge, message) - return not_implemented() +messages(::AbstractGraph) = not_implemented() +messages(bp_cache::AbstractDataGraph) = edge_data(bp_cache) +messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges] + +message(bp_cache::AbstractGraph, edge::AbstractEdge) = messages(bp_cache)[edge] + +deletemessage!(bp_cache::AbstractGraph, edge) = not_implemented() +function deletemessage!(bp_cache::AbstractDataGraph, edge) + ms = messages(bp_cache) + delete!(ms, edge) + return bp_cache end -function deletemessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) - return not_implemented() + +function deletemessages!(bp_cache::AbstractGraph, edges = edges(bp_cache)) + for e in edges + deletemessage!(bp_cache, e) + end + return bp_cache end -function rescale_messages( - bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}; kwargs... - ) - return not_implemented() + +setmessage!(bp_cache::AbstractGraph, edge, message) = not_implemented() +function setmessage!(bp_cache::AbstractDataGraph, edge, message) + ms = messages(bp_cache) + set!(ms, edge, message) + return bp_cache end -function rescale_vertices( - bp_cache::AbstractBeliefPropagationCache, vertices::Vector; kwargs... - ) - return not_implemented() +function setmessage!(bp_cache::QuotientView, edge, message) + setmessages!(parent(bp_cache), QuotientEdge(edge), message) + return bp_cache end -function vertex_scalar(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) - return not_implemented() +function setmessages!(bp_cache::AbstractGraph, edge::QuotientEdge, message) + for e in edges(bp_cache, edge) + setmessage!(parent(bp_cache), e, message[e]) + end + return bp_cache end -function edge_scalar( - bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs... - ) - return not_implemented() +function setmessages!(bpc_dst::AbstractGraph, bpc_src::AbstractGraph, edges) + for e in edges + setmessage!(bpc_dst, e, message(bpc_src, e)) + end + return bpc_dst end -#Graph functionality needed -Graphs.vertices(bp_cache::AbstractBeliefPropagationCache) = not_implemented() -Graphs.edges(bp_cache::AbstractBeliefPropagationCache) = not_implemented() -function NamedGraphs.GraphsExtensions.boundary_edges( - bp_cache::AbstractBeliefPropagationCache, vertices; kwargs... - ) - return not_implemented() +factors(bpc::AbstractGraph) = vertex_data(bpc) +factors(bpc::AbstractGraph, vertices::Vector) = [factor(bpc, v) for v in vertices] +factors(bpc::AbstractGraph{V}, vertex::V) where {V} = factors(bpc, V[vertex]) + +factor(bpc::AbstractGraph, vertex) = factors(bpc)[vertex] + +setfactor!(bpc::AbstractGraph, vertex, factor) = not_implemented() +function setfactor!(bpc::AbstractDataGraph, vertex, factor) + fs = factors(bpc) + set!(fs, vertex, factor) + return bpc end -#Functions derived from the interface -function setmessages!(bp_cache::AbstractBeliefPropagationCache, edges, messages) - for (e, m) in zip(edges) - setmessage!(bp_cache, e, m) - end - return +function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge) + return message(bp_cache, edge) * message(bp_cache, reverse(edge)) end -function deletemessages!( - bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge} = edges(bp_cache) - ) - for e in edges - deletemessage!(bp_cache, e) - end - return bp_cache +function region_scalar(bp_cache::AbstractGraph, vertex) + + messages = incoming_messages(bp_cache, vertex) + state = factors(bp_cache, vertex) + + return reduce(*, messages) * reduce(*, state) end -function vertex_scalars( - bp_cache::AbstractBeliefPropagationCache, vertices = Graphs.vertices(bp_cache); kwargs... - ) - return map(v -> region_scalar(bp_cache, v; kwargs...), vertices) +message_type(bpc::AbstractGraph) = message_type(typeof(bpc)) +message_type(G::Type{<:AbstractGraph}) = eltype(Base.promote_op(messages, G)) +message_type(type::Type{<:AbstractDataGraph}) = edge_data_eltype(type) + +function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache)) + return map(v -> region_scalar(bp_cache, v), vertices) end -function edge_scalars( - bp_cache::AbstractBeliefPropagationCache, edges = Graphs.edges(bp_cache); kwargs... - ) - return map(e -> region_scalar(bp_cache, e; kwargs...), edges) +function edge_scalars(bp_cache::AbstractGraph, edges = edges(bp_cache)) + return map(e -> region_scalar(bp_cache, e), edges) end -function scalar_factors_quotient(bp_cache::AbstractBeliefPropagationCache) +function scalar_factors_quotient(bp_cache::AbstractGraph) return vertex_scalars(bp_cache), edge_scalars(bp_cache) end -function incoming_messages( - bp_cache::AbstractBeliefPropagationCache, vertices::Vector{<:Any}; ignore_edges = [] - ) - b_edges = NamedGraphs.GraphsExtensions.boundary_edges(bp_cache, vertices; dir = :in) +function incoming_messages(bp_cache::AbstractGraph, vertices; ignore_edges = []) + b_edges = boundary_edges(bp_cache, [vertices;]; dir = :in) b_edges = !isempty(ignore_edges) ? setdiff(b_edges, ignore_edges) : b_edges return messages(bp_cache, b_edges) end -function incoming_messages(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) - return incoming_messages(bp_cache, [vertex]; kwargs...) -end +default_messages(::AbstractGraph) = not_implemented() #Adapt interface for changing device -function map_messages(f, bp_cache::AbstractBeliefPropagationCache, es = edges(bp_cache)) - bp_cache = copy(bp_cache) +map_messages(f, bp_cache, es = edges(bp_cache)) = map_messages!(f, copy(bp_cache), es) +function map_messages!(f, bp_cache, es = edges(bp_cache)) for e in es setmessage!(bp_cache, e, f(message(bp_cache, e))) end return bp_cache end -function map_factors(f, bp_cache::AbstractBeliefPropagationCache, vs = vertices(bp_cache)) - bp_cache = copy(bp_cache) + +map_factors(f, bp_cache, vs = vertices(bp_cache)) = map_factors!(f, copy(bp_cache), vs) +function map_factors!(f, bp_cache, vs = vertices(bp_cache)) for v in vs setfactor!(bp_cache, v, f(factor(bp_cache, v))) end return bp_cache end -function adapt_messages(to, bp_cache::AbstractBeliefPropagationCache, args...) - return map_messages(adapt(to), bp_cache, args...) -end -function adapt_factors(to, bp_cache::AbstractBeliefPropagationCache, args...) - return map_factors(adapt(to), bp_cache, args...) -end -function freenergy(bp_cache::AbstractBeliefPropagationCache) +adapt_messages(to, bp_cache, es = edges(bp_cache)) = map_messages(adapt(to), bp_cache, es) +adapt_factors(to, bp_cache, vs = vertices(bp_cache)) = map_factors(adapt(to), bp_cache, vs) + +abstract type AbstractBeliefPropagationCache{V, ED} <: AbstractDataGraph{V, Nothing, ED} end + +function free_energy(bp_cache::AbstractBeliefPropagationCache) numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) if any(t -> real(t) < 0, numerator_terms) numerator_terms = complex.(numerator_terms) @@ -123,29 +130,4 @@ function freenergy(bp_cache::AbstractBeliefPropagationCache) any(iszero, denominator_terms) && return -Inf return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) end - -function partitionfunction(bp_cache::AbstractBeliefPropagationCache) - return exp(freenergy(bp_cache)) -end - -function rescale_messages(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) - return rescale_messages(bp_cache, [edge]) -end - -function rescale_messages(bp_cache::AbstractBeliefPropagationCache) - return rescale_messages(bp_cache, edges(bp_cache)) -end - -function rescale_vertices(bpc::AbstractBeliefPropagationCache; kwargs...) - return rescale_vertices(bpc, collect(vertices(bpc)); kwargs...) -end - -function rescale_vertex(bpc::AbstractBeliefPropagationCache, vertex; kwargs...) - return rescale_vertices(bpc, [vertex]; kwargs...) -end - -function rescale(bpc::AbstractBeliefPropagationCache, args...; kwargs...) - bpc = rescale_messages(bpc) - bpc = rescale_partitions(bpc, args...; kwargs...) - return bpc -end +partitionfunction(bp_cache::AbstractBeliefPropagationCache) = exp(free_energy(bp_cache)) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index b3a32b1..4e441fb 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -1,145 +1,93 @@ +using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph using Dictionaries: Dictionary, set!, delete! using Graphs: AbstractGraph, is_tree, connected_components -using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges +using NamedGraphs: convert_vertextype +using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, is_path_graph using ITensorBase: ITensor, dim +using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, quotient_graph -struct BeliefPropagationCache{V, N <: AbstractDataGraph{V}} <: - AbstractBeliefPropagationCache{V} +struct BeliefPropagationCache{V, N <: AbstractGraph{V}, ET, MT} <: + AbstractBeliefPropagationCache{V, MT} network::N - messages::Dictionary + messages::Dictionary{ET, MT} end -messages(bp_cache::BeliefPropagationCache) = bp_cache.messages -network(bp_cache::BeliefPropagationCache) = bp_cache.network +network(bp_cache) = underlying_graph(bp_cache) -BeliefPropagationCache(network) = BeliefPropagationCache(network, Dictionary()) - -function Base.copy(bp_cache::BeliefPropagationCache) - return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) +DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = getfield(bpc, :network) +DataGraphs.edge_data(bpc::BeliefPropagationCache) = getfield(bpc, :messages) +DataGraphs.vertex_data(bpc::BeliefPropagationCache) = vertex_data(network(bpc)) +function DataGraphs.underlying_graph_type(type::Type{<:BeliefPropagationCache}) + return fieldtype(type, :network) end -function deletemessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge) - ms = messages(bp_cache) - delete!(ms, e) - return bp_cache -end +message_type(::Type{<:BeliefPropagationCache{V, N, ET, MT}}) where {V, N, ET, MT} = MT -function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message) - ms = messages(bp_cache) - set!(ms, e, message) - return bp_cache +function BeliefPropagationCache(alg, network::AbstractGraph) + es = collect(edges(network)) + es = vcat(es, reverse.(es)) + messages = map(edge -> default_message(alg, network, edge), es) + return BeliefPropagationCache(network, Dictionary(es, messages)) end -function setmessages!(bpc_dst::BeliefPropagationCache, bpc_src::BeliefPropagationCache, edges) - ms_dst = messages(bpc_dst) - for e in edges - set!(ms_dst, e, message(bpc_src, e)) - end - return bpc_dst +function Base.copy(bp_cache::BeliefPropagationCache) + return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) end -function message(bp_cache::BeliefPropagationCache, edge::AbstractEdge; kwargs...) - ms = messages(bp_cache) - return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge) +# TODO: This needs to go in DataGraphsGraphsExtensionsExt +# +# This function is problematic when `ng isa TensorNetwork` as it relies on deleting edges +# and taking subgraphs, which is not always well-defined for the `TensorNetwork` type, +# hence we just strip off any `AbstractDataGraph` data to avoid this. +function forest_cover_edge_sequence(g::AbstractDataGraph; kwargs...) + return forest_cover_edge_sequence(underlying_graph(g); kwargs...) end - -function messages(bp_cache::BeliefPropagationCache, edges::Vector{<:AbstractEdge}) - return [message(bp_cache, e) for e in edges] +# TODO: This needs to go in PartitionedGraphsGraphsExtensionsExt +# +# While it is not at all necessary to explictly instantiate the `QuotientView`, it allows the +# data of a data graph to be removed using the above method if `parent_type(g)` is an +# `AbstractDataGraph`. +function forest_cover_edge_sequence(g::QuotientView; kwargs...) + return forest_cover_edge_sequence(quotient_graph(parent(g)); kwargs...) end - -#Forward onto the network -for f in [ - :(Graphs.vertices), - :(Graphs.edges), - :(Graphs.is_tree), - :(NamedGraphs.GraphsExtensions.boundary_edges), - :(factors), - :(default_bp_maxiter), - :(ITensorNetworksNext.setfactor!), - :(ITensorNetworksNext.linkinds), - :(ITensorNetworksNext.underlying_graph), - ] - @eval begin - function $f(bp_cache::BeliefPropagationCache, args...; kwargs...) - return $f(network(bp_cache), args...; kwargs...) +# TODO: This needs to go in GraphsExtensions +function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root_vertex) + add_edges!(g, edges(g)) + forests = forest_cover(g) + rv = edgetype(g)[] + for forest in forests + trees = [forest[vs] for vs in connected_components(forest)] + for tree in trees + tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) + push!(rv, vcat(tree_edges, reverse(reverse.(tree_edges)))...) end end + return rv end -function factors(tn::AbstractTensorNetwork, vertex) - return [tn[vertex]] -end - -function region_scalar(bp_cache::BeliefPropagationCache, edge::AbstractEdge) - return (message(bp_cache, edge) * message(bp_cache, reverse(edge)))[] -end - -function region_scalar(bp_cache::BeliefPropagationCache, vertex) - incoming_ms = incoming_messages(bp_cache, vertex) - state = factors(bp_cache, vertex) - return (reduce(*, incoming_ms) * reduce(*, state))[] -end - -function default_message(bp_cache::BeliefPropagationCache, edge::AbstractEdge) - return default_message(network(bp_cache), edge::AbstractEdge) -end - -function default_message(tn::AbstractTensorNetwork, edge::AbstractEdge) - t = ITensor(ones(dim.(linkinds(tn, edge))...), linkinds(tn, edge)...) - #TODO: Get datatype working on tensornetworks so we can support GPU, etc... - return t -end - -#TODO: Update message etc should go here... -function updated_message( - alg::Algorithm"contract", bp_cache::BeliefPropagationCache, edge::AbstractEdge - ) - vertex = src(edge) - incoming_ms = incoming_messages( - bp_cache, vertex; ignore_edges = typeof(edge)[reverse(edge)] - ) - state = factors(bp_cache, vertex) - #contract_list = ITensor[incoming_ms; state] - #sequence = contraction_sequence(contract_list; alg=alg.kwargs.sequence_alg) - #updated_messages = contract(contract_list; sequence) - updated_message = - !isempty(incoming_ms) ? reduce(*, state) * reduce(*, incoming_ms) : reduce(*, state) - if alg.normalize - message_norm = LinearAlgebra.norm(updated_message) - if !iszero(message_norm) - updated_message /= message_norm +function bpcache_induced_subgraph(graph, subvertices) + underlying_subgraph, vlist = Graphs.induced_subgraph(network(graph), subvertices) + subgraph = BeliefPropagationCache(underlying_subgraph, typeof(edge_data(graph))()) + for e in edges(subgraph) + if isassigned(graph, e) + set!(edge_data(subgraph), e, graph[e]) end end - return updated_message + return subgraph, vlist end -function default_algorithm( - ::Type{<:Algorithm"contract"}; normalize = true, sequence_alg = "optimal" - ) - return Algorithm("contract"; normalize, sequence_alg) +function Graphs.induced_subgraph(graph::BeliefPropagationCache, subvertices) + return bpcache_induced_subgraph(graph, subvertices) end -function default_algorithm( - ::Type{<:Algorithm"adapt_update"}; adapt, alg = default_algorithm(Algorithm"contract") - ) - return Algorithm("adapt_update"; adapt, alg) +# For method ambiguity +function Graphs.induced_subgraph(graph::BeliefPropagationCache{V}, subvertices::AbstractVector{V}) where {V <: Int} + return bpcache_induced_subgraph(graph, subvertices) end -function update_message!( - message_update_alg::Algorithm, bpc::BeliefPropagationCache, edge::AbstractEdge - ) - return setmessage!(bpc, edge, updated_message(message_update_alg, bpc, edge)) -end +## PartitionedGraphs -#Edge sequence stuff -function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root_vertex) - forests = forest_cover(g) - edges = edgetype(g)[] - for forest in forests - trees = [forest[vs] for vs in connected_components(forest)] - for tree in trees - tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) - push!(edges, vcat(tree_edges, reverse(reverse.(tree_edges)))...) - end - end - return edges +function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) + qview = QuotientView(network(bpc)) + messages = edge_data(QuotientView(bpc)) + return BeliefPropagationCache(qview, messages) end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 967b454..a05c97a 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,70 +1,121 @@ -mutable struct BeliefPropagationProblem{V, Cache <: AbstractBeliefPropagationCache{V}} <: - AbstractProblem +using Distributed: WorkerPool, @everywhere, remotecall, myid, waitall, workers +using Graphs: SimpleGraph, vertices, edges, has_edge +using NamedGraphs: AbstractNamedGraph, position_graph +using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices +using NamedGraphs.OrderedDictionaries: OrderedDictionary, OrderedIndices + +abstract type AbstractBeliefPropagationProblem{Alg} <: AbstractProblem end + +mutable struct BeliefPropagationProblem{Alg, Cache} <: AbstractBeliefPropagationProblem{Alg} + const alg::Alg const cache::Cache diff::Union{Nothing, Float64} end +BeliefPropagationProblem(alg, cache) = BeliefPropagationProblem(alg, cache, nothing) + function default_algorithm( ::Type{<:Algorithm"bp"}, - bpc::BeliefPropagationCache; + bpc; verbose = false, tol = nothing, - edge_sequence = forest_cover_edge_sequence(network(bpc)), + edge_sequence = forest_cover_edge_sequence(bpc), message_update_alg = default_algorithm(Algorithm"contract"), maxiter = is_tree(bpc) ? 1 : nothing, ) return Algorithm("bp"; verbose, tol, edge_sequence, message_update_alg, maxiter) end -function compute!(iter::RegionIterator{<:BeliefPropagationProblem}) - prob = iter.problem +function region_plan(prob::BeliefPropagationProblem{<:Algorithm"bp"}; sweep_kwargs...) + edges = prob.alg.edge_sequence - edge_group, kwargs = current_region_plan(iter) + plan = map(edges) do e + return e => (; sweep_kwargs...) + end - new_message_tensors = map(edge_group) do edge - old_message = message(prob.cache, edge) + return plan +end - new_message = updated_message(kwargs.message_update_alg, prob.cache, edge) +function compute!(iter::RegionIterator{<:BeliefPropagationProblem{<:Algorithm"bp"}}) + prob = iter.problem - if !isnothing(prob.diff) - # TODO: Define `message_diff` - prob.diff += message_diff(new_message, old_message) - end + edge, _ = current_region_plan(iter) + new_message = updated_message(prob.alg.message_update_alg, prob.cache, edge) + setmessage!(prob.cache, edge, new_message) - return new_message - end + return iter +end - foreach(edge_group, new_message_tensors) do edge, new_message - setmessage!(prob.cache, edge, new_message) - end +default_message(alg, network, edge) = default_message(typeof(alg), network, edge) - return iter +default_message(::Type{<:Algorithm}, network, edge) = not_implemented() +function default_message(::Type{<:Algorithm"bp"}, network, edge) + + #TODO: Get datatype working on tensornetworks so we can support GPU, etc... + links = linkinds(network, edge) + data = ones(dim.(links)...) + + t = ITensor(data, links) + return t end -function region_plan( - prob::BeliefPropagationProblem; root_vertex = default_root_vertex, sweep_kwargs... +updated_message(alg, bpc, edge) = not_implemented() +function updated_message(alg::Algorithm"contract", bpc, edge) + vertex = src(edge) + + incoming_ms = incoming_messages( + bpc, vertex; ignore_edges = typeof(edge)[reverse(edge)] ) - edges = forest_cover_edge_sequence(network(prob.cache); root_vertex) + updated_message = contract_messages(alg.contraction_alg, factors(bpc, vertex), incoming_ms) - plan = map(edges) do e - return [e] => (; sweep_kwargs...) + if alg.normalize + message_norm = LinearAlgebra.norm(updated_message) + if !iszero(message_norm) + updated_message /= message_norm + end end + return updated_message +end - return plan +contract_messages(alg, factors, messages) = not_implemented() +function contract_messages( + alg, + factors::Vector{<:AbstractArray}, + messages::Vector{<:AbstractArray}, + ) + return contract_network(alg, vcat(factors, messages)) +end + +function default_algorithm( + ::Type{<:Algorithm"contract"}; normalize = true, contraction_alg = default_algorithm(Algorithm"exact") + ) + return Algorithm("contract"; normalize, contraction_alg) +end +function default_algorithm( + ::Type{<:Algorithm"adapt_update"}; adapt, alg = default_algorithm(Algorithm"contract") + ) + return Algorithm("adapt_update"; adapt, alg) +end + +function update_message!( + message_update_alg::Algorithm, bpc::BeliefPropagationCache, edge::AbstractEdge + ) + return setmessage!(bpc, edge, updated_message(message_update_alg, bpc, edge)) end function update(bpc::AbstractBeliefPropagationCache; kwargs...) return update(default_algorithm(Algorithm"bp", bpc; kwargs...), bpc) end -function update(alg::Algorithm"bp", bpc) + +function update(alg, bpc) compute_error = !isnothing(alg.tol) diff = compute_error ? 0.0 : nothing - prob = BeliefPropagationProblem(bpc, diff) + prob = BeliefPropagationProblem(alg, bpc, diff) - iter = SweepIterator(prob, alg.maxiter; compute_error, getfield(alg, :kwargs)...) + iter = SweepIterator(prob, alg.maxiter; compute_error) for _ in iter if compute_error && prob.diff <= alg.tol From db46c04214ed93c05a6bbcc7d88b06c2745f9c34 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 12:47:19 -0500 Subject: [PATCH 10/45] Remove dead deps --- src/beliefpropagation/beliefpropagationproblem.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index a05c97a..f487ccc 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,4 +1,3 @@ -using Distributed: WorkerPool, @everywhere, remotecall, myid, waitall, workers using Graphs: SimpleGraph, vertices, edges, has_edge using NamedGraphs: AbstractNamedGraph, position_graph using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices From 400e373b9fbb7205359bfe5914ba8d6e0763cd16 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 13:05:45 -0500 Subject: [PATCH 11/45] Fix merge --- src/beliefpropagation/beliefpropagationproblem.jl | 2 +- src/tensornetwork.jl | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index f487ccc..61c97df 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -87,7 +87,7 @@ function contract_messages( end function default_algorithm( - ::Type{<:Algorithm"contract"}; normalize = true, contraction_alg = default_algorithm(Algorithm"exact") + ::Type{<:Algorithm"contract"}; normalize = true, contraction_alg = Algorithm("exact") ) return Algorithm("contract"; normalize, contraction_alg) end diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 11c2e88..44b883a 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -4,7 +4,7 @@ using Dictionaries: AbstractDictionary, Indices, dictionary using Graphs: AbstractSimpleGraph, rem_vertex!, rem_edge! using NamedDimsArrays: AbstractNamedDimsArray, dimnames using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype -using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, vertextype +using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, arrange_edge, vertextype using NamedGraphs.PartitionedGraphs: AbstractPartitionedGraph, PartitionedGraphs, @@ -87,10 +87,7 @@ function fix_links!(tn::AbstractTensorNetwork) for e in setdiff(arranged_edges(graph), tn_edges) insert_trivial_link!(tn, e) end - for edge in setdiff(arranged_edges(graph), arranged_edges(graph_structure)) - insert_trivial_link!(network, edge) - end - return network + return tn end # Determine the graph structure from the tensors. From b9aafe890f235c0543d7b209a46fbb86ce9f3b70 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 13:12:01 -0500 Subject: [PATCH 12/45] Fix type inference in TensorNetwork construction --- src/tensornetwork.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 44b883a..0681da5 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -66,8 +66,7 @@ end tensornetwork_edges(tensors) = tensornetwork_edges(NamedEdge, tensors) function TensorNetwork(f::Base.Callable, graph::AbstractGraph) - tensors = Dictionary(vertices(graph), f.(vertices(graph))) - return TensorNetwork(graph, tensors) + return TensorNetwork(graph, Dictionary(map(f, vertices(graph)))) end function TensorNetwork(graph::AbstractGraph, tensors) tn = _TensorNetwork(graph, tensors) From 4090e61f0069084ffd64ff53f65095ea3d05353c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Nov 2025 18:16:04 +0000 Subject: [PATCH 13/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/test_beliefpropagation.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index fc657e7..a39e1a6 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -51,4 +51,3 @@ using Test: @test, @testset z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test abs(z_bp - z_exact) <= 1.0e-14 end - From be0750ee8f0ea1323eb94de8c14eec4490ef1995 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 16:45:45 -0500 Subject: [PATCH 14/45] Remove `ITensorBase` dep --- Project.toml | 2 -- src/beliefpropagation/beliefpropagationcache.jl | 1 - src/beliefpropagation/beliefpropagationproblem.jl | 6 ++---- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index e0aea23..95b8be0 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,6 @@ DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" -ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" @@ -40,7 +39,6 @@ DerivableInterfaces = "0.5.5" DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" -ITensorBase = "0.2.14" LinearAlgebra = "1.10" MacroTools = "0.5.16" NamedDimsArrays = "0.8, 0.9" diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 4e441fb..5d8fa35 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -3,7 +3,6 @@ using Dictionaries: Dictionary, set!, delete! using Graphs: AbstractGraph, is_tree, connected_components using NamedGraphs: convert_vertextype using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, is_path_graph -using ITensorBase: ITensor, dim using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, quotient_graph struct BeliefPropagationCache{V, N <: AbstractGraph{V}, ET, MT} <: diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 61c97df..49d0ef8 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -52,10 +52,8 @@ function default_message(::Type{<:Algorithm"bp"}, network, edge) #TODO: Get datatype working on tensornetworks so we can support GPU, etc... links = linkinds(network, edge) - data = ones(dim.(links)...) - - t = ITensor(data, links) - return t + data = ones(Tuple(links)) + return data end updated_message(alg, bpc, edge) = not_implemented() From b971b89a91954d4175160c9788e2974267dc6fdc Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 1 Dec 2025 17:24:09 -0500 Subject: [PATCH 15/45] `forest_cover_edge_sequence` now constructs a temporary `NamedGraph` instead of trying to operate on existing graphs The reason for this is: - One only cares about the edges of the input graph - A simple graph cannot be used as it "forgets" its edge names resulting in recursion - As shown with `TensorNetwork`, removing edges may not always be defined. --- .../beliefpropagationcache.jl | 22 ++++--------------- 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 5d8fa35..994f480 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -33,25 +33,11 @@ function Base.copy(bp_cache::BeliefPropagationCache) return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) end -# TODO: This needs to go in DataGraphsGraphsExtensionsExt -# -# This function is problematic when `ng isa TensorNetwork` as it relies on deleting edges -# and taking subgraphs, which is not always well-defined for the `TensorNetwork` type, -# hence we just strip off any `AbstractDataGraph` data to avoid this. -function forest_cover_edge_sequence(g::AbstractDataGraph; kwargs...) - return forest_cover_edge_sequence(underlying_graph(g); kwargs...) -end -# TODO: This needs to go in PartitionedGraphsGraphsExtensionsExt -# -# While it is not at all necessary to explictly instantiate the `QuotientView`, it allows the -# data of a data graph to be removed using the above method if `parent_type(g)` is an -# `AbstractDataGraph`. -function forest_cover_edge_sequence(g::QuotientView; kwargs...) - return forest_cover_edge_sequence(quotient_graph(parent(g)); kwargs...) -end # TODO: This needs to go in GraphsExtensions -function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root_vertex) - add_edges!(g, edges(g)) +function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_root_vertex) + # All we care about are the edges so the type of the graph doesnt matter + g = NamedGraph(vertices(gi)) + add_edges!(g, edges(gi)) forests = forest_cover(g) rv = edgetype(g)[] for forest in forests From 9ebf0310c19fdf661cf6afd39c294710f167918b Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 6 Jan 2026 09:42:36 -0500 Subject: [PATCH 16/45] [LazyNamedDimsArrays] Fix `parenttype` method --- src/LazyNamedDimsArrays/lazynameddimsarray.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/LazyNamedDimsArrays/lazynameddimsarray.jl b/src/LazyNamedDimsArrays/lazynameddimsarray.jl index b0ed86a..c269902 100644 --- a/src/LazyNamedDimsArrays/lazynameddimsarray.jl +++ b/src/LazyNamedDimsArrays/lazynameddimsarray.jl @@ -7,7 +7,7 @@ using WrappedUnions: @wrapped union::Union{A, Mul{LazyNamedDimsArray{T, A}}} end -parenttype(::Type{LazyNamedDimsArray{<:Any, A}}) where {A} = A +parenttype(::Type{LazyNamedDimsArray{T, A}}) where {T, A} = A parenttype(::Type{LazyNamedDimsArray{T}}) where {T} = AbstractNamedDimsArray{T} parenttype(::Type{LazyNamedDimsArray}) = AbstractNamedDimsArray From 16fe303b73ab7f9ab3f5a1c46118319063a7af4a Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 6 Jan 2026 09:46:08 -0500 Subject: [PATCH 17/45] BP Cache now uses new `DataGraphs`interface --- .../abstractbeliefpropagationcache.jl | 13 +-- .../beliefpropagationcache.jl | 101 +++++++++++++----- 2 files changed, 82 insertions(+), 32 deletions(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 8c6b3dd..0cae3fa 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -3,11 +3,13 @@ using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_eltype using NamedGraphs.GraphsExtensions: boundary_edges using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, parent -messages(::AbstractGraph) = not_implemented() -messages(bp_cache::AbstractDataGraph) = edge_data(bp_cache) +messages(bp_cache::AbstractGraph) = edge_data(bp_cache) messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges] -message(bp_cache::AbstractGraph, edge::AbstractEdge) = messages(bp_cache)[edge] +function message(bp_cache::AbstractGraph, edge::AbstractEdge) + ms = messages(bp_cache) + return get!(ms, edge, default_message(bp_cache, edge)) +end deletemessage!(bp_cache::AbstractGraph, edge) = not_implemented() function deletemessage!(bp_cache::AbstractDataGraph, edge) @@ -25,8 +27,7 @@ end setmessage!(bp_cache::AbstractGraph, edge, message) = not_implemented() function setmessage!(bp_cache::AbstractDataGraph, edge, message) - ms = messages(bp_cache) - set!(ms, edge, message) + setindex!(bp_cache, message, edge) return bp_cache end function setmessage!(bp_cache::QuotientView, edge, message) @@ -56,7 +57,7 @@ factor(bpc::AbstractGraph, vertex) = factors(bpc)[vertex] setfactor!(bpc::AbstractGraph, vertex, factor) = not_implemented() function setfactor!(bpc::AbstractDataGraph, vertex, factor) fs = factors(bpc) - set!(fs, vertex, factor) + setindex!(fs, vertex, factor) return bpc end diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 994f480..c9793e6 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -1,32 +1,85 @@ -using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph +using DataGraphs: + DataGraphs, + AbstractDataGraph, + DataGraph, + has_edge_data, + get_vertex_data, + get_edge_data, + set_vertex_data!, + set_edge_data!, + unset_vertex_data!, + unset_edge_data!, + vertex_data_eltype, + edge_data_eltype, + underlying_graph, + underlying_graph_type using Dictionaries: Dictionary, set!, delete! -using Graphs: AbstractGraph, is_tree, connected_components -using NamedGraphs: convert_vertextype -using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, is_path_graph -using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, quotient_graph +using Graphs: AbstractGraph, is_tree, connected_components, is_directed +using NamedGraphs: NamedDiGraph, convert_vertextype, parent_graph_indices +using NamedGraphs.GraphsExtensions: default_root_vertex, + forest_cover, + post_order_dfs_edges, + vertextype, + is_path_graph, + undirected_graph +using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, QuotientEdges, quotient_graph, quotientedges -struct BeliefPropagationCache{V, N <: AbstractGraph{V}, ET, MT} <: +struct BeliefPropagationCache{V, G <: AbstractGraph{V}, N <: AbstractGraph{V}, ET, MT} <: AbstractBeliefPropagationCache{V, MT} + underlying_graph::G # we only use this for the edges. network::N messages::Dictionary{ET, MT} + function BeliefPropagationCache(network::AbstractGraph, messages::Dictionary) + + V = vertextype(network) + N = typeof(network) + ET = keytype(messages) + MT = eltype(messages) + + # Construct a directed graph version of the underlying graph of the tensor network. + digraph = directed_graph(underlying_graph(network)) + + bpc = new{V, typeof(digraph), N, ET, MT}(digraph, network, messages) + + for edge in edges(bpc) + get!(() -> default_message(bpc, edge), messages, edge) + end + return bpc + end end -network(bp_cache) = underlying_graph(bp_cache) +network(bp_cache) = getfield(bp_cache, :network) + +DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = getfield(bpc, :underlying_graph) + +DataGraphs.has_vertex_data(bpc::BeliefPropagationCache, vertex) = has_vertex_data(network(bpc), vertex) +DataGraphs.has_edge_data(bpc::BeliefPropagationCache, edge) = haskey(bpc.messages, edge) -DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = getfield(bpc, :network) -DataGraphs.edge_data(bpc::BeliefPropagationCache) = getfield(bpc, :messages) -DataGraphs.vertex_data(bpc::BeliefPropagationCache) = vertex_data(network(bpc)) -function DataGraphs.underlying_graph_type(type::Type{<:BeliefPropagationCache}) - return fieldtype(type, :network) +DataGraphs.get_vertex_data(bpc::BeliefPropagationCache, vertex) = get_vertex_data(network(bpc), vertex) +DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) = bpc.messages[edge] + +DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = set_vertex_data!(network(bpc), val, vertex) +DataGraphs.set_edge_data!(bpc::BeliefPropagationCache, val, edge) = set!(bpc.messages, edge, val) + +DataGraphs.unset_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = unset_vertex_data!(network(bpc), val, vertex) +DataGraphs.unset_edge_data!(bpc::BeliefPropagationCache, val, edge) = unset!(bpc.messages, edge, val) + +function DataGraphs.vertex_data_eltype(T::Type{<:BeliefPropagationCache}) + return vertex_data_eltype(fieldtype(T, :network)) +end +function DataGraphs.edge_data_eltype(T::Type{<:BeliefPropagationCache}) + return eltype(fieldtype(T, :messages)) end -message_type(::Type{<:BeliefPropagationCache{V, N, ET, MT}}) where {V, N, ET, MT} = MT +message_type(T::Type{<:BeliefPropagationCache}) = edge_data_eltype(T) -function BeliefPropagationCache(alg, network::AbstractGraph) - es = collect(edges(network)) - es = vcat(es, reverse.(es)) - messages = map(edge -> default_message(alg, network, edge), es) - return BeliefPropagationCache(network, Dictionary(es, messages)) +function BeliefPropagationCache(network::AbstractGraph) + MT = vertex_data_eltype(typeof(network)) + return BeliefPropagationCache(MT, network) +end +function BeliefPropagationCache(MT::Type, network::AbstractGraph) + dict = Dictionary{edgetype(network), MT}() + return BeliefPropagationCache(network, dict) end function Base.copy(bp_cache::BeliefPropagationCache) @@ -61,18 +114,14 @@ function bpcache_induced_subgraph(graph, subvertices) return subgraph, vlist end -function Graphs.induced_subgraph(graph::BeliefPropagationCache, subvertices) - return bpcache_induced_subgraph(graph, subvertices) -end -# For method ambiguity -function Graphs.induced_subgraph(graph::BeliefPropagationCache{V}, subvertices::AbstractVector{V}) where {V <: Int} +function Graphs.induced_subgraph(graph::BeliefPropagationCache{V}, subvertices::Vector{V}) where {V} return bpcache_induced_subgraph(graph, subvertices) end ## PartitionedGraphs function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) - qview = QuotientView(network(bpc)) - messages = edge_data(QuotientView(bpc)) - return BeliefPropagationCache(qview, messages) + inds = Indices(parent_graph_indices(QuotientEdges(underlying_graph(bpc)))) + data = map(e -> bpc[QuotientEdge(e)], inds) + return BeliefPropagationCache(QuotientView(network(bpc)), data) end From 24a4335f61699a2d818f8b75a8b2867f7a16b3b5 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 6 Jan 2026 09:46:49 -0500 Subject: [PATCH 18/45] Adjust `default_message` to take a `message` type as its first argument --- .../beliefpropagationproblem.jl | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 49d0ef8..24b024d 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -2,6 +2,9 @@ using Graphs: SimpleGraph, vertices, edges, has_edge using NamedGraphs: AbstractNamedGraph, position_graph using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices using NamedGraphs.OrderedDictionaries: OrderedDictionary, OrderedIndices +using NamedDimsArrays: AbstractNamedDimsArray +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, parenttype, lazy + abstract type AbstractBeliefPropagationProblem{Alg} <: AbstractProblem end @@ -45,15 +48,16 @@ function compute!(iter::RegionIterator{<:BeliefPropagationProblem{<:Algorithm"bp return iter end -default_message(alg, network, edge) = default_message(typeof(alg), network, edge) - -default_message(::Type{<:Algorithm}, network, edge) = not_implemented() -function default_message(::Type{<:Algorithm"bp"}, network, edge) - - #TODO: Get datatype working on tensornetworks so we can support GPU, etc... - links = linkinds(network, edge) - data = ones(Tuple(links)) - return data +function default_message(bpc::BeliefPropagationCache, edge) + return default_message(message_type(bpc), network(bpc), edge) +end +function default_message(T::Type, network, edge) + array = ones(Tuple(linkinds(network, edge))) + return convert(T, array) +end +function default_message(T::Type{<:LazyNamedDimsArray}, network, edge) + message = default_message(parenttype(T), network, edge) + return convert(T, lazy(message)) end updated_message(alg, bpc, edge) = not_implemented() From c43884ecb5185386ab5acc6c08f4344c0d566811 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 6 Jan 2026 09:47:44 -0500 Subject: [PATCH 19/45] Remove unnecessary code and fix ambiguities in `AbstractTensorNetwork` --- src/abstracttensornetwork.jl | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index b02c789..b820867 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -53,10 +53,6 @@ Base.eltype(tn::AbstractTensorNetwork) = eltype(vertex_data(tn)) # Overload if needed Graphs.is_directed(::Type{<:AbstractTensorNetwork}) = false -# AbstractDataGraphs overloads -DataGraphs.vertex_data(::AbstractTensorNetwork) = not_implemented() -DataGraphs.edge_data(::AbstractTensorNetwork) = not_implemented() - DataGraphs.underlying_graph(::AbstractTensorNetwork) = not_implemented() function NamedGraphs.vertex_positions(tn::AbstractTensorNetwork) return NamedGraphs.vertex_positions(underlying_graph(tn)) @@ -240,10 +236,7 @@ end Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) -function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices::AbstractVector{V}) where {V <: Int} - return tensornetwork_induced_subgraph(graph, subvertices) -end -function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices) +function Graphs.induced_subgraph(graph::AbstractTensorNetwork{V}, subvertices::Vector{V}) where {V} return tensornetwork_induced_subgraph(graph, subvertices) end From dd6f6454f01380e03e609cd60b1d4bfdf5499718 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 6 Jan 2026 09:48:10 -0500 Subject: [PATCH 20/45] `TensorNetwork` type now uses new DataGraphs interface --- src/tensornetwork.jl | 50 +++++++++++++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 0681da5..16c80e3 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -1,9 +1,9 @@ using Combinatorics: combinations using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph -using Dictionaries: AbstractDictionary, Indices, dictionary +using Dictionaries: AbstractDictionary, Indices, dictionary, set!, unset! using Graphs: AbstractSimpleGraph, rem_vertex!, rem_edge! using NamedDimsArrays: AbstractNamedDimsArray, dimnames -using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype +using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype, Vertices, parent_graph_indices using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, arrange_edge, vertextype using NamedGraphs.PartitionedGraphs: AbstractPartitionedGraph, @@ -12,9 +12,13 @@ using NamedGraphs.PartitionedGraphs: partitioned_vertices, partitionedgraph, quotient_graph, - quotient_graph_type + quotient_graph_type, + QuotientVertex, + QuotientVertices, + QuotientVertexVertices, + quotientvertices using .LazyNamedDimsArrays: lazy, Mul -using DataGraphs: vertex_data_eltype, vertex_data, edge_data +using DataGraphs: vertex_data_eltype, vertex_data, edge_data, get_vertices_data using DataGraphs.DataGraphsPartitionedGraphsExt function _TensorNetwork end @@ -31,7 +35,7 @@ struct TensorNetwork{V, VD, UG <: AbstractGraph{V}, Tensors <: AbstractDictionar end end # This assumes the tensor connectivity matches the graph structure. -function _TensorNetwork(graph::AbstractGraph, tensors) +function TensorNetwork(graph::AbstractGraph, tensors::AbstractDictionary) return _TensorNetwork(graph, Dictionary(keys(tensors), values(tensors))) end @@ -39,10 +43,18 @@ function TensorNetwork{V, VD, UG, Tensors}(graph::UG) where {V, VD, UG <: Abstra return _TensorNetwork(graph, Tensors()) end -DataGraphs.underlying_graph(tn::TensorNetwork) = getfield(tn, :underlying_graph) -DataGraphs.vertex_data(tn::TensorNetwork) = getfield(tn, :tensors) -DataGraphs.edge_data(tn::TensorNetwork) = Dictionary{edgetype(tn), Nothing}() -DataGraphs.vertex_data_eltype(T::Type{<:TensorNetwork}) = eltype(fieldtype(T, :tensors)) +# DataGraphs interface + +DataGraphs.underlying_graph(tn::TensorNetwork) = tn.underlying_graph + +DataGraphs.has_vertex_data(tn::TensorNetwork, v) = haskey(tn.tensors, v) +DataGraphs.has_edge_data(tn::TensorNetwork, e) = false + +DataGraphs.get_vertex_data(tn::TensorNetwork, v) = tn.tensors[v] + +DataGraphs.set_vertex_data!(tn::TensorNetwork, val, v) = set!(tn.tensors, v, val) +DataGraphs.unset_vertex_data!(tn::TensorNetwork, val, v) = unset!(tn.tensors, v, val) + function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork}) return fieldtype(type, :underlying_graph) end @@ -123,17 +135,23 @@ function Graphs.rem_edge!(tn::TensorNetwork, e) return true end -function GraphsExtensions.graph_from_vertices(type::Type{<:TensorNetwork}, vertices) +function GraphsExtensions.similar(type::Type{<:TensorNetwork}) DT = fieldtype(type, :tensors) empty_dict = DT() - return TensorNetwork(similar_graph(underlying_graph_type(type), vertices), empty_dict) + return TensorNetwork(similar_graph(underlying_graph_type(type)), empty_dict) end ## PartitionedGraphs function PartitionedGraphs.quotient_graph(tn::TensorNetwork) ug = quotient_graph(underlying_graph(tn)) - return TensorNetwork(ug, vertex_data(QuotientView(tn))) + + inds = Indices(parent_graph_indices(QuotientVertices(tn))) + data = map(v -> tn[QuotientVertex(v)], inds) + + return TensorNetwork(ug, data) end +# TODO: This method should not be required with a better interface with a better +# DataGraphsPartitionedGraphsExt interface. function PartitionedGraphs.quotient_graph_type(type::Type{<:TensorNetwork}) UG = quotient_graph_type(underlying_graph_type(type)) VD = Vector{vertex_data_eltype(type)} @@ -141,9 +159,10 @@ function PartitionedGraphs.quotient_graph_type(type::Type{<:TensorNetwork}) return TensorNetwork{V, VD, UG, Dictionary{V, VD}} end +# Partition the underlying graph of the tensor network; does not affect the data. function PartitionedGraphs.partitionedgraph(tn::TensorNetwork, parts) pg = partitionedgraph(underlying_graph(tn), parts) - return TensorNetwork(pg, vertex_data(tn)) + return TensorNetwork(pg, copy(vertex_data(tn))) end PartitionedGraphs.departition(tn::TensorNetwork) = tn @@ -153,8 +172,9 @@ function PartitionedGraphs.departition( return TensorNetwork(departition(underlying_graph(tn)), vertex_data(tn)) end -function DataGraphsPartitionedGraphsExt.to_quotient_vertex_data(::TensorNetwork, data) - return mapreduce(lazy, *, collect(last(data))) +function DataGraphs.get_vertices_data(tn::TensorNetwork, vertex::QuotientVertexVertices) + data = collect(map(v -> tn[v], NamedGraphs.parent_graph_indices(vertex))) + return mapreduce(lazy, *, data) end function PartitionedGraphs.quotientview(tn::TensorNetwork) From 7bb579c7037c93e591a09a0c88e3aa489ef39c5d Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Fri, 19 Dec 2025 16:37:59 -0500 Subject: [PATCH 21/45] Sweeping algorithms based on AlgorithmsInterface.jl (#30) --- Project.toml | 4 +- docs/Project.toml | 2 +- examples/Project.toml | 2 +- .../AlgorithmsInterfaceExtensions.jl | 306 ++++++++++++ src/ITensorNetworksNext.jl | 6 +- src/abstract_problem.jl | 1 - src/adapters.jl | 45 -- src/iterators.jl | 170 ------- src/sweeping/eigenproblem.jl | 44 ++ src/sweeping/utils.jl | 12 + test/Project.toml | 3 +- test/test_algorithmsinterfaceextensions.jl | 472 ++++++++++++++++++ test/test_aqua.jl | 2 +- test/test_dmrg.jl | 34 ++ test/test_iterators.jl | 221 -------- test/test_sweeping.jl | 65 +++ 16 files changed, 944 insertions(+), 445 deletions(-) create mode 100644 src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl delete mode 100644 src/abstract_problem.jl delete mode 100644 src/adapters.jl delete mode 100644 src/iterators.jl create mode 100644 src/sweeping/eigenproblem.jl create mode 100644 src/sweeping/utils.jl create mode 100644 test/test_algorithmsinterfaceextensions.jl create mode 100644 test/test_dmrg.jl delete mode 100644 test/test_iterators.jl create mode 100644 test/test_sweeping.jl diff --git a/Project.toml b/Project.toml index 95b8be0..e6919fc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,12 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.2.4" +version = "0.3.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +AlgorithmsInterface = "d1e3940c-cd12-4505-8585-b0a4b322527d" BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" @@ -32,6 +33,7 @@ ITensorNetworksNextTensorOperationsExt = "TensorOperations" [compat] AbstractTrees = "0.4.5" Adapt = "4.3" +AlgorithmsInterface = "0.1.0" BackendSelection = "0.1.6" Combinatorics = "1" DataGraphs = "0.2.7" diff --git a/docs/Project.toml b/docs/Project.toml index 15d156a..9e273b0 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -8,5 +8,5 @@ ITensorNetworksNext = {path = ".."} [compat] Documenter = "1" -ITensorNetworksNext = "0.2" +ITensorNetworksNext = "0.3" Literate = "2" diff --git a/examples/Project.toml b/examples/Project.toml index a9cd21b..bd688e9 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,4 +5,4 @@ ITensorNetworksNext = "302f2e75-49f0-4526-aef7-d8ba550cb06c" ITensorNetworksNext = {path = ".."} [compat] -ITensorNetworksNext = "0.2" +ITensorNetworksNext = "0.3" diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl new file mode 100644 index 0000000..a8c814e --- /dev/null +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -0,0 +1,306 @@ +module AlgorithmsInterfaceExtensions + +import AlgorithmsInterface as AI + +#========================== Patches for AlgorithmsInterface.jl ============================# + +abstract type Problem <: AI.Problem end +abstract type Algorithm <: AI.Algorithm end +abstract type State <: AI.State end + +function AI.initialize_state!( + problem::Problem, algorithm::Algorithm, state::State; iteration = 0, kwargs... + ) + for (k, v) in pairs(kwargs) + setproperty!(state, k, v) + end + state.iteration = iteration + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state + ) + return state +end + +function AI.initialize_state( + problem::Problem, algorithm::Algorithm; kwargs... + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + return DefaultState(; stopping_criterion_state, kwargs...) +end + +#============================ DefaultState ================================================# + +@kwdef mutable struct DefaultState{ + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, + } <: State + iterate::Iterate + iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState +end + +#============================ increment! ==================================================# + +# Custom version of `increment!` that also takes the problem and algorithm as arguments. +function AI.increment!(problem::Problem, algorithm::Algorithm, state::State) + return AI.increment!(state) +end + +#============================ solve! ======================================================# + +# Custom version of `solve!` that allows specifying the logger and also overloads +# `increment!` on the problem and algorithm. +function basetypenameof(x) + return Symbol(last(split(String(Symbol(Base.typename(typeof(x)).wrapper)), "."))) +end +default_logging_context_prefix(x) = Symbol(basetypenameof(x), :_) +function default_logging_context_prefix(problem::Problem, algorithm::Algorithm) + return Symbol( + default_logging_context_prefix(problem), + default_logging_context_prefix(algorithm), + ) +end +function AI.solve!( + problem::Problem, algorithm::Algorithm, state::State; + logging_context_prefix = default_logging_context_prefix(problem, algorithm), + kwargs..., + ) + logger = AI.algorithm_logger() + + context_suffixes = [:Start, :PreStep, :PostStep, :Stop] + contexts = Dict(context_suffixes .=> Symbol.(logging_context_prefix, context_suffixes)) + + # initialize the state and emit message + AI.initialize_state!(problem, algorithm, state; kwargs...) + AI.emit_message(logger, problem, algorithm, state, contexts[:Start]) + + # main body of the algorithm + while !AI.is_finished!(problem, algorithm, state) + AI.increment!(problem, algorithm, state) + + # logging event between convergence check and algorithm step + AI.emit_message(logger, problem, algorithm, state, contexts[:PreStep]) + + # algorithm step + AI.step!(problem, algorithm, state; logging_context_prefix) + + # logging event between algorithm step and convergence check + AI.emit_message(logger, problem, algorithm, state, contexts[:PostStep]) + end + + # emit message about finished state + AI.emit_message(logger, problem, algorithm, state, contexts[:Stop]) + return state +end + +function AI.solve( + problem::Problem, algorithm::Algorithm; + logging_context_prefix = default_logging_context_prefix(problem, algorithm), + kwargs..., + ) + state = AI.initialize_state(problem, algorithm; kwargs...) + return AI.solve!(problem, algorithm, state; logging_context_prefix, kwargs...) +end + +#============================ AlgorithmIterator ===========================================# + +abstract type AlgorithmIterator end + +function algorithm_iterator( + problem::Problem, algorithm::Algorithm, state::State + ) + return DefaultAlgorithmIterator(problem, algorithm, state) +end + +function AI.is_finished!(iterator::AlgorithmIterator) + return AI.is_finished!(iterator.problem, iterator.algorithm, iterator.state) +end +function AI.is_finished(iterator::AlgorithmIterator) + return AI.is_finished(iterator.problem, iterator.algorithm, iterator.state) +end +function AI.increment!(iterator::AlgorithmIterator) + return AI.increment!(iterator.problem, iterator.algorithm, iterator.state) +end +function AI.step!(iterator::AlgorithmIterator) + return AI.step!(iterator.problem, iterator.algorithm, iterator.state) +end +function Base.iterate(iterator::AlgorithmIterator, init = nothing) + AI.is_finished!(iterator) && return nothing + AI.increment!(iterator) + AI.step!(iterator) + return iterator.state, nothing +end + +struct DefaultAlgorithmIterator{Problem, Algorithm, State} <: AlgorithmIterator + problem::Problem + algorithm::Algorithm + state::State +end + +#============================ with_algorithmlogger ========================================# + +# Allow passing functions, not just CallbackActions. +@inline function with_algorithmlogger(f, args::Pair{Symbol, AI.LoggingAction}...) + return AI.with_algorithmlogger(f, args...) +end +@inline function with_algorithmlogger(f, args::Pair{Symbol}...) + return AI.with_algorithmlogger(f, (first.(args) .=> AI.CallbackAction.(last.(args)))...) +end + +#============================ NestedAlgorithm =============================================# + +abstract type NestedAlgorithm <: Algorithm end + +function nested_algorithm(f::Function, nalgorithms::Int; kwargs...) + return DefaultNestedAlgorithm(f, nalgorithms; kwargs...) +end + +max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms) + +function get_subproblem( + problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State + ) + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) + return subproblem, subalgorithm, substate +end + +function set_substate!( + problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State, substate::AI.State + ) + state.iterate = substate.iterate + return state +end + +function AI.step!( + problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State; + logging_context_prefix = Symbol() + ) + # Get the subproblem, subalgorithm, and substate. + subproblem, subalgorithm, substate = get_subproblem(problem, algorithm, state) + + # Solve the subproblem with the subalgorithm. + logging_context_prefix = Symbol( + logging_context_prefix, default_logging_context_prefix(subalgorithm) + ) + AI.solve!(subproblem, subalgorithm, substate; logging_context_prefix) + + # Update the state with the substate. + set_substate!(problem, algorithm, state, substate) + + return state +end + +#= + DefaultNestedAlgorithm(sweeps::AbstractVector{<:Algorithm}) + +An algorithm that consists of running an algorithm at each iteration +from a list of stored algorithms. +=# +@kwdef struct DefaultNestedAlgorithm{ + ChildAlgorithm <: Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: NestedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) +end +function DefaultNestedAlgorithm(f::Function, nalgorithms::Int; kwargs...) + return DefaultNestedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...) +end + +#============================ FlattenedAlgorithm ==========================================# + +# Flatten a nested algorithm. +abstract type FlattenedAlgorithm <: Algorithm end +abstract type FlattenedAlgorithmState <: State end + +function flattened_algorithm(f::Function, nalgorithms::Int; kwargs...) + return DefaultFlattenedAlgorithm(f, nalgorithms; kwargs...) +end + +function AI.initialize_state( + problem::Problem, algorithm::FlattenedAlgorithm; kwargs... + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + return DefaultFlattenedAlgorithmState(; stopping_criterion_state, kwargs...) +end +function AI.increment!( + problem::Problem, algorithm::Algorithm, state::FlattenedAlgorithmState + ) + # Increment the total iteration count. + state.iteration += 1 + # TODO: Use `is_finished!` instead? + if state.child_iteration ≥ max_iterations(algorithm.algorithms[state.parent_iteration]) + # We're on the last iteration of the child algorithm, so move to the next + # child algorithm. + state.parent_iteration += 1 + state.child_iteration = 1 + else + # Iterate the child algorithm. + state.child_iteration += 1 + end + return state +end +function AI.step!( + problem::AI.Problem, algorithm::FlattenedAlgorithm, state::FlattenedAlgorithmState; + logging_context_prefix = Symbol() + ) + algorithm_sweep = algorithm.algorithms[state.parent_iteration] + state_sweep = AI.initialize_state( + problem, algorithm_sweep; + state.iterate, iteration = state.child_iteration + ) + AI.step!(problem, algorithm_sweep, state_sweep; logging_context_prefix) + state.iterate = state_sweep.iterate + return state +end + +@kwdef struct DefaultFlattenedAlgorithm{ + ChildAlgorithm <: Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: FlattenedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = + AI.StopAfterIteration(sum(max_iterations, algorithms)) +end +function DefaultFlattenedAlgorithm(f::Function, nalgorithms::Int; kwargs...) + return DefaultFlattenedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...) +end + +@kwdef mutable struct DefaultFlattenedAlgorithmState{ + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, + } <: FlattenedAlgorithmState + iterate::Iterate + iteration::Int = 0 + parent_iteration::Int = 1 + child_iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState +end + +#============================ NonIterativeAlgorithm =======================================# + +# Algorithm that only performs a single step. +abstract type NonIterativeAlgorithm <: Algorithm end +abstract type NonIterativeAlgorithmState <: State end + +function AI.initialize_state(problem::Problem, algorithm::NonIterativeAlgorithm; kwargs...) + return DefaultNonIterativeAlgorithmState(; kwargs...) +end +function AI.solve!( + problem::Problem, algorithm::NonIterativeAlgorithm, state::State; kwargs... + ) + return throw(MethodError(AI.solve!, (problem, algorithm, state))) +end + +@kwdef mutable struct DefaultNonIterativeAlgorithmState{Iterate} <: + NonIterativeAlgorithmState + iterate::Iterate +end + +end diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index cca4b6d..d3c5c21 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -1,13 +1,13 @@ module ITensorNetworksNext +include("AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl") include("LazyNamedDimsArrays/LazyNamedDimsArrays.jl") include("abstracttensornetwork.jl") include("tensornetwork.jl") include("TensorNetworkGenerators/TensorNetworkGenerators.jl") include("contract_network.jl") -include("abstract_problem.jl") -include("iterators.jl") -include("adapters.jl") +include("sweeping/utils.jl") +include("sweeping/eigenproblem.jl") include("beliefpropagation/abstractbeliefpropagationcache.jl") include("beliefpropagation/beliefpropagationcache.jl") diff --git a/src/abstract_problem.jl b/src/abstract_problem.jl deleted file mode 100644 index 5a65e0a..0000000 --- a/src/abstract_problem.jl +++ /dev/null @@ -1 +0,0 @@ -abstract type AbstractProblem end diff --git a/src/adapters.jl b/src/adapters.jl deleted file mode 100644 index 28318fb..0000000 --- a/src/adapters.jl +++ /dev/null @@ -1,45 +0,0 @@ -""" - struct IncrementOnly{S<:AbstractNetworkIterator} <: AbstractNetworkIterator - -Iterator wrapper whos `compute!` function simply returns itself, doing nothing in the -process. This allows one to manually call a custom `compute!` or insert their own code it in -the loop body in place of `compute!`. -""" -struct IncrementOnly{S <: AbstractNetworkIterator} <: AbstractNetworkIterator - parent::S -end - -islaststep(adapter::IncrementOnly) = islaststep(adapter.parent) -increment!(adapter::IncrementOnly) = increment!(adapter.parent) -compute!(adapter::IncrementOnly) = adapter - -IncrementOnly(adapter::IncrementOnly) = adapter - -""" - struct EachRegion{SweepIterator} <: AbstractNetworkIterator - -Adapter that flattens each region iterator in the parent sweep iterator into a single -iterator. -""" -struct EachRegion{SI <: SweepIterator} <: AbstractNetworkIterator - parent::SI -end - -# In keeping with Julia convention. -eachregion(iter::SweepIterator) = EachRegion(iter) - -# Essential definitions -function islaststep(adapter::EachRegion) - region_iter = region_iterator(adapter.parent) - return islaststep(adapter.parent) && islaststep(region_iter) -end -function increment!(adapter::EachRegion) - region_iter = region_iterator(adapter.parent) - islaststep(region_iter) ? increment!(adapter.parent) : increment!(region_iter) - return adapter -end -function compute!(adapter::EachRegion) - region_iter = region_iterator(adapter.parent) - compute!(region_iter) - return adapter -end diff --git a/src/iterators.jl b/src/iterators.jl deleted file mode 100644 index 62d5b21..0000000 --- a/src/iterators.jl +++ /dev/null @@ -1,170 +0,0 @@ -""" - abstract type AbstractNetworkIterator - -A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins -with a call to `increment!` before executing `compute!`, however the initial call to -`iterate` skips the `increment!` call as it is assumed the iterator is initalized such that -this call is implict. Termination of the iterator is controlled by the function `done`. -""" -abstract type AbstractNetworkIterator end - -# We use greater than or equals here as we increment the state at the start of the iteration -islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator) - -function Base.iterate(iterator::AbstractNetworkIterator, init = true) - # The assumption is that first "increment!" is implicit, therefore we must skip the - # the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not - # defined when length < 1, - init || islaststep(iterator) && return nothing - # We seperate increment! from step! and demand that any AbstractNetworkIterator *must* - # define a method for increment! This way we avoid cases where one may wish to nest - # calls to different step! methods accidentaly incrementing multiple times. - init || increment!(iterator) - rv = compute!(iterator) - return rv, false -end - -increment!(iterator::AbstractNetworkIterator) = throw(MethodError(increment!, Tuple{typeof(iterator)})) -compute!(iterator::AbstractNetworkIterator) = iterator - -step!(iterator::AbstractNetworkIterator) = step!(identity, iterator) -function step!(f, iterator::AbstractNetworkIterator) - compute!(iterator) - f(iterator) - increment!(iterator) - return iterator -end - -# -# RegionIterator -# -""" - struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator -""" -mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator - problem::Problem - region_plan::RegionPlan - which_region::Int - const which_sweep::Int - function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R} - if isempty(region_plan) - throw(ArgumentError("Cannot construct a region iterator with 0 elements.")) - end - return new{P, R}(problem, region_plan, 1, sweep) - end -end - -function RegionIterator(problem; sweep, sweep_kwargs...) - plan = region_plan(problem; sweep_kwargs...) - return RegionIterator(problem, plan, sweep) -end - -state(region_iter::RegionIterator) = region_iter.which_region -Base.length(region_iter::RegionIterator) = length(region_iter.region_plan) - -problem(region_iter::RegionIterator) = region_iter.problem - -function current_region_plan(region_iter::RegionIterator) - return region_iter.region_plan[region_iter.which_region] -end - -function current_region(region_iter::RegionIterator) - region, _ = current_region_plan(region_iter) - return region -end - -function region_kwargs(region_iter::RegionIterator) - _, kwargs = current_region_plan(region_iter) - return kwargs -end -function region_kwargs(f::Function, iter::RegionIterator) - return get(region_kwargs(iter), Symbol(f, :_kwargs), (;)) -end - -function prev_region(region_iter::RegionIterator) - state(region_iter) <= 1 && return nothing - prev, _ = region_iter.region_plan[region_iter.which_region - 1] - return prev -end - -function next_region(region_iter::RegionIterator) - islaststep(region_iter) && return nothing - next, _ = region_iter.region_plan[region_iter.which_region + 1] - return next -end - -# -# Functions associated with RegionIterator -# -function increment!(region_iter::RegionIterator) - region_iter.which_region += 1 - return region_iter -end - -function compute!(iter::RegionIterator) - extract!(iter; region_kwargs(extract!, iter)...) - update!(iter; region_kwargs(update!, iter)...) - insert!(iter; region_kwargs(insert!, iter)...) - - return iter -end - -region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...) - -# -# SweepIterator -# - -mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator - region_iter::RegionIterator{Problem} - sweep_kwargs::Iterators.Stateful{Iter} - which_sweep::Int - function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter} - stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs) - - first_state = Iterators.peel(stateful_sweep_kwargs) - - if isnothing(first_state) - throw(ArgumentError("Cannot construct a sweep iterator with 0 elements.")) - end - - first_kwargs, _ = first_state - region_iter = RegionIterator(problem; sweep = 1, first_kwargs...) - - return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1) - end -end - -islaststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs)) - -region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter -problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter)) - -state(sweep_iter::SweepIterator) = sweep_iter.which_sweep -Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs) -function increment!(sweep_iter::SweepIterator) - sweep_iter.which_sweep += 1 - sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs) - update_region_iterator!(sweep_iter; sweep_kwargs...) - return sweep_iter -end - -function update_region_iterator!(iterator::SweepIterator; kwargs...) - sweep = state(iterator) - iterator.region_iter = RegionIterator(problem(iterator); sweep, kwargs...) - return iterator -end - -function compute!(sweep_iter::SweepIterator) - for _ in sweep_iter.region_iter - # TODO: Is it sensible to execute the default region callback function? - end - return -end - -# More basic constructor where sweep_kwargs are constant throughout sweeps -function SweepIterator(problem, nsweeps::Int; sweep_kwargs...) - # Initialize this to an empty RegionIterator - sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps) - return SweepIterator(problem, sweep_kwargs_iter) -end diff --git a/src/sweeping/eigenproblem.jl b/src/sweeping/eigenproblem.jl new file mode 100644 index 0000000..36978b2 --- /dev/null +++ b/src/sweeping/eigenproblem.jl @@ -0,0 +1,44 @@ +import AlgorithmsInterface as AI +import .AlgorithmsInterfaceExtensions as AIE + +function dmrg(operator, algorithm, state) + problem = EigenProblem(operator) + return AI.solve(problem, algorithm; iterate = state).iterate +end +function dmrg(operator, state; kwargs...) + problem = EigenProblem(operator) + algorithm = select_algorithm(dmrg, operator, state; kwargs...) + return AI.solve(problem, algorithm; iterate = state).iterate +end + +# TODO: Allow specifying the region algorithm type? +function select_algorithm(::typeof(dmrg), operator, state; nsweeps, regions, kwargs...) + extended_kwargs = extend_columns((; kwargs...), nsweeps) + region_kwargs = rows(extended_kwargs) + return AIE.nested_algorithm(nsweeps) do i + return AIE.nested_algorithm(length(regions)) do j + return EigsolveRegion(regions[j]; region_kwargs[i]...) + end + end +end +#= + EigenProblem(operator) + +Represents the problem we are trying to solve and minimal algorithm-independent +information, so for an eigenproblem it is the operator we want the eigenvector of. +=# +struct EigenProblem{Operator} <: AIE.Problem + operator::Operator +end + +struct EigsolveRegion{R, Kwargs <: NamedTuple} <: AIE.NonIterativeAlgorithm + region::R + kwargs::Kwargs +end +EigsolveRegion(region; kwargs...) = EigsolveRegion(region, (; kwargs...)) + +function AI.solve!( + problem::EigenProblem, algorithm::EigsolveRegion, state::AIE.State; kwargs... + ) + return error("EigsolveRegion step for EigenProblem not implemented yet.") +end diff --git a/src/sweeping/utils.jl b/src/sweeping/utils.jl new file mode 100644 index 0000000..39e09e4 --- /dev/null +++ b/src/sweeping/utils.jl @@ -0,0 +1,12 @@ +# Utility functions for processing keyword arguments. +function repeat_last(v::AbstractVector, len::Int) + return [v; fill(v[end], max(len - length(v), 0))] +end +repeat_last(v, len::Int) = fill(v, len) +function extend_columns(nt::NamedTuple, len::Int) + return (; (keys(nt) .=> map(v -> repeat_last(v, len), values(nt)))...) +end +rowlength(nt::NamedTuple) = only(unique(length.(values(nt)))) +function rows(nt::NamedTuple, len::Int = rowlength(nt)) + return [(; (keys(nt) .=> map(v -> v[i], values(nt)))...) for i in 1:len] +end diff --git a/test/Project.toml b/test/Project.toml index 4b7dc81..e71e7a4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +AlgorithmsInterface = "d1e3940c-cd12-4505-8585-b0a4b322527d" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" @@ -26,7 +27,7 @@ DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" ITensorBase = "0.3" -ITensorNetworksNext = "0.2" +ITensorNetworksNext = "0.3" NamedDimsArrays = "0.8, 0.9" NamedGraphs = "0.6.8, 0.7, 0.8" QuadGK = "2.11.2" diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl new file mode 100644 index 0000000..8e0665c --- /dev/null +++ b/test/test_algorithmsinterfaceextensions.jl @@ -0,0 +1,472 @@ +import AlgorithmsInterface as AI +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +using Test: @test, @testset + +# Define test problems, algorithms, and states for testing +struct TestProblem <: AIE.Problem + data::Vector{Float64} +end + +@kwdef struct TestAlgorithm{StoppingCriterion <: AI.StoppingCriterion} <: AIE.Algorithm + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(10) +end + +@kwdef struct TestAlgorithmStep{StoppingCriterion <: AI.StoppingCriterion} <: AIE.Algorithm + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(5) +end + +function AI.step!( + problem::TestProblem, algorithm::TestAlgorithm, state::AIE.DefaultState; + logging_context_prefix = Symbol() + ) + state.iterate .+= 1 # Simple increment step + return state +end + +function AI.step!( + problem::TestProblem, algorithm::TestAlgorithmStep, state::AIE.DefaultState; + kwargs... + ) + state.iterate .+= 2 # Different increment step + return state +end + +@testset "AlgorithmsInterfaceExtensions" begin + @testset "DefaultState" begin + # Test DefaultState construction + iterate = [1.0, 2.0, 3.0] + stopping_criterion_state = AI.initialize_state( + TestProblem([1.0]), TestAlgorithm(), TestAlgorithm().stopping_criterion + ) + state = AIE.DefaultState(; iterate = copy(iterate), stopping_criterion_state) + @test state.iterate == iterate + @test state.iteration == 0 + @test state.stopping_criterion_state isa AI.StoppingCriterionState + + # Test DefaultState with custom iteration + state.iteration = 5 + @test state.iteration == 5 + end + + @testset "initialize_state!" begin + # Test initialize_state! with iterate kwarg + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm() + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + state = AIE.DefaultState(; + iteration = 2, iterate = [0.0, 0.0], stopping_criterion_state + ) + AI.initialize_state!(problem, algorithm, state) + @test state.iterate == [0.0, 0.0] + @test state.iteration == 0 + @test state.stopping_criterion_state == stopping_criterion_state + end + + @testset "initialize_state" begin + # Test initialize_state without exclamation + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm() + + state = AI.initialize_state(problem, algorithm; iterate = [0.0, 0.0]) + @test state isa AIE.DefaultState + @test state.iteration == 0 + end + + @testset "increment!" begin + # Test increment! with problem and algorithm + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm() + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + state = AIE.DefaultState(; iterate = [0.0, 0.0], stopping_criterion_state) + + # Increment and verify iteration counter increases + AI.increment!(problem, algorithm, state) + @test state.iteration == 1 + + AI.increment!(problem, algorithm, state) + @test state.iteration == 2 + end + + @testset "solve! and solve" begin + # Test solve! with simple problem + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(3)) + + initial_iterate = [10.0, 20.0] + state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) + + # Solve with custom initial iterate + initial_iterate = [5.0, 10.0] + final_state = AI.solve!( + problem, algorithm, state; iterate = copy(initial_iterate) + ) + + @test final_state.iteration == 3 + # Each step increments by 1, so after 3 steps: [5, 10] + 3 = [8, 13] + @test final_state.iterate ≈ [8.0, 13.0] + + # Test solve without exclamation + problem2 = TestProblem([1.0, 2.0]) + algorithm2 = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) + initial_iterate2 = [5.0, 10.0] + + final_state2 = AI.solve(problem2, algorithm2; iterate = copy(initial_iterate2)) + @test final_state2.iteration == 2 + @test final_state2.iterate ≈ [7.0, 12.0] + end + + @testset "DefaultAlgorithmIterator" begin + # Test algorithm iterator creation + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) + initial_iterate = [0.0, 0.0] + state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) + iterator = AIE.algorithm_iterator(problem, algorithm, state) + + @test iterator isa AIE.DefaultAlgorithmIterator + @test iterator.problem === problem + @test iterator.algorithm === algorithm + @test iterator.state === state + + # Test iteration interface + @test !AI.is_finished!(iterator) + + # Step through iterator + state_out, _ = iterate(iterator) + @test state_out.iteration == 1 + @test state_out.iterate ≈ [1.0, 1.0] # Incremented by step! + + state_out, _ = iterate(iterator) + @test state_out.iteration == 2 + + @test AI.is_finished!(iterator) + end + + @testset "with_algorithmlogger" begin + # Test with_algorithmlogger with functions + results = [] + function callback1(problem, algorithm, state) + push!(results, :callback1) + return nothing + end + function callback2(problem, algorithm, state) + push!(results, :callback2) + return nothing + end + + problem = TestProblem([1.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) + + # Test with CallbackAction (wrapped functions) + state = AIE.with_algorithmlogger( + :TestProblem_TestAlgorithm_PreStep => callback1, + :TestProblem_TestAlgorithm_PostStep => callback2, + ) do + return AI.solve(problem, algorithm; iterate = [0.0]) + end + @test results == [:callback1, :callback2] + end + + @testset "DefaultNestedAlgorithm" begin + # Test creating nested algorithm with function + nested_alg = AIE.nested_algorithm(3) do i + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + + @test nested_alg isa AIE.DefaultNestedAlgorithm + @test length(nested_alg.algorithms) == 3 + @test AIE.max_iterations(nested_alg) == 3 + + # Test stepping through nested algorithm + problem = TestProblem([1.0, 2.0]) + stopping_criterion_state = AI.initialize_state( + problem, nested_alg, nested_alg.stopping_criterion + ) + state = AIE.DefaultState(; iterate = [0.0, 0.0], stopping_criterion_state) + + initial_iterate = [0.0, 0.0] + AI.solve!( + problem, nested_alg, state; iterate = copy(initial_iterate) + ) + + @test state.iteration == 3 + # Each nested algorithm runs once with 2 steps, incrementing by 2 + # Total: 3 algorithms × 2 iterations × 2 increment = 12 + @test state.iterate ≈ [12.0, 12.0] + end + + @testset "NestedAlgorithm basic tests" begin + # Test basic nested algorithm functionality + nested_alg = AIE.nested_algorithm(2) do i + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + + problem = TestProblem([1.0, 2.0]) + + # Test state initialization + state_nested = AI.initialize_state(problem, nested_alg; iterate = [0.0, 0.0]) + + @test state_nested isa AIE.DefaultState + @test state_nested.iteration == 0 + @test AIE.max_iterations(nested_alg) == 2 + end + + @testset "increment! for nested algorithms" begin + # Test increment! logic for nested algorithm state + problem = TestProblem([1.0]) + nested_alg = AIE.nested_algorithm(2) do i + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + + stopping_criterion_state = AI.initialize_state( + problem, nested_alg, nested_alg.stopping_criterion + ) + state = AIE.DefaultState(; + iterate = [0.0], + stopping_criterion_state = stopping_criterion_state, + ) + + # Test progression through iterations + @test state.iteration == 0 + + AI.increment!(problem, nested_alg, state) + @test state.iteration == 1 + + AI.increment!(problem, nested_alg, state) + @test state.iteration == 2 + end + + @testset "get_subproblem and set_substate!" begin + # Test get_subproblem + problem = TestProblem([1.0, 2.0]) + nested_alg = AIE.nested_algorithm(2) do i + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(1)) + end + + stopping_criterion_state = AI.initialize_state( + problem, nested_alg, nested_alg.stopping_criterion + ) + state = AIE.DefaultState(; + iterate = [5.0, 10.0], + iteration = 1, + stopping_criterion_state, + ) + + subproblem, subalgorithm, substate = AIE.get_subproblem(problem, nested_alg, state) + @test subproblem === problem + @test subalgorithm === nested_alg.algorithms[1] + @test substate.iterate ≈ [5.0, 10.0] + + # Test set_substate! + new_substate = AIE.DefaultState(; + iterate = [100.0, 200.0], + substate.stopping_criterion_state, + ) + AIE.set_substate!(problem, nested_alg, state, new_substate) + @test state.iterate ≈ [100.0, 200.0] + end + + @testset "basetypenameof and default_logging_context_prefix" begin + # Test basetypenameof utility + problem = TestProblem([1.0]) + algorithm = TestAlgorithm() + + prefix_problem = AIE.default_logging_context_prefix(problem) + prefix_algorithm = AIE.default_logging_context_prefix(algorithm) + prefix_combined = AIE.default_logging_context_prefix(problem, algorithm) + + @test prefix_problem isa Symbol + @test prefix_algorithm isa Symbol + @test prefix_combined isa Symbol + @test contains(String(prefix_combined), String(prefix_problem)) + end + + @testset "DefaultFlattenedAlgorithm" begin + # Create nested algorithms that support max_iterations + nested_algs = map(1:3) do i + return AIE.nested_algorithm(1) do j + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + flattened_alg = AIE.DefaultFlattenedAlgorithm(; + algorithms = nested_algs, + stopping_criterion = AI.StopAfterIteration(6) # 3 algorithms × 2 iterations each + ) + + @test flattened_alg isa AIE.DefaultFlattenedAlgorithm + @test length(flattened_alg.algorithms) == 3 + + # Test state initialization + problem = TestProblem([1.0, 2.0]) + state_flat = AI.initialize_state(problem, flattened_alg; iterate = [0.0, 0.0]) + + @test state_flat isa AIE.DefaultFlattenedAlgorithmState + @test state_flat.iteration == 0 + @test state_flat.parent_iteration == 1 + @test state_flat.child_iteration == 0 + end + + @testset "DefaultFlattenedAlgorithmState increment!" begin + # Create nested algorithms for flattened algorithm + nested_algs = map(1:2) do i + return AIE.nested_algorithm(1) do j + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + flattened_alg = AIE.DefaultFlattenedAlgorithm(; + algorithms = nested_algs, + stopping_criterion = AI.StopAfterIteration(4), + ) + + problem = TestProblem([1.0]) + stopping_criterion_state = AI.initialize_state( + problem, flattened_alg, flattened_alg.stopping_criterion + ) + state = AIE.DefaultFlattenedAlgorithmState(; + iterate = [0.0], + stopping_criterion_state = stopping_criterion_state, + ) + + # Test initial state + @test state.iteration == 0 + @test state.parent_iteration == 1 + @test state.child_iteration == 0 + + # First increment - should increment child_iteration + AI.increment!(problem, flattened_alg, state) + @test state.iteration == 1 + @test state.parent_iteration == 1 + @test state.child_iteration == 1 + + # Second increment - should increment child_iteration again + AI.increment!(problem, flattened_alg, state) + @test state.iteration == 2 + @test state.parent_iteration == 2 # Should move to next parent + @test state.child_iteration == 1 + end + + @testset "FlattenedAlgorithm step!" begin + # Test individual step! calls for flattened algorithm + nested_algs = map(1:2) do i + return AIE.nested_algorithm(1) do j + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + flattened_alg = AIE.DefaultFlattenedAlgorithm(; + algorithms = nested_algs, + stopping_criterion = AI.StopAfterIteration(4) + ) + + problem = TestProblem([1.0, 2.0]) + state = AI.initialize_state(problem, flattened_alg; iterate = [0.0, 0.0]) + + # Manually step through to test step! functionality + AI.increment!(problem, flattened_alg, state) + @test state.parent_iteration == 1 + @test state.child_iteration == 1 + + AI.step!(problem, flattened_alg, state) + # The nested algorithm runs TestAlgorithmStep with 2 iterations, each incrementing by 2 + @test state.iterate ≈ [4.0, 4.0] + end + + @testset "flattened_algorithm helper" begin + # Test the flattened_algorithm helper function + nested_algs = map(1:2) do i + return AIE.nested_algorithm(1) do j + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + # Using the helper function + flattened_alg = AIE.flattened_algorithm(2) do i + AIE.nested_algorithm(1) do j + TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + @test flattened_alg isa AIE.DefaultFlattenedAlgorithm + @test length(flattened_alg.algorithms) == 2 + end + + @testset "AlgorithmIterator is_finished (without !)" begin + # Test is_finished without mutation + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) + initial_iterate = [0.0, 0.0] + state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) + iterator = AIE.algorithm_iterator(problem, algorithm, state) + + # Before any iterations + @test !AI.is_finished(iterator) + + # Run the algorithm + AI.solve!(problem, algorithm, state; iterate = copy(initial_iterate)) + + # After completion + @test AI.is_finished(iterator) + end + + @testset "AlgorithmIterator step!" begin + # Test step! method for iterator + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) + initial_iterate = [0.0, 0.0] + state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) + iterator = AIE.algorithm_iterator(problem, algorithm, state) + + # Step the iterator + AI.step!(iterator) + @test iterator.state.iterate ≈ [1.0, 1.0] + + AI.step!(iterator) + @test iterator.state.iterate ≈ [2.0, 2.0] + end + + @testset "NestedAlgorithm with different sub-algorithms" begin + # Test nested algorithm with varying sub-algorithms + nested_alg = AIE.DefaultNestedAlgorithm(; + algorithms = [ + TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)), + TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)), + TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)), + ] + ) + + @test AIE.max_iterations(nested_alg) == 3 + @test length(nested_alg.algorithms) == 3 + + problem = TestProblem([1.0, 2.0]) + state = AI.initialize_state(problem, nested_alg; iterate = [0.0, 0.0]) + + AI.solve!(problem, nested_alg, state; iterate = [0.0, 0.0]) + + # First algorithm: 1 iteration × 1 increment = 1 + # Second algorithm: 2 iterations × 2 increment = 4 + # Third algorithm: 1 iteration × 1 increment = 1 + # Total: 1 + 4 + 1 = 6 + @test state.iterate ≈ [6.0, 6.0] + @test state.iteration == 3 + end + + @testset "Edge cases" begin + # Test with single nested algorithm + nested_alg = AIE.nested_algorithm(1) do i + return TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) + end + + problem = TestProblem([1.0]) + state = AI.initialize_state(problem, nested_alg; iterate = [0.0]) + AI.solve!(problem, nested_alg, state; iterate = [0.0]) + + @test state.iterate ≈ [1.0] + @test state.iteration == 1 + end +end diff --git a/test/test_aqua.jl b/test/test_aqua.jl index 0afead5..a38563a 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -3,5 +3,5 @@ using Aqua: Aqua using Test: @testset @testset "Code quality (Aqua.jl)" begin - Aqua.test_all(ITensorNetworksNext) + Aqua.test_all(ITensorNetworksNext; persistent_tasks = false) end diff --git a/test/test_dmrg.jl b/test/test_dmrg.jl new file mode 100644 index 0000000..01f04ac --- /dev/null +++ b/test/test_dmrg.jl @@ -0,0 +1,34 @@ +import AlgorithmsInterface as AI +using ITensorNetworksNext: EigsolveRegion, dmrg, select_algorithm +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +using Test: @test, @testset + +@testset "select_algorithm(dmrg, ...)" begin + operator = "operator" + init = "init" + nsweeps = 3 + regions = ["region1", "region2"] + maxdim = [10, 20] + cutoff = 1.0e-7 + algorithm = select_algorithm(dmrg, operator, init; nsweeps, regions, maxdim, cutoff) + @test algorithm isa AIE.NestedAlgorithm + @test length(algorithm.algorithms) == nsweeps + + maxdims = [10, 20, 20] + cutoffs = [1.0e-7, 1.0e-7, 1.0e-7] + algorithm′ = AIE.nested_algorithm(nsweeps) do i + return AIE.nested_algorithm(length(regions)) do j + return EigsolveRegion( + regions[j]; + maxdim = maxdims[i], + cutoff = cutoffs[i], + ) + end + end + for i in 1:nsweeps + for j in 1:length(regions) + @test algorithm.algorithms[i].algorithms[j] == + algorithm′.algorithms[i].algorithms[j] + end + end +end diff --git a/test/test_iterators.jl b/test/test_iterators.jl deleted file mode 100644 index a17c7be..0000000 --- a/test/test_iterators.jl +++ /dev/null @@ -1,221 +0,0 @@ -using Test: @test, @testset, @test_throws -import ITensorNetworksNext as ITensorNetworks -using .ITensorNetworks: RegionIterator, SweepIterator, IncrementOnly, compute!, increment!, islaststep, state, eachregion - -module TestIteratorUtils - - import ITensorNetworksNext as ITensorNetworks - using .ITensorNetworks - - struct TestProblem <: ITensorNetworks.AbstractProblem - data::Vector{Int} - end - ITensorNetworks.region_plan(::TestProblem) = [:a => (; val = 1), :b => (; val = 2)] - function ITensorNetworks.compute!(iter::ITensorNetworks.RegionIterator{<:TestProblem}) - kwargs = ITensorNetworks.region_kwargs(iter) - push!(ITensorNetworks.problem(iter).data, kwargs.val) - return iter - end - - - mutable struct TestIterator <: ITensorNetworks.AbstractNetworkIterator - state::Int - max::Int - output::Vector{Int} - end - - ITensorNetworks.increment!(TI::TestIterator) = TI.state += 1 - Base.length(TI::TestIterator) = TI.max - ITensorNetworks.state(TI::TestIterator) = TI.state - function ITensorNetworks.compute!(TI::TestIterator) - push!(TI.output, ITensorNetworks.state(TI)) - return TI - end - - mutable struct SquareAdapter <: ITensorNetworks.AbstractNetworkIterator - parent::TestIterator - end - - Base.length(SA::SquareAdapter) = length(SA.parent) - ITensorNetworks.increment!(SA::SquareAdapter) = ITensorNetworks.increment!(SA.parent) - ITensorNetworks.state(SA::SquareAdapter) = ITensorNetworks.state(SA.parent) - function ITensorNetworks.compute!(SA::SquareAdapter) - ITensorNetworks.compute!(SA.parent) - return last(SA.parent.output)^2 - end - -end - -@testset "Iterators" begin - - import .TestIteratorUtils - - @testset "`AbstractNetworkIterator` Interface" begin - - @testset "Edge cases" begin - TI = TestIteratorUtils.TestIterator(1, 1, []) - cb = [] - @test islaststep(TI) - for _ in TI - @test islaststep(TI) - push!(cb, state(TI)) - end - @test length(cb) == 1 - @test length(TI.output) == 1 - @test only(cb) == 1 - - prob = TestIteratorUtils.TestProblem([]) - @test_throws ArgumentError SweepIterator(prob, 0) - @test_throws ArgumentError RegionIterator(prob, [], 1) - end - - TI = TestIteratorUtils.TestIterator(1, 4, []) - - @test !islaststep((TI)) - - # First iterator should compute only - rv, st = iterate(TI) - @test !islaststep((TI)) - @test !st - @test rv === TI - @test length(TI.output) == 1 - @test only(TI.output) == 1 - @test state(TI) == 1 - @test !st - - rv, st = iterate(TI, st) - @test !islaststep((TI)) - @test !st - @test length(TI.output) == 2 - @test state(TI) == 2 - @test TI.output == [1, 2] - - increment!(TI) - @test !islaststep((TI)) - @test state(TI) == 3 - @test length(TI.output) == 2 - @test TI.output == [1, 2] - - compute!(TI) - @test !islaststep((TI)) - @test state(TI) == 3 - @test length(TI.output) == 3 - @test TI.output == [1, 2, 3] - - # Final Step - iterate(TI, false) - @test islaststep((TI)) - @test state(TI) == 4 - @test length(TI.output) == 4 - @test TI.output == [1, 2, 3, 4] - - @test iterate(TI, false) === nothing - - TI = TestIteratorUtils.TestIterator(1, 5, []) - - cb = [] - - for _ in TI - @test length(cb) == length(TI.output) - 1 - @test cb == (TI.output)[1:(end - 1)] - push!(cb, state(TI)) - @test cb == TI.output - end - - @test islaststep((TI)) - @test length(TI.output) == 5 - @test length(cb) == 5 - @test cb == TI.output - - - TI = TestIteratorUtils.TestIterator(1, 5, []) - end - - @testset "Adapters" begin - TI = TestIteratorUtils.TestIterator(1, 5, []) - SA = TestIteratorUtils.SquareAdapter(TI) - - @testset "Generic" begin - - i = 0 - for rv in SA - i += 1 - @test rv isa Int - @test rv == i^2 - @test state(SA) == i - end - - @test islaststep((SA)) - - TI = TestIteratorUtils.TestIterator(1, 5, []) - SA = TestIteratorUtils.SquareAdapter(TI) - - SA_c = collect(SA) - - @test SA_c isa Vector - @test length(SA_c) == 5 - @test SA_c == [1, 4, 9, 16, 25] - - end - - @testset "IncrementOnly" begin - TI = TestIteratorUtils.TestIterator(1, 5, []) - NI = IncrementOnly(TI) - - NI_c = [] - - for _ in IncrementOnly(TI) - push!(NI_c, state(TI)) - end - - @test length(NI_c) == 5 - @test isempty(TI.output) - end - - @testset "EachRegion" begin - prob = TestIteratorUtils.TestProblem([]) - prob_region = TestIteratorUtils.TestProblem([]) - - SI = SweepIterator(prob, 5) - SI_region = SweepIterator(prob_region, 5) - - callback = [] - callback_region = [] - - let i = 1 - for _ in SI - push!(callback, i) - i += 1 - end - end - - @test length(callback) == 5 - - let i = 1 - for _ in eachregion(SI_region) - push!(callback_region, i) - i += 1 - end - end - - @test length(callback_region) == 10 - - @test prob.data == prob_region.data - - @test prob.data[1:2:end] == fill(1, 5) - @test prob.data[2:2:end] == fill(2, 5) - - - let i = 1, prob = TestIteratorUtils.TestProblem([]) - SI = SweepIterator(prob, 1) - cb = [] - for _ in eachregion(SI) - push!(cb, i) - i += 1 - end - @test length(cb) == 2 - end - - end - end -end diff --git a/test/test_sweeping.jl b/test/test_sweeping.jl new file mode 100644 index 0000000..215a8b8 --- /dev/null +++ b/test/test_sweeping.jl @@ -0,0 +1,65 @@ +import AlgorithmsInterface as AI +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +using Test: @test, @testset + +struct TestProblem <: AIE.Problem +end + +struct TestRegion{R, Kwargs <: NamedTuple} <: AIE.NonIterativeAlgorithm + region::R + kwargs::Kwargs +end +TestRegion(region; kwargs...) = TestRegion(region, (; kwargs...)) + +function AI.solve!(problem::TestProblem, algorithm::TestRegion, state::AIE.State; kwargs...) + new_iterate = (; algorithm.region, algorithm.kwargs.foo, algorithm.kwargs.bar) + state.iterate = [state.iterate; [new_iterate]] + return state +end + +@testset "Sweeping" begin + @testset "TestRegion" begin + algorithm = TestRegion("region"; foo = 1, bar = 2) + @test algorithm isa AIE.NonIterativeAlgorithm + @test algorithm isa AIE.Algorithm + @test algorithm isa AI.Algorithm + @test algorithm.region == "region" + @test algorithm.kwargs == (; foo = 1, bar = 2) + + problem = TestProblem() + iterate = [] + state = AI.solve(problem, algorithm; iterate) + @test state.iterate == [(; region = "region", foo = 1, bar = 2)] + end + @testset "Sweep" begin + algorithm = AIE.nested_algorithm(3) do i + return TestRegion("region$i"; foo = i, bar = 2i) + end + problem = TestProblem() + iterate = [] + state = AI.solve(problem, algorithm; iterate) + @test state.iterate == [ + (; region = "region1", foo = 1, bar = 2), + (; region = "region2", foo = 2, bar = 4), + (; region = "region3", foo = 3, bar = 6), + ] + end + @testset "Sweeping" begin + algorithm = AIE.nested_algorithm(2) do i + AIE.nested_algorithm(3) do j + return TestRegion("sweep$i, region$j"; foo = (i, j), bar = (2i, 2j)) + end + end + problem = TestProblem() + iterate = [] + state = AI.solve(problem, algorithm; iterate) + @test state.iterate == [ + (; region = "sweep1, region1", foo = (1, 1), bar = (2, 2)), + (; region = "sweep1, region2", foo = (1, 2), bar = (2, 4)), + (; region = "sweep1, region3", foo = (1, 3), bar = (2, 6)), + (; region = "sweep2, region1", foo = (2, 1), bar = (4, 2)), + (; region = "sweep2, region2", foo = (2, 2), bar = (4, 4)), + (; region = "sweep2, region3", foo = (2, 3), bar = (4, 6)), + ] + end +end From 032447a00de29e7a8fba27f76bb0ae6a8c193e26 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Tue, 23 Dec 2025 18:15:22 -0500 Subject: [PATCH 22/45] Upgrade to NamedDimsArrays.jl v0.11 (#38) --- Project.toml | 6 +++--- test/Project.toml | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index e6919fc..7b86558 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.3.0" +version = "0.3.1" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -33,7 +33,7 @@ ITensorNetworksNextTensorOperationsExt = "TensorOperations" [compat] AbstractTrees = "0.4.5" Adapt = "4.3" -AlgorithmsInterface = "0.1.0" +AlgorithmsInterface = "0.1" BackendSelection = "0.1.6" Combinatorics = "1" DataGraphs = "0.2.7" @@ -43,7 +43,7 @@ Dictionaries = "0.4.5" Graphs = "1.13.1" LinearAlgebra = "1.10" MacroTools = "0.5.16" -NamedDimsArrays = "0.8, 0.9" +NamedDimsArrays = "0.8, 0.9, 0.10, 0.11" NamedGraphs = "0.6.9, 0.7, 0.8" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" diff --git a/test/Project.toml b/test/Project.toml index e71e7a4..0e74eef 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -22,13 +22,14 @@ ITensorNetworksNext = {path = ".."} [compat] AbstractTrees = "0.4.5" +AlgorithmsInterface = "0.1" Aqua = "0.8.14" DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" -ITensorBase = "0.3" +ITensorBase = "0.3, 0.4" ITensorNetworksNext = "0.3" -NamedDimsArrays = "0.8, 0.9" +NamedDimsArrays = "0.8, 0.9, 0.10, 0.11" NamedGraphs = "0.6.8, 0.7, 0.8" QuadGK = "2.11.2" SafeTestsets = "0.1" From b256d79f250cc5f06b83885381879b8f0fa41f10 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:34:38 -0500 Subject: [PATCH 23/45] [LazyNamedDimsArrays] New `symnameddims` method that pulls out indices from an array. --- src/LazyNamedDimsArrays/symbolicnameddimsarray.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl b/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl index a215319..628baf3 100644 --- a/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl +++ b/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl @@ -5,6 +5,9 @@ const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} = function symnameddims(name, dims) return lazy(nameddims(SymbolicArray(name, dename.(dims)), dims)) end +function symnameddims(name, ndarray::AbstractNamedDimsArray) + return symnameddims(name, Tuple(inds(ndarray))) +end symnameddims(name) = symnameddims(name, ()) using AbstractTrees: AbstractTrees function AbstractTrees.printnode(io::IO, a::SymbolicNamedDimsArray) From b2da9d80a35da7ea5a2b51fb791a1115342cd8ca Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:35:32 -0500 Subject: [PATCH 24/45] The function `region_scalar` should now return a scalar, rather than a order-0 array --- src/beliefpropagation/abstractbeliefpropagationcache.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 0cae3fa..3545b53 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -62,7 +62,7 @@ function setfactor!(bpc::AbstractDataGraph, vertex, factor) end function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge) - return message(bp_cache, edge) * message(bp_cache, reverse(edge)) + return (message(bp_cache, edge) * message(bp_cache, reverse(edge)))[] end function region_scalar(bp_cache::AbstractGraph, vertex) @@ -70,7 +70,7 @@ function region_scalar(bp_cache::AbstractGraph, vertex) messages = incoming_messages(bp_cache, vertex) state = factors(bp_cache, vertex) - return reduce(*, messages) * reduce(*, state) + return (reduce(*, messages) * reduce(*, state))[] end message_type(bpc::AbstractGraph) = message_type(typeof(bpc)) From 8506e26a3d8814e3e51487a48469f27c9cd64a8f Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:37:43 -0500 Subject: [PATCH 25/45] Fix double counting in `edge_scalars` function This was caused by the change to the `cache` being backed by a directed graph. --- src/beliefpropagation/abstractbeliefpropagationcache.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 3545b53..8e7185e 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -81,7 +81,7 @@ function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache)) return map(v -> region_scalar(bp_cache, v), vertices) end -function edge_scalars(bp_cache::AbstractGraph, edges = edges(bp_cache)) +function edge_scalars(bp_cache::AbstractGraph, edges = edges(undirected_graph(underlying_graph(bp_cache)))) return map(e -> region_scalar(bp_cache, e), edges) end @@ -120,7 +120,9 @@ adapt_factors(to, bp_cache, vs = vertices(bp_cache)) = map_factors(adapt(to), bp abstract type AbstractBeliefPropagationCache{V, ED} <: AbstractDataGraph{V, Nothing, ED} end function free_energy(bp_cache::AbstractBeliefPropagationCache) + numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) + if any(t -> real(t) < 0, numerator_terms) numerator_terms = complex.(numerator_terms) end From 938180af0e35b3e091aa39bfa405a0dd5842d523 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:37:59 -0500 Subject: [PATCH 26/45] Minor code formatting --- src/beliefpropagation/abstractbeliefpropagationcache.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 8e7185e..0efc95d 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -130,7 +130,10 @@ function free_energy(bp_cache::AbstractBeliefPropagationCache) denominator_terms = complex.(denominator_terms) end - any(iszero, denominator_terms) && return -Inf + if any(iszero, denominator_terms) + return -Inf + end + return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) end partitionfunction(bp_cache::AbstractBeliefPropagationCache) = exp(free_energy(bp_cache)) From 44619673fedaf47c59bd2557222086807f12a2ec Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:39:43 -0500 Subject: [PATCH 27/45] Expressed belief propagation in terms of AlgorithmsInterface --- .../beliefpropagationcache.jl | 13 + .../beliefpropagationproblem.jl | 279 +++++++++++++----- src/sweeping/utils.jl | 8 +- test/test_beliefpropagation.jl | 10 +- 4 files changed, 222 insertions(+), 88 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index c9793e6..27a580d 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -23,6 +23,7 @@ using NamedGraphs.GraphsExtensions: default_root_vertex, is_path_graph, undirected_graph using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, QuotientEdges, quotient_graph, quotientedges +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, lazy, parenttype struct BeliefPropagationCache{V, G <: AbstractGraph{V}, N <: AbstractGraph{V}, ET, MT} <: AbstractBeliefPropagationCache{V, MT} @@ -125,3 +126,15 @@ function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) data = map(e -> bpc[QuotientEdge(e)], inds) return BeliefPropagationCache(QuotientView(network(bpc)), data) end + +function default_message(bpc::BeliefPropagationCache, edge) + return default_message(message_type(bpc), network(bpc), edge) +end +function default_message(T::Type, network, edge) + array = ones(Tuple(linkinds(network, edge))) + return convert(T, array) +end +function default_message(T::Type{<:LazyNamedDimsArray}, network, edge) + message = default_message(parenttype(T), network, edge) + return convert(T, lazy(message)) +end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 24b024d..0d997ee 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,82 +1,200 @@ -using Graphs: SimpleGraph, vertices, edges, has_edge +using Graphs: SimpleGraph, vertices, edges, has_edge, AbstractEdge using NamedGraphs: AbstractNamedGraph, position_graph using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices using NamedGraphs.OrderedDictionaries: OrderedDictionary, OrderedIndices using NamedDimsArrays: AbstractNamedDimsArray -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, parenttype, lazy +using DataGraphs: edge_data +import AlgorithmsInterface as AI +import .AlgorithmsInterfaceExtensions as AIE -abstract type AbstractBeliefPropagationProblem{Alg} <: AbstractProblem end +@kwdef struct StopWhenConverged <: AI.StoppingCriterion + tol::Float64 = 0.0 +end -mutable struct BeliefPropagationProblem{Alg, Cache} <: AbstractBeliefPropagationProblem{Alg} - const alg::Alg - const cache::Cache - diff::Union{Nothing, Float64} +@kwdef mutable struct StopWhenConvergedState <: AI.StoppingCriterionState + delta::Float64 = Inf end -BeliefPropagationProblem(alg, cache) = BeliefPropagationProblem(alg, cache, nothing) +function AI.initialize_state(::AIE.Problem, ::AIE.Algorithm, ::StopWhenConverged) + return StopWhenConvergedState() +end -function default_algorithm( - ::Type{<:Algorithm"bp"}, - bpc; - verbose = false, - tol = nothing, - edge_sequence = forest_cover_edge_sequence(bpc), - message_update_alg = default_algorithm(Algorithm"contract"), - maxiter = is_tree(bpc) ? 1 : nothing, +function AI.initialize_state!( + ::AIE.Problem, + ::AIE.Algorithm, + ::StopWhenConverged, + st::StopWhenConvergedState, ) - return Algorithm("bp"; verbose, tol, edge_sequence, message_update_alg, maxiter) + st.delta = Inf + return st end -function region_plan(prob::BeliefPropagationProblem{<:Algorithm"bp"}; sweep_kwargs...) - edges = prob.alg.edge_sequence +function AI.is_finished!( + ::AIE.Problem, + ::AIE.Algorithm, + state::AIE.State, + c::StopWhenConverged, + st::StopWhenConvergedState, + ) - plan = map(edges) do e - return e => (; sweep_kwargs...) + # maxdiff = 0.0 initially, so skip this the first time. + if state.iteration > 0 + st.delta = state.iterate.maxdiff end - return plan + return st.delta < c.tol +end + +struct BeliefPropagationProblem{Network} <: AIE.Problem + network::Network +end + +@kwdef mutable struct BeliefPropagationState{ + Iterate <: BeliefPropagationCache, + Diffs, + } <: AIE.NonIterativeAlgorithmState + iterate::Iterate + diffs::Diffs = similar(edge_data(iterate), Float64) + maxdiff::Float64 = 0.0 +end + +function AI.initialize_state( + problem::BeliefPropagationProblem, + algorithm::AIE.NonIterativeAlgorithm; iterate, kwargs... + ) + + diffs = iterate.diffs + maxdiff = iterate.maxdiff + + return BeliefPropagationState(; iterate = iterate.iterate, diffs, maxdiff, kwargs...) end -function compute!(iter::RegionIterator{<:BeliefPropagationProblem{<:Algorithm"bp"}}) - prob = iter.problem +# This gets called at the start of every sweep. +function AI.initialize_state!( + problem::BeliefPropagationProblem, + algorithm::AIE.NestedAlgorithm, + state::AIE.State, + ) + state.iterate.maxdiff = 0.0 + return state +end + +function AIE.set_substate!( + ::BeliefPropagationProblem, + ::AIE.NestedAlgorithm, + state::AIE.State, + substate::BeliefPropagationState + ) + + state.iterate = substate + + return state +end - edge, _ = current_region_plan(iter) - new_message = updated_message(prob.alg.message_update_alg, prob.cache, edge) - setmessage!(prob.cache, edge, new_message) +abstract type AbstractMessageUpdate <: AIE.NonIterativeAlgorithm end - return iter +struct SimpleMessageUpdate{E <: AbstractEdge, Kwargs <: NamedTuple} <: AbstractMessageUpdate + edge::E + kwargs::Kwargs end -function default_message(bpc::BeliefPropagationCache, edge) - return default_message(message_type(bpc), network(bpc), edge) +function SimpleMessageUpdate( + edge; + normalize = false, + contraction_alg = "eager", + compute_diff = false, + kwargs... + ) + return SimpleMessageUpdate(edge, (; normalize, contraction_alg, compute_diff, kwargs...)) end -function default_message(T::Type, network, edge) - array = ones(Tuple(linkinds(network, edge))) - return convert(T, array) + +function Base.getproperty(alg::SimpleMessageUpdate, name::Symbol) + if name in (:edge, :kwargs) + return getfield(alg, name) + else + return getproperty(getfield(alg, :kwargs), name) + end end -function default_message(T::Type{<:LazyNamedDimsArray}, network, edge) - message = default_message(parenttype(T), network, edge) - return convert(T, lazy(message)) + +struct MessageUpdateProblem{Messages, Factors} <: AIE.Problem + messages::Messages + factors::Factors end -updated_message(alg, bpc, edge) = not_implemented() -function updated_message(alg::Algorithm"contract", bpc, edge) +function AI.solve!( + problem::BeliefPropagationProblem, + algorithm::AbstractMessageUpdate, + state::BeliefPropagationState; + logging_context_prefix = default_logging_context_prefix(problem, algorithm), + ) + + logger = AI.algorithm_logger() + + cache = state.iterate + edge = algorithm.edge + + AI.emit_message( + logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreUpdate) + ) + + new_message = updated_message(algorithm, cache) + + if algorithm.compute_diff + diff = message_diff(new_message, cache[edge]) + + if diff > state.maxdiff + state.maxdiff = diff + end + + state.diffs[edge] = diff + end + + setmessage!(cache, edge, new_message) + + AI.emit_message( + logger, problem, algorithm, state, Symbol(logging_context_prefix, :PostUpdate) + ) + + return state +end + +message_diff(m1, m2) = LinearAlgebra.norm(m1 - m2) + +function updated_message(algorithm, cache) + edge = algorithm.edge + vertex = src(edge) + messages = incoming_messages(cache, vertex; ignore_edges = typeof(edge)[reverse(edge)]) + + update_problem = MessageUpdateProblem(messages, factors(cache, vertex)) + + message_state = AI.solve(update_problem, algorithm; iterate = message(cache, edge)) - incoming_ms = incoming_messages( - bpc, vertex; ignore_edges = typeof(edge)[reverse(edge)] + return message_state.iterate +end + +function AI.solve!( + problem::MessageUpdateProblem, + algorithm::SimpleMessageUpdate, + state::AIE.NonIterativeAlgorithmState; + logging_context_prefix = AI.default_logging_context_prefix(problem, algorithm), + kwargs... ) - updated_message = contract_messages(alg.contraction_alg, factors(bpc, vertex), incoming_ms) + # TODO: logging... - if alg.normalize - message_norm = LinearAlgebra.norm(updated_message) + state.iterate = contract_messages(algorithm.contraction_alg, problem.factors, problem.messages) + + if algorithm.normalize + # TODO: use `sum` not `norm` + message_norm = LinearAlgebra.norm(state.iterate) if !iszero(message_norm) - updated_message /= message_norm + state.iterate /= message_norm end end - return updated_message + + return state end contract_messages(alg, factors, messages) = not_implemented() @@ -85,54 +203,51 @@ function contract_messages( factors::Vector{<:AbstractArray}, messages::Vector{<:AbstractArray}, ) - return contract_network(alg, vcat(factors, messages)) + return contract_network(vcat(factors, messages); alg) end -function default_algorithm( - ::Type{<:Algorithm"contract"}; normalize = true, contraction_alg = Algorithm("exact") - ) - return Algorithm("contract"; normalize, contraction_alg) -end -function default_algorithm( - ::Type{<:Algorithm"adapt_update"}; adapt, alg = default_algorithm(Algorithm"contract") - ) - return Algorithm("adapt_update"; adapt, alg) -end +beliefpropagation(network; kwargs...) = beliefpropagation(BeliefPropagationCache(network); kwargs...) +function beliefpropagation(cache::AbstractBeliefPropagationCache; kwargs...) -function update_message!( - message_update_alg::Algorithm, bpc::BeliefPropagationCache, edge::AbstractEdge - ) - return setmessage!(bpc, edge, updated_message(message_update_alg, bpc, edge)) -end + problem = BeliefPropagationProblem(network(cache)) -function update(bpc::AbstractBeliefPropagationCache; kwargs...) - return update(default_algorithm(Algorithm"bp", bpc; kwargs...), bpc) -end + algorithm = select_algorithm(beliefpropagation, cache; kwargs...) -function update(alg, bpc) - compute_error = !isnothing(alg.tol) + # The nested algorithms will wrap and manipulate this object. + base_state = BeliefPropagationState(; iterate = cache) - diff = compute_error ? 0.0 : nothing + state = AI.solve(problem, algorithm; iterate = base_state) - prob = BeliefPropagationProblem(alg, bpc, diff) + return state.iterate.iterate +end - iter = SweepIterator(prob, alg.maxiter; compute_error) +function select_algorithm( + ::typeof(beliefpropagation), + cache; + edges = forest_cover_edge_sequence(network(cache)), + maxiter = is_tree(network(cache)) ? 1 : nothing, + tol = 0.0, + kwargs... + ) - for _ in iter - if compute_error && prob.diff <= alg.tol - break - end + if isnothing(maxiter) + throw(ArgumentError("`maxiter` must be specified for non-tree graphs")) end - if alg.verbose && compute_error - if prob.diff <= alg.tol - println("BP converged to desired precision after $(iter.which_sweep) iterations.") - else - println( - "BP failed to converge to precision $(alg.tol), got $(prob.diff) after $(iter.which_sweep) iterations", - ) - end + stopping_criterion = AI.StopAfterIteration(maxiter) + compute_diff = false + + if tol > 0.0 + stopping_criterion = stopping_criterion | StopWhenConverged(tol) + compute_diff = true end - return bpc + extended_kwargs = extend_columns((; compute_diff, kwargs...), maxiter) + edge_kwargs = rows(extended_kwargs, len = maxiter) + + return AIE.nested_algorithm(maxiter; stopping_criterion) do repnum + return AIE.nested_algorithm(length(edges)) do edgenum + return SimpleMessageUpdate(edges[edgenum]; edge_kwargs[repnum]...) + end + end end diff --git a/src/sweeping/utils.jl b/src/sweeping/utils.jl index 39e09e4..9a39c9d 100644 --- a/src/sweeping/utils.jl +++ b/src/sweeping/utils.jl @@ -7,6 +7,12 @@ function extend_columns(nt::NamedTuple, len::Int) return (; (keys(nt) .=> map(v -> repeat_last(v, len), values(nt)))...) end rowlength(nt::NamedTuple) = only(unique(length.(values(nt)))) -function rows(nt::NamedTuple, len::Int = rowlength(nt)) +function rows(nt::NamedTuple; len = nothing) + if isnothing(len) + if isempty(nt) + throw(ArgumentError("Got empty named tuple; keyword `len` must be specified in this case.")) + end + len = rowlength(nt) + end return [(; (keys(nt) .=> map(v -> v[i], values(nt)))...) for i in 1:len] end diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index a39e1a6..8c7829b 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -20,7 +20,7 @@ using Test: @test, @testset @testset "BeliefPropagation" begin #Chain of tensors - dims = (4, 1) + dims = (2, 1) g = named_grid(dims) l = Dict(e => Index(2) for e in edges(g)) l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) @@ -30,10 +30,10 @@ using Test: @test, @testset end bpc = BeliefPropagationCache(tn) - bpc = ITensorNetworksNext.update(bpc; maxiter = 1) + bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test abs(z_bp - z_exact) <= 1.0e-14 + @test z_bp ≈ z_exact atol = 1.0e-14 #Tree of tensors dims = (4, 3) @@ -46,8 +46,8 @@ using Test: @test, @testset end bpc = BeliefPropagationCache(tn) - bpc = ITensorNetworksNext.update(bpc; maxiter = 10) + bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test abs(z_bp - z_exact) <= 1.0e-14 + @test z_bp ≈ z_exact atol = 1.0e-12 end From d68860ae59092f2382fccfee87d03abe9a097b58 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:40:23 -0500 Subject: [PATCH 28/45] Fixes to TensorNetwork construction from tensor list --- src/abstracttensornetwork.jl | 4 ++-- src/tensornetwork.jl | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index b820867..08f86a1 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -1,7 +1,7 @@ using Adapt: Adapt, adapt, adapt_structure using BackendSelection: @Algorithm_str, Algorithm using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, underlying_graph, - underlying_graph_type, vertex_data + underlying_graph_type, vertex_data, set_vertex_data! using Dictionaries: Dictionary using Graphs: Graphs, AbstractEdge, AbstractGraph, Graph, add_edge!, add_vertex!, bfs_tree, center, dst, edges, edgetype, ne, neighbors, nv, rem_edge!, src, vertices @@ -111,7 +111,7 @@ function sitenames(tn::AbstractGraph, edge::AbstractEdge) end function setindex_preserve_graph!(tn::AbstractGraph, value, vertex) - set!(vertex_data(tn), vertex, value) + set_vertex_data!(tn, value, vertex) return tn end diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 16c80e3..b811e2b 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -35,8 +35,13 @@ struct TensorNetwork{V, VD, UG <: AbstractGraph{V}, Tensors <: AbstractDictionar end end # This assumes the tensor connectivity matches the graph structure. +function TensorNetwork(graph::AbstractGraph, tensors) + return TensorNetwork(graph, Dictionary(keys(tensors), values(tensors))) +end function TensorNetwork(graph::AbstractGraph, tensors::AbstractDictionary) - return _TensorNetwork(graph, Dictionary(keys(tensors), values(tensors))) + tn = _TensorNetwork(graph, tensors) + fix_links!(tn) + return tn end function TensorNetwork{V, VD, UG, Tensors}(graph::UG) where {V, VD, UG <: AbstractGraph{V}, Tensors} @@ -80,11 +85,6 @@ tensornetwork_edges(tensors) = tensornetwork_edges(NamedEdge, tensors) function TensorNetwork(f::Base.Callable, graph::AbstractGraph) return TensorNetwork(graph, Dictionary(map(f, vertices(graph)))) end -function TensorNetwork(graph::AbstractGraph, tensors) - tn = _TensorNetwork(graph, tensors) - fix_links!(tn) - return tn -end # Insert trivial links for missing edges, and also check # the vertices and edges are consistent between the graph and tensors. @@ -172,6 +172,7 @@ function PartitionedGraphs.departition( return TensorNetwork(departition(underlying_graph(tn)), vertex_data(tn)) end +# When getting data according the quotient vertices, take a lazy contraction. function DataGraphs.get_vertices_data(tn::TensorNetwork, vertex::QuotientVertexVertices) data = collect(map(v -> tn[v], NamedGraphs.parent_graph_indices(vertex))) return mapreduce(lazy, *, data) From 2f5c783f4760d813777e392321c97028f05b3f99 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:41:18 -0500 Subject: [PATCH 29/45] Minor simplifications to `contract_network` interface. --- src/contract_network.jl | 91 ++++++++++++++++------------------- test/test_contract_network.jl | 12 ++--- 2 files changed, 48 insertions(+), 55 deletions(-) diff --git a/src/contract_network.jl b/src/contract_network.jl index e89fa00..4511595 100644 --- a/src/contract_network.jl +++ b/src/contract_network.jl @@ -1,69 +1,62 @@ using BackendSelection: @Algorithm_str, Algorithm using Base.Broadcast: materialize -using ITensorNetworksNext.LazyNamedDimsArrays: Mul, lazy, optimize_evaluation_order, +using NamedDimsArrays: inds +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArrays, Mul, lazy, optimize_evaluation_order, substitute, symnameddims -# This is related to `MatrixAlgebraKit.select_algorithm`. -# TODO: Define this in BackendSelection.jl. -backend_value(::Algorithm{alg}) where {alg} = alg -using BackendSelection: parameters -function merge_parameters(alg::Algorithm; kwargs...) - return Algorithm(backend_value(alg); merge(parameters(alg), kwargs)...) +function contract_network(tn; alg = default_kwargs(contract_network, tn).alg) + return contract_network(alg, tn) end -to_algorithm(alg::Algorithm; kwargs...) = merge_parameters(alg; kwargs...) -to_algorithm(alg; kwargs...) = Algorithm(alg; kwargs...) -# `contract_network` -function contract_network(alg::Algorithm, tn) - return throw(ArgumentError("`contract_network` algorithm `$(alg)` not implemented.")) -end -function default_kwargs(::typeof(contract_network), tn) - return (; alg = Algorithm"exact"(; order_alg = Algorithm"eager"())) -end -function contract_network(tn; alg = default_kwargs(contract_network, tn).alg, kwargs...) - return contract_network(to_algorithm(alg; kwargs...), tn) +contract_network(alg::String, tn) = contract_network(Algorithm(alg), tn) + +default_kwargs(::typeof(contract_network), tn) = (; alg = "eager") + +function contract_network( + alg, + tensors, + ) + + order = contraction_expression(tensors; order = alg) + symbols_to_tensors = Dict( + symnameddims(i, tensors[i]) => lazy(tensors[i]) for i in keys(tensors) + ) + + return materialize(substitute(order, symbols_to_tensors)) end -# `contract_network(::Algorithm"exact", ...)` -function get_order(alg::Algorithm"exact", tn) - # Allow specifying either `order` or `order_alg`. - order = get(alg, :order, nothing) - order = if !isnothing(order) - order - else - default_order_alg = default_kwargs(contraction_order, tn).alg - order_alg = get(alg, :order_alg, default_order_alg) - # TODO: Capture other keyword arguments and pass them to `contraction_order`. - contraction_order(tn; alg = order_alg) - end +# `contraction_order` +function contraction_order end +default_kwargs(::typeof(contraction_order), tensors) = (; order = "eager") + +function contraction_expression(tensors; order = default_kwargs(contraction_order, tensors).order) + order = contraction_order(order, tensors) + # Contraction order may or may not have indices attached, canonicalize the format # by attaching indices. - subs = Dict(symnameddims(i) => symnameddims(i, Tuple(inds(tn[i]))) for i in keys(tn)) + subs = Dict(symnameddims(i) => symnameddims(i, tensors[i]) for i in keys(tensors)) + return substitute(order, subs) end -function contract_network(alg::Algorithm"exact", tn) - order = get_order(alg, tn) - syms_to_ts = Dict(symnameddims(i, Tuple(inds(tn[i]))) => lazy(tn[i]) for i in keys(tn)) - tn_expression = substitute(order, syms_to_ts) - return materialize(tn_expression) -end -# `contraction_order` -function contraction_order end -default_kwargs(::typeof(contraction_order), tn) = (; alg = Algorithm"eager"()) -function contraction_order(tn; alg = default_kwargs(contraction_order, tn).alg, kwargs...) - return contraction_order(to_algorithm(alg; kwargs...), tn) +contraction_order(order, tensors) = order +function contraction_order(tensors; order = default_kwargs(contraction_order, tensors).order) + return contraction_order(Algorithm(order), tensors) end # Convert the tensor network to a flat symbolic multiplication expression. -function contraction_order(alg::Algorithm"flat", tn) +function contraction_order(::Algorithm"flat", tensors) # Same as: `reduce((a, b) -> *(a, b; flatten = true), syms)`. - syms = vec([symnameddims(i, Tuple(inds(tn[i]))) for i in keys(tn)]) + syms = vec([symnameddims(i, Tuple(inds(tensors[i]))) for i in keys(tensors)]) return lazy(Mul(syms)) end -function contraction_order(alg::Algorithm"left_associative", tn) - return prod(i -> symnameddims(i, Tuple(inds(tn[i]))), keys(tn)) +function contraction_order(::Algorithm"left_associative", tensors) + return prod(i -> symnameddims(i, Tuple(inds(tensors[i]))), keys(tensors)) end -function contraction_order(alg::Algorithm, tn) - s = contraction_order(Algorithm"flat"(), tn) - return optimize_evaluation_order(s; alg) + +function contraction_order( + order_algorithm::Algorithm, + tensors, + ) + order = contraction_order(tensors; order = "flat") + return optimize_evaluation_order(order; alg = order_algorithm) end diff --git a/test/test_contract_network.jl b/test/test_contract_network.jl index c9abfdd..b5ff72e 100644 --- a/test/test_contract_network.jl +++ b/test/test_contract_network.jl @@ -14,9 +14,9 @@ using Test: @test, @testset C = ITensor([5.0, 1.0], j) D = ITensor([-2.0, 3.0, 4.0, 5.0, 1.0], k) - ABCD_1 = contract_network([A, B, C, D]; order_alg = "left_associative") - ABCD_2 = contract_network([A, B, C, D]; order_alg = "eager") - ABCD_3 = contract_network([A, B, C, D]; order_alg = "optimal") + ABCD_1 = contract_network([A, B, C, D]; alg = "left_associative") + ABCD_2 = contract_network([A, B, C, D]; alg = "eager") + ABCD_3 = contract_network([A, B, C, D]; alg = "optimal") @test ABCD_1 == ABCD_2 == ABCD_3 end @@ -31,9 +31,9 @@ using Test: @test, @testset return randn(Tuple(is)) end - z1 = contract_network(tn; order_alg = "left_associative")[] - z2 = contract_network(tn; order_alg = "eager")[] - z3 = contract_network(tn; order_alg = "optimal")[] + z1 = contract_network(tn; alg = "left_associative")[] + z2 = contract_network(tn; alg = "eager")[] + z3 = contract_network(tn; alg = "optimal")[] @test abs(z1 - z2) / abs(z1) <= 1.0e3 * eps(Float64) @test abs(z1 - z3) / abs(z1) <= 1.0e3 * eps(Float64) From 4eec9b65e4917c3feb11926ccf61207773833e2b Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 10 Feb 2026 11:50:00 -0500 Subject: [PATCH 30/45] Upgrade DataGraphs and NamedGraphs dependencies --- src/abstracttensornetwork.jl | 20 +----- .../abstractbeliefpropagationcache.jl | 19 +++--- .../beliefpropagationcache.jl | 63 ++++++++++--------- src/tensornetwork.jl | 40 +++++++++--- test/Project.toml | 4 +- 5 files changed, 79 insertions(+), 67 deletions(-) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index 08f86a1..671ba3a 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -16,7 +16,8 @@ using NamedGraphs.GraphsExtensions: incident_edges, rem_edges!, rename_vertices, - vertextype + vertextype, + similar_graph using SplitApplyCombine: flatten using NamedGraphs.SimilarType: similar_type @@ -25,7 +26,7 @@ abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} # Need to be careful about removing edges from tensor networks in case there is a bond Graphs.rem_edge!(::AbstractTensorNetwork, edge) = not_implemented() -DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = not_implemented() +DataGraphs.edge_data_type(::Type{<:AbstractTensorNetwork}) = not_implemented() # Graphs.jl overloads function Graphs.weights(graph::AbstractTensorNetwork) @@ -235,18 +236,3 @@ function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) end Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) - -function Graphs.induced_subgraph(graph::AbstractTensorNetwork{V}, subvertices::Vector{V}) where {V} - return tensornetwork_induced_subgraph(graph, subvertices) -end - -function tensornetwork_induced_subgraph(graph, subvertices) - underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) - subgraph = similar_type(graph)(underlying_subgraph) - for v in vertices(subgraph) - if isassigned(graph, v) - set!(vertex_data(subgraph), v, graph[v]) - end - end - return subgraph, vlist -end diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 0efc95d..b77fb4e 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -1,15 +1,12 @@ using Graphs: AbstractGraph, AbstractEdge -using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_eltype +using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_type using NamedGraphs.GraphsExtensions: boundary_edges using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, parent messages(bp_cache::AbstractGraph) = edge_data(bp_cache) messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges] -function message(bp_cache::AbstractGraph, edge::AbstractEdge) - ms = messages(bp_cache) - return get!(ms, edge, default_message(bp_cache, edge)) -end +message(bp_cache::AbstractGraph, edge::AbstractEdge) = messages(bp_cache)[edge] deletemessage!(bp_cache::AbstractGraph, edge) = not_implemented() function deletemessage!(bp_cache::AbstractDataGraph, edge) @@ -52,7 +49,7 @@ factors(bpc::AbstractGraph) = vertex_data(bpc) factors(bpc::AbstractGraph, vertices::Vector) = [factor(bpc, v) for v in vertices] factors(bpc::AbstractGraph{V}, vertex::V) where {V} = factors(bpc, V[vertex]) -factor(bpc::AbstractGraph, vertex) = factors(bpc)[vertex] +factor(bpc::AbstractGraph, vertex) = bpc[vertex] setfactor!(bpc::AbstractGraph, vertex, factor) = not_implemented() function setfactor!(bpc::AbstractDataGraph, vertex, factor) @@ -75,7 +72,7 @@ end message_type(bpc::AbstractGraph) = message_type(typeof(bpc)) message_type(G::Type{<:AbstractGraph}) = eltype(Base.promote_op(messages, G)) -message_type(type::Type{<:AbstractDataGraph}) = edge_data_eltype(type) +message_type(type::Type{<:AbstractDataGraph}) = edge_data_type(type) function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache)) return map(v -> region_scalar(bp_cache, v), vertices) @@ -117,7 +114,13 @@ end adapt_messages(to, bp_cache, es = edges(bp_cache)) = map_messages(adapt(to), bp_cache, es) adapt_factors(to, bp_cache, vs = vertices(bp_cache)) = map_factors(adapt(to), bp_cache, vs) -abstract type AbstractBeliefPropagationCache{V, ED} <: AbstractDataGraph{V, Nothing, ED} end +abstract type AbstractBeliefPropagationCache{V, VD, ED} <: AbstractDataGraph{V, VD, ED} end + +factor_type(bpc::AbstractBeliefPropagationCache) = typeof(bpc) +factor_type(::Type{<:AbstractBeliefPropagationCache{<:Any, VD}}) where {VD} = VD + +message_type(bpc::AbstractBeliefPropagationCache) = message_type(typeof(bpc)) +message_type(::Type{<:AbstractBeliefPropagationCache{<:Any, <:Any, ED}}) where {ED} = ED function free_energy(bp_cache::AbstractBeliefPropagationCache) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 27a580d..10ab586 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -2,20 +2,19 @@ using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph, - has_edge_data, get_vertex_data, get_edge_data, set_vertex_data!, set_edge_data!, - unset_vertex_data!, - unset_edge_data!, - vertex_data_eltype, - edge_data_eltype, + vertex_data_type, + edge_data_type, underlying_graph, - underlying_graph_type + underlying_graph_type, + is_vertex_assigned, + is_edge_assigned using Dictionaries: Dictionary, set!, delete! using Graphs: AbstractGraph, is_tree, connected_components, is_directed -using NamedGraphs: NamedDiGraph, convert_vertextype, parent_graph_indices +using NamedGraphs: NamedDiGraph, convert_vertextype, parent_graph_indices, Vertices using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, @@ -25,22 +24,23 @@ using NamedGraphs.GraphsExtensions: default_root_vertex, using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, QuotientEdges, quotient_graph, quotientedges using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, lazy, parenttype -struct BeliefPropagationCache{V, G <: AbstractGraph{V}, N <: AbstractGraph{V}, ET, MT} <: - AbstractBeliefPropagationCache{V, MT} +struct BeliefPropagationCache{V, VD, ED, G <: AbstractGraph{V}, N <: AbstractGraph{V}, E} <: + AbstractBeliefPropagationCache{V, VD, ED} underlying_graph::G # we only use this for the edges. network::N - messages::Dictionary{ET, MT} + messages::Dictionary{E, ED} function BeliefPropagationCache(network::AbstractGraph, messages::Dictionary) V = vertextype(network) + VD = vertex_data_type(network) N = typeof(network) ET = keytype(messages) - MT = eltype(messages) + ED = eltype(messages) # Construct a directed graph version of the underlying graph of the tensor network. digraph = directed_graph(underlying_graph(network)) - bpc = new{V, typeof(digraph), N, ET, MT}(digraph, network, messages) + bpc = new{V, VD, ED, typeof(digraph), N, ET}(digraph, network, messages) for edge in edges(bpc) get!(() -> default_message(bpc, edge), messages, edge) @@ -53,8 +53,8 @@ network(bp_cache) = getfield(bp_cache, :network) DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = getfield(bpc, :underlying_graph) -DataGraphs.has_vertex_data(bpc::BeliefPropagationCache, vertex) = has_vertex_data(network(bpc), vertex) -DataGraphs.has_edge_data(bpc::BeliefPropagationCache, edge) = haskey(bpc.messages, edge) +DataGraphs.is_vertex_assigned(bpc::BeliefPropagationCache, vertex) = is_vertex_assigned(network(bpc), vertex) +DataGraphs.is_edge_assigned(bpc::BeliefPropagationCache, edge) = haskey(bpc.messages, edge) DataGraphs.get_vertex_data(bpc::BeliefPropagationCache, vertex) = get_vertex_data(network(bpc), vertex) DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) = bpc.messages[edge] @@ -62,20 +62,8 @@ DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) = bpc. DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = set_vertex_data!(network(bpc), val, vertex) DataGraphs.set_edge_data!(bpc::BeliefPropagationCache, val, edge) = set!(bpc.messages, edge, val) -DataGraphs.unset_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = unset_vertex_data!(network(bpc), val, vertex) -DataGraphs.unset_edge_data!(bpc::BeliefPropagationCache, val, edge) = unset!(bpc.messages, edge, val) - -function DataGraphs.vertex_data_eltype(T::Type{<:BeliefPropagationCache}) - return vertex_data_eltype(fieldtype(T, :network)) -end -function DataGraphs.edge_data_eltype(T::Type{<:BeliefPropagationCache}) - return eltype(fieldtype(T, :messages)) -end - -message_type(T::Type{<:BeliefPropagationCache}) = edge_data_eltype(T) - function BeliefPropagationCache(network::AbstractGraph) - MT = vertex_data_eltype(typeof(network)) + MT = vertex_data_type(typeof(network)) return BeliefPropagationCache(MT, network) end function BeliefPropagationCache(MT::Type, network::AbstractGraph) @@ -95,7 +83,7 @@ function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_roo forests = forest_cover(g) rv = edgetype(g)[] for forest in forests - trees = [forest[vs] for vs in connected_components(forest)] + trees = [forest[Vertices(vs)] for vs in connected_components(forest)] for tree in trees tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) push!(rv, vcat(tree_edges, reverse(reverse.(tree_edges)))...) @@ -106,16 +94,19 @@ end function bpcache_induced_subgraph(graph, subvertices) underlying_subgraph, vlist = Graphs.induced_subgraph(network(graph), subvertices) - subgraph = BeliefPropagationCache(underlying_subgraph, typeof(edge_data(graph))()) + + edge_data = Dictionary{edgetype(underlying_subgraph), edge_data_type(typeof(graph))}() + + subgraph = BeliefPropagationCache(underlying_subgraph, edge_data) for e in edges(subgraph) if isassigned(graph, e) - set!(edge_data(subgraph), e, graph[e]) + subgraph[e] = graph[e] end end return subgraph, vlist end -function Graphs.induced_subgraph(graph::BeliefPropagationCache{V}, subvertices::Vector{V}) where {V} +function NamedGraphs.induced_subgraph_from_vertices(graph::BeliefPropagationCache, subvertices) return bpcache_induced_subgraph(graph, subvertices) end @@ -138,3 +129,13 @@ function default_message(T::Type{<:LazyNamedDimsArray}, network, edge) message = default_message(parenttype(T), network, edge) return convert(T, lazy(message)) end + +NamedGraphs.to_graph_index(::BeliefPropagationCache, vertex::QuotientVertex) = vertex +# When getting data according the quotient vertices, take a lazy contraction. +function DataGraphs.get_index_data(tn::BeliefPropagationCache, vertex::QuotientVertex) + data = collect(map(v -> tn[v], vertices(tn, vertex))) + return mapreduce(lazy, *, data) +end +function DataGraphs.is_graph_index_assigned(tn::BeliefPropagationCache, vertex::QuotientVertex) + return isassigned(tn, Vertices(vertices(tn, vertex))) +end diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index b811e2b..0d30970 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -18,7 +18,7 @@ using NamedGraphs.PartitionedGraphs: QuotientVertexVertices, quotientvertices using .LazyNamedDimsArrays: lazy, Mul -using DataGraphs: vertex_data_eltype, vertex_data, edge_data, get_vertices_data +using DataGraphs: vertex_data_type, vertex_data, edge_data, get_vertices_data using DataGraphs.DataGraphsPartitionedGraphsExt function _TensorNetwork end @@ -52,13 +52,12 @@ end DataGraphs.underlying_graph(tn::TensorNetwork) = tn.underlying_graph -DataGraphs.has_vertex_data(tn::TensorNetwork, v) = haskey(tn.tensors, v) -DataGraphs.has_edge_data(tn::TensorNetwork, e) = false +DataGraphs.is_vertex_assigned(tn::TensorNetwork, v) = haskey(tn.tensors, v) +DataGraphs.is_edge_assigned(tn::TensorNetwork, e) = false DataGraphs.get_vertex_data(tn::TensorNetwork, v) = tn.tensors[v] DataGraphs.set_vertex_data!(tn::TensorNetwork, val, v) = set!(tn.tensors, v, val) -DataGraphs.unset_vertex_data!(tn::TensorNetwork, val, v) = unset!(tn.tensors, v, val) function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork}) return fieldtype(type, :underlying_graph) @@ -135,11 +134,30 @@ function Graphs.rem_edge!(tn::TensorNetwork, e) return true end -function GraphsExtensions.similar(type::Type{<:TensorNetwork}) +function GraphsExtensions.similar_graph(type::Type{<:TensorNetwork}) DT = fieldtype(type, :tensors) empty_dict = DT() return TensorNetwork(similar_graph(underlying_graph_type(type)), empty_dict) end +function GraphsExtensions.similar_graph(tn::TensorNetwork, underlying_graph::AbstractGraph) + DT = fieldtype(typeof(tn), :tensors) + empty_dict = DT() + return _TensorNetwork(underlying_graph, empty_dict) +end + +function NamedGraphs.induced_subgraph_from_vertices(graph::TensorNetwork, subvertices) + return tensornetwork_induced_subgraph(graph, subvertices) +end + +function tensornetwork_induced_subgraph(graph, subvertices) + underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) + + subgraph = TensorNetwork(underlying_subgraph) do vertex + return graph[vertex] + end + + return subgraph, vlist +end ## PartitionedGraphs function PartitionedGraphs.quotient_graph(tn::TensorNetwork) @@ -154,7 +172,7 @@ end # DataGraphsPartitionedGraphsExt interface. function PartitionedGraphs.quotient_graph_type(type::Type{<:TensorNetwork}) UG = quotient_graph_type(underlying_graph_type(type)) - VD = Vector{vertex_data_eltype(type)} + VD = Vector{vertex_data_type(type)} V = vertextype(UG) return TensorNetwork{V, VD, UG, Dictionary{V, VD}} end @@ -172,14 +190,18 @@ function PartitionedGraphs.departition( return TensorNetwork(departition(underlying_graph(tn)), vertex_data(tn)) end +NamedGraphs.to_graph_index(::TensorNetwork, vertex::QuotientVertex) = vertex # When getting data according the quotient vertices, take a lazy contraction. -function DataGraphs.get_vertices_data(tn::TensorNetwork, vertex::QuotientVertexVertices) - data = collect(map(v -> tn[v], NamedGraphs.parent_graph_indices(vertex))) +function DataGraphs.get_index_data(tn::TensorNetwork, vertex::QuotientVertex) + data = collect(map(v -> tn[v], vertices(tn, vertex))) return mapreduce(lazy, *, data) end +function DataGraphs.is_graph_index_assigned(tn::TensorNetwork, vertex::QuotientVertex) + return isassigned(tn, Vertices(vertices(tn, vertex))) +end function PartitionedGraphs.quotientview(tn::TensorNetwork) qview = QuotientView(underlying_graph(tn)) - tensors = vertex_data(QuotientView(tn)) + tensors = map(qv -> vertex_data(tn)[Indices(qv)], Indices(quotientvertices(tn))) return TensorNetwork(qview, tensors) end diff --git a/test/Project.toml b/test/Project.toml index 564db3f..975c2c1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -29,8 +29,8 @@ Dictionaries = "0.4.5" Graphs = "1.13.1" ITensorBase = "0.5" ITensorNetworksNext = "0.3" -NamedDimsArrays = "0.14" -NamedGraphs = "0.6.8, 0.7, 0.8" +NamedDimsArrays = "0.13" +NamedGraphs = "0.10" QuadGK = "2.11.2" SafeTestsets = "0.1" Suppressor = "0.2.8" From 202724ca021139bf7fa5d5cd561406dd497cacd4 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 10 Feb 2026 11:57:32 -0500 Subject: [PATCH 31/45] [AlgorithmsInterfaceExtensions] Allowing mapping over a generic iterable when constructing nested algorithms --- .../AlgorithmsInterfaceExtensions.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index a8c814e..3c887b7 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -152,8 +152,8 @@ end abstract type NestedAlgorithm <: Algorithm end -function nested_algorithm(f::Function, nalgorithms::Int; kwargs...) - return DefaultNestedAlgorithm(f, nalgorithms; kwargs...) +function nested_algorithm(f::Function, iterable; kwargs...) + return DefaultNestedAlgorithm(f, iterable; kwargs...) end max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms) @@ -211,6 +211,9 @@ function DefaultNestedAlgorithm(f::Function, nalgorithms::Int; kwargs...) return DefaultNestedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...) end +function DefaultNestedAlgorithm(f::Function, iterable; kwargs...) + return DefaultNestedAlgorithm(; algorithms = map(f, iterable), kwargs...) +end #============================ FlattenedAlgorithm ==========================================# # Flatten a nested algorithm. From 69542e32ba7d5ad1a4b616a40822dffcd1de4c9c Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 11 Feb 2026 11:44:18 -0500 Subject: [PATCH 32/45] Upgrade serial BP to use own `<:Algorithm` structs. --- .../beliefpropagationproblem.jl | 136 +++++++++++------- 1 file changed, 87 insertions(+), 49 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 0d997ee..75023b3 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,8 +1,9 @@ using Graphs: SimpleGraph, vertices, edges, has_edge, AbstractEdge using NamedGraphs: AbstractNamedGraph, position_graph -using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices +using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices, subgraph, boundary_edges using NamedGraphs.OrderedDictionaries: OrderedDictionary, OrderedIndices using NamedDimsArrays: AbstractNamedDimsArray +using NamedGraphs.PartitionedGraphs: quotientvertices using DataGraphs: edge_data import AlgorithmsInterface as AI @@ -41,55 +42,35 @@ function AI.is_finished!( # maxdiff = 0.0 initially, so skip this the first time. if state.iteration > 0 st.delta = state.iterate.maxdiff + @info "$(state.iteration): $(st.delta)" end return st.delta < c.tol end -struct BeliefPropagationProblem{Network} <: AIE.Problem - network::Network -end +# struct BeliefPropagationProblem{Network} <: AIE.Problem +# network::Network +# end + +struct BeliefPropagationProblem <: AIE.Problem end -@kwdef mutable struct BeliefPropagationState{ - Iterate <: BeliefPropagationCache, - Diffs, - } <: AIE.NonIterativeAlgorithmState +@kwdef mutable struct BeliefPropagationState{Iterate, Diffs} <: AIE.NonIterativeAlgorithmState iterate::Iterate diffs::Diffs = similar(edge_data(iterate), Float64) maxdiff::Float64 = 0.0 end -function AI.initialize_state( - problem::BeliefPropagationProblem, - algorithm::AIE.NonIterativeAlgorithm; iterate, kwargs... - ) - - diffs = iterate.diffs - maxdiff = iterate.maxdiff - - return BeliefPropagationState(; iterate = iterate.iterate, diffs, maxdiff, kwargs...) -end - -# This gets called at the start of every sweep. -function AI.initialize_state!( - problem::BeliefPropagationProblem, - algorithm::AIE.NestedAlgorithm, - state::AIE.State, - ) - state.iterate.maxdiff = 0.0 - return state +@kwdef struct BeliefPropagation{ + ChildAlgorithm <: AIE.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: AIE.NestedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end -function AIE.set_substate!( - ::BeliefPropagationProblem, - ::AIE.NestedAlgorithm, - state::AIE.State, - substate::BeliefPropagationState - ) - - state.iterate = substate - - return state +function BeliefPropagation(f::Function, niterations::Int; kwargs...) + return BeliefPropagation(; algorithms = f.(1:niterations), kwargs...) end abstract type AbstractMessageUpdate <: AIE.NonIterativeAlgorithm end @@ -101,7 +82,7 @@ end function SimpleMessageUpdate( edge; - normalize = false, + normalize = true, contraction_alg = "eager", compute_diff = false, kwargs... @@ -117,6 +98,53 @@ function Base.getproperty(alg::SimpleMessageUpdate, name::Symbol) end end +struct BeliefPropagationSweep{ + ChildAlgorithm <: AIE.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + } <: AIE.NestedAlgorithm + algorithms::Algorithms + stopping_criterion::AI.StopAfterIteration + function BeliefPropagationSweep(; algorithms) + stopping_criterion = AI.StopAfterIteration(length(algorithms)) + return new{eltype(algorithms), typeof(algorithms)}(algorithms, stopping_criterion) + end +end + +BeliefPropagationSweep(f::Function, edges) = BeliefPropagationSweep(; algorithms = f.(edges)) + +function AI.initialize_state( + problem::BeliefPropagationProblem, + update_algorithm::AIE.NonIterativeAlgorithm; iterate, kwargs... + ) + + diffs = iterate.diffs + maxdiff = iterate.maxdiff + + return BeliefPropagationState(; iterate = iterate.iterate, diffs, maxdiff, kwargs...) +end + +# This gets called at the start of every sweep. +function AI.initialize_state!( + ::BeliefPropagationProblem, + ::BeliefPropagationSweep, + iteration_state::AIE.State, + ) + iteration_state.iterate.maxdiff = 0.0 + return iteration_state +end + +function AIE.set_substate!( + ::BeliefPropagationProblem, + sweep_algorithm::BeliefPropagationSweep, + sweep_state::AIE.DefaultState, + noniterative_substate::BeliefPropagationState, + ) + + sweep_state.iterate = noniterative_substate + + return sweep_state +end + struct MessageUpdateProblem{Messages, Factors} <: AIE.Problem messages::Messages factors::Factors @@ -124,7 +152,7 @@ end function AI.solve!( problem::BeliefPropagationProblem, - algorithm::AbstractMessageUpdate, + algorithm::SimpleMessageUpdate, state::BeliefPropagationState; logging_context_prefix = default_logging_context_prefix(problem, algorithm), ) @@ -177,7 +205,7 @@ end function AI.solve!( problem::MessageUpdateProblem, algorithm::SimpleMessageUpdate, - state::AIE.NonIterativeAlgorithmState; + state::AIE.DefaultNonIterativeAlgorithmState; logging_context_prefix = AI.default_logging_context_prefix(problem, algorithm), kwargs... ) @@ -209,24 +237,29 @@ end beliefpropagation(network; kwargs...) = beliefpropagation(BeliefPropagationCache(network); kwargs...) function beliefpropagation(cache::AbstractBeliefPropagationCache; kwargs...) - problem = BeliefPropagationProblem(network(cache)) + # problem = BeliefPropagationProblem(network(cache)) + problem = BeliefPropagationProblem() algorithm = select_algorithm(beliefpropagation, cache; kwargs...) # The nested algorithms will wrap and manipulate this object. + base_state = BeliefPropagationState(; iterate = cache) - state = AI.solve(problem, algorithm; iterate = base_state) + state = AI.initialize_state(problem, algorithm; iterate = base_state) + + state = AI.solve!(problem, algorithm, state) return state.iterate.iterate end + function select_algorithm( ::typeof(beliefpropagation), - cache; + cache::AbstractBeliefPropagationCache; edges = forest_cover_edge_sequence(network(cache)), maxiter = is_tree(network(cache)) ? 1 : nothing, - tol = 0.0, + tol = -Inf, kwargs... ) @@ -237,7 +270,7 @@ function select_algorithm( stopping_criterion = AI.StopAfterIteration(maxiter) compute_diff = false - if tol > 0.0 + if tol > -Inf stopping_criterion = stopping_criterion | StopWhenConverged(tol) compute_diff = true end @@ -245,9 +278,14 @@ function select_algorithm( extended_kwargs = extend_columns((; compute_diff, kwargs...), maxiter) edge_kwargs = rows(extended_kwargs, len = maxiter) - return AIE.nested_algorithm(maxiter; stopping_criterion) do repnum - return AIE.nested_algorithm(length(edges)) do edgenum - return SimpleMessageUpdate(edges[edgenum]; edge_kwargs[repnum]...) - end + return BeliefPropagation(maxiter; stopping_criterion) do repnum + return beliefpropagation_sweep(cache; edges, edge_kwargs[repnum]...) + end +end + +# A single sweep across the given edges. +function beliefpropagation_sweep(cache::BeliefPropagationCache; edges, kwargs...) + return BeliefPropagationSweep(edges) do edge + return SimpleMessageUpdate(edge; kwargs...) end end From 992506900fd225d106a57e03346fd62e6f74bc80 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 13 Feb 2026 17:19:04 -0500 Subject: [PATCH 33/45] Simplify BP cache to only store factors --- src/abstracttensornetwork.jl | 26 ++-- .../beliefpropagationcache.jl | 131 +++++++++--------- .../beliefpropagationproblem.jl | 81 ++++++----- 3 files changed, 115 insertions(+), 123 deletions(-) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index 671ba3a..c4b6fcb 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -1,25 +1,17 @@ -using Adapt: Adapt, adapt, adapt_structure +using Adapt: Adapt, adapt using BackendSelection: @Algorithm_str, Algorithm -using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, underlying_graph, - underlying_graph_type, vertex_data, set_vertex_data! +using DataGraphs: AbstractDataGraph, DataGraphs, edge_data, set_vertex_data!, + underlying_graph, underlying_graph_type, vertex_data using Dictionaries: Dictionary -using Graphs: Graphs, AbstractEdge, AbstractGraph, Graph, add_edge!, add_vertex!, - bfs_tree, center, dst, edges, edgetype, ne, neighbors, nv, rem_edge!, src, vertices -using LinearAlgebra: LinearAlgebra, factorize +using Graphs: AbstractEdge, AbstractGraph, Graphs, add_edge!, add_vertex!, + dst, edges, edgetype, ne, neighbors, nv, rem_edge!, src, vertices +using LinearAlgebra: LinearAlgebra using MacroTools: @capture using NamedDimsArrays: dimnames, inds -using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree +using NamedGraphs: NamedGraph, NamedGraphs, not_implemented using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger -using NamedGraphs.GraphsExtensions: - ⊔, - directed_graph, - incident_edges, - rem_edges!, - rename_vertices, - vertextype, - similar_graph -using SplitApplyCombine: flatten -using NamedGraphs.SimilarType: similar_type +using NamedGraphs.GraphsExtensions: directed_graph, incident_edges, rem_edges!, + similar_graph, vertextype abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 10ab586..2c253e6 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -1,46 +1,29 @@ -using DataGraphs: - DataGraphs, - AbstractDataGraph, - DataGraph, - get_vertex_data, - get_edge_data, - set_vertex_data!, - set_edge_data!, - vertex_data_type, - edge_data_type, - underlying_graph, - underlying_graph_type, - is_vertex_assigned, - is_edge_assigned -using Dictionaries: Dictionary, set!, delete! -using Graphs: AbstractGraph, is_tree, connected_components, is_directed -using NamedGraphs: NamedDiGraph, convert_vertextype, parent_graph_indices, Vertices -using NamedGraphs.GraphsExtensions: default_root_vertex, - forest_cover, - post_order_dfs_edges, - vertextype, - is_path_graph, - undirected_graph -using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, QuotientEdges, quotient_graph, quotientedges +using DataGraphs: AbstractDataGraph, DataGraphs, edge_data, edge_data_type, + set_vertex_data!, underlying_graph, underlying_graph_type, vertex_data, + vertex_data_type +using Dictionaries: Dictionary, delete!, set!, getindices +using Graphs: AbstractGraph, connected_components, is_tree, is_directed using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, lazy, parenttype +using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, undirected_graph, vertextype +using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, quotient_graph -struct BeliefPropagationCache{V, VD, ED, G <: AbstractGraph{V}, N <: AbstractGraph{V}, E} <: - AbstractBeliefPropagationCache{V, VD, ED} +using NamedGraphs: Vertices, convert_vertextype, parent_graph_indices + +struct BeliefPropagationCache{V, VD, ED, E, G <: AbstractGraph{V}} <: AbstractBeliefPropagationCache{V, VD, ED} underlying_graph::G # we only use this for the edges. - network::N + factors::Dictionary{V, VD} messages::Dictionary{E, ED} - function BeliefPropagationCache(network::AbstractGraph, messages::Dictionary) + function BeliefPropagationCache(graph::AbstractGraph, factors::Dictionary, messages::Dictionary) + # Ensure the graph is directed, if not make it directed. + digraph = is_directed(graph) ? graph : directed_graph(graph) - V = vertextype(network) - VD = vertex_data_type(network) - N = typeof(network) - ET = keytype(messages) - ED = eltype(messages) + V = keytype(factors) + VD = eltype(factors) - # Construct a directed graph version of the underlying graph of the tensor network. - digraph = directed_graph(underlying_graph(network)) + E = keytype(messages) + ED = eltype(messages) - bpc = new{V, VD, ED, typeof(digraph), N, ET}(digraph, network, messages) + bpc = new{V, VD, ED, E, typeof(digraph)}(digraph, factors, messages) for edge in edges(bpc) get!(() -> default_message(bpc, edge), messages, edge) @@ -49,30 +32,39 @@ struct BeliefPropagationCache{V, VD, ED, G <: AbstractGraph{V}, N <: AbstractGra end end -network(bp_cache) = getfield(bp_cache, :network) - -DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = getfield(bpc, :underlying_graph) +DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = bpc.underlying_graph -DataGraphs.is_vertex_assigned(bpc::BeliefPropagationCache, vertex) = is_vertex_assigned(network(bpc), vertex) +DataGraphs.is_vertex_assigned(bpc::BeliefPropagationCache, vertex) = haskey(bpc.factors, vertex) DataGraphs.is_edge_assigned(bpc::BeliefPropagationCache, edge) = haskey(bpc.messages, edge) -DataGraphs.get_vertex_data(bpc::BeliefPropagationCache, vertex) = get_vertex_data(network(bpc), vertex) +DataGraphs.get_vertex_data(bpc::BeliefPropagationCache, vertex) = bpc.factors[vertex] DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) = bpc.messages[edge] -DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = set_vertex_data!(network(bpc), val, vertex) +DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = set!(bpc.factors, vertex, val) DataGraphs.set_edge_data!(bpc::BeliefPropagationCache, val, edge) = set!(bpc.messages, edge, val) +# These two methods assume `network` behaves llike a tensor network +# (could be e.g. a QuotientView) otherwise how would one know what the factors should be. function BeliefPropagationCache(network::AbstractGraph) - MT = vertex_data_type(typeof(network)) - return BeliefPropagationCache(MT, network) + graph = underlying_graph(network) + return BeliefPropagationCache(graph, copy(vertex_data(network))) end function BeliefPropagationCache(MT::Type, network::AbstractGraph) - dict = Dictionary{edgetype(network), MT}() - return BeliefPropagationCache(network, dict) + graph = underlying_graph(network) + return BeliefPropagationCache(MT, graph, copy(vertex_data(network))) +end + +function BeliefPropagationCache(graph::AbstractGraph, factors::Dictionary) + MT = vertex_data_type(typeof(graph)) + return BeliefPropagationCache(MT, graph, factors) +end +function BeliefPropagationCache(MT::Type, graph::AbstractGraph, factors::Dictionary) + messages = Dictionary{edgetype(graph), MT}() + return BeliefPropagationCache(graph, factors, messages) end function Base.copy(bp_cache::BeliefPropagationCache) - return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) + return BeliefPropagationCache(copy(bp_cache.underlying_graph), copy(bp_cache.factors), copy(bp_cache.messages)) end # TODO: This needs to go in GraphsExtensions @@ -92,41 +84,50 @@ function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_roo return rv end -function bpcache_induced_subgraph(graph, subvertices) - underlying_subgraph, vlist = Graphs.induced_subgraph(network(graph), subvertices) +function induced_subgraph_bpcache(graph, subvertices) + underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) - edge_data = Dictionary{edgetype(underlying_subgraph), edge_data_type(typeof(graph))}() + assigned = v -> isassigned(graph, v) + + assigned_subvertices = Iterators.filter(assigned, subvertices) + assigned_subedges = Iterators.filter(assigned, edges(underlying_subgraph)) + + factors = getindices(vertex_data(graph), Indices(assigned_subvertices)) + messages = getindices(edge_data(graph), Indices(assigned_subedges)) + + subgraph = BeliefPropagationCache(underlying_subgraph, factors, messages) - subgraph = BeliefPropagationCache(underlying_subgraph, edge_data) - for e in edges(subgraph) - if isassigned(graph, e) - subgraph[e] = graph[e] - end - end return subgraph, vlist end function NamedGraphs.induced_subgraph_from_vertices(graph::BeliefPropagationCache, subvertices) - return bpcache_induced_subgraph(graph, subvertices) + return induced_subgraph_bpcache(graph, subvertices) end ## PartitionedGraphs +# Take a QuotientView of the underlying graph. function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) - inds = Indices(parent_graph_indices(QuotientEdges(underlying_graph(bpc)))) - data = map(e -> bpc[QuotientEdge(e)], inds) - return BeliefPropagationCache(QuotientView(network(bpc)), data) + + graph = underlying_graph(bpc) + + quotient_view = QuotientView(graph) + + factors = map(v -> bpc[QuotientVertex(v)], Indices(vertices(quotient_view))) + messages = map(e -> bpc[QuotientEdge(e)], Indices(edges(quotient_view))) + + return BeliefPropagationCache(quotient_view, factors, messages) end function default_message(bpc::BeliefPropagationCache, edge) - return default_message(message_type(bpc), network(bpc), edge) + return default_message(message_type(bpc), bpc[src(edge)], bpc[dst(edge)]) end -function default_message(T::Type, network, edge) - array = ones(Tuple(linkinds(network, edge))) +function default_message(T::Type, src, dst) + array = ones(Tuple(inds(src) ∩ inds(dst))) return convert(T, array) end -function default_message(T::Type{<:LazyNamedDimsArray}, network, edge) - message = default_message(parenttype(T), network, edge) +function default_message(T::Type{<:LazyNamedDimsArray}, src, dst) + message = default_message(parenttype(T), src, dst) return convert(T, lazy(message)) end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 75023b3..89c28df 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,10 +1,9 @@ -using Graphs: SimpleGraph, vertices, edges, has_edge, AbstractEdge -using NamedGraphs: AbstractNamedGraph, position_graph -using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices, subgraph, boundary_edges -using NamedGraphs.OrderedDictionaries: OrderedDictionary, OrderedIndices +using Graphs: AbstractEdge, edges, has_edge, vertices +using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph using NamedDimsArrays: AbstractNamedDimsArray using NamedGraphs.PartitionedGraphs: quotientvertices using DataGraphs: edge_data +using LinearAlgebra: norm, normalize import AlgorithmsInterface as AI import .AlgorithmsInterfaceExtensions as AIE @@ -42,17 +41,14 @@ function AI.is_finished!( # maxdiff = 0.0 initially, so skip this the first time. if state.iteration > 0 st.delta = state.iterate.maxdiff - @info "$(state.iteration): $(st.delta)" end return st.delta < c.tol end -# struct BeliefPropagationProblem{Network} <: AIE.Problem -# network::Network -# end - -struct BeliefPropagationProblem <: AIE.Problem end +struct BeliefPropagationProblem{Network} <: AIE.Problem + network::Network +end @kwdef mutable struct BeliefPropagationState{Iterate, Diffs} <: AIE.NonIterativeAlgorithmState iterate::Iterate @@ -113,8 +109,7 @@ end BeliefPropagationSweep(f::Function, edges) = BeliefPropagationSweep(; algorithms = f.(edges)) function AI.initialize_state( - problem::BeliefPropagationProblem, - update_algorithm::AIE.NonIterativeAlgorithm; iterate, kwargs... + ::BeliefPropagationProblem, ::AIE.NonIterativeAlgorithm; iterate, kwargs... ) diffs = iterate.diffs @@ -135,7 +130,7 @@ end function AIE.set_substate!( ::BeliefPropagationProblem, - sweep_algorithm::BeliefPropagationSweep, + ::BeliefPropagationSweep, sweep_state::AIE.DefaultState, noniterative_substate::BeliefPropagationState, ) @@ -145,16 +140,16 @@ function AIE.set_substate!( return sweep_state end -struct MessageUpdateProblem{Messages, Factors} <: AIE.Problem +struct MessageUpdateProblem{Factor, Messages} <: AIE.Problem + factor::Factor messages::Messages - factors::Factors end function AI.solve!( problem::BeliefPropagationProblem, algorithm::SimpleMessageUpdate, state::BeliefPropagationState; - logging_context_prefix = default_logging_context_prefix(problem, algorithm), + logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm), ) logger = AI.algorithm_logger() @@ -168,8 +163,8 @@ function AI.solve!( new_message = updated_message(algorithm, cache) - if algorithm.compute_diff - diff = message_diff(new_message, cache[edge]) + if !isnothing(algorithm.message_diff_function) + diff = algorithm.message_diff_function(new_message, cache[edge]) if diff > state.maxdiff state.maxdiff = diff @@ -187,7 +182,7 @@ function AI.solve!( return state end -message_diff(m1, m2) = LinearAlgebra.norm(m1 - m2) +default_message_diff_function(m1, m2) = norm(normalize(m1) - normalize(m2)) function updated_message(algorithm, cache) edge = algorithm.edge @@ -195,7 +190,7 @@ function updated_message(algorithm, cache) vertex = src(edge) messages = incoming_messages(cache, vertex; ignore_edges = typeof(edge)[reverse(edge)]) - update_problem = MessageUpdateProblem(messages, factors(cache, vertex)) + update_problem = MessageUpdateProblem(cache[vertex], messages) message_state = AI.solve(update_problem, algorithm; iterate = message(cache, edge)) @@ -206,13 +201,21 @@ function AI.solve!( problem::MessageUpdateProblem, algorithm::SimpleMessageUpdate, state::AIE.DefaultNonIterativeAlgorithmState; - logging_context_prefix = AI.default_logging_context_prefix(problem, algorithm), + logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm), kwargs... ) - # TODO: logging... + logger = AI.algorithm_logger() + + AI.emit_message( + logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreUpdate) + ) + + state.iterate = contract_messages(algorithm.contraction_alg, problem.factor, problem.messages) - state.iterate = contract_messages(algorithm.contraction_alg, problem.factors, problem.messages) + AI.emit_message( + logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreNormalization) + ) if algorithm.normalize # TODO: use `sum` not `norm` @@ -222,28 +225,26 @@ function AI.solve!( end end + AI.emit_message( + logger, problem, algorithm, state, Symbol(logging_context_prefix, :PostNormalization) + ) + return state end -contract_messages(alg, factors, messages) = not_implemented() -function contract_messages( - alg, - factors::Vector{<:AbstractArray}, - messages::Vector{<:AbstractArray}, - ) +function contract_messages(alg, factor::AbstractArray, messages::Vector{<:AbstractArray}) + factors = typeof(factor)[factor] return contract_network(vcat(factors, messages); alg) end -beliefpropagation(network; kwargs...) = beliefpropagation(BeliefPropagationCache(network); kwargs...) -function beliefpropagation(cache::AbstractBeliefPropagationCache; kwargs...) +beliefpropagation(network; kwargs...) = beliefpropagation(BeliefPropagationCache(network), network; kwargs...) +function beliefpropagation(cache::AbstractBeliefPropagationCache, network = nothing; kwargs...) - # problem = BeliefPropagationProblem(network(cache)) - problem = BeliefPropagationProblem() + problem = BeliefPropagationProblem(network) algorithm = select_algorithm(beliefpropagation, cache; kwargs...) # The nested algorithms will wrap and manipulate this object. - base_state = BeliefPropagationState(; iterate = cache) state = AI.initialize_state(problem, algorithm; iterate = base_state) @@ -253,13 +254,13 @@ function beliefpropagation(cache::AbstractBeliefPropagationCache; kwargs...) return state.iterate.iterate end - function select_algorithm( ::typeof(beliefpropagation), cache::AbstractBeliefPropagationCache; - edges = forest_cover_edge_sequence(network(cache)), - maxiter = is_tree(network(cache)) ? 1 : nothing, + edges = forest_cover_edge_sequence(cache), + maxiter = is_tree(cache) ? 1 : nothing, tol = -Inf, + message_diff_function = tol > -Inf ? (m1, m2) -> norm(m1 / norm(m1) - m2 / norm(m2)) : nothing, kwargs... ) @@ -268,14 +269,12 @@ function select_algorithm( end stopping_criterion = AI.StopAfterIteration(maxiter) - compute_diff = false if tol > -Inf stopping_criterion = stopping_criterion | StopWhenConverged(tol) - compute_diff = true end - extended_kwargs = extend_columns((; compute_diff, kwargs...), maxiter) + extended_kwargs = extend_columns((; message_diff_function, kwargs...), maxiter) edge_kwargs = rows(extended_kwargs, len = maxiter) return BeliefPropagation(maxiter; stopping_criterion) do repnum @@ -284,7 +283,7 @@ function select_algorithm( end # A single sweep across the given edges. -function beliefpropagation_sweep(cache::BeliefPropagationCache; edges, kwargs...) +function beliefpropagation_sweep(::BeliefPropagationCache; edges, kwargs...) return BeliefPropagationSweep(edges) do edge return SimpleMessageUpdate(edge; kwargs...) end From 292f2fa10be8626746f87148c95ea0fb0ba17ae8 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 13 Feb 2026 17:28:23 -0500 Subject: [PATCH 34/45] Upgrade to DataGraphs v0.3.1 and NamedGraphs v0.10 --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index efd1d3c..c7133ff 100644 --- a/Project.toml +++ b/Project.toml @@ -39,7 +39,7 @@ Adapt = "4.3" AlgorithmsInterface = "0.1" BackendSelection = "0.1.6" Combinatorics = "1" -DataGraphs = "0.2.7" +DataGraphs = "0.3.1" DiagonalArrays = "0.3.31" Dictionaries = "0.4.5" FunctionImplementations = "0.4" @@ -47,7 +47,7 @@ Graphs = "1.13.1" LinearAlgebra = "1.10" MacroTools = "0.5.16" NamedDimsArrays = "0.14.2" -NamedGraphs = "0.6.9, 0.7, 0.8" +NamedGraphs = "0.10" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" TensorOperations = "5.3.1" From 9d937aa366d7afb54ab3e918a7039606de148112 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 13 Feb 2026 17:38:37 -0500 Subject: [PATCH 35/45] Fix compat --- Project.toml | 4 ++-- test/Project.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index c7133ff..1da8abe 100644 --- a/Project.toml +++ b/Project.toml @@ -42,11 +42,11 @@ Combinatorics = "1" DataGraphs = "0.3.1" DiagonalArrays = "0.3.31" Dictionaries = "0.4.5" -FunctionImplementations = "0.4" +FunctionImplementations = "0.4.1" Graphs = "1.13.1" LinearAlgebra = "1.10" MacroTools = "0.5.16" -NamedDimsArrays = "0.14.2" +NamedDimsArrays = "0.14.3" NamedGraphs = "0.10" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" diff --git a/test/Project.toml b/test/Project.toml index 975c2c1..cf048b7 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -29,7 +29,7 @@ Dictionaries = "0.4.5" Graphs = "1.13.1" ITensorBase = "0.5" ITensorNetworksNext = "0.3" -NamedDimsArrays = "0.13" +NamedDimsArrays = "0.14" NamedGraphs = "0.10" QuadGK = "2.11.2" SafeTestsets = "0.1" From 5432fe28bb172ff61bb8a191b5de4604da06ef53 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 13 Feb 2026 18:08:12 -0500 Subject: [PATCH 36/45] Fix broken merge Fix broken merge --- .../beliefpropagationproblem.jl | 4 +- src/contract_network.jl | 54 +++++++++---------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 89c28df..c127655 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -79,7 +79,7 @@ end function SimpleMessageUpdate( edge; normalize = true, - contraction_alg = "eager", + contraction_alg = "exact", compute_diff = false, kwargs... ) @@ -275,7 +275,7 @@ function select_algorithm( end extended_kwargs = extend_columns((; message_diff_function, kwargs...), maxiter) - edge_kwargs = rows(extended_kwargs, len = maxiter) + edge_kwargs = rows(extended_kwargs, maxiter) return BeliefPropagation(maxiter; stopping_criterion) do repnum return beliefpropagation_sweep(cache; edges, edge_kwargs[repnum]...) diff --git a/src/contract_network.jl b/src/contract_network.jl index 4fda3a7..a8c3fc7 100644 --- a/src/contract_network.jl +++ b/src/contract_network.jl @@ -1,11 +1,27 @@ using BackendSelection: @Algorithm_str, Algorithm using Base.Broadcast: materialize -using NamedDimsArrays: inds -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArrays, Mul, lazy, optimize_evaluation_order, +using ITensorNetworksNext.LazyNamedDimsArrays: Mul, lazy, optimize_evaluation_order, substitute, symnameddims -function contract_network(tn; alg = default_kwargs(contract_network, tn).alg) - return contract_network(alg, tn) +# This is related to `MatrixAlgebraKit.select_algorithm`. +# TODO: Define this in BackendSelection.jl. +backend_value(::Algorithm{alg}) where {alg} = alg +using BackendSelection: parameters +function merge_parameters(alg::Algorithm; kwargs...) + return Algorithm(backend_value(alg); merge(parameters(alg), kwargs)...) +end +to_algorithm(alg::Algorithm; kwargs...) = merge_parameters(alg; kwargs...) +to_algorithm(alg; kwargs...) = Algorithm(alg; kwargs...) + +# `contract_network` +function contract_network(alg::Algorithm, tn) + return throw(ArgumentError("`contract_network` algorithm `$(alg)` not implemented.")) +end +function default_kwargs(::typeof(contract_network), tn) + return (; alg = Algorithm"exact"(; order_alg = Algorithm"eager"())) +end +function contract_network(tn; alg = default_kwargs(contract_network, tn).alg, kwargs...) + return contract_network(to_algorithm(alg; kwargs...), tn) end # `contract_network(::Algorithm"exact", ...)` @@ -34,24 +50,12 @@ end # `contraction_order` function contraction_order end -default_kwargs(::typeof(contraction_order), tensors) = (; order = "eager") - -function contraction_expression(tensors; order = default_kwargs(contraction_order, tensors).order) - order = contraction_order(order, tensors) - - # Contraction order may or may not have indices attached, canonicalize the format - # by attaching indices. - subs = Dict(symnameddims(i) => symnameddims(i, tensors[i]) for i in keys(tensors)) - - return substitute(order, subs) -end - -contraction_order(order, tensors) = order -function contraction_order(tensors; order = default_kwargs(contraction_order, tensors).order) - return contraction_order(Algorithm(order), tensors) +default_kwargs(::typeof(contraction_order), tn) = (; alg = Algorithm"eager"()) +function contraction_order(tn; alg = default_kwargs(contraction_order, tn).alg, kwargs...) + return contraction_order(to_algorithm(alg; kwargs...), tn) end # Convert the tensor network to a flat symbolic multiplication expression. -function contraction_order(::Algorithm"flat", tensors) +function contraction_order(alg::Algorithm"flat", tn) # Same as: `reduce((a, b) -> *(a, b; flatten = true), syms)`. syms = vec([symnameddims(i, Tuple(axes(tn[i]))) for i in keys(tn)]) return lazy(Mul(syms)) @@ -59,11 +63,7 @@ end function contraction_order(alg::Algorithm"left_associative", tn) return prod(i -> symnameddims(i, Tuple(axes(tn[i]))), keys(tn)) end - -function contraction_order( - order_algorithm::Algorithm, - tensors, - ) - order = contraction_order(tensors; order = "flat") - return optimize_evaluation_order(order; alg = order_algorithm) +function contraction_order(alg::Algorithm, tn) + s = contraction_order(Algorithm"flat"(), tn) + return optimize_evaluation_order(s; alg) end From c916c84c19502294b77aeca61165b778ddbd66c8 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 19 Feb 2026 17:44:59 -0500 Subject: [PATCH 37/45] Bug fix; upgrade tests --- .../beliefpropagationproblem.jl | 2 +- test/Project.toml | 1 + test/test_contract_network.jl | 16 +++++++++------- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index c127655..0312843 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -232,7 +232,7 @@ function AI.solve!( return state end -function contract_messages(alg, factor::AbstractArray, messages::Vector{<:AbstractArray}) +function contract_messages(alg, factor::AbstractArray, messages) factors = typeof(factor)[factor] return contract_network(vcat(factors, messages); alg) end diff --git a/test/Project.toml b/test/Project.toml index cf048b7..8b1072a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" AlgorithmsInterface = "d1e3940c-cd12-4505-8585-b0a4b322527d" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" diff --git a/test/test_contract_network.jl b/test/test_contract_network.jl index fc863f6..35b2275 100644 --- a/test/test_contract_network.jl +++ b/test/test_contract_network.jl @@ -5,8 +5,11 @@ using ITensorBase: Index using ITensorNetworksNext: TensorNetwork, linkinds, siteinds, contract_network using TensorOperations: TensorOperations using Test: @test, @testset +using BackendSelection: @Algorithm_str, Algorithm @testset "contract_network" begin + orderalg = alg -> Algorithm"exact"(; order_alg = Algorithm(alg)) + @testset "Contract Vectors of ITensors" begin i, j, k = Index(2), Index(2), Index(5) A = [1.0 1.0; 0.5 1.0][i, j] @@ -14,10 +17,9 @@ using Test: @test, @testset C = [5.0, 1.0][j] D = [-2.0, 3.0, 4.0, 5.0, 1.0][k] - ABCD_1 = contract_network([A, B, C, D]; alg = "left_associative") - ABCD_2 = contract_network([A, B, C, D]; alg = "eager") - ABCD_3 = contract_network([A, B, C, D]; alg = "optimal") - + ABCD_1 = contract_network([A, B, C, D]; alg = orderalg("left_associative")) + ABCD_2 = contract_network([A, B, C, D]; alg = orderalg("eager")) + ABCD_3 = contract_network([A, B, C, D]; alg = orderalg("optimal")) @test ABCD_1 == ABCD_2 == ABCD_3 end @@ -31,9 +33,9 @@ using Test: @test, @testset return randn(Tuple(is)) end - z1 = contract_network(tn; alg = "left_associative")[] - z2 = contract_network(tn; alg = "eager")[] - z3 = contract_network(tn; alg = "optimal")[] + z1 = contract_network(tn; alg = orderalg("left_associative"))[] + z2 = contract_network(tn; alg = orderalg("eager"))[] + z3 = contract_network(tn; alg = orderalg("optimal"))[] @test abs(z1 - z2) / abs(z1) <= 1.0e3 * eps(Float64) @test abs(z1 - z3) / abs(z1) <= 1.0e3 * eps(Float64) From 4a511a159d298ef466108b7af250b754c6d0dc35 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 24 Feb 2026 09:41:03 -0500 Subject: [PATCH 38/45] Add 2D TN test --- test/Project.toml | 1 + test/test_beliefpropagation.jl | 64 +++++++++++++++++++++++++++------- 2 files changed, 52 insertions(+), 13 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 8b1072a..50a58c5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,6 +8,7 @@ Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" ITensorNetworksNext = "302f2e75-49f0-4526-aef7-d8ba550cb06c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 8c7829b..8a817b2 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -1,22 +1,48 @@ -using Dictionaries: Dictionary -using ITensorBase: Index +using Dictionaries: Dictionary, set! +using ITensorBase: Index, ITensor, prime, noprime using ITensorNetworksNext: BeliefPropagationCache, ITensorNetworksNext, TensorNetwork, - adapt_messages, - default_message, - default_messages, - edge_scalars, - factors, - messages, - partitionfunction, - setmessages! -using Graphs: edges, vertices + partitionfunction +using DiagonalArrays: δ +using Graphs: src, dst, edges, vertices, AbstractGraph using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree -using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges +using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges, vertextype using Test: @test, @testset +using LinearAlgebra: LinearAlgebra +using NamedDimsArrays: name, inds +function ising_tensornetwork(g::AbstractGraph, β::Real; h = 0.0) + links = Dictionary(edges(g), [Index(2; tags = "edge" => "e$(src(e))_$(dst(e))") for e in edges(g)]) + links = merge(links, Dictionary(reverse.(edges(g)), [links[e] for e in edges(g)])) + + # symmetric sqrt of Boltzmann matrix W = exp(β σσ') + sqrt_Ws = Dictionary() + for e in edges(g) + W = [ exp(-(β + 2 * h)) exp(β); exp(β) exp(-(β - 2 * h)) ] + + F = LinearAlgebra.svd(W) + U, S, V = F.U, F.S, F.Vt + @assert U * LinearAlgebra.diagm(S) * V ≈ W + id = [1.0 0.0; 0.0 1.0] + set!(sqrt_Ws, e, id) + set!(sqrt_Ws, reverse(e), U * LinearAlgebra.diagm(S) * V) + end + ts = Dictionary{vertextype(g), ITensor}() + for v in vertices(g) + es = incident_edges(g, v; dir = :in) + #t = ITensor(1.0, physical_inds[v]...) * delta([links[e] for e in es]) + t = δ(Float64, Tuple([links[e] for e in es])) + for e in es + t_prime = ITensor(sqrt_Ws[e], (name(links[e]), name(prime(links[e])))) * t + newinds = noprime.(inds(t_prime)) + t = ITensor(parent(t_prime), name.(newinds)) + end + set!(ts, v, t) + end + return TensorNetwork(g, ts) +end @testset "BeliefPropagation" begin #Chain of tensors @@ -49,5 +75,17 @@ using Test: @test, @testset bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test z_bp ≈ z_exact atol = 1.0e-12 + @test z_bp ≈ z_exact atol = 1.0e-10 + + #Square lattice Ising model + dims = (3, 3) + g = named_grid(dims) + tn = ising_tensornetwork(g, 0.05, h = 0.5) + bpc = ITensorNetworksNext.BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 50, tol = 1.0e-10) + + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test z_bp ≈ z_exact rtol = 1.0e-4 + end From 5b97af3a6b5a219c09b6d7db9e40022ab398bb51 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 24 Feb 2026 09:47:03 -0500 Subject: [PATCH 39/45] Formatting --- docs/make.jl | 9 +-- docs/make_index.jl | 4 +- docs/make_readme.jl | 4 +- .../ITensorNetworksNextTensorOperationsExt.jl | 4 +- .../AlgorithmsInterfaceExtensions.jl | 41 ++++-------- src/LazyNamedDimsArrays/symbolicarray.jl | 8 ++- src/TensorNetworkGenerators/delta_network.jl | 2 +- src/TensorNetworkGenerators/ising_network.jl | 2 +- src/abstracttensornetwork.jl | 16 ++--- .../abstractbeliefpropagationcache.jl | 13 ++-- .../beliefpropagationcache.jl | 58 +++++++++++----- .../beliefpropagationproblem.jl | 66 +++++++++++-------- src/contract_network.jl | 4 +- src/sweeping/eigenproblem.jl | 2 +- src/tensornetwork.jl | 47 ++++++------- test/runtests.jl | 15 +++-- test/test_algorithmsinterfaceextensions.jl | 14 ++-- test/test_aqua.jl | 2 +- test/test_basics.jl | 2 +- test/test_beliefpropagation.jl | 25 ++++--- test/test_contract_network.jl | 6 +- test/test_dmrg.jl | 4 +- test/test_lazynameddimsarrays.jl | 8 +-- test/test_tensornetworkgenerators.jl | 2 +- 24 files changed, 195 insertions(+), 163 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 1b29518..c4f46f3 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,5 +1,5 @@ -using ITensorNetworksNext: ITensorNetworksNext using Documenter: Documenter, DocMeta, deploydocs, makedocs +using ITensorNetworksNext: ITensorNetworksNext DocMeta.setdocmeta!( ITensorNetworksNext, :DocTestSetup, :(using ITensorNetworksNext); recursive = true @@ -14,11 +14,12 @@ makedocs(; format = Documenter.HTML(; canonical = "https://itensor.github.io/ITensorNetworksNext.jl", edit_link = "main", - assets = ["assets/favicon.ico", "assets/extras.css"], + assets = ["assets/favicon.ico", "assets/extras.css"] ), - pages = ["Home" => "index.md", "Reference" => "reference.md"], + pages = ["Home" => "index.md", "Reference" => "reference.md"] ) deploydocs(; - repo = "github.com/ITensor/ITensorNetworksNext.jl", devbranch = "main", push_preview = true + repo = "github.com/ITensor/ITensorNetworksNext.jl", devbranch = "main", + push_preview = true ) diff --git a/docs/make_index.jl b/docs/make_index.jl index 038bc87..af08861 100644 --- a/docs/make_index.jl +++ b/docs/make_index.jl @@ -1,5 +1,5 @@ -using Literate: Literate using ITensorNetworksNext: ITensorNetworksNext +using Literate: Literate function ccq_logo(content) include_ccq_logo = """ @@ -17,5 +17,5 @@ Literate.markdown( joinpath(pkgdir(ITensorNetworksNext), "docs", "src"); flavor = Literate.DocumenterFlavor(), name = "index", - postprocess = ccq_logo, + postprocess = ccq_logo ) diff --git a/docs/make_readme.jl b/docs/make_readme.jl index 088dc58..52d0dbb 100644 --- a/docs/make_readme.jl +++ b/docs/make_readme.jl @@ -1,5 +1,5 @@ -using Literate: Literate using ITensorNetworksNext: ITensorNetworksNext +using Literate: Literate function ccq_logo(content) include_ccq_logo = """ @@ -17,5 +17,5 @@ Literate.markdown( joinpath(pkgdir(ITensorNetworksNext)); flavor = Literate.CommonMarkFlavor(), name = "README", - postprocess = ccq_logo, + postprocess = ccq_logo ) diff --git a/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl b/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl index 4766ee6..972b11e 100644 --- a/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl +++ b/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl @@ -1,9 +1,9 @@ module ITensorNetworksNextTensorOperationsExt using BackendSelection: @Algorithm_str, Algorithm -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArrays, ismul, symnameddims, - substitute using ITensorNetworksNext.LazyNamedDimsArrays.TermInterface: arguments +using ITensorNetworksNext.LazyNamedDimsArrays: + LazyNamedDimsArrays, ismul, substitute, symnameddims using NamedDimsArrays: inds using TensorOperations: TensorOperations, optimaltree diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index 3c887b7..69a4a97 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -1,8 +1,6 @@ module AlgorithmsInterfaceExtensions -import AlgorithmsInterface as AI - -#========================== Patches for AlgorithmsInterface.jl ============================# +import AlgorithmsInterface as AI #========================== Patches for AlgorithmsInterface.jl ============================# abstract type Problem <: AI.Problem end abstract type Algorithm <: AI.Algorithm end @@ -28,9 +26,7 @@ function AI.initialize_state( problem, algorithm, algorithm.stopping_criterion ) return DefaultState(; stopping_criterion_state, kwargs...) -end - -#============================ DefaultState ================================================# +end #============================ DefaultState ================================================# @kwdef mutable struct DefaultState{ Iterate, StoppingCriterionState <: AI.StoppingCriterionState, @@ -38,16 +34,12 @@ end iterate::Iterate iteration::Int = 0 stopping_criterion_state::StoppingCriterionState -end - -#============================ increment! ==================================================# +end #============================ increment! ==================================================# # Custom version of `increment!` that also takes the problem and algorithm as arguments. function AI.increment!(problem::Problem, algorithm::Algorithm, state::State) return AI.increment!(state) -end - -#============================ solve! ======================================================# +end #============================ solve! ======================================================# # Custom version of `solve!` that allows specifying the logger and also overloads # `increment!` on the problem and algorithm. @@ -58,13 +50,13 @@ default_logging_context_prefix(x) = Symbol(basetypenameof(x), :_) function default_logging_context_prefix(problem::Problem, algorithm::Algorithm) return Symbol( default_logging_context_prefix(problem), - default_logging_context_prefix(algorithm), + default_logging_context_prefix(algorithm) ) end function AI.solve!( problem::Problem, algorithm::Algorithm, state::State; logging_context_prefix = default_logging_context_prefix(problem, algorithm), - kwargs..., + kwargs... ) logger = AI.algorithm_logger() @@ -97,13 +89,11 @@ end function AI.solve( problem::Problem, algorithm::Algorithm; logging_context_prefix = default_logging_context_prefix(problem, algorithm), - kwargs..., + kwargs... ) state = AI.initialize_state(problem, algorithm; kwargs...) return AI.solve!(problem, algorithm, state; logging_context_prefix, kwargs...) -end - -#============================ AlgorithmIterator ===========================================# +end #============================ AlgorithmIterator ===========================================# abstract type AlgorithmIterator end @@ -136,9 +126,7 @@ struct DefaultAlgorithmIterator{Problem, Algorithm, State} <: AlgorithmIterator problem::Problem algorithm::Algorithm state::State -end - -#============================ with_algorithmlogger ========================================# +end #============================ with_algorithmlogger ========================================# # Allow passing functions, not just CallbackActions. @inline function with_algorithmlogger(f, args::Pair{Symbol, AI.LoggingAction}...) @@ -146,9 +134,7 @@ end end @inline function with_algorithmlogger(f, args::Pair{Symbol}...) return AI.with_algorithmlogger(f, (first.(args) .=> AI.CallbackAction.(last.(args)))...) -end - -#============================ NestedAlgorithm =============================================# +end #============================ NestedAlgorithm =============================================# abstract type NestedAlgorithm <: Algorithm end @@ -213,8 +199,7 @@ end function DefaultNestedAlgorithm(f::Function, iterable; kwargs...) return DefaultNestedAlgorithm(; algorithms = map(f, iterable), kwargs...) -end -#============================ FlattenedAlgorithm ==========================================# +end #============================ FlattenedAlgorithm ==========================================# # Flatten a nested algorithm. abstract type FlattenedAlgorithm <: Algorithm end @@ -284,9 +269,7 @@ end parent_iteration::Int = 1 child_iteration::Int = 0 stopping_criterion_state::StoppingCriterionState -end - -#============================ NonIterativeAlgorithm =======================================# +end #============================ NonIterativeAlgorithm =======================================# # Algorithm that only performs a single step. abstract type NonIterativeAlgorithm <: Algorithm end diff --git a/src/LazyNamedDimsArrays/symbolicarray.jl b/src/LazyNamedDimsArrays/symbolicarray.jl index a0922fd..e3ff4d4 100644 --- a/src/LazyNamedDimsArrays/symbolicarray.jl +++ b/src/LazyNamedDimsArrays/symbolicarray.jl @@ -1,8 +1,12 @@ # TODO: Allow dynamic/unknown number of dimensions by supporting vector axes. -struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N} +struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: + AbstractArray{T, N} name::Name axes::Axes - function SymbolicArray{T}(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) where {T} + function SymbolicArray{T}( + name, + ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}} + ) where {T} N = length(ax) return new{T, N, typeof(name), typeof(ax)}(name, ax) end diff --git a/src/TensorNetworkGenerators/delta_network.jl b/src/TensorNetworkGenerators/delta_network.jl index 8b28def..e6a453c 100644 --- a/src/TensorNetworkGenerators/delta_network.jl +++ b/src/TensorNetworkGenerators/delta_network.jl @@ -1,6 +1,6 @@ +using ..ITensorNetworksNext: TensorNetwork using DiagonalArrays: δ using Graphs: AbstractGraph -using ..ITensorNetworksNext: TensorNetwork using NamedGraphs.GraphsExtensions: incident_edges """ diff --git a/src/TensorNetworkGenerators/ising_network.jl b/src/TensorNetworkGenerators/ising_network.jl index 1f2fa31..e37551c 100644 --- a/src/TensorNetworkGenerators/ising_network.jl +++ b/src/TensorNetworkGenerators/ising_network.jl @@ -1,6 +1,6 @@ +using ..ITensorNetworksNext: @preserve_graph using DiagonalArrays: DiagonalArray using Graphs: degree, dst, edges, src -using ..ITensorNetworksNext: @preserve_graph using LinearAlgebra: Diagonal, eigen using NamedDimsArrays: apply, denamed, name, operator, randname using NamedGraphs.GraphsExtensions: vertextype diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index c4b6fcb..7fca799 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -1,17 +1,17 @@ using Adapt: Adapt, adapt using BackendSelection: @Algorithm_str, Algorithm -using DataGraphs: AbstractDataGraph, DataGraphs, edge_data, set_vertex_data!, +using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, set_vertex_data!, underlying_graph, underlying_graph_type, vertex_data using Dictionaries: Dictionary -using Graphs: AbstractEdge, AbstractGraph, Graphs, add_edge!, add_vertex!, - dst, edges, edgetype, ne, neighbors, nv, rem_edge!, src, vertices +using Graphs: Graphs, AbstractEdge, AbstractGraph, add_edge!, add_vertex!, dst, edges, + edgetype, ne, neighbors, nv, rem_edge!, src, vertices using LinearAlgebra: LinearAlgebra using MacroTools: @capture using NamedDimsArrays: dimnames, inds -using NamedGraphs: NamedGraph, NamedGraphs, not_implemented +using NamedGraphs.GraphsExtensions: + directed_graph, incident_edges, rem_edges!, similar_graph, vertextype using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger -using NamedGraphs.GraphsExtensions: directed_graph, incident_edges, rem_edges!, - similar_graph, vertextype +using NamedGraphs: NamedGraphs, NamedGraph, not_implemented abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end @@ -125,7 +125,7 @@ is_assignment_expr(expr) = false macro preserve_graph(expr) if !is_setindex!_expr(expr) error( - "preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)", + "preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)" ) end @capture(expr, array_[indices__] = value_) @@ -207,7 +207,7 @@ Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) = not_implemented() function Base.setindex!( tn::AbstractTensorNetwork, value, - edge::Pair{<:OrdinalSuffixedInteger, <:OrdinalSuffixedInteger}, + edge::Pair{<:OrdinalSuffixedInteger, <:OrdinalSuffixedInteger} ) return not_implemented() end diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index b77fb4e..33f185b 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -1,7 +1,7 @@ -using Graphs: AbstractGraph, AbstractEdge -using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_type +using DataGraphs: AbstractDataGraph, edge_data, edge_data_type, vertex_data +using Graphs: AbstractEdge, AbstractGraph using NamedGraphs.GraphsExtensions: boundary_edges -using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, parent +using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, parent messages(bp_cache::AbstractGraph) = edge_data(bp_cache) messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges] @@ -63,7 +63,6 @@ function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge) end function region_scalar(bp_cache::AbstractGraph, vertex) - messages = incoming_messages(bp_cache, vertex) state = factors(bp_cache, vertex) @@ -78,7 +77,10 @@ function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache)) return map(v -> region_scalar(bp_cache, v), vertices) end -function edge_scalars(bp_cache::AbstractGraph, edges = edges(undirected_graph(underlying_graph(bp_cache)))) +function edge_scalars( + bp_cache::AbstractGraph, + edges = edges(undirected_graph(underlying_graph(bp_cache))) + ) return map(e -> region_scalar(bp_cache, e), edges) end @@ -123,7 +125,6 @@ message_type(bpc::AbstractBeliefPropagationCache) = message_type(typeof(bpc)) message_type(::Type{<:AbstractBeliefPropagationCache{<:Any, <:Any, ED}}) where {ED} = ED function free_energy(bp_cache::AbstractBeliefPropagationCache) - numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) if any(t -> real(t) < 0, numerator_terms) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 2c253e6..5d1a31c 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -1,19 +1,23 @@ -using DataGraphs: AbstractDataGraph, DataGraphs, edge_data, edge_data_type, - set_vertex_data!, underlying_graph, underlying_graph_type, vertex_data, - vertex_data_type -using Dictionaries: Dictionary, delete!, set!, getindices -using Graphs: AbstractGraph, connected_components, is_tree, is_directed +using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, edge_data_type, + set_vertex_data!, underlying_graph, underlying_graph_type, vertex_data, vertex_data_type +using Dictionaries: Dictionary, delete!, getindices, set! +using Graphs: AbstractGraph, connected_components, is_directed, is_tree using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, lazy, parenttype -using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, undirected_graph, vertextype +using NamedGraphs.GraphsExtensions: + default_root_vertex, forest_cover, post_order_dfs_edges, undirected_graph, vertextype using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, quotient_graph - using NamedGraphs: Vertices, convert_vertextype, parent_graph_indices -struct BeliefPropagationCache{V, VD, ED, E, G <: AbstractGraph{V}} <: AbstractBeliefPropagationCache{V, VD, ED} +struct BeliefPropagationCache{V, VD, ED, E, G <: AbstractGraph{V}} <: + AbstractBeliefPropagationCache{V, VD, ED} underlying_graph::G # we only use this for the edges. factors::Dictionary{V, VD} messages::Dictionary{E, ED} - function BeliefPropagationCache(graph::AbstractGraph, factors::Dictionary, messages::Dictionary) + function BeliefPropagationCache( + graph::AbstractGraph, + factors::Dictionary, + messages::Dictionary + ) # Ensure the graph is directed, if not make it directed. digraph = is_directed(graph) ? graph : directed_graph(graph) @@ -34,14 +38,22 @@ end DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = bpc.underlying_graph -DataGraphs.is_vertex_assigned(bpc::BeliefPropagationCache, vertex) = haskey(bpc.factors, vertex) +function DataGraphs.is_vertex_assigned(bpc::BeliefPropagationCache, vertex) + return haskey(bpc.factors, vertex) +end DataGraphs.is_edge_assigned(bpc::BeliefPropagationCache, edge) = haskey(bpc.messages, edge) DataGraphs.get_vertex_data(bpc::BeliefPropagationCache, vertex) = bpc.factors[vertex] -DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) = bpc.messages[edge] +function DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) + return bpc.messages[edge] +end -DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = set!(bpc.factors, vertex, val) -DataGraphs.set_edge_data!(bpc::BeliefPropagationCache, val, edge) = set!(bpc.messages, edge, val) +function DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) + return set!(bpc.factors, vertex, val) +end +function DataGraphs.set_edge_data!(bpc::BeliefPropagationCache, val, edge) + return set!(bpc.messages, edge, val) +end # These two methods assume `network` behaves llike a tensor network # (could be e.g. a QuotientView) otherwise how would one know what the factors should be. @@ -64,7 +76,11 @@ function BeliefPropagationCache(MT::Type, graph::AbstractGraph, factors::Diction end function Base.copy(bp_cache::BeliefPropagationCache) - return BeliefPropagationCache(copy(bp_cache.underlying_graph), copy(bp_cache.factors), copy(bp_cache.messages)) + return BeliefPropagationCache( + copy(bp_cache.underlying_graph), + copy(bp_cache.factors), + copy(bp_cache.messages) + ) end # TODO: This needs to go in GraphsExtensions @@ -85,7 +101,8 @@ function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_roo end function induced_subgraph_bpcache(graph, subvertices) - underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) + underlying_subgraph, vlist = + Graphs.induced_subgraph(underlying_graph(graph), subvertices) assigned = v -> isassigned(graph, v) @@ -100,7 +117,10 @@ function induced_subgraph_bpcache(graph, subvertices) return subgraph, vlist end -function NamedGraphs.induced_subgraph_from_vertices(graph::BeliefPropagationCache, subvertices) +function NamedGraphs.induced_subgraph_from_vertices( + graph::BeliefPropagationCache, + subvertices + ) return induced_subgraph_bpcache(graph, subvertices) end @@ -108,7 +128,6 @@ end # Take a QuotientView of the underlying graph. function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) - graph = underlying_graph(bpc) quotient_view = QuotientView(graph) @@ -137,6 +156,9 @@ function DataGraphs.get_index_data(tn::BeliefPropagationCache, vertex::QuotientV data = collect(map(v -> tn[v], vertices(tn, vertex))) return mapreduce(lazy, *, data) end -function DataGraphs.is_graph_index_assigned(tn::BeliefPropagationCache, vertex::QuotientVertex) +function DataGraphs.is_graph_index_assigned( + tn::BeliefPropagationCache, + vertex::QuotientVertex + ) return isassigned(tn, Vertices(vertices(tn, vertex))) end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 0312843..1a62792 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,12 +1,11 @@ +import .AlgorithmsInterfaceExtensions as AIE +import AlgorithmsInterface as AI +using DataGraphs: edge_data using Graphs: AbstractEdge, edges, has_edge, vertices -using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph +using LinearAlgebra: norm, normalize using NamedDimsArrays: AbstractNamedDimsArray +using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph using NamedGraphs.PartitionedGraphs: quotientvertices -using DataGraphs: edge_data -using LinearAlgebra: norm, normalize - -import AlgorithmsInterface as AI -import .AlgorithmsInterfaceExtensions as AIE @kwdef struct StopWhenConverged <: AI.StoppingCriterion tol::Float64 = 0.0 @@ -24,7 +23,7 @@ function AI.initialize_state!( ::AIE.Problem, ::AIE.Algorithm, ::StopWhenConverged, - st::StopWhenConvergedState, + st::StopWhenConvergedState ) st.delta = Inf return st @@ -35,7 +34,7 @@ function AI.is_finished!( ::AIE.Algorithm, state::AIE.State, c::StopWhenConverged, - st::StopWhenConvergedState, + st::StopWhenConvergedState ) # maxdiff = 0.0 initially, so skip this the first time. @@ -50,7 +49,8 @@ struct BeliefPropagationProblem{Network} <: AIE.Problem network::Network end -@kwdef mutable struct BeliefPropagationState{Iterate, Diffs} <: AIE.NonIterativeAlgorithmState +@kwdef mutable struct BeliefPropagationState{Iterate, Diffs} <: + AIE.NonIterativeAlgorithmState iterate::Iterate diffs::Diffs = similar(edge_data(iterate), Float64) maxdiff::Float64 = 0.0 @@ -83,7 +83,10 @@ function SimpleMessageUpdate( compute_diff = false, kwargs... ) - return SimpleMessageUpdate(edge, (; normalize, contraction_alg, compute_diff, kwargs...)) + return SimpleMessageUpdate( + edge, + (; normalize, contraction_alg, compute_diff, kwargs...) + ) end function Base.getproperty(alg::SimpleMessageUpdate, name::Symbol) @@ -106,12 +109,13 @@ struct BeliefPropagationSweep{ end end -BeliefPropagationSweep(f::Function, edges) = BeliefPropagationSweep(; algorithms = f.(edges)) +function BeliefPropagationSweep(f::Function, edges) + return BeliefPropagationSweep(; algorithms = f.(edges)) +end function AI.initialize_state( ::BeliefPropagationProblem, ::AIE.NonIterativeAlgorithm; iterate, kwargs... ) - diffs = iterate.diffs maxdiff = iterate.maxdiff @@ -122,7 +126,7 @@ end function AI.initialize_state!( ::BeliefPropagationProblem, ::BeliefPropagationSweep, - iteration_state::AIE.State, + iteration_state::AIE.State ) iteration_state.iterate.maxdiff = 0.0 return iteration_state @@ -132,9 +136,8 @@ function AIE.set_substate!( ::BeliefPropagationProblem, ::BeliefPropagationSweep, sweep_state::AIE.DefaultState, - noniterative_substate::BeliefPropagationState, + noniterative_substate::BeliefPropagationState ) - sweep_state.iterate = noniterative_substate return sweep_state @@ -149,9 +152,8 @@ function AI.solve!( problem::BeliefPropagationProblem, algorithm::SimpleMessageUpdate, state::BeliefPropagationState; - logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm), + logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm) ) - logger = AI.algorithm_logger() cache = state.iterate @@ -204,17 +206,20 @@ function AI.solve!( logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm), kwargs... ) - logger = AI.algorithm_logger() AI.emit_message( logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreUpdate) ) - state.iterate = contract_messages(algorithm.contraction_alg, problem.factor, problem.messages) + state.iterate = + contract_messages(algorithm.contraction_alg, problem.factor, problem.messages) AI.emit_message( - logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreNormalization) + logger, problem, algorithm, state, Symbol( + logging_context_prefix, + :PreNormalization + ) ) if algorithm.normalize @@ -226,7 +231,8 @@ function AI.solve!( end AI.emit_message( - logger, problem, algorithm, state, Symbol(logging_context_prefix, :PostNormalization) + logger, problem, algorithm, state, + Symbol(logging_context_prefix, :PostNormalization) ) return state @@ -237,9 +243,14 @@ function contract_messages(alg, factor::AbstractArray, messages) return contract_network(vcat(factors, messages); alg) end -beliefpropagation(network; kwargs...) = beliefpropagation(BeliefPropagationCache(network), network; kwargs...) -function beliefpropagation(cache::AbstractBeliefPropagationCache, network = nothing; kwargs...) - +function beliefpropagation(network; kwargs...) + return beliefpropagation(BeliefPropagationCache(network), network; kwargs...) +end +function beliefpropagation( + cache::AbstractBeliefPropagationCache, + network = nothing; + kwargs... + ) problem = BeliefPropagationProblem(network) algorithm = select_algorithm(beliefpropagation, cache; kwargs...) @@ -260,10 +271,13 @@ function select_algorithm( edges = forest_cover_edge_sequence(cache), maxiter = is_tree(cache) ? 1 : nothing, tol = -Inf, - message_diff_function = tol > -Inf ? (m1, m2) -> norm(m1 / norm(m1) - m2 / norm(m2)) : nothing, + message_diff_function = if tol > -Inf + (m1, m2) -> norm(m1 / norm(m1) - m2 / norm(m2)) + else + nothing + end, kwargs... ) - if isnothing(maxiter) throw(ArgumentError("`maxiter` must be specified for non-tree graphs")) end diff --git a/src/contract_network.jl b/src/contract_network.jl index a8c3fc7..9db4c32 100644 --- a/src/contract_network.jl +++ b/src/contract_network.jl @@ -1,7 +1,7 @@ using BackendSelection: @Algorithm_str, Algorithm using Base.Broadcast: materialize -using ITensorNetworksNext.LazyNamedDimsArrays: Mul, lazy, optimize_evaluation_order, - substitute, symnameddims +using ITensorNetworksNext.LazyNamedDimsArrays: + Mul, lazy, optimize_evaluation_order, substitute, symnameddims # This is related to `MatrixAlgebraKit.select_algorithm`. # TODO: Define this in BackendSelection.jl. diff --git a/src/sweeping/eigenproblem.jl b/src/sweeping/eigenproblem.jl index 36978b2..8fefbd0 100644 --- a/src/sweeping/eigenproblem.jl +++ b/src/sweeping/eigenproblem.jl @@ -1,5 +1,5 @@ -import AlgorithmsInterface as AI import .AlgorithmsInterfaceExtensions as AIE +import AlgorithmsInterface as AI function dmrg(operator, algorithm, state) problem = EigenProblem(operator) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 0d30970..a371373 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -1,25 +1,19 @@ +using .LazyNamedDimsArrays: Mul, lazy using Combinatorics: combinations -using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph +using DataGraphs.DataGraphsPartitionedGraphsExt +using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph, edge_data, get_vertices_data, + vertex_data, vertex_data_type using Dictionaries: AbstractDictionary, Indices, dictionary, set!, unset! -using Graphs: AbstractSimpleGraph, rem_vertex!, rem_edge! +using Graphs: AbstractSimpleGraph, rem_edge!, rem_vertex! using NamedDimsArrays: AbstractNamedDimsArray, dimnames -using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype, Vertices, parent_graph_indices -using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, arrange_edge, vertextype -using NamedGraphs.PartitionedGraphs: - AbstractPartitionedGraph, - PartitionedGraphs, - departition, - partitioned_vertices, - partitionedgraph, - quotient_graph, - quotient_graph_type, - QuotientVertex, - QuotientVertices, - QuotientVertexVertices, +using NamedGraphs.GraphsExtensions: + GraphsExtensions, arrange_edge, arranged_edges, vertextype +using NamedGraphs.PartitionedGraphs: AbstractPartitionedGraph, PartitionedGraphs, + QuotientVertex, QuotientVertexVertices, QuotientVertices, departition, + partitioned_vertices, partitionedgraph, quotient_graph, quotient_graph_type, quotientvertices -using .LazyNamedDimsArrays: lazy, Mul -using DataGraphs: vertex_data_type, vertex_data, edge_data, get_vertices_data -using DataGraphs.DataGraphsPartitionedGraphsExt +using NamedGraphs: + NamedGraphs, NamedEdge, NamedGraph, Vertices, parent_graph_indices, vertextype function _TensorNetwork end @@ -44,7 +38,9 @@ function TensorNetwork(graph::AbstractGraph, tensors::AbstractDictionary) return tn end -function TensorNetwork{V, VD, UG, Tensors}(graph::UG) where {V, VD, UG <: AbstractGraph{V}, Tensors} +function TensorNetwork{V, VD, UG, Tensors}( + graph::UG + ) where {V, VD, UG <: AbstractGraph{V}, Tensors} return _TensorNetwork(graph, Tensors()) end @@ -121,14 +117,20 @@ end NamedGraphs.convert_vertextype(::Type{V}, tn::TensorNetwork{V}) where {V} = tn NamedGraphs.convert_vertextype(V::Type, tn::TensorNetwork) = TensorNetwork{V}(tn) -Graphs.connected_components(tn::TensorNetwork) = Graphs.connected_components(underlying_graph(tn)) +function Graphs.connected_components(tn::TensorNetwork) + return Graphs.connected_components(underlying_graph(tn)) +end function Graphs.rem_edge!(tn::TensorNetwork, e) if !has_edge(underlying_graph(tn), e) return false end if !isempty(linkinds(tn, e)) - throw(ArgumentError("cannot remove edge $e due to tensor indices existing on this edge.")) + throw( + ArgumentError( + "cannot remove edge $e due to tensor indices existing on this edge." + ) + ) end rem_edge!(underlying_graph(tn), e) return true @@ -150,7 +152,8 @@ function NamedGraphs.induced_subgraph_from_vertices(graph::TensorNetwork, subver end function tensornetwork_induced_subgraph(graph, subvertices) - underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) + underlying_subgraph, vlist = + Graphs.induced_subgraph(underlying_graph(graph), subvertices) subgraph = TensorNetwork(underlying_subgraph) do vertex return graph[vertex] diff --git a/test/runtests.jl b/test/runtests.jl index 0008050..16689fa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,14 +10,19 @@ const GROUP = uppercase( get(ENV, "GROUP", "ALL") else only(match(pat, ARGS[arg_id]).captures) - end, + end ) -"match files of the form `test_*.jl`, but exclude `*setup*.jl`" +""" +match files of the form `test_*.jl`, but exclude `*setup*.jl` +""" function istestfile(fn) - return endswith(fn, ".jl") && startswith(basename(fn), "test_") && !contains(fn, "setup") + return endswith(fn, ".jl") && startswith(basename(fn), "test_") && + !contains(fn, "setup") end -"match files of the form `*.jl`, but exclude `*_notest.jl` and `*setup*.jl`" +""" +match files of the form `*.jl`, but exclude `*_notest.jl` and `*setup*.jl` +""" function isexamplefile(fn) return endswith(fn, ".jl") && !endswith(fn, "_notest.jl") && !contains(fn, "setup") end @@ -57,7 +62,7 @@ end :macrocall, GlobalRef(Suppressor, Symbol("@suppress")), LineNumberNode(@__LINE__, @__FILE__), - :(include($filename)), + :(include($filename)) ) ) end diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl index 8e0665c..44e6a09 100644 --- a/test/test_algorithmsinterfaceextensions.jl +++ b/test/test_algorithmsinterfaceextensions.jl @@ -164,7 +164,7 @@ end # Test with CallbackAction (wrapped functions) state = AIE.with_algorithmlogger( :TestProblem_TestAlgorithm_PreStep => callback1, - :TestProblem_TestAlgorithm_PostStep => callback2, + :TestProblem_TestAlgorithm_PostStep => callback2 ) do return AI.solve(problem, algorithm; iterate = [0.0]) end @@ -227,7 +227,7 @@ end ) state = AIE.DefaultState(; iterate = [0.0], - stopping_criterion_state = stopping_criterion_state, + stopping_criterion_state = stopping_criterion_state ) # Test progression through iterations @@ -253,7 +253,7 @@ end state = AIE.DefaultState(; iterate = [5.0, 10.0], iteration = 1, - stopping_criterion_state, + stopping_criterion_state ) subproblem, subalgorithm, substate = AIE.get_subproblem(problem, nested_alg, state) @@ -264,7 +264,7 @@ end # Test set_substate! new_substate = AIE.DefaultState(; iterate = [100.0, 200.0], - substate.stopping_criterion_state, + substate.stopping_criterion_state ) AIE.set_substate!(problem, nested_alg, state, new_substate) @test state.iterate ≈ [100.0, 200.0] @@ -321,7 +321,7 @@ end flattened_alg = AIE.DefaultFlattenedAlgorithm(; algorithms = nested_algs, - stopping_criterion = AI.StopAfterIteration(4), + stopping_criterion = AI.StopAfterIteration(4) ) problem = TestProblem([1.0]) @@ -330,7 +330,7 @@ end ) state = AIE.DefaultFlattenedAlgorithmState(; iterate = [0.0], - stopping_criterion_state = stopping_criterion_state, + stopping_criterion_state = stopping_criterion_state ) # Test initial state @@ -388,7 +388,7 @@ end # Using the helper function flattened_alg = AIE.flattened_algorithm(2) do i AIE.nested_algorithm(1) do j - TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) end end diff --git a/test/test_aqua.jl b/test/test_aqua.jl index a38563a..8eb4612 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -1,5 +1,5 @@ -using ITensorNetworksNext: ITensorNetworksNext using Aqua: Aqua +using ITensorNetworksNext: ITensorNetworksNext using Test: @testset @testset "Code quality (Aqua.jl)" begin diff --git a/test/test_basics.jl b/test/test_basics.jl index 0c9d803..9f80b25 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,7 +1,7 @@ using Dictionaries: Indices using Graphs: dst, edges, has_edge, ne, nv, src, vertices -using ITensorNetworksNext: TensorNetwork, linkinds, siteinds using ITensorBase: Index +using ITensorNetworksNext: TensorNetwork, linkinds, siteinds using NamedDimsArrays: dimnames using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges using NamedGraphs.NamedGraphGenerators: named_grid diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 8a817b2..d1cca76 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -1,26 +1,26 @@ +using DiagonalArrays: δ using Dictionaries: Dictionary, set! -using ITensorBase: Index, ITensor, prime, noprime +using Graphs: AbstractGraph, dst, edges, src, vertices +using ITensorBase: ITensor, Index, noprime, prime using ITensorNetworksNext: - BeliefPropagationCache, - ITensorNetworksNext, - TensorNetwork, - partitionfunction -using DiagonalArrays: δ -using Graphs: src, dst, edges, vertices, AbstractGraph -using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree + ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, partitionfunction +using LinearAlgebra: LinearAlgebra +using NamedDimsArrays: inds, name using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges, vertextype +using NamedGraphs.NamedGraphGenerators: named_comb_tree, named_grid using Test: @test, @testset -using LinearAlgebra: LinearAlgebra -using NamedDimsArrays: name, inds function ising_tensornetwork(g::AbstractGraph, β::Real; h = 0.0) - links = Dictionary(edges(g), [Index(2; tags = "edge" => "e$(src(e))_$(dst(e))") for e in edges(g)]) + links = Dictionary( + edges(g), + [Index(2; tags = "edge" => "e$(src(e))_$(dst(e))") for e in edges(g)] + ) links = merge(links, Dictionary(reverse.(edges(g)), [links[e] for e in edges(g)])) # symmetric sqrt of Boltzmann matrix W = exp(β σσ') sqrt_Ws = Dictionary() for e in edges(g) - W = [ exp(-(β + 2 * h)) exp(β); exp(β) exp(-(β - 2 * h)) ] + W = [exp(-(β + 2 * h)) exp(β); exp(β) exp(-(β - 2 * h))] F = LinearAlgebra.svd(W) U, S, V = F.U, F.S, F.Vt @@ -87,5 +87,4 @@ end z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact rtol = 1.0e-4 - end diff --git a/test/test_contract_network.jl b/test/test_contract_network.jl index 35b2275..b453e76 100644 --- a/test/test_contract_network.jl +++ b/test/test_contract_network.jl @@ -1,11 +1,11 @@ +using BackendSelection: @Algorithm_str, Algorithm using Graphs: edges +using ITensorBase: Index +using ITensorNetworksNext: TensorNetwork, contract_network, linkinds, siteinds using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges using NamedGraphs.NamedGraphGenerators: named_grid -using ITensorBase: Index -using ITensorNetworksNext: TensorNetwork, linkinds, siteinds, contract_network using TensorOperations: TensorOperations using Test: @test, @testset -using BackendSelection: @Algorithm_str, Algorithm @testset "contract_network" begin orderalg = alg -> Algorithm"exact"(; order_alg = Algorithm(alg)) diff --git a/test/test_dmrg.jl b/test/test_dmrg.jl index 01f04ac..dba2570 100644 --- a/test/test_dmrg.jl +++ b/test/test_dmrg.jl @@ -1,6 +1,6 @@ import AlgorithmsInterface as AI -using ITensorNetworksNext: EigsolveRegion, dmrg, select_algorithm import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +using ITensorNetworksNext: EigsolveRegion, dmrg, select_algorithm using Test: @test, @testset @testset "select_algorithm(dmrg, ...)" begin @@ -21,7 +21,7 @@ using Test: @test, @testset return EigsolveRegion( regions[j]; maxdim = maxdims[i], - cutoff = cutoffs[i], + cutoff = cutoffs[i] ) end end diff --git a/test/test_lazynameddimsarrays.jl b/test/test_lazynameddimsarrays.jl index d067c24..751b469 100644 --- a/test/test_lazynameddimsarrays.jl +++ b/test/test_lazynameddimsarrays.jl @@ -1,9 +1,9 @@ using AbstractTrees: AbstractTrees, print_tree, printnode using Base.Broadcast: materialize -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArrays, LazyNamedDimsArray, - Mul, SymbolicArray, ismul, lazy, substitute, symnameddims -using NamedDimsArrays: NamedDimsArray, @names, denamed, dimnames, inds, nameddims, - namedoneto +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, LazyNamedDimsArrays, Mul, + SymbolicArray, ismul, lazy, substitute, symnameddims +using NamedDimsArrays: + @names, NamedDimsArray, denamed, dimnames, inds, nameddims, namedoneto using TermInterface: arguments, arity, children, head, iscall, isexpr, maketerm, operation, sorted_arguments, sorted_children using Test: @test, @test_throws, @testset diff --git a/test/test_tensornetworkgenerators.jl b/test/test_tensornetworkgenerators.jl index 2d092c3..f29a900 100644 --- a/test/test_tensornetworkgenerators.jl +++ b/test/test_tensornetworkgenerators.jl @@ -1,8 +1,8 @@ using DiagonalArrays: δ using Graphs: edges, ne, nv, vertices using ITensorBase: Index -using ITensorNetworksNext: contract_network using ITensorNetworksNext.TensorNetworkGenerators: delta_network, ising_network +using ITensorNetworksNext: contract_network using NamedDimsArrays: inds using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges using NamedGraphs.NamedGraphGenerators: named_grid From 3654af647838b5b6724c9f14542b1e51534cee8c Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 13 Feb 2026 17:20:29 -0500 Subject: [PATCH 40/45] Working Parallel BP --- Project.toml | 4 + .../ITensorNetworksNextDaggerExt.jl | 86 ++++++++++ .../daggerbeliefpropagation.jl | 150 ++++++++++++++++++ .../ITensorNetworksNextDistributedExt.jl | 84 ++++++++++ .../distributedbeliefpropagation.jl | 116 ++++++++++++++ src/ITensorNetworksNext.jl | 2 + .../ITensorNetworksNextParallel.jl | 27 ++++ src/ITensorNetworksNextParallel/dagger.jl | 38 +++++ .../distributed.jl | 38 +++++ 9 files changed, 545 insertions(+) create mode 100644 ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl create mode 100644 ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl create mode 100644 ext/ITensorNetworksNextDistributedExt/ITensorNetworksNextDistributedExt.jl create mode 100644 ext/ITensorNetworksNextDistributedExt/distributedbeliefpropagation.jl create mode 100644 src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl create mode 100644 src/ITensorNetworksNextParallel/dagger.jl create mode 100644 src/ITensorNetworksNextParallel/distributed.jl diff --git a/Project.toml b/Project.toml index 7576273..11bcf27 100644 --- a/Project.toml +++ b/Project.toml @@ -29,9 +29,13 @@ WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [weakdeps] TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" [extensions] ITensorNetworksNextTensorOperationsExt = "TensorOperations" +ITensorNetworksNextDistributedExt = "Distributed" +ITensorNetworksNextDaggerExt = "Dagger" [compat] AbstractTrees = "0.4.5" diff --git a/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl new file mode 100644 index 0000000..b5e2c80 --- /dev/null +++ b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl @@ -0,0 +1,86 @@ +module ITensorNetworksNextDaggerExt + +using Dagger +using Dagger.Distributed +using ITensorNetworksNext.ITensorNetworksNextParallel: DaggerNestedAlgorithm, DaggerState, + ITensorNetworksNextParallel + +import AlgorithmsInterface as AI +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE + + +function ITensorNetworksNextParallel.DaggerNestedAlgorithm(f::Function, iterable; workers = workers(), kwargs...) + return DaggerNestedAlgorithm(; algorithms = map(f, iterable), workers, kwargs...) +end + +function initialize_dagger_state( + problem::AIE.Problem, + algorithm::AIE.Algorithm; + iterate, + remote_subiterates = Dict{Int, Dagger.Chunk}(), + ) + + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + + remote_results = Dict{Int, Dagger.DTask}() + + return DaggerState(; iterate, remote_subiterates, stopping_criterion_state, remote_results) +end + +function AI.initialize_state( + problem::AIE.Problem, + algorithm::DaggerNestedAlgorithm; + kwargs... + ) + return initialize_dagger_state(problem, algorithm; kwargs...) +end + +function AIE.get_subproblem( + problem::AIE.Problem, + algorithm::AIE.NestedAlgorithm, + state::DaggerState + ) + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + + iterate = state.iterate + remote_subiterates = state.remote_subiterates + + substate = AI.initialize_state(subproblem, subalgorithm; iterate, remote_subiterates) + + return subproblem, subalgorithm, substate +end + + +function AI.step!( + problem::AI.Problem, + algorithm::DaggerNestedAlgorithm, + state::DaggerState; + kwargs... + ) + + subproblem, subalgorithm, subiterate_chunk = AIE.get_subproblem(problem, algorithm, state) + + dtask = Dagger.@spawn AI.solve(subproblem, subalgorithm; iterate = subiterate_chunk) + + AIE.set_substate!(problem, algorithm, state, dtask) + + return state +end + +function AIE.set_substate!( + ::AIE.Problem, + ::DaggerNestedAlgorithm, + state::DaggerState, + dtask::Dagger.DTask, + ) + state.remote_results[state.iteration] = dtask + + return state +end + +include("daggerbeliefpropagation.jl") + +end # ITensorNetworksNextDaggerExt diff --git a/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl b/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl new file mode 100644 index 0000000..a4acdce --- /dev/null +++ b/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl @@ -0,0 +1,150 @@ +using Dagger +using Dagger.Distributed + +using DataGraphs: DataGraphs, get_edge_data, get_vertex_data, is_edge_assigned, + is_vertex_assigned, set_edge_data!, set_vertex_data!, underlying_graph +using Dictionaries: Indices +using Graphs: AbstractEdge, AbstractGraph, dst, edges, src, vertices +using ITensorNetworksNext.ITensorNetworksNextParallel: DaggerBeliefPropagationCache, + DaggerNestedAlgorithm, DaggerState, ITensorNetworksNextParallel, dagger_algorithm, + subcache +using ITensorNetworksNext: BeliefPropagation, BeliefPropagationCache, + BeliefPropagationProblem, BeliefPropagationState, ITensorNetworksNext, + beliefpropagation, forest_cover_edge_sequence, select_algorithm +using NamedGraphs.PartitionedGraphs: QuotientVertex, quotientedges, quotientvertices +using NamedGraphs: NamedGraphs +using NamedGraphs.GraphsExtensions: boundary_edges + +function ITensorNetworksNextParallel.subcache(cache::DaggerBeliefPropagationCache, inds) + return subcache(cache.underlying_cache, inds) +end + +function ITensorNetworksNextParallel.DaggerBeliefPropagationCache(network::AbstractGraph) + underlying_cache = BeliefPropagationCache(network) + + keys = Indices(quotientvertices(underlying_cache)) + + workers = Iterators.cycle(Distributed.workers()) + worker_dict = similar(keys, Int) + + for quotient_vertex in keys + worker, workers = Iterators.peel(workers) + worker_dict[quotient_vertex] = worker + end + + quotient_chunks = map(keys) do quotient_vertex + worker = worker_dict[quotient_vertex] + iterate = subcache(underlying_cache, quotient_vertex) + chunk = Dagger.@mutable worker = worker BeliefPropagationState(; iterate) + return chunk + end + + return DaggerBeliefPropagationCache(underlying_cache, quotient_chunks) +end + +DataGraphs.underlying_graph(cache::DaggerBeliefPropagationCache) = underlying_graph(cache.underlying_cache) + +DataGraphs.is_vertex_assigned(bpc::DaggerBeliefPropagationCache, vertex) = is_vertex_assigned(bpc.underlying_cache, vertex) +DataGraphs.is_edge_assigned(bpc::DaggerBeliefPropagationCache, edge) = is_edge_assigned(bpc.undelying_cache, edge) + +DataGraphs.get_vertex_data(bpc::DaggerBeliefPropagationCache, vertex) = get_vertex_data(bpc.underlying_cache, vertex) +DataGraphs.get_edge_data(bpc::DaggerBeliefPropagationCache, edge::AbstractEdge) = get_edge_data(bpc.undelying_caches, edge) + +DataGraphs.set_vertex_data!(bpc::DaggerBeliefPropagationCache, val, vertex) = set_vertex_data!(bpc.underlying_cache, val, vertex) +DataGraphs.set_edge_data!(bpc::DaggerBeliefPropagationCache, val, edge) = set_edge_data!(bpc.underlying_cache, val, edge) + +NamedGraphs.to_graph_index(::DaggerBeliefPropagationCache, qv::QuotientVertex) = qv +function DataGraphs.get_index_data(cache::DaggerBeliefPropagationCache, qv::QuotientVertex) + return cache.quotient_chunks[qv] +end + +function ITensorNetworksNext.beliefpropagation_sweep(cache::DaggerBeliefPropagationCache; edges, workers = workers(), kwargs...) + + keys = collect(quotientvertices(cache)) + + return dagger_algorithm(keys; keys, workers) do quotient_vertex + + subcache = fetch(cache[quotient_vertex]).iterate + + subcache_edges = forest_cover_edge_sequence(subcache) ∩ edges + incoming_edges = boundary_edges(cache, vertices(cache, quotient_vertex); dir = :in) + + alg = select_algorithm( + beliefpropagation, + subcache; + # Don't update the incoming messages + edges = setdiff(subcache_edges, incoming_edges), + maxiter = 1, + kwargs... + ) + + return alg + end +end + +function AI.initialize_state( + problem::AIE.Problem, + algorithm::BeliefPropagation{<:DaggerNestedAlgorithm}; + kwargs... + ) + return initialize_dagger_state(problem, algorithm; kwargs...) +end + +function AIE.get_subproblem( + problem::BeliefPropagationProblem, + algorithm::DaggerNestedAlgorithm, + state::DaggerState, + ) + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + + quotient_vertex = algorithm.keys[state.iteration] + + cache = state.iterate.iterate + + subiterate = cache[quotient_vertex] + + return subproblem, subalgorithm, subiterate +end + +function AIE.set_substate!( + ::BeliefPropagationProblem, + algorithm::AIE.NestedAlgorithm, + state::AIE.State, + substate::DaggerState, + ) + + dst_cache = state.iterate.iterate + + state.iterate.maxdiff = 0.0 + + current_algorithm = algorithm.algorithms[state.iteration] + + for (i, quotient_vertex) in enumerate(current_algorithm.keys) + get_maxdiff = dtask -> dtask.iterate.maxdiff + src_maxdiff = fetch(Dagger.@spawn get_maxdiff(substate.remote_results[i])) + + if src_maxdiff > state.iterate.maxdiff + state.iterate.maxdiff = src_maxdiff + end + end + + + transfer_edges! = (dst_chunk, src_chunk, edges) -> begin + src_subcache = src_chunk.iterate + dst_subcache = dst_chunk.iterate + for edge in edges + dst_subcache[edge] = src_subcache[edge] + end + end + + transfer_dtasks = map(quotientedges(dst_cache)) do quotient_edge + src_subcache = dst_cache[src(quotient_edge)] + dst_subcache = dst_cache[dst(quotient_edge)] + return Dagger.@spawn transfer_edges!(dst_subcache, fetch(src_subcache), edges(dst_cache, quotient_edge)) + end + + wait.(transfer_dtasks) + + return state +end diff --git a/ext/ITensorNetworksNextDistributedExt/ITensorNetworksNextDistributedExt.jl b/ext/ITensorNetworksNextDistributedExt/ITensorNetworksNextDistributedExt.jl new file mode 100644 index 0000000..c19db03 --- /dev/null +++ b/ext/ITensorNetworksNextDistributedExt/ITensorNetworksNextDistributedExt.jl @@ -0,0 +1,84 @@ +module ITensorNetworksNextDistributedExt + +using Distributed + +import AlgorithmsInterface as AI +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE + +import ITensorNetworksNext.ITensorNetworksNextParallel as Parallel + +function initialize_distributed_state( + problem::AIE.Problem, + algorithm::AIE.Algorithm; + keys, + iterate, + kwargs... + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + remote_results = Dict{eltype(keys), Distributed.Future}() + + return Parallel.DistributedState(; iterate, stopping_criterion_state, remote_results) +end + +function AI.initialize_state( + problem::AIE.Problem, + algorithm::Parallel.DistributedNestedAlgorithm; + kwargs... + ) + return initialize_distributed_state(problem, algorithm; keys = algorithm.keys, kwargs...) +end + +function Parallel.DistributedNestedAlgorithm(f::Function, iterable; kwargs...) + return Parallel.DistributedNestedAlgorithm(; algorithms = map(f, iterable), kwargs...) +end + +function AIE.get_subproblem( + problem::AI.Problem, algorithm::Parallel.DistributedNestedAlgorithm, state::Parallel.DistributedState + ) + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + + return subproblem, subalgorithm, state.iterate +end + +function AI.step!( + problem::AI.Problem, + algorithm::Parallel.DistributedNestedAlgorithm, + state::Parallel.DistributedState; + kwargs... + ) + + subproblem, subalgorithm, subiterate = AIE.get_subproblem(problem, algorithm, state) + + # Do whatever should have happened at `step!`, but store the result as a future. + + function solve(subproblem, subalgorithm, iterate) + rv = AI.solve(subproblem, subalgorithm; iterate) + return rv + end + + future = remotecall(solve, algorithm.workers, subproblem, subalgorithm, subiterate) + + AIE.set_substate!(problem, algorithm, state, future) + + return state +end + +function AIE.set_substate!( + ::AIE.Problem, + algorithm::Parallel.DistributedNestedAlgorithm, + state::Parallel.DistributedState, + future::Distributed.Future, + ) + key = algorithm.keys[state.iteration] + + state.remote_results[key] = future + + return state +end + +include("distributedbeliefpropagation.jl") + +end # ITensorNetworksNextDistributedExt diff --git a/ext/ITensorNetworksNextDistributedExt/distributedbeliefpropagation.jl b/ext/ITensorNetworksNextDistributedExt/distributedbeliefpropagation.jl new file mode 100644 index 0000000..3848e65 --- /dev/null +++ b/ext/ITensorNetworksNextDistributedExt/distributedbeliefpropagation.jl @@ -0,0 +1,116 @@ +using DataGraphs: DataGraphs, get_edge_data, get_vertex_data, is_edge_assigned, + is_vertex_assigned, set_edge_data!, set_vertex_data!, underlying_graph +using Graphs: AbstractEdge, AbstractGraph, edges, vertices +using ITensorNetworksNext.ITensorNetworksNextParallel: DistributedBeliefPropagationCache, + DistributedNestedAlgorithm, DistributedState, ITensorNetworksNextParallel, + distributed_algorithm +using ITensorNetworksNext: BeliefPropagation, BeliefPropagationCache, + BeliefPropagationProblem, BeliefPropagationState, ITensorNetworksNext, + beliefpropagation, forest_cover_edge_sequence, select_algorithm, setmessages! +using NamedGraphs.GraphsExtensions: boundary_edges +using NamedGraphs.PartitionedGraphs: QuotientVertex, quotientvertices +using NamedGraphs: NamedGraphs + +function ITensorNetworksNextParallel.DistributedBeliefPropagationCache(network::AbstractGraph) + underlying_cache = BeliefPropagationCache(network) + return DistributedBeliefPropagationCache(underlying_cache) +end + +DataGraphs.underlying_graph(cache::DistributedBeliefPropagationCache) = underlying_graph(cache.underlying_cache) + +DataGraphs.is_vertex_assigned(bpc::DistributedBeliefPropagationCache, vertex) = is_vertex_assigned(bpc.underlying_cache, vertex) +DataGraphs.is_edge_assigned(bpc::DistributedBeliefPropagationCache, edge) = is_edge_assigned(bpc.undelying_cache, edge) + +DataGraphs.get_vertex_data(bpc::DistributedBeliefPropagationCache, vertex) = get_vertex_data(bpc.underlying_cache, vertex) +DataGraphs.get_edge_data(bpc::DistributedBeliefPropagationCache, edge::AbstractEdge) = get_edge_data(bpc.underlying_cache, edge) + +DataGraphs.set_vertex_data!(bpc::DistributedBeliefPropagationCache, val, vertex) = set_vertex_data!(bpc.underlying_cache, val, vertex) +DataGraphs.set_edge_data!(bpc::DistributedBeliefPropagationCache, val, edge) = set_edge_data!(bpc.underlying_cache, val, edge) + +NamedGraphs.to_graph_index(::DistributedBeliefPropagationCache, qv::QuotientVertex) = qv +function DataGraphs.get_index_data(cache::DistributedBeliefPropagationCache, qv::QuotientVertex) + return ITensorNetworksNextParallel.subcache(cache.underlying_cache, qv) +end +function ITensorNetworksNext.beliefpropagation_sweep( + cache::DistributedBeliefPropagationCache; edges, kwargs... + ) + + keys = collect(quotientvertices(cache)) + + return distributed_algorithm(keys; keys, workers = WorkerPool(workers())) do quotient_vertex + + subcache = cache[quotient_vertex] + subcache_edges = forest_cover_edge_sequence(subcache) ∩ edges + incoming_edges = boundary_edges(cache, vertices(cache, quotient_vertex); dir = :in) + + alg = select_algorithm( + beliefpropagation, + subcache; + edges = setdiff(subcache_edges, incoming_edges), + maxiter = 1, + kwargs... + ) + + return alg + end +end + +function AI.initialize_state( + problem::AIE.Problem, + algorithm::BeliefPropagation{<:DistributedNestedAlgorithm}; + kwargs... + ) + + keys = first(algorithm.algorithms).keys + + return initialize_distributed_state(problem, algorithm; keys = keys, kwargs...) +end + +function AIE.get_subproblem( + problem::BeliefPropagationProblem, + algorithm::DistributedNestedAlgorithm, + state::DistributedState + ) + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + + cache = state.iterate.iterate + + quotient_vertex = algorithm.keys[state.iteration] + subiterate = BeliefPropagationState(; iterate = cache[quotient_vertex]) + + return subproblem, subalgorithm, subiterate +end + +function AIE.set_substate!( + ::BeliefPropagationProblem, + ::AIE.NestedAlgorithm, + state::AIE.State, + substate::DistributedState, + ) + + dst_cache = state.iterate.iterate + + state.iterate.maxdiff = 0.0 + + for quotient_vertex in quotientvertices(dst_cache) + + src_state = fetch(substate.remote_results[quotient_vertex]).iterate + + src_cache = src_state.iterate + src_maxdiff = src_state.maxdiff + + incoming_edges = boundary_edges(dst_cache, vertices(dst_cache, quotient_vertex); dir = :in) + + updated_messages = setdiff(edges(src_cache), incoming_edges) + + setmessages!(dst_cache, src_cache, updated_messages) + + if src_maxdiff > state.iterate.maxdiff + state.iterate.maxdiff = src_maxdiff + end + + end + + return state +end diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index d3c5c21..dc50876 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -13,4 +13,6 @@ include("beliefpropagation/abstractbeliefpropagationcache.jl") include("beliefpropagation/beliefpropagationcache.jl") include("beliefpropagation/beliefpropagationproblem.jl") +include("ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl") + end diff --git a/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl new file mode 100644 index 0000000..1131ca3 --- /dev/null +++ b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl @@ -0,0 +1,27 @@ +module ITensorNetworksNextParallel + +using Graphs: neighbors, add_vertex!, vertices +using NamedGraphs.GraphsExtensions: subgraph +using NamedGraphs.PartitionedGraphs: QuotientVertex +using ..ITensorNetworksNext: BeliefPropagationCache + +subcache(cache::BeliefPropagationCache, vertex::QuotientVertex) = subcache(cache, vertices(cache, vertex)) +function subcache(cache::BeliefPropagationCache, vertices) + subcache = subgraph(cache, vertices) + + for vertex in vertices + for neighbor_vertex in neighbors(cache, vertex) + add_vertex!(subcache, neighbor_vertex) + # Add in necessary messages. + subcache[vertex => neighbor_vertex] = cache[vertex => neighbor_vertex] + subcache[neighbor_vertex => vertex] = cache[neighbor_vertex => vertex] + end + end + + return subcache +end + +include("distributed.jl") +include("dagger.jl") + +end # ITensorNetworksNextParallel diff --git a/src/ITensorNetworksNextParallel/dagger.jl b/src/ITensorNetworksNextParallel/dagger.jl new file mode 100644 index 0000000..eef37bc --- /dev/null +++ b/src/ITensorNetworksNextParallel/dagger.jl @@ -0,0 +1,38 @@ +import AlgorithmsInterface as AI +import ..ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +using ITensorNetworksNext: AbstractBeliefPropagationCache + +@kwdef mutable struct DaggerState{ + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, Chunk, DTask, + } <: AIE.State + iterate::Iterate + iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState + remote_subiterates::Dict{Int, Chunk} = Dict{Int, Any}() + remote_results::Dict{Int, DTask} = Dict{Int, Any}() +end + +@kwdef struct DaggerNestedAlgorithm{ + ChildAlgorithm <: AIE.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + KeyType, + } <: AIE.NestedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) + workers::Vector{Int} + keys::Vector{KeyType} = collect(1:length(algorithms)) +end + +function dagger_algorithm(f::Function, iterable; kwargs...) + return DaggerNestedAlgorithm(f, iterable; kwargs...) +end + +# ================================== belief propagation ================================== # + +struct DaggerBeliefPropagationCache{ + V, VD, ED, UC <: AbstractBeliefPropagationCache{V, VD, ED}, Chunks, + } <: AbstractBeliefPropagationCache{V, VD, ED} + underlying_cache::UC + quotient_chunks::Chunks +end diff --git a/src/ITensorNetworksNextParallel/distributed.jl b/src/ITensorNetworksNextParallel/distributed.jl new file mode 100644 index 0000000..01c1344 --- /dev/null +++ b/src/ITensorNetworksNextParallel/distributed.jl @@ -0,0 +1,38 @@ +import AlgorithmsInterface as AI +import ..ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE + +using ..ITensorNetworksNext: AbstractBeliefPropagationCache + +@kwdef mutable struct DistributedState{ + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, Future, KeyType, + } <: AIE.State + iterate::Iterate + iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState + remote_results::Dict{KeyType, Future} = Dict{Int, Any}() +end + +@kwdef struct DistributedNestedAlgorithm{ + ChildAlgorithm <: AIE.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + WorkerPool, + KeyType, + } <: AIE.NestedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) + workers::WorkerPool + keys::Vector{KeyType} = collect(1:length(algorithms)) +end + +function distributed_algorithm(f::Function, iterable; kwargs...) + return DistributedNestedAlgorithm(f, iterable; kwargs...) +end + +# ================================== belief propagation ================================== # + +struct DistributedBeliefPropagationCache{ + V, VD, ED, UC <: AbstractBeliefPropagationCache{V, VD, ED}, + } <: AbstractBeliefPropagationCache{V, VD, ED} + underlying_cache::UC +end From 3b868f3b64b5f743caf81a1e3166a0ea3ff53f49 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 26 Feb 2026 15:01:50 -0500 Subject: [PATCH 41/45] Remove basic `Distributed.jl` implementation. --- .../ITensorNetworksNextDaggerExt.jl | 60 ++++----- .../daggerbeliefpropagation.jl | 115 +++++++++-------- .../ITensorNetworksNextDistributedExt.jl | 84 ------------- .../distributedbeliefpropagation.jl | 116 ------------------ .../ITensorNetworksNextParallel.jl | 26 ++-- src/ITensorNetworksNextParallel/dagger.jl | 6 +- .../distributed.jl | 38 ------ src/abstracttensornetwork.jl | 1 + .../abstractbeliefpropagationcache.jl | 20 ++- 9 files changed, 118 insertions(+), 348 deletions(-) delete mode 100644 ext/ITensorNetworksNextDistributedExt/ITensorNetworksNextDistributedExt.jl delete mode 100644 ext/ITensorNetworksNextDistributedExt/distributedbeliefpropagation.jl delete mode 100644 src/ITensorNetworksNextParallel/distributed.jl diff --git a/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl index b5e2c80..95c4105 100644 --- a/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl +++ b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl @@ -1,37 +1,33 @@ module ITensorNetworksNextDaggerExt -using Dagger -using Dagger.Distributed -using ITensorNetworksNext.ITensorNetworksNextParallel: DaggerNestedAlgorithm, DaggerState, - ITensorNetworksNextParallel - import AlgorithmsInterface as AI import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +import ITensorNetworksNext.ITensorNetworksNextParallel as ITNNP +using Dagger +using ITensorNetworksNext.ITensorNetworksNextParallel: + DaggerNestedAlgorithm, DaggerState, ITensorNetworksNextParallel - -function ITensorNetworksNextParallel.DaggerNestedAlgorithm(f::Function, iterable; workers = workers(), kwargs...) +function ITNNP.DaggerNestedAlgorithm(f::Function, iterable; workers = workers(), kwargs...) return DaggerNestedAlgorithm(; algorithms = map(f, iterable), workers, kwargs...) end -function initialize_dagger_state( - problem::AIE.Problem, - algorithm::AIE.Algorithm; - iterate, - remote_subiterates = Dict{Int, Dagger.Chunk}(), - ) - +function initialize_dagger_state(problem::AIE.Problem, algorithm::AIE.Algorithm; iterate) stopping_criterion_state = AI.initialize_state( problem, algorithm, algorithm.stopping_criterion ) remote_results = Dict{Int, Dagger.DTask}() - return DaggerState(; iterate, remote_subiterates, stopping_criterion_state, remote_results) + return ITNNP.DaggerState(; + iterate, + remote_results, + stopping_criterion_state + ) end function AI.initialize_state( problem::AIE.Problem, - algorithm::DaggerNestedAlgorithm; + algorithm::ITNNP.DaggerNestedAlgorithm; kwargs... ) return initialize_dagger_state(problem, algorithm; kwargs...) @@ -39,43 +35,31 @@ end function AIE.get_subproblem( problem::AIE.Problem, - algorithm::AIE.NestedAlgorithm, - state::DaggerState + algorithm::ITNNP.DaggerNestedAlgorithm, + state::ITNNP.DaggerState ) subproblem = problem subalgorithm = algorithm.algorithms[state.iteration] - iterate = state.iterate - remote_subiterates = state.remote_subiterates + # This might be a Dagger.chun object. + iterate = ITNNP.get_subiterate(subproblem, subalgorithm, state) - substate = AI.initialize_state(subproblem, subalgorithm; iterate, remote_subiterates) + substate = Dagger.@spawn AI.initialize_state(subproblem, subalgorithm; iterate) return subproblem, subalgorithm, substate end - function AI.step!( problem::AI.Problem, - algorithm::DaggerNestedAlgorithm, - state::DaggerState; + algorithm::ITNNP.DaggerNestedAlgorithm, + state::ITNNP.DaggerState; kwargs... ) + subproblem, subalgorithm, substate_future = + AIE.get_subproblem(problem, algorithm, state) - subproblem, subalgorithm, subiterate_chunk = AIE.get_subproblem(problem, algorithm, state) - - dtask = Dagger.@spawn AI.solve(subproblem, subalgorithm; iterate = subiterate_chunk) - - AIE.set_substate!(problem, algorithm, state, dtask) + dtask = Dagger.@spawn AI.solve(subproblem, subalgorithm, substate_future) - return state -end - -function AIE.set_substate!( - ::AIE.Problem, - ::DaggerNestedAlgorithm, - state::DaggerState, - dtask::Dagger.DTask, - ) state.remote_results[state.iteration] = dtask return state diff --git a/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl b/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl index a4acdce..8029fdd 100644 --- a/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl +++ b/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl @@ -1,30 +1,22 @@ +import ITensorNetworksNext.ITensorNetworksNextParallel as ITNNP using Dagger -using Dagger.Distributed - using DataGraphs: DataGraphs, get_edge_data, get_vertex_data, is_edge_assigned, is_vertex_assigned, set_edge_data!, set_vertex_data!, underlying_graph using Dictionaries: Indices using Graphs: AbstractEdge, AbstractGraph, dst, edges, src, vertices -using ITensorNetworksNext.ITensorNetworksNextParallel: DaggerBeliefPropagationCache, - DaggerNestedAlgorithm, DaggerState, ITensorNetworksNextParallel, dagger_algorithm, - subcache -using ITensorNetworksNext: BeliefPropagation, BeliefPropagationCache, - BeliefPropagationProblem, BeliefPropagationState, ITensorNetworksNext, - beliefpropagation, forest_cover_edge_sequence, select_algorithm +using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagation, BeliefPropagationCache, + BeliefPropagationProblem, BeliefPropagationState, beliefpropagation, + forest_cover_edge_sequence, select_algorithm +using NamedGraphs.GraphsExtensions: boundary_edges using NamedGraphs.PartitionedGraphs: QuotientVertex, quotientedges, quotientvertices using NamedGraphs: NamedGraphs -using NamedGraphs.GraphsExtensions: boundary_edges -function ITensorNetworksNextParallel.subcache(cache::DaggerBeliefPropagationCache, inds) - return subcache(cache.underlying_cache, inds) -end - -function ITensorNetworksNextParallel.DaggerBeliefPropagationCache(network::AbstractGraph) +function ITNNP.DaggerBeliefPropagationCache(network::AbstractGraph) underlying_cache = BeliefPropagationCache(network) keys = Indices(quotientvertices(underlying_cache)) - workers = Iterators.cycle(Distributed.workers()) + workers = Iterators.cycle(Dagger.Distributed.workers()) worker_dict = similar(keys, Int) for quotient_vertex in keys @@ -39,31 +31,54 @@ function ITensorNetworksNextParallel.DaggerBeliefPropagationCache(network::Abstr return chunk end - return DaggerBeliefPropagationCache(underlying_cache, quotient_chunks) + return ITNNP.DaggerBeliefPropagationCache(underlying_cache, quotient_chunks) end -DataGraphs.underlying_graph(cache::DaggerBeliefPropagationCache) = underlying_graph(cache.underlying_cache) +function DataGraphs.underlying_graph(cache::ITNNP.DaggerBeliefPropagationCache) + return underlying_graph(cache.underlying_cache) +end -DataGraphs.is_vertex_assigned(bpc::DaggerBeliefPropagationCache, vertex) = is_vertex_assigned(bpc.underlying_cache, vertex) -DataGraphs.is_edge_assigned(bpc::DaggerBeliefPropagationCache, edge) = is_edge_assigned(bpc.undelying_cache, edge) +function DataGraphs.is_vertex_assigned(bpc::ITNNP.DaggerBeliefPropagationCache, vertex) + return is_vertex_assigned(bpc.underlying_cache, vertex) +end +function DataGraphs.is_edge_assigned(bpc::ITNNP.DaggerBeliefPropagationCache, edge) + return is_edge_assigned(bpc.undelying_cache, edge) +end -DataGraphs.get_vertex_data(bpc::DaggerBeliefPropagationCache, vertex) = get_vertex_data(bpc.underlying_cache, vertex) -DataGraphs.get_edge_data(bpc::DaggerBeliefPropagationCache, edge::AbstractEdge) = get_edge_data(bpc.undelying_caches, edge) +function DataGraphs.get_vertex_data(bpc::ITNNP.DaggerBeliefPropagationCache, vertex) + return get_vertex_data(bpc.underlying_cache, vertex) +end +function DataGraphs.get_edge_data( + bpc::ITNNP.DaggerBeliefPropagationCache, + edge::AbstractEdge + ) + return get_edge_data(bpc.undelying_caches, edge) +end -DataGraphs.set_vertex_data!(bpc::DaggerBeliefPropagationCache, val, vertex) = set_vertex_data!(bpc.underlying_cache, val, vertex) -DataGraphs.set_edge_data!(bpc::DaggerBeliefPropagationCache, val, edge) = set_edge_data!(bpc.underlying_cache, val, edge) +function DataGraphs.set_vertex_data!(bpc::ITNNP.DaggerBeliefPropagationCache, val, vertex) + return set_vertex_data!(bpc.underlying_cache, val, vertex) +end +function DataGraphs.set_edge_data!(bpc::ITNNP.DaggerBeliefPropagationCache, val, edge) + return set_edge_data!(bpc.underlying_cache, val, edge) +end -NamedGraphs.to_graph_index(::DaggerBeliefPropagationCache, qv::QuotientVertex) = qv -function DataGraphs.get_index_data(cache::DaggerBeliefPropagationCache, qv::QuotientVertex) +NamedGraphs.to_graph_index(::ITNNP.DaggerBeliefPropagationCache, qv::QuotientVertex) = qv +function DataGraphs.get_index_data( + cache::ITNNP.DaggerBeliefPropagationCache, + qv::QuotientVertex + ) return cache.quotient_chunks[qv] end -function ITensorNetworksNext.beliefpropagation_sweep(cache::DaggerBeliefPropagationCache; edges, workers = workers(), kwargs...) - +function ITensorNetworksNext.beliefpropagation_sweep( + cache::ITNNP.DaggerBeliefPropagationCache; + edges, + workers = workers(), + kwargs... + ) keys = collect(quotientvertices(cache)) - return dagger_algorithm(keys; keys, workers) do quotient_vertex - + return ITNNP.dagger_algorithm(keys; keys, workers) do quotient_vertex subcache = fetch(cache[quotient_vertex]).iterate subcache_edges = forest_cover_edge_sequence(subcache) ∩ edges @@ -84,64 +99,62 @@ end function AI.initialize_state( problem::AIE.Problem, - algorithm::BeliefPropagation{<:DaggerNestedAlgorithm}; + algorithm::BeliefPropagation{<:ITNNP.DaggerNestedAlgorithm}; kwargs... ) - return initialize_dagger_state(problem, algorithm; kwargs...) + return ITNNP.initialize_dagger_state(problem, algorithm; kwargs...) end -function AIE.get_subproblem( - problem::BeliefPropagationProblem, - algorithm::DaggerNestedAlgorithm, - state::DaggerState, +function ITNNP.get_subiterate( + ::BeliefPropagationProblem, + ::BeliefPropagation, # Our parallel region runs a small BP + state::ITNNP.DaggerState ) - subproblem = problem - subalgorithm = algorithm.algorithms[state.iteration] - - quotient_vertex = algorithm.keys[state.iteration] - cache = state.iterate.iterate + quotient_vertex = collect(quotientvertices(cache))[state.iteration] + subiterate = cache[quotient_vertex] - return subproblem, subalgorithm, subiterate + return subiterate end function AIE.set_substate!( ::BeliefPropagationProblem, - algorithm::AIE.NestedAlgorithm, + ::AIE.NestedAlgorithm, state::AIE.State, - substate::DaggerState, + substate::ITNNP.DaggerState ) - dst_cache = state.iterate.iterate state.iterate.maxdiff = 0.0 - current_algorithm = algorithm.algorithms[state.iteration] - - for (i, quotient_vertex) in enumerate(current_algorithm.keys) + for remote_result in substate.remote_results get_maxdiff = dtask -> dtask.iterate.maxdiff - src_maxdiff = fetch(Dagger.@spawn get_maxdiff(substate.remote_results[i])) + src_maxdiff = fetch(Dagger.@spawn get_maxdiff(remote_result)) if src_maxdiff > state.iterate.maxdiff state.iterate.maxdiff = src_maxdiff end end - - transfer_edges! = (dst_chunk, src_chunk, edges) -> begin + function transfer_edges!(dst_chunk, src_chunk, edges) src_subcache = src_chunk.iterate dst_subcache = dst_chunk.iterate for edge in edges dst_subcache[edge] = src_subcache[edge] end + return end transfer_dtasks = map(quotientedges(dst_cache)) do quotient_edge src_subcache = dst_cache[src(quotient_edge)] dst_subcache = dst_cache[dst(quotient_edge)] - return Dagger.@spawn transfer_edges!(dst_subcache, fetch(src_subcache), edges(dst_cache, quotient_edge)) + return Dagger.@spawn transfer_edges!( + dst_subcache, + fetch(src_subcache), + edges(dst_cache, quotient_edge) + ) end wait.(transfer_dtasks) diff --git a/ext/ITensorNetworksNextDistributedExt/ITensorNetworksNextDistributedExt.jl b/ext/ITensorNetworksNextDistributedExt/ITensorNetworksNextDistributedExt.jl deleted file mode 100644 index c19db03..0000000 --- a/ext/ITensorNetworksNextDistributedExt/ITensorNetworksNextDistributedExt.jl +++ /dev/null @@ -1,84 +0,0 @@ -module ITensorNetworksNextDistributedExt - -using Distributed - -import AlgorithmsInterface as AI -import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE - -import ITensorNetworksNext.ITensorNetworksNextParallel as Parallel - -function initialize_distributed_state( - problem::AIE.Problem, - algorithm::AIE.Algorithm; - keys, - iterate, - kwargs... - ) - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion - ) - remote_results = Dict{eltype(keys), Distributed.Future}() - - return Parallel.DistributedState(; iterate, stopping_criterion_state, remote_results) -end - -function AI.initialize_state( - problem::AIE.Problem, - algorithm::Parallel.DistributedNestedAlgorithm; - kwargs... - ) - return initialize_distributed_state(problem, algorithm; keys = algorithm.keys, kwargs...) -end - -function Parallel.DistributedNestedAlgorithm(f::Function, iterable; kwargs...) - return Parallel.DistributedNestedAlgorithm(; algorithms = map(f, iterable), kwargs...) -end - -function AIE.get_subproblem( - problem::AI.Problem, algorithm::Parallel.DistributedNestedAlgorithm, state::Parallel.DistributedState - ) - subproblem = problem - subalgorithm = algorithm.algorithms[state.iteration] - - return subproblem, subalgorithm, state.iterate -end - -function AI.step!( - problem::AI.Problem, - algorithm::Parallel.DistributedNestedAlgorithm, - state::Parallel.DistributedState; - kwargs... - ) - - subproblem, subalgorithm, subiterate = AIE.get_subproblem(problem, algorithm, state) - - # Do whatever should have happened at `step!`, but store the result as a future. - - function solve(subproblem, subalgorithm, iterate) - rv = AI.solve(subproblem, subalgorithm; iterate) - return rv - end - - future = remotecall(solve, algorithm.workers, subproblem, subalgorithm, subiterate) - - AIE.set_substate!(problem, algorithm, state, future) - - return state -end - -function AIE.set_substate!( - ::AIE.Problem, - algorithm::Parallel.DistributedNestedAlgorithm, - state::Parallel.DistributedState, - future::Distributed.Future, - ) - key = algorithm.keys[state.iteration] - - state.remote_results[key] = future - - return state -end - -include("distributedbeliefpropagation.jl") - -end # ITensorNetworksNextDistributedExt diff --git a/ext/ITensorNetworksNextDistributedExt/distributedbeliefpropagation.jl b/ext/ITensorNetworksNextDistributedExt/distributedbeliefpropagation.jl deleted file mode 100644 index 3848e65..0000000 --- a/ext/ITensorNetworksNextDistributedExt/distributedbeliefpropagation.jl +++ /dev/null @@ -1,116 +0,0 @@ -using DataGraphs: DataGraphs, get_edge_data, get_vertex_data, is_edge_assigned, - is_vertex_assigned, set_edge_data!, set_vertex_data!, underlying_graph -using Graphs: AbstractEdge, AbstractGraph, edges, vertices -using ITensorNetworksNext.ITensorNetworksNextParallel: DistributedBeliefPropagationCache, - DistributedNestedAlgorithm, DistributedState, ITensorNetworksNextParallel, - distributed_algorithm -using ITensorNetworksNext: BeliefPropagation, BeliefPropagationCache, - BeliefPropagationProblem, BeliefPropagationState, ITensorNetworksNext, - beliefpropagation, forest_cover_edge_sequence, select_algorithm, setmessages! -using NamedGraphs.GraphsExtensions: boundary_edges -using NamedGraphs.PartitionedGraphs: QuotientVertex, quotientvertices -using NamedGraphs: NamedGraphs - -function ITensorNetworksNextParallel.DistributedBeliefPropagationCache(network::AbstractGraph) - underlying_cache = BeliefPropagationCache(network) - return DistributedBeliefPropagationCache(underlying_cache) -end - -DataGraphs.underlying_graph(cache::DistributedBeliefPropagationCache) = underlying_graph(cache.underlying_cache) - -DataGraphs.is_vertex_assigned(bpc::DistributedBeliefPropagationCache, vertex) = is_vertex_assigned(bpc.underlying_cache, vertex) -DataGraphs.is_edge_assigned(bpc::DistributedBeliefPropagationCache, edge) = is_edge_assigned(bpc.undelying_cache, edge) - -DataGraphs.get_vertex_data(bpc::DistributedBeliefPropagationCache, vertex) = get_vertex_data(bpc.underlying_cache, vertex) -DataGraphs.get_edge_data(bpc::DistributedBeliefPropagationCache, edge::AbstractEdge) = get_edge_data(bpc.underlying_cache, edge) - -DataGraphs.set_vertex_data!(bpc::DistributedBeliefPropagationCache, val, vertex) = set_vertex_data!(bpc.underlying_cache, val, vertex) -DataGraphs.set_edge_data!(bpc::DistributedBeliefPropagationCache, val, edge) = set_edge_data!(bpc.underlying_cache, val, edge) - -NamedGraphs.to_graph_index(::DistributedBeliefPropagationCache, qv::QuotientVertex) = qv -function DataGraphs.get_index_data(cache::DistributedBeliefPropagationCache, qv::QuotientVertex) - return ITensorNetworksNextParallel.subcache(cache.underlying_cache, qv) -end -function ITensorNetworksNext.beliefpropagation_sweep( - cache::DistributedBeliefPropagationCache; edges, kwargs... - ) - - keys = collect(quotientvertices(cache)) - - return distributed_algorithm(keys; keys, workers = WorkerPool(workers())) do quotient_vertex - - subcache = cache[quotient_vertex] - subcache_edges = forest_cover_edge_sequence(subcache) ∩ edges - incoming_edges = boundary_edges(cache, vertices(cache, quotient_vertex); dir = :in) - - alg = select_algorithm( - beliefpropagation, - subcache; - edges = setdiff(subcache_edges, incoming_edges), - maxiter = 1, - kwargs... - ) - - return alg - end -end - -function AI.initialize_state( - problem::AIE.Problem, - algorithm::BeliefPropagation{<:DistributedNestedAlgorithm}; - kwargs... - ) - - keys = first(algorithm.algorithms).keys - - return initialize_distributed_state(problem, algorithm; keys = keys, kwargs...) -end - -function AIE.get_subproblem( - problem::BeliefPropagationProblem, - algorithm::DistributedNestedAlgorithm, - state::DistributedState - ) - subproblem = problem - subalgorithm = algorithm.algorithms[state.iteration] - - cache = state.iterate.iterate - - quotient_vertex = algorithm.keys[state.iteration] - subiterate = BeliefPropagationState(; iterate = cache[quotient_vertex]) - - return subproblem, subalgorithm, subiterate -end - -function AIE.set_substate!( - ::BeliefPropagationProblem, - ::AIE.NestedAlgorithm, - state::AIE.State, - substate::DistributedState, - ) - - dst_cache = state.iterate.iterate - - state.iterate.maxdiff = 0.0 - - for quotient_vertex in quotientvertices(dst_cache) - - src_state = fetch(substate.remote_results[quotient_vertex]).iterate - - src_cache = src_state.iterate - src_maxdiff = src_state.maxdiff - - incoming_edges = boundary_edges(dst_cache, vertices(dst_cache, quotient_vertex); dir = :in) - - updated_messages = setdiff(edges(src_cache), incoming_edges) - - setmessages!(dst_cache, src_cache, updated_messages) - - if src_maxdiff > state.iterate.maxdiff - state.iterate.maxdiff = src_maxdiff - end - - end - - return state -end diff --git a/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl index 1131ca3..8e90670 100644 --- a/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl +++ b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl @@ -1,27 +1,19 @@ module ITensorNetworksNextParallel -using Graphs: neighbors, add_vertex!, vertices +using ..ITensorNetworksNext: BeliefPropagationCache +using Graphs: add_vertex!, neighbors, vertices using NamedGraphs.GraphsExtensions: subgraph using NamedGraphs.PartitionedGraphs: QuotientVertex -using ..ITensorNetworksNext: BeliefPropagationCache - -subcache(cache::BeliefPropagationCache, vertex::QuotientVertex) = subcache(cache, vertices(cache, vertex)) -function subcache(cache::BeliefPropagationCache, vertices) - subcache = subgraph(cache, vertices) - for vertex in vertices - for neighbor_vertex in neighbors(cache, vertex) - add_vertex!(subcache, neighbor_vertex) - # Add in necessary messages. - subcache[vertex => neighbor_vertex] = cache[vertex => neighbor_vertex] - subcache[neighbor_vertex => vertex] = cache[neighbor_vertex => vertex] - end - end +""" + get_subiterate(subproblem::AI.Problem, subalgorithm::AI.Algorithm, state::AI.State) - return subcache -end +For a given `subproblem` and `subalgorithm` of a parent nested algorithm, +derive (from the parent state `state`) the iterate to be used in the associated sub state. +The returned value of this function is then pass to a remote call of `initialize_state`. +""" +get_subiterate(::AI.Problem, ::AI.Algorithm, state::AI.State) = state.iterate -include("distributed.jl") include("dagger.jl") end # ITensorNetworksNextParallel diff --git a/src/ITensorNetworksNextParallel/dagger.jl b/src/ITensorNetworksNextParallel/dagger.jl index eef37bc..8f79de5 100644 --- a/src/ITensorNetworksNextParallel/dagger.jl +++ b/src/ITensorNetworksNextParallel/dagger.jl @@ -1,14 +1,14 @@ -import AlgorithmsInterface as AI import ..ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +import AlgorithmsInterface as AI using ITensorNetworksNext: AbstractBeliefPropagationCache @kwdef mutable struct DaggerState{ Iterate, StoppingCriterionState <: AI.StoppingCriterionState, Chunk, DTask, } <: AIE.State - iterate::Iterate + iterate::Iterate # DaggerBeliefPropagationCache iteration::Int = 0 stopping_criterion_state::StoppingCriterionState - remote_subiterates::Dict{Int, Chunk} = Dict{Int, Any}() + # remote_subiterates::Dict{Int, Chunk} = Dict{Int, Any}() remote_results::Dict{Int, DTask} = Dict{Int, Any}() end diff --git a/src/ITensorNetworksNextParallel/distributed.jl b/src/ITensorNetworksNextParallel/distributed.jl deleted file mode 100644 index 01c1344..0000000 --- a/src/ITensorNetworksNextParallel/distributed.jl +++ /dev/null @@ -1,38 +0,0 @@ -import AlgorithmsInterface as AI -import ..ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE - -using ..ITensorNetworksNext: AbstractBeliefPropagationCache - -@kwdef mutable struct DistributedState{ - Iterate, StoppingCriterionState <: AI.StoppingCriterionState, Future, KeyType, - } <: AIE.State - iterate::Iterate - iteration::Int = 0 - stopping_criterion_state::StoppingCriterionState - remote_results::Dict{KeyType, Future} = Dict{Int, Any}() -end - -@kwdef struct DistributedNestedAlgorithm{ - ChildAlgorithm <: AIE.Algorithm, - Algorithms <: AbstractVector{ChildAlgorithm}, - StoppingCriterion <: AI.StoppingCriterion, - WorkerPool, - KeyType, - } <: AIE.NestedAlgorithm - algorithms::Algorithms - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) - workers::WorkerPool - keys::Vector{KeyType} = collect(1:length(algorithms)) -end - -function distributed_algorithm(f::Function, iterable; kwargs...) - return DistributedNestedAlgorithm(f, iterable; kwargs...) -end - -# ================================== belief propagation ================================== # - -struct DistributedBeliefPropagationCache{ - V, VD, ED, UC <: AbstractBeliefPropagationCache{V, VD, ED}, - } <: AbstractBeliefPropagationCache{V, VD, ED} - underlying_cache::UC -end diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index 7fca799..1e25df5 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -37,6 +37,7 @@ Base.copy(::AbstractTensorNetwork) = not_implemented() # Iteration Base.iterate(tn::AbstractTensorNetwork, args...) = iterate(vertex_data(tn), args...) + Base.keys(tn::AbstractTensorNetwork) = vertices(tn) # TODO: This contrasts with the `DataGraphs.AbstractDataGraph` definition, diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 33f185b..a810c9a 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -1,7 +1,7 @@ using DataGraphs: AbstractDataGraph, edge_data, edge_data_type, vertex_data using Graphs: AbstractEdge, AbstractGraph using NamedGraphs.GraphsExtensions: boundary_edges -using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, parent +using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, parent, QuotientVertex messages(bp_cache::AbstractGraph) = edge_data(bp_cache) messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges] @@ -141,3 +141,21 @@ function free_energy(bp_cache::AbstractBeliefPropagationCache) return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) end partitionfunction(bp_cache::AbstractBeliefPropagationCache) = exp(free_energy(bp_cache)) + +function subcache(cache::AbstractBeliefPropagationCache, vertex::QuotientVertex) + return subcache(cache, vertices(cache, vertex)) +end +function subcache(cache::AbstractBeliefPropagationCache, vertices) + subcache = subgraph(cache, vertices) + + for vertex in vertices + for neighbor_vertex in neighbors(cache, vertex) + add_vertex!(subcache, neighbor_vertex) + # Add in necessary messages. + subcache[vertex => neighbor_vertex] = cache[vertex => neighbor_vertex] + subcache[neighbor_vertex => vertex] = cache[neighbor_vertex => vertex] + end + end + + return subcache +end From 37fe65034e698fb7e499df27c9895106a216c60a Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 26 Feb 2026 15:06:03 -0500 Subject: [PATCH 42/45] Fix imports. --- .../ITensorNetworksNextParallel.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl index 8e90670..7117641 100644 --- a/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl +++ b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl @@ -1,9 +1,6 @@ module ITensorNetworksNextParallel -using ..ITensorNetworksNext: BeliefPropagationCache -using Graphs: add_vertex!, neighbors, vertices -using NamedGraphs.GraphsExtensions: subgraph -using NamedGraphs.PartitionedGraphs: QuotientVertex +import AlgorithmsInterface as AI """ get_subiterate(subproblem::AI.Problem, subalgorithm::AI.Algorithm, state::AI.State) From 3f488fa15957b864e70bf4548d31183f18b52332 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 26 Feb 2026 15:06:41 -0500 Subject: [PATCH 43/45] Fix Project.toml --- Project.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/Project.toml b/Project.toml index 11bcf27..59dea41 100644 --- a/Project.toml +++ b/Project.toml @@ -29,12 +29,10 @@ WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [weakdeps] TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" -Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" [extensions] ITensorNetworksNextTensorOperationsExt = "TensorOperations" -ITensorNetworksNextDistributedExt = "Distributed" ITensorNetworksNextDaggerExt = "Dagger" [compat] From dbc60670cb6f4dc27e669ea4a440f44bb3545029 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 2 Mar 2026 14:07:27 -0500 Subject: [PATCH 44/45] The `NestedAlgorithm` abstract type now takes the type of the child algorithm as its only type parameter --- .../AlgorithmsInterfaceExtensions.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index f44cbeb..bfe1d0c 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -149,7 +149,7 @@ end # ============================ NestedAlgorithm ============================================= -abstract type NestedAlgorithm <: Algorithm end +abstract type NestedAlgorithm{Child} <: Algorithm end function nested_algorithm(f::Function, iterable; kwargs...) return DefaultNestedAlgorithm(f, iterable; kwargs...) @@ -202,7 +202,7 @@ from a list of stored algorithms. ChildAlgorithm <: Algorithm, Algorithms <: AbstractVector{ChildAlgorithm}, StoppingCriterion <: AI.StoppingCriterion, - } <: NestedAlgorithm + } <: NestedAlgorithm{ChildAlgorithm} algorithms::Algorithms stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end From fe3ac5c2e7691523ecfbd9be0023e2d3d8ac1b0e Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 2 Mar 2026 14:07:59 -0500 Subject: [PATCH 45/45] Simplify parallel code. --- .../ITensorNetworksNextDaggerExt.jl | 41 +++---- .../daggerbeliefpropagation.jl | 105 ++++++++++++------ .../ITensorNetworksNextParallel.jl | 23 ++++ src/ITensorNetworksNextParallel/dagger.jl | 27 +++-- .../beliefpropagationproblem.jl | 18 ++- 5 files changed, 140 insertions(+), 74 deletions(-) diff --git a/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl index 95c4105..33375c7 100644 --- a/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl +++ b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl @@ -6,17 +6,24 @@ import ITensorNetworksNext.ITensorNetworksNextParallel as ITNNP using Dagger using ITensorNetworksNext.ITensorNetworksNextParallel: DaggerNestedAlgorithm, DaggerState, ITensorNetworksNextParallel +using Dictionaries: set! -function ITNNP.DaggerNestedAlgorithm(f::Function, iterable; workers = workers(), kwargs...) - return DaggerNestedAlgorithm(; algorithms = map(f, iterable), workers, kwargs...) +function ITNNP.DaggerNestedAlgorithm(f, iterable; kwargs...) + return DaggerNestedAlgorithm(; algorithms = map(f, iterable), kwargs...) end -function initialize_dagger_state(problem::AIE.Problem, algorithm::AIE.Algorithm; iterate) +function ITNNP.dagger_algorithm(f::Base.Callable, iterable; kwargs...) + return DaggerNestedAlgorithm(f, iterable; kwargs...) +end + +function ITNNP.initialize_dagger_state( + problem::AIE.Problem, algorithm::AIE.Algorithm; iterate + ) stopping_criterion_state = AI.initialize_state( problem, algorithm, algorithm.stopping_criterion ) - remote_results = Dict{Int, Dagger.DTask}() + remote_results = Dictionary{Int, Dagger.DTask}() return ITNNP.DaggerState(; iterate, @@ -30,37 +37,23 @@ function AI.initialize_state( algorithm::ITNNP.DaggerNestedAlgorithm; kwargs... ) - return initialize_dagger_state(problem, algorithm; kwargs...) + return ITNNP.initialize_dagger_state(problem, algorithm; kwargs...) end -function AIE.get_subproblem( +function AI.step!( problem::AIE.Problem, algorithm::ITNNP.DaggerNestedAlgorithm, - state::ITNNP.DaggerState + state::ITNNP.DaggerState; + kwargs... ) subproblem = problem subalgorithm = algorithm.algorithms[state.iteration] - # This might be a Dagger.chun object. iterate = ITNNP.get_subiterate(subproblem, subalgorithm, state) - substate = Dagger.@spawn AI.initialize_state(subproblem, subalgorithm; iterate) - - return subproblem, subalgorithm, substate -end - -function AI.step!( - problem::AI.Problem, - algorithm::ITNNP.DaggerNestedAlgorithm, - state::ITNNP.DaggerState; - kwargs... - ) - subproblem, subalgorithm, substate_future = - AIE.get_subproblem(problem, algorithm, state) - - dtask = Dagger.@spawn AI.solve(subproblem, subalgorithm, substate_future) + dtask = Dagger.@spawn AI.solve(subproblem, subalgorithm; iterate) - state.remote_results[state.iteration] = dtask + set!(state.remote_results, state.iteration, dtask) return state end diff --git a/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl b/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl index 8029fdd..547b6fb 100644 --- a/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl +++ b/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl @@ -1,33 +1,49 @@ import ITensorNetworksNext.ITensorNetworksNextParallel as ITNNP using Dagger -using DataGraphs: DataGraphs, get_edge_data, get_vertex_data, is_edge_assigned, +using DataGraphs: DataGraphs, edge_data, get_edge_data, get_vertex_data, is_edge_assigned, is_vertex_assigned, set_edge_data!, set_vertex_data!, underlying_graph -using Dictionaries: Indices +using Dictionaries: Dictionary, Indices, getindices using Graphs: AbstractEdge, AbstractGraph, dst, edges, src, vertices using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagation, BeliefPropagationCache, BeliefPropagationProblem, BeliefPropagationState, beliefpropagation, - forest_cover_edge_sequence, select_algorithm + forest_cover_edge_sequence, select_algorithm, subcache using NamedGraphs.GraphsExtensions: boundary_edges using NamedGraphs.PartitionedGraphs: QuotientVertex, quotientedges, quotientvertices using NamedGraphs: NamedGraphs -function ITNNP.DaggerBeliefPropagationCache(network::AbstractGraph) +const DaggerBeliefPropagation = BeliefPropagation{<:ITNNP.DaggerNestedAlgorithm}; + +function ITNNP.DaggerBeliefPropagationCache( + network::AbstractGraph; + workers = nothing, + scopes = nothing + ) underlying_cache = BeliefPropagationCache(network) keys = Indices(quotientvertices(underlying_cache)) - workers = Iterators.cycle(Dagger.Distributed.workers()) - worker_dict = similar(keys, Int) + if isnothing(scopes) + workers = isnothing(workers) ? Dagger.Distributed.workers() : workers - for quotient_vertex in keys - worker, workers = Iterators.peel(workers) - worker_dict[quotient_vertex] = worker + sorted_workers = Iterators.take(Iterators.cycle(workers), length(keys)) + + scopes = map(Dagger.ProcessScope, collect(sorted_workers)) + else + if length(keys) != length(scopes) + throw( + ArgumentError( + "Number of provided scopes must match the number of vertex partitions of underlying graph" + ) + ) + end end + scope_dict = Dictionary(keys, scopes) + quotient_chunks = map(keys) do quotient_vertex - worker = worker_dict[quotient_vertex] + scope = scope_dict[quotient_vertex] iterate = subcache(underlying_cache, quotient_vertex) - chunk = Dagger.@mutable worker = worker BeliefPropagationState(; iterate) + chunk = Dagger.@mutable scope = scope BeliefPropagationState(; iterate) return chunk end @@ -42,7 +58,7 @@ function DataGraphs.is_vertex_assigned(bpc::ITNNP.DaggerBeliefPropagationCache, return is_vertex_assigned(bpc.underlying_cache, vertex) end function DataGraphs.is_edge_assigned(bpc::ITNNP.DaggerBeliefPropagationCache, edge) - return is_edge_assigned(bpc.undelying_cache, edge) + return is_edge_assigned(bpc.underlying_cache, edge) end function DataGraphs.get_vertex_data(bpc::ITNNP.DaggerBeliefPropagationCache, vertex) @@ -52,7 +68,7 @@ function DataGraphs.get_edge_data( bpc::ITNNP.DaggerBeliefPropagationCache, edge::AbstractEdge ) - return get_edge_data(bpc.undelying_caches, edge) + return get_edge_data(bpc.underlying_cache, edge) end function DataGraphs.set_vertex_data!(bpc::ITNNP.DaggerBeliefPropagationCache, val, vertex) @@ -73,13 +89,11 @@ end function ITensorNetworksNext.beliefpropagation_sweep( cache::ITNNP.DaggerBeliefPropagationCache; edges, - workers = workers(), kwargs... ) - keys = collect(quotientvertices(cache)) - - return ITNNP.dagger_algorithm(keys; keys, workers) do quotient_vertex - subcache = fetch(cache[quotient_vertex]).iterate + return ITNNP.dagger_algorithm(quotientvertices(cache)) do quotient_vertex + substate = fetch(cache[quotient_vertex]) + subcache = substate.iterate subcache_edges = forest_cover_edge_sequence(subcache) ∩ edges incoming_edges = boundary_edges(cache, vertices(cache, quotient_vertex); dir = :in) @@ -120,8 +134,8 @@ function ITNNP.get_subiterate( end function AIE.set_substate!( - ::BeliefPropagationProblem, - ::AIE.NestedAlgorithm, + problem::BeliefPropagationProblem, + algorithm::AIE.NestedAlgorithm, state::AIE.State, substate::ITNNP.DaggerState ) @@ -129,35 +143,52 @@ function AIE.set_substate!( state.iterate.maxdiff = 0.0 - for remote_result in substate.remote_results - get_maxdiff = dtask -> dtask.iterate.maxdiff - src_maxdiff = fetch(Dagger.@spawn get_maxdiff(remote_result)) - - if src_maxdiff > state.iterate.maxdiff - state.iterate.maxdiff = src_maxdiff - end + maxdiff_dtasks = map(substate.remote_results) do remote_result + return Dagger.spawn(dtask -> dtask.iterate.maxdiff, remote_result) end - function transfer_edges!(dst_chunk, src_chunk, edges) - src_subcache = src_chunk.iterate - dst_subcache = dst_chunk.iterate - for edge in edges - dst_subcache[edge] = src_subcache[edge] - end - return + maxdiff = maximum(fetch, maxdiff_dtasks) + + if maxdiff > state.iterate.maxdiff + state.iterate.maxdiff = maxdiff end transfer_dtasks = map(quotientedges(dst_cache)) do quotient_edge src_subcache = dst_cache[src(quotient_edge)] dst_subcache = dst_cache[dst(quotient_edge)] - return Dagger.@spawn transfer_edges!( + + src_subcache = fetch(src_subcache) + + return Dagger.spawn( dst_subcache, fetch(src_subcache), edges(dst_cache, quotient_edge) - ) + ) do dst, src, edges + src_subcache = src.iterate + dst_subcache = dst.iterate + for edge in edges + dst_subcache[edge] = src_subcache[edge] + end + end end - wait.(transfer_dtasks) + foreach(wait, transfer_dtasks) + + return state +end + +function ITNNP.finalize_state!( + ::BeliefPropagationProblem, + ::BeliefPropagation, + state::ITNNP.DaggerState + ) + dst_cache = state.iterate.iterate + + for quotient_vertex in quotientvertices(dst_cache) + substate = fetch(dst_cache[quotient_vertex]) + subcache = substate.iterate + edge_data(dst_cache) .= edge_data(subcache) + end return state end diff --git a/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl index 7117641..5d56ea4 100644 --- a/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl +++ b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl @@ -1,6 +1,10 @@ module ITensorNetworksNextParallel import AlgorithmsInterface as AI +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE + +abstract type ParallelAlgorithm{Child} <: AIE.NestedAlgorithm{Child} end +const IterativeParallelAlgorithm{Child <: ParallelAlgorithm} = AIE.NestedAlgorithm{Child} """ get_subiterate(subproblem::AI.Problem, subalgorithm::AI.Algorithm, state::AI.State) @@ -11,6 +15,25 @@ The returned value of this function is then pass to a remote call of `initialize """ get_subiterate(::AI.Problem, ::AI.Algorithm, state::AI.State) = state.iterate +finalize_state!(::AI.Problem, ::AI.Algorithm, state::AI.State) = state + +function AI.is_finished!( + problem::AI.Problem, + algorithm::IterativeParallelAlgorithm, + state::AI.State + ) + c = algorithm.stopping_criterion + st = state.stopping_criterion_state + + isfinished = AI.is_finished!(problem, algorithm, state, c, st) + + if isfinished + finalize_state!(problem, algorithm, state) + end + + return isfinished +end + include("dagger.jl") end # ITensorNetworksNextParallel diff --git a/src/ITensorNetworksNextParallel/dagger.jl b/src/ITensorNetworksNextParallel/dagger.jl index 8f79de5..4ed21ad 100644 --- a/src/ITensorNetworksNextParallel/dagger.jl +++ b/src/ITensorNetworksNextParallel/dagger.jl @@ -1,31 +1,40 @@ import ..ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE import AlgorithmsInterface as AI +using Dictionaries: Dictionary using ITensorNetworksNext: AbstractBeliefPropagationCache @kwdef mutable struct DaggerState{ - Iterate, StoppingCriterionState <: AI.StoppingCriterionState, Chunk, DTask, + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, DTask, } <: AIE.State iterate::Iterate # DaggerBeliefPropagationCache iteration::Int = 0 stopping_criterion_state::StoppingCriterionState - # remote_subiterates::Dict{Int, Chunk} = Dict{Int, Any}() - remote_results::Dict{Int, DTask} = Dict{Int, Any}() + remote_results::Dictionary{Int, DTask} = Dict{Int, Any}() +end + +function initialize_dagger_state(problem, algorithm; kwargs...) + throw( + ErrorException( + "Package Dagger not loaded. Please install and load the Dagger package." + ) + ) end @kwdef struct DaggerNestedAlgorithm{ ChildAlgorithm <: AIE.Algorithm, Algorithms <: AbstractVector{ChildAlgorithm}, StoppingCriterion <: AI.StoppingCriterion, - KeyType, - } <: AIE.NestedAlgorithm + } <: ParallelAlgorithm{ChildAlgorithm} algorithms::Algorithms stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) - workers::Vector{Int} - keys::Vector{KeyType} = collect(1:length(algorithms)) end -function dagger_algorithm(f::Function, iterable; kwargs...) - return DaggerNestedAlgorithm(f, iterable; kwargs...) +function dagger_algorithm(f, iterable; kwargs...) + throw( + ErrorException( + "Package Dagger not loaded. Please install and load the Dagger package." + ) + ) end # ================================== belief propagation ================================== # diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 1a62792..833775f 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -30,8 +30,8 @@ function AI.initialize_state!( end function AI.is_finished!( - ::AIE.Problem, - ::AIE.Algorithm, + problem::AIE.Problem, + algorithm::AIE.Algorithm, state::AIE.State, c::StopWhenConverged, st::StopWhenConvergedState @@ -42,6 +42,16 @@ function AI.is_finished!( st.delta = state.iterate.maxdiff end + return AI.is_finished(problem, algorithm, state, c, st) +end + +function AI.is_finished( + ::AIE.Problem, + ::AIE.Algorithm, + ::AIE.State, + c::StopWhenConverged, + st::StopWhenConvergedState + ) return st.delta < c.tol end @@ -60,7 +70,7 @@ end ChildAlgorithm <: AIE.Algorithm, Algorithms <: AbstractVector{ChildAlgorithm}, StoppingCriterion <: AI.StoppingCriterion, - } <: AIE.NestedAlgorithm + } <: AIE.NestedAlgorithm{ChildAlgorithm} algorithms::Algorithms stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end @@ -100,7 +110,7 @@ end struct BeliefPropagationSweep{ ChildAlgorithm <: AIE.Algorithm, Algorithms <: AbstractVector{ChildAlgorithm}, - } <: AIE.NestedAlgorithm + } <: AIE.NestedAlgorithm{ChildAlgorithm} algorithms::Algorithms stopping_criterion::AI.StopAfterIteration function BeliefPropagationSweep(; algorithms)