diff --git a/CHANGELOG.md b/CHANGELOG.md index 3206f006e..1b2777b8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The following emojis are used to highlight certain changes: - 🛠 `pinning/pinner`: added `Pinner.Close() error`. Close cancels every in-flight operation's context, including streaming goroutines from `RecursiveKeys`, `DirectKeys`, and `InternalPins`, and waits for them to return. A scalar method that observes the cancellation may return `context.Canceled`; a stream interrupted by Close may surface `ErrClosed` on the channel before it closes. After Close returns, every other method returns the new `ErrClosed` sentinel; streaming methods deliver it as `StreamedPin.Err` on a single entry, then close the channel. Close is idempotent and goroutine-safe. **Action required:** downstream `Pinner` implementations must add `Close`. [#1150](https://github.com/ipfs/boxo/pull/1150) - `pinning/pinner/dspinner`: implements `Close`. Close cancels the contexts of in-flight operations, so snapshot iteration in `RecursiveKeys`/`DirectKeys` and DAG fetches in `Pin` bail out promptly instead of draining to completion. Close returns as soon as those operations honor their ctx. Hosts owning the datastore should call `Close` on the pinner before closing the datastore to avoid the use-after-close panic path in stores such as pebble. [#1150](https://github.com/ipfs/boxo/pull/1150) +- `routing/http/types/iter`: added `Limit`, an iterator that caps another iterator at a fixed number of values. ### Changed - upgrade to `go-libp2p-kad-dht` [v0.39.2](https://github.com/libp2p/go-libp2p-kad-dht/releases/tag/v0.39.2) @@ -40,11 +41,14 @@ The following emojis are used to highlight certain changes: See [ipfs/kubo#11254](https://github.com/ipfs/kubo/pull/11254) for a worked example of the call-site update. [#1128](https://github.com/ipfs/boxo/pull/1128) +- ✨ `routing/http/server`: the Delegated Routing server now calls `DelegatedRouter.FindProviders`/`FindPeers` with a limit of `0` (unbounded) and applies the configured records limit itself, after filtering. This is what lets a filtered request still return a full page of results. The server reads the result iterator lazily and closes it once it has enough records, so delegates should return results lazily and stop work when the iterator is closed. A delegate that used the limit to end its walk early will now end it on `Close` instead. + ### Removed ### Fixed - `files`: now builds under `GOOS=js GOARCH=wasm` and `GOOS=wasip1 GOARCH=wasm`. [#935](https://github.com/ipfs/boxo/pull/935) +- `routing/http/server`: filtered `/routing/v1/providers` and `/routing/v1/peers` requests no longer return fewer records than the configured limit. Before, the limit was applied before `filter-addrs`/`filter-protocols` ran, so records dropped by the filters shrank the response. The limit is now applied after filtering. ### Security diff --git a/routing/http/client/client_test.go b/routing/http/client/client_test.go index 9e5ead506..d6fb6f9ec 100644 --- a/routing/http/client/client_test.go +++ b/routing/http/client/client_test.go @@ -403,11 +403,9 @@ func TestClient_FindProviders(t *testing.T) { cid := makeCID() routerResultIter := iter.FromSlice(c.routerResult) - if c.expStreamingResponse { - router.On("FindProviders", mock.Anything, cid, 0).Return(routerResultIter, c.routerErr) - } else { - router.On("FindProviders", mock.Anything, cid, 20).Return(routerResultIter, c.routerErr) - } + // The server always passes 0 (unbounded) to the delegate; it + // enforces records limits itself, after filtering. + router.On("FindProviders", mock.Anything, cid, 0).Return(routerResultIter, c.routerErr) resultIter, err := client.FindProviders(ctx, cid) c.expErrContains.errContains(t, err) @@ -701,11 +699,9 @@ func TestClient_FindPeers(t *testing.T) { } routerResultIter := iter.FromSlice(c.routerResult) - if c.expStreamingResponse { - router.On("FindPeers", mock.Anything, pid, 0).Return(routerResultIter, c.routerErr) - } else { - router.On("FindPeers", mock.Anything, pid, 20).Return(routerResultIter, c.routerErr) - } + // The server always passes 0 (unbounded) to the delegate; it + // enforces records limits itself, after filtering. + router.On("FindPeers", mock.Anything, pid, 0).Return(routerResultIter, c.routerErr) resultIter, err := client.FindPeers(ctx, pid) c.expErrContains.errContains(t, err) diff --git a/routing/http/server/server.go b/routing/http/server/server.go index 1b7a6d254..fe09e1a12 100644 --- a/routing/http/server/server.go +++ b/routing/http/server/server.go @@ -69,6 +69,11 @@ type FindProvidersAsyncResponse struct { type DelegatedRouter interface { // FindProviders searches for peers who are able to provide the given [cid.Cid]. // Limit indicates the maximum amount of results to return; 0 means unbounded. + // + // The HTTP server in this package always calls FindProviders with a + // limit of 0 and caps the response itself, after filtering. It consumes + // the iterator lazily and Closes it once enough records are collected, + // so implementations should return results lazily and stop work on Close. FindProviders(ctx context.Context, cid cid.Cid, limit int) (iter.ResultIter[types.Record], error) // Deprecated: historic API from [IPIP-526], may be removed in a future version. @@ -78,6 +83,9 @@ type DelegatedRouter interface { // FindPeers searches for peers who have the provided [peer.ID]. // Limit indicates the maximum amount of results to return; 0 means unbounded. + // + // As with FindProviders, the HTTP server always calls FindPeers with a + // limit of 0 and caps the response itself, after filtering. FindPeers(ctx context.Context, pid peer.ID, limit int) (iter.ResultIter[*types.PeerRecord], error) // GetIPNS searches for an [ipns.Record] for the given [ipns.Name]. @@ -124,18 +132,20 @@ func WithStreamingResultsDisabled() Option { } } -// WithRecordsLimit sets a limit that will be passed to [ContentRouter.FindProviders] -// and [ContentRouter.FindPeers] for non-streaming requests (application/json). -// Default is [DefaultRecordsLimit]. +// WithRecordsLimit caps the number of records returned for non-streaming +// requests (application/json). The server applies the cap after filtering, +// so filtered-out records do not shrink the response. The delegate +// [ContentRouter] is always called with a limit of 0 (unbounded). +// A limit of 0 disables the cap. Default is [DefaultRecordsLimit]. func WithRecordsLimit(limit int) Option { return func(s *server) { s.recordsLimit = limit } } -// WithStreamingRecordsLimit sets a limit that will be passed to [ContentRouter.FindProviders] -// and [ContentRouter.FindPeers] for streaming requests (application/x-ndjson). -// Default is [DefaultStreamingRecordsLimit]. +// WithStreamingRecordsLimit caps the number of records returned for +// streaming requests (application/x-ndjson). See [WithRecordsLimit] for +// how the cap is applied. Default is [DefaultStreamingRecordsLimit]. func WithStreamingRecordsLimit(limit int) Option { return func(s *server) { s.streamingRecordsLimit = limit @@ -272,7 +282,7 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { } var ( - handlerFunc func(w http.ResponseWriter, provIter iter.ResultIter[types.Record], filterAddrs, filterProtocols []string) + handlerFunc func(w http.ResponseWriter, provIter iter.ResultIter[types.Record], recordsLimit int, filterAddrs, filterProtocols []string) recordsLimit int ) @@ -287,7 +297,13 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { ctx, cancel := context.WithTimeout(httpReq.Context(), s.routingTimeout) defer cancel() - provIter, err := s.svc.FindProviders(ctx, cid, recordsLimit) + // Pass 0 (unbounded) to the delegate and enforce recordsLimit here, + // after filtering. Passing recordsLimit would let the delegate stop + // early, before filters run, so records dropped by filters would + // shrink the response below recordsLimit. The delegate returns + // results lazily; the limiting iterator closes it once the cap is + // reached. + provIter, err := s.svc.FindProviders(ctx, cid, 0) if err != nil { if errors.Is(err, routing.ErrNotFound) { // handlerFunc takes care of setting the 404 and necessary headers @@ -298,14 +314,15 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { } } - handlerFunc(w, provIter, filterAddrs, filterProtocols) + handlerFunc(w, provIter, recordsLimit, filterAddrs, filterProtocols) } -func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.ResultIter[types.Record], filterAddrs, filterProtocols []string) { +func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.ResultIter[types.Record], recordsLimit int, filterAddrs, filterProtocols []string) { defer provIter.Close() filteredIter := filters.ApplyFiltersToIter(provIter, filterAddrs, filterProtocols) - providers, err := iter.ReadAllResults(filteredIter) + var limitedIter iter.ResultIter[types.Record] = iter.Limit(filteredIter, recordsLimit) + providers, err := iter.ReadAllResults(limitedIter) if err != nil { writeErr(w, "FindProviders", http.StatusInternalServerError, fmt.Errorf("delegate error: %w", err)) return @@ -321,10 +338,10 @@ func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.ResultIt }) } -func (s *server) findProvidersNDJSON(w http.ResponseWriter, provIter iter.ResultIter[types.Record], filterAddrs, filterProtocols []string) { +func (s *server) findProvidersNDJSON(w http.ResponseWriter, provIter iter.ResultIter[types.Record], recordsLimit int, filterAddrs, filterProtocols []string) { filteredIter := filters.ApplyFiltersToIter(provIter, filterAddrs, filterProtocols) - - writeResultsIterNDJSON(w, filteredIter) + var limitedIter iter.ResultIter[types.Record] = iter.Limit(filteredIter, recordsLimit) + writeResultsIterNDJSON(w, limitedIter) } func (s *server) findPeers(w http.ResponseWriter, r *http.Request) { @@ -346,7 +363,7 @@ func (s *server) findPeers(w http.ResponseWriter, r *http.Request) { } var ( - handlerFunc func(w http.ResponseWriter, provIter iter.ResultIter[*types.PeerRecord], filterAddrs, filterProtocols []string) + handlerFunc func(w http.ResponseWriter, provIter iter.ResultIter[*types.PeerRecord], recordsLimit int, filterAddrs, filterProtocols []string) recordsLimit int ) @@ -362,7 +379,9 @@ func (s *server) findPeers(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), s.routingTimeout) defer cancel() - provIter, err := s.svc.FindPeers(ctx, pid, recordsLimit) + // Pass 0 (unbounded) to the delegate and enforce recordsLimit here, + // after filtering. See findProviders for the rationale. + provIter, err := s.svc.FindPeers(ctx, pid, 0) if err != nil { if errors.Is(err, routing.ErrNotFound) { // handlerFunc takes care of setting the 404 and necessary headers @@ -373,7 +392,7 @@ func (s *server) findPeers(w http.ResponseWriter, r *http.Request) { } } - handlerFunc(w, provIter, filterAddrs, filterProtocols) + handlerFunc(w, provIter, recordsLimit, filterAddrs, filterProtocols) } func (s *server) provide(w http.ResponseWriter, httpReq *http.Request) { @@ -438,12 +457,13 @@ func (s *server) provide(w http.ResponseWriter, httpReq *http.Request) { writeJSONResult(w, "Provide", resp) } -func (s *server) findPeersJSON(w http.ResponseWriter, peersIter iter.ResultIter[*types.PeerRecord], filterAddrs, filterProtocols []string) { +func (s *server) findPeersJSON(w http.ResponseWriter, peersIter iter.ResultIter[*types.PeerRecord], recordsLimit int, filterAddrs, filterProtocols []string) { defer peersIter.Close() - peersIter = filters.ApplyFiltersToPeerRecordIter(peersIter, filterAddrs, filterProtocols) + filteredIter := filters.ApplyFiltersToPeerRecordIter(peersIter, filterAddrs, filterProtocols) + var limitedIter iter.ResultIter[*types.PeerRecord] = iter.Limit(filteredIter, recordsLimit) - peers, err := iter.ReadAllResults(peersIter) + peers, err := iter.ReadAllResults(limitedIter) if err != nil { writeErr(w, "FindPeers", http.StatusInternalServerError, fmt.Errorf("delegate error: %w", err)) return @@ -459,7 +479,7 @@ func (s *server) findPeersJSON(w http.ResponseWriter, peersIter iter.ResultIter[ }) } -func (s *server) findPeersNDJSON(w http.ResponseWriter, peersIter iter.ResultIter[*types.PeerRecord], filterAddrs, filterProtocols []string) { +func (s *server) findPeersNDJSON(w http.ResponseWriter, peersIter iter.ResultIter[*types.PeerRecord], recordsLimit int, filterAddrs, filterProtocols []string) { // Convert PeerRecord to Record so that we can reuse the filtering logic from findProviders mappedIter := iter.Map(peersIter, func(v iter.Result[*types.PeerRecord]) iter.Result[types.Record] { if v.Err != nil || v.Val == nil { @@ -471,7 +491,8 @@ func (s *server) findPeersNDJSON(w http.ResponseWriter, peersIter iter.ResultIte }) filteredIter := filters.ApplyFiltersToIter(mappedIter, filterAddrs, filterProtocols) - writeResultsIterNDJSON(w, filteredIter) + var limitedIter iter.ResultIter[types.Record] = iter.Limit(filteredIter, recordsLimit) + writeResultsIterNDJSON(w, limitedIter) } func (s *server) GetIPNS(w http.ResponseWriter, r *http.Request) { diff --git a/routing/http/server/server_test.go b/routing/http/server/server_test.go index cc17cb5bf..8d9486459 100644 --- a/routing/http/server/server_test.go +++ b/routing/http/server/server_test.go @@ -46,7 +46,7 @@ func TestHeaders(t *testing.T) { cb, err := cid.Decode(c) require.NoError(t, err) - router.On("FindProviders", mock.Anything, cb, DefaultRecordsLimit). + router.On("FindProviders", mock.Anything, cb, 0). Return(results, nil) resp, err := http.Get(serverAddr + "/routing/v1/providers/" + c) @@ -147,11 +147,9 @@ func TestProviders(t *testing.T) { server := httptest.NewServer(Handler(router)) t.Cleanup(server.Close) serverAddr := "http://" + server.Listener.Addr().String() - limit := DefaultRecordsLimit - if expectedStream { - limit = DefaultStreamingRecordsLimit - } - router.On("FindProviders", mock.Anything, cid, limit).Return(results, nil) + // The server enforces records limits itself, after filtering, and + // always passes 0 (unbounded) to the delegate router. + router.On("FindProviders", mock.Anything, cid, 0).Return(results, nil) urlStr := fmt.Sprintf("%s/routing/v1/providers/%s", serverAddr, cidStr) urlStr = filters.AddFiltersToURL(urlStr, strings.Split(filterProtocols, ","), strings.Split(filterAddrs, ",")) @@ -246,7 +244,7 @@ func TestProviders(t *testing.T) { server := httptest.NewServer(Handler(router)) t.Cleanup(server.Close) serverAddr := "http://" + server.Listener.Addr().String() - router.On("FindProviders", mock.Anything, cid, DefaultRecordsLimit).Return(nil, routing.ErrNotFound) + router.On("FindProviders", mock.Anything, cid, 0).Return(nil, routing.ErrNotFound) req, err := http.NewRequest(http.MethodGet, serverAddr+"/routing/v1/providers/"+cidStr, nil) require.NoError(t, err) @@ -263,6 +261,125 @@ func TestProviders(t *testing.T) { }) } +func TestProvidersRecordsLimit(t *testing.T) { + t.Parallel() + + cidStr := "bafkreifjjcie6lypi6ny7amxnfftagclbuxndqonfipmb64f2km2devei4" + c, err := cid.Decode(cidStr) + require.NoError(t, err) + + quicAddr, err := multiaddr.NewMultiaddr("/ip4/127.0.0.1/udp/4001/quic-v1") + require.NoError(t, err) + tcpAddr, err := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001") + require.NoError(t, err) + + // makeRecords returns n provider records. The first tcpOnly of them + // carry only a TCP address; the rest carry a QUIC address. With + // filter-addrs=quic-v1, only the QUIC records survive filtering. + makeRecords := func(t *testing.T, n, tcpOnly int) []iter.Result[types.Record] { + recs := make([]iter.Result[types.Record], 0, n) + for i := 0; i < n; i++ { + _, pid := makeEd25519PeerID(t) + addr := quicAddr + if i < tcpOnly { + addr = tcpAddr + } + recs = append(recs, iter.Result[types.Record]{ + Val: &types.PeerRecord{ + Schema: types.SchemaPeer, + ID: &pid, + Addrs: []types.Multiaddr{{Multiaddr: addr}}, + }, + }) + } + return recs + } + + const ( + jsonLimit = 5 + streamLimit = 12 + supplied = 30 // exceeds both caps so the cap is the binding limit + ) + + newServerAddr := func(router *mockContentRouter) string { + server := httptest.NewServer(Handler(router, + WithRecordsLimit(jsonLimit), + WithStreamingRecordsLimit(streamLimit), + )) + t.Cleanup(server.Close) + return "http://" + server.Listener.Addr().String() + } + + t.Run("JSON response is capped at WithRecordsLimit", func(t *testing.T) { + t.Parallel() + router := &mockContentRouter{} + // The delegate must be asked for 0 (unbounded); the server caps. + router.On("FindProviders", mock.Anything, c, 0). + Return(iter.FromSlice(makeRecords(t, supplied, 0)), nil) + addr := newServerAddr(router) + + req, err := http.NewRequest(http.MethodGet, addr+"/routing/v1/providers/"+cidStr, nil) + require.NoError(t, err) + req.Header.Set("Accept", mediaTypeJSON) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, jsonLimit, strings.Count(string(body), `"Schema":"peer"`)) + router.AssertExpectations(t) + }) + + t.Run("NDJSON stream is capped at WithStreamingRecordsLimit", func(t *testing.T) { + t.Parallel() + router := &mockContentRouter{} + router.On("FindProviders", mock.Anything, c, 0). + Return(iter.FromSlice(makeRecords(t, supplied, 0)), nil) + addr := newServerAddr(router) + + req, err := http.NewRequest(http.MethodGet, addr+"/routing/v1/providers/"+cidStr, nil) + require.NoError(t, err) + req.Header.Set("Accept", mediaTypeNDJSON) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, streamLimit, strings.Count(string(body), `"Schema":"peer"`)) + router.AssertExpectations(t) + }) + + t.Run("limit counts records that survive filtering", func(t *testing.T) { + t.Parallel() + router := &mockContentRouter{} + // First 10 records are TCP-only and dropped by filter-addrs; the + // remaining 20 carry QUIC and survive. Capping before the filter + // would yield fewer than jsonLimit, so an exact jsonLimit count + // proves the cap is applied after filtering. + router.On("FindProviders", mock.Anything, c, 0). + Return(iter.FromSlice(makeRecords(t, supplied, 10)), nil) + addr := newServerAddr(router) + + urlStr := filters.AddFiltersToURL(addr+"/routing/v1/providers/"+cidStr, nil, []string{"quic-v1"}) + req, err := http.NewRequest(http.MethodGet, urlStr, nil) + require.NoError(t, err) + req.Header.Set("Accept", mediaTypeJSON) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, jsonLimit, strings.Count(string(body), `"Schema":"peer"`)) + router.AssertExpectations(t) + }) +} + func TestPeers(t *testing.T) { makeRequest := func(t *testing.T, router *mockContentRouter, contentType, arg, filterAddrs, filterProtocols string) *http.Response { server := httptest.NewServer(Handler(router)) @@ -296,7 +413,7 @@ func TestPeers(t *testing.T) { results := iter.FromSlice([]iter.Result[*types.PeerRecord]{}) router := &mockContentRouter{} - router.On("FindPeers", mock.Anything, pid, DefaultRecordsLimit).Return(results, nil) + router.On("FindPeers", mock.Anything, pid, 0).Return(results, nil) resp := makeRequest(t, router, mediaTypeJSON, peer.ToCid(pid).String(), "", "") // Per IPIP-0513: Return 200 for empty results @@ -321,7 +438,7 @@ func TestPeers(t *testing.T) { results := iter.FromSlice([]iter.Result[*types.PeerRecord]{}) router := &mockContentRouter{} - router.On("FindPeers", mock.Anything, pid, DefaultRecordsLimit).Return(results, nil) + router.On("FindPeers", mock.Anything, pid, 0).Return(results, nil) // Simulate request with Accept header that includes wildcard match resp := makeRequest(t, router, "text/html,*/*", peer.ToCid(pid).String(), "", "") @@ -343,7 +460,7 @@ func TestPeers(t *testing.T) { results := iter.FromSlice([]iter.Result[*types.PeerRecord]{}) router := &mockContentRouter{} - router.On("FindPeers", mock.Anything, pid, DefaultRecordsLimit).Return(results, nil) + router.On("FindPeers", mock.Anything, pid, 0).Return(results, nil) // Simulate request without Accept header resp := makeRequest(t, router, "", peer.ToCid(pid).String(), "", "") @@ -364,7 +481,7 @@ func TestPeers(t *testing.T) { _, pid := makeEd25519PeerID(t) router := &mockContentRouter{} - router.On("FindPeers", mock.Anything, pid, DefaultRecordsLimit).Return(nil, routing.ErrNotFound) + router.On("FindPeers", mock.Anything, pid, 0).Return(nil, routing.ErrNotFound) // Simulate request without Accept header resp := makeRequest(t, router, "", peer.ToCid(pid).String(), "", "") @@ -399,7 +516,7 @@ func TestPeers(t *testing.T) { }) router := &mockContentRouter{} - router.On("FindPeers", mock.Anything, pid, DefaultRecordsLimit).Return(results, nil) + router.On("FindPeers", mock.Anything, pid, 0).Return(results, nil) libp2pKeyCID := peer.ToCid(pid).String() resp := makeRequest(t, router, mediaTypeJSON, libp2pKeyCID, "", "") @@ -451,7 +568,7 @@ func TestPeers(t *testing.T) { }) router := &mockContentRouter{} - router.On("FindPeers", mock.Anything, pid, DefaultRecordsLimit).Return(results, nil) + router.On("FindPeers", mock.Anything, pid, 0).Return(results, nil) libp2pKeyCID := peer.ToCid(pid).String() resp := makeRequest(t, router, mediaTypeJSON, libp2pKeyCID, "tcp", "") @@ -503,7 +620,7 @@ func TestPeers(t *testing.T) { }) router := &mockContentRouter{} - router.On("FindPeers", mock.Anything, pid, DefaultRecordsLimit).Return(results, nil) + router.On("FindPeers", mock.Anything, pid, 0).Return(results, nil) libp2pKeyCID := peer.ToCid(pid).String() resp := makeRequest(t, router, mediaTypeJSON, libp2pKeyCID, "", "transport-bitswap") @@ -529,7 +646,7 @@ func TestPeers(t *testing.T) { results := iter.FromSlice([]iter.Result[*types.PeerRecord]{}) router := &mockContentRouter{} - router.On("FindPeers", mock.Anything, pid, DefaultStreamingRecordsLimit).Return(results, nil) + router.On("FindPeers", mock.Anything, pid, 0).Return(results, nil) resp := makeRequest(t, router, mediaTypeNDJSON, peer.ToCid(pid).String(), "", "") // Per IPIP-0513: Return 200 for empty results @@ -567,7 +684,7 @@ func TestPeers(t *testing.T) { }) router := &mockContentRouter{} - router.On("FindPeers", mock.Anything, pid, DefaultStreamingRecordsLimit).Return(results, nil) + router.On("FindPeers", mock.Anything, pid, 0).Return(results, nil) libp2pKeyCID := peer.ToCid(pid).String() resp := makeRequest(t, router, mediaTypeNDJSON, libp2pKeyCID, "", "") @@ -630,7 +747,7 @@ func TestPeers(t *testing.T) { t.Parallel() router := &mockContentRouter{} - router.On("FindPeers", mock.Anything, pid, DefaultStreamingRecordsLimit).Return(iter.FromSlice(results), nil) + router.On("FindPeers", mock.Anything, pid, 0).Return(iter.FromSlice(results), nil) resp := makeRequest(t, router, mediaTypeNDJSON, peerIDStr, "", "") require.Equal(t, http.StatusOK, resp.StatusCode) @@ -650,7 +767,7 @@ func TestPeers(t *testing.T) { t.Parallel() router := &mockContentRouter{} - router.On("FindPeers", mock.Anything, pid, DefaultRecordsLimit).Return(iter.FromSlice(results), nil) + router.On("FindPeers", mock.Anything, pid, 0).Return(iter.FromSlice(results), nil) resp := makeRequest(t, router, mediaTypeJSON, peerIDStr, "", "") require.Equal(t, http.StatusOK, resp.StatusCode) diff --git a/routing/http/types/iter/limit.go b/routing/http/types/iter/limit.go new file mode 100644 index 000000000..84eb0042c --- /dev/null +++ b/routing/http/types/iter/limit.go @@ -0,0 +1,35 @@ +package iter + +// Limit returns an iterator that yields at most limit values from iter. +// A limit of 0 or less means no limit, in which case the returned +// iterator behaves like the one passed in. +func Limit[T any](iter Iter[T], limit int) *LimitIter[T] { + return &LimitIter[T]{iter: iter, limit: limit} +} + +// LimitIter is an [Iter] that stops after a fixed number of values, even +// if the underlying iterator has more. Close cascades to that iterator. +type LimitIter[T any] struct { + iter Iter[T] + limit int + count int +} + +func (l *LimitIter[T]) Next() bool { + if l.limit > 0 && l.count >= l.limit { + return false + } + if !l.iter.Next() { + return false + } + l.count++ + return true +} + +func (l *LimitIter[T]) Val() T { + return l.iter.Val() +} + +func (l *LimitIter[T]) Close() error { + return l.iter.Close() +} diff --git a/routing/http/types/iter/limit_test.go b/routing/http/types/iter/limit_test.go new file mode 100644 index 000000000..be1e4ad3c --- /dev/null +++ b/routing/http/types/iter/limit_test.go @@ -0,0 +1,83 @@ +package iter + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLimit(t *testing.T) { + for _, c := range []struct { + name string + input []int + limit int + expResults []int + }{ + { + name: "caps a longer iterator", + input: []int{1, 2, 3, 4, 5}, + limit: 3, + expResults: []int{1, 2, 3}, + }, + { + name: "limit larger than input yields all", + input: []int{1, 2, 3}, + limit: 10, + expResults: []int{1, 2, 3}, + }, + { + name: "limit equal to input yields all", + input: []int{1, 2, 3}, + limit: 3, + expResults: []int{1, 2, 3}, + }, + { + name: "zero limit means unbounded", + input: []int{1, 2, 3}, + limit: 0, + expResults: []int{1, 2, 3}, + }, + { + name: "negative limit means unbounded", + input: []int{1, 2, 3}, + limit: -1, + expResults: []int{1, 2, 3}, + }, + { + name: "empty input yields nothing", + input: []int{}, + limit: 5, + expResults: nil, + }, + } { + t.Run(c.name, func(t *testing.T) { + it := Limit[int](FromSlice(c.input), c.limit) + var res []int + for it.Next() { + res = append(res, it.Val()) + } + assert.Equal(t, c.expResults, res) + }) + } +} + +// closeTrackingIter records whether Close was called, to verify that +// LimitIter cascades Close to the wrapped iterator. +type closeTrackingIter[T any] struct { + Iter[T] + closed bool +} + +func (c *closeTrackingIter[T]) Close() error { + c.closed = true + return c.Iter.Close() +} + +func TestLimitClosesWrappedIter(t *testing.T) { + inner := &closeTrackingIter[int]{Iter: FromSlice([]int{1, 2, 3})} + it := Limit[int](inner, 2) + + require.NoError(t, it.Close()) + require.True(t, inner.closed, "Close must cascade to the wrapped iterator") +}