Skip to content

Commit

Permalink
Prepare general RPC marshaling for open generic interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
AArnott committed Aug 24, 2020
1 parent ff971ff commit 9adeb2a
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 26 deletions.
70 changes: 58 additions & 12 deletions src/StreamJsonRpc/JsonMessageFormatter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ namespace StreamJsonRpc
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.IO;
using System.IO.Pipelines;
Expand Down Expand Up @@ -64,7 +65,7 @@ public class JsonMessageFormatter : IJsonRpcAsyncMessageTextFormatter, IJsonRpcF
/// </remarks>
private static readonly JsonSerializer DefaultSerializer = JsonSerializer.CreateDefault();

private readonly IReadOnlyDictionary<Type, JsonConverter> implicitlyMarshaledTypes;
private readonly IReadOnlyDictionary<Type, RpcMarshalableImplicitConverter> implicitlyMarshaledTypes;

/// <summary>
/// The reusable <see cref="TextWriter"/> to use with newtonsoft.json's serializer.
Expand Down Expand Up @@ -189,12 +190,12 @@ public JsonMessageFormatter(Encoding encoding)

var camelCaseProxyOptions = new JsonRpcProxyOptions { MethodNameTransform = CommonMethodNameTransforms.CamelCase };
var camelCaseTargetOptions = new JsonRpcTargetOptions { MethodNameTransform = CommonMethodNameTransforms.CamelCase };
this.implicitlyMarshaledTypes = new Dictionary<Type, JsonConverter>
this.implicitlyMarshaledTypes = new Dictionary<Type, RpcMarshalableImplicitConverter>
{
{ typeof(IDisposable), new RpcMarshalableImplicitConverter<IDisposable>(this, camelCaseProxyOptions, camelCaseTargetOptions) },
{ typeof(IDisposable), new RpcMarshalableImplicitConverter(typeof(IDisposable), this, camelCaseProxyOptions, camelCaseTargetOptions) },
};

foreach (KeyValuePair<Type, JsonConverter> implicitlyMarshaledType in this.implicitlyMarshaledTypes)
foreach (KeyValuePair<Type, RpcMarshalableImplicitConverter> implicitlyMarshaledType in this.implicitlyMarshaledTypes)
{
this.JsonSerializer.Converters.Add(implicitlyMarshaledType.Value);
}
Expand Down Expand Up @@ -613,7 +614,7 @@ private JToken TokenizeUserData(Type? declaredType, object? value)
return JValue.CreateNull();
}

if (declaredType is object && this.implicitlyMarshaledTypes.TryGetValue(declaredType, out JsonConverter converter))
if (declaredType is object && this.TryGetImplicitlyMarshaledJsonConverter(declaredType, out RpcMarshalableImplicitConverter? converter))
{
using var jsonWriter = new JTokenWriter();
converter.WriteJson(jsonWriter, value, this.JsonSerializer);
Expand All @@ -623,6 +624,22 @@ private JToken TokenizeUserData(Type? declaredType, object? value)
return JToken.FromObject(value, this.JsonSerializer);
}

private bool TryGetImplicitlyMarshaledJsonConverter(Type type, [NotNullWhen(true)] out RpcMarshalableImplicitConverter? converter)
{
if (this.implicitlyMarshaledTypes.TryGetValue(type, out converter))
{
return true;
}

if (type.IsConstructedGenericType && this.implicitlyMarshaledTypes.TryGetValue(type.GetGenericTypeDefinition(), out converter))
{
converter = converter.WithClosedType(type);
return true;
}

return false;
}

private JsonRpcRequest ReadRequest(JToken json)
{
Requires.NotNull(json, nameof(json));
Expand Down Expand Up @@ -1155,26 +1172,37 @@ public override void WriteJson(JsonWriter writer, Stream? value, JsonSerializer
}
}

