Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions internal/certificates/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,24 @@ var (
ErrDecodeCertificatePEM = errors.New("failed to decode certificate PEM")
ErrDecodePrivateKeyPEM = errors.New("failed to decode private key PEM")
ErrCertFilesNotFound = errors.New("certificate files not found")
ErrCertNotYetValid = errors.New("certificate is not yet valid: notBefore is in the future")
ErrCertExpired = errors.New("certificate has expired")
)

// validateCertDates checks that the certificate's validity window covers the current time.
func validateCertDates(cert *x509.Certificate) error {
now := time.Now()
if now.Before(cert.NotBefore) {
return fmt.Errorf("%w: notBefore=%s", ErrCertNotYetValid, cert.NotBefore.Format(time.RFC3339))
}

if cert.NotAfter.Before(now) {
return fmt.Errorf("%w: notAfter=%s", ErrCertExpired, cert.NotAfter.Format(time.RFC3339))
}

return nil
}

// ObjectStorager extends security.Storager with object storage capabilities.
type ObjectStorager interface {
security.Storager
Expand Down Expand Up @@ -101,6 +117,10 @@ func ParseCertificateFromPEM(certPEM, keyPEM string) (*x509.Certificate, *rsa.Pr
return nil, nil, fmt.Errorf("failed to parse certificate: %w", err)
}

if err := validateCertDates(cert); err != nil {
return nil, nil, err
}

// Decode private key PEM
keyBlock, _ := pem.Decode([]byte(keyPEM))
if keyBlock == nil {
Expand Down Expand Up @@ -135,6 +155,10 @@ func LoadCertificateFromFile(certPath, keyPath string) (*x509.Certificate, *rsa.
return nil, nil, fmt.Errorf("failed to parse certificate: %w", err)
}

if err := validateCertDates(cert); err != nil {
return nil, nil, err
}

// Read private key file
keyPEM, err := os.ReadFile(keyPath)
if err != nil {
Expand Down Expand Up @@ -169,11 +193,11 @@ func CheckAndLoadOrGenerateRootCertificate(addThumbPrintToName bool, commonName,
if err == nil {
return cert, key, nil
}
// If loading fails, fall through to generation
log.Printf("Warning: Failed to load existing certificates: %v. Generating new ones...", err)
// Files exist but failed validation — do not silently regenerate.
return nil, nil, fmt.Errorf("existing root certificate is invalid: %w", err)
}

// Files don't exist or loading failed, generate new certificates
// Files don't exist generate new certificates.
return GenerateRootCertificate(addThumbPrintToName, commonName, country, organization, strong)
}

Expand Down Expand Up @@ -202,7 +226,12 @@ func LoadOrGenerateRootCertificateWithVault(store security.Storager, addThumbPri
return cert, key, nil
}

// Generate new certificates
// Files exist but failed validation — do not silently regenerate.
if !errors.Is(err, ErrCertFilesNotFound) {
return nil, nil, err
}

// Files not found — generate new certificates.
return generateAndStoreRootCert(store, certName, addThumbPrintToName, commonName, country, organization, strong)
}

Expand Down Expand Up @@ -273,11 +302,11 @@ func CheckAndLoadOrGenerateWebServerCertificate(rootCert CertAndKeyType, addThum
if err == nil {
return cert, key, nil
}
// If loading fails, fall through to generation
log.Printf("Warning: Failed to load existing certificates: %v. Generating new ones...", err)
// Files exist but failed validation — do not silently regenerate.
return nil, nil, fmt.Errorf("existing web server certificate is invalid: %w", err)
}

// Files don't exist or loading failed, generate new certificates
// Files don't exist generate new certificates.
return IssueWebServerCertificate(rootCert, addThumbPrintToName, commonName, country, organization, strong)
}

Expand Down Expand Up @@ -308,7 +337,12 @@ func LoadOrGenerateWebServerCertificateWithVault(store security.Storager, rootCe
return cert, key, nil
}

// Generate new certificates
// Files exist but failed validation — do not silently regenerate.
if !errors.Is(err, ErrCertFilesNotFound) {
return nil, nil, err
}

