From ed92dee9f5d847268fb7d5a1d2400ecebfafbe54 Mon Sep 17 00:00:00 2001
From: Dmytro Nosan <dimanosan@gmail.com>
Date: Fri, 15 Nov 2024 18:07:50 +0200
Subject: [PATCH] Add TaskDecorator support for scheduled tasks

---
 .../task/TaskSchedulingConfigurations.java    | 12 +++-
 .../TaskSchedulingAutoConfigurationTests.java | 36 +++++++++++
 .../task/SimpleAsyncTaskSchedulerBuilder.java | 40 +++++++++---
 .../task/ThreadPoolTaskSchedulerBuilder.java  | 62 +++++++++++++++----
 .../SimpleAsyncTaskSchedulerBuilderTests.java | 10 ++-
 .../ThreadPoolTaskSchedulerBuilderTests.java  | 10 ++-
 6 files changed, 145 insertions(+), 25 deletions(-)

diff --git a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/task/TaskSchedulingConfigurations.java b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/task/TaskSchedulingConfigurations.java
index 0112171fa7af..b0cb0d187c20 100644
--- a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/task/TaskSchedulingConfigurations.java
+++ b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/task/TaskSchedulingConfigurations.java
@@ -29,6 +29,7 @@
 import org.springframework.boot.task.ThreadPoolTaskSchedulerCustomizer;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Configuration;
+import org.springframework.core.task.TaskDecorator;
 import org.springframework.scheduling.TaskScheduler;
 import org.springframework.scheduling.concurrent.SimpleAsyncTaskScheduler;
 import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
