Skip to content

Commit 8077d8d

Browse files
committed
feat: persist preferred project by account key
1 parent 7255c9c commit 8077d8d

9 files changed

Lines changed: 355 additions & 26 deletions

File tree

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
CREATE TABLE `auth_preferences` (
2+
`account_key` text NOT NULL,
3+
`cloud_region` text NOT NULL,
4+
`last_selected_project_id` integer,
5+
`created_at` text DEFAULT (CURRENT_TIMESTAMP) NOT NULL,
6+
`updated_at` text DEFAULT (CURRENT_TIMESTAMP) NOT NULL
7+
);
8+
--> statement-breakpoint
9+
CREATE INDEX `auth_preferences_account_region_idx` ON `auth_preferences` (`account_key`,`cloud_region`);

apps/code/src/main/db/migrations/meta/_journal.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@
2929
"when": 1774890000000,
3030
"tag": "0003_fair_whiplash",
3131
"breakpoints": true
32+
},
33+
{
34+
"idx": 4,
35+
"version": "7",
36+
"when": 1774891000000,
37+
"tag": "0004_auth_preferences",
38+
"breakpoints": true
3239
}
3340
]
3441
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import type {
2+
AuthPreference,
3+
IAuthPreferenceRepository,
4+
PersistAuthPreferenceInput,
5+
} from "./auth-preference-repository";
6+
7+
export interface MockAuthPreferenceRepository
8+
extends IAuthPreferenceRepository {
9+
_preferences: AuthPreference[];
10+
}
11+
12+
export function createMockAuthPreferenceRepository(): MockAuthPreferenceRepository {
13+
let preferences: AuthPreference[] = [];
14+
15+
const clone = (value: AuthPreference): AuthPreference => ({ ...value });
16+
17+
return {
18+
get _preferences() {
19+
return preferences.map(clone);
20+
},
21+
set _preferences(value) {
22+
preferences = value.map(clone);
23+
},
24+
get: (accountKey, cloudRegion) => {
25+
const preference = preferences.find(
26+
(entry) =>
27+
entry.accountKey === accountKey && entry.cloudRegion === cloudRegion,
28+
);
29+
return preference ? clone(preference) : null;
30+
},
31+
save: (input: PersistAuthPreferenceInput) => {
32+
const timestamp = new Date().toISOString();
33+
const existingIndex = preferences.findIndex(
34+
(entry) =>
35+
entry.accountKey === input.accountKey &&
36+
entry.cloudRegion === input.cloudRegion,
37+
);
38+
39+
const row: AuthPreference = {
40+
accountKey: input.accountKey,
41+
cloudRegion: input.cloudRegion,
42+
lastSelectedProjectId: input.lastSelectedProjectId,
43+
createdAt:
44+
existingIndex >= 0 ? preferences[existingIndex].createdAt : timestamp,
45+
updatedAt: timestamp,
46+
};
47+
48+
if (existingIndex >= 0) {
49+
preferences[existingIndex] = row;
50+
} else {
51+
preferences.push(row);
52+
}
53+
54+
return clone(row);
55+
},
56+
};
57+
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import { and, eq } from "drizzle-orm";
2+
import { inject, injectable } from "inversify";
3+
import { MAIN_TOKENS } from "../../di/tokens";
4+
import { authPreferences } from "../schema";
5+
import type { DatabaseService } from "../service";
6+
7+
export type AuthPreference = typeof authPreferences.$inferSelect;
8+
export type NewAuthPreference = typeof authPreferences.$inferInsert;
9+
10+
export interface PersistAuthPreferenceInput {
11+
accountKey: string;
12+
cloudRegion: "us" | "eu" | "dev";
13+
lastSelectedProjectId: number | null;
14+
}
15+
16+
export interface IAuthPreferenceRepository {
17+
get(
18+
accountKey: string,
19+
cloudRegion: "us" | "eu" | "dev",
20+
): AuthPreference | null;
21+
save(input: PersistAuthPreferenceInput): AuthPreference;
22+
}
23+
24+
const now = () => new Date().toISOString();
25+
26+
@injectable()
27+
export class AuthPreferenceRepository implements IAuthPreferenceRepository {
28+
constructor(
29+
@inject(MAIN_TOKENS.DatabaseService)
30+
private readonly databaseService: DatabaseService,
31+
) {}
32+
33+
private get db() {
34+
return this.databaseService.db;
35+
}
36+
37+
get(
38+
accountKey: string,
39+
cloudRegion: "us" | "eu" | "dev",
40+
): AuthPreference | null {
41+
return (
42+
this.db
43+
.select()
44+
.from(authPreferences)
45+
.where(
46+
and(
47+
eq(authPreferences.accountKey, accountKey),
48+
eq(authPreferences.cloudRegion, cloudRegion),
49+
),
50+
)
51+
.limit(1)
52+
.get() ?? null
53+
);
54+
}
55+
56+
save(input: PersistAuthPreferenceInput): AuthPreference {
57+
const timestamp = now();
58+
const existing = this.get(input.accountKey, input.cloudRegion);
59+
60+
const row: NewAuthPreference = {
61+
accountKey: input.accountKey,
62+
cloudRegion: input.cloudRegion,
63+
lastSelectedProjectId: input.lastSelectedProjectId,
64+
createdAt: existing?.createdAt ?? timestamp,
65+
updatedAt: timestamp,
66+
};
67+
68+
if (existing) {
69+
this.db
70+
.update(authPreferences)
71+
.set(row)
72+
.where(
73+
and(
74+
eq(authPreferences.accountKey, input.accountKey),
75+
eq(authPreferences.cloudRegion, input.cloudRegion),
76+
),
77+
)
78+
.run();
79+
} else {
80+
this.db.insert(authPreferences).values(row).run();
81+
}
82+
83+
const saved = this.get(input.accountKey, input.cloudRegion);
84+
if (!saved) {
85+
throw new Error("Failed to persist auth preference");
86+
}
87+
return saved;
88+
}
89+
}

apps/code/src/main/db/schema.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,20 @@ export const authSessions = sqliteTable("auth_sessions", {
8686
createdAt: createdAt(),
8787
updatedAt: updatedAt(),
8888
});
89+
90+
export const authPreferences = sqliteTable(
91+
"auth_preferences",
92+
{
93+
accountKey: text().notNull(),
94+
cloudRegion: text({ enum: ["us", "eu", "dev"] }).notNull(),
95+
lastSelectedProjectId: integer(),
96+
createdAt: createdAt(),
97+
updatedAt: updatedAt(),
98+
},
99+
(t) => [
100+
index("auth_preferences_account_region_idx").on(
101+
t.accountKey,
102+
t.cloudRegion,
103+
),
104+
],
105+
);

apps/code/src/main/di/container.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import "reflect-metadata";
22

33
import { Container } from "inversify";
44
import { ArchiveRepository } from "../db/repositories/archive-repository";
5+
import { AuthPreferenceRepository } from "../db/repositories/auth-preference-repository";
56
import { AuthSessionRepository } from "../db/repositories/auth-session-repository";
67
import { RepositoryRepository } from "../db/repositories/repository-repository";
78
import { SuspensionRepositoryImpl } from "../db/repositories/suspension-repository";
@@ -52,6 +53,9 @@ export const container = new Container({
5253
});
5354

5455
container.bind(MAIN_TOKENS.DatabaseService).to(DatabaseService);
56+
container
57+
.bind(MAIN_TOKENS.AuthPreferenceRepository)
58+
.to(AuthPreferenceRepository);
5559
container.bind(MAIN_TOKENS.AuthSessionRepository).to(AuthSessionRepository);
5660
container.bind(MAIN_TOKENS.RepositoryRepository).to(RepositoryRepository);
5761
container.bind(MAIN_TOKENS.WorkspaceRepository).to(WorkspaceRepository);

apps/code/src/main/di/tokens.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ export const MAIN_TOKENS = Object.freeze({
99
SettingsStore: Symbol.for("Main.SettingsStore"),
1010

1111
// Database
12+
AuthPreferenceRepository: Symbol.for("Main.AuthPreferenceRepository"),
1213
DatabaseService: Symbol.for("Main.DatabaseService"),
1314
AuthSessionRepository: Symbol.for("Main.AuthSessionRepository"),
1415
RepositoryRepository: Symbol.for("Main.RepositoryRepository"),

apps/code/src/main/services/auth/service.test.ts

Lines changed: 95 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { OAUTH_SCOPE_VERSION } from "@shared/constants/oauth";
22
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
3+
import { createMockAuthPreferenceRepository } from "../../db/repositories/auth-preference-repository.mock";
34
import { createMockAuthSessionRepository } from "../../db/repositories/auth-session-repository.mock";
45
import { decrypt, encrypt } from "../../utils/encryption";
56
import type { ConnectivityService } from "../connectivity/service";
@@ -18,6 +19,7 @@ vi.mock("../../utils/logger.js", () => ({
1819
}));
1920

2021
describe("AuthService", () => {
22+
const preferenceRepository = createMockAuthPreferenceRepository();
2123
const repository = createMockAuthSessionRepository();
2224
const oauthService = {
2325
refreshToken: vi.fn(),
@@ -31,16 +33,43 @@ describe("AuthService", () => {
3133
let service: AuthService;
3234

3335
beforeEach(() => {
36+
preferenceRepository._preferences = [];
3437
repository.clearCurrent();
3538
vi.clearAllMocks();
36-
service = new AuthService(repository, oauthService, connectivityService);
39+
service = new AuthService(
40+
preferenceRepository,
41+
repository,
42+
oauthService,
43+
connectivityService,
44+
);
3745
});
3846

3947
afterEach(async () => {
4048
vi.unstubAllGlobals();
4149
await service.logout();
4250
});
4351

52+
const stubAuthFetch = (accountKey = "user-1") => {
53+
vi.stubGlobal(
54+
"fetch",
55+
vi.fn(async (input: string | Request) => {
56+
const url = typeof input === "string" ? input : input.url;
57+
58+
if (url.includes("/api/users/@me/")) {
59+
return {
60+
ok: true,
61+
json: vi.fn().mockResolvedValue({ uuid: accountKey }),
62+
} as unknown as Response;
63+
}
64+
65+
return {
66+
ok: true,
67+
json: vi.fn().mockResolvedValue({ has_access: true }),
68+
} as unknown as Response;
69+
}) as typeof fetch,
70+
);
71+
};
72+
4473
it("bootstraps to anonymous when there is no stored session", async () => {
4574
await service.initialize();
4675

@@ -99,12 +128,7 @@ describe("AuthService", () => {
99128
},
100129
});
101130

102-
vi.stubGlobal(
103-
"fetch",
104-
vi.fn().mockResolvedValue({
105-
json: vi.fn().mockResolvedValue({ has_access: true }),
106-
}) as unknown as typeof fetch,
107-
);
131+
stubAuthFetch();
108132

109133
await service.initialize();
110134

@@ -151,12 +175,7 @@ describe("AuthService", () => {
151175
scoped_organizations: ["org-1"],
152176
},
153177
});
154-
vi.stubGlobal(
155-
"fetch",
156-
vi.fn().mockResolvedValue({
157-
json: vi.fn().mockResolvedValue({ has_access: true }),
158-
}) as unknown as typeof fetch,
159-
);
178+
stubAuthFetch();
160179

161180
await service.login("us");
162181

@@ -211,12 +230,7 @@ describe("AuthService", () => {
211230
},
212231
});
213232

214-
vi.stubGlobal(
215-
"fetch",
216-
vi.fn().mockResolvedValue({
217-
json: vi.fn().mockResolvedValue({ has_access: true }),
218-
}) as unknown as typeof fetch,
219-
);
233+
stubAuthFetch();
220234

221235
await service.login("us");
222236
await service.selectProject(84);
@@ -237,4 +251,66 @@ describe("AuthService", () => {
237251
availableProjectIds: [42, 84],
238252
});
239253
});
254+
255+
it("restores the selected project after app restart while logged out", async () => {
256+
vi.mocked(oauthService.startFlow)
257+
.mockResolvedValueOnce({
258+
success: true,
259+
data: {
260+
access_token: "initial-access-token",
261+
refresh_token: "initial-refresh-token",
262+
expires_in: 3600,
263+
token_type: "Bearer",
264+
scope: "",
265+
scoped_teams: [42, 84],
266+
scoped_organizations: ["org-1"],
267+
},
268+
})
269+
.mockResolvedValueOnce({
270+
success: true,
271+
data: {
272+
access_token: "second-access-token",
273+
refresh_token: "second-refresh-token",
274+
expires_in: 3600,
275+
token_type: "Bearer",
276+
scope: "",
277+
scoped_teams: [42, 84],
278+
scoped_organizations: ["org-1"],
279+
},
280+
});
281+
vi.mocked(oauthService.refreshToken).mockResolvedValue({
282+
success: true,
283+
data: {
284+
access_token: "refreshed-access-token",
285+
refresh_token: "refreshed-refresh-token",
286+
expires_in: 3600,
287+
token_type: "Bearer",
288+
scope: "",
289+
scoped_teams: [42, 84],
290+
scoped_organizations: ["org-1"],
291+
},
292+
});
293+
294+
stubAuthFetch();
295+
296+
await service.login("us");
297+
await service.selectProject(84);
298+
await service.logout();
299+
300+
service = new AuthService(
301+
preferenceRepository,
302+
repository,
303+
oauthService,
304+
connectivityService,
305+
);
306+
307+
await service.login("us");
308+
309+
expect(service.getState()).toMatchObject({
310+
status: "authenticated",
311+
cloudRegion: "us",
312+
projectId: 84,
313+
availableProjectIds: [42, 84],
314+
});
315+
});
240316
});

0 commit comments

Comments
 (0)