diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/Agent.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/Agent.kt index a700673..3df4878 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/Agent.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/Agent.kt @@ -13,13 +13,16 @@ import kotlinx.atomicfu.atomic import kotlinx.atomicfu.update import kotlinx.collections.immutable.persistentMapOf import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineStart import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.Job +import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.currentCoroutineContext +import kotlinx.coroutines.launch import kotlinx.coroutines.withContext import kotlinx.serialization.json.JsonElement import kotlin.coroutines.AbstractCoroutineContextElement import kotlin.coroutines.CoroutineContext -import kotlin.coroutines.cancellation.CancellationException import kotlin.math.min import kotlin.uuid.ExperimentalUuidApi @@ -58,47 +61,52 @@ public class Agent( val clientOperations: ClientSessionOperations, protocol: Protocol ) : BaseSessionWrapper(agent, protocol) { - private class PromptSession(val currentRequestId: RequestId) + private class PromptSession(val currentRequestId: RequestId, val promptJob: Job) private val _activePrompt = atomic(null) suspend fun prompt(content: List, _meta: JsonElement? = null): PromptResponse { val currentRpcRequest = currentCoroutineContext().jsonRpcRequest - if (!_activePrompt.compareAndSet(null, PromptSession(currentRpcRequest.id))) error("There is already active prompt execution") - try { - var response: PromptResponse? = null - - agentSession.prompt(content, _meta).collect { event -> - when (event) { - is Event.PromptResponseEvent -> { - if (response != null) { - logger.error { "Received repeated prompt response: ${event.response} (previous: $response). The last is used" } + var response: PromptResponse? = null + return coroutineScope { + try { + val promptJob = launch(start = CoroutineStart.LAZY) { + agentSession.prompt(content, _meta).collect { event -> + when (event) { + is Event.PromptResponseEvent -> { + if (response != null) { + logger.error { "Received repeated prompt response: ${event.response} (previous: $response). The last is used" } + } + response = event.response + } + + is Event.SessionUpdateEvent -> { + clientOperations.notify(event.update, _meta) + } } - response = event.response } + } - is Event.SessionUpdateEvent -> { - clientOperations.notify(event.update, _meta) - } + val promptSession = PromptSession(currentRpcRequest.id, promptJob) + if (!_activePrompt.compareAndSet(null, promptSession)) { + error("There is already active prompt execution") } + promptJob.join() + response ?: PromptResponse( + stopReason = if (promptJob.isCancelled) StopReason.CANCELLED else StopReason.END_TURN + ) + } finally { + _activePrompt.getAndSet(null) } - - return response ?: PromptResponse(StopReason.END_TURN) - } - catch (ce: CancellationException) { - logger.trace(ce) { "Prompt job cancelled" } - return PromptResponse(StopReason.CANCELLED) - } - finally { - _activePrompt.getAndSet(null) } } suspend fun cancel() { - // TODO do we need it while the cancellation can be handled by coroutine mechanism (catching CE inside `prompt`) + // notify AgentSession about upcoming cancellation, this way implementations can gracefully stop ongoing requests agentSession.cancel() val activePrompt = _activePrompt.getAndSet(null) if (activePrompt != null) { + logger.trace { "Cancelling prompt" } // we expect that all nested outgoing jobs will be cancelled automatically due to structured concurrency // -> prompt task // <- [request] read file @@ -106,7 +114,7 @@ public class Agent( // <- [request] permissions // |suspended| // cancelling the whole prompt should cancel all nested outgoing requests. These requests on CE will propagate cancellation to the other side - protocol.cancelPendingIncomingRequest(activePrompt.currentRequestId) + activePrompt.promptJob.cancel() } } } diff --git a/acp/src/jvmTest/kotlin/com/agentclientprotocol/agent/AgentTest.kt b/acp/src/jvmTest/kotlin/com/agentclientprotocol/agent/AgentTest.kt new file mode 100644 index 0000000..b4334a9 --- /dev/null +++ b/acp/src/jvmTest/kotlin/com/agentclientprotocol/agent/AgentTest.kt @@ -0,0 +1,70 @@ +package com.agentclientprotocol.agent + +import com.agentclientprotocol.annotations.UnstableApi +import com.agentclientprotocol.model.* +import kotlinx.coroutines.async +import kotlinx.coroutines.delay +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.seconds + +@OptIn(UnstableApi::class) +class AgentTest { + + @Test + fun `initialize agent`() { + withTestAgent { testAgent -> + val (response) = testAgent.testInitialize(InitializeRequest(LATEST_PROTOCOL_VERSION)) + assertNotNull(response) + assertTrue(testAgent.agentSupport.isInitialized) + } + } + + @Test + fun `create new session`() { + withInitializedTestAgent { testAgent -> + val (response) = testAgent.testNewSession(NewSessionRequest(cwd = ".", mcpServers = emptyList())) + assertNotNull(response) + assertTrue(response.sessionId in testAgent.agentSupport.createdSessions) + } + } + + @Test + fun `simple prompt turn`() { + withTestAgentSession(promptHandler = echoPromptHandler) { testAgent, _ -> + testAgent.simplePrompt("hello").let { (response, updates) -> + assertEquals(StopReason.END_TURN, response.stopReason) + assertEquals(1, updates.size) + + val message = updates.filterIsInstance() + .map { (it.content as? ContentBlock.Text)?.text } + .firstOrNull() + assertEquals("hello", message) + } + + testAgent.simplePrompt("world").let { (response, updates) -> + assertEquals(StopReason.END_TURN, response.stopReason) + assertEquals(1, updates.size) + + val message = updates.filterIsInstance() + .map { (it.content as? ContentBlock.Text)?.text } + .firstOrNull() + assertEquals("world", message) + } + } + } + + @Test + fun `prompt cancellation`() { + withTestAgentSession(promptHandler = delayEchoPromptHandler(2.seconds)) { testAgent, session -> + val deferredResponse = async { testAgent.simplePrompt("hello").first } + delay(1.seconds) + testAgent.testCancel(CancelNotification(session.sessionId)) + + val response = deferredResponse.await() + assertEquals(StopReason.CANCELLED, response.stopReason) + } + } +} diff --git a/acp/src/jvmTest/kotlin/com/agentclientprotocol/agent/TestAgent.kt b/acp/src/jvmTest/kotlin/com/agentclientprotocol/agent/TestAgent.kt new file mode 100644 index 0000000..f683774 --- /dev/null +++ b/acp/src/jvmTest/kotlin/com/agentclientprotocol/agent/TestAgent.kt @@ -0,0 +1,179 @@ +@file:OptIn(UnstableApi::class) + +package com.agentclientprotocol.agent + +import com.agentclientprotocol.annotations.UnstableApi +import com.agentclientprotocol.client.ClientInfo +import com.agentclientprotocol.common.Event +import com.agentclientprotocol.common.SessionCreationParameters +import com.agentclientprotocol.model.AcpMethod +import com.agentclientprotocol.model.AcpNotification +import com.agentclientprotocol.model.AcpRequest +import com.agentclientprotocol.model.AcpResponse +import com.agentclientprotocol.model.CancelNotification +import com.agentclientprotocol.model.ContentBlock +import com.agentclientprotocol.model.InitializeRequest +import com.agentclientprotocol.model.LATEST_PROTOCOL_VERSION +import com.agentclientprotocol.model.McpServer +import com.agentclientprotocol.model.NewSessionRequest +import com.agentclientprotocol.model.PromptRequest +import com.agentclientprotocol.model.PromptResponse +import com.agentclientprotocol.model.SessionId +import com.agentclientprotocol.model.SessionUpdate +import com.agentclientprotocol.model.StopReason +import com.agentclientprotocol.protocol.Protocol +import com.agentclientprotocol.rpc.ACPJson +import com.agentclientprotocol.rpc.JsonRpcNotification +import com.agentclientprotocol.rpc.JsonRpcResponse +import kotlinx.atomicfu.atomic +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.FlowCollector +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.JsonElement +import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.seconds + +class TestAgent(val agent: Agent, val agentSupport: TestAgentSupport, val transport: TestTransport) { + suspend fun testRequest( + method: AcpMethod.AcpRequestResponseMethod, + request: TRequest + ): Pair> { + val received = transport.fireTestRequest( + methodName = method.methodName, + params = ACPJson.encodeToJsonElement(method.requestSerializer, request) + ) + val response = (received.lastOrNull() as? JsonRpcResponse)?.result?.let { + ACPJson.decodeFromJsonElement(method.responseSerializer, it) + } + val notifications = received.filterIsInstance() + return response to notifications + } + + fun testNotification( + method: AcpMethod.AcpNotificationMethod, + notification: TNotification + ) { + transport.fireTestNotification(method.methodName, ACPJson.encodeToJsonElement(method.serializer, notification)) + } + + fun close() { + agent.protocol.close() + } + + suspend fun testInitialize(request: InitializeRequest) = testRequest(AcpMethod.AgentMethods.Initialize, request) + suspend fun testNewSession(request: NewSessionRequest) = testRequest(AcpMethod.AgentMethods.SessionNew, request) + suspend fun testPrompt(request: PromptRequest) = testRequest(AcpMethod.AgentMethods.SessionPrompt, request) + + fun testCancel(notification: CancelNotification) = testNotification(AcpMethod.AgentMethods.SessionCancel, notification) +} + +suspend fun TestAgent.simplePrompt(prompt: String): Pair> { + val session = agentSupport.createdSessions.values.single() + val (resp, notifications) = testPrompt(PromptRequest(session.sessionId, listOf(ContentBlock.Text(prompt)))) + checkNotNull(resp) + + return resp to notifications + .filter { it.method == AcpMethod.ClientMethods.SessionUpdate.methodName } + .mapNotNull { it.params } + .map { ACPJson.decodeFromJsonElement(AcpMethod.ClientMethods.SessionUpdate.serializer, it).update } +} + +class TestAgentSupport(val promptHandler: PromptHandler) : AgentSupport { + var isInitialized = false + val createdSessions = mutableMapOf() + + override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { + isInitialized = true + return AgentInfo() + } + + override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { + val sessionId = SessionId("test-agent-session-${sessionId.incrementAndGet()}") + val session = TestAgentSession(sessionId, promptHandler) + createdSessions[sessionId] = session + return session + } + + companion object { + private val sessionId = atomic(0) + } +} + +typealias PromptHandler = suspend FlowCollector.(List) -> Unit + +class TestAgentSession( + override val sessionId: SessionId, + val promptHandler: PromptHandler +) : AgentSession { + override suspend fun prompt(content: List, _meta: JsonElement?): Flow = flow { + promptHandler(content) + } +} + +fun withTestAgent( + timeout: Duration = 5.seconds, + promptHandler: PromptHandler = echoPromptHandler, + block: suspend CoroutineScope.(TestAgent) -> Unit +) = runBlocking { + val transport = TestTransport(timeout) + val protocol = Protocol(this, transport) + val agentSupport = TestAgentSupport(promptHandler) + val agent = Agent(protocol, agentSupport) + protocol.start() + + // wait a little after protocol start, if messages get sent right away they can get lost + delay(100.milliseconds) + + val testAgent = TestAgent(agent, agentSupport, transport) + block(testAgent) + testAgent.close() +} + +fun withInitializedTestAgent( + timeout: Duration = 5.seconds, + promptHandler: PromptHandler = echoPromptHandler, + block: suspend CoroutineScope.(TestAgent) -> Unit +) = withTestAgent( + timeout = timeout, + promptHandler = promptHandler, +) { testAgent -> + testAgent.testInitialize(InitializeRequest(LATEST_PROTOCOL_VERSION)) + check(testAgent.agentSupport.isInitialized) + block(testAgent) +} + +fun withTestAgentSession( + timeout: Duration = 5.seconds, + promptHandler: PromptHandler = echoPromptHandler, + cwd: String = ".", + mcpServers: List = emptyList(), + block: suspend CoroutineScope.(TestAgent, TestAgentSession) -> Unit +) = withInitializedTestAgent( + timeout = timeout, + promptHandler = promptHandler, +) { testAgent -> + val (newSessionResponse) = testAgent.testNewSession(NewSessionRequest(cwd, mcpServers)) + checkNotNull(newSessionResponse) + val session = testAgent.agentSupport.createdSessions[newSessionResponse.sessionId] + checkNotNull(session) + block(testAgent, session) +} + +val echoPromptHandler: PromptHandler = { prompt -> + prompt.filterIsInstance().forEach { + emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(it))) + } + emit(Event.PromptResponseEvent(PromptResponse(StopReason.END_TURN))) +} + +fun delayEchoPromptHandler(delay: Duration): PromptHandler = { prompt -> + delay(delay) + prompt.filterIsInstance().forEach { + emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(it))) + } + emit(Event.PromptResponseEvent(PromptResponse(StopReason.END_TURN))) +} diff --git a/acp/src/jvmTest/kotlin/com/agentclientprotocol/agent/TestTransport.kt b/acp/src/jvmTest/kotlin/com/agentclientprotocol/agent/TestTransport.kt new file mode 100644 index 0000000..ba025b4 --- /dev/null +++ b/acp/src/jvmTest/kotlin/com/agentclientprotocol/agent/TestTransport.kt @@ -0,0 +1,68 @@ +package com.agentclientprotocol.agent + +import com.agentclientprotocol.rpc.JsonRpcMessage +import com.agentclientprotocol.rpc.JsonRpcNotification +import com.agentclientprotocol.rpc.JsonRpcRequest +import com.agentclientprotocol.rpc.JsonRpcResponse +import com.agentclientprotocol.rpc.MethodName +import com.agentclientprotocol.rpc.RequestId +import com.agentclientprotocol.transport.BaseTransport +import com.agentclientprotocol.transport.Transport +import kotlinx.atomicfu.atomic +import kotlinx.coroutines.async +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.filterNotNull +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.flow.transformWhile +import kotlinx.coroutines.withTimeoutOrNull +import kotlinx.serialization.json.JsonElement +import kotlin.time.Duration + +class TestTransport(val timeout: Duration) : BaseTransport() { + private val requestId = atomic(0) + + private val responseFlow: MutableStateFlow = MutableStateFlow(null) + + override fun start() { + _state.value = Transport.State.STARTED + } + + override fun send(message: JsonRpcMessage) { + responseFlow.value = message + } + + override fun close() { + _state.value = Transport.State.CLOSING + fireClose() + _state.value = Transport.State.CLOSED + } + + suspend fun fireTestRequest(methodName: MethodName, params: JsonElement): List { + val reqId = RequestId.create(requestId.incrementAndGet()) + val jsonReq = JsonRpcRequest(reqId, methodName, params) + + return try { + coroutineScope { + val responses = async { + withTimeoutOrNull(timeout) { + responseFlow.filterNotNull() + .transformWhile { + emit(it) + it !is JsonRpcResponse + } + .toList() + } + } + fireMessage(jsonReq) + responses.await() ?: emptyList() + } + } finally { + responseFlow.value = null + } + } + + fun fireTestNotification(methodName: MethodName, params: JsonElement) { + fireMessage(JsonRpcNotification(methodName, params)) + } +} \ No newline at end of file