Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow registering a custom Predicate for determining non-blocking threads #3854

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2016-2023 VMware Inc. or its affiliates, All Rights Reserved.
* Copyright (c) 2016-2024 VMware Inc. or its affiliates, All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,6 +34,7 @@
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;

import io.micrometer.core.instrument.MeterRegistry;
Expand Down Expand Up @@ -121,6 +122,10 @@ public abstract class Schedulers {
.map(Boolean::parseBoolean)
.orElse(false);

static final Predicate<Thread> DEFAULT_NON_BLOCKING_THREAD_PREDICATE = thread -> false;

static Predicate<Thread> nonBlockingThreadPredicate = DEFAULT_NON_BLOCKING_THREAD_PREDICATE;

/**
* Create a {@link Scheduler} which uses a backing {@link Executor} to schedule
* Runnables for async operators.
Expand Down Expand Up @@ -659,24 +664,50 @@ public static void onHandleError(String key, BiConsumer<Thread, ? super Throwabl

/**
* Check if calling a Reactor blocking API in the current {@link Thread} is forbidden
* or not, by checking if the thread implements {@link NonBlocking} (in which case it is
* forbidden and this method returns {@code true}).
* or not. This method returns {@code true} and will forbid the Reactor blocking API if
* any of the following conditions meet:
* <ul>
* <li>the thread implements {@link NonBlocking}; or</li>
* <li>any of the {@link Predicate}s registered via {@link #registerNonBlockingThreadPredicate(Predicate)}
* returns {@code true}.</li>
* </ul>
*
* @return {@code true} if blocking is forbidden in this thread, {@code false} otherwise
*/
public static boolean isInNonBlockingThread() {
return Thread.currentThread() instanceof NonBlocking;
return isNonBlockingThread(Thread.currentThread());
}

/**
* Check if calling a Reactor blocking API in the given {@link Thread} is forbidden
* or not, by checking if the thread implements {@link NonBlocking} (in which case it is
* forbidden and this method returns {@code true}).
* or not. This method returns {@code true} and will forbid the Reactor blocking API if
* any of the following conditions meet:
* <ul>
* <li>the thread implements {@link NonBlocking}; or</li>
* <li>any of the {@link Predicate}s registered via {@link #registerNonBlockingThreadPredicate(Predicate)}
* returns {@code true}.</li>
* </ul>
*
* @return {@code true} if blocking is forbidden in that thread, {@code false} otherwise
*/
public static boolean isNonBlockingThread(Thread t) {
return t instanceof NonBlocking;
return t instanceof NonBlocking || nonBlockingThreadPredicate.test(t);
}

/**
* Registers the specified {@link Predicate} that determines whether it is forbidden to call
* a Reactor blocking API in a given {@link Thread} or not.
*/
public static void registerNonBlockingThreadPredicate(Predicate<Thread> predicate) {
nonBlockingThreadPredicate = nonBlockingThreadPredicate.or(predicate);
}

/**
* Unregisters all the {@link Predicate}s registered so far via
* {@link #registerNonBlockingThreadPredicate(Predicate)}.
*/
public static void resetNonBlockingThreadPredicate() {
nonBlockingThreadPredicate = DEFAULT_NON_BLOCKING_THREAD_PREDICATE;
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2016-2022 VMware Inc. or its affiliates, All Rights Reserved.
* Copyright (c) 2016-2024 VMware Inc. or its affiliates, All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@
import java.time.Duration;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
Expand Down Expand Up @@ -359,10 +360,53 @@ public void isNonBlockingThreadInstanceOf() {

@Test
public void isInNonBlockingThreadTrue() {
new ReactorThreadFactory.NonBlockingThread(() -> assertThat(Schedulers.isInNonBlockingThread())
.as("isInNonBlockingThread")
.isFalse(),
"isInNonBlockingThreadTrue");
assertNonBlockingThread(ReactorThreadFactory.NonBlockingThread::new, true);
}

@Test
public void customNonBlockingThreadPredicate() {
assertThat(Schedulers.nonBlockingThreadPredicate)
.as("nonBlockingThreadPredicate")
.isSameAs(Schedulers.DEFAULT_NON_BLOCKING_THREAD_PREDICATE);

// The custom `Predicate` is not registered yet,
// so `CustomNonBlockingThread` will be considered blocking.
assertNonBlockingThread(CustomNonBlockingThread::new, false);

// Now register the `Predicate` and ensure `CustomNonBlockingThread` is non-blocking.
Schedulers.registerNonBlockingThreadPredicate(t -> t instanceof CustomNonBlockingThread);
try {
assertNonBlockingThread(CustomNonBlockingThread::new, true);
} finally {
// Restore the global predicate.
Schedulers.resetNonBlockingThreadPredicate();
}

assertThat(Schedulers.nonBlockingThreadPredicate)
.as("nonBlockingThreadPredicate (after reset)")
.isSameAs(Schedulers.DEFAULT_NON_BLOCKING_THREAD_PREDICATE);
}

private static void assertNonBlockingThread(BiFunction<Runnable, String, Thread> threadFactory,
boolean expectedNonBlocking) {
CompletableFuture<Void> future = new CompletableFuture<>();
Thread thread = threadFactory.apply(() -> {
try {
assertThat(Schedulers.isInNonBlockingThread())
.as("isInNonBlockingThread")
.isEqualTo(expectedNonBlocking);
future.complete(null);
} catch (Throwable cause) {
future.completeExceptionally(cause);
}
}, "assertNonBlockingThread");

assertThat(Schedulers.isNonBlockingThread(thread))
.as("isNonBlockingThread")
.isEqualTo(expectedNonBlocking);

thread.start();
future.join();
}

@Test
Expand Down Expand Up @@ -1457,4 +1501,10 @@ public void dispose() {
}
}
}

final static class CustomNonBlockingThread extends Thread {
CustomNonBlockingThread(Runnable target, String name) {
super(target, name);
}
}
}
Loading