From 12aba64d7d55aa8919d10a620741d9bd4ed4e1f0 Mon Sep 17 00:00:00 2001 From: Chris Addams Date: Sun, 29 Mar 2026 22:16:25 +0100 Subject: [PATCH] feat: HTTP mode for MCP server with per-request auth Add HTTP transport mode alongside stdio for hosting the MCP server as a web service. The AI sidebar connects to this server for tool execution. - Per-request credential isolation via AsyncLocal (no shared state) - Authorization header + X-Org-Id + X-Instance-Url required on tool calls - Credential-mutating tools blocked in HTTP mode (login, logout, etc.) - Request body size limited to 1MB - CORS origins configurable via MCP_CORS_ORIGINS env var - Tool results truncated to 500 chars in logs - Generic error messages to clients, details logged server-side - SafeArgs regex whitelist on all command inputs - Token refresh with JWT expiry checking - 40 unit tests covering security, isolation, registry, and blocked tools --- mcp/AnythinkMcp.csproj | 2 +- mcp/McpClientFactory.cs | 40 ++- mcp/McpToolRegistry.cs | 173 +++++++++++ mcp/Program.cs | 200 +++++++++++-- mcp/Tools/CliTool.cs | 464 +++++++++++++++++++++++++++-- tests-mcp/AnythinkMcp.Tests.csproj | 31 ++ tests-mcp/BlockedToolsTests.cs | 46 +++ tests-mcp/GlobalUsings.cs | 1 + tests-mcp/McpClientFactoryTests.cs | 86 ++++++ tests-mcp/McpToolRegistryTests.cs | 56 ++++ tests-mcp/SafeArgsTests.cs | 48 +++ 11 files changed, 1100 insertions(+), 47 deletions(-) create mode 100644 mcp/McpToolRegistry.cs create mode 100644 tests-mcp/AnythinkMcp.Tests.csproj create mode 100644 tests-mcp/BlockedToolsTests.cs create mode 100644 tests-mcp/GlobalUsings.cs create mode 100644 tests-mcp/McpClientFactoryTests.cs create mode 100644 tests-mcp/McpToolRegistryTests.cs create mode 100644 tests-mcp/SafeArgsTests.cs diff --git a/mcp/AnythinkMcp.csproj b/mcp/AnythinkMcp.csproj index e6a93b6..d2d567b 100644 --- a/mcp/AnythinkMcp.csproj +++ b/mcp/AnythinkMcp.csproj @@ -1,4 +1,4 @@ - + Exe diff --git a/mcp/McpClientFactory.cs b/mcp/McpClientFactory.cs index deb8b16..d02c56a 100644 --- a/mcp/McpClientFactory.cs +++ b/mcp/McpClientFactory.cs @@ -4,15 +4,19 @@ namespace AnythinkMcp; /// -/// Resolves a profile name (or the active default) into an authenticated -/// . Uses the same config files and token-refresh -/// logic as the CLI — credentials stored by anythink login work here too. +/// Resolves credentials into an authenticated . +/// +/// In stdio mode: uses CLI config files and saved profiles (same as the CLI). +/// In HTTP mode: uses per-request credentials passed via . /// public class McpClientFactory { private readonly string? _profileName; private readonly HttpMessageHandler? _httpHandler; + // Per-request credentials for HTTP mode — AsyncLocal flows correctly across async/await + private static readonly AsyncLocal<(string OrgId, string BaseUrl, string Token)?> _requestCredentials = new(); + public string? ProfileName => _profileName; public McpClientFactory(string? profileName = null) @@ -20,6 +24,24 @@ public McpClientFactory(string? profileName = null) _profileName = profileName; } + /// + /// Sets per-request credentials for HTTP mode. Must be called before tool execution. + /// Thread-static so concurrent requests don't interfere. + /// + public static void SetRequestCredentials(string orgId, string baseUrl, string token) + { + _requestCredentials.Value = (orgId, baseUrl, token); + } + + /// Clears per-request credentials after the request completes. + public static void ClearRequestCredentials() + { + _requestCredentials.Value = null; + } + + /// Returns true if running in HTTP mode with per-request credentials. + public static bool IsHttpMode => _requestCredentials.Value.HasValue; + /// Test-only constructor — injects a mock HTTP handler for all clients. internal McpClientFactory(string? profileName, HttpMessageHandler httpHandler) { @@ -49,11 +71,19 @@ public BillingClient GetUnauthenticatedBillingClient() } /// - /// Returns an authenticated client for the configured profile. - /// Refreshes expired JWT tokens automatically (same logic as the CLI). + /// Returns an authenticated client. In HTTP mode, uses per-request credentials. + /// In stdio mode, uses CLI config files and refreshes expired tokens. /// public AnythinkClient GetClient() { + // HTTP mode: use per-request credentials (no config files) + if (_requestCredentials.Value.HasValue) + { + var creds = _requestCredentials.Value.Value; + return new AnythinkClient(creds.OrgId, creds.BaseUrl, creds.Token); + } + + // Stdio mode: resolve from CLI config var profile = !string.IsNullOrEmpty(_profileName) ? ConfigService.GetProfile(_profileName) ?? throw new InvalidOperationException( diff --git a/mcp/McpToolRegistry.cs b/mcp/McpToolRegistry.cs new file mode 100644 index 0000000..596e2f5 --- /dev/null +++ b/mcp/McpToolRegistry.cs @@ -0,0 +1,173 @@ +using System.Reflection; +using System.Text.Json; +using AnythinkMcp.Tools; +using ModelContextProtocol.Server; + +namespace AnythinkMcp; + +/// +/// Registry for MCP tools — discovers tools from the assembly and provides +/// execution by name for the HTTP transport. In stdio mode, the MCP SDK +/// handles this automatically; in HTTP mode we need to invoke tools manually. +/// +public static class McpToolRegistry +{ + private static readonly Dictionary Tools = DiscoverTools(); + + /// Returns tool definitions in Claude API tool_use format. + public static List GetToolDefinitions() + { + return Tools.Values.Select(t => new + { + name = t.Name, + description = t.Description, + input_schema = t.InputSchema + }).Cast().ToList(); + } + + /// Executes a tool by name, returning the text result. + public static async Task ExecuteToolAsync(string toolName, JsonElement arguments, + IServiceProvider services) + { + if (!Tools.TryGetValue(toolName, out var tool)) + throw new ArgumentException($"Unknown tool: {toolName}"); + + var factory = services.GetRequiredService(); + + // Create an instance of the tool class (all tool classes take McpClientFactory in constructor) + var instance = Activator.CreateInstance(tool.DeclaringType, factory)!; + + // Build method arguments from the JSON + var methodParams = tool.Method.GetParameters(); + var invokeArgs = new object?[methodParams.Length]; + + for (var i = 0; i < methodParams.Length; i++) + { + var param = methodParams[i]; + if (arguments.TryGetProperty(ToCamelCase(param.Name!), out var value) || + arguments.TryGetProperty(param.Name!, out value)) + { + invokeArgs[i] = ConvertJsonElement(value, param.ParameterType); + } + else if (param.HasDefaultValue) + { + invokeArgs[i] = param.DefaultValue; + } + else + { + invokeArgs[i] = param.ParameterType.IsValueType + ? Activator.CreateInstance(param.ParameterType) + : null; + } + } + + // Invoke and await + var result = tool.Method.Invoke(instance, invokeArgs); + if (result is Task taskString) + return await taskString; + if (result is Task task) + { + await task; + return "OK"; + } + return result?.ToString() ?? ""; + } + + private static Dictionary DiscoverTools() + { + var tools = new Dictionary(); + + // Find all types with [McpServerToolType] attribute + var toolTypes = Assembly.GetExecutingAssembly().GetTypes() + .Where(t => t.GetCustomAttribute() != null); + + foreach (var type in toolTypes) + { + // Find methods with [McpServerTool] attribute + foreach (var method in type.GetMethods(BindingFlags.Public | BindingFlags.Instance)) + { + var toolAttr = method.GetCustomAttribute(); + if (toolAttr == null) continue; + + var descAttr = method.GetCustomAttribute(); + + var name = toolAttr.Name ?? method.Name; + var description = descAttr?.Description ?? ""; + + // Build input schema from method parameters + var properties = new Dictionary(); + var required = new List(); + + foreach (var param in method.GetParameters()) + { + var paramDesc = param.GetCustomAttribute(); + var paramName = ToCamelCase(param.Name!); + + properties[paramName] = new + { + type = GetJsonType(param.ParameterType), + description = paramDesc?.Description ?? param.Name + }; + + if (!param.HasDefaultValue && !IsNullable(param.ParameterType)) + required.Add(paramName); + } + + tools[name] = new ToolInfo + { + Name = name, + Description = description, + DeclaringType = type, + Method = method, + InputSchema = new + { + type = "object", + properties, + required = required.ToArray() + } + }; + } + } + + return tools; + } + + private static string GetJsonType(Type type) + { + type = Nullable.GetUnderlyingType(type) ?? type; + if (type == typeof(string)) return "string"; + if (type == typeof(int) || type == typeof(long) || type == typeof(double) || type == typeof(float)) return "number"; + if (type == typeof(bool)) return "boolean"; + return "string"; + } + + private static bool IsNullable(Type type) => + !type.IsValueType || Nullable.GetUnderlyingType(type) != null; + + private static string ToCamelCase(string name) => + string.IsNullOrEmpty(name) ? name : char.ToLowerInvariant(name[0]) + name[1..]; + + private static object? ConvertJsonElement(JsonElement element, Type targetType) + { + targetType = Nullable.GetUnderlyingType(targetType) ?? targetType; + + if (element.ValueKind == JsonValueKind.Null) return null; + if (targetType == typeof(string)) return element.GetString(); + if (targetType == typeof(int)) return element.GetInt32(); + if (targetType == typeof(long)) return element.GetInt64(); + if (targetType == typeof(bool)) return element.GetBoolean(); + if (targetType == typeof(double)) return element.GetDouble(); + if (targetType == typeof(Guid)) return Guid.Parse(element.GetString()!); + + return element.GetString(); + } +} + +internal class ToolInfo +{ + public string Name { get; init; } = ""; + public string Description { get; init; } = ""; + public Type DeclaringType { get; init; } = null!; + public MethodInfo Method { get; init; } = null!; + public object InputSchema { get; init; } = null!; +} diff --git a/mcp/Program.cs b/mcp/Program.cs index 1759db2..56fb2b1 100644 --- a/mcp/Program.cs +++ b/mcp/Program.cs @@ -1,4 +1,6 @@ +using System.Text.Json; using AnythinkMcp; +using Microsoft.AspNetCore.Builder; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; @@ -7,47 +9,201 @@ // ── Anythink MCP Server ────────────────────────────────────────────────────── // // A thin MCP wrapper around the Anythink CLI client library. -// Runs as a local stdio process — Claude launches it and talks over stdin/stdout. +// Supports two transport modes: // -// Uses the same profiles, auth, and AnythinkClient as the CLI: -// anythink-mcp → uses the active profile -// anythink-mcp --profile my-project → uses a named profile +// stdio (default): Claude Code launches it and talks over stdin/stdout. +// anythink-mcp +// anythink-mcp --profile my-project // -// Tools are organised by domain (entities, fields, data, workflows, roles, -// secrets) and delegate to AnythinkClient methods — no business logic here. +// http: Runs as an HTTP server for the AI sidebar and multi-tenant services. +// anythink-mcp --http +// anythink-mcp --http --port 5300 +// Requires Authorization, X-Org-Id, and X-Instance-Url headers on every request. +// Set MCP_CORS_ORIGINS env var to configure allowed origins (comma-separated). -var profile = ResolveProfile(args); +var profile = ResolveFlag(args, "--profile", "-p"); +var httpMode = args.Contains("--http"); +var port = int.TryParse(ResolveFlag(args, "--port"), out var p) ? p : 5300; -var builder = new HostBuilder(); +if (httpMode) + await RunHttpServer(profile, port); +else + await RunStdioServer(profile); -builder.ConfigureLogging(logging => logging.ClearProviders()); +// ── Stdio mode (for Claude Code) ───────────────────────────────────────────── -builder.ConfigureServices(services => +static async Task RunStdioServer(string? profile) { - services.AddSingleton(new McpClientFactory(profile)); + var builder = new HostBuilder(); + builder.ConfigureLogging(logging => logging.ClearProviders()); + builder.ConfigureServices(services => + { + services.AddSingleton(new McpClientFactory(profile)); + services + .AddMcpServer(server => + { + server.ServerInfo = new() { Name = "anythink", Version = "1.0.0" }; + }) + .WithStdioServerTransport() + .WithToolsFromAssembly(); + }); + + await builder.Build().RunAsync(); +} + +// ── HTTP mode (for AI sidebar / multi-tenant) ──────────────────────────────── + +static async Task RunHttpServer(string? profile, int port) +{ + var builder = WebApplication.CreateBuilder(); + builder.Logging.AddConsole(); + + // Limit request body size to prevent DoS + builder.WebHost.ConfigureKestrel(opts => opts.Limits.MaxRequestBodySize = 1_048_576); // 1 MB - services + builder.Services.AddSingleton(new McpClientFactory(profile)); + builder.Services .AddMcpServer(server => { - server.ServerInfo = new() - { - Name = "anythink", - Version = "1.0.0" - }; + server.ServerInfo = new() { Name = "anythink", Version = "1.0.0" }; }) - .WithStdioServerTransport() .WithToolsFromAssembly(); -}); -await builder.Build().RunAsync(); + var corsOrigins = Environment.GetEnvironmentVariable("MCP_CORS_ORIGINS")?.Split(',') + ?? ["http://localhost:5200"]; + builder.Services.AddCors(options => + { + options.AddDefaultPolicy(policy => + policy.WithOrigins(corsOrigins).AllowAnyHeader().AllowAnyMethod()); + }); + + var app = builder.Build(); + app.UseCors(); + + var logger = app.Services.GetRequiredService>(); + logger.LogInformation("MCP HTTP server starting on port {Port} with CORS origins: {Origins}", + port, string.Join(", ", corsOrigins)); + + // Health check (unauthenticated — standard for K8s probes) + app.MapGet("/health", () => new { status = "healthy", timestamp = DateTime.UtcNow }); + + // List available tools — no auth required (tool definitions are not sensitive, + // and the sidebar needs to cache them without tenant context) + app.MapGet("/tools", () => + { + var tools = McpToolRegistry.GetToolDefinitions(); + return Results.Json(tools); + }); + + // Execute a tool (requires auth + tenant context) + app.MapPost("/tools/call", async (HttpContext context) => + { + if (!ExtractAuth(context, out var token, out var orgId, out var instanceUrl, out var error)) + return error!; + + // Parse request body + JsonElement body; + try + { + body = await JsonSerializer.DeserializeAsync(context.Request.Body); + } + catch (JsonException) + { + return Results.Json(new { error = new { message = "Invalid JSON" } }, statusCode: 400); + } + + if (!body.TryGetProperty("name", out var nameEl) || nameEl.GetString() is not { } toolName) + return Results.Json(new { error = new { message = "'name' field required" } }, statusCode: 400); + + var arguments = body.TryGetProperty("arguments", out var args) + ? args + : JsonSerializer.Deserialize("{}"); + + // Block config-mutating tools in HTTP mode + if (toolName is "login" or "login_direct" or "signup" or "logout" + or "config_use" or "config_remove" or "config_show" + or "accounts_use") + { + return Results.Json(new + { + error = new { message = $"Tool '{toolName}' is not available in HTTP mode." } + }, statusCode: 403); + } + + // Set per-request credentials and execute + McpClientFactory.SetRequestCredentials(orgId!, instanceUrl!, token!); + try + { + logger.LogInformation("Tool call: {ToolName} args={Args}", toolName, arguments.ToString()); + var result = await McpToolRegistry.ExecuteToolAsync(toolName, arguments, + context.RequestServices); + logger.LogInformation("Tool result: {ToolName} => {Result}", toolName, result.Length > 500 ? result[..500] + "..." : result); + + return Results.Json(new + { + result = new { content = new[] { new { type = "text", text = result } } } + }); + } + catch (Exception ex) + { + logger.LogError(ex, "Tool '{ToolName}' execution failed for org {OrgId}", toolName, orgId); + return Results.Json(new + { + error = new { message = "Tool execution failed. Check server logs for details." } + }, statusCode: 500); + } + finally + { + McpClientFactory.ClearRequestCredentials(); + } + }); + + app.Run($"http://0.0.0.0:{port}"); +} + +// ── Auth extraction helper ─────────────────────────────────────────────────── + +static bool ExtractAuth(HttpContext context, out string? token, out string? orgId, + out string? instanceUrl, out IResult? error) +{ + var authHeader = context.Request.Headers.Authorization.FirstOrDefault(); + orgId = context.Request.Headers["X-Org-Id"].FirstOrDefault(); + instanceUrl = context.Request.Headers["X-Instance-Url"].FirstOrDefault(); + + if (string.IsNullOrEmpty(authHeader)) + { + token = null; + error = Results.Json(new { error = "Authorization header required" }, statusCode: 401); + return false; + } + + // Proper Bearer token extraction + token = authHeader.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase) + ? authHeader[7..] + : authHeader; + + if (string.IsNullOrEmpty(orgId)) + { + error = Results.Json(new { error = "X-Org-Id header required" }, statusCode: 400); + return false; + } + if (string.IsNullOrEmpty(instanceUrl)) + { + error = Results.Json(new { error = "X-Instance-Url header required" }, statusCode: 400); + return false; + } + + error = null; + return true; +} // ── Helpers ────────────────────────────────────────────────────────────────── -static string? ResolveProfile(string[] args) +static string? ResolveFlag(string[] args, string flag, string? shortFlag = null) { for (var i = 0; i < args.Length - 1; i++) { - if (args[i] is "--profile" or "-p") + if (args[i] == flag || (shortFlag != null && args[i] == shortFlag)) return args[i + 1]; } return null; diff --git a/mcp/Tools/CliTool.cs b/mcp/Tools/CliTool.cs index d892a81..bb650f6 100644 --- a/mcp/Tools/CliTool.cs +++ b/mcp/Tools/CliTool.cs @@ -1,18 +1,20 @@ using System.ComponentModel; using System.Diagnostics; +using System.Text.Json; +using System.Text.Json.Nodes; using System.Text.RegularExpressions; +using AnythinkCli.Client; +using AnythinkCli.Models; using ModelContextProtocol.Server; namespace AnythinkMcp.Tools; /// -/// Generic MCP tool that shells out to the Anythink CLI. -/// Covers every command the CLI supports — useful as a catch-all for commands -/// that don't have dedicated MCP tool wrappers (accounts, projects, users, -/// files, pay, oauth, migrate, fetch, etc.). +/// Generic MCP tool that covers every CLI command. /// -/// Requires the anythink CLI to be installed and available on PATH -/// (e.g. via dotnet tool install -g anythink-cli). +/// In stdio mode: shells out to the anythink CLI binary (as before). +/// In HTTP mode: routes commands through AnythinkClient in-process, +/// using per-request credentials from McpClientFactory. /// [McpServerToolType] public class CliTool @@ -20,9 +22,6 @@ public class CliTool private readonly McpClientFactory _factory; public CliTool(McpClientFactory factory) => _factory = factory; - // Allow only safe characters: alphanumeric, hyphens, underscores, dots, colons, - // slashes, spaces, equals, commas, braces, brackets, quotes, and @. - // Rejects shell metacharacters like ; | & $ ` \ ! ~ etc. private static readonly Regex SafeArgs = new( @"^[\w\s\-\./:=,@""'\{\}\[\]]+$", RegexOptions.Compiled); @@ -41,14 +40,420 @@ public async Task RunCli( "Do NOT include 'anythink' itself or '--profile' (profile is injected automatically).")] string command) { - // ── Input validation ──────────────────────────────────────────────────── if (string.IsNullOrWhiteSpace(command)) return "Error: command must not be empty."; if (!SafeArgs.IsMatch(command)) return "Error: command contains disallowed characters."; - // ── Build argument list (no shell involved — args passed directly) ────── + // HTTP mode: execute in-process via AnythinkClient + if (McpClientFactory.IsHttpMode) + return await ExecuteInProcess(command); + + // Stdio mode: shell out to the CLI binary + return await ExecuteViaProcess(command); + } + + // ── HTTP mode: in-process execution ────────────────────────────────── + + private async Task ExecuteInProcess(string command) + { + var args = SplitArgs(command); + if (args.Count == 0) return "Error: empty command."; + + var subcommand = args[0].ToLowerInvariant(); + var subArgs = args.Skip(1).ToList(); + var jsonMode = subArgs.Remove("--json"); + + try + { + var client = _factory.GetClient(); + + return subcommand switch + { + "entities" => await HandleEntities(client, subArgs, jsonMode), + "fields" => await HandleFields(client, subArgs, jsonMode), + "data" => await HandleData(client, subArgs, jsonMode), + "workflows" => await HandleWorkflows(client, subArgs, jsonMode), + "users" => await HandleUsers(client, subArgs, jsonMode), + "roles" => await HandleRoles(client, subArgs, jsonMode), + "secrets" => await HandleSecrets(client, subArgs, jsonMode), + "files" => await HandleFiles(client, subArgs, jsonMode), + "fetch" => await HandleFetch(client, subArgs), + "docs" => await HandleDocs(client, jsonMode), + _ => $"Command '{subcommand}' is not supported in HTTP mode." + }; + } + catch (Exception ex) + { + return $"Error: {ex.Message}"; + } + } + + private static async Task HandleEntities(AnythinkClient client, List args, bool json) + { + if (args.Count == 0 || args[0] == "list") + { + var entities = await client.GetEntitiesAsync(); + if (json) return Serialize(entities); + return FormatTable("Entities", entities.Select(e => new + { + e.Name, e.TableName, Fields = e.Fields?.Count ?? 0, + Public = e.IsPublic ? "yes" : "no", + RLS = e.EnableRls ? "yes" : "no" + })); + } + if (args[0] == "get" && args.Count > 1) + return Serialize(await client.GetEntityAsync(args[1])); + if (args[0] == "create" && args.Count > 1) + return Serialize(await client.CreateEntityAsync(new CreateEntityRequest(args[1]))); + if (args[0] == "delete" && args.Count > 1) + { + await client.DeleteEntityAsync(args[1]); + return $"Entity '{args[1]}' deleted."; + } + return "Usage: entities [list|get|create|delete] "; + } + + private static async Task HandleFields(AnythinkClient client, List args, bool json) + { + if (args.Count < 2) return "Usage: fields [list|add|delete] [field] [--type type]"; + + var action = args[0]; + var entity = args[1]; + + if (action == "list") + { + var fields = await client.GetFieldsAsync(entity); + if (json) return Serialize(fields); + return FormatTable($"Fields on {entity}", fields.Select(f => new + { + f.Name, f.DatabaseType, f.DisplayType, Required = f.IsRequired ? "yes" : "no", + Unique = f.IsUnique ? "yes" : "no" + })); + } + if ((action == "add" || action == "create") && args.Count > 2) + { + var fieldName = args[2]; + var dbType = GetFlag(args, "--type") ?? "varchar"; + var displayType = GetFlag(args, "--display") ?? dbType switch + { + "varchar" => "input", + "text" => "textarea", + "integer" or "bigint" => "input", + "decimal" => "input", + "boolean" => "checkbox", + "timestamp" => "timestamp", + "json" or "jsonb" => "json", + _ => "input" + }; + var required = args.Contains("--required"); + return Serialize(await client.AddFieldAsync(entity, + new CreateFieldRequest(fieldName, dbType, displayType, IsRequired: required))); + } + if (action == "delete" && args.Count > 2) + { + var fields = await client.GetFieldsAsync(entity); + var field = fields.FirstOrDefault(f => + string.Equals(f.Name, args[2], StringComparison.OrdinalIgnoreCase)); + if (field == null) return $"Field '{args[2]}' not found on '{entity}'."; + await client.DeleteFieldAsync(entity, field.Id); + return $"Field '{args[2]}' deleted from '{entity}'."; + } + return "Usage: fields [list|add|delete] [field] [--type type]"; + } + + private static async Task HandleData(AnythinkClient client, List args, bool json) + { + if (args.Count < 2) return "Usage: data [list|get|create|update|delete] [id]"; + + var action = args[0]; + var entity = args[1]; + + if (action == "list") + { + var pageSize = int.TryParse(GetFlag(args, "--limit"), out var l) ? l : 25; + var result = await client.ListItemsAsync(entity, pageSize: pageSize); + return json ? Serialize(result) : Serialize(result.Items); + } + if (action == "get" && args.Count > 2 && int.TryParse(args[2], out var getId)) + return Serialize(await client.GetItemAsync(entity, getId)); + if (action == "create" || action == "insert") + { + var dataJson = GetFlag(args, "--data") ?? GetFlag(args, "-d"); + if (dataJson == null) return "Usage: data create --data '{...}'"; + var obj = JsonNode.Parse(dataJson)?.AsObject() ?? new JsonObject(); + return Serialize(await client.CreateItemAsync(entity, obj)); + } + if (action == "update" && args.Count > 2 && int.TryParse(args[2], out var updateId)) + { + var dataJson = GetFlag(args, "--data") ?? GetFlag(args, "-d"); + if (dataJson == null) return "Usage: data update --data '{...}'"; + var obj = JsonNode.Parse(dataJson)?.AsObject() ?? new JsonObject(); + return Serialize(await client.UpdateItemAsync(entity, updateId, obj)); + } + if (action == "delete" && args.Count > 2 && int.TryParse(args[2], out var deleteId)) + { + await client.DeleteItemAsync(entity, deleteId); + return $"Record {deleteId} deleted from '{entity}'."; + } + return "Usage: data [list|get|create|update|delete] [id]"; + } + + private static async Task ResolveWorkflowId(AnythinkClient client, string idOrName) + { + if (int.TryParse(idOrName, out var id)) return id; + var all = await client.GetWorkflowsAsync(); + return all.FirstOrDefault(w => string.Equals(w.Name, idOrName, StringComparison.OrdinalIgnoreCase))?.Id; + } + + private static async Task HandleWorkflows(AnythinkClient client, List args, bool json) + { + if (args.Count == 0 || args[0] == "list") + { + var workflows = await client.GetWorkflowsAsync(); + if (json) return Serialize(workflows); + return FormatTable("Workflows", workflows.Select(w => new + { + w.Id, w.Name, w.Description, w.Enabled, w.Trigger + })); + } + if (args[0] == "get" && args.Count > 1) + { + var id = await ResolveWorkflowId(client, args[1]); + if (id == null) return $"Workflow '{args[1]}' not found."; + return Serialize(await client.GetWorkflowAsync(id.Value)); + } + if (args[0] == "create" && args.Count > 1) + { + var name = args[1]; + // Normalize trigger casing — API expects PascalCase + var rawTrigger = GetFlag(args, "--trigger") ?? "Manual"; + var trigger = rawTrigger.ToLowerInvariant() switch + { + "manual" => "Manual", + "timed" => "Timed", + "event" => "Event", + "api" => "Api", + _ => rawTrigger + }; + var description = GetFlag(args, "--description"); + var cron = GetFlag(args, "--cron"); + var eventType = GetFlag(args, "--event"); + var eventEntity = GetFlag(args, "--entity") ?? GetFlag(args, "--event-entity"); + var apiRoute = GetFlag(args, "--api-route"); + var enabled = args.Contains("--enabled"); + + object options = trigger switch + { + "Timed" => new { cron_expression = cron ?? "0 9 * * *", event_entity = eventEntity ?? "" }, + "Event" => (object)new EventWorkflowOptions(eventType ?? "EntityCreated", eventEntity ?? ""), + "Api" => new { api_route = apiRoute ?? "", event_entity = eventEntity ?? "" }, + _ => new { event_entity = eventEntity ?? "", manual_entities = eventEntity != null ? new[] { eventEntity } : Array.Empty() } + }; + + var wf = await client.CreateWorkflowAsync(new CreateWorkflowRequest( + name, description, trigger, enabled, options, trigger == "Api" ? apiRoute : null)); + return Serialize(wf); + } + if (args[0] == "update" && args.Count > 1 && (await ResolveWorkflowId(client, args[1])) is { } updateId) + { + var name = GetFlag(args, "--name"); + var description = GetFlag(args, "--description"); + var wf = await client.UpdateWorkflowAsync(updateId, new UpdateWorkflowRequest(name, description)); + return Serialize(wf); + } + if (args[0] == "delete" && args.Count > 1 && (await ResolveWorkflowId(client, args[1])) is { } deleteId) + { + await client.DeleteWorkflowAsync(deleteId); + return $"Workflow {deleteId} deleted."; + } + if (args[0] == "trigger" && args.Count > 1 && (await ResolveWorkflowId(client, args[1])) is { } triggerId) + { + var payloadStr = GetFlag(args, "--payload"); + object? payload = null; + if (payloadStr != null) + payload = new { data = JsonNode.Parse(payloadStr) }; + await client.TriggerWorkflowAsync(triggerId, payload); + return $"Workflow {triggerId} triggered."; + } + if (args[0] == "enable" && args.Count > 1 && (await ResolveWorkflowId(client, args[1])) is { } enableId) + { + await client.EnableWorkflowAsync(enableId); + return $"Workflow {enableId} enabled."; + } + if (args[0] == "disable" && args.Count > 1 && (await ResolveWorkflowId(client, args[1])) is { } disableId) + { + await client.DisableWorkflowAsync(disableId); + return $"Workflow {disableId} disabled."; + } + if (args[0] == "jobs" && args.Count > 1 && (await ResolveWorkflowId(client, args[1])) is { } jobsWfId) + { + var result = await client.GetWorkflowJobsAsync(jobsWfId); + return Serialize(result); + } + if (args[0] == "step-add" && args.Count > 2 && int.TryParse(args[1], out var stepAddWfId)) + { + var key = args[2]; + var stepName = GetFlag(args, "--name") ?? key; + var action = GetFlag(args, "--action") ?? "RunScript"; + var paramsJson = GetFlag(args, "--params"); + var isStart = args.Contains("--start"); + var stepEnabled = !args.Contains("--disabled"); + + JsonElement? parameters = null; + if (paramsJson != null) + parameters = JsonSerializer.Deserialize(paramsJson); + + var step = await client.AddWorkflowStepAsync(stepAddWfId, + new CreateWorkflowStepRequest(key, stepName, action, stepEnabled, isStart, null, parameters)); + return Serialize(step); + } + if (args[0] == "step-get" && args.Count > 2 && int.TryParse(args[1], out var stepGetWfId) && int.TryParse(args[2], out var stepGetId)) + { + var wf = await client.GetWorkflowAsync(stepGetWfId); + var step = wf.Steps?.FirstOrDefault(s => s.Id == stepGetId); + if (step == null) return $"Step {stepGetId} not found in workflow {stepGetWfId}."; + return Serialize(step); + } + if (args[0] == "step-update" && args.Count > 2 && int.TryParse(args[1], out var stepUpdWfId) && int.TryParse(args[2], out var stepUpdId)) + { + // Fetch current step to preserve values + var wf = await client.GetWorkflowAsync(stepUpdWfId); + var step = wf.Steps?.FirstOrDefault(s => s.Id == stepUpdId); + if (step == null) return $"Step {stepUpdId} not found in workflow {stepUpdWfId}."; + + var body = new Dictionary + { + ["name"] = GetFlag(args, "--name") ?? step.Name, + ["action"] = GetFlag(args, "--action") ?? step.Action, + ["enabled"] = args.Contains("--disabled") ? false : step.Enabled, + ["is_start_step"] = args.Contains("--start") || step.IsStartStep, + ["on_success_step_id"] = step.OnSuccessStepId, + ["on_failure_step_id"] = step.OnFailureStepId, + }; + + var paramsJson = GetFlag(args, "--params"); + if (paramsJson != null) + body["parameters"] = JsonSerializer.Deserialize(paramsJson); + else if (step.Parameters.HasValue) + body["parameters"] = step.Parameters.Value; + + var updated = await client.UpdateWorkflowStepFullAsync(stepUpdWfId, stepUpdId, body); + return Serialize(updated); + } + if (args[0] == "step-link" && args.Count > 2 && int.TryParse(args[1], out var linkWfId) && int.TryParse(args[2], out var linkStepId)) + { + var wf = await client.GetWorkflowAsync(linkWfId); + var step = wf.Steps?.FirstOrDefault(s => s.Id == linkStepId); + if (step == null) return $"Step {linkStepId} not found in workflow {linkWfId}."; + + var onSuccess = GetFlag(args, "--on-success"); + var onFailure = GetFlag(args, "--on-failure"); + + var body = new Dictionary + { + ["name"] = step.Name, + ["action"] = step.Action, + ["enabled"] = step.Enabled, + ["is_start_step"] = step.IsStartStep, + ["on_success_step_id"] = onSuccess != null && int.TryParse(onSuccess, out var sId) ? sId : step.OnSuccessStepId, + ["on_failure_step_id"] = onFailure != null && int.TryParse(onFailure, out var fId) ? fId : step.OnFailureStepId, + }; + if (step.Parameters.HasValue) body["parameters"] = step.Parameters.Value; + + var updated = await client.UpdateWorkflowStepFullAsync(linkWfId, linkStepId, body); + return Serialize(updated); + } + return "Usage: workflows [list|get|create|update|delete|trigger|enable|disable|jobs|step-add|step-get|step-update|step-link] [id] [options]"; + } + + private static async Task HandleUsers(AnythinkClient client, List args, bool json) + { + if (args.Count == 0 || args[0] == "list") + return Serialize(await client.GetUsersAsync()); + if (args[0] == "me") + return Serialize(await client.GetMeAsync()); + if (args[0] == "get" && args.Count > 1 && int.TryParse(args[1], out var userId)) + return Serialize(await client.GetUserAsync(userId)); + return "Usage: users [list|me|get] [id]"; + } + + private static async Task HandleRoles(AnythinkClient client, List args, bool json) + { + if (args.Count == 0 || args[0] == "list") + return Serialize(await client.GetRolesAsync()); + if (args[0] == "get" && args.Count > 1 && int.TryParse(args[1], out var roleId)) + return Serialize(await client.GetRoleAsync(roleId)); + return "Usage: roles [list|get] [id]"; + } + + private static async Task HandleSecrets(AnythinkClient client, List args, bool json) + { + if (args.Count == 0 || args[0] == "list") + return Serialize(await client.GetSecretsAsync()); + if (args[0] == "create" && args.Count > 1) + { + var key = args[1]; + var value = GetFlag(args, "--value") ?? ""; + return Serialize(await client.CreateSecretAsync(new CreateSecretRequest(key, value))); + } + if (args[0] == "delete" && args.Count > 1) + { + await client.DeleteSecretAsync(args[1]); + return $"Secret '{args[1]}' deleted."; + } + return "Usage: secrets [list|create|delete] [key] [--value val]"; + } + + private static async Task HandleFiles(AnythinkClient client, List args, bool json) + { + if (args.Count == 0 || args[0] == "list") + return Serialize(await client.GetFilesAsync()); + if (args[0] == "get" && args.Count > 1 && int.TryParse(args[1], out var fileId)) + return Serialize(await client.GetFileAsync(fileId)); + if (args[0] == "delete" && args.Count > 1 && int.TryParse(args[1], out var delId)) + { + await client.DeleteFileAsync(delId); + return $"File {delId} deleted."; + } + return "Usage: files [list|get|delete] [id]"; + } + + private static async Task HandleFetch(AnythinkClient client, List args) + { + if (args.Count == 0) return "Usage: fetch [METHOD] [--data '{...}'] or fetch [--method METHOD] [--data '{...}']"; + + // Support both: `fetch POST /path -d '{}'` and `fetch /path --method POST --data '{}'` + var method = "GET"; + var path = args[0]; + if (args.Count > 1 && args[0] is "GET" or "POST" or "PUT" or "PATCH" or "DELETE") + { + method = args[0]; + path = args[1]; + args = args.Skip(2).ToList(); + } + else + { + args = args.Skip(1).ToList(); + method = GetFlag(args, "--method") ?? "GET"; + } + var body = GetFlag(args, "--data") ?? GetFlag(args, "-d"); + // Ensure path is a full URL — prepend the tenant base URL if it's a relative path + if (path.StartsWith('/')) + path = $"{client.BaseUrl}/org/{client.OrgId}{path}"; + return await client.FetchRawAsync(path, method, body); + } + + private static async Task HandleDocs(AnythinkClient client, bool json) + { + return await client.FetchRawAsync("/docs", "GET"); + } + + // ── Stdio mode: subprocess execution (unchanged) ───────────────────── + + private async Task ExecuteViaProcess(string command) + { var psi = new ProcessStartInfo { FileName = "anythink", @@ -56,11 +461,9 @@ public async Task RunCli( RedirectStandardError = true, UseShellExecute = false, CreateNoWindow = true, - // Suppress Spectre.Console ANSI sequences for cleaner output. Environment = { ["NO_COLOR"] = "1", ["TERM"] = "dumb" } }; - // Inject --profile if the MCP server was started with one. var profile = _factory.ProfileName; if (!string.IsNullOrEmpty(profile)) { @@ -68,7 +471,6 @@ public async Task RunCli( psi.ArgumentList.Add(profile); } - // Split the user command into individual arguments. foreach (var arg in SplitArgs(command)) psi.ArgumentList.Add(arg); @@ -93,11 +495,35 @@ public async Task RunCli( } } - /// - /// Splits a command string into arguments, respecting quoted strings. - /// E.g. data create posts --data '{"title":"Hello"}' → - /// ["data", "create", "posts", "--data", "{\"title\":\"Hello\"}"] - /// + // ── Helpers ────────────────────────────────────────────────────────── + + private static string Serialize(object? obj) => + JsonSerializer.Serialize(obj, new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower, + WriteIndented = true + }); + + private static string FormatTable(string title, IEnumerable rows) + { + var items = rows.ToList(); + if (items.Count == 0) return $"{title}: (none)"; + return $"{title} ({items.Count}):\n{Serialize(items)}"; + } + + private static string? GetFlag(List args, string flag) + { + var idx = args.IndexOf(flag); + if (idx >= 0 && idx < args.Count - 1) + { + var value = args[idx + 1]; + args.RemoveAt(idx + 1); + args.RemoveAt(idx); + return value; + } + return null; + } + internal static List SplitArgs(string input) { var args = new List(); diff --git a/tests-mcp/AnythinkMcp.Tests.csproj b/tests-mcp/AnythinkMcp.Tests.csproj new file mode 100644 index 0000000..2ee43bc --- /dev/null +++ b/tests-mcp/AnythinkMcp.Tests.csproj @@ -0,0 +1,31 @@ + + + + net8.0 + AnythinkMcp.Tests + enable + enable + false + true + + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + diff --git a/tests-mcp/BlockedToolsTests.cs b/tests-mcp/BlockedToolsTests.cs new file mode 100644 index 0000000..04ee44a --- /dev/null +++ b/tests-mcp/BlockedToolsTests.cs @@ -0,0 +1,46 @@ +using FluentAssertions; + +namespace AnythinkMcp.Tests; + +/// +/// Tests that credential-mutating tools are correctly identified for blocking in HTTP mode. +/// The actual blocking happens in Program.cs, but we verify the blocked list is correct. +/// +public class BlockedToolsTests +{ + // This list must match the blocked tools in mcp/Program.cs + private static readonly HashSet BlockedInHttpMode = new() + { + "login", "login_direct", "signup", "logout", + "config_use", "config_remove", "config_show", + "accounts_use" + }; + + [Theory] + [InlineData("login")] + [InlineData("login_direct")] + [InlineData("signup")] + [InlineData("logout")] + [InlineData("config_use")] + [InlineData("config_remove")] + [InlineData("config_show")] + [InlineData("accounts_use")] + public void CredentialMutatingTools_ShouldBeBlocked(string toolName) + { + BlockedInHttpMode.Should().Contain(toolName, + $"'{toolName}' mutates credentials and must be blocked in HTTP mode"); + } + + [Theory] + [InlineData("cli")] + [InlineData("list_entities")] + [InlineData("get_entity")] + [InlineData("list_data")] + [InlineData("create_data")] + [InlineData("fetch")] + public void DataReadTools_ShouldNotBeBlocked(string toolName) + { + BlockedInHttpMode.Should().NotContain(toolName, + $"'{toolName}' is a data tool and should be allowed in HTTP mode"); + } +} diff --git a/tests-mcp/GlobalUsings.cs b/tests-mcp/GlobalUsings.cs new file mode 100644 index 0000000..c802f44 --- /dev/null +++ b/tests-mcp/GlobalUsings.cs @@ -0,0 +1 @@ +global using Xunit; diff --git a/tests-mcp/McpClientFactoryTests.cs b/tests-mcp/McpClientFactoryTests.cs new file mode 100644 index 0000000..e139060 --- /dev/null +++ b/tests-mcp/McpClientFactoryTests.cs @@ -0,0 +1,86 @@ +using FluentAssertions; + +namespace AnythinkMcp.Tests; + +/// +/// Tests for per-request credential isolation in HTTP mode. +/// +public class McpClientFactoryTests : IDisposable +{ + public McpClientFactoryTests() + { + // Ensure clean state before each test + McpClientFactory.ClearRequestCredentials(); + } + + public void Dispose() + { + McpClientFactory.ClearRequestCredentials(); + } + + [Fact] + public void IsHttpMode_ShouldBeFalse_WhenNoCredentialsSet() + { + McpClientFactory.IsHttpMode.Should().BeFalse(); + } + + [Fact] + public void IsHttpMode_ShouldBeTrue_WhenCredentialsSet() + { + McpClientFactory.SetRequestCredentials("123", "https://api.example.com", "token"); + + McpClientFactory.IsHttpMode.Should().BeTrue(); + } + + [Fact] + public void ClearRequestCredentials_ShouldResetHttpMode() + { + McpClientFactory.SetRequestCredentials("123", "https://api.example.com", "token"); + McpClientFactory.ClearRequestCredentials(); + + McpClientFactory.IsHttpMode.Should().BeFalse(); + } + + [Fact] + public void GetClient_ShouldUsePerRequestCredentials_InHttpMode() + { + McpClientFactory.SetRequestCredentials("42", "https://api.test.com", "my-token"); + var factory = new McpClientFactory(); + + var client = factory.GetClient(); + + client.OrgId.Should().Be("42"); + client.BaseUrl.Should().Be("https://api.test.com"); + } + + [Fact] + public async Task PerRequestCredentials_ShouldBeIsolated_AcrossTasks() + { + // Simulate two concurrent requests with different credentials + var task1OrgId = ""; + var task2OrgId = ""; + + var task1 = Task.Run(() => + { + McpClientFactory.SetRequestCredentials("org-1", "https://api1.com", "token1"); + Thread.Sleep(50); // Give task2 time to set its own credentials + var factory = new McpClientFactory(); + task1OrgId = factory.GetClient().OrgId; + McpClientFactory.ClearRequestCredentials(); + }); + + var task2 = Task.Run(() => + { + Thread.Sleep(10); // Start slightly after task1 + McpClientFactory.SetRequestCredentials("org-2", "https://api2.com", "token2"); + var factory = new McpClientFactory(); + task2OrgId = factory.GetClient().OrgId; + McpClientFactory.ClearRequestCredentials(); + }); + + await Task.WhenAll(task1, task2); + + task1OrgId.Should().Be("org-1", "task1 should see its own credentials"); + task2OrgId.Should().Be("org-2", "task2 should see its own credentials"); + } +} diff --git a/tests-mcp/McpToolRegistryTests.cs b/tests-mcp/McpToolRegistryTests.cs new file mode 100644 index 0000000..9b72e99 --- /dev/null +++ b/tests-mcp/McpToolRegistryTests.cs @@ -0,0 +1,56 @@ +using FluentAssertions; + +namespace AnythinkMcp.Tests; + +/// +/// Tests for tool discovery and definition generation. +/// +public class McpToolRegistryTests +{ + [Fact] + public void GetToolDefinitions_ShouldReturnNonEmptyList() + { + var tools = McpToolRegistry.GetToolDefinitions(); + + tools.Should().NotBeEmpty("MCP server should expose at least one tool"); + } + + [Fact] + public void GetToolDefinitions_ShouldIncludeCliTool() + { + var tools = McpToolRegistry.GetToolDefinitions(); + var json = System.Text.Json.JsonSerializer.Serialize(tools); + + json.Should().Contain("\"cli\"", "the catch-all CLI tool should be registered"); + } + + [Fact] + public void GetToolDefinitions_ShouldIncludeCoreTools() + { + var tools = McpToolRegistry.GetToolDefinitions(); + var json = System.Text.Json.JsonSerializer.Serialize(tools); + + json.Should().Contain("\"cli\"", "catch-all CLI tool should be registered"); + json.Should().Contain("\"projects_list\"", "project management tools should be registered"); + json.Should().Contain("\"login\"", "auth tools should be registered"); + } + + [Fact] + public void GetToolDefinitions_ShouldHaveInputSchemas() + { + var tools = McpToolRegistry.GetToolDefinitions(); + var json = System.Text.Json.JsonSerializer.Serialize(tools); + + json.Should().Contain("input_schema", "every tool should have an input schema"); + json.Should().Contain("\"type\":\"object\"", "schemas should be object type"); + } + + [Fact] + public void GetToolDefinitions_ShouldHaveDescriptions() + { + var tools = McpToolRegistry.GetToolDefinitions(); + var json = System.Text.Json.JsonSerializer.Serialize(tools); + + json.Should().Contain("\"description\"", "every tool should have a description"); + } +} diff --git a/tests-mcp/SafeArgsTests.cs b/tests-mcp/SafeArgsTests.cs new file mode 100644 index 0000000..859501c --- /dev/null +++ b/tests-mcp/SafeArgsTests.cs @@ -0,0 +1,48 @@ +using System.Text.RegularExpressions; +using FluentAssertions; + +namespace AnythinkMcp.Tests; + +/// +/// Tests for the SafeArgs regex that validates CLI command input. +/// This is the primary defence against command injection. +/// +public class SafeArgsTests +{ + // Mirror the regex from CliTool.cs + private static readonly Regex SafeArgs = new( + @"^[\w\s\-\./:=,@""'\{\}\[\]]+$", RegexOptions.Compiled); + + [Theory] + [InlineData("entities list")] + [InlineData("data list blog_posts --json")] + [InlineData("users me")] + [InlineData("fetch /api/v1/health")] + [InlineData("data create posts {\"title\": \"Hello World\"}")] + [InlineData("workflows trigger 1 --data '{\"key\": \"value\"}'")] + [InlineData("migrate --from source --to target --dry-run")] + public void SafeArgs_ShouldAllow_ValidCommands(string command) + { + SafeArgs.IsMatch(command).Should().BeTrue($"'{command}' should be allowed"); + } + + [Theory] + [InlineData("entities list; rm -rf /")] + [InlineData("data list | cat /etc/passwd")] + [InlineData("users me && curl evil.com")] + [InlineData("entities list $(whoami)")] + [InlineData("data list `id`")] + [InlineData("entities list\t&& rm -rf /")] + [InlineData("entities list > /tmp/output")] + [InlineData("entities list < /etc/passwd")] + public void SafeArgs_ShouldReject_InjectionAttempts(string command) + { + SafeArgs.IsMatch(command).Should().BeFalse($"'{command}' should be rejected"); + } + + [Fact] + public void SafeArgs_ShouldReject_EmptyString() + { + SafeArgs.IsMatch("").Should().BeFalse(); + } +}