From 00de0d9f5ab3acf3a6077f0f46a44c4d7e5a43be Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Mon, 25 Jan 2021 19:45:19 +0200 Subject: [PATCH 1/2] [ML] Restore data counts on resuming data frame analytics Now that data frame analytics jobs can be resumed straight into the inference phase, we need to ensure data counts are persisted at the end of the analysis step and restored when the job is started again. This commit removes the need for storing the progress on start as a task parameter. Instead, when the task gets assigned we now restore all stats by making a call to the get stats API. Additionally, we now ensure that an allocated task that hasn't had its `StatsHolder` restored yet is treated as a stopped task from the get stats API, which means we will report the stored stats. Relates #67623 --- .../action/StartDataFrameAnalyticsAction.java | 34 +++++++------------ .../xpack/core/ml/MlTasksTests.java | 3 +- ...taFrameAnalyticsActionTaskParamsTests.java | 9 ----- .../DataFrameAnalyticsConfigProviderIT.java | 2 +- ...sportGetDataFrameAnalyticsStatsAction.java | 16 ++++++--- ...ransportStartDataFrameAnalyticsAction.java | 27 ++++++--------- .../dataframe/DataFrameAnalyticsManager.java | 4 +-- .../ml/dataframe/DataFrameAnalyticsTask.java | 9 +++-- .../ml/dataframe/stats/DataCountsTracker.java | 22 +++++++++++- .../xpack/ml/dataframe/stats/StatsHolder.java | 15 ++++---- .../steps/AbstractDataFrameAnalyticsStep.java | 6 +++- .../ml/dataframe/steps/AnalysisStep.java | 2 ++ .../xpack/ml/dataframe/steps/FinalStep.java | 9 +++-- .../ml/dataframe/steps/InferenceStep.java | 2 ++ ...ortStartDataFrameAnalyticsActionTests.java | 10 +++--- ...portStopDataFrameAnalyticsActionTests.java | 2 +- .../MlAutoscalingDeciderServiceTests.java | 2 +- .../DataFrameAnalyticsTaskTests.java | 9 +++-- .../inference/InferenceRunnerTests.java | 3 +- .../process/AnalyticsProcessManagerTests.java | 16 ++++++--- .../AnalyticsResultProcessorTests.java | 7 +++- .../ml/dataframe/stats/StatsHolderTests.java | 17 ++++++---- .../xpack/ml/job/JobNodeSelectorTests.java | 4 +-- .../ml/process/MlMemoryTrackerTests.java | 4 +-- 24 files changed, 136 insertions(+), 98 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java index c402d15c39aa6..702ada4d67a00 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java @@ -9,7 +9,6 @@ import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.support.master.MasterNodeRequest; -import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; @@ -30,7 +29,6 @@ import java.io.IOException; import java.util.Collections; -import java.util.List; import java.util.Objects; import java.util.concurrent.TimeUnit; @@ -147,17 +145,13 @@ public static class TaskParams implements PersistentTaskParams { public static final Version VERSION_INTRODUCED = Version.V_7_3_0; public static final Version VERSION_DESTINATION_INDEX_MAPPINGS_CHANGED = Version.V_7_10_0; - private static final ParseField PROGRESS_ON_START = new ParseField("progress_on_start"); - - @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, true, - a -> new TaskParams((String) a[0], (String) a[1], (List) a[2], (Boolean) a[3])); + a -> new TaskParams((String) a[0], (String) a[1], (Boolean) a[2])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameAnalyticsConfig.ID); PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameAnalyticsConfig.VERSION); - PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), PhaseProgress.PARSER, PROGRESS_ON_START); PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), DataFrameAnalyticsConfig.ALLOW_LAZY_START); } @@ -167,25 +161,24 @@ public static TaskParams fromXContent(XContentParser parser) { private final String id; private final Version version; - private final List progressOnStart; private final boolean allowLazyStart; - public TaskParams(String id, Version version, List progressOnStart, boolean allowLazyStart) { + public TaskParams(String id, Version version, boolean allowLazyStart) { this.id = Objects.requireNonNull(id); this.version = Objects.requireNonNull(version); - this.progressOnStart = Collections.unmodifiableList(progressOnStart); this.allowLazyStart = allowLazyStart; } - private TaskParams(String id, String version, @Nullable List progressOnStart, Boolean allowLazyStart) { - this(id, Version.fromString(version), progressOnStart == null ? Collections.emptyList() : progressOnStart, - allowLazyStart != null && allowLazyStart); + private TaskParams(String id, String version, Boolean allowLazyStart) { + this(id, Version.fromString(version), allowLazyStart != null && allowLazyStart); } public TaskParams(StreamInput in) throws IOException { this.id = in.readString(); this.version = Version.readVersion(in); - this.progressOnStart = in.readList(PhaseProgress::new); + if (in.getVersion().before(Version.V_8_0_0)) { + in.readList(PhaseProgress::new); + } this.allowLazyStart = in.readBoolean(); } @@ -197,10 +190,6 @@ public Version getVersion() { return version; } - public List getProgressOnStart() { - return progressOnStart; - } - public boolean isAllowLazyStart() { return allowLazyStart; } @@ -219,7 +208,10 @@ public Version getMinimalSupportedVersion() { public void writeTo(StreamOutput out) throws IOException { out.writeString(id); Version.writeVersion(version, out); - out.writeList(progressOnStart); + if (out.getVersion().before(Version.V_8_0_0)) { + // Previous versions expect a list of phase progress objects. + out.writeList(Collections.emptyList()); + } out.writeBoolean(allowLazyStart); } @@ -228,7 +220,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id); builder.field(DataFrameAnalyticsConfig.VERSION.getPreferredName(), version); - builder.field(PROGRESS_ON_START.getPreferredName(), progressOnStart); builder.field(DataFrameAnalyticsConfig.ALLOW_LAZY_START.getPreferredName(), allowLazyStart); builder.endObject(); return builder; @@ -236,7 +227,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public int hashCode() { - return Objects.hash(id, version, progressOnStart, allowLazyStart); + return Objects.hash(id, version, allowLazyStart); } @Override @@ -247,7 +238,6 @@ public boolean equals(Object o) { TaskParams other = (TaskParams) o; return Objects.equals(id, other.id) && Objects.equals(version, other.version) - && Objects.equals(progressOnStart, other.progressOnStart) && Objects.equals(allowLazyStart, other.allowLazyStart); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java index 9567b67a2a807..715c898bcd4ac 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java @@ -22,7 +22,6 @@ import org.elasticsearch.xpack.core.ml.job.config.JobTaskState; import java.net.InetAddress; -import java.util.Collections; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.empty; @@ -248,7 +247,7 @@ private static PersistentTasksCustomMetadata.PersistentTask createDataFrameAn boolean isStale) { PersistentTasksCustomMetadata.Builder builder = PersistentTasksCustomMetadata.builder(); builder.addTask(MlTasks.dataFrameAnalyticsTaskId(jobId), MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, - new StartDataFrameAnalyticsAction.TaskParams(jobId, Version.CURRENT, Collections.emptyList(), false), + new StartDataFrameAnalyticsAction.TaskParams(jobId, Version.CURRENT, false), new PersistentTasksCustomMetadata.Assignment(nodeId, "test assignment")); if (state != null) { builder.updateTaskState(MlTasks.dataFrameAnalyticsTaskId(jobId), diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsActionTaskParamsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsActionTaskParamsTests.java index 2a5dac66aaa7d..e2d2197018e28 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsActionTaskParamsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsActionTaskParamsTests.java @@ -9,11 +9,8 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractSerializingTestCase; -import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; import static org.elasticsearch.test.VersionUtils.randomVersion; @@ -26,15 +23,9 @@ protected StartDataFrameAnalyticsAction.TaskParams doParseInstance(XContentParse @Override protected StartDataFrameAnalyticsAction.TaskParams createTestInstance() { - int phaseCount = randomIntBetween(0, 5); - List progressOnStart = new ArrayList<>(phaseCount); - for (int i = 0; i < phaseCount; i++) { - progressOnStart.add(new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100))); - } return new StartDataFrameAnalyticsAction.TaskParams( randomAlphaOfLength(10), randomVersion(random()), - progressOnStart, randomBoolean()); } diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsConfigProviderIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsConfigProviderIT.java index 62c9bfa57f967..2e117d3d4a91d 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsConfigProviderIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsConfigProviderIT.java @@ -354,7 +354,7 @@ private static ClusterState clusterStateWithRunningAnalyticsTask(String analytic builder.addTask( MlTasks.dataFrameAnalyticsTaskId(analyticsId), MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, - new StartDataFrameAnalyticsAction.TaskParams(analyticsId, Version.CURRENT, emptyList(), false), + new StartDataFrameAnalyticsAction.TaskParams(analyticsId, Version.CURRENT, false), new PersistentTasksCustomMetadata.Assignment("node", "test assignment")); builder.updateTaskState( MlTasks.dataFrameAnalyticsTaskId(analyticsId), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java index a57715aed9085..f7cc8fd8ff99e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java @@ -55,6 +55,7 @@ import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask; import org.elasticsearch.xpack.ml.dataframe.StoredProgress; import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker; +import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; import org.elasticsearch.xpack.ml.utils.persistence.MlParserUtils; import java.util.ArrayList; @@ -107,12 +108,19 @@ protected void taskOperation(GetDataFrameAnalyticsStatsAction.Request request, D ActionListener updateProgressListener = ActionListener.wrap( aVoid -> { + StatsHolder statsHolder = task.getStatsHolder(); + if (statsHolder == null) { + // The task has just been assigned and has not been initialized with its stats holder yet. + // We return empty result here so that we treat it as a stopped task and return its stored stats. + listener.onResponse(new QueryPage<>(Collections.emptyList(), 0, GetDataFrameAnalyticsAction.Response.RESULTS_FIELD)); + return; + } Stats stats = buildStats( task.getParams().getId(), - task.getStatsHolder().getProgressTracker().report(), - task.getStatsHolder().getDataCountsTracker().report(task.getParams().getId()), - task.getStatsHolder().getMemoryUsage(), - task.getStatsHolder().getAnalysisStats() + statsHolder.getProgressTracker().report(), + statsHolder.getDataCountsTracker().report(), + statsHolder.getMemoryUsage(), + statsHolder.getAnalysisStats() ); listener.onResponse(new QueryPage<>(Collections.singletonList(stats), 1, GetDataFrameAnalyticsAction.Response.RESULTS_FIELD)); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java index 9f6b6894e7905..4db85ff2246d6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java @@ -79,6 +79,7 @@ import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetectorFactory; import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; +import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.job.JobNodeSelector; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; @@ -185,7 +186,6 @@ public void onFailure(Exception e) { new TaskParams( request.getId(), startContext.config.getVersion(), - startContext.progressOnStart, startContext.config.isAllowLazyStart()); persistentTasksService.sendStartRequest( MlTasks.dataFrameAnalyticsTaskId(request.getId()), @@ -484,13 +484,11 @@ public void onTimeout(TimeValue timeout) { private static class StartContext { private final DataFrameAnalyticsConfig config; - private final List progressOnStart; private final DataFrameAnalyticsTask.StartingState startingState; private volatile ExtractedFields extractedFields; private StartContext(DataFrameAnalyticsConfig config, List progressOnStart) { this.config = config; - this.progressOnStart = progressOnStart; this.startingState = DataFrameAnalyticsTask.determineStartingState(config.getId(), progressOnStart); } } @@ -671,26 +669,21 @@ protected void nodeOperation(AllocatedPersistentTask task, TaskParams params, Pe return; } - ActionListener progressListener = ActionListener.wrap( - storedProgress -> { - if (storedProgress != null) { - dfaTask.getStatsHolder().setProgressTracker(storedProgress.get()); - } + // Execute task + ActionListener statsListener = ActionListener.wrap( + statsResponse -> { + GetDataFrameAnalyticsStatsAction.Response.Stats stats = statsResponse.getResponse().results().get(0); + dfaTask.setStatsHolder( + new StatsHolder(stats.getProgress(), stats.getMemoryUsage(), stats.getAnalysisStats(), stats.getDataCounts())); executeTask(dfaTask); }, dfaTask::setFailed ); + // Get stats to initialize in memory stats tracking ActionListener templateCheckListener = ActionListener.wrap( - ok -> { - if (analyticsState != DataFrameAnalyticsState.STOPPED) { - // If the state is not stopped it means the task is reassigning and - // we need to update the progress from the last stored progress doc. - searchProgressFromIndex(params.getId(), progressListener); - } else { - progressListener.onResponse(null); - } - }, + ok -> executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE, + new GetDataFrameAnalyticsStatsAction.Request(params.getId()), statsListener), error -> { Throwable cause = ExceptionsHelper.unwrapCause(error); logger.error( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java index 648d691ed3d84..bada15f1b7127 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java @@ -178,8 +178,8 @@ private void executeStep(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig c ActionListener stepListener = ActionListener.wrap( stepResponse -> { if (stepResponse.isTaskComplete()) { - LOGGER.info("[{}] Marking task completed", config.getId()); - task.markAsCompleted(); + // We always want to perform the final step as it tidies things up + executeStep(task, config, new FinalStep(client, task, auditor, config)); return; } switch (step.name()) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java index 3d4222d730ea1..310a76da10736 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java @@ -18,6 +18,7 @@ import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.Client; import org.elasticsearch.client.ParentTaskAssigningClient; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.json.JsonXContent; @@ -60,7 +61,7 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S private final StartDataFrameAnalyticsAction.TaskParams taskParams; private volatile boolean isStopping; private volatile boolean isMarkAsCompletedCalled; - private final StatsHolder statsHolder; + private volatile StatsHolder statsHolder; private volatile DataFrameAnalyticsStep currentStep; public DataFrameAnalyticsTask(long id, String type, String action, TaskId parentTask, Map headers, @@ -71,7 +72,6 @@ public DataFrameAnalyticsTask(long id, String type, String action, TaskId parent this.analyticsManager = Objects.requireNonNull(analyticsManager); this.auditor = Objects.requireNonNull(auditor); this.taskParams = Objects.requireNonNull(taskParams); - this.statsHolder = new StatsHolder(taskParams.getProgressOnStart()); } public void setStep(DataFrameAnalyticsStep step) { @@ -86,6 +86,11 @@ public boolean isStopping() { return isStopping; } + public void setStatsHolder(StatsHolder statsHolder) { + this.statsHolder = Objects.requireNonNull(statsHolder); + } + + @Nullable public StatsHolder getStatsHolder() { return statsHolder; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTracker.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTracker.java index bed9f52b448cf..4a4d3fe728ea6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTracker.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTracker.java @@ -8,12 +8,22 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; +import java.util.Objects; + public class DataCountsTracker { + private final String jobId; private volatile long trainingDocsCount; private volatile long testDocsCount; private volatile long skippedDocsCount; + public DataCountsTracker(DataCounts dataCounts) { + this.jobId = Objects.requireNonNull(dataCounts.getJobId()); + this.trainingDocsCount = dataCounts.getTrainingDocsCount(); + this.testDocsCount = dataCounts.getTestDocsCount(); + this.skippedDocsCount = dataCounts.getSkippedDocsCount(); + } + public void incrementTrainingDocsCount() { trainingDocsCount++; } @@ -26,7 +36,7 @@ public void incrementSkippedDocsCount() { skippedDocsCount++; } - public DataCounts report(String jobId) { + public DataCounts report() { return new DataCounts( jobId, trainingDocsCount, @@ -34,4 +44,14 @@ public DataCounts report(String jobId) { skippedDocsCount ); } + + public void reset() { + trainingDocsCount = 0; + testDocsCount = 0; + skippedDocsCount = 0; + } + + public void resetTestDocsCount() { + testDocsCount = 0; + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java index 1b74ee6ec47d1..59ba0c51e40bd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java @@ -5,7 +5,9 @@ */ package org.elasticsearch.xpack.ml.dataframe.stats; +import org.elasticsearch.common.Nullable; import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; @@ -23,15 +25,12 @@ public class StatsHolder { private final AtomicReference analysisStatsHolder; private final DataCountsTracker dataCountsTracker; - public StatsHolder(List progressOnStart) { - progressTracker = new ProgressTracker(progressOnStart); - memoryUsageHolder = new AtomicReference<>(); - analysisStatsHolder = new AtomicReference<>(); - dataCountsTracker = new DataCountsTracker(); - } - - public void setProgressTracker(List progress) { + public StatsHolder(List progress, @Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats, + DataCounts dataCounts) { progressTracker = new ProgressTracker(progress); + memoryUsageHolder = new AtomicReference<>(memoryUsage); + analysisStatsHolder = new AtomicReference<>(analysisStats); + dataCountsTracker = new DataCountsTracker(dataCounts); } /** diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AbstractDataFrameAnalyticsStep.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AbstractDataFrameAnalyticsStep.java index db9db30c2b0d8..51dcbae69d5bd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AbstractDataFrameAnalyticsStep.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AbstractDataFrameAnalyticsStep.java @@ -57,7 +57,7 @@ protected TaskId getParentTaskId() { @Override public final void execute(ActionListener listener) { logger.debug(() -> new ParameterizedMessage("[{}] Executing step [{}]", config.getId(), name())); - if (task.isStopping()) { + if (task.isStopping() && shouldSkipIfTaskIsStopping()) { logger.debug(() -> new ParameterizedMessage("[{}] task is stopping before starting [{}] step", config.getId(), name())); listener.onResponse(new StepResponse(true)); return; @@ -76,4 +76,8 @@ protected void refreshDestAsync(ActionListener refreshListener) executeWithHeadersAsync(config.getHeaders(), ML_ORIGIN, parentTaskClient, RefreshAction.INSTANCE, new RefreshRequest(config.getDest().getIndex()), refreshListener); } + + protected boolean shouldSkipIfTaskIsStopping() { + return true; + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AnalysisStep.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AnalysisStep.java index 495028c5f48d6..1d0e20f546fab 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AnalysisStep.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AnalysisStep.java @@ -48,6 +48,8 @@ public void updateProgress(ActionListener listener) { @Override protected void doExecute(ActionListener listener) { + task.getStatsHolder().getDataCountsTracker().reset(); + final ParentTaskAssigningClient parentTaskClient = parentTaskClient(); // Update state to ANALYZING and start process ActionListener dataExtractorFactoryListener = ActionListener.wrap( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/FinalStep.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/FinalStep.java index 6bffefde369ea..827a46ae80345 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/FinalStep.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/FinalStep.java @@ -60,7 +60,7 @@ public Name name() { protected void doExecute(ActionListener listener) { ActionListener refreshListener = ActionListener.wrap( - refreshResponse -> listener.onResponse(new StepResponse(true)), + refreshResponse -> listener.onResponse(new StepResponse(false)), listener::onFailure ); @@ -73,7 +73,7 @@ protected void doExecute(ActionListener listener) { } private void indexDataCounts(ActionListener listener) { - DataCounts dataCounts = task.getStatsHolder().getDataCountsTracker().report(config.getId()); + DataCounts dataCounts = task.getStatsHolder().getDataCountsTracker().report(); try (XContentBuilder builder = XContentFactory.jsonBuilder()) { dataCounts.toXContent(builder, new ToXContent.MapParams( Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"))); @@ -111,4 +111,9 @@ public void updateProgress(ActionListener listener) { // No progress to update listener.onResponse(null); } + + @Override + protected boolean shouldSkipIfTaskIsStopping() { + return false; + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/InferenceStep.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/InferenceStep.java index 381f6090c907a..84af2fe107442 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/InferenceStep.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/InferenceStep.java @@ -62,6 +62,8 @@ protected void doExecute(ActionListener listener) { return; } + task.getStatsHolder().getDataCountsTracker().resetTestDocsCount(); + ActionListener modelIdListener = ActionListener.wrap( modelId -> runInference(modelId, listener), listener::onFailure diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java index f49eb74ace686..f112bf52ea555 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java @@ -50,7 +50,7 @@ public class TransportStartDataFrameAnalyticsActionTests extends ESTestCase { // Cannot assign the node because upgrade mode is enabled public void testGetAssignment_UpgradeModeIsEnabled() { TaskExecutor executor = createTaskExecutor(); - TaskParams params = new TaskParams(JOB_ID, Version.CURRENT, Collections.emptyList(), false); + TaskParams params = new TaskParams(JOB_ID, Version.CURRENT, false); ClusterState clusterState = ClusterState.builder(new ClusterName("_name")) .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().isUpgradeMode(true).build())) @@ -64,7 +64,7 @@ public void testGetAssignment_UpgradeModeIsEnabled() { // Cannot assign the node because there are no existing nodes in the cluster state public void testGetAssignment_NoNodes() { TaskExecutor executor = createTaskExecutor(); - TaskParams params = new TaskParams(JOB_ID, Version.CURRENT, Collections.emptyList(), false); + TaskParams params = new TaskParams(JOB_ID, Version.CURRENT, false); ClusterState clusterState = ClusterState.builder(new ClusterName("_name")) .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().build())) @@ -78,7 +78,7 @@ public void testGetAssignment_NoNodes() { // Cannot assign the node because none of the existing nodes is an ML node public void testGetAssignment_NoMlNodes() { TaskExecutor executor = createTaskExecutor(); - TaskParams params = new TaskParams(JOB_ID, Version.CURRENT, Collections.emptyList(), false); + TaskParams params = new TaskParams(JOB_ID, Version.CURRENT, false); ClusterState clusterState = ClusterState.builder(new ClusterName("_name")) .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().build())) @@ -104,7 +104,7 @@ public void testGetAssignment_NoMlNodes() { // - _node_name2 is too old (version 7.9.2) public void testGetAssignment_MlNodesAreTooOld() { TaskExecutor executor = createTaskExecutor(); - TaskParams params = new TaskParams(JOB_ID, Version.CURRENT, Collections.emptyList(), false); + TaskParams params = new TaskParams(JOB_ID, Version.CURRENT, false); ClusterState clusterState = ClusterState.builder(new ClusterName("_name")) .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().build())) @@ -131,7 +131,7 @@ public void testGetAssignment_MlNodesAreTooOld() { // In such a case destination index will be created from scratch so that its mappings are up-to-date. public void testGetAssignment_MlNodeIsNewerThanTheMlJobButTheAssignmentSuceeds() { TaskExecutor executor = createTaskExecutor(); - TaskParams params = new TaskParams(JOB_ID, Version.V_7_9_0, Collections.emptyList(), false); + TaskParams params = new TaskParams(JOB_ID, Version.V_7_9_0, false); ClusterState clusterState = ClusterState.builder(new ClusterName("_name")) .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().build())) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStopDataFrameAnalyticsActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStopDataFrameAnalyticsActionTests.java index 33a1581d73fca..0476844e8cf63 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStopDataFrameAnalyticsActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStopDataFrameAnalyticsActionTests.java @@ -62,7 +62,7 @@ private static void addAnalyticsTask(PersistentTasksCustomMetadata.Builder build private static void addAnalyticsTask(PersistentTasksCustomMetadata.Builder builder, String analyticsId, String nodeId, DataFrameAnalyticsState state, boolean allowLazyStart) { builder.addTask(MlTasks.dataFrameAnalyticsTaskId(analyticsId), MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, - new StartDataFrameAnalyticsAction.TaskParams(analyticsId, Version.CURRENT, Collections.emptyList(), allowLazyStart), + new StartDataFrameAnalyticsAction.TaskParams(analyticsId, Version.CURRENT, allowLazyStart), new PersistentTasksCustomMetadata.Assignment(nodeId, "test assignment")); if (state != null) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingDeciderServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingDeciderServiceTests.java index 6e20689ecaca3..71ed53f9bd371 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingDeciderServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingDeciderServiceTests.java @@ -489,7 +489,7 @@ public static void addAnalyticsTask(String jobId, builder.addTask( MlTasks.dataFrameAnalyticsTaskId(jobId), MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, - new StartDataFrameAnalyticsAction.TaskParams(jobId, Version.CURRENT, Collections.emptyList(), true), + new StartDataFrameAnalyticsAction.TaskParams(jobId, Version.CURRENT, true), nodeId == null ? AWAITING_LAZY_ASSIGNMENT : new PersistentTasksCustomMetadata.Assignment(nodeId, "test assignment") ); if (jobState != null) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java index aae85f27e65ef..fcd3002a8be9f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java @@ -31,10 +31,12 @@ import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.StartingState; import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker; +import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; import org.elasticsearch.xpack.ml.dataframe.steps.DataFrameAnalyticsStep; import org.elasticsearch.xpack.ml.dataframe.steps.StepResponse; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; @@ -163,7 +165,7 @@ private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAl new PhaseProgress(ProgressTracker.WRITING_RESULTS, 0)); StartDataFrameAnalyticsAction.TaskParams taskParams = new StartDataFrameAnalyticsAction.TaskParams( - "task_id", Version.CURRENT, progress, false); + "task_id", Version.CURRENT, false); SearchResponse searchResponse = mock(SearchResponse.class); when(searchResponse.getHits()).thenReturn(searchHits); @@ -180,6 +182,7 @@ private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAl new DataFrameAnalyticsTask( 123, "type", "action", null, Map.of(), client, analyticsManager, auditor, taskParams); task.init(persistentTasksService, taskManager, "task-id", 42); + task.setStatsHolder(new StatsHolder(progress, null, null, new DataCounts("test_job"))); task.persistProgress(client, "task_id", runnable); @@ -243,7 +246,6 @@ private void testSetFailed(boolean nodeShuttingDown) throws IOException { new StartDataFrameAnalyticsAction.TaskParams( "job-id", Version.CURRENT, - progress, false); SearchResponse searchResponse = mock(SearchResponse.class); @@ -257,6 +259,7 @@ private void testSetFailed(boolean nodeShuttingDown) throws IOException { new DataFrameAnalyticsTask( 123, "type", "action", null, Map.of(), client, analyticsManager, auditor, taskParams); task.init(persistentTasksService, taskManager, "task-id", 42); + task.setStatsHolder(new StatsHolder(progress, null, null, new DataCounts("test_job"))); task.setStep(new StubReindexingStep(task.getStatsHolder().getProgressTracker())); Exception exception = new Exception("some exception"); @@ -301,7 +304,7 @@ private static Answer withResponse(Response response) { }; } - private class StubReindexingStep implements DataFrameAnalyticsStep { + private static class StubReindexingStep implements DataFrameAnalyticsStep { private final ProgressTracker progressTracker; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java index ae21775da7a4e..ca6cbab18e51a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.analyses.RegressionTests; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; @@ -178,6 +179,6 @@ private LocalModel localModelInferences(InferenceResults first, InferenceResults private InferenceRunner createInferenceRunner(ExtractedFields extractedFields) { return new InferenceRunner(Settings.EMPTY, client, modelLoadingService, resultsPersisterService, parentTaskId, config, - extractedFields, progressTracker, new DataCountsTracker()); + extractedFields, progressTracker, new DataCountsTracker(new DataCounts(config.getId()))); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java index 77105f213ea26..7f49a32bcff11 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; @@ -97,8 +98,7 @@ public void setUpMocks() { task = mock(DataFrameAnalyticsTask.class); when(task.getAllocationId()).thenReturn(TASK_ALLOCATION_ID); - when(task.getStatsHolder()).thenReturn(new StatsHolder( - ProgressTracker.fromZeroes(Collections.singletonList("analyzing"), false).report())); + when(task.getStatsHolder()).thenReturn(newStatsHolder()); when(task.getParentTaskId()).thenReturn(new TaskId("")); dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandomBuilder(CONFIG_ID, false, @@ -117,10 +117,16 @@ public void setUpMocks() { processFactory, auditor, trainedModelProvider, resultsPersisterService, 1); } + private StatsHolder newStatsHolder() { + return new StatsHolder(ProgressTracker.fromZeroes(Collections.singletonList("analyzing"), false).report(), + null, + null, + new DataCounts(CONFIG_ID)); + } + public void testRunJob_TaskIsStopping() { when(task.isStopping()).thenReturn(true); - when(task.getParams()).thenReturn( - new StartDataFrameAnalyticsAction.TaskParams("data_frame_id", Version.CURRENT, Collections.emptyList(), false)); + when(task.getParams()).thenReturn(new StartDataFrameAnalyticsAction.TaskParams("data_frame_id", Version.CURRENT, false)); processManager.runJob(task, dataFrameAnalyticsConfig, dataExtractorFactory, ActionListener.wrap( stepResponse -> { @@ -209,7 +215,7 @@ public void testRunJob_Ok() { public void testRunJob_ProcessNotAliveAfterStart() { when(process.isProcessAlive()).thenReturn(false); when(task.getParams()).thenReturn( - new StartDataFrameAnalyticsAction.TaskParams("data_frame_id", Version.CURRENT, Collections.emptyList(), false)); + new StartDataFrameAnalyticsAction.TaskParams("data_frame_id", Version.CURRENT, false)); processManager.runJob(task, dataFrameAnalyticsConfig, dataExtractorFactory, ActionListener.wrap( stepResponse -> fail("Expected error but listener got a response instead"), diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index a77637bc59bd1..2f83fb4a166dd 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStatsTests; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage; import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests; @@ -59,7 +60,11 @@ public class AnalyticsResultProcessorTests extends ESTestCase { private AnalyticsProcess process; private DataFrameRowsJoiner dataFrameRowsJoiner; - private StatsHolder statsHolder = new StatsHolder(ProgressTracker.fromZeroes(Collections.singletonList("analyzing"), false).report()); + private StatsHolder statsHolder = new StatsHolder( + ProgressTracker.fromZeroes(Collections.singletonList("analyzing"), false).report(), + null, + null, + new DataCounts(JOB_ID)); private TrainedModelProvider trainedModelProvider; private DataFrameAnalyticsAuditor auditor; private StatsPersister statsPersister; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java index 7a5cf841d582f..2a2721a835102 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.dataframe.stats; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import java.util.Arrays; @@ -29,7 +30,7 @@ public void testAdjustProgressTracker_GivenZeroProgress() { new PhaseProgress("writing_results", 0) ) ); - StatsHolder statsHolder = new StatsHolder(phases); + StatsHolder statsHolder = newStatsHolder(phases); statsHolder.adjustProgressTracker(Arrays.asList("a", "b"), false); @@ -55,7 +56,7 @@ public void testAdjustProgressTracker_GivenSameAnalysisPhases() { new PhaseProgress("writing_results", 50) ) ); - StatsHolder statsHolder = new StatsHolder(phases); + StatsHolder statsHolder = newStatsHolder(phases); statsHolder.adjustProgressTracker(Arrays.asList("a", "b"), false); @@ -81,7 +82,7 @@ public void testAdjustProgressTracker_GivenDifferentAnalysisPhases() { new PhaseProgress("writing_results", 50) ) ); - StatsHolder statsHolder = new StatsHolder(phases); + StatsHolder statsHolder = newStatsHolder(phases); statsHolder.adjustProgressTracker(Arrays.asList("c", "d"), false); @@ -107,7 +108,7 @@ public void testAdjustProgressTracker_GivenReindexingProgressIncomplete() { new PhaseProgress("writing_results", 50) ) ); - StatsHolder statsHolder = new StatsHolder(phases); + StatsHolder statsHolder = newStatsHolder(phases); statsHolder.adjustProgressTracker(Arrays.asList("a", "b"), false); @@ -133,7 +134,7 @@ public void testAdjustProgressTracker_GivenAllPhasesCompleteExceptInference() { new PhaseProgress("inference", 20) ) ); - StatsHolder statsHolder = new StatsHolder(phases); + StatsHolder statsHolder = newStatsHolder(phases); statsHolder.adjustProgressTracker(Arrays.asList("a", "b"), true); @@ -159,7 +160,7 @@ public void testResetProgressTracker() { new PhaseProgress("writing_results", 50) ) ); - StatsHolder statsHolder = new StatsHolder(phases); + StatsHolder statsHolder = newStatsHolder(phases); statsHolder.resetProgressTracker(Arrays.asList("a", "b"), false); @@ -174,4 +175,8 @@ public void testResetProgressTracker() { assertThat(phaseProgresses.get(3).getProgressPercent(), equalTo(0)); assertThat(phaseProgresses.get(4).getProgressPercent(), equalTo(0)); } + + private static StatsHolder newStatsHolder(List progress) { + return new StatsHolder(progress, null, null, new DataCounts("test_job")); + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java index f5d82171289ce..4d831d714e939 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java @@ -819,7 +819,7 @@ static void addDataFrameAnalyticsJobTask(String id, String nodeId, DataFrameAnal static void addDataFrameAnalyticsJobTask(String id, String nodeId, DataFrameAnalyticsState state, PersistentTasksCustomMetadata.Builder builder, boolean isStale, boolean allowLazyStart) { builder.addTask(MlTasks.dataFrameAnalyticsTaskId(id), MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, - new StartDataFrameAnalyticsAction.TaskParams(id, Version.CURRENT, Collections.emptyList(), allowLazyStart), + new StartDataFrameAnalyticsAction.TaskParams(id, Version.CURRENT, allowLazyStart), new PersistentTasksCustomMetadata.Assignment(nodeId, "test assignment")); if (state != null) { builder.updateTaskState(MlTasks.dataFrameAnalyticsTaskId(id), @@ -828,6 +828,6 @@ static void addDataFrameAnalyticsJobTask(String id, String nodeId, DataFrameAnal } private static TaskParams createTaskParams(String id) { - return new TaskParams(id, Version.CURRENT, Collections.emptyList(), false); + return new TaskParams(id, Version.CURRENT, false); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java index 4cb597a5d8b86..e9d3846f7090c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java @@ -267,7 +267,7 @@ private PersistentTasksCustomMetadata.PersistentTask ma PersistentTasksCustomMetadata.PersistentTask makeTestDataFrameAnalyticsTask(String id, boolean allowLazyStart) { return new PersistentTasksCustomMetadata.PersistentTask<>(MlTasks.dataFrameAnalyticsTaskId(id), - MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, new StartDataFrameAnalyticsAction.TaskParams(id, Version.CURRENT, - Collections.emptyList(), allowLazyStart), 0, PersistentTasksCustomMetadata.INITIAL_ASSIGNMENT); + MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, new StartDataFrameAnalyticsAction.TaskParams(id, Version.CURRENT, allowLazyStart), + 0, PersistentTasksCustomMetadata.INITIAL_ASSIGNMENT); } } From 4a721fdc5d0ec0d07aa84b9e7381fdb41b78615e Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Mon, 25 Jan 2021 21:14:09 +0200 Subject: [PATCH 2/2] Add unit test for `DataCountsTracker` --- .../stats/DataCountsTrackerTests.java | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTrackerTests.java diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTrackerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTrackerTests.java new file mode 100644 index 0000000000000..6c73bc774f68c --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTrackerTests.java @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.stats; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; + +import static org.hamcrest.Matchers.equalTo; + +public class DataCountsTrackerTests extends ESTestCase { + + private static final String JOB_ID = "test"; + + public void testReset() { + DataCountsTracker dataCountsTracker = new DataCountsTracker(new DataCounts(JOB_ID, 10, 20, 30)); + dataCountsTracker.reset(); + DataCounts resetDataCounts = dataCountsTracker.report(); + assertThat(resetDataCounts, equalTo(new DataCounts(JOB_ID))); + } + + public void testResetTestDocsCount() { + DataCountsTracker dataCountsTracker = new DataCountsTracker(new DataCounts(JOB_ID, 10, 20, 30)); + dataCountsTracker.resetTestDocsCount(); + DataCounts resetDataCounts = dataCountsTracker.report(); + assertThat(resetDataCounts, equalTo(new DataCounts(JOB_ID, 10, 0, 30))); + } +}