Skip to content

Commit

Permalink
[7.x][ML] Restore data counts on resuming data frame analytics (#67937)…
Browse files Browse the repository at this point in the history
… (#67979)

Now that data frame analytics jobs can be resumed straight into
the inference phase, we need to ensure data counts are persisted
at the end of the analysis step and restored when the job is
started again.

This commit removes the need for storing the progress on start
as a task parameter. Instead, when the task gets assigned we now
restore all stats by making a call to the get stats API. Additionally,
we now ensure that an allocated task that hasn't had its `StatsHolder`
restored yet is treated as a stopped task from the get stats API, which
means we will report the stored stats.

Relates #67623

Backport of #67937
  • Loading branch information
dimitris-athanasiou authored Jan 26, 2021
1 parent feab69b commit 3bb6e7d
Show file tree
Hide file tree
Showing 25 changed files with 165 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.MasterNodeRequest;
import org.elasticsearch.client.ElasticsearchClient;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
Expand All @@ -32,7 +31,6 @@

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;

Expand Down Expand Up @@ -156,17 +154,13 @@ public static class TaskParams implements XPackPlugin.XPackPersistentTaskParams
public static final Version VERSION_INTRODUCED = Version.V_7_3_0;
public static final Version VERSION_DESTINATION_INDEX_MAPPINGS_CHANGED = Version.V_7_10_0;

private static final ParseField PROGRESS_ON_START = new ParseField("progress_on_start");

@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, true,
a -> new TaskParams((String) a[0], (String) a[1], (List<PhaseProgress>) a[2], (Boolean) a[3]));
a -> new TaskParams((String) a[0], (String) a[1], (Boolean) a[2]));

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameAnalyticsConfig.ID);
PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameAnalyticsConfig.VERSION);
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), PhaseProgress.PARSER, PROGRESS_ON_START);
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), DataFrameAnalyticsConfig.ALLOW_LAZY_START);
}

