From c1aea94dc990b286c305abe324564ad7af9d1549 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Wed, 11 Feb 2026 22:37:48 -0700 Subject: [PATCH 1/8] feat!: Passes init payload result This automatically propagates the init payload result from the callback into the `onExecute` and `onSubscribe` closures. Since the init callback is usually used to determine authentication and authorization, this should be usable from our execution and subscription calls, and lifecycle is most easily managed within this package --- Sources/GraphQLWS/Server.swift | 30 ++++---- Tests/GraphQLWSTests/GraphQLWSTests.swift | 92 +++++++++++++++++------ 2 files changed, 84 insertions(+), 38 deletions(-) diff --git a/Sources/GraphQLWS/Server.swift b/Sources/GraphQLWS/Server.swift index b6bd98f..e943209 100644 --- a/Sources/GraphQLWS/Server.swift +++ b/Sources/GraphQLWS/Server.swift @@ -6,6 +6,7 @@ import GraphQL /// By default, there are no authorization checks public class Server< InitPayload: Equatable & Codable & Sendable, + InitPayloadResult: Sendable, SubscriptionSequenceType: AsyncSequence & Sendable >: @unchecked Sendable where SubscriptionSequenceType.Element == GraphQLResult @@ -13,9 +14,9 @@ public class Server< // We keep this weak because we strongly inject this object into the messenger callback weak var messenger: Messenger? - let onExecute: (GraphQLRequest) async throws -> GraphQLResult - let onSubscribe: (GraphQLRequest) async throws -> SubscriptionSequenceType - var auth: (InitPayload) async throws -> Void + let onInit: (InitPayload) async throws -> InitPayloadResult + let onExecute: (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult + let onSubscribe: (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType var onExit: () async throws -> Void = {} var onMessage: (String) async throws -> Void = { _ in } @@ -23,6 +24,7 @@ public class Server< var onOperationError: (String, [Error]) async throws -> Void = { _, _ in } var initialized = false + var initResult: InitPayloadResult? let decoder = JSONDecoder() let encoder = GraphQLJSONEncoder() @@ -37,13 +39,14 @@ public class Server< /// - onSubscribe: Callback run during `start` resolution for streaming queries. Typically this is `API.subscribe`. public init( messenger: Messenger, - onExecute: @escaping (GraphQLRequest) async throws -> GraphQLResult, - onSubscribe: @escaping (GraphQLRequest) async throws -> SubscriptionSequenceType + onInit: @escaping (InitPayload) async throws -> InitPayloadResult, + onExecute: @escaping (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult, + onSubscribe: @escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType ) { self.messenger = messenger + self.onInit = onInit self.onExecute = onExecute self.onSubscribe = onSubscribe - auth = { _ in } messenger.onReceive { message in guard let messenger = self.messenger else { return } @@ -105,13 +108,6 @@ public class Server< subscriptionTasks.values.forEach { $0.cancel() } } - /// Define a custom callback run during `connection_init` resolution that allows authorization using the `payload`. - /// Throw from this closure to indicate that authorization has failed. - /// - Parameter callback: The callback to assign - public func auth(_ callback: @escaping (InitPayload) async throws -> Void) { - auth = callback - } - /// Define the callback run when the communication is shut down, either by the client or server /// - Parameter callback: The callback to assign public func onExit(_ callback: @escaping () -> Void) { @@ -143,7 +139,7 @@ public class Server< } do { - try await auth(connectionInitRequest.payload) + initResult = try await onInit(connectionInitRequest.payload) } catch { try await self.error(.unauthorized()) return @@ -154,7 +150,7 @@ public class Server< } private func onStart(_ startRequest: StartRequest, _ messenger: Messenger) async throws { - guard initialized else { + guard initialized, let initResult else { try await error(.notInitialized()) return } @@ -177,7 +173,7 @@ public class Server< if isStreaming { subscriptionTasks[id] = Task { do { - let stream = try await onSubscribe(graphQLRequest) + let stream = try await onSubscribe(graphQLRequest, initResult) for try await event in stream { try Task.checkCancellation() try await self.sendData(event, id: id) @@ -192,7 +188,7 @@ public class Server< } } else { do { - let result = try await onExecute(graphQLRequest) + let result = try await onExecute(graphQLRequest, initResult) try await sendData(result, id: id) try await sendComplete(id: id) } catch { diff --git a/Tests/GraphQLWSTests/GraphQLWSTests.swift b/Tests/GraphQLWSTests/GraphQLWSTests.swift index 37457b7..b2f0b31 100644 --- a/Tests/GraphQLWSTests/GraphQLWSTests.swift +++ b/Tests/GraphQLWSTests/GraphQLWSTests.swift @@ -8,42 +8,39 @@ import GraphQLWS class GraphqlWsTests: XCTestCase { var clientMessenger: TestMessenger! var serverMessenger: TestMessenger! - var server: Server>! - var context: TestContext! var subscribeReady: Bool! = false + let context = TestContext() + let api = TestAPI() + override func setUp() { // Point the client and server at each other clientMessenger = TestMessenger() serverMessenger = TestMessenger() clientMessenger.other = serverMessenger serverMessenger.other = clientMessenger + } - let api = TestAPI() - let context = TestContext() - - server = .init( + /// Tests that trying to run methods before `connection_init` is not allowed + func testInitialize() async throws { + let server = Server>( messenger: serverMessenger, - onExecute: { graphQLRequest in - try await api.execute( + onInit: { _ in }, + onExecute: { graphQLRequest, _ in + try await self.api.execute( request: graphQLRequest.query, - context: context + context: self.context ) }, - onSubscribe: { graphQLRequest in - let subscription = try await api.subscribe( + onSubscribe: { graphQLRequest, _ in + let subscription = try await self.api.subscribe( request: graphQLRequest.query, - context: context + context: self.context ).get() self.subscribeReady = true return subscription } ) - self.context = context - } - - /// Tests that trying to run methods before `connection_init` is not allowed - func testInitialize() async throws { let client = Client(messenger: clientMessenger) let messageStream = AsyncThrowingStream { continuation in client.onMessage { message, _ in @@ -78,10 +75,26 @@ class GraphqlWsTests: XCTestCase { /// Tests that throwing in the authorization callback forces an unauthorized error func testAuthWithThrow() async throws { - server.auth { _ in - throw TestError.couldBeAnything - } - + let server = Server>( + messenger: serverMessenger, + onInit: { _ in + throw TestError.couldBeAnything + }, + onExecute: { graphQLRequest, _ in + try await self.api.execute( + request: graphQLRequest.query, + context: self.context + ) + }, + onSubscribe: { graphQLRequest, _ in + let subscription = try await self.api.subscribe( + request: graphQLRequest.query, + context: self.context + ).get() + self.subscribeReady = true + return subscription + } + ) let client = Client(messenger: clientMessenger) let messageStream = AsyncThrowingStream { continuation in client.onMessage { message, _ in @@ -111,6 +124,24 @@ class GraphqlWsTests: XCTestCase { /// Test single op message flow works as expected func testSingleOp() async throws { + let server = Server>( + messenger: serverMessenger, + onInit: { _ in }, + onExecute: { graphQLRequest, _ in + try await self.api.execute( + request: graphQLRequest.query, + context: self.context + ) + }, + onSubscribe: { graphQLRequest, _ in + let subscription = try await self.api.subscribe( + request: graphQLRequest.query, + context: self.context + ).get() + self.subscribeReady = true + return subscription + } + ) let id = UUID().description let client = Client(messenger: clientMessenger) @@ -152,6 +183,25 @@ class GraphqlWsTests: XCTestCase { /// Test streaming message flow works as expected func testStreaming() async throws { + let server = Server>( + messenger: serverMessenger, + onInit: { _ in }, + onExecute: { graphQLRequest, _ in + try await self.api.execute( + request: graphQLRequest.query, + context: self.context + ) + }, + onSubscribe: { graphQLRequest, _ in + let subscription = try await self.api.subscribe( + request: graphQLRequest.query, + context: self.context + ).get() + self.subscribeReady = true + return subscription + } + ) + let id = UUID().description var dataIndex = 1 From 169e8a752f7f66bfe3438947642671492934fcef Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Wed, 11 Feb 2026 22:48:51 -0700 Subject: [PATCH 2/8] feat!: Messenger is send-only To support receiving messages, we added `listen` functions to server and client. This resolves the confusing ownership rules by avoiding `onReceive` callbacks in Messenger. --- README.md | 30 ++++--------- Sources/GraphQLWS/Client.swift | 16 +++---- Sources/GraphQLWS/Messenger.swift | 10 ++--- Sources/GraphQLWS/Server.swift | 20 +++------ Tests/GraphQLWSTests/GraphQLWSTests.swift | 44 ++++++++++++++++--- .../GraphQLWSTests/Utils/TestMessenger.swift | 33 ++++++-------- 6 files changed, 78 insertions(+), 75 deletions(-) diff --git a/README.md b/README.md index b97e2c4..a4dfe11 100644 --- a/README.md +++ b/README.md @@ -26,32 +26,21 @@ import GraphQLWS /// Messenger wrapper for WebSockets class WebSocketMessenger: Messenger { - private weak var websocket: WebSocket? - private var onReceive: (String) -> Void = { _ in } - + let websocket: WebSocket + init(websocket: WebSocket) { self.websocket = websocket - websocket.onText { _, message in - try await self.onReceive(message) - } } func send(_ message: S) async throws where S: Collection, S.Element == Character async throws { - guard let websocket = websocket else { return } try await websocket.send(message) } - func onReceive(callback: @escaping (String) async throws -> Void) { - self.onReceive = callback - } - func error(_ message: String, code: Int) async throws { - guard let websocket = websocket else { return } try await websocket.send("\(code): \(message)") } func close() async throws { - guard let websocket = websocket else { return } try await websocket.close() } } @@ -85,6 +74,12 @@ routes.webSocket( ) } ) + let incoming = AsyncStream { continuation in + websocket.onText { _, message in + continuation.yield(message) + } + } + try await server.listen(to: incoming) } ) ``` @@ -125,12 +120,3 @@ This example would require `connection_init` message from the client to look lik ``` If the `payload` field is not required on your server, you may make Server's generic declaration optional like `Server` - -## Memory Management - -Memory ownership among the Server, Client, and Messenger may seem a little backwards. This is because the Swift/Vapor WebSocket -implementation persists WebSocket objects long after their callback and they are expected to retain strong memory references to the -objects required for responses. In order to align cleanly and avoid memory cycles, Server and Client are injected strongly into Messenger -callbacks, and only hold weak references to their Messenger. This means that Messenger objects (or their enclosing WebSocket) must -be persisted to have the connected Server or Client objects function. That is, if a Server's Messenger falls out of scope and deinitializes, -the Server will no longer respond to messages. diff --git a/Sources/GraphQLWS/Client.swift b/Sources/GraphQLWS/Client.swift index 0dd16ab..6a75966 100644 --- a/Sources/GraphQLWS/Client.swift +++ b/Sources/GraphQLWS/Client.swift @@ -2,9 +2,9 @@ import Foundation import GraphQL /// Client is an open-ended implementation of the client side of the protocol. It parses and adds callbacks for each type of server respose. -public class Client { +public class Client: @unchecked Sendable { // We keep this weak because we strongly inject this object into the messenger callback - weak var messenger: Messenger? + let messenger: Messenger var onConnectionError: (ConnectionErrorResponse, Client) async throws -> Void = { _, _ in } var onConnectionAck: (ConnectionAckResponse, Client) async throws -> Void = { _, _ in } @@ -25,7 +25,12 @@ public class Client { messenger: Messenger ) { self.messenger = messenger - messenger.onReceive { message in + } + + /// Listen and react to the provided async sequence of server messages. This function will block until the stream is completed. + /// - Parameter incoming: The server message sequence that the client should react to. + public func listen(to incoming: A) async throws -> Void where A.Element == String { + for try await message in incoming { try await self.onMessage(message, self) // Detect and ignore error responses. @@ -134,7 +139,6 @@ public class Client { /// Send a `connection_init` request through the messenger public func sendConnectionInit(payload: InitPayload) async throws { - guard let messenger = messenger else { return } try await messenger.send( ConnectionInitRequest( payload: payload @@ -144,7 +148,6 @@ public class Client { /// Send a `start` request through the messenger public func sendStart(payload: GraphQLRequest, id: String) async throws { - guard let messenger = messenger else { return } try await messenger.send( StartRequest( payload: payload, @@ -155,7 +158,6 @@ public class Client { /// Send a `stop` request through the messenger public func sendStop(id: String) async throws { - guard let messenger = messenger else { return } try await messenger.send( StopRequest( id: id @@ -165,7 +167,6 @@ public class Client { /// Send a `connection_terminate` request through the messenger public func sendConnectionTerminate() async throws { - guard let messenger = messenger else { return } try await messenger.send( ConnectionTerminateRequest().toJSON(encoder) ) @@ -173,7 +174,6 @@ public class Client { /// Send an error through the messenger and close the connection private func error(_ error: GraphQLWSError) async throws { - guard let messenger = messenger else { return } try await messenger.error(error.message, code: error.code.rawValue) } } diff --git a/Sources/GraphQLWS/Messenger.swift b/Sources/GraphQLWS/Messenger.swift index 3a9c157..e0bcf37 100644 --- a/Sources/GraphQLWS/Messenger.swift +++ b/Sources/GraphQLWS/Messenger.swift @@ -1,15 +1,11 @@ import Foundation -/// Protocol for an object that can send and recieve messages. This allows mocking in tests -public protocol Messenger: AnyObject { - // AnyObject compliance requires that the implementing object is a class and we can reference it weakly +/// Protocol for an object that can send messages. This allows mocking in tests +public protocol Messenger { /// Send a message through this messenger /// - Parameter message: The message to send - func send(_ message: S) async throws -> Void where S: Collection, S.Element == Character - - /// Set the callback that should be run when a message is recieved - func onReceive(callback: @escaping (String) async throws -> Void) + func send(_ message: S) async throws -> Void where S: Collection, S.Element == Character /// Close the messenger func close() async throws diff --git a/Sources/GraphQLWS/Server.swift b/Sources/GraphQLWS/Server.swift index e943209..4a419bb 100644 --- a/Sources/GraphQLWS/Server.swift +++ b/Sources/GraphQLWS/Server.swift @@ -11,8 +11,7 @@ public class Server< >: @unchecked Sendable where SubscriptionSequenceType.Element == GraphQLResult { - // We keep this weak because we strongly inject this object into the messenger callback - weak var messenger: Messenger? + let messenger: Messenger let onInit: (InitPayload) async throws -> InitPayloadResult let onExecute: (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult @@ -47,10 +46,12 @@ public class Server< self.onInit = onInit self.onExecute = onExecute self.onSubscribe = onSubscribe - - messenger.onReceive { message in - guard let messenger = self.messenger else { return } - + } + + /// Listen and react to the provided async sequence of client messages. This function will block until the stream is completed. + /// - Parameter incoming: The client message sequence that the server should react to. + public func listen(to incoming: A) async throws -> Void where A.Element == String { + for try await message in incoming { try await self.onMessage(message) // Detect and ignore error responses. @@ -223,7 +224,6 @@ public class Server< /// Send a `connection_ack` response through the messenger private func sendConnectionAck(_ payload: [String: Map]? = nil) async throws { - guard let messenger = messenger else { return } try await messenger.send( ConnectionAckResponse(payload: payload).toJSON(encoder) ) @@ -231,7 +231,6 @@ public class Server< /// Send a `connection_error` response through the messenger private func sendConnectionError(_ payload: [String: Map]? = nil) async throws { - guard let messenger = messenger else { return } try await messenger.send( ConnectionErrorResponse(payload: payload).toJSON(encoder) ) @@ -239,7 +238,6 @@ public class Server< /// Send a `ka` response through the messenger private func sendConnectionKeepAlive(_ payload: [String: Map]? = nil) async throws { - guard let messenger = messenger else { return } try await messenger.send( ConnectionKeepAliveResponse(payload: payload).toJSON(encoder) ) @@ -247,7 +245,6 @@ public class Server< /// Send a `data` response through the messenger private func sendData(_ payload: GraphQLResult? = nil, id: String) async throws { - guard let messenger = messenger else { return } try await messenger.send( DataResponse( payload: payload, @@ -258,7 +255,6 @@ public class Server< /// Send a `complete` response through the messenger private func sendComplete(id: String) async throws { - guard let messenger = messenger else { return } try await messenger.send( CompleteResponse( id: id @@ -269,7 +265,6 @@ public class Server< /// Send an `error` response through the messenger private func sendError(_ errors: [Error], id: String) async throws { - guard let messenger = messenger else { return } try await messenger.send( ErrorResponse( errors, @@ -291,7 +286,6 @@ public class Server< /// Send an error through the messenger and close the connection private func error(_ error: GraphQLWSError) async throws { - guard let messenger = messenger else { return } try await messenger.error(error.message, code: error.code.rawValue) } } diff --git a/Tests/GraphQLWSTests/GraphQLWSTests.swift b/Tests/GraphQLWSTests/GraphQLWSTests.swift index b2f0b31..b510987 100644 --- a/Tests/GraphQLWSTests/GraphQLWSTests.swift +++ b/Tests/GraphQLWSTests/GraphQLWSTests.swift @@ -14,11 +14,8 @@ class GraphqlWsTests: XCTestCase { let api = TestAPI() override func setUp() { - // Point the client and server at each other clientMessenger = TestMessenger() serverMessenger = TestMessenger() - clientMessenger.other = serverMessenger - serverMessenger.other = clientMessenger } /// Tests that trying to run methods before `connection_init` is not allowed @@ -42,6 +39,15 @@ class GraphqlWsTests: XCTestCase { } ) let client = Client(messenger: clientMessenger) + let serverStream = serverMessenger.stream + let clientStream = clientMessenger.stream + Task { + try await server.listen(to: clientStream) + } + Task { + try await client.listen(to: serverStream) + } + let messageStream = AsyncThrowingStream { continuation in client.onMessage { message, _ in continuation.yield(message) @@ -96,6 +102,15 @@ class GraphqlWsTests: XCTestCase { } ) let client = Client(messenger: clientMessenger) + let serverStream = serverMessenger.stream + let clientStream = clientMessenger.stream + Task { + try await server.listen(to: clientStream) + } + Task { + try await client.listen(to: serverStream) + } + let messageStream = AsyncThrowingStream { continuation in client.onMessage { message, _ in continuation.yield(message) @@ -142,9 +157,18 @@ class GraphqlWsTests: XCTestCase { return subscription } ) - let id = UUID().description - let client = Client(messenger: clientMessenger) + let serverStream = serverMessenger.stream + let clientStream = clientMessenger.stream + Task { + try await server.listen(to: clientStream) + } + Task { + try await client.listen(to: serverStream) + } + + let id = UUID().description + let messageStream = AsyncThrowingStream { continuation in client.onConnectionAck { _, client in try await client.sendStart( @@ -201,13 +225,21 @@ class GraphqlWsTests: XCTestCase { return subscription } ) + let client = Client(messenger: clientMessenger) + let serverStream = serverMessenger.stream + let clientStream = clientMessenger.stream + Task { + try await server.listen(to: clientStream) + } + Task { + try await client.listen(to: serverStream) + } let id = UUID().description var dataIndex = 1 let dataIndexMax = 3 - let client = Client(messenger: clientMessenger) let messageStream = AsyncThrowingStream { continuation in client.onConnectionAck { _, client in try await client.sendStart( diff --git a/Tests/GraphQLWSTests/Utils/TestMessenger.swift b/Tests/GraphQLWSTests/Utils/TestMessenger.swift index 803f080..d775715 100644 --- a/Tests/GraphQLWSTests/Utils/TestMessenger.swift +++ b/Tests/GraphQLWSTests/Utils/TestMessenger.swift @@ -4,32 +4,27 @@ import Foundation @testable import GraphQLWS /// Messenger for simple testing that doesn't require starting up a websocket server. -/// -/// Note that this only retains a weak reference to 'other', so the client should retain references -/// or risk them being deinitialized early -class TestMessenger: Messenger, @unchecked Sendable { - weak var other: TestMessenger? - var onReceive: (String) async throws -> Void = { _ in } - let queue: DispatchQueue = .init(label: "Test messenger") - - init() {} - - func send(_ message: S) async throws where S: Collection, S.Element == Character { - guard let other = other else { - return - } - try await other.onReceive(String(message)) +actor TestMessenger: Messenger { + /// An async stream of the messages sent through this messenger. + let stream: AsyncStream + private var continuation: AsyncStream.Continuation + + init() { + let (stream, continuation) = AsyncStream.makeStream() + self.stream = stream + self.continuation = continuation } - func onReceive(callback: @escaping (String) async throws -> Void) { - onReceive = callback + func send(_ message: S) async throws where S.Element == Character { + continuation.yield(String(message)) } func error(_ message: String, code: Int) async throws { - try await send("\(code): \(message)") + continuation.yield("\(code): \(message)") + continuation.finish() } func close() { - // This is a testing no-op + continuation.finish() } } From 8b2cd4dbcef170bb3a1c1c46ba358e4f9f4856c9 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Wed, 11 Feb 2026 23:09:40 -0700 Subject: [PATCH 3/8] feat!: Actor conversion and Swift Testing Client and Server became actors to ensure sendability, and Messenger was marked sendable --- Sources/GraphQLWS/Client.swift | 83 +++---- Sources/GraphQLWS/Messenger.swift | 2 +- Sources/GraphQLWS/Server.swift | 52 ++--- Tests/GraphQLWSTests/GraphQLWSTests.swift | 255 ++++++++++------------ 4 files changed, 170 insertions(+), 222 deletions(-) diff --git a/Sources/GraphQLWS/Client.swift b/Sources/GraphQLWS/Client.swift index 6a75966..975292c 100644 --- a/Sources/GraphQLWS/Client.swift +++ b/Sources/GraphQLWS/Client.swift @@ -2,29 +2,50 @@ import Foundation import GraphQL /// Client is an open-ended implementation of the client side of the protocol. It parses and adds callbacks for each type of server respose. -public class Client: @unchecked Sendable { +public actor Client { // We keep this weak because we strongly inject this object into the messenger callback let messenger: Messenger - var onConnectionError: (ConnectionErrorResponse, Client) async throws -> Void = { _, _ in } - var onConnectionAck: (ConnectionAckResponse, Client) async throws -> Void = { _, _ in } - var onConnectionKeepAlive: (ConnectionKeepAliveResponse, Client) async throws -> Void = { _, _ in } - var onData: (DataResponse, Client) async throws -> Void = { _, _ in } - var onError: (ErrorResponse, Client) async throws -> Void = { _, _ in } - var onComplete: (CompleteResponse, Client) async throws -> Void = { _, _ in } - var onMessage: (String, Client) async throws -> Void = { _, _ in } + let onConnectionError: (ConnectionErrorResponse, Client) async throws -> Void + let onConnectionAck: (ConnectionAckResponse, Client) async throws -> Void + let onConnectionKeepAlive: (ConnectionKeepAliveResponse, Client) async throws -> Void + let onData: (DataResponse, Client) async throws -> Void + let onError: (ErrorResponse, Client) async throws -> Void + let onComplete: (CompleteResponse, Client) async throws -> Void + let onMessage: (String, Client) async throws -> Void let encoder = GraphQLJSONEncoder() let decoder = JSONDecoder() /// Create a new client. - /// + /// /// - Parameters: /// - messenger: The messenger to bind the client to. + /// - onConnectionError: The callback run on receipt of a `connection_error` message + /// - onConnectionAck: The callback run on receipt of a `connection_ack` message + /// - onConnectionKeepAlive: The callback run on receipt of a `connection_ka` message + /// - onData: The callback run on receipt of a `data` message + /// - onError: The callback run on receipt of an `error` message + /// - onComplete: The callback run on receipt of a `complete` message + /// - onMessage: The callback run on receipt of any message public init( - messenger: Messenger + messenger: Messenger, + onConnectionError: @escaping (ConnectionErrorResponse, Client) async throws -> Void = { _, _ in }, + onConnectionAck: @escaping (ConnectionAckResponse, Client) async throws -> Void = { _, _ in }, + onConnectionKeepAlive: @escaping (ConnectionKeepAliveResponse, Client) async throws -> Void = { _, _ in }, + onData: @escaping (DataResponse, Client) async throws -> Void = { _, _ in }, + onError: @escaping (ErrorResponse, Client) async throws -> Void = { _, _ in }, + onComplete: @escaping (CompleteResponse, Client) async throws -> Void = { _, _ in }, + onMessage: @escaping (String, Client) async throws -> Void = { _, _ in } ) { self.messenger = messenger + self.onConnectionError = onConnectionError + self.onConnectionAck = onConnectionAck + self.onConnectionKeepAlive = onConnectionKeepAlive + self.onData = onData + self.onError = onError + self.onComplete = onComplete + self.onMessage = onMessage } /// Listen and react to the provided async sequence of server messages. This function will block until the stream is completed. @@ -95,48 +116,6 @@ public class Client: @unchecked Sendable { } } - /// Define the callback run on receipt of a `connection_error` message - /// - Parameter callback: The callback to assign - public func onConnectionError(_ callback: @escaping (ConnectionErrorResponse, Client) async throws -> Void) { - onConnectionError = callback - } - - /// Define the callback run on receipt of a `connection_ack` message - /// - Parameter callback: The callback to assign - public func onConnectionAck(_ callback: @escaping (ConnectionAckResponse, Client) async throws -> Void) { - onConnectionAck = callback - } - - /// Define the callback run on receipt of a `connection_ka` message - /// - Parameter callback: The callback to assign - public func onConnectionKeepAlive(_ callback: @escaping (ConnectionKeepAliveResponse, Client) async throws -> Void) { - onConnectionKeepAlive = callback - } - - /// Define the callback run on receipt of a `data` message - /// - Parameter callback: The callback to assign - public func onData(_ callback: @escaping (DataResponse, Client) async throws -> Void) { - onData = callback - } - - /// Define the callback run on receipt of an `error` message - /// - Parameter callback: The callback to assign - public func onError(_ callback: @escaping (ErrorResponse, Client) async throws -> Void) { - onError = callback - } - - /// Define the callback run on receipt of any message - /// - Parameter callback: The callback to assign - public func onComplete(_ callback: @escaping (CompleteResponse, Client) async throws -> Void) { - onComplete = callback - } - - /// Define the callback run on receipt of a `complete` message - /// - Parameter callback: The callback to assign - public func onMessage(_ callback: @escaping (String, Client) async throws -> Void) { - onMessage = callback - } - /// Send a `connection_init` request through the messenger public func sendConnectionInit(payload: InitPayload) async throws { try await messenger.send( diff --git a/Sources/GraphQLWS/Messenger.swift b/Sources/GraphQLWS/Messenger.swift index e0bcf37..2b2b113 100644 --- a/Sources/GraphQLWS/Messenger.swift +++ b/Sources/GraphQLWS/Messenger.swift @@ -1,7 +1,7 @@ import Foundation /// Protocol for an object that can send messages. This allows mocking in tests -public protocol Messenger { +public protocol Messenger: Sendable { /// Send a message through this messenger /// - Parameter message: The message to send diff --git a/Sources/GraphQLWS/Server.swift b/Sources/GraphQLWS/Server.swift index 4a419bb..3eb10f6 100644 --- a/Sources/GraphQLWS/Server.swift +++ b/Sources/GraphQLWS/Server.swift @@ -4,11 +4,11 @@ import GraphQL /// Server implements the server-side portion of the protocol, allowing a few callbacks for customization. /// /// By default, there are no authorization checks -public class Server< +public actor Server< InitPayload: Equatable & Codable & Sendable, InitPayloadResult: Sendable, SubscriptionSequenceType: AsyncSequence & Sendable ->: @unchecked Sendable where +> where SubscriptionSequenceType.Element == GraphQLResult { let messenger: Messenger @@ -17,17 +17,15 @@ public class Server< let onExecute: (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult let onSubscribe: (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType - var onExit: () async throws -> Void = {} - var onMessage: (String) async throws -> Void = { _ in } - var onOperationComplete: (String) async throws -> Void = { _ in } - var onOperationError: (String, [Error]) async throws -> Void = { _, _ in } - - var initialized = false - var initResult: InitPayloadResult? + let onMessage: (String) async throws -> Void + let onOperationComplete: (String) async throws -> Void + let onOperationError: (String, [Error]) async throws -> Void let decoder = JSONDecoder() let encoder = GraphQLJSONEncoder() + private var initialized = false + private var initResult: InitPayloadResult? private var subscriptionTasks = [String: Task]() /// Create a new server @@ -36,16 +34,25 @@ public class Server< /// - messenger: The messenger to bind the server to. /// - onExecute: Callback run during `start` resolution for non-streaming queries. Typically this is `API.execute`. /// - onSubscribe: Callback run during `start` resolution for streaming queries. Typically this is `API.subscribe`. + /// - onMessage: Optional callback run on every message event + /// - onOperationComplete: Optional callback run when an operation completes + /// - onOperationError: Optional callback run when an operation errors public init( messenger: Messenger, onInit: @escaping (InitPayload) async throws -> InitPayloadResult, onExecute: @escaping (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult, - onSubscribe: @escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType + onSubscribe: @escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType, + onMessage: @escaping (String) async throws -> Void = { _ in }, + onOperationComplete: @escaping (String) async throws -> Void = { _ in }, + onOperationError: @escaping (String, [Error]) async throws -> Void = { _, _ in }, ) { self.messenger = messenger self.onInit = onInit self.onExecute = onExecute self.onSubscribe = onSubscribe + self.onMessage = onMessage + self.onOperationComplete = onOperationComplete + self.onOperationError = onOperationError } /// Listen and react to the provided async sequence of client messages. This function will block until the stream is completed. @@ -109,30 +116,6 @@ public class Server< subscriptionTasks.values.forEach { $0.cancel() } } - /// Define the callback run when the communication is shut down, either by the client or server - /// - Parameter callback: The callback to assign - public func onExit(_ callback: @escaping () -> Void) { - onExit = callback - } - - /// Define the callback run on receipt of any message - /// - Parameter callback: The callback to assign - public func onMessage(_ callback: @escaping (String) -> Void) { - onMessage = callback - } - - /// Define the callback run on the completion a full operation (query/mutation, end of subscription) - /// - Parameter callback: The callback to assign - public func onOperationComplete(_ callback: @escaping (String) -> Void) { - onOperationComplete = callback - } - - /// Define the callback to run on error of any full operation (failed query, interrupted subscription) - /// - Parameter callback: The callback to assign - public func onOperationError(_ callback: @escaping (String, [Error]) -> Void) { - onOperationError = callback - } - private func onConnectionInit(_ connectionInitRequest: ConnectionInitRequest, _: Messenger) async throws { guard !initialized else { try await error(.tooManyInitializations()) @@ -218,7 +201,6 @@ public class Server< subscriptionTask.cancel() } subscriptionTasks.removeAll() - try await onExit() try await messenger.close() } diff --git a/Tests/GraphQLWSTests/GraphQLWSTests.swift b/Tests/GraphQLWSTests/GraphQLWSTests.swift index b510987..24b12d9 100644 --- a/Tests/GraphQLWSTests/GraphQLWSTests.swift +++ b/Tests/GraphQLWSTests/GraphQLWSTests.swift @@ -1,44 +1,48 @@ import Foundation import GraphQL -import XCTest +import Testing import GraphQLWS -class GraphqlWsTests: XCTestCase { - var clientMessenger: TestMessenger! - var serverMessenger: TestMessenger! - var subscribeReady: Bool! = false - - let context = TestContext() - let api = TestAPI() - - override func setUp() { - clientMessenger = TestMessenger() - serverMessenger = TestMessenger() - } +@Suite +struct GraphqlTransportWSTests { + let clientMessenger = TestMessenger() + let serverMessenger = TestMessenger() /// Tests that trying to run methods before `connection_init` is not allowed - func testInitialize() async throws { + @Test func initialize() async throws { + let api = TestAPI() + let context = TestContext() let server = Server>( messenger: serverMessenger, onInit: { _ in }, onExecute: { graphQLRequest, _ in - try await self.api.execute( + try await api.execute( request: graphQLRequest.query, - context: self.context + context: context ) }, onSubscribe: { graphQLRequest, _ in - let subscription = try await self.api.subscribe( + let subscription = try await api.subscribe( request: graphQLRequest.query, - context: self.context + context: context ).get() - self.subscribeReady = true return subscription } ) - let client = Client(messenger: clientMessenger) + let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let client = Client( + messenger: clientMessenger, + onError: { message, _ in + messageContinuation.finish(throwing: message.payload[0]) + }, + onMessage: { message, _ in + messageContinuation.yield(message) + // Expect only one message + messageContinuation.finish() + } + ) let serverStream = serverMessenger.stream let clientStream = clientMessenger.stream Task { @@ -47,17 +51,6 @@ class GraphqlWsTests: XCTestCase { Task { try await client.listen(to: serverStream) } - - let messageStream = AsyncThrowingStream { continuation in - client.onMessage { message, _ in - continuation.yield(message) - // Expect only one message - continuation.finish() - } - client.onError { message, _ in - continuation.finish(throwing: message.payload[0]) - } - } try await client.sendStart( payload: GraphQLRequest( @@ -73,35 +66,47 @@ class GraphqlWsTests: XCTestCase { let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) } - XCTAssertEqual( - messages, + #expect( + messages == ["\(ErrorCode.notInitialized): Connection not initialized"] ) } /// Tests that throwing in the authorization callback forces an unauthorized error - func testAuthWithThrow() async throws { + @Test func authWithThrow() async throws { + let api = TestAPI() + let context = TestContext() let server = Server>( messenger: serverMessenger, onInit: { _ in throw TestError.couldBeAnything }, onExecute: { graphQLRequest, _ in - try await self.api.execute( + try await api.execute( request: graphQLRequest.query, - context: self.context + context: context ) }, onSubscribe: { graphQLRequest, _ in - let subscription = try await self.api.subscribe( + let subscription = try await api.subscribe( request: graphQLRequest.query, - context: self.context + context: context ).get() - self.subscribeReady = true return subscription } ) - let client = Client(messenger: clientMessenger) + let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let client = Client( + messenger: clientMessenger, + onError: { message, _ in + messageContinuation.finish(throwing: message.payload[0]) + }, + onMessage: { message, _ in + messageContinuation.yield(message) + // Expect only one message + messageContinuation.finish() + } + ) let serverStream = serverMessenger.stream let clientStream = clientMessenger.stream Task { @@ -110,17 +115,6 @@ class GraphqlWsTests: XCTestCase { Task { try await client.listen(to: serverStream) } - - let messageStream = AsyncThrowingStream { continuation in - client.onMessage { message, _ in - continuation.yield(message) - // Expect only one message - continuation.finish() - } - client.onError { message, _ in - continuation.finish(throwing: message.payload[0]) - } - } try await client.sendConnectionInit( payload: TokenInitPayload( @@ -131,46 +125,39 @@ class GraphqlWsTests: XCTestCase { let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) } - XCTAssertEqual( - messages, + #expect( + messages == ["\(ErrorCode.unauthorized): Unauthorized"] ) } /// Test single op message flow works as expected - func testSingleOp() async throws { + @Test func singleOp() async throws { + let api = TestAPI() + let context = TestContext() + let id = UUID().description + let server = Server>( messenger: serverMessenger, onInit: { _ in }, onExecute: { graphQLRequest, _ in - try await self.api.execute( + try await api.execute( request: graphQLRequest.query, - context: self.context + context: context ) }, onSubscribe: { graphQLRequest, _ in - let subscription = try await self.api.subscribe( + let subscription = try await api.subscribe( request: graphQLRequest.query, - context: self.context + context: context ).get() - self.subscribeReady = true return subscription } ) - let client = Client(messenger: clientMessenger) - let serverStream = serverMessenger.stream - let clientStream = clientMessenger.stream - Task { - try await server.listen(to: clientStream) - } - Task { - try await client.listen(to: serverStream) - } - - let id = UUID().description - - let messageStream = AsyncThrowingStream { continuation in - client.onConnectionAck { _, client in + let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let client = Client( + messenger: clientMessenger, + onConnectionAck: { _, client in try await client.sendStart( payload: GraphQLRequest( query: """ @@ -181,16 +168,24 @@ class GraphqlWsTests: XCTestCase { ), id: id ) + }, + onError: { message, _ in + messageContinuation.finish(throwing: message.payload[0]) + }, + onComplete: { _, _ in + messageContinuation.finish() + }, + onMessage: { message, _ in + messageContinuation.yield(message) } - client.onMessage { message, _ in - continuation.yield(message) - } - client.onError { message, _ in - continuation.finish(throwing: message.payload[0]) - } - client.onComplete { _, _ in - continuation.finish() - } + ) + let serverStream = serverMessenger.stream + let clientStream = clientMessenger.stream + Task { + try await server.listen(to: clientStream) + } + Task { + try await client.listen(to: serverStream) } try await client.sendConnectionInit(payload: TokenInitPayload(authToken: "")) @@ -198,50 +193,43 @@ class GraphqlWsTests: XCTestCase { let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) } - XCTAssertEqual( - messages.count, - 3, // 1 connection_ack, 1 data, 1 complete + #expect( + messages.count == 3, // 1 connection_ack, 1 data, 1 complete "Messages: \(messages.description)" ) } /// Test streaming message flow works as expected - func testStreaming() async throws { + @Test func streaming() async throws { + let api = TestAPI() + let context = TestContext() + let id = UUID().description + var dataIndex = 1 + let dataIndexMax = 3 + + let (subscribeReadyStream, subscribeReadyContinuation) = AsyncStream.makeStream() let server = Server>( messenger: serverMessenger, onInit: { _ in }, onExecute: { graphQLRequest, _ in - try await self.api.execute( + try await api.execute( request: graphQLRequest.query, - context: self.context + context: context ) }, onSubscribe: { graphQLRequest, _ in - let subscription = try await self.api.subscribe( + let subscription = try await api.subscribe( request: graphQLRequest.query, - context: self.context + context: context ).get() - self.subscribeReady = true + subscribeReadyContinuation.finish() return subscription } ) - let client = Client(messenger: clientMessenger) - let serverStream = serverMessenger.stream - let clientStream = clientMessenger.stream - Task { - try await server.listen(to: clientStream) - } - Task { - try await client.listen(to: serverStream) - } - - let id = UUID().description - - var dataIndex = 1 - let dataIndexMax = 3 - - let messageStream = AsyncThrowingStream { continuation in - client.onConnectionAck { _, client in + let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let client = Client( + messenger: clientMessenger, + onConnectionAck: { _, client in try await client.sendStart( payload: GraphQLRequest( query: """ @@ -254,34 +242,34 @@ class GraphqlWsTests: XCTestCase { ) // Wait until server has registered subscription - var i = 0 - while !self.subscribeReady, i < 50 { - usleep(1000) - i = i + 1 - } - if i == 50 { - XCTFail("Subscription timeout: Took longer than 50ms to set up") - } - - self.context.publisher.emit(event: "hello \(dataIndex)") - } - client.onData { _, _ in + for await _ in subscribeReadyStream {} + context.publisher.emit(event: "hello \(dataIndex)") + }, + onData: { _, _ in dataIndex = dataIndex + 1 if dataIndex <= dataIndexMax { - self.context.publisher.emit(event: "hello \(dataIndex)") + context.publisher.emit(event: "hello \(dataIndex)") } else { - self.context.publisher.cancel() + context.publisher.cancel() } + }, + onError: { message, _ in + messageContinuation.finish(throwing: message.payload[0]) + }, + onComplete: { _, _ in + messageContinuation.finish() + }, + onMessage: { message, _ in + messageContinuation.yield(message) } - client.onMessage { message, _ in - continuation.yield(message) - } - client.onError { message, _ in - continuation.finish(throwing: message.payload[0]) - } - client.onComplete { _, _ in - continuation.finish() - } + ) + let clientStream = clientMessenger.stream + let serverStream = serverMessenger.stream + Task { + try await server.listen(to: clientStream) + } + Task { + try await client.listen(to: serverStream) } try await client.sendConnectionInit(payload: TokenInitPayload(authToken: "")) @@ -289,9 +277,8 @@ class GraphqlWsTests: XCTestCase { let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) } - XCTAssertEqual( - messages.count, - 5, // 1 connection_ack, 3 data, 1 complete + #expect( + messages.count == 5, // 1 connection_ack, 3 next, 1 complete "Messages: \(messages.description)" ) } From aa1520d5a318e81a2375a8d83876f47725d8318f Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Wed, 11 Feb 2026 23:10:47 -0700 Subject: [PATCH 4/8] chore: Formatting --- Sources/GraphQLWS/Client.swift | 50 +++++++++---------- Sources/GraphQLWS/GraphQLWSError.swift | 2 +- Sources/GraphQLWS/Messenger.swift | 3 +- Sources/GraphQLWS/Server.swift | 36 ++++++------- Tests/GraphQLWSTests/GraphQLWSTests.swift | 19 +++---- .../GraphQLWSTests/Utils/TestMessenger.swift | 1 - 6 files changed, 52 insertions(+), 59 deletions(-) diff --git a/Sources/GraphQLWS/Client.swift b/Sources/GraphQLWS/Client.swift index 975292c..d73e4fd 100644 --- a/Sources/GraphQLWS/Client.swift +++ b/Sources/GraphQLWS/Client.swift @@ -3,7 +3,7 @@ import GraphQL /// Client is an open-ended implementation of the client side of the protocol. It parses and adds callbacks for each type of server respose. public actor Client { - // We keep this weak because we strongly inject this object into the messenger callback + /// We keep this weak because we strongly inject this object into the messenger callback let messenger: Messenger let onConnectionError: (ConnectionErrorResponse, Client) async throws -> Void @@ -18,7 +18,7 @@ public actor Client { let decoder = JSONDecoder() /// Create a new client. - /// + /// /// - Parameters: /// - messenger: The messenger to bind the client to. /// - onConnectionError: The callback run on receipt of a `connection_error` message @@ -47,12 +47,12 @@ public actor Client { self.onComplete = onComplete self.onMessage = onMessage } - + /// Listen and react to the provided async sequence of server messages. This function will block until the stream is completed. /// - Parameter incoming: The server message sequence that the client should react to. public func listen(to incoming: A) async throws -> Void where A.Element == String { for try await message in incoming { - try await self.onMessage(message, self) + try await onMessage(message, self) // Detect and ignore error responses. if message.starts(with: "44") { @@ -61,13 +61,13 @@ public actor Client { } guard let json = message.data(using: .utf8) else { - try await self.error(.invalidEncoding()) + try await error(.invalidEncoding()) return } let response: Response do { - response = try self.decoder.decode(Response.self, from: json) + response = try decoder.decode(Response.self, from: json) } catch { try await self.error(.noType()) return @@ -75,43 +75,43 @@ public actor Client { switch response.type { case .GQL_CONNECTION_ERROR: - guard let connectionErrorResponse = try? self.decoder.decode(ConnectionErrorResponse.self, from: json) else { - try await self.error(.invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR)) + guard let connectionErrorResponse = try? decoder.decode(ConnectionErrorResponse.self, from: json) else { + try await error(.invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR)) return } - try await self.onConnectionError(connectionErrorResponse, self) + try await onConnectionError(connectionErrorResponse, self) case .GQL_CONNECTION_ACK: - guard let connectionAckResponse = try? self.decoder.decode(ConnectionAckResponse.self, from: json) else { - try await self.error(.invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR)) + guard let connectionAckResponse = try? decoder.decode(ConnectionAckResponse.self, from: json) else { + try await error(.invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR)) return } - try await self.onConnectionAck(connectionAckResponse, self) + try await onConnectionAck(connectionAckResponse, self) case .GQL_CONNECTION_KEEP_ALIVE: - guard let connectionKeepAliveResponse = try? self.decoder.decode(ConnectionKeepAliveResponse.self, from: json) else { - try await self.error(.invalidResponseFormat(messageType: .GQL_CONNECTION_KEEP_ALIVE)) + guard let connectionKeepAliveResponse = try? decoder.decode(ConnectionKeepAliveResponse.self, from: json) else { + try await error(.invalidResponseFormat(messageType: .GQL_CONNECTION_KEEP_ALIVE)) return } - try await self.onConnectionKeepAlive(connectionKeepAliveResponse, self) + try await onConnectionKeepAlive(connectionKeepAliveResponse, self) case .GQL_DATA: - guard let nextResponse = try? self.decoder.decode(DataResponse.self, from: json) else { - try await self.error(.invalidResponseFormat(messageType: .GQL_DATA)) + guard let nextResponse = try? decoder.decode(DataResponse.self, from: json) else { + try await error(.invalidResponseFormat(messageType: .GQL_DATA)) return } - try await self.onData(nextResponse, self) + try await onData(nextResponse, self) case .GQL_ERROR: - guard let errorResponse = try? self.decoder.decode(ErrorResponse.self, from: json) else { - try await self.error(.invalidResponseFormat(messageType: .GQL_ERROR)) + guard let errorResponse = try? decoder.decode(ErrorResponse.self, from: json) else { + try await error(.invalidResponseFormat(messageType: .GQL_ERROR)) return } - try await self.onError(errorResponse, self) + try await onError(errorResponse, self) case .GQL_COMPLETE: - guard let completeResponse = try? self.decoder.decode(CompleteResponse.self, from: json) else { - try await self.error(.invalidResponseFormat(messageType: .GQL_COMPLETE)) + guard let completeResponse = try? decoder.decode(CompleteResponse.self, from: json) else { + try await error(.invalidResponseFormat(messageType: .GQL_COMPLETE)) return } - try await self.onComplete(completeResponse, self) + try await onComplete(completeResponse, self) default: - try await self.error(.invalidType()) + try await error(.invalidType()) } } } diff --git a/Sources/GraphQLWS/GraphQLWSError.swift b/Sources/GraphQLWS/GraphQLWSError.swift index 8036e91..52dec61 100644 --- a/Sources/GraphQLWS/GraphQLWSError.swift +++ b/Sources/GraphQLWS/GraphQLWSError.swift @@ -89,7 +89,7 @@ struct GraphQLWSError: Error { /// Error codes for miscellaneous issues public enum ErrorCode: Int, CustomStringConvertible, Sendable { - // Miscellaneous + /// Miscellaneous case miscellaneous = 4400 // Internal errors diff --git a/Sources/GraphQLWS/Messenger.swift b/Sources/GraphQLWS/Messenger.swift index 2b2b113..b0eb8cd 100644 --- a/Sources/GraphQLWS/Messenger.swift +++ b/Sources/GraphQLWS/Messenger.swift @@ -2,10 +2,9 @@ import Foundation /// Protocol for an object that can send messages. This allows mocking in tests public protocol Messenger: Sendable { - /// Send a message through this messenger /// - Parameter message: The message to send - func send(_ message: S) async throws -> Void where S: Collection, S.Element == Character + func send(_ message: S) async throws -> Void where S.Element == Character /// Close the messenger func close() async throws diff --git a/Sources/GraphQLWS/Server.swift b/Sources/GraphQLWS/Server.swift index 3eb10f6..8b3f204 100644 --- a/Sources/GraphQLWS/Server.swift +++ b/Sources/GraphQLWS/Server.swift @@ -44,7 +44,7 @@ public actor Server< onSubscribe: @escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType, onMessage: @escaping (String) async throws -> Void = { _ in }, onOperationComplete: @escaping (String) async throws -> Void = { _ in }, - onOperationError: @escaping (String, [Error]) async throws -> Void = { _, _ in }, + onOperationError: @escaping (String, [Error]) async throws -> Void = { _, _ in } ) { self.messenger = messenger self.onInit = onInit @@ -54,12 +54,12 @@ public actor Server< self.onOperationComplete = onOperationComplete self.onOperationError = onOperationError } - + /// Listen and react to the provided async sequence of client messages. This function will block until the stream is completed. /// - Parameter incoming: The client message sequence that the server should react to. public func listen(to incoming: A) async throws -> Void where A.Element == String { for try await message in incoming { - try await self.onMessage(message) + try await onMessage(message) // Detect and ignore error responses. if message.starts(with: "44") { @@ -68,13 +68,13 @@ public actor Server< } guard let json = message.data(using: .utf8) else { - try await self.error(.invalidEncoding()) + try await error(.invalidEncoding()) return } let request: Request do { - request = try self.decoder.decode(Request.self, from: json) + request = try decoder.decode(Request.self, from: json) } catch { try await self.error(.noType()) return @@ -83,31 +83,31 @@ public actor Server< // handle incoming message switch request.type { case .GQL_CONNECTION_INIT: - guard let connectionInitRequest = try? self.decoder.decode(ConnectionInitRequest.self, from: json) else { - try await self.error(.invalidRequestFormat(messageType: .GQL_CONNECTION_INIT)) + guard let connectionInitRequest = try? decoder.decode(ConnectionInitRequest.self, from: json) else { + try await error(.invalidRequestFormat(messageType: .GQL_CONNECTION_INIT)) return } - try await self.onConnectionInit(connectionInitRequest, messenger) + try await onConnectionInit(connectionInitRequest, messenger) case .GQL_START: - guard let startRequest = try? self.decoder.decode(StartRequest.self, from: json) else { - try await self.error(.invalidRequestFormat(messageType: .GQL_START)) + guard let startRequest = try? decoder.decode(StartRequest.self, from: json) else { + try await error(.invalidRequestFormat(messageType: .GQL_START)) return } - try await self.onStart(startRequest, messenger) + try await onStart(startRequest, messenger) case .GQL_STOP: - guard let stopRequest = try? self.decoder.decode(StopRequest.self, from: json) else { - try await self.error(.invalidRequestFormat(messageType: .GQL_STOP)) + guard let stopRequest = try? decoder.decode(StopRequest.self, from: json) else { + try await error(.invalidRequestFormat(messageType: .GQL_STOP)) return } - try await self.onStop(stopRequest) + try await onStop(stopRequest) case .GQL_CONNECTION_TERMINATE: - guard let connectionTerminateRequest = try? self.decoder.decode(ConnectionTerminateRequest.self, from: json) else { - try await self.error(.invalidRequestFormat(messageType: .GQL_CONNECTION_TERMINATE)) + guard let connectionTerminateRequest = try? decoder.decode(ConnectionTerminateRequest.self, from: json) else { + try await error(.invalidRequestFormat(messageType: .GQL_CONNECTION_TERMINATE)) return } - try await self.onConnectionTerminate(connectionTerminateRequest, messenger) + try await onConnectionTerminate(connectionTerminateRequest, messenger) default: - try await self.error(.invalidType()) + try await error(.invalidType()) } } } diff --git a/Tests/GraphQLWSTests/GraphQLWSTests.swift b/Tests/GraphQLWSTests/GraphQLWSTests.swift index 24b12d9..526cac3 100644 --- a/Tests/GraphQLWSTests/GraphQLWSTests.swift +++ b/Tests/GraphQLWSTests/GraphQLWSTests.swift @@ -1,9 +1,7 @@ import Foundation - import GraphQL -import Testing - import GraphQLWS +import Testing @Suite struct GraphqlTransportWSTests { @@ -24,11 +22,10 @@ struct GraphqlTransportWSTests { ) }, onSubscribe: { graphQLRequest, _ in - let subscription = try await api.subscribe( + try await api.subscribe( request: graphQLRequest.query, context: context ).get() - return subscription } ) let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() @@ -68,7 +65,7 @@ struct GraphqlTransportWSTests { } #expect( messages == - ["\(ErrorCode.notInitialized): Connection not initialized"] + ["\(ErrorCode.notInitialized): Connection not initialized"] ) } @@ -88,11 +85,10 @@ struct GraphqlTransportWSTests { ) }, onSubscribe: { graphQLRequest, _ in - let subscription = try await api.subscribe( + try await api.subscribe( request: graphQLRequest.query, context: context ).get() - return subscription } ) let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() @@ -127,7 +123,7 @@ struct GraphqlTransportWSTests { } #expect( messages == - ["\(ErrorCode.unauthorized): Unauthorized"] + ["\(ErrorCode.unauthorized): Unauthorized"] ) } @@ -147,11 +143,10 @@ struct GraphqlTransportWSTests { ) }, onSubscribe: { graphQLRequest, _ in - let subscription = try await api.subscribe( + try await api.subscribe( request: graphQLRequest.query, context: context ).get() - return subscription } ) let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() @@ -206,7 +201,7 @@ struct GraphqlTransportWSTests { let id = UUID().description var dataIndex = 1 let dataIndexMax = 3 - + let (subscribeReadyStream, subscribeReadyContinuation) = AsyncStream.makeStream() let server = Server>( messenger: serverMessenger, diff --git a/Tests/GraphQLWSTests/Utils/TestMessenger.swift b/Tests/GraphQLWSTests/Utils/TestMessenger.swift index d775715..e6931a9 100644 --- a/Tests/GraphQLWSTests/Utils/TestMessenger.swift +++ b/Tests/GraphQLWSTests/Utils/TestMessenger.swift @@ -1,6 +1,5 @@ import Foundation - @testable import GraphQLWS /// Messenger for simple testing that doesn't require starting up a websocket server. From 141103ad17fdde56a5d417c470735af10874b7d1 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Wed, 11 Feb 2026 23:12:18 -0700 Subject: [PATCH 5/8] docs: Readme updates --- README.md | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index a4dfe11..5915c98 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ # GraphQLWS +[![](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2FGraphQLSwift%2FGraphQLWS%2Fbadge%3Ftype%3Dplatforms)](https://swiftpackageindex.com/GraphQLSwift/GraphQLWS) +[![](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2FGraphQLSwift%2FGraphQLWS%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/GraphQLSwift/GraphQLWS) + This implements the [graphql-ws WebSocket subprotocol](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md). It is mainly intended for server support, but there is a basic client implementation included. @@ -14,7 +17,7 @@ Features: To use this package, include it in your `Package.swift` dependencies: ```swift -.package(url: "git@gitlab.com:PassiveLogic/platform/GraphQLWS.git", from: ""), +.package(url: "https://github.com/GraphQLSwift/GraphQLWS", from: ""), ``` Then create a class to implement the `Messenger` protocol. Here's an example using @@ -25,12 +28,8 @@ import WebSocketKit import GraphQLWS /// Messenger wrapper for WebSockets -class WebSocketMessenger: Messenger { +struct WebSocketMessenger: Messenger { let websocket: WebSocket - - init(websocket: WebSocket) { - self.websocket = websocket - } func send(_ message: S) async throws where S: Collection, S.Element == Character async throws { try await websocket.send(message) From 919da33ead7196a362caebfc17db3aa6f357273c Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Wed, 11 Feb 2026 23:20:01 -0700 Subject: [PATCH 6/8] feat!: Remove `onMessage` callback This is not necessary anymore, since you can just map the incoming AsyncStream --- Sources/GraphQLWS/Client.swift | 8 +-- Sources/GraphQLWS/Server.swift | 6 -- Tests/GraphQLWSTests/GraphQLWSTests.swift | 67 ++++++++++++----------- 3 files changed, 36 insertions(+), 45 deletions(-) diff --git a/Sources/GraphQLWS/Client.swift b/Sources/GraphQLWS/Client.swift index d73e4fd..fd77343 100644 --- a/Sources/GraphQLWS/Client.swift +++ b/Sources/GraphQLWS/Client.swift @@ -12,7 +12,6 @@ public actor Client { let onData: (DataResponse, Client) async throws -> Void let onError: (ErrorResponse, Client) async throws -> Void let onComplete: (CompleteResponse, Client) async throws -> Void - let onMessage: (String, Client) async throws -> Void let encoder = GraphQLJSONEncoder() let decoder = JSONDecoder() @@ -27,7 +26,6 @@ public actor Client { /// - onData: The callback run on receipt of a `data` message /// - onError: The callback run on receipt of an `error` message /// - onComplete: The callback run on receipt of a `complete` message - /// - onMessage: The callback run on receipt of any message public init( messenger: Messenger, onConnectionError: @escaping (ConnectionErrorResponse, Client) async throws -> Void = { _, _ in }, @@ -35,8 +33,7 @@ public actor Client { onConnectionKeepAlive: @escaping (ConnectionKeepAliveResponse, Client) async throws -> Void = { _, _ in }, onData: @escaping (DataResponse, Client) async throws -> Void = { _, _ in }, onError: @escaping (ErrorResponse, Client) async throws -> Void = { _, _ in }, - onComplete: @escaping (CompleteResponse, Client) async throws -> Void = { _, _ in }, - onMessage: @escaping (String, Client) async throws -> Void = { _, _ in } + onComplete: @escaping (CompleteResponse, Client) async throws -> Void = { _, _ in } ) { self.messenger = messenger self.onConnectionError = onConnectionError @@ -45,15 +42,12 @@ public actor Client { self.onData = onData self.onError = onError self.onComplete = onComplete - self.onMessage = onMessage } /// Listen and react to the provided async sequence of server messages. This function will block until the stream is completed. /// - Parameter incoming: The server message sequence that the client should react to. public func listen(to incoming: A) async throws -> Void where A.Element == String { for try await message in incoming { - try await onMessage(message, self) - // Detect and ignore error responses. if message.starts(with: "44") { // TODO: Determine what to do with returned error messages diff --git a/Sources/GraphQLWS/Server.swift b/Sources/GraphQLWS/Server.swift index 8b3f204..e524df9 100644 --- a/Sources/GraphQLWS/Server.swift +++ b/Sources/GraphQLWS/Server.swift @@ -17,7 +17,6 @@ public actor Server< let onExecute: (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult let onSubscribe: (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType - let onMessage: (String) async throws -> Void let onOperationComplete: (String) async throws -> Void let onOperationError: (String, [Error]) async throws -> Void @@ -34,7 +33,6 @@ public actor Server< /// - messenger: The messenger to bind the server to. /// - onExecute: Callback run during `start` resolution for non-streaming queries. Typically this is `API.execute`. /// - onSubscribe: Callback run during `start` resolution for streaming queries. Typically this is `API.subscribe`. - /// - onMessage: Optional callback run on every message event /// - onOperationComplete: Optional callback run when an operation completes /// - onOperationError: Optional callback run when an operation errors public init( @@ -42,7 +40,6 @@ public actor Server< onInit: @escaping (InitPayload) async throws -> InitPayloadResult, onExecute: @escaping (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult, onSubscribe: @escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType, - onMessage: @escaping (String) async throws -> Void = { _ in }, onOperationComplete: @escaping (String) async throws -> Void = { _ in }, onOperationError: @escaping (String, [Error]) async throws -> Void = { _, _ in } ) { @@ -50,7 +47,6 @@ public actor Server< self.onInit = onInit self.onExecute = onExecute self.onSubscribe = onSubscribe - self.onMessage = onMessage self.onOperationComplete = onOperationComplete self.onOperationError = onOperationError } @@ -59,8 +55,6 @@ public actor Server< /// - Parameter incoming: The client message sequence that the server should react to. public func listen(to incoming: A) async throws -> Void where A.Element == String { for try await message in incoming { - try await onMessage(message) - // Detect and ignore error responses. if message.starts(with: "44") { // TODO: Determine what to do with returned error messages diff --git a/Tests/GraphQLWSTests/GraphQLWSTests.swift b/Tests/GraphQLWSTests/GraphQLWSTests.swift index 526cac3..22e0a32 100644 --- a/Tests/GraphQLWSTests/GraphQLWSTests.swift +++ b/Tests/GraphQLWSTests/GraphQLWSTests.swift @@ -29,24 +29,23 @@ struct GraphqlTransportWSTests { } ) let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let serverMessageStream = serverMessenger.stream.map { message in + messageContinuation.yield(message) + // Expect only one message + messageContinuation.finish() + return message + } let client = Client( messenger: clientMessenger, onError: { message, _ in messageContinuation.finish(throwing: message.payload[0]) - }, - onMessage: { message, _ in - messageContinuation.yield(message) - // Expect only one message - messageContinuation.finish() + await clientMessenger.close() } ) - let serverStream = serverMessenger.stream let clientStream = clientMessenger.stream Task { try await server.listen(to: clientStream) - } - Task { - try await client.listen(to: serverStream) + await serverMessenger.close() } try await client.sendStart( @@ -59,6 +58,7 @@ struct GraphqlTransportWSTests { ), id: UUID().uuidString ) + try await client.listen(to: serverMessageStream) let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) @@ -92,24 +92,23 @@ struct GraphqlTransportWSTests { } ) let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let serverMessageStream = serverMessenger.stream.map { message in + messageContinuation.yield(message) + // Expect only one message + messageContinuation.finish() + return message + } let client = Client( messenger: clientMessenger, onError: { message, _ in messageContinuation.finish(throwing: message.payload[0]) - }, - onMessage: { message, _ in - messageContinuation.yield(message) - // Expect only one message - messageContinuation.finish() + await clientMessenger.close() } ) - let serverStream = serverMessenger.stream let clientStream = clientMessenger.stream Task { try await server.listen(to: clientStream) - } - Task { - try await client.listen(to: serverStream) + await serverMessenger.close() } try await client.sendConnectionInit( @@ -117,6 +116,7 @@ struct GraphqlTransportWSTests { authToken: "" ) ) + try await client.listen(to: serverMessageStream) let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) @@ -150,6 +150,10 @@ struct GraphqlTransportWSTests { } ) let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + let serverMessageStream = serverMessenger.stream.map { message in + messageContinuation.yield(message) + return message + } let client = Client( messenger: clientMessenger, onConnectionAck: { _, client in @@ -166,24 +170,21 @@ struct GraphqlTransportWSTests { }, onError: { message, _ in messageContinuation.finish(throwing: message.payload[0]) + await clientMessenger.close() }, onComplete: { _, _ in messageContinuation.finish() - }, - onMessage: { message, _ in - messageContinuation.yield(message) + await clientMessenger.close() } ) - let serverStream = serverMessenger.stream let clientStream = clientMessenger.stream Task { try await server.listen(to: clientStream) - } - Task { - try await client.listen(to: serverStream) + await serverMessenger.close() } try await client.sendConnectionInit(payload: TokenInitPayload(authToken: "")) + try await client.listen(to: serverMessageStream) let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) @@ -222,6 +223,11 @@ struct GraphqlTransportWSTests { } ) let (messageStream, messageContinuation) = AsyncThrowingStream.makeStream() + // Used to extract the server messages + let serverMessageStream = serverMessenger.stream.map { message in + messageContinuation.yield(message) + return message + } let client = Client( messenger: clientMessenger, onConnectionAck: { _, client in @@ -250,24 +256,21 @@ struct GraphqlTransportWSTests { }, onError: { message, _ in messageContinuation.finish(throwing: message.payload[0]) + await clientMessenger.close() }, onComplete: { _, _ in messageContinuation.finish() - }, - onMessage: { message, _ in - messageContinuation.yield(message) + await clientMessenger.close() } ) let clientStream = clientMessenger.stream - let serverStream = serverMessenger.stream Task { try await server.listen(to: clientStream) - } - Task { - try await client.listen(to: serverStream) + await serverMessenger.close() } try await client.sendConnectionInit(payload: TokenInitPayload(authToken: "")) + try await client.listen(to: serverMessageStream) let messages = try await messageStream.reduce(into: [String]()) { result, message in result.append(message) From 44b45ec08517f519acf0c6a62879c8d00ea9f2b1 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Wed, 11 Feb 2026 23:21:17 -0700 Subject: [PATCH 7/8] fix!: Avoids closing server on single executions --- Sources/GraphQLWS/Server.swift | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Sources/GraphQLWS/Server.swift b/Sources/GraphQLWS/Server.swift index e524df9..4fab56d 100644 --- a/Sources/GraphQLWS/Server.swift +++ b/Sources/GraphQLWS/Server.swift @@ -127,7 +127,7 @@ public actor Server< // TODO: Should we send the `ka` message? } - private func onStart(_ startRequest: StartRequest, _ messenger: Messenger) async throws { + private func onStart(_ startRequest: StartRequest, _: Messenger) async throws { guard initialized, let initResult else { try await error(.notInitialized()) return @@ -172,7 +172,6 @@ public actor Server< } catch { try await sendError(error, id: id) } - try await messenger.close() } } From 0c07dd744d05f4bc71f5bac4c6db33fe230167e4 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Thu, 12 Feb 2026 10:06:47 -0700 Subject: [PATCH 8/8] docs: Minor corrections --- Sources/GraphQLWS/Client.swift | 1 - Sources/GraphQLWS/Messenger.swift | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/Sources/GraphQLWS/Client.swift b/Sources/GraphQLWS/Client.swift index fd77343..a7f993f 100644 --- a/Sources/GraphQLWS/Client.swift +++ b/Sources/GraphQLWS/Client.swift @@ -3,7 +3,6 @@ import GraphQL /// Client is an open-ended implementation of the client side of the protocol. It parses and adds callbacks for each type of server respose. public actor Client { - /// We keep this weak because we strongly inject this object into the messenger callback let messenger: Messenger let onConnectionError: (ConnectionErrorResponse, Client) async throws -> Void diff --git a/Sources/GraphQLWS/Messenger.swift b/Sources/GraphQLWS/Messenger.swift index b0eb8cd..86ca9d4 100644 --- a/Sources/GraphQLWS/Messenger.swift +++ b/Sources/GraphQLWS/Messenger.swift @@ -1,6 +1,6 @@ import Foundation -/// Protocol for an object that can send messages. This allows mocking in tests +/// Protocol for an object that can send messages. public protocol Messenger: Sendable { /// Send a message through this messenger /// - Parameter message: The message to send