From 22f7b3db65dbc099b9df3bac23cc116de0af1370 Mon Sep 17 00:00:00 2001 From: Safia Abdalla Date: Tue, 31 Dec 2024 14:43:53 -0800 Subject: [PATCH] Add initial validations generator for minimal APIs --- AspNetCore.sln | 57 ++- eng/Dependencies.props | 1 + eng/Versions.props | 1 + .../SupportFiles/Directory.Build.targets | 2 +- .../src/PublicAPI.Unshipped.txt | 1 + src/DefaultBuilder/src/WebApplication.cs | 23 +- .../src/WebApplicationBuilder.cs | 2 +- .../DiagnosticDescriptors.cs | 0 .../GeneratorSteps.cs | 0 ...tCore.Http.RequestDelegateGenerator.csproj | 0 .../RequestDelegateGenerator.cs | 0 .../RequestDelegateGeneratorSources.cs | 0 .../RequestDelegateGeneratorSuppressor.cs | 0 .../Resources.resx | 0 .../Emitters/DiagnosticEmitter.cs | 0 .../Emitters/EmitterConstants.cs | 0 .../Emitters/EmitterContext.cs | 0 .../Emitters/EmitterExtensions.cs | 0 .../Emitters/EndpointEmitter.cs | 0 .../EndpointJsonPreparationEmitter.cs | 0 .../Emitters/EndpointParameterEmitter.cs | 0 .../StaticRouteHandlerModel/Endpoint.cs | 0 .../EndpointDelegateComparer.cs | 0 .../EndpointHttpMethodComparer.cs | 0 .../EndpointParameter.cs | 0 .../EndpointParameterSource.cs | 0 .../EndpointResponse.cs | 0 .../InvocationOperationExtensions.cs | 0 .../Model/ConstructorParameter.cs | 0 .../Model/EndpointParameterExtensions.cs | 0 .../Model/ParameterLookupKey.cs | 0 .../StaticRouteHandlerModel.Emitter.cs | 0 .../Emitters/ValidationsGenerator.Emitter.cs | 49 ++ ...ValidationsGenerator.EndpointKeyEmitter.cs | 53 +++ ...idationsGenerator.TypeValidationEmitter.cs | 142 ++++++ ...nsGenerator.ValidationAttributeEmitters.cs | 52 ++ ...nerator.ValidationProblemBuilderEmitter.cs | 110 +++++ ...tionsGenerator.ValidationsFilterEmitter.cs | 101 ++++ ...idationsGenerator.WithValidationEmitter.cs | 47 ++ .../IInvocationOperationExtensions.cs | 86 ++++ .../Extensions/ISymbolExtensions.cs | 32 ++ .../Extensions/ITypeSymbolExtensions.cs | 99 ++++ .../IncrementalValuesProviderExtensions.cs | 85 ++++ ...spNetCore.Http.ValidationsGenerator.csproj | 33 ++ .../Models/RequiredSymbols.cs | 14 + .../Models/ValidatableEndpoint.cs | 12 + .../Models/ValidatableMember.cs | 15 + .../Models/ValidatableParameter.cs | 19 + .../Models/ValidatableType.cs | 16 + .../Models/ValidatableTypeComparer.cs | 29 ++ .../Models/ValidationAttribute.cs | 14 + .../ValidationsGenerator.EndpointsParser.cs | 65 +++ ...idationsGenerator.RequiredSymbolsParser.cs | 21 + .../ValidationsGenerator.TypesParser.cs | 165 +++++++ .../ValidationsGenerator.cs | 50 ++ ...ft.AspNetCore.Http.Extensions.Tests.csproj | 8 +- .../RequestDelegateCreationTests.cs | 2 +- .../ValidationsGenerator/ModuleInitializer.cs | 11 + .../ValidationsGeneratorTests.ComplexType.cs | 292 ++++++++++++ ...ationsGeneratorTests.IValidatableObject.cs | 189 ++++++++ .../ValidationsGeneratorTests.Interception.cs | 90 ++++ .../ValidationsGeneratorTests.Parameters.cs | 82 ++++ .../ValidationsGeneratorTests.Polymorphism.cs | 207 ++++++++ .../ValidationsGeneratorTests.Recursion.cs | 121 +++++ .../ValidationsGeneratorTestsBase.cs | 111 +++++ ...ypes#RouteHandlerValidations.g.verified.cs | 380 +++++++++++++++ ...ject#RouteHandlerValidations.g.verified.cs | 332 +++++++++++++ ...oads#RouteHandlerValidations.g.verified.cs | 278 +++++++++++ ...ters#RouteHandlerValidations.g.verified.cs | 240 ++++++++++ ...ypes#RouteHandlerValidations.g.verified.cs | 446 ++++++++++++++++++ ...ypes#RouteHandlerValidations.g.verified.cs | 242 ++++++++++ ...ypes#RouteHandlerValidations.g.verified.cs | 368 +++++++++++++++ ...oft.AspNetCore.Http.Microbenchmarks.csproj | 2 +- src/Http/HttpAbstractions.slnf | 5 +- .../ValidationConventionBuilderExtensions.cs | 21 + src/Http/Routing/src/PublicAPI.Unshipped.txt | 2 + .../MinimalSample/MinimalSample.csproj | 2 +- 77 files changed, 4796 insertions(+), 31 deletions(-) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/DiagnosticDescriptors.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/GeneratorSteps.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/Microsoft.AspNetCore.Http.RequestDelegateGenerator.csproj (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/RequestDelegateGenerator.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/RequestDelegateGeneratorSources.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/RequestDelegateGeneratorSuppressor.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/Resources.resx (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/StaticRouteHandlerModel/Emitters/DiagnosticEmitter.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/StaticRouteHandlerModel/Emitters/EmitterConstants.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/StaticRouteHandlerModel/Emitters/EmitterContext.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/StaticRouteHandlerModel/Emitters/EmitterExtensions.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/StaticRouteHandlerModel/Emitters/EndpointEmitter.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/StaticRouteHandlerModel/Emitters/EndpointJsonPreparationEmitter.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/StaticRouteHandlerModel/Emitters/EndpointParameterEmitter.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/StaticRouteHandlerModel/Endpoint.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/StaticRouteHandlerModel/EndpointDelegateComparer.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/StaticRouteHandlerModel/EndpointHttpMethodComparer.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/StaticRouteHandlerModel/EndpointParameter.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/StaticRouteHandlerModel/EndpointParameterSource.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/StaticRouteHandlerModel/EndpointResponse.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/StaticRouteHandlerModel/InvocationOperationExtensions.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/StaticRouteHandlerModel/Model/ConstructorParameter.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/StaticRouteHandlerModel/Model/EndpointParameterExtensions.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/StaticRouteHandlerModel/Model/ParameterLookupKey.cs (100%) rename src/Http/Http.Extensions/gen/{ => RequestDelegateGenerator}/StaticRouteHandlerModel/StaticRouteHandlerModel.Emitter.cs (100%) create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.Emitter.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.EndpointKeyEmitter.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.TypeValidationEmitter.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.ValidationAttributeEmitters.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.ValidationProblemBuilderEmitter.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.ValidationsFilterEmitter.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.WithValidationEmitter.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Extensions/IInvocationOperationExtensions.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Extensions/ISymbolExtensions.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Extensions/ITypeSymbolExtensions.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Extensions/IncrementalValuesProviderExtensions.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Microsoft.AspNetCore.Http.ValidationsGenerator.csproj create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Models/RequiredSymbols.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableEndpoint.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableMember.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableParameter.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableType.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableTypeComparer.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidationAttribute.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Parsers/ValidationsGenerator.EndpointsParser.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Parsers/ValidationsGenerator.RequiredSymbolsParser.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/Parsers/ValidationsGenerator.TypesParser.cs create mode 100644 src/Http/Http.Extensions/gen/ValidationsGenerator/ValidationsGenerator.cs create mode 100644 src/Http/Http.Extensions/test/ValidationsGenerator/ModuleInitializer.cs create mode 100644 src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.ComplexType.cs create mode 100644 src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.IValidatableObject.cs create mode 100644 src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.Interception.cs create mode 100644 src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.Parameters.cs create mode 100644 src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.Polymorphism.cs create mode 100644 src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.Recursion.cs create mode 100644 src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTestsBase.cs create mode 100644 src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateComplexTypes#RouteHandlerValidations.g.verified.cs create mode 100644 src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateIValidatableObject#RouteHandlerValidations.g.verified.cs create mode 100644 src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateOnAllMapOverloads#RouteHandlerValidations.g.verified.cs create mode 100644 src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateParameters#RouteHandlerValidations.g.verified.cs create mode 100644 src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidatePolymorphicTypes#RouteHandlerValidations.g.verified.cs create mode 100644 src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateRecursiveTypes#RouteHandlerValidations.g.verified.cs create mode 100644 src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTypeTests.CanValidateComplexTypes#RouteHandlerValidations.g.verified.cs create mode 100644 src/Http/Routing/src/Builder/ValidationConventionBuilderExtensions.cs diff --git a/AspNetCore.sln b/AspNetCore.sln index 1805c5617512..58d4cbec675d 100644 --- a/AspNetCore.sln +++ b/AspNetCore.sln @@ -1732,7 +1732,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.Server EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.Server.Kestrel.Transport.NamedPipes.Tests", "src\Servers\Kestrel\Transport.NamedPipes\test\Microsoft.AspNetCore.Server.Kestrel.Transport.NamedPipes.Tests.csproj", "{97C7D2A4-87E5-4A4A-A170-D736427D5C21}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.Http.RequestDelegateGenerator", "src\Http\Http.Extensions\gen\Microsoft.AspNetCore.Http.RequestDelegateGenerator.csproj", "{4730F56D-24EF-4BB2-AA75-862E31205F3A}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.Http.RequestDelegateGenerator", "src\Http\Http.Extensions\gen\RequestDelegateGenerator\Microsoft.AspNetCore.Http.RequestDelegateGenerator.csproj", "{4730F56D-24EF-4BB2-AA75-862E31205F3A}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "QuickGrid", "QuickGrid", "{C406D9E0-1585-43F9-AA8F-D468AF84A996}" EndProject @@ -1812,6 +1812,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Assets", "Assets", "{2B858B EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.App.Internal.Assets", "src\Assets\Microsoft.AspNetCore.App.Internal.Assets.csproj", "{2AAE7819-BC3E-48F4-9CFA-5DD4CD5FFD62}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.Http.ValidationsGenerator", "src\Http\Http.Extensions\gen\ValidationsGenerator\Microsoft.AspNetCore.Http.ValidationsGenerator.csproj", "{185D74AB-76CE-44C1-86EB-16E93C84033F}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -10943,22 +10945,38 @@ Global {C3928C15-1836-46DB-A09D-9EFBCCA33E08}.Release|x64.Build.0 = Release|Any CPU {C3928C15-1836-46DB-A09D-9EFBCCA33E08}.Release|x86.ActiveCfg = Release|Any CPU {C3928C15-1836-46DB-A09D-9EFBCCA33E08}.Release|x86.Build.0 = Release|Any CPU - {2AAE7819-BC3E-48F4-9CFA-5DD4CD5FFD62}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {2AAE7819-BC3E-48F4-9CFA-5DD4CD5FFD62}.Debug|Any CPU.Build.0 = Debug|Any CPU - {2AAE7819-BC3E-48F4-9CFA-5DD4CD5FFD62}.Debug|arm64.ActiveCfg = Debug|Any CPU - {2AAE7819-BC3E-48F4-9CFA-5DD4CD5FFD62}.Debug|arm64.Build.0 = Debug|Any CPU - {2AAE7819-BC3E-48F4-9CFA-5DD4CD5FFD62}.Debug|x64.ActiveCfg = Debug|Any CPU - {2AAE7819-BC3E-48F4-9CFA-5DD4CD5FFD62}.Debug|x64.Build.0 = Debug|Any CPU - {2AAE7819-BC3E-48F4-9CFA-5DD4CD5FFD62}.Debug|x86.ActiveCfg = Debug|Any CPU - {2AAE7819-BC3E-48F4-9CFA-5DD4CD5FFD62}.Debug|x86.Build.0 = Debug|Any CPU - {2AAE7819-BC3E-48F4-9CFA-5DD4CD5FFD62}.Release|Any CPU.ActiveCfg = Release|Any CPU - {2AAE7819-BC3E-48F4-9CFA-5DD4CD5FFD62}.Release|Any CPU.Build.0 = Release|Any CPU - {2AAE7819-BC3E-48F4-9CFA-5DD4CD5FFD62}.Release|arm64.ActiveCfg = Release|Any CPU - {2AAE7819-BC3E-48F4-9CFA-5DD4CD5FFD62}.Release|arm64.Build.0 = Release|Any CPU - {2AAE7819-BC3E-48F4-9CFA-5DD4CD5FFD62}.Release|x64.ActiveCfg = Release|Any CPU - {2AAE7819-BC3E-48F4-9CFA-5DD4CD5FFD62}.Release|x64.Build.0 = Release|Any CPU - {2AAE7819-BC3E-48F4-9CFA-5DD4CD5FFD62}.Release|x86.ActiveCfg = Release|Any CPU - {2AAE7819-BC3E-48F4-9CFA-5DD4CD5FFD62}.Release|x86.Build.0 = Release|Any CPU + {FB45AD76-D348-4F96-A8EE-71F61DC6DA11}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {FB45AD76-D348-4F96-A8EE-71F61DC6DA11}.Debug|Any CPU.Build.0 = Debug|Any CPU + {FB45AD76-D348-4F96-A8EE-71F61DC6DA11}.Debug|arm64.ActiveCfg = Debug|Any CPU + {FB45AD76-D348-4F96-A8EE-71F61DC6DA11}.Debug|arm64.Build.0 = Debug|Any CPU + {FB45AD76-D348-4F96-A8EE-71F61DC6DA11}.Debug|x64.ActiveCfg = Debug|Any CPU + {FB45AD76-D348-4F96-A8EE-71F61DC6DA11}.Debug|x64.Build.0 = Debug|Any CPU + {FB45AD76-D348-4F96-A8EE-71F61DC6DA11}.Debug|x86.ActiveCfg = Debug|Any CPU + {FB45AD76-D348-4F96-A8EE-71F61DC6DA11}.Debug|x86.Build.0 = Debug|Any CPU + {FB45AD76-D348-4F96-A8EE-71F61DC6DA11}.Release|Any CPU.ActiveCfg = Release|Any CPU + {FB45AD76-D348-4F96-A8EE-71F61DC6DA11}.Release|Any CPU.Build.0 = Release|Any CPU + {FB45AD76-D348-4F96-A8EE-71F61DC6DA11}.Release|arm64.ActiveCfg = Release|Any CPU + {FB45AD76-D348-4F96-A8EE-71F61DC6DA11}.Release|arm64.Build.0 = Release|Any CPU + {FB45AD76-D348-4F96-A8EE-71F61DC6DA11}.Release|x64.ActiveCfg = Release|Any CPU + {FB45AD76-D348-4F96-A8EE-71F61DC6DA11}.Release|x64.Build.0 = Release|Any CPU + {FB45AD76-D348-4F96-A8EE-71F61DC6DA11}.Release|x86.ActiveCfg = Release|Any CPU + {FB45AD76-D348-4F96-A8EE-71F61DC6DA11}.Release|x86.Build.0 = Release|Any CPU + {185D74AB-76CE-44C1-86EB-16E93C84033F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {185D74AB-76CE-44C1-86EB-16E93C84033F}.Debug|Any CPU.Build.0 = Debug|Any CPU + {185D74AB-76CE-44C1-86EB-16E93C84033F}.Debug|arm64.ActiveCfg = Debug|Any CPU + {185D74AB-76CE-44C1-86EB-16E93C84033F}.Debug|arm64.Build.0 = Debug|Any CPU + {185D74AB-76CE-44C1-86EB-16E93C84033F}.Debug|x64.ActiveCfg = Debug|Any CPU + {185D74AB-76CE-44C1-86EB-16E93C84033F}.Debug|x64.Build.0 = Debug|Any CPU + {185D74AB-76CE-44C1-86EB-16E93C84033F}.Debug|x86.ActiveCfg = Debug|Any CPU + {185D74AB-76CE-44C1-86EB-16E93C84033F}.Debug|x86.Build.0 = Debug|Any CPU + {185D74AB-76CE-44C1-86EB-16E93C84033F}.Release|Any CPU.ActiveCfg = Release|Any CPU + {185D74AB-76CE-44C1-86EB-16E93C84033F}.Release|Any CPU.Build.0 = Release|Any CPU + {185D74AB-76CE-44C1-86EB-16E93C84033F}.Release|arm64.ActiveCfg = Release|Any CPU + {185D74AB-76CE-44C1-86EB-16E93C84033F}.Release|arm64.Build.0 = Release|Any CPU + {185D74AB-76CE-44C1-86EB-16E93C84033F}.Release|x64.ActiveCfg = Release|Any CPU + {185D74AB-76CE-44C1-86EB-16E93C84033F}.Release|x64.Build.0 = Release|Any CPU + {185D74AB-76CE-44C1-86EB-16E93C84033F}.Release|x86.ActiveCfg = Release|Any CPU + {185D74AB-76CE-44C1-86EB-16E93C84033F}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -11853,8 +11871,9 @@ Global {B32FF7A7-9CB3-4DCD-AE97-3B2594DB9DAC} = {2299CCD8-8F9C-4F2B-A633-9BF4DA81022B} {B9BBC1A8-7F58-4F43-94C3-5F3CB125CEF7} = {B32FF7A7-9CB3-4DCD-AE97-3B2594DB9DAC} {C3928C15-1836-46DB-A09D-9EFBCCA33E08} = {B5D98AEB-9409-4280-8225-9C1EC6A791B2} - {2B858B82-5F0B-4A24-B3C0-5E99149F70D6} = {017429CC-C5FB-48B4-9C46-034E29EE2F06} - {2AAE7819-BC3E-48F4-9CFA-5DD4CD5FFD62} = {2B858B82-5F0B-4A24-B3C0-5E99149F70D6} + {18A3AF88-D633-44E6-9407-3B2C708F7F64} = {225AEDCF-7162-4A86-AC74-06B84660B379} + {FB45AD76-D348-4F96-A8EE-71F61DC6DA11} = {18A3AF88-D633-44E6-9407-3B2C708F7F64} + {185D74AB-76CE-44C1-86EB-16E93C84033F} = {18A3AF88-D633-44E6-9407-3B2C708F7F64} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {3E8720B3-DBDD-498C-B383-2CC32A054E8F} diff --git a/eng/Dependencies.props b/eng/Dependencies.props index 26dc86e0e31c..d22eeeb2ec28 100644 --- a/eng/Dependencies.props +++ b/eng/Dependencies.props @@ -224,6 +224,7 @@ and are generated based on the last package release. + diff --git a/eng/Versions.props b/eng/Versions.props index fa900f196f02..b9cc1fc53df4 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -327,6 +327,7 @@ 5.0.0 6.6.2 19.14.0 + 2.2.0 2.0.3 1.15.0 2.9.2 diff --git a/eng/testing/linker/SupportFiles/Directory.Build.targets b/eng/testing/linker/SupportFiles/Directory.Build.targets index eb82d94bec27..f03257488b40 100644 --- a/eng/testing/linker/SupportFiles/Directory.Build.targets +++ b/eng/testing/linker/SupportFiles/Directory.Build.targets @@ -76,7 +76,7 @@ - + diff --git a/src/DefaultBuilder/src/PublicAPI.Unshipped.txt b/src/DefaultBuilder/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..3316f2dbcdd0 100644 --- a/src/DefaultBuilder/src/PublicAPI.Unshipped.txt +++ b/src/DefaultBuilder/src/PublicAPI.Unshipped.txt @@ -1 +1,2 @@ #nullable enable +Microsoft.AspNetCore.Builder.WebApplication.Conventions.get -> Microsoft.AspNetCore.Builder.IEndpointConventionBuilder! diff --git a/src/DefaultBuilder/src/WebApplication.cs b/src/DefaultBuilder/src/WebApplication.cs index ed700260a0e4..b5d10eb637be 100644 --- a/src/DefaultBuilder/src/WebApplication.cs +++ b/src/DefaultBuilder/src/WebApplication.cs @@ -27,15 +27,18 @@ public sealed class WebApplication : IHost, IApplicationBuilder, IEndpointRouteB internal const string GlobalEndpointRouteBuilderKey = "__GlobalEndpointRouteBuilder"; private readonly IHost _host; - private readonly List _dataSources = new(); + private readonly GlobalEndpointRouteBuilder _innerBuilder; + private readonly RouteGroupBuilder _globalRouteGroup; internal WebApplication(IHost host) { _host = host; + _innerBuilder = new(this); + _globalRouteGroup = _innerBuilder.MapGroup(""); ApplicationBuilder = new ApplicationBuilder(host.Services, ServerFeatures); Logger = host.Services.GetRequiredService().CreateLogger(Environment.ApplicationName ?? nameof(WebApplication)); - Properties[GlobalEndpointRouteBuilderKey] = this; + Properties[GlobalEndpointRouteBuilderKey] = _innerBuilder; } /// @@ -80,9 +83,14 @@ IServiceProvider IApplicationBuilder.ApplicationServices internal IDictionary Properties => ApplicationBuilder.Properties; IDictionary IApplicationBuilder.Properties => Properties; - internal ICollection DataSources => _dataSources; + internal ICollection DataSources => ((IEndpointRouteBuilder)_globalRouteGroup).DataSources; ICollection IEndpointRouteBuilder.DataSources => DataSources; + /// + /// Gets the for the application. + /// + public IEndpointConventionBuilder Conventions => _globalRouteGroup; + internal ApplicationBuilder ApplicationBuilder { get; } IServiceProvider IEndpointRouteBuilder.ServiceProvider => Services; @@ -307,4 +315,13 @@ public IList? Middleware } } } + + private class GlobalEndpointRouteBuilder(WebApplication application) : IEndpointRouteBuilder + { + public IServiceProvider ServiceProvider => application.Services; + + public ICollection DataSources { get; } = []; + + public IApplicationBuilder CreateApplicationBuilder() => application; + } } diff --git a/src/DefaultBuilder/src/WebApplicationBuilder.cs b/src/DefaultBuilder/src/WebApplicationBuilder.cs index 2cff4ae4ffb9..e2c79345c258 100644 --- a/src/DefaultBuilder/src/WebApplicationBuilder.cs +++ b/src/DefaultBuilder/src/WebApplicationBuilder.cs @@ -408,7 +408,7 @@ private void ConfigureApplication(WebHostBuilderContext context, IApplicationBui // destination.UseEndpoints() // Set the route builder so that UseRouting will use the WebApplication as the IEndpointRouteBuilder for route matching - app.Properties.Add(WebApplication.GlobalEndpointRouteBuilderKey, _builtApplication); + app.Properties.Add(WebApplication.GlobalEndpointRouteBuilderKey, _builtApplication.Properties[WebApplication.GlobalEndpointRouteBuilderKey]); // Only call UseRouting() if there are endpoints configured and UseRouting() wasn't called on the global route builder already if (_builtApplication.DataSources.Count > 0) diff --git a/src/Http/Http.Extensions/gen/DiagnosticDescriptors.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/DiagnosticDescriptors.cs similarity index 100% rename from src/Http/Http.Extensions/gen/DiagnosticDescriptors.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/DiagnosticDescriptors.cs diff --git a/src/Http/Http.Extensions/gen/GeneratorSteps.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/GeneratorSteps.cs similarity index 100% rename from src/Http/Http.Extensions/gen/GeneratorSteps.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/GeneratorSteps.cs diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.RequestDelegateGenerator.csproj b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/Microsoft.AspNetCore.Http.RequestDelegateGenerator.csproj similarity index 100% rename from src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.RequestDelegateGenerator.csproj rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/Microsoft.AspNetCore.Http.RequestDelegateGenerator.csproj diff --git a/src/Http/Http.Extensions/gen/RequestDelegateGenerator.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/RequestDelegateGenerator.cs similarity index 100% rename from src/Http/Http.Extensions/gen/RequestDelegateGenerator.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/RequestDelegateGenerator.cs diff --git a/src/Http/Http.Extensions/gen/RequestDelegateGeneratorSources.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/RequestDelegateGeneratorSources.cs similarity index 100% rename from src/Http/Http.Extensions/gen/RequestDelegateGeneratorSources.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/RequestDelegateGeneratorSources.cs diff --git a/src/Http/Http.Extensions/gen/RequestDelegateGeneratorSuppressor.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/RequestDelegateGeneratorSuppressor.cs similarity index 100% rename from src/Http/Http.Extensions/gen/RequestDelegateGeneratorSuppressor.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/RequestDelegateGeneratorSuppressor.cs diff --git a/src/Http/Http.Extensions/gen/Resources.resx b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/Resources.resx similarity index 100% rename from src/Http/Http.Extensions/gen/Resources.resx rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/Resources.resx diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Emitters/DiagnosticEmitter.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Emitters/DiagnosticEmitter.cs similarity index 100% rename from src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Emitters/DiagnosticEmitter.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Emitters/DiagnosticEmitter.cs diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Emitters/EmitterConstants.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Emitters/EmitterConstants.cs similarity index 100% rename from src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Emitters/EmitterConstants.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Emitters/EmitterConstants.cs diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Emitters/EmitterContext.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Emitters/EmitterContext.cs similarity index 100% rename from src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Emitters/EmitterContext.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Emitters/EmitterContext.cs diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Emitters/EmitterExtensions.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Emitters/EmitterExtensions.cs similarity index 100% rename from src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Emitters/EmitterExtensions.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Emitters/EmitterExtensions.cs diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Emitters/EndpointEmitter.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Emitters/EndpointEmitter.cs similarity index 100% rename from src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Emitters/EndpointEmitter.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Emitters/EndpointEmitter.cs diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Emitters/EndpointJsonPreparationEmitter.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Emitters/EndpointJsonPreparationEmitter.cs similarity index 100% rename from src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Emitters/EndpointJsonPreparationEmitter.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Emitters/EndpointJsonPreparationEmitter.cs diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Emitters/EndpointParameterEmitter.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Emitters/EndpointParameterEmitter.cs similarity index 100% rename from src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Emitters/EndpointParameterEmitter.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Emitters/EndpointParameterEmitter.cs diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Endpoint.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Endpoint.cs similarity index 100% rename from src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Endpoint.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Endpoint.cs diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointDelegateComparer.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/EndpointDelegateComparer.cs similarity index 100% rename from src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointDelegateComparer.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/EndpointDelegateComparer.cs diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointHttpMethodComparer.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/EndpointHttpMethodComparer.cs similarity index 100% rename from src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointHttpMethodComparer.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/EndpointHttpMethodComparer.cs diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointParameter.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/EndpointParameter.cs similarity index 100% rename from src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointParameter.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/EndpointParameter.cs diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointParameterSource.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/EndpointParameterSource.cs similarity index 100% rename from src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointParameterSource.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/EndpointParameterSource.cs diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointResponse.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/EndpointResponse.cs similarity index 100% rename from src/Http/Http.Extensions/gen/StaticRouteHandlerModel/EndpointResponse.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/EndpointResponse.cs diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/InvocationOperationExtensions.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/InvocationOperationExtensions.cs similarity index 100% rename from src/Http/Http.Extensions/gen/StaticRouteHandlerModel/InvocationOperationExtensions.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/InvocationOperationExtensions.cs diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Model/ConstructorParameter.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Model/ConstructorParameter.cs similarity index 100% rename from src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Model/ConstructorParameter.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Model/ConstructorParameter.cs diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Model/EndpointParameterExtensions.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Model/EndpointParameterExtensions.cs similarity index 100% rename from src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Model/EndpointParameterExtensions.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Model/EndpointParameterExtensions.cs diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Model/ParameterLookupKey.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Model/ParameterLookupKey.cs similarity index 100% rename from src/Http/Http.Extensions/gen/StaticRouteHandlerModel/Model/ParameterLookupKey.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/Model/ParameterLookupKey.cs diff --git a/src/Http/Http.Extensions/gen/StaticRouteHandlerModel/StaticRouteHandlerModel.Emitter.cs b/src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/StaticRouteHandlerModel.Emitter.cs similarity index 100% rename from src/Http/Http.Extensions/gen/StaticRouteHandlerModel/StaticRouteHandlerModel.Emitter.cs rename to src/Http/Http.Extensions/gen/RequestDelegateGenerator/StaticRouteHandlerModel/StaticRouteHandlerModel.Emitter.cs diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.Emitter.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.Emitter.cs new file mode 100644 index 000000000000..a24c9fb42aba --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.Emitter.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using System.IO; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator +{ + internal static void EmitValidationsFile(SourceProductionContext context, ((string Left, string Right) Left, ImmutableArray Right) source) + { + var withValidations = source.Left.Left; + var typeValidations = source.Left.Right; + var validationsFilters = source.Right; + var writer = new StringWriter(); + var output = new CodeWriter(writer, baseIndent: 0); + output.WriteLine("// "); + output.WriteLine("#nullable enable"); + output.WriteLine("namespace System.Runtime.CompilerServices"); + output.StartBlock(); + output.WriteLine("[AttributeUsage(System.AttributeTargets.Method, AllowMultiple = true)]"); + output.WriteLine("file sealed class InterceptsLocationAttribute : Attribute"); + output.StartBlock(); + output.WriteLine("public InterceptsLocationAttribute(int version, string data) { }"); + output.EndBlock(); + output.EndBlock(); + output.WriteLine(); + output.WriteLine("namespace Microsoft.AspNetCore.Http.Validations.Generated"); + output.StartBlock(); + output.WriteLine("using System;"); + output.WriteLine("using System.Linq;"); + output.WriteLine("using System.Diagnostics;"); + output.WriteLine("using System.ComponentModel.DataAnnotations;"); + output.WriteLine(); + output.Indent--; + output.WriteLine(EmitEndpointKey()); + output.WriteLine(EmitValidationProblemBuilder()); + output.WriteLine(withValidations); + output.WriteLine(); + output.WriteLine(typeValidations); + output.WriteLine(); + output.Write(EmitEndpointValidationFilters(validationsFilters)); + output.WriteLine("}"); + output.WriteLine("#nullable restore"); + context.AddSource("RouteHandlerValidations.g.cs", writer.ToString()); + } +} diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.EndpointKeyEmitter.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.EndpointKeyEmitter.cs new file mode 100644 index 000000000000..b07f4540cf5d --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.EndpointKeyEmitter.cs @@ -0,0 +1,53 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.IO; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator +{ + internal static string EmitEndpointKey() + { + var writer = new StringWriter(); + var code = new CodeWriter(writer, baseIndent: 1); + code.WriteLine("file class EndpointKey(string route, global::System.Collections.Generic.IEnumerable methods)"); + code.StartBlock(); + code.WriteLine("public string Route { get; } = route;"); + code.WriteLine("public global::System.Collections.Generic.IEnumerable Methods { get; } = methods;"); + code.WriteLine(); + code.WriteLine("public override bool Equals(object? obj)"); + code.StartBlock(); + code.WriteLine("if (obj is EndpointKey other)"); + code.StartBlock(); + code.WriteLine("return string.Equals(Route, other.Route, global::System.StringComparison.OrdinalIgnoreCase) &&"); + code.WriteLine("Methods.SequenceEqual(other.Methods, global::System.StringComparer.OrdinalIgnoreCase);"); + code.EndBlock(); + code.WriteLine("return false;"); + code.EndBlock(); + code.WriteLine(); + code.WriteLine("public override int GetHashCode()"); + code.StartBlock(); + code.WriteLine("int hash = 17;"); + code.WriteLine("hash = hash * 23 + (Route?.GetHashCode(global::System.StringComparison.OrdinalIgnoreCase) ?? 0);"); + code.WriteLine("hash = hash * 23 + GetMethodsHashCode(Methods);"); + code.WriteLine("return hash;"); + code.EndBlock(); + code.WriteLine(); + code.WriteLine("private static int GetMethodsHashCode(global::System.Collections.Generic.IEnumerable methods)"); + code.StartBlock(); + code.WriteLine("if (methods == null)"); + code.StartBlock(); + code.WriteLine("return 0;"); + code.EndBlock(); + code.WriteLine("int hash = 17;"); + code.WriteLine("foreach (var method in methods)"); + code.StartBlock(); + code.WriteLine("hash = hash * 23 + (method?.GetHashCode(global::System.StringComparison.OrdinalIgnoreCase) ?? 0);"); + code.EndBlock(); + code.WriteLine("return hash;"); + code.EndBlock(); + code.EndBlock(); + return writer.ToString(); + } +} diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.TypeValidationEmitter.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.TypeValidationEmitter.cs new file mode 100644 index 000000000000..23148d3e176c --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.TypeValidationEmitter.cs @@ -0,0 +1,142 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using System.IO; +using System.Threading; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator +{ + internal static string EmitTypeValidations(ImmutableArray validatableTypes, CancellationToken cancellationToken) + { + var writer = new StringWriter(); + var code = new CodeWriter(writer, baseIndent: 1); + code.WriteLine("file static class ValidationTypes"); + code.StartBlock(); + code.WriteLine("public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(T value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) => null;"); + if (validatableTypes.Length == 0) + { + code.EndBlock(); + return writer.ToString(); + } + code.WriteLine(); + foreach (var type in validatableTypes) + { + var needsValidationResult = true; + foreach (var member in type.Members) + { + foreach (var attribute in member.Attributes) + { + code.WriteLine(EmitValidationAttribute(attribute)); + } + } + code.WriteLine(); + code.WriteLine($"public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate({type.Name}? value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false)"); + code.StartBlock(); + code.WriteLine("var maxDepth = ((global::Microsoft.Extensions.Options.IOptions)(serviceProvider ?? validationContext).GetService(typeof(global::Microsoft.Extensions.Options.IOptions))).Value?.SerializerOptions.MaxDepth;"); + code.WriteLine("if (currentDepth > (maxDepth == null || maxDepth == 0 ? 64 : maxDepth))"); + code.StartBlock(); + code.WriteLine("return null;"); + code.EndBlock(); + code.WriteLine("ValidationProblemBuilder resultBuilder = new();"); + code.WriteLine($"if (value != null)"); + code.StartBlock(); + foreach (var subTypeName in type.ValidatableSubTypeNames) + { + code.WriteLine($"if (value is {subTypeName} subType{subTypeName})"); + code.StartBlock(); + code.WriteLine($"var subType{subTypeName}ValidationResult = ValidationTypes.Validate(({subTypeName})subType{subTypeName}, validationContext, serviceProvider, currentDepth: currentDepth + 1, skipRecurse);"); + code.WriteLine($"if (subType{subTypeName}ValidationResult is not null)"); + code.StartBlock(); + code.WriteLine($"foreach (var error in subType{subTypeName}ValidationResult.Errors)"); + code.StartBlock(); + code.WriteLine("resultBuilder.WithErrors(error.Key, error.Value);"); + code.EndBlock(); + code.EndBlock(); + code.EndBlock(); + } + foreach (var derivedTypeName in type.ValidatableDerivedTypeNames) + { + code.WriteLine($"if (!skipRecurse && value is {derivedTypeName} derivedType{derivedTypeName})"); + code.StartBlock(); + code.WriteLine($"var derivedType{derivedTypeName}ValidationResult = ValidationTypes.Validate(derivedType{derivedTypeName}, validationContext, serviceProvider, currentDepth: currentDepth + 1, skipRecurse: true);"); + code.WriteLine($"if (derivedType{derivedTypeName}ValidationResult is not null)"); + code.StartBlock(); + code.WriteLine($"foreach (var error in derivedType{derivedTypeName}ValidationResult.Errors)"); + code.StartBlock(); + code.WriteLine("resultBuilder.WithErrors(error.Key, error.Value);"); + code.EndBlock(); + code.EndBlock(); + code.WriteLine("return resultBuilder.HasValue() ? resultBuilder.Build() : null;"); + code.EndBlock(); + } + code.WriteLine("validationContext ??= new(value, serviceProvider: serviceProvider, items: null);"); + foreach (var member in type.Members) + { + code.WriteLine($@"validationContext.DisplayName = ""{member.DisplayName}"";"); + code.WriteLine($@"validationContext.MemberName = ""{member.Name}"";"); + if (member.IsEnumerable && member.HasValidatableType) + { + code.WriteLine($"var {member.Name}Index = 0;"); + code.WriteLine($"foreach (var item in value?.{member.Name} ?? [])"); + code.StartBlock(); + code.WriteLine($"var itemValidationResult = ValidationTypes.Validate(item, validationContext, serviceProvider, currentDepth, skipRecurse);"); + code.WriteLine($"if (itemValidationResult is not null)"); + code.StartBlock(); + code.WriteLine($"foreach (var error in itemValidationResult.Errors)"); + code.StartBlock(); + code.WriteLine($@"resultBuilder.WithErrors($""{member.Name}[{{ {member.Name}Index }}].{{error.Key}}"", error.Value);"); + code.EndBlock(); + code.WriteLine($"{member.Name}Index++;"); + code.EndBlock(); + code.EndBlock(); + } + else if (member.HasValidatableType) + { + code.WriteLine($"var type{member.Name}ValidationResult = ValidationTypes.Validate(value?.{member.Name}, validationContext, serviceProvider, currentDepth: currentDepth + 1, skipRecurse);"); + code.WriteLine($"if (type{member.Name}ValidationResult is not null)"); + code.StartBlock(); + code.WriteLine($"foreach (var error in type{member.Name}ValidationResult.Errors)"); + code.StartBlock(); + code.WriteLine($@"resultBuilder.WithErrors($""{member.Name}.{{error.Key}}"", error.Value);"); + code.EndBlock(); + code.EndBlock(); + } + foreach (var attribute in member.Attributes) + { + if (needsValidationResult) + { + code.WriteLine("global::System.ComponentModel.DataAnnotations.ValidationResult? validationResult;"); + needsValidationResult = false; + } + code.WriteLine($@"validationResult = {attribute.Name}.GetValidationResult(value?.{member.Name}, validationContext);"); + code.WriteLine("if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null })"); + code.StartBlock(); + code.WriteLine($@"resultBuilder.WithError(""{member.Name}"", validationResult.ErrorMessage);"); + code.EndBlock(); + } + } + if (type.IsIValidatableObject) + { + code.WriteLine("if (!resultBuilder.HasValue())"); + code.StartBlock(); + code.WriteLine("validationContext = new(value, serviceProvider: serviceProvider, items: null);"); + code.WriteLine("foreach (var validatableValidationResult in value.Validate(validationContext))"); + code.StartBlock(); + code.WriteLine("if (validatableValidationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null })"); + code.StartBlock(); + code.WriteLine($@"resultBuilder.WithError(validatableValidationResult.MemberNames.FirstOrDefault() ?? ""{type.Name}"", validatableValidationResult.ErrorMessage);"); + code.EndBlock(); + code.EndBlock(); + code.EndBlock(); + } + code.EndBlock(); + code.WriteLine("return resultBuilder.HasValue() ? resultBuilder.Build() : null;"); + code.EndBlock(); + } + code.EndBlock(); + return writer.ToString(); + } +} diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.ValidationAttributeEmitters.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.ValidationAttributeEmitters.cs new file mode 100644 index 000000000000..a10a658ccded --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.ValidationAttributeEmitters.cs @@ -0,0 +1,52 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public partial class ValidationsGenerator +{ + internal static string EmitValidationAttribute(ValidationAttribute attribute) + { + var builder = new StringBuilder(); + if (attribute.ForParameter) + { + builder.Append("var "); + } + else + { + builder.Append("private static readonly ValidationAttribute "); + } + builder.Append(attribute.Name); + builder.Append(' '); + builder.Append("= "); + builder.Append("new "); + builder.Append(attribute.ClassName); + builder.Append('('); + for (var i = 0; i < attribute.Arguments.Count; i++) + { + builder.Append(attribute.Arguments[i]); + if (i < attribute.Arguments.Count - 1) + { + builder.Append(", "); + } + } + if (attribute.NamedArguments.Count > 0) + { + builder.Append(") { "); + foreach (var kvp in attribute.NamedArguments) + { + builder.Append(kvp.Key); + builder.Append(" = "); + builder.Append(kvp.Value); + } + builder.Append(" };"); + } + else + { + builder.Append(");"); + } + return builder.ToString(); + } +} diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.ValidationProblemBuilderEmitter.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.ValidationProblemBuilderEmitter.cs new file mode 100644 index 000000000000..9f2053b7fa5b --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.ValidationProblemBuilderEmitter.cs @@ -0,0 +1,110 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.IO; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator +{ + internal static string EmitValidationProblemBuilder() + { + var writer = new StringWriter(); + var code = new CodeWriter(writer, baseIndent: 1); + code.WriteLine("file class ValidationProblemBuilder"); + code.StartBlock(); + code.WriteLine("private readonly global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails _problemDetails;"); + code.WriteLine(); + code.WriteLine("public ValidationProblemBuilder()"); + code.StartBlock(); + code.WriteLine("_problemDetails = new global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails();"); + code.EndBlock(); + code.WriteLine(); + code.WriteLine("public ValidationProblemBuilder WithTitle(string title)"); + code.StartBlock(); + code.WriteLine("_problemDetails.Title = title;"); + code.WriteLine("return this;"); + code.EndBlock(); + code.WriteLine(); + code.WriteLine("public ValidationProblemBuilder WithStatus(int? status)"); + code.StartBlock(); + code.WriteLine("_problemDetails.Status = status;"); + code.WriteLine("return this;"); + code.EndBlock(); + code.WriteLine(); + code.WriteLine("public ValidationProblemBuilder WithDetail(string detail)"); + code.StartBlock(); + code.WriteLine("_problemDetails.Detail = detail;"); + code.WriteLine("return this;"); + code.EndBlock(); + code.WriteLine(); + code.WriteLine("public ValidationProblemBuilder WithInstance(string instance)"); + code.StartBlock(); + code.WriteLine("_problemDetails.Instance = instance;"); + code.WriteLine("return this;"); + code.EndBlock(); + code.WriteLine(); + code.WriteLine("public ValidationProblemBuilder WithType(string type)"); + code.StartBlock(); + code.WriteLine("_problemDetails.Type = type;"); + code.WriteLine("return this;"); + code.EndBlock(); + code.WriteLine(); + code.WriteLine("public ValidationProblemBuilder WithExtensions(global::System.Collections.Generic.IDictionary extensions)"); + code.StartBlock(); + code.WriteLine("foreach (var kvp in extensions)"); + code.StartBlock(); + code.WriteLine("_problemDetails.Extensions[kvp.Key] = kvp.Value;"); + code.EndBlock(); + code.WriteLine("return this;"); + code.EndBlock(); + code.WriteLine(); + code.WriteLine("public ValidationProblemBuilder WithErrors(global::System.Collections.Generic.IDictionary errors)"); + code.StartBlock(); + code.WriteLine("foreach (var kvp in errors)"); + code.StartBlock(); + code.WriteLine("_problemDetails.Errors[kvp.Key] = kvp.Value;"); + code.EndBlock(); + code.WriteLine("return this;"); + code.EndBlock(); + code.WriteLine(); + code.WriteLine("public ValidationProblemBuilder WithError(string key, string error)"); + code.StartBlock(); + code.WriteLine("if (_problemDetails.Errors.ContainsKey(key))"); + code.StartBlock(); + code.WriteLine("_problemDetails.Errors[key] = _problemDetails.Errors[key].Append(error).ToArray();"); + code.EndBlock(); + code.WriteLine("else"); + code.StartBlock(); + code.WriteLine("_problemDetails.Errors[key] = new string[] { error };"); + code.EndBlock(); + code.WriteLine("return this;"); + code.EndBlock(); + code.WriteLine(); + code.WriteLine("public ValidationProblemBuilder WithErrors(string key, string[] errors)"); + code.StartBlock(); + code.WriteLine("if (_problemDetails.Errors.ContainsKey(key))"); + code.StartBlock(); + code.WriteLine("_problemDetails.Errors[key] = _problemDetails.Errors[key].Concat(errors).ToArray();"); + code.EndBlock(); + code.WriteLine("else"); + code.StartBlock(); + code.WriteLine("_problemDetails.Errors[key] = errors;"); + code.EndBlock(); + code.WriteLine("return this;"); + code.EndBlock(); + code.WriteLine(); + code.WriteLine("public global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails Build()"); + code.StartBlock(); + code.WriteLine("return _problemDetails;"); + code.EndBlock(); + code.WriteLine(); + code.WriteLine("public bool HasValue()"); + code.StartBlock(); + code.WriteLine("return _problemDetails.Errors.Count > 0;"); + code.EndBlock(); + code.EndBlock(); + + return writer.ToString(); + } +} diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.ValidationsFilterEmitter.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.ValidationsFilterEmitter.cs new file mode 100644 index 000000000000..e8794e62e72a --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.ValidationsFilterEmitter.cs @@ -0,0 +1,101 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using System.IO; +using System.Linq; +using System.Threading; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator +{ + internal static string EmitEndpointValidationFilters(ImmutableArray filterDeclarations) + { + var writer = new StringWriter(); + var code = new CodeWriter(writer, baseIndent: 1); + code.WriteLine("file static class ValidationsFilters"); + code.StartBlock(); + code.WriteLine("public static readonly global::System.Collections.Generic.Dictionary> Filters = new()"); + code.WriteLine("{"); + code.Indent--; + foreach (var filter in filterDeclarations) + { + code.WriteLine(filter); + } + code.Indent++; + code.WriteLine("};"); + code.Indent--; + code.WriteLine("}"); + return writer.ToString(); + } + + internal string EmitEndpointValidationFilter(ValidatableEndpoint endpoint, CancellationToken cancellationToken) + { + var writer = new StringWriter(); + var code = new CodeWriter(writer, baseIndent: 2); + code.WriteLine($@"{{ {endpoint.EndpointKey}, context => "); + code.Indent++; + code.Indent++; + code.StartBlock(); + code.WriteLine("ValidationProblemBuilder resultBuilder = new();"); + var validationResultEmitted = false; + foreach (var parameter in endpoint.Parameters) + { + var parameterType = parameter.OriginalType.ToDisplayString(); + var parameterIndex = parameter.Index; + if (parameter.Attributes.Any() && !validationResultEmitted) + { + code.WriteLine("global::System.ComponentModel.DataAnnotations.ValidationResult? validationResult = null;"); + validationResultEmitted = true; + } + code.WriteLine($"var value{parameterIndex} = context.GetArgument<{parameterType}>({parameterIndex});"); + foreach (var attribute in parameter.Attributes) + { + code.WriteLine(EmitValidationAttribute(attribute)); + } + if (parameter.Attributes.Any()) + { + foreach (var attribute in parameter.Attributes) + { + code.WriteLine($@"validationResult = {attribute.Name}.GetValidationResult(value{parameterIndex}, new global::System.ComponentModel.DataAnnotations.ValidationContext(value{parameterIndex}) {{ DisplayName = ""{parameter.DisplayName}"" }});"); + code.WriteLine("if (validationResult is not null)"); + code.StartBlock(); + code.WriteLine(@$"resultBuilder.WithError(""{parameter.Name}"", validationResult.ErrorMessage);"); + code.EndBlock(); + } + } + if (parameter.IsEnumerable && parameter.HasValidatableType) + { + code.WriteLine($"var value{parameterIndex}Index = 0;"); + code.WriteLine($"foreach (var item in value{parameterIndex} ?? [])"); + code.StartBlock(); + code.WriteLine($"var itemValidationResult = ValidationTypes.Validate(item, validationContext);"); + code.WriteLine($"if (itemValidationResult is not null)"); + code.StartBlock(); + code.WriteLine($"foreach (var error in itemValidationResult.Errors)"); + code.StartBlock(); + code.WriteLine($@"resultBuilder.WithErrors($""value{parameterIndex}[{{ value{parameterIndex}Index }}].{{error.Key}}"", error.Value);"); + code.EndBlock(); + code.WriteLine($"value{parameterIndex}Index++;"); + code.EndBlock(); + code.EndBlock(); + } + else if (parameter.HasValidatableType) + { + code.WriteLine($"var typeValidationResult = ValidationTypes.Validate(value{parameterIndex}, serviceProvider: context.HttpContext.RequestServices);"); + code.WriteLine("if (typeValidationResult is not null)"); + code.StartBlock(); + code.WriteLine("foreach (var error in typeValidationResult.Errors)"); + code.StartBlock(); + code.WriteLine("resultBuilder.WithErrors(error.Key, error.Value);"); + code.EndBlock(); + code.EndBlock(); + } + } + code.WriteLine("return resultBuilder.HasValue() ? resultBuilder.Build() : null;"); + code.EndBlock(); + code.EndBlockWithComma(); + return writer.ToString(); + } +} diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.WithValidationEmitter.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.WithValidationEmitter.cs new file mode 100644 index 000000000000..758bbd0040b6 --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Emitters/ValidationsGenerator.WithValidationEmitter.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.IO; +using System.Threading; +using Microsoft.AspNetCore.Analyzers.Infrastructure; +using Microsoft.CodeAnalysis.CSharp; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator +{ +#pragma warning disable RSEXPERIMENTAL002 + internal static string EmitWithValidationInterception(InterceptableLocation? location, CancellationToken cancellationToken) + { + AnalyzerDebug.Assert(location != null, "Interceptable location should not be null."); + var writer = new StringWriter(); + var code = new CodeWriter(writer, baseIndent: 1); + code.WriteLine("file static class WithValidationsInterceptor"); + code.StartBlock(); + code.WriteLine(location.GetInterceptsLocationAttributeSyntax()); + code.WriteLine("public static global::Microsoft.AspNetCore.Builder.IEndpointConventionBuilder WithValidation(this global::Microsoft.AspNetCore.Builder.IEndpointConventionBuilder builder)"); + code.StartBlock(); + code.WriteLine("System.Diagnostics.Debugger.Break();"); + code.WriteLine("builder.AddEndpointFilter(async (context, next) =>"); + code.StartBlock(); + code.WriteLine("var targetEndpoint = context.HttpContext.Features.Get()?.Endpoint;"); + code.WriteLine("Debug.Assert(targetEndpoint != null);"); + code.WriteLine("var route = ((global::Microsoft.AspNetCore.Routing.RouteEndpoint)targetEndpoint).RoutePattern.RawText;"); + code.WriteLine(@"var methods = ((global::Microsoft.AspNetCore.Routing.RouteEndpoint)targetEndpoint).Metadata.GetMetadata()?.HttpMethods ?? [""GET""];"); + code.WriteLine("Debug.Assert(route != null);"); + code.WriteLine("var validationFilter = ValidationsFilters.Filters[new EndpointKey(route, methods)];"); + code.WriteLine("var validationProblemDetails = validationFilter(context);"); + code.WriteLine("if (validationProblemDetails == null)"); + code.StartBlock(); + code.WriteLine("return await next(context);"); + code.EndBlock(); + code.WriteLine("return global::Microsoft.AspNetCore.Http.TypedResults.ValidationProblem(validationProblemDetails.Errors);"); + code.Indent--; + code.WriteLine("});"); + code.WriteLine("return builder;"); + code.EndBlock(); + code.EndBlock(); + return writer.ToString(); + } +#pragma warning restore RSEXPERIMENTAL002 +} diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Extensions/IInvocationOperationExtensions.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Extensions/IInvocationOperationExtensions.cs new file mode 100644 index 000000000000..8a78decff676 --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Extensions/IInvocationOperationExtensions.cs @@ -0,0 +1,86 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public static class IInvocationOperationExtensions +{ + public static string GetEndpointKey(this IInvocationOperation operation) + { + var routePattern = operation.GetRoutePattern(); + var httpMethods = operation.GetHttpMethods(); + return $@"new EndpointKey(""{routePattern}"", {httpMethods})"; + } + + private static string GetRoutePattern(this IInvocationOperation operation) + { + if (operation.Arguments.Length < 2) + { + throw new InvalidOperationException("The operation does not contain enough arguments to extract the route pattern."); + } + + var routePatternArgument = operation.Arguments[1]; + if (routePatternArgument.Value.Syntax is LiteralExpressionSyntax literalExpression) + { + return literalExpression.Token.ValueText; + } + + throw new InvalidOperationException("The route pattern argument is not a literal expression."); + } + + private static string GetHttpMethods(this IInvocationOperation operation) + { + var syntax = (InvocationExpressionSyntax)operation.Syntax; + var expression = (MemberAccessExpressionSyntax)syntax.Expression; + var name = (IdentifierNameSyntax)expression.Name; + var identifier = name.Identifier; + var builder = new StringBuilder(); + builder.Append('['); + if (identifier.ValueText == "MapMethods") + { + var methods = ExtractMapMethods(operation); + builder.Append(string.Join(", ", methods.Select(method => @$"""{method}"""))); + } + else + { + builder.Append('"'); + builder.Append(identifier.ValueText switch + { + "MapGet" => "GET", + "MapPost" => "POST", + "MapPut" => "PUT", + "MapDelete" => "DELETE", + "MapPatch" => "PATCH", + _ => throw new InvalidOperationException("Unsupported HTTP method."), + }); + builder.Append('"'); + } + builder.Append(']'); + return builder.ToString(); + + static List ExtractMapMethods(IInvocationOperation operation) + { + var arguments = operation.Arguments; + var methods = arguments[2].Value; + if (methods.Syntax is ImplicitArrayCreationExpressionSyntax implicitArrayCreation) + { + var initializer = implicitArrayCreation.Initializer; + if (initializer != null) + { + return [.. initializer.Expressions + .OfType() + .Select(literal => literal.Token.ValueText)]; + } + } + + return []; + } + } +} diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Extensions/ISymbolExtensions.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Extensions/ISymbolExtensions.cs new file mode 100644 index 000000000000..6d0c926424b9 --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Extensions/ISymbolExtensions.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Linq; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public static class ISymbolExtensions +{ + public static string GetDisplayName(this ISymbol property, INamedTypeSymbol displayAttribute) + { + var displayNameAttribute = property.GetAttributes() + .FirstOrDefault(attribute => + attribute.AttributeClass is { } attributeClass && + SymbolEqualityComparer.Default.Equals(attributeClass, displayAttribute)); + if (displayNameAttribute is not null) + { + if (displayNameAttribute.ConstructorArguments.Length > 0) + { + return displayNameAttribute.ConstructorArguments[0].Value?.ToString() ?? property.Name; + } + else if (displayNameAttribute.NamedArguments.Length > 0) + { + return displayNameAttribute.NamedArguments[0].Value.Value?.ToString() ?? property.Name; + } + return property.Name; + } + + return property.Name; + } +} diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Extensions/ITypeSymbolExtensions.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Extensions/ITypeSymbolExtensions.cs new file mode 100644 index 000000000000..aada93ae4ad3 --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Extensions/ITypeSymbolExtensions.cs @@ -0,0 +1,99 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public static class ITypeSymbolExtensions +{ + public static bool IsNullable(this ITypeSymbol type) + { + if (type is INamedTypeSymbol namedType) + { + return namedType.IsGenericType && + namedType.ConstructedFrom.SpecialType == SpecialType.System_Nullable_T; + } + + return false; + } + + public static bool IsEnumerable(this ITypeSymbol type, INamedTypeSymbol enumerable) + { + if (type.SpecialType == SpecialType.System_String) + { + return false; + } + + return type.ImplementsInterface(enumerable) || SymbolEqualityComparer.Default.Equals(type, enumerable); + } + + public static bool ImplementsValidationAttribute(this ITypeSymbol typeSymbol, INamedTypeSymbol validationAttributeSymbol) + { + var baseType = typeSymbol.BaseType; + while (baseType != null) + { + if (SymbolEqualityComparer.Default.Equals(baseType, validationAttributeSymbol)) + { + return true; + } + baseType = baseType.BaseType; + } + + return false; + } + + public static ITypeSymbol UnwrapType(this ITypeSymbol type, INamedTypeSymbol enumerable) + { + if (type.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T) + { + // Extract the T from a Nullable + type = ((INamedTypeSymbol)type).TypeArguments[0]; + } + + if (type.NullableAnnotation == NullableAnnotation.Annotated) + { + // Extract the underlying type from a reference type + type = type.OriginalDefinition; + } + + if (type is INamedTypeSymbol namedType && namedType.IsEnumerable(enumerable)) + { + // Extract the T from an IEnumerable or List + type = namedType.TypeArguments[0]; + } + + return type; + } + + internal static bool ImplementsInterface(this ITypeSymbol type, ITypeSymbol interfaceType) + { + foreach (var iface in type.AllInterfaces) + { + if (SymbolEqualityComparer.Default.Equals(interfaceType, iface)) + { + return true; + } + } + return false; + } + + internal static ImmutableArray? GetJsonDerivedTypes(this ITypeSymbol type, INamedTypeSymbol jsonDerivedTypeAttribute) + { + var derivedTypes = ImmutableArray.CreateBuilder(); + foreach (var attribute in type.GetAttributes()) + { + if (SymbolEqualityComparer.Default.Equals(attribute.AttributeClass, jsonDerivedTypeAttribute)) + { + var derivedType = (INamedTypeSymbol?)attribute.ConstructorArguments[0].Value; + if (derivedType is not null && !SymbolEqualityComparer.Default.Equals(derivedType, type)) + { + derivedTypes.Add(derivedType); + } + } + } + + return derivedTypes.Count == 0 ? null : derivedTypes.ToImmutable(); + } +} diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Extensions/IncrementalValuesProviderExtensions.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Extensions/IncrementalValuesProviderExtensions.cs new file mode 100644 index 000000000000..4d587a989a9a --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Extensions/IncrementalValuesProviderExtensions.cs @@ -0,0 +1,85 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public static class IncrementalValuesProviderExtensions +{ + public static IncrementalValuesProvider Distinct(this IncrementalValuesProvider source, IEqualityComparer comparer) + { + return source + .Collect() + .WithComparer(ImmutableArrayEqualityComparer.Instance) + .SelectMany((values, cancellationToken) => + { + if (values.IsEmpty) + { + return values; + } + + var results = ImmutableArray.CreateBuilder(values.Length); + HashSet set = new(comparer); + + foreach (var value in values) + { + if (set.Add(value)) + { + results.Add(value); + } + } + + return results.DrainToImmutable(); + }); + } + + private sealed class ImmutableArrayEqualityComparer : IEqualityComparer> + { + public static readonly ImmutableArrayEqualityComparer Instance = new(); + + public bool Equals(ImmutableArray x, ImmutableArray y) + { + if (x.IsDefault) + { + return y.IsDefault; + } + else if (y.IsDefault) + { + return false; + } + + if (x.Length != y.Length) + { + return false; + } + + for (var i = 0; i < x.Length; i++) + { + if (!EqualityComparer.Default.Equals(x[i], y[i])) + { + return false; + } + } + + return true; + } + + public int GetHashCode(ImmutableArray obj) + { + if (obj.IsDefault) + { + return 0; + } + var hashCode = -450793227; + foreach (var item in obj) + { + hashCode = (hashCode * -1521134295) + EqualityComparer.Default.GetHashCode(item); + } + + return hashCode; + } + } +} diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Microsoft.AspNetCore.Http.ValidationsGenerator.csproj b/src/Http/Http.Extensions/gen/ValidationsGenerator/Microsoft.AspNetCore.Http.ValidationsGenerator.csproj new file mode 100644 index 000000000000..ba7614b0ca7d --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Microsoft.AspNetCore.Http.ValidationsGenerator.csproj @@ -0,0 +1,33 @@ + + + + netstandard2.0 + true + false + true + false + enable + true + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/RequiredSymbols.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/RequiredSymbols.cs new file mode 100644 index 000000000000..29ca4f5de6bf --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/RequiredSymbols.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal sealed record class RequiredSymbols( + INamedTypeSymbol DisplayAttribute, + INamedTypeSymbol ValidationAttribute, + INamedTypeSymbol IEnumerable, + INamedTypeSymbol IValidatableObject, + INamedTypeSymbol JsonDerivedTypeAttribute +); diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableEndpoint.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableEndpoint.cs new file mode 100644 index 000000000000..24cf2d6c9af2 --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableEndpoint.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal sealed record class ValidatableEndpoint( + string EndpointKey, + ImmutableArray Parameters, + ImmutableArray ValidatableTypes +); diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableMember.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableMember.cs new file mode 100644 index 000000000000..918d4b9c4bb4 --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableMember.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal sealed record class ValidatableMember( + string Name, + string DisplayName, + bool IsEnumerable, + bool IsNullable, + bool HasValidatableType, + ImmutableArray Attributes +); diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableParameter.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableParameter.cs new file mode 100644 index 000000000000..2702a65e10ea --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableParameter.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal sealed record class ValidatableParameter( + ITypeSymbol Type, + ITypeSymbol OriginalType, + string Name, + string DisplayName, + int Index, + bool IsEnumerable, + bool IsNullable, + bool HasValidatableType, + ImmutableArray Attributes +); diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableType.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableType.cs new file mode 100644 index 000000000000..7f1a8a60e723 --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableType.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal sealed record class ValidatableType( + ITypeSymbol Type, + string Name, + bool IsIValidatableObject, + ImmutableArray Members, + ImmutableArray ValidatableSubTypeNames, + ImmutableArray ValidatableDerivedTypeNames +); diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableTypeComparer.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableTypeComparer.cs new file mode 100644 index 000000000000..de60b3500da4 --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidatableTypeComparer.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal class ValidatableTypeComparer : IEqualityComparer +{ + public static ValidatableTypeComparer Instance { get; } = new(); + + public bool Equals(ValidatableType? x, ValidatableType? y) + { + if (x is null && y is null) + { + return true; + } + if (x is null || y is null) + { + return false; + } + return x.Name == y.Name; + } + + public int GetHashCode(ValidatableType? obj) + { + return obj?.Name.GetHashCode() ?? 0; + } +} diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidationAttribute.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidationAttribute.cs new file mode 100644 index 000000000000..e881d5f9cd25 --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Models/ValidationAttribute.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal sealed record class ValidationAttribute( + string Name, + string ClassName, + List Arguments, + Dictionary NamedArguments, + bool ForParameter = false +); diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Parsers/ValidationsGenerator.EndpointsParser.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Parsers/ValidationsGenerator.EndpointsParser.cs new file mode 100644 index 000000000000..692a6cd2757c --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Parsers/ValidationsGenerator.EndpointsParser.cs @@ -0,0 +1,65 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Threading; +using Microsoft.AspNetCore.Analyzers.Infrastructure; +using Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator : IIncrementalGenerator +{ + internal bool FindWithValidation(SyntaxNode syntaxNode, CancellationToken cancellationToken) + { + if (syntaxNode is InvocationExpressionSyntax + && syntaxNode.TryGetMapMethodName(out var method) + && method == "WithValidation") + { + return true; + } + return false; + } + +#pragma warning disable RSEXPERIMENTAL002 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + internal InterceptableLocation? TransformWithValidation(GeneratorSyntaxContext context, CancellationToken cancellationToken) + { + var node = (InvocationExpressionSyntax)context.Node; + var semanticModel = context.SemanticModel; + return semanticModel.GetInterceptableLocation(node, cancellationToken); + } +#pragma warning restore RSEXPERIMENTAL002 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + + internal bool FindEndpoints(SyntaxNode syntaxNode, CancellationToken cancellationToken) + { + if (syntaxNode is InvocationExpressionSyntax + && syntaxNode.TryGetMapMethodName(out var method)) + { + return method == "MapMethods" || InvocationOperationExtensions.KnownMethods.Contains(method); + } + return false; + } + + internal IInvocationOperation TransformEndpoints(GeneratorSyntaxContext context, CancellationToken cancellationToken) + { + var node = (InvocationExpressionSyntax)context.Node; + var operation = context.SemanticModel.GetOperation(node, cancellationToken); + AnalyzerDebug.Assert(operation != null, "Operation should not be null."); + return (IInvocationOperation)operation; + } + + internal ValidatableEndpoint ExtractValidatableEndpoint((IInvocationOperation Operation, RequiredSymbols RequiredSymbols) input, CancellationToken cancellationToken) + { + var endpointKey = input.Operation.GetEndpointKey(); + HashSet validatableTypes = new HashSet(ValidatableTypeComparer.Instance); + var parameters = ExtractParameters(input.Operation, input.RequiredSymbols, ref validatableTypes); + return new ValidatableEndpoint(endpointKey, parameters, [.. validatableTypes]); + } +} diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Parsers/ValidationsGenerator.RequiredSymbolsParser.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Parsers/ValidationsGenerator.RequiredSymbolsParser.cs new file mode 100644 index 000000000000..a24c53aae009 --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Parsers/ValidationsGenerator.RequiredSymbolsParser.cs @@ -0,0 +1,21 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator : IIncrementalGenerator +{ + internal RequiredSymbols ExtractRequireSymbols(Compilation compilation, CancellationToken cancellationToken) + { + return new RequiredSymbols( + compilation.GetTypeByMetadataName("System.ComponentModel.DataAnnotations.DisplayAttribute")!, + compilation.GetTypeByMetadataName("System.ComponentModel.DataAnnotations.ValidationAttribute")!, + compilation.GetTypeByMetadataName("System.Collections.IEnumerable")!, + compilation.GetTypeByMetadataName("System.ComponentModel.DataAnnotations.IValidatableObject")!, + compilation.GetTypeByMetadataName("System.Text.Json.Serialization.JsonDerivedTypeAttribute")! + ); + } +} diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/Parsers/ValidationsGenerator.TypesParser.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/Parsers/ValidationsGenerator.TypesParser.cs new file mode 100644 index 000000000000..34fdba9fab5a --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/Parsers/ValidationsGenerator.TypesParser.cs @@ -0,0 +1,165 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Threading; +using Microsoft.AspNetCore.Analyzers.Infrastructure; +using Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Operations; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator : IIncrementalGenerator +{ + private static readonly SymbolDisplayFormat _symbolDisplayFormat = new( + globalNamespaceStyle: SymbolDisplayGlobalNamespaceStyle.Included, + genericsOptions: SymbolDisplayGenericsOptions.IncludeTypeParameters, + typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypesAndNamespaces); + + internal ImmutableArray ExtractParameters(IInvocationOperation operation, RequiredSymbols requiredSymbols, ref HashSet validatableTypes) + { + AnalyzerDebug.Assert(operation.SemanticModel != null, "SemanticModel should not be null."); + var parameters = operation.TryGetRouteHandlerMethod(operation.SemanticModel, out var method) + ? method.Parameters + : []; + var validatableParameters = ImmutableArray.CreateBuilder(parameters.Length); + List visitedTypes = []; + foreach (var parameter in parameters) + { + var hasValidatableType = TryExtractValidatableType(parameter.Type.UnwrapType(requiredSymbols.IEnumerable), requiredSymbols, ref validatableTypes, ref visitedTypes); + validatableParameters.Add(new ValidatableParameter( + Type: parameter.Type.UnwrapType(requiredSymbols.IEnumerable), + OriginalType: parameter.Type, + Name: parameter.Name, + DisplayName: parameter.GetDisplayName(requiredSymbols.DisplayAttribute), + Index: parameter.Ordinal, + IsNullable: parameter.Type.IsNullable(), + IsEnumerable: parameter.Type.IsEnumerable(requiredSymbols.IEnumerable), + Attributes: ExtractValidationAttributes(parameter, requiredSymbols), + HasValidatableType: hasValidatableType)); + } + return validatableParameters.ToImmutable(); + } + + internal bool TryExtractValidatableType(ITypeSymbol typeSymbol, RequiredSymbols requiredSymbols, ref HashSet validatableTypes, ref List visitedTypes) + { + if (typeSymbol.SpecialType != SpecialType.None) + { + return false; + } + + if (visitedTypes.Contains(typeSymbol)) + { + return true; + } + + visitedTypes.Add(typeSymbol); + + // Extract validatable types discovered in base types of this type and add them to the top-level list. + var current = typeSymbol.BaseType; + List? validatableSubTypes = []; + while (current != null && current.SpecialType != SpecialType.System_Object) + { + _ = TryExtractValidatableType(current, requiredSymbols, ref validatableTypes, ref visitedTypes); + validatableSubTypes.Add(current.Name); + current = current.BaseType; + } + + // Extract validatable types discovered in members of this type and add them to the top-level list. + var members = ExtractValidatableMembers(typeSymbol, requiredSymbols, ref validatableTypes, ref visitedTypes); + + // Extract the validatable types discovered in the JsonDerivedTypeAttributes of this type and add them to the top-level list. + var derivedTypes = typeSymbol.GetJsonDerivedTypes(requiredSymbols.JsonDerivedTypeAttribute); + var derivedTypeNames = derivedTypes?.Select(t => t.Name).ToArray() ?? []; + foreach (var derivedType in derivedTypes ?? []) + { + _ = TryExtractValidatableType(derivedType, requiredSymbols, ref validatableTypes, ref visitedTypes); + } + + // Add the type itself as a validatable type itself. + validatableTypes.Add(new ValidatableType( + Type: typeSymbol, + Name: typeSymbol.ToDisplayString(), + Members: members, + IsIValidatableObject: typeSymbol.ImplementsInterface(requiredSymbols.IValidatableObject), + ValidatableSubTypeNames: [.. validatableSubTypes], + ValidatableDerivedTypeNames: [.. derivedTypeNames])); + + return true; + } + + internal ImmutableArray ExtractValidatableMembers(ITypeSymbol typeSymbol, RequiredSymbols requiredSymbols, ref HashSet validatableTypes, ref List visitedTypes) + { + var members = new List(); + foreach (var member in typeSymbol.GetMembers().OfType()) + { + var hasValidatableType = TryExtractValidatableType(member.Type.UnwrapType(requiredSymbols.IEnumerable), requiredSymbols, ref validatableTypes, ref visitedTypes); + members.Add(new ValidatableMember( + Name: member.Name, + DisplayName: member.GetDisplayName(requiredSymbols.DisplayAttribute), + IsEnumerable: member.Type.IsEnumerable(requiredSymbols.IEnumerable), + IsNullable: member.Type.IsNullable(), + Attributes: ExtractValidationAttributes(member, requiredSymbols), + HasValidatableType: hasValidatableType)); + } + + return [.. members]; + } + + public ImmutableArray ExtractPropertyTypes(ITypeSymbol type, CancellationToken cancellationToken) + { + var builder = ImmutableArray.CreateBuilder(); + var processed = new HashSet(SymbolEqualityComparer.Default); + + void Traverse(ITypeSymbol currentType) + { + if (currentType == null || currentType.SpecialType != SpecialType.None || processed.Contains(currentType)) + { + return; + } + + processed.Add(currentType); + builder.Add(currentType); + + foreach (var member in currentType.GetMembers().OfType()) + { + if (member.Type is ITypeSymbol propertyType) + { + Traverse(propertyType); + } + } + } + + Traverse(type); + return builder.ToImmutable(); + } + + internal static ImmutableArray ExtractValidationAttributes(IPropertySymbol property, RequiredSymbols requiredSymbols) + { + return [.. property.GetAttributes() + .Where(attribute => attribute.AttributeClass != null) + .Where(attribute => attribute.AttributeClass!.ImplementsValidationAttribute(requiredSymbols.ValidationAttribute)) + .Select(attribute => new ValidationAttribute( + Name: property.Type.Name + property.Name + attribute.AttributeClass!.Name, + ClassName: attribute.AttributeClass!.ToDisplayString(_symbolDisplayFormat), + Arguments: [.. attribute.ConstructorArguments.Select(a => a.ToCSharpString())], + NamedArguments: attribute.NamedArguments.ToDictionary(kvp => kvp.Key, kvp => kvp.Value.ToCSharpString())))]; + } + + internal static ImmutableArray ExtractValidationAttributes(IParameterSymbol parameter, RequiredSymbols requiredSymbols) + { + return [.. parameter.GetAttributes() + .Where(attribute => attribute.AttributeClass != null) + .Where(attribute => attribute.AttributeClass!.ImplementsValidationAttribute(requiredSymbols.ValidationAttribute)) + .Select(attribute => new ValidationAttribute( + Name: parameter.Name + attribute.AttributeClass!.Name, + ClassName: attribute.AttributeClass!.ToDisplayString(_symbolDisplayFormat), + Arguments: [.. attribute.ConstructorArguments.Select(a => a.ToCSharpString())], + NamedArguments: attribute.NamedArguments.ToDictionary(kvp => kvp.Key, kvp => kvp.Value.ToCSharpString()), + ForParameter: true))]; + } +} diff --git a/src/Http/Http.Extensions/gen/ValidationsGenerator/ValidationsGenerator.cs b/src/Http/Http.Extensions/gen/ValidationsGenerator/ValidationsGenerator.cs new file mode 100644 index 000000000000..ba1e76a539bc --- /dev/null +++ b/src/Http/Http.Extensions/gen/ValidationsGenerator/ValidationsGenerator.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Linq; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator : IIncrementalGenerator +{ + public void Initialize(IncrementalGeneratorInitializationContext context) + { + // Find the app.Conventions.WithValidation to call to indicate the + // user has opted in to validation for the minimal APIs. + var withValidation = context.SyntaxProvider.CreateSyntaxProvider( + predicate: FindWithValidation, + transform: TransformWithValidation + ); + // Extract all minimal API endpoints in the application. + var endpoints = context.SyntaxProvider.CreateSyntaxProvider( + predicate: FindEndpoints, + transform: TransformEndpoints); + // Resolve the symbols that will be required when making comparisons + // in future steps. + var requiredSymbols = context.CompilationProvider.Select(ExtractRequireSymbols); + // Extract all validatable endpoints encountered in the type graph. + var validatableEndpoints = endpoints + .Combine(requiredSymbols) + .Select(ExtractValidatableEndpoint); + // Extract all validatable types encountered in the type graph. + var validatableTypes = validatableEndpoints + .SelectMany((endpoint, ct) => endpoint.ValidatableTypes) + .Distinct(ValidatableTypeComparer.Instance) + .Collect(); + + // Generate emitted code for subtypes, interceptions, and filters. + var typeValidations = validatableTypes + .Select(EmitTypeValidations); + var withValidationInterceptions = withValidation + .Where(location => location is not null) + .Select(EmitWithValidationInterception); + var validationsFilters = validatableEndpoints.Select(EmitEndpointValidationFilter).Collect(); + + var validations = withValidationInterceptions + .Combine(typeValidations) + .Combine(validationsFilters); + + context.RegisterSourceOutput(validations, EmitValidationsFile); + } +} diff --git a/src/Http/Http.Extensions/test/Microsoft.AspNetCore.Http.Extensions.Tests.csproj b/src/Http/Http.Extensions/test/Microsoft.AspNetCore.Http.Extensions.Tests.csproj index 4a35778afa55..c6e2542a0d4d 100644 --- a/src/Http/Http.Extensions/test/Microsoft.AspNetCore.Http.Extensions.Tests.csproj +++ b/src/Http/Http.Extensions/test/Microsoft.AspNetCore.Http.Extensions.Tests.csproj @@ -3,7 +3,7 @@ $(DefaultNetCoreTargetFramework) - $(Features.Replace('nullablePublicOnly', '') + $(Features.Replace('nullablePublicOnly', '')) true @@ -18,9 +18,12 @@ + + + @@ -29,7 +32,8 @@ - + + diff --git a/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateCreationTests.cs b/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateCreationTests.cs index 9e4f803968e8..fce27e952dab 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateCreationTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateCreationTests.cs @@ -605,7 +605,7 @@ public async Task RequestDelegateHandlesStringValuesFromExplicitQueryStringSourc public async Task RequestDelegateGeneratesCompilableCodeForServiceInNamespaceHttp() { var source = """ -app.MapGet("/hello", ([FromServices] ExampleService e) => e.Act("To be or not to be…")); +app.MapGet("/hello", ([FromServices] global::Http.ExampleService e) => e.Act("To be or not to be…")); """; var (results, compilation) = await RunGeneratorAsync(source); diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ModuleInitializer.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ModuleInitializer.cs new file mode 100644 index 000000000000..2e42614606d4 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ModuleInitializer.cs @@ -0,0 +1,11 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.CompilerServices; + +public static class ModuleInitializer +{ + [ModuleInitializer] + public static void Init() => + VerifySourceGenerators.Initialize(); +} diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.ComplexType.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.ComplexType.cs new file mode 100644 index 000000000000..8d85465f69c9 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.ComplexType.cs @@ -0,0 +1,292 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Net.Http.Json; +using Microsoft.AspNetCore.Http; + +public partial class ValidationsGeneratorTests : ValidationsGeneratorTestsBase +{ + [Fact] + public async Task CanValidateComplexTypes() + { + // Arrange + var source = """ +using System; +using System.ComponentModel.DataAnnotations; +using System.Collections.Generic; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing; + +var builder = WebApplication.CreateBuilder(); + +var app = builder.Build(); + +app.Conventions.WithValidation(); + +app.MapPost("/complex-type", (ComplexType complexType) => Results.Ok()); + +app.Run(); + +public class ComplexType +{ + [Range(10, 100)] + public int IntegerWithRange { get; set; } = 10; + + [Range(10, 100), Display(Name = "Valid identifier")] + public int IntegerWithRangeAndDisplayName { get; set; } = 50; + + [Required] + public SubType PropertyWithMemberAttributes { get; set; } = new SubType(); + + public SubType PropertyWithoutMemberAttributes { get; set; } = new SubType(); + + public SubTypeWithInheritance PropertyWithInheritance { get; set; } = new SubTypeWithInheritance(); + + public List ListOfSubTypes { get; set; } = []; + + [CustomValidation(ErrorMessage = "Value must be an even number")] + public int IntegerWithCustomValidationAttribute { get; set; } + + [CustomValidation, Range(10, 100)] + public int PropertyWithMultipleAttributes { get; set; } = 10; +} + +public class CustomValidationAttribute : ValidationAttribute +{ + public override bool IsValid(object? value) => value is int number && number % 2 == 0; +} + +public class SubType +{ + [Required] + public string RequiredProperty { get; set; } = "some-value"; + + [StringLength(10)] + public string? StringWithLength { get; set; } +} + +public class SubTypeWithInheritance : SubType +{ + [EmailAddress] + public string? EmailString { get; set; } +} +"""; + await Verify(source, out var compilation); + await VerifyEndpoint(compilation, async client => + { + await InvalidIntegerWithRangeProducesError(client); + await InvalidIntegerWithRangeAndDisplayNameProducesError(client); + await MissingRequiredSubtypePropertyProducesError(client); + await InvalidRequiredSubtypePropertyProducesError(client); + await InvalidSubTypeWithInheritancePropertyProducesError(client); + await InvalidListOfSubTypesProducesError(client); + await InvalidPropertyWithCustomValidationAttributeProducesError(client); + await InvalidPropertyWithMultipleAttributesProducesError(client); + + static async Task InvalidIntegerWithRangeProducesError(HttpClient client) + { + var payload = """ + { + "IntegerWithRange": 5 + } + """; + var content = new StringContent(payload, new MediaTypeHeaderValue("application/json")); + var response = await client.PostAsync("/complex-type", content); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var problemDetails = await response.Content.ReadFromJsonAsync(); + Assert.Collection(problemDetails.Errors, kvp => + { + Assert.Equal("IntegerWithRange", kvp.Key); + Assert.Equal("The field IntegerWithRange must be between 10 and 100.", kvp.Value.Single()); + }); + } + + static async Task InvalidIntegerWithRangeAndDisplayNameProducesError(HttpClient client) + { + var payload = """ + { + "IntegerWithRangeAndDisplayName": 5 + } + """; + var content = new StringContent(payload, new MediaTypeHeaderValue("application/json")); + var response = await client.PostAsync("/complex-type", content); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var problemDetails = await response.Content.ReadFromJsonAsync(); + Assert.Collection(problemDetails.Errors, kvp => + { + Assert.Equal("IntegerWithRangeAndDisplayName", kvp.Key); + Assert.Equal("The field Valid identifier must be between 10 and 100.", kvp.Value.Single()); + }); + } + + static async Task MissingRequiredSubtypePropertyProducesError(HttpClient client) + { + var payload = """ + { + "PropertyWithMemberAttributes": null + } + """; + var content = new StringContent(payload, new MediaTypeHeaderValue("application/json")); + var response = await client.PostAsync("/complex-type", content); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var problemDetails = await response.Content.ReadFromJsonAsync(); + Assert.Collection(problemDetails.Errors, kvp => + { + Assert.Equal("PropertyWithMemberAttributes", kvp.Key); + Assert.Equal("The PropertyWithMemberAttributes field is required.", kvp.Value.Single()); + }); + } + + static async Task InvalidRequiredSubtypePropertyProducesError(HttpClient client) + { + var payload = """ + { + "PropertyWithMemberAttributes": { + "RequiredProperty": "", + "StringWithLength": "way-too-long" + } + } + """; + var content = new StringContent(payload, new MediaTypeHeaderValue("application/json")); + var response = await client.PostAsync("/complex-type", content); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var problemDetails = await response.Content.ReadFromJsonAsync(); + Assert.Collection(problemDetails.Errors, + kvp => + { + Assert.Equal("PropertyWithMemberAttributes.RequiredProperty", kvp.Key); + Assert.Equal("The RequiredProperty field is required.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("PropertyWithMemberAttributes.StringWithLength", kvp.Key); + Assert.Equal("The field StringWithLength must be a string with a maximum length of 10.", kvp.Value.Single()); + }); + } + + static async Task InvalidSubTypeWithInheritancePropertyProducesError(HttpClient client) + { + var payload = """ + { + "PropertyWithInheritance": { + "RequiredProperty": "", + "StringWithLength": "way-too-long", + "EmailString": "not-an-email" + } + } + """; + var content = new StringContent(payload, new MediaTypeHeaderValue("application/json")); + var response = await client.PostAsync("/complex-type", content); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var problemDetails = await response.Content.ReadFromJsonAsync(); + Assert.Collection(problemDetails.Errors, + kvp => + { + Assert.Equal("PropertyWithInheritance.RequiredProperty", kvp.Key); + Assert.Equal("The RequiredProperty field is required.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("PropertyWithInheritance.StringWithLength", kvp.Key); + Assert.Equal("The field StringWithLength must be a string with a maximum length of 10.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("PropertyWithInheritance.EmailString", kvp.Key); + Assert.Equal("The EmailString field is not a valid e-mail address.", kvp.Value.Single()); + }); + } + + static async Task InvalidListOfSubTypesProducesError(HttpClient client) + { + var payload = """ + { + "ListOfSubTypes": [ + { + "RequiredProperty": "", + "StringWithLength": "way-too-long" + }, + { + "RequiredProperty": "valid", + "StringWithLength": "way-too-long" + }, + { + "RequiredProperty": "valid", + "StringWithLength": "valid" + } + ] + } + """; + var content = new StringContent(payload, new MediaTypeHeaderValue("application/json")); + var response = await client.PostAsync("/complex-type", content); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var problemDetails = await response.Content.ReadFromJsonAsync(); + Assert.Collection(problemDetails.Errors, + kvp => + { + Assert.Equal("ListOfSubTypes[0].RequiredProperty", kvp.Key); + Assert.Equal("The RequiredProperty field is required.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("ListOfSubTypes[0].StringWithLength", kvp.Key); + Assert.Equal("The field StringWithLength must be a string with a maximum length of 10.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("ListOfSubTypes[1].StringWithLength", kvp.Key); + Assert.Equal("The field StringWithLength must be a string with a maximum length of 10.", kvp.Value.Single()); + }); + } + + static async Task InvalidPropertyWithCustomValidationAttributeProducesError(HttpClient client) + { + var payload = """ + { + "IntegerWithCustomValidationAttribute": 5 + } + """; + var content = new StringContent(payload, new MediaTypeHeaderValue("application/json")); + var response = await client.PostAsync("/complex-type", content); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var problemDetails = await response.Content.ReadFromJsonAsync(); + Assert.Collection(problemDetails.Errors, kvp => + { + Assert.Equal("IntegerWithCustomValidationAttribute", kvp.Key); + Assert.Equal("Value must be an even number", kvp.Value.Single()); + }); + } + + static async Task InvalidPropertyWithMultipleAttributesProducesError(HttpClient client) + { + var payload = """ + { + "PropertyWithMultipleAttributes": 5 + } + """; + var content = new StringContent(payload, new MediaTypeHeaderValue("application/json")); + var response = await client.PostAsync("/complex-type", content); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var problemDetails = await response.Content.ReadFromJsonAsync(); + Assert.Collection(problemDetails.Errors, kvp => + { + Assert.Equal("PropertyWithMultipleAttributes", kvp.Key); + Assert.Collection(kvp.Value, + error => + { + Assert.Equal("The field PropertyWithMultipleAttributes is invalid.", error); + }, + error => + { + Assert.Equal("The field PropertyWithMultipleAttributes must be between 10 and 100.", error); + }); + }); + } + }); + + } +} diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.IValidatableObject.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.IValidatableObject.cs new file mode 100644 index 000000000000..5a0f4f1662fc --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.IValidatableObject.cs @@ -0,0 +1,189 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Net.Http.Json; +using Microsoft.AspNetCore.Http; + +public partial class ValidationsGeneratorTests : ValidationsGeneratorTestsBase +{ + [Fact] + public async Task CanValidateIValidatableObject() + { + var source = """ +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +var builder = WebApplication.CreateBuilder(); + +builder.Services.AddSingleton(); + +var app = builder.Build(); + +app.Conventions.WithValidation(); + +app.MapPost("/validatable-object", (ComplexValidatableType model) => Results.Ok()); + +app.Run(); + +public class ComplexValidatableType: IValidatableObject +{ + [Display(Name = "Value 1")] + public int Value1 { get; set; } + + [EmailAddress] + [Required] + public required string Value2 { get; set; } = "test@example.com"; + + public ValidatableSubType SubType { get; set; } = new ValidatableSubType(); + + public IEnumerable Validate(ValidationContext validationContext) + { + var rangeService = (IRangeService?)validationContext.GetService(typeof(IRangeService)); + var minimum = rangeService?.GetMinimum(); + var maximum = rangeService?.GetMaximum(); + if (Value1 < minimum || Value1 > maximum) + { + yield return new ValidationResult($"The field {validationContext.DisplayName} must be between {minimum} and {maximum}.", [nameof(Value1)]); + } + } +} + +public class SubType +{ + [Required] + public string RequiredProperty { get; set; } = "some-value"; + + [StringLength(10)] + public string? StringWithLength { get; set; } +} + +public class ValidatableSubType : SubType, IValidatableObject +{ + public string Value3 { get; set; } = "some-value"; + + public IEnumerable Validate(ValidationContext validationContext) + { + if (Value3 != "some-value") + { + yield return new ValidationResult($"The field {validationContext.DisplayName} must be 'some-value'.", [nameof(Value3)]); + } + } +} + +public interface IRangeService +{ + int GetMinimum(); + int GetMaximum(); +} + +public class RangeService : IRangeService +{ + public int GetMinimum() => 10; + public int GetMaximum() => 100; +} +"""; + await Verify(source, out var compilation); + await VerifyEndpoint(compilation, async client => + { + await ValidateMethodSkippedIfPropertyValidationsFail(client); + await ValidateForSubtypeInvokedFirst(client); + await ValidateForTopLevelInvoked(client); + + static async Task ValidateMethodSkippedIfPropertyValidationsFail(HttpClient client) + { + var payload = """ + { + "Value1": 5, + "Value2": "", + "SubType": { + "Value3": "foo", + "RequiredProperty": "", + "StringWithLength": "" + } + } + """; + var content = new StringContent(payload, new MediaTypeHeaderValue("application/json")); + var response = await client.PostAsync("/validatable-object", content); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var problemDetails = await response.Content.ReadFromJsonAsync(); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("Value2", error.Key); + Assert.Collection(error.Value, + error => + { + Assert.Equal("The Value2 field is not a valid e-mail address.", error); + + }, + error => + { + Assert.Equal("The Value2 field is required.", error); + }); + }, + error => + { + Assert.Equal("SubType.RequiredProperty", error.Key); + Assert.Equal("The RequiredProperty field is required.", error.Value.Single()); + }); + } + + static async Task ValidateForSubtypeInvokedFirst(HttpClient client) + { + var payload = """ + { + "Value1": 5, + "Value2": "test@test.com", + "SubType": { + "Value3": "foo", + "RequiredProperty": "some-value-2", + "StringWithLength": "element" + } + } + """; + var content = new StringContent(payload, new MediaTypeHeaderValue("application/json")); + var response = await client.PostAsync("/validatable-object", content); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var problemDetails = await response.Content.ReadFromJsonAsync(); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("SubType.Value3", error.Key); + Assert.Equal("The field ValidatableSubType must be 'some-value'.", error.Value.Single()); + }); + } + + static async Task ValidateForTopLevelInvoked(HttpClient client) + { + var payload = """ + { + "Value1": 5, + "Value2": "test@test.com", + "SubType": { + "Value3": "some-value", + "RequiredProperty": "some-value-2", + "StringWithLength": "element" + } + } + """; + var content = new StringContent(payload, new MediaTypeHeaderValue("application/json")); + var response = await client.PostAsync("/validatable-object", content); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var problemDetails = await response.Content.ReadFromJsonAsync(); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("Value1", error.Key); + Assert.Equal("The field ComplexValidatableType must be between 10 and 100.", error.Value.Single()); + }); + } + }); + } +} diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.Interception.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.Interception.cs new file mode 100644 index 000000000000..b061efa601ae --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.Interception.cs @@ -0,0 +1,90 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Net.Http.Json; +using Microsoft.AspNetCore.Http; + +public partial class ValidationsGeneratorTests : ValidationsGeneratorTestsBase +{ + [Fact] + public async Task CanValidateOnAllMapOverloads() + { + var source = """ +using System; +using System.ComponentModel.DataAnnotations; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Routing; + +var builder = WebApplication.CreateBuilder(); + +var app = builder.Build(); + +app.Conventions.WithValidation(); + +app.MapGet("/todos/{id}", ([Range(1, 10)] int id) => id); +app.MapPost("/todos", (Todo todo) => todo.Id); +app.MapPut("/todos", (Todo todo) => todo.Id); +app.MapMethods("/todos/{id}", new [] { "delete", "PaTcH" }, ([Range(1, 10)] int id) => id); + +app.Run(); + +public class Todo +{ + [Range(1, 10)] + public int Id { get; set; } +} +"""; + await Verify(source, out var compilation); + await VerifyEndpoint(compilation, async client => + { + var response = await client.GetAsync("/todos/12"); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var problemDetails = await response.Content.ReadFromJsonAsync(); + var error = Assert.Single(problemDetails.Errors); + Assert.Equal("id", error.Key); + Assert.Equal("The field id must be between 1 and 10.", error.Value.Single()); + + var invalidPayload = """ + { + "Id": 12 + } + """; + var invalidResponse = await client.PostAsync("/todos", new StringContent(invalidPayload, new MediaTypeHeaderValue("application/json"))); + Assert.Equal(HttpStatusCode.BadRequest, invalidResponse.StatusCode); + problemDetails = await invalidResponse.Content.ReadFromJsonAsync(); + error = Assert.Single(problemDetails.Errors); + Assert.Equal("Id", error.Key); + Assert.Equal("The field Id must be between 1 and 10.", error.Value.Single()); + + invalidResponse = await client.PutAsync("/todos", new StringContent(invalidPayload, new MediaTypeHeaderValue("application/json"))); + Assert.Equal(HttpStatusCode.BadRequest, invalidResponse.StatusCode); + problemDetails = await invalidResponse.Content.ReadFromJsonAsync(); + error = Assert.Single(problemDetails.Errors); + Assert.Equal("Id", error.Key); + Assert.Equal("The field Id must be between 1 and 10.", error.Value.Single()); + + response = await client.DeleteAsync("/todos/12"); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + problemDetails = await response.Content.ReadFromJsonAsync(); + error = Assert.Single(problemDetails.Errors); + Assert.Equal("id", error.Key); + Assert.Equal("The field id must be between 1 and 10.", error.Value.Single()); + + response = await client.SendAsync(new HttpRequestMessage(HttpMethod.Patch, "/todos/12")); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + problemDetails = await response.Content.ReadFromJsonAsync(); + error = Assert.Single(problemDetails.Errors); + Assert.Equal("id", error.Key); + Assert.Equal("The field id must be between 1 and 10.", error.Value.Single()); + + response = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, "/todos/12")); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + problemDetails = await response.Content.ReadFromJsonAsync(); + error = Assert.Single(problemDetails.Errors); + Assert.Equal("The field id must be between 1 and 10.", error.Value.Single()); + }); + } +} diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.Parameters.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.Parameters.cs new file mode 100644 index 000000000000..fe8ec30431cd --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.Parameters.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Net.Http.Json; +using Microsoft.AspNetCore.Http; + +public partial class ValidationsGeneratorTests : ValidationsGeneratorTestsBase +{ + [Fact] + public async Task CanValidateParameters() + { + var source = """ +using System; +using System.ComponentModel.DataAnnotations; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Routing; + +var builder = WebApplication.CreateBuilder(); + +var app = builder.Build(); + +app.Conventions.WithValidation(); + +app.MapGet("/params", ( + [Range(10, 100)] int value1, + [Range(10, 100), Display(Name = "Valid identifier")] int value2, + [Required] string value3 = "some-value", + [CustomValidation(ErrorMessage = "Value must be an even number")] int value4 = 4, + [CustomValidation, Range(10, 100)] int value5 = 10) => "OK"); + +app.Run(); + +public class CustomValidationAttribute : ValidationAttribute +{ + public override bool IsValid(object? value) => value is int number && number % 2 == 0; +} +"""; + await Verify(source, out var compilation); + await VerifyEndpoint(compilation, async client => + { + var response = await client.GetAsync("/params?value1=5&value2=5&value3=&value4=3&value5=5"); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var problemDetails = await response.Content.ReadFromJsonAsync(); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("value1", error.Key); + Assert.Equal("The field value1 must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("value2", error.Key); + Assert.Equal("The field Valid identifier must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("value3", error.Key); + Assert.Equal("The value3 field is required.", error.Value.Single()); + }, + error => + { + Assert.Equal("value4", error.Key); + Assert.Equal("Value must be an even number", error.Value.Single()); + }, + error => + { + Assert.Equal("value5", error.Key); + Assert.Collection(error.Value, error => + { + Assert.Equal("The field value5 is invalid.", error); + }, + error => + { + Assert.Equal("The field value5 must be between 10 and 100.", error); + }); + }); + }); + } +} diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.Polymorphism.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.Polymorphism.cs new file mode 100644 index 000000000000..43f9e985d6b3 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.Polymorphism.cs @@ -0,0 +1,207 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Net.Http.Json; +using Microsoft.AspNetCore.Http; + +public partial class ValidationsGeneratorTests : ValidationsGeneratorTestsBase +{ + [Fact] + public async Task CanValidatePolymorphicTypes() + { + var source = """ +using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.Text.Json.Serialization; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing; + +var builder = WebApplication.CreateBuilder(); + +var app = builder.Build(); + +app.Conventions.WithValidation(); + +app.MapPost("/basic-polymorphism", (BaseType model) => Results.Ok()); +app.MapPost("/validatable-polymorphism", (BaseValidatableType model) => Results.Ok()); +app.MapPost("/polymorphism-container", (ContainerType model) => Results.Ok()); + +app.Run(); + +public class ContainerType +{ + public BaseType BaseType { get; set; } = new BaseType(); + public BaseValidatableType BaseValidatableType { get; set; } = new BaseValidatableType(); +} + +[JsonDerivedType(typeof(BaseType), typeDiscriminator: "base")] +[JsonDerivedType(typeof(DerivedType), typeDiscriminator: "derived")] +public class BaseType +{ + [Display(Name = "Value 1")] + [Range(10, 100)] + public int Value1 { get; set; } + + [EmailAddress] + [Required] + public string Value2 { get; set; } = "test@example.com"; +} + +public class DerivedType : BaseType +{ + [Base64String] + public string? Value3 { get; set; } +} + +[JsonDerivedType(typeof(BaseValidatableType), typeDiscriminator: "base")] +[JsonDerivedType(typeof(DerivedValidatableType), typeDiscriminator: "derived")] +public class BaseValidatableType : IValidatableObject +{ + [Display(Name = "Value 1")] + public int Value1 { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (Value1 < 10 || Value1 > 100) + { + yield return new ValidationResult("The field Value 1 must be between 10 and 100.", new[] { nameof(Value1) }); + } + } +} + +public class DerivedValidatableType : BaseValidatableType +{ + [EmailAddress] + public required string Value3 { get; set; } +} +"""; + await Verify(source, out var compilation); + await VerifyEndpoint(compilation, async client => + { + await CallsBaseTypeValidationsOnDerivedType(client); + await CallsBaseTypeValidationOnDerivedTypeWithIValidateObject(client); + await CanValidateContainerTypeWithPolymorphicProperties(client); + + static async Task CallsBaseTypeValidationsOnDerivedType(HttpClient client) + { + var payload = """ + { + "$type": "derived", + "Value1": 5, + "Value2": "invalid-email", + "Value3": "invalid-base64" + } + """; + var response = await client.PostAsync("/basic-polymorphism", new StringContent(payload, new MediaTypeHeaderValue("application/json"))); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var problemDetails = await response.Content.ReadFromJsonAsync(); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("Value1", error.Key); + Assert.Equal("The field Value 1 must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("Value2", error.Key); + Assert.Equal("The Value2 field is not a valid e-mail address.", error.Value.Single()); + }, + error => + { + Assert.Equal("Value3", error.Key); + Assert.Equal("The Value3 field is not a valid Base64 encoding.", error.Value.Single()); + }); + } + + static async Task CallsBaseTypeValidationOnDerivedTypeWithIValidateObject(HttpClient client) + { + var payload = """ + { + "$type": "derived", + "Value1": 5, + "Value3": "invalid-email" + } + """; + var response = await client.PostAsync("/validatable-polymorphism", new StringContent(payload, new MediaTypeHeaderValue("application/json"))); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var problemDetails = await response.Content.ReadFromJsonAsync(); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("Value1", error.Key); + Assert.Equal("The field Value 1 must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("Value3", error.Key); + Assert.Equal("The Value3 field is not a valid e-mail address.", error.Value.Single()); + }); + + payload = """ + { + "$type": "derived", + "Value1": 5, + "Value3": "test@example.com" + } + """; + response = await client.PostAsync("/validatable-polymorphism", new StringContent(payload, new MediaTypeHeaderValue("application/json"))); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + problemDetails = await response.Content.ReadFromJsonAsync(); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("Value1", error.Key); + Assert.Equal("The field Value 1 must be between 10 and 100.", error.Value.Single()); + }); + } + + static async Task CanValidateContainerTypeWithPolymorphicProperties(HttpClient client) + { + var payload = """ + { + "BaseType": { + "$type": "derived", + "Value1": 5, + "Value2": "invalid-email", + "Value3": "invalid-base64" + }, + "BaseValidatableType": { + "$type": "derived", + "Value1": 5, + "Value3": "test@example.com" + } + } + """; + var response = await client.PostAsync("/polymorphism-container", new StringContent(payload, new MediaTypeHeaderValue("application/json"))); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var problemDetails = await response.Content.ReadFromJsonAsync(); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("BaseType.Value1", error.Key); + Assert.Equal("The field Value 1 must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("BaseType.Value2", error.Key); + Assert.Equal("The Value2 field is not a valid e-mail address.", error.Value.Single()); + }, + error => + { + Assert.Equal("BaseType.Value3", error.Key); + Assert.Equal("The Value3 field is not a valid Base64 encoding.", error.Value.Single()); + }, + error => + { + Assert.Equal("BaseValidatableType.Value1", error.Key); + Assert.Equal("The field Value 1 must be between 10 and 100.", error.Value.Single()); + }); + } + }); + } +} diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.Recursion.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.Recursion.cs new file mode 100644 index 000000000000..99b6b029caa8 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTests.Recursion.cs @@ -0,0 +1,121 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Formats.Asn1; +using System.Net; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Net.Http.Json; +using Microsoft.AspNetCore.Http; + +public partial class ValidationsGeneratorTests : ValidationsGeneratorTestsBase +{ + [Fact] + public async Task CanValidateRecursiveTypes() + { + var source = """ +using System; +using System.ComponentModel.DataAnnotations; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +var builder = WebApplication.CreateBuilder(); + +var app = builder.Build(); + +app.Conventions.WithValidation(); + +app.MapPost("/recursive-type", (RecursiveType model) => Results.Ok()); + +app.Run(); + +public class RecursiveType +{ + [Range(10, 100)] + public int Value { get; set; } + public RecursiveType? Next { get; set; } +} +"""; + await Verify(source, out var compilation); + await VerifyEndpoint(compilation, async client => + { + await RecursiveTypeRespectsMaximumDepth(client); + + static async Task RecursiveTypeRespectsMaximumDepth(HttpClient client) + { + var payload = """ + { + "value": 1, + "next": { + "value": 2, + "next": { + "value": 3, + "next": { + "value": 4, + "next": { + "value": 5, + "next": { + "value": 6, + "next": { + "value": 7, + "next": { + "value": 8 + } + } + } + } + } + } + } + } + """; + var response = await client.PostAsync("/recursive-type", new StringContent(payload, new MediaTypeHeaderValue("application/json"))); + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var problemDetails = await response.Content.ReadFromJsonAsync(); + Assert.Collection(problemDetails.Errors, + kvp => + { + Assert.Equal("Value", kvp.Key); + Assert.Equal("The field Value must be between 10 and 100.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("Next.Value", kvp.Key); + Assert.Equal("The field Value must be between 10 and 100.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("Next.Next.Value", kvp.Key); + Assert.Equal("The field Value must be between 10 and 100.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("Next.Next.Next.Value", kvp.Key); + Assert.Equal("The field Value must be between 10 and 100.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("Next.Next.Next.Next.Value", kvp.Key); + Assert.Equal("The field Value must be between 10 and 100.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("Next.Next.Next.Next.Next.Value", kvp.Key); + Assert.Equal("The field Value must be between 10 and 100.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("Next.Next.Next.Next.Next.Next.Value", kvp.Key); + Assert.Equal("The field Value must be between 10 and 100.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("Next.Next.Next.Next.Next.Next.Next.Value", kvp.Key); + Assert.Equal("The field Value must be between 10 and 100.", kvp.Value.Single()); + }); + } + }); + } +} diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTestsBase.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTestsBase.cs new file mode 100644 index 000000000000..1efcacbf24b6 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTestsBase.cs @@ -0,0 +1,111 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net.Http; +using System.Runtime.Loader; +using System.Text; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Http.ValidationsGenerator; +using Microsoft.AspNetCore.Mvc.Testing; +using Microsoft.AspNetCore.Routing; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Emit; +using Microsoft.CodeAnalysis.Text; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; + +[UsesVerify] +public class ValidationsGeneratorTestsBase +{ + private static readonly CSharpParseOptions _parseOptions = new CSharpParseOptions(LanguageVersion.Preview) + .WithFeatures([new KeyValuePair("InterceptorsNamespaces", "Microsoft.AspNetCore.Http.Validations.Generated")]); + + private static string CreateSourceText(string source) => $$""" +{{source}} +// Make Program class public for consumption +// in WebApplicationFactory +public partial class Program { } +"""; + + public Task Verify(string source, out Compilation compilation) + { + var references = AppDomain.CurrentDomain.GetAssemblies() + .Where(assembly => !assembly.IsDynamic && !string.IsNullOrWhiteSpace(assembly.Location)) + .Select(assembly => MetadataReference.CreateFromFile(assembly.Location)) + .Concat( + [ + MetadataReference.CreateFromFile(typeof(WebApplicationBuilder).Assembly.Location), + MetadataReference.CreateFromFile(typeof(EndpointRouteBuilderExtensions).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IApplicationBuilder).Assembly.Location), + MetadataReference.CreateFromFile(typeof(Microsoft.AspNetCore.Mvc.ApiExplorer.IApiDescriptionProvider).Assembly.Location), + MetadataReference.CreateFromFile(typeof(Microsoft.AspNetCore.Mvc.ControllerBase).Assembly.Location), + MetadataReference.CreateFromFile(typeof(MvcCoreMvcBuilderExtensions).Assembly.Location), + MetadataReference.CreateFromFile(typeof(TypedResults).Assembly.Location), + MetadataReference.CreateFromFile(typeof(System.Text.Json.Nodes.JsonArray).Assembly.Location), + MetadataReference.CreateFromFile(typeof(Console).Assembly.Location), + MetadataReference.CreateFromFile(typeof(Uri).Assembly.Location), + MetadataReference.CreateFromFile(typeof(System.ComponentModel.DataAnnotations.ValidationAttribute).Assembly.Location), + MetadataReference.CreateFromFile(typeof(RouteData).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IFeatureCollection).Assembly.Location), + MetadataReference.CreateFromFile(typeof(ValidateOptionsResult).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IHttpMethodMetadata).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IResult).Assembly.Location), + MetadataReference.CreateFromFile(typeof(HttpJsonServiceExtensions).Assembly.Location), + ]); + var generator = new ValidationsGenerator(); + var inputCompilation = CSharpCompilation.Create($"ValidationsGeneratorSample-{Guid.NewGuid()}", + [CSharpSyntaxTree.ParseText(CreateSourceText(source), options: _parseOptions, path: "Program.cs")], + references, + new CSharpCompilationOptions(OutputKind.ConsoleApplication, nullableContextOptions: NullableContextOptions.Enable)); + var driver = CSharpGeneratorDriver.Create(generators: [generator.AsSourceGenerator()], parseOptions: _parseOptions); + return Verifier + .Verify(driver.RunGeneratorsAndUpdateCompilation(inputCompilation, out compilation, out _)) + .AutoVerify() + .UseDirectory("snapshots"); + } + + public async Task VerifyEndpoint(Compilation compilation, Func verifyEndpoint) + { + var symbolsName = compilation.AssemblyName; + var output = new MemoryStream(); + var pdb = new MemoryStream(); + + var emitOptions = new EmitOptions(debugInformationFormat: DebugInformationFormat.PortablePdb, pdbFilePath: symbolsName); + + var embeddedTexts = new List(); + + // Make sure we embed the sources in pdb for easy debugging + foreach (var syntaxTree in compilation.SyntaxTrees) + { + var text = syntaxTree.GetText(); + var encoding = text.Encoding ?? Encoding.UTF8; + var buffer = encoding.GetBytes(text.ToString()); + var sourceText = SourceText.From(buffer, buffer.Length, encoding, canBeEmbedded: true); + + var syntaxRootNode = (CSharpSyntaxNode)syntaxTree.GetRoot(); + var newSyntaxTree = CSharpSyntaxTree.Create(syntaxRootNode, options: _parseOptions, encoding: encoding, path: syntaxTree.FilePath); + + compilation = compilation.ReplaceSyntaxTree(syntaxTree, newSyntaxTree); + + embeddedTexts.Add(EmbeddedText.FromSource(syntaxTree.FilePath, sourceText)); + } + + _ = compilation.Emit(output, pdb, options: emitOptions, embeddedTexts: embeddedTexts); + + output.Position = 0; + pdb.Position = 0; + + var assembly = AssemblyLoadContext.Default.LoadFromStream(output, pdb); + + var depsFileName = $"{assembly.GetName().Name}.deps.json"; + var depsFile = new FileInfo(Path.Combine(AppContext.BaseDirectory, depsFileName)); + File.Create(depsFile.FullName).Dispose(); + + var factory = Activator.CreateInstance(typeof(WebApplicationFactory<>).MakeGenericType(assembly?.GetType("Program")!)); + var client = (HttpClient)factory.GetType().GetMethod("CreateClient", Type.EmptyTypes).Invoke(factory, null); + await verifyEndpoint(client); + } +} diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateComplexTypes#RouteHandlerValidations.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateComplexTypes#RouteHandlerValidations.g.verified.cs new file mode 100644 index 000000000000..a205cee894c7 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateComplexTypes#RouteHandlerValidations.g.verified.cs @@ -0,0 +1,380 @@ +//HintName: RouteHandlerValidations.g.cs +// +#nullable enable +namespace System.Runtime.CompilerServices +{ + [AttributeUsage(System.AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(int version, string data) { } + } +} + +namespace Microsoft.AspNetCore.Http.Validations.Generated +{ + using System; + using System.Linq; + using System.Diagnostics; + using System.ComponentModel.DataAnnotations; + + file class EndpointKey(string route, global::System.Collections.Generic.IEnumerable methods) + { + public string Route { get; } = route; + public global::System.Collections.Generic.IEnumerable Methods { get; } = methods; + + public override bool Equals(object? obj) + { + if (obj is EndpointKey other) + { + return string.Equals(Route, other.Route, global::System.StringComparison.OrdinalIgnoreCase) && + Methods.SequenceEqual(other.Methods, global::System.StringComparer.OrdinalIgnoreCase); + } + return false; + } + + public override int GetHashCode() + { + int hash = 17; + hash = hash * 23 + (Route?.GetHashCode(global::System.StringComparison.OrdinalIgnoreCase) ?? 0); + hash = hash * 23 + GetMethodsHashCode(Methods); + return hash; + } + + private static int GetMethodsHashCode(global::System.Collections.Generic.IEnumerable methods) + { + if (methods == null) + { + return 0; + } + int hash = 17; + foreach (var method in methods) + { + hash = hash * 23 + (method?.GetHashCode(global::System.StringComparison.OrdinalIgnoreCase) ?? 0); + } + return hash; + } + } + + file class ValidationProblemBuilder + { + private readonly global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails _problemDetails; + + public ValidationProblemBuilder() + { + _problemDetails = new global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails(); + } + + public ValidationProblemBuilder WithTitle(string title) + { + _problemDetails.Title = title; + return this; + } + + public ValidationProblemBuilder WithStatus(int? status) + { + _problemDetails.Status = status; + return this; + } + + public ValidationProblemBuilder WithDetail(string detail) + { + _problemDetails.Detail = detail; + return this; + } + + public ValidationProblemBuilder WithInstance(string instance) + { + _problemDetails.Instance = instance; + return this; + } + + public ValidationProblemBuilder WithType(string type) + { + _problemDetails.Type = type; + return this; + } + + public ValidationProblemBuilder WithExtensions(global::System.Collections.Generic.IDictionary extensions) + { + foreach (var kvp in extensions) + { + _problemDetails.Extensions[kvp.Key] = kvp.Value; + } + return this; + } + + public ValidationProblemBuilder WithErrors(global::System.Collections.Generic.IDictionary errors) + { + foreach (var kvp in errors) + { + _problemDetails.Errors[kvp.Key] = kvp.Value; + } + return this; + } + + public ValidationProblemBuilder WithError(string key, string error) + { + if (_problemDetails.Errors.ContainsKey(key)) + { + _problemDetails.Errors[key] = _problemDetails.Errors[key].Append(error).ToArray(); + } + else + { + _problemDetails.Errors[key] = new string[] { error }; + } + return this; + } + + public ValidationProblemBuilder WithErrors(string key, string[] errors) + { + if (_problemDetails.Errors.ContainsKey(key)) + { + _problemDetails.Errors[key] = _problemDetails.Errors[key].Concat(errors).ToArray(); + } + else + { + _problemDetails.Errors[key] = errors; + } + return this; + } + + public global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails Build() + { + return _problemDetails; + } + + public bool HasValue() + { + return _problemDetails.Errors.Count > 0; + } + } + + file static class WithValidationsInterceptor + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute(1, "AnYTrS2tGWrwalLyj4MD5SIBAABQcm9ncmFtLmNz")] + public static global::Microsoft.AspNetCore.Builder.IEndpointConventionBuilder WithValidation(this global::Microsoft.AspNetCore.Builder.IEndpointConventionBuilder builder) + { + System.Diagnostics.Debugger.Break(); + builder.AddEndpointFilter(async (context, next) => + { + var targetEndpoint = context.HttpContext.Features.Get()?.Endpoint; + Debug.Assert(targetEndpoint != null); + var route = ((global::Microsoft.AspNetCore.Routing.RouteEndpoint)targetEndpoint).RoutePattern.RawText; + var methods = ((global::Microsoft.AspNetCore.Routing.RouteEndpoint)targetEndpoint).Metadata.GetMetadata()?.HttpMethods ?? ["GET"]; + Debug.Assert(route != null); + var validationFilter = ValidationsFilters.Filters[new EndpointKey(route, methods)]; + var validationProblemDetails = validationFilter(context); + if (validationProblemDetails == null) + { + return await next(context); + } + return global::Microsoft.AspNetCore.Http.TypedResults.ValidationProblem(validationProblemDetails.Errors); + }); + return builder; + } + } + + + file static class ValidationTypes + { + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(T value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) => null; + + private static readonly ValidationAttribute StringRequiredPropertyRequiredAttribute = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); + private static readonly ValidationAttribute StringStringWithLengthStringLengthAttribute = new global::System.ComponentModel.DataAnnotations.StringLengthAttribute(10); + + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(SubType? value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) + { + var maxDepth = ((global::Microsoft.Extensions.Options.IOptions)(serviceProvider ?? validationContext).GetService(typeof(global::Microsoft.Extensions.Options.IOptions))).Value?.SerializerOptions.MaxDepth; + if (currentDepth > (maxDepth == null || maxDepth == 0 ? 64 : maxDepth)) + { + return null; + } + ValidationProblemBuilder resultBuilder = new(); + if (value != null) + { + validationContext ??= new(value, serviceProvider: serviceProvider, items: null); + validationContext.DisplayName = "RequiredProperty"; + validationContext.MemberName = "RequiredProperty"; + global::System.ComponentModel.DataAnnotations.ValidationResult? validationResult; + validationResult = StringRequiredPropertyRequiredAttribute.GetValidationResult(value?.RequiredProperty, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("RequiredProperty", validationResult.ErrorMessage); + } + validationContext.DisplayName = "StringWithLength"; + validationContext.MemberName = "StringWithLength"; + validationResult = StringStringWithLengthStringLengthAttribute.GetValidationResult(value?.StringWithLength, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("StringWithLength", validationResult.ErrorMessage); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + private static readonly ValidationAttribute StringEmailStringEmailAddressAttribute = new global::System.ComponentModel.DataAnnotations.EmailAddressAttribute(); + + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(SubTypeWithInheritance? value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) + { + var maxDepth = ((global::Microsoft.Extensions.Options.IOptions)(serviceProvider ?? validationContext).GetService(typeof(global::Microsoft.Extensions.Options.IOptions))).Value?.SerializerOptions.MaxDepth; + if (currentDepth > (maxDepth == null || maxDepth == 0 ? 64 : maxDepth)) + { + return null; + } + ValidationProblemBuilder resultBuilder = new(); + if (value != null) + { + if (value is SubType subTypeSubType) + { + var subTypeSubTypeValidationResult = ValidationTypes.Validate((SubType)subTypeSubType, validationContext, serviceProvider, currentDepth: currentDepth + 1, skipRecurse); + if (subTypeSubTypeValidationResult is not null) + { + foreach (var error in subTypeSubTypeValidationResult.Errors) + { + resultBuilder.WithErrors(error.Key, error.Value); + } + } + } + validationContext ??= new(value, serviceProvider: serviceProvider, items: null); + validationContext.DisplayName = "EmailString"; + validationContext.MemberName = "EmailString"; + global::System.ComponentModel.DataAnnotations.ValidationResult? validationResult; + validationResult = StringEmailStringEmailAddressAttribute.GetValidationResult(value?.EmailString, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("EmailString", validationResult.ErrorMessage); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + private static readonly ValidationAttribute Int32IntegerWithRangeRangeAttribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(10, 100); + private static readonly ValidationAttribute Int32IntegerWithRangeAndDisplayNameRangeAttribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(10, 100); + private static readonly ValidationAttribute SubTypePropertyWithMemberAttributesRequiredAttribute = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); + private static readonly ValidationAttribute Int32IntegerWithCustomValidationAttributeCustomValidationAttribute = new global::CustomValidationAttribute() { ErrorMessage = "Value must be an even number" }; + private static readonly ValidationAttribute Int32PropertyWithMultipleAttributesCustomValidationAttribute = new global::CustomValidationAttribute(); + private static readonly ValidationAttribute Int32PropertyWithMultipleAttributesRangeAttribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(10, 100); + + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(ComplexType? value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) + { + var maxDepth = ((global::Microsoft.Extensions.Options.IOptions)(serviceProvider ?? validationContext).GetService(typeof(global::Microsoft.Extensions.Options.IOptions))).Value?.SerializerOptions.MaxDepth; + if (currentDepth > (maxDepth == null || maxDepth == 0 ? 64 : maxDepth)) + { + return null; + } + ValidationProblemBuilder resultBuilder = new(); + if (value != null) + { + validationContext ??= new(value, serviceProvider: serviceProvider, items: null); + validationContext.DisplayName = "IntegerWithRange"; + validationContext.MemberName = "IntegerWithRange"; + global::System.ComponentModel.DataAnnotations.ValidationResult? validationResult; + validationResult = Int32IntegerWithRangeRangeAttribute.GetValidationResult(value?.IntegerWithRange, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("IntegerWithRange", validationResult.ErrorMessage); + } + validationContext.DisplayName = "Valid identifier"; + validationContext.MemberName = "IntegerWithRangeAndDisplayName"; + validationResult = Int32IntegerWithRangeAndDisplayNameRangeAttribute.GetValidationResult(value?.IntegerWithRangeAndDisplayName, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("IntegerWithRangeAndDisplayName", validationResult.ErrorMessage); + } + validationContext.DisplayName = "PropertyWithMemberAttributes"; + validationContext.MemberName = "PropertyWithMemberAttributes"; + var typePropertyWithMemberAttributesValidationResult = ValidationTypes.Validate(value?.PropertyWithMemberAttributes, validationContext, serviceProvider, currentDepth: currentDepth + 1, skipRecurse); + if (typePropertyWithMemberAttributesValidationResult is not null) + { + foreach (var error in typePropertyWithMemberAttributesValidationResult.Errors) + { + resultBuilder.WithErrors($"PropertyWithMemberAttributes.{error.Key}", error.Value); + } + } + validationResult = SubTypePropertyWithMemberAttributesRequiredAttribute.GetValidationResult(value?.PropertyWithMemberAttributes, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("PropertyWithMemberAttributes", validationResult.ErrorMessage); + } + validationContext.DisplayName = "PropertyWithoutMemberAttributes"; + validationContext.MemberName = "PropertyWithoutMemberAttributes"; + var typePropertyWithoutMemberAttributesValidationResult = ValidationTypes.Validate(value?.PropertyWithoutMemberAttributes, validationContext, serviceProvider, currentDepth: currentDepth + 1, skipRecurse); + if (typePropertyWithoutMemberAttributesValidationResult is not null) + { + foreach (var error in typePropertyWithoutMemberAttributesValidationResult.Errors) + { + resultBuilder.WithErrors($"PropertyWithoutMemberAttributes.{error.Key}", error.Value); + } + } + validationContext.DisplayName = "PropertyWithInheritance"; + validationContext.MemberName = "PropertyWithInheritance"; + var typePropertyWithInheritanceValidationResult = ValidationTypes.Validate(value?.PropertyWithInheritance, validationContext, serviceProvider, currentDepth: currentDepth + 1, skipRecurse); + if (typePropertyWithInheritanceValidationResult is not null) + { + foreach (var error in typePropertyWithInheritanceValidationResult.Errors) + { + resultBuilder.WithErrors($"PropertyWithInheritance.{error.Key}", error.Value); + } + } + validationContext.DisplayName = "ListOfSubTypes"; + validationContext.MemberName = "ListOfSubTypes"; + var ListOfSubTypesIndex = 0; + foreach (var item in value?.ListOfSubTypes ?? []) + { + var itemValidationResult = ValidationTypes.Validate(item, validationContext, serviceProvider, currentDepth, skipRecurse); + if (itemValidationResult is not null) + { + foreach (var error in itemValidationResult.Errors) + { + resultBuilder.WithErrors($"ListOfSubTypes[{ ListOfSubTypesIndex }].{error.Key}", error.Value); + } + ListOfSubTypesIndex++; + } + } + validationContext.DisplayName = "IntegerWithCustomValidationAttribute"; + validationContext.MemberName = "IntegerWithCustomValidationAttribute"; + validationResult = Int32IntegerWithCustomValidationAttributeCustomValidationAttribute.GetValidationResult(value?.IntegerWithCustomValidationAttribute, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("IntegerWithCustomValidationAttribute", validationResult.ErrorMessage); + } + validationContext.DisplayName = "PropertyWithMultipleAttributes"; + validationContext.MemberName = "PropertyWithMultipleAttributes"; + validationResult = Int32PropertyWithMultipleAttributesCustomValidationAttribute.GetValidationResult(value?.PropertyWithMultipleAttributes, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("PropertyWithMultipleAttributes", validationResult.ErrorMessage); + } + validationResult = Int32PropertyWithMultipleAttributesRangeAttribute.GetValidationResult(value?.PropertyWithMultipleAttributes, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("PropertyWithMultipleAttributes", validationResult.ErrorMessage); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + } + + + file static class ValidationsFilters + { + public static readonly global::System.Collections.Generic.Dictionary> Filters = new() + { + { new EndpointKey("/complex-type", ["POST"]), context => + { + ValidationProblemBuilder resultBuilder = new(); + var value0 = context.GetArgument(0); + var typeValidationResult = ValidationTypes.Validate(value0, serviceProvider: context.HttpContext.RequestServices); + if (typeValidationResult is not null) + { + foreach (var error in typeValidationResult.Errors) + { + resultBuilder.WithErrors(error.Key, error.Value); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + }, + + }; + } +} +#nullable restore diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateIValidatableObject#RouteHandlerValidations.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateIValidatableObject#RouteHandlerValidations.g.verified.cs new file mode 100644 index 000000000000..b4ce357628e8 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateIValidatableObject#RouteHandlerValidations.g.verified.cs @@ -0,0 +1,332 @@ +//HintName: RouteHandlerValidations.g.cs +// +#nullable enable +namespace System.Runtime.CompilerServices +{ + [AttributeUsage(System.AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(int version, string data) { } + } +} + +namespace Microsoft.AspNetCore.Http.Validations.Generated +{ + using System; + using System.Linq; + using System.Diagnostics; + using System.ComponentModel.DataAnnotations; + + file class EndpointKey(string route, global::System.Collections.Generic.IEnumerable methods) + { + public string Route { get; } = route; + public global::System.Collections.Generic.IEnumerable Methods { get; } = methods; + + public override bool Equals(object? obj) + { + if (obj is EndpointKey other) + { + return string.Equals(Route, other.Route, global::System.StringComparison.OrdinalIgnoreCase) && + Methods.SequenceEqual(other.Methods, global::System.StringComparer.OrdinalIgnoreCase); + } + return false; + } + + public override int GetHashCode() + { + int hash = 17; + hash = hash * 23 + (Route?.GetHashCode(global::System.StringComparison.OrdinalIgnoreCase) ?? 0); + hash = hash * 23 + GetMethodsHashCode(Methods); + return hash; + } + + private static int GetMethodsHashCode(global::System.Collections.Generic.IEnumerable methods) + { + if (methods == null) + { + return 0; + } + int hash = 17; + foreach (var method in methods) + { + hash = hash * 23 + (method?.GetHashCode(global::System.StringComparison.OrdinalIgnoreCase) ?? 0); + } + return hash; + } + } + + file class ValidationProblemBuilder + { + private readonly global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails _problemDetails; + + public ValidationProblemBuilder() + { + _problemDetails = new global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails(); + } + + public ValidationProblemBuilder WithTitle(string title) + { + _problemDetails.Title = title; + return this; + } + + public ValidationProblemBuilder WithStatus(int? status) + { + _problemDetails.Status = status; + return this; + } + + public ValidationProblemBuilder WithDetail(string detail) + { + _problemDetails.Detail = detail; + return this; + } + + public ValidationProblemBuilder WithInstance(string instance) + { + _problemDetails.Instance = instance; + return this; + } + + public ValidationProblemBuilder WithType(string type) + { + _problemDetails.Type = type; + return this; + } + + public ValidationProblemBuilder WithExtensions(global::System.Collections.Generic.IDictionary extensions) + { + foreach (var kvp in extensions) + { + _problemDetails.Extensions[kvp.Key] = kvp.Value; + } + return this; + } + + public ValidationProblemBuilder WithErrors(global::System.Collections.Generic.IDictionary errors) + { + foreach (var kvp in errors) + { + _problemDetails.Errors[kvp.Key] = kvp.Value; + } + return this; + } + + public ValidationProblemBuilder WithError(string key, string error) + { + if (_problemDetails.Errors.ContainsKey(key)) + { + _problemDetails.Errors[key] = _problemDetails.Errors[key].Append(error).ToArray(); + } + else + { + _problemDetails.Errors[key] = new string[] { error }; + } + return this; + } + + public ValidationProblemBuilder WithErrors(string key, string[] errors) + { + if (_problemDetails.Errors.ContainsKey(key)) + { + _problemDetails.Errors[key] = _problemDetails.Errors[key].Concat(errors).ToArray(); + } + else + { + _problemDetails.Errors[key] = errors; + } + return this; + } + + public global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails Build() + { + return _problemDetails; + } + + public bool HasValue() + { + return _problemDetails.Errors.Count > 0; + } + } + + file static class WithValidationsInterceptor + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute(1, "M+n0go1d3dTXFKc7wYCJ+YMBAABQcm9ncmFtLmNz")] + public static global::Microsoft.AspNetCore.Builder.IEndpointConventionBuilder WithValidation(this global::Microsoft.AspNetCore.Builder.IEndpointConventionBuilder builder) + { + System.Diagnostics.Debugger.Break(); + builder.AddEndpointFilter(async (context, next) => + { + var targetEndpoint = context.HttpContext.Features.Get()?.Endpoint; + Debug.Assert(targetEndpoint != null); + var route = ((global::Microsoft.AspNetCore.Routing.RouteEndpoint)targetEndpoint).RoutePattern.RawText; + var methods = ((global::Microsoft.AspNetCore.Routing.RouteEndpoint)targetEndpoint).Metadata.GetMetadata()?.HttpMethods ?? ["GET"]; + Debug.Assert(route != null); + var validationFilter = ValidationsFilters.Filters[new EndpointKey(route, methods)]; + var validationProblemDetails = validationFilter(context); + if (validationProblemDetails == null) + { + return await next(context); + } + return global::Microsoft.AspNetCore.Http.TypedResults.ValidationProblem(validationProblemDetails.Errors); + }); + return builder; + } + } + + + file static class ValidationTypes + { + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(T value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) => null; + + private static readonly ValidationAttribute StringRequiredPropertyRequiredAttribute = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); + private static readonly ValidationAttribute StringStringWithLengthStringLengthAttribute = new global::System.ComponentModel.DataAnnotations.StringLengthAttribute(10); + + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(SubType? value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) + { + var maxDepth = ((global::Microsoft.Extensions.Options.IOptions)(serviceProvider ?? validationContext).GetService(typeof(global::Microsoft.Extensions.Options.IOptions))).Value?.SerializerOptions.MaxDepth; + if (currentDepth > (maxDepth == null || maxDepth == 0 ? 64 : maxDepth)) + { + return null; + } + ValidationProblemBuilder resultBuilder = new(); + if (value != null) + { + validationContext ??= new(value, serviceProvider: serviceProvider, items: null); + validationContext.DisplayName = "RequiredProperty"; + validationContext.MemberName = "RequiredProperty"; + global::System.ComponentModel.DataAnnotations.ValidationResult? validationResult; + validationResult = StringRequiredPropertyRequiredAttribute.GetValidationResult(value?.RequiredProperty, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("RequiredProperty", validationResult.ErrorMessage); + } + validationContext.DisplayName = "StringWithLength"; + validationContext.MemberName = "StringWithLength"; + validationResult = StringStringWithLengthStringLengthAttribute.GetValidationResult(value?.StringWithLength, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("StringWithLength", validationResult.ErrorMessage); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(ValidatableSubType? value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) + { + var maxDepth = ((global::Microsoft.Extensions.Options.IOptions)(serviceProvider ?? validationContext).GetService(typeof(global::Microsoft.Extensions.Options.IOptions))).Value?.SerializerOptions.MaxDepth; + if (currentDepth > (maxDepth == null || maxDepth == 0 ? 64 : maxDepth)) + { + return null; + } + ValidationProblemBuilder resultBuilder = new(); + if (value != null) + { + if (value is SubType subTypeSubType) + { + var subTypeSubTypeValidationResult = ValidationTypes.Validate((SubType)subTypeSubType, validationContext, serviceProvider, currentDepth: currentDepth + 1, skipRecurse); + if (subTypeSubTypeValidationResult is not null) + { + foreach (var error in subTypeSubTypeValidationResult.Errors) + { + resultBuilder.WithErrors(error.Key, error.Value); + } + } + } + validationContext ??= new(value, serviceProvider: serviceProvider, items: null); + validationContext.DisplayName = "Value3"; + validationContext.MemberName = "Value3"; + if (!resultBuilder.HasValue()) + { + validationContext = new(value, serviceProvider: serviceProvider, items: null); + foreach (var validatableValidationResult in value.Validate(validationContext)) + { + if (validatableValidationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError(validatableValidationResult.MemberNames.FirstOrDefault() ?? "ValidatableSubType", validatableValidationResult.ErrorMessage); + } + } + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + private static readonly ValidationAttribute StringValue2EmailAddressAttribute = new global::System.ComponentModel.DataAnnotations.EmailAddressAttribute(); + private static readonly ValidationAttribute StringValue2RequiredAttribute = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); + + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(ComplexValidatableType? value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) + { + var maxDepth = ((global::Microsoft.Extensions.Options.IOptions)(serviceProvider ?? validationContext).GetService(typeof(global::Microsoft.Extensions.Options.IOptions))).Value?.SerializerOptions.MaxDepth; + if (currentDepth > (maxDepth == null || maxDepth == 0 ? 64 : maxDepth)) + { + return null; + } + ValidationProblemBuilder resultBuilder = new(); + if (value != null) + { + validationContext ??= new(value, serviceProvider: serviceProvider, items: null); + validationContext.DisplayName = "Value 1"; + validationContext.MemberName = "Value1"; + validationContext.DisplayName = "Value2"; + validationContext.MemberName = "Value2"; + global::System.ComponentModel.DataAnnotations.ValidationResult? validationResult; + validationResult = StringValue2EmailAddressAttribute.GetValidationResult(value?.Value2, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("Value2", validationResult.ErrorMessage); + } + validationResult = StringValue2RequiredAttribute.GetValidationResult(value?.Value2, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("Value2", validationResult.ErrorMessage); + } + validationContext.DisplayName = "SubType"; + validationContext.MemberName = "SubType"; + var typeSubTypeValidationResult = ValidationTypes.Validate(value?.SubType, validationContext, serviceProvider, currentDepth: currentDepth + 1, skipRecurse); + if (typeSubTypeValidationResult is not null) + { + foreach (var error in typeSubTypeValidationResult.Errors) + { + resultBuilder.WithErrors($"SubType.{error.Key}", error.Value); + } + } + if (!resultBuilder.HasValue()) + { + validationContext = new(value, serviceProvider: serviceProvider, items: null); + foreach (var validatableValidationResult in value.Validate(validationContext)) + { + if (validatableValidationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError(validatableValidationResult.MemberNames.FirstOrDefault() ?? "ComplexValidatableType", validatableValidationResult.ErrorMessage); + } + } + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + } + + + file static class ValidationsFilters + { + public static readonly global::System.Collections.Generic.Dictionary> Filters = new() + { + { new EndpointKey("/validatable-object", ["POST"]), context => + { + ValidationProblemBuilder resultBuilder = new(); + var value0 = context.GetArgument(0); + var typeValidationResult = ValidationTypes.Validate(value0, serviceProvider: context.HttpContext.RequestServices); + if (typeValidationResult is not null) + { + foreach (var error in typeValidationResult.Errors) + { + resultBuilder.WithErrors(error.Key, error.Value); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + }, + + }; + } +} +#nullable restore diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateOnAllMapOverloads#RouteHandlerValidations.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateOnAllMapOverloads#RouteHandlerValidations.g.verified.cs new file mode 100644 index 000000000000..5d6938786fd0 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateOnAllMapOverloads#RouteHandlerValidations.g.verified.cs @@ -0,0 +1,278 @@ +//HintName: RouteHandlerValidations.g.cs +// +#nullable enable +namespace System.Runtime.CompilerServices +{ + [AttributeUsage(System.AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(int version, string data) { } + } +} + +namespace Microsoft.AspNetCore.Http.Validations.Generated +{ + using System; + using System.Linq; + using System.Diagnostics; + using System.ComponentModel.DataAnnotations; + + file class EndpointKey(string route, global::System.Collections.Generic.IEnumerable methods) + { + public string Route { get; } = route; + public global::System.Collections.Generic.IEnumerable Methods { get; } = methods; + + public override bool Equals(object? obj) + { + if (obj is EndpointKey other) + { + return string.Equals(Route, other.Route, global::System.StringComparison.OrdinalIgnoreCase) && + Methods.SequenceEqual(other.Methods, global::System.StringComparer.OrdinalIgnoreCase); + } + return false; + } + + public override int GetHashCode() + { + int hash = 17; + hash = hash * 23 + (Route?.GetHashCode(global::System.StringComparison.OrdinalIgnoreCase) ?? 0); + hash = hash * 23 + GetMethodsHashCode(Methods); + return hash; + } + + private static int GetMethodsHashCode(global::System.Collections.Generic.IEnumerable methods) + { + if (methods == null) + { + return 0; + } + int hash = 17; + foreach (var method in methods) + { + hash = hash * 23 + (method?.GetHashCode(global::System.StringComparison.OrdinalIgnoreCase) ?? 0); + } + return hash; + } + } + + file class ValidationProblemBuilder + { + private readonly global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails _problemDetails; + + public ValidationProblemBuilder() + { + _problemDetails = new global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails(); + } + + public ValidationProblemBuilder WithTitle(string title) + { + _problemDetails.Title = title; + return this; + } + + public ValidationProblemBuilder WithStatus(int? status) + { + _problemDetails.Status = status; + return this; + } + + public ValidationProblemBuilder WithDetail(string detail) + { + _problemDetails.Detail = detail; + return this; + } + + public ValidationProblemBuilder WithInstance(string instance) + { + _problemDetails.Instance = instance; + return this; + } + + public ValidationProblemBuilder WithType(string type) + { + _problemDetails.Type = type; + return this; + } + + public ValidationProblemBuilder WithExtensions(global::System.Collections.Generic.IDictionary extensions) + { + foreach (var kvp in extensions) + { + _problemDetails.Extensions[kvp.Key] = kvp.Value; + } + return this; + } + + public ValidationProblemBuilder WithErrors(global::System.Collections.Generic.IDictionary errors) + { + foreach (var kvp in errors) + { + _problemDetails.Errors[kvp.Key] = kvp.Value; + } + return this; + } + + public ValidationProblemBuilder WithError(string key, string error) + { + if (_problemDetails.Errors.ContainsKey(key)) + { + _problemDetails.Errors[key] = _problemDetails.Errors[key].Append(error).ToArray(); + } + else + { + _problemDetails.Errors[key] = new string[] { error }; + } + return this; + } + + public ValidationProblemBuilder WithErrors(string key, string[] errors) + { + if (_problemDetails.Errors.ContainsKey(key)) + { + _problemDetails.Errors[key] = _problemDetails.Errors[key].Concat(errors).ToArray(); + } + else + { + _problemDetails.Errors[key] = errors; + } + return this; + } + + public global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails Build() + { + return _problemDetails; + } + + public bool HasValue() + { + return _problemDetails.Errors.Count > 0; + } + } + + file static class WithValidationsInterceptor + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute(1, "e96px15Ji9NhefBjLD9dcd8AAABQcm9ncmFtLmNz")] + public static global::Microsoft.AspNetCore.Builder.IEndpointConventionBuilder WithValidation(this global::Microsoft.AspNetCore.Builder.IEndpointConventionBuilder builder) + { + System.Diagnostics.Debugger.Break(); + builder.AddEndpointFilter(async (context, next) => + { + var targetEndpoint = context.HttpContext.Features.Get()?.Endpoint; + Debug.Assert(targetEndpoint != null); + var route = ((global::Microsoft.AspNetCore.Routing.RouteEndpoint)targetEndpoint).RoutePattern.RawText; + var methods = ((global::Microsoft.AspNetCore.Routing.RouteEndpoint)targetEndpoint).Metadata.GetMetadata()?.HttpMethods ?? ["GET"]; + Debug.Assert(route != null); + var validationFilter = ValidationsFilters.Filters[new EndpointKey(route, methods)]; + var validationProblemDetails = validationFilter(context); + if (validationProblemDetails == null) + { + return await next(context); + } + return global::Microsoft.AspNetCore.Http.TypedResults.ValidationProblem(validationProblemDetails.Errors); + }); + return builder; + } + } + + + file static class ValidationTypes + { + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(T value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) => null; + + private static readonly ValidationAttribute Int32IdRangeAttribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(1, 10); + + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(Todo? value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) + { + var maxDepth = ((global::Microsoft.Extensions.Options.IOptions)(serviceProvider ?? validationContext).GetService(typeof(global::Microsoft.Extensions.Options.IOptions))).Value?.SerializerOptions.MaxDepth; + if (currentDepth > (maxDepth == null || maxDepth == 0 ? 64 : maxDepth)) + { + return null; + } + ValidationProblemBuilder resultBuilder = new(); + if (value != null) + { + validationContext ??= new(value, serviceProvider: serviceProvider, items: null); + validationContext.DisplayName = "Id"; + validationContext.MemberName = "Id"; + global::System.ComponentModel.DataAnnotations.ValidationResult? validationResult; + validationResult = Int32IdRangeAttribute.GetValidationResult(value?.Id, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("Id", validationResult.ErrorMessage); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + } + + + file static class ValidationsFilters + { + public static readonly global::System.Collections.Generic.Dictionary> Filters = new() + { + { new EndpointKey("/todos/{id}", ["GET"]), context => + { + ValidationProblemBuilder resultBuilder = new(); + global::System.ComponentModel.DataAnnotations.ValidationResult? validationResult = null; + var value0 = context.GetArgument(0); + var idRangeAttribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(1, 10); + validationResult = idRangeAttribute.GetValidationResult(value0, new global::System.ComponentModel.DataAnnotations.ValidationContext(value0) { DisplayName = "id" }); + if (validationResult is not null) + { + resultBuilder.WithError("id", validationResult.ErrorMessage); + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + }, + + { new EndpointKey("/todos", ["POST"]), context => + { + ValidationProblemBuilder resultBuilder = new(); + var value0 = context.GetArgument(0); + var typeValidationResult = ValidationTypes.Validate(value0, serviceProvider: context.HttpContext.RequestServices); + if (typeValidationResult is not null) + { + foreach (var error in typeValidationResult.Errors) + { + resultBuilder.WithErrors(error.Key, error.Value); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + }, + + { new EndpointKey("/todos", ["PUT"]), context => + { + ValidationProblemBuilder resultBuilder = new(); + var value0 = context.GetArgument(0); + var typeValidationResult = ValidationTypes.Validate(value0, serviceProvider: context.HttpContext.RequestServices); + if (typeValidationResult is not null) + { + foreach (var error in typeValidationResult.Errors) + { + resultBuilder.WithErrors(error.Key, error.Value); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + }, + + { new EndpointKey("/todos/{id}", ["delete", "PaTcH"]), context => + { + ValidationProblemBuilder resultBuilder = new(); + global::System.ComponentModel.DataAnnotations.ValidationResult? validationResult = null; + var value0 = context.GetArgument(0); + var idRangeAttribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(1, 10); + validationResult = idRangeAttribute.GetValidationResult(value0, new global::System.ComponentModel.DataAnnotations.ValidationContext(value0) { DisplayName = "id" }); + if (validationResult is not null) + { + resultBuilder.WithError("id", validationResult.ErrorMessage); + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + }, + + }; + } +} +#nullable restore diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateParameters#RouteHandlerValidations.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateParameters#RouteHandlerValidations.g.verified.cs new file mode 100644 index 000000000000..58fb940004af --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateParameters#RouteHandlerValidations.g.verified.cs @@ -0,0 +1,240 @@ +//HintName: RouteHandlerValidations.g.cs +// +#nullable enable +namespace System.Runtime.CompilerServices +{ + [AttributeUsage(System.AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(int version, string data) { } + } +} + +namespace Microsoft.AspNetCore.Http.Validations.Generated +{ + using System; + using System.Linq; + using System.Diagnostics; + using System.ComponentModel.DataAnnotations; + + file class EndpointKey(string route, global::System.Collections.Generic.IEnumerable methods) + { + public string Route { get; } = route; + public global::System.Collections.Generic.IEnumerable Methods { get; } = methods; + + public override bool Equals(object? obj) + { + if (obj is EndpointKey other) + { + return string.Equals(Route, other.Route, global::System.StringComparison.OrdinalIgnoreCase) && + Methods.SequenceEqual(other.Methods, global::System.StringComparer.OrdinalIgnoreCase); + } + return false; + } + + public override int GetHashCode() + { + int hash = 17; + hash = hash * 23 + (Route?.GetHashCode(global::System.StringComparison.OrdinalIgnoreCase) ?? 0); + hash = hash * 23 + GetMethodsHashCode(Methods); + return hash; + } + + private static int GetMethodsHashCode(global::System.Collections.Generic.IEnumerable methods) + { + if (methods == null) + { + return 0; + } + int hash = 17; + foreach (var method in methods) + { + hash = hash * 23 + (method?.GetHashCode(global::System.StringComparison.OrdinalIgnoreCase) ?? 0); + } + return hash; + } + } + + file class ValidationProblemBuilder + { + private readonly global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails _problemDetails; + + public ValidationProblemBuilder() + { + _problemDetails = new global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails(); + } + + public ValidationProblemBuilder WithTitle(string title) + { + _problemDetails.Title = title; + return this; + } + + public ValidationProblemBuilder WithStatus(int? status) + { + _problemDetails.Status = status; + return this; + } + + public ValidationProblemBuilder WithDetail(string detail) + { + _problemDetails.Detail = detail; + return this; + } + + public ValidationProblemBuilder WithInstance(string instance) + { + _problemDetails.Instance = instance; + return this; + } + + public ValidationProblemBuilder WithType(string type) + { + _problemDetails.Type = type; + return this; + } + + public ValidationProblemBuilder WithExtensions(global::System.Collections.Generic.IDictionary extensions) + { + foreach (var kvp in extensions) + { + _problemDetails.Extensions[kvp.Key] = kvp.Value; + } + return this; + } + + public ValidationProblemBuilder WithErrors(global::System.Collections.Generic.IDictionary errors) + { + foreach (var kvp in errors) + { + _problemDetails.Errors[kvp.Key] = kvp.Value; + } + return this; + } + + public ValidationProblemBuilder WithError(string key, string error) + { + if (_problemDetails.Errors.ContainsKey(key)) + { + _problemDetails.Errors[key] = _problemDetails.Errors[key].Append(error).ToArray(); + } + else + { + _problemDetails.Errors[key] = new string[] { error }; + } + return this; + } + + public ValidationProblemBuilder WithErrors(string key, string[] errors) + { + if (_problemDetails.Errors.ContainsKey(key)) + { + _problemDetails.Errors[key] = _problemDetails.Errors[key].Concat(errors).ToArray(); + } + else + { + _problemDetails.Errors[key] = errors; + } + return this; + } + + public global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails Build() + { + return _problemDetails; + } + + public bool HasValue() + { + return _problemDetails.Errors.Count > 0; + } + } + + file static class WithValidationsInterceptor + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute(1, "jZqu1LqTQ9Hw0e9sWMnm198AAABQcm9ncmFtLmNz")] + public static global::Microsoft.AspNetCore.Builder.IEndpointConventionBuilder WithValidation(this global::Microsoft.AspNetCore.Builder.IEndpointConventionBuilder builder) + { + System.Diagnostics.Debugger.Break(); + builder.AddEndpointFilter(async (context, next) => + { + var targetEndpoint = context.HttpContext.Features.Get()?.Endpoint; + Debug.Assert(targetEndpoint != null); + var route = ((global::Microsoft.AspNetCore.Routing.RouteEndpoint)targetEndpoint).RoutePattern.RawText; + var methods = ((global::Microsoft.AspNetCore.Routing.RouteEndpoint)targetEndpoint).Metadata.GetMetadata()?.HttpMethods ?? ["GET"]; + Debug.Assert(route != null); + var validationFilter = ValidationsFilters.Filters[new EndpointKey(route, methods)]; + var validationProblemDetails = validationFilter(context); + if (validationProblemDetails == null) + { + return await next(context); + } + return global::Microsoft.AspNetCore.Http.TypedResults.ValidationProblem(validationProblemDetails.Errors); + }); + return builder; + } + } + + + file static class ValidationTypes + { + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(T value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) => null; + } + + + file static class ValidationsFilters + { + public static readonly global::System.Collections.Generic.Dictionary> Filters = new() + { + { new EndpointKey("/params", ["GET"]), context => + { + ValidationProblemBuilder resultBuilder = new(); + global::System.ComponentModel.DataAnnotations.ValidationResult? validationResult = null; + var value0 = context.GetArgument(0); + var value1RangeAttribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(10, 100); + validationResult = value1RangeAttribute.GetValidationResult(value0, new global::System.ComponentModel.DataAnnotations.ValidationContext(value0) { DisplayName = "value1" }); + if (validationResult is not null) + { + resultBuilder.WithError("value1", validationResult.ErrorMessage); + } + var value1 = context.GetArgument(1); + var value2RangeAttribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(10, 100); + validationResult = value2RangeAttribute.GetValidationResult(value1, new global::System.ComponentModel.DataAnnotations.ValidationContext(value1) { DisplayName = "Valid identifier" }); + if (validationResult is not null) + { + resultBuilder.WithError("value2", validationResult.ErrorMessage); + } + var value2 = context.GetArgument(2); + var value3RequiredAttribute = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); + validationResult = value3RequiredAttribute.GetValidationResult(value2, new global::System.ComponentModel.DataAnnotations.ValidationContext(value2) { DisplayName = "value3" }); + if (validationResult is not null) + { + resultBuilder.WithError("value3", validationResult.ErrorMessage); + } + var value3 = context.GetArgument(3); + var value4CustomValidationAttribute = new global::CustomValidationAttribute() { ErrorMessage = "Value must be an even number" }; + validationResult = value4CustomValidationAttribute.GetValidationResult(value3, new global::System.ComponentModel.DataAnnotations.ValidationContext(value3) { DisplayName = "value4" }); + if (validationResult is not null) + { + resultBuilder.WithError("value4", validationResult.ErrorMessage); + } + var value4 = context.GetArgument(4); + var value5CustomValidationAttribute = new global::CustomValidationAttribute(); + var value5RangeAttribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(10, 100); + validationResult = value5CustomValidationAttribute.GetValidationResult(value4, new global::System.ComponentModel.DataAnnotations.ValidationContext(value4) { DisplayName = "value5" }); + if (validationResult is not null) + { + resultBuilder.WithError("value5", validationResult.ErrorMessage); + } + validationResult = value5RangeAttribute.GetValidationResult(value4, new global::System.ComponentModel.DataAnnotations.ValidationContext(value4) { DisplayName = "value5" }); + if (validationResult is not null) + { + resultBuilder.WithError("value5", validationResult.ErrorMessage); + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + }, + + }; + } +} +#nullable restore diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidatePolymorphicTypes#RouteHandlerValidations.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidatePolymorphicTypes#RouteHandlerValidations.g.verified.cs new file mode 100644 index 000000000000..123185d045ed --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidatePolymorphicTypes#RouteHandlerValidations.g.verified.cs @@ -0,0 +1,446 @@ +//HintName: RouteHandlerValidations.g.cs +// +#nullable enable +namespace System.Runtime.CompilerServices +{ + [AttributeUsage(System.AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(int version, string data) { } + } +} + +namespace Microsoft.AspNetCore.Http.Validations.Generated +{ + using System; + using System.Linq; + using System.Diagnostics; + using System.ComponentModel.DataAnnotations; + + file class EndpointKey(string route, global::System.Collections.Generic.IEnumerable methods) + { + public string Route { get; } = route; + public global::System.Collections.Generic.IEnumerable Methods { get; } = methods; + + public override bool Equals(object? obj) + { + if (obj is EndpointKey other) + { + return string.Equals(Route, other.Route, global::System.StringComparison.OrdinalIgnoreCase) && + Methods.SequenceEqual(other.Methods, global::System.StringComparer.OrdinalIgnoreCase); + } + return false; + } + + public override int GetHashCode() + { + int hash = 17; + hash = hash * 23 + (Route?.GetHashCode(global::System.StringComparison.OrdinalIgnoreCase) ?? 0); + hash = hash * 23 + GetMethodsHashCode(Methods); + return hash; + } + + private static int GetMethodsHashCode(global::System.Collections.Generic.IEnumerable methods) + { + if (methods == null) + { + return 0; + } + int hash = 17; + foreach (var method in methods) + { + hash = hash * 23 + (method?.GetHashCode(global::System.StringComparison.OrdinalIgnoreCase) ?? 0); + } + return hash; + } + } + + file class ValidationProblemBuilder + { + private readonly global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails _problemDetails; + + public ValidationProblemBuilder() + { + _problemDetails = new global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails(); + } + + public ValidationProblemBuilder WithTitle(string title) + { + _problemDetails.Title = title; + return this; + } + + public ValidationProblemBuilder WithStatus(int? status) + { + _problemDetails.Status = status; + return this; + } + + public ValidationProblemBuilder WithDetail(string detail) + { + _problemDetails.Detail = detail; + return this; + } + + public ValidationProblemBuilder WithInstance(string instance) + { + _problemDetails.Instance = instance; + return this; + } + + public ValidationProblemBuilder WithType(string type) + { + _problemDetails.Type = type; + return this; + } + + public ValidationProblemBuilder WithExtensions(global::System.Collections.Generic.IDictionary extensions) + { + foreach (var kvp in extensions) + { + _problemDetails.Extensions[kvp.Key] = kvp.Value; + } + return this; + } + + public ValidationProblemBuilder WithErrors(global::System.Collections.Generic.IDictionary errors) + { + foreach (var kvp in errors) + { + _problemDetails.Errors[kvp.Key] = kvp.Value; + } + return this; + } + + public ValidationProblemBuilder WithError(string key, string error) + { + if (_problemDetails.Errors.ContainsKey(key)) + { + _problemDetails.Errors[key] = _problemDetails.Errors[key].Append(error).ToArray(); + } + else + { + _problemDetails.Errors[key] = new string[] { error }; + } + return this; + } + + public ValidationProblemBuilder WithErrors(string key, string[] errors) + { + if (_problemDetails.Errors.ContainsKey(key)) + { + _problemDetails.Errors[key] = _problemDetails.Errors[key].Concat(errors).ToArray(); + } + else + { + _problemDetails.Errors[key] = errors; + } + return this; + } + + public global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails Build() + { + return _problemDetails; + } + + public bool HasValue() + { + return _problemDetails.Errors.Count > 0; + } + } + + file static class WithValidationsInterceptor + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute(1, "y289G4fCMnF0CWz59kYTt0gBAABQcm9ncmFtLmNz")] + public static global::Microsoft.AspNetCore.Builder.IEndpointConventionBuilder WithValidation(this global::Microsoft.AspNetCore.Builder.IEndpointConventionBuilder builder) + { + System.Diagnostics.Debugger.Break(); + builder.AddEndpointFilter(async (context, next) => + { + var targetEndpoint = context.HttpContext.Features.Get()?.Endpoint; + Debug.Assert(targetEndpoint != null); + var route = ((global::Microsoft.AspNetCore.Routing.RouteEndpoint)targetEndpoint).RoutePattern.RawText; + var methods = ((global::Microsoft.AspNetCore.Routing.RouteEndpoint)targetEndpoint).Metadata.GetMetadata()?.HttpMethods ?? ["GET"]; + Debug.Assert(route != null); + var validationFilter = ValidationsFilters.Filters[new EndpointKey(route, methods)]; + var validationProblemDetails = validationFilter(context); + if (validationProblemDetails == null) + { + return await next(context); + } + return global::Microsoft.AspNetCore.Http.TypedResults.ValidationProblem(validationProblemDetails.Errors); + }); + return builder; + } + } + + + file static class ValidationTypes + { + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(T value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) => null; + + private static readonly ValidationAttribute StringValue3Base64StringAttribute = new global::System.ComponentModel.DataAnnotations.Base64StringAttribute(); + + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(DerivedType? value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) + { + var maxDepth = ((global::Microsoft.Extensions.Options.IOptions)(serviceProvider ?? validationContext).GetService(typeof(global::Microsoft.Extensions.Options.IOptions))).Value?.SerializerOptions.MaxDepth; + if (currentDepth > (maxDepth == null || maxDepth == 0 ? 64 : maxDepth)) + { + return null; + } + ValidationProblemBuilder resultBuilder = new(); + if (value != null) + { + if (value is BaseType subTypeBaseType) + { + var subTypeBaseTypeValidationResult = ValidationTypes.Validate((BaseType)subTypeBaseType, validationContext, serviceProvider, currentDepth: currentDepth + 1, skipRecurse); + if (subTypeBaseTypeValidationResult is not null) + { + foreach (var error in subTypeBaseTypeValidationResult.Errors) + { + resultBuilder.WithErrors(error.Key, error.Value); + } + } + } + validationContext ??= new(value, serviceProvider: serviceProvider, items: null); + validationContext.DisplayName = "Value3"; + validationContext.MemberName = "Value3"; + global::System.ComponentModel.DataAnnotations.ValidationResult? validationResult; + validationResult = StringValue3Base64StringAttribute.GetValidationResult(value?.Value3, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("Value3", validationResult.ErrorMessage); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + private static readonly ValidationAttribute Int32Value1RangeAttribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(10, 100); + private static readonly ValidationAttribute StringValue2EmailAddressAttribute = new global::System.ComponentModel.DataAnnotations.EmailAddressAttribute(); + private static readonly ValidationAttribute StringValue2RequiredAttribute = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); + + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(BaseType? value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) + { + var maxDepth = ((global::Microsoft.Extensions.Options.IOptions)(serviceProvider ?? validationContext).GetService(typeof(global::Microsoft.Extensions.Options.IOptions))).Value?.SerializerOptions.MaxDepth; + if (currentDepth > (maxDepth == null || maxDepth == 0 ? 64 : maxDepth)) + { + return null; + } + ValidationProblemBuilder resultBuilder = new(); + if (value != null) + { + if (!skipRecurse && value is DerivedType derivedTypeDerivedType) + { + var derivedTypeDerivedTypeValidationResult = ValidationTypes.Validate(derivedTypeDerivedType, validationContext, serviceProvider, currentDepth: currentDepth + 1, skipRecurse: true); + if (derivedTypeDerivedTypeValidationResult is not null) + { + foreach (var error in derivedTypeDerivedTypeValidationResult.Errors) + { + resultBuilder.WithErrors(error.Key, error.Value); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + validationContext ??= new(value, serviceProvider: serviceProvider, items: null); + validationContext.DisplayName = "Value 1"; + validationContext.MemberName = "Value1"; + global::System.ComponentModel.DataAnnotations.ValidationResult? validationResult; + validationResult = Int32Value1RangeAttribute.GetValidationResult(value?.Value1, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("Value1", validationResult.ErrorMessage); + } + validationContext.DisplayName = "Value2"; + validationContext.MemberName = "Value2"; + validationResult = StringValue2EmailAddressAttribute.GetValidationResult(value?.Value2, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("Value2", validationResult.ErrorMessage); + } + validationResult = StringValue2RequiredAttribute.GetValidationResult(value?.Value2, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("Value2", validationResult.ErrorMessage); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + private static readonly ValidationAttribute StringValue3EmailAddressAttribute = new global::System.ComponentModel.DataAnnotations.EmailAddressAttribute(); + + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(DerivedValidatableType? value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) + { + var maxDepth = ((global::Microsoft.Extensions.Options.IOptions)(serviceProvider ?? validationContext).GetService(typeof(global::Microsoft.Extensions.Options.IOptions))).Value?.SerializerOptions.MaxDepth; + if (currentDepth > (maxDepth == null || maxDepth == 0 ? 64 : maxDepth)) + { + return null; + } + ValidationProblemBuilder resultBuilder = new(); + if (value != null) + { + if (value is BaseValidatableType subTypeBaseValidatableType) + { + var subTypeBaseValidatableTypeValidationResult = ValidationTypes.Validate((BaseValidatableType)subTypeBaseValidatableType, validationContext, serviceProvider, currentDepth: currentDepth + 1, skipRecurse); + if (subTypeBaseValidatableTypeValidationResult is not null) + { + foreach (var error in subTypeBaseValidatableTypeValidationResult.Errors) + { + resultBuilder.WithErrors(error.Key, error.Value); + } + } + } + validationContext ??= new(value, serviceProvider: serviceProvider, items: null); + validationContext.DisplayName = "Value3"; + validationContext.MemberName = "Value3"; + global::System.ComponentModel.DataAnnotations.ValidationResult? validationResult; + validationResult = StringValue3EmailAddressAttribute.GetValidationResult(value?.Value3, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("Value3", validationResult.ErrorMessage); + } + if (!resultBuilder.HasValue()) + { + validationContext = new(value, serviceProvider: serviceProvider, items: null); + foreach (var validatableValidationResult in value.Validate(validationContext)) + { + if (validatableValidationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError(validatableValidationResult.MemberNames.FirstOrDefault() ?? "DerivedValidatableType", validatableValidationResult.ErrorMessage); + } + } + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(BaseValidatableType? value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) + { + var maxDepth = ((global::Microsoft.Extensions.Options.IOptions)(serviceProvider ?? validationContext).GetService(typeof(global::Microsoft.Extensions.Options.IOptions))).Value?.SerializerOptions.MaxDepth; + if (currentDepth > (maxDepth == null || maxDepth == 0 ? 64 : maxDepth)) + { + return null; + } + ValidationProblemBuilder resultBuilder = new(); + if (value != null) + { + if (!skipRecurse && value is DerivedValidatableType derivedTypeDerivedValidatableType) + { + var derivedTypeDerivedValidatableTypeValidationResult = ValidationTypes.Validate(derivedTypeDerivedValidatableType, validationContext, serviceProvider, currentDepth: currentDepth + 1, skipRecurse: true); + if (derivedTypeDerivedValidatableTypeValidationResult is not null) + { + foreach (var error in derivedTypeDerivedValidatableTypeValidationResult.Errors) + { + resultBuilder.WithErrors(error.Key, error.Value); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + validationContext ??= new(value, serviceProvider: serviceProvider, items: null); + validationContext.DisplayName = "Value 1"; + validationContext.MemberName = "Value1"; + if (!resultBuilder.HasValue()) + { + validationContext = new(value, serviceProvider: serviceProvider, items: null); + foreach (var validatableValidationResult in value.Validate(validationContext)) + { + if (validatableValidationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError(validatableValidationResult.MemberNames.FirstOrDefault() ?? "BaseValidatableType", validatableValidationResult.ErrorMessage); + } + } + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(ContainerType? value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) + { + var maxDepth = ((global::Microsoft.Extensions.Options.IOptions)(serviceProvider ?? validationContext).GetService(typeof(global::Microsoft.Extensions.Options.IOptions))).Value?.SerializerOptions.MaxDepth; + if (currentDepth > (maxDepth == null || maxDepth == 0 ? 64 : maxDepth)) + { + return null; + } + ValidationProblemBuilder resultBuilder = new(); + if (value != null) + { + validationContext ??= new(value, serviceProvider: serviceProvider, items: null); + validationContext.DisplayName = "BaseType"; + validationContext.MemberName = "BaseType"; + var typeBaseTypeValidationResult = ValidationTypes.Validate(value?.BaseType, validationContext, serviceProvider, currentDepth: currentDepth + 1, skipRecurse); + if (typeBaseTypeValidationResult is not null) + { + foreach (var error in typeBaseTypeValidationResult.Errors) + { + resultBuilder.WithErrors($"BaseType.{error.Key}", error.Value); + } + } + validationContext.DisplayName = "BaseValidatableType"; + validationContext.MemberName = "BaseValidatableType"; + var typeBaseValidatableTypeValidationResult = ValidationTypes.Validate(value?.BaseValidatableType, validationContext, serviceProvider, currentDepth: currentDepth + 1, skipRecurse); + if (typeBaseValidatableTypeValidationResult is not null) + { + foreach (var error in typeBaseValidatableTypeValidationResult.Errors) + { + resultBuilder.WithErrors($"BaseValidatableType.{error.Key}", error.Value); + } + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + } + + + file static class ValidationsFilters + { + public static readonly global::System.Collections.Generic.Dictionary> Filters = new() + { + { new EndpointKey("/basic-polymorphism", ["POST"]), context => + { + ValidationProblemBuilder resultBuilder = new(); + var value0 = context.GetArgument(0); + var typeValidationResult = ValidationTypes.Validate(value0, serviceProvider: context.HttpContext.RequestServices); + if (typeValidationResult is not null) + { + foreach (var error in typeValidationResult.Errors) + { + resultBuilder.WithErrors(error.Key, error.Value); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + }, + + { new EndpointKey("/validatable-polymorphism", ["POST"]), context => + { + ValidationProblemBuilder resultBuilder = new(); + var value0 = context.GetArgument(0); + var typeValidationResult = ValidationTypes.Validate(value0, serviceProvider: context.HttpContext.RequestServices); + if (typeValidationResult is not null) + { + foreach (var error in typeValidationResult.Errors) + { + resultBuilder.WithErrors(error.Key, error.Value); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + }, + + { new EndpointKey("/polymorphism-container", ["POST"]), context => + { + ValidationProblemBuilder resultBuilder = new(); + var value0 = context.GetArgument(0); + var typeValidationResult = ValidationTypes.Validate(value0, serviceProvider: context.HttpContext.RequestServices); + if (typeValidationResult is not null) + { + foreach (var error in typeValidationResult.Errors) + { + resultBuilder.WithErrors(error.Key, error.Value); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + }, + + }; + } +} +#nullable restore diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateRecursiveTypes#RouteHandlerValidations.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateRecursiveTypes#RouteHandlerValidations.g.verified.cs new file mode 100644 index 000000000000..b99c7a57aca9 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateRecursiveTypes#RouteHandlerValidations.g.verified.cs @@ -0,0 +1,242 @@ +//HintName: RouteHandlerValidations.g.cs +// +#nullable enable +namespace System.Runtime.CompilerServices +{ + [AttributeUsage(System.AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(int version, string data) { } + } +} + +namespace Microsoft.AspNetCore.Http.Validations.Generated +{ + using System; + using System.Linq; + using System.Diagnostics; + using System.ComponentModel.DataAnnotations; + + file class EndpointKey(string route, global::System.Collections.Generic.IEnumerable methods) + { + public string Route { get; } = route; + public global::System.Collections.Generic.IEnumerable Methods { get; } = methods; + + public override bool Equals(object? obj) + { + if (obj is EndpointKey other) + { + return string.Equals(Route, other.Route, global::System.StringComparison.OrdinalIgnoreCase) && + Methods.SequenceEqual(other.Methods, global::System.StringComparer.OrdinalIgnoreCase); + } + return false; + } + + public override int GetHashCode() + { + int hash = 17; + hash = hash * 23 + (Route?.GetHashCode(global::System.StringComparison.OrdinalIgnoreCase) ?? 0); + hash = hash * 23 + GetMethodsHashCode(Methods); + return hash; + } + + private static int GetMethodsHashCode(global::System.Collections.Generic.IEnumerable methods) + { + if (methods == null) + { + return 0; + } + int hash = 17; + foreach (var method in methods) + { + hash = hash * 23 + (method?.GetHashCode(global::System.StringComparison.OrdinalIgnoreCase) ?? 0); + } + return hash; + } + } + + file class ValidationProblemBuilder + { + private readonly global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails _problemDetails; + + public ValidationProblemBuilder() + { + _problemDetails = new global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails(); + } + + public ValidationProblemBuilder WithTitle(string title) + { + _problemDetails.Title = title; + return this; + } + + public ValidationProblemBuilder WithStatus(int? status) + { + _problemDetails.Status = status; + return this; + } + + public ValidationProblemBuilder WithDetail(string detail) + { + _problemDetails.Detail = detail; + return this; + } + + public ValidationProblemBuilder WithInstance(string instance) + { + _problemDetails.Instance = instance; + return this; + } + + public ValidationProblemBuilder WithType(string type) + { + _problemDetails.Type = type; + return this; + } + + public ValidationProblemBuilder WithExtensions(global::System.Collections.Generic.IDictionary extensions) + { + foreach (var kvp in extensions) + { + _problemDetails.Extensions[kvp.Key] = kvp.Value; + } + return this; + } + + public ValidationProblemBuilder WithErrors(global::System.Collections.Generic.IDictionary errors) + { + foreach (var kvp in errors) + { + _problemDetails.Errors[kvp.Key] = kvp.Value; + } + return this; + } + + public ValidationProblemBuilder WithError(string key, string error) + { + if (_problemDetails.Errors.ContainsKey(key)) + { + _problemDetails.Errors[key] = _problemDetails.Errors[key].Append(error).ToArray(); + } + else + { + _problemDetails.Errors[key] = new string[] { error }; + } + return this; + } + + public ValidationProblemBuilder WithErrors(string key, string[] errors) + { + if (_problemDetails.Errors.ContainsKey(key)) + { + _problemDetails.Errors[key] = _problemDetails.Errors[key].Concat(errors).ToArray(); + } + else + { + _problemDetails.Errors[key] = errors; + } + return this; + } + + public global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails Build() + { + return _problemDetails; + } + + public bool HasValue() + { + return _problemDetails.Errors.Count > 0; + } + } + + file static class WithValidationsInterceptor + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute(1, "Ve9fVbKCj9TV9MAUJK8WnzABAABQcm9ncmFtLmNz")] + public static global::Microsoft.AspNetCore.Builder.IEndpointConventionBuilder WithValidation(this global::Microsoft.AspNetCore.Builder.IEndpointConventionBuilder builder) + { + System.Diagnostics.Debugger.Break(); + builder.AddEndpointFilter(async (context, next) => + { + var targetEndpoint = context.HttpContext.Features.Get()?.Endpoint; + Debug.Assert(targetEndpoint != null); + var route = ((global::Microsoft.AspNetCore.Routing.RouteEndpoint)targetEndpoint).RoutePattern.RawText; + var methods = ((global::Microsoft.AspNetCore.Routing.RouteEndpoint)targetEndpoint).Metadata.GetMetadata()?.HttpMethods ?? ["GET"]; + Debug.Assert(route != null); + var validationFilter = ValidationsFilters.Filters[new EndpointKey(route, methods)]; + var validationProblemDetails = validationFilter(context); + if (validationProblemDetails == null) + { + return await next(context); + } + return global::Microsoft.AspNetCore.Http.TypedResults.ValidationProblem(validationProblemDetails.Errors); + }); + return builder; + } + } + + + file static class ValidationTypes + { + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(T value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) => null; + + private static readonly ValidationAttribute Int32ValueRangeAttribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(10, 100); + + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(RecursiveType? value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null, IServiceProvider? serviceProvider = null, int currentDepth = 0, bool skipRecurse = false) + { + var maxDepth = ((global::Microsoft.Extensions.Options.IOptions)(serviceProvider ?? validationContext).GetService(typeof(global::Microsoft.Extensions.Options.IOptions))).Value?.SerializerOptions.MaxDepth; + if (currentDepth > (maxDepth == null || maxDepth == 0 ? 64 : maxDepth)) + { + return null; + } + ValidationProblemBuilder resultBuilder = new(); + if (value != null) + { + validationContext ??= new(value, serviceProvider: serviceProvider, items: null); + validationContext.DisplayName = "Value"; + validationContext.MemberName = "Value"; + global::System.ComponentModel.DataAnnotations.ValidationResult? validationResult; + validationResult = Int32ValueRangeAttribute.GetValidationResult(value?.Value, validationContext); + if (validationResult is global::System.ComponentModel.DataAnnotations.ValidationResult { ErrorMessage: not null }) + { + resultBuilder.WithError("Value", validationResult.ErrorMessage); + } + validationContext.DisplayName = "Next"; + validationContext.MemberName = "Next"; + var typeNextValidationResult = ValidationTypes.Validate(value?.Next, validationContext, serviceProvider, currentDepth: currentDepth + 1, skipRecurse); + if (typeNextValidationResult is not null) + { + foreach (var error in typeNextValidationResult.Errors) + { + resultBuilder.WithErrors($"Next.{error.Key}", error.Value); + } + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + } + + + file static class ValidationsFilters + { + public static readonly global::System.Collections.Generic.Dictionary> Filters = new() + { + { new EndpointKey("/recursive-type", ["POST"]), context => + { + ValidationProblemBuilder resultBuilder = new(); + var value0 = context.GetArgument(0); + var typeValidationResult = ValidationTypes.Validate(value0, serviceProvider: context.HttpContext.RequestServices); + if (typeValidationResult is not null) + { + foreach (var error in typeValidationResult.Errors) + { + resultBuilder.WithErrors(error.Key, error.Value); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + }, + + }; + } +} +#nullable restore diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTypeTests.CanValidateComplexTypes#RouteHandlerValidations.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTypeTests.CanValidateComplexTypes#RouteHandlerValidations.g.verified.cs new file mode 100644 index 000000000000..3bd699a0b5a1 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTypeTests.CanValidateComplexTypes#RouteHandlerValidations.g.verified.cs @@ -0,0 +1,368 @@ +//HintName: RouteHandlerValidations.g.cs +// +#nullable enable +namespace System.Runtime.CompilerServices +{ + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(int version, string data) { } + } +} + +namespace Microsoft.AspNetCore.Http.Validations.Generated +{ + using System; + using System.Linq; + using System.Diagnostics; + using System.ComponentModel.DataAnnotations; + using Microsoft.AspNetCore.Builder; + using System.Collections.Generic; + using Microsoft.AspNetCore.Http.Features; + using Microsoft.AspNetCore.Routing; + using Microsoft.Extensions.Options; + + file class EndpointKey(string route, global::System.Collections.Generic.IEnumerable methods) + { + public string Route { get; } = route; + public global::System.Collections.Generic.IEnumerable Methods { get; } = methods; + + public override bool Equals(object? obj) + { + if (obj is EndpointKey other) + { + return string.Equals(Route, other.Route, global::System.StringComparison.OrdinalIgnoreCase) && + Methods.SequenceEqual(other.Methods, global::System.StringComparer.OrdinalIgnoreCase); + } + return false; + } + + public override int GetHashCode() + { + int hash = 17; + hash = hash * 23 + (Route?.GetHashCode(global::System.StringComparison.OrdinalIgnoreCase) ?? 0); + hash = hash * 23 + GetMethodsHashCode(Methods); + return hash; + } + + private static int GetMethodsHashCode(global::System.Collections.Generic.IEnumerable methods) + { + if (methods == null) + return 0; + int hash = 17; + foreach (var method in methods) + { + hash = hash * 23 + (method?.GetHashCode(global::System.StringComparison.OrdinalIgnoreCase) ?? 0); + } + return hash; + } + } + + file class ValidationProblemBuilder + { + private readonly global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails _problemDetails; + + public ValidationProblemBuilder() + { + _problemDetails = new global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails(); + } + + public ValidationProblemBuilder WithTitle(string title) + { + _problemDetails.Title = title; + return this; + } + + public ValidationProblemBuilder WithStatus(int? status) + { + _problemDetails.Status = status; + return this; + } + + public ValidationProblemBuilder WithDetail(string detail) + { + _problemDetails.Detail = detail; + return this; + } + + public ValidationProblemBuilder WithInstance(string instance) + { + _problemDetails.Instance = instance; + return this; + } + + public ValidationProblemBuilder WithType(string type) + { + _problemDetails.Type = type; + return this; + } + + public ValidationProblemBuilder WithExtensions(global::System.Collections.Generic.IDictionary extensions) + { + foreach (var kvp in extensions) + { + _problemDetails.Extensions[kvp.Key] = kvp.Value; + } + return this; + } + + public ValidationProblemBuilder WithErrors(global::System.Collections.Generic.IDictionary errors) + { + foreach (var kvp in errors) + { + _problemDetails.Errors[kvp.Key] = kvp.Value; + } + return this; + } + public ValidationProblemBuilder WithError(string key, string error) + { + if (_problemDetails.Errors.ContainsKey(key)) + { + _problemDetails.Errors[key] = _problemDetails.Errors[key].Append(error).ToArray(); + } + else + { + _problemDetails.Errors[key] = new string[] { error }; + } + return this; + } + + public ValidationProblemBuilder WithErrors(string key, string[] errors) + { + if (_problemDetails.Errors.ContainsKey(key)) + { + _problemDetails.Errors[key] = _problemDetails.Errors[key].Concat(errors).ToArray(); + } + else + { + _problemDetails.Errors[key] = errors; + } + return this; + } + + public global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails Build() + { + return _problemDetails; + } + + public bool HasValue() + { + return _problemDetails.Errors.Count > 0; + } + } + + file static class WithValidationsInterceptor + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute(1, "Rme9eOraTczInPnMR3QdwhwBAABQcm9ncmFtLmNz")] + public static global::Microsoft.AspNetCore.Builder.IEndpointConventionBuilder WithValidation(this global::Microsoft.AspNetCore.Builder.IEndpointConventionBuilder builder) + { + builder.AddEndpointFilter(async (context, next) => + { + System.Diagnostics.Debugger.Break(); + var targetEndpoint = context.HttpContext.Features.Get()?.Endpoint; + Debug.Assert(targetEndpoint != null); + var route = ((global::Microsoft.AspNetCore.Routing.RouteEndpoint)targetEndpoint).RoutePattern.RawText; + var method = ((global::Microsoft.AspNetCore.Routing.RouteEndpoint)targetEndpoint).Metadata.GetMetadata().HttpMethods; + Debug.Assert(route != null); + var validationFilter = ValidationsFilters.Filters[new EndpointKey(route, method)]; + var validationProblemDetails = validationFilter(context); + if (validationProblemDetails == null) + { + return await next(context); + } + return global::Microsoft.AspNetCore.Http.TypedResults.ValidationProblem(validationProblemDetails.Errors); + }); + return builder; + } + } + + + file static class ValidationTypes + { + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(T value) => null; + + private static readonly ValidationAttribute Int32IntegerWithRangeRangeAttribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(10, 100); + private static readonly ValidationAttribute Int32IntegerWithRangeAndDisplayNameRangeAttribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(10, 100); + private static readonly ValidationAttribute SubTypePropertyWithMemberAttributesRequiredAttribute = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); + private static readonly ValidationAttribute Int32IntegerWithCustomValidationAttributeCustomValidationAttribute = new global::CustomValidationAttribute() { ErrorMessage= "Value must be an even number" }; + private static readonly ValidationAttribute Int32PropertyWithMultipleAttributesCustomValidationAttribute = new global::CustomValidationAttribute(); + private static readonly ValidationAttribute Int32PropertyWithMultipleAttributesRangeAttribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(10, 100); + + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(ComplexType? value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null) + { + ValidationProblemBuilder resultBuilder = new(); + global::System.ComponentModel.DataAnnotations.ValidationResult? validationResult; + if (value != null) + { + validationContext ??= new(value); + validationContext.DisplayName = "IntegerWithRange"; + validationContext.MemberName = "IntegerWithRange"; + validationResult = Int32IntegerWithRangeRangeAttribute.GetValidationResult(value.IntegerWithRange, validationContext); + if (validationResult is not null && validationResult != global::System.ComponentModel.DataAnnotations.ValidationResult.Success) + { + resultBuilder.WithError("IntegerWithRange", validationResult.ErrorMessage); + } + validationContext.DisplayName = "Valid identifier"; + validationContext.MemberName = "IntegerWithRangeAndDisplayName"; + validationResult = Int32IntegerWithRangeAndDisplayNameRangeAttribute.GetValidationResult(value.IntegerWithRangeAndDisplayName, validationContext); + if (validationResult is not null && validationResult != global::System.ComponentModel.DataAnnotations.ValidationResult.Success) + { + resultBuilder.WithError("IntegerWithRangeAndDisplayName", validationResult.ErrorMessage); + } + validationContext.DisplayName = "PropertyWithMemberAttributes"; + validationContext.MemberName = "PropertyWithMemberAttributes"; + var typePropertyWithMemberAttributesValidationResult = ValidationTypes.Validate(value?.PropertyWithMemberAttributes, validationContext); + if (typePropertyWithMemberAttributesValidationResult is not null) + { + foreach (var error in typePropertyWithMemberAttributesValidationResult.Errors) + { + resultBuilder.WithErrors($"PropertyWithMemberAttributes.{error.Key}", error.Value); + } + } + validationResult = SubTypePropertyWithMemberAttributesRequiredAttribute.GetValidationResult(value.PropertyWithMemberAttributes, validationContext); + if (validationResult is not null && validationResult != global::System.ComponentModel.DataAnnotations.ValidationResult.Success) + { + resultBuilder.WithError("PropertyWithMemberAttributes", validationResult.ErrorMessage); + } + validationContext.DisplayName = "PropertyWithoutMemberAttributes"; + validationContext.MemberName = "PropertyWithoutMemberAttributes"; + var typePropertyWithoutMemberAttributesValidationResult = ValidationTypes.Validate(value?.PropertyWithoutMemberAttributes, validationContext); + if (typePropertyWithoutMemberAttributesValidationResult is not null) + { + foreach (var error in typePropertyWithoutMemberAttributesValidationResult.Errors) + { + resultBuilder.WithErrors($"PropertyWithoutMemberAttributes.{error.Key}", error.Value); + } + } + validationContext.DisplayName = "PropertyWithInheritance"; + validationContext.MemberName = "PropertyWithInheritance"; + var typePropertyWithInheritanceValidationResult = ValidationTypes.Validate(value?.PropertyWithInheritance, validationContext); + if (typePropertyWithInheritanceValidationResult is not null) + { + foreach (var error in typePropertyWithInheritanceValidationResult.Errors) + { + resultBuilder.WithErrors($"PropertyWithInheritance.{error.Key}", error.Value); + } + } + validationContext.DisplayName = "ListOfSubTypes"; + validationContext.MemberName = "ListOfSubTypes"; + var ListOfSubTypesIndex = 0; + foreach (var item in value.ListOfSubTypes ?? []) + { + var itemValidationResult = ValidationTypes.Validate(item, validationContext); + if (itemValidationResult is not null) + { + foreach (var error in itemValidationResult.Errors) + { + resultBuilder.WithErrors($"ListOfSubTypes[{ ListOfSubTypesIndex }].{error.Key}", error.Value); + } + ListOfSubTypesIndex++; + } + } + validationContext.DisplayName = "IntegerWithCustomValidationAttribute"; + validationContext.MemberName = "IntegerWithCustomValidationAttribute"; + validationResult = Int32IntegerWithCustomValidationAttributeCustomValidationAttribute.GetValidationResult(value.IntegerWithCustomValidationAttribute, validationContext); + if (validationResult is not null && validationResult != global::System.ComponentModel.DataAnnotations.ValidationResult.Success) + { + resultBuilder.WithError("IntegerWithCustomValidationAttribute", validationResult.ErrorMessage); + } + validationContext.DisplayName = "PropertyWithMultipleAttributes"; + validationContext.MemberName = "PropertyWithMultipleAttributes"; + validationResult = Int32PropertyWithMultipleAttributesCustomValidationAttribute.GetValidationResult(value.PropertyWithMultipleAttributes, validationContext); + if (validationResult is not null && validationResult != global::System.ComponentModel.DataAnnotations.ValidationResult.Success) + { + resultBuilder.WithError("PropertyWithMultipleAttributes", validationResult.ErrorMessage); + } + validationResult = Int32PropertyWithMultipleAttributesRangeAttribute.GetValidationResult(value.PropertyWithMultipleAttributes, validationContext); + if (validationResult is not null && validationResult != global::System.ComponentModel.DataAnnotations.ValidationResult.Success) + { + resultBuilder.WithError("PropertyWithMultipleAttributes", validationResult.ErrorMessage); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + private static readonly ValidationAttribute StringRequiredPropertyRequiredAttribute = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); + private static readonly ValidationAttribute StringStringWithLengthStringLengthAttribute = new global::System.ComponentModel.DataAnnotations.StringLengthAttribute(10); + + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(SubType? value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null) + { + ValidationProblemBuilder resultBuilder = new(); + global::System.ComponentModel.DataAnnotations.ValidationResult? validationResult; + if (value != null) + { + validationContext ??= new(value); + validationContext.DisplayName = "RequiredProperty"; + validationContext.MemberName = "RequiredProperty"; + validationResult = StringRequiredPropertyRequiredAttribute.GetValidationResult(value.RequiredProperty, validationContext); + if (validationResult is not null && validationResult != global::System.ComponentModel.DataAnnotations.ValidationResult.Success) + { + resultBuilder.WithError("RequiredProperty", validationResult.ErrorMessage); + } + validationContext.DisplayName = "StringWithLength"; + validationContext.MemberName = "StringWithLength"; + validationResult = StringStringWithLengthStringLengthAttribute.GetValidationResult(value.StringWithLength, validationContext); + if (validationResult is not null && validationResult != global::System.ComponentModel.DataAnnotations.ValidationResult.Success) + { + resultBuilder.WithError("StringWithLength", validationResult.ErrorMessage); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + private static readonly ValidationAttribute StringEmailStringEmailAddressAttribute = new global::System.ComponentModel.DataAnnotations.EmailAddressAttribute(); + + public static global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? Validate(SubTypeWithInheritance? value, global::System.ComponentModel.DataAnnotations.ValidationContext? validationContext = null) + { + ValidationProblemBuilder resultBuilder = new(); + global::System.ComponentModel.DataAnnotations.ValidationResult? validationResult; + if (value != null) + { + if (value is SubType subTypeSubType) + { + var subTypeSubTypeValidationResult = ValidationTypes.Validate(subTypeSubType, validationContext); + if (subTypeSubTypeValidationResult is not null) + { + foreach (var error in subTypeSubTypeValidationResult.Errors) + { + resultBuilder.WithErrors(error.Key, error.Value); + } + } + } + validationContext ??= new(value); + validationContext.DisplayName = "EmailString"; + validationContext.MemberName = "EmailString"; + validationResult = StringEmailStringEmailAddressAttribute.GetValidationResult(value.EmailString, validationContext); + if (validationResult is not null && validationResult != global::System.ComponentModel.DataAnnotations.ValidationResult.Success) + { + resultBuilder.WithError("EmailString", validationResult.ErrorMessage); + } + } + return resultBuilder.HasValue() ? resultBuilder.Build() : null; + } + } + + + file static class ValidationsFilters + { + public static readonly global::System.Collections.Generic.Dictionary> Filters = new() + { + { new EndpointKey("/complex-type", ["POST"]), context => + { + + var value = context.GetArgument(0); + global::Microsoft.AspNetCore.Http.HttpValidationProblemDetails? result = null; + var typeValidationResult = ValidationTypes.Validate(value, null); + if (typeValidationResult is not null) + { + foreach (var error in typeValidationResult.Errors) + { + (result ??= new()).Errors.Add(error.Key, error.Value); + } + } + return result; + } + }, + + }; + } +} +#nullable restore diff --git a/src/Http/Http/perf/Microbenchmarks/Microsoft.AspNetCore.Http.Microbenchmarks.csproj b/src/Http/Http/perf/Microbenchmarks/Microsoft.AspNetCore.Http.Microbenchmarks.csproj index 216ac33d03c1..30f74289fc47 100644 --- a/src/Http/Http/perf/Microbenchmarks/Microsoft.AspNetCore.Http.Microbenchmarks.csproj +++ b/src/Http/Http/perf/Microbenchmarks/Microsoft.AspNetCore.Http.Microbenchmarks.csproj @@ -33,7 +33,7 @@ - + diff --git a/src/Http/HttpAbstractions.slnf b/src/Http/HttpAbstractions.slnf index dbf53c6c4c39..959b30a02a55 100644 --- a/src/Http/HttpAbstractions.slnf +++ b/src/Http/HttpAbstractions.slnf @@ -20,7 +20,8 @@ "src\\Http\\Http.Abstractions\\perf\\Microbenchmarks\\Microsoft.AspNetCore.Http.Abstractions.Microbenchmarks.csproj", "src\\Http\\Http.Abstractions\\src\\Microsoft.AspNetCore.Http.Abstractions.csproj", "src\\Http\\Http.Abstractions\\test\\Microsoft.AspNetCore.Http.Abstractions.Tests.csproj", - "src\\Http\\Http.Extensions\\gen\\Microsoft.AspNetCore.Http.RequestDelegateGenerator.csproj", + "src\\Http\\Http.Extensions\\gen\\RequestDelegateGenerator\\Microsoft.AspNetCore.Http.RequestDelegateGenerator.csproj", + "src\\Http\\Http.Extensions\\gen\\ValidationsGenerator\\Microsoft.AspNetCore.Http.ValidationsGenerator.csproj", "src\\Http\\Http.Extensions\\src\\Microsoft.AspNetCore.Http.Extensions.csproj", "src\\Http\\Http.Extensions\\test\\Microsoft.AspNetCore.Http.Extensions.Tests.csproj", "src\\Http\\Http.Features\\src\\Microsoft.AspNetCore.Http.Features.csproj", @@ -71,4 +72,4 @@ "src\\WebEncoders\\src\\Microsoft.Extensions.WebEncoders.csproj" ] } -} \ No newline at end of file +} diff --git a/src/Http/Routing/src/Builder/ValidationConventionBuilderExtensions.cs b/src/Http/Routing/src/Builder/ValidationConventionBuilderExtensions.cs new file mode 100644 index 000000000000..b88226328e13 --- /dev/null +++ b/src/Http/Routing/src/Builder/ValidationConventionBuilderExtensions.cs @@ -0,0 +1,21 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Builder; + +/// +/// Extension methods for managing model validation on an endpoint. +/// +public static class ValidationConventionBuilderExtensions +{ + /// + /// Adds validation to the specified . + /// + /// The to add validation to. + /// The with validation added. + public static IEndpointConventionBuilder WithValidation(this IEndpointConventionBuilder builder) + { + // Intentionally empty, requires interception by validations generator + return builder; + } +} diff --git a/src/Http/Routing/src/PublicAPI.Unshipped.txt b/src/Http/Routing/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..9034bc33471b 100644 --- a/src/Http/Routing/src/PublicAPI.Unshipped.txt +++ b/src/Http/Routing/src/PublicAPI.Unshipped.txt @@ -1 +1,3 @@ #nullable enable +Microsoft.AspNetCore.Builder.ValidationConventionBuilderExtensions +static Microsoft.AspNetCore.Builder.ValidationConventionBuilderExtensions.WithValidation(this Microsoft.AspNetCore.Builder.IEndpointConventionBuilder! builder) -> Microsoft.AspNetCore.Builder.IEndpointConventionBuilder! diff --git a/src/Http/samples/MinimalSample/MinimalSample.csproj b/src/Http/samples/MinimalSample/MinimalSample.csproj index 7783a82fe5d8..1f34e8c43f92 100644 --- a/src/Http/samples/MinimalSample/MinimalSample.csproj +++ b/src/Http/samples/MinimalSample/MinimalSample.csproj @@ -19,7 +19,7 @@ - +