Skip to content
Open
11 changes: 11 additions & 0 deletions sdks/csharp/src/SpacetimeDBClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,17 @@ public void Disconnect()
{
webSocket.Close();
}
#if UNITY_WEBGL && !UNITY_EDITOR
else if (webSocket.IsConnecting)
{
webSocket.Abort(); // forceful during connecting
}
#else
else if (webSocket.IsConnecting || webSocket.IsNoneState)
{
webSocket.Abort(); // forceful during connecting
}
#endif

_parseCancellationTokenSource.Cancel();
}
Expand Down
85 changes: 81 additions & 4 deletions sdks/csharp/src/WebSocket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public struct ConnectOptions
private readonly ConcurrentQueue<Action> dispatchQueue = new();

protected ClientWebSocket Ws = new();
private CancellationTokenSource? _connectCts;

public WebSocket(ConnectOptions options)
{
Expand All @@ -60,9 +61,13 @@ public WebSocket(ConnectOptions options)
#if UNITY_WEBGL && !UNITY_EDITOR
private bool _isConnected = false;
private bool _isConnecting = false;
private bool _cancelConnectRequested = false;
public bool IsConnected => _isConnected;
public bool IsConnecting => _isConnecting;
#else
public bool IsConnected { get { return Ws != null && Ws.State == WebSocketState.Open; } }
public bool IsConnecting { get { return Ws != null && Ws.State == WebSocketState.Connecting; } }
public bool IsNoneState { get { return Ws != null && Ws.State == WebSocketState.None; } }
#endif

#if UNITY_WEBGL && !UNITY_EDITOR
Expand Down Expand Up @@ -145,8 +150,9 @@ public async Task Connect(string? auth, string host, string nameOrAddress, Conne
{
#if UNITY_WEBGL && !UNITY_EDITOR
if (_isConnecting || _isConnected) return;

_isConnecting = true;
_cancelConnectRequested = false;
try
{
var uri = $"{host}/v1/database/{nameOrAddress}/subscribe?connection_id={connectionId}&compression={compression}";
Expand All @@ -161,6 +167,11 @@ public async Task Connect(string? auth, string host, string nameOrAddress, Conne
dispatchQueue.Enqueue(() => OnConnectError?.Invoke(
new Exception("Failed to connect WebSocket")));
}
else if (_cancelConnectRequested)
{
// If cancel was requested before open, proactively close now.
WebSocket_Close(_webglSocketId, (int)WebSocketCloseStatus.NormalClosure, "Canceled during connect.");
}
}
catch (Exception e)
{
Expand All @@ -180,7 +191,7 @@ public async Task Connect(string? auth, string host, string nameOrAddress, Conne
var url = new Uri(uri);
Ws.Options.AddSubProtocol(_options.Protocol);

var source = new CancellationTokenSource(10000);
_connectCts = new CancellationTokenSource(10000);
if (!string.IsNullOrEmpty(auth))
{
Ws.Options.SetRequestHeader("Authorization", $"Bearer {auth}");
Expand All @@ -192,7 +203,7 @@ public async Task Connect(string? auth, string host, string nameOrAddress, Conne

try
{
await Ws.ConnectAsync(url, source.Token);
await Ws.ConnectAsync(url, _connectCts.Token);
if (Ws.State == WebSocketState.Open)
{
if (OnConnect != null)
Expand Down Expand Up @@ -364,14 +375,36 @@ await Ws.CloseAsync(WebSocketCloseStatus.MessageTooBig, closeMessage,
#endif
}

/// <summary>
/// Cancel an in-flight ConnectAsync. Safe to call if no connect is pending.
/// </summary>
public void CancelConnect()
{
#if UNITY_WEBGL && !UNITY_EDITOR
// No CTS on WebGL. Mark cancel intent so that when socket id arrives or open fires,
// we immediately close and avoid reporting a connected state.
_cancelConnectRequested = true;
return;
#else
try { _connectCts?.Cancel(); } catch { /* ignore */ }
#endif
}

public Task Close(WebSocketCloseStatus code = WebSocketCloseStatus.NormalClosure)
{
#if UNITY_WEBGL && !UNITY_EDITOR
if (_isConnected && _webglSocketId >= 0)
if (_webglSocketId >= 0)
{
// If connected or connecting with a valid socket id, request a close.
WebSocket_Close(_webglSocketId, (int)code, "Disconnecting normally.");
_cancelConnectRequested = false; // graceful close intent
_isConnected = false;
}
else if (_isConnecting)
{
// We don't yet have a socket id; remember to cancel once it arrives/opens.
_cancelConnectRequested = true;
}
#else
if (Ws?.State == WebSocketState.Open)
{
Expand All @@ -381,6 +414,35 @@ public Task Close(WebSocketCloseStatus code = WebSocketCloseStatus.NormalClosure
return Task.CompletedTask;
}

/// <summary>
/// Forcefully abort the WebSocket connection. This terminates any in-flight connect/receive/send
/// and ensures the server-side socket is torn down promptly. Prefer Close() for graceful shutdowns.
/// </summary>
public void Abort()
{
#if UNITY_WEBGL && !UNITY_EDITOR
if (_webglSocketId >= 0)
{
WebSocket_Close(_webglSocketId, (int)WebSocketCloseStatus.NormalClosure, "Aborting connection.");
_isConnected = false;
}
else if (_isConnecting)
{
// No socket yet; ensure we close immediately once it opens.
_cancelConnectRequested = true;
}
#else
try
{
Ws?.Abort();
}
catch
{
// Intentionally swallow; Abort is best-effort.
}
#endif
}

private Task? senderTask;
private readonly ConcurrentQueue<ClientMessage> messageSendQueue = new();

Expand Down Expand Up @@ -447,11 +509,21 @@ public WebSocketState GetState()
{
return Ws!.State;
}

#if UNITY_WEBGL && !UNITY_EDITOR
public void HandleWebGLOpen(int socketId)
{
if (socketId == _webglSocketId)
{
if (_cancelConnectRequested)
{
// Immediately close instead of reporting connected.
WebSocket_Close(_webglSocketId, (int)WebSocketCloseStatus.NormalClosure, "Canceled during connect.");
_isConnecting = false;
_isConnected = false;
_cancelConnectRequested = false;
return;
}
_isConnected = true;
if (OnConnect != null)
dispatchQueue.Enqueue(() => OnConnect());
Expand All @@ -472,6 +544,9 @@ public void HandleWebGLClose(int socketId, int code, string reason)
if (socketId == _webglSocketId && OnClose != null)
{
_isConnected = false;
_isConnecting = false;
_webglSocketId = -1;
_cancelConnectRequested = false;
var ex = code != (int)WebSocketCloseStatus.NormalClosure ? new Exception($"WebSocket closed with code {code}: {reason}") : null;
dispatchQueue.Enqueue(() => OnClose?.Invoke(ex));
}
Expand All @@ -482,6 +557,8 @@ public void HandleWebGLError(int socketId)
UnityEngine.Debug.Log($"HandleWebGLError: {socketId}");
if (socketId == _webglSocketId && OnConnectError != null)
{
_isConnecting = false;
_webglSocketId = -1;
dispatchQueue.Enqueue(() => OnConnectError(new Exception($"Socket {socketId} error.")));
}
}
Expand Down
Loading