diff --git a/apps/code/src/main/services/handoff/schemas.ts b/apps/code/src/main/services/handoff/schemas.ts index 97ea85adc..e6d76c26d 100644 --- a/apps/code/src/main/services/handoff/schemas.ts +++ b/apps/code/src/main/services/handoff/schemas.ts @@ -28,6 +28,22 @@ export const handoffPreflightResult = z.object({ reason: z.string().optional(), localTreeDirty: z.boolean(), localGitState: handoffLocalGitStateSchema.optional(), + changedFiles: z + .array( + z.object({ + path: z.string(), + status: z.enum([ + "modified", + "added", + "deleted", + "renamed", + "untracked", + ]), + linesAdded: z.number().optional(), + linesRemoved: z.number().optional(), + }), + ) + .optional(), }); export type HandoffPreflightResult = z.infer; diff --git a/apps/code/src/main/services/handoff/service.ts b/apps/code/src/main/services/handoff/service.ts index 93b96e190..ed15efaac 100644 --- a/apps/code/src/main/services/handoff/service.ts +++ b/apps/code/src/main/services/handoff/service.ts @@ -12,6 +12,8 @@ import { type GitHandoffBranchDivergence, readHandoffLocalGitState, } from "@posthog/git/handoff"; +import { ResetToDefaultBranchSaga } from "@posthog/git/sagas/branch"; +import { StashPushSaga } from "@posthog/git/sagas/stash"; import { app, dialog, net } from "electron"; import { inject, injectable } from "inversify"; import type { IWorkspaceRepository } from "../../db/repositories/workspace-repository"; @@ -63,9 +65,16 @@ export class HandoffService extends TypedEventEmitter { let localTreeDirty = false; let localGitState: AgentTypes.HandoffLocalGitState | undefined; + let changedFileDetails: HandoffPreflightResult["changedFiles"]; try { const changedFiles = await this.gitService.getChangedFilesHead(repoPath); localTreeDirty = changedFiles.length > 0; + changedFileDetails = changedFiles.map((f) => ({ + path: f.path, + status: f.status, + linesAdded: f.linesAdded, + linesRemoved: f.linesRemoved, + })); localGitState = await this.getLocalGitState(repoPath); } catch (err) { log.warn("Failed to check local working tree", { repoPath, err }); @@ -76,7 +85,13 @@ export class HandoffService extends TypedEventEmitter { ? "Local working tree has uncommitted changes. Commit or stash them first." : undefined; - return { canHandoff, reason, localTreeDirty, localGitState }; + return { + canHandoff, + reason, + localTreeDirty, + localGitState, + changedFiles: changedFileDetails, + }; } async execute(input: HandoffExecuteInput): Promise { @@ -368,12 +383,62 @@ export class HandoffService extends TypedEventEmitter { }; } + await this.cleanupLocalAfterCloudHandoff( + repoPath, + input.localGitState?.branch ?? null, + ); + return { success: true, logEntryCount: result.data.flushedLogEntryCount, }; } + private async cleanupLocalAfterCloudHandoff( + repoPath: string, + branchName: string | null, + ): Promise { + try { + const hasChanges = + (await this.gitService.getChangedFilesHead(repoPath)).length > 0; + + if (hasChanges) { + const label = branchName ?? "unknown"; + const stashSaga = new StashPushSaga(); + const stashResult = await stashSaga.run({ + baseDir: repoPath, + message: `posthog-code: handoff backup (${label})`, + }); + if (!stashResult.success) { + log.warn("Failed to stash changes during cloud handoff cleanup", { + error: stashResult.error, + }); + return; + } + } + + const resetSaga = new ResetToDefaultBranchSaga(); + const resetResult = await resetSaga.run({ baseDir: repoPath }); + if (!resetResult.success) { + log.warn( + "Failed to reset to default branch during cloud handoff cleanup", + { + error: resetResult.error, + }, + ); + return; + } + + log.info("Local cleanup after cloud handoff complete", { + repoPath, + switched: resetResult.data.switched, + defaultBranch: resetResult.data.defaultBranch, + }); + } catch (err) { + log.warn("Post-handoff local cleanup failed", { repoPath, err }); + } + } + private createApiClient(apiHost: string, teamId: number): PostHogAPIClient { const config = this.agentAuthAdapter.createPosthogConfig({ apiHost, diff --git a/apps/code/src/renderer/components/HeaderRow.tsx b/apps/code/src/renderer/components/HeaderRow.tsx index 8a27cde0a..a0096e0c0 100644 --- a/apps/code/src/renderer/components/HeaderRow.tsx +++ b/apps/code/src/renderer/components/HeaderRow.tsx @@ -2,8 +2,10 @@ import { useAuthStateValue } from "@features/auth/hooks/authQueries"; import { DiffStatsBadge } from "@features/code-review/components/DiffStatsBadge"; import { CloudGitInteractionHeader } from "@features/git-interaction/components/CloudGitInteractionHeader"; import { GitInteractionHeader } from "@features/git-interaction/components/GitInteractionHeader"; +import { HandoffConfirmDialog } from "@features/sessions/components/HandoffConfirmDialog"; import { useSessionForTask } from "@features/sessions/hooks/useSession"; import { useSessionCallbacks } from "@features/sessions/hooks/useSessionCallbacks"; +import { useHandoffDialogStore } from "@features/sessions/stores/handoffDialogStore"; import { SidebarTrigger } from "@features/sidebar/components/SidebarTrigger"; import { useSidebarStore } from "@features/sidebar/stores/sidebarStore"; import { useWorkspace } from "@features/workspace/hooks/useWorkspace"; @@ -12,38 +14,79 @@ import type { Task } from "@shared/types"; import { useHeaderStore } from "@stores/headerStore"; import { useNavigationStore } from "@stores/navigationStore"; import { isWindows } from "@utils/platform"; +import { useState } from "react"; function LocalHandoffButton({ taskId, task }: { taskId: string; task: Task }) { const session = useSessionForTask(taskId); const workspace = useWorkspace(taskId); const repoPath = workspace?.folderPath ?? null; const authStatus = useAuthStateValue((s) => s.status); - const { handleContinueInCloud } = useSessionCallbacks({ + const { initiateHandoffToCloud } = useSessionCallbacks({ taskId, task, session: session ?? undefined, repoPath, }); + const confirmOpen = useHandoffDialogStore((s) => s.confirmOpen); + const direction = useHandoffDialogStore((s) => s.direction); + const branchName = useHandoffDialogStore((s) => s.branchName); + const openConfirm = useHandoffDialogStore((s) => s.openConfirm); + const closeConfirm = useHandoffDialogStore((s) => s.closeConfirm); + + const [isSubmitting, setIsSubmitting] = useState(false); + const [error, setError] = useState(null); + if (authStatus !== "authenticated") return null; + const handleConfirm = async () => { + setError(null); + setIsSubmitting(true); + try { + await initiateHandoffToCloud(); + } catch (err) { + setError(err instanceof Error ? err.message : "Handoff failed"); + } finally { + setIsSubmitting(false); + } + }; + return ( - + <> + + {confirmOpen && direction === "to-cloud" && ( + { + if (!open) { + closeConfirm(); + setError(null); + } + }} + direction="to-cloud" + branchName={branchName} + onConfirm={handleConfirm} + isSubmitting={isSubmitting} + error={error} + /> + )} + ); } export const HEADER_HEIGHT = 36; const COLLAPSED_WIDTH = 110; -/** Width reserved for Windows title bar buttons (Close/Minimize/Maximize) */ const WINDOWS_TITLEBAR_INSET = 140; export function HeaderRow() { diff --git a/apps/code/src/renderer/features/git-interaction/components/CloudGitInteractionHeader.tsx b/apps/code/src/renderer/features/git-interaction/components/CloudGitInteractionHeader.tsx index 714524d29..d9477f3e7 100644 --- a/apps/code/src/renderer/features/git-interaction/components/CloudGitInteractionHeader.tsx +++ b/apps/code/src/renderer/features/git-interaction/components/CloudGitInteractionHeader.tsx @@ -1,14 +1,25 @@ +import { + GitBranchDialog, + GitCommitDialog, +} from "@features/git-interaction/components/GitInteractionDialogs"; +import { useGitInteraction } from "@features/git-interaction/hooks/useGitInteraction"; import { usePrActions } from "@features/git-interaction/hooks/usePrActions"; import { usePrDetails } from "@features/git-interaction/hooks/usePrDetails"; +import { useGitInteractionStore } from "@features/git-interaction/state/gitInteractionStore"; +import { getSuggestedBranchName } from "@features/git-interaction/utils/getSuggestedBranchName"; import { getPrVisualConfig, parsePrNumber, } from "@features/git-interaction/utils/prStatus"; +import { DirtyTreeDialog } from "@features/sessions/components/DirtyTreeDialog"; +import { HandoffConfirmDialog } from "@features/sessions/components/HandoffConfirmDialog"; import { useSessionForTask } from "@features/sessions/hooks/useSession"; -import { useSessionCallbacks } from "@features/sessions/hooks/useSessionCallbacks"; +import { getLocalHandoffService } from "@features/sessions/service/localHandoffService"; +import { useHandoffDialogStore } from "@features/sessions/stores/handoffDialogStore"; import { ChevronDownIcon } from "@radix-ui/react-icons"; import { Button, DropdownMenu, Flex, Spinner, Text } from "@radix-ui/themes"; import type { Task } from "@shared/types"; +import { useState } from "react"; interface CloudGitInteractionHeaderProps { taskId: string; @@ -25,12 +36,59 @@ export function CloudGitInteractionHeader({ meta: { state, merged, draft }, } = usePrDetails(prUrl); const { execute, isPending } = usePrActions(prUrl); - const { handleContinueLocally } = useSessionCallbacks({ - taskId, - task, - session: session ?? undefined, - repoPath: null, - }); + const localHandoff = getLocalHandoffService(); + + const confirmOpen = useHandoffDialogStore((s) => s.confirmOpen); + const direction = useHandoffDialogStore((s) => s.direction); + const branchName = useHandoffDialogStore((s) => s.branchName); + const dirtyTreeOpen = useHandoffDialogStore((s) => s.dirtyTreeOpen); + const changedFiles = useHandoffDialogStore((s) => s.changedFiles); + const closeConfirm = useHandoffDialogStore((s) => s.closeConfirm); + const pendingAfterCommit = useHandoffDialogStore((s) => s.pendingAfterCommit); + + const commitRepoPath = pendingAfterCommit?.repoPath; + const git = useGitInteraction(taskId, commitRepoPath); + + const [isPreflighting, setIsPreflighting] = useState(false); + const [preflightError, setPreflightError] = useState(null); + + const handleConfirm = async () => { + setPreflightError(null); + setIsPreflighting(true); + try { + await localHandoff.start(taskId, task); + } catch (err) { + setPreflightError( + err instanceof Error ? err.message : "Preflight failed", + ); + } finally { + setIsPreflighting(false); + } + }; + + const handleCommitAndContinue = async () => { + localHandoff.hideDirtyTree(); + if (git.state.isFeatureBranch) { + useGitInteractionStore.getState().actions.openCommit("commit"); + return; + } + + useGitInteractionStore + .getState() + .actions.openBranch(getSuggestedBranchName(taskId, commitRepoPath)); + }; + + const handleBranchConfirm = async () => { + const branchCreated = await git.actions.runBranch(); + if (!branchCreated) return; + useGitInteractionStore.getState().actions.openCommit("commit"); + }; + + const handleCommitConfirm = async () => { + const committed = await git.actions.runCommit(); + if (!committed) return; + await localHandoff.resumePending(); + }; const config = prUrl && state !== null ? getPrVisualConfig(state, merged, draft) : null; @@ -43,7 +101,9 @@ export function CloudGitInteractionHeader({ size="1" variant="soft" disabled={session?.handoffInProgress} - onClick={handleContinueLocally} + onClick={() => + localHandoff.openConfirm(taskId, session?.cloudBranch ?? null) + } > {session?.handoffInProgress ? "Transferring..." : "Continue locally"} @@ -105,6 +165,78 @@ export function CloudGitInteractionHeader({ )} )} + {confirmOpen && direction === "to-local" && ( + { + if (!open) { + closeConfirm(); + setPreflightError(null); + } + }} + direction="to-local" + branchName={branchName} + onConfirm={handleConfirm} + isSubmitting={isPreflighting} + error={preflightError} + /> + )} + {dirtyTreeOpen && ( + { + if (!open) localHandoff.cancelPendingFlow(); + }} + changedFiles={changedFiles} + onCommitAndContinue={handleCommitAndContinue} + /> + )} + {pendingAfterCommit && ( + { + if (!open) { + git.actions.closeCommit(); + localHandoff.cancelPendingFlow(); + } + }} + branchName={git.state.currentBranch ?? pendingAfterCommit.branchName} + diffStats={git.state.diffStats} + commitMessage={git.modals.commitMessage} + onCommitMessageChange={git.actions.setCommitMessage} + nextStep={git.modals.commitNextStep} + onNextStepChange={git.actions.setCommitNextStep} + pushDisabledReason={git.state.pushDisabledReason} + onContinue={handleCommitConfirm} + isSubmitting={git.modals.isSubmitting} + error={git.modals.commitError} + onGenerateMessage={git.actions.generateCommitMessage} + isGeneratingMessage={git.modals.isGeneratingCommitMessage} + showCommitAllToggle={ + git.state.stagedFiles.length > 0 && + git.state.unstagedFiles.length > 0 + } + commitAll={git.modals.commitAll} + onCommitAllChange={git.actions.setCommitAll} + stagedFileCount={git.state.stagedFiles.length} + /> + )} + {pendingAfterCommit && ( + { + if (!open) { + git.actions.closeBranch(); + localHandoff.cancelPendingFlow(); + } + }} + branchName={git.modals.branchName} + onBranchNameChange={git.actions.setBranchName} + onConfirm={handleBranchConfirm} + isSubmitting={git.modals.isSubmitting} + error={git.modals.branchError} + /> + )} ); } diff --git a/apps/code/src/renderer/features/git-interaction/components/GitInteractionDialogs.tsx b/apps/code/src/renderer/features/git-interaction/components/GitInteractionDialogs.tsx index 987d99a64..bd7057eea 100644 --- a/apps/code/src/renderer/features/git-interaction/components/GitInteractionDialogs.tsx +++ b/apps/code/src/renderer/features/git-interaction/components/GitInteractionDialogs.tsx @@ -168,7 +168,7 @@ interface GitDialogProps { hideCancel?: boolean; } -function GitDialog({ +export function GitDialog({ open, onOpenChange, icon, diff --git a/apps/code/src/renderer/features/git-interaction/hooks/useGitInteraction.ts b/apps/code/src/renderer/features/git-interaction/hooks/useGitInteraction.ts index 10b84473c..ac30eb2ef 100644 --- a/apps/code/src/renderer/features/git-interaction/hooks/useGitInteraction.ts +++ b/apps/code/src/renderer/features/git-interaction/hooks/useGitInteraction.ts @@ -48,6 +48,7 @@ interface GitInteractionState { behind: number; currentBranch: string | null; defaultBranch: string | null; + isFeatureBranch: boolean; prBaseBranch: string | null; prHeadBranch: string | null; diffStats: DiffStats; @@ -69,9 +70,9 @@ interface GitInteractionActions { setPrTitle: (value: string) => void; setPrBody: (value: string) => void; setBranchName: (value: string) => void; - runCommit: () => Promise; + runCommit: () => Promise; runPush: () => Promise; - runBranch: () => Promise; + runBranch: () => Promise; runCreatePr: () => Promise; generateCommitMessage: () => Promise; generatePrTitleAndBody: () => Promise; @@ -338,12 +339,12 @@ export function useGitInteraction( } }; - const runCommit = async () => { - if (!repoPath) return; + const runCommit = async (): Promise => { + if (!repoPath) return false; if (store.commitNextStep === "commit-push" && computed.pushDisabledReason) { modal.setCommitError(computed.pushDisabledReason); - return; + return false; } modal.setIsSubmitting(true); @@ -363,7 +364,7 @@ export function useGitInteraction( "No changes detected to generate a commit message.", ); modal.setIsSubmitting(false); - return; + return false; } message = generated.message; @@ -376,7 +377,7 @@ export function useGitInteraction( : "Failed to generate commit message.", ); modal.setIsSubmitting(false); - return; + return false; } } @@ -394,7 +395,7 @@ export function useGitInteraction( if (!result.success) { trackGitAction(taskId, "commit", false, commitStagingContext); modal.setCommitError(result.message || "Commit failed."); - return; + return false; } trackGitAction(taskId, "commit", true, commitStagingContext); @@ -409,6 +410,7 @@ export function useGitInteraction( if (store.commitNextStep === "commit-push") { modal.openPush(git.hasRemote ? "push" : "publish"); } + return true; } finally { modal.setIsSubmitting(false); } @@ -516,8 +518,8 @@ export function useGitInteraction( } }; - const runBranch = async () => { - if (!repoPath) return; + const runBranch = async (): Promise => { + if (!repoPath) return false; modal.setIsSubmitting(true); modal.setBranchError(null); @@ -534,7 +536,7 @@ export function useGitInteraction( } modal.setBranchError(result.error); - return; + return false; } trackGitAction(taskId, "branch-here", true); @@ -547,12 +549,14 @@ export function useGitInteraction( ); modal.closeBranch(); + return true; } catch (error) { log.error("Failed to create branch", error); trackGitAction(taskId, "branch-here", false); modal.setBranchError( error instanceof Error ? error.message : "Failed to create branch.", ); + return false; } finally { modal.setIsSubmitting(false); } @@ -567,6 +571,7 @@ export function useGitInteraction( behind: git.behind, currentBranch: git.currentBranch, defaultBranch: git.defaultBranch, + isFeatureBranch: git.isFeatureBranch, prBaseBranch: computed.prBaseBranch, prHeadBranch: computed.prHeadBranch, diffStats: git.diffStats, diff --git a/apps/code/src/renderer/features/sessions/components/DirtyTreeDialog.tsx b/apps/code/src/renderer/features/sessions/components/DirtyTreeDialog.tsx new file mode 100644 index 000000000..2a2b6824e --- /dev/null +++ b/apps/code/src/renderer/features/sessions/components/DirtyTreeDialog.tsx @@ -0,0 +1,108 @@ +import { FileIcon } from "@components/ui/FileIcon"; +import { GitDialog } from "@features/git-interaction/components/GitInteractionDialogs"; +import { + getStatusIndicator, + type StatusIndicator, +} from "@features/git-interaction/utils/gitStatusUtils"; +import { Warning } from "@phosphor-icons/react"; +import { Badge, Box, Flex, Text } from "@radix-ui/themes"; +import type { HandoffChangedFile } from "../stores/handoffDialogStore"; + +interface DirtyTreeDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + changedFiles: HandoffChangedFile[]; + onCommitAndContinue: () => void; +} + +function FileLineStats({ file }: { file: HandoffChangedFile }) { + const hasStats = + file.linesAdded !== undefined || file.linesRemoved !== undefined; + if (!hasStats) return null; + + return ( + + {(file.linesAdded ?? 0) > 0 && ( + +{file.linesAdded} + )} + {(file.linesRemoved ?? 0) > 0 && ( + -{file.linesRemoved} + )} + + ); +} + +function StatusBadge({ indicator }: { indicator: StatusIndicator }) { + return ( + + {indicator.label} + + ); +} + +export function DirtyTreeDialog({ + open, + onOpenChange, + changedFiles, + onCommitAndContinue, +}: DirtyTreeDialogProps) { + return ( + } + title="Uncommitted changes" + error={null} + buttonLabel="Commit and continue" + isSubmitting={false} + onSubmit={onCommitAndContinue} + > + + + The following local files have uncommitted changes that would be + overwritten by the handoff. Commit them to continue. + + + {changedFiles.map((file) => { + const fileName = file.path.split("/").pop() || file.path; + const indicator = getStatusIndicator(file.status); + return ( + + + + {fileName} + + + + + ); + })} + + + + ); +} diff --git a/apps/code/src/renderer/features/sessions/components/HandoffConfirmDialog.tsx b/apps/code/src/renderer/features/sessions/components/HandoffConfirmDialog.tsx new file mode 100644 index 000000000..35f4f3485 --- /dev/null +++ b/apps/code/src/renderer/features/sessions/components/HandoffConfirmDialog.tsx @@ -0,0 +1,54 @@ +import { GitDialog } from "@features/git-interaction/components/GitInteractionDialogs"; +import { ArrowLineDown, Cloud } from "@phosphor-icons/react"; +import { Code, Text } from "@radix-ui/themes"; + +interface HandoffConfirmDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + direction: "to-local" | "to-cloud"; + branchName: string | null; + onConfirm: () => void; + isSubmitting: boolean; + error: string | null; +} + +export function HandoffConfirmDialog({ + open, + onOpenChange, + direction, + branchName, + onConfirm, + isSubmitting, + error, +}: HandoffConfirmDialogProps) { + const isToLocal = direction === "to-local"; + + return ( + : } + title={isToLocal ? "Continue locally" : "Continue in cloud"} + error={error} + buttonLabel="Continue" + isSubmitting={isSubmitting} + onSubmit={onConfirm} + > + + {isToLocal ? ( + <> + This will bring your changes from the cloud run into your local + environment on branch{" "} + {branchName ?? "unknown"}. + + ) : ( + <> + This will send your changes on branch{" "} + {branchName ?? "unknown"} to the cloud and + continue running there. + + )} + + + ); +} diff --git a/apps/code/src/renderer/features/sessions/hooks/useSessionCallbacks.ts b/apps/code/src/renderer/features/sessions/hooks/useSessionCallbacks.ts index f61ad7286..7c168787f 100644 --- a/apps/code/src/renderer/features/sessions/hooks/useSessionCallbacks.ts +++ b/apps/code/src/renderer/features/sessions/hooks/useSessionCallbacks.ts @@ -13,42 +13,6 @@ import { sessionStoreSetters } from "../stores/sessionStore"; const log = logger.scope("session-callbacks"); -async function resolveRepoPathFromRemote( - remoteUrl: string | undefined | null, -): Promise { - if (!remoteUrl) return null; - const repo = await trpcClient.folders.getRepositoryByRemoteUrl.query({ - remoteUrl, - }); - return repo?.path ?? null; -} - -async function resolveRepoPathFromPicker( - taskId: string, -): Promise { - const selectedPath = await trpcClient.os.selectDirectory.query(); - if (!selectedPath) return null; - - let folder = (await trpcClient.folders.getFolders.query()).find( - (f) => f.path === selectedPath, - ); - if (!folder) { - folder = await trpcClient.folders.addFolder.mutate({ - folderPath: selectedPath, - }); - } - - await trpcClient.workspace.create.mutate({ - taskId, - mainRepoPath: selectedPath, - folderId: folder.id, - folderPath: selectedPath, - mode: "local", - }); - - return selectedPath; -} - interface UseSessionCallbacksOptions { taskId: string; task: Task; @@ -188,23 +152,7 @@ export function useSessionCallbacks({ [taskId, repoPath], ); - const handleContinueLocally = useCallback(async () => { - try { - const targetPath = - (await resolveRepoPathFromRemote(task.repository)) ?? - (await resolveRepoPathFromPicker(taskId)); - - if (!targetPath) return; - - await getSessionService().handoffToLocal(taskId, targetPath); - } catch (error) { - log.error("Failed to hand off to local", error); - const message = error instanceof Error ? error.message : "Unknown error"; - toast.error(`Failed to continue locally: ${message}`); - } - }, [taskId, task.repository]); - - const handleContinueInCloud = useCallback(async () => { + const initiateHandoffToCloud = useCallback(async () => { if (!repoPath) return; try { await getSessionService().handoffToCloud(taskId, repoPath); @@ -221,7 +169,6 @@ export function useSessionCallbacks({ handleRetry, handleNewSession, handleBashCommand, - handleContinueLocally, - handleContinueInCloud, + initiateHandoffToCloud, }; } diff --git a/apps/code/src/renderer/features/sessions/hooks/useSessionViewState.ts b/apps/code/src/renderer/features/sessions/hooks/useSessionViewState.ts index 4b2551599..8b36d6954 100644 --- a/apps/code/src/renderer/features/sessions/hooks/useSessionViewState.ts +++ b/apps/code/src/renderer/features/sessions/hooks/useSessionViewState.ts @@ -17,7 +17,16 @@ export function useSessionViewState(taskId: string, task: Task) { const isCloudRunTerminal = isCloud && !isCloudRunNotTerminal; const hasError = session?.status === "error"; - const isRunning = isCloud ? !hasError : session?.status === "connected"; + const handoffInProgress = session?.handoffInProgress ?? false; + + let isRunning = false; + if (!handoffInProgress) { + if (isCloud) { + isRunning = !hasError; + } else { + isRunning = session?.status === "connected"; + } + } const events = session?.events ?? []; const isPromptPending = session?.isPromptPending ?? false; diff --git a/apps/code/src/renderer/features/sessions/service/localHandoffService.ts b/apps/code/src/renderer/features/sessions/service/localHandoffService.ts new file mode 100644 index 000000000..581c3b32e --- /dev/null +++ b/apps/code/src/renderer/features/sessions/service/localHandoffService.ts @@ -0,0 +1,135 @@ +import { trpcClient } from "@renderer/trpc/client"; +import type { Task } from "@shared/types"; +import { logger } from "@utils/logger"; +import { toast } from "@utils/toast"; +import { useHandoffDialogStore } from "../stores/handoffDialogStore"; +import { getSessionService } from "./service"; + +const log = logger.scope("local-handoff-service"); + +async function resolveRepoPathFromRemote( + remoteUrl: string | undefined | null, +): Promise { + if (!remoteUrl) return null; + const repo = await trpcClient.folders.getRepositoryByRemoteUrl.query({ + remoteUrl, + }); + return repo?.path ?? null; +} + +async function resolveRepoPathFromPicker( + taskId: string, +): Promise { + const selectedPath = await trpcClient.os.selectDirectory.query(); + if (!selectedPath) return null; + + let folder = (await trpcClient.folders.getFolders.query()).find( + (f) => f.path === selectedPath, + ); + if (!folder) { + folder = await trpcClient.folders.addFolder.mutate({ + folderPath: selectedPath, + }); + } + + await trpcClient.workspace.create.mutate({ + taskId, + mainRepoPath: selectedPath, + folderId: folder.id, + folderPath: selectedPath, + mode: "local", + }); + + return selectedPath; +} + +let serviceInstance: LocalHandoffService | null = null; + +export function getLocalHandoffService(): LocalHandoffService { + if (!serviceInstance) { + serviceInstance = new LocalHandoffService(); + } + return serviceInstance; +} + +export class LocalHandoffService { + public openConfirm(taskId: string, branchName: string | null): void { + useHandoffDialogStore + .getState() + .openConfirm(taskId, "to-local", branchName); + } + + public closeConfirm(): void { + useHandoffDialogStore.getState().closeConfirm(); + } + + public cancelPendingFlow(): void { + useHandoffDialogStore.getState().cancelPendingHandoff(); + } + + public hideDirtyTree(): void { + useHandoffDialogStore.getState().hideDirtyTree(); + } + + public getPendingAfterCommit() { + return useHandoffDialogStore.getState().pendingAfterCommit; + } + + public async start(taskId: string, task: Task): Promise { + try { + const targetPath = + (await resolveRepoPathFromRemote(task.repository)) ?? + (await resolveRepoPathFromPicker(taskId)); + + if (!targetPath) return; + + const preflight = await getSessionService().preflightToLocal( + taskId, + targetPath, + ); + + if (preflight.canHandoff) { + this.closeConfirm(); + await getSessionService().handoffToLocal(taskId, targetPath); + return; + } + + if (preflight.localTreeDirty && preflight.changedFiles) { + useHandoffDialogStore + .getState() + .openDirtyTreeForPendingHandoff(preflight.changedFiles, { + taskId, + repoPath: targetPath, + branchName: preflight.localGitState?.branch ?? null, + }); + return; + } + + toast.error(preflight.reason ?? "Cannot continue locally"); + this.closeConfirm(); + } catch (error) { + log.error("Failed to hand off to local", error); + const message = error instanceof Error ? error.message : "Unknown error"; + toast.error(`Failed to continue locally: ${message}`); + this.closeConfirm(); + } + } + + public async resumePending(): Promise { + const pending = this.getPendingAfterCommit(); + if (!pending) return; + + useHandoffDialogStore.getState().clearPendingAfterCommit(); + + try { + await getSessionService().handoffToLocal( + pending.taskId, + pending.repoPath, + ); + } catch (error) { + log.error("Failed to resume handoff to local", error); + const message = error instanceof Error ? error.message : "Unknown error"; + toast.error(`Failed to continue locally: ${message}`); + } + } +} diff --git a/apps/code/src/renderer/features/sessions/service/service.ts b/apps/code/src/renderer/features/sessions/service/service.ts index ec4fe0048..a486f339b 100644 --- a/apps/code/src/renderer/features/sessions/service/service.ts +++ b/apps/code/src/renderer/features/sessions/service/service.ts @@ -2114,6 +2114,40 @@ export class SessionService { ); } + async preflightToLocal(taskId: string, repoPath: string) { + const session = sessionStoreSetters.getSessionByTaskId(taskId); + if (!session) + return { + canHandoff: false as const, + localTreeDirty: false as const, + reason: "No session found", + }; + + const auth = await this.getHandoffAuth(); + if (!auth) + return { + canHandoff: false as const, + localTreeDirty: false as const, + reason: "Authentication required", + }; + + const preflight = await trpcClient.handoff.preflight.query({ + taskId, + runId: session.taskRunId, + repoPath, + apiHost: auth.apiHost, + teamId: auth.projectId, + }); + + return { + canHandoff: preflight.canHandoff, + localTreeDirty: preflight.localTreeDirty, + localGitState: preflight.localGitState, + changedFiles: preflight.changedFiles, + reason: preflight.reason, + }; + } + async handoffToLocal(taskId: string, repoPath: string): Promise { const session = sessionStoreSetters.getSessionByTaskId(taskId); if (!session) { @@ -2145,8 +2179,11 @@ export class SessionService { ); this.transitionToLocalSession(runId); this.subscribeToChannel(runId); - queryClient.invalidateQueries({ queryKey: ["tasks"] }); - queryClient.invalidateQueries(trpc.workspace.getAll.pathFilter()); + await Promise.all([ + queryClient.refetchQueries({ queryKey: ["tasks"] }), + queryClient.refetchQueries(trpc.workspace.getAll.pathFilter()), + ]); + sessionStoreSetters.updateSession(runId, { handoffInProgress: false }); log.info("Cloud-to-local handoff complete", { taskId, runId }); } catch (err) { log.error("Handoff failed", { taskId, err }); @@ -2209,14 +2246,16 @@ export class SessionService { cloudOutput: undefined, cloudErrorMessage: undefined, cloudBranch: undefined, - handoffInProgress: false, status: "disconnected", processedLineCount: result.logEntryCount ?? 0, }); this.watchCloudTask(taskId, runId, auth.apiHost, auth.projectId); - queryClient.invalidateQueries({ queryKey: ["tasks"] }); - queryClient.invalidateQueries(trpc.workspace.getAll.pathFilter()); + await Promise.all([ + queryClient.refetchQueries({ queryKey: ["tasks"] }), + queryClient.refetchQueries(trpc.workspace.getAll.pathFilter()), + ]); + sessionStoreSetters.updateSession(runId, { handoffInProgress: false }); log.info("Local-to-cloud handoff complete", { taskId, runId }); } catch (err) { log.error("Handoff to cloud failed", { taskId, err }); @@ -2305,7 +2344,6 @@ export class SessionService { cloudOutput: undefined, cloudErrorMessage: undefined, cloudBranch: undefined, - handoffInProgress: false, status: "connected", }); } diff --git a/apps/code/src/renderer/features/sessions/stores/handoffDialogStore.ts b/apps/code/src/renderer/features/sessions/stores/handoffDialogStore.ts new file mode 100644 index 000000000..85de78529 --- /dev/null +++ b/apps/code/src/renderer/features/sessions/stores/handoffDialogStore.ts @@ -0,0 +1,85 @@ +import type { GitFileStatus } from "@shared/types"; +import { create } from "zustand"; + +type HandoffDirection = "to-local" | "to-cloud"; + +export interface HandoffChangedFile { + path: string; + status: GitFileStatus; + linesAdded?: number; + linesRemoved?: number; +} + +interface HandoffDialogState { + confirmOpen: boolean; + direction: HandoffDirection | null; + taskId: string | null; + branchName: string | null; + dirtyTreeOpen: boolean; + changedFiles: HandoffChangedFile[]; + pendingAfterCommit: { + taskId: string; + repoPath: string; + branchName: string | null; + } | null; +} + +interface HandoffDialogActions { + openConfirm: ( + taskId: string, + direction: HandoffDirection, + branchName: string | null, + ) => void; + closeConfirm: () => void; + openDirtyTreeForPendingHandoff: ( + changedFiles: HandoffChangedFile[], + pending: { + taskId: string; + repoPath: string; + branchName: string | null; + }, + ) => void; + hideDirtyTree: () => void; + cancelPendingHandoff: () => void; + clearPendingAfterCommit: () => void; + reset: () => void; +} + +type HandoffDialogStore = HandoffDialogState & HandoffDialogActions; + +const initialState: HandoffDialogState = { + confirmOpen: false, + direction: null, + taskId: null, + branchName: null, + dirtyTreeOpen: false, + changedFiles: [], + pendingAfterCommit: null, +}; + +const closedDirtyTreeState = { + dirtyTreeOpen: false, + changedFiles: [], +} satisfies Pick; + +export const useHandoffDialogStore = create((set) => ({ + ...initialState, + openConfirm: (taskId, direction, branchName) => + set({ confirmOpen: true, taskId, direction, branchName }), + closeConfirm: () => set({ confirmOpen: false }), + openDirtyTreeForPendingHandoff: (changedFiles, pending) => + set({ + confirmOpen: false, + dirtyTreeOpen: true, + changedFiles, + pendingAfterCommit: pending, + }), + hideDirtyTree: () => set(closedDirtyTreeState), + cancelPendingHandoff: () => + set({ + ...closedDirtyTreeState, + pendingAfterCommit: null, + }), + clearPendingAfterCommit: () => set({ pendingAfterCommit: null }), + reset: () => set(initialState), +})); diff --git a/packages/agent/src/handoff-checkpoint.test.ts b/packages/agent/src/handoff-checkpoint.test.ts index 78749ff07..3f1587665 100644 --- a/packages/agent/src/handoff-checkpoint.test.ts +++ b/packages/agent/src/handoff-checkpoint.test.ts @@ -179,5 +179,6 @@ describe("HandoffCheckpointTracker", () => { expect(status).toContain("M tracked.txt"); expect(status).toContain(" M unstaged.txt"); expect(status).toContain("?? untracked.txt"); + expect(localRepo.exists(".posthog/tmp")).toBe(false); }); }); diff --git a/packages/agent/src/handoff-checkpoint.ts b/packages/agent/src/handoff-checkpoint.ts index 818535a33..35c9bcdd1 100644 --- a/packages/agent/src/handoff-checkpoint.ts +++ b/packages/agent/src/handoff-checkpoint.ts @@ -1,4 +1,11 @@ -import { mkdir, readFile, rm, writeFile } from "node:fs/promises"; +import { + mkdir, + readdir, + readFile, + rm, + rmdir, + writeFile, +} from "node:fs/promises"; import { join } from "node:path"; import { type GitHandoffBranchDivergence, @@ -161,6 +168,7 @@ export class HandoffCheckpointTracker { } finally { await this.removeIfPresent(packPath); await this.removeIfPresent(indexPath); + await this.removeTmpDirIfEmpty(tmpDir); } } @@ -361,4 +369,12 @@ export class HandoffCheckpointTracker { } await rm(filePath, { force: true }).catch(() => {}); } + + private async removeTmpDirIfEmpty(tmpDir: string): Promise { + const entries = await readdir(tmpDir).catch(() => null); + if (!entries || entries.length > 0) { + return; + } + await rmdir(tmpDir).catch(() => {}); + } } diff --git a/packages/agent/src/sagas/apply-snapshot-saga.test.ts b/packages/agent/src/sagas/apply-snapshot-saga.test.ts index 89f3970dd..b9ef1e50f 100644 --- a/packages/agent/src/sagas/apply-snapshot-saga.test.ts +++ b/packages/agent/src/sagas/apply-snapshot-saga.test.ts @@ -328,6 +328,7 @@ describe("ApplySnapshotSaga", () => { }); expect(repo.exists(".posthog/tmp/test-tree-hash.tar.gz")).toBe(false); + expect(repo.exists(".posthog/tmp")).toBe(false); }); it("cleans up downloaded archive on checkout failure (rollback verification)", async () => { diff --git a/packages/agent/src/sagas/apply-snapshot-saga.ts b/packages/agent/src/sagas/apply-snapshot-saga.ts index 01a942c30..ab0e554de 100644 --- a/packages/agent/src/sagas/apply-snapshot-saga.ts +++ b/packages/agent/src/sagas/apply-snapshot-saga.ts @@ -1,4 +1,4 @@ -import { mkdir, rm, writeFile } from "node:fs/promises"; +import { mkdir, readdir, rm, rmdir, writeFile } from "node:fs/promises"; import { join } from "node:path"; import { ApplyTreeSaga as GitApplyTreeSaga } from "@posthog/git/sagas/tree"; import { Saga } from "@posthog/shared"; @@ -37,64 +37,78 @@ export class ApplySnapshotSaga extends Saga< const archiveUrl = snapshot.archiveUrl; - await this.step({ - name: "create_tmp_dir", - execute: () => mkdir(tmpDir, { recursive: true }), - rollback: async () => {}, - }); + try { + await this.step({ + name: "create_tmp_dir", + execute: () => mkdir(tmpDir, { recursive: true }), + rollback: async () => {}, + }); - const archivePath = join(tmpDir, `${snapshot.treeHash}.tar.gz`); - this.archivePath = archivePath; - await this.step({ - name: "download_archive", - execute: async () => { - const arrayBuffer = await apiClient.downloadArtifact( - taskId, - runId, - archiveUrl, - ); - if (!arrayBuffer) { - throw new Error("Failed to download archive"); - } - const base64Content = Buffer.from(arrayBuffer).toString("utf-8"); - const binaryContent = Buffer.from(base64Content, "base64"); - await writeFile(archivePath, binaryContent); - this.log.info("Tree archive downloaded", { - treeHash: snapshot.treeHash, - snapshotBytes: binaryContent.byteLength, - snapshotWireBytes: arrayBuffer.byteLength, - totalBytes: binaryContent.byteLength, - totalWireBytes: arrayBuffer.byteLength, - }); - }, - rollback: async () => { - if (this.archivePath) { - await rm(this.archivePath, { force: true }).catch(() => {}); - } - }, - }); + const archivePath = join(tmpDir, `${snapshot.treeHash}.tar.gz`); + this.archivePath = archivePath; + await this.step({ + name: "download_archive", + execute: async () => { + const arrayBuffer = await apiClient.downloadArtifact( + taskId, + runId, + archiveUrl, + ); + if (!arrayBuffer) { + throw new Error("Failed to download archive"); + } + const base64Content = Buffer.from(arrayBuffer).toString("utf-8"); + const binaryContent = Buffer.from(base64Content, "base64"); + await writeFile(archivePath, binaryContent); + this.log.info("Tree archive downloaded", { + treeHash: snapshot.treeHash, + snapshotBytes: binaryContent.byteLength, + snapshotWireBytes: arrayBuffer.byteLength, + totalBytes: binaryContent.byteLength, + totalWireBytes: arrayBuffer.byteLength, + }); + }, + rollback: async () => { + if (this.archivePath) { + await rm(this.archivePath, { force: true }).catch(() => {}); + } + }, + }); - const gitApplySaga = new GitApplyTreeSaga(this.log); - const applyResult = await gitApplySaga.run({ - baseDir: repositoryPath, - treeHash: snapshot.treeHash, - baseCommit: snapshot.baseCommit, - changes: snapshot.changes, - archivePath: this.archivePath, - }); + const gitApplySaga = new GitApplyTreeSaga(this.log); + const applyResult = await gitApplySaga.run({ + baseDir: repositoryPath, + treeHash: snapshot.treeHash, + baseCommit: snapshot.baseCommit, + changes: snapshot.changes, + archivePath: this.archivePath, + }); - if (!applyResult.success) { - throw new Error(`Failed to apply tree: ${applyResult.error}`); - } + if (!applyResult.success) { + throw new Error(`Failed to apply tree: ${applyResult.error}`); + } - await rm(this.archivePath, { force: true }).catch(() => {}); + this.log.info("Tree snapshot applied", { + treeHash: snapshot.treeHash, + totalChanges: snapshot.changes.length, + deletedFiles: snapshot.changes.filter((c) => c.status === "D").length, + }); - this.log.info("Tree snapshot applied", { - treeHash: snapshot.treeHash, - totalChanges: snapshot.changes.length, - deletedFiles: snapshot.changes.filter((c) => c.status === "D").length, - }); + return { treeHash: snapshot.treeHash }; + } finally { + if (this.archivePath) { + await rm(this.archivePath, { force: true }).catch(() => {}); + } + await this.removeTmpDirIfEmpty(tmpDir); + this.archivePath = null; + } + } - return { treeHash: snapshot.treeHash }; + private async removeTmpDirIfEmpty(tmpDir: string): Promise { + const entries = await readdir(tmpDir).catch(() => null); + if (!entries || entries.length > 0) { + return; + } + await rmdir(tmpDir).catch(() => {}); } } diff --git a/packages/agent/src/sagas/capture-tree-saga.test.ts b/packages/agent/src/sagas/capture-tree-saga.test.ts index 275d0e7bc..39d60e46e 100644 --- a/packages/agent/src/sagas/capture-tree-saga.test.ts +++ b/packages/agent/src/sagas/capture-tree-saga.test.ts @@ -366,6 +366,24 @@ describe("CaptureTreeSaga", () => { const indexFiles = files.filter((f: string) => f.startsWith("index-")); expect(indexFiles).toHaveLength(0); }); + + it("cleans up uploaded tree archive and tmp dir on success", async () => { + const mockApiClient = createMockApiClient(); + + await repo.writeFile("new.ts", "content"); + + const saga = new CaptureTreeSaga(mockLogger); + const result = await saga.run({ + repositoryPath: repo.path, + taskId: "task-1", + runId: "run-1", + lastTreeHash: null, + apiClient: mockApiClient, + }); + + expect(result.success).toBe(true); + expect(repo.exists(".posthog/tmp")).toBe(false); + }); }); describe("git state isolation", () => { diff --git a/packages/agent/src/sagas/capture-tree-saga.ts b/packages/agent/src/sagas/capture-tree-saga.ts index 851082637..e0b0980a1 100644 --- a/packages/agent/src/sagas/capture-tree-saga.ts +++ b/packages/agent/src/sagas/capture-tree-saga.ts @@ -1,5 +1,5 @@ import { existsSync } from "node:fs"; -import { readFile, rm } from "node:fs/promises"; +import { readdir, readFile, rm, rmdir } from "node:fs/promises"; import { join } from "node:path"; import { CaptureTreeSaga as GitCaptureTreeSaga } from "@posthog/git/sagas/tree"; import { Saga } from "@posthog/shared"; @@ -45,60 +45,67 @@ export class CaptureTreeSaga extends Saga { ? join(tmpDir, `tree-${Date.now()}.tar.gz`) : undefined; - const gitCaptureSaga = new GitCaptureTreeSaga(this.log); - const captureResult = await gitCaptureSaga.run({ - baseDir: repositoryPath, - lastTreeHash, - archivePath, - }); + try { + const gitCaptureSaga = new GitCaptureTreeSaga(this.log); + const captureResult = await gitCaptureSaga.run({ + baseDir: repositoryPath, + lastTreeHash, + archivePath, + }); - if (!captureResult.success) { - throw new Error(`Failed to capture tree: ${captureResult.error}`); - } + if (!captureResult.success) { + throw new Error(`Failed to capture tree: ${captureResult.error}`); + } - const { - snapshot: gitSnapshot, - archivePath: createdArchivePath, - changed, - } = captureResult.data; - - if (!changed || !gitSnapshot) { - this.log.debug("No changes since last capture", { lastTreeHash }); - return { snapshot: null, newTreeHash: lastTreeHash }; - } + const { + snapshot: gitSnapshot, + archivePath: createdArchivePath, + changed, + } = captureResult.data; - let archiveUrl: string | undefined; - if (apiClient && createdArchivePath) { - try { - archiveUrl = await this.uploadArchive( - createdArchivePath, - gitSnapshot.treeHash, - apiClient, - taskId, - runId, - ); - } finally { - await rm(createdArchivePath, { force: true }).catch(() => {}); + if (!changed || !gitSnapshot) { + this.log.debug("No changes since last capture", { lastTreeHash }); + return { snapshot: null, newTreeHash: lastTreeHash }; } - } - const snapshot: TreeSnapshot = { - treeHash: gitSnapshot.treeHash, - baseCommit: gitSnapshot.baseCommit, - changes: gitSnapshot.changes, - timestamp: gitSnapshot.timestamp, - interrupted, - archiveUrl, - }; - - this.log.info("Tree captured", { - treeHash: snapshot.treeHash, - changes: snapshot.changes.length, - interrupted, - archiveUrl, - }); + let archiveUrl: string | undefined; + if (apiClient && createdArchivePath) { + try { + archiveUrl = await this.uploadArchive( + createdArchivePath, + gitSnapshot.treeHash, + apiClient, + taskId, + runId, + ); + } finally { + await rm(createdArchivePath, { force: true }).catch(() => {}); + } + } - return { snapshot, newTreeHash: snapshot.treeHash }; + const snapshot: TreeSnapshot = { + treeHash: gitSnapshot.treeHash, + baseCommit: gitSnapshot.baseCommit, + changes: gitSnapshot.changes, + timestamp: gitSnapshot.timestamp, + interrupted, + archiveUrl, + }; + + this.log.info("Tree captured", { + treeHash: snapshot.treeHash, + changes: snapshot.changes.length, + interrupted, + archiveUrl, + }); + + return { snapshot, newTreeHash: snapshot.treeHash }; + } finally { + if (archivePath) { + await rm(archivePath, { force: true }).catch(() => {}); + } + await this.removeTmpDirIfEmpty(tmpDir); + } } private async uploadArchive( @@ -147,4 +154,12 @@ export class CaptureTreeSaga extends Saga { return archiveUrl; } + + private async removeTmpDirIfEmpty(tmpDir: string): Promise { + const entries = await readdir(tmpDir).catch(() => null); + if (!entries || entries.length > 0) { + return; + } + await rmdir(tmpDir).catch(() => {}); + } }