Skip to content

Commit

Permalink
SNOW-352846 OAuth Authentication: #2 OAuth Support (#537)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-alhuang authored Jul 12, 2023
1 parent 7d01fb7 commit 2999390
Show file tree
Hide file tree
Showing 15 changed files with 835 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ final class JWTManager extends SecurityManager {
private final transient KeyPair keyPair;

// the token itself
protected final AtomicReference<String> token;
private final AtomicReference<String> token;

/**
* Creates a JWTManager entity for a given account, user and KeyPair with a specified time to
Expand Down
14 changes: 14 additions & 0 deletions src/main/java/net/snowflake/ingest/connection/OAuthClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Copyright (c) 2023 Snowflake Computing Inc. All rights reserved.
*/

package net.snowflake.ingest.connection;

import java.util.concurrent.atomic.AtomicReference;

/** Interface to perform token refresh request from {@link OAuthManager} */
public interface OAuthClient {
AtomicReference<OAuthCredential> getoAuthCredentialRef();

void refreshToken();
}
55 changes: 55 additions & 0 deletions src/main/java/net/snowflake/ingest/connection/OAuthCredential.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) 2023 Snowflake Computing Inc. All rights reserved.
*/

package net.snowflake.ingest.connection;

import java.util.Base64;

/** This class hold credentials for OAuth authentication */
public class OAuthCredential {
private static final String BASIC_AUTH_HEADER_PREFIX = "Basic ";
private final String authHeader;
private final String clientId;
private final String clientSecret;
private String accessToken;
private String refreshToken;
private int expiresIn;

public OAuthCredential(String clientId, String clientSecret, String refreshToken) {
this.authHeader =
BASIC_AUTH_HEADER_PREFIX
+ Base64.getEncoder().encodeToString((clientId + ":" + clientSecret).getBytes());
this.clientId = clientId;
this.clientSecret = clientSecret;
this.refreshToken = refreshToken;
}

public String getAuthHeader() {
return authHeader;
}

public String getAccessToken() {
return accessToken;
}

public void setAccessToken(String accessToken) {
this.accessToken = accessToken;
}

public String getRefreshToken() {
return refreshToken;
}

public void setRefreshToken(String refreshToken) {
this.refreshToken = refreshToken;
}

public void setExpiresIn(int expiresIn) {
this.expiresIn = expiresIn;
}

public int getExpiresIn() {
return expiresIn;
}
}
175 changes: 175 additions & 0 deletions src/main/java/net/snowflake/ingest/connection/OAuthManager.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
*/

package net.snowflake.ingest.connection;

import java.util.concurrent.TimeUnit;
import net.snowflake.client.jdbc.internal.apache.http.client.utils.URIBuilder;
import net.snowflake.ingest.utils.Constants;
import net.snowflake.ingest.utils.ErrorCode;
import net.snowflake.ingest.utils.SFException;

