From 32202ead227f099630ef175ce4d45c523a7e1d0e Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Sun, 5 Apr 2026 01:31:34 -0600 Subject: [PATCH 1/2] chore: Add swift-format file --- .swift-format | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 .swift-format diff --git a/.swift-format b/.swift-format new file mode 100644 index 0000000..1ffdbf5 --- /dev/null +++ b/.swift-format @@ -0,0 +1,7 @@ +{ + "version": 1, + "indentation" : { + "spaces" : 4 + }, + "lineBreakBeforeEachArgument": true +} From 8f8c77a52f88d176b9fcf90545d1517d46af3164 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Sun, 5 Apr 2026 01:35:34 -0600 Subject: [PATCH 2/2] refactor: Applies formatter --- Package.swift | 2 +- Sources/GraphQLWS/Client.swift | 38 ++++++++--- Sources/GraphQLWS/Messenger.swift | 2 +- Sources/GraphQLWS/Requests.swift | 48 ++++++++----- Sources/GraphQLWS/Responses.swift | 67 ++++++++++++------- Sources/GraphQLWS/Server.swift | 32 +++++++-- Tests/GraphQLWSTests/GraphQLWSTests.swift | 46 +++++++------ Tests/GraphQLWSTests/Utils/TestAPI.swift | 7 +- .../GraphQLWSTests/Utils/TestMessenger.swift | 2 +- 9 files changed, 159 insertions(+), 85 deletions(-) diff --git a/Package.swift b/Package.swift index c8c29db..a7b8a92 100644 --- a/Package.swift +++ b/Package.swift @@ -9,7 +9,7 @@ let package = Package( .library( name: "GraphQLWS", targets: ["GraphQLWS"] - ), + ) ], dependencies: [ .package(url: "https://github.com/GraphQLSwift/Graphiti.git", from: "3.0.0"), diff --git a/Sources/GraphQLWS/Client.swift b/Sources/GraphQLWS/Client.swift index a7f993f..82656e4 100644 --- a/Sources/GraphQLWS/Client.swift +++ b/Sources/GraphQLWS/Client.swift @@ -27,9 +27,14 @@ public actor Client { /// - onComplete: The callback run on receipt of a `complete` message public init( messenger: Messenger, - onConnectionError: @escaping (ConnectionErrorResponse, Client) async throws -> Void = { _, _ in }, - onConnectionAck: @escaping (ConnectionAckResponse, Client) async throws -> Void = { _, _ in }, - onConnectionKeepAlive: @escaping (ConnectionKeepAliveResponse, Client) async throws -> Void = { _, _ in }, + onConnectionError: @escaping (ConnectionErrorResponse, Client) async throws -> Void = { + _, + _ in + }, + onConnectionAck: @escaping (ConnectionAckResponse, Client) async throws -> Void = { _, _ in + }, + onConnectionKeepAlive: + @escaping (ConnectionKeepAliveResponse, Client) async throws -> Void = { _, _ in }, onData: @escaping (DataResponse, Client) async throws -> Void = { _, _ in }, onError: @escaping (ErrorResponse, Client) async throws -> Void = { _, _ in }, onComplete: @escaping (CompleteResponse, Client) async throws -> Void = { _, _ in } @@ -45,7 +50,8 @@ public actor Client { /// Listen and react to the provided async sequence of server messages. This function will block until the stream is completed. /// - Parameter incoming: The server message sequence that the client should react to. - public func listen(to incoming: A) async throws -> Void where A.Element == String { + public func listen(to incoming: A) async throws + where A.Element == String { for try await message in incoming { // Detect and ignore error responses. if message.starts(with: "44") { @@ -68,19 +74,34 @@ public actor Client { switch response.type { case .GQL_CONNECTION_ERROR: - guard let connectionErrorResponse = try? decoder.decode(ConnectionErrorResponse.self, from: json) else { + guard + let connectionErrorResponse = try? decoder.decode( + ConnectionErrorResponse.self, + from: json + ) + else { try await error(.invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR)) return } try await onConnectionError(connectionErrorResponse, self) case .GQL_CONNECTION_ACK: - guard let connectionAckResponse = try? decoder.decode(ConnectionAckResponse.self, from: json) else { + guard + let connectionAckResponse = try? decoder.decode( + ConnectionAckResponse.self, + from: json + ) + else { try await error(.invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR)) return } try await onConnectionAck(connectionAckResponse, self) case .GQL_CONNECTION_KEEP_ALIVE: - guard let connectionKeepAliveResponse = try? decoder.decode(ConnectionKeepAliveResponse.self, from: json) else { + guard + let connectionKeepAliveResponse = try? decoder.decode( + ConnectionKeepAliveResponse.self, + from: json + ) + else { try await error(.invalidResponseFormat(messageType: .GQL_CONNECTION_KEEP_ALIVE)) return } @@ -98,7 +119,8 @@ public actor Client { } try await onError(errorResponse, self) case .GQL_COMPLETE: - guard let completeResponse = try? decoder.decode(CompleteResponse.self, from: json) else { + guard let completeResponse = try? decoder.decode(CompleteResponse.self, from: json) + else { try await error(.invalidResponseFormat(messageType: .GQL_COMPLETE)) return } diff --git a/Sources/GraphQLWS/Messenger.swift b/Sources/GraphQLWS/Messenger.swift index 86ca9d4..e0ba6d9 100644 --- a/Sources/GraphQLWS/Messenger.swift +++ b/Sources/GraphQLWS/Messenger.swift @@ -4,7 +4,7 @@ import Foundation public protocol Messenger: Sendable { /// Send a message through this messenger /// - Parameter message: The message to send - func send(_ message: S) async throws -> Void where S.Element == Character + func send(_ message: S) async throws where S.Element == Character /// Close the messenger func close() async throws diff --git a/Sources/GraphQLWS/Requests.swift b/Sources/GraphQLWS/Requests.swift index f86d69b..251df45 100644 --- a/Sources/GraphQLWS/Requests.swift +++ b/Sources/GraphQLWS/Requests.swift @@ -18,10 +18,13 @@ public struct ConnectionInitRequest: Equatable public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: Self.CodingKeys.self) if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_CONNECTION_INIT { - throw DecodingError.dataCorrupted(.init( - codingPath: decoder.codingPath, - debugDescription: "type must be `\(RequestMessageType.GQL_CONNECTION_INIT.type)`" - )) + throw DecodingError.dataCorrupted( + .init( + codingPath: decoder.codingPath, + debugDescription: + "type must be `\(RequestMessageType.GQL_CONNECTION_INIT.type)`" + ) + ) } payload = try container.decode(InitPayload.self, forKey: .payload) } @@ -41,10 +44,12 @@ public struct StartRequest: Equatable, JsonEncodable { public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: Self.CodingKeys.self) if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_START { - throw DecodingError.dataCorrupted(.init( - codingPath: decoder.codingPath, - debugDescription: "type must be `\(RequestMessageType.GQL_START.type)`" - )) + throw DecodingError.dataCorrupted( + .init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(RequestMessageType.GQL_START.type)`" + ) + ) } payload = try container.decode(GraphQLRequest.self, forKey: .payload) id = try container.decode(String.self, forKey: .id) @@ -62,11 +67,14 @@ public struct StopRequest: Equatable, JsonEncodable { public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: Self.CodingKeys.self) - if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_CONNECTION_TERMINATE { - throw DecodingError.dataCorrupted(.init( - codingPath: decoder.codingPath, - debugDescription: "type must be `\(RequestMessageType.GQL_STOP.type)`" - )) + if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_CONNECTION_TERMINATE + { + throw DecodingError.dataCorrupted( + .init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(RequestMessageType.GQL_STOP.type)`" + ) + ) } id = try container.decode(String.self, forKey: .id) } @@ -80,11 +88,15 @@ public struct ConnectionTerminateRequest: Equatable, JsonEncodable { public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: Self.CodingKeys.self) - if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_CONNECTION_TERMINATE { - throw DecodingError.dataCorrupted(.init( - codingPath: decoder.codingPath, - debugDescription: "type must be `\(RequestMessageType.GQL_CONNECTION_TERMINATE.type)`" - )) + if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_CONNECTION_TERMINATE + { + throw DecodingError.dataCorrupted( + .init( + codingPath: decoder.codingPath, + debugDescription: + "type must be `\(RequestMessageType.GQL_CONNECTION_TERMINATE.type)`" + ) + ) } } } diff --git a/Sources/GraphQLWS/Responses.swift b/Sources/GraphQLWS/Responses.swift index f4e5511..c914eb2 100644 --- a/Sources/GraphQLWS/Responses.swift +++ b/Sources/GraphQLWS/Responses.swift @@ -18,10 +18,13 @@ public struct ConnectionAckResponse: Equatable, JsonEncodable { public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: Self.CodingKeys.self) if try container.decode(ResponseMessageType.self, forKey: .type) != .GQL_CONNECTION_ACK { - throw DecodingError.dataCorrupted(.init( - codingPath: decoder.codingPath, - debugDescription: "type must be `\(ResponseMessageType.GQL_CONNECTION_ACK.type)`" - )) + throw DecodingError.dataCorrupted( + .init( + codingPath: decoder.codingPath, + debugDescription: + "type must be `\(ResponseMessageType.GQL_CONNECTION_ACK.type)`" + ) + ) } payload = try container.decodeIfPresent([String: Map].self, forKey: .payload) } @@ -39,10 +42,13 @@ public struct ConnectionErrorResponse: Equatable, JsonEncodable { public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: Self.CodingKeys.self) if try container.decode(ResponseMessageType.self, forKey: .type) != .GQL_CONNECTION_ERROR { - throw DecodingError.dataCorrupted(.init( - codingPath: decoder.codingPath, - debugDescription: "type must be `\(ResponseMessageType.GQL_CONNECTION_ERROR.type)`" - )) + throw DecodingError.dataCorrupted( + .init( + codingPath: decoder.codingPath, + debugDescription: + "type must be `\(ResponseMessageType.GQL_CONNECTION_ERROR.type)`" + ) + ) } payload = try container.decodeIfPresent([String: Map].self, forKey: .payload) } @@ -59,11 +65,16 @@ public struct ConnectionKeepAliveResponse: Equatable, JsonEncodable { public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: Self.CodingKeys.self) - if try container.decode(ResponseMessageType.self, forKey: .type) != .GQL_CONNECTION_KEEP_ALIVE { - throw DecodingError.dataCorrupted(.init( - codingPath: decoder.codingPath, - debugDescription: "type must be `\(ResponseMessageType.GQL_CONNECTION_KEEP_ALIVE.type)`" - )) + if try container.decode(ResponseMessageType.self, forKey: .type) + != .GQL_CONNECTION_KEEP_ALIVE + { + throw DecodingError.dataCorrupted( + .init( + codingPath: decoder.codingPath, + debugDescription: + "type must be `\(ResponseMessageType.GQL_CONNECTION_KEEP_ALIVE.type)`" + ) + ) } payload = try container.decodeIfPresent([String: Map].self, forKey: .payload) } @@ -83,10 +94,12 @@ public struct DataResponse: Equatable, JsonEncodable { public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: Self.CodingKeys.self) if try container.decode(ResponseMessageType.self, forKey: .type) != .GQL_DATA { - throw DecodingError.dataCorrupted(.init( - codingPath: decoder.codingPath, - debugDescription: "type must be `\(ResponseMessageType.GQL_DATA.type)`" - )) + throw DecodingError.dataCorrupted( + .init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(ResponseMessageType.GQL_DATA.type)`" + ) + ) } payload = try container.decodeIfPresent(GraphQLResult.self, forKey: .payload) id = try container.decode(String.self, forKey: .id) @@ -105,10 +118,12 @@ public struct CompleteResponse: Equatable, JsonEncodable { public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: Self.CodingKeys.self) if try container.decode(ResponseMessageType.self, forKey: .type) != .GQL_COMPLETE { - throw DecodingError.dataCorrupted(.init( - codingPath: decoder.codingPath, - debugDescription: "type must be `\(ResponseMessageType.GQL_COMPLETE.type)`" - )) + throw DecodingError.dataCorrupted( + .init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(ResponseMessageType.GQL_COMPLETE.type)`" + ) + ) } id = try container.decode(String.self, forKey: .id) } @@ -136,10 +151,12 @@ public struct ErrorResponse: Equatable, JsonEncodable { public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: Self.CodingKeys.self) if try container.decode(ResponseMessageType.self, forKey: .type) != .GQL_ERROR { - throw DecodingError.dataCorrupted(.init( - codingPath: decoder.codingPath, - debugDescription: "type must be `\(ResponseMessageType.GQL_ERROR.type)`" - )) + throw DecodingError.dataCorrupted( + .init( + codingPath: decoder.codingPath, + debugDescription: "type must be `\(ResponseMessageType.GQL_ERROR.type)`" + ) + ) } payload = try container.decode([GraphQLError].self, forKey: .payload) id = try container.decode(String.self, forKey: .id) diff --git a/Sources/GraphQLWS/Server.swift b/Sources/GraphQLWS/Server.swift index 4fab56d..a04ef6b 100644 --- a/Sources/GraphQLWS/Server.swift +++ b/Sources/GraphQLWS/Server.swift @@ -8,7 +8,8 @@ public actor Server< InitPayload: Equatable & Codable & Sendable, InitPayloadResult: Sendable, SubscriptionSequenceType: AsyncSequence & Sendable -> where +> +where SubscriptionSequenceType.Element == GraphQLResult { let messenger: Messenger @@ -39,7 +40,8 @@ public actor Server< messenger: Messenger, onInit: @escaping (InitPayload) async throws -> InitPayloadResult, onExecute: @escaping (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult, - onSubscribe: @escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType, + onSubscribe: + @escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType, onOperationComplete: @escaping (String) async throws -> Void = { _ in }, onOperationError: @escaping (String, [Error]) async throws -> Void = { _, _ in } ) { @@ -53,7 +55,8 @@ public actor Server< /// Listen and react to the provided async sequence of client messages. This function will block until the stream is completed. /// - Parameter incoming: The client message sequence that the server should react to. - public func listen(to incoming: A) async throws -> Void where A.Element == String { + public func listen(to incoming: A) async throws + where A.Element == String { for try await message in incoming { // Detect and ignore error responses. if message.starts(with: "44") { @@ -77,7 +80,12 @@ public actor Server< // handle incoming message switch request.type { case .GQL_CONNECTION_INIT: - guard let connectionInitRequest = try? decoder.decode(ConnectionInitRequest.self, from: json) else { + guard + let connectionInitRequest = try? decoder.decode( + ConnectionInitRequest.self, + from: json + ) + else { try await error(.invalidRequestFormat(messageType: .GQL_CONNECTION_INIT)) return } @@ -95,7 +103,12 @@ public actor Server< } try await onStop(stopRequest) case .GQL_CONNECTION_TERMINATE: - guard let connectionTerminateRequest = try? decoder.decode(ConnectionTerminateRequest.self, from: json) else { + guard + let connectionTerminateRequest = try? decoder.decode( + ConnectionTerminateRequest.self, + from: json + ) + else { try await error(.invalidRequestFormat(messageType: .GQL_CONNECTION_TERMINATE)) return } @@ -110,7 +123,10 @@ public actor Server< subscriptionTasks.values.forEach { $0.cancel() } } - private func onConnectionInit(_ connectionInitRequest: ConnectionInitRequest, _: Messenger) async throws { + private func onConnectionInit( + _ connectionInitRequest: ConnectionInitRequest, + _: Messenger + ) async throws { guard !initialized else { try await error(.tooManyInitializations()) return @@ -189,7 +205,9 @@ public actor Server< try await onOperationComplete(id) } - private func onConnectionTerminate(_: ConnectionTerminateRequest, _ messenger: Messenger) async throws { + private func onConnectionTerminate(_: ConnectionTerminateRequest, _ messenger: Messenger) + async throws + { for (_, subscriptionTask) in subscriptionTasks { subscriptionTask.cancel() } diff --git a/Tests/GraphQLWSTests/GraphQLWSTests.swift b/Tests/GraphQLWSTests/GraphQLWSTests.swift index 22e0a32..11ccb58 100644 --- a/Tests/GraphQLWSTests/GraphQLWSTests.swift +++ b/Tests/GraphQLWSTests/GraphQLWSTests.swift @@ -28,7 +28,8 @@ struct GraphqlTransportWSTests { ).get() } ) - let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let (messageStream, messageContinuation) = AsyncThrowingStream + .makeStream() let serverMessageStream = serverMessenger.stream.map { message in messageContinuation.yield(message) // Expect only one message @@ -51,10 +52,10 @@ struct GraphqlTransportWSTests { try await client.sendStart( payload: GraphQLRequest( query: """ - query { - hello - } - """ + query { + hello + } + """ ), id: UUID().uuidString ) @@ -64,8 +65,7 @@ struct GraphqlTransportWSTests { result.append(message) } #expect( - messages == - ["\(ErrorCode.notInitialized): Connection not initialized"] + messages == ["\(ErrorCode.notInitialized): Connection not initialized"] ) } @@ -91,7 +91,8 @@ struct GraphqlTransportWSTests { ).get() } ) - let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let (messageStream, messageContinuation) = AsyncThrowingStream + .makeStream() let serverMessageStream = serverMessenger.stream.map { message in messageContinuation.yield(message) // Expect only one message @@ -122,8 +123,7 @@ struct GraphqlTransportWSTests { result.append(message) } #expect( - messages == - ["\(ErrorCode.unauthorized): Unauthorized"] + messages == ["\(ErrorCode.unauthorized): Unauthorized"] ) } @@ -149,7 +149,8 @@ struct GraphqlTransportWSTests { ).get() } ) - let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let (messageStream, messageContinuation) = AsyncThrowingStream + .makeStream() let serverMessageStream = serverMessenger.stream.map { message in messageContinuation.yield(message) return message @@ -160,10 +161,10 @@ struct GraphqlTransportWSTests { try await client.sendStart( payload: GraphQLRequest( query: """ - query { - hello - } - """ + query { + hello + } + """ ), id: id ) @@ -190,7 +191,7 @@ struct GraphqlTransportWSTests { result.append(message) } #expect( - messages.count == 3, // 1 connection_ack, 1 data, 1 complete + messages.count == 3, // 1 connection_ack, 1 data, 1 complete "Messages: \(messages.description)" ) } @@ -222,7 +223,8 @@ struct GraphqlTransportWSTests { return subscription } ) - let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let (messageStream, messageContinuation) = AsyncThrowingStream + .makeStream() // Used to extract the server messages let serverMessageStream = serverMessenger.stream.map { message in messageContinuation.yield(message) @@ -234,10 +236,10 @@ struct GraphqlTransportWSTests { try await client.sendStart( payload: GraphQLRequest( query: """ - subscription { - hello - } - """ + subscription { + hello + } + """ ), id: id ) @@ -276,7 +278,7 @@ struct GraphqlTransportWSTests { result.append(message) } #expect( - messages.count == 5, // 1 connection_ack, 3 next, 1 complete + messages.count == 5, // 1 connection_ack, 3 next, 1 complete "Messages: \(messages.description)" ) } diff --git a/Tests/GraphQLWSTests/Utils/TestAPI.swift b/Tests/GraphQLWSTests/Utils/TestAPI.swift index 8867da1..b6e8cca 100644 --- a/Tests/GraphQLWSTests/Utils/TestAPI.swift +++ b/Tests/GraphQLWSTests/Utils/TestAPI.swift @@ -1,6 +1,6 @@ import Foundation -import Graphiti import GraphQL +import Graphiti struct TestAPI: API { let resolver = TestResolver() @@ -29,7 +29,10 @@ struct TestResolver { context.hello() } - func subscribeHello(context: TestContext, arguments _: NoArguments) -> AsyncThrowingStream { + func subscribeHello( + context: TestContext, + arguments _: NoArguments + ) -> AsyncThrowingStream { context.publisher.subscribe() } } diff --git a/Tests/GraphQLWSTests/Utils/TestMessenger.swift b/Tests/GraphQLWSTests/Utils/TestMessenger.swift index e6931a9..ac86b02 100644 --- a/Tests/GraphQLWSTests/Utils/TestMessenger.swift +++ b/Tests/GraphQLWSTests/Utils/TestMessenger.swift @@ -1,5 +1,5 @@ - import Foundation + @testable import GraphQLWS /// Messenger for simple testing that doesn't require starting up a websocket server.