Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;

import org.apache.kafka.common.utils.Time;
import org.jose4j.jwk.HttpsJwks;
import org.jose4j.jwk.JsonWebKey;
Expand All @@ -44,7 +45,7 @@
* possible to receive a JWT that contains a <code>kid</code> that points to yet-unknown JWK,
* thus requiring a connection to the OAuth/OIDC provider to be made. Hopefully, in practice,
* keys are made available for some amount of time before they're used within JWTs.
*
* <p>
* This instance is created and provided to the
* {@link org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver} that is used when using
* an HTTP-/HTTPS-based {@link org.jose4j.keys.resolvers.VerificationKeyResolver}, which is then
Expand Down Expand Up @@ -75,7 +76,7 @@ public final class RefreshingHttpsJwks implements Initable, Closeable {
* JWKS. In some cases, the call to {@link HttpsJwks#getJsonWebKeys()} will trigger a call
* to {@link HttpsJwks#refresh()} which will block the current thread in network I/O. We cache
* the JWKS ourselves (see {@link #jsonWebKeys}) to avoid the network I/O.
*
* <p>
* We want to be very careful where we use the {@link HttpsJwks} instance so that we don't
* perform any operation (directly or indirectly) that could cause blocking. This is because
* the JWKS logic is part of the larger authentication logic which operates on Kafka's network
Expand Down Expand Up @@ -121,23 +122,17 @@ public final class RefreshingHttpsJwks implements Initable, Closeable {
private boolean isInitialized;

/**
* Creates a <code>RefreshingHttpsJwks</code> that will be used by the
* {@link RefreshingHttpsJwksVerificationKeyResolver} to resolve new key IDs in JWTs.
*
* @param time {@link Time} instance
* @param httpsJwks {@link HttpsJwks} instance from which to retrieve the JWKS
* based on the OAuth/OIDC standard
* @param refreshMs The number of milliseconds between refresh passes to connect
* to the OAuth/OIDC JWKS endpoint to retrieve the latest set
* @param refreshRetryBackoffMs Time for delay after initial failed attempt to retrieve JWKS
* @param refreshRetryBackoffMaxMs Maximum time to retrieve JWKS
* Creates a <code>RefreshingHttpsJwks</code>. It should only be used for testing to pass in a mock executor
* service. Otherwise the constructor below should be used.
*/

public RefreshingHttpsJwks(Time time,
HttpsJwks httpsJwks,
long refreshMs,
long refreshRetryBackoffMs,
long refreshRetryBackoffMaxMs) {
// VisibleForTesting
RefreshingHttpsJwks(Time time,
HttpsJwks httpsJwks,
long refreshMs,
long refreshRetryBackoffMs,
long refreshRetryBackoffMaxMs,
ScheduledExecutorService executorService) {
if (refreshMs <= 0)
throw new IllegalArgumentException("JWKS validation key refresh configuration value retryWaitMs value must be positive");

Expand All @@ -146,7 +141,7 @@ public RefreshingHttpsJwks(Time time,
this.refreshMs = refreshMs;
this.refreshRetryBackoffMs = refreshRetryBackoffMs;
this.refreshRetryBackoffMaxMs = refreshRetryBackoffMaxMs;
this.executorService = Executors.newSingleThreadScheduledExecutor();
this.executorService = executorService;
this.missingKeyIds = new LinkedHashMap<String, Long>(MISSING_KEY_ID_CACHE_MAX_ENTRIES, .75f, true) {
@Override
protected boolean removeEldestEntry(Map.Entry<String, Long> eldest) {
Expand All @@ -155,6 +150,27 @@ protected boolean removeEldestEntry(Map.Entry<String, Long> eldest) {
};
}

/**
* Creates a <code>RefreshingHttpsJwks</code> that will be used by the
* {@link RefreshingHttpsJwksVerificationKeyResolver} to resolve new key IDs in JWTs.
*
* @param time {@link Time} instance
* @param httpsJwks {@link HttpsJwks} instance from which to retrieve the JWKS
* based on the OAuth/OIDC standard
* @param refreshMs The number of milliseconds between refresh passes to connect
* to the OAuth/OIDC JWKS endpoint to retrieve the latest set
* @param refreshRetryBackoffMs Time for delay after initial failed attempt to retrieve JWKS
* @param refreshRetryBackoffMaxMs Maximum time to retrieve JWKS
*/

public RefreshingHttpsJwks(Time time,
HttpsJwks httpsJwks,
long refreshMs,
long refreshRetryBackoffMs,
long refreshRetryBackoffMaxMs) {
this(time, httpsJwks, refreshMs, refreshRetryBackoffMs, refreshRetryBackoffMaxMs, Executors.newSingleThreadScheduledExecutor());
}

@Override
public void init() throws IOException {
try {
Expand All @@ -180,9 +196,9 @@ public void init() throws IOException {
//
// Note: we refer to this as a _scheduled_ refresh.
executorService.scheduleAtFixedRate(this::refresh,
refreshMs,
refreshMs,
TimeUnit.MILLISECONDS);
refreshMs,
refreshMs,
TimeUnit.MILLISECONDS);

log.info("JWKS validation key refresh thread started with a refresh interval of {} ms", refreshMs);
} finally {
Expand All @@ -203,7 +219,7 @@ public void close() {

if (!executorService.awaitTermination(SHUTDOWN_TIMEOUT, SHUTDOWN_TIME_UNIT)) {
log.warn("JWKS validation key refresh thread termination did not end after {} {}",
SHUTDOWN_TIMEOUT, SHUTDOWN_TIME_UNIT);
SHUTDOWN_TIMEOUT, SHUTDOWN_TIME_UNIT);
}
} catch (InterruptedException e) {
log.warn("JWKS validation key refresh thread error during close", e);
Expand All @@ -217,13 +233,12 @@ public void close() {
* Our implementation avoids the blocking call within {@link HttpsJwks#refresh()} that is
* sometimes called internal to {@link HttpsJwks#getJsonWebKeys()}. We want to avoid any
* blocking I/O as this code is running in the authentication path on the Kafka network thread.
*
* <p>
* The list may be stale up to {@link #refreshMs}.
*
* @return {@link List} of {@link JsonWebKey} instances
*
* @throws JoseException Thrown if a problem is encountered parsing the JSON content into JWKs
* @throws IOException Thrown f a problem is encountered making the HTTP request
* @throws IOException Thrown f a problem is encountered making the HTTP request
*/

public List<JsonWebKey> getJsonWebKeys() throws JoseException, IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,21 @@
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Collection;
import java.util.List;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.AbstractMap;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;

import org.apache.kafka.common.KafkaFuture;
import org.apache.kafka.common.internals.KafkaFutureImpl;
import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Time;
import org.jose4j.http.SimpleResponse;
Expand Down Expand Up @@ -122,14 +132,36 @@ public void testLongKey() throws Exception {
@Test
public void testSecondaryRefreshAfterElapsedDelay() throws Exception {
String keyId = "abc123";
Time time = MockTime.SYSTEM; // Unfortunately, we can't mock time here because the
// scheduled executor doesn't respect it.
MockTime time = new MockTime();
HttpsJwks httpsJwks = spyHttpsJwks();

try (RefreshingHttpsJwks refreshingHttpsJwks = getRefreshingHttpsJwks(time, httpsJwks)) {
MockExecutorService mockExecutorService = new MockExecutorService(time);
ScheduledExecutorService executorService = Mockito.mock(ScheduledExecutorService.class);
Mockito.doAnswer(invocation -> {
Runnable command = invocation.getArgument(0, Runnable.class);
long delay = invocation.getArgument(1, Long.class);
TimeUnit unit = invocation.getArgument(2, TimeUnit.class);
return mockExecutorService.schedule(() -> {
command.run();
return null;
}, unit.toMillis(delay), null);
}).when(executorService).schedule(Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.any(TimeUnit.class));
Mockito.doAnswer(invocation -> {
Runnable command = invocation.getArgument(0, Runnable.class);
long initialDelay = invocation.getArgument(1, Long.class);
long period = invocation.getArgument(2, Long.class);
TimeUnit unit = invocation.getArgument(3, TimeUnit.class);
return mockExecutorService.schedule(() -> {
command.run();
return null;
}, unit.toMillis(initialDelay), period);
}).when(executorService).scheduleAtFixedRate(Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.anyLong(), Mockito.any(TimeUnit.class));

try (RefreshingHttpsJwks refreshingHttpsJwks = getRefreshingHttpsJwks(time, httpsJwks, executorService)) {
refreshingHttpsJwks.init();
// We refresh once at the initialization time from getJsonWebKeys.
verify(httpsJwks, times(1)).refresh();
assertTrue(refreshingHttpsJwks.maybeExpediteRefresh(keyId));
verify(httpsJwks, times(2)).refresh();
time.sleep(REFRESH_MS + 1);
verify(httpsJwks, times(3)).refresh();
assertFalse(refreshingHttpsJwks.maybeExpediteRefresh(keyId));
Expand All @@ -153,6 +185,10 @@ private RefreshingHttpsJwks getRefreshingHttpsJwks(final Time time, final HttpsJ
return new RefreshingHttpsJwks(time, httpsJwks, REFRESH_MS, RETRY_BACKOFF_MS, RETRY_BACKOFF_MAX_MS);
}

private RefreshingHttpsJwks getRefreshingHttpsJwks(final Time time, final HttpsJwks httpsJwks, final ScheduledExecutorService executorService) {
return new RefreshingHttpsJwks(time, httpsJwks, REFRESH_MS, RETRY_BACKOFF_MS, RETRY_BACKOFF_MAX_MS, executorService);
}

/**
* We *spy* (not *mock*) the {@link HttpsJwks} instance because we want to have it
* _partially mocked_ to determine if it's calling its internal refresh method. We want to
Expand Down Expand Up @@ -195,4 +231,82 @@ public String getBody() {
return Mockito.spy(httpsJwks);
}

}
/**
* A mock ScheduledExecutorService just for the test. Note that this is not a generally reusable mock as it does not
* implement some interfaces like scheduleWithFixedDelay, etc. And it does not return ScheduledFuture correctly.
*/
private class MockExecutorService implements MockTime.Listener {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see it is basically identical to the MockScheduler. Could you explain why can't we use MockScheduler directly here?

Copy link
Contributor Author

@olalamichelle olalamichelle Aug 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is very similar but I have to create my own mock for 2 reasons:

  1. MockScheduler.schedule method does not take a period parameter which schedules a periodical task. I can add this function to MockScheduler but then I will also implement the corresponding addWaiter which is pretty much just put all the MockExecutorService code into MockScheduler. I don't think it is a good idea since it will make MockScheduler very cumbersome. Moreover, if I do that its schedule interface cannot take a ExecutorService parameter and it needs to take care of the execution itself inside. See below for more details.
  2. MockScheduler is really just a scheduler that does not take care of the execution. It just submits tasks to the executor service. The flakiness of the test comes from the real clock based executor service where, because of CPU scheduling, cannot be 100% accurate on timings. We need a callback-based one in the test to make sure it is reliable. It is just the logic (since there principles are the same) behind a mock scheduler and a mock executor service is the same that we have to use MockTime and be callback-based so they are very similar.

private final MockTime time;

private final TreeMap<Long, List<AbstractMap.SimpleEntry<Long, KafkaFutureImpl<Long>>>> waiters = new TreeMap<>();

public MockExecutorService(MockTime time) {
this.time = time;
time.addListener(this);
}

/**
* The actual execution and rescheduling logic. Check all internal tasks to see if any one reaches its next
* execution point, call it and optionally reschedule it if it has a specified period.
*/
@Override
public synchronized void onTimeUpdated() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some java doc on what we're trying to do in this method? Same as below. Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

long timeMs = time.milliseconds();
while (true) {
Map.Entry<Long, List<AbstractMap.SimpleEntry<Long, KafkaFutureImpl<Long>>>> entry = waiters.firstEntry();
if ((entry == null) || (entry.getKey() > timeMs)) {
break;
}
for (AbstractMap.SimpleEntry<Long, KafkaFutureImpl<Long>> pair : entry.getValue()) {
pair.getValue().complete(timeMs);
if (pair.getKey() != null) {
addWaiter(entry.getKey() + pair.getKey(), pair.getKey(), pair.getValue());
}
}
waiters.remove(entry.getKey());
}
}

/**
* Add a task with `delayMs` and optional period to the internal waiter.
* When `delayMs` < 0, we immediately complete the waiter. Otherwise, we add the task metadata to the waiter and
* onTimeUpdated will take care of execute and reschedule it when it reaches its scheduled timestamp.
*
* @param delayMs Delay time in ms.
* @param period Scheduling period, null means no periodic.
* @param waiter A wrapper over a callable function.
*/
private synchronized void addWaiter(long delayMs, Long period, KafkaFutureImpl<Long> waiter) {
long timeMs = time.milliseconds();
if (delayMs <= 0) {
waiter.complete(timeMs);
} else {
long triggerTimeMs = timeMs + delayMs;
List<AbstractMap.SimpleEntry<Long, KafkaFutureImpl<Long>>> futures =
waiters.computeIfAbsent(triggerTimeMs, k -> new ArrayList<>());
futures.add(new AbstractMap.SimpleEntry<>(period, waiter));
}
}

/**
* Internal utility function for periodic or one time refreshes.
*
* @param period null indicates one time refresh, otherwise it is periodic.
*/
public <T> ScheduledFuture<T> schedule(final Callable<T> callable, long delayMs, Long period) {

KafkaFutureImpl<Long> waiter = new KafkaFutureImpl<>();
waiter.thenApply((KafkaFuture.BaseFunction<Long, Void>) now -> {
try {
callable.call();
} catch (Throwable e) {
e.printStackTrace();
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we log anything here for future troubleshooting?

Copy link
Contributor Author

@olalamichelle olalamichelle Aug 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a test-only class so I think it should be fine? We should make sure scheduled thing does not have exceptions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, but we sometimes do e.printStackTrace(); for this kind of situation like here:
https://github.com/apache/kafka/blob/trunk/connect/runtime/src/test/java/org/apache/kafka/connect/util/TopicAdminTest.java#L272

It will be helpful when this test someday become flaky. Could we add it?

return null;
});
addWaiter(delayMs, period, waiter);
return null;
}
}

}