Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ import (
)

func main() {
tlsClientConfigFunc, err := tlsclientconfig.GetTLSClientConfigFunc(slog.Default(), &tlsconfig.TLSClientConfig{
tlsClientConfig, err := tlsclientconfig.GetTLSClientConfig(slog.Default(), &tlsconfig.TLSClientConfig{
Enable: true,
Refresh: 1 * time.Second,
InsecureSkipVerify: false,
Expand All @@ -92,7 +92,7 @@ func main() {
if err != nil {
log.Fatalln(err)
}
transport := tlsclient.NewDefaultRoundTripper(tlsclient.WithClientTLSConfig(tlsClientConfigFunc()))
transport := tlsclient.NewDefaultRoundTripper(tlsclient.WithClientTLSConfig(tlsClientConfig))
client := &http.Client{Transport: transport}
resp, err := client.Get("https://localhost:8443")
if err != nil {
Expand Down
81 changes: 66 additions & 15 deletions tls/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tlsclient

import (
"crypto/tls"
"crypto/x509"
"errors"
"log/slog"
"time"
Expand All @@ -13,36 +14,86 @@ const (
initLoadTimeout = 5 * time.Second
)

type TLSClientConfigFunc func() *tls.Config
type Provider struct {
store *source.ClientCertsStore
opts []TLSClientConfigOption
}

func NewTLSClientConfigFunc(logger *slog.Logger, src source.ClientCertsSource, opts ...TLSClientConfigOption) (TLSClientConfigFunc, error) {
func NewProvider(logger *slog.Logger, src source.ClientCertsSource, opts ...TLSClientConfigOption) (*Provider, error) {
store, err := NewTLSClientCertsStore(logger, src)
if err != nil {
return nil, err
}
return &Provider{
store: store,
opts: opts,
}, nil
}

func (p *Provider) LoadClientCerts() source.ClientCerts {
return p.store.LoadClientCerts()
}

func (p *Provider) TLSConfig() *tls.Config {
cs := p.LoadClientCerts()
x := &tls.Config{
// Root CAs are read dynamically in VerifyConnection to support rotation.
// nolint:gosec
InsecureSkipVerify: true,
}
var getClientCertificateFunc func(info *tls.CertificateRequestInfo) (*tls.Certificate, error)
if store.LoadClientCerts().Certificate != nil {
if cs.Certificate != nil {
// Set function only when client certificate is available.
// TLS 1.3 checks if GetClientCertificate function is nil, if it is not nil,
// it assumes client certificate is available which call cause the panic if nil is returned.
//nolint:unparam
getClientCertificateFunc = func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return store.LoadClientCerts().Certificate, nil
return p.LoadClientCerts().Certificate, nil
}
}
return func() *tls.Config {
cs := store.LoadClientCerts()
x := &tls.Config{
RootCAs: cs.RootCAs,
// nolint:gosec
InsecureSkipVerify: cs.InsecureSkipVerify,
GetClientCertificate: getClientCertificateFunc,
x.GetClientCertificate = getClientCertificateFunc
for _, opt := range p.opts {
opt(x)
}
x.VerifyConnection = p.verifyConnection(x.ServerName)
return x
}

func (p *Provider) verifyConnection(configuredServerName string) func(tls.ConnectionState) error {
return func(cs tls.ConnectionState) error {
clientCerts := p.LoadClientCerts()
if clientCerts.InsecureSkipVerify {
return nil
}
for _, opt := range opts {
opt(x)
if len(cs.PeerCertificates) == 0 {
return errors.New("tls: no peer certificates")
}
return x
}, nil

serverName := cs.ServerName
if serverName == "" {
serverName = configuredServerName
}

opts := x509.VerifyOptions{
Roots: clientCerts.RootCAs,
DNSName: serverName,
Intermediates: x509.NewCertPool(),
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
for _, cert := range cs.PeerCertificates[1:] {
opts.Intermediates.AddCert(cert)
}
_, err := cs.PeerCertificates[0].Verify(opts)
return err
}
}

func NewTLSConfig(logger *slog.Logger, src source.ClientCertsSource, opts ...TLSClientConfigOption) (*tls.Config, error) {
provider, err := NewProvider(logger, src, opts...)
if err != nil {
return nil, err
}
return provider.TLSConfig(), nil
}

func NewTLSClientCertsStore(logger *slog.Logger, src source.ClientCertsSource) (*source.ClientCertsStore, error) {
Expand Down
25 changes: 23 additions & 2 deletions tls/client/config/config.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package config

import (
"crypto/tls"
"fmt"
"log/slog"
"net/http"

"github.com/grepplabs/cert-source/config"
tlsclient "github.com/grepplabs/cert-source/tls/client"
"github.com/grepplabs/cert-source/tls/client/filesource"
)

func GetTLSClientConfigFunc(logger *slog.Logger, conf *config.TLSClientConfig, opts ...tlsclient.TLSClientConfigOption) (tlsclient.TLSClientConfigFunc, error) {
func GetTLSClientConfig(logger *slog.Logger, conf *config.TLSClientConfig, opts ...tlsclient.TLSClientConfigOption) (*tls.Config, error) {
if !conf.Enable {
return nil, nil
}
Expand All @@ -25,5 +27,24 @@ func GetTLSClientConfigFunc(logger *slog.Logger, conf *config.TLSClientConfig, o
if err != nil {
return nil, fmt.Errorf("setup client cert file source: %w", err)
}
return tlsclient.NewTLSClientConfigFunc(logger, fs, opts...)
return tlsclient.NewTLSConfig(logger, fs, opts...)
}

func NewRoundTripper(logger *slog.Logger, conf *config.TLSClientConfig, opts ...tlsclient.TLSClientConfigOption) (*tlsclient.RoundTripper, error) {
tlsConfig, err := GetTLSClientConfig(logger, conf, opts...)
if err != nil {
return nil, err
}
if tlsConfig == nil {
return tlsclient.NewDefaultRoundTripper(), nil
}
return tlsclient.NewDefaultRoundTripper(tlsclient.WithClientTLSConfig(tlsConfig)), nil
}

func NewClient(logger *slog.Logger, conf *config.TLSClientConfig, opts ...tlsclient.TLSClientConfigOption) (*http.Client, error) {
transport, err := NewRoundTripper(logger, conf, opts...)
if err != nil {
return nil, err
}
return &http.Client{Transport: transport}, nil
}
41 changes: 29 additions & 12 deletions tls/client/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,20 @@ import (
func TestGetClientTLSConfig(t *testing.T) {
bundle := testutil.NewCertsBundle()
defer bundle.Close()
tlsConfigFunc, err := GetTLSClientConfigFunc(slog.Default(), &config.TLSClientConfig{
tlsConfig, err := GetTLSClientConfig(slog.Default(), &config.TLSClientConfig{
Enable: true,
Refresh: 0,
File: config.TLSClientFiles{
Key: bundle.ClientKey.Name(),
Cert: bundle.ClientCert.Name(),
RootCAs: bundle.CACert.Name(),
},
}, tlsclient.WithTLSClientNextProtos([]string{"h2"}))
}, tlsclient.WithTLSClientHTTP2(), tlsclient.WithTLSServerName("localhost"))
require.NoError(t, err)
tlsConfig := tlsConfigFunc()
require.NotNil(t, tlsConfig.RootCAs)
require.True(t, tlsConfig.InsecureSkipVerify)
require.NotNil(t, tlsConfig.VerifyConnection)
require.Equal(t, []string{"h2"}, tlsConfig.NextProtos)
require.Equal(t, "localhost", tlsConfig.ServerName)

clientCert, err := tlsConfig.GetClientCertificate(nil)
require.NoError(t, err)
Expand All @@ -35,33 +36,49 @@ func TestGetClientTLSConfig(t *testing.T) {
func TestGetClientTLSConfigNoConfig(t *testing.T) {
bundle := testutil.NewCertsBundle()
defer bundle.Close()
tlsConfigFunc, err := GetTLSClientConfigFunc(slog.Default(), &config.TLSClientConfig{
tlsConfig, err := GetTLSClientConfig(slog.Default(), &config.TLSClientConfig{
Enable: true,
Refresh: 0,
File: config.TLSClientFiles{},
})
require.NoError(t, err)
tlsConfig := tlsConfigFunc()
require.Nil(t, tlsConfig.RootCAs)
require.True(t, tlsConfig.InsecureSkipVerify)
require.Nil(t, tlsConfig.GetClientCertificate)
}

func TestGetClientTLSConfigSkipVerify(t *testing.T) {
bundle := testutil.NewCertsBundle()
defer bundle.Close()
tlsConfigFunc, err := GetTLSClientConfigFunc(slog.Default(), &config.TLSClientConfig{
Enable: true,
Refresh: 0,
tlsConfig, err := GetTLSClientConfig(slog.Default(), &config.TLSClientConfig{
Enable: true,
Refresh: 0,
InsecureSkipVerify: true,
File: config.TLSClientFiles{
Key: bundle.ClientKey.Name(),
Cert: bundle.ClientCert.Name(),
},
})
require.NoError(t, err)
tlsConfig := tlsConfigFunc()
require.Nil(t, tlsConfig.RootCAs)
require.True(t, tlsConfig.InsecureSkipVerify)

clientCert, err := tlsConfig.GetClientCertificate(nil)
require.NoError(t, err)
require.NotNil(t, clientCert)
}

func TestGetClientTLSHTTP2AndHTTP11Config(t *testing.T) {
bundle := testutil.NewCertsBundle()
defer bundle.Close()

tlsConfig, err := GetTLSClientConfig(slog.Default(), &config.TLSClientConfig{
Enable: true,
Refresh: 0,
File: config.TLSClientFiles{
Key: bundle.ClientKey.Name(),
Cert: bundle.ClientCert.Name(),
RootCAs: bundle.CACert.Name(),
},
}, tlsclient.WithTLSClientHTTP2AndHTTP11())
require.NoError(t, err)
require.Equal(t, []string{"h2", "http/1.1"}, tlsConfig.NextProtos)
}
76 changes: 71 additions & 5 deletions tls/client/filesource/filesource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestCertRotation(t *testing.T) {
WithNotifyFunc(notifyFunc),
).(*fileSource)

clientCertsStore, err := tlsclient.NewTLSClientCertsStore(slog.Default(), clientSource)
tlsConfig, err := tlsclient.NewTLSConfig(slog.Default(), clientSource)
require.NoError(t, err)

serverSource := serverfilesource.MustNew(
Expand All @@ -57,7 +57,7 @@ func TestCertRotation(t *testing.T) {

// when
client := &http.Client{
Transport: tlsclient.NewDefaultRoundTripper(tlsclient.WithClientCertsStore(clientCertsStore)),
Transport: tlsclient.NewDefaultRoundTripper(tlsclient.WithClientTLSConfig(tlsConfig)),
}
resp, err := client.Do(req)
require.NoError(t, err)
Expand All @@ -77,7 +77,7 @@ func TestCertRotation(t *testing.T) {
// old client - bad certificate
// create new client as connection can be kept alive
client = &http.Client{
Transport: tlsclient.NewDefaultRoundTripper(tlsclient.WithClientCertsStore(clientCertsStore)),
Transport: tlsclient.NewDefaultRoundTripper(tlsclient.WithClientTLSConfig(tlsConfig)),
}
// nolint:bodyclose
_, err = client.Do(req)
Expand Down Expand Up @@ -106,7 +106,7 @@ func TestKeyEncryption(t *testing.T) {
WithSystemPool(true),
).(*fileSource)

clientCertsStore, err := tlsclient.NewTLSClientCertsStore(slog.Default(), clientSource)
tlsConfig, err := tlsclient.NewTLSConfig(slog.Default(), clientSource)
require.NoError(t, err)

serverSource := serverfilesource.MustNew(
Expand All @@ -124,9 +124,75 @@ func TestKeyEncryption(t *testing.T) {

// when
client := &http.Client{
Transport: tlsclient.NewDefaultRoundTripper(tlsclient.WithClientCertsStore(clientCertsStore)),
Transport: tlsclient.NewDefaultRoundTripper(tlsclient.WithClientTLSConfig(tlsConfig)),
}
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
}

func TestTLSConfigRotatesRootCAs(t *testing.T) {
bundle1 := testutil.NewCertsBundle()
defer bundle1.Close()

bundle2 := testutil.NewCertsBundle()
defer bundle2.Close()

ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()

rotatedCh := make(chan struct{}, 1)
notifyFunc := func() {
rotatedCh <- struct{}{}
}
clientSource := MustNew(
WithClientRootCAs(bundle1.CACert.Name()),
WithClientCert(bundle1.ClientCert.Name(), bundle1.ClientKey.Name()),
WithRefresh(1*time.Second),
WithNotifyFunc(notifyFunc),
)

serverSource := serverfilesource.MustNew(
serverfilesource.WithX509KeyPair(bundle1.ServerCert.Name(), bundle1.ServerKey.Name()),
serverfilesource.WithClientAuthFile(bundle1.CACert.Name()),
serverfilesource.WithClientCRLFile(bundle1.CAEmptyCRL.Name()),
serverfilesource.WithRefresh(1*time.Second),
)
ts.TLS = servertls.MustNewServerConfig(slog.Default(), serverSource)
ts.StartTLS()

tlsConfig, err := tlsclient.NewTLSConfig(slog.Default(), clientSource)
require.NoError(t, err)

client := &http.Client{
Transport: tlsclient.NewDefaultRoundTripper(tlsclient.WithClientTLSConfig(tlsConfig)),
}
resp, err := client.Get(ts.URL)
require.NoError(t, err)
resp.Body.Close()

require.NoError(t, os.Rename(bundle2.CACert.Name(), bundle1.CACert.Name()))

select {
case <-rotatedCh:
time.Sleep(100 * time.Millisecond)
case <-time.After(3 * time.Second):
t.Fatal("expected certificate change notification")
}

client = &http.Client{
Transport: tlsclient.NewDefaultRoundTripper(tlsclient.WithClientTLSConfig(tlsConfig)),
}
resp, err = client.Get(ts.URL)
if resp != nil {
resp.Body.Close()
}
require.Error(t, err)

msg := err.Error()
ok := strings.Contains(msg, "certificate signed by unknown authority") ||
strings.Contains(msg, "unknown certificate authority")
require.Truef(t, ok, "unexpected error: %q", msg)
}
19 changes: 19 additions & 0 deletions tls/client/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,27 @@ import "crypto/tls"

type TLSClientConfigOption func(*tls.Config)

var (
http2OnlyNextProtos = []string{"h2"}
http2AndHTTP11NextProtos = []string{"h2", "http/1.1"}
)

func WithTLSClientNextProtos(nextProto []string) TLSClientConfigOption {
return func(c *tls.Config) {
c.NextProtos = nextProto
}
}

func WithTLSClientHTTP2() TLSClientConfigOption {
return WithTLSClientNextProtos(http2OnlyNextProtos)
}

func WithTLSClientHTTP2AndHTTP11() TLSClientConfigOption {
return WithTLSClientNextProtos(http2AndHTTP11NextProtos)
}

func WithTLSServerName(serverName string) TLSClientConfigOption {
return func(c *tls.Config) {
c.ServerName = serverName
}
}
Loading