diff --git a/CHANGELOG.md b/CHANGELOG.md index 3206f006e..a91414b87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The following emojis are used to highlight certain changes: ### Added +- `retrieval`: added `State.Snapshot`, `State.Apply`, and `State.Notify` so consumers can stream `State` across a process boundary, e.g. to drive a live progress bar in Kubo's `cat`, `get`, or `dag export`. [#1153](https://github.com/ipfs/boxo/pull/1153) - 🛠 `pinning/pinner`: added `Pinner.Close() error`. Close cancels every in-flight operation's context, including streaming goroutines from `RecursiveKeys`, `DirectKeys`, and `InternalPins`, and waits for them to return. A scalar method that observes the cancellation may return `context.Canceled`; a stream interrupted by Close may surface `ErrClosed` on the channel before it closes. After Close returns, every other method returns the new `ErrClosed` sentinel; streaming methods deliver it as `StreamedPin.Err` on a single entry, then close the channel. Close is idempotent and goroutine-safe. **Action required:** downstream `Pinner` implementations must add `Close`. [#1150](https://github.com/ipfs/boxo/pull/1150) - `pinning/pinner/dspinner`: implements `Close`. Close cancels the contexts of in-flight operations, so snapshot iteration in `RecursiveKeys`/`DirectKeys` and DAG fetches in `Pin` bail out promptly instead of draining to completion. Close returns as soon as those operations honor their ctx. Hosts owning the datastore should call `Close` on the pinner before closing the datastore to avoid the use-after-close panic path in stores such as pebble. [#1150](https://github.com/ipfs/boxo/pull/1150) @@ -40,6 +41,8 @@ The following emojis are used to highlight certain changes: See [ipfs/kubo#11254](https://github.com/ipfs/kubo/pull/11254) for a worked example of the call-site update. [#1128](https://github.com/ipfs/boxo/pull/1128) +- `path/resolver`: `ResolveToLastNode`, `ResolvePath`, and `ResolvePathComponents` now populate `retrieval.State` on the request context when one is attached. They advance the state to `PhasePathResolution`, record the root CID from the input path, and record the terminal CID once resolution completes. Until now only the gateway backends populated these fields, leaving non-gateway callers (CLIs, custom tools) without phase or CID diagnostics on retrieval errors. The new calls are idempotent with the existing gateway-side ones, so behavior on the gateway path is unchanged. + ### Removed ### Fixed diff --git a/path/resolver/resolver.go b/path/resolver/resolver.go index 47ce1c5f4..43ef2db24 100644 --- a/path/resolver/resolver.go +++ b/path/resolver/resolver.go @@ -9,6 +9,7 @@ import ( "github.com/ipfs/boxo/fetcher" fetcherhelpers "github.com/ipfs/boxo/fetcher/helpers" "github.com/ipfs/boxo/path" + "github.com/ipfs/boxo/retrieval" cid "github.com/ipfs/go-cid" logging "github.com/ipfs/go-log/v2" "github.com/ipld/go-ipld-prime" @@ -81,8 +82,10 @@ func (r *basicResolver) ResolveToLastNode(ctx context.Context, fpath path.Immuta defer span.End() c, remainder := fpath.RootCid(), fpath.Segments()[2:] + enterPathResolution(ctx, c) if len(remainder) == 0 { + setTerminalCid(ctx, c) return c, nil, nil } @@ -121,6 +124,7 @@ func (r *basicResolver) ResolveToLastNode(ctx context.Context, fpath path.Immuta // if last node is not a link, just return it's cid, add path to remainder and return if nd.Kind() != ipld.Kind_Link { + setTerminalCid(ctx, lastCid) // return the cid and the remainder of the path return lastCid, remainder[len(remainder)-depth-1:], nil } @@ -135,6 +139,7 @@ func (r *basicResolver) ResolveToLastNode(ctx context.Context, fpath path.Immuta return cid.Cid{}, nil, fmt.Errorf("path %v resolves to a link that is not a cid link: %v", fpath, lnk) } + setTerminalCid(ctx, clnk.Cid) return clnk.Cid, []string{}, nil } @@ -147,6 +152,7 @@ func (r *basicResolver) ResolvePath(ctx context.Context, fpath path.ImmutablePat defer span.End() c, remainder := fpath.RootCid(), fpath.Segments()[2:] + enterPathResolution(ctx, c) // create a selector to traverse all path segments but only match the last pathSelector := pathLeafSelector(remainder) @@ -158,6 +164,7 @@ func (r *basicResolver) ResolvePath(ctx context.Context, fpath path.ImmutablePat if len(nodes) < 1 { return nil, nil, fmt.Errorf("path %v did not resolve to a node", fpath) } + setTerminalCid(ctx, c) return nodes[len(nodes)-1], cidlink.Link{Cid: c}, nil } @@ -172,11 +179,15 @@ func (r *basicResolver) ResolvePathComponents(ctx context.Context, fpath path.Im defer log.Debugw("resolvePathComponents", "fpath", fpath, "error", err) c, remainder := fpath.RootCid(), fpath.Segments()[2:] + enterPathResolution(ctx, c) // create a selector to traverse and match all path segments pathSelector := pathAllSelector(remainder) - nodes, _, _, err = r.resolveNodes(ctx, c, pathSelector) + nodes, terminal, _, err := r.resolveNodes(ctx, c, pathSelector) + if err == nil && terminal.Defined() { + setTerminalCid(ctx, terminal) + } return nodes, err } @@ -246,3 +257,27 @@ func pathSelector(path []string, ssb builder.SelectorSpecBuilder, reduce func(st func startSpan(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) { return otel.Tracer("boxo/path/resolver").Start(ctx, "Path."+name, opts...) } + +// enterPathResolution advances retrieval state into PhasePathResolution and +// records the root CID, when a [retrieval.State] is attached to ctx. It is a +// no-op otherwise. Calls are idempotent: SetPhase is monotonic, and SetRootCID +// is last-write-wins under a mutex. +func enterPathResolution(ctx context.Context, root cid.Cid) { + if rs := retrieval.StateFromContext(ctx); rs != nil { + rs.SetPhase(retrieval.PhasePathResolution) + if root.Defined() { + rs.SetRootCID(root) + } + } +} + +// setTerminalCid records the CID of the terminating DAG entity on the resolved +// path, when a [retrieval.State] is attached to ctx. Otherwise it is a no-op. +func setTerminalCid(ctx context.Context, terminal cid.Cid) { + if !terminal.Defined() { + return + } + if rs := retrieval.StateFromContext(ctx); rs != nil { + rs.SetTerminalCID(terminal) + } +} diff --git a/path/resolver/resolver_test.go b/path/resolver/resolver_test.go index bd2d4e718..6e4cbf1f5 100644 --- a/path/resolver/resolver_test.go +++ b/path/resolver/resolver_test.go @@ -13,6 +13,7 @@ import ( dagmock "github.com/ipfs/boxo/ipld/merkledag/test" "github.com/ipfs/boxo/path" "github.com/ipfs/boxo/path/resolver" + "github.com/ipfs/boxo/retrieval" blocks "github.com/ipfs/go-block-format" "github.com/ipfs/go-cid" "github.com/ipfs/go-unixfsnode" @@ -265,3 +266,75 @@ func TestResolveToLastNode_MixedSegmentTypes(t *testing.T) { require.Equal(t, 0, len(remainder)) require.True(t, cid.Equals(a.Cid())) } + +// TestRetrievalStatePropagation verifies that the resolver advances +// retrieval.State into PhasePathResolution and records both the root and +// terminal CIDs when a State is attached to the request context. This is what +// lets non-gateway callers (like kubo's CLI) get phase + CID diagnostics for +// free, without each call site having to hand-set them. +func TestRetrievalStatePropagation(t *testing.T) { + bsrv := dagmock.Bserv() + + root := randNode() + mid := randNode() + leaf := randNode() + require.NoError(t, mid.AddNodeLink("grandchild", leaf)) + require.NoError(t, root.AddNodeLink("child", mid)) + for _, n := range []*merkledag.ProtoNode{root, mid, leaf} { + require.NoError(t, bsrv.AddBlock(t.Context(), n)) + } + + fetcherFactory := bsfetcher.NewFetcherConfig(bsrv) + fetcherFactory.NodeReifier = unixfsnode.Reify + fetcherFactory.PrototypeChooser = dagpb.AddSupportToChooser(func(lnk ipld.Link, lnkCtx ipld.LinkContext) (ipld.NodePrototype, error) { + if tlnkNd, ok := lnkCtx.LinkNode.(schema.TypedLinkNode); ok { + return tlnkNd.LinkTargetNodePrototype(), nil + } + return basicnode.Prototype.Any, nil + }) + r := resolver.NewBasicResolver(fetcherFactory) + + p, err := path.Join(path.FromCid(root.Cid()), "child", "grandchild") + require.NoError(t, err) + imPath, err := path.NewImmutablePath(p) + require.NoError(t, err) + + t.Run("ResolveToLastNode populates state", func(t *testing.T) { + ctx, rs := retrieval.ContextWithState(t.Context()) + require.Equal(t, retrieval.PhaseInitializing, rs.GetPhase()) + + _, _, err := r.ResolveToLastNode(ctx, imPath) + require.NoError(t, err) + + require.GreaterOrEqual(t, int(rs.GetPhase()), int(retrieval.PhasePathResolution)) + require.True(t, rs.GetRootCID().Equals(root.Cid()), "root CID should match path root") + require.True(t, rs.GetTerminalCID().Equals(leaf.Cid()), "terminal CID should match resolved leaf") + }) + + t.Run("ResolvePath populates state", func(t *testing.T) { + ctx, rs := retrieval.ContextWithState(t.Context()) + + _, _, err := r.ResolvePath(ctx, imPath) + require.NoError(t, err) + + require.GreaterOrEqual(t, int(rs.GetPhase()), int(retrieval.PhasePathResolution)) + require.True(t, rs.GetRootCID().Equals(root.Cid())) + require.True(t, rs.GetTerminalCID().Equals(leaf.Cid())) + }) + + t.Run("CID-only path sets terminal to root", func(t *testing.T) { + ctx, rs := retrieval.ContextWithState(t.Context()) + + _, _, err := r.ResolveToLastNode(ctx, path.FromCid(root.Cid())) + require.NoError(t, err) + + require.True(t, rs.GetRootCID().Equals(root.Cid())) + require.True(t, rs.GetTerminalCID().Equals(root.Cid()), + "for /ipfs/ with no path, root and terminal should match") + }) + + t.Run("no state on context is a no-op", func(t *testing.T) { + _, _, err := r.ResolveToLastNode(t.Context(), imPath) + require.NoError(t, err) + }) +} diff --git a/retrieval/state.go b/retrieval/state.go index d83ded79a..9eb5f76f8 100644 --- a/retrieval/state.go +++ b/retrieval/state.go @@ -1,8 +1,21 @@ -// Package retrieval provides state tracking for IPFS content retrieval operations. -// It enables detailed diagnostics about the retrieval process, including which stage -// failed (path resolution, provider discovery, connection, or block retrieval) and -// statistics about provider interactions. This information is particularly useful -// for debugging timeout errors and understanding retrieval performance. +// Package retrieval tracks the state of an IPFS content retrieval as +// it moves through path resolution, provider discovery, connection, +// and data transfer. State lives on the request context and is +// updated by boxo's path resolver, provider query manager, and +// gateway middleware as the retrieval progresses. +// +// Typical consumers: +// +// - boxo/gateway wraps timeout errors with the State (see +// [WrapWithState]) so 504 responses include which phase was +// active and how many providers were found. +// +// - CLI tools like Kubo can mirror a daemon's State into a local +// one via the [State.Snapshot] / [State.Apply] / [State.Notify] +// pub/sub interface to drive a live progress bar during commands +// like cat, get, or dag export. +// +// Attach with [ContextWithState]; read with [StateFromContext]. package retrieval import ( @@ -49,7 +62,9 @@ const ( PhaseDataRetrieval ) -// String returns a human-readable name for the retrieval phase. +// String returns a human-readable name for the retrieval phase, used +// in error messages and log output. JSON encoding of phases (in +// [Snapshot]) uses the underlying int. func (p RetrievalPhase) String() string { switch p { case PhaseInitializing: @@ -96,12 +111,20 @@ type State struct { // For /ipfs/cid/path/to/file, rootCID is 'cid' and terminalCID is the CID of 'file' rootCID cid.Cid // First CID in the path terminalCID cid.Cid // CID of terminating DAG entity on the path + + // notify is a size-1, coalescing channel used to wake subscribers when + // the State changes. Writers do a non-blocking send; if a wake-up is + // already pending the send is dropped. Subscribers read the channel and + // then call Snapshot to read the latest values; intermediate updates + // between sends are coalesced into a single wake-up. Always non-nil + // after [NewState]. + notify chan struct{} } // NewState creates a new State initialized to PhaseInitializing. The returned // state is safe for concurrent use. func NewState() *State { - rs := &State{} + rs := &State{notify: make(chan struct{}, 1)} rs.phase.Store(int32(PhaseInitializing)) return rs } @@ -120,6 +143,7 @@ func (rs *State) SetPhase(phase RetrievalPhase) { } // Try to update atomically if rs.phase.CompareAndSwap(current, newPhase) { + rs.signal() return } // If CAS failed, another goroutine updated it, loop will check again @@ -135,24 +159,24 @@ func (rs *State) GetPhase() RetrievalPhase { // appendProviders is a helper to append providers to a sample list with size limit. // Only the first MaxProvidersSampleSize providers are kept to prevent unbounded memory growth. // Duplicate peer IDs are automatically filtered out to ensure each peer appears only once. -// This follows the idiomatic append pattern but operates on internal state. +// Signals subscribers if any peer was actually added. func (rs *State) appendProviders(list *[]peer.ID, peerIDs ...peer.ID) { rs.mu.Lock() - defer rs.mu.Unlock() - if len(*list) >= MaxProvidersSampleSize { - return - } + prev := len(*list) for _, peerID := range peerIDs { - // Skip if we already have this peer ID in the list + if len(*list) >= MaxProvidersSampleSize { + break + } if slices.Contains(*list, peerID) { continue } - // Stop if we've reached the sample size limit - if len(*list) >= MaxProvidersSampleSize { - return - } *list = append(*list, peerID) } + changed := len(*list) != prev + rs.mu.Unlock() + if changed { + rs.signal() + } } // AddFoundProvider records a provider peer ID that was discovered during provider search. @@ -190,16 +214,24 @@ func (rs *State) GetFailedProviders() []peer.ID { // This method is safe for concurrent use. func (rs *State) SetRootCID(c cid.Cid) { rs.mu.Lock() - defer rs.mu.Unlock() + changed := !rs.rootCID.Equals(c) rs.rootCID = c + rs.mu.Unlock() + if changed { + rs.signal() + } } // SetTerminalCID sets the terminal CID (CID of terminating DAG entity). // This method is safe for concurrent use. func (rs *State) SetTerminalCID(c cid.Cid) { rs.mu.Lock() - defer rs.mu.Unlock() + changed := !rs.terminalCID.Equals(c) rs.terminalCID = c + rs.mu.Unlock() + if changed { + rs.signal() + } } // GetRootCID returns the root CID (first CID in the path). @@ -218,6 +250,135 @@ func (rs *State) GetTerminalCID() cid.Cid { return rs.terminalCID } +// Snapshot is an immutable copy of a [State] at a point in time. It is +// safe to share across goroutines and to serialize as JSON. Receivers +// (e.g. CLIs reading from a streaming endpoint) reconstitute a local +// State by calling [State.Apply] with the snapshot. +// +// JSON encoding uses Go's default field naming (PascalCase, matching +// the struct fields verbatim). Phase is encoded as the underlying +// integer of [RetrievalPhase] (type RetrievalPhase int). Receivers can +// compare against the [PhaseInitializing] / [PhasePathResolution] / +// etc. constants directly, or call [RetrievalPhase.String] for a +// human-readable form. +type Snapshot struct { + Phase RetrievalPhase + ProvidersFound int32 + ProvidersAttempted int32 + ProvidersConnected int32 + FoundProviders []peer.ID + FailedProviders []peer.ID + RootCID cid.Cid + TerminalCID cid.Cid +} + +// Snapshot returns the current state as an immutable value. Slice fields +// are cloned, so callers may freely retain or modify the result without +// affecting the live State. +// +// Consistency: the read takes the State's lock for slices and CIDs, but +// counter fields ([State.ProvidersFound] etc.) are atomics that other +// writers update without the lock. A concurrent writer that mutates an +// atomic counter while Snapshot is running may produce a snapshot whose +// counters are slightly newer than its slices (or vice versa). For +// observability and progress UI use cases this eventual consistency is +// fine; callers needing a single-instant atomic view across all fields +// would need writers to also take the lock, which would slow them down. +func (rs *State) Snapshot() Snapshot { + rs.mu.RLock() + defer rs.mu.RUnlock() + return Snapshot{ + Phase: RetrievalPhase(rs.phase.Load()), + ProvidersFound: rs.ProvidersFound.Load(), + ProvidersAttempted: rs.ProvidersAttempted.Load(), + ProvidersConnected: rs.ProvidersConnected.Load(), + FoundProviders: slices.Clone(rs.foundProviders), + FailedProviders: slices.Clone(rs.failedProviders), + RootCID: rs.rootCID, + TerminalCID: rs.terminalCID, + } +} + +// Apply mirrors a Snapshot onto this State. It is intended for +// receivers that observe a remote State over a transport (e.g. NDJSON +// over HTTP) and want to reflect the remote values into a local State +// that some local UI is observing. Phase progression remains monotonic: +// a snapshot with an earlier phase will not move the local phase +// backwards. All writes happen under one critical section, so observers +// either see the snapshot in full or not at all, and Apply emits +// exactly one wake-up on [State.Notify]. +// +// Apply assumes snapshots arrive in causal order from a single +// producer. Out-of-order delivery (e.g. multiple producers, or a +// transport that reorders) is unsupported: counters and CID/slice +// fields are written unconditionally, so a stale snapshot can regress +// them. The retrieval-state pipeline shipped in kubo (one daemon-side +// State, one CLI-side subscriber, ordered NDJSON) satisfies this +// assumption by construction. +func (rs *State) Apply(s Snapshot) { + // Phase update via the same monotonic CAS loop SetPhase uses, so + // concurrent SetPhase callers cannot regress the phase via Apply + // even though SetPhase does not take rs.mu. + target := int32(s.Phase) + for { + cur := rs.phase.Load() + if target <= cur { + break + } + if rs.phase.CompareAndSwap(cur, target) { + break + } + } + + rs.mu.Lock() + rs.ProvidersFound.Store(s.ProvidersFound) + rs.ProvidersAttempted.Store(s.ProvidersAttempted) + rs.ProvidersConnected.Store(s.ProvidersConnected) + rs.foundProviders = slices.Clone(s.FoundProviders) + rs.failedProviders = slices.Clone(s.FailedProviders) + rs.rootCID = s.RootCID + rs.terminalCID = s.TerminalCID + rs.mu.Unlock() + rs.signal() +} + +// Notify returns a size-1 channel that is signalled when the State +// changes. Writes are coalescing: if multiple updates happen between +// successive receives, the receiver wakes once and should call +// [State.Snapshot] to observe the latest values. +// +// Lifecycle: the channel never closes. Subscribers should stop +// receiving by other means, typically a context cancellation in the +// surrounding select: +// +// for { +// select { +// case <-ctx.Done(): +// return +// case <-state.Notify(): +// publish(state.Snapshot()) +// } +// } +// +// Single-subscriber: the channel is shared, not fan-out. If two +// goroutines receive on it, each wake-up goes to one of them +// non-deterministically and the other misses it. To support multiple +// subscribers, fan out via your own goroutine: a single reader on +// Notify that broadcasts to a slice of per-subscriber channels. +func (rs *State) Notify() <-chan struct{} { + return rs.notify +} + +// signal performs a non-blocking send on the notification channel. If the +// channel is full (a wake-up is already pending) the send is dropped. Used +// internally by every State write that observers might care about. +func (rs *State) signal() { + select { + case rs.notify <- struct{}{}: + default: + } +} + // formatPeerIDs converts a slice of peer IDs to a formatted string with a prefix. // Returns empty string if the slice is empty. func formatPeerIDs(peers []peer.ID, prefix string) string { diff --git a/retrieval/state_test.go b/retrieval/state_test.go index dba7c438d..a624ac262 100644 --- a/retrieval/state_test.go +++ b/retrieval/state_test.go @@ -2,9 +2,11 @@ package retrieval import ( "context" + "encoding/json" "errors" "sync" "testing" + "time" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/test" @@ -592,3 +594,177 @@ func TestErrorWithState(t *testing.T) { assert.Len(t, state.GetFailedProviders(), 1) }) } + +// TestStateNotifyAndSnapshot covers the pub/sub additions used by external +// progress UIs that observe a [State] across a process boundary. +func TestStateNotifyAndSnapshot(t *testing.T) { + t.Run("Snapshot returns an immutable copy", func(t *testing.T) { + rs := NewState() + rs.SetPhase(PhaseProviderDiscovery) + rs.ProvidersFound.Store(3) + p1, _ := test.RandPeerID() + p2, _ := test.RandPeerID() + rs.AddFoundProvider(p1) + rs.AddFoundProvider(p2) + + snap := rs.Snapshot() + require.Equal(t, PhaseProviderDiscovery, snap.Phase) + require.Equal(t, int32(3), snap.ProvidersFound) + require.Len(t, snap.FoundProviders, 2) + + // Mutating the snapshot must not affect the live State. + snap.FoundProviders[0] = peer.ID("tampered") + require.NotEqual(t, "tampered", string(rs.GetFoundProviders()[0])) + }) + + t.Run("Notify wakes on phase advance and coalesces", func(t *testing.T) { + rs := NewState() + ch := rs.Notify() + + // First write fills the buffer (size 1). + rs.SetPhase(PhasePathResolution) + // Second write before drain is coalesced. + rs.SetPhase(PhaseProviderDiscovery) + + select { + case <-ch: + // got the (single) wake-up + default: + t.Fatal("expected a notification after SetPhase") + } + + // Channel must be empty now: the second SetPhase was coalesced. + select { + case <-ch: + t.Fatal("notification channel should have been drained") + default: + } + + // Subsequent writes wake again. + p, _ := test.RandPeerID() + rs.AddFoundProvider(p) + select { + case <-ch: + default: + t.Fatal("expected a notification after AddFoundProvider") + } + }) + + t.Run("monotonic SetPhase to the current phase does not signal", func(t *testing.T) { + rs := NewState() + rs.SetPhase(PhaseProviderDiscovery) + <-rs.Notify() // drain the initial advance + + // SetPhase at or below current must not signal. + rs.SetPhase(PhasePathResolution) + rs.SetPhase(PhaseProviderDiscovery) + select { + case <-rs.Notify(): + t.Fatal("non-advancing SetPhase must not wake subscribers") + default: + } + }) + + t.Run("Apply restores remote snapshot into a local State", func(t *testing.T) { + // Build a "remote" state with some content. + remote := NewState() + remote.SetPhase(PhaseConnecting) + remote.ProvidersFound.Store(4) + remote.ProvidersAttempted.Store(2) + remote.ProvidersConnected.Store(1) + p, _ := test.RandPeerID() + remote.AddFoundProvider(p) + + // Receiver: apply the snapshot to a fresh State. + local := NewState() + ch := local.Notify() + local.Apply(remote.Snapshot()) + + require.Equal(t, PhaseConnecting, local.GetPhase()) + require.Equal(t, int32(4), local.ProvidersFound.Load()) + require.Equal(t, int32(1), local.ProvidersConnected.Load()) + require.Len(t, local.GetFoundProviders(), 1) + + select { + case <-ch: + default: + t.Fatal("Apply should signal subscribers") + } + }) + + t.Run("Apply preserves monotonic phase", func(t *testing.T) { + local := NewState() + local.SetPhase(PhaseConnecting) + + stale := Snapshot{Phase: PhasePathResolution} + local.Apply(stale) + require.Equal(t, PhaseConnecting, local.GetPhase(), + "Apply with an earlier phase must not move local phase backwards") + }) + + t.Run("concurrent writers and one subscriber", func(t *testing.T) { + rs := NewState() + ch := rs.Notify() + + const writers = 8 + const writes = 25 + var wg sync.WaitGroup + wg.Add(writers) + for w := 0; w < writers; w++ { + go func() { + defer wg.Done() + for i := 0; i < writes; i++ { + p, _ := test.RandPeerID() + rs.AddFoundProvider(p) + } + }() + } + + // Drain wake-ups while writers are running. We don't expect to see + // every individual write (coalesced), only that we never deadlock + // and we eventually observe at least one wake-up. + got := 0 + done := make(chan struct{}) + go func() { + for range ch { + got++ + if got == 1 { + close(done) + } + } + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("no wake-up received under concurrent writers") + } + wg.Wait() + }) +} + +// TestSnapshotJSONRoundTrip confirms a Snapshot survives a JSON +// marshal/unmarshal/Apply round-trip. RetrievalPhase is encoded as +// its underlying int (RetrievalPhase is `type RetrievalPhase int`); +// the wire format uses Go's default field naming. +func TestSnapshotJSONRoundTrip(t *testing.T) { + rs := NewState() + rs.SetPhase(PhaseConnecting) + rs.ProvidersFound.Store(7) + + data, err := json.Marshal(rs.Snapshot()) + require.NoError(t, err) + require.Contains(t, string(data), `"Phase":3`, + "Phase encodes as the underlying int (PhaseConnecting == 3)") + require.Contains(t, string(data), `"ProvidersFound":7`) + + var snap Snapshot + require.NoError(t, json.Unmarshal(data, &snap)) + require.Equal(t, PhaseConnecting, snap.Phase) + require.Equal(t, int32(7), snap.ProvidersFound) + + other := NewState() + other.Apply(snap) + require.Equal(t, PhaseConnecting, other.GetPhase()) + require.Equal(t, int32(7), other.ProvidersFound.Load()) +}