Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Okta authentication Deadlock issue #139

Merged
merged 2 commits into from
Jun 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Snowflake.Data.Tests/Mock/MockCloseSessionGone.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ public Task<HttpResponseMessage> GetAsync(IRestRequest request, CancellationToke
return Task.FromResult<HttpResponseMessage>(null);
}

public HttpResponseMessage Get(IRestRequest request)
{
return null;
}

public T Post<T>(IRestRequest postRequest)
{
return Task.Run(async () => await PostAsync<T>(postRequest, CancellationToken.None)).Result;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in theory this could cause deadlock as well right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Expand Down
7 changes: 6 additions & 1 deletion Snowflake.Data.Tests/Mock/MockOkta.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public Task<HttpResponseMessage> GetAsync(IRestRequest request, CancellationToke

public T Post<T>(IRestRequest postRequest)
{
throw new System.NotImplementedException();
return Task.Run(async () => await PostAsync<T>(postRequest, CancellationToken.None)).Result;
}

public Task<T> PostAsync<T>(IRestRequest postRequest, CancellationToken cancellationToken)
Expand Down Expand Up @@ -66,5 +66,10 @@ public Task<T> PostAsync<T>(IRestRequest postRequest, CancellationToken cancella
return Task.FromResult<T>((T)(object)tokenResponse);
}
}

public HttpResponseMessage Get(IRestRequest request)
{
return Task.Run(async () => await GetAsync(request, CancellationToken.None)).Result;
}
}
}
5 changes: 5 additions & 0 deletions Snowflake.Data.Tests/Mock/MockRestSessionExpired.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,10 @@ public Task<HttpResponseMessage> GetAsync(IRestRequest request, CancellationToke
{
return Task.FromResult<HttpResponseMessage>(null);
}

public HttpResponseMessage Get(IRestRequest request)
{
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ public Task<HttpResponseMessage> GetAsync(IRestRequest request, CancellationToke
{
return Task.FromResult<HttpResponseMessage>(null);
}

public HttpResponseMessage Get(IRestRequest request)
{
return null;
}
}
}

5 changes: 5 additions & 0 deletions Snowflake.Data.Tests/Mock/MockServiceName.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,10 @@ public Task<HttpResponseMessage> GetAsync(IRestRequest request, CancellationToke
{
return Task.FromResult<HttpResponseMessage>(null);
}

public HttpResponseMessage Get(IRestRequest request)
{
return null;
}
}
}
12 changes: 9 additions & 3 deletions Snowflake.Data.Tests/SFConnectionIT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ public void TestLoginTimeout()
conn.Open();
Assert.Fail();
}
catch(SnowflakeDbException e)
catch (AggregateException e)
{
Assert.AreEqual(SFError.REQUEST_TIMEOUT.GetAttribute<SFErrorAttr>().errorCode,e.ErrorCode);
Assert.AreEqual(SFError.REQUEST_TIMEOUT.GetAttribute<SFErrorAttr>().errorCode,
((SnowflakeDbException)e.InnerException).ErrorCode);
}
Assert.AreEqual(5, conn.ConnectionTimeout);
}
Expand Down Expand Up @@ -301,9 +302,14 @@ public void TestOktaConnection()
{
using (IDbConnection conn = new SnowflakeDbConnection())
{
conn.ConnectionString = "scheme=http;host=10.211.55.3;port=8080;user=qa@snowflakecomputing.com;password=Test123!;" +
conn.ConnectionString = "scheme=http;host=testaccount.reg.snowflakecomputing.com;port=8082;user=qa@snowflakecomputing.com;password=Test123!;" +
"account=testaccount;role=sysadmin;db=testdb;schema=public;warehouse=regress;authenticator=https://snowflakecomputing.okta.com";
conn.Open();
using (IDbCommand command = conn.CreateCommand())
{
command.CommandText = "SELECT 1";
Assert.AreEqual("1", command.ExecuteScalar().ToString());
}
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions Snowflake.Data/Core/Authenticator/BasicAuthenticator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ async Task IAuthenticator.AuthenticateAsync(CancellationToken cancellationToken)
session.ProcessLoginResponse(response);
}

void IAuthenticator.Authenticate()
{
var loginRequest = BuildLoginRequest();

var response = session.restRequester.Post<AuthnResponse>(loginRequest);

session.ProcessLoginResponse(response);
}

private SFRestRequest BuildLoginRequest()
{
// build uri
Expand Down
6 changes: 6 additions & 0 deletions Snowflake.Data/Core/Authenticator/IAuthenticator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ internal interface IAuthenticator
/// <returns></returns>
/// <exception cref="SnowflakeDbException"></exception>
Task AuthenticateAsync(CancellationToken cancellationToken);

/// <summary>
/// Process the authentication synchronously
/// </summary>
/// <exception cref="SnowflakeDbException"></exception>
void Authenticate();
}

/// <summary>
Expand Down
36 changes: 36 additions & 0 deletions Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,42 @@ async Task IAuthenticator.AuthenticateAsync(CancellationToken cancellationToken)
session.ProcessLoginResponse(authnResponse);
}

