Skip to content

Commit 922af0c

Browse files
committed
fix: improve model selection reliability, make gateway model source of truth
1 parent 56fab62 commit 922af0c

18 files changed

Lines changed: 436 additions & 201 deletions

File tree

apps/twig/src/main/services/agent/schemas.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,20 @@ export const startSessionInput = z.object({
4949

5050
export type StartSessionInput = z.infer<typeof startSessionInput>;
5151

52+
export const modelOptionSchema = z.object({
53+
modelId: z.string(),
54+
name: z.string(),
55+
description: z.string().nullish(),
56+
provider: z.string().optional(),
57+
});
58+
59+
export type ModelOption = z.infer<typeof modelOptionSchema>;
60+
5261
export const sessionResponseSchema = z.object({
5362
sessionId: z.string(),
5463
channel: z.string(),
64+
availableModels: z.array(modelOptionSchema).optional(),
65+
currentModelId: z.string().optional(),
5566
});
5667

5768
export type SessionResponse = z.infer<typeof sessionResponseSchema>;
@@ -220,3 +231,10 @@ export const sessionInfoSchema = z.object({
220231
});
221232

222233
export const listSessionsOutput = z.array(sessionInfoSchema);
234+
235+
export const getGatewayModelsInput = z.object({
236+
apiHost: z.string(),
237+
apiKey: z.string(),
238+
});
239+
240+
export const getGatewayModelsOutput = z.array(modelOptionSchema);

apps/twig/src/main/services/agent/service.ts

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@ import {
1010
type RequestPermissionRequest,
1111
type RequestPermissionResponse,
1212
} from "@agentclientprotocol/sdk";
13-
import { Agent, getLlmGatewayUrl, type OnLogCallback } from "@posthog/agent";
13+
import { Agent } from "@posthog/agent/agent";
14+
import {
15+
fetchGatewayModels,
16+
formatGatewayModelName,
17+
getProviderName,
18+
} from "@posthog/agent/gateway-models";
19+
import { getLlmGatewayUrl } from "@posthog/agent/posthog-api";
20+
import type { OnLogCallback } from "@posthog/agent/types";
1421
import { app } from "electron";
1522
import { injectable } from "inversify";
1623
import type { AcpMessage } from "../../../shared/types/session-events.js";
@@ -152,6 +159,12 @@ interface ManagedSession {
152159
needsRecreation: boolean;
153160
promptPending: boolean;
154161
pendingContext?: string;
162+
availableModels?: Array<{
163+
modelId: string;
164+
name: string;
165+
description?: string | null;
166+
}>;
167+
currentModelId?: string;
155168
}
156169

157170
function getClaudeCliPath(): string {
@@ -383,25 +396,43 @@ export class AgentService extends TypedEventEmitter<AgentServiceEvents> {
383396

384397
const mcpServers = this.buildMcpServers(credentials);
385398

399+
let availableModels:
400+
| Array<{ modelId: string; name: string; description?: string | null }>
401+
| undefined;
402+
let currentModelId: string | undefined;
403+
386404
if (isReconnect) {
387-
await connection.extMethod("_posthog/session/resume", {
388-
sessionId: taskRunId,
389-
cwd: repoPath,
390-
mcpServers,
391-
_meta: {
392-
...(logUrl && {
393-
persistence: { taskId, runId: taskRunId, logUrl },
394-
}),
395-
...(sdkSessionId && { sdkSessionId }),
396-
...(additionalDirectories?.length && {
397-
claudeCode: {
398-
options: { additionalDirectories },
399-
},
400-
}),
405+
const resumeResponse = await connection.extMethod(
406+
"_posthog/session/resume",
407+
{
408+
sessionId: taskRunId,
409+
cwd: repoPath,
410+
mcpServers,
411+
_meta: {
412+
...(logUrl && {
413+
persistence: { taskId, runId: taskRunId, logUrl },
414+
}),
415+
...(sdkSessionId && { sdkSessionId }),
416+
...(additionalDirectories?.length && {
417+
claudeCode: {
418+
options: { additionalDirectories },
419+
},
420+
}),
421+
},
401422
},
402-
});
423+
);
424+
const resumeMeta = resumeResponse?._meta as
425+
| {
426+
models?: {
427+
availableModels?: typeof availableModels;
428+
currentModelId?: string;
429+
};
430+
}
431+
| undefined;
432+
availableModels = resumeMeta?.models?.availableModels;
433+
currentModelId = resumeMeta?.models?.currentModelId;
403434
} else {
404-
await connection.newSession({
435+
const newSessionResponse = await connection.newSession({
405436
cwd: repoPath,
406437
mcpServers,
407438
_meta: {
@@ -415,6 +446,8 @@ export class AgentService extends TypedEventEmitter<AgentServiceEvents> {
415446
}),
416447
},
417448
});
449+
availableModels = newSessionResponse.models?.availableModels;
450+
currentModelId = newSessionResponse.models?.currentModelId;
418451
}
419452

420453
const session: ManagedSession = {
@@ -430,6 +463,8 @@ export class AgentService extends TypedEventEmitter<AgentServiceEvents> {
430463
config,
431464
needsRecreation: false,
432465
promptPending: false,
466+
availableModels,
467+
currentModelId,
433468
};
434469

435470
this.sessions.set(taskRunId, session);
@@ -999,7 +1034,12 @@ For git operations while detached:
9991034
}
10001035

10011036
private toSessionResponse(session: ManagedSession): SessionResponse {
1002-
return { sessionId: session.taskRunId, channel: session.channel };
1037+
return {
1038+
sessionId: session.taskRunId,
1039+
channel: session.channel,
1040+
availableModels: session.availableModels,
1041+
currentModelId: session.currentModelId,
1042+
};
10031043
}
10041044

10051045
/**
@@ -1109,4 +1149,16 @@ For git operations while detached:
11091149
log.debug("Error in PR URL detection", { taskRunId, error: err });
11101150
}
11111151
}
1152+
1153+
async getGatewayModels(apiHost: string, _apiKey: string) {
1154+
const gatewayUrl = getLlmGatewayUrl(apiHost);
1155+
const models = await fetchGatewayModels({ gatewayUrl });
1156+
1157+
return models.map((model) => ({
1158+
modelId: model.id,
1159+
name: formatGatewayModelName(model),
1160+
description: `Context: ${model.context_window.toLocaleString()} tokens`,
1161+
provider: getProviderName(model.owned_by),
1162+
}));
1163+
}
11121164
}

apps/twig/src/main/services/folders/service.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { exec } from "node:child_process";
22
import fs from "node:fs";
33
import path from "node:path";
44
import { promisify } from "node:util";
5-
import { WorktreeManager } from "@posthog/agent";
5+
import { WorktreeManager } from "@posthog/agent/worktree-manager";
66
import { dialog } from "electron";
77
import { injectable } from "inversify";
88
import { generateId } from "../../../shared/utils/id.js";

apps/twig/src/main/services/workspace/service.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import * as fs from "node:fs";
33
import * as fsPromises from "node:fs/promises";
44
import path from "node:path";
55
import { promisify } from "node:util";
6-
import { WorktreeManager } from "@posthog/agent";
6+
import { WorktreeManager } from "@posthog/agent/worktree-manager";
77
import { inject, injectable } from "inversify";
88
import type {
99
TaskFolderAssociation,

apps/twig/src/main/trpc/routers/agent.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import {
66
cancelPermissionInput,
77
cancelPromptInput,
88
cancelSessionInput,
9+
getGatewayModelsInput,
10+
getGatewayModelsOutput,
911
listSessionsInput,
1012
listSessionsOutput,
1113
notifySessionContextInput,
@@ -140,4 +142,11 @@ export const agentRouter = router({
140142
markAllForRecreation: publicProcedure.mutation(() =>
141143
getService().markAllSessionsForRecreation(),
142144
),
145+
146+
getGatewayModels: publicProcedure
147+
.input(getGatewayModelsInput)
148+
.output(getGatewayModelsOutput)
149+
.query(({ input }) =>
150+
getService().getGatewayModels(input.apiHost, input.apiKey),
151+
),
143152
});

apps/twig/src/main/utils/store.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { WorktreeManager } from "@posthog/agent";
1+
import { WorktreeManager } from "@posthog/agent/worktree-manager";
22
import { app } from "electron";
33
import Store from "electron-store";
44
import type {

apps/twig/src/renderer/features/sessions/components/ModelSelector.tsx

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
1-
import { useSettingsStore } from "@features/settings/stores/settingsStore";
21
import { Select, Text } from "@radix-ui/themes";
3-
import {
4-
AVAILABLE_MODELS,
5-
getModelsByProvider,
6-
type ModelProvider,
7-
} from "@shared/types/models";
82
import { Fragment } from "react";
3+
import { useModelsStore } from "../stores/modelsStore";
94
import { useSessionActions, useSessionForTask } from "../stores/sessionStore";
105

116
interface ModelSelectorProps {
@@ -19,31 +14,26 @@ export function ModelSelector({
1914
disabled,
2015
onModelChange,
2116
}: ModelSelectorProps) {
22-
const defaultModel = useSettingsStore((state) => state.defaultModel);
23-
const setDefaultModel = useSettingsStore((state) => state.setDefaultModel);
2417
const { setSessionModel } = useSessionActions();
2518
const session = useSessionForTask(taskId);
2619

27-
// Use session model if available, otherwise fall back to default
28-
const activeModel = session?.model ?? defaultModel;
20+
const groupedModels = useModelsStore((s) => s.groupedModels);
21+
const models = useModelsStore((s) => s.models);
22+
const selectedModel = useModelsStore((s) => s.selectedModel);
23+
const setSelectedModel = useModelsStore((s) => s.setSelectedModel);
24+
25+
const activeModel = session?.model ?? selectedModel;
2926

3027
const handleChange = (value: string) => {
31-
// Always update the default
32-
setDefaultModel(value);
28+
setSelectedModel(value);
3329
onModelChange?.(value);
3430

35-
// If there's an active session, update the model mid-session
3631
if (taskId && session?.status === "connected" && !session.isCloud) {
3732
setSessionModel(taskId, value);
3833
}
3934
};
4035

41-
const modelsByProvider = getModelsByProvider();
42-
const providers = (Object.keys(modelsByProvider) as ModelProvider[]).filter(
43-
(provider) => modelsByProvider[provider].models.length > 0,
44-
);
45-
46-
const currentModel = AVAILABLE_MODELS.find((m) => m.id === activeModel);
36+
const currentModel = models.find((m) => m.modelId === activeModel);
4737
const displayName = currentModel?.name ?? activeModel;
4838

4939
return (
@@ -69,13 +59,13 @@ export function ModelSelector({
6959
</Text>
7060
</Select.Trigger>
7161
<Select.Content position="popper" sideOffset={4}>
72-
{providers.map((provider, index) => (
73-
<Fragment key={provider}>
62+
{groupedModels.map((group, index) => (
63+
<Fragment key={group.provider}>
7464
{index > 0 && <Select.Separator />}
7565
<Select.Group>
76-
<Select.Label>{modelsByProvider[provider].name}</Select.Label>
77-
{modelsByProvider[provider].models.map((model) => (
78-
<Select.Item key={model.id} value={model.id}>
66+
<Select.Label>{group.provider}</Select.Label>
67+
{group.models.map((model) => (
68+
<Select.Item key={model.modelId} value={model.modelId}>
7969
{model.name}
8070
</Select.Item>
8171
))}

0 commit comments

Comments
 (0)