Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prepare streaming in management sdk. #2129

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public BinaryPayloadContentBuilder(IReadOnlyList<IHubProtocol> hubProtocols)
_hubProtocols = hubProtocols;
}

public HttpContent? Build(PayloadMessage? payload)
public HttpContent? Build(HubMessage? payload)
{
return payload == null ? null : (HttpContent)new BinaryPayloadMessageContent(payload, _hubProtocols);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ internal class BinaryPayloadMessageContent : HttpContent
};
private static readonly MediaTypeHeaderValue ContentType = new("application/octet-stream");

private readonly PayloadMessage _payloadMessage;
private readonly HubMessage _payloadMessage;
private readonly IReadOnlyList<IHubProtocol> _hubProtocols;

public BinaryPayloadMessageContent(PayloadMessage payloadMessage, IReadOnlyList<IHubProtocol> hubProtocols)
public BinaryPayloadMessageContent(HubMessage payloadMessage, IReadOnlyList<IHubProtocol> hubProtocols)
{
_payloadMessage = payloadMessage ?? throw new ArgumentNullException(nameof(payloadMessage));
_hubProtocols = hubProtocols ?? throw new ArgumentNullException(nameof(hubProtocols));
Expand All @@ -49,13 +49,12 @@ protected override bool TryComputeLength(out long length)

private void WriteMessageCore(IBufferWriter<byte> bufferWriter)
{
var invocationMessage = new InvocationMessage(_payloadMessage.Target, _payloadMessage.Arguments);
var messagePackWriter = new MessagePackWriter(bufferWriter);
messagePackWriter.WriteMapHeader(_hubProtocols.Count);
foreach (var hubProtocol in _hubProtocols)
{
messagePackWriter.WriteString(ProtocolMap[hubProtocol.Name]);
messagePackWriter.Write(hubProtocol.GetMessageBytes(invocationMessage).Span);
messagePackWriter.Write(hubProtocol.GetMessageBytes(_payloadMessage).Span);
}
messagePackWriter.Flush();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Net.Http;
using Microsoft.AspNetCore.SignalR.Protocol;

#nullable enable

namespace Microsoft.Azure.SignalR.Common
{
internal interface IPayloadContentBuilder
{
HttpContent? Build(PayloadMessage? payload);
HttpContent? Build(HubMessage? payload);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Net.Http;
using Azure.Core.Serialization;
using Microsoft.AspNetCore.SignalR.Protocol;

#nullable enable
namespace Microsoft.Azure.SignalR.Common
Expand All @@ -16,7 +17,7 @@ public JsonPayloadContentBuilder(ObjectSerializer jsonObjectSerializer)
_jsonObjectSerializer = jsonObjectSerializer;
}

public HttpContent? Build(PayloadMessage? payload)
public HttpContent? Build(HubMessage? payload)
{
return payload == null ? null : new JsonPayloadMessageContent(payload, _jsonObjectSerializer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Text.Json;
using System.Threading.Tasks;
using Azure.Core.Serialization;
using Microsoft.AspNetCore.SignalR.Protocol;

namespace Microsoft.Azure.SignalR
{
Expand All @@ -22,10 +23,10 @@ internal class JsonPayloadMessageContent : HttpContent
// We must skip validation because what we break the writing midway and write JSON in other ways.
SkipValidation = true
};
private readonly PayloadMessage _payloadMessage;
private readonly HubMessage _payloadMessage;
private readonly ObjectSerializer _jsonObjectSerializer;

public JsonPayloadMessageContent(PayloadMessage payloadMessage, ObjectSerializer jsonObjectSerializer)
public JsonPayloadMessageContent(HubMessage payloadMessage, ObjectSerializer jsonObjectSerializer)
{
_payloadMessage = payloadMessage ?? throw new System.ArgumentNullException(nameof(payloadMessage));
_jsonObjectSerializer = jsonObjectSerializer;
Expand All @@ -34,14 +35,21 @@ public JsonPayloadMessageContent(PayloadMessage payloadMessage, ObjectSerializer

protected override async Task SerializeToStreamAsync(Stream stream, TransportContext context)
{
using var jsonWriter = new Utf8JsonWriter(stream, JsonWriterOptions);
jsonWriter.WriteStartObject();
jsonWriter.WriteString(nameof(PayloadMessage.Target), _payloadMessage.Target);
jsonWriter.WritePropertyName(nameof(PayloadMessage.Arguments));
await jsonWriter.FlushAsync();
await _jsonObjectSerializer.SerializeAsync(stream, _payloadMessage.Arguments, typeof(object[]), default);
jsonWriter.WriteEndObject();
await jsonWriter.FlushAsync();
if (_payloadMessage is InvocationMessage invocationMessage)
{
using var jsonWriter = new Utf8JsonWriter(stream, JsonWriterOptions);
jsonWriter.WriteStartObject();
jsonWriter.WriteString(nameof(PayloadMessage.Target), invocationMessage.Target);
jsonWriter.WritePropertyName(nameof(PayloadMessage.Arguments));
await jsonWriter.FlushAsync();
await _jsonObjectSerializer.SerializeAsync(stream, invocationMessage.Arguments, typeof(object[]), default);
jsonWriter.WriteEndObject();
await jsonWriter.FlushAsync();
}
else if (_payloadMessage is StreamItemMessage streamItemMessage)
{
await _jsonObjectSerializer.SerializeAsync(stream, streamItemMessage.Item, streamItemMessage.Item?.GetType() ?? typeof(object), default);
}
}

protected override bool TryComputeLength(out long length)
Expand Down
64 changes: 37 additions & 27 deletions src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.Threading;
using System.Threading.Tasks;
using Azure.Core.Serialization;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Azure.SignalR.Common;
using Microsoft.Extensions.Primitives;

Expand Down Expand Up @@ -37,47 +38,54 @@ public RestClient(IHttpClientFactory httpClientFactory, IPayloadContentBuilder c
public Task SendAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
string? methodName = null,
object[]? args = null,
Func<HttpResponseMessage, bool>? handleExpectedResponse = null,
CancellationToken cancellationToken = default)
{
return handleExpectedResponse == null
? SendAsync(api, httpMethod, methodName, args, handleExpectedResponseAsync: null, cancellationToken)
: SendAsync(api, httpMethod, methodName, args, response => Task.FromResult(handleExpectedResponse(response)), cancellationToken);
}
CancellationToken cancellationToken = default) =>
vwxyzh marked this conversation as resolved.
Show resolved Hide resolved
SendAsync(api, httpMethod, (Func<HttpResponseMessage, Task<bool>>?)null, cancellationToken);

public Task SendAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
string? methodName = null,
object[]? args = null,
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponseAsync = null,
Func<HttpResponseMessage, bool>? handleExpectedResponse,
CancellationToken cancellationToken = default) =>
SendAsync(api, httpMethod, AsAsync(handleExpectedResponse), cancellationToken);

public Task SendAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponseAsync,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.UserDefault, api, httpMethod, methodName, args, handleExpectedResponseAsync, cancellationToken);
return SendAsyncCore(Constants.HttpClientNames.UserDefault, api, httpMethod, null, handleExpectedResponseAsync, cancellationToken);
}

public Task SendWithRetryAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
string? methodName = null,
object[]? args = null,
Func<HttpResponseMessage, bool>? handleExpectedResponse = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.Resilient, api, httpMethod, methodName, args, handleExpectedResponse == null ? null : response => Task.FromResult(handleExpectedResponse(response)), cancellationToken);
return SendAsyncCore(Constants.HttpClientNames.Resilient, api, httpMethod, null, AsAsync(handleExpectedResponse), cancellationToken);
}

public Task SendMessageWithRetryAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
string? methodName = null,
object[]? args = null,
string methodName,
object?[] args,
Func<HttpResponseMessage, bool>? handleExpectedResponse = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, methodName, args, handleExpectedResponse == null ? null : response => Task.FromResult(handleExpectedResponse(response)), cancellationToken);
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), AsAsync(handleExpectedResponse), cancellationToken);
}

public Task SendStreamMessageWithRetryAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
string streamId,
object? arg = null,
Func<HttpResponseMessage, bool>? handleExpectedResponse = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new StreamItemMessage(streamId, arg), AsAsync(handleExpectedResponse), cancellationToken);
}

private static Uri GetUri(string url, IDictionary<string, StringValues>? query)
Expand Down Expand Up @@ -142,13 +150,12 @@ private async Task SendAsyncCore(
string httpClientName,
RestApiEndpoint api,
HttpMethod httpMethod,
string? methodName = null,
object[]? args = null,
HubMessage? body,
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponseAsync = null,
CancellationToken cancellationToken = default)
{
using var httpClient = _httpClientFactory.CreateClient(httpClientName);
using var request = BuildRequest(api, httpMethod, methodName, args);
using var request = BuildRequest(api, httpMethod, body);

try
{
Expand All @@ -171,17 +178,20 @@ private async Task SendAsyncCore(
}
}

private HttpRequestMessage BuildRequest(RestApiEndpoint api, HttpMethod httpMethod, string? methodName = null, object[]? args = null)
private HttpRequestMessage BuildRequest(RestApiEndpoint api, HttpMethod httpMethod, HubMessage? body)
{
var payload = httpMethod == HttpMethod.Post ? new PayloadMessage { Target = methodName, Arguments = args } : null;
return GenerateHttpRequest(api.Audience, api.Query, httpMethod, payload, api.Token);
var payload = httpMethod == HttpMethod.Post ? body : null;
return GenerateHttpRequest(api.Audience, api.Query, httpMethod, body, api.Token);
}

private HttpRequestMessage GenerateHttpRequest(string url, IDictionary<string, StringValues> query, HttpMethod httpMethod, PayloadMessage? payload, string tokenString)
private HttpRequestMessage GenerateHttpRequest(string url, IDictionary<string, StringValues> query, HttpMethod httpMethod, HubMessage? body, string tokenString)
{
var request = new HttpRequestMessage(httpMethod, GetUri(url, query));
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", tokenString);
request.Content = _payloadContentBuilder.Build(payload);
request.Content = _payloadContentBuilder.Build(body);
return request;
}

private static Func<HttpResponseMessage, Task<bool>>? AsAsync(Func<HttpResponseMessage, bool>? syncFunc) =>
syncFunc == null ? null : (response => Task.FromResult(syncFunc(response)));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Collections.Generic;
using System.Threading.Channels;
using System.Threading.Tasks;

namespace Microsoft.Azure.SignalR.Management
{
/// <summary>
/// A streaming manager abstraction for sending stream response.
/// </summary>
internal interface IStreamingManager
{
Task SendStream<TItem>(string connectionId, string streamId, IAsyncEnumerable<TItem> items);
vwxyzh marked this conversation as resolved.
Show resolved Hide resolved
Task SendStream<TItem>(string connectionId, string streamId, ChannelReader<TItem> items);
void CancelStream(string connectionId, string streamId);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Collections.Generic;
using System.Threading.Channels;
using System.Threading.Tasks;

namespace Microsoft.Azure.SignalR.Management.HubContext
{
internal abstract class StreamingManager : IStreamingManager
{
public abstract void CancelStream(string connectionId, string streamId);

public abstract Task SendStream<TItem>(string connectionId, string streamId, IAsyncEnumerable<TItem> items);

public abstract Task SendStream<TItem>(string connectionId, string streamId, ChannelReader<TItem> items);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.Azure.SignalR.Management.HubContext;

namespace Microsoft.Azure.SignalR.Management
{
internal class StreamingManagerAdapter : StreamingManager
{
private readonly ConcurrentDictionary<(string connectionId, string streamId), CancellationTokenSource> _cancellations = new();
private readonly IStreamingHubLifetimeManager _lifetimeManager;

public StreamingManagerAdapter(IStreamingHubLifetimeManager lifetimeManager)
{
_lifetimeManager = lifetimeManager;
}

public override void CancelStream(string connectionId, string streamId)
{
if (_cancellations.TryRemove((connectionId, streamId), out var cts))
{
cts.Cancel();
}
}

public override async Task SendStream<TItem>(string connectionId, string streamId, IAsyncEnumerable<TItem> items)
{
var source = new CancellationTokenSource();
if (!_cancellations.TryAdd((connectionId, streamId), source))
{
throw new InvalidOperationException("Cannot send a stream twice.");
}
try
{
await foreach (var item in items.WithCancellation(source.Token))
{
await _lifetimeManager.SendStreamItemAsync(connectionId, streamId, item);
}
await _lifetimeManager.SendStreamCompletionAsync(connectionId, streamId, null);
vwxyzh marked this conversation as resolved.
Show resolved Hide resolved
}
catch (OperationCanceledException) when (source.Token.IsCancellationRequested)
{
// do not send anything if the stream is cancelled.
}
catch (Exception ex)
{
await _lifetimeManager.SendStreamCompletionAsync(connectionId, streamId, ex.Message);
Copy link
Member

@vicancy vicancy Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this ex.message exposed to the end user? Is it expected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
}

public override async Task SendStream<TItem>(string connectionId, string streamId, ChannelReader<TItem> channelReader)
{
var source = new CancellationTokenSource();
if (!_cancellations.TryAdd((connectionId, streamId), source))
{
throw new InvalidOperationException("Cannot send a stream twice.");
}
try
{
while (await channelReader.WaitToReadAsync(source.Token))
{
while (channelReader.TryRead(out var item))
{
await _lifetimeManager.SendStreamItemAsync(connectionId, streamId, item);
}
}
await _lifetimeManager.SendStreamCompletionAsync(connectionId, streamId, null);
}
catch (OperationCanceledException) when (source.Token.IsCancellationRequested)
{
// do not send anything if the stream is cancelled.
}
catch (Exception ex)
{
await _lifetimeManager.SendStreamCompletionAsync(connectionId, streamId, ex.Message);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace Microsoft.Azure.SignalR.Management
{
internal interface IServiceHubLifetimeManager : IHubLifetimeManager, IUserGroupHubLifetimeManager
internal interface IServiceHubLifetimeManager : IHubLifetimeManager, IUserGroupHubLifetimeManager, IStreamingHubLifetimeManager
{
Task CloseConnectionAsync(string connectionId, string reason, CancellationToken cancellationToken);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Azure.SignalR.Management
{
internal interface IStreamingHubLifetimeManager
{
Task SendStreamItemAsync<TItem>(string connectionId, string streamId, TItem item, CancellationToken cancellationToken = default);

Task SendStreamCompletionAsync(string connectionId, string streamId, string error, CancellationToken cancellationToken = default);
}
}
Loading
Loading