void IAuthenticator.Authenticate()
{
logger.Info("Okta Authentication");

logger.Debug("step 1: get sso and token url");
var authenticatorRestRequest = BuildAuthenticatorRestRequest();
var authenticatorResponse = session.restRequester.Post<AuthnResponse>(authenticatorRestRequest);
FilterFailedResponse(authenticatorResponse);
Uri ssoUrl = new Uri(authenticatorResponse.data.ssoUrl);
Uri tokenUrl = new Uri(authenticatorResponse.data.tokenUrl);

logger.Debug("step 2: verify urls fetched from step 1");
logger.Debug("Checking sso url");
VerifyUrls(ssoUrl, oktaUrl);
logger.Debug("Checking token url");
VerifyUrls(tokenUrl, oktaUrl);

logger.Debug("step 3: get idp onetime token");
IdpTokenRestRequest idpTokenRestRequest = BuildIdpTokenRestRequest(tokenUrl);
var idpResponse = session.restRequester.Post<IdpTokenResponse>(idpTokenRestRequest);
string onetimeToken = idpResponse.CookieToken;

logger.Debug("step 4: get SAML reponse from sso");
var samlRestRequest = BuildSAMLRestRequest(ssoUrl, onetimeToken);
var samlRawResponse = session.restRequester.Get(samlRestRequest);
var samlRawHtmlString = Task.Run(async () => await samlRawResponse.Content.ReadAsStringAsync()).Result;

logger.Debug("step 5: verify postback url in SAML reponse");
VerifyPostbackUrl(samlRawHtmlString);

logger.Debug("step 6: send SAML reponse to snowflake to login");
var loginRestRequest = BuildOktaLoginRestRequest(samlRawHtmlString);
var authnResponse = session.restRequester.Post<AuthnResponse>(loginRestRequest);
session.ProcessLoginResponse(authnResponse);
}

private SFRestRequest BuildAuthenticatorRestRequest()
{
var fedUrl = session.BuildUri(RestPath.SF_AUTHENTICATOR_REQUEST_PATH);
Expand Down
15 changes: 13 additions & 2 deletions Snowflake.Data/Core/RestRequester.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ internal interface IRestRequester

T Post<T>(IRestRequest postRequest);

T Get<T>(IRestRequest request);

Task<T> GetAsync<T>(IRestRequest request, CancellationToken cancellationToken);

T Get<T>(IRestRequest request);

Task<HttpResponseMessage> GetAsync(IRestRequest request, CancellationToken cancellationToken);

HttpResponseMessage Get(IRestRequest request);
}

internal class RestRequester : IRestRequester
Expand Down Expand Up @@ -80,6 +82,15 @@ public Task<HttpResponseMessage> GetAsync(IRestRequest request, CancellationToke

return SendAsync(message, request.GetRestTimeout(), cancellationToken);
}

public HttpResponseMessage Get(IRestRequest request)
{
HttpRequestMessage message = request.ToRequestMessage(HttpMethod.Get);
logger.Debug($"Http method: {message.ToString()}, http request message: {message.ToString()}");

//Run synchronous in a new thread-pool task.
return Task.Run(async () => await GetAsync(request, CancellationToken.None)).Result;
}

private async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request,
TimeSpan timeoutPerRestRequest,
Expand Down
2 changes: 1 addition & 1 deletion Snowflake.Data/Core/SFSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ internal void Open()
authenticator = AuthenticatorFactory.GetAuthenticator(this);
}

authenticator.AuthenticateAsync(CancellationToken.None).GetAwaiter().GetResult();
authenticator.Authenticate();
}

internal async Task OpenAsync(CancellationToken cancellationToken)
Expand Down