From 76e4a0bc93338caed2f39da2f0e98c60d05b8606 Mon Sep 17 00:00:00 2001 From: MK Date: Thu, 19 Mar 2026 18:52:54 -0700 Subject: [PATCH 1/2] security: implement Phase 1 critical fixes (C-1 through C-7) - C-1/C-2/C-7: Add strict IPv4 parsing to reject octal/hex/packed SSRF bypass vectors, SafeDialer with post-DNS-resolution IP validation, IPv6 transition address detection (NAT64, 6to4, Teredo), and container-aware allowPrivateIPs for K8s inter-service communication - C-3: Strip Authorization/Cookie headers on cross-origin redirects in http_request and webhook_call tools - C-4: Replace wildcard CORS with origin allowlist (localhost defaults), configurable via --cors-origins flag, FORGE_CORS_ORIGINS env, or forge.yaml cors_origins field - C-6: Add X-Content-Type-Options, Referrer-Policy, X-Frame-Options, and Content-Security-Policy headers to all A2A server responses --- forge-cli/cmd/run.go | 12 + forge-cli/cmd/serve.go | 5 + forge-cli/runtime/runner.go | 31 ++- forge-cli/server/a2a_server.go | 98 ++++++++- forge-cli/server/a2a_server_test.go | 181 +++++++++++++++ forge-core/security/domain_matcher.go | 8 +- forge-core/security/domain_matcher_test.go | 6 + forge-core/security/egress_enforcer.go | 33 ++- forge-core/security/egress_enforcer_test.go | 44 +++- .../security/egress_integration_test.go | 6 +- forge-core/security/egress_proxy.go | 44 +++- forge-core/security/egress_proxy_test.go | 18 +- forge-core/security/ip_validator.go | 204 +++++++++++++++++ forge-core/security/ip_validator_test.go | 207 ++++++++++++++++++ forge-core/security/redirect.go | 63 ++++++ forge-core/security/redirect_test.go | 147 +++++++++++++ forge-core/security/safe_dialer.go | 92 ++++++++ forge-core/security/safe_dialer_test.go | 161 ++++++++++++++ forge-core/security/types.go | 11 +- forge-core/tools/adapters/webhook_call.go | 5 +- forge-core/tools/builtins/http_request.go | 5 +- forge-core/types/config.go | 10 +- 22 files changed, 1324 insertions(+), 67 deletions(-) create mode 100644 forge-cli/server/a2a_server_test.go create mode 100644 forge-core/security/ip_validator.go create mode 100644 forge-core/security/ip_validator_test.go create mode 100644 forge-core/security/redirect.go create mode 100644 forge-core/security/redirect_test.go create mode 100644 forge-core/security/safe_dialer.go create mode 100644 forge-core/security/safe_dialer_test.go diff --git a/forge-cli/cmd/run.go b/forge-cli/cmd/run.go index cb49bc4..d415353 100644 --- a/forge-cli/cmd/run.go +++ b/forge-cli/cmd/run.go @@ -30,6 +30,7 @@ var ( runWithChannels string runNoAuth bool runAuthToken string + runCORSOrigins string ) var runCmd = &cobra.Command{ @@ -51,6 +52,7 @@ func init() { runCmd.Flags().StringVar(&runWithChannels, "with", "", "comma-separated channel adapters to start (e.g. slack,telegram)") runCmd.Flags().BoolVar(&runNoAuth, "no-auth", false, "disable bearer token authentication (localhost only)") runCmd.Flags().StringVar(&runAuthToken, "auth-token", "", "explicit bearer token (default: auto-generated)") + runCmd.Flags().StringVar(&runCORSOrigins, "cors-origins", "", "comma-separated CORS allowed origins (default: localhost only, use '*' for wildcard)") } func runRun(cmd *cobra.Command, args []string) error { @@ -66,6 +68,15 @@ func runRun(cmd *cobra.Command, args []string) error { enforceGuardrails = false } + var corsOrigins []string + if runCORSOrigins != "" { + for _, o := range strings.Split(runCORSOrigins, ",") { + if o = strings.TrimSpace(o); o != "" { + corsOrigins = append(corsOrigins, o) + } + } + } + runner, err := runtime.NewRunner(runtime.RunnerConfig{ Config: cfg, WorkDir: workDir, @@ -81,6 +92,7 @@ func runRun(cmd *cobra.Command, args []string) error { Channels: activeChannels, NoAuth: runNoAuth, AuthToken: runAuthToken, + CORSOrigins: corsOrigins, }) if err != nil { return fmt.Errorf("creating runner: %w", err) diff --git a/forge-cli/cmd/serve.go b/forge-cli/cmd/serve.go index 17114f6..1a961b1 100644 --- a/forge-cli/cmd/serve.go +++ b/forge-cli/cmd/serve.go @@ -36,6 +36,7 @@ var ( serveWithChannels string serveNoAuth bool serveAuthToken string + serveCORSOrigins string ) var serveCmd = &cobra.Command{ @@ -96,6 +97,7 @@ func registerServeFlags(cmd *cobra.Command) { cmd.Flags().StringVar(&serveWithChannels, "with", "", "comma-separated channel adapters to start (e.g. slack,telegram)") cmd.Flags().BoolVar(&serveNoAuth, "no-auth", false, "disable bearer token authentication (localhost only)") cmd.Flags().StringVar(&serveAuthToken, "auth-token", "", "explicit bearer token (default: auto-generated)") + cmd.Flags().StringVar(&serveCORSOrigins, "cors-origins", "", "comma-separated CORS allowed origins (default: localhost only, use '*' for wildcard)") } func init() { @@ -191,6 +193,9 @@ func serveStartRun(cmd *cobra.Command, args []string) error { if serveAuthToken != "" { runArgs = append(runArgs, "--auth-token", serveAuthToken) } + if serveCORSOrigins != "" { + runArgs = append(runArgs, "--cors-origins", serveCORSOrigins) + } // Ensure .forge directory exists forgeDir := filepath.Dir(statePath) diff --git a/forge-cli/runtime/runner.go b/forge-cli/runtime/runner.go index 2a97806..3ab85ae 100644 --- a/forge-cli/runtime/runner.go +++ b/forge-cli/runtime/runner.go @@ -48,6 +48,7 @@ type RunnerConfig struct { Channels []string // active channel adapters from --with flag NoAuth bool // disable bearer token authentication AuthToken string // explicit bearer token (empty = auto-generate) + CORSOrigins []string // CORS allowed origins (from --cors-origins flag) } // ScheduleNotifier is called after a scheduled task completes to deliver the @@ -223,7 +224,15 @@ func (r *Runner) Run(ctx context.Context) error { r.logger.Warn("failed to resolve egress config, using default", map[string]any{"error": egressErr.Error()}) egressClient = http.DefaultClient } else { - enforcer := security.NewEgressEnforcer(nil, egressCfg.Mode, egressCfg.AllDomains) + // Resolve allowPrivateIPs: explicit config > container auto-detect > false + allowPrivateIPs := false + if r.cfg.Config.Egress.AllowPrivateIPs != nil { + allowPrivateIPs = *r.cfg.Config.Egress.AllowPrivateIPs + } else if security.InContainer() { + allowPrivateIPs = true + } + + enforcer := security.NewEgressEnforcer(nil, egressCfg.Mode, egressCfg.AllDomains, allowPrivateIPs) enforcer.OnAttempt = func(ctx context.Context, domain string, allowed bool) { event := coreruntime.AuditEgressAllowed if !allowed { @@ -241,7 +250,7 @@ func (r *Runner) Run(ctx context.Context) error { // Start local proxy for subprocess egress enforcement if !security.InContainer() && egressCfg.Mode != security.ModeDevOpen { matcher := security.NewDomainMatcher(egressCfg.Mode, egressCfg.AllDomains) - egressProxy = security.NewEgressProxy(matcher) + egressProxy = security.NewEgressProxy(matcher, allowPrivateIPs) egressProxy.OnAttempt = func(domain string, allowed bool) { event := coreruntime.AuditEgressAllowed if !allowed { @@ -583,6 +592,23 @@ func (r *Runner) Run(ctx context.Context) error { return fmt.Errorf("resolving auth: %w", err) } + // 6b. Resolve CORS origins: CLI flag > env var > forge.yaml > defaults + corsOrigins := r.cfg.CORSOrigins + if len(corsOrigins) == 0 { + if envCORS := os.Getenv("FORGE_CORS_ORIGINS"); envCORS != "" { + corsOrigins = strings.Split(envCORS, ",") + for i := range corsOrigins { + corsOrigins[i] = strings.TrimSpace(corsOrigins[i]) + } + } + } + if len(corsOrigins) == 0 && len(r.cfg.Config.CORSOrigins) > 0 { + corsOrigins = r.cfg.Config.CORSOrigins + } + if len(corsOrigins) == 0 { + corsOrigins = server.DefaultAllowedOrigins() + } + // 6. Create A2A server r.startTime = time.Now() srv := server.NewServer(server.ServerConfig{ @@ -591,6 +617,7 @@ func (r *Runner) Run(ctx context.Context) error { ShutdownTimeout: r.cfg.ShutdownTimeout, AgentCard: card, AuthMiddleware: auth.Middleware(authCfg), + AllowedOrigins: corsOrigins, }) // 7. Register JSON-RPC handlers diff --git a/forge-cli/server/a2a_server.go b/forge-cli/server/a2a_server.go index dcfde34..03178c6 100644 --- a/forge-cli/server/a2a_server.go +++ b/forge-cli/server/a2a_server.go @@ -8,6 +8,7 @@ import ( "log" "net" "net/http" + "strings" "sync" "syscall" "time" @@ -28,6 +29,7 @@ type ServerConfig struct { ShutdownTimeout time.Duration // graceful shutdown timeout (0 = immediate) AgentCard *a2a.AgentCard AuthMiddleware func(http.Handler) http.Handler // optional auth middleware + AllowedOrigins []string // CORS allowed origins } type httpRoute struct { @@ -47,11 +49,16 @@ type Server struct { sseHandlers map[string]SSEHandler httpHandlers []httpRoute authMiddleware func(http.Handler) http.Handler + allowedOrigins []string srv *http.Server } // NewServer creates a new A2A server. func NewServer(cfg ServerConfig) *Server { + allowedOrigins := cfg.AllowedOrigins + if len(allowedOrigins) == 0 { + allowedOrigins = DefaultAllowedOrigins() + } s := &Server{ port: cfg.Port, host: cfg.Host, @@ -61,6 +68,7 @@ func NewServer(cfg ServerConfig) *Server { handlers: make(map[string]Handler), sseHandlers: make(map[string]SSEHandler), authMiddleware: cfg.AuthMiddleware, + allowedOrigins: allowedOrigins, } return s } @@ -121,13 +129,14 @@ func (s *Server) Start(ctx context.Context) error { mux.HandleFunc("POST /", s.handleJSONRPC) mux.HandleFunc("GET /", s.handleAgentCard) - // Build handler chain: CORS → Auth → Mux + // Build handler chain: CORS → Security Headers → Auth → Mux // CORS is outermost so OPTIONS preflight is handled before auth. var handler http.Handler = mux if s.authMiddleware != nil { handler = s.authMiddleware(handler) } - handler = corsMiddleware(handler) + handler = securityHeadersMiddleware(handler) + handler = newCORSMiddleware(s.allowedOrigins)(handler) s.srv = &http.Server{ Handler: handler, @@ -233,15 +242,84 @@ func writeJSON(w http.ResponseWriter, status int, v any) { json.NewEncoder(w).Encode(v) //nolint:errcheck } -func corsMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusNoContent) - return +// DefaultAllowedOrigins returns the default CORS origins for local development. +func DefaultAllowedOrigins() []string { + return []string{ + "http://localhost", + "https://localhost", + "http://127.0.0.1", + "https://127.0.0.1", + "http://[::1]", + "https://[::1]", + } +} + +// isOriginAllowed checks if the given origin matches the allowlist. +// Supports exact match and prefix+port matching (e.g. "http://localhost" matches +// "http://localhost:3000"). If the allowed list contains "*", all origins pass. +func isOriginAllowed(origin string, allowed []string) bool { + if origin == "" { + return false + } + for _, a := range allowed { + if a == "*" { + return true + } + if strings.EqualFold(origin, a) { + return true + } + // Prefix+colon match for port variants: "http://localhost" matches "http://localhost:3000" + if strings.HasPrefix(strings.ToLower(origin), strings.ToLower(a)+":") { + return true } + } + return false +} + +// newCORSMiddleware returns CORS middleware that restricts origins to the allowlist. +// When the allowlist contains "*", it behaves as a wildcard (Access-Control-Allow-Origin: *). +// Otherwise it echoes the matched origin and adds Vary: Origin. +func newCORSMiddleware(allowed []string) func(http.Handler) http.Handler { + hasWildcard := false + for _, a := range allowed { + if a == "*" { + hasWildcard = true + break + } + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + + if hasWildcard { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + } else if isOriginAllowed(origin, allowed) { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + w.Header().Set("Vary", "Origin") + } + // Non-matching origins: no CORS headers added + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + next.ServeHTTP(w, r) + }) + } +} + +// securityHeadersMiddleware adds security headers to every response. +func securityHeadersMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("Referrer-Policy", "no-referrer") + w.Header().Set("X-Frame-Options", "DENY") + w.Header().Set("Content-Security-Policy", "default-src 'none'") next.ServeHTTP(w, r) }) } diff --git a/forge-cli/server/a2a_server_test.go b/forge-cli/server/a2a_server_test.go new file mode 100644 index 0000000..d9f92bf --- /dev/null +++ b/forge-cli/server/a2a_server_test.go @@ -0,0 +1,181 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestSecurityHeadersPresent(t *testing.T) { + handler := securityHeadersMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + tests := []struct { + method string + }{ + {"GET"}, + {"POST"}, + {"OPTIONS"}, + } + + for _, tt := range tests { + t.Run(tt.method, func(t *testing.T) { + req := httptest.NewRequest(tt.method, "/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + expected := map[string]string{ + "X-Content-Type-Options": "nosniff", + "Referrer-Policy": "no-referrer", + "X-Frame-Options": "DENY", + "Content-Security-Policy": "default-src 'none'", + } + for header, want := range expected { + got := rec.Header().Get(header) + if got != want { + t.Errorf("%s %s: header %q = %q, want %q", tt.method, "/", header, got, want) + } + } + }) + } +} + +func TestCORSAllowlistMatchedOrigin(t *testing.T) { + allowed := []string{"http://localhost", "https://app.example.com"} + middleware := newCORSMiddleware(allowed)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + tests := []struct { + name string + origin string + wantCORS bool + wantOriginEcho string + }{ + {"matching origin", "http://localhost", true, "http://localhost"}, + {"matching with port", "http://localhost:3000", true, "http://localhost:3000"}, + {"matching exact", "https://app.example.com", true, "https://app.example.com"}, + {"non-matching origin", "https://evil.com", false, ""}, + {"no origin header", "", false, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + if tt.origin != "" { + req.Header.Set("Origin", tt.origin) + } + rec := httptest.NewRecorder() + middleware.ServeHTTP(rec, req) + + gotOrigin := rec.Header().Get("Access-Control-Allow-Origin") + if tt.wantCORS { + if gotOrigin != tt.wantOriginEcho { + t.Errorf("Access-Control-Allow-Origin = %q, want %q", gotOrigin, tt.wantOriginEcho) + } + if rec.Header().Get("Vary") != "Origin" { + t.Error("expected Vary: Origin header") + } + } else { + if gotOrigin != "" { + t.Errorf("expected no CORS headers, got Access-Control-Allow-Origin = %q", gotOrigin) + } + } + }) + } +} + +func TestCORSWildcard(t *testing.T) { + middleware := newCORSMiddleware([]string{"*"})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Origin", "https://anything.com") + rec := httptest.NewRecorder() + middleware.ServeHTTP(rec, req) + + if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "*" { + t.Errorf("Access-Control-Allow-Origin = %q, want %q", got, "*") + } +} + +func TestCORSPreflight(t *testing.T) { + middleware := newCORSMiddleware([]string{"http://localhost"})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("next handler should not be called for OPTIONS") + })) + + req := httptest.NewRequest("OPTIONS", "/", nil) + req.Header.Set("Origin", "http://localhost") + rec := httptest.NewRecorder() + middleware.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Errorf("OPTIONS status = %d, want %d", rec.Code, http.StatusNoContent) + } + if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost" { + t.Errorf("Access-Control-Allow-Origin = %q, want %q", got, "http://localhost") + } +} + +func TestDefaultAllowedOrigins(t *testing.T) { + origins := DefaultAllowedOrigins() + if len(origins) == 0 { + t.Fatal("DefaultAllowedOrigins should return at least one origin") + } + + expected := map[string]bool{ + "http://localhost": true, + "https://localhost": true, + "http://127.0.0.1": true, + "https://127.0.0.1": true, + "http://[::1]": true, + "https://[::1]": true, + } + for _, o := range origins { + if !expected[o] { + t.Errorf("unexpected origin in defaults: %q", o) + } + } +} + +func TestIsOriginAllowed(t *testing.T) { + allowed := []string{"http://localhost", "https://app.example.com"} + + tests := []struct { + origin string + want bool + }{ + {"http://localhost", true}, + {"http://localhost:3000", true}, + {"https://app.example.com", true}, + {"https://evil.com", false}, + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.origin, func(t *testing.T) { + if got := isOriginAllowed(tt.origin, allowed); got != tt.want { + t.Errorf("isOriginAllowed(%q) = %v, want %v", tt.origin, got, tt.want) + } + }) + } +} + +func TestSecurityHeadersOnErrorResponses(t *testing.T) { + handler := securityHeadersMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + + req := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized) + } + if got := rec.Header().Get("X-Content-Type-Options"); got != "nosniff" { + t.Errorf("X-Content-Type-Options = %q on 401 response", got) + } +} diff --git a/forge-core/security/domain_matcher.go b/forge-core/security/domain_matcher.go index a74e4cc..bb1a373 100644 --- a/forge-core/security/domain_matcher.go +++ b/forge-core/security/domain_matcher.go @@ -70,10 +70,16 @@ func (m *DomainMatcher) Mode() EgressMode { } // IsLocalhost returns true for loopback addresses. +// It uses strict IPv4 parsing to prevent octal/hex bypass (e.g. 0177.0.0.1). func IsLocalhost(host string) bool { if host == "localhost" { return true } + // Strict IPv4 check prevents octal/hex bypass + if ip4 := ParseStrictIPv4(host); ip4 != nil { + return ip4.IsLoopback() + } + // IPv6 only (To4() == nil ensures we don't re-check IPv4) ip := net.ParseIP(host) - return ip != nil && ip.IsLoopback() + return ip != nil && ip.To4() == nil && ip.IsLoopback() } diff --git a/forge-core/security/domain_matcher_test.go b/forge-core/security/domain_matcher_test.go index ff5ac40..5ed3ed6 100644 --- a/forge-core/security/domain_matcher_test.go +++ b/forge-core/security/domain_matcher_test.go @@ -95,6 +95,12 @@ func TestIsLocalhostExported(t *testing.T) { {"example.com", false}, {"192.168.1.1", false}, {"10.0.0.1", false}, + // SSRF bypass vectors — must NOT be recognized as localhost + {"0177.0.0.1", false}, // octal + {"0x7f.0.0.1", false}, // hex + {"0x7f000001", false}, // hex packed + {"2130706433", false}, // decimal packed + {"127.0.0.01", false}, // leading zero } for _, tt := range tests { diff --git a/forge-core/security/egress_enforcer.go b/forge-core/security/egress_enforcer.go index 6c7ea8b..114b1ac 100644 --- a/forge-core/security/egress_enforcer.go +++ b/forge-core/security/egress_enforcer.go @@ -13,22 +13,24 @@ type egressClientKey struct{} // EgressEnforcer is an http.RoundTripper that validates outbound requests // against a domain allowlist before forwarding them to the base transport. type EgressEnforcer struct { - base http.RoundTripper - matcher *DomainMatcher - OnAttempt func(ctx context.Context, domain string, allowed bool) + base http.RoundTripper + matcher *DomainMatcher + AllowPrivateIPs bool + OnAttempt func(ctx context.Context, domain string, allowed bool) } // NewEgressEnforcer creates a new EgressEnforcer wrapping the given base transport. -// If base is nil, http.DefaultTransport is used. Domains may include wildcard -// prefixes (e.g. "*.github.com") which match any subdomain. -func NewEgressEnforcer(base http.RoundTripper, mode EgressMode, domains []string) *EgressEnforcer { +// If base is nil, a SafeTransport is used instead of http.DefaultTransport. +// Domains may include wildcard prefixes (e.g. "*.github.com") which match any subdomain. +func NewEgressEnforcer(base http.RoundTripper, mode EgressMode, domains []string, allowPrivateIPs bool) *EgressEnforcer { if base == nil { - base = http.DefaultTransport + base = NewSafeTransport(nil, allowPrivateIPs) } return &EgressEnforcer{ - base: base, - matcher: NewDomainMatcher(mode, domains), + base: base, + matcher: NewDomainMatcher(mode, domains), + AllowPrivateIPs: allowPrivateIPs, } } @@ -39,12 +41,21 @@ func (e *EgressEnforcer) RoundTrip(req *http.Request) (*http.Response, error) { ctx := req.Context() - // Localhost is always allowed. + // Reject non-standard IP formats (octal, hex, packed decimal) early. + if err := ValidateHostIP(host); err != nil { + if e.OnAttempt != nil { + e.OnAttempt(ctx, host, false) + } + return nil, fmt.Errorf("egress blocked: %w", err) + } + + // Localhost is always allowed. Use http.DefaultTransport to bypass the + // safe dialer (which blocks loopback IPs for DNS rebinding protection). if IsLocalhost(host) { if e.OnAttempt != nil { e.OnAttempt(ctx, host, true) } - return e.base.RoundTrip(req) + return http.DefaultTransport.RoundTrip(req) } allowed := e.matcher.IsAllowed(host) diff --git a/forge-core/security/egress_enforcer_test.go b/forge-core/security/egress_enforcer_test.go index 67b1688..bf1c5fe 100644 --- a/forge-core/security/egress_enforcer_test.go +++ b/forge-core/security/egress_enforcer_test.go @@ -92,7 +92,7 @@ func TestEgressEnforcerAllowlist(t *testing.T) { })) defer ts.Close() - enforcer := NewEgressEnforcer(http.DefaultTransport, ModeAllowlist, tt.domains) + enforcer := NewEgressEnforcer(http.DefaultTransport, ModeAllowlist, tt.domains, false) req, err := http.NewRequest("GET", tt.url, nil) if err != nil { @@ -119,7 +119,7 @@ func TestEgressEnforcerAllowlist(t *testing.T) { } func TestEgressEnforcerDenyAll(t *testing.T) { - enforcer := NewEgressEnforcer(http.DefaultTransport, ModeDenyAll, nil) + enforcer := NewEgressEnforcer(http.DefaultTransport, ModeDenyAll, nil, false) req, _ := http.NewRequest("GET", "https://api.openai.com/v1/chat", nil) _, err := enforcer.RoundTrip(req) @@ -134,7 +134,7 @@ func TestEgressEnforcerDenyAllAllowsLocalhost(t *testing.T) { })) defer ts.Close() - enforcer := NewEgressEnforcer(http.DefaultTransport, ModeDenyAll, nil) + enforcer := NewEgressEnforcer(http.DefaultTransport, ModeDenyAll, nil, false) req, _ := http.NewRequest("GET", ts.URL+"/test", nil) resp, err := enforcer.RoundTrip(req) @@ -150,7 +150,7 @@ func TestEgressEnforcerDevOpen(t *testing.T) { })) defer ts.Close() - enforcer := NewEgressEnforcer(http.DefaultTransport, ModeDevOpen, nil) + enforcer := NewEgressEnforcer(http.DefaultTransport, ModeDevOpen, nil, false) req, _ := http.NewRequest("GET", ts.URL+"/test", nil) resp, err := enforcer.RoundTrip(req) @@ -167,7 +167,7 @@ func TestEgressEnforcerOnAttemptCallback(t *testing.T) { allowed bool } - enforcer := NewEgressEnforcer(http.DefaultTransport, ModeAllowlist, []string{"api.openai.com"}) + enforcer := NewEgressEnforcer(http.DefaultTransport, ModeAllowlist, []string{"api.openai.com"}, false) enforcer.OnAttempt = func(_ context.Context, domain string, allowed bool) { mu.Lock() calls = append(calls, struct { @@ -201,7 +201,7 @@ func TestEgressEnforcerOnAttemptCallback(t *testing.T) { func TestEgressEnforcerDevOpenCallback(t *testing.T) { var called bool - enforcer := NewEgressEnforcer(http.DefaultTransport, ModeDevOpen, nil) + enforcer := NewEgressEnforcer(http.DefaultTransport, ModeDevOpen, nil, false) enforcer.OnAttempt = func(_ context.Context, domain string, allowed bool) { called = true if !allowed { @@ -256,9 +256,9 @@ func TestEgressContextMissing(t *testing.T) { } func TestEgressEnforcerNilBase(t *testing.T) { - enforcer := NewEgressEnforcer(nil, ModeAllowlist, []string{"example.com"}) + enforcer := NewEgressEnforcer(nil, ModeAllowlist, []string{"example.com"}, false) if enforcer.base == nil { - t.Error("nil base should be replaced with http.DefaultTransport") + t.Error("nil base should be replaced with SafeTransport") } } @@ -282,3 +282,31 @@ func TestIsLocalhost(t *testing.T) { }) } } + +func TestEgressEnforcerSSRFBypass(t *testing.T) { + // Non-standard IP formats should be blocked by ValidateHostIP + enforcer := NewEgressEnforcer(http.DefaultTransport, ModeAllowlist, []string{"example.com"}, false) + + tests := []struct { + name string + url string + }{ + {"octal loopback", "http://0177.0.0.1/secret"}, + {"hex loopback", "http://0x7f000001/secret"}, + {"packed decimal", "http://2130706433/secret"}, + {"leading zero", "http://127.0.0.01/secret"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, _ := http.NewRequest("GET", tt.url, nil) + _, err := enforcer.RoundTrip(req) + if err == nil { + t.Error("expected SSRF bypass to be blocked") + } + if err != nil && !strings.Contains(err.Error(), "egress blocked") { + t.Errorf("expected 'egress blocked' error, got: %v", err) + } + }) + } +} diff --git a/forge-core/security/egress_integration_test.go b/forge-core/security/egress_integration_test.go index 1551a56..f0d263c 100644 --- a/forge-core/security/egress_integration_test.go +++ b/forge-core/security/egress_integration_test.go @@ -20,7 +20,7 @@ func TestEgressEnforcerIntegration(t *testing.T) { defer ts.Close() // Build an enforcer that allows only localhost (test server is localhost) - enforcer := NewEgressEnforcer(http.DefaultTransport, ModeAllowlist, []string{"api.openai.com"}) + enforcer := NewEgressEnforcer(http.DefaultTransport, ModeAllowlist, []string{"api.openai.com"}, false) var attemptLog []struct { domain string @@ -78,7 +78,7 @@ func TestEgressEnforcerDenyAllIntegration(t *testing.T) { })) defer ts.Close() - enforcer := NewEgressEnforcer(http.DefaultTransport, ModeDenyAll, nil) + enforcer := NewEgressEnforcer(http.DefaultTransport, ModeDenyAll, nil, false) client := &http.Client{Transport: enforcer} // Localhost should still work even in deny-all @@ -90,7 +90,7 @@ func TestEgressEnforcerDenyAllIntegration(t *testing.T) { } func TestEgressEnforcerWildcardIntegration(t *testing.T) { - enforcer := NewEgressEnforcer(http.DefaultTransport, ModeAllowlist, []string{"*.example.com"}) + enforcer := NewEgressEnforcer(http.DefaultTransport, ModeAllowlist, []string{"*.example.com"}, false) var attempts []struct { domain string diff --git a/forge-core/security/egress_proxy.go b/forge-core/security/egress_proxy.go index 6c167fc..7853cc5 100644 --- a/forge-core/security/egress_proxy.go +++ b/forge-core/security/egress_proxy.go @@ -15,18 +15,23 @@ import ( // It is used to enforce egress rules on subprocesses (skill scripts) that // cannot use the Go-level EgressEnforcer RoundTripper. type EgressProxy struct { - matcher *DomainMatcher - listener net.Listener - srv *http.Server - addr string // "127.0.0.1:" - OnAttempt func(domain string, allowed bool) + matcher *DomainMatcher + safeDialer *SafeDialer + safeTransport *http.Transport + listener net.Listener + srv *http.Server + addr string // "127.0.0.1:" + OnAttempt func(domain string, allowed bool) } // NewEgressProxy creates a new EgressProxy that validates domains using the // given DomainMatcher. Call Start to bind and begin serving. -func NewEgressProxy(matcher *DomainMatcher) *EgressProxy { +func NewEgressProxy(matcher *DomainMatcher, allowPrivateIPs bool) *EgressProxy { + sd := NewSafeDialer(nil, allowPrivateIPs) return &EgressProxy{ - matcher: matcher, + matcher: matcher, + safeDialer: sd, + safeTransport: NewSafeTransport(nil, allowPrivateIPs), } } @@ -101,7 +106,12 @@ func (p *EgressProxy) handleHTTP(w http.ResponseWriter, req *http.Request) { outReq.Header.Del("Proxy-Connection") outReq.Header.Del("Proxy-Authorization") - resp, err := http.DefaultTransport.RoundTrip(outReq) + // Use http.DefaultTransport for localhost (safe dialer blocks loopback). + var transport http.RoundTripper = p.safeTransport + if IsLocalhost(host) { + transport = http.DefaultTransport + } + resp, err := transport.RoundTrip(outReq) if err != nil { http.Error(w, "egress proxy: upstream error", http.StatusBadGateway) return @@ -128,8 +138,16 @@ func (p *EgressProxy) handleConnect(w http.ResponseWriter, req *http.Request) { return } - // Dial the upstream - upstream, err := net.DialTimeout("tcp", req.Host, 10*time.Second) + // Dial the upstream. Use safe dialer for non-localhost, standard dial for localhost + // (safe dialer blocks loopback IPs for DNS rebinding protection). + var upstream net.Conn + var err error + if IsLocalhost(host) { + upstream, err = net.DialTimeout("tcp", req.Host, 10*time.Second) + } else { + ctx := req.Context() + upstream, err = p.safeDialer.SafeDialContext(ctx, "tcp", req.Host) + } if err != nil { http.Error(w, "egress proxy: failed to connect upstream", http.StatusBadGateway) return @@ -165,6 +183,12 @@ func (p *EgressProxy) handleConnect(w http.ResponseWriter, req *http.Request) { // checkDomain validates a host against the matcher, allowing localhost always. func (p *EgressProxy) checkDomain(host string) bool { + // Reject non-standard IP formats early + if err := ValidateHostIP(host); err != nil { + p.fireCallback(host, false) + return false + } + // Localhost is always allowed if IsLocalhost(host) { p.fireCallback(host, true) diff --git a/forge-core/security/egress_proxy_test.go b/forge-core/security/egress_proxy_test.go index 2ce5d1b..3f01e61 100644 --- a/forge-core/security/egress_proxy_test.go +++ b/forge-core/security/egress_proxy_test.go @@ -24,7 +24,7 @@ func TestEgressProxyAllowedHTTP(t *testing.T) { upstreamURL, _ := url.Parse(upstream.URL) matcher := NewDomainMatcher(ModeAllowlist, []string{upstreamURL.Hostname()}) - proxy := NewEgressProxy(matcher) + proxy := NewEgressProxy(matcher, false) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -57,7 +57,7 @@ func TestEgressProxyAllowedHTTP(t *testing.T) { func TestEgressProxyBlockedHTTP(t *testing.T) { matcher := NewDomainMatcher(ModeAllowlist, []string{"allowed.com"}) - proxy := NewEgressProxy(matcher) + proxy := NewEgressProxy(matcher, false) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -90,7 +90,7 @@ func TestEgressProxyBlockedHTTP(t *testing.T) { func TestEgressProxyLocalhostAlwaysAllowed(t *testing.T) { // Even with deny-all, localhost should pass matcher := NewDomainMatcher(ModeDenyAll, nil) - proxy := NewEgressProxy(matcher) + proxy := NewEgressProxy(matcher, false) // Start a local test server upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -128,7 +128,7 @@ func TestEgressProxyLocalhostAlwaysAllowed(t *testing.T) { func TestEgressProxyCONNECTBlocked(t *testing.T) { matcher := NewDomainMatcher(ModeAllowlist, []string{"allowed.com"}) - proxy := NewEgressProxy(matcher) + proxy := NewEgressProxy(matcher, false) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -167,7 +167,7 @@ func TestEgressProxyCONNECTAllowed(t *testing.T) { host, port, _ := net.SplitHostPort(upstreamURL.Host) matcher := NewDomainMatcher(ModeAllowlist, []string{host}) - proxy := NewEgressProxy(matcher) + proxy := NewEgressProxy(matcher, false) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -209,7 +209,7 @@ func TestEgressProxyDevOpenPassthrough(t *testing.T) { // dev-open should allow everything matcher := NewDomainMatcher(ModeDevOpen, nil) - proxy := NewEgressProxy(matcher) + proxy := NewEgressProxy(matcher, false) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -248,7 +248,7 @@ func TestEgressProxyOnAttemptCallback(t *testing.T) { upstreamURL, _ := url.Parse(upstream.URL) matcher := NewDomainMatcher(ModeAllowlist, []string{upstreamURL.Hostname()}) - proxy := NewEgressProxy(matcher) + proxy := NewEgressProxy(matcher, false) var mu sync.Mutex var calls []struct { @@ -304,7 +304,7 @@ func TestEgressProxyOnAttemptCallback(t *testing.T) { func TestEgressProxyStop(t *testing.T) { matcher := NewDomainMatcher(ModeDevOpen, nil) - proxy := NewEgressProxy(matcher) + proxy := NewEgressProxy(matcher, false) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -326,7 +326,7 @@ func TestEgressProxyStop(t *testing.T) { } func TestEgressProxyURL(t *testing.T) { - proxy := NewEgressProxy(NewDomainMatcher(ModeDevOpen, nil)) + proxy := NewEgressProxy(NewDomainMatcher(ModeDevOpen, nil), false) if proxy.ProxyURL() != "" { t.Error("ProxyURL should be empty before Start") } diff --git a/forge-core/security/ip_validator.go b/forge-core/security/ip_validator.go new file mode 100644 index 0000000..cdc51ba --- /dev/null +++ b/forge-core/security/ip_validator.go @@ -0,0 +1,204 @@ +package security + +import ( + "fmt" + "net" + "regexp" + "strings" +) + +// strictIPv4Re matches exactly four dotted-decimal octets (no leading zeros, 0-255). +var strictIPv4Re = regexp.MustCompile( + `^(25[0-5]|2[0-4]\d|1\d\d|[1-9]\d|\d)\.(25[0-5]|2[0-4]\d|1\d\d|[1-9]\d|\d)\.(25[0-5]|2[0-4]\d|1\d\d|[1-9]\d|\d)\.(25[0-5]|2[0-4]\d|1\d\d|[1-9]\d|\d)$`, +) + +// alwaysBlockedCIDRs are blocked regardless of allowPrivate setting. +// Cloud metadata and loopback must never be reachable. +var alwaysBlockedCIDRs []*net.IPNet + +// privateBlockedCIDRs are blocked only when allowPrivate is false. +var privateBlockedCIDRs []*net.IPNet + +func init() { + for _, cidr := range []string{ + "169.254.169.254/32", // cloud metadata endpoint + "127.0.0.0/8", // IPv4 loopback + "::1/128", // IPv6 loopback + "0.0.0.0/8", // "this" network + } { + _, n, _ := net.ParseCIDR(cidr) + alwaysBlockedCIDRs = append(alwaysBlockedCIDRs, n) + } + for _, cidr := range []string{ + "10.0.0.0/8", // RFC 1918 + "172.16.0.0/12", // RFC 1918 + "192.168.0.0/16", // RFC 1918 + "169.254.0.0/16", // link-local + "100.64.0.0/10", // CGNAT + "fc00::/7", // IPv6 ULA + "fe80::/10", // IPv6 link-local + } { + _, n, _ := net.ParseCIDR(cidr) + privateBlockedCIDRs = append(privateBlockedCIDRs, n) + } +} + +// ParseStrictIPv4 parses an IPv4 address in strict dotted-decimal notation. +// It rejects octal (0177.0.0.1), hex (0x7f.0.0.1), packed decimal (2130706433), +// and leading-zero forms (127.0.0.01). Returns nil if the input is not a valid +// strict IPv4 address. +func ParseStrictIPv4(s string) net.IP { + if !strictIPv4Re.MatchString(s) { + return nil + } + return net.ParseIP(s).To4() +} + +// IsBlockedIP checks whether an IP is in a blocked CIDR range. +// When allowPrivate is true, RFC 1918 and link-local ranges are permitted +// (for container/K8s environments), but cloud metadata and loopback are +// always blocked. Returns true (blocked) for nil IPs (fail closed). +func IsBlockedIP(ip net.IP, allowPrivate bool) bool { + if ip == nil { + return true // fail closed + } + + // Check IPv6 transition addresses that embed blocked IPv4 + if isBlockedIPv6Transition(ip, allowPrivate) { + return true + } + + for _, n := range alwaysBlockedCIDRs { + if n.Contains(ip) { + return true + } + } + + if !allowPrivate { + for _, n := range privateBlockedCIDRs { + if n.Contains(ip) { + return true + } + } + } + + return false +} + +// isBlockedIPv6Transition detects IPv6 transition addresses (NAT64, 6to4, Teredo) +// that embed blocked IPv4 addresses. +func isBlockedIPv6Transition(ip net.IP, allowPrivate bool) bool { + // Ensure we're working with a 16-byte representation + ip16 := ip.To16() + if ip16 == nil { + return false + } + // Skip if this is actually an IPv4 address (mapped or native) + if ip.To4() != nil { + return false + } + + // NAT64: 64:ff9b::/96 — embedded IPv4 in last 4 bytes + nat64Prefix := []byte{0, 0x64, 0xff, 0x9b, 0, 0, 0, 0, 0, 0, 0, 0} + if bytesEqual(ip16[:12], nat64Prefix) { + embedded := net.IP(ip16[12:16]) + return IsBlockedIP(embedded.To4(), allowPrivate) + } + + // NAT64 extended: 64:ff9b:1::/48 — embedded IPv4 in last 4 bytes + nat64ExtPrefix := []byte{0, 0x64, 0xff, 0x9b, 0, 0x01} + if bytesEqual(ip16[:6], nat64ExtPrefix) { + embedded := net.IP(ip16[12:16]) + return IsBlockedIP(embedded.To4(), allowPrivate) + } + + // 6to4: 2002::/16 — embedded IPv4 in bytes 2-5 + if ip16[0] == 0x20 && ip16[1] == 0x02 { + embedded := net.IPv4(ip16[2], ip16[3], ip16[4], ip16[5]) + return IsBlockedIP(embedded.To4(), allowPrivate) + } + + // Teredo: 2001:0000::/32 — XOR'd IPv4 in last 4 bytes + if ip16[0] == 0x20 && ip16[1] == 0x01 && ip16[2] == 0x00 && ip16[3] == 0x00 { + embedded := net.IPv4(ip16[12]^0xff, ip16[13]^0xff, ip16[14]^0xff, ip16[15]^0xff) + return IsBlockedIP(embedded.To4(), allowPrivate) + } + + return false +} + +func bytesEqual(a, b []byte) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +// looksLikeIP returns true if the string looks like it might be a non-standard +// IP representation: digit-only strings (packed decimal), 0x prefix (hex), +// or digit+dot strings that fail strict IPv4 parsing. +func looksLikeIP(s string) bool { + if s == "" { + return false + } + // Hex prefix + if strings.HasPrefix(s, "0x") || strings.HasPrefix(s, "0X") { + return true + } + + hasLetter := false + hasDot := false + allDigitsOrDots := true + for _, c := range s { + if c == '.' { + hasDot = true + continue + } + if c >= '0' && c <= '9' { + continue + } + allDigitsOrDots = false + if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '-' { + hasLetter = true + } + } + + // Pure digit string like "2130706433" (packed decimal) + if allDigitsOrDots && !hasDot { + return true + } + + // Digit+dot string that failed strict parse means octal or leading-zero + if allDigitsOrDots && hasDot && !hasLetter { + return ParseStrictIPv4(s) == nil + } + + return false +} + +// ValidateHostIP validates that a hostname is not using a non-standard IP format +// that could bypass security checks. It rejects octal, hex, packed decimal, and +// leading-zero IP representations. +func ValidateHostIP(host string) error { + // If it's a valid strict IPv4, it's fine (checked elsewhere for blocked ranges) + if ParseStrictIPv4(host) != nil { + return nil + } + + // If it's a standard IPv6, it's fine + if net.ParseIP(host) != nil { + return nil + } + + // Check if it looks like a non-standard IP format + if looksLikeIP(host) { + return fmt.Errorf("rejected non-standard IP format: %q", host) + } + + return nil +} diff --git a/forge-core/security/ip_validator_test.go b/forge-core/security/ip_validator_test.go new file mode 100644 index 0000000..ba703ee --- /dev/null +++ b/forge-core/security/ip_validator_test.go @@ -0,0 +1,207 @@ +package security + +import ( + "net" + "testing" +) + +func TestParseStrictIPv4(t *testing.T) { + tests := []struct { + name string + input string + valid bool + }{ + {"standard loopback", "127.0.0.1", true}, + {"standard private", "10.0.0.1", true}, + {"standard public", "8.8.8.8", true}, + {"zero", "0.0.0.0", true}, + {"max", "255.255.255.255", true}, + {"octal loopback", "0177.0.0.1", false}, + {"hex loopback", "0x7f.0.0.1", false}, + {"packed decimal", "2130706433", false}, + {"leading zero", "127.0.0.01", false}, + {"leading zero octet", "010.0.0.1", false}, + {"empty", "", false}, + {"hostname", "example.com", false}, + {"too many octets", "1.2.3.4.5", false}, + {"too few octets", "1.2.3", false}, + {"negative", "-1.0.0.1", false}, + {"overflow octet", "256.0.0.1", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ParseStrictIPv4(tt.input) + if tt.valid && got == nil { + t.Errorf("ParseStrictIPv4(%q) = nil, want valid IP", tt.input) + } + if !tt.valid && got != nil { + t.Errorf("ParseStrictIPv4(%q) = %v, want nil", tt.input, got) + } + }) + } +} + +func TestIsBlockedIP(t *testing.T) { + tests := []struct { + name string + ip string + allowPrivate bool + blocked bool + }{ + // Always blocked + {"nil IP", "", false, true}, + {"loopback", "127.0.0.1", false, true}, + {"loopback allowPrivate", "127.0.0.1", true, true}, + {"metadata", "169.254.169.254", false, true}, + {"metadata allowPrivate", "169.254.169.254", true, true}, + {"ipv6 loopback", "::1", false, true}, + {"ipv6 loopback allowPrivate", "::1", true, true}, + {"this network", "0.0.0.0", false, true}, + + // Private ranges - blocked when allowPrivate=false + {"rfc1918 10.x", "10.0.0.1", false, true}, + {"rfc1918 10.x allowPrivate", "10.0.0.1", true, false}, + {"rfc1918 172.16.x", "172.16.0.1", false, true}, + {"rfc1918 172.16.x allowPrivate", "172.16.0.1", true, false}, + {"rfc1918 192.168.x", "192.168.1.1", false, true}, + {"rfc1918 192.168.x allowPrivate", "192.168.1.1", true, false}, + {"link-local", "169.254.1.1", false, true}, + {"link-local allowPrivate", "169.254.1.1", true, false}, + {"cgnat", "100.64.0.1", false, true}, + {"cgnat allowPrivate", "100.64.0.1", true, false}, + {"ipv6 ula", "fd00::1", false, true}, + {"ipv6 ula allowPrivate", "fd00::1", true, false}, + {"ipv6 link-local", "fe80::1", false, true}, + {"ipv6 link-local allowPrivate", "fe80::1", true, false}, + + // Public IPs - never blocked + {"public ipv4", "8.8.8.8", false, false}, + {"public ipv4 allowPrivate", "8.8.8.8", true, false}, + {"public ipv6", "2607:f8b0:4004:800::200e", false, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var ip net.IP + if tt.ip != "" { + ip = net.ParseIP(tt.ip) + } + got := IsBlockedIP(ip, tt.allowPrivate) + if got != tt.blocked { + t.Errorf("IsBlockedIP(%v, %v) = %v, want %v", ip, tt.allowPrivate, got, tt.blocked) + } + }) + } +} + +func TestIsBlockedIPv6Transition(t *testing.T) { + tests := []struct { + name string + ip string + allowPrivate bool + blocked bool + }{ + // NAT64: 64:ff9b:: — loopback embedded + {"nat64 loopback", "64:ff9b::127.0.0.1", false, true}, + {"nat64 loopback allowPrivate", "64:ff9b::127.0.0.1", true, true}, + // NAT64 with private + {"nat64 private", "64:ff9b::10.0.0.1", false, true}, + {"nat64 private allowPrivate", "64:ff9b::10.0.0.1", true, false}, + // NAT64 with public + {"nat64 public", "64:ff9b::8.8.8.8", false, false}, + // NAT64 with metadata + {"nat64 metadata", "64:ff9b::169.254.169.254", false, true}, + {"nat64 metadata allowPrivate", "64:ff9b::169.254.169.254", true, true}, + + // 6to4: 2002::: + {"6to4 loopback", "2002:7f00:0001::", false, true}, + {"6to4 private", "2002:0a00:0001::", false, true}, + {"6to4 private allowPrivate", "2002:0a00:0001::", true, false}, + {"6to4 public", "2002:0808:0808::", false, false}, + {"6to4 metadata", "2002:a9fe:a9fe::", false, true}, + + // Teredo: 2001:0000:::: + // Teredo XORs client IPv4 with 0xFFFFFFFF + // 127.0.0.1 XOR'd = 0x80ffff fe = (128.255.255.254) + {"teredo loopback", "2001:0000:4136:e378:8000:63bf:80ff:fffe", false, true}, + // 10.0.0.1 XOR'd = 0xf5fffffe = (245.255.255.254) + {"teredo private", "2001:0000:4136:e378:8000:63bf:f5ff:fffe", false, true}, + {"teredo private allowPrivate", "2001:0000:4136:e378:8000:63bf:f5ff:fffe", true, false}, + + // Regular IPv6 — not transition + {"regular ipv6", "2607:f8b0:4004:800::200e", false, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip := net.ParseIP(tt.ip) + if ip == nil { + t.Fatalf("failed to parse IP %q", tt.ip) + } + got := IsBlockedIP(ip, tt.allowPrivate) + if got != tt.blocked { + t.Errorf("IsBlockedIP(%v, %v) = %v, want %v", ip, tt.allowPrivate, got, tt.blocked) + } + }) + } +} + +func TestValidateHostIP(t *testing.T) { + tests := []struct { + name string + host string + wantErr bool + }{ + {"standard ipv4", "127.0.0.1", false}, + {"public ipv4", "8.8.8.8", false}, + {"hostname", "example.com", false}, + {"ipv6", "::1", false}, + {"octal", "0177.0.0.1", true}, + {"hex", "0x7f000001", true}, + {"packed decimal", "2130706433", true}, + {"leading zero", "127.0.0.01", true}, + {"leading zero octet", "010.0.0.1", true}, + {"empty", "", false}, + {"subdomain", "api.example.com", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateHostIP(tt.host) + if tt.wantErr && err == nil { + t.Errorf("ValidateHostIP(%q) = nil, want error", tt.host) + } + if !tt.wantErr && err != nil { + t.Errorf("ValidateHostIP(%q) = %v, want nil", tt.host, err) + } + }) + } +} + +func TestLooksLikeIP(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"2130706433", true}, // packed decimal + {"0x7f000001", true}, // hex + {"0X7F000001", true}, // hex uppercase + {"0177.0.0.1", true}, // octal-looking + {"127.0.0.01", true}, // leading zero + {"127.0.0.1", false}, // valid strict IPv4 — not "suspicious" + {"example.com", false}, // hostname + {"", false}, // empty + {"10.0.0.1", false}, // valid strict IPv4 + {"api.example.com", false}, // hostname with dots + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := looksLikeIP(tt.input) + if got != tt.want { + t.Errorf("looksLikeIP(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} diff --git a/forge-core/security/redirect.go b/forge-core/security/redirect.go new file mode 100644 index 0000000..039e4f9 --- /dev/null +++ b/forge-core/security/redirect.go @@ -0,0 +1,63 @@ +package security + +import ( + "fmt" + "net/http" + "strings" +) + +// SafeRedirectPolicy returns a CheckRedirect function that strips sensitive +// credentials (Authorization, Cookie, etc.) when a redirect crosses origin +// boundaries (different scheme, host, or port from the original request). +func SafeRedirectPolicy(maxRedirects int) func(*http.Request, []*http.Request) error { + return func(req *http.Request, via []*http.Request) error { + if len(via) >= maxRedirects { + return fmt.Errorf("stopped after %d redirects", maxRedirects) + } + + // Compare against the original (first) request + original := via[0] + if !isSameOrigin(original, req) { + req.Header.Del("Authorization") + req.Header.Del("Proxy-Authorization") + req.Header.Del("Cookie") + req.Header.Del("Cookie2") + } + + return nil + } +} + +// isSameOrigin returns true if two requests share the same scheme, host, and port. +func isSameOrigin(a, b *http.Request) bool { + aScheme := strings.ToLower(a.URL.Scheme) + bScheme := strings.ToLower(b.URL.Scheme) + if aScheme != bScheme { + return false + } + + aHost := strings.ToLower(a.URL.Hostname()) + bHost := strings.ToLower(b.URL.Hostname()) + if aHost != bHost { + return false + } + + aPort := effectivePort(a.URL.Port(), aScheme) + bPort := effectivePort(b.URL.Port(), bScheme) + return aPort == bPort +} + +// effectivePort returns the port or the default for the scheme. +func effectivePort(port, scheme string) string { + if port != "" { + return port + } + switch scheme { + case "https": + return "443" + case "http": + return "80" + default: + return "" + } +} diff --git a/forge-core/security/redirect_test.go b/forge-core/security/redirect_test.go new file mode 100644 index 0000000..9d0c87d --- /dev/null +++ b/forge-core/security/redirect_test.go @@ -0,0 +1,147 @@ +package security + +import ( + "net/http" + "testing" +) + +func TestSafeRedirectPolicy(t *testing.T) { + policy := SafeRedirectPolicy(10) + + tests := []struct { + name string + originalURL string + redirectURL string + wantAuthStrip bool + wantCookieStrip bool + }{ + { + name: "same origin preserves headers", + originalURL: "https://api.example.com/v1/data", + redirectURL: "https://api.example.com/v1/data2", + wantAuthStrip: false, + wantCookieStrip: false, + }, + { + name: "different host strips headers", + originalURL: "https://api.example.com/v1/data", + redirectURL: "https://evil.com/capture", + wantAuthStrip: true, + wantCookieStrip: true, + }, + { + name: "different scheme strips headers", + originalURL: "https://api.example.com/v1/data", + redirectURL: "http://api.example.com/v1/data", + wantAuthStrip: true, + wantCookieStrip: true, + }, + { + name: "different port strips headers", + originalURL: "https://api.example.com/v1/data", + redirectURL: "https://api.example.com:8443/v1/data", + wantAuthStrip: true, + wantCookieStrip: true, + }, + { + name: "implicit port matches explicit 443", + originalURL: "https://api.example.com/v1/data", + redirectURL: "https://api.example.com:443/v1/data", + wantAuthStrip: false, + wantCookieStrip: false, + }, + { + name: "implicit port matches explicit 80", + originalURL: "http://api.example.com/v1/data", + redirectURL: "http://api.example.com:80/v1/data", + wantAuthStrip: false, + wantCookieStrip: false, + }, + { + name: "case insensitive host match", + originalURL: "https://API.Example.COM/v1/data", + redirectURL: "https://api.example.com/v1/data", + wantAuthStrip: false, + wantCookieStrip: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + original, _ := http.NewRequest("GET", tt.originalURL, nil) + redirect, _ := http.NewRequest("GET", tt.redirectURL, nil) + redirect.Header.Set("Authorization", "Bearer secret") + redirect.Header.Set("Proxy-Authorization", "Basic creds") + redirect.Header.Set("Cookie", "session=abc") + redirect.Header.Set("Cookie2", "old=val") + + err := policy(redirect, []*http.Request{original}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + hasAuth := redirect.Header.Get("Authorization") != "" + hasCookie := redirect.Header.Get("Cookie") != "" + + if tt.wantAuthStrip && hasAuth { + t.Error("expected Authorization header to be stripped") + } + if !tt.wantAuthStrip && !hasAuth { + t.Error("expected Authorization header to be preserved") + } + if tt.wantCookieStrip && hasCookie { + t.Error("expected Cookie header to be stripped") + } + if !tt.wantCookieStrip && !hasCookie { + t.Error("expected Cookie header to be preserved") + } + + // Also check Proxy-Authorization and Cookie2 + if tt.wantAuthStrip && redirect.Header.Get("Proxy-Authorization") != "" { + t.Error("expected Proxy-Authorization header to be stripped") + } + if tt.wantCookieStrip && redirect.Header.Get("Cookie2") != "" { + t.Error("expected Cookie2 header to be stripped") + } + }) + } +} + +func TestSafeRedirectPolicyMaxRedirects(t *testing.T) { + policy := SafeRedirectPolicy(2) + + original, _ := http.NewRequest("GET", "https://example.com/1", nil) + second, _ := http.NewRequest("GET", "https://example.com/2", nil) + third, _ := http.NewRequest("GET", "https://example.com/3", nil) + + // 2 via requests means we're at redirect #3, which exceeds max of 2 + err := policy(third, []*http.Request{original, second}) + if err == nil { + t.Error("expected error for exceeding max redirects") + } +} + +func TestIsSameOrigin(t *testing.T) { + tests := []struct { + name string + a, b string + want bool + }{ + {"same", "https://example.com/a", "https://example.com/b", true}, + {"diff host", "https://a.com/x", "https://b.com/x", false}, + {"diff scheme", "https://a.com/x", "http://a.com/x", false}, + {"diff port", "https://a.com/x", "https://a.com:8443/x", false}, + {"implicit 443", "https://a.com/x", "https://a.com:443/x", true}, + {"implicit 80", "http://a.com/x", "http://a.com:80/x", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a, _ := http.NewRequest("GET", tt.a, nil) + b, _ := http.NewRequest("GET", tt.b, nil) + if got := isSameOrigin(a, b); got != tt.want { + t.Errorf("isSameOrigin(%q, %q) = %v, want %v", tt.a, tt.b, got, tt.want) + } + }) + } +} diff --git a/forge-core/security/safe_dialer.go b/forge-core/security/safe_dialer.go new file mode 100644 index 0000000..ae03d7a --- /dev/null +++ b/forge-core/security/safe_dialer.go @@ -0,0 +1,92 @@ +package security + +import ( + "context" + "fmt" + "net" + "net/http" + "time" +) + +// Resolver abstracts DNS resolution for testability. +type Resolver interface { + LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, error) +} + +// SafeDialer validates resolved IPs before establishing connections, +// preventing DNS rebinding and SSRF via post-resolution checks. +type SafeDialer struct { + resolver Resolver + dialer net.Dialer + allowPrivate bool +} + +// NewSafeDialer creates a SafeDialer. If resolver is nil, net.DefaultResolver +// is used. Set allowPrivateIPs to true in container environments where RFC 1918 +// addresses are used for inter-service communication. +func NewSafeDialer(resolver Resolver, allowPrivateIPs bool) *SafeDialer { + if resolver == nil { + resolver = net.DefaultResolver + } + return &SafeDialer{ + resolver: resolver, + dialer: net.Dialer{Timeout: 10 * time.Second}, + allowPrivate: allowPrivateIPs, + } +} + +// SafeDialContext resolves the address, validates all resulting IPs against +// blocked ranges, then dials the first safe IP directly to avoid TOCTOU +// re-resolution. +func (s *SafeDialer) SafeDialContext(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("safe dialer: invalid address %q: %w", addr, err) + } + + // Check for non-standard IP formats first + if err := ValidateHostIP(host); err != nil { + return nil, fmt.Errorf("safe dialer: %w", err) + } + + // If it's an IP literal, validate and dial directly + if ip := net.ParseIP(host); ip != nil { + if IsBlockedIP(ip, s.allowPrivate) { + return nil, fmt.Errorf("safe dialer: blocked IP %s", ip) + } + return s.dialer.DialContext(ctx, network, addr) + } + + // Resolve hostname + addrs, err := s.resolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, fmt.Errorf("safe dialer: DNS lookup failed for %q: %w", host, err) + } + if len(addrs) == 0 { + return nil, fmt.Errorf("safe dialer: no addresses found for %q", host) + } + + // Validate ALL resolved IPs — reject if ANY is blocked + for _, a := range addrs { + if IsBlockedIP(a.IP, s.allowPrivate) { + return nil, fmt.Errorf("safe dialer: domain %q resolved to blocked IP %s", host, a.IP) + } + } + + // Dial the first safe IP directly (avoids TOCTOU re-resolution) + directAddr := net.JoinHostPort(addrs[0].IP.String(), port) + return s.dialer.DialContext(ctx, network, directAddr) +} + +// NewSafeTransport creates an http.Transport that uses SafeDialer for all +// connections. If resolver is nil, net.DefaultResolver is used. +func NewSafeTransport(resolver Resolver, allowPrivateIPs bool) *http.Transport { + sd := NewSafeDialer(resolver, allowPrivateIPs) + return &http.Transport{ + DialContext: sd.SafeDialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } +} diff --git a/forge-core/security/safe_dialer_test.go b/forge-core/security/safe_dialer_test.go new file mode 100644 index 0000000..b1df1c2 --- /dev/null +++ b/forge-core/security/safe_dialer_test.go @@ -0,0 +1,161 @@ +package security + +import ( + "context" + "net" + "strings" + "testing" +) + +// mockResolver implements Resolver for testing. +type mockResolver struct { + addrs map[string][]net.IPAddr + err error +} + +func (m *mockResolver) LookupIPAddr(_ context.Context, host string) ([]net.IPAddr, error) { + if m.err != nil { + return nil, m.err + } + if addrs, ok := m.addrs[host]; ok { + return addrs, nil + } + return nil, &net.DNSError{Err: "no such host", Name: host, IsNotFound: true} +} + +func TestSafeDialerBlocksPrivateResolution(t *testing.T) { + resolver := &mockResolver{ + addrs: map[string][]net.IPAddr{ + "internal.example.com": {{IP: net.ParseIP("10.0.0.1")}}, + }, + } + + sd := NewSafeDialer(resolver, false) + _, err := sd.SafeDialContext(context.Background(), "tcp", "internal.example.com:80") + if err == nil { + t.Fatal("expected error for domain resolving to private IP") + } + if !strings.Contains(err.Error(), "blocked IP") { + t.Errorf("expected 'blocked IP' error, got: %v", err) + } +} + +func TestSafeDialerAllowsPrivateWhenConfigured(t *testing.T) { + resolver := &mockResolver{ + addrs: map[string][]net.IPAddr{ + "service.cluster.local": {{IP: net.ParseIP("10.96.0.1")}}, + }, + } + + sd := NewSafeDialer(resolver, true) + // This will fail at the dial stage (no actual service), but should + // pass IP validation + _, err := sd.SafeDialContext(context.Background(), "tcp", "service.cluster.local:80") + if err == nil { + return // unexpectedly connected + } + if strings.Contains(err.Error(), "blocked IP") { + t.Errorf("should allow private IPs when configured, got: %v", err) + } +} + +func TestSafeDialerBlocksMetadataAlways(t *testing.T) { + resolver := &mockResolver{ + addrs: map[string][]net.IPAddr{ + "metadata.internal": {{IP: net.ParseIP("169.254.169.254")}}, + }, + } + + // Even with allowPrivate=true, metadata must be blocked + sd := NewSafeDialer(resolver, true) + _, err := sd.SafeDialContext(context.Background(), "tcp", "metadata.internal:80") + if err == nil { + t.Fatal("expected error for metadata IP") + } + if !strings.Contains(err.Error(), "blocked IP") { + t.Errorf("expected 'blocked IP' error, got: %v", err) + } +} + +func TestSafeDialerBlocksLoopbackAlways(t *testing.T) { + resolver := &mockResolver{ + addrs: map[string][]net.IPAddr{ + "loopback.example.com": {{IP: net.ParseIP("127.0.0.1")}}, + }, + } + + sd := NewSafeDialer(resolver, true) + _, err := sd.SafeDialContext(context.Background(), "tcp", "loopback.example.com:80") + if err == nil { + t.Fatal("expected error for loopback IP") + } + if !strings.Contains(err.Error(), "blocked IP") { + t.Errorf("expected 'blocked IP' error, got: %v", err) + } +} + +func TestSafeDialerBlocksMixedIPs(t *testing.T) { + resolver := &mockResolver{ + addrs: map[string][]net.IPAddr{ + "mixed.example.com": { + {IP: net.ParseIP("8.8.8.8")}, + {IP: net.ParseIP("10.0.0.1")}, + }, + }, + } + + sd := NewSafeDialer(resolver, false) + _, err := sd.SafeDialContext(context.Background(), "tcp", "mixed.example.com:80") + if err == nil { + t.Fatal("expected error when any resolved IP is blocked") + } + if !strings.Contains(err.Error(), "blocked IP") { + t.Errorf("expected 'blocked IP' error, got: %v", err) + } +} + +func TestSafeDialerDNSFailure(t *testing.T) { + resolver := &mockResolver{ + addrs: map[string][]net.IPAddr{}, + } + + sd := NewSafeDialer(resolver, false) + _, err := sd.SafeDialContext(context.Background(), "tcp", "nonexistent.example.com:80") + if err == nil { + t.Fatal("expected error for DNS failure") + } +} + +func TestSafeDialerIPLiteral(t *testing.T) { + sd := NewSafeDialer(nil, false) + + // Blocked IP literal + _, err := sd.SafeDialContext(context.Background(), "tcp", "169.254.169.254:80") + if err == nil { + t.Fatal("expected error for blocked IP literal") + } + if !strings.Contains(err.Error(), "blocked IP") { + t.Errorf("expected 'blocked IP' error, got: %v", err) + } +} + +func TestSafeDialerRejectsNonStandardIP(t *testing.T) { + sd := NewSafeDialer(nil, false) + + _, err := sd.SafeDialContext(context.Background(), "tcp", "0x7f000001:80") + if err == nil { + t.Fatal("expected error for hex IP") + } + if !strings.Contains(err.Error(), "non-standard IP") { + t.Errorf("expected 'non-standard IP' error, got: %v", err) + } +} + +func TestSafeDialerInvalidAddress(t *testing.T) { + sd := NewSafeDialer(nil, false) + + _, err := sd.SafeDialContext(context.Background(), "tcp", "no-port") + if err == nil { + t.Fatal("expected error for invalid address") + } +} diff --git a/forge-core/security/types.go b/forge-core/security/types.go index ccfe928..b252613 100644 --- a/forge-core/security/types.go +++ b/forge-core/security/types.go @@ -21,9 +21,10 @@ const ( // EgressConfig holds the resolved egress configuration. type EgressConfig struct { - Profile EgressProfile `json:"profile"` - Mode EgressMode `json:"mode"` - AllowedDomains []string `json:"allowed_domains,omitempty"` // explicit user domains - ToolDomains []string `json:"tool_domains,omitempty"` // inferred from tools - AllDomains []string `json:"all_domains,omitempty"` // deduplicated union + Profile EgressProfile `json:"profile"` + Mode EgressMode `json:"mode"` + AllowedDomains []string `json:"allowed_domains,omitempty"` // explicit user domains + ToolDomains []string `json:"tool_domains,omitempty"` // inferred from tools + AllDomains []string `json:"all_domains,omitempty"` // deduplicated union + AllowPrivateIPs bool `json:"allow_private_ips,omitempty"` } diff --git a/forge-core/tools/adapters/webhook_call.go b/forge-core/tools/adapters/webhook_call.go index ea9a74b..2a65853 100644 --- a/forge-core/tools/adapters/webhook_call.go +++ b/forge-core/tools/adapters/webhook_call.go @@ -55,8 +55,9 @@ func (t *webhookCallTool) Execute(ctx context.Context, args json.RawMessage) (st } client := &http.Client{ - Transport: security.EgressTransportFromContext(ctx), - Timeout: 30 * time.Second, + Transport: security.EgressTransportFromContext(ctx), + Timeout: 30 * time.Second, + CheckRedirect: security.SafeRedirectPolicy(10), } resp, err := client.Do(req) if err != nil { diff --git a/forge-core/tools/builtins/http_request.go b/forge-core/tools/builtins/http_request.go index 183e543..c597b4d 100644 --- a/forge-core/tools/builtins/http_request.go +++ b/forge-core/tools/builtins/http_request.go @@ -68,8 +68,9 @@ func (t *httpRequestTool) Execute(ctx context.Context, args json.RawMessage) (st } client := &http.Client{ - Transport: security.EgressTransportFromContext(ctx), - Timeout: timeout, + Transport: security.EgressTransportFromContext(ctx), + Timeout: timeout, + CheckRedirect: security.SafeRedirectPolicy(10), } resp, err := client.Do(req) if err != nil { diff --git a/forge-core/types/config.go b/forge-core/types/config.go index 3ff5a6e..3593bea 100644 --- a/forge-core/types/config.go +++ b/forge-core/types/config.go @@ -23,6 +23,7 @@ type ForgeConfig struct { Memory MemoryConfig `yaml:"memory,omitempty"` Secrets SecretsConfig `yaml:"secrets,omitempty"` Schedules []ScheduleConfig `yaml:"schedules,omitempty"` + CORSOrigins []string `yaml:"cors_origins,omitempty"` } // ScheduleConfig defines a recurring scheduled task in forge.yaml. @@ -61,10 +62,11 @@ type MemoryConfig struct { // EgressRef configures egress security controls. type EgressRef struct { - Profile string `yaml:"profile,omitempty"` // strict, standard, permissive - Mode string `yaml:"mode,omitempty"` // deny-all, allowlist, dev-open - AllowedDomains []string `yaml:"allowed_domains,omitempty"` - Capabilities []string `yaml:"capabilities,omitempty"` // capability bundles (e.g., "slack", "telegram") + Profile string `yaml:"profile,omitempty"` // strict, standard, permissive + Mode string `yaml:"mode,omitempty"` // deny-all, allowlist, dev-open + AllowedDomains []string `yaml:"allowed_domains,omitempty"` + Capabilities []string `yaml:"capabilities,omitempty"` // capability bundles (e.g., "slack", "telegram") + AllowPrivateIPs *bool `yaml:"allow_private_ips,omitempty"` } // SkillsRef references a skills definition file. From 29544eeef407d50aa71d0b1d51b1016c57961afd Mon Sep 17 00:00:00 2001 From: MK Date: Thu, 19 Mar 2026 19:28:35 -0700 Subject: [PATCH 2/2] docs: update security docs for Phase 1 critical fixes Sync documentation to reflect IP validation, SafeDialer, CORS restriction, security headers, redirect credential stripping, and container-aware allowPrivateIPs across 6 doc files. --- docs/architecture.md | 14 +++++++- docs/commands.md | 5 +++ docs/configuration.md | 5 +++ docs/security/egress.md | 76 +++++++++++++++++++++++++++++++++++++-- docs/security/overview.md | 26 +++++++++----- docs/tools.md | 4 +-- 6 files changed, 117 insertions(+), 13 deletions(-) diff --git a/docs/architecture.md b/docs/architecture.md index 3c8f692..e14341d 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -344,7 +344,19 @@ AgentSpec JSON is validated against `schemas/agentspec.v1.0.schema.json` (JSON S ## Egress Security -Egress controls operate at both build time and runtime. Build-time controls generate allowlist artifacts and Kubernetes NetworkPolicy manifests. Runtime controls include an in-process `EgressEnforcer` (Go `http.RoundTripper`) and a local `EgressProxy` for subprocess HTTP traffic. See [Egress Security](security/egress.md) for details. +Egress controls operate at both build time and runtime. Build-time controls generate allowlist artifacts and Kubernetes NetworkPolicy manifests. Runtime controls include: + +- **IP Validation** — Rejects non-standard IP formats (octal, hex, packed decimal) and IPv6 transition addresses embedding private IPs +- **SafeDialer** — Validates resolved IPs post-DNS against blocked CIDR ranges before connecting (prevents DNS rebinding) +- **EgressEnforcer** — In-process `http.RoundTripper` backed by `SafeTransport` for domain allowlist enforcement +- **EgressProxy** — Local HTTP/HTTPS forward proxy for subprocess traffic, also backed by `SafeDialer` +- **Redirect credential stripping** — `http_request` and `webhook_call` strip `Authorization`/`Cookie` headers on cross-origin redirects + +The A2A server adds: +- **CORS restriction** — Origin allowlist (localhost by default), configurable via flag/env/YAML +- **Security headers** — `X-Content-Type-Options`, `Referrer-Policy`, `X-Frame-Options`, `Content-Security-Policy` + +See [Egress Security](security/egress.md) for details. --- ← [Installation](installation.md) | [Back to README](../README.md) | [Skills](skills.md) → diff --git a/docs/commands.md b/docs/commands.md index ab9d022..13a47e0 100644 --- a/docs/commands.md +++ b/docs/commands.md @@ -147,6 +147,7 @@ forge run [flags] | `--provider` | | LLM provider: `openai`, `anthropic`, or `ollama` | | `--env` | `.env` | Path to .env file | | `--with` | | Comma-separated channel adapters (e.g., `slack,telegram`) | +| `--cors-origins` | localhost | Comma-separated CORS allowed origins (e.g., `https://app.example.com,https://admin.example.com`). Use `*` to allow all origins | ### Examples @@ -165,6 +166,9 @@ forge run --host 0.0.0.0 --shutdown-timeout 30s # Run with guardrails enforced forge run --enforce-guardrails --env .env.production + +# Run with custom CORS origins (for K8s ingress) +forge run --cors-origins 'https://app.example.com,https://admin.example.com' ``` --- @@ -193,6 +197,7 @@ forge serve [start|stop|status|logs] [flags] | `--port` | `8080` | HTTP server port | | `--host` | `127.0.0.1` | Bind address (secure default) | | `--with` | | Channel adapters | +| `--cors-origins` | localhost | Comma-separated CORS allowed origins | ### Examples diff --git a/docs/configuration.md b/docs/configuration.md index 8c473d5..3204bcf 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -41,6 +41,10 @@ egress: - "*.github.com" capabilities: # Capability bundles - "slack" + allow_private_ips: false # Allow RFC 1918 IPs (auto: true in containers) + +cors_origins: # CORS allowed origins for A2A server + - "https://app.example.com" # (default: localhost variants) skills: path: "SKILL.md" @@ -91,6 +95,7 @@ schedules: # Recurring scheduled tasks (optional) | `OPENAI_BASE_URL` | Override OpenAI base URL | | `ANTHROPIC_BASE_URL` | Override Anthropic base URL | | `OLLAMA_BASE_URL` | Override Ollama base URL (default: `http://localhost:11434`) | +| `FORGE_CORS_ORIGINS` | Comma-separated CORS allowed origins for A2A server | | `FORGE_PASSPHRASE` | Passphrase for encrypted secrets file | --- diff --git a/docs/security/egress.md b/docs/security/egress.md index b4971ec..12cd742 100644 --- a/docs/security/egress.md +++ b/docs/security/egress.md @@ -45,15 +45,81 @@ Domain matching is handled by `DomainMatcher` (`forge-core/security/domain_match - **Case insensitive**: `API.OpenAI.COM` matches `api.openai.com` - **Localhost bypass**: `127.0.0.1`, `::1`, and `localhost` are always allowed in all modes +## IP Validation + +All egress paths validate hostnames against non-standard IP formats before domain matching. The IP validator (`forge-core/security/ip_validator.go`) rejects SSRF bypass vectors: + +| Blocked Format | Example | Reason | +|---------------|---------|--------| +| Octal | `0177.0.0.1` | Resolves to `127.0.0.1` in some parsers | +| Hexadecimal | `0x7f000001` | Resolves to `127.0.0.1` in some parsers | +| Packed decimal | `2130706433` | Resolves to `127.0.0.1` in some parsers | +| Leading zeros | `127.0.0.01` | Ambiguous parsing across languages | +| IPv6 transition (NAT64) | `64:ff9b::10.0.0.1` | Embeds private IPv4 in IPv6 | +| IPv6 transition (6to4) | `2002:0a00:0001::` | Embeds private IPv4 in IPv6 | +| IPv6 transition (Teredo) | `2001:0000:...` | Embeds XOR'd IPv4 in IPv6 | + +The `ValidateHostIP()` function is called early in both the EgressEnforcer and EgressProxy before any domain matching occurs. + +## Safe Dialer (DNS Rebinding Protection) + +The `SafeDialer` (`forge-core/security/safe_dialer.go`) prevents DNS rebinding and TOCTOU attacks by validating resolved IPs before connecting: + +1. Resolves hostname to IP addresses via DNS +2. Validates **all** resolved IPs against blocked CIDR ranges +3. Dials the first safe IP directly (bypasses re-resolution) + +Blocked IP ranges depend on the `allowPrivateIPs` setting: + +| CIDR | Always Blocked | Blocked when `allowPrivateIPs=false` | +|------|---------------|--------------------------------------| +| `169.254.169.254/32` (cloud metadata) | Yes | Yes | +| `127.0.0.0/8` (loopback) | Yes | Yes | +| `::1/128` (IPv6 loopback) | Yes | Yes | +| `0.0.0.0/8` | Yes | Yes | +| `10.0.0.0/8` (RFC 1918) | — | Yes | +| `172.16.0.0/12` (RFC 1918) | — | Yes | +| `192.168.0.0/16` (RFC 1918) | — | Yes | +| `169.254.0.0/16` (link-local) | — | Yes | +| `100.64.0.0/10` (CGNAT) | — | Yes | +| `fc00::/7` (IPv6 ULA) | — | Yes | +| `fe80::/10` (IPv6 link-local) | — | Yes | + +Both the EgressEnforcer and EgressProxy use `SafeTransport` (an `http.Transport` wired to the SafeDialer) for non-localhost connections. + +## Container-Aware Private IP Handling + +In container and Kubernetes environments, pods communicate via service DNS names that resolve to RFC 1918 addresses (e.g., `10.96.x.x`). Blocking these would break inter-service communication. + +The `allowPrivateIPs` setting is resolved with this precedence: + +1. **Explicit config** — `egress.allow_private_ips` in `forge.yaml` +2. **Auto-detect** — `true` if `InContainer()` detects Docker/Kubernetes +3. **Default** — `false` (block all private IPs) + +| Scenario | `allowPrivateIPs` | RFC 1918 | Cloud Metadata | Loopback | +|----------|-------------------|----------|----------------|----------| +| Local dev | `false` | Blocked | Blocked | Allowed (localhost bypass) | +| Docker Desktop | `true` (auto) | Allowed | **Blocked** | Allowed (localhost bypass) | +| Kubernetes | `true` (auto) | Allowed | **Blocked** | Allowed (localhost bypass) | + +Cloud metadata (`169.254.169.254`) is **always** blocked regardless of the `allowPrivateIPs` setting. + ## Runtime Egress Enforcer -The `EgressEnforcer` (`forge-core/security/egress_enforcer.go`) is an `http.RoundTripper` that wraps the default HTTP transport. Every outbound HTTP request from in-process Go code (builtins like `http_request`, `web_search`, LLM API calls) passes through it. +The `EgressEnforcer` (`forge-core/security/egress_enforcer.go`) is an `http.RoundTripper` that wraps a `SafeTransport`. Every outbound HTTP request from in-process Go code (builtins like `http_request`, `web_search`, LLM API calls) passes through it. ```go -enforcer := security.NewEgressEnforcer(nil, security.ModeAllowlist, allowedDomains) +enforcer := security.NewEgressEnforcer(nil, security.ModeAllowlist, allowedDomains, false) client := &http.Client{Transport: enforcer} ``` +Request validation order: +1. Reject non-standard IP formats (`ValidateHostIP`) +2. Allow localhost (bypass SafeTransport, use `http.DefaultTransport`) +3. Check domain against allowlist (`DomainMatcher.IsAllowed`) +4. Forward via `SafeTransport` (post-DNS IP validation) + Blocked requests return: `egress blocked: domain "X" not in allowlist (mode=allowlist)` The enforcer fires an `OnAttempt` callback for every request, enabling audit logging with domain, mode, and allow/deny decision. @@ -187,8 +253,11 @@ egress: capabilities: - slack - telegram + allow_private_ips: false # default: auto-detect from container env ``` +The `allow_private_ips` field controls whether RFC 1918 addresses are allowed through the SafeDialer. When omitted, it defaults to `true` inside containers (detected via `KUBERNETES_SERVICE_HOST` or `/.dockerenv`) and `false` otherwise. Cloud metadata (`169.254.169.254`) is always blocked. + ## Production vs Development | Setting | Production | Development | @@ -217,9 +286,12 @@ Events without `"source"` come from the in-process enforcer; events with `"sourc | File | Purpose | |------|---------| | `forge-core/security/types.go` | Profile and mode types, `EgressConfig` | +| `forge-core/security/ip_validator.go` | Strict IP parsing, CIDR blocking, IPv6 transition detection | +| `forge-core/security/safe_dialer.go` | Post-DNS-resolution IP validation, `SafeTransport` | | `forge-core/security/domain_matcher.go` | `DomainMatcher` — shared exact/wildcard matching logic | | `forge-core/security/egress_enforcer.go` | `EgressEnforcer` — in-process `http.RoundTripper` | | `forge-core/security/egress_proxy.go` | `EgressProxy` — localhost HTTP/HTTPS forward proxy | +| `forge-core/security/redirect.go` | Cross-origin redirect credential stripping | | `forge-core/security/container.go` | `InContainer()` — Docker/Kubernetes detection | | `forge-core/security/resolver.go` | Allowlist resolution logic | | `forge-core/security/capabilities.go` | Capability bundle definitions | diff --git a/docs/security/overview.md b/docs/security/overview.md index dc1fc44..1edb980 100644 --- a/docs/security/overview.md +++ b/docs/security/overview.md @@ -15,7 +15,7 @@ Forge's security is organized in layers, each addressing a different threat surf │ (content filtering, PII, jailbreak) │ ├──────────────────────────────────────────────────────────────┤ │ Egress Enforcement │ -│ (EgressEnforcer + EgressProxy + NetworkPolicy) │ +│ (EgressEnforcer + EgressProxy + SafeDialer + NetworkPolicy) │ ├──────────────────────────────────────────────────────────────┤ │ Execution Sandboxing │ │ (env isolation, binary allowlists, arg validation, │ @@ -55,6 +55,8 @@ Forge agents are designed to never expose inbound listeners to the public intern - Slack: Socket Mode (outbound WebSocket via `apps.connections.open`) - Telegram: Long-polling via `getUpdates` - **Local-only HTTP server** — The A2A dev server binds to `localhost` by default +- **CORS restriction** — The A2A server restricts `Access-Control-Allow-Origin` to localhost by default; configurable via `--cors-origins` flag, `FORGE_CORS_ORIGINS` env var, or `cors_origins` in `forge.yaml` +- **Security response headers** — All A2A responses include `X-Content-Type-Options: nosniff`, `Referrer-Policy: no-referrer`, `X-Frame-Options: DENY`, and `Content-Security-Policy: default-src 'none'` - **No hidden listeners** — Every network binding is explicit and logged This means a running Forge agent has zero inbound attack surface by default. @@ -63,17 +65,25 @@ This means a running Forge agent has zero inbound attack surface by default. ## Egress Enforcement -Forge restricts outbound network access at three levels: +Forge restricts outbound network access at multiple levels: -### 1. In-Process Enforcer +### 1. IP Validation -The `EgressEnforcer` is a Go `http.RoundTripper` that wraps every outbound HTTP request from in-process tools (`http_request`, `web_search`, LLM API calls). It validates the destination domain against a resolved allowlist before forwarding. +All egress paths reject non-standard IP formats (octal, hex, packed decimal, leading zeros) that could bypass allowlist checks. IPv6 transition addresses (NAT64, 6to4, Teredo) embedding private IPv4 addresses are also blocked. -### 2. Subprocess Proxy +### 2. In-Process Enforcer -Skill scripts and `cli_execute` subprocesses bypass Go-level enforcement. A local `EgressProxy` on `127.0.0.1:` validates domains for subprocess HTTP traffic via `HTTP_PROXY`/`HTTPS_PROXY` env var injection. +The `EgressEnforcer` is a Go `http.RoundTripper` backed by a `SafeTransport` that validates resolved IPs post-DNS. Every outbound HTTP request from in-process tools (`http_request`, `web_search`, LLM API calls) is checked against IP validation, domain allowlist, and post-resolution CIDR blocking. -### 3. Kubernetes NetworkPolicy +### 3. Subprocess Proxy + +Skill scripts and `cli_execute` subprocesses bypass Go-level enforcement. A local `EgressProxy` on `127.0.0.1:` validates domains and resolved IPs for subprocess HTTP traffic via `HTTP_PROXY`/`HTTPS_PROXY` env var injection. + +### 4. Redirect Credential Stripping + +HTTP clients used by `http_request` and `webhook_call` tools strip `Authorization`, `Cookie`, and `Proxy-Authorization` headers when a redirect crosses origin boundaries (different scheme, host, or port). + +### 5. Kubernetes NetworkPolicy In containerized deployments, generated Kubernetes `NetworkPolicy` manifests enforce egress at the pod level, restricting traffic to allowed domains on ports 80/443. @@ -244,7 +254,7 @@ Production builds enforce: | Document | Description | |----------|-------------| -| [Egress Security](egress.md) | Deep dive into egress enforcement: profiles, modes, domain matching, proxy architecture, NetworkPolicy | +| [Egress Security](egress.md) | Deep dive into egress enforcement: IP validation, SafeDialer, profiles, modes, domain matching, proxy architecture, NetworkPolicy | | [Secrets Management](secrets.md) | Encrypted storage, per-agent secrets, passphrase handling | | [Build Signing & Verification](signing.md) | Key management, build signing, runtime verification | | [Content Guardrails](guardrails.md) | PII detection, jailbreak protection, custom rules | diff --git a/docs/tools.md b/docs/tools.md index d35cdbe..a7a5592 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -17,7 +17,7 @@ Tools are capabilities that an LLM agent can invoke during execution. Forge prov | Tool | Description | |------|-------------| -| `http_request` | Make HTTP requests (GET, POST, PUT, DELETE) | +| `http_request` | Make HTTP requests (GET, POST, PUT, DELETE). Strips credentials on cross-origin redirects | | `json_parse` | Parse and query JSON data | | `csv_parse` | Parse CSV data into structured records | | `datetime_now` | Get current date and time | @@ -86,7 +86,7 @@ All file tools use `PathValidator` (from `pathutil.go`): | Adapter | Description | |---------|-------------| | `mcp_call` | Call tools on MCP servers via JSON-RPC | -| `webhook_call` | POST JSON payloads to webhook URLs | +| `webhook_call` | POST JSON payloads to webhook URLs. Strips credentials on cross-origin redirects | | `openapi_call` | Call OpenAPI-described endpoints | Adapter tools bridge external services into the agent's tool set.