Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/Grpc.AspNetCore.Server/Grpc.AspNetCore.Server.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
<GenerateDocumentationFile>true</GenerateDocumentationFile>
</PropertyGroup>

<ItemGroup>
<Compile Include="..\Shared\DefaultDeserializationContext.cs" Link="Internal\DefaultDeserializationContext.cs" />
<Compile Include="..\Shared\DefaultSerializationContext.cs" Link="Internal\DefaultSerializationContext.cs" />
</ItemGroup>

<ItemGroup>
<FrameworkReference Include="Microsoft.AspNetCore.App" />

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpC
service = activator.Create();
response = await _invoker(
service,
new HttpContextStreamReader<TRequest>(serverCallContext, Method.RequestMarshaller.Deserializer),
new HttpContextStreamReader<TRequest>(serverCallContext, Method.RequestMarshaller.ContextualDeserializer),
serverCallContext);
}
finally
Expand All @@ -101,7 +101,7 @@ protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpC
else
{
response = await _pipelineInvoker(
new HttpContextStreamReader<TRequest>(serverCallContext, Method.RequestMarshaller.Deserializer),
new HttpContextStreamReader<TRequest>(serverCallContext, Method.RequestMarshaller.ContextualDeserializer),
serverCallContext);
}

Expand All @@ -112,7 +112,7 @@ protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpC
}

