@@ -10,7 +10,14 @@ import {
1010 type RequestPermissionRequest ,
1111 type RequestPermissionResponse ,
1212} from "@agentclientprotocol/sdk" ;
13- import { Agent , getLlmGatewayUrl , type OnLogCallback } from "@posthog/agent" ;
13+ import { Agent } from "@posthog/agent/agent" ;
14+ import {
15+ fetchGatewayModels ,
16+ formatGatewayModelName ,
17+ getProviderName ,
18+ } from "@posthog/agent/gateway-models" ;
19+ import { getLlmGatewayUrl } from "@posthog/agent/posthog-api" ;
20+ import type { OnLogCallback } from "@posthog/agent/types" ;
1421import { app } from "electron" ;
1522import { injectable , preDestroy } from "inversify" ;
1623import type { AcpMessage } from "../../../shared/types/session-events.js" ;
@@ -179,6 +186,12 @@ interface ManagedSession {
179186 needsRecreation : boolean ;
180187 promptPending : boolean ;
181188 pendingContext ?: string ;
189+ availableModels ?: Array < {
190+ modelId : string ;
191+ name : string ;
192+ description ?: string | null ;
193+ } > ;
194+ currentModelId ?: string ;
182195}
183196
184197function getClaudeCliPath ( ) : string {
@@ -410,25 +423,43 @@ export class AgentService extends TypedEventEmitter<AgentServiceEvents> {
410423
411424 const mcpServers = this . buildMcpServers ( credentials ) ;
412425
426+ let availableModels :
427+ | Array < { modelId : string ; name : string ; description ?: string | null } >
428+ | undefined ;
429+ let currentModelId : string | undefined ;
430+
413431 if ( isReconnect ) {
414- await connection . extMethod ( "_posthog/session/resume" , {
415- sessionId : taskRunId ,
416- cwd : repoPath ,
417- mcpServers,
418- _meta : {
419- ...( logUrl && {
420- persistence : { taskId, runId : taskRunId , logUrl } ,
421- } ) ,
422- ...( sdkSessionId && { sdkSessionId } ) ,
423- ...( additionalDirectories ?. length && {
424- claudeCode : {
425- options : { additionalDirectories } ,
426- } ,
427- } ) ,
432+ const resumeResponse = await connection . extMethod (
433+ "_posthog/session/resume" ,
434+ {
435+ sessionId : taskRunId ,
436+ cwd : repoPath ,
437+ mcpServers,
438+ _meta : {
439+ ...( logUrl && {
440+ persistence : { taskId, runId : taskRunId , logUrl } ,
441+ } ) ,
442+ ...( sdkSessionId && { sdkSessionId } ) ,
443+ ...( additionalDirectories ?. length && {
444+ claudeCode : {
445+ options : { additionalDirectories } ,
446+ } ,
447+ } ) ,
448+ } ,
428449 } ,
429- } ) ;
450+ ) ;
451+ const resumeMeta = resumeResponse ?. _meta as
452+ | {
453+ models ?: {
454+ availableModels ?: typeof availableModels ;
455+ currentModelId ?: string ;
456+ } ;
457+ }
458+ | undefined ;
459+ availableModels = resumeMeta ?. models ?. availableModels ;
460+ currentModelId = resumeMeta ?. models ?. currentModelId ;
430461 } else {
431- await connection . newSession ( {
462+ const newSessionResponse = await connection . newSession ( {
432463 cwd : repoPath ,
433464 mcpServers,
434465 _meta : {
@@ -442,6 +473,8 @@ export class AgentService extends TypedEventEmitter<AgentServiceEvents> {
442473 } ) ,
443474 } ,
444475 } ) ;
476+ availableModels = newSessionResponse . models ?. availableModels ;
477+ currentModelId = newSessionResponse . models ?. currentModelId ;
445478 }
446479
447480 const session : ManagedSession = {
@@ -457,6 +490,8 @@ export class AgentService extends TypedEventEmitter<AgentServiceEvents> {
457490 config,
458491 needsRecreation : false ,
459492 promptPending : false ,
493+ availableModels,
494+ currentModelId,
460495 } ;
461496
462497 this . sessions . set ( taskRunId , session ) ;
@@ -1027,7 +1062,12 @@ For git operations while detached:
10271062 }
10281063
10291064 private toSessionResponse ( session : ManagedSession ) : SessionResponse {
1030- return { sessionId : session . taskRunId , channel : session . channel } ;
1065+ return {
1066+ sessionId : session . taskRunId ,
1067+ channel : session . channel ,
1068+ availableModels : session . availableModels ,
1069+ currentModelId : session . currentModelId ,
1070+ } ;
10311071 }
10321072
10331073 /**
@@ -1137,4 +1177,30 @@ For git operations while detached:
11371177 log . debug ( "Error in PR URL detection" , { taskRunId, error : err } ) ;
11381178 }
11391179 }
1180+
1181+ async getGatewayModels ( apiHost : string , _apiKey : string ) {
1182+ const gatewayUrl = getLlmGatewayUrl ( apiHost ) ;
1183+ const models = await fetchGatewayModels ( { gatewayUrl } ) ;
1184+
1185+ const MODEL_TIER_ORDER = [ "opus" , "sonnet" , "haiku" ] ;
1186+
1187+ const getModelTier = ( modelId : string ) : number => {
1188+ const lowerId = modelId . toLowerCase ( ) ;
1189+ for ( let i = 0 ; i < MODEL_TIER_ORDER . length ; i ++ ) {
1190+ if ( lowerId . includes ( MODEL_TIER_ORDER [ i ] ) ) return i ;
1191+ }
1192+ return MODEL_TIER_ORDER . length ;
1193+ } ;
1194+
1195+ const mapped = models . map ( ( model ) => ( {
1196+ modelId : model . id ,
1197+ name : formatGatewayModelName ( model ) ,
1198+ description : `Context: ${ model . context_window . toLocaleString ( ) } tokens` ,
1199+ provider : getProviderName ( model . owned_by ) ,
1200+ } ) ) ;
1201+
1202+ return mapped . sort (
1203+ ( a , b ) => getModelTier ( a . modelId ) - getModelTier ( b . modelId ) ,
1204+ ) ;
1205+ }
11401206}
0 commit comments