From fd2b744ecb0c17f91f684a8612fc0ca0a7c12fb6 Mon Sep 17 00:00:00 2001 From: Michal Budzyn Date: Sun, 5 Apr 2026 02:43:39 +0200 Subject: [PATCH] redesign tls client --- README.md | 4 +- tls/client/client.go | 81 +++++++++++++++++++----- tls/client/config/config.go | 25 +++++++- tls/client/config/config_test.go | 41 ++++++++---- tls/client/filesource/filesource_test.go | 76 ++++++++++++++++++++-- tls/client/option.go | 19 ++++++ tls/client/roundtripper.go | 17 ----- tls/server/config/config_test.go | 18 +++++- tls/server/option.go | 14 ++++ 9 files changed, 241 insertions(+), 54 deletions(-) diff --git a/README.md b/README.md index abebef4..4b77c72 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,7 @@ import ( ) func main() { - tlsClientConfigFunc, err := tlsclientconfig.GetTLSClientConfigFunc(slog.Default(), &tlsconfig.TLSClientConfig{ + tlsClientConfig, err := tlsclientconfig.GetTLSClientConfig(slog.Default(), &tlsconfig.TLSClientConfig{ Enable: true, Refresh: 1 * time.Second, InsecureSkipVerify: false, @@ -92,7 +92,7 @@ func main() { if err != nil { log.Fatalln(err) } - transport := tlsclient.NewDefaultRoundTripper(tlsclient.WithClientTLSConfig(tlsClientConfigFunc())) + transport := tlsclient.NewDefaultRoundTripper(tlsclient.WithClientTLSConfig(tlsClientConfig)) client := &http.Client{Transport: transport} resp, err := client.Get("https://localhost:8443") if err != nil { diff --git a/tls/client/client.go b/tls/client/client.go index 95b890d..f903036 100644 --- a/tls/client/client.go +++ b/tls/client/client.go @@ -2,6 +2,7 @@ package tlsclient import ( "crypto/tls" + "crypto/x509" "errors" "log/slog" "time" @@ -13,36 +14,86 @@ const ( initLoadTimeout = 5 * time.Second ) -type TLSClientConfigFunc func() *tls.Config +type Provider struct { + store *source.ClientCertsStore + opts []TLSClientConfigOption +} -func NewTLSClientConfigFunc(logger *slog.Logger, src source.ClientCertsSource, opts ...TLSClientConfigOption) (TLSClientConfigFunc, error) { +func NewProvider(logger *slog.Logger, src source.ClientCertsSource, opts ...TLSClientConfigOption) (*Provider, error) { store, err := NewTLSClientCertsStore(logger, src) if err != nil { return nil, err } + return &Provider{ + store: store, + opts: opts, + }, nil +} + +func (p *Provider) LoadClientCerts() source.ClientCerts { + return p.store.LoadClientCerts() +} + +func (p *Provider) TLSConfig() *tls.Config { + cs := p.LoadClientCerts() + x := &tls.Config{ + // Root CAs are read dynamically in VerifyConnection to support rotation. + // nolint:gosec + InsecureSkipVerify: true, + } var getClientCertificateFunc func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) - if store.LoadClientCerts().Certificate != nil { + if cs.Certificate != nil { // Set function only when client certificate is available. // TLS 1.3 checks if GetClientCertificate function is nil, if it is not nil, // it assumes client certificate is available which call cause the panic if nil is returned. //nolint:unparam getClientCertificateFunc = func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { - return store.LoadClientCerts().Certificate, nil + return p.LoadClientCerts().Certificate, nil } } - return func() *tls.Config { - cs := store.LoadClientCerts() - x := &tls.Config{ - RootCAs: cs.RootCAs, - // nolint:gosec - InsecureSkipVerify: cs.InsecureSkipVerify, - GetClientCertificate: getClientCertificateFunc, + x.GetClientCertificate = getClientCertificateFunc + for _, opt := range p.opts { + opt(x) + } + x.VerifyConnection = p.verifyConnection(x.ServerName) + return x +} + +func (p *Provider) verifyConnection(configuredServerName string) func(tls.ConnectionState) error { + return func(cs tls.ConnectionState) error { + clientCerts := p.LoadClientCerts() + if clientCerts.InsecureSkipVerify { + return nil } - for _, opt := range opts { - opt(x) + if len(cs.PeerCertificates) == 0 { + return errors.New("tls: no peer certificates") } - return x - }, nil + + serverName := cs.ServerName + if serverName == "" { + serverName = configuredServerName + } + + opts := x509.VerifyOptions{ + Roots: clientCerts.RootCAs, + DNSName: serverName, + Intermediates: x509.NewCertPool(), + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + for _, cert := range cs.PeerCertificates[1:] { + opts.Intermediates.AddCert(cert) + } + _, err := cs.PeerCertificates[0].Verify(opts) + return err + } +} + +func NewTLSConfig(logger *slog.Logger, src source.ClientCertsSource, opts ...TLSClientConfigOption) (*tls.Config, error) { + provider, err := NewProvider(logger, src, opts...) + if err != nil { + return nil, err + } + return provider.TLSConfig(), nil } func NewTLSClientCertsStore(logger *slog.Logger, src source.ClientCertsSource) (*source.ClientCertsStore, error) { diff --git a/tls/client/config/config.go b/tls/client/config/config.go index d205bb7..816b287 100644 --- a/tls/client/config/config.go +++ b/tls/client/config/config.go @@ -1,15 +1,17 @@ package config import ( + "crypto/tls" "fmt" "log/slog" + "net/http" "github.com/grepplabs/cert-source/config" tlsclient "github.com/grepplabs/cert-source/tls/client" "github.com/grepplabs/cert-source/tls/client/filesource" ) -func GetTLSClientConfigFunc(logger *slog.Logger, conf *config.TLSClientConfig, opts ...tlsclient.TLSClientConfigOption) (tlsclient.TLSClientConfigFunc, error) { +func GetTLSClientConfig(logger *slog.Logger, conf *config.TLSClientConfig, opts ...tlsclient.TLSClientConfigOption) (*tls.Config, error) { if !conf.Enable { return nil, nil } @@ -25,5 +27,24 @@ func GetTLSClientConfigFunc(logger *slog.Logger, conf *config.TLSClientConfig, o if err != nil { return nil, fmt.Errorf("setup client cert file source: %w", err) } - return tlsclient.NewTLSClientConfigFunc(logger, fs, opts...) + return tlsclient.NewTLSConfig(logger, fs, opts...) +} + +func NewRoundTripper(logger *slog.Logger, conf *config.TLSClientConfig, opts ...tlsclient.TLSClientConfigOption) (*tlsclient.RoundTripper, error) { + tlsConfig, err := GetTLSClientConfig(logger, conf, opts...) + if err != nil { + return nil, err + } + if tlsConfig == nil { + return tlsclient.NewDefaultRoundTripper(), nil + } + return tlsclient.NewDefaultRoundTripper(tlsclient.WithClientTLSConfig(tlsConfig)), nil +} + +func NewClient(logger *slog.Logger, conf *config.TLSClientConfig, opts ...tlsclient.TLSClientConfigOption) (*http.Client, error) { + transport, err := NewRoundTripper(logger, conf, opts...) + if err != nil { + return nil, err + } + return &http.Client{Transport: transport}, nil } diff --git a/tls/client/config/config_test.go b/tls/client/config/config_test.go index 9adf43f..66bef5a 100644 --- a/tls/client/config/config_test.go +++ b/tls/client/config/config_test.go @@ -13,7 +13,7 @@ import ( func TestGetClientTLSConfig(t *testing.T) { bundle := testutil.NewCertsBundle() defer bundle.Close() - tlsConfigFunc, err := GetTLSClientConfigFunc(slog.Default(), &config.TLSClientConfig{ + tlsConfig, err := GetTLSClientConfig(slog.Default(), &config.TLSClientConfig{ Enable: true, Refresh: 0, File: config.TLSClientFiles{ @@ -21,11 +21,12 @@ func TestGetClientTLSConfig(t *testing.T) { Cert: bundle.ClientCert.Name(), RootCAs: bundle.CACert.Name(), }, - }, tlsclient.WithTLSClientNextProtos([]string{"h2"})) + }, tlsclient.WithTLSClientHTTP2(), tlsclient.WithTLSServerName("localhost")) require.NoError(t, err) - tlsConfig := tlsConfigFunc() - require.NotNil(t, tlsConfig.RootCAs) + require.True(t, tlsConfig.InsecureSkipVerify) + require.NotNil(t, tlsConfig.VerifyConnection) require.Equal(t, []string{"h2"}, tlsConfig.NextProtos) + require.Equal(t, "localhost", tlsConfig.ServerName) clientCert, err := tlsConfig.GetClientCertificate(nil) require.NoError(t, err) @@ -35,33 +36,49 @@ func TestGetClientTLSConfig(t *testing.T) { func TestGetClientTLSConfigNoConfig(t *testing.T) { bundle := testutil.NewCertsBundle() defer bundle.Close() - tlsConfigFunc, err := GetTLSClientConfigFunc(slog.Default(), &config.TLSClientConfig{ + tlsConfig, err := GetTLSClientConfig(slog.Default(), &config.TLSClientConfig{ Enable: true, Refresh: 0, File: config.TLSClientFiles{}, }) require.NoError(t, err) - tlsConfig := tlsConfigFunc() - require.Nil(t, tlsConfig.RootCAs) + require.True(t, tlsConfig.InsecureSkipVerify) require.Nil(t, tlsConfig.GetClientCertificate) } func TestGetClientTLSConfigSkipVerify(t *testing.T) { bundle := testutil.NewCertsBundle() defer bundle.Close() - tlsConfigFunc, err := GetTLSClientConfigFunc(slog.Default(), &config.TLSClientConfig{ - Enable: true, - Refresh: 0, + tlsConfig, err := GetTLSClientConfig(slog.Default(), &config.TLSClientConfig{ + Enable: true, + Refresh: 0, + InsecureSkipVerify: true, File: config.TLSClientFiles{ Key: bundle.ClientKey.Name(), Cert: bundle.ClientCert.Name(), }, }) require.NoError(t, err) - tlsConfig := tlsConfigFunc() - require.Nil(t, tlsConfig.RootCAs) + require.True(t, tlsConfig.InsecureSkipVerify) clientCert, err := tlsConfig.GetClientCertificate(nil) require.NoError(t, err) require.NotNil(t, clientCert) } + +func TestGetClientTLSHTTP2AndHTTP11Config(t *testing.T) { + bundle := testutil.NewCertsBundle() + defer bundle.Close() + + tlsConfig, err := GetTLSClientConfig(slog.Default(), &config.TLSClientConfig{ + Enable: true, + Refresh: 0, + File: config.TLSClientFiles{ + Key: bundle.ClientKey.Name(), + Cert: bundle.ClientCert.Name(), + RootCAs: bundle.CACert.Name(), + }, + }, tlsclient.WithTLSClientHTTP2AndHTTP11()) + require.NoError(t, err) + require.Equal(t, []string{"h2", "http/1.1"}, tlsConfig.NextProtos) +} diff --git a/tls/client/filesource/filesource_test.go b/tls/client/filesource/filesource_test.go index fa6c8a9..bccf7cf 100644 --- a/tls/client/filesource/filesource_test.go +++ b/tls/client/filesource/filesource_test.go @@ -39,7 +39,7 @@ func TestCertRotation(t *testing.T) { WithNotifyFunc(notifyFunc), ).(*fileSource) - clientCertsStore, err := tlsclient.NewTLSClientCertsStore(slog.Default(), clientSource) + tlsConfig, err := tlsclient.NewTLSConfig(slog.Default(), clientSource) require.NoError(t, err) serverSource := serverfilesource.MustNew( @@ -57,7 +57,7 @@ func TestCertRotation(t *testing.T) { // when client := &http.Client{ - Transport: tlsclient.NewDefaultRoundTripper(tlsclient.WithClientCertsStore(clientCertsStore)), + Transport: tlsclient.NewDefaultRoundTripper(tlsclient.WithClientTLSConfig(tlsConfig)), } resp, err := client.Do(req) require.NoError(t, err) @@ -77,7 +77,7 @@ func TestCertRotation(t *testing.T) { // old client - bad certificate // create new client as connection can be kept alive client = &http.Client{ - Transport: tlsclient.NewDefaultRoundTripper(tlsclient.WithClientCertsStore(clientCertsStore)), + Transport: tlsclient.NewDefaultRoundTripper(tlsclient.WithClientTLSConfig(tlsConfig)), } // nolint:bodyclose _, err = client.Do(req) @@ -106,7 +106,7 @@ func TestKeyEncryption(t *testing.T) { WithSystemPool(true), ).(*fileSource) - clientCertsStore, err := tlsclient.NewTLSClientCertsStore(slog.Default(), clientSource) + tlsConfig, err := tlsclient.NewTLSConfig(slog.Default(), clientSource) require.NoError(t, err) serverSource := serverfilesource.MustNew( @@ -124,9 +124,75 @@ func TestKeyEncryption(t *testing.T) { // when client := &http.Client{ - Transport: tlsclient.NewDefaultRoundTripper(tlsclient.WithClientCertsStore(clientCertsStore)), + Transport: tlsclient.NewDefaultRoundTripper(tlsclient.WithClientTLSConfig(tlsConfig)), } resp, err := client.Do(req) require.NoError(t, err) defer resp.Body.Close() } + +func TestTLSConfigRotatesRootCAs(t *testing.T) { + bundle1 := testutil.NewCertsBundle() + defer bundle1.Close() + + bundle2 := testutil.NewCertsBundle() + defer bundle2.Close() + + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + rotatedCh := make(chan struct{}, 1) + notifyFunc := func() { + rotatedCh <- struct{}{} + } + clientSource := MustNew( + WithClientRootCAs(bundle1.CACert.Name()), + WithClientCert(bundle1.ClientCert.Name(), bundle1.ClientKey.Name()), + WithRefresh(1*time.Second), + WithNotifyFunc(notifyFunc), + ) + + serverSource := serverfilesource.MustNew( + serverfilesource.WithX509KeyPair(bundle1.ServerCert.Name(), bundle1.ServerKey.Name()), + serverfilesource.WithClientAuthFile(bundle1.CACert.Name()), + serverfilesource.WithClientCRLFile(bundle1.CAEmptyCRL.Name()), + serverfilesource.WithRefresh(1*time.Second), + ) + ts.TLS = servertls.MustNewServerConfig(slog.Default(), serverSource) + ts.StartTLS() + + tlsConfig, err := tlsclient.NewTLSConfig(slog.Default(), clientSource) + require.NoError(t, err) + + client := &http.Client{ + Transport: tlsclient.NewDefaultRoundTripper(tlsclient.WithClientTLSConfig(tlsConfig)), + } + resp, err := client.Get(ts.URL) + require.NoError(t, err) + resp.Body.Close() + + require.NoError(t, os.Rename(bundle2.CACert.Name(), bundle1.CACert.Name())) + + select { + case <-rotatedCh: + time.Sleep(100 * time.Millisecond) + case <-time.After(3 * time.Second): + t.Fatal("expected certificate change notification") + } + + client = &http.Client{ + Transport: tlsclient.NewDefaultRoundTripper(tlsclient.WithClientTLSConfig(tlsConfig)), + } + resp, err = client.Get(ts.URL) + if resp != nil { + resp.Body.Close() + } + require.Error(t, err) + + msg := err.Error() + ok := strings.Contains(msg, "certificate signed by unknown authority") || + strings.Contains(msg, "unknown certificate authority") + require.Truef(t, ok, "unexpected error: %q", msg) +} diff --git a/tls/client/option.go b/tls/client/option.go index c8c656f..cb2a389 100644 --- a/tls/client/option.go +++ b/tls/client/option.go @@ -4,8 +4,27 @@ import "crypto/tls" type TLSClientConfigOption func(*tls.Config) +var ( + http2OnlyNextProtos = []string{"h2"} + http2AndHTTP11NextProtos = []string{"h2", "http/1.1"} +) + func WithTLSClientNextProtos(nextProto []string) TLSClientConfigOption { return func(c *tls.Config) { c.NextProtos = nextProto } } + +func WithTLSClientHTTP2() TLSClientConfigOption { + return WithTLSClientNextProtos(http2OnlyNextProtos) +} + +func WithTLSClientHTTP2AndHTTP11() TLSClientConfigOption { + return WithTLSClientNextProtos(http2AndHTTP11NextProtos) +} + +func WithTLSServerName(serverName string) TLSClientConfigOption { + return func(c *tls.Config) { + c.ServerName = serverName + } +} diff --git a/tls/client/roundtripper.go b/tls/client/roundtripper.go index e550fed..659d695 100644 --- a/tls/client/roundtripper.go +++ b/tls/client/roundtripper.go @@ -4,8 +4,6 @@ import ( "crypto/tls" "crypto/x509" "net/http" - - "github.com/grepplabs/cert-source/tls/client/source" ) type RoundTripper struct { @@ -20,21 +18,6 @@ func WithClientTLSConfig(tlsClientConfig *tls.Config) RoundTripperOption { } } -func WithClientCertsStore(source *source.ClientCertsStore) RoundTripperOption { - return func(rt *RoundTripper) { - cs := source.LoadClientCerts() - if rt.transport.TLSClientConfig == nil { - // nolint:gosec - rt.transport.TLSClientConfig = &tls.Config{} - } - rt.transport.TLSClientConfig.RootCAs = cs.RootCAs - rt.transport.TLSClientConfig.InsecureSkipVerify = cs.InsecureSkipVerify - rt.transport.TLSClientConfig.GetClientCertificate = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { - return source.LoadClientCerts().Certificate, nil - } - } -} - func WithSystemRootCA(cert *x509.Certificate) RoundTripperOption { certPool, err := x509.SystemCertPool() if err != nil { diff --git a/tls/server/config/config_test.go b/tls/server/config/config_test.go index 735a99b..4896e75 100644 --- a/tls/server/config/config_test.go +++ b/tls/server/config/config_test.go @@ -53,7 +53,7 @@ func TestGetServerTLSOptionsConfig(t *testing.T) { Key: bundle.ServerKey.Name(), Cert: bundle.ServerCert.Name(), }, - }, tlsserver.WithTLSServerNextProtos([]string{"h2"}), + }, tlsserver.WithTLSServerHTTP2(), tlsserver.WithTLSServerMinVersion(tls.VersionTLS13), tlsserver.WithTLSServerCipherSuites([]uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}), tlsserver.WithTLSServerCurvePreferences([]tls.CurveID{tls.CurveP256, tls.CurveP384}), @@ -69,6 +69,22 @@ func TestGetServerTLSOptionsConfig(t *testing.T) { require.Equal(t, []tls.CurveID{tls.CurveP256, tls.CurveP384}, tlsConfig.CurvePreferences) } +func TestGetServerTLSHTTP2AndHTTP11OptionsConfig(t *testing.T) { + bundle := testutil.NewCertsBundle() + defer bundle.Close() + + tlsConfig, err := GetServerTLSConfig(slog.Default(), &config.TLSServerConfig{ + Enable: true, + Refresh: 0, + File: config.TLSServerFiles{ + Key: bundle.ServerKey.Name(), + Cert: bundle.ServerCert.Name(), + }, + }, tlsserver.WithTLSServerHTTP2AndHTTP11()) + require.NoError(t, err) + require.Equal(t, []string{"h2", "http/1.1"}, tlsConfig.NextProtos) +} + func TestGetServerTLSVerifyPeerCertificateConfig(t *testing.T) { bundle := testutil.NewCertsBundle() defer bundle.Close() diff --git a/tls/server/option.go b/tls/server/option.go index 5c307e8..07c2267 100644 --- a/tls/server/option.go +++ b/tls/server/option.go @@ -7,11 +7,25 @@ import ( type TLSServerConfigOption func(*tls.Config) +var ( + http2OnlyNextProtos = []string{"h2"} + http2AndHTTP11NextProtos = []string{"h2", "http/1.1"} +) + func WithTLSServerNextProtos(nextProto []string) TLSServerConfigOption { return func(c *tls.Config) { c.NextProtos = nextProto } } + +func WithTLSServerHTTP2() TLSServerConfigOption { + return WithTLSServerNextProtos(http2OnlyNextProtos) +} + +func WithTLSServerHTTP2AndHTTP11() TLSServerConfigOption { + return WithTLSServerNextProtos(http2AndHTTP11NextProtos) +} + func WithTLSServerMinVersion(minVersion uint16) TLSServerConfigOption { return func(c *tls.Config) { c.MinVersion = minVersion