Skip to content

Commit

Permalink
[ML] Parent datafeed actions to the datafeed's persistent task (#81143)…
Browse files Browse the repository at this point in the history
… (#81152)

The vast majority of a datafeed's actions are executed from the
data extractor. This includes the heaviest actions which are the
searches. This commit passes a `ParentTaskAssigningClient` to
`DataExtractorFactory.create` which ensures the client used by
any extractor will be setting the corresponding task id: the action
task id for preview datafeed and the master operation stage of the
start datafeed action, and the persistent task id for the datafeed
operations after it has started.
  • Loading branch information
dimitris-athanasiou authored Nov 30, 2021
1 parent a774a7b commit f4a2da8
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.ParentTaskAssigningClient;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings;
Expand Down Expand Up @@ -53,6 +55,7 @@ public class TransportPreviewDatafeedAction extends HandledTransportAction<Previ

private final ThreadPool threadPool;
private final Client client;
private final ClusterService clusterService;
private final JobConfigProvider jobConfigProvider;
private final DatafeedConfigProvider datafeedConfigProvider;
private final NamedXContentRegistry xContentRegistry;
Expand All @@ -65,13 +68,15 @@ public TransportPreviewDatafeedAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
ClusterService clusterService,
JobConfigProvider jobConfigProvider,
DatafeedConfigProvider datafeedConfigProvider,
NamedXContentRegistry xContentRegistry
) {
super(PreviewDatafeedAction.NAME, transportService, actionFilters, PreviewDatafeedAction.Request::new);
this.threadPool = threadPool;
this.client = client;
this.clusterService = clusterService;
this.jobConfigProvider = jobConfigProvider;
this.datafeedConfigProvider = datafeedConfigProvider;
this.xContentRegistry = xContentRegistry;
Expand All @@ -84,12 +89,12 @@ public TransportPreviewDatafeedAction(
protected void doExecute(Task task, PreviewDatafeedAction.Request request, ActionListener<PreviewDatafeedAction.Response> listener) {
ActionListener<DatafeedConfig> datafeedConfigActionListener = ActionListener.wrap(datafeedConfig -> {
if (request.getJobConfig() != null) {
previewDatafeed(datafeedConfig, request.getJobConfig().build(new Date()), listener);
previewDatafeed(task, datafeedConfig, request.getJobConfig().build(new Date()), listener);
return;
}
jobConfigProvider.getJob(
datafeedConfig.getJobId(),
ActionListener.wrap(jobBuilder -> previewDatafeed(datafeedConfig, jobBuilder.build(), listener), listener::onFailure)
ActionListener.wrap(jobBuilder -> previewDatafeed(task, datafeedConfig, jobBuilder.build(), listener), listener::onFailure)
);
}, listener::onFailure);
if (request.getDatafeedConfig() != null) {
Expand All @@ -102,7 +107,12 @@ protected void doExecute(Task task, PreviewDatafeedAction.Request request, Actio
}
}

private void previewDatafeed(DatafeedConfig datafeedConfig, Job job, ActionListener<PreviewDatafeedAction.Response> listener) {
private void previewDatafeed(
Task task,
DatafeedConfig datafeedConfig,
Job job,
ActionListener<PreviewDatafeedAction.Response> listener
) {
DatafeedConfig.Builder previewDatafeedBuilder = buildPreviewDatafeed(datafeedConfig);
useSecondaryAuthIfAvailable(securityContext, () -> {
previewDatafeedBuilder.setHeaders(filterSecurityHeaders(threadPool.getThreadContext().getHeaders()));
Expand All @@ -111,7 +121,7 @@ private void previewDatafeed(DatafeedConfig datafeedConfig, Job job, ActionListe
// requesting the preview doesn't have permission to search the relevant indices.
DatafeedConfig previewDatafeedConfig = previewDatafeedBuilder.build();
DataExtractorFactory.create(
client,
new ParentTaskAssigningClient(client, clusterService.localNode(), task),
previewDatafeedConfig,
job,
xContentRegistry,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.ParentTaskAssigningClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
Expand Down Expand Up @@ -251,7 +252,7 @@ public void onFailure(Exception e) {
remoteAliases,
(cn) -> remoteClusterService.getConnection(cn).getVersion()
);
createDataExtractor(job, datafeedConfigHolder.get(), params, waitForTaskListener);
createDataExtractor(task, job, datafeedConfigHolder.get(), params, waitForTaskListener);
}
},
e -> listener.onFailure(
Expand All @@ -264,7 +265,7 @@ public void onFailure(Exception e) {
)
);
} else {
createDataExtractor(job, datafeedConfigHolder.get(), params, waitForTaskListener);
createDataExtractor(task, job, datafeedConfigHolder.get(), params, waitForTaskListener);
}
};

Expand Down Expand Up @@ -343,13 +344,14 @@ static void checkRemoteClusterVersions(

/** Creates {@link DataExtractorFactory} solely for the purpose of validation i.e. verifying that it can be created. */
private void createDataExtractor(
Task task,
Job job,
DatafeedConfig datafeed,
StartDatafeedAction.DatafeedParams params,
ActionListener<PersistentTasksCustomMetadata.PersistentTask<StartDatafeedAction.DatafeedParams>> listener
) {
DataExtractorFactory.create(
client,
new ParentTaskAssigningClient(client, clusterService.localNode(), task),
datafeed,
job,
xContentRegistry,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public class DatafeedJobBuilder {
private final Supplier<Long> currentTimeSupplier;
private final JobResultsPersister jobResultsPersister;
private final boolean remoteClusterClient;
private final String nodeName;
private final ClusterService clusterService;

private volatile long delayedDataCheckFreq;

Expand All @@ -65,8 +65,8 @@ public DatafeedJobBuilder(
this.currentTimeSupplier = Objects.requireNonNull(currentTimeSupplier);
this.jobResultsPersister = Objects.requireNonNull(jobResultsPersister);
this.remoteClusterClient = DiscoveryNode.isRemoteClusterClient(settings);
this.nodeName = clusterService.getNodeName();
this.delayedDataCheckFreq = DELAYED_DATA_CHECK_FREQ.get(settings).millis();
this.clusterService = Objects.requireNonNull(clusterService);
clusterService.getClusterSettings().addSettingsUpdateConsumer(DELAYED_DATA_CHECK_FREQ, this::setDelayedDataCheckFreq);
}

Expand All @@ -75,7 +75,7 @@ private void setDelayedDataCheckFreq(TimeValue value) {
}

void build(TransportStartDatafeedAction.DatafeedTask task, DatafeedContext context, ActionListener<DatafeedJob> listener) {
final ParentTaskAssigningClient parentTaskAssigningClient = new ParentTaskAssigningClient(client, task.getParentTaskId());
final ParentTaskAssigningClient parentTaskAssigningClient = new ParentTaskAssigningClient(client, clusterService.localNode(), task);
final DatafeedConfig datafeedConfig = context.getDatafeedConfig();
final Job job = context.getJob();
final long latestFinalBucketEndMs = context.getRestartTimeInfo().getLatestFinalBucketTimeMs() == null
Expand Down Expand Up @@ -155,7 +155,12 @@ private void checkRemoteIndicesAreAvailable(DatafeedConfig datafeedConfig) {
List<String> remoteIndices = RemoteClusterLicenseChecker.remoteIndices(datafeedConfig.getIndices());
if (remoteIndices.isEmpty() == false) {
throw ExceptionsHelper.badRequestException(
Messages.getMessage(Messages.DATAFEED_NEEDS_REMOTE_CLUSTER_SEARCH, datafeedConfig.getId(), remoteIndices, nodeName)
Messages.getMessage(
Messages.DATAFEED_NEEDS_REMOTE_CLUSTER_SEARCH,
datafeedConfig.getId(),
remoteIndices,
clusterService.getNodeName()
)
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
*/
package org.elasticsearch.xpack.ml.datafeed;

import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlocks;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.routing.OperationRouting;
import org.elasticsearch.cluster.service.ClusterApplierService;
import org.elasticsearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -38,6 +44,8 @@
import java.util.HashSet;
import java.util.concurrent.atomic.AtomicBoolean;

import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import static org.elasticsearch.test.NodeRoles.nonRemoteClusterClientNode;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
Expand Down Expand Up @@ -78,11 +86,25 @@ public void init() {
)
)
);
final DiscoveryNode localNode = new DiscoveryNode(
"test_node",
buildNewFakeTransportAddress(),
emptyMap(),
emptySet(),
Version.CURRENT
);
clusterService = new ClusterService(
Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), "test_node").build(),
clusterSettings,
threadPool
);
clusterService.getClusterApplierService()
.setInitialState(
ClusterState.builder(new ClusterName("DatafeedJobBuilderTests"))
.nodes(DiscoveryNodes.builder().add(localNode).localNodeId(localNode.getId()).masterNodeId(localNode.getId()))
.blocks(ClusterBlocks.EMPTY_CLUSTER_BLOCK)
.build()
);

datafeedJobBuilder = new DatafeedJobBuilder(
client,
Expand Down

0 comments on commit f4a2da8

Please sign in to comment.