Expand All @@ -176,28 +170,23 @@ public static TaskParams fromXContent(XContentParser parser) {

private final String id;
private final Version version;
private final List<PhaseProgress> progressOnStart;
private final boolean allowLazyStart;

public TaskParams(String id, Version version, List<PhaseProgress> progressOnStart, boolean allowLazyStart) {
public TaskParams(String id, Version version, boolean allowLazyStart) {
this.id = Objects.requireNonNull(id);
this.version = Objects.requireNonNull(version);
this.progressOnStart = Collections.unmodifiableList(progressOnStart);
this.allowLazyStart = allowLazyStart;
}

private TaskParams(String id, String version, @Nullable List<PhaseProgress> progressOnStart, Boolean allowLazyStart) {
this(id, Version.fromString(version), progressOnStart == null ? Collections.emptyList() : progressOnStart,
allowLazyStart != null && allowLazyStart);
private TaskParams(String id, String version, Boolean allowLazyStart) {
this(id, Version.fromString(version), allowLazyStart != null && allowLazyStart);
}

public TaskParams(StreamInput in) throws IOException {
this.id = in.readString();
this.version = Version.readVersion(in);
if (in.getVersion().onOrAfter(Version.V_7_5_0)) {
this.progressOnStart = in.readList(PhaseProgress::new);
} else {
this.progressOnStart = Collections.emptyList();
if (in.getVersion().onOrAfter(Version.V_7_5_0) && in.getVersion().before(Version.V_7_12_0)) {
in.readList(PhaseProgress::new);
}
if (in.getVersion().onOrAfter(Version.V_7_5_0)) {
this.allowLazyStart = in.readBoolean();
Expand All @@ -214,10 +203,6 @@ public Version getVersion() {
return version;
}

public List<PhaseProgress> getProgressOnStart() {
return progressOnStart;
}

public boolean isAllowLazyStart() {
return allowLazyStart;
}
Expand All @@ -236,8 +221,9 @@ public Version getMinimalSupportedVersion() {
public void writeTo(StreamOutput out) throws IOException {
out.writeString(id);
Version.writeVersion(version, out);
if (out.getVersion().onOrAfter(Version.V_7_5_0)) {
out.writeList(progressOnStart);
if (out.getVersion().onOrAfter(Version.V_7_5_0) && out.getVersion().before(Version.V_7_12_0)) {
// Previous versions expect a list of phase progress objects.
out.writeList(Collections.emptyList());
}
if (out.getVersion().onOrAfter(Version.V_7_5_0)) {
out.writeBoolean(allowLazyStart);
Expand All @@ -249,15 +235,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject();
builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id);
builder.field(DataFrameAnalyticsConfig.VERSION.getPreferredName(), version);
builder.field(PROGRESS_ON_START.getPreferredName(), progressOnStart);
builder.field(DataFrameAnalyticsConfig.ALLOW_LAZY_START.getPreferredName(), allowLazyStart);
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(id, version, progressOnStart, allowLazyStart);
return Objects.hash(id, version, allowLazyStart);
}

@Override
Expand All @@ -268,7 +253,6 @@ public boolean equals(Object o) {
TaskParams other = (TaskParams) o;
return Objects.equals(id, other.id)
&& Objects.equals(version, other.version)
&& Objects.equals(progressOnStart, other.progressOnStart)
&& Objects.equals(allowLazyStart, other.allowLazyStart);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;

import java.net.InetAddress;
import java.util.Collections;

import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.empty;
Expand Down Expand Up @@ -248,7 +247,7 @@ private static PersistentTasksCustomMetadata.PersistentTask<?> createDataFrameAn
boolean isStale) {
PersistentTasksCustomMetadata.Builder builder = PersistentTasksCustomMetadata.builder();
builder.addTask(MlTasks.dataFrameAnalyticsTaskId(jobId), MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME,
new StartDataFrameAnalyticsAction.TaskParams(jobId, Version.CURRENT, Collections.emptyList(), false),
new StartDataFrameAnalyticsAction.TaskParams(jobId, Version.CURRENT, false),
new PersistentTasksCustomMetadata.Assignment(nodeId, "test assignment"));
if (state != null) {
builder.updateTaskState(MlTasks.dataFrameAnalyticsTaskId(jobId),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import static org.elasticsearch.test.VersionUtils.randomVersion;

Expand All @@ -26,15 +23,9 @@ protected StartDataFrameAnalyticsAction.TaskParams doParseInstance(XContentParse

@Override
protected StartDataFrameAnalyticsAction.TaskParams createTestInstance() {
int phaseCount = randomIntBetween(0, 5);
List<PhaseProgress> progressOnStart = new ArrayList<>(phaseCount);
for (int i = 0; i < phaseCount; i++) {
progressOnStart.add(new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100)));
}
return new StartDataFrameAnalyticsAction.TaskParams(
randomAlphaOfLength(10),
randomVersion(random()),
progressOnStart,
randomBoolean());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;

import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -354,7 +353,7 @@ private static ClusterState clusterStateWithRunningAnalyticsTask(String analytic
builder.addTask(
MlTasks.dataFrameAnalyticsTaskId(analyticsId),
MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME,
new StartDataFrameAnalyticsAction.TaskParams(analyticsId, Version.CURRENT, emptyList(), false),
new StartDataFrameAnalyticsAction.TaskParams(analyticsId, Version.CURRENT, false),
new PersistentTasksCustomMetadata.Assignment("node", "test assignment"));
builder.updateTaskState(
MlTasks.dataFrameAnalyticsTaskId(analyticsId),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
import org.elasticsearch.xpack.ml.dataframe.StoredProgress;
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
import org.elasticsearch.xpack.ml.utils.persistence.MlParserUtils;

import java.util.ArrayList;
Expand Down Expand Up @@ -107,12 +108,19 @@ protected void taskOperation(GetDataFrameAnalyticsStatsAction.Request request, D

ActionListener<Void> updateProgressListener = ActionListener.wrap(
aVoid -> {
StatsHolder statsHolder = task.getStatsHolder();
if (statsHolder == null) {
// The task has just been assigned and has not been initialized with its stats holder yet.
// We return empty result here so that we treat it as a stopped task and return its stored stats.
listener.onResponse(new QueryPage<>(Collections.emptyList(), 0, GetDataFrameAnalyticsAction.Response.RESULTS_FIELD));
return;
}
Stats stats = buildStats(
task.getParams().getId(),
task.getStatsHolder().getProgressTracker().report(),
task.getStatsHolder().getDataCountsTracker().report(task.getParams().getId()),
task.getStatsHolder().getMemoryUsage(),
task.getStatsHolder().getAnalysisStats()
statsHolder.getProgressTracker().report(),
statsHolder.getDataCountsTracker().report(),
statsHolder.getMemoryUsage(),
statsHolder.getAnalysisStats()
);
listener.onResponse(new QueryPage<>(Collections.singletonList(stats), 1,
GetDataFrameAnalyticsAction.Response.RESULTS_FIELD));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetectorFactory;
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.job.JobNodeSelector;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
Expand Down Expand Up @@ -182,7 +183,6 @@ public void onFailure(Exception e) {
new TaskParams(
request.getId(),
startContext.config.getVersion(),
startContext.progressOnStart,
startContext.config.isAllowLazyStart());
persistentTasksService.sendStartRequest(
MlTasks.dataFrameAnalyticsTaskId(request.getId()),
Expand Down Expand Up @@ -479,13 +479,11 @@ public void onTimeout(TimeValue timeout) {

private static class StartContext {
private final DataFrameAnalyticsConfig config;
private final List<PhaseProgress> progressOnStart;
private final DataFrameAnalyticsTask.StartingState startingState;
private volatile ExtractedFields extractedFields;

private StartContext(DataFrameAnalyticsConfig config, List<PhaseProgress> progressOnStart) {
this.config = config;
this.progressOnStart = progressOnStart;
this.startingState = DataFrameAnalyticsTask.determineStartingState(config.getId(), progressOnStart);
}
}
Expand Down Expand Up @@ -666,26 +664,21 @@ protected void nodeOperation(AllocatedPersistentTask task, TaskParams params, Pe
return;
}

ActionListener<StoredProgress> progressListener = ActionListener.wrap(
storedProgress -> {
if (storedProgress != null) {
dfaTask.getStatsHolder().setProgressTracker(storedProgress.get());
}
// Execute task
ActionListener<GetDataFrameAnalyticsStatsAction.Response> statsListener = ActionListener.wrap(
statsResponse -> {
GetDataFrameAnalyticsStatsAction.Response.Stats stats = statsResponse.getResponse().results().get(0);
dfaTask.setStatsHolder(
new StatsHolder(stats.getProgress(), stats.getMemoryUsage(), stats.getAnalysisStats(), stats.getDataCounts()));
executeTask(dfaTask);
},
dfaTask::setFailed
);

// Get stats to initialize in memory stats tracking
ActionListener<Boolean> templateCheckListener = ActionListener.wrap(
ok -> {
if (analyticsState != DataFrameAnalyticsState.STOPPED) {
// If the state is not stopped it means the task is reassigning and
// we need to update the progress from the last stored progress doc.
searchProgressFromIndex(params.getId(), progressListener);
} else {
progressListener.onResponse(null);
}
},
ok -> executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE,
new GetDataFrameAnalyticsStatsAction.Request(params.getId()), statsListener),
error -> {
Throwable cause = ExceptionsHelper.unwrapCause(error);
logger.error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ private void executeStep(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig c
ActionListener<StepResponse> stepListener = ActionListener.wrap(
stepResponse -> {
if (stepResponse.isTaskComplete()) {
LOGGER.info("[{}] Marking task completed", config.getId());
task.markAsCompleted();
// We always want to perform the final step as it tidies things up
executeStep(task, config, new FinalStep(client, task, auditor, config));
return;
}
switch (step.name()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.ParentTaskAssigningClient;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.json.JsonXContent;
Expand Down Expand Up @@ -60,7 +61,7 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
private final StartDataFrameAnalyticsAction.TaskParams taskParams;
private volatile boolean isStopping;
private volatile boolean isMarkAsCompletedCalled;
private final StatsHolder statsHolder;
private volatile StatsHolder statsHolder;
private volatile DataFrameAnalyticsStep currentStep;

public DataFrameAnalyticsTask(long id, String type, String action, TaskId parentTask, Map<String, String> headers,
Expand All @@ -71,7 +72,6 @@ public DataFrameAnalyticsTask(long id, String type, String action, TaskId parent
this.analyticsManager = Objects.requireNonNull(analyticsManager);
this.auditor = Objects.requireNonNull(auditor);
this.taskParams = Objects.requireNonNull(taskParams);
this.statsHolder = new StatsHolder(taskParams.getProgressOnStart());
}

public void setStep(DataFrameAnalyticsStep step) {
Expand All @@ -86,6 +86,11 @@ public boolean isStopping() {
return isStopping;
}

public void setStatsHolder(StatsHolder statsHolder) {
this.statsHolder = Objects.requireNonNull(statsHolder);
}

@Nullable
public StatsHolder getStatsHolder() {
return statsHolder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,22 @@

import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;

import java.util.Objects;

public class DataCountsTracker {

private final String jobId;
private volatile long trainingDocsCount;
private volatile long testDocsCount;
private volatile long skippedDocsCount;

public DataCountsTracker(DataCounts dataCounts) {
this.jobId = Objects.requireNonNull(dataCounts.getJobId());
this.trainingDocsCount = dataCounts.getTrainingDocsCount();
this.testDocsCount = dataCounts.getTestDocsCount();
this.skippedDocsCount = dataCounts.getSkippedDocsCount();
}

public void incrementTrainingDocsCount() {
trainingDocsCount++;
}
Expand All @@ -26,12 +36,22 @@ public void incrementSkippedDocsCount() {
skippedDocsCount++;
}

public DataCounts report(String jobId) {
public DataCounts report() {
return new DataCounts(
jobId,
trainingDocsCount,
testDocsCount,
skippedDocsCount
);
}

public void reset() {
trainingDocsCount = 0;
testDocsCount = 0;
skippedDocsCount = 0;
}

public void resetTestDocsCount() {
testDocsCount = 0;
}
}
Loading

0 comments on commit 3bb6e7d

Please sign in to comment.