Skip to content

Commit

Permalink
Hardens PropertiesUtil against recursive property sources (#3263)
Browse files Browse the repository at this point in the history
As showed in #3252, Spring's `JndiPropertySource` not only can throw exceptions, but can also perform logging calls.
Such a call causes a recursive call to `PropertiesUtil.getProperty("log4j2.flowMessageFactory"`) and a `StackOverflowException` in the best scenario. The worst scenario includes a deadlock.

This PR:

- Moves the creation of the default `MessageFactory` and `FlowMessageFactory` to the static initializer of `LoggerContext`. This should be close enough to the pre-2.23.0 location in `AbstractLogger`. The `LoggerContext` class is usually initialized, before Spring Boot adds its property sources to `PropertiesUtil`.
- Adds a check to `PropertiesUtil` to ignore recursive calls.

Closes #3252.
  • Loading branch information
ppkarwasz authored Dec 9, 2024
1 parent bad8b56 commit 18a1deb
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
import java.util.Map;
import java.util.Properties;
import java.util.stream.Stream;
import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.test.ListStatusListener;
import org.apache.logging.log4j.test.junit.UsingStatusListener;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.parallel.ResourceAccessMode;
Expand Down Expand Up @@ -193,16 +196,56 @@ void testPublish() {
@Test
@ResourceLock(value = Resources.SYSTEM_PROPERTIES, mode = ResourceAccessMode.READ)
@Issue("https://github.com/spring-projects/spring-boot/issues/33450")
void testBadPropertySource() {
@UsingStatusListener
void testErrorPropertySource(ListStatusListener statusListener) {
final String key = "testKey";
final Properties props = new Properties();
props.put(key, "test");
final PropertiesUtil util = new PropertiesUtil(props);
final ErrorPropertySource source = new ErrorPropertySource();
util.addPropertySource(source);
try {
statusListener.clear();
assertEquals("test", util.getStringProperty(key));
assertTrue(source.exceptionThrown);
assertThat(statusListener.findStatusData(Level.WARN))
.anySatisfy(data ->
assertThat(data.getMessage().getFormattedMessage()).contains("Failed"));
} finally {
util.removePropertySource(source);
}
}

@Test
@ResourceLock(value = Resources.SYSTEM_PROPERTIES, mode = ResourceAccessMode.READ)
@Issue("https://github.com/apache/logging-log4j2/issues/3252")
@UsingStatusListener
void testRecursivePropertySource(ListStatusListener statusListener) {
final String key = "testKey";
final Properties props = new Properties();
props.put(key, "test");
final PropertiesUtil util = new PropertiesUtil(props);
final PropertySource source = new RecursivePropertySource(util);
util.addPropertySource(source);
try {
// We ignore the recursive source
statusListener.clear();
assertThat(util.getStringProperty(key)).isEqualTo("test");
assertThat(statusListener.findStatusData(Level.WARN))
.anySatisfy(data -> assertThat(data.getMessage().getFormattedMessage())
.contains("Recursive call", "getProperty"));

statusListener.clear();
// To check for existence, the sources are looked up in a random order.
assertThat(util.hasProperty(key)).isTrue();
// To find a missing key, all the sources must be used.
assertThat(util.hasProperty("noSuchKey")).isFalse();
assertThat(statusListener.findStatusData(Level.WARN))
.anySatisfy(data -> assertThat(data.getMessage().getFormattedMessage())
.contains("Recursive call", "containsProperty"));
// We check that the source is recursive
assertThat(source.getProperty(key)).isEqualTo("test");
assertThat(source.containsProperty(key)).isTrue();
} finally {
util.removePropertySource(source);
}
Expand Down Expand Up @@ -289,4 +332,28 @@ public boolean containsProperty(final String key) {
throw new IllegalStateException("Test");
}
}

private static class RecursivePropertySource implements PropertySource {

private final PropertiesUtil propertiesUtil;

private RecursivePropertySource(PropertiesUtil propertiesUtil) {
this.propertiesUtil = propertiesUtil;
}

@Override
public int getPriority() {
return Integer.MIN_VALUE;
}

@Override
public String getProperty(String key) {
return propertiesUtil.getStringProperty(key);
}

@Override
public boolean containsProperty(String key) {
return propertiesUtil.hasProperty(key);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ public void reload() {}
private static final class Environment {

private final Set<PropertySource> sources = ConcurrentHashMap.newKeySet();
private final ThreadLocal<PropertySource> CURRENT_PROPERTY_SOURCE = new ThreadLocal<>();

private Environment(final PropertySource propertySource) {
final PropertySource sysProps = new PropertyFilePropertySource(LOG4J_SYSTEM_PROPERTIES_FILE_NAME, false);
Expand Down Expand Up @@ -547,21 +548,35 @@ private String get(final String key) {
}

private boolean sourceContainsProperty(final PropertySource source, final String key) {
try {
return source.containsProperty(key);
} catch (final Exception e) {
LOGGER.warn("Failed to retrieve Log4j property {} from property source {}.", key, source, e);
return false;
PropertySource recursiveSource = CURRENT_PROPERTY_SOURCE.get();
if (recursiveSource == null) {
CURRENT_PROPERTY_SOURCE.set(source);
try {
return source.containsProperty(key);
} catch (final Exception e) {
LOGGER.warn("Failed to retrieve Log4j property {} from property source {}.", key, source, e);
} finally {
CURRENT_PROPERTY_SOURCE.remove();
}
}
LOGGER.warn("Recursive call to `containsProperty()` from property source {}.", recursiveSource);
return false;
}

private String sourceGetProperty(final PropertySource source, final String key) {
try {
return source.getProperty(key);
} catch (final Exception e) {
LOGGER.warn("Failed to retrieve Log4j property {} from property source {}.", key, source, e);
return null;
PropertySource recursiveSource = CURRENT_PROPERTY_SOURCE.get();
if (recursiveSource == null) {
CURRENT_PROPERTY_SOURCE.set(source);
try {
return source.getProperty(key);
} catch (final Exception e) {
LOGGER.warn("Failed to retrieve Log4j property {} from property source {}.", key, source, e);
} finally {
CURRENT_PROPERTY_SOURCE.remove();
}
}
LOGGER.warn("Recursive call to `getProperty()` from property source {}.", recursiveSource);
return null;
}

private boolean containsKey(final String key) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,52 +20,61 @@

import org.apache.logging.log4j.message.AbstractMessageFactory;
import org.apache.logging.log4j.message.DefaultFlowMessageFactory;
import org.apache.logging.log4j.message.FlowMessageFactory;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.MessageFactory;
import org.apache.logging.log4j.message.ParameterizedMessageFactory;
import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.ClearSystemProperty;
import org.junit.jupiter.api.TestInfo;
import org.junitpioneer.jupiter.SetSystemProperty;

@SetSystemProperty(
key = "log4j2.messageFactory",
value = "org.apache.logging.log4j.core.LoggerMessageFactoryCustomizationTest$AlternativeTestMessageFactory")
@SetSystemProperty(
key = "log4j2.flowMessageFactory",
value = "org.apache.logging.log4j.core.LoggerMessageFactoryCustomizationTest$AlternativeTestFlowMessageFactory")
class LoggerMessageFactoryCustomizationTest {

@Test
@ClearSystemProperty(key = "log4j2.messageFactory")
@ClearSystemProperty(key = "log4j2.flowMessageFactory")
void arguments_should_be_honored() {
final LoggerContext loggerContext =
new LoggerContext(LoggerMessageFactoryCustomizationTest.class.getSimpleName());
final Logger logger = new Logger(
loggerContext, "arguments_should_be_honored", new TestMessageFactory(), new TestFlowMessageFactory());
assertTestMessageFactories(logger);
void arguments_should_be_honored(TestInfo testInfo) {
try (LoggerContext loggerContext =
new LoggerContext(LoggerMessageFactoryCustomizationTest.class.getSimpleName())) {
Logger logger = new Logger(
loggerContext, testInfo.getDisplayName(), new TestMessageFactory(), new TestFlowMessageFactory());
assertTestMessageFactories(logger, TestMessageFactory.class, TestFlowMessageFactory.class);
}
}

@Test
@SetSystemProperty(
key = "log4j2.messageFactory",
value = "org.apache.logging.log4j.core.LoggerMessageFactoryCustomizationTest$TestMessageFactory")
@SetSystemProperty(
key = "log4j2.flowMessageFactory",
value = "org.apache.logging.log4j.core.LoggerMessageFactoryCustomizationTest$TestFlowMessageFactory")
void properties_should_be_honored() {
final LoggerContext loggerContext =
new LoggerContext(LoggerMessageFactoryCustomizationTest.class.getSimpleName());
final Logger logger = new Logger(loggerContext, "properties_should_be_honored", null, null);
assertTestMessageFactories(logger);
void properties_should_be_honored(TestInfo testInfo) {
try (LoggerContext loggerContext =
new LoggerContext(LoggerMessageFactoryCustomizationTest.class.getSimpleName())) {
Logger logger = loggerContext.getLogger(testInfo.getDisplayName());
assertTestMessageFactories(
logger, AlternativeTestMessageFactory.class, AlternativeTestFlowMessageFactory.class);
}
}

private static void assertTestMessageFactories(Logger logger) {
assertThat((MessageFactory) logger.getMessageFactory()).isInstanceOf(TestMessageFactory.class);
assertThat(logger.getFlowMessageFactory()).isInstanceOf(TestFlowMessageFactory.class);
private static void assertTestMessageFactories(
Logger logger,
Class<? extends MessageFactory> messageFactoryClass,
Class<? extends FlowMessageFactory> flowMessageFactoryClass) {
assertThat((MessageFactory) logger.getMessageFactory()).isInstanceOf(messageFactoryClass);
assertThat(logger.getFlowMessageFactory()).isInstanceOf(flowMessageFactoryClass);
}

public static final class TestMessageFactory extends AbstractMessageFactory {
public static class TestMessageFactory extends AbstractMessageFactory {

@Override
public Message newMessage(final String message, final Object... params) {
return ParameterizedMessageFactory.INSTANCE.newMessage(message, params);
}
}

public static final class TestFlowMessageFactory extends DefaultFlowMessageFactory {}
public static class AlternativeTestMessageFactory extends TestMessageFactory {}

public static class TestFlowMessageFactory extends DefaultFlowMessageFactory {}

public static class AlternativeTestFlowMessageFactory extends TestFlowMessageFactory {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,20 @@
import org.apache.logging.log4j.message.MessageFactory;
import org.apache.logging.log4j.message.ParameterizedMessageFactory;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
import org.junitpioneer.jupiter.SetSystemProperty;

class LoggerMessageFactoryDefaultsTlaDisabledTest {

@Test
@SetSystemProperty(key = "log4j2.enableThreadLocals", value = "false")
void defaults_should_match_when_thread_locals_disabled() {
void defaults_should_match_when_thread_locals_disabled(TestInfo testInfo) {
assertThat(Constants.ENABLE_THREADLOCALS).isFalse();
final LoggerContext loggerContext =
new LoggerContext(LoggerMessageFactoryDefaultsTlaDisabledTest.class.getSimpleName());
final Logger logger =
new Logger(loggerContext, "defaults_should_match_when_thread_locals_disabled", null, null);
assertThat((MessageFactory) logger.getMessageFactory()).isSameAs(ParameterizedMessageFactory.INSTANCE);
assertThat(logger.getFlowMessageFactory()).isSameAs(DefaultFlowMessageFactory.INSTANCE);
try (LoggerContext loggerContext =
new LoggerContext(LoggerMessageFactoryDefaultsTlaDisabledTest.class.getSimpleName())) {
final Logger logger = loggerContext.getLogger(testInfo.getDisplayName());
assertThat((MessageFactory) logger.getMessageFactory()).isSameAs(ParameterizedMessageFactory.INSTANCE);
assertThat(logger.getFlowMessageFactory()).isSameAs(DefaultFlowMessageFactory.INSTANCE);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,21 @@
import org.apache.logging.log4j.message.MessageFactory;
import org.apache.logging.log4j.message.ReusableMessageFactory;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
import org.junitpioneer.jupiter.SetSystemProperty;

class LoggerMessageFactoryDefaultsTlaEnabledTest {

@Test
@SetSystemProperty(key = "log4j2.is.webapp", value = "false")
@SetSystemProperty(key = "log4j2.enableThreadLocals", value = "true")
void defaults_should_match_when_thread_locals_enabled() {
@SetSystemProperty(key = "log4j2.isWebapp", value = "false")
@SetSystemProperty(key = "log4j2.enableThreadlocals", value = "true")
void defaults_should_match_when_thread_locals_enabled(TestInfo testInfo) {
assertThat(Constants.ENABLE_THREADLOCALS).isTrue();
final LoggerContext loggerContext =
new LoggerContext(LoggerMessageFactoryDefaultsTlaEnabledTest.class.getSimpleName());
final Logger logger = new Logger(loggerContext, "defaults_should_match_when_thread_locals_enabled", null, null);
assertThat((MessageFactory) logger.getMessageFactory()).isSameAs(ReusableMessageFactory.INSTANCE);
assertThat(logger.getFlowMessageFactory()).isSameAs(DefaultFlowMessageFactory.INSTANCE);
try (LoggerContext loggerContext =
new LoggerContext(LoggerMessageFactoryDefaultsTlaEnabledTest.class.getSimpleName())) {
Logger logger = loggerContext.getLogger(testInfo.getDisplayName());
assertThat((MessageFactory) logger.getMessageFactory()).isSameAs(ReusableMessageFactory.INSTANCE);
assertThat(logger.getFlowMessageFactory()).isSameAs(DefaultFlowMessageFactory.INSTANCE);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import org.apache.logging.log4j.ThreadContext;
import org.apache.logging.log4j.core.ContextDataInjector;
import org.apache.logging.log4j.spi.ThreadContextMap;
import org.apache.logging.log4j.util.PropertiesUtil;
import org.apache.logging.log4j.util.ProviderUtil;
import org.apache.logging.log4j.util.SortedArrayStringMap;
import org.apache.logging.log4j.util.StringMap;
Expand All @@ -59,7 +58,6 @@ public static Collection<String[]> threadContextMapClassNames() {
public String threadContextMapClassName;

private static void resetThreadContextMap() {
PropertiesUtil.getProperties().reload();
final Log4jProvider provider = (Log4jProvider) ProviderUtil.getProvider();
provider.resetThreadContextMap();
ThreadContext.init();
Expand Down
Loading

0 comments on commit 18a1deb

Please sign in to comment.