From 96ded8e0db40e1eb6c011462fc973c42caa4cd62 Mon Sep 17 00:00:00 2001 From: Natalie Gaston Date: Mon, 27 Apr 2026 17:01:49 -0700 Subject: [PATCH 1/4] feat: add certificate validation on load and upload --- internal/certificates/generate.go | 45 +++++++++++++++++---- internal/certificates/generate_fuzz_test.go | 6 +-- internal/certificates/generate_test.go | 42 +++++++++++++++++++ internal/usecase/devices/certificates.go | 13 ++++-- 4 files changed, 91 insertions(+), 15 deletions(-) diff --git a/internal/certificates/generate.go b/internal/certificates/generate.go index 8fca175fe..fbda68b37 100644 --- a/internal/certificates/generate.go +++ b/internal/certificates/generate.go @@ -32,8 +32,19 @@ var ( ErrDecodeCertificatePEM = errors.New("failed to decode certificate PEM") ErrDecodePrivateKeyPEM = errors.New("failed to decode private key PEM") ErrCertFilesNotFound = errors.New("certificate files not found") + ErrCertNotYetValid = errors.New("certificate is not yet valid: notBefore is in the future") ) +// validateCertDates checks that the certificate's validity window covers the current time. +func validateCertDates(cert *x509.Certificate) error { + now := time.Now() + if now.Before(cert.NotBefore) { + return fmt.Errorf("%w: notBefore=%s", ErrCertNotYetValid, cert.NotBefore.Format(time.RFC3339)) + } + + return nil +} + // ObjectStorager extends security.Storager with object storage capabilities. type ObjectStorager interface { security.Storager @@ -101,6 +112,10 @@ func ParseCertificateFromPEM(certPEM, keyPEM string) (*x509.Certificate, *rsa.Pr return nil, nil, fmt.Errorf("failed to parse certificate: %w", err) } + if err := validateCertDates(cert); err != nil { + return nil, nil, err + } + // Decode private key PEM keyBlock, _ := pem.Decode([]byte(keyPEM)) if keyBlock == nil { @@ -135,6 +150,10 @@ func LoadCertificateFromFile(certPath, keyPath string) (*x509.Certificate, *rsa. return nil, nil, fmt.Errorf("failed to parse certificate: %w", err) } + if err := validateCertDates(cert); err != nil { + return nil, nil, err + } + // Read private key file keyPEM, err := os.ReadFile(keyPath) if err != nil { @@ -169,11 +188,11 @@ func CheckAndLoadOrGenerateRootCertificate(addThumbPrintToName bool, commonName, if err == nil { return cert, key, nil } - // If loading fails, fall through to generation - log.Printf("Warning: Failed to load existing certificates: %v. Generating new ones...", err) + // Files exist but failed validation — do not silently regenerate. + return nil, nil, fmt.Errorf("existing root certificate is invalid: %w", err) } - // Files don't exist or loading failed, generate new certificates + // Files don't exist — generate new certificates. return GenerateRootCertificate(addThumbPrintToName, commonName, country, organization, strong) } @@ -202,7 +221,12 @@ func LoadOrGenerateRootCertificateWithVault(store security.Storager, addThumbPri return cert, key, nil } - // Generate new certificates + // Files exist but failed validation — do not silently regenerate. + if !errors.Is(err, ErrCertFilesNotFound) { + return nil, nil, err + } + + // Files not found — generate new certificates. return generateAndStoreRootCert(store, certName, addThumbPrintToName, commonName, country, organization, strong) } @@ -273,11 +297,11 @@ func CheckAndLoadOrGenerateWebServerCertificate(rootCert CertAndKeyType, addThum if err == nil { return cert, key, nil } - // If loading fails, fall through to generation - log.Printf("Warning: Failed to load existing certificates: %v. Generating new ones...", err) + // Files exist but failed validation — do not silently regenerate. + return nil, nil, fmt.Errorf("existing web server certificate is invalid: %w", err) } - // Files don't exist or loading failed, generate new certificates + // Files don't exist — generate new certificates. return IssueWebServerCertificate(rootCert, addThumbPrintToName, commonName, country, organization, strong) } @@ -308,7 +332,12 @@ func LoadOrGenerateWebServerCertificateWithVault(store security.Storager, rootCe return cert, key, nil } - // Generate new certificates + // Files exist but failed validation — do not silently regenerate. + if !errors.Is(err, ErrCertFilesNotFound) { + return nil, nil, err + } + + // Files not found — generate new certificates. return generateAndStoreWebServerCert(store, rootCert, certName, addThumbPrintToName, commonName, country, organization, strong) } diff --git a/internal/certificates/generate_fuzz_test.go b/internal/certificates/generate_fuzz_test.go index 04285085e..b6ff98ebe 100644 --- a/internal/certificates/generate_fuzz_test.go +++ b/internal/certificates/generate_fuzz_test.go @@ -126,7 +126,7 @@ func generateFuzzPEMCertificate(tb testing.TB, seed int64) (certPEM, keyPEM stri serialNumber := new(big.Int).SetInt64(rng.Int63()) - fixedTime := time.Date(2099, 1, 1, 0, 0, 0, 0, time.UTC) + now := time.Now() template := x509.Certificate{ SerialNumber: serialNumber, @@ -134,8 +134,8 @@ func generateFuzzPEMCertificate(tb testing.TB, seed int64) (certPEM, keyPEM stri CommonName: "fuzz-cert", Organization: []string{"console"}, }, - NotBefore: fixedTime.Add(-time.Hour), - NotAfter: fixedTime.Add(24 * time.Hour), + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(24 * time.Hour), KeyUsage: x509.KeyUsageDigitalSignature, BasicConstraintsValid: true, } diff --git a/internal/certificates/generate_test.go b/internal/certificates/generate_test.go index 3362e48f1..bcd8474a5 100644 --- a/internal/certificates/generate_test.go +++ b/internal/certificates/generate_test.go @@ -121,6 +121,48 @@ func TestParseCertificateFromPEM(t *testing.T) { assert.Equal(t, cert.Subject.CommonName, parsedCert.Subject.CommonName) } +// generateFutureNotBeforeCertAndKey generates a cert whose notBefore is in the future. +func generateFutureNotBeforeCertAndKey(t *testing.T) (*x509.Certificate, *rsa.PrivateKey) { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + assert.NoError(t, err) + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: "future-notbefore-cert", + Organization: []string{"test-org"}, + }, + NotBefore: time.Now().Add(24 * time.Hour), + NotAfter: time.Now().Add(48 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + assert.NoError(t, err) + + cert, err := x509.ParseCertificate(certBytes) + assert.NoError(t, err) + + return cert, privateKey +} + +func TestParseCertificateFromPEM_FutureNotBefore(t *testing.T) { + t.Parallel() + + cert, key := generateFutureNotBeforeCertAndKey(t) + certPEM, keyPEM := certAndKeyToPEM(cert, key) + + _, _, err := ParseCertificateFromPEM(certPEM, keyPEM) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not yet valid") +} + func TestParseCertificateFromPEM_InvalidCert(t *testing.T) { t.Parallel() diff --git a/internal/usecase/devices/certificates.go b/internal/usecase/devices/certificates.go index aaa72334a..8fde23f41 100644 --- a/internal/usecase/devices/certificates.go +++ b/internal/usecase/devices/certificates.go @@ -349,7 +349,7 @@ func (uc *UseCase) AddCertificate(c context.Context, guid string, certInfo dto.C block, _ := pem.Decode(certData) if block != nil { if block.Type != "CERTIFICATE" { - return "", err + return "", ValidationError{}.Wrap("AddCertificate", "pemType", fmt.Sprintf("invalid PEM block type: expected CERTIFICATE, got %s", block.Type)) } certData = block.Bytes @@ -360,8 +360,13 @@ func (uc *UseCase) AddCertificate(c context.Context, guid string, certInfo dto.C return "", err } - if cert.NotAfter.Before(time.Now()) { - return "", err + now := time.Now() + if now.Before(cert.NotBefore) { + return "", ValidationError{}.Wrap("AddCertificate", "notBefore", fmt.Sprintf("certificate is not yet valid: notBefore=%s", cert.NotBefore.Format(time.RFC3339))) + } + + if cert.NotAfter.Before(now) { + return "", ValidationError{}.Wrap("AddCertificate", "notAfter", fmt.Sprintf("certificate has expired: notAfter=%s", cert.NotAfter.Format(time.RFC3339))) } pemCert := pem.EncodeToMemory(&pem.Block{ @@ -371,7 +376,7 @@ func (uc *UseCase) AddCertificate(c context.Context, guid string, certInfo dto.C block, _ = pem.Decode(pemCert) if block == nil { - return "", err + return "", ValidationError{}.Wrap("AddCertificate", "pemReEncode", "failed to re-encode certificate as PEM") } cleanedCert := strings.ReplaceAll(base64.StdEncoding.EncodeToString(block.Bytes), "\r\n", "") From 5b5ee84cf8acee1a1252b3e78e0a836d33b5376e Mon Sep 17 00:00:00 2001 From: Natalie Gaston Date: Mon, 27 Apr 2026 17:10:31 -0700 Subject: [PATCH 2/4] refactor: extract parseCertFromCertInfo to reduce cognitive complexity --- internal/usecase/devices/certificates.go | 43 +++++++++++++----------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/internal/usecase/devices/certificates.go b/internal/usecase/devices/certificates.go index 8fde23f41..0b7bef5d6 100644 --- a/internal/usecase/devices/certificates.go +++ b/internal/usecase/devices/certificates.go @@ -327,25 +327,15 @@ func populateCertificateDTO(cert *x509.Certificate) dto.Certificate { } } -func (uc *UseCase) AddCertificate(c context.Context, guid string, certInfo dto.CertInfo) (handle string, err error) { - var certData []byte - - item, err := uc.repo.GetByID(c, guid, "") +// parseCertFromCertInfo decodes, validates, and re-encodes a certificate from a CertInfo DTO. +// Returns the cleaned base64 DER string ready to send to AMT. +func parseCertFromCertInfo(certInfo dto.CertInfo) (string, error) { + certData, err := base64.StdEncoding.DecodeString(certInfo.Cert) if err != nil { return "", err } - if item == nil || item.GUID == "" { - return "", ErrNotFound - } - - // Decode base64 certificate - certData, err = base64.StdEncoding.DecodeString(certInfo.Cert) - if err != nil { - return "", err - } - - // Try to decode as PEM + // Try to decode as PEM; if it is PEM, unwrap to raw DER bytes. block, _ := pem.Decode(certData) if block != nil { if block.Type != "CERTIFICATE" { @@ -369,17 +359,30 @@ func (uc *UseCase) AddCertificate(c context.Context, guid string, certInfo dto.C return "", ValidationError{}.Wrap("AddCertificate", "notAfter", fmt.Sprintf("certificate has expired: notAfter=%s", cert.NotAfter.Format(time.RFC3339))) } - pemCert := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: cert.Raw, - }) + pemCert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}) block, _ = pem.Decode(pemCert) if block == nil { return "", ValidationError{}.Wrap("AddCertificate", "pemReEncode", "failed to re-encode certificate as PEM") } - cleanedCert := strings.ReplaceAll(base64.StdEncoding.EncodeToString(block.Bytes), "\r\n", "") + return strings.ReplaceAll(base64.StdEncoding.EncodeToString(block.Bytes), "\r\n", ""), nil +} + +func (uc *UseCase) AddCertificate(c context.Context, guid string, certInfo dto.CertInfo) (handle string, err error) { + item, err := uc.repo.GetByID(c, guid, "") + if err != nil { + return "", err + } + + if item == nil || item.GUID == "" { + return "", ErrNotFound + } + + cleanedCert, err := parseCertFromCertInfo(certInfo) + if err != nil { + return "", err + } device, err := uc.device.SetupWsmanClient(c, *item, false, true) if err != nil { From d45a3951beb0540c8e4beec4a3fbb7ee8c76bc99 Mon Sep 17 00:00:00 2001 From: Natalie Gaston Date: Mon, 27 Apr 2026 17:25:36 -0700 Subject: [PATCH 3/4] feat: add certificate validation on load and upload --- internal/certificates/generate.go | 5 ++++ internal/certificates/generate_test.go | 34 ++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/internal/certificates/generate.go b/internal/certificates/generate.go index fbda68b37..37cde2b2c 100644 --- a/internal/certificates/generate.go +++ b/internal/certificates/generate.go @@ -33,6 +33,7 @@ var ( ErrDecodePrivateKeyPEM = errors.New("failed to decode private key PEM") ErrCertFilesNotFound = errors.New("certificate files not found") ErrCertNotYetValid = errors.New("certificate is not yet valid: notBefore is in the future") + ErrCertExpired = errors.New("certificate has expired") ) // validateCertDates checks that the certificate's validity window covers the current time. @@ -42,6 +43,10 @@ func validateCertDates(cert *x509.Certificate) error { return fmt.Errorf("%w: notBefore=%s", ErrCertNotYetValid, cert.NotBefore.Format(time.RFC3339)) } + if cert.NotAfter.Before(now) { + return fmt.Errorf("%w: notAfter=%s", ErrCertExpired, cert.NotAfter.Format(time.RFC3339)) + } + return nil } diff --git a/internal/certificates/generate_test.go b/internal/certificates/generate_test.go index bcd8474a5..c2d219583 100644 --- a/internal/certificates/generate_test.go +++ b/internal/certificates/generate_test.go @@ -163,6 +163,40 @@ func TestParseCertificateFromPEM_FutureNotBefore(t *testing.T) { assert.Contains(t, err.Error(), "not yet valid") } +func TestParseCertificateFromPEM_Expired(t *testing.T) { + t.Parallel() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + assert.NoError(t, err) + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: "expired-cert", + Organization: []string{"test-org"}, + }, + NotBefore: time.Now().Add(-48 * time.Hour), + NotAfter: time.Now().Add(-24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + assert.NoError(t, err) + + cert, err := x509.ParseCertificate(certBytes) + assert.NoError(t, err) + + certPEM, keyPEM := certAndKeyToPEM(cert, privateKey) + + _, _, err = ParseCertificateFromPEM(certPEM, keyPEM) + assert.Error(t, err) + assert.Contains(t, err.Error(), "expired") +} + func TestParseCertificateFromPEM_InvalidCert(t *testing.T) { t.Parallel() From f9cf02d89f26772179438fdf4d244775c51d36c5 Mon Sep 17 00:00:00 2001 From: Natalie Gaston Date: Mon, 27 Apr 2026 17:54:53 -0700 Subject: [PATCH 4/4] feat: reject certificates with RSA keys below 2048 bits on upload --- internal/usecase/devices/certificates.go | 7 +++ internal/usecase/devices/certificates_test.go | 51 +++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/internal/usecase/devices/certificates.go b/internal/usecase/devices/certificates.go index 0b7bef5d6..63a966c73 100644 --- a/internal/usecase/devices/certificates.go +++ b/internal/usecase/devices/certificates.go @@ -359,6 +359,13 @@ func parseCertFromCertInfo(certInfo dto.CertInfo) (string, error) { return "", ValidationError{}.Wrap("AddCertificate", "notAfter", fmt.Sprintf("certificate has expired: notAfter=%s", cert.NotAfter.Format(time.RFC3339))) } + if rsaPub, ok := cert.PublicKey.(*rsa.PublicKey); ok { + const minRSAKeyBits = 2048 + if rsaPub.N.BitLen() < minRSAKeyBits { + return "", ValidationError{}.Wrap("AddCertificate", "keySize", fmt.Sprintf("RSA key size %d is below the minimum required %d bits", rsaPub.N.BitLen(), minRSAKeyBits)) + } + } + pemCert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}) block, _ = pem.Decode(pemCert) diff --git a/internal/usecase/devices/certificates_test.go b/internal/usecase/devices/certificates_test.go index a88f0169b..94cb95415 100644 --- a/internal/usecase/devices/certificates_test.go +++ b/internal/usecase/devices/certificates_test.go @@ -2,9 +2,16 @@ package devices_test import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" "encoding/xml" "errors" + "math/big" "testing" + "time" "github.com/stretchr/testify/require" gomock "go.uber.org/mock/gomock" @@ -332,6 +339,50 @@ func TestAddCertificate(t *testing.T) { } } +func TestAddCertificate_WeakKey(t *testing.T) { + t.Parallel() + + device := &entity.Device{ + GUID: "device-guid-123", + TenantID: "tenant-id-456", + } + + //nolint:gosec // intentionally weak key for validation testing + privateKey, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(t, err) + + serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + require.NoError(t, err) + + template := x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{CommonName: "weak-key-test"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(48 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + require.NoError(t, err) + + weakCertBase64 := base64.StdEncoding.EncodeToString(certBytes) + + useCase, wsmanMock, _, repo := initCertificateTest(t) + + wsmanMock.EXPECT().SetupWsmanClient(gomock.Any(), gomock.Any(), false, true).Times(0) + repo.EXPECT(). + GetByID(context.Background(), device.GUID, ""). + Return(device, nil) + + _, err = useCase.AddCertificate(context.Background(), device.GUID, dto.CertInfo{ + Cert: weakCertBase64, + IsTrusted: true, + }) + + require.Error(t, err) + require.Contains(t, err.Error(), "below the minimum required") +} + func TestDeleteCertificate(t *testing.T) { t.Parallel()