diff --git a/src/contrib/cluster/Akka.Cluster.Sharding.Tests/ShardEntityFailureSpec.cs b/src/contrib/cluster/Akka.Cluster.Sharding.Tests/ShardEntityFailureSpec.cs index 6d64b399922..972a7ab92b4 100644 --- a/src/contrib/cluster/Akka.Cluster.Sharding.Tests/ShardEntityFailureSpec.cs +++ b/src/contrib/cluster/Akka.Cluster.Sharding.Tests/ShardEntityFailureSpec.cs @@ -135,7 +135,8 @@ public async Task Persistent_Shard_must_recover_from_failing_entity(Props entity settings, new TestMessageExtractor(), PoisonPill.Instance, - provider + provider, + null )); Sys.EventStream.Subscribe(TestActor); diff --git a/src/contrib/cluster/Akka.Cluster.Sharding.Tests/ShardingBufferAdapterSpec.cs b/src/contrib/cluster/Akka.Cluster.Sharding.Tests/ShardingBufferAdapterSpec.cs index c4cff5cefc3..dc708b039b2 100644 --- a/src/contrib/cluster/Akka.Cluster.Sharding.Tests/ShardingBufferAdapterSpec.cs +++ b/src/contrib/cluster/Akka.Cluster.Sharding.Tests/ShardingBufferAdapterSpec.cs @@ -68,6 +68,11 @@ public object Apply(object message, IActorContext context) _counter.IncrementAndGet(); return message; } + + public object UnApply(object message, IActorContext context) + { + return message; + } } private const string ShardTypeName = "Caat"; diff --git a/src/contrib/cluster/Akka.Cluster.Sharding.Tests/WrappedShardBufferedMessageSpec.cs b/src/contrib/cluster/Akka.Cluster.Sharding.Tests/WrappedShardBufferedMessageSpec.cs new file mode 100644 index 00000000000..c0c2434b10f --- /dev/null +++ b/src/contrib/cluster/Akka.Cluster.Sharding.Tests/WrappedShardBufferedMessageSpec.cs @@ -0,0 +1,220 @@ +// ----------------------------------------------------------------------- +// +// Copyright (C) 2009-2025 Lightbend Inc. +// Copyright (C) 2013-2025 .NET Foundation +// +// ----------------------------------------------------------------------- + +using System.Collections.Immutable; +using System.Threading.Tasks; +using Akka.Actor; +using Akka.Cluster.Sharding.Internal; +using Akka.Cluster.Tools.Singleton; +using Akka.Configuration; +using Akka.Event; +using Akka.TestKit; +using FluentAssertions; +using Xunit; +using Xunit.Abstractions; + +namespace Akka.Cluster.Sharding.Tests; + +public class WrappedShardBufferedMessageSpec: AkkaSpec +{ + #region Custom Classes + + private sealed class MyEnvelope : IWrappedMessage + { + public MyEnvelope(object message) + { + Message = message; + } + + public object Message { get; } + } + + private sealed class BufferMessageAdapter: IShardingBufferMessageAdapter + { + public object Apply(object message, IActorContext context) + => new MyEnvelope(message); + + public object UnApply(object message, IActorContext context) + { + return message is MyEnvelope envelope ? envelope.Message : message; + } + } + + private class EchoActor: UntypedActor + { + private readonly ILoggingAdapter _log = Context.GetLogger(); + protected override void OnReceive(object message) + { + _log.Info($">>>> OnReceive {message.GetType()}: {message}"); + if(message is string) + Sender.Tell(message); + else + Unhandled(message); + } + } + + private sealed class FakeRememberEntitiesProvider: IRememberEntitiesProvider + { + private readonly IActorRef _probe; + + public FakeRememberEntitiesProvider(IActorRef probe) + { + _probe = probe; + } + + public Props CoordinatorStoreProps() => FakeCoordinatorStoreActor.Props(); + + public Props ShardStoreProps(string shardId) => FakeShardStoreActor.Props(shardId, _probe); + } + + private class ShardStoreCreated + { + public ShardStoreCreated(IActorRef store, string shardId) + { + Store = store; + ShardId = shardId; + } + + public IActorRef Store { get; } + public string ShardId { get; } + } + + private class CoordinatorStoreCreated + { + public CoordinatorStoreCreated(IActorRef store) + { + Store = store; + } + + public IActorRef Store { get; } + } + + private class FakeShardStoreActor : ActorBase + { + public static Props Props(string shardId, IActorRef probe) => Actor.Props.Create(() => new FakeShardStoreActor(shardId, probe)); + + private readonly string _shardId; + private readonly IActorRef _probe; + + public FakeShardStoreActor(string shardId, IActorRef probe) + { + _shardId = shardId; + _probe = probe; + Context.System.EventStream.Publish(new ShardStoreCreated(Self, shardId)); + } + + protected override bool Receive(object message) + { + switch (message) + { + case RememberEntitiesShardStore.GetEntities: + Sender.Tell(new RememberEntitiesShardStore.RememberedEntities(ImmutableHashSet.Empty)); + return true; + case RememberEntitiesShardStore.Update m: + _probe.Tell(new RememberEntitiesShardStore.UpdateDone(m.Started, m.Stopped)); + return true; + } + return false; + } + } + + private class FakeCoordinatorStoreActor : ActorBase + { + public static Props Props() => Actor.Props.Create(() => new FakeCoordinatorStoreActor()); + + public FakeCoordinatorStoreActor() + { + Context.System.EventStream.Publish(new CoordinatorStoreCreated(Context.Self)); + } + + protected override bool Receive(object message) + { + switch (message) + { + case RememberEntitiesCoordinatorStore.GetShards _: + Sender.Tell(new RememberEntitiesCoordinatorStore.RememberedShards(ImmutableHashSet.Empty)); + return true; + case RememberEntitiesCoordinatorStore.AddShard m: + Sender.Tell(new RememberEntitiesCoordinatorStore.UpdateDone(m.ShardId)); + return true; + } + return false; + } + } + + private static Config GetConfig() + { + return ConfigurationFactory.ParseString(@" + akka.loglevel=DEBUG + akka.actor.provider = cluster + akka.remote.dot-netty.tcp.port = 0 + akka.cluster.sharding.state-store-mode = ddata + akka.cluster.sharding.remember-entities = on + + # no leaks between test runs thank you + akka.cluster.sharding.distributed-data.durable.keys = [] + akka.cluster.sharding.verbose-debug-logging = on + akka.cluster.sharding.fail-on-invalid-entity-state-transition = on") + + .WithFallback(Sharding.ClusterSharding.DefaultConfig()) + .WithFallback(DistributedData.DistributedData.DefaultConfig()) + .WithFallback(ClusterSingleton.DefaultConfig()); + } + + #endregion + + private const string Msg = "hit"; + private readonly IActorRef _shard; + private IActorRef _store; + + public WrappedShardBufferedMessageSpec(ITestOutputHelper output) : base(GetConfig(), output) + { + Sys.EventStream.Subscribe(TestActor, typeof(ShardStoreCreated)); + Sys.EventStream.Subscribe(TestActor, typeof(CoordinatorStoreCreated)); + + _shard = ChildActorOf(Shard.Props( + typeName: "test", + shardId: "test", + entityProps: _ => Props.Create(() => new EchoActor()), + settings: ClusterShardingSettings.Create(Sys), + extractor: new ExtractorAdapter(HashCodeMessageExtractor.Create(10, m => m.ToString())), + handOffStopMessage: PoisonPill.Instance, + rememberEntitiesProvider: new FakeRememberEntitiesProvider(TestActor), + bufferMessageAdapter: new BufferMessageAdapter())); + } + + private async Task ExpectShardStartup() + { + var createdEvent = await ExpectMsgAsync(); + createdEvent.ShardId.Should().Be("test"); + + _store = createdEvent.Store; + + await ExpectMsgAsync(); + + _shard.Tell(new ShardRegion.StartEntity(Msg)); + + return await ExpectMsgAsync(); + } + + [Fact(DisplayName = "Message wrapped in ShardingEnvelope, buffered by Shard, transformed by BufferMessageAdapter, must arrive in entity actor")] + public async Task WrappedMessageDelivery() + { + IgnoreMessages(); + + var continueMessage = await ExpectShardStartup(); + + // this message should be buffered + _shard.Tell(new ShardingEnvelope(Msg, Msg)); + await Task.Yield(); + + // Tell shard to continue processing + _shard.Tell(continueMessage); + + await ExpectMsgAsync(Msg); + } +} \ No newline at end of file diff --git a/src/contrib/cluster/Akka.Cluster.Sharding/IShardingBufferMessageAdapter.cs b/src/contrib/cluster/Akka.Cluster.Sharding/IShardingBufferMessageAdapter.cs index dc30db6e20c..89231850f8a 100644 --- a/src/contrib/cluster/Akka.Cluster.Sharding/IShardingBufferMessageAdapter.cs +++ b/src/contrib/cluster/Akka.Cluster.Sharding/IShardingBufferMessageAdapter.cs @@ -14,6 +14,7 @@ namespace Akka.Cluster.Sharding; public interface IShardingBufferMessageAdapter { public object Apply(object message, IActorContext context); + public object UnApply(object message, IActorContext context); } [InternalApi] @@ -24,6 +25,8 @@ internal class EmptyBufferMessageAdapter: IShardingBufferMessageAdapter private EmptyBufferMessageAdapter() { } - + public object Apply(object message, IActorContext context) => message; + + public object UnApply(object message, IActorContext context) => message; } diff --git a/src/contrib/cluster/Akka.Cluster.Sharding/Shard.cs b/src/contrib/cluster/Akka.Cluster.Sharding/Shard.cs index bc8dfd0eb0d..8afedddb406 100644 --- a/src/contrib/cluster/Akka.Cluster.Sharding/Shard.cs +++ b/src/contrib/cluster/Akka.Cluster.Sharding/Shard.cs @@ -344,7 +344,8 @@ public static Props Props( ClusterShardingSettings settings, IMessageExtractor extractor, object handOffStopMessage, - IRememberEntitiesProvider? rememberEntitiesProvider) + IRememberEntitiesProvider? rememberEntitiesProvider, + IShardingBufferMessageAdapter? bufferMessageAdapter) { return Actor.Props.Create(() => new Shard( typeName, @@ -353,7 +354,8 @@ public static Props Props( settings, extractor, handOffStopMessage, - rememberEntitiesProvider)).WithDeploy(Deploy.Local); + rememberEntitiesProvider, + bufferMessageAdapter)).WithDeploy(Deploy.Local); } [Serializable] @@ -976,7 +978,8 @@ public Shard( ClusterShardingSettings settings, IMessageExtractor extractor, object handOffStopMessage, - IRememberEntitiesProvider? rememberEntitiesProvider) + IRememberEntitiesProvider? rememberEntitiesProvider, + IShardingBufferMessageAdapter? bufferMessageAdapter) { _typeName = typeName; _shardId = shardId; @@ -1020,7 +1023,7 @@ public Shard( _leaseRetryInterval = settings.LeaseSettings.LeaseRetryInterval; } - _bufferMessageAdapter = ClusterSharding.Get(Context.System).BufferMessageAdapter; + _bufferMessageAdapter = bufferMessageAdapter ?? EmptyBufferMessageAdapter.Instance; } protected override SupervisorStrategy SupervisorStrategy() @@ -2001,7 +2004,7 @@ private void SendMsgBuffer(EntityId entityId) if (WrappedMessage.Unwrap(message) is ShardRegion.StartEntity se) StartEntity(se.EntityId, @ref); else - DeliverMessage(entityId, message, @ref); + DeliverMessage(entityId, _bufferMessageAdapter.UnApply(message, Context), @ref); } TouchLastMessageTimestamp(entityId); diff --git a/src/contrib/cluster/Akka.Cluster.Sharding/ShardRegion.cs b/src/contrib/cluster/Akka.Cluster.Sharding/ShardRegion.cs index e282dc5ab7f..3df82a2b618 100644 --- a/src/contrib/cluster/Akka.Cluster.Sharding/ShardRegion.cs +++ b/src/contrib/cluster/Akka.Cluster.Sharding/ShardRegion.cs @@ -1310,7 +1310,8 @@ private IActorRef GetShard(ShardId id) _settings, _messageExtractor, _handOffStopMessage, - _rememberEntitiesProvider) + _rememberEntitiesProvider, + _bufferMessageAdapter) .WithDispatcher(Context.Props.Dispatcher), name)); _shardsByRef = _shardsByRef.SetItem(shardRef, id); diff --git a/src/core/Akka.API.Tests/verify/CoreAPISpec.ApproveClusterSharding.DotNet.verified.txt b/src/core/Akka.API.Tests/verify/CoreAPISpec.ApproveClusterSharding.DotNet.verified.txt index 7c8df1bb4ed..49bc3bd4627 100644 --- a/src/core/Akka.API.Tests/verify/CoreAPISpec.ApproveClusterSharding.DotNet.verified.txt +++ b/src/core/Akka.API.Tests/verify/CoreAPISpec.ApproveClusterSharding.DotNet.verified.txt @@ -223,6 +223,7 @@ namespace Akka.Cluster.Sharding public interface IShardingBufferMessageAdapter { object Apply(object message, Akka.Actor.IActorContext context); + object UnApply(object message, Akka.Actor.IActorContext context); } public interface IStartableAllocationStrategy : Akka.Actor.INoSerializationVerificationNeeded, Akka.Cluster.Sharding.IShardAllocationStrategy { diff --git a/src/core/Akka.API.Tests/verify/CoreAPISpec.ApproveClusterSharding.Net.verified.txt b/src/core/Akka.API.Tests/verify/CoreAPISpec.ApproveClusterSharding.Net.verified.txt index 20920646b0d..5570b269431 100644 --- a/src/core/Akka.API.Tests/verify/CoreAPISpec.ApproveClusterSharding.Net.verified.txt +++ b/src/core/Akka.API.Tests/verify/CoreAPISpec.ApproveClusterSharding.Net.verified.txt @@ -223,6 +223,7 @@ namespace Akka.Cluster.Sharding public interface IShardingBufferMessageAdapter { object Apply(object message, Akka.Actor.IActorContext context); + object UnApply(object message, Akka.Actor.IActorContext context); } public interface IStartableAllocationStrategy : Akka.Actor.INoSerializationVerificationNeeded, Akka.Cluster.Sharding.IShardAllocationStrategy {