diff --git a/core/src/main/java/org/springframework/security/scheduling/DelegatingSecurityContextTaskScheduler.java b/core/src/main/java/org/springframework/security/scheduling/DelegatingSecurityContextTaskScheduler.java index 0925b7fa1fc..b245a605cb5 100644 --- a/core/src/main/java/org/springframework/security/scheduling/DelegatingSecurityContextTaskScheduler.java +++ b/core/src/main/java/org/springframework/security/scheduling/DelegatingSecurityContextTaskScheduler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,8 +19,12 @@ import java.util.Date; import java.util.concurrent.ScheduledFuture; +import org.springframework.core.task.TaskExecutor; import org.springframework.scheduling.TaskScheduler; import org.springframework.scheduling.Trigger; +import org.springframework.security.concurrent.DelegatingSecurityContextRunnable; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.util.Assert; /** @@ -32,45 +36,67 @@ */ public class DelegatingSecurityContextTaskScheduler implements TaskScheduler { - private final TaskScheduler taskScheduler; + private final TaskScheduler delegate; + + private final SecurityContext securityContext; /** - * Creates a new {@link DelegatingSecurityContextTaskScheduler} - * @param taskScheduler the {@link TaskScheduler} + * Creates a new {@link DelegatingSecurityContextTaskScheduler} that uses the + * specified {@link SecurityContext}. + * @param delegateTaskScheduler the {@link TaskScheduler} to delegate to. Cannot be + * null. + * @param securityContext the {@link SecurityContext} to use for each + * {@link DelegatingSecurityContextRunnable} or null to default to the current + * {@link SecurityContext} */ - public DelegatingSecurityContextTaskScheduler(TaskScheduler taskScheduler) { - Assert.notNull(taskScheduler, "Task scheduler must not be null"); - this.taskScheduler = taskScheduler; + public DelegatingSecurityContextTaskScheduler(TaskScheduler delegateTaskScheduler, + SecurityContext securityContext) { + Assert.notNull(delegateTaskScheduler, "delegateTaskScheduler cannot be null"); + this.delegate = delegateTaskScheduler; + this.securityContext = securityContext; + } + + /** + * Creates a new {@link DelegatingSecurityContextTaskScheduler} that uses the current + * {@link SecurityContext} from the {@link SecurityContextHolder}. + * @param delegate the {@link TaskExecutor} to delegate to. Cannot be null. + */ + public DelegatingSecurityContextTaskScheduler(TaskScheduler delegate) { + this(delegate, null); } @Override public ScheduledFuture schedule(Runnable task, Trigger trigger) { - return this.taskScheduler.schedule(task, trigger); + return this.delegate.schedule(wrap(task), trigger); } @Override public ScheduledFuture schedule(Runnable task, Date startTime) { - return this.taskScheduler.schedule(task, startTime); + return this.delegate.schedule(wrap(task), startTime); } @Override public ScheduledFuture scheduleAtFixedRate(Runnable task, Date startTime, long period) { - return this.taskScheduler.scheduleAtFixedRate(task, startTime, period); + return this.delegate.scheduleAtFixedRate(wrap(task), startTime, period); } @Override public ScheduledFuture scheduleAtFixedRate(Runnable task, long period) { - return this.taskScheduler.scheduleAtFixedRate(task, period); + return this.delegate.scheduleAtFixedRate(wrap(task), period); } @Override public ScheduledFuture scheduleWithFixedDelay(Runnable task, Date startTime, long delay) { - return this.taskScheduler.scheduleWithFixedDelay(task, startTime, delay); + return this.delegate.scheduleWithFixedDelay(wrap(task), startTime, delay); } @Override public ScheduledFuture scheduleWithFixedDelay(Runnable task, long delay) { - return this.taskScheduler.scheduleWithFixedDelay(task, delay); + return this.delegate.scheduleWithFixedDelay(wrap(task), delay); + } + + private Runnable wrap(Runnable delegate) { + return DelegatingSecurityContextRunnable.create(delegate, this.securityContext); } } diff --git a/core/src/test/java/org/springframework/security/scheduling/DelegatingSecurityContextTaskSchedulerTests.java b/core/src/test/java/org/springframework/security/scheduling/DelegatingSecurityContextTaskSchedulerTests.java index 30dfff81806..065d8ca4b74 100644 --- a/core/src/test/java/org/springframework/security/scheduling/DelegatingSecurityContextTaskSchedulerTests.java +++ b/core/src/test/java/org/springframework/security/scheduling/DelegatingSecurityContextTaskSchedulerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,9 @@ package org.springframework.security.scheduling; -import java.time.Duration; import java.time.Instant; import java.util.Date; +import java.util.concurrent.ScheduledFuture; import org.junit.After; import org.junit.Before; @@ -28,11 +28,16 @@ import org.springframework.scheduling.TaskScheduler; import org.springframework.scheduling.Trigger; +import org.springframework.scheduling.concurrent.ConcurrentTaskScheduler; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.BDDMockito.willAnswer; import static org.mockito.Mockito.verify; /** @@ -47,22 +52,30 @@ public class DelegatingSecurityContextTaskSchedulerTests { @Mock private TaskScheduler scheduler; + @Mock + private SecurityContext securityContext; + @Mock private Runnable runnable; @Mock private Trigger trigger; + private SecurityContext originalSecurityContext; + private DelegatingSecurityContextTaskScheduler delegatingSecurityContextTaskScheduler; @Before public void setup() { MockitoAnnotations.initMocks(this); - this.delegatingSecurityContextTaskScheduler = new DelegatingSecurityContextTaskScheduler(this.scheduler); + this.originalSecurityContext = SecurityContextHolder.createEmptyContext(); + this.delegatingSecurityContextTaskScheduler = new DelegatingSecurityContextTaskScheduler(this.scheduler, + this.securityContext); } @After public void cleanup() { + SecurityContextHolder.clearContext(); this.delegatingSecurityContextTaskScheduler = null; } @@ -71,6 +84,36 @@ public void testSchedulerIsNotNull() { assertThatIllegalArgumentException().isThrownBy(() -> new DelegatingSecurityContextTaskScheduler(null)); } + @Test + public void testSchedulerCurrentSecurityContext() throws Exception { + willAnswer((invocation) -> { + assertThat(SecurityContextHolder.getContext()).isEqualTo(this.originalSecurityContext); + return null; + }).given(this.runnable).run(); + TaskScheduler delegateTaskScheduler = new ConcurrentTaskScheduler(); + this.delegatingSecurityContextTaskScheduler = new DelegatingSecurityContextTaskScheduler(delegateTaskScheduler); + assertWrapped(this.runnable); + } + + @Test + public void testSchedulerExplicitSecurityContext() throws Exception { + willAnswer((invocation) -> { + assertThat(SecurityContextHolder.getContext()).isEqualTo(this.securityContext); + return null; + }).given(this.runnable).run(); + TaskScheduler delegateTaskScheduler = new ConcurrentTaskScheduler(); + this.delegatingSecurityContextTaskScheduler = new DelegatingSecurityContextTaskScheduler(delegateTaskScheduler, + this.securityContext); + assertWrapped(this.runnable); + } + + private void assertWrapped(Runnable runnable) throws Exception { + ScheduledFuture schedule = this.delegatingSecurityContextTaskScheduler.schedule(runnable, new Date()); + schedule.get(); + verify(this.runnable).run(); + assertThat(SecurityContextHolder.getContext()).isEqualTo(this.originalSecurityContext); + } + @Test public void testSchedulerWithRunnableAndTrigger() { this.delegatingSecurityContextTaskScheduler.schedule(this.runnable, this.trigger); @@ -87,7 +130,6 @@ public void testSchedulerWithRunnableAndInstant() { @Test public void testScheduleAtFixedRateWithRunnableAndDate() { Date date = new Date(1544751374L); - Duration duration = Duration.ofSeconds(4L); this.delegatingSecurityContextTaskScheduler.scheduleAtFixedRate(this.runnable, date, 1000L); verify(this.scheduler).scheduleAtFixedRate(isA(Runnable.class), isA(Date.class), eq(1000L)); }