Skip to content

Commit

Permalink
Fix | Handle NRE on Azure federated authentication (#1625)
Browse files Browse the repository at this point in the history
  • Loading branch information
DavoudEshtehari authored Jun 7, 2022
1 parent 9b1996a commit 031afe8
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ public void AssertUnrecoverableStateCountIsCorrect()

internal sealed class SqlInternalConnectionTds : SqlInternalConnection, IDisposable
{
// https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/wiki/retry-after#simple-retry-for-errors-with-http-error-codes-500-600
internal const int MsalHttpRetryStatusCode = 429;

// CONNECTION AND STATE VARIABLES
private readonly SqlConnectionPoolGroupProviderInfo _poolGroupProviderInfo; // will only be null when called for ChangePassword, or creating SSE User Instance
private TdsParser _parser;
Expand Down Expand Up @@ -2421,7 +2424,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
// Deal with Msal service exceptions first, retry if 429 received.
catch (MsalServiceException serviceException)
{
if (429 == serviceException.StatusCode)
if (serviceException.StatusCode == MsalHttpRetryStatusCode)
{
RetryConditionHeaderValue retryAfter = serviceException.Headers.RetryAfter;
if (retryAfter.Delta.HasValue)
Expand All @@ -2440,9 +2443,15 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
}
else
{
break;
SqlClientEventSource.Log.TryTraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken.MsalServiceException error:> Timeout: {0}", serviceException.ErrorCode);
throw SQL.ActiveDirectoryTokenRetrievingTimeout(Enum.GetName(typeof(SqlAuthenticationMethod), ConnectionOptions.Authentication), serviceException.ErrorCode, serviceException);
}
}
else
{
SqlClientEventSource.Log.TryTraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken.MsalServiceException error:> {0}", serviceException.ErrorCode);
throw ADP.CreateSqlException(serviceException, ConnectionOptions, this, username);
}
}
// Deal with normal MsalExceptions.
catch (MsalException msalException)
Expand All @@ -2453,21 +2462,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
{
SqlClientEventSource.Log.TryTraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken.MSALException error:> {0}", msalException.ErrorCode);

// Error[0]
SqlErrorCollection sqlErs = new();
sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS, ConnectionOptions.DataSource, StringsHelper.GetString(Strings.SQL_MSALFailure, username, ConnectionOptions.Authentication.ToString("G")), ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));

// Error[1]
string errorMessage1 = StringsHelper.GetString(Strings.SQL_MSALInnerException, msalException.ErrorCode);
sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS, ConnectionOptions.DataSource, errorMessage1, ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));

// Error[2]
if (!string.IsNullOrEmpty(msalException.Message))
{
sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS, ConnectionOptions.DataSource, msalException.Message, ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));
}
SqlException exc = SqlException.CreateException(sqlErs, "", this);
throw exc;
throw ADP.CreateSqlException(msalException, ConnectionOptions, this, username);
}

SqlClientEventSource.Log.TryAdvancedTraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken|ADV> {0}, sleeping {1}[Milliseconds]", ObjectID, sleepInterval);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,10 @@ internal static Exception ActiveDirectoryDeviceFlowTimeout()
return ADP.TimeoutException(Strings.SQL_Timeout_Active_Directory_DeviceFlow_Authentication);
}

internal static Exception ActiveDirectoryTokenRetrievingTimeout(string authenticaton, string errorCode, Exception exception)
{
return ADP.TimeoutException(StringsHelper.GetString(Strings.AAD_Token_Retrieving_Timeout, authenticaton, errorCode, exception?.Message), exception);
}

//
// SQL.DataCommand
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -1932,10 +1932,7 @@
<data name="SQL_ParameterDirectionInvalidForOptimizedBinding" xml:space="preserve">
<value>Parameter '{0}' cannot have Direction Output or InputOutput when EnableOptimizedParameterBinding is enabled on the parent command.</value>
</data>
<data name="DataCategory_Update" xml:space="preserve">
<value>Update</value>
<data name="AAD_Token_Retrieving_Timeout" xml:space="preserve">
<value>Connection timed out while retrieving an access token using '{0}' authentication method. Last error: {1}: {2}</value>
</data>
<data name="DataCategory_Xml" xml:space="preserve">
<value>XML</value>
</data>
</root>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ public void AssertUnrecoverableStateCountIsCorrect()

