diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz
index a9a3bf700d0..9b991559df9 100644
Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ
diff --git a/tools/server/server-ws.cpp b/tools/server/server-ws.cpp
index 2d6bdf76452..1353f6e8394 100644
--- a/tools/server/server-ws.cpp
+++ b/tools/server/server-ws.cpp
@@ -1,879 +1 @@
-#include "server-ws.h"
-#include "common.h"
-#include "log.h"
-#include "arg.h"
-#include "base64.hpp"
-#include "sha1/sha1.h"
-
-#include
-#include
-#include
-#include
-#include
-
-// Security limits to prevent DoS attacks
-static constexpr size_t WS_MAX_PAYLOAD_SIZE = 10 * 1024 * 1024; // 10MB per frame
-static constexpr size_t WS_MAX_MESSAGE_SIZE = 100 * 1024 * 1024; // 100MB assembled message
-static constexpr size_t WS_MAX_RECEIVE_BUFFER = 16 * 1024 * 1024; // 16MB receive buffer
-static constexpr size_t WS_MAX_CONNECTIONS_DEFAULT = 10; // Default max concurrent connections
-static constexpr int WS_SOCKET_TIMEOUT_SEC = 300; // 5 minute socket read timeout (MCP tools can be slow)
-
-// Get max connections limit from environment or use default
-static size_t get_ws_max_connections() {
- static size_t cached = 0;
- static bool initialized = false;
- if (!initialized) {
- const char * env = std::getenv("LLAMA_WS_MAX_CONNECTIONS");
- if (env) {
- cached = static_cast(std::atoi(env));
- if (cached == 0) {
- cached = WS_MAX_CONNECTIONS_DEFAULT;
- }
- } else {
- cached = WS_MAX_CONNECTIONS_DEFAULT;
- }
- initialized = true;
- }
- return cached;
-}
-
-#ifdef _WIN32
-#include
-#include
-typedef SOCKET socket_t;
-#define INVALID_SOCKET_VALUE INVALID_SOCKET
-#define SOCKET_ERROR_VALUE SOCKET_ERROR
-#else
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-typedef int socket_t;
-#define INVALID_SOCKET_VALUE -1
-#define SOCKET_ERROR_VALUE -1
-#endif
-
-// WebSocket frame constants
-namespace ws_frame {
- constexpr uint8_t FIN_BIT = 0x80;
-
- enum opcode : uint8_t {
- CONTINUATION = 0x0,
- TEXT = 0x1,
- BINARY = 0x2,
- CLOSE = 0x8,
- PING = 0x9,
- PONG = 0xa
- };
-}
-
-// Simple WebSocket implementation using raw sockets
-class ws_connection_impl : public server_ws_connection {
-public:
- ws_connection_impl(socket_t sock, const std::string & path, const std::string & query)
- : sock_(sock)
- , path_(path)
- , query_(query)
- , closed_(false) {
- parse_query_params();
- }
-
- ~ws_connection_impl() override {
- close(1000, "");
- }
-
- void send(const std::string & message) override {
- if (closed_) {
- SRV_WRN("%s: cannot send, connection closed: %s\n", __func__, get_remote_address().c_str());
- return;
- }
-
- // Create WebSocket text frame
- std::vector frame;
-
- uint8_t first_byte = ws_frame::FIN_BIT | ws_frame::TEXT;
- frame.push_back(first_byte);
-
- size_t len = message.size();
- if (len < 126) {
- frame.push_back(static_cast(len));
- } else if (len < 65536) {
- frame.push_back(126);
- frame.push_back(static_cast((len >> 8) & 0xff));
- frame.push_back(static_cast(len & 0xff));
- } else {
- frame.push_back(127);
- for (int i = 7; i >= 0; i--) {
- frame.push_back(static_cast((len >> (i * 8)) & 0xff));
- }
- }
-
- frame.insert(frame.end(), message.begin(), message.end());
-
- // Log frame header for debugging
- SRV_INF("%s: %s: frame_header=[%02x %02x %02x %02x ...] payload_size=%zd\n",
- __func__, get_remote_address().c_str(),
- frame.size() > 0 ? frame[0] : 0,
- frame.size() > 1 ? frame[1] : 0,
- frame.size() > 2 ? frame[2] : 0,
- frame.size() > 3 ? frame[3] : 0,
- len);
-
- // Send frame
-#ifdef _WIN32
- int sent = ::send(sock_, reinterpret_cast(frame.data()), static_cast(frame.size()), 0);
-#else
- ssize_t sent = ::send(sock_, frame.data(), frame.size(), 0);
-#endif
- SRV_INF("%s: %s: frame_size=%zd sent=%zd\n", __func__, get_remote_address().c_str(), frame.size(), sent);
- if (sent < 0) {
- SRV_ERR("%s: send failed: %s (errno=%d)\n", __func__, get_remote_address().c_str(), errno);
- closed_ = true;
- } else if (static_cast(sent) != frame.size()) {
- SRV_WRN("%s: partial send: %zd/%zd bytes\n", __func__, sent, frame.size());
- }
- }
-
- void close(int code, const std::string & reason) override {
- if (closed_) {
- return;
- }
-
- // Send close frame
- std::vector close_frame;
- close_frame.push_back(ws_frame::FIN_BIT | ws_frame::CLOSE);
- close_frame.push_back(2 + reason.size());
- close_frame.push_back(static_cast((code >> 8) & 0xff));
- close_frame.push_back(static_cast(code & 0xff));
- close_frame.insert(close_frame.end(), reason.begin(), reason.end());
-
-#ifdef _WIN32
- ::send(sock_, reinterpret_cast(close_frame.data()), static_cast(close_frame.size()), 0);
- closesocket(sock_);
-#else
- ::send(sock_, close_frame.data(), close_frame.size(), 0);
- ::close(sock_);
-#endif
-
- closed_ = true;
- }
-
- std::string get_query_param(const std::string & key) const override {
- auto it = query_params_.find(key);
- if (it != query_params_.end()) {
- return it->second;
- }
- return "";
- }
-
- std::string get_remote_address() const override {
- return remote_address_;
- }
-
- socket_t socket() const { return sock_; }
- bool is_closed() const { return closed_; }
-
- // Handle incoming data
- void handle_data(const std::vector & data, std::function on_message) {
- // Check receive buffer limit to prevent DoS
- if (receive_buffer_.size() + data.size() > WS_MAX_RECEIVE_BUFFER) {
- SRV_ERR("%s: receive buffer overflow, closing connection\n", __func__);
- close(1009, "Message too big");
- return;
- }
- receive_buffer_.insert(receive_buffer_.end(), data.begin(), data.end());
-
- while (true) {
- if (receive_buffer_.size() < 2) {
- break; // Need at least 2 bytes for header
- }
-
- uint8_t first_byte = receive_buffer_[0];
- uint8_t second_byte = receive_buffer_[1];
-
- // Validate RSV bits (must be 0 unless extension negotiated)
- uint8_t rsv_bits = (first_byte & 0x70);
- if (rsv_bits != 0) {
- SRV_ERR("%s: protocol error - RSV bits must be 0\n", __func__);
- close(1002, "Protocol error");
- return;
- }
-
- bool fin = (first_byte & 0x80) != 0;
- ws_frame::opcode opcode = static_cast(first_byte & 0x0f);
- bool masked = (second_byte & 0x80) != 0;
- uint64_t payload_len = second_byte & 0x7f;
-
- // RFC 6455: Client frames MUST be masked
- if (!masked) {
- SRV_ERR("%s: protocol error - client frames must be masked\n", __func__);
- close(1002, "Protocol error");
- return;
- }
-
- size_t header_len = 2;
-
- if (payload_len == 126) {
- if (receive_buffer_.size() < 4) break;
- payload_len = (static_cast(receive_buffer_[2]) << 8) |
- static_cast(receive_buffer_[3]);
- header_len = 4;
- } else if (payload_len == 127) {
- if (receive_buffer_.size() < 10) break;
- payload_len = 0;
- for (int i = 0; i < 8; i++) {
- payload_len = (payload_len << 8) | receive_buffer_[2 + i];
- }
- header_len = 10;
- }
-
- // Validate payload size to prevent DoS via huge allocations
- if (payload_len > WS_MAX_PAYLOAD_SIZE) {
- SRV_ERR("%s: payload too large: %zu > %zu\n", __func__,
- static_cast(payload_len), WS_MAX_PAYLOAD_SIZE);
- close(1009, "Message too big");
- return;
- }
-
- // Check for integer overflow in total_len calculation
- if (payload_len > SIZE_MAX - header_len - 4) {
- SRV_ERR("%s: frame size overflow\n", __func__);
- close(1009, "Message too big");
- return;
- }
-
- size_t total_len = header_len + payload_len + (masked ? 4 : 0);
- if (receive_buffer_.size() < total_len) {
- break; // Incomplete frame
- }
-
- // Extract payload (skip the mask if present)
- size_t payload_offset = header_len + (masked ? 4 : 0);
- std::vector payload(receive_buffer_.begin() + payload_offset,
- receive_buffer_.begin() + payload_offset + payload_len);
-
- // Unmask if needed
- if (masked) {
- uint8_t mask[4];
- std::memcpy(mask, &receive_buffer_[header_len], 4);
- SRV_DBG("%s: unmasking payload with mask: [0x%02x, 0x%02x, 0x%02x, 0x%02x]\n",
- __func__, mask[0], mask[1], mask[2], mask[3]);
- for (size_t i = 0; i < payload_len; i++) {
- uint8_t masked = payload[i];
- payload[i] ^= mask[i % 4];
- if (i < 20) {
- SRV_DBG("%s: [%zu] masked=0x%02x unmasked=0x%02x ('%c')\n",
- __func__, i, masked, payload[i],
- isprint(payload[i]) ? payload[i] : '.');
- }
- }
- SRV_DBG("%s: first 20 chars of payload: '%.*s'\n",
- __func__, (int)std::min(size_t(20), payload.size()),
- payload.data());
- }
-
- // Remove processed frame from buffer
- receive_buffer_.erase(receive_buffer_.begin(),
- receive_buffer_.begin() + total_len);
-
- // Handle frame
- if (opcode == ws_frame::TEXT || opcode == ws_frame::CONTINUATION) {
- // Check message buffer limit to prevent DoS via fragmented messages
- if (message_buffer_.size() + payload.size() > WS_MAX_MESSAGE_SIZE) {
- SRV_ERR("%s: message too large, closing connection\n", __func__);
- close(1009, "Message too big");
- message_buffer_.clear();
- return;
- }
- if (fin) {
- // Complete message
- message_buffer_.insert(message_buffer_.end(), payload.begin(), payload.end());
- std::string msg(message_buffer_.begin(), message_buffer_.end());
- on_message(msg);
- message_buffer_.clear();
- } else {
- // Fragmented message
- message_buffer_.insert(message_buffer_.end(), payload.begin(), payload.end());
- }
- } else if (opcode == ws_frame::PING) {
- // Respond with pong
- send_pong(payload);
- } else if (opcode == ws_frame::CLOSE) {
- close(1000, "Normal closure");
- break;
- }
- }
- }
-
- void set_remote_address(const std::string & addr) {
- remote_address_ = addr;
- }
-
-private:
- socket_t sock_;
- std::string path_;
- std::string query_;
- std::string remote_address_;
- std::map query_params_;
- bool closed_;
-
- std::vector receive_buffer_;
- std::vector message_buffer_;
-
- void parse_query_params() {
- std::string remaining = query_;
- while (!remaining.empty()) {
- size_t amp_pos = remaining.find('&');
- std::string pair;
- if (amp_pos == std::string::npos) {
- pair = remaining;
- remaining.clear();
- } else {
- pair = remaining.substr(0, amp_pos);
- remaining = remaining.substr(amp_pos + 1);
- }
-
- size_t eq_pos = pair.find('=');
- if (eq_pos != std::string::npos) {
- std::string key = pair.substr(0, eq_pos);
- std::string value = pair.substr(eq_pos + 1);
- query_params_[key] = value;
- }
- }
- }
-
- void send_pong(const std::vector & payload) {
- // RFC 6455: Control frames must have payload ≤ 125 bytes
- if (payload.size() > 125) {
- SRV_WRN("%s: PING payload too large (%zu > 125), ignoring\n",
- __func__, payload.size());
- return;
- }
-
- std::vector frame;
- frame.push_back(ws_frame::FIN_BIT | ws_frame::PONG);
- frame.push_back(static_cast(payload.size()));
- frame.insert(frame.end(), payload.begin(), payload.end());
-
-#ifdef _WIN32
- ::send(sock_, reinterpret_cast(frame.data()), static_cast(frame.size()), 0);
-#else
- ::send(sock_, frame.data(), frame.size(), 0);
-#endif
- }
-};
-
-struct server_ws_context::Impl {
- socket_t listen_sock = INVALID_SOCKET_VALUE;
- std::atomic running{false};
- std::thread accept_thread;
- std::mutex connections_mutex;
- std::map> connections;
-
- server_ws_context::on_open_t on_open_cb;
- server_ws_context::on_message_t on_message_cb;
- server_ws_context::on_close_t on_close_cb;
-
- int port = 0;
- std::string path_prefix = "/mcp";
- std::vector api_keys;
-
- // Methods
- void accept_loop();
- void handle_connection(socket_t sock, const struct sockaddr_in & addr);
- bool validate_api_key(const std::string & auth_header) const;
-};
-
-server_ws_context::server_ws_context()
- : pimpl(std::make_unique()) {
-}
-
-server_ws_context::~server_ws_context() {
- stop();
-}
-
-bool server_ws_context::init(const common_params & params) {
- // Use port + 1 from the HTTP server to avoid conflicts
- // This provides a predictable port for the frontend
- pimpl->port = params.port + 1;
- pimpl->path_prefix = "/mcp";
- pimpl->api_keys = params.api_keys;
-
- if (pimpl->api_keys.size() == 1) {
- auto key = pimpl->api_keys[0];
- std::string substr = key.substr(std::max((int)(key.length() - 4), 0));
- SRV_INF("%s: api_keys: ****%s\n", __func__, substr.c_str());
- } else if (pimpl->api_keys.size() > 1) {
- SRV_INF("%s: api_keys: %zu keys loaded\n", __func__, pimpl->api_keys.size());
- }
-
- SRV_INF("%s: WebSocket context initialized\n", __func__);
- return true;
-}
-
-bool server_ws_context::start() {
- if (pimpl->running) {
- SRV_WRN("%s: WebSocket server already running\n", __func__);
- return true;
- }
-
-#ifdef _WIN32
- // Initialize Winsock
- WSADATA wsa_data;
- if (WSAStartup(MAKEWORD(2, 2), &wsa_data) != 0) {
- SRV_ERR("%s: WSAStartup failed\n", __func__);
- return false;
- }
-#endif
-
- // Create listening socket
- pimpl->listen_sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
- if (pimpl->listen_sock == INVALID_SOCKET_VALUE) {
- SRV_ERR("%s: socket() failed\n", __func__);
- return false;
- }
-
- // Set SO_REUSEADDR
- int opt = 1;
-#ifdef _WIN32
- setsockopt(pimpl->listen_sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&opt), sizeof(opt));
-#else
- setsockopt(pimpl->listen_sock, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
-#endif
-
- // Bind to address
- struct sockaddr_in addr = {};
- addr.sin_family = AF_INET;
- addr.sin_addr.s_addr = INADDR_ANY;
- addr.sin_port = htons(pimpl->port);
-
- if (bind(pimpl->listen_sock, reinterpret_cast(&addr), sizeof(addr)) < 0) {
- SRV_ERR("%s: bind() failed on port %d\n", __func__, pimpl->port);
-#ifdef _WIN32
- closesocket(pimpl->listen_sock);
-#else
- close(pimpl->listen_sock);
-#endif
- return false;
- }
-
- // Listen
- if (listen(pimpl->listen_sock, SOMAXCONN) < 0) {
- SRV_ERR("%s: listen() failed\n", __func__);
-#ifdef _WIN32
- closesocket(pimpl->listen_sock);
-#else
- close(pimpl->listen_sock);
-#endif
- return false;
- }
-
- // Get actual port
- struct sockaddr_in actual_addr;
- socklen_t len = sizeof(actual_addr);
- getsockname(pimpl->listen_sock, reinterpret_cast(&actual_addr), &len);
- pimpl->port = ntohs(actual_addr.sin_port);
-
- std::ostringstream oss;
- oss << "ws://" << actual_addr.sin_addr.s_addr << ":" << pimpl->port;
- listening_address = oss.str();
-
- pimpl->running = true;
-
- // Start accept thread
- pimpl->accept_thread = std::thread([this]() {
- pimpl->accept_loop();
- });
-
- is_ready = true;
-
- SRV_INF("%s: WebSocket server started on port %d\n", __func__, pimpl->port);
- return true;
-}
-
-void server_ws_context::stop() {
- if (!pimpl->running) {
- return;
- }
-
- pimpl->running = false;
-
- // Close all connections
- {
- std::lock_guard lock(pimpl->connections_mutex);
- for (auto & [key, conn] : pimpl->connections) {
- conn->close(1001, "Server shutdown");
- }
- pimpl->connections.clear();
- }
-
- // Close listening socket
- if (pimpl->listen_sock != INVALID_SOCKET_VALUE) {
-#ifdef _WIN32
- closesocket(pimpl->listen_sock);
-#else
- close(pimpl->listen_sock);
-#endif
- pimpl->listen_sock = INVALID_SOCKET_VALUE;
- }
-
- // Wait for accept thread
- if (pimpl->accept_thread.joinable()) {
- pimpl->accept_thread.join();
- }
-
- is_ready.store(false);
-
- SRV_INF("%s: WebSocket server stopped\n", __func__);
-}
-
-int server_ws_context::get_actual_port() const {
- return pimpl->port;
-}
-
-void server_ws_context::on_open(on_open_t handler) {
- pimpl->on_open_cb = std::move(handler);
-}
-
-void server_ws_context::on_message(on_message_t handler) {
- pimpl->on_message_cb = std::move(handler);
-}
-
-void server_ws_context::on_close(on_close_t handler) {
- pimpl->on_close_cb = std::move(handler);
-}
-
-bool server_ws_context::Impl::validate_api_key(const std::string & auth_header) const {
- // Use shared authentication helper from server-common
- return ::validate_auth_header(auth_header, api_keys);
-}
-
-void server_ws_context::Impl::accept_loop() {
- while (running) {
- struct sockaddr_in client_addr = {};
- socklen_t client_len = sizeof(client_addr);
-
- socket_t client_sock = accept(listen_sock,
- reinterpret_cast(&client_addr),
- &client_len);
-
- if (client_sock == INVALID_SOCKET_VALUE) {
- if (running) {
- SRV_ERR("%s: accept() failed\n", __func__);
- }
- continue;
- }
-
- // Check connection limit to prevent resource exhaustion
- {
- std::lock_guard lock(connections_mutex);
- const size_t max_conn = get_ws_max_connections();
- if (connections.size() >= max_conn) {
- SRV_WRN("%s: connection limit reached (%zu), rejecting\n",
- __func__, max_conn);
- const char * response = "HTTP/1.1 503 Service Unavailable\r\n\r\n";
- send(client_sock, response, strlen(response), 0);
-#ifdef _WIN32
- closesocket(client_sock);
-#else
- close(client_sock);
-#endif
- continue;
- }
- }
-
- // Handle connection in a thread
- std::thread([this, client_sock, client_addr]() {
- this->handle_connection(client_sock, client_addr);
- }).detach();
- }
-}
-
-void server_ws_context::Impl::handle_connection(socket_t sock, const struct sockaddr_in & addr) {
- // Set socket options
- int flag = 1;
-#ifdef _WIN32
- setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&flag), sizeof(flag));
- // Set socket timeout to prevent slow-loris attacks
- DWORD timeout_ms = WS_SOCKET_TIMEOUT_SEC * 1000;
- setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&timeout_ms), sizeof(timeout_ms));
-#else
- setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(flag));
- // Set socket timeout to prevent slow-loris attacks
- struct timeval tv;
- tv.tv_sec = WS_SOCKET_TIMEOUT_SEC;
- tv.tv_usec = 0;
- setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
-#endif
-
- // Read HTTP request (WebSocket handshake)
- std::vector buffer(4096);
-#ifdef _WIN32
- int recv_len = recv(sock, reinterpret_cast(buffer.data()), buffer.size() - 1, 0);
-#else
- ssize_t recv_len = recv(sock, buffer.data(), buffer.size() - 1, 0);
-#endif
-
- if (recv_len <= 0) {
-#ifdef _WIN32
- closesocket(sock);
-#else
- close(sock);
-#endif
- return;
- }
-
- buffer[recv_len] = '\0';
- std::string request(reinterpret_cast(buffer.data()));
-
- // Parse HTTP request
- std::istringstream iss(request);
- std::string method, path, version;
- iss >> method >> path >> version;
-
- if (method != "GET") {
- // Bad request
- const char * response = "HTTP/1.1 400 Bad Request\r\n\r\n";
- send(sock, response, strlen(response), 0);
-#ifdef _WIN32
- closesocket(sock);
-#else
- close(sock);
-#endif
- return;
- }
-
- // Parse path and query
- std::string query;
- size_t qpos = path.find('?');
- if (qpos != std::string::npos) {
- query = path.substr(qpos + 1);
- path = path.substr(0, qpos);
- }
-
- // Check if path matches our prefix
- if (path != path_prefix) {
- // Not found
- const char * response = "HTTP/1.1 404 Not Found\r\n\r\n";
- send(sock, response, strlen(response), 0);
-#ifdef _WIN32
- closesocket(sock);
-#else
- close(sock);
-#endif
- return;
- }
-
- // After iss >> method >> path >> version, consume the remaining empty line
- // The stream is positioned right after "HTTP/1.1", the next line is just "\r\n"
- std::string line; // Declare line for use below
- std::getline(iss, line);
-
- // Extract headers (case-insensitive matching)
- std::string websocket_key;
- std::string websocket_protocol;
- std::string authorization_header;
- while (std::getline(iss, line)) {
- // Trim trailing \r (common in HTTP headers)
- if (!line.empty() && line.back() == '\r') {
- line.pop_back();
- }
- // Empty line marks end of headers
- if (line.empty()) {
- break;
- }
- // Case-insensitive header matching
- // Convert line to lowercase for comparison, but preserve original for value extraction
- std::string line_lower = line;
- for (char & c : line_lower) {
- if (c >= 'A' && c <= 'Z') {
- c = c + 32; // tolower
- }
- }
-
- // Use string_starts_with for safe prefix matching
- if (string_starts_with(line_lower, "sec-websocket-key:")) {
- const std::string header_name = "sec-websocket-key:";
- if (line.length() > header_name.length()) {
- websocket_key = line.substr(header_name.length());
- // Trim leading spaces
- while (!websocket_key.empty() && websocket_key[0] == ' ') {
- websocket_key.erase(0, 1);
- }
- }
- } else if (string_starts_with(line_lower, "sec-websocket-protocol:")) {
- const std::string header_name = "sec-websocket-protocol:";
- if (line.length() > header_name.length()) {
- websocket_protocol = line.substr(header_name.length());
- // Trim leading spaces
- while (!websocket_protocol.empty() && websocket_protocol[0] == ' ') {
- websocket_protocol.erase(0, 1);
- }
- // Trim trailing spaces
- while (!websocket_protocol.empty() && websocket_protocol.back() == ' ') {
- websocket_protocol.pop_back();
- }
- SRV_INF("%s: parsed websocket_protocol='%s'\n", __func__, websocket_protocol.c_str());
- }
- } else if (string_starts_with(line_lower, "authorization:")) {
- const std::string header_name = "authorization:";
- if (line.length() > header_name.length()) {
- authorization_header = line.substr(header_name.length());
- // Trim leading spaces
- while (!authorization_header.empty() && authorization_header[0] == ' ') {
- authorization_header.erase(0, 1);
- }
- SRV_INF("%s: parsed authorization_header='%s'\n", __func__, authorization_header.c_str());
- }
- }
- }
-
- // Validate API key if configured
- if (!validate_api_key(authorization_header)) {
- const char * response = "HTTP/1.1 401 Unauthorized\r\n"
- "Content-Type: application/json\r\n\r\n"
- "{\"error\":{\"message\":\"Invalid API Key\",\"type\":\"authentication_error\",\"code\":401}}";
- send(sock, response, strlen(response), 0);
- SRV_WRN("%s", "Unauthorized: Invalid API Key\n");
-#ifdef _WIN32
- closesocket(sock);
-#else
- close(sock);
-#endif
- return;
- }
-
- if (websocket_key.empty()) {
- const char * response = "HTTP/1.1 400 Bad Request\r\n\r\n";
- send(sock, response, strlen(response), 0);
-#ifdef _WIN32
- closesocket(sock);
-#else
- close(sock);
-#endif
- return;
- }
-
- // Compute accept key: SHA1(key + magic) -> base64
- std::string magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
- std::string combined = websocket_key + magic;
- char hash[20];
- SHA1(hash, combined.c_str(), static_cast(combined.size()));
- std::string accept_key = base64::encode(hash, sizeof(hash));
-
- std::ostringstream full_response_debug;
- full_response_debug << "HTTP/1.1 101 Switching Protocols\r\n";
- full_response_debug << "Upgrade: websocket\r\n";
- full_response_debug << "Connection: Upgrade\r\n";
- full_response_debug << "Sec-WebSocket-Accept: " << accept_key << "\r\n\r\n";
- std::string full_response_str = full_response_debug.str();
-
- SRV_INF("%s: hash_size=%zu, accept_key_len=%zu, last_4_chars=",
- __func__, sizeof(hash), accept_key.length());
- if (accept_key.length() >= 4) {
- for (size_t i = accept_key.length() - 4; i < accept_key.length(); i++) {
- SRV_INF(" [%zu]='%c' (0x%02x)\n", i, accept_key[i], (unsigned char)accept_key[i]);
- }
- }
-
- // Send handshake response
- std::ostringstream response;
- response << "HTTP/1.1 101 Switching Protocols\r\n";
- response << "Upgrade: websocket\r\n";
- response << "Connection: Upgrade\r\n";
- response << "Sec-WebSocket-Accept: " << accept_key << "\r\n";
- // Echo back the protocol if requested
- if (!websocket_protocol.empty()) {
- response << "Sec-WebSocket-Protocol: " << websocket_protocol << "\r\n";
- }
- response << "\r\n";
-
- std::string response_str = response.str();
- SRV_DBG("%s: Sending 101 response, %zu bytes\n", __func__, response_str.size());
- size_t total_sent = 0;
- size_t to_send = response_str.size();
- while (total_sent < to_send) {
-#ifdef _WIN32
- int sent = ::send(sock, response_str.c_str() + total_sent, static_cast(to_send - total_sent), 0);
-#else
- ssize_t sent = ::send(sock, response_str.c_str() + total_sent, to_send - total_sent, 0);
-#endif
- if (sent <= 0) {
- SRV_ERR("%s: send() failed during handshake\n", __func__);
- break;
- }
- total_sent += sent;
- }
-
- // Create connection object
- auto conn = std::make_shared(sock, path, query);
-
- // Set remote address
- char addr_str[INET_ADDRSTRLEN];
- inet_ntop(AF_INET, &addr.sin_addr, addr_str, sizeof(addr_str));
- conn->set_remote_address(addr_str);
-
- // Store connection
- {
- std::lock_guard lock(connections_mutex);
- connections[conn.get()] = conn;
- }
-
- // Call on_open callback
- if (on_open_cb) {
- SRV_INF("%s: About to call on_open_cb for server: %s\n", __func__, query.c_str());
- on_open_cb(conn);
- SRV_INF("%s: Returned from on_open_cb for server: %s\n", __func__, query.c_str());
- }
-
- // Read loop
- SRV_INF("%s: Entering read loop for server: %s\n", __func__, query.c_str());
- std::vector recv_buf(4096);
- int recv_count = 0;
- while (!conn->is_closed()) {
-#ifdef _WIN32
- int n = recv(sock, reinterpret_cast(recv_buf.data()), recv_buf.size(), 0);
-#else
- ssize_t n = recv(sock, recv_buf.data(), recv_buf.size(), 0);
-#endif
-
- if (n <= 0) {
- SRV_INF("%s: recv returned %zd (count=%d) for server: %s\n", __func__, n, recv_count, query.c_str());
- break;
- }
- SRV_INF("%s: recv returned %zd bytes (count=%d) for server: %s\n", __func__, n, ++recv_count, query.c_str());
-
- recv_buf.resize(n);
- SRV_DBG("%s: received %zd bytes from WebSocket\n", __func__, n);
- // Log first 100 bytes in hex
- for (ssize_t i = 0; i < std::min(ssize_t(100), n); i++) {
- SRV_DBG("%s: [%zd] = 0x%02x (%c)\n", __func__, i, recv_buf[i],
- isprint(recv_buf[i]) ? recv_buf[i] : '.');
- }
-
- conn->handle_data(recv_buf, [this, conn](const std::string & msg) {
- SRV_DBG("%s: handle_data callback: msg length=%zu\n", __func__, msg.length());
- if (on_message_cb) {
- on_message_cb(conn, msg);
- }
- });
- }
-
- // Call on_close callback
- if (on_close_cb) {
- on_close_cb(conn);
- }
-
- // Remove connection
- {
- std::lock_guard lock(connections_mutex);
- connections.erase(conn.get());
- }
-
- // Close socket
- conn->close(1000, "");
-}
+// nothing to see here
\ No newline at end of file
diff --git a/tools/server/server-ws.h b/tools/server/server-ws.h
index 1540cf2d35a..43199ef3d40 100644
--- a/tools/server/server-ws.h
+++ b/tools/server/server-ws.h
@@ -1,6 +1,7 @@
#pragma once
#include "server-common.h"
+#include "server-http.h"
#include
#include
#include
@@ -16,69 +17,187 @@ struct common_params;
struct sockaddr_in;
class ws_connection_impl;
-// WebSocket connection interface
-// Abstracts the underlying WebSocket implementation
-struct server_ws_connection {
- virtual ~server_ws_connection() = default;
-
- // Send a message to the client
- virtual void send(const std::string & message) = 0;
- // Close the connection
- virtual void close(int code = 1000, const std::string & reason = "") = 0;
- // Get query parameter by key
- virtual std::string get_query_param(const std::string & key) const = 0;
- // Get the remote address
- virtual std::string get_remote_address() const = 0;
+// @ngxson: this is a demo for how a bi-directional connection between
+// the server and frontend can be implemented using SSE + HTTP POST
+// I'm reusing the name "WS" here, but this is not a real WebSocket implementation
+// the code is 100% written by human, no AI involved
+// but this is just a demo, do not use it in practice
+
+
+
+struct server_ws_connection;
+
+// hacky: server_ws_connection is a member of this struct because
+// we want to have shared_ptr for other handler functions
+// in practice, we don't really need this
+struct server_ws_sse : server_http_res {
+ std::string id;
+ std::shared_ptr conn;
+ const server_http_req & req;
+
+ std::mutex mutex_send;
+ std::condition_variable cv;
+ struct msg {
+ std::string data;
+ bool is_closed = false;
+ };
+ std::queue queue_send;
+
+ server_ws_sse(const server_http_req & req, const std::string & id) : id(id), req(req) {
+ conn = std::make_shared(*this);
+
+ queue_send.push({
+ "data: {\"llamacpp_id\":\"" + id + "\"}", false
+ });
+
+ next = [this, &req, id](std::string & output) {
+ std::unique_lock lk(mutex_send);
+ constexpr auto poll_interval = std::chrono::milliseconds(500);
+ while (true) {
+ if (!queue_send.empty()) {
+ output.clear();
+ auto & front = queue_send.front();
+ if (front.is_closed) {
+ return false; // closed
+ }
+ SRV_INF("%s: sending SSE message: %s\n", id.c_str(), front.data.c_str());
+ output = "data: " + front.data + "\n\n";
+ queue_send.pop();
+ return true;
+ }
+ if (req.should_stop()) {
+ return false; // connection closed
+ }
+ cv.wait_for(lk, poll_interval);
+ }
+ };
+ }
+
+ std::function on_close;
+ ~server_ws_sse() {
+ close();
+ if (on_close) {
+ on_close();
+ }
+ }
+
+ void send(const std::string & message) {
+ std::lock_guard lk(mutex_send);
+ queue_send.push({message, false});
+ cv.notify_all();
+ }
+
+ void close() {
+ std::lock_guard lk(mutex_send);
+ queue_send.push({"", true});
+ cv.notify_all();
+ }
};
-// Forward declarations
-class ws_connection_impl;
-// WebSocket context - manages the WebSocket server
-// Runs on a separate thread and handles WebSocket connections
-struct server_ws_context {
- struct Impl;
- std::unique_ptr pimpl;
- std::thread thread;
- std::atomic is_ready = false;
+struct server_ws_connection {
+ server_ws_sse & parent;
+ server_ws_connection(server_ws_sse & parent) : parent(parent) {}
- std::string path_prefix; // e.g., "/mcp"
- int port;
+ // Send a message to the client
+ void send(const std::string & message) {
+ parent.send(message);
+ }
- server_ws_context();
- ~server_ws_context();
+ // Close the connection
+ void close(int code = 1000, const std::string & reason = "") {
+ SRV_INF("%s: closing connection: code=%d, reason=%s\n",
+ __func__, code, reason.c_str());
+ parent.close();
+ }
- // Initialize the WebSocket server
- bool init(const common_params & params);
+ // Get query parameter by key
+ std::string get_query_param(const std::string & key) const {
+ return parent.req.get_param(key);
+ }
- // Start the WebSocket server (runs in background thread)
- bool start();
+ // Get the remote address
+ std::string get_remote_address() {
+ return parent.id;
+ }
+};
- // Stop the WebSocket server
- void stop();
- // Get the actual port the WebSocket server is listening on
- int get_actual_port() const;
- // Set the port for the WebSocket server (note: actual port may differ if set to 0)
- void set_port(int port) { this->port = port; }
+// SSE + HTTP POST implementation of server_ws_context
+struct server_ws_context {
+ server_ws_context() = default;
+ ~server_ws_context() = default;
+
+ // map ID to connection
+ std::mutex mutex;
+ std::map res_map;
+
+ // SSE endpoint
+ server_http_context::handler_t get_mcp = [this](const server_http_req & req) {
+ auto id = random_string();
+ auto res = std::make_unique(req, id);
+ {
+ std::unique_lock lock(mutex);
+ res_map[id] = res.get();
+ }
+ SRV_INF("%s: new SSE connection established, ID: %s\n%s", __func__, id.c_str(), req.body.c_str());
+ res->id = id;
+ res->status = 200;
+ res->headers["X-Connection-ID"] = id;
+ res->content_type = "text/event-stream";
+ // res->next is set in server_ws_sse constructor
+ res->on_close = [this, id]() {
+ std::unique_lock lock(mutex);
+ handler_on_close(res_map[id]->conn);
+ res_map.erase(id);
+ };
+ handler_on_open(res->conn);
+ return res;
+ };
+
+ // HTTP POST endpoint
+ server_http_context::handler_t post_mcp = [this](const server_http_req & req) {
+ auto id = req.get_param("llamacpp_id");
+ std::shared_ptr conn;
+ SRV_INF("%s: received POST for connection ID: %s\n%s", __func__, id.c_str(), req.body.c_str());
+ std::unique_lock lock(mutex);
+ {
+ auto it = res_map.find(id);
+ if (it != res_map.end()) {
+ conn = it->second->conn;
+ }
+ }
+ if (!conn) {
+ SRV_ERR("%s: invalid connection ID: %s\n", __func__, id.c_str());
+ auto res = std::make_unique();
+ res->status = 400;
+ res->data = "Invalid connection ID";
+ return res;
+ }
+ handler_on_message(conn, req.body);
+ auto res = std::make_unique();
+ res->status = 200;
+ return res;
+ };
// Called when new connection is established
using on_open_t = std::function)>;
- void on_open(on_open_t handler);
+ void on_open(on_open_t handler) { handler_on_open = handler; }
// Called when message is received from a connection
using on_message_t = std::function, const std::string &)>;
- void on_message(on_message_t handler);
+ void on_message(on_message_t handler) { handler_on_message = handler; }
// Called when connection is closed
using on_close_t = std::function)>;
- void on_close(on_close_t handler);
+ void on_close(on_close_t handler) { handler_on_close = handler; }
- // For debugging
- std::string listening_address;
+ on_open_t handler_on_open;
+ on_message_t handler_on_message;
+ on_close_t handler_on_close;
};
diff --git a/tools/server/server.cpp b/tools/server/server.cpp
index 0dcab5d677f..c6bce3e200b 100644
--- a/tools/server/server.cpp
+++ b/tools/server/server.cpp
@@ -120,20 +120,12 @@ int main(int argc, char ** argv, char ** envp) {
// WebSocket Server (for MCP support) - only if --webui-mcp is enabled
//
- server_ws_context * ctx_ws = nullptr;
- server_mcp_bridge * mcp_bridge = nullptr;
+ std::unique_ptr ctx_ws = nullptr;
+ std::unique_ptr mcp_bridge = nullptr;
if (params.webui_mcp) {
- ctx_ws = new server_ws_context();
- mcp_bridge = new server_mcp_bridge();
-
- // Initialize WebSocket server with params (sets port to HTTP port + 1)
- if (!ctx_ws->init(params)) {
- LOG_ERR("%s: failed to initialize WebSocket server\n", __func__);
- delete ctx_ws;
- delete mcp_bridge;
- return 1;
- }
+ ctx_ws = std::make_unique();
+ mcp_bridge = std::make_unique();
}
// Helper function to get MCP config path
@@ -289,6 +281,9 @@ int main(int argc, char ** argv, char ** envp) {
res->data = response.dump();
return res;
});
+
+ ctx_http.get ("/mcp", ex_wrapper(ctx_ws->get_mcp));
+ ctx_http.post("/mcp", ex_wrapper(ctx_ws->post_mcp));
}
//
@@ -315,15 +310,8 @@ int main(int argc, char ** argv, char ** envp) {
if (is_router_server) {
LOG_INF("%s: starting router server, no model will be loaded in this process\n", __func__);
- clean_up = [&models_routes, &ctx_ws, &mcp_bridge]() {
+ clean_up = [&models_routes]() {
SRV_INF("%s: cleaning up before exit...\n", __func__);
- if (ctx_ws) {
- ctx_ws->stop();
- delete ctx_ws;
- }
- if (mcp_bridge) {
- delete mcp_bridge;
- }
if (models_routes.has_value()) {
models_routes->models.unload_all();
}
@@ -337,34 +325,14 @@ int main(int argc, char ** argv, char ** envp) {
}
ctx_http.is_ready.store(true);
- // Start WebSocket server (OS will assign an available port) - only if --webui-mcp is enabled
- if (params.webui_mcp && ctx_ws) {
- if (!ctx_ws->start()) {
- clean_up();
- LOG_ERR("%s: exiting due to WebSocket server error\n", __func__);
- return 1;
- }
- LOG_INF("%s: WebSocket server started on port %d\n", __func__, ctx_ws->get_actual_port());
- }
-
shutdown_handler = [&](int) {
- if (ctx_ws) {
- ctx_ws->stop();
- }
ctx_http.stop();
};
} else {
// setup clean up function, to be called before exit
- clean_up = [&ctx_http, &ctx_ws, &ctx_server, &mcp_bridge]() {
+ clean_up = [&ctx_http, &ctx_server]() {
SRV_INF("%s: cleaning up before exit...\n", __func__);
- if (ctx_ws) {
- ctx_ws->stop();
- delete ctx_ws;
- }
- if (mcp_bridge) {
- delete mcp_bridge;
- }
ctx_http.stop();
ctx_server.terminate();
llama_backend_free();
@@ -377,16 +345,6 @@ int main(int argc, char ** argv, char ** envp) {
return 1;
}
- // Start WebSocket server (OS will assign an available port) - only if --webui-mcp is enabled
- if (params.webui_mcp && ctx_ws) {
- if (!ctx_ws->start()) {
- clean_up();
- LOG_ERR("%s: exiting due to WebSocket server error\n", __func__);
- return 1;
- }
- LOG_INF("%s: WebSocket server started on port %d\n", __func__, ctx_ws->get_actual_port());
- }
-
// load the model
LOG_INF("%s: loading model\n", __func__);
@@ -405,10 +363,6 @@ int main(int argc, char ** argv, char ** envp) {
LOG_INF("%s: model loaded\n", __func__);
shutdown_handler = [&](int) {
- // this will unblock start_loop()
- if (ctx_ws) {
- ctx_ws->stop();
- }
ctx_server.terminate();
};
}
diff --git a/tools/server/webui/package-lock.json b/tools/server/webui/package-lock.json
index 9706678cde1..fa2792f7dbd 100644
--- a/tools/server/webui/package-lock.json
+++ b/tools/server/webui/package-lock.json
@@ -880,6 +880,7 @@
"integrity": "sha512-oJrXtQiAXLvT9clCf1K4kxp3eKsQhIaZqxEyowkBcsvZDdZkbWrVmnGknxs5flTD0VGsxrxKgBCZty1EzoiMzA==",
"dev": true,
"license": "Apache-2.0",
+ "peer": true,
"dependencies": {
"@swc/helpers": "^0.5.0"
}
@@ -2105,6 +2106,7 @@
"integrity": "sha512-rO+YQhHucy47Vh67z318pALmd6x+K1Kj30Fb4a6oOEw4xn4zCo9KTmkMWs24c4oduEXD/eJu3badlRmsVXzyfA==",
"dev": true,
"license": "MIT",
+ "peer": true,
"dependencies": {
"ts-dedent": "^2.0.0",
"type-fest": "~2.19"
@@ -2188,6 +2190,7 @@
"integrity": "sha512-Vp3zX/qlwerQmHMP6x0Ry1oY7eKKRcOWGc2P59srOp4zcqyn+etJyQpELgOi4+ZSUgteX8Y387NuwruLgGXLUQ==",
"dev": true,
"license": "MIT",
+ "peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -2227,6 +2230,7 @@
"integrity": "sha512-YZs/OSKOQAQCnJvM/P+F1URotNnYNeU3P2s4oIpzm1uFaqUEqRxUB0g5ejMjEb5Gjb9/PiBI5Ktrq4rUUF8UVQ==",
"dev": true,
"license": "MIT",
+ "peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^5.0.0",
"debug": "^4.4.1",
@@ -2642,6 +2646,7 @@
"integrity": "sha512-pemlzrSESWbdAloYml3bAJMEfNh1Z7EduzqPKprCH5S341frlpYnUEW0H72dLxa6IsYr+mPno20GiSm+h9dEdQ==",
"dev": true,
"license": "MIT",
+ "peer": true,
"dependencies": {
"@babel/code-frame": "^7.10.4",
"@babel/runtime": "^7.12.5",
@@ -2809,6 +2814,7 @@
"integrity": "sha512-bJFoMATwIGaxxx8VJPeM8TonI8t579oRvgAuT8zFugJsJZgzqv0Fu8Mhp68iecjzG7cnN3mO2dJQ5uUM2EFrgQ==",
"dev": true,
"license": "MIT",
+ "peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -2876,6 +2882,7 @@
"integrity": "sha512-kVIaQE9vrN9RLCQMQ3iyRlVJpTiDUY6woHGb30JDkfJErqrQEmtdWH3gV0PBAfGZgQXoqzXOO0T3K6ioApbbAA==",
"dev": true,
"license": "MIT",
+ "peer": true,
"dependencies": {
"@typescript-eslint/scope-manager": "8.37.0",
"@typescript-eslint/types": "8.37.0",
@@ -3100,6 +3107,7 @@
"integrity": "sha512-tJxiPrWmzH8a+w9nLKlQMzAKX/7VjFs50MWgcAj7p9XQ7AQ9/35fByFYptgPELyLw+0aixTnC4pUWV+APcZ/kw==",
"dev": true,
"license": "MIT",
+ "peer": true,
"dependencies": {
"@testing-library/dom": "^10.4.0",
"@testing-library/user-event": "^14.6.1",
@@ -3203,6 +3211,7 @@
"integrity": "sha512-oukfKT9Mk41LreEW09vt45f8wx7DordoWUZMYdY/cyAk7w5TWkTRCNZYF7sX7n2wB7jyGAl74OxgwhPgKaqDMQ==",
"dev": true,
"license": "MIT",
+ "peer": true,
"dependencies": {
"@vitest/utils": "3.2.4",
"pathe": "^2.0.3",
@@ -3273,6 +3282,7 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
+ "peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -3954,8 +3964,7 @@
"resolved": "https://registry.npmjs.org/csstype/-/csstype-3.1.3.tgz",
"integrity": "sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==",
"dev": true,
- "license": "MIT",
- "peer": true
+ "license": "MIT"
},
"node_modules/debug": {
"version": "4.4.3",
@@ -4209,6 +4218,7 @@
"dev": true,
"hasInstallScript": true,
"license": "MIT",
+ "peer": true,
"bin": {
"esbuild": "bin/esbuild"
},
@@ -4269,6 +4279,7 @@
"integrity": "sha512-QldCVh/ztyKJJZLr4jXNUByx3gR+TDYZCRXEktiZoUR3PGy4qCmSbkxcIle8GEwGpb5JBZazlaJ/CxLidXdEbQ==",
"dev": true,
"license": "MIT",
+ "peer": true,
"dependencies": {
"@eslint-community/eslint-utils": "^4.2.0",
"@eslint-community/regexpp": "^4.12.1",
@@ -7501,6 +7512,7 @@
}
],
"license": "MIT",
+ "peer": true,
"dependencies": {
"nanoid": "^3.3.11",
"picocolors": "^1.1.1",
@@ -7634,6 +7646,7 @@
"integrity": "sha512-I7AIg5boAr5R0FFtJ6rCfD+LFsWHp81dolrFD8S79U9tb8Az2nGrJncnMSnys+bpQJfRUzqs9hnA81OAA3hCuQ==",
"dev": true,
"license": "MIT",
+ "peer": true,
"bin": {
"prettier": "bin/prettier.cjs"
},
@@ -7650,6 +7663,7 @@
"integrity": "sha512-pn1ra/0mPObzqoIQn/vUTR3ZZI6UuZ0sHqMK5x2jMLGrs53h0sXhkVuDcrlssHwIMk7FYrMjHBPoUSyyEEDlBQ==",
"dev": true,
"license": "MIT",
+ "peer": true,
"peerDependencies": {
"prettier": "^3.0.0",
"svelte": "^3.2.0 || ^4.0.0-next.0 || ^5.0.0-next.0"
@@ -7926,6 +7940,7 @@
"integrity": "sha512-FS+XFBNvn3GTAWq26joslQgWNoFu08F4kl0J4CgdNKADkdSGXQyTCnKteIAJy96Br6YbpEU1LSzV5dYtjMkMDg==",
"dev": true,
"license": "MIT",
+ "peer": true,
"engines": {
"node": ">=0.10.0"
}
@@ -7936,6 +7951,7 @@
"integrity": "sha512-Xs1hdnE+DyKgeHJeJznQmYMIBG3TKIHJJT95Q58nHLSrElKlGQqDTR2HQ9fx5CN/Gk6Vh/kupBTDLU11/nDk/g==",
"dev": true,
"license": "MIT",
+ "peer": true,
"dependencies": {
"scheduler": "^0.26.0"
},
@@ -8221,6 +8237,7 @@
"integrity": "sha512-4iya7Jb76fVpQyLoiVpzUrsjQ12r3dM7fIVz+4NwoYvZOShknRmiv+iu9CClZml5ZLGb0XMcYLutK6w9tgxHDw==",
"dev": true,
"license": "MIT",
+ "peer": true,
"dependencies": {
"@types/estree": "1.0.8"
},
@@ -8342,6 +8359,7 @@
"integrity": "sha512-elOcIZRTM76dvxNAjqYrucTSI0teAF/L2Lv0s6f6b7FOwcwIuA357bIE871580AjHJuSvLIRUosgV+lIWx6Rgg==",
"dev": true,
"license": "MIT",
+ "peer": true,
"dependencies": {
"chokidar": "^4.0.0",
"immutable": "^5.0.2",
@@ -8630,6 +8648,7 @@
"integrity": "sha512-7smAu0o+kdm378Q2uIddk32pn0UdIbrtTVU+rXRVtTVTCrK/P2cCui2y4JH+Bl3NgEq1bbBQpCAF/HKrDjk2Qw==",
"dev": true,
"license": "MIT",
+ "peer": true,
"dependencies": {
"@storybook/global": "^5.0.0",
"@storybook/icons": "^1.6.0",
@@ -8775,6 +8794,7 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.36.12.tgz",
"integrity": "sha512-c3mWT+b0yBLl3gPGSHiy4pdSQCsPNTjLC0tVoOhrGJ6PPfCzD/RQpAmAfJtQZ304CAae2ph+L3C4aqds3R3seQ==",
"license": "MIT",
+ "peer": true,
"dependencies": {
"@ampproject/remapping": "^2.3.0",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -9018,6 +9038,7 @@
"integrity": "sha512-gBXpgUm/3rp1lMZZrM/w7D8GKqshif0zAymAhbCyIt8KMe+0v9DQ7cdYLR4FHH/cKpdTXb+A/tKKU3eolfsI+g==",
"dev": true,
"license": "MIT",
+ "peer": true,
"funding": {
"type": "github",
"url": "https://github.com/sponsors/dcastil"
@@ -9048,7 +9069,8 @@
"resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.1.11.tgz",
"integrity": "sha512-2E9TBm6MDD/xKYe+dvJZAmg3yxIEDNRc0jwlNyDg/4Fil2QcSLjFKGVff0lAf1jjeaArlG/M75Ey/EYr/OJtBA==",
"dev": true,
- "license": "MIT"
+ "license": "MIT",
+ "peer": true
},
"node_modules/tapable": {
"version": "2.2.2",
@@ -9284,6 +9306,7 @@
"integrity": "sha512-p1diW6TqL9L07nNxvRMM7hMMw4c5XOo/1ibL4aAIGmSAt9slTE1Xgw5KWuof2uTOvCg9BY7ZRi+GaF+7sfgPeQ==",
"dev": true,
"license": "Apache-2.0",
+ "peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -9667,6 +9690,7 @@
"integrity": "sha512-BxAKBWmIbrDgrokdGZH1IgkIk/5mMHDreLDmCJ0qpyJaAteP8NvMhkwr/ZCQNqNH97bw/dANTE9PDzqwJghfMQ==",
"dev": true,
"license": "MIT",
+ "peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.5.0",
@@ -9827,6 +9851,7 @@
"integrity": "sha512-LUCP5ev3GURDysTWiP47wRRUpLKMOfPh+yKTx3kVIEiu5KOMeqzpnYNsKyOoVrULivR8tLcks4+lga33Whn90A==",
"dev": true,
"license": "MIT",
+ "peer": true,
"dependencies": {
"@types/chai": "^5.2.2",
"@vitest/expect": "3.2.4",
@@ -10047,6 +10072,7 @@
"resolved": "https://registry.npmjs.org/zod/-/zod-4.2.1.tgz",
"integrity": "sha512-0wZ1IRqGGhMP76gLqz8EyfBXKk0J2qo2+H3fi4mcUP/KtTocoX08nmIAHl1Z2kJIZbZee8KOpBCSNPRgauucjw==",
"license": "MIT",
+ "peer": true,
"funding": {
"url": "https://github.com/sponsors/colinhacks"
}
diff --git a/tools/server/webui/src/lib/services/mcp-transport-custom.ts b/tools/server/webui/src/lib/services/mcp-transport-custom.ts
new file mode 100644
index 00000000000..3ce1c8174c0
--- /dev/null
+++ b/tools/server/webui/src/lib/services/mcp-transport-custom.ts
@@ -0,0 +1,113 @@
+import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
+import { JSONRPCMessageSchema } from '@modelcontextprotocol/sdk/types.js';
+import { type Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
+
+const SUBPROTOCOL = 'mcp';
+
+// modified from WebSocketClientTransport
+
+export class CustomLlamaCppTransport implements Transport {
+ private _evSource?: EventSource;
+ private _url: URL;
+ private _id: string = 'unknown';
+
+ onclose?: () => void;
+ onerror?: (error: Error) => void;
+ onmessage?: (message: JSONRPCMessage) => void;
+
+ constructor(url: string) {
+ this._url = new URL(url);
+ }
+
+ start(): Promise {
+ if (this._evSource) {
+ throw new Error(
+ 'CustomLlamaCppTransport already started! If using Client class, note that connect() calls start() automatically.'
+ );
+ }
+
+ return new Promise((resolve, reject) => {
+ console.log('Connecting to SSE:', this._url.toString());
+ this._evSource = new EventSource(this._url.toString(), { withCredentials: true });
+
+ this._evSource.onerror = (event) => {
+ if (event.eventPhase == EventSource.CLOSED) {
+ this.onclose?.();
+ console.log('Event Source Closed');
+ }
+
+ const error =
+ 'error' in event
+ ? (event.error as Error)
+ : new Error(`EventSource error: ${JSON.stringify(event)}`);
+ reject(error);
+ this.onerror?.(error);
+ };
+
+ // this._evSource.onopen = () => {
+ // resolve();
+ // };
+
+ // this._evSource.onclose = () => {
+ // this.onclose?.();
+ // };
+
+ this._evSource.onmessage = (event: MessageEvent) => {
+ const raw = event.data.startsWith('data: ') ? event.data.slice(6) : event.data;
+ console.log('SSE Message received:', raw);
+ const data = JSON.parse(raw);
+ if (data.llamacpp_id) {
+ this._id = data.llamacpp_id;
+ console.log('Connected to SSE with id:', this._id);
+ resolve();
+ return;
+ }
+ let message: JSONRPCMessage;
+ try {
+ message = JSONRPCMessageSchema.parse(data);
+ } catch (error) {
+ this.onerror?.(error as Error);
+ return;
+ }
+
+ this.onmessage?.(message);
+ };
+ });
+ }
+
+ async close(): Promise {
+ this._evSource?.close();
+ }
+
+ send(message: JSONRPCMessage): Promise {
+ return new Promise((resolve, reject) => {
+ if (!this._evSource) {
+ reject(new Error('Not connected'));
+ return;
+ }
+
+ if (this._id == 'unknown') {
+ reject(new Error('Connection ID not set yet'));
+ return;
+ }
+
+ const url = new URL(this._url.toString());
+ url.searchParams.append('llamacpp_id', this._id);
+ fetch(url.toString(), {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json',
+ 'X-Subprotocol': SUBPROTOCOL // redundant, maybe remove later
+ },
+ body: JSON.stringify(message),
+ credentials: 'include' // redundant too, maybe remove later
+ })
+ .then(() => {
+ resolve();
+ })
+ .catch((error) => {
+ reject(error);
+ });
+ });
+ }
+}
diff --git a/tools/server/webui/src/lib/services/mcp.ts b/tools/server/webui/src/lib/services/mcp.ts
index 881ebbfaea5..96cbb83ab58 100644
--- a/tools/server/webui/src/lib/services/mcp.ts
+++ b/tools/server/webui/src/lib/services/mcp.ts
@@ -8,8 +8,9 @@
*/
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
-import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
+// import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
import type { Tool, Notification } from '@modelcontextprotocol/sdk/types.js';
+import { CustomLlamaCppTransport } from './mcp-transport-custom';
// Timeout constants - increased for Docker containers that may take time to start
const REQUEST_TIMEOUT_MS = 120_000; // 2 minutes for slow-starting containers
@@ -19,7 +20,7 @@ const INITIAL_RECONNECT_DELAY_MS = 1000; // Start with 1s delay
export class McpService {
private client: Client | null = null;
- private transport: WebSocketClientTransport | null = null;
+ private transport: CustomLlamaCppTransport | null = null;
private reconnectTimer: ReturnType | null = null;
private reconnectAttempts = 0;
private manualDisconnect = false;
@@ -57,7 +58,7 @@ export class McpService {
async connect(): Promise {
try {
// Create a new transport instance for each connection attempt
- this.transport = new WebSocketClientTransport(new URL(this.wsUrl));
+ this.transport = new CustomLlamaCppTransport(this.wsUrl);
// Set up transport event handlers
this.transport.onclose = () => {
@@ -275,8 +276,8 @@ export const mcpServiceFactory = {
*/
create: (serverName: string): McpService => {
const url = new URL(window.location.href);
- const wsPort = (parseInt(url.port) || 80) + 1;
- const wsUrl = `ws://${url.hostname}:${wsPort}/mcp?server=${encodeURIComponent(serverName)}`;
- return new McpService(serverName, wsUrl);
+ url.pathname = '/mcp';
+ url.searchParams.set('server', serverName);
+ return new McpService(serverName, url.toString());
}
};
diff --git a/tools/server/webui/src/lib/stores/mcp.svelte.ts b/tools/server/webui/src/lib/stores/mcp.svelte.ts
index 018de16d58d..f0febf5df03 100644
--- a/tools/server/webui/src/lib/stores/mcp.svelte.ts
+++ b/tools/server/webui/src/lib/stores/mcp.svelte.ts
@@ -16,7 +16,7 @@
*/
import type { McpConnectionState, McpTool } from '$lib/types/mcp';
-import { McpService } from '$lib/services/mcp';
+import { McpService, mcpServiceFactory } from '$lib/services/mcp';
import { SvelteMap, SvelteSet } from 'svelte/reactivity';
class McpStore {
@@ -138,11 +138,6 @@ class McpStore {
await new Promise((resolve) => setTimeout(resolve, 500));
}
- // WebSocket port is always HTTP port + 1
- const url = new URL(window.location.href);
- const wsPort = parseInt(url.port) + 1;
- const wsUrl = `ws://${window.location.hostname}:${wsPort}/mcp?server=${encodeURIComponent(serverName)}`;
-
// Increment connection generation for this server
const currentGen = (this.connectionGenerations.get(serverName) ?? 0) + 1;
this.connectionGenerations.set(serverName, currentGen);
@@ -159,7 +154,7 @@ class McpStore {
});
try {
- const service = new McpService(serverName, wsUrl);
+ const service = mcpServiceFactory.create(serverName);
this.services.set(serverName, service);
// Set up event handlers with generation checking
diff --git a/tools/server/webui/vite.config.ts b/tools/server/webui/vite.config.ts
index 899d3f63ea3..349738a2674 100644
--- a/tools/server/webui/vite.config.ts
+++ b/tools/server/webui/vite.config.ts
@@ -159,13 +159,7 @@ export default defineConfig({
'/models': 'http://localhost:8080',
// HTTP endpoint for MCP
'/mcp/servers': 'http://localhost:8080',
- // WebSocket endpoint for MCP (on HTTP port + 1)
- '/mcp': {
- target: 'http://localhost:8081',
- ws: true,
- // Don't rewrite the path - keep /mcp as-is
- rewrite: (path) => path
- }
+ '/mcp': 'http://localhost:8080'
},
headers: {
'Cross-Origin-Embedder-Policy': 'require-corp',