diff --git a/src/main/java/com/blade/task/TaskManager.java b/src/main/java/com/blade/task/TaskManager.java index 14a28719f..08d2e9157 100644 --- a/src/main/java/com/blade/task/TaskManager.java +++ b/src/main/java/com/blade/task/TaskManager.java @@ -19,12 +19,10 @@ import lombok.AccessLevel; import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; -import lombok.var; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantReadWriteLock; import static com.blade.kit.BladeKit.getStartedSymbol; @@ -41,6 +39,9 @@ public final class TaskManager { private final static Map TASK_MAP = new HashMap<>(8); + private final static ReentrantReadWriteLock rrw = new ReentrantReadWriteLock(); + private final static Lock readLock = rrw.readLock(); + private final static Lock writeLock = rrw.writeLock(); private static CronExecutorService cronExecutorService; @@ -57,21 +58,45 @@ public static CronExecutorService getExecutorService() { } public static void addTask(Task task) { - TASK_MAP.put(task.getName(), task); + writeLock.lock(); + try { + TASK_MAP.put(task.getName(), task); + } finally { + writeLock.unlock(); + } log.info("{}Add task [{}]", getStartedSymbol(), task.getName()); } public static List getTasks() { - return new ArrayList<>(TASK_MAP.values()); + Collection values; + readLock.lock(); + try { + values = Optional.ofNullable(TASK_MAP.values()).orElse(Collections.EMPTY_LIST); + } finally { + readLock.unlock(); + } + return new ArrayList<>(values); } public static Task getTask(String name) { - return TASK_MAP.get(name); + readLock.lock(); + try { + return TASK_MAP.get(name); + } finally { + readLock.unlock(); + } + } public static boolean stopTask(String name) { - var task = TASK_MAP.get(name); - return task.stop(); + Task task; + readLock.lock(); + try { + task = TASK_MAP.get(name); + } finally { + readLock.unlock(); + } + return task == null ? Boolean.FALSE : task.stop(); } } diff --git a/src/test/java/com/blade/task/TaskManagerTest.java b/src/test/java/com/blade/task/TaskManagerTest.java new file mode 100644 index 000000000..5fea19d00 --- /dev/null +++ b/src/test/java/com/blade/task/TaskManagerTest.java @@ -0,0 +1,30 @@ +package com.blade.task; + +import java.util.concurrent.CountDownLatch; +import java.util.stream.IntStream; +import org.junit.Assert; +import org.junit.Test; + +/** + * @author PSH + * @date 2019/03/16 + */ +public class TaskManagerTest { + + @Test + public void testAddTaskMultiThreading() throws Exception { + + final int tackCount = 500; + CountDownLatch downLatch = new CountDownLatch(tackCount); + IntStream.range(0, tackCount).forEach(i -> { + Task task = new Task("task-" + i, null, Integer.MAX_VALUE); + new Thread(() -> { + TaskManager.addTask(task); + downLatch.countDown(); + }).start(); + }); + + downLatch.await(); + Assert.assertEquals(tackCount, TaskManager.getTasks().size()); + } +}