Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions desktop/src/renderer/src/lib/install.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { z } from "zod";

import { generateDashboardToken } from "./security";

export const EXISTING_SECRET_VALUE = "__OCTOPAL_DESKTOP_EXISTING_SECRET__";

export function isExistingSecret(value: string | undefined | null): boolean {
Expand Down Expand Up @@ -150,6 +152,18 @@ function secretString(value: string | undefined): string {
return isExistingSecret(value) ? "" : value || "";
}

function dashboardToken(value: string | undefined, webappEnabled: boolean): string {
if (isExistingSecret(value)) {
return "";
}

const configured = (value || "").trim();
if (configured || !webappEnabled) {
return configured;
}
return generateDashboardToken();
}

function secretNullable(value: string | undefined): string | null {
return isExistingSecret(value) ? null : value || null;
}
Expand Down Expand Up @@ -233,9 +247,9 @@ export function buildOctopalConfig(values: InstallForm) {
workspace_dir: "workspace",
},
gateway: {
host: "0.0.0.0",
host: "127.0.0.1",
port: values.dashboardPort,
dashboard_token: secretString(values.dashboardToken),
dashboard_token: dashboardToken(values.dashboardToken, values.dashboardEnabled),
tailscale_auto_serve: true,
tailscale_ips: "",
webapp_enabled: values.dashboardEnabled,
Expand Down
57 changes: 47 additions & 10 deletions src/octopal/gateway/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,22 @@
logger = structlog.get_logger(__name__)


def _is_local_ws_client(client_host: str) -> bool:
return client_host in ("127.0.0.1", "::1", "localhost", "testclient")


def _provided_ws_token(socket: WebSocket) -> str:
auth_header = socket.headers.get("authorization", "").strip()
if auth_header.lower().startswith("bearer "):
return auth_header[7:].strip()
return str(socket.query_params.get("token", "")).strip()


async def _reject_ws(socket: WebSocket, *, host: str, reason: str) -> None:
logger.warning("Rejected WebSocket connection", host=host, reason=reason)
await socket.close(code=status.WS_1008_POLICY_VIOLATION)


async def _ws_send_json(
session: _ActiveWsSession,
payload: dict[str, Any],
Expand Down Expand Up @@ -152,12 +168,19 @@ async def websocket_endpoint(socket: WebSocket) -> None:
if allowed_ips:
logger.info("Automatically discovered Tailscale IPs", ips=allowed_ips)

is_local = client_host in ("127.0.0.1", "::1", "localhost")
is_local = _is_local_ws_client(client_host)
if not is_local and not allowed_ips:
await _reject_ws(socket, host=client_host, reason="no Tailscale allowlist available")
return

if allowed_ips and not is_local and client_host not in allowed_ips:
logger.warning("Rejected WebSocket connection from unauthorized IP", host=client_host)
await socket.close(code=status.WS_1008_POLICY_VIOLATION)
return
await _reject_ws(socket, host=client_host, reason="host not in Tailscale allowlist")
return

expected_token = str(getattr(settings, "dashboard_token", "") or "").strip()
if expected_token and _provided_ws_token(socket) != expected_token:
await _reject_ws(socket, host=client_host, reason="invalid dashboard token")
return

await socket.accept()
logger.info("WebSocket connection established", host=client_host)
Expand Down Expand Up @@ -223,8 +246,14 @@ async def _ws_worker_event(chat_id: int, event: str, payload: dict[str, Any]) ->

# A newer WS client takes over the interactive channel from any older session.
async with app.state.ws_session_lock:
previous_session: _ActiveWsSession | None = getattr(app.state, "active_ws_session", None)
if previous_session and previous_session.connection_id != connection_id and not previous_session.closed.is_set():
previous_session: _ActiveWsSession | None = getattr(
app.state, "active_ws_session", None
)
if (
previous_session
and previous_session.connection_id != connection_id
and not previous_session.closed.is_set()
):
logger.info(
"Taking over active WebSocket session",
host=client_host,
Expand All @@ -241,12 +270,16 @@ async def _ws_worker_event(chat_id: int, event: str, payload: dict[str, Any]) ->
event_name="takeover_warning",
)
except Exception:
logger.debug("Failed to notify previous WebSocket session before takeover", exc_info=True)
logger.debug(
"Failed to notify previous WebSocket session before takeover", exc_info=True
)

try:
await previous_session.socket.close(code=status.WS_1000_NORMAL_CLOSURE)
except Exception:
logger.debug("Failed to close previous WebSocket session during takeover", exc_info=True)
logger.debug(
"Failed to close previous WebSocket session during takeover", exc_info=True
)

try:
await asyncio.wait_for(previous_session.closed.wait(), timeout=2.0)
Expand Down Expand Up @@ -282,7 +315,9 @@ async def _ws_worker_event(chat_id: int, event: str, payload: dict[str, Any]) ->
try:
active_workers = await asyncio.to_thread(octo.store.get_active_workers)
except Exception:
logger.debug("Failed to load active workers snapshot for WebSocket session", exc_info=True)
logger.debug(
"Failed to load active workers snapshot for WebSocket session", exc_info=True
)
active_workers = []
await _ws_send_json(
session,
Expand All @@ -293,7 +328,9 @@ async def _ws_worker_event(chat_id: int, event: str, payload: dict[str, Any]) ->
event_name="workers_snapshot",
)

approvals = WsApprovalManager(send=lambda payload: _ws_send_json(session, payload, event_name="approval_request"))
approvals = WsApprovalManager(
send=lambda payload: _ws_send_json(session, payload, event_name="approval_request")
)
message_lock = asyncio.Lock()
tasks: set[asyncio.Task] = set()

Expand Down
2 changes: 1 addition & 1 deletion src/octopal/infrastructure/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class MemoryConfig(BaseModel):


class GatewayConfig(BaseModel):
host: str = "0.0.0.0"
host: str = "127.0.0.1"
port: int = 8000
tailscale_ips: str = ""
dashboard_token: str = ""
Expand Down
2 changes: 1 addition & 1 deletion src/octopal/infrastructure/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class Settings(BaseSettings):
memory_max_chars: int = Field(2000, alias="OCTOPAL_MEMORY_MAX_CHARS")
memory_owner_id: str = Field("default", alias="OCTOPAL_MEMORY_OWNER_ID")

gateway_host: str = Field("0.0.0.0", alias="OCTOPAL_GATEWAY_HOST")
gateway_host: str = Field("127.0.0.1", alias="OCTOPAL_GATEWAY_HOST")
gateway_port: int = Field(8000, alias="OCTOPAL_GATEWAY_PORT")
tailscale_ips: str = Field("", alias="OCTOPAL_TAILSCALE_IPS")
dashboard_token: str = Field("", alias="OCTOPAL_DASHBOARD_TOKEN")
Expand Down
31 changes: 31 additions & 0 deletions tests/test_gateway_ws_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from starlette.websockets import WebSocketDisconnect

from octopal.gateway.ws import (
WsApprovalManager,
_ActiveWsSession,
_build_ws_file_payload,
_handle_message,
_is_local_ws_client,
_resolve_ws_chat_id,
register_ws_routes,
)
Expand All @@ -29,6 +31,35 @@ def test_resolve_ws_chat_id_uses_first_allowed_id_when_valid() -> None:
assert _resolve_ws_chat_id(settings) == 42


def test_websocket_client_host_helper_rejects_lan_addresses() -> None:
assert _is_local_ws_client("127.0.0.1")
assert _is_local_ws_client("::1")
assert _is_local_ws_client("testclient")
assert not _is_local_ws_client("192.168.1.55")


def test_websocket_requires_dashboard_token_when_configured() -> None:
class DummyOcto:
def set_output_channel(self, is_ws: bool, **kwargs) -> bool:
return True

app = FastAPI()
app.state.settings = SimpleNamespace(
tailscale_ips="testclient",
allowed_telegram_chat_ids="",
dashboard_token="secret-token",
)
app.state.octo = DummyOcto()
register_ws_routes(app)

with TestClient(app) as client:
with pytest.raises(WebSocketDisconnect), client.websocket_connect("/ws"):
pass

with client.websocket_connect("/ws?token=secret-token") as ws:
assert ws.receive_json() == {"type": "workers_snapshot", "workers": []}


def test_new_websocket_connection_takes_over_previous_session() -> None:
class DummyOcto:
def __init__(self) -> None:
Expand Down