Skip to content

Commit

Permalink
Make infer security schemes opt-in
Browse files Browse the repository at this point in the history
  • Loading branch information
domaindrivendev committed Jul 18, 2022
1 parent 71ed7d3 commit 1332212
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Microsoft.OpenApi.Models;
using Microsoft.AspNetCore.Mvc.ApiExplorer;
using Swashbuckle.AspNetCore.SwaggerGen;
using Microsoft.AspNetCore.Authentication;

namespace Microsoft.Extensions.DependencyInjection
{
Expand Down Expand Up @@ -308,6 +309,22 @@ public static void SupportNonNullableReferenceTypes(this SwaggerGenOptions swagg
swaggerGenOptions.SchemaGeneratorOptions.SupportNonNullableReferenceTypes = true;
}

/// <summary>
/// Automatically infer security schemes from authentication/authorization state in ASP.NET Core.
/// </summary>
/// <param name="swaggerGenOptions"></param>
/// <param name="securitySchemesSelector">
/// Provide alternative implementation that maps ASP.NET Core Authentication schemes to Open API security schemes
/// </param>
/// <remarks>Currently only supports JWT Bearer authentication</remarks>
public static void InferSecuritySchemes(
this SwaggerGenOptions swaggerGenOptions,
Func<IEnumerable<AuthenticationScheme>, IDictionary<string, OpenApiSecurityScheme>> securitySchemesSelector = null)
{
swaggerGenOptions.SwaggerGeneratorOptions.InferSecuritySchemes = true;
swaggerGenOptions.SwaggerGeneratorOptions.SecuritySchemesSelector = securitySchemesSelector;
}

/// <summary>
/// Extend the Swagger Generator with "filters" that can modify Schemas after they're initially generated
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ public SwaggerGenerator(
SwaggerGeneratorOptions options,
IApiDescriptionGroupCollectionProvider apiDescriptionsProvider,
ISchemaGenerator schemaGenerator,
IAuthenticationSchemeProvider authentiationSchemeProvider) : this(options, apiDescriptionsProvider, schemaGenerator)
IAuthenticationSchemeProvider authenticationSchemeProvider) : this(options, apiDescriptionsProvider, schemaGenerator)
{
_authenticationSchemeProvider = authentiationSchemeProvider;
_authenticationSchemeProvider = authenticationSchemeProvider;
}

public async Task<OpenApiDocument> GetSwaggerAsync(string documentName, string host = null, string basePath = null)
Expand Down Expand Up @@ -88,27 +88,34 @@ public OpenApiDocument GetSwagger(string documentName, string host = null, strin
return (applicableApiDescriptions, swaggerDoc, schemaRepository);
}

private async Task<Dictionary<string, OpenApiSecurityScheme>> GetSecuritySchemes()
private async Task<IDictionary<string, OpenApiSecurityScheme>> GetSecuritySchemes()
{
var securitySchemes = new Dictionary<string, OpenApiSecurityScheme>(_options.SecuritySchemes);
var authenticationSchemes = Enumerable.Empty<AuthenticationScheme>();
if (_authenticationSchemeProvider is not null)
if (!_options.InferSecuritySchemes)
{
authenticationSchemes = await _authenticationSchemeProvider.GetAllSchemesAsync();
return new Dictionary<string, OpenApiSecurityScheme>(_options.SecuritySchemes);
}
var securitySchemesFromSelector = _options.SecuritySchemesSelector(authenticationSchemes);
// Favor security schemes set via options over those generated
// from the selector. For the default selector, this effectively
// ends up favoring `Bearer` authentication types explicitly set
// by the user over those derived by the selector.
foreach (var securityScheme in securitySchemesFromSelector)

var authenticationSchemes = (_authenticationSchemeProvider is not null)
? await _authenticationSchemeProvider.GetAllSchemesAsync()
: Enumerable.Empty<AuthenticationScheme>();

if (_options.SecuritySchemesSelector != null)
{
if (!securitySchemes.ContainsKey(securityScheme.Key))
{
securitySchemes.Add(securityScheme.Key, securityScheme.Value);
}
return _options.SecuritySchemesSelector(authenticationSchemes);
}
return securitySchemes;

// Default implementation, currently only supports JWT Bearer scheme
return authenticationSchemes
.Where(authScheme => authScheme.Name == "Bearer")
.ToDictionary(
(authScheme) => authScheme.Name,
(authScheme) => new OpenApiSecurityScheme
{
Type = SecuritySchemeType.Http,
Scheme = "bearer", // "bearer" refers to the header name here
In = ParameterLocation.Header,
BearerFormat = "Json Web Token"
});
}

