Skip to content

Commit 3529539

Browse files
committed
fix: improve model selection reliability, make gateway model source of truth
1 parent 929a2a4 commit 3529539

18 files changed

Lines changed: 450 additions & 201 deletions

File tree

apps/twig/src/main/services/agent/schemas.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,20 @@ export const startSessionInput = z.object({
4949

5050
export type StartSessionInput = z.infer<typeof startSessionInput>;
5151

52+
export const modelOptionSchema = z.object({
53+
modelId: z.string(),
54+
name: z.string(),
55+
description: z.string().nullish(),
56+
provider: z.string().optional(),
57+
});
58+
59+
export type ModelOption = z.infer<typeof modelOptionSchema>;
60+
5261
export const sessionResponseSchema = z.object({
5362
sessionId: z.string(),
5463
channel: z.string(),
64+
availableModels: z.array(modelOptionSchema).optional(),
65+
currentModelId: z.string().optional(),
5566
});
5667

5768
export type SessionResponse = z.infer<typeof sessionResponseSchema>;
@@ -220,3 +231,10 @@ export const sessionInfoSchema = z.object({
220231
});
221232

222233
export const listSessionsOutput = z.array(sessionInfoSchema);
234+
235+
export const getGatewayModelsInput = z.object({
236+
apiHost: z.string(),
237+
apiKey: z.string(),
238+
});
239+
240+
export const getGatewayModelsOutput = z.array(modelOptionSchema);

apps/twig/src/main/services/agent/service.ts

Lines changed: 84 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@ import {
1010
type RequestPermissionRequest,
1111
type RequestPermissionResponse,
1212
} from "@agentclientprotocol/sdk";
13-
import { Agent, getLlmGatewayUrl, type OnLogCallback } from "@posthog/agent";
13+
import { Agent } from "@posthog/agent/agent";
14+
import {
15+
fetchGatewayModels,
16+
formatGatewayModelName,
17+
getProviderName,
18+
} from "@posthog/agent/gateway-models";
19+
import { getLlmGatewayUrl } from "@posthog/agent/posthog-api";
20+
import type { OnLogCallback } from "@posthog/agent/types";
1421
import { app } from "electron";
1522
import { injectable, preDestroy } from "inversify";
1623
import type { AcpMessage } from "../../../shared/types/session-events.js";
@@ -179,6 +186,12 @@ interface ManagedSession {
179186
needsRecreation: boolean;
180187
promptPending: boolean;
181188
pendingContext?: string;
189+
availableModels?: Array<{
190+
modelId: string;
191+
name: string;
192+
description?: string | null;
193+
}>;
194+
currentModelId?: string;
182195
}
183196

184197
function getClaudeCliPath(): string {
@@ -410,25 +423,43 @@ export class AgentService extends TypedEventEmitter<AgentServiceEvents> {
410423

411424
const mcpServers = this.buildMcpServers(credentials);
412425

426+
let availableModels:
427+
| Array<{ modelId: string; name: string; description?: string | null }>
428+
| undefined;
429+
let currentModelId: string | undefined;
430+
413431
if (isReconnect) {
414-
await connection.extMethod("_posthog/session/resume", {
415-
sessionId: taskRunId,
416-
cwd: repoPath,
417-
mcpServers,
418-
_meta: {
419-
...(logUrl && {
420-
persistence: { taskId, runId: taskRunId, logUrl },
421-
}),
422-
...(sdkSessionId && { sdkSessionId }),
423-
...(additionalDirectories?.length && {
424-
claudeCode: {
425-
options: { additionalDirectories },
426-
},
427-
}),
432+
const resumeResponse = await connection.extMethod(
433+
"_posthog/session/resume",
434+
{
435+
sessionId: taskRunId,
436+
cwd: repoPath,
437+
mcpServers,
438+
_meta: {
439+
...(logUrl && {
440+
persistence: { taskId, runId: taskRunId, logUrl },
441+
}),
442+
...(sdkSessionId && { sdkSessionId }),
443+
...(additionalDirectories?.length && {
444+
claudeCode: {
445+
options: { additionalDirectories },
446+
},
447+
}),
448+
},
428449
},
429-
});
450+
);
451+
const resumeMeta = resumeResponse?._meta as
452+
| {
453+
models?: {
454+
availableModels?: typeof availableModels;
455+
currentModelId?: string;
456+
};
457+
}
458+
| undefined;
459+
availableModels = resumeMeta?.models?.availableModels;
460+
currentModelId = resumeMeta?.models?.currentModelId;
430461
} else {
431-
await connection.newSession({
462+
const newSessionResponse = await connection.newSession({
432463
cwd: repoPath,
433464
mcpServers,
434465
_meta: {
@@ -442,6 +473,8 @@ export class AgentService extends TypedEventEmitter<AgentServiceEvents> {
442473
}),
443474
},
444475
});
476+
availableModels = newSessionResponse.models?.availableModels;
477+
currentModelId = newSessionResponse.models?.currentModelId;
445478
}
446479

447480
const session: ManagedSession = {
@@ -457,6 +490,8 @@ export class AgentService extends TypedEventEmitter<AgentServiceEvents> {
457490
config,
458491
needsRecreation: false,
459492
promptPending: false,
493+
availableModels,
494+
currentModelId,
460495
};
461496

462497
this.sessions.set(taskRunId, session);
@@ -1027,7 +1062,12 @@ For git operations while detached:
10271062
}
10281063

10291064
private toSessionResponse(session: ManagedSession): SessionResponse {
1030-
return { sessionId: session.taskRunId, channel: session.channel };
1065+
return {
1066+
sessionId: session.taskRunId,
1067+
channel: session.channel,
1068+
availableModels: session.availableModels,
1069+
currentModelId: session.currentModelId,
1070+
};
10311071
}
10321072

10331073
/**
@@ -1137,4 +1177,30 @@ For git operations while detached:
11371177
log.debug("Error in PR URL detection", { taskRunId, error: err });
11381178
}
11391179
}
1180+
1181+
async getGatewayModels(apiHost: string, _apiKey: string) {
1182+
const gatewayUrl = getLlmGatewayUrl(apiHost);
1183+
const models = await fetchGatewayModels({ gatewayUrl });
1184+
1185+
const MODEL_TIER_ORDER = ["opus", "sonnet", "haiku"];
1186+
1187+
const getModelTier = (modelId: string): number => {
1188+
const lowerId = modelId.toLowerCase();
1189+
for (let i = 0; i < MODEL_TIER_ORDER.length; i++) {
1190+
if (lowerId.includes(MODEL_TIER_ORDER[i])) return i;
1191+
}
1192+
return MODEL_TIER_ORDER.length;
1193+
};
1194+
1195+
const mapped = models.map((model) => ({
1196+
modelId: model.id,
1197+
name: formatGatewayModelName(model),
1198+
description: `Context: ${model.context_window.toLocaleString()} tokens`,
1199+
provider: getProviderName(model.owned_by),
1200+
}));
1201+
1202+
return mapped.sort(
1203+
(a, b) => getModelTier(a.modelId) - getModelTier(b.modelId),
1204+
);
1205+
}
11401206
}

apps/twig/src/main/services/folders/service.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { exec } from "node:child_process";
22
import fs from "node:fs";
33
import path from "node:path";
44
import { promisify } from "node:util";
5-
import { WorktreeManager } from "@posthog/agent";
5+
import { WorktreeManager } from "@posthog/agent/worktree-manager";
66
import { dialog } from "electron";
77
import { injectable } from "inversify";
88
import { generateId } from "../../../shared/utils/id.js";

apps/twig/src/main/services/workspace/service.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import * as fs from "node:fs";
33
import * as fsPromises from "node:fs/promises";
44
import path from "node:path";
55
import { promisify } from "node:util";
6-
import { WorktreeManager } from "@posthog/agent";
6+
import { WorktreeManager } from "@posthog/agent/worktree-manager";
77
import { inject, injectable } from "inversify";
88
import type {
99
TaskFolderAssociation,

apps/twig/src/main/trpc/routers/agent.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import {
66
cancelPermissionInput,
77
cancelPromptInput,
88
cancelSessionInput,
9+
getGatewayModelsInput,
10+
getGatewayModelsOutput,
911
listSessionsInput,
1012
listSessionsOutput,
1113
notifySessionContextInput,
@@ -140,4 +142,11 @@ export const agentRouter = router({
140142
markAllForRecreation: publicProcedure.mutation(() =>
141143
getService().markAllSessionsForRecreation(),
142144
),
145+
146+
getGatewayModels: publicProcedure
147+
.input(getGatewayModelsInput)
148+
.output(getGatewayModelsOutput)
149+
.query(({ input }) =>
150+
getService().getGatewayModels(input.apiHost, input.apiKey),
151+
),
143152
});

apps/twig/src/main/utils/store.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { WorktreeManager } from "@posthog/agent";
1+
import { WorktreeManager } from "@posthog/agent/worktree-manager";
22
import { app } from "electron";
33
import Store from "electron-store";
44
import type {

apps/twig/src/renderer/features/sessions/components/ModelSelector.tsx

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
1-
import { useSettingsStore } from "@features/settings/stores/settingsStore";
21
import { Select, Text } from "@radix-ui/themes";
3-
import {
4-
AVAILABLE_MODELS,
5-
getModelsByProvider,
6-
type ModelProvider,
7-
} from "@shared/types/models";
82
import { Fragment } from "react";
3+
import { useModelsStore } from "../stores/modelsStore";
94
import { useSessionActions, useSessionForTask } from "../stores/sessionStore";
105

116
interface ModelSelectorProps {
@@ -19,31 +14,26 @@ export function ModelSelector({
1914
disabled,
2015
onModelChange,
2116
}: ModelSelectorProps) {
22-
const defaultModel = useSettingsStore((state) => state.defaultModel);
23-
const setDefaultModel = useSettingsStore((state) => state.setDefaultModel);
2417
const { setSessionModel } = useSessionActions();
2518
const session = useSessionForTask(taskId);
2619

27-
// Use session model if available, otherwise fall back to default
28-
const activeModel = session?.model ?? defaultModel;
20+
const groupedModels = useModelsStore((s) => s.groupedModels);
21+
const models = useModelsStore((s) => s.models);
22+
const selectedModel = useModelsStore((s) => s.selectedModel);
23+
const setSelectedModel = useModelsStore((s) => s.setSelectedModel);
24+
25+
const activeModel = session?.model ?? selectedModel;
2926

3027
const handleChange = (value: string) => {
31-
// Always update the default
32-
setDefaultModel(value);
28+
setSelectedModel(value);
3329
onModelChange?.(value);
3430

35-
// If there's an active session, update the model mid-session
3631
if (taskId && session?.status === "connected" && !session.isCloud) {
3732
setSessionModel(taskId, value);
3833
}
3934
};
4035

41-
const modelsByProvider = getModelsByProvider();
42-
const providers = (Object.keys(modelsByProvider) as ModelProvider[]).filter(
43-
(provider) => modelsByProvider[provider].models.length > 0,
44-
);
45-
46-
const currentModel = AVAILABLE_MODELS.find((m) => m.id === activeModel);
36+
const currentModel = models.find((m) => m.modelId === activeModel);
4737
const displayName = currentModel?.name ?? activeModel;
4838

4939
return (
@@ -69,13 +59,13 @@ export function ModelSelector({
6959
</Text>
7060
</Select.Trigger>
7161
<Select.Content position="popper" sideOffset={4}>
72-
{providers.map((provider, index) => (
73-
<Fragment key={provider}>
62+
{groupedModels.map((group, index) => (
63+
<Fragment key={group.provider}>
7464
{index > 0 && <Select.Separator />}
7565
<Select.Group>
76-
<Select.Label>{modelsByProvider[provider].name}</Select.Label>
77-
{modelsByProvider[provider].models.map((model) => (
78-
<Select.Item key={model.id} value={model.id}>
66+
<Select.Label>{group.provider}</Select.Label>
67+
{group.models.map((model) => (
68+
<Select.Item key={model.modelId} value={model.modelId}>
7969
{model.name}
8070
</Select.Item>
8171
))}

0 commit comments

Comments
 (0)