diff --git a/src/Orleans.Core/Core/ClientBuilderExtensions.cs b/src/Orleans.Core/Core/ClientBuilderExtensions.cs index 5daf2b48c3..3a7b35033d 100644 --- a/src/Orleans.Core/Core/ClientBuilderExtensions.cs +++ b/src/Orleans.Core/Core/ClientBuilderExtensions.cs @@ -127,7 +127,8 @@ public static IClientBuilder AddActivityPropagation(this IClientBuilder builder) builder.Services.TryAddSingleton(DistributedContextPropagator.Current); return builder - .AddOutgoingGrainCallFilter(); + .AddOutgoingGrainCallFilter() + .AddIncomingGrainCallFilter(); } /// diff --git a/src/Orleans.Core/Core/ClientBuilderGrainCallFilterExtensions.cs b/src/Orleans.Core/Core/ClientBuilderGrainCallFilterExtensions.cs index 3fc6ab0929..bbf57532bb 100644 --- a/src/Orleans.Core/Core/ClientBuilderGrainCallFilterExtensions.cs +++ b/src/Orleans.Core/Core/ClientBuilderGrainCallFilterExtensions.cs @@ -1,42 +1,75 @@ -namespace Orleans.Hosting +namespace Orleans.Hosting; + +/// +/// Extensions for configuring grain call filters. +/// +public static class ClientBuilderGrainCallFilterExtensions { /// - /// Extensions for configuring grain call filters. + /// Adds an to the filter pipeline. + /// + /// The builder. + /// The filter. + /// The builder. + public static IClientBuilder AddIncomingGrainCallFilter(this IClientBuilder builder, IIncomingGrainCallFilter filter) + { + return builder.ConfigureServices(services => services.AddIncomingGrainCallFilter(filter)); + } + + /// + /// Adds an to the filter pipeline. + /// + /// The filter implementation type. + /// The builder. + /// The builder. + public static IClientBuilder AddIncomingGrainCallFilter(this IClientBuilder builder) + where TImplementation : class, IIncomingGrainCallFilter + { + return builder.ConfigureServices(services => services.AddIncomingGrainCallFilter()); + } + + /// + /// Adds an to the filter pipeline via a delegate. + /// + /// The builder. + /// The filter. + /// The builder. + public static IClientBuilder AddIncomingGrainCallFilter(this IClientBuilder builder, IncomingGrainCallFilterDelegate filter) + { + return builder.ConfigureServices(services => services.AddIncomingGrainCallFilter(filter)); + } + + /// + /// Adds an to the filter pipeline. /// - public static class ClientBuilderGrainCallFilterExtensions + /// The builder. + /// The filter. + /// The . + public static IClientBuilder AddOutgoingGrainCallFilter(this IClientBuilder builder, IOutgoingGrainCallFilter filter) { - /// - /// Adds an to the filter pipeline. - /// - /// The builder. - /// The filter. - /// The . - public static IClientBuilder AddOutgoingGrainCallFilter(this IClientBuilder builder, IOutgoingGrainCallFilter filter) - { - return builder.ConfigureServices(services => services.AddOutgoingGrainCallFilter(filter)); - } + return builder.ConfigureServices(services => services.AddOutgoingGrainCallFilter(filter)); + } - /// - /// Adds an to the filter pipeline. - /// - /// The filter implementation type. - /// The builder. - /// The . - public static IClientBuilder AddOutgoingGrainCallFilter(this IClientBuilder builder) - where TImplementation : class, IOutgoingGrainCallFilter - { - return builder.ConfigureServices(services => services.AddOutgoingGrainCallFilter()); - } + /// + /// Adds an to the filter pipeline. + /// + /// The filter implementation type. + /// The builder. + /// The . + public static IClientBuilder AddOutgoingGrainCallFilter(this IClientBuilder builder) + where TImplementation : class, IOutgoingGrainCallFilter + { + return builder.ConfigureServices(services => services.AddOutgoingGrainCallFilter()); + } - /// - /// Adds an to the filter pipeline via a delegate. - /// - /// The builder. - /// The filter. - /// The . - public static IClientBuilder AddOutgoingGrainCallFilter(this IClientBuilder builder, OutgoingGrainCallFilterDelegate filter) - { - return builder.ConfigureServices(services => services.AddOutgoingGrainCallFilter(filter)); - } + /// + /// Adds an to the filter pipeline via a delegate. + /// + /// The builder. + /// The filter. + /// The . + public static IClientBuilder AddOutgoingGrainCallFilter(this IClientBuilder builder, OutgoingGrainCallFilterDelegate filter) + { + return builder.ConfigureServices(services => services.AddOutgoingGrainCallFilter(filter)); } } \ No newline at end of file diff --git a/src/Orleans.Core/Core/DefaultClientServices.cs b/src/Orleans.Core/Core/DefaultClientServices.cs index 9f077f9525..51315459af 100644 --- a/src/Orleans.Core/Core/DefaultClientServices.cs +++ b/src/Orleans.Core/Core/DefaultClientServices.cs @@ -73,6 +73,7 @@ public static void AddDefaultServices(IClientBuilder builder) services.TryAddSingleton(); services.TryAddSingleton(); services.TryAddSingleton(); + services.TryAddSingleton(); services.TryAddSingleton(); services.AddFromExisting(); services.TryAddFromExisting(); diff --git a/src/Orleans.Runtime/Core/GrainMethodInvoker.cs b/src/Orleans.Core/Core/GrainMethodInvoker.cs similarity index 100% rename from src/Orleans.Runtime/Core/GrainMethodInvoker.cs rename to src/Orleans.Core/Core/GrainMethodInvoker.cs diff --git a/src/Orleans.Core/Runtime/InvokableObjectManager.cs b/src/Orleans.Core/Runtime/InvokableObjectManager.cs index a729e3f6ce..48f752750f 100644 --- a/src/Orleans.Core/Runtime/InvokableObjectManager.cs +++ b/src/Orleans.Core/Runtime/InvokableObjectManager.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Orleans.Runtime; using Orleans.Serialization; @@ -14,23 +15,34 @@ internal class InvokableObjectManager : IDisposable { private readonly CancellationTokenSource disposed = new CancellationTokenSource(); private readonly ConcurrentDictionary localObjects = new ConcurrentDictionary(); + + private readonly InterfaceToImplementationMappingCache _interfaceToImplementationMapping; private readonly IGrainContext rootGrainContext; private readonly IRuntimeClient runtimeClient; private readonly ILogger logger; private readonly DeepCopier deepCopier; + private readonly DeepCopier _responseCopier; private readonly MessagingTrace messagingTrace; + private List _grainCallFilters; + + private List GrainCallFilters + => _grainCallFilters ??= new List(runtimeClient.ServiceProvider.GetServices()); public InvokableObjectManager( IGrainContext rootGrainContext, IRuntimeClient runtimeClient, DeepCopier deepCopier, MessagingTrace messagingTrace, + DeepCopier responseCopier, + InterfaceToImplementationMappingCache interfaceToImplementationMapping, ILogger logger) { this.rootGrainContext = rootGrainContext; this.runtimeClient = runtimeClient; this.deepCopier = deepCopier; this.messagingTrace = messagingTrace; + _responseCopier = responseCopier; + _interfaceToImplementationMapping = interfaceToImplementationMapping; this.logger = logger; } @@ -246,7 +258,20 @@ private async Task LocalObjectMessagePumpAsync() try { request.SetTarget(this); - var response = await request.Invoke(); + var filters = _manager.GrainCallFilters; + Response response; + if (filters is { Count: > 0 } || LocalObject is IIncomingGrainCallFilter) + { + var invoker = new GrainMethodInvoker(message, this, request, filters, _manager._interfaceToImplementationMapping, _manager._responseCopier); + await invoker.Invoke(); + response = invoker.Response; + } + else + { + response = await request.Invoke(); + response = _manager._responseCopier.Copy(response); + } + if (message.Direction != Message.Directions.OneWay) { this.SendResponseAsync(message, response); diff --git a/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs b/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs index 28e2c34ee8..c44ac4316d 100644 --- a/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs +++ b/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs @@ -13,6 +13,7 @@ using Orleans.Runtime; using Orleans.Serialization; using Orleans.Serialization.Invocation; +using Orleans.Serialization.Serializers; using static Orleans.Internal.StandardExtensions; namespace Orleans @@ -30,6 +31,7 @@ internal class OutsideRuntimeClient : IRuntimeClient, IDisposable, IClusterConne private bool disposed; private readonly MessagingTrace messagingTrace; + private readonly InterfaceToImplementationMappingCache _interfaceToImplementationMapping; public IInternalGrainFactory InternalGrainFactory { get; private set; } @@ -65,9 +67,11 @@ public OutsideRuntimeClient( IOptions clientMessagingOptions, MessagingTrace messagingTrace, IServiceProvider serviceProvider, - TimeProvider timeProvider) + TimeProvider timeProvider, + InterfaceToImplementationMappingCache interfaceToImplementationMapping) { TimeProvider = timeProvider; + _interfaceToImplementationMapping = interfaceToImplementationMapping; this.ServiceProvider = serviceProvider; _localClientDetails = localClientDetails; this.loggerFactory = loggerFactory; @@ -105,14 +109,14 @@ internal void ConsumeServices() this.InternalGrainFactory = this.ServiceProvider.GetRequiredService(); this.messageFactory = this.ServiceProvider.GetService(); - - var copier = this.ServiceProvider.GetRequiredService(); this.localObjects = new InvokableObjectManager( ServiceProvider.GetRequiredService(), this, - copier, - this.messagingTrace, - this.loggerFactory.CreateLogger()); + ServiceProvider.GetRequiredService(), + messagingTrace, + ServiceProvider.GetRequiredService>(), + _interfaceToImplementationMapping, + loggerFactory.CreateLogger()); this.callbackTimerTask = Task.Run(MonitorCallbackExpiry); diff --git a/src/Orleans.Runtime/Core/HostedClient.cs b/src/Orleans.Runtime/Core/HostedClient.cs index 5a2c6f2a1a..04faef384c 100644 --- a/src/Orleans.Runtime/Core/HostedClient.cs +++ b/src/Orleans.Runtime/Core/HostedClient.cs @@ -12,6 +12,7 @@ using Orleans.Internal; using Orleans.Runtime.Messaging; using Orleans.Serialization; +using Orleans.Serialization.Invocation; namespace Orleans.Runtime { @@ -24,7 +25,7 @@ internal sealed class HostedClient : IGrainContext, IGrainExtensionBinder, IDisp private readonly Channel incomingMessages; private readonly IGrainReferenceRuntime grainReferenceRuntime; private readonly InvokableObjectManager invokableObjects; - private readonly IRuntimeClient runtimeClient; + private readonly InsideRuntimeClient runtimeClient; private readonly ILogger logger; private readonly IInternalGrainFactory grainFactory; private readonly MessageCenter siloMessageCenter; @@ -36,7 +37,7 @@ internal sealed class HostedClient : IGrainContext, IGrainExtensionBinder, IDisp private Task? messagePump; public HostedClient( - IRuntimeClient runtimeClient, + InsideRuntimeClient runtimeClient, ILocalSiloDetails siloDetails, ILogger logger, IGrainReferenceRuntime grainReferenceRuntime, @@ -44,7 +45,8 @@ public HostedClient( MessageCenter messageCenter, MessagingTrace messagingTrace, DeepCopier deepCopier, - GrainReferenceActivator referenceActivator) + GrainReferenceActivator referenceActivator, + InterfaceToImplementationMappingCache interfaceToImplementationMappingCache) { this.incomingMessages = Channel.CreateUnbounded(new UnboundedChannelOptions { @@ -61,6 +63,8 @@ public HostedClient( runtimeClient, deepCopier, messagingTrace, + runtimeClient.ServiceProvider.GetRequiredService>(), + interfaceToImplementationMappingCache, logger); this.siloMessageCenter = messageCenter; this.messagingTrace = messagingTrace; diff --git a/src/Orleans.Runtime/Core/InsideRuntimeClient.cs b/src/Orleans.Runtime/Core/InsideRuntimeClient.cs index 6a50eecfd1..922fe223d2 100644 --- a/src/Orleans.Runtime/Core/InsideRuntimeClient.cs +++ b/src/Orleans.Runtime/Core/InsideRuntimeClient.cs @@ -31,6 +31,7 @@ internal sealed class InsideRuntimeClient : IRuntimeClient, ILifecycleParticipan private readonly ILoggerFactory loggerFactory; private readonly SiloMessagingOptions messagingOptions; private readonly ConcurrentDictionary<(GrainId, CorrelationId), CallbackData> callbacks; + private readonly InterfaceToImplementationMappingCache interfaceToImplementationMapping; private readonly SharedCallbackData sharedCallbackData; private readonly SharedCallbackData systemSharedCallbackData; private readonly PeriodicTimer callbackTimer; @@ -39,10 +40,9 @@ internal sealed class InsideRuntimeClient : IRuntimeClient, ILifecycleParticipan private MessageCenter messageCenter; private List grainCallFilters; private readonly DeepCopier _deepCopier; - private readonly InterfaceToImplementationMappingCache interfaceToImplementationMapping; private HostedClient hostedClient; - private HostedClient HostedClient => this.hostedClient ?? (this.hostedClient = this.ServiceProvider.GetRequiredService()); + private HostedClient HostedClient => this.hostedClient ??= this.ServiceProvider.GetRequiredService(); private readonly MessageFactory messageFactory; private IGrainReferenceRuntime grainReferenceRuntime; private Task callbackTimerTask; @@ -60,10 +60,11 @@ public InsideRuntimeClient( GrainInterfaceTypeResolver interfaceIdResolver, GrainInterfaceTypeToGrainTypeResolver interfaceToTypeResolver, DeepCopier deepCopier, - TimeProvider timeProvider) + TimeProvider timeProvider, + InterfaceToImplementationMappingCache interfaceToImplementationMapping) { TimeProvider = timeProvider; - this.interfaceToImplementationMapping = new InterfaceToImplementationMappingCache(); + this.interfaceToImplementationMapping = interfaceToImplementationMapping; this._deepCopier = deepCopier; this.ServiceProvider = serviceProvider; this.MySilo = siloDetails.SiloAddress; @@ -102,7 +103,7 @@ private GrainLocator GrainLocator => this.grainLocator ?? (this.grainLocator = this.ServiceProvider.GetRequiredService()); private List GrainCallFilters - => this.grainCallFilters ?? (this.grainCallFilters = new List(this.ServiceProvider.GetServices())); + => this.grainCallFilters ??= new List(this.ServiceProvider.GetServices()); private MessageCenter MessageCenter => this.messageCenter ?? (this.messageCenter = this.ServiceProvider.GetRequiredService()); diff --git a/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs b/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs index 2e9d5d3e42..2985a61904 100644 --- a/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs +++ b/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs @@ -100,6 +100,7 @@ internal static void AddDefaultServices(ISiloBuilder builder) services.AddTransient(); services.AddKeyedTransient(typeof(ICancellationSourcesExtension), (sp, _) => sp.GetRequiredService()); services.TryAddSingleton(sp => sp.GetRequiredService().ConcreteGrainFactory); + services.TryAddSingleton(); services.TryAddSingleton(); services.TryAddFromExisting(); services.TryAddFromExisting(); diff --git a/test/Grains/TestGrainInterfaces/IMethodInterceptionGrain.cs b/test/Grains/TestGrainInterfaces/IMethodInterceptionGrain.cs index 8a6b720b15..36a5aabbc9 100644 --- a/test/Grains/TestGrainInterfaces/IMethodInterceptionGrain.cs +++ b/test/Grains/TestGrainInterfaces/IMethodInterceptionGrain.cs @@ -20,6 +20,22 @@ public interface IMethodInterceptionGrain : IGrainWithIntegerKey, IMethodFromAno Task SystemWideCallFilterMarker(); } + [GrainInterfaceType("obs-method-interception-custom-name")] + public interface IMethodInterceptionGrainObserver : IGrainObserver, IMethodFromAnotherInterface + { + [Id(14142)] + Task One(); + + [Id(4142)] + Task Echo(string someArg); + Task NotIntercepted(); + Task Throw(); + Task IncorrectResultType(); + Task FilterThrows(); + + Task SystemWideCallFilterMarker(); + } + [GrainInterfaceType("custom-outgoing-interception-grain")] public interface IOutgoingMethodInterceptionGrain : IGrainWithIntegerKey { @@ -34,11 +50,17 @@ public interface IGenericMethodInterceptionGrain : IGrainWithIntegerKey, I Task GetInputAsString(T input); } + [Alias("UnitTests.GrainInterfaces.IGenericMethodInterceptionGrainObserver`1")] + public interface IGenericMethodInterceptionGrainObserver : IGrainObserver, IMethodFromAnotherInterface + { + [Alias("GetInputAsString")] + Task GetInputAsString(T input); + } + public interface IMethodFromAnotherInterface { Task SayHello(); } - [Alias("UnitTests.GrainInterfaces.ITrickyMethodInterceptionGrain")] public interface ITrickyMethodInterceptionGrain : IGenericMethodInterceptionGrain, IGenericMethodInterceptionGrain @@ -47,6 +69,13 @@ public interface ITrickyMethodInterceptionGrain : IGenericMethodInterceptionGrai Task GetBestNumber(); } + [Alias("UnitTests.GrainInterfaces.ITrickyMethodInterceptionGrainObserver")] + public interface ITrickyMethodInterceptionGrainObserver : IGenericMethodInterceptionGrainObserver, IGenericMethodInterceptionGrainObserver + { + [Alias("GetBestNumber")] + Task GetBestNumber(); + } + [Alias("UnitTests.GrainInterfaces.ITrickierMethodInterceptionGrain")] public interface ITrickierMethodInterceptionGrain : IGenericMethodInterceptionGrain>, IGenericMethodInterceptionGrain> { @@ -68,6 +97,17 @@ public interface IGrainCallFilterTestGrain : IGrainWithIntegerKey Task GrainSpecificCallFilterMarker(); } + public interface IGrainCallFilterTestGrainObserver : IGrainObserver + { + Task ThrowIfGreaterThanZero(int value); + Task GetRequestContext(); + + Task SumSet(HashSet numbers); + + Task SystemWideCallFilterMarker(); + Task GrainSpecificCallFilterMarker(); + } + public interface IHungryGrain : IGrainWithIntegerKey { [TestMethodTag("hungry-eat")] diff --git a/test/Grains/TestGrains/MethodInterceptionGrain.cs b/test/Grains/TestGrains/MethodInterceptionGrain.cs index c19718bc5e..514a1a8939 100644 --- a/test/Grains/TestGrains/MethodInterceptionGrain.cs +++ b/test/Grains/TestGrains/MethodInterceptionGrain.cs @@ -29,10 +29,7 @@ public Task ThrowIfGreaterThanZero(int value) public class MethodInterceptionGrain : IMethodInterceptionGrain, IIncomingGrainCallFilter { - public Task One() - { - throw new InvalidOperationException("Not allowed to actually invoke this method!"); - } + public Task One() => throw new InvalidOperationException("Not allowed to actually invoke this method!"); [MessWithResult] public Task Echo(string someArg) => Task.FromResult(someArg); @@ -41,10 +38,95 @@ public Task One() public Task SayHello() => Task.FromResult("Hello"); - public Task Throw() + public Task Throw() => throw new MyDomainSpecificException("Oi!"); + + public Task FilterThrows() => Task.CompletedTask; + + public Task SystemWideCallFilterMarker() => Task.CompletedTask; + + public Task IncorrectResultType() => Task.FromResult("hop scotch"); + + async Task IIncomingGrainCallFilter.Invoke(IIncomingGrainCallContext context) + { + var methodInfo = context.ImplementationMethod; + if (methodInfo.Name == nameof(One) && methodInfo.GetParameters().Length == 0) + { + // Short-circuit the request and return to the caller without actually invoking the grain method. + context.Result = "intercepted one with no args"; + return; + } + + if (methodInfo.Name == nameof(IncorrectResultType)) + { + // This method has a string return type, but we are setting the result to a Guid. + // This should result in an invalid cast exception. + context.Result = Guid.NewGuid(); + return; + } + + if (methodInfo.Name == nameof(FilterThrows)) + { + throw new MyDomainSpecificException("Filter THROW!"); + } + + // Invoke the request. + try + { + await context.Invoke(); + } + catch (MyDomainSpecificException e) + { + context.Result = "EXCEPTION! " + e.Message; + return; + } + + // To prove that the MethodInfo is from the implementation and not the interface, + // we check for this attribute which is only present on the implementation. This could be + // done in a simpler fashion, but this demonstrates a potential usage scenario. + var shouldMessWithResult = methodInfo.GetCustomAttribute(); + var resultString = context.Result as string; + if (shouldMessWithResult != null && resultString != null) + { + context.Result = string.Concat(resultString.Reverse()); + } + } + + [Serializable] + [GenerateSerializer] + public class MyDomainSpecificException : Exception + { + public MyDomainSpecificException() + { + } + + public MyDomainSpecificException(string message) : base(message) + { + } + + [Obsolete] + protected MyDomainSpecificException(SerializationInfo info, StreamingContext context) : base(info, context) + { + } + } + + [AttributeUsage(AttributeTargets.Method)] + public class MessWithResultAttribute : Attribute { - throw new MyDomainSpecificException("Oi!"); } + } + + public class MethodInterceptionGrainObserver : IMethodInterceptionGrainObserver, IIncomingGrainCallFilter + { + public Task One() => throw new InvalidOperationException("Not allowed to actually invoke this method!"); + + [MessWithResult] + public Task Echo(string someArg) => Task.FromResult(someArg); + + public Task NotIntercepted() => Task.FromResult("not intercepted"); + + public Task SayHello() => Task.FromResult("Hello"); + + public Task Throw() => throw new MyDomainSpecificException("Oi!"); public Task FilterThrows() => Task.CompletedTask; @@ -159,6 +241,44 @@ public async Task Invoke(IIncomingGrainCallContext context) } } + public class GenericMethodInterceptionGrainObserver : IGenericMethodInterceptionGrainObserver, IIncomingGrainCallFilter + { + public Task SayHello() => Task.FromResult("Hello"); + + public Task GetInputAsString(T input) => Task.FromResult(input.ToString()); + public async Task Invoke(IIncomingGrainCallContext context) + { + if (context.ImplementationMethod.Name == nameof(GetInputAsString)) + { + context.Result = $"Hah! You wanted {context.Request.GetArgument(0)}, but you got me!"; + return; + } + + await context.Invoke(); + } + } + + public class TrickyInterceptionGrainObserver : ITrickyMethodInterceptionGrainObserver, IIncomingGrainCallFilter + { + public Task SayHello() => Task.FromResult("Hello"); + + public Task GetInputAsString(string input) => Task.FromResult(input); + + public Task GetInputAsString(bool input) => Task.FromResult(input.ToString(CultureInfo.InvariantCulture)); + + public Task GetBestNumber() => Task.FromResult(38); + public async Task Invoke(IIncomingGrainCallContext context) + { + if (context.ImplementationMethod.Name == nameof(GetInputAsString)) + { + context.Result = $"Hah! You wanted {context.Request.GetArgument(0)}, but you got me!"; + return; + } + + await context.Invoke(); + } + } + public class GrainCallFilterTestGrain : IGrainCallFilterTestGrain, IIncomingGrainCallFilter { private const string Key = GrainCallFilterTestConstants.Key; @@ -192,7 +312,7 @@ public async Task Invoke(IIncomingGrainCallContext ctx) if (string.Equals(implementationMethod.Name, nameof(GrainSpecificCallFilterMarker))) { - // explicitely do not continue calling Invoke + // explicitly do not continue calling Invoke return; } @@ -212,20 +332,71 @@ public async Task Invoke(IIncomingGrainCallContext ctx) } } - public Task SumSet(HashSet numbers) - { - return Task.FromResult(numbers.Sum()); - } + public Task SumSet(HashSet numbers) => Task.FromResult(numbers.Sum()); + + public Task SystemWideCallFilterMarker() => Task.CompletedTask; + + public Task GrainSpecificCallFilterMarker() => Task.CompletedTask; + } - public Task SystemWideCallFilterMarker() + public class GrainCallFilterTestGrainObserver : IGrainCallFilterTestGrainObserver, IIncomingGrainCallFilter + { + private const string Key = GrainCallFilterTestConstants.Key; + + public Task ThrowIfGreaterThanZero(int value) { - return Task.CompletedTask; + if (value > 0) + { + throw new ArgumentOutOfRangeException($"{value} is greater than zero!"); + } + + return Task.FromResult("Thanks for nothing"); } - public Task GrainSpecificCallFilterMarker() + public Task GetRequestContext() => Task.FromResult((string)RequestContext.Get(Key) + "4"); + + public async Task Invoke(IIncomingGrainCallContext ctx) { - return Task.CompletedTask; + var attemptsRemaining = 2; + + while (attemptsRemaining > 0) + { + try + { + var interfaceMethod = ctx.InterfaceMethod ?? throw new ArgumentException("InterfaceMethod is null!"); + var implementationMethod = ctx.ImplementationMethod ?? throw new ArgumentException("ImplementationMethod is null!"); + if (!string.Equals(implementationMethod.Name, interfaceMethod.Name)) + { + throw new ArgumentException("InterfaceMethod.Name != ImplementationMethod.Name"); + } + + if (string.Equals(implementationMethod.Name, nameof(GrainSpecificCallFilterMarker))) + { + // explicitly do not continue calling Invoke + return; + } + + if (RequestContext.Get(Key) is string value) RequestContext.Set(Key, value + '3'); + await ctx.Invoke(); + return; + } + catch (ArgumentOutOfRangeException) when (attemptsRemaining > 1) + { + if (string.Equals(ctx.ImplementationMethod?.Name, nameof(ThrowIfGreaterThanZero)) && ctx.Request.GetArgument(0) is int value) + { + ctx.Request.SetArgument(0, value - 1); + } + + --attemptsRemaining; + } + } } + + public Task SumSet(HashSet numbers) => Task.FromResult(numbers.Sum()); + + public Task SystemWideCallFilterMarker() => Task.CompletedTask; + + public Task GrainSpecificCallFilterMarker() => Task.CompletedTask; } public class CaterpillarGrain : ICaterpillarGrain, IIncomingGrainCallFilter diff --git a/test/Tester/GrainCallFilterTests.cs b/test/Tester/GrainCallFilterTests.cs index 02ba146374..f15ece713c 100644 --- a/test/Tester/GrainCallFilterTests.cs +++ b/test/Tester/GrainCallFilterTests.cs @@ -8,6 +8,7 @@ using UnitTests.Grains; using Xunit; using Orleans.Providers; +using System.Diagnostics; namespace UnitTests.General { @@ -109,7 +110,7 @@ public void Configure(ISiloBuilder hostBuilder) if (string.Equals(ctx.InterfaceMethod?.Name, nameof(IMethodInterceptionGrain.SystemWideCallFilterMarker))) { - // explicitely do not continue calling Invoke + // explicitly do not continue calling Invoke return; } @@ -126,6 +127,33 @@ private class ClientConfigurator : IClientBuilderConfigurator public void Configure(IConfiguration configuration, IClientBuilder clientBuilder) { clientBuilder + .AddIncomingGrainCallFilter(context => + { + Assert.NotNull(context); + Assert.NotNull(context.InterfaceMethod); + Assert.NotNull(context.Grain); + Assert.NotNull(context.ImplementationMethod); + Assert.NotNull(context.TargetContext); + Assert.NotEmpty(context.InterfaceName); + Assert.NotEmpty(context.MethodName); + Assert.False(context.TargetId.IsDefault); + Assert.False(context.InterfaceType.IsDefault); + + if (string.Equals(context.InterfaceMethod.Name, nameof(IGrainCallFilterTestGrainObserver.GetRequestContext))) + { + if (RequestContext.Get(GrainCallFilterTestConstants.Key) != null) throw new InvalidOperationException(); + RequestContext.Set(GrainCallFilterTestConstants.Key, "1"); + } + + if (string.Equals(context.InterfaceMethod.Name, nameof(IGrainCallFilterTestGrainObserver.SystemWideCallFilterMarker))) + { + // explicitly do not continue calling Invoke + return Task.CompletedTask; + } + + return context.Invoke(); + }) + .AddIncomingGrainCallFilter() .AddOutgoingGrainCallFilter(RetryCertainCalls) .AddOutgoingGrainCallFilter(async context => { @@ -181,19 +209,11 @@ static async Task RetryCertainCalls(IOutgoingGrainCallContext ctx) } [SuppressMessage("ReSharper", "NotAccessedField.Local")] - public class GrainCallFilterWithDependencies : IIncomingGrainCallFilter + public class GrainCallFilterWithDependencies(IGrainFactory grainFactory) : IIncomingGrainCallFilter { - private readonly Silo silo; - private readonly IGrainFactory grainFactory; - - public GrainCallFilterWithDependencies(Silo silo, IGrainFactory grainFactory) - { - this.silo = silo; - this.grainFactory = grainFactory; - } - public Task Invoke(IIncomingGrainCallContext context) { + Assert.NotNull(grainFactory); if (string.Equals(context.ImplementationMethod?.Name, nameof(IGrainCallFilterTestGrain.GetRequestContext))) { if (RequestContext.Get(GrainCallFilterTestConstants.Key) is string value) @@ -499,5 +519,183 @@ public async Task GrainCallFilter_Outgoing_SystemWideDoesNotCallContextInvoke_Te // InvalidOperationException, not an NullReferenceException. await Assert.ThrowsAsync(() => grain.SystemWideCallFilterMarker()); } + + /// + /// Ensures that grain call filters are invoked around method calls in the correct order. + /// + [Fact] + public async Task Observer_GrainCallFilter_Incoming_Order_Test() + { + var observer = new GrainCallFilterTestGrainObserver(); + var grain = this.fixture.GrainFactory.CreateObjectReference(observer); + + // This grain method reads the context and returns it + var context = await grain.GetRequestContext(); + Assert.NotNull(context); + Assert.Equal("1234", context); + } + + /// + /// Tests that an incoming call filter can retry calls to an observer. + /// + [Fact] + public async Task Observer_GrainCallFilter_Incoming_Retry_Test() + { + var observer = new GrainCallFilterTestGrainObserver(); + var grain = this.fixture.GrainFactory.CreateObjectReference(observer); + + var result = await grain.ThrowIfGreaterThanZero(1); + Assert.Equal("Thanks for nothing", result); + + await Assert.ThrowsAsync(() => grain.ThrowIfGreaterThanZero(2)); + } + + /// + /// Tests that an incoming call filter works on an observer with HashSet. + /// + [Fact] + public async Task Observer_GrainCallFilter_Incoming_HashSet_Test() + { + var observer = new GrainCallFilterTestGrainObserver(); + var grain = this.fixture.GrainFactory.CreateObjectReference(observer); + + var result = await grain.SumSet(new HashSet { 1, 2, 3 }); + Assert.Equal(6, result); + } + + /// + /// Tests that if a grain call filter does not call , + /// an exception is thrown on the caller. + /// + [Fact] + public async Task Observer_GrainCallFilter_Incoming_SystemWideDoesNotCallContextInvoke_Test() + { + var observer = new GrainCallFilterTestGrainObserver(); + var grain = this.fixture.GrainFactory.CreateObjectReference(observer); + + // The call filter doesn't continue the Invoke chain, but the error state should be thrown as an + // InvalidOperationException, not an NullReferenceException. + await Assert.ThrowsAsync(() => grain.SystemWideCallFilterMarker()); + } + + /// + /// Tests that if a grain call filter does not call , + /// an exception is thrown on the caller. + /// + [Fact] + public async Task Observer_GrainCallFilter_Incoming_GrainSpecificDoesNotCallContextInvoke_Test() + { + var observer = new GrainCallFilterTestGrainObserver(); + var grain = this.fixture.GrainFactory.CreateObjectReference(observer); + + // The call filter doesn't continue the Invoke chain, but the error state should be thrown as an + // InvalidOperationException, not an NullReferenceException. + await Assert.ThrowsAsync(() => grain.GrainSpecificCallFilterMarker()); + } + + /// + /// Tests filters on just the grain level. + /// + [Fact] + public async Task Observer_GrainCallFilter_Incoming_GrainLevel_Test() + { + var observer = new MethodInterceptionGrainObserver(); + var grain = this.fixture.GrainFactory.CreateObjectReference(observer); + var result = await grain.One(); + Assert.Equal("intercepted one with no args", result); + + result = await grain.Echo("stao erom tae"); + Assert.Equal("eat more oats", result);// Grain interceptors should receive the MethodInfo of the implementation, not the interface. + + result = await grain.NotIntercepted(); + Assert.Equal("not intercepted", result); + + result = await grain.SayHello(); + Assert.Equal("Hello", result); + } + + /// + /// Tests filters on generic grains. + /// + [Fact] + public async Task Observer_GrainCallFilter_Incoming_GenericGrain_Test() + { + var observer = new GenericMethodInterceptionGrainObserver(); + var grain = this.fixture.GrainFactory.CreateObjectReference>(observer); + var result = await grain.GetInputAsString(679); + Assert.Contains("Hah!", result); + Assert.Contains("679", result); + + result = await grain.SayHello(); + Assert.Equal("Hello", result); + } + + /// + /// Tests filters on grains which implement multiple of the same generic interface. + /// + [Fact] + public async Task Observer_GrainCallFilter_Incoming_ConstructedGenericInheritance_Test() + { + var observer = new TrickyInterceptionGrainObserver(); + var grain = this.fixture.GrainFactory.CreateObjectReference(observer); + + var result = await grain.GetInputAsString("2014-12-19T14:32:50Z"); + Assert.Contains("Hah!", result); + Assert.Contains("2014-12-19T14:32:50Z", result); + + result = await grain.SayHello(); + Assert.Equal("Hello", result); + + var bestNumber = await grain.GetBestNumber(); + Assert.Equal(38, bestNumber); + + result = await grain.GetInputAsString(true); + Assert.Contains(true.ToString(CultureInfo.InvariantCulture), result); + } + + /// + /// Tests that grain call filters can handle exceptions. + /// + [Fact] + public async Task Observer_GrainCallFilter_Incoming_ExceptionHandling_Test() + { + var observer = new MethodInterceptionGrainObserver(); + var grain = this.fixture.GrainFactory.CreateObjectReference(observer); + + // This grain method throws, but the exception should be handled by one of the filters and converted + // into a specific message. + var result = await grain.Throw(); + Assert.NotNull(result); + Assert.Equal("EXCEPTION! Oi!", result); + } + + /// + /// Tests that grain call filters can throw exceptions. + /// + [Fact] + public async Task Observer_GrainCallFilter_Incoming_FilterThrows_Test() + { + var observer = new MethodInterceptionGrainObserver(); + var grain = this.fixture.GrainFactory.CreateObjectReference(observer); + + var exception = await Assert.ThrowsAsync(() => grain.FilterThrows()); + Assert.NotNull(exception); + Assert.Equal("Filter THROW!", exception.Message); + } + + /// + /// Tests that if a grain call filter sets an incorrect result type for , + /// an exception is thrown on the caller. + /// + [Fact] + public async Task Observer_GrainCallFilter_Incoming_SetIncorrectResultType_Test() + { + var observer = new MethodInterceptionGrainObserver(); + var grain = this.fixture.GrainFactory.CreateObjectReference(observer); + + // This grain method throws, but the exception should be handled by one of the filters and converted + // into a specific message. + await Assert.ThrowsAsync(() => grain.IncorrectResultType()); + } } }