From 106b40bc223b5cc544686c340b4eafed54a17a45 Mon Sep 17 00:00:00 2001 From: Ketan Verma Date: Thu, 21 Apr 2022 19:21:44 +0530 Subject: [PATCH 1/6] Support task resource tracking in OpenSearch Reopens changes from #2639 (reverted in #3046) to add a framework for task resource tracking. Currently, SearchTask and SearchShardTask support resource tracking but it can be extended to any other task. Changes since #2639: * Replaced the usage of AutoQueueAdjustingExecutorBuilder with ResizableExecutorBuilder * Resolved merge conflicts * Fixed broken tests Signed-off-by: Ketan Verma --- .../admin/cluster/node/tasks/TasksIT.java | 6 + .../tasks/list/TransportListTasksAction.java | 13 +- .../action/search/SearchShardTask.java | 5 + .../opensearch/action/search/SearchTask.java | 5 + .../action/support/TransportAction.java | 78 +- .../org/opensearch/cluster/ClusterModule.java | 2 + .../common/settings/ClusterSettings.java | 4 +- .../util/concurrent/OpenSearchExecutors.java | 18 +- .../common/util/concurrent/ThreadContext.java | 16 +- .../main/java/org/opensearch/node/Node.java | 14 +- .../main/java/org/opensearch/tasks/Task.java | 17 +- .../org/opensearch/tasks/TaskManager.java | 27 +- .../tasks/TaskResourceTrackingService.java | 255 +++++++ .../opensearch/tasks/ThreadResourceInfo.java | 10 +- .../threadpool/ResizableExecutorBuilder.java | 25 +- .../RunnableTaskExecutionListener.java | 33 + .../threadpool/TaskAwareRunnable.java | 90 +++ .../org/opensearch/threadpool/ThreadPool.java | 16 +- .../transport/RequestHandlerRegistry.java | 4 + .../tasks/RecordingTaskManagerListener.java | 3 + .../node/tasks/ResourceAwareTasksTests.java | 670 ++++++++++++++++++ .../node/tasks/TaskManagerTestCase.java | 17 +- .../bulk/TransportBulkActionIngestTests.java | 3 +- .../util/concurrent/ThreadContextTests.java | 10 + .../snapshots/SnapshotResiliencyTests.java | 3 + .../opensearch/tasks/TaskManagerTests.java | 6 +- .../TaskResourceTrackingServiceTests.java | 97 +++ .../test/tasks/MockTaskManager.java | 16 + .../test/tasks/MockTaskManagerListener.java | 3 + .../opensearch/threadpool/TestThreadPool.java | 20 +- 30 files changed, 1425 insertions(+), 61 deletions(-) create mode 100644 server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java create mode 100644 server/src/main/java/org/opensearch/threadpool/RunnableTaskExecutionListener.java create mode 100644 server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java create mode 100644 server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java create mode 100644 server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java diff --git a/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java b/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java index d37285211f774..61059f83f0e77 100644 --- a/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java @@ -470,6 +470,9 @@ public void onTaskUnregistered(Task task) {} @Override public void waitForTaskCompletion(Task task) {} + + @Override + public void taskExecutionStarted(Task task, Boolean closeableInvoked) {} }); } // Need to run the task in a separate thread because node client's .execute() is blocked by our task listener @@ -651,6 +654,9 @@ public void waitForTaskCompletion(Task task) { waitForWaitingToStart.countDown(); } + @Override + public void taskExecutionStarted(Task task, Boolean closeableInvoked) {} + @Override public void onTaskRegistered(Task task) {} diff --git a/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java b/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java index 796ea023edd40..aede3fe5b1cc0 100644 --- a/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java +++ b/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java @@ -42,6 +42,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskInfo; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -65,8 +66,15 @@ public static long waitForCompletionTimeout(TimeValue timeout) { private static final TimeValue DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT = timeValueSeconds(30); + private final TaskResourceTrackingService taskResourceTrackingService; + @Inject - public TransportListTasksAction(ClusterService clusterService, TransportService transportService, ActionFilters actionFilters) { + public TransportListTasksAction( + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + TaskResourceTrackingService taskResourceTrackingService + ) { super( ListTasksAction.NAME, clusterService, @@ -77,6 +85,7 @@ public TransportListTasksAction(ClusterService clusterService, TransportService TaskInfo::new, ThreadPool.Names.MANAGEMENT ); + this.taskResourceTrackingService = taskResourceTrackingService; } @Override @@ -106,6 +115,8 @@ protected void processTasks(ListTasksRequest request, Consumer operation) } taskManager.waitForTaskCompletion(task, timeoutNanos); }); + } else { + operation = operation.andThen(taskResourceTrackingService::refreshResourceStats); } super.processTasks(request, operation); } diff --git a/server/src/main/java/org/opensearch/action/search/SearchShardTask.java b/server/src/main/java/org/opensearch/action/search/SearchShardTask.java index e5f25595b9ec8..57831896db714 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchShardTask.java +++ b/server/src/main/java/org/opensearch/action/search/SearchShardTask.java @@ -51,6 +51,11 @@ public SearchShardTask(long id, String type, String action, String description, super(id, type, action, description, parentTaskId, headers); } + @Override + public boolean supportsResourceTracking() { + return true; + } + @Override public boolean shouldCancelChildrenOnCancellation() { return false; diff --git a/server/src/main/java/org/opensearch/action/search/SearchTask.java b/server/src/main/java/org/opensearch/action/search/SearchTask.java index 89f23bb9bdaeb..987485fe44c65 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchTask.java +++ b/server/src/main/java/org/opensearch/action/search/SearchTask.java @@ -80,6 +80,11 @@ public final String getDescription() { return descriptionSupplier.get(); } + @Override + public boolean supportsResourceTracking() { + return true; + } + /** * Attach a {@link SearchProgressListener} to this task. */ diff --git a/server/src/main/java/org/opensearch/action/support/TransportAction.java b/server/src/main/java/org/opensearch/action/support/TransportAction.java index a60648e50ff31..71ae187b48c4e 100644 --- a/server/src/main/java/org/opensearch/action/support/TransportAction.java +++ b/server/src/main/java/org/opensearch/action/support/TransportAction.java @@ -40,6 +40,7 @@ import org.opensearch.action.ActionResponse; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskCancelledException; import org.opensearch.tasks.TaskId; @@ -93,31 +94,39 @@ public final Task execute(Request request, ActionListener listener) { */ final Releasable unregisterChildNode = registerChildNode(request.getParentTask()); final Task task; + try { task = taskManager.register("transport", actionName, request); } catch (TaskCancelledException e) { unregisterChildNode.close(); throw e; } - execute(task, request, new ActionListener() { - @Override - public void onResponse(Response response) { - try { - Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); - } finally { - listener.onResponse(response); + + ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task); + try { + execute(task, request, new ActionListener() { + @Override + public void onResponse(Response response) { + try { + Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); + } finally { + listener.onResponse(response); + } } - } - @Override - public void onFailure(Exception e) { - try { - Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); - } finally { - listener.onFailure(e); + @Override + public void onFailure(Exception e) { + try { + Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); + } finally { + listener.onFailure(e); + } } - } - }); + }); + } finally { + storedContext.close(); + } + return task; } @@ -134,25 +143,30 @@ public final Task execute(Request request, TaskListener listener) { unregisterChildNode.close(); throw e; } - execute(task, request, new ActionListener() { - @Override - public void onResponse(Response response) { - try { - Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); - } finally { - listener.onResponse(task, response); + ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task); + try { + execute(task, request, new ActionListener() { + @Override + public void onResponse(Response response) { + try { + Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); + } finally { + listener.onResponse(task, response); + } } - } - @Override - public void onFailure(Exception e) { - try { - Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); - } finally { - listener.onFailure(task, e); + @Override + public void onFailure(Exception e) { + try { + Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); + } finally { + listener.onFailure(task, e); + } } - } - }); + }); + } finally { + storedContext.close(); + } return task; } diff --git a/server/src/main/java/org/opensearch/cluster/ClusterModule.java b/server/src/main/java/org/opensearch/cluster/ClusterModule.java index 900dceb8564c9..f8ba520e465e2 100644 --- a/server/src/main/java/org/opensearch/cluster/ClusterModule.java +++ b/server/src/main/java/org/opensearch/cluster/ClusterModule.java @@ -94,6 +94,7 @@ import org.opensearch.script.ScriptMetadata; import org.opensearch.snapshots.SnapshotsInfoService; import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.tasks.TaskResultsService; import java.util.ArrayList; @@ -396,6 +397,7 @@ protected void configure() { bind(NodeMappingRefreshAction.class).asEagerSingleton(); bind(MappingUpdatedAction.class).asEagerSingleton(); bind(TaskResultsService.class).asEagerSingleton(); + bind(TaskResourceTrackingService.class).asEagerSingleton(); bind(AllocationDeciders.class).toInstance(allocationDeciders); bind(ShardsAllocator.class).toInstance(shardsAllocator); } diff --git a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java index bee3428188026..fe322992cd3e5 100644 --- a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java @@ -41,6 +41,7 @@ import org.opensearch.index.ShardIndexingPressureMemoryManager; import org.opensearch.index.ShardIndexingPressureSettings; import org.opensearch.index.ShardIndexingPressureStore; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.watcher.ResourceWatcherService; import org.opensearch.action.admin.cluster.configuration.TransportAddVotingConfigExclusionsAction; import org.opensearch.action.admin.indices.close.TransportCloseIndexAction; @@ -575,7 +576,8 @@ public void apply(Settings value, Settings current, Settings previous) { ShardIndexingPressureMemoryManager.THROUGHPUT_DEGRADATION_LIMITS, ShardIndexingPressureMemoryManager.SUCCESSFUL_REQUEST_ELAPSED_TIMEOUT, ShardIndexingPressureMemoryManager.MAX_OUTSTANDING_REQUESTS, - IndexingPressure.MAX_INDEXING_BYTES + IndexingPressure.MAX_INDEXING_BYTES, + TaskResourceTrackingService.TASK_RESOURCE_TRACKING_ENABLED ) ) ); diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java b/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java index 14f9486b4baf0..ec1024bbe5f30 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java @@ -40,6 +40,8 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.node.Node; +import org.opensearch.threadpool.RunnableTaskExecutionListener; +import org.opensearch.threadpool.TaskAwareRunnable; import java.util.List; import java.util.Optional; @@ -55,6 +57,7 @@ import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; /** @@ -182,13 +185,24 @@ public static OpenSearchThreadPoolExecutor newResizable( int size, int queueCapacity, ThreadFactory threadFactory, - ThreadContext contextHolder + ThreadContext contextHolder, + AtomicReference runnableTaskListener ) { if (queueCapacity <= 0) { throw new IllegalArgumentException("queue capacity for [" + name + "] executor must be positive, got: " + queueCapacity); } + Function runnableWrapper; + if (runnableTaskListener != null) { + runnableWrapper = (runnable) -> { + TaskAwareRunnable taskAwareRunnable = new TaskAwareRunnable(contextHolder, runnable, runnableTaskListener); + return new TimedRunnable(taskAwareRunnable); + }; + } else { + runnableWrapper = TimedRunnable::new; + } + return new QueueResizableOpenSearchThreadPoolExecutor( name, size, @@ -196,7 +210,7 @@ public static OpenSearchThreadPoolExecutor newResizable( 0, TimeUnit.MILLISECONDS, new ResizableBlockingQueue<>(ConcurrentCollections.newBlockingQueue(), queueCapacity), - TimedRunnable::new, + runnableWrapper, threadFactory, new OpenSearchAbortPolicy(), contextHolder diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java index 5e2381c949c00..5b9a77c75dddb 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java @@ -66,6 +66,7 @@ import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_WARNING_HEADER_COUNT; import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_WARNING_HEADER_SIZE; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; /** * A ThreadContext is a map of string headers and a transient map of keyed objects that are associated with @@ -135,16 +136,23 @@ public StoredContext stashContext() { * This is needed so the DeprecationLogger in another thread can see the value of X-Opaque-ID provided by a user. * Otherwise when context is stash, it should be empty. */ + + ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT; + if (context.requestHeaders.containsKey(Task.X_OPAQUE_ID)) { - ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT.putHeaders( + threadContextStruct = threadContextStruct.putHeaders( MapBuilder.newMapBuilder() .put(Task.X_OPAQUE_ID, context.requestHeaders.get(Task.X_OPAQUE_ID)) .immutableMap() ); - threadLocal.set(threadContextStruct); - } else { - threadLocal.set(DEFAULT_CONTEXT); } + + if (context.transientHeaders.containsKey(TASK_ID)) { + threadContextStruct = threadContextStruct.putTransient(TASK_ID, context.transientHeaders.get(TASK_ID)); + } + + threadLocal.set(threadContextStruct); + return () -> { // If the node and thus the threadLocal get closed while this task // is still executing, we don't want this runnable to fail with an diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index bc5f17759d498..d3f0912cab638 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -39,10 +39,12 @@ import org.opensearch.common.util.FeatureFlags; import org.opensearch.cluster.routing.allocation.AwarenessReplicaBalance; import org.opensearch.index.IndexingPressureService; +import org.opensearch.index.store.RemoteDirectoryFactory; import org.opensearch.indices.replication.SegmentReplicationSourceFactory; import org.opensearch.indices.replication.SegmentReplicationTargetService; import org.opensearch.indices.replication.SegmentReplicationSourceService; -import org.opensearch.index.store.RemoteDirectoryFactory; +import org.opensearch.tasks.TaskResourceTrackingService; +import org.opensearch.threadpool.RunnableTaskExecutionListener; import org.opensearch.watcher.ResourceWatcherService; import org.opensearch.Assertions; import org.opensearch.Build; @@ -219,6 +221,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.function.UnaryOperator; import java.util.stream.Collectors; @@ -338,6 +341,7 @@ public static class DiscoverySettings { private final LocalNodeFactory localNodeFactory; private final NodeService nodeService; final NamedWriteableRegistry namedWriteableRegistry; + private final AtomicReference runnableTaskListener; public Node(Environment environment) { this(environment, Collections.emptyList(), true); @@ -447,7 +451,8 @@ protected Node( final List> executorBuilders = pluginsService.getExecutorBuilders(settings); - final ThreadPool threadPool = new ThreadPool(settings, executorBuilders.toArray(new ExecutorBuilder[0])); + runnableTaskListener = new AtomicReference<>(); + final ThreadPool threadPool = new ThreadPool(settings, runnableTaskListener, executorBuilders.toArray(new ExecutorBuilder[0])); resourcesToClose.add(() -> ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS)); final ResourceWatcherService resourceWatcherService = new ResourceWatcherService(settings, threadPool); resourcesToClose.add(resourceWatcherService); @@ -1095,6 +1100,11 @@ public Node start() throws NodeValidationException { TransportService transportService = injector.getInstance(TransportService.class); transportService.getTaskManager().setTaskResultsService(injector.getInstance(TaskResultsService.class)); transportService.getTaskManager().setTaskCancellationService(new TaskCancellationService(transportService)); + + TaskResourceTrackingService taskResourceTrackingService = injector.getInstance(TaskResourceTrackingService.class); + transportService.getTaskManager().setTaskResourceTrackingService(taskResourceTrackingService); + runnableTaskListener.set(taskResourceTrackingService); + transportService.start(); assert localNodeFactory.getNode() != null; assert transportService.getLocalNode().equals(localNodeFactory.getNode()) diff --git a/server/src/main/java/org/opensearch/tasks/Task.java b/server/src/main/java/org/opensearch/tasks/Task.java index 522ecac5ef698..c8d5312b2e1bc 100644 --- a/server/src/main/java/org/opensearch/tasks/Task.java +++ b/server/src/main/java/org/opensearch/tasks/Task.java @@ -32,8 +32,6 @@ package org.opensearch.tasks; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionResponse; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.NamedWriteable; @@ -55,8 +53,6 @@ */ public class Task { - private static final Logger logger = LogManager.getLogger(Task.class); - /** * The request header to mark tasks with specific ids */ @@ -291,7 +287,7 @@ public void startThreadResourceTracking(long threadId, ResourceStatsType statsTy ); } } - threadResourceInfoList.add(new ThreadResourceInfo(statsType, resourceUsageMetrics)); + threadResourceInfoList.add(new ThreadResourceInfo(threadId, statsType, resourceUsageMetrics)); } /** @@ -338,6 +334,17 @@ public void stopThreadResourceTracking(long threadId, ResourceStatsType statsTyp throw new IllegalStateException("cannot update final values if active thread resource entry is not present"); } + /** + * Individual tasks can override this if they want to support task resource tracking. We just need to make sure that + * the ThreadPool on which the task runs on have runnable wrapper similar to + * {@link org.opensearch.common.util.concurrent.OpenSearchExecutors#newResizable} + * + * @return true if resource tracking is supported by the task + */ + public boolean supportsResourceTracking() { + return false; + } + /** * Report of the internal status of a task. These can vary wildly from task * to task because each task is implemented differently but we should try diff --git a/server/src/main/java/org/opensearch/tasks/TaskManager.java b/server/src/main/java/org/opensearch/tasks/TaskManager.java index 9bb931ea7f0aa..3121eab18fd6b 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskManager.java +++ b/server/src/main/java/org/opensearch/tasks/TaskManager.java @@ -91,7 +91,9 @@ public class TaskManager implements ClusterStateApplier { private static final TimeValue WAIT_FOR_COMPLETION_POLL = timeValueMillis(100); - /** Rest headers that are copied to the task */ + /** + * Rest headers that are copied to the task + */ private final List taskHeaders; private final ThreadPool threadPool; @@ -105,6 +107,7 @@ public class TaskManager implements ClusterStateApplier { private final Map banedParents = new ConcurrentHashMap<>(); private TaskResultsService taskResultsService; + private final SetOnce taskResourceTrackingService = new SetOnce<>(); private volatile DiscoveryNodes lastDiscoveryNodes = DiscoveryNodes.EMPTY_NODES; @@ -127,6 +130,10 @@ public void setTaskCancellationService(TaskCancellationService taskCancellationS this.cancellationService.set(taskCancellationService); } + public void setTaskResourceTrackingService(TaskResourceTrackingService taskResourceTrackingService) { + this.taskResourceTrackingService.set(taskResourceTrackingService); + } + /** * Registers a task without parent task */ @@ -204,6 +211,11 @@ public void cancel(CancellableTask task, String reason, Runnable listener) { */ public Task unregister(Task task) { logger.trace("unregister task for id: {}", task.getId()); + + if (taskResourceTrackingService.get() != null && task.supportsResourceTracking()) { + taskResourceTrackingService.get().stopTracking(task); + } + if (task instanceof CancellableTask) { CancellableTaskHolder holder = cancellableTasks.remove(task.getId()); if (holder != null) { @@ -363,6 +375,7 @@ public int getBanCount() { * Bans all tasks with the specified parent task from execution, cancels all tasks that are currently executing. *

