Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,12 @@
<dependency>
<groupId>redis.clients.authentication</groupId>
<artifactId>redis-authx-core</artifactId>
<version>0.1.1-beta1</version>
<version>0.1.1-beta2</version>
</dependency>
<dependency>
<groupId>redis.clients.authentication</groupId>
<artifactId>redis-authx-entraid</artifactId>
<version>0.1.1-beta1</version>
<version>0.1.1-beta2</version>
<scope>test</scope>
</dependency>
<!-- Start of core dependencies -->
Expand Down
144 changes: 75 additions & 69 deletions src/test/java/io/lettuce/authx/EntraIdIntegrationTests.java
Original file line number Diff line number Diff line change
@@ -1,27 +1,19 @@
package io.lettuce.authx;

import io.lettuce.core.ClientOptions;
import io.lettuce.core.RedisClient;
import io.lettuce.core.RedisFuture;
import io.lettuce.core.RedisURI;
import io.lettuce.core.SocketOptions;
import io.lettuce.core.TimeoutOptions;
import io.lettuce.core.TransactionResult;
import com.azure.identity.DefaultAzureCredential;
import com.azure.identity.DefaultAzureCredentialBuilder;
import io.lettuce.core.*;
import io.lettuce.core.api.StatefulRedisConnection;
import io.lettuce.core.api.async.RedisAsyncCommands;
import io.lettuce.core.api.sync.RedisCommands;
import io.lettuce.core.cluster.ClusterClientOptions;
import io.lettuce.core.pubsub.StatefulRedisPubSubConnection;
import io.lettuce.core.support.PubSubTestListener;
import io.lettuce.test.Wait;
import io.lettuce.test.env.Endpoints;
import io.lettuce.test.env.Endpoints.Endpoint;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.*;
import redis.clients.authentication.core.TokenAuthConfig;
import redis.clients.authentication.entraid.AzureTokenAuthConfigBuilder;
import redis.clients.authentication.entraid.EntraIDTokenAuthConfigBuilder;

import java.time.Duration;
Expand All @@ -41,52 +33,49 @@ public class EntraIdIntegrationTests {

private static final EntraIdTestContext testCtx = EntraIdTestContext.DEFAULT;

private static TokenBasedRedisCredentialsProvider credentialsProvider;
private RedisClient client;

private static RedisClient client;
private Endpoint standalone;

private static Endpoint standalone;
private ClientOptions clientOptions;

@BeforeAll
public static void setup() {
private TokenBasedRedisCredentialsProvider credentialsProvider;

@BeforeEach
public void setup() {
standalone = Endpoints.DEFAULT.getEndpoint("standalone-entraid-acl");
if (standalone != null) {
Assumptions.assumeTrue(testCtx.getClientId() != null && testCtx.getClientSecret() != null,
"Skipping EntraID tests. Azure AD credentials not provided!");
// Configure timeout options to assure fast test failover
ClusterClientOptions clientOptions = ClusterClientOptions.builder()
.socketOptions(SocketOptions.builder().connectTimeout(Duration.ofSeconds(1)).build())
.timeoutOptions(TimeoutOptions.enabled(Duration.ofSeconds(1)))
// enable re-authentication
.reauthenticateBehavior(ClientOptions.ReauthenticateBehavior.ON_NEW_CREDENTIALS).build();

TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder().clientId(testCtx.getClientId())
.secret(testCtx.getClientSecret()).authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes())
.expirationRefreshRatio(0.0000001F).build();

credentialsProvider = TokenBasedRedisCredentialsProvider.create(tokenAuthConfig);

RedisURI uri = RedisURI.create((standalone.getEndpoints().get(0)));
uri.setCredentialsProvider(credentialsProvider);
client = RedisClient.create(uri);
client.setOptions(clientOptions);
assumeTrue(standalone != null, "Skipping EntraID tests. Redis host with enabled EntraId not provided!");
Assumptions.assumeTrue(testCtx.getClientId() != null && testCtx.getClientSecret() != null,
"Skipping EntraID tests. Azure AD credentials not provided!");

}
clientOptions = ClientOptions.builder()
.socketOptions(SocketOptions.builder().connectTimeout(Duration.ofSeconds(1)).build())
.timeoutOptions(TimeoutOptions.enabled(Duration.ofSeconds(1)))
.reauthenticateBehavior(ClientOptions.ReauthenticateBehavior.ON_NEW_CREDENTIALS).build();

TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder().clientId(testCtx.getClientId())
.secret(testCtx.getClientSecret()).authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes())
.expirationRefreshRatio(0.0000001F).build();

TokenBasedRedisCredentialsProvider credentialsProvider = TokenBasedRedisCredentialsProvider.create(tokenAuthConfig);

client = createClient(credentialsProvider);
}