sealed internal class SqlInternalConnectionTds : SqlInternalConnection, IDisposable
{
// https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/wiki/retry-after#simple-retry-for-errors-with-http-error-codes-500-600
internal const int MsalHttpRetryStatusCode = 429;

// Connection re-route limit
internal const int _maxNumberOfRedirectRoute = 10;
Expand Down Expand Up @@ -2870,7 +2872,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
// Deal with Msal service exceptions first, retry if 429 received.
catch (MsalServiceException serviceException)
{
if (429 == serviceException.StatusCode)
if (serviceException.StatusCode == MsalHttpRetryStatusCode)
{
RetryConditionHeaderValue retryAfter = serviceException.Headers.RetryAfter;
if (retryAfter.Delta.HasValue)
Expand All @@ -2889,9 +2891,15 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
}
else
{
break;
SqlClientEventSource.Log.TryTraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken.MsalServiceException error:> Timeout: {0}", serviceException.ErrorCode);
throw SQL.ActiveDirectoryTokenRetrievingTimeout(Enum.GetName(typeof(SqlAuthenticationMethod), ConnectionOptions.Authentication), serviceException.ErrorCode, serviceException);
}
}
else
{
SqlClientEventSource.Log.TryTraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken.MsalServiceException error:> {0}", serviceException.ErrorCode);
throw ADP.CreateSqlException(serviceException, ConnectionOptions, this, username);
}
}
// Deal with normal MsalExceptions.
catch (MsalException msalException)
Expand All @@ -2902,21 +2910,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
{
SqlClientEventSource.Log.TryTraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken.MSALException error:> {0}", msalException.ErrorCode);

// Error[0]
SqlErrorCollection sqlErs = new SqlErrorCollection();
sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS, ConnectionOptions.DataSource, StringsHelper.GetString(Strings.SQL_MSALFailure, username, ConnectionOptions.Authentication.ToString("G")), ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));

// Error[1]
string errorMessage1 = StringsHelper.GetString(Strings.SQL_MSALInnerException, msalException.ErrorCode);
sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS, ConnectionOptions.DataSource, errorMessage1, ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));

// Error[2]
if (!string.IsNullOrEmpty(msalException.Message))
{
sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS, ConnectionOptions.DataSource, msalException.Message, ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));
}
SqlException exc = SqlException.CreateException(sqlErs, "", this);
throw exc;
throw ADP.CreateSqlException(msalException, ConnectionOptions, this, username);
}

SqlClientEventSource.Log.TryAdvancedTraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken|ADV> {0}, sleeping {1}[Milliseconds]", ObjectID, sleepInterval);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,11 @@ static internal Exception ActiveDirectoryDeviceFlowTimeout()
return ADP.TimeoutException(Strings.SQL_Timeout_Active_Directory_DeviceFlow_Authentication);
}

internal static Exception ActiveDirectoryTokenRetrievingTimeout(string authenticaton, string errorCode, Exception exception)
{
return ADP.TimeoutException(StringsHelper.GetString(Strings.AAD_Token_Retrieving_Timeout, authenticaton, errorCode, exception?.Message), exception);
}

//
// SQL.DataCommand
//
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/Microsoft.Data.SqlClient/netfx/src/Resources/Strings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -4617,4 +4617,7 @@
<data name="SQL_ParameterDirectionInvalidForOptimizedBinding" xml:space="preserve">
<value>Parameter '{0}' cannot have Direction Output or InputOutput when EnableOptimizedParameterBinding is enabled on the parent command.</value>
</data>
<data name="AAD_Token_Retrieving_Timeout" xml:space="preserve">
<value>Connection timed out while retrieving an access token using '{0}' authentication method. Last error: {1}: {2}</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
using Microsoft.Data.SqlClient;
using Microsoft.Win32;
using IsolationLevel = System.Data.IsolationLevel;
using Microsoft.Identity.Client;

#if NETFRAMEWORK
using Microsoft.SqlServer.Server;
Expand Down Expand Up @@ -214,9 +215,9 @@ internal static OverflowException Overflow(string error, Exception inner)
return e;
}

internal static TimeoutException TimeoutException(string error)
internal static TimeoutException TimeoutException(string error, Exception inner = null)
{
TimeoutException e = new(error);
TimeoutException e = new(error, inner);
TraceExceptionAsReturnValue(e);
return e;
}
Expand Down Expand Up @@ -416,6 +417,33 @@ internal static ArgumentException InvalidArgumentLength(string argumentName, int
=> Argument(StringsHelper.GetString(Strings.ADP_InvalidArgumentLength, argumentName, limit));

internal static ArgumentException MustBeReadOnly(string argumentName) => Argument(StringsHelper.GetString(Strings.ADP_MustBeReadOnly, argumentName));

internal static Exception CreateSqlException(MsalException msalException, SqlConnectionString connectionOptions, SqlInternalConnectionTds sender, string username)
{
// Error[0]
SqlErrorCollection sqlErs = new();

sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS,
connectionOptions.DataSource,
StringsHelper.GetString(Strings.SQL_MSALFailure, username, connectionOptions.Authentication.ToString("G")),
ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));

// Error[1]
string errorMessage1 = StringsHelper.GetString(Strings.SQL_MSALInnerException, msalException.ErrorCode);
sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS,
connectionOptions.DataSource, errorMessage1,
ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));

// Error[2]
if (!string.IsNullOrEmpty(msalException.Message))
{
sqlErs.Add(new SqlError(0, (byte)0x00, (byte)TdsEnums.MIN_ERROR_CLASS,
connectionOptions.DataSource, msalException.Message,
ActiveDirectoryAuthentication.MSALGetAccessTokenFunctionName, 0));
}
return SqlException.CreateException(sqlErs, "", sender);
}

#endregion

#region CommandBuilder, Command, BulkCopy
Expand Down

0 comments on commit 031afe8

Please sign in to comment.