private class RpcMarshalableImplicitConverter<T> : JsonConverter
where T : class
[DebuggerDisplay("{" + nameof(DebuggerDisplay) + "}")]
private class RpcMarshalableImplicitConverter : JsonConverter
{
private readonly Type implicitlyConvertedType;
private readonly JsonMessageFormatter jsonMessageFormatter;
private readonly JsonRpcProxyOptions proxyOptions;
private readonly JsonRpcTargetOptions targetOptions;

public RpcMarshalableImplicitConverter(JsonMessageFormatter jsonMessageFormatter, JsonRpcProxyOptions proxyOptions, JsonRpcTargetOptions targetOptions)
public RpcMarshalableImplicitConverter(Type implicitlyConvertedType, JsonMessageFormatter jsonMessageFormatter, JsonRpcProxyOptions proxyOptions, JsonRpcTargetOptions targetOptions)
{
this.implicitlyConvertedType = implicitlyConvertedType;
this.jsonMessageFormatter = jsonMessageFormatter;
this.proxyOptions = proxyOptions;
this.targetOptions = targetOptions;
}

public override bool CanConvert(Type objectType) => objectType == typeof(T);
private RpcMarshalableImplicitConverter(RpcMarshalableImplicitConverter copyFrom, Type implicitlyConvertedType)
: this(implicitlyConvertedType, copyFrom.jsonMessageFormatter, copyFrom.proxyOptions, copyFrom.targetOptions)
{
}

private string DebuggerDisplay => $"Implicit converter for: {this.implicitlyConvertedType.Name}";

public override bool CanConvert(Type objectType) =>
objectType == this.implicitlyConvertedType ||
(this.implicitlyConvertedType.IsGenericTypeDefinition && objectType.IsGenericType && objectType.GetGenericTypeDefinition() == this.implicitlyConvertedType);

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
var token = (MessageFormatterRpcMarshaledContextTracker.MarshalToken?)JToken.Load(reader).ToObject(typeof(MessageFormatterRpcMarshaledContextTracker.MarshalToken), serializer);
return this.jsonMessageFormatter.RpcMarshaledContextTracker.GetObject<T>(token, this.proxyOptions);
return this.jsonMessageFormatter.RpcMarshaledContextTracker.GetObject(objectType, token, this.proxyOptions);
}

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
Expand All @@ -1185,11 +1213,24 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
}
else
{
IRpcMarshaledContext<object> context = JsonRpc.MarshalWithControlledLifetime((T)value, this.targetOptions);
IRpcMarshaledContext<object> context = JsonRpc.MarshalWithControlledLifetime(this.implicitlyConvertedType, value, this.targetOptions);
MessageFormatterRpcMarshaledContextTracker.MarshalToken token = this.jsonMessageFormatter.RpcMarshaledContextTracker.GetToken(context);
serializer.Serialize(writer, token);
}
}

internal RpcMarshalableImplicitConverter WithClosedType(Type implicitlyMarshaledType)
{
if (this.implicitlyConvertedType == implicitlyMarshaledType ||
!this.implicitlyConvertedType.IsGenericType ||
this.implicitlyConvertedType.IsConstructedGenericType)
{
return this;
}

Assumes.True(implicitlyMarshaledType.GetGenericTypeDefinition().Equals(this.implicitlyConvertedType.GetGenericTypeDefinition()));
return new RpcMarshalableImplicitConverter(this, implicitlyMarshaledType);
}
}