private IList<OpenApiServer> GenerateServers(string host, string basePath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public SwaggerGeneratorOptions()
OperationIdSelector = DefaultOperationIdSelector;
TagsSelector = DefaultTagsSelector;
SortKeySelector = DefaultSortKeySelector;
SecuritySchemesSelector = DefaultSecuritySchemeSelector;
SecuritySchemesSelector = null;
SchemaComparer = StringComparer.Ordinal;
Servers = new List<OpenApiServer>();
SecuritySchemes = new Dictionary<string, OpenApiSecurityScheme>();
Expand All @@ -45,6 +45,10 @@ public SwaggerGeneratorOptions()

public Func<ApiDescription, string> SortKeySelector { get; set; }

public bool InferSecuritySchemes { get; set; }

public Func<IEnumerable<AuthenticationScheme>, IDictionary<string, OpenApiSecurityScheme>> SecuritySchemesSelector { get; set;}

public bool DescribeAllParametersInCamelCase { get; set; }

public List<OpenApiServer> Servers { get; set; }
Expand All @@ -63,8 +67,6 @@ public SwaggerGeneratorOptions()

public IList<IDocumentFilter> DocumentFilters { get; set; }

public Func<IEnumerable<AuthenticationScheme>, Dictionary<string, OpenApiSecurityScheme>> SecuritySchemesSelector { get; set;}

private bool DefaultDocInclusionPredicate(string documentName, ApiDescription apiDescription)
{
return apiDescription.GroupName == null || apiDescription.GroupName == documentName;
Expand Down Expand Up @@ -106,26 +108,5 @@ private string DefaultSortKeySelector(ApiDescription apiDescription)
{
return TagsSelector(apiDescription).First();
}

private Dictionary<string, OpenApiSecurityScheme> DefaultSecuritySchemeSelector(IEnumerable<AuthenticationScheme> schemes)
{
Dictionary<string, OpenApiSecurityScheme> securitySchemes = new();
#if (NET6_0_OR_GREATER)
foreach (var scheme in schemes)
{
if (scheme.Name == "Bearer")
{
securitySchemes[scheme.Name] = new OpenApiSecurityScheme
{
Type = SecuritySchemeType.Http,
Scheme = "bearer", // "bearer" refers to the header name here
In = ParameterLocation.Header,
BearerFormat = "Json Web Token"
};
}
}
#endif
return securitySchemes;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authentication;

namespace Swashbuckle.AspNetCore.SwaggerGen.Test
{
public class FakeAuthenticationSchemeProvider : IAuthenticationSchemeProvider
{
private readonly IEnumerable<AuthenticationScheme> _authenticationSchemes;

public FakeAuthenticationSchemeProvider(IEnumerable<AuthenticationScheme> authenticationSchemes)
{
_authenticationSchemes = authenticationSchemes;
}

public void AddScheme(AuthenticationScheme scheme)
=> throw new NotImplementedException();
public Task<IEnumerable<AuthenticationScheme>> GetAllSchemesAsync()
=> Task.FromResult(_authenticationSchemes);

public Task<AuthenticationScheme> GetDefaultAuthenticateSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultChallengeSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultForbidSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultSignInSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultSignOutSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<IEnumerable<AuthenticationScheme>> GetRequestHandlerSchemesAsync()
=> throw new NotImplementedException();

public Task<AuthenticationScheme> GetSchemeAsync(string name)
=> Task.FromResult(_authenticationSchemes.First());

public void RemoveScheme(string name)
=> throw new NotImplementedException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
using Swashbuckle.AspNetCore.TestSupport;
using Xunit;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Server.HttpSys;
using Microsoft.AspNetCore.Authentication;

namespace Swashbuckle.AspNetCore.SwaggerGen.Test
{
Expand Down Expand Up @@ -1081,76 +1081,70 @@ public void GetSwagger_SupportsOption_SecuritySchemes()

var document = subject.GetSwagger("v1");

Assert.Equal(new[] { "basic", "Bearer" }, document.Components.SecuritySchemes.Keys);
}

[Fact]
public async Task GetSwagger_SupportsSecuritySchemesSelector()
{
var subject = Subject(
apiDescriptions: new ApiDescription[] { },
options: new SwaggerGeneratorOptions
{
SwaggerDocs = new Dictionary<string, OpenApiInfo>
{
["v1"] = new OpenApiInfo { Version = "V1", Title = "Test API" }
},
SecuritySchemesSelector = (schemes) => new Dictionary<string, OpenApiSecurityScheme>
{
["basic"] = new OpenApiSecurityScheme { Type = SecuritySchemeType.Http, Scheme = "basic" }
}
}
);

var document = await subject.GetSwaggerAsync("v1");

// Overrides the default set of [basic, bearer] with just [basic]
Assert.Equal(new[] { "basic" }, document.Components.SecuritySchemes.Keys);
}

[Fact]
public async Task GetSwagger_DefaultSecuritySchemeSelectorAddsBearerByDefault()
[Theory]
[InlineData(false, new string[] { })]
[InlineData(true, new string[] { "Bearer" })]
public async Task GetSwagger_SupportsOption_InferSecuritySchemes(
bool inferSecuritySchemes,
string[] expectedSecuritySchemeNames)

{
var subject = Subject(
apiDescriptions: new ApiDescription[] { },
authenticationSchemes: new[] {
new AuthenticationScheme("Bearer", null, typeof(IAuthenticationHandler)),
new AuthenticationScheme("Cookies", null, typeof(IAuthenticationHandler))
},
options: new SwaggerGeneratorOptions
{
SwaggerDocs = new Dictionary<string, OpenApiInfo>
{
["v1"] = new OpenApiInfo { Version = "V1", Title = "Test API" }
},
InferSecuritySchemes = inferSecuritySchemes
}
);

var document = await subject.GetSwaggerAsync("v1");

Assert.Equal(new[] { "Bearer" }, document.Components.SecuritySchemes.Keys);
Assert.Equal(expectedSecuritySchemeNames, document.Components.SecuritySchemes.Keys);
}

[Fact]
public async Task GetSwagger_DefaultSecuritySchemesSelectorDoesNotOverrideBearer()
[Theory]
[InlineData(false, new string[] { })]
[InlineData(true, new string[] { "Bearer", "Cookies" })]
public async Task GetSwagger_SupportsOption_SecuritySchemesSelector(
bool inferSecuritySchemes,
string[] expectedSecuritySchemeNames)

{
var subject = Subject(
apiDescriptions: new ApiDescription[] { },
authenticationSchemes: new[] {
new AuthenticationScheme("Bearer", null, typeof(IAuthenticationHandler)),
new AuthenticationScheme("Cookies", null, typeof(IAuthenticationHandler))
},
options: new SwaggerGeneratorOptions
{
SwaggerDocs = new Dictionary<string, OpenApiInfo>
{
["v1"] = new OpenApiInfo { Version = "V1", Title = "Test API" }
},
SecuritySchemes = new Dictionary<string, OpenApiSecurityScheme>
{
["Bearer"] = new OpenApiSecurityScheme { Type = SecuritySchemeType.ApiKey, Scheme = "someSpecialOne" }
}
InferSecuritySchemes = inferSecuritySchemes,
SecuritySchemesSelector = (authenticationSchemes) =>
authenticationSchemes
.ToDictionary(
(authScheme) => authScheme.Name,
(authScheme) => new OpenApiSecurityScheme())
}
);

var document = await subject.GetSwaggerAsync("v1");

var securityScheme = Assert.Single(document.Components.SecuritySchemes);
Assert.Equal("Bearer", securityScheme.Key);
Assert.Equal(SecuritySchemeType.ApiKey, securityScheme.Value.Type);
Assert.Equal("someSpecialOne", securityScheme.Value.Scheme);
Assert.Equal(expectedSecuritySchemeNames, document.Components.SecuritySchemes.Keys);
}

[Fact]
Expand Down Expand Up @@ -1283,13 +1277,16 @@ public void GetSwagger_SupportsOption_DocumentFilters()
Assert.Contains("ComplexType", document.Components.Schemas.Keys);
}

private SwaggerGenerator Subject(IEnumerable<ApiDescription> apiDescriptions, SwaggerGeneratorOptions options = null)
private SwaggerGenerator Subject(
IEnumerable<ApiDescription> apiDescriptions,
SwaggerGeneratorOptions options = null,
IEnumerable<AuthenticationScheme> authenticationSchemes = null)
{
return new SwaggerGenerator(
options ?? DefaultOptions,
new FakeApiDescriptionGroupCollectionProvider(apiDescriptions),
new SchemaGenerator(new SchemaGeneratorOptions(), new JsonSerializerDataContractResolver(new JsonSerializerOptions())),
new TestAuthenticationSchemeProvider()
new FakeAuthenticationSchemeProvider(authenticationSchemes ?? Enumerable.Empty<AuthenticationScheme>())
);
}

Expand All @@ -1301,41 +1298,4 @@ private SwaggerGenerator Subject(IEnumerable<ApiDescription> apiDescriptions, Sw
}
};
}

class TestAuthenticationSchemeProvider : IAuthenticationSchemeProvider
{
private readonly IEnumerable<AuthenticationScheme> _authenticationSchemes = new AuthenticationScheme[]
{
new AuthenticationScheme("Bearer", null, typeof(IAuthenticationHandler))
};

public void AddScheme(AuthenticationScheme scheme)
=> throw new NotImplementedException();
public Task<IEnumerable<AuthenticationScheme>> GetAllSchemesAsync()
=> Task.FromResult(_authenticationSchemes);

public Task<AuthenticationScheme> GetDefaultAuthenticateSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultChallengeSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultForbidSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultSignInSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<AuthenticationScheme> GetDefaultSignOutSchemeAsync()
=> Task.FromResult(_authenticationSchemes.First());

public Task<IEnumerable<AuthenticationScheme>> GetRequestHandlerSchemesAsync()
=> throw new NotImplementedException();

public Task<AuthenticationScheme> GetSchemeAsync(string name)
=> Task.FromResult(_authenticationSchemes.First());

public void RemoveScheme(string name)
=> throw new NotImplementedException();
}
}

0 comments on commit 1332212

Please sign in to comment.