diff --git a/connection/http.go b/connection/http.go index 1009d7e3..94e9a2b4 100644 --- a/connection/http.go +++ b/connection/http.go @@ -1,6 +1,8 @@ package connection import ( + "crypto/tls" + "crypto/x509" "encoding/json" "fmt" netHTTP "net/http" @@ -38,7 +40,42 @@ type TLSConfig struct { } func (t TLSConfig) IsEmpty() bool { - return t.CA.IsEmpty() || t.Cert.IsEmpty() || t.Key.IsEmpty() + hasClientCert := !t.Cert.IsEmpty() || !t.Key.IsEmpty() + return !t.InsecureSkipVerify && t.HandshakeTimeout == 0 && t.CA.IsEmpty() && !hasClientCert +} + +func (t TLSConfig) clientConfig() (*tls.Config, error) { + if t.IsEmpty() { + return nil, nil + } + + config := &tls.Config{InsecureSkipVerify: t.InsecureSkipVerify} + if t.CA.ValueStatic != "" { + certPool, err := x509.SystemCertPool() + if err != nil { + return nil, err + } + + if !certPool.AppendCertsFromPEM([]byte(t.CA.ValueStatic)) { + return nil, fmt.Errorf("failed to append ca certificate") + } + config.RootCAs = certPool + } + + hasCert := !t.Cert.IsEmpty() + hasKey := !t.Key.IsEmpty() + if hasCert != hasKey { + return nil, fmt.Errorf("both client certificate and key must be provided") + } + if hasCert { + cert, err := tls.X509KeyPair([]byte(t.Cert.ValueStatic), []byte(t.Key.ValueStatic)) + if err != nil { + return nil, fmt.Errorf("failed to create client certificate: %w", err) + } + config.Certificates = []tls.Certificate{cert} + } + + return config, nil } // +kubebuilder:object:generate=true @@ -348,19 +385,49 @@ func (h *HTTPConnection) Hydrate(ctx ConnectionContext, namespace string) (*HTTP return h, nil } -func (h HTTPConnection) Transport(opts ...types.ClientOption) netHTTP.RoundTripper { +func (h HTTPConnection) Transport(opts ...types.ClientOption) (netHTTP.RoundTripper, error) { return h.TransportWithContext(nil, opts...) } -func (h HTTPConnection) TransportWithContext(ctx any, opts ...types.ClientOption) netHTTP.RoundTripper { +func (h HTTPConnection) TransportWithContext(ctx any, opts ...types.ClientOption) (netHTTP.RoundTripper, error) { o := types.NewClientOptions(opts...) - base := applyHTTPObservability(ctx, "http", &netHTTP.Transport{}, o.HARCollector) + feature := o.Feature + if feature == "" { + feature = "http" + } + + base, err := h.transport() + if err != nil { + return nil, err + } + + base = applyHTTPObservability(ctx, feature, base, o.HARCollector) rt := &httpConnectionRoundTripper{ HTTPConnection: h, Base: base, - TokenTransport: harTokenTransport(ctx, "http", o.HARCollector), + TokenTransport: harTokenTransport(ctx, feature, o.HARCollector), } - return rt + return rt, nil +} + +func (h HTTPConnection) transport() (netHTTP.RoundTripper, error) { + base, ok := netHTTP.DefaultTransport.(*netHTTP.Transport) + if !ok { + base = &netHTTP.Transport{} + } + + transport := base.Clone() + if !h.TLS.IsEmpty() { + tlsConfig, err := h.TLS.clientConfig() + if err != nil { + return nil, err + } + transport.TLSClientConfig = tlsConfig + if h.TLS.HandshakeTimeout != 0 { + transport.TLSHandshakeTimeout = h.TLS.HandshakeTimeout + } + } + return transport, nil } type httpConnectionRoundTripper struct { @@ -394,10 +461,6 @@ func (rt *httpConnectionRoundTripper) RoundTrip(req *netHTTP.Request) (*netHTTP. } } - if !conn.TLS.IsEmpty() { - rt.TLS = conn.TLS - } - if conn.AWSSigV4 != nil && conn.AwsConfig != nil { base = middlewares.NewAWSSigv4Transport(middlewares.AWSSigv4Config{ Region: conn.AwsConfig.Region, diff --git a/connection/http_test.go b/connection/http_test.go index 65674e6d..fa0d1489 100644 --- a/connection/http_test.go +++ b/connection/http_test.go @@ -2,13 +2,181 @@ package connection import ( gocontext "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net/http" + "net/http/httptest" "testing" + "time" "github.com/flanksource/duty/models" "github.com/flanksource/duty/types" "github.com/onsi/gomega" ) +func TestTLSConfigIsEmpty(t *testing.T) { + tests := []struct { + name string + config TLSConfig + expects bool + }{ + {name: "empty", config: TLSConfig{}, expects: true}, + {name: "insecure skip verify", config: TLSConfig{InsecureSkipVerify: true}}, + {name: "handshake timeout", config: TLSConfig{HandshakeTimeout: time.Second}}, + {name: "ca only", config: TLSConfig{CA: types.EnvVar{ValueStatic: "ca"}}}, + {name: "cert only", config: TLSConfig{Cert: types.EnvVar{ValueStatic: "cert"}}}, + {name: "key only", config: TLSConfig{Key: types.EnvVar{ValueStatic: "key"}}}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := gomega.NewWithT(t) + g.Expect(tc.config.IsEmpty()).To(gomega.Equal(tc.expects)) + }) + } +} + +func TestHTTPConnectionTransportTLS(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: server.Certificate().Raw}) + + tests := []struct { + name string + config TLSConfig + expectsErr bool + }{ + { + name: "default TLS rejects unknown CA", + expectsErr: true, + }, + { + name: "custom CA", + config: TLSConfig{CA: types.EnvVar{ValueStatic: string(certPEM)}}, + }, + { + name: "insecure skip verify", + config: TLSConfig{InsecureSkipVerify: true}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + g := gomega.NewWithT(t) + rt, err := HTTPConnection{TLS: tc.config}.Transport() + g.Expect(err).ToNot(gomega.HaveOccurred()) + client := &http.Client{Transport: rt} + + resp, err := client.Get(server.URL) + if tc.expectsErr { + g.Expect(err).To(gomega.HaveOccurred()) + return + } + + g.Expect(err).ToNot(gomega.HaveOccurred()) + defer resp.Body.Close() + g.Expect(resp.StatusCode).To(gomega.Equal(http.StatusOK)) + }) + } +} + +func TestHTTPConnectionTransportMTLS(t *testing.T) { + g := gomega.NewWithT(t) + clientCertPEM, clientKeyPEM, clientCAPool := newClientCertificate(t) + + server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { + http.Error(w, "client certificate required", http.StatusUnauthorized) + return + } + w.WriteHeader(http.StatusOK) + })) + server.TLS = &tls.Config{ + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: clientCAPool, + } + server.StartTLS() + defer server.Close() + + serverCAPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: server.Certificate().Raw}) + rt, err := HTTPConnection{TLS: TLSConfig{ + CA: types.EnvVar{ValueStatic: string(serverCAPEM)}, + Cert: types.EnvVar{ValueStatic: string(clientCertPEM)}, + Key: types.EnvVar{ValueStatic: string(clientKeyPEM)}, + }}.Transport() + g.Expect(err).ToNot(gomega.HaveOccurred()) + + client := &http.Client{Transport: rt} + resp, err := client.Get(server.URL) + g.Expect(err).ToNot(gomega.HaveOccurred()) + defer resp.Body.Close() + g.Expect(resp.StatusCode).To(gomega.Equal(http.StatusOK)) +} + +func newClientCertificate(t *testing.T) (certPEM, keyPEM []byte, caPool *x509.CertPool) { + t.Helper() + + caKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate ca key: %v", err) + } + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "test-ca"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IsCA: true, + } + caDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) + if err != nil { + t.Fatalf("create ca certificate: %v", err) + } + caCert, err := x509.ParseCertificate(caDER) + if err != nil { + t.Fatalf("parse ca certificate: %v", err) + } + + clientKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate client key: %v", err) + } + clientTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{CommonName: "test-client"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + clientDER, err := x509.CreateCertificate(rand.Reader, clientTemplate, caCert, &clientKey.PublicKey, caKey) + if err != nil { + t.Fatalf("create client certificate: %v", err) + } + + caPool = x509.NewCertPool() + caPool.AddCert(caCert) + certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: clientDER}) + keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(clientKey)}) + return certPEM, keyPEM, caPool +} + +func TestCreateHTTPClientWithTLS(t *testing.T) { + g := gomega.NewWithT(t) + client, err := CreateHTTPClient(nil, HTTPConnection{TLS: TLSConfig{InsecureSkipVerify: true}}) + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(client).ToNot(gomega.BeNil()) +} + func TestHTTPConnectionPretty(t *testing.T) { tests := []struct { name string diff --git a/connection/prometheus.go b/connection/prometheus.go index 4aea54ff..9af99f03 100644 --- a/connection/prometheus.go +++ b/connection/prometheus.go @@ -36,9 +36,14 @@ func (p *PrometheusConnection) Populate(ctx ConnectionContext) error { } func (p *PrometheusConnection) NewClient(ctx context.Context, opts ...types.ClientOption) (v1.API, error) { + rt, err := p.HTTPConnection.TransportWithContext(ctx, opts...) + if err != nil { + return nil, err + } + cfg := api.Config{ Address: p.HTTPConnection.URL, - RoundTripper: p.HTTPConnection.TransportWithContext(ctx, opts...), + RoundTripper: rt, } client, err := api.NewClient(cfg)