From b88d99d2f79d1e5f8b4b9f3d388b2c0b356668c7 Mon Sep 17 00:00:00 2001 From: Alexandre Dutra Date: Thu, 12 Dec 2024 16:01:42 +0100 Subject: [PATCH] Auth Manager API part 3: OAuth2 Manager --- .../apache/iceberg/rest/auth/AuthConfig.java | 47 +- .../iceberg/rest/auth/AuthManagers.java | 22 +- .../iceberg/rest/auth/AuthProperties.java | 3 + .../iceberg/rest/auth/AuthSessionCache.java | 127 +++++ .../iceberg/rest/auth/OAuth2Manager.java | 256 +++++++++ .../apache/iceberg/rest/auth/OAuth2Util.java | 26 +- .../rest/auth/RefreshingAuthManager.java | 88 +++ .../iceberg/rest/auth/TestAuthManagers.java | 39 ++ .../rest/auth/TestAuthSessionCache.java | 91 +++ .../iceberg/rest/auth/TestOAuth2Manager.java | 526 ++++++++++++++++++ 10 files changed, 1214 insertions(+), 11 deletions(-) create mode 100644 core/src/main/java/org/apache/iceberg/rest/auth/AuthSessionCache.java create mode 100644 core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Manager.java create mode 100644 core/src/main/java/org/apache/iceberg/rest/auth/RefreshingAuthManager.java create mode 100644 core/src/test/java/org/apache/iceberg/rest/auth/TestAuthSessionCache.java create mode 100644 core/src/test/java/org/apache/iceberg/rest/auth/TestOAuth2Manager.java diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/AuthConfig.java b/core/src/main/java/org/apache/iceberg/rest/auth/AuthConfig.java index 275884e1184a..1709d4a3e514 100644 --- a/core/src/main/java/org/apache/iceberg/rest/auth/AuthConfig.java +++ b/core/src/main/java/org/apache/iceberg/rest/auth/AuthConfig.java @@ -21,15 +21,16 @@ import java.util.Map; import javax.annotation.Nullable; import org.apache.iceberg.rest.ResourcePaths; +import org.apache.iceberg.util.PropertyUtil; import org.immutables.value.Value; /** - * The purpose of this class is to hold configuration options for {@link - * org.apache.iceberg.rest.auth.OAuth2Util.AuthSession}. + * The purpose of this class is to hold OAuth configuration options for {@link + * OAuth2Util.AuthSession}. */ @Value.Style(redactedMask = "****") -@SuppressWarnings("ImmutablesStyle") @Value.Immutable +@SuppressWarnings({"ImmutablesStyle", "SafeLoggingPropagation"}) public interface AuthConfig { @Nullable @Value.Redacted @@ -47,7 +48,7 @@ default String scope() { return OAuth2Properties.CATALOG_SCOPE; } - @Value.Lazy + @Value.Default @Nullable default Long expiresAtMillis() { return OAuth2Util.expiresAtMillis(token()); @@ -69,4 +70,42 @@ default String oauth2ServerUri() { static ImmutableAuthConfig.Builder builder() { return ImmutableAuthConfig.builder(); } + + static AuthConfig fromProperties(Map properties) { + return builder() + .credential(properties.get(OAuth2Properties.CREDENTIAL)) + .token(properties.get(OAuth2Properties.TOKEN)) + .scope(properties.getOrDefault(OAuth2Properties.SCOPE, OAuth2Properties.CATALOG_SCOPE)) + .oauth2ServerUri( + properties.getOrDefault(OAuth2Properties.OAUTH2_SERVER_URI, ResourcePaths.tokens())) + .optionalOAuthParams(OAuth2Util.buildOptionalParam(properties)) + .keepRefreshed( + PropertyUtil.propertyAsBoolean( + properties, + OAuth2Properties.TOKEN_REFRESH_ENABLED, + OAuth2Properties.TOKEN_REFRESH_ENABLED_DEFAULT)) + .expiresAtMillis(expiresAtMillis(properties)) + .build(); + } + + private static Long expiresAtMillis(Map props) { + Long expiresAtMillis = null; + + if (props.containsKey(OAuth2Properties.TOKEN)) { + expiresAtMillis = OAuth2Util.expiresAtMillis(props.get(OAuth2Properties.TOKEN)); + } + + if (expiresAtMillis == null) { + if (props.containsKey(OAuth2Properties.TOKEN_EXPIRES_IN_MS)) { + long millis = + PropertyUtil.propertyAsLong( + props, + OAuth2Properties.TOKEN_EXPIRES_IN_MS, + OAuth2Properties.TOKEN_EXPIRES_IN_MS_DEFAULT); + expiresAtMillis = System.currentTimeMillis() + millis; + } + } + + return expiresAtMillis; + } } diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/AuthManagers.java b/core/src/main/java/org/apache/iceberg/rest/auth/AuthManagers.java index 42c2b1eeba83..46188c1281c5 100644 --- a/core/src/main/java/org/apache/iceberg/rest/auth/AuthManagers.java +++ b/core/src/main/java/org/apache/iceberg/rest/auth/AuthManagers.java @@ -31,8 +31,23 @@ public class AuthManagers { private AuthManagers() {} public static AuthManager loadAuthManager(String name, Map properties) { - String authType = - properties.getOrDefault(AuthProperties.AUTH_TYPE, AuthProperties.AUTH_TYPE_NONE); + String authType = properties.get(AuthProperties.AUTH_TYPE); + if (authType == null) { + boolean hasCredential = properties.containsKey(OAuth2Properties.CREDENTIAL); + boolean hasToken = properties.containsKey(OAuth2Properties.TOKEN); + if (hasCredential || hasToken) { + LOG.warn( + "Inferring {}={} since property {} was provided. " + + "Please explicitly set {} to avoid this warning.", + AuthProperties.AUTH_TYPE, + AuthProperties.AUTH_TYPE_OAUTH2, + hasCredential ? OAuth2Properties.CREDENTIAL : OAuth2Properties.TOKEN, + AuthProperties.AUTH_TYPE); + authType = AuthProperties.AUTH_TYPE_OAUTH2; + } else { + authType = AuthProperties.AUTH_TYPE_NONE; + } + } String impl; switch (authType.toLowerCase(Locale.ROOT)) { @@ -42,6 +57,9 @@ public static AuthManager loadAuthManager(String name, Map prope case AuthProperties.AUTH_TYPE_BASIC: impl = AuthProperties.AUTH_MANAGER_IMPL_BASIC; break; + case AuthProperties.AUTH_TYPE_OAUTH2: + impl = AuthProperties.AUTH_MANAGER_IMPL_OAUTH2; + break; default: impl = authType; } diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/AuthProperties.java b/core/src/main/java/org/apache/iceberg/rest/auth/AuthProperties.java index bf94311d5578..a4ba2db586a7 100644 --- a/core/src/main/java/org/apache/iceberg/rest/auth/AuthProperties.java +++ b/core/src/main/java/org/apache/iceberg/rest/auth/AuthProperties.java @@ -26,11 +26,14 @@ private AuthProperties() {} public static final String AUTH_TYPE_NONE = "none"; public static final String AUTH_TYPE_BASIC = "basic"; + public static final String AUTH_TYPE_OAUTH2 = "oauth2"; public static final String AUTH_MANAGER_IMPL_NONE = "org.apache.iceberg.rest.auth.NoopAuthManager"; public static final String AUTH_MANAGER_IMPL_BASIC = "org.apache.iceberg.rest.auth.BasicAuthManager"; + public static final String AUTH_MANAGER_IMPL_OAUTH2 = + "org.apache.iceberg.rest.auth.OAuth2Manager"; public static final String BASIC_USERNAME = "rest.auth.basic.username"; public static final String BASIC_PASSWORD = "rest.auth.basic.password"; diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/AuthSessionCache.java b/core/src/main/java/org/apache/iceberg/rest/auth/AuthSessionCache.java new file mode 100644 index 000000000000..c27f29f5a5ab --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/rest/auth/AuthSessionCache.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import com.github.benmanes.caffeine.cache.Cache; +import com.github.benmanes.caffeine.cache.Caffeine; +import java.time.Duration; +import java.util.concurrent.Executor; +import java.util.concurrent.ForkJoinPool; +import java.util.function.Function; +import java.util.function.LongSupplier; +import javax.annotation.Nullable; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; + +/** A cache for {@link AuthSession} instances. */ +public class AuthSessionCache implements AutoCloseable { + + private final Duration sessionTimeout; + private final Executor executor; + private final LongSupplier nanoTimeSupplier; + + private volatile Cache sessionCache; + + /** + * Creates a new cache with the given session timeout, and with default executor and nano time + * supplier for eviction tasks. + * + * @param sessionTimeout the session timeout. Sessions will become eligible for eviction after + * this duration of inactivity. + */ + public AuthSessionCache(Duration sessionTimeout) { + this(sessionTimeout, null, null); + } + + /** + * Creates a new cache with the given session timeout, executor, and nano time supplier. This + * method is useful for testing mostly. + * + * @param sessionTimeout the session timeout. Sessions will become eligible for eviction after + * this duration of inactivity. + * @param executor the executor to use for eviction tasks; if null, the cache will use the + * {@linkplain ForkJoinPool#commonPool() common pool}. The executor will not be closed when + * this cache is closed. + * @param nanoTimeSupplier the supplier for nano time; if null, the cache will use {@link + * System#nanoTime()}. + */ + public AuthSessionCache( + Duration sessionTimeout, + @Nullable Executor executor, + @Nullable LongSupplier nanoTimeSupplier) { + this.sessionTimeout = sessionTimeout; + this.executor = executor; + this.nanoTimeSupplier = nanoTimeSupplier; + } + + /** + * Returns a cached session for the given key, loading it with the given loader if it is not + * already cached. + * + * @param key the key to use for the session. + * @param loader the loader to use to load the session if it is not already cached. + * @param the type of the session. + * @return the cached session. + */ + @SuppressWarnings("unchecked") + public T cachedSession(String key, Function loader) { + return (T) sessionCache().get(key, loader); + } + + @Override + public void close() { + Cache cache = sessionCache; + this.sessionCache = null; + if (cache != null) { + cache.invalidateAll(); + cache.cleanUp(); + } + } + + @VisibleForTesting + Cache sessionCache() { + if (sessionCache == null) { + synchronized (this) { + if (sessionCache == null) { + this.sessionCache = newSessionCache(sessionTimeout, executor, nanoTimeSupplier); + } + } + } + return sessionCache; + } + + private static Cache newSessionCache( + Duration sessionTimeout, Executor executor, LongSupplier nanoTimeSupplier) { + Caffeine builder = + Caffeine.newBuilder() + .expireAfterAccess(sessionTimeout) + .removalListener( + (id, auth, cause) -> { + if (auth != null) { + auth.close(); + } + }); + if (executor != null) { + builder.executor(executor); + } + if (nanoTimeSupplier != null) { + builder.ticker(nanoTimeSupplier::getAsLong); + } + return builder.build(); + } +} diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Manager.java b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Manager.java new file mode 100644 index 000000000000..271b8612f87d --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Manager.java @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.catalog.SessionCatalog; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.rest.RESTClient; +import org.apache.iceberg.rest.RESTUtil; +import org.apache.iceberg.rest.ResourcePaths; +import org.apache.iceberg.rest.responses.OAuthTokenResponse; +import org.apache.iceberg.util.PropertyUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class OAuth2Manager extends RefreshingAuthManager { + + private static final Logger LOG = LoggerFactory.getLogger(OAuth2Manager.class); + + private static final List TOKEN_PREFERENCE_ORDER = + ImmutableList.of( + OAuth2Properties.ID_TOKEN_TYPE, + OAuth2Properties.ACCESS_TOKEN_TYPE, + OAuth2Properties.JWT_TOKEN_TYPE, + OAuth2Properties.SAML2_TOKEN_TYPE, + OAuth2Properties.SAML1_TOKEN_TYPE); + + // Auth-related properties that are allowed to be passed to the table session + private static final Set TABLE_SESSION_ALLOW_LIST = + ImmutableSet.builder() + .add(OAuth2Properties.TOKEN) + .addAll(TOKEN_PREFERENCE_ORDER) + .build(); + + private RESTClient client; + private long startTimeMillis; + private OAuthTokenResponse authResponse; + private AuthSessionCache sessionCache; + + public OAuth2Manager(String name) { + super(name + "-token-refresh"); + } + + @Override + public OAuth2Util.AuthSession initSession(RESTClient initClient, Map properties) { + warnIfDeprecatedTokenEndpointUsed(properties); + AuthConfig config = AuthConfig.fromProperties(properties); + Map headers = OAuth2Util.authHeaders(config.token()); + OAuth2Util.AuthSession session = new OAuth2Util.AuthSession(headers, config); + if (config.credential() != null && !config.credential().isEmpty()) { + // Do not enable token refresh here since this is a short-lived session, + // but keep track of the start time, so that token refresh can be + // enabled later on when catalogSession is called. + this.startTimeMillis = System.currentTimeMillis(); + this.authResponse = + OAuth2Util.fetchToken( + initClient, + headers, + config.credential(), + config.scope(), + config.oauth2ServerUri(), + config.optionalOAuthParams()); + return OAuth2Util.AuthSession.fromTokenResponse( + initClient, null, authResponse, startTimeMillis, session); + } else if (config.token() != null) { + return OAuth2Util.AuthSession.fromAccessToken( + initClient, null, config.token(), null, session); + } + return session; + } + + @Override + public OAuth2Util.AuthSession catalogSession( + RESTClient sharedClient, Map properties) { + this.client = sharedClient; + this.sessionCache = newSessionCache(properties); + AuthConfig config = AuthConfig.fromProperties(properties); + Map headers = OAuth2Util.authHeaders(config.token()); + OAuth2Util.AuthSession session = new OAuth2Util.AuthSession(headers, config); + keepRefreshed(config.keepRefreshed()); + // authResponse comes from the init phase: this means we already fetched a token + // so reuse it now and turn token refresh on. + if (authResponse != null) { + return OAuth2Util.AuthSession.fromTokenResponse( + client, refreshExecutor(), authResponse, startTimeMillis, session); + } else if (config.credential() != null && !config.credential().isEmpty()) { + OAuthTokenResponse response = + OAuth2Util.fetchToken( + sharedClient, + headers, + config.credential(), + config.scope(), + config.oauth2ServerUri(), + config.optionalOAuthParams()); + return OAuth2Util.AuthSession.fromTokenResponse( + sharedClient, refreshExecutor(), response, System.currentTimeMillis(), session); + } else if (config.token() != null) { + return OAuth2Util.AuthSession.fromAccessToken( + client, refreshExecutor(), config.token(), config.expiresAtMillis(), session); + } + return session; + } + + @Override + public OAuth2Util.AuthSession contextualSession( + SessionCatalog.SessionContext context, AuthSession parent) { + return maybeCreateChildSession( + context.credentials(), + context.properties(), + ignored -> context.sessionId(), + (OAuth2Util.AuthSession) parent); + } + + @Override + public OAuth2Util.AuthSession tableSession( + TableIdentifier table, Map properties, AuthSession parent) { + return maybeCreateChildSession( + Maps.filterKeys(properties, TABLE_SESSION_ALLOW_LIST::contains), + properties, + properties::get, + (OAuth2Util.AuthSession) parent); + } + + @Override + public void close() { + try { + super.close(); + } finally { + AuthSessionCache cache = sessionCache; + this.sessionCache = null; + if (cache != null) { + cache.close(); + } + } + } + + protected AuthSessionCache newSessionCache(Map properties) { + return new AuthSessionCache(sessionTimeout(properties)); + } + + protected OAuth2Util.AuthSession maybeCreateChildSession( + Map credentials, + Map properties, + Function cacheKeyFunc, + OAuth2Util.AuthSession parent) { + if (credentials != null) { + // use the bearer token without exchanging + if (credentials.containsKey(OAuth2Properties.TOKEN)) { + String token = credentials.get(OAuth2Properties.TOKEN); + return sessionCache.cachedSession( + cacheKeyFunc.apply(OAuth2Properties.TOKEN), + k -> newSessionFromAccessToken(token, properties, parent)); + } + + if (credentials.containsKey(OAuth2Properties.CREDENTIAL)) { + // fetch a token using the client credentials flow + String credential = credentials.get(OAuth2Properties.CREDENTIAL); + return sessionCache.cachedSession( + cacheKeyFunc.apply(OAuth2Properties.CREDENTIAL), + k -> newSessionFromCredential(credential, parent)); + } + + for (String tokenType : TOKEN_PREFERENCE_ORDER) { + if (credentials.containsKey(tokenType)) { + // exchange the token for an access token using the token exchange flow + String token = credentials.get(tokenType); + return sessionCache.cachedSession( + cacheKeyFunc.apply(tokenType), + k -> newSessionFromTokenExchange(token, tokenType, parent)); + } + } + } + + return parent; + } + + protected OAuth2Util.AuthSession newSessionFromAccessToken( + String token, Map properties, OAuth2Util.AuthSession parent) { + Long expiresAtMillis = AuthConfig.fromProperties(properties).expiresAtMillis(); + return OAuth2Util.AuthSession.fromAccessToken( + client, refreshExecutor(), token, expiresAtMillis, parent); + } + + protected OAuth2Util.AuthSession newSessionFromCredential( + String credential, OAuth2Util.AuthSession parent) { + return OAuth2Util.AuthSession.fromCredential(client, refreshExecutor(), credential, parent); + } + + protected OAuth2Util.AuthSession newSessionFromTokenExchange( + String token, String tokenType, OAuth2Util.AuthSession parent) { + return OAuth2Util.AuthSession.fromTokenExchange( + client, refreshExecutor(), token, tokenType, parent); + } + + private static void warnIfDeprecatedTokenEndpointUsed(Map properties) { + if (usesDeprecatedTokenEndpoint(properties)) { + String credential = properties.get(OAuth2Properties.CREDENTIAL); + String initToken = properties.get(OAuth2Properties.TOKEN); + boolean hasCredential = credential != null && !credential.isEmpty(); + boolean hasInitToken = initToken != null; + if (hasInitToken || hasCredential) { + LOG.warn( + "Iceberg REST client is missing the OAuth2 server URI configuration and defaults to {}/{}. " + + "This automatic fallback will be removed in a future Iceberg release." + + "It is recommended to configure the OAuth2 endpoint using the '{}' property to be prepared. " + + "This warning will disappear if the OAuth2 endpoint is explicitly configured. " + + "See https://github.com/apache/iceberg/issues/10537", + RESTUtil.stripTrailingSlash(properties.get(CatalogProperties.URI)), + ResourcePaths.tokens(), + OAuth2Properties.OAUTH2_SERVER_URI); + } + } + } + + private static boolean usesDeprecatedTokenEndpoint(Map properties) { + if (properties.containsKey(OAuth2Properties.OAUTH2_SERVER_URI)) { + String oauth2ServerUri = properties.get(OAuth2Properties.OAUTH2_SERVER_URI); + boolean relativePath = !oauth2ServerUri.startsWith("http"); + boolean sameHost = oauth2ServerUri.startsWith(properties.get(CatalogProperties.URI)); + return relativePath || sameHost; + } + return true; + } + + private static Duration sessionTimeout(Map props) { + return Duration.ofMillis( + PropertyUtil.propertyAsLong( + props, + CatalogProperties.AUTH_SESSION_TIMEOUT_MS, + CatalogProperties.AUTH_SESSION_TIMEOUT_MS_DEFAULT)); + } +} diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java index 1757ae653cc9..2bcf592d2aab 100644 --- a/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java +++ b/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Util.java @@ -43,6 +43,9 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.rest.ErrorHandlers; +import org.apache.iceberg.rest.HTTPHeaders; +import org.apache.iceberg.rest.HTTPRequest; +import org.apache.iceberg.rest.ImmutableHTTPRequest; import org.apache.iceberg.rest.RESTClient; import org.apache.iceberg.rest.RESTUtil; import org.apache.iceberg.rest.ResourcePaths; @@ -451,18 +454,26 @@ static Long expiresAtMillis(String token) { } /** Class to handle authorization headers and token refresh. */ - public static class AuthSession { + public static class AuthSession implements org.apache.iceberg.rest.auth.AuthSession { private static int tokenRefreshNumRetries = 5; private static final long MAX_REFRESH_WINDOW_MILLIS = 300_000; // 5 minutes private static final long MIN_REFRESH_WAIT_MILLIS = 10; private volatile Map headers; private volatile AuthConfig config; - public AuthSession(Map baseHeaders, AuthConfig config) { - this.headers = RESTUtil.merge(baseHeaders, authHeaders(config.token())); + public AuthSession(Map headers, AuthConfig config) { + this.headers = ImmutableMap.copyOf(headers); this.config = config; } + @Override + public HTTPRequest authenticate(HTTPRequest request) { + HTTPHeaders newHeaders = request.headers().putIfAbsent(HTTPHeaders.of(headers())); + return newHeaders.equals(request.headers()) + ? request + : ImmutableHTTPRequest.builder().from(request).headers(newHeaders).build(); + } + public Map headers() { return headers; } @@ -487,6 +498,11 @@ public synchronized void stopRefreshing() { this.config = ImmutableAuthConfig.copyOf(config).withKeepRefreshed(false); } + @Override + public void close() { + stopRefreshing(); + } + public String credential() { return config.credential(); } @@ -647,7 +663,7 @@ public static AuthSession fromAccessToken( AuthSession parent) { AuthSession session = new AuthSession( - parent.headers(), + RESTUtil.merge(parent.headers(), authHeaders(token)), AuthConfig.builder() .from(parent.config()) .token(token) @@ -727,7 +743,7 @@ private static AuthSession fromTokenResponse( } AuthSession session = new AuthSession( - parent.headers(), + RESTUtil.merge(parent.headers(), authHeaders(response.token())), AuthConfig.builder() .from(parent.config()) .token(response.token()) diff --git a/core/src/main/java/org/apache/iceberg/rest/auth/RefreshingAuthManager.java b/core/src/main/java/org/apache/iceberg/rest/auth/RefreshingAuthManager.java new file mode 100644 index 000000000000..2b443e0ea5c1 --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/rest/auth/RefreshingAuthManager.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import java.util.List; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.iceberg.util.ThreadPools; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An {@link AuthManager} that provides machinery for refreshing authentication data asynchronously, + * using a background thread pool. + */ +public abstract class RefreshingAuthManager implements AuthManager { + + private static final Logger LOG = LoggerFactory.getLogger(RefreshingAuthManager.class); + + private final String executorNamePrefix; + private boolean keepRefreshed = true; + private volatile ScheduledExecutorService refreshExecutor; + + protected RefreshingAuthManager(String executorNamePrefix) { + this.executorNamePrefix = executorNamePrefix; + } + + public void keepRefreshed(boolean keep) { + this.keepRefreshed = keep; + } + + @Override + public void close() { + ScheduledExecutorService service = refreshExecutor; + this.refreshExecutor = null; + if (service != null) { + List tasks = service.shutdownNow(); + tasks.forEach( + task -> { + if (task instanceof Future) { + ((Future) task).cancel(true); + } + }); + + try { + if (!service.awaitTermination(1, TimeUnit.MINUTES)) { + LOG.warn("Timed out waiting for refresh executor to terminate"); + } + } catch (InterruptedException e) { + LOG.warn("Interrupted while waiting for refresh executor to terminate", e); + Thread.currentThread().interrupt(); + } + } + } + + @Nullable + protected ScheduledExecutorService refreshExecutor() { + if (!keepRefreshed) { + return null; + } + if (refreshExecutor == null) { + synchronized (this) { + if (refreshExecutor == null) { + this.refreshExecutor = ThreadPools.newScheduledPool(executorNamePrefix, 1); + } + } + } + return refreshExecutor; + } +} diff --git a/core/src/test/java/org/apache/iceberg/rest/auth/TestAuthManagers.java b/core/src/test/java/org/apache/iceberg/rest/auth/TestAuthManagers.java index 21bd8c1b2963..d49f398d7a47 100644 --- a/core/src/test/java/org/apache/iceberg/rest/auth/TestAuthManagers.java +++ b/core/src/test/java/org/apache/iceberg/rest/auth/TestAuthManagers.java @@ -43,6 +43,45 @@ public void after() { System.setErr(standardErr); } + @Test + void oauth2Explicit() { + try (AuthManager manager = + AuthManagers.loadAuthManager( + "test", Map.of(AuthProperties.AUTH_TYPE, AuthProperties.AUTH_TYPE_OAUTH2))) { + assertThat(manager).isInstanceOf(OAuth2Manager.class); + } + assertThat(streamCaptor.toString()) + .contains("Loading AuthManager implementation: org.apache.iceberg.rest.auth.OAuth2Manager"); + } + + @Test + void oauth2InferredFromToken() { + try (AuthManager manager = + AuthManagers.loadAuthManager("test", Map.of(OAuth2Properties.TOKEN, "irrelevant"))) { + assertThat(manager).isInstanceOf(OAuth2Manager.class); + } + assertThat(streamCaptor.toString()) + .contains( + "Inferring rest.auth.type=oauth2 since property token was provided. " + + "Please explicitly set rest.auth.type to avoid this warning."); + assertThat(streamCaptor.toString()) + .contains("Loading AuthManager implementation: org.apache.iceberg.rest.auth.OAuth2Manager"); + } + + @Test + void oauth2InferredFromCredential() { + try (AuthManager manager = + AuthManagers.loadAuthManager("test", Map.of(OAuth2Properties.CREDENTIAL, "irrelevant"))) { + assertThat(manager).isInstanceOf(OAuth2Manager.class); + } + assertThat(streamCaptor.toString()) + .contains( + "Inferring rest.auth.type=oauth2 since property credential was provided. " + + "Please explicitly set rest.auth.type to avoid this warning."); + assertThat(streamCaptor.toString()) + .contains("Loading AuthManager implementation: org.apache.iceberg.rest.auth.OAuth2Manager"); + } + @Test void noop() { try (AuthManager manager = AuthManagers.loadAuthManager("test", Map.of())) { diff --git a/core/src/test/java/org/apache/iceberg/rest/auth/TestAuthSessionCache.java b/core/src/test/java/org/apache/iceberg/rest/auth/TestAuthSessionCache.java new file mode 100644 index 000000000000..52b742d2c536 --- /dev/null +++ b/core/src/test/java/org/apache/iceberg/rest/auth/TestAuthSessionCache.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; + +import java.time.Duration; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +class TestAuthSessionCache { + + @Test + void cachedHitsAndMisses() { + AuthSessionCache cache = + new AuthSessionCache(Duration.ofHours(1), Runnable::run, System::nanoTime); + AuthSession session1 = Mockito.mock(AuthSession.class); + AuthSession session2 = Mockito.mock(AuthSession.class); + + @SuppressWarnings("unchecked") + Function loader = Mockito.mock(Function.class); + Mockito.when(loader.apply("key1")).thenReturn(session1); + Mockito.when(loader.apply("key2")).thenReturn(session2); + + AuthSession session = cache.cachedSession("key1", loader); + assertThat(session).isNotNull().isSameAs(session1); + + session = cache.cachedSession("key1", loader); + assertThat(session).isNotNull().isSameAs(session1); + + session = cache.cachedSession("key2", loader); + assertThat(session).isNotNull().isSameAs(session2); + + session = cache.cachedSession("key2", loader); + assertThat(session).isNotNull().isSameAs(session2); + + Mockito.verify(loader, times(1)).apply("key1"); + Mockito.verify(loader, times(1)).apply("key2"); + + assertThat(cache.sessionCache().asMap()).hasSize(2); + cache.close(); + assertThat(cache.sessionCache().asMap()).isEmpty(); + + Mockito.verify(session1).close(); + Mockito.verify(session2).close(); + } + + @Test + void cacheEviction() { + AtomicLong ticker = new AtomicLong(0); + AuthSessionCache cache = new AuthSessionCache(Duration.ofHours(1), Runnable::run, ticker::get); + AuthSession session1 = Mockito.mock(AuthSession.class); + + @SuppressWarnings("unchecked") + Function loader = Mockito.mock(Function.class); + Mockito.when(loader.apply("key1")).thenReturn(session1); + + AuthSession session = cache.cachedSession("key1", loader); + assertThat(session).isNotNull().isSameAs(session1); + + Mockito.verify(loader, times(1)).apply("key1"); + Mockito.verify(session1, never()).close(); + + ticker.set(TimeUnit.HOURS.toNanos(1)); + cache.sessionCache().cleanUp(); + Mockito.verify(session1).close(); + + cache.close(); + } +} diff --git a/core/src/test/java/org/apache/iceberg/rest/auth/TestOAuth2Manager.java b/core/src/test/java/org/apache/iceberg/rest/auth/TestOAuth2Manager.java new file mode 100644 index 000000000000..1fba602d4aa3 --- /dev/null +++ b/core/src/test/java/org/apache/iceberg/rest/auth/TestOAuth2Manager.java @@ -0,0 +1,526 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.rest.auth; + +import static java.util.Map.entry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.type; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.when; + +import java.time.Duration; +import java.util.Map; +import org.apache.iceberg.catalog.SessionCatalog; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.rest.RESTClient; +import org.apache.iceberg.rest.responses.OAuthTokenResponse; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +class TestOAuth2Manager { + + private RESTClient client; + + @BeforeEach + void before() { + client = Mockito.mock(RESTClient.class); + when(client.postForm(any(), any(), eq(OAuthTokenResponse.class), anyMap(), any())) + .thenReturn( + OAuthTokenResponse.builder() + .withToken("test") + .addScope("scope") + .withIssuedTokenType(OAuth2Properties.ACCESS_TOKEN_TYPE) + .withTokenType("bearer") + .setExpirationInSeconds(3600) + .build()); + } + + @Test + void initSessionNoOAuth2Properties() { + Map properties = Map.of(); + try (OAuth2Manager manager = new OAuth2Manager("test"); + OAuth2Util.AuthSession session = manager.initSession(client, properties)) { + assertThat(session.headers()).isEmpty(); + assertThat(manager) + .extracting("refreshExecutor") + .as("should not create refresh executor for init session") + .isNull(); + } + Mockito.verifyNoInteractions(client); + } + + @Test + void initSessionTokenProvided() { + Map properties = Map.of(OAuth2Properties.TOKEN, "test"); + try (OAuth2Manager manager = new OAuth2Manager("test"); + OAuth2Util.AuthSession session = manager.initSession(client, properties)) { + assertThat(session.headers()).containsOnly(entry("Authorization", "Bearer test")); + assertThat(manager) + .extracting("refreshExecutor") + .as("should not create refresh executor for init session") + .isNull(); + } + Mockito.verifyNoInteractions(client); + } + + @Test + void initSessionCredentialsProvided() { + Map properties = Map.of(OAuth2Properties.CREDENTIAL, "client:secret"); + try (OAuth2Manager manager = new OAuth2Manager("test"); + OAuth2Util.AuthSession session = manager.initSession(client, properties)) { + assertThat(session.headers()).containsOnly(entry("Authorization", "Bearer test")); + assertThat(manager) + .extracting("refreshExecutor") + .as("should not create refresh executor for init session") + .isNull(); + } + Mockito.verify(client) + .postForm( + any(), + eq( + Map.of( + "grant_type", "client_credentials", + "client_id", "client", + "client_secret", "secret", + "scope", "catalog")), + eq(OAuthTokenResponse.class), + eq(Map.of()), + any()); + } + + @Test + void catalogSessionNoOAuth2Properties() { + Map properties = Map.of(); + try (OAuth2Manager manager = new OAuth2Manager("test"); + OAuth2Util.AuthSession catalogSession = manager.catalogSession(client, properties)) { + assertThat(catalogSession.headers()).isEmpty(); + assertThat(manager) + .extracting("refreshExecutor") + .as("should not create refresh executor when no token and no credentials provided") + .isNull(); + } + Mockito.verifyNoInteractions(client); + } + + @Test + void catalogSessionTokenProvided() { + Map properties = Map.of(OAuth2Properties.TOKEN, "test"); + try (OAuth2Manager manager = new OAuth2Manager("test"); + OAuth2Util.AuthSession catalogSession = manager.catalogSession(client, properties)) { + assertThat(catalogSession.headers()).containsOnly(entry("Authorization", "Bearer test")); + assertThat(manager) + .extracting("refreshExecutor") + .as("should create refresh executor when token provided") + .isNotNull(); + } + Mockito.verifyNoInteractions(client); + } + + @Test + void catalogSessionRefreshDisabled() { + Map properties = + Map.of(OAuth2Properties.TOKEN, "test", OAuth2Properties.TOKEN_REFRESH_ENABLED, "false"); + try (OAuth2Manager manager = new OAuth2Manager("test"); + OAuth2Util.AuthSession catalogSession = manager.catalogSession(client, properties)) { + assertThat(catalogSession.headers()).containsOnly(entry("Authorization", "Bearer test")); + assertThat(manager) + .extracting("refreshExecutor") + .as("should not create refresh executor when refresh disabled") + .isNull(); + } + Mockito.verifyNoInteractions(client); + } + + @Test + void catalogSessionCredentialsProvidedWithInitSession() { + // Emulates the cases where the credentials are exchanged for a token when initSession is + // called, and the obtained token is used for the catalog session. + Map properties = Map.of(OAuth2Properties.CREDENTIAL, "client:secret"); + try (OAuth2Manager manager = new OAuth2Manager("test"); + OAuth2Util.AuthSession ignored = manager.initSession(client, properties); + OAuth2Util.AuthSession catalogSession = manager.catalogSession(client, properties)) { + assertThat(catalogSession.headers()).containsOnly(entry("Authorization", "Bearer test")); + assertThat(manager) + .extracting("refreshExecutor") + .as("should create refresh executor when credentials provided") + .isNotNull(); + } + Mockito.verify(client) + .postForm( + any(), + eq( + Map.of( + "grant_type", "client_credentials", + "client_id", "client", + "client_secret", "secret", + "scope", "catalog")), + eq(OAuthTokenResponse.class), + eq(Map.of()), + any()); + } + + @Test + void catalogSessionCredentialsProvidedWithoutInitSession() { + // Emulate the case where initSession is not called before catalogSession, + // so the credentials are exchanged for a token during the catalogSession call. + Map properties = Map.of(OAuth2Properties.CREDENTIAL, "client:secret"); + try (OAuth2Manager manager = new OAuth2Manager("test"); + OAuth2Util.AuthSession catalogSession = manager.catalogSession(client, properties)) { + assertThat(catalogSession.headers()).containsOnly(entry("Authorization", "Bearer test")); + assertThat(manager) + .extracting("refreshExecutor") + .as("should create refresh executor when credentials provided") + .isNotNull(); + } + Mockito.verify(client) + .postForm( + any(), + eq( + Map.of( + "grant_type", "client_credentials", + "client_id", "client", + "client_secret", "secret", + "scope", "catalog")), + eq(OAuthTokenResponse.class), + eq(Map.of()), + any()); + } + + @Test + void contextualSessionEmptyContext() { + SessionCatalog.SessionContext context = SessionCatalog.SessionContext.createEmpty(); + Map properties = Map.of(); + try (OAuth2Manager manager = new OAuth2Manager("test"); + OAuth2Util.AuthSession catalogSession = manager.catalogSession(client, properties); + OAuth2Util.AuthSession contextualSession = + manager.contextualSession(context, catalogSession)) { + assertThat(contextualSession).isSameAs(catalogSession); + assertThat(manager) + .extracting("refreshExecutor") + .as("should not create refresh executor when no context credentials provided") + .isNull(); + assertThat(manager) + .extracting("sessionCache") + .asInstanceOf(type(AuthSessionCache.class)) + .as("should not create session cache for empty context") + .satisfies(cache -> assertThat(cache.sessionCache().asMap()).isEmpty()); + } + Mockito.verifyNoInteractions(client); + } + + @Test + void contextualSessionTokenProvided() { + SessionCatalog.SessionContext context = + new SessionCatalog.SessionContext( + "test", "test", Map.of(OAuth2Properties.TOKEN, "context-token"), Map.of()); + Map properties = Map.of(); + try (OAuth2Manager manager = new OAuth2Manager("test"); + OAuth2Util.AuthSession catalogSession = manager.catalogSession(client, properties); + OAuth2Util.AuthSession contextualSession = + manager.contextualSession(context, catalogSession)) { + assertThat(contextualSession).isNotSameAs(catalogSession); + assertThat(contextualSession.headers()) + .containsOnly(entry("Authorization", "Bearer context-token")); + assertThat(manager) + .extracting("refreshExecutor") + .as("should create refresh executor when contextual session created") + .isNotNull(); + assertThat(manager) + .extracting("sessionCache") + .asInstanceOf(type(AuthSessionCache.class)) + .as("should create session cache for context with token") + .satisfies(cache -> assertThat(cache.sessionCache().asMap()).hasSize(1)); + } + Mockito.verifyNoInteractions(client); + } + + @Test + void contextualSessionCredentialsProvided() { + SessionCatalog.SessionContext context = + new SessionCatalog.SessionContext( + "test", "test", Map.of(OAuth2Properties.CREDENTIAL, "client:secret"), Map.of()); + Map properties = Map.of(); + try (OAuth2Manager manager = new OAuth2Manager("test"); + OAuth2Util.AuthSession catalogSession = manager.catalogSession(client, properties); + OAuth2Util.AuthSession contextualSession = + manager.contextualSession(context, catalogSession)) { + assertThat(contextualSession).isNotSameAs(catalogSession); + assertThat(contextualSession.headers()).containsOnly(entry("Authorization", "Bearer test")); + assertThat(manager) + .extracting("refreshExecutor") + .as("should create refresh executor when contextual session created") + .isNotNull(); + assertThat(manager) + .extracting("sessionCache") + .asInstanceOf(type(AuthSessionCache.class)) + .as("should create session cache for context with credentials") + .satisfies(cache -> assertThat(cache.sessionCache().asMap()).hasSize(1)); + } + Mockito.verify(client) + .postForm( + any(), + eq( + Map.of( + "grant_type", "client_credentials", + "client_id", "client", + "client_secret", "secret", + "scope", "catalog")), + eq(OAuthTokenResponse.class), + eq(Map.of()), + any()); + } + + @Test + void contextualSessionTokenExchange() { + SessionCatalog.SessionContext context = + new SessionCatalog.SessionContext( + "test", "test", Map.of(OAuth2Properties.ACCESS_TOKEN_TYPE, "context-token"), Map.of()); + Map properties = Map.of(OAuth2Properties.TOKEN, "catalog-token"); + try (OAuth2Manager manager = new OAuth2Manager("test"); + OAuth2Util.AuthSession catalogSession = manager.catalogSession(client, properties); + OAuth2Util.AuthSession contextualSession = + manager.contextualSession(context, catalogSession)) { + assertThat(contextualSession.headers()).containsOnly(entry("Authorization", "Bearer test")); + assertThat(manager) + .extracting("refreshExecutor") + .as("should create refresh executor when contextual session created") + .isNotNull(); + assertThat(manager) + .extracting("sessionCache") + .asInstanceOf(type(AuthSessionCache.class)) + .as("should create session cache for context with token exchange") + .satisfies(cache -> assertThat(cache.sessionCache().asMap()).hasSize(1)); + } + Mockito.verify(client) + .postForm( + any(), + eq( + Map.of( + "grant_type", "urn:ietf:params:oauth:grant-type:token-exchange", + "subject_token", "context-token", + "subject_token_type", "urn:ietf:params:oauth:token-type:access_token", + "actor_token", "catalog-token", + "actor_token_type", "urn:ietf:params:oauth:token-type:access_token", + "scope", "catalog")), + eq(OAuthTokenResponse.class), + eq(Map.of("Authorization", "Bearer catalog-token")), + any()); + } + + @Test + void contextualSessionCacheHit() { + SessionCatalog.SessionContext context = + new SessionCatalog.SessionContext( + "test", "test", Map.of(OAuth2Properties.TOKEN, "context-token"), Map.of()); + Map properties = Map.of(); + try (OAuth2Manager manager = Mockito.spy(new OAuth2Manager("test")); + OAuth2Util.AuthSession catalogSession = manager.catalogSession(client, properties); + OAuth2Util.AuthSession contextualSession1 = + manager.contextualSession(context, catalogSession); + OAuth2Util.AuthSession contextualSession2 = + manager.contextualSession(context, catalogSession)) { + assertThat(contextualSession1).isNotSameAs(catalogSession); + assertThat(contextualSession2).isNotSameAs(catalogSession); + assertThat(contextualSession1).isSameAs(contextualSession2); + assertThat(manager) + .extracting("sessionCache") + .asInstanceOf(type(AuthSessionCache.class)) + .as("should only create and cache contextual session once") + .satisfies(cache -> assertThat(cache.sessionCache().asMap()).hasSize(1)); + Mockito.verify(manager, times(1)) + .newSessionFromAccessToken("context-token", Map.of(), catalogSession); + } + Mockito.verifyNoInteractions(client); + } + + @Test + void tableSessionNoTableCredentials() { + Map properties = Map.of(); + TableIdentifier table = TableIdentifier.of("ns", "tbl"); + try (OAuth2Manager manager = new OAuth2Manager("test"); + OAuth2Util.AuthSession catalogSession = manager.catalogSession(client, properties); + OAuth2Util.AuthSession tableSession = + manager.tableSession(table, properties, catalogSession)) { + assertThat(tableSession).isSameAs(catalogSession); + assertThat(manager) + .extracting("refreshExecutor") + .as("should not create refresh executor when no table credentials provided") + .isNull(); + assertThat(manager) + .extracting("sessionCache") + .asInstanceOf(type(AuthSessionCache.class)) + .as("should not create session cache for empty table credentials") + .satisfies(cache -> assertThat(cache.sessionCache().asMap()).isEmpty()); + } + Mockito.verifyNoInteractions(client); + } + + @Test + void tableSessionTokenProvided() { + Map catalogProperties = Map.of(); + Map tableProperties = Map.of(OAuth2Properties.TOKEN, "table-token"); + TableIdentifier table = TableIdentifier.of("ns", "tbl"); + try (OAuth2Manager manager = new OAuth2Manager("test"); + OAuth2Util.AuthSession catalogSession = manager.catalogSession(client, catalogProperties); + OAuth2Util.AuthSession tableSession = + manager.tableSession(table, tableProperties, catalogSession)) { + assertThat(tableSession).isNotSameAs(catalogSession); + assertThat(tableSession.headers()).containsOnly(entry("Authorization", "Bearer table-token")); + assertThat(manager) + .extracting("refreshExecutor") + .as("should create refresh executor when table session created") + .isNotNull(); + assertThat(manager) + .extracting("sessionCache") + .asInstanceOf(type(AuthSessionCache.class)) + .as("should create session cache for table with token") + .satisfies(cache -> assertThat(cache.sessionCache().asMap()).hasSize(1)); + } + Mockito.verifyNoInteractions(client); + } + + @Test + void tableSessionTokenExchange() { + Map catalogProperties = Map.of(OAuth2Properties.TOKEN, "catalog-token"); + Map tableProperties = Map.of(OAuth2Properties.ACCESS_TOKEN_TYPE, "table-token"); + TableIdentifier table = TableIdentifier.of("ns", "tbl"); + try (OAuth2Manager manager = new OAuth2Manager("test"); + OAuth2Util.AuthSession catalogSession = manager.catalogSession(client, catalogProperties); + OAuth2Util.AuthSession tableSession = + manager.tableSession(table, tableProperties, catalogSession)) { + assertThat(tableSession.headers()).containsOnly(entry("Authorization", "Bearer test")); + assertThat(manager) + .extracting("refreshExecutor") + .as("should create refresh executor when table session created") + .isNotNull(); + assertThat(manager) + .extracting("sessionCache") + .asInstanceOf(type(AuthSessionCache.class)) + .as("should create session cache for table with token exchange") + .satisfies(cache -> assertThat(cache.sessionCache().asMap()).hasSize(1)); + } + Mockito.verify(client) + .postForm( + any(), + eq( + Map.of( + "grant_type", "urn:ietf:params:oauth:grant-type:token-exchange", + "subject_token", "table-token", + "subject_token_type", "urn:ietf:params:oauth:token-type:access_token", + "actor_token", "catalog-token", + "actor_token_type", "urn:ietf:params:oauth:token-type:access_token", + "scope", "catalog")), + eq(OAuthTokenResponse.class), + eq(Map.of("Authorization", "Bearer catalog-token")), + any()); + } + + @Test + void tableSessionCacheHit() { + Map catalogProperties = Map.of(); + Map tableProperties = Map.of(OAuth2Properties.TOKEN, "table-token"); + TableIdentifier table = TableIdentifier.of("ns", "tbl"); + try (OAuth2Manager manager = Mockito.spy(new OAuth2Manager("test")); + OAuth2Util.AuthSession catalogSession = manager.catalogSession(client, catalogProperties); + OAuth2Util.AuthSession tableSession1 = + manager.tableSession(table, tableProperties, catalogSession); + OAuth2Util.AuthSession tableSession2 = + manager.tableSession(table, tableProperties, catalogSession)) { + assertThat(tableSession1).isNotSameAs(catalogSession); + assertThat(tableSession2).isNotSameAs(catalogSession); + assertThat(tableSession1).isSameAs(tableSession2); + assertThat(manager) + .extracting("sessionCache") + .asInstanceOf(type(AuthSessionCache.class)) + .as("should only create and cache table session once") + .satisfies(cache -> assertThat(cache.sessionCache().asMap()).hasSize(1)); + Mockito.verify(manager, times(1)) + .newSessionFromAccessToken("table-token", Map.of("token", "table-token"), catalogSession); + } + Mockito.verifyNoInteractions(client); + } + + @Test + void tableSessionDisallowedTableProperties() { + // Servers should not include sensitive information in table properties; + // if they do, such properties should be ignored. + Map catalogProperties = Map.of(); + Map tableProperties = Map.of(OAuth2Properties.CREDENTIAL, "client:secret"); + TableIdentifier table = TableIdentifier.of("ns", "tbl"); + try (OAuth2Manager manager = Mockito.spy(new OAuth2Manager("test")); + OAuth2Util.AuthSession catalogSession = manager.catalogSession(client, catalogProperties); + OAuth2Util.AuthSession tableSession = + manager.tableSession(table, tableProperties, catalogSession)) { + assertThat(tableSession).isSameAs(catalogSession); + assertThat(manager) + .extracting("refreshExecutor") + .as("should not create refresh executor when table credentials were filtered out") + .isNull(); + assertThat(manager) + .extracting("sessionCache") + .asInstanceOf(type(AuthSessionCache.class)) + .as("should not create session cache for ignored table credentials") + .satisfies(cache -> assertThat(cache.sessionCache().asMap()).isEmpty()); + } + Mockito.verifyNoInteractions(client); + } + + @Test + void close() { + Map catalogProperties = Map.of(); + SessionCatalog.SessionContext context = + new SessionCatalog.SessionContext( + "test", "test", Map.of(OAuth2Properties.TOKEN, "context-token"), Map.of()); + Map tableProperties = Map.of(OAuth2Properties.TOKEN, "table-token"); + TableIdentifier table = TableIdentifier.of("ns", "tbl"); + try (OAuth2Manager manager = + new OAuth2Manager("test") { + @Override + protected AuthSessionCache newSessionCache(Map properties) { + return new AuthSessionCache(Duration.ofHours(1), Runnable::run, null); + } + + @Override + protected OAuth2Util.AuthSession newSessionFromAccessToken( + String token, Map properties, OAuth2Util.AuthSession parent) { + return Mockito.spy(super.newSessionFromAccessToken(token, properties, parent)); + } + }; + OAuth2Util.AuthSession catalogSession = manager.catalogSession(client, catalogProperties); + OAuth2Util.AuthSession contextualSession = + manager.contextualSession(context, catalogSession); + OAuth2Util.AuthSession tableSession = + manager.tableSession(table, tableProperties, contextualSession)) { + manager.close(); + assertThat(manager) + .extracting("refreshExecutor") + .as("should close refresh executor") + .isNull(); + assertThat(manager).extracting("sessionCache").as("should close session cache").isNull(); + // all cached sessions should be closed + Mockito.verify(contextualSession).close(); + Mockito.verify(tableSession).close(); + } + Mockito.verifyNoInteractions(client); + } +}