From 6a9ef9fddc7c2a188c611474d06512d26dc0f61d Mon Sep 17 00:00:00 2001 From: Mike Johanson Date: Tue, 5 May 2026 14:33:58 -0700 Subject: [PATCH] fix: ensure postgres works with ciraconfigs --- internal/controller/httpapi/v1/login.go | 1 + .../usecase/ciraconfigs/transform_test.go | 39 ++++++++ internal/usecase/profiles/usecase_test.go | 80 ++++++++++++++++ internal/usecase/sqldb/ciraconfig.go | 17 ++-- internal/usecase/sqldb/ciraconfig_test.go | 96 +++++++++++++++++++ 5 files changed, 227 insertions(+), 6 deletions(-) create mode 100644 internal/usecase/ciraconfigs/transform_test.go diff --git a/internal/controller/httpapi/v1/login.go b/internal/controller/httpapi/v1/login.go index b0fcd3477..679e53cbf 100644 --- a/internal/controller/httpapi/v1/login.go +++ b/internal/controller/httpapi/v1/login.go @@ -71,6 +71,7 @@ func (lr LoginRoute) handleBasicAuth(creds dto.Credentials, c *gin.Context) { expirationTime := time.Now().Add(config.ConsoleConfig.JWTExpiration) claims := jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(expirationTime), + Issuer: config.ConsoleConfig.Issuer, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) diff --git a/internal/usecase/ciraconfigs/transform_test.go b/internal/usecase/ciraconfigs/transform_test.go new file mode 100644 index 000000000..e9694d21e --- /dev/null +++ b/internal/usecase/ciraconfigs/transform_test.go @@ -0,0 +1,39 @@ +package ciraconfigs + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + dto "github.com/device-management-toolkit/console/internal/entity/dto/v1" + "github.com/device-management-toolkit/console/internal/mocks" + "github.com/device-management-toolkit/console/pkg/logger" +) + +func TestDTOEntityRoundTrip_GenerateRandomPassword(t *testing.T) { + t.Parallel() + + uc := &UseCase{log: logger.New("error"), safeRequirements: mocks.MockCrypto{}} + + for _, value := range []bool{true, false} { + value := value + + t.Run("", func(t *testing.T) { + t.Parallel() + + input := &dto.CIRAConfig{ + ConfigName: "cfg", + TenantID: "tenant", + GenerateRandomPassword: value, + } + + ent, err := uc.dtoToEntity(input) + require.NoError(t, err) + assert.Equal(t, value, ent.GenerateRandomPassword) + + back := uc.entityToDTO(ent) + assert.Equal(t, value, back.GenerateRandomPassword) + }) + } +} diff --git a/internal/usecase/profiles/usecase_test.go b/internal/usecase/profiles/usecase_test.go index 3bb442b8b..632176ca9 100644 --- a/internal/usecase/profiles/usecase_test.go +++ b/internal/usecase/profiles/usecase_test.go @@ -1016,6 +1016,86 @@ func TestBuildConfigurationObject(t *testing.T) { }, } + ciraTests := []struct { + name string + generateRandomPassword bool + }{ + {"cira config with generate random password true", true}, + {"cira config with generate random password false", false}, + } + + for _, ct := range ciraTests { + ct := ct + tests = append(tests, struct { + name string + profile *entity.Profile + domain *entity.Domain + wifi []config.WirelessProfile + cira *entity.CIRAConfig + expected config.Configuration + }{ + name: ct.name, + profile: &entity.Profile{ + ProfileName: "test-profile-cira", + Tags: "cira", + DHCPEnabled: true, + IPSyncEnabled: true, + Activation: "acmactivate", + AMTPassword: "amtpw", + MEBXPassword: "mebxpw", + TLSMode: 0, + UserConsent: "None", + }, + domain: &entity.Domain{}, + wifi: []config.WirelessProfile{}, + cira: &entity.CIRAConfig{ + Username: "mpsuser", + Password: "mpspw", + MPSAddress: "mps.example.com", + MPSRootCertificate: "mpscert", + GenerateRandomPassword: ct.generateRandomPassword, + }, + expected: config.Configuration{ + Name: "test-profile-cira", + Tags: []string{"cira"}, + Configuration: config.RemoteManagement{ + GeneralSettings: config.GeneralSettings{}, + Network: config.Network{ + Wired: config.Wired{ + DHCPEnabled: true, + IPSyncEnabled: true, + }, + Wireless: config.Wireless{ + Profiles: []config.WirelessProfile{}, + }, + }, + Redirection: config.Redirection{ + UserConsent: "None", + }, + TLS: config.TLS{}, + EnterpriseAssistant: config.EnterpriseAssistant{ + URL: "http://test.com:8080", + Username: "username", + Password: "password", + }, + AMTSpecific: config.AMTSpecific{ + ControlMode: "acmactivate", + AdminPassword: "amtpw", + MEBXPassword: "mebxpw", + CIRA: config.CIRA{ + MPSUsername: "mpsuser", + MPSPassword: "mpspw", + MPSAddress: "mps.example.com", + MPSCert: "mpscert", + EnvironmentDetection: []string{}, + GenerateRandomPassword: ct.generateRandomPassword, + }, + }, + }, + }, + }) + } + for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { t.Parallel() diff --git a/internal/usecase/sqldb/ciraconfig.go b/internal/usecase/sqldb/ciraconfig.go index dbf6c6d11..b0acfe9f3 100644 --- a/internal/usecase/sqldb/ciraconfig.go +++ b/internal/usecase/sqldb/ciraconfig.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "errors" + "strconv" "github.com/device-management-toolkit/console/internal/entity" "github.com/device-management-toolkit/console/internal/repoerrors" @@ -111,14 +112,16 @@ func (r *CIRARepo) Get(_ context.Context, top, skip int, tenantID string) ([]ent for rows.Next() { p := entity.CIRAConfig{} - var generateRandomPassword sql.NullBool + var generateRandomPassword sql.NullString err = rows.Scan(&p.ConfigName, &p.MPSAddress, &p.MPSPort, &p.Username, &p.Password, &p.CommonName, &p.ServerAddressFormat, &p.AuthMethod, &p.MPSRootCertificate, &p.ProxyDetails, &p.TenantID, &generateRandomPassword) if err != nil { return nil, ErrCIRARepoDatabase.Wrap("Get", "rows.Scan", err) } - p.GenerateRandomPassword = generateRandomPassword.Bool + if generateRandomPassword.Valid { + p.GenerateRandomPassword, _ = strconv.ParseBool(generateRandomPassword.String) + } configs = append(configs, p) } @@ -164,14 +167,16 @@ func (r *CIRARepo) GetByName(_ context.Context, configName, tenantID string) (*e for rows.Next() { p := &entity.CIRAConfig{} - var generateRandomPassword sql.NullBool + var generateRandomPassword sql.NullString err = rows.Scan(&p.ConfigName, &p.MPSAddress, &p.MPSPort, &p.Username, &p.Password, &p.CommonName, &p.ServerAddressFormat, &p.AuthMethod, &p.MPSRootCertificate, &p.ProxyDetails, &p.TenantID, &generateRandomPassword) if err != nil { return p, ErrCIRARepoDatabase.Wrap("GetByName", "rows.Scan", err) } - p.GenerateRandomPassword = generateRandomPassword.Bool + if generateRandomPassword.Valid { + p.GenerateRandomPassword, _ = strconv.ParseBool(generateRandomPassword.String) + } configs = append(configs, p) } @@ -219,7 +224,7 @@ func (r *CIRARepo) Update(_ context.Context, p *entity.CIRAConfig) (bool, error) Set("auth_method", p.AuthMethod). Set("mps_root_certificate", p.MPSRootCertificate). Set("proxydetails", p.ProxyDetails). - Set("generate_random_password", p.GenerateRandomPassword). + Set("generate_random_password", strconv.FormatBool(p.GenerateRandomPassword)). Where("cira_config_name = ? AND tenant_id = ?", p.ConfigName, p.TenantID). ToSql() if err != nil { @@ -244,7 +249,7 @@ func (r *CIRARepo) Insert(_ context.Context, p *entity.CIRAConfig) (string, erro insertBuilder := r.Builder. Insert("ciraconfigs"). Columns("cira_config_name", "mps_server_address", "mps_port", "user_name", "password", "common_name", "server_address_format", "auth_method", "mps_root_certificate", "proxydetails", "tenant_id", "generate_random_password"). - Values(p.ConfigName, p.MPSAddress, p.MPSPort, p.Username, p.Password, p.CommonName, p.ServerAddressFormat, p.AuthMethod, p.MPSRootCertificate, p.ProxyDetails, p.TenantID, p.GenerateRandomPassword) + Values(p.ConfigName, p.MPSAddress, p.MPSPort, p.Username, p.Password, p.CommonName, p.ServerAddressFormat, p.AuthMethod, p.MPSRootCertificate, p.ProxyDetails, p.TenantID, strconv.FormatBool(p.GenerateRandomPassword)) if !r.IsEmbedded { insertBuilder = insertBuilder.Suffix("RETURNING xmin::text") diff --git a/internal/usecase/sqldb/ciraconfig_test.go b/internal/usecase/sqldb/ciraconfig_test.go index 78aa966c9..ffd06004f 100644 --- a/internal/usecase/sqldb/ciraconfig_test.go +++ b/internal/usecase/sqldb/ciraconfig_test.go @@ -598,3 +598,99 @@ func TestCIRARepo_Delete(t *testing.T) { }) } } + +func TestCIRARepo_GenerateRandomPasswordRoundTrip(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value bool + }{ + {"true round-trip", true}, + {"false round-trip", false}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + dbConn := setupDatabase(t) + defer dbConn.Close() + + sqlConfig := CreateSQLConfig(dbConn, false) + mockLog := mocks.NewMockLogger(nil) + repo := sqldb.NewCIRARepo(sqlConfig, mockLog) + + input := &entity.CIRAConfig{ + ConfigName: "cfg", + MPSAddress: "mps", + MPSPort: 4433, + Username: "u", + Password: "p", + CommonName: "cn", + MPSRootCertificate: "cert", + ProxyDetails: "proxy", + TenantID: "tenant1", + GenerateRandomPassword: tc.value, + } + + _, err := repo.Insert(context.Background(), input) + require.NoError(t, err) + + got, err := repo.GetByName(context.Background(), "cfg", "tenant1") + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, tc.value, got.GenerateRandomPassword) + + input.GenerateRandomPassword = !tc.value + + updated, err := repo.Update(context.Background(), input) + require.NoError(t, err) + assert.True(t, updated) + + got, err = repo.GetByName(context.Background(), "cfg", "tenant1") + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, !tc.value, got.GenerateRandomPassword) + }) + } +} + +func TestCIRARepo_Get_GenerateRandomPasswordTextValues(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + stored any + expected bool + }{ + {"text true", "true", true}, + {"text false", "false", false}, + {"null", nil, false}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + dbConn := setupDatabase(t) + defer dbConn.Close() + + _, err := dbConn.ExecContext(context.Background(), + `INSERT INTO ciraconfigs (cira_config_name, mps_server_address, mps_port, user_name, password, common_name, server_address_format, auth_method, mps_root_certificate, proxydetails, tenant_id, generate_random_password) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + "cfg", "mps", 4433, "u", "p", "cn", 0, 0, "cert", "proxy", "tenant1", tc.stored) + require.NoError(t, err) + + sqlConfig := CreateSQLConfig(dbConn, false) + mockLog := mocks.NewMockLogger(nil) + repo := sqldb.NewCIRARepo(sqlConfig, mockLog) + + got, err := repo.GetByName(context.Background(), "cfg", "tenant1") + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, tc.expected, got.GenerateRandomPassword) + }) + } +}