Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,10 @@ public class Protocol(
}

// Start processing incoming messages
val messageChannel = transport.asMessageChannel()
scope.launch(CoroutineName("${Protocol::class.simpleName!!}.read-messages")) {
runCatching {
for (message in transport.asMessageChannel()) {
for (message in messageChannel) {
handleIncomingMessage(message)
}
}.checkCancelled().onFailure {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.runBlocking
import kotlinx.serialization.json.JsonElement
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds

class TestAgent(val agent: Agent, val agentSupport: TestAgentSupport, val transport: TestTransport) {
Expand Down Expand Up @@ -124,10 +123,6 @@ fun withTestAgent(
val agentSupport = TestAgentSupport(promptHandler)
val agent = Agent(protocol, agentSupport)
protocol.start()

// wait a little after protocol start, if messages get sent right away they can get lost
delay(100.milliseconds)

val testAgent = TestAgent(agent, agentSupport, transport)
block(testAgent)
testAgent.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,55 +10,49 @@ import com.agentclientprotocol.transport.BaseTransport
import com.agentclientprotocol.transport.Transport
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.async
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.flow.transformWhile
import kotlinx.coroutines.withTimeoutOrNull
import kotlinx.serialization.json.JsonElement
import kotlin.time.Duration

class TestTransport(val timeout: Duration) : BaseTransport() {
private val requestId = atomic(0)

private val responseFlow: MutableStateFlow<JsonRpcMessage?> = MutableStateFlow(null)
private val responses = Channel<JsonRpcMessage>(capacity = Channel.UNLIMITED)

override fun start() {
_state.value = Transport.State.STARTED
}

override fun send(message: JsonRpcMessage) {
responseFlow.value = message
responses.trySend(message)
}

override fun close() {
_state.value = Transport.State.CLOSING
fireClose()
responses.close()
_state.value = Transport.State.CLOSED
}

suspend fun fireTestRequest(methodName: MethodName, params: JsonElement): List<JsonRpcMessage> {
val reqId = RequestId.create(requestId.incrementAndGet())
val jsonReq = JsonRpcRequest(reqId, methodName, params)

return try {
coroutineScope {
val responses = async {
withTimeoutOrNull(timeout) {
responseFlow.filterNotNull()
.transformWhile {
emit(it)
it !is JsonRpcResponse
}
.toList()
return coroutineScope {
val responses = async {
withTimeoutOrNull(timeout) {
buildList {
do {
val message = responses.receive()
add(message)
} while (message !is JsonRpcResponse)
}
}
fireMessage(jsonReq)
responses.await() ?: emptyList()
}
} finally {
responseFlow.value = null
fireMessage(jsonReq)
responses.await() ?: emptyList()
}
}

Expand Down
Loading