diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/protocol/Protocol.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/protocol/Protocol.kt index 9f45cd0..0afa075 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/protocol/Protocol.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/protocol/Protocol.kt @@ -173,9 +173,10 @@ public class Protocol( } // Start processing incoming messages + val messageChannel = transport.asMessageChannel() scope.launch(CoroutineName("${Protocol::class.simpleName!!}.read-messages")) { runCatching { - for (message in transport.asMessageChannel()) { + for (message in messageChannel) { handleIncomingMessage(message) } }.checkCancelled().onFailure { diff --git a/acp/src/jvmTest/kotlin/com/agentclientprotocol/agent/TestAgent.kt b/acp/src/jvmTest/kotlin/com/agentclientprotocol/agent/TestAgent.kt index f683774..acb4efe 100644 --- a/acp/src/jvmTest/kotlin/com/agentclientprotocol/agent/TestAgent.kt +++ b/acp/src/jvmTest/kotlin/com/agentclientprotocol/agent/TestAgent.kt @@ -34,7 +34,6 @@ import kotlinx.coroutines.flow.flow import kotlinx.coroutines.runBlocking import kotlinx.serialization.json.JsonElement import kotlin.time.Duration -import kotlin.time.Duration.Companion.milliseconds import kotlin.time.Duration.Companion.seconds class TestAgent(val agent: Agent, val agentSupport: TestAgentSupport, val transport: TestTransport) { @@ -124,10 +123,6 @@ fun withTestAgent( val agentSupport = TestAgentSupport(promptHandler) val agent = Agent(protocol, agentSupport) protocol.start() - - // wait a little after protocol start, if messages get sent right away they can get lost - delay(100.milliseconds) - val testAgent = TestAgent(agent, agentSupport, transport) block(testAgent) testAgent.close() diff --git a/acp/src/jvmTest/kotlin/com/agentclientprotocol/agent/TestTransport.kt b/acp/src/jvmTest/kotlin/com/agentclientprotocol/agent/TestTransport.kt index ba025b4..c12593d 100644 --- a/acp/src/jvmTest/kotlin/com/agentclientprotocol/agent/TestTransport.kt +++ b/acp/src/jvmTest/kotlin/com/agentclientprotocol/agent/TestTransport.kt @@ -10,11 +10,8 @@ import com.agentclientprotocol.transport.BaseTransport import com.agentclientprotocol.transport.Transport import kotlinx.atomicfu.atomic import kotlinx.coroutines.async +import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.coroutineScope -import kotlinx.coroutines.flow.MutableStateFlow -import kotlinx.coroutines.flow.filterNotNull -import kotlinx.coroutines.flow.toList -import kotlinx.coroutines.flow.transformWhile import kotlinx.coroutines.withTimeoutOrNull import kotlinx.serialization.json.JsonElement import kotlin.time.Duration @@ -22,19 +19,20 @@ import kotlin.time.Duration class TestTransport(val timeout: Duration) : BaseTransport() { private val requestId = atomic(0) - private val responseFlow: MutableStateFlow = MutableStateFlow(null) + private val responses = Channel(capacity = Channel.UNLIMITED) override fun start() { _state.value = Transport.State.STARTED } override fun send(message: JsonRpcMessage) { - responseFlow.value = message + responses.trySend(message) } override fun close() { _state.value = Transport.State.CLOSING fireClose() + responses.close() _state.value = Transport.State.CLOSED } @@ -42,23 +40,19 @@ class TestTransport(val timeout: Duration) : BaseTransport() { val reqId = RequestId.create(requestId.incrementAndGet()) val jsonReq = JsonRpcRequest(reqId, methodName, params) - return try { - coroutineScope { - val responses = async { - withTimeoutOrNull(timeout) { - responseFlow.filterNotNull() - .transformWhile { - emit(it) - it !is JsonRpcResponse - } - .toList() + return coroutineScope { + val responses = async { + withTimeoutOrNull(timeout) { + buildList { + do { + val message = responses.receive() + add(message) + } while (message !is JsonRpcResponse) } } - fireMessage(jsonReq) - responses.await() ?: emptyList() } - } finally { - responseFlow.value = null + fireMessage(jsonReq) + responses.await() ?: emptyList() } }