@AfterAll
public static void cleanup() {
@AfterEach
public void cleanUp() {
if (credentialsProvider != null) {
credentialsProvider.close();
}
if (client != null) {
client.shutdown();
}
}

// T.1.1
// Verify authentication using Azure AD with service principals using Redis Standalone client
@Test
public void standaloneWithSecret_azureServicePrincipalIntegrationTest() throws ExecutionException, InterruptedException {
assumeTrue(standalone != null, "Skipping EntraID tests. Redis host with enabled EntraId not provided!");

try (StatefulRedisConnection<String, String> connection = client.connect()) {
RedisCommands<String, String> sync = connection.sync();
String key = UUID.randomUUID().toString();
Expand All @@ -102,32 +91,23 @@ public void standaloneWithSecret_azureServicePrincipalIntegrationTest() throws E
// Test that the Redis client is not blocked/interrupted during token renewal.
@Test
public void renewalDuringOperationsTest() throws InterruptedException {
assumeTrue(standalone != null, "Skipping EntraID tests. Redis host with enabled EntraId not provided!");

// Counter to track the number of command cycles
AtomicInteger commandCycleCount = new AtomicInteger(0);

// Start a thread to continuously send Redis commands
Thread commandThread = new Thread(() -> {
try (StatefulRedisConnection<String, String> connection = client.connect()) {
RedisAsyncCommands<String, String> async = connection.async();
for (int i = 1; i <= 10; i++) {
// Start a transaction with SET and INCRBY commands
RedisFuture<String> multi = async.multi();
RedisFuture<String> set = async.set("key", "1");
RedisFuture<Long> incrby = async.incrby("key", 1);
async.multi();
async.set("key", "1");
async.incrby("key", 1);
RedisFuture<TransactionResult> exec = async.exec();
TransactionResult results = exec.get(1, TimeUnit.SECONDS);

// Increment the command cycle count after each execution
commandCycleCount.incrementAndGet();

// Verify the results from EXEC
assertThat(results).hasSize(2); // We expect 2 responses: SET and INCRBY

// Check the response from each command in the transaction
assertThat((String) results.get(0)).isEqualTo("OK"); // SET "key" = "1"
assertThat((Long) results.get(1)).isEqualTo(2L); // INCRBY "key" by 1, expected result is 2
assertThat(results).hasSize(2);
assertThat((String) results.get(0)).isEqualTo("OK");
assertThat((Long) results.get(1)).isEqualTo(2L);
}
} catch (Exception e) {
fail("Command execution failed during token refresh", e);
Expand All @@ -136,16 +116,12 @@ public void renewalDuringOperationsTest() throws InterruptedException {

commandThread.start();

CountDownLatch latch = new CountDownLatch(10); // Wait for at least 10 token renewals

credentialsProvider.credentials().subscribe(cred -> {
latch.countDown(); // Signal each renewal as it's received
});
CountDownLatch latch = new CountDownLatch(10); // Wait for at least 10 token renewalss
credentialsProvider.credentials().subscribe(cred -> latch.countDown());

assertThat(latch.await(1, TimeUnit.SECONDS)).isTrue(); // Wait to reach 10 renewals
commandThread.join(); // Wait for the command thread to finish

// Verify that at least 10 command cycles were executed during the test
assertThat(commandCycleCount.get()).isGreaterThanOrEqualTo(10);
}

Expand All @@ -162,7 +138,6 @@ public void renewalDuringPubSubOperationsTest() throws InterruptedException {
connectionPubSub.addListener(listener);
connectionPubSub.sync().subscribe("channel");

// Start a thread to continuously send Redis commands
Thread pubsubThread = new Thread(() -> {
for (int i = 1; i <= 100; i++) {
connectionPubSub1.sync().publish("channel", "message");
Expand All @@ -172,17 +147,48 @@ public void renewalDuringPubSubOperationsTest() throws InterruptedException {
pubsubThread.start();

CountDownLatch latch = new CountDownLatch(10);
credentialsProvider.credentials().subscribe(cred -> {
latch.countDown();
});
credentialsProvider.credentials().subscribe(cred -> latch.countDown());

assertThat(latch.await(2, TimeUnit.SECONDS)).isTrue(); // Wait for at least 10 token renewals
pubsubThread.join(); // Wait for the pub/sub thread to finish

// Verify that all messages were received
Wait.untilEquals(100, () -> listener.getMessages().size()).waitOrTimeout();
assertThat(listener.getMessages()).allMatch(msg -> msg.equals("message"));
}
}

@Test
public void azureTokenAuthWithDefaultAzureCredentials() throws ExecutionException, InterruptedException {
DefaultAzureCredential credential = new DefaultAzureCredentialBuilder().build();

TokenAuthConfig tokenAuthConfig = AzureTokenAuthConfigBuilder.builder().defaultAzureCredential(credential)
.tokenRequestExecTimeoutInMs(2000).build();

try (RedisClient azureCredClient = createClient(credentialsProvider);
TokenBasedRedisCredentialsProvider credentialsProvider = TokenBasedRedisCredentialsProvider
.create(tokenAuthConfig);) {
RedisCredentials credentials = credentialsProvider.resolveCredentials().block(Duration.ofSeconds(5));
assertThat(credentials).isNotNull();

String key = UUID.randomUUID().toString();
try (StatefulRedisConnection<String, String> connection = azureCredClient.connect()) {
RedisCommands<String, String> sync = connection.sync();
assertThat(sync.aclWhoami()).isEqualTo(credentials.getUsername());
sync.set(key, "value");
assertThat(sync.get(key)).isEqualTo("value");
assertThat(connection.async().get(key).get()).isEqualTo("value");
assertThat(connection.reactive().get(key).block()).isEqualTo("value");
sync.del(key);
}
}
}

private RedisClient createClient(TokenBasedRedisCredentialsProvider credentialsProvider) {
RedisURI uri = RedisURI.create((standalone.getEndpoints().get(0)));
uri.setCredentialsProvider(credentialsProvider);
RedisClient redis = RedisClient.create(uri);
redis.setOptions(clientOptions);
return redis;
}

}
Loading