Skip to content

Commit

Permalink
[ML] prefer secondary auth headers on data frame analytics _explain (#…
Browse files Browse the repository at this point in the history
…63281) (#63323)

We should prefer secondary auth headers when calling _explain
  • Loading branch information
benwtrent authored Oct 6, 2020
1 parent ca68298 commit a72d7cc
Show file tree
Hide file tree
Showing 2 changed files with 246 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.integration;

import org.apache.http.util.EntityUtils;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken;
import org.junit.Before;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;

public class ExplainDataFrameAnalyticsRestIT extends ESRestTestCase {

private static String basicAuth(String user) {
return UsernamePasswordToken.basicAuthHeaderValue(user, SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
}

private static final String SUPER_USER = "x_pack_rest_user";
private static final String ML_ADMIN = "ml_admin";
private static final String BASIC_AUTH_VALUE_SUPER_USER = basicAuth(SUPER_USER);
private static final String AUTH_KEY = "Authorization";
private static final String SECONDARY_AUTH_KEY = "es-secondary-authorization";

private static RequestOptions.Builder addAuthHeader(RequestOptions.Builder builder, String user) {
builder.addHeader(AUTH_KEY, basicAuth(user));
return builder;
}

private static RequestOptions.Builder addSecondaryAuthHeader(RequestOptions.Builder builder, String user) {
builder.addHeader(SECONDARY_AUTH_KEY, basicAuth(user));
return builder;
}

@Override
protected Settings restClientSettings() {
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE_SUPER_USER).build();
}

private void setupUser(String user, List<String> roles) throws IOException {
String password = new String(SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING.getChars());

Request request = new Request("PUT", "/_security/user/" + user);
request.setJsonEntity("{"
+ " \"password\" : \"" + password + "\","
+ " \"roles\" : [ " + roles.stream().map(unquoted -> "\"" + unquoted + "\"").collect(Collectors.joining(", ")) + " ]"
+ "}");
client().performRequest(request);
}

@Before
public void setUpData() throws Exception {
// This user has admin rights on machine learning, but (importantly for the tests) no rights
// on any of the data indexes
setupUser(ML_ADMIN, Collections.singletonList("machine_learning_admin"));
addAirlineData();
}

private void addAirlineData() throws IOException {
StringBuilder bulk = new StringBuilder();

// Create index with source = enabled, doc_values = enabled, stored = false + multi-field
Request createAirlineDataRequest = new Request("PUT", "/airline-data");
createAirlineDataRequest.setJsonEntity("{"
+ " \"mappings\": {"
+ " \"properties\": {"
+ " \"time stamp\": { \"type\":\"date\"}," // space in 'time stamp' is intentional
+ " \"airline\": {"
+ " \"type\":\"keyword\""
+ " },"
+ " \"responsetime\": { \"type\":\"float\"}"
+ " }"
+ " }"
+ "}");
client().performRequest(createAirlineDataRequest);

bulk.append("{\"index\": {\"_index\": \"airline-data\", \"_id\": 1}}\n");
bulk.append("{\"time stamp\":\"2016-06-01T00:00:00Z\",\"airline\":\"AAA\",\"responsetime\":135.22}\n");
bulk.append("{\"index\": {\"_index\": \"airline-data\", \"_id\": 2}}\n");
bulk.append("{\"time stamp\":\"2016-06-01T01:59:00Z\",\"airline\":\"AAA\",\"responsetime\":541.76}\n");

bulkIndex(bulk.toString());
}

public void testExplain_GivenSecondaryHeadersAndConfig() throws IOException {
String config = "{\n" +
" \"source\": {\n" +
" \"index\": \"airline-data\"\n" +
" },\n" +
" \"analysis\": {\n" +
" \"regression\": {\n" +
" \"dependent_variable\": \"responsetime\"\n" +
" }\n" +
" }\n" +
"}";


{ // Request with secondary headers without perms
Request explain = explainRequestViaConfig(config);
RequestOptions.Builder options = explain.getOptions().toBuilder();
addAuthHeader(options, SUPER_USER);
addSecondaryAuthHeader(options, ML_ADMIN);
explain.setOptions(options);
// Should throw
ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(explain));
assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(403));
}
{ // request with secondary headers with perms
Request explain = explainRequestViaConfig(config);
RequestOptions.Builder options = explain.getOptions().toBuilder();
addAuthHeader(options, ML_ADMIN);
addSecondaryAuthHeader(options, SUPER_USER);
explain.setOptions(options);
// Should not throw
client().performRequest(explain);
}
}

