diff --git a/Directory.Packages.props b/Directory.Packages.props index 462cbe578..6a87cb5f6 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -22,11 +22,12 @@ + - + - + diff --git a/nuget.config b/nuget.config index 0a73357e8..8ef7b7880 100644 --- a/nuget.config +++ b/nuget.config @@ -5,10 +5,25 @@ + + + + + + + + + + + + + + + diff --git a/src/StreamJsonRpc/FormatterBase.cs b/src/StreamJsonRpc/FormatterBase.cs index 7dd4c6479..c7bdb6ffa 100644 --- a/src/StreamJsonRpc/FormatterBase.cs +++ b/src/StreamJsonRpc/FormatterBase.cs @@ -6,6 +6,7 @@ using System.IO.Pipelines; using System.Reflection; using System.Runtime.Serialization; +using Nerdbank.MessagePack; using Nerdbank.Streams; using StreamJsonRpc.Protocol; using StreamJsonRpc.Reflection; diff --git a/src/StreamJsonRpc/JsonRpc.cs b/src/StreamJsonRpc/JsonRpc.cs index e8ab2712a..e5f1e596a 100644 --- a/src/StreamJsonRpc/JsonRpc.cs +++ b/src/StreamJsonRpc/JsonRpc.cs @@ -1697,7 +1697,7 @@ protected virtual async ValueTask DispatchRequestAsync(JsonRpcRe } /// - /// Sends the JSON-RPC message to intance to be transmitted. + /// Sends the JSON-RPC message to instance to be transmitted. /// /// The message to send. /// A token to cancel the send request. diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.CommonString.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.CommonString.cs new file mode 100644 index 000000000..dfc6ef831 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.CommonString.cs @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Diagnostics; +using NBMP = Nerdbank.MessagePack; + +namespace StreamJsonRpc; + +public partial class NerdbankMessagePackFormatter +{ + [DebuggerDisplay("{" + nameof(Value) + "}")] + private struct CommonString + { + internal CommonString(string value) + { + Requires.Argument(value.Length > 0 && value.Length <= 16, nameof(value), "Length must be >0 and <=16."); + this.Value = value; + ReadOnlyMemory encodedBytes = MessagePack.Internal.CodeGenHelpers.GetEncodedStringBytes(value); + this.EncodedBytes = encodedBytes; + + ReadOnlySpan span = this.EncodedBytes.Span.Slice(1); + this.Key = MessagePack.Internal.AutomataKeyGen.GetKey(ref span); // header is 1 byte because string length <= 16 + this.Key2 = span.Length > 0 ? (ulong?)MessagePack.Internal.AutomataKeyGen.GetKey(ref span) : null; + } + + /// + /// Gets the original string. + /// + internal string Value { get; } + + /// + /// Gets the 64-bit integer that represents the string without decoding it. + /// + private ulong Key { get; } + + /// + /// Gets the next 64-bit integer that represents the string without decoding it. + /// + private ulong? Key2 { get; } + + /// + /// Gets the messagepack header and UTF-8 bytes for this string. + /// + private ReadOnlyMemory EncodedBytes { get; } + + /// + /// Writes out the messagepack binary for this common string, if it matches the given value. + /// + /// The writer to use. + /// The value to be written, if it matches this . + /// if matches this and it was written; otherwise. + internal bool TryWrite(ref NBMP::MessagePackWriter writer, string value) + { + if (value == this.Value) + { + this.Write(ref writer); + return true; + } + + return false; + } + + internal readonly void Write(ref NBMP::MessagePackWriter writer) => writer.WriteRaw(this.EncodedBytes.Span); + + /// + /// Checks whether a span of UTF-8 bytes equal this common string. + /// + /// The UTF-8 string. + /// if the UTF-8 bytes are the encoding of this common string; otherwise. + internal readonly bool TryRead(ReadOnlySpan utf8String) + { + if (utf8String.Length != this.EncodedBytes.Length - 1) + { + return false; + } + + ulong key1 = MessagePack.Internal.AutomataKeyGen.GetKey(ref utf8String); + if (key1 != this.Key) + { + return false; + } + + if (utf8String.Length > 0) + { + if (!this.Key2.HasValue) + { + return false; + } + + ulong key2 = MessagePack.Internal.AutomataKeyGen.GetKey(ref utf8String); + if (key2 != this.Key2.Value) + { + return false; + } + } + else if (this.Key2.HasValue) + { + return false; + } + + return true; + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContext.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContext.cs new file mode 100644 index 000000000..3d560bbc5 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContext.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Nerdbank.MessagePack; +using PolyType; +using PolyType.Abstractions; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +/// +/// The MessagePack implementation used here comes from https://github.com/AArnott/Nerdbank.MessagePack. +/// +public sealed partial class NerdbankMessagePackFormatter +{ + internal class FormatterContext(MessagePackSerializer serializer, ITypeShapeProvider shapeProvider) + { + public MessagePackSerializer Serializer => serializer; + + public ITypeShapeProvider ShapeProvider => shapeProvider; + + public T? Deserialize(ref MessagePackReader reader, CancellationToken cancellationToken = default) + { + return serializer.Deserialize(ref reader, shapeProvider, cancellationToken); + } + + public T Deserialize(in RawMessagePack pack, CancellationToken cancellationToken = default) + { + // TODO: Improve the exception + return serializer.Deserialize(pack, shapeProvider, cancellationToken) + ?? throw new InvalidOperationException("Deserialization failed."); + } + + public object? DeserializeObject(in RawMessagePack pack, Type objectType, CancellationToken cancellationToken = default) + { + MessagePackReader reader = new(pack); + return serializer.DeserializeObject( + ref reader, + shapeProvider.Resolve(objectType), + cancellationToken); + } + + public void Serialize(ref MessagePackWriter writer, T? value, CancellationToken cancellationToken = default) + { + serializer.Serialize(ref writer, value, shapeProvider, cancellationToken); + } + + internal void SerializeObject(ref MessagePackWriter writer, object? value, Type objectType, CancellationToken cancellationToken = default) + { + serializer.SerializeObject(ref writer, value, shapeProvider.Resolve(objectType), cancellationToken); + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContextBuilder.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContextBuilder.cs new file mode 100644 index 000000000..b511cbf22 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.FormatterContextBuilder.cs @@ -0,0 +1,219 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Collections.Immutable; +using System.IO.Pipelines; +using Nerdbank.MessagePack; +using PolyType; +using PolyType.Abstractions; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +/// +/// The MessagePack implementation used here comes from https://github.com/AArnott/Nerdbank.MessagePack. +/// +public sealed partial class NerdbankMessagePackFormatter +{ + /// + /// Provides methods to build a serialization context for the . + /// + public class FormatterContextBuilder + { + private readonly NerdbankMessagePackFormatter formatter; + private readonly FormatterContext baseContext; + + private ImmutableArray.Builder? typeShapeProvidersBuilder = null; + + /// + /// Initializes a new instance of the class. + /// + /// The formatter to use. + /// The base context to build upon. + internal FormatterContextBuilder(NerdbankMessagePackFormatter formatter, FormatterContext baseContext) + { + this.formatter = formatter; + this.baseContext = baseContext; + } + + /// + /// Adds a type shape provider to the context. + /// + /// The type shape provider to add. + public void AddTypeShapeProvider(ITypeShapeProvider provider) + { + this.typeShapeProvidersBuilder ??= ImmutableArray.CreateBuilder(); + this.typeShapeProvidersBuilder.Add(provider); + } + + /// + /// Registers an async enumerable type with the context. + /// + /// The type of the async enumerable. + /// The type of the elements in the async enumerable. + public void RegisterAsyncEnumerableType() + where TEnumerable : IAsyncEnumerable + { + MessagePackConverter converter = this.formatter.asyncEnumerableConverterResolver.GetConverter(); + this.baseContext.Serializer.RegisterConverter(converter); + } + + /// + /// Registers a converter with the context. + /// + /// The type the converter handles. + /// The converter to register. + public void RegisterConverter(MessagePackConverter converter) + { + this.baseContext.Serializer.RegisterConverter(converter); + } + + /// + /// Registers known subtypes for a base type with the context. + /// + /// The base type. + /// The mapping of known subtypes. + public void RegisterKnownSubTypes(KnownSubTypeMapping mapping) + { + this.baseContext.Serializer.RegisterKnownSubTypes(mapping); + } + + /// + /// Registers a progress type with the context. + /// + /// The type of the progress. + /// The type of the report. + public void RegisterProgressType() + where TProgress : IProgress + { + MessagePackConverter converter = this.formatter.progressConverterResolver.GetConverter(); + this.baseContext.Serializer.RegisterConverter(converter); + } + + /// + /// Registers a duplex pipe type with the context. + /// + /// The type of the duplex pipe. + public void RegisterDuplexPipeType() + where TPipe : IDuplexPipe + { + MessagePackConverter converter = this.formatter.pipeConverterResolver.GetConverter(); + this.baseContext.Serializer.RegisterConverter(converter); + } + + /// + /// Registers a pipe reader type with the context. + /// + /// The type of the pipe reader. + public void RegisterPipeReaderType() + where TReader : PipeReader + { + MessagePackConverter converter = this.formatter.pipeConverterResolver.GetConverter(); + this.baseContext.Serializer.RegisterConverter(converter); + } + + /// + /// Registers a pipe writer type with the context. + /// + /// The type of the pipe writer. + public void RegisterPipeWriterType() + where TWriter : PipeWriter + { + MessagePackConverter converter = this.formatter.pipeConverterResolver.GetConverter(); + this.baseContext.Serializer.RegisterConverter(converter); + } + + /// + /// Registers a stream type with the context. + /// + /// The type of the stream. + public void RegisterStreamType() + where TStream : Stream + { + MessagePackConverter converter = this.formatter.pipeConverterResolver.GetConverter(); + this.baseContext.Serializer.RegisterConverter(converter); + } + + /// + /// Registers an exception type with the context. + /// + /// The type of the exception. + public void RegisterExceptionType() + where TException : Exception + { + MessagePackConverter converter = this.formatter.exceptionResolver.GetConverter(); + this.baseContext.Serializer.RegisterConverter(converter); + } + + /// + /// Registers an RPC marshalable type with the context. + /// + /// The type to register. + public void RegisterRpcMarshalableType() + where T : class + { + if (MessageFormatterRpcMarshaledContextTracker.TryGetMarshalOptionsForType( + typeof(T), + out JsonRpcProxyOptions? proxyOptions, + out JsonRpcTargetOptions? targetOptions, + out RpcMarshalableAttribute? attribute)) + { + var converter = (RpcMarshalableConverter)Activator.CreateInstance( + typeof(RpcMarshalableConverter<>).MakeGenericType(typeof(T)), + this.formatter, + proxyOptions, + targetOptions, + attribute)!; + + this.baseContext.Serializer.RegisterConverter(converter); + } + + // TODO: Throw? + } + + /// + /// Builds the formatter context. + /// + /// The built formatter context. + internal FormatterContext Build() + { + if (this.typeShapeProvidersBuilder is null || this.typeShapeProvidersBuilder.Count < 1) + { + return this.baseContext; + } + + ITypeShapeProvider provider = this.typeShapeProvidersBuilder.Count == 1 + ? this.typeShapeProvidersBuilder[0] + : new CompositeTypeShapeProvider(this.typeShapeProvidersBuilder.ToImmutable()); + + return new FormatterContext(this.baseContext.Serializer, provider); + } + } + + private class CompositeTypeShapeProvider : ITypeShapeProvider + { + private readonly ImmutableArray providers; + + internal CompositeTypeShapeProvider(ImmutableArray providers) + { + this.providers = providers; + } + + public ITypeShape? GetShape(Type type) + { + foreach (ITypeShapeProvider provider in this.providers) + { + ITypeShape? shape = provider.GetShape(type); + if (shape is not null) + { + return shape; + } + } + + return null; + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs new file mode 100644 index 000000000..204196138 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -0,0 +1,2242 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using System.Collections.Immutable; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.IO.Pipelines; +using System.Reflection; +using System.Runtime.ExceptionServices; +using System.Runtime.Serialization; +using System.Text; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using Nerdbank.Streams; +using PolyType; +using PolyType.Abstractions; +using PolyType.ReflectionProvider; +using PolyType.SourceGenerator; +using StreamJsonRpc.Protocol; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +/// +/// The MessagePack implementation used here comes from https://github.com/AArnott/Nerdbank.MessagePack. +/// +public sealed partial class NerdbankMessagePackFormatter : FormatterBase, IJsonRpcMessageFormatter, IJsonRpcFormatterTracingCallbacks, IJsonRpcMessageFactory +{ + /// + /// The constant "jsonrpc", in its various forms. + /// + private static readonly CommonString VersionPropertyName = new(Constants.jsonrpc); + + /// + /// The constant "id", in its various forms. + /// + private static readonly CommonString IdPropertyName = new(Constants.id); + + /// + /// The constant "method", in its various forms. + /// + private static readonly CommonString MethodPropertyName = new(Constants.Request.method); + + /// + /// The constant "result", in its various forms. + /// + private static readonly CommonString ResultPropertyName = new(Constants.Result.result); + + /// + /// The constant "error", in its various forms. + /// + private static readonly CommonString ErrorPropertyName = new(Constants.Error.error); + + /// + /// The constant "params", in its various forms. + /// + private static readonly CommonString ParamsPropertyName = new(Constants.Request.@params); + + /// + /// The constant "traceparent", in its various forms. + /// + private static readonly CommonString TraceParentPropertyName = new(Constants.Request.traceparent); + + /// + /// The constant "tracestate", in its various forms. + /// + private static readonly CommonString TraceStatePropertyName = new(Constants.Request.tracestate); + + /// + /// The constant "2.0", in its various forms. + /// + private static readonly CommonString Version2 = new("2.0"); + + /// + /// A cache of property names to declared property types, indexed by their containing parameter object type. + /// + /// + /// All access to this field should be while holding a lock on this member's value. + /// + private static readonly Dictionary> ParameterObjectPropertyTypes = new Dictionary>(); + + /// + /// The serializer context to use for top-level RPC messages. + /// + private readonly FormatterContext rpcContext; + + private readonly ProgressConverterResolver progressConverterResolver; + + private readonly AsyncEnumerableConverterResolver asyncEnumerableConverterResolver; + + private readonly PipeConverterResolver pipeConverterResolver; + + private readonly MessagePackExceptionConverterResolver exceptionResolver; + + private readonly ToStringHelper serializationToStringHelper = new(); + + private readonly ToStringHelper deserializationToStringHelper = new(); + + /// + /// The serializer to use for user data (e.g. arguments, return values and errors). + /// + private FormatterContext userDataContext; + + /// + /// Initializes a new instance of the class. + /// + public NerdbankMessagePackFormatter() + { + // Set up initial options for our own message types. + MessagePackSerializer serializer = new() + { + InternStrings = true, + SerializeDefaultValues = false, + }; + + serializer.RegisterConverter(new RequestIdConverter()); + serializer.RegisterConverter(new JsonRpcMessageConverter(this)); + serializer.RegisterConverter(new JsonRpcRequestConverter(this)); + serializer.RegisterConverter(new JsonRpcResultConverter(this)); + serializer.RegisterConverter(new JsonRpcErrorConverter(this)); + serializer.RegisterConverter(new JsonRpcErrorDetailConverter(this)); + serializer.RegisterConverter(new TraceParentConverter()); + + this.rpcContext = new FormatterContext(serializer, ShapeProvider_StreamJsonRpc.Default); + + // Create the specialized formatters/resolvers that we will inject into the chain for user data. + this.progressConverterResolver = new ProgressConverterResolver(this); + this.asyncEnumerableConverterResolver = new AsyncEnumerableConverterResolver(this); + this.pipeConverterResolver = new PipeConverterResolver(this); + this.exceptionResolver = new MessagePackExceptionConverterResolver(this); + + FormatterContext userDataContext = new( + new() + { + InternStrings = true, + SerializeDefaultValues = false, + }, + ReflectionTypeShapeProvider.Default); + + this.MassageUserDataContext(userDataContext); + this.userDataContext = userDataContext; + } + + private interface IJsonRpcMessagePackRetention + { + /// + /// Gets the original msgpack sequence that was deserialized into this message. + /// + /// + /// The buffer is only retained for a short time. If it has already been cleared, the result of this property is an empty sequence. + /// + ReadOnlySequence OriginalMessagePack { get; } + } + + /// + /// Configures the serialization context for user data with the specified configuration action. + /// + /// The action to configure the serialization context. + public void SetFormatterContext(Action configure) + { + Requires.NotNull(configure, nameof(configure)); + + var builder = new FormatterContextBuilder(this, this.userDataContext); + configure(builder); + + FormatterContext context = builder.Build(); + this.MassageUserDataContext(context); + + this.userDataContext = context; + } + + /// + public JsonRpcMessage Deserialize(ReadOnlySequence contentBuffer) + { + JsonRpcMessage message = this.rpcContext.Serializer.Deserialize(contentBuffer, ShapeProvider_StreamJsonRpc.Default) + ?? throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); + + IJsonRpcTracingCallbacks? tracingCallbacks = this.JsonRpc; + this.deserializationToStringHelper.Activate(contentBuffer); + try + { + tracingCallbacks?.OnMessageDeserialized(message, this.deserializationToStringHelper); + } + finally + { + this.deserializationToStringHelper.Deactivate(); + } + + return message; + } + + /// + public void Serialize(IBufferWriter contentBuffer, JsonRpcMessage message) + { + if (message is Protocol.JsonRpcRequest request + && request.Arguments is not null + && request.ArgumentsList is null + && request.Arguments is not IReadOnlyDictionary) + { + // This request contains named arguments, but not using a standard dictionary. Convert it to a dictionary so that + // the parameters can be matched to the method we're invoking. + if (GetParamsObjectDictionary(request.Arguments) is { } namedArgs) + { + request.Arguments = namedArgs.ArgumentValues; + request.NamedArgumentDeclaredTypes = namedArgs.ArgumentTypes; + } + } + + var writer = new MessagePackWriter(contentBuffer); + try + { + this.rpcContext.Serializer.Serialize(ref writer, message, this.rpcContext.ShapeProvider); + writer.Flush(); + } + catch (Exception ex) + { + throw new MessagePackSerializationException(string.Format(CultureInfo.CurrentCulture, Resources.ErrorWritingJsonRpcMessage, ex.GetType().Name, ex.Message), ex); + } + } + + /// + public object GetJsonText(JsonRpcMessage message) => message is IJsonRpcMessagePackRetention retainedMsgPack + ? MessagePackSerializer.ConvertToJson(retainedMsgPack.OriginalMessagePack) + : throw new NotSupportedException(); + + /// + Protocol.JsonRpcRequest IJsonRpcMessageFactory.CreateRequestMessage() => new OutboundJsonRpcRequest(this); + + /// + Protocol.JsonRpcError IJsonRpcMessageFactory.CreateErrorMessage() => new JsonRpcError(this.userDataContext); + + /// + Protocol.JsonRpcResult IJsonRpcMessageFactory.CreateResultMessage() => new JsonRpcResult(this, this.rpcContext); + + void IJsonRpcFormatterTracingCallbacks.OnSerializationComplete(JsonRpcMessage message, ReadOnlySequence encodedMessage) + { + IJsonRpcTracingCallbacks? tracingCallbacks = this.JsonRpc; + this.serializationToStringHelper.Activate(encodedMessage); + try + { + tracingCallbacks?.OnMessageSerialized(message, this.serializationToStringHelper); + } + finally + { + this.serializationToStringHelper.Deactivate(); + } + } + + /// + /// Extracts a dictionary of property names and values from the specified params object. + /// + /// The params object. + /// A dictionary of argument values and another of declared argument types, or if is null. + /// + /// This method supports DataContractSerializer-compliant types. This includes C# anonymous types. + /// + [return: NotNullIfNotNull(nameof(paramsObject))] + private static (IReadOnlyDictionary ArgumentValues, IReadOnlyDictionary ArgumentTypes)? GetParamsObjectDictionary(object? paramsObject) + { + if (paramsObject is null) + { + return default; + } + + // Look up the argument types dictionary if we saved it before. + Type paramsObjectType = paramsObject.GetType(); + IReadOnlyDictionary? argumentTypes; + lock (ParameterObjectPropertyTypes) + { + ParameterObjectPropertyTypes.TryGetValue(paramsObjectType, out argumentTypes); + } + + // If we couldn't find a previously created argument types dictionary, create a mutable one that we'll build this time. + Dictionary? mutableArgumentTypes = argumentTypes is null ? new Dictionary() : null; + + var result = new Dictionary(StringComparer.Ordinal); + + TypeInfo paramsTypeInfo = paramsObject.GetType().GetTypeInfo(); + bool isDataContract = paramsTypeInfo.GetCustomAttribute() is not null; + + BindingFlags bindingFlags = BindingFlags.FlattenHierarchy | BindingFlags.Public | BindingFlags.Instance; + if (isDataContract) + { + bindingFlags |= BindingFlags.NonPublic; + } + + bool TryGetSerializationInfo(MemberInfo memberInfo, out string key) + { + key = memberInfo.Name; + if (isDataContract) + { + DataMemberAttribute? dataMemberAttribute = memberInfo.GetCustomAttribute(); + if (dataMemberAttribute is null) + { + return false; + } + + if (!dataMemberAttribute.EmitDefaultValue) + { + throw new NotSupportedException($"(DataMemberAttribute.EmitDefaultValue == false) is not supported but was found on: {memberInfo.DeclaringType!.FullName}.{memberInfo.Name}."); + } + + key = dataMemberAttribute.Name ?? memberInfo.Name; + return true; + } + else + { + return memberInfo.GetCustomAttribute() is null; + } + } + + foreach (PropertyInfo property in paramsTypeInfo.GetProperties(bindingFlags)) + { + if (property.GetMethod is not null) + { + if (TryGetSerializationInfo(property, out string key)) + { + result[key] = property.GetValue(paramsObject); + if (mutableArgumentTypes is not null) + { + mutableArgumentTypes[key] = property.PropertyType; + } + } + } + } + + foreach (FieldInfo field in paramsTypeInfo.GetFields(bindingFlags)) + { + if (TryGetSerializationInfo(field, out string key)) + { + result[key] = field.GetValue(paramsObject); + if (mutableArgumentTypes is not null) + { + mutableArgumentTypes[key] = field.FieldType; + } + } + } + + // If we assembled the argument types dictionary this time, save it for next time. + if (mutableArgumentTypes is not null) + { + lock (ParameterObjectPropertyTypes) + { + if (ParameterObjectPropertyTypes.TryGetValue(paramsObjectType, out IReadOnlyDictionary? lostRace)) + { + // Of the two, pick the winner to use ourselves so we consolidate on one and allow the GC to collect the loser sooner. + argumentTypes = lostRace; + } + else + { + ParameterObjectPropertyTypes.Add(paramsObjectType, argumentTypes = mutableArgumentTypes); + } + } + } + + return (result, argumentTypes!); + } + + private static ReadOnlySequence GetSliceForNextToken(ref MessagePackReader reader, in SerializationContext context) + { + SequencePosition startingPosition = reader.Position; + reader.Skip(context); + SequencePosition endingPosition = reader.Position; + return reader.Sequence.Slice(startingPosition, endingPosition); + } + + /// + /// Reads a string with an optimized path for the value "2.0". + /// + /// The reader to use. + /// The decoded string. + private static unsafe string ReadProtocolVersion(ref MessagePackReader reader) + { + if (!reader.TryReadStringSpan(out ReadOnlySpan valueBytes)) + { + // TODO: More specific exception type + throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); + } + + // Recognize "2.0" since we expect it and can avoid decoding and allocating a new string for it. + if (Version2.TryRead(valueBytes)) + { + return Version2.Value; + } + else + { + // It wasn't the expected value, so decode it. + fixed (byte* pValueBytes = valueBytes) + { + return Encoding.UTF8.GetString(pValueBytes, valueBytes.Length); + } + } + } + + /// + /// Writes the JSON-RPC version property name and value in a highly optimized way. + /// + private static void WriteProtocolVersionPropertyAndValue(ref MessagePackWriter writer, string version) + { + VersionPropertyName.Write(ref writer); + if (!Version2.TryWrite(ref writer, version)) + { + writer.Write(version); + } + } + + private static void ReadUnknownProperty(ref MessagePackReader reader, in SerializationContext context, ref Dictionary>? topLevelProperties, ReadOnlySpan stringKey) + { + topLevelProperties ??= new Dictionary>(StringComparer.Ordinal); +#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER + string name = Encoding.UTF8.GetString(stringKey); +#else + string name = Encoding.UTF8.GetString(stringKey.ToArray()); +#endif + topLevelProperties.Add(name, GetSliceForNextToken(ref reader, context)); + } + + /// + /// Takes the user-supplied resolver for their data types and prepares the wrapping options + /// and the dynamic object wrapper for serialization. + /// + /// The options for user data that is supplied by the user (or the default). + private void MassageUserDataContext(FormatterContext userDataContext) + { + // Add our own resolvers to fill in specialized behavior if the user doesn't provide/override it by their own resolver. + userDataContext.Serializer.RegisterConverter(RequestIdConverter.Instance); + userDataContext.Serializer.RegisterConverter(EventArgsConverter.Instance); + } + + private class MessagePackFormatterConverter : IFormatterConverter + { + private readonly FormatterContext context; + + internal MessagePackFormatterConverter(FormatterContext formatterContext) + { + this.context = formatterContext; + } + +#pragma warning disable CS8766 // This method may in fact return null, and no one cares. + public object? Convert(object value, Type type) +#pragma warning restore CS8766 + { + return this.context.DeserializeObject((RawMessagePack)value, type); + } + + public object Convert(object value, TypeCode typeCode) + { + return typeCode switch + { + TypeCode.Object => this.context.Deserialize((RawMessagePack)value), + _ => ExceptionSerializationHelpers.Convert(this, value, typeCode), + }; + } + + public bool ToBoolean(object value) => this.context.Deserialize((RawMessagePack)value); + + public byte ToByte(object value) => this.context.Deserialize((RawMessagePack)value); + + public char ToChar(object value) => this.context.Deserialize((RawMessagePack)value); + + public DateTime ToDateTime(object value) => this.context.Deserialize((RawMessagePack)value); + + public decimal ToDecimal(object value) => this.context.Deserialize((RawMessagePack)value); + + public double ToDouble(object value) => this.context.Deserialize((RawMessagePack)value); + + public short ToInt16(object value) => this.context.Deserialize((RawMessagePack)value); + + public int ToInt32(object value) => this.context.Deserialize((RawMessagePack)value); + + public long ToInt64(object value) => this.context.Deserialize((RawMessagePack)value); + + public sbyte ToSByte(object value) => this.context.Deserialize((RawMessagePack)value); + + public float ToSingle(object value) => this.context.Deserialize((RawMessagePack)value); + + public string? ToString(object value) => value is null ? null : this.context.Deserialize((RawMessagePack)value); + + public ushort ToUInt16(object value) => this.context.Deserialize((RawMessagePack)value); + + public uint ToUInt32(object value) => this.context.Deserialize((RawMessagePack)value); + + public ulong ToUInt64(object value) => this.context.Deserialize((RawMessagePack)value); + } + + /// + /// A recyclable object that can serialize a message to JSON on demand. + /// + /// + /// In perf traces, creation of this object used to show up as one of the most allocated objects. + /// It is used even when tracing isn't active. So we changed its design to be reused, + /// since its lifetime is only required during a synchronous call to a trace API. + /// + private class ToStringHelper + { + private ReadOnlySequence? encodedMessage; + private string? jsonString; + + public override string ToString() + { + Verify.Operation(this.encodedMessage.HasValue, "This object has not been activated. It may have already been recycled."); + + return this.jsonString ??= MessagePackSerializer.ConvertToJson(this.encodedMessage.Value); + } + + /// + /// Initializes this object to represent a message. + /// + internal void Activate(ReadOnlySequence encodedMessage) + { + this.encodedMessage = encodedMessage; + } + + /// + /// Cleans out this object to release memory and ensure throws if someone uses it after deactivation. + /// + internal void Deactivate() + { + this.encodedMessage = null; + this.jsonString = null; + } + } + + private class RequestIdConverter : MessagePackConverter + { + internal static readonly RequestIdConverter Instance = new(); + + public override RequestId Read(ref MessagePackReader reader, SerializationContext context) + { + context.DepthStep(); + + if (reader.NextMessagePackType == MessagePackType.Integer) + { + return new RequestId(reader.ReadInt64()); + } + else + { + // Do *not* read as an interned string here because this ID should be unique. + return new RequestId(reader.ReadString()); + } + } + + public override void Write(ref MessagePackWriter writer, in RequestId value, SerializationContext context) + { + context.DepthStep(); + + if (value.Number.HasValue) + { + writer.Write(value.Number.Value); + } + else + { + writer.Write(value.String); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => JsonNode.Parse(""" + { + "type": ["string", { "type": "integer", "format": "int64" }] + } + """)?.AsObject(); + } + + private class ProgressConverterResolver + { + private readonly NerdbankMessagePackFormatter mainFormatter; + + internal ProgressConverterResolver(NerdbankMessagePackFormatter formatter) + { + this.mainFormatter = formatter; + } + + public MessagePackConverter GetConverter() + { + MessagePackConverter? converter = default; + + if (MessageFormatterProgressTracker.CanDeserialize(typeof(T))) + { + converter = new PreciseTypeConverter(this.mainFormatter); + } + else if (MessageFormatterProgressTracker.CanSerialize(typeof(T))) + { + converter = new ProgressClientConverter(this.mainFormatter); + } + + // TODO: Improve Exception + return converter ?? throw new NotSupportedException(); + } + + /// + /// Converts an instance of to a progress token. + /// + private class ProgressClientConverter : MessagePackConverter + { + private readonly NerdbankMessagePackFormatter formatter; + + internal ProgressClientConverter(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + public override TClass Read(ref MessagePackReader reader, SerializationContext context) + { + throw new NotSupportedException("This formatter only serializes IProgress instances."); + } + + public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + } + else + { + long progressId = this.formatter.FormatterProgressTracker.GetTokenForProgress(value); + writer.Write(progressId); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(ProgressClientConverter)); + } + } + + /// + /// Converts a progress token to an or an into a token. + /// + private class PreciseTypeConverter : MessagePackConverter + { + private readonly NerdbankMessagePackFormatter formatter; + + internal PreciseTypeConverter(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + [return: MaybeNull] + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] + public override TClass? Read(ref MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return default!; + } + + Assumes.NotNull(this.formatter.JsonRpc); + RawMessagePack token = reader.ReadRaw(context); + bool clientRequiresNamedArgs = this.formatter.ApplicableMethodAttributeOnDeserializingMethod?.ClientRequiresNamedArguments is true; + return (TClass)this.formatter.FormatterProgressTracker.CreateProgress(this.formatter.JsonRpc, token, typeof(TClass), clientRequiresNamedArgs); + } + + public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + } + else + { + long progressId = this.formatter.FormatterProgressTracker.GetTokenForProgress(value); + writer.Write(progressId); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(PreciseTypeConverter)); + } + } + } + + private class AsyncEnumerableConverterResolver + { + private readonly NerdbankMessagePackFormatter mainFormatter; + + internal AsyncEnumerableConverterResolver(NerdbankMessagePackFormatter formatter) + { + this.mainFormatter = formatter; + } + + public MessagePackConverter GetConverter() + { + MessagePackConverter? converter = default; + + if (TrackerHelpers>.IsActualInterfaceMatch(typeof(T))) + { + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(PreciseTypeConverter<>).MakeGenericType(typeof(T).GenericTypeArguments[0]), new object[] { this.mainFormatter }); + } + else if (TrackerHelpers>.FindInterfaceImplementedBy(typeof(T)) is { } iface) + { + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(GeneratorConverter<,>).MakeGenericType(typeof(T), iface.GenericTypeArguments[0]), new object[] { this.mainFormatter }); + } + + // TODO: Improve Exception + return converter ?? throw new NotSupportedException(); + } + + /// + /// Converts an enumeration token to an + /// or an into an enumeration token. + /// +#pragma warning disable CA1812 + private partial class PreciseTypeConverter(NerdbankMessagePackFormatter mainFormatter) : MessagePackConverter> +#pragma warning restore CA1812 + { + /// + /// The constant "token", in its various forms. + /// + private static readonly CommonString TokenPropertyName = new(MessageFormatterEnumerableTracker.TokenPropertyName); + + /// + /// The constant "values", in its various forms. + /// + private static readonly CommonString ValuesPropertyName = new(MessageFormatterEnumerableTracker.ValuesPropertyName); + + public override IAsyncEnumerable? Read(ref MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return default; + } + + context.DepthStep(); + RawMessagePack? token = default; + IReadOnlyList? initialElements = null; + int propertyCount = reader.ReadMapHeader(); + for (int i = 0; i < propertyCount; i++) + { + if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) + { + throw new MessagePackSerializationException(Resources.UnexpectedErrorProcessingJsonRpc); + } + + if (TokenPropertyName.TryRead(stringKey)) + { + token = reader.ReadRaw(context); + } + else if (ValuesPropertyName.TryRead(stringKey)) + { + initialElements = context.GetConverter>(context.TypeShapeProvider).Read(ref reader, context); + } + else + { + reader.Skip(context); + } + } + + return mainFormatter.EnumerableTracker.CreateEnumerableProxy(token.HasValue ? token : null, initialElements); + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] + public override void Write(ref MessagePackWriter writer, in IAsyncEnumerable? value, SerializationContext context) + { + Serialize_Shared(mainFormatter, ref writer, value, context); + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(PreciseTypeConverter)); + } + + internal static void Serialize_Shared(NerdbankMessagePackFormatter mainFormatter, ref MessagePackWriter writer, IAsyncEnumerable? value, SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + } + else + { + (IReadOnlyList Elements, bool Finished) prefetched = value.TearOffPrefetchedElements(); + long token = mainFormatter.EnumerableTracker.GetToken(value); + + int propertyCount = 0; + if (prefetched.Elements.Count > 0) + { + propertyCount++; + } + + if (!prefetched.Finished) + { + propertyCount++; + } + + writer.WriteMapHeader(propertyCount); + + if (!prefetched.Finished) + { + writer.Write(MessageFormatterEnumerableTracker.TokenPropertyName); + writer.Write(token); + } + + if (prefetched.Elements.Count > 0) + { + writer.Write(MessageFormatterEnumerableTracker.ValuesPropertyName); + context.GetConverter>(context.TypeShapeProvider).Write(ref writer, prefetched.Elements, context); + } + } + } + } + + /// + /// Converts an instance of to an enumeration token. + /// +#pragma warning disable CA1812 + private class GeneratorConverter(NerdbankMessagePackFormatter mainFormatter) : MessagePackConverter where TClass : IAsyncEnumerable +#pragma warning restore CA1812 + { + public override TClass Read(ref MessagePackReader reader, SerializationContext context) + { + throw new NotSupportedException(); + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] + public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) + { + PreciseTypeConverter.Serialize_Shared(mainFormatter, ref writer, value, context); + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(GeneratorConverter)); + } + } + } + + private class PipeConverterResolver + { + private readonly NerdbankMessagePackFormatter mainFormatter; + + internal PipeConverterResolver(NerdbankMessagePackFormatter formatter) + { + this.mainFormatter = formatter; + } + + public MessagePackConverter GetConverter() + { + MessagePackConverter? converter = default; + + if (typeof(IDuplexPipe).IsAssignableFrom(typeof(T))) + { + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(DuplexPipeConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + } + else if (typeof(PipeReader).IsAssignableFrom(typeof(T))) + { + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(PipeReaderConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + } + else if (typeof(PipeWriter).IsAssignableFrom(typeof(T))) + { + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(PipeWriterConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + } + else if (typeof(Stream).IsAssignableFrom(typeof(T))) + { + converter = (MessagePackConverter?)Activator.CreateInstance(typeof(StreamConverter<>).MakeGenericType(typeof(T)), this.mainFormatter)!; + } + + // TODO: Improve Exception + return converter ?? throw new NotSupportedException(); + } + +#pragma warning disable CA1812 + private class DuplexPipeConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter + where T : class, IDuplexPipe +#pragma warning restore CA1812 + { + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return null; + } + + return (T)formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()); + } + + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(DuplexPipeConverter)); + } + } + +#pragma warning disable CA1812 + private class PipeReaderConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter + where T : PipeReader +#pragma warning restore CA1812 + { + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return null; + } + + return (T)formatter.DuplexPipeTracker.GetPipeReader(reader.ReadUInt64()); + } + + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(PipeReaderConverter)); + } + } + +#pragma warning disable CA1812 + private class PipeWriterConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter + where T : PipeWriter +#pragma warning restore CA1812 + { + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return null; + } + + return (T)formatter.DuplexPipeTracker.GetPipeWriter(reader.ReadUInt64()); + } + + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(PipeWriterConverter)); + } + } + +#pragma warning disable CA1812 + private class StreamConverter : MessagePackConverter + where T : Stream +#pragma warning restore CA1812 + { + private readonly NerdbankMessagePackFormatter formatter; + + public StreamConverter(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return null; + } + + return (T)this.formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()).AsStream(); + } + + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + if (this.formatter.DuplexPipeTracker.GetULongToken(value?.UsePipe()) is { } token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(StreamConverter)); + } + } + } + +#pragma warning disable CA1812 + private class RpcMarshalableConverter( + NerdbankMessagePackFormatter formatter, + JsonRpcProxyOptions proxyOptions, + JsonRpcTargetOptions targetOptions, + RpcMarshalableAttribute rpcMarshalableAttribute) : MessagePackConverter + where T : class +#pragma warning restore CA1812 + { + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = formatter.rpcContext.Deserialize(ref reader); + return token.HasValue ? (T?)formatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions) : null; + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + } + else + { + MessageFormatterRpcMarshaledContextTracker.MarshalToken token = formatter.RpcMarshaledContextTracker.GetToken(value, targetOptions, typeof(T), rpcMarshalableAttribute); + formatter.rpcContext.Serialize(ref writer, token); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(RpcMarshalableConverter)); + } + } + + /// + /// Manages serialization of any -derived type that follows standard rules. + /// + /// + /// A serializable class will: + /// 1. Derive from + /// 2. Be attributed with + /// 3. Declare a constructor with a signature of (, ). + /// + private class MessagePackExceptionConverterResolver + { + /// + /// Tracks recursion count while serializing or deserializing an exception. + /// + /// + /// This is placed here (outside the generic class) + /// so that it's one counter shared across all exception types that may be serialized or deserialized. + /// + private static ThreadLocal exceptionRecursionCounter = new(); + + private readonly object[] formatterActivationArgs; + + internal MessagePackExceptionConverterResolver(NerdbankMessagePackFormatter formatter) + { + this.formatterActivationArgs = new object[] { formatter }; + } + + public MessagePackConverter GetConverter() + { + MessagePackConverter? formatter = null; + if (typeof(Exception).IsAssignableFrom(typeof(T)) && typeof(T).GetCustomAttribute() is object) + { + formatter = (MessagePackConverter)Activator.CreateInstance(typeof(ExceptionConverter<>).MakeGenericType(typeof(T)), this.formatterActivationArgs)!; + } + + // TODO: Improve Exception + return formatter ?? throw new NotSupportedException(); + } + +#pragma warning disable CA1812 + private partial class ExceptionConverter(NerdbankMessagePackFormatter formatter) : MessagePackConverter + where T : Exception +#pragma warning restore CA1812 + { + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + Assumes.NotNull(formatter.JsonRpc); + if (reader.TryReadNil()) + { + return null; + } + + // We have to guard our own recursion because the serializer has no visibility into inner exceptions. + // Each exception in the russian doll is a new serialization job from its perspective. + exceptionRecursionCounter.Value++; + try + { + if (exceptionRecursionCounter.Value > formatter.JsonRpc.ExceptionOptions.RecursionLimit) + { + // Exception recursion has gone too deep. Skip this value and return null as if there were no inner exception. + // Note that in skipping, the parser may use recursion internally and may still throw if its own limits are exceeded. + reader.Skip(context); + return null; + } + + // TODO: Is this the right context? + var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.rpcContext)); + int memberCount = reader.ReadMapHeader(); + for (int i = 0; i < memberCount; i++) + { + string? name = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context) + ?? throw new MessagePackSerializationException(Resources.UnexpectedNullValueInMap); + + // SerializationInfo.GetValue(string, typeof(object)) does not call our formatter, + // so the caller will get a boxed RawMessagePack struct in that case. + // Although we can't do much about *that* in general, we can at least ensure that null values + // are represented as null instead of this boxed struct. + var value = reader.TryReadNil() ? null : (object)reader.ReadRaw(context); + + info.AddSafeValue(name, value); + } + + return ExceptionSerializationHelpers.Deserialize(formatter.JsonRpc, info, formatter.JsonRpc.TraceSource); + } + finally + { + exceptionRecursionCounter.Value--; + } + } + + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + return; + } + + exceptionRecursionCounter.Value++; + try + { + if (exceptionRecursionCounter.Value > formatter.JsonRpc?.ExceptionOptions.RecursionLimit) + { + // Exception recursion has gone too deep. Skip this value and write null as if there were no inner exception. + writer.WriteNil(); + return; + } + + // TODO: Is this the right context? + var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter.rpcContext)); + ExceptionSerializationHelpers.Serialize(value, info); + writer.WriteMapHeader(info.GetSafeMemberCount()); + foreach (SerializationEntry element in info.GetSafeMembers()) + { + writer.Write(element.Name); +#pragma warning disable NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + formatter.rpcContext.SerializeObject( + ref writer, + element.Value, + element.ObjectType); +#pragma warning restore NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + } + } + finally + { + exceptionRecursionCounter.Value--; + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(ExceptionConverter)); + } + } + } + + private class JsonRpcMessageConverter : MessagePackConverter + { + private readonly NerdbankMessagePackFormatter formatter; + + internal JsonRpcMessageConverter(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + public override JsonRpcMessage? Read(ref MessagePackReader reader, SerializationContext context) + { + context.DepthStep(); + + MessagePackReader readAhead = reader.CreatePeekReader(); + int propertyCount = readAhead.ReadMapHeader(); + for (int i = 0; i < propertyCount; i++) + { + // We read the property name in this fancy way in order to avoid paying to decode and allocate a string when we already know what we're looking for. + // MessagePackFormatter: ReadOnlySpan stringKey = MessagePack.Internal.CodeGenHelpers.ReadStringSpan(ref readAhead); + if (!readAhead.TryReadStringSpan(out ReadOnlySpan stringKey)) + { + throw new UnrecognizedJsonRpcMessageException(); + } + + if (MethodPropertyName.TryRead(stringKey)) + { + return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ResultPropertyName.TryRead(stringKey)) + { + return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ErrorPropertyName.TryRead(stringKey)) + { + return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else + { + readAhead.Skip(context); + } + } + + throw new UnrecognizedJsonRpcMessageException(); + } + + public override void Write(ref MessagePackWriter writer, in JsonRpcMessage? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + using (this.formatter.TrackSerialization(value)) + { + context.DepthStep(); + + switch (value) + { + case Protocol.JsonRpcRequest request: + context.GetConverter(context.TypeShapeProvider).Write(ref writer, request, context); + break; + case Protocol.JsonRpcResult result: + context.GetConverter(context.TypeShapeProvider).Write(ref writer, result, context); + break; + case Protocol.JsonRpcError error: + context.GetConverter(context.TypeShapeProvider).Write(ref writer, error, context); + break; + default: + throw new NotSupportedException("Unexpected JsonRpcMessage-derived type: " + value.GetType().Name); + } + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return base.GetJsonSchema(context, typeShape); + } + } + + private class JsonRpcRequestConverter : MessagePackConverter + { + private readonly NerdbankMessagePackFormatter formatter; + + internal JsonRpcRequestConverter(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + public override Protocol.JsonRpcRequest? Read(ref MessagePackReader reader, SerializationContext context) + { + var result = new JsonRpcRequest(this.formatter) + { + OriginalMessagePack = reader.Sequence, + }; + + context.DepthStep(); + + int propertyCount = reader.ReadMapHeader(); + Dictionary>? topLevelProperties = null; + for (int propertyIndex = 0; propertyIndex < propertyCount; propertyIndex++) + { + // We read the property name in this fancy way in order to avoid paying to decode and allocate a string when we already know what we're looking for. + if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) + { + throw new UnrecognizedJsonRpcMessageException(); + } + + if (VersionPropertyName.TryRead(stringKey)) + { + result.Version = ReadProtocolVersion(ref reader); + } + else if (IdPropertyName.TryRead(stringKey)) + { + result.RequestId = context.GetConverter(null).Read(ref reader, context); + } + else if (MethodPropertyName.TryRead(stringKey)) + { + result.Method = context.GetConverter(null).Read(ref reader, context); + } + else if (ParamsPropertyName.TryRead(stringKey)) + { + SequencePosition paramsTokenStartPosition = reader.Position; + + // Parse out the arguments into a dictionary or array, but don't deserialize them because we don't yet know what types to deserialize them to. + switch (reader.NextMessagePackType) + { + case MessagePackType.Array: + var positionalArgs = new ReadOnlySequence[reader.ReadArrayHeader()]; + for (int i = 0; i < positionalArgs.Length; i++) + { + positionalArgs[i] = GetSliceForNextToken(ref reader, context); + } + + result.MsgPackPositionalArguments = positionalArgs; + break; + case MessagePackType.Map: + int namedArgsCount = reader.ReadMapHeader(); + var namedArgs = new Dictionary>(namedArgsCount); + for (int i = 0; i < namedArgsCount; i++) + { + string? propertyName = context.GetConverter(null).Read(ref reader, context); + if (propertyName is null) + { + throw new MessagePackSerializationException(Resources.UnexpectedNullValueInMap); + } + + namedArgs.Add(propertyName, GetSliceForNextToken(ref reader, context)); + } + + result.MsgPackNamedArguments = namedArgs; + break; + case MessagePackType.Nil: + result.MsgPackPositionalArguments = Array.Empty>(); + reader.ReadNil(); + break; + case MessagePackType type: + throw new MessagePackSerializationException("Expected a map or array of arguments but got " + type); + } + + result.MsgPackArguments = reader.Sequence.Slice(paramsTokenStartPosition, reader.Position); + } + else if (TraceParentPropertyName.TryRead(stringKey)) + { + TraceParent traceParent = context.GetConverter(null).Read(ref reader, context); + result.TraceParent = traceParent.ToString(); + } + else if (TraceStatePropertyName.TryRead(stringKey)) + { + result.TraceState = ReadTraceState(ref reader, context); + } + else + { + ReadUnknownProperty(ref reader, context, ref topLevelProperties, stringKey); + } + } + + if (topLevelProperties is not null) + { + result.TopLevelPropertyBag = new TopLevelPropertyBag(this.formatter.userDataContext, topLevelProperties); + } + + this.formatter.TryHandleSpecialIncomingMessage(result); + + return result; + } + + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcRequest? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + context.DepthStep(); + + var topLevelPropertyBag = (TopLevelPropertyBag?)(value as IMessageWithTopLevelPropertyBag)?.TopLevelPropertyBag; + + int mapElementCount = value.RequestId.IsEmpty ? 3 : 4; + if (value.TraceParent?.Length > 0) + { + mapElementCount++; + if (value.TraceState?.Length > 0) + { + mapElementCount++; + } + } + + mapElementCount += topLevelPropertyBag?.PropertyCount ?? 0; + writer.WriteMapHeader(mapElementCount); + + WriteProtocolVersionPropertyAndValue(ref writer, value.Version); + + if (!value.RequestId.IsEmpty) + { + IdPropertyName.Write(ref writer); + context.GetConverter(context.TypeShapeProvider) + .Write(ref writer, value.RequestId, context); + } + + MethodPropertyName.Write(ref writer); + writer.Write(value.Method); + + ParamsPropertyName.Write(ref writer); + + // TODO: Get from SetOptions + ITypeShapeProvider? userShapeProvider = context.TypeShapeProvider; + + if (value.ArgumentsList is not null) + { + writer.WriteArrayHeader(value.ArgumentsList.Count); + + + for (int i = 0; i < value.ArgumentsList.Count; i++) + { + object? arg = value.ArgumentsList[i]; + ITypeShape? argShape = arg is null + ? null + : value.ArgumentListDeclaredTypes is not null + ? userShapeProvider?.GetShape(value.ArgumentListDeclaredTypes[i]) + : ReflectionTypeShapeProvider.Default.Resolve(arg.GetType()); + + if (argShape is not null) + { +#pragma warning disable NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + this.formatter.userDataContext.Serializer.SerializeObject(ref writer, arg, argShape, context.CancellationToken); +#pragma warning restore NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + } + else + { + // TODO: NOT REALLY SURE ABOUT THIS YET + writer.WriteNil(); + } + } + } + else if (value.NamedArguments is not null) + { + writer.WriteMapHeader(value.NamedArguments.Count); + foreach (KeyValuePair entry in value.NamedArguments) + { + writer.Write(entry.Key); + ITypeShape? argShape = value.NamedArgumentDeclaredTypes?[entry.Key] is Type argType + ? userShapeProvider?.GetShape(argType) + : null; + + if (argShape is not null) + { +#pragma warning disable NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + this.formatter.userDataContext.Serializer.SerializeObject(ref writer, entry.Value, argShape, context.CancellationToken); +#pragma warning restore NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + } + else + { + // TODO: NOT REALLY SURE ABOUT THIS YET + writer.WriteNil(); + } + } + } + else + { + writer.WriteNil(); + } + + if (value.TraceParent?.Length > 0) + { + TraceParentPropertyName.Write(ref writer); + context.GetConverter(context.TypeShapeProvider) + .Write(ref writer, new TraceParent(value.TraceParent), context); + + if (value.TraceState?.Length > 0) + { + TraceStatePropertyName.Write(ref writer); + WriteTraceState(ref writer, value.TraceState); + } + } + + topLevelPropertyBag?.WriteProperties(ref writer); + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(JsonRpcRequestConverter)); + } + + private static void WriteTraceState(ref MessagePackWriter writer, string traceState) + { + ReadOnlySpan traceStateChars = traceState.AsSpan(); + + // Count elements first so we can write the header. + int elementCount = 1; + int commaIndex; + while ((commaIndex = traceStateChars.IndexOf(',')) >= 0) + { + elementCount++; + traceStateChars = traceStateChars.Slice(commaIndex + 1); + } + + // For every element, we have a key and value to record. + writer.WriteArrayHeader(elementCount * 2); + + traceStateChars = traceState.AsSpan(); + while ((commaIndex = traceStateChars.IndexOf(',')) >= 0) + { + ReadOnlySpan element = traceStateChars.Slice(0, commaIndex); + WritePair(ref writer, element); + traceStateChars = traceStateChars.Slice(commaIndex + 1); + } + + // Write out the last one. + WritePair(ref writer, traceStateChars); + + static void WritePair(ref MessagePackWriter writer, ReadOnlySpan pair) + { + int equalsIndex = pair.IndexOf('='); + ReadOnlySpan key = pair.Slice(0, equalsIndex); + ReadOnlySpan value = pair.Slice(equalsIndex + 1); + writer.Write(key); + writer.Write(value); + } + } + + private static unsafe string ReadTraceState(ref MessagePackReader reader, SerializationContext context) + { + int elements = reader.ReadArrayHeader(); + if (elements % 2 != 0) + { + throw new NotSupportedException("Odd number of elements not expected."); + } + + // With care, we could probably assemble this string with just two allocations (the string + a char[]). + var resultBuilder = new StringBuilder(); + for (int i = 0; i < elements; i += 2) + { + if (resultBuilder.Length > 0) + { + resultBuilder.Append(','); + } + + // We assume the key is a frequent string, and the value is unique, + // so we optimize whether to use string interning or not on that basis. + resultBuilder.Append(context.GetConverter(null).Read(ref reader, context)); + resultBuilder.Append('='); + resultBuilder.Append(reader.ReadString()); + } + + return resultBuilder.ToString(); + } + } + + private partial class JsonRpcResultConverter : MessagePackConverter + { + private readonly NerdbankMessagePackFormatter formatter; + + internal JsonRpcResultConverter(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + public override Protocol.JsonRpcResult Read(ref MessagePackReader reader, SerializationContext context) + { + var result = new JsonRpcResult(this.formatter, this.formatter.userDataContext) + { + OriginalMessagePack = reader.Sequence, + }; + + Dictionary>? topLevelProperties = null; + context.DepthStep(); + + int propertyCount = reader.ReadMapHeader(); + for (int propertyIndex = 0; propertyIndex < propertyCount; propertyIndex++) + { + // We read the property name in this fancy way in order to avoid paying to decode and allocate a string when we already know what we're looking for. + if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) + { + throw new UnrecognizedJsonRpcMessageException(); + } + + if (VersionPropertyName.TryRead(stringKey)) + { + result.Version = ReadProtocolVersion(ref reader); + } + else if (IdPropertyName.TryRead(stringKey)) + { + result.RequestId = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ResultPropertyName.TryRead(stringKey)) + { + result.MsgPackResult = GetSliceForNextToken(ref reader, context); + } + else + { + ReadUnknownProperty(ref reader, context, ref topLevelProperties, stringKey); + } + } + + if (topLevelProperties is not null) + { + result.TopLevelPropertyBag = new TopLevelPropertyBag(this.formatter.userDataContext, topLevelProperties); + } + + return result; + } + + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcResult? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + var topLevelPropertyBagMessage = value as IMessageWithTopLevelPropertyBag; + + int mapElementCount = 3; + mapElementCount += (topLevelPropertyBagMessage?.TopLevelPropertyBag as TopLevelPropertyBag)?.PropertyCount ?? 0; + writer.WriteMapHeader(mapElementCount); + + WriteProtocolVersionPropertyAndValue(ref writer, value.Version); + + IdPropertyName.Write(ref writer); + context.GetConverter(context.TypeShapeProvider).Write(ref writer, value.RequestId, context); + + ResultPropertyName.Write(ref writer); + + ITypeShape? typeShape = value.ResultDeclaredType is not null && value.ResultDeclaredType != typeof(void) + ? this.formatter.userDataContext.ShapeProvider.Resolve(value.ResultDeclaredType) + : value.Result is null + ? null + : this.formatter.userDataContext.ShapeProvider.Resolve(value.Result.GetType()); + + if (typeShape is not null) + { +#pragma warning disable NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + this.formatter.userDataContext.Serializer.SerializeObject(ref writer, value.Result, typeShape, context.CancellationToken); +#pragma warning restore NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + } + else + { + // TODO: NOT REALLY SURE ABOUT THIS YET + writer.WriteNil(); + } + + (topLevelPropertyBagMessage?.TopLevelPropertyBag as TopLevelPropertyBag)?.WriteProperties(ref writer); + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(JsonRpcResultConverter)); + } + } + + private partial class JsonRpcErrorConverter : MessagePackConverter + { + private readonly NerdbankMessagePackFormatter formatter; + + internal JsonRpcErrorConverter(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + public override Protocol.JsonRpcError Read(ref MessagePackReader reader, SerializationContext context) + { + var error = new JsonRpcError(this.formatter.rpcContext) + { + OriginalMessagePack = reader.Sequence, + }; + + Dictionary>? topLevelProperties = null; + + context.DepthStep(); + + int propertyCount = reader.ReadMapHeader(); + for (int propertyIdx = 0; propertyIdx < propertyCount; propertyIdx++) + { + // We read the property name in this fancy way in order to avoid paying to decode and allocate a string when we already know what we're looking for. + if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) + { + throw new UnrecognizedJsonRpcMessageException(); + } + + if (VersionPropertyName.TryRead(stringKey)) + { + error.Version = ReadProtocolVersion(ref reader); + } + else if (IdPropertyName.TryRead(stringKey)) + { + error.RequestId = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ErrorPropertyName.TryRead(stringKey)) + { + error.Error = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else + { + ReadUnknownProperty(ref reader, context, ref topLevelProperties, stringKey); + } + } + + if (topLevelProperties is not null) + { + error.TopLevelPropertyBag = new TopLevelPropertyBag(this.formatter.userDataContext, topLevelProperties); + } + + return error; + } + + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcError? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + var topLevelPropertyBag = (TopLevelPropertyBag?)(value as IMessageWithTopLevelPropertyBag)?.TopLevelPropertyBag; + + int mapElementCount = 3; + mapElementCount += topLevelPropertyBag?.PropertyCount ?? 0; + writer.WriteMapHeader(mapElementCount); + + WriteProtocolVersionPropertyAndValue(ref writer, value.Version); + + IdPropertyName.Write(ref writer); + context.GetConverter(context.TypeShapeProvider).Write(ref writer, value.RequestId, context); + + ErrorPropertyName.Write(ref writer); + context.GetConverter(context.TypeShapeProvider).Write(ref writer, value.Error, context); + + topLevelPropertyBag?.WriteProperties(ref writer); + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(JsonRpcErrorConverter)); + } + } + + private partial class JsonRpcErrorDetailConverter : MessagePackConverter + { + private static readonly CommonString CodePropertyName = new("code"); + private static readonly CommonString MessagePropertyName = new("message"); + private static readonly CommonString DataPropertyName = new("data"); + + private readonly NerdbankMessagePackFormatter formatter; + + internal JsonRpcErrorDetailConverter(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + public override Protocol.JsonRpcError.ErrorDetail Read(ref MessagePackReader reader, SerializationContext context) + { + var result = new JsonRpcError.ErrorDetail(this.formatter.userDataContext); + context.DepthStep(); + + int propertyCount = reader.ReadMapHeader(); + for (int propertyIdx = 0; propertyIdx < propertyCount; propertyIdx++) + { + if (!reader.TryReadStringSpan(out ReadOnlySpan stringKey)) + { + throw new UnrecognizedJsonRpcMessageException(); + } + + if (CodePropertyName.TryRead(stringKey)) + { + result.Code = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (MessagePropertyName.TryRead(stringKey)) + { + result.Message = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (DataPropertyName.TryRead(stringKey)) + { + result.MsgPackData = GetSliceForNextToken(ref reader, context); + } + else + { + reader.Skip(context); + } + } + + return result; + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "")] + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcError.ErrorDetail? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + writer.WriteMapHeader(3); + + CodePropertyName.Write(ref writer); + context.GetConverter(context.TypeShapeProvider).Write(ref writer, value.Code, context); + + MessagePropertyName.Write(ref writer); + writer.Write(value.Message); + + DataPropertyName.Write(ref writer); +#pragma warning disable NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + this.formatter.userDataContext.Serializer.SerializeObject( + ref writer, + value.Data, + this.formatter.userDataContext.ShapeProvider.Resolve()); +#pragma warning restore NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(JsonRpcErrorDetailConverter)); + } + } + + /// + /// Enables formatting the default/empty class. + /// + private class EventArgsConverter : MessagePackConverter + { + internal static readonly EventArgsConverter Instance = new(); + + private EventArgsConverter() + { + } + + /// + public override void Write(ref MessagePackWriter writer, in EventArgs? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + writer.WriteMapHeader(0); + } + + /// + public override EventArgs Read(ref MessagePackReader reader, SerializationContext context) + { + reader.Skip(context); + return EventArgs.Empty; + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(EventArgsConverter)); + } + } + + private class TraceParentConverter : MessagePackConverter + { + public unsafe override TraceParent Read(ref MessagePackReader reader, SerializationContext context) + { + if (reader.ReadArrayHeader() != 2) + { + throw new NotSupportedException("Unexpected array length."); + } + + var result = default(TraceParent); + result.Version = reader.ReadByte(); + if (result.Version != 0) + { + throw new NotSupportedException("traceparent version " + result.Version + " is not supported."); + } + + if (reader.ReadArrayHeader() != 3) + { + throw new NotSupportedException("Unexpected array length in version-format."); + } + + ReadOnlySequence bytes = reader.ReadBytes() ?? throw new NotSupportedException("Expected traceid not found."); + bytes.CopyTo(new Span(result.TraceId, TraceParent.TraceIdByteCount)); + + bytes = reader.ReadBytes() ?? throw new NotSupportedException("Expected parentid not found."); + bytes.CopyTo(new Span(result.ParentId, TraceParent.ParentIdByteCount)); + + result.Flags = (TraceParent.TraceFlags)reader.ReadByte(); + + return result; + } + + public unsafe override void Write(ref MessagePackWriter writer, in TraceParent value, SerializationContext context) + { + if (value.Version != 0) + { + throw new NotSupportedException("traceparent version " + value.Version + " is not supported."); + } + + writer.WriteArrayHeader(2); + + writer.Write(value.Version); + + writer.WriteArrayHeader(3); + + fixed (byte* traceId = value.TraceId) + { + writer.Write(new ReadOnlySpan(traceId, TraceParent.TraceIdByteCount)); + } + + fixed (byte* parentId = value.ParentId) + { + writer.Write(new ReadOnlySpan(parentId, TraceParent.ParentIdByteCount)); + } + + writer.Write((byte)value.Flags); + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) + { + return CreateUndocumentedSchema(typeof(TraceParentConverter)); + } + } + + private class TopLevelPropertyBag : TopLevelPropertyBagBase + { + private readonly FormatterContext formatterContext; + private readonly IReadOnlyDictionary>? inboundUnknownProperties; + + /// + /// Initializes a new instance of the class + /// for an incoming message. + /// + /// The serializer options to use for this data. + /// The map of unrecognized inbound properties. + internal TopLevelPropertyBag(FormatterContext userDataContext, IReadOnlyDictionary> inboundUnknownProperties) + : base(isOutbound: false) + { + this.formatterContext = userDataContext; + this.inboundUnknownProperties = inboundUnknownProperties; + } + + /// + /// Initializes a new instance of the class + /// for an outbound message. + /// + /// The serializer options to use for this data. + internal TopLevelPropertyBag(FormatterContext formatterContext) + : base(isOutbound: true) + { + this.formatterContext = formatterContext; + } + + internal int PropertyCount => this.inboundUnknownProperties?.Count ?? this.OutboundProperties?.Count ?? 0; + + /// + /// Writes the properties tracked by this collection to a messagepack writer. + /// + /// The writer to use. + internal void WriteProperties(ref MessagePackWriter writer) + { + if (this.inboundUnknownProperties is not null) + { + // We're actually re-transmitting an incoming message (remote target feature). + // We need to copy all the properties that were in the original message. + // Don't implement this without enabling the tests for the scenario found in JsonRpcRemoteTargetMessagePackFormatterTests.cs. + // The tests fail for reasons even without this support, so there's work to do beyond just implementing this. + throw new NotImplementedException(); + + ////foreach (KeyValuePair> entry in this.inboundUnknownProperties) + ////{ + //// writer.Write(entry.Key); + //// writer.Write(entry.Value); + ////} + } + else + { + foreach (KeyValuePair entry in this.OutboundProperties) + { + ITypeShape shape = this.formatterContext.ShapeProvider.Resolve(entry.Value.DeclaredType); + + writer.Write(entry.Key); + this.formatterContext.Serializer.SerializeObject(ref writer, entry.Value.Value, shape); + } + } + } + + protected internal override bool TryGetTopLevelProperty(string name, [MaybeNull] out T value) + { + if (this.inboundUnknownProperties is null) + { + throw new InvalidOperationException(Resources.InboundMessageOnly); + } + + value = default; + + if (this.inboundUnknownProperties.TryGetValue(name, out ReadOnlySequence serializedValue) is true) + { + var reader = new MessagePackReader(serializedValue); + value = this.formatterContext.Serializer.Deserialize(ref reader, this.formatterContext.ShapeProvider); + return true; + } + + return false; + } + } + + [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] + [DataContract] + private class OutboundJsonRpcRequest : JsonRpcRequestBase + { + private readonly NerdbankMessagePackFormatter formatter; + + internal OutboundJsonRpcRequest(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter ?? throw new ArgumentNullException(nameof(formatter)); + } + + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatter.userDataContext); + } + + [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] + [DataContract] + private class JsonRpcRequest : JsonRpcRequestBase, IJsonRpcMessagePackRetention + { + private readonly NerdbankMessagePackFormatter formatter; + + internal JsonRpcRequest(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter ?? throw new ArgumentNullException(nameof(formatter)); + } + + public override int ArgumentCount => this.MsgPackNamedArguments?.Count ?? this.MsgPackPositionalArguments?.Count ?? base.ArgumentCount; + + public override IEnumerable? ArgumentNames => this.MsgPackNamedArguments?.Keys; + + public ReadOnlySequence OriginalMessagePack { get; internal set; } + + internal ReadOnlySequence MsgPackArguments { get; set; } + + internal IReadOnlyDictionary>? MsgPackNamedArguments { get; set; } + + internal IReadOnlyList>? MsgPackPositionalArguments { get; set; } + + public override ArgumentMatchResult TryGetTypedArguments(ReadOnlySpan parameters, Span typedArguments) + { + using (this.formatter.TrackDeserialization(this, parameters)) + { + if (parameters.Length == 1 && this.MsgPackNamedArguments is not null) + { + if (this.formatter.ApplicableMethodAttributeOnDeserializingMethod?.UseSingleObjectParameterDeserialization ?? false) + { + var reader = new MessagePackReader(this.MsgPackArguments); + try + { + typedArguments[0] = this.formatter.userDataContext.Serializer.DeserializeObject( + ref reader, + this.formatter.userDataContext.ShapeProvider.Resolve(parameters[0].ParameterType)); + + return ArgumentMatchResult.Success; + } + catch (MessagePackSerializationException) + { + return ArgumentMatchResult.ParameterArgumentTypeMismatch; + } + } + } + + return base.TryGetTypedArguments(parameters, typedArguments); + } + } + + public override bool TryGetArgumentByNameOrIndex(string? name, int position, Type? typeHint, out object? value) + { + // If anyone asks us for an argument *after* we've been told deserialization is done, there's something very wrong. + Assumes.True(this.MsgPackNamedArguments is not null || this.MsgPackPositionalArguments is not null); + + ReadOnlySequence msgpackArgument = default; + if (position >= 0 && this.MsgPackPositionalArguments?.Count > position) + { + msgpackArgument = this.MsgPackPositionalArguments[position]; + } + else if (name is not null && this.MsgPackNamedArguments is not null) + { + this.MsgPackNamedArguments.TryGetValue(name, out msgpackArgument); + } + + if (msgpackArgument.IsEmpty) + { + value = null; + return false; + } + + var reader = new MessagePackReader(msgpackArgument); + using (this.formatter.TrackDeserialization(this)) + { + try + { + value = this.formatter.userDataContext.Serializer.DeserializeObject( + ref reader, + this.formatter.userDataContext.ShapeProvider.Resolve(typeHint ?? typeof(object))); + + return true; + } + catch (MessagePackSerializationException ex) + { + if (this.formatter.JsonRpc?.TraceSource.Switch.ShouldTrace(TraceEventType.Warning) ?? false) + { + this.formatter.JsonRpc.TraceSource.TraceEvent(TraceEventType.Warning, (int)JsonRpc.TraceEvents.MethodArgumentDeserializationFailure, Resources.FailureDeserializingRpcArgument, name, position, typeHint, ex); + } + + throw new RpcArgumentDeserializationException(name, position, typeHint, ex); + } + } + } + + protected override void ReleaseBuffers() + { + base.ReleaseBuffers(); + this.MsgPackNamedArguments = null; + this.MsgPackPositionalArguments = null; + this.TopLevelPropertyBag = null; + this.MsgPackArguments = default; + this.OriginalMessagePack = default; + } + + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatter.userDataContext); + } + + [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] + [DataContract] + private class JsonRpcResult : JsonRpcResultBase, IJsonRpcMessagePackRetention + { + private readonly NerdbankMessagePackFormatter formatter; + private readonly FormatterContext serializerOptions; + + private Exception? resultDeserializationException; + + internal JsonRpcResult(NerdbankMessagePackFormatter formatter, FormatterContext serializationOptions) + { + this.formatter = formatter; + this.serializerOptions = serializationOptions; + } + + public ReadOnlySequence OriginalMessagePack { get; internal set; } + + internal ReadOnlySequence MsgPackResult { get; set; } + + public override T GetResult() + { + if (this.resultDeserializationException is not null) + { + ExceptionDispatchInfo.Capture(this.resultDeserializationException).Throw(); + } + + return this.MsgPackResult.IsEmpty + ? (T)this.Result! + : this.serializerOptions.Serializer.Deserialize(this.MsgPackResult, this.serializerOptions.ShapeProvider) + ?? throw new MessagePackSerializationException(Resources.FailureDeserializingJsonRpc); + } + + protected internal override void SetExpectedResultType(Type resultType) + { + Verify.Operation(!this.MsgPackResult.IsEmpty, "Result is no longer available or has already been deserialized."); + + var reader = new MessagePackReader(this.MsgPackResult); + try + { + using (this.formatter.TrackDeserialization(this)) + { + this.Result = this.serializerOptions.Serializer.DeserializeObject( + ref reader, + this.serializerOptions.ShapeProvider.Resolve(resultType)); + } + + this.MsgPackResult = default; + } + catch (MessagePackSerializationException ex) + { + // This was a best effort anyway. We'll throw again later at a more convenient time for JsonRpc. + this.resultDeserializationException = ex; + } + } + + protected override void ReleaseBuffers() + { + base.ReleaseBuffers(); + this.MsgPackResult = default; + this.OriginalMessagePack = default; + } + + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.serializerOptions); + } + + [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] + [DataContract] + private class JsonRpcError : JsonRpcErrorBase, IJsonRpcMessagePackRetention + { + private readonly FormatterContext serializerOptions; + + public JsonRpcError(FormatterContext serializerOptions) + { + this.serializerOptions = serializerOptions; + } + + public ReadOnlySequence OriginalMessagePack { get; internal set; } + + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.serializerOptions); + + protected override void ReleaseBuffers() + { + base.ReleaseBuffers(); + if (this.Error is ErrorDetail privateDetail) + { + privateDetail.MsgPackData = default; + } + + this.OriginalMessagePack = default; + } + + [DataContract] + internal new class ErrorDetail : Protocol.JsonRpcError.ErrorDetail + { + private readonly FormatterContext serializerOptions; + + internal ErrorDetail(FormatterContext serializerOptions) + { + this.serializerOptions = serializerOptions ?? throw new ArgumentNullException(nameof(serializerOptions)); + } + + internal ReadOnlySequence MsgPackData { get; set; } + + public override object? GetData(Type dataType) + { + Requires.NotNull(dataType, nameof(dataType)); + if (this.MsgPackData.IsEmpty) + { + return this.Data; + } + + var reader = new MessagePackReader(this.MsgPackData); + try + { + return this.serializerOptions.Serializer.DeserializeObject( + ref reader, + this.serializerOptions.ShapeProvider.Resolve(dataType)) + ?? throw new MessagePackSerializationException(Resources.FailureDeserializingJsonRpc); + } + catch (MessagePackSerializationException) + { + // Deserialization failed. Try returning array/dictionary based primitive objects. + try + { + // return MessagePackSerializer.Deserialize(this.MsgPackData, this.serializerOptions.WithResolver(PrimitiveObjectResolver.Instance)); + // TODO: Which Shape Provider to use? + return this.serializerOptions.Serializer.Deserialize(this.MsgPackData, this.serializerOptions.ShapeProvider); + } + catch (MessagePackSerializationException) + { + return null; + } + } + } + + protected internal override void SetExpectedDataType(Type dataType) + { + Verify.Operation(!this.MsgPackData.IsEmpty, "Data is no longer available or has already been deserialized."); + + this.Data = this.GetData(dataType); + + // Clear the source now that we've deserialized to prevent GetData from attempting + // deserialization later when the buffer may be recycled on another thread. + this.MsgPackData = default; + } + } + } +} diff --git a/src/StreamJsonRpc/Protocol/JsonRpcError.cs b/src/StreamJsonRpc/Protocol/JsonRpcError.cs index 7905eb811..143cd82cb 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcError.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcError.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.Runtime.Serialization; +using PolyType; using StreamJsonRpc.Reflection; using JsonNET = Newtonsoft.Json.Linq; using STJ = System.Text.Json.Serialization; @@ -13,14 +14,16 @@ namespace StreamJsonRpc.Protocol; /// Describes the error resulting from a that failed on the server. /// [DataContract] +[GenerateShape] [DebuggerDisplay("{" + nameof(DebuggerDisplay) + "}")] -public class JsonRpcError : JsonRpcMessage, IJsonRpcMessageWithId +public partial class JsonRpcError : JsonRpcMessage, IJsonRpcMessageWithId { /// /// Gets or sets the detail about the error. /// [DataMember(Name = "error", Order = 2, IsRequired = true)] [STJ.JsonPropertyName("error"), STJ.JsonPropertyOrder(2), STJ.JsonRequired] + [PropertyShape(Name = "error", Order = 2)] public ErrorDetail? Error { get; set; } /// @@ -30,6 +33,7 @@ public class JsonRpcError : JsonRpcMessage, IJsonRpcMessageWithId [Obsolete("Use " + nameof(RequestId) + " instead.")] [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public object? Id { get => this.RequestId.ObjectValue; @@ -41,6 +45,7 @@ public object? Id /// [DataMember(Name = "id", Order = 1, IsRequired = true, EmitDefaultValue = true)] [STJ.JsonPropertyName("id"), STJ.JsonPropertyOrder(1), STJ.JsonRequired] + [PropertyShape(Name = "id", Order = 1)] public RequestId RequestId { get; set; } /// @@ -77,6 +82,7 @@ public class ErrorDetail /// [DataMember(Name = "code", Order = 0, IsRequired = true)] [STJ.JsonPropertyName("code"), STJ.JsonPropertyOrder(0), STJ.JsonRequired] + [PropertyShape(Name = "code", Order = 0)] public JsonRpcErrorCode Code { get; set; } /// @@ -87,6 +93,7 @@ public class ErrorDetail /// [DataMember(Name = "message", Order = 1, IsRequired = true)] [STJ.JsonPropertyName("message"), STJ.JsonPropertyOrder(1), STJ.JsonRequired] + [PropertyShape(Name = "message", Order = 1)] public string? Message { get; set; } /// @@ -95,6 +102,7 @@ public class ErrorDetail [DataMember(Name = "data", Order = 2, IsRequired = false)] [Newtonsoft.Json.JsonProperty(DefaultValueHandling = Newtonsoft.Json.DefaultValueHandling.Ignore)] [STJ.JsonPropertyName("data"), STJ.JsonPropertyOrder(2)] + [PropertyShape(Name = "data", Order = 2)] public object? Data { get; set; } /// diff --git a/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs b/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs index 84acc9373..1a9a6edc9 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs @@ -3,6 +3,8 @@ using System.Diagnostics.CodeAnalysis; using System.Runtime.Serialization; +using Nerdbank.MessagePack; +using PolyType; using STJ = System.Text.Json.Serialization; namespace StreamJsonRpc.Protocol; @@ -14,6 +16,11 @@ namespace StreamJsonRpc.Protocol; [KnownType(typeof(JsonRpcRequest))] [KnownType(typeof(JsonRpcResult))] [KnownType(typeof(JsonRpcError))] +#pragma warning disable CS0618 //'KnownSubTypeAttribute.KnownSubTypeAttribute(Type)' is obsolete: 'Use the generic version of this attribute instead.' +[KnownSubType(typeof(JsonRpcRequest))] +[KnownSubType(typeof(JsonRpcResult))] +[KnownSubType(typeof(JsonRpcError))] +#pragma warning restore CS0618 public abstract class JsonRpcMessage { /// @@ -22,6 +29,7 @@ public abstract class JsonRpcMessage /// Defaults to "2.0". [DataMember(Name = "jsonrpc", Order = 0, IsRequired = true)] [STJ.JsonPropertyName("jsonrpc"), STJ.JsonPropertyOrder(0), STJ.JsonRequired] + [PropertyShape(Name = "jsonrpc", Order = 0)] public string Version { get; set; } = "2.0"; /// diff --git a/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs b/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs index c41239ac6..33f5a96cd 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.Reflection; using System.Runtime.Serialization; +using PolyType; using JsonNET = Newtonsoft.Json.Linq; using STJ = System.Text.Json.Serialization; @@ -13,8 +14,9 @@ namespace StreamJsonRpc.Protocol; /// Describes a method to be invoked on the server. /// [DataContract] +[GenerateShape] [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] -public class JsonRpcRequest : JsonRpcMessage, IJsonRpcMessageWithId +public partial class JsonRpcRequest : JsonRpcMessage, IJsonRpcMessageWithId { /// /// The result of an attempt to match request arguments with a candidate method's parameters. @@ -47,6 +49,7 @@ public enum ArgumentMatchResult /// [DataMember(Name = "method", Order = 2, IsRequired = true)] [STJ.JsonPropertyName("method"), STJ.JsonPropertyOrder(2), STJ.JsonRequired] + [PropertyShape(Name = "method", Order = 2)] public string? Method { get; set; } /// @@ -61,6 +64,7 @@ public enum ArgumentMatchResult /// [DataMember(Name = "params", Order = 3, IsRequired = false, EmitDefaultValue = false)] [STJ.JsonPropertyName("params"), STJ.JsonPropertyOrder(3), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] + [PropertyShape(Name = "params", Order = 3)] public object? Arguments { get; set; } /// @@ -70,6 +74,7 @@ public enum ArgumentMatchResult [Obsolete("Use " + nameof(RequestId) + " instead.")] [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public object? Id { get => this.RequestId.ObjectValue; @@ -81,6 +86,7 @@ public object? Id /// [DataMember(Name = "id", Order = 1, IsRequired = false, EmitDefaultValue = false)] [STJ.JsonPropertyName("id"), STJ.JsonPropertyOrder(1), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingDefault)] + [PropertyShape(Name = "id", Order = 1)] public RequestId RequestId { get; set; } /// @@ -88,6 +94,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public bool IsResponseExpected => !this.RequestId.IsEmpty; /// @@ -95,6 +102,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public bool IsNotification => this.RequestId.IsEmpty; /// @@ -102,6 +110,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public virtual int ArgumentCount => this.NamedArguments?.Count ?? this.ArgumentsList?.Count ?? 0; /// @@ -109,6 +118,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public IReadOnlyDictionary? NamedArguments { get => this.Arguments as IReadOnlyDictionary; @@ -127,6 +137,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public IReadOnlyDictionary? NamedArgumentDeclaredTypes { get; set; } /// @@ -134,6 +145,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] [Obsolete("Use " + nameof(ArgumentsList) + " instead.")] public object?[]? ArgumentsArray { @@ -146,6 +158,7 @@ public object?[]? ArgumentsArray /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public IReadOnlyList? ArgumentsList { get => this.Arguments as IReadOnlyList; @@ -166,6 +179,7 @@ public IReadOnlyList? ArgumentsList /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public IReadOnlyList? ArgumentListDeclaredTypes { get; set; } /// @@ -173,6 +187,7 @@ public IReadOnlyList? ArgumentsList /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public virtual IEnumerable? ArgumentNames => this.NamedArguments?.Keys; /// @@ -180,6 +195,7 @@ public IReadOnlyList? ArgumentsList /// [DataMember(Name = "traceparent", EmitDefaultValue = false)] [STJ.JsonPropertyName("traceparent"), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] + [PropertyShape(Name = "traceparent")] public string? TraceParent { get; set; } /// @@ -187,6 +203,7 @@ public IReadOnlyList? ArgumentsList /// [DataMember(Name = "tracestate", EmitDefaultValue = false)] [STJ.JsonPropertyName("tracestate"), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] + [PropertyShape(Name = "tracestate")] public string? TraceState { get; set; } /// diff --git a/src/StreamJsonRpc/Protocol/JsonRpcResult.cs b/src/StreamJsonRpc/Protocol/JsonRpcResult.cs index 6bd3157e6..0019e660e 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcResult.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcResult.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.Runtime.Serialization; +using PolyType; using JsonNET = Newtonsoft.Json.Linq; using STJ = System.Text.Json.Serialization; @@ -13,13 +14,15 @@ namespace StreamJsonRpc.Protocol; /// [DataContract] [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] -public class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId +[GenerateShape] +public partial class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId { /// /// Gets or sets the value of the result of an invocation, if any. /// [DataMember(Name = "result", Order = 2, IsRequired = true, EmitDefaultValue = true)] [STJ.JsonPropertyName("result"), STJ.JsonPropertyOrder(2), STJ.JsonRequired] + [PropertyShape(Name = "result", Order = 2)] public object? Result { get; set; } /// @@ -30,6 +33,7 @@ public class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public Type? ResultDeclaredType { get; set; } /// @@ -39,6 +43,7 @@ public class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId [Obsolete("Use " + nameof(RequestId) + " instead.")] [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public object? Id { get => this.RequestId.ObjectValue; @@ -50,6 +55,7 @@ public object? Id /// [DataMember(Name = "id", Order = 1, IsRequired = true)] [STJ.JsonPropertyName("id"), STJ.JsonPropertyOrder(1), STJ.JsonRequired] + [PropertyShape(Name = "id", Order = 1)] public RequestId RequestId { get; set; } /// diff --git a/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs b/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs index 07e5f29d6..e2158249e 100644 --- a/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs +++ b/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs @@ -7,7 +7,7 @@ namespace StreamJsonRpc; /// Designates an interface that is used in an RPC contract to marshal the object so the receiver can invoke remote methods on it instead of serializing the object to send its data to the remote end. /// /// -/// Learn more about marshable interfaces. +/// Learn more about marshalable interfaces. /// [AttributeUsage(AttributeTargets.Interface, AllowMultiple = false, Inherited = false)] public class RpcMarshalableAttribute : Attribute diff --git a/src/StreamJsonRpc/StreamJsonRpc.csproj b/src/StreamJsonRpc/StreamJsonRpc.csproj index 70fa95f12..b51740ebc 100644 --- a/src/StreamJsonRpc/StreamJsonRpc.csproj +++ b/src/StreamJsonRpc/StreamJsonRpc.csproj @@ -1,6 +1,6 @@  - netstandard2.0;netstandard2.1;net6.0;net8.0 + netstandard2.0;netstandard2.1;net8.0 prompt 4 true @@ -17,6 +17,7 @@ + diff --git a/test/Benchmarks/Benchmarks.csproj b/test/Benchmarks/Benchmarks.csproj index b3a91e41d..bbcc37f3f 100644 --- a/test/Benchmarks/Benchmarks.csproj +++ b/test/Benchmarks/Benchmarks.csproj @@ -2,7 +2,7 @@ Exe - net6.0;net472 + net8.0;net472 diff --git a/test/StreamJsonRpc.Tests/AssemblyLoadTests.cs b/test/StreamJsonRpc.Tests/AssemblyLoadTests.cs index ebf80d0d1..441ab0215 100644 --- a/test/StreamJsonRpc.Tests/AssemblyLoadTests.cs +++ b/test/StreamJsonRpc.Tests/AssemblyLoadTests.cs @@ -63,6 +63,27 @@ public void MessagePackDoesNotLoadNewtonsoftJsonUnnecessarily() } } + [Fact] + public void NerdbankMessagePackDoesNotLoadNewtonsoftJsonUnnecessarily() + { + AppDomain testDomain = CreateTestAppDomain(); + try + { + var driver = (AppDomainTestDriver)testDomain.CreateInstanceAndUnwrap(typeof(AppDomainTestDriver).Assembly.FullName, typeof(AppDomainTestDriver).FullName); + + this.PrintLoadedAssemblies(driver); + + driver.CreateNerdbankMessagePackConnection(); + + this.PrintLoadedAssemblies(driver); + driver.ThrowIfAssembliesLoaded("Newtonsoft.Json"); + } + finally + { + AppDomain.Unload(testDomain); + } + } + [Fact] public void MockFormatterDoesNotLoadJsonOrMessagePackUnnecessarily() { @@ -142,6 +163,11 @@ internal void CreateMessagePackConnection() var jsonRpc = new JsonRpc(new LengthHeaderMessageHandler(FullDuplexStream.CreatePipePair().Item1, new MessagePackFormatter())); } + internal void CreateNerdbankMessagePackConnection() + { + var jsonRpc = new JsonRpc(new LengthHeaderMessageHandler(FullDuplexStream.CreatePipePair().Item1, new NerdbankMessagePackFormatter())); + } + #pragma warning restore CA1822 // Mark members as static private class MockFormatter : IJsonRpcMessageFormatter diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs new file mode 100644 index 000000000..d4c5ee933 --- /dev/null +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +public class AsyncEnumerableNerdbankMessagePackTests : AsyncEnumerableTests +{ + public AsyncEnumerableNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override void InitializeFormattersAndHandlers() + { + this.serverMessageFormatter = new NerdbankMessagePackFormatter(); + this.clientMessageFormatter = new NerdbankMessagePackFormatter(); + } +} diff --git a/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs new file mode 100644 index 000000000..c4a381e65 --- /dev/null +++ b/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Nerdbank.MessagePack; + +public class DisposableProxyNerdbankMessagePackTests : DisposableProxyTests +{ + public DisposableProxyNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override Type FormatterExceptionType => typeof(MessagePackSerializationException); + + protected override IJsonRpcMessageFormatter CreateFormatter() => new NerdbankMessagePackFormatter(); +} diff --git a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs new file mode 100644 index 000000000..924929ae3 --- /dev/null +++ b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +public class DuplexPipeMarshalingNerdbankMessagePackTests : DuplexPipeMarshalingTests +{ + public DuplexPipeMarshalingNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override void InitializeFormattersAndHandlers() + { + this.serverMessageFormatter = new NerdbankMessagePackFormatter { MultiplexingStream = this.serverMx }; + this.clientMessageFormatter = new NerdbankMessagePackFormatter { MultiplexingStream = this.clientMx }; + } +} diff --git a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs new file mode 100644 index 000000000..277ee3287 --- /dev/null +++ b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs @@ -0,0 +1,543 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Runtime.CompilerServices; +using Microsoft.VisualStudio.Threading; +using Nerdbank.MessagePack; +using PolyType; +using PolyType.SourceGenerator; + +public partial class JsonRpcNerdbankMessagePackLengthTests : JsonRpcTests +{ + public JsonRpcNerdbankMessagePackLengthTests(ITestOutputHelper logger) + : base(logger) + { + } + + internal interface IMessagePackServer + { + Task ReturnUnionTypeAsync(CancellationToken cancellationToken); + + Task AcceptUnionTypeAndReturnStringAsync(UnionBaseClass value, CancellationToken cancellationToken); + + Task AcceptUnionTypeAsync(UnionBaseClass value, CancellationToken cancellationToken); + + Task ProgressUnionType(IProgress progress, CancellationToken cancellationToken); + + IAsyncEnumerable GetAsyncEnumerableOfUnionType(CancellationToken cancellationToken); + + Task IsExtensionArgNonNull(CustomExtensionType extensionValue); + } + + protected override Type FormatterExceptionType => typeof(MessagePackSerializationException); + + [Fact] + public override async Task CanPassAndCallPrivateMethodsObjects() + { + var result = await this.clientRpc.InvokeAsync(nameof(Server.MethodThatAcceptsFoo), new Foo { Bar = "bar", Bazz = 1000 }); + Assert.NotNull(result); + Assert.Equal("bar!", result.Bar); + Assert.Equal(1001, result.Bazz); + } + + [Fact] + public async Task ExceptionControllingErrorData() + { + var exception = await Assert.ThrowsAsync(() => this.clientRpc.InvokeAsync(nameof(Server.ThrowLocalRpcException))).WithCancellation(this.TimeoutToken); + + IDictionary? data = (IDictionary?)exception.ErrorData; + Assert.NotNull(data); + object myCustomData = data["myCustomData"]; + string actual = (string)myCustomData; + Assert.Equal("hi", actual); + } + + [Fact] + public override async Task CanPassExceptionFromServer_ErrorData() + { + RemoteInvocationException exception = await Assert.ThrowsAnyAsync(() => this.clientRpc.InvokeAsync(nameof(Server.MethodThatThrowsUnauthorizedAccessException))); + Assert.Equal((int)JsonRpcErrorCode.InvocationError, exception.ErrorCode); + + var errorData = Assert.IsType(exception.ErrorData); + Assert.NotNull(errorData.StackTrace); + Assert.StrictEqual(COR_E_UNAUTHORIZEDACCESS, errorData.HResult); + } + + /// + /// Verifies that return values can support union types by considering the return type as declared in the server method signature. + /// + [Fact] + public async Task UnionType_ReturnValue() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + UnionBaseClass result = await this.clientRpc.InvokeWithCancellationAsync(nameof(MessagePackServer.ReturnUnionTypeAsync), null, this.TimeoutToken); + Assert.IsType(result); + } + + /// + /// Verifies that return values can support union types by considering the return type as declared in the server method signature. + /// + [Fact] + public async Task UnionType_ReturnValue_NonAsync() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + UnionBaseClass result = await this.clientRpc.InvokeWithCancellationAsync(nameof(MessagePackServer.ReturnUnionType), null, this.TimeoutToken); + Assert.IsType(result); + } + + /// + /// Verifies that positional parameters can support union types by providing extra type information for each argument. + /// + [Theory] + [CombinatorialData] + public async Task UnionType_PositionalParameter_NoReturnValue(bool notify) + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + UnionBaseClass? receivedValue; + if (notify) + { + await this.clientRpc.NotifyAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), new object?[] { new UnionDerivedClass() }, new[] { typeof(UnionBaseClass) }).WithCancellation(this.TimeoutToken); + receivedValue = await server.ReceivedValueSource.Task.WithCancellation(this.TimeoutToken); + } + else + { + await this.clientRpc.InvokeWithCancellationAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), new object?[] { new UnionDerivedClass() }, new[] { typeof(UnionBaseClass) }, this.TimeoutToken); + receivedValue = server.ReceivedValue; + } + + Assert.IsType(receivedValue); + } + + /// + /// Verifies that positional parameters can support union types by providing extra type information for each argument. + /// + [Fact] + public async Task UnionType_PositionalParameter_AndReturnValue() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + string? result = await this.clientRpc.InvokeWithCancellationAsync(nameof(MessagePackServer.AcceptUnionTypeAndReturnStringAsync), new object?[] { new UnionDerivedClass() }, new[] { typeof(UnionBaseClass) }, this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + } + + /// + /// Verifies that the type information associated with named parameters is used for proper serialization of union types. + /// + [Theory] + [CombinatorialData] + public async Task UnionType_NamedParameter_NoReturnValue_UntypedDictionary(bool notify) + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var argument = new Dictionary { { "value", new UnionDerivedClass() } }; + var argumentDeclaredTypes = new Dictionary { { "value", typeof(UnionBaseClass) } }; + + UnionBaseClass? receivedValue; + if (notify) + { + await this.clientRpc.NotifyWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), argument, argumentDeclaredTypes).WithCancellation(this.TimeoutToken); + receivedValue = await server.ReceivedValueSource.Task.WithCancellation(this.TimeoutToken); + } + else + { + await this.clientRpc.InvokeWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), argument, argumentDeclaredTypes, this.TimeoutToken); + receivedValue = server.ReceivedValue; + } + + Assert.IsType(receivedValue); + + // Exercise the non-init path by repeating + server.ReceivedValueSource = new TaskCompletionSource(); + if (notify) + { + await this.clientRpc.NotifyWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), argument, argumentDeclaredTypes).WithCancellation(this.TimeoutToken); + receivedValue = await server.ReceivedValueSource.Task.WithCancellation(this.TimeoutToken); + } + else + { + await this.clientRpc.InvokeWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), argument, argumentDeclaredTypes, this.TimeoutToken); + receivedValue = server.ReceivedValue; + } + + Assert.IsType(receivedValue); + } + + /// + /// Verifies that the type information associated with named parameters is used for proper serialization of union types. + /// + [Fact] + public async Task UnionType_NamedParameter_AndReturnValue_UntypedDictionary() + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + string? result = await this.clientRpc.InvokeWithParameterObjectAsync( + nameof(MessagePackServer.AcceptUnionTypeAndReturnStringAsync), + new Dictionary { { "value", new UnionDerivedClass() } }, + new Dictionary { { "value", typeof(UnionBaseClass) } }, + this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + Assert.IsType(server.ReceivedValue); + + // Exercise the non-init path by repeating + result = await this.clientRpc.InvokeWithParameterObjectAsync( + nameof(MessagePackServer.AcceptUnionTypeAndReturnStringAsync), + new Dictionary { { "value", new UnionDerivedClass() } }, + new Dictionary { { "value", typeof(UnionBaseClass) } }, + this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + Assert.IsType(server.ReceivedValue); + } + + /// + /// Verifies that the type information associated with named parameters is used for proper serialization of union types. + /// + [Theory] + [CombinatorialData] + public async Task UnionType_NamedParameter_NoReturnValue(bool notify) + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var namedArgs = new { value = (UnionBaseClass)new UnionDerivedClass() }; + + UnionBaseClass? receivedValue; + if (notify) + { + await this.clientRpc.NotifyWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), namedArgs).WithCancellation(this.TimeoutToken); + receivedValue = await server.ReceivedValueSource.Task.WithCancellation(this.TimeoutToken); + } + else + { + await this.clientRpc.InvokeWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), namedArgs, this.TimeoutToken); + receivedValue = server.ReceivedValue; + } + + Assert.IsType(receivedValue); + + // Exercise the non-init path by repeating + server.ReceivedValueSource = new TaskCompletionSource(); + if (notify) + { + await this.clientRpc.NotifyWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), namedArgs).WithCancellation(this.TimeoutToken); + receivedValue = await server.ReceivedValueSource.Task.WithCancellation(this.TimeoutToken); + } + else + { + await this.clientRpc.InvokeWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAsync), namedArgs, this.TimeoutToken); + receivedValue = server.ReceivedValue; + } + + Assert.IsType(receivedValue); + } + + /// + /// Verifies that the type information associated with named parameters is used for proper serialization of union types. + /// + [Fact] + public async Task UnionType_NamedParameter_AndReturnValue() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + string? result = await this.clientRpc.InvokeWithParameterObjectAsync(nameof(MessagePackServer.AcceptUnionTypeAndReturnStringAsync), new { value = (UnionBaseClass)new UnionDerivedClass() }, this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + } + + /// + /// Verifies that return values can support union types by considering the return type as declared in the server method signature. + /// + [Fact] + public async Task UnionType_ReturnValue_Proxy() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + var clientProxy = this.clientRpc.Attach(); + UnionBaseClass result = await clientProxy.ReturnUnionTypeAsync(this.TimeoutToken); + Assert.IsType(result); + } + + /// + /// Verifies that positional parameters can support union types by providing extra type information for each argument. + /// + [Fact] + public async Task UnionType_PositionalParameter_AndReturnValue_Proxy() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + var clientProxy = this.clientRpc.Attach(); + string? result = await clientProxy.AcceptUnionTypeAndReturnStringAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + + // Repeat the proxy call to exercise the non-init path of the dynamically generated proxy. + result = await clientProxy.AcceptUnionTypeAndReturnStringAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + } + + /// + /// Verifies that the type information associated with named parameters is used for proper serialization of union types. + /// + [Fact] + public async Task UnionType_NamedParameter_AndReturnValue_Proxy() + { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(new MessagePackServer()); + var clientProxy = this.clientRpc.Attach(new JsonRpcProxyOptions { ServerRequiresNamedArguments = true }); + string? result = await clientProxy.AcceptUnionTypeAndReturnStringAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + + // Repeat the proxy call to exercise the non-init path of the dynamically generated proxy. + result = await clientProxy.AcceptUnionTypeAndReturnStringAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.Equal(typeof(UnionDerivedClass).Name, result); + } + + /// + /// Verifies that positional parameters can support union types by providing extra type information for each argument. + /// + [Fact] + public async Task UnionType_PositionalParameter_NoReturnValue_Proxy() + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var clientProxy = this.clientRpc.Attach(); + await clientProxy.AcceptUnionTypeAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.IsType(server.ReceivedValue); + + // Repeat the proxy call to exercise the non-init path of the dynamically generated proxy. + server.ReceivedValueSource = new TaskCompletionSource(); + await clientProxy.AcceptUnionTypeAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.IsType(server.ReceivedValue); + } + + /// + /// Verifies that the type information associated with named parameters is used for proper serialization of union types. + /// + [Fact] + public async Task UnionType_NamedParameter_NoReturnValue_Proxy() + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var clientProxy = this.clientRpc.Attach(new JsonRpcProxyOptions { ServerRequiresNamedArguments = true }); + await clientProxy.AcceptUnionTypeAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.IsType(server.ReceivedValue); + + // Repeat the proxy call to exercise the non-init path of the dynamically generated proxy. + server.ReceivedValueSource = new TaskCompletionSource(); + await clientProxy.AcceptUnionTypeAsync(new UnionDerivedClass(), this.TimeoutToken); + Assert.IsType(server.ReceivedValue); + } + + [Fact] + public async Task UnionType_AsIProgressTypeArgument() + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var clientProxy = this.clientRpc.Attach(); + + var reportSource = new TaskCompletionSource(); + var progress = new Progress(v => reportSource.SetResult(v)); + await clientProxy.ProgressUnionType(progress, this.TimeoutToken); + Assert.IsType(await reportSource.Task.WithCancellation(this.TimeoutToken)); + } + + [Fact] + public async Task UnionType_AsAsyncEnumerableTypeArgument() + { + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var clientProxy = this.clientRpc.Attach(); + + UnionBaseClass? actualItem = null; + await foreach (UnionBaseClass item in clientProxy.GetAsyncEnumerableOfUnionType(this.TimeoutToken)) + { + actualItem = item; + } + + Assert.IsType(actualItem); + } + + /// + /// Verifies that an argument that cannot be deserialized by the msgpack primitive formatter will not cause a failure. + /// + /// + /// This is a regression test for a bug where + /// verbose ETW tracing would fail to deserialize arguments with the primitive formatter that deserialize just fine for the actual method dispatch. + /// + [SkippableTheory, PairwiseData] + public async Task VerboseLoggingDoesNotFailWhenArgsDoNotDeserializePrimitively(bool namedArguments) + { + Skip.IfNot(SharedUtilities.GetEventSourceTestMode() == SharedUtilities.EventSourceTestMode.EmulateProduction, $"This test specifically verifies behavior when the EventSource should swallow exceptions. Current mode: {SharedUtilities.GetEventSourceTestMode()}."); + var server = new MessagePackServer(); + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLocalRpcTarget(server); + var clientProxy = this.clientRpc.Attach(new JsonRpcProxyOptions { ServerRequiresNamedArguments = namedArguments }); + + Assert.True(await clientProxy.IsExtensionArgNonNull(new CustomExtensionType())); + } + + protected override void InitializeFormattersAndHandlers( + Stream serverStream, + Stream clientStream, + out IJsonRpcMessageFormatter serverMessageFormatter, + out IJsonRpcMessageFormatter clientMessageFormatter, + out IJsonRpcMessageHandler serverMessageHandler, + out IJsonRpcMessageHandler clientMessageHandler, + bool controlledFlushingClient) + { + serverMessageFormatter = new NerdbankMessagePackFormatter(); + clientMessageFormatter = new NerdbankMessagePackFormatter(); + + ((NerdbankMessagePackFormatter)serverMessageFormatter).SetFormatterContext(Configure); + ((NerdbankMessagePackFormatter)clientMessageFormatter).SetFormatterContext(Configure); + + serverMessageHandler = new LengthHeaderMessageHandler(serverStream, serverStream, serverMessageFormatter); + clientMessageHandler = controlledFlushingClient + ? new DelayedFlushingHandler(clientStream, clientMessageFormatter) + : new LengthHeaderMessageHandler(clientStream, clientStream, clientMessageFormatter); + + static void Configure(NerdbankMessagePackFormatter.FormatterContextBuilder b) + { + b.RegisterConverter(new UnserializableTypeConverter()); + b.RegisterConverter(new TypeThrowsWhenDeserializedConverter()); + b.RegisterConverter(new CustomExtensionConverter()); + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + } + } + + protected override object[] CreateFormatterIntrinsicParamsObject(string arg) => []; + + [GenerateShape] +#pragma warning disable CS0618 + [KnownSubType(typeof(UnionDerivedClass))] +#pragma warning restore CS0618 + public abstract partial class UnionBaseClass + { + } + + [GenerateShape] + public partial class UnionDerivedClass : UnionBaseClass + { + } + + [GenerateShape] + internal partial class CustomExtensionType + { + } + + private class CustomExtensionConverter : MessagePackConverter + { + public override CustomExtensionType? Read(ref MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return null; + } + + if (reader.ReadExtensionHeader() is { TypeCode: 1, Length: 0 }) + { + return new(); + } + + throw new Exception("Unexpected extension header."); + } + + public override void Write(ref MessagePackWriter writer, in CustomExtensionType? value, SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + } + else + { + writer.Write(new Extension(1, default(Memory))); + } + } + } + + private class UnserializableTypeConverter : MessagePackConverter + { + public override CustomSerializedType Read(ref MessagePackReader reader, SerializationContext context) + { + return new CustomSerializedType { Value = reader.ReadString() }; + } + + public override void Write(ref MessagePackWriter writer, in CustomSerializedType? value, SerializationContext context) + { + writer.Write(value?.Value); + } + } + + private class TypeThrowsWhenDeserializedConverter : MessagePackConverter + { + public override TypeThrowsWhenDeserialized Read(ref MessagePackReader reader, SerializationContext context) + { + throw CreateExceptionToBeThrownByDeserializer(); + } + + public override void Write(ref MessagePackWriter writer, in TypeThrowsWhenDeserialized? value, SerializationContext context) + { + writer.WriteArrayHeader(0); + } + } + + private class MessagePackServer : IMessagePackServer + { + internal UnionBaseClass? ReceivedValue { get; private set; } + + internal TaskCompletionSource ReceivedValueSource { get; set; } = new TaskCompletionSource(); + + public Task ReturnUnionTypeAsync(CancellationToken cancellationToken) => Task.FromResult(new UnionDerivedClass()); + + public Task AcceptUnionTypeAndReturnStringAsync(UnionBaseClass value, CancellationToken cancellationToken) => Task.FromResult((this.ReceivedValue = value)?.GetType().Name); + + public Task AcceptUnionTypeAsync(UnionBaseClass value, CancellationToken cancellationToken) + { + this.ReceivedValue = value; + this.ReceivedValueSource.SetResult(value); + return Task.CompletedTask; + } + + public UnionBaseClass ReturnUnionType() => new UnionDerivedClass(); + + public Task ProgressUnionType(IProgress progress, CancellationToken cancellationToken) + { + progress.Report(new UnionDerivedClass()); + return Task.CompletedTask; + } + + public async IAsyncEnumerable GetAsyncEnumerableOfUnionType([EnumeratorCancellation] CancellationToken cancellationToken) + { + await Task.Yield(); + yield return new UnionDerivedClass(); + } + + public Task IsExtensionArgNonNull(CustomExtensionType extensionValue) => Task.FromResult(extensionValue is not null); + } + + private class DelayedFlushingHandler : LengthHeaderMessageHandler, IControlledFlushHandler + { + public DelayedFlushingHandler(Stream stream, IJsonRpcMessageFormatter formatter) + : base(stream, stream, formatter) + { + } + + public AsyncAutoResetEvent FlushEntered { get; } = new AsyncAutoResetEvent(); + + public AsyncManualResetEvent AllowFlushAsyncExit { get; } = new AsyncManualResetEvent(); + + protected override async ValueTask FlushAsync(CancellationToken cancellationToken) + { + this.FlushEntered.Set(); + await this.AllowFlushAsyncExit.WaitAsync(CancellationToken.None); + await base.FlushAsync(cancellationToken); + } + } +} diff --git a/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs new file mode 100644 index 000000000..13d44420a --- /dev/null +++ b/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Nerdbank.MessagePack; + +public class MarshalableProxyNerdbankMessagePackTests : MarshalableProxyTests +{ + public MarshalableProxyNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override Type FormatterExceptionType => typeof(MessagePackSerializationException); + + protected override IJsonRpcMessageFormatter CreateFormatter() => new NerdbankMessagePackFormatter(); +} diff --git a/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs new file mode 100644 index 000000000..9a62e6a5b --- /dev/null +++ b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs @@ -0,0 +1,453 @@ +using System.Diagnostics; +using System.Runtime.Serialization; +using Microsoft.VisualStudio.Threading; +using Nerdbank.MessagePack; +using Nerdbank.Streams; +using PolyType; +using PolyType.ReflectionProvider; +using PolyType.SourceGenerator; + +public partial class NerdbankMessagePackFormatterTests : FormatterTestBase +{ + public NerdbankMessagePackFormatterTests(ITestOutputHelper logger) + : base(logger) + { + } + + [Fact] + public void JsonRpcRequest_PositionalArgs() + { + var original = new JsonRpcRequest + { + RequestId = new RequestId(5), + Method = "test", + ArgumentsList = new object[] { 5, "hi", new CustomType { Age = 8 } }, + }; + + var actual = this.Roundtrip(original); + Assert.Equal(original.RequestId, actual.RequestId); + Assert.Equal(original.Method, actual.Method); + + Assert.True(actual.TryGetArgumentByNameOrIndex(null, 0, typeof(int), out object? actualArg0)); + Assert.Equal(original.ArgumentsList[0], actualArg0); + + Assert.True(actual.TryGetArgumentByNameOrIndex(null, 1, typeof(string), out object? actualArg1)); + Assert.Equal(original.ArgumentsList[1], actualArg1); + + Assert.True(actual.TryGetArgumentByNameOrIndex(null, 2, typeof(CustomType), out object? actualArg2)); + Assert.Equal(((CustomType?)original.ArgumentsList[2])!.Age, ((CustomType)actualArg2!).Age); + } + + [Fact] + public void JsonRpcRequest_NamedArgs() + { + var original = new JsonRpcRequest + { + RequestId = new RequestId(5), + Method = "test", + NamedArguments = new Dictionary + { + { "Number", 5 }, + { "Message", "hi" }, + { "Custom", new CustomType { Age = 8 } }, + }, + }; + + var actual = this.Roundtrip(original); + Assert.Equal(original.RequestId, actual.RequestId); + Assert.Equal(original.Method, actual.Method); + + Assert.True(actual.TryGetArgumentByNameOrIndex("Number", -1, typeof(int), out object? actualArg0)); + Assert.Equal(original.NamedArguments["Number"], actualArg0); + + Assert.True(actual.TryGetArgumentByNameOrIndex("Message", -1, typeof(string), out object? actualArg1)); + Assert.Equal(original.NamedArguments["Message"], actualArg1); + + Assert.True(actual.TryGetArgumentByNameOrIndex("Custom", -1, typeof(CustomType), out object? actualArg2)); + Assert.Equal(((CustomType?)original.NamedArguments["Custom"])!.Age, ((CustomType)actualArg2!).Age); + } + + [Fact] + public void JsonRpcResult() + { + var original = new JsonRpcResult + { + RequestId = new RequestId(5), + Result = new CustomType { Age = 7 }, + }; + + var actual = this.Roundtrip(original); + Assert.Equal(original.RequestId, actual.RequestId); + Assert.Equal(((CustomType?)original.Result)!.Age, actual.GetResult().Age); + } + + [Fact] + public void JsonRpcError() + { + var original = new JsonRpcError + { + RequestId = new RequestId(5), + Error = new JsonRpcError.ErrorDetail + { + Code = JsonRpcErrorCode.InvocationError, + Message = "Oops", + Data = new CustomType { Age = 15 }, + }, + }; + + var actual = this.Roundtrip(original); + Assert.Equal(original.RequestId, actual.RequestId); + Assert.Equal(original.Error.Code, actual.Error!.Code); + Assert.Equal(original.Error.Message, actual.Error.Message); + Assert.Equal(((CustomType)original.Error.Data).Age, actual.Error.GetData().Age); + } + + [Fact] + public async Task BasicJsonRpc() + { + var (clientStream, serverStream) = FullDuplexStream.CreatePair(); + var clientFormatter = new NerdbankMessagePackFormatter(); + var serverFormatter = new NerdbankMessagePackFormatter(); + + var clientHandler = new LengthHeaderMessageHandler(clientStream.UsePipe(), clientFormatter); + var serverHandler = new LengthHeaderMessageHandler(serverStream.UsePipe(), serverFormatter); + + var clientRpc = new JsonRpc(clientHandler); + var serverRpc = new JsonRpc(serverHandler, new Server()); + + serverRpc.TraceSource = new TraceSource("Server", SourceLevels.Verbose); + clientRpc.TraceSource = new TraceSource("Client", SourceLevels.Verbose); + + serverRpc.TraceSource.Listeners.Add(new XunitTraceListener(this.Logger)); + clientRpc.TraceSource.Listeners.Add(new XunitTraceListener(this.Logger)); + + clientRpc.StartListening(); + serverRpc.StartListening(); + + int result = await clientRpc.InvokeAsync(nameof(Server.Add), 3, 5).WithCancellation(this.TimeoutToken); + Assert.Equal(8, result); + } + + [Fact] + public void Resolver_RequestArgInArray() + { + this.Formatter.SetFormatterContext(b => + { + b.RegisterConverter(new CustomConverter()); + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + }); + + var originalArg = new TypeRequiringCustomFormatter { Prop1 = 3, Prop2 = 5 }; + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + ArgumentsList = new object[] { originalArg }, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(null, 0, typeof(TypeRequiringCustomFormatter), out object? roundtripArgObj)); + var roundtripArg = (TypeRequiringCustomFormatter)roundtripArgObj!; + Assert.Equal(originalArg.Prop1, roundtripArg.Prop1); + Assert.Equal(originalArg.Prop2, roundtripArg.Prop2); + } + + [Fact] + public void Resolver_RequestArgInNamedArgs_AnonymousType() + { + this.Formatter.SetFormatterContext(b => + { + b.RegisterConverter(new CustomConverter()); + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + }); + + var originalArg = new { Prop1 = 3, Prop2 = 5 }; + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + Arguments = originalArg, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.Prop1), -1, typeof(int), out object? prop1)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.Prop2), -1, typeof(int), out object? prop2)); + Assert.Equal(originalArg.Prop1, prop1); + Assert.Equal(originalArg.Prop2, prop2); + } + + [Fact] + public void Resolver_RequestArgInNamedArgs_DataContractObject() + { + this.Formatter.SetFormatterContext(b => + { + b.RegisterConverter(new CustomConverter()); + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + }); + + var originalArg = new DataContractWithSubsetOfMembersIncluded { ExcludedField = "A", ExcludedProperty = "B", IncludedField = "C", IncludedProperty = "D" }; + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + Arguments = originalArg, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.ExcludedField), -1, typeof(string), out object? _)); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.ExcludedProperty), -1, typeof(string), out object? _)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.IncludedField), -1, typeof(string), out object? includedField)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.IncludedProperty), -1, typeof(string), out object? includedProperty)); + Assert.Equal(originalArg.IncludedProperty, includedProperty); + Assert.Equal(originalArg.IncludedField, includedField); + } + + [Fact] + public void Resolver_RequestArgInNamedArgs_NonDataContractObject() + { + this.Formatter.SetFormatterContext(b => + { + b.RegisterConverter(new CustomConverter()); + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + }); + + var originalArg = new NonDataContractWithExcludedMembers { ExcludedField = "A", ExcludedProperty = "B", InternalField = "C", InternalProperty = "D", PublicField = "E", PublicProperty = "F" }; + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + Arguments = originalArg, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.ExcludedField), -1, typeof(string), out object? _)); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.ExcludedProperty), -1, typeof(string), out object? _)); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.InternalField), -1, typeof(string), out object? _)); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.InternalProperty), -1, typeof(string), out object? _)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.PublicField), -1, typeof(string), out object? publicField)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.PublicProperty), -1, typeof(string), out object? publicProperty)); + Assert.Equal(originalArg.PublicProperty, publicProperty); + Assert.Equal(originalArg.PublicField, publicField); + } + + [Fact] + public void Resolver_RequestArgInNamedArgs_NullObject() + { + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + Arguments = null, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.Null(roundtripRequest.Arguments); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex("AnythingReally", -1, typeof(string), out object? _)); + } + + [Fact] + public void Resolver_Result() + { + this.Formatter.SetFormatterContext(b => + { + b.RegisterConverter(new CustomConverter()); + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + }); + + var originalResultValue = new TypeRequiringCustomFormatter { Prop1 = 3, Prop2 = 5 }; + var originalResult = new JsonRpcResult + { + RequestId = new RequestId(1), + Result = originalResultValue, + }; + var roundtripResult = this.Roundtrip(originalResult); + var roundtripResultValue = roundtripResult.GetResult(); + Assert.Equal(originalResultValue.Prop1, roundtripResultValue.Prop1); + Assert.Equal(originalResultValue.Prop2, roundtripResultValue.Prop2); + } + + [Fact] + public void Resolver_ErrorData() + { + this.Formatter.SetFormatterContext(b => + { + b.RegisterConverter(new CustomConverter()); + b.AddTypeShapeProvider(ShapeProvider_StreamJsonRpc_Tests.Default); + }); + + var originalErrorData = new TypeRequiringCustomFormatter { Prop1 = 3, Prop2 = 5 }; + var originalError = new JsonRpcError + { + RequestId = new RequestId(1), + Error = new JsonRpcError.ErrorDetail + { + Data = originalErrorData, + }, + }; + var roundtripError = this.Roundtrip(originalError); + var roundtripErrorData = roundtripError.Error!.GetData(); + Assert.Equal(originalErrorData.Prop1, roundtripErrorData.Prop1); + Assert.Equal(originalErrorData.Prop2, roundtripErrorData.Prop2); + } + + [Fact] + public void CanDeserializeWithExtraProperty_JsonRpcRequest() + { + var dynamic = new + { + jsonrpc = "2.0", + method = "something", + extra = (object?)null, + @params = new object[] { "hi" }, + }; + var request = this.Read(dynamic); + Assert.Equal(dynamic.jsonrpc, request.Version); + Assert.Equal(dynamic.method, request.Method); + Assert.Equal(dynamic.@params.Length, request.ArgumentCount); + Assert.True(request.TryGetArgumentByNameOrIndex(null, 0, typeof(string), out object? arg)); + Assert.Equal(dynamic.@params[0], arg); + } + + [Fact] + public void CanDeserializeWithExtraProperty_JsonRpcResult() + { + var dynamic = new + { + jsonrpc = "2.0", + id = 2, + extra = (object?)null, + result = "hi", + }; + var request = this.Read(dynamic); + Assert.Equal(dynamic.jsonrpc, request.Version); + Assert.Equal(dynamic.id, request.RequestId.Number); + Assert.Equal(dynamic.result, request.GetResult()); + } + + [Fact] + public void CanDeserializeWithExtraProperty_JsonRpcError() + { + var dynamic = new + { + jsonrpc = "2.0", + id = 2, + extra = (object?)null, + error = new { extra = 2, code = 5 }, + }; + var request = this.Read(dynamic); + Assert.Equal(dynamic.jsonrpc, request.Version); + Assert.Equal(dynamic.id, request.RequestId.Number); + Assert.Equal(dynamic.error.code, (int?)request.Error?.Code); + } + + [Fact] + public void StringsInUserDataAreInterned() + { + var dynamic = new + { + jsonrpc = "2.0", + method = "something", + extra = (object?)null, + @params = new object[] { "hi" }, + }; + var request1 = this.Read(dynamic); + var request2 = this.Read(dynamic); + Assert.True(request1.TryGetArgumentByNameOrIndex(null, 0, typeof(string), out object? arg1)); + Assert.True(request2.TryGetArgumentByNameOrIndex(null, 0, typeof(string), out object? arg2)); + Assert.Same(arg2, arg1); // reference equality to ensure it was interned. + } + + [Fact] + public void StringValuesOfStandardPropertiesAreInterned() + { + var dynamic = new + { + jsonrpc = "2.0", + method = "something", + extra = (object?)null, + @params = Array.Empty(), + }; + var request1 = this.Read(dynamic); + var request2 = this.Read(dynamic); + Assert.Same(request1.Method, request2.Method); // reference equality to ensure it was interned. + } + + protected override NerdbankMessagePackFormatter CreateFormatter() => new(); + + private T Read(object anonymousObject) + where T : JsonRpcMessage + { + var sequence = new Sequence(); + var writer = new MessagePackWriter(sequence); + new MessagePackSerializer().Serialize(ref writer, anonymousObject, ReflectionTypeShapeProvider.Default); + writer.Flush(); + return (T)this.Formatter.Deserialize(sequence); + } + + [DataContract] + [GenerateShape] + public partial class DataContractWithSubsetOfMembersIncluded + { + [PropertyShape(Ignore = true)] + public string? ExcludedField; + + [DataMember] + internal string? IncludedField; + + [PropertyShape(Ignore = true)] + public string? ExcludedProperty { get; set; } + + [DataMember] + internal string? IncludedProperty { get; set; } + } + + [GenerateShape] + public partial class NonDataContractWithExcludedMembers + { + [IgnoreDataMember] + [PropertyShape(Ignore = true)] + public string? ExcludedField; + + public string? PublicField; + + internal string? InternalField; + + [IgnoreDataMember] + [PropertyShape(Ignore = true)] + public string? ExcludedProperty { get; set; } + + public string? PublicProperty { get; set; } + + internal string? InternalProperty { get; set; } + } + + [GenerateShape] + public partial class TypeRequiringCustomFormatter + { + internal int Prop1 { get; set; } + + internal int Prop2 { get; set; } + } + + private class CustomConverter : MessagePackConverter + { + public override TypeRequiringCustomFormatter Read(ref MessagePackReader reader, SerializationContext context) + { + Assert.Equal(2, reader.ReadArrayHeader()); + return new TypeRequiringCustomFormatter + { + Prop1 = reader.ReadInt32(), + Prop2 = reader.ReadInt32(), + }; + } + + public override void Write(ref MessagePackWriter writer, in TypeRequiringCustomFormatter? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + writer.WriteArrayHeader(2); + writer.Write(value.Prop1); + writer.Write(value.Prop2); + } + } + + private class Server + { + public int Add(int a, int b) => a + b; + } +} diff --git a/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs new file mode 100644 index 000000000..7f46b1da4 --- /dev/null +++ b/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +public class ObserverMarshalingNerdbankMessagePackTests : ObserverMarshalingTests +{ + public ObserverMarshalingNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override IJsonRpcMessageFormatter CreateFormatter() => new NerdbankMessagePackFormatter(); +} diff --git a/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj b/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj index 82fb00a07..f7bf2232d 100644 --- a/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj +++ b/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj @@ -1,7 +1,7 @@  - net6.0;net8.0 + net8.0 $(TargetFrameworks);net472 @@ -11,30 +11,39 @@ + + + + + + + + + diff --git a/test/StreamJsonRpc.Tests/TargetObjectEventsNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/TargetObjectEventsNerdbankMessagePackTests.cs new file mode 100644 index 000000000..12152f55d --- /dev/null +++ b/test/StreamJsonRpc.Tests/TargetObjectEventsNerdbankMessagePackTests.cs @@ -0,0 +1,16 @@ +public class TargetObjectEventsNerdbankMessagePackTests : TargetObjectEventsTests +{ + public TargetObjectEventsNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override void InitializeFormattersAndHandlers() + { + var serverMessageFormatter = new NerdbankMessagePackFormatter(); + var clientMessageFormatter = new NerdbankMessagePackFormatter(); + + this.serverMessageHandler = new LengthHeaderMessageHandler(this.serverStream, this.serverStream, serverMessageFormatter); + this.clientMessageHandler = new LengthHeaderMessageHandler(this.clientStream, this.clientStream, clientMessageFormatter); + } +} diff --git a/test/StreamJsonRpc.Tests/WebSocketMessageHandlerNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/WebSocketMessageHandlerNerdbankMessagePackTests.cs new file mode 100644 index 000000000..b14c7b25b --- /dev/null +++ b/test/StreamJsonRpc.Tests/WebSocketMessageHandlerNerdbankMessagePackTests.cs @@ -0,0 +1,7 @@ +public class WebSocketMessageHandlerNerdbankMessagePackTests : WebSocketMessageHandlerTests +{ + public WebSocketMessageHandlerNerdbankMessagePackTests(ITestOutputHelper logger) + : base(new NerdbankMessagePackFormatter(), logger) + { + } +}