diff --git a/src/Grpc.AspNetCore.Server/Grpc.AspNetCore.Server.csproj b/src/Grpc.AspNetCore.Server/Grpc.AspNetCore.Server.csproj index 9dc5448ac..5d0a10901 100644 --- a/src/Grpc.AspNetCore.Server/Grpc.AspNetCore.Server.csproj +++ b/src/Grpc.AspNetCore.Server/Grpc.AspNetCore.Server.csproj @@ -23,6 +23,11 @@ true + + + + + diff --git a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ClientStreamingServerCallHandler.cs b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ClientStreamingServerCallHandler.cs index 55324eb61..917d3e2cd 100644 --- a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ClientStreamingServerCallHandler.cs +++ b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ClientStreamingServerCallHandler.cs @@ -87,7 +87,7 @@ protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpC service = activator.Create(); response = await _invoker( service, - new HttpContextStreamReader(serverCallContext, Method.RequestMarshaller.Deserializer), + new HttpContextStreamReader(serverCallContext, Method.RequestMarshaller.ContextualDeserializer), serverCallContext); } finally @@ -101,7 +101,7 @@ protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpC else { response = await _pipelineInvoker( - new HttpContextStreamReader(serverCallContext, Method.RequestMarshaller.Deserializer), + new HttpContextStreamReader(serverCallContext, Method.RequestMarshaller.ContextualDeserializer), serverCallContext); } @@ -112,7 +112,7 @@ protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpC } var responseBodyWriter = httpContext.Response.BodyWriter; - await responseBodyWriter.WriteMessageAsync(response, serverCallContext, Method.ResponseMarshaller.Serializer, canFlush: false); + await responseBodyWriter.WriteMessageAsync(response, serverCallContext, Method.ResponseMarshaller.ContextualSerializer, canFlush: false); } } } diff --git a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/DuplexStreamingServerCallHandler.cs b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/DuplexStreamingServerCallHandler.cs index 1536d9a8c..0b6e5635d 100644 --- a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/DuplexStreamingServerCallHandler.cs +++ b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/DuplexStreamingServerCallHandler.cs @@ -86,8 +86,8 @@ protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpC service = activator.Create(); await _invoker( service, - new HttpContextStreamReader(serverCallContext, Method.RequestMarshaller.Deserializer), - new HttpContextStreamWriter(serverCallContext, Method.ResponseMarshaller.Serializer), + new HttpContextStreamReader(serverCallContext, Method.RequestMarshaller.ContextualDeserializer), + new HttpContextStreamWriter(serverCallContext, Method.ResponseMarshaller.ContextualSerializer), serverCallContext); } finally @@ -101,8 +101,8 @@ await _invoker( else { await _pipelineInvoker( - new HttpContextStreamReader(serverCallContext, Method.RequestMarshaller.Deserializer), - new HttpContextStreamWriter(serverCallContext, Method.ResponseMarshaller.Serializer), + new HttpContextStreamReader(serverCallContext, Method.RequestMarshaller.ContextualDeserializer), + new HttpContextStreamWriter(serverCallContext, Method.ResponseMarshaller.ContextualSerializer), serverCallContext); } } diff --git a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ServerStreamingServerCallHandler.cs b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ServerStreamingServerCallHandler.cs index c9a5b64a2..c4cf7473a 100644 --- a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ServerStreamingServerCallHandler.cs +++ b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ServerStreamingServerCallHandler.cs @@ -72,7 +72,10 @@ protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpC { // Decode request var requestPayload = await httpContext.Request.BodyReader.ReadSingleMessageAsync(serverCallContext); - var request = Method.RequestMarshaller.Deserializer(requestPayload); + + serverCallContext.DeserializationContext.SetPayload(requestPayload); + var request = Method.RequestMarshaller.ContextualDeserializer(serverCallContext.DeserializationContext); + serverCallContext.DeserializationContext.SetPayload(null); if (_pipelineInvoker == null) { @@ -84,7 +87,7 @@ protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpC await _invoker( service, request, - new HttpContextStreamWriter(serverCallContext, Method.ResponseMarshaller.Serializer), + new HttpContextStreamWriter(serverCallContext, Method.ResponseMarshaller.ContextualSerializer), serverCallContext); } finally @@ -99,7 +102,7 @@ await _invoker( { await _pipelineInvoker( request, - new HttpContextStreamWriter(serverCallContext, Method.ResponseMarshaller.Serializer), + new HttpContextStreamWriter(serverCallContext, Method.ResponseMarshaller.ContextualSerializer), serverCallContext); } } diff --git a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/UnaryServerCallHandler.cs b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/UnaryServerCallHandler.cs index b0d74d398..adf3e6e4a 100644 --- a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/UnaryServerCallHandler.cs +++ b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/UnaryServerCallHandler.cs @@ -71,7 +71,10 @@ public UnaryServerCallHandler( protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpContextServerCallContext serverCallContext) { var requestPayload = await httpContext.Request.BodyReader.ReadSingleMessageAsync(serverCallContext); - var request = Method.RequestMarshaller.Deserializer(requestPayload); + + serverCallContext.DeserializationContext.SetPayload(requestPayload); + var request = Method.RequestMarshaller.ContextualDeserializer(serverCallContext.DeserializationContext); + serverCallContext.DeserializationContext.SetPayload(null); TResponse? response = null; @@ -104,7 +107,7 @@ protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpC } var responseBodyWriter = httpContext.Response.BodyWriter; - await responseBodyWriter.WriteMessageAsync(response, serverCallContext, Method.ResponseMarshaller.Serializer, canFlush: false); + await responseBodyWriter.WriteMessageAsync(response, serverCallContext, Method.ResponseMarshaller.ContextualSerializer, canFlush: false); } } } diff --git a/src/Grpc.AspNetCore.Server/Internal/HttpContextServerCallContext.cs b/src/Grpc.AspNetCore.Server/Internal/HttpContextServerCallContext.cs index 7a844656d..eddd7a10c 100644 --- a/src/Grpc.AspNetCore.Server/Internal/HttpContextServerCallContext.cs +++ b/src/Grpc.AspNetCore.Server/Internal/HttpContextServerCallContext.cs @@ -24,6 +24,7 @@ using System.Threading.Tasks; using Grpc.AspNetCore.Server.Features; using Grpc.Core; +using Grpc.Shared; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.Extensions.Logging; @@ -42,6 +43,8 @@ internal sealed partial class HttpContextServerCallContext : ServerCallContext, private AuthContext? _authContext; // Internal for tests internal ServerCallDeadlineManager? DeadlineManager; + private DefaultSerializationContext? _serializationContext; + private DefaultDeserializationContext? _deserializationContext; internal HttpContextServerCallContext(HttpContext httpContext, GrpcServiceOptions serviceOptions, ILogger logger) { @@ -54,6 +57,15 @@ internal HttpContextServerCallContext(HttpContext httpContext, GrpcServiceOption internal GrpcServiceOptions ServiceOptions { get; } internal string? ResponseGrpcEncoding { get; private set; } + internal DefaultSerializationContext SerializationContext + { + get => _serializationContext ??= new DefaultSerializationContext(); + } + internal DefaultDeserializationContext DeserializationContext + { + get => _deserializationContext ??= new DefaultDeserializationContext(); + } + internal bool HasResponseTrailers => _responseTrailers != null; protected override string? MethodCore => HttpContext.Request.Path.Value; diff --git a/src/Grpc.AspNetCore.Server/Internal/HttpContextStreamReader.cs b/src/Grpc.AspNetCore.Server/Internal/HttpContextStreamReader.cs index 12c2a8597..42151b54a 100644 --- a/src/Grpc.AspNetCore.Server/Internal/HttpContextStreamReader.cs +++ b/src/Grpc.AspNetCore.Server/Internal/HttpContextStreamReader.cs @@ -29,9 +29,9 @@ internal class HttpContextStreamReader : IAsyncStreamReader private static readonly Task False = Task.FromResult(false); private readonly HttpContextServerCallContext _serverCallContext; - private readonly Func _deserializer; + private readonly Func _deserializer; - public HttpContextStreamReader(HttpContextServerCallContext serverCallContext, Func deserializer) + public HttpContextStreamReader(HttpContextServerCallContext serverCallContext, Func deserializer) { _serverCallContext = serverCallContext; _deserializer = deserializer; @@ -75,7 +75,11 @@ private bool ProcessPayload(byte[]? requestPayload) return false; } - Current = _deserializer(requestPayload); + var context = _serverCallContext.DeserializationContext; + context.SetPayload(requestPayload); + Current = _deserializer(context); + context.SetPayload(null); + return true; } } diff --git a/src/Grpc.AspNetCore.Server/Internal/HttpContextStreamWriter.cs b/src/Grpc.AspNetCore.Server/Internal/HttpContextStreamWriter.cs index 842d9f1bb..0b5407ddb 100644 --- a/src/Grpc.AspNetCore.Server/Internal/HttpContextStreamWriter.cs +++ b/src/Grpc.AspNetCore.Server/Internal/HttpContextStreamWriter.cs @@ -25,9 +25,9 @@ namespace Grpc.AspNetCore.Server.Internal internal class HttpContextStreamWriter : IServerStreamWriter { private readonly HttpContextServerCallContext _context; - private readonly Func _serializer; + private readonly Action _serializer; - public HttpContextStreamWriter(HttpContextServerCallContext context, Func serializer) + public HttpContextStreamWriter(HttpContextServerCallContext context, Action serializer) { _context = context; _serializer = serializer; diff --git a/src/Grpc.AspNetCore.Server/Internal/PipeExtensions.cs b/src/Grpc.AspNetCore.Server/Internal/PipeExtensions.cs index 8cc6a7418..028f5d4c0 100644 --- a/src/Grpc.AspNetCore.Server/Internal/PipeExtensions.cs +++ b/src/Grpc.AspNetCore.Server/Internal/PipeExtensions.cs @@ -48,9 +48,17 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport return new Status(StatusCode.Unimplemented, $"Unsupported grpc-encoding value '{unsupportedEncoding}'. Supported encodings: {string.Join(", ", supportedEncodings)}"); } - public static Task WriteMessageAsync(this PipeWriter pipeWriter, TResponse response, HttpContextServerCallContext serverCallContext, Func serializer, bool canFlush) + public static Task WriteMessageAsync(this PipeWriter pipeWriter, TResponse response, HttpContextServerCallContext serverCallContext, Action serializer, bool canFlush) { - var responsePayload = serializer(response); + var serializationContext = serverCallContext.SerializationContext; + serializer(response, serializationContext); + var responsePayload = serializationContext.Payload; + serializationContext.Payload = null; + + if (responsePayload == null) + { + return Task.FromException(new InvalidOperationException("Serialization did not return a payload.")); + } // Flush messages unless WriteOptions.Flags has BufferHint set var flush = canFlush && ((serverCallContext.WriteOptions?.Flags ?? default) & WriteFlags.BufferHint) != WriteFlags.BufferHint; diff --git a/src/Grpc.Net.Client/Grpc.Net.Client.csproj b/src/Grpc.Net.Client/Grpc.Net.Client.csproj index d8e024875..fc8dbaf90 100644 --- a/src/Grpc.Net.Client/Grpc.Net.Client.csproj +++ b/src/Grpc.Net.Client/Grpc.Net.Client.csproj @@ -34,4 +34,9 @@ + + + + + diff --git a/src/Grpc.Net.Client/Internal/GrpcCall.cs b/src/Grpc.Net.Client/Internal/GrpcCall.cs index 57215fb96..8c5ee9e0a 100644 --- a/src/Grpc.Net.Client/Internal/GrpcCall.cs +++ b/src/Grpc.Net.Client/Internal/GrpcCall.cs @@ -300,7 +300,7 @@ public async Task GetResponseAsync() var responseStream = await HttpResponse.Content.ReadAsStreamAsync().ConfigureAwait(false); var message = await responseStream.ReadSingleMessageAsync( Logger, - Method.ResponseMarshaller.Deserializer, + Method.ResponseMarshaller.ContextualDeserializer, GrpcProtocolHelpers.GetGrpcEncoding(HttpResponse), _callCts.Token).ConfigureAwait(false); FinishResponse(); @@ -382,7 +382,7 @@ private void SetMessageContent(TRequest request, HttpRequestMessage message) return stream.WriteMessage( Logger, request, - Method.RequestMarshaller.Serializer, + Method.RequestMarshaller.ContextualSerializer, grpcEncoding, Options.CancellationToken); }, diff --git a/src/Grpc.Net.Client/Internal/HttpContentClientStreamReader.cs b/src/Grpc.Net.Client/Internal/HttpContentClientStreamReader.cs index 9d7123684..8f28e56b9 100644 --- a/src/Grpc.Net.Client/Internal/HttpContentClientStreamReader.cs +++ b/src/Grpc.Net.Client/Internal/HttpContentClientStreamReader.cs @@ -122,7 +122,7 @@ private async Task MoveNextCore(CancellationToken cancellationToken) Current = await _responseStream.ReadStreamedMessageAsync( _call.Logger, - _call.Method.ResponseMarshaller.Deserializer, + _call.Method.ResponseMarshaller.ContextualDeserializer, GrpcProtocolHelpers.GetGrpcEncoding(_httpResponse), cancellationToken).ConfigureAwait(false); if (Current == null) diff --git a/src/Grpc.Net.Client/Internal/HttpContentClientStreamWriter.cs b/src/Grpc.Net.Client/Internal/HttpContentClientStreamWriter.cs index 3be40a47f..228bb9f21 100644 --- a/src/Grpc.Net.Client/Internal/HttpContentClientStreamWriter.cs +++ b/src/Grpc.Net.Client/Internal/HttpContentClientStreamWriter.cs @@ -129,7 +129,7 @@ private async Task WriteAsyncCore(TRequest message) await writeStream.WriteMessage( _call.Logger, message, - _call.Method.RequestMarshaller.Serializer, + _call.Method.RequestMarshaller.ContextualSerializer, _grpcEncoding, _call.CancellationToken).ConfigureAwait(false); } diff --git a/src/Grpc.Net.Client/Internal/StreamExtensions.cs b/src/Grpc.Net.Client/Internal/StreamExtensions.cs index e85a93773..95c90603b 100644 --- a/src/Grpc.Net.Client/Internal/StreamExtensions.cs +++ b/src/Grpc.Net.Client/Internal/StreamExtensions.cs @@ -26,6 +26,7 @@ using System.Threading.Tasks; using Grpc.Core; using Grpc.Net.Client.Internal; +using Grpc.Shared; using Microsoft.Extensions.Logging; using CompressionLevel = System.IO.Compression.CompressionLevel; @@ -46,7 +47,7 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport public static Task ReadSingleMessageAsync( this Stream responseStream, ILogger logger, - Func deserializer, + Func deserializer, string grpcEncoding, CancellationToken cancellationToken) where TResponse : class @@ -57,7 +58,7 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport public static Task ReadStreamedMessageAsync( this Stream responseStream, ILogger logger, - Func deserializer, + Func deserializer, string grpcEncoding, CancellationToken cancellationToken) where TResponse : class @@ -68,7 +69,7 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport private static async Task ReadMessageCoreAsync( this Stream responseStream, ILogger logger, - Func deserializer, + Func deserializer, string grpcEncoding, CancellationToken cancellationToken, bool canBeEmpty, @@ -156,12 +157,18 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport throw new RpcException(CreateUnknownMessageEncodingMessageStatus(grpcEncoding, GrpcProtocolConstants.CompressionProviders.Select(c => c.EncodingName))); } +#pragma warning disable CS8600 // Converting null literal or possible null value to non-nullable type. messageData = decompressedMessage; +#pragma warning restore CS8600 // Converting null literal or possible null value to non-nullable type. } +#pragma warning disable CS8602 // Dereference of a possibly null reference. Log.DeserializingMessage(logger, messageData.Length, typeof(TResponse)); +#pragma warning restore CS8602 // Dereference of a possibly null reference. - var message = deserializer(messageData); + var deserializationContext = new DefaultDeserializationContext(); + deserializationContext.SetPayload(messageData); + var message = deserializer(deserializationContext); if (singleMessage) { @@ -203,7 +210,7 @@ public static async Task WriteMessage( this Stream stream, ILogger logger, TMessage message, - Func serializer, + Action serializer, string grpcEncoding, CancellationToken cancellationToken) { @@ -212,7 +219,14 @@ public static async Task WriteMessage( Log.SendingMessage(logger); // Serialize message first. Need to know size to prefix the length in the header - var data = serializer(message); + var serializationContext = new DefaultSerializationContext(); + serializer(message, serializationContext); + var data = serializationContext.Payload; + + if (data == null) + { + throw new InvalidOperationException("Serialization did not return a payload."); + } var isCompressed = !string.Equals(grpcEncoding, GrpcProtocolConstants.IdentityGrpcEncoding, StringComparison.Ordinal); diff --git a/src/Shared/DefaultDeserializationContext.cs b/src/Shared/DefaultDeserializationContext.cs new file mode 100644 index 000000000..2a137e89d --- /dev/null +++ b/src/Shared/DefaultDeserializationContext.cs @@ -0,0 +1,48 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System.Buffers; +using System.Diagnostics; +using Grpc.Core; + +namespace Grpc.Shared +{ + internal sealed class DefaultDeserializationContext : DeserializationContext + { + private byte[]? _payload; + + public void SetPayload(byte[]? payload) + { + _payload = payload; + } + + public override byte[] PayloadAsNewBuffer() + { + Debug.Assert(_payload != null, "Payload must be set."); + return _payload; + } + + public override ReadOnlySequence PayloadAsReadOnlySequence() + { + Debug.Assert(_payload != null, "Payload must be set."); + return new ReadOnlySequence(_payload); + } + + public override int PayloadLength => _payload?.Length ?? 0; + } +} diff --git a/src/Shared/DefaultSerializationContext.cs b/src/Shared/DefaultSerializationContext.cs new file mode 100644 index 000000000..70a39b796 --- /dev/null +++ b/src/Shared/DefaultSerializationContext.cs @@ -0,0 +1,32 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using Grpc.Core; + +namespace Grpc.Shared +{ + internal sealed class DefaultSerializationContext : SerializationContext + { + public byte[]? Payload { get; set; } + + public override void Complete(byte[] payload) + { + Payload = payload; + } + } +} diff --git a/test/Grpc.AspNetCore.Server.Tests/HttpContextStreamReaderTests.cs b/test/Grpc.AspNetCore.Server.Tests/HttpContextStreamReaderTests.cs index fe9806228..eaa75d569 100644 --- a/test/Grpc.AspNetCore.Server.Tests/HttpContextStreamReaderTests.cs +++ b/test/Grpc.AspNetCore.Server.Tests/HttpContextStreamReaderTests.cs @@ -19,7 +19,6 @@ using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; -using Google.Protobuf; using Greet; using Grpc.AspNetCore.Server.Internal; using Grpc.Tests.Shared; @@ -40,12 +39,7 @@ public void MoveNext_AlreadyCancelledToken_CancelReturnImmediately() var httpContext = new DefaultHttpContext(); var serverCallContext = HttpContextServerCallContextHelper.CreateServerCallContext(httpContext); - var reader = new HttpContextStreamReader(serverCallContext, (data) => - { - var message = new HelloReply(); - message.MergeFrom(data); - return message; - }); + var reader = new HttpContextStreamReader(serverCallContext, MessageHelpers.HelloReplyMarshaller.ContextualDeserializer); // Act var nextTask = reader.MoveNext(new CancellationToken(true)); @@ -64,12 +58,7 @@ public async Task MoveNext_TokenCancelledDuringMoveNext_CancelTask() var httpContext = new DefaultHttpContext(); httpContext.Features.Set(new TestRequestBodyPipeFeature(PipeReader.Create(ms))); var serverCallContext = HttpContextServerCallContextHelper.CreateServerCallContext(httpContext); - var reader = new HttpContextStreamReader(serverCallContext, (data) => - { - var message = new HelloReply(); - message.MergeFrom(data); - return message; - }); + var reader = new HttpContextStreamReader(serverCallContext, MessageHelpers.HelloReplyMarshaller.ContextualDeserializer); var cts = new CancellationTokenSource(); diff --git a/test/Grpc.AspNetCore.Server.Tests/HttpContextStreamWriterTests.cs b/test/Grpc.AspNetCore.Server.Tests/HttpContextStreamWriterTests.cs index bdd4ddf00..697f94f98 100644 --- a/test/Grpc.AspNetCore.Server.Tests/HttpContextStreamWriterTests.cs +++ b/test/Grpc.AspNetCore.Server.Tests/HttpContextStreamWriterTests.cs @@ -19,7 +19,6 @@ using System.IO; using System.IO.Pipelines; using System.Threading.Tasks; -using Google.Protobuf; using Greet; using Grpc.AspNetCore.Server.Internal; using Grpc.Core; @@ -42,7 +41,7 @@ public async Task WriteAsync_DefaultWriteOptions_Flushes() var httpContext = new DefaultHttpContext(); httpContext.Features.Set(new TestResponseBodyPipeFeature(PipeWriter.Create(ms))); var serverCallContext = HttpContextServerCallContextHelper.CreateServerCallContext(httpContext); - var writer = new HttpContextStreamWriter(serverCallContext, (message) => message.ToByteArray()); + var writer = new HttpContextStreamWriter(serverCallContext, MessageHelpers.HelloReplyMarshaller.ContextualSerializer); // Act 1 await writer.WriteAsync(new HelloReply @@ -80,7 +79,7 @@ public async Task WriteAsync_BufferHintWriteOptions_DoesNotFlush() var httpContext = new DefaultHttpContext(); httpContext.Features.Set(new TestResponseBodyPipeFeature(PipeWriter.Create(ms))); var serverCallContext = HttpContextServerCallContextHelper.CreateServerCallContext(httpContext); - var writer = new HttpContextStreamWriter(serverCallContext, (message) => message.ToByteArray()); + var writer = new HttpContextStreamWriter(serverCallContext, MessageHelpers.HelloReplyMarshaller.ContextualSerializer); serverCallContext.WriteOptions = new WriteOptions(WriteFlags.BufferHint); // Act 1 diff --git a/test/Grpc.Net.Client.Tests/AsyncClientStreamingCallTests.cs b/test/Grpc.Net.Client.Tests/AsyncClientStreamingCallTests.cs index fc35d0250..f43aa3e26 100644 --- a/test/Grpc.Net.Client.Tests/AsyncClientStreamingCallTests.cs +++ b/test/Grpc.Net.Client.Tests/AsyncClientStreamingCallTests.cs @@ -112,9 +112,9 @@ public async Task AsyncClientStreamingCall_Success_RequestContentSent() await call.RequestStream.CompleteAsync().DefaultTimeout(); var requestContent = await streamTask.DefaultTimeout(); - var requestMessage = await requestContent.ReadStreamedMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.Deserializer, GrpcProtocolConstants.IdentityGrpcEncoding, CancellationToken.None).DefaultTimeout(); + var requestMessage = await requestContent.ReadStreamedMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, GrpcProtocolConstants.IdentityGrpcEncoding, CancellationToken.None).DefaultTimeout(); Assert.AreEqual("1", requestMessage.Name); - requestMessage = await requestContent.ReadStreamedMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.Deserializer, GrpcProtocolConstants.IdentityGrpcEncoding, CancellationToken.None).DefaultTimeout(); + requestMessage = await requestContent.ReadStreamedMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, GrpcProtocolConstants.IdentityGrpcEncoding, CancellationToken.None).DefaultTimeout(); Assert.AreEqual("2", requestMessage.Name); var responseMessage = await responseTask.DefaultTimeout(); diff --git a/test/Grpc.Net.Client.Tests/AsyncDuplexStreamingCallTests.cs b/test/Grpc.Net.Client.Tests/AsyncDuplexStreamingCallTests.cs index 2525b6789..51097abf1 100644 --- a/test/Grpc.Net.Client.Tests/AsyncDuplexStreamingCallTests.cs +++ b/test/Grpc.Net.Client.Tests/AsyncDuplexStreamingCallTests.cs @@ -117,9 +117,9 @@ public async Task AsyncDuplexStreamingCall_MessagesStreamed_MessagesReceived() Assert.IsNotNull(content); var requestContent = await content!.ReadAsStreamAsync().DefaultTimeout(); - var requestMessage = await requestContent.ReadStreamedMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.Deserializer, GrpcProtocolConstants.IdentityGrpcEncoding, CancellationToken.None).DefaultTimeout(); + var requestMessage = await requestContent.ReadStreamedMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, GrpcProtocolConstants.IdentityGrpcEncoding, CancellationToken.None).DefaultTimeout(); Assert.AreEqual("1", requestMessage.Name); - requestMessage = await requestContent.ReadStreamedMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.Deserializer, GrpcProtocolConstants.IdentityGrpcEncoding, CancellationToken.None).DefaultTimeout(); + requestMessage = await requestContent.ReadStreamedMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, GrpcProtocolConstants.IdentityGrpcEncoding, CancellationToken.None).DefaultTimeout(); Assert.AreEqual("2", requestMessage.Name); Assert.IsNull(responseStream.Current); diff --git a/test/Grpc.Net.Client.Tests/AsyncUnaryCallTests.cs b/test/Grpc.Net.Client.Tests/AsyncUnaryCallTests.cs index 9ebbb19f2..1ccbddfbc 100644 --- a/test/Grpc.Net.Client.Tests/AsyncUnaryCallTests.cs +++ b/test/Grpc.Net.Client.Tests/AsyncUnaryCallTests.cs @@ -107,7 +107,7 @@ public async Task AsyncUnaryCall_Success_RequestContentSent() Assert.IsNotNull(content); var requestContent = await content!.ReadAsStreamAsync().DefaultTimeout(); - var requestMessage = await requestContent.ReadSingleMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.Deserializer, GrpcProtocolConstants.IdentityGrpcEncoding, CancellationToken.None).DefaultTimeout(); + var requestMessage = await requestContent.ReadSingleMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, GrpcProtocolConstants.IdentityGrpcEncoding, CancellationToken.None).DefaultTimeout(); Assert.AreEqual("World", requestMessage.Name); } diff --git a/test/Grpc.Net.Client.Tests/CompressionTests.cs b/test/Grpc.Net.Client.Tests/CompressionTests.cs index 1eb73bd67..f084b5ce6 100644 --- a/test/Grpc.Net.Client.Tests/CompressionTests.cs +++ b/test/Grpc.Net.Client.Tests/CompressionTests.cs @@ -53,7 +53,7 @@ public void AsyncUnaryCall_UnknownCompressMetadataSentWithRequest_ThrowsError() helloRequest = await StreamExtensions.ReadSingleMessageAsync( requestStream, NullLogger.Instance, - ClientTestHelpers.ServiceMethod.RequestMarshaller.Deserializer, + ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, "gzip", CancellationToken.None); @@ -96,7 +96,7 @@ public async Task AsyncUnaryCall_CompressMetadataSentWithRequest_RequestMessageC helloRequest = await StreamExtensions.ReadSingleMessageAsync( requestStream, NullLogger.Instance, - ClientTestHelpers.ServiceMethod.RequestMarshaller.Deserializer, + ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, "gzip", CancellationToken.None); @@ -147,7 +147,7 @@ public async Task AsyncUnaryCall_CompressedResponse_ResponseMessageDecompressed( helloRequest = await StreamExtensions.ReadSingleMessageAsync( requestStream, NullLogger.Instance, - ClientTestHelpers.ServiceMethod.RequestMarshaller.Deserializer, + ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, "gzip", CancellationToken.None); @@ -191,7 +191,7 @@ public void AsyncUnaryCall_CompressedResponseWithUnknownEncoding_ErrorThrown() helloRequest = await StreamExtensions.ReadSingleMessageAsync( requestStream, NullLogger.Instance, - ClientTestHelpers.ServiceMethod.RequestMarshaller.Deserializer, + ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, "gzip", CancellationToken.None); diff --git a/test/Shared/MessageHelpers.cs b/test/Shared/MessageHelpers.cs index 65f00defd..027f04b82 100644 --- a/test/Shared/MessageHelpers.cs +++ b/test/Shared/MessageHelpers.cs @@ -18,19 +18,26 @@ using System.Collections.Generic; using System.IO; -using System.IO.Compression; using System.IO.Pipelines; using System.Threading.Tasks; using Google.Protobuf; +using Greet; using Grpc.AspNetCore.Server; using Grpc.AspNetCore.Server.Compression; using Grpc.AspNetCore.Server.Internal; +using Grpc.Core; using Microsoft.AspNetCore.Http; +using CompressionLevel = System.IO.Compression.CompressionLevel; namespace Grpc.Tests.Shared { internal static class MessageHelpers { + public static readonly Marshaller HelloRequestMarshaller = Marshallers.Create(r => r.ToByteArray(), data => HelloRequest.Parser.ParseFrom(data)); + public static readonly Marshaller HelloReplyMarshaller = Marshallers.Create(r => r.ToByteArray(), data => HelloReply.Parser.ParseFrom(data)); + + public static readonly Method ServiceMethod = new Method(MethodType.Unary, "ServiceName", "MethodName", HelloRequestMarshaller, HelloReplyMarshaller); + private static readonly HttpContextServerCallContext TestServerCallContext = HttpContextServerCallContextHelper.CreateServerCallContext(); public static T AssertReadMessage(byte[] messageData, string? compressionEncoding = null, List? compressionProviders = null) where T : IMessage, new()