Skip to content

Commit ecd2663

Browse files
authored
search: add streaming client (#484)
1 parent 0c912e4 commit ecd2663

5 files changed

Lines changed: 630 additions & 0 deletions

File tree

internal/streaming/api.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package streaming
2+
3+
// Progress is an aggregate type representing a progress update.
4+
type Progress struct {
5+
// Done is true if this is a final progress event.
6+
Done bool `json:"done"`
7+
8+
// RepositoriesCount is the number of repositories being searched. It is
9+
// non-nil once the set of repositories has been resolved.
10+
RepositoriesCount *int `json:"repositoriesCount,omitempty"`
11+
12+
// MatchCount is number of non-overlapping matches. If skipped is
13+
// non-empty, then this is a lower bound.
14+
MatchCount int `json:"matchCount"`
15+
16+
// DurationMs is the wall clock time in milliseconds for this search.
17+
DurationMs int `json:"durationMs"`
18+
19+
// Skipped is a description of shards or documents that were skipped. This
20+
// has a deterministic ordering. More important reasons will be listed
21+
// first. If a search is repeated, the final skipped list will be the
22+
// same. However, within a search stream when a new skipped reason is
23+
// found, it may appear anywhere in the list.
24+
Skipped []Skipped `json:"skipped"`
25+
}
26+
27+
// Skipped is a description of shards or documents that were skipped.
28+
type Skipped struct {
29+
// Reason is why a document/shard/repository was skipped. We group counts
30+
// by reason. eg ShardTimeout
31+
Reason SkippedReason `json:"reason"`
32+
// Title is a short message. eg "1,200 timed out".
33+
Title string `json:"title"`
34+
// Message is a message to show the user. Usually includes information
35+
// explaining the reason, count as well as a sample of the missing items.
36+
Message string `json:"message"`
37+
Severity SkippedSeverity `json:"severity"`
38+
// Suggested is a query expression to remedy the skip. eg "archived:yes".
39+
Suggested *SkippedSuggested `json:"suggested,omitempty"`
40+
}
41+
42+
// SkippedSuggested is a query to suggest to the user to resolve the reason
43+
// for skipping.
44+
type SkippedSuggested struct {
45+
Title string `json:"title"`
46+
QueryExpression string `json:"queryExpression"`
47+
}
48+
49+
// SkippedReason is an enum for Skipped.Reason.
50+
type SkippedReason string
51+
52+
const (
53+
// DocumentMatchLimit is when we found too many matches in a document, so
54+
// we stopped searching it.
55+
DocumentMatchLimit SkippedReason = "document-match-limit"
56+
// ShardMatchLimit is when we found too many matches in a
57+
// shard/repository, so we stopped searching it.
58+
ShardMatchLimit = "shard-match-limit"
59+
// RepositoryLimit is when we did not search a repository because the set
60+
// of repositories to search was too large.
61+
RepositoryLimit = "repository-limit"
62+
// ShardTimeout is when we ran out of time before searching a
63+
// shard/repository.
64+
ShardTimeout = "shard-timeout"
65+
// RepositoryCloning is when we could not search a repository because it
66+
// is not cloned.
67+
RepositoryCloning = "repository-cloning"
68+
// RepositoryMissing is when we could not search a repository because it
69+
// is not cloned and we failed to find it on the remote code host.
70+
RepositoryMissing = "repository-missing"
71+
// ExcludedFork is when we did not search a repository because it is a
72+
// fork.
73+
ExcludedFork = "repository-fork"
74+
// ExcludedArchive is when we did not search a repository because it is
75+
// archived.
76+
ExcludedArchive = "excluded-archive"
77+
)
78+
79+
// SkippedSeverity is an enum for Skipped.Severity.
80+
type SkippedSeverity string
81+
82+
const (
83+
SeverityInfo SkippedSeverity = "info"
84+
SeverityWarn = "warn"
85+
)

internal/streaming/client.go

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
package streaming
2+
3+
import (
4+
"bufio"
5+
"bytes"
6+
"encoding/json"
7+
"fmt"
8+
"io"
9+
"net/http"
10+
"net/url"
11+
)
12+
13+
// NewRequest returns an http.Request against the streaming API for query.
14+
func NewRequest(baseURL string, query string) (*http.Request, error) {
15+
u := baseURL + "/search/stream?q=" + url.QueryEscape(query)
16+
req, err := http.NewRequest("GET", u, nil)
17+
if err != nil {
18+
return nil, err
19+
}
20+
req.Header.Set("Accept", "text/event-stream")
21+
return req, nil
22+
}
23+
24+
// Decoder decodes streaming events from a Server Sent Event stream. We only
25+
// support streams which are generated by Sourcegraph. IE this is not a fully
26+
// compliant Server Sent Events decoder.
27+
type Decoder struct {
28+
OnProgress func(*Progress)
29+
OnMatches func([]EventMatch)
30+
OnFilters func([]*EventFilter)
31+
OnAlert func(*EventAlert)
32+
OnError func(*EventError)
33+
OnUnknown func(event, data []byte)
34+
}
35+
36+
func (rr Decoder) ReadAll(r io.Reader) error {
37+
const maxPayloadSize = 10 * 1024 * 1024 // 10mb
38+
scanner := bufio.NewScanner(r)
39+
scanner.Buffer(make([]byte, 0, 4096), maxPayloadSize)
40+
// bufio.ScanLines, except we look for two \n\n which separate events.
41+
split := func(data []byte, atEOF bool) (int, []byte, error) {
42+
if atEOF && len(data) == 0 {
43+
return 0, nil, nil
44+
}
45+
if i := bytes.Index(data, []byte("\n\n")); i >= 0 {
46+
return i + 2, data[:i], nil
47+
}
48+
// If we're at EOF, we have a final, non-terminated event. This should
49+
// be empty.
50+
if atEOF {
51+
return len(data), data, nil
52+
}
53+
// Request more data.
54+
return 0, nil, nil
55+
}
56+
scanner.Split(split)
57+
58+
for scanner.Scan() {
59+
// event: $event\n
60+
// data: json($data)\n\n
61+
data := scanner.Bytes()
62+
nl := bytes.Index(data, []byte("\n"))
63+
if nl < 0 {
64+
return fmt.Errorf("malformed event, no newline: %s", data)
65+
}
66+
67+
eventK, event := splitColon(data[:nl])
68+
dataK, data := splitColon(data[nl+1:])
69+
70+
if !bytes.Equal(eventK, []byte("event")) {
71+
return fmt.Errorf("malformed event, expected event: %s", eventK)
72+
}
73+
if !bytes.Equal(dataK, []byte("data")) {
74+
return fmt.Errorf("malformed event %s, expected data: %s", eventK, dataK)
75+
}
76+
77+
if bytes.Equal(event, []byte("progress")) {
78+
if rr.OnProgress == nil {
79+
continue
80+
}
81+
var d Progress
82+
if err := json.Unmarshal(data, &d); err != nil {
83+
return fmt.Errorf("failed to decode progress payload: %w", err)
84+
}
85+
rr.OnProgress(&d)
86+
} else if bytes.Equal(event, []byte("matches")) {
87+
if rr.OnMatches == nil {
88+
continue
89+
}
90+
var d []eventMatchUnmarshaller
91+
if err := json.Unmarshal(data, &d); err != nil {
92+
return fmt.Errorf("failed to decode matches payload: %w", err)
93+
}
94+
m := make([]EventMatch, 0, len(d))
95+
for _, e := range d {
96+
m = append(m, e.EventMatch)
97+
}
98+
rr.OnMatches(m)
99+
} else if bytes.Equal(event, []byte("filters")) {
100+
if rr.OnFilters == nil {
101+
continue
102+
}
103+
var d []*EventFilter
104+
if err := json.Unmarshal(data, &d); err != nil {
105+
return fmt.Errorf("failed to decode filters payload: %w", err)
106+
}
107+
rr.OnFilters(d)
108+
} else if bytes.Equal(event, []byte("alert")) {
109+
if rr.OnAlert == nil {
110+
continue
111+
}
112+
var d EventAlert
113+
if err := json.Unmarshal(data, &d); err != nil {
114+
return fmt.Errorf("failed to decode alert payload: %w", err)
115+
}
116+
rr.OnAlert(&d)
117+
} else if bytes.Equal(event, []byte("error")) {
118+
if rr.OnError == nil {
119+
continue
120+
}
121+
var d EventError
122+
if err := json.Unmarshal(data, &d); err != nil {
123+
return fmt.Errorf("failed to decode error payload: %w", err)
124+
}
125+
rr.OnError(&d)
126+
} else if bytes.Equal(event, []byte("done")) {
127+
// Always the last event
128+
break
129+
} else {
130+
if rr.OnUnknown == nil {
131+
continue
132+
}
133+
rr.OnUnknown(event, data)
134+
}
135+
}
136+
return scanner.Err()
137+
}
138+
139+
func splitColon(data []byte) ([]byte, []byte) {
140+
i := bytes.Index(data, []byte(":"))
141+
if i < 0 {
142+
return bytes.TrimSpace(data), nil
143+
}
144+
return bytes.TrimSpace(data[:i]), bytes.TrimSpace(data[i+1:])
145+
}
146+
147+
type eventMatchUnmarshaller struct {
148+
EventMatch
149+
}
150+
151+
func (r *eventMatchUnmarshaller) UnmarshalJSON(b []byte) error {
152+
var typeU struct {
153+
Type MatchType `json:"type"`
154+
}
155+
156+
if err := json.Unmarshal(b, &typeU); err != nil {
157+
return err
158+
}
159+
160+
switch typeU.Type {
161+
case FileMatchType:
162+
r.EventMatch = &EventFileMatch{}
163+
case RepoMatchType:
164+
r.EventMatch = &EventRepoMatch{}
165+
case SymbolMatchType:
166+
r.EventMatch = &EventSymbolMatch{}
167+
case CommitMatchType:
168+
r.EventMatch = &EventCommitMatch{}
169+
default:
170+
return fmt.Errorf("unknown MatchType %v", typeU.Type)
171+
}
172+
return json.Unmarshal(b, r.EventMatch)
173+
}

internal/streaming/client_test.go

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package streaming
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"github.com/google/go-cmp/cmp"
9+
)
10+
11+
func TestDecoder(t *testing.T) {
12+
type Event struct {
13+
Name string
14+
Value interface{}
15+
}
16+
17+
want := []Event{{
18+
Name: "progress",
19+
Value: &Progress{
20+
MatchCount: 5,
21+
},
22+
}, {
23+
Name: "progress",
24+
Value: &Progress{
25+
MatchCount: 10,
26+
},
27+
}, {
28+
Name: "matches",
29+
Value: []EventMatch{
30+
&EventFileMatch{
31+
Type: FileMatchType,
32+
Path: "test",
33+
},
34+
&EventRepoMatch{
35+
Type: RepoMatchType,
36+
Repository: "test",
37+
},
38+
&EventSymbolMatch{
39+
Type: SymbolMatchType,
40+
Path: "test",
41+
},
42+
&EventCommitMatch{
43+
Type: CommitMatchType,
44+
Detail: "test",
45+
},
46+
},
47+
}, {
48+
Name: "filters",
49+
Value: []*EventFilter{{
50+
Value: "filter-1",
51+
}, {
52+
Value: "filter-2",
53+
}},
54+
}, {
55+
Name: "alert",
56+
Value: &EventAlert{
57+
Title: "alert",
58+
},
59+
}, {
60+
Name: "error",
61+
Value: &EventError{
62+
Message: "error",
63+
},
64+
}}
65+
66+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
67+
ew, err := NewWriter(w)
68+
if err != nil {
69+
http.Error(w, err.Error(), http.StatusInternalServerError)
70+
return
71+
}
72+
for _, e := range want {
73+
ew.Event(e.Name, e.Value)
74+
}
75+
ew.Event("done", struct{}{})
76+
}))
77+
defer ts.Close()
78+
79+
req, err := NewRequest(ts.URL, "hello world")
80+
if err != nil {
81+
t.Fatal(err)
82+
}
83+
resp, err := http.DefaultClient.Do(req)
84+
if err != nil {
85+
t.Fatal(err)
86+
}
87+
defer resp.Body.Close()
88+
89+
var got []Event
90+
err = Decoder{
91+
OnProgress: func(d *Progress) {
92+
got = append(got, Event{Name: "progress", Value: d})
93+
},
94+
OnMatches: func(d []EventMatch) {
95+
got = append(got, Event{Name: "matches", Value: d})
96+
},
97+
OnFilters: func(d []*EventFilter) {
98+
got = append(got, Event{Name: "filters", Value: d})
99+
},
100+
OnAlert: func(d *EventAlert) {
101+
got = append(got, Event{Name: "alert", Value: d})
102+
},
103+
OnError: func(d *EventError) {
104+
got = append(got, Event{Name: "error", Value: d})
105+
},
106+
OnUnknown: func(event, data []byte) {
107+
t.Fatalf("got unexpected event: %s %s", event, data)
108+
},
109+
}.ReadAll(resp.Body)
110+
if err != nil {
111+
t.Fatal(err)
112+
}
113+
114+
if d := cmp.Diff(want, got); d != "" {
115+
t.Fatalf("mismatch (-want +got):\n%s", d)
116+
}
117+
}

0 commit comments

Comments
 (0)