diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisition.cs b/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisition.cs index aaa55784e..c812b6482 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisition.cs +++ b/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisition.cs @@ -357,16 +357,11 @@ public Task GetAuthenticationResultForAppAsync( if (tokenAcquisitionOptions != null) { - if (tokenAcquisitionOptions.ExtraQueryParameters != null) + var dict = MergeExtraQueryParameters(mergedOptions, tokenAcquisitionOptions); + + if (dict != null) { - if (mergedOptions.ExtraQueryParameters != null) - { - builder.WithExtraQueryParameters((Dictionary)(MergeExtraQueryParameters(mergedOptions, tokenAcquisitionOptions))); - } - else - { - builder.WithExtraQueryParameters((Dictionary)(tokenAcquisitionOptions.ExtraQueryParameters)); - } + builder.WithExtraQueryParameters(dict); } if (tokenAcquisitionOptions.ExtraHeadersParameters != null) { @@ -719,16 +714,11 @@ private IConfidentialClientApplication BuildConfidentialClientApplication(Merged } if (tokenAcquisitionOptions != null) { - if (tokenAcquisitionOptions.ExtraQueryParameters != null) + var dict = MergeExtraQueryParameters(mergedOptions, tokenAcquisitionOptions); + + if (dict != null) { - if (mergedOptions.ExtraQueryParameters != null) - { - builder.WithExtraQueryParameters((Dictionary)(MergeExtraQueryParameters(mergedOptions, tokenAcquisitionOptions))); - } - else - { - builder.WithExtraQueryParameters((Dictionary)(tokenAcquisitionOptions.ExtraQueryParameters)); - } + builder.WithExtraQueryParameters(dict); } if (tokenAcquisitionOptions.ExtraHeadersParameters != null) { @@ -861,20 +851,11 @@ private Task GetAuthenticationResultForWebAppWithAccountFr if (tokenAcquisitionOptions != null) { - if (tokenAcquisitionOptions.ExtraQueryParameters != null) - { - if (mergedOptions.ExtraQueryParameters != null) - { - builder.WithExtraQueryParameters((Dictionary)(MergeExtraQueryParameters(mergedOptions, tokenAcquisitionOptions))); - } - else - { - builder.WithExtraQueryParameters((Dictionary)(tokenAcquisitionOptions.ExtraQueryParameters)); - } - } - else if (mergedOptions.ExtraQueryParameters != null) + var dict = MergeExtraQueryParameters(mergedOptions, tokenAcquisitionOptions); + + if (dict != null) { - builder.WithExtraQueryParameters((Dictionary)mergedOptions.ExtraQueryParameters); + builder.WithExtraQueryParameters(dict); } if (tokenAcquisitionOptions.ExtraHeadersParameters != null) { @@ -914,18 +895,25 @@ private Task GetAuthenticationResultForWebAppWithAccountFr return builder.ExecuteAsync(tokenAcquisitionOptions != null ? tokenAcquisitionOptions.CancellationToken : CancellationToken.None); } - internal static IDictionary MergeExtraQueryParameters( + internal static Dictionary? MergeExtraQueryParameters( MergedOptions mergedOptions, TokenAcquisitionOptions tokenAcquisitionOptions) { - var mergedDict = tokenAcquisitionOptions!.ExtraQueryParameters; - foreach (var pair in mergedOptions!.ExtraQueryParameters!) + if (tokenAcquisitionOptions.ExtraQueryParameters != null) { - if (!mergedDict!.ContainsKey(pair.Key)) - mergedDict.Add(pair.Key, pair.Value); + var mergedDict = new Dictionary(tokenAcquisitionOptions.ExtraQueryParameters); + if (mergedOptions.ExtraQueryParameters != null) + { + foreach (var pair in mergedOptions!.ExtraQueryParameters) + { + if (!mergedDict!.ContainsKey(pair.Key)) + mergedDict.Add(pair.Key, pair.Value); + } + } + return mergedDict; } - - return mergedDict!; + + return (Dictionary?)mergedOptions.ExtraQueryParameters; } protected static bool AcceptedTokenVersionMismatch(MsalUiRequiredException msalServiceException) diff --git a/tests/DevApps/WebAppCallsMicrosoftGraph/WebAppCallsMicrosoftGraph.csproj b/tests/DevApps/WebAppCallsMicrosoftGraph/WebAppCallsMicrosoftGraph.csproj index ec95925d1..6393f06b4 100644 --- a/tests/DevApps/WebAppCallsMicrosoftGraph/WebAppCallsMicrosoftGraph.csproj +++ b/tests/DevApps/WebAppCallsMicrosoftGraph/WebAppCallsMicrosoftGraph.csproj @@ -7,6 +7,5 @@ - diff --git a/tests/Microsoft.Identity.Web.Test/TokenAcquisitionAuthorityTests.cs b/tests/Microsoft.Identity.Web.Test/TokenAcquisitionAuthorityTests.cs index a77d1384f..c41c95a29 100644 --- a/tests/Microsoft.Identity.Web.Test/TokenAcquisitionAuthorityTests.cs +++ b/tests/Microsoft.Identity.Web.Test/TokenAcquisitionAuthorityTests.cs @@ -236,6 +236,7 @@ public void TestParseAuthorityIfNecessary() [Fact] public void MergeExtraQueryParametersTest() { + // Arrange var mergedOptions = new MergedOptions { ExtraQueryParameters = new Dictionary @@ -253,12 +254,60 @@ public void MergeExtraQueryParametersTest() } }; + // Act var mergedDict = TokenAcquisition.MergeExtraQueryParameters(mergedOptions, tokenAcquisitionOptions); - Assert.Equal(3, mergedDict.Count); + + // Assert + Assert.Equal(3, mergedDict!.Count); Assert.Equal("newvalue1", mergedDict["key1"]); Assert.Equal("value2", mergedDict["key2"]); Assert.Equal("value3", mergedDict["key3"]); } + + [Fact] + public void MergeExtraQueryParameters_TokenAcquisitionOptionsNull_Test() + { + // Arrange + var mergedOptions = new MergedOptions + { + ExtraQueryParameters = new Dictionary + { + { "key1", "value1" }, + { "key2", "value2" } + } + }; + var tokenAcquisitionOptions = new TokenAcquisitionOptions + { + ExtraQueryParameters = null, + }; + + // Act + var mergedDict = TokenAcquisition.MergeExtraQueryParameters(mergedOptions, tokenAcquisitionOptions); + + // Assert + Assert.Equal("value1", mergedDict!["key1"]); + Assert.Equal("value2", mergedDict["key2"]); + } + + [Fact] + public void MergeExtraQueryParameters_MergedOptionsNull_Test() + { + // Arrange + var mergedOptions = new MergedOptions + { + ExtraQueryParameters = null, + }; + var tokenAcquisitionOptions = new TokenAcquisitionOptions + { + ExtraQueryParameters = null, + }; + + // Act + var mergedDict = TokenAcquisition.MergeExtraQueryParameters(mergedOptions, tokenAcquisitionOptions); + + // Assert + Assert.Null(mergedDict); + } } }