diff --git a/server/e2e/e2e_cdp_reconnect_test.go b/server/e2e/e2e_cdp_reconnect_test.go new file mode 100644 index 00000000..dfa9815f --- /dev/null +++ b/server/e2e/e2e_cdp_reconnect_test.go @@ -0,0 +1,592 @@ +package e2e + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "os/exec" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/coder/websocket" + instanceoapi "github.com/onkernel/kernel-images/server/lib/oapi" + "github.com/stretchr/testify/require" +) + +type cdpError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type cdpEnvelope struct { + ID int `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Params json.RawMessage `json:"params,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *cdpError `json:"error,omitempty"` + SessionID string `json:"sessionId,omitempty"` +} + +type cdpCall struct { + response chan cdpEnvelope +} + +type cdpEventWaiter struct { + method string + sessionID string + response chan cdpEnvelope +} + +type cdpClient struct { + conn *websocket.Conn + closeMu sync.Mutex + closed bool + closeCh chan struct{} + closeErr error + + mu sync.Mutex + nextID int + pending map[int]cdpCall + waiters []cdpEventWaiter +} + +type cdpExerciseResult struct { + Browser string + Title string + Heading string + Sum int + ReadyState string + ScreenshotBytes int +} + +func newCDPClient(ctx context.Context, wsURL string) (*cdpClient, error) { + conn, _, err := websocket.Dial(ctx, wsURL, nil) + if err != nil { + return nil, err + } + conn.SetReadLimit(100 * 1024 * 1024) + + client := &cdpClient{ + conn: conn, + closeCh: make(chan struct{}), + nextID: 1, + pending: make(map[int]cdpCall), + } + go client.readLoop() + return client, nil +} + +func (c *cdpClient) readLoop() { + for { + _, msg, err := c.conn.Read(context.Background()) + if err != nil { + c.closeWithErr(err) + return + } + + var envelope cdpEnvelope + if err := json.Unmarshal(msg, &envelope); err != nil { + c.closeWithErr(fmt.Errorf("unmarshal CDP message: %w", err)) + return + } + + if envelope.ID != 0 { + c.mu.Lock() + call, ok := c.pending[envelope.ID] + if ok { + delete(c.pending, envelope.ID) + } + c.mu.Unlock() + if ok { + call.response <- envelope + } + continue + } + + c.mu.Lock() + for i, waiter := range c.waiters { + if waiter.method == envelope.Method && waiter.sessionID == envelope.SessionID { + c.waiters = append(c.waiters[:i], c.waiters[i+1:]...) + c.mu.Unlock() + waiter.response <- envelope + goto handled + } + } + c.mu.Unlock() + + handled: + } +} + +func (c *cdpClient) closeWithErr(err error) { + c.closeMu.Lock() + defer c.closeMu.Unlock() + if c.closed { + return + } + c.closed = true + c.closeErr = err + + c.mu.Lock() + for _, call := range c.pending { + close(call.response) + } + c.pending = map[int]cdpCall{} + for _, waiter := range c.waiters { + close(waiter.response) + } + c.waiters = nil + c.mu.Unlock() + + close(c.closeCh) +} + +func (c *cdpClient) Close() { + _ = c.conn.Close(websocket.StatusNormalClosure, "") +} + +func (c *cdpClient) WaitClosed(ctx context.Context) error { + select { + case <-c.closeCh: + return c.closeErr + case <-ctx.Done(): + return ctx.Err() + } +} + +func (c *cdpClient) Call(ctx context.Context, method string, params any, sessionID string) (json.RawMessage, error) { + c.closeMu.Lock() + closed := c.closed + c.closeMu.Unlock() + if closed { + return nil, fmt.Errorf("connection closed: %w", c.closeErr) + } + + id := 0 + responseCh := make(chan cdpEnvelope, 1) + + c.mu.Lock() + id = c.nextID + c.nextID++ + c.pending[id] = cdpCall{response: responseCh} + c.mu.Unlock() + + payload := map[string]any{ + "id": id, + "method": method, + "params": params, + } + if sessionID != "" { + payload["sessionId"] = sessionID + } + + writeCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + if err := wsjsonWrite(writeCtx, c.conn, payload); err != nil { + c.mu.Lock() + delete(c.pending, id) + c.mu.Unlock() + return nil, err + } + + select { + case envelope, ok := <-responseCh: + if !ok { + return nil, fmt.Errorf("connection closed while waiting for %s", method) + } + if envelope.Error != nil { + return nil, fmt.Errorf("CDP %s failed: %d %s", method, envelope.Error.Code, envelope.Error.Message) + } + return envelope.Result, nil + case <-ctx.Done(): + c.mu.Lock() + delete(c.pending, id) + c.mu.Unlock() + return nil, ctx.Err() + } +} + +func (c *cdpClient) WaitForEvent(ctx context.Context, method, sessionID string) error { + responseCh := make(chan cdpEnvelope, 1) + + c.mu.Lock() + c.waiters = append(c.waiters, cdpEventWaiter{ + method: method, + sessionID: sessionID, + response: responseCh, + }) + c.mu.Unlock() + + select { + case _, ok := <-responseCh: + if !ok { + return fmt.Errorf("connection closed before event %s", method) + } + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func wsjsonWrite(ctx context.Context, conn *websocket.Conn, value any) error { + payload, err := json.Marshal(value) + if err != nil { + return err + } + return conn.Write(ctx, websocket.MessageText, payload) +} + +func decodeJSONStringField(raw json.RawMessage, field string) (string, error) { + var values map[string]any + if err := json.Unmarshal(raw, &values); err != nil { + return "", err + } + + value, ok := values[field].(string) + if !ok { + return "", fmt.Errorf("expected field %q in %s", field, string(raw)) + } + return value, nil +} + +func connectAndExerciseCDP(ctx context.Context, wsURL, label string) (*cdpClient, cdpExerciseResult, error) { + client, err := newCDPClient(ctx, wsURL) + if err != nil { + return nil, cdpExerciseResult{}, err + } + + versionRaw, err := client.Call(ctx, "Browser.getVersion", map[string]any{}, "") + if err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + browser, err := decodeJSONStringField(versionRaw, "product") + if err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + targetRaw, err := client.Call(ctx, "Target.createTarget", map[string]any{"url": "about:blank"}, "") + if err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + targetID, err := decodeJSONStringField(targetRaw, "targetId") + if err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + attachRaw, err := client.Call(ctx, "Target.attachToTarget", map[string]any{ + "targetId": targetID, + "flatten": true, + }, "") + if err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + sessionID, err := decodeJSONStringField(attachRaw, "sessionId") + if err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + if _, err := client.Call(ctx, "Page.enable", map[string]any{}, sessionID); err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + if _, err := client.Call(ctx, "Runtime.enable", map[string]any{}, sessionID); err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + loadCtx, loadCancel := context.WithTimeout(ctx, 15*time.Second) + defer loadCancel() + loadDone := make(chan error, 1) + go func() { + loadDone <- client.WaitForEvent(loadCtx, "Page.loadEventFired", sessionID) + }() + + html := fmt.Sprintf("%s

%s

", label, label) + if _, err := client.Call(ctx, "Page.navigate", map[string]any{ + "url": "data:text/html," + url.PathEscape(html), + }, sessionID); err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + if err := <-loadDone; err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + evalRaw, err := client.Call(ctx, "Runtime.evaluate", map[string]any{ + "expression": `JSON.stringify({title:document.title,heading:document.querySelector("h1").textContent,sum:window.sum,ready:document.readyState})`, + "returnByValue": true, + }, sessionID) + if err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + var evalEnvelope struct { + Result struct { + Value string `json:"value"` + } `json:"result"` + } + if err := json.Unmarshal(evalRaw, &evalEnvelope); err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + var summary struct { + Title string `json:"title"` + Heading string `json:"heading"` + Sum int `json:"sum"` + Ready string `json:"ready"` + } + if err := json.Unmarshal([]byte(evalEnvelope.Result.Value), &summary); err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + screenshotRaw, err := client.Call(ctx, "Page.captureScreenshot", map[string]any{"format": "png"}, sessionID) + if err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + var screenshotEnvelope struct { + Data string `json:"data"` + } + if err := json.Unmarshal(screenshotRaw, &screenshotEnvelope); err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + screenshotBytes, err := base64.StdEncoding.DecodeString(screenshotEnvelope.Data) + if err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + if _, err := client.Call(ctx, "Target.closeTarget", map[string]any{"targetId": targetID}, ""); err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + return client, cdpExerciseResult{ + Browser: browser, + Title: summary.Title, + Heading: summary.Heading, + Sum: summary.Sum, + ReadyState: summary.Ready, + ScreenshotBytes: len(screenshotBytes), + }, nil +} + +func restartChromiumViaAPI(ctx context.Context, client *instanceoapi.ClientWithResponses) error { + args := []string{"-c", "/etc/supervisor/supervisord.conf", "restart", "chromium"} + req := instanceoapi.ProcessExecJSONRequestBody{ + Command: "supervisorctl", + Args: &args, + } + + rsp, err := client.ProcessExecWithResponse(ctx, req) + if err != nil { + return err + } + if rsp.JSON200 == nil { + return fmt.Errorf("restart chromium returned status %s", rsp.Status()) + } + return nil +} + +func waitForContainerFile(ctx context.Context, client *instanceoapi.ClientWithResponses, path string, timeout time.Duration) error { + waitCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + args := []string{"-lc", fmt.Sprintf("test -f %q", path)} + req := instanceoapi.ProcessExecJSONRequestBody{ + Command: "sh", + Args: &args, + } + reqCtx, reqCancel := context.WithTimeout(waitCtx, 2*time.Second) + rsp, err := client.ProcessExecWithResponse(reqCtx, req) + reqCancel() + if err == nil && rsp.JSON200 != nil && rsp.JSON200.ExitCode != nil && *rsp.JSON200.ExitCode == 0 { + return nil + } + + select { + case <-waitCtx.Done(): + return waitCtx.Err() + case <-ticker.C: + } + } +} + +func touchContainerFile(ctx context.Context, client *instanceoapi.ClientWithResponses, path string) error { + args := []string{"-lc", fmt.Sprintf("mkdir -p %q && touch %q", filepath.Dir(path), path)} + req := instanceoapi.ProcessExecJSONRequestBody{ + Command: "sh", + Args: &args, + } + reqCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + rsp, err := client.ProcessExecWithResponse(reqCtx, req) + if err != nil { + return err + } + if rsp.JSON200 == nil { + return fmt.Errorf("touch %s returned status %s", path, rsp.Status()) + } + if rsp.JSON200.ExitCode == nil || *rsp.JSON200.ExitCode != 0 { + return fmt.Errorf("touch %s failed with exit code %v", path, rsp.JSON200.ExitCode) + } + return nil +} + +func fetchBrowserWebSocketURL(ctx context.Context, c *TestContainer) (string, error) { + versionURL := fmt.Sprintf("http://127.0.0.1:%d/json/version", c.CDPPort) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, versionURL, nil) + if err != nil { + return "", err + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("json/version returned %s", resp.Status) + } + + var payload struct { + WebSocketDebuggerURL string `json:"webSocketDebuggerUrl"` + } + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return "", err + } + if payload.WebSocketDebuggerURL == "" { + return "", fmt.Errorf("json/version missing webSocketDebuggerUrl") + } + return payload.WebSocketDebuggerURL, nil +} + +func waitForChangedBrowserWebSocketURL(ctx context.Context, c *TestContainer, previous string, timeout time.Duration) (string, error) { + waitCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + ticker := time.NewTicker(200 * time.Millisecond) + defer ticker.Stop() + + for { + current, err := fetchBrowserWebSocketURL(waitCtx, c) + if err == nil && current != "" && current != previous { + return current, nil + } + + select { + case <-waitCtx.Done(): + if err != nil { + return "", fmt.Errorf("waiting for changed browser websocket url: %w", err) + } + return "", waitCtx.Err() + case <-ticker.C: + } + } +} + +func TestCDPProxyReconnectPendingConnectionDuringRestart(t *testing.T) { + t.Parallel() + + if _, err := exec.LookPath("docker"); err != nil { + t.Skipf("docker not available: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Minute) + defer cancel() + + const hookPath = "/tmp/devtoolsproxy-race" + c := NewTestContainer(t, headlessImage) + require.NoError(t, c.Start(ctx, ContainerConfig{ + Env: map[string]string{ + "DEVTOOLS_PROXY_TEST_POST_CURRENT_BLOCK_FILE": hookPath, + }, + }), "failed to start container") + defer c.Stop(ctx) + + require.NoError(t, c.WaitReady(ctx), "api not ready") + require.NoError(t, c.WaitDevTools(ctx), "devtools not ready") + initialBrowserWS, err := fetchBrowserWebSocketURL(ctx, c) + require.NoError(t, err, "failed to fetch initial browser websocket URL") + + apiClient, err := c.APIClientNoKeepAlive() + require.NoError(t, err) + + initialConn, initialResult, err := connectAndExerciseCDP(ctx, c.CDPURL(), "before-restart") + require.NoError(t, err, "initial CDP session failed") + defer initialConn.Close() + + require.Equal(t, "before-restart", initialResult.Title) + require.Equal(t, "before-restart", initialResult.Heading) + require.Equal(t, 21, initialResult.Sum) + require.Equal(t, "complete", initialResult.ReadyState) + require.Greater(t, initialResult.ScreenshotBytes, 1000) + + type reconnectResult struct { + client *cdpClient + data cdpExerciseResult + err error + } + reconnectCh := make(chan reconnectResult, 1) + reconnectURL := c.CDPURL() + "?devtoolsProxyTestHook=1" + go func() { + client, data, err := connectAndExerciseCDP(ctx, reconnectURL, "after-restart") + reconnectCh <- reconnectResult{client: client, data: data, err: err} + }() + + require.NoError(t, waitForContainerFile(ctx, apiClient, hookPath+".ready", 15*time.Second), "pending reconnect never reached post-current hook") + + require.NoError(t, restartChromiumViaAPI(ctx, apiClient), "restart chromium failed") + + closeCtx, closeCancel := context.WithTimeout(ctx, 30*time.Second) + defer closeCancel() + closeErr := initialConn.WaitClosed(closeCtx) + require.Error(t, closeErr, "expected initial CDP connection to close on chromium restart") + require.False(t, errors.Is(closeErr, context.DeadlineExceeded), "timed out waiting for initial CDP connection to close") + + _, err = initialConn.Call(ctx, "Browser.getVersion", map[string]any{}, "") + require.Error(t, err, "expected stale CDP connection to reject commands after restart") + + updatedBrowserWS, err := waitForChangedBrowserWebSocketURL(ctx, c, initialBrowserWS, 20*time.Second) + require.NoError(t, err, "proxy never exposed a new browser websocket URL after restart") + require.NotEqual(t, initialBrowserWS, updatedBrowserWS) + require.NoError(t, touchContainerFile(ctx, apiClient, hookPath+".release"), "failed to release pending reconnect") + + select { + case reconnect := <-reconnectCh: + require.NoError(t, reconnect.err, "reconnect CDP session failed") + if reconnect.client != nil { + defer reconnect.client.Close() + } + require.Equal(t, "after-restart", reconnect.data.Title) + require.Equal(t, "after-restart", reconnect.data.Heading) + require.Equal(t, 21, reconnect.data.Sum) + require.Equal(t, "complete", reconnect.data.ReadyState) + require.Greater(t, reconnect.data.ScreenshotBytes, 1000) + case <-time.After(30 * time.Second): + t.Fatal("timed out waiting for reconnect CDP session") + } +} diff --git a/server/lib/devtoolsproxy/proxy.go b/server/lib/devtoolsproxy/proxy.go index 082dd806..9c762b35 100644 --- a/server/lib/devtoolsproxy/proxy.go +++ b/server/lib/devtoolsproxy/proxy.go @@ -8,6 +8,7 @@ import ( "log/slog" "net/http" "net/url" + "os" "os/exec" "regexp" "strconv" @@ -194,23 +195,112 @@ func (u *UpstreamManager) runTailOnce(ctx context.Context) { } } -// WebSocketProxyHandler returns an http.Handler that upgrades incoming connections and -// proxies them to the current upstream websocket URL. It expects only websocket requests. -// If logCDPMessages is true, all CDP messages will be logged with their direction. -func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMessages bool, ctrl scaletozero.Controller) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - upstreamCurrent := mgr.Current() - if upstreamCurrent == "" { - http.Error(w, "upstream not ready", http.StatusServiceUnavailable) +func dialUpstreamWithRetry(ctx context.Context, mgr *UpstreamManager, urlCh <-chan string, initialUpstreamURL string, dialOpts *websocket.DialOptions, logger *slog.Logger) (*websocket.Conn, string, error) { + upstreamURL := normalizeUpstreamURL(initialUpstreamURL) + if upstreamURL == "" { + return nil, "", fmt.Errorf("upstream not ready") + } + + deadline := time.NewTimer(5 * time.Second) + defer deadline.Stop() + + for { + upstreamConn, _, err := websocket.Dial(ctx, upstreamURL, dialOpts) + if err == nil { + return upstreamConn, upstreamURL, nil + } + + logger.Warn("dial upstream failed, checking for newer URL", + slog.String("err", err.Error()), slog.String("url", upstreamURL)) + + latestURL := normalizeUpstreamURL(mgr.Current()) + if latestURL != "" && latestURL != upstreamURL { + upstreamURL = latestURL + continue + } + + select { + case newURL, ok := <-urlCh: + if !ok { + return nil, "", fmt.Errorf("upstream unavailable") + } + newURL = normalizeUpstreamURL(newURL) + if newURL == "" || newURL == upstreamURL { + continue + } + upstreamURL = newURL + case <-deadline.C: + return nil, "", fmt.Errorf("timed out waiting for new upstream URL") + case <-ctx.Done(): + return nil, "", ctx.Err() + } + } +} + +func maybePauseAfterCurrentRead(ctx context.Context, logger *slog.Logger, r *http.Request) { + if r.URL.Query().Get("devtoolsProxyTestHook") != "1" { + return + } + + // Test-only hook used by e2e to widen the window between reading Current + // and dialing/subscribing so reconnect races can be reproduced reliably. + rawDelayMs := os.Getenv("DEVTOOLS_PROXY_TEST_POST_CURRENT_DELAY_MS") + if rawDelayMs != "" { + delayMs, err := strconv.Atoi(rawDelayMs) + if err != nil || delayMs <= 0 { + logger.Warn("ignoring invalid devtools proxy test delay", slog.String("value", rawDelayMs)) + } else { + timer := time.NewTimer(time.Duration(delayMs) * time.Millisecond) + defer timer.Stop() + + select { + case <-timer.C: + case <-ctx.Done(): + return + } + } + } + + blockPath := os.Getenv("DEVTOOLS_PROXY_TEST_POST_CURRENT_BLOCK_FILE") + if blockPath == "" { + return + } + + readyPath := blockPath + ".ready" + releasePath := blockPath + ".release" + if err := os.WriteFile(readyPath, []byte("ready\n"), 0o644); err != nil { + logger.Warn("failed to write devtools proxy test ready marker", + slog.String("path", readyPath), + slog.String("err", err.Error())) + return + } + + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + + for { + if _, err := os.Stat(releasePath); err == nil { + return + } else if !os.IsNotExist(err) { + logger.Warn("failed to read devtools proxy test release marker", + slog.String("path", releasePath), + slog.String("err", err.Error())) return } - parsed, err := url.Parse(upstreamCurrent) - if err != nil { - http.Error(w, "invalid upstream", http.StatusInternalServerError) + + select { + case <-ticker.C: + case <-ctx.Done(): return } - upstreamURL := (&url.URL{Scheme: parsed.Scheme, Host: parsed.Host, Path: parsed.Path, RawQuery: parsed.RawQuery}).String() + } +} +// WebSocketProxyHandler returns an http.Handler that upgrades incoming connections and +// proxies them to the current upstream websocket URL. It expects only websocket requests. +// If logCDPMessages is true, all CDP messages will be logged with their direction. +func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMessages bool, ctrl scaletozero.Controller) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var transform wsproxy.MessageTransform if logCDPMessages { transform = func(direction string, mt websocket.MessageType, msg []byte) []byte { @@ -226,15 +316,93 @@ func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMess dialOpts := &websocket.DialOptions{ CompressionMode: websocket.CompressionContextTakeover, } - wsproxy.Proxy(w, r, upstreamURL, wsproxy.ProxyOptions{ - AcceptOptions: acceptOpts, - DialOptions: dialOpts, - Logger: logger, - Transform: transform, - }) + + // Subscribe to upstream URL changes so we can tear down stale sessions + // when Chromium restarts and retry if the current URL is already dead. + urlCh, unsub := mgr.Subscribe() + defer unsub() + + upstreamCurrent := mgr.Current() + if upstreamCurrent == "" { + http.Error(w, "upstream not ready", http.StatusServiceUnavailable) + return + } + maybePauseAfterCurrentRead(r.Context(), logger, r) + + // Accept the client WebSocket connection. + clientConn, err := websocket.Accept(w, r, acceptOpts) + if err != nil { + logger.Error("websocket accept failed", slog.String("err", err.Error())) + return + } + clientConn.SetReadLimit(100 * 1024 * 1024) + + // Dial upstream. If the URL is stale (Chromium just restarted), first + // re-check the manager's latest URL in case we missed the notification, + // then wait briefly for the next update from Subscribe. + upstreamConn, upstreamURL, err := dialUpstreamWithRetry(r.Context(), mgr, urlCh, upstreamCurrent, dialOpts, logger) + if err != nil { + switch { + case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.Is(r.Context().Err(), context.Canceled), errors.Is(r.Context().Err(), context.DeadlineExceeded): + clientConn.Close(websocket.StatusGoingAway, "request cancelled") + default: + logger.Error("failed to connect to upstream", slog.String("err", err.Error())) + clientConn.Close(websocket.StatusInternalError, "upstream unavailable") + } + return + } + upstreamConn.SetReadLimit(100 * 1024 * 1024) + + logger.Debug("proxying websocket", slog.String("url", upstreamURL)) + + // Cancel the pump when the upstream URL changes (Chromium restarted), + // forcing the client to reconnect with the new upstream. + pumpCtx, pumpCancel := context.WithCancel(r.Context()) + + go func(currentUpstreamURL string) { + for { + select { + case newURL, ok := <-urlCh: + if !ok { + return + } + newURL = normalizeUpstreamURL(newURL) + if newURL == "" || newURL == currentUpstreamURL { + continue + } + logger.Info("upstream URL changed, closing stale proxy session", + slog.String("old_url", currentUpstreamURL), + slog.String("new_url", newURL)) + pumpCancel() + return + case <-pumpCtx.Done(): + return + } + } + }(upstreamURL) + + var once sync.Once + cleanup := func() { + once.Do(func() { + pumpCancel() + upstreamConn.Close(websocket.StatusNormalClosure, "") + clientConn.Close(websocket.StatusNormalClosure, "") + }) + } + + wsproxy.Pump(pumpCtx, clientConn, upstreamConn, cleanup, logger, transform) }) } +// normalizeUpstreamURL parses a raw DevTools URL and returns a clean form. +func normalizeUpstreamURL(raw string) string { + parsed, err := url.Parse(raw) + if err != nil { + return raw + } + return (&url.URL{Scheme: parsed.Scheme, Host: parsed.Host, Path: parsed.Path, RawQuery: parsed.RawQuery}).String() +} + // logCDPMessage logs a CDP message with its direction if logging is enabled func logCDPMessage(logger *slog.Logger, direction string, mt websocket.MessageType, msg []byte) { if mt != websocket.MessageText { diff --git a/server/lib/devtoolsproxy/proxy_test.go b/server/lib/devtoolsproxy/proxy_test.go index 3092891f..9c8bffe6 100644 --- a/server/lib/devtoolsproxy/proxy_test.go +++ b/server/lib/devtoolsproxy/proxy_test.go @@ -159,6 +159,76 @@ func TestWebSocketProxyHandler_ProxiesEcho(t *testing.T) { } } +func TestDialUpstreamWithRetry_RechecksCurrentAfterMissedUpdate(t *testing.T) { + // Start a working websocket upstream. + upstreamSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, + }) + if err != nil { + t.Fatalf("accept failed: %v", err) + return + } + defer c.Close(websocket.StatusNormalClosure, "") + + mt, msg, err := c.Read(r.Context()) + if err != nil { + return + } + if err := c.Write(r.Context(), mt, msg); err != nil { + t.Fatalf("write failed: %v", err) + } + })) + defer upstreamSrv.Close() + + freshURL, err := url.Parse(upstreamSrv.URL) + if err != nil { + t.Fatalf("parse upstream URL: %v", err) + } + freshURL.Scheme = "ws" + freshURL.Path = "/devtools/browser/fresh" + + stalePort, err := getFreePort() + if err != nil { + t.Fatalf("get stale port: %v", err) + } + staleURL := fmt.Sprintf("ws://127.0.0.1:%d/devtools/browser/stale", stalePort) + + logger := silentLogger() + mgr := NewUpstreamManager("/dev/null", logger) + urlCh, cancel := mgr.Subscribe() + defer cancel() + + // Simulate the race window by advancing Current without broadcasting to the + // subscriber channel. The retry path must re-check Current after the stale + // dial fails instead of waiting forever for a missed notification. + mgr.currentURL.Store(freshURL.String()) + + ctx, ctxCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer ctxCancel() + + conn, connectedURL, err := dialUpstreamWithRetry(ctx, mgr, urlCh, staleURL, nil, logger) + if err != nil { + t.Fatalf("dialUpstreamWithRetry failed: %v", err) + } + defer conn.Close(websocket.StatusNormalClosure, "") + + if connectedURL != freshURL.String() { + t.Fatalf("expected to connect to %q, got %q", freshURL.String(), connectedURL) + } + + if err := conn.Write(ctx, websocket.MessageText, []byte("ping")); err != nil { + t.Fatalf("write failed: %v", err) + } + _, msg, err := conn.Read(ctx) + if err != nil { + t.Fatalf("read failed: %v", err) + } + if string(msg) != "ping" { + t.Fatalf("expected echo %q, got %q", "ping", string(msg)) + } +} + func TestUpstreamManagerDetectsChromiumAndRestart(t *testing.T) { browser, err := findBrowserBinary() if err != nil {