Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 34 additions & 26 deletions acp/src/commonMain/kotlin/com/agentclientprotocol/agent/Agent.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@ import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.update
import kotlinx.collections.immutable.persistentMapOf
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineStart
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.Job
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import kotlinx.serialization.json.JsonElement
import kotlin.coroutines.AbstractCoroutineContextElement
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.cancellation.CancellationException
import kotlin.math.min
import kotlin.uuid.ExperimentalUuidApi

Expand Down Expand Up @@ -58,55 +61,60 @@ public class Agent(
val clientOperations: ClientSessionOperations,
protocol: Protocol
) : BaseSessionWrapper(agent, protocol) {
private class PromptSession(val currentRequestId: RequestId)
private class PromptSession(val currentRequestId: RequestId, val promptJob: Job)
private val _activePrompt = atomic<PromptSession?>(null)

suspend fun prompt(content: List<ContentBlock>, _meta: JsonElement? = null): PromptResponse {
val currentRpcRequest = currentCoroutineContext().jsonRpcRequest
if (!_activePrompt.compareAndSet(null, PromptSession(currentRpcRequest.id))) error("There is already active prompt execution")
try {
var response: PromptResponse? = null

agentSession.prompt(content, _meta).collect { event ->
when (event) {
is Event.PromptResponseEvent -> {
if (response != null) {
logger.error { "Received repeated prompt response: ${event.response} (previous: $response). The last is used" }
var response: PromptResponse? = null
return coroutineScope {
try {
val promptJob = launch(start = CoroutineStart.LAZY) {
agentSession.prompt(content, _meta).collect { event ->
when (event) {
is Event.PromptResponseEvent -> {
if (response != null) {
logger.error { "Received repeated prompt response: ${event.response} (previous: $response). The last is used" }
}
response = event.response
}

is Event.SessionUpdateEvent -> {
clientOperations.notify(event.update, _meta)
}
}
response = event.response
}
}

is Event.SessionUpdateEvent -> {
clientOperations.notify(event.update, _meta)
}
val promptSession = PromptSession(currentRpcRequest.id, promptJob)
if (!_activePrompt.compareAndSet(null, promptSession)) {
error("There is already active prompt execution")
}
promptJob.join()
response ?: PromptResponse(
stopReason = if (promptJob.isCancelled) StopReason.CANCELLED else StopReason.END_TURN
)
} finally {
_activePrompt.getAndSet(null)
}

return response ?: PromptResponse(StopReason.END_TURN)
}
catch (ce: CancellationException) {
logger.trace(ce) { "Prompt job cancelled" }
return PromptResponse(StopReason.CANCELLED)
}
finally {
_activePrompt.getAndSet(null)
}
}

suspend fun cancel() {
// TODO do we need it while the cancellation can be handled by coroutine mechanism (catching CE inside `prompt`)
// notify AgentSession about upcoming cancellation, this way implementations can gracefully stop ongoing requests
agentSession.cancel()

val activePrompt = _activePrompt.getAndSet(null)
if (activePrompt != null) {
logger.trace { "Cancelling prompt" }
// we expect that all nested outgoing jobs will be cancelled automatically due to structured concurrency
// -> prompt task
// <- [request] read file
// -> [response] read file
// <- [request] permissions
// |suspended|
// cancelling the whole prompt should cancel all nested outgoing requests. These requests on CE will propagate cancellation to the other side
protocol.cancelPendingIncomingRequest(activePrompt.currentRequestId)
activePrompt.promptJob.cancel()
}
}
}
Expand Down
70 changes: 70 additions & 0 deletions acp/src/jvmTest/kotlin/com/agentclientprotocol/agent/AgentTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package com.agentclientprotocol.agent

import com.agentclientprotocol.annotations.UnstableApi
import com.agentclientprotocol.model.*
import kotlinx.coroutines.async
import kotlinx.coroutines.delay
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
import kotlin.time.Duration.Companion.seconds

@OptIn(UnstableApi::class)
class AgentTest {

@Test
fun `initialize agent`() {
withTestAgent { testAgent ->
val (response) = testAgent.testInitialize(InitializeRequest(LATEST_PROTOCOL_VERSION))
assertNotNull(response)
assertTrue(testAgent.agentSupport.isInitialized)
}
}

@Test
fun `create new session`() {
withInitializedTestAgent { testAgent ->
val (response) = testAgent.testNewSession(NewSessionRequest(cwd = ".", mcpServers = emptyList()))
assertNotNull(response)
assertTrue(response.sessionId in testAgent.agentSupport.createdSessions)
}
}

@Test
fun `simple prompt turn`() {
withTestAgentSession(promptHandler = echoPromptHandler) { testAgent, _ ->
testAgent.simplePrompt("hello").let { (response, updates) ->
assertEquals(StopReason.END_TURN, response.stopReason)
assertEquals(1, updates.size)

val message = updates.filterIsInstance<SessionUpdate.AgentMessageChunk>()
.map { (it.content as? ContentBlock.Text)?.text }
.firstOrNull()
assertEquals("hello", message)
}

testAgent.simplePrompt("world").let { (response, updates) ->
assertEquals(StopReason.END_TURN, response.stopReason)
assertEquals(1, updates.size)

val message = updates.filterIsInstance<SessionUpdate.AgentMessageChunk>()
.map { (it.content as? ContentBlock.Text)?.text }
.firstOrNull()
assertEquals("world", message)
}
}
}

@Test
fun `prompt cancellation`() {
withTestAgentSession(promptHandler = delayEchoPromptHandler(2.seconds)) { testAgent, session ->
val deferredResponse = async { testAgent.simplePrompt("hello").first }
delay(1.seconds)
testAgent.testCancel(CancelNotification(session.sessionId))

val response = deferredResponse.await()
assertEquals(StopReason.CANCELLED, response.stopReason)
}
}
}
179 changes: 179 additions & 0 deletions acp/src/jvmTest/kotlin/com/agentclientprotocol/agent/TestAgent.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
@file:OptIn(UnstableApi::class)

package com.agentclientprotocol.agent

import com.agentclientprotocol.annotations.UnstableApi
import com.agentclientprotocol.client.ClientInfo
import com.agentclientprotocol.common.Event
import com.agentclientprotocol.common.SessionCreationParameters
import com.agentclientprotocol.model.AcpMethod
import com.agentclientprotocol.model.AcpNotification
import com.agentclientprotocol.model.AcpRequest
import com.agentclientprotocol.model.AcpResponse
import com.agentclientprotocol.model.CancelNotification
import com.agentclientprotocol.model.ContentBlock
import com.agentclientprotocol.model.InitializeRequest
import com.agentclientprotocol.model.LATEST_PROTOCOL_VERSION
import com.agentclientprotocol.model.McpServer
import com.agentclientprotocol.model.NewSessionRequest
import com.agentclientprotocol.model.PromptRequest
import com.agentclientprotocol.model.PromptResponse
import com.agentclientprotocol.model.SessionId
import com.agentclientprotocol.model.SessionUpdate
import com.agentclientprotocol.model.StopReason
import com.agentclientprotocol.protocol.Protocol
import com.agentclientprotocol.rpc.ACPJson
import com.agentclientprotocol.rpc.JsonRpcNotification
import com.agentclientprotocol.rpc.JsonRpcResponse
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.runBlocking
import kotlinx.serialization.json.JsonElement
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds

class TestAgent(val agent: Agent, val agentSupport: TestAgentSupport, val transport: TestTransport) {
suspend fun <TRequest : AcpRequest, TResponse : AcpResponse> testRequest(
method: AcpMethod.AcpRequestResponseMethod<TRequest, TResponse>,
request: TRequest
): Pair<TResponse?, List<JsonRpcNotification>> {
val received = transport.fireTestRequest(
methodName = method.methodName,
params = ACPJson.encodeToJsonElement(method.requestSerializer, request)
)
val response = (received.lastOrNull() as? JsonRpcResponse)?.result?.let {
ACPJson.decodeFromJsonElement(method.responseSerializer, it)
}
val notifications = received.filterIsInstance<JsonRpcNotification>()
return response to notifications
}

fun <TNotification : AcpNotification> testNotification(
method: AcpMethod.AcpNotificationMethod<TNotification>,
notification: TNotification
) {
transport.fireTestNotification(method.methodName, ACPJson.encodeToJsonElement(method.serializer, notification))
}

fun close() {
agent.protocol.close()
}

suspend fun testInitialize(request: InitializeRequest) = testRequest(AcpMethod.AgentMethods.Initialize, request)
suspend fun testNewSession(request: NewSessionRequest) = testRequest(AcpMethod.AgentMethods.SessionNew, request)
suspend fun testPrompt(request: PromptRequest) = testRequest(AcpMethod.AgentMethods.SessionPrompt, request)

fun testCancel(notification: CancelNotification) = testNotification(AcpMethod.AgentMethods.SessionCancel, notification)
}

suspend fun TestAgent.simplePrompt(prompt: String): Pair<PromptResponse, List<SessionUpdate>> {
val session = agentSupport.createdSessions.values.single()
val (resp, notifications) = testPrompt(PromptRequest(session.sessionId, listOf(ContentBlock.Text(prompt))))
checkNotNull(resp)

return resp to notifications
.filter { it.method == AcpMethod.ClientMethods.SessionUpdate.methodName }
.mapNotNull { it.params }
.map { ACPJson.decodeFromJsonElement(AcpMethod.ClientMethods.SessionUpdate.serializer, it).update }
}

class TestAgentSupport(val promptHandler: PromptHandler) : AgentSupport {
var isInitialized = false
val createdSessions = mutableMapOf<SessionId, TestAgentSession>()

override suspend fun initialize(clientInfo: ClientInfo): AgentInfo {
isInitialized = true
return AgentInfo()
}

override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession {
val sessionId = SessionId("test-agent-session-${sessionId.incrementAndGet()}")
val session = TestAgentSession(sessionId, promptHandler)
createdSessions[sessionId] = session
return session
}

companion object {
private val sessionId = atomic(0)
}
}

typealias PromptHandler = suspend FlowCollector<Event>.(List<ContentBlock>) -> Unit

class TestAgentSession(
override val sessionId: SessionId,
val promptHandler: PromptHandler
) : AgentSession {
override suspend fun prompt(content: List<ContentBlock>, _meta: JsonElement?): Flow<Event> = flow {
promptHandler(content)
}
}

fun withTestAgent(
timeout: Duration = 5.seconds,
promptHandler: PromptHandler = echoPromptHandler,
block: suspend CoroutineScope.(TestAgent) -> Unit
) = runBlocking {
val transport = TestTransport(timeout)
val protocol = Protocol(this, transport)
val agentSupport = TestAgentSupport(promptHandler)
val agent = Agent(protocol, agentSupport)
protocol.start()

// wait a little after protocol start, if messages get sent right away they can get lost
delay(100.milliseconds)

val testAgent = TestAgent(agent, agentSupport, transport)
block(testAgent)
testAgent.close()
}

fun withInitializedTestAgent(
timeout: Duration = 5.seconds,
promptHandler: PromptHandler = echoPromptHandler,
block: suspend CoroutineScope.(TestAgent) -> Unit
) = withTestAgent(
timeout = timeout,
promptHandler = promptHandler,
) { testAgent ->
testAgent.testInitialize(InitializeRequest(LATEST_PROTOCOL_VERSION))
check(testAgent.agentSupport.isInitialized)
block(testAgent)
}

fun withTestAgentSession(
timeout: Duration = 5.seconds,
promptHandler: PromptHandler = echoPromptHandler,
cwd: String = ".",
mcpServers: List<McpServer> = emptyList(),
block: suspend CoroutineScope.(TestAgent, TestAgentSession) -> Unit
) = withInitializedTestAgent(
timeout = timeout,
promptHandler = promptHandler,
) { testAgent ->
val (newSessionResponse) = testAgent.testNewSession(NewSessionRequest(cwd, mcpServers))
checkNotNull(newSessionResponse)
val session = testAgent.agentSupport.createdSessions[newSessionResponse.sessionId]
checkNotNull(session)
block(testAgent, session)
}

val echoPromptHandler: PromptHandler = { prompt ->
prompt.filterIsInstance<ContentBlock.Text>().forEach {
emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(it)))
}
emit(Event.PromptResponseEvent(PromptResponse(StopReason.END_TURN)))
}

fun delayEchoPromptHandler(delay: Duration): PromptHandler = { prompt ->
delay(delay)
prompt.filterIsInstance<ContentBlock.Text>().forEach {
emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(it)))
}
emit(Event.PromptResponseEvent(PromptResponse(StopReason.END_TURN)))
}
Loading
Loading