From 7578ed1bb425558f046871d4a74684647633c9b2 Mon Sep 17 00:00:00 2001 From: sjmiller609 <7516283+sjmiller609@users.noreply.github.com> Date: Wed, 25 Mar 2026 20:31:25 +0000 Subject: [PATCH 1/3] Wire UpstreamManager.Subscribe into CDP proxy to handle Chromium restarts When Chromium restarts via supervisorctl, the CDP WebSocket proxy now: 1. Subscribes to upstream URL changes per proxy session, so active connections are proactively closed when the upstream URL changes (forcing clients to reconnect with the fresh URL). 2. If dialing the current upstream URL fails (stale URL from a fast restart cycle), waits up to 5s for a new URL from Subscribe and retries the dial once before giving up. Fixes CUS-109. --- server/lib/devtoolsproxy/proxy.go | 93 +++++++++++++++++++++++++++---- 1 file changed, 81 insertions(+), 12 deletions(-) diff --git a/server/lib/devtoolsproxy/proxy.go b/server/lib/devtoolsproxy/proxy.go index 082dd806..8b5137df 100644 --- a/server/lib/devtoolsproxy/proxy.go +++ b/server/lib/devtoolsproxy/proxy.go @@ -204,12 +204,6 @@ func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMess http.Error(w, "upstream not ready", http.StatusServiceUnavailable) return } - parsed, err := url.Parse(upstreamCurrent) - if err != nil { - http.Error(w, "invalid upstream", http.StatusInternalServerError) - return - } - upstreamURL := (&url.URL{Scheme: parsed.Scheme, Host: parsed.Host, Path: parsed.Path, RawQuery: parsed.RawQuery}).String() var transform wsproxy.MessageTransform if logCDPMessages { @@ -226,15 +220,90 @@ func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMess dialOpts := &websocket.DialOptions{ CompressionMode: websocket.CompressionContextTakeover, } - wsproxy.Proxy(w, r, upstreamURL, wsproxy.ProxyOptions{ - AcceptOptions: acceptOpts, - DialOptions: dialOpts, - Logger: logger, - Transform: transform, - }) + + // Subscribe to upstream URL changes so we can tear down stale sessions + // when Chromium restarts and retry if the current URL is already dead. + urlCh, unsub := mgr.Subscribe() + defer unsub() + + upstreamURL := normalizeUpstreamURL(upstreamCurrent) + + // Accept the client WebSocket connection. + clientConn, err := websocket.Accept(w, r, acceptOpts) + if err != nil { + logger.Error("websocket accept failed", slog.String("err", err.Error())) + return + } + clientConn.SetReadLimit(100 * 1024 * 1024) + + // Dial upstream. If the URL is stale (Chromium just restarted), wait + // briefly for a fresh URL from Subscribe and retry once. + upstreamConn, _, err := websocket.Dial(r.Context(), upstreamURL, dialOpts) + if err != nil { + logger.Warn("dial upstream failed, waiting for new URL", + slog.String("err", err.Error()), slog.String("url", upstreamURL)) + select { + case newURL, ok := <-urlCh: + if !ok { + clientConn.Close(websocket.StatusInternalError, "upstream unavailable") + return + } + upstreamURL = normalizeUpstreamURL(newURL) + upstreamConn, _, err = websocket.Dial(r.Context(), upstreamURL, dialOpts) + if err != nil { + logger.Error("dial upstream failed after retry", + slog.String("err", err.Error()), slog.String("url", upstreamURL)) + clientConn.Close(websocket.StatusInternalError, "failed to connect to upstream") + return + } + case <-time.After(5 * time.Second): + logger.Error("timed out waiting for new upstream URL") + clientConn.Close(websocket.StatusInternalError, "upstream unavailable") + return + case <-r.Context().Done(): + clientConn.Close(websocket.StatusGoingAway, "request cancelled") + return + } + } + upstreamConn.SetReadLimit(100 * 1024 * 1024) + + logger.Debug("proxying websocket", slog.String("url", upstreamURL)) + + // Cancel the pump when the upstream URL changes (Chromium restarted), + // forcing the client to reconnect with the new upstream. + pumpCtx, pumpCancel := context.WithCancel(r.Context()) + + go func() { + select { + case <-urlCh: + logger.Info("upstream URL changed, closing stale proxy session") + pumpCancel() + case <-pumpCtx.Done(): + } + }() + + var once sync.Once + cleanup := func() { + once.Do(func() { + pumpCancel() + upstreamConn.Close(websocket.StatusNormalClosure, "") + clientConn.Close(websocket.StatusNormalClosure, "") + }) + } + + wsproxy.Pump(pumpCtx, clientConn, upstreamConn, cleanup, logger, transform) }) } +// normalizeUpstreamURL parses a raw DevTools URL and returns a clean form. +func normalizeUpstreamURL(raw string) string { + parsed, err := url.Parse(raw) + if err != nil { + return raw + } + return (&url.URL{Scheme: parsed.Scheme, Host: parsed.Host, Path: parsed.Path, RawQuery: parsed.RawQuery}).String() +} + // logCDPMessage logs a CDP message with its direction if logging is enabled func logCDPMessage(logger *slog.Logger, direction string, mt websocket.MessageType, msg []byte) { if mt != websocket.MessageText { From 0e91b5c80d99e2be0f349ed3ea51c5017a5cc03e Mon Sep 17 00:00:00 2001 From: Steven Miller Date: Thu, 26 Mar 2026 12:20:19 -0400 Subject: [PATCH 2/3] fix: retry latest devtools URL after stale dial --- server/lib/devtoolsproxy/proxy.go | 89 +++++++++++++++++--------- server/lib/devtoolsproxy/proxy_test.go | 70 ++++++++++++++++++++ 2 files changed, 127 insertions(+), 32 deletions(-) diff --git a/server/lib/devtoolsproxy/proxy.go b/server/lib/devtoolsproxy/proxy.go index 8b5137df..117f237c 100644 --- a/server/lib/devtoolsproxy/proxy.go +++ b/server/lib/devtoolsproxy/proxy.go @@ -194,17 +194,53 @@ func (u *UpstreamManager) runTailOnce(ctx context.Context) { } } +func dialUpstreamWithRetry(ctx context.Context, mgr *UpstreamManager, urlCh <-chan string, initialUpstreamURL string, dialOpts *websocket.DialOptions, logger *slog.Logger) (*websocket.Conn, string, error) { + upstreamURL := normalizeUpstreamURL(initialUpstreamURL) + if upstreamURL == "" { + return nil, "", fmt.Errorf("upstream not ready") + } + + deadline := time.NewTimer(5 * time.Second) + defer deadline.Stop() + + for { + upstreamConn, _, err := websocket.Dial(ctx, upstreamURL, dialOpts) + if err == nil { + return upstreamConn, upstreamURL, nil + } + + logger.Warn("dial upstream failed, checking for newer URL", + slog.String("err", err.Error()), slog.String("url", upstreamURL)) + + latestURL := normalizeUpstreamURL(mgr.Current()) + if latestURL != "" && latestURL != upstreamURL { + upstreamURL = latestURL + continue + } + + select { + case newURL, ok := <-urlCh: + if !ok { + return nil, "", fmt.Errorf("upstream unavailable") + } + newURL = normalizeUpstreamURL(newURL) + if newURL == "" || newURL == upstreamURL { + continue + } + upstreamURL = newURL + case <-deadline.C: + return nil, "", fmt.Errorf("timed out waiting for new upstream URL") + case <-ctx.Done(): + return nil, "", ctx.Err() + } + } +} + // WebSocketProxyHandler returns an http.Handler that upgrades incoming connections and // proxies them to the current upstream websocket URL. It expects only websocket requests. // If logCDPMessages is true, all CDP messages will be logged with their direction. func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMessages bool, ctrl scaletozero.Controller) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - upstreamCurrent := mgr.Current() - if upstreamCurrent == "" { - http.Error(w, "upstream not ready", http.StatusServiceUnavailable) - return - } - var transform wsproxy.MessageTransform if logCDPMessages { transform = func(direction string, mt websocket.MessageType, msg []byte) []byte { @@ -226,7 +262,11 @@ func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMess urlCh, unsub := mgr.Subscribe() defer unsub() - upstreamURL := normalizeUpstreamURL(upstreamCurrent) + upstreamCurrent := mgr.Current() + if upstreamCurrent == "" { + http.Error(w, "upstream not ready", http.StatusServiceUnavailable) + return + } // Accept the client WebSocket connection. clientConn, err := websocket.Accept(w, r, acceptOpts) @@ -236,34 +276,19 @@ func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMess } clientConn.SetReadLimit(100 * 1024 * 1024) - // Dial upstream. If the URL is stale (Chromium just restarted), wait - // briefly for a fresh URL from Subscribe and retry once. - upstreamConn, _, err := websocket.Dial(r.Context(), upstreamURL, dialOpts) + // Dial upstream. If the URL is stale (Chromium just restarted), first + // re-check the manager's latest URL in case we missed the notification, + // then wait briefly for the next update from Subscribe. + upstreamConn, upstreamURL, err := dialUpstreamWithRetry(r.Context(), mgr, urlCh, upstreamCurrent, dialOpts, logger) if err != nil { - logger.Warn("dial upstream failed, waiting for new URL", - slog.String("err", err.Error()), slog.String("url", upstreamURL)) - select { - case newURL, ok := <-urlCh: - if !ok { - clientConn.Close(websocket.StatusInternalError, "upstream unavailable") - return - } - upstreamURL = normalizeUpstreamURL(newURL) - upstreamConn, _, err = websocket.Dial(r.Context(), upstreamURL, dialOpts) - if err != nil { - logger.Error("dial upstream failed after retry", - slog.String("err", err.Error()), slog.String("url", upstreamURL)) - clientConn.Close(websocket.StatusInternalError, "failed to connect to upstream") - return - } - case <-time.After(5 * time.Second): - logger.Error("timed out waiting for new upstream URL") - clientConn.Close(websocket.StatusInternalError, "upstream unavailable") - return - case <-r.Context().Done(): + switch { + case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.Is(r.Context().Err(), context.Canceled), errors.Is(r.Context().Err(), context.DeadlineExceeded): clientConn.Close(websocket.StatusGoingAway, "request cancelled") - return + default: + logger.Error("failed to connect to upstream", slog.String("err", err.Error())) + clientConn.Close(websocket.StatusInternalError, "upstream unavailable") } + return } upstreamConn.SetReadLimit(100 * 1024 * 1024) diff --git a/server/lib/devtoolsproxy/proxy_test.go b/server/lib/devtoolsproxy/proxy_test.go index 3092891f..9c8bffe6 100644 --- a/server/lib/devtoolsproxy/proxy_test.go +++ b/server/lib/devtoolsproxy/proxy_test.go @@ -159,6 +159,76 @@ func TestWebSocketProxyHandler_ProxiesEcho(t *testing.T) { } } +func TestDialUpstreamWithRetry_RechecksCurrentAfterMissedUpdate(t *testing.T) { + // Start a working websocket upstream. + upstreamSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, + }) + if err != nil { + t.Fatalf("accept failed: %v", err) + return + } + defer c.Close(websocket.StatusNormalClosure, "") + + mt, msg, err := c.Read(r.Context()) + if err != nil { + return + } + if err := c.Write(r.Context(), mt, msg); err != nil { + t.Fatalf("write failed: %v", err) + } + })) + defer upstreamSrv.Close() + + freshURL, err := url.Parse(upstreamSrv.URL) + if err != nil { + t.Fatalf("parse upstream URL: %v", err) + } + freshURL.Scheme = "ws" + freshURL.Path = "/devtools/browser/fresh" + + stalePort, err := getFreePort() + if err != nil { + t.Fatalf("get stale port: %v", err) + } + staleURL := fmt.Sprintf("ws://127.0.0.1:%d/devtools/browser/stale", stalePort) + + logger := silentLogger() + mgr := NewUpstreamManager("/dev/null", logger) + urlCh, cancel := mgr.Subscribe() + defer cancel() + + // Simulate the race window by advancing Current without broadcasting to the + // subscriber channel. The retry path must re-check Current after the stale + // dial fails instead of waiting forever for a missed notification. + mgr.currentURL.Store(freshURL.String()) + + ctx, ctxCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer ctxCancel() + + conn, connectedURL, err := dialUpstreamWithRetry(ctx, mgr, urlCh, staleURL, nil, logger) + if err != nil { + t.Fatalf("dialUpstreamWithRetry failed: %v", err) + } + defer conn.Close(websocket.StatusNormalClosure, "") + + if connectedURL != freshURL.String() { + t.Fatalf("expected to connect to %q, got %q", freshURL.String(), connectedURL) + } + + if err := conn.Write(ctx, websocket.MessageText, []byte("ping")); err != nil { + t.Fatalf("write failed: %v", err) + } + _, msg, err := conn.Read(ctx) + if err != nil { + t.Fatalf("read failed: %v", err) + } + if string(msg) != "ping" { + t.Fatalf("expected echo %q, got %q", "ping", string(msg)) + } +} + func TestUpstreamManagerDetectsChromiumAndRestart(t *testing.T) { browser, err := findBrowserBinary() if err != nil { From e31c727b9f5fcfe9a08f02f8446cb4f1e01a1522 Mon Sep 17 00:00:00 2001 From: Steven Miller Date: Thu, 26 Mar 2026 12:51:56 -0400 Subject: [PATCH 3/3] test: cover CDP reconnect during chromium restart --- server/e2e/e2e_cdp_reconnect_test.go | 592 +++++++++++++++++++++++++++ server/lib/devtoolsproxy/proxy.go | 88 +++- 2 files changed, 673 insertions(+), 7 deletions(-) create mode 100644 server/e2e/e2e_cdp_reconnect_test.go diff --git a/server/e2e/e2e_cdp_reconnect_test.go b/server/e2e/e2e_cdp_reconnect_test.go new file mode 100644 index 00000000..dfa9815f --- /dev/null +++ b/server/e2e/e2e_cdp_reconnect_test.go @@ -0,0 +1,592 @@ +package e2e + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "os/exec" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/coder/websocket" + instanceoapi "github.com/onkernel/kernel-images/server/lib/oapi" + "github.com/stretchr/testify/require" +) + +type cdpError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type cdpEnvelope struct { + ID int `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Params json.RawMessage `json:"params,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *cdpError `json:"error,omitempty"` + SessionID string `json:"sessionId,omitempty"` +} + +type cdpCall struct { + response chan cdpEnvelope +} + +type cdpEventWaiter struct { + method string + sessionID string + response chan cdpEnvelope +} + +type cdpClient struct { + conn *websocket.Conn + closeMu sync.Mutex + closed bool + closeCh chan struct{} + closeErr error + + mu sync.Mutex + nextID int + pending map[int]cdpCall + waiters []cdpEventWaiter +} + +type cdpExerciseResult struct { + Browser string + Title string + Heading string + Sum int + ReadyState string + ScreenshotBytes int +} + +func newCDPClient(ctx context.Context, wsURL string) (*cdpClient, error) { + conn, _, err := websocket.Dial(ctx, wsURL, nil) + if err != nil { + return nil, err + } + conn.SetReadLimit(100 * 1024 * 1024) + + client := &cdpClient{ + conn: conn, + closeCh: make(chan struct{}), + nextID: 1, + pending: make(map[int]cdpCall), + } + go client.readLoop() + return client, nil +} + +func (c *cdpClient) readLoop() { + for { + _, msg, err := c.conn.Read(context.Background()) + if err != nil { + c.closeWithErr(err) + return + } + + var envelope cdpEnvelope + if err := json.Unmarshal(msg, &envelope); err != nil { + c.closeWithErr(fmt.Errorf("unmarshal CDP message: %w", err)) + return + } + + if envelope.ID != 0 { + c.mu.Lock() + call, ok := c.pending[envelope.ID] + if ok { + delete(c.pending, envelope.ID) + } + c.mu.Unlock() + if ok { + call.response <- envelope + } + continue + } + + c.mu.Lock() + for i, waiter := range c.waiters { + if waiter.method == envelope.Method && waiter.sessionID == envelope.SessionID { + c.waiters = append(c.waiters[:i], c.waiters[i+1:]...) + c.mu.Unlock() + waiter.response <- envelope + goto handled + } + } + c.mu.Unlock() + + handled: + } +} + +func (c *cdpClient) closeWithErr(err error) { + c.closeMu.Lock() + defer c.closeMu.Unlock() + if c.closed { + return + } + c.closed = true + c.closeErr = err + + c.mu.Lock() + for _, call := range c.pending { + close(call.response) + } + c.pending = map[int]cdpCall{} + for _, waiter := range c.waiters { + close(waiter.response) + } + c.waiters = nil + c.mu.Unlock() + + close(c.closeCh) +} + +func (c *cdpClient) Close() { + _ = c.conn.Close(websocket.StatusNormalClosure, "") +} + +func (c *cdpClient) WaitClosed(ctx context.Context) error { + select { + case <-c.closeCh: + return c.closeErr + case <-ctx.Done(): + return ctx.Err() + } +} + +func (c *cdpClient) Call(ctx context.Context, method string, params any, sessionID string) (json.RawMessage, error) { + c.closeMu.Lock() + closed := c.closed + c.closeMu.Unlock() + if closed { + return nil, fmt.Errorf("connection closed: %w", c.closeErr) + } + + id := 0 + responseCh := make(chan cdpEnvelope, 1) + + c.mu.Lock() + id = c.nextID + c.nextID++ + c.pending[id] = cdpCall{response: responseCh} + c.mu.Unlock() + + payload := map[string]any{ + "id": id, + "method": method, + "params": params, + } + if sessionID != "" { + payload["sessionId"] = sessionID + } + + writeCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + if err := wsjsonWrite(writeCtx, c.conn, payload); err != nil { + c.mu.Lock() + delete(c.pending, id) + c.mu.Unlock() + return nil, err + } + + select { + case envelope, ok := <-responseCh: + if !ok { + return nil, fmt.Errorf("connection closed while waiting for %s", method) + } + if envelope.Error != nil { + return nil, fmt.Errorf("CDP %s failed: %d %s", method, envelope.Error.Code, envelope.Error.Message) + } + return envelope.Result, nil + case <-ctx.Done(): + c.mu.Lock() + delete(c.pending, id) + c.mu.Unlock() + return nil, ctx.Err() + } +} + +func (c *cdpClient) WaitForEvent(ctx context.Context, method, sessionID string) error { + responseCh := make(chan cdpEnvelope, 1) + + c.mu.Lock() + c.waiters = append(c.waiters, cdpEventWaiter{ + method: method, + sessionID: sessionID, + response: responseCh, + }) + c.mu.Unlock() + + select { + case _, ok := <-responseCh: + if !ok { + return fmt.Errorf("connection closed before event %s", method) + } + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func wsjsonWrite(ctx context.Context, conn *websocket.Conn, value any) error { + payload, err := json.Marshal(value) + if err != nil { + return err + } + return conn.Write(ctx, websocket.MessageText, payload) +} + +func decodeJSONStringField(raw json.RawMessage, field string) (string, error) { + var values map[string]any + if err := json.Unmarshal(raw, &values); err != nil { + return "", err + } + + value, ok := values[field].(string) + if !ok { + return "", fmt.Errorf("expected field %q in %s", field, string(raw)) + } + return value, nil +} + +func connectAndExerciseCDP(ctx context.Context, wsURL, label string) (*cdpClient, cdpExerciseResult, error) { + client, err := newCDPClient(ctx, wsURL) + if err != nil { + return nil, cdpExerciseResult{}, err + } + + versionRaw, err := client.Call(ctx, "Browser.getVersion", map[string]any{}, "") + if err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + browser, err := decodeJSONStringField(versionRaw, "product") + if err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + targetRaw, err := client.Call(ctx, "Target.createTarget", map[string]any{"url": "about:blank"}, "") + if err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + targetID, err := decodeJSONStringField(targetRaw, "targetId") + if err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + attachRaw, err := client.Call(ctx, "Target.attachToTarget", map[string]any{ + "targetId": targetID, + "flatten": true, + }, "") + if err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + sessionID, err := decodeJSONStringField(attachRaw, "sessionId") + if err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + if _, err := client.Call(ctx, "Page.enable", map[string]any{}, sessionID); err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + if _, err := client.Call(ctx, "Runtime.enable", map[string]any{}, sessionID); err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + loadCtx, loadCancel := context.WithTimeout(ctx, 15*time.Second) + defer loadCancel() + loadDone := make(chan error, 1) + go func() { + loadDone <- client.WaitForEvent(loadCtx, "Page.loadEventFired", sessionID) + }() + + html := fmt.Sprintf("%s

%s

", label, label) + if _, err := client.Call(ctx, "Page.navigate", map[string]any{ + "url": "data:text/html," + url.PathEscape(html), + }, sessionID); err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + if err := <-loadDone; err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + evalRaw, err := client.Call(ctx, "Runtime.evaluate", map[string]any{ + "expression": `JSON.stringify({title:document.title,heading:document.querySelector("h1").textContent,sum:window.sum,ready:document.readyState})`, + "returnByValue": true, + }, sessionID) + if err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + var evalEnvelope struct { + Result struct { + Value string `json:"value"` + } `json:"result"` + } + if err := json.Unmarshal(evalRaw, &evalEnvelope); err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + var summary struct { + Title string `json:"title"` + Heading string `json:"heading"` + Sum int `json:"sum"` + Ready string `json:"ready"` + } + if err := json.Unmarshal([]byte(evalEnvelope.Result.Value), &summary); err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + screenshotRaw, err := client.Call(ctx, "Page.captureScreenshot", map[string]any{"format": "png"}, sessionID) + if err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + var screenshotEnvelope struct { + Data string `json:"data"` + } + if err := json.Unmarshal(screenshotRaw, &screenshotEnvelope); err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + screenshotBytes, err := base64.StdEncoding.DecodeString(screenshotEnvelope.Data) + if err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + if _, err := client.Call(ctx, "Target.closeTarget", map[string]any{"targetId": targetID}, ""); err != nil { + client.Close() + return nil, cdpExerciseResult{}, err + } + + return client, cdpExerciseResult{ + Browser: browser, + Title: summary.Title, + Heading: summary.Heading, + Sum: summary.Sum, + ReadyState: summary.Ready, + ScreenshotBytes: len(screenshotBytes), + }, nil +} + +func restartChromiumViaAPI(ctx context.Context, client *instanceoapi.ClientWithResponses) error { + args := []string{"-c", "/etc/supervisor/supervisord.conf", "restart", "chromium"} + req := instanceoapi.ProcessExecJSONRequestBody{ + Command: "supervisorctl", + Args: &args, + } + + rsp, err := client.ProcessExecWithResponse(ctx, req) + if err != nil { + return err + } + if rsp.JSON200 == nil { + return fmt.Errorf("restart chromium returned status %s", rsp.Status()) + } + return nil +} + +func waitForContainerFile(ctx context.Context, client *instanceoapi.ClientWithResponses, path string, timeout time.Duration) error { + waitCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + args := []string{"-lc", fmt.Sprintf("test -f %q", path)} + req := instanceoapi.ProcessExecJSONRequestBody{ + Command: "sh", + Args: &args, + } + reqCtx, reqCancel := context.WithTimeout(waitCtx, 2*time.Second) + rsp, err := client.ProcessExecWithResponse(reqCtx, req) + reqCancel() + if err == nil && rsp.JSON200 != nil && rsp.JSON200.ExitCode != nil && *rsp.JSON200.ExitCode == 0 { + return nil + } + + select { + case <-waitCtx.Done(): + return waitCtx.Err() + case <-ticker.C: + } + } +} + +func touchContainerFile(ctx context.Context, client *instanceoapi.ClientWithResponses, path string) error { + args := []string{"-lc", fmt.Sprintf("mkdir -p %q && touch %q", filepath.Dir(path), path)} + req := instanceoapi.ProcessExecJSONRequestBody{ + Command: "sh", + Args: &args, + } + reqCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + rsp, err := client.ProcessExecWithResponse(reqCtx, req) + if err != nil { + return err + } + if rsp.JSON200 == nil { + return fmt.Errorf("touch %s returned status %s", path, rsp.Status()) + } + if rsp.JSON200.ExitCode == nil || *rsp.JSON200.ExitCode != 0 { + return fmt.Errorf("touch %s failed with exit code %v", path, rsp.JSON200.ExitCode) + } + return nil +} + +func fetchBrowserWebSocketURL(ctx context.Context, c *TestContainer) (string, error) { + versionURL := fmt.Sprintf("http://127.0.0.1:%d/json/version", c.CDPPort) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, versionURL, nil) + if err != nil { + return "", err + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("json/version returned %s", resp.Status) + } + + var payload struct { + WebSocketDebuggerURL string `json:"webSocketDebuggerUrl"` + } + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return "", err + } + if payload.WebSocketDebuggerURL == "" { + return "", fmt.Errorf("json/version missing webSocketDebuggerUrl") + } + return payload.WebSocketDebuggerURL, nil +} + +func waitForChangedBrowserWebSocketURL(ctx context.Context, c *TestContainer, previous string, timeout time.Duration) (string, error) { + waitCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + ticker := time.NewTicker(200 * time.Millisecond) + defer ticker.Stop() + + for { + current, err := fetchBrowserWebSocketURL(waitCtx, c) + if err == nil && current != "" && current != previous { + return current, nil + } + + select { + case <-waitCtx.Done(): + if err != nil { + return "", fmt.Errorf("waiting for changed browser websocket url: %w", err) + } + return "", waitCtx.Err() + case <-ticker.C: + } + } +} + +func TestCDPProxyReconnectPendingConnectionDuringRestart(t *testing.T) { + t.Parallel() + + if _, err := exec.LookPath("docker"); err != nil { + t.Skipf("docker not available: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Minute) + defer cancel() + + const hookPath = "/tmp/devtoolsproxy-race" + c := NewTestContainer(t, headlessImage) + require.NoError(t, c.Start(ctx, ContainerConfig{ + Env: map[string]string{ + "DEVTOOLS_PROXY_TEST_POST_CURRENT_BLOCK_FILE": hookPath, + }, + }), "failed to start container") + defer c.Stop(ctx) + + require.NoError(t, c.WaitReady(ctx), "api not ready") + require.NoError(t, c.WaitDevTools(ctx), "devtools not ready") + initialBrowserWS, err := fetchBrowserWebSocketURL(ctx, c) + require.NoError(t, err, "failed to fetch initial browser websocket URL") + + apiClient, err := c.APIClientNoKeepAlive() + require.NoError(t, err) + + initialConn, initialResult, err := connectAndExerciseCDP(ctx, c.CDPURL(), "before-restart") + require.NoError(t, err, "initial CDP session failed") + defer initialConn.Close() + + require.Equal(t, "before-restart", initialResult.Title) + require.Equal(t, "before-restart", initialResult.Heading) + require.Equal(t, 21, initialResult.Sum) + require.Equal(t, "complete", initialResult.ReadyState) + require.Greater(t, initialResult.ScreenshotBytes, 1000) + + type reconnectResult struct { + client *cdpClient + data cdpExerciseResult + err error + } + reconnectCh := make(chan reconnectResult, 1) + reconnectURL := c.CDPURL() + "?devtoolsProxyTestHook=1" + go func() { + client, data, err := connectAndExerciseCDP(ctx, reconnectURL, "after-restart") + reconnectCh <- reconnectResult{client: client, data: data, err: err} + }() + + require.NoError(t, waitForContainerFile(ctx, apiClient, hookPath+".ready", 15*time.Second), "pending reconnect never reached post-current hook") + + require.NoError(t, restartChromiumViaAPI(ctx, apiClient), "restart chromium failed") + + closeCtx, closeCancel := context.WithTimeout(ctx, 30*time.Second) + defer closeCancel() + closeErr := initialConn.WaitClosed(closeCtx) + require.Error(t, closeErr, "expected initial CDP connection to close on chromium restart") + require.False(t, errors.Is(closeErr, context.DeadlineExceeded), "timed out waiting for initial CDP connection to close") + + _, err = initialConn.Call(ctx, "Browser.getVersion", map[string]any{}, "") + require.Error(t, err, "expected stale CDP connection to reject commands after restart") + + updatedBrowserWS, err := waitForChangedBrowserWebSocketURL(ctx, c, initialBrowserWS, 20*time.Second) + require.NoError(t, err, "proxy never exposed a new browser websocket URL after restart") + require.NotEqual(t, initialBrowserWS, updatedBrowserWS) + require.NoError(t, touchContainerFile(ctx, apiClient, hookPath+".release"), "failed to release pending reconnect") + + select { + case reconnect := <-reconnectCh: + require.NoError(t, reconnect.err, "reconnect CDP session failed") + if reconnect.client != nil { + defer reconnect.client.Close() + } + require.Equal(t, "after-restart", reconnect.data.Title) + require.Equal(t, "after-restart", reconnect.data.Heading) + require.Equal(t, 21, reconnect.data.Sum) + require.Equal(t, "complete", reconnect.data.ReadyState) + require.Greater(t, reconnect.data.ScreenshotBytes, 1000) + case <-time.After(30 * time.Second): + t.Fatal("timed out waiting for reconnect CDP session") + } +} diff --git a/server/lib/devtoolsproxy/proxy.go b/server/lib/devtoolsproxy/proxy.go index 117f237c..9c762b35 100644 --- a/server/lib/devtoolsproxy/proxy.go +++ b/server/lib/devtoolsproxy/proxy.go @@ -8,6 +8,7 @@ import ( "log/slog" "net/http" "net/url" + "os" "os/exec" "regexp" "strconv" @@ -236,6 +237,65 @@ func dialUpstreamWithRetry(ctx context.Context, mgr *UpstreamManager, urlCh <-ch } } +func maybePauseAfterCurrentRead(ctx context.Context, logger *slog.Logger, r *http.Request) { + if r.URL.Query().Get("devtoolsProxyTestHook") != "1" { + return + } + + // Test-only hook used by e2e to widen the window between reading Current + // and dialing/subscribing so reconnect races can be reproduced reliably. + rawDelayMs := os.Getenv("DEVTOOLS_PROXY_TEST_POST_CURRENT_DELAY_MS") + if rawDelayMs != "" { + delayMs, err := strconv.Atoi(rawDelayMs) + if err != nil || delayMs <= 0 { + logger.Warn("ignoring invalid devtools proxy test delay", slog.String("value", rawDelayMs)) + } else { + timer := time.NewTimer(time.Duration(delayMs) * time.Millisecond) + defer timer.Stop() + + select { + case <-timer.C: + case <-ctx.Done(): + return + } + } + } + + blockPath := os.Getenv("DEVTOOLS_PROXY_TEST_POST_CURRENT_BLOCK_FILE") + if blockPath == "" { + return + } + + readyPath := blockPath + ".ready" + releasePath := blockPath + ".release" + if err := os.WriteFile(readyPath, []byte("ready\n"), 0o644); err != nil { + logger.Warn("failed to write devtools proxy test ready marker", + slog.String("path", readyPath), + slog.String("err", err.Error())) + return + } + + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + + for { + if _, err := os.Stat(releasePath); err == nil { + return + } else if !os.IsNotExist(err) { + logger.Warn("failed to read devtools proxy test release marker", + slog.String("path", releasePath), + slog.String("err", err.Error())) + return + } + + select { + case <-ticker.C: + case <-ctx.Done(): + return + } + } +} + // WebSocketProxyHandler returns an http.Handler that upgrades incoming connections and // proxies them to the current upstream websocket URL. It expects only websocket requests. // If logCDPMessages is true, all CDP messages will be logged with their direction. @@ -267,6 +327,7 @@ func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMess http.Error(w, "upstream not ready", http.StatusServiceUnavailable) return } + maybePauseAfterCurrentRead(r.Context(), logger, r) // Accept the client WebSocket connection. clientConn, err := websocket.Accept(w, r, acceptOpts) @@ -298,14 +359,27 @@ func WebSocketProxyHandler(mgr *UpstreamManager, logger *slog.Logger, logCDPMess // forcing the client to reconnect with the new upstream. pumpCtx, pumpCancel := context.WithCancel(r.Context()) - go func() { - select { - case <-urlCh: - logger.Info("upstream URL changed, closing stale proxy session") - pumpCancel() - case <-pumpCtx.Done(): + go func(currentUpstreamURL string) { + for { + select { + case newURL, ok := <-urlCh: + if !ok { + return + } + newURL = normalizeUpstreamURL(newURL) + if newURL == "" || newURL == currentUpstreamURL { + continue + } + logger.Info("upstream URL changed, closing stale proxy session", + slog.String("old_url", currentUpstreamURL), + slog.String("new_url", newURL)) + pumpCancel() + return + case <-pumpCtx.Done(): + return + } } - }() + }(upstreamURL) var once sync.Once cleanup := func() {