Skip to content

Commit

Permalink
Fix | Handle NRE on Azure federated authentication (dotnet#1625)
Browse files Browse the repository at this point in the history
# Conflicts:
#	src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.Designer.cs
#	src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.resx
#	src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs
  • Loading branch information
DavoudEshtehari committed Aug 15, 2022
1 parent be9731c commit 8852545
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ public void AssertUnrecoverableStateCountIsCorrect()

internal sealed class SqlInternalConnectionTds : SqlInternalConnection, IDisposable
{
// https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/wiki/retry-after#simple-retry-for-errors-with-http-error-codes-500-600
internal const int MsalHttpRetryStatusCode = 429;

// CONNECTION AND STATE VARIABLES
private readonly SqlConnectionPoolGroupProviderInfo _poolGroupProviderInfo; // will only be null when called for ChangePassword, or creating SSE User Instance
private TdsParser _parser;
Expand Down Expand Up @@ -2421,7 +2424,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
// Deal with Msal service exceptions first, retry if 429 received.
catch (MsalServiceException serviceException)
{
if (429 == serviceException.StatusCode)
if (serviceException.StatusCode == MsalHttpRetryStatusCode)
{
RetryConditionHeaderValue retryAfter = serviceException.Headers.RetryAfter;
if (retryAfter.Delta.HasValue)
Expand All @@ -2440,9 +2443,15 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
}
else
{
break;
SqlClientEventSource.Log.TryTraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken.MsalServiceException error:> Timeout: {0}", serviceException.ErrorCode);
throw SQL.ActiveDirectoryTokenRetrievingTimeout(Enum.GetName(typeof(SqlAuthenticationMethod), ConnectionOptions.Authentication), serviceException.ErrorCode, serviceException);
}
}
else
{
SqlClientEventSource.Log.TryTraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken.MsalServiceException error:> {0}", serviceException.ErrorCode);
throw ADP.CreateSqlException(serviceException, ConnectionOptions, this, username);
}
}
// Deal with normal MsalExceptions.
catch (MsalException msalException)
Expand All @@ -2453,21 +2462,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
{
SqlClientEventSource.Log.TryTraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken.MSALException error:> {0}", msalException.ErrorCode);

// Error[0]
SqlErrorCollection sqlErs = new();
sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS, ConnectionOptions.DataSource, StringsHelper.GetString(Strings.SQL_MSALFailure, username, ConnectionOptions.Authentication.ToString("G")), ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));

// Error[1]
string errorMessage1 = StringsHelper.GetString(Strings.SQL_MSALInnerException, msalException.ErrorCode);
sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS, ConnectionOptions.DataSource, errorMessage1, ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));

// Error[2]
if (!string.IsNullOrEmpty(msalException.Message))
{
sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS, ConnectionOptions.DataSource, msalException.Message, ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));
}
SqlException exc = SqlException.CreateException(sqlErs, "", this);
throw exc;
throw ADP.CreateSqlException(msalException, ConnectionOptions, this, username);
}

SqlClientEventSource.Log.TryAdvancedTraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken|ADV> {0}, sleeping {1}[Milliseconds]", ObjectID, sleepInterval);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,10 @@ internal static Exception ActiveDirectoryDeviceFlowTimeout()
return ADP.TimeoutException(Strings.SQL_Timeout_Active_Directory_DeviceFlow_Authentication);
}

internal static Exception ActiveDirectoryTokenRetrievingTimeout(string authenticaton, string errorCode, Exception exception)
{
return ADP.TimeoutException(StringsHelper.GetString(Strings.AAD_Token_Retrieving_Timeout, authenticaton, errorCode, exception?.Message), exception);
}

//
// SQL.DataCommand
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -1932,4 +1932,7 @@
<data name="SQL_ParameterDirectionInvalidForOptimizedBinding" xml:space="preserve">
<value>Parameter '{0}' cannot have Direction Output or InputOutput when EnableOptimizedParameterBinding is enabled on the parent command.</value>
</data>
</root>
<data name="AAD_Token_Retrieving_Timeout" xml:space="preserve">
<value>Connection timed out while retrieving an access token using '{0}' authentication method. Last error: {1}: {2}</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ public void AssertUnrecoverableStateCountIsCorrect()

