diff --git a/plugin/gthulhu/auth.go b/plugin/gthulhu/auth.go index 90c101d..dd8a2c6 100644 --- a/plugin/gthulhu/auth.go +++ b/plugin/gthulhu/auth.go @@ -2,6 +2,7 @@ package gthulhu import ( "bytes" + "crypto/tls" "crypto/x509" "encoding/json" "encoding/pem" @@ -11,6 +12,8 @@ import ( "os" "strings" "time" + + reg "github.com/Gthulhu/plugin/plugin/internal/registry" ) // TokenRequest represents the request structure for JWT token generation @@ -47,20 +50,65 @@ type JWTClient struct { authEnabled bool } -// NewJWTClient creates a new JWT client +// NewJWTClient creates a new JWT client. When mtlsCfg.Enable is true the +// underlying HTTP client is configured with mutual TLS so the plugin +// authenticates itself to the API server and verifies the server certificate +// against the shared CA. func NewJWTClient( publicKeyPath, apiBaseURL string, authEnabled bool, -) *JWTClient { + mtlsCfg reg.MTLSConfig, +) (*JWTClient, error) { + httpClient := &http.Client{ + Timeout: 30 * time.Second, + } + + if mtlsCfg.Enable { + tlsClient, err := buildMTLSClient(mtlsCfg) + if err != nil { + return nil, err + } + httpClient = tlsClient + } + return &JWTClient{ publicKeyPath: publicKeyPath, apiBaseURL: strings.TrimSuffix(apiBaseURL, "/"), - httpClient: &http.Client{ - Timeout: 30 * time.Second, - }, - authEnabled: authEnabled, + httpClient: httpClient, + authEnabled: authEnabled, + }, nil +} + +// buildMTLSClient constructs an HTTP client with mutual TLS configured. +func buildMTLSClient(mtlsCfg reg.MTLSConfig) (*http.Client, error) { + cert, err := tls.X509KeyPair([]byte(mtlsCfg.CertPem), []byte(mtlsCfg.KeyPem)) + if err != nil { + return nil, fmt.Errorf("load mTLS client certificate: %w", err) + } + + caPool := x509.NewCertPool() + if !caPool.AppendCertsFromPEM([]byte(mtlsCfg.CAPem)) { + return nil, fmt.Errorf("parse mTLS CA certificate") + } + + tlsCfg := &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: caPool, + MinVersion: tls.VersionTLS12, } + + defaultTransport, ok := http.DefaultTransport.(*http.Transport) + if !ok { + return nil, fmt.Errorf("unexpected default transport type %T", http.DefaultTransport) + } + mtlsTransport := defaultTransport.Clone() + mtlsTransport.TLSClientConfig = tlsCfg + + return &http.Client{ + Timeout: 30 * time.Second, + Transport: mtlsTransport, + }, nil } // loadPublicKey loads the RSA public key from PEM file @@ -161,12 +209,18 @@ func (c *JWTClient) GetAuthenticatedClient() (*http.Client, error) { return nil, err } + // Preserve any custom transport (e.g. mTLS) already configured on the client. + transport := c.httpClient.Transport + if transport == nil { + transport = http.DefaultTransport + } + // Create a custom transport that adds the Authorization header client := &http.Client{ Timeout: 30 * time.Second, Transport: &authenticatedTransport{ token: c.token, - transport: http.DefaultTransport, + transport: transport, }, } diff --git a/plugin/gthulhu/auth_test.go b/plugin/gthulhu/auth_test.go new file mode 100644 index 0000000..4b033f2 --- /dev/null +++ b/plugin/gthulhu/auth_test.go @@ -0,0 +1,210 @@ +package gthulhu + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + reg "github.com/Gthulhu/plugin/plugin/internal/registry" +) + +// authTestCerts holds PEM-encoded self-signed CA + leaf cert for unit testing. +type authTestCerts struct { + caPEM string + certPEM string + keyPEM string +} + +// generateAuthTestCerts creates a minimal self-signed CA and a leaf cert/key signed by it. +func generateAuthTestCerts(t *testing.T) authTestCerts { + t.Helper() + + notBefore := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + notAfter := time.Date(2100, 1, 1, 0, 0, 0, 0, time.UTC) + + // Generate CA key + cert + caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("generate CA key: %v", err) + } + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "test-ca"}, + NotBefore: notBefore, + NotAfter: notAfter, + IsCA: true, + BasicConstraintsValid: true, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + } + caDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) + if err != nil { + t.Fatalf("create CA cert: %v", err) + } + caCert, err := x509.ParseCertificate(caDER) + if err != nil { + t.Fatalf("parse CA cert: %v", err) + } + caPEM := string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caDER})) + + // Generate leaf key + cert signed by CA + leafKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("generate leaf key: %v", err) + } + leafTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{CommonName: "test-leaf"}, + NotBefore: notBefore, + NotAfter: notAfter, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + } + leafDER, err := x509.CreateCertificate(rand.Reader, leafTemplate, caCert, &leafKey.PublicKey, caKey) + if err != nil { + t.Fatalf("create leaf cert: %v", err) + } + certPEM := string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: leafDER})) + + leafKeyDER, err := x509.MarshalECPrivateKey(leafKey) + if err != nil { + t.Fatalf("marshal leaf key: %v", err) + } + keyPEM := string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: leafKeyDER})) + + return authTestCerts{caPEM: caPEM, certPEM: certPEM, keyPEM: keyPEM} +} + +func TestNewJWTClientMTLSDisabled(t *testing.T) { + mtlsCfg := reg.MTLSConfig{Enable: false} + c, err := NewJWTClient("", "http://localhost", false, mtlsCfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if c == nil { + t.Fatal("expected non-nil JWTClient") + } + // Transport should be nil (plain http.Client default) + if c.httpClient.Transport != nil { + t.Errorf("expected nil transport for plain HTTP, got %T", c.httpClient.Transport) + } +} + +func TestNewJWTClientMTLSBadCert(t *testing.T) { + mtlsCfg := reg.MTLSConfig{ + Enable: true, + CertPem: "not-valid-pem", + KeyPem: "not-valid-pem", + CAPem: "not-valid-pem", + } + _, err := NewJWTClient("", "https://localhost", false, mtlsCfg) + if err == nil { + t.Fatal("expected error for invalid cert PEM, got nil") + } + if !strings.Contains(err.Error(), "load mTLS client certificate") { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestNewJWTClientMTLSBadCA(t *testing.T) { + certs := generateAuthTestCerts(t) + mtlsCfg := reg.MTLSConfig{ + Enable: true, + CertPem: certs.certPEM, + KeyPem: certs.keyPEM, + CAPem: "not-a-valid-ca-pem", + } + _, err := NewJWTClient("", "https://localhost", false, mtlsCfg) + if err == nil { + t.Fatal("expected error for invalid CA PEM, got nil") + } + if !strings.Contains(err.Error(), "parse mTLS CA certificate") { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestNewJWTClientMTLSEnabled(t *testing.T) { + certs := generateAuthTestCerts(t) + mtlsCfg := reg.MTLSConfig{ + Enable: true, + CertPem: certs.certPEM, + KeyPem: certs.keyPEM, + CAPem: certs.caPEM, + } + c, err := NewJWTClient("", "https://localhost", false, mtlsCfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if c == nil { + t.Fatal("expected non-nil JWTClient") + } + // Transport should be an mTLS-capable *http.Transport + if _, ok := c.httpClient.Transport.(*http.Transport); !ok { + t.Errorf("expected *http.Transport, got %T", c.httpClient.Transport) + } +} + +// TestNewJWTClientMTLSEndToEnd verifies that a JWTClient configured with mTLS +// can successfully complete a round-trip against an mTLS-enforcing httptest server. +func TestNewJWTClientMTLSEndToEnd(t *testing.T) { + certs := generateAuthTestCerts(t) + + // Build mTLS test server that requires a client cert signed by our CA. + serverCert, err := tls.X509KeyPair([]byte(certs.certPEM), []byte(certs.keyPEM)) + if err != nil { + t.Fatalf("load server cert: %v", err) + } + caPool := x509.NewCertPool() + if !caPool.AppendCertsFromPEM([]byte(certs.caPEM)) { + t.Fatal("append CA cert") + } + serverTLSCfg := &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: caPool, + } + + server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + server.TLS = serverTLSCfg + server.StartTLS() + defer server.Close() + + mtlsCfg := reg.MTLSConfig{ + Enable: true, + CertPem: certs.certPEM, + KeyPem: certs.keyPEM, + CAPem: certs.caPEM, + } + // authEnabled=false so no token fetch is attempted; we only verify the TLS handshake. + c, err := NewJWTClient("", server.URL, false, mtlsCfg) + if err != nil { + t.Fatalf("NewJWTClient: %v", err) + } + + resp, err := c.MakeAuthenticatedRequest("GET", server.URL, nil) + if err != nil { + t.Fatalf("MakeAuthenticatedRequest over mTLS: %v", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + t.Logf("resp.Body.Close(): %v", err) + } + }() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d; want %d", resp.StatusCode, http.StatusOK) + } +} diff --git a/plugin/gthulhu/gthulhu.go b/plugin/gthulhu/gthulhu.go index 8da1811..0b6c332 100644 --- a/plugin/gthulhu/gthulhu.go +++ b/plugin/gthulhu/gthulhu.go @@ -34,6 +34,7 @@ func init() { config.APIConfig.PublicKeyPath, config.APIConfig.BaseURL, config.APIConfig.AuthEnabled, + config.APIConfig.MTLS, ) if err != nil { return nil, err @@ -312,8 +313,13 @@ func (g *GthulhuPlugin) InitJWTClient( publicKeyPath, apiBaseURL string, authEnabled bool, + mtlsCfg reg.MTLSConfig, ) error { - g.jwtClient = NewJWTClient(publicKeyPath, apiBaseURL, authEnabled) + client, err := NewJWTClient(publicKeyPath, apiBaseURL, authEnabled, mtlsCfg) + if err != nil { + return err + } + g.jwtClient = client return nil } diff --git a/plugin/internal/registry/registry.go b/plugin/internal/registry/registry.go index 49ba1f0..e65d0c4 100644 --- a/plugin/internal/registry/registry.go +++ b/plugin/internal/registry/registry.go @@ -38,12 +38,23 @@ type Scheduler struct { SliceNsMin uint64 `yaml:"slice_ns_min"` } +// MTLSConfig holds the mutual TLS configuration used for plugin → API server communication. +// CertPem and KeyPem are the plugin's own certificate/key pair signed by the private CA. +// CAPem is the private CA certificate used to verify the API server's certificate. +type MTLSConfig struct { + Enable bool `yaml:"enable"` + CertPem string `yaml:"cert_pem"` + KeyPem string `yaml:"key_pem"` + CAPem string `yaml:"ca_pem"` +} + type APIConfig struct { - PublicKeyPath string `yaml:"public_key_path"` - BaseURL string `yaml:"base_url"` - Interval int `yaml:"interval"` - Enabled bool `yaml:"enabled"` - AuthEnabled bool `yaml:"auth_enabled"` + PublicKeyPath string `yaml:"public_key_path"` + BaseURL string `yaml:"base_url"` + Interval int `yaml:"interval"` + Enabled bool `yaml:"enabled"` + AuthEnabled bool `yaml:"auth_enabled"` + MTLS MTLSConfig `yaml:"mtls"` } // SchedConfig holds the configuration parameters for creating a scheduler plugin diff --git a/plugin/plugin.go b/plugin/plugin.go index fbd4f61..6ed63aa 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -11,6 +11,7 @@ type ( Sched = reg.Sched CustomScheduler = reg.CustomScheduler Scheduler = reg.Scheduler + MTLSConfig = reg.MTLSConfig APIConfig = reg.APIConfig SchedConfig = reg.SchedConfig PluginFactory = reg.PluginFactory