From 068f0b02ac78a71a0509b39707fb8f9234850db9 Mon Sep 17 00:00:00 2001 From: Arjun kumar Giri Date: Wed, 12 Jun 2024 17:24:02 -0700 Subject: [PATCH 01/10] AWS DDB SDK client support for remote data store Signed-off-by: Arjun kumar Giri --- .../sdk/DeleteDataObjectRequest.java | 17 +- .../opensearch/sdk/GetDataObjectRequest.java | 16 +- .../opensearch/sdk/PutDataObjectRequest.java | 28 +- plugin/build.gradle | 22 ++ .../ml/sdkclient/DDBOpenSearchClient.java | 125 ++++++++ .../ml/sdkclient/SdkClientModule.java | 86 +++++- .../sdkclient/DDBOpenSearchClientTests.java | 274 ++++++++++++++++++ .../ml/sdkclient/SdkClientModuleTests.java | 19 +- 8 files changed, 569 insertions(+), 18 deletions(-) create mode 100644 plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java create mode 100644 plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java diff --git a/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java index 31d560815d..4cbe587f75 100644 --- a/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java @@ -13,6 +13,8 @@ public class DeleteDataObjectRequest { private final String index; private final String id; + private final String tenantId; + /** * Instantiate this request with an index and id. *

@@ -20,9 +22,10 @@ public class DeleteDataObjectRequest { * @param index the index location to delete the object * @param id the document id */ - public DeleteDataObjectRequest(String index, String id) { + public DeleteDataObjectRequest(String index, String id, String tenantId) { this.index = index; this.id = id; + this.tenantId = tenantId; } /** @@ -41,12 +44,17 @@ public String id() { return this.id; } + public String tenantId() { + return this.tenantId; + } + /** * Class for constructing a Builder for this Request Object */ public static class Builder { private String index = null; private String id = null; + private String tenantId = null; /** * Empty Constructor for the Builder object @@ -73,12 +81,17 @@ public Builder id(String id) { return this; } + public Builder tenantId(String tenantId) { + this.tenantId = tenantId; + return this; + } + /** * Builds the object * @return A {@link DeleteDataObjectRequest} */ public DeleteDataObjectRequest build() { - return new DeleteDataObjectRequest(this.index, this.id); + return new DeleteDataObjectRequest(this.index, this.id, this.tenantId); } } } diff --git a/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java index 8edbb99f39..3d282dbf04 100644 --- a/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java @@ -14,6 +14,7 @@ public class GetDataObjectRequest { private final String index; private final String id; + private final String tenantId; private final FetchSourceContext fetchSourceContext; /** @@ -24,9 +25,10 @@ public class GetDataObjectRequest { * @param id the document id * @param fetchSourceContext the context to use when fetching _source */ - public GetDataObjectRequest(String index, String id, FetchSourceContext fetchSourceContext) { + public GetDataObjectRequest(String index, String id, String tenantId, FetchSourceContext fetchSourceContext) { this.index = index; this.id = id; + this.tenantId = tenantId; this.fetchSourceContext = fetchSourceContext; } @@ -46,6 +48,10 @@ public String id() { return this.id; } + public String tenantId() { + return this.tenantId; + } + /** * Returns the context for fetching _source * @return the fetchSourceContext @@ -60,6 +66,7 @@ public FetchSourceContext fetchSourceContext() { public static class Builder { private String index = null; private String id = null; + private String tenantId = null; private FetchSourceContext fetchSourceContext; /** @@ -87,6 +94,11 @@ public Builder id(String id) { return this; } + public Builder tenantId(String tenantId) { + this.tenantId = tenantId; + return this; + } + /** * Add a fetchSourceContext to this builder * @param fetchSourceContext the fetchSourceContext @@ -102,7 +114,7 @@ public Builder fetchSourceContext(FetchSourceContext fetchSourceContext) { * @return A {@link GetDataObjectRequest} */ public GetDataObjectRequest build() { - return new GetDataObjectRequest(this.index, this.id, this.fetchSourceContext); + return new GetDataObjectRequest(this.index, this.id, this.tenantId, this.fetchSourceContext); } } } diff --git a/common/src/main/java/org/opensearch/sdk/PutDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/PutDataObjectRequest.java index 2d6d0a5d07..bb36150de0 100644 --- a/common/src/main/java/org/opensearch/sdk/PutDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/PutDataObjectRequest.java @@ -13,6 +13,8 @@ public class PutDataObjectRequest { private final String index; + private final String id; + private final String tenantId; private final ToXContentObject dataObject; /** @@ -22,8 +24,10 @@ public class PutDataObjectRequest { * @param index the index location to put the object * @param dataObject the data object */ - public PutDataObjectRequest(String index, ToXContentObject dataObject) { + public PutDataObjectRequest(String index, String id, String tenantId, ToXContentObject dataObject) { this.index = index; + this.id = id; + this.tenantId = tenantId; this.dataObject = dataObject; } @@ -35,6 +39,14 @@ public String index() { return this.index; } + public String id() { + return this.id; + } + + public String tenantId() { + return this.tenantId; + } + /** * Returns the data object * @return the data object @@ -48,6 +60,8 @@ public ToXContentObject dataObject() { */ public static class Builder { private String index = null; + private String id = null; + private String tenantId = null; private ToXContentObject dataObject = null; /** @@ -65,6 +79,16 @@ public Builder index(String index) { return this; } + public Builder id(String id) { + this.id = id; + return this; + } + + public Builder tenantId(String tenantId) { + this.tenantId = tenantId; + return this; + } + /** * Add a data object to this builder * @param dataObject the data object @@ -80,7 +104,7 @@ public Builder dataObject(ToXContentObject dataObject) { * @return A {@link PutDataObjectRequest} */ public PutDataObjectRequest build() { - return new PutDataObjectRequest(this.index, this.dataObject); + return new PutDataObjectRequest(this.index, this.id, this.tenantId, this.dataObject); } } } diff --git a/plugin/build.gradle b/plugin/build.gradle index 947ed22c5a..5b4f1ee3fb 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -63,6 +63,28 @@ dependencies { implementation 'com.jayway.jsonpath:json-path:2.9.0' implementation "org.opensearch.client:opensearch-java:2.10.2" + // Dynamo dependencies + implementation("software.amazon.awssdk:sdk-core:2.25.40") + implementation("software.amazon.awssdk:aws-core:2.25.40") + implementation "software.amazon.awssdk:aws-json-protocol:2.25.40" + implementation("software.amazon.awssdk:auth:2.25.40") + implementation("software.amazon.awssdk:checksums:2.25.40") + implementation("software.amazon.awssdk:checksums-spi:2.25.40") + implementation("software.amazon.awssdk:dynamodb:2.25.40") + implementation("software.amazon.awssdk:endpoints-spi:2.25.40") + implementation("software.amazon.awssdk:http-auth-aws:2.25.40") + implementation("software.amazon.awssdk:http-auth-spi:2.25.40") + implementation("software.amazon.awssdk:http-client-spi:2.25.40") + implementation("software.amazon.awssdk:identity-spi:2.25.40") + implementation "software.amazon.awssdk:json-utils:2.25.40" + implementation "software.amazon.awssdk:metrics-spi:2.25.40" + implementation("software.amazon.awssdk:profiles:2.25.40") + implementation "software.amazon.awssdk:protocol-core:2.25.40" + implementation("software.amazon.awssdk:regions:2.25.40") + implementation "software.amazon.awssdk:third-party-jackson-core:2.25.40" + implementation("software.amazon.awssdk:url-connection-client:2.25.40") + implementation("software.amazon.awssdk:utils:2.25.40") + configurations.all { resolutionStrategy.force 'org.apache.httpcomponents.core5:httpcore5:5.2.4' diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java new file mode 100644 index 0000000000..2c3039c2bb --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java @@ -0,0 +1,125 @@ +package org.opensearch.ml.sdkclient; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.opensearch.OpenSearchException; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.sdk.DeleteDataObjectRequest; +import org.opensearch.sdk.DeleteDataObjectResponse; +import org.opensearch.sdk.GetDataObjectRequest; +import org.opensearch.sdk.GetDataObjectResponse; +import org.opensearch.sdk.PutDataObjectRequest; +import org.opensearch.sdk.PutDataObjectResponse; +import org.opensearch.sdk.SdkClient; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; +import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; +import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; +import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.Executor; + +@AllArgsConstructor +@Log4j2 +public class DDBOpenSearchClient implements SdkClient { + + private static final String DEFAULT_TENANT = "DEFAULT_TENANT"; + + private static final String HASH_KEY = "tenant_id"; + private static final String RANGE_KEY = "id"; + private static final String SOURCE = "source"; + + private DynamoDbClient dynamoDbClient; + @Override + public CompletionStage putDataObjectAsync(PutDataObjectRequest request, Executor executor) { + final String id = request.id() != null ? request.id() : UUID.randomUUID().toString(); + final String tenantId = request.tenantId() != null ? request.tenantId() : DEFAULT_TENANT; + final String tableName = getTableName(request.index()); + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + try (XContentBuilder sourceBuilder = XContentFactory.jsonBuilder()) { + XContentBuilder builder = request.dataObject().toXContent(sourceBuilder, ToXContent.EMPTY_PARAMS); + String source = builder.toString(); + + final Map item = Map.ofEntries( + Map.entry(HASH_KEY, AttributeValue.builder().s(tenantId).build()), + Map.entry(RANGE_KEY, AttributeValue.builder().s(id).build()), + Map.entry(SOURCE, AttributeValue.builder().s(source).build()) + ); + final PutItemRequest putItemRequest = PutItemRequest.builder() + .tableName(tableName) + .item(item) + .build(); + + dynamoDbClient.putItem(putItemRequest); + return new PutDataObjectResponse.Builder().id(id).created(true).build(); + } catch (Exception e){ + log.error("Exception while inserting data into DDB: " + e.getMessage(), e); + throw new OpenSearchException(e); + } + }), executor); + } + + @Override + public CompletionStage getDataObjectAsync(GetDataObjectRequest request, Executor executor) { + final String tenantId = request.tenantId() != null ? request.tenantId() : DEFAULT_TENANT; + final GetItemRequest getItemRequest = GetItemRequest.builder() + .tableName(getTableName(request.index())) + .key(Map.ofEntries( + Map.entry(HASH_KEY, AttributeValue.builder().s(tenantId).build()), + Map.entry(RANGE_KEY, AttributeValue.builder().s(request.id()).build()) + )) + .build(); + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + try { + final GetItemResponse getItemResponse = dynamoDbClient.getItem(getItemRequest); + if (getItemResponse == null || getItemResponse.item() == null || getItemResponse.item().isEmpty()) { + return new GetDataObjectResponse.Builder().id(request.id()).parser(Optional.empty()).build(); + } + + String source = getItemResponse.item().get(SOURCE).s(); + XContentParser parser = JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, source); + return new GetDataObjectResponse.Builder().id(request.id()).parser(Optional.of(parser)).build(); + } catch (Exception e) { + log.error("Exception while fetching data from DDB: " + e.getMessage(), e); + throw new OpenSearchException(e); + } + }), executor); + } + + @Override + public CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request, Executor executor) { + final String tenantId = request.tenantId() != null ? request.tenantId() : DEFAULT_TENANT; + final DeleteItemRequest deleteItemRequest = DeleteItemRequest.builder() + .tableName(getTableName(request.index())) + .key(Map.ofEntries( + Map.entry(HASH_KEY, AttributeValue.builder().s(tenantId).build()), + Map.entry(RANGE_KEY, AttributeValue.builder().s(request.id()).build()) + )).build(); + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + dynamoDbClient.deleteItem(deleteItemRequest); + return new DeleteDataObjectResponse.Builder().id(request.id()).deleted(true).build(); + }), executor); + } + + private String getTableName(String index) { + // Table name will be same as index name. As DDB table name does not support dot(.) + // it will be removed form name. + return index.replaceAll("\\.", ""); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java index fb7d1d3119..504397cb30 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java @@ -8,36 +8,79 @@ */ package org.opensearch.ml.sdkclient; +import lombok.extern.log4j.Log4j2; +import lombok.extern.slf4j.Slf4j; import org.apache.http.HttpHost; import org.apache.http.conn.ssl.NoopHostnameVerifier; import org.opensearch.OpenSearchException; +import org.opensearch.SpecialPermission; import org.opensearch.client.RestClient; import org.opensearch.client.json.jackson.JacksonJsonpMapper; import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.transport.rest_client.RestClientTransport; import org.opensearch.common.inject.AbstractModule; -import org.opensearch.core.common.Strings; +import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.profiles.ProfileFileSystemSetting; + +import java.security.AccessController; +import java.security.PrivilegedAction; /** * A module for binding this plugin's desired implementation of {@link SdkClient}. */ +@Log4j2 public class SdkClientModule extends AbstractModule { + public static final String REMOTE_METADATA_TYPE = "REMOTE_METADATA_TYPE"; public static final String REMOTE_METADATA_ENDPOINT = "REMOTE_METADATA_ENDPOINT"; public static final String REGION = "REGION"; + public static final String REMOTE_OPENSEARCH = "RemoteOpenSearch"; + public static final String AWS_DYNAMO_DB = "AWSDynamoDB"; + private final String remoteStoreType; private final String remoteMetadataEndpoint; private final String region; // not using with RestClient + static { + // Aws v2 sdk tries to load a default profile from home path which is restricted. Hence, setting these to random valid paths. + // @SuppressForbidden(reason = "Need to provide this override to v2 SDK so that path does not default to home path") + if (ProfileFileSystemSetting.AWS_SHARED_CREDENTIALS_FILE.getStringValue().isEmpty()) { + SocketAccess.doPrivileged( + () -> System.setProperty( + ProfileFileSystemSetting.AWS_SHARED_CREDENTIALS_FILE.property(), + System.getProperty("opensearch.path.conf") + ) + ); + } + if (ProfileFileSystemSetting.AWS_CONFIG_FILE.getStringValue().isEmpty()) { + SocketAccess.doPrivileged( + () -> System.setProperty(ProfileFileSystemSetting.AWS_CONFIG_FILE.property(), System.getProperty("opensearch.path.conf")) + ); + } + } + + private static final class SocketAccess { + private SocketAccess() {} + + public static T doPrivileged(PrivilegedAction operation) { + SpecialPermission.check(); + return AccessController.doPrivileged(operation); + } + } + /** * Instantiate this module using environment variables */ public SdkClientModule() { - this(System.getenv(REMOTE_METADATA_ENDPOINT), System.getenv(REGION)); + this(System.getenv(REMOTE_METADATA_TYPE), System.getenv(REMOTE_METADATA_ENDPOINT), System.getenv(REGION)); } /** @@ -45,19 +88,44 @@ public SdkClientModule() { * @param remoteMetadataEndpoint The remote endpoint * @param region The region */ - SdkClientModule(String remoteMetadataEndpoint, String region) { + SdkClientModule(String remoteStoreType, String remoteMetadataEndpoint, String region) { + this.remoteStoreType = remoteStoreType; this.remoteMetadataEndpoint = remoteMetadataEndpoint; - this.region = region; + this.region = region == null ? "us-west-2" : region; } @Override - protected void configure() { - boolean local = Strings.isNullOrEmpty(remoteMetadataEndpoint); - if (local) { + protected void configure() {/* + if (this.remoteStoreType == null) { + log.info("Using local opensearch cluster as metadata store"); bind(SdkClient.class).to(LocalClusterIndicesClient.class); - } else { - bind(SdkClient.class).toInstance(new RemoteClusterIndicesClient(createOpenSearchClient())); + return; } + + switch (this.remoteStoreType) { + case REMOTE_OPENSEARCH: + log.info("Using remote opensearch cluster as metadata store"); + bind(SdkClient.class).toInstance(new RemoteClusterIndicesClient(createOpenSearchClient())); + return; + case AWS_DYNAMO_DB: + log.info("Using dynamo DB as metadata store"); + bind(SdkClient.class).toInstance(new DDBOpenSearchClient(createDynamoDbClient())); + return; + default: + log.info("Using local opensearch cluster as metadata store"); + bind(SdkClient.class).to(LocalClusterIndicesClient.class); + }*/ + bind(SdkClient.class).toInstance(new DDBOpenSearchClient(createDynamoDbClient())); + } + + private DynamoDbClient createDynamoDbClient() { + if (this.region == null) { + throw new IllegalStateException("REGION environment variable needs to be set!"); + } + + return DynamoDbClient.builder() + .region(Region.of(this.region)) + .build(); } private OpenSearchClient createOpenSearchClient() { diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java new file mode 100644 index 0000000000..0da4bc21a5 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java @@ -0,0 +1,274 @@ +package org.opensearch.ml.sdkclient; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchException; +import org.opensearch.client.opensearch.core.IndexRequest; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.sdk.DeleteDataObjectRequest; +import org.opensearch.sdk.DeleteDataObjectResponse; +import org.opensearch.sdk.GetDataObjectRequest; +import org.opensearch.sdk.GetDataObjectResponse; +import org.opensearch.sdk.PutDataObjectRequest; +import org.opensearch.sdk.PutDataObjectResponse; +import org.opensearch.sdk.SdkClient; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; +import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse; +import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; +import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; +import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; +import software.amazon.awssdk.services.dynamodb.model.PutItemResponse; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.TimeUnit; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; + + +public class DDBOpenSearchClientTests extends OpenSearchTestCase { + + private static final String TEST_ID = "123"; + private static final String TENANT_ID = "TEST_TENANT_ID"; + private static final String TEST_INDEX = "test_index"; + private SdkClient sdkClient; + + @Mock + private DynamoDbClient dynamoDbClient; + @Captor + private ArgumentCaptor putItemRequestArgumentCaptor; + @Captor + private ArgumentCaptor getItemRequestArgumentCaptor; + @Captor + private ArgumentCaptor deleteItemRequestArgumentCaptor; + private TestDataObject testDataObject; + + + private static TestThreadPool testThreadPool = new TestThreadPool( + LocalClusterIndicesClientTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); + + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); + } + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + sdkClient = new DDBOpenSearchClient(dynamoDbClient); + testDataObject = new TestDataObject("foo"); + } + + @Test + public void testPutDataObject_HappyCase() throws IOException { + PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .tenantId(TENANT_ID) + .dataObject(testDataObject).build(); + Mockito.when(dynamoDbClient.putItem(Mockito.any(PutItemRequest.class))) + .thenReturn(PutItemResponse.builder().build()); + PutDataObjectResponse response = sdkClient + .putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + Mockito.verify(dynamoDbClient).putItem(putItemRequestArgumentCaptor.capture()); + Assert.assertEquals(TEST_ID, response.id()); + Assert.assertEquals(true, response.created()); + + PutItemRequest putItemRequest = putItemRequestArgumentCaptor.getValue(); + Assert.assertEquals(TEST_INDEX, putItemRequest.tableName()); + Assert.assertEquals(TEST_ID, putItemRequest.item().get("id").s()); + Assert.assertEquals(TENANT_ID, putItemRequest.item().get("tenant_id").s()); + XContentBuilder sourceBuilder = XContentFactory.jsonBuilder(); + XContentBuilder builder = testDataObject.toXContent(sourceBuilder, ToXContent.EMPTY_PARAMS); + Assert.assertEquals(builder.toString(), putItemRequest.item().get("source").s()); + } + + @Test + public void testPutDataObject_NullTenantId_SetsDefaultTenantId() throws IOException { + PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject).build(); + Mockito.when(dynamoDbClient.putItem(Mockito.any(PutItemRequest.class))) + .thenReturn(PutItemResponse.builder().build()); + sdkClient.putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + Mockito.verify(dynamoDbClient).putItem(putItemRequestArgumentCaptor.capture()); + + PutItemRequest putItemRequest = putItemRequestArgumentCaptor.getValue(); + Assert.assertEquals("DEFAULT_TENANT", putItemRequest.item().get("tenant_id").s()); + } + + @Test + public void testPutDataObject_NullId_SetsDefaultTenantId() throws IOException { + PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder() + .index(TEST_INDEX) + .dataObject(testDataObject).build(); + Mockito.when(dynamoDbClient.putItem(Mockito.any(PutItemRequest.class))) + .thenReturn(PutItemResponse.builder().build()); + PutDataObjectResponse response = sdkClient.putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + Mockito.verify(dynamoDbClient).putItem(putItemRequestArgumentCaptor.capture()); + + PutItemRequest putItemRequest = putItemRequestArgumentCaptor.getValue(); + Assert.assertNotNull(putItemRequest.item().get("id").s()); + Assert.assertNotNull(response.id()); + } + + @Test + public void testPutDataObject_DDBException_ThrowsException() { + PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject).build(); + Mockito.when(dynamoDbClient.putItem(Mockito.any(PutItemRequest.class))) + .thenThrow(new RuntimeException("Test exception")); + CompletableFuture future = sdkClient + .putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); + + CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); + assertEquals(OpenSearchException.class, ce.getCause().getClass()); + } + + @Test + public void testGetDataObject_HappyCase() throws IOException { + GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .tenantId(TENANT_ID).build(); + XContentBuilder sourceBuilder = XContentFactory.jsonBuilder(); + XContentBuilder builder = testDataObject.toXContent(sourceBuilder, ToXContent.EMPTY_PARAMS); + GetItemResponse getItemResponse = GetItemResponse.builder().item(Map.ofEntries( + Map.entry("source", AttributeValue.builder().s(builder.toString()).build()) + )).build(); + Mockito.when(dynamoDbClient.getItem(Mockito.any(GetItemRequest.class))) + .thenReturn(getItemResponse); + GetDataObjectResponse response = sdkClient + .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + Mockito.verify(dynamoDbClient).getItem(getItemRequestArgumentCaptor.capture()); + GetItemRequest getItemRequest = getItemRequestArgumentCaptor.getValue(); + Assert.assertEquals(TEST_INDEX, getItemRequest.tableName()); + Assert.assertEquals(TENANT_ID, getItemRequest.key().get("tenant_id").s()); + Assert.assertEquals(TEST_ID, getItemRequest.key().get("id").s()); + Assert.assertEquals(TEST_ID, response.id()); + Assert.assertTrue(response.parser().isPresent()); + Assert.assertEquals("foo", response.parser().get().map().get("data")); + } + + @Test + public void testGetDataObject_NoExistingDoc() throws IOException { + GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .tenantId(TENANT_ID).build(); + GetItemResponse getItemResponse = GetItemResponse.builder().build(); + Mockito.when(dynamoDbClient.getItem(Mockito.any(GetItemRequest.class))) + .thenReturn(getItemResponse); + GetDataObjectResponse response = sdkClient + .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + Assert.assertEquals(TEST_ID, response.id()); + Assert.assertFalse(response.parser().isPresent()); + } + + @Test + public void testGetDataObject_UseDefaultTenantIdIfNull() throws IOException { + GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID).build(); + GetItemResponse getItemResponse = GetItemResponse.builder().build(); + Mockito.when(dynamoDbClient.getItem(Mockito.any(GetItemRequest.class))) + .thenReturn(getItemResponse); + GetDataObjectResponse response = sdkClient + .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + Mockito.verify(dynamoDbClient).getItem(getItemRequestArgumentCaptor.capture()); + GetItemRequest getItemRequest = getItemRequestArgumentCaptor.getValue(); + Assert.assertEquals("DEFAULT_TENANT", getItemRequest.key().get("tenant_id").s()); + } + + @Test + public void testGetDataObject_DDBException_ThrowsOSException() throws IOException { + GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .tenantId(TENANT_ID).build(); + Mockito.when(dynamoDbClient.getItem(Mockito.any(GetItemRequest.class))) + .thenThrow(new RuntimeException("Test exception")); + CompletableFuture future = sdkClient + .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); + CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); + assertEquals(OpenSearchException.class, ce.getCause().getClass()); + } + + @Test + public void testDeleteDataObject_HappyCase() { + DeleteDataObjectRequest deleteRequest = new DeleteDataObjectRequest.Builder() + .id(TEST_ID).index(TEST_INDEX).tenantId(TENANT_ID).build(); + Mockito.when(dynamoDbClient.deleteItem(deleteItemRequestArgumentCaptor.capture())) + .thenReturn(DeleteItemResponse.builder().build()); + DeleteDataObjectResponse deleteResponse = sdkClient + .deleteDataObjectAsync(deleteRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture().join(); + DeleteItemRequest deleteItemRequest = deleteItemRequestArgumentCaptor.getValue(); + Assert.assertEquals(TEST_INDEX, deleteItemRequest.tableName()); + Assert.assertEquals(TENANT_ID, deleteItemRequest.key().get("tenant_id").s()); + Assert.assertEquals(TEST_ID, deleteItemRequest.key().get("id").s()); + Assert.assertEquals(TEST_ID, deleteResponse.id()); + Assert.assertTrue(deleteResponse.deleted()); + } + + @Test + public void testDeleteDataObject_NullTenantId_UsesDefaultTenantId() { + DeleteDataObjectRequest deleteRequest = new DeleteDataObjectRequest.Builder() + .id(TEST_ID).index(TEST_INDEX).build(); + Mockito.when(dynamoDbClient.deleteItem(deleteItemRequestArgumentCaptor.capture())) + .thenReturn(DeleteItemResponse.builder().build()); + DeleteDataObjectResponse deleteResponse = sdkClient + .deleteDataObjectAsync(deleteRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture().join(); + DeleteItemRequest deleteItemRequest = deleteItemRequestArgumentCaptor.getValue(); + Assert.assertEquals("DEFAULT_TENANT", deleteItemRequest.key().get("tenant_id").s()); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java index 707ddd46f6..b1cd9d7db6 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java @@ -10,6 +10,7 @@ import static org.mockito.Mockito.mock; +import org.junit.Before; import org.opensearch.common.inject.AbstractModule; import org.opensearch.common.inject.Guice; import org.opensearch.common.inject.Injector; @@ -22,6 +23,11 @@ @ThreadLeakScope(ThreadLeakScope.Scope.NONE) // remote http client is never closed public class SdkClientModuleTests extends OpenSearchTestCase { + @Before + public void setup() { + System.setProperty("opensearch.path.conf", "/tmp"); + } + private Module localClientModule = new AbstractModule() { @Override protected void configure() { @@ -30,16 +36,23 @@ protected void configure() { }; public void testLocalBinding() { - Injector injector = Guice.createInjector(new SdkClientModule(null, null), localClientModule); + Injector injector = Guice.createInjector(new SdkClientModule(null, null, null), localClientModule); SdkClient sdkClient = injector.getInstance(SdkClient.class); assertTrue(sdkClient instanceof LocalClusterIndicesClient); } - public void testRemoteBinding() { - Injector injector = Guice.createInjector(new SdkClientModule("http://example.org", "eu-west-3")); + public void testRemoteOpenSearchBinding() { + Injector injector = Guice.createInjector(new SdkClientModule(SdkClientModule.REMOTE_OPENSEARCH, "http://example.org", "eu-west-3")); SdkClient sdkClient = injector.getInstance(SdkClient.class); assertTrue(sdkClient instanceof RemoteClusterIndicesClient); } + + public void testDDBBinding() { + Injector injector = Guice.createInjector(new SdkClientModule(SdkClientModule.AWS_DYNAMO_DB, null, "eu-west-3")); + + SdkClient sdkClient = injector.getInstance(SdkClient.class); + assertTrue(sdkClient instanceof DDBOpenSearchClient); + } } From e9e209acbdd38341fd369aead431cede2aeb2e68 Mon Sep 17 00:00:00 2001 From: Arjun kumar Giri Date: Wed, 12 Jun 2024 17:24:02 -0700 Subject: [PATCH 02/10] AWS DDB SDK client support for remote data store Signed-off-by: Arjun kumar Giri --- .../ml/sdkclient/DDBOpenSearchClient.java | 98 ++++++--- .../ml/sdkclient/SdkClientModule.java | 63 ++---- .../sdkclient/DDBOpenSearchClientTests.java | 192 +++++++++--------- .../ml/sdkclient/SdkClientModuleTests.java | 6 - 4 files changed, 175 insertions(+), 184 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java index 2c3039c2bb..196e1885a9 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java @@ -1,7 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ package org.opensearch.ml.sdkclient; -import lombok.AllArgsConstructor; -import lombok.extern.log4j.Log4j2; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.Executor; + import org.opensearch.OpenSearchException; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; @@ -17,6 +32,9 @@ import org.opensearch.sdk.PutDataObjectRequest; import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; @@ -24,16 +42,10 @@ import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; -import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedAction; -import java.util.Map; -import java.util.Optional; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionStage; -import java.util.concurrent.Executor; - +/** + * DDB implementation of {@link SdkClient}. DDB table name will be mapped to index name. + * + */ @AllArgsConstructor @Log4j2 public class DDBOpenSearchClient implements SdkClient { @@ -45,6 +57,13 @@ public class DDBOpenSearchClient implements SdkClient { private static final String SOURCE = "source"; private DynamoDbClient dynamoDbClient; + + /** + * DDB implementation to write data objects to DDB table. Tenant ID will be used as hash key and document ID will + * be used as range key. If tenant ID is not defined a default tenant ID will be used. If document ID is not defined + * a random UUID will be generated. Data object will be written as a nested DDB attribute. + * + */ @Override public CompletionStage putDataObjectAsync(PutDataObjectRequest request, Executor executor) { final String id = request.id() != null ? request.id() : UUID.randomUUID().toString(); @@ -55,35 +74,41 @@ public CompletionStage putDataObjectAsync(PutDataObjectRe XContentBuilder builder = request.dataObject().toXContent(sourceBuilder, ToXContent.EMPTY_PARAMS); String source = builder.toString(); - final Map item = Map.ofEntries( + final Map item = Map + .ofEntries( Map.entry(HASH_KEY, AttributeValue.builder().s(tenantId).build()), Map.entry(RANGE_KEY, AttributeValue.builder().s(id).build()), Map.entry(SOURCE, AttributeValue.builder().s(source).build()) - ); - final PutItemRequest putItemRequest = PutItemRequest.builder() - .tableName(tableName) - .item(item) - .build(); + ); + final PutItemRequest putItemRequest = PutItemRequest.builder().tableName(tableName).item(item).build(); - dynamoDbClient.putItem(putItemRequest); - return new PutDataObjectResponse.Builder().id(id).created(true).build(); - } catch (Exception e){ + dynamoDbClient.putItem(putItemRequest); + return new PutDataObjectResponse.Builder().id(id).created(true).build(); + } catch (Exception e) { log.error("Exception while inserting data into DDB: " + e.getMessage(), e); throw new OpenSearchException(e); - } + } }), executor); } + /** + * Fetches data document from DDB. Default tenant ID will be used if tenant ID is not specified. + * + */ @Override public CompletionStage getDataObjectAsync(GetDataObjectRequest request, Executor executor) { final String tenantId = request.tenantId() != null ? request.tenantId() : DEFAULT_TENANT; - final GetItemRequest getItemRequest = GetItemRequest.builder() - .tableName(getTableName(request.index())) - .key(Map.ofEntries( + final GetItemRequest getItemRequest = GetItemRequest + .builder() + .tableName(getTableName(request.index())) + .key( + Map + .ofEntries( Map.entry(HASH_KEY, AttributeValue.builder().s(tenantId).build()), Map.entry(RANGE_KEY, AttributeValue.builder().s(request.id()).build()) - )) - .build(); + ) + ) + .build(); return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { try { final GetItemResponse getItemResponse = dynamoDbClient.getItem(getItemRequest); @@ -93,7 +118,7 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe String source = getItemResponse.item().get(SOURCE).s(); XContentParser parser = JsonXContent.jsonXContent - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, source); + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, source); return new GetDataObjectResponse.Builder().id(request.id()).parser(Optional.of(parser)).build(); } catch (Exception e) { log.error("Exception while fetching data from DDB: " + e.getMessage(), e); @@ -102,15 +127,24 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe }), executor); } + /** + * Deletes data document from DDB. Default tenant ID will be used if tenant ID is not specified. + * + */ @Override public CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request, Executor executor) { final String tenantId = request.tenantId() != null ? request.tenantId() : DEFAULT_TENANT; - final DeleteItemRequest deleteItemRequest = DeleteItemRequest.builder() - .tableName(getTableName(request.index())) - .key(Map.ofEntries( + final DeleteItemRequest deleteItemRequest = DeleteItemRequest + .builder() + .tableName(getTableName(request.index())) + .key( + Map + .ofEntries( Map.entry(HASH_KEY, AttributeValue.builder().s(tenantId).build()), Map.entry(RANGE_KEY, AttributeValue.builder().s(request.id()).build()) - )).build(); + ) + ) + .build(); return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { dynamoDbClient.deleteItem(deleteItemRequest); return new DeleteDataObjectResponse.Builder().id(request.id()).deleted(true).build(); diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java index 504397cb30..d7be7043b9 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java @@ -8,30 +8,26 @@ */ package org.opensearch.ml.sdkclient; -import lombok.extern.log4j.Log4j2; -import lombok.extern.slf4j.Slf4j; import org.apache.http.HttpHost; import org.apache.http.conn.ssl.NoopHostnameVerifier; import org.opensearch.OpenSearchException; -import org.opensearch.SpecialPermission; import org.opensearch.client.RestClient; import org.opensearch.client.json.jackson.JacksonJsonpMapper; import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.transport.rest_client.RestClientTransport; import org.opensearch.common.inject.AbstractModule; -import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.PropertyNamingStrategies; -import software.amazon.awssdk.services.dynamodb.DynamoDbClient; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; -import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; -import software.amazon.awssdk.profiles.ProfileFileSystemSetting; -import java.security.AccessController; -import java.security.PrivilegedAction; +import lombok.extern.log4j.Log4j2; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProviderChain; +import software.amazon.awssdk.auth.credentials.ContainerCredentialsProvider; +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; /** * A module for binding this plugin's desired implementation of {@link SdkClient}. @@ -49,33 +45,6 @@ public class SdkClientModule extends AbstractModule { private final String remoteMetadataEndpoint; private final String region; // not using with RestClient - static { - // Aws v2 sdk tries to load a default profile from home path which is restricted. Hence, setting these to random valid paths. - // @SuppressForbidden(reason = "Need to provide this override to v2 SDK so that path does not default to home path") - if (ProfileFileSystemSetting.AWS_SHARED_CREDENTIALS_FILE.getStringValue().isEmpty()) { - SocketAccess.doPrivileged( - () -> System.setProperty( - ProfileFileSystemSetting.AWS_SHARED_CREDENTIALS_FILE.property(), - System.getProperty("opensearch.path.conf") - ) - ); - } - if (ProfileFileSystemSetting.AWS_CONFIG_FILE.getStringValue().isEmpty()) { - SocketAccess.doPrivileged( - () -> System.setProperty(ProfileFileSystemSetting.AWS_CONFIG_FILE.property(), System.getProperty("opensearch.path.conf")) - ); - } - } - - private static final class SocketAccess { - private SocketAccess() {} - - public static T doPrivileged(PrivilegedAction operation) { - SpecialPermission.check(); - return AccessController.doPrivileged(operation); - } - } - /** * Instantiate this module using environment variables */ @@ -91,11 +60,11 @@ public SdkClientModule() { SdkClientModule(String remoteStoreType, String remoteMetadataEndpoint, String region) { this.remoteStoreType = remoteStoreType; this.remoteMetadataEndpoint = remoteMetadataEndpoint; - this.region = region == null ? "us-west-2" : region; + this.region = region; } @Override - protected void configure() {/* + protected void configure() { if (this.remoteStoreType == null) { log.info("Using local opensearch cluster as metadata store"); bind(SdkClient.class).to(LocalClusterIndicesClient.class); @@ -114,8 +83,7 @@ protected void configure() {/* default: log.info("Using local opensearch cluster as metadata store"); bind(SdkClient.class).to(LocalClusterIndicesClient.class); - }*/ - bind(SdkClient.class).toInstance(new DDBOpenSearchClient(createDynamoDbClient())); + } } private DynamoDbClient createDynamoDbClient() { @@ -123,9 +91,14 @@ private DynamoDbClient createDynamoDbClient() { throw new IllegalStateException("REGION environment variable needs to be set!"); } - return DynamoDbClient.builder() - .region(Region.of(this.region)) - .build(); + AwsCredentialsProviderChain credentialsProviderChain = AwsCredentialsProviderChain + .builder() + .addCredentialsProvider(EnvironmentVariableCredentialsProvider.create()) + .addCredentialsProvider(ContainerCredentialsProvider.builder().build()) + .addCredentialsProvider(InstanceProfileCredentialsProvider.create()) + .build(); + + return DynamoDbClient.builder().region(Region.of(this.region)).credentialsProvider(credentialsProviderChain).build(); } private OpenSearchClient createOpenSearchClient() { diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java index 0da4bc21a5..22faf00b25 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java @@ -1,5 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + package org.opensearch.ml.sdkclient; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.TimeUnit; + import org.junit.AfterClass; import org.junit.Assert; import org.junit.Before; @@ -10,14 +28,12 @@ import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchException; -import org.opensearch.client.opensearch.core.IndexRequest; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; import org.opensearch.sdk.GetDataObjectRequest; @@ -29,6 +45,7 @@ import org.opensearch.threadpool.ScalingExecutorBuilder; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; + import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; @@ -38,16 +55,6 @@ import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; import software.amazon.awssdk.services.dynamodb.model.PutItemResponse; -import java.io.IOException; -import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; -import java.util.concurrent.TimeUnit; - -import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; -import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; - - public class DDBOpenSearchClientTests extends OpenSearchTestCase { private static final String TEST_ID = "123"; @@ -65,16 +72,15 @@ public class DDBOpenSearchClientTests extends OpenSearchTestCase { private ArgumentCaptor deleteItemRequestArgumentCaptor; private TestDataObject testDataObject; - private static TestThreadPool testThreadPool = new TestThreadPool( - LocalClusterIndicesClientTests.class.getName(), - new ScalingExecutorBuilder( - GENERAL_THREAD_POOL, - 1, - Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), - TimeValue.timeValueMinutes(1), - ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL - ) + LocalClusterIndicesClientTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) ); @AfterClass @@ -93,16 +99,16 @@ public void setup() { @Test public void testPutDataObject_HappyCase() throws IOException { PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder() - .index(TEST_INDEX) - .id(TEST_ID) - .tenantId(TENANT_ID) - .dataObject(testDataObject).build(); - Mockito.when(dynamoDbClient.putItem(Mockito.any(PutItemRequest.class))) - .thenReturn(PutItemResponse.builder().build()); + .index(TEST_INDEX) + .id(TEST_ID) + .tenantId(TENANT_ID) + .dataObject(testDataObject) + .build(); + Mockito.when(dynamoDbClient.putItem(Mockito.any(PutItemRequest.class))).thenReturn(PutItemResponse.builder().build()); PutDataObjectResponse response = sdkClient - .putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) - .toCompletableFuture() - .join(); + .putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); Mockito.verify(dynamoDbClient).putItem(putItemRequestArgumentCaptor.capture()); Assert.assertEquals(TEST_ID, response.id()); Assert.assertEquals(true, response.created()); @@ -119,14 +125,12 @@ public void testPutDataObject_HappyCase() throws IOException { @Test public void testPutDataObject_NullTenantId_SetsDefaultTenantId() throws IOException { PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder() - .index(TEST_INDEX) - .id(TEST_ID) - .dataObject(testDataObject).build(); - Mockito.when(dynamoDbClient.putItem(Mockito.any(PutItemRequest.class))) - .thenReturn(PutItemResponse.builder().build()); - sdkClient.putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) - .toCompletableFuture() - .join(); + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject) + .build(); + Mockito.when(dynamoDbClient.putItem(Mockito.any(PutItemRequest.class))).thenReturn(PutItemResponse.builder().build()); + sdkClient.putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)).toCompletableFuture().join(); Mockito.verify(dynamoDbClient).putItem(putItemRequestArgumentCaptor.capture()); PutItemRequest putItemRequest = putItemRequestArgumentCaptor.getValue(); @@ -135,14 +139,12 @@ public void testPutDataObject_NullTenantId_SetsDefaultTenantId() throws IOExcept @Test public void testPutDataObject_NullId_SetsDefaultTenantId() throws IOException { - PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder() - .index(TEST_INDEX) - .dataObject(testDataObject).build(); - Mockito.when(dynamoDbClient.putItem(Mockito.any(PutItemRequest.class))) - .thenReturn(PutItemResponse.builder().build()); - PutDataObjectResponse response = sdkClient.putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) - .toCompletableFuture() - .join(); + PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder().index(TEST_INDEX).dataObject(testDataObject).build(); + Mockito.when(dynamoDbClient.putItem(Mockito.any(PutItemRequest.class))).thenReturn(PutItemResponse.builder().build()); + PutDataObjectResponse response = sdkClient + .putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); Mockito.verify(dynamoDbClient).putItem(putItemRequestArgumentCaptor.capture()); PutItemRequest putItemRequest = putItemRequestArgumentCaptor.getValue(); @@ -153,14 +155,14 @@ public void testPutDataObject_NullId_SetsDefaultTenantId() throws IOException { @Test public void testPutDataObject_DDBException_ThrowsException() { PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder() - .index(TEST_INDEX) - .id(TEST_ID) - .dataObject(testDataObject).build(); - Mockito.when(dynamoDbClient.putItem(Mockito.any(PutItemRequest.class))) - .thenThrow(new RuntimeException("Test exception")); + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject) + .build(); + Mockito.when(dynamoDbClient.putItem(Mockito.any(PutItemRequest.class))).thenThrow(new RuntimeException("Test exception")); CompletableFuture future = sdkClient - .putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) - .toCompletableFuture(); + .putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); assertEquals(OpenSearchException.class, ce.getCause().getClass()); @@ -168,21 +170,18 @@ public void testPutDataObject_DDBException_ThrowsException() { @Test public void testGetDataObject_HappyCase() throws IOException { - GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder() - .index(TEST_INDEX) - .id(TEST_ID) - .tenantId(TENANT_ID).build(); + GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder().index(TEST_INDEX).id(TEST_ID).tenantId(TENANT_ID).build(); XContentBuilder sourceBuilder = XContentFactory.jsonBuilder(); XContentBuilder builder = testDataObject.toXContent(sourceBuilder, ToXContent.EMPTY_PARAMS); - GetItemResponse getItemResponse = GetItemResponse.builder().item(Map.ofEntries( - Map.entry("source", AttributeValue.builder().s(builder.toString()).build()) - )).build(); - Mockito.when(dynamoDbClient.getItem(Mockito.any(GetItemRequest.class))) - .thenReturn(getItemResponse); + GetItemResponse getItemResponse = GetItemResponse + .builder() + .item(Map.ofEntries(Map.entry("source", AttributeValue.builder().s(builder.toString()).build()))) + .build(); + Mockito.when(dynamoDbClient.getItem(Mockito.any(GetItemRequest.class))).thenReturn(getItemResponse); GetDataObjectResponse response = sdkClient - .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) - .toCompletableFuture() - .join(); + .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); Mockito.verify(dynamoDbClient).getItem(getItemRequestArgumentCaptor.capture()); GetItemRequest getItemRequest = getItemRequestArgumentCaptor.getValue(); Assert.assertEquals(TEST_INDEX, getItemRequest.tableName()); @@ -195,33 +194,26 @@ public void testGetDataObject_HappyCase() throws IOException { @Test public void testGetDataObject_NoExistingDoc() throws IOException { - GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder() - .index(TEST_INDEX) - .id(TEST_ID) - .tenantId(TENANT_ID).build(); + GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder().index(TEST_INDEX).id(TEST_ID).tenantId(TENANT_ID).build(); GetItemResponse getItemResponse = GetItemResponse.builder().build(); - Mockito.when(dynamoDbClient.getItem(Mockito.any(GetItemRequest.class))) - .thenReturn(getItemResponse); + Mockito.when(dynamoDbClient.getItem(Mockito.any(GetItemRequest.class))).thenReturn(getItemResponse); GetDataObjectResponse response = sdkClient - .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) - .toCompletableFuture() - .join(); + .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); Assert.assertEquals(TEST_ID, response.id()); Assert.assertFalse(response.parser().isPresent()); } @Test public void testGetDataObject_UseDefaultTenantIdIfNull() throws IOException { - GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder() - .index(TEST_INDEX) - .id(TEST_ID).build(); + GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder().index(TEST_INDEX).id(TEST_ID).build(); GetItemResponse getItemResponse = GetItemResponse.builder().build(); - Mockito.when(dynamoDbClient.getItem(Mockito.any(GetItemRequest.class))) - .thenReturn(getItemResponse); + Mockito.when(dynamoDbClient.getItem(Mockito.any(GetItemRequest.class))).thenReturn(getItemResponse); GetDataObjectResponse response = sdkClient - .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) - .toCompletableFuture() - .join(); + .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); Mockito.verify(dynamoDbClient).getItem(getItemRequestArgumentCaptor.capture()); GetItemRequest getItemRequest = getItemRequestArgumentCaptor.getValue(); Assert.assertEquals("DEFAULT_TENANT", getItemRequest.key().get("tenant_id").s()); @@ -229,15 +221,11 @@ public void testGetDataObject_UseDefaultTenantIdIfNull() throws IOException { @Test public void testGetDataObject_DDBException_ThrowsOSException() throws IOException { - GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder() - .index(TEST_INDEX) - .id(TEST_ID) - .tenantId(TENANT_ID).build(); - Mockito.when(dynamoDbClient.getItem(Mockito.any(GetItemRequest.class))) - .thenThrow(new RuntimeException("Test exception")); + GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder().index(TEST_INDEX).id(TEST_ID).tenantId(TENANT_ID).build(); + Mockito.when(dynamoDbClient.getItem(Mockito.any(GetItemRequest.class))).thenThrow(new RuntimeException("Test exception")); CompletableFuture future = sdkClient - .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) - .toCompletableFuture(); + .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); assertEquals(OpenSearchException.class, ce.getCause().getClass()); } @@ -245,12 +233,15 @@ public void testGetDataObject_DDBException_ThrowsOSException() throws IOExceptio @Test public void testDeleteDataObject_HappyCase() { DeleteDataObjectRequest deleteRequest = new DeleteDataObjectRequest.Builder() - .id(TEST_ID).index(TEST_INDEX).tenantId(TENANT_ID).build(); - Mockito.when(dynamoDbClient.deleteItem(deleteItemRequestArgumentCaptor.capture())) - .thenReturn(DeleteItemResponse.builder().build()); + .id(TEST_ID) + .index(TEST_INDEX) + .tenantId(TENANT_ID) + .build(); + Mockito.when(dynamoDbClient.deleteItem(deleteItemRequestArgumentCaptor.capture())).thenReturn(DeleteItemResponse.builder().build()); DeleteDataObjectResponse deleteResponse = sdkClient - .deleteDataObjectAsync(deleteRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) - .toCompletableFuture().join(); + .deleteDataObjectAsync(deleteRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); DeleteItemRequest deleteItemRequest = deleteItemRequestArgumentCaptor.getValue(); Assert.assertEquals(TEST_INDEX, deleteItemRequest.tableName()); Assert.assertEquals(TENANT_ID, deleteItemRequest.key().get("tenant_id").s()); @@ -261,13 +252,12 @@ public void testDeleteDataObject_HappyCase() { @Test public void testDeleteDataObject_NullTenantId_UsesDefaultTenantId() { - DeleteDataObjectRequest deleteRequest = new DeleteDataObjectRequest.Builder() - .id(TEST_ID).index(TEST_INDEX).build(); - Mockito.when(dynamoDbClient.deleteItem(deleteItemRequestArgumentCaptor.capture())) - .thenReturn(DeleteItemResponse.builder().build()); + DeleteDataObjectRequest deleteRequest = new DeleteDataObjectRequest.Builder().id(TEST_ID).index(TEST_INDEX).build(); + Mockito.when(dynamoDbClient.deleteItem(deleteItemRequestArgumentCaptor.capture())).thenReturn(DeleteItemResponse.builder().build()); DeleteDataObjectResponse deleteResponse = sdkClient - .deleteDataObjectAsync(deleteRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) - .toCompletableFuture().join(); + .deleteDataObjectAsync(deleteRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); DeleteItemRequest deleteItemRequest = deleteItemRequestArgumentCaptor.getValue(); Assert.assertEquals("DEFAULT_TENANT", deleteItemRequest.key().get("tenant_id").s()); } diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java index b1cd9d7db6..4c1b3e71ff 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java @@ -10,7 +10,6 @@ import static org.mockito.Mockito.mock; -import org.junit.Before; import org.opensearch.common.inject.AbstractModule; import org.opensearch.common.inject.Guice; import org.opensearch.common.inject.Injector; @@ -23,11 +22,6 @@ @ThreadLeakScope(ThreadLeakScope.Scope.NONE) // remote http client is never closed public class SdkClientModuleTests extends OpenSearchTestCase { - @Before - public void setup() { - System.setProperty("opensearch.path.conf", "/tmp"); - } - private Module localClientModule = new AbstractModule() { @Override protected void configure() { From 13d75f43b2f56a254bec416485cd126f8b769c2e Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Fri, 14 Jun 2024 08:16:51 -0700 Subject: [PATCH 03/10] multi-tenancy for models (create, get, delete, update) + update connector (#2546) * multi-tenancy for models (create, get, delete) Signed-off-by: Dhrubo Saha * added update connector + update model Signed-off-by: Dhrubo Saha --------- Signed-off-by: Dhrubo Saha Signed-off-by: Arjun kumar Giri --- .../ml/client/MachineLearningClient.java | 20 ++ .../ml/client/MachineLearningNodeClient.java | 7 + .../ml/client/MachineLearningClientTest.java | 5 + .../org/opensearch/ml/common/MLModel.java | 16 +- .../connector/MLUpdateConnectorRequest.java | 4 +- .../transport/model/MLModelDeleteRequest.java | 9 +- .../transport/model/MLModelGetRequest.java | 6 +- .../transport/model/MLUpdateModelInput.java | 20 +- .../register/MLRegisterModelInput.java | 23 +- .../MLUpdateConnectorRequestTests.java | 2 +- .../MetricsCorrelation.java | 2 +- .../DeleteConnectorTransportAction.java | 41 +--- .../GetConnectorTransportAction.java | 94 ++++---- .../UpdateConnectorTransportAction.java | 74 +++++-- .../models/DeleteModelTransportAction.java | 17 +- .../models/GetModelTransportAction.java | 44 +++- .../models/UpdateModelTransportAction.java | 176 +++++++++------ .../TransportRegisterModelAction.java | 59 +++-- .../helper/ConnectorAccessControlHelper.java | 98 ++++++++- .../opensearch/ml/model/MLModelManager.java | 2 + .../ml/plugin/MachineLearningPlugin.java | 6 +- .../ml/rest/RestMLDeleteModelAction.java | 11 +- .../ml/rest/RestMLGetModelAction.java | 12 +- .../ml/rest/RestMLRegisterModelAction.java | 3 + .../ml/rest/RestMLUpdateConnectorAction.java | 4 +- .../ml/rest/RestMLUpdateModelAction.java | 11 +- .../ml/action/MLCommonsIntegTestCase.java | 2 +- .../DeleteConnectorTransportActionTests.java | 122 +++++----- .../GetConnectorTransportActionTests.java | 113 ++++------ .../UpdateConnectorTransportActionTests.java | 40 +++- .../DeleteModelTransportActionTests.java | 7 +- .../models/GetModelTransportActionTests.java | 14 +- .../UpdateModelTransportActionTests.java | 49 ++++- .../TransportRegisterModelActionTests.java | 24 +- .../ConnectorAccessControlHelperTests.java | 208 +++++++++++++++++- .../ml/rest/RestMLDeleteModelActionTests.java | 11 +- .../ml/rest/RestMLGetModelActionTests.java | 23 +- .../rest/RestMLGetModelGroupActionTests.java | 21 +- .../ml/rest/RestMLUpdateModelActionTests.java | 10 +- 39 files changed, 1001 insertions(+), 409 deletions(-) diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index f0d1c24d1f..c226815992 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -135,6 +135,18 @@ default ActionFuture getModel(String modelId) { return actionFuture; } + /** + * Get MLModel and return ActionFuture. + * For more info on get model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#get-model-information + * @param modelId id of the model + * @return ActionFuture of ml model + */ + default ActionFuture getModel(String modelId, String tenantId) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + getModel(modelId, tenantId, actionFuture); + return actionFuture; + } + /** * Get MLModel and return model in listener * For more info on get model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#get-model-information @@ -143,6 +155,14 @@ default ActionFuture getModel(String modelId) { */ void getModel(String modelId, ActionListener listener); + /** + * Get MLModel and return model in listener + * For more info on get model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#get-model-information + * @param modelId id of the model + * @param listener action listener + */ + void getModel(String modelId, String tenantId, ActionListener listener); + /** * Get MLTask and return ActionFuture. * For more info on get task, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#get-task-information diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index 0ceee0575d..7ac4549d9e 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -166,6 +166,13 @@ public void getModel(String modelId, ActionListener listener) { client.execute(MLModelGetAction.INSTANCE, mlModelGetRequest, getMlGetModelResponseActionListener(listener)); } + @Override + public void getModel(String modelId, String tenantId, ActionListener listener) { + MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).build(); + + client.execute(MLModelGetAction.INSTANCE, mlModelGetRequest, getMlGetModelResponseActionListener(listener)); + } + private ActionListener getMlGetModelResponseActionListener(ActionListener listener) { ActionListener internalListener = ActionListener.wrap(predictionResponse -> { listener.onResponse(predictionResponse.getMlModel()); diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index 3dbc680447..f14657506f 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -150,6 +150,11 @@ public void getModel(String modelId, ActionListener listener) { listener.onResponse(mlModel); } + @Override + public void getModel(String modelId, String tenantId, ActionListener listener) { + listener.onResponse(mlModel); + } + @Override public void deleteModel(String modelId, ActionListener listener) { listener.onResponse(deleteResponse); diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index 363fa4bb7d..147db00bb3 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -38,6 +38,7 @@ import java.util.Set; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.TENANT_ID; import static org.opensearch.ml.common.CommonValue.USER; import static org.opensearch.ml.common.connector.Connector.createConnector; import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap; @@ -142,6 +143,7 @@ public class MLModel implements ToXContentObject { private Connector connector; private String connectorId; private Guardrails guardrails; + private String tenantId; /** * Model interface is a map that contains the input and output fields of the model, with JSON schema as the value. @@ -206,7 +208,8 @@ public MLModel(String name, Connector connector, String connectorId, Guardrails guardrails, - Map modelInterface) { + Map modelInterface, + String tenantId) { this.name = name; this.modelGroupId = modelGroupId; this.algorithm = algorithm; @@ -241,6 +244,7 @@ public MLModel(String name, this.connectorId = connectorId; this.guardrails = guardrails; this.modelInterface = modelInterface; + this.tenantId = tenantId; } public MLModel(StreamInput input) throws IOException { @@ -305,6 +309,7 @@ public MLModel(StreamInput input) throws IOException { if (input.readBoolean()) { modelInterface = input.readMap(StreamInput::readString, StreamInput::readString); } + tenantId = input.readOptionalString(); } } @@ -388,6 +393,7 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + out.writeOptionalString(tenantId); } @Override @@ -495,6 +501,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (modelInterface != null) { builder.field(INTERFACE_FIELD, modelInterface); } + if (tenantId != null) { + builder.field(TENANT_ID, tenantId); + } builder.endObject(); return builder; } @@ -540,6 +549,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws String connectorId = null; Guardrails guardrails = null; Map modelInterface = null; + String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -674,6 +684,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws case INTERFACE_FIELD: modelInterface = filteredParameterMap(parser.map(), allowedInterfaceFieldKeys); break; + case TENANT_ID: + tenantId = parser.text(); + break; default: parser.skipChildren(); break; @@ -714,6 +727,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws .connectorId(connectorId) .guardrails(guardrails) .modelInterface(modelInterface) + .tenantId(tenantId) .build(); } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java index 089180cdc5..a09aa5520a 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java @@ -61,9 +61,9 @@ public ActionRequestValidationException validate() { return exception; } - public static MLUpdateConnectorRequest parse(XContentParser parser, String connectorId) throws IOException { + public static MLUpdateConnectorRequest parse(XContentParser parser, String connectorId, String tenantId) throws IOException { MLCreateConnectorInput updateContent = MLCreateConnectorInput.parse(parser, true); - + updateContent.setTenantId(tenantId); return MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(updateContent).build(); } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequest.java index a42cf1d071..19c9d1c699 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequest.java @@ -25,20 +25,27 @@ public class MLModelDeleteRequest extends ActionRequest { @Getter String modelId; + @Getter + String tenantId; + @Builder - public MLModelDeleteRequest(String modelId) { + public MLModelDeleteRequest(String modelId, String tenantId) { + this.modelId = modelId; + this.tenantId = tenantId; } public MLModelDeleteRequest(StreamInput input) throws IOException { super(input); this.modelId = input.readString(); + this.tenantId = input.readOptionalString(); } @Override public void writeTo(StreamOutput output) throws IOException { super.writeTo(output); output.writeString(modelId); + output.writeOptionalString(tenantId); } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java index 7cad570f1d..0598bca691 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java @@ -35,12 +35,14 @@ public class MLModelGetRequest extends ActionRequest { // delete/update options, we also perform get operation. This field is to distinguish between // these two situations. boolean isUserInitiatedGetRequest; + String tenantId; @Builder - public MLModelGetRequest(String modelId, boolean returnContent, boolean isUserInitiatedGetRequest) { + public MLModelGetRequest(String modelId, boolean returnContent, boolean isUserInitiatedGetRequest, String tenantId) { this.modelId = modelId; this.returnContent = returnContent; this.isUserInitiatedGetRequest = isUserInitiatedGetRequest; + this.tenantId = tenantId; } public MLModelGetRequest(StreamInput in) throws IOException { @@ -48,6 +50,7 @@ public MLModelGetRequest(StreamInput in) throws IOException { this.modelId = in.readString(); this.returnContent = in.readBoolean(); this.isUserInitiatedGetRequest = in.readBoolean(); + this.tenantId = in.readOptionalString(); } @Override @@ -56,6 +59,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(this.modelId); out.writeBoolean(returnContent); out.writeBoolean(isUserInitiatedGetRequest); + out.writeOptionalString(tenantId); } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java index 03047cf692..08c983c78b 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java @@ -31,6 +31,7 @@ import java.util.Map; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.TENANT_ID; import static org.opensearch.ml.common.MLModel.allowedInterfaceFieldKeys; import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap; @@ -70,6 +71,7 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable { private MLCreateConnectorInput connector; private Instant lastUpdateTime; private Guardrails guardrails; + private String tenantId; private Map modelInterface; @@ -77,7 +79,7 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable { public MLUpdateModelInput(String modelId, String description, String version, String name, String modelGroupId, Boolean isEnabled, MLRateLimiter rateLimiter, MLModelConfig modelConfig, MLDeploySetting deploySetting, Connector updatedConnector, String connectorId, MLCreateConnectorInput connector, Instant lastUpdateTime, - Guardrails guardrails, Map modelInterface) { + Guardrails guardrails, Map modelInterface, String tenantId) { this.modelId = modelId; this.description = description; this.version = version; @@ -93,6 +95,7 @@ public MLUpdateModelInput(String modelId, String description, String version, St this.lastUpdateTime = lastUpdateTime; this.guardrails = guardrails; this.modelInterface = modelInterface; + this.tenantId = tenantId; } public MLUpdateModelInput(StreamInput in) throws IOException { @@ -130,6 +133,8 @@ public MLUpdateModelInput(StreamInput in) throws IOException { modelInterface = in.readMap(StreamInput::readString, StreamInput::readString); } } + //TODO: I will add BWC check later here. + tenantId = in.readOptionalString(); } @Override @@ -176,6 +181,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (modelInterface != null) { builder.field(MLModel.INTERFACE_FIELD, modelInterface); } + if (tenantId != null) { + builder.field(TENANT_ID, tenantId); + } builder.endObject(); return builder; } @@ -237,6 +245,8 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } } + //TODO: I will add BWC check here later. + out.writeOptionalString(tenantId); } public static MLUpdateModelInput parse(XContentParser parser) throws IOException { @@ -255,6 +265,7 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException Instant lastUpdateTime = null; Guardrails guardrails = null; Map modelInterface = null; + String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -294,6 +305,9 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException case MLModel.INTERFACE_FIELD: modelInterface = filteredParameterMap(parser.map(), allowedInterfaceFieldKeys); break; + case TENANT_ID: + tenantId = parser.text(); + break; default: parser.skipChildren(); break; @@ -303,6 +317,6 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException // automatically. return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, isEnabled, rateLimiter, modelConfig, deploySetting, updatedConnector, connectorId, connector, lastUpdateTime, guardrails, - modelInterface); + modelInterface, tenantId); } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index 9eb5ba6b4f..1146d201da 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -39,6 +39,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.MLModel.allowedInterfaceFieldKeys; import static org.opensearch.ml.common.connector.Connector.createConnector; +import static org.opensearch.ml.common.input.Constants.TENANT_ID; import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap; @@ -103,6 +104,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { private Guardrails guardrails; private Map modelInterface; + private String tenantId; @Builder(toBuilder = true) public MLRegisterModelInput(FunctionName functionName, @@ -127,7 +129,8 @@ public MLRegisterModelInput(FunctionName functionName, Boolean doesVersionCreateModelGroup, Boolean isHidden, Guardrails guardrails, - Map modelInterface) { + Map modelInterface, + String tenantId) { this.functionName = Objects.requireNonNullElse(functionName, FunctionName.TEXT_EMBEDDING); if (modelName == null) { throw new IllegalArgumentException("model name is null"); @@ -166,6 +169,7 @@ public MLRegisterModelInput(FunctionName functionName, this.isHidden = isHidden; this.guardrails = guardrails; this.modelInterface = modelInterface; + this.tenantId = tenantId; } public MLRegisterModelInput(StreamInput in) throws IOException { @@ -225,6 +229,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException { this.modelInterface = in.readMap(StreamInput::readString, StreamInput::readString); } } + this.tenantId = in.readOptionalString(); } @Override @@ -306,6 +311,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } } + out.writeOptionalString(tenantId); } @Override @@ -374,6 +380,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (modelInterface != null) { builder.field(MLModel.INTERFACE_FIELD, modelInterface); } + if (tenantId != null) { + builder.field(TENANT_ID, tenantId); + } builder.endObject(); return builder; } @@ -400,6 +409,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName Boolean isHidden = null; Guardrails guardrails = null; Map modelInterface = null; + String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -476,6 +486,9 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName case MLModel.INTERFACE_FIELD: modelInterface = filteredParameterMap(parser.map(), allowedInterfaceFieldKeys); break; + case TENANT_ID: + tenantId = parser.text(); + break; default: parser.skipChildren(); break; @@ -484,7 +497,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, isEnabled, rateLimiter, url, hashValue, modelFormat, modelConfig, deploySetting, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, - isHidden, guardrails, modelInterface); + isHidden, guardrails, modelInterface, tenantId); } public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException { @@ -510,6 +523,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo Boolean isHidden = null; Guardrails guardrails = null; Map modelInterface = null; + String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -593,6 +607,9 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo case MLModel.INTERFACE_FIELD: modelInterface = filteredParameterMap(parser.map(), allowedInterfaceFieldKeys); break; + case TENANT_ID: + tenantId = parser.text(); + break; default: parser.skipChildren(); break; @@ -601,6 +618,6 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, isEnabled, rateLimiter, url, hashValue, modelFormat, modelConfig, deploySetting, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, - isHidden, guardrails, modelInterface); + isHidden, guardrails, modelInterface, tenantId); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java index 44e970f95c..71e7505b2d 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java @@ -72,7 +72,7 @@ public void parse_success() throws IOException { XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), null, jsonStr); parser.nextToken(); - MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.parse(parser, connectorId); + MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.parse(parser, connectorId, null); assertEquals(updateConnectorRequest.getConnectorId(), connectorId); assertTrue(updateConnectorRequest.getUpdateContent().isUpdateConnector()); assertEquals("new version", updateConnectorRequest.getUpdateContent().getVersion()); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java index c752456b7f..371374e189 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java @@ -361,7 +361,7 @@ public MLTask getTask(String taskId) { } public MLModel getModel(String modelId) { - MLModelGetRequest getRequest = new MLModelGetRequest(modelId, false, false); + MLModelGetRequest getRequest = new MLModelGetRequest(modelId, false, false, null); ActionFuture future = client.execute(MLModelGetAction.INSTANCE, getRequest); MLModelGetResponse response = future.actionGet(5000); return response.getMlModel(); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java index c9638817f7..b4220fef95 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java @@ -13,7 +13,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Objects; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; @@ -34,8 +33,6 @@ import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; -import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; -import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TenantAwareHelper; @@ -47,8 +44,6 @@ import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; -import com.google.common.annotations.VisibleForTesting; - import lombok.extern.log4j.Log4j2; @Log4j2 @@ -88,8 +83,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener handleConnectorAccessValidation(connectorId, tenantId, isAllowed, actionListener), @@ -129,25 +127,17 @@ private void checkForModelsUsingConnector(String connectorId, String tenantId, A client.search(searchRequest, ActionListener.runBefore(ActionListener.wrap(searchResponse -> { SearchHit[] searchHits = searchResponse.getHits().getHits(); if (searchHits.length == 0) { - handleNoModelsUsingConnector(connectorId, tenantId, actionListener); + deleteConnector(connectorId, actionListener); } else { handleModelsUsingConnector(searchHits, connectorId, actionListener); } - }, e -> handleSearchFailure(connectorId, tenantId, e, actionListener)), context::restore)); + }, e -> handleSearchFailure(connectorId, e, actionListener)), context::restore)); } catch (Exception e) { log.error("Failed to check for models using connector: " + connectorId, e); actionListener.onFailure(e); } } - private void handleNoModelsUsingConnector(String connectorId, String tenantId, ActionListener actionListener) { - if (mlFeatureEnabledSetting.isMultiTenancyEnabled() && Objects.nonNull(tenantId)) { - checkConnectorPermission(connectorId, tenantId, actionListener, () -> deleteConnector(connectorId, actionListener)); - } else { - deleteConnector(connectorId, actionListener); - } - } - private void handleModelsUsingConnector(SearchHit[] searchHits, String connectorId, ActionListener actionListener) { log.error(searchHits.length + " models are still using this connector, please delete or update the models first!"); List modelIds = new ArrayList<>(); @@ -165,32 +155,15 @@ private void handleModelsUsingConnector(SearchHit[] searchHits, String connector ); } - private void handleSearchFailure(String connectorId, String tenantId, Exception e, ActionListener actionListener) { + private void handleSearchFailure(String connectorId, Exception e, ActionListener actionListener) { if (e instanceof IndexNotFoundException) { - handleNoModelsUsingConnector(connectorId, tenantId, actionListener); + deleteConnector(connectorId, actionListener); return; } log.error("Failed to search for models using connector: {}", connectorId, e); actionListener.onFailure(e); } - // TODO: merge this method with validateConnectorAccess and use sdkClient not client. - @VisibleForTesting - void checkConnectorPermission( - String connectorId, - String tenantId, - ActionListener actionListener, - Runnable deleteAction - ) { - MLConnectorGetRequest mlConnectorGetRequest = new MLConnectorGetRequest(connectorId, tenantId, true); - client.execute(MLConnectorGetAction.INSTANCE, mlConnectorGetRequest, ActionListener.wrap(getResponse -> { - if (TenantAwareHelper - .validateTenantResource(mlFeatureEnabledSetting, tenantId, getResponse.getMlConnector().getTenantId(), actionListener)) { - deleteAction.run(); - } - }, actionListener::onFailure)); - } - private void deleteConnector(String connectorId, ActionListener actionListener) { DeleteRequest deleteRequest = new DeleteRequest(ML_CONNECTOR_INDEX, connectorId); try { diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java index 42fc164500..8221b33886 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java @@ -5,9 +5,7 @@ package org.opensearch.ml.action.connector; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; -import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; import org.opensearch.OpenSearchStatusException; @@ -20,8 +18,6 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest; @@ -84,62 +80,48 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - context.restore(); - log.debug("Completed Get Connector Request, id:{}", connectorId); - if (throwable != null) { - Throwable cause = throwable.getCause() == null ? throwable : throwable.getCause(); - if (cause instanceof IndexNotFoundException) { - log.error("Failed to get connector index", cause); - actionListener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND)); - } else { - log.error("Failed to get ML connector {}", connectorId, cause); - actionListener.onFailure(new RuntimeException(cause)); - } - } else { - if (r != null && r.parser().isPresent()) { - try { - XContentParser parser = r.parser().get(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Connector mlConnector = Connector.createConnector(parser); - if (!TenantAwareHelper - .validateTenantResource(mlFeatureEnabledSetting, tenantId, mlConnector.getTenantId(), actionListener)) { - return; - } - mlConnector.removeCredential(); + connectorAccessControlHelper + .getConnector( + sdkClient, + client, + context, + getDataObjectRequest, + connectorId, + ActionListener + .wrap( + connector -> handleConnectorAccessValidation(user, tenantId, connector, actionListener), + e -> handleConnectorAccessValidationFailure(connectorId, e, actionListener) + ) + ); - if (connectorAccessControlHelper.hasPermission(user, mlConnector)) { - actionListener.onResponse(MLConnectorGetResponse.builder().mlConnector(mlConnector).build()); - } else { - actionListener - .onFailure( - new OpenSearchStatusException( - "You don't have permission to access this connector", - RestStatus.FORBIDDEN - ) - ); - } - } catch (Exception e) { - log.error("Failed to parse ml connector {}", r.id(), e); - actionListener.onFailure(e); - } - } else { - actionListener - .onFailure( - new OpenSearchStatusException( - "Failed to find connector with the provided connector id: " + connectorId, - RestStatus.NOT_FOUND - ) - ); - } - } - }); } catch (Exception e) { - log.error("Failed to get ML connector " + connectorId, e); + log.error("Failed to get ML connector {}", connectorId, e); actionListener.onFailure(e); } + } + + private void handleConnectorAccessValidation( + User user, + String tenantId, + Connector mlConnector, + ActionListener actionListener + ) { + if (TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, tenantId, mlConnector.getTenantId(), actionListener)) { + if (connectorAccessControlHelper.hasPermission(user, mlConnector)) { + actionListener.onResponse(MLConnectorGetResponse.builder().mlConnector(mlConnector).build()); + } else { + actionListener + .onFailure(new OpenSearchStatusException("You don't have permission to access this connector", RestStatus.FORBIDDEN)); + } + } + } + private void handleConnectorAccessValidationFailure( + String connectorId, + Exception e, + ActionListener actionListener + ) { + log.error("Failed to get ML connector: {}", connectorId, e); + actionListener.onFailure(e); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java index 970d94aa48..f227b22e2a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java @@ -28,6 +28,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -35,13 +36,19 @@ import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.sdk.GetDataObjectRequest; +import org.opensearch.sdk.SdkClient; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -53,8 +60,9 @@ @FieldDefaults(level = AccessLevel.PRIVATE) public class UpdateConnectorTransportAction extends HandledTransportAction { Client client; - + private final SdkClient sdkClient; ConnectorAccessControlHelper connectorAccessControlHelper; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; MLModelManager mlModelManager; MLEngine mlEngine; volatile List trustedConnectorEndpointsRegex; @@ -64,17 +72,21 @@ public UpdateConnectorTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + SdkClient sdkClient, ConnectorAccessControlHelper connectorAccessControlHelper, MLModelManager mlModelManager, Settings settings, ClusterService clusterService, - MLEngine mlEngine + MLEngine mlEngine, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { super(MLUpdateConnectorAction.NAME, transportService, actionFilters, MLUpdateConnectorRequest::new); this.client = client; + this.sdkClient = sdkClient; this.connectorAccessControlHelper = connectorAccessControlHelper; this.mlModelManager = mlModelManager; this.mlEngine = mlEngine; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; trustedConnectorEndpointsRegex = ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings); clusterService .getClusterSettings() @@ -84,27 +96,49 @@ public UpdateConnectorTransportAction( @Override protected void doExecute(Task task, ActionRequest request, ActionListener listener) { MLUpdateConnectorRequest mlUpdateConnectorAction = MLUpdateConnectorRequest.fromActionRequest(request); + MLCreateConnectorInput mlCreateConnectorInput = mlUpdateConnectorAction.getUpdateContent(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, mlCreateConnectorInput.getTenantId(), listener)) { + return; + } String connectorId = mlUpdateConnectorAction.getConnectorId(); + FetchSourceContext fetchSourceContext = new FetchSourceContext(true, Strings.EMPTY_ARRAY, Strings.EMPTY_ARRAY); + GetDataObjectRequest getDataObjectRequest = new GetDataObjectRequest.Builder() + .index(ML_CONNECTOR_INDEX) + .id(connectorId) + .fetchSourceContext(fetchSourceContext) + .build(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - connectorAccessControlHelper.getConnector(client, connectorId, ActionListener.wrap(connector -> { - boolean hasPermission = connectorAccessControlHelper.validateConnectorAccess(client, connector); - if (Boolean.TRUE.equals(hasPermission)) { - connector.update(mlUpdateConnectorAction.getUpdateContent(), mlEngine::encrypt); - connector.validateConnectorURL(trustedConnectorEndpointsRegex); - UpdateRequest updateRequest = new UpdateRequest(ML_CONNECTOR_INDEX, connectorId); - updateRequest.doc(connector.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); - updateUndeployedConnector(connectorId, updateRequest, listener, context); - } else { - listener - .onFailure( - new IllegalArgumentException("You don't have permission to update the connector, connector id: " + connectorId) - ); - } - }, exception -> { - log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", connectorId, exception); - listener.onFailure(exception); - })); + connectorAccessControlHelper + .getConnector(sdkClient, client, context, getDataObjectRequest, connectorId, ActionListener.wrap(connector -> { + if (TenantAwareHelper + .validateTenantResource( + mlFeatureEnabledSetting, + mlCreateConnectorInput.getTenantId(), + connector.getTenantId(), + listener + )) { + boolean hasPermission = connectorAccessControlHelper.validateConnectorAccess(client, connector); + if (Boolean.TRUE.equals(hasPermission)) { + connector.update(mlUpdateConnectorAction.getUpdateContent(), mlEngine::encrypt); + connector.validateConnectorURL(trustedConnectorEndpointsRegex); + UpdateRequest updateRequest = new UpdateRequest(ML_CONNECTOR_INDEX, connectorId); + updateRequest + .doc(connector.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); + updateUndeployedConnector(connectorId, updateRequest, listener, context); + } else { + listener + .onFailure( + new IllegalArgumentException( + "You don't have permission to update the connector, connector id: " + connectorId + ) + ); + } + } + }, exception -> { + log.error("Unable to find the connector with ID {}. Details: {}", connectorId, exception); + listener.onFailure(exception); + })); } catch (Exception e) { log.error("Failed to update ML connector for connector id {}. Details {}:", connectorId, e); listener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java index 9faee40aa6..c6a8befcce 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java @@ -48,7 +48,9 @@ import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; import org.opensearch.ml.common.transport.model.MLModelGetRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.ml.utils.TenantAwareHelper; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -74,6 +76,7 @@ public class DeleteModelTransportAction extends HandledTransportAction actionListener) { MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.fromActionRequest(request); String modelId = mlModelDeleteRequest.getModelId(); - MLModelGetRequest mlModelGetRequest = new MLModelGetRequest(modelId, false, false); + String tenantId = mlModelDeleteRequest.getTenantId(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { + return; + } + MLModelGetRequest mlModelGetRequest = new MLModelGetRequest(modelId, false, false, tenantId); FetchSourceContext fetchSourceContext = getFetchSourceContext(mlModelGetRequest.isReturnContent()); GetRequest getRequest = new GetRequest(ML_MODEL_INDEX).id(modelId).fetchSourceContext(fetchSourceContext); User user = RestActionUtils.getUserContext(client); @@ -114,6 +123,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - Client client; - NamedXContentRegistry xContentRegistry; - ClusterService clusterService; + final Client client; + final SdkClient sdkClient; + final NamedXContentRegistry xContentRegistry; + final ClusterService clusterService; - ModelAccessControlHelper modelAccessControlHelper; + final ModelAccessControlHelper modelAccessControlHelper; + + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; Settings settings; @@ -63,30 +70,43 @@ public GetModelTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + SdkClient sdkClient, Settings settings, NamedXContentRegistry xContentRegistry, ClusterService clusterService, - ModelAccessControlHelper modelAccessControlHelper + ModelAccessControlHelper modelAccessControlHelper, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { super(MLModelGetAction.NAME, transportService, actionFilters, MLModelGetRequest::new); this.client = client; + this.sdkClient = sdkClient; this.settings = settings; this.xContentRegistry = xContentRegistry; this.clusterService = clusterService; this.modelAccessControlHelper = modelAccessControlHelper; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { MLModelGetRequest mlModelGetRequest = MLModelGetRequest.fromActionRequest(request); String modelId = mlModelGetRequest.getModelId(); + String tenantId = mlModelGetRequest.getTenantId(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { + return; + } FetchSourceContext fetchSourceContext = getFetchSourceContext(mlModelGetRequest.isReturnContent()); + GetDataObjectRequest getDataObjectRequest = new GetDataObjectRequest.Builder() + .index(ML_MODEL_INDEX) + .id(modelId) + .fetchSourceContext(fetchSourceContext) + .build(); GetRequest getRequest = new GetRequest(ML_MODEL_INDEX).id(modelId).fetchSourceContext(fetchSourceContext); User user = RestActionUtils.getUserContext(client); boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); + ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); client.get(getRequest, ActionListener.wrap(r -> { if (r != null && r.isExists()) { try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { @@ -94,6 +114,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - log.error("Failed to validate Access for Model Id " + modelId, e); + log.error("Failed to validate Access for Model Id {}", modelId, e); wrappedListener.onFailure(e); })); } } catch (Exception e) { - log.error("Failed to parse ml model " + r.getId(), e); + log.error("Failed to parse ml model {}", r.getId(), e); wrappedListener.onFailure(e); } } else { @@ -147,12 +171,12 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { Client client; - + private final SdkClient sdkClient; Settings settings; ClusterService clusterService; ModelAccessControlHelper modelAccessControlHelper; ConnectorAccessControlHelper connectorAccessControlHelper; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; MLModelManager mlModelManager; MLModelGroupManager mlModelGroupManager; MLEngine mlEngine; @@ -88,16 +92,19 @@ public UpdateModelTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + SdkClient sdkClient, ConnectorAccessControlHelper connectorAccessControlHelper, ModelAccessControlHelper modelAccessControlHelper, MLModelManager mlModelManager, MLModelGroupManager mlModelGroupManager, Settings settings, ClusterService clusterService, - MLEngine mlEngine + MLEngine mlEngine, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { super(MLUpdateModelAction.NAME, transportService, actionFilters, MLUpdateModelRequest::new); this.client = client; + this.sdkClient = sdkClient; this.modelAccessControlHelper = modelAccessControlHelper; this.connectorAccessControlHelper = connectorAccessControlHelper; this.mlModelManager = mlModelManager; @@ -105,6 +112,7 @@ public UpdateModelTransportAction( this.clusterService = clusterService; this.mlEngine = mlEngine; this.settings = settings; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; trustedConnectorEndpointsRegex = ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings); clusterService .getClusterSettings() @@ -116,6 +124,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { - if (!isModelDeploying(mlModel.getModelState())) { - FunctionName functionName = mlModel.getAlgorithm(); - // TODO: Support update as well as model/user level throttling in all other DLModel categories - if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { - if (mlModel.getIsHidden() != null && mlModel.getIsHidden()) { - if (isSuperAdmin) { - updateRemoteOrTextEmbeddingModel(modelId, updateModelInput, mlModel, user, wrappedListener); + if (TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModel.getTenantId(), actionListener)) { + if (!isModelDeploying(mlModel.getModelState())) { + FunctionName functionName = mlModel.getAlgorithm(); + // TODO: Support update as well as model/user level throttling in all other DLModel categories + if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { + if (mlModel.getIsHidden() != null && mlModel.getIsHidden()) { + if (isSuperAdmin) { + updateRemoteOrTextEmbeddingModel(modelId, tenantId, updateModelInput, mlModel, user, wrappedListener); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN + ) + ); + } } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model", - RestStatus.FORBIDDEN - ) + modelAccessControlHelper + .validateModelGroupAccess( + user, + mlModel.getModelGroupId(), + client, + ActionListener.wrap(hasPermission -> { + if (hasPermission) { + updateRemoteOrTextEmbeddingModel( + modelId, + tenantId, + updateModelInput, + mlModel, + user, + wrappedListener + ); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model, model ID " + + modelId, + RestStatus.FORBIDDEN + ) + ); + } + }, exception -> { + log + .error( + "Permission denied: Unable to update the model with ID {}. Details: {}", + modelId, + exception + ); + wrappedListener.onFailure(exception); + }) ); } + } else { - modelAccessControlHelper - .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { - if (hasPermission) { - updateRemoteOrTextEmbeddingModel(modelId, updateModelInput, mlModel, user, wrappedListener); - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model, model ID " - + modelId, - RestStatus.FORBIDDEN - ) - ); - } - }, exception -> { - log.error("Permission denied: Unable to update the model with ID {}. Details: {}", modelId, exception); - wrappedListener.onFailure(exception); - })); + wrappedListener + .onFailure( + new OpenSearchStatusException( + "The function category " + functionName.toString() + " is not supported at this time.", + RestStatus.FORBIDDEN + ) + ); } - } else { wrappedListener .onFailure( new OpenSearchStatusException( - "The function category " + functionName.toString() + " is not supported at this time.", - RestStatus.FORBIDDEN + "Model is deploying. Please wait for the model to complete deployment. model ID " + modelId, + RestStatus.CONFLICT ) ); } - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "Model is deploying. Please wait for the model to complete deployment. model ID " + modelId, - RestStatus.CONFLICT - ) - ); } }, e -> wrappedListener @@ -196,6 +227,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - if (hasNewConnectorPermission) { - updateModelWithRegisteringToAnotherModelGroup( - modelId, - newModelGroupId, - user, - updateModelInput, - wrappedListener, - isUpdateModelCache - ); - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "You don't have permission to update the connector, connector id: " + newConnectorId, - RestStatus.FORBIDDEN - ) - ); - } - }, exception -> { - log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", newConnectorId, exception); - wrappedListener.onFailure(exception); - })); + connectorAccessControlHelper + .validateConnectorAccess( + sdkClient, + client, + newConnectorId, + tenantId, + mlFeatureEnabledSetting, + ActionListener.wrap(hasNewConnectorPermission -> { + if (hasNewConnectorPermission) { + updateModelWithRegisteringToAnotherModelGroup( + modelId, + newModelGroupId, + user, + updateModelInput, + wrappedListener, + isUpdateModelCache + ); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "You don't have permission to update the connector, connector id: " + newConnectorId, + RestStatus.FORBIDDEN + ) + ); + } + }, exception -> { + log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", newConnectorId, exception); + wrappedListener.onFailure(exception); + }) + ); } else { wrappedListener .onFailure( diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 765a076c32..8356094803 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -63,6 +63,8 @@ import org.opensearch.ml.task.MLTaskManager; import org.opensearch.ml.utils.MLExceptionUtils; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.sdk.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -82,6 +84,7 @@ public class TransportRegisterModelAction extends HandledTransportAction listener) { MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest(request); MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, registerModelInput.getTenantId(), listener)) { + return; + } if (FunctionName.isDLModel(registerModelInput.getFunctionName()) && !mlFeatureEnabledSetting.isLocalModelEnabled()) { throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); } @@ -236,26 +244,35 @@ private void doRegister(MLRegisterModelInput registerModelInput, ActionListener< FunctionName functionName = registerModelInput.getFunctionName(); if (FunctionName.REMOTE == functionName) { if (Strings.isNotBlank(registerModelInput.getConnectorId())) { - connectorAccessControlHelper.validateConnectorAccess(client, registerModelInput.getConnectorId(), ActionListener.wrap(r -> { - if (Boolean.TRUE.equals(r)) { - createModelGroup(registerModelInput, listener); - } else { - listener - .onFailure( - new IllegalArgumentException( - "You don't have permission to use the connector provided, connector id: " - + registerModelInput.getConnectorId() - ) - ); - } - }, e -> { - log - .error( - "You don't have permission to use the connector provided, connector id: " + registerModelInput.getConnectorId(), - e - ); - listener.onFailure(e); - })); + connectorAccessControlHelper + .validateConnectorAccess( + sdkClient, + client, + registerModelInput.getConnectorId(), + registerModelInput.getTenantId(), + mlFeatureEnabledSetting, + ActionListener.wrap(r -> { + if (Boolean.TRUE.equals(r)) { + createModelGroup(registerModelInput, listener); + } else { + listener + .onFailure( + new IllegalArgumentException( + "You don't have permission to use the connector provided, connector id: " + + registerModelInput.getConnectorId() + ) + ); + } + }, e -> { + log + .error( + "You don't have permission to use the connector provided, connector id: {}", + registerModelInput.getConnectorId(), + e + ); + listener.onFailure(e); + }) + ); } else { validateInternalConnector(registerModelInput); ActionListener dryRunResultListener = ActionListener.wrap(res -> { @@ -350,7 +367,7 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen listener.onResponse(new MLRegisterModelResponse(taskId, MLTaskState.CREATED.name())); ActionListener forwardActionListener = ActionListener.wrap(res -> { - log.debug("Register model response: " + res); + log.debug("Register model response: {}", res); if (!clusterService.localNode().getId().equals(nodeId)) { mlTaskManager.remove(taskId); } diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java index b1096e7e38..2b9935f82e 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java @@ -8,7 +8,10 @@ package org.opensearch.ml.helper; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; +import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; import org.apache.lucene.search.join.ScoreMode; import org.opensearch.OpenSearchStatusException; @@ -23,6 +26,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.NestedQueryBuilder; import org.opensearch.index.query.QueryBuilder; @@ -32,9 +36,14 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.connector.AbstractConnector; import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.MLNodeUtils; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.sdk.GetDataObjectRequest; +import org.opensearch.sdk.SdkClient; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; import lombok.extern.log4j.Log4j2; @@ -70,13 +79,49 @@ public void validateConnectorAccess(Client client, String connectorId, ActionLis getConnector(client, connectorId, ActionListener.wrap(connector -> { boolean hasPermission = hasPermission(user, connector); wrappedListener.onResponse(hasPermission); - }, e -> { wrappedListener.onFailure(e); })); + }, wrappedListener::onFailure)); } catch (Exception e) { log.error("Failed to validate Access for connector:" + connectorId, e); listener.onFailure(e); } } + public void validateConnectorAccess( + SdkClient sdkClient, + Client client, + String connectorId, + String tenantId, + MLFeatureEnabledSetting mlFeatureEnabledSetting, + ActionListener listener + ) { + + User user = RestActionUtils.getUserContext(client); + if (!mlFeatureEnabledSetting.isMultiTenancyEnabled()) { + if (isAdmin(user) || accessControlNotEnabled(user)) { + listener.onResponse(true); + return; + } + } + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + FetchSourceContext fetchSourceContext = getFetchSourceContext(true); + GetDataObjectRequest getDataObjectRequest = new GetDataObjectRequest.Builder() + .index(ML_CONNECTOR_INDEX) + .id(connectorId) + .fetchSourceContext(fetchSourceContext) + .build(); + getConnector(sdkClient, client, context, getDataObjectRequest, connectorId, ActionListener.wrap(connector -> { + if (TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, tenantId, connector.getTenantId(), listener)) { + boolean hasPermission = hasPermission(user, connector); + wrappedListener.onResponse(hasPermission); + } + }, wrappedListener::onFailure)); + } catch (Exception e) { + log.error("Failed to validate Access for connector:{}", connectorId, e); + listener.onFailure(e); + } + } + public boolean validateConnectorAccess(Client client, Connector connector) { User user = RestActionUtils.getUserContext(client); if (isAdmin(user) || accessControlNotEnabled(user)) { @@ -85,6 +130,8 @@ public boolean validateConnectorAccess(Client client, Connector connector) { return hasPermission(user, connector); } + // TODO will remove this method in favor of other getConnector method. This method is still being used in update model/update connect. + // I'll remove this method when I'll refactor update methods. public void getConnector(Client client, String connectorId, ActionListener listener) { GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId); client.get(getRequest, ActionListener.wrap(r -> { @@ -109,6 +156,55 @@ public void getConnector(Client client, String connectorId, ActionListener listener + ) { + + sdkClient + .getDataObjectAsync(getDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) + .whenComplete((r, throwable) -> { + context.restore(); + log.debug("Completed Get Connector Request, id:{}", connectorId); + if (throwable != null) { + Throwable cause = throwable.getCause() == null ? throwable : throwable.getCause(); + if (cause instanceof IndexNotFoundException) { + log.error("Failed to get connector index", cause); + listener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND)); + } else { + log.error("Failed to find connector {}", connectorId, cause); + listener.onFailure(new RuntimeException(cause)); + } + } else { + if (r != null && r.parser().isPresent()) { + try { + XContentParser parser = r.parser().get(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Connector mlConnector = Connector.createConnector(parser); + mlConnector.removeCredential(); + listener.onResponse(mlConnector); + } catch (Exception e) { + log.error("Failed to parse ml connector {}", r.id(), e); + listener.onFailure(e); + } + } else { + listener + .onFailure( + new OpenSearchStatusException( + "Failed to find connector with the provided connector id: " + connectorId, + RestStatus.NOT_FOUND + ) + ); + } + } + }); + + } + public boolean skipConnectorAccessControl(User user) { // Case 1: user == null when 1. Security is disabled. 2. When user is super-admin // Case 2: If Security is enabled and filter is disabled, proceed with search as diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 30cc0a0567..3e91edd8b3 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -540,6 +540,7 @@ private void indexRemoteModel( .isHidden(registerModelInput.getIsHidden()) .guardrails(registerModelInput.getGuardrails()) .modelInterface(registerModelInput.getModelInterface()) + .tenantId(registerModelInput.getTenantId()) .build(); IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX); @@ -607,6 +608,7 @@ void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, St .isHidden(registerModelInput.getIsHidden()) .guardrails(registerModelInput.getGuardrails()) .modelInterface(registerModelInput.getModelInterface()) + .tenantId(registerModelInput.getTenantId()) .build(); IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX); if (registerModelInput.getIsHidden() != null && registerModelInput.getIsHidden()) { diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 63200c56b2..09cec6245b 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -702,8 +702,8 @@ public List getRestHandlers( RestMLTrainAndPredictAction restMLTrainAndPredictAction = new RestMLTrainAndPredictAction(); RestMLPredictionAction restMLPredictionAction = new RestMLPredictionAction(mlModelManager, mlFeatureEnabledSetting); RestMLExecuteAction restMLExecuteAction = new RestMLExecuteAction(mlFeatureEnabledSetting); - RestMLGetModelAction restMLGetModelAction = new RestMLGetModelAction(); - RestMLDeleteModelAction restMLDeleteModelAction = new RestMLDeleteModelAction(); + RestMLGetModelAction restMLGetModelAction = new RestMLGetModelAction(mlFeatureEnabledSetting); + RestMLDeleteModelAction restMLDeleteModelAction = new RestMLDeleteModelAction(mlFeatureEnabledSetting); RestMLSearchModelAction restMLSearchModelAction = new RestMLSearchModelAction(); RestMLGetTaskAction restMLGetTaskAction = new RestMLGetTaskAction(); RestMLDeleteTaskAction restMLDeleteTaskAction = new RestMLDeleteTaskAction(); @@ -723,7 +723,7 @@ public List getRestHandlers( RestMLUpdateModelGroupAction restMLUpdateModelGroupAction = new RestMLUpdateModelGroupAction(); RestMLGetModelGroupAction restMLGetModelGroupAction = new RestMLGetModelGroupAction(); RestMLSearchModelGroupAction restMLSearchModelGroupAction = new RestMLSearchModelGroupAction(); - RestMLUpdateModelAction restMLUpdateModelAction = new RestMLUpdateModelAction(); + RestMLUpdateModelAction restMLUpdateModelAction = new RestMLUpdateModelAction(mlFeatureEnabledSetting); RestMLDeleteModelGroupAction restMLDeleteModelGroupAction = new RestMLDeleteModelGroupAction(); RestMLCreateConnectorAction restMLCreateConnectorAction = new RestMLCreateConnectorAction(mlFeatureEnabledSetting); RestMLGetConnectorAction restMLGetConnectorAction = new RestMLGetConnectorAction(clusterService, settings, mlFeatureEnabledSetting); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteModelAction.java index d8e0b9f3b6..29a14e94c3 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteModelAction.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; +import static org.opensearch.ml.utils.RestActionUtils.getTenantID; import java.io.IOException; import java.util.List; @@ -15,6 +16,7 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.ml.common.transport.model.MLModelDeleteAction; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -27,7 +29,11 @@ public class RestMLDeleteModelAction extends BaseRestHandler { private static final String ML_DELETE_MODEL_ACTION = "ml_delete_model_action"; - public void RestMLDeleteModelAction() {} + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + public RestMLDeleteModelAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } @Override public String getName() { @@ -43,8 +49,9 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { String modelId = request.param(PARAMETER_MODEL_ID); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); - MLModelDeleteRequest mlModelDeleteRequest = new MLModelDeleteRequest(modelId); + MLModelDeleteRequest mlModelDeleteRequest = new MLModelDeleteRequest(modelId, tenantId); return channel -> client.execute(MLModelDeleteAction.INSTANCE, mlModelDeleteRequest, new RestToXContentListener<>(channel)); } } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelAction.java index 097bc6fb77..1f69c035c8 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelAction.java @@ -8,6 +8,7 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; +import static org.opensearch.ml.utils.RestActionUtils.getTenantID; import static org.opensearch.ml.utils.RestActionUtils.returnContent; import java.io.IOException; @@ -17,6 +18,7 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.ml.common.transport.model.MLModelGetAction; import org.opensearch.ml.common.transport.model.MLModelGetRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -27,10 +29,14 @@ public class RestMLGetModelAction extends BaseRestHandler { private static final String ML_GET_MODEL_ACTION = "ml_get_model_action"; + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + /** * Constructor */ - public RestMLGetModelAction() {} + public RestMLGetModelAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } @Override public String getName() { @@ -59,7 +65,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client MLModelGetRequest getRequest(RestRequest request) throws IOException { String modelId = getParameterId(request, PARAMETER_MODEL_ID); boolean returnContent = returnContent(request); - - return new MLModelGetRequest(modelId, returnContent, true); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); + return new MLModelGetRequest(modelId, returnContent, true, tenantId); } } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java index 68fd73b20a..4d16ddef3d 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java @@ -12,6 +12,7 @@ import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_DEPLOY_MODEL; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_VERSION; +import static org.opensearch.ml.utils.RestActionUtils.getTenantID; import java.io.IOException; import java.util.List; @@ -93,10 +94,12 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client */ @VisibleForTesting MLRegisterModelRequest getRequest(RestRequest request) throws IOException { + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); boolean loadModel = request.paramAsBoolean(PARAMETER_DEPLOY_MODEL, false); XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); MLRegisterModelInput mlInput = MLRegisterModelInput.parse(parser, loadModel); + mlInput.setTenantId(tenantId); if (mlInput.getFunctionName() == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); } else if (FunctionName.isDLModel(mlInput.getFunctionName()) && !mlFeatureEnabledSetting.isLocalModelEnabled()) { diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java index b6e3822318..3a563c8abd 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java @@ -10,6 +10,7 @@ import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONNECTOR_ID; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; +import static org.opensearch.ml.utils.RestActionUtils.getTenantID; import java.io.IOException; import java.util.List; @@ -65,11 +66,12 @@ private MLUpdateConnectorRequest getRequest(RestRequest request) throws IOExcept } String connectorId = getParameterId(request, PARAMETER_CONNECTOR_ID); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); try { XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - return MLUpdateConnectorRequest.parse(parser, connectorId); + return MLUpdateConnectorRequest.parse(parser, connectorId, tenantId); } catch (IllegalStateException illegalStateException) { throw new OpenSearchParseException(illegalStateException.getMessage()); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java index 5a40ae8c47..fe1dbcc415 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; +import static org.opensearch.ml.utils.RestActionUtils.getTenantID; import java.io.IOException; import java.util.List; @@ -22,6 +23,7 @@ import org.opensearch.ml.common.transport.model.MLUpdateModelAction; import org.opensearch.ml.common.transport.model.MLUpdateModelInput; import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -31,6 +33,11 @@ public class RestMLUpdateModelAction extends BaseRestHandler { private static final String ML_UPDATE_MODEL_ACTION = "ml_update_model_action"; + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + public RestMLUpdateModelAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } @Override public String getName() { @@ -51,7 +58,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client /** * Creates a MLUpdateModelRequest from a RestRequest - * + * * @param request RestRequest * @return MLUpdateModelRequest */ @@ -61,6 +68,7 @@ private MLUpdateModelRequest getRequest(RestRequest request) throws IOException } String modelId = getParameterId(request, PARAMETER_MODEL_ID); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); @@ -76,6 +84,7 @@ private MLUpdateModelRequest getRequest(RestRequest request) throws IOException input.setModelId(modelId); input.setVersion(null); input.setUpdatedConnector(null); + input.setTenantId(tenantId); return new MLUpdateModelRequest(input); } catch (IllegalStateException e) { throw new OpenSearchParseException(e.getMessage()); diff --git a/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java b/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java index c4fb3f906c..463893997d 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java @@ -385,7 +385,7 @@ public MLTask getTask(String taskId) { } public MLModel getModel(String modelId) { - MLModelGetRequest getRequest = new MLModelGetRequest(modelId, false, true); + MLModelGetRequest getRequest = new MLModelGetRequest(modelId, false, true, null); MLModelGetResponse response = client().execute(MLModelGetAction.INSTANCE, getRequest).actionGet(5000); return response.getMlModel(); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java index 6bb41ae889..6f10e1d715 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java @@ -6,6 +6,7 @@ package org.opensearch.ml.action.connector; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -23,7 +24,6 @@ import org.apache.lucene.search.TotalHits; import org.junit.AfterClass; import org.junit.Before; -import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -48,7 +48,6 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.index.shard.ShardId; -import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -57,8 +56,6 @@ import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; -import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest; -import org.opensearch.ml.common.transport.connector.MLConnectorGetResponse; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.sdkclient.LocalClusterIndicesClient; import org.opensearch.ml.settings.MLFeatureEnabledSetting; @@ -146,10 +143,10 @@ public void setup() throws IOException { ); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(5); listener.onResponse(true); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any()); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), isA(ActionListener.class)); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); @@ -162,12 +159,18 @@ public static void cleanup() { ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); } - public void testDeleteConnector_Success() throws IOException, InterruptedException { + public void testDeleteConnector_Success() throws InterruptedException { when(deleteResponse.getResult()).thenReturn(DELETED); PlainActionFuture future = PlainActionFuture.newFuture(); future.onResponse(deleteResponse); when(client.delete(any(DeleteRequest.class))).thenReturn(future); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(5); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); + SearchResponse searchResponse = getEmptySearchResponse(); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -192,6 +195,12 @@ public void testDeleteConnector_ModelIndexNotFoundSuccess() throws IOException, future.onResponse(deleteResponse); when(client.delete(any(DeleteRequest.class))).thenReturn(future); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(5); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); + doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onFailure(new IndexNotFoundException("ml_model index not found!")); @@ -209,12 +218,19 @@ public void testDeleteConnector_ModelIndexNotFoundSuccess() throws IOException, assertEquals(DELETED, captor.getValue().getResult()); } + // TODO need to check if it has any value in it or not. public void testDeleteConnector_ConnectorNotFound() throws IOException, InterruptedException { when(deleteResponse.getResult()).thenReturn(NOT_FOUND); PlainActionFuture future = PlainActionFuture.newFuture(); future.onResponse(deleteResponse); when(client.delete(any(DeleteRequest.class))).thenReturn(future); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(5); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); + SearchResponse searchResponse = getEmptySearchResponse(); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -238,6 +254,12 @@ public void testDeleteConnector_BlockedByModel() throws IOException, Interrupted future.onResponse(deleteResponse); when(client.delete(any(DeleteRequest.class))).thenReturn(future); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(5); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); + SearchResponse searchResponse = getNonEmptySearchResponse(); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -260,10 +282,10 @@ public void testDeleteConnector_BlockedByModel() throws IOException, Interrupted public void test_UserHasNoAccessException() throws IOException { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(5); listener.onResponse(false); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any()); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -272,6 +294,13 @@ public void test_UserHasNoAccessException() throws IOException { } public void testDeleteConnector_SearchFailure() throws IOException { + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(5); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); + doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onFailure(new RuntimeException("Search Failed!")); @@ -285,8 +314,16 @@ public void testDeleteConnector_SearchFailure() throws IOException { } public void testDeleteConnector_SearchException() throws IOException { + when(client.threadPool()).thenThrow(new RuntimeException("Thread Context Error!")); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(5); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); + + deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); verify(actionListener).onFailure(argumentCaptor.capture()); @@ -296,6 +333,13 @@ public void testDeleteConnector_SearchException() throws IOException { public void testDeleteConnector_ResourceNotFoundException() throws IOException, InterruptedException { when(client.delete(any(DeleteRequest.class))).thenThrow(new ResourceNotFoundException("errorMessage")); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(5); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); + + SearchResponse searchResponse = getEmptySearchResponse(); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -316,18 +360,11 @@ public void testDeleteConnector_ResourceNotFoundException() throws IOException, } public void test_ValidationFailedException() throws IOException { - GetResponse getResponse = prepareMLConnector(null); - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(getResponse); - return null; - }).when(client).search(any(), any()); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(5); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any()); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -352,55 +389,6 @@ public void testDeleteConnector_MultiTenancyEnabled_NoTenantId() throws Interrup assertEquals("You don't have permission to access this resource", argumentCaptor.getValue().getMessage()); } - @Test - public void testCheckConnectorPermission_AllowedToDelete() { - String connectorId = "connector_id"; - String tenantId = "tenant_id"; - HttpConnector connector = HttpConnector.builder().name("test_connector").protocol("http").tenantId(tenantId).build(); - MLConnectorGetResponse getResponse = new MLConnectorGetResponse(connector); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(getResponse); - return null; - }).when(client).execute(any(), any(MLConnectorGetRequest.class), any()); - - Runnable deleteAction = mock(Runnable.class); - - deleteConnectorTransportAction.checkConnectorPermission(connectorId, tenantId, actionListener, deleteAction); - - verify(deleteAction).run(); - } - - @Test - public void testCheckConnectorPermission_NotAllowedToDelete() { - // Enable multi-tenancy - when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); - String connectorId = "connector_id"; - String tenantId = "tenant_id"; - String differentTenantId = "different_tenant_id"; - HttpConnector connector = HttpConnector.builder().name("test_connector").protocol("http").tenantId(differentTenantId).build(); - MLConnectorGetResponse getResponse = new MLConnectorGetResponse(connector); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(getResponse); - return null; - }).when(client).execute(any(), any(MLConnectorGetRequest.class), any()); - - Runnable deleteAction = mock(Runnable.class); - - deleteConnectorTransportAction.checkConnectorPermission(connectorId, tenantId, actionListener, deleteAction); - - ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(captor.capture()); - Exception exception = captor.getValue(); - assert exception instanceof OpenSearchStatusException; - OpenSearchStatusException statusException = (OpenSearchStatusException) exception; - assert statusException.status() == RestStatus.FORBIDDEN; - assert statusException.getMessage().equals("You don't have permission to access this resource"); - } - public GetResponse prepareMLConnector(String tenantId) throws IOException { HttpConnector connector = HttpConnector.builder().name("test_connector").protocol("http").tenantId(tenantId).build(); XContentBuilder content = connector.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java index f4786898d6..ba7189df13 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java @@ -20,9 +20,11 @@ import org.junit.AfterClass; import org.junit.Before; +import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.LatchedActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; @@ -36,11 +38,12 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.get.GetResult; +import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest; import org.opensearch.ml.common.transport.connector.MLConnectorGetResponse; @@ -123,12 +126,6 @@ public void setup() throws IOException { ) ); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(true); - return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any()); - threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); @@ -140,12 +137,15 @@ public static void cleanup() { ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); } - public void testGetConnector_UserHasNodeAccess() throws IOException, InterruptedException { + @Test + public void testGetConnector_UserHasNoAccess() throws IOException, InterruptedException { + HttpConnector httpConnector = HttpConnector.builder().name("test_connector").protocol("http").tenantId("tenantId").build(); + when(connectorAccessControlHelper.hasPermission(any(), any())).thenReturn(false); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(false); + ActionListener listener = invocation.getArgument(5); + listener.onResponse(httpConnector); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any()); + }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); GetResponse getResponse = prepareConnector(null); PlainActionFuture future = PlainActionFuture.newFuture(); @@ -162,32 +162,19 @@ public void testGetConnector_UserHasNodeAccess() throws IOException, Interrupted assertEquals("You don't have permission to access this connector", argumentCaptor.getValue().getMessage()); } - public void testGetConnector_ValidateAccessFailed() throws IOException, InterruptedException { + @Test + public void testGetConnector_NullResponse() throws InterruptedException { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onFailure(new Exception("Failed to validate access")); + ActionListener listener = invocation.getArgument(5); + listener + .onFailure( + new OpenSearchStatusException( + "Failed to find connector with the provided connector id: connector_id", + RestStatus.NOT_FOUND + ) + ); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any()); - - GetResponse getResponse = prepareConnector(null); - PlainActionFuture future = PlainActionFuture.newFuture(); - future.onResponse(getResponse); - when(client.get(any(GetRequest.class))).thenReturn(future); - - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); - getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, latchedActionListener); - latch.await(); - - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("You don't have permission to access this connector", argumentCaptor.getValue().getMessage()); - } - - public void testGetConnector_NullResponse() throws InterruptedException { - PlainActionFuture future = PlainActionFuture.newFuture(); - future.onResponse(null); - when(client.get(any(GetRequest.class))).thenReturn(future); + }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); @@ -199,48 +186,20 @@ public void testGetConnector_NullResponse() throws InterruptedException { assertEquals("Failed to find connector with the provided connector id: connector_id", argumentCaptor.getValue().getMessage()); } - public void testGetConnector_IndexNotFoundException() throws InterruptedException { - PlainActionFuture future = PlainActionFuture.newFuture(); - future.onFailure(new IndexNotFoundException("Fail to find model")); - when(client.get(any(GetRequest.class))).thenReturn(future); - - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); - getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, latchedActionListener); - latch.await(); - - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Failed to find connector", argumentCaptor.getValue().getMessage()); - } - - public void testGetConnector_RuntimeException() throws InterruptedException { - PlainActionFuture future = PlainActionFuture.newFuture(); - future.onFailure(new RuntimeException("errorMessage")); - when(client.get(any(GetRequest.class))).thenReturn(future); - - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); - getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, latchedActionListener); - latch.await(); - - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - // TODO: Fix this nesting - // [OpenSearchException[java.lang.RuntimeException: errorMessage]; nested: RuntimeException[errorMessage]; - assertEquals("errorMessage", argumentCaptor.getValue().getCause().getCause().getMessage()); - } - + @Test public void testGetConnector_MultiTenancyEnabled_Success() throws IOException, InterruptedException { when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); when(connectorAccessControlHelper.hasPermission(any(), any())).thenReturn(true); String tenantId = "test_tenant"; mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(CONNECTOR_ID).tenantId(tenantId).build(); - GetResponse getResponse = prepareConnector(tenantId); - PlainActionFuture future = PlainActionFuture.newFuture(); - future.onResponse(getResponse); - when(client.get(any(GetRequest.class))).thenReturn(future); + HttpConnector httpConnector = HttpConnector.builder().name("test_connector").protocol("http").tenantId(tenantId).build(); + when(connectorAccessControlHelper.hasPermission(any(), any())).thenReturn(true); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(5); + listener.onResponse(httpConnector); + return null; + }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); @@ -252,16 +211,20 @@ public void testGetConnector_MultiTenancyEnabled_Success() throws IOException, I assertEquals(tenantId, argumentCaptor.getValue().getMlConnector().getTenantId()); } + @Test public void testGetConnector_MultiTenancyEnabled_ForbiddenAccess() throws IOException, InterruptedException { when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); when(connectorAccessControlHelper.hasPermission(any(), any())).thenReturn(true); String tenantId = "test_tenant"; mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(CONNECTOR_ID).tenantId(tenantId).build(); - GetResponse getResponse = prepareConnector("different_tenant"); - PlainActionFuture future = PlainActionFuture.newFuture(); - future.onResponse(getResponse); - when(client.get(any(GetRequest.class))).thenReturn(future); + HttpConnector httpConnector = HttpConnector.builder().name("test_connector").protocol("http").tenantId("tenantId").build(); + when(connectorAccessControlHelper.hasPermission(any(), any())).thenReturn(true); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(5); + listener.onResponse(httpConnector); + return null; + }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java index b73fd907a6..d4e929ea25 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java @@ -8,6 +8,8 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.*; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.utils.TestHelper.clusterSetting; @@ -37,10 +39,13 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; @@ -52,12 +57,17 @@ import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.sdkclient.LocalClusterIndicesClient; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; +import org.opensearch.sdk.SdkClient; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -67,6 +77,17 @@ public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { private UpdateConnectorTransportAction updateConnectorTransportAction; + private static TestThreadPool testThreadPool = new TestThreadPool( + UpdateConnectorTransportActionTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); + @Mock private ConnectorAccessControlHelper connectorAccessControlHelper; @@ -75,6 +96,7 @@ public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { @Mock private Client client; + private SdkClient sdkClient; @Mock private ThreadPool threadPool; @@ -85,9 +107,15 @@ public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { @Mock private TransportService transportService; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock private ActionFilters actionFilters; + @Mock + NamedXContentRegistry xContentRegistry; + @Mock private MLUpdateConnectorRequest updateRequest; @@ -116,6 +144,7 @@ public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); + sdkClient = new LocalClusterIndicesClient(client, xContentRegistry); settings = Settings .builder() .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) @@ -126,11 +155,14 @@ public void setup() throws IOException { ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED ); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); + Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), true).build(); threadContext = new ThreadContext(settings); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.executor(anyString())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); String connector_id = "test_connector_id"; MLCreateConnectorInput updateContent = MLCreateConnectorInput .builder() @@ -161,11 +193,13 @@ public void setup() throws IOException { transportService, actionFilters, client, + sdkClient, connectorAccessControlHelper, mlModelManager, settings, clusterService, - mlEngine + mlEngine, + mlFeatureEnabledSetting ); when(mlModelManager.getAllModelIds()).thenReturn(new String[] {}); @@ -173,7 +207,7 @@ public void setup() throws IOException { updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(5); Connector connector = HttpConnector .builder() .name("test") @@ -199,7 +233,7 @@ public void setup() throws IOException { // doNothing().when(connector).update(any(), any()); listener.onResponse(connector); return null; - }).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class)); + }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); } @Test diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java index cdd3c80184..09128f71da 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java @@ -57,6 +57,7 @@ import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -95,6 +96,9 @@ public class DeleteModelTransportActionTests extends OpenSearchTestCase { @Mock ClusterService clusterService; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + DeleteModelTransportAction deleteModelTransportAction; MLModelDeleteRequest mlModelDeleteRequest; ThreadContext threadContext; @@ -118,7 +122,8 @@ public void setup() throws IOException { settings, xContentRegistry, clusterService, - modelAccessControlHelper + modelAccessControlHelper, + mlFeatureEnabledSetting ) ); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java index 8a95805aa8..a287477a86 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java @@ -40,6 +40,9 @@ import org.opensearch.ml.common.transport.model.MLModelGetRequest; import org.opensearch.ml.common.transport.model.MLModelGetResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.sdkclient.LocalClusterIndicesClient; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.sdk.SdkClient; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -60,6 +63,8 @@ public class GetModelTransportActionTests extends OpenSearchTestCase { @Mock ActionFilters actionFilters; + SdkClient sdkClient; + @Mock ActionListener actionListener; @@ -78,21 +83,28 @@ public class GetModelTransportActionTests extends OpenSearchTestCase { @Mock private ModelAccessControlHelper modelAccessControlHelper; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); mlModelGetRequest = MLModelGetRequest.builder().modelId("test_id").build(); settings = Settings.builder().build(); + sdkClient = new LocalClusterIndicesClient(client, xContentRegistry); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); getModelTransportAction = spy( new GetModelTransportAction( transportService, actionFilters, client, + sdkClient, settings, xContentRegistry, clusterService, - modelAccessControlHelper + modelAccessControlHelper, + mlFeatureEnabledSetting ) ); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java index 942a968cf0..022ef66156 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java @@ -6,6 +6,7 @@ package org.opensearch.ml.action.models; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; @@ -15,6 +16,8 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.utils.TestHelper.clusterSetting; @@ -50,6 +53,8 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.authuser.User; @@ -58,6 +63,7 @@ import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.get.GetResult; @@ -80,8 +86,13 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.sdkclient.LocalClusterIndicesClient; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.sdk.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -93,6 +104,7 @@ public class UpdateModelTransportActionTests extends OpenSearchTestCase { @Mock Client client; + private SdkClient sdkClient; @Mock Task task; @@ -121,6 +133,9 @@ public class UpdateModelTransportActionTests extends OpenSearchTestCase { @Mock MLModelGroupManager mlModelGroupManager; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock private ModelAccessControlHelper modelAccessControlHelper; @@ -162,11 +177,26 @@ public class UpdateModelTransportActionTests extends OpenSearchTestCase { @Mock MLEngine mlEngine; + @Mock + NamedXContentRegistry xContentRegistry; + private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList.of("^https://api\\.test\\.com/.*$"); + private static TestThreadPool testThreadPool = new TestThreadPool( + UpdateModelTransportActionTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); + @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); + sdkClient = new LocalClusterIndicesClient(client, xContentRegistry); updateLocalModelInput = MLUpdateModelInput .builder() .modelId("test_model_id") @@ -225,6 +255,7 @@ public void setup() throws IOException { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterState.nodes()).thenReturn(nodes); when(mlModelManager.getWorkerNodes("test_model_id", FunctionName.REMOTE)).thenReturn(targetNodeIds); + when(threadPool.executor(anyString())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); shardId = new ShardId(new Index("indexName", "uuid"), 1); updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); @@ -234,13 +265,15 @@ public void setup() throws IOException { transportService, actionFilters, client, + sdkClient, connectorAccessControlHelper, modelAccessControlHelper, mlModelManager, mlModelGroupManager, settings, clusterService, - mlEngine + mlEngine, + mlFeatureEnabledSetting ) ); @@ -281,12 +314,10 @@ public void setup() throws IOException { .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(5); listener.onResponse(true); return null; - }) - .when(connectorAccessControlHelper) - .validateConnectorAccess(any(Client.class), eq("updated_test_connector_id"), isA(ActionListener.class)); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -435,10 +466,10 @@ public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControl }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(5); listener.onResponse(false); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_EXTERNAL"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -459,13 +490,13 @@ public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControl }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(5); listener .onFailure( new RuntimeException("Any other connector access control Exception occurred. Please check log for more details.") ); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_EXTERNAL"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index d30ef15a5a..fd511c4ff8 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -6,7 +6,6 @@ package org.opensearch.ml.action.register; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; @@ -43,6 +42,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLTask; @@ -61,6 +61,7 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.sdkclient.LocalClusterIndicesClient; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStat; @@ -68,6 +69,7 @@ import org.opensearch.ml.task.MLTaskDispatcher; import org.opensearch.ml.task.MLTaskManager; import org.opensearch.ml.utils.TestHelper; +import org.opensearch.sdk.SdkClient; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.tasks.Task; @@ -137,6 +139,8 @@ public class TransportRegisterModelActionTests extends OpenSearchTestCase { @Mock private IndexResponse indexResponse; + private SdkClient sdkClient; + ThreadContext threadContext; private TransportRegisterModelAction transportRegisterModelAction; @@ -155,6 +159,9 @@ public class TransportRegisterModelActionTests extends OpenSearchTestCase { @Mock private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock + NamedXContentRegistry xContentRegistry; + @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); @@ -165,6 +172,7 @@ public void setup() throws IOException { .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) .build(); threadContext = new ThreadContext(settings); + sdkClient = new LocalClusterIndicesClient(client, xContentRegistry); ClusterSettings clusterSettings = clusterSetting( settings, ML_COMMONS_TRUSTED_URL_REGEX, @@ -184,6 +192,7 @@ public void setup() throws IOException { settings, threadPool, client, + sdkClient, nodeFilter, mlTaskDispatcher, mlStats, @@ -367,6 +376,7 @@ public void testRegisterModelUrlNotAllowed() throws Exception { settings, threadPool, client, + sdkClient, nodeFilter, mlTaskDispatcher, mlStats, @@ -459,10 +469,10 @@ public void test_execute_registerRemoteModel_withConnectorId_success() { when(input.getConnectorId()).thenReturn("mockConnectorId"); when(input.getFunctionName()).thenReturn(FunctionName.REMOTE); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(5); listener.onResponse(true); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), anyString(), isA(ActionListener.class)); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), isA(ActionListener.class)); MLRegisterModelResponse response = mock(MLRegisterModelResponse.class); transportRegisterModelAction.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); @@ -476,10 +486,10 @@ public void test_execute_registerRemoteModel_withConnectorId_noPermissionToConne when(input.getConnectorId()).thenReturn("mockConnectorId"); when(input.getFunctionName()).thenReturn(FunctionName.REMOTE); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(5); listener.onResponse(false); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), anyString(), isA(ActionListener.class)); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), isA(ActionListener.class)); MLRegisterModelResponse response = mock(MLRegisterModelResponse.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -502,10 +512,10 @@ public void test_execute_registerRemoteModel_withConnectorId_connectorValidation when(input.getConnectorId()).thenReturn("mockConnectorId"); when(input.getFunctionName()).thenReturn(FunctionName.REMOTE); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(5); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), anyString(), isA(ActionListener.class)); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), isA(ActionListener.class)); transportRegisterModelAction.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); diff --git a/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java index 30c9f6191c..68bef15120 100644 --- a/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java @@ -8,32 +8,48 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING; import static org.opensearch.ml.utils.TestHelper.clusterSetting; import java.io.IOException; +import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.junit.AfterClass; import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.get.GetResult; @@ -41,10 +57,17 @@ import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.CommonValue; +import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorProtocols; import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.sdkclient.LocalClusterIndicesClient; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.sdk.GetDataObjectRequest; +import org.opensearch.sdk.SdkClient; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableList; @@ -60,6 +83,9 @@ public class ConnectorAccessControlHelperTests extends OpenSearchTestCase { @Mock private ActionListener actionListener; + @Mock + private ActionListener getConnectorActionListener; + @Mock private ThreadPool threadPool; @@ -71,14 +97,35 @@ public class ConnectorAccessControlHelperTests extends OpenSearchTestCase { private User user; + SdkClient sdkClient; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + + private static TestThreadPool testThreadPool = new TestThreadPool( + ConnectorAccessControlHelperTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); + @Before public void setup() { MockitoAnnotations.openMocks(this); Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), true).build(); + sdkClient = new LocalClusterIndicesClient(client, xContentRegistry); threadContext = new ThreadContext(settings); ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - connectorAccessControlHelper = new ConnectorAccessControlHelper(clusterService, settings); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); + connectorAccessControlHelper = spy(new ConnectorAccessControlHelper(clusterService, settings)); user = User.parse("mockUser|role-1,role-2|null"); getResponse = createGetResponse(null); @@ -90,14 +137,22 @@ public void setup() { when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.executor(any())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); + } + + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); } + @Test public void test_hasPermission_user_null_return_true() { HttpConnector httpConnector = mock(HttpConnector.class); boolean hasPermission = connectorAccessControlHelper.hasPermission(null, httpConnector); assertTrue(hasPermission); } + @Test public void test_hasPermission_connectorAccessControl_not_enabled_return_true() { HttpConnector httpConnector = mock(HttpConnector.class); Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), false).build(); @@ -108,6 +163,7 @@ public void test_hasPermission_connectorAccessControl_not_enabled_return_true() assertTrue(hasPermission); } + @Test public void test_hasPermission_connectorOwner_is_null_return_true() { HttpConnector httpConnector = mock(HttpConnector.class); when(httpConnector.getOwner()).thenReturn(null); @@ -115,12 +171,14 @@ public void test_hasPermission_connectorOwner_is_null_return_true() { assertTrue(hasPermission); } + @Test public void test_hasPermission_user_is_admin_return_true() { User user = User.parse("admin|role-1|all_access"); boolean hasPermission = connectorAccessControlHelper.hasPermission(user, mock(HttpConnector.class)); assertTrue(hasPermission); } + @Test public void test_hasPermission_connector_isPublic_return_true() { HttpConnector httpConnector = mock(HttpConnector.class); when(httpConnector.getAccess()).thenReturn(AccessMode.PUBLIC); @@ -128,6 +186,7 @@ public void test_hasPermission_connector_isPublic_return_true() { assertTrue(hasPermission); } + @Test public void test_hasPermission_connector_isPrivate_userIsOwner_return_true() { HttpConnector httpConnector = mock(HttpConnector.class); when(httpConnector.getAccess()).thenReturn(AccessMode.PRIVATE); @@ -136,6 +195,7 @@ public void test_hasPermission_connector_isPrivate_userIsOwner_return_true() { assertTrue(hasPermission); } + @Test public void test_hasPermission_connector_isPrivate_userIsNotOwner_return_false() { HttpConnector httpConnector = mock(HttpConnector.class); when(httpConnector.getAccess()).thenReturn(AccessMode.PRIVATE); @@ -145,6 +205,7 @@ public void test_hasPermission_connector_isPrivate_userIsNotOwner_return_false() assertFalse(hasPermission); } + @Test public void test_hasPermission_connector_isRestricted_userHasBackendRole_return_true() { HttpConnector httpConnector = mock(HttpConnector.class); when(httpConnector.getAccess()).thenReturn(AccessMode.RESTRICTED); @@ -153,6 +214,7 @@ public void test_hasPermission_connector_isRestricted_userHasBackendRole_return_ assertTrue(hasPermission); } + @Test public void test_hasPermission_connector_isRestricted_userNotHasBackendRole_return_false() { HttpConnector httpConnector = mock(HttpConnector.class); when(httpConnector.getAccess()).thenReturn(AccessMode.RESTRICTED); @@ -162,7 +224,8 @@ public void test_hasPermission_connector_isRestricted_userNotHasBackendRole_retu assertFalse(hasPermission); } - public void test_validateConnectorAccess_user_isAdmin_return_true() { + // todo: will remove this later + public void test_validateConnectorAccess_user_isAdmin_return_true_old() { String userString = "admin|role-1|all_access"; Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), true).build(); ThreadContext threadContext = new ThreadContext(settings); @@ -174,7 +237,21 @@ public void test_validateConnectorAccess_user_isAdmin_return_true() { verify(actionListener).onResponse(true); } - public void test_validateConnectorAccess_user_isNotAdmin_hasNoBackendRole_return_false() { + @Test + public void test_validateConnectorAccess_user_isAdmin_return_true() { + String userString = "admin|role-1|all_access"; + Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), true).build(); + ThreadContext threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userString); + + connectorAccessControlHelper.validateConnectorAccess(sdkClient, client, "anyId", null, mlFeatureEnabledSetting, actionListener); + verify(actionListener).onResponse(true); + } + + // todo will remove later. + public void test_validateConnectorAccess_user_isNotAdmin_hasNoBackendRole_return_false_old() { GetResponse getResponse = createGetResponse(ImmutableList.of("role-3")); Client client = mock(Client.class); doAnswer(invocation -> { @@ -190,12 +267,67 @@ public void test_validateConnectorAccess_user_isNotAdmin_hasNoBackendRole_return verify(actionListener).onResponse(false); } + @Test + public void test_validateConnectorAccess_user_isNotAdmin_hasNoBackendRole_return_false() throws Exception { + // Mock the client thread pool + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + // Set up user context + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); + // Create HttpConnector + HttpConnector httpConnector = HttpConnector.builder() + .name("testConnector") + .protocol(ConnectorProtocols.HTTP) + .owner(user) + .description("This is test connector") + .backendRoles(Collections.singletonList("role-3")) + .accessMode(AccessMode.RESTRICTED) + .build(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(5); + listener.onResponse(httpConnector); + return null; + }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); + + // Execute the validation + connectorAccessControlHelper.validateConnectorAccess(sdkClient, client, "anyId", null, mlFeatureEnabledSetting, actionListener); + + // Verify the action listener was called with false + verify(actionListener).onResponse(false); + } + + @Test public void test_validateConnectorAccess_user_isNotAdmin_hasBackendRole_return_true() { + connectorAccessControlHelper.validateConnectorAccess(sdkClient, client, "anyId", null, mlFeatureEnabledSetting, actionListener); + verify(actionListener).onResponse(true); + } + + // todo will remove later + public void test_validateConnectorAccess_user_isNotAdmin_hasBackendRole_return_true_old() { connectorAccessControlHelper.validateConnectorAccess(client, "anyId", actionListener); verify(actionListener).onResponse(true); } + @Test public void test_validateConnectorAccess_connectorNotFound_return_false() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(5); + listener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND)); + return null; + }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); + + // connectorAccessControlHelper.validateConnectorAccess(client, "anyId", actionListener); + connectorAccessControlHelper.validateConnectorAccess(sdkClient, client, "anyId", null, mlFeatureEnabledSetting, actionListener); + verify(actionListener, times(1)).onFailure(any(OpenSearchStatusException.class)); + } + + // todo will remove later + public void test_validateConnectorAccess_connectorNotFound_return_false_old() { Client client = mock(Client.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -210,7 +342,24 @@ public void test_validateConnectorAccess_connectorNotFound_return_false() { verify(actionListener, times(1)).onFailure(any(OpenSearchStatusException.class)); } + @Test public void test_validateConnectorAccess_searchConnectorException_return_false() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(5); + listener.onFailure(new RuntimeException("Failed to find connector")); + return null; + }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); + + // connectorAccessControlHelper.validateConnectorAccess(client, "anyId", actionListener); + connectorAccessControlHelper.validateConnectorAccess(sdkClient, client, "anyId", null, mlFeatureEnabledSetting, actionListener); + verify(actionListener, times(1)).onFailure(any(RuntimeException.class)); + } + + // todo will remove later + public void test_validateConnectorAccess_searchConnectorException_return_false_old() { Client client = mock(Client.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -225,11 +374,13 @@ public void test_validateConnectorAccess_searchConnectorException_return_false() verify(actionListener).onFailure(any(OpenSearchStatusException.class)); } + @Test public void test_skipConnectorAccessControl_userIsNull_return_true() { boolean skip = connectorAccessControlHelper.skipConnectorAccessControl(null); assertTrue(skip); } + @Test public void test_skipConnectorAccessControl_connectorAccessControl_notEnabled_return_true() { Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), false).build(); ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED); @@ -239,12 +390,14 @@ public void test_skipConnectorAccessControl_connectorAccessControl_notEnabled_re assertTrue(skip); } + @Test public void test_skipConnectorAccessControl_userIsAdmin_return_true() { User user = User.parse("admin|role-1|all_access"); boolean skip = connectorAccessControlHelper.skipConnectorAccessControl(user); assertTrue(skip); } + @Test public void test_accessControlNotEnabled_connectorAccessControl_notEnabled_return_true() { Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), false).build(); ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED); @@ -254,17 +407,20 @@ public void test_accessControlNotEnabled_connectorAccessControl_notEnabled_retur assertTrue(skip); } + @Test public void test_accessControlNotEnabled_userIsNull_return_true() { boolean notEnabled = connectorAccessControlHelper.accessControlNotEnabled(null); assertTrue(notEnabled); } + @Test public void test_addUserBackendRolesFilter_nullQuery() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); SearchSourceBuilder result = connectorAccessControlHelper.addUserBackendRolesFilter(user, searchSourceBuilder); assertNotNull(result); } + @Test public void test_addUserBackendRolesFilter_boolQuery() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.query(new BoolQueryBuilder()); @@ -272,6 +428,7 @@ public void test_addUserBackendRolesFilter_boolQuery() { assertEquals("bool", result.query().getName()); } + @Test public void test_addUserBackendRolesFilter_nonBoolQuery() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.query(new MatchAllQueryBuilder()); @@ -279,6 +436,40 @@ public void test_addUserBackendRolesFilter_nonBoolQuery() { assertEquals("bool", result.query().getName()); } + @Test + public void testGetConnectorHappyCase() throws IOException, InterruptedException { + GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder() + .index(CommonValue.ML_CONNECTOR_INDEX) + .id("connectorId") + .build(); + GetResponse getResponse = prepareConnector(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(getResponse); + when(client.get(any(GetRequest.class))).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(getConnectorActionListener, latch); + connectorAccessControlHelper + .getConnector( + sdkClient, + client, + client.threadPool().getThreadContext().newStoredContext(true), + getRequest, + "connectorId", + latchedActionListener + ); + latch.await(); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(GetRequest.class); + verify(client, times(1)).get(requestCaptor.capture()); + assertEquals(CommonValue.ML_CONNECTOR_INDEX, requestCaptor.getValue().index()); + } + private GetResponse createGetResponse(List backendRoles) { HttpConnector httpConnector = HttpConnector .builder() @@ -289,7 +480,7 @@ private GetResponse createGetResponse(List backendRoles) { .backendRoles(Optional.ofNullable(backendRoles).orElse(ImmutableList.of("role-1"))) .accessMode(AccessMode.RESTRICTED) .build(); - XContentBuilder content = null; + XContentBuilder content; try { content = httpConnector.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); } catch (IOException e) { @@ -299,4 +490,13 @@ private GetResponse createGetResponse(List backendRoles) { GetResult getResult = new GetResult(CommonValue.ML_MODEL_GROUP_INDEX, "111", 111l, 111l, 111l, true, bytesReference, null, null); return new GetResponse(getResult); } + + public GetResponse prepareConnector() throws IOException { + HttpConnector httpConnector = HttpConnector.builder().name("test_connector").protocol("http").build(); + XContentBuilder content = httpConnector.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); + GetResponse getResponse = new GetResponse(getResult); + return getResponse; + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelActionTests.java index e3ef6422d0..0b8a706df7 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelActionTests.java @@ -18,6 +18,7 @@ import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.client.node.NodeClient; import org.opensearch.common.settings.Settings; @@ -26,6 +27,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.transport.model.MLModelDeleteAction; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -44,9 +46,14 @@ public class RestMLDeleteModelActionTests extends OpenSearchTestCase { @Mock RestChannel channel; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() { - restMLDeleteModelAction = new RestMLDeleteModelAction(); + MockitoAnnotations.openMocks(this); + when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false); + restMLDeleteModelAction = new RestMLDeleteModelAction(mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -66,7 +73,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLDeleteModelAction mlDeleteModelAction = new RestMLDeleteModelAction(); + RestMLDeleteModelAction mlDeleteModelAction = new RestMLDeleteModelAction(mlFeatureEnabledSetting); assertNotNull(mlDeleteModelAction); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelActionTests.java index a17358e213..5a6f884381 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelActionTests.java @@ -9,18 +9,23 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; import static org.mockito.Mockito.times; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; @@ -28,6 +33,7 @@ import org.opensearch.ml.common.transport.model.MLModelGetAction; import org.opensearch.ml.common.transport.model.MLModelGetRequest; import org.opensearch.ml.common.transport.model.MLModelGetResponse; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -48,9 +54,22 @@ public class RestMLGetModelActionTests extends OpenSearchTestCase { @Mock RestChannel channel; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + Settings settings; + + @Mock + private ClusterService clusterService; + @Before public void setup() { - restMLGetModelAction = new RestMLGetModelAction(); + MockitoAnnotations.openMocks(this); + settings = Settings.builder().put(ML_COMMONS_MULTI_TENANCY_ENABLED.getKey(), false).build(); + when(clusterService.getSettings()).thenReturn(settings); + when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MULTI_TENANCY_ENABLED))); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); + restMLGetModelAction = new RestMLGetModelAction(mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -70,7 +89,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLGetModelAction mlGetModelAction = new RestMLGetModelAction(); + RestMLGetModelAction mlGetModelAction = new RestMLGetModelAction(mlFeatureEnabledSetting); assertNotNull(mlGetModelAction); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelGroupActionTests.java index 0f99b406df..f1a2e3c932 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelGroupActionTests.java @@ -11,18 +11,24 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_GROUP_ID; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; @@ -30,6 +36,7 @@ import org.opensearch.ml.common.transport.model_group.MLModelGroupGetAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetRequest; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetResponse; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -44,14 +51,26 @@ public class RestMLGetModelGroupActionTests extends OpenSearchTestCase { private RestMLGetModelGroupAction restMLGetModelGroupAction; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Mock + private ClusterService clusterService; + NodeClient client; private ThreadPool threadPool; + Settings settings; @Mock RestChannel channel; @Before public void setup() { + MockitoAnnotations.openMocks(this); + settings = Settings.builder().put(ML_COMMONS_MULTI_TENANCY_ENABLED.getKey(), false).build(); + when(clusterService.getSettings()).thenReturn(settings); + when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MULTI_TENANCY_ENABLED))); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); restMLGetModelGroupAction = new RestMLGetModelGroupAction(); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); @@ -72,7 +91,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLGetModelAction mlGetModelAction = new RestMLGetModelAction(); + RestMLGetModelAction mlGetModelAction = new RestMLGetModelAction(mlFeatureEnabledSetting); assertNotNull(mlGetModelAction); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java index c11c7e3fb8..1a2d1f9598 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.opensearch.ml.utils.TestHelper.toJsonString; import java.util.HashMap; @@ -38,6 +39,7 @@ import org.opensearch.ml.common.transport.model.MLUpdateModelAction; import org.opensearch.ml.common.transport.model.MLUpdateModelInput; import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -56,6 +58,9 @@ public class RestMLUpdateModelActionTests extends OpenSearchTestCase { private NodeClient client; private ThreadPool threadPool; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock RestChannel channel; @@ -64,7 +69,8 @@ public void setup() { MockitoAnnotations.openMocks(this); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); - restMLUpdateModelAction = new RestMLUpdateModelAction(); + when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true); + restMLUpdateModelAction = new RestMLUpdateModelAction(mlFeatureEnabledSetting); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); return null; @@ -80,7 +86,7 @@ public void tearDown() throws Exception { @Test public void testConstructor() { - RestMLUpdateModelAction UpdateModelAction = new RestMLUpdateModelAction(); + RestMLUpdateModelAction UpdateModelAction = new RestMLUpdateModelAction(mlFeatureEnabledSetting); assertNotNull(UpdateModelAction); } From 7d5afbc1e0b80a09d3fc4841818a03fc6c654bed Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Fri, 14 Jun 2024 09:49:55 -0700 Subject: [PATCH 04/10] [Feature/multi_tenancy] Add source map to GetDataObjectResponse (#2489) * Add source map to GetDataObjectResponse Signed-off-by: Daniel Widdis * Add test for map getter in clients Signed-off-by: Daniel Widdis --------- Signed-off-by: Daniel Widdis Signed-off-by: Arjun kumar Giri --- .../opensearch/sdk/GetDataObjectResponse.java | 36 +++++++++++++++---- .../sdk/GetDataObjectResponseTests.java | 6 +++- .../sdkclient/LocalClusterIndicesClient.java | 6 +++- .../sdkclient/RemoteClusterIndicesClient.java | 8 +++-- .../LocalClusterIndicesClientTests.java | 7 +++- .../RemoteClusterIndicesClientTests.java | 1 + 6 files changed, 52 insertions(+), 12 deletions(-) diff --git a/common/src/main/java/org/opensearch/sdk/GetDataObjectResponse.java b/common/src/main/java/org/opensearch/sdk/GetDataObjectResponse.java index 3d98ccd36c..b884cc9eb4 100644 --- a/common/src/main/java/org/opensearch/sdk/GetDataObjectResponse.java +++ b/common/src/main/java/org/opensearch/sdk/GetDataObjectResponse.java @@ -10,22 +10,27 @@ import org.opensearch.core.xcontent.XContentParser; +import java.util.Collections; +import java.util.Map; import java.util.Optional; public class GetDataObjectResponse { private final String id; private final Optional parser; + private final Map source; /** - * Instantiate this request with an id and parser used to recreate the data object. + * Instantiate this request with an id and parser/map used to recreate the data object. *

* For data storage implementations other than OpenSearch, the id may be referred to as a primary key. * @param id the document id - * @param parser an optional XContentParser that can be used to create the object if present. + * @param parser an optional XContentParser that can be used to create the data object if present. + * @param source the data object as a map */ - public GetDataObjectResponse(String id, Optional parser) { + public GetDataObjectResponse(String id, Optional parser, Map source) { this.id = id; this.parser = parser; + this.source = source; } /** @@ -37,12 +42,20 @@ public String id() { } /** - * Returns the parser optional + * Returns the parser optional. If present, is a representation of the data object that may be parsed. * @return the parser optional */ public Optional parser() { return this.parser; } + + /** + * Returns the source map. This is a logical representation of the data object. + * @return the source map + */ + public Map source() { + return this.source; + } /** * Class for constructing a Builder for this Response Object @@ -50,6 +63,7 @@ public Optional parser() { public static class Builder { private String id = null; private Optional parser = Optional.empty(); + private Map source = Collections.emptyMap(); /** * Empty Constructor for the Builder object @@ -68,7 +82,7 @@ public Builder id(String id) { /** * Add an optional parser to this builder - * @param parser an {@link Optional} which may contain the parser + * @param parser an {@link Optional} which may contain the data object parser * @return the updated builder */ public Builder parser(Optional parser) { @@ -76,12 +90,22 @@ public Builder parser(Optional parser) { return this; } + /** + * Add a source map to this builder + * @param source the data object as a map + * @return the updated builder + */ + public Builder source(Map source) { + this.source = source; + return this; + } + /** * Builds the response * @return A {@link GetDataObjectResponse} */ public GetDataObjectResponse build() { - return new GetDataObjectResponse(this.id, this.parser); + return new GetDataObjectResponse(this.id, this.parser, this.source); } } } diff --git a/common/src/test/java/org/opensearch/sdk/GetDataObjectResponseTests.java b/common/src/test/java/org/opensearch/sdk/GetDataObjectResponseTests.java index a7a900ba17..9e79593dd8 100644 --- a/common/src/test/java/org/opensearch/sdk/GetDataObjectResponseTests.java +++ b/common/src/test/java/org/opensearch/sdk/GetDataObjectResponseTests.java @@ -12,6 +12,7 @@ import org.junit.Test; import org.opensearch.core.xcontent.XContentParser; +import java.util.Map; import java.util.Optional; import static org.junit.Assert.assertEquals; @@ -21,18 +22,21 @@ public class GetDataObjectResponseTests { private String testId; private XContentParser testParser; + private Map testSource; @Before public void setUp() { testId = "test-id"; testParser = mock(XContentParser.class); + testSource = Map.of("foo", "bar"); } @Test public void testGetDataObjectResponse() { - GetDataObjectResponse response = new GetDataObjectResponse.Builder().id(testId).parser(Optional.of(testParser)).build(); + GetDataObjectResponse response = new GetDataObjectResponse.Builder().id(testId).parser(Optional.of(testParser)).source(testSource).build(); assertEquals(testId, response.id()); assertEquals(testParser, response.parser().get()); + assertEquals(testSource, response.source()); } } diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java index 6e3db16784..2af15f6bb3 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java @@ -97,7 +97,11 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe XContentParser parser = jsonXContent .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.getSourceAsString()); log.info("Retrieved data object"); - return new GetDataObjectResponse.Builder().id(getResponse.getId()).parser(Optional.of(parser)).build(); + return new GetDataObjectResponse.Builder() + .id(getResponse.getId()) + .parser(Optional.of(parser)) + .source(getResponse.getSource()) + .build(); } catch (OpenSearchStatusException | IndexNotFoundException notFound) { throw notFound; } catch (Exception e) { diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java index 5f55a4471a..8284b95f4e 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java @@ -87,11 +87,13 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe if (!getResponse.found()) { return new GetDataObjectResponse.Builder().id(getResponse.id()).build(); } - String json = new ObjectMapper().setSerializationInclusion(Include.NON_NULL).writeValueAsString(getResponse.source()); - log.info("Retrieved data object"); + // Since we use the JacksonJsonBMapper we know this is String-Object map + @SuppressWarnings("unchecked") + Map source = getResponse.source(); + String json = new ObjectMapper().setSerializationInclusion(Include.NON_NULL).writeValueAsString(source); XContentParser parser = JsonXContent.jsonXContent .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, json); - return new GetDataObjectResponse.Builder().id(getResponse.id()).parser(Optional.of(parser)).build(); + return new GetDataObjectResponse.Builder().id(getResponse.id()).parser(Optional.of(parser)).source(source).build(); } catch (Exception e) { throw new OpenSearchException(e); } diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java index e113c03c97..ffd0e0072b 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java @@ -43,6 +43,8 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -143,7 +145,9 @@ public void testGetDataObject() throws IOException { GetResponse getResponse = mock(GetResponse.class); when(getResponse.isExists()).thenReturn(true); when(getResponse.getId()).thenReturn(TEST_ID); - when(getResponse.getSourceAsString()).thenReturn(testDataObject.toJson()); + String json = testDataObject.toJson(); + when(getResponse.getSourceAsString()).thenReturn(json); + when(getResponse.getSource()).thenReturn(XContentHelper.convertToMap(JsonXContent.jsonXContent, json, false)); @SuppressWarnings("unchecked") ActionFuture future = mock(ActionFuture.class); when(mockedClient.get(any(GetRequest.class))).thenReturn(future); @@ -158,6 +162,7 @@ public void testGetDataObject() throws IOException { verify(mockedClient, times(1)).get(requestCaptor.capture()); assertEquals(TEST_INDEX, requestCaptor.getValue().index()); assertEquals(TEST_ID, response.id()); + assertEquals("foo", response.source().get("data")); assertTrue(response.parser().isPresent()); XContentParser parser = response.parser().get(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java index 239f3420eb..ee7ee53f98 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java @@ -173,6 +173,7 @@ public void testGetDataObject() throws IOException { assertEquals(TEST_INDEX, getRequestCaptor.getValue().index()); assertEquals(TEST_ID, response.id()); + assertEquals("foo", response.source().get("data")); assertTrue(response.parser().isPresent()); XContentParser parser = response.parser().get(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); From 86aea1bb6b5b44dfb72a9c28e50093fb6904e497 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Fri, 14 Jun 2024 16:06:16 -0700 Subject: [PATCH 05/10] [Feature/multi_tenancy] Add UpdateDataObject interface, Client, and Connector Implementations (#2520) * Restore original exception handling expectations Signed-off-by: Daniel Widdis * Add UpdateDataObject to interface and implementations Signed-off-by: Daniel Widdis * Implement UpdateConnector action Signed-off-by: Daniel Widdis * Move CompletionException handling to a common method Signed-off-by: Daniel Widdis * Add tests for SDKClient exceptions refactored from Transport Action Signed-off-by: Daniel Widdis --------- Signed-off-by: Daniel Widdis Signed-off-by: Arjun kumar Giri --- .../java/org/opensearch/sdk/SdkClient.java | 59 ++++++-- .../sdk/UpdateDataObjectRequest.java | 108 ++++++++++++++ .../sdk/UpdateDataObjectResponse.java | 140 ++++++++++++++++++ .../org/opensearch/sdk/SdkClientTests.java | 58 +++++++- .../DeleteConnectorTransportAction.java | 12 +- .../GetConnectorTransportAction.java | 1 - .../TransportCreateConnectorAction.java | 9 +- .../UpdateConnectorTransportAction.java | 58 ++++++-- .../helper/ConnectorAccessControlHelper.java | 9 +- .../sdkclient/LocalClusterIndicesClient.java | 80 +++++++--- .../sdkclient/RemoteClusterIndicesClient.java | 74 ++++++++- .../DeleteConnectorTransportActionTests.java | 4 +- .../GetConnectorTransportActionTests.java | 1 - .../UpdateConnectorTransportActionTests.java | 103 ++++++++----- .../ConnectorAccessControlHelperTests.java | 66 ++++++++- .../LocalClusterIndicesClientTests.java | 115 +++++++++++--- .../RemoteClusterIndicesClientTests.java | 93 +++++++++++- 17 files changed, 842 insertions(+), 148 deletions(-) create mode 100644 common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java create mode 100644 common/src/main/java/org/opensearch/sdk/UpdateDataObjectResponse.java diff --git a/common/src/main/java/org/opensearch/sdk/SdkClient.java b/common/src/main/java/org/opensearch/sdk/SdkClient.java index 9fb195e13f..78f3d8b9a5 100644 --- a/common/src/main/java/org/opensearch/sdk/SdkClient.java +++ b/common/src/main/java/org/opensearch/sdk/SdkClient.java @@ -42,11 +42,7 @@ default PutDataObjectResponse putDataObject(PutDataObjectRequest request) { try { return putDataObjectAsync(request).toCompletableFuture().join(); } catch (CompletionException e) { - Throwable cause = e.getCause(); - if (cause instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } - throw new OpenSearchException(cause); + throw unwrapAndConvertToRuntime(e); } } @@ -76,11 +72,37 @@ default GetDataObjectResponse getDataObject(GetDataObjectRequest request) { try { return getDataObjectAsync(request).toCompletableFuture().join(); } catch (CompletionException e) { - Throwable cause = e.getCause(); - if (cause instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } - throw new OpenSearchException(cause); + throw unwrapAndConvertToRuntime(e); + } + } + + /** + * Update a data object/document in a table/index. + * @param request A request identifying the data object to update + * @param executor the executor to use for asynchronous execution + * @return A completion stage encapsulating the response or exception + */ + public CompletionStage updateDataObjectAsync(UpdateDataObjectRequest request, Executor executor); + + /** + * Update a data object/document in a table/index. + * @param request A request identifying the data object to update + * @return A completion stage encapsulating the response or exception + */ + default CompletionStage updateDataObjectAsync(UpdateDataObjectRequest request) { + return updateDataObjectAsync(request, ForkJoinPool.commonPool()); + } + + /** + * Update a data object/document in a table/index. + * @param request A request identifying the data object to update + * @return A response on success. Throws {@link OpenSearchException} wrapping the cause on exception. + */ + default UpdateDataObjectResponse updateDataObject(UpdateDataObjectRequest request) { + try { + return updateDataObjectAsync(request).toCompletableFuture().join(); + } catch (CompletionException e) { + throw unwrapAndConvertToRuntime(e); } } @@ -110,11 +132,18 @@ default DeleteDataObjectResponse deleteDataObject(DeleteDataObjectRequest reques try { return deleteDataObjectAsync(request).toCompletableFuture().join(); } catch (CompletionException e) { - Throwable cause = e.getCause(); - if (cause instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } - throw new OpenSearchException(cause); + throw unwrapAndConvertToRuntime(e); + } + } + + private static RuntimeException unwrapAndConvertToRuntime(CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + if (cause instanceof RuntimeException) { + return (RuntimeException) cause; } + return new OpenSearchException(cause); } } diff --git a/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java new file mode 100644 index 0000000000..25891a1167 --- /dev/null +++ b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.sdk; + +import org.opensearch.core.xcontent.ToXContentObject; + +public class UpdateDataObjectRequest { + + private final String index; + private final String id; + private final ToXContentObject dataObject; + + /** + * Instantiate this request with an index and data object. + *

+ * For data storage implementations other than OpenSearch, an index may be referred to as a table and the data object may be referred to as an item. + * @param index the index location to update the object + * @param id the document id + * @param dataObject the data object + */ + public UpdateDataObjectRequest(String index, String id, ToXContentObject dataObject) { + this.index = index; + this.id = id; + this.dataObject = dataObject; + } + + /** + * Returns the index + * @return the index + */ + public String index() { + return this.index; + } + + /** + * Returns the document id + * @return the id + */ + public String id() { + return this.id; + } + + /** + * Returns the data object + * @return the data object + */ + public ToXContentObject dataObject() { + return this.dataObject; + } + + /** + * Class for constructing a Builder for this Request Object + */ + public static class Builder { + private String index = null; + private String id = null; + private ToXContentObject dataObject = null; + + /** + * Empty Constructor for the Builder object + */ + public Builder() {} + + /** + * Add an index to this builder + * @param index the index to put the object + * @return the updated builder + */ + public Builder index(String index) { + this.index = index; + return this; + } + + /** + * Add an id to this builder + * @param id the document id + * @return the updated builder + */ + public Builder id(String id) { + this.id = id; + return this; + } + + /** + * Add a data object to this builder + * @param dataObject the data object + * @return the updated builder + */ + public Builder dataObject(ToXContentObject dataObject) { + this.dataObject = dataObject; + return this; + } + + /** + * Builds the request + * @return A {@link UpdateDataObjectRequest} + */ + public UpdateDataObjectRequest build() { + return new UpdateDataObjectRequest(this.index, this.id, this.dataObject); + } + } +} diff --git a/common/src/main/java/org/opensearch/sdk/UpdateDataObjectResponse.java b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectResponse.java new file mode 100644 index 0000000000..56711c60d6 --- /dev/null +++ b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectResponse.java @@ -0,0 +1,140 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.sdk; + +import org.opensearch.action.support.replication.ReplicationResponse.ShardInfo; +import org.opensearch.core.common.Strings; +import org.opensearch.core.index.shard.ShardId; + +public class UpdateDataObjectResponse { + private final String id; + private final ShardId shardId; + private final ShardInfo shardInfo; + private final boolean updated; + + /** + * Instantiate this request with an id and update status. + *

+ * For data storage implementations other than OpenSearch, the id may be referred to as a primary key. + * @param id the document id + * @param shardId the shard id + * @param shardInfo the shard info + * @param updated Whether the object was updated. + */ + public UpdateDataObjectResponse(String id, ShardId shardId, ShardInfo shardInfo, boolean updated) { + this.id = id; + this.shardId = shardId; + this.shardInfo = shardInfo; + this.updated = updated; + } + + /** + * Returns the document id + * @return the id + */ + public String id() { + return id; + } + + /** + * Returns the shard id. + * @return the shard id, or a generated id if shards are not applicable + */ + public ShardId shardId() { + return shardId; + } + + /** + * Returns the shard info. + * @return the shard info, or generated info if shards are not applicable + */ + public ShardInfo shardInfo() { + return shardInfo; + } + + /** + * Returns whether update was successful + * @return true if update was successful + */ + public boolean updated() { + return updated; + } + + /** + * Class for constructing a Builder for this Response Object + */ + public static class Builder { + private String id = null; + private ShardId shardId = null; + private ShardInfo shardInfo = null; + private boolean updated = false; + + /** + * Empty Constructor for the Builder object + */ + public Builder() {} + + /** + * Add an id to this builder + * @param id the id to add + * @return the updated builder + */ + public Builder id(String id) { + this.id = id; + return this; + } + + /** + * Adds a shard id to this builder + * @param shardId the shard id to add + * @return the updated builder + */ + public Builder shardId(ShardId shardId) { + this.shardId = shardId; + return this; + } + + /** + * Adds a generated shard id to this builder + * @param indexName the index name to generate a shard id + * @return the updated builder + */ + public Builder shardId(String indexName) { + this.shardId = new ShardId(indexName, Strings.UNKNOWN_UUID_VALUE, 0); + return this; + } + + /** + * Adds shard information (statistics) to this builder + * @param shardInfo the shard info to add + * @return the updated builder + */ + public Builder shardInfo(ShardInfo shardInfo) { + this.shardInfo = shardInfo; + return this; + } + /** + * Add a updated status to this builder + * @param updated the updated status to add + * @return the updated builder + */ + public Builder updated(boolean updated) { + this.updated = updated; + return this; + } + + /** + * Builds the object + * @return A {@link UpdateDataObjectResponse} + */ + public UpdateDataObjectResponse build() { + return new UpdateDataObjectResponse(this.id, this.shardId, this.shardInfo, this.updated); + } + } +} diff --git a/common/src/test/java/org/opensearch/sdk/SdkClientTests.java b/common/src/test/java/org/opensearch/sdk/SdkClientTests.java index 08b3732a42..141c63775f 100644 --- a/common/src/test/java/org/opensearch/sdk/SdkClientTests.java +++ b/common/src/test/java/org/opensearch/sdk/SdkClientTests.java @@ -13,6 +13,8 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.core.rest.RestStatus; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; @@ -40,11 +42,15 @@ public class SdkClientTests { @Mock private GetDataObjectResponse getResponse; @Mock + private UpdateDataObjectRequest updateRequest; + @Mock + private UpdateDataObjectResponse updateResponse; + @Mock private DeleteDataObjectRequest deleteRequest; @Mock private DeleteDataObjectResponse deleteResponse; - private RuntimeException testException; + private OpenSearchStatusException testException; private InterruptedException interruptedException; @Before @@ -61,12 +67,17 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe return CompletableFuture.completedFuture(getResponse); } + @Override + public CompletionStage updateDataObjectAsync(UpdateDataObjectRequest request, Executor executor) { + return CompletableFuture.completedFuture(updateResponse); + } + @Override public CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request, Executor executor) { return CompletableFuture.completedFuture(deleteResponse); } }); - testException = new RuntimeException(); + testException = new OpenSearchStatusException("Test", RestStatus.BAD_REQUEST); interruptedException = new InterruptedException(); } @@ -81,10 +92,10 @@ public void testPutDataObjectException() { when(sdkClient.putDataObjectAsync(any(PutDataObjectRequest.class), any(Executor.class))) .thenReturn(CompletableFuture.failedFuture(testException)); - OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { + OpenSearchStatusException exception = assertThrows(OpenSearchStatusException.class, () -> { sdkClient.putDataObject(putRequest); }); - assertEquals(testException, exception.getCause()); + assertEquals(testException, exception); assertFalse(Thread.interrupted()); verify(sdkClient).putDataObjectAsync(any(PutDataObjectRequest.class), any(Executor.class)); } @@ -113,10 +124,10 @@ public void testGetDataObjectException() { when(sdkClient.getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class))) .thenReturn(CompletableFuture.failedFuture(testException)); - OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { + OpenSearchStatusException exception = assertThrows(OpenSearchStatusException.class, () -> { sdkClient.getDataObject(getRequest); }); - assertEquals(testException, exception.getCause()); + assertEquals(testException, exception); assertFalse(Thread.interrupted()); verify(sdkClient).getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class)); } @@ -134,6 +145,37 @@ public void testGetDataObjectInterrupted() { verify(sdkClient).getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class)); } + + @Test + public void testUpdateDataObjectSuccess() { + assertEquals(updateResponse, sdkClient.updateDataObject(updateRequest)); + verify(sdkClient).updateDataObjectAsync(any(UpdateDataObjectRequest.class), any(Executor.class)); + } + + @Test + public void testUpdateDataObjectException() { + when(sdkClient.updateDataObjectAsync(any(UpdateDataObjectRequest.class), any(Executor.class))) + .thenReturn(CompletableFuture.failedFuture(testException)); + OpenSearchStatusException exception = assertThrows(OpenSearchStatusException.class, () -> { + sdkClient.updateDataObject(updateRequest); + }); + assertEquals(testException, exception); + assertFalse(Thread.interrupted()); + verify(sdkClient).updateDataObjectAsync(any(UpdateDataObjectRequest.class), any(Executor.class)); + } + + @Test + public void testUpdateDataObjectInterrupted() { + when(sdkClient.updateDataObjectAsync(any(UpdateDataObjectRequest.class), any(Executor.class))) + .thenReturn(CompletableFuture.failedFuture(interruptedException)); + OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { + sdkClient.updateDataObject(updateRequest); + }); + assertEquals(interruptedException, exception.getCause()); + assertTrue(Thread.interrupted()); + verify(sdkClient).updateDataObjectAsync(any(UpdateDataObjectRequest.class), any(Executor.class)); + } + @Test public void testDeleteDataObjectSuccess() { assertEquals(deleteResponse, sdkClient.deleteDataObject(deleteRequest)); @@ -144,10 +186,10 @@ public void testDeleteDataObjectSuccess() { public void testDeleteDataObjectException() { when(sdkClient.deleteDataObjectAsync(any(DeleteDataObjectRequest.class), any(Executor.class))) .thenReturn(CompletableFuture.failedFuture(testException)); - OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { + OpenSearchStatusException exception = assertThrows(OpenSearchStatusException.class, () -> { sdkClient.deleteDataObject(deleteRequest); }); - assertEquals(testException, exception.getCause()); + assertEquals(testException, exception); assertFalse(Thread.interrupted()); verify(sdkClient).deleteDataObjectAsync(any(DeleteDataObjectRequest.class), any(Executor.class)); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java index b4220fef95..b5e6fb81e0 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java @@ -14,6 +14,7 @@ import java.util.Arrays; import java.util.List; +import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.delete.DeleteRequest; @@ -123,7 +124,7 @@ private void checkForModelsUsingConnector(String connectorId, String tenantId, A sourceBuilder.query(QueryBuilders.matchQuery(TENANT_ID, tenantId)); } searchRequest.source(sourceBuilder); - // TODO: User SDK client not client. + // TODO: Use SDK client not client. client.search(searchRequest, ActionListener.runBefore(ActionListener.wrap(searchResponse -> { SearchHit[] searchHits = searchResponse.getHits().getHits(); if (searchHits.length == 0) { @@ -186,8 +187,13 @@ private void handleDeleteResponse( ActionListener actionListener ) { if (throwable != null) { - log.error("Failed to delete ML connector: {}", connectorId, throwable); - actionListener.onFailure(new RuntimeException(throwable)); + Throwable cause = throwable.getCause() == null ? throwable : throwable.getCause(); + log.error("Failed to delete ML connector: {}", connectorId, cause); + if (cause instanceof Exception) { + actionListener.onFailure((Exception) cause); + } else { + actionListener.onFailure(new OpenSearchException(cause)); + } } else { log.info("Connector deletion result: {}, connector id: {}", response.deleted(), response.id()); DeleteResponse deleteResponse = new DeleteResponse(response.shardId(), response.id(), 0, 0, 0, response.deleted()); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java index 8221b33886..9d09bfffbc 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java @@ -93,7 +93,6 @@ protected void doExecute(Task task, ActionRequest request, ActionListener handleConnectorAccessValidationFailure(connectorId, e, actionListener) ) ); - } catch (Exception e) { log.error("Failed to get ML connector {}", connectorId, e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java index e7cb3c40ac..7f4253e59d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java @@ -12,6 +12,7 @@ import java.util.HashSet; import java.util.List; +import org.opensearch.OpenSearchException; import org.opensearch.action.ActionRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -147,7 +148,13 @@ private void indexConnector(Connector connector, ActionListener { context.restore(); if (throwable != null) { - listener.onFailure(new RuntimeException(throwable)); + Throwable cause = throwable.getCause() == null ? throwable : throwable.getCause(); + log.error("Failed to create ML connector", cause); + if (cause instanceof Exception) { + listener.onFailure((Exception) cause); + } else { + listener.onFailure(new OpenSearchException(cause)); + } } else { log.info("Connector creation result: {}, connector id: {}", r.created(), r.id()); MLCreateConnectorResponse response = new MLCreateConnectorResponse(r.id()); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java index f227b22e2a..7b4fa7e787 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java @@ -7,31 +7,30 @@ import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.DocWriteResponse.Result; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -46,6 +45,8 @@ import org.opensearch.ml.utils.TenantAwareHelper; import org.opensearch.sdk.GetDataObjectRequest; import org.opensearch.sdk.SdkClient; +import org.opensearch.sdk.UpdateDataObjectRequest; +import org.opensearch.sdk.UpdateDataObjectResponse; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; @@ -60,7 +61,8 @@ @FieldDefaults(level = AccessLevel.PRIVATE) public class UpdateConnectorTransportAction extends HandledTransportAction { Client client; - private final SdkClient sdkClient; + SdkClient sdkClient; + ConnectorAccessControlHelper connectorAccessControlHelper; private final MLFeatureEnabledSetting mlFeatureEnabledSetting; MLModelManager mlModelManager; @@ -122,10 +124,12 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - log.error("Unable to find the connector with ID {}. Details: {}", connectorId, exception); + log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", connectorId, exception); listener.onFailure(exception); })); } catch (Exception e) { @@ -147,7 +151,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listener, ThreadContext.StoredContext context ) { @@ -159,10 +163,15 @@ private void updateUndeployedConnector( sourceBuilder.query(boolQueryBuilder); searchRequest.source(sourceBuilder); + // TODO: Use SDK client not client. client.search(searchRequest, ActionListener.wrap(searchResponse -> { SearchHit[] searchHits = searchResponse.getHits().getHits(); if (searchHits.length == 0) { - client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context)); + sdkClient + .updateDataObjectAsync(updateDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) + .whenComplete((r, throwable) -> { + handleUpdateDataObjectCompletionStage(r, throwable, getUpdateResponseListener(connectorId, listener, context)); + }); } else { log.error(searchHits.length + " models are still using this connector, please undeploy the models first!"); List modelIds = new ArrayList<>(); @@ -181,15 +190,36 @@ private void updateUndeployedConnector( } }, e -> { if (e instanceof IndexNotFoundException) { - client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context)); + sdkClient + .updateDataObjectAsync(updateDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) + .whenComplete((r, throwable) -> { + handleUpdateDataObjectCompletionStage(r, throwable, getUpdateResponseListener(connectorId, listener, context)); + }); return; } log.error("Failed to update ML connector: " + connectorId, e); listener.onFailure(e); - })); } + private void handleUpdateDataObjectCompletionStage( + UpdateDataObjectResponse r, + Throwable throwable, + ActionListener updateListener + ) { + if (throwable != null) { + Throwable cause = throwable.getCause() == null ? throwable : throwable.getCause(); + if (cause instanceof Exception) { + updateListener.onFailure((Exception) cause); + } else { + updateListener.onFailure(new OpenSearchException(cause)); + } + } else { + log.info("Connector update result: {}, connector id: {}", r.updated(), r.id()); + updateListener.onResponse(new UpdateResponse(r.shardId(), r.id(), 0, 0, 0, r.updated() ? Result.UPDATED : Result.CREATED)); + } + } + private ActionListener getUpdateResponseListener( String connectorId, ActionListener actionListener, diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java index 2b9935f82e..76b58e7695 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java @@ -14,6 +14,7 @@ import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetRequest; import org.opensearch.client.Client; @@ -176,8 +177,12 @@ public void getConnector( log.error("Failed to get connector index", cause); listener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND)); } else { - log.error("Failed to find connector {}", connectorId, cause); - listener.onFailure(new RuntimeException(cause)); + log.error("Failed to get ML connector " + connectorId, cause); + if (cause instanceof Exception) { + listener.onFailure((Exception) cause); + } else { + listener.onFailure(new OpenSearchException(cause)); + } } } else { if (r != null && r.parser().isPresent()) { diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java index 2af15f6bb3..dca67bd15b 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java @@ -10,10 +10,12 @@ import static org.opensearch.action.DocWriteResponse.Result.CREATED; import static org.opensearch.action.DocWriteResponse.Result.DELETED; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; +import java.io.IOException; import java.security.AccessController; import java.security.PrivilegedAction; import java.util.Optional; @@ -21,7 +23,6 @@ import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; -import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; @@ -29,13 +30,15 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.IndexNotFoundException; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; import org.opensearch.sdk.GetDataObjectRequest; @@ -43,6 +46,8 @@ import org.opensearch.sdk.PutDataObjectRequest; import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; +import org.opensearch.sdk.UpdateDataObjectRequest; +import org.opensearch.sdk.UpdateDataObjectResponse; import lombok.extern.log4j.Log4j2; @@ -79,8 +84,12 @@ public CompletionStage putDataObjectAsync(PutDataObjectRe .actionGet(); log.info("Creation status for id {}: {}", indexResponse.getId(), indexResponse.getResult()); return new PutDataObjectResponse.Builder().id(indexResponse.getId()).created(indexResponse.getResult() == CREATED).build(); - } catch (Exception e) { - throw new OpenSearchException(e); + } catch (IOException e) { + // Rethrow unchecked exception on XContent parsing error + throw new OpenSearchStatusException( + "Failed to parse data object to put in index " + request.index(), + RestStatus.BAD_REQUEST + ); } }), executor); } @@ -90,7 +99,9 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { try { log.info("Getting {} from {}", request.id(), request.index()); - GetResponse getResponse = client.get(new GetRequest(request.index(), request.id())).actionGet(); + GetResponse getResponse = client + .get(new GetRequest(request.index(), request.id()).fetchSourceContext(request.fetchSourceContext())) + .actionGet(); if (getResponse == null || !getResponse.isExists()) { return new GetDataObjectResponse.Builder().id(request.id()).build(); } @@ -102,30 +113,55 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe .parser(Optional.of(parser)) .source(getResponse.getSource()) .build(); - } catch (OpenSearchStatusException | IndexNotFoundException notFound) { - throw notFound; - } catch (Exception e) { - throw new OpenSearchException(e); + } catch (IOException e) { + // Rethrow unchecked exception on XContent parser creation error + throw new OpenSearchStatusException( + "Failed to create parser for data object retrieved from index " + request.index(), + RestStatus.INTERNAL_SERVER_ERROR + ); } }), executor); } @Override - public CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request, Executor executor) { - return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { - try { - log.info("Deleting {} from {}", request.id(), request.index()); - DeleteResponse deleteResponse = client.delete(new DeleteRequest(request.index(), request.id())).actionGet(); - log.info("Deletion status for id {}: {}", deleteResponse.getId(), deleteResponse.getResult()); - return new DeleteDataObjectResponse.Builder() - .id(deleteResponse.getId()) - .shardId(deleteResponse.getShardId()) - .shardInfo(deleteResponse.getShardInfo()) - .deleted(deleteResponse.getResult() == DELETED) + public CompletionStage updateDataObjectAsync(UpdateDataObjectRequest request, Executor executor) { + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + try (XContentBuilder sourceBuilder = XContentFactory.jsonBuilder()) { + log.info("Updating {} from {}", request.id(), request.index()); + UpdateResponse updateResponse = client + .update( + new UpdateRequest(request.index(), request.id()).doc(request.dataObject().toXContent(sourceBuilder, EMPTY_PARAMS)) + ) + .actionGet(); + log.info("Update status for id {}: {}", updateResponse.getId(), updateResponse.getResult()); + return new UpdateDataObjectResponse.Builder() + .id(updateResponse.getId()) + .shardId(updateResponse.getShardId()) + .shardInfo(updateResponse.getShardInfo()) + .updated(updateResponse.getResult() == UPDATED) .build(); - } catch (Exception e) { - throw new OpenSearchException(e); + } catch (IOException e) { + // Rethrow unchecked exception on XContent parsing error + throw new OpenSearchStatusException( + "Failed to parse data object to update in index " + request.index(), + RestStatus.BAD_REQUEST + ); } }), executor); } + + @Override + public CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request, Executor executor) { + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + log.info("Deleting {} from {}", request.id(), request.index()); + DeleteResponse deleteResponse = client.delete(new DeleteRequest(request.index(), request.id())).actionGet(); + log.info("Deletion status for id {}: {}", deleteResponse.getId(), deleteResponse.getResult()); + return new DeleteDataObjectResponse.Builder() + .id(deleteResponse.getId()) + .shardId(deleteResponse.getShardId()) + .shardInfo(deleteResponse.getShardInfo()) + .deleted(deleteResponse.getResult() == DELETED) + .build(); + }), executor); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java index 8284b95f4e..35c5fb5743 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java @@ -10,7 +10,9 @@ import static org.opensearch.client.opensearch._types.Result.Created; import static org.opensearch.client.opensearch._types.Result.Deleted; +import static org.opensearch.client.opensearch._types.Result.Updated; +import java.io.IOException; import java.security.AccessController; import java.security.PrivilegedAction; import java.util.Map; @@ -19,7 +21,7 @@ import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; -import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.support.replication.ReplicationResponse.ShardInfo; import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.opensearch.core.DeleteRequest; @@ -28,9 +30,15 @@ import org.opensearch.client.opensearch.core.GetResponse; import org.opensearch.client.opensearch.core.IndexRequest; import org.opensearch.client.opensearch.core.IndexResponse; +import org.opensearch.client.opensearch.core.UpdateRequest; +import org.opensearch.client.opensearch.core.UpdateResponse; import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; @@ -39,6 +47,8 @@ import org.opensearch.sdk.PutDataObjectRequest; import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; +import org.opensearch.sdk.UpdateDataObjectRequest; +import org.opensearch.sdk.UpdateDataObjectResponse; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.databind.ObjectMapper; @@ -70,8 +80,12 @@ public CompletionStage putDataObjectAsync(PutDataObjectRe IndexResponse indexResponse = openSearchClient.index(indexRequest); log.info("Creation status for id {}: {}", indexResponse.id(), indexResponse.result()); return new PutDataObjectResponse.Builder().id(indexResponse.id()).created(indexResponse.result() == Created).build(); - } catch (Exception e) { - throw new OpenSearchException("Error occurred while indexing data object", e); + } catch (IOException e) { + // Rethrow unchecked exception on XContent parsing error + throw new OpenSearchStatusException( + "Failed to parse data object to put in index " + request.index(), + RestStatus.BAD_REQUEST + ); } }), executor); } @@ -94,8 +108,50 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe XContentParser parser = JsonXContent.jsonXContent .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, json); return new GetDataObjectResponse.Builder().id(getResponse.id()).parser(Optional.of(parser)).source(source).build(); - } catch (Exception e) { - throw new OpenSearchException(e); + } catch (IOException e) { + // Rethrow unchecked exception on XContent parser creation error + throw new OpenSearchStatusException( + "Failed to create parser for data object retrieved from index " + request.index(), + RestStatus.INTERNAL_SERVER_ERROR + ); + } + }), executor); + } + + @Override + public CompletionStage updateDataObjectAsync(UpdateDataObjectRequest request, Executor executor) { + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + @SuppressWarnings("unchecked") + Class> documentType = (Class>) (Class) Map.class; + request.dataObject().toXContent(builder, ToXContent.EMPTY_PARAMS); + Map docMap = JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, builder.toString()) + .map(); + UpdateRequest, ?> updateRequest = new UpdateRequest.Builder, Map>() + .index(request.index()) + .id(request.id()) + .doc(docMap) + .build(); + log.info("Updating {} in {}", request.id(), request.index()); + UpdateResponse> updateResponse = openSearchClient.update(updateRequest, documentType); + log.info("Update status for id {}: {}", updateResponse.id(), updateResponse.result()); + ShardInfo shardInfo = new ShardInfo( + updateResponse.shards().total().intValue(), + updateResponse.shards().successful().intValue() + ); + return new UpdateDataObjectResponse.Builder() + .id(updateResponse.id()) + .shardId(updateResponse.index()) + .shardInfo(shardInfo) + .updated(updateResponse.result() == Updated) + .build(); + } catch (IOException e) { + // Rethrow unchecked exception on update IOException + throw new OpenSearchStatusException( + "Parsing error updating data object " + request.id() + " in index " + request.index(), + RestStatus.BAD_REQUEST + ); } }), executor); } @@ -118,8 +174,12 @@ public CompletionStage deleteDataObjectAsync(DeleteDat .shardInfo(shardInfo) .deleted(deleteResponse.result() == Deleted) .build(); - } catch (Exception e) { - throw new OpenSearchException("Error occurred while deleting data object", e); + } catch (IOException e) { + // Rethrow unchecked exception on deletion IOException + throw new OpenSearchStatusException( + "IOException occurred while deleting data object " + request.id() + " from index " + request.index(), + RestStatus.INTERNAL_SERVER_ERROR + ); } }), executor); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java index 6f10e1d715..3cf3595e2a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java @@ -354,9 +354,7 @@ public void testDeleteConnector_ResourceNotFoundException() throws IOException, ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); verify(actionListener).onFailure(argumentCaptor.capture()); - // TODO: fix all this exception nesting - // java.util.concurrent.CompletionException: OpenSearchException[ResourceNotFoundException[errorMessage]]; nested: ResourceNotFoundException[errorMessage]; - assertEquals("errorMessage", argumentCaptor.getValue().getCause().getCause().getCause().getMessage()); + assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); } public void test_ValidationFailedException() throws IOException { diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java index ba7189df13..f125321d15 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java @@ -186,7 +186,6 @@ public void testGetConnector_NullResponse() throws InterruptedException { assertEquals("Failed to find connector with the provided connector id: connector_id", argumentCaptor.getValue().getMessage()); } - @Test public void testGetConnector_MultiTenancyEnabled_Success() throws IOException, InterruptedException { when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); when(connectorAccessControlHelper.hasPermission(any(), any())).thenReturn(true); diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java index d4e929ea25..fe89661967 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java @@ -6,6 +6,7 @@ package org.opensearch.ml.action.connector; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.*; import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; @@ -20,19 +21,25 @@ import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import org.apache.lucene.search.TotalHits; +import org.junit.AfterClass; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.DocWriteResponse.Result; +import org.opensearch.action.LatchedActionListener; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; @@ -75,8 +82,6 @@ public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { - private UpdateConnectorTransportAction updateConnectorTransportAction; - private static TestThreadPool testThreadPool = new TestThreadPool( UpdateConnectorTransportActionTests.class.getName(), new ScalingExecutorBuilder( @@ -88,6 +93,8 @@ public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { ) ); + private UpdateConnectorTransportAction updateConnectorTransportAction; + @Mock private ConnectorAccessControlHelper connectorAccessControlHelper; @@ -98,6 +105,9 @@ public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { private Client client; private SdkClient sdkClient; + @Mock + private NamedXContentRegistry xContentRegistry; + @Mock private ThreadPool threadPool; @@ -113,13 +123,9 @@ public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { @Mock private ActionFilters actionFilters; - @Mock - NamedXContentRegistry xContentRegistry; - @Mock private MLUpdateConnectorRequest updateRequest; - @Mock private UpdateResponse updateResponse; @Mock @@ -138,13 +144,16 @@ public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { private MLEngine mlEngine; + private static final String TEST_CONNECTOR_ID = "test_connector_id"; private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList .of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$"); @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); + sdkClient = new LocalClusterIndicesClient(client, xContentRegistry); + settings = Settings .builder() .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) @@ -163,14 +172,13 @@ public void setup() throws IOException { when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); when(threadPool.executor(anyString())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); - String connector_id = "test_connector_id"; MLCreateConnectorInput updateContent = MLCreateConnectorInput .builder() .updateConnector(true) .version("2") .description("updated description") .build(); - when(updateRequest.getConnectorId()).thenReturn(connector_id); + when(updateRequest.getConnectorId()).thenReturn(TEST_CONNECTOR_ID); when(updateRequest.getUpdateContent()).thenReturn(updateContent); SearchHits hits = new SearchHits(new SearchHit[] {}, new TotalHits(0, TotalHits.Relation.EQUAL_TO), Float.NaN); @@ -236,8 +244,13 @@ public void setup() throws IOException { }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); } + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); + } + @Test - public void testExecuteConnectorAccessControlSuccess() { + public void testExecuteConnectorAccessControlSuccess() throws InterruptedException { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -246,14 +259,16 @@ public void testExecuteConnectorAccessControlSuccess() { return null; }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(updateResponse); - return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(updateResponse); + when(client.update(any(UpdateRequest.class))).thenReturn(future); - updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + CountDownLatch latch = new CountDownLatch(1); + ActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + updateConnectorTransportAction.doExecute(task, updateRequest, latchedActionListener); + latch.await(); + + verify(actionListener).onResponse(any(UpdateResponse.class)); } @Test @@ -294,7 +309,7 @@ public void testExecuteConnectorAccessControlException() { } @Test - public void testExecuteUpdateWrongStatus() { + public void testExecuteUpdateWrongStatus() throws InterruptedException { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -303,19 +318,23 @@ public void testExecuteUpdateWrongStatus() { return null; }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); - UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(updateResponse); - return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + PlainActionFuture future = PlainActionFuture.newFuture(); + UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, Result.CREATED); + future.onResponse(updateResponse); + when(client.update(any(UpdateRequest.class))).thenReturn(future); - updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + CountDownLatch latch = new CountDownLatch(1); + ActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + updateConnectorTransportAction.doExecute(task, updateRequest, latchedActionListener); + latch.await(); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(Result.CREATED, argumentCaptor.getValue().getResult()); } @Test - public void testExecuteUpdateException() { + public void testExecuteUpdateException() throws InterruptedException { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -324,13 +343,13 @@ public void testExecuteUpdateException() { return null; }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new RuntimeException("update document failure")); - return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + when(client.update(any(UpdateRequest.class))).thenThrow(new RuntimeException("update document failure")); + + CountDownLatch latch = new CountDownLatch(1); + ActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + updateConnectorTransportAction.doExecute(task, updateRequest, latchedActionListener); + latch.await(); - updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("update document failure", argumentCaptor.getValue().getMessage()); @@ -371,7 +390,7 @@ public void testExecuteSearchResponseError() { } @Test - public void testExecuteSearchIndexNotFoundError() { + public void testExecuteSearchIndexNotFoundError() throws InterruptedException { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -409,14 +428,18 @@ public void testExecuteSearchIndexNotFoundError() { return null; }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(updateResponse); - return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(updateResponse); + when(client.update(any(UpdateRequest.class))).thenReturn(future); - updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + CountDownLatch latch = new CountDownLatch(1); + ActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + updateConnectorTransportAction.doExecute(task, updateRequest, latchedActionListener); + latch.await(); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(Result.UPDATED, argumentCaptor.getValue().getResult()); } private SearchResponse noneEmptySearchResponse() throws IOException { diff --git a/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java index 68bef15120..696fd634cf 100644 --- a/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java @@ -52,6 +52,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.get.GetResult; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; @@ -443,11 +444,7 @@ public void testGetConnectorHappyCase() throws IOException, InterruptedException .id("connectorId") .build(); GetResponse getResponse = prepareConnector(); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(getResponse); - return null; - }).when(client).get(any(), any()); + PlainActionFuture future = PlainActionFuture.newFuture(); future.onResponse(getResponse); when(client.get(any(GetRequest.class))).thenReturn(future); @@ -470,6 +467,65 @@ public void testGetConnectorHappyCase() throws IOException, InterruptedException assertEquals(CommonValue.ML_CONNECTOR_INDEX, requestCaptor.getValue().index()); } + @Test + public void testGetConnectorException() throws IOException, InterruptedException { + GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder() + .index(CommonValue.ML_CONNECTOR_INDEX) + .id("connectorId") + .build(); + + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onFailure(new RuntimeException("Failed to get connector")); + when(client.get(any(GetRequest.class))).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(getConnectorActionListener, latch); + connectorAccessControlHelper + .getConnector( + sdkClient, + client, + client.threadPool().getThreadContext().newStoredContext(true), + getRequest, + "connectorId", + latchedActionListener + ); + latch.await(); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getConnectorActionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to get connector", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testGetConnectorIndexNotFound() throws IOException, InterruptedException { + GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder() + .index(CommonValue.ML_CONNECTOR_INDEX) + .id("connectorId") + .build(); + + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onFailure(new IndexNotFoundException("Index not found")); + when(client.get(any(GetRequest.class))).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(getConnectorActionListener, latch); + connectorAccessControlHelper + .getConnector( + sdkClient, + client, + client.threadPool().getThreadContext().newStoredContext(true), + getRequest, + "connectorId", + latchedActionListener + ); + latch.await(); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(getConnectorActionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to find connector", argumentCaptor.getValue().getMessage()); + assertEquals(RestStatus.NOT_FOUND, argumentCaptor.getValue().status()); + } + private GetResponse createGetResponse(List backendRoles) { HttpConnector httpConnector = HttpConnector .builder() diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java index ffd0e0072b..8b13419ac0 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java @@ -9,7 +9,6 @@ package org.opensearch.ml.sdkclient; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -28,7 +27,6 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.OpenSearchException; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; @@ -38,6 +36,8 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.support.replication.ReplicationResponse.ShardInfo; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.common.action.ActionFuture; import org.opensearch.common.settings.Settings; @@ -45,8 +45,6 @@ import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.json.JsonXContent; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.action.ActionResponse; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.sdk.DeleteDataObjectRequest; @@ -56,6 +54,8 @@ import org.opensearch.sdk.PutDataObjectRequest; import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; +import org.opensearch.sdk.UpdateDataObjectRequest; +import org.opensearch.sdk.UpdateDataObjectResponse; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ScalingExecutorBuilder; import org.opensearch.threadpool.TestThreadPool; @@ -125,18 +125,17 @@ public void testPutDataObject() throws IOException { public void testPutDataObject_Exception() throws IOException { PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder().index(TEST_INDEX).dataObject(testDataObject).build(); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new IOException("test")); - return null; - }).when(mockedClient).index(any(IndexRequest.class), any()); + ArgumentCaptor indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class); + when(mockedClient.index(indexRequestCaptor.capture())).thenThrow(new UnsupportedOperationException("test")); CompletableFuture future = sdkClient .putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) .toCompletableFuture(); CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); - assertEquals(OpenSearchException.class, ce.getCause().getClass()); + Throwable cause = ce.getCause(); + assertEquals(UnsupportedOperationException.class, cause.getClass()); + assertEquals("test", cause.getMessage()); } public void testGetDataObject() throws IOException { @@ -194,18 +193,91 @@ public void testGetDataObject_NotFound() throws IOException { public void testGetDataObject_Exception() throws IOException { GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder().index(TEST_INDEX).id(TEST_ID).build(); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new IOException("test")); - return null; - }).when(mockedClient).get(any(GetRequest.class), any()); + ArgumentCaptor getRequestCaptor = ArgumentCaptor.forClass(GetRequest.class); + when(mockedClient.get(getRequestCaptor.capture())).thenThrow(new UnsupportedOperationException("test")); CompletableFuture future = sdkClient .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) .toCompletableFuture(); CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); - assertEquals(OpenSearchException.class, ce.getCause().getClass()); + Throwable cause = ce.getCause(); + assertEquals(UnsupportedOperationException.class, cause.getClass()); + assertEquals("test", cause.getMessage()); + } + + public void testUpdateDataObject() throws IOException { + UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject) + .build(); + + UpdateResponse updateResponse = mock(UpdateResponse.class); + when(updateResponse.getId()).thenReturn(TEST_ID); + when(updateResponse.getResult()).thenReturn(DocWriteResponse.Result.UPDATED); + @SuppressWarnings("unchecked") + ActionFuture future = mock(ActionFuture.class); + when(mockedClient.update(any(UpdateRequest.class))).thenReturn(future); + when(future.actionGet()).thenReturn(updateResponse); + + UpdateDataObjectResponse response = sdkClient + .updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(UpdateRequest.class); + verify(mockedClient, times(1)).update(requestCaptor.capture()); + assertEquals(TEST_INDEX, requestCaptor.getValue().index()); + assertEquals(TEST_ID, response.id()); + assertTrue(response.updated()); + } + + public void testUpdateDataObject_NotFound() throws IOException { + UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject) + .build(); + + UpdateResponse updateResponse = mock(UpdateResponse.class); + when(updateResponse.getId()).thenReturn(TEST_ID); + when(updateResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); + @SuppressWarnings("unchecked") + ActionFuture future = mock(ActionFuture.class); + when(mockedClient.update(any(UpdateRequest.class))).thenReturn(future); + when(future.actionGet()).thenReturn(updateResponse); + + UpdateDataObjectResponse response = sdkClient + .updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(UpdateRequest.class); + verify(mockedClient, times(1)).update(requestCaptor.capture()); + assertEquals(TEST_INDEX, requestCaptor.getValue().index()); + assertEquals(TEST_ID, response.id()); + assertFalse(response.updated()); + } + + public void testUpdateDataObject_Exception() throws IOException { + UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject) + .build(); + + ArgumentCaptor updateRequestCaptor = ArgumentCaptor.forClass(UpdateRequest.class); + when(mockedClient.update(updateRequestCaptor.capture())).thenThrow(new UnsupportedOperationException("test")); + + CompletableFuture future = sdkClient + .updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); + + CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); + Throwable cause = ce.getCause(); + assertEquals(UnsupportedOperationException.class, cause.getClass()); + assertEquals("test", cause.getMessage()); } public void testDeleteDataObject() throws IOException { @@ -236,17 +308,16 @@ public void testDeleteDataObject() throws IOException { public void testDeleteDataObject_Exception() throws IOException { DeleteDataObjectRequest deleteRequest = new DeleteDataObjectRequest.Builder().index(TEST_INDEX).id(TEST_ID).build(); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new IOException("test")); - return null; - }).when(mockedClient).delete(any(DeleteRequest.class), any()); + ArgumentCaptor deleteRequestCaptor = ArgumentCaptor.forClass(DeleteRequest.class); + when(mockedClient.delete(deleteRequestCaptor.capture())).thenThrow(new UnsupportedOperationException("test")); CompletableFuture future = sdkClient .deleteDataObjectAsync(deleteRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) .toCompletableFuture(); CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); - assertEquals(OpenSearchException.class, ce.getCause().getClass()); + Throwable cause = ce.getCause(); + assertEquals(UnsupportedOperationException.class, cause.getClass()); + assertEquals("test", cause.getMessage()); } } diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java index ee7ee53f98..2018a69a0b 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java @@ -8,6 +8,7 @@ */ package org.opensearch.ml.sdkclient; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; @@ -24,7 +25,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchStatusException; import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.opensearch._types.Result; import org.opensearch.client.opensearch._types.ShardStatistics; @@ -34,6 +35,8 @@ import org.opensearch.client.opensearch.core.GetResponse; import org.opensearch.client.opensearch.core.IndexRequest; import org.opensearch.client.opensearch.core.IndexResponse; +import org.opensearch.client.opensearch.core.UpdateRequest; +import org.opensearch.client.opensearch.core.UpdateResponse; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; @@ -45,6 +48,8 @@ import org.opensearch.sdk.PutDataObjectRequest; import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; +import org.opensearch.sdk.UpdateDataObjectRequest; +import org.opensearch.sdk.UpdateDataObjectResponse; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ScalingExecutorBuilder; import org.opensearch.threadpool.TestThreadPool; @@ -148,7 +153,7 @@ public void testPutDataObject_Exception() throws IOException { .toCompletableFuture(); CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); - assertEquals(OpenSearchException.class, ce.getCause().getClass()); + assertEquals(OpenSearchStatusException.class, ce.getCause().getClass()); } @SuppressWarnings({ "unchecked", "rawtypes" }) @@ -213,7 +218,87 @@ public void testGetDataObject_Exception() throws IOException { .toCompletableFuture(); CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); - assertEquals(OpenSearchException.class, ce.getCause().getClass()); + assertEquals(OpenSearchStatusException.class, ce.getCause().getClass()); + } + + public void testUpdateDataObject() throws IOException { + UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject) + .build(); + + UpdateResponse> updateResponse = new UpdateResponse.Builder>() + .id(TEST_ID) + .index(TEST_INDEX) + .primaryTerm(0) + .result(Result.Updated) + .seqNo(0) + .shards(new ShardStatistics.Builder().failed(0).successful(1).total(1).build()) + .version(0) + .build(); + + @SuppressWarnings("unchecked") + ArgumentCaptor, ?>> updateRequestCaptor = ArgumentCaptor.forClass(UpdateRequest.class); + when(mockedOpenSearchClient.update(updateRequestCaptor.capture(), any())).thenReturn(updateResponse); + + UpdateDataObjectResponse response = sdkClient + .updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + + assertEquals(TEST_INDEX, updateRequestCaptor.getValue().index()); + assertEquals(TEST_ID, response.id()); + assertTrue(response.updated()); + } + + public void testUpdateDataObject_NotFound() throws IOException { + UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject) + .build(); + + UpdateResponse> updateResponse = new UpdateResponse.Builder>() + .id(TEST_ID) + .index(TEST_INDEX) + .primaryTerm(0) + .result(Result.Created) + .seqNo(0) + .shards(new ShardStatistics.Builder().failed(0).successful(1).total(1).build()) + .version(0) + .build(); + + @SuppressWarnings("unchecked") + ArgumentCaptor, ?>> updateRequestCaptor = ArgumentCaptor.forClass(UpdateRequest.class); + when(mockedOpenSearchClient.update(updateRequestCaptor.capture(), any())).thenReturn(updateResponse); + + UpdateDataObjectResponse response = sdkClient + .updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + + assertEquals(TEST_INDEX, updateRequestCaptor.getValue().index()); + assertEquals(TEST_ID, response.id()); + assertFalse(response.updated()); + } + + public void testtUpdateDataObject_Exception() throws IOException { + UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject) + .build(); + + ArgumentCaptor> updateRequestCaptor = ArgumentCaptor.forClass(UpdateRequest.class); + when(mockedOpenSearchClient.update(updateRequestCaptor.capture(), any())).thenThrow(new IOException("test")); + + CompletableFuture future = sdkClient + .updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); + + CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); + assertEquals(OpenSearchStatusException.class, ce.getCause().getClass()); } public void testDeleteDataObject() throws IOException { @@ -285,6 +370,6 @@ public void testDeleteDataObject_Exception() throws IOException { .toCompletableFuture(); CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); - assertEquals(OpenSearchException.class, ce.getCause().getClass()); + assertEquals(OpenSearchStatusException.class, ce.getCause().getClass()); } } From 20381bee570dca24208b1925e2512aaaac3aab9e Mon Sep 17 00:00:00 2001 From: Arjun kumar Giri Date: Mon, 17 Jun 2024 10:23:26 -0700 Subject: [PATCH 06/10] Addressed CR comment Signed-off-by: Arjun kumar Giri --- .../sdk/DeleteDataObjectRequest.java | 10 +++- .../opensearch/sdk/GetDataObjectRequest.java | 9 ++++ .../opensearch/sdk/PutDataObjectRequest.java | 18 +++++++ .../sdk/UpdateDataObjectRequest.java | 23 +++++++- .../ml/sdkclient/DDBOpenSearchClient.java | 52 ++++++++++--------- .../ml/sdkclient/SdkClientModule.java | 10 ++-- 6 files changed, 89 insertions(+), 33 deletions(-) diff --git a/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java index 4cbe587f75..d2108be87f 100644 --- a/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java @@ -12,7 +12,6 @@ public class DeleteDataObjectRequest { private final String index; private final String id; - private final String tenantId; /** @@ -44,6 +43,10 @@ public String id() { return this.id; } + /** + * Returns the tenant id + * @return the tenantId + */ public String tenantId() { return this.tenantId; } @@ -81,6 +84,11 @@ public Builder id(String id) { return this; } + /** + * Add a tenant id to this builder + * @param tenantId the tenant id + * @return the updated builder + */ public Builder tenantId(String tenantId) { this.tenantId = tenantId; return this; diff --git a/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java index 3d282dbf04..d38e227f83 100644 --- a/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java @@ -48,6 +48,10 @@ public String id() { return this.id; } + /** + * Returns the tenant id + * @return the tenantId + */ public String tenantId() { return this.tenantId; } @@ -94,6 +98,11 @@ public Builder id(String id) { return this; } + /** + * Add a tenant id to this builder + * @param tenantId the tenant id + * @return the updated builder + */ public Builder tenantId(String tenantId) { this.tenantId = tenantId; return this; diff --git a/common/src/main/java/org/opensearch/sdk/PutDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/PutDataObjectRequest.java index bb36150de0..9052ef6d40 100644 --- a/common/src/main/java/org/opensearch/sdk/PutDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/PutDataObjectRequest.java @@ -39,10 +39,18 @@ public String index() { return this.index; } + /** + * Returns the document id + * @return the id + */ public String id() { return this.id; } + /** + * Returns the tenant id + * @return the tenantId + */ public String tenantId() { return this.tenantId; } @@ -79,11 +87,21 @@ public Builder index(String index) { return this; } + /** + * Add an id to this builder + * @param id the documet id + * @return the updated builder + */ public Builder id(String id) { this.id = id; return this; } + /** + * Add a tenant id to this builder + * @param tenantId the tenant id + * @return the updated builder + */ public Builder tenantId(String tenantId) { this.tenantId = tenantId; return this; diff --git a/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java index 25891a1167..81670091f6 100644 --- a/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java @@ -14,6 +14,9 @@ public class UpdateDataObjectRequest { private final String index; private final String id; + + private final String tenantId; + private final ToXContentObject dataObject; /** @@ -24,9 +27,10 @@ public class UpdateDataObjectRequest { * @param id the document id * @param dataObject the data object */ - public UpdateDataObjectRequest(String index, String id, ToXContentObject dataObject) { + public UpdateDataObjectRequest(String index, String id, String tenantId, ToXContentObject dataObject) { this.index = index; this.id = id; + this.tenantId = tenantId; this.dataObject = dataObject; } @@ -45,6 +49,14 @@ public String index() { public String id() { return this.id; } + + /** + * Returns the tenant id + * @return the tenantId + */ + public String tenantId() { + return this.tenantId; + } /** * Returns the data object @@ -60,6 +72,8 @@ public ToXContentObject dataObject() { public static class Builder { private String index = null; private String id = null; + + private String tenantId = null; private ToXContentObject dataObject = null; /** @@ -87,6 +101,11 @@ public Builder id(String id) { return this; } + public Builder tenantId(String tenantId) { + this.tenantId = tenantId; + return this; + } + /** * Add a data object to this builder * @param dataObject the data object @@ -102,7 +121,7 @@ public Builder dataObject(ToXContentObject dataObject) { * @return A {@link UpdateDataObjectRequest} */ public UpdateDataObjectRequest build() { - return new UpdateDataObjectRequest(this.index, this.id, this.dataObject); + return new UpdateDataObjectRequest(this.index, this.id, this.tenantId, this.dataObject); } } } diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java index 196e1885a9..631327a421 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java @@ -8,6 +8,7 @@ */ package org.opensearch.ml.sdkclient; +import java.io.IOException; import java.security.AccessController; import java.security.PrivilegedAction; import java.util.Map; @@ -17,13 +18,13 @@ import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; -import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchStatusException; import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.Strings; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; @@ -32,6 +33,8 @@ import org.opensearch.sdk.PutDataObjectRequest; import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; +import org.opensearch.sdk.UpdateDataObjectRequest; +import org.opensearch.sdk.UpdateDataObjectResponse; import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; @@ -70,24 +73,17 @@ public CompletionStage putDataObjectAsync(PutDataObjectRe final String tenantId = request.tenantId() != null ? request.tenantId() : DEFAULT_TENANT; final String tableName = getTableName(request.index()); return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { - try (XContentBuilder sourceBuilder = XContentFactory.jsonBuilder()) { - XContentBuilder builder = request.dataObject().toXContent(sourceBuilder, ToXContent.EMPTY_PARAMS); - String source = builder.toString(); + String source = Strings.toString(MediaTypeRegistry.JSON, request.dataObject()); + final Map item = Map + .ofEntries( + Map.entry(HASH_KEY, AttributeValue.builder().s(tenantId).build()), + Map.entry(RANGE_KEY, AttributeValue.builder().s(id).build()), + Map.entry(SOURCE, AttributeValue.builder().s(source).build()) + ); + final PutItemRequest putItemRequest = PutItemRequest.builder().tableName(tableName).item(item).build(); - final Map item = Map - .ofEntries( - Map.entry(HASH_KEY, AttributeValue.builder().s(tenantId).build()), - Map.entry(RANGE_KEY, AttributeValue.builder().s(id).build()), - Map.entry(SOURCE, AttributeValue.builder().s(source).build()) - ); - final PutItemRequest putItemRequest = PutItemRequest.builder().tableName(tableName).item(item).build(); - - dynamoDbClient.putItem(putItemRequest); - return new PutDataObjectResponse.Builder().id(id).created(true).build(); - } catch (Exception e) { - log.error("Exception while inserting data into DDB: " + e.getMessage(), e); - throw new OpenSearchException(e); - } + dynamoDbClient.putItem(putItemRequest); + return new PutDataObjectResponse.Builder().id(id).created(true).build(); }), executor); } @@ -120,13 +116,19 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe XContentParser parser = JsonXContent.jsonXContent .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, source); return new GetDataObjectResponse.Builder().id(request.id()).parser(Optional.of(parser)).build(); - } catch (Exception e) { - log.error("Exception while fetching data from DDB: " + e.getMessage(), e); - throw new OpenSearchException(e); + } catch (IOException e) { + // Rethrow unchecked exception on XContent parsing error + throw new OpenSearchStatusException("Failed to parse data object " + request.id(), RestStatus.BAD_REQUEST); } }), executor); } + @Override + public CompletionStage updateDataObjectAsync(UpdateDataObjectRequest request, Executor executor) { + // TODO: Implement update + return null; + } + /** * Deletes data document from DDB. Default tenant ID will be used if tenant ID is not specified. * @@ -153,7 +155,7 @@ public CompletionStage deleteDataObjectAsync(DeleteDat private String getTableName(String index) { // Table name will be same as index name. As DDB table name does not support dot(.) - // it will be removed form name. + // it will be removed from name. return index.replaceAll("\\.", ""); } } diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java index d7be7043b9..70fa2be136 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java @@ -41,7 +41,7 @@ public class SdkClientModule extends AbstractModule { public static final String REMOTE_OPENSEARCH = "RemoteOpenSearch"; public static final String AWS_DYNAMO_DB = "AWSDynamoDB"; - private final String remoteStoreType; + private final String remoteMetadataType; private final String remoteMetadataEndpoint; private final String region; // not using with RestClient @@ -57,21 +57,21 @@ public SdkClientModule() { * @param remoteMetadataEndpoint The remote endpoint * @param region The region */ - SdkClientModule(String remoteStoreType, String remoteMetadataEndpoint, String region) { - this.remoteStoreType = remoteStoreType; + SdkClientModule(String remoteMetadataType, String remoteMetadataEndpoint, String region) { + this.remoteMetadataType = remoteMetadataType; this.remoteMetadataEndpoint = remoteMetadataEndpoint; this.region = region; } @Override protected void configure() { - if (this.remoteStoreType == null) { + if (this.remoteMetadataType == null) { log.info("Using local opensearch cluster as metadata store"); bind(SdkClient.class).to(LocalClusterIndicesClient.class); return; } - switch (this.remoteStoreType) { + switch (this.remoteMetadataType) { case REMOTE_OPENSEARCH: log.info("Using remote opensearch cluster as metadata store"); bind(SdkClient.class).toInstance(new RemoteClusterIndicesClient(createOpenSearchClient())); From 82bddc85815dbbe5ef8d3e2f2e4ad057f54ba9e5 Mon Sep 17 00:00:00 2001 From: Arjun kumar Giri Date: Mon, 17 Jun 2024 11:15:53 -0700 Subject: [PATCH 07/10] Added javadoc based on feedback Signed-off-by: Arjun kumar Giri --- .../java/org/opensearch/sdk/DeleteDataObjectRequest.java | 1 + .../main/java/org/opensearch/sdk/GetDataObjectRequest.java | 1 + .../java/org/opensearch/sdk/UpdateDataObjectRequest.java | 3 ++- .../java/org/opensearch/ml/sdkclient/SdkClientModule.java | 1 + .../opensearch/ml/sdkclient/DDBOpenSearchClientTests.java | 5 ++--- 5 files changed, 7 insertions(+), 4 deletions(-) diff --git a/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java index d2108be87f..70e9506388 100644 --- a/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java @@ -20,6 +20,7 @@ public class DeleteDataObjectRequest { * For data storage implementations other than OpenSearch, an index may be referred to as a table and the id may be referred to as a primary key. * @param index the index location to delete the object * @param id the document id + * @param tenantId the tenant id */ public DeleteDataObjectRequest(String index, String id, String tenantId) { this.index = index; diff --git a/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java index d38e227f83..ee81411709 100644 --- a/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java @@ -23,6 +23,7 @@ public class GetDataObjectRequest { * For data storage implementations other than OpenSearch, an index may be referred to as a table and the id may be referred to as a primary key. * @param index the index location to get the object * @param id the document id + * @param tenantId the tenant id * @param fetchSourceContext the context to use when fetching _source */ public GetDataObjectRequest(String index, String id, String tenantId, FetchSourceContext fetchSourceContext) { diff --git a/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java index e7736272fa..d7ae56e704 100644 --- a/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java @@ -23,6 +23,7 @@ public class UpdateDataObjectRequest { * For data storage implementations other than OpenSearch, an index may be referred to as a table and the data object may be referred to as an item. * @param index the index location to update the object * @param id the document id + * @param tenantId the tenant id * @param dataObject the data object */ public UpdateDataObjectRequest(String index, String id, String tenantId, ToXContentObject dataObject) { @@ -100,7 +101,7 @@ public Builder id(String id) { /** * Add a tenant ID to this builder - * @param id the tenant id + * @param tenantId the tenant id * @return the updated builder */ public Builder tenantId(String tenantId) { diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java index 70fa2be136..e181011392 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java @@ -54,6 +54,7 @@ public SdkClientModule() { /** * Instantiate this module specifying the endpoint and region. Package private for testing. + * @param remoteMetadataType Type of remote metadata store * @param remoteMetadataEndpoint The remote endpoint * @param region The region */ diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java index 22faf00b25..aed028e42c 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java @@ -27,7 +27,6 @@ import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; -import org.opensearch.OpenSearchException; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; @@ -165,7 +164,7 @@ public void testPutDataObject_DDBException_ThrowsException() { .toCompletableFuture(); CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); - assertEquals(OpenSearchException.class, ce.getCause().getClass()); + assertEquals(RuntimeException.class, ce.getCause().getClass()); } @Test @@ -227,7 +226,7 @@ public void testGetDataObject_DDBException_ThrowsOSException() throws IOExceptio .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) .toCompletableFuture(); CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); - assertEquals(OpenSearchException.class, ce.getCause().getClass()); + assertEquals(RuntimeException.class, ce.getCause().getClass()); } @Test From 8315ee337b2f0d921b6eba5a67d628b3c7e75ef9 Mon Sep 17 00:00:00 2001 From: Arjun kumar Giri Date: Mon, 8 Jul 2024 08:59:48 -0700 Subject: [PATCH 08/10] Set tenant ID for predict request Signed-off-by: Arjun kumar Giri --- .../ml/engine/indices/MLIndicesHandler.java | 2 +- plugin/build.gradle | 1 + .../CreateControllerTransportAction.java | 2 +- .../DeleteControllerTransportAction.java | 2 +- .../GetControllerTransportAction.java | 2 +- .../UpdateControllerTransportAction.java | 2 +- .../deploy/TransportDeployModelAction.java | 4 +- .../TransportDeployModelOnNodeAction.java | 1 + .../models/GetModelTransportAction.java | 1 + .../models/UpdateModelTransportAction.java | 2 +- .../TransportPredictionTaskAction.java | 2 +- .../TransportRegisterModelAction.java | 1 + .../TransportUndeployModelsAction.java | 2 +- .../helper/ConnectorAccessControlHelper.java | 1 + .../opensearch/ml/model/MLModelManager.java | 140 ++++++++++-------- .../ml/plugin/MachineLearningPlugin.java | 12 +- .../ml/rest/RestMLPredictionAction.java | 7 +- .../ml/sdkclient/DDBOpenSearchClient.java | 30 ++-- .../sdkclient/LocalClusterIndicesClient.java | 6 +- .../ml/sdkclient/SdkClientModule.java | 65 ++++++-- .../ml/task/MLPredictTaskRunner.java | 46 ++---- .../CreateControllerTransportActionTests.java | 13 +- .../DeleteControllerTransportActionTests.java | 20 +-- .../GetControllerTransportActionTests.java | 16 +- .../UpdateControllerTransportActionTests.java | 32 ++-- .../TransportDeployModelActionTests.java | 38 ++--- ...TransportDeployModelOnNodeActionTests.java | 10 +- .../UpdateModelTransportActionTests.java | 81 +++++----- .../TransportUndeployModelsActionTests.java | 22 +-- .../ml/model/MLModelManagerTests.java | 117 +++++++-------- .../sdkclient/DDBOpenSearchClientTests.java | 4 +- .../ml/sdkclient/SdkClientModuleTests.java | 8 +- .../ml/task/MLPredictTaskRunnerTests.java | 40 ++--- 33 files changed, 393 insertions(+), 339 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java index 0b88d9ca19..779cab1ffc 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java @@ -180,7 +180,7 @@ public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) */ public void shouldUpdateIndex(String indexName, Integer newVersion, ActionListener listener) { IndexMetadata indexMetaData = clusterService.state().getMetadata().indices().get(indexName); - if (indexMetaData == null) { + if (indexMetaData == null || indexMetaData.mapping() == null) { listener.onResponse(Boolean.FALSE); return; } diff --git a/plugin/build.gradle b/plugin/build.gradle index 5b4f1ee3fb..7579b2a839 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -84,6 +84,7 @@ dependencies { implementation "software.amazon.awssdk:third-party-jackson-core:2.25.40" implementation("software.amazon.awssdk:url-connection-client:2.25.40") implementation("software.amazon.awssdk:utils:2.25.40") + implementation("software.amazon.awssdk:apache-client:2.25.40") configurations.all { diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java index 5439d73619..3172b337f2 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java @@ -99,7 +99,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); - mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { + mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> { FunctionName functionName = mlModel.getAlgorithm(); Boolean isHidden = mlModel.getIsHidden(); if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java index 92b8095ad4..37b4757017 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java @@ -86,7 +86,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); - mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { + mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> { Boolean isHidden = mlModel.getIsHidden(); modelAccessControlHelper .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java index 26c59decdf..183c081da1 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java @@ -85,7 +85,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> { Boolean isHidden = mlModel.getIsHidden(); modelAccessControlHelper .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java index dab8410ad0..9b378f3334 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java @@ -91,7 +91,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); - mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { + mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> { FunctionName functionName = mlModel.getAlgorithm(); Boolean isHidden = mlModel.getIsHidden(); if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index 76e17e9675..6151e645a3 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -149,7 +149,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); - mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { + mlModelManager.getModel(modelId, tenantId, null, excludes, ActionListener.wrap(mlModel -> { FunctionName functionName = mlModel.getAlgorithm(); Boolean isHidden = mlModel.getIsHidden(); if (!TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModel.getTenantId(), listener)) { @@ -284,7 +284,7 @@ private void deployModel( mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { String taskId = response.getId(); mlTask.setTaskId(taskId); - if (algorithm == FunctionName.REMOTE) { + if (algorithm == FunctionName.REMOTE && !mlFeatureEnabledSetting.isMultiTenancyEnabled()) { mlTaskManager.add(mlTask, eligibleNodeIds); deployRemoteModel(mlModel, mlTask, localNodeId, eligibleNodes, deployToAllNodes, listener); return; diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java index 495ea771f2..e86fda3ef8 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java @@ -224,6 +224,7 @@ private void deployModel( mlModelManager .deployModel( modelId, + null, modelContentHash, functionName, deployToAllNodes, diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java index a33cb3bbe3..381e9f529c 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java @@ -103,6 +103,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); - mlModelManager.getModel(sdkClient, modelId, null, excludes, ActionListener.wrap(mlModel -> { + mlModelManager.getModel(modelId, tenantId, null, excludes, ActionListener.wrap(mlModel -> { if (TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModel.getTenantId(), actionListener)) { if (!isModelDeploying(mlModel.getModelState())) { FunctionName functionName = mlModel.getAlgorithm(); diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 60c50d6716..9e95d94304 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -211,7 +211,7 @@ public void onFailure(Exception e) { modelActionListener.onResponse(cachedMlModel); } else { // For multi-node cluster, the function name is null in cache, so should always get model first. - mlModelManager.getModel(modelId, modelActionListener); + mlModelManager.getModel(modelId, tenantId, modelActionListener); } } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index f22e51b24a..982a874e3d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -419,6 +419,7 @@ private MLRegisterModelGroupInput createRegisterModelGroupRequest(MLRegisterMode .backendRoles(registerModelInput.getBackendRoles()) .modelAccessMode(registerModelInput.getAccessMode()) .isAddAllBackendRoles(registerModelInput.getAddAllBackendRoles()) + .tenantId(registerModelInput.getTenantId()) .build(); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java index 734b9209a6..ba82d91e00 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java @@ -182,7 +182,7 @@ private void validateAccess(String modelId, String tenantId, ActionListener { + mlModelManager.getModel(modelId, tenantId, null, excludes, ActionListener.runBefore(ActionListener.wrap(mlModel -> { if (!TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModel.getTenantId(), listener)) { return; } diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java index bf719164b2..ff650e3c41 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java @@ -112,6 +112,7 @@ public void validateConnectorAccess( GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest .builder() .index(ML_CONNECTOR_INDEX) + .tenantId(tenantId) .id(connectorId) .fetchSourceContext(fetchSourceContext) .build(); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 7522a2affe..4e61b8b4a4 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -10,6 +10,7 @@ import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.CommonValue.NOT_FOUND; @@ -102,7 +103,6 @@ import org.opensearch.index.reindex.DeleteByQueryRequest; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.cluster.DiscoveryNodeHelper; -import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLModelGroup; @@ -164,6 +164,7 @@ public class MLModelManager { public static final long MODEL_FILE_SIZE_LIMIT = 4L * 1024 * 1024 * 1024;// 4GB private final Client client; + private final SdkClient sdkClient; private final ClusterService clusterService; private final ScriptService scriptService; private final ThreadPool threadPool; @@ -196,6 +197,7 @@ public MLModelManager( ClusterService clusterService, ScriptService scriptService, Client client, + SdkClient sdkClient, ThreadPool threadPool, NamedXContentRegistry xContentRegistry, ModelHelper modelHelper, @@ -209,6 +211,7 @@ public MLModelManager( DiscoveryNodeHelper nodeHelper ) { this.client = client; + this.sdkClient = sdkClient; this.threadPool = threadPool; this.xContentRegistry = xContentRegistry; this.modelHelper = modelHelper; @@ -367,7 +370,11 @@ public void registerMLRemoteModel( mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); String modelGroupId = mlRegisterModelInput.getModelGroupId(); - GetDataObjectRequest getModelGroupRequest = GetDataObjectRequest.builder().index(ML_MODEL_GROUP_INDEX).id(modelGroupId).build(); + GetDataObjectRequest getModelGroupRequest = GetDataObjectRequest.builder() + .index(ML_MODEL_GROUP_INDEX) + .tenantId(mlRegisterModelInput.getTenantId()) + .id(modelGroupId) + .build(); sdkClient .getDataObjectAsync(getModelGroupRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) .whenComplete((r, throwable) -> { @@ -388,10 +395,10 @@ public void registerMLRemoteModel( */ modelGroupSourceMap.put(MLModelGroup.LATEST_VERSION_FIELD, updatedVersion); modelGroupSourceMap.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); - UpdateDataObjectRequest updateDataObjectRequest = UpdateDataObjectRequest - .builder() + UpdateDataObjectRequest updateDataObjectRequest = UpdateDataObjectRequest.builder() .index(ML_MODEL_GROUP_INDEX) .id(modelGroupId) + .tenantId(mlRegisterModelInput.getTenantId()) // TODO need to track these for concurrency // .setIfSeqNo(seqNo) // .setIfPrimaryTerm(primaryTerm) @@ -539,7 +546,7 @@ private UpdateRequest createUpdateModelGroupRequest( } private int incrementLatestVersion(Map modelGroupSourceMap) { - return (int) modelGroupSourceMap.get(MLModelGroup.LATEST_VERSION_FIELD) + 1; + return Integer.parseInt(modelGroupSourceMap.get(MLModelGroup.LATEST_VERSION_FIELD).toString()) + 1; } private void indexRemoteModel( @@ -582,10 +589,10 @@ private void indexRemoteModel( .tenantId(registerModelInput.getTenantId()) .build(); - PutDataObjectRequest putModelMetaRequest = PutDataObjectRequest - .builder() + PutDataObjectRequest putModelMetaRequest = PutDataObjectRequest.builder() .index(ML_MODEL_INDEX) .id(Boolean.TRUE.equals(registerModelInput.getIsHidden()) ? modelName : null) + .tenantId(registerModelInput.getTenantId()) .dataObject(mlModelMeta) .build(); @@ -934,7 +941,7 @@ private void updateModelRegisterStateAsDone( void deployModelAfterRegistering(MLRegisterModelInput registerModelInput, String modelId) { String[] modelNodeIds = registerModelInput.getModelNodeIds(); log.debug("start deploying model after registering, modelId: {} on nodes: {}", modelId, Arrays.toString(modelNodeIds)); - MLDeployModelRequest request = new MLDeployModelRequest(modelId, null, modelNodeIds, false, true, true); + MLDeployModelRequest request = new MLDeployModelRequest(modelId, registerModelInput.getTenantId(), modelNodeIds, false, true, true); ActionListener listener = ActionListener .wrap(r -> log.debug("model deployed, response {}", r), e -> log.error("Failed to deploy model", e)); client.execute(MLDeployModelAction.INSTANCE, request, listener); @@ -1001,6 +1008,7 @@ private void handleException(FunctionName functionName, String taskId, Exception * into memory. * * @param modelId model id + * @param tenantId tenant id * @param modelContentHash model content hash value * @param functionName function name * @param mlTask ML task @@ -1008,6 +1016,7 @@ private void handleException(FunctionName functionName, String taskId, Exception */ public void deployModel( String modelId, + String tenantId, String modelContentHash, FunctionName functionName, boolean deployToAllNodes, @@ -1050,7 +1059,7 @@ public void deployModel( if (!autoDeployModel) { checkAndAddRunningTask(mlTask, maxDeployTasksPerNode); } - this.getModel(modelId, threadedActionListener(DEPLOY_THREAD_POOL, ActionListener.wrap(mlModel -> { + this.getModel(modelId, tenantId, threadedActionListener(DEPLOY_THREAD_POOL, ActionListener.wrap(mlModel -> { modelCacheHelper.setIsModelEnabled(modelId, mlModel.getIsEnabled()); modelCacheHelper.setModelInfo(modelId, mlModel); if (FunctionName.REMOTE == mlModel.getAlgorithm() @@ -1175,7 +1184,7 @@ private void deployRemoteOrBuiltInModel(MLModel mlModel, Integer eligibleNodeCou return; } log.info("Set connector {} for the model: {}", mlModel.getConnectorId(), modelId); - getConnector(mlModel.getConnectorId(), ActionListener.wrap(connector -> { + getConnector(mlModel.getConnectorId(), mlModel.getTenantId(), ActionListener.wrap(connector -> { mlModel.setConnector(connector); setupParamsAndPredictable(modelId, mlModel); mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); @@ -1252,7 +1261,7 @@ public synchronized void updateModelCache(String modelId, ActionListener wrappedListener.onResponse("Successfully updated model cache for the remote model " + modelId); log.info("Completed the model cache update for the remote model {}", modelId); } else { - getConnector(mlModel.getConnectorId(), ActionListener.wrap(connector -> { + getConnector(mlModel.getConnectorId(), mlModel.getTenantId(), ActionListener.wrap(connector -> { mlModel.setConnector(connector); setupParamsAndPredictable(modelId, mlModel); wrappedListener.onResponse("Successfully updated model cache for the remote model " + modelId); @@ -1297,7 +1306,7 @@ public synchronized void deployControllerWithDeployedModel(String modelId, Actio wrappedListener.onResponse("Successfully deployed model controller for the remote model " + modelId); log.info("Deployed model controller for the remote model {}", modelId); } else { - getConnector(mlModel.getConnectorId(), ActionListener.wrap(connector -> { + getConnector(mlModel.getConnectorId(), mlModel.getTenantId(), ActionListener.wrap(connector -> { mlModel.setConnector(connector); setupParamsAndPredictable(modelId, mlModel); wrappedListener.onResponse("Successfully deployed model controller for the remote model " + modelId); @@ -1335,7 +1344,7 @@ public synchronized void undeployController(String modelId, ActionListener { + getConnector(mlModel.getConnectorId(), mlModel.getTenantId(), ActionListener.wrap(connector -> { mlModel.setConnector(connector); setupParamsAndPredictable(modelId, mlModel); wrappedListener.onResponse("Successfully undeployed model controller for the remote model " + modelId); @@ -1601,55 +1610,35 @@ public MLGuard getMLGuard(String modelId) { * @param listener action listener */ public void getModel(String modelId, ActionListener listener) { - getModel(modelId, null, null, listener); + getModel(modelId, null, listener); } - // TODO remove when all usages are migrated to SDK version /** - * Get model from model index with includes/excludes filter. + * Get model from model index. * * @param modelId model id - * @param includes fields included - * @param excludes fields excluded + * @param tenantId tenant id * @param listener action listener */ - public void getModel(String modelId, String[] includes, String[] excludes, ActionListener listener) { - GetRequest getRequest = new GetRequest(); - FetchSourceContext fetchContext = new FetchSourceContext(true, includes, excludes); - getRequest.index(ML_MODEL_INDEX).id(modelId).fetchSourceContext(fetchContext); - client.get(getRequest, ActionListener.wrap(r -> { - if (r != null && r.isExists()) { - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - String algorithmName = r.getSource().get(ALGORITHM_FIELD).toString(); - - MLModel mlModel = MLModel.parse(parser, algorithmName); - mlModel.setModelId(modelId); - listener.onResponse(mlModel); - } catch (Exception e) { - log.error("Failed to parse ml task{}", r.getId(), e); - listener.onFailure(e); - } - } else { - listener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND)); - } - }, listener::onFailure)); + public void getModel(String modelId, String tenantId, ActionListener listener) { + getModel(modelId, tenantId, null, null, listener); } + // TODO remove when all usages are migrated to SDK version /** * Get model from model index with includes/excludes filter. * - * @param sdkClient the SdkClient instance * @param modelId model id + * @param tenantId tenant id * @param includes fields included * @param excludes fields excluded * @param listener action listener */ - public void getModel(SdkClient sdkClient, String modelId, String[] includes, String[] excludes, ActionListener listener) { - GetDataObjectRequest getRequest = GetDataObjectRequest - .builder() + public void getModel(String modelId, String tenantId, String[] includes, String[] excludes, ActionListener listener) { + GetDataObjectRequest getRequest = GetDataObjectRequest.builder() .index(ML_MODEL_INDEX) .id(modelId) + .tenantId(tenantId) .fetchSourceContext(new FetchSourceContext(true, includes, excludes)) .build(); sdkClient.getDataObjectAsync(getRequest, client.threadPool().executor(GENERAL_THREAD_POOL)).whenComplete((r, throwable) -> { @@ -1713,30 +1702,53 @@ public void getController(String modelId, ActionListener listener) * Get connector from connector index. * * @param connectorId connector id + * @param tenantId tenant id * @param listener action listener */ - private void getConnector(String connectorId, ActionListener listener) { - GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId); - client.get(getRequest, ActionListener.wrap(r -> { - if (r != null && r.isExists()) { - try ( - XContentParser parser = MLNodeUtils - .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, r.getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Connector connector = Connector.createConnector(parser); - listener.onResponse(connector); - } catch (Exception e) { - log.error("Failed to parse connector:" + connectorId); - listener.onFailure(e); + private void getConnector(String connectorId, String tenantId, ActionListener listener) { + GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest.builder() + .index(ML_CONNECTOR_INDEX) + .id(connectorId) + .tenantId(tenantId) + .build(); + + sdkClient + .getDataObjectAsync(getDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) + .whenComplete((r, throwable) -> { + log.debug("Completed Get Connector Request, id:{}", connectorId); + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + if (cause instanceof IndexNotFoundException) { + log.error("Failed to get connector index", cause); + listener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND)); + } else { + log.error("Failed to get ML connector " + connectorId, cause); + listener.onFailure(cause); + } + } else { + try { + GetResponse gr = r.parser() == null ? null : GetResponse.fromXContent(r.parser()); + if (gr != null && gr.isExists()) { + try ( + XContentParser parser = MLNodeUtils + .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, gr.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Connector connector = Connector.createConnector(parser); + listener.onResponse(connector); + } catch (Exception e) { + log.error("Failed to parse connector:" + connectorId); + listener.onFailure(e); + } + } else { + listener + .onFailure(new OpenSearchStatusException("Failed to find connector:" + connectorId, RestStatus.NOT_FOUND)); + } + } catch (Exception e) { + listener.onFailure(e); + } } - } else { - listener.onFailure(new OpenSearchStatusException("Failed to find connector:" + connectorId, RestStatus.NOT_FOUND)); - } - }, e -> { - log.error("Failed to get connector", e); - listener.onFailure(new OpenSearchStatusException("Failed to get connector:" + connectorId, RestStatus.NOT_FOUND)); - })); + }); } /** diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index cfaa019b5f..204d3c479d 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -27,7 +27,9 @@ import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Injector; import org.opensearch.common.inject.Module; +import org.opensearch.common.inject.ModulesBuilder; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.IndexScopedSettings; import org.opensearch.common.settings.Setting; @@ -287,6 +289,7 @@ import org.opensearch.rest.RestController; import org.opensearch.rest.RestHandler; import org.opensearch.script.ScriptService; +import org.opensearch.sdk.SdkClient; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchRequestProcessor; import org.opensearch.search.pipeline.SearchResponseProcessor; @@ -434,7 +437,7 @@ public MachineLearningPlugin(Settings settings) { @Override public Collection createGuiceModules() { - return List.of(new SdkClientModule()); + return List.of(new SdkClientModule(null, null)); } @SneakyThrows @@ -461,6 +464,12 @@ public Collection createComponents( Settings settings = environment.settings(); Path dataPath = environment.dataFiles()[0]; Path configFile = environment.configFile(); + ModulesBuilder modules = new ModulesBuilder(); + modules.add(new SdkClientModule(client, xContentRegistry)); + Injector injector = modules.createInjector(); + + // Get the injected SdkClient instance from the injector + SdkClient sdkClient = injector.getInstance(SdkClient.class); mlIndicesHandler = new MLIndicesHandler(clusterService, client); encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); @@ -503,6 +512,7 @@ public Collection createComponents( clusterService, scriptService, client, + sdkClient, threadPool, xContentRegistry, modelHelper, diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java index da909f5474..860bcbfef8 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java @@ -109,7 +109,12 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client } }); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - modelManager.getModel(modelId, ActionListener.runBefore(listener, () -> context.restore())); + modelManager + .getModel( + modelId, + getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request), + ActionListener.runBefore(listener, () -> context.restore()) + ); } }; } diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java index 030313ae05..6e42e2cbcb 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java @@ -16,6 +16,7 @@ import java.security.AccessController; import java.security.PrivilegedAction; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; @@ -54,7 +55,9 @@ import lombok.extern.log4j.Log4j2; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeAction; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.AttributeValueUpdate; import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; @@ -110,12 +113,7 @@ public CompletionStage putDataObjectAsync(PutDataObjectRe final PutItemRequest putItemRequest = PutItemRequest.builder().tableName(tableName).item(item).build(); dynamoDbClient.putItem(putItemRequest); - String simulatedIndexResponse = simulateOpenSearchResponse( - request.index(), - request.id(), - source, - Map.of("result", "created") - ); + String simulatedIndexResponse = simulateOpenSearchResponse(request.index(), id, source, Map.of("result", "created")); return PutDataObjectResponse.builder().id(id).parser(createParser(simulatedIndexResponse)).build(); } catch (IOException e) { // Rethrow unchecked exception on XContent parsing error @@ -185,12 +183,26 @@ public CompletionStage updateDataObjectAsync(UpdateDat String source = Strings.toString(MediaTypeRegistry.JSON, request.dataObject()); JsonNode jsonNode = OBJECT_MAPPER.readTree(source); Map updateItem = JsonTransformer.convertJsonObjectToDDBAttributeMap(jsonNode); - updateItem.put(HASH_KEY, AttributeValue.builder().s(tenantId).build()); - updateItem.put(RANGE_KEY, AttributeValue.builder().s(request.id()).build()); + updateItem.remove(HASH_KEY); + updateItem.remove(RANGE_KEY); + Map updateAttributeValue = updateItem + .entrySet() + .stream() + .collect( + Collectors + .toMap( + Map.Entry::getKey, + entry -> AttributeValueUpdate.builder().action(AttributeAction.PUT).value(entry.getValue()).build() + ) + ); + Map updateKey = new HashMap<>(); + updateKey.put(HASH_KEY, AttributeValue.builder().s(tenantId).build()); + updateKey.put(RANGE_KEY, AttributeValue.builder().s(request.id()).build()); UpdateItemRequest updateItemRequest = UpdateItemRequest .builder() .tableName(getTableName(request.index())) - .key(updateItem) + .key(updateKey) + .attributeUpdates(updateAttributeValue) .build(); dynamoDbClient.updateItem(updateItemRequest); diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java index 54551cc28a..3682b95ff2 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java @@ -102,17 +102,13 @@ public CompletionStage putDataObjectAsync(PutDataObjectRe public CompletionStage getDataObjectAsync(GetDataObjectRequest request, Executor executor) { return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { try { - log.info("Getting {} from {}", request.id(), request.index()); GetResponse getResponse = client .get(new GetRequest(request.index(), request.id()).fetchSourceContext(request.fetchSourceContext())) .actionGet(); if (getResponse == null) { - log.info("Null GetResponse"); return GetDataObjectResponse.builder().id(request.id()).parser(null).build(); } - log.info("Retrieved data object"); - return GetDataObjectResponse - .builder() + return GetDataObjectResponse.builder() .id(getResponse.getId()) .parser(createParser(getResponse)) .source(getResponse.getSource()) diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java index cc9b161146..e296109bbd 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java @@ -9,13 +9,14 @@ package org.opensearch.ml.sdkclient; import org.apache.http.HttpHost; -import org.apache.http.conn.ssl.NoopHostnameVerifier; +import org.apache.http.impl.client.BasicCredentialsProvider; import org.opensearch.OpenSearchException; -import org.opensearch.client.RestClient; -import org.opensearch.client.json.jackson.JacksonJsonpMapper; +import org.opensearch.client.Client; import org.opensearch.client.opensearch.OpenSearchClient; -import org.opensearch.client.transport.rest_client.RestClientTransport; +import org.opensearch.client.transport.aws.AwsSdk2Transport; +import org.opensearch.client.transport.aws.AwsSdk2TransportOptions; import org.opensearch.common.inject.AbstractModule; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.sdk.SdkClient; import com.fasterxml.jackson.annotation.JsonInclude; @@ -27,6 +28,8 @@ import software.amazon.awssdk.auth.credentials.ContainerCredentialsProvider; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider; +import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.http.apache.ApacheHttpClient; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; @@ -45,12 +48,20 @@ public class SdkClientModule extends AbstractModule { private final String remoteMetadataType; private final String remoteMetadataEndpoint; private final String region; // not using with RestClient + private Client client; + private NamedXContentRegistry namedXContentRegistry; /** * Instantiate this module using environment variables */ - public SdkClientModule() { - this(System.getenv(REMOTE_METADATA_TYPE), System.getenv(REMOTE_METADATA_ENDPOINT), System.getenv(REGION)); + public SdkClientModule(Client client, NamedXContentRegistry namedXContentRegistry) { + this( + client, + namedXContentRegistry, + System.getenv(REMOTE_METADATA_TYPE), + System.getenv(REMOTE_METADATA_ENDPOINT), + System.getenv(REGION) + ); } /** @@ -59,7 +70,15 @@ public SdkClientModule() { * @param remoteMetadataEndpoint The remote endpoint * @param region The region */ - SdkClientModule(String remoteMetadataType, String remoteMetadataEndpoint, String region) { + SdkClientModule( + Client client, + NamedXContentRegistry namedXContentRegistry, + String remoteMetadataType, + String remoteMetadataEndpoint, + String region + ) { + this.client = client; + this.namedXContentRegistry = namedXContentRegistry; this.remoteMetadataType = remoteMetadataType; this.remoteMetadataEndpoint = remoteMetadataEndpoint; this.region = region; @@ -69,7 +88,7 @@ public SdkClientModule() { protected void configure() { if (this.remoteMetadataType == null) { log.info("Using local opensearch cluster as metadata store"); - bind(SdkClient.class).to(LocalClusterIndicesClient.class); + bindLocalClient(); return; } @@ -85,7 +104,15 @@ protected void configure() { return; default: log.info("Using local opensearch cluster as metadata store"); - bind(SdkClient.class).to(LocalClusterIndicesClient.class); + bindLocalClient(); + } + } + + private void bindLocalClient() { + if (client == null) { + bind(SdkClient.class).to(LocalClusterIndicesClient.class); + } else { + bind(SdkClient.class).toInstance(new LocalClusterIndicesClient(this.client, this.namedXContentRegistry)); } } @@ -106,23 +133,35 @@ private DynamoDbClient createDynamoDbClient() { private OpenSearchClient createOpenSearchClient() { try { + BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider(); // Basic http(not-s) client using RestClient. - RestClient restClient = RestClient + SdkHttpClient httpClient = ApacheHttpClient.builder().build(); + AwsSdk2Transport awsSdk2Transport = new AwsSdk2Transport( + httpClient, + HttpHost.create(remoteMetadataEndpoint).getHostName(), + "aoss", + Region.of(region), + AwsSdk2TransportOptions.builder().build() + ); + /*RestClient restClient = RestClient // This HttpHost syntax works with export REMOTE_METADATA_ENDPOINT=http://127.0.0.1:9200 .builder(HttpHost.create(remoteMetadataEndpoint)) .setStrictDeprecationMode(true) .setHttpClientConfigCallback(httpClientBuilder -> { try { - return httpClientBuilder.setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE); + return httpClientBuilder + .setDefaultCredentialsProvider(credentialsProvider) + .setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE); } catch (Exception e) { throw new OpenSearchException(e); } }) - .build(); + .build();*/ ObjectMapper objectMapper = new ObjectMapper() .setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE) .setSerializationInclusion(JsonInclude.Include.NON_NULL); - return new OpenSearchClient(new RestClientTransport(restClient, new JacksonJsonpMapper(objectMapper))); + // return new OpenSearchClient(new RestClientTransport(restClient, new JacksonJsonpMapper(objectMapper))); + return new OpenSearchClient(awsSdk2Transport); } catch (Exception e) { throw new OpenSearchException(e); } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index a2fba5f959..1e75108e7a 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -5,9 +5,6 @@ package org.opensearch.ml.task; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; -import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD; import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage; import static org.opensearch.ml.permission.AccessController.checkUserPermissions; import static org.opensearch.ml.permission.AccessController.getUserContext; @@ -23,23 +20,18 @@ import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionListenerResponseHandler; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ThreadedActionListener; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; @@ -154,7 +146,7 @@ public void dispatchTask( if (workerNodes == null || workerNodes.length == 0) { if (FunctionName.isAutoDeployEnabled(autoDeploymentEnabled, functionName)) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - mlModelManager.getModel(modelId, ActionListener.runBefore(ActionListener.wrap(model -> { + mlModelManager.getModel(modelId, request.getTenantId(), ActionListener.runBefore(ActionListener.wrap(model -> { Boolean isHidden = model.getIsHidden(); if (!checkModelAutoDeployEnabled(model)) { final String errorMsg = getErrorMessage( @@ -245,7 +237,7 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener dataFrameActionListener = ActionListener.wrap(dataSet -> { MLInput newInput = mlInput.toBuilder().inputDataset(dataSet).build(); - predict(modelId, mlTask, newInput, listener); + predict(modelId, request.getTenantId(), mlTask, newInput, listener); }, e -> { log.error("Failed to generate DataFrame from search query", e); handleAsyncMLTaskFailure(mlTask, e); @@ -258,7 +250,7 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener { predict(modelId, mlTask, mlInput, listener); }); + threadPool.executor(threadPoolName).execute(() -> { predict(modelId, request.getTenantId(), mlTask, mlInput, listener); }); break; } } @@ -274,7 +266,7 @@ private String getPredictThreadPool(FunctionName functionName) { return functionName == FunctionName.REMOTE ? REMOTE_PREDICT_THREAD_POOL : PREDICT_THREAD_POOL; } - private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListener listener) { + private void predict(String modelId, String tenantId, MLTask mlTask, MLInput mlInput, ActionListener listener) { ActionListener internalListener = wrappedCleanupListener(listener, mlTask.getTaskId()); // track ML task count and add ML task into cache mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); @@ -305,8 +297,8 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe .state(MLTaskState.RUNNING) .workerNodes(Arrays.asList(clusterService.localNode().getId())) .build(); - mlModelManager.deployModel(modelId, null, functionName, false, true, mlDeployTask, ActionListener.wrap(s -> { - runPredict(modelId, mlTask, mlInput, functionName, internalListener); + mlModelManager.deployModel(modelId, tenantId, null, functionName, false, true, mlDeployTask, ActionListener.wrap(s -> { + runPredict(modelId, tenantId, mlTask, mlInput, functionName, internalListener); }, e -> { log.error("Failed to auto deploy model " + modelId, e); internalListener.onFailure(e); @@ -314,11 +306,12 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe return; } - runPredict(modelId, mlTask, mlInput, functionName, internalListener); + runPredict(modelId, tenantId, mlTask, mlInput, functionName, internalListener); } private void runPredict( String modelId, + String tenantId, MLTask mlTask, MLInput mlInput, FunctionName algorithm, @@ -367,21 +360,12 @@ private void runPredict( // search model by model id. try (ThreadContext.StoredContext context = threadPool.getThreadContext().stashContext()) { - ActionListener getModelListener = ActionListener.wrap(r -> { - if (r == null || !r.isExists()) { + ActionListener getModelListener = ActionListener.wrap(mlModel -> { + if (mlModel == null) { internalListener.onFailure(new ResourceNotFoundException("No model found, please check the modelId.")); return; } - try ( - XContentParser xContentParser = XContentType.JSON - .xContent() - .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, r.getSourceAsString()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, xContentParser.nextToken(), xContentParser); - GetResponse getResponse = r; - String algorithmName = getResponse.getSource().get(ALGORITHM_FIELD).toString(); - MLModel mlModel = MLModel.parse(xContentParser, algorithmName); - mlModel.setModelId(modelId); + try { User resourceUser = mlModel.getUser(); User requestUser = getUserContext(client); if (!checkUserPermissions(requestUser, resourceUser, modelId)) { @@ -416,10 +400,10 @@ private void runPredict( log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + mlTask.getModelId(), e); handlePredictFailure(mlTask, internalListener, e, true, modelId); }); - GetRequest getRequest = new GetRequest(ML_MODEL_INDEX, mlTask.getModelId()); - client - .get( - getRequest, + mlModelManager + .getModel( + mlTask.getModelId(), + tenantId, threadedActionListener( mlTask.getFunctionName(), ActionListener.runBefore(getModelListener, () -> context.restore()) diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java index c9a4a1a6d5..2a9d507de6 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java @@ -5,9 +5,7 @@ package org.opensearch.ml.action.controller; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; @@ -28,6 +26,7 @@ import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.Version; import org.opensearch.action.DocWriteResponse; @@ -169,10 +168,10 @@ public void setup() throws IOException { }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), Mockito.isNull(), any(), any(), isA(ActionListener.class)); when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); doAnswer(invocation -> { @@ -247,10 +246,10 @@ public void testCreateControllerWithModelAccessControlOtherException() { @Test public void testCreateControllerWithModelNotFound() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(null); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), any(), isA(ActionListener.class)); createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java index 1e49ab2fd7..4fb9a968cf 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java @@ -160,10 +160,10 @@ public void setup() throws IOException { }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -216,10 +216,10 @@ public void testDeleteControllerWithModelAccessControlNoPermissionHiddenModel() when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); when(mlModel.getIsHidden()).thenReturn(Boolean.TRUE); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(false); @@ -255,10 +255,10 @@ public void testDeleteControllerWithModelAccessControlOtherExceptionHiddenModel( when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); when(mlModel.getIsHidden()).thenReturn(Boolean.TRUE); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener @@ -280,10 +280,10 @@ public void testDeleteControllerWithModelAccessControlOtherExceptionHiddenModel( @Test public void testDeleteControllerWithGetModelNotFoundSuccess() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(null); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), any(), isA(ActionListener.class)); deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); verify(actionListener).onResponse(deleteResponse); @@ -320,10 +320,10 @@ public void testDeleteControllerWithGetControllerOtherException() { @Test public void testDeleteControllerWithGetModelNotFoundWithGetControllerOtherException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(null); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java index 489e71e080..49b771271c 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java @@ -5,9 +5,7 @@ package org.opensearch.ml.action.controller; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; @@ -107,10 +105,10 @@ public void setup() throws IOException { mlControllerGetRequest = MLControllerGetRequest.builder().modelId("testModelId").build(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -170,10 +168,10 @@ public void testGetControllerWithModelAccessControlOtherException() { @Test public void testGetControllerWithGetModelNotFound() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(null); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), any(), isA(ActionListener.class)); getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -187,10 +185,10 @@ public void testGetControllerWithGetModelNotFound() { @Test public void testGetControllerWithGetModelOtherException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), any(), isA(ActionListener.class)); getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java index fd378647e9..f1a87a7dfb 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java @@ -179,10 +179,10 @@ public void setup() throws IOException { }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), any(), isA(ActionListener.class)); when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); when(mlModel.getModelId()).thenReturn("testModelId"); @@ -246,10 +246,10 @@ public void testUpdateControllerWithModelAccessControlNoPermissionHiddenModel() when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); when(mlModel.getIsHidden()).thenReturn(Boolean.TRUE); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(false); @@ -285,10 +285,10 @@ public void testUpdateControllerWithModelAccessControlOtherExceptionHiddenModel( when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); when(mlModel.getIsHidden()).thenReturn(Boolean.TRUE); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onFailure(new RuntimeException("Permission denied: Unable to create the model controller for the model. Details: ")); @@ -328,10 +328,10 @@ public void testUpdateControllerWithControllerEnabledNullHiddenModel() { when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); when(mlModel.getIsHidden()).thenReturn(Boolean.TRUE); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); @@ -396,10 +396,10 @@ public void testUpdateControllerWithModelFunctionUnsupported() { @Test public void tesUpdateControllerWithGetModelNotFound() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(null); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), any(), isA(ActionListener.class)); updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -506,10 +506,10 @@ public void testUpdateControllerWithUndeploySuccessPartiallyFailuresHiddenModel( when(mlModel.getIsHidden()).thenReturn(Boolean.TRUE); when(mlModel.getModelId()).thenReturn("testModelId"); doAnswer(invocation -> { - ActionListener mllistener = invocation.getArgument(3); + ActionListener mllistener = invocation.getArgument(4); mllistener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); List failures = List .of(new FailedNodeException("foo1", "Undeploy failed.", new RuntimeException("Exception occurred."))); when(mlModelCacheHelper.getWorkerNodes("testModelId")).thenReturn(new String[] { "foo1", "foo2" }); @@ -556,10 +556,10 @@ public void testUpdateControllerWithUndeployNullResponseHiddenModel() { when(mlModel.getIsHidden()).thenReturn(Boolean.TRUE); when(mlModel.getModelId()).thenReturn("testModelId"); doAnswer(invocation -> { - ActionListener mllistener = invocation.getArgument(3); + ActionListener mllistener = invocation.getArgument(4); mllistener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); when(mlModelCacheHelper.getWorkerNodes("testModelId")).thenReturn(new String[] { "foo1", "foo2" }); @@ -606,10 +606,10 @@ public void testUpdateControllerWithUndeployOtherExceptionHiddenModel() { when(mlModel.getIsHidden()).thenReturn(Boolean.TRUE); when(mlModel.getModelId()).thenReturn("testModelId"); doAnswer(invocation -> { - ActionListener mllistener = invocation.getArgument(3); + ActionListener mllistener = invocation.getArgument(4); mllistener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); when(mlModelCacheHelper.getWorkerNodes("testModelId")).thenReturn(new String[] { "foo1", "foo2" }); diff --git a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java index 8b8ee5234f..bce5c4b2dd 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java @@ -212,10 +212,10 @@ public void testDoExecute_success() { MLModel mlModel = mock(MLModel.class); when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); IndexResponse indexResponse = mock(IndexResponse.class); when(indexResponse.getId()).thenReturn("mockIndexId"); @@ -234,10 +234,10 @@ public void testDoExecute_success_not_userInitiatedRequest() { MLModel mlModel = mock(MLModel.class); when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); when(mlDeployModelRequest.isUserInitiatedDeployRequest()).thenReturn(false); @@ -279,10 +279,10 @@ public void testDoExecute_success_hidden_model() { when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); when(mlModel.getIsHidden()).thenReturn(true); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); IndexResponse indexResponse = mock(IndexResponse.class); when(indexResponse.getId()).thenReturn("mockIndexId"); @@ -325,10 +325,10 @@ public void testDoExecute_no_permission_hidden_model() { when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); when(mlModel.getIsHidden()).thenReturn(true); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); IndexResponse indexResponse = mock(IndexResponse.class); when(indexResponse.getId()).thenReturn("mockIndexId"); @@ -351,10 +351,10 @@ public void testDoExecute_userHasNoAccessException() { MLModel mlModel = mock(MLModel.class); when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -373,10 +373,10 @@ public void testDoExecuteRemoteInferenceDisabled() { MLModel mlModel = mock(MLModel.class); when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false); ActionListener deployModelResponseListener = mock(ActionListener.class); @@ -390,10 +390,10 @@ public void testDoExecuteLocalInferenceDisabled() { MLModel mlModel = mock(MLModel.class); when(mlModel.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false); ActionListener deployModelResponseListener = mock(ActionListener.class); @@ -407,10 +407,10 @@ public void test_ValidationFailedException() { MLModel mlModel = mock(MLModel.class); when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -491,7 +491,7 @@ public void testDoExecute_whenDeployModelRequestNodeIdsEmpty_thenMLResourceNotFo public void testDoExecute_whenGetModelHasNPE_exception() { doThrow(NullPointerException.class) .when(mlModelManager) - .getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + .getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); ActionListener deployModelResponseListener = mock(ActionListener.class); transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener); @@ -502,10 +502,10 @@ public void testDoExecute_whenThreadPoolExecutorException_TaskRemoved() { MLModel mlModel = mock(MLModel.class); when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); IndexResponse indexResponse = mock(IndexResponse.class); when(indexResponse.getId()).thenReturn("mockIndexId"); diff --git a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeActionTests.java index 83852cc68f..2fa105957e 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeActionTests.java @@ -207,7 +207,7 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(5); listener.onResponse("successful"); return null; - }).when(mlModelManager).deployModel(any(), any(), any(), any(Boolean.class), any(), any(), any()); + }).when(mlModelManager).deployModel(any(), any(), any(), any(), any(Boolean.class), any(), any(), any()); MLForwardResponse forwardResponse = Mockito.mock(MLForwardResponse.class); doAnswer(invocation -> { ActionListenerResponseHandler handler = invocation.getArgument(3); @@ -313,7 +313,7 @@ public void testNodeOperation_FailToSendForwardRequest() { ActionListener listener = invocation.getArgument(4); listener.onResponse("ok"); return null; - }).when(mlModelManager).deployModel(any(), any(), any(), any(Boolean.class), any(), any(), any()); + }).when(mlModelManager).deployModel(any(), any(), any(), any(), any(Boolean.class), any(), any(), any()); doAnswer(invocation -> { TransportResponseHandler handler = invocation.getArgument(3); handler.handleException(new TransportException("error")); @@ -331,7 +331,7 @@ public void testNodeOperation_Exception() { ActionListener listener = invocation.getArgument(4); listener.onFailure(new RuntimeException("Something went wrong")); return null; - }).when(mlModelManager).deployModel(any(), any(), any(), any(Boolean.class), any(), any(), any()); + }).when(mlModelManager).deployModel(any(), any(), any(), any(), any(Boolean.class), any(), any(), any()); final MLDeployModelNodesRequest nodesRequest = prepareRequest(localNode.getId()); final MLDeployModelNodeRequest request = action.newNodeRequest(nodesRequest); final MLDeployModelNodeResponse response = action.nodeOperation(request); @@ -342,7 +342,7 @@ public void testNodeOperation_Exception() { public void testNodeOperation_DeployModelRuntimeException() { doThrow(new RuntimeException("error")) .when(mlModelManager) - .deployModel(any(), any(), any(), any(Boolean.class), any(), any(), any()); + .deployModel(any(), any(), any(), any(), any(Boolean.class), any(), any(), any()); final MLDeployModelNodesRequest nodesRequest = prepareRequest(localNode.getId()); final MLDeployModelNodeRequest request = action.newNodeRequest(nodesRequest); final MLDeployModelNodeResponse response = action.nodeOperation(request); @@ -355,7 +355,7 @@ public void testNodeOperation_MLLimitExceededException() { ActionListener listener = invocation.getArgument(4); listener.onFailure(new MLLimitExceededException("Limit exceeded exception")); return null; - }).when(mlModelManager).deployModel(any(), any(), any(), any(Boolean.class), any(), any(), any()); + }).when(mlModelManager).deployModel(any(), any(), any(), any(), any(Boolean.class), any(), any(), any()); final MLDeployModelNodesRequest nodesRequest = prepareRequest(localNode.getId()); final MLDeployModelNodeRequest request = action.newNodeRequest(nodesRequest); final MLDeployModelNodeResponse response = action.nodeOperation(request); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java index 5129bfa16a..5e5f4e8199 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java @@ -364,13 +364,13 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(3); listener.onResponse(localModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(4); listener.onResponse(localModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); MLModelGroup modelGroup = MLModelGroup .builder() @@ -445,7 +445,7 @@ public void testUpdateRemoteModelWithLocalInformationSuccess() throws Interrupte ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); @@ -465,7 +465,7 @@ public void testUpdateExternalRemoteModelWithExternalRemoteInformationSuccess() ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); @@ -485,7 +485,7 @@ public void testUpdateInternalRemoteModelWithInternalRemoteInformationSuccess() ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); @@ -505,7 +505,7 @@ public void testUpdateHiddenRemoteModelWithRemoteInformationSuccess() throws Int ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doReturn(true).when(transportUpdateModelAction).isSuperAdminUserWrapper(clusterService, client); CountDownLatch latch = new CountDownLatch(1); @@ -526,7 +526,7 @@ public void testUpdateHiddenRemoteModelPermissionError() { ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doReturn(false).when(transportUpdateModelAction).isSuperAdminUserWrapper(clusterService, client); transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_INTERNAL"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -541,7 +541,7 @@ public void testUpdateRemoteModelWithNoExternalConnectorFound() { ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModelWithInternalConnector); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_EXTERNAL"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -559,7 +559,7 @@ public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControl ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(5); @@ -583,7 +583,7 @@ public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControl ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(5); @@ -712,7 +712,7 @@ public void testUpdateModelWithModelNotFound() { ActionListener listener = invocation.getArgument(4); listener.onResponse(null); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -726,7 +726,7 @@ public void testUpdateModelWithFunctionNameFieldNotFound() { ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModelWithNullFunctionName); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -756,7 +756,7 @@ public void testUpdateLocalModelWithUnsupportedFunction() { ActionListener listener = invocation.getArgument(4); listener.onResponse(localModelWithUnsupportedFunction); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_EXTERNAL"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -773,7 +773,7 @@ public void testUpdateRequestDocIOException() throws IOException, InterruptedExc ActionListener listener = invocation.getArgument(4); listener.onResponse(mockModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("mockId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), any(), isA(ActionListener.class)); doReturn("test_model_group_id").when(mockModel).getModelGroupId(); doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); @@ -800,7 +800,7 @@ public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IO ActionListener listener = invocation.getArgument(4); listener.onResponse(mockModel); return null; - }).when(mlModelManager).getModel(eq(sdkClient), eq("mockId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), any(), isA(ActionListener.class)); doReturn("test_model_group_id").when(mockModel).getModelGroupId(); doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); @@ -957,7 +957,7 @@ public void testUpdateModelStateDeployingException() { ActionListener listener = invocation.getArgument(4); listener.onResponse(testDeployingModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); @@ -975,7 +975,7 @@ public void testUpdateModelStateLoadingException() { ActionListener listener = invocation.getArgument(4); listener.onResponse(testDeployingModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); @@ -993,7 +993,7 @@ public void testUpdateModelCacheModelStateDeployedSuccess() throws InterruptedEx ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1022,7 +1022,7 @@ public void testUpdateModelCacheModelWithIsModelEnabledSuccess() throws Interrup ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1053,7 +1053,7 @@ public void testUpdateModelCacheModelWithoutUpdateConnectorWithRateLimiterSucces ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1085,7 +1085,7 @@ public void testUpdateModelCacheModelWithRateLimiterSuccess() throws Interrupted ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1116,7 +1116,7 @@ public void testUpdateModelWithPartialRateLimiterSuccess() throws InterruptedExc ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); MLRateLimiter rateLimiter = MLRateLimiter.builder().limit("1").build(); MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); @@ -1142,7 +1142,7 @@ public void testUpdateModelCacheModelWithPartialRateLimiterSuccess() throws Inte ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1173,7 +1173,7 @@ public void testUpdateModelCacheUpdateResponseListenerWithNullUpdateResponse() t ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); PlainActionFuture future = PlainActionFuture.newFuture(); future.onResponse(null); @@ -1206,7 +1206,7 @@ public void testUpdateModelCacheModelWithUndeploySuccessEmptyFailures() throws I ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1239,7 +1239,7 @@ public void testUpdateControllerWithUndeploySuccessPartiallyFailures() throws In ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1271,7 +1271,7 @@ public void testUpdateControllerWithUndeployNullResponse() throws InterruptedExc ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1302,7 +1302,7 @@ public void testUpdateControllerWithUndeployOtherException() throws InterruptedE ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); @@ -1337,7 +1337,7 @@ public void testUpdateModelCacheModelStateDeployedWrongStatus() throws Interrupt ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1366,7 +1366,7 @@ public void testUpdateModelCacheModelStateDeployedUpdateModelCacheException() th ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1411,7 +1411,7 @@ public void testUpdateModelCacheModelStateDeployedUpdateException() throws Inter ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1442,7 +1442,7 @@ public void testUpdateModelCacheModelRegisterToNewModelGroupSuccess() throws Int ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1475,7 +1475,7 @@ public void testUpdateModelCacheModelRegisterToNewModelGroupWrongStatus() throws ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1503,7 +1503,7 @@ public void testUpdateModelCacheModelRegisterToNewModelGroupUpdateModelCacheExce ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1547,7 +1547,7 @@ public void testUpdateModelCacheModelRegisterToNewModelGroupUpdateException() th ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1577,7 +1577,7 @@ public void testUpdateModelCacheModelStateLoadedSuccess() throws InterruptedExce ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1606,7 +1606,7 @@ public void testUpdateModelCacheModelStatePartiallyDeployedSuccess() throws Inte ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1635,7 +1635,7 @@ public void testUpdateModelCacheModelStatePartiallyLoadedSuccess() throws Interr ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1677,6 +1677,7 @@ private MLModel prepareMLModel(String functionName, MLModelState modelState, boo mlModel = MLModel .builder() .name("test_name") + .tenantId("tenant_id") .modelId("test_model_id") .modelGroupId("test_model_group_id") .description("test_description") @@ -1797,7 +1798,7 @@ public void testUpdateModelStatePartiallyLoadedException() { ActionListener listener = invocation.getArgument(4); listener.onResponse(mockModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("mockId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("mockId"), anyString(), any(), any(), isA(ActionListener.class)); doReturn("test_model_group_id").when(mockModel).getModelGroupId(); doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); @@ -1822,7 +1823,7 @@ public void testUpdateModelStatePartiallyDeployedException() { ActionListener listener = invocation.getArgument(4); listener.onResponse(mockModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("mockId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("mockId"), anyString(), any(), any(), isA(ActionListener.class)); doReturn("test_model_group_id").when(mockModel).getModelGroupId(); doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); diff --git a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java index 4cf82f948f..d7b6088aab 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java @@ -188,10 +188,10 @@ public void setup() throws IOException { .isHidden(false) .build(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class)); } @AfterClass @@ -213,10 +213,10 @@ public void testHiddenModelSuccess() { .isHidden(true) .build(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class)); List responseList = new ArrayList<>(); List failuresList = new ArrayList<>(); @@ -247,10 +247,10 @@ public void testHiddenModelPermissionError() { .isHidden(true) .build(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class)); List responseList = new ArrayList<>(); List failuresList = new ArrayList<>(); @@ -292,7 +292,7 @@ public void testDoExecute() { public void testDoExecute_modelAccessControl_notEnabled() { when(modelAccessControlHelper.isModelAccessControlEnabled()).thenReturn(false); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(true); return null; }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); @@ -327,17 +327,19 @@ public void testDoExecute_validate_false() { public void testDoExecute_getModel_exception() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onFailure(new RuntimeException("runtime exception")); return null; - }).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class)); MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); transportUndeployModelsAction.doExecute(task, request, actionListener); verify(actionListener).onFailure(isA(RuntimeException.class)); } public void testDoExecute_validateAccess_exception() { - doThrow(new RuntimeException("runtime exception")).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); + doThrow(new RuntimeException("runtime exception")) + .when(mlModelManager) + .getModel(any(), any(), any(), any(), isA(ActionListener.class)); MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); transportUndeployModelsAction.doExecute(task, request, actionListener); verify(actionListener).onFailure(isA(RuntimeException.class)); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 326721803d..250ce2bd98 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -39,9 +39,7 @@ import static org.opensearch.ml.utils.MockHelper.mock_MLIndicesHandler_initModelIndex_failure; import static org.opensearch.ml.utils.MockHelper.mock_client_ThreadContext; import static org.opensearch.ml.utils.MockHelper.mock_client_ThreadContext_Exception; -import static org.opensearch.ml.utils.MockHelper.mock_client_get_NotExist; import static org.opensearch.ml.utils.MockHelper.mock_client_get_NullResponse; -import static org.opensearch.ml.utils.MockHelper.mock_client_get_failure; import static org.opensearch.ml.utils.MockHelper.mock_client_index; import static org.opensearch.ml.utils.MockHelper.mock_client_index_failure; import static org.opensearch.ml.utils.MockHelper.mock_client_update; @@ -90,9 +88,14 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.get.GetResult; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.breaker.MemoryCircuitBreaker; import org.opensearch.ml.breaker.ThresholdCircuitBreaker; @@ -207,7 +210,7 @@ public class MLModelManagerTests extends OpenSearchTestCase { private MLTask pretrainedMLTask; @Before - public void setup() throws URISyntaxException { + public void setup() throws URISyntaxException, IOException { String masterKey = "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="; MockitoAnnotations.openMocks(this); @@ -287,6 +290,7 @@ public void setup() throws URISyntaxException { clusterService, scriptService, client, + sdkClient, threadPool, xContentRegistry, modelHelper, @@ -339,6 +343,16 @@ public void setup() throws URISyntaxException { }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); } + public void setupGetModel(MLModel model) throws IOException { + XContentBuilder content = model.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); + GetResponse getResponse = new GetResponse(getResult); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(getResponse); + when(client.get(any(GetRequest.class))).thenReturn(future); + } + @AfterClass public static void cleanup() { ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); @@ -633,9 +647,12 @@ public void testDeployModel_FailedToGetModel() { when(modelCacheHelper.getDeployedModels()).thenReturn(new String[] {}); when(modelCacheHelper.getLocalDeployedModels()).thenReturn(new String[] {}); mock_threadpool(threadPool, taskExecutorService); - mock_client_get_failure(client); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onFailure(new RuntimeException("get doc failure")); + when(client.get(any(GetRequest.class))).thenReturn(future); + mock_client_ThreadContext(client, threadPool, threadContext); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); + modelManager.deployModel(modelId, null, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); assertFalse(modelManager.isModelRunningOnNode(modelId)); ArgumentCaptor exception = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(exception.capture()); @@ -648,26 +665,13 @@ public void testDeployModel_FailedToGetModel() { ); } - public void testDeployModel_NullGetModelResponse() { + public void testDeployModel_NullGetModelResponse() throws IOException { MLModelConfig modelConfig = TextEmbeddingModelConfig .builder() .modelType("bert") .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) .embeddingDimension(384) .build(); - model = MLModel - .builder() - .modelId(modelId) - .modelState(MLModelState.DEPLOYING) - .algorithm(FunctionName.TEXT_EMBEDDING) - .name(modelName) - .version(version) - .totalChunks(2) - .modelFormat(MLModelFormat.TORCH_SCRIPT) - .modelConfig(modelConfig) - .modelContentHash(modelContentHashValue) - .modelContentSizeInBytes(modelContentSize) - .build(); String[] nodes = new String[] { "node1", "node2" }; mlTask.setWorkerNodes(List.of(nodes)); ActionListener listener = mock(ActionListener.class); @@ -676,7 +680,10 @@ public void testDeployModel_NullGetModelResponse() { when(modelCacheHelper.getLocalDeployedModels()).thenReturn(new String[] {}); mock_threadpool(threadPool, taskExecutorService); mock_client_get_NullResponse(client); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(null); + when(client.get(any(GetRequest.class))).thenReturn(future); + modelManager.deployModel(modelId, null, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); assertFalse(modelManager.isModelRunningOnNode(modelId)); ArgumentCaptor exception = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(exception.capture()); @@ -689,26 +696,13 @@ public void testDeployModel_NullGetModelResponse() { ); } - public void testDeployModel_GetModelResponse_NotExist() { + public void testDeployModel_GetModelResponse_NotExist() throws IOException { MLModelConfig modelConfig = TextEmbeddingModelConfig .builder() .modelType("bert") .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) .embeddingDimension(384) .build(); - model = MLModel - .builder() - .modelId(modelId) - .modelState(MLModelState.DEPLOYING) - .algorithm(FunctionName.TEXT_EMBEDDING) - .name(modelName) - .version(version) - .totalChunks(2) - .modelFormat(MLModelFormat.TORCH_SCRIPT) - .modelConfig(modelConfig) - .modelContentHash(modelContentHashValue) - .modelContentSizeInBytes(modelContentSize) - .build(); String[] nodes = new String[] { "node1", "node2" }; mlTask.setWorkerNodes(List.of(nodes)); ActionListener listener = mock(ActionListener.class); @@ -716,8 +710,15 @@ public void testDeployModel_GetModelResponse_NotExist() { when(modelCacheHelper.getDeployedModels()).thenReturn(new String[] {}); when(modelCacheHelper.getLocalDeployedModels()).thenReturn(new String[] {}); mock_threadpool(threadPool, taskExecutorService); - mock_client_get_NotExist(client); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); + XContentBuilder content = model.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", "111", -2, 0, 111l, false, bytesReference, null, null); + GetResponse getResponse = new GetResponse(getResult); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(getResponse); + when(client.get(any(GetRequest.class))).thenReturn(future); + + modelManager.deployModel(modelId, null, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); assertFalse(modelManager.isModelRunningOnNode(modelId)); ArgumentCaptor exception = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(exception.capture()); @@ -730,7 +731,7 @@ public void testDeployModel_GetModelResponse_NotExist() { ); } - public void testDeployModel_GetModelResponse_wrong_hash_value() { + public void testDeployModel_GetModelResponse_wrong_hash_value() throws IOException { MLModelConfig modelConfig = TextEmbeddingModelConfig .builder() .modelType("bert") @@ -759,10 +760,10 @@ public void testDeployModel_GetModelResponse_wrong_hash_value() { when(modelCacheHelper.getLocalDeployedModels()).thenReturn(new String[] {}); mock_client_ThreadContext(client, threadPool, threadContext); mock_threadpool(threadPool, taskExecutorService); - setUpMock_GetModel(model); - setUpMock_GetModel(modelChunk0); - setUpMock_GetModel(modelChunk0); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); + setupGetModel(model); + setupGetModel(modelChunk0); + setupGetModel(modelChunk0); + modelManager.deployModel(modelId, null, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); assertFalse(modelManager.isModelRunningOnNode(modelId)); ArgumentCaptor exception = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(exception.capture()); @@ -781,7 +782,7 @@ public void testDeployModel_GetModelResponse_wrong_hash_value() { ); } - public void testDeployModel_GetModelResponse_FailedToDeploy() { + public void testDeployModel_GetModelResponse_FailedToDeploy() throws IOException { MLModelConfig modelConfig = TextEmbeddingModelConfig .builder() .modelType("bert") @@ -812,7 +813,7 @@ public void testDeployModel_GetModelResponse_FailedToDeploy() { setUpMock_GetModelChunks(model); // setUpMock_GetModel(modelChunk0); // setUpMock_GetModel(modelChunk1); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); + modelManager.deployModel(modelId, null, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); assertFalse(modelManager.isModelRunningOnNode(modelId)); ArgumentCaptor exception = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(exception.capture()); @@ -828,7 +829,7 @@ public void testDeployModel_GetModelResponse_FailedToDeploy() { public void testDeployModel_ModelAlreadyDeployed() { when(modelCacheHelper.isModelDeployed(modelId)).thenReturn(true); ActionListener listener = mock(ActionListener.class); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); + modelManager.deployModel(modelId, null, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); ArgumentCaptor response = ArgumentCaptor.forClass(String.class); verify(listener).onResponse(response.capture()); assertEquals("successful", response.getValue()); @@ -843,7 +844,7 @@ public void testDeployModel_ExceedMaxDeployedModel() { when(modelCacheHelper.getDeployedModels()).thenReturn(models); when(modelCacheHelper.getLocalDeployedModels()).thenReturn(models); ActionListener listener = mock(ActionListener.class); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); + modelManager.deployModel(modelId, null, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); ArgumentCaptor failure = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(failure.capture()); assertEquals("Exceed max local model per node limit", failure.getValue().getMessage()); @@ -878,7 +879,7 @@ public void testDeployModel_ThreadPoolException() { ActionListener listener = mock(ActionListener.class); FunctionName functionName = FunctionName.TEXT_EMBEDDING; - modelManager.deployModel(modelId, modelContentHashValue, functionName, true, false, mlTask, listener); + modelManager.deployModel(modelId, null, modelContentHashValue, functionName, true, false, mlTask, listener); verify(modelCacheHelper).removeModel(eq(modelId)); verify(mlStats).createCounterStatIfAbsent(eq(functionName), eq(ActionName.DEPLOY), eq(MLActionLevelStat.ML_ACTION_FAILURE_COUNT)); } @@ -1037,7 +1038,7 @@ private void testDeployModel_FailedToRetrieveModelChunks(boolean lastChunk) { ActionListener listener = mock(ActionListener.class); FunctionName functionName = FunctionName.TEXT_EMBEDDING; - modelManager.deployModel(modelId, modelContentHashValue, functionName, true, false, mlTask, listener); + modelManager.deployModel(modelId, null, modelContentHashValue, functionName, true, false, mlTask, listener); verify(modelCacheHelper).removeModel(eq(modelId)); verify(mlStats).createCounterStatIfAbsent(eq(functionName), eq(ActionName.DEPLOY), eq(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); verify(mlStats).getStat(eq(MLNodeLevelStat.ML_REQUEST_COUNT)); @@ -1067,46 +1068,46 @@ private void setUpMock_GetModel(MLModel model) { private void setUpMock_GetModelChunks(MLModel model) { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(model); return null; }).doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(modelChunk0); return null; }).doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(modelChunk1); return null; - }).when(modelManager).getModel(any(), any()); + }).when(modelManager).getModel(any(), any(), any()); } private void setUpMock_GetModelMeta_FailedToGetFirstChunk(MLModel model) { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(model); return null; }).doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new RuntimeException("Failed to get model")); return null; - }).when(modelManager).getModel(any(), any()); + }).when(modelManager).getModel(any(), any(), any()); } private void setUpMock_GetModelMeta_FailedToGetLastChunk(MLModel model) { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(model); return null; }).doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(modelChunk0); return null; }).doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new RuntimeException("Failed to get model")); return null; - }).when(modelManager).getModel(any(), any()); + }).when(modelManager).getModel(any(), any(), any()); } private void setUpMock_DownloadModelFileFailure() { diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java index e85c12d714..b7fe5841a4 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java @@ -384,7 +384,7 @@ public void updateDataObjectAsync_HappyCase() { assertEquals(TEST_INDEX, updateItemRequest.tableName()); assertEquals(TEST_ID, updateItemRequest.key().get("id").s()); assertEquals(TENANT_ID, updateItemRequest.key().get("tenant_id").s()); - assertEquals("foo", updateItemRequest.key().get("data").s()); + assertEquals("foo", updateItemRequest.attributeUpdates().get("data").value().s()); } @@ -408,7 +408,7 @@ public void updateDataObjectAsync_HappyCaseWithMap() { assertEquals(TEST_INDEX, updateItemRequest.tableName()); assertEquals(TEST_ID, updateItemRequest.key().get("id").s()); assertEquals(TENANT_ID, updateItemRequest.key().get("tenant_id").s()); - assertEquals("bar", updateItemRequest.key().get("foo").s()); + assertEquals("bar", updateItemRequest.attributeUpdates().get("foo").value().s()); } diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java index 8667450d9c..12cac0afdd 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java @@ -30,21 +30,23 @@ protected void configure() { }; public void testLocalBinding() { - Injector injector = Guice.createInjector(new SdkClientModule(null, null, null), localClientModule); + Injector injector = Guice.createInjector(new SdkClientModule(null, null, null, null, null), localClientModule); SdkClient sdkClient = injector.getInstance(SdkClient.class); assertTrue(sdkClient instanceof LocalClusterIndicesClient); } public void testRemoteOpenSearchBinding() { - Injector injector = Guice.createInjector(new SdkClientModule(SdkClientModule.REMOTE_OPENSEARCH, "http://example.org", "eu-west-3")); + Injector injector = Guice + .createInjector(new SdkClientModule(null, null, SdkClientModule.REMOTE_OPENSEARCH, "http://example.org", "eu-west-3")); SdkClient sdkClient = injector.getInstance(SdkClient.class); assertTrue(sdkClient instanceof RemoteClusterIndicesClient); } public void testDDBBinding() { - Injector injector = Guice.createInjector(new SdkClientModule(SdkClientModule.AWS_DYNAMO_DB, "http://example.org", "eu-west-3")); + Injector injector = Guice + .createInjector(new SdkClientModule(null, null, SdkClientModule.AWS_DYNAMO_DB, "http://example.org", "eu-west-3")); SdkClient sdkClient = injector.getInstance(SdkClient.class); assertTrue(sdkClient instanceof DDBOpenSearchClient); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index cbde703543..20dc04527c 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -233,7 +233,6 @@ public void testExecuteTask_OnLocalNode() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); verify(mlTaskManager).add(any(MLTask.class)); - verify(client).get(any(), any()); verify(mlTaskManager).remove(anyString()); } @@ -243,17 +242,16 @@ public void testExecuteTask_OnLocalNode_QueryInput() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithQuery, transportService, listener); verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); verify(mlTaskManager).add(any(MLTask.class)); - verify(client).get(any(), any()); verify(mlTaskManager).remove(anyString()); } public void testExecuteTask_OnLocalNode_RemoteModelAutoDeploy() { setupMocks(true, false, false, false); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); + ActionListener actionListener = invocation.getArgument(2); actionListener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(any(), any()); + }).when(mlModelManager).getModel(any(), any(), any()); when(mlModelManager.addModelToAutoDeployCache("111", mlModel)).thenReturn(mlModel); taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); verify(client).execute(any(), any(), any()); @@ -276,7 +274,6 @@ public void testExecuteTask_NoPermission() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlTaskManager).add(any(MLTask.class)); verify(mlTaskManager).remove(anyString()); - verify(client).get(any(), any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); assertEquals("User: test_user does not have permissions to run predict by model: 111", argumentCaptor.getValue().getMessage()); @@ -294,7 +291,6 @@ public void testExecuteTask_OnLocalNode_GetModelFail() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); verify(mlTaskManager).add(any(MLTask.class)); - verify(client).get(any(), any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); assertEquals(errorMessage, argumentCaptor.getValue().getMessage()); @@ -380,13 +376,12 @@ public void testExecuteTask_OnLocalNode_prediction_exception() { assertEquals("runtime exception", argumentCaptor.getValue().getMessage()); } - public void testExecuteTask_OnLocalNode_NullGetResponse() { + public void testExecuteTask_OnLocalNode_NullMLModel() { setupMocks(true, false, false, true); taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); verify(mlTaskManager).add(any(MLTask.class)); - verify(client).get(any(), any()); verify(mlTaskManager).remove(anyString()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); @@ -432,7 +427,7 @@ public void testValidateModelTensorOutputFailed() { taskRunner.validateOutputSchema("testId", modelTensorOutput); } - private void setupMocks(boolean runOnLocalNode, boolean failedToParseQueryInput, boolean failedToGetModel, boolean nullGetResponse) { + private void setupMocks(boolean runOnLocalNode, boolean failedToParseQueryInput, boolean failedToGetModel, boolean nullMlModel) { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); if (runOnLocalNode) { @@ -466,23 +461,16 @@ private void setupMocks(boolean runOnLocalNode, boolean failedToParseQueryInput, return null; }).when(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); } - - if (nullGetResponse) { - getResponse = null; - } - - if (failedToGetModel) { - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + if (failedToGetModel) { actionListener.onFailure(new RuntimeException(errorMessage)); - return null; - }).when(client).get(any(), any()); - } else { - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(getResponse); - return null; - }).when(client).get(any(), any()); - } + } else if (nullMlModel) { + actionListener.onResponse(null); + } else { + actionListener.onResponse(mlModel); + } + return null; + }).when(mlModelManager).getModel(any(), any(), any()); } } From 9159fe9bf7e3dad914ecfbbb1226142deb05cd07 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Wed, 3 Jul 2024 12:38:44 -0700 Subject: [PATCH 09/10] Simplify instantiating Data Object Request/Response builders (#2608) Signed-off-by: Daniel Widdis --- .github/workflows/CI-workflow.yml | 4 +-- plugin/build.gradle | 1 - .../CreateControllerTransportAction.java | 1 + .../DeleteControllerTransportAction.java | 1 + .../GetControllerTransportAction.java | 1 + .../UpdateControllerTransportAction.java | 1 + .../deploy/TransportDeployModelAction.java | 5 +++- .../opensearch/ml/model/MLModelManager.java | 17 +++++++----- .../sdkclient/LocalClusterIndicesClient.java | 3 ++- .../ml/sdkclient/SdkClientModule.java | 27 +++++++------------ 10 files changed, 32 insertions(+), 29 deletions(-) diff --git a/.github/workflows/CI-workflow.yml b/.github/workflows/CI-workflow.yml index 7077a70031..29a78fe81d 100644 --- a/.github/workflows/CI-workflow.yml +++ b/.github/workflows/CI-workflow.yml @@ -145,8 +145,8 @@ jobs: - name: Generate Password For Admin id: genpass run: | - PASSWORD=$(openssl rand -base64 20 | tr -dc 'A-Za-z0-9!@#$%^&*()_+=-') - echo "password={$PASSWORD}" >> $GITHUB_OUTPUT + PASSWORD=$(openssl rand -base64 20 | tr -dc 'A-Za-z0-9!@#$%^&*()_+=-') + echo "password={$PASSWORD}" >> $GITHUB_OUTPUT - name: Run Docker Image if: env.imagePresent == 'true' run: | diff --git a/plugin/build.gradle b/plugin/build.gradle index 7579b2a839..5b4f1ee3fb 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -84,7 +84,6 @@ dependencies { implementation "software.amazon.awssdk:third-party-jackson-core:2.25.40" implementation("software.amazon.awssdk:url-connection-client:2.25.40") implementation("software.amazon.awssdk:utils:2.25.40") - implementation("software.amazon.awssdk:apache-client:2.25.40") configurations.all { diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java index 3172b337f2..44e7e6f37a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java @@ -99,6 +99,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); + // TODO: Add support for multi tenancy mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> { FunctionName functionName = mlModel.getAlgorithm(); Boolean isHidden = mlModel.getIsHidden(); diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java index 37b4757017..a5b3931ff6 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java @@ -86,6 +86,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); + // TODO: Add support for multi tenancy mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> { Boolean isHidden = mlModel.getIsHidden(); modelAccessControlHelper diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java index 183c081da1..a5020a52f8 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java @@ -85,6 +85,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { Boolean isHidden = mlModel.getIsHidden(); modelAccessControlHelper diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java index 9b378f3334..95fb42c5a8 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java @@ -91,6 +91,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); + // TODO: Add support for multi tenancy mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> { FunctionName functionName = mlModel.getAlgorithm(); Boolean isHidden = mlModel.getIsHidden(); diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index 6151e645a3..1e7d0cce51 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -284,7 +284,10 @@ private void deployModel( mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { String taskId = response.getId(); mlTask.setTaskId(taskId); - if (algorithm == FunctionName.REMOTE && !mlFeatureEnabledSetting.isMultiTenancyEnabled()) { + if (algorithm == FunctionName.REMOTE) { + if (!mlFeatureEnabledSetting.isMultiTenancyEnabled()) { + return; + } mlTaskManager.add(mlTask, eligibleNodeIds); deployRemoteModel(mlModel, mlTask, localNodeId, eligibleNodes, deployToAllNodes, listener); return; diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 4e61b8b4a4..96e756d739 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -9,8 +9,8 @@ import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.CommonValue.NOT_FOUND; @@ -370,7 +370,8 @@ public void registerMLRemoteModel( mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); String modelGroupId = mlRegisterModelInput.getModelGroupId(); - GetDataObjectRequest getModelGroupRequest = GetDataObjectRequest.builder() + GetDataObjectRequest getModelGroupRequest = GetDataObjectRequest + .builder() .index(ML_MODEL_GROUP_INDEX) .tenantId(mlRegisterModelInput.getTenantId()) .id(modelGroupId) @@ -395,7 +396,8 @@ public void registerMLRemoteModel( */ modelGroupSourceMap.put(MLModelGroup.LATEST_VERSION_FIELD, updatedVersion); modelGroupSourceMap.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); - UpdateDataObjectRequest updateDataObjectRequest = UpdateDataObjectRequest.builder() + UpdateDataObjectRequest updateDataObjectRequest = UpdateDataObjectRequest + .builder() .index(ML_MODEL_GROUP_INDEX) .id(modelGroupId) .tenantId(mlRegisterModelInput.getTenantId()) @@ -589,7 +591,8 @@ private void indexRemoteModel( .tenantId(registerModelInput.getTenantId()) .build(); - PutDataObjectRequest putModelMetaRequest = PutDataObjectRequest.builder() + PutDataObjectRequest putModelMetaRequest = PutDataObjectRequest + .builder() .index(ML_MODEL_INDEX) .id(Boolean.TRUE.equals(registerModelInput.getIsHidden()) ? modelName : null) .tenantId(registerModelInput.getTenantId()) @@ -1635,7 +1638,8 @@ public void getModel(String modelId, String tenantId, ActionListener li * @param listener action listener */ public void getModel(String modelId, String tenantId, String[] includes, String[] excludes, ActionListener listener) { - GetDataObjectRequest getRequest = GetDataObjectRequest.builder() + GetDataObjectRequest getRequest = GetDataObjectRequest + .builder() .index(ML_MODEL_INDEX) .id(modelId) .tenantId(tenantId) @@ -1706,7 +1710,8 @@ public void getController(String modelId, ActionListener listener) * @param listener action listener */ private void getConnector(String connectorId, String tenantId, ActionListener listener) { - GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest.builder() + GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest + .builder() .index(ML_CONNECTOR_INDEX) .id(connectorId) .tenantId(tenantId) diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java index 3682b95ff2..7bc2bd9468 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java @@ -108,7 +108,8 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe if (getResponse == null) { return GetDataObjectResponse.builder().id(request.id()).parser(null).build(); } - return GetDataObjectResponse.builder() + return GetDataObjectResponse + .builder() .id(getResponse.getId()) .parser(createParser(getResponse)) .source(getResponse.getSource()) diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java index e296109bbd..e5b1e4706f 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java @@ -9,12 +9,14 @@ package org.opensearch.ml.sdkclient; import org.apache.http.HttpHost; +import org.apache.http.conn.ssl.NoopHostnameVerifier; import org.apache.http.impl.client.BasicCredentialsProvider; import org.opensearch.OpenSearchException; import org.opensearch.client.Client; +import org.opensearch.client.RestClient; +import org.opensearch.client.json.jackson.JacksonJsonpMapper; import org.opensearch.client.opensearch.OpenSearchClient; -import org.opensearch.client.transport.aws.AwsSdk2Transport; -import org.opensearch.client.transport.aws.AwsSdk2TransportOptions; +import org.opensearch.client.transport.rest_client.RestClientTransport; import org.opensearch.common.inject.AbstractModule; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.sdk.SdkClient; @@ -28,8 +30,6 @@ import software.amazon.awssdk.auth.credentials.ContainerCredentialsProvider; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider; -import software.amazon.awssdk.http.SdkHttpClient; -import software.amazon.awssdk.http.apache.ApacheHttpClient; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; @@ -135,33 +135,24 @@ private OpenSearchClient createOpenSearchClient() { try { BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider(); // Basic http(not-s) client using RestClient. - SdkHttpClient httpClient = ApacheHttpClient.builder().build(); - AwsSdk2Transport awsSdk2Transport = new AwsSdk2Transport( - httpClient, - HttpHost.create(remoteMetadataEndpoint).getHostName(), - "aoss", - Region.of(region), - AwsSdk2TransportOptions.builder().build() - ); - /*RestClient restClient = RestClient + RestClient restClient = RestClient // This HttpHost syntax works with export REMOTE_METADATA_ENDPOINT=http://127.0.0.1:9200 .builder(HttpHost.create(remoteMetadataEndpoint)) .setStrictDeprecationMode(true) .setHttpClientConfigCallback(httpClientBuilder -> { try { return httpClientBuilder - .setDefaultCredentialsProvider(credentialsProvider) - .setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE); + .setDefaultCredentialsProvider(credentialsProvider) + .setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE); } catch (Exception e) { throw new OpenSearchException(e); } }) - .build();*/ + .build(); ObjectMapper objectMapper = new ObjectMapper() .setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE) .setSerializationInclusion(JsonInclude.Include.NON_NULL); - // return new OpenSearchClient(new RestClientTransport(restClient, new JacksonJsonpMapper(objectMapper))); - return new OpenSearchClient(awsSdk2Transport); + return new OpenSearchClient(new RestClientTransport(restClient, new JacksonJsonpMapper(objectMapper))); } catch (Exception e) { throw new OpenSearchException(e); } From 53b41132c8b6a6c65f1541491ee0520366132757 Mon Sep 17 00:00:00 2001 From: Arjun kumar Giri Date: Wed, 10 Jul 2024 10:18:27 -0700 Subject: [PATCH 10/10] Addressed comments Signed-off-by: Arjun kumar Giri --- .../sdk/UpdateDataObjectRequest.java | 2 +- .../deploy/TransportDeployModelAction.java | 3 ++- .../ml/plugin/MachineLearningPlugin.java | 3 +++ .../GetControllerTransportActionTests.java | 4 ++- .../TransportDeployModelActionTests.java | 26 +++++++++++++++++++ 5 files changed, 35 insertions(+), 3 deletions(-) diff --git a/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java index aee473b5f9..553257eb7d 100644 --- a/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java @@ -136,7 +136,7 @@ public Builder id(String id) { return this; } - /** + /** * Add a tenant ID to this builder * @param tenantId the tenant id * @return the updated builder diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index 1e7d0cce51..059782c82f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -285,7 +285,8 @@ private void deployModel( String taskId = response.getId(); mlTask.setTaskId(taskId); if (algorithm == FunctionName.REMOTE) { - if (!mlFeatureEnabledSetting.isMultiTenancyEnabled()) { + if (mlFeatureEnabledSetting.isMultiTenancyEnabled()) { + listener.onResponse(new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.CREATED.name())); return; } mlTaskManager.add(mlTask, eligibleNodeIds); diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 204d3c479d..b64820dda6 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -437,6 +437,8 @@ public MachineLearningPlugin(Settings settings) { @Override public Collection createGuiceModules() { + // TODO: SDKClientModule is initialized both in createGuiceModules and createComponents. Unify these + // approaches to prevent multiple instances of SDKClient. return List.of(new SdkClientModule(null, null)); } @@ -464,6 +466,7 @@ public Collection createComponents( Settings settings = environment.settings(); Path dataPath = environment.dataFiles()[0]; Path configFile = environment.configFile(); + // TODO: Rather than recreating SDKClientModule reuse module created as part of createGuiceModules ModulesBuilder modules = new ModulesBuilder(); modules.add(new SdkClientModule(client, xContentRegistry)); Injector injector = modules.createInjector(); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java index 49b771271c..d4ff67fa14 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java @@ -5,7 +5,9 @@ package org.opensearch.ml.action.controller; -import static org.mockito.ArgumentMatchers.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; diff --git a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java index bce5c4b2dd..879219c3da 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java @@ -386,6 +386,32 @@ public void testDoExecuteRemoteInferenceDisabled() { assertEquals(REMOTE_INFERENCE_DISABLED_ERR_MSG, argumentCaptor.getValue().getMessage()); } + public void testDoExecuteRemoteInference_MultiNodeEnabled() { + MLModel mlModel = mock(MLModel.class); + when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); + when(mlModel.getTenantId()).thenReturn("test_tenant"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(4); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + IndexResponse indexResponse = mock(IndexResponse.class); + when(indexResponse.getId()).thenReturn("mockIndexId"); + listener.onResponse(indexResponse); + return null; + }).when(mlTaskManager).createMLTask(any(MLTask.class), Mockito.isA(ActionListener.class)); + + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + ActionListener deployModelResponseListener = mock(ActionListener.class); + when(mlDeployModelRequest.getTenantId()).thenReturn("test_tenant"); + transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLDeployModelResponse.class); + verify(deployModelResponseListener).onResponse(argumentCaptor.capture()); + assertEquals("CREATED", argumentCaptor.getValue().getStatus()); + } + public void testDoExecuteLocalInferenceDisabled() { MLModel mlModel = mock(MLModel.class); when(mlModel.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING);