Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Restore data counts on resuming data frame analytics #67937

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.MasterNodeRequest;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
Expand All @@ -30,7 +29,6 @@

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;

Expand Down Expand Up @@ -147,17 +145,13 @@ public static class TaskParams implements PersistentTaskParams {
public static final Version VERSION_INTRODUCED = Version.V_7_3_0;
public static final Version VERSION_DESTINATION_INDEX_MAPPINGS_CHANGED = Version.V_7_10_0;

private static final ParseField PROGRESS_ON_START = new ParseField("progress_on_start");

@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, true,
a -> new TaskParams((String) a[0], (String) a[1], (List<PhaseProgress>) a[2], (Boolean) a[3]));
a -> new TaskParams((String) a[0], (String) a[1], (Boolean) a[2]));

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameAnalyticsConfig.ID);
PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameAnalyticsConfig.VERSION);
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), PhaseProgress.PARSER, PROGRESS_ON_START);
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), DataFrameAnalyticsConfig.ALLOW_LAZY_START);
}

Expand All @@ -167,25 +161,24 @@ public static TaskParams fromXContent(XContentParser parser) {

private final String id;
private final Version version;
private final List<PhaseProgress> progressOnStart;
private final boolean allowLazyStart;

public TaskParams(String id, Version version, List<PhaseProgress> progressOnStart, boolean allowLazyStart) {
public TaskParams(String id, Version version, boolean allowLazyStart) {
this.id = Objects.requireNonNull(id);
this.version = Objects.requireNonNull(version);
this.progressOnStart = Collections.unmodifiableList(progressOnStart);
this.allowLazyStart = allowLazyStart;
}

private TaskParams(String id, String version, @Nullable List<PhaseProgress> progressOnStart, Boolean allowLazyStart) {
this(id, Version.fromString(version), progressOnStart == null ? Collections.emptyList() : progressOnStart,
allowLazyStart != null && allowLazyStart);
private TaskParams(String id, String version, Boolean allowLazyStart) {
this(id, Version.fromString(version), allowLazyStart != null && allowLazyStart);
}

public TaskParams(StreamInput in) throws IOException {
this.id = in.readString();
this.version = Version.readVersion(in);
this.progressOnStart = in.readList(PhaseProgress::new);
if (in.getVersion().before(Version.V_8_0_0)) {
in.readList(PhaseProgress::new);
}
this.allowLazyStart = in.readBoolean();
}

Expand All @@ -197,10 +190,6 @@ public Version getVersion() {
return version;
}

public List<PhaseProgress> getProgressOnStart() {
return progressOnStart;
}

public boolean isAllowLazyStart() {
return allowLazyStart;
}
Expand All @@ -219,7 +208,10 @@ public Version getMinimalSupportedVersion() {
public void writeTo(StreamOutput out) throws IOException {
out.writeString(id);
Version.writeVersion(version, out);
out.writeList(progressOnStart);
if (out.getVersion().before(Version.V_8_0_0)) {
// Previous versions expect a list of phase progress objects.
out.writeList(Collections.emptyList());
}
out.writeBoolean(allowLazyStart);
}

Expand All @@ -228,15 +220,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject();
builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id);
builder.field(DataFrameAnalyticsConfig.VERSION.getPreferredName(), version);
builder.field(PROGRESS_ON_START.getPreferredName(), progressOnStart);
builder.field(DataFrameAnalyticsConfig.ALLOW_LAZY_START.getPreferredName(), allowLazyStart);
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(id, version, progressOnStart, allowLazyStart);
return Objects.hash(id, version, allowLazyStart);
}

@Override
Expand All @@ -247,7 +238,6 @@ public boolean equals(Object o) {
TaskParams other = (TaskParams) o;
return Objects.equals(id, other.id)
&& Objects.equals(version, other.version)
&& Objects.equals(progressOnStart, other.progressOnStart)
&& Objects.equals(allowLazyStart, other.allowLazyStart);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;

import java.net.InetAddress;
import java.util.Collections;

import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.empty;
Expand Down Expand Up @@ -248,7 +247,7 @@ private static PersistentTasksCustomMetadata.PersistentTask<?> createDataFrameAn
boolean isStale) {
PersistentTasksCustomMetadata.Builder builder = PersistentTasksCustomMetadata.builder();
builder.addTask(MlTasks.dataFrameAnalyticsTaskId(jobId), MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME,
new StartDataFrameAnalyticsAction.TaskParams(jobId, Version.CURRENT, Collections.emptyList(), false),
new StartDataFrameAnalyticsAction.TaskParams(jobId, Version.CURRENT, false),
new PersistentTasksCustomMetadata.Assignment(nodeId, "test assignment"));
if (state != null) {
builder.updateTaskState(MlTasks.dataFrameAnalyticsTaskId(jobId),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import static org.elasticsearch.test.VersionUtils.randomVersion;

Expand All @@ -26,15 +23,9 @@ protected StartDataFrameAnalyticsAction.TaskParams doParseInstance(XContentParse

@Override
protected StartDataFrameAnalyticsAction.TaskParams createTestInstance() {
int phaseCount = randomIntBetween(0, 5);
List<PhaseProgress> progressOnStart = new ArrayList<>(phaseCount);
for (int i = 0; i < phaseCount; i++) {
progressOnStart.add(new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100)));
}
return new StartDataFrameAnalyticsAction.TaskParams(
randomAlphaOfLength(10),
randomVersion(random()),
progressOnStart,
randomBoolean());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ private static ClusterState clusterStateWithRunningAnalyticsTask(String analytic
builder.addTask(
MlTasks.dataFrameAnalyticsTaskId(analyticsId),
MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME,
new StartDataFrameAnalyticsAction.TaskParams(analyticsId, Version.CURRENT, emptyList(), false),
new StartDataFrameAnalyticsAction.TaskParams(analyticsId, Version.CURRENT, false),
new PersistentTasksCustomMetadata.Assignment("node", "test assignment"));
builder.updateTaskState(
MlTasks.dataFrameAnalyticsTaskId(analyticsId),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
import org.elasticsearch.xpack.ml.dataframe.StoredProgress;
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
import org.elasticsearch.xpack.ml.utils.persistence.MlParserUtils;

import java.util.ArrayList;
Expand Down Expand Up @@ -107,12 +108,19 @@ protected void taskOperation(GetDataFrameAnalyticsStatsAction.Request request, D

ActionListener<Void> updateProgressListener = ActionListener.wrap(
aVoid -> {
StatsHolder statsHolder = task.getStatsHolder();
if (statsHolder == null) {
// The task has just been assigned and has not been initialized with its stats holder yet.
// We return empty result here so that we treat it as a stopped task and return its stored stats.
listener.onResponse(new QueryPage<>(Collections.emptyList(), 0, GetDataFrameAnalyticsAction.Response.RESULTS_FIELD));
return;
}
Stats stats = buildStats(
task.getParams().getId(),
task.getStatsHolder().getProgressTracker().report(),
task.getStatsHolder().getDataCountsTracker().report(task.getParams().getId()),
task.getStatsHolder().getMemoryUsage(),
task.getStatsHolder().getAnalysisStats()
statsHolder.getProgressTracker().report(),
statsHolder.getDataCountsTracker().report(),
statsHolder.getMemoryUsage(),
statsHolder.getAnalysisStats()
);
listener.onResponse(new QueryPage<>(Collections.singletonList(stats), 1,
GetDataFrameAnalyticsAction.Response.RESULTS_FIELD));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetectorFactory;
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.job.JobNodeSelector;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
Expand Down Expand Up @@ -185,7 +186,6 @@ public void onFailure(Exception e) {
new TaskParams(
request.getId(),
startContext.config.getVersion(),
startContext.progressOnStart,
startContext.config.isAllowLazyStart());
persistentTasksService.sendStartRequest(
MlTasks.dataFrameAnalyticsTaskId(request.getId()),
Expand Down Expand Up @@ -484,13 +484,11 @@ public void onTimeout(TimeValue timeout) {

private static class StartContext {
private final DataFrameAnalyticsConfig config;
private final List<PhaseProgress> progressOnStart;
private final DataFrameAnalyticsTask.StartingState startingState;
private volatile ExtractedFields extractedFields;

private StartContext(DataFrameAnalyticsConfig config, List<PhaseProgress> progressOnStart) {
this.config = config;
this.progressOnStart = progressOnStart;
this.startingState = DataFrameAnalyticsTask.determineStartingState(config.getId(), progressOnStart);
}
}
Expand Down Expand Up @@ -671,26 +669,21 @@ protected void nodeOperation(AllocatedPersistentTask task, TaskParams params, Pe
return;
}

ActionListener<StoredProgress> progressListener = ActionListener.wrap(
storedProgress -> {
if (storedProgress != null) {
dfaTask.getStatsHolder().setProgressTracker(storedProgress.get());
}
// Execute task
ActionListener<GetDataFrameAnalyticsStatsAction.Response> statsListener = ActionListener.wrap(
statsResponse -> {
GetDataFrameAnalyticsStatsAction.Response.Stats stats = statsResponse.getResponse().results().get(0);
dfaTask.setStatsHolder(
new StatsHolder(stats.getProgress(), stats.getMemoryUsage(), stats.getAnalysisStats(), stats.getDataCounts()));
executeTask(dfaTask);
},
dfaTask::setFailed
);

// Get stats to initialize in memory stats tracking
ActionListener<Boolean> templateCheckListener = ActionListener.wrap(
ok -> {
if (analyticsState != DataFrameAnalyticsState.STOPPED) {
// If the state is not stopped it means the task is reassigning and
// we need to update the progress from the last stored progress doc.
searchProgressFromIndex(params.getId(), progressListener);
} else {
progressListener.onResponse(null);
}
},
ok -> executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE,
new GetDataFrameAnalyticsStatsAction.Request(params.getId()), statsListener),
error -> {
Throwable cause = ExceptionsHelper.unwrapCause(error);
logger.error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ private void executeStep(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig c
ActionListener<StepResponse> stepListener = ActionListener.wrap(
stepResponse -> {
if (stepResponse.isTaskComplete()) {
LOGGER.info("[{}] Marking task completed", config.getId());
task.markAsCompleted();
// We always want to perform the final step as it tidies things up
executeStep(task, config, new FinalStep(client, task, auditor, config));
return;
}
switch (step.name()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.ParentTaskAssigningClient;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.json.JsonXContent;
Expand Down Expand Up @@ -60,7 +61,7 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
private final StartDataFrameAnalyticsAction.TaskParams taskParams;
private volatile boolean isStopping;
private volatile boolean isMarkAsCompletedCalled;
private final StatsHolder statsHolder;
private volatile StatsHolder statsHolder;
private volatile DataFrameAnalyticsStep currentStep;

public DataFrameAnalyticsTask(long id, String type, String action, TaskId parentTask, Map<String, String> headers,
Expand All @@ -71,7 +72,6 @@ public DataFrameAnalyticsTask(long id, String type, String action, TaskId parent
this.analyticsManager = Objects.requireNonNull(analyticsManager);
this.auditor = Objects.requireNonNull(auditor);
this.taskParams = Objects.requireNonNull(taskParams);
this.statsHolder = new StatsHolder(taskParams.getProgressOnStart());
}

public void setStep(DataFrameAnalyticsStep step) {
Expand All @@ -86,6 +86,11 @@ public boolean isStopping() {
return isStopping;
}

public void setStatsHolder(StatsHolder statsHolder) {
this.statsHolder = Objects.requireNonNull(statsHolder);
}

@Nullable
public StatsHolder getStatsHolder() {
return statsHolder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,22 @@

import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;

import java.util.Objects;

public class DataCountsTracker {

private final String jobId;
private volatile long trainingDocsCount;
private volatile long testDocsCount;
private volatile long skippedDocsCount;

public DataCountsTracker(DataCounts dataCounts) {
this.jobId = Objects.requireNonNull(dataCounts.getJobId());
this.trainingDocsCount = dataCounts.getTrainingDocsCount();
this.testDocsCount = dataCounts.getTestDocsCount();
this.skippedDocsCount = dataCounts.getSkippedDocsCount();
}

public void incrementTrainingDocsCount() {
trainingDocsCount++;
}
Expand All @@ -26,12 +36,22 @@ public void incrementSkippedDocsCount() {
skippedDocsCount++;
}

public DataCounts report(String jobId) {
public DataCounts report() {
return new DataCounts(
jobId,
trainingDocsCount,
testDocsCount,
skippedDocsCount
);
}

public void reset() {
trainingDocsCount = 0;
testDocsCount = 0;
skippedDocsCount = 0;
}

public void resetTestDocsCount() {
testDocsCount = 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
*/
package org.elasticsearch.xpack.ml.dataframe.stats;

import org.elasticsearch.common.Nullable;
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;

Expand All @@ -23,15 +25,12 @@ public class StatsHolder {
private final AtomicReference<AnalysisStats> analysisStatsHolder;
private final DataCountsTracker dataCountsTracker;

public StatsHolder(List<PhaseProgress> progressOnStart) {
progressTracker = new ProgressTracker(progressOnStart);
memoryUsageHolder = new AtomicReference<>();
analysisStatsHolder = new AtomicReference<>();
dataCountsTracker = new DataCountsTracker();
}

public void setProgressTracker(List<PhaseProgress> progress) {
public StatsHolder(List<PhaseProgress> progress, @Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats,
DataCounts dataCounts) {
progressTracker = new ProgressTracker(progress);
memoryUsageHolder = new AtomicReference<>(memoryUsage);
analysisStatsHolder = new AtomicReference<>(analysisStats);
dataCountsTracker = new DataCountsTracker(dataCounts);
}

/**
Expand Down
Loading