diff --git a/server/src/main/java/org/opensearch/tasks/Task.java b/server/src/main/java/org/opensearch/tasks/Task.java index cd6b1845cd9d1..e501f68be0f5b 100644 --- a/server/src/main/java/org/opensearch/tasks/Task.java +++ b/server/src/main/java/org/opensearch/tasks/Task.java @@ -32,6 +32,8 @@ package org.opensearch.tasks; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionResponse; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.NamedWriteable; @@ -45,6 +47,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; /** @@ -54,6 +57,8 @@ */ public class Task { + private static final Logger logger = LogManager.getLogger(Task.class); + /** * The request header to mark tasks with specific ids */ @@ -77,9 +82,13 @@ public class Task { private final List resourceTrackingListeners; - // Keeps track of the number of active resource tracking threads for this task. It is initialized to 1 to track - // the task's own/self thread. When this value becomes 0, all threads have been marked inactive and the resource - // tracking can be stopped for this task. + private final AtomicBoolean isResourceTrackingCompleted = new AtomicBoolean(false); + + /** + * Keeps track of the number of active resource tracking threads for this task. It is initialized to 1 to track + * the task's own/self thread. When this value becomes 0, all threads have been marked inactive and the resource + * tracking can be stopped for this task. + */ private final AtomicInteger numActiveResourceTrackingThreads = new AtomicInteger(1); /** @@ -310,7 +319,13 @@ public void startThreadResourceTracking(long threadId, ResourceStatsType statsTy } threadResourceInfoList.add(new ThreadResourceInfo(threadId, statsType, resourceUsageMetrics)); incrementResourceTrackingThreads(); - resourceTrackingListeners.forEach(listener -> listener.onTaskExecutionStartedOnThread(this, threadId)); + resourceTrackingListeners.forEach(listener -> { + try { + listener.onTaskExecutionStartedOnThread(this, threadId); + } catch (Exception e) { + logger.warn("failure in listener during handling of onTaskExecutionStartedOnThread", e); + } + }); } /** @@ -328,7 +343,13 @@ public void updateThreadResourceStats(long threadId, ResourceStatsType statsType // the active entry present in the list is updated if (threadResourceInfo.getStatsType() == statsType && threadResourceInfo.isActive()) { threadResourceInfo.recordResourceUsageMetrics(resourceUsageMetrics); - resourceTrackingListeners.forEach(listener -> listener.onTaskResourceStatsUpdated(this)); + resourceTrackingListeners.forEach(listener -> { + try { + listener.onTaskResourceStatsUpdated(this); + } catch (Exception e) { + logger.warn("failure in listener during handling of onTaskResourceStatsUpdated", e); + } + }); return; } } @@ -352,7 +373,14 @@ public void stopThreadResourceTracking(long threadId, ResourceStatsType statsTyp threadResourceInfo.setActive(false); threadResourceInfo.recordResourceUsageMetrics(resourceUsageMetrics); decrementResourceTrackingThreads(); - resourceTrackingListeners.forEach(listener -> listener.onTaskExecutionFinishedOnThread(this, threadId)); + resourceTrackingListeners.forEach(listener -> { + try { + listener.onTaskExecutionFinishedOnThread(this, threadId); + } catch (Exception e) { + logger.warn("failure in listener during handling of onTaskExecutionFinishedOnThread", e); + } + }); + return; } } @@ -422,15 +450,30 @@ public int incrementResourceTrackingThreads() { /** * Decrements the number of active resource tracking threads. - * When this value becomes zero, the onTaskResourceTrackingCompleted method is called on all registered listeners. + * This method is called when threads finish execution, and also when the task is unregistered (to mark the task's + * own thread as complete). When the active thread count becomes zero, the onTaskResourceTrackingCompleted method + * is called exactly once on all registered listeners. + * + * Since a task is unregistered after the message is processed, it implies that the threads responsible to produce + * the response must have started prior to it (i.e. startThreadResourceTracking called before unregister). + * This ensures that the number of active threads doesn't drop to zero pre-maturely. + * + * Rarely, some threads may even start execution after the task is unregistered. As resource stats are piggy-backed + * with the response, any thread usage info captured after the task is unregistered may be irrelevant. * * @return the number of active resource tracking threads. */ public int decrementResourceTrackingThreads() { int count = numActiveResourceTrackingThreads.decrementAndGet(); - if (count == 0) { - resourceTrackingListeners.forEach(listener -> listener.onTaskResourceTrackingCompleted(this)); + if (count == 0 && isResourceTrackingCompleted.compareAndSet(false, true)) { + resourceTrackingListeners.forEach(listener -> { + try { + listener.onTaskResourceTrackingCompleted(this); + } catch (Exception e) { + logger.warn("failure in listener during handling of onTaskResourceTrackingCompleted", e); + } + }); } return count; diff --git a/server/src/main/java/org/opensearch/tasks/TaskManager.java b/server/src/main/java/org/opensearch/tasks/TaskManager.java index 4197dd2b5eaaa..b98f1e44df349 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskManager.java +++ b/server/src/main/java/org/opensearch/tasks/TaskManager.java @@ -213,7 +213,7 @@ public void cancel(CancellableTask task, String reason, Runnable listener) { public Task unregister(Task task) { logger.trace("unregister task for id: {}", task.getId()); - // Decrement the task's self-thread as a part of unregistration. + // Decrement the task's self-thread as part of unregistration. task.decrementResourceTrackingThreads(); if (task instanceof CancellableTask) { diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index 71b829e023385..c3cad117390e4 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -110,12 +110,6 @@ public void stopTracking(Task task) { if (isCurrentThreadWorkingOnTask(task)) { taskExecutionFinishedOnThread(task.getId(), Thread.currentThread().getId()); } - - List threadsWorkingOnTask = getThreadsWorkingOnTask(task); - if (threadsWorkingOnTask.size() > 0) { - logger.warn("No thread should be active when task finishes. Active threads: {}", threadsWorkingOnTask); - assert false : "No thread should be marked active when task finishes"; - } } catch (Exception e) { logger.warn("Failed while trying to mark the task execution on current thread completed.", e); assert false; @@ -165,11 +159,10 @@ private void refreshResourceStats(Task resourceAwareTask) { @Override public void taskExecutionStartedOnThread(long taskId, long threadId) { try { - if (resourceAwareTasks.containsKey(taskId)) { + final Task task = resourceAwareTasks.get(taskId); + if (task != null) { logger.debug("Task execution started on thread. Task: {}, Thread: {}", taskId, threadId); - - resourceAwareTasks.get(taskId) - .startThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); + task.startThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); } } catch (Exception e) { logger.warn(new ParameterizedMessage("Failed to mark thread execution started for task: [{}]", taskId), e); @@ -187,10 +180,10 @@ public void taskExecutionStartedOnThread(long taskId, long threadId) { @Override public void taskExecutionFinishedOnThread(long taskId, long threadId) { try { - if (resourceAwareTasks.containsKey(taskId)) { + final Task task = resourceAwareTasks.get(taskId); + if (task != null) { logger.debug("Task execution finished on thread. Task: {}, Thread: {}", taskId, threadId); - resourceAwareTasks.get(taskId) - .stopThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); + task.stopThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); } } catch (Exception e) { logger.warn(new ParameterizedMessage("Failed to mark thread execution finished for task: [{}]", taskId), e);