public void testExplain_GivenSecondaryHeadersAndPreviouslyStoredConfig() throws IOException {
String config = "{\n" +
" \"source\": {\n" +
" \"index\": \"airline-data\"\n" +
" },\n" +
" \"dest\": {\n" +
" \"index\": \"response_prediction\"\n" +
" },\n" +
" \"analysis\":\n" +
" {\n" +
" \"regression\": {\n" +
" \"dependent_variable\": \"responsetime\"\n" +
" }\n" +
" }\n" +
"}";

String configId = "explain_test";

Request storeConfig = new Request("PUT", "_ml/data_frame/analytics/" + configId);
storeConfig.setJsonEntity(config);
client().performRequest(storeConfig);

{ // Request with secondary headers without perms
Request explain = explainRequestConfigId(configId);
RequestOptions.Builder options = explain.getOptions().toBuilder();
addAuthHeader(options, SUPER_USER);
addSecondaryAuthHeader(options, ML_ADMIN);
explain.setOptions(options);
// Should throw
ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(explain));
assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(403));
}
{ // request with secondary headers with perms
Request explain = explainRequestConfigId(configId);
RequestOptions.Builder options = explain.getOptions().toBuilder();
addAuthHeader(options, ML_ADMIN);
addSecondaryAuthHeader(options, SUPER_USER);
explain.setOptions(options);
// Should not throw
client().performRequest(explain);
}
}

private static Request explainRequestViaConfig(String config) {
Request request = new Request("POST", "_ml/data_frame/analytics/_explain");
request.setJsonEntity(config);
return request;
}

private static Request explainRequestConfigId(String id) {
return new Request("POST", "_ml/data_frame/analytics/" + id + "/_explain");
}

