diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index a9a3bf700d0..9b991559df9 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server-ws.cpp b/tools/server/server-ws.cpp index 2d6bdf76452..1353f6e8394 100644 --- a/tools/server/server-ws.cpp +++ b/tools/server/server-ws.cpp @@ -1,879 +1 @@ -#include "server-ws.h" -#include "common.h" -#include "log.h" -#include "arg.h" -#include "base64.hpp" -#include "sha1/sha1.h" - -#include -#include -#include -#include -#include - -// Security limits to prevent DoS attacks -static constexpr size_t WS_MAX_PAYLOAD_SIZE = 10 * 1024 * 1024; // 10MB per frame -static constexpr size_t WS_MAX_MESSAGE_SIZE = 100 * 1024 * 1024; // 100MB assembled message -static constexpr size_t WS_MAX_RECEIVE_BUFFER = 16 * 1024 * 1024; // 16MB receive buffer -static constexpr size_t WS_MAX_CONNECTIONS_DEFAULT = 10; // Default max concurrent connections -static constexpr int WS_SOCKET_TIMEOUT_SEC = 300; // 5 minute socket read timeout (MCP tools can be slow) - -// Get max connections limit from environment or use default -static size_t get_ws_max_connections() { - static size_t cached = 0; - static bool initialized = false; - if (!initialized) { - const char * env = std::getenv("LLAMA_WS_MAX_CONNECTIONS"); - if (env) { - cached = static_cast(std::atoi(env)); - if (cached == 0) { - cached = WS_MAX_CONNECTIONS_DEFAULT; - } - } else { - cached = WS_MAX_CONNECTIONS_DEFAULT; - } - initialized = true; - } - return cached; -} - -#ifdef _WIN32 -#include -#include -typedef SOCKET socket_t; -#define INVALID_SOCKET_VALUE INVALID_SOCKET -#define SOCKET_ERROR_VALUE SOCKET_ERROR -#else -#include -#include -#include -#include -#include -#include -#include -#include -typedef int socket_t; -#define INVALID_SOCKET_VALUE -1 -#define SOCKET_ERROR_VALUE -1 -#endif - -// WebSocket frame constants -namespace ws_frame { - constexpr uint8_t FIN_BIT = 0x80; - - enum opcode : uint8_t { - CONTINUATION = 0x0, - TEXT = 0x1, - BINARY = 0x2, - CLOSE = 0x8, - PING = 0x9, - PONG = 0xa - }; -} - -// Simple WebSocket implementation using raw sockets -class ws_connection_impl : public server_ws_connection { -public: - ws_connection_impl(socket_t sock, const std::string & path, const std::string & query) - : sock_(sock) - , path_(path) - , query_(query) - , closed_(false) { - parse_query_params(); - } - - ~ws_connection_impl() override { - close(1000, ""); - } - - void send(const std::string & message) override { - if (closed_) { - SRV_WRN("%s: cannot send, connection closed: %s\n", __func__, get_remote_address().c_str()); - return; - } - - // Create WebSocket text frame - std::vector frame; - - uint8_t first_byte = ws_frame::FIN_BIT | ws_frame::TEXT; - frame.push_back(first_byte); - - size_t len = message.size(); - if (len < 126) { - frame.push_back(static_cast(len)); - } else if (len < 65536) { - frame.push_back(126); - frame.push_back(static_cast((len >> 8) & 0xff)); - frame.push_back(static_cast(len & 0xff)); - } else { - frame.push_back(127); - for (int i = 7; i >= 0; i--) { - frame.push_back(static_cast((len >> (i * 8)) & 0xff)); - } - } - - frame.insert(frame.end(), message.begin(), message.end()); - - // Log frame header for debugging - SRV_INF("%s: %s: frame_header=[%02x %02x %02x %02x ...] payload_size=%zd\n", - __func__, get_remote_address().c_str(), - frame.size() > 0 ? frame[0] : 0, - frame.size() > 1 ? frame[1] : 0, - frame.size() > 2 ? frame[2] : 0, - frame.size() > 3 ? frame[3] : 0, - len); - - // Send frame -#ifdef _WIN32 - int sent = ::send(sock_, reinterpret_cast(frame.data()), static_cast(frame.size()), 0); -#else - ssize_t sent = ::send(sock_, frame.data(), frame.size(), 0); -#endif - SRV_INF("%s: %s: frame_size=%zd sent=%zd\n", __func__, get_remote_address().c_str(), frame.size(), sent); - if (sent < 0) { - SRV_ERR("%s: send failed: %s (errno=%d)\n", __func__, get_remote_address().c_str(), errno); - closed_ = true; - } else if (static_cast(sent) != frame.size()) { - SRV_WRN("%s: partial send: %zd/%zd bytes\n", __func__, sent, frame.size()); - } - } - - void close(int code, const std::string & reason) override { - if (closed_) { - return; - } - - // Send close frame - std::vector close_frame; - close_frame.push_back(ws_frame::FIN_BIT | ws_frame::CLOSE); - close_frame.push_back(2 + reason.size()); - close_frame.push_back(static_cast((code >> 8) & 0xff)); - close_frame.push_back(static_cast(code & 0xff)); - close_frame.insert(close_frame.end(), reason.begin(), reason.end()); - -#ifdef _WIN32 - ::send(sock_, reinterpret_cast(close_frame.data()), static_cast(close_frame.size()), 0); - closesocket(sock_); -#else - ::send(sock_, close_frame.data(), close_frame.size(), 0); - ::close(sock_); -#endif - - closed_ = true; - } - - std::string get_query_param(const std::string & key) const override { - auto it = query_params_.find(key); - if (it != query_params_.end()) { - return it->second; - } - return ""; - } - - std::string get_remote_address() const override { - return remote_address_; - } - - socket_t socket() const { return sock_; } - bool is_closed() const { return closed_; } - - // Handle incoming data - void handle_data(const std::vector & data, std::function on_message) { - // Check receive buffer limit to prevent DoS - if (receive_buffer_.size() + data.size() > WS_MAX_RECEIVE_BUFFER) { - SRV_ERR("%s: receive buffer overflow, closing connection\n", __func__); - close(1009, "Message too big"); - return; - } - receive_buffer_.insert(receive_buffer_.end(), data.begin(), data.end()); - - while (true) { - if (receive_buffer_.size() < 2) { - break; // Need at least 2 bytes for header - } - - uint8_t first_byte = receive_buffer_[0]; - uint8_t second_byte = receive_buffer_[1]; - - // Validate RSV bits (must be 0 unless extension negotiated) - uint8_t rsv_bits = (first_byte & 0x70); - if (rsv_bits != 0) { - SRV_ERR("%s: protocol error - RSV bits must be 0\n", __func__); - close(1002, "Protocol error"); - return; - } - - bool fin = (first_byte & 0x80) != 0; - ws_frame::opcode opcode = static_cast(first_byte & 0x0f); - bool masked = (second_byte & 0x80) != 0; - uint64_t payload_len = second_byte & 0x7f; - - // RFC 6455: Client frames MUST be masked - if (!masked) { - SRV_ERR("%s: protocol error - client frames must be masked\n", __func__); - close(1002, "Protocol error"); - return; - } - - size_t header_len = 2; - - if (payload_len == 126) { - if (receive_buffer_.size() < 4) break; - payload_len = (static_cast(receive_buffer_[2]) << 8) | - static_cast(receive_buffer_[3]); - header_len = 4; - } else if (payload_len == 127) { - if (receive_buffer_.size() < 10) break; - payload_len = 0; - for (int i = 0; i < 8; i++) { - payload_len = (payload_len << 8) | receive_buffer_[2 + i]; - } - header_len = 10; - } - - // Validate payload size to prevent DoS via huge allocations - if (payload_len > WS_MAX_PAYLOAD_SIZE) { - SRV_ERR("%s: payload too large: %zu > %zu\n", __func__, - static_cast(payload_len), WS_MAX_PAYLOAD_SIZE); - close(1009, "Message too big"); - return; - } - - // Check for integer overflow in total_len calculation - if (payload_len > SIZE_MAX - header_len - 4) { - SRV_ERR("%s: frame size overflow\n", __func__); - close(1009, "Message too big"); - return; - } - - size_t total_len = header_len + payload_len + (masked ? 4 : 0); - if (receive_buffer_.size() < total_len) { - break; // Incomplete frame - } - - // Extract payload (skip the mask if present) - size_t payload_offset = header_len + (masked ? 4 : 0); - std::vector payload(receive_buffer_.begin() + payload_offset, - receive_buffer_.begin() + payload_offset + payload_len); - - // Unmask if needed - if (masked) { - uint8_t mask[4]; - std::memcpy(mask, &receive_buffer_[header_len], 4); - SRV_DBG("%s: unmasking payload with mask: [0x%02x, 0x%02x, 0x%02x, 0x%02x]\n", - __func__, mask[0], mask[1], mask[2], mask[3]); - for (size_t i = 0; i < payload_len; i++) { - uint8_t masked = payload[i]; - payload[i] ^= mask[i % 4]; - if (i < 20) { - SRV_DBG("%s: [%zu] masked=0x%02x unmasked=0x%02x ('%c')\n", - __func__, i, masked, payload[i], - isprint(payload[i]) ? payload[i] : '.'); - } - } - SRV_DBG("%s: first 20 chars of payload: '%.*s'\n", - __func__, (int)std::min(size_t(20), payload.size()), - payload.data()); - } - - // Remove processed frame from buffer - receive_buffer_.erase(receive_buffer_.begin(), - receive_buffer_.begin() + total_len); - - // Handle frame - if (opcode == ws_frame::TEXT || opcode == ws_frame::CONTINUATION) { - // Check message buffer limit to prevent DoS via fragmented messages - if (message_buffer_.size() + payload.size() > WS_MAX_MESSAGE_SIZE) { - SRV_ERR("%s: message too large, closing connection\n", __func__); - close(1009, "Message too big"); - message_buffer_.clear(); - return; - } - if (fin) { - // Complete message - message_buffer_.insert(message_buffer_.end(), payload.begin(), payload.end()); - std::string msg(message_buffer_.begin(), message_buffer_.end()); - on_message(msg); - message_buffer_.clear(); - } else { - // Fragmented message - message_buffer_.insert(message_buffer_.end(), payload.begin(), payload.end()); - } - } else if (opcode == ws_frame::PING) { - // Respond with pong - send_pong(payload); - } else if (opcode == ws_frame::CLOSE) { - close(1000, "Normal closure"); - break; - } - } - } - - void set_remote_address(const std::string & addr) { - remote_address_ = addr; - } - -private: - socket_t sock_; - std::string path_; - std::string query_; - std::string remote_address_; - std::map query_params_; - bool closed_; - - std::vector receive_buffer_; - std::vector message_buffer_; - - void parse_query_params() { - std::string remaining = query_; - while (!remaining.empty()) { - size_t amp_pos = remaining.find('&'); - std::string pair; - if (amp_pos == std::string::npos) { - pair = remaining; - remaining.clear(); - } else { - pair = remaining.substr(0, amp_pos); - remaining = remaining.substr(amp_pos + 1); - } - - size_t eq_pos = pair.find('='); - if (eq_pos != std::string::npos) { - std::string key = pair.substr(0, eq_pos); - std::string value = pair.substr(eq_pos + 1); - query_params_[key] = value; - } - } - } - - void send_pong(const std::vector & payload) { - // RFC 6455: Control frames must have payload ≤ 125 bytes - if (payload.size() > 125) { - SRV_WRN("%s: PING payload too large (%zu > 125), ignoring\n", - __func__, payload.size()); - return; - } - - std::vector frame; - frame.push_back(ws_frame::FIN_BIT | ws_frame::PONG); - frame.push_back(static_cast(payload.size())); - frame.insert(frame.end(), payload.begin(), payload.end()); - -#ifdef _WIN32 - ::send(sock_, reinterpret_cast(frame.data()), static_cast(frame.size()), 0); -#else - ::send(sock_, frame.data(), frame.size(), 0); -#endif - } -}; - -struct server_ws_context::Impl { - socket_t listen_sock = INVALID_SOCKET_VALUE; - std::atomic running{false}; - std::thread accept_thread; - std::mutex connections_mutex; - std::map> connections; - - server_ws_context::on_open_t on_open_cb; - server_ws_context::on_message_t on_message_cb; - server_ws_context::on_close_t on_close_cb; - - int port = 0; - std::string path_prefix = "/mcp"; - std::vector api_keys; - - // Methods - void accept_loop(); - void handle_connection(socket_t sock, const struct sockaddr_in & addr); - bool validate_api_key(const std::string & auth_header) const; -}; - -server_ws_context::server_ws_context() - : pimpl(std::make_unique()) { -} - -server_ws_context::~server_ws_context() { - stop(); -} - -bool server_ws_context::init(const common_params & params) { - // Use port + 1 from the HTTP server to avoid conflicts - // This provides a predictable port for the frontend - pimpl->port = params.port + 1; - pimpl->path_prefix = "/mcp"; - pimpl->api_keys = params.api_keys; - - if (pimpl->api_keys.size() == 1) { - auto key = pimpl->api_keys[0]; - std::string substr = key.substr(std::max((int)(key.length() - 4), 0)); - SRV_INF("%s: api_keys: ****%s\n", __func__, substr.c_str()); - } else if (pimpl->api_keys.size() > 1) { - SRV_INF("%s: api_keys: %zu keys loaded\n", __func__, pimpl->api_keys.size()); - } - - SRV_INF("%s: WebSocket context initialized\n", __func__); - return true; -} - -bool server_ws_context::start() { - if (pimpl->running) { - SRV_WRN("%s: WebSocket server already running\n", __func__); - return true; - } - -#ifdef _WIN32 - // Initialize Winsock - WSADATA wsa_data; - if (WSAStartup(MAKEWORD(2, 2), &wsa_data) != 0) { - SRV_ERR("%s: WSAStartup failed\n", __func__); - return false; - } -#endif - - // Create listening socket - pimpl->listen_sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); - if (pimpl->listen_sock == INVALID_SOCKET_VALUE) { - SRV_ERR("%s: socket() failed\n", __func__); - return false; - } - - // Set SO_REUSEADDR - int opt = 1; -#ifdef _WIN32 - setsockopt(pimpl->listen_sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&opt), sizeof(opt)); -#else - setsockopt(pimpl->listen_sock, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); -#endif - - // Bind to address - struct sockaddr_in addr = {}; - addr.sin_family = AF_INET; - addr.sin_addr.s_addr = INADDR_ANY; - addr.sin_port = htons(pimpl->port); - - if (bind(pimpl->listen_sock, reinterpret_cast(&addr), sizeof(addr)) < 0) { - SRV_ERR("%s: bind() failed on port %d\n", __func__, pimpl->port); -#ifdef _WIN32 - closesocket(pimpl->listen_sock); -#else - close(pimpl->listen_sock); -#endif - return false; - } - - // Listen - if (listen(pimpl->listen_sock, SOMAXCONN) < 0) { - SRV_ERR("%s: listen() failed\n", __func__); -#ifdef _WIN32 - closesocket(pimpl->listen_sock); -#else - close(pimpl->listen_sock); -#endif - return false; - } - - // Get actual port - struct sockaddr_in actual_addr; - socklen_t len = sizeof(actual_addr); - getsockname(pimpl->listen_sock, reinterpret_cast(&actual_addr), &len); - pimpl->port = ntohs(actual_addr.sin_port); - - std::ostringstream oss; - oss << "ws://" << actual_addr.sin_addr.s_addr << ":" << pimpl->port; - listening_address = oss.str(); - - pimpl->running = true; - - // Start accept thread - pimpl->accept_thread = std::thread([this]() { - pimpl->accept_loop(); - }); - - is_ready = true; - - SRV_INF("%s: WebSocket server started on port %d\n", __func__, pimpl->port); - return true; -} - -void server_ws_context::stop() { - if (!pimpl->running) { - return; - } - - pimpl->running = false; - - // Close all connections - { - std::lock_guard lock(pimpl->connections_mutex); - for (auto & [key, conn] : pimpl->connections) { - conn->close(1001, "Server shutdown"); - } - pimpl->connections.clear(); - } - - // Close listening socket - if (pimpl->listen_sock != INVALID_SOCKET_VALUE) { -#ifdef _WIN32 - closesocket(pimpl->listen_sock); -#else - close(pimpl->listen_sock); -#endif - pimpl->listen_sock = INVALID_SOCKET_VALUE; - } - - // Wait for accept thread - if (pimpl->accept_thread.joinable()) { - pimpl->accept_thread.join(); - } - - is_ready.store(false); - - SRV_INF("%s: WebSocket server stopped\n", __func__); -} - -int server_ws_context::get_actual_port() const { - return pimpl->port; -} - -void server_ws_context::on_open(on_open_t handler) { - pimpl->on_open_cb = std::move(handler); -} - -void server_ws_context::on_message(on_message_t handler) { - pimpl->on_message_cb = std::move(handler); -} - -void server_ws_context::on_close(on_close_t handler) { - pimpl->on_close_cb = std::move(handler); -} - -bool server_ws_context::Impl::validate_api_key(const std::string & auth_header) const { - // Use shared authentication helper from server-common - return ::validate_auth_header(auth_header, api_keys); -} - -void server_ws_context::Impl::accept_loop() { - while (running) { - struct sockaddr_in client_addr = {}; - socklen_t client_len = sizeof(client_addr); - - socket_t client_sock = accept(listen_sock, - reinterpret_cast(&client_addr), - &client_len); - - if (client_sock == INVALID_SOCKET_VALUE) { - if (running) { - SRV_ERR("%s: accept() failed\n", __func__); - } - continue; - } - - // Check connection limit to prevent resource exhaustion - { - std::lock_guard lock(connections_mutex); - const size_t max_conn = get_ws_max_connections(); - if (connections.size() >= max_conn) { - SRV_WRN("%s: connection limit reached (%zu), rejecting\n", - __func__, max_conn); - const char * response = "HTTP/1.1 503 Service Unavailable\r\n\r\n"; - send(client_sock, response, strlen(response), 0); -#ifdef _WIN32 - closesocket(client_sock); -#else - close(client_sock); -#endif - continue; - } - } - - // Handle connection in a thread - std::thread([this, client_sock, client_addr]() { - this->handle_connection(client_sock, client_addr); - }).detach(); - } -} - -void server_ws_context::Impl::handle_connection(socket_t sock, const struct sockaddr_in & addr) { - // Set socket options - int flag = 1; -#ifdef _WIN32 - setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&flag), sizeof(flag)); - // Set socket timeout to prevent slow-loris attacks - DWORD timeout_ms = WS_SOCKET_TIMEOUT_SEC * 1000; - setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&timeout_ms), sizeof(timeout_ms)); -#else - setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(flag)); - // Set socket timeout to prevent slow-loris attacks - struct timeval tv; - tv.tv_sec = WS_SOCKET_TIMEOUT_SEC; - tv.tv_usec = 0; - setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); -#endif - - // Read HTTP request (WebSocket handshake) - std::vector buffer(4096); -#ifdef _WIN32 - int recv_len = recv(sock, reinterpret_cast(buffer.data()), buffer.size() - 1, 0); -#else - ssize_t recv_len = recv(sock, buffer.data(), buffer.size() - 1, 0); -#endif - - if (recv_len <= 0) { -#ifdef _WIN32 - closesocket(sock); -#else - close(sock); -#endif - return; - } - - buffer[recv_len] = '\0'; - std::string request(reinterpret_cast(buffer.data())); - - // Parse HTTP request - std::istringstream iss(request); - std::string method, path, version; - iss >> method >> path >> version; - - if (method != "GET") { - // Bad request - const char * response = "HTTP/1.1 400 Bad Request\r\n\r\n"; - send(sock, response, strlen(response), 0); -#ifdef _WIN32 - closesocket(sock); -#else - close(sock); -#endif - return; - } - - // Parse path and query - std::string query; - size_t qpos = path.find('?'); - if (qpos != std::string::npos) { - query = path.substr(qpos + 1); - path = path.substr(0, qpos); - } - - // Check if path matches our prefix - if (path != path_prefix) { - // Not found - const char * response = "HTTP/1.1 404 Not Found\r\n\r\n"; - send(sock, response, strlen(response), 0); -#ifdef _WIN32 - closesocket(sock); -#else - close(sock); -#endif - return; - } - - // After iss >> method >> path >> version, consume the remaining empty line - // The stream is positioned right after "HTTP/1.1", the next line is just "\r\n" - std::string line; // Declare line for use below - std::getline(iss, line); - - // Extract headers (case-insensitive matching) - std::string websocket_key; - std::string websocket_protocol; - std::string authorization_header; - while (std::getline(iss, line)) { - // Trim trailing \r (common in HTTP headers) - if (!line.empty() && line.back() == '\r') { - line.pop_back(); - } - // Empty line marks end of headers - if (line.empty()) { - break; - } - // Case-insensitive header matching - // Convert line to lowercase for comparison, but preserve original for value extraction - std::string line_lower = line; - for (char & c : line_lower) { - if (c >= 'A' && c <= 'Z') { - c = c + 32; // tolower - } - } - - // Use string_starts_with for safe prefix matching - if (string_starts_with(line_lower, "sec-websocket-key:")) { - const std::string header_name = "sec-websocket-key:"; - if (line.length() > header_name.length()) { - websocket_key = line.substr(header_name.length()); - // Trim leading spaces - while (!websocket_key.empty() && websocket_key[0] == ' ') { - websocket_key.erase(0, 1); - } - } - } else if (string_starts_with(line_lower, "sec-websocket-protocol:")) { - const std::string header_name = "sec-websocket-protocol:"; - if (line.length() > header_name.length()) { - websocket_protocol = line.substr(header_name.length()); - // Trim leading spaces - while (!websocket_protocol.empty() && websocket_protocol[0] == ' ') { - websocket_protocol.erase(0, 1); - } - // Trim trailing spaces - while (!websocket_protocol.empty() && websocket_protocol.back() == ' ') { - websocket_protocol.pop_back(); - } - SRV_INF("%s: parsed websocket_protocol='%s'\n", __func__, websocket_protocol.c_str()); - } - } else if (string_starts_with(line_lower, "authorization:")) { - const std::string header_name = "authorization:"; - if (line.length() > header_name.length()) { - authorization_header = line.substr(header_name.length()); - // Trim leading spaces - while (!authorization_header.empty() && authorization_header[0] == ' ') { - authorization_header.erase(0, 1); - } - SRV_INF("%s: parsed authorization_header='%s'\n", __func__, authorization_header.c_str()); - } - } - } - - // Validate API key if configured - if (!validate_api_key(authorization_header)) { - const char * response = "HTTP/1.1 401 Unauthorized\r\n" - "Content-Type: application/json\r\n\r\n" - "{\"error\":{\"message\":\"Invalid API Key\",\"type\":\"authentication_error\",\"code\":401}}"; - send(sock, response, strlen(response), 0); - SRV_WRN("%s", "Unauthorized: Invalid API Key\n"); -#ifdef _WIN32 - closesocket(sock); -#else - close(sock); -#endif - return; - } - - if (websocket_key.empty()) { - const char * response = "HTTP/1.1 400 Bad Request\r\n\r\n"; - send(sock, response, strlen(response), 0); -#ifdef _WIN32 - closesocket(sock); -#else - close(sock); -#endif - return; - } - - // Compute accept key: SHA1(key + magic) -> base64 - std::string magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - std::string combined = websocket_key + magic; - char hash[20]; - SHA1(hash, combined.c_str(), static_cast(combined.size())); - std::string accept_key = base64::encode(hash, sizeof(hash)); - - std::ostringstream full_response_debug; - full_response_debug << "HTTP/1.1 101 Switching Protocols\r\n"; - full_response_debug << "Upgrade: websocket\r\n"; - full_response_debug << "Connection: Upgrade\r\n"; - full_response_debug << "Sec-WebSocket-Accept: " << accept_key << "\r\n\r\n"; - std::string full_response_str = full_response_debug.str(); - - SRV_INF("%s: hash_size=%zu, accept_key_len=%zu, last_4_chars=", - __func__, sizeof(hash), accept_key.length()); - if (accept_key.length() >= 4) { - for (size_t i = accept_key.length() - 4; i < accept_key.length(); i++) { - SRV_INF(" [%zu]='%c' (0x%02x)\n", i, accept_key[i], (unsigned char)accept_key[i]); - } - } - - // Send handshake response - std::ostringstream response; - response << "HTTP/1.1 101 Switching Protocols\r\n"; - response << "Upgrade: websocket\r\n"; - response << "Connection: Upgrade\r\n"; - response << "Sec-WebSocket-Accept: " << accept_key << "\r\n"; - // Echo back the protocol if requested - if (!websocket_protocol.empty()) { - response << "Sec-WebSocket-Protocol: " << websocket_protocol << "\r\n"; - } - response << "\r\n"; - - std::string response_str = response.str(); - SRV_DBG("%s: Sending 101 response, %zu bytes\n", __func__, response_str.size()); - size_t total_sent = 0; - size_t to_send = response_str.size(); - while (total_sent < to_send) { -#ifdef _WIN32 - int sent = ::send(sock, response_str.c_str() + total_sent, static_cast(to_send - total_sent), 0); -#else - ssize_t sent = ::send(sock, response_str.c_str() + total_sent, to_send - total_sent, 0); -#endif - if (sent <= 0) { - SRV_ERR("%s: send() failed during handshake\n", __func__); - break; - } - total_sent += sent; - } - - // Create connection object - auto conn = std::make_shared(sock, path, query); - - // Set remote address - char addr_str[INET_ADDRSTRLEN]; - inet_ntop(AF_INET, &addr.sin_addr, addr_str, sizeof(addr_str)); - conn->set_remote_address(addr_str); - - // Store connection - { - std::lock_guard lock(connections_mutex); - connections[conn.get()] = conn; - } - - // Call on_open callback - if (on_open_cb) { - SRV_INF("%s: About to call on_open_cb for server: %s\n", __func__, query.c_str()); - on_open_cb(conn); - SRV_INF("%s: Returned from on_open_cb for server: %s\n", __func__, query.c_str()); - } - - // Read loop - SRV_INF("%s: Entering read loop for server: %s\n", __func__, query.c_str()); - std::vector recv_buf(4096); - int recv_count = 0; - while (!conn->is_closed()) { -#ifdef _WIN32 - int n = recv(sock, reinterpret_cast(recv_buf.data()), recv_buf.size(), 0); -#else - ssize_t n = recv(sock, recv_buf.data(), recv_buf.size(), 0); -#endif - - if (n <= 0) { - SRV_INF("%s: recv returned %zd (count=%d) for server: %s\n", __func__, n, recv_count, query.c_str()); - break; - } - SRV_INF("%s: recv returned %zd bytes (count=%d) for server: %s\n", __func__, n, ++recv_count, query.c_str()); - - recv_buf.resize(n); - SRV_DBG("%s: received %zd bytes from WebSocket\n", __func__, n); - // Log first 100 bytes in hex - for (ssize_t i = 0; i < std::min(ssize_t(100), n); i++) { - SRV_DBG("%s: [%zd] = 0x%02x (%c)\n", __func__, i, recv_buf[i], - isprint(recv_buf[i]) ? recv_buf[i] : '.'); - } - - conn->handle_data(recv_buf, [this, conn](const std::string & msg) { - SRV_DBG("%s: handle_data callback: msg length=%zu\n", __func__, msg.length()); - if (on_message_cb) { - on_message_cb(conn, msg); - } - }); - } - - // Call on_close callback - if (on_close_cb) { - on_close_cb(conn); - } - - // Remove connection - { - std::lock_guard lock(connections_mutex); - connections.erase(conn.get()); - } - - // Close socket - conn->close(1000, ""); -} +// nothing to see here \ No newline at end of file diff --git a/tools/server/server-ws.h b/tools/server/server-ws.h index 1540cf2d35a..43199ef3d40 100644 --- a/tools/server/server-ws.h +++ b/tools/server/server-ws.h @@ -1,6 +1,7 @@ #pragma once #include "server-common.h" +#include "server-http.h" #include #include #include @@ -16,69 +17,187 @@ struct common_params; struct sockaddr_in; class ws_connection_impl; -// WebSocket connection interface -// Abstracts the underlying WebSocket implementation -struct server_ws_connection { - virtual ~server_ws_connection() = default; - - // Send a message to the client - virtual void send(const std::string & message) = 0; - // Close the connection - virtual void close(int code = 1000, const std::string & reason = "") = 0; - // Get query parameter by key - virtual std::string get_query_param(const std::string & key) const = 0; - // Get the remote address - virtual std::string get_remote_address() const = 0; +// @ngxson: this is a demo for how a bi-directional connection between +// the server and frontend can be implemented using SSE + HTTP POST +// I'm reusing the name "WS" here, but this is not a real WebSocket implementation +// the code is 100% written by human, no AI involved +// but this is just a demo, do not use it in practice + + + +struct server_ws_connection; + +// hacky: server_ws_connection is a member of this struct because +// we want to have shared_ptr for other handler functions +// in practice, we don't really need this +struct server_ws_sse : server_http_res { + std::string id; + std::shared_ptr conn; + const server_http_req & req; + + std::mutex mutex_send; + std::condition_variable cv; + struct msg { + std::string data; + bool is_closed = false; + }; + std::queue queue_send; + + server_ws_sse(const server_http_req & req, const std::string & id) : id(id), req(req) { + conn = std::make_shared(*this); + + queue_send.push({ + "data: {\"llamacpp_id\":\"" + id + "\"}", false + }); + + next = [this, &req, id](std::string & output) { + std::unique_lock lk(mutex_send); + constexpr auto poll_interval = std::chrono::milliseconds(500); + while (true) { + if (!queue_send.empty()) { + output.clear(); + auto & front = queue_send.front(); + if (front.is_closed) { + return false; // closed + } + SRV_INF("%s: sending SSE message: %s\n", id.c_str(), front.data.c_str()); + output = "data: " + front.data + "\n\n"; + queue_send.pop(); + return true; + } + if (req.should_stop()) { + return false; // connection closed + } + cv.wait_for(lk, poll_interval); + } + }; + } + + std::function on_close; + ~server_ws_sse() { + close(); + if (on_close) { + on_close(); + } + } + + void send(const std::string & message) { + std::lock_guard lk(mutex_send); + queue_send.push({message, false}); + cv.notify_all(); + } + + void close() { + std::lock_guard lk(mutex_send); + queue_send.push({"", true}); + cv.notify_all(); + } }; -// Forward declarations -class ws_connection_impl; -// WebSocket context - manages the WebSocket server -// Runs on a separate thread and handles WebSocket connections -struct server_ws_context { - struct Impl; - std::unique_ptr pimpl; - std::thread thread; - std::atomic is_ready = false; +struct server_ws_connection { + server_ws_sse & parent; + server_ws_connection(server_ws_sse & parent) : parent(parent) {} - std::string path_prefix; // e.g., "/mcp" - int port; + // Send a message to the client + void send(const std::string & message) { + parent.send(message); + } - server_ws_context(); - ~server_ws_context(); + // Close the connection + void close(int code = 1000, const std::string & reason = "") { + SRV_INF("%s: closing connection: code=%d, reason=%s\n", + __func__, code, reason.c_str()); + parent.close(); + } - // Initialize the WebSocket server - bool init(const common_params & params); + // Get query parameter by key + std::string get_query_param(const std::string & key) const { + return parent.req.get_param(key); + } - // Start the WebSocket server (runs in background thread) - bool start(); + // Get the remote address + std::string get_remote_address() { + return parent.id; + } +}; - // Stop the WebSocket server - void stop(); - // Get the actual port the WebSocket server is listening on - int get_actual_port() const; - // Set the port for the WebSocket server (note: actual port may differ if set to 0) - void set_port(int port) { this->port = port; } +// SSE + HTTP POST implementation of server_ws_context +struct server_ws_context { + server_ws_context() = default; + ~server_ws_context() = default; + + // map ID to connection + std::mutex mutex; + std::map res_map; + + // SSE endpoint + server_http_context::handler_t get_mcp = [this](const server_http_req & req) { + auto id = random_string(); + auto res = std::make_unique(req, id); + { + std::unique_lock lock(mutex); + res_map[id] = res.get(); + } + SRV_INF("%s: new SSE connection established, ID: %s\n%s", __func__, id.c_str(), req.body.c_str()); + res->id = id; + res->status = 200; + res->headers["X-Connection-ID"] = id; + res->content_type = "text/event-stream"; + // res->next is set in server_ws_sse constructor + res->on_close = [this, id]() { + std::unique_lock lock(mutex); + handler_on_close(res_map[id]->conn); + res_map.erase(id); + }; + handler_on_open(res->conn); + return res; + }; + + // HTTP POST endpoint + server_http_context::handler_t post_mcp = [this](const server_http_req & req) { + auto id = req.get_param("llamacpp_id"); + std::shared_ptr conn; + SRV_INF("%s: received POST for connection ID: %s\n%s", __func__, id.c_str(), req.body.c_str()); + std::unique_lock lock(mutex); + { + auto it = res_map.find(id); + if (it != res_map.end()) { + conn = it->second->conn; + } + } + if (!conn) { + SRV_ERR("%s: invalid connection ID: %s\n", __func__, id.c_str()); + auto res = std::make_unique(); + res->status = 400; + res->data = "Invalid connection ID"; + return res; + } + handler_on_message(conn, req.body); + auto res = std::make_unique(); + res->status = 200; + return res; + }; // Called when new connection is established using on_open_t = std::function)>; - void on_open(on_open_t handler); + void on_open(on_open_t handler) { handler_on_open = handler; } // Called when message is received from a connection using on_message_t = std::function, const std::string &)>; - void on_message(on_message_t handler); + void on_message(on_message_t handler) { handler_on_message = handler; } // Called when connection is closed using on_close_t = std::function)>; - void on_close(on_close_t handler); + void on_close(on_close_t handler) { handler_on_close = handler; } - // For debugging - std::string listening_address; + on_open_t handler_on_open; + on_message_t handler_on_message; + on_close_t handler_on_close; }; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 0dcab5d677f..c6bce3e200b 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -120,20 +120,12 @@ int main(int argc, char ** argv, char ** envp) { // WebSocket Server (for MCP support) - only if --webui-mcp is enabled // - server_ws_context * ctx_ws = nullptr; - server_mcp_bridge * mcp_bridge = nullptr; + std::unique_ptr ctx_ws = nullptr; + std::unique_ptr mcp_bridge = nullptr; if (params.webui_mcp) { - ctx_ws = new server_ws_context(); - mcp_bridge = new server_mcp_bridge(); - - // Initialize WebSocket server with params (sets port to HTTP port + 1) - if (!ctx_ws->init(params)) { - LOG_ERR("%s: failed to initialize WebSocket server\n", __func__); - delete ctx_ws; - delete mcp_bridge; - return 1; - } + ctx_ws = std::make_unique(); + mcp_bridge = std::make_unique(); } // Helper function to get MCP config path @@ -289,6 +281,9 @@ int main(int argc, char ** argv, char ** envp) { res->data = response.dump(); return res; }); + + ctx_http.get ("/mcp", ex_wrapper(ctx_ws->get_mcp)); + ctx_http.post("/mcp", ex_wrapper(ctx_ws->post_mcp)); } // @@ -315,15 +310,8 @@ int main(int argc, char ** argv, char ** envp) { if (is_router_server) { LOG_INF("%s: starting router server, no model will be loaded in this process\n", __func__); - clean_up = [&models_routes, &ctx_ws, &mcp_bridge]() { + clean_up = [&models_routes]() { SRV_INF("%s: cleaning up before exit...\n", __func__); - if (ctx_ws) { - ctx_ws->stop(); - delete ctx_ws; - } - if (mcp_bridge) { - delete mcp_bridge; - } if (models_routes.has_value()) { models_routes->models.unload_all(); } @@ -337,34 +325,14 @@ int main(int argc, char ** argv, char ** envp) { } ctx_http.is_ready.store(true); - // Start WebSocket server (OS will assign an available port) - only if --webui-mcp is enabled - if (params.webui_mcp && ctx_ws) { - if (!ctx_ws->start()) { - clean_up(); - LOG_ERR("%s: exiting due to WebSocket server error\n", __func__); - return 1; - } - LOG_INF("%s: WebSocket server started on port %d\n", __func__, ctx_ws->get_actual_port()); - } - shutdown_handler = [&](int) { - if (ctx_ws) { - ctx_ws->stop(); - } ctx_http.stop(); }; } else { // setup clean up function, to be called before exit - clean_up = [&ctx_http, &ctx_ws, &ctx_server, &mcp_bridge]() { + clean_up = [&ctx_http, &ctx_server]() { SRV_INF("%s: cleaning up before exit...\n", __func__); - if (ctx_ws) { - ctx_ws->stop(); - delete ctx_ws; - } - if (mcp_bridge) { - delete mcp_bridge; - } ctx_http.stop(); ctx_server.terminate(); llama_backend_free(); @@ -377,16 +345,6 @@ int main(int argc, char ** argv, char ** envp) { return 1; } - // Start WebSocket server (OS will assign an available port) - only if --webui-mcp is enabled - if (params.webui_mcp && ctx_ws) { - if (!ctx_ws->start()) { - clean_up(); - LOG_ERR("%s: exiting due to WebSocket server error\n", __func__); - return 1; - } - LOG_INF("%s: WebSocket server started on port %d\n", __func__, ctx_ws->get_actual_port()); - } - // load the model LOG_INF("%s: loading model\n", __func__); @@ -405,10 +363,6 @@ int main(int argc, char ** argv, char ** envp) { LOG_INF("%s: model loaded\n", __func__); shutdown_handler = [&](int) { - // this will unblock start_loop() - if (ctx_ws) { - ctx_ws->stop(); - } ctx_server.terminate(); }; } diff --git a/tools/server/webui/package-lock.json b/tools/server/webui/package-lock.json index 9706678cde1..fa2792f7dbd 100644 --- a/tools/server/webui/package-lock.json +++ b/tools/server/webui/package-lock.json @@ -880,6 +880,7 @@ "integrity": "sha512-oJrXtQiAXLvT9clCf1K4kxp3eKsQhIaZqxEyowkBcsvZDdZkbWrVmnGknxs5flTD0VGsxrxKgBCZty1EzoiMzA==", "dev": true, "license": "Apache-2.0", + "peer": true, "dependencies": { "@swc/helpers": "^0.5.0" } @@ -2105,6 +2106,7 @@ "integrity": "sha512-rO+YQhHucy47Vh67z318pALmd6x+K1Kj30Fb4a6oOEw4xn4zCo9KTmkMWs24c4oduEXD/eJu3badlRmsVXzyfA==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "ts-dedent": "^2.0.0", "type-fest": "~2.19" @@ -2188,6 +2190,7 @@ "integrity": "sha512-Vp3zX/qlwerQmHMP6x0Ry1oY7eKKRcOWGc2P59srOp4zcqyn+etJyQpELgOi4+ZSUgteX8Y387NuwruLgGXLUQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@standard-schema/spec": "^1.0.0", "@sveltejs/acorn-typescript": "^1.0.5", @@ -2227,6 +2230,7 @@ "integrity": "sha512-YZs/OSKOQAQCnJvM/P+F1URotNnYNeU3P2s4oIpzm1uFaqUEqRxUB0g5ejMjEb5Gjb9/PiBI5Ktrq4rUUF8UVQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@sveltejs/vite-plugin-svelte-inspector": "^5.0.0", "debug": "^4.4.1", @@ -2642,6 +2646,7 @@ "integrity": "sha512-pemlzrSESWbdAloYml3bAJMEfNh1Z7EduzqPKprCH5S341frlpYnUEW0H72dLxa6IsYr+mPno20GiSm+h9dEdQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@babel/code-frame": "^7.10.4", "@babel/runtime": "^7.12.5", @@ -2809,6 +2814,7 @@ "integrity": "sha512-bJFoMATwIGaxxx8VJPeM8TonI8t579oRvgAuT8zFugJsJZgzqv0Fu8Mhp68iecjzG7cnN3mO2dJQ5uUM2EFrgQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "undici-types": "~6.21.0" } @@ -2876,6 +2882,7 @@ "integrity": "sha512-kVIaQE9vrN9RLCQMQ3iyRlVJpTiDUY6woHGb30JDkfJErqrQEmtdWH3gV0PBAfGZgQXoqzXOO0T3K6ioApbbAA==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@typescript-eslint/scope-manager": "8.37.0", "@typescript-eslint/types": "8.37.0", @@ -3100,6 +3107,7 @@ "integrity": "sha512-tJxiPrWmzH8a+w9nLKlQMzAKX/7VjFs50MWgcAj7p9XQ7AQ9/35fByFYptgPELyLw+0aixTnC4pUWV+APcZ/kw==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@testing-library/dom": "^10.4.0", "@testing-library/user-event": "^14.6.1", @@ -3203,6 +3211,7 @@ "integrity": "sha512-oukfKT9Mk41LreEW09vt45f8wx7DordoWUZMYdY/cyAk7w5TWkTRCNZYF7sX7n2wB7jyGAl74OxgwhPgKaqDMQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@vitest/utils": "3.2.4", "pathe": "^2.0.3", @@ -3273,6 +3282,7 @@ "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "license": "MIT", + "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -3954,8 +3964,7 @@ "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.1.3.tgz", "integrity": "sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==", "dev": true, - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/debug": { "version": "4.4.3", @@ -4209,6 +4218,7 @@ "dev": true, "hasInstallScript": true, "license": "MIT", + "peer": true, "bin": { "esbuild": "bin/esbuild" }, @@ -4269,6 +4279,7 @@ "integrity": "sha512-QldCVh/ztyKJJZLr4jXNUByx3gR+TDYZCRXEktiZoUR3PGy4qCmSbkxcIle8GEwGpb5JBZazlaJ/CxLidXdEbQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@eslint-community/eslint-utils": "^4.2.0", "@eslint-community/regexpp": "^4.12.1", @@ -7501,6 +7512,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "nanoid": "^3.3.11", "picocolors": "^1.1.1", @@ -7634,6 +7646,7 @@ "integrity": "sha512-I7AIg5boAr5R0FFtJ6rCfD+LFsWHp81dolrFD8S79U9tb8Az2nGrJncnMSnys+bpQJfRUzqs9hnA81OAA3hCuQ==", "dev": true, "license": "MIT", + "peer": true, "bin": { "prettier": "bin/prettier.cjs" }, @@ -7650,6 +7663,7 @@ "integrity": "sha512-pn1ra/0mPObzqoIQn/vUTR3ZZI6UuZ0sHqMK5x2jMLGrs53h0sXhkVuDcrlssHwIMk7FYrMjHBPoUSyyEEDlBQ==", "dev": true, "license": "MIT", + "peer": true, "peerDependencies": { "prettier": "^3.0.0", "svelte": "^3.2.0 || ^4.0.0-next.0 || ^5.0.0-next.0" @@ -7926,6 +7940,7 @@ "integrity": "sha512-FS+XFBNvn3GTAWq26joslQgWNoFu08F4kl0J4CgdNKADkdSGXQyTCnKteIAJy96Br6YbpEU1LSzV5dYtjMkMDg==", "dev": true, "license": "MIT", + "peer": true, "engines": { "node": ">=0.10.0" } @@ -7936,6 +7951,7 @@ "integrity": "sha512-Xs1hdnE+DyKgeHJeJznQmYMIBG3TKIHJJT95Q58nHLSrElKlGQqDTR2HQ9fx5CN/Gk6Vh/kupBTDLU11/nDk/g==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "scheduler": "^0.26.0" }, @@ -8221,6 +8237,7 @@ "integrity": "sha512-4iya7Jb76fVpQyLoiVpzUrsjQ12r3dM7fIVz+4NwoYvZOShknRmiv+iu9CClZml5ZLGb0XMcYLutK6w9tgxHDw==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@types/estree": "1.0.8" }, @@ -8342,6 +8359,7 @@ "integrity": "sha512-elOcIZRTM76dvxNAjqYrucTSI0teAF/L2Lv0s6f6b7FOwcwIuA357bIE871580AjHJuSvLIRUosgV+lIWx6Rgg==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "chokidar": "^4.0.0", "immutable": "^5.0.2", @@ -8630,6 +8648,7 @@ "integrity": "sha512-7smAu0o+kdm378Q2uIddk32pn0UdIbrtTVU+rXRVtTVTCrK/P2cCui2y4JH+Bl3NgEq1bbBQpCAF/HKrDjk2Qw==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@storybook/global": "^5.0.0", "@storybook/icons": "^1.6.0", @@ -8775,6 +8794,7 @@ "resolved": "https://registry.npmjs.org/svelte/-/svelte-5.36.12.tgz", "integrity": "sha512-c3mWT+b0yBLl3gPGSHiy4pdSQCsPNTjLC0tVoOhrGJ6PPfCzD/RQpAmAfJtQZ304CAae2ph+L3C4aqds3R3seQ==", "license": "MIT", + "peer": true, "dependencies": { "@ampproject/remapping": "^2.3.0", "@jridgewell/sourcemap-codec": "^1.5.0", @@ -9018,6 +9038,7 @@ "integrity": "sha512-gBXpgUm/3rp1lMZZrM/w7D8GKqshif0zAymAhbCyIt8KMe+0v9DQ7cdYLR4FHH/cKpdTXb+A/tKKU3eolfsI+g==", "dev": true, "license": "MIT", + "peer": true, "funding": { "type": "github", "url": "https://github.com/sponsors/dcastil" @@ -9048,7 +9069,8 @@ "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.1.11.tgz", "integrity": "sha512-2E9TBm6MDD/xKYe+dvJZAmg3yxIEDNRc0jwlNyDg/4Fil2QcSLjFKGVff0lAf1jjeaArlG/M75Ey/EYr/OJtBA==", "dev": true, - "license": "MIT" + "license": "MIT", + "peer": true }, "node_modules/tapable": { "version": "2.2.2", @@ -9284,6 +9306,7 @@ "integrity": "sha512-p1diW6TqL9L07nNxvRMM7hMMw4c5XOo/1ibL4aAIGmSAt9slTE1Xgw5KWuof2uTOvCg9BY7ZRi+GaF+7sfgPeQ==", "dev": true, "license": "Apache-2.0", + "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -9667,6 +9690,7 @@ "integrity": "sha512-BxAKBWmIbrDgrokdGZH1IgkIk/5mMHDreLDmCJ0qpyJaAteP8NvMhkwr/ZCQNqNH97bw/dANTE9PDzqwJghfMQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "esbuild": "^0.25.0", "fdir": "^6.5.0", @@ -9827,6 +9851,7 @@ "integrity": "sha512-LUCP5ev3GURDysTWiP47wRRUpLKMOfPh+yKTx3kVIEiu5KOMeqzpnYNsKyOoVrULivR8tLcks4+lga33Whn90A==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@types/chai": "^5.2.2", "@vitest/expect": "3.2.4", @@ -10047,6 +10072,7 @@ "resolved": "https://registry.npmjs.org/zod/-/zod-4.2.1.tgz", "integrity": "sha512-0wZ1IRqGGhMP76gLqz8EyfBXKk0J2qo2+H3fi4mcUP/KtTocoX08nmIAHl1Z2kJIZbZee8KOpBCSNPRgauucjw==", "license": "MIT", + "peer": true, "funding": { "url": "https://github.com/sponsors/colinhacks" } diff --git a/tools/server/webui/src/lib/services/mcp-transport-custom.ts b/tools/server/webui/src/lib/services/mcp-transport-custom.ts new file mode 100644 index 00000000000..3ce1c8174c0 --- /dev/null +++ b/tools/server/webui/src/lib/services/mcp-transport-custom.ts @@ -0,0 +1,113 @@ +import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; +import { JSONRPCMessageSchema } from '@modelcontextprotocol/sdk/types.js'; +import { type Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; + +const SUBPROTOCOL = 'mcp'; + +// modified from WebSocketClientTransport + +export class CustomLlamaCppTransport implements Transport { + private _evSource?: EventSource; + private _url: URL; + private _id: string = 'unknown'; + + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: JSONRPCMessage) => void; + + constructor(url: string) { + this._url = new URL(url); + } + + start(): Promise { + if (this._evSource) { + throw new Error( + 'CustomLlamaCppTransport already started! If using Client class, note that connect() calls start() automatically.' + ); + } + + return new Promise((resolve, reject) => { + console.log('Connecting to SSE:', this._url.toString()); + this._evSource = new EventSource(this._url.toString(), { withCredentials: true }); + + this._evSource.onerror = (event) => { + if (event.eventPhase == EventSource.CLOSED) { + this.onclose?.(); + console.log('Event Source Closed'); + } + + const error = + 'error' in event + ? (event.error as Error) + : new Error(`EventSource error: ${JSON.stringify(event)}`); + reject(error); + this.onerror?.(error); + }; + + // this._evSource.onopen = () => { + // resolve(); + // }; + + // this._evSource.onclose = () => { + // this.onclose?.(); + // }; + + this._evSource.onmessage = (event: MessageEvent) => { + const raw = event.data.startsWith('data: ') ? event.data.slice(6) : event.data; + console.log('SSE Message received:', raw); + const data = JSON.parse(raw); + if (data.llamacpp_id) { + this._id = data.llamacpp_id; + console.log('Connected to SSE with id:', this._id); + resolve(); + return; + } + let message: JSONRPCMessage; + try { + message = JSONRPCMessageSchema.parse(data); + } catch (error) { + this.onerror?.(error as Error); + return; + } + + this.onmessage?.(message); + }; + }); + } + + async close(): Promise { + this._evSource?.close(); + } + + send(message: JSONRPCMessage): Promise { + return new Promise((resolve, reject) => { + if (!this._evSource) { + reject(new Error('Not connected')); + return; + } + + if (this._id == 'unknown') { + reject(new Error('Connection ID not set yet')); + return; + } + + const url = new URL(this._url.toString()); + url.searchParams.append('llamacpp_id', this._id); + fetch(url.toString(), { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'X-Subprotocol': SUBPROTOCOL // redundant, maybe remove later + }, + body: JSON.stringify(message), + credentials: 'include' // redundant too, maybe remove later + }) + .then(() => { + resolve(); + }) + .catch((error) => { + reject(error); + }); + }); + } +} diff --git a/tools/server/webui/src/lib/services/mcp.ts b/tools/server/webui/src/lib/services/mcp.ts index 881ebbfaea5..96cbb83ab58 100644 --- a/tools/server/webui/src/lib/services/mcp.ts +++ b/tools/server/webui/src/lib/services/mcp.ts @@ -8,8 +8,9 @@ */ import { Client } from '@modelcontextprotocol/sdk/client/index.js'; -import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js'; +// import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js'; import type { Tool, Notification } from '@modelcontextprotocol/sdk/types.js'; +import { CustomLlamaCppTransport } from './mcp-transport-custom'; // Timeout constants - increased for Docker containers that may take time to start const REQUEST_TIMEOUT_MS = 120_000; // 2 minutes for slow-starting containers @@ -19,7 +20,7 @@ const INITIAL_RECONNECT_DELAY_MS = 1000; // Start with 1s delay export class McpService { private client: Client | null = null; - private transport: WebSocketClientTransport | null = null; + private transport: CustomLlamaCppTransport | null = null; private reconnectTimer: ReturnType | null = null; private reconnectAttempts = 0; private manualDisconnect = false; @@ -57,7 +58,7 @@ export class McpService { async connect(): Promise { try { // Create a new transport instance for each connection attempt - this.transport = new WebSocketClientTransport(new URL(this.wsUrl)); + this.transport = new CustomLlamaCppTransport(this.wsUrl); // Set up transport event handlers this.transport.onclose = () => { @@ -275,8 +276,8 @@ export const mcpServiceFactory = { */ create: (serverName: string): McpService => { const url = new URL(window.location.href); - const wsPort = (parseInt(url.port) || 80) + 1; - const wsUrl = `ws://${url.hostname}:${wsPort}/mcp?server=${encodeURIComponent(serverName)}`; - return new McpService(serverName, wsUrl); + url.pathname = '/mcp'; + url.searchParams.set('server', serverName); + return new McpService(serverName, url.toString()); } }; diff --git a/tools/server/webui/src/lib/stores/mcp.svelte.ts b/tools/server/webui/src/lib/stores/mcp.svelte.ts index 018de16d58d..f0febf5df03 100644 --- a/tools/server/webui/src/lib/stores/mcp.svelte.ts +++ b/tools/server/webui/src/lib/stores/mcp.svelte.ts @@ -16,7 +16,7 @@ */ import type { McpConnectionState, McpTool } from '$lib/types/mcp'; -import { McpService } from '$lib/services/mcp'; +import { McpService, mcpServiceFactory } from '$lib/services/mcp'; import { SvelteMap, SvelteSet } from 'svelte/reactivity'; class McpStore { @@ -138,11 +138,6 @@ class McpStore { await new Promise((resolve) => setTimeout(resolve, 500)); } - // WebSocket port is always HTTP port + 1 - const url = new URL(window.location.href); - const wsPort = parseInt(url.port) + 1; - const wsUrl = `ws://${window.location.hostname}:${wsPort}/mcp?server=${encodeURIComponent(serverName)}`; - // Increment connection generation for this server const currentGen = (this.connectionGenerations.get(serverName) ?? 0) + 1; this.connectionGenerations.set(serverName, currentGen); @@ -159,7 +154,7 @@ class McpStore { }); try { - const service = new McpService(serverName, wsUrl); + const service = mcpServiceFactory.create(serverName); this.services.set(serverName, service); // Set up event handlers with generation checking diff --git a/tools/server/webui/vite.config.ts b/tools/server/webui/vite.config.ts index 899d3f63ea3..349738a2674 100644 --- a/tools/server/webui/vite.config.ts +++ b/tools/server/webui/vite.config.ts @@ -159,13 +159,7 @@ export default defineConfig({ '/models': 'http://localhost:8080', // HTTP endpoint for MCP '/mcp/servers': 'http://localhost:8080', - // WebSocket endpoint for MCP (on HTTP port + 1) - '/mcp': { - target: 'http://localhost:8081', - ws: true, - // Don't rewrite the path - keep /mcp as-is - rewrite: (path) => path - } + '/mcp': 'http://localhost:8080' }, headers: { 'Cross-Origin-Embedder-Policy': 'require-corp',