* This method is called when a parent task that has children is cancelled. + * * @return a list of pending cancellable child tasks */ public List setBan(TaskId parentTaskId, String reason) { @@ -450,6 +463,18 @@ public void waitForTaskCompletion(Task task, long untilInNanos) { throw new OpenSearchTimeoutException("Timed out waiting for completion of [{}]", task); } + /** + * Takes actions when a task is registered and its execution starts + * + * @param task getting executed. + * @return AutoCloseable to free up resources (clean up thread context) when task execution block returns + */ + public ThreadContext.StoredContext taskExecutionStarted(Task task) { + if (taskResourceTrackingService.get() == null) return () -> {}; + + return taskResourceTrackingService.get().startTracking(task); + } + private static class CancellableTaskHolder { private final CancellableTask task; private boolean finished = false; diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java new file mode 100644 index 0000000000000..71b829e023385 --- /dev/null +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -0,0 +1,255 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import com.sun.management.ThreadMXBean; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ConcurrentCollections; +import org.opensearch.common.util.concurrent.ConcurrentMapLong; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.threadpool.RunnableTaskExecutionListener; +import org.opensearch.threadpool.ThreadPool; + +import java.lang.management.ManagementFactory; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.opensearch.tasks.ResourceStatsType.WORKER_STATS; + +/** + * Service that helps track resource usage of tasks running on a node. + */ +@SuppressForbidden(reason = "ThreadMXBean#getThreadAllocatedBytes") +public class TaskResourceTrackingService implements RunnableTaskExecutionListener { + + private static final Logger logger = LogManager.getLogger(TaskManager.class); + + public static final Setting TASK_RESOURCE_TRACKING_ENABLED = Setting.boolSetting( + "task_resource_tracking.enabled", + true, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + public static final String TASK_ID = "TASK_ID"; + + private static final ThreadMXBean threadMXBean = (ThreadMXBean) ManagementFactory.getThreadMXBean(); + + private final ConcurrentMapLong resourceAwareTasks = ConcurrentCollections.newConcurrentMapLongWithAggressiveConcurrency(); + private final ThreadPool threadPool; + private volatile boolean taskResourceTrackingEnabled; + + @Inject + public TaskResourceTrackingService(Settings settings, ClusterSettings clusterSettings, ThreadPool threadPool) { + this.taskResourceTrackingEnabled = TASK_RESOURCE_TRACKING_ENABLED.get(settings); + this.threadPool = threadPool; + clusterSettings.addSettingsUpdateConsumer(TASK_RESOURCE_TRACKING_ENABLED, this::setTaskResourceTrackingEnabled); + } + + public void setTaskResourceTrackingEnabled(boolean taskResourceTrackingEnabled) { + this.taskResourceTrackingEnabled = taskResourceTrackingEnabled; + } + + public boolean isTaskResourceTrackingEnabled() { + return taskResourceTrackingEnabled; + } + + public boolean isTaskResourceTrackingSupported() { + return threadMXBean.isThreadAllocatedMemorySupported() && threadMXBean.isThreadAllocatedMemoryEnabled(); + } + + /** + * Executes logic only if task supports resource tracking and resource tracking setting is enabled. + *

+ * 1. Starts tracking the task in map of resourceAwareTasks. + * 2. Adds Task Id in thread context to make sure it's available while task is processed across multiple threads. + * + * @param task for which resources needs to be tracked + * @return Autocloseable stored context to restore ThreadContext to the state before this method changed it. + */ + public ThreadContext.StoredContext startTracking(Task task) { + if (task.supportsResourceTracking() == false + || isTaskResourceTrackingEnabled() == false + || isTaskResourceTrackingSupported() == false) { + return () -> {}; + } + + logger.debug("Starting resource tracking for task: {}", task.getId()); + resourceAwareTasks.put(task.getId(), task); + return addTaskIdToThreadContext(task); + } + + /** + * Stops tracking task registered earlier for tracking. + *

+ * It doesn't have feature enabled check to avoid any issues if setting was disable while the task was in progress. + *

+ * It's also responsible to stop tracking the current thread's resources against this task if not already done. + * This happens when the thread executing the request logic itself calls the unregister method. So in this case unregister + * happens before runnable finishes. + * + * @param task task which has finished and doesn't need resource tracking. + */ + public void stopTracking(Task task) { + logger.debug("Stopping resource tracking for task: {}", task.getId()); + try { + if (isCurrentThreadWorkingOnTask(task)) { + taskExecutionFinishedOnThread(task.getId(), Thread.currentThread().getId()); + } + + List threadsWorkingOnTask = getThreadsWorkingOnTask(task); + if (threadsWorkingOnTask.size() > 0) { + logger.warn("No thread should be active when task finishes. Active threads: {}", threadsWorkingOnTask); + assert false : "No thread should be marked active when task finishes"; + } + } catch (Exception e) { + logger.warn("Failed while trying to mark the task execution on current thread completed.", e); + assert false; + } finally { + resourceAwareTasks.remove(task.getId()); + } + } + + /** + * Refreshes the resource stats for the tasks provided by looking into which threads are actively working on these + * and how much resources these have consumed till now. + * + * @param tasks for which resource stats needs to be refreshed. + */ + public void refreshResourceStats(Task... tasks) { + if (isTaskResourceTrackingEnabled() == false || isTaskResourceTrackingSupported() == false) { + return; + } + + for (Task task : tasks) { + if (task.supportsResourceTracking() && resourceAwareTasks.containsKey(task.getId())) { + refreshResourceStats(task); + } + } + } + + private void refreshResourceStats(Task resourceAwareTask) { + try { + logger.debug("Refreshing resource stats for Task: {}", resourceAwareTask.getId()); + List threadsWorkingOnTask = getThreadsWorkingOnTask(resourceAwareTask); + threadsWorkingOnTask.forEach( + threadId -> resourceAwareTask.updateThreadResourceStats(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)) + ); + } catch (IllegalStateException e) { + logger.debug("Resource stats already updated."); + } + + } + + /** + * Called when a thread starts working on a task's runnable. + * + * @param taskId of the task for which runnable is starting + * @param threadId of the thread which will be executing the runnable and we need to check resource usage for this + * thread + */ + @Override + public void taskExecutionStartedOnThread(long taskId, long threadId) { + try { + if (resourceAwareTasks.containsKey(taskId)) { + logger.debug("Task execution started on thread. Task: {}, Thread: {}", taskId, threadId); + + resourceAwareTasks.get(taskId) + .startThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); + } + } catch (Exception e) { + logger.warn(new ParameterizedMessage("Failed to mark thread execution started for task: [{}]", taskId), e); + assert false; + } + + } + + /** + * Called when a thread finishes working on a task's runnable. + * + * @param taskId of the task for which runnable is complete + * @param threadId of the thread which executed the runnable and we need to check resource usage for this thread + */ + @Override + public void taskExecutionFinishedOnThread(long taskId, long threadId) { + try { + if (resourceAwareTasks.containsKey(taskId)) { + logger.debug("Task execution finished on thread. Task: {}, Thread: {}", taskId, threadId); + resourceAwareTasks.get(taskId) + .stopThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); + } + } catch (Exception e) { + logger.warn(new ParameterizedMessage("Failed to mark thread execution finished for task: [{}]", taskId), e); + assert false; + } + } + + public Map getResourceAwareTasks() { + return Collections.unmodifiableMap(resourceAwareTasks); + } + + private ResourceUsageMetric[] getResourceUsageMetricsForThread(long threadId) { + ResourceUsageMetric currentMemoryUsage = new ResourceUsageMetric( + ResourceStats.MEMORY, + threadMXBean.getThreadAllocatedBytes(threadId) + ); + ResourceUsageMetric currentCPUUsage = new ResourceUsageMetric(ResourceStats.CPU, threadMXBean.getThreadCpuTime(threadId)); + return new ResourceUsageMetric[] { currentMemoryUsage, currentCPUUsage }; + } + + private boolean isCurrentThreadWorkingOnTask(Task task) { + long threadId = Thread.currentThread().getId(); + List threadResourceInfos = task.getResourceStats().getOrDefault(threadId, Collections.emptyList()); + + for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) { + if (threadResourceInfo.isActive()) { + return true; + } + } + return false; + } + + private List getThreadsWorkingOnTask(Task task) { + List activeThreads = new ArrayList<>(); + for (List threadResourceInfos : task.getResourceStats().values()) { + for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) { + if (threadResourceInfo.isActive()) { + activeThreads.add(threadResourceInfo.getThreadId()); + } + } + } + return activeThreads; + } + + /** + * Adds Task Id in the ThreadContext. + *

+ * Stashes the existing ThreadContext and preserves all the existing ThreadContext's data in the new ThreadContext + * as well. + * + * @param task for which Task Id needs to be added in ThreadContext. + * @return StoredContext reference to restore the ThreadContext from which we created a new one. + * Caller can call context.restore() to get the existing ThreadContext back. + */ + private ThreadContext.StoredContext addTaskIdToThreadContext(Task task) { + ThreadContext threadContext = threadPool.getThreadContext(); + ThreadContext.StoredContext storedContext = threadContext.newStoredContext(true, Collections.singletonList(TASK_ID)); + threadContext.putTransient(TASK_ID, task.getId()); + return storedContext; + } + +} diff --git a/server/src/main/java/org/opensearch/tasks/ThreadResourceInfo.java b/server/src/main/java/org/opensearch/tasks/ThreadResourceInfo.java index a0b38649b6420..de49d86d1d5c4 100644 --- a/server/src/main/java/org/opensearch/tasks/ThreadResourceInfo.java +++ b/server/src/main/java/org/opensearch/tasks/ThreadResourceInfo.java @@ -17,11 +17,13 @@ * @opensearch.internal */ public class ThreadResourceInfo { + private final long threadId; private volatile boolean isActive = true; private final ResourceStatsType statsType; private final ResourceUsageInfo resourceUsageInfo; - public ThreadResourceInfo(ResourceStatsType statsType, ResourceUsageMetric... resourceUsageMetrics) { + public ThreadResourceInfo(long threadId, ResourceStatsType statsType, ResourceUsageMetric... resourceUsageMetrics) { + this.threadId = threadId; this.statsType = statsType; this.resourceUsageInfo = new ResourceUsageInfo(resourceUsageMetrics); } @@ -45,12 +47,16 @@ public ResourceStatsType getStatsType() { return statsType; } + public long getThreadId() { + return threadId; + } + public ResourceUsageInfo getResourceUsageInfo() { return resourceUsageInfo; } @Override public String toString() { - return resourceUsageInfo + ", stats_type=" + statsType + ", is_active=" + isActive; + return resourceUsageInfo + ", stats_type=" + statsType + ", is_active=" + isActive + ", threadId=" + threadId; } } diff --git a/server/src/main/java/org/opensearch/threadpool/ResizableExecutorBuilder.java b/server/src/main/java/org/opensearch/threadpool/ResizableExecutorBuilder.java index fd9ca1d3b5f3b..23f8a8979f821 100644 --- a/server/src/main/java/org/opensearch/threadpool/ResizableExecutorBuilder.java +++ b/server/src/main/java/org/opensearch/threadpool/ResizableExecutorBuilder.java @@ -20,6 +20,7 @@ import java.util.Locale; import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicReference; /** * A builder for resizable executors. @@ -30,12 +31,26 @@ public final class ResizableExecutorBuilder extends ExecutorBuilder sizeSetting; private final Setting queueSizeSetting; + private final AtomicReference runnableTaskListener; - ResizableExecutorBuilder(final Settings settings, final String name, final int size, final int queueSize) { - this(settings, name, size, queueSize, "thread_pool." + name); + ResizableExecutorBuilder( + final Settings settings, + final String name, + final int size, + final int queueSize, + final AtomicReference runnableTaskListener + ) { + this(settings, name, size, queueSize, "thread_pool." + name, runnableTaskListener); } - public ResizableExecutorBuilder(final Settings settings, final String name, final int size, final int queueSize, final String prefix) { + public ResizableExecutorBuilder( + final Settings settings, + final String name, + final int size, + final int queueSize, + final String prefix, + final AtomicReference runnableTaskListener + ) { super(name); final String sizeKey = settingsKey(prefix, "size"); this.sizeSetting = new Setting<>( @@ -50,6 +65,7 @@ public ResizableExecutorBuilder(final Settings settings, final String name, fina queueSize, new Setting.Property[] { Setting.Property.NodeScope, Setting.Property.Dynamic } ); + this.runnableTaskListener = runnableTaskListener; } @Override @@ -77,7 +93,8 @@ ThreadPool.ExecutorHolder build(final ResizableExecutorSettings settings, final size, queueSize, threadFactory, - threadContext + threadContext, + runnableTaskListener ); final ThreadPool.Info info = new ThreadPool.Info( name(), diff --git a/server/src/main/java/org/opensearch/threadpool/RunnableTaskExecutionListener.java b/server/src/main/java/org/opensearch/threadpool/RunnableTaskExecutionListener.java new file mode 100644 index 0000000000000..03cd66f80d044 --- /dev/null +++ b/server/src/main/java/org/opensearch/threadpool/RunnableTaskExecutionListener.java @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.threadpool; + +/** + * Listener for events when a runnable execution starts or finishes on a thread and is aware of the task for which the + * runnable is associated to. + */ +public interface RunnableTaskExecutionListener { + + /** + * Sends an update when ever a task's execution start on a thread + * + * @param taskId of task which has started + * @param threadId of thread which is executing the task + */ + void taskExecutionStartedOnThread(long taskId, long threadId); + + /** + * + * Sends an update when task execution finishes on a thread + * + * @param taskId of task which has finished + * @param threadId of thread which executed the task + */ + void taskExecutionFinishedOnThread(long taskId, long threadId); +} diff --git a/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java b/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java new file mode 100644 index 0000000000000..183b9b2f4cf9a --- /dev/null +++ b/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java @@ -0,0 +1,90 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.threadpool; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.common.util.concurrent.AbstractRunnable; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.util.concurrent.WrappedRunnable; +import org.opensearch.tasks.TaskManager; + +import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; + +import static java.lang.Thread.currentThread; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; + +/** + * Responsible for wrapping the original task's runnable and sending updates on when it starts and finishes to + * entities listening to the events. + * + * It's able to associate runnable with a task with the help of task Id available in thread context. + */ +public class TaskAwareRunnable extends AbstractRunnable implements WrappedRunnable { + + private static final Logger logger = LogManager.getLogger(TaskManager.class); + + private final Runnable original; + private final ThreadContext threadContext; + private final AtomicReference runnableTaskListener; + + public TaskAwareRunnable( + final ThreadContext threadContext, + final Runnable original, + final AtomicReference runnableTaskListener + ) { + this.original = original; + this.threadContext = threadContext; + this.runnableTaskListener = runnableTaskListener; + } + + @Override + public void onFailure(Exception e) { + ExceptionsHelper.reThrowIfNotNull(e); + } + + @Override + public boolean isForceExecution() { + return original instanceof AbstractRunnable && ((AbstractRunnable) original).isForceExecution(); + } + + @Override + public void onRejection(final Exception e) { + if (original instanceof AbstractRunnable) { + ((AbstractRunnable) original).onRejection(e); + } else { + ExceptionsHelper.reThrowIfNotNull(e); + } + } + + @Override + protected void doRun() throws Exception { + assert runnableTaskListener.get() != null : "Listener should be attached"; + Long taskId = threadContext.getTransient(TASK_ID); + if (Objects.nonNull(taskId)) { + runnableTaskListener.get().taskExecutionStartedOnThread(taskId, currentThread().getId()); + } else { + logger.debug("Task Id not available in thread context. Skipping update. Thread Info: {}", Thread.currentThread()); + } + try { + original.run(); + } finally { + if (Objects.nonNull(taskId)) { + runnableTaskListener.get().taskExecutionFinishedOnThread(taskId, currentThread().getId()); + } + } + } + + @Override + public Runnable unwrap() { + return original; + } +} diff --git a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java index cc8d81d2a7b4b..928b4871590c6 100644 --- a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java +++ b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java @@ -69,6 +69,7 @@ import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import static java.util.Collections.unmodifiableMap; @@ -200,6 +201,14 @@ public Collection builders() { ); public ThreadPool(final Settings settings, final ExecutorBuilder... customBuilders) { + this(settings, null, customBuilders); + } + + public ThreadPool( + final Settings settings, + final AtomicReference runnableTaskListener, + final ExecutorBuilder... customBuilders + ) { assert Node.NODE_NAME_SETTING.exists(settings); final Map builders = new HashMap<>(); @@ -211,8 +220,11 @@ public ThreadPool(final Settings settings, final ExecutorBuilder... customBui builders.put(Names.WRITE, new FixedExecutorBuilder(settings, Names.WRITE, allocatedProcessors, 10000)); builders.put(Names.GET, new FixedExecutorBuilder(settings, Names.GET, allocatedProcessors, 1000)); builders.put(Names.ANALYZE, new FixedExecutorBuilder(settings, Names.ANALYZE, 1, 16)); - builders.put(Names.SEARCH, new ResizableExecutorBuilder(settings, Names.SEARCH, searchThreadPoolSize(allocatedProcessors), 1000)); - builders.put(Names.SEARCH_THROTTLED, new ResizableExecutorBuilder(settings, Names.SEARCH_THROTTLED, 1, 100)); + builders.put( + Names.SEARCH, + new ResizableExecutorBuilder(settings, Names.SEARCH, searchThreadPoolSize(allocatedProcessors), 1000, runnableTaskListener) + ); + builders.put(Names.SEARCH_THROTTLED, new ResizableExecutorBuilder(settings, Names.SEARCH_THROTTLED, 1, 100, runnableTaskListener)); builders.put(Names.MANAGEMENT, new ScalingExecutorBuilder(Names.MANAGEMENT, 1, 5, TimeValue.timeValueMinutes(5))); // no queue as this means clients will need to handle rejections on listener queue even if the operation succeeded // the assumption here is that the listeners should be very lightweight on the listeners side diff --git a/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java b/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java index b65b72b745a01..dbd6f651a6b3c 100644 --- a/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java +++ b/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java @@ -37,6 +37,7 @@ import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskManager; @@ -86,6 +87,8 @@ public Request newRequest(StreamInput in) throws IOException { public void processMessageReceived(Request request, TransportChannel channel) throws Exception { final Task task = taskManager.register(channel.getChannelType(), action, request); + ThreadContext.StoredContext contextToRestore = taskManager.taskExecutionStarted(task); + Releasable unregisterTask = () -> taskManager.unregister(task); try { if (channel instanceof TcpTransportChannel && task instanceof CancellableTask) { @@ -104,6 +107,7 @@ public void processMessageReceived(Request request, TransportChannel channel) th unregisterTask = null; } finally { Releasables.close(unregisterTask); + contextToRestore.restore(); } } diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java index 7756eb12bb3f4..9bd44185baf24 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java @@ -75,6 +75,9 @@ public synchronized void onTaskUnregistered(Task task) { @Override public void waitForTaskCompletion(Task task) {} + @Override + public void taskExecutionStarted(Task task, Boolean closeableInvoked) {} + public synchronized List> getEvents() { return Collections.unmodifiableList(new ArrayList<>(events)); } diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java new file mode 100644 index 0000000000000..7af7b1929cf35 --- /dev/null +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java @@ -0,0 +1,670 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.action.admin.cluster.node.tasks; + +import com.sun.management.ThreadMXBean; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.ActionListener; +import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; +import org.opensearch.action.admin.cluster.node.tasks.list.ListTasksRequest; +import org.opensearch.action.admin.cluster.node.tasks.list.ListTasksResponse; +import org.opensearch.action.support.ActionTestUtils; +import org.opensearch.action.support.nodes.BaseNodeRequest; +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.AbstractRunnable; +import org.opensearch.tasks.CancellableTask; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancelledException; +import org.opensearch.tasks.TaskId; +import org.opensearch.tasks.TaskInfo; +import org.opensearch.tasks.ThreadResourceInfo; +import org.opensearch.test.tasks.MockTaskManager; +import org.opensearch.test.tasks.MockTaskManagerListener; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.io.IOException; +import java.lang.management.ManagementFactory; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; + +@SuppressForbidden(reason = "ThreadMXBean#getThreadAllocatedBytes") +public class ResourceAwareTasksTests extends TaskManagerTestCase { + + private static final ThreadMXBean threadMXBean = (ThreadMXBean) ManagementFactory.getThreadMXBean(); + + public static class ResourceAwareNodeRequest extends BaseNodeRequest { + protected String requestName; + + public ResourceAwareNodeRequest() { + super(); + } + + public ResourceAwareNodeRequest(StreamInput in) throws IOException { + super(in); + requestName = in.readString(); + } + + public ResourceAwareNodeRequest(NodesRequest request) { + requestName = request.requestName; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(requestName); + } + + @Override + public String getDescription() { + return "ResourceAwareNodeRequest[" + requestName + "]"; + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) { + @Override + public boolean shouldCancelChildrenOnCancellation() { + return false; + } + + @Override + public boolean supportsResourceTracking() { + return true; + } + }; + } + } + + public static class NodesRequest extends BaseNodesRequest { + private final String requestName; + + private NodesRequest(StreamInput in) throws IOException { + super(in); + requestName = in.readString(); + } + + public NodesRequest(String requestName, String... nodesIds) { + super(nodesIds); + this.requestName = requestName; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(requestName); + } + + @Override + public String getDescription() { + return "NodesRequest[" + requestName + "]"; + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) { + @Override + public boolean shouldCancelChildrenOnCancellation() { + return true; + } + }; + } + } + + /** + * Simulates a task which executes work on search executor. + */ + class ResourceAwareNodesAction extends AbstractTestNodesAction { + private final TaskTestContext taskTestContext; + private final boolean blockForCancellation; + + ResourceAwareNodesAction( + String actionName, + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + boolean shouldBlock, + TaskTestContext taskTestContext + ) { + super(actionName, threadPool, clusterService, transportService, NodesRequest::new, ResourceAwareNodeRequest::new); + this.taskTestContext = taskTestContext; + this.blockForCancellation = shouldBlock; + } + + @Override + protected ResourceAwareNodeRequest newNodeRequest(NodesRequest request) { + return new ResourceAwareNodeRequest(request); + } + + @Override + protected NodeResponse nodeOperation(ResourceAwareNodeRequest request, Task task) { + assert task.supportsResourceTracking(); + + AtomicLong threadId = new AtomicLong(); + Future result = threadPool.executor(ThreadPool.Names.SEARCH).submit(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + ExceptionsHelper.reThrowIfNotNull(e); + } + + @Override + @SuppressForbidden(reason = "ThreadMXBean#getThreadAllocatedBytes") + protected void doRun() { + taskTestContext.memoryConsumptionWhenExecutionStarts = threadMXBean.getThreadAllocatedBytes( + Thread.currentThread().getId() + ); + threadId.set(Thread.currentThread().getId()); + + if (taskTestContext.operationStartValidator != null) { + try { + taskTestContext.operationStartValidator.accept(threadId.get()); + } catch (AssertionError error) { + throw new RuntimeException(error); + } + } + + Object[] allocation1 = new Object[1000000]; // 4MB + + if (blockForCancellation) { + // Simulate a job that takes forever to finish + // Using periodic checks method to identify that the task was cancelled + try { + boolean taskCancelled = waitUntil(((CancellableTask) task)::isCancelled); + if (taskCancelled) { + throw new TaskCancelledException("Task Cancelled"); + } else { + fail("It should have thrown an exception"); + } + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + } + + } + + Object[] allocation2 = new Object[1000000]; // 4MB + } + }); + + try { + result.get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e.getCause()); + } finally { + if (taskTestContext.operationFinishedValidator != null) { + try { + // Wait until Task.stopThreadResourceTracking is called and threads are marked inactive + // before performing validation checks. + assertTrue( + "threads should not be marked active when task finishes", + waitUntil(() -> !hasActiveThreads(task), 5, TimeUnit.SECONDS) + ); + } catch (InterruptedException ignored) {} + taskTestContext.operationFinishedValidator.accept(threadId.get()); + } + } + + return new NodeResponse(clusterService.localNode()); + } + + @Override + protected NodeResponse nodeOperation(ResourceAwareNodeRequest request) { + throw new UnsupportedOperationException("the task parameter is required"); + } + } + + private TaskTestContext startResourceAwareNodesAction( + TestNode node, + boolean blockForCancellation, + TaskTestContext taskTestContext, + ActionListener listener + ) { + NodesRequest request = new NodesRequest("Test Request", node.getNodeId()); + + taskTestContext.requestCompleteLatch = new CountDownLatch(1); + + ResourceAwareNodesAction action = new ResourceAwareNodesAction( + "internal:resourceAction", + threadPool, + node.clusterService, + node.transportService, + blockForCancellation, + taskTestContext + ); + taskTestContext.mainTask = action.execute(request, listener); + return taskTestContext; + } + + private static class TaskTestContext { + private Task mainTask; + private CountDownLatch requestCompleteLatch; + private Consumer operationStartValidator; + private Consumer operationFinishedValidator; + private long memoryConsumptionWhenExecutionStarts; + } + + public void testBasicTaskResourceTracking() throws Exception { + setup(true, false); + + final AtomicReference throwableReference = new AtomicReference<>(); + final AtomicReference responseReference = new AtomicReference<>(); + TaskTestContext taskTestContext = new TaskTestContext(); + + Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); + + taskTestContext.operationStartValidator = threadId -> { + Task task = resourceTasks.values().stream().findAny().get(); + + // One thread is currently working on task but not finished + assertEquals(1, resourceTasks.size()); + assertEquals(1, task.getResourceStats().size()); + assertEquals(1, task.getResourceStats().get(threadId).size()); + assertTrue(task.getResourceStats().get(threadId).get(0).isActive()); + assertEquals(0, task.getTotalResourceStats().getCpuTimeInNanos()); + assertEquals(0, task.getTotalResourceStats().getMemoryInBytes()); + }; + + taskTestContext.operationFinishedValidator = threadId -> { + Task task = resourceTasks.values().stream().findAny().get(); + + // Thread has finished working on the task's runnable + assertEquals(1, resourceTasks.size()); + assertEquals(1, task.getResourceStats().size()); + assertEquals(1, task.getResourceStats().get(threadId).size()); + assertFalse(task.getResourceStats().get(threadId).get(0).isActive()); + + long expectedArrayAllocationOverhead = 2 * 4000000; // Task's memory overhead due to array allocations + long actualTaskMemoryOverhead = task.getTotalResourceStats().getMemoryInBytes(); + + assertMemoryUsageWithinLimits( + actualTaskMemoryOverhead - taskTestContext.memoryConsumptionWhenExecutionStarts, + expectedArrayAllocationOverhead + ); + assertTrue(task.getTotalResourceStats().getCpuTimeInNanos() > 0); + }; + + startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + // Waiting for whole request to complete and return successfully till client + taskTestContext.requestCompleteLatch.await(); + + assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get()); + } + + public void testTaskResourceTrackingDuringTaskCancellation() throws Exception { + setup(true, false); + + final AtomicReference throwableReference = new AtomicReference<>(); + final AtomicReference responseReference = new AtomicReference<>(); + TaskTestContext taskTestContext = new TaskTestContext(); + + Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); + + taskTestContext.operationStartValidator = threadId -> { + Task task = resourceTasks.values().stream().findAny().get(); + + // One thread is currently working on task but not finished + assertEquals(1, resourceTasks.size()); + assertEquals(1, task.getResourceStats().size()); + assertEquals(1, task.getResourceStats().get(threadId).size()); + assertTrue(task.getResourceStats().get(threadId).get(0).isActive()); + assertEquals(0, task.getTotalResourceStats().getCpuTimeInNanos()); + assertEquals(0, task.getTotalResourceStats().getMemoryInBytes()); + }; + + taskTestContext.operationFinishedValidator = threadId -> { + Task task = resourceTasks.values().stream().findAny().get(); + + // Thread has finished working on the task's runnable + assertEquals(1, resourceTasks.size()); + assertEquals(1, task.getResourceStats().size()); + assertEquals(1, task.getResourceStats().get(threadId).size()); + assertFalse(task.getResourceStats().get(threadId).get(0).isActive()); + + // allocations are completed before the task is cancelled + long expectedArrayAllocationOverhead = 4000000; // Task's memory overhead due to array allocations + long taskCancellationOverhead = 30000; // Task cancellation overhead ~ 30Kb + long actualTaskMemoryOverhead = task.getTotalResourceStats().getMemoryInBytes(); + + long expectedOverhead = expectedArrayAllocationOverhead + taskCancellationOverhead; + assertMemoryUsageWithinLimits( + actualTaskMemoryOverhead - taskTestContext.memoryConsumptionWhenExecutionStarts, + expectedOverhead + ); + assertTrue(task.getTotalResourceStats().getCpuTimeInNanos() > 0); + }; + + startResourceAwareNodesAction(testNodes[0], true, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + // Cancel main task + CancelTasksRequest request = new CancelTasksRequest(); + request.setReason("Cancelling request to verify Task resource tracking behaviour"); + request.setTaskId(new TaskId(testNodes[0].getNodeId(), taskTestContext.mainTask.getId())); + ActionTestUtils.executeBlocking(testNodes[0].transportCancelTasksAction, request); + + // Waiting for whole request to complete and return successfully till client + taskTestContext.requestCompleteLatch.await(); + + assertEquals(0, resourceTasks.size()); + assertNull(throwableReference.get()); + assertNotNull(responseReference.get()); + assertEquals(1, responseReference.get().failureCount()); + assertEquals(TaskCancelledException.class, findActualException(responseReference.get().failures().get(0)).getClass()); + } + + public void testTaskResourceTrackingDisabled() throws Exception { + setup(false, false); + + final AtomicReference throwableReference = new AtomicReference<>(); + final AtomicReference responseReference = new AtomicReference<>(); + TaskTestContext taskTestContext = new TaskTestContext(); + + Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); + + taskTestContext.operationStartValidator = threadId -> { assertEquals(0, resourceTasks.size()); }; + + taskTestContext.operationFinishedValidator = threadId -> { assertEquals(0, resourceTasks.size()); }; + + startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + // Waiting for whole request to complete and return successfully till client + taskTestContext.requestCompleteLatch.await(); + + assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get()); + } + + public void testTaskResourceTrackingDisabledWhileTaskInProgress() throws Exception { + setup(true, false); + + final AtomicReference throwableReference = new AtomicReference<>(); + final AtomicReference responseReference = new AtomicReference<>(); + TaskTestContext taskTestContext = new TaskTestContext(); + + Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); + + taskTestContext.operationStartValidator = threadId -> { + Task task = resourceTasks.values().stream().findAny().get(); + // One thread is currently working on task but not finished + assertEquals(1, resourceTasks.size()); + assertEquals(1, task.getResourceStats().size()); + assertEquals(1, task.getResourceStats().get(threadId).size()); + assertTrue(task.getResourceStats().get(threadId).get(0).isActive()); + assertEquals(0, task.getTotalResourceStats().getCpuTimeInNanos()); + assertEquals(0, task.getTotalResourceStats().getMemoryInBytes()); + + testNodes[0].taskResourceTrackingService.setTaskResourceTrackingEnabled(false); + }; + + taskTestContext.operationFinishedValidator = threadId -> { + Task task = resourceTasks.values().stream().findAny().get(); + // Thread has finished working on the task's runnable + assertEquals(1, resourceTasks.size()); + assertEquals(1, task.getResourceStats().size()); + assertEquals(1, task.getResourceStats().get(threadId).size()); + assertFalse(task.getResourceStats().get(threadId).get(0).isActive()); + + long expectedArrayAllocationOverhead = 2 * 4000000; // Task's memory overhead due to array allocations + long actualTaskMemoryOverhead = task.getTotalResourceStats().getMemoryInBytes(); + + assertMemoryUsageWithinLimits( + actualTaskMemoryOverhead - taskTestContext.memoryConsumptionWhenExecutionStarts, + expectedArrayAllocationOverhead + ); + assertTrue(task.getTotalResourceStats().getCpuTimeInNanos() > 0); + }; + + startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + // Waiting for whole request to complete and return successfully till client + taskTestContext.requestCompleteLatch.await(); + + assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get()); + } + + public void testTaskResourceTrackingEnabledWhileTaskInProgress() throws Exception { + setup(false, false); + + final AtomicReference throwableReference = new AtomicReference<>(); + final AtomicReference responseReference = new AtomicReference<>(); + TaskTestContext taskTestContext = new TaskTestContext(); + + Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); + + taskTestContext.operationStartValidator = threadId -> { + assertEquals(0, resourceTasks.size()); + + testNodes[0].taskResourceTrackingService.setTaskResourceTrackingEnabled(true); + }; + + taskTestContext.operationFinishedValidator = threadId -> { assertEquals(0, resourceTasks.size()); }; + + startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + // Waiting for whole request to complete and return successfully till client + taskTestContext.requestCompleteLatch.await(); + + assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get()); + } + + public void testOnDemandRefreshWhileFetchingTasks() throws InterruptedException { + setup(true, false); + + final AtomicReference throwableReference = new AtomicReference<>(); + final AtomicReference responseReference = new AtomicReference<>(); + + TaskTestContext taskTestContext = new TaskTestContext(); + + Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); + + taskTestContext.operationStartValidator = threadId -> { + ListTasksResponse listTasksResponse = ActionTestUtils.executeBlocking( + testNodes[0].transportListTasksAction, + new ListTasksRequest().setActions("internal:resourceAction*").setDetailed(true) + ); + + TaskInfo taskInfo = listTasksResponse.getTasks().get(1); + + assertNotNull(taskInfo.getResourceStats()); + assertNotNull(taskInfo.getResourceStats().getResourceUsageInfo()); + assertTrue(taskInfo.getResourceStats().getResourceUsageInfo().get("total").getCpuTimeInNanos() > 0); + assertTrue(taskInfo.getResourceStats().getResourceUsageInfo().get("total").getMemoryInBytes() > 0); + }; + + startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + // Waiting for whole request to complete and return successfully till client + taskTestContext.requestCompleteLatch.await(); + + assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get()); + } + + public void testTaskIdPersistsInThreadContext() throws InterruptedException { + setup(true, true); + + final List taskIdsAddedToThreadContext = new ArrayList<>(); + final List taskIdsRemovedFromThreadContext = new ArrayList<>(); + AtomicLong actualTaskIdInThreadContext = new AtomicLong(-1); + AtomicLong expectedTaskIdInThreadContext = new AtomicLong(-2); + + ((MockTaskManager) testNodes[0].transportService.getTaskManager()).addListener(new MockTaskManagerListener() { + @Override + public void waitForTaskCompletion(Task task) {} + + @Override + public void taskExecutionStarted(Task task, Boolean closeableInvoked) { + if (closeableInvoked) { + taskIdsRemovedFromThreadContext.add(task.getId()); + } else { + taskIdsAddedToThreadContext.add(task.getId()); + } + } + + @Override + public void onTaskRegistered(Task task) {} + + @Override + public void onTaskUnregistered(Task task) { + if (task.getAction().equals("internal:resourceAction[n]")) { + expectedTaskIdInThreadContext.set(task.getId()); + actualTaskIdInThreadContext.set(threadPool.getThreadContext().getTransient(TASK_ID)); + } + } + }); + + TaskTestContext taskTestContext = new TaskTestContext(); + startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + taskTestContext.requestCompleteLatch.await(); + + assertEquals(expectedTaskIdInThreadContext.get(), actualTaskIdInThreadContext.get()); + assertThat(taskIdsAddedToThreadContext, containsInAnyOrder(taskIdsRemovedFromThreadContext.toArray())); + } + + private void setup(boolean resourceTrackingEnabled, boolean useMockTaskManager) { + Settings settings = Settings.builder() + .put("task_resource_tracking.enabled", resourceTrackingEnabled) + .put(MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING.getKey(), useMockTaskManager) + .build(); + setupTestNodes(settings); + connectNodes(testNodes[0]); + + runnableTaskListener.set(testNodes[0].taskResourceTrackingService); + } + + private Throwable findActualException(Exception e) { + Throwable throwable = e.getCause(); + while (throwable.getCause() != null) { + throwable = throwable.getCause(); + } + return throwable; + } + + private void assertTasksRequestFinishedSuccessfully(int activeResourceTasks, NodesResponse nodesResponse, Throwable throwable) { + assertEquals(0, activeResourceTasks); + assertNull(throwable); + assertNotNull(nodesResponse); + assertEquals(0, nodesResponse.failureCount()); + } + + private void assertMemoryUsageWithinLimits(long actual, long expected) { + // 5% buffer up to 200 KB to account for classloading overhead. + long maxOverhead = Math.min(200000, expected * 5 / 100); + assertThat(actual, lessThanOrEqualTo(expected + maxOverhead)); + } + + private boolean hasActiveThreads(Task task) { + for (List threadResourceInfos : task.getResourceStats().values()) { + for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) { + if (threadResourceInfo.isActive()) { + return true; + } + } + } + + return false; + } +} diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java index 4383b21aa7e74..fd6f5d17a3a80 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java @@ -59,8 +59,10 @@ import org.opensearch.indices.breaker.NoneCircuitBreakerService; import org.opensearch.tasks.TaskCancellationService; import org.opensearch.tasks.TaskManager; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.tasks.MockTaskManager; +import org.opensearch.threadpool.RunnableTaskExecutionListener; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -74,6 +76,7 @@ import java.util.List; import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import static java.util.Collections.emptyMap; @@ -89,10 +92,12 @@ public abstract class TaskManagerTestCase extends OpenSearchTestCase { protected ThreadPool threadPool; protected TestNode[] testNodes; protected int nodesCount; + protected AtomicReference runnableTaskListener; @Before public void setupThreadPool() { - threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName()); + runnableTaskListener = new AtomicReference<>(); + threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName(), runnableTaskListener); } public void setupTestNodes(Settings settings) { @@ -225,14 +230,22 @@ protected TaskManager createTaskManager(Settings settings, ThreadPool threadPool transportService.start(); clusterService = createClusterService(threadPool, discoveryNode.get()); clusterService.addStateApplier(transportService.getTaskManager()); + taskResourceTrackingService = new TaskResourceTrackingService(settings, clusterService.getClusterSettings(), threadPool); + transportService.getTaskManager().setTaskResourceTrackingService(taskResourceTrackingService); ActionFilters actionFilters = new ActionFilters(emptySet()); - transportListTasksAction = new TransportListTasksAction(clusterService, transportService, actionFilters); + transportListTasksAction = new TransportListTasksAction( + clusterService, + transportService, + actionFilters, + taskResourceTrackingService + ); transportCancelTasksAction = new TransportCancelTasksAction(clusterService, transportService, actionFilters); transportService.acceptIncomingRequests(); } public final ClusterService clusterService; public final TransportService transportService; + public final TaskResourceTrackingService taskResourceTrackingService; private final SetOnce discoveryNode = new SetOnce<>(); public final TransportListTasksAction transportListTasksAction; public final TransportCancelTasksAction transportCancelTasksAction; diff --git a/server/src/test/java/org/opensearch/action/bulk/TransportBulkActionIngestTests.java b/server/src/test/java/org/opensearch/action/bulk/TransportBulkActionIngestTests.java index 4b98870422ce8..202f1b7dcb5b4 100644 --- a/server/src/test/java/org/opensearch/action/bulk/TransportBulkActionIngestTests.java +++ b/server/src/test/java/org/opensearch/action/bulk/TransportBulkActionIngestTests.java @@ -91,6 +91,7 @@ import static java.util.Collections.emptyMap; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.sameInstance; +import static org.mockito.Answers.RETURNS_MOCKS; import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyString; @@ -224,7 +225,7 @@ public void setupAction() { remoteResponseHandler = ArgumentCaptor.forClass(TransportResponseHandler.class); // setup services that will be called by action - transportService = mock(TransportService.class); + transportService = mock(TransportService.class, RETURNS_MOCKS); clusterService = mock(ClusterService.class); localIngest = true; // setup nodes for local and remote diff --git a/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java b/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java index 9c70accaca3e4..64286e47b4966 100644 --- a/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java +++ b/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java @@ -48,6 +48,7 @@ import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.sameInstance; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; public class ThreadContextTests extends OpenSearchTestCase { @@ -154,6 +155,15 @@ public void testNewContextWithClearedTransients() { assertEquals(1, threadContext.getResponseHeaders().get("baz").size()); } + public void testStashContextWithPreservedTransients() { + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + threadContext.putTransient("foo", "bar"); + threadContext.putTransient(TASK_ID, 1); + threadContext.stashContext(); + assertNull(threadContext.getTransient("foo")); + assertEquals(1, (int) threadContext.getTransient(TASK_ID)); + } + public void testStashWithOrigin() { final String origin = randomAlphaOfLengthBetween(4, 16); final ThreadContext threadContext = new ThreadContext(Settings.EMPTY); diff --git a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java index 1369309218519..2e0e0b5f09698 100644 --- a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java @@ -203,6 +203,7 @@ import org.opensearch.search.fetch.FetchPhase; import org.opensearch.search.query.QueryPhase; import org.opensearch.snapshots.mockstore.MockEventuallyConsistentRepository; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.disruption.DisruptableMockTransport; import org.opensearch.threadpool.ThreadPool; @@ -1754,6 +1755,8 @@ public void onFailure(final Exception e) { final IndexNameExpressionResolver indexNameExpressionResolver = new IndexNameExpressionResolver( new ThreadContext(Settings.EMPTY) ); + transportService.getTaskManager() + .setTaskResourceTrackingService(new TaskResourceTrackingService(settings, clusterSettings, threadPool)); repositoriesService = new RepositoriesService( settings, clusterService, diff --git a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java index 0f09b0de34206..ab49109eb8247 100644 --- a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java +++ b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java @@ -40,6 +40,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ConcurrentCollections; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.RunnableTaskExecutionListener; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.FakeTcpChannel; @@ -59,6 +60,7 @@ import java.util.Set; import java.util.concurrent.Phaser; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.everyItem; @@ -67,10 +69,12 @@ public class TaskManagerTests extends OpenSearchTestCase { private ThreadPool threadPool; + private AtomicReference runnableTaskListener; @Before public void setupThreadPool() { - threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName()); + runnableTaskListener = new AtomicReference<>(); + threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName(), runnableTaskListener); } @After diff --git a/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java b/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java new file mode 100644 index 0000000000000..8ba23c5d3219c --- /dev/null +++ b/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java @@ -0,0 +1,97 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.action.admin.cluster.node.tasks.TransportTasksActionTests; +import org.opensearch.action.search.SearchTask; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +import java.util.HashMap; +import java.util.concurrent.atomic.AtomicReference; + +import static org.opensearch.tasks.ResourceStats.MEMORY; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; + +public class TaskResourceTrackingServiceTests extends OpenSearchTestCase { + + private ThreadPool threadPool; + private TaskResourceTrackingService taskResourceTrackingService; + + @Before + public void setup() { + threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName(), new AtomicReference<>()); + taskResourceTrackingService = new TaskResourceTrackingService( + Settings.EMPTY, + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), + threadPool + ); + } + + @After + public void terminateThreadPool() { + terminate(threadPool); + } + + public void testThreadContextUpdateOnTrackingStart() { + taskResourceTrackingService.setTaskResourceTrackingEnabled(true); + + Task task = new SearchTask(1, "test", "test", () -> "Test", TaskId.EMPTY_TASK_ID, new HashMap<>()); + + String key = "KEY"; + String value = "VALUE"; + + // Prepare thread context + threadPool.getThreadContext().putHeader(key, value); + threadPool.getThreadContext().putTransient(key, value); + threadPool.getThreadContext().addResponseHeader(key, value); + + ThreadContext.StoredContext storedContext = taskResourceTrackingService.startTracking(task); + + // All headers should be preserved and Task Id should also be included in thread context + verifyThreadContextFixedHeaders(key, value); + assertEquals((long) threadPool.getThreadContext().getTransient(TASK_ID), task.getId()); + + storedContext.restore(); + + // Post restore only task id should be removed from the thread context + verifyThreadContextFixedHeaders(key, value); + assertNull(threadPool.getThreadContext().getTransient(TASK_ID)); + } + + public void testStopTrackingHandlesCurrentActiveThread() { + taskResourceTrackingService.setTaskResourceTrackingEnabled(true); + Task task = new SearchTask(1, "test", "test", () -> "Test", TaskId.EMPTY_TASK_ID, new HashMap<>()); + ThreadContext.StoredContext storedContext = taskResourceTrackingService.startTracking(task); + long threadId = Thread.currentThread().getId(); + taskResourceTrackingService.taskExecutionStartedOnThread(task.getId(), threadId); + + assertTrue(task.getResourceStats().get(threadId).get(0).isActive()); + assertEquals(0, task.getResourceStats().get(threadId).get(0).getResourceUsageInfo().getStatsInfo().get(MEMORY).getTotalValue()); + + taskResourceTrackingService.stopTracking(task); + + // Makes sure stop tracking marks the current active thread inactive and refreshes the resource stats before returning. + assertFalse(task.getResourceStats().get(threadId).get(0).isActive()); + assertTrue(task.getResourceStats().get(threadId).get(0).getResourceUsageInfo().getStatsInfo().get(MEMORY).getTotalValue() > 0); + } + + private void verifyThreadContextFixedHeaders(String key, String value) { + assertEquals(threadPool.getThreadContext().getHeader(key), value); + assertEquals(threadPool.getThreadContext().getTransient(key), value); + assertEquals(threadPool.getThreadContext().getResponseHeaders().get(key).get(0), value); + } + +} diff --git a/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java b/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java index e60871f67ea54..677ec7a0a6600 100644 --- a/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java +++ b/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java @@ -39,6 +39,7 @@ import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Setting.Property; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskAwareRequest; import org.opensearch.tasks.TaskManager; @@ -127,6 +128,21 @@ public void waitForTaskCompletion(Task task, long untilInNanos) { super.waitForTaskCompletion(task, untilInNanos); } + @Override + public ThreadContext.StoredContext taskExecutionStarted(Task task) { + for (MockTaskManagerListener listener : listeners) { + listener.taskExecutionStarted(task, false); + } + + ThreadContext.StoredContext storedContext = super.taskExecutionStarted(task); + return () -> { + for (MockTaskManagerListener listener : listeners) { + listener.taskExecutionStarted(task, true); + } + storedContext.restore(); + }; + } + public void addListener(MockTaskManagerListener listener) { listeners.add(listener); } diff --git a/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManagerListener.java b/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManagerListener.java index eb8361ac552fc..f15f878995aa2 100644 --- a/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManagerListener.java +++ b/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManagerListener.java @@ -43,4 +43,7 @@ public interface MockTaskManagerListener { void onTaskUnregistered(Task task); void waitForTaskCompletion(Task task); + + void taskExecutionStarted(Task task, Boolean closeableInvoked); + } diff --git a/test/framework/src/main/java/org/opensearch/threadpool/TestThreadPool.java b/test/framework/src/main/java/org/opensearch/threadpool/TestThreadPool.java index 5f8611d99f0a0..2d97d5bffee01 100644 --- a/test/framework/src/main/java/org/opensearch/threadpool/TestThreadPool.java +++ b/test/framework/src/main/java/org/opensearch/threadpool/TestThreadPool.java @@ -40,6 +40,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.atomic.AtomicReference; public class TestThreadPool extends ThreadPool { @@ -47,12 +48,29 @@ public class TestThreadPool extends ThreadPool { private volatile boolean returnRejectingExecutor = false; private volatile ThreadPoolExecutor rejectingExecutor; + public TestThreadPool( + String name, + AtomicReference runnableTaskListener, + ExecutorBuilder... customBuilders + ) { + this(name, Settings.EMPTY, runnableTaskListener, customBuilders); + } + public TestThreadPool(String name, ExecutorBuilder... customBuilders) { this(name, Settings.EMPTY, customBuilders); } public TestThreadPool(String name, Settings settings, ExecutorBuilder... customBuilders) { - super(Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), name).put(settings).build(), customBuilders); + this(name, settings, null, customBuilders); + } + + public TestThreadPool( + String name, + Settings settings, + AtomicReference runnableTaskListener, + ExecutorBuilder... customBuilders + ) { + super(Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), name).put(settings).build(), runnableTaskListener, customBuilders); } @Override From 3190e4469ed1cac36cc4924e8d89ed5c4fe440c9 Mon Sep 17 00:00:00 2001 From: Ketan Verma Date: Sat, 23 Jul 2022 16:38:21 +0530 Subject: [PATCH 2/6] Fixed a race-condition when Task is unregistered before its threads are stopped Signed-off-by: Ketan Verma --- .../main/java/org/opensearch/tasks/Task.java | 20 +++++++++++++++ .../tasks/TaskResourceTrackingService.java | 1 + .../node/tasks/ResourceAwareTasksTests.java | 25 +++---------------- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/server/src/main/java/org/opensearch/tasks/Task.java b/server/src/main/java/org/opensearch/tasks/Task.java index c8d5312b2e1bc..2a4e4c00b75bb 100644 --- a/server/src/main/java/org/opensearch/tasks/Task.java +++ b/server/src/main/java/org/opensearch/tasks/Task.java @@ -45,6 +45,9 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Phaser; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; /** * Current task information @@ -74,6 +77,8 @@ public class Task { private final Map> resourceStats; + private final Phaser resourceTrackingThreadsBarrier = new Phaser(1); + /** * The task's start time as a wall clock time since epoch ({@link System#currentTimeMillis()} style). */ @@ -288,6 +293,7 @@ public void startThreadResourceTracking(long threadId, ResourceStatsType statsTy } } threadResourceInfoList.add(new ThreadResourceInfo(threadId, statsType, resourceUsageMetrics)); + resourceTrackingThreadsBarrier.register(); } /** @@ -327,6 +333,7 @@ public void stopThreadResourceTracking(long threadId, ResourceStatsType statsTyp if (threadResourceInfo.getStatsType() == statsType && threadResourceInfo.isActive()) { threadResourceInfo.setActive(false); threadResourceInfo.recordResourceUsageMetrics(resourceUsageMetrics); + resourceTrackingThreadsBarrier.arriveAndDeregister(); return; } } @@ -377,4 +384,17 @@ public TaskResult result(DiscoveryNode node, ActionResponse response) throws IOE throw new IllegalStateException("response has to implement ToXContent to be able to store the results"); } } + + /** + * Awaits for stopThreadResourceTracking to be called for all threads of the current task. + * @throws InterruptedException if thread interrupted while waiting + * @throws TimeoutException if timed out while waiting + */ + public void awaitResourceTrackingThreadsCompletion() throws InterruptedException, TimeoutException { + resourceTrackingThreadsBarrier.awaitAdvanceInterruptibly( + resourceTrackingThreadsBarrier.arriveAndDeregister(), + 10L, + TimeUnit.SECONDS + ); + } } diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index 71b829e023385..b822cfa12fb59 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -111,6 +111,7 @@ public void stopTracking(Task task) { taskExecutionFinishedOnThread(task.getId(), Thread.currentThread().getId()); } + task.awaitResourceTrackingThreadsCompletion(); List threadsWorkingOnTask = getThreadsWorkingOnTask(task); if (threadsWorkingOnTask.size() > 0) { logger.warn("No thread should be active when task finishes. Active threads: {}", threadsWorkingOnTask); diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java index 7af7b1929cf35..c969399e2e13f 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java @@ -28,7 +28,6 @@ import org.opensearch.tasks.TaskCancelledException; import org.opensearch.tasks.TaskId; import org.opensearch.tasks.TaskInfo; -import org.opensearch.tasks.ThreadResourceInfo; import org.opensearch.test.tasks.MockTaskManager; import org.opensearch.test.tasks.MockTaskManagerListener; import org.opensearch.threadpool.ThreadPool; @@ -42,7 +41,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -215,13 +214,9 @@ protected void doRun() { } finally { if (taskTestContext.operationFinishedValidator != null) { try { - // Wait until Task.stopThreadResourceTracking is called and threads are marked inactive - // before performing validation checks. - assertTrue( - "threads should not be marked active when task finishes", - waitUntil(() -> !hasActiveThreads(task), 5, TimeUnit.SECONDS) - ); - } catch (InterruptedException ignored) {} + // Wait for threads to be marked inactive before performing validation checks. + task.awaitResourceTrackingThreadsCompletion(); + } catch (InterruptedException | TimeoutException ignored) {} taskTestContext.operationFinishedValidator.accept(threadId.get()); } } @@ -655,16 +650,4 @@ private void assertMemoryUsageWithinLimits(long actual, long expected) { long maxOverhead = Math.min(200000, expected * 5 / 100); assertThat(actual, lessThanOrEqualTo(expected + maxOverhead)); } - - private boolean hasActiveThreads(Task task) { - for (List threadResourceInfos : task.getResourceStats().values()) { - for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) { - if (threadResourceInfo.isActive()) { - return true; - } - } - } - - return false; - } } From 61cbf794e428654c139efb4403683abfaccedd88 Mon Sep 17 00:00:00 2001 From: Ketan Verma Date: Wed, 27 Jul 2022 01:50:34 +0530 Subject: [PATCH 3/6] Replaced await with callbacks to mark task resource tracking completion Signed-off-by: Ketan Verma --- .../main/java/org/opensearch/tasks/Task.java | 72 +++++++++++---- .../org/opensearch/tasks/TaskManager.java | 16 +++- .../tasks/TaskResourceTrackingListener.java | 42 +++++++++ .../tasks/TaskResourceTrackingService.java | 1 - .../node/tasks/ResourceAwareTasksTests.java | 87 +++++++++---------- 5 files changed, 148 insertions(+), 70 deletions(-) create mode 100644 server/src/main/java/org/opensearch/tasks/TaskResourceTrackingListener.java diff --git a/server/src/main/java/org/opensearch/tasks/Task.java b/server/src/main/java/org/opensearch/tasks/Task.java index 2a4e4c00b75bb..cd6b1845cd9d1 100644 --- a/server/src/main/java/org/opensearch/tasks/Task.java +++ b/server/src/main/java/org/opensearch/tasks/Task.java @@ -45,9 +45,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.Phaser; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; /** * Current task information @@ -77,7 +75,12 @@ public class Task { private final Map> resourceStats; - private final Phaser resourceTrackingThreadsBarrier = new Phaser(1); + private final List resourceTrackingListeners; + + // Keeps track of the number of active resource tracking threads for this task. It is initialized to 1 to track + // the task's own/self thread. When this value becomes 0, all threads have been marked inactive and the resource + // tracking can be stopped for this task. + private final AtomicInteger numActiveResourceTrackingThreads = new AtomicInteger(1); /** * The task's start time as a wall clock time since epoch ({@link System#currentTimeMillis()} style). @@ -90,7 +93,18 @@ public class Task { private final long startTimeNanos; public Task(long id, String type, String action, String description, TaskId parentTask, Map headers) { - this(id, type, action, description, parentTask, System.currentTimeMillis(), System.nanoTime(), headers, new ConcurrentHashMap<>()); + this( + id, + type, + action, + description, + parentTask, + System.currentTimeMillis(), + System.nanoTime(), + headers, + new ConcurrentHashMap<>(), + Collections.synchronizedList(new ArrayList<>()) + ); } public Task( @@ -102,7 +116,8 @@ public Task( long startTime, long startTimeNanos, Map headers, - ConcurrentHashMap> resourceStats + ConcurrentHashMap> resourceStats, + List resourceTrackingListeners ) { this.id = id; this.type = type; @@ -113,6 +128,7 @@ public Task( this.startTimeNanos = startTimeNanos; this.headers = headers; this.resourceStats = resourceStats; + this.resourceTrackingListeners = resourceTrackingListeners; } /** @@ -293,7 +309,8 @@ public void startThreadResourceTracking(long threadId, ResourceStatsType statsTy } } threadResourceInfoList.add(new ThreadResourceInfo(threadId, statsType, resourceUsageMetrics)); - resourceTrackingThreadsBarrier.register(); + incrementResourceTrackingThreads(); + resourceTrackingListeners.forEach(listener -> listener.onTaskExecutionStartedOnThread(this, threadId)); } /** @@ -311,6 +328,7 @@ public void updateThreadResourceStats(long threadId, ResourceStatsType statsType // the active entry present in the list is updated if (threadResourceInfo.getStatsType() == statsType && threadResourceInfo.isActive()) { threadResourceInfo.recordResourceUsageMetrics(resourceUsageMetrics); + resourceTrackingListeners.forEach(listener -> listener.onTaskResourceStatsUpdated(this)); return; } } @@ -333,7 +351,8 @@ public void stopThreadResourceTracking(long threadId, ResourceStatsType statsTyp if (threadResourceInfo.getStatsType() == statsType && threadResourceInfo.isActive()) { threadResourceInfo.setActive(false); threadResourceInfo.recordResourceUsageMetrics(resourceUsageMetrics); - resourceTrackingThreadsBarrier.arriveAndDeregister(); + decrementResourceTrackingThreads(); + resourceTrackingListeners.forEach(listener -> listener.onTaskExecutionFinishedOnThread(this, threadId)); return; } } @@ -386,15 +405,34 @@ public TaskResult result(DiscoveryNode node, ActionResponse response) throws IOE } /** - * Awaits for stopThreadResourceTracking to be called for all threads of the current task. - * @throws InterruptedException if thread interrupted while waiting - * @throws TimeoutException if timed out while waiting + * Registers a TaskResourceTrackingListener callback listener on this task. */ - public void awaitResourceTrackingThreadsCompletion() throws InterruptedException, TimeoutException { - resourceTrackingThreadsBarrier.awaitAdvanceInterruptibly( - resourceTrackingThreadsBarrier.arriveAndDeregister(), - 10L, - TimeUnit.SECONDS - ); + public void addTaskResourceTrackingListener(TaskResourceTrackingListener listener) { + resourceTrackingListeners.add(listener); + } + + /** + * Increments the number of active resource tracking threads. + * + * @return the number of active resource tracking threads. + */ + public int incrementResourceTrackingThreads() { + return numActiveResourceTrackingThreads.incrementAndGet(); + } + + /** + * Decrements the number of active resource tracking threads. + * When this value becomes zero, the onTaskResourceTrackingCompleted method is called on all registered listeners. + * + * @return the number of active resource tracking threads. + */ + public int decrementResourceTrackingThreads() { + int count = numActiveResourceTrackingThreads.decrementAndGet(); + + if (count == 0) { + resourceTrackingListeners.forEach(listener -> listener.onTaskResourceTrackingCompleted(this)); + } + + return count; } } diff --git a/server/src/main/java/org/opensearch/tasks/TaskManager.java b/server/src/main/java/org/opensearch/tasks/TaskManager.java index 3121eab18fd6b..4197dd2b5eaaa 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskManager.java +++ b/server/src/main/java/org/opensearch/tasks/TaskManager.java @@ -85,7 +85,7 @@ * * @opensearch.internal */ -public class TaskManager implements ClusterStateApplier { +public class TaskManager implements ClusterStateApplier, TaskResourceTrackingListener { private static final Logger logger = LogManager.getLogger(TaskManager.class); @@ -153,6 +153,7 @@ public Task register(String type, String action, TaskAwareRequest request) { } } Task task = request.createTask(taskIdGenerator.incrementAndGet(), type, action, request.getParentTask(), headers); + task.addTaskResourceTrackingListener(this); Objects.requireNonNull(task); assert task.getParentTaskId().equals(request.getParentTask()) : "Request [ " + request + "] didn't preserve it parentTaskId"; if (logger.isTraceEnabled()) { @@ -212,9 +213,8 @@ public void cancel(CancellableTask task, String reason, Runnable listener) { public Task unregister(Task task) { logger.trace("unregister task for id: {}", task.getId()); - if (taskResourceTrackingService.get() != null && task.supportsResourceTracking()) { - taskResourceTrackingService.get().stopTracking(task); - } + // Decrement the task's self-thread as a part of unregistration. + task.decrementResourceTrackingThreads(); if (task instanceof CancellableTask) { CancellableTaskHolder holder = cancellableTasks.remove(task.getId()); @@ -700,4 +700,12 @@ public void cancelTaskAndDescendants(CancellableTask task, String reason, boolea throw new IllegalStateException("TaskCancellationService is not initialized"); } } + + @Override + public void onTaskResourceTrackingCompleted(Task task) { + // Stop tracking the task once the last thread has been marked inactive. + if (taskResourceTrackingService.get() != null && task.supportsResourceTracking()) { + taskResourceTrackingService.get().stopTracking(task); + } + } } diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingListener.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingListener.java new file mode 100644 index 0000000000000..e29c6ec8a0256 --- /dev/null +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingListener.java @@ -0,0 +1,42 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +/** + * Listener for events related to resource tracking of a Task. + */ +public interface TaskResourceTrackingListener { + + /** + * Invoked when a task execution is started on a thread. + * @param task object + * @param threadId on which execution of task started + */ + default void onTaskExecutionStartedOnThread(Task task, long threadId) {} + + /** + * Invoked when a task execution is finished on a thread. + * @param task object + * @param threadId on which execution of task finished. + */ + default void onTaskExecutionFinishedOnThread(Task task, long threadId) {} + + /** + * Invoked when a task's resource stats are updated. + * @param task object + */ + default void onTaskResourceStatsUpdated(Task task) {} + + /** + * Invoked when a task's resource tracking is completed. + * This happens when all of task's threads are marked inactive. + * @param task object + */ + default void onTaskResourceTrackingCompleted(Task task) {} +} diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index b822cfa12fb59..71b829e023385 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -111,7 +111,6 @@ public void stopTracking(Task task) { taskExecutionFinishedOnThread(task.getId(), Thread.currentThread().getId()); } - task.awaitResourceTrackingThreadsCompletion(); List threadsWorkingOnTask = getThreadsWorkingOnTask(task); if (threadsWorkingOnTask.size() > 0) { logger.warn("No thread should be active when task finishes. Active threads: {}", threadsWorkingOnTask); diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java index c969399e2e13f..c7f53754d015c 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java @@ -28,6 +28,7 @@ import org.opensearch.tasks.TaskCancelledException; import org.opensearch.tasks.TaskId; import org.opensearch.tasks.TaskInfo; +import org.opensearch.tasks.TaskResourceTrackingListener; import org.opensearch.test.tasks.MockTaskManager; import org.opensearch.test.tasks.MockTaskManagerListener; import org.opensearch.threadpool.ThreadPool; @@ -41,10 +42,9 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; +import java.util.function.BiConsumer; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.lessThanOrEqualTo; @@ -177,12 +177,20 @@ protected void doRun() { ); threadId.set(Thread.currentThread().getId()); + // operationStartValidator will be called just before the task execution. if (taskTestContext.operationStartValidator != null) { - try { - taskTestContext.operationStartValidator.accept(threadId.get()); - } catch (AssertionError error) { - throw new RuntimeException(error); - } + taskTestContext.operationStartValidator.accept(task, threadId.get()); + } + + // operationFinishedValidator will be called just after all task threads are marked inactive and + // the task is unregistered. + if (taskTestContext.operationFinishedValidator != null) { + task.addTaskResourceTrackingListener(new TaskResourceTrackingListener() { + @Override + public void onTaskResourceTrackingCompleted(Task task) { + taskTestContext.operationFinishedValidator.accept(task, threadId.get()); + } + }); } Object[] allocation1 = new Object[1000000]; // 4MB @@ -211,14 +219,6 @@ protected void doRun() { result.get(); } catch (InterruptedException | ExecutionException e) { throw new RuntimeException(e.getCause()); - } finally { - if (taskTestContext.operationFinishedValidator != null) { - try { - // Wait for threads to be marked inactive before performing validation checks. - task.awaitResourceTrackingThreadsCompletion(); - } catch (InterruptedException | TimeoutException ignored) {} - taskTestContext.operationFinishedValidator.accept(threadId.get()); - } } return new NodeResponse(clusterService.localNode()); @@ -255,8 +255,8 @@ private TaskTestContext startResourceAwareNodesAction( private static class TaskTestContext { private Task mainTask; private CountDownLatch requestCompleteLatch; - private Consumer operationStartValidator; - private Consumer operationFinishedValidator; + private BiConsumer operationStartValidator; + private BiConsumer operationFinishedValidator; private long memoryConsumptionWhenExecutionStarts; } @@ -269,9 +269,7 @@ public void testBasicTaskResourceTracking() throws Exception { Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); - taskTestContext.operationStartValidator = threadId -> { - Task task = resourceTasks.values().stream().findAny().get(); - + taskTestContext.operationStartValidator = (task, threadId) -> { // One thread is currently working on task but not finished assertEquals(1, resourceTasks.size()); assertEquals(1, task.getResourceStats().size()); @@ -281,11 +279,9 @@ public void testBasicTaskResourceTracking() throws Exception { assertEquals(0, task.getTotalResourceStats().getMemoryInBytes()); }; - taskTestContext.operationFinishedValidator = threadId -> { - Task task = resourceTasks.values().stream().findAny().get(); - + taskTestContext.operationFinishedValidator = (task, threadId) -> { // Thread has finished working on the task's runnable - assertEquals(1, resourceTasks.size()); + assertEquals(0, resourceTasks.size()); assertEquals(1, task.getResourceStats().size()); assertEquals(1, task.getResourceStats().get(threadId).size()); assertFalse(task.getResourceStats().get(threadId).get(0).isActive()); @@ -317,7 +313,7 @@ public void onFailure(Exception e) { // Waiting for whole request to complete and return successfully till client taskTestContext.requestCompleteLatch.await(); - assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get()); + assertTasksRequestFinishedSuccessfully(responseReference.get(), throwableReference.get()); } public void testTaskResourceTrackingDuringTaskCancellation() throws Exception { @@ -329,9 +325,7 @@ public void testTaskResourceTrackingDuringTaskCancellation() throws Exception { Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); - taskTestContext.operationStartValidator = threadId -> { - Task task = resourceTasks.values().stream().findAny().get(); - + taskTestContext.operationStartValidator = (task, threadId) -> { // One thread is currently working on task but not finished assertEquals(1, resourceTasks.size()); assertEquals(1, task.getResourceStats().size()); @@ -341,11 +335,9 @@ public void testTaskResourceTrackingDuringTaskCancellation() throws Exception { assertEquals(0, task.getTotalResourceStats().getMemoryInBytes()); }; - taskTestContext.operationFinishedValidator = threadId -> { - Task task = resourceTasks.values().stream().findAny().get(); - + taskTestContext.operationFinishedValidator = (task, threadId) -> { // Thread has finished working on the task's runnable - assertEquals(1, resourceTasks.size()); + assertEquals(0, resourceTasks.size()); assertEquals(1, task.getResourceStats().size()); assertEquals(1, task.getResourceStats().get(threadId).size()); assertFalse(task.getResourceStats().get(threadId).get(0).isActive()); @@ -402,9 +394,9 @@ public void testTaskResourceTrackingDisabled() throws Exception { Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); - taskTestContext.operationStartValidator = threadId -> { assertEquals(0, resourceTasks.size()); }; + taskTestContext.operationStartValidator = (task, threadId) -> { assertEquals(0, resourceTasks.size()); }; - taskTestContext.operationFinishedValidator = threadId -> { assertEquals(0, resourceTasks.size()); }; + taskTestContext.operationFinishedValidator = (task, threadId) -> { assertEquals(0, resourceTasks.size()); }; startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { @Override @@ -423,7 +415,7 @@ public void onFailure(Exception e) { // Waiting for whole request to complete and return successfully till client taskTestContext.requestCompleteLatch.await(); - assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get()); + assertTasksRequestFinishedSuccessfully(responseReference.get(), throwableReference.get()); } public void testTaskResourceTrackingDisabledWhileTaskInProgress() throws Exception { @@ -435,8 +427,7 @@ public void testTaskResourceTrackingDisabledWhileTaskInProgress() throws Excepti Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); - taskTestContext.operationStartValidator = threadId -> { - Task task = resourceTasks.values().stream().findAny().get(); + taskTestContext.operationStartValidator = (task, threadId) -> { // One thread is currently working on task but not finished assertEquals(1, resourceTasks.size()); assertEquals(1, task.getResourceStats().size()); @@ -448,10 +439,9 @@ public void testTaskResourceTrackingDisabledWhileTaskInProgress() throws Excepti testNodes[0].taskResourceTrackingService.setTaskResourceTrackingEnabled(false); }; - taskTestContext.operationFinishedValidator = threadId -> { - Task task = resourceTasks.values().stream().findAny().get(); + taskTestContext.operationFinishedValidator = (task, threadId) -> { // Thread has finished working on the task's runnable - assertEquals(1, resourceTasks.size()); + assertEquals(0, resourceTasks.size()); assertEquals(1, task.getResourceStats().size()); assertEquals(1, task.getResourceStats().get(threadId).size()); assertFalse(task.getResourceStats().get(threadId).get(0).isActive()); @@ -483,7 +473,7 @@ public void onFailure(Exception e) { // Waiting for whole request to complete and return successfully till client taskTestContext.requestCompleteLatch.await(); - assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get()); + assertTasksRequestFinishedSuccessfully(responseReference.get(), throwableReference.get()); } public void testTaskResourceTrackingEnabledWhileTaskInProgress() throws Exception { @@ -495,13 +485,13 @@ public void testTaskResourceTrackingEnabledWhileTaskInProgress() throws Exceptio Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); - taskTestContext.operationStartValidator = threadId -> { + taskTestContext.operationStartValidator = (task, threadId) -> { assertEquals(0, resourceTasks.size()); testNodes[0].taskResourceTrackingService.setTaskResourceTrackingEnabled(true); }; - taskTestContext.operationFinishedValidator = threadId -> { assertEquals(0, resourceTasks.size()); }; + taskTestContext.operationFinishedValidator = (task, threadId) -> { assertEquals(0, resourceTasks.size()); }; startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { @Override @@ -520,7 +510,7 @@ public void onFailure(Exception e) { // Waiting for whole request to complete and return successfully till client taskTestContext.requestCompleteLatch.await(); - assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get()); + assertTasksRequestFinishedSuccessfully(responseReference.get(), throwableReference.get()); } public void testOnDemandRefreshWhileFetchingTasks() throws InterruptedException { @@ -533,7 +523,7 @@ public void testOnDemandRefreshWhileFetchingTasks() throws InterruptedException Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); - taskTestContext.operationStartValidator = threadId -> { + taskTestContext.operationStartValidator = (task, threadId) -> { ListTasksResponse listTasksResponse = ActionTestUtils.executeBlocking( testNodes[0].transportListTasksAction, new ListTasksRequest().setActions("internal:resourceAction*").setDetailed(true) @@ -547,6 +537,8 @@ public void testOnDemandRefreshWhileFetchingTasks() throws InterruptedException assertTrue(taskInfo.getResourceStats().getResourceUsageInfo().get("total").getMemoryInBytes() > 0); }; + taskTestContext.operationFinishedValidator = (task, threadId) -> { assertEquals(0, resourceTasks.size()); }; + startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { @Override public void onResponse(NodesResponse listTasksResponse) { @@ -564,7 +556,7 @@ public void onFailure(Exception e) { // Waiting for whole request to complete and return successfully till client taskTestContext.requestCompleteLatch.await(); - assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get()); + assertTasksRequestFinishedSuccessfully(responseReference.get(), throwableReference.get()); } public void testTaskIdPersistsInThreadContext() throws InterruptedException { @@ -638,8 +630,7 @@ private Throwable findActualException(Exception e) { return throwable; } - private void assertTasksRequestFinishedSuccessfully(int activeResourceTasks, NodesResponse nodesResponse, Throwable throwable) { - assertEquals(0, activeResourceTasks); + private void assertTasksRequestFinishedSuccessfully(NodesResponse nodesResponse, Throwable throwable) { assertNull(throwable); assertNotNull(nodesResponse); assertEquals(0, nodesResponse.failureCount()); From a09a60acbbbad255f2b6626dd08ce94492cdeb97 Mon Sep 17 00:00:00 2001 From: Ketan Verma Date: Thu, 28 Jul 2022 13:37:35 +0530 Subject: [PATCH 4/6] Improved error handling for callbacks and relaxed assertions Signed-off-by: Ketan Verma --- .../main/java/org/opensearch/tasks/Task.java | 61 ++++++++++++++++--- .../org/opensearch/tasks/TaskManager.java | 2 +- .../tasks/TaskResourceTrackingService.java | 19 ++---- 3 files changed, 59 insertions(+), 23 deletions(-) diff --git a/server/src/main/java/org/opensearch/tasks/Task.java b/server/src/main/java/org/opensearch/tasks/Task.java index cd6b1845cd9d1..e501f68be0f5b 100644 --- a/server/src/main/java/org/opensearch/tasks/Task.java +++ b/server/src/main/java/org/opensearch/tasks/Task.java @@ -32,6 +32,8 @@ package org.opensearch.tasks; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionResponse; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.NamedWriteable; @@ -45,6 +47,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; /** @@ -54,6 +57,8 @@ */ public class Task { + private static final Logger logger = LogManager.getLogger(Task.class); + /** * The request header to mark tasks with specific ids */ @@ -77,9 +82,13 @@ public class Task { private final List resourceTrackingListeners; - // Keeps track of the number of active resource tracking threads for this task. It is initialized to 1 to track - // the task's own/self thread. When this value becomes 0, all threads have been marked inactive and the resource - // tracking can be stopped for this task. + private final AtomicBoolean isResourceTrackingCompleted = new AtomicBoolean(false); + + /** + * Keeps track of the number of active resource tracking threads for this task. It is initialized to 1 to track + * the task's own/self thread. When this value becomes 0, all threads have been marked inactive and the resource + * tracking can be stopped for this task. + */ private final AtomicInteger numActiveResourceTrackingThreads = new AtomicInteger(1); /** @@ -310,7 +319,13 @@ public void startThreadResourceTracking(long threadId, ResourceStatsType statsTy } threadResourceInfoList.add(new ThreadResourceInfo(threadId, statsType, resourceUsageMetrics)); incrementResourceTrackingThreads(); - resourceTrackingListeners.forEach(listener -> listener.onTaskExecutionStartedOnThread(this, threadId)); + resourceTrackingListeners.forEach(listener -> { + try { + listener.onTaskExecutionStartedOnThread(this, threadId); + } catch (Exception e) { + logger.warn("failure in listener during handling of onTaskExecutionStartedOnThread", e); + } + }); } /** @@ -328,7 +343,13 @@ public void updateThreadResourceStats(long threadId, ResourceStatsType statsType // the active entry present in the list is updated if (threadResourceInfo.getStatsType() == statsType && threadResourceInfo.isActive()) { threadResourceInfo.recordResourceUsageMetrics(resourceUsageMetrics); - resourceTrackingListeners.forEach(listener -> listener.onTaskResourceStatsUpdated(this)); + resourceTrackingListeners.forEach(listener -> { + try { + listener.onTaskResourceStatsUpdated(this); + } catch (Exception e) { + logger.warn("failure in listener during handling of onTaskResourceStatsUpdated", e); + } + }); return; } } @@ -352,7 +373,14 @@ public void stopThreadResourceTracking(long threadId, ResourceStatsType statsTyp threadResourceInfo.setActive(false); threadResourceInfo.recordResourceUsageMetrics(resourceUsageMetrics); decrementResourceTrackingThreads(); - resourceTrackingListeners.forEach(listener -> listener.onTaskExecutionFinishedOnThread(this, threadId)); + resourceTrackingListeners.forEach(listener -> { + try { + listener.onTaskExecutionFinishedOnThread(this, threadId); + } catch (Exception e) { + logger.warn("failure in listener during handling of onTaskExecutionFinishedOnThread", e); + } + }); + return; } } @@ -422,15 +450,30 @@ public int incrementResourceTrackingThreads() { /** * Decrements the number of active resource tracking threads. - * When this value becomes zero, the onTaskResourceTrackingCompleted method is called on all registered listeners. + * This method is called when threads finish execution, and also when the task is unregistered (to mark the task's + * own thread as complete). When the active thread count becomes zero, the onTaskResourceTrackingCompleted method + * is called exactly once on all registered listeners. + * + * Since a task is unregistered after the message is processed, it implies that the threads responsible to produce + * the response must have started prior to it (i.e. startThreadResourceTracking called before unregister). + * This ensures that the number of active threads doesn't drop to zero pre-maturely. + * + * Rarely, some threads may even start execution after the task is unregistered. As resource stats are piggy-backed + * with the response, any thread usage info captured after the task is unregistered may be irrelevant. * * @return the number of active resource tracking threads. */ public int decrementResourceTrackingThreads() { int count = numActiveResourceTrackingThreads.decrementAndGet(); - if (count == 0) { - resourceTrackingListeners.forEach(listener -> listener.onTaskResourceTrackingCompleted(this)); + if (count == 0 && isResourceTrackingCompleted.compareAndSet(false, true)) { + resourceTrackingListeners.forEach(listener -> { + try { + listener.onTaskResourceTrackingCompleted(this); + } catch (Exception e) { + logger.warn("failure in listener during handling of onTaskResourceTrackingCompleted", e); + } + }); } return count; diff --git a/server/src/main/java/org/opensearch/tasks/TaskManager.java b/server/src/main/java/org/opensearch/tasks/TaskManager.java index 4197dd2b5eaaa..b98f1e44df349 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskManager.java +++ b/server/src/main/java/org/opensearch/tasks/TaskManager.java @@ -213,7 +213,7 @@ public void cancel(CancellableTask task, String reason, Runnable listener) { public Task unregister(Task task) { logger.trace("unregister task for id: {}", task.getId()); - // Decrement the task's self-thread as a part of unregistration. + // Decrement the task's self-thread as part of unregistration. task.decrementResourceTrackingThreads(); if (task instanceof CancellableTask) { diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index 71b829e023385..c3cad117390e4 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -110,12 +110,6 @@ public void stopTracking(Task task) { if (isCurrentThreadWorkingOnTask(task)) { taskExecutionFinishedOnThread(task.getId(), Thread.currentThread().getId()); } - - List threadsWorkingOnTask = getThreadsWorkingOnTask(task); - if (threadsWorkingOnTask.size() > 0) { - logger.warn("No thread should be active when task finishes. Active threads: {}", threadsWorkingOnTask); - assert false : "No thread should be marked active when task finishes"; - } } catch (Exception e) { logger.warn("Failed while trying to mark the task execution on current thread completed.", e); assert false; @@ -165,11 +159,10 @@ private void refreshResourceStats(Task resourceAwareTask) { @Override public void taskExecutionStartedOnThread(long taskId, long threadId) { try { - if (resourceAwareTasks.containsKey(taskId)) { + final Task task = resourceAwareTasks.get(taskId); + if (task != null) { logger.debug("Task execution started on thread. Task: {}, Thread: {}", taskId, threadId); - - resourceAwareTasks.get(taskId) - .startThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); + task.startThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); } } catch (Exception e) { logger.warn(new ParameterizedMessage("Failed to mark thread execution started for task: [{}]", taskId), e); @@ -187,10 +180,10 @@ public void taskExecutionStartedOnThread(long taskId, long threadId) { @Override public void taskExecutionFinishedOnThread(long taskId, long threadId) { try { - if (resourceAwareTasks.containsKey(taskId)) { + final Task task = resourceAwareTasks.get(taskId); + if (task != null) { logger.debug("Task execution finished on thread. Task: {}, Thread: {}", taskId, threadId); - resourceAwareTasks.get(taskId) - .stopThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); + task.stopThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); } } catch (Exception e) { logger.warn(new ParameterizedMessage("Failed to mark thread execution finished for task: [{}]", taskId), e); From d89de65f34f44dd21e49941bee0be63fa905be14 Mon Sep 17 00:00:00 2001 From: Ketan Verma Date: Tue, 2 Aug 2022 00:07:13 +0530 Subject: [PATCH 5/6] Improved error handling and simplified task resource tracking completion listener Signed-off-by: Ketan Verma --- .../main/java/org/opensearch/tasks/Task.java | 57 +++++++------------ .../org/opensearch/tasks/TaskManager.java | 29 ++++++---- .../tasks/TaskResourceTrackingListener.java | 42 -------------- .../node/tasks/ResourceAwareTasksTests.java | 11 +++- 4 files changed, 47 insertions(+), 92 deletions(-) delete mode 100644 server/src/main/java/org/opensearch/tasks/TaskResourceTrackingListener.java diff --git a/server/src/main/java/org/opensearch/tasks/Task.java b/server/src/main/java/org/opensearch/tasks/Task.java index e501f68be0f5b..4f6724041902d 100644 --- a/server/src/main/java/org/opensearch/tasks/Task.java +++ b/server/src/main/java/org/opensearch/tasks/Task.java @@ -34,7 +34,9 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.action.ActionResponse; +import org.opensearch.action.NotifyOnceListener; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.NamedWriteable; import org.opensearch.common.xcontent.ToXContent; @@ -47,7 +49,6 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; /** @@ -80,9 +81,7 @@ public class Task { private final Map> resourceStats; - private final List resourceTrackingListeners; - - private final AtomicBoolean isResourceTrackingCompleted = new AtomicBoolean(false); + private final List> resourceTrackingCompletionListeners; /** * Keeps track of the number of active resource tracking threads for this task. It is initialized to 1 to track @@ -112,7 +111,7 @@ public Task(long id, String type, String action, String description, TaskId pare System.nanoTime(), headers, new ConcurrentHashMap<>(), - Collections.synchronizedList(new ArrayList<>()) + new ArrayList<>() ); } @@ -126,7 +125,7 @@ public Task( long startTimeNanos, Map headers, ConcurrentHashMap> resourceStats, - List resourceTrackingListeners + List> resourceTrackingCompletionListeners ) { this.id = id; this.type = type; @@ -137,7 +136,7 @@ public Task( this.startTimeNanos = startTimeNanos; this.headers = headers; this.resourceStats = resourceStats; - this.resourceTrackingListeners = resourceTrackingListeners; + this.resourceTrackingCompletionListeners = resourceTrackingCompletionListeners; } /** @@ -319,13 +318,6 @@ public void startThreadResourceTracking(long threadId, ResourceStatsType statsTy } threadResourceInfoList.add(new ThreadResourceInfo(threadId, statsType, resourceUsageMetrics)); incrementResourceTrackingThreads(); - resourceTrackingListeners.forEach(listener -> { - try { - listener.onTaskExecutionStartedOnThread(this, threadId); - } catch (Exception e) { - logger.warn("failure in listener during handling of onTaskExecutionStartedOnThread", e); - } - }); } /** @@ -343,13 +335,6 @@ public void updateThreadResourceStats(long threadId, ResourceStatsType statsType // the active entry present in the list is updated if (threadResourceInfo.getStatsType() == statsType && threadResourceInfo.isActive()) { threadResourceInfo.recordResourceUsageMetrics(resourceUsageMetrics); - resourceTrackingListeners.forEach(listener -> { - try { - listener.onTaskResourceStatsUpdated(this); - } catch (Exception e) { - logger.warn("failure in listener during handling of onTaskResourceStatsUpdated", e); - } - }); return; } } @@ -373,14 +358,6 @@ public void stopThreadResourceTracking(long threadId, ResourceStatsType statsTyp threadResourceInfo.setActive(false); threadResourceInfo.recordResourceUsageMetrics(resourceUsageMetrics); decrementResourceTrackingThreads(); - resourceTrackingListeners.forEach(listener -> { - try { - listener.onTaskExecutionFinishedOnThread(this, threadId); - } catch (Exception e) { - logger.warn("failure in listener during handling of onTaskExecutionFinishedOnThread", e); - } - }); - return; } } @@ -433,10 +410,10 @@ public TaskResult result(DiscoveryNode node, ActionResponse response) throws IOE } /** - * Registers a TaskResourceTrackingListener callback listener on this task. + * Registers a task resource tracking completion listener on this task. */ - public void addTaskResourceTrackingListener(TaskResourceTrackingListener listener) { - resourceTrackingListeners.add(listener); + public void addResourceTrackingCompletionListener(NotifyOnceListener listener) { + resourceTrackingCompletionListeners.add(listener); } /** @@ -466,14 +443,20 @@ public int incrementResourceTrackingThreads() { public int decrementResourceTrackingThreads() { int count = numActiveResourceTrackingThreads.decrementAndGet(); - if (count == 0 && isResourceTrackingCompleted.compareAndSet(false, true)) { - resourceTrackingListeners.forEach(listener -> { + if (count == 0) { + List listenerExceptions = new ArrayList<>(); + resourceTrackingCompletionListeners.forEach(listener -> { try { - listener.onTaskResourceTrackingCompleted(this); - } catch (Exception e) { - logger.warn("failure in listener during handling of onTaskResourceTrackingCompleted", e); + listener.onResponse(this); + } catch (Exception e1) { + try { + listener.onFailure(e1); + } catch (Exception e2) { + listenerExceptions.add(e2); + } } }); + ExceptionsHelper.maybeThrowRuntimeAndSuppress(listenerExceptions); } return count; diff --git a/server/src/main/java/org/opensearch/tasks/TaskManager.java b/server/src/main/java/org/opensearch/tasks/TaskManager.java index b98f1e44df349..95f2ef0c4d731 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskManager.java +++ b/server/src/main/java/org/opensearch/tasks/TaskManager.java @@ -44,6 +44,7 @@ import org.opensearch.OpenSearchTimeoutException; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionResponse; +import org.opensearch.action.NotifyOnceListener; import org.opensearch.cluster.ClusterChangedEvent; import org.opensearch.cluster.ClusterStateApplier; import org.opensearch.cluster.node.DiscoveryNode; @@ -85,7 +86,7 @@ * * @opensearch.internal */ -public class TaskManager implements ClusterStateApplier, TaskResourceTrackingListener { +public class TaskManager implements ClusterStateApplier { private static final Logger logger = LogManager.getLogger(TaskManager.class); @@ -153,13 +154,29 @@ public Task register(String type, String action, TaskAwareRequest request) { } } Task task = request.createTask(taskIdGenerator.incrementAndGet(), type, action, request.getParentTask(), headers); - task.addTaskResourceTrackingListener(this); Objects.requireNonNull(task); assert task.getParentTaskId().equals(request.getParentTask()) : "Request [ " + request + "] didn't preserve it parentTaskId"; if (logger.isTraceEnabled()) { logger.trace("register {} [{}] [{}] [{}]", task.getId(), type, action, task.getDescription()); } + if (task.supportsResourceTracking()) { + task.addResourceTrackingCompletionListener(new NotifyOnceListener<>() { + @Override + protected void innerOnResponse(Task task) { + // Stop tracking the task once the last thread has been marked inactive. + if (taskResourceTrackingService.get() != null && task.supportsResourceTracking()) { + taskResourceTrackingService.get().stopTracking(task); + } + } + + @Override + protected void innerOnFailure(Exception e) { + ExceptionsHelper.reThrowIfNotNull(e); + } + }); + } + if (task instanceof CancellableTask) { registerCancellableTask(task); } else { @@ -700,12 +717,4 @@ public void cancelTaskAndDescendants(CancellableTask task, String reason, boolea throw new IllegalStateException("TaskCancellationService is not initialized"); } } - - @Override - public void onTaskResourceTrackingCompleted(Task task) { - // Stop tracking the task once the last thread has been marked inactive. - if (taskResourceTrackingService.get() != null && task.supportsResourceTracking()) { - taskResourceTrackingService.get().stopTracking(task); - } - } } diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingListener.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingListener.java deleted file mode 100644 index e29c6ec8a0256..0000000000000 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingListener.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.tasks; - -/** - * Listener for events related to resource tracking of a Task. - */ -public interface TaskResourceTrackingListener { - - /** - * Invoked when a task execution is started on a thread. - * @param task object - * @param threadId on which execution of task started - */ - default void onTaskExecutionStartedOnThread(Task task, long threadId) {} - - /** - * Invoked when a task execution is finished on a thread. - * @param task object - * @param threadId on which execution of task finished. - */ - default void onTaskExecutionFinishedOnThread(Task task, long threadId) {} - - /** - * Invoked when a task's resource stats are updated. - * @param task object - */ - default void onTaskResourceStatsUpdated(Task task) {} - - /** - * Invoked when a task's resource tracking is completed. - * This happens when all of task's threads are marked inactive. - * @param task object - */ - default void onTaskResourceTrackingCompleted(Task task) {} -} diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java index c7f53754d015c..c269bf6e7fa77 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java @@ -11,6 +11,7 @@ import com.sun.management.ThreadMXBean; import org.opensearch.ExceptionsHelper; import org.opensearch.action.ActionListener; +import org.opensearch.action.NotifyOnceListener; import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; import org.opensearch.action.admin.cluster.node.tasks.list.ListTasksRequest; import org.opensearch.action.admin.cluster.node.tasks.list.ListTasksResponse; @@ -28,7 +29,6 @@ import org.opensearch.tasks.TaskCancelledException; import org.opensearch.tasks.TaskId; import org.opensearch.tasks.TaskInfo; -import org.opensearch.tasks.TaskResourceTrackingListener; import org.opensearch.test.tasks.MockTaskManager; import org.opensearch.test.tasks.MockTaskManagerListener; import org.opensearch.threadpool.ThreadPool; @@ -185,11 +185,16 @@ protected void doRun() { // operationFinishedValidator will be called just after all task threads are marked inactive and // the task is unregistered. if (taskTestContext.operationFinishedValidator != null) { - task.addTaskResourceTrackingListener(new TaskResourceTrackingListener() { + task.addResourceTrackingCompletionListener(new NotifyOnceListener<>() { @Override - public void onTaskResourceTrackingCompleted(Task task) { + protected void innerOnResponse(Task task) { taskTestContext.operationFinishedValidator.accept(task, threadId.get()); } + + @Override + protected void innerOnFailure(Exception e) { + ExceptionsHelper.reThrowIfNotNull(e); + } }); } From 202760277ecdd8d61e411032ee051236b25fdb12 Mon Sep 17 00:00:00 2001 From: Ketan Verma Date: Tue, 2 Aug 2022 12:14:14 +0530 Subject: [PATCH 6/6] Avoid registering listeners on already completed tasks Signed-off-by: Ketan Verma --- server/src/main/java/org/opensearch/tasks/Task.java | 12 +++++++++--- .../main/java/org/opensearch/tasks/TaskManager.java | 9 ++++++++- .../cluster/node/tasks/ResourceAwareTasksTests.java | 6 +++++- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/server/src/main/java/org/opensearch/tasks/Task.java b/server/src/main/java/org/opensearch/tasks/Task.java index 4f6724041902d..d052d374ef1f0 100644 --- a/server/src/main/java/org/opensearch/tasks/Task.java +++ b/server/src/main/java/org/opensearch/tasks/Task.java @@ -410,10 +410,16 @@ public TaskResult result(DiscoveryNode node, ActionResponse response) throws IOE } /** - * Registers a task resource tracking completion listener on this task. + * Registers a task resource tracking completion listener on this task if resource tracking is still active. + * Returns true on successful subscription, false otherwise. */ - public void addResourceTrackingCompletionListener(NotifyOnceListener listener) { - resourceTrackingCompletionListeners.add(listener); + public boolean addResourceTrackingCompletionListener(NotifyOnceListener listener) { + if (numActiveResourceTrackingThreads.get() > 0) { + resourceTrackingCompletionListeners.add(listener); + return true; + } + + return false; } /** diff --git a/server/src/main/java/org/opensearch/tasks/TaskManager.java b/server/src/main/java/org/opensearch/tasks/TaskManager.java index 95f2ef0c4d731..01b6b8ea603bf 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskManager.java +++ b/server/src/main/java/org/opensearch/tasks/TaskManager.java @@ -161,7 +161,7 @@ public Task register(String type, String action, TaskAwareRequest request) { } if (task.supportsResourceTracking()) { - task.addResourceTrackingCompletionListener(new NotifyOnceListener<>() { + boolean success = task.addResourceTrackingCompletionListener(new NotifyOnceListener<>() { @Override protected void innerOnResponse(Task task) { // Stop tracking the task once the last thread has been marked inactive. @@ -175,6 +175,13 @@ protected void innerOnFailure(Exception e) { ExceptionsHelper.reThrowIfNotNull(e); } }); + + if (success == false) { + logger.debug( + "failed to register a completion listener as task resource tracking has already completed [taskId={}]", + task.getId() + ); + } } if (task instanceof CancellableTask) { diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java index c269bf6e7fa77..654d5cde7bb00 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java @@ -185,7 +185,7 @@ protected void doRun() { // operationFinishedValidator will be called just after all task threads are marked inactive and // the task is unregistered. if (taskTestContext.operationFinishedValidator != null) { - task.addResourceTrackingCompletionListener(new NotifyOnceListener<>() { + boolean success = task.addResourceTrackingCompletionListener(new NotifyOnceListener<>() { @Override protected void innerOnResponse(Task task) { taskTestContext.operationFinishedValidator.accept(task, threadId.get()); @@ -196,6 +196,10 @@ protected void innerOnFailure(Exception e) { ExceptionsHelper.reThrowIfNotNull(e); } }); + + if (success == false) { + fail("failed to register a completion listener as task resource tracking has already completed"); + } } Object[] allocation1 = new Object[1000000]; // 4MB