private void bulkIndex(String bulk) throws IOException {
Request bulkRequest = new Request("POST", "/_bulk");
bulkRequest.setJsonEntity(bulk);
bulkRequest.addParameter("refresh", "true");
bulkRequest.addParameter("pretty", null);
String bulkResponse = EntityUtils.toString(client().performRequest(bulkRequest).getEntity());
assertThat(bulkResponse, not(containsString("\"errors\": false")));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,21 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection;
import org.elasticsearch.xpack.core.ml.dataframe.explain.MemoryEstimation;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.security.SecurityContext;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetector;
Expand All @@ -37,6 +42,9 @@
import java.util.Objects;
import java.util.Optional;

import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders;
import static org.elasticsearch.xpack.ml.utils.SecondaryAuthorizationUtils.useSecondaryAuthIfAvailable;

/**
* Provides explanations on aspects of the given data frame analytics spec like memory estimation, field selection, etc.
* Redirects to a different node if the current node is *not* an ML node.
Expand All @@ -49,20 +57,28 @@ public class TransportExplainDataFrameAnalyticsAction
private final ClusterService clusterService;
private final NodeClient client;
private final MemoryUsageEstimationProcessManager processManager;
private final SecurityContext securityContext;
private final ThreadPool threadPool;

@Inject
public TransportExplainDataFrameAnalyticsAction(TransportService transportService,
ActionFilters actionFilters,
ClusterService clusterService,
NodeClient client,
XPackLicenseState licenseState,
MemoryUsageEstimationProcessManager processManager) {
MemoryUsageEstimationProcessManager processManager,
Settings settings,
ThreadPool threadPool) {
super(ExplainDataFrameAnalyticsAction.NAME, transportService, actionFilters, PutDataFrameAnalyticsAction.Request::new);
this.transportService = transportService;
this.clusterService = Objects.requireNonNull(clusterService);
this.client = Objects.requireNonNull(client);
this.licenseState = licenseState;
this.processManager = Objects.requireNonNull(processManager);
this.threadPool = threadPool;
this.securityContext = XPackSettings.SECURITY_ENABLED.get(settings) ?
new SecurityContext(settings, threadPool.getThreadContext()) :
null;
}

@Override
Expand All @@ -84,17 +100,38 @@ protected void doExecute(Task task,

private void explain(Task task, PutDataFrameAnalyticsAction.Request request,
ActionListener<ExplainDataFrameAnalyticsAction.Response> listener) {
ExtractedFieldsDetectorFactory extractedFieldsDetectorFactory =
new ExtractedFieldsDetectorFactory(new ParentTaskAssigningClient(client, task.getParentTaskId()));
extractedFieldsDetectorFactory.createFromSource(
request.getConfig(),
ActionListener.wrap(
extractedFieldsDetector -> explain(task, request, extractedFieldsDetector, listener),
listener::onFailure)

final ExtractedFieldsDetectorFactory extractedFieldsDetectorFactory = new ExtractedFieldsDetectorFactory(
new ParentTaskAssigningClient(client, task.getParentTaskId())
);
if (licenseState.isSecurityEnabled()) {
useSecondaryAuthIfAvailable(this.securityContext, () -> {
// Set the auth headers (preferring the secondary headers) to the caller's.
// Regardless if the config was previously stored or not.
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder(request.getConfig())
.setHeaders(filterSecurityHeaders(threadPool.getThreadContext().getHeaders()))
.build();
extractedFieldsDetectorFactory.createFromSource(
config,
ActionListener.wrap(
extractedFieldsDetector -> explain(task, config, extractedFieldsDetector, listener),
listener::onFailure
)
);
});
} else {
extractedFieldsDetectorFactory.createFromSource(
request.getConfig(),
ActionListener.wrap(
extractedFieldsDetector -> explain(task, request.getConfig(), extractedFieldsDetector, listener),
listener::onFailure
)
);
}

}

private void explain(Task task, PutDataFrameAnalyticsAction.Request request, ExtractedFieldsDetector extractedFieldsDetector,
private void explain(Task task, DataFrameAnalyticsConfig config, ExtractedFieldsDetector extractedFieldsDetector,
ActionListener<ExplainDataFrameAnalyticsAction.Response> listener) {
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();

Expand All @@ -103,7 +140,7 @@ private void explain(Task task, PutDataFrameAnalyticsAction.Request request, Ext
listener::onFailure
);

estimateMemoryUsage(task, request, fieldExtraction.v1(), memoryEstimationListener);
estimateMemoryUsage(task, config, fieldExtraction.v1(), memoryEstimationListener);
}

/**
Expand All @@ -112,15 +149,15 @@ private void explain(Task task, PutDataFrameAnalyticsAction.Request request, Ext
* the ML node.
*/
private void estimateMemoryUsage(Task task,
PutDataFrameAnalyticsAction.Request request,
DataFrameAnalyticsConfig config,
ExtractedFields extractedFields,
ActionListener<MemoryEstimation> listener) {
final String estimateMemoryTaskId = "memory_usage_estimation_" + task.getId();
DataFrameDataExtractorFactory extractorFactory = DataFrameDataExtractorFactory.createForSourceIndices(
new ParentTaskAssigningClient(client, task.getParentTaskId()), estimateMemoryTaskId, request.getConfig(), extractedFields);
new ParentTaskAssigningClient(client, task.getParentTaskId()), estimateMemoryTaskId, config, extractedFields);
processManager.runJobAsync(
estimateMemoryTaskId,
request.getConfig(),
config,
extractorFactory,
ActionListener.wrap(
result -> listener.onResponse(
Expand Down

0 comments on commit a72d7cc

Please sign in to comment.