// Files not found — generate new certificates.
return generateAndStoreWebServerCert(store, rootCert, certName, addThumbPrintToName, commonName, country, organization, strong)
}

Expand Down
6 changes: 3 additions & 3 deletions internal/certificates/generate_fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,16 @@ func generateFuzzPEMCertificate(tb testing.TB, seed int64) (certPEM, keyPEM stri

serialNumber := new(big.Int).SetInt64(rng.Int63())

fixedTime := time.Date(2099, 1, 1, 0, 0, 0, 0, time.UTC)
now := time.Now()

template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: "fuzz-cert",
Organization: []string{"console"},
},
NotBefore: fixedTime.Add(-time.Hour),
NotAfter: fixedTime.Add(24 * time.Hour),
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
}
Expand Down
76 changes: 76 additions & 0 deletions internal/certificates/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,82 @@ func TestParseCertificateFromPEM(t *testing.T) {
assert.Equal(t, cert.Subject.CommonName, parsedCert.Subject.CommonName)
}

// generateFutureNotBeforeCertAndKey generates a cert whose notBefore is in the future.
func generateFutureNotBeforeCertAndKey(t *testing.T) (*x509.Certificate, *rsa.PrivateKey) {
t.Helper()

privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)

serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
assert.NoError(t, err)

template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: "future-notbefore-cert",
Organization: []string{"test-org"},
},
NotBefore: time.Now().Add(24 * time.Hour),
NotAfter: time.Now().Add(48 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
}

certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
assert.NoError(t, err)

cert, err := x509.ParseCertificate(certBytes)
assert.NoError(t, err)

return cert, privateKey
}

func TestParseCertificateFromPEM_FutureNotBefore(t *testing.T) {
t.Parallel()

cert, key := generateFutureNotBeforeCertAndKey(t)
certPEM, keyPEM := certAndKeyToPEM(cert, key)

_, _, err := ParseCertificateFromPEM(certPEM, keyPEM)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not yet valid")
}

func TestParseCertificateFromPEM_Expired(t *testing.T) {
t.Parallel()

privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)

serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
assert.NoError(t, err)

template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: "expired-cert",
Organization: []string{"test-org"},
},
NotBefore: time.Now().Add(-48 * time.Hour),
NotAfter: time.Now().Add(-24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
}

certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
assert.NoError(t, err)

cert, err := x509.ParseCertificate(certBytes)
assert.NoError(t, err)

certPEM, keyPEM := certAndKeyToPEM(cert, privateKey)

_, _, err = ParseCertificateFromPEM(certPEM, keyPEM)
assert.Error(t, err)
assert.Contains(t, err.Error(), "expired")
}

func TestParseCertificateFromPEM_InvalidCert(t *testing.T) {
t.Parallel()

Expand Down
61 changes: 38 additions & 23 deletions internal/usecase/devices/certificates.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,29 +327,19 @@ func populateCertificateDTO(cert *x509.Certificate) dto.Certificate {
}
}