@@ -67,7 +68,8 @@ static class ThreadPoolTaskSchedulerBuilderConfiguration {
 		@Bean
 		@ConditionalOnMissingBean
 		ThreadPoolTaskSchedulerBuilder threadPoolTaskSchedulerBuilder(TaskSchedulingProperties properties,
-				ObjectProvider<ThreadPoolTaskSchedulerCustomizer> threadPoolTaskSchedulerCustomizers) {
+				ObjectProvider<ThreadPoolTaskSchedulerCustomizer> threadPoolTaskSchedulerCustomizers,
+				ObjectProvider<TaskDecorator> taskDecorator) {
 			TaskSchedulingProperties.Shutdown shutdown = properties.getShutdown();
 			ThreadPoolTaskSchedulerBuilder builder = new ThreadPoolTaskSchedulerBuilder();
 			builder = builder.poolSize(properties.getPool().getSize());
@@ -75,6 +77,7 @@ ThreadPoolTaskSchedulerBuilder threadPoolTaskSchedulerBuilder(TaskSchedulingProp
 			builder = builder.awaitTerminationPeriod(shutdown.getAwaitTerminationPeriod());
 			builder = builder.threadNamePrefix(properties.getThreadNamePrefix());
 			builder = builder.customizers(threadPoolTaskSchedulerCustomizers);
+			builder = builder.taskDecorator(taskDecorator.getIfUnique());
 			return builder;
 		}
 
@@ -87,10 +90,14 @@ static class SimpleAsyncTaskSchedulerBuilderConfiguration {
 
 		private final ObjectProvider<SimpleAsyncTaskSchedulerCustomizer> taskSchedulerCustomizers;
 
+		private final ObjectProvider<TaskDecorator> taskDecorator;
+
 		SimpleAsyncTaskSchedulerBuilderConfiguration(TaskSchedulingProperties properties,
-				ObjectProvider<SimpleAsyncTaskSchedulerCustomizer> taskSchedulerCustomizers) {
+				ObjectProvider<SimpleAsyncTaskSchedulerCustomizer> taskSchedulerCustomizers,
+				ObjectProvider<TaskDecorator> taskDecorator) {
 			this.properties = properties;
 			this.taskSchedulerCustomizers = taskSchedulerCustomizers;
+			this.taskDecorator = taskDecorator;
 		}
 
 		@Bean
@@ -117,6 +124,7 @@ private SimpleAsyncTaskSchedulerBuilder builder() {
 			if (shutdown.isAwaitTermination()) {
 				builder = builder.taskTerminationTimeout(shutdown.getAwaitTerminationPeriod());
 			}
+			builder = builder.taskDecorator(this.taskDecorator.getIfUnique());
 			return builder;
 		}
 
diff --git a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/task/TaskSchedulingAutoConfigurationTests.java b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/task/TaskSchedulingAutoConfigurationTests.java
index 74dc49d97403..62963d383229 100644
--- a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/task/TaskSchedulingAutoConfigurationTests.java
+++ b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/task/TaskSchedulingAutoConfigurationTests.java
@@ -41,6 +41,7 @@
 import org.springframework.boot.test.context.runner.ApplicationContextRunner;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Configuration;
+import org.springframework.core.task.TaskDecorator;
 import org.springframework.core.task.TaskExecutor;
 import org.springframework.scheduling.TaskScheduler;
 import org.springframework.scheduling.annotation.EnableScheduling;
@@ -50,6 +51,7 @@
 import org.springframework.scheduling.config.ScheduledTaskRegistrar;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Mockito.mock;
 
 /**
  * Tests for {@link TaskSchedulingAutoConfiguration}.
@@ -154,6 +156,30 @@ void simpleAsyncTaskSchedulerBuilderShouldApplyCustomizers() {
 			});
 	}
 
+	@Test
+	void simpleAsyncTaskSchedulerBuilderShouldApplyTaskDecorator() {
+		this.contextRunner.withUserConfiguration(SchedulingConfiguration.class, TaskDecoratorConfig.class)
+			.run((context) -> {
+				assertThat(context).hasSingleBean(SimpleAsyncTaskSchedulerBuilder.class);
+				assertThat(context).hasSingleBean(TaskDecorator.class);
+				TaskDecorator taskDecorator = context.getBean(TaskDecorator.class);
+				SimpleAsyncTaskSchedulerBuilder builder = context.getBean(SimpleAsyncTaskSchedulerBuilder.class);
+				assertThat(builder).extracting("taskDecorator").isSameAs(taskDecorator);
+			});
+	}
+
+	@Test
+	void threadPoolTaskSchedulerBuilderShouldApplyTaskDecorator() {
+		this.contextRunner.withUserConfiguration(SchedulingConfiguration.class, TaskDecoratorConfig.class)
+			.run((context) -> {
+				assertThat(context).hasSingleBean(ThreadPoolTaskSchedulerBuilder.class);
+				assertThat(context).hasSingleBean(TaskDecorator.class);
+				TaskDecorator taskDecorator = context.getBean(TaskDecorator.class);
+				ThreadPoolTaskSchedulerBuilder builder = context.getBean(ThreadPoolTaskSchedulerBuilder.class);
+				assertThat(builder).extracting("taskDecorator").isSameAs(taskDecorator);
+			});
+	}
+
 	@Test
 	void enableSchedulingWithNoTaskExecutorAppliesCustomizers() {
 		this.contextRunner.withPropertyValues("spring.task.scheduling.thread-name-prefix=scheduling-test-")
@@ -305,4 +331,14 @@ static class TestTaskScheduler extends ThreadPoolTaskScheduler {
 
 	}
 
+	@Configuration(proxyBeanMethods = false)
+	static class TaskDecoratorConfig {
+
+		@Bean
+		TaskDecorator mockTaskDecorator() {
+			return mock(TaskDecorator.class);
+		}
+
+	}
+
 }
diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/task/SimpleAsyncTaskSchedulerBuilder.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/task/SimpleAsyncTaskSchedulerBuilder.java
index 4e2f4069bd8c..6e5df9055bfd 100644
--- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/task/SimpleAsyncTaskSchedulerBuilder.java
+++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/task/SimpleAsyncTaskSchedulerBuilder.java
@@ -1,5 +1,5 @@
 /*
- * Copyright 2012-2023 the original author or authors.
+ * Copyright 2012-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@
 import java.util.Set;
 
 import org.springframework.boot.context.properties.PropertyMapper;
+import org.springframework.core.task.TaskDecorator;
 import org.springframework.scheduling.concurrent.SimpleAsyncTaskScheduler;
 import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
@@ -51,17 +52,26 @@ public class SimpleAsyncTaskSchedulerBuilder {
 
 	private final Duration taskTerminationTimeout;
 
+	private final TaskDecorator taskDecorator;
+
+	/**
+	 * Constructs a new {@code SimpleAsyncTaskSchedulerBuilder} with default settings.
+	 * Initializes a builder instance with all fields set to {@code null}, allowing for
+	 * further customization through its fluent API methods.
+	 */
 	public SimpleAsyncTaskSchedulerBuilder() {
-		this(null, null, null, null, null);
+		this(null, null, null, null, null, null);
 	}
 
 	private SimpleAsyncTaskSchedulerBuilder(String threadNamePrefix, Integer concurrencyLimit, Boolean virtualThreads,
-			Set<SimpleAsyncTaskSchedulerCustomizer> taskSchedulerCustomizers, Duration taskTerminationTimeout) {
+			Set<SimpleAsyncTaskSchedulerCustomizer> taskSchedulerCustomizers, Duration taskTerminationTimeout,
+			TaskDecorator taskDecorator) {
 		this.threadNamePrefix = threadNamePrefix;
 		this.concurrencyLimit = concurrencyLimit;
 		this.virtualThreads = virtualThreads;
 		this.customizers = taskSchedulerCustomizers;
 		this.taskTerminationTimeout = taskTerminationTimeout;
+		this.taskDecorator = taskDecorator;
 	}
 
 	/**
@@ -71,7 +81,7 @@ private SimpleAsyncTaskSchedulerBuilder(String threadNamePrefix, Integer concurr
 	 */
 	public SimpleAsyncTaskSchedulerBuilder threadNamePrefix(String threadNamePrefix) {
 		return new SimpleAsyncTaskSchedulerBuilder(threadNamePrefix, this.concurrencyLimit, this.virtualThreads,
-				this.customizers, this.taskTerminationTimeout);
+				this.customizers, this.taskTerminationTimeout, this.taskDecorator);
 	}
 
 	/**
@@ -81,7 +91,7 @@ public SimpleAsyncTaskSchedulerBuilder threadNamePrefix(String threadNamePrefix)
 	 */
 	public SimpleAsyncTaskSchedulerBuilder concurrencyLimit(Integer concurrencyLimit) {
 		return new SimpleAsyncTaskSchedulerBuilder(this.threadNamePrefix, concurrencyLimit, this.virtualThreads,
-				this.customizers, this.taskTerminationTimeout);
+				this.customizers, this.taskTerminationTimeout, this.taskDecorator);
 	}
 
 	/**
@@ -91,7 +101,7 @@ public SimpleAsyncTaskSchedulerBuilder concurrencyLimit(Integer concurrencyLimit
 	 */
 	public SimpleAsyncTaskSchedulerBuilder virtualThreads(Boolean virtualThreads) {
 		return new SimpleAsyncTaskSchedulerBuilder(this.threadNamePrefix, this.concurrencyLimit, virtualThreads,
-				this.customizers, this.taskTerminationTimeout);
+				this.customizers, this.taskTerminationTimeout, this.taskDecorator);
 	}
 
 	/**
@@ -102,7 +112,7 @@ public SimpleAsyncTaskSchedulerBuilder virtualThreads(Boolean virtualThreads) {
 	 */
 	public SimpleAsyncTaskSchedulerBuilder taskTerminationTimeout(Duration taskTerminationTimeout) {
 		return new SimpleAsyncTaskSchedulerBuilder(this.threadNamePrefix, this.concurrencyLimit, this.virtualThreads,
-				this.customizers, taskTerminationTimeout);
+				this.customizers, taskTerminationTimeout, this.taskDecorator);
 	}
 
 	/**
@@ -132,7 +142,7 @@ public SimpleAsyncTaskSchedulerBuilder customizers(
 			Iterable<? extends SimpleAsyncTaskSchedulerCustomizer> customizers) {
 		Assert.notNull(customizers, "Customizers must not be null");
 		return new SimpleAsyncTaskSchedulerBuilder(this.threadNamePrefix, this.concurrencyLimit, this.virtualThreads,
-				append(null, customizers), this.taskTerminationTimeout);
+				append(null, customizers), this.taskTerminationTimeout, this.taskDecorator);
 	}
 
 	/**
@@ -160,7 +170,18 @@ public SimpleAsyncTaskSchedulerBuilder additionalCustomizers(
 			Iterable<? extends SimpleAsyncTaskSchedulerCustomizer> customizers) {
 		Assert.notNull(customizers, "Customizers must not be null");
 		return new SimpleAsyncTaskSchedulerBuilder(this.threadNamePrefix, this.concurrencyLimit, this.virtualThreads,
-				append(this.customizers, customizers), this.taskTerminationTimeout);
+				append(this.customizers, customizers), this.taskTerminationTimeout, this.taskDecorator);
+	}
+
+	/**
+	 * Set the task decorator to be used by the {@link SimpleAsyncTaskScheduler}.
+	 * @param taskDecorator the task decorator to set
+	 * @return a new builder instance
+	 * @since 3.5.0
+	 */
+	public SimpleAsyncTaskSchedulerBuilder taskDecorator(TaskDecorator taskDecorator) {
+		return new SimpleAsyncTaskSchedulerBuilder(this.threadNamePrefix, this.concurrencyLimit, this.virtualThreads,
+				this.customizers, this.taskTerminationTimeout, taskDecorator);
 	}
 
 	/**
@@ -187,6 +208,7 @@ public <T extends SimpleAsyncTaskScheduler> T configure(T taskScheduler) {
 		map.from(this.concurrencyLimit).to(taskScheduler::setConcurrencyLimit);
 		map.from(this.virtualThreads).to(taskScheduler::setVirtualThreads);
 		map.from(this.taskTerminationTimeout).as(Duration::toMillis).to(taskScheduler::setTaskTerminationTimeout);
+		map.from(this.taskDecorator).to(taskScheduler::setTaskDecorator);
 		if (!CollectionUtils.isEmpty(this.customizers)) {
 			this.customizers.forEach((customizer) -> customizer.customize(taskScheduler));
 		}
diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/task/ThreadPoolTaskSchedulerBuilder.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/task/ThreadPoolTaskSchedulerBuilder.java
index a36e48308ee4..9815554056e3 100644
--- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/task/ThreadPoolTaskSchedulerBuilder.java
+++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/task/ThreadPoolTaskSchedulerBuilder.java
@@ -1,5 +1,5 @@
 /*
- * Copyright 2012-2023 the original author or authors.
+ * Copyright 2012-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@
 import java.util.Set;
 
 import org.springframework.boot.context.properties.PropertyMapper;
+import org.springframework.core.task.TaskDecorator;
 import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
 import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
@@ -48,23 +49,48 @@ public class ThreadPoolTaskSchedulerBuilder {
 
 	private final String threadNamePrefix;
 
+	private final TaskDecorator taskDecorator;
+
 	private final Set<ThreadPoolTaskSchedulerCustomizer> customizers;
 
+	/**
+	 * Default constructor for creating a new instance of
+	 * {@code ThreadPoolTaskSchedulerBuilder}. Initializes a builder instance with all
+	 * fields set to {@code null}, allowing for further customization through its fluent
+	 * API methods.
+	 */
 	public ThreadPoolTaskSchedulerBuilder() {
-		this.poolSize = null;
-		this.awaitTermination = null;
-		this.awaitTerminationPeriod = null;
-		this.threadNamePrefix = null;
-		this.customizers = null;
+		this(null, null, null, null, null, null);
 	}
 
