Skip to content

Commit

Permalink
[Backport 2.x] Add profile transport action to AD client (#1123)
Browse files Browse the repository at this point in the history
Signed-off-by: Tyler Ohlsen <ohltyler@amazon.com>
  • Loading branch information
ohltyler authored Dec 27, 2023
1 parent 2ea79db commit 4e6ca48
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.ad.transport.ADTaskProfileResponse;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;

Expand Down Expand Up @@ -40,7 +41,7 @@ default ActionFuture<SearchResponse> searchAnomalyDetectors(SearchRequest search
*/
default ActionFuture<SearchResponse> searchAnomalyResults(SearchRequest searchRequest) {
PlainActionFuture<SearchResponse> actionFuture = PlainActionFuture.newFuture();
searchAnomalyDetectors(searchRequest, actionFuture);
searchAnomalyResults(searchRequest, actionFuture);
return actionFuture;
}

Expand All @@ -51,4 +52,22 @@ default ActionFuture<SearchResponse> searchAnomalyResults(SearchRequest searchRe
*/
void searchAnomalyResults(SearchRequest searchRequest, ActionListener<SearchResponse> listener);

/**
* Get detector profile - refer to https://opensearch.org/docs/latest/observing-your-data/ad/api/#profile-detector
* @param detectorId the detector ID to fetch the profile for
* @return ActionFuture of ADTaskProfileResponse
*/
default ActionFuture<ADTaskProfileResponse> getDetectorProfile(String detectorId) {
PlainActionFuture<ADTaskProfileResponse> actionFuture = PlainActionFuture.newFuture();
getDetectorProfile(detectorId, actionFuture);
return actionFuture;
}

/**
* Get detector profile - refer to https://opensearch.org/docs/latest/observing-your-data/ad/api/#profile-detector
* @param detectorId the detector ID to fetch the profile for
* @param listener a listener to be notified of the result
*/
void getDetectorProfile(String detectorId, ActionListener<ADTaskProfileResponse> listener);

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,29 @@

package org.opensearch.ad.client;

import java.util.function.Function;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.ad.transport.ADTaskProfileAction;
import org.opensearch.ad.transport.ADTaskProfileRequest;
import org.opensearch.ad.transport.ADTaskProfileResponse;
import org.opensearch.ad.transport.SearchAnomalyDetectorAction;
import org.opensearch.ad.transport.SearchAnomalyResultAction;
import org.opensearch.ad.util.DiscoveryNodeFilterer;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;

public class AnomalyDetectionNodeClient implements AnomalyDetectionClient {
private final Client client;
private final DiscoveryNodeFilterer nodeFilterer;

public AnomalyDetectionNodeClient(Client client) {
public AnomalyDetectionNodeClient(Client client, ClusterService clusterService) {
this.client = client;
this.nodeFilterer = new DiscoveryNodeFilterer(clusterService);
}

@Override
Expand All @@ -38,4 +49,34 @@ public void searchAnomalyResults(SearchRequest searchRequest, ActionListener<Sea
ActionListener.wrap(searchResponse -> { listener.onResponse(searchResponse); }, listener::onFailure)
);
}

@Override
public void getDetectorProfile(String detectorId, ActionListener<ADTaskProfileResponse> listener) {
final DiscoveryNode[] eligibleNodes = this.nodeFilterer.getEligibleDataNodes();
ADTaskProfileRequest profileRequest = new ADTaskProfileRequest(detectorId, eligibleNodes);
this.client.execute(ADTaskProfileAction.INSTANCE, profileRequest, getADTaskProfileResponseActionListener(listener));
}

// We need to wrap AD-specific response type listeners around an internal listener, and re-generate the response from a generic
// ActionResponse. This is needed to prevent classloader issues and ClassCastExceptions when executed by other plugins.
private ActionListener<ADTaskProfileResponse> getADTaskProfileResponseActionListener(ActionListener<ADTaskProfileResponse> listener) {
ActionListener<ADTaskProfileResponse> internalListener = ActionListener
.wrap(profileResponse -> { listener.onResponse(profileResponse); }, listener::onFailure);
ActionListener<ADTaskProfileResponse> actionListener = wrapActionListener(internalListener, actionResponse -> {
ADTaskProfileResponse response = ADTaskProfileResponse.fromActionResponse(actionResponse);
return response;
});
return actionListener;
}

private <T extends ActionResponse> ActionListener<T> wrapActionListener(
final ActionListener<T> listener,
final Function<ActionResponse, T> recreate
) {
ActionListener<T> actionListener = ActionListener.wrap(r -> {
listener.onResponse(recreate.apply(r));
;
}, e -> { listener.onFailure(e); });
return actionListener;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@

package org.opensearch.ad.transport;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.List;

import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.nodes.BaseNodesResponse;
import org.opensearch.cluster.ClusterName;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

Expand All @@ -40,4 +46,18 @@ public List<ADTaskProfileNodeResponse> readNodesFrom(StreamInput in) throws IOEx
return in.readList(ADTaskProfileNodeResponse::readNodeResponse);
}

public static ADTaskProfileResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof ADTaskProfileResponse) {
return (ADTaskProfileResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new ADTaskProfileResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into ADTaskProfileResponse", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,21 @@
import org.mockito.MockitoAnnotations;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.ad.transport.ADTaskProfileResponse;
import org.opensearch.core.action.ActionListener;

public class AnomalyDetectionClientTests {

AnomalyDetectionClient anomalyDetectionClient;

@Mock
SearchResponse searchResponse;
SearchResponse searchDetectorsResponse;

@Mock
SearchResponse searchResultsResponse;

@Mock
ADTaskProfileResponse profileResponse;

@Before
public void setUp() {
Expand All @@ -30,24 +37,34 @@ public void setUp() {
anomalyDetectionClient = new AnomalyDetectionClient() {
@Override
public void searchAnomalyDetectors(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
listener.onResponse(searchResponse);
listener.onResponse(searchDetectorsResponse);
}

@Override
public void searchAnomalyResults(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
listener.onResponse(searchResponse);
listener.onResponse(searchResultsResponse);
}

@Override
public void getDetectorProfile(String detectorId, ActionListener<ADTaskProfileResponse> listener) {
listener.onResponse(profileResponse);
}
};
}

@Test
public void searchAnomalyDetectors() {
assertEquals(searchResponse, anomalyDetectionClient.searchAnomalyDetectors(new SearchRequest()).actionGet());
assertEquals(searchDetectorsResponse, anomalyDetectionClient.searchAnomalyDetectors(new SearchRequest()).actionGet());
}

@Test
public void searchAnomalyResults() {
assertEquals(searchResponse, anomalyDetectionClient.searchAnomalyResults(new SearchRequest()).actionGet());
assertEquals(searchResultsResponse, anomalyDetectionClient.searchAnomalyResults(new SearchRequest()).actionGet());
}

@Test
public void getDetectorProfile() {
assertEquals(profileResponse, anomalyDetectionClient.getDetectorProfile("foo").actionGet());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,43 @@

package org.opensearch.ad.client;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.opensearch.ad.TestHelpers.matchAllRequest;
import static org.opensearch.ad.indices.AnomalyDetectionIndices.ALL_AD_RESULTS_INDEX_PATTERN;
import static org.opensearch.ad.model.AnomalyDetector.DETECTOR_TYPE_FIELD;

import java.io.IOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutionException;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.junit.Before;
import org.junit.Test;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.ad.HistoricalAnalysisIntegTestCase;
import org.opensearch.ad.TestHelpers;
import org.opensearch.ad.constant.CommonName;
import org.opensearch.ad.model.ADTaskProfile;
import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.ad.model.AnomalyDetectorType;
import org.opensearch.ad.transport.ADTaskProfileAction;
import org.opensearch.ad.transport.ADTaskProfileNodeResponse;
import org.opensearch.ad.transport.ADTaskProfileResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterName;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.search.builder.SearchSourceBuilder;
Expand All @@ -37,22 +53,26 @@
// The exhaustive set of transport action scenarios are within the respective transport action
// test suites themselves. We do not want to unnecessarily duplicate all of those tests here.
public class AnomalyDetectionNodeClientTests extends HistoricalAnalysisIntegTestCase {
private final Logger logger = LogManager.getLogger(this.getClass());

private String indexName = "test-data";
private Instant startTime = Instant.now().minus(2, ChronoUnit.DAYS);
private Client clientSpy;
private AnomalyDetectionNodeClient adClient;
private PlainActionFuture<SearchResponse> future;
private PlainActionFuture<SearchResponse> searchResponseFuture;
private PlainActionFuture<ADTaskProfileResponse> profileFuture;

@Before
public void setup() {
adClient = new AnomalyDetectionNodeClient(client());
clientSpy = spy(client());
adClient = new AnomalyDetectionNodeClient(clientSpy, clusterService());
}

@Test
public void testSearchAnomalyDetectors_NoIndices() {
deleteIndexIfExists(AnomalyDetector.ANOMALY_DETECTORS_INDEX);

SearchResponse searchResponse = adClient.searchAnomalyDetectors(matchAllRequest()).actionGet(10000);
SearchResponse searchResponse = adClient.searchAnomalyDetectors(TestHelpers.matchAllRequest()).actionGet(10000);
assertEquals(0, searchResponse.getInternalResponse().hits().getTotalHits().value);
}

Expand All @@ -61,13 +81,13 @@ public void testSearchAnomalyDetectors_Empty() throws IOException {
deleteIndexIfExists(AnomalyDetector.ANOMALY_DETECTORS_INDEX);
createDetectorIndex();

SearchResponse searchResponse = adClient.searchAnomalyDetectors(matchAllRequest()).actionGet(10000);
SearchResponse searchResponse = adClient.searchAnomalyDetectors(TestHelpers.matchAllRequest()).actionGet(10000);
assertEquals(0, searchResponse.getInternalResponse().hits().getTotalHits().value);
}

@Test
public void searchAnomalyDetectors_Populated() throws IOException {
ingestTestData(indexName, startTime, 1, "test", 3000);
ingestTestData(indexName, startTime, 1, "test", 10);
String detectorType = AnomalyDetectorType.SINGLE_ENTITY.name();
AnomalyDetector detector = TestHelpers
.randomAnomalyDetector(
Expand All @@ -93,18 +113,18 @@ public void searchAnomalyDetectors_Populated() throws IOException {

@Test
public void testSearchAnomalyResults_NoIndices() {
future = mock(PlainActionFuture.class);
searchResponseFuture = mock(PlainActionFuture.class);
SearchRequest request = new SearchRequest().indices(new String[] {});

adClient.searchAnomalyResults(request, future);
verify(future).onFailure(any(IllegalArgumentException.class));
adClient.searchAnomalyResults(request, searchResponseFuture);
verify(searchResponseFuture).onFailure(any(IllegalArgumentException.class));
}

@Test
public void testSearchAnomalyResults_Empty() throws IOException {
createADResultIndex();
SearchResponse searchResponse = adClient
.searchAnomalyResults(matchAllRequest().indices(ALL_AD_RESULTS_INDEX_PATTERN))
.searchAnomalyResults(TestHelpers.matchAllRequest().indices(ALL_AD_RESULTS_INDEX_PATTERN))
.actionGet(10000);
assertEquals(0, searchResponse.getInternalResponse().hits().getTotalHits().value);
}
Expand All @@ -116,11 +136,52 @@ public void testSearchAnomalyResults_Populated() throws IOException {
String adResultId = createADResult(TestHelpers.randomAnomalyDetectResult());

SearchResponse searchResponse = adClient
.searchAnomalyResults(matchAllRequest().indices(ALL_AD_RESULTS_INDEX_PATTERN))
.searchAnomalyResults(TestHelpers.matchAllRequest().indices(ALL_AD_RESULTS_INDEX_PATTERN))
.actionGet(10000);
assertEquals(1, searchResponse.getInternalResponse().hits().getTotalHits().value);

assertEquals(1, searchResponse.getInternalResponse().hits().getTotalHits().value);
assertEquals(adResultId, searchResponse.getInternalResponse().hits().getAt(0).getId());
}

@Test
public void testGetDetectorProfile_NoIndices() throws ExecutionException, InterruptedException {
deleteIndexIfExists(AnomalyDetector.ANOMALY_DETECTORS_INDEX);
deleteIndexIfExists(ALL_AD_RESULTS_INDEX_PATTERN);
deleteIndexIfExists(CommonName.DETECTION_STATE_INDEX);

profileFuture = mock(PlainActionFuture.class);
ADTaskProfileResponse response = adClient.getDetectorProfile("foo").actionGet(10000);
List<ADTaskProfileNodeResponse> responses = response.getNodes();

assertNotEquals(0, responses.size());
assertEquals(null, responses.get(0).getAdTaskProfile());
verify(clientSpy, times(1)).execute(any(ADTaskProfileAction.class), any(), any());

}

@Test
public void testGetDetectorProfile_Populated() {
DiscoveryNode localNode = clusterService().localNode();
ADTaskProfile adTaskProfile = new ADTaskProfile("foo-task-id", 0, 0L, false, 0, 0L, localNode.getId());

doAnswer(invocation -> {
Object[] args = invocation.getArguments();

ActionListener<ADTaskProfileResponse> listener = (ActionListener<ADTaskProfileResponse>) args[2];
ADTaskProfileNodeResponse nodeResponse = new ADTaskProfileNodeResponse(localNode, adTaskProfile, null);

List<ADTaskProfileNodeResponse> nodeResponses = Arrays.asList(nodeResponse);
listener.onResponse(new ADTaskProfileResponse(new ClusterName("test-cluster"), nodeResponses, Collections.emptyList()));

return null;
}).when(clientSpy).execute(any(ADTaskProfileAction.class), any(), any());

ADTaskProfileResponse response = adClient.getDetectorProfile("foo").actionGet(10000);
String responseTaskId = response.getNodes().get(0).getAdTaskProfile().getTaskId();

assertNotEquals(0, response.getNodes().size());
assertEquals(responseTaskId, adTaskProfile.getTaskId());
verify(clientSpy, times(1)).execute(any(ADTaskProfileAction.class), any(), any());
}

}
Loading

0 comments on commit 4e6ca48

Please sign in to comment.