diff --git a/.swift-format b/.swift-format new file mode 100644 index 0000000..1ffdbf5 --- /dev/null +++ b/.swift-format @@ -0,0 +1,7 @@ +{ + "version": 1, + "indentation" : { + "spaces" : 4 + }, + "lineBreakBeforeEachArgument": true +} diff --git a/Package.swift b/Package.swift index db11fb4..33bdc14 100644 --- a/Package.swift +++ b/Package.swift @@ -9,7 +9,7 @@ let package = Package( .library( name: "GraphQLTransportWS", targets: ["GraphQLTransportWS"] - ), + ) ], dependencies: [ .package(url: "https://github.com/GraphQLSwift/Graphiti.git", from: "3.0.0"), diff --git a/Sources/GraphQLTransportWS/Client.swift b/Sources/GraphQLTransportWS/Client.swift index bcac773..27d15c6 100644 --- a/Sources/GraphQLTransportWS/Client.swift +++ b/Sources/GraphQLTransportWS/Client.swift @@ -23,7 +23,8 @@ public actor Client { /// - onComplete: The callback run on receipt of a `complete` message public init( messenger: Messenger, - onConnectionAck: @escaping (ConnectionAckResponse, Client) async throws -> Void = { _, _ in }, + onConnectionAck: @escaping (ConnectionAckResponse, Client) async throws -> Void = { _, _ in + }, onNext: @escaping (NextResponse, Client) async throws -> Void = { _, _ in }, onError: @escaping (ErrorResponse, Client) async throws -> Void = { _, _ in }, onComplete: @escaping (CompleteResponse, Client) async throws -> Void = { _, _ in } @@ -37,7 +38,8 @@ public actor Client { /// Listen and react to the provided async sequence of server messages. This function will block until the stream is completed. /// - Parameter incoming: The server message sequence that the client should react to. - public func listen(to incoming: A) async throws -> Void where A.Element == String { + public func listen(to incoming: A) async throws + where A.Element == String { for try await message in incoming { // Detect and ignore error responses. if message.starts(with: "44") { @@ -60,7 +62,12 @@ public actor Client { switch response.type { case .connectionAck: - guard let connectionAckResponse = try? decoder.decode(ConnectionAckResponse.self, from: json) else { + guard + let connectionAckResponse = try? decoder.decode( + ConnectionAckResponse.self, + from: json + ) + else { try await error(.invalidResponseFormat(messageType: .connectionAck)) return } @@ -78,7 +85,8 @@ public actor Client { } try await onError(errorResponse, self) case .complete: - guard let completeResponse = try? decoder.decode(CompleteResponse.self, from: json) else { + guard let completeResponse = try? decoder.decode(CompleteResponse.self, from: json) + else { try await error(.invalidResponseFormat(messageType: .complete)) return } diff --git a/Sources/GraphQLTransportWS/Messenger.swift b/Sources/GraphQLTransportWS/Messenger.swift index 86ca9d4..e0ba6d9 100644 --- a/Sources/GraphQLTransportWS/Messenger.swift +++ b/Sources/GraphQLTransportWS/Messenger.swift @@ -4,7 +4,7 @@ import Foundation public protocol Messenger: Sendable { /// Send a message through this messenger /// - Parameter message: The message to send - func send(_ message: S) async throws -> Void where S.Element == Character + func send(_ message: S) async throws where S.Element == Character /// Close the messenger func close() async throws diff --git a/Sources/GraphQLTransportWS/Requests.swift b/Sources/GraphQLTransportWS/Requests.swift index 09665b6..5807190 100644 --- a/Sources/GraphQLTransportWS/Requests.swift +++ b/Sources/GraphQLTransportWS/Requests.swift @@ -18,10 +18,12 @@ public struct ConnectionInitRequest: Equatable public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: Self.CodingKeys.self) if try container.decode(RequestMessageType.self, forKey: .type) != .connectionInit { - throw DecodingError.dataCorrupted(.init( - codingPath: decoder.codingPath, - debugDescription: "type must be `\(RequestMessageType.connectionInit.type)`" - )) + throw DecodingError.dataCorrupted( + .init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(RequestMessageType.connectionInit.type)`" + ) + ) } payload = try container.decode(InitPayload.self, forKey: .payload) } @@ -41,10 +43,12 @@ public struct SubscribeRequest: Equatable, JsonEncodable { public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: Self.CodingKeys.self) if try container.decode(RequestMessageType.self, forKey: .type) != .subscribe { - throw DecodingError.dataCorrupted(.init( - codingPath: decoder.codingPath, - debugDescription: "type must be `\(RequestMessageType.subscribe.type)`" - )) + throw DecodingError.dataCorrupted( + .init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(RequestMessageType.subscribe.type)`" + ) + ) } payload = try container.decode(GraphQLRequest.self, forKey: .payload) id = try container.decode(String.self, forKey: .id) @@ -63,10 +67,12 @@ public struct CompleteRequest: Equatable, JsonEncodable { public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: Self.CodingKeys.self) if try container.decode(RequestMessageType.self, forKey: .type) != .complete { - throw DecodingError.dataCorrupted(.init( - codingPath: decoder.codingPath, - debugDescription: "type must be `\(RequestMessageType.complete.type)`" - )) + throw DecodingError.dataCorrupted( + .init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(RequestMessageType.complete.type)`" + ) + ) } id = try container.decode(String.self, forKey: .id) } diff --git a/Sources/GraphQLTransportWS/Responses.swift b/Sources/GraphQLTransportWS/Responses.swift index 71d7d4f..8a2bf3d 100644 --- a/Sources/GraphQLTransportWS/Responses.swift +++ b/Sources/GraphQLTransportWS/Responses.swift @@ -18,10 +18,12 @@ public struct ConnectionAckResponse: Equatable, JsonEncodable { public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: Self.CodingKeys.self) if try container.decode(ResponseMessageType.self, forKey: .type) != .connectionAck { - throw DecodingError.dataCorrupted(.init( - codingPath: decoder.codingPath, - debugDescription: "type must be `\(ResponseMessageType.connectionAck.type)`" - )) + throw DecodingError.dataCorrupted( + .init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(ResponseMessageType.connectionAck.type)`" + ) + ) } payload = try container.decodeIfPresent([String: Map].self, forKey: .payload) } @@ -41,10 +43,12 @@ public struct NextResponse: Equatable, JsonEncodable { public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: Self.CodingKeys.self) if try container.decode(ResponseMessageType.self, forKey: .type) != .next { - throw DecodingError.dataCorrupted(.init( - codingPath: decoder.codingPath, - debugDescription: "type must be `\(ResponseMessageType.next.type)`" - )) + throw DecodingError.dataCorrupted( + .init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(ResponseMessageType.next.type)`" + ) + ) } payload = try container.decodeIfPresent(GraphQLResult.self, forKey: .payload) id = try container.decode(String.self, forKey: .id) @@ -63,10 +67,12 @@ public struct CompleteResponse: Equatable, JsonEncodable { public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: Self.CodingKeys.self) if try container.decode(ResponseMessageType.self, forKey: .type) != .complete { - throw DecodingError.dataCorrupted(.init( - codingPath: decoder.codingPath, - debugDescription: "type must be `\(ResponseMessageType.complete.type)`" - )) + throw DecodingError.dataCorrupted( + .init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(ResponseMessageType.complete.type)`" + ) + ) } id = try container.decode(String.self, forKey: .id) } @@ -94,10 +100,12 @@ public struct ErrorResponse: Equatable, JsonEncodable { public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: Self.CodingKeys.self) if try container.decode(ResponseMessageType.self, forKey: .type) != .error { - throw DecodingError.dataCorrupted(.init( - codingPath: decoder.codingPath, - debugDescription: "type must be `\(ResponseMessageType.error.type)`" - )) + throw DecodingError.dataCorrupted( + .init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(ResponseMessageType.error.type)`" + ) + ) } payload = try container.decode([GraphQLError].self, forKey: .payload) id = try container.decode(String.self, forKey: .id) diff --git a/Sources/GraphQLTransportWS/Server.swift b/Sources/GraphQLTransportWS/Server.swift index 10f0bfe..3e79e03 100644 --- a/Sources/GraphQLTransportWS/Server.swift +++ b/Sources/GraphQLTransportWS/Server.swift @@ -8,7 +8,8 @@ public actor Server< InitPayload: Equatable & Codable & Sendable, InitPayloadResult: Sendable, SubscriptionSequenceType: AsyncSequence & Sendable -> where +> +where SubscriptionSequenceType.Element == GraphQLResult { let messenger: Messenger @@ -38,7 +39,8 @@ public actor Server< messenger: Messenger, onInit: @escaping (InitPayload) async throws -> InitPayloadResult, onExecute: @escaping (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult, - onSubscribe: @escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType, + onSubscribe: + @escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType, onOperationComplete: @escaping (String) async throws -> Void = { _ in }, onOperationError: @escaping (String, [Error]) async throws -> Void = { _, _ in } ) { @@ -52,7 +54,8 @@ public actor Server< /// Listen and react to the provided async sequence of client messages. This function will block until the stream is completed. /// - Parameter incoming: The client message sequence that the server should react to. - public func listen(to incoming: A) async throws -> Void where A.Element == String { + public func listen(to incoming: A) async throws + where A.Element == String { for try await message in incoming { // Detect and ignore error responses. if message.starts(with: "44") { @@ -76,19 +79,26 @@ public actor Server< // handle incoming message switch request.type { case .connectionInit: - guard let connectionInitRequest = try? decoder.decode(ConnectionInitRequest.self, from: json) else { + guard + let connectionInitRequest = try? decoder.decode( + ConnectionInitRequest.self, + from: json + ) + else { try await error(.invalidRequestFormat(messageType: .connectionInit)) return } try await onConnectionInit(connectionInitRequest, messenger) case .subscribe: - guard let subscribeRequest = try? decoder.decode(SubscribeRequest.self, from: json) else { + guard let subscribeRequest = try? decoder.decode(SubscribeRequest.self, from: json) + else { try await error(.invalidRequestFormat(messageType: .subscribe)) return } try await onSubscribe(subscribeRequest) case .complete: - guard let completeRequest = try? decoder.decode(CompleteRequest.self, from: json) else { + guard let completeRequest = try? decoder.decode(CompleteRequest.self, from: json) + else { try await error(.invalidRequestFormat(messageType: .complete)) return } @@ -103,7 +113,10 @@ public actor Server< subscriptionTasks.values.forEach { $0.cancel() } } - private func onConnectionInit(_ connectionInitRequest: ConnectionInitRequest, _: Messenger) async throws { + private func onConnectionInit( + _ connectionInitRequest: ConnectionInitRequest, + _: Messenger + ) async throws { guard !initialized else { try await error(.tooManyInitializations()) return diff --git a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift index 9597059..759c65b 100644 --- a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift +++ b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift @@ -28,7 +28,8 @@ struct GraphqlTransportWSTests { ).get() } ) - let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let (messageStream, messageContinuation) = AsyncThrowingStream + .makeStream() let serverMessageStream = serverMessenger.stream.map { message in messageContinuation.yield(message) // Expect only one message @@ -51,10 +52,10 @@ struct GraphqlTransportWSTests { try await client.sendStart( payload: GraphQLRequest( query: """ - query { - hello - } - """ + query { + hello + } + """ ), id: UUID().uuidString ) @@ -64,8 +65,7 @@ struct GraphqlTransportWSTests { result.append(message) } #expect( - messages == - ["\(ErrorCode.notInitialized): Connection not initialized"] + messages == ["\(ErrorCode.notInitialized): Connection not initialized"] ) } @@ -91,7 +91,8 @@ struct GraphqlTransportWSTests { ).get() } ) - let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let (messageStream, messageContinuation) = AsyncThrowingStream + .makeStream() let serverMessageStream = serverMessenger.stream.map { message in messageContinuation.yield(message) // Expect only one message @@ -122,8 +123,7 @@ struct GraphqlTransportWSTests { result.append(message) } #expect( - messages == - ["\(ErrorCode.unauthorized): Unauthorized"] + messages == ["\(ErrorCode.unauthorized): Unauthorized"] ) } @@ -149,7 +149,8 @@ struct GraphqlTransportWSTests { ).get() } ) - let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let (messageStream, messageContinuation) = AsyncThrowingStream + .makeStream() let serverMessageStream = serverMessenger.stream.map { message in messageContinuation.yield(message) return message @@ -160,10 +161,10 @@ struct GraphqlTransportWSTests { try await client.sendStart( payload: GraphQLRequest( query: """ - query { - hello - } - """ + query { + hello + } + """ ), id: id ) @@ -194,8 +195,7 @@ struct GraphqlTransportWSTests { result.append(message) } #expect( - messages.count == - 3 // 1 connection_ack, 1 next, 1 complete + messages.count == 3 // 1 connection_ack, 1 next, 1 complete ) } @@ -226,7 +226,8 @@ struct GraphqlTransportWSTests { return subscription } ) - let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let (messageStream, messageContinuation) = AsyncThrowingStream + .makeStream() // Used to extract the server messages let serverMessageStream = serverMessenger.stream.map { message in messageContinuation.yield(message) @@ -238,10 +239,10 @@ struct GraphqlTransportWSTests { try await client.sendStart( payload: GraphQLRequest( query: """ - subscription { - hello - } - """ + subscription { + hello + } + """ ), id: id ) @@ -285,7 +286,7 @@ struct GraphqlTransportWSTests { result.append(message) } #expect( - messages.count == 5 // 1 connection_ack, 3 next, 1 complete + messages.count == 5 // 1 connection_ack, 3 next, 1 complete ) } diff --git a/Tests/GraphQLTransportWSTests/Utils/TestAPI.swift b/Tests/GraphQLTransportWSTests/Utils/TestAPI.swift index 8867da1..b6e8cca 100644 --- a/Tests/GraphQLTransportWSTests/Utils/TestAPI.swift +++ b/Tests/GraphQLTransportWSTests/Utils/TestAPI.swift @@ -1,6 +1,6 @@ import Foundation -import Graphiti import GraphQL +import Graphiti struct TestAPI: API { let resolver = TestResolver() @@ -29,7 +29,10 @@ struct TestResolver { context.hello() } - func subscribeHello(context: TestContext, arguments _: NoArguments) -> AsyncThrowingStream { + func subscribeHello( + context: TestContext, + arguments _: NoArguments + ) -> AsyncThrowingStream { context.publisher.subscribe() } } diff --git a/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift b/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift index 90449c9..a27c1d5 100644 --- a/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift +++ b/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift @@ -1,5 +1,5 @@ - import Foundation + @testable import GraphQLTransportWS /// Messenger for simple testing that doesn't require starting up a websocket server.