var responseBodyWriter = httpContext.Response.BodyWriter;
await responseBodyWriter.WriteMessageAsync(response, serverCallContext, Method.ResponseMarshaller.Serializer, canFlush: false);
await responseBodyWriter.WriteMessageAsync(response, serverCallContext, Method.ResponseMarshaller.ContextualSerializer, canFlush: false);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpC
service = activator.Create();
await _invoker(
service,
new HttpContextStreamReader<TRequest>(serverCallContext, Method.RequestMarshaller.Deserializer),
new HttpContextStreamWriter<TResponse>(serverCallContext, Method.ResponseMarshaller.Serializer),
new HttpContextStreamReader<TRequest>(serverCallContext, Method.RequestMarshaller.ContextualDeserializer),
new HttpContextStreamWriter<TResponse>(serverCallContext, Method.ResponseMarshaller.ContextualSerializer),
serverCallContext);
}
finally
Expand All @@ -101,8 +101,8 @@ await _invoker(
else
{
await _pipelineInvoker(
new HttpContextStreamReader<TRequest>(serverCallContext, Method.RequestMarshaller.Deserializer),
new HttpContextStreamWriter<TResponse>(serverCallContext, Method.ResponseMarshaller.Serializer),
new HttpContextStreamReader<TRequest>(serverCallContext, Method.RequestMarshaller.ContextualDeserializer),
new HttpContextStreamWriter<TResponse>(serverCallContext, Method.ResponseMarshaller.ContextualSerializer),
serverCallContext);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpC
{
// Decode request
var requestPayload = await httpContext.Request.BodyReader.ReadSingleMessageAsync(serverCallContext);
var request = Method.RequestMarshaller.Deserializer(requestPayload);

serverCallContext.DeserializationContext.SetPayload(requestPayload);
var request = Method.RequestMarshaller.ContextualDeserializer(serverCallContext.DeserializationContext);
serverCallContext.DeserializationContext.SetPayload(null);

if (_pipelineInvoker == null)
{
Expand All @@ -84,7 +87,7 @@ protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpC
await _invoker(
service,
request,
new HttpContextStreamWriter<TResponse>(serverCallContext, Method.ResponseMarshaller.Serializer),
new HttpContextStreamWriter<TResponse>(serverCallContext, Method.ResponseMarshaller.ContextualSerializer),
serverCallContext);
}
finally
Expand All @@ -99,7 +102,7 @@ await _invoker(
{
await _pipelineInvoker(
request,
new HttpContextStreamWriter<TResponse>(serverCallContext, Method.ResponseMarshaller.Serializer),
new HttpContextStreamWriter<TResponse>(serverCallContext, Method.ResponseMarshaller.ContextualSerializer),
serverCallContext);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ public UnaryServerCallHandler(
protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpContextServerCallContext serverCallContext)
{
var requestPayload = await httpContext.Request.BodyReader.ReadSingleMessageAsync(serverCallContext);
var request = Method.RequestMarshaller.Deserializer(requestPayload);

serverCallContext.DeserializationContext.SetPayload(requestPayload);
var request = Method.RequestMarshaller.ContextualDeserializer(serverCallContext.DeserializationContext);
serverCallContext.DeserializationContext.SetPayload(null);

TResponse? response = null;

Expand Down Expand Up @@ -104,7 +107,7 @@ protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpC
}

var responseBodyWriter = httpContext.Response.BodyWriter;
await responseBodyWriter.WriteMessageAsync(response, serverCallContext, Method.ResponseMarshaller.Serializer, canFlush: false);
await responseBodyWriter.WriteMessageAsync(response, serverCallContext, Method.ResponseMarshaller.ContextualSerializer, canFlush: false);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
using System.Threading.Tasks;
using Grpc.AspNetCore.Server.Features;
using Grpc.Core;
using Grpc.Shared;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.Extensions.Logging;
Expand All @@ -42,6 +43,8 @@ internal sealed partial class HttpContextServerCallContext : ServerCallContext,
private AuthContext? _authContext;
// Internal for tests
internal ServerCallDeadlineManager? DeadlineManager;
private DefaultSerializationContext? _serializationContext;
private DefaultDeserializationContext? _deserializationContext;

internal HttpContextServerCallContext(HttpContext httpContext, GrpcServiceOptions serviceOptions, ILogger logger)
{
Expand All @@ -54,6 +57,15 @@ internal HttpContextServerCallContext(HttpContext httpContext, GrpcServiceOption
internal GrpcServiceOptions ServiceOptions { get; }
internal string? ResponseGrpcEncoding { get; private set; }

internal DefaultSerializationContext SerializationContext
{
get => _serializationContext ??= new DefaultSerializationContext();
}
internal DefaultDeserializationContext DeserializationContext
{
get => _deserializationContext ??= new DefaultDeserializationContext();
}

internal bool HasResponseTrailers => _responseTrailers != null;

protected override string? MethodCore => HttpContext.Request.Path.Value;
Expand Down
10 changes: 7 additions & 3 deletions src/Grpc.AspNetCore.Server/Internal/HttpContextStreamReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ internal class HttpContextStreamReader<TRequest> : IAsyncStreamReader<TRequest>
private static readonly Task<bool> False = Task.FromResult(false);

private readonly HttpContextServerCallContext _serverCallContext;
private readonly Func<byte[], TRequest> _deserializer;
private readonly Func<DeserializationContext, TRequest> _deserializer;

public HttpContextStreamReader(HttpContextServerCallContext serverCallContext, Func<byte[], TRequest> deserializer)
public HttpContextStreamReader(HttpContextServerCallContext serverCallContext, Func<DeserializationContext, TRequest> deserializer)
{
_serverCallContext = serverCallContext;
_deserializer = deserializer;
Expand Down Expand Up @@ -75,7 +75,11 @@ private bool ProcessPayload(byte[]? requestPayload)
return false;
}

Current = _deserializer(requestPayload);
var context = _serverCallContext.DeserializationContext;
context.SetPayload(requestPayload);
Current = _deserializer(context);
context.SetPayload(null);

return true;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ namespace Grpc.AspNetCore.Server.Internal
internal class HttpContextStreamWriter<TResponse> : IServerStreamWriter<TResponse>
{
private readonly HttpContextServerCallContext _context;
private readonly Func<TResponse, byte[]> _serializer;
private readonly Action<TResponse, SerializationContext> _serializer;

public HttpContextStreamWriter(HttpContextServerCallContext context, Func<TResponse, byte[]> serializer)
public HttpContextStreamWriter(HttpContextServerCallContext context, Action<TResponse, SerializationContext> serializer)
{
_context = context;
_serializer = serializer;
Expand Down
12 changes: 10 additions & 2 deletions src/Grpc.AspNetCore.Server/Internal/PipeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,17 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
return new Status(StatusCode.Unimplemented, $"Unsupported grpc-encoding value '{unsupportedEncoding}'. Supported encodings: {string.Join(", ", supportedEncodings)}");
}

public static Task WriteMessageAsync<TResponse>(this PipeWriter pipeWriter, TResponse response, HttpContextServerCallContext serverCallContext, Func<TResponse, byte[]> serializer, bool canFlush)
public static Task WriteMessageAsync<TResponse>(this PipeWriter pipeWriter, TResponse response, HttpContextServerCallContext serverCallContext, Action<TResponse, SerializationContext> serializer, bool canFlush)
{
var responsePayload = serializer(response);
var serializationContext = serverCallContext.SerializationContext;
serializer(response, serializationContext);
var responsePayload = serializationContext.Payload;
serializationContext.Payload = null;

if (responsePayload == null)
{
return Task.FromException(new InvalidOperationException("Serialization did not return a payload."));
}

// Flush messages unless WriteOptions.Flags has BufferHint set
var flush = canFlush && ((serverCallContext.WriteOptions?.Flags ?? default) & WriteFlags.BufferHint) != WriteFlags.BufferHint;
Expand Down
5 changes: 5 additions & 0 deletions src/Grpc.Net.Client/Grpc.Net.Client.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,9 @@
<Content Include="build\**\*.targets" PackagePath="%(Identity)" />
</ItemGroup>

<ItemGroup>
<Compile Include="..\Shared\DefaultDeserializationContext.cs" Link="Internal\DefaultDeserializationContext.cs" />
<Compile Include="..\Shared\DefaultSerializationContext.cs" Link="Internal\DefaultSerializationContext.cs" />
</ItemGroup>

</Project>
4 changes: 2 additions & 2 deletions src/Grpc.Net.Client/Internal/GrpcCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ public async Task<TResponse> GetResponseAsync()
var responseStream = await HttpResponse.Content.ReadAsStreamAsync().ConfigureAwait(false);
var message = await responseStream.ReadSingleMessageAsync(
Logger,
Method.ResponseMarshaller.Deserializer,
Method.ResponseMarshaller.ContextualDeserializer,
GrpcProtocolHelpers.GetGrpcEncoding(HttpResponse),
_callCts.Token).ConfigureAwait(false);
FinishResponse();
Expand Down Expand Up @@ -382,7 +382,7 @@ private void SetMessageContent(TRequest request, HttpRequestMessage message)
return stream.WriteMessage<TRequest>(
Logger,
request,
Method.RequestMarshaller.Serializer,
Method.RequestMarshaller.ContextualSerializer,
grpcEncoding,
Options.CancellationToken);
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ private async Task<bool> MoveNextCore(CancellationToken cancellationToken)

Current = await _responseStream.ReadStreamedMessageAsync(
_call.Logger,
_call.Method.ResponseMarshaller.Deserializer,
_call.Method.ResponseMarshaller.ContextualDeserializer,
GrpcProtocolHelpers.GetGrpcEncoding(_httpResponse),
cancellationToken).ConfigureAwait(false);
if (Current == null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ private async Task WriteAsyncCore(TRequest message)
await writeStream.WriteMessage<TRequest>(
_call.Logger,
message,
_call.Method.RequestMarshaller.Serializer,
_call.Method.RequestMarshaller.ContextualSerializer,
_grpcEncoding,
_call.CancellationToken).ConfigureAwait(false);
}
Expand Down
26 changes: 20 additions & 6 deletions src/Grpc.Net.Client/Internal/StreamExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
using System.Threading.Tasks;
using Grpc.Core;
using Grpc.Net.Client.Internal;
using Grpc.Shared;
using Microsoft.Extensions.Logging;
using CompressionLevel = System.IO.Compression.CompressionLevel;

Expand All @@ -46,7 +47,7 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
public static Task<TResponse?> ReadSingleMessageAsync<TResponse>(
this Stream responseStream,
ILogger logger,
Func<byte[], TResponse> deserializer,
Func<DeserializationContext, TResponse> deserializer,
string grpcEncoding,
CancellationToken cancellationToken)
where TResponse : class
Expand All @@ -57,7 +58,7 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
public static Task<TResponse?> ReadStreamedMessageAsync<TResponse>(
this Stream responseStream,
ILogger logger,
Func<byte[], TResponse> deserializer,
Func<DeserializationContext, TResponse> deserializer,
string grpcEncoding,
CancellationToken cancellationToken)
where TResponse : class
Expand All @@ -68,7 +69,7 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
private static async Task<TResponse?> ReadMessageCoreAsync<TResponse>(
this Stream responseStream,
ILogger logger,
Func<byte[], TResponse> deserializer,
Func<DeserializationContext, TResponse> deserializer,
string grpcEncoding,
CancellationToken cancellationToken,
bool canBeEmpty,
Expand Down Expand Up @@ -156,12 +157,18 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
throw new RpcException(CreateUnknownMessageEncodingMessageStatus(grpcEncoding, GrpcProtocolConstants.CompressionProviders.Select(c => c.EncodingName)));
}

#pragma warning disable CS8600 // Converting null literal or possible null value to non-nullable type.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nullable attribute isn't working here for some reason. Will clean it up when bug in compiler is fixed. Only effects designer.

messageData = decompressedMessage;
#pragma warning restore CS8600 // Converting null literal or possible null value to non-nullable type.
}

#pragma warning disable CS8602 // Dereference of a possibly null reference.
Log.DeserializingMessage(logger, messageData.Length, typeof(TResponse));
#pragma warning restore CS8602 // Dereference of a possibly null reference.

var message = deserializer(messageData);
var deserializationContext = new DefaultDeserializationContext();
deserializationContext.SetPayload(messageData);
var message = deserializer(deserializationContext);

if (singleMessage)
{
Expand Down Expand Up @@ -203,7 +210,7 @@ public static async Task WriteMessage<TMessage>(
this Stream stream,
ILogger logger,
TMessage message,
Func<TMessage, byte[]> serializer,
Action<TMessage, SerializationContext> serializer,
string grpcEncoding,
CancellationToken cancellationToken)
{
Expand All @@ -212,7 +219,14 @@ public static async Task WriteMessage<TMessage>(
Log.SendingMessage(logger);

// Serialize message first. Need to know size to prefix the length in the header
var data = serializer(message);
var serializationContext = new DefaultSerializationContext();
serializer(message, serializationContext);
var data = serializationContext.Payload;

if (data == null)
{
throw new InvalidOperationException("Serialization did not return a payload.");
}

var isCompressed = !string.Equals(grpcEncoding, GrpcProtocolConstants.IdentityGrpcEncoding, StringComparison.Ordinal);

Expand Down
48 changes: 48 additions & 0 deletions src/Shared/DefaultDeserializationContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#endregion

using System.Buffers;
using System.Diagnostics;
using Grpc.Core;

namespace Grpc.Shared
{
internal sealed class DefaultDeserializationContext : DeserializationContext
{
private byte[]? _payload;

public void SetPayload(byte[]? payload)
{
_payload = payload;
}

public override byte[] PayloadAsNewBuffer()
{
Debug.Assert(_payload != null, "Payload must be set.");
return _payload;
}

public override ReadOnlySequence<byte> PayloadAsReadOnlySequence()
{
Debug.Assert(_payload != null, "Payload must be set.");
return new ReadOnlySequence<byte>(_payload);
}

public override int PayloadLength => _payload?.Length ?? 0;
}
}
Loading