From 16510064e861868f649b6bc8fdc54b8a39890812 Mon Sep 17 00:00:00 2001 From: Leo <39062083+lsirac@users.noreply.github.com> Date: Thu, 5 Sep 2024 11:05:35 -0700 Subject: [PATCH] feat: updates UserAuthorizer to support retrieving token response directly with different client auth types (#1486) * feat: updates UserAuthorizer to support retrieving token response directly * fix: cleanup * fix: review * fix: review suggestions * fix: incorrect import * fix: adds missing check --- .../com/google/auth/oauth2/OAuth2Utils.java | 24 + .../google/auth/oauth2/UserAuthorizer.java | 456 ++++++++++++++---- .../javatests/com/google/auth/TestUtils.java | 6 + .../auth/oauth2/MockTokenServerTransport.java | 408 ++++++++++------ .../google/auth/oauth2/OAuth2UtilsTest.java | 101 ++++ .../auth/oauth2/UserAuthorizerTest.java | 254 +++++++++- 6 files changed, 995 insertions(+), 254 deletions(-) create mode 100644 oauth2_http/javatests/com/google/auth/oauth2/OAuth2UtilsTest.java diff --git a/oauth2_http/java/com/google/auth/oauth2/OAuth2Utils.java b/oauth2_http/java/com/google/auth/oauth2/OAuth2Utils.java index 2fb0bda66..31c422bd4 100644 --- a/oauth2_http/java/com/google/auth/oauth2/OAuth2Utils.java +++ b/oauth2_http/java/com/google/auth/oauth2/OAuth2Utils.java @@ -43,6 +43,8 @@ import com.google.api.client.util.SecurityUtils; import com.google.auth.http.AuthHttpConstants; import com.google.auth.http.HttpTransportFactory; +import com.google.common.base.Strings; +import com.google.common.io.BaseEncoding; import com.google.common.io.ByteStreams; import java.io.ByteArrayInputStream; import java.io.File; @@ -80,6 +82,7 @@ class OAuth2Utils { "https://iamcredentials.%s/v1/projects/-/serviceAccounts/%s:generateIdToken"; static final URI TOKEN_SERVER_URI = URI.create("https://oauth2.googleapis.com/token"); + static final URI TOKEN_REVOKE_URI = URI.create("https://oauth2.googleapis.com/revoke"); static final URI USER_AUTH_URI = URI.create("https://accounts.google.com/o/oauth2/auth"); @@ -261,5 +264,26 @@ static PrivateKey privateKeyFromPkcs8(String privateKeyPkcs8) throws IOException throw new IOException("Unexpected exception reading PKCS#8 data", unexpectedException); } + /** + * Generates a Basic Authentication header string for the provided username and password. + * + *

This method constructs a Basic Authentication string using the provided username and + * password. The credentials are encoded in Base64 format and prefixed with the "Basic " scheme + * identifier. + * + * @param username The username for authentication. + * @param password The password for authentication. + * @return The Basic Authentication header value. + * @throws IllegalArgumentException if either username or password is null or empty. + */ + static String generateBasicAuthHeader(String username, String password) { + if (Strings.isNullOrEmpty(username) || Strings.isNullOrEmpty(password)) { + throw new IllegalArgumentException("Username and password cannot be null or empty."); + } + String credentials = username + ":" + password; + String encodedCredentials = BaseEncoding.base64().encode(credentials.getBytes()); + return "Basic " + encodedCredentials; + } + private OAuth2Utils() {} } diff --git a/oauth2_http/java/com/google/auth/oauth2/UserAuthorizer.java b/oauth2_http/java/com/google/auth/oauth2/UserAuthorizer.java index 19305180e..5a008c705 100644 --- a/oauth2_http/java/com/google/auth/oauth2/UserAuthorizer.java +++ b/oauth2_http/java/com/google/auth/oauth2/UserAuthorizer.java @@ -52,10 +52,22 @@ import java.util.Date; import java.util.List; import java.util.Map; +import javax.annotation.Nullable; /** Handles an interactive 3-Legged-OAuth2 (3LO) user consent authorization. */ public class UserAuthorizer { + /** + * Represents the client authentication types as specified in RFC 7591. + * + *

For more details, see RFC 7591. + */ + public enum ClientAuthenticationType { + CLIENT_SECRET_POST, + CLIENT_SECRET_BASIC, + NONE + } + static final URI DEFAULT_CALLBACK_URI = URI.create("/oauth2callback"); private final String TOKEN_STORE_ERROR = "Error parsing stored token data."; @@ -70,38 +82,27 @@ public class UserAuthorizer { private final URI tokenServerUri; private final URI userAuthUri; private final PKCEProvider pkce; + private final ClientAuthenticationType clientAuthenticationType; - /** - * Constructor with all parameters. - * - * @param clientId Client ID to identify the OAuth2 consent prompt - * @param scopes OAuth2 scopes defining the user consent - * @param tokenStore Implementation of a component for long term storage of tokens - * @param callbackUri URI for implementation of the OAuth2 web callback - * @param transportFactory HTTP transport factory, creates the transport used to get access - * tokens. - * @param tokenServerUri URI of the end point that provides tokens - * @param userAuthUri URI of the Web UI for user consent - * @param pkce PKCE implementation - */ - private UserAuthorizer( - ClientId clientId, - Collection scopes, - TokenStore tokenStore, - URI callbackUri, - HttpTransportFactory transportFactory, - URI tokenServerUri, - URI userAuthUri, - PKCEProvider pkce) { - this.clientId = Preconditions.checkNotNull(clientId); - this.scopes = ImmutableList.copyOf(Preconditions.checkNotNull(scopes)); - this.callbackUri = (callbackUri == null) ? DEFAULT_CALLBACK_URI : callbackUri; + /** Internal constructor. See {@link Builder}. */ + private UserAuthorizer(Builder builder) { + this.clientId = Preconditions.checkNotNull(builder.clientId); + this.scopes = ImmutableList.copyOf(Preconditions.checkNotNull(builder.scopes)); + this.callbackUri = (builder.callbackUri == null) ? DEFAULT_CALLBACK_URI : builder.callbackUri; this.transportFactory = - (transportFactory == null) ? OAuth2Utils.HTTP_TRANSPORT_FACTORY : transportFactory; - this.tokenServerUri = (tokenServerUri == null) ? OAuth2Utils.TOKEN_SERVER_URI : tokenServerUri; - this.userAuthUri = (userAuthUri == null) ? OAuth2Utils.USER_AUTH_URI : userAuthUri; - this.tokenStore = (tokenStore == null) ? new MemoryTokensStorage() : tokenStore; - this.pkce = pkce; + (builder.transportFactory == null) + ? OAuth2Utils.HTTP_TRANSPORT_FACTORY + : builder.transportFactory; + this.tokenServerUri = + (builder.tokenServerUri == null) ? OAuth2Utils.TOKEN_SERVER_URI : builder.tokenServerUri; + this.userAuthUri = + (builder.userAuthUri == null) ? OAuth2Utils.USER_AUTH_URI : builder.userAuthUri; + this.tokenStore = (builder.tokenStore == null) ? new MemoryTokensStorage() : builder.tokenStore; + this.pkce = builder.pkce; + this.clientAuthenticationType = + (builder.clientAuthenticationType == null) + ? ClientAuthenticationType.CLIENT_SECRET_POST + : builder.clientAuthenticationType; } /** @@ -162,7 +163,16 @@ public TokenStore getTokenStore() { } /** - * Return an URL that performs the authorization consent prompt web UI. + * Returns the client authentication type as defined in RFC 7591. + * + * @return The {@link ClientAuthenticationType} + */ + public ClientAuthenticationType getClientAuthenticationType() { + return clientAuthenticationType; + } + + /** + * Return a URL that performs the authorization consent prompt web UI. * * @param userId Application's identifier for the end user. * @param state State that is passed on to the OAuth2 callback URI after the consent. @@ -174,7 +184,7 @@ public URL getAuthorizationUrl(String userId, String state, URI baseUri) { } /** - * Return an URL that performs the authorization consent prompt web UI. + * Return a URL that performs the authorization consent prompt web UI. * * @param userId Application's identifier for the end user. * @param state State that is passed on to the OAuth2 callback URI after the consent. @@ -285,61 +295,35 @@ public UserCredentials getCredentialsFromCode(String code, URI baseUri) throws I */ public UserCredentials getCredentialsFromCode( String code, URI baseUri, Map additionalParameters) throws IOException { - Preconditions.checkNotNull(code); - URI resolvedCallbackUri = getCallbackUri(baseUri); - - GenericData tokenData = new GenericData(); - tokenData.put("code", code); - tokenData.put("client_id", clientId.getClientId()); - tokenData.put("client_secret", clientId.getClientSecret()); - tokenData.put("redirect_uri", resolvedCallbackUri); - tokenData.put("grant_type", "authorization_code"); - - if (additionalParameters != null) { - for (Map.Entry entry : additionalParameters.entrySet()) { - tokenData.put(entry.getKey(), entry.getValue()); - } - } - - if (pkce != null) { - tokenData.put("code_verifier", pkce.getCodeVerifier()); - } - - UrlEncodedContent tokenContent = new UrlEncodedContent(tokenData); - HttpRequestFactory requestFactory = transportFactory.create().createRequestFactory(); - HttpRequest tokenRequest = - requestFactory.buildPostRequest(new GenericUrl(tokenServerUri), tokenContent); - tokenRequest.setParser(new JsonObjectParser(OAuth2Utils.JSON_FACTORY)); - - HttpResponse tokenResponse = tokenRequest.execute(); - - GenericJson parsedTokens = tokenResponse.parseAs(GenericJson.class); - String accessTokenValue = - OAuth2Utils.validateString(parsedTokens, "access_token", FETCH_TOKEN_ERROR); - int expiresInSecs = OAuth2Utils.validateInt32(parsedTokens, "expires_in", FETCH_TOKEN_ERROR); - Date expirationTime = new Date(new Date().getTime() + expiresInSecs * 1000); - String scopes = - OAuth2Utils.validateOptionalString( - parsedTokens, OAuth2Utils.TOKEN_RESPONSE_SCOPE, FETCH_TOKEN_ERROR); - AccessToken accessToken = - AccessToken.newBuilder() - .setExpirationTime(expirationTime) - .setTokenValue(accessTokenValue) - .setScopes(scopes) - .build(); - String refreshToken = - OAuth2Utils.validateOptionalString(parsedTokens, "refresh_token", FETCH_TOKEN_ERROR); - + TokenResponseWithConfig tokenResponseWithConfig = + getCredentialsFromCodeInternal(code, baseUri, additionalParameters); return UserCredentials.newBuilder() - .setClientId(clientId.getClientId()) - .setClientSecret(clientId.getClientSecret()) - .setRefreshToken(refreshToken) - .setAccessToken(accessToken) - .setHttpTransportFactory(transportFactory) - .setTokenServerUri(tokenServerUri) + .setClientId(tokenResponseWithConfig.getClientId()) + .setClientSecret(tokenResponseWithConfig.getClientSecret()) + .setAccessToken(tokenResponseWithConfig.getAccessToken()) + .setRefreshToken(tokenResponseWithConfig.getRefreshToken()) + .setHttpTransportFactory(tokenResponseWithConfig.getHttpTransportFactory()) + .setTokenServerUri(tokenResponseWithConfig.getTokenServerUri()) .build(); } + /** + * Handles OAuth2 authorization code exchange and returns a {@link TokenResponseWithConfig} object + * containing the tokens and configuration details. + * + * @param code The authorization code received from the OAuth2 authorization server. + * @param callbackUri The URI to which the authorization server redirected the user after granting + * authorization. + * @param additionalParameters Additional parameters to include in the token exchange request. + * @return A {@link TokenResponseWithConfig} object containing the access token, refresh token (if + * granted), and configuration details used in the OAuth flow. + * @throws IOException If an error occurs during the token exchange process. + */ + public TokenResponseWithConfig getTokenResponseFromAuthCodeExchange( + String code, URI callbackUri, Map additionalParameters) throws IOException { + return getCredentialsFromCodeInternal(code, callbackUri, additionalParameters); + } + /** * Exchanges an authorization code for tokens and stores them. * @@ -418,7 +402,6 @@ public void storeCredentials(String userId, UserCredentials credentials) throws } AccessToken accessToken = credentials.getAccessToken(); String acessTokenValue = null; - String scopes = null; Date expiresBy = null; List grantedScopes = new ArrayList<>(); @@ -450,6 +433,74 @@ protected void monitorCredentials(String userId, UserCredentials credentials) { credentials.addChangeListener(new UserCredentialsListener(userId)); } + private TokenResponseWithConfig getCredentialsFromCodeInternal( + String code, URI baseUri, Map additionalParameters) throws IOException { + Preconditions.checkNotNull(code); + URI resolvedCallbackUri = getCallbackUri(baseUri); + + GenericData tokenData = new GenericData(); + tokenData.put("code", code); + tokenData.put("client_id", clientId.getClientId()); + tokenData.put("redirect_uri", resolvedCallbackUri); + tokenData.put("grant_type", "authorization_code"); + + if (additionalParameters != null) { + for (Map.Entry entry : additionalParameters.entrySet()) { + tokenData.put(entry.getKey(), entry.getValue()); + } + } + + if (pkce != null) { + tokenData.put("code_verifier", pkce.getCodeVerifier()); + } + + if (this.clientAuthenticationType == ClientAuthenticationType.CLIENT_SECRET_POST) { + tokenData.put("client_secret", clientId.getClientSecret()); + } + + HttpRequestFactory requestFactory = transportFactory.create().createRequestFactory(); + UrlEncodedContent tokenContent = new UrlEncodedContent(tokenData); + HttpRequest tokenRequest = + requestFactory.buildPostRequest(new GenericUrl(tokenServerUri), tokenContent); + tokenRequest.setParser(new JsonObjectParser(OAuth2Utils.JSON_FACTORY)); + + if (this.clientAuthenticationType == ClientAuthenticationType.CLIENT_SECRET_BASIC) { + tokenRequest + .getHeaders() + .setAuthorization( + OAuth2Utils.generateBasicAuthHeader( + clientId.getClientId(), clientId.getClientSecret())); + } + + HttpResponse tokenResponse = tokenRequest.execute(); + + GenericJson parsedTokens = tokenResponse.parseAs(GenericJson.class); + String accessTokenValue = + OAuth2Utils.validateString(parsedTokens, "access_token", FETCH_TOKEN_ERROR); + int expiresInSecs = OAuth2Utils.validateInt32(parsedTokens, "expires_in", FETCH_TOKEN_ERROR); + Date expirationTime = new Date(new Date().getTime() + expiresInSecs * 1000); + String scopes = + OAuth2Utils.validateOptionalString( + parsedTokens, OAuth2Utils.TOKEN_RESPONSE_SCOPE, FETCH_TOKEN_ERROR); + AccessToken accessToken = + AccessToken.newBuilder() + .setExpirationTime(expirationTime) + .setTokenValue(accessTokenValue) + .setScopes(scopes) + .build(); + String refreshToken = + OAuth2Utils.validateOptionalString(parsedTokens, "refresh_token", FETCH_TOKEN_ERROR); + + return TokenResponseWithConfig.newBuilder() + .setClientId(clientId.getClientId()) + .setClientSecret(clientId.getClientSecret()) + .setAccessToken(accessToken) + .setRefreshToken(refreshToken) + .setHttpTransportFactory(transportFactory) + .setTokenServerUri(tokenServerUri) + .build(); + } + /** * Implementation of listener used by monitorCredentials to rewrite the credentials when the * tokens are refreshed. @@ -488,6 +539,7 @@ public static class Builder { private Collection scopes; private HttpTransportFactory transportFactory; private PKCEProvider pkce; + private ClientAuthenticationType clientAuthenticationType; protected Builder() {} @@ -500,50 +552,102 @@ protected Builder(UserAuthorizer authorizer) { this.callbackUri = authorizer.callbackUri; this.userAuthUri = authorizer.userAuthUri; this.pkce = new DefaultPKCEProvider(); + this.clientAuthenticationType = authorizer.clientAuthenticationType; } + /** + * Sets the OAuth 2.0 client ID. + * + * @param clientId the client ID + * @return this {@code Builder} object + */ @CanIgnoreReturnValue public Builder setClientId(ClientId clientId) { this.clientId = clientId; return this; } + /** + * Sets the {@link TokenStore} to use for long term token storage. + * + * @param tokenStore the token store + * @return this {@code Builder} object + */ @CanIgnoreReturnValue public Builder setTokenStore(TokenStore tokenStore) { this.tokenStore = tokenStore; return this; } + /** + * Sets the OAuth 2.0 scopes to request. + * + * @param scopes the scopes to request + * @return this {@code Builder} object + */ @CanIgnoreReturnValue public Builder setScopes(Collection scopes) { this.scopes = scopes; return this; } + /** + * Sets the token exchange endpoint. + * + * @param tokenServerUri the token exchange endpoint to use + * @return this {@code Builder} object + */ @CanIgnoreReturnValue public Builder setTokenServerUri(URI tokenServerUri) { this.tokenServerUri = tokenServerUri; return this; } + /** + * Sets the redirect URI registered with your OAuth provider. This is where the user's browser + * will be redirected after granting or denying authorization. + * + * @param callbackUri the redirect URI + * @return this {@code Builder} object + */ @CanIgnoreReturnValue public Builder setCallbackUri(URI callbackUri) { this.callbackUri = callbackUri; return this; } + /** + * Sets the authorization URI where the user is directed to log in and grant authorization. + * + * @param userAuthUri the authorization URI + * @return this {@code Builder} object + */ @CanIgnoreReturnValue public Builder setUserAuthUri(URI userAuthUri) { this.userAuthUri = userAuthUri; return this; } + /** + * Sets the HTTP transport factory. + * + * @param transportFactory the {@code HttpTransportFactory} to set + * @return this {@code Builder} object + */ @CanIgnoreReturnValue public Builder setHttpTransportFactory(HttpTransportFactory transportFactory) { this.transportFactory = transportFactory; return this; } + /** + * Sets the optional {@link PKCEProvider} to enable Proof Key for Code Exchange to be used. This + * enhances security by using a code challenge and verifier to prevent authorization code + * interception attacks. + * + * @param pkce the {@code PKCEProvider} to set + * @return this {@code Builder} object + */ @CanIgnoreReturnValue public Builder setPKCEProvider(PKCEProvider pkce) { if (pkce != null) { @@ -559,6 +663,20 @@ public Builder setPKCEProvider(PKCEProvider pkce) { return this; } + /** + * Sets the optional {@link ClientAuthenticationType}, one of the client authentication methods + * defined in RFC 7591. This specifies how your application authenticates itself to the + * authorization server. + * + * @param clientAuthentication the {@code ClientAuthenticationType} to set + * @return this {@code Builder} object + */ + @CanIgnoreReturnValue + public Builder setClientAuthenticationType(ClientAuthenticationType clientAuthentication) { + this.clientAuthenticationType = clientAuthentication; + return this; + } + public ClientId getClientId() { return clientId; } @@ -591,16 +709,168 @@ public PKCEProvider getPKCEProvider() { return pkce; } + public ClientAuthenticationType getClientAuthenticationType() { + return clientAuthenticationType; + } + public UserAuthorizer build() { - return new UserAuthorizer( - clientId, - scopes, - tokenStore, - callbackUri, - transportFactory, - tokenServerUri, - userAuthUri, - pkce); + return new UserAuthorizer(this); + } + } + + /** + * Represents the response from an OAuth token exchange, including configuration details used to + * initiate the flow. + * + *

This response can be used to initialize the following credentials types: + * + *

{@code
+   * // UserCredentials when Google is the identity provider:
+   * UserCredentials userCredentials = UserCredentials.newBuilder()
+   *     .setHttpTransportFactory(tokenResponseWithConfig.getHttpTransportFactory())
+   *     .setClientId(tokenResponseWithConfig.getClientId())
+   *     .setClientSecret(tokenResponseWithConfig.getClientSecret())
+   *     .setAccessToken(tokenResponseWithConfig.getAccessToken())
+   *     .setRefreshToken(tokenResponseWithConfig.getRefreshToken())
+   *     .setTokenServerUri(tokenResponseWithConfig.getTokenServerUri())
+   *     .build();
+   *
+   * // ExternalAccountAuthorizedUserCredentials when using Workforce Identity Federation:
+   * ExternalAccountAuthorizedUserCredentials externalAccountAuthorizedUserCredentials =
+   *     ExternalAccountAuthorizedUserCredentials.newBuilder()
+   *         .setHttpTransportFactory(tokenResponseWithConfig.getHttpTransportFactory())
+   *         .setClientId(tokenResponseWithConfig.getClientId())
+   *         .setClientSecret(tokenResponseWithConfig.getClientSecret())
+   *         .setAccessToken(tokenResponseWithConfig.getAccessToken())
+   *         .setRefreshToken(tokenResponseWithConfig.getRefreshToken())
+   *         .setTokenUrl(tokenResponseWithConfig.getTokenServerUri().toURL().toString())
+   *         .build();
+   * }
+ */ + public static class TokenResponseWithConfig { + + private final String clientId; + private final String clientSecret; + private final String refreshToken; + private final AccessToken accessToken; + private URI tokenServerUri; + private final HttpTransportFactory httpTransportFactory; + + private TokenResponseWithConfig(Builder builder) { + this.clientId = builder.clientId; + this.clientSecret = builder.clientSecret; + this.accessToken = builder.accessToken; + this.httpTransportFactory = builder.httpTransportFactory; + this.tokenServerUri = builder.tokenServerUri; + this.refreshToken = builder.refreshToken; + } + + /** + * Returns the OAuth 2.0 client ID used. + * + * @return The client ID. + */ + public String getClientId() { + return clientId; + } + + /** + * Returns the OAuth 2.0 client secret used. + * + * @return The client secret. + */ + public String getClientSecret() { + return clientSecret; + } + + /** + * Returns the access token obtained from the token exchange. + * + * @return The access token. + */ + public AccessToken getAccessToken() { + return accessToken; + } + + /** + * Returns the HTTP transport factory used. + * + * @return The HTTP transport factory. + */ + public HttpTransportFactory getHttpTransportFactory() { + return httpTransportFactory; + } + + /** + * Returns the URI of the token server used. + * + * @return The token server URI. + */ + public URI getTokenServerUri() { + return tokenServerUri; + } + + /** + * Returns the refresh token obtained from the token exchange, if available. + * + * @return The refresh token, or null if not granted. + */ + @Nullable + public String getRefreshToken() { + return refreshToken; + } + + static Builder newBuilder() { + return new Builder(); + } + + static class Builder { + private String clientId; + private String clientSecret; + private String refreshToken; + private AccessToken accessToken; + private URI tokenServerUri; + private HttpTransportFactory httpTransportFactory; + + @CanIgnoreReturnValue + Builder setClientId(String clientId) { + this.clientId = clientId; + return this; + } + + @CanIgnoreReturnValue + Builder setClientSecret(String clientSecret) { + this.clientSecret = clientSecret; + return this; + } + + @CanIgnoreReturnValue + Builder setRefreshToken(String refreshToken) { + this.refreshToken = refreshToken; + return this; + } + + @CanIgnoreReturnValue + Builder setAccessToken(AccessToken accessToken) { + this.accessToken = accessToken; + return this; + } + + @CanIgnoreReturnValue + Builder setHttpTransportFactory(HttpTransportFactory httpTransportFactory) { + this.httpTransportFactory = httpTransportFactory; + return this; + } + + @CanIgnoreReturnValue + Builder setTokenServerUri(URI tokenServerUri) { + this.tokenServerUri = tokenServerUri; + return this; + } + + TokenResponseWithConfig build() { + return new TokenResponseWithConfig(this); + } } } } diff --git a/oauth2_http/javatests/com/google/auth/TestUtils.java b/oauth2_http/javatests/com/google/auth/TestUtils.java index f7c53c9f9..99d601da8 100644 --- a/oauth2_http/javatests/com/google/auth/TestUtils.java +++ b/oauth2_http/javatests/com/google/auth/TestUtils.java @@ -47,6 +47,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.UnsupportedEncodingException; +import java.net.URI; import java.net.URLDecoder; import java.text.SimpleDateFormat; import java.util.Calendar; @@ -59,6 +60,11 @@ /** Utilities for test code under com.google.auth. */ public class TestUtils { + public static final URI WORKFORCE_IDENTITY_FEDERATION_AUTH_URI = + URI.create("https://auth.cloud.google/authorize"); + public static final URI WORKFORCE_IDENTITY_FEDERATION_TOKEN_SERVER_URI = + URI.create("https://sts.googleapis.com/v1/oauthtoken"); + private static final JsonFactory JSON_FACTORY = GsonFactory.getDefaultInstance(); public static void assertContainsBearerToken(Map> metadata, String token) { diff --git a/oauth2_http/javatests/com/google/auth/oauth2/MockTokenServerTransport.java b/oauth2_http/javatests/com/google/auth/oauth2/MockTokenServerTransport.java index 95680c02e..a61c185b5 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/MockTokenServerTransport.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/MockTokenServerTransport.java @@ -31,6 +31,8 @@ package com.google.auth.oauth2; +import static com.google.auth.TestUtils.WORKFORCE_IDENTITY_FEDERATION_TOKEN_SERVER_URI; + import com.google.api.client.http.LowLevelHttpRequest; import com.google.api.client.http.LowLevelHttpResponse; import com.google.api.client.json.GenericJson; @@ -42,6 +44,7 @@ import com.google.api.client.testing.http.MockLowLevelHttpRequest; import com.google.api.client.testing.http.MockLowLevelHttpResponse; import com.google.auth.TestUtils; +import com.google.auth.oauth2.UserAuthorizer.ClientAuthenticationType; import com.google.common.util.concurrent.Futures; import java.io.IOException; import java.net.URI; @@ -71,6 +74,9 @@ public class MockTokenServerTransport extends MockHttpTransport { private IOException error; private final Queue> responseSequence = new ArrayDeque<>(); private int expiresInSeconds = 3600; + private MockLowLevelHttpRequest request; + private ClientAuthenticationType clientAuthenticationType; + private PKCEProvider pkceProvider; public MockTokenServerTransport() {} @@ -82,6 +88,14 @@ public void setTokenServerUri(URI tokenServerUri) { this.tokenServerUri = tokenServerUri; } + public void setClientAuthType(ClientAuthenticationType type) { + this.clientAuthenticationType = type; + } + + public void setPkceProvider(PKCEProvider pkceProvider) { + this.pkceProvider = pkceProvider; + } + public void addAuthorizationCode( String code, String refreshToken, @@ -93,7 +107,7 @@ public void addAuthorizationCode( this.grantedScopes.put(refreshToken, grantedScopes); if (additionalParameters != null) { - this.additionalParameters.put(refreshToken, additionalParameters); + this.additionalParameters.put(code, additionalParameters); } } @@ -147,6 +161,10 @@ public void setExpiresInSeconds(int expiresInSeconds) { this.expiresInSeconds = expiresInSeconds; } + public MockLowLevelHttpRequest getRequest() { + return request; + } + @Override public LowLevelHttpRequest buildRequest(String method, String url) throws IOException { buildRequestCount++; @@ -158,183 +176,265 @@ public LowLevelHttpRequest buildRequest(String method, String url) throws IOExce final String query = (questionMarkPos > 0) ? url.substring(questionMarkPos + 1) : ""; if (!responseSequence.isEmpty()) { - return new MockLowLevelHttpRequest(url) { - @Override - public LowLevelHttpResponse execute() throws IOException { - try { - return responseSequence.poll().get(); - } catch (ExecutionException e) { - Throwable cause = e.getCause(); - throw (IOException) cause; - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException("Unexpectedly interrupted"); - } - } - }; + request = + new MockLowLevelHttpRequest(url) { + @Override + public LowLevelHttpResponse execute() throws IOException { + try { + return responseSequence.poll().get(); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + throw (IOException) cause; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Unexpectedly interrupted"); + } + } + }; + return request; } if (urlWithoutQuery.equals(tokenServerUri.toString())) { - return new MockLowLevelHttpRequest(url) { - @Override - public LowLevelHttpResponse execute() throws IOException { - - if (!responseSequence.isEmpty()) { - try { - return responseSequence.poll().get(); - } catch (ExecutionException e) { - Throwable cause = e.getCause(); - throw (IOException) cause; - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException("Unexpectedly interrupted"); - } - } - - String content = this.getContentAsString(); - Map query = TestUtils.parseQuery(content); - String accessToken = null; - String refreshToken = null; - String grantedScopesString = null; - boolean generateAccessToken = true; - - String foundId = query.get("client_id"); - boolean isUserEmailScope = false; - if (foundId != null) { - if (!clients.containsKey(foundId)) { - throw new IOException("Client ID not found."); - } - String foundSecret = query.get("client_secret"); - String expectedSecret = clients.get(foundId); - if (foundSecret == null || !foundSecret.equals(expectedSecret)) { - throw new IOException("Client secret not found."); - } - String grantType = query.get("grant_type"); - if (grantType != null && grantType.equals("authorization_code")) { - String foundCode = query.get("code"); - if (!codes.containsKey(foundCode)) { - throw new IOException("Authorization code not found"); + request = + new MockLowLevelHttpRequest(url) { + @Override + public LowLevelHttpResponse execute() throws IOException { + + if (!responseSequence.isEmpty()) { + try { + return responseSequence.poll().get(); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + throw (IOException) cause; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Unexpectedly interrupted"); + } } - refreshToken = codes.get(foundCode); - } else { - refreshToken = query.get("refresh_token"); - } - if (!refreshTokens.containsKey(refreshToken)) { - throw new IOException("Refresh Token not found."); - } - if (refreshToken.equals(REFRESH_TOKEN_WITH_USER_SCOPE)) { - isUserEmailScope = true; - } - accessToken = refreshTokens.get(refreshToken); - - if (grantedScopes.containsKey(refreshToken)) { - grantedScopesString = grantedScopes.get(refreshToken); - } - if (additionalParameters.containsKey(refreshToken)) { - Map additionalParametersMap = additionalParameters.get(refreshToken); - for (Map.Entry entry : additionalParametersMap.entrySet()) { - String key = entry.getKey(); - String expectedValue = entry.getValue(); - if (!query.containsKey(key)) { - throw new IllegalArgumentException("Missing additional parameter: " + key); + String content = this.getContentAsString(); + Map query = TestUtils.parseQuery(content); + String accessToken = null; + String refreshToken = null; + String grantedScopesString = null; + boolean generateAccessToken = true; + + String foundId = query.get("client_id"); + boolean isUserEmailScope = false; + if (foundId != null) { + if (!clients.containsKey(foundId)) { + throw new IOException("Client ID not found."); + } + String foundSecret = query.get("client_secret"); + String expectedSecret = clients.get(foundId); + if ((foundSecret == null || !foundSecret.equals(expectedSecret))) { + throw new IOException("Client secret not found."); + } + String grantType = query.get("grant_type"); + if (grantType != null && grantType.equals("authorization_code")) { + String foundCode = query.get("code"); + if (!codes.containsKey(foundCode)) { + throw new IOException("Authorization code not found"); + } + refreshToken = codes.get(foundCode); } else { - String actualValue = query.get(key); - if (!expectedValue.equals(actualValue)) { - throw new IllegalArgumentException( - "For additional parameter " - + key - + ", Actual value: " - + actualValue - + ", Expected value: " - + expectedValue); + refreshToken = query.get("refresh_token"); + } + if (!refreshTokens.containsKey(refreshToken)) { + throw new IOException("Refresh Token not found."); + } + if (refreshToken.equals(REFRESH_TOKEN_WITH_USER_SCOPE)) { + isUserEmailScope = true; + } + accessToken = refreshTokens.get(refreshToken); + + if (grantedScopes.containsKey(refreshToken)) { + grantedScopesString = grantedScopes.get(refreshToken); + } + validateAdditionalParameters(query); + } else if (query.containsKey("grant_type")) { + String grantType = query.get("grant_type"); + String assertion = query.get("assertion"); + JsonWebSignature signature = JsonWebSignature.parse(JSON_FACTORY, assertion); + if (OAuth2Utils.GRANT_TYPE_JWT_BEARER.equals(grantType)) { + String foundEmail = signature.getPayload().getIssuer(); + if (!serviceAccounts.containsKey(foundEmail)) {} + accessToken = serviceAccounts.get(foundEmail); + String foundTargetAudience = + (String) signature.getPayload().get("target_audience"); + String foundScopes = (String) signature.getPayload().get("scope"); + if ((foundScopes == null || foundScopes.length() == 0) + && (foundTargetAudience == null || foundTargetAudience.length() == 0)) { + throw new IOException("Either target_audience or scopes must be specified."); + } + + if (foundScopes != null && foundTargetAudience != null) { + throw new IOException( + "Only one of target_audience or scopes must be specified."); + } + if (foundTargetAudience != null) { + generateAccessToken = false; + } + + // For GDCH scenario + } else if (OAuth2Utils.TOKEN_TYPE_TOKEN_EXCHANGE.equals(grantType)) { + String foundServiceIdentityName = signature.getPayload().getIssuer(); + if (!gdchServiceAccounts.containsKey(foundServiceIdentityName)) { + throw new IOException( + "GDCH Service Account Service Identity Name not found as issuer."); + } + accessToken = gdchServiceAccounts.get(foundServiceIdentityName); + String foundApiAudience = (String) signature.getPayload().get("api_audience"); + if ((foundApiAudience == null || foundApiAudience.length() == 0)) { + throw new IOException("Api_audience must be specified."); } + } else { + throw new IOException("Service Account Email not found as issuer."); } + } else { + throw new IOException("Unknown token type."); } + + // Create the JSON response + // https://developers.google.com/identity/protocols/OpenIDConnect#server-flow + GenericJson responseContents = new GenericJson(); + responseContents.setFactory(JSON_FACTORY); + responseContents.put("token_type", "Bearer"); + responseContents.put("expires_in", expiresInSeconds); + if (generateAccessToken) { + responseContents.put("access_token", accessToken); + if (refreshToken != null) { + responseContents.put("refresh_token", refreshToken); + } + if (grantedScopesString != null) { + responseContents.put("scope", grantedScopesString); + } + } + if (isUserEmailScope || !generateAccessToken) { + responseContents.put("id_token", ServiceAccountCredentialsTest.DEFAULT_ID_TOKEN); + } + String refreshText = responseContents.toPrettyString(); + + return new MockLowLevelHttpResponse() + .setContentType(Json.MEDIA_TYPE) + .setContent(refreshText); } + }; + return request; + } else if (urlWithoutQuery.equals(OAuth2Utils.TOKEN_REVOKE_URI.toString())) { + request = + new MockLowLevelHttpRequest(url) { + @Override + public LowLevelHttpResponse execute() throws IOException { + Map parameters = TestUtils.parseQuery(this.getContentAsString()); + String token = parameters.get("token"); + if (token == null) { + throw new IOException("Token to revoke not found."); + } + // Token could be access token or refresh token so remove keys and values + refreshTokens.values().removeAll(Collections.singleton(token)); + refreshTokens.remove(token); + return new MockLowLevelHttpResponse().setContentType(Json.MEDIA_TYPE); + } + }; + return request; + } + if (urlWithoutQuery.equals(WORKFORCE_IDENTITY_FEDERATION_TOKEN_SERVER_URI.toString())) { + request = + new MockLowLevelHttpRequest(url) { + @Override + public LowLevelHttpResponse execute() throws IOException { + + if (!responseSequence.isEmpty()) { + try { + return responseSequence.poll().get(); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + throw (IOException) cause; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Unexpectedly interrupted"); + } + } - } else if (query.containsKey("grant_type")) { - String grantType = query.get("grant_type"); - String assertion = query.get("assertion"); - JsonWebSignature signature = JsonWebSignature.parse(JSON_FACTORY, assertion); - if (OAuth2Utils.GRANT_TYPE_JWT_BEARER.equals(grantType)) { - String foundEmail = signature.getPayload().getIssuer(); - if (!serviceAccounts.containsKey(foundEmail)) {} - accessToken = serviceAccounts.get(foundEmail); - String foundTargetAudience = (String) signature.getPayload().get("target_audience"); - String foundScopes = (String) signature.getPayload().get("scope"); - if ((foundScopes == null || foundScopes.length() == 0) - && (foundTargetAudience == null || foundTargetAudience.length() == 0)) { - throw new IOException("Either target_audience or scopes must be specified."); + String content = this.getContentAsString(); + Map query = TestUtils.parseQuery(content); + + // Validate required fields. + if (!query.containsKey("code") + || !query.containsKey("client_id") + || !query.containsKey("redirect_uri") + || !query.containsKey("grant_type")) { + throw new IOException("Invalid request, missing one or more fields."); } - if (foundScopes != null && foundTargetAudience != null) { - throw new IOException("Only one of target_audience or scopes must be specified."); + String clientId = query.get("client_id"); + if (!clients.containsKey(clientId)) { + throw new IOException("Client ID not registered."); } - if (foundTargetAudience != null) { - generateAccessToken = false; + + if (!clients.containsKey(query.get("client_id"))) { + throw new IOException("Client ID not registered."); } - // For GDCH scenario - } else if (OAuth2Utils.TOKEN_TYPE_TOKEN_EXCHANGE.equals(grantType)) { - String foundServiceIdentityName = signature.getPayload().getIssuer(); - if (!gdchServiceAccounts.containsKey(foundServiceIdentityName)) { - throw new IOException( - "GDCH Service Account Service Identity Name not found as issuer."); + String grantType = query.get("grant_type"); + if (!grantType.equals("authorization_code")) { + throw new IOException("Invalid grant_type. Must be authorization_code."); } - accessToken = gdchServiceAccounts.get(foundServiceIdentityName); - String foundApiAudience = (String) signature.getPayload().get("api_audience"); - if ((foundApiAudience == null || foundApiAudience.length() == 0)) { - throw new IOException("Api_audience must be specified."); + + if (pkceProvider != null && !query.containsKey("code_verifier")) { + throw new IOException("Code verifier must be provided."); } - } else { - throw new IOException("Service Account Email not found as issuer."); - } - } else { - throw new IOException("Unknown token type."); - } - // Create the JSON response - // https://developers.google.com/identity/protocols/OpenIDConnect#server-flow - GenericJson responseContents = new GenericJson(); - responseContents.setFactory(JSON_FACTORY); - responseContents.put("token_type", "Bearer"); - responseContents.put("expires_in", expiresInSeconds); - if (generateAccessToken) { - responseContents.put("access_token", accessToken); - if (refreshToken != null) { + validateAdditionalParameters(query); + + // Generate response. + String refreshToken = codes.get(query.get("code")); + String accessToken = getAccessToken(refreshToken); + GenericJson responseContents = new GenericJson(); + responseContents.setFactory(JSON_FACTORY); + responseContents.put("token_type", "Bearer"); + responseContents.put("expires_in", expiresInSeconds); + responseContents.put("access_token", accessToken); responseContents.put("refresh_token", refreshToken); + + if (query.containsKey("scopes")) { + responseContents.put("scope", query.get("scopes")); + } + + String refreshText = responseContents.toPrettyString(); + + return new MockLowLevelHttpResponse() + .setContentType(Json.MEDIA_TYPE) + .setContent(refreshText); } - if (grantedScopesString != null) { - responseContents.put("scope", grantedScopesString); - } - } - if (isUserEmailScope || !generateAccessToken) { - responseContents.put("id_token", ServiceAccountCredentialsTest.DEFAULT_ID_TOKEN); - } - String refreshText = responseContents.toPrettyString(); + }; + return request; + } + return super.buildRequest(method, url); + } - return new MockLowLevelHttpResponse() - .setContentType(Json.MEDIA_TYPE) - .setContent(refreshText); - } - }; - } else if (urlWithoutQuery.equals(OAuth2Utils.TOKEN_REVOKE_URI.toString())) { - return new MockLowLevelHttpRequest(url) { - @Override - public LowLevelHttpResponse execute() throws IOException { - Map parameters = TestUtils.parseQuery(this.getContentAsString()); - String token = parameters.get("token"); - if (token == null) { - throw new IOException("Token to revoke not found."); + private void validateAdditionalParameters(Map query) { + if (additionalParameters.containsKey(query.get("code"))) { + Map additionalParametersMap = additionalParameters.get(query.get("code")); + for (Map.Entry entry : additionalParametersMap.entrySet()) { + String key = entry.getKey(); + String expectedValue = entry.getValue(); + if (!query.containsKey(key)) { + throw new IllegalArgumentException("Missing additional parameter: " + key); + } else { + String actualValue = query.get(key); + if (!expectedValue.equals(actualValue)) { + throw new IllegalArgumentException( + "For additional parameter " + + key + + ", Actual value: " + + actualValue + + ", Expected value: " + + expectedValue); } - // Token could be access token or refresh token so remove keys and values - refreshTokens.values().removeAll(Collections.singleton(token)); - refreshTokens.remove(token); - return new MockLowLevelHttpResponse().setContentType(Json.MEDIA_TYPE); } - }; + } } - return super.buildRequest(method, url); } } diff --git a/oauth2_http/javatests/com/google/auth/oauth2/OAuth2UtilsTest.java b/oauth2_http/javatests/com/google/auth/oauth2/OAuth2UtilsTest.java new file mode 100644 index 000000000..20f831917 --- /dev/null +++ b/oauth2_http/javatests/com/google/auth/oauth2/OAuth2UtilsTest.java @@ -0,0 +1,101 @@ +/* + * Copyright 2024 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package com.google.auth.oauth2; + +import static com.google.auth.oauth2.OAuth2Utils.generateBasicAuthHeader; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import org.junit.Test; + +/** Tests for {@link OAuth2Utils}. */ +public class OAuth2UtilsTest { + + @Test + public void testValidCredentials() { + String username = "testUser"; + String password = "testPassword"; + String expectedHeader = "Basic dGVzdFVzZXI6dGVzdFBhc3N3b3Jk"; + + String actualHeader = generateBasicAuthHeader(username, password); + + assertEquals(expectedHeader, actualHeader); + } + + @Test + public void testEmptyUsername_throws() { + String username = ""; + String password = "testPassword"; + + assertThrows( + IllegalArgumentException.class, + () -> { + generateBasicAuthHeader(username, password); + }); + } + + @Test + public void testEmptyPassword_throws() { + String username = "testUser"; + String password = ""; + + assertThrows( + IllegalArgumentException.class, + () -> { + generateBasicAuthHeader(username, password); + }); + } + + @Test + public void testNullUsername_throws() { + String username = null; + String password = "testPassword"; + + assertThrows( + IllegalArgumentException.class, + () -> { + generateBasicAuthHeader(username, password); + }); + } + + @Test + public void testNullPassword_throws() { + String username = "testUser"; + String password = null; + + assertThrows( + IllegalArgumentException.class, + () -> { + generateBasicAuthHeader(username, password); + }); + } +} diff --git a/oauth2_http/javatests/com/google/auth/oauth2/UserAuthorizerTest.java b/oauth2_http/javatests/com/google/auth/oauth2/UserAuthorizerTest.java index e0a8e2753..d53241fed 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/UserAuthorizerTest.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/UserAuthorizerTest.java @@ -31,17 +31,26 @@ package com.google.auth.oauth2; +import static com.google.auth.TestUtils.WORKFORCE_IDENTITY_FEDERATION_AUTH_URI; +import static com.google.auth.TestUtils.WORKFORCE_IDENTITY_FEDERATION_TOKEN_SERVER_URI; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import com.google.auth.TestUtils; +import com.google.auth.http.HttpTransportFactory; +import com.google.auth.oauth2.UserAuthorizer.ClientAuthenticationType; +import com.google.auth.oauth2.UserAuthorizer.TokenResponseWithConfig; import java.io.IOException; import java.net.URI; import java.net.URL; import java.util.Arrays; +import java.util.Collections; import java.util.Date; import java.util.HashMap; import java.util.List; @@ -90,6 +99,9 @@ public void constructorMinimum() { assertSame(store, authorizer.getTokenStore()); assertEquals(DUMMY_SCOPES, authorizer.getScopes()); assertEquals(UserAuthorizer.DEFAULT_CALLBACK_URI, authorizer.getCallbackUri()); + assertEquals( + UserAuthorizer.ClientAuthenticationType.CLIENT_SECRET_POST, + authorizer.getClientAuthenticationType()); } @Test @@ -102,12 +114,38 @@ public void constructorCommon() { .setScopes(DUMMY_SCOPES) .setTokenStore(store) .setCallbackUri(CALLBACK_URI) + .setClientAuthenticationType( + UserAuthorizer.ClientAuthenticationType.CLIENT_SECRET_BASIC) .build(); assertSame(CLIENT_ID, authorizer.getClientId()); assertSame(store, authorizer.getTokenStore()); assertEquals(DUMMY_SCOPES, authorizer.getScopes()); assertEquals(CALLBACK_URI, authorizer.getCallbackUri()); + assertEquals( + UserAuthorizer.ClientAuthenticationType.CLIENT_SECRET_BASIC, + authorizer.getClientAuthenticationType()); + } + + @Test + public void constructorWithClientAuthenticationTypeNone() { + TokenStore store = new MemoryTokensStorage(); + + UserAuthorizer authorizer = + UserAuthorizer.newBuilder() + .setClientId(CLIENT_ID) + .setScopes(DUMMY_SCOPES) + .setTokenStore(store) + .setCallbackUri(CALLBACK_URI) + .setClientAuthenticationType(UserAuthorizer.ClientAuthenticationType.NONE) + .build(); + + assertSame(CLIENT_ID, authorizer.getClientId()); + assertSame(store, authorizer.getTokenStore()); + assertEquals(DUMMY_SCOPES, authorizer.getScopes()); + assertEquals(CALLBACK_URI, authorizer.getCallbackUri()); + assertEquals( + UserAuthorizer.ClientAuthenticationType.NONE, authorizer.getClientAuthenticationType()); } @Test(expected = NullPointerException.class) @@ -229,6 +267,157 @@ public void getCredentials_noCredentials_returnsNull() throws IOException { assertNull(credentials); } + @Test + public void testGetTokenResponseFromAuthCodeExchange_convertsCodeToTokens() throws IOException { + MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory(); + transportFactory.transport.addClient(CLIENT_ID_VALUE, CLIENT_SECRET); + transportFactory.transport.addAuthorizationCode( + CODE, + REFRESH_TOKEN, + ACCESS_TOKEN_VALUE, + GRANTED_SCOPES_STRING, + /* additionalParameters= */ null); + + UserAuthorizer authorizer = + UserAuthorizer.newBuilder() + .setClientId(CLIENT_ID) + .setScopes(DUMMY_SCOPES) + .setHttpTransportFactory(transportFactory) + .build(); + + TokenResponseWithConfig response = + authorizer.getTokenResponseFromAuthCodeExchange( + CODE, BASE_URI, /* additionalParameters= */ null); + + assertEquals(REFRESH_TOKEN, response.getRefreshToken()); + assertNotNull(response.getAccessToken()); + assertEquals(ACCESS_TOKEN_VALUE, response.getAccessToken().getTokenValue()); + assertEquals(GRANTED_SCOPES, response.getAccessToken().getScopes()); + } + + @Test + public void testGetTokenResponseFromAuthCodeExchange_workforceIdentityFederationClientAuthBasic() + throws IOException { + MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory(); + transportFactory.transport.addClient(CLIENT_ID_VALUE, CLIENT_SECRET); + transportFactory.transport.setClientAuthType(ClientAuthenticationType.CLIENT_SECRET_BASIC); + transportFactory.transport.setPkceProvider(new DefaultPKCEProvider()); + transportFactory.transport.addAuthorizationCode( + CODE, + REFRESH_TOKEN, + ACCESS_TOKEN_VALUE, + GRANTED_SCOPES_STRING, + /* additionalParameters= */ null); + + UserAuthorizer authorizer = + UserAuthorizer.newBuilder() + .setClientId(CLIENT_ID) + .setScopes(Collections.singletonList("https://www.googleapis.com/auth/cloud-platform")) + .setTokenServerUri(WORKFORCE_IDENTITY_FEDERATION_TOKEN_SERVER_URI) + .setUserAuthUri(WORKFORCE_IDENTITY_FEDERATION_AUTH_URI) + .setClientAuthenticationType(ClientAuthenticationType.CLIENT_SECRET_BASIC) + .setPKCEProvider(new DefaultPKCEProvider()) + .setHttpTransportFactory(transportFactory) + .build(); + + TokenResponseWithConfig response = + authorizer.getTokenResponseFromAuthCodeExchange( + CODE, BASE_URI, /* additionalParameters= */ null); + + assertEquals(REFRESH_TOKEN, response.getRefreshToken()); + assertNotNull(response.getAccessToken()); + assertEquals(ACCESS_TOKEN_VALUE, response.getAccessToken().getTokenValue()); + + Map> headers = transportFactory.transport.getRequest().getHeaders(); + List authHeader = headers.get("authorization"); + + assertEquals( + OAuth2Utils.generateBasicAuthHeader(CLIENT_ID_VALUE, CLIENT_SECRET), + authHeader.iterator().next()); + assertEquals(1, authHeader.size()); + } + + @Test + public void testGetTokenResponseFromAuthCodeExchange_workforceIdentityFederationNoClientAuth() + throws IOException { + MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory(); + transportFactory.transport.addClient(CLIENT_ID_VALUE, CLIENT_SECRET); + transportFactory.transport.setClientAuthType(ClientAuthenticationType.CLIENT_SECRET_POST); + transportFactory.transport.addAuthorizationCode( + CODE, + REFRESH_TOKEN, + ACCESS_TOKEN_VALUE, + GRANTED_SCOPES_STRING, + /* additionalParameters= */ null); + + UserAuthorizer authorizer = + UserAuthorizer.newBuilder() + .setClientId(CLIENT_ID) + .setScopes(Collections.singletonList("https://www.googleapis.com/auth/cloud-platform")) + .setTokenServerUri(WORKFORCE_IDENTITY_FEDERATION_TOKEN_SERVER_URI) + .setUserAuthUri(WORKFORCE_IDENTITY_FEDERATION_AUTH_URI) + .setClientAuthenticationType(ClientAuthenticationType.NONE) + .setHttpTransportFactory(transportFactory) + .build(); + + TokenResponseWithConfig response = + authorizer.getTokenResponseFromAuthCodeExchange( + CODE, BASE_URI, /* additionalParameters= */ null); + + assertEquals(REFRESH_TOKEN, response.getRefreshToken()); + assertNotNull(response.getAccessToken()); + assertEquals(ACCESS_TOKEN_VALUE, response.getAccessToken().getTokenValue()); + + Map> headers = transportFactory.transport.getRequest().getHeaders(); + assertNull(headers.get("authorization")); + } + + @Test + public void testGetTokenResponseFromAuthCodeExchange_missingAuthCode_throws() { + UserAuthorizer authorizer = + UserAuthorizer.newBuilder().setClientId(CLIENT_ID).setScopes(DUMMY_SCOPES).build(); + + assertThrows( + NullPointerException.class, + () -> { + authorizer.getTokenResponseFromAuthCodeExchange( + /* code= */ null, BASE_URI, /* additionalParameters= */ null); + }); + } + + @Test + public void testGetTokenResponseFromAuthCodeExchange_missingAccessToken_throws() + throws IOException { + MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory(); + transportFactory.transport.addClient(CLIENT_ID_VALUE, CLIENT_SECRET); + // Missing access token. + transportFactory.transport.addAuthorizationCode( + CODE, + REFRESH_TOKEN, + /* accessToken= */ null, + GRANTED_SCOPES_STRING, + /* additionalParameters= */ null); + + UserAuthorizer authorizer = + UserAuthorizer.newBuilder() + .setClientId(CLIENT_ID) + .setScopes(DUMMY_SCOPES) + .setHttpTransportFactory(transportFactory) + .build(); + + IOException e = + assertThrows( + IOException.class, + () -> { + authorizer.getTokenResponseFromAuthCodeExchange( + CODE, BASE_URI, /* additionalParameters= */ null); + }); + + assertTrue( + e.getMessage() + .contains("Error reading result of Token API:Expected value access_token not found.")); + } + @Test public void getCredentials_storedCredentials_returnsStored() throws IOException { TokenStore tokenStore = new MemoryTokensStorage(); @@ -381,7 +570,7 @@ public void getCredentials_refreshedToken_different_granted_scopes() throws IOEx } @Test - public void getCredentialsFromCode_conevertsCodeToTokens() throws IOException { + public void getCredentialsFromCode_convertsCodeToTokens() throws IOException { MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory(); transportFactory.transport.addClient(CLIENT_ID_VALUE, CLIENT_SECRET); transportFactory.transport.addAuthorizationCode( @@ -645,12 +834,63 @@ public String getCodeChallenge() { } }; - UserAuthorizer authorizer = - UserAuthorizer.newBuilder() - .setClientId(CLIENT_ID) - .setScopes(DUMMY_SCOPES) - .setTokenStore(new MemoryTokensStorage()) - .setPKCEProvider(pkce) + UserAuthorizer.newBuilder() + .setClientId(CLIENT_ID) + .setScopes(DUMMY_SCOPES) + .setTokenStore(new MemoryTokensStorage()) + .setPKCEProvider(pkce) + .build(); + } + + @Test + public void testTokenResponseWithConfig() { + String clientId = "testClientId"; + String clientSecret = "testClientSecret"; + String refreshToken = "testRefreshToken"; + AccessToken accessToken = new AccessToken("token", new Date()); + URI tokenServerUri = URI.create("https://example.com/token"); + HttpTransportFactory httpTransportFactory = new MockTokenServerTransportFactory(); + + TokenResponseWithConfig tokenResponse = + TokenResponseWithConfig.newBuilder() + .setClientId(clientId) + .setClientSecret(clientSecret) + .setRefreshToken(refreshToken) + .setAccessToken(accessToken) + .setTokenServerUri(tokenServerUri) + .setHttpTransportFactory(httpTransportFactory) .build(); + + assertEquals(clientId, tokenResponse.getClientId()); + assertEquals(clientSecret, tokenResponse.getClientSecret()); + assertEquals(refreshToken, tokenResponse.getRefreshToken()); + assertEquals(accessToken, tokenResponse.getAccessToken()); + assertEquals(tokenServerUri, tokenResponse.getTokenServerUri()); + assertEquals(httpTransportFactory, tokenResponse.getHttpTransportFactory()); + } + + @Test + public void testTokenResponseWithConfig_noRefreshToken() { + String clientId = "testClientId"; + String clientSecret = "testClientSecret"; + AccessToken accessToken = new AccessToken("token", new Date()); + URI tokenServerUri = URI.create("https://example.com/token"); + HttpTransportFactory httpTransportFactory = new MockTokenServerTransportFactory(); + + TokenResponseWithConfig tokenResponse = + TokenResponseWithConfig.newBuilder() + .setClientId(clientId) + .setClientSecret(clientSecret) + .setAccessToken(accessToken) + .setTokenServerUri(tokenServerUri) + .setHttpTransportFactory(httpTransportFactory) + .build(); + + assertEquals(clientId, tokenResponse.getClientId()); + assertEquals(clientSecret, tokenResponse.getClientSecret()); + assertEquals(accessToken, tokenResponse.getAccessToken()); + assertEquals(tokenServerUri, tokenResponse.getTokenServerUri()); + assertEquals(httpTransportFactory, tokenResponse.getHttpTransportFactory()); + assertNull(tokenResponse.getRefreshToken()); } }