+	/**
+	 * Constructs a new {@code ThreadPoolTaskSchedulerBuilder} instance with the specified
+	 * configuration.
+	 * @param poolSize the maximum allowed number of threads
+	 * @param awaitTermination whether the executor should wait for scheduled tasks to
+	 * complete on shutdown
+	 * @param awaitTerminationPeriod the maximum time the executor is supposed to block on
+	 * shutdown
+	 * @param threadNamePrefix the prefix to use for the names of newly created threads
+	 * @param taskSchedulerCustomizers the customizers to apply to the
+	 * {@link ThreadPoolTaskScheduler}
+	 * @deprecated since 3.5.0 for removal in 3.7.0 in favor of the default constructor
+	 */
+	@Deprecated(since = "3.5.0", forRemoval = true)
 	public ThreadPoolTaskSchedulerBuilder(Integer poolSize, Boolean awaitTermination, Duration awaitTerminationPeriod,
 			String threadNamePrefix, Set<ThreadPoolTaskSchedulerCustomizer> taskSchedulerCustomizers) {
+		this(poolSize, awaitTermination, awaitTerminationPeriod, threadNamePrefix, taskSchedulerCustomizers, null);
+	}
+
+	private ThreadPoolTaskSchedulerBuilder(Integer poolSize, Boolean awaitTermination, Duration awaitTerminationPeriod,
+			String threadNamePrefix, Set<ThreadPoolTaskSchedulerCustomizer> taskSchedulerCustomizers,
+			TaskDecorator taskDecorator) {
 		this.poolSize = poolSize;
 		this.awaitTermination = awaitTermination;
 		this.awaitTerminationPeriod = awaitTerminationPeriod;
 		this.threadNamePrefix = threadNamePrefix;
 		this.customizers = taskSchedulerCustomizers;
+		this.taskDecorator = taskDecorator;
 	}
 
 	/**
@@ -74,7 +100,7 @@ public ThreadPoolTaskSchedulerBuilder(Integer poolSize, Boolean awaitTermination
 	 */
 	public ThreadPoolTaskSchedulerBuilder poolSize(int poolSize) {
 		return new ThreadPoolTaskSchedulerBuilder(poolSize, this.awaitTermination, this.awaitTerminationPeriod,
-				this.threadNamePrefix, this.customizers);
+				this.threadNamePrefix, this.customizers, this.taskDecorator);
 	}
 
 	/**
@@ -87,7 +113,7 @@ public ThreadPoolTaskSchedulerBuilder poolSize(int poolSize) {
 	 */
 	public ThreadPoolTaskSchedulerBuilder awaitTermination(boolean awaitTermination) {
 		return new ThreadPoolTaskSchedulerBuilder(this.poolSize, awaitTermination, this.awaitTerminationPeriod,
-				this.threadNamePrefix, this.customizers);
+				this.threadNamePrefix, this.customizers, this.taskDecorator);
 	}
 
 	/**
@@ -101,7 +127,7 @@ public ThreadPoolTaskSchedulerBuilder awaitTermination(boolean awaitTermination)
 	 */
 	public ThreadPoolTaskSchedulerBuilder awaitTerminationPeriod(Duration awaitTerminationPeriod) {
 		return new ThreadPoolTaskSchedulerBuilder(this.poolSize, this.awaitTermination, awaitTerminationPeriod,
-				this.threadNamePrefix, this.customizers);
+				this.threadNamePrefix, this.customizers, this.taskDecorator);
 	}
 
 	/**
@@ -111,7 +137,18 @@ public ThreadPoolTaskSchedulerBuilder awaitTerminationPeriod(Duration awaitTermi
 	 */
 	public ThreadPoolTaskSchedulerBuilder threadNamePrefix(String threadNamePrefix) {
 		return new ThreadPoolTaskSchedulerBuilder(this.poolSize, this.awaitTermination, this.awaitTerminationPeriod,
-				threadNamePrefix, this.customizers);
+				threadNamePrefix, this.customizers, this.taskDecorator);
+	}
+
+	/**
+	 * Set the {@link TaskDecorator} to be applied to the {@link ThreadPoolTaskScheduler}.
+	 * @param taskDecorator the task decorator to set
+	 * @return a new builder instance
+	 * @since 3.5.0
+	 */
+	public ThreadPoolTaskSchedulerBuilder taskDecorator(TaskDecorator taskDecorator) {
+		return new ThreadPoolTaskSchedulerBuilder(this.poolSize, this.awaitTermination, this.awaitTerminationPeriod,
+				this.threadNamePrefix, this.customizers, taskDecorator);
 	}
 
 	/**
@@ -143,7 +180,7 @@ public ThreadPoolTaskSchedulerBuilder customizers(
 			Iterable<? extends ThreadPoolTaskSchedulerCustomizer> customizers) {
 		Assert.notNull(customizers, "Customizers must not be null");
 		return new ThreadPoolTaskSchedulerBuilder(this.poolSize, this.awaitTermination, this.awaitTerminationPeriod,
-				this.threadNamePrefix, append(null, customizers));
+				this.threadNamePrefix, append(null, customizers), this.taskDecorator);
 	}
 
 	/**
@@ -173,7 +210,7 @@ public ThreadPoolTaskSchedulerBuilder additionalCustomizers(
 			Iterable<? extends ThreadPoolTaskSchedulerCustomizer> customizers) {
 		Assert.notNull(customizers, "Customizers must not be null");
 		return new ThreadPoolTaskSchedulerBuilder(this.poolSize, this.awaitTermination, this.awaitTerminationPeriod,
-				this.threadNamePrefix, append(this.customizers, customizers));
+				this.threadNamePrefix, append(this.customizers, customizers), this.taskDecorator);
 	}
 
 	/**
@@ -199,6 +236,7 @@ public <T extends ThreadPoolTaskScheduler> T configure(T taskScheduler) {
 		map.from(this.awaitTermination).to(taskScheduler::setWaitForTasksToCompleteOnShutdown);
 		map.from(this.awaitTerminationPeriod).asInt(Duration::getSeconds).to(taskScheduler::setAwaitTerminationSeconds);
 		map.from(this.threadNamePrefix).to(taskScheduler::setThreadNamePrefix);
+		map.from(this.taskDecorator).to(taskScheduler::setTaskDecorator);
 		if (!CollectionUtils.isEmpty(this.customizers)) {
 			this.customizers.forEach((customizer) -> customizer.customize(taskScheduler));
 		}
diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/task/SimpleAsyncTaskSchedulerBuilderTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/task/SimpleAsyncTaskSchedulerBuilderTests.java
index 9cb06c5f3213..f2ac75112a8e 100644
--- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/task/SimpleAsyncTaskSchedulerBuilderTests.java
+++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/task/SimpleAsyncTaskSchedulerBuilderTests.java
@@ -1,5 +1,5 @@
 /*
- * Copyright 2012-2023 the original author or authors.
+ * Copyright 2012-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -24,6 +24,7 @@
 import org.junit.jupiter.api.condition.EnabledForJreRange;
 import org.junit.jupiter.api.condition.JRE;
 
+import org.springframework.core.task.TaskDecorator;
 import org.springframework.scheduling.concurrent.SimpleAsyncTaskScheduler;
 
 import static org.assertj.core.api.Assertions.assertThat;
@@ -134,4 +135,11 @@ void taskTerminationTimeoutShouldApply() {
 		assertThat(scheduler).extracting("taskTerminationTimeout").isEqualTo(1000L);
 	}
 
+	@Test
+	void taskDecoratorShouldApply() {
+		TaskDecorator taskDecorator = mock(TaskDecorator.class);
+		SimpleAsyncTaskScheduler scheduler = this.builder.taskDecorator(taskDecorator).build();
+		assertThat(scheduler).extracting("taskDecorator").isSameAs(taskDecorator);
+	}
+
 }
diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/task/ThreadPoolTaskSchedulerBuilderTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/task/ThreadPoolTaskSchedulerBuilderTests.java
index 11b4f15f49af..9411700bfb15 100644
--- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/task/ThreadPoolTaskSchedulerBuilderTests.java
+++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/task/ThreadPoolTaskSchedulerBuilderTests.java
@@ -1,5 +1,5 @@
 /*
- * Copyright 2012-2023 the original author or authors.
+ * Copyright 2012-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -22,6 +22,7 @@
 
 import org.junit.jupiter.api.Test;
 
+import org.springframework.core.task.TaskDecorator;
 import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
 
 import static org.assertj.core.api.Assertions.assertThat;
@@ -131,4 +132,11 @@ void additionalCustomizersShouldAddToExisting() {
 		then(customizer2).should().customize(scheduler);
 	}
 
+	@Test
+	void taskDecoratorShouldApply() {
+		TaskDecorator taskDecorator = mock(TaskDecorator.class);
+		ThreadPoolTaskScheduler scheduler = this.builder.taskDecorator(taskDecorator).build();
+		assertThat(scheduler).extracting("taskDecorator").isSameAs(taskDecorator);
+	}
+
 }