diff --git a/src/Orleans.Core/Lifecycle/MigrationContext.cs b/src/Orleans.Core/Lifecycle/MigrationContext.cs index d030133343..8b918b49f3 100644 --- a/src/Orleans.Core/Lifecycle/MigrationContext.cs +++ b/src/Orleans.Core/Lifecycle/MigrationContext.cs @@ -117,10 +117,10 @@ public bool TryGetValue(string key, out T? value) IEnumerator IEnumerable.GetEnumerator() => new Enumerator(this); IEnumerator IEnumerable.GetEnumerator() => new Enumerator(this); - private sealed class Enumerator : IEnumerator, IEnumerator + private sealed class Enumerator(MigrationContext context) : IEnumerator, IEnumerator { - private Dictionary.KeyCollection.Enumerator _value; - public Enumerator(MigrationContext context) => _value = context._indices.Keys.GetEnumerator(); + private Dictionary.KeyCollection.Enumerator _value = context._indices.Keys.GetEnumerator(); + public string Current => _value.Current; object IEnumerator.Current => Current; public void Dispose() => _value.Dispose(); @@ -133,18 +133,8 @@ public void Reset() } } - internal sealed class SerializationHooks + internal sealed class SerializationHooks(SerializerSessionPool serializerSessionPool) { - private readonly SerializerSessionPool _serializerSessionPool; - - public SerializationHooks(SerializerSessionPool serializerSessionPool) - { - _serializerSessionPool = serializerSessionPool; - } - - public void OnDeserializing(MigrationContext context) - { - context._sessionPool = _serializerSessionPool; - } + public void OnDeserializing(MigrationContext context) => context._sessionPool = serializerSessionPool; } } diff --git a/src/Orleans.Runtime/Catalog/ActivationMigrationManager.cs b/src/Orleans.Runtime/Catalog/ActivationMigrationManager.cs index 02b6d457a5..19a491b95c 100644 --- a/src/Orleans.Runtime/Catalog/ActivationMigrationManager.cs +++ b/src/Orleans.Runtime/Catalog/ActivationMigrationManager.cs @@ -9,6 +9,7 @@ using System.Threading.Tasks.Sources; using Microsoft.Extensions.Logging; using Microsoft.Extensions.ObjectPool; +using Orleans.Internal; using Orleans.Runtime.Internal; using Orleans.Runtime.Scheduler; @@ -52,7 +53,7 @@ internal interface IActivationMigrationManager /// /// Migrates grain activations to target hosts and handles migration requests from other hosts. /// -internal class ActivationMigrationManager : SystemTarget, IActivationMigrationManagerSystemTarget, IActivationMigrationManager +internal class ActivationMigrationManager : SystemTarget, IActivationMigrationManagerSystemTarget, IActivationMigrationManager, ILifecycleParticipant { private const int MaxBatchSize = 1_000; private readonly ConcurrentDictionary WorkItemChannel)> _workers = new(); @@ -305,6 +306,28 @@ private void RemoveWorker(SiloAddress targetSilo) } } + private Task StartAsync(CancellationToken cancellationToken) => Task.CompletedTask; + private async Task StopAsync(CancellationToken cancellationToken) + { + var workerTasks = new List(); + foreach (var (_, value) in _workers) + { + value.WorkItemChannel.Writer.TryComplete(); + workerTasks.Add(value.PumpTask); + } + + await Task.WhenAll(workerTasks).WithCancellation(cancellationToken); + } + + void ILifecycleParticipant.Participate(ISiloLifecycle lifecycle) + { + lifecycle.Subscribe( + nameof(ActivationMigrationManager), + ServiceLifecycleStage.RuntimeGrainServices, + ct => this.RunOrQueueTask(() => StartAsync(ct)), + ct => this.RunOrQueueTask(() => StopAsync(ct))); + } + private class MigrationWorkItem : IValueTaskSource { private ManualResetValueTaskSourceCore _core = new() { RunContinuationsAsynchronously = true }; diff --git a/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs b/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs index 9a6174c0fe..98165b244a 100644 --- a/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs +++ b/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs @@ -399,6 +399,7 @@ internal static void AddDefaultServices(ISiloBuilder builder) services.AddSingleton(); services.AddSingleton(); services.AddFromExisting(); + services.AddFromExisting, ActivationMigrationManager>(); ApplyConfiguration(builder); } diff --git a/src/Orleans.Serialization/Codecs/ByteArrayCodec.cs b/src/Orleans.Serialization/Codecs/ByteArrayCodec.cs index 08d0ddb979..a8bb281690 100644 --- a/src/Orleans.Serialization/Codecs/ByteArrayCodec.cs +++ b/src/Orleans.Serialization/Codecs/ByteArrayCodec.cs @@ -296,10 +296,11 @@ public static Memory DeepCopy(Memory input, CopyContext copyContext) /// Serializer for instances. /// [RegisterSerializer] - public sealed class PooledBufferCodec : IValueSerializer + public sealed class PooledBufferCodec : IFieldCodec { - public void Serialize(ref Writer writer, scoped ref PooledBuffer value) where TBufferWriter : IBufferWriter + public void WriteField(ref Writer writer, uint fieldIdDelta, Type expectedType, PooledBuffer value) where TBufferWriter : IBufferWriter { + writer.WriteFieldHeader(fieldIdDelta, expectedType, typeof(PooledBuffer), WireType.LengthPrefixed); writer.WriteVarUInt32((uint)value.Length); foreach (var segment in value) { @@ -311,11 +312,12 @@ public void Serialize(ref Writer writer, scoped re // Senders must not use the value after sending. // Receivers must dispose of the value after use. value.Reset(); - value = default; } - public void Deserialize(ref Reader reader, scoped ref PooledBuffer value) + public PooledBuffer ReadValue(ref Reader reader, Field field) { + field.EnsureWireType(WireType.LengthPrefixed); + var value = new PooledBuffer(); const int MaxSpanLength = 4096; var length = (int)reader.ReadVarUInt32(); while (length > 0) @@ -328,6 +330,7 @@ public void Deserialize(ref Reader reader, scoped ref PooledBuff } Debug.Assert(length == 0); + return value; } } diff --git a/test/DefaultCluster.Tests/Migration/MigrationTests.cs b/test/DefaultCluster.Tests/Migration/MigrationTests.cs index 9b537dff85..4b4074cc19 100644 --- a/test/DefaultCluster.Tests/Migration/MigrationTests.cs +++ b/test/DefaultCluster.Tests/Migration/MigrationTests.cs @@ -1,3 +1,4 @@ +using System.Diagnostics; using Orleans.Core.Internal; using Orleans.Runtime; using Orleans.Runtime.Placement; @@ -75,6 +76,61 @@ public async Task DirectedGrainMigrationTest() } } + /// + /// Tests that multiple grains can be migrated simultaneously. + /// + [Fact, TestCategory("BVT")] + public async Task MultiGrainDirectedMigrationTest() + { + var baseId = GetRandomGrainId(); + for (var i = 1; i < 100; ++i) + { + var a = GrainFactory.GetGrain(baseId + 2 * i); + var expectedState = Random.Shared.Next(); + await a.SetState(expectedState); + var originalAddressA = await a.GetGrainAddress(); + var originalHostA = originalAddressA.SiloAddress; + + RequestContext.Set(IPlacementDirector.PlacementHintKey, originalHostA); + var b = GrainFactory.GetGrain(baseId + 1 + 2 * i); + await b.SetState(expectedState); + var originalAddressB = await b.GetGrainAddress(); + Assert.Equal(originalHostA, originalAddressB.SiloAddress); + + var targetHost = Fixture.HostedCluster.GetActiveSilos().Select(s => s.SiloAddress).First(address => address != originalHostA); + + // Trigger migration, setting a placement hint to coerce the placement director to use the target silo + RequestContext.Set(IPlacementDirector.PlacementHintKey, targetHost); + var migrateA = a.Cast().MigrateOnIdle(); + var migrateB = b.Cast().MigrateOnIdle(); + await migrateA; + await migrateB; + + while (true) + { + var newAddress = await a.GetGrainAddress(); + if (newAddress.ActivationId != originalAddressA.ActivationId) + { + Assert.Equal(targetHost, newAddress.SiloAddress); + break; + } + } + + while (true) + { + var newAddress = await b.GetGrainAddress(); + if (newAddress.ActivationId != originalAddressB.ActivationId) + { + Assert.Equal(targetHost, newAddress.SiloAddress); + break; + } + } + + Assert.Equal(expectedState, await a.GetState()); + Assert.Equal(expectedState, await b.GetState()); + } + } + /// /// Tests that grain migration works for a simple grain which uses for state. /// The test specifies an alternative location for the grain to migrate to and asserts that it migrates to that location. diff --git a/test/NonSilo.Tests/Serialization/BuiltInSerializerTests.cs b/test/NonSilo.Tests/Serialization/BuiltInSerializerTests.cs index 9f5c602169..27a30a4c7f 100644 --- a/test/NonSilo.Tests/Serialization/BuiltInSerializerTests.cs +++ b/test/NonSilo.Tests/Serialization/BuiltInSerializerTests.cs @@ -63,7 +63,6 @@ public void ValueTupleTypesHasSerializer() /// /// Tests that the default (non-fallback) serializer can handle complex classes. /// - /// [Fact, TestCategory("BVT")] public void Serialize_ComplexAccessibleClass() { diff --git a/test/Orleans.Serialization.UnitTests/PooledBufferTests.cs b/test/Orleans.Serialization.UnitTests/PooledBufferTests.cs index 19d043e583..a472513e78 100644 --- a/test/Orleans.Serialization.UnitTests/PooledBufferTests.cs +++ b/test/Orleans.Serialization.UnitTests/PooledBufferTests.cs @@ -126,6 +126,90 @@ static void SerializeObject(SerializerSessionPool pool, Serializer serializer, L } } + /// + /// Tests that the serializer can correctly serialized . + /// + [Fact] + public void PooledBuffer_SerializerRoundTrip() + { + var serviceProvider = new ServiceCollection() + .AddSerializer() + .BuildServiceProvider(); + var serializer = serviceProvider.GetRequiredService(); + + var random = new Random(); + for (var i = 0; i < 10; i++) + { + const int TargetLength = 8120; + + // NOTE: The serializer is responsible for freeing the buffer provided to it, so we do not free this. + var buffer = new PooledBuffer(); + while (buffer.Length < TargetLength) + { + var span = buffer.GetSpan(TargetLength - buffer.Length); + var writeLen = Math.Min(span.Length, TargetLength - buffer.Length); + random.NextBytes(span[..writeLen]); + buffer.Advance(writeLen); + } + + var bytes = buffer.ToArray(); + Assert.Equal(TargetLength, bytes.Length); + + var result = serializer.Deserialize(serializer.SerializeToArray(buffer)); + Assert.Equal(TargetLength, result.Length); + + var resultBytes = result.ToArray(); + Assert.Equal(bytes, resultBytes); + + // NOTE: we are responsible for disposing a buffer returned from deserialization. + result.Dispose(); + } + } + + /// + /// Tests that the serializer can correctly serialized when it's embedded in another structure. + /// + [Fact] + public void PooledBuffer_SerializerRoundTrip_Embedded() + { + var serviceProvider = new ServiceCollection() + .AddSerializer() + .BuildServiceProvider(); + var serializer = serviceProvider.GetRequiredService(); + + var random = new Random(); + for (var i = 0; i < 10; i++) + { + const int TargetLength = 8120; + + // NOTE: The serializer is responsible for freeing the buffer provided to it, so we do not free this. + var buffer = new PooledBuffer(); + while (buffer.Length < TargetLength) + { + var span = buffer.GetSpan(TargetLength - buffer.Length); + var writeLen = Math.Min(span.Length, TargetLength - buffer.Length); + random.NextBytes(span[..writeLen]); + buffer.Advance(writeLen); + } + + var bytes = buffer.ToArray(); + Assert.Equal(TargetLength, bytes.Length); + + var embed = (Guid: Guid.NewGuid(), Buffer: buffer, Int: 42); + var result = serializer.Deserialize<(Guid Guid, PooledBuffer Buffer, int Int)>(serializer.SerializeToArray(embed)); + Assert.Equal(embed.Guid, result.Guid); + Assert.Equal(embed.Int, result.Int); + var resultBuffer = result.Buffer; + Assert.Equal(TargetLength, resultBuffer.Length); + + var resultBytes = resultBuffer.ToArray(); + Assert.Equal(bytes, resultBytes); + + // NOTE: we are responsible for disposing a buffer returned from deserialization. + resultBuffer.Dispose(); + } + } + [GenerateSerializer] public readonly record struct LargeObject( [property: Id(0)] Guid Id,