diff --git a/Project.toml b/Project.toml index 2784ccb..59dea41 100644 --- a/Project.toml +++ b/Project.toml @@ -29,9 +29,11 @@ WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [weakdeps] TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" +Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" [extensions] ITensorNetworksNextTensorOperationsExt = "TensorOperations" +ITensorNetworksNextDaggerExt = "Dagger" [compat] AbstractTrees = "0.4.5" @@ -39,15 +41,15 @@ Adapt = "4.3" AlgorithmsInterface = "0.1" BackendSelection = "0.1.6" Combinatorics = "1" -DataGraphs = "0.2.7" +DataGraphs = "0.3.1" DiagonalArrays = "0.3.31" Dictionaries = "0.4.5" -FunctionImplementations = "0.4" +FunctionImplementations = "0.4.1" Graphs = "1.13.1" LinearAlgebra = "1.10" MacroTools = "0.5.16" -NamedDimsArrays = "0.14.2" -NamedGraphs = "0.6.9, 0.7, 0.8" +NamedDimsArrays = "0.14.3" +NamedGraphs = "0.10" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" TensorOperations = "5.3.1" diff --git a/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl new file mode 100644 index 0000000..33375c7 --- /dev/null +++ b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl @@ -0,0 +1,63 @@ +module ITensorNetworksNextDaggerExt + +import AlgorithmsInterface as AI +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +import ITensorNetworksNext.ITensorNetworksNextParallel as ITNNP +using Dagger +using ITensorNetworksNext.ITensorNetworksNextParallel: + DaggerNestedAlgorithm, DaggerState, ITensorNetworksNextParallel +using Dictionaries: set! + +function ITNNP.DaggerNestedAlgorithm(f, iterable; kwargs...) + return DaggerNestedAlgorithm(; algorithms = map(f, iterable), kwargs...) +end + +function ITNNP.dagger_algorithm(f::Base.Callable, iterable; kwargs...) + return DaggerNestedAlgorithm(f, iterable; kwargs...) +end + +function ITNNP.initialize_dagger_state( + problem::AIE.Problem, algorithm::AIE.Algorithm; iterate + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + + remote_results = Dictionary{Int, Dagger.DTask}() + + return ITNNP.DaggerState(; + iterate, + remote_results, + stopping_criterion_state + ) +end + +function AI.initialize_state( + problem::AIE.Problem, + algorithm::ITNNP.DaggerNestedAlgorithm; + kwargs... + ) + return ITNNP.initialize_dagger_state(problem, algorithm; kwargs...) +end + +function AI.step!( + problem::AIE.Problem, + algorithm::ITNNP.DaggerNestedAlgorithm, + state::ITNNP.DaggerState; + kwargs... + ) + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + + iterate = ITNNP.get_subiterate(subproblem, subalgorithm, state) + + dtask = Dagger.@spawn AI.solve(subproblem, subalgorithm; iterate) + + set!(state.remote_results, state.iteration, dtask) + + return state +end + +include("daggerbeliefpropagation.jl") + +end # ITensorNetworksNextDaggerExt diff --git a/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl b/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl new file mode 100644 index 0000000..547b6fb --- /dev/null +++ b/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl @@ -0,0 +1,194 @@ +import ITensorNetworksNext.ITensorNetworksNextParallel as ITNNP +using Dagger +using DataGraphs: DataGraphs, edge_data, get_edge_data, get_vertex_data, is_edge_assigned, + is_vertex_assigned, set_edge_data!, set_vertex_data!, underlying_graph +using Dictionaries: Dictionary, Indices, getindices +using Graphs: AbstractEdge, AbstractGraph, dst, edges, src, vertices +using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagation, BeliefPropagationCache, + BeliefPropagationProblem, BeliefPropagationState, beliefpropagation, + forest_cover_edge_sequence, select_algorithm, subcache +using NamedGraphs.GraphsExtensions: boundary_edges +using NamedGraphs.PartitionedGraphs: QuotientVertex, quotientedges, quotientvertices +using NamedGraphs: NamedGraphs + +const DaggerBeliefPropagation = BeliefPropagation{<:ITNNP.DaggerNestedAlgorithm}; + +function ITNNP.DaggerBeliefPropagationCache( + network::AbstractGraph; + workers = nothing, + scopes = nothing + ) + underlying_cache = BeliefPropagationCache(network) + + keys = Indices(quotientvertices(underlying_cache)) + + if isnothing(scopes) + workers = isnothing(workers) ? Dagger.Distributed.workers() : workers + + sorted_workers = Iterators.take(Iterators.cycle(workers), length(keys)) + + scopes = map(Dagger.ProcessScope, collect(sorted_workers)) + else + if length(keys) != length(scopes) + throw( + ArgumentError( + "Number of provided scopes must match the number of vertex partitions of underlying graph" + ) + ) + end + end + + scope_dict = Dictionary(keys, scopes) + + quotient_chunks = map(keys) do quotient_vertex + scope = scope_dict[quotient_vertex] + iterate = subcache(underlying_cache, quotient_vertex) + chunk = Dagger.@mutable scope = scope BeliefPropagationState(; iterate) + return chunk + end + + return ITNNP.DaggerBeliefPropagationCache(underlying_cache, quotient_chunks) +end + +function DataGraphs.underlying_graph(cache::ITNNP.DaggerBeliefPropagationCache) + return underlying_graph(cache.underlying_cache) +end + +function DataGraphs.is_vertex_assigned(bpc::ITNNP.DaggerBeliefPropagationCache, vertex) + return is_vertex_assigned(bpc.underlying_cache, vertex) +end +function DataGraphs.is_edge_assigned(bpc::ITNNP.DaggerBeliefPropagationCache, edge) + return is_edge_assigned(bpc.underlying_cache, edge) +end + +function DataGraphs.get_vertex_data(bpc::ITNNP.DaggerBeliefPropagationCache, vertex) + return get_vertex_data(bpc.underlying_cache, vertex) +end +function DataGraphs.get_edge_data( + bpc::ITNNP.DaggerBeliefPropagationCache, + edge::AbstractEdge + ) + return get_edge_data(bpc.underlying_cache, edge) +end + +function DataGraphs.set_vertex_data!(bpc::ITNNP.DaggerBeliefPropagationCache, val, vertex) + return set_vertex_data!(bpc.underlying_cache, val, vertex) +end +function DataGraphs.set_edge_data!(bpc::ITNNP.DaggerBeliefPropagationCache, val, edge) + return set_edge_data!(bpc.underlying_cache, val, edge) +end + +NamedGraphs.to_graph_index(::ITNNP.DaggerBeliefPropagationCache, qv::QuotientVertex) = qv +function DataGraphs.get_index_data( + cache::ITNNP.DaggerBeliefPropagationCache, + qv::QuotientVertex + ) + return cache.quotient_chunks[qv] +end + +function ITensorNetworksNext.beliefpropagation_sweep( + cache::ITNNP.DaggerBeliefPropagationCache; + edges, + kwargs... + ) + return ITNNP.dagger_algorithm(quotientvertices(cache)) do quotient_vertex + substate = fetch(cache[quotient_vertex]) + subcache = substate.iterate + + subcache_edges = forest_cover_edge_sequence(subcache) ∩ edges + incoming_edges = boundary_edges(cache, vertices(cache, quotient_vertex); dir = :in) + + alg = select_algorithm( + beliefpropagation, + subcache; + # Don't update the incoming messages + edges = setdiff(subcache_edges, incoming_edges), + maxiter = 1, + kwargs... + ) + + return alg + end +end + +function AI.initialize_state( + problem::AIE.Problem, + algorithm::BeliefPropagation{<:ITNNP.DaggerNestedAlgorithm}; + kwargs... + ) + return ITNNP.initialize_dagger_state(problem, algorithm; kwargs...) +end + +function ITNNP.get_subiterate( + ::BeliefPropagationProblem, + ::BeliefPropagation, # Our parallel region runs a small BP + state::ITNNP.DaggerState + ) + cache = state.iterate.iterate + + quotient_vertex = collect(quotientvertices(cache))[state.iteration] + + subiterate = cache[quotient_vertex] + + return subiterate +end + +function AIE.set_substate!( + problem::BeliefPropagationProblem, + algorithm::AIE.NestedAlgorithm, + state::AIE.State, + substate::ITNNP.DaggerState + ) + dst_cache = state.iterate.iterate + + state.iterate.maxdiff = 0.0 + + maxdiff_dtasks = map(substate.remote_results) do remote_result + return Dagger.spawn(dtask -> dtask.iterate.maxdiff, remote_result) + end + + maxdiff = maximum(fetch, maxdiff_dtasks) + + if maxdiff > state.iterate.maxdiff + state.iterate.maxdiff = maxdiff + end + + transfer_dtasks = map(quotientedges(dst_cache)) do quotient_edge + src_subcache = dst_cache[src(quotient_edge)] + dst_subcache = dst_cache[dst(quotient_edge)] + + src_subcache = fetch(src_subcache) + + return Dagger.spawn( + dst_subcache, + fetch(src_subcache), + edges(dst_cache, quotient_edge) + ) do dst, src, edges + src_subcache = src.iterate + dst_subcache = dst.iterate + for edge in edges + dst_subcache[edge] = src_subcache[edge] + end + end + end + + foreach(wait, transfer_dtasks) + + return state +end + +function ITNNP.finalize_state!( + ::BeliefPropagationProblem, + ::BeliefPropagation, + state::ITNNP.DaggerState + ) + dst_cache = state.iterate.iterate + + for quotient_vertex in quotientvertices(dst_cache) + substate = fetch(dst_cache[quotient_vertex]) + subcache = substate.iterate + edge_data(dst_cache) .= edge_data(subcache) + end + + return state +end diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index 5d9561a..bfe1d0c 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -149,10 +149,10 @@ end # ============================ NestedAlgorithm ============================================= -abstract type NestedAlgorithm <: Algorithm end +abstract type NestedAlgorithm{Child} <: Algorithm end -function nested_algorithm(f::Function, nalgorithms::Int; kwargs...) - return DefaultNestedAlgorithm(f, nalgorithms; kwargs...) +function nested_algorithm(f::Function, iterable; kwargs...) + return DefaultNestedAlgorithm(f, iterable; kwargs...) end max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms) @@ -202,7 +202,7 @@ from a list of stored algorithms. ChildAlgorithm <: Algorithm, Algorithms <: AbstractVector{ChildAlgorithm}, StoppingCriterion <: AI.StoppingCriterion, - } <: NestedAlgorithm + } <: NestedAlgorithm{ChildAlgorithm} algorithms::Algorithms stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index ace4030..dc50876 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -9,4 +9,10 @@ include("contract_network.jl") include("sweeping/utils.jl") include("sweeping/eigenproblem.jl") +include("beliefpropagation/abstractbeliefpropagationcache.jl") +include("beliefpropagation/beliefpropagationcache.jl") +include("beliefpropagation/beliefpropagationproblem.jl") + +include("ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl") + end diff --git a/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl new file mode 100644 index 0000000..5d56ea4 --- /dev/null +++ b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl @@ -0,0 +1,39 @@ +module ITensorNetworksNextParallel + +import AlgorithmsInterface as AI +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE + +abstract type ParallelAlgorithm{Child} <: AIE.NestedAlgorithm{Child} end +const IterativeParallelAlgorithm{Child <: ParallelAlgorithm} = AIE.NestedAlgorithm{Child} + +""" + get_subiterate(subproblem::AI.Problem, subalgorithm::AI.Algorithm, state::AI.State) + +For a given `subproblem` and `subalgorithm` of a parent nested algorithm, +derive (from the parent state `state`) the iterate to be used in the associated sub state. +The returned value of this function is then pass to a remote call of `initialize_state`. +""" +get_subiterate(::AI.Problem, ::AI.Algorithm, state::AI.State) = state.iterate + +finalize_state!(::AI.Problem, ::AI.Algorithm, state::AI.State) = state + +function AI.is_finished!( + problem::AI.Problem, + algorithm::IterativeParallelAlgorithm, + state::AI.State + ) + c = algorithm.stopping_criterion + st = state.stopping_criterion_state + + isfinished = AI.is_finished!(problem, algorithm, state, c, st) + + if isfinished + finalize_state!(problem, algorithm, state) + end + + return isfinished +end + +include("dagger.jl") + +end # ITensorNetworksNextParallel diff --git a/src/ITensorNetworksNextParallel/dagger.jl b/src/ITensorNetworksNextParallel/dagger.jl new file mode 100644 index 0000000..4ed21ad --- /dev/null +++ b/src/ITensorNetworksNextParallel/dagger.jl @@ -0,0 +1,47 @@ +import ..ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +import AlgorithmsInterface as AI +using Dictionaries: Dictionary +using ITensorNetworksNext: AbstractBeliefPropagationCache + +@kwdef mutable struct DaggerState{ + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, DTask, + } <: AIE.State + iterate::Iterate # DaggerBeliefPropagationCache + iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState + remote_results::Dictionary{Int, DTask} = Dict{Int, Any}() +end + +function initialize_dagger_state(problem, algorithm; kwargs...) + throw( + ErrorException( + "Package Dagger not loaded. Please install and load the Dagger package." + ) + ) +end + +@kwdef struct DaggerNestedAlgorithm{ + ChildAlgorithm <: AIE.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: ParallelAlgorithm{ChildAlgorithm} + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) +end + +function dagger_algorithm(f, iterable; kwargs...) + throw( + ErrorException( + "Package Dagger not loaded. Please install and load the Dagger package." + ) + ) +end + +# ================================== belief propagation ================================== # + +struct DaggerBeliefPropagationCache{ + V, VD, ED, UC <: AbstractBeliefPropagationCache{V, VD, ED}, Chunks, + } <: AbstractBeliefPropagationCache{V, VD, ED} + underlying_cache::UC + quotient_chunks::Chunks +end diff --git a/src/LazyNamedDimsArrays/lazynameddimsarray.jl b/src/LazyNamedDimsArrays/lazynameddimsarray.jl index 4cbd3f9..62774a7 100644 --- a/src/LazyNamedDimsArrays/lazynameddimsarray.jl +++ b/src/LazyNamedDimsArrays/lazynameddimsarray.jl @@ -7,7 +7,7 @@ using WrappedUnions: @wrapped union::Union{A, Mul{LazyNamedDimsArray{T, A}}} end -parenttype(::Type{LazyNamedDimsArray{<:Any, A}}) where {A} = A +parenttype(::Type{LazyNamedDimsArray{T, A}}) where {T, A} = A parenttype(::Type{LazyNamedDimsArray{T}}) where {T} = AbstractNamedDimsArray{T} parenttype(::Type{LazyNamedDimsArray}) = AbstractNamedDimsArray diff --git a/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl b/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl index 172ec08..44bae0a 100644 --- a/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl +++ b/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl @@ -5,6 +5,9 @@ const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} = function symnameddims(symname, dims) return lazy(nameddims(SymbolicArray(symname, denamed.(dims)), name.(dims))) end +function symnameddims(name, ndarray::AbstractNamedDimsArray) + return symnameddims(name, Tuple(inds(ndarray))) +end symnameddims(name) = symnameddims(name, ()) using AbstractTrees: AbstractTrees function AbstractTrees.printnode(io::IO, a::SymbolicNamedDimsArray) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index fb661f5..1e25df5 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -1,27 +1,24 @@ -using Adapt: Adapt, adapt, adapt_structure +using Adapt: Adapt, adapt using BackendSelection: @Algorithm_str, Algorithm -using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, underlying_graph, - underlying_graph_type, vertex_data +using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, set_vertex_data!, + underlying_graph, underlying_graph_type, vertex_data using Dictionaries: Dictionary -using Graphs: Graphs, AbstractEdge, AbstractGraph, Graph, add_edge!, add_vertex!, bfs_tree, - center, dst, edges, edgetype, ne, neighbors, nv, rem_edge!, src, vertices -using LinearAlgebra: LinearAlgebra, factorize +using Graphs: Graphs, AbstractEdge, AbstractGraph, add_edge!, add_vertex!, dst, edges, + edgetype, ne, neighbors, nv, rem_edge!, src, vertices +using LinearAlgebra: LinearAlgebra using MacroTools: @capture using NamedDimsArrays: dimnames, inds using NamedGraphs.GraphsExtensions: - directed_graph, incident_edges, rem_edges!, rename_vertices, vertextype, ⊔ -using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree -using SplitApplyCombine: flatten + directed_graph, incident_edges, rem_edges!, similar_graph, vertextype +using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger +using NamedGraphs: NamedGraphs, NamedGraph, not_implemented abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end -function Graphs.rem_edge!(tn::AbstractTensorNetwork, e) - rem_edge!(underlying_graph(tn), e) - return tn -end +# Need to be careful about removing edges from tensor networks in case there is a bond +Graphs.rem_edge!(::AbstractTensorNetwork, edge) = not_implemented() -# TODO: Define a generic fallback for `AbstractDataGraph`? -DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = error("No edge data") +DataGraphs.edge_data_type(::Type{<:AbstractTensorNetwork}) = not_implemented() # Graphs.jl overloads function Graphs.weights(graph::AbstractTensorNetwork) @@ -36,10 +33,11 @@ function Graphs.weights(graph::AbstractTensorNetwork) end # Copy -Base.copy(tn::AbstractTensorNetwork) = error("Not implemented") +Base.copy(::AbstractTensorNetwork) = not_implemented() # Iteration Base.iterate(tn::AbstractTensorNetwork, args...) = iterate(vertex_data(tn), args...) + Base.keys(tn::AbstractTensorNetwork) = vertices(tn) # TODO: This contrasts with the `DataGraphs.AbstractDataGraph` definition, @@ -49,20 +47,7 @@ Base.eltype(tn::AbstractTensorNetwork) = eltype(vertex_data(tn)) # Overload if needed Graphs.is_directed(::Type{<:AbstractTensorNetwork}) = false -# Derived interface, may need to be overloaded -function DataGraphs.underlying_graph_type(G::Type{<:AbstractTensorNetwork}) - return underlying_graph_type(data_graph_type(G)) -end - -# AbstractDataGraphs overloads -function DataGraphs.vertex_data(graph::AbstractTensorNetwork, args...) - return error("Not implemented") -end -function DataGraphs.edge_data(graph::AbstractTensorNetwork, args...) - return error("Not implemented") -end - -DataGraphs.underlying_graph(tn::AbstractTensorNetwork) = error("Not implemented") +DataGraphs.underlying_graph(::AbstractTensorNetwork) = not_implemented() function NamedGraphs.vertex_positions(tn::AbstractTensorNetwork) return NamedGraphs.vertex_positions(underlying_graph(tn)) end @@ -81,40 +66,37 @@ function Adapt.adapt_structure(to, tn::AbstractTensorNetwork) return map_vertex_data_preserve_graph(adapt(to), tn) end -function linkinds(tn::AbstractTensorNetwork, edge::Pair) - return linkinds(tn, edgetype(tn)(edge)) -end -function linkinds(tn::AbstractTensorNetwork, edge::AbstractEdge) - return inds(tn[src(edge)]) ∩ inds(tn[dst(edge)]) -end -function linkaxes(tn::AbstractTensorNetwork, edge::Pair) +linkinds(tn::AbstractGraph, edge::Pair) = linkinds(tn, edgetype(tn)(edge)) +linkinds(tn::AbstractGraph, edge::AbstractEdge) = inds(tn[src(edge)]) ∩ inds(tn[dst(edge)]) + +function linkaxes(tn::AbstractGraph, edge::Pair) return linkaxes(tn, edgetype(tn)(edge)) end -function linkaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) +function linkaxes(tn::AbstractGraph, edge::AbstractEdge) return axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) end -function linknames(tn::AbstractTensorNetwork, edge::Pair) +function linknames(tn::AbstractGraph, edge::Pair) return linknames(tn, edgetype(tn)(edge)) end -function linknames(tn::AbstractTensorNetwork, edge::AbstractEdge) +function linknames(tn::AbstractGraph, edge::AbstractEdge) return dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) end -function siteinds(tn::AbstractTensorNetwork, v) +function siteinds(tn::AbstractGraph, v) s = inds(tn[v]) for v′ in neighbors(tn, v) s = setdiff(s, inds(tn[v′])) end return s end -function siteaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) +function siteaxes(tn::AbstractGraph, edge::AbstractEdge) s = axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) for v′ in neighbors(tn, v) s = setdiff(s, axes(tn[v′])) end return s end -function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge) +function sitenames(tn::AbstractGraph, edge::AbstractEdge) s = dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) for v′ in neighbors(tn, v) s = setdiff(s, dimnames(tn[v′])) @@ -122,8 +104,8 @@ function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge) return s end -function setindex_preserve_graph!(tn::AbstractTensorNetwork, value, vertex) - vertex_data(tn)[vertex] = value +function setindex_preserve_graph!(tn::AbstractGraph, value, vertex) + set_vertex_data!(tn, value, vertex) return tn end @@ -153,7 +135,7 @@ end # Update the graph of the TensorNetwork `tn` to include # edges that should exist based on the tensor connectivity. -function add_missing_edges!(tn::AbstractTensorNetwork) +function add_missing_edges!(tn::AbstractGraph) foreach(v -> add_missing_edges!(tn, v), vertices(tn)) return tn end @@ -161,7 +143,7 @@ end # Update the graph of the TensorNetwork `tn` to include # edges that should be incident to the vertex `v` # based on the tensor connectivity. -function add_missing_edges!(tn::AbstractTensorNetwork, v) +function add_missing_edges!(tn::AbstractGraph, v) for v′ in vertices(tn) if v ≠ v′ e = v => v′ @@ -175,13 +157,13 @@ end # Fix the edges of the TensorNetwork `tn` to match # the tensor connectivity. -function fix_edges!(tn::AbstractTensorNetwork) +function fix_edges!(tn::AbstractGraph) foreach(v -> fix_edges!(tn, v), vertices(tn)) return tn end # Fix the edges of the TensorNetwork `tn` to match # the tensor connectivity at vertex `v`. -function fix_edges!(tn::AbstractTensorNetwork, v) +function fix_edges!(tn::AbstractGraph, v) rem_edges!(tn, incident_edges(tn, v)) add_missing_edges!(tn, v) return tn @@ -215,28 +197,20 @@ function Base.setindex!(tn::AbstractTensorNetwork, value, v) fix_edges!(tn, v) return tn end -using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger # Fix ambiguity error. function Base.setindex!(graph::AbstractTensorNetwork, value, vertex::OrdinalSuffixedInteger) graph[vertices(graph)[vertex]] = value return graph end -# Fix ambiguity error. -function Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) - return error("No edge data.") -end -# Fix ambiguity error. -function Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) - return error("No edge data.") -end -using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger +Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) = not_implemented() +Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) = not_implemented() # Fix ambiguity error. function Base.setindex!( tn::AbstractTensorNetwork, value, edge::Pair{<:OrdinalSuffixedInteger, <:OrdinalSuffixedInteger} ) - return error("No edge data.") + return not_implemented() end function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl new file mode 100644 index 0000000..a810c9a --- /dev/null +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -0,0 +1,161 @@ +using DataGraphs: AbstractDataGraph, edge_data, edge_data_type, vertex_data +using Graphs: AbstractEdge, AbstractGraph +using NamedGraphs.GraphsExtensions: boundary_edges +using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, parent, QuotientVertex + +messages(bp_cache::AbstractGraph) = edge_data(bp_cache) +messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges] + +message(bp_cache::AbstractGraph, edge::AbstractEdge) = messages(bp_cache)[edge] + +deletemessage!(bp_cache::AbstractGraph, edge) = not_implemented() +function deletemessage!(bp_cache::AbstractDataGraph, edge) + ms = messages(bp_cache) + delete!(ms, edge) + return bp_cache +end + +function deletemessages!(bp_cache::AbstractGraph, edges = edges(bp_cache)) + for e in edges + deletemessage!(bp_cache, e) + end + return bp_cache +end + +setmessage!(bp_cache::AbstractGraph, edge, message) = not_implemented() +function setmessage!(bp_cache::AbstractDataGraph, edge, message) + setindex!(bp_cache, message, edge) + return bp_cache +end +function setmessage!(bp_cache::QuotientView, edge, message) + setmessages!(parent(bp_cache), QuotientEdge(edge), message) + return bp_cache +end + +function setmessages!(bp_cache::AbstractGraph, edge::QuotientEdge, message) + for e in edges(bp_cache, edge) + setmessage!(parent(bp_cache), e, message[e]) + end + return bp_cache +end +function setmessages!(bpc_dst::AbstractGraph, bpc_src::AbstractGraph, edges) + for e in edges + setmessage!(bpc_dst, e, message(bpc_src, e)) + end + return bpc_dst +end + +factors(bpc::AbstractGraph) = vertex_data(bpc) +factors(bpc::AbstractGraph, vertices::Vector) = [factor(bpc, v) for v in vertices] +factors(bpc::AbstractGraph{V}, vertex::V) where {V} = factors(bpc, V[vertex]) + +factor(bpc::AbstractGraph, vertex) = bpc[vertex] + +setfactor!(bpc::AbstractGraph, vertex, factor) = not_implemented() +function setfactor!(bpc::AbstractDataGraph, vertex, factor) + fs = factors(bpc) + setindex!(fs, vertex, factor) + return bpc +end + +function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge) + return (message(bp_cache, edge) * message(bp_cache, reverse(edge)))[] +end + +function region_scalar(bp_cache::AbstractGraph, vertex) + messages = incoming_messages(bp_cache, vertex) + state = factors(bp_cache, vertex) + + return (reduce(*, messages) * reduce(*, state))[] +end + +message_type(bpc::AbstractGraph) = message_type(typeof(bpc)) +message_type(G::Type{<:AbstractGraph}) = eltype(Base.promote_op(messages, G)) +message_type(type::Type{<:AbstractDataGraph}) = edge_data_type(type) + +function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache)) + return map(v -> region_scalar(bp_cache, v), vertices) +end + +function edge_scalars( + bp_cache::AbstractGraph, + edges = edges(undirected_graph(underlying_graph(bp_cache))) + ) + return map(e -> region_scalar(bp_cache, e), edges) +end + +function scalar_factors_quotient(bp_cache::AbstractGraph) + return vertex_scalars(bp_cache), edge_scalars(bp_cache) +end + +function incoming_messages(bp_cache::AbstractGraph, vertices; ignore_edges = []) + b_edges = boundary_edges(bp_cache, [vertices;]; dir = :in) + b_edges = !isempty(ignore_edges) ? setdiff(b_edges, ignore_edges) : b_edges + return messages(bp_cache, b_edges) +end + +default_messages(::AbstractGraph) = not_implemented() + +#Adapt interface for changing device +map_messages(f, bp_cache, es = edges(bp_cache)) = map_messages!(f, copy(bp_cache), es) +function map_messages!(f, bp_cache, es = edges(bp_cache)) + for e in es + setmessage!(bp_cache, e, f(message(bp_cache, e))) + end + return bp_cache +end + +map_factors(f, bp_cache, vs = vertices(bp_cache)) = map_factors!(f, copy(bp_cache), vs) +function map_factors!(f, bp_cache, vs = vertices(bp_cache)) + for v in vs + setfactor!(bp_cache, v, f(factor(bp_cache, v))) + end + return bp_cache +end + +adapt_messages(to, bp_cache, es = edges(bp_cache)) = map_messages(adapt(to), bp_cache, es) +adapt_factors(to, bp_cache, vs = vertices(bp_cache)) = map_factors(adapt(to), bp_cache, vs) + +abstract type AbstractBeliefPropagationCache{V, VD, ED} <: AbstractDataGraph{V, VD, ED} end + +factor_type(bpc::AbstractBeliefPropagationCache) = typeof(bpc) +factor_type(::Type{<:AbstractBeliefPropagationCache{<:Any, VD}}) where {VD} = VD + +message_type(bpc::AbstractBeliefPropagationCache) = message_type(typeof(bpc)) +message_type(::Type{<:AbstractBeliefPropagationCache{<:Any, <:Any, ED}}) where {ED} = ED + +function free_energy(bp_cache::AbstractBeliefPropagationCache) + numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) + + if any(t -> real(t) < 0, numerator_terms) + numerator_terms = complex.(numerator_terms) + end + if any(t -> real(t) < 0, denominator_terms) + denominator_terms = complex.(denominator_terms) + end + + if any(iszero, denominator_terms) + return -Inf + end + + return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) +end +partitionfunction(bp_cache::AbstractBeliefPropagationCache) = exp(free_energy(bp_cache)) + +function subcache(cache::AbstractBeliefPropagationCache, vertex::QuotientVertex) + return subcache(cache, vertices(cache, vertex)) +end +function subcache(cache::AbstractBeliefPropagationCache, vertices) + subcache = subgraph(cache, vertices) + + for vertex in vertices + for neighbor_vertex in neighbors(cache, vertex) + add_vertex!(subcache, neighbor_vertex) + # Add in necessary messages. + subcache[vertex => neighbor_vertex] = cache[vertex => neighbor_vertex] + subcache[neighbor_vertex => vertex] = cache[neighbor_vertex => vertex] + end + end + + return subcache +end diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl new file mode 100644 index 0000000..5d1a31c --- /dev/null +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -0,0 +1,164 @@ +using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, edge_data_type, + set_vertex_data!, underlying_graph, underlying_graph_type, vertex_data, vertex_data_type +using Dictionaries: Dictionary, delete!, getindices, set! +using Graphs: AbstractGraph, connected_components, is_directed, is_tree +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, lazy, parenttype +using NamedGraphs.GraphsExtensions: + default_root_vertex, forest_cover, post_order_dfs_edges, undirected_graph, vertextype +using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, quotient_graph +using NamedGraphs: Vertices, convert_vertextype, parent_graph_indices + +struct BeliefPropagationCache{V, VD, ED, E, G <: AbstractGraph{V}} <: + AbstractBeliefPropagationCache{V, VD, ED} + underlying_graph::G # we only use this for the edges. + factors::Dictionary{V, VD} + messages::Dictionary{E, ED} + function BeliefPropagationCache( + graph::AbstractGraph, + factors::Dictionary, + messages::Dictionary + ) + # Ensure the graph is directed, if not make it directed. + digraph = is_directed(graph) ? graph : directed_graph(graph) + + V = keytype(factors) + VD = eltype(factors) + + E = keytype(messages) + ED = eltype(messages) + + bpc = new{V, VD, ED, E, typeof(digraph)}(digraph, factors, messages) + + for edge in edges(bpc) + get!(() -> default_message(bpc, edge), messages, edge) + end + return bpc + end +end + +DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = bpc.underlying_graph + +function DataGraphs.is_vertex_assigned(bpc::BeliefPropagationCache, vertex) + return haskey(bpc.factors, vertex) +end +DataGraphs.is_edge_assigned(bpc::BeliefPropagationCache, edge) = haskey(bpc.messages, edge) + +DataGraphs.get_vertex_data(bpc::BeliefPropagationCache, vertex) = bpc.factors[vertex] +function DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) + return bpc.messages[edge] +end + +function DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) + return set!(bpc.factors, vertex, val) +end +function DataGraphs.set_edge_data!(bpc::BeliefPropagationCache, val, edge) + return set!(bpc.messages, edge, val) +end + +# These two methods assume `network` behaves llike a tensor network +# (could be e.g. a QuotientView) otherwise how would one know what the factors should be. +function BeliefPropagationCache(network::AbstractGraph) + graph = underlying_graph(network) + return BeliefPropagationCache(graph, copy(vertex_data(network))) +end +function BeliefPropagationCache(MT::Type, network::AbstractGraph) + graph = underlying_graph(network) + return BeliefPropagationCache(MT, graph, copy(vertex_data(network))) +end + +function BeliefPropagationCache(graph::AbstractGraph, factors::Dictionary) + MT = vertex_data_type(typeof(graph)) + return BeliefPropagationCache(MT, graph, factors) +end +function BeliefPropagationCache(MT::Type, graph::AbstractGraph, factors::Dictionary) + messages = Dictionary{edgetype(graph), MT}() + return BeliefPropagationCache(graph, factors, messages) +end + +function Base.copy(bp_cache::BeliefPropagationCache) + return BeliefPropagationCache( + copy(bp_cache.underlying_graph), + copy(bp_cache.factors), + copy(bp_cache.messages) + ) +end + +# TODO: This needs to go in GraphsExtensions +function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_root_vertex) + # All we care about are the edges so the type of the graph doesnt matter + g = NamedGraph(vertices(gi)) + add_edges!(g, edges(gi)) + forests = forest_cover(g) + rv = edgetype(g)[] + for forest in forests + trees = [forest[Vertices(vs)] for vs in connected_components(forest)] + for tree in trees + tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) + push!(rv, vcat(tree_edges, reverse(reverse.(tree_edges)))...) + end + end + return rv +end + +function induced_subgraph_bpcache(graph, subvertices) + underlying_subgraph, vlist = + Graphs.induced_subgraph(underlying_graph(graph), subvertices) + + assigned = v -> isassigned(graph, v) + + assigned_subvertices = Iterators.filter(assigned, subvertices) + assigned_subedges = Iterators.filter(assigned, edges(underlying_subgraph)) + + factors = getindices(vertex_data(graph), Indices(assigned_subvertices)) + messages = getindices(edge_data(graph), Indices(assigned_subedges)) + + subgraph = BeliefPropagationCache(underlying_subgraph, factors, messages) + + return subgraph, vlist +end + +function NamedGraphs.induced_subgraph_from_vertices( + graph::BeliefPropagationCache, + subvertices + ) + return induced_subgraph_bpcache(graph, subvertices) +end + +## PartitionedGraphs + +# Take a QuotientView of the underlying graph. +function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) + graph = underlying_graph(bpc) + + quotient_view = QuotientView(graph) + + factors = map(v -> bpc[QuotientVertex(v)], Indices(vertices(quotient_view))) + messages = map(e -> bpc[QuotientEdge(e)], Indices(edges(quotient_view))) + + return BeliefPropagationCache(quotient_view, factors, messages) +end + +function default_message(bpc::BeliefPropagationCache, edge) + return default_message(message_type(bpc), bpc[src(edge)], bpc[dst(edge)]) +end +function default_message(T::Type, src, dst) + array = ones(Tuple(inds(src) ∩ inds(dst))) + return convert(T, array) +end +function default_message(T::Type{<:LazyNamedDimsArray}, src, dst) + message = default_message(parenttype(T), src, dst) + return convert(T, lazy(message)) +end + +NamedGraphs.to_graph_index(::BeliefPropagationCache, vertex::QuotientVertex) = vertex +# When getting data according the quotient vertices, take a lazy contraction. +function DataGraphs.get_index_data(tn::BeliefPropagationCache, vertex::QuotientVertex) + data = collect(map(v -> tn[v], vertices(tn, vertex))) + return mapreduce(lazy, *, data) +end +function DataGraphs.is_graph_index_assigned( + tn::BeliefPropagationCache, + vertex::QuotientVertex + ) + return isassigned(tn, Vertices(vertices(tn, vertex))) +end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl new file mode 100644 index 0000000..833775f --- /dev/null +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -0,0 +1,314 @@ +import .AlgorithmsInterfaceExtensions as AIE +import AlgorithmsInterface as AI +using DataGraphs: edge_data +using Graphs: AbstractEdge, edges, has_edge, vertices +using LinearAlgebra: norm, normalize +using NamedDimsArrays: AbstractNamedDimsArray +using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph +using NamedGraphs.PartitionedGraphs: quotientvertices + +@kwdef struct StopWhenConverged <: AI.StoppingCriterion + tol::Float64 = 0.0 +end + +@kwdef mutable struct StopWhenConvergedState <: AI.StoppingCriterionState + delta::Float64 = Inf +end + +function AI.initialize_state(::AIE.Problem, ::AIE.Algorithm, ::StopWhenConverged) + return StopWhenConvergedState() +end + +function AI.initialize_state!( + ::AIE.Problem, + ::AIE.Algorithm, + ::StopWhenConverged, + st::StopWhenConvergedState + ) + st.delta = Inf + return st +end + +function AI.is_finished!( + problem::AIE.Problem, + algorithm::AIE.Algorithm, + state::AIE.State, + c::StopWhenConverged, + st::StopWhenConvergedState + ) + + # maxdiff = 0.0 initially, so skip this the first time. + if state.iteration > 0 + st.delta = state.iterate.maxdiff + end + + return AI.is_finished(problem, algorithm, state, c, st) +end + +function AI.is_finished( + ::AIE.Problem, + ::AIE.Algorithm, + ::AIE.State, + c::StopWhenConverged, + st::StopWhenConvergedState + ) + return st.delta < c.tol +end + +struct BeliefPropagationProblem{Network} <: AIE.Problem + network::Network +end + +@kwdef mutable struct BeliefPropagationState{Iterate, Diffs} <: + AIE.NonIterativeAlgorithmState + iterate::Iterate + diffs::Diffs = similar(edge_data(iterate), Float64) + maxdiff::Float64 = 0.0 +end + +@kwdef struct BeliefPropagation{ + ChildAlgorithm <: AIE.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: AIE.NestedAlgorithm{ChildAlgorithm} + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) +end + +function BeliefPropagation(f::Function, niterations::Int; kwargs...) + return BeliefPropagation(; algorithms = f.(1:niterations), kwargs...) +end + +abstract type AbstractMessageUpdate <: AIE.NonIterativeAlgorithm end + +struct SimpleMessageUpdate{E <: AbstractEdge, Kwargs <: NamedTuple} <: AbstractMessageUpdate + edge::E + kwargs::Kwargs +end + +function SimpleMessageUpdate( + edge; + normalize = true, + contraction_alg = "exact", + compute_diff = false, + kwargs... + ) + return SimpleMessageUpdate( + edge, + (; normalize, contraction_alg, compute_diff, kwargs...) + ) +end + +function Base.getproperty(alg::SimpleMessageUpdate, name::Symbol) + if name in (:edge, :kwargs) + return getfield(alg, name) + else + return getproperty(getfield(alg, :kwargs), name) + end +end + +struct BeliefPropagationSweep{ + ChildAlgorithm <: AIE.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + } <: AIE.NestedAlgorithm{ChildAlgorithm} + algorithms::Algorithms + stopping_criterion::AI.StopAfterIteration + function BeliefPropagationSweep(; algorithms) + stopping_criterion = AI.StopAfterIteration(length(algorithms)) + return new{eltype(algorithms), typeof(algorithms)}(algorithms, stopping_criterion) + end +end + +function BeliefPropagationSweep(f::Function, edges) + return BeliefPropagationSweep(; algorithms = f.(edges)) +end + +function AI.initialize_state( + ::BeliefPropagationProblem, ::AIE.NonIterativeAlgorithm; iterate, kwargs... + ) + diffs = iterate.diffs + maxdiff = iterate.maxdiff + + return BeliefPropagationState(; iterate = iterate.iterate, diffs, maxdiff, kwargs...) +end + +# This gets called at the start of every sweep. +function AI.initialize_state!( + ::BeliefPropagationProblem, + ::BeliefPropagationSweep, + iteration_state::AIE.State + ) + iteration_state.iterate.maxdiff = 0.0 + return iteration_state +end + +function AIE.set_substate!( + ::BeliefPropagationProblem, + ::BeliefPropagationSweep, + sweep_state::AIE.DefaultState, + noniterative_substate::BeliefPropagationState + ) + sweep_state.iterate = noniterative_substate + + return sweep_state +end + +struct MessageUpdateProblem{Factor, Messages} <: AIE.Problem + factor::Factor + messages::Messages +end + +function AI.solve!( + problem::BeliefPropagationProblem, + algorithm::SimpleMessageUpdate, + state::BeliefPropagationState; + logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm) + ) + logger = AI.algorithm_logger() + + cache = state.iterate + edge = algorithm.edge + + AI.emit_message( + logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreUpdate) + ) + + new_message = updated_message(algorithm, cache) + + if !isnothing(algorithm.message_diff_function) + diff = algorithm.message_diff_function(new_message, cache[edge]) + + if diff > state.maxdiff + state.maxdiff = diff + end + + state.diffs[edge] = diff + end + + setmessage!(cache, edge, new_message) + + AI.emit_message( + logger, problem, algorithm, state, Symbol(logging_context_prefix, :PostUpdate) + ) + + return state +end + +default_message_diff_function(m1, m2) = norm(normalize(m1) - normalize(m2)) + +function updated_message(algorithm, cache) + edge = algorithm.edge + + vertex = src(edge) + messages = incoming_messages(cache, vertex; ignore_edges = typeof(edge)[reverse(edge)]) + + update_problem = MessageUpdateProblem(cache[vertex], messages) + + message_state = AI.solve(update_problem, algorithm; iterate = message(cache, edge)) + + return message_state.iterate +end + +function AI.solve!( + problem::MessageUpdateProblem, + algorithm::SimpleMessageUpdate, + state::AIE.DefaultNonIterativeAlgorithmState; + logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm), + kwargs... + ) + logger = AI.algorithm_logger() + + AI.emit_message( + logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreUpdate) + ) + + state.iterate = + contract_messages(algorithm.contraction_alg, problem.factor, problem.messages) + + AI.emit_message( + logger, problem, algorithm, state, Symbol( + logging_context_prefix, + :PreNormalization + ) + ) + + if algorithm.normalize + # TODO: use `sum` not `norm` + message_norm = LinearAlgebra.norm(state.iterate) + if !iszero(message_norm) + state.iterate /= message_norm + end + end + + AI.emit_message( + logger, problem, algorithm, state, + Symbol(logging_context_prefix, :PostNormalization) + ) + + return state +end + +function contract_messages(alg, factor::AbstractArray, messages) + factors = typeof(factor)[factor] + return contract_network(vcat(factors, messages); alg) +end + +function beliefpropagation(network; kwargs...) + return beliefpropagation(BeliefPropagationCache(network), network; kwargs...) +end +function beliefpropagation( + cache::AbstractBeliefPropagationCache, + network = nothing; + kwargs... + ) + problem = BeliefPropagationProblem(network) + + algorithm = select_algorithm(beliefpropagation, cache; kwargs...) + + # The nested algorithms will wrap and manipulate this object. + base_state = BeliefPropagationState(; iterate = cache) + + state = AI.initialize_state(problem, algorithm; iterate = base_state) + + state = AI.solve!(problem, algorithm, state) + + return state.iterate.iterate +end + +function select_algorithm( + ::typeof(beliefpropagation), + cache::AbstractBeliefPropagationCache; + edges = forest_cover_edge_sequence(cache), + maxiter = is_tree(cache) ? 1 : nothing, + tol = -Inf, + message_diff_function = if tol > -Inf + (m1, m2) -> norm(m1 / norm(m1) - m2 / norm(m2)) + else + nothing + end, + kwargs... + ) + if isnothing(maxiter) + throw(ArgumentError("`maxiter` must be specified for non-tree graphs")) + end + + stopping_criterion = AI.StopAfterIteration(maxiter) + + if tol > -Inf + stopping_criterion = stopping_criterion | StopWhenConverged(tol) + end + + extended_kwargs = extend_columns((; message_diff_function, kwargs...), maxiter) + edge_kwargs = rows(extended_kwargs, maxiter) + + return BeliefPropagation(maxiter; stopping_criterion) do repnum + return beliefpropagation_sweep(cache; edges, edge_kwargs[repnum]...) + end +end + +# A single sweep across the given edges. +function beliefpropagation_sweep(::BeliefPropagationCache; edges, kwargs...) + return BeliefPropagationSweep(edges) do edge + return SimpleMessageUpdate(edge; kwargs...) + end +end diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 9c03adf..a371373 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -1,10 +1,19 @@ +using .LazyNamedDimsArrays: Mul, lazy using Combinatorics: combinations -using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph -using Dictionaries: AbstractDictionary, Indices, dictionary -using Graphs: AbstractSimpleGraph +using DataGraphs.DataGraphsPartitionedGraphsExt +using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph, edge_data, get_vertices_data, + vertex_data, vertex_data_type +using Dictionaries: AbstractDictionary, Indices, dictionary, set!, unset! +using Graphs: AbstractSimpleGraph, rem_edge!, rem_vertex! using NamedDimsArrays: AbstractNamedDimsArray, dimnames -using NamedGraphs.GraphsExtensions: add_edges!, arrange_edge, arranged_edges, vertextype -using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype +using NamedGraphs.GraphsExtensions: + GraphsExtensions, arrange_edge, arranged_edges, vertextype +using NamedGraphs.PartitionedGraphs: AbstractPartitionedGraph, PartitionedGraphs, + QuotientVertex, QuotientVertexVertices, QuotientVertices, departition, + partitioned_vertices, partitionedgraph, quotient_graph, quotient_graph_type, + quotientvertices +using NamedGraphs: + NamedGraphs, NamedEdge, NamedGraph, Vertices, parent_graph_indices, vertextype function _TensorNetwork end @@ -20,12 +29,32 @@ struct TensorNetwork{V, VD, UG <: AbstractGraph{V}, Tensors <: AbstractDictionar end end # This assumes the tensor connectivity matches the graph structure. -function _TensorNetwork(graph::AbstractGraph, tensors) - return _TensorNetwork(graph, Dictionary(keys(tensors), values(tensors))) +function TensorNetwork(graph::AbstractGraph, tensors) + return TensorNetwork(graph, Dictionary(keys(tensors), values(tensors))) +end +function TensorNetwork(graph::AbstractGraph, tensors::AbstractDictionary) + tn = _TensorNetwork(graph, tensors) + fix_links!(tn) + return tn +end + +function TensorNetwork{V, VD, UG, Tensors}( + graph::UG + ) where {V, VD, UG <: AbstractGraph{V}, Tensors} + return _TensorNetwork(graph, Tensors()) end -DataGraphs.underlying_graph(tn::TensorNetwork) = getfield(tn, :underlying_graph) -DataGraphs.vertex_data(tn::TensorNetwork) = getfield(tn, :tensors) +# DataGraphs interface + +DataGraphs.underlying_graph(tn::TensorNetwork) = tn.underlying_graph + +DataGraphs.is_vertex_assigned(tn::TensorNetwork, v) = haskey(tn.tensors, v) +DataGraphs.is_edge_assigned(tn::TensorNetwork, e) = false + +DataGraphs.get_vertex_data(tn::TensorNetwork, v) = tn.tensors[v] + +DataGraphs.set_vertex_data!(tn::TensorNetwork, val, v) = set!(tn.tensors, v, val) + function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork}) return fieldtype(type, :underlying_graph) end @@ -49,13 +78,7 @@ end tensornetwork_edges(tensors) = tensornetwork_edges(NamedEdge, tensors) function TensorNetwork(f::Base.Callable, graph::AbstractGraph) - tensors = Dictionary(vertices(graph), f.(vertices(graph))) - return TensorNetwork(graph, tensors) -end -function TensorNetwork(graph::AbstractGraph, tensors) - tn = _TensorNetwork(graph, tensors) - fix_links!(tn) - return tn + return TensorNetwork(graph, Dictionary(map(f, vertices(graph)))) end # Insert trivial links for missing edges, and also check @@ -93,3 +116,95 @@ end NamedGraphs.convert_vertextype(::Type{V}, tn::TensorNetwork{V}) where {V} = tn NamedGraphs.convert_vertextype(V::Type, tn::TensorNetwork) = TensorNetwork{V}(tn) + +function Graphs.connected_components(tn::TensorNetwork) + return Graphs.connected_components(underlying_graph(tn)) +end + +function Graphs.rem_edge!(tn::TensorNetwork, e) + if !has_edge(underlying_graph(tn), e) + return false + end + if !isempty(linkinds(tn, e)) + throw( + ArgumentError( + "cannot remove edge $e due to tensor indices existing on this edge." + ) + ) + end + rem_edge!(underlying_graph(tn), e) + return true +end + +function GraphsExtensions.similar_graph(type::Type{<:TensorNetwork}) + DT = fieldtype(type, :tensors) + empty_dict = DT() + return TensorNetwork(similar_graph(underlying_graph_type(type)), empty_dict) +end +function GraphsExtensions.similar_graph(tn::TensorNetwork, underlying_graph::AbstractGraph) + DT = fieldtype(typeof(tn), :tensors) + empty_dict = DT() + return _TensorNetwork(underlying_graph, empty_dict) +end + +function NamedGraphs.induced_subgraph_from_vertices(graph::TensorNetwork, subvertices) + return tensornetwork_induced_subgraph(graph, subvertices) +end + +function tensornetwork_induced_subgraph(graph, subvertices) + underlying_subgraph, vlist = + Graphs.induced_subgraph(underlying_graph(graph), subvertices) + + subgraph = TensorNetwork(underlying_subgraph) do vertex + return graph[vertex] + end + + return subgraph, vlist +end + +## PartitionedGraphs +function PartitionedGraphs.quotient_graph(tn::TensorNetwork) + ug = quotient_graph(underlying_graph(tn)) + + inds = Indices(parent_graph_indices(QuotientVertices(tn))) + data = map(v -> tn[QuotientVertex(v)], inds) + + return TensorNetwork(ug, data) +end +# TODO: This method should not be required with a better interface with a better +# DataGraphsPartitionedGraphsExt interface. +function PartitionedGraphs.quotient_graph_type(type::Type{<:TensorNetwork}) + UG = quotient_graph_type(underlying_graph_type(type)) + VD = Vector{vertex_data_type(type)} + V = vertextype(UG) + return TensorNetwork{V, VD, UG, Dictionary{V, VD}} +end + +# Partition the underlying graph of the tensor network; does not affect the data. +function PartitionedGraphs.partitionedgraph(tn::TensorNetwork, parts) + pg = partitionedgraph(underlying_graph(tn), parts) + return TensorNetwork(pg, copy(vertex_data(tn))) +end + +PartitionedGraphs.departition(tn::TensorNetwork) = tn +function PartitionedGraphs.departition( + tn::TensorNetwork{<:Any, <:Any, <:AbstractPartitionedGraph} + ) + return TensorNetwork(departition(underlying_graph(tn)), vertex_data(tn)) +end + +NamedGraphs.to_graph_index(::TensorNetwork, vertex::QuotientVertex) = vertex +# When getting data according the quotient vertices, take a lazy contraction. +function DataGraphs.get_index_data(tn::TensorNetwork, vertex::QuotientVertex) + data = collect(map(v -> tn[v], vertices(tn, vertex))) + return mapreduce(lazy, *, data) +end +function DataGraphs.is_graph_index_assigned(tn::TensorNetwork, vertex::QuotientVertex) + return isassigned(tn, Vertices(vertices(tn, vertex))) +end + +function PartitionedGraphs.quotientview(tn::TensorNetwork) + qview = QuotientView(underlying_graph(tn)) + tensors = map(qv -> vertex_data(tn)[Indices(qv)], Indices(quotientvertices(tn))) + return TensorNetwork(qview, tensors) +end diff --git a/test/Project.toml b/test/Project.toml index 564db3f..50a58c5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,11 +2,13 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" AlgorithmsInterface = "d1e3940c-cd12-4505-8585-b0a4b322527d" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" ITensorNetworksNext = "302f2e75-49f0-4526-aef7-d8ba550cb06c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" @@ -30,7 +32,7 @@ Graphs = "1.13.1" ITensorBase = "0.5" ITensorNetworksNext = "0.3" NamedDimsArrays = "0.14" -NamedGraphs = "0.6.8, 0.7, 0.8" +NamedGraphs = "0.10" QuadGK = "2.11.2" SafeTestsets = "0.1" Suppressor = "0.2.8" diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl new file mode 100644 index 0000000..d1cca76 --- /dev/null +++ b/test/test_beliefpropagation.jl @@ -0,0 +1,90 @@ +using DiagonalArrays: δ +using Dictionaries: Dictionary, set! +using Graphs: AbstractGraph, dst, edges, src, vertices +using ITensorBase: ITensor, Index, noprime, prime +using ITensorNetworksNext: + ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, partitionfunction +using LinearAlgebra: LinearAlgebra +using NamedDimsArrays: inds, name +using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges, vertextype +using NamedGraphs.NamedGraphGenerators: named_comb_tree, named_grid +using Test: @test, @testset + +function ising_tensornetwork(g::AbstractGraph, β::Real; h = 0.0) + links = Dictionary( + edges(g), + [Index(2; tags = "edge" => "e$(src(e))_$(dst(e))") for e in edges(g)] + ) + links = merge(links, Dictionary(reverse.(edges(g)), [links[e] for e in edges(g)])) + + # symmetric sqrt of Boltzmann matrix W = exp(β σσ') + sqrt_Ws = Dictionary() + for e in edges(g) + W = [exp(-(β + 2 * h)) exp(β); exp(β) exp(-(β - 2 * h))] + + F = LinearAlgebra.svd(W) + U, S, V = F.U, F.S, F.Vt + @assert U * LinearAlgebra.diagm(S) * V ≈ W + id = [1.0 0.0; 0.0 1.0] + set!(sqrt_Ws, e, id) + set!(sqrt_Ws, reverse(e), U * LinearAlgebra.diagm(S) * V) + end + ts = Dictionary{vertextype(g), ITensor}() + for v in vertices(g) + es = incident_edges(g, v; dir = :in) + #t = ITensor(1.0, physical_inds[v]...) * delta([links[e] for e in es]) + t = δ(Float64, Tuple([links[e] for e in es])) + for e in es + t_prime = ITensor(sqrt_Ws[e], (name(links[e]), name(prime(links[e])))) * t + newinds = noprime.(inds(t_prime)) + t = ITensor(parent(t_prime), name.(newinds)) + end + set!(ts, v, t) + end + return TensorNetwork(g, ts) +end +@testset "BeliefPropagation" begin + + #Chain of tensors + dims = (2, 1) + g = named_grid(dims) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + bpc = BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test z_bp ≈ z_exact atol = 1.0e-14 + + #Tree of tensors + dims = (4, 3) + g = named_comb_tree(dims) + l = Dict(e => Index(3) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + bpc = BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test z_bp ≈ z_exact atol = 1.0e-10 + + #Square lattice Ising model + dims = (3, 3) + g = named_grid(dims) + tn = ising_tensornetwork(g, 0.05, h = 0.5) + bpc = ITensorNetworksNext.BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 50, tol = 1.0e-10) + + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test z_bp ≈ z_exact rtol = 1.0e-4 +end diff --git a/test/test_contract_network.jl b/test/test_contract_network.jl index 7dda0c6..b453e76 100644 --- a/test/test_contract_network.jl +++ b/test/test_contract_network.jl @@ -1,3 +1,4 @@ +using BackendSelection: @Algorithm_str, Algorithm using Graphs: edges using ITensorBase: Index using ITensorNetworksNext: TensorNetwork, contract_network, linkinds, siteinds @@ -7,6 +8,8 @@ using TensorOperations: TensorOperations using Test: @test, @testset @testset "contract_network" begin + orderalg = alg -> Algorithm"exact"(; order_alg = Algorithm(alg)) + @testset "Contract Vectors of ITensors" begin i, j, k = Index(2), Index(2), Index(5) A = [1.0 1.0; 0.5 1.0][i, j] @@ -14,10 +17,9 @@ using Test: @test, @testset C = [5.0, 1.0][j] D = [-2.0, 3.0, 4.0, 5.0, 1.0][k] - ABCD_1 = contract_network([A, B, C, D]; order_alg = "left_associative") - ABCD_2 = contract_network([A, B, C, D]; order_alg = "eager") - ABCD_3 = contract_network([A, B, C, D]; order_alg = "optimal") - + ABCD_1 = contract_network([A, B, C, D]; alg = orderalg("left_associative")) + ABCD_2 = contract_network([A, B, C, D]; alg = orderalg("eager")) + ABCD_3 = contract_network([A, B, C, D]; alg = orderalg("optimal")) @test ABCD_1 == ABCD_2 == ABCD_3 end @@ -31,9 +33,9 @@ using Test: @test, @testset return randn(Tuple(is)) end - z1 = contract_network(tn; order_alg = "left_associative")[] - z2 = contract_network(tn; order_alg = "eager")[] - z3 = contract_network(tn; order_alg = "optimal")[] + z1 = contract_network(tn; alg = orderalg("left_associative"))[] + z2 = contract_network(tn; alg = orderalg("eager"))[] + z3 = contract_network(tn; alg = orderalg("optimal"))[] @test abs(z1 - z2) / abs(z1) <= 1.0e3 * eps(Float64) @test abs(z1 - z3) / abs(z1) <= 1.0e3 * eps(Float64)