Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .swift-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"version": 1,
"indentation" : {
"spaces" : 4
},
"lineBreakBeforeEachArgument": true
}
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ let package = Package(
.library(
name: "GraphQLWS",
targets: ["GraphQLWS"]
),
)
],
dependencies: [
.package(url: "https://github.com/GraphQLSwift/Graphiti.git", from: "3.0.0"),
Expand Down
38 changes: 30 additions & 8 deletions Sources/GraphQLWS/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,14 @@ public actor Client<InitPayload: Equatable & Codable> {
/// - onComplete: The callback run on receipt of a `complete` message
public init(
messenger: Messenger,
onConnectionError: @escaping (ConnectionErrorResponse, Client) async throws -> Void = { _, _ in },
onConnectionAck: @escaping (ConnectionAckResponse, Client) async throws -> Void = { _, _ in },
onConnectionKeepAlive: @escaping (ConnectionKeepAliveResponse, Client) async throws -> Void = { _, _ in },
onConnectionError: @escaping (ConnectionErrorResponse, Client) async throws -> Void = {
_,
_ in
},
onConnectionAck: @escaping (ConnectionAckResponse, Client) async throws -> Void = { _, _ in
},
onConnectionKeepAlive:
@escaping (ConnectionKeepAliveResponse, Client) async throws -> Void = { _, _ in },
onData: @escaping (DataResponse, Client) async throws -> Void = { _, _ in },
onError: @escaping (ErrorResponse, Client) async throws -> Void = { _, _ in },
onComplete: @escaping (CompleteResponse, Client) async throws -> Void = { _, _ in }
Expand All @@ -45,7 +50,8 @@ public actor Client<InitPayload: Equatable & Codable> {

/// Listen and react to the provided async sequence of server messages. This function will block until the stream is completed.
/// - Parameter incoming: The server message sequence that the client should react to.
public func listen<A: AsyncSequence & Sendable>(to incoming: A) async throws -> Void where A.Element == String {
public func listen<A: AsyncSequence & Sendable>(to incoming: A) async throws
where A.Element == String {
for try await message in incoming {
// Detect and ignore error responses.
if message.starts(with: "44") {
Expand All @@ -68,19 +74,34 @@ public actor Client<InitPayload: Equatable & Codable> {

switch response.type {
case .GQL_CONNECTION_ERROR:
guard let connectionErrorResponse = try? decoder.decode(ConnectionErrorResponse.self, from: json) else {
guard
let connectionErrorResponse = try? decoder.decode(
ConnectionErrorResponse.self,
from: json
)
else {
try await error(.invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR))
return
}
try await onConnectionError(connectionErrorResponse, self)
case .GQL_CONNECTION_ACK:
guard let connectionAckResponse = try? decoder.decode(ConnectionAckResponse.self, from: json) else {
guard
let connectionAckResponse = try? decoder.decode(
ConnectionAckResponse.self,
from: json
)
else {
try await error(.invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR))
return
}
try await onConnectionAck(connectionAckResponse, self)
case .GQL_CONNECTION_KEEP_ALIVE:
guard let connectionKeepAliveResponse = try? decoder.decode(ConnectionKeepAliveResponse.self, from: json) else {
guard
let connectionKeepAliveResponse = try? decoder.decode(
ConnectionKeepAliveResponse.self,
from: json
)
else {
try await error(.invalidResponseFormat(messageType: .GQL_CONNECTION_KEEP_ALIVE))
return
}
Expand All @@ -98,7 +119,8 @@ public actor Client<InitPayload: Equatable & Codable> {
}
try await onError(errorResponse, self)
case .GQL_COMPLETE:
guard let completeResponse = try? decoder.decode(CompleteResponse.self, from: json) else {
guard let completeResponse = try? decoder.decode(CompleteResponse.self, from: json)
else {
try await error(.invalidResponseFormat(messageType: .GQL_COMPLETE))
return
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/GraphQLWS/Messenger.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import Foundation
public protocol Messenger: Sendable {
/// Send a message through this messenger
/// - Parameter message: The message to send
func send<S: Sendable & Collection>(_ message: S) async throws -> Void where S.Element == Character
func send<S: Sendable & Collection>(_ message: S) async throws where S.Element == Character

/// Close the messenger
func close() async throws
Expand Down
48 changes: 30 additions & 18 deletions Sources/GraphQLWS/Requests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ public struct ConnectionInitRequest<InitPayload: Codable & Equatable>: Equatable
public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: Self.CodingKeys.self)
if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_CONNECTION_INIT {
throw DecodingError.dataCorrupted(.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(RequestMessageType.GQL_CONNECTION_INIT.type)`"
))
throw DecodingError.dataCorrupted(
.init(
codingPath: decoder.codingPath,
debugDescription:
"type must be `\(RequestMessageType.GQL_CONNECTION_INIT.type)`"
)
)
}
payload = try container.decode(InitPayload.self, forKey: .payload)
}
Expand All @@ -41,10 +44,12 @@ public struct StartRequest: Equatable, JsonEncodable {
public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: Self.CodingKeys.self)
if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_START {
throw DecodingError.dataCorrupted(.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(RequestMessageType.GQL_START.type)`"
))
throw DecodingError.dataCorrupted(
.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(RequestMessageType.GQL_START.type)`"
)
)
}
payload = try container.decode(GraphQLRequest.self, forKey: .payload)
id = try container.decode(String.self, forKey: .id)
Expand All @@ -62,11 +67,14 @@ public struct StopRequest: Equatable, JsonEncodable {

public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: Self.CodingKeys.self)
if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_CONNECTION_TERMINATE {
throw DecodingError.dataCorrupted(.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(RequestMessageType.GQL_STOP.type)`"
))
if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_CONNECTION_TERMINATE
{
throw DecodingError.dataCorrupted(
.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(RequestMessageType.GQL_STOP.type)`"
)
)
}
id = try container.decode(String.self, forKey: .id)
}
Expand All @@ -80,11 +88,15 @@ public struct ConnectionTerminateRequest: Equatable, JsonEncodable {

public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: Self.CodingKeys.self)
if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_CONNECTION_TERMINATE {
throw DecodingError.dataCorrupted(.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(RequestMessageType.GQL_CONNECTION_TERMINATE.type)`"
))
if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_CONNECTION_TERMINATE
{
throw DecodingError.dataCorrupted(
.init(
codingPath: decoder.codingPath,
debugDescription:
"type must be `\(RequestMessageType.GQL_CONNECTION_TERMINATE.type)`"
)
)
}
}
}
Expand Down
67 changes: 42 additions & 25 deletions Sources/GraphQLWS/Responses.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ public struct ConnectionAckResponse: Equatable, JsonEncodable {
public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: Self.CodingKeys.self)
if try container.decode(ResponseMessageType.self, forKey: .type) != .GQL_CONNECTION_ACK {
throw DecodingError.dataCorrupted(.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(ResponseMessageType.GQL_CONNECTION_ACK.type)`"
))
throw DecodingError.dataCorrupted(
.init(
codingPath: decoder.codingPath,
debugDescription:
"type must be `\(ResponseMessageType.GQL_CONNECTION_ACK.type)`"
)
)
}
payload = try container.decodeIfPresent([String: Map].self, forKey: .payload)
}
Expand All @@ -39,10 +42,13 @@ public struct ConnectionErrorResponse: Equatable, JsonEncodable {
public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: Self.CodingKeys.self)
if try container.decode(ResponseMessageType.self, forKey: .type) != .GQL_CONNECTION_ERROR {
throw DecodingError.dataCorrupted(.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(ResponseMessageType.GQL_CONNECTION_ERROR.type)`"
))
throw DecodingError.dataCorrupted(
.init(
codingPath: decoder.codingPath,
debugDescription:
"type must be `\(ResponseMessageType.GQL_CONNECTION_ERROR.type)`"
)
)
}
payload = try container.decodeIfPresent([String: Map].self, forKey: .payload)
}
Expand All @@ -59,11 +65,16 @@ public struct ConnectionKeepAliveResponse: Equatable, JsonEncodable {

public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: Self.CodingKeys.self)
if try container.decode(ResponseMessageType.self, forKey: .type) != .GQL_CONNECTION_KEEP_ALIVE {
throw DecodingError.dataCorrupted(.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(ResponseMessageType.GQL_CONNECTION_KEEP_ALIVE.type)`"
))
if try container.decode(ResponseMessageType.self, forKey: .type)
!= .GQL_CONNECTION_KEEP_ALIVE
{
throw DecodingError.dataCorrupted(
.init(
codingPath: decoder.codingPath,
debugDescription:
"type must be `\(ResponseMessageType.GQL_CONNECTION_KEEP_ALIVE.type)`"
)
)
}
payload = try container.decodeIfPresent([String: Map].self, forKey: .payload)
}
Expand All @@ -83,10 +94,12 @@ public struct DataResponse: Equatable, JsonEncodable {
public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: Self.CodingKeys.self)
if try container.decode(ResponseMessageType.self, forKey: .type) != .GQL_DATA {
throw DecodingError.dataCorrupted(.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(ResponseMessageType.GQL_DATA.type)`"
))
throw DecodingError.dataCorrupted(
.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(ResponseMessageType.GQL_DATA.type)`"
)
)
}
payload = try container.decodeIfPresent(GraphQLResult.self, forKey: .payload)
id = try container.decode(String.self, forKey: .id)
Expand All @@ -105,10 +118,12 @@ public struct CompleteResponse: Equatable, JsonEncodable {
public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: Self.CodingKeys.self)
if try container.decode(ResponseMessageType.self, forKey: .type) != .GQL_COMPLETE {
throw DecodingError.dataCorrupted(.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(ResponseMessageType.GQL_COMPLETE.type)`"
))
throw DecodingError.dataCorrupted(
.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(ResponseMessageType.GQL_COMPLETE.type)`"
)
)
}
id = try container.decode(String.self, forKey: .id)
}
Expand Down Expand Up @@ -136,10 +151,12 @@ public struct ErrorResponse: Equatable, JsonEncodable {
public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: Self.CodingKeys.self)
if try container.decode(ResponseMessageType.self, forKey: .type) != .GQL_ERROR {
throw DecodingError.dataCorrupted(.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(ResponseMessageType.GQL_ERROR.type)`"
))
throw DecodingError.dataCorrupted(
.init(
codingPath: decoder.codingPath,
debugDescription: "type must be `\(ResponseMessageType.GQL_ERROR.type)`"
)
)
}
payload = try container.decode([GraphQLError].self, forKey: .payload)
id = try container.decode(String.self, forKey: .id)
Expand Down
32 changes: 25 additions & 7 deletions Sources/GraphQLWS/Server.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ public actor Server<
InitPayload: Equatable & Codable & Sendable,
InitPayloadResult: Sendable,
SubscriptionSequenceType: AsyncSequence & Sendable
> where
>
where
SubscriptionSequenceType.Element == GraphQLResult
{
let messenger: Messenger
Expand Down Expand Up @@ -39,7 +40,8 @@ public actor Server<
messenger: Messenger,
onInit: @escaping (InitPayload) async throws -> InitPayloadResult,
onExecute: @escaping (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult,
onSubscribe: @escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType,
onSubscribe:
@escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType,
onOperationComplete: @escaping (String) async throws -> Void = { _ in },
onOperationError: @escaping (String, [Error]) async throws -> Void = { _, _ in }
) {
Expand All @@ -53,7 +55,8 @@ public actor Server<

/// Listen and react to the provided async sequence of client messages. This function will block until the stream is completed.
/// - Parameter incoming: The client message sequence that the server should react to.
public func listen<A: AsyncSequence & Sendable>(to incoming: A) async throws -> Void where A.Element == String {
public func listen<A: AsyncSequence & Sendable>(to incoming: A) async throws
where A.Element == String {
for try await message in incoming {
// Detect and ignore error responses.
if message.starts(with: "44") {
Expand All @@ -77,7 +80,12 @@ public actor Server<
// handle incoming message
switch request.type {
case .GQL_CONNECTION_INIT:
guard let connectionInitRequest = try? decoder.decode(ConnectionInitRequest<InitPayload>.self, from: json) else {
guard
let connectionInitRequest = try? decoder.decode(
ConnectionInitRequest<InitPayload>.self,
from: json
)
else {
try await error(.invalidRequestFormat(messageType: .GQL_CONNECTION_INIT))
return
}
Expand All @@ -95,7 +103,12 @@ public actor Server<
}
try await onStop(stopRequest)
case .GQL_CONNECTION_TERMINATE:
guard let connectionTerminateRequest = try? decoder.decode(ConnectionTerminateRequest.self, from: json) else {
guard
let connectionTerminateRequest = try? decoder.decode(
ConnectionTerminateRequest.self,
from: json
)
else {
try await error(.invalidRequestFormat(messageType: .GQL_CONNECTION_TERMINATE))
return
}
Expand All @@ -110,7 +123,10 @@ public actor Server<
subscriptionTasks.values.forEach { $0.cancel() }
}

private func onConnectionInit(_ connectionInitRequest: ConnectionInitRequest<InitPayload>, _: Messenger) async throws {
private func onConnectionInit(
_ connectionInitRequest: ConnectionInitRequest<InitPayload>,
_: Messenger
) async throws {
guard !initialized else {
try await error(.tooManyInitializations())
return
Expand Down Expand Up @@ -189,7 +205,9 @@ public actor Server<
try await onOperationComplete(id)
}

private func onConnectionTerminate(_: ConnectionTerminateRequest, _ messenger: Messenger) async throws {
private func onConnectionTerminate(_: ConnectionTerminateRequest, _ messenger: Messenger)
async throws
{
for (_, subscriptionTask) in subscriptionTasks {
subscriptionTask.cancel()
}
Expand Down
Loading
Loading