diff --git a/internal/certificates/generate.go b/internal/certificates/generate.go index 8fca175fe..37cde2b2c 100644 --- a/internal/certificates/generate.go +++ b/internal/certificates/generate.go @@ -32,8 +32,24 @@ var ( ErrDecodeCertificatePEM = errors.New("failed to decode certificate PEM") ErrDecodePrivateKeyPEM = errors.New("failed to decode private key PEM") ErrCertFilesNotFound = errors.New("certificate files not found") + ErrCertNotYetValid = errors.New("certificate is not yet valid: notBefore is in the future") + ErrCertExpired = errors.New("certificate has expired") ) +// validateCertDates checks that the certificate's validity window covers the current time. +func validateCertDates(cert *x509.Certificate) error { + now := time.Now() + if now.Before(cert.NotBefore) { + return fmt.Errorf("%w: notBefore=%s", ErrCertNotYetValid, cert.NotBefore.Format(time.RFC3339)) + } + + if cert.NotAfter.Before(now) { + return fmt.Errorf("%w: notAfter=%s", ErrCertExpired, cert.NotAfter.Format(time.RFC3339)) + } + + return nil +} + // ObjectStorager extends security.Storager with object storage capabilities. type ObjectStorager interface { security.Storager @@ -101,6 +117,10 @@ func ParseCertificateFromPEM(certPEM, keyPEM string) (*x509.Certificate, *rsa.Pr return nil, nil, fmt.Errorf("failed to parse certificate: %w", err) } + if err := validateCertDates(cert); err != nil { + return nil, nil, err + } + // Decode private key PEM keyBlock, _ := pem.Decode([]byte(keyPEM)) if keyBlock == nil { @@ -135,6 +155,10 @@ func LoadCertificateFromFile(certPath, keyPath string) (*x509.Certificate, *rsa. return nil, nil, fmt.Errorf("failed to parse certificate: %w", err) } + if err := validateCertDates(cert); err != nil { + return nil, nil, err + } + // Read private key file keyPEM, err := os.ReadFile(keyPath) if err != nil { @@ -169,11 +193,11 @@ func CheckAndLoadOrGenerateRootCertificate(addThumbPrintToName bool, commonName, if err == nil { return cert, key, nil } - // If loading fails, fall through to generation - log.Printf("Warning: Failed to load existing certificates: %v. Generating new ones...", err) + // Files exist but failed validation — do not silently regenerate. + return nil, nil, fmt.Errorf("existing root certificate is invalid: %w", err) } - // Files don't exist or loading failed, generate new certificates + // Files don't exist — generate new certificates. return GenerateRootCertificate(addThumbPrintToName, commonName, country, organization, strong) } @@ -202,7 +226,12 @@ func LoadOrGenerateRootCertificateWithVault(store security.Storager, addThumbPri return cert, key, nil } - // Generate new certificates + // Files exist but failed validation — do not silently regenerate. + if !errors.Is(err, ErrCertFilesNotFound) { + return nil, nil, err + } + + // Files not found — generate new certificates. return generateAndStoreRootCert(store, certName, addThumbPrintToName, commonName, country, organization, strong) } @@ -273,11 +302,11 @@ func CheckAndLoadOrGenerateWebServerCertificate(rootCert CertAndKeyType, addThum if err == nil { return cert, key, nil } - // If loading fails, fall through to generation - log.Printf("Warning: Failed to load existing certificates: %v. Generating new ones...", err) + // Files exist but failed validation — do not silently regenerate. + return nil, nil, fmt.Errorf("existing web server certificate is invalid: %w", err) } - // Files don't exist or loading failed, generate new certificates + // Files don't exist — generate new certificates. return IssueWebServerCertificate(rootCert, addThumbPrintToName, commonName, country, organization, strong) } @@ -308,7 +337,12 @@ func LoadOrGenerateWebServerCertificateWithVault(store security.Storager, rootCe return cert, key, nil } - // Generate new certificates + // Files exist but failed validation — do not silently regenerate. + if !errors.Is(err, ErrCertFilesNotFound) { + return nil, nil, err + } + + // Files not found — generate new certificates. return generateAndStoreWebServerCert(store, rootCert, certName, addThumbPrintToName, commonName, country, organization, strong) } diff --git a/internal/certificates/generate_fuzz_test.go b/internal/certificates/generate_fuzz_test.go index 04285085e..b6ff98ebe 100644 --- a/internal/certificates/generate_fuzz_test.go +++ b/internal/certificates/generate_fuzz_test.go @@ -126,7 +126,7 @@ func generateFuzzPEMCertificate(tb testing.TB, seed int64) (certPEM, keyPEM stri serialNumber := new(big.Int).SetInt64(rng.Int63()) - fixedTime := time.Date(2099, 1, 1, 0, 0, 0, 0, time.UTC) + now := time.Now() template := x509.Certificate{ SerialNumber: serialNumber, @@ -134,8 +134,8 @@ func generateFuzzPEMCertificate(tb testing.TB, seed int64) (certPEM, keyPEM stri CommonName: "fuzz-cert", Organization: []string{"console"}, }, - NotBefore: fixedTime.Add(-time.Hour), - NotAfter: fixedTime.Add(24 * time.Hour), + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(24 * time.Hour), KeyUsage: x509.KeyUsageDigitalSignature, BasicConstraintsValid: true, } diff --git a/internal/certificates/generate_test.go b/internal/certificates/generate_test.go index 3362e48f1..c2d219583 100644 --- a/internal/certificates/generate_test.go +++ b/internal/certificates/generate_test.go @@ -121,6 +121,82 @@ func TestParseCertificateFromPEM(t *testing.T) { assert.Equal(t, cert.Subject.CommonName, parsedCert.Subject.CommonName) } +// generateFutureNotBeforeCertAndKey generates a cert whose notBefore is in the future. +func generateFutureNotBeforeCertAndKey(t *testing.T) (*x509.Certificate, *rsa.PrivateKey) { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + assert.NoError(t, err) + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: "future-notbefore-cert", + Organization: []string{"test-org"}, + }, + NotBefore: time.Now().Add(24 * time.Hour), + NotAfter: time.Now().Add(48 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + assert.NoError(t, err) + + cert, err := x509.ParseCertificate(certBytes) + assert.NoError(t, err) + + return cert, privateKey +} + +func TestParseCertificateFromPEM_FutureNotBefore(t *testing.T) { + t.Parallel() + + cert, key := generateFutureNotBeforeCertAndKey(t) + certPEM, keyPEM := certAndKeyToPEM(cert, key) + + _, _, err := ParseCertificateFromPEM(certPEM, keyPEM) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not yet valid") +} + +func TestParseCertificateFromPEM_Expired(t *testing.T) { + t.Parallel() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + assert.NoError(t, err) + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: "expired-cert", + Organization: []string{"test-org"}, + }, + NotBefore: time.Now().Add(-48 * time.Hour), + NotAfter: time.Now().Add(-24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + assert.NoError(t, err) + + cert, err := x509.ParseCertificate(certBytes) + assert.NoError(t, err) + + certPEM, keyPEM := certAndKeyToPEM(cert, privateKey) + + _, _, err = ParseCertificateFromPEM(certPEM, keyPEM) + assert.Error(t, err) + assert.Contains(t, err.Error(), "expired") +} + func TestParseCertificateFromPEM_InvalidCert(t *testing.T) { t.Parallel() diff --git a/internal/usecase/devices/certificates.go b/internal/usecase/devices/certificates.go index aaa72334a..63a966c73 100644 --- a/internal/usecase/devices/certificates.go +++ b/internal/usecase/devices/certificates.go @@ -327,29 +327,19 @@ func populateCertificateDTO(cert *x509.Certificate) dto.Certificate { } } -func (uc *UseCase) AddCertificate(c context.Context, guid string, certInfo dto.CertInfo) (handle string, err error) { - var certData []byte - - item, err := uc.repo.GetByID(c, guid, "") - if err != nil { - return "", err - } - - if item == nil || item.GUID == "" { - return "", ErrNotFound - } - - // Decode base64 certificate - certData, err = base64.StdEncoding.DecodeString(certInfo.Cert) +// parseCertFromCertInfo decodes, validates, and re-encodes a certificate from a CertInfo DTO. +// Returns the cleaned base64 DER string ready to send to AMT. +func parseCertFromCertInfo(certInfo dto.CertInfo) (string, error) { + certData, err := base64.StdEncoding.DecodeString(certInfo.Cert) if err != nil { return "", err } - // Try to decode as PEM + // Try to decode as PEM; if it is PEM, unwrap to raw DER bytes. block, _ := pem.Decode(certData) if block != nil { if block.Type != "CERTIFICATE" { - return "", err + return "", ValidationError{}.Wrap("AddCertificate", "pemType", fmt.Sprintf("invalid PEM block type: expected CERTIFICATE, got %s", block.Type)) } certData = block.Bytes @@ -360,21 +350,46 @@ func (uc *UseCase) AddCertificate(c context.Context, guid string, certInfo dto.C return "", err } - if cert.NotAfter.Before(time.Now()) { - return "", err + now := time.Now() + if now.Before(cert.NotBefore) { + return "", ValidationError{}.Wrap("AddCertificate", "notBefore", fmt.Sprintf("certificate is not yet valid: notBefore=%s", cert.NotBefore.Format(time.RFC3339))) + } + + if cert.NotAfter.Before(now) { + return "", ValidationError{}.Wrap("AddCertificate", "notAfter", fmt.Sprintf("certificate has expired: notAfter=%s", cert.NotAfter.Format(time.RFC3339))) + } + + if rsaPub, ok := cert.PublicKey.(*rsa.PublicKey); ok { + const minRSAKeyBits = 2048 + if rsaPub.N.BitLen() < minRSAKeyBits { + return "", ValidationError{}.Wrap("AddCertificate", "keySize", fmt.Sprintf("RSA key size %d is below the minimum required %d bits", rsaPub.N.BitLen(), minRSAKeyBits)) + } } - pemCert := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: cert.Raw, - }) + pemCert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}) block, _ = pem.Decode(pemCert) if block == nil { + return "", ValidationError{}.Wrap("AddCertificate", "pemReEncode", "failed to re-encode certificate as PEM") + } + + return strings.ReplaceAll(base64.StdEncoding.EncodeToString(block.Bytes), "\r\n", ""), nil +} + +func (uc *UseCase) AddCertificate(c context.Context, guid string, certInfo dto.CertInfo) (handle string, err error) { + item, err := uc.repo.GetByID(c, guid, "") + if err != nil { return "", err } - cleanedCert := strings.ReplaceAll(base64.StdEncoding.EncodeToString(block.Bytes), "\r\n", "") + if item == nil || item.GUID == "" { + return "", ErrNotFound + } + + cleanedCert, err := parseCertFromCertInfo(certInfo) + if err != nil { + return "", err + } device, err := uc.device.SetupWsmanClient(c, *item, false, true) if err != nil { diff --git a/internal/usecase/devices/certificates_test.go b/internal/usecase/devices/certificates_test.go index a88f0169b..94cb95415 100644 --- a/internal/usecase/devices/certificates_test.go +++ b/internal/usecase/devices/certificates_test.go @@ -2,9 +2,16 @@ package devices_test import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" "encoding/xml" "errors" + "math/big" "testing" + "time" "github.com/stretchr/testify/require" gomock "go.uber.org/mock/gomock" @@ -332,6 +339,50 @@ func TestAddCertificate(t *testing.T) { } } +func TestAddCertificate_WeakKey(t *testing.T) { + t.Parallel() + + device := &entity.Device{ + GUID: "device-guid-123", + TenantID: "tenant-id-456", + } + + //nolint:gosec // intentionally weak key for validation testing + privateKey, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(t, err) + + serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + require.NoError(t, err) + + template := x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{CommonName: "weak-key-test"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(48 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + require.NoError(t, err) + + weakCertBase64 := base64.StdEncoding.EncodeToString(certBytes) + + useCase, wsmanMock, _, repo := initCertificateTest(t) + + wsmanMock.EXPECT().SetupWsmanClient(gomock.Any(), gomock.Any(), false, true).Times(0) + repo.EXPECT(). + GetByID(context.Background(), device.GUID, ""). + Return(device, nil) + + _, err = useCase.AddCertificate(context.Background(), device.GUID, dto.CertInfo{ + Cert: weakCertBase64, + IsTrusted: true, + }) + + require.Error(t, err) + require.Contains(t, err.Error(), "below the minimum required") +} + func TestDeleteCertificate(t *testing.T) { t.Parallel()