Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"net/http"
"strconv"
"time"

"github.com/git-pkgs/registries/safehttp"
)

const (
Expand Down Expand Up @@ -241,3 +243,18 @@ func WithTransport(rt http.RoundTripper) Option {
c.HTTPClient.Transport = rt
}
}

// WithSafeHTTP wraps the client's underlying *http.Client with the
// safehttp transport: dial-time IP gate (rejects loopback, RFC1918,
// CGNAT, link-local, multicast, unspecified), redirect chain capped
// at 10, non-http(s) schemes rejected on redirect. DNS is resolved at
// dial time and the connection dials the resolved IP directly so a
// rebind between resolve and connect cannot escape the gate. Suitable
// for any code path that fetches from URLs an attacker might control
// (a malicious registry response, a manifest-supplied URL, a redirect
// target).
func WithSafeHTTP() Option {
return func(c *Client) {
c.HTTPClient = safehttp.New(c.HTTPClient, safehttp.Options{})
}
}
19 changes: 16 additions & 3 deletions fetch/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"time"

"github.com/rs/dnscache"

"github.com/git-pkgs/registries/safehttp"
)

const (
Expand Down Expand Up @@ -143,13 +145,24 @@ func NewFetcher(opts ...Option) *Fetcher {
if err != nil {
return nil, err
}
// Gate every resolved IP against the safehttp block
// list (loopback, RFC1918, CGNAT, link-local, ...)
// before dialing. The dial is to the resolved IP
// directly so a rebind between gate and connect
// cannot escape.
var lastErr error
for _, ip := range ips {
conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(ip, port))
if err == nil {
if parsed := net.ParseIP(ip); parsed != nil {
if err := safehttp.CheckIP(parsed, safehttp.Options{}); err != nil {
lastErr = err
continue
}
}
conn, derr := dialer.DialContext(ctx, network, net.JoinHostPort(ip, port))
if derr == nil {
return conn, nil
}
lastErr = err
lastErr = derr
}
if lastErr == nil {
return nil, fmt.Errorf("no IPs resolved for %s", host)
Expand Down
16 changes: 16 additions & 0 deletions fetch/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package fetch

import (
"os"
"testing"

"github.com/git-pkgs/registries/safehttp"
)

// TestMain opts the safehttp dial gate off loopback so the package's
// httptest-backed test suite continues to run. The opt-out is binary-
// scoped — production code never sees it.
func TestMain(m *testing.M) {
safehttp.EnableLoopbackForTesting()
os.Exit(m.Run())
}
201 changes: 201 additions & 0 deletions safehttp/safehttp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
// Package safehttp builds an http.Client suitable for fetching from
// untrusted hosts. The transport applies three defences in concert:
//
// 1. Dial-time IP gate. DNS is resolved once per dial; each resolved
// address is checked against the block list (loopback, RFC1918,
// CGNAT 100.64.0.0/10, link-local, multicast, unspecified) before
// any TCP connect. The connection then dials the resolved IP
// directly, so a rebind between check and connect cannot escape
// the gate.
//
// 2. Redirect cap. Go's default ignores chain length and re-trusts
// each hop. CheckRedirect caps at 10 and re-validates every
// redirect target.
//
// 3. Scheme gate on redirect. file://, gopher://, ftp://, data:// are
// rejected — the only reason a registry would 30x to those is to
// exfiltrate something.
//
// Threat model: a compromised registry, CDN, or maliciously-crafted
// URL that returns a 30x to http://localhost or an RFC1918 address
// should not be able to use this transport to probe internal services
// or exfiltrate local files. The gate is on every dial including
// redirect targets.
package safehttp

import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"time"
)

const (
// MaxRedirects bounds the redirect chain length.
MaxRedirects = 10

defaultTimeout = 30 * time.Second
dialTimeout = 30 * time.Second
)

// Options configures a safehttp client. The zero value gives the
// production-strict gate; tests can opt parts of it off explicitly.
type Options struct {
// AllowLoopback disables the loopback (127.0.0.0/8, ::1) check.
// Test-only; never set in production paths.
AllowLoopback bool

// AllowPrivate disables the RFC1918 / ULA / CGNAT checks. Test-only.
AllowPrivate bool
}

// testInsecure flips both AllowLoopback and AllowPrivate on at the
// gate. Set via EnableLoopbackForTesting from a test binary's TestMain
// so existing httptest-based test suites don't have to thread an
// Options flag through every constructor.
var testInsecure bool

// EnableLoopbackForTesting flips the SSRF dial gate's loopback and
// private-IP checks off for the calling test binary. Use in TestMain;
// never call from production code.
func EnableLoopbackForTesting() { testInsecure = true }

// New returns an http.Client that applies the SSRF defences described
// in the package doc. base may be nil; if non-nil its Timeout, Jar,
// and other non-Transport fields are preserved.
func New(base *http.Client, opts Options) *http.Client {
c := http.Client{Timeout: defaultTimeout}
if base != nil {
c = *base
}

transport, _ := http.DefaultTransport.(*http.Transport)
transport = transport.Clone()
if base != nil {
if t, ok := base.Transport.(*http.Transport); ok && t != nil {
transport = t.Clone()
}
}

underlying := transport.DialContext
if underlying == nil {
d := &net.Dialer{Timeout: dialTimeout}
underlying = d.DialContext
}

gate := newGate(opts)
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return gate.dial(ctx, network, addr, underlying)
}
c.Transport = transport

c.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= MaxRedirects {
return fmt.Errorf("safehttp: stopped after %d redirects", MaxRedirects)
}
return validateRedirect(req.URL)
}
return &c
}

