diff --git a/src/Orleans.CodeGenerator/InvokableGenerator.cs b/src/Orleans.CodeGenerator/InvokableGenerator.cs index 06e4696df3..3ec56db2ac 100644 --- a/src/Orleans.CodeGenerator/InvokableGenerator.cs +++ b/src/Orleans.CodeGenerator/InvokableGenerator.cs @@ -7,6 +7,7 @@ using System.Linq; using Orleans.CodeGenerator.Diagnostics; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using System.Threading; namespace Orleans.CodeGenerator { @@ -45,7 +46,14 @@ public static (ClassDeclarationSyntax Syntax, GeneratedInvokerDescription Invoke .AddMembers(fields); if (ctor != null) + { classDeclaration = classDeclaration.AddMembers(ctor); + } + + if (method.ResponseTimeoutTicks.HasValue) + { + classDeclaration = classDeclaration.AddMembers(GenerateResponseTimeoutPropertyMembers(libraryTypes, method.ResponseTimeoutTicks.Value)); + } classDeclaration = AddOptionalMembers(classDeclaration, GenerateGetArgumentCount(method), @@ -119,6 +127,29 @@ static Accessibility GetAccessibility(InvokableInterfaceDescription interfaceDes } } + private static MemberDeclarationSyntax[] GenerateResponseTimeoutPropertyMembers(LibraryTypes libraryTypes, long value) + { + var timespanField = FieldDeclaration( + VariableDeclaration( + libraryTypes.TimeSpan.ToTypeSyntax(), + SingletonSeparatedList(VariableDeclarator("_responseTimeoutValue") + .WithInitializer(EqualsValueClause( + InvocationExpression( + IdentifierName("global::System.TimeSpan").Member("FromTicks"), + ArgumentList(SeparatedList(new[] + { + Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(value))) + })))))))) + .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword)); + + var responseTimeoutProperty = MethodDeclaration(NullableType(libraryTypes.TimeSpan.ToTypeSyntax()), "GetDefaultResponseTimeout") + .WithExpressionBody(ArrowExpressionClause(IdentifierName("_responseTimeoutValue"))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) + .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword)); +; + return new MemberDeclarationSyntax[] { timespanField, responseTimeoutProperty }; + } + private static ClassDeclarationSyntax AddOptionalMembers(ClassDeclarationSyntax decl, params MemberDeclarationSyntax[] items) => decl.WithMembers(decl.Members.AddRange(items.Where(i => i != null))); diff --git a/src/Orleans.CodeGenerator/LibraryTypes.cs b/src/Orleans.CodeGenerator/LibraryTypes.cs index ef78e92549..5c4a0a65c2 100644 --- a/src/Orleans.CodeGenerator/LibraryTypes.cs +++ b/src/Orleans.CodeGenerator/LibraryTypes.cs @@ -1,3 +1,4 @@ +#nullable enable using System; using System.Collections.Concurrent; using System.Collections.Generic; @@ -11,85 +12,83 @@ namespace Orleans.CodeGenerator { internal sealed class LibraryTypes { - private LibraryTypes() { } - - public static LibraryTypes FromCompilation(Compilation compilation, CodeGeneratorOptions options) + public static LibraryTypes FromCompilation(Compilation compilation, CodeGeneratorOptions options) => new LibraryTypes(compilation, options); + private LibraryTypes(Compilation compilation, CodeGeneratorOptions options) { - return new LibraryTypes - { - Compilation = compilation, - ApplicationPartAttribute = Type("Orleans.ApplicationPartAttribute"), - Action_2 = Type("System.Action`2"), - TypeManifestProviderBase = Type("Orleans.Serialization.Configuration.TypeManifestProviderBase"), - Field = Type("Orleans.Serialization.WireProtocol.Field"), - FieldCodec_1 = Type("Orleans.Serialization.Codecs.IFieldCodec`1"), - AbstractTypeSerializer = Type("Orleans.Serialization.Serializers.AbstractTypeSerializer`1"), - DeepCopier_1 = Type("Orleans.Serialization.Cloning.IDeepCopier`1"), - ShallowCopier = Type("Orleans.Serialization.Cloning.ShallowCopier`1"), - CompoundTypeAliasAttribute = Type("Orleans.CompoundTypeAliasAttribute"), - CopyContext = Type("Orleans.Serialization.Cloning.CopyContext"), - MethodInfo = Type("System.Reflection.MethodInfo"), - Func_2 = Type("System.Func`2"), - GenerateMethodSerializersAttribute = Type("Orleans.GenerateMethodSerializersAttribute"), - GenerateSerializerAttribute = Type("Orleans.GenerateSerializerAttribute"), - SerializationCallbacksAttribute = Type("Orleans.SerializationCallbacksAttribute"), - IActivator_1 = Type("Orleans.Serialization.Activators.IActivator`1"), - IBufferWriter = Type("System.Buffers.IBufferWriter`1"), - IdAttributeTypes = options.IdAttributes.Select(Type).ToArray(), - ConstructorAttributeTypes = options.ConstructorAttributes.Select(Type).ToArray(), - AliasAttribute = Type("Orleans.AliasAttribute"), - IInvokable = Type("Orleans.Serialization.Invocation.IInvokable"), - InvokeMethodNameAttribute = Type("Orleans.InvokeMethodNameAttribute"), - RuntimeHelpers = Type("System.Runtime.CompilerServices.RuntimeHelpers"), - InvokableCustomInitializerAttribute = Type("Orleans.InvokableCustomInitializerAttribute"), - DefaultInvokableBaseTypeAttribute = Type("Orleans.DefaultInvokableBaseTypeAttribute"), - GenerateCodeForDeclaringAssemblyAttribute = Type("Orleans.GenerateCodeForDeclaringAssemblyAttribute"), - InvokableBaseTypeAttribute = Type("Orleans.InvokableBaseTypeAttribute"), - ReturnValueProxyAttribute = Type("Orleans.Invocation.ReturnValueProxyAttribute"), - RegisterSerializerAttribute = Type("Orleans.RegisterSerializerAttribute"), - GeneratedActivatorConstructorAttribute = Type("Orleans.GeneratedActivatorConstructorAttribute"), - SerializerTransparentAttribute = Type("Orleans.SerializerTransparentAttribute"), - RegisterActivatorAttribute = Type("Orleans.RegisterActivatorAttribute"), - RegisterConverterAttribute = Type("Orleans.RegisterConverterAttribute"), - RegisterCopierAttribute = Type("Orleans.RegisterCopierAttribute"), - UseActivatorAttribute = Type("Orleans.UseActivatorAttribute"), - SuppressReferenceTrackingAttribute = Type("Orleans.SuppressReferenceTrackingAttribute"), - OmitDefaultMemberValuesAttribute = Type("Orleans.OmitDefaultMemberValuesAttribute"), - ITargetHolder = Type("Orleans.Serialization.Invocation.ITargetHolder"), - TypeManifestProviderAttribute = Type("Orleans.Serialization.Configuration.TypeManifestProviderAttribute"), - NonSerializedAttribute = Type("System.NonSerializedAttribute"), - ObsoleteAttribute = Type("System.ObsoleteAttribute"), - BaseCodec_1 = Type("Orleans.Serialization.Serializers.IBaseCodec`1"), - BaseCopier_1 = Type("Orleans.Serialization.Cloning.IBaseCopier`1"), - ArrayCodec = Type("Orleans.Serialization.Codecs.ArrayCodec`1"), - ArrayCopier = Type("Orleans.Serialization.Codecs.ArrayCopier`1"), - Reader = Type("Orleans.Serialization.Buffers.Reader`1"), - TypeManifestOptions = Type("Orleans.Serialization.Configuration.TypeManifestOptions"), - Task = Type("System.Threading.Tasks.Task"), - Task_1 = Type("System.Threading.Tasks.Task`1"), - Type = Type("System.Type"), - Uri = Type("System.Uri"), - Int128 = TypeOrDefault("System.Int128"), - UInt128 = TypeOrDefault("System.UInt128"), - Half = TypeOrDefault("System.Half"), - DateOnly = TypeOrDefault("System.DateOnly"), - DateTimeOffset = Type("System.DateTimeOffset"), - BitVector32 = Type("System.Collections.Specialized.BitVector32"), - Guid = Type("System.Guid"), - CompareInfo = Type("System.Globalization.CompareInfo"), - CultureInfo = Type("System.Globalization.CultureInfo"), - Version = Type("System.Version"), - TimeOnly = TypeOrDefault("System.TimeOnly"), - ICodecProvider = Type("Orleans.Serialization.Serializers.ICodecProvider"), - ValueSerializer = Type("Orleans.Serialization.Serializers.IValueSerializer`1"), - ValueTask = Type("System.Threading.Tasks.ValueTask"), - ValueTask_1 = Type("System.Threading.Tasks.ValueTask`1"), - ValueTypeGetter_2 = Type("Orleans.Serialization.Utilities.ValueTypeGetter`2"), - ValueTypeSetter_2 = Type("Orleans.Serialization.Utilities.ValueTypeSetter`2"), - Writer = Type("Orleans.Serialization.Buffers.Writer`1"), - FSharpSourceConstructFlagsOrDefault = TypeOrDefault("Microsoft.FSharp.Core.SourceConstructFlags"), - FSharpCompilationMappingAttributeOrDefault = TypeOrDefault("Microsoft.FSharp.Core.CompilationMappingAttribute"), - StaticCodecs = new List + Compilation = compilation; + ApplicationPartAttribute = Type("Orleans.ApplicationPartAttribute"); + Action_2 = Type("System.Action`2"); + TypeManifestProviderBase = Type("Orleans.Serialization.Configuration.TypeManifestProviderBase"); + Field = Type("Orleans.Serialization.WireProtocol.Field"); + FieldCodec_1 = Type("Orleans.Serialization.Codecs.IFieldCodec`1"); + AbstractTypeSerializer = Type("Orleans.Serialization.Serializers.AbstractTypeSerializer`1"); + DeepCopier_1 = Type("Orleans.Serialization.Cloning.IDeepCopier`1"); + ShallowCopier = Type("Orleans.Serialization.Cloning.ShallowCopier`1"); + CompoundTypeAliasAttribute = Type("Orleans.CompoundTypeAliasAttribute"); + CopyContext = Type("Orleans.Serialization.Cloning.CopyContext"); + MethodInfo = Type("System.Reflection.MethodInfo"); + Func_2 = Type("System.Func`2"); + GenerateMethodSerializersAttribute = Type("Orleans.GenerateMethodSerializersAttribute"); + GenerateSerializerAttribute = Type("Orleans.GenerateSerializerAttribute"); + SerializationCallbacksAttribute = Type("Orleans.SerializationCallbacksAttribute"); + IActivator_1 = Type("Orleans.Serialization.Activators.IActivator`1"); + IBufferWriter = Type("System.Buffers.IBufferWriter`1"); + IdAttributeTypes = options.IdAttributes.Select(Type).ToArray(); + ConstructorAttributeTypes = options.ConstructorAttributes.Select(Type).ToArray(); + AliasAttribute = Type("Orleans.AliasAttribute"); + IInvokable = Type("Orleans.Serialization.Invocation.IInvokable"); + InvokeMethodNameAttribute = Type("Orleans.InvokeMethodNameAttribute"); + RuntimeHelpers = Type("System.Runtime.CompilerServices.RuntimeHelpers"); + InvokableCustomInitializerAttribute = Type("Orleans.InvokableCustomInitializerAttribute"); + DefaultInvokableBaseTypeAttribute = Type("Orleans.DefaultInvokableBaseTypeAttribute"); + GenerateCodeForDeclaringAssemblyAttribute = Type("Orleans.GenerateCodeForDeclaringAssemblyAttribute"); + InvokableBaseTypeAttribute = Type("Orleans.InvokableBaseTypeAttribute"); + ReturnValueProxyAttribute = Type("Orleans.Invocation.ReturnValueProxyAttribute"); + RegisterSerializerAttribute = Type("Orleans.RegisterSerializerAttribute"); + ResponseTimeoutAttribute = Type("Orleans.ResponseTimeoutAttribute"); + GeneratedActivatorConstructorAttribute = Type("Orleans.GeneratedActivatorConstructorAttribute"); + SerializerTransparentAttribute = Type("Orleans.SerializerTransparentAttribute"); + RegisterActivatorAttribute = Type("Orleans.RegisterActivatorAttribute"); + RegisterConverterAttribute = Type("Orleans.RegisterConverterAttribute"); + RegisterCopierAttribute = Type("Orleans.RegisterCopierAttribute"); + UseActivatorAttribute = Type("Orleans.UseActivatorAttribute"); + SuppressReferenceTrackingAttribute = Type("Orleans.SuppressReferenceTrackingAttribute"); + OmitDefaultMemberValuesAttribute = Type("Orleans.OmitDefaultMemberValuesAttribute"); + ITargetHolder = Type("Orleans.Serialization.Invocation.ITargetHolder"); + TypeManifestProviderAttribute = Type("Orleans.Serialization.Configuration.TypeManifestProviderAttribute"); + NonSerializedAttribute = Type("System.NonSerializedAttribute"); + ObsoleteAttribute = Type("System.ObsoleteAttribute"); + BaseCodec_1 = Type("Orleans.Serialization.Serializers.IBaseCodec`1"); + BaseCopier_1 = Type("Orleans.Serialization.Cloning.IBaseCopier`1"); + ArrayCodec = Type("Orleans.Serialization.Codecs.ArrayCodec`1"); + ArrayCopier = Type("Orleans.Serialization.Codecs.ArrayCopier`1"); + Reader = Type("Orleans.Serialization.Buffers.Reader`1"); + TypeManifestOptions = Type("Orleans.Serialization.Configuration.TypeManifestOptions"); + Task = Type("System.Threading.Tasks.Task"); + Task_1 = Type("System.Threading.Tasks.Task`1"); + this.Type = Type("System.Type"); + Uri = Type("System.Uri"); + Int128 = TypeOrDefault("System.Int128"); + UInt128 = TypeOrDefault("System.UInt128"); + Half = TypeOrDefault("System.Half"); + DateOnly = TypeOrDefault("System.DateOnly"); + DateTimeOffset = Type("System.DateTimeOffset"); + BitVector32 = Type("System.Collections.Specialized.BitVector32"); + Guid = Type("System.Guid"); + CompareInfo = Type("System.Globalization.CompareInfo"); + CultureInfo = Type("System.Globalization.CultureInfo"); + Version = Type("System.Version"); + TimeOnly = TypeOrDefault("System.TimeOnly"); + ICodecProvider = Type("Orleans.Serialization.Serializers.ICodecProvider"); + ValueSerializer = Type("Orleans.Serialization.Serializers.IValueSerializer`1"); + ValueTask = Type("System.Threading.Tasks.ValueTask"); + ValueTask_1 = Type("System.Threading.Tasks.ValueTask`1"); + ValueTypeGetter_2 = Type("Orleans.Serialization.Utilities.ValueTypeGetter`2"); + ValueTypeSetter_2 = Type("Orleans.Serialization.Utilities.ValueTypeSetter`2"); + Writer = Type("Orleans.Serialization.Buffers.Writer`1"); + FSharpSourceConstructFlagsOrDefault = TypeOrDefault("Microsoft.FSharp.Core.SourceConstructFlags"); + FSharpCompilationMappingAttributeOrDefault = TypeOrDefault("Microsoft.FSharp.Core.CompilationMappingAttribute"); + StaticCodecs = new List { new(compilation.GetSpecialType(SpecialType.System_Object), Type("Orleans.Serialization.Codecs.ObjectCodec")), new(compilation.GetSpecialType(SpecialType.System_Boolean), Type("Orleans.Serialization.Codecs.BoolCodec")), @@ -122,38 +121,38 @@ public static LibraryTypes FromCompilation(Compilation compilation, CodeGenerato new(TypeOrDefault("System.Int128"), TypeOrDefault("Orleans.Serialization.Codecs.Int128Codec")), new(TypeOrDefault("System.Half"), TypeOrDefault("Orleans.Serialization.Codecs.HalfCodec")), new(Type("System.Uri"), Type("Orleans.Serialization.Codecs.UriCodec")), - }.Where(desc => desc.UnderlyingType is {} && desc.CodecType is {}).ToArray(), - WellKnownCodecs = new WellKnownCodecDescription[] - { + }.Where(desc => desc.UnderlyingType is { } && desc.CodecType is { }).ToArray(); + WellKnownCodecs = new WellKnownCodecDescription[] + { new(Type("System.Exception"), Type("Orleans.Serialization.ExceptionCodec")), new(Type("System.Collections.Generic.Dictionary`2"), Type("Orleans.Serialization.Codecs.DictionaryCodec`2")), new(Type("System.Collections.Generic.List`1"), Type("Orleans.Serialization.Codecs.ListCodec`1")), new(Type("System.Collections.Generic.HashSet`1"), Type("Orleans.Serialization.Codecs.HashSetCodec`1")), new(compilation.GetSpecialType(SpecialType.System_Nullable_T), Type("Orleans.Serialization.Codecs.NullableCodec`1")), - }, - StaticCopiers = new WellKnownCopierDescription[] - { + }; + StaticCopiers = new WellKnownCopierDescription[] + { new(compilation.GetSpecialType(SpecialType.System_Object), Type("Orleans.Serialization.Codecs.ObjectCopier")), new(compilation.CreateArrayTypeSymbol(compilation.GetSpecialType(SpecialType.System_Byte), 1), Type("Orleans.Serialization.Codecs.ByteArrayCopier")), new(Type("System.ReadOnlyMemory`1").Construct(compilation.GetSpecialType(SpecialType.System_Byte)), Type("Orleans.Serialization.Codecs.ReadOnlyMemoryOfByteCopier")), new(Type("System.Memory`1").Construct(compilation.GetSpecialType(SpecialType.System_Byte)), Type("Orleans.Serialization.Codecs.MemoryOfByteCopier")), - }, - WellKnownCopiers = new WellKnownCopierDescription[] - { + }; + WellKnownCopiers = new WellKnownCopierDescription[] + { new(Type("System.Exception"), Type("Orleans.Serialization.ExceptionCodec")), new(Type("System.Collections.Generic.Dictionary`2"), Type("Orleans.Serialization.Codecs.DictionaryCopier`2")), new(Type("System.Collections.Generic.List`1"), Type("Orleans.Serialization.Codecs.ListCopier`1")), new(Type("System.Collections.Generic.HashSet`1"), Type("Orleans.Serialization.Codecs.HashSetCopier`1")), new(compilation.GetSpecialType(SpecialType.System_Nullable_T), Type("Orleans.Serialization.Codecs.NullableCopier`1")), - }, - Exception = Type("System.Exception"), - ImmutableAttributes = options.ImmutableAttributes.Select(Type).ToArray(), - TimeSpan = Type("System.TimeSpan"), - IPAddress = Type("System.Net.IPAddress"), - IPEndPoint = Type("System.Net.IPEndPoint"), - CancellationToken = Type("System.Threading.CancellationToken"), - ImmutableContainerTypes = new[] - { + }; + Exception = Type("System.Exception"); + ImmutableAttributes = options.ImmutableAttributes.Select(Type).ToArray(); + TimeSpan = Type("System.TimeSpan"); + IPAddress = Type("System.Net.IPAddress"); + IPEndPoint = Type("System.Net.IPEndPoint"); + CancellationToken = Type("System.Threading.CancellationToken"); + ImmutableContainerTypes = new[] + { compilation.GetSpecialType(SpecialType.System_Nullable_T), Type("System.Tuple`1"), Type("System.Tuple`2"), @@ -179,10 +178,9 @@ public static LibraryTypes FromCompilation(Compilation compilation, CodeGenerato Type("System.Collections.Immutable.ImmutableSortedDictionary`2"), Type("System.Collections.Immutable.ImmutableSortedSet`1"), Type("System.Collections.Immutable.ImmutableStack`1"), - }, + }; - LanguageVersion = (compilation.SyntaxTrees.FirstOrDefault()?.Options as CSharpParseOptions)?.LanguageVersion - }; + LanguageVersion = (compilation.SyntaxTrees.FirstOrDefault()?.Options as CSharpParseOptions)?.LanguageVersion; INamedTypeSymbol Type(string metadataName) { @@ -195,7 +193,7 @@ INamedTypeSymbol Type(string metadataName) return result; } - INamedTypeSymbol TypeOrDefault(string metadataName) + INamedTypeSymbol? TypeOrDefault(string metadataName) { var result = compilation.GetTypeByMetadataName(metadataName); return result; @@ -230,9 +228,9 @@ INamedTypeSymbol TypeOrDefault(string metadataName) public INamedTypeSymbol Task_1 { get; private set; } public INamedTypeSymbol Type { get; private set; } private INamedTypeSymbol Uri; - private INamedTypeSymbol DateOnly; + private INamedTypeSymbol? DateOnly; private INamedTypeSymbol DateTimeOffset; - private INamedTypeSymbol TimeOnly; + private INamedTypeSymbol? TimeOnly; public INamedTypeSymbol MethodInfo { get; private set; } public INamedTypeSymbol ICodecProvider { get; private set; } public INamedTypeSymbol ValueSerializer { get; private set; } @@ -250,6 +248,7 @@ INamedTypeSymbol TypeOrDefault(string metadataName) public WellKnownCopierDescription[] WellKnownCopiers { get; private set; } public INamedTypeSymbol RegisterCopierAttribute { get; private set; } public INamedTypeSymbol RegisterSerializerAttribute { get; private set; } + public INamedTypeSymbol ResponseTimeoutAttribute { get; private set; } public INamedTypeSymbol RegisterConverterAttribute { get; private set; } public INamedTypeSymbol RegisterActivatorAttribute { get; private set; } public INamedTypeSymbol UseActivatorAttribute { get; private set; } @@ -257,7 +256,7 @@ INamedTypeSymbol TypeOrDefault(string metadataName) public INamedTypeSymbol OmitDefaultMemberValuesAttribute { get; private set; } public INamedTypeSymbol CopyContext { get; private set; } public Compilation Compilation { get; private set; } - private INamedTypeSymbol TimeSpan; + public INamedTypeSymbol TimeSpan { get; private set; } private INamedTypeSymbol IPAddress; private INamedTypeSymbol IPEndPoint; private INamedTypeSymbol CancellationToken; @@ -267,11 +266,11 @@ INamedTypeSymbol TypeOrDefault(string metadataName) private INamedTypeSymbol CompareInfo; private INamedTypeSymbol CultureInfo; private INamedTypeSymbol Version; - private INamedTypeSymbol Int128; - private INamedTypeSymbol UInt128; - private INamedTypeSymbol Half; - private INamedTypeSymbol[] _regularShallowCopyableTypes; - private INamedTypeSymbol[] RegularShallowCopyableType => _regularShallowCopyableTypes ??= new List + private INamedTypeSymbol? Int128; + private INamedTypeSymbol? UInt128; + private INamedTypeSymbol? Half; + private INamedTypeSymbol[]? _regularShallowCopyableTypes; + private INamedTypeSymbol[] RegularShallowCopyableType => _regularShallowCopyableTypes ??= new List { TimeSpan, DateOnly, @@ -290,7 +289,7 @@ INamedTypeSymbol TypeOrDefault(string metadataName) UInt128, Int128, Half - }.Where(t => t is {}).ToArray(); + }.Where(t => t is {}).ToArray()!; public INamedTypeSymbol[] ImmutableAttributes { get; private set; } public INamedTypeSymbol Exception { get; private set; } @@ -304,8 +303,8 @@ INamedTypeSymbol TypeOrDefault(string metadataName) public INamedTypeSymbol SerializationCallbacksAttribute { get; private set; } public INamedTypeSymbol GeneratedActivatorConstructorAttribute { get; private set; } public INamedTypeSymbol SerializerTransparentAttribute { get; private set; } - public INamedTypeSymbol FSharpCompilationMappingAttributeOrDefault { get; private set; } - public INamedTypeSymbol FSharpSourceConstructFlagsOrDefault { get; private set; } + public INamedTypeSymbol? FSharpCompilationMappingAttributeOrDefault { get; private set; } + public INamedTypeSymbol? FSharpSourceConstructFlagsOrDefault { get; private set; } public INamedTypeSymbol RuntimeHelpers { get; private set; } public LanguageVersion? LanguageVersion { get; private set; } @@ -435,7 +434,7 @@ private bool AreShallowCopyable(ImmutableArray fields) internal static class LibraryExtensions { - public static WellKnownCodecDescription FindByUnderlyingType(this WellKnownCodecDescription[] values, ISymbol type) + public static WellKnownCodecDescription? FindByUnderlyingType(this WellKnownCodecDescription[] values, ISymbol type) { foreach (var c in values) if (SymbolEqualityComparer.Default.Equals(c.UnderlyingType, type)) @@ -444,7 +443,7 @@ public static WellKnownCodecDescription FindByUnderlyingType(this WellKnownCodec return null; } - public static WellKnownCopierDescription FindByUnderlyingType(this WellKnownCopierDescription[] values, ISymbol type) + public static WellKnownCopierDescription? FindByUnderlyingType(this WellKnownCopierDescription[] values, ISymbol type) { foreach (var c in values) if (SymbolEqualityComparer.Default.Equals(c.UnderlyingType, type)) diff --git a/src/Orleans.CodeGenerator/Model/MethodDescription.cs b/src/Orleans.CodeGenerator/Model/MethodDescription.cs index 438c66b7d2..19f6468d7b 100644 --- a/src/Orleans.CodeGenerator/Model/MethodDescription.cs +++ b/src/Orleans.CodeGenerator/Model/MethodDescription.cs @@ -122,6 +122,11 @@ private void PopulateOverrides(InvokableInterfaceDescription containingType, IMe CustomInitializerMethods.Add((methodName, methodArgument)); } } + + if (SymbolEqualityComparer.Default.Equals(methodAttr.AttributeClass, containingType.CodeGenerator.LibraryTypes.ResponseTimeoutAttribute)) + { + ResponseTimeoutTicks = TimeSpan.Parse((string)methodAttr.ConstructorArguments[0].Value).Ticks; + } } bool TryGetNamedArgument(ImmutableArray> arguments, string name, out TypedConstant value) @@ -161,6 +166,11 @@ bool TryGetNamedArgument(ImmutableArray> arg /// public Dictionary InvokableBaseTypes { get; } + /// + /// Gets the response timeout ticks, if set. + /// + public long? ResponseTimeoutTicks { get; private set; } + public override int GetHashCode() => SymbolEqualityComparer.Default.GetHashCode(Method); } } \ No newline at end of file diff --git a/src/Orleans.Core.Abstractions/Runtime/GrainReference.cs b/src/Orleans.Core.Abstractions/Runtime/GrainReference.cs index a3e47105ea..4c2a19a031 100644 --- a/src/Orleans.Core.Abstractions/Runtime/GrainReference.cs +++ b/src/Orleans.Core.Abstractions/Runtime/GrainReference.cs @@ -575,6 +575,9 @@ public void AddInvokeMethodOptions(InvokeMethodOptions options) /// public override string ToString() => IRequest.ToString(this); + + /// + public virtual TimeSpan? GetDefaultResponseTimeout() => null; } /// diff --git a/src/Orleans.Core/Runtime/CallbackData.cs b/src/Orleans.Core/Runtime/CallbackData.cs index 07e0bbe210..bc57c251cf 100644 --- a/src/Orleans.Core/Runtime/CallbackData.cs +++ b/src/Orleans.Core/Runtime/CallbackData.cs @@ -1,5 +1,5 @@ using System; -using System.Text; +using System.Diagnostics; using System.Threading; using Microsoft.Extensions.Logging; using Orleans.Serialization.Invocation; @@ -37,10 +37,23 @@ public void OnStatusUpdate(StatusResponse status) public bool IsExpired(long currentTimestamp) { var duration = currentTimestamp - this.stopwatch.GetRawTimestamp(); - return duration > shared.ResponseTimeoutStopwatchTicks; + return duration > GetResponseTimeoutStopwatchTicks(); } - public void OnTimeout(TimeSpan timeout) + private long GetResponseTimeoutStopwatchTicks() + { + var defaultResponseTimeout = (Message.BodyObject as IInvokable)?.GetDefaultResponseTimeout(); + if (defaultResponseTimeout.HasValue) + { + return (long)(defaultResponseTimeout.Value.TotalSeconds * Stopwatch.Frequency); + } + + return shared.ResponseTimeoutStopwatchTicks; + } + + private TimeSpan GetResponseTimeout() => (Message.BodyObject as IInvokable)?.GetDefaultResponseTimeout() ?? shared.ResponseTimeout; + + public void OnTimeout() { if (Interlocked.CompareExchange(ref completed, 1, 0) != 0) { @@ -58,6 +71,7 @@ public void OnTimeout(TimeSpan timeout) var msg = this.Message; // Local working copy var statusMessage = lastKnownStatus is StatusResponse status ? $"Last known status is {status}. " : string.Empty; + var timeout = GetResponseTimeout(); this.shared.Logger.LogWarning( (int)ErrorCode.Runtime_Error_100157, "Response did not arrive on time in {Timeout} for message: {Message}. {StatusMessage}. About to break its promise.", diff --git a/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs b/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs index 9abceac846..9b8d4bd887 100644 --- a/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs +++ b/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs @@ -233,11 +233,6 @@ public void SendRequest(GrainReference target, IInvokable request, IResponseComp var message = this.messageFactory.CreateMessage(request, options); OrleansOutsideRuntimeClientEvent.Log.SendRequest(message); - SendRequestMessage(target, message, context, options); - } - - private void SendRequestMessage(GrainReference target, Message message, IResponseCompletionSource context, InvokeMethodOptions options) - { message.InterfaceType = target.InterfaceType; message.InterfaceVersion = target.InterfaceVersion; var targetGrainId = target.GrainId; @@ -254,7 +249,8 @@ private void SendRequestMessage(GrainReference target, Message message, IRespons if (message.IsExpirableMessage(this.clientMessagingOptions.DropExpiredMessages)) { // don't set expiration for system target messages. - message.TimeToLive = this.clientMessagingOptions.ResponseTimeout; + var ttl = request.GetDefaultResponseTimeout() ?? this.clientMessagingOptions.ResponseTimeout; + message.TimeToLive = ttl; } if (!oneWay) @@ -452,7 +448,7 @@ private void OnCallbackExpiryTick(object state) { var callback = pair.Value; if (callback.IsCompleted) continue; - if (callback.IsExpired(currentStopwatchTicks)) callback.OnTimeout(this.clientMessagingOptions.ResponseTimeout); + if (callback.IsExpired(currentStopwatchTicks)) callback.OnTimeout(); } } diff --git a/src/Orleans.Runtime/Core/InsideRuntimeClient.cs b/src/Orleans.Runtime/Core/InsideRuntimeClient.cs index 686cfc1717..c618f07dbf 100644 --- a/src/Orleans.Runtime/Core/InsideRuntimeClient.cs +++ b/src/Orleans.Runtime/Core/InsideRuntimeClient.cs @@ -151,7 +151,7 @@ public void SendRequest( if (message.IsExpirableMessage(this.messagingOptions.DropExpiredMessages)) { - message.TimeToLive = sharedData.ResponseTimeout; + message.TimeToLive = request.GetDefaultResponseTimeout() ?? sharedData.ResponseTimeout; } var oneWay = (options & InvokeMethodOptions.OneWay) != 0; @@ -532,12 +532,11 @@ public void Participate(ISiloLifecycle lifecycle) private void OnCallbackExpiryTick(object state) { var currentStopwatchTicks = ValueStopwatch.GetTimestamp(); - var responseTimeout = this.messagingOptions.ResponseTimeout; foreach (var pair in callbacks) { var callback = pair.Value; if (callback.IsCompleted) continue; - if (callback.IsExpired(currentStopwatchTicks)) callback.OnTimeout(responseTimeout); + if (callback.IsExpired(currentStopwatchTicks)) callback.OnTimeout(); } } } diff --git a/src/Orleans.Serialization.Abstractions/Annotations.cs b/src/Orleans.Serialization.Abstractions/Annotations.cs index da68cb5767..e6a970d5c9 100644 --- a/src/Orleans.Serialization.Abstractions/Annotations.cs +++ b/src/Orleans.Serialization.Abstractions/Annotations.cs @@ -523,6 +523,24 @@ public GenerateCodeForDeclaringAssemblyAttribute(Type type) public Type Type { get; } } + /// + /// Specifies the response timeout for the interface method which it is specified on. + /// + [AttributeUsage(AttributeTargets.Method)] + public sealed class ResponseTimeoutAttribute : Attribute + { + /// + /// Specifies the response timeout for the interface method which it is specified on. + /// + /// The response timeout, using syntax. + public ResponseTimeoutAttribute(string timeout) => Timeout = TimeSpan.Parse(timeout); + + /// + /// Gets or sets the response timeout for this method. + /// + public TimeSpan? Timeout { get; init; } + } + /// /// Functionality for converting between two types. /// diff --git a/src/Orleans.Serialization/Invocation/IInvokable.cs b/src/Orleans.Serialization/Invocation/IInvokable.cs index 9262bdbf46..bbe1b7ae5e 100644 --- a/src/Orleans.Serialization/Invocation/IInvokable.cs +++ b/src/Orleans.Serialization/Invocation/IInvokable.cs @@ -70,5 +70,10 @@ public interface IInvokable : IDisposable /// Gets the interface type. /// Type GetInterfaceType(); + + /// + /// Gets the default response timeout. + /// + TimeSpan? GetDefaultResponseTimeout(); } } \ No newline at end of file diff --git a/test/DefaultCluster.Tests/EchoTaskGrainTests.cs b/test/DefaultCluster.Tests/EchoTaskGrainTests.cs index d61774ce92..f6c80c8744 100644 --- a/test/DefaultCluster.Tests/EchoTaskGrainTests.cs +++ b/test/DefaultCluster.Tests/EchoTaskGrainTests.cs @@ -71,27 +71,27 @@ await promise.ContinueWith(t => } [Fact, TestCategory("SlowBVT"), TestCategory("Echo"), TestCategory("Timeout")] - public async Task EchoGrain_Timeout_Wait() + public async Task EchoGrain_Timeout_ContinueWith() { grain = this.GrainFactory.GetGrain(Guid.NewGuid()); - TimeSpan delay30 = TimeSpan.FromSeconds(30); // grain call timeout (set in config) + TimeSpan delay5 = TimeSpan.FromSeconds(30); // grain call timeout (set in config) TimeSpan delay45 = TimeSpan.FromSeconds(45); TimeSpan delay60 = TimeSpan.FromSeconds(60); Stopwatch sw = new Stopwatch(); sw.Start(); - Task promise = grain.BlockingCallTimeoutAsync(delay60); + Task promise = grain.BlockingCallTimeoutNoResponseTimeoutOverrideAsync(delay60); await promise.ContinueWith( t => { - if (!t.IsFaulted) Assert.True(false); // BlockingCallTimeout should not have completed successfully + if (!t.IsFaulted) Assert.Fail("BlockingCallTimeout should not have completed successfully"); Exception exc = t.Exception; while (exc is AggregateException) exc = exc.InnerException; Assert.IsAssignableFrom(exc); }).WithTimeout(delay45); sw.Stop(); - Assert.True(TimeIsLonger(sw.Elapsed, delay30), $"Elapsed time out of range: {sw.Elapsed}"); + Assert.True(TimeIsLonger(sw.Elapsed, delay5), $"Elapsed time out of range: {sw.Elapsed}"); Assert.True(TimeIsShorter(sw.Elapsed, delay60), $"Elapsed time out of range: {sw.Elapsed}"); } @@ -100,14 +100,14 @@ public async Task EchoGrain_Timeout_Await() { grain = this.GrainFactory.GetGrain(Guid.NewGuid()); - TimeSpan delay30 = TimeSpan.FromSeconds(30); - TimeSpan delay60 = TimeSpan.FromSeconds(60); + TimeSpan delay5 = TimeSpan.FromSeconds(5); + TimeSpan delay25 = TimeSpan.FromSeconds(25); Stopwatch sw = new Stopwatch(); sw.Start(); try { - int res = await grain.BlockingCallTimeoutAsync(delay60); - Assert.True(false); // BlockingCallTimeout should not have completed successfully + int res = await grain.BlockingCallTimeoutAsync(delay25); + Assert.Fail($"BlockingCallTimeout should not have completed successfully, but returned {res}"); } catch (Exception exc) { @@ -115,23 +115,24 @@ public async Task EchoGrain_Timeout_Await() Assert.IsAssignableFrom(exc); } sw.Stop(); - Assert.True(TimeIsLonger(sw.Elapsed, delay30), $"Elapsed time out of range: {sw.Elapsed}"); - Assert.True(TimeIsShorter(sw.Elapsed, delay60), $"Elapsed time out of range: {sw.Elapsed}"); + Assert.True(TimeIsLonger(sw.Elapsed, delay5), $"Elapsed time out of range: {sw.Elapsed}"); + Assert.True(TimeIsShorter(sw.Elapsed, delay25), $"Elapsed time out of range: {sw.Elapsed}"); } [Fact, TestCategory("SlowBVT"), TestCategory("Echo"), TestCategory("Timeout")] - public async Task EchoGrain_Timeout_Result() + public void EchoGrain_Timeout_Result() { grain = this.GrainFactory.GetGrain(Guid.NewGuid()); - TimeSpan delay30 = TimeSpan.FromSeconds(30); - TimeSpan delay60 = TimeSpan.FromSeconds(60); + TimeSpan delay5 = TimeSpan.FromSeconds(5); + TimeSpan delay25 = TimeSpan.FromSeconds(25); Stopwatch sw = new Stopwatch(); sw.Start(); try { - int res = await grain.BlockingCallTimeoutAsync(delay60); - Assert.True(false, "BlockingCallTimeout should not have completed successfully, but returned " + res); + // Note that this method purposely uses Task.Result. + int res = grain.BlockingCallTimeoutAsync(delay25).Result; + Assert.Fail($"BlockingCallTimeout should not have completed successfully, but returned {res}"); } catch (Exception exc) { @@ -139,8 +140,8 @@ public async Task EchoGrain_Timeout_Result() Assert.IsAssignableFrom(exc); } sw.Stop(); - Assert.True(TimeIsLonger(sw.Elapsed, delay30), $"Elapsed time out of range: {sw.Elapsed}"); - Assert.True(TimeIsShorter(sw.Elapsed, delay60), $"Elapsed time out of range: {sw.Elapsed}"); + Assert.True(TimeIsLonger(sw.Elapsed, delay5), $"Elapsed time out of range: {sw.Elapsed}"); + Assert.True(TimeIsShorter(sw.Elapsed, delay25), $"Elapsed time out of range: {sw.Elapsed}"); } [Fact, TestCategory("BVT"), TestCategory("Echo")] diff --git a/test/Grains/TestGrainInterfaces/IEchoTaskGrain.cs b/test/Grains/TestGrainInterfaces/IEchoTaskGrain.cs index 9e47a28227..ec086c9b0c 100644 --- a/test/Grains/TestGrainInterfaces/IEchoTaskGrain.cs +++ b/test/Grains/TestGrainInterfaces/IEchoTaskGrain.cs @@ -27,9 +27,13 @@ public interface IEchoTaskGrain : IGrainWithGuidKey Task EchoAsync(string data); Task EchoErrorAsync(string data); + [ResponseTimeout("00:00:05")] Task BlockingCallTimeoutAsync(TimeSpan delay); + Task BlockingCallTimeoutNoResponseTimeoutOverrideAsync(TimeSpan delay); + Task PingAsync(); + Task PingLocalSiloAsync(); Task PingRemoteSiloAsync(SiloAddress siloAddress); Task PingOtherSiloAsync(); diff --git a/test/Grains/TestInternalGrains/EchoTaskGrain.cs b/test/Grains/TestInternalGrains/EchoTaskGrain.cs index 96088bdd10..49cd1c6ceb 100644 --- a/test/Grains/TestInternalGrains/EchoTaskGrain.cs +++ b/test/Grains/TestInternalGrains/EchoTaskGrain.cs @@ -135,6 +135,16 @@ public Task BlockingCallTimeoutAsync(TimeSpan delay) throw new InvalidOperationException("Timeout should have been returned to caller before " + delay); } + public Task BlockingCallTimeoutNoResponseTimeoutOverrideAsync(TimeSpan delay) + { + logger.LogInformation("IEchoGrainAsync.BlockingCallTimeoutNoResponseTimeoutOverrideAsync Delay={Delay}", delay); + Stopwatch sw = new Stopwatch(); + sw.Start(); + Thread.Sleep(delay); + logger.LogInformation("IEchoGrainAsync.BlockingCallTimeoutNoResponseTimeoutOverrideAsync Awoke from sleep after {ElapsedDuration}", sw.Elapsed); + throw new InvalidOperationException("Timeout should have been returned to caller before " + delay); + } + public Task PingAsync() { logger.LogInformation("IEchoGrainAsync.Ping"); diff --git a/test/Orleans.Serialization.UnitTests/Request.cs b/test/Orleans.Serialization.UnitTests/Request.cs index cb77f52570..9623ea003a 100644 --- a/test/Orleans.Serialization.UnitTests/Request.cs +++ b/test/Orleans.Serialization.UnitTests/Request.cs @@ -22,6 +22,7 @@ public abstract class UnitTestRequestBase : IInvokable public abstract Type GetInterfaceType(); public abstract MethodInfo GetMethod(); + public virtual TimeSpan? GetDefaultResponseTimeout() => null; } [GenerateSerializer]