Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,8 +19,12 @@
import java.util.Date;
import java.util.concurrent.ScheduledFuture;

import org.springframework.core.task.TaskExecutor;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.scheduling.Trigger;
import org.springframework.security.concurrent.DelegatingSecurityContextRunnable;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.util.Assert;

/**
Expand All @@ -32,45 +36,67 @@
*/
public class DelegatingSecurityContextTaskScheduler implements TaskScheduler {

private final TaskScheduler taskScheduler;
private final TaskScheduler delegate;

private final SecurityContext securityContext;

/**
* Creates a new {@link DelegatingSecurityContextTaskScheduler}
* @param taskScheduler the {@link TaskScheduler}
* Creates a new {@link DelegatingSecurityContextTaskScheduler} that uses the
* specified {@link SecurityContext}.
* @param delegateTaskScheduler the {@link TaskScheduler} to delegate to. Cannot be
* null.
* @param securityContext the {@link SecurityContext} to use for each
* {@link DelegatingSecurityContextRunnable} or null to default to the current
* {@link SecurityContext}
*/
public DelegatingSecurityContextTaskScheduler(TaskScheduler taskScheduler) {
Assert.notNull(taskScheduler, "Task scheduler must not be null");
this.taskScheduler = taskScheduler;
public DelegatingSecurityContextTaskScheduler(TaskScheduler delegateTaskScheduler,
SecurityContext securityContext) {
Assert.notNull(delegateTaskScheduler, "delegateTaskScheduler cannot be null");
this.delegate = delegateTaskScheduler;
this.securityContext = securityContext;
}

/**
* Creates a new {@link DelegatingSecurityContextTaskScheduler} that uses the current
* {@link SecurityContext} from the {@link SecurityContextHolder}.
* @param delegate the {@link TaskExecutor} to delegate to. Cannot be null.
*/
public DelegatingSecurityContextTaskScheduler(TaskScheduler delegate) {
this(delegate, null);
}

@Override
public ScheduledFuture<?> schedule(Runnable task, Trigger trigger) {
return this.taskScheduler.schedule(task, trigger);
return this.delegate.schedule(wrap(task), trigger);
}

@Override
public ScheduledFuture<?> schedule(Runnable task, Date startTime) {
return this.taskScheduler.schedule(task, startTime);
return this.delegate.schedule(wrap(task), startTime);
}

@Override
public ScheduledFuture<?> scheduleAtFixedRate(Runnable task, Date startTime, long period) {
return this.taskScheduler.scheduleAtFixedRate(task, startTime, period);
return this.delegate.scheduleAtFixedRate(wrap(task), startTime, period);
}

@Override
public ScheduledFuture<?> scheduleAtFixedRate(Runnable task, long period) {
return this.taskScheduler.scheduleAtFixedRate(task, period);
return this.delegate.scheduleAtFixedRate(wrap(task), period);
}

@Override
public ScheduledFuture<?> scheduleWithFixedDelay(Runnable task, Date startTime, long delay) {
return this.taskScheduler.scheduleWithFixedDelay(task, startTime, delay);
return this.delegate.scheduleWithFixedDelay(wrap(task), startTime, delay);
}

@Override
public ScheduledFuture<?> scheduleWithFixedDelay(Runnable task, long delay) {
return this.taskScheduler.scheduleWithFixedDelay(task, delay);
return this.delegate.scheduleWithFixedDelay(wrap(task), delay);
}

private Runnable wrap(Runnable delegate) {
return DelegatingSecurityContextRunnable.create(delegate, this.securityContext);
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,9 +16,9 @@

package org.springframework.security.scheduling;

import java.time.Duration;
import java.time.Instant;
import java.util.Date;
import java.util.concurrent.ScheduledFuture;

import org.junit.After;
import org.junit.Before;
Expand All @@ -28,11 +28,16 @@

import org.springframework.scheduling.TaskScheduler;
import org.springframework.scheduling.Trigger;
import org.springframework.scheduling.concurrent.ConcurrentTaskScheduler;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.BDDMockito.willAnswer;
import static org.mockito.Mockito.verify;

/**
Expand All @@ -47,22 +52,30 @@ public class DelegatingSecurityContextTaskSchedulerTests {
@Mock
private TaskScheduler scheduler;

@Mock
private SecurityContext securityContext;

@Mock
private Runnable runnable;

@Mock
private Trigger trigger;

private SecurityContext originalSecurityContext;

private DelegatingSecurityContextTaskScheduler delegatingSecurityContextTaskScheduler;

@Before
public void setup() {
MockitoAnnotations.initMocks(this);
this.delegatingSecurityContextTaskScheduler = new DelegatingSecurityContextTaskScheduler(this.scheduler);
this.originalSecurityContext = SecurityContextHolder.createEmptyContext();
this.delegatingSecurityContextTaskScheduler = new DelegatingSecurityContextTaskScheduler(this.scheduler,
this.securityContext);
}

@After
public void cleanup() {
SecurityContextHolder.clearContext();
this.delegatingSecurityContextTaskScheduler = null;
}

Expand All @@ -71,6 +84,36 @@ public void testSchedulerIsNotNull() {
assertThatIllegalArgumentException().isThrownBy(() -> new DelegatingSecurityContextTaskScheduler(null));
}

@Test
public void testSchedulerCurrentSecurityContext() throws Exception {
willAnswer((invocation) -> {
assertThat(SecurityContextHolder.getContext()).isEqualTo(this.originalSecurityContext);
return null;
}).given(this.runnable).run();
TaskScheduler delegateTaskScheduler = new ConcurrentTaskScheduler();
this.delegatingSecurityContextTaskScheduler = new DelegatingSecurityContextTaskScheduler(delegateTaskScheduler);
assertWrapped(this.runnable);
}

@Test
public void testSchedulerExplicitSecurityContext() throws Exception {
willAnswer((invocation) -> {
assertThat(SecurityContextHolder.getContext()).isEqualTo(this.securityContext);
return null;
}).given(this.runnable).run();
TaskScheduler delegateTaskScheduler = new ConcurrentTaskScheduler();
this.delegatingSecurityContextTaskScheduler = new DelegatingSecurityContextTaskScheduler(delegateTaskScheduler,
this.securityContext);
assertWrapped(this.runnable);
}

private void assertWrapped(Runnable runnable) throws Exception {
ScheduledFuture<?> schedule = this.delegatingSecurityContextTaskScheduler.schedule(runnable, new Date());
schedule.get();
verify(this.runnable).run();
assertThat(SecurityContextHolder.getContext()).isEqualTo(this.originalSecurityContext);
}

@Test
public void testSchedulerWithRunnableAndTrigger() {
this.delegatingSecurityContextTaskScheduler.schedule(this.runnable, this.trigger);
Expand All @@ -87,7 +130,6 @@ public void testSchedulerWithRunnableAndInstant() {
@Test
public void testScheduleAtFixedRateWithRunnableAndDate() {
Date date = new Date(1544751374L);
Duration duration = Duration.ofSeconds(4L);
this.delegatingSecurityContextTaskScheduler.scheduleAtFixedRate(this.runnable, date, 1000L);
verify(this.scheduler).scheduleAtFixedRate(isA(Runnable.class), isA(Date.class), eq(1000L));
}
Expand Down