Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve gRPC JSON transcoding request binding and support HttpBody #41523

Merged
merged 4 commits into from
May 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Text.Json;
using Google.Api;
using Grpc.Core;
using Grpc.Shared.Server;
using Microsoft.AspNetCore.Http;
Expand Down Expand Up @@ -42,8 +43,15 @@ protected override async Task HandleCallAsyncCore(HttpContext httpContext, JsonT
throw new RpcException(new Status(StatusCode.Cancelled, "No message returned from method."));
}

serverCallContext.EnsureResponseHeaders();

await JsonRequestHelpers.SendMessage(serverCallContext, SerializerOptions, response, CancellationToken.None);
if (response is HttpBody httpBody)
{
serverCallContext.EnsureResponseHeaders(httpBody.ContentType);
await serverCallContext.HttpContext.Response.Body.WriteAsync(httpBody.Data.Memory);
}
else
{
serverCallContext.EnsureResponseHeaders();
await JsonRequestHelpers.SendMessage(serverCallContext, SerializerOptions, response, CancellationToken.None);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Text.Json;
using Google.Api;
using Grpc.Core;

namespace Microsoft.AspNetCore.Grpc.JsonTranscoding.Internal;
Expand Down Expand Up @@ -85,7 +86,17 @@ private async Task WriteAsyncCore(TResponse message, CancellationToken cancellat

private async Task WriteMessageAndDelimiter(TResponse message, CancellationToken cancellationToken)
{
await JsonRequestHelpers.SendMessage(_context, _serializerOptions, message, cancellationToken);
if (message is HttpBody httpBody)
{
_context.EnsureResponseHeaders(httpBody.ContentType);
await _context.HttpContext.Response.Body.WriteAsync(httpBody.Data.Memory, cancellationToken);
}
else
{
_context.EnsureResponseHeaders();
await JsonRequestHelpers.SendMessage(_context, _serializerOptions, message, cancellationToken);
}

await _context.HttpContext.Response.Body.WriteAsync(GrpcProtocolConstants.StreamingDelimiter, cancellationToken);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public override TMessage Read(

if (reader.TokenType != JsonTokenType.StartObject)
{
throw new InvalidOperationException($"Unexpected JSON token: {reader.TokenType}");
throw new JsonException($"Unexpected JSON token: {reader.TokenType}");
}

while (reader.Read())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Text.Json;
using Google.Api;
using Google.Protobuf;
using Google.Protobuf.Reflection;
using Grpc.Core;
Expand All @@ -13,6 +15,7 @@
using Microsoft.AspNetCore.Grpc.JsonTranscoding.Internal.Json;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc.Formatters;
using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Extensions.Primitives;
using Microsoft.Net.Http.Headers;

Expand Down Expand Up @@ -85,7 +88,7 @@ public static (Stream stream, bool usesTranscodingStream) GetStream(Stream inner
}
}

public static async Task SendErrorResponse(HttpResponse response, Encoding encoding, Status status, JsonSerializerOptions options)
public static async ValueTask SendErrorResponse(HttpResponse response, Encoding encoding, Status status, JsonSerializerOptions options)
{
if (!response.HasStarted)
{
Expand Down Expand Up @@ -147,7 +150,7 @@ public static int MapStatusCodeToHttpStatus(StatusCode statusCode)
return StatusCodes.Status500InternalServerError;
}

public static async Task WriteResponseMessage(HttpResponse response, Encoding encoding, object responseBody, JsonSerializerOptions options, CancellationToken cancellationToken)
public static async ValueTask WriteResponseMessage(HttpResponse response, Encoding encoding, object responseBody, JsonSerializerOptions options, CancellationToken cancellationToken)
{
var (stream, usesTranscodingStream) = GetStream(response.Body, encoding);

Expand All @@ -164,7 +167,7 @@ public static async Task WriteResponseMessage(HttpResponse response, Encoding en
}
}

public static async Task<TRequest> ReadMessage<TRequest>(JsonTranscodingServerCallContext serverCallContext, JsonSerializerOptions serializerOptions) where TRequest : class
public static async ValueTask<TRequest> ReadMessage<TRequest>(JsonTranscodingServerCallContext serverCallContext, JsonSerializerOptions serializerOptions) where TRequest : class
{
try
{
Expand All @@ -173,65 +176,75 @@ public static async Task<TRequest> ReadMessage<TRequest>(JsonTranscodingServerCa
IMessage requestMessage;
if (serverCallContext.DescriptorInfo.BodyDescriptor != null)
{
if (!serverCallContext.IsJsonRequestContent)
Type type;
object bodyContent;

if (serverCallContext.DescriptorInfo.BodyDescriptor.FullName == HttpBody.Descriptor.FullName)
{
GrpcServerLog.UnsupportedRequestContentType(serverCallContext.Logger, serverCallContext.HttpContext.Request.ContentType);
throw new RpcException(new Status(StatusCode.InvalidArgument, "Request content-type of application/json is required."));
type = typeof(HttpBody);

bodyContent = await ReadHttpBodyAsync(serverCallContext);
}
else
{
if (!serverCallContext.IsJsonRequestContent)
{
GrpcServerLog.UnsupportedRequestContentType(serverCallContext.Logger, serverCallContext.HttpContext.Request.ContentType);
throw new InvalidOperationException($"Unable to read the request as JSON because the request content type '{serverCallContext.HttpContext.Request.ContentType}' is not a known JSON content type.");
}

var (stream, usesTranscodingStream) = GetStream(serverCallContext.HttpContext.Request.Body, serverCallContext.RequestEncoding);
var (stream, usesTranscodingStream) = GetStream(serverCallContext.HttpContext.Request.Body, serverCallContext.RequestEncoding);

try
{
if (serverCallContext.DescriptorInfo.BodyDescriptorRepeated)
try
{
requestMessage = (IMessage)Activator.CreateInstance<TRequest>();
if (serverCallContext.DescriptorInfo.BodyDescriptorRepeated)
{
requestMessage = (IMessage)Activator.CreateInstance<TRequest>();

// TODO: JsonSerializer currently doesn't support deserializing values onto an existing object or collection.
// Either update this to use new functionality in JsonSerializer or improve work-around perf.
var type = JsonConverterHelper.GetFieldType(serverCallContext.DescriptorInfo.BodyFieldDescriptors.Last());
var listType = typeof(List<>).MakeGenericType(type);
// TODO: JsonSerializer currently doesn't support deserializing values onto an existing object or collection.
// Either update this to use new functionality in JsonSerializer or improve work-around perf.
type = JsonConverterHelper.GetFieldType(serverCallContext.DescriptorInfo.BodyFieldDescriptors.Last());
type = typeof(List<>).MakeGenericType(type);

GrpcServerLog.DeserializingMessage(serverCallContext.Logger, listType);
var repeatedContent = (IList)(await JsonSerializer.DeserializeAsync(stream, listType, serializerOptions))!;
GrpcServerLog.DeserializingMessage(serverCallContext.Logger, type);

ServiceDescriptorHelpers.RecursiveSetValue(requestMessage, serverCallContext.DescriptorInfo.BodyFieldDescriptors, repeatedContent);
}
else
{
IMessage bodyContent;
bodyContent = (await JsonSerializer.DeserializeAsync(stream, type, serializerOptions))!;

try
{
GrpcServerLog.DeserializingMessage(serverCallContext.Logger, serverCallContext.DescriptorInfo.BodyDescriptor.ClrType);
bodyContent = (IMessage)(await JsonSerializer.DeserializeAsync(stream, serverCallContext.DescriptorInfo.BodyDescriptor.ClrType, serializerOptions))!;
if (bodyContent == null)
{
throw new InvalidOperationException($"Unable to deserialize null to {type.Name}.");
}
}
catch (JsonException)
{
throw new RpcException(new Status(StatusCode.InvalidArgument, "Request JSON payload is not correctly formatted."));
}
catch (Exception exception)
else
{
throw new RpcException(new Status(StatusCode.InvalidArgument, exception.Message));
}
type = serverCallContext.DescriptorInfo.BodyDescriptor.ClrType;

if (serverCallContext.DescriptorInfo.BodyFieldDescriptors != null)
{
requestMessage = (IMessage)Activator.CreateInstance<TRequest>();
ServiceDescriptorHelpers.RecursiveSetValue(requestMessage, serverCallContext.DescriptorInfo.BodyFieldDescriptors, bodyContent!); // TODO - check nullability
GrpcServerLog.DeserializingMessage(serverCallContext.Logger, type);
bodyContent = (IMessage)(await JsonSerializer.DeserializeAsync(stream, serverCallContext.DescriptorInfo.BodyDescriptor.ClrType, serializerOptions))!;
}
else
}
finally
{
if (usesTranscodingStream)
{
requestMessage = bodyContent;
await stream.DisposeAsync();
}
}
}
finally

if (serverCallContext.DescriptorInfo.BodyFieldDescriptors != null)
{
requestMessage = (IMessage)Activator.CreateInstance<TRequest>();
ServiceDescriptorHelpers.RecursiveSetValue(requestMessage, serverCallContext.DescriptorInfo.BodyFieldDescriptors, bodyContent); // TODO - check nullability
}
else
{
if (usesTranscodingStream)
if (bodyContent == null)
{
await stream.DisposeAsync();
throw new InvalidOperationException($"Unable to deserialize null to {type.Name}.");
}

requestMessage = (IMessage)bodyContent;
}
}
else
Expand Down Expand Up @@ -265,11 +278,60 @@ public static async Task<TRequest> ReadMessage<TRequest>(JsonTranscodingServerCa
GrpcServerLog.ReceivedMessage(serverCallContext.Logger);
return (TRequest)requestMessage;
}
catch (JsonException ex)
{
GrpcServerLog.ErrorReadingMessage(serverCallContext.Logger, ex);
throw new RpcException(new Status(StatusCode.InvalidArgument, "Request JSON payload is not correctly formatted.", ex));
}
catch (Exception ex)
{
GrpcServerLog.ErrorReadingMessage(serverCallContext.Logger, ex);
throw;
throw new RpcException(new Status(StatusCode.InvalidArgument, ex.Message, ex));
}
}

private static async ValueTask<IMessage> ReadHttpBodyAsync(JsonTranscodingServerCallContext serverCallContext)
{
var httpBody = (IMessage)Activator.CreateInstance(serverCallContext.DescriptorInfo.BodyDescriptor!.ClrType)!;

var contentType = serverCallContext.HttpContext.Request.ContentType;
if (contentType != null)
{
httpBody.Descriptor.Fields[HttpBody.ContentTypeFieldNumber].Accessor.SetValue(httpBody, contentType);
}

var data = await ReadDataAsync(serverCallContext);
httpBody.Descriptor.Fields[HttpBody.DataFieldNumber].Accessor.SetValue(httpBody, UnsafeByteOperations.UnsafeWrap(data));

return httpBody;
}

private static async ValueTask<byte[]> ReadDataAsync(JsonTranscodingServerCallContext serverCallContext)
{
// Buffer to disk if content is larger than 30Kb.
// Based on value in XmlSerializer and NewtonsoftJson input formatters.
const int DefaultMemoryThreshold = 1024 * 30;

var memoryThreshold = DefaultMemoryThreshold;
var contentLength = serverCallContext.HttpContext.Request.ContentLength.GetValueOrDefault();
if (contentLength > 0 && contentLength < memoryThreshold)
{
// If the Content-Length is known and is smaller than the default buffer size, use it.
memoryThreshold = (int)contentLength;
halter73 marked this conversation as resolved.
Show resolved Hide resolved
}

using var fs = new FileBufferingReadStream(serverCallContext.HttpContext.Request.Body, memoryThreshold);

// Read the request body into buffer.
// No explicit cancellation token. Request body uses underlying request aborted token.
await fs.DrainAsync(CancellationToken.None);
fs.Seek(0, SeekOrigin.Begin);

var data = new byte[fs.Length];
var read = fs.Read(data);
Debug.Assert(read == data.Length);

return data;
}

private static List<FieldDescriptor>? GetPathDescriptors(JsonTranscodingServerCallContext serverCallContext, IMessage requestMessage, string path)
Expand All @@ -281,7 +343,7 @@ public static async Task<TRequest> ReadMessage<TRequest>(JsonTranscodingServerCa
});
}

public static async Task SendMessage<TResponse>(JsonTranscodingServerCallContext serverCallContext, JsonSerializerOptions serializerOptions, TResponse message, CancellationToken cancellationToken) where TResponse : class
public static async ValueTask SendMessage<TResponse>(JsonTranscodingServerCallContext serverCallContext, JsonSerializerOptions serializerOptions, TResponse message, CancellationToken cancellationToken) where TResponse : class
{
var response = serverCallContext.HttpContext.Response;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,12 @@ protected override Task WriteResponseHeadersAsyncCore(Metadata responseHeaders)
return HttpContext.Response.BodyWriter.FlushAsync().GetAsTask();
}

internal void EnsureResponseHeaders()
internal void EnsureResponseHeaders(string? contentType = null)
{
if (!HttpContext.Response.HasStarted)
{
HttpContext.Response.StatusCode = StatusCodes.Status200OK;
HttpContext.Response.ContentType = MediaType.ReplaceEncoding("application/json", RequestEncoding);
HttpContext.Response.ContentType = contentType ?? MediaType.ReplaceEncoding("application/json", RequestEncoding);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@
using Google.Api;
using Google.Protobuf;
using Google.Protobuf.Reflection;
using Google.Protobuf.WellKnownTypes;
using Microsoft.AspNetCore.Grpc.JsonTranscoding.Internal.Json;
using Microsoft.AspNetCore.Routing.Patterns;
using Microsoft.Extensions.Primitives;
using Type = System.Type;

namespace Grpc.Shared;

Expand Down Expand Up @@ -165,6 +168,11 @@ public static bool TryResolveDescriptors(MessageDescriptor messageDescriptor, st
case FieldType.Message:
if (IsWrapperType(descriptor.MessageType))
{
if (value == null)
{
return null;
}

return ConvertValue(value, descriptor.MessageType.FindFieldByName("value"));
}
break;
Expand Down Expand Up @@ -219,7 +227,17 @@ public static void RecursiveSetValue(IMessage currentValue, List<FieldDescriptor
}
else if (values is IMessage message)
{
field.Accessor.SetValue(currentValue, message);
if (IsWrapperType(message.Descriptor))
{
const int WrapperValueFieldNumber = Int32Value.ValueFieldNumber;

var wrappedValue = message.Descriptor.Fields[WrapperValueFieldNumber].Accessor.GetValue(message);
field.Accessor.SetValue(currentValue, wrappedValue);
}
else
{
field.Accessor.SetValue(currentValue, message);
}
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
</PropertyGroup>

<ItemGroup>
<Protobuf Include="Proto\httpbody.proto" GrpcServices="Both" />
<Protobuf Include="Proto\transcoding.proto" GrpcServices="Both" />

<Compile Include="..\Shared\TestGrpcServiceActivator.cs" Link="Infrastructure\TestGrpcServiceActivator.cs" />

<Reference Include="Microsoft.AspNetCore.Grpc.JsonTranscoding" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

syntax = "proto3";

import "google/api/annotations.proto";
import "google/api/httpbody.proto";

package transcoding;

service HttpBodyService {
rpc HelloWorld(HelloWorldRequest) returns (google.api.HttpBody) {
option (google.api.http) = {
get: "/helloworld"
};
}
}

message HttpBodySubField {
string name = 1;
google.api.HttpBody sub = 2;
}

message NestedHttpBodySubField {
string name = 1;
HttpBodySubField sub = 2;
}

message HelloWorldRequest {
string name = 1;
}
Loading