sealed internal class SqlInternalConnectionTds : SqlInternalConnection, IDisposable
{
// https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/wiki/retry-after#simple-retry-for-errors-with-http-error-codes-500-600
internal const int MsalHttpRetryStatusCode = 429;

// Connection re-route limit
internal const int _maxNumberOfRedirectRoute = 10;
Expand Down Expand Up @@ -2859,7 +2861,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
// Deal with Msal service exceptions first, retry if 429 received.
catch (MsalServiceException serviceException)
{
if (429 == serviceException.StatusCode)
if (serviceException.StatusCode == MsalHttpRetryStatusCode)
{
RetryConditionHeaderValue retryAfter = serviceException.Headers.RetryAfter;
if (retryAfter.Delta.HasValue)
Expand All @@ -2878,9 +2880,15 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
}
else
{
break;
SqlClientEventSource.Log.TryTraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken.MsalServiceException error:> Timeout: {0}", serviceException.ErrorCode);
throw SQL.ActiveDirectoryTokenRetrievingTimeout(Enum.GetName(typeof(SqlAuthenticationMethod), ConnectionOptions.Authentication), serviceException.ErrorCode, serviceException);
}
}
else
{
SqlClientEventSource.Log.TryTraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken.MsalServiceException error:> {0}", serviceException.ErrorCode);
throw ADP.CreateSqlException(serviceException, ConnectionOptions, this, username);
}
}
// Deal with normal MsalExceptions.
catch (MsalException msalException)
Expand All @@ -2891,21 +2899,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
{
SqlClientEventSource.Log.TryTraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken.MSALException error:> {0}", msalException.ErrorCode);

// Error[0]
SqlErrorCollection sqlErs = new SqlErrorCollection();
sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS, ConnectionOptions.DataSource, StringsHelper.GetString(Strings.SQL_MSALFailure, username, ConnectionOptions.Authentication.ToString("G")), ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));

// Error[1]
string errorMessage1 = StringsHelper.GetString(Strings.SQL_MSALInnerException, msalException.ErrorCode);
sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS, ConnectionOptions.DataSource, errorMessage1, ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));

// Error[2]
if (!string.IsNullOrEmpty(msalException.Message))
{
sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS, ConnectionOptions.DataSource, msalException.Message, ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));
}
SqlException exc = SqlException.CreateException(sqlErs, "", this);
throw exc;
throw ADP.CreateSqlException(msalException, ConnectionOptions, this, username);
}

SqlClientEventSource.Log.TryAdvancedTraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken|ADV> {0}, sleeping {1}[Milliseconds]", ObjectID, sleepInterval);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,11 @@ static internal Exception ActiveDirectoryDeviceFlowTimeout()
return ADP.TimeoutException(Strings.SQL_Timeout_Active_Directory_DeviceFlow_Authentication);
}

internal static Exception ActiveDirectoryTokenRetrievingTimeout(string authenticaton, string errorCode, Exception exception)
{
return ADP.TimeoutException(StringsHelper.GetString(Strings.AAD_Token_Retrieving_Timeout, authenticaton, errorCode, exception?.Message), exception);
}

//
// SQL.DataCommand
//
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -4617,4 +4617,7 @@
<data name="SQL_ParameterDirectionInvalidForOptimizedBinding" xml:space="preserve">
<value>Parameter '{0}' cannot have Direction Output or InputOutput when EnableOptimizedParameterBinding is enabled on the parent command.</value>
</data>
<data name="AAD_Token_Retrieving_Timeout" xml:space="preserve">
<value>Connection timed out while retrieving an access token using '{0}' authentication method. Last error: {1}: {2}</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
using Microsoft.Data.SqlClient.Server;
using Microsoft.Win32;
using IsolationLevel = System.Data.IsolationLevel;
using Microsoft.Identity.Client;

namespace Microsoft.Data.Common
{
Expand Down Expand Up @@ -188,9 +189,9 @@ internal static OverflowException Overflow(string error, Exception inner)
return e;
}

internal static TimeoutException TimeoutException(string error)
internal static TimeoutException TimeoutException(string error, Exception inner = null)
{
TimeoutException e = new(error);
TimeoutException e = new(error, inner);
TraceExceptionAsReturnValue(e);
return e;
}
Expand Down Expand Up @@ -390,7 +391,34 @@ internal static ArgumentException InvalidArgumentLength(string argumentName, int
=> Argument(StringsHelper.GetString(Strings.ADP_InvalidArgumentLength, argumentName, limit));

internal static ArgumentException MustBeReadOnly(string argumentName) => Argument(StringsHelper.GetString(Strings.ADP_MustBeReadOnly, argumentName));
#endregion

internal static Exception CreateSqlException(MsalException msalException, SqlConnectionString connectionOptions, SqlInternalConnectionTds sender, string username)
{
// Error[0]
SqlErrorCollection sqlErs = new();

sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS,
connectionOptions.DataSource,
StringsHelper.GetString(Strings.SQL_MSALFailure, username, connectionOptions.Authentication.ToString("G")),
ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));

// Error[1]
string errorMessage1 = StringsHelper.GetString(Strings.SQL_MSALInnerException, msalException.ErrorCode);
sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS,
connectionOptions.DataSource, errorMessage1,
ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));

// Error[2]
if (!string.IsNullOrEmpty(msalException.Message))
{
sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS,
connectionOptions.DataSource, msalException.Message,
ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));
}
return SqlException.CreateException(sqlErs, "", sender);
}

#endregion

#region CommandBuilder, Command, BulkCopy
/// <summary>
Expand Down

0 comments on commit 8852545

Please sign in to comment.