Skip to content

Commit b3edd87

Browse files
authored
Update the StorageConfiguration to invoke singleton client objects, a… (#1386)
* Update the StorageConfiguration to invoke singleton client objects, and add a test * Fix formatting * using guava suppliers * Add aws region * Cleanup and mock test
1 parent e4969e3 commit b3edd87

File tree

2 files changed

+187
-30
lines changed

2 files changed

+187
-30
lines changed

service/common/src/main/java/org/apache/polaris/service/storage/StorageConfiguration.java

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import com.google.auth.oauth2.AccessToken;
2222
import com.google.auth.oauth2.GoogleCredentials;
23+
import com.google.common.base.Suppliers;
2324
import java.io.IOException;
2425
import java.time.Duration;
2526
import java.time.Instant;
@@ -60,38 +61,40 @@ public interface StorageConfiguration {
6061
Optional<Duration> gcpAccessTokenLifespan();
6162

6263
default Supplier<StsClient> stsClientSupplier() {
63-
return () -> {
64-
StsClientBuilder stsClientBuilder = StsClient.builder();
65-
if (awsAccessKey().isPresent() && awsSecretKey().isPresent()) {
66-
LoggerFactory.getLogger(StorageConfiguration.class)
67-
.warn("Using hard-coded AWS credentials - this is not recommended for production");
68-
StaticCredentialsProvider awsCredentialsProvider =
69-
StaticCredentialsProvider.create(
70-
AwsBasicCredentials.create(awsAccessKey().get(), awsSecretKey().get()));
71-
stsClientBuilder.credentialsProvider(awsCredentialsProvider);
72-
}
73-
return stsClientBuilder.build();
74-
};
64+
return Suppliers.memoize(
65+
() -> {
66+
StsClientBuilder stsClientBuilder = StsClient.builder();
67+
if (awsAccessKey().isPresent() && awsSecretKey().isPresent()) {
68+
LoggerFactory.getLogger(StorageConfiguration.class)
69+
.warn("Using hard-coded AWS credentials - this is not recommended for production");
70+
StaticCredentialsProvider awsCredentialsProvider =
71+
StaticCredentialsProvider.create(
72+
AwsBasicCredentials.create(awsAccessKey().get(), awsSecretKey().get()));
73+
stsClientBuilder.credentialsProvider(awsCredentialsProvider);
74+
}
75+
return stsClientBuilder.build();
76+
});
7577
}
7678

7779
default Supplier<GoogleCredentials> gcpCredentialsSupplier() {
78-
return () -> {
79-
if (gcpAccessToken().isEmpty()) {
80-
try {
81-
return GoogleCredentials.getApplicationDefault();
82-
} catch (IOException e) {
83-
throw new RuntimeException("Failed to get GCP credentials", e);
84-
}
85-
} else {
86-
AccessToken accessToken =
87-
new AccessToken(
88-
gcpAccessToken().get(),
89-
new Date(
90-
Instant.now()
91-
.plus(gcpAccessTokenLifespan().orElse(DEFAULT_TOKEN_LIFESPAN))
92-
.toEpochMilli()));
93-
return GoogleCredentials.create(accessToken);
94-
}
95-
};
80+
return Suppliers.memoize(
81+
() -> {
82+
if (gcpAccessToken().isEmpty()) {
83+
try {
84+
return GoogleCredentials.getApplicationDefault();
85+
} catch (IOException e) {
86+
throw new RuntimeException("Failed to get GCP credentials", e);
87+
}
88+
} else {
89+
AccessToken accessToken =
90+
new AccessToken(
91+
gcpAccessToken().get(),
92+
new Date(
93+
Instant.now()
94+
.plus(gcpAccessTokenLifespan().orElse(DEFAULT_TOKEN_LIFESPAN))
95+
.toEpochMilli()));
96+
return GoogleCredentials.create(accessToken);
97+
}
98+
});
9699
}
97100
}
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.polaris.service.storage;
20+
21+
import static org.assertj.core.api.Assertions.assertThat;
22+
import static org.mockito.Mockito.*;
23+
24+
import com.google.auth.oauth2.AccessToken;
25+
import com.google.auth.oauth2.GoogleCredentials;
26+
import java.time.Duration;
27+
import java.time.Instant;
28+
import java.util.Optional;
29+
import java.util.function.Supplier;
30+
import org.junit.jupiter.api.Test;
31+
import org.mockito.ArgumentCaptor;
32+
import org.mockito.MockedStatic;
33+
import org.mockito.Mockito;
34+
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
35+
import software.amazon.awssdk.services.sts.StsClient;
36+
import software.amazon.awssdk.services.sts.StsClientBuilder;
37+
38+
public class StorageConfigurationTest {
39+
40+
private static final String TEST_ACCESS_KEY = "test-access-key";
41+
private static final String TEST_GCP_TOKEN = "ya29.test-token";
42+
private static final String TEST_SECRET_KEY = "test-secret-key";
43+
private static final Duration TEST_TOKEN_LIFESPAN = Duration.ofMinutes(20);
44+
45+
private StorageConfiguration configWithAwsCredentialsAndGcpToken() {
46+
return new StorageConfiguration() {
47+
@Override
48+
public Optional<String> awsAccessKey() {
49+
return Optional.of(TEST_ACCESS_KEY);
50+
}
51+
52+
@Override
53+
public Optional<String> awsSecretKey() {
54+
return Optional.of(TEST_SECRET_KEY);
55+
}
56+
57+
@Override
58+
public Optional<String> gcpAccessToken() {
59+
return Optional.of(TEST_GCP_TOKEN);
60+
}
61+
62+
@Override
63+
public Optional<Duration> gcpAccessTokenLifespan() {
64+
return Optional.of(TEST_TOKEN_LIFESPAN);
65+
}
66+
};
67+
}
68+
69+
private StorageConfiguration configWithoutGcpToken() {
70+
return new StorageConfiguration() {
71+
@Override
72+
public Optional<String> awsAccessKey() {
73+
return Optional.empty();
74+
}
75+
76+
@Override
77+
public Optional<String> awsSecretKey() {
78+
return Optional.empty();
79+
}
80+
81+
@Override
82+
public Optional<String> gcpAccessToken() {
83+
return Optional.empty();
84+
}
85+
86+
@Override
87+
public Optional<Duration> gcpAccessTokenLifespan() {
88+
return Optional.empty();
89+
}
90+
};
91+
}
92+
93+
@Test
94+
public void testSingletonStsClientWithStaticCredentials() {
95+
StsClientBuilder mockBuilder = mock(StsClientBuilder.class);
96+
StsClient mockStsClient = mock(StsClient.class);
97+
ArgumentCaptor<StaticCredentialsProvider> credsCaptor =
98+
ArgumentCaptor.forClass(StaticCredentialsProvider.class);
99+
100+
when(mockBuilder.credentialsProvider(credsCaptor.capture())).thenReturn(mockBuilder);
101+
when(mockBuilder.region(any())).thenReturn(mockBuilder);
102+
when(mockBuilder.build()).thenReturn(mockStsClient);
103+
104+
try (MockedStatic<StsClient> staticMock = Mockito.mockStatic(StsClient.class)) {
105+
staticMock.when(StsClient::builder).thenReturn(mockBuilder);
106+
107+
StorageConfiguration config = configWithAwsCredentialsAndGcpToken();
108+
Supplier<StsClient> supplier = config.stsClientSupplier();
109+
StsClient client1 = supplier.get();
110+
StsClient client2 = supplier.get();
111+
112+
assertThat(client1).isSameAs(client2);
113+
assertThat(client1).isNotNull();
114+
115+
StaticCredentialsProvider credentialsProvider = credsCaptor.getValue();
116+
assertThat(credentialsProvider.resolveCredentials().accessKeyId()).isEqualTo(TEST_ACCESS_KEY);
117+
assertThat(credentialsProvider.resolveCredentials().secretAccessKey())
118+
.isEqualTo(TEST_SECRET_KEY);
119+
}
120+
}
121+
122+
@Test
123+
public void testCreateGcpCredentialsFromStaticToken() {
124+
Supplier<GoogleCredentials> supplier =
125+
configWithAwsCredentialsAndGcpToken().gcpCredentialsSupplier();
126+
127+
GoogleCredentials credentials = supplier.get();
128+
assertThat(credentials).isNotNull();
129+
130+
AccessToken accessToken = credentials.getAccessToken();
131+
assertThat(accessToken).isNotNull();
132+
assertThat(accessToken.getTokenValue()).isEqualTo(TEST_GCP_TOKEN);
133+
long expectedExpiry = Instant.now().plus(Duration.ofMinutes(20)).toEpochMilli();
134+
long actualExpiry = accessToken.getExpirationTime().getTime();
135+
assertThat(actualExpiry).isBetween(expectedExpiry - 500, expectedExpiry + 500);
136+
}
137+
138+
@Test
139+
public void testGcpCredentialsFromDefault() {
140+
GoogleCredentials mockDefaultCreds = mock(GoogleCredentials.class);
141+
142+
try (MockedStatic<GoogleCredentials> mockedStatic =
143+
Mockito.mockStatic(GoogleCredentials.class)) {
144+
145+
mockedStatic.when(GoogleCredentials::getApplicationDefault).thenReturn(mockDefaultCreds);
146+
147+
Supplier<GoogleCredentials> supplier = configWithoutGcpToken().gcpCredentialsSupplier();
148+
GoogleCredentials result = supplier.get();
149+
150+
assertThat(result).isSameAs(mockDefaultCreds);
151+
mockedStatic.verify(GoogleCredentials::getApplicationDefault, times(1));
152+
}
153+
}
154+
}

0 commit comments

Comments
 (0)