Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .swift-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"version": 1,
"indentation" : {
"spaces" : 4
},
"lineBreakBeforeEachArgument": true
}
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ let package = Package(
.library(
name: "GraphQLTransportWS",
targets: ["GraphQLTransportWS"]
),
)
],
dependencies: [
.package(url: "https://github.com/GraphQLSwift/Graphiti.git", from: "3.0.0"),
Expand Down
16 changes: 12 additions & 4 deletions Sources/GraphQLTransportWS/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ public actor Client<InitPayload: Equatable & Codable> {
/// - onComplete: The callback run on receipt of a `complete` message
public init(
messenger: Messenger,
onConnectionAck: @escaping (ConnectionAckResponse, Client) async throws -> Void = { _, _ in },
onConnectionAck: @escaping (ConnectionAckResponse, Client) async throws -> Void = { _, _ in
},
onNext: @escaping (NextResponse, Client) async throws -> Void = { _, _ in },
onError: @escaping (ErrorResponse, Client) async throws -> Void = { _, _ in },
onComplete: @escaping (CompleteResponse, Client) async throws -> Void = { _, _ in }
Expand All @@ -37,7 +38,8 @@ public actor Client<InitPayload: Equatable & Codable> {

/// Listen and react to the provided async sequence of server messages. This function will block until the stream is completed.
/// - Parameter incoming: The server message sequence that the client should react to.
public func listen<A: AsyncSequence & Sendable>(to incoming: A) async throws -> Void where A.Element == String {
public func listen<A: AsyncSequence & Sendable>(to incoming: A) async throws
where A.Element == String {
for try await message in incoming {
// Detect and ignore error responses.
if message.starts(with: "44") {
Expand All @@ -60,7 +62,12 @@ public actor Client<InitPayload: Equatable & Codable> {

switch response.type {
case .connectionAck:
guard let connectionAckResponse = try? decoder.decode(ConnectionAckResponse.self, from: json) else {
guard
let connectionAckResponse = try? decoder.decode(
ConnectionAckResponse.self,
from: json
)
else {
try await error(.invalidResponseFormat(messageType: .connectionAck))
return
}
Expand All @@ -78,7 +85,8 @@ public actor Client<InitPayload: Equatable & Codable> {
}
try await onError(errorResponse, self)
case .complete:
guard let completeResponse = try? decoder.decode(CompleteResponse.self, from: json) else {
guard let completeResponse = try? decoder.decode(CompleteResponse.self, from: json)
else {
try await error(.invalidResponseFormat(messageType: .complete))
return
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/GraphQLTransportWS/Messenger.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import Foundation
public protocol Messenger: Sendable {
/// Send a message through this messenger
/// - Parameter message: The message to send
func send<S: Sendable & Collection>(_ message: S) async throws -> Void where S.Element == Character
func send<S: Sendable & Collection>(_ message: S) async throws where S.Element == Character

/// Close the messenger
func close() async throws
Expand Down
30 changes: 18 additions & 12 deletions Sources/GraphQLTransportWS/Requests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ public struct ConnectionInitRequest<InitPayload: Codable & Equatable>: Equatable
public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: Self.CodingKeys.self)
if try container.decode(RequestMessageType.self, forKey: .type) != .connectionInit {
throw DecodingError.dataCorrupted(.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(RequestMessageType.connectionInit.type)`"
))
throw DecodingError.dataCorrupted(
.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(RequestMessageType.connectionInit.type)`"
)
)
}
payload = try container.decode(InitPayload.self, forKey: .payload)
}
Expand All @@ -41,10 +43,12 @@ public struct SubscribeRequest: Equatable, JsonEncodable {
public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: Self.CodingKeys.self)
if try container.decode(RequestMessageType.self, forKey: .type) != .subscribe {
throw DecodingError.dataCorrupted(.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(RequestMessageType.subscribe.type)`"
))
throw DecodingError.dataCorrupted(
.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(RequestMessageType.subscribe.type)`"
)
)
}
payload = try container.decode(GraphQLRequest.self, forKey: .payload)
id = try container.decode(String.self, forKey: .id)
Expand All @@ -63,10 +67,12 @@ public struct CompleteRequest: Equatable, JsonEncodable {
public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: Self.CodingKeys.self)
if try container.decode(RequestMessageType.self, forKey: .type) != .complete {
throw DecodingError.dataCorrupted(.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(RequestMessageType.complete.type)`"
))
throw DecodingError.dataCorrupted(
.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(RequestMessageType.complete.type)`"
)
)
}
id = try container.decode(String.self, forKey: .id)
}
Expand Down
40 changes: 24 additions & 16 deletions Sources/GraphQLTransportWS/Responses.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ public struct ConnectionAckResponse: Equatable, JsonEncodable {
public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: Self.CodingKeys.self)
if try container.decode(ResponseMessageType.self, forKey: .type) != .connectionAck {
throw DecodingError.dataCorrupted(.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(ResponseMessageType.connectionAck.type)`"
))
throw DecodingError.dataCorrupted(
.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(ResponseMessageType.connectionAck.type)`"
)
)
}
payload = try container.decodeIfPresent([String: Map].self, forKey: .payload)
}
Expand All @@ -41,10 +43,12 @@ public struct NextResponse: Equatable, JsonEncodable {
public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: Self.CodingKeys.self)
if try container.decode(ResponseMessageType.self, forKey: .type) != .next {
throw DecodingError.dataCorrupted(.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(ResponseMessageType.next.type)`"
))
throw DecodingError.dataCorrupted(
.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(ResponseMessageType.next.type)`"
)
)
}
payload = try container.decodeIfPresent(GraphQLResult.self, forKey: .payload)
id = try container.decode(String.self, forKey: .id)
Expand All @@ -63,10 +67,12 @@ public struct CompleteResponse: Equatable, JsonEncodable {
public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: Self.CodingKeys.self)
if try container.decode(ResponseMessageType.self, forKey: .type) != .complete {
throw DecodingError.dataCorrupted(.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(ResponseMessageType.complete.type)`"
))
throw DecodingError.dataCorrupted(
.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(ResponseMessageType.complete.type)`"
)
)
}
id = try container.decode(String.self, forKey: .id)
}
Expand Down Expand Up @@ -94,10 +100,12 @@ public struct ErrorResponse: Equatable, JsonEncodable {
public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: Self.CodingKeys.self)
if try container.decode(ResponseMessageType.self, forKey: .type) != .error {
throw DecodingError.dataCorrupted(.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(ResponseMessageType.error.type)`"
))
throw DecodingError.dataCorrupted(
.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(ResponseMessageType.error.type)`"
)
)
}
payload = try container.decode([GraphQLError].self, forKey: .payload)
id = try container.decode(String.self, forKey: .id)
Expand Down
27 changes: 20 additions & 7 deletions Sources/GraphQLTransportWS/Server.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ public actor Server<
InitPayload: Equatable & Codable & Sendable,
InitPayloadResult: Sendable,
SubscriptionSequenceType: AsyncSequence & Sendable
> where
>
where
SubscriptionSequenceType.Element == GraphQLResult
{
let messenger: Messenger
Expand Down Expand Up @@ -38,7 +39,8 @@ public actor Server<
messenger: Messenger,
onInit: @escaping (InitPayload) async throws -> InitPayloadResult,
onExecute: @escaping (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult,
onSubscribe: @escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType,
onSubscribe:
@escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType,
onOperationComplete: @escaping (String) async throws -> Void = { _ in },
onOperationError: @escaping (String, [Error]) async throws -> Void = { _, _ in }
) {
Expand All @@ -52,7 +54,8 @@ public actor Server<

/// Listen and react to the provided async sequence of client messages. This function will block until the stream is completed.
/// - Parameter incoming: The client message sequence that the server should react to.
public func listen<A: AsyncSequence & Sendable>(to incoming: A) async throws -> Void where A.Element == String {
public func listen<A: AsyncSequence & Sendable>(to incoming: A) async throws
where A.Element == String {
for try await message in incoming {
// Detect and ignore error responses.
if message.starts(with: "44") {
Expand All @@ -76,19 +79,26 @@ public actor Server<
// handle incoming message
switch request.type {
case .connectionInit:
guard let connectionInitRequest = try? decoder.decode(ConnectionInitRequest<InitPayload>.self, from: json) else {
guard
let connectionInitRequest = try? decoder.decode(
ConnectionInitRequest<InitPayload>.self,
from: json
)
else {
try await error(.invalidRequestFormat(messageType: .connectionInit))
return
}
try await onConnectionInit(connectionInitRequest, messenger)
case .subscribe:
guard let subscribeRequest = try? decoder.decode(SubscribeRequest.self, from: json) else {
guard let subscribeRequest = try? decoder.decode(SubscribeRequest.self, from: json)
else {
try await error(.invalidRequestFormat(messageType: .subscribe))
return
}
try await onSubscribe(subscribeRequest)
case .complete:
guard let completeRequest = try? decoder.decode(CompleteRequest.self, from: json) else {
guard let completeRequest = try? decoder.decode(CompleteRequest.self, from: json)
else {
try await error(.invalidRequestFormat(messageType: .complete))
return
}
Expand All @@ -103,7 +113,10 @@ public actor Server<
subscriptionTasks.values.forEach { $0.cancel() }
}

private func onConnectionInit(_ connectionInitRequest: ConnectionInitRequest<InitPayload>, _: Messenger) async throws {
private func onConnectionInit(
_ connectionInitRequest: ConnectionInitRequest<InitPayload>,
_: Messenger
) async throws {
guard !initialized else {
try await error(.tooManyInitializations())
return
Expand Down
47 changes: 24 additions & 23 deletions Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ struct GraphqlTransportWSTests {
).get()
}
)
let (messageStream, messageContinuation) = AsyncThrowingStream<String, any Error>.makeStream()
let (messageStream, messageContinuation) = AsyncThrowingStream<String, any Error>
.makeStream()
let serverMessageStream = serverMessenger.stream.map { message in
messageContinuation.yield(message)
// Expect only one message
Expand All @@ -51,10 +52,10 @@ struct GraphqlTransportWSTests {
try await client.sendStart(
payload: GraphQLRequest(
query: """
query {
hello
}
"""
query {
hello
}
"""
),
id: UUID().uuidString
)
Expand All @@ -64,8 +65,7 @@ struct GraphqlTransportWSTests {
result.append(message)
}
#expect(
messages ==
["\(ErrorCode.notInitialized): Connection not initialized"]
messages == ["\(ErrorCode.notInitialized): Connection not initialized"]
)
}

Expand All @@ -91,7 +91,8 @@ struct GraphqlTransportWSTests {
).get()
}
)
let (messageStream, messageContinuation) = AsyncThrowingStream<String, any Error>.makeStream()
let (messageStream, messageContinuation) = AsyncThrowingStream<String, any Error>
.makeStream()
let serverMessageStream = serverMessenger.stream.map { message in
messageContinuation.yield(message)
// Expect only one message
Expand Down Expand Up @@ -122,8 +123,7 @@ struct GraphqlTransportWSTests {
result.append(message)
}
#expect(
messages ==
["\(ErrorCode.unauthorized): Unauthorized"]
messages == ["\(ErrorCode.unauthorized): Unauthorized"]
)
}

Expand All @@ -149,7 +149,8 @@ struct GraphqlTransportWSTests {
).get()
}
)
let (messageStream, messageContinuation) = AsyncThrowingStream<String, any Error>.makeStream()
let (messageStream, messageContinuation) = AsyncThrowingStream<String, any Error>
.makeStream()
let serverMessageStream = serverMessenger.stream.map { message in
messageContinuation.yield(message)
return message
Expand All @@ -160,10 +161,10 @@ struct GraphqlTransportWSTests {
try await client.sendStart(
payload: GraphQLRequest(
query: """
query {
hello
}
"""
query {
hello
}
"""
),
id: id
)
Expand Down Expand Up @@ -194,8 +195,7 @@ struct GraphqlTransportWSTests {
result.append(message)
}
#expect(
messages.count ==
3 // 1 connection_ack, 1 next, 1 complete
messages.count == 3 // 1 connection_ack, 1 next, 1 complete
)
}

Expand Down Expand Up @@ -226,7 +226,8 @@ struct GraphqlTransportWSTests {
return subscription
}
)
let (messageStream, messageContinuation) = AsyncThrowingStream<String, any Error>.makeStream()
let (messageStream, messageContinuation) = AsyncThrowingStream<String, any Error>
.makeStream()
// Used to extract the server messages
let serverMessageStream = serverMessenger.stream.map { message in
messageContinuation.yield(message)
Expand All @@ -238,10 +239,10 @@ struct GraphqlTransportWSTests {
try await client.sendStart(
payload: GraphQLRequest(
query: """
subscription {
hello
}
"""
subscription {
hello
}
"""
),
id: id
)
Expand Down Expand Up @@ -285,7 +286,7 @@ struct GraphqlTransportWSTests {
result.append(message)
}
#expect(
messages.count == 5 // 1 connection_ack, 3 next, 1 complete
messages.count == 5 // 1 connection_ack, 3 next, 1 complete
)
}

Expand Down
Loading
Loading