Skip to content

Commit 6774b64

Browse files
committed
feat(cloud-agent): user-authored-prs
chore: cleaner chore: cleaner
1 parent e55079e commit 6774b64

14 files changed

Lines changed: 275 additions & 6 deletions

File tree

apps/code/src/main/services/git/schemas.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,14 @@ export const ghStatusOutput = z.object({
223223

224224
export type GhStatusOutput = z.infer<typeof ghStatusOutput>;
225225

226+
export const ghAuthTokenOutput = z.object({
227+
success: z.boolean(),
228+
token: z.string().nullable(),
229+
error: z.string().nullable(),
230+
});
231+
232+
export type GhAuthTokenOutput = z.infer<typeof ghAuthTokenOutput>;
233+
226234
// Pull request status
227235
export const prStatusInput = directoryPathInput;
228236
export const prStatusOutput = z.object({

apps/code/src/main/services/git/service.test.ts

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,61 @@ describe("GitService.getPrChangedFiles", () => {
127127
).rejects.toThrow("Failed to fetch PR files");
128128
});
129129
});
130+
131+
describe("GitService.getGhAuthToken", () => {
132+
let service: GitService;
133+
134+
beforeEach(() => {
135+
vi.clearAllMocks();
136+
service = new GitService({} as LlmGatewayService);
137+
});
138+
139+
it("returns the authenticated GitHub CLI token", async () => {
140+
mockExecGh.mockResolvedValue({
141+
exitCode: 0,
142+
stdout: "ghu_test_token\n",
143+
stderr: "",
144+
});
145+
146+
const result = await service.getGhAuthToken();
147+
148+
expect(mockExecGh).toHaveBeenCalledWith(["auth", "token"]);
149+
expect(result).toEqual({
150+
success: true,
151+
token: "ghu_test_token",
152+
error: null,
153+
});
154+
});
155+
156+
it("returns the gh error when auth token lookup fails", async () => {
157+
mockExecGh.mockResolvedValue({
158+
exitCode: 1,
159+
stdout: "",
160+
stderr: "authentication required",
161+
});
162+
163+
const result = await service.getGhAuthToken();
164+
165+
expect(result).toEqual({
166+
success: false,
167+
token: null,
168+
error: "authentication required",
169+
});
170+
});
171+
172+
it("returns error when stdout is empty", async () => {
173+
mockExecGh.mockResolvedValue({
174+
exitCode: 0,
175+
stdout: "",
176+
stderr: "",
177+
});
178+
179+
const result = await service.getGhAuthToken();
180+
181+
expect(result).toEqual({
182+
success: false,
183+
token: null,
184+
error: "GitHub auth token is empty",
185+
});
186+
});
187+
});

apps/code/src/main/services/git/service.ts

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ import type {
4646
DiscardFileChangesOutput,
4747
GetCommitConventionsOutput,
4848
GetPrTemplateOutput,
49+
GhAuthTokenOutput,
4950
GhStatusOutput,
5051
GitCommitInfo,
5152
GitFileStatus,
@@ -706,6 +707,33 @@ export class GitService extends TypedEventEmitter<GitServiceEvents> {
706707
};
707708
}
708709

710+
public async getGhAuthToken(): Promise<GhAuthTokenOutput> {
711+
const result = await execGh(["auth", "token"]);
712+
if (result.exitCode !== 0) {
713+
return {
714+
success: false,
715+
token: null,
716+
error:
717+
result.stderr || result.error || "Failed to read GitHub auth token",
718+
};
719+
}
720+
721+
const token = result.stdout.trim();
722+
if (!token) {
723+
return {
724+
success: false,
725+
token: null,
726+
error: "GitHub auth token is empty",
727+
};
728+
}
729+
730+
return {
731+
success: true,
732+
token,
733+
error: null,
734+
};
735+
}
736+
709737
public async getPrStatus(directoryPath: string): Promise<PrStatusOutput> {
710738
const base: PrStatusOutput = {
711739
hasRemote: false,

apps/code/src/main/trpc/routers/git.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import {
4444
getPrChangedFilesOutput,
4545
getPrTemplateInput,
4646
getPrTemplateOutput,
47+
ghAuthTokenOutput,
4748
ghStatusOutput,
4849
gitStateSnapshotSchema,
4950
openPrInput,
@@ -264,6 +265,10 @@ export const gitRouter = router({
264265
.output(ghStatusOutput)
265266
.query(() => getService().getGhStatus()),
266267

268+
getGhAuthToken: publicProcedure
269+
.output(ghAuthTokenOutput)
270+
.query(() => getService().getGhAuthToken()),
271+
267272
getPrStatus: publicProcedure
268273
.input(prStatusInput)
269274
.output(prStatusOutput)

apps/code/src/renderer/api/posthogClient.ts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import type {
1010
Task,
1111
TaskRun,
1212
} from "@shared/types";
13+
import type { CloudRunSource, PrAuthorshipMode } from "@shared/types/cloud";
1314
import type { StoredLogEntry } from "@shared/types/session-events";
1415
import { logger } from "@utils/logger";
1516
import { buildApiFetcher } from "./fetcher";
@@ -559,6 +560,12 @@ export class PostHogAPIClient {
559560
branch?: string | null,
560561
resumeOptions?: { resumeFromRunId: string; pendingUserMessage: string },
561562
sandboxEnvironmentId?: string,
563+
runOptions?: {
564+
prAuthorshipMode?: PrAuthorshipMode;
565+
runSource?: CloudRunSource;
566+
signalReportId?: string;
567+
githubUserToken?: string;
568+
},
562569
): Promise<Task> {
563570
const teamId = await this.getTeamId();
564571
const body: Record<string, unknown> = { mode: "interactive" };
@@ -572,6 +579,18 @@ export class PostHogAPIClient {
572579
if (sandboxEnvironmentId) {
573580
body.sandbox_environment_id = sandboxEnvironmentId;
574581
}
582+
if (runOptions?.prAuthorshipMode) {
583+
body.pr_authorship_mode = runOptions.prAuthorshipMode;
584+
}
585+
if (runOptions?.runSource) {
586+
body.run_source = runOptions.runSource;
587+
}
588+
if (runOptions?.signalReportId) {
589+
body.signal_report_id = runOptions.signalReportId;
590+
}
591+
if (runOptions?.githubUserToken) {
592+
body.github_user_token = runOptions.githubUserToken;
593+
}
575594

576595
const data = await this.api.post(
577596
`/api/projects/{project_id}/tasks/{id}/run/`,

apps/code/src/renderer/features/inbox/stores/inboxCloudTaskStore.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ export const useInboxCloudTaskStore = create<InboxCloudTaskStore>()(
6161
workspaceMode: "cloud",
6262
githubIntegrationId: params.githubIntegrationId,
6363
repository: selectedRepo,
64+
cloudPrAuthorshipMode: "user",
65+
cloudRunSource: "signal_report",
66+
signalReportId: params.reportId,
6467
});
6568

6669
if (result.success) {

apps/code/src/renderer/features/sessions/service/service.ts

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import { taskViewedApi } from "@features/sidebar/hooks/useTaskViewed";
2929
import { DEFAULT_GATEWAY_MODEL } from "@posthog/agent/gateway-models";
3030
import { getIsOnline } from "@renderer/stores/connectivityStore";
3131
import { trpcClient } from "@renderer/trpc/client";
32+
import { getGhUserTokenOrThrow } from "@renderer/utils/github";
3233
import { toast } from "@renderer/utils/toast";
3334
import { getCloudUrlFromRegion } from "@shared/constants/oauth";
3435
import {
@@ -39,6 +40,7 @@ import {
3940
type Task,
4041
} from "@shared/types";
4142
import { ANALYTICS_EVENTS } from "@shared/types/analytics";
43+
import type { CloudRunSource, PrAuthorshipMode } from "@shared/types/cloud";
4244
import type { AcpMessage, StoredLogEntry } from "@shared/types/session-events";
4345
import { isJsonRpcRequest } from "@shared/types/session-events";
4446
import { buildPermissionToolMetadata, track } from "@utils/analytics";
@@ -1269,6 +1271,35 @@ export class SessionService {
12691271
throw new Error("Authentication required for cloud commands");
12701272
}
12711273

1274+
const [previousRun, task] = await Promise.all([
1275+
client.getTaskRun(session.taskId, session.taskRunId),
1276+
client.getTask(session.taskId),
1277+
]);
1278+
const hasGitHubRepo = !!task.repository && !!task.github_integration;
1279+
const previousState = previousRun.state as Record<string, unknown>;
1280+
const previousOutput = (previousRun.output ?? {}) as Record<
1281+
string,
1282+
unknown
1283+
>;
1284+
// Prefer the actual working branch the agent last pushed to (synced by
1285+
// agent-server after each turn), then the run-level branch field, then
1286+
// the original base branch from state. This preserves unmerged work when
1287+
// the snapshot has expired and the sandbox is rebuilt from scratch.
1288+
const previousBaseBranch =
1289+
(typeof previousOutput.head_branch === "string"
1290+
? previousOutput.head_branch
1291+
: null) ??
1292+
previousRun.branch ??
1293+
(typeof previousState.pr_base_branch === "string"
1294+
? previousState.pr_base_branch
1295+
: null) ??
1296+
session.cloudBranch;
1297+
const prAuthorshipMode = this.getCloudPrAuthorshipMode(previousState);
1298+
const githubUserToken =
1299+
prAuthorshipMode === "user" && hasGitHubRepo
1300+
? await getGhUserTokenOrThrow()
1301+
: undefined;
1302+
12721303
log.info("Creating resume run for terminal cloud task", {
12731304
taskId: session.taskId,
12741305
previousRunId: session.taskRunId,
@@ -1280,11 +1311,21 @@ export class SessionService {
12801311
// The agent will load conversation history and restore the sandbox snapshot.
12811312
const updatedTask = await client.runTaskInCloud(
12821313
session.taskId,
1283-
session.cloudBranch,
1314+
previousBaseBranch,
12841315
{
12851316
resumeFromRunId: session.taskRunId,
12861317
pendingUserMessage: promptText,
12871318
},
1319+
undefined,
1320+
{
1321+
prAuthorshipMode,
1322+
runSource: this.getCloudRunSource(previousState),
1323+
signalReportId:
1324+
typeof previousState.signal_report_id === "string"
1325+
? previousState.signal_report_id
1326+
: undefined,
1327+
githubUserToken,
1328+
},
12881329
);
12891330
const newRun = updatedTask.latest_run;
12901331
if (!newRun?.id) {
@@ -2007,6 +2048,20 @@ export class SessionService {
20072048
}
20082049
}
20092050

2051+
private getCloudPrAuthorshipMode(
2052+
state: Record<string, unknown>,
2053+
): PrAuthorshipMode {
2054+
const explicitMode = state.pr_authorship_mode;
2055+
if (explicitMode === "user" || explicitMode === "bot") {
2056+
return explicitMode;
2057+
}
2058+
return state.run_source === "signal_report" ? "bot" : "user";
2059+
}
2060+
2061+
private getCloudRunSource(state: Record<string, unknown>): CloudRunSource {
2062+
return state.run_source === "signal_report" ? "signal_report" : "manual";
2063+
}
2064+
20102065
/**
20112066
* Filter out session/prompt events that should be skipped during resume.
20122067
* When resuming a cloud run, the initial session/prompt from the new run's

apps/code/src/renderer/sagas/task/task-creation.test.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,12 @@ describe("TaskCreationSaga", () => {
137137
"release/remembered-branch",
138138
undefined,
139139
undefined,
140+
{
141+
prAuthorshipMode: "bot",
142+
runSource: "manual",
143+
signalReportId: undefined,
144+
githubUserToken: undefined,
145+
},
140146
);
141147
expect(onTaskReady).toHaveBeenCalledTimes(1);
142148
expect(onTaskReady.mock.calls[0][0].task.latest_run?.branch).toBe(

apps/code/src/renderer/sagas/task/task-creation.ts

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import { trpcClient } from "@renderer/trpc";
1717
import { generateTitleAndSummary } from "@renderer/utils/generateTitle";
1818
import { getTaskRepository } from "@renderer/utils/repository";
1919
import type { ExecutionMode, Task } from "@shared/types";
20+
import type { CloudRunSource, PrAuthorshipMode } from "@shared/types/cloud";
21+
import { getGhUserTokenOrThrow } from "@utils/github";
2022
import { logger } from "@utils/logger";
2123
import { queryClient } from "@utils/queryClient";
2224

@@ -73,6 +75,8 @@ export interface TaskCreationInput {
7375
reasoningLevel?: string;
7476
environmentId?: string;
7577
sandboxEnvironmentId?: string;
78+
cloudPrAuthorshipMode?: PrAuthorshipMode;
79+
cloudRunSource?: CloudRunSource;
7680
signalReportId?: string;
7781
}
7882

@@ -259,13 +263,29 @@ export class TaskCreationSaga extends Saga<
259263
if (shouldStartCloudRun) {
260264
task = await this.step({
261265
name: "cloud_run",
262-
execute: () =>
263-
this.deps.posthogClient.runTaskInCloud(
266+
execute: async () => {
267+
const hasGitHubRepo = !!task.repository && !!task.github_integration;
268+
const prAuthorshipMode =
269+
input.cloudPrAuthorshipMode ?? (hasGitHubRepo ? "user" : "bot");
270+
let githubUserToken: string | undefined;
271+
272+
if (prAuthorshipMode === "user" && hasGitHubRepo) {
273+
githubUserToken = await getGhUserTokenOrThrow();
274+
}
275+
276+
return this.deps.posthogClient.runTaskInCloud(
264277
task.id,
265278
branch,
266279
undefined,
267280
input.sandboxEnvironmentId,
268-
),
281+
{
282+
prAuthorshipMode,
283+
runSource: input.cloudRunSource ?? "manual",
284+
signalReportId: input.signalReportId,
285+
githubUserToken,
286+
},
287+
);
288+
},
269289
rollback: async () => {
270290
log.info("Rolling back: cloud run (no-op)", { taskId: task.id });
271291
},
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import { trpcClient } from "@renderer/trpc";
2+
3+
export async function getGhUserTokenOrThrow(): Promise<string> {
4+
const tokenResult = await trpcClient.git.getGhAuthToken.query();
5+
if (!tokenResult.success || !tokenResult.token) {
6+
throw new Error(
7+
tokenResult.error ||
8+
"Authenticate GitHub CLI with `gh auth login` before starting a cloud task.",
9+
);
10+
}
11+
return tokenResult.token;
12+
}

0 commit comments

Comments
 (0)