/** This class manages creating and automatically refresh the OAuth token */
public final class OAuthManager extends SecurityManager {
private static final double DEFAULT_UPDATE_THRESHOLD_RATIO = 0.8;

// the endpoint for token request
private static final String TOKEN_REQUEST_ENDPOINT = "/oauth/token-request";

// Content type header to specify the encoding
private static final String OAUTH_CONTENT_TYPE_HEADER = "application/x-www-form-urlencoded";

// Properties for token refresh request
private static final String TOKEN_TYPE = "OAUTH";

// Update threshold, a floating-point value representing the ratio between the expiration time of
// an access token and the time needed to update it. It must be a value greater than 0 and less
// than 1. E.g. An access token with expires_in=600 and update_threshold_ratio=0.8 would be
// updated after 600*0.8 = 480.
private final double updateThresholdRatio;

private OAuthClient oAuthClient;

/**
* Creates a OAuthManager entity for a given account, user and OAuthCredential with default time
* to refresh the access token
*
* @param accountName - the snowflake account name of this user
* @param username - the snowflake username of the current user
* @param oAuthCredential - the OAuth credential we're using to connect
* @param baseURIBuilder - the uri builder with common scheme, host and port
* @param telemetryService reference to the telemetry service
*/
OAuthManager(
String accountName,
String username,
OAuthCredential oAuthCredential,
URIBuilder baseURIBuilder,
TelemetryService telemetryService) {
this(
accountName,
username,
oAuthCredential,
baseURIBuilder,
DEFAULT_UPDATE_THRESHOLD_RATIO,
telemetryService);
}

/**
* Creates a OAuthManager entity for a given account, user and OAuthCredential with a specified
* time to refresh the token
*
* @param accountName - the snowflake account name of this user
* @param username - the snowflake username of the current user
* @param oAuthCredential - the OAuth credential we're using to connect
* @param baseURIBuilder - the uri builder with common scheme, host and port
* @param updateThresholdRatio - the ratio between the expiration time of a token and the time
* needed to refresh it.
* @param telemetryService reference to the telemetry service
*/
OAuthManager(
String accountName,
String username,
OAuthCredential oAuthCredential,
URIBuilder baseURIBuilder,
double updateThresholdRatio,
TelemetryService telemetryService) {
// disable telemetry service until jdbc v3.13.34 is released
super(accountName, username, null);

// if any of our arguments are null, throw an exception
if (oAuthCredential == null || baseURIBuilder == null) {
throw new IllegalArgumentException();
}

// check if update threshold is within (0, 1)
if (updateThresholdRatio <= 0 || updateThresholdRatio >= 1) {
throw new IllegalArgumentException("updateThresholdRatio should fall in (0, 1)");
}
this.updateThresholdRatio = updateThresholdRatio;
this.oAuthClient = new SnowflakeOAuthClient(accountName, oAuthCredential, baseURIBuilder);

// generate our first token
refreshToken();
}

/**
* Creates a OAuthManager entity for a given account, user and OAuthClient with a specified time
* to refresh the token. Use for testing only.
*
* @param accountName - the snowflake account name of this user
* @param username - the snowflake username of the current user
* @param oAuthClient - the OAuth client to perform token refresh
* @param updateThresholdRatio - the ratio between the expiration time of a token and the time
* needed to refresh it.
*/
public OAuthManager(
String accountName, String username, OAuthClient oAuthClient, double updateThresholdRatio) {
super(accountName, username, null);

this.updateThresholdRatio = updateThresholdRatio;
this.oAuthClient = oAuthClient;

refreshToken();
}

@Override
String getToken() {
if (refreshFailed.get()) {
throw new SecurityException("getToken request failed due to token refresh failure");
}
return oAuthClient.getoAuthCredentialRef().get().getAccessToken();
}

@Override
String getTokenType() {
return TOKEN_TYPE;
}

/**
* Set refresh token, this method is for refresh token renewal without requiring to restart
* client. This method only works when the authorization type is OAuth.
*
* @param refreshToken the new refresh token
*/
void setRefreshToken(String refreshToken) {
oAuthClient.getoAuthCredentialRef().get().setRefreshToken(refreshToken);
}

/** refreshToken - Get new access token using refresh_token, client_id, client_secret */
private void refreshToken() {
for (int retries = 0; retries < Constants.MAX_OAUTH_REFRESH_TOKEN_RETRY; retries++) {
try {
oAuthClient.refreshToken();

// Schedule next refresh
long nextRefreshDelay =
(long)
(oAuthClient.getoAuthCredentialRef().get().getExpiresIn()
* this.updateThresholdRatio);
tokenRefresher.schedule(this::refreshToken, nextRefreshDelay, TimeUnit.SECONDS);
LOGGER.debug(
"Refresh access token, next refresh is scheduled after {} seconds", nextRefreshDelay);

return;
} catch (SFException e1) {
// Exponential backoff retries
try {
Thread.sleep((1L << retries) * 1000L);
} catch (InterruptedException e2) {
throw new SFException(ErrorCode.OAUTH_REFRESH_TOKEN_ERROR, e2.getMessage());
}
}
}

refreshFailed.set(true);
throw new SecurityException("Fail to refresh access token");
}

/** Currently, it only shuts down the instance of ExecutorService. */
@Override
public void close() {
tokenRefresher.shutdown();
}
}
Loading

0 comments on commit 2999390

Please sign in to comment.