From ba424c1b7d4fd9d72f24ebe95db3f1f057d29de4 Mon Sep 17 00:00:00 2001 From: Andrew Nesbitt Date: Tue, 12 May 2026 12:27:45 +0100 Subject: [PATCH] safehttp: SSRF-safe transport for client and fetch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an HTTP transport that hardens registry fetches against three related threats: * Server-Side Request Forgery via a registry / CDN that 30x's to an internal address (http://localhost, RFC1918, link-local, ...) so the client probes services it shouldn't reach. * DNS rebinding between resolve and connect — the IP that comes back from DNS is not necessarily what gets connected to if the resolver cache or upstream is poisoned. * Redirect-target exfiltration via non-http(s) schemes (file://, gopher://, data://) returned in a Location header. Three defences in concert: 1. Dial-time IP gate. DNS is resolved once per dial; each resolved address is checked against the block list (loopback, RFC1918, CGNAT 100.64.0.0/10, link-local, multicast, unspecified) before any TCP connect. The connection then dials the resolved IP directly, so a rebind between check and connect cannot escape. 2. Redirect cap at 10 hops, re-validating every target through the same dial gate. 3. Non-http(s) scheme rejected on redirect. New package github.com/git-pkgs/registries/safehttp holds the transport. client.WithSafeHTTP() opts a Client into it. The fetch package's existing dnscache-backed dialer now gates each resolved IP against safehttp.CheckIP before connecting. Loopback opt-out via safehttp.EnableLoopbackForTesting (called from TestMain) keeps existing httptest.Server-backed test suites working; production paths never see the opt-out. --- client/client.go | 17 ++++ fetch/fetcher.go | 19 +++- fetch/main_test.go | 16 +++ safehttp/safehttp.go | 201 ++++++++++++++++++++++++++++++++++++++ safehttp/safehttp_test.go | 163 +++++++++++++++++++++++++++++++ 5 files changed, 413 insertions(+), 3 deletions(-) create mode 100644 fetch/main_test.go create mode 100644 safehttp/safehttp.go create mode 100644 safehttp/safehttp_test.go diff --git a/client/client.go b/client/client.go index cfd8ed5..8e86cf6 100644 --- a/client/client.go +++ b/client/client.go @@ -10,6 +10,8 @@ import ( "net/http" "strconv" "time" + + "github.com/git-pkgs/registries/safehttp" ) const ( @@ -241,3 +243,18 @@ func WithTransport(rt http.RoundTripper) Option { c.HTTPClient.Transport = rt } } + +// WithSafeHTTP wraps the client's underlying *http.Client with the +// safehttp transport: dial-time IP gate (rejects loopback, RFC1918, +// CGNAT, link-local, multicast, unspecified), redirect chain capped +// at 10, non-http(s) schemes rejected on redirect. DNS is resolved at +// dial time and the connection dials the resolved IP directly so a +// rebind between resolve and connect cannot escape the gate. Suitable +// for any code path that fetches from URLs an attacker might control +// (a malicious registry response, a manifest-supplied URL, a redirect +// target). +func WithSafeHTTP() Option { + return func(c *Client) { + c.HTTPClient = safehttp.New(c.HTTPClient, safehttp.Options{}) + } +} diff --git a/fetch/fetcher.go b/fetch/fetcher.go index 4a6fc9e..20e6617 100644 --- a/fetch/fetcher.go +++ b/fetch/fetcher.go @@ -15,6 +15,8 @@ import ( "time" "github.com/rs/dnscache" + + "github.com/git-pkgs/registries/safehttp" ) const ( @@ -143,13 +145,24 @@ func NewFetcher(opts ...Option) *Fetcher { if err != nil { return nil, err } + // Gate every resolved IP against the safehttp block + // list (loopback, RFC1918, CGNAT, link-local, ...) + // before dialing. The dial is to the resolved IP + // directly so a rebind between gate and connect + // cannot escape. var lastErr error for _, ip := range ips { - conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(ip, port)) - if err == nil { + if parsed := net.ParseIP(ip); parsed != nil { + if err := safehttp.CheckIP(parsed, safehttp.Options{}); err != nil { + lastErr = err + continue + } + } + conn, derr := dialer.DialContext(ctx, network, net.JoinHostPort(ip, port)) + if derr == nil { return conn, nil } - lastErr = err + lastErr = derr } if lastErr == nil { return nil, fmt.Errorf("no IPs resolved for %s", host) diff --git a/fetch/main_test.go b/fetch/main_test.go new file mode 100644 index 0000000..9a34e82 --- /dev/null +++ b/fetch/main_test.go @@ -0,0 +1,16 @@ +package fetch + +import ( + "os" + "testing" + + "github.com/git-pkgs/registries/safehttp" +) + +// TestMain opts the safehttp dial gate off loopback so the package's +// httptest-backed test suite continues to run. The opt-out is binary- +// scoped — production code never sees it. +func TestMain(m *testing.M) { + safehttp.EnableLoopbackForTesting() + os.Exit(m.Run()) +} diff --git a/safehttp/safehttp.go b/safehttp/safehttp.go new file mode 100644 index 0000000..ce02bc2 --- /dev/null +++ b/safehttp/safehttp.go @@ -0,0 +1,201 @@ +// Package safehttp builds an http.Client suitable for fetching from +// untrusted hosts. The transport applies three defences in concert: +// +// 1. Dial-time IP gate. DNS is resolved once per dial; each resolved +// address is checked against the block list (loopback, RFC1918, +// CGNAT 100.64.0.0/10, link-local, multicast, unspecified) before +// any TCP connect. The connection then dials the resolved IP +// directly, so a rebind between check and connect cannot escape +// the gate. +// +// 2. Redirect cap. Go's default ignores chain length and re-trusts +// each hop. CheckRedirect caps at 10 and re-validates every +// redirect target. +// +// 3. Scheme gate on redirect. file://, gopher://, ftp://, data:// are +// rejected — the only reason a registry would 30x to those is to +// exfiltrate something. +// +// Threat model: a compromised registry, CDN, or maliciously-crafted +// URL that returns a 30x to http://localhost or an RFC1918 address +// should not be able to use this transport to probe internal services +// or exfiltrate local files. The gate is on every dial including +// redirect targets. +package safehttp + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "time" +) + +const ( + // MaxRedirects bounds the redirect chain length. + MaxRedirects = 10 + + defaultTimeout = 30 * time.Second + dialTimeout = 30 * time.Second +) + +// Options configures a safehttp client. The zero value gives the +// production-strict gate; tests can opt parts of it off explicitly. +type Options struct { + // AllowLoopback disables the loopback (127.0.0.0/8, ::1) check. + // Test-only; never set in production paths. + AllowLoopback bool + + // AllowPrivate disables the RFC1918 / ULA / CGNAT checks. Test-only. + AllowPrivate bool +} + +// testInsecure flips both AllowLoopback and AllowPrivate on at the +// gate. Set via EnableLoopbackForTesting from a test binary's TestMain +// so existing httptest-based test suites don't have to thread an +// Options flag through every constructor. +var testInsecure bool + +// EnableLoopbackForTesting flips the SSRF dial gate's loopback and +// private-IP checks off for the calling test binary. Use in TestMain; +// never call from production code. +func EnableLoopbackForTesting() { testInsecure = true } + +// New returns an http.Client that applies the SSRF defences described +// in the package doc. base may be nil; if non-nil its Timeout, Jar, +// and other non-Transport fields are preserved. +func New(base *http.Client, opts Options) *http.Client { + c := http.Client{Timeout: defaultTimeout} + if base != nil { + c = *base + } + + transport, _ := http.DefaultTransport.(*http.Transport) + transport = transport.Clone() + if base != nil { + if t, ok := base.Transport.(*http.Transport); ok && t != nil { + transport = t.Clone() + } + } + + underlying := transport.DialContext + if underlying == nil { + d := &net.Dialer{Timeout: dialTimeout} + underlying = d.DialContext + } + + gate := newGate(opts) + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return gate.dial(ctx, network, addr, underlying) + } + c.Transport = transport + + c.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if len(via) >= MaxRedirects { + return fmt.Errorf("safehttp: stopped after %d redirects", MaxRedirects) + } + return validateRedirect(req.URL) + } + return &c +} + +// CheckIP reports whether an IP is acceptable to dial under the +// supplied options. Exported so other transports (e.g. registries/ +// fetch, which manages its own DialContext for DNS caching) can apply +// the same gate without wiring through a full safehttp client. +func CheckIP(ip net.IP, opts Options) error { + return newGate(opts).check(ip) +} + +type ipGate struct { + allowLoopback bool + allowPrivate bool +} + +func newGate(opts Options) *ipGate { + return &ipGate{allowLoopback: opts.AllowLoopback, allowPrivate: opts.AllowPrivate} +} + +func (g *ipGate) dial(ctx context.Context, network, addr string, dial func(context.Context, string, string) (net.Conn, error)) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + if ip := net.ParseIP(host); ip != nil { + if err := g.check(ip); err != nil { + return nil, err + } + return dial(ctx, network, addr) + } + + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } + var lastErr error + for _, ip := range ips { + if err := g.check(ip.IP); err != nil { + lastErr = err + continue + } + conn, derr := dial(ctx, network, net.JoinHostPort(ip.IP.String(), port)) + if derr == nil { + return conn, nil + } + lastErr = derr + } + if lastErr != nil { + return nil, lastErr + } + return nil, fmt.Errorf("safehttp: no addresses resolved for %s", host) +} + +var cgnat = mustCIDR("100.64.0.0/10") + +func mustCIDR(s string) *net.IPNet { + _, n, err := net.ParseCIDR(s) + if err != nil { + panic(err) + } + return n +} + +func (g *ipGate) check(ip net.IP) error { + allowLoopback := g.allowLoopback || testInsecure + allowPrivate := g.allowPrivate || testInsecure + + if ip.IsUnspecified() { + return blockedErr(ip, "unspecified") + } + if ip.IsLoopback() && !allowLoopback { + return blockedErr(ip, "loopback") + } + if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return blockedErr(ip, "link-local") + } + if ip.IsInterfaceLocalMulticast() || ip.IsMulticast() { + return blockedErr(ip, "multicast") + } + if !allowPrivate { + if ip.IsPrivate() { + return blockedErr(ip, "private") + } + if cgnat.Contains(ip) { + return blockedErr(ip, "CGNAT") + } + } + return nil +} + +func blockedErr(ip net.IP, kind string) error { + return fmt.Errorf("safehttp: refusing to connect to %s (%s)", ip, kind) +} + +func validateRedirect(u *url.URL) error { + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("safehttp: refusing redirect to scheme %q", u.Scheme) + } + return nil +} diff --git a/safehttp/safehttp_test.go b/safehttp/safehttp_test.go new file mode 100644 index 0000000..5d71b37 --- /dev/null +++ b/safehttp/safehttp_test.go @@ -0,0 +1,163 @@ +package safehttp + +import ( + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +func TestCheckIP_Blocks(t *testing.T) { + g := newGate(Options{}) + cases := map[string]string{ + "127.0.0.1": "loopback", + "127.255.255.255": "loopback", + "::1": "loopback", + "10.0.0.1": "private", + "172.16.5.4": "private", + "192.168.1.1": "private", + "169.254.169.254": "link-local", + "100.64.0.1": "CGNAT", + "0.0.0.0": "unspecified", + "fc00::1": "private", + "fe80::1": "link-local", + "ff02::1": "link-local", // link-local multicast: link-local check wins + "239.255.255.250": "multicast", // SSDP — multicast but not link-local + } + for in, wantKind := range cases { + ip := net.ParseIP(in) + if ip == nil { + t.Fatalf("bad test IP %q", in) + } + err := g.check(ip) + if err == nil { + t.Errorf("checkIP(%s) = nil; want %s", in, wantKind) + continue + } + if !strings.Contains(err.Error(), wantKind) { + t.Errorf("checkIP(%s) = %v; want kind %q", in, err, wantKind) + } + } +} + +func TestCheckIP_Allows(t *testing.T) { + g := newGate(Options{}) + for _, in := range []string{"8.8.8.8", "1.1.1.1", "2606:4700:4700::1111"} { + ip := net.ParseIP(in) + if err := g.check(ip); err != nil { + t.Errorf("checkIP(%s) = %v; want nil", in, err) + } + } +} + +func TestCheckIP_AllowLoopbackOptOut(t *testing.T) { + g := newGate(Options{AllowLoopback: true}) + if err := g.check(net.ParseIP("127.0.0.1")); err != nil { + t.Errorf("with AllowLoopback, checkIP(127.0.0.1) = %v; want nil", err) + } + if err := g.check(net.ParseIP("10.0.0.1")); err == nil { + t.Error("AllowLoopback must not relax the private-IP check") + } +} + +// CheckIP is the exported entry point; assert it agrees with the +// internal gate. +func TestCheckIP_Exported(t *testing.T) { + if err := CheckIP(net.ParseIP("127.0.0.1"), Options{}); err == nil { + t.Error("CheckIP(127.0.0.1) under default options should fail") + } + if err := CheckIP(net.ParseIP("8.8.8.8"), Options{}); err != nil { + t.Errorf("CheckIP(8.8.8.8) under default options should pass; err=%v", err) + } +} + +func TestClient_LoopbackRefused(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + c := New(nil, Options{}) + resp, err := c.Get(ts.URL) + if err == nil { + _ = resp.Body.Close() + t.Fatalf("expected loopback to be refused, got 200") + } + if !strings.Contains(err.Error(), "loopback") { + t.Errorf("error %v should mention loopback", err) + } +} + +func TestClient_LoopbackAllowed(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + c := New(nil, Options{AllowLoopback: true}) + resp, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("AllowLoopback: %v", err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != 200 { + t.Errorf("status %d, want 200", resp.StatusCode) + } +} + +func TestClient_RedirectCap(t *testing.T) { + var hits int + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hits++ + http.Redirect(w, r, "/", http.StatusFound) + })) + defer ts.Close() + + c := New(nil, Options{AllowLoopback: true}) + resp, err := c.Get(ts.URL) + if err == nil { + _ = resp.Body.Close() + t.Fatalf("expected redirect-cap error") + } + if !strings.Contains(err.Error(), "stopped after") { + t.Errorf("error %v should mention the redirect cap", err) + } + if hits < MaxRedirects { + t.Errorf("expected at least %d hits before bail, got %d", MaxRedirects, hits) + } +} + +func TestClient_BadSchemeRedirect(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Location", "file:///etc/passwd") + w.WriteHeader(http.StatusFound) + })) + defer ts.Close() + + c := New(nil, Options{AllowLoopback: true}) + resp, err := c.Get(ts.URL) + if err == nil { + _ = resp.Body.Close() + t.Fatalf("expected scheme rejection on redirect") + } + if !strings.Contains(err.Error(), "refusing redirect to scheme") { + t.Errorf("error %v should mention scheme rejection", err) + } +} + +func TestValidateRedirect(t *testing.T) { + for _, scheme := range []string{"file", "gopher", "ftp", "data"} { + u, _ := url.Parse(scheme + "://x/") + if err := validateRedirect(u); err == nil { + t.Errorf("scheme %q should be refused", scheme) + } + } + for _, scheme := range []string{"http", "https"} { + u, _ := url.Parse(scheme + "://example.com/") + if err := validateRedirect(u); err != nil { + t.Errorf("scheme %q should be allowed: %v", scheme, err) + } + } +}