// CheckIP reports whether an IP is acceptable to dial under the
// supplied options. Exported so other transports (e.g. registries/
// fetch, which manages its own DialContext for DNS caching) can apply
// the same gate without wiring through a full safehttp client.
func CheckIP(ip net.IP, opts Options) error {
return newGate(opts).check(ip)
}

type ipGate struct {
allowLoopback bool
allowPrivate bool
}

func newGate(opts Options) *ipGate {
return &ipGate{allowLoopback: opts.AllowLoopback, allowPrivate: opts.AllowPrivate}
}

func (g *ipGate) dial(ctx context.Context, network, addr string, dial func(context.Context, string, string) (net.Conn, error)) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}

if ip := net.ParseIP(host); ip != nil {
if err := g.check(ip); err != nil {
return nil, err
}
return dial(ctx, network, addr)
}

ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, err
}
var lastErr error
for _, ip := range ips {
if err := g.check(ip.IP); err != nil {
lastErr = err
continue
}
conn, derr := dial(ctx, network, net.JoinHostPort(ip.IP.String(), port))
if derr == nil {
return conn, nil
}
lastErr = derr
}
if lastErr != nil {
return nil, lastErr
}
return nil, fmt.Errorf("safehttp: no addresses resolved for %s", host)
}

var cgnat = mustCIDR("100.64.0.0/10")

func mustCIDR(s string) *net.IPNet {
_, n, err := net.ParseCIDR(s)
if err != nil {
panic(err)
}
return n
}

func (g *ipGate) check(ip net.IP) error {
allowLoopback := g.allowLoopback || testInsecure
allowPrivate := g.allowPrivate || testInsecure

if ip.IsUnspecified() {
return blockedErr(ip, "unspecified")
}
if ip.IsLoopback() && !allowLoopback {
return blockedErr(ip, "loopback")
}
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return blockedErr(ip, "link-local")
}
if ip.IsInterfaceLocalMulticast() || ip.IsMulticast() {
return blockedErr(ip, "multicast")
}
if !allowPrivate {
if ip.IsPrivate() {
return blockedErr(ip, "private")
}
if cgnat.Contains(ip) {
return blockedErr(ip, "CGNAT")
}
}
return nil
}

func blockedErr(ip net.IP, kind string) error {
return fmt.Errorf("safehttp: refusing to connect to %s (%s)", ip, kind)
}

func validateRedirect(u *url.URL) error {
if u.Scheme != "http" && u.Scheme != "https" {
return fmt.Errorf("safehttp: refusing redirect to scheme %q", u.Scheme)
}
return nil
}
Loading