Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ final class GlobalSecurityContextHolderStrategy implements SecurityContextHolder

private static SecurityContext contextHolder;

SecurityContext peek() {
return contextHolder;
}

@Override
public void clearContext() {
contextHolder = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ final class InheritableThreadLocalSecurityContextHolderStrategy implements Secur

private static final ThreadLocal<SecurityContext> contextHolder = new InheritableThreadLocal<>();

SecurityContext peek() {
return contextHolder.get();
}

@Override
public void clearContext() {
contextHolder.remove();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,73 +16,130 @@

package org.springframework.security.core.context;

import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.Arrays;
import java.util.Collection;

final class ListeningSecurityContextHolderStrategy implements SecurityContextHolderStrategy {
import org.springframework.util.Assert;

private static final BiConsumer<SecurityContext, SecurityContext> NULL_PUBLISHER = (previous, current) -> {
};
/**
* An API for notifying when the {@link SecurityContext} changes.
*
* Note that this does not notify when the underlying authentication changes. To get
* notified about authentication changes, ensure that you are using {@link #setContext}
* when changing the authentication like so:
*
* <pre>
* SecurityContext context = SecurityContextHolder.createEmptyContext();
* context.setAuthentication(authentication);
* SecurityContextHolder.setContext(context);
* </pre>
*
* To add a listener to the existing {@link SecurityContextHolder}, you can do:
*
* <pre>
* SecurityContextHolderStrategy original = SecurityContextHolder.getContextHolderStrategy();
* SecurityContextChangedListener listener = new YourListener();
* SecurityContextHolderStrategy strategy = new ListeningSecurityContextHolderStrategy(original, listener);
* SecurityContextHolder.setContextHolderStrategy(strategy);
* </pre>
*
* NOTE: Any object that you supply to the {@link SecurityContextHolder} is now part of
* the static context and as such will not get garbage collected. To remove the reference,
* {@link SecurityContextHolder#setContextHolderStrategy reset the strategy} like so:
*
* <pre>
* SecurityContextHolder.setContextHolderStrategy(original);
* </pre>
*
* This will then allow {@code YourListener} and its members to be garbage collected.
*
* @author Josh Cummings
* @since 5.6
*/
public final class ListeningSecurityContextHolderStrategy implements SecurityContextHolderStrategy {

private final Supplier<SecurityContext> peek;
private final Collection<SecurityContextChangedListener> listeners;

private final SecurityContextHolderStrategy delegate;

private final SecurityContextEventPublisher base = new SecurityContextEventPublisher();

private BiConsumer<SecurityContext, SecurityContext> publisher = NULL_PUBLISHER;
/**
* Construct a {@link ListeningSecurityContextHolderStrategy}
* @param listeners the listeners that should be notified when the
* {@link SecurityContext} is {@link #setContext(SecurityContext) set} or
* {@link #clearContext() cleared}
* @param delegate the underlying {@link SecurityContextHolderStrategy}
*/
public ListeningSecurityContextHolderStrategy(SecurityContextHolderStrategy delegate,
Collection<SecurityContextChangedListener> listeners) {
Assert.notNull(delegate, "securityContextHolderStrategy cannot be null");
Assert.notNull(listeners, "securityContextChangedListeners cannot be null");
Assert.notEmpty(listeners, "securityContextChangedListeners cannot be empty");
Assert.noNullElements(listeners, "securityContextChangedListeners cannot contain null elements");
this.delegate = delegate;
this.listeners = listeners;
}

ListeningSecurityContextHolderStrategy(Supplier<SecurityContext> peek, SecurityContextHolderStrategy delegate) {
this.peek = peek;
/**
* Construct a {@link ListeningSecurityContextHolderStrategy}
* @param listeners the listeners that should be notified when the
* {@link SecurityContext} is {@link #setContext(SecurityContext) set} or
* {@link #clearContext() cleared}
* @param delegate the underlying {@link SecurityContextHolderStrategy}
*/
public ListeningSecurityContextHolderStrategy(SecurityContextHolderStrategy delegate,
SecurityContextChangedListener... listeners) {
Assert.notNull(delegate, "securityContextHolderStrategy cannot be null");
Assert.notNull(listeners, "securityContextChangedListeners cannot be null");
Assert.notEmpty(listeners, "securityContextChangedListeners cannot be empty");
Assert.noNullElements(listeners, "securityContextChangedListeners cannot contain null elements");
this.delegate = delegate;
this.listeners = Arrays.asList(listeners);
}

/**
* {@inheritDoc}
*/
@Override
public void clearContext() {
SecurityContext from = this.peek.get();
SecurityContext from = getContext();
this.delegate.clearContext();
this.publisher.accept(from, null);
publish(from, null);
}

/**
* {@inheritDoc}
*/
@Override
public SecurityContext getContext() {
return this.delegate.getContext();
}

/**
* {@inheritDoc}
*/
@Override
public void setContext(SecurityContext context) {
SecurityContext from = this.peek.get();
SecurityContext from = getContext();
this.delegate.setContext(context);
this.publisher.accept(from, context);
publish(from, context);
}

/**
* {@inheritDoc}
*/
@Override
public SecurityContext createEmptyContext() {
return this.delegate.createEmptyContext();
}

void addListener(SecurityContextChangedListener listener) {
this.base.listeners.add(listener);
this.publisher = this.base;
}

private static class SecurityContextEventPublisher implements BiConsumer<SecurityContext, SecurityContext> {

private final List<SecurityContextChangedListener> listeners = new CopyOnWriteArrayList<>();

@Override
public void accept(SecurityContext previous, SecurityContext current) {
if (previous == current) {
return;
}
SecurityContextChangedEvent event = new SecurityContextChangedEvent(previous, current);
for (SecurityContextChangedListener listener : this.listeners) {
listener.securityContextChanged(event);
}
private void publish(SecurityContext previous, SecurityContext current) {
if (previous == current) {
return;
}
SecurityContextChangedEvent event = new SecurityContextChangedEvent(previous, current);
for (SecurityContextChangedListener listener : this.listeners) {
listener.securityContextChanged(event);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ public class SecurityContextHolder {

public static final String MODE_GLOBAL = "MODE_GLOBAL";

private static final String MODE_PRE_INITIALIZED = "MODE_PRE_INITIALIZED";

public static final String SYSTEM_PROPERTY = "spring.security.strategy";

private static String strategyName = System.getProperty(SYSTEM_PROPERTY);
Expand All @@ -69,34 +71,41 @@ public class SecurityContextHolder {
}

private static void initialize() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be called during context init so shouldn't be called concurrently, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is called whenever the strategy changes. For example, if setContextHolderStrategy were called concurrently, then initialize() would be called concurrently. This isn't the way the API is intended to be used, though.

initializeStrategy();
initializeCount++;
}

private static void initializeStrategy() {
if (MODE_PRE_INITIALIZED.equals(strategyName)) {
Assert.state(strategy != null, "When using " + MODE_PRE_INITIALIZED
+ ", setContextHolderStrategy must be called with the fully constructed strategy");
return;
}
if (!StringUtils.hasText(strategyName)) {
// Set default
strategyName = MODE_THREADLOCAL;
}
if (strategyName.equals(MODE_THREADLOCAL)) {
ThreadLocalSecurityContextHolderStrategy delegate = new ThreadLocalSecurityContextHolderStrategy();
strategy = new ListeningSecurityContextHolderStrategy(delegate::peek, delegate);
strategy = new ThreadLocalSecurityContextHolderStrategy();
return;
}
else if (strategyName.equals(MODE_INHERITABLETHREADLOCAL)) {
InheritableThreadLocalSecurityContextHolderStrategy delegate = new InheritableThreadLocalSecurityContextHolderStrategy();
strategy = new ListeningSecurityContextHolderStrategy(delegate::peek, delegate);
if (strategyName.equals(MODE_INHERITABLETHREADLOCAL)) {
strategy = new InheritableThreadLocalSecurityContextHolderStrategy();
return;
}
else if (strategyName.equals(MODE_GLOBAL)) {
GlobalSecurityContextHolderStrategy delegate = new GlobalSecurityContextHolderStrategy();
strategy = new ListeningSecurityContextHolderStrategy(delegate::peek, delegate);
if (strategyName.equals(MODE_GLOBAL)) {
strategy = new GlobalSecurityContextHolderStrategy();
return;
}
else {
// Try to load a custom strategy
try {
Class<?> clazz = Class.forName(strategyName);
Constructor<?> customStrategy = clazz.getConstructor();
strategy = (SecurityContextHolderStrategy) customStrategy.newInstance();
}
catch (Exception ex) {
ReflectionUtils.handleReflectionException(ex);
}
// Try to load a custom strategy
try {
Class<?> clazz = Class.forName(strategyName);
Constructor<?> customStrategy = clazz.getConstructor();
strategy = (SecurityContextHolderStrategy) customStrategy.newInstance();
}
catch (Exception ex) {
ReflectionUtils.handleReflectionException(ex);
}
initializeCount++;
}

/**
Expand All @@ -118,7 +127,9 @@ public static SecurityContext getContext() {
* Primarily for troubleshooting purposes, this method shows how many times the class
* has re-initialized its <code>SecurityContextHolderStrategy</code>.
* @return the count (should be one unless you've called
* {@link #setStrategyName(String)} to switch to an alternate strategy.
* {@link #setStrategyName(String)} or
* {@link #setContextHolderStrategy(SecurityContextHolderStrategy)} to switch to an
* alternate strategy).
*/
public static int getInitializeCount() {
return initializeCount;
Expand All @@ -144,6 +155,41 @@ public static void setStrategyName(String strategyName) {
initialize();
}

/**
* Use this {@link SecurityContextHolderStrategy}.
*
* Call either {@link #setStrategyName(String)} or this method, but not both.
*
* This method is not thread safe. Changing the strategy while requests are in-flight
* may cause race conditions.
*
* {@link SecurityContextHolder} maintains a static reference to the provided
* {@link SecurityContextHolderStrategy}. This means that the strategy and its members
* will not be garbage collected until you remove your strategy.
*
* To ensure garbage collection, remember the original strategy like so:
*
* <pre>
* SecurityContextHolderStrategy original = SecurityContextHolder.getContextHolderStrategy();
* SecurityContextHolder.setContextHolderStrategy(myStrategy);
* </pre>
*
* And then when you are ready for {@code myStrategy} to be garbage collected you can
* do:
*
* <pre>
* SecurityContextHolder.setContextHolderStrategy(original);
* </pre>
* @param strategy the {@link SecurityContextHolderStrategy} to use
* @since 5.6
*/
public static void setContextHolderStrategy(SecurityContextHolderStrategy strategy) {
Assert.notNull(strategy, "securityContextHolderStrategy cannot be null");
SecurityContextHolder.strategyName = MODE_PRE_INITIALIZED;
SecurityContextHolder.strategy = strategy;
initialize();
}

/**
* Allows retrieval of the context strategy. See SEC-1188.
* @return the configured strategy for storing the security context.
Expand All @@ -159,38 +205,10 @@ public static SecurityContext createEmptyContext() {
return strategy.createEmptyContext();
}

/**
* Register a listener to be notified when the {@link SecurityContext} changes.
*
* Note that this does not notify when the underlying authentication changes. To get
* notified about authentication changes, ensure that you are using
* {@link #setContext} when changing the authentication like so:
*
* <pre>
* SecurityContext context = SecurityContextHolder.createEmptyContext();
* context.setAuthentication(authentication);
* SecurityContextHolder.setContext(context);
* </pre>
*
* To integrate this with Spring's
* {@link org.springframework.context.ApplicationEvent} support, you can add a
* listener like so:
*
* <pre>
* SecurityContextHolder.addListener(this.applicationContext::publishEvent);
* </pre>
* @param listener a listener to be notified when the {@link SecurityContext} changes
* @since 5.6
*/
public static void addListener(SecurityContextChangedListener listener) {
Assert.isInstanceOf(ListeningSecurityContextHolderStrategy.class, strategy,
"strategy must be of type ListeningSecurityContextHolderStrategy to add listeners");
((ListeningSecurityContextHolderStrategy) strategy).addListener(listener);
}

@Override
public String toString() {
return "SecurityContextHolder[strategy='" + strategyName + "'; initializeCount=" + initializeCount + "]";
return "SecurityContextHolder[strategy='" + strategy.getClass().getSimpleName() + "'; initializeCount="
+ initializeCount + "]";
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ final class ThreadLocalSecurityContextHolderStrategy implements SecurityContextH

private static final ThreadLocal<SecurityContext> contextHolder = new ThreadLocal<>();

SecurityContext peek() {
return contextHolder.get();
}

@Override
public void clearContext() {
contextHolder.remove();
Expand Down
Loading