diff --git a/src/WebJobs.Script.WebHost/Security/Authentication/Jwt/ScriptJwtBearerExtensions.cs b/src/WebJobs.Script.WebHost/Security/Authentication/Jwt/ScriptJwtBearerExtensions.cs index f8f9d692f0..3b3557bdbc 100644 --- a/src/WebJobs.Script.WebHost/Security/Authentication/Jwt/ScriptJwtBearerExtensions.cs +++ b/src/WebJobs.Script.WebHost/Security/Authentication/Jwt/ScriptJwtBearerExtensions.cs @@ -86,7 +86,7 @@ public static AuthenticationBuilder AddScriptJwtBearer(this AuthenticationBuilde } }); - private static string[] GetValidAudiences() + private static IEnumerable GetValidAudiences() { if (SystemEnvironment.Instance.IsPlaceholderModeEnabled() && SystemEnvironment.Instance.IsLinuxConsumptionOnAtlas()) @@ -97,11 +97,22 @@ private static string[] GetValidAudiences() }; } - return new string[] + string siteName = ScriptSettingsManager.Instance.GetSetting(AzureWebsiteName); + string runtimeSiteName = ScriptSettingsManager.Instance.GetSetting(AzureWebsiteRuntimeSiteName); + var audiences = new List { - string.Format(SiteAzureFunctionsUriFormat, ScriptSettingsManager.Instance.GetSetting(AzureWebsiteName)), - string.Format(SiteUriFormat, ScriptSettingsManager.Instance.GetSetting(AzureWebsiteName)) + string.Format(SiteAzureFunctionsUriFormat, siteName), + string.Format(SiteUriFormat, siteName) }; + + if (!string.IsNullOrEmpty(runtimeSiteName) && !string.Equals(siteName, runtimeSiteName, StringComparison.OrdinalIgnoreCase)) + { + // on a non-production slot, the runtime site name will differ from the site name + // we allow both for audience + audiences.Add(string.Format(SiteUriFormat, runtimeSiteName)); + } + + return audiences; } public static TokenValidationParameters CreateTokenValidationParameters() diff --git a/test/WebJobs.Script.Tests/Extensions/ScriptJwtBearerExtensionsTests.cs b/test/WebJobs.Script.Tests/Extensions/ScriptJwtBearerExtensionsTests.cs index fc8c9fdabb..7bf622e3d1 100644 --- a/test/WebJobs.Script.Tests/Extensions/ScriptJwtBearerExtensionsTests.cs +++ b/test/WebJobs.Script.Tests/Extensions/ScriptJwtBearerExtensionsTests.cs @@ -67,5 +67,43 @@ public void CreateTokenValidationParameters_HasExpectedAudience(bool isPlacehold } } } + + [Theory] + [InlineData("testsite", "testsite")] + [InlineData("testsite", "testsite__5bb5")] + [InlineData("testsite", null)] + [InlineData("testsite", "")] + public void CreateTokenValidationParameters_NonProductionSlot_HasExpectedAudiences(string siteName, string runtimeSiteName) + { + string azFuncAudience = string.Format(ScriptConstants.SiteAzureFunctionsUriFormat, siteName); + string siteAudience = string.Format(ScriptConstants.SiteUriFormat, siteName); + string runtimeSiteAudience = string.Format(ScriptConstants.SiteUriFormat, runtimeSiteName); + + var testEnv = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { EnvironmentSettingNames.AzureWebsiteName, siteName }, + { EnvironmentSettingNames.AzureWebsiteRuntimeSiteName, runtimeSiteName }, + { ContainerEncryptionKey, Convert.ToBase64String(TestHelpers.GenerateKeyBytes()) } + }; + + using (new TestScopedSettings(ScriptSettingsManager.Instance, testEnv)) + { + var tokenValidationParameters = ScriptJwtBearerExtensions.CreateTokenValidationParameters(); + var audiences = tokenValidationParameters.ValidAudiences.ToArray(); + + Assert.Equal(audiences[0], azFuncAudience); + Assert.Equal(audiences[1], siteAudience); + + if (string.Compare(siteName, runtimeSiteName, StringComparison.OrdinalIgnoreCase) == 0) + { + Assert.Equal(2, audiences.Length); + } + else if (!string.IsNullOrEmpty(runtimeSiteName)) + { + Assert.Equal(3, audiences.Length); + Assert.Equal(audiences[2], runtimeSiteAudience); + } + } + } } } \ No newline at end of file