diff --git a/cmd/cli/cmd/chat.go b/cmd/cli/cmd/chat.go index 4b589a2..62b2be7 100644 --- a/cmd/cli/cmd/chat.go +++ b/cmd/cli/cmd/chat.go @@ -1,12 +1,13 @@ package cmd import ( + "errors" "fmt" "os" "strings" - "github.com/nullify-platform/cli/internal/auth" "github.com/nullify-platform/cli/internal/chat" + "github.com/nullify-platform/cli/internal/lib" "github.com/nullify-platform/logger/pkg/logger" "github.com/spf13/cobra" ) @@ -27,31 +28,21 @@ Examples: ctx := setupLogger(cmd.Context()) defer logger.L(ctx).Sync() - chatHost := resolveHost(ctx) - - token, err := auth.GetValidToken(ctx, chatHost) + authCtx, err := resolveCommandAuth(ctx) if err != nil { - fmt.Fprintf(os.Stderr, "Error: not authenticated. Run 'nullify auth login' first.\n") + if errors.Is(err, lib.ErrNoToken) { + fmt.Fprintf(os.Stderr, "Error: not authenticated. Run 'nullify auth login' first.\n") + } else { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + } os.Exit(ExitAuthError) } - creds, err := auth.LoadCredentials() - if err != nil { - fmt.Fprintf(os.Stderr, "Error: failed to load credentials: %v\n", err) - os.Exit(1) - } - - hostCreds := creds[auth.CredentialKey(chatHost)] - queryParams := hostCreds.QueryParameters - if queryParams == nil { - queryParams = make(map[string]string) - } - // Connect via WebSocket - conn, err := chat.Dial(ctx, chatHost, token) + conn, err := chat.Dial(ctx, authCtx.Host, authCtx.Token) if err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) + os.Exit(ExitNetworkError) } // Build client options @@ -67,7 +58,7 @@ Examples: opts = append(opts, chat.WithSystemPrompt(systemPrompt)) } - client := chat.NewClient(conn, queryParams, opts...) + client := chat.NewClient(conn, authCtx.QueryParams, opts...) defer client.Close() if len(args) > 0 { diff --git a/cmd/cli/cmd/ci.go b/cmd/cli/cmd/ci.go index 53ad6f7..1bf2a29 100644 --- a/cmd/cli/cmd/ci.go +++ b/cmd/cli/cmd/ci.go @@ -2,6 +2,7 @@ package cmd import ( "encoding/json" + "errors" "fmt" "os" "sync" @@ -47,7 +48,11 @@ Exit codes: ciHost := resolveHost(ctx) token, err := lib.GetNullifyToken(ctx, ciHost, nullifyToken, githubToken) if err != nil { - fmt.Fprintf(os.Stderr, "Error: not authenticated\n") + if errors.Is(err, lib.ErrNoToken) { + fmt.Fprintf(os.Stderr, "Error: not authenticated. Run 'nullify auth login' first.\n") + } else { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + } os.Exit(ExitAuthError) } @@ -159,7 +164,11 @@ var ciReportCmd = &cobra.Command{ ciHost := resolveHost(ctx) token, err := lib.GetNullifyToken(ctx, ciHost, nullifyToken, githubToken) if err != nil { - fmt.Fprintf(os.Stderr, "Error: not authenticated\n") + if errors.Is(err, lib.ErrNoToken) { + fmt.Fprintf(os.Stderr, "Error: not authenticated. Run 'nullify auth login' first.\n") + } else { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + } os.Exit(ExitAuthError) } @@ -189,6 +198,9 @@ var ciReportCmd = &cobra.Command{ rows := make([]reportRow, len(endpoints)*len(severities)) g, gctx := errgroup.WithContext(ctx) + var successCount int64 + var apiErrors int64 + var mu sync.Mutex for i, ep := range endpoints { for j, sev := range severities { @@ -202,9 +214,13 @@ var ciReportCmd = &cobra.Command{ body, err := lib.DoGet(gctx, nullifyClient.HttpClient, nullifyClient.BaseURL, ep.path+qs) if err != nil { + atomic.AddInt64(&apiErrors, 1) + mu.Lock() fmt.Fprintf(os.Stderr, "Warning: failed to query %s (%s): %v\n", ep.name, sev, err) + mu.Unlock() return nil } + atomic.AddInt64(&successCount, 1) rows[i*len(severities)+j] = reportRow{ scanner: ep.name, @@ -218,6 +234,14 @@ var ciReportCmd = &cobra.Command{ _ = g.Wait() + if successCount == 0 { + fmt.Fprintln(os.Stderr, "Error: all API requests failed, cannot generate report") + os.Exit(ExitNetworkError) + } + if apiErrors > 0 { + fmt.Fprintf(os.Stderr, "Warning: %d API requests failed while generating the report\n", apiErrors) + } + fmt.Println("## Nullify Security Report") fmt.Println() fmt.Println("| Scanner | Severity | Count |") diff --git a/cmd/cli/cmd/findings.go b/cmd/cli/cmd/findings.go index c74589e..77e0a63 100644 --- a/cmd/cli/cmd/findings.go +++ b/cmd/cli/cmd/findings.go @@ -110,6 +110,24 @@ Auto-detects the current repository from git if --repo is not specified.`, _ = g.Wait() + successCount := 0 + errorCount := 0 + for _, result := range results { + if result.Error != "" { + errorCount++ + continue + } + successCount++ + } + + if successCount == 0 { + fmt.Fprintln(os.Stderr, "Error: all scanner requests failed") + os.Exit(ExitNetworkError) + } + if errorCount > 0 { + fmt.Fprintf(os.Stderr, "Warning: %d/%d scanner requests failed\n", errorCount, len(results)) + } + out, _ := json.MarshalIndent(results, "", " ") if err := output.Print(cmd, out); err != nil { fmt.Fprintln(os.Stderr, string(out)) diff --git a/cmd/cli/cmd/mcp.go b/cmd/cli/cmd/mcp.go index 0f75ea8..efbf668 100644 --- a/cmd/cli/cmd/mcp.go +++ b/cmd/cli/cmd/mcp.go @@ -1,12 +1,12 @@ package cmd import ( + "errors" "fmt" "os" "strings" - "github.com/nullify-platform/cli/internal/auth" "github.com/nullify-platform/cli/internal/client" "github.com/nullify-platform/cli/internal/lib" "github.com/nullify-platform/cli/internal/mcp" @@ -28,26 +28,17 @@ var mcpServeCmd = &cobra.Command{ ctx := setupLogger(cmd.Context()) defer logger.L(ctx).Sync() - mcpHost := resolveHost(ctx) - - // Validate that we have a working token before starting the server - if _, err := auth.GetValidToken(ctx, mcpHost); err != nil { - fmt.Fprintf(os.Stderr, "Error: not authenticated. Run 'nullify auth login' first.\n") - os.Exit(ExitAuthError) - } - - creds, err := auth.LoadCredentials() + authCtx, err := resolveCommandAuth(ctx) if err != nil { - fmt.Fprintf(os.Stderr, "Error: failed to load credentials: %v\n", err) - os.Exit(1) + if errors.Is(err, lib.ErrNoToken) { + fmt.Fprintf(os.Stderr, "Error: not authenticated. Run 'nullify auth login' first.\n") + } else { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + } + os.Exit(ExitAuthError) } - hostCreds := creds[auth.CredentialKey(mcpHost)] - - queryParams := hostCreds.QueryParameters - if queryParams == nil { - queryParams = make(map[string]string) - } + queryParams := authCtx.QueryParams // Apply --repo flag or auto-detect from git repoFlag, _ := cmd.Flags().GetString("repo") @@ -74,9 +65,9 @@ var mcpServeCmd = &cobra.Command{ // Create a refreshing client for long-running MCP sessions tokenProvider := func() (string, error) { - return auth.GetValidToken(ctx, mcpHost) + return lib.GetNullifyToken(ctx, authCtx.Host, nullifyToken, githubToken) } - nullifyClient, clientErr := client.NewRefreshingNullifyClient(mcpHost, tokenProvider) + nullifyClient, clientErr := client.NewRefreshingNullifyClient(authCtx.Host, tokenProvider) if clientErr != nil { fmt.Fprintf(os.Stderr, "Error: failed to create client: %v\n", clientErr) os.Exit(1) diff --git a/cmd/cli/cmd/pentest.go b/cmd/cli/cmd/pentest.go index 98b85f5..e153359 100644 --- a/cmd/cli/cmd/pentest.go +++ b/cmd/cli/cmd/pentest.go @@ -37,7 +37,7 @@ func init() { pentestCmd.Flags().String("spec-path", "", "The file path to the OpenAPI file (yaml or json)") _ = pentestCmd.MarkFlagRequired("spec-path") pentestCmd.Flags().String("target-host", "", "The base URL of the API to be scanned") - pentestCmd.Flags().StringSlice("header", nil, "Headers for the pentest agent to authenticate with your API") + pentestCmd.Flags().StringSlice("header", nil, "Header for the pentest agent to authenticate with your API. Repeat the flag for multiple headers.") pentestCmd.Flags().String("github-owner", "", "The GitHub username or organisation") pentestCmd.Flags().String("github-repo", "", "The repository name for the Nullify issue dashboard") diff --git a/cmd/cli/cmd/root.go b/cmd/cli/cmd/root.go index 842f46f..6d5e946 100644 --- a/cmd/cli/cmd/root.go +++ b/cmd/cli/cmd/root.go @@ -39,6 +39,10 @@ var rootCmd = &cobra.Command{ if cmd.Name() == "login" || cmd.Name() == "completion" { return } + + if noColor || os.Getenv("NO_COLOR") != "" { + _ = os.Setenv("NO_COLOR", "1") + } }, } diff --git a/cmd/cli/cmd/runtime_auth.go b/cmd/cli/cmd/runtime_auth.go new file mode 100644 index 0000000..54a41ef --- /dev/null +++ b/cmd/cli/cmd/runtime_auth.go @@ -0,0 +1,39 @@ +package cmd + +import ( + "context" + + "github.com/nullify-platform/cli/internal/auth" + "github.com/nullify-platform/cli/internal/lib" +) + +type commandAuthContext struct { + Host string + Token string + QueryParams map[string]string +} + +func resolveCommandAuth(ctx context.Context) (*commandAuthContext, error) { + commandHost := resolveHost(ctx) + + token, err := lib.GetNullifyToken(ctx, commandHost, nullifyToken, githubToken) + if err != nil { + return nil, err + } + + queryParams := map[string]string{} + creds, err := auth.LoadCredentials() + if err == nil { + if hostCreds, ok := creds[auth.CredentialKey(commandHost)]; ok && hostCreds.QueryParameters != nil { + for key, value := range hostCreds.QueryParameters { + queryParams[key] = value + } + } + } + + return &commandAuthContext{ + Host: commandHost, + Token: token, + QueryParams: queryParams, + }, nil +} diff --git a/cmd/cli/cmd/runtime_auth_test.go b/cmd/cli/cmd/runtime_auth_test.go new file mode 100644 index 0000000..42181d9 --- /dev/null +++ b/cmd/cli/cmd/runtime_auth_test.go @@ -0,0 +1,72 @@ +package cmd + +import ( + "context" + "testing" + "time" + + "github.com/nullify-platform/cli/internal/auth" + "github.com/stretchr/testify/require" +) + +func TestResolveCommandAuthUsesEnvTokenWithoutStoredCredentials(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("NULLIFY_HOST", "acme.nullify.ai") + t.Setenv("NULLIFY_TOKEN", "env-token") + + originalHost := host + originalNullifyToken := nullifyToken + originalGithubToken := githubToken + host = "" + nullifyToken = "" + githubToken = "" + t.Cleanup(func() { + host = originalHost + nullifyToken = originalNullifyToken + githubToken = originalGithubToken + }) + + authCtx, err := resolveCommandAuth(setupLogger(context.Background())) + require.NoError(t, err) + require.Equal(t, "acme.nullify.ai", authCtx.Host) + require.Equal(t, "env-token", authCtx.Token) + require.Empty(t, authCtx.QueryParams) +} + +func TestResolveCommandAuthClonesStoredQueryParams(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("NULLIFY_HOST", "acme.nullify.ai") + t.Setenv("NULLIFY_TOKEN", "env-token") + + err := auth.SaveHostCredentials("acme.nullify.ai", auth.HostCredentials{ + AccessToken: "stored-token", + RefreshToken: "refresh-token", + ExpiresAt: time.Now().Add(time.Hour).Unix(), + QueryParameters: map[string]string{ + "githubOwnerId": "123", + }, + }) + require.NoError(t, err) + + originalHost := host + originalNullifyToken := nullifyToken + originalGithubToken := githubToken + host = "" + nullifyToken = "" + githubToken = "" + t.Cleanup(func() { + host = originalHost + nullifyToken = originalNullifyToken + githubToken = originalGithubToken + }) + + authCtx, err := resolveCommandAuth(setupLogger(context.Background())) + require.NoError(t, err) + require.Equal(t, map[string]string{"githubOwnerId": "123"}, authCtx.QueryParams) + + authCtx.QueryParams["githubOwnerId"] = "456" + + creds, err := auth.LoadCredentials() + require.NoError(t, err) + require.Equal(t, "123", creds[auth.CredentialKey("acme.nullify.ai")].QueryParameters["githubOwnerId"]) +} diff --git a/cmd/cli/cmd/status.go b/cmd/cli/cmd/status.go index 55f1277..eebcbe7 100644 --- a/cmd/cli/cmd/status.go +++ b/cmd/cli/cmd/status.go @@ -2,13 +2,17 @@ package cmd import ( "encoding/json" + "errors" "fmt" "os" + "sort" "strings" + "text/tabwriter" "github.com/nullify-platform/cli/internal/auth" "github.com/nullify-platform/cli/internal/client" "github.com/nullify-platform/cli/internal/lib" + "github.com/nullify-platform/cli/internal/output" "github.com/nullify-platform/logger/pkg/logger" "github.com/spf13/cobra" "golang.org/x/sync/errgroup" @@ -27,7 +31,11 @@ var securityStatusCmd = &cobra.Command{ statusHost := resolveHost(ctx) token, err := lib.GetNullifyToken(ctx, statusHost, nullifyToken, githubToken) if err != nil { - fmt.Fprintf(os.Stderr, "Error: not authenticated. Run 'nullify auth login' first.\n") + if errors.Is(err, lib.ErrNoToken) { + fmt.Fprintf(os.Stderr, "Error: not authenticated. Run 'nullify auth login' first.\n") + } else { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + } os.Exit(ExitAuthError) } @@ -50,28 +58,17 @@ var securityStatusCmd = &cobra.Command{ os.Exit(ExitNetworkError) } - // Try to pretty-print var overview any - if err := json.Unmarshal([]byte(overviewBody), &overview); err == nil { - pretty, _ := json.MarshalIndent(overview, "", " ") - fmt.Println("Security Posture Overview") - fmt.Println("========================") - fmt.Println(string(pretty)) - } else { - fmt.Println(overviewBody) + if err := json.Unmarshal([]byte(overviewBody), &overview); err != nil { + overview = map[string]any{"raw": overviewBody} } - // Fetch individual scanner postures scanners := allScannerEndpoints() - fmt.Println("\nFindings by Scanner") - fmt.Println("===================") - fmt.Printf("%-20s %s\n", "Scanner", "Status") - fmt.Printf("%-20s %s\n", "-------", "------") - type scannerResult struct { name string summary string + err string } results := make([]scannerResult, len(scanners)) g, gctx := errgroup.WithContext(ctx) @@ -82,7 +79,7 @@ var securityStatusCmd = &cobra.Command{ scannerQS := lib.BuildQueryString(queryParams, "limit", "1") body, err := lib.DoGet(gctx, nullifyClient.HttpClient, nullifyClient.BaseURL, scanner.path+scannerQS) if err != nil { - results[i] = scannerResult{name: scanner.name, summary: fmt.Sprintf("error: %v", err)} + results[i] = scannerResult{name: scanner.name, err: err.Error()} } else { results[i] = scannerResult{name: scanner.name, summary: summarizeFindingsResponse(body)} } @@ -92,8 +89,36 @@ var securityStatusCmd = &cobra.Command{ _ = g.Wait() - for _, r := range results { - fmt.Printf("%-20s %s\n", r.name, r.summary) + statusOutput := securityStatusOutput{ + Overview: overview, + Scanners: make([]scannerStatusOutput, 0, len(results)), + } + for _, result := range results { + statusOutput.Scanners = append(statusOutput.Scanners, scannerStatusOutput{ + Name: result.name, + Summary: result.summary, + Error: result.err, + }) + } + + format, _ := cmd.Flags().GetString("output") + outputExplicit := cmd.Flags().Lookup("output").Changed + + if format == "table" || !outputExplicit { + if err := printStatusTable(statusOutput); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + return + } + + out, err := json.Marshal(statusOutput) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: failed to encode status output: %v\n", err) + os.Exit(1) + } + if err := output.Print(cmd, out); err != nil { + fmt.Fprintln(os.Stderr, string(out)) } }, } @@ -132,3 +157,73 @@ func summarizeFindingsResponse(body string) string { return "data available" } + +type securityStatusOutput struct { + Overview any `json:"overview"` + Scanners []scannerStatusOutput `json:"scanners"` +} + +type scannerStatusOutput struct { + Name string `json:"name"` + Summary string `json:"summary,omitempty"` + Error string `json:"error,omitempty"` +} + +func printStatusTable(statusOutput securityStatusOutput) error { + fmt.Println("Security Posture Overview") + fmt.Println("========================") + + if overviewMap, ok := statusOutput.Overview.(map[string]any); ok { + overviewWriter := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(overviewWriter, "KEY\tVALUE") + + keys := make([]string, 0, len(overviewMap)) + for key := range overviewMap { + keys = append(keys, key) + } + sort.Strings(keys) + + for _, key := range keys { + fmt.Fprintf(overviewWriter, "%s\t%s\n", key, statusValueString(overviewMap[key])) + } + if err := overviewWriter.Flush(); err != nil { + return err + } + } else { + pretty, err := json.MarshalIndent(statusOutput.Overview, "", " ") + if err != nil { + return err + } + fmt.Println(string(pretty)) + } + + fmt.Println("\nFindings by Scanner") + fmt.Println("===================") + + scannerWriter := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(scannerWriter, "SCANNER\tSTATUS") + for _, scanner := range statusOutput.Scanners { + status := scanner.Summary + if scanner.Error != "" { + status = "error: " + scanner.Error + } + fmt.Fprintf(scannerWriter, "%s\t%s\n", scanner.Name, status) + } + + return scannerWriter.Flush() +} + +func statusValueString(value any) string { + switch v := value.(type) { + case nil: + return "" + case string: + return v + default: + data, err := json.Marshal(v) + if err != nil { + return fmt.Sprintf("%v", value) + } + return string(data) + } +} diff --git a/cmd/cli/cmd/status_test.go b/cmd/cli/cmd/status_test.go index 5d12edb..8228d48 100644 --- a/cmd/cli/cmd/status_test.go +++ b/cmd/cli/cmd/status_test.go @@ -3,6 +3,7 @@ package cmd import ( "testing" + "github.com/spf13/cobra" "github.com/stretchr/testify/require" ) @@ -70,3 +71,36 @@ func TestFilterEndpointsByType(t *testing.T) { noMatch := filterEndpointsByType(endpoints, "nonexistent") require.Nil(t, noMatch) } + +func TestStatusDefaultsToTable(t *testing.T) { + // The global --output flag defaults to "json", but status should + // display a table when the user hasn't explicitly passed -o. + root := &cobra.Command{Use: "test"} + root.PersistentFlags().StringP("output", "o", "json", "Output format") + + var sawTable bool + child := &cobra.Command{ + Use: "status", + Run: func(cmd *cobra.Command, args []string) { + format, _ := cmd.Flags().GetString("output") + outputExplicit := cmd.Flags().Lookup("output").Changed + sawTable = format == "table" || !outputExplicit + }, + } + root.AddCommand(child) + + // No flag -> defaults to table + root.SetArgs([]string{"status"}) + require.NoError(t, root.Execute()) + require.True(t, sawTable, "status should default to table when -o is not explicitly set") + + // Explicit -o json -> JSON + root.SetArgs([]string{"status", "-o", "json"}) + require.NoError(t, root.Execute()) + require.False(t, sawTable, "status should respect explicit -o json") + + // Explicit -o table -> table + root.SetArgs([]string{"status", "-o", "table"}) + require.NoError(t, root.Execute()) + require.True(t, sawTable, "status should show table when -o table is explicit") +} diff --git a/internal/chat/dialer.go b/internal/chat/dialer.go index 932dd71..79de8c5 100644 --- a/internal/chat/dialer.go +++ b/internal/chat/dialer.go @@ -2,24 +2,89 @@ package chat import ( "context" + "encoding/json" + "errors" "fmt" + "io" "net/http" + "strings" "github.com/gorilla/websocket" ) // Dial connects to the Nullify chat WebSocket. func Dial(ctx context.Context, host string, token string) (Conn, error) { - url := fmt.Sprintf("wss://%s/chat/websocket", host) + url := buildWebSocketURL(host) header := http.Header{} header.Set("Authorization", "Bearer "+token) dialer := websocket.DefaultDialer - conn, _, err := dialer.DialContext(ctx, url, header) + conn, resp, err := dialer.DialContext(ctx, url, header) if err != nil { - return nil, fmt.Errorf("failed to connect to chat: %w", err) + return nil, formatDialError(url, err, resp) } return conn, nil } + +func buildWebSocketURL(host string) string { + return fmt.Sprintf("wss://%s/chat/websocket", websocketHost(host)) +} + +func websocketHost(host string) string { + if strings.HasPrefix(host, "api.") { + return host + } + return "api." + host +} + +func formatDialError(url string, err error, resp *http.Response) error { + if resp == nil { + return fmt.Errorf("failed to connect to chat at %s: %w", url, err) + } + + defer resp.Body.Close() + + message := summarizeHandshakeBody(resp.Body) + base := fmt.Sprintf("failed to connect to chat at %s: websocket handshake failed with HTTP %d %s", url, resp.StatusCode, http.StatusText(resp.StatusCode)) + if message != "" { + base += ": " + message + } + + switch resp.StatusCode { + case http.StatusUnauthorized: + base += ". Check that your Nullify token is valid for this host." + case http.StatusForbidden: + base += ". The chat websocket is reachable, but this identity is not allowed to use it. Verify chat permissions for this host." + } + + return errors.New(base) +} + +func summarizeHandshakeBody(body io.Reader) string { + if body == nil { + return "" + } + + data, err := io.ReadAll(io.LimitReader(body, 4096)) + if err != nil { + return "" + } + + raw := strings.TrimSpace(string(data)) + if raw == "" { + return "" + } + + var parsed map[string]any + if json.Unmarshal(data, &parsed) == nil { + for _, key := range []string{"message", "error", "detail"} { + if value, ok := parsed[key].(string); ok && strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + } + } + + return raw +} diff --git a/internal/chat/dialer_test.go b/internal/chat/dialer_test.go new file mode 100644 index 0000000..2feda47 --- /dev/null +++ b/internal/chat/dialer_test.go @@ -0,0 +1,40 @@ +package chat + +import ( + "errors" + "io" + "net/http" + "strings" + "testing" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" +) + +func TestBuildWebSocketURL(t *testing.T) { + require.Equal(t, "wss://api.acme.nullify.ai/chat/websocket", buildWebSocketURL("acme.nullify.ai")) + require.Equal(t, "wss://api.acme.nullify.ai/chat/websocket", buildWebSocketURL("api.acme.nullify.ai")) +} + +func TestFormatDialErrorWithForbiddenHandshake(t *testing.T) { + err := formatDialError( + "wss://api.acme.nullify.ai/chat/websocket", + websocket.ErrBadHandshake, + &http.Response{ + StatusCode: http.StatusForbidden, + Body: io.NopCloser(strings.NewReader(`{"message":"User is not authorized"}`)), + }, + ) + + require.EqualError(t, err, "failed to connect to chat at wss://api.acme.nullify.ai/chat/websocket: websocket handshake failed with HTTP 403 Forbidden: User is not authorized. The chat websocket is reachable, but this identity is not allowed to use it. Verify chat permissions for this host.") +} + +func TestFormatDialErrorWithoutResponse(t *testing.T) { + err := formatDialError("wss://api.acme.nullify.ai/chat/websocket", errors.New("tls: internal error"), nil) + require.EqualError(t, err, "failed to connect to chat at wss://api.acme.nullify.ai/chat/websocket: tls: internal error") +} + +func TestSummarizeHandshakeBodyReturnsRawBody(t *testing.T) { + message := summarizeHandshakeBody(strings.NewReader("plain text failure")) + require.Equal(t, "plain text failure", message) +} diff --git a/internal/chat/renderer.go b/internal/chat/renderer.go index eb5961c..ef7ad3b 100644 --- a/internal/chat/renderer.go +++ b/internal/chat/renderer.go @@ -26,12 +26,16 @@ func isTTY() bool { // ansi returns the escape code if stdout is a terminal, empty string otherwise. func ansi(code string) string { - if isTTY() { + if colorsEnabled(isTTY()) { return code } return "" } +func colorsEnabled(stdoutIsTTY bool) bool { + return stdoutIsTTY && os.Getenv("NO_COLOR") == "" +} + // RenderToolCall renders a tool call message (dim text). func RenderToolCall(message string) string { return fmt.Sprintf("%s%s[tool] %s%s", ansi(ansiDim), ansi(ansiCyan), message, ansi(ansiReset)) diff --git a/internal/chat/renderer_test.go b/internal/chat/renderer_test.go index 2b9a8b2..45237fc 100644 --- a/internal/chat/renderer_test.go +++ b/internal/chat/renderer_test.go @@ -6,6 +6,15 @@ import ( "github.com/stretchr/testify/require" ) +func TestColorsEnabledHonorsNoColor(t *testing.T) { + t.Setenv("NO_COLOR", "1") + require.False(t, colorsEnabled(true)) + + t.Setenv("NO_COLOR", "") + require.True(t, colorsEnabled(true)) + require.False(t, colorsEnabled(false)) +} + func TestRenderToolCall(t *testing.T) { result := RenderToolCall("calling search_findings") require.Contains(t, result, "[tool]") diff --git a/internal/lib/auth_headers.go b/internal/lib/auth_headers.go index 68f9d80..65d57ac 100644 --- a/internal/lib/auth_headers.go +++ b/internal/lib/auth_headers.go @@ -3,26 +3,38 @@ package lib import ( "context" "fmt" + "regexp" "strings" "github.com/nullify-platform/logger/pkg/logger" ) +var multiHeaderPattern = regexp.MustCompile(`, [A-Z][a-zA-Z0-9-]+: `) + func ParseAuthHeaders(ctx context.Context, authHeaders []string) (map[string]string, error) { result := map[string]string{} for _, header := range authHeaders { - headers := strings.Split(header, ",") - for _, h := range headers { - headerParts := strings.SplitN(h, ": ", 2) - if len(headerParts) != 2 { - logger.L(ctx).Error("please provide headers in the format of 'key: value'") - return nil, fmt.Errorf("please provide headers in the format of 'key: value'") - } - - headerName := strings.TrimSpace(headerParts[0]) - headerValue := strings.TrimSpace(headerParts[1]) - result[headerName] = headerValue + headerParts := strings.SplitN(header, ":", 2) + if len(headerParts) != 2 { + logger.L(ctx).Error("please provide one header per flag in the format 'key: value'") + return nil, fmt.Errorf("please provide one header per flag in the format 'key: value'") + } + + headerName := strings.TrimSpace(headerParts[0]) + headerValue := strings.TrimSpace(headerParts[1]) + if headerName == "" { + logger.L(ctx).Error("header name cannot be empty") + return nil, fmt.Errorf("header name cannot be empty") + } + + result[headerName] = headerValue + + // Warn if the value looks like it contains another header. + if multiHeaderPattern.MatchString(headerValue) { + logger.L(ctx).Warn("header value looks like it may contain multiple headers; use repeated --header flags instead", + logger.String("header", headerName), + ) } } diff --git a/internal/lib/auth_headers_test.go b/internal/lib/auth_headers_test.go new file mode 100644 index 0000000..3c4aaa9 --- /dev/null +++ b/internal/lib/auth_headers_test.go @@ -0,0 +1,40 @@ +package lib + +import ( + "context" + "testing" + + "github.com/nullify-platform/logger/pkg/logger" + "github.com/stretchr/testify/require" +) + +func TestParseAuthHeaders(t *testing.T) { + ctx, err := logger.ConfigureDevelopmentLogger(context.Background(), "error") + require.NoError(t, err) + + t.Run("accepts headers without a space after the colon", func(t *testing.T) { + headers, err := ParseAuthHeaders(ctx, []string{"Authorization:Bearer token"}) + require.NoError(t, err) + require.Equal(t, map[string]string{"Authorization": "Bearer token"}, headers) + }) + + t.Run("preserves commas in the header value", func(t *testing.T) { + headers, err := ParseAuthHeaders(ctx, []string{"Cookie: a=b, c=d"}) + require.NoError(t, err) + require.Equal(t, map[string]string{"Cookie": "a=b, c=d"}, headers) + }) + + t.Run("rejects malformed headers", func(t *testing.T) { + _, err := ParseAuthHeaders(ctx, []string{"Authorization"}) + require.Error(t, err) + }) + + t.Run("warns on value resembling multiple headers", func(t *testing.T) { + warnCtx, err := logger.ConfigureDevelopmentLogger(context.Background(), "warn") + require.NoError(t, err) + + headers, err := ParseAuthHeaders(warnCtx, []string{"Authorization: Bearer token, X-Custom: value"}) + require.NoError(t, err) + require.Equal(t, map[string]string{"Authorization": "Bearer token, X-Custom: value"}, headers) + }) +} diff --git a/internal/lib/get_token.go b/internal/lib/get_token.go index 0a903f2..d880e0f 100644 --- a/internal/lib/get_token.go +++ b/internal/lib/get_token.go @@ -99,6 +99,9 @@ func GetNullifyToken( logger.L(ctx).Debug("using token from stored credentials") return storedToken, nil } + if err != nil { + return "", fmt.Errorf("stored credentials: %w", err) + } return "", ErrNoToken } diff --git a/internal/lib/get_token_test.go b/internal/lib/get_token_test.go new file mode 100644 index 0000000..3359708 --- /dev/null +++ b/internal/lib/get_token_test.go @@ -0,0 +1,34 @@ +package lib + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/nullify-platform/cli/internal/auth" + "github.com/nullify-platform/logger/pkg/logger" + "github.com/stretchr/testify/require" +) + +func TestGetNullifyTokenPreservesStoredCredentialErrors(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("NULLIFY_TOKEN", "") + t.Setenv("GITHUB_ACTIONS", "") + + ctx, err := logger.ConfigureDevelopmentLogger(context.Background(), "error") + require.NoError(t, err) + + // Save expired credentials without a refresh token + err = auth.SaveHostCredentials("acme.nullify.ai", auth.HostCredentials{ + AccessToken: "expired-token", + ExpiresAt: time.Now().Add(-time.Hour).Unix(), + }) + require.NoError(t, err) + + _, err = GetNullifyToken(ctx, "acme.nullify.ai", "", "") + require.Error(t, err) + require.False(t, errors.Is(err, ErrNoToken), + "should preserve stored credential error, not return generic ErrNoToken") + require.Contains(t, err.Error(), "stored credentials") +} diff --git a/internal/mcp/resources.go b/internal/mcp/resources.go index 91f20d3..54b1db0 100644 --- a/internal/mcp/resources.go +++ b/internal/mcp/resources.go @@ -30,7 +30,7 @@ func registerResources(s *server.MCPServer, c *client.NullifyClient, queryParams }, func(ctx context.Context, request mcplib.ReadResourceRequest) ([]mcplib.ResourceContents, error) { qs := buildQueryString(queryParams) - result, err := doGet(ctx, c, "/admin/metrics/overview"+qs) + result, err := doPost(ctx, c, "/admin/metrics/overview"+qs, metricsOverviewBody()) if err != nil { return nil, err } diff --git a/internal/mcp/server.go b/internal/mcp/server.go index d5fe167..0259f0b 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -3,6 +3,8 @@ package mcp import ( "context" "fmt" + "io" + "os" "github.com/nullify-platform/cli/internal/client" "github.com/nullify-platform/cli/internal/lib" @@ -17,6 +19,10 @@ func Serve(ctx context.Context, host string, token string, queryParams map[strin } func ServeWithClient(ctx context.Context, nullifyClient *client.NullifyClient, queryParams map[string]string, toolSet ToolSet) error { + return serveWithClientIO(ctx, nullifyClient, queryParams, toolSet, os.Stdin, os.Stdout) +} + +func serveWithClientIO(ctx context.Context, nullifyClient *client.NullifyClient, queryParams map[string]string, toolSet ToolSet, stdin io.Reader, stdout io.Writer) error { s := server.NewMCPServer( "Nullify", logger.Version, @@ -32,7 +38,7 @@ func ServeWithClient(ctx context.Context, nullifyClient *client.NullifyClient, q logger.L(ctx).Debug("starting MCP server over stdio", logger.String("toolSet", string(toolSet))) stdioServer := server.NewStdioServer(s) - return stdioServer.Listen(ctx, nil, nil) + return stdioServer.Listen(ctx, stdin, stdout) } func registerTools(s *server.MCPServer, c *client.NullifyClient, queryParams map[string]string, toolSet ToolSet) { diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index 302d21e..e90e770 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -1,8 +1,14 @@ package mcp import ( + "context" + "io" "strings" "testing" + + "github.com/nullify-platform/cli/internal/client" + "github.com/nullify-platform/logger/pkg/logger" + "github.com/stretchr/testify/require" ) func TestBuildQueryString(t *testing.T) { @@ -90,3 +96,19 @@ func TestGetIntArg(t *testing.T) { t.Errorf("getIntArg(name) = %d, want 10 (default for wrong type)", got) } } + +func TestServeWithClientIOHandlesEOF(t *testing.T) { + ctx, err := logger.ConfigureDevelopmentLogger(context.Background(), "error") + require.NoError(t, err) + + err = serveWithClientIO( + ctx, + client.NewNullifyClient("acme.nullify.ai", "token"), + map[string]string{}, + ToolSetDefault, + strings.NewReader(""), + io.Discard, + ) + + require.NoError(t, err) +} diff --git a/internal/mcp/tools_admin.go b/internal/mcp/tools_admin.go index b4f145c..9a1f65c 100644 --- a/internal/mcp/tools_admin.go +++ b/internal/mcp/tools_admin.go @@ -1,19 +1,71 @@ package mcp import ( + "context" + "time" + "github.com/nullify-platform/cli/internal/client" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" ) +func metricsOverviewBody() map[string]any { + return map[string]any{ + "query": map[string]any{ + "sort": []any{ + map[string]any{ + "isFalsePositive": map[string]any{ + "order": "asc", + "missing": 0, + }, + }, + }, + "isArchived": false, + }, + } +} + +func metricsOverTimeBody(period string) map[string]any { + now := time.Now().UTC() + var from time.Time + switch period { + case "7d": + from = now.AddDate(0, 0, -7) + case "90d": + from = now.AddDate(0, 0, -90) + case "365d": + from = now.AddDate(0, 0, -365) + default: // 30d + from = now.AddDate(0, 0, -30) + } + return map[string]any{ + "query": map[string]any{ + "sort": []any{ + map[string]any{ + "isFalsePositive": map[string]any{ + "order": "asc", + "missing": 0, + }, + }, + }, + "isArchived": false, + "fromDate": from.Format(time.RFC3339), + "toDate": now.Format(time.RFC3339), + }, + } +} + func registerAdminTools(s *server.MCPServer, c *client.NullifyClient, queryParams map[string]string) { s.AddTool( mcp.NewTool( "get_metrics_overview", mcp.WithDescription("Get a high-level security posture overview with counts of findings by severity and type. Use this to understand the overall security state."), ), - makeGetHandler(c, "/admin/metrics/overview", queryParams), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + qs := buildQueryString(queryParams) + return doPost(ctx, c, "/admin/metrics/overview"+qs, metricsOverviewBody()) + }, ) s.AddTool( @@ -22,7 +74,15 @@ func registerAdminTools(s *server.MCPServer, c *client.NullifyClient, queryParam mcp.WithDescription("Get security metrics trends over time. Shows how the number of findings has changed, useful for tracking security posture improvements."), mcp.WithString("period", mcp.Description("Time period"), mcp.Enum("7d", "30d", "90d", "365d")), ), - makeGetHandler(c, "/admin/metrics/over-time", queryParams), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args := request.GetArguments() + period := getStringArg(args, "period") + if period == "" { + period = "30d" + } + qs := buildQueryString(queryParams) + return doPost(ctx, c, "/admin/metrics/over-time"+qs, metricsOverTimeBody(period)) + }, ) s.AddTool( diff --git a/internal/mcp/tools_composite.go b/internal/mcp/tools_composite.go index 866cd18..bfd30f7 100644 --- a/internal/mcp/tools_composite.go +++ b/internal/mcp/tools_composite.go @@ -237,14 +237,14 @@ func registerCompositeTools(s *server.MCPServer, c *client.NullifyClient, queryP // Get overview overviewQS := buildQueryString(queryParams) - overviewResult, err := doGet(ctx, c, "/admin/metrics/overview"+overviewQS) + overviewResult, err := doPost(ctx, c, "/admin/metrics/overview"+overviewQS, metricsOverviewBody()) if err != nil { return toolError(err), nil } // Get over-time data - timeQS := buildQueryString(queryParams, "period", period) - timeResult, err := doGet(ctx, c, "/admin/metrics/over-time"+timeQS) + timeQS := buildQueryString(queryParams) + timeResult, err := doPost(ctx, c, "/admin/metrics/over-time"+timeQS, metricsOverTimeBody(period)) if err != nil { return toolError(err), nil } diff --git a/internal/mcp/tools_context.go b/internal/mcp/tools_context.go index aadca1f..e4a7bdf 100644 --- a/internal/mcp/tools_context.go +++ b/internal/mcp/tools_context.go @@ -51,10 +51,21 @@ func registerContextTools(s *server.MCPServer, c *client.NullifyClient, queryPar mcp.NewTool( "list_dependencies", mcp.WithDescription("List third-party dependencies across all monitored repositories. Useful for understanding your supply chain."), - mcp.WithString("repository", mcp.Description("Filter by repository name")), - mcp.WithNumber("limit", mcp.Description("Max results (default 20)")), + mcp.WithNumber("pageSize", mcp.Description("Max results per page (default 20)")), + mcp.WithString("cursor", mcp.Description("Pagination cursor from previous response")), ), - makeGetHandler(c, "/context/dependencies", queryParams), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args := request.GetArguments() + extra := []string{} + if ps := getIntArg(args, "pageSize", 0); ps > 0 { + extra = append(extra, "pageSize", fmt.Sprintf("%d", ps)) + } + if cur := getStringArg(args, "cursor"); cur != "" { + extra = append(extra, "cursor", cur) + } + qs := buildQueryString(queryParams, extra...) + return doGet(ctx, c, "/context/deps"+qs) + }, ) s.AddTool( @@ -77,6 +88,8 @@ func registerContextTools(s *server.MCPServer, c *client.NullifyClient, queryPar mcp.NewTool( "get_dependency_exposure", mcp.WithDescription("Get dependency exposure analysis showing which dependencies are exposed to the internet or internal networks."), + mcp.WithString("ecosystem", mcp.Required(), mcp.Description("Package ecosystem (e.g. npm, pip, maven, go, nuget)")), + mcp.WithString("name", mcp.Required(), mcp.Description("Dependency name to check exposure for")), ), makeGetHandler(c, "/context/deps/exposure", queryParams), ) diff --git a/internal/wizard/steps.go b/internal/wizard/steps.go index a0da0bd..fef7633 100644 --- a/internal/wizard/steps.go +++ b/internal/wizard/steps.go @@ -151,11 +151,6 @@ func MCPConfigStep() Step { } } -// mcpConfig is the structure for MCP configuration files. -type mcpConfig struct { - MCPServers map[string]mcpServerConfig `json:"mcpServers"` -} - type mcpServerConfig struct { Command string `json:"command"` Args []string `json:"args"` @@ -163,20 +158,26 @@ type mcpServerConfig struct { func writeMCPConfig(path string) error { // Read existing config if present - existing := mcpConfig{MCPServers: make(map[string]mcpServerConfig)} + existing := map[string]any{} if data, err := os.ReadFile(path); err == nil { _ = json.Unmarshal(data, &existing) } + mcpServers, ok := existing["mcpServers"].(map[string]any) + if !ok || mcpServers == nil { + mcpServers = make(map[string]any) + } + // Only add nullify if not already configured - if _, ok := existing.MCPServers["nullify"]; ok { + if _, ok := mcpServers["nullify"]; ok { return nil } - existing.MCPServers["nullify"] = mcpServerConfig{ + mcpServers["nullify"] = mcpServerConfig{ Command: "nullify", Args: []string{"mcp", "serve"}, } + existing["mcpServers"] = mcpServers data, err := json.MarshalIndent(existing, "", " ") if err != nil { diff --git a/internal/wizard/steps_test.go b/internal/wizard/steps_test.go new file mode 100644 index 0000000..da160db --- /dev/null +++ b/internal/wizard/steps_test.go @@ -0,0 +1,48 @@ +package wizard + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWriteMCPConfigPreservesUnknownTopLevelFields(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "mcp.json") + + err := os.WriteFile(configPath, []byte(`{ + "version": 1, + "mcpServers": { + "existing": { + "command": "existing", + "args": ["serve"] + } + }, + "other": { + "enabled": true + } +}`), 0600) + require.NoError(t, err) + + err = writeMCPConfig(configPath) + require.NoError(t, err) + + data, err := os.ReadFile(configPath) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + require.Equal(t, float64(1), parsed["version"]) + + other, ok := parsed["other"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, other["enabled"]) + + mcpServers, ok := parsed["mcpServers"].(map[string]any) + require.True(t, ok) + require.Contains(t, mcpServers, "existing") + require.Contains(t, mcpServers, "nullify") +}