private class JsonConverterFormatter : IFormatterConverter
Expand Down Expand Up @@ -1329,7 +1370,12 @@ public override JsonContract ResolveContract(Type type)
{
foreach (JsonProperty property in objectContract.Properties)
{
if (this.formatter.implicitlyMarshaledTypes.TryGetValue(property.PropertyType, out JsonConverter converter))
if (property.Ignored)
{
continue;
}

if (this.formatter.TryGetImplicitlyMarshaledJsonConverter(property.PropertyType, out RpcMarshalableImplicitConverter? converter))
{
property.Converter = converter;
}
Expand Down
10 changes: 10 additions & 0 deletions src/StreamJsonRpc/JsonRpc.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class JsonRpc : IDisposableObservable, IJsonRpcFormatterCallbacks, IJsonR
private static readonly ReadOnlyDictionary<string, string> EmptyDictionary = new ReadOnlyDictionary<string, string>(new Dictionary<string, string>(StringComparer.Ordinal));
private static readonly object[] EmptyObjectArray = new object[0];
private static readonly JsonSerializer DefaultJsonSerializer = JsonSerializer.CreateDefault();
private static readonly MethodInfo MarshalWithControlledLifetimeOpenGenericMethodInfo = typeof(JsonRpc).GetMethods(BindingFlags.Static | BindingFlags.NonPublic).Single(m => m.Name == nameof(MarshalWithControlledLifetime) && m.IsGenericMethod);

/// <summary>
/// The <see cref="System.Threading.SynchronizationContext"/> to use to schedule work on the threadpool.
Expand Down Expand Up @@ -1288,6 +1289,15 @@ internal static IRpcMarshaledContext<T> MarshalWithControlledLifetime<T>(T marsh
return new RpcMarshaledContext<T>(marshaledObject, options);
}

/// <inheritdoc cref="MarshalWithControlledLifetime{T}(T, JsonRpcTargetOptions)"/>
/// <param name="interfaceType"><inheritdoc cref="MarshalWithControlledLifetime{T}(T, JsonRpcTargetOptions)" path="/typeparam"/></param>
/// <param name="marshaledObject"><inheritdoc cref="MarshalWithControlledLifetime{T}(T, JsonRpcTargetOptions)" path="/param[@name='marshaledObject']"/></param>
/// <param name="options"><inheritdoc cref="MarshalWithControlledLifetime{T}(T, JsonRpcTargetOptions)" path="/param[@name='options']"/></param>
internal static IRpcMarshaledContext<object> MarshalWithControlledLifetime(Type interfaceType, object marshaledObject, JsonRpcTargetOptions options)
{
return (IRpcMarshaledContext<object>)MarshalWithControlledLifetimeOpenGenericMethodInfo.MakeGenericMethod(interfaceType).Invoke(null, new object?[] { marshaledObject, options });
}

/// <inheritdoc cref="MarshalWithControlledLifetime{T}(T, JsonRpcTargetOptions)"/>
/// <returns>A proxy value that may be used within an RPC argument so the RPC server may call back into the <paramref name="marshaledObject"/> object on the RPC client.</returns>
/// <remarks>
Expand Down
70 changes: 67 additions & 3 deletions src/StreamJsonRpc/MessagePackFormatter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,10 @@ private MessagePackSerializerOptions MassageUserDataOptions(MessagePackSerialize
{
var camelCaseProxyOptions = new JsonRpcProxyOptions { MethodNameTransform = CommonMethodNameTransforms.CamelCase };
var camelCaseTargetOptions = new JsonRpcTargetOptions { MethodNameTransform = CommonMethodNameTransforms.CamelCase };
var implicitlyMarshaledTypes = new (Type, JsonRpcProxyOptions, JsonRpcTargetOptions)[]
{
(typeof(IDisposable), camelCaseProxyOptions, camelCaseTargetOptions),
};

var formatters = new IMessagePackFormatter[]
{
Expand All @@ -449,8 +453,7 @@ private MessagePackSerializerOptions MassageUserDataOptions(MessagePackSerialize
this.pipeFormatterResolver,

// Support for marshalled objects.
CompositeResolver.Create(
new RpcMarshalableImplicitFormatter<IDisposable>(this, camelCaseProxyOptions, camelCaseTargetOptions)),
new RpcMarshalableImplicitResolver(this, implicitlyMarshaledTypes),

// Add resolvers to make types serializable that we expect to be serializable.
MessagePackExceptionResolver.Instance,
Expand Down Expand Up @@ -1103,6 +1106,67 @@ public void Serialize(ref MessagePackWriter writer, T? value, MessagePackSeriali
}
}

private class RpcMarshalableImplicitResolver : IFormatterResolver
{
private readonly MessagePackFormatter formatter;
private readonly ReadOnlyMemory<(Type Type, JsonRpcProxyOptions ProxyOptions, JsonRpcTargetOptions TargetOptions)> implicitlyMarshaledTypes;
private readonly Dictionary<Type, object> formatters = new Dictionary<Type, object>();

internal RpcMarshalableImplicitResolver(MessagePackFormatter formatter, ReadOnlyMemory<(Type Type, JsonRpcProxyOptions ProxyOptions, JsonRpcTargetOptions TargetOptions)> implicitlyMarshaledTypes)
{
this.formatter = formatter;
this.implicitlyMarshaledTypes = implicitlyMarshaledTypes;
}

public IMessagePackFormatter<T>? GetFormatter<T>()
{
if (typeof(T).IsValueType)
{
return null;
}

lock (this.formatters)
{
if (this.formatters.TryGetValue(typeof(T), out object? cachedFormatter))
{
return (IMessagePackFormatter<T>)cachedFormatter;
}
}

(Type Type, JsonRpcProxyOptions ProxyOptions, JsonRpcTargetOptions TargetOptions)? matchingCandidate = null;
foreach ((Type Type, JsonRpcProxyOptions ProxyOptions, JsonRpcTargetOptions TargetOptions) candidate in this.implicitlyMarshaledTypes.Span)
{
if (candidate.Type == typeof(T) ||
(candidate.Type.IsGenericTypeDefinition && typeof(T).IsConstructedGenericType && candidate.Type == typeof(T).GetGenericTypeDefinition()))
{
matchingCandidate = candidate;
break;
}
}

if (!matchingCandidate.HasValue)
{
return null;
}

object formatter = Activator.CreateInstance(
typeof(RpcMarshalableImplicitFormatter<>).MakeGenericType(typeof(T)),
this.formatter,
matchingCandidate.Value.ProxyOptions,
matchingCandidate.Value.TargetOptions);

lock (this.formatters)
{
if (!this.formatters.TryGetValue(typeof(T), out object? cachedFormatter))
{
this.formatters.Add(typeof(T), cachedFormatter = formatter);
}

return (IMessagePackFormatter<T>)cachedFormatter;
}
}
}

private class RpcMarshalableImplicitFormatter<T> : IMessagePackFormatter<T?>
where T : class
{
Expand All @@ -1120,7 +1184,7 @@ public RpcMarshalableImplicitFormatter(MessagePackFormatter messagePackFormatter
public T? Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options)
{
MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = MessagePackSerializer.Deserialize<MessageFormatterRpcMarshaledContextTracker.MarshalToken?>(ref reader, options);
return token.HasValue ? this.messagePackFormatter.RpcMarshaledContextTracker.GetObject<T>(token, this.proxyOptions) : null;
return token.HasValue ? (T?)this.messagePackFormatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, this.proxyOptions) : null;
}

public void Serialize(ref MessagePackWriter writer, T? value, MessagePackSerializerOptions options)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,12 @@ internal MarshalToken GetToken(IRpcMarshaledContext<object> marshaledContext)
/// <summary>
/// Creates a proxy for a remote object.
/// </summary>
/// <typeparam name="T">The interface the proxy must implement.</typeparam>
/// <param name="interfaceType">The interface the proxy must implement.</param>
/// <param name="token">The token received from the remote party that includes the handle to the remote object.</param>
/// <param name="options">The options to feed into proxy generation.</param>
/// <returns>The generated proxy, or <c>null</c> if <paramref name="token"/> is null.</returns>
[return: NotNullIfNotNull("token")]
internal T? GetObject<T>(MarshalToken? token, JsonRpcProxyOptions options)
where T : class
internal object? GetObject(Type interfaceType, MarshalToken? token, JsonRpcProxyOptions options)
{
if (token is null)
{
Expand All @@ -123,15 +122,22 @@ internal MarshalToken GetToken(IRpcMarshaledContext<object> marshaledContext)
}

// CONSIDER: If we ever support arbitrary RPC interfaces, we'd need to consider how events on those interfaces would work.
return this.jsonRpc.Attach<T>(new JsonRpcProxyOptions(options)
{
MethodNameTransform = mn => Invariant($"$/invokeProxy/{token.Value.Handle}/{options.MethodNameTransform(mn)}"),
OnDispose = delegate
return this.jsonRpc.Attach(
interfaceType,
new JsonRpcProxyOptions(options)
{
this.jsonRpc.NotifyAsync(Invariant($"$/invokeProxy/{token.Value.Handle}/{options.MethodNameTransform(nameof(IDisposable.Dispose))}")).Forget();
this.jsonRpc.NotifyWithParameterObjectAsync("$/releaseMarshaledObject", new { handle = token.Value.Handle, ownedBySender = false }).Forget();
},
});
MethodNameTransform = mn => Invariant($"$/invokeProxy/{token.Value.Handle}/{options.MethodNameTransform(mn)}"),
OnDispose = delegate
{
// Only forward the Dispose call if the marshaled interface derives from IDisposable.
if (typeof(IDisposable).IsAssignableFrom(interfaceType))
{
this.jsonRpc.NotifyAsync(Invariant($"$/invokeProxy/{token.Value.Handle}/{options.MethodNameTransform(nameof(IDisposable.Dispose))}")).Forget();
}

this.jsonRpc.NotifyWithParameterObjectAsync("$/releaseMarshaledObject", new { handle = token.Value.Handle, ownedBySender = false }).Forget();
},
});
}

/// <summary>
Expand Down

0 comments on commit 9adeb2a

Please sign in to comment.