diff --git a/src/Http/Http.Extensions/gen/DiagnosticDescriptors.cs b/src/Http/Http.Extensions/gen/DiagnosticDescriptors.cs index d6a54e65fd00..9e89aef8634d 100644 --- a/src/Http/Http.Extensions/gen/DiagnosticDescriptors.cs +++ b/src/Http/Http.Extensions/gen/DiagnosticDescriptors.cs @@ -24,4 +24,16 @@ internal static class DiagnosticDescriptors DiagnosticSeverity.Warning, isEnabledByDefault: true ); + + // This is temporary. The plan is to be able to resolve all parameters to a known EndpointParameterSource. + public static DiagnosticDescriptor GetUnableToResolveParameterDescriptor(string parameterName) + { + return new( + "RDG073", + new LocalizableResourceString(nameof(Resources.UnableToResolveParameter_Title), Resources.ResourceManager, typeof(Resources)), + new LocalizableResourceString(nameof(Resources.FormatUnableToResolveParameter_Message), Resources.ResourceManager, typeof(Resources), parameterName), + "Usage", + DiagnosticSeverity.Hidden, + isEnabledByDefault: true); + } } diff --git a/src/Http/Http.Extensions/gen/RequestDelegateGenerator.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator.cs index 8fa3a9116366..77fca5f07d52 100644 --- a/src/Http/Http.Extensions/gen/RequestDelegateGenerator.cs +++ b/src/Http/Http.Extensions/gen/RequestDelegateGenerator.cs @@ -53,20 +53,10 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { context.ReportDiagnostic(Diagnostic.Create(diagnostic, endpoint.Operation.Syntax.GetLocation(), filePath)); } - foreach (var diagnostic in endpoint.Response.Diagnostics) - { - context.ReportDiagnostic(Diagnostic.Create(diagnostic, endpoint.Operation.Syntax.GetLocation(), filePath)); - } - foreach (var diagnostic in endpoint.Route.Diagnostics) - { - context.ReportDiagnostic(Diagnostic.Create(diagnostic, endpoint.Operation.Syntax.GetLocation(), filePath)); - } }); var endpoints = endpointsWithDiagnostics - .Where(endpoint => endpoint.Diagnostics.Count == 0 && - endpoint.Response.Diagnostics.Count == 0 && - endpoint.Route.Diagnostics.Count == 0) + .Where(endpoint => endpoint.Diagnostics.Count == 0) .WithTrackingName(GeneratorSteps.EndpointsWithoutDiagnosicsStep); var thunks = endpoints.Select((endpoint, _) => $$""" @@ -97,7 +87,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) } {{endpoint.EmitRequestHandler()}} -{{StaticRouteHandlerModelEmitter.EmitFilteredRequestHandler()}} +{{endpoint.EmitFilteredRequestHandler()}} RequestDelegate targetDelegate = filteredInvocation is null ? RequestHandler : RequestHandlerFiltered; var metadata = inferredMetadataResult?.EndpointMetadata ?? ReadOnlyCollection.Empty; diff --git a/src/Http/Http.Extensions/gen/Resources.resx b/src/Http/Http.Extensions/gen/Resources.resx index 6880e701b101..c207634b7525 100644 --- a/src/Http/Http.Extensions/gen/Resources.resx +++ b/src/Http/Http.Extensions/gen/Resources.resx @@ -1,17 +1,17 @@ - @@ -129,4 +129,10 @@ Unable to statically resolve endpoint handler method. Only method groups, lambda expressions or readonly fields/variables are allowed. Compile-time endpoint generation will skip this endpoint. - + + Unable to statically resolve parameter '{0}' for endpoint. Compile-time endpoint generation will skip this endpoint. + + + Unable to resolve parameter + + \ No newline at end of file diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Endpoint.cs b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Endpoint.cs index b5a694770f80..c698758755bc 100644 --- a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Endpoint.cs +++ b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Endpoint.cs @@ -1,5 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. + using System; using System.Collections.Generic; using System.Linq; @@ -12,61 +13,118 @@ namespace Microsoft.AspNetCore.Http.Generators.StaticRouteHandlerModel; internal class Endpoint { - public string HttpMethod { get; } - public EndpointRoute Route { get; } - public EndpointResponse Response { get; } - public List Diagnostics { get; } = new List(); - public (string, int) Location { get; } - public IInvocationOperation Operation { get; } - - private WellKnownTypes WellKnownTypes { get; } + private string? _argumentListCache; public Endpoint(IInvocationOperation operation, WellKnownTypes wellKnownTypes) { Operation = operation; - WellKnownTypes = wellKnownTypes; - Location = GetLocation(); - HttpMethod = GetHttpMethod(); - Response = new EndpointResponse(Operation, wellKnownTypes); - Route = new EndpointRoute(Operation); - } + Location = GetLocation(operation); + HttpMethod = GetHttpMethod(operation); - private (string, int) GetLocation() - { - var filePath = Operation.Syntax.SyntaxTree.FilePath; - var span = Operation.Syntax.SyntaxTree.GetLineSpan(Operation.Syntax.Span); - var lineNumber = span.EndLinePosition.Line + 1; - return (filePath, lineNumber); - } + if (!operation.TryGetRouteHandlerPattern(out var routeToken)) + { + Diagnostics.Add(DiagnosticDescriptors.UnableToResolveRoutePattern); + return; + } - private string GetHttpMethod() - { - var syntax = (InvocationExpressionSyntax)Operation.Syntax; - var expression = (MemberAccessExpressionSyntax)syntax.Expression; - var name = (IdentifierNameSyntax)expression.Name; - var identifier = name.Identifier; - return identifier.ValueText; + RoutePattern = routeToken.ValueText; + + if (!operation.TryGetRouteHandlerMethod(out var method)) + { + Diagnostics.Add(DiagnosticDescriptors.UnableToResolveMethod); + return; + } + + Response = new EndpointResponse(method, wellKnownTypes); + + if (method.Parameters.Length == 0) + { + return; + } + + var parameters = new EndpointParameter[method.Parameters.Length]; + + for (var i = 0; i < method.Parameters.Length; i++) + { + var parameter = new EndpointParameter(method.Parameters[i], wellKnownTypes); + + if (parameter.Source == EndpointParameterSource.Unknown) + { + Diagnostics.Add(DiagnosticDescriptors.GetUnableToResolveParameterDescriptor(parameter.Name)); + return; + } + + parameters[i] = parameter; + } + + Parameters = parameters; } - public override bool Equals(object o) + public string HttpMethod { get; } + public string? RoutePattern { get; } + public EndpointResponse? Response { get; } + public EndpointParameter[] Parameters { get; } = Array.Empty(); + public string EmitArgumentList() => _argumentListCache ??= string.Join(", ", Parameters.Select(p => p.EmitArgument())); + + public List Diagnostics { get; } = new List(); + + public (string File, int LineNumber) Location { get; } + public IInvocationOperation Operation { get; } + + public override bool Equals(object o) => + o is Endpoint other && Location == other.Location && SignatureEquals(this, other); + + public override int GetHashCode() => + HashCode.Combine(Location, GetSignatureHashCode(this)); + + public static bool SignatureEquals(Endpoint a, Endpoint b) { - if (o is null) + if (!a.Response.WrappedResponseType.Equals(b.Response.WrappedResponseType, StringComparison.Ordinal) || + !a.HttpMethod.Equals(b.HttpMethod, StringComparison.Ordinal) || + a.Parameters.Length != b.Parameters.Length) { return false; } - if (o is Endpoint endpoint) + for (var i = 0; i < a.Parameters.Length; i++) { - return endpoint.HttpMethod.Equals(HttpMethod, StringComparison.OrdinalIgnoreCase) && - endpoint.Location.Item1.Equals(Location.Item1, StringComparison.OrdinalIgnoreCase) && - endpoint.Location.Item2.Equals(Location.Item2) && - endpoint.Response.Equals(Response) && - endpoint.Diagnostics.SequenceEqual(Diagnostics); + if (a.Parameters[i].Equals(b.Parameters[i])) + { + return false; + } } - return false; + return true; } - public override int GetHashCode() => - HashCode.Combine(HttpMethod, Route, Location, Response, Diagnostics); + public static int GetSignatureHashCode(Endpoint endpoint) + { + var hashCode = new HashCode(); + hashCode.Add(endpoint.Response.WrappedResponseType); + hashCode.Add(endpoint.HttpMethod); + + foreach (var parameter in endpoint.Parameters) + { + hashCode.Add(parameter); + } + + return hashCode.ToHashCode(); + } + + private static (string, int) GetLocation(IInvocationOperation operation) + { + var filePath = operation.Syntax.SyntaxTree.FilePath; + var span = operation.Syntax.SyntaxTree.GetLineSpan(operation.Syntax.Span); + var lineNumber = span.StartLinePosition.Line + 1; + return (filePath, lineNumber); + } + + private static string GetHttpMethod(IInvocationOperation operation) + { + var syntax = (InvocationExpressionSyntax)operation.Syntax; + var expression = (MemberAccessExpressionSyntax)syntax.Expression; + var name = (IdentifierNameSyntax)expression.Name; + var identifier = name.Identifier; + return identifier.ValueText; + } } diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointDelegateComparer.cs b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointDelegateComparer.cs index 7d159ebd7176..41491f11ec2b 100644 --- a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointDelegateComparer.cs +++ b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointDelegateComparer.cs @@ -1,31 +1,14 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; -using System.Collections; + using System.Collections.Generic; + namespace Microsoft.AspNetCore.Http.Generators.StaticRouteHandlerModel; -internal sealed class EndpointDelegateComparer : IEqualityComparer, IComparer +internal sealed class EndpointDelegateComparer : IEqualityComparer { public static readonly EndpointDelegateComparer Instance = new EndpointDelegateComparer(); - public bool Equals(Endpoint a, Endpoint b) => Compare(a, b) == 0; - - public int GetHashCode(Endpoint endpoint) => HashCode.Combine( - endpoint.Response.WrappedResponseType, - endpoint.Response.IsVoid, - endpoint.Response.IsAwaitable, - endpoint.HttpMethod); - - public int Compare(Endpoint a, Endpoint b) - { - if (a.Response.IsAwaitable == b.Response.IsAwaitable && - a.Response.IsVoid == b.Response.IsVoid && - a.Response.WrappedResponseType.Equals(b.Response.WrappedResponseType, StringComparison.Ordinal) && - a.HttpMethod.Equals(b.HttpMethod, StringComparison.Ordinal)) - { - return 0; - } - return -1; - } + public bool Equals(Endpoint a, Endpoint b) => Endpoint.SignatureEquals(a, b); + public int GetHashCode(Endpoint endpoint) => Endpoint.GetSignatureHashCode(endpoint); } diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointParameter.cs b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointParameter.cs new file mode 100644 index 000000000000..72a6f1f0a948 --- /dev/null +++ b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointParameter.cs @@ -0,0 +1,94 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.AspNetCore.App.Analyzers.Infrastructure; +using Microsoft.CodeAnalysis; +using WellKnownType = Microsoft.AspNetCore.App.Analyzers.Infrastructure.WellKnownTypeData.WellKnownType; + +namespace Microsoft.AspNetCore.Http.Generators.StaticRouteHandlerModel; + +internal class EndpointParameter +{ + public EndpointParameter(IParameterSymbol parameter, WellKnownTypes wellKnownTypes) + { + Type = parameter.Type; + Name = parameter.Name; + Source = EndpointParameterSource.Unknown; + + if (GetSpecialTypeCallingCode(Type, wellKnownTypes) is string callingCode) + { + Source = EndpointParameterSource.SpecialType; + CallingCode = callingCode; + } + } + + public ITypeSymbol Type { get; } + public EndpointParameterSource Source { get; } + + // TODO: If the parameter has [FromRoute("AnotherName")] or similar, prefer that. + public string Name { get; } + public string? CallingCode { get; } + + public string EmitArgument() + { + switch (Source) + { + case EndpointParameterSource.SpecialType: + return CallingCode!; + default: + // Eventually there should be know unknown parameter sources, but in the meantime we don't expect them to get this far. + // The netstandard2.0 target means there is no UnreachableException. + throw new Exception("Unreachable!"); + } + } + + // TODO: Handle special form types like IFormFileCollection that need special body-reading logic. + private static string? GetSpecialTypeCallingCode(ITypeSymbol type, WellKnownTypes wellKnownTypes) + { + if (SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_HttpContext))) + { + return "httpContext"; + } + if (SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_HttpRequest))) + { + return "httpContext.Request"; + } + if (SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_HttpResponse))) + { + return "httpContext.Response"; + } + if (SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownType.System_IO_Pipelines_PipeReader))) + { + return "httpContext.Request.BodyReader"; + } + if (SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownType.System_IO_Stream))) + { + return "httpContext.Request.Body"; + } + if (SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownType.System_Security_Claims_ClaimsPrincipal))) + { + return "httpContext.User"; + } + if (SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownType.System_Threading_CancellationToken))) + { + return "httpContext.RequestAborted"; + } + + return null; + } + + public override bool Equals(object obj) => + obj is EndpointParameter other && + other.Source == Source && + other.Name == Name && + SymbolEqualityComparer.Default.Equals(other.Type, Type); + + public override int GetHashCode() + { + var hashCode = new HashCode(); + hashCode.Add(Name); + hashCode.Add(Type, SymbolEqualityComparer.Default); + return hashCode.ToHashCode(); + } +} diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointParameterSource.cs b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointParameterSource.cs new file mode 100644 index 000000000000..9c7d592c0664 --- /dev/null +++ b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointParameterSource.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http.Generators.StaticRouteHandlerModel; + +internal enum EndpointParameterSource +{ + Route, + Query, + // This should only be necessary if the route pattern is not statically analyzable + RouteOrQuery, + Header, + JsonBody, + JsonBodyOrService, + FormBody, + Service, + // SpecialType refers to HttpContext, HttpRequest, CancellationToken, Stream, etc... + // that are specially checked for in RequestDelegateFactory.CreateArgument() + SpecialType, + BindAsync, + // Unknown should be temporary for development. + Unknown, +} diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointResponse.cs b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointResponse.cs index 4f27b0700ce2..611891215c8b 100644 --- a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointResponse.cs +++ b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointResponse.cs @@ -2,19 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Generic; using System.Linq; using Microsoft.AspNetCore.App.Analyzers.Infrastructure; using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.Operations; namespace Microsoft.AspNetCore.Http.Generators.StaticRouteHandlerModel; using WellKnownType = WellKnownTypeData.WellKnownType; -public class EndpointResponse +internal class EndpointResponse { - public ITypeSymbol ResponseType { get; set; } + public ITypeSymbol? ResponseType { get; set; } public string WrappedResponseType { get; set; } public string ContentType { get; set; } public bool IsAwaitable { get; set; } @@ -23,16 +21,8 @@ public class EndpointResponse private WellKnownTypes WellKnownTypes { get; init; } - public List Diagnostics { get; init; } = new List(); - - internal EndpointResponse(IInvocationOperation operation, WellKnownTypes wellKnownTypes) + internal EndpointResponse(IMethodSymbol method, WellKnownTypes wellKnownTypes) { - if (!operation.TryGetRouteHandlerMethod(out var method)) - { - Diagnostics.Add(DiagnosticDescriptors.UnableToResolveMethod); - return; - } - WellKnownTypes = wellKnownTypes; ResponseType = UnwrapResponseType(method); WrappedResponseType = method.ReturnType.ToString(); diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointRoute.cs b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointRoute.cs deleted file mode 100644 index 66de4f9d88f6..000000000000 --- a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointRoute.cs +++ /dev/null @@ -1,48 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Collections.Generic; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Operations; - -namespace Microsoft.AspNetCore.Http.Generators.StaticRouteHandlerModel; - -public class EndpointRoute -{ - private const int RoutePatternArgumentOrdinal = 1; - - public string RoutePattern { get; init; } - - public List Diagnostics { get; init; } = new List(); - - public EndpointRoute(IInvocationOperation operation) - { - if (!TryGetRouteHandlerPattern(operation, out var routeToken)) - { - Diagnostics.Add(DiagnosticDescriptors.UnableToResolveRoutePattern); - } - - RoutePattern = routeToken.ValueText; - } - - private static bool TryGetRouteHandlerPattern(IInvocationOperation invocation, out SyntaxToken token) - { - IArgumentOperation? argumentOperation = null; - foreach (var argument in invocation.Arguments) - { - if (argument.Parameter?.Ordinal == RoutePatternArgumentOrdinal) - { - argumentOperation = argument; - } - } - if (argumentOperation?.Syntax is not ArgumentSyntax routePatternArgumentSyntax || - routePatternArgumentSyntax.Expression is not LiteralExpressionSyntax routePatternArgumentLiteralSyntax) - { - token = default; - return false; - } - token = routePatternArgumentLiteralSyntax.Token; - return true; - } -} diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/InvocationOperationExtensions.cs b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/InvocationOperationExtensions.cs index 6a0ee6783a49..4ce9b7a87904 100644 --- a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/InvocationOperationExtensions.cs +++ b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/InvocationOperationExtensions.cs @@ -7,8 +7,9 @@ namespace Microsoft.AspNetCore.Http.Generators.StaticRouteHandlerModel; -public static class InvocationOperationExtensions +internal static class InvocationOperationExtensions { + private const int RoutePatternArgumentOrdinal = 1; private const int RouteHandlerArgumentOrdinal = 2; public static bool TryGetRouteHandlerMethod(this IInvocationOperation invocation, out IMethodSymbol method) @@ -25,6 +26,26 @@ public static bool TryGetRouteHandlerMethod(this IInvocationOperation invocation return false; } + public static bool TryGetRouteHandlerPattern(this IInvocationOperation invocation, out SyntaxToken token) + { + IArgumentOperation? argumentOperation = null; + foreach (var argument in invocation.Arguments) + { + if (argument.Parameter?.Ordinal == RoutePatternArgumentOrdinal) + { + argumentOperation = argument; + } + } + if (argumentOperation?.Syntax is not ArgumentSyntax routePatternArgumentSyntax || + routePatternArgumentSyntax.Expression is not LiteralExpressionSyntax routePatternArgumentLiteralSyntax) + { + token = default; + return false; + } + token = routePatternArgumentLiteralSyntax.Token; + return true; + } + private static IMethodSymbol ResolveMethodFromOperation(IOperation operation) => operation switch { IArgumentOperation argument => ResolveMethodFromOperation(argument.Value), diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/StaticRouteHandlerModel.Emitter.cs b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/StaticRouteHandlerModel.Emitter.cs index a167d600bb62..4ca73fb4c1d4 100644 --- a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/StaticRouteHandlerModel.Emitter.cs +++ b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/StaticRouteHandlerModel.Emitter.cs @@ -2,41 +2,38 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Linq; +using System.Text; using Microsoft.CodeAnalysis; namespace Microsoft.AspNetCore.Http.Generators.StaticRouteHandlerModel; internal static class StaticRouteHandlerModelEmitter { - /* - * TODO: Emit code that represents the signature of the delegate - * represented by the handler. When the handler does not return a value - * but consumes parameters the following will be emitted: - * - * ``` - * System.Action - * ``` - * - * Where `string` and `int` represent parameter types. For handlers - * that do return a value, `System.Func` will - * be emitted to indicate a `string`return type. - */ public static string EmitHandlerDelegateType(this Endpoint endpoint) { - if (endpoint.Response.IsVoid) + if (endpoint.Parameters.Length == 0) { - return $"System.Action"; + return endpoint.Response.IsVoid ? "System.Action" : $"System.Func<{endpoint.Response.WrappedResponseType}>"; } - if (endpoint.Response.IsAwaitable) + else { - return $"System.Func<{endpoint.Response.WrappedResponseType}>"; + var parameterTypeList = string.Join(", ", endpoint.Parameters.Select(p => p.Type)); + + if (endpoint.Response.IsVoid) + { + return $"System.Action<{parameterTypeList}>"; + } + else + { + return $"System.Func<{parameterTypeList}, {endpoint.Response.WrappedResponseType}>"; + } } - return $"System.Func<{endpoint.Response.ResponseType}>"; } public static string EmitSourceKey(this Endpoint endpoint) { - return $@"(@""{endpoint.Location.Item1}"", {endpoint.Location.Item2})"; + return $@"(@""{endpoint.Location.File}"", {endpoint.Location.LineNumber})"; } public static string EmitVerb(this Endpoint endpoint) @@ -68,7 +65,7 @@ public static string EmitRequestHandler(this Endpoint endpoint) {{handlerSignature}} { {{setContentType}} - {{resultAssignment}}{{awaitHandler}}handler(); + {{resultAssignment}}{{awaitHandler}}handler({{endpoint.EmitArgumentList()}}); {{(endpoint.Response.IsVoid ? "return Task.CompletedTask;" : endpoint.EmitResponseWritingCall())}} } """; @@ -111,42 +108,50 @@ private static string EmitResponseWritingCall(this Endpoint endpoint) * can be used to reduce the boxing that happens at runtime when constructing * the context object. */ - public static string EmitFilteredRequestHandler() + public static string EmitFilteredRequestHandler(this Endpoint endpoint) { - return """ + var argumentList = endpoint.Parameters.Length == 0 ? string.Empty : $", {endpoint.EmitArgumentList()}"; + + return $$""" async Task RequestHandlerFiltered(HttpContext httpContext) { - var result = await filteredInvocation(new DefaultEndpointFilterInvocationContext(httpContext)); + var result = await filteredInvocation(new DefaultEndpointFilterInvocationContext(httpContext{{argumentList}})); await GeneratedRouteBuilderExtensionsCore.ExecuteObjectResult(result, httpContext); } """; } - /* - * TODO: Emit code that will call the `handler` with - * the appropriate arguments processed via the parameter binding. - * - * ``` - * return ValueTask.FromResult(handler(name, age)); - * ``` - * - * If the handler returns void, it will be invoked and an `EmptyHttpResult` - * will be returned to the user. - * - * ``` - * handler(name, age); - * return ValueTask.FromResult(Results.Empty); - * ``` - */ public static string EmitFilteredInvocation(this Endpoint endpoint) { // Note: This string does not need indentation since it is // handled when we generate the output string in the `thunks` pipeline. - return endpoint.Response.IsVoid ? """ -handler(); + return endpoint.Response.IsVoid ? $""" +handler({endpoint.EmitFilteredArgumentList()}); return ValueTask.FromResult(Results.Empty); -""" : """ -return ValueTask.FromResult(handler()); +""" : $""" +return ValueTask.FromResult(handler({endpoint.EmitFilteredArgumentList()})); """; } + + public static string EmitFilteredArgumentList(this Endpoint endpoint) + { + if (endpoint.Parameters.Length == 0) + { + return ""; + } + + var sb = new StringBuilder(); + + for (var i = 0; i < endpoint.Parameters.Length; i++) + { + sb.Append($"ic.GetArgument<{endpoint.Parameters[i].Type}>({i})"); + + if (i < endpoint.Parameters.Length - 1) + { + sb.Append(", "); + } + } + + return sb.ToString(); + } } diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/WellKnownTypeData.cs b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/WellKnownTypeData.cs index 72657181113a..9b26b8c76e37 100644 --- a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/WellKnownTypeData.cs +++ b/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/WellKnownTypeData.cs @@ -7,19 +7,39 @@ internal static class WellKnownTypeData { public enum WellKnownType { + Microsoft_AspNetCore_Http_HttpContext, + Microsoft_AspNetCore_Http_HttpRequest, + Microsoft_AspNetCore_Http_HttpResponse, + Microsoft_AspNetCore_Http_IFormCollection, + Microsoft_AspNetCore_Http_IFormFileCollection, + Microsoft_AspNetCore_Http_IFormFile, Microsoft_AspNetCore_Http_IResult, + System_IO_Pipelines_PipeReader, + System_IO_Stream, + System_Security_Claims_ClaimsPrincipal, + System_Threading_CancellationToken, System_Threading_Tasks_Task, System_Threading_Tasks_Task_T, System_Threading_Tasks_ValueTask, - System_Threading_Tasks_ValueTask_T + System_Threading_Tasks_ValueTask_T, } public static readonly string[] WellKnownTypeNames = new[] { + "Microsoft.AspNetCore.Http.HttpContext", + "Microsoft.AspNetCore.Http.HttpRequest", + "Microsoft.AspNetCore.Http.HttpResponse", + "Microsoft.AspNetCore.Http.IFormCollection", + "Microsoft.AspNetCore.Http.IFormFileCollection", + "Microsoft.AspNetCore.Http.IFormFile", "Microsoft.AspNetCore.Http.IResult", + "System.IO.Pipelines.PipeReader", + "System.IO.Stream", + "System.Security.Claims.ClaimsPrincipal", + "System.Threading.CancellationToken", "System.Threading.Tasks.Task", "System.Threading.Tasks.Task`1", "System.Threading.Tasks.ValueTask", - "System.Threading.Tasks.ValueTask`1" + "System.Threading.Tasks.ValueTask`1", }; } diff --git a/src/Http/Http.Extensions/test/RequestDelegateGenerator/Baselines/MapGet_NoParam_StringReturn_WithFilter.generated.txt b/src/Http/Http.Extensions/test/RequestDelegateGenerator/Baselines/MapAction_NoParam_StringReturn_WithFilter.generated.txt similarity index 100% rename from src/Http/Http.Extensions/test/RequestDelegateGenerator/Baselines/MapGet_NoParam_StringReturn_WithFilter.generated.txt rename to src/Http/Http.Extensions/test/RequestDelegateGenerator/Baselines/MapAction_NoParam_StringReturn_WithFilter.generated.txt diff --git a/src/Http/Http.Extensions/test/RequestDelegateGenerator/Baselines/Multiple_MapAction_WithParams_StringReturn.generated.txt b/src/Http/Http.Extensions/test/RequestDelegateGenerator/Baselines/Multiple_MapAction_WithParams_StringReturn.generated.txt new file mode 100644 index 000000000000..1b85f761bb5d --- /dev/null +++ b/src/Http/Http.Extensions/test/RequestDelegateGenerator/Baselines/Multiple_MapAction_WithParams_StringReturn.generated.txt @@ -0,0 +1,291 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ +#nullable enable + +namespace Microsoft.AspNetCore.Builder +{ + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.Generators, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + internal class SourceKey + { + public string Path { get; init; } + public int Line { get; init; } + + public SourceKey(string path, int line) + { + Path = path; + Line = line; + } + } + + // This class needs to be internal so that the compiled application + // has access to the strongly-typed endpoint definitions that are + // generated by the compiler so that they will be favored by + // overload resolution and opt the runtime in to the code generated + // implementation produced here. + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.Generators, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + internal static class GenerateRouteBuilderEndpoints + { + private static readonly string[] GetVerb = new[] { global::Microsoft.AspNetCore.Http.HttpMethods.Get }; + private static readonly string[] PostVerb = new[] { global::Microsoft.AspNetCore.Http.HttpMethods.Post }; + private static readonly string[] PutVerb = new[] { global::Microsoft.AspNetCore.Http.HttpMethods.Put }; + private static readonly string[] DeleteVerb = new[] { global::Microsoft.AspNetCore.Http.HttpMethods.Delete }; + private static readonly string[] PatchVerb = new[] { global::Microsoft.AspNetCore.Http.HttpMethods.Patch }; + + internal static global::Microsoft.AspNetCore.Builder.RouteHandlerBuilder MapGet( + this global::Microsoft.AspNetCore.Routing.IEndpointRouteBuilder endpoints, + [global::System.Diagnostics.CodeAnalysis.StringSyntax("Route")] string pattern, + global::System.Func handler, + [global::System.Runtime.CompilerServices.CallerFilePath] string filePath = "", + [global::System.Runtime.CompilerServices.CallerLineNumber]int lineNumber = 0) + { + return global::Microsoft.AspNetCore.Http.Generated.GeneratedRouteBuilderExtensionsCore.MapCore( + endpoints, + pattern, + handler, + GetVerb, + filePath, + lineNumber); + } + internal static global::Microsoft.AspNetCore.Builder.RouteHandlerBuilder MapGet( + this global::Microsoft.AspNetCore.Routing.IEndpointRouteBuilder endpoints, + [global::System.Diagnostics.CodeAnalysis.StringSyntax("Route")] string pattern, + global::System.Func handler, + [global::System.Runtime.CompilerServices.CallerFilePath] string filePath = "", + [global::System.Runtime.CompilerServices.CallerLineNumber]int lineNumber = 0) + { + return global::Microsoft.AspNetCore.Http.Generated.GeneratedRouteBuilderExtensionsCore.MapCore( + endpoints, + pattern, + handler, + GetVerb, + filePath, + lineNumber); + } + internal static global::Microsoft.AspNetCore.Builder.RouteHandlerBuilder MapGet( + this global::Microsoft.AspNetCore.Routing.IEndpointRouteBuilder endpoints, + [global::System.Diagnostics.CodeAnalysis.StringSyntax("Route")] string pattern, + global::System.Func handler, + [global::System.Runtime.CompilerServices.CallerFilePath] string filePath = "", + [global::System.Runtime.CompilerServices.CallerLineNumber]int lineNumber = 0) + { + return global::Microsoft.AspNetCore.Http.Generated.GeneratedRouteBuilderExtensionsCore.MapCore( + endpoints, + pattern, + handler, + GetVerb, + filePath, + lineNumber); + } + + } +} + +namespace Microsoft.AspNetCore.Http.Generated +{ + using System; + using System.Collections; + using System.Collections.Generic; + using System.Collections.ObjectModel; + using System.Diagnostics; + using System.Linq; + using System.Reflection; + using System.Threading.Tasks; + using System.IO; + using Microsoft.AspNetCore.Routing; + using Microsoft.AspNetCore.Routing.Patterns; + using Microsoft.AspNetCore.Builder; + using Microsoft.AspNetCore.Http; + using Microsoft.AspNetCore.Http.Metadata; + using Microsoft.Extensions.DependencyInjection; + using Microsoft.Extensions.FileProviders; + using Microsoft.Extensions.Primitives; + + using MetadataPopulator = System.Func; + using RequestDelegateFactoryFunc = System.Func; + + file static class GeneratedRouteBuilderExtensionsCore + { + + private static readonly Dictionary<(string, int), (MetadataPopulator, RequestDelegateFactoryFunc)> map = new() + { + [(@"TestMapActions.cs", 15)] = ( + (methodInfo, options) => + { + Debug.Assert(options?.EndpointBuilder != null, "EndpointBuilder not found."); + options.EndpointBuilder.Metadata.Add(new SourceKey(@"TestMapActions.cs", 15)); + return new RequestDelegateMetadataResult { EndpointMetadata = options.EndpointBuilder.Metadata.AsReadOnly() }; + }, + (del, options, inferredMetadataResult) => + { + var handler = (System.Func)del; + EndpointFilterDelegate? filteredInvocation = null; + + if (options?.EndpointBuilder?.FilterFactories.Count > 0) + { + filteredInvocation = GeneratedRouteBuilderExtensionsCore.BuildFilterDelegate(ic => + { + if (ic.HttpContext.Response.StatusCode == 400) + { + return ValueTask.FromResult(Results.Empty); + } + return ValueTask.FromResult(handler(ic.GetArgument(0))); + }, + options.EndpointBuilder, + handler.Method); + } + + Task RequestHandler(HttpContext httpContext) + { + httpContext.Response.ContentType ??= "text/plain"; + var result = handler(httpContext.Request); + return httpContext.Response.WriteAsync(result); + } + async Task RequestHandlerFiltered(HttpContext httpContext) + { + var result = await filteredInvocation(new DefaultEndpointFilterInvocationContext(httpContext, httpContext.Request)); + await GeneratedRouteBuilderExtensionsCore.ExecuteObjectResult(result, httpContext); + } + + RequestDelegate targetDelegate = filteredInvocation is null ? RequestHandler : RequestHandlerFiltered; + var metadata = inferredMetadataResult?.EndpointMetadata ?? ReadOnlyCollection.Empty; + return new RequestDelegateResult(targetDelegate, metadata); + }), + [(@"TestMapActions.cs", 16)] = ( + (methodInfo, options) => + { + Debug.Assert(options?.EndpointBuilder != null, "EndpointBuilder not found."); + options.EndpointBuilder.Metadata.Add(new SourceKey(@"TestMapActions.cs", 16)); + return new RequestDelegateMetadataResult { EndpointMetadata = options.EndpointBuilder.Metadata.AsReadOnly() }; + }, + (del, options, inferredMetadataResult) => + { + var handler = (System.Func)del; + EndpointFilterDelegate? filteredInvocation = null; + + if (options?.EndpointBuilder?.FilterFactories.Count > 0) + { + filteredInvocation = GeneratedRouteBuilderExtensionsCore.BuildFilterDelegate(ic => + { + if (ic.HttpContext.Response.StatusCode == 400) + { + return ValueTask.FromResult(Results.Empty); + } + return ValueTask.FromResult(handler(ic.GetArgument(0))); + }, + options.EndpointBuilder, + handler.Method); + } + + Task RequestHandler(HttpContext httpContext) + { + httpContext.Response.ContentType ??= "text/plain"; + var result = handler(httpContext.Response); + return httpContext.Response.WriteAsync(result); + } + async Task RequestHandlerFiltered(HttpContext httpContext) + { + var result = await filteredInvocation(new DefaultEndpointFilterInvocationContext(httpContext, httpContext.Response)); + await GeneratedRouteBuilderExtensionsCore.ExecuteObjectResult(result, httpContext); + } + + RequestDelegate targetDelegate = filteredInvocation is null ? RequestHandler : RequestHandlerFiltered; + var metadata = inferredMetadataResult?.EndpointMetadata ?? ReadOnlyCollection.Empty; + return new RequestDelegateResult(targetDelegate, metadata); + }), + [(@"TestMapActions.cs", 17)] = ( + (methodInfo, options) => + { + Debug.Assert(options?.EndpointBuilder != null, "EndpointBuilder not found."); + options.EndpointBuilder.Metadata.Add(new SourceKey(@"TestMapActions.cs", 17)); + return new RequestDelegateMetadataResult { EndpointMetadata = options.EndpointBuilder.Metadata.AsReadOnly() }; + }, + (del, options, inferredMetadataResult) => + { + var handler = (System.Func)del; + EndpointFilterDelegate? filteredInvocation = null; + + if (options?.EndpointBuilder?.FilterFactories.Count > 0) + { + filteredInvocation = GeneratedRouteBuilderExtensionsCore.BuildFilterDelegate(ic => + { + if (ic.HttpContext.Response.StatusCode == 400) + { + return ValueTask.FromResult(Results.Empty); + } + return ValueTask.FromResult(handler(ic.GetArgument(0), ic.GetArgument(1))); + }, + options.EndpointBuilder, + handler.Method); + } + + Task RequestHandler(HttpContext httpContext) + { + httpContext.Response.ContentType ??= "text/plain"; + var result = handler(httpContext.Request, httpContext.Response); + return httpContext.Response.WriteAsync(result); + } + async Task RequestHandlerFiltered(HttpContext httpContext) + { + var result = await filteredInvocation(new DefaultEndpointFilterInvocationContext(httpContext, httpContext.Request, httpContext.Response)); + await GeneratedRouteBuilderExtensionsCore.ExecuteObjectResult(result, httpContext); + } + + RequestDelegate targetDelegate = filteredInvocation is null ? RequestHandler : RequestHandlerFiltered; + var metadata = inferredMetadataResult?.EndpointMetadata ?? ReadOnlyCollection.Empty; + return new RequestDelegateResult(targetDelegate, metadata); + }), + + }; + + internal static RouteHandlerBuilder MapCore( + this IEndpointRouteBuilder routes, + string pattern, + Delegate handler, + IEnumerable httpMethods, + string filePath, + int lineNumber) + { + var (populateMetadata, createRequestDelegate) = map[(filePath, lineNumber)]; + return RouteHandlerServices.Map(routes, pattern, handler, httpMethods, populateMetadata, createRequestDelegate); + } + + private static EndpointFilterDelegate BuildFilterDelegate(EndpointFilterDelegate filteredInvocation, EndpointBuilder builder, MethodInfo mi) + { + var routeHandlerFilters = builder.FilterFactories; + var context0 = new EndpointFilterFactoryContext + { + MethodInfo = mi, + ApplicationServices = builder.ApplicationServices, + }; + var initialFilteredInvocation = filteredInvocation; + for (var i = routeHandlerFilters.Count - 1; i >= 0; i--) + { + var filterFactory = routeHandlerFilters[i]; + filteredInvocation = filterFactory(context0, filteredInvocation); + } + return filteredInvocation; + } + + private static Task ExecuteObjectResult(object? obj, HttpContext httpContext) + { + if (obj is IResult r) + { + return r.ExecuteAsync(httpContext); + } + else if (obj is string s) + { + return httpContext.Response.WriteAsync(s); + } + else + { + return httpContext.Response.WriteAsJsonAsync(obj); + } + } + } +} diff --git a/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateGeneratorTestBase.cs b/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateGeneratorTestBase.cs index 6322006f7355..f3143dd790ef 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateGeneratorTestBase.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateGeneratorTestBase.cs @@ -6,10 +6,10 @@ using System.Runtime.Loader; using System.Text; using Microsoft.AspNetCore.Builder; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; using Microsoft.AspNetCore.Routing; using Microsoft.AspNetCore.Testing; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.Emit; using Microsoft.CodeAnalysis.Text; using Microsoft.Extensions.DependencyInjection; @@ -50,21 +50,27 @@ public class RequestDelegateGeneratorTestBase : LoggedTest return (Assert.Single(runResult.Results), updatedCompilation); } - internal static StaticRouteHandlerModel.Endpoint GetStaticEndpoint(GeneratorRunResult result, string stepName) + internal static StaticRouteHandlerModel.Endpoint GetStaticEndpoint(GeneratorRunResult result, string stepName) => + Assert.Single(GetStaticEndpoints(result, stepName)); + + internal static StaticRouteHandlerModel.Endpoint[] GetStaticEndpoints(GeneratorRunResult result, string stepName) { // We only invoke the generator once in our test scenarios if (result.TrackedSteps.TryGetValue(stepName, out var staticEndpointSteps)) { - var staticEndpointStep = staticEndpointSteps.Single(); - var staticEndpointOutput = staticEndpointStep.Outputs.Single(); - var (staticEndpoint, _) = staticEndpointOutput; - var endpoint = Assert.IsType(staticEndpoint); - return endpoint; + return staticEndpointSteps + .SelectMany(step => step.Outputs) + .Select(output => Assert.IsType(output.Value)) + .ToArray(); } - return null; + + return Array.Empty(); } - internal static Endpoint GetEndpointFromCompilation(Compilation compilation, bool checkSourceKey = true) + internal static Endpoint GetEndpointFromCompilation(Compilation compilation, bool expectSourceKey = true) => + Assert.Single(GetEndpointsFromCompilation(compilation, expectSourceKey)); + + internal static Endpoint[] GetEndpointsFromCompilation(Compilation compilation, bool expectSourceKey = true) { var assemblyName = compilation.AssemblyName!; var symbolsName = Path.ChangeExtension(assemblyName, "pdb"); @@ -106,6 +112,7 @@ internal static Endpoint GetEndpointFromCompilation(Compilation compilation, boo var handler = assembly.GetType("TestMapActions") ?.GetMethod("MapTestEndpoints", BindingFlags.Public | BindingFlags.Static) ?.CreateDelegate>(); + var sourceKeyType = assembly.GetType("Microsoft.AspNetCore.Builder.SourceKey"); Assert.NotNull(handler); @@ -113,17 +120,25 @@ internal static Endpoint GetEndpointFromCompilation(Compilation compilation, boo _ = handler(builder); var dataSource = Assert.Single(builder.DataSources); + // Trigger Endpoint build by calling getter. - var endpoint = Assert.Single(dataSource.Endpoints); + var endpoints = dataSource.Endpoints.ToArray(); - if (checkSourceKey) + foreach (var endpoint in endpoints) { - var sourceKeyType = assembly.GetType("Microsoft.AspNetCore.Builder.SourceKey"); - var sourceKeyMetadata = endpoint.Metadata.Single(metadata => metadata.GetType() == sourceKeyType); - Assert.NotNull(sourceKeyMetadata); + var sourceKeyMetadata = endpoint.Metadata.FirstOrDefault(metadata => metadata.GetType() == sourceKeyType); + + if (expectSourceKey) + { + Assert.NotNull(sourceKeyMetadata); + } + else + { + Assert.Null(sourceKeyMetadata); + } } - return endpoint; + return endpoints; } internal HttpContext CreateHttpContext() @@ -140,6 +155,16 @@ internal HttpContext CreateHttpContext() return httpContext; } + internal static async Task VerifyResponseBodyAsync(HttpContext httpContext, string expectedBody) + { + var httpResponse = httpContext.Response; + httpResponse.Body.Seek(0, SeekOrigin.Begin); + var streamReader = new StreamReader(httpResponse.Body); + var body = await streamReader.ReadToEndAsync(); + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.Equal(expectedBody, body); + } + private static string GetMapActionString(string sources) => $$""" #nullable enable using System; @@ -278,7 +303,7 @@ public bool TryResolveAssemblyPaths(CompilationLibrary library, List ass } } - private class EmptyServiceProvider : IServiceScope, IServiceProvider, IServiceScopeFactory + private class EmptyServiceProvider : IServiceScope, IServiceProvider, IServiceScopeFactory, IServiceProviderIsService { public IServiceProvider ServiceProvider => this; @@ -291,8 +316,18 @@ public void Dispose() { } public object GetService(Type serviceType) { + if (IsService(serviceType)) + { + return this; + } + return null; } + + public bool IsService(Type serviceType) => + serviceType == typeof(IServiceProvider) || + serviceType == typeof(IServiceScopeFactory) || + serviceType == typeof(IServiceProviderIsService); } private class DefaultEndpointRouteBuilder : IEndpointRouteBuilder diff --git a/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateGeneratorTests.cs b/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateGeneratorTests.cs index bc7d46e42c43..1a38c9c24f1b 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateGeneratorTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateGeneratorTests.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using Microsoft.AspNetCore.Http.Generators.StaticRouteHandlerModel; + namespace Microsoft.AspNetCore.Http.Generators.Tests; public class RequestDelegateGeneratorTests : RequestDelegateGeneratorTestBase @@ -20,28 +22,112 @@ public async Task MapAction_NoParam_StringReturn(string source, string httpMetho var endpointModel = GetStaticEndpoint(result, GeneratorSteps.EndpointModelStep); var endpoint = GetEndpointFromCompilation(compilation); - var requestDelegate = endpoint.RequestDelegate; - Assert.Equal("/hello", endpointModel.Route.RoutePattern); + Assert.Equal("/hello", endpointModel.RoutePattern); Assert.Equal(httpMethod, endpointModel.HttpMethod); - var httpContext = new DefaultHttpContext(); + var httpContext = CreateHttpContext(); + await endpoint.RequestDelegate(httpContext); + await VerifyResponseBodyAsync(httpContext, expectedBody); + } + + [Theory] + [InlineData("HttpContext")] + [InlineData("HttpRequest")] + [InlineData("HttpResponse")] + [InlineData("System.IO.Pipelines.PipeReader")] + [InlineData("System.IO.Stream")] + [InlineData("System.Security.Claims.ClaimsPrincipal")] + [InlineData("System.Threading.CancellationToken")] + public async Task MapAction_SingleSpecialTypeParam_StringReturn(string parameterType) + { + var (results, compilation) = await RunGeneratorAsync($""" +app.MapGet("/hello", ({parameterType} p) => p == null ? "null!" : "Hello world!"); +"""); + + var endpointModel = GetStaticEndpoint(results, GeneratorSteps.EndpointModelStep); + var endpoint = GetEndpointFromCompilation(compilation); + + Assert.Equal("/hello", endpointModel.RoutePattern); + Assert.Equal("MapGet", endpointModel.HttpMethod); + var p = Assert.Single(endpointModel.Parameters); + Assert.Equal(EndpointParameterSource.SpecialType, p.Source); + Assert.Equal("p", p.Name); + + var httpContext = CreateHttpContext(); + await endpoint.RequestDelegate(httpContext); + await VerifyResponseBodyAsync(httpContext, "Hello world!"); + } - var outStream = new MemoryStream(); - httpContext.Response.Body = outStream; + [Fact] + public async Task MapAction_MultipleSpecialTypeParam_StringReturn() + { + var (results, compilation) = await RunGeneratorAsync(""" +app.MapGet("/hello", (HttpRequest req, HttpResponse res) => req is null || res is null ? "null!" : "Hello world!"); +"""); - await requestDelegate(httpContext); + var endpointModel = GetStaticEndpoint(results, GeneratorSteps.EndpointModelStep); + var endpoint = GetEndpointFromCompilation(compilation); - var httpResponse = httpContext.Response; - httpResponse.Body.Seek(0, SeekOrigin.Begin); - var streamReader = new StreamReader(httpResponse.Body); - var body = await streamReader.ReadToEndAsync(); - Assert.Equal(200, httpContext.Response.StatusCode); - Assert.Equal(expectedBody, body); + Assert.Equal("/hello", endpointModel.RoutePattern); + Assert.Equal("MapGet", endpointModel.HttpMethod); + + Assert.Collection(endpointModel.Parameters, + reqParam => + { + Assert.Equal(EndpointParameterSource.SpecialType, reqParam.Source); + Assert.Equal("req", reqParam.Name); + }, + reqParam => + { + Assert.Equal(EndpointParameterSource.SpecialType, reqParam.Source); + Assert.Equal("res", reqParam.Name); + }); + + var httpContext = CreateHttpContext(); + await endpoint.RequestDelegate(httpContext); + await VerifyResponseBodyAsync(httpContext, "Hello world!"); } [Fact] - public async Task MapGet_NoParam_StringReturn_WithFilter() + public async Task MapGet_WithRequestDelegate_DoesNotGenerateSources() + { + var (results, compilation) = await RunGeneratorAsync(""" +app.MapGet("/hello", (HttpContext context) => Task.CompletedTask); +"""); + + Assert.Empty(GetStaticEndpoints(results, GeneratorSteps.EndpointModelStep)); + + var endpoint = GetEndpointFromCompilation(compilation, expectSourceKey: false); + + var httpContext = CreateHttpContext(); + await endpoint.RequestDelegate(httpContext); + await VerifyResponseBodyAsync(httpContext, ""); + } + + [Fact] + public async Task MapAction_MultilineLambda() + { + var source = """ +app.MapGet("/hello", () => +{ + return "Hello world!"; +}); +"""; + var (result, compilation) = await RunGeneratorAsync(source); + + var endpointModel = GetStaticEndpoint(result, GeneratorSteps.EndpointModelStep); + var endpoint = GetEndpointFromCompilation(compilation); + + Assert.Equal("/hello", endpointModel.RoutePattern); + + var httpContext = CreateHttpContext(); + await endpoint.RequestDelegate(httpContext); + await VerifyResponseBodyAsync(httpContext, "Hello world!"); + } + + [Fact] + public async Task MapAction_NoParam_StringReturn_WithFilter() { var source = """ app.MapGet("/hello", () => "Hello world!") @@ -57,23 +143,12 @@ public async Task MapGet_NoParam_StringReturn_WithFilter() var endpointModel = GetStaticEndpoint(result, GeneratorSteps.EndpointModelStep); var endpoint = GetEndpointFromCompilation(compilation); - var requestDelegate = endpoint.RequestDelegate; - - Assert.Equal("/hello", endpointModel.Route.RoutePattern); - var httpContext = new DefaultHttpContext(); + Assert.Equal("/hello", endpointModel.RoutePattern); - var outStream = new MemoryStream(); - httpContext.Response.Body = outStream; - - await requestDelegate(httpContext); - - var httpResponse = httpContext.Response; - httpResponse.Body.Seek(0, SeekOrigin.Begin); - var streamReader = new StreamReader(httpResponse.Body); - var body = await streamReader.ReadToEndAsync(); - Assert.Equal(200, httpContext.Response.StatusCode); - Assert.Equal(expectedBody, body); + var httpContext = CreateHttpContext(); + await endpoint.RequestDelegate(httpContext); + await VerifyResponseBodyAsync(httpContext, expectedBody); } [Theory] @@ -86,24 +161,13 @@ public async Task MapAction_NoParam_AnyReturn(string source, string expectedBody var endpointModel = GetStaticEndpoint(result, GeneratorSteps.EndpointModelStep); var endpoint = GetEndpointFromCompilation(compilation); - var requestDelegate = endpoint.RequestDelegate; - Assert.Equal("/", endpointModel.Route.RoutePattern); + Assert.Equal("/", endpointModel.RoutePattern); Assert.Equal("MapGet", endpointModel.HttpMethod); - var httpContext = new DefaultHttpContext(); - - var outStream = new MemoryStream(); - httpContext.Response.Body = outStream; - - await requestDelegate(httpContext); - - var httpResponse = httpContext.Response; - httpResponse.Body.Seek(0, SeekOrigin.Begin); - var streamReader = new StreamReader(httpResponse.Body); - var body = await streamReader.ReadToEndAsync(); - Assert.Equal(200, httpContext.Response.StatusCode); - Assert.Equal(expectedBody, body); + var httpContext = CreateHttpContext(); + await endpoint.RequestDelegate(httpContext); + await VerifyResponseBodyAsync(httpContext, expectedBody); } [Theory] @@ -120,21 +184,13 @@ public async Task MapAction_NoParam_ComplexReturn(string source) var endpointModel = GetStaticEndpoint(result, GeneratorSteps.EndpointModelStep); var endpoint = GetEndpointFromCompilation(compilation); - var requestDelegate = endpoint.RequestDelegate; - Assert.Equal("/", endpointModel.Route.RoutePattern); + Assert.Equal("/", endpointModel.RoutePattern); Assert.Equal("MapGet", endpointModel.HttpMethod); var httpContext = CreateHttpContext(); - - await requestDelegate(httpContext); - - var httpResponse = httpContext.Response; - httpResponse.Body.Seek(0, SeekOrigin.Begin); - var streamReader = new StreamReader(httpResponse.Body); - var body = await streamReader.ReadToEndAsync(); - Assert.Equal(200, httpContext.Response.StatusCode); - Assert.Equal(expectedBody, body); + await endpoint.RequestDelegate(httpContext); + await VerifyResponseBodyAsync(httpContext, expectedBody); } [Theory] @@ -149,7 +205,7 @@ public async Task MapAction_ProducesCorrectContentType(string source, string exp var endpointModel = GetStaticEndpoint(result, GeneratorSteps.EndpointModelStep); - Assert.Equal("/", endpointModel.Route.RoutePattern); + Assert.Equal("/", endpointModel.RoutePattern); Assert.Equal("MapGet", endpointModel.HttpMethod); Assert.Equal(expectedContentType, endpointModel.Response.ContentType); } @@ -164,22 +220,14 @@ public async Task MapAction_NoParam_TaskOfTReturn(string source, string expected var endpointModel = GetStaticEndpoint(result, GeneratorSteps.EndpointModelStep); var endpoint = GetEndpointFromCompilation(compilation); - var requestDelegate = endpoint.RequestDelegate; - Assert.Equal("/", endpointModel.Route.RoutePattern); + Assert.Equal("/", endpointModel.RoutePattern); Assert.Equal("MapGet", endpointModel.HttpMethod); Assert.True(endpointModel.Response.IsAwaitable); var httpContext = CreateHttpContext(); - - await requestDelegate(httpContext); - - var httpResponse = httpContext.Response; - httpResponse.Body.Seek(0, SeekOrigin.Begin); - var streamReader = new StreamReader(httpResponse.Body); - var body = await streamReader.ReadToEndAsync(); - Assert.Equal(200, httpContext.Response.StatusCode); - Assert.Equal(expectedBody, body); + await endpoint.RequestDelegate(httpContext); + await VerifyResponseBodyAsync(httpContext, expectedBody); } [Theory] @@ -192,22 +240,14 @@ public async Task MapAction_NoParam_ValueTaskOfTReturn(string source, string exp var endpointModel = GetStaticEndpoint(result, GeneratorSteps.EndpointModelStep); var endpoint = GetEndpointFromCompilation(compilation); - var requestDelegate = endpoint.RequestDelegate; - Assert.Equal("/", endpointModel.Route.RoutePattern); + Assert.Equal("/", endpointModel.RoutePattern); Assert.Equal("MapGet", endpointModel.HttpMethod); Assert.True(endpointModel.Response.IsAwaitable); var httpContext = CreateHttpContext(); - - await requestDelegate(httpContext); - - var httpResponse = httpContext.Response; - httpResponse.Body.Seek(0, SeekOrigin.Begin); - var streamReader = new StreamReader(httpResponse.Body); - var body = await streamReader.ReadToEndAsync(); - Assert.Equal(200, httpContext.Response.StatusCode); - Assert.Equal(expectedBody, body); + await endpoint.RequestDelegate(httpContext); + await VerifyResponseBodyAsync(httpContext, expectedBody); } [Theory] @@ -223,22 +263,14 @@ public async Task MapAction_NoParam_TaskLikeOfObjectReturn(string source, string var endpointModel = GetStaticEndpoint(result, GeneratorSteps.EndpointModelStep); var endpoint = GetEndpointFromCompilation(compilation); - var requestDelegate = endpoint.RequestDelegate; - Assert.Equal("/", endpointModel.Route.RoutePattern); + Assert.Equal("/", endpointModel.RoutePattern); Assert.Equal("MapGet", endpointModel.HttpMethod); Assert.True(endpointModel.Response.IsAwaitable); var httpContext = CreateHttpContext(); - - await requestDelegate(httpContext); - - var httpResponse = httpContext.Response; - httpResponse.Body.Seek(0, SeekOrigin.Begin); - var streamReader = new StreamReader(httpResponse.Body); - var body = await streamReader.ReadToEndAsync(); - Assert.Equal(200, httpContext.Response.StatusCode); - Assert.Equal(expectedBody, body); + await endpoint.RequestDelegate(httpContext); + await VerifyResponseBodyAsync(httpContext, expectedBody); } [Fact] @@ -255,6 +287,70 @@ public async Task Multiple_MapAction_NoParam_StringReturn() await VerifyAgainstBaselineUsingFile(compilation); } + [Fact] + public async Task Multiple_MapAction_WithParams_StringReturn() + { + var source = """ +app.MapGet("/en", (HttpRequest req) => "Hello world!"); +app.MapGet("/es", (HttpResponse res) => "Hola mundo!"); +app.MapGet("/zh", (HttpRequest req, HttpResponse res) => "你好世界!"); +"""; + var (results, compilation) = await RunGeneratorAsync(source); + + await VerifyAgainstBaselineUsingFile(compilation); + + var endpointModels = GetStaticEndpoints(results, GeneratorSteps.EndpointModelStep); + + Assert.Collection(endpointModels, + endpointModel => + { + Assert.Equal("/en", endpointModel.RoutePattern); + Assert.Equal("MapGet", endpointModel.HttpMethod); + var reqParam = Assert.Single(endpointModel.Parameters); + Assert.Equal(EndpointParameterSource.SpecialType, reqParam.Source); + Assert.Equal("req", reqParam.Name); + }, + endpointModel => + { + Assert.Equal("/es", endpointModel.RoutePattern); + Assert.Equal("MapGet", endpointModel.HttpMethod); + var reqParam = Assert.Single(endpointModel.Parameters); + Assert.Equal(EndpointParameterSource.SpecialType, reqParam.Source); + Assert.Equal("res", reqParam.Name); + }, + endpointModel => + { + Assert.Equal("/zh", endpointModel.RoutePattern); + Assert.Equal("MapGet", endpointModel.HttpMethod); + Assert.Collection(endpointModel.Parameters, + reqParam => + { + Assert.Equal(EndpointParameterSource.SpecialType, reqParam.Source); + Assert.Equal("req", reqParam.Name); + }, + reqParam => + { + Assert.Equal(EndpointParameterSource.SpecialType, reqParam.Source); + Assert.Equal("res", reqParam.Name); + }); + }); + + var endpoints = GetEndpointsFromCompilation(compilation); + + Assert.Equal(3, endpoints.Length); + var httpContext = CreateHttpContext(); + await endpoints[0].RequestDelegate(httpContext); + await VerifyResponseBodyAsync(httpContext, "Hello world!"); + + httpContext = CreateHttpContext(); + await endpoints[1].RequestDelegate(httpContext); + await VerifyResponseBodyAsync(httpContext, "Hola mundo!"); + + httpContext = CreateHttpContext(); + await endpoints[2].RequestDelegate(httpContext); + await VerifyResponseBodyAsync(httpContext, "你好世界!"); + } + [Fact] public async Task MapAction_VariableRoutePattern_EmitsDiagnostic_NoSource() { @@ -271,19 +367,36 @@ public async Task MapAction_VariableRoutePattern_EmitsDiagnostic_NoSource() Assert.Empty(result.GeneratedSources); // Falls back to runtime-generated endpoint - var endpoint = GetEndpointFromCompilation(compilation, checkSourceKey: false); - var requestDelegate = endpoint.RequestDelegate; + var endpoint = GetEndpointFromCompilation(compilation, expectSourceKey: false); var httpContext = CreateHttpContext(); + await endpoint.RequestDelegate(httpContext); + await VerifyResponseBodyAsync(httpContext, expectedBody); + } - await requestDelegate(httpContext); + [Fact] + public async Task MapAction_UnknownParameter_EmitsDiagnostic_NoSource() + { + // This will eventually be handled by the EndpointParameterSource.JsonBodyOrService. + // All parameters should theoretically be handleable with enough "Or"s in the future + // we'll remove this test and diagnostic. + var source = """ +app.MapGet("/", (IServiceProvider provider) => "Hello world!"); +"""; + var expectedBody = "Hello world!"; + var (result, compilation) = await RunGeneratorAsync(source); + + // Emits diagnostic but generates no source + var diagnostic = Assert.Single(result.Diagnostics); + Assert.Equal(DiagnosticDescriptors.GetUnableToResolveParameterDescriptor("provider").Id, diagnostic.Id); + Assert.Empty(result.GeneratedSources); + + // Falls back to runtime-generated endpoint + var endpoint = GetEndpointFromCompilation(compilation, expectSourceKey: false); - var httpResponse = httpContext.Response; - httpResponse.Body.Seek(0, SeekOrigin.Begin); - var streamReader = new StreamReader(httpResponse.Body); - var body = await streamReader.ReadToEndAsync(); - Assert.Equal(200, httpContext.Response.StatusCode); - Assert.Equal(expectedBody, body); + var httpContext = CreateHttpContext(); + await endpoint.RequestDelegate(httpContext); + await VerifyResponseBodyAsync(httpContext, expectedBody); } [Fact] @@ -293,10 +406,9 @@ public async Task MapAction_RequestDelegateHandler_DoesNotEmit() app.MapGet("/", (HttpContext context) => context.Response.WriteAsync("Hello world")); """; var (result, _) = await RunGeneratorAsync(source); - var endpointModel = GetStaticEndpoint(result, GeneratorSteps.EndpointModelStep); + var endpointModels = GetStaticEndpoints(result, GeneratorSteps.EndpointModelStep); - // Endpoint model is null because we don't pass transform - Assert.Null(endpointModel); Assert.Empty(result.GeneratedSources); + Assert.Empty(endpointModels); } } diff --git a/src/Shared/RoslynUtils/WellKnownTypes.cs b/src/Shared/RoslynUtils/WellKnownTypes.cs index 3d8c22d17b10..27e44a583cbc 100644 --- a/src/Shared/RoslynUtils/WellKnownTypes.cs +++ b/src/Shared/RoslynUtils/WellKnownTypes.cs @@ -116,8 +116,13 @@ public bool Implements(ITypeSymbol type, WellKnownTypeData.WellKnownType[] inter return false; } - public static bool Implements(ITypeSymbol type, ITypeSymbol interfaceType) + public static bool Implements(ITypeSymbol? type, ITypeSymbol interfaceType) { + if (type is null) + { + return false; + } + foreach (var t in type.AllInterfaces) { if (SymbolEqualityComparer.Default.Equals(t, interfaceType))