Skip to content

Commit

Permalink
Perform NullAway build-time checks in spring-test
Browse files Browse the repository at this point in the history
Closes gh-32475
  • Loading branch information
sdeleuze committed Mar 26, 2024
1 parent 96d9081 commit 996e66a
Show file tree
Hide file tree
Showing 27 changed files with 56 additions and 10 deletions.
3 changes: 1 addition & 2 deletions gradle/spring-module.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ tasks.withType(JavaCompile).configureEach {
option("NullAway:AnnotatedPackages", "org.springframework")
option("NullAway:UnannotatedSubPackages", "org.springframework.instrument,org.springframework.context.index," +
"org.springframework.asm,org.springframework.cglib,org.springframework.objenesis," +
"org.springframework.javapoet,org.springframework.aot.nativex.substitution,org.springframework.aot.nativex.feature," +
"org.springframework.test,org.springframework.mock")
"org.springframework.javapoet,org.springframework.aot.nativex.substitution,org.springframework.aot.nativex.feature")
}
}
tasks.compileJava {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ private GenericApplicationContext loadContextForAotProcessing(
}
catch (Exception ex) {
Throwable cause = (ex instanceof ContextLoadException cle ? cle.getCause() : ex);
Assert.state(cause != null, "Cause must not be null");
throw new TestContextAotException(
"Failed to load ApplicationContext for AOT processing for test class [%s]"
.formatted(testClass.getName()), cause);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ else if (enforceExistingDefinition) {
registry.registerBeanDefinition(beanName, beanDefinition);

Object override = overrideMetadata.createOverride(beanName, existingBeanDefinition, null);
Assert.state(this.beanFactory != null, "ConfigurableListableBeanFactory must not be null");
if (this.beanFactory.isSingleton(beanName)) {
// Now we have an instance (the override) that we can register.
// At this stage we don't expect a singleton instance to be present,
Expand Down Expand Up @@ -222,6 +223,7 @@ protected final Object wrapIfNecessary(Object bean, String beanName) throws Bean
final OverrideMetadata metadata = this.earlyOverrideMetadata.get(beanName);
if (metadata != null && metadata.getBeanOverrideStrategy() == BeanOverrideStrategy.WRAP_EARLY_BEAN) {
bean = metadata.createOverride(beanName, null, bean);
Assert.state(this.beanFactory != null, "ConfigurableListableBeanFactory must not be null");
metadata.track(bean, this.beanFactory);
}
return bean;
Expand All @@ -234,6 +236,7 @@ private RootBeanDefinition createBeanDefinition(OverrideMetadata metadata) {
}

private Set<String> getExistingBeanNames(ResolvableType resolvableType) {
Assert.state(this.beanFactory != null, "ConfigurableListableBeanFactory must not be null");
Set<String> beans = new LinkedHashSet<>(
Arrays.asList(this.beanFactory.getBeanNamesForType(resolvableType, true, false)));
Class<?> type = resolvableType.resolve(Object.class);
Expand Down Expand Up @@ -274,6 +277,7 @@ private void inject(Field field, Object target, String beanName) {
try {
ReflectionUtils.makeAccessible(field);
Object existingValue = ReflectionUtils.getField(field, target);
Assert.state(this.beanFactory != null, "ConfigurableListableBeanFactory must not be null");
Object bean = this.beanFactory.getBean(beanName, field.getType());
if (existingValue == bean) {
return;
Expand Down Expand Up @@ -308,7 +312,7 @@ public static void register(BeanDefinitionRegistry registry, @Nullable Set<Overr
constructorArgs.addIndexedArgumentValue(0, new LinkedHashSet<OverrideMetadata>()));
ConstructorArgumentValues.ValueHolder constructorArg =
definition.getConstructorArgumentValues().getIndexedArgumentValue(0, Set.class);
@SuppressWarnings("unchecked")
@SuppressWarnings({"unchecked", "NullAway"})
Set<OverrideMetadata> existing = (Set<OverrideMetadata>) constructorArg.getValue();
if (overrideMetadata != null && existing != null) {
existing.addAll(overrideMetadata);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.lang.Nullable;
import org.springframework.test.context.ContextConfigurationAttributes;
import org.springframework.test.context.ContextCustomizer;
import org.springframework.test.context.ContextCustomizerFactory;
Expand All @@ -37,6 +38,7 @@
public class BeanOverrideContextCustomizerFactory implements ContextCustomizerFactory {

@Override
@Nullable
public ContextCustomizer createContextCustomizer(Class<?> testClass,
List<ContextConfigurationAttributes> configAttributes) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import org.springframework.beans.BeanUtils;
import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.MergedAnnotation;
import org.springframework.core.annotation.MergedAnnotations;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
Expand Down Expand Up @@ -96,7 +97,9 @@ private void parseField(Field field, Class<?> source) {

BeanOverride beanOverride = mergedAnnotation.synthesize();
BeanOverrideProcessor processor = BeanUtils.instantiateClass(beanOverride.value());
Annotation composedAnnotation = mergedAnnotation.getMetaSource().synthesize();
MergedAnnotation<?> metaSource = mergedAnnotation.getMetaSource();
Assert.state(metaSource != null, "Meta-annotation source must not be null");
Annotation composedAnnotation = metaSource.synthesize();
ResolvableType typeToOverride = processor.getOrDeduceType(field, composedAnnotation, source);

Assert.state(overrideAnnotationFound.compareAndSet(false, true),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public String getBeanOverrideDescription() {
}

@Override
protected Object createOverride(String beanName, BeanDefinition existingBeanDefinition, Object existingBeanInstance) {
protected Object createOverride(String beanName, @Nullable BeanDefinition existingBeanDefinition, @Nullable Object existingBeanInstance) {
return createMock(beanName);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.core.NativeDetector;
import org.springframework.core.Ordered;
import org.springframework.lang.Nullable;
import org.springframework.test.context.TestContext;
import org.springframework.test.context.support.AbstractTestExecutionListener;

Expand Down Expand Up @@ -101,6 +102,7 @@ private void resetMocks(ConfigurableApplicationContext applicationContext, MockR
}
}

@Nullable
private Object getBean(ConfigurableListableBeanFactory beanFactory, String name) {
try {
if (isStandardBeanOrSingletonFactoryBean(beanFactory, name)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ private String[] getScripts(Sql sql, Class<?> testClass, @Nullable Method testMe
* Detect a default SQL script by implementing the algorithm defined in
* {@link Sql#scripts}.
*/
@SuppressWarnings("NullAway")
private String detectDefaultScript(Class<?> testClass, @Nullable Method testMethod, boolean classLevel) {
Assert.state(classLevel || testMethod != null, "Method-level @Sql requires a testMethod");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ private static Store getStore(ExtensionContext context) {
* the supplied {@link TestContextManager}.
* @since 6.1
*/
@SuppressWarnings("NullAway")
private static void registerMethodInvoker(TestContextManager testContextManager, ExtensionContext context) {
testContextManager.getTestContext().setMethodInvoker(context.getExecutableInvoker()::invoke);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ protected void dirtyContext(TestContext testContext, @Nullable HierarchyMode hie
* @since 4.2
* @see #dirtyContext
*/
@SuppressWarnings("NullAway")
protected void beforeOrAfterTestMethod(TestContext testContext, MethodMode requiredMethodMode,
ClassMode requiredClassMode) throws Exception {

Expand Down Expand Up @@ -135,6 +136,7 @@ else if (logger.isDebugEnabled()) {
* @since 4.2
* @see #dirtyContext
*/
@SuppressWarnings("NullAway")
protected void beforeOrAfterTestClass(TestContext testContext, ClassMode requiredClassMode) throws Exception {
Assert.notNull(testContext, "TestContext must not be null");
Assert.notNull(requiredClassMode, "requiredClassMode must not be null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ static Map<String, List<ContextConfigurationAttributes>> buildContextHierarchyMa
* @throws IllegalArgumentException if the supplied class is {@code null} or if
* {@code @ContextConfiguration} is not <em>present</em> on the supplied class
*/
@SuppressWarnings("NullAway")
static List<ContextConfigurationAttributes> resolveContextConfigurationAttributes(Class<?> testClass) {
Assert.notNull(testClass, "Class must not be null");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ else if (!duplicationDetected(currentAttributes, previousAttributes)) {
return mergedAttributes;
}

@SuppressWarnings("NullAway")
private static boolean duplicationDetected(TestPropertySourceAttributes currentAttributes,
@Nullable TestPropertySourceAttributes previousAttributes) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ public final int getOrder() {
* @see #getTransactionManager(TestContext, String)
*/
@Override
@SuppressWarnings("NullAway")
public void beforeTestMethod(final TestContext testContext) throws Exception {
Method testMethod = testContext.getTestMethod();
Class<?> testClass = testContext.getTestClass();
Expand Down Expand Up @@ -414,6 +415,7 @@ protected PlatformTransactionManager getTransactionManager(TestContext testConte
* @return the <em>default rollback</em> flag for the supplied test context
* @throws Exception if an error occurs while determining the default rollback flag
*/
@SuppressWarnings("NullAway")
protected final boolean isDefaultRollback(TestContext testContext) throws Exception {
Class<?> testClass = testContext.getTestClass();
Rollback rollback = TestContextAnnotationUtils.findMergedAnnotation(testClass, Rollback.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ public String getResourceBasePath() {
*/
@Override
public boolean equals(@Nullable Object other) {
return (this == other || (super.equals(other) &&
this.resourceBasePath.equals(((WebMergedContextConfiguration) other).resourceBasePath)));
return (this == other || (super.equals(other) && other instanceof WebMergedContextConfiguration otherConfiguration &&
this.resourceBasePath.equals(otherConfiguration.resourceBasePath)));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ public MediaTypeAssert isCompatibleWith(String mediaType) {
}


@SuppressWarnings("NullAway")
private MediaType parseMediaType(String value) {
try {
return MediaType.parseMediaType(value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.test.util;

import org.springframework.lang.Contract;
import org.springframework.lang.Nullable;
import org.springframework.util.ObjectUtils;

Expand All @@ -33,6 +34,7 @@ public abstract class AssertionErrors {
* Fail a test with the given message.
* @param message a message that describes the reason for the failure
*/
@Contract("_ -> fail")
public static void fail(String message) {
throw new AssertionError(message);
}
Expand Down Expand Up @@ -65,6 +67,7 @@ public static void fail(String message, @Nullable Object expected, @Nullable Obj
* @param message a message that describes the reason for the failure
* @param condition the condition to test for
*/
@Contract("_, false -> fail")
public static void assertTrue(String message, boolean condition) {
if (!condition) {
fail(message);
Expand All @@ -78,6 +81,7 @@ public static void assertTrue(String message, boolean condition) {
* @param condition the condition to test for
* @since 5.2.1
*/
@Contract("_, true -> fail")
public static void assertFalse(String message, boolean condition) {
if (condition) {
fail(message);
Expand All @@ -91,6 +95,7 @@ public static void assertFalse(String message, boolean condition) {
* @param object the object to check
* @since 5.2.1
*/
@Contract("_, !null -> fail")
public static void assertNull(String message, @Nullable Object object) {
assertTrue(message, object == null);
}
Expand All @@ -102,6 +107,7 @@ public static void assertNull(String message, @Nullable Object object) {
* @param object the object to check
* @since 5.1.8
*/
@Contract("_, null -> fail")
public static void assertNotNull(String message, @Nullable Object object) {
assertTrue(message, object != null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ public static void setField(
* @see ReflectionUtils#setField(Field, Object, Object)
* @see AopTestUtils#getUltimateTargetObject(Object)
*/
@SuppressWarnings("NullAway")
public static void setField(@Nullable Object targetObject, @Nullable Class<?> targetClass,
@Nullable String name, @Nullable Object value, @Nullable Class<?> type) {

Expand Down Expand Up @@ -259,6 +260,7 @@ public static Object getField(Class<?> targetClass, String name) {
* @see AopTestUtils#getUltimateTargetObject(Object)
*/
@Nullable
@SuppressWarnings("NullAway")
public static Object getField(@Nullable Object targetObject, @Nullable Class<?> targetClass, String name) {
Assert.isTrue(targetObject != null || targetClass != null,
"Either targetObject or targetClass for the field must be specified");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public abstract class ModelAndViewAssert {
* @param expectedType expected type of the model value
* @return the model value
*/
@SuppressWarnings("unchecked")
@SuppressWarnings({"unchecked", "NullAway"})
public static <T> T assertAndReturnModelAttributeOfType(ModelAndView mav, String modelName, Class<T> expectedType) {
Map<String, Object> model = mav.getModel();
Object obj = model.get(modelName);
Expand Down Expand Up @@ -109,6 +109,7 @@ public static void assertModelAttributeValue(ModelAndView mav, String modelName,
* @param mav the ModelAndView to test against (never {@code null})
* @param expectedModel the expected model
*/
@SuppressWarnings("NullAway")
public static void assertModelAttributeValues(ModelAndView mav, Map<String, Object> expectedModel) {
Map<String, Object> model = mav.getModel();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ public UriAssert matchesAntPattern(String uriPattern) {
return this;
}

@SuppressWarnings("NullAway")
private String buildUri(String uriTemplate, Object... uriVars) {
try {
return UriComponentsBuilder.fromUriString(uriTemplate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ public static RequestMatcher queryParamList(String name, Matcher<? super List<St
* @see #queryParam(String, String...)
*/
@SafeVarargs
@SuppressWarnings("NullAway")
public static RequestMatcher queryParam(String name, Matcher<? super String>... matchers) {
return request -> {
MultiValueMap<String, String> params = getQueryParams(request);
Expand Down Expand Up @@ -185,6 +186,7 @@ public static RequestMatcher queryParam(String name, Matcher<? super String>...
* @see #queryParamList(String, Matcher)
* @see #queryParam(String, Matcher...)
*/
@SuppressWarnings("NullAway")
public static RequestMatcher queryParam(String name, String... expectedValues) {
return request -> {
MultiValueMap<String, String> params = getQueryParams(request);
Expand Down Expand Up @@ -362,7 +364,7 @@ private static void assertValueCount(
if (values == null) {
fail(message + " to exist but was null");
}
if (count > values.size()) {
else if (count > values.size()) {
fail(message + " to have at least <" + count + "> values but found " + values);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ public ResponseSpec exchange() {
DefaultWebTestClient.this.entityResultConsumer, getResponseTimeout());
}

@SuppressWarnings("NullAway")
private ClientRequest.Builder initRequestBuilder() {
return ClientRequest.create(this.httpMethod, initUri())
.headers(headersToUse -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class WiretapConnector implements ClientHttpConnector {


@Override
@SuppressWarnings("NullAway")
public Mono<ClientHttpResponse> connect(HttpMethod method, URI uri,
Function<? super ClientHttpRequest, Mono<Void>> requestCallback) {

Expand Down Expand Up @@ -181,6 +182,7 @@ public Publisher<? extends Publisher<? extends DataBuffer>> getNestedPublisherTo
return this.publisherNested;
}

@SuppressWarnings("NullAway")
public Mono<byte[]> getContent() {
return Mono.defer(() -> {
if (this.content.scan(Scannable.Attr.TERMINATED) == Boolean.TRUE) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ public Object getAsyncResult() {
}

@Override
@SuppressWarnings("NullAway")
public Object getAsyncResult(long timeToWait) {
if (this.mockRequest.getAsyncContext() != null && timeToWait == -1) {
long requestTimeout = this.mockRequest.getAsyncContext().getTimeout();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,25 @@ public MockHttpServletResponse getResponse() {
}

@Override
@Nullable
public Object getHandler() {
return getTarget().getHandler();
}

@Override
@Nullable
public HandlerInterceptor[] getInterceptors() {
return getTarget().getInterceptors();
}

@Override
@Nullable
public ModelAndView getModelAndView() {
return getTarget().getModelAndView();
}

@Override
@Nullable
public Exception getResolvedException() {
return getTarget().getResolvedException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ public MvcResultAssert hasViewName(String viewName) {
}


@SuppressWarnings("NullAway")
private ModelAndView getModelAndView() {
ModelAndView modelAndView = this.actual.getModelAndView();
Assertions.assertThat(modelAndView).as("ModelAndView").isNotNull();
Expand Down
Loading

0 comments on commit 996e66a

Please sign in to comment.