diff --git a/src/Proto.Remote/Serialization/ForcedSerializationSenderMiddleware.cs b/src/Proto.Remote/Serialization/ForcedSerializationSenderMiddleware.cs new file mode 100644 index 0000000000..db7e1dbbfa --- /dev/null +++ b/src/Proto.Remote/Serialization/ForcedSerializationSenderMiddleware.cs @@ -0,0 +1,92 @@ +// ----------------------------------------------------------------------- +// +// Copyright (C) 2015-2022 Asynkron AB All rights reserved +// +// ----------------------------------------------------------------------- +using System; +using Google.Protobuf; +using Microsoft.Extensions.Logging; + +namespace Proto.Remote; + +public static class ForcedSerializationSenderMiddleware +{ + private static readonly ILogger Logger = Log.CreateLogger(nameof(ForcedSerializationSenderMiddleware)); + + /// + /// Returns sender middleware that forces serialization of the message. This middleware serializes and then deserializes the message before + /// sending it further down the pipeline. It simulates the serialization process in . + /// Useful for testing if serialization is working correctly and the messages are immutable. + /// + /// + /// A predicate that can prevent serialization by returning false. + /// If null, it defaults to + /// + /// + /// Middleware configuration function, to be used with WithSenderMiddleware on + /// or on configuration + /// + public static Func Create(Func? shouldSerialize = null) + { + shouldSerialize ??= SkipInternalProtoMessages; + + return next => + (context, target, envelope) => { + object? message = null; + PID? sender; + Proto.MessageHeader headers; + + try + { + if (shouldSerialize?.Invoke(envelope) == false) + return next(context, target, envelope); + + var serialization = context.System.Serialization(); + + // serialize + (message, sender, headers) = Proto.MessageEnvelope.Unwrap(envelope); + + if (message is IRootSerializable rootSerializable) + message = rootSerializable.Serialize(context.System); + + if (message is null) + throw new Exception("Null message passed to the forced serialization middleware"); + + var (bytes, typeName, serializerId) = serialization.Serialize(message); + + // deserialize + var deserializedMessage = serialization.Deserialize(typeName, bytes, serializerId); + + if (message is IRootSerialized rootDeserialized) + deserializedMessage = rootDeserialized.Deserialize(context.System); + + // forward + var newEnvelope = new Proto.MessageEnvelope(deserializedMessage, sender, headers); + + return next(context, target, newEnvelope); + } + catch (CodedOutputStream.OutOfSpaceException oom) + { + Logger.LogError(oom, "Message is too large for serialization {Message}", message?.GetType().Name); + throw; + } + catch (Exception ex) + { + ex.CheckFailFast(); + Logger.LogError(ex, "Forced serialization -> deserialization failed for message {Message}", message?.GetType().Name); + throw; + } + }; + } + + /// + /// Predicate to skip serialization of internal Proto messages + /// + /// + /// + public static bool SkipInternalProtoMessages(Proto.MessageEnvelope envelope) + { + var (message, _, _) = Proto.MessageEnvelope.Unwrap(envelope); + return message.GetType().FullName?.StartsWith("Proto.") == false; + } +} \ No newline at end of file diff --git a/tests/Proto.Cluster.Tests/ForcedSerializationTests.cs b/tests/Proto.Cluster.Tests/ForcedSerializationTests.cs new file mode 100644 index 0000000000..046e5e049f --- /dev/null +++ b/tests/Proto.Cluster.Tests/ForcedSerializationTests.cs @@ -0,0 +1,59 @@ +// ----------------------------------------------------------------------- +// +// Copyright (C) 2015-2022 Asynkron AB All rights reserved +// +// ----------------------------------------------------------------------- +using System.Linq; +using System.Threading.Tasks; +using ClusterTest.Messages; +using FluentAssertions; +using Proto.Cluster.Gossip; +using Proto.Remote; +using Xunit; + +namespace Proto.Cluster.Tests; + +public class ForcedSerializationTests +{ + [Fact] + public async Task Forced_serialization_works_correctly_in_a_cluster() + { + await using var fixture = new ForcedSerializationClusterFixture(); + await fixture.InitializeAsync(); + var entryMember = fixture.Members.First(); + + var testData = Enumerable.Range(1, 100).Select(i => i.ToString()).ToList(); + + var tasks = testData.Select(id => entryMember.Ping(id, id, CancellationTokens.FromSeconds(10))).ToList(); + await Task.WhenAll(tasks); + + var results = tasks.Select(t => t.Result.Message).ToList(); + + results.Should().BeEquivalentTo(testData); + } + + [Fact] + public void The_test_messages_are_allowed_by_the_default_predicate() + { + var predicate = ForcedSerializationSenderMiddleware.SkipInternalProtoMessages; + + predicate(MessageEnvelope.Wrap(new Ping())).Should().BeTrue(); + } + + [Fact] + public void Sample_internal_proto_messages_are_not_allowed_by_the_default_predicate() + { + var predicate = ForcedSerializationSenderMiddleware.SkipInternalProtoMessages; + + predicate(MessageEnvelope.Wrap(new GetGossipStateRequest("test"))).Should().BeFalse(); + predicate(MessageEnvelope.Wrap(new GossipState())).Should().BeFalse(); + } + + private class ForcedSerializationClusterFixture : InMemoryClusterFixture + { + protected override ActorSystemConfig GetActorSystemConfig() => + base.GetActorSystemConfig().WithConfigureRootContext( + conf => conf.WithSenderMiddleware(ForcedSerializationSenderMiddleware.Create()) + ); + } +} \ No newline at end of file diff --git a/tests/Proto.Remote.Tests/ForcedSerializationTests.cs b/tests/Proto.Remote.Tests/ForcedSerializationTests.cs new file mode 100644 index 0000000000..384ed9fe68 --- /dev/null +++ b/tests/Proto.Remote.Tests/ForcedSerializationTests.cs @@ -0,0 +1,194 @@ +// ----------------------------------------------------------------------- +// +// Copyright (C) 2015-2022 Asynkron AB All rights reserved +// +// ----------------------------------------------------------------------- +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using FluentAssertions; +using ForcedSerialization.TestMessages; +using Proto; +using Proto.Remote; +using Xunit; + +namespace Proto.Remote.Tests +{ + public class ForcedSerializationTests + { + private object _receivedMessage; + private PID _sender; + private Proto.MessageHeader _header; + private readonly ManualResetEvent _wait = new(false); + private readonly Props _receivingActorProps; + private readonly Props _sendingActorProps; + + public ForcedSerializationTests() + { + _receivingActorProps = Props.FromFunc(ctx => { + if (ctx.Message is TestMessage or TestRootSerializableMessage) + { + _receivedMessage = ctx.Message; + _sender = ctx.Sender; + _header = ctx.Headers; + ctx.Respond(new TestResponse()); + _wait.Set(); + } + + return Task.CompletedTask; + } + ); + + _sendingActorProps = Props.FromFunc(ctx => { + switch (ctx.Message) + { + case RunRequestAsync msg: + _ = ctx.RequestWithHeadersAsync(msg.Target, new TestMessage("From another actor"), msg.Headers); + break; + case RunRequest msg: + ctx.Request(msg.Target, new TestMessage("From another actor")); + break; + } + + return Task.CompletedTask; + } + ).WithSenderMiddleware(ForcedSerializationSenderMiddleware.Create()); + } + + [Fact] + public void The_test_messages_are_allowed_by_the_default_predicate() + { + var predicate = ForcedSerializationSenderMiddleware.SkipInternalProtoMessages; + + predicate(Proto.MessageEnvelope.Wrap(new TestMessage("test"))).Should().BeTrue(); + predicate(Proto.MessageEnvelope.Wrap(new TestRootSerializableMessage("test"))).Should().BeTrue(); + } + + [Fact] + public void Sample_internal_proto_messages_are_not_allowed_by_the_default_predicate() + { + var predicate = ForcedSerializationSenderMiddleware.SkipInternalProtoMessages; + + predicate(Proto.MessageEnvelope.Wrap(Started.Instance)).Should().BeFalse(); + predicate(Proto.MessageEnvelope.Wrap(new RemoteDeliver(null!, null!, null!, null))).Should().BeFalse(); + } + + [Fact] + public void It_serializes_and_deserializes() + { + var system = new ActorSystem(ActorSystemConfig.Setup() + .WithConfigureRootContext(ctx => ctx.WithSenderMiddleware( + ForcedSerializationSenderMiddleware.Create() + ) + ) + ); + system.Extensions.Register(new Serialization()); + + var pid = system.Root.Spawn(_receivingActorProps); + var sentMessage = new TestMessage("Serialized"); + system.Root.Send(pid, sentMessage); + + _wait.WaitOne(TimeSpan.FromSeconds(2)); + + _receivedMessage.Should().BeEquivalentTo(sentMessage, "the received message should be the same as the sent message"); + _receivedMessage.Should().NotBeSameAs(sentMessage, "the message should have been serialized"); + } + + [Fact] + public void It_should_not_serialize_if_predicate_prevents_it() + { + var system = new ActorSystem(ActorSystemConfig.Setup() + .WithConfigureRootContext(ctx => ctx.WithSenderMiddleware( + ForcedSerializationSenderMiddleware.Create(_ => false) + ) + ) + ); + system.Extensions.Register(new Serialization()); + + var pid = system.Root.Spawn(_receivingActorProps); + var sentMessage = new TestMessage("Not serialized"); + system.Root.Send(pid, sentMessage); + + _wait.WaitOne(TimeSpan.FromSeconds(2)); + + _receivedMessage.Should().BeEquivalentTo(sentMessage, "the received message should be the same as the sent message"); + _receivedMessage.Should().BeSameAs(sentMessage, "the message should not have been serialized"); + } + + [Fact] + public async Task It_preserves_headers() + { + await using var system = new ActorSystem(ActorSystemConfig.Setup()); + system.Extensions.Register(new Serialization()); + + var pid = system.Root.Spawn(_receivingActorProps); + var sender = system.Root.Spawn(_sendingActorProps); + + var headers = new Proto.MessageHeader(new Dictionary {{"key", "value"}}); + system.Root.Send(sender, new RunRequestAsync(pid, headers)); + + _wait.WaitOne(TimeSpan.FromSeconds(2)); + + _header.Should().BeEquivalentTo(headers); + } + + [Fact] + public async Task It_preserves_sender() + { + await using var system = new ActorSystem(ActorSystemConfig.Setup()); + system.Extensions.Register(new Serialization()); + + var pid = system.Root.Spawn(_receivingActorProps); + var sender = system.Root.Spawn(_sendingActorProps); + + system.Root.Send(sender, new RunRequest(pid, null)); + + _wait.WaitOne(TimeSpan.FromSeconds(2)); + + _sender.Should().BeEquivalentTo(sender); + } + + [Fact] + public async Task It_can_handle_root_serializable() + { + await using var system = new ActorSystem(ActorSystemConfig.Setup() + .WithConfigureRootContext(ctx => ctx.WithSenderMiddleware( + ForcedSerializationSenderMiddleware.Create() + ) + ) + ); + system.Extensions.Register(new Serialization()); + + var pid = system.Root.Spawn(_receivingActorProps); + var sentMessage = new TestRootSerializableMessage("Serialized"); + system.Root.Send(pid, sentMessage); + + _wait.WaitOne(TimeSpan.FromSeconds(2)); + + _receivedMessage.Should().BeEquivalentTo(sentMessage, "the received message should be the same as the sent message"); + _receivedMessage.Should().NotBeSameAs(sentMessage, "the message should have been serialized"); + } + } +} + +namespace ForcedSerialization.TestMessages +{ + record TestMessage(string Value); + + record TestRootSerializableMessage(string Value) : IRootSerializable + { + public IRootSerialized Serialize(ActorSystem system) => new TestRootSerializedMessage(Value); + } + + record TestRootSerializedMessage(string Value) : IRootSerialized + { + public IRootSerializable Deserialize(ActorSystem system) => new TestRootSerializableMessage(Value); + } + + record TestResponse(); + + record RunRequest(PID Target, Proto.MessageHeader Headers); + + record RunRequestAsync(PID Target, Proto.MessageHeader Headers); +} \ No newline at end of file