func (uc *UseCase) AddCertificate(c context.Context, guid string, certInfo dto.CertInfo) (handle string, err error) {
var certData []byte

item, err := uc.repo.GetByID(c, guid, "")
if err != nil {
return "", err
}

if item == nil || item.GUID == "" {
return "", ErrNotFound
}

// Decode base64 certificate
certData, err = base64.StdEncoding.DecodeString(certInfo.Cert)
// parseCertFromCertInfo decodes, validates, and re-encodes a certificate from a CertInfo DTO.
// Returns the cleaned base64 DER string ready to send to AMT.
func parseCertFromCertInfo(certInfo dto.CertInfo) (string, error) {
certData, err := base64.StdEncoding.DecodeString(certInfo.Cert)
if err != nil {
return "", err
}

// Try to decode as PEM
// Try to decode as PEM; if it is PEM, unwrap to raw DER bytes.
block, _ := pem.Decode(certData)
if block != nil {
if block.Type != "CERTIFICATE" {
return "", err
return "", ValidationError{}.Wrap("AddCertificate", "pemType", fmt.Sprintf("invalid PEM block type: expected CERTIFICATE, got %s", block.Type))
}

certData = block.Bytes
Expand All @@ -360,21 +350,46 @@ func (uc *UseCase) AddCertificate(c context.Context, guid string, certInfo dto.C
return "", err
}

if cert.NotAfter.Before(time.Now()) {
return "", err
now := time.Now()
if now.Before(cert.NotBefore) {
return "", ValidationError{}.Wrap("AddCertificate", "notBefore", fmt.Sprintf("certificate is not yet valid: notBefore=%s", cert.NotBefore.Format(time.RFC3339)))
}

if cert.NotAfter.Before(now) {
return "", ValidationError{}.Wrap("AddCertificate", "notAfter", fmt.Sprintf("certificate has expired: notAfter=%s", cert.NotAfter.Format(time.RFC3339)))
}

if rsaPub, ok := cert.PublicKey.(*rsa.PublicKey); ok {
const minRSAKeyBits = 2048
if rsaPub.N.BitLen() < minRSAKeyBits {
return "", ValidationError{}.Wrap("AddCertificate", "keySize", fmt.Sprintf("RSA key size %d is below the minimum required %d bits", rsaPub.N.BitLen(), minRSAKeyBits))
}
}

pemCert := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
})
pemCert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})

block, _ = pem.Decode(pemCert)
if block == nil {
return "", ValidationError{}.Wrap("AddCertificate", "pemReEncode", "failed to re-encode certificate as PEM")
}

return strings.ReplaceAll(base64.StdEncoding.EncodeToString(block.Bytes), "\r\n", ""), nil
}

func (uc *UseCase) AddCertificate(c context.Context, guid string, certInfo dto.CertInfo) (handle string, err error) {
item, err := uc.repo.GetByID(c, guid, "")
if err != nil {
return "", err
}

cleanedCert := strings.ReplaceAll(base64.StdEncoding.EncodeToString(block.Bytes), "\r\n", "")
if item == nil || item.GUID == "" {
return "", ErrNotFound
}

cleanedCert, err := parseCertFromCertInfo(certInfo)
if err != nil {
return "", err
}

device, err := uc.device.SetupWsmanClient(c, *item, false, true)
if err != nil {
Expand Down
51 changes: 51 additions & 0 deletions internal/usecase/devices/certificates_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,16 @@

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/xml"
"errors"
"math/big"
"testing"
"time"

"github.com/stretchr/testify/require"
gomock "go.uber.org/mock/gomock"
Expand Down Expand Up @@ -332,6 +339,50 @@
}
}

func TestAddCertificate_WeakKey(t *testing.T) {
t.Parallel()

device := &entity.Device{
GUID: "device-guid-123",
TenantID: "tenant-id-456",
}

//nolint:gosec // intentionally weak key for validation testing

Check failure on line 350 in internal/usecase/devices/certificates_test.go

View workflow job for this annotation

GitHub Actions / runner / golangci-lint

[golangci] reported by reviewdog 🐶 directive `//nolint:gosec // intentionally weak key for validation testing` is unused for linter "gosec" (nolintlint) Raw Output: internal/usecase/devices/certificates_test.go:350:2: directive `//nolint:gosec // intentionally weak key for validation testing` is unused for linter "gosec" (nolintlint) //nolint:gosec // intentionally weak key for validation testing ^ 1 issues: * nolintlint: 1
privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(t, err)

serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
require.NoError(t, err)

template := x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{CommonName: "weak-key-test"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(48 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
}

certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)

weakCertBase64 := base64.StdEncoding.EncodeToString(certBytes)

useCase, wsmanMock, _, repo := initCertificateTest(t)

wsmanMock.EXPECT().SetupWsmanClient(gomock.Any(), gomock.Any(), false, true).Times(0)
repo.EXPECT().
GetByID(context.Background(), device.GUID, "").
Return(device, nil)

_, err = useCase.AddCertificate(context.Background(), device.GUID, dto.CertInfo{
Cert: weakCertBase64,
IsTrusted: true,
})

require.Error(t, err)
require.Contains(t, err.Error(), "below the minimum required")
}

func TestDeleteCertificate(t *testing.T) {
t.Parallel()

Expand Down
Loading