Skip to content

Commit

Permalink
Merge pull request #677 from Cysharp/feature/RawBytesResponse
Browse files Browse the repository at this point in the history
Introduce ServiceContext.SetRawBytesResponse
  • Loading branch information
mayuki authored Sep 19, 2023
2 parents a2d0f80 + e60ecf1 commit edfe2f0
Show file tree
Hide file tree
Showing 12 changed files with 401 additions and 162 deletions.
17 changes: 17 additions & 0 deletions src/MagicOnion.Abstractions/Internal/RawBytesBox.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using System;
using System.ComponentModel;

namespace MagicOnion.Internal
{
// Pubternal API
[EditorBrowsable(EditorBrowsableState.Never)]
public sealed class RawBytesBox
{
public ReadOnlyMemory<byte> Bytes { get; }

public RawBytesBox(ReadOnlyMemory<byte> bytes)
{
Bytes = bytes;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using System;
using System.ComponentModel;

namespace MagicOnion.Internal
{
// Pubternal API
[EditorBrowsable(EditorBrowsableState.Never)]
public sealed class RawBytesBox
{
public ReadOnlyMemory<byte> Bytes { get; }

public RawBytesBox(ReadOnlyMemory<byte> bytes)
{
Bytes = bytes;
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using System;
using System.Collections.Generic;
using System.Reflection;
using System.Text;
using System.Runtime.CompilerServices;
using Grpc.Core;
using MagicOnion.Internal;
using MagicOnion.Serialization;
Expand All @@ -14,18 +13,39 @@ public static class GrpcMethodHelper
public sealed class MagicOnionMethod<TRequest, TResponse, TRawRequest, TRawResponse>
{
public Method<TRawRequest, TRawResponse> Method { get; }
public Func<TRequest, TRawRequest> ToRawRequest { get; }
public Func<TResponse, TRawResponse> ToRawResponse { get; }
public Func<TRawRequest, TRequest> FromRawRequest { get; }
public Func<TRawResponse, TResponse> FromRawResponse { get; }

public MagicOnionMethod(Method<TRawRequest, TRawResponse> method)
{
Method = method;
ToRawRequest = ((typeof(TRawRequest) == typeof(Box<TRequest>)) ? (Func<TRequest, TRawRequest>)(x => (TRawRequest)(object)Box.Create(x)) : x => DangerousDummyNull.GetObjectOrDummyNull((TRawRequest)(object)x));
ToRawResponse = ((typeof(TRawResponse) == typeof(Box<TResponse>)) ? (Func<TResponse, TRawResponse>)(x => (TRawResponse)(object)Box.Create(x)) : x => DangerousDummyNull.GetObjectOrDummyNull((TRawResponse)(object)x));
FromRawRequest = ((typeof(TRawRequest) == typeof(Box<TRequest>)) ? (Func<TRawRequest, TRequest>)(x => ((Box<TRequest>)(object)x).Value) : x => DangerousDummyNull.GetObjectOrDefault<TRequest>(x));
FromRawResponse = ((typeof(TRawResponse) == typeof(Box<TResponse>)) ? (Func<TRawResponse, TResponse>)(x => ((Box<TResponse>)(object)x).Value) : x => DangerousDummyNull.GetObjectOrDefault<TResponse>(x));
}

public TRawRequest ToRawRequest(TRequest obj) => ToRaw<TRequest, TRawRequest>(obj);
public TRawResponse ToRawResponse(TResponse obj) => ToRaw<TResponse, TRawResponse>(obj);
public TRequest FromRawRequest(TRawRequest obj) => FromRaw<TRawRequest, TRequest>(obj);
public TResponse FromRawResponse(TRawResponse obj) => FromRaw<TRawResponse, TResponse>(obj);

static TRaw ToRaw<T, TRaw>(T obj)
{
if (typeof(TRaw) == typeof(Box<T>))
{
return (TRaw)(object)Box.Create(obj);
}
else
{
return DangerousDummyNull.GetObjectOrDummyNull(Unsafe.As<T, TRaw>(ref obj));
}
}

static T FromRaw<TRaw, T>(TRaw obj)
{
if (typeof(TRaw) == typeof(Box<T>))
{
return ((Box<T>)(object)obj).Value;
}
else
{
return DangerousDummyNull.GetObjectOrDefault<T>(obj);
}
}
}

Expand All @@ -39,27 +59,17 @@ public static MagicOnionMethod<Nil, TResponse, Box<Nil>, TRawResponse> CreateMet
// DynamicClient sends byte[], but GeneratedClient sends Nil, which is incompatible,
// so as a special case we do not serialize/deserialize and always convert to a fixed values.
var isMethodResponseTypeBoxed = typeof(TResponse).IsValueType;
var responseMarshaller = isMethodResponseTypeBoxed
? (object)CreateBoxedMarshaller<TResponse>(messageSerializer)
: (object)CreateMarshaller<TResponse>(messageSerializer);

if (isMethodResponseTypeBoxed)
{
return new MagicOnionMethod<Nil, TResponse, Box<Nil>, TRawResponse>(new Method<Box<Nil>, TRawResponse>(
methodType,
serviceName,
name,
IgnoreNilMarshaller,
(Marshaller<TRawResponse>)(object)CreateBoxedMarshaller<TResponse>(messageSerializer, methodType, methodInfo)
));
}
else
{
return new MagicOnionMethod<Nil, TResponse, Box<Nil>, TRawResponse>(new Method<Box<Nil>, TRawResponse>(
methodType,
serviceName,
name,
IgnoreNilMarshaller,
(Marshaller<TRawResponse>)(object)CreateMarshaller<TResponse>(messageSerializer, methodType, methodInfo)
));
}
return new MagicOnionMethod<Nil, TResponse, Box<Nil>, TRawResponse>(new Method<Box<Nil>, TRawResponse>(
methodType,
serviceName,
name,
IgnoreNilMarshaller,
(Marshaller<TRawResponse>)responseMarshaller
));
}

public static MagicOnionMethod<TRequest, TResponse, TRawRequest, TRawResponse> CreateMethod<TRequest, TResponse, TRawRequest, TRawResponse>(MethodType methodType, string serviceName, string name, IMagicOnionSerializer messageSerializer)
Expand All @@ -71,46 +81,20 @@ public static MagicOnionMethod<TRequest, TResponse, TRawRequest, TRawResponse> C
var isMethodRequestTypeBoxed = typeof(TRequest).IsValueType;
var isMethodResponseTypeBoxed = typeof(TResponse).IsValueType;

if (isMethodRequestTypeBoxed && isMethodResponseTypeBoxed)
{
return new MagicOnionMethod<TRequest, TResponse, TRawRequest, TRawResponse>(new Method<TRawRequest, TRawResponse>(
methodType,
serviceName,
name,
(Marshaller<TRawRequest>)(object)CreateBoxedMarshaller<TRequest>(messageSerializer, methodType, methodInfo),
(Marshaller<TRawResponse>)(object)CreateBoxedMarshaller<TResponse>(messageSerializer, methodType, methodInfo)
));
}
else if (isMethodRequestTypeBoxed)
{
return new MagicOnionMethod<TRequest, TResponse, TRawRequest, TRawResponse>(new Method<TRawRequest, TRawResponse>(
methodType,
serviceName,
name,
(Marshaller<TRawRequest>)(object)CreateBoxedMarshaller<TRequest>(messageSerializer, methodType, methodInfo),
(Marshaller<TRawResponse>)(object)CreateMarshaller<TResponse>(messageSerializer, methodType, methodInfo)
));
}
else if (isMethodResponseTypeBoxed)
{
return new MagicOnionMethod<TRequest, TResponse, TRawRequest, TRawResponse>(new Method<TRawRequest, TRawResponse>(
methodType,
serviceName,
name,
(Marshaller<TRawRequest>)(object)CreateMarshaller<TRequest>(messageSerializer, methodType, methodInfo),
(Marshaller<TRawResponse>)(object)CreateBoxedMarshaller<TResponse>(messageSerializer, methodType, methodInfo)
));
}
else
{
return new MagicOnionMethod<TRequest, TResponse, TRawRequest, TRawResponse>(new Method<TRawRequest, TRawResponse>(
methodType,
serviceName,
name,
(Marshaller<TRawRequest>)(object)CreateMarshaller<TRequest>(messageSerializer, methodType, methodInfo),
(Marshaller<TRawResponse>)(object)CreateMarshaller<TResponse>(messageSerializer, methodType, methodInfo)
));
}
var requestMarshaller = isMethodRequestTypeBoxed
? (object)CreateBoxedMarshaller<TRequest>(messageSerializer)
: (object)CreateMarshaller<TRequest>(messageSerializer);
var responseMarshaller = isMethodResponseTypeBoxed
? (object)CreateBoxedMarshaller<TResponse>(messageSerializer)
: (object)CreateMarshaller<TResponse>(messageSerializer);

return new MagicOnionMethod<TRequest, TResponse, TRawRequest, TRawResponse>(new Method<TRawRequest, TRawResponse>(
methodType,
serviceName,
name,
(Marshaller<TRawRequest>)requestMarshaller,
(Marshaller<TRawResponse>)responseMarshaller
));
}

// WORKAROUND: Prior to MagicOnion 5.0, the request type for the parameter-less method was byte[].
Expand All @@ -131,23 +115,45 @@ public static MagicOnionMethod<TRequest, TResponse, TRawRequest, TRawResponse> C
deserializer: (ctx) => Box.Create(Nil.Default) /* Box.Create always returns cached Box<Nil> */
);

static Marshaller<T> CreateMarshaller<T>(IMagicOnionSerializer messageSerializer, MethodType methodType, MethodInfo methodInfo)
static Marshaller<T> CreateMarshaller<T>(IMagicOnionSerializer messageSerializer)
{
return new Marshaller<T>(
serializer: (obj, ctx) =>
{
messageSerializer.Serialize(ctx.GetBufferWriter(), DangerousDummyNull.GetObjectOrDefault<T>(obj));
if (obj.GetType() == typeof(RawBytesBox))
{
var rawBytesBox = (RawBytesBox)(object)obj;
var writer = ctx.GetBufferWriter();
var buffer = writer.GetSpan(rawBytesBox.Bytes.Length);
rawBytesBox.Bytes.Span.CopyTo(buffer);
writer.Advance(rawBytesBox.Bytes.Length);
}
else
{
messageSerializer.Serialize(ctx.GetBufferWriter(), DangerousDummyNull.GetObjectOrDefault<T>(obj));
}
ctx.Complete();
},
deserializer: (ctx) => DangerousDummyNull.GetObjectOrDummyNull(messageSerializer.Deserialize<T>(ctx.PayloadAsReadOnlySequence())));
}

static Marshaller<Box<T>> CreateBoxedMarshaller<T>(IMagicOnionSerializer messageSerializer, MethodType methodType, MethodInfo methodInfo)
static Marshaller<Box<T>> CreateBoxedMarshaller<T>(IMagicOnionSerializer messageSerializer)
{
return new Marshaller<Box<T>>(
serializer: (obj, ctx) =>
{
messageSerializer.Serialize(ctx.GetBufferWriter(), obj.Value);
if (obj.GetType() == typeof(RawBytesBox))
{
var rawBytesBox = (RawBytesBox)(object)obj;
var writer = ctx.GetBufferWriter();
var buffer = writer.GetSpan(rawBytesBox.Bytes.Length);
rawBytesBox.Bytes.Span.CopyTo(buffer);
writer.Advance(rawBytesBox.Bytes.Length);
}
else
{
messageSerializer.Serialize(ctx.GetBufferWriter(), obj.Value);
}
ctx.Complete();
},
deserializer: (ctx) => Box.Create(messageSerializer.Deserialize<T>(ctx.PayloadAsReadOnlySequence()))
Expand Down
26 changes: 22 additions & 4 deletions src/MagicOnion.Server/Internal/MagicOnionMethodHandlerBinder.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
using System.Diagnostics;
using Grpc.Core;
using MagicOnion.Internal;
using MagicOnion.Serialization;
using MessagePack;
using System.Reflection;
using System.Runtime.CompilerServices;

namespace MagicOnion.Server.Internal;

Expand All @@ -25,21 +27,37 @@ internal class MagicOnionMethodHandlerBinder<TRequest, TResponse, TRawRequest, T
{
public static MagicOnionMethodHandlerBinder<TRequest, TResponse, TRawRequest, TRawResponse> Instance { get; } = new MagicOnionMethodHandlerBinder<TRequest, TResponse, TRawRequest, TRawResponse>();

public void BindUnary(ServiceBinderBase binder, Func<TRequest, ServerCallContext, Task<TResponse?>> serverMethod, MethodHandler methodHandler, IMagicOnionSerializer messageSerializer)
public void BindUnary(ServiceBinderBase binder, Func<TRequest, ServerCallContext, Task<object?>> serverMethod, MethodHandler methodHandler, IMagicOnionSerializer messageSerializer)
{
var method = GrpcMethodHelper.CreateMethod<TRequest, TResponse, TRawRequest, TRawResponse>(MethodType.Unary, methodHandler.ServiceName, methodHandler.MethodName, methodHandler.MethodInfo, messageSerializer);
binder.AddMethod(new MagicOnionServerMethod<TRawRequest, TRawResponse>(method.Method, methodHandler),
async (request, context) => method.ToRawResponse(await serverMethod(method.FromRawRequest(request), context)));
async (request, context) =>
{
var response = await serverMethod(method.FromRawRequest(request), context);
if (response is RawBytesBox rawBytesResponse)
{
return Unsafe.As<RawBytesBox, TRawResponse>(ref rawBytesResponse); // NOTE: To disguise an object as a `TRawResponse`, `TRawResponse` must be `class`.
}
return method.ToRawResponse((TResponse?)response);
});
}

public void BindUnaryParameterless(ServiceBinderBase binder, Func<Nil, ServerCallContext, Task<TResponse?>> serverMethod, MethodHandler methodHandler, IMagicOnionSerializer messageSerializer)
public void BindUnaryParameterless(ServiceBinderBase binder, Func<Nil, ServerCallContext, Task<object?>> serverMethod, MethodHandler methodHandler, IMagicOnionSerializer messageSerializer)
{
// WORKAROUND: Prior to MagicOnion 5.0, the request type for the parameter-less method was byte[].
// DynamicClient sends byte[], but GeneratedClient sends Nil, which is incompatible,
// so as a special case we do not serialize/deserialize and always convert to a fixed values.
var method = GrpcMethodHelper.CreateMethod<TResponse, TRawResponse>(MethodType.Unary, methodHandler.ServiceName, methodHandler.MethodName, methodHandler.MethodInfo, messageSerializer);
binder.AddMethod(new MagicOnionServerMethod<Box<Nil>, TRawResponse>(method.Method, methodHandler),
async (request, context) => method.ToRawResponse(await serverMethod(method.FromRawRequest(request), context)));
async (request, context) =>
{
var response = await serverMethod(method.FromRawRequest(request), context);
if (response is RawBytesBox rawBytesResponse)
{
return Unsafe.As<RawBytesBox, TRawResponse>(ref rawBytesResponse); // NOTE: To disguise an object as a `TRawResponse`, `TRawResponse` must be `class`.
}
return method.ToRawResponse((TResponse?)response);
});
}

public void BindStreamingHub(ServiceBinderBase binder, Func<IAsyncStreamReader<TRequest>, IServerStreamWriter<TResponse>, ServerCallContext, Task> serverMethod, MethodHandler methodHandler, IMagicOnionSerializer messageSerializer)
Expand Down
9 changes: 6 additions & 3 deletions src/MagicOnion.Server/MagicOnionEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ public static MagicOnionServiceDefinition BuildServerServiceDefinition(IServiceP
var handlers = new HashSet<MethodHandler>();
var streamingHubHandlers = new List<StreamingHubHandler>();

var methodHandlerOptions = new MethodHandlerOptions(options);
var streamingHubHandlerOptions = new StreamingHubHandlerOptions(options);

logger.BeginBuildServiceDefinition();
var sw = Stopwatch.StartNew();

Expand Down Expand Up @@ -179,7 +182,7 @@ public static MagicOnionServiceDefinition BuildServerServiceDefinition(IServiceP
// register for StreamingHub
if (isStreamingHub && methodName != "Connect")
{
var streamingHandler = new StreamingHubHandler(classType, methodInfo, new StreamingHubHandlerOptions(options), serviceProvider);
var streamingHandler = new StreamingHubHandler(classType, methodInfo, streamingHubHandlerOptions, serviceProvider);
if (!tempStreamingHubHandlers!.Add(streamingHandler))
{
throw new InvalidOperationException($"Method does not allow overload, {className}.{methodName}");
Expand All @@ -189,7 +192,7 @@ public static MagicOnionServiceDefinition BuildServerServiceDefinition(IServiceP
else
{
// create handler
var handler = new MethodHandler(classType, methodInfo, methodName, new MethodHandlerOptions(options), serviceProvider, logger, isStreamingHub: false);
var handler = new MethodHandler(classType, methodInfo, methodName, methodHandlerOptions, serviceProvider, logger, isStreamingHub: false);
if (!handlers.Add(handler))
{
throw new InvalidOperationException($"Method does not allow overload, {className}.{methodName}");
Expand All @@ -199,7 +202,7 @@ public static MagicOnionServiceDefinition BuildServerServiceDefinition(IServiceP

if (isStreamingHub)
{
var connectHandler = new MethodHandler(classType, classType.GetMethod("Connect")!, "Connect", new MethodHandlerOptions(options), serviceProvider, logger, isStreamingHub: true);
var connectHandler = new MethodHandler(classType, classType.GetMethod("Connect")!, "Connect", methodHandlerOptions, serviceProvider, logger, isStreamingHub: true);
if (!handlers.Add(connectHandler))
{
throw new InvalidOperationException($"Method does not allow overload, {className}.Connect");
Expand Down
6 changes: 3 additions & 3 deletions src/MagicOnion.Server/MethodHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,13 @@ void BindHandlerTyped<TRequest, TResponse, TRawRequest, TRawResponse>(ServiceBin
}
}

async Task<TResponse?> UnaryServerMethod<TRequest, TResponse>(TRequest request, ServerCallContext context)
async Task<object?> UnaryServerMethod<TRequest, TResponse>(TRequest request, ServerCallContext context)
{
var isErrorOrInterrupted = false;
var serviceContext = new ServiceContext(ServiceType, MethodInfo, AttributeLookup, this.MethodType, context, messageSerializer, Logger, this, context.GetHttpContext().RequestServices);
serviceContext.SetRawRequest(request);

TResponse? response = default;
object? response = default(TResponse?);
try
{
Logger.BeginInvokeMethod(serviceContext, typeof(TRequest));
Expand All @@ -241,7 +241,7 @@ void BindHandlerTyped<TRequest, TResponse, TRawRequest, TRawResponse>(ServiceBin
await this.methodBody(serviceContext).ConfigureAwait(false);
if (serviceContext.Result is not null)
{
response = (TResponse?)serviceContext.Result;
response = serviceContext.Result;
}
}
catch (ReturnStatusException ex)
Expand Down
Loading

0 comments on commit edfe2f0

Please sign in to comment.