diff --git a/.github/workflows/CI-workflow.yml b/.github/workflows/CI-workflow.yml index 02c25afa9c..94d3443b59 100644 --- a/.github/workflows/CI-workflow.yml +++ b/.github/workflows/CI-workflow.yml @@ -147,8 +147,8 @@ jobs: - name: Generate Password For Admin id: genpass run: | - PASSWORD=$(openssl rand -base64 20 | tr -dc 'A-Za-z0-9!@#$%^&*()_+=-') - echo "password={$PASSWORD}" >> $GITHUB_OUTPUT + PASSWORD=$(openssl rand -base64 20 | tr -dc 'A-Za-z0-9!@#$%^&*()_+=-') + echo "password={$PASSWORD}" >> $GITHUB_OUTPUT - name: Run Docker Image if: env.imagePresent == 'true' run: | diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java index 0b88d9ca19..779cab1ffc 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java @@ -180,7 +180,7 @@ public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) */ public void shouldUpdateIndex(String indexName, Integer newVersion, ActionListener listener) { IndexMetadata indexMetaData = clusterService.state().getMetadata().indices().get(indexName); - if (indexMetaData == null) { + if (indexMetaData == null || indexMetaData.mapping() == null) { listener.onResponse(Boolean.FALSE); return; } diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java index 5439d73619..44e7e6f37a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java @@ -99,7 +99,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); - mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { + // TODO: Add support for multi tenancy + mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> { FunctionName functionName = mlModel.getAlgorithm(); Boolean isHidden = mlModel.getIsHidden(); if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java index 92b8095ad4..a5b3931ff6 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java @@ -86,7 +86,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); - mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { + // TODO: Add support for multi tenancy + mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> { Boolean isHidden = mlModel.getIsHidden(); modelAccessControlHelper .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java index 26c59decdf..a5020a52f8 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java @@ -85,7 +85,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + // TODO: Add support for multi tenancy + mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> { Boolean isHidden = mlModel.getIsHidden(); modelAccessControlHelper .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java index dab8410ad0..95fb42c5a8 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java @@ -91,7 +91,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); - mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { + // TODO: Add support for multi tenancy + mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> { FunctionName functionName = mlModel.getAlgorithm(); Boolean isHidden = mlModel.getIsHidden(); if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index 76e17e9675..059782c82f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -149,7 +149,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); - mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { + mlModelManager.getModel(modelId, tenantId, null, excludes, ActionListener.wrap(mlModel -> { FunctionName functionName = mlModel.getAlgorithm(); Boolean isHidden = mlModel.getIsHidden(); if (!TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModel.getTenantId(), listener)) { @@ -285,6 +285,10 @@ private void deployModel( String taskId = response.getId(); mlTask.setTaskId(taskId); if (algorithm == FunctionName.REMOTE) { + if (mlFeatureEnabledSetting.isMultiTenancyEnabled()) { + listener.onResponse(new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.CREATED.name())); + return; + } mlTaskManager.add(mlTask, eligibleNodeIds); deployRemoteModel(mlModel, mlTask, localNodeId, eligibleNodes, deployToAllNodes, listener); return; diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java index 495ea771f2..e86fda3ef8 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java @@ -224,6 +224,7 @@ private void deployModel( mlModelManager .deployModel( modelId, + null, modelContentHash, functionName, deployToAllNodes, diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java index a33cb3bbe3..381e9f529c 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java @@ -103,6 +103,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); - mlModelManager.getModel(sdkClient, modelId, null, excludes, ActionListener.wrap(mlModel -> { + mlModelManager.getModel(modelId, tenantId, null, excludes, ActionListener.wrap(mlModel -> { if (TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModel.getTenantId(), actionListener)) { if (!isModelDeploying(mlModel.getModelState())) { FunctionName functionName = mlModel.getAlgorithm(); diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 60c50d6716..9e95d94304 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -211,7 +211,7 @@ public void onFailure(Exception e) { modelActionListener.onResponse(cachedMlModel); } else { // For multi-node cluster, the function name is null in cache, so should always get model first. - mlModelManager.getModel(modelId, modelActionListener); + mlModelManager.getModel(modelId, tenantId, modelActionListener); } } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index f22e51b24a..982a874e3d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -419,6 +419,7 @@ private MLRegisterModelGroupInput createRegisterModelGroupRequest(MLRegisterMode .backendRoles(registerModelInput.getBackendRoles()) .modelAccessMode(registerModelInput.getAccessMode()) .isAddAllBackendRoles(registerModelInput.getAddAllBackendRoles()) + .tenantId(registerModelInput.getTenantId()) .build(); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java index 734b9209a6..ba82d91e00 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java @@ -182,7 +182,7 @@ private void validateAccess(String modelId, String tenantId, ActionListener { + mlModelManager.getModel(modelId, tenantId, null, excludes, ActionListener.runBefore(ActionListener.wrap(mlModel -> { if (!TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModel.getTenantId(), listener)) { return; } diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java index bf719164b2..ff650e3c41 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java @@ -112,6 +112,7 @@ public void validateConnectorAccess( GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest .builder() .index(ML_CONNECTOR_INDEX) + .tenantId(tenantId) .id(connectorId) .fetchSourceContext(fetchSourceContext) .build(); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 5466f11839..187a0ccbc1 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -9,6 +9,7 @@ import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; @@ -102,7 +103,6 @@ import org.opensearch.index.reindex.DeleteByQueryRequest; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.cluster.DiscoveryNodeHelper; -import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLModelGroup; @@ -164,6 +164,7 @@ public class MLModelManager { public static final long MODEL_FILE_SIZE_LIMIT = 4L * 1024 * 1024 * 1024;// 4GB private final Client client; + private final SdkClient sdkClient; private final ClusterService clusterService; private final ScriptService scriptService; private final ThreadPool threadPool; @@ -196,6 +197,7 @@ public MLModelManager( ClusterService clusterService, ScriptService scriptService, Client client, + SdkClient sdkClient, ThreadPool threadPool, NamedXContentRegistry xContentRegistry, ModelHelper modelHelper, @@ -209,6 +211,7 @@ public MLModelManager( DiscoveryNodeHelper nodeHelper ) { this.client = client; + this.sdkClient = sdkClient; this.threadPool = threadPool; this.xContentRegistry = xContentRegistry; this.modelHelper = modelHelper; @@ -367,7 +370,12 @@ public void registerMLRemoteModel( mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); String modelGroupId = mlRegisterModelInput.getModelGroupId(); - GetDataObjectRequest getModelGroupRequest = GetDataObjectRequest.builder().index(ML_MODEL_GROUP_INDEX).id(modelGroupId).build(); + GetDataObjectRequest getModelGroupRequest = GetDataObjectRequest + .builder() + .index(ML_MODEL_GROUP_INDEX) + .tenantId(mlRegisterModelInput.getTenantId()) + .id(modelGroupId) + .build(); sdkClient .getDataObjectAsync(getModelGroupRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) .whenComplete((r, throwable) -> { @@ -383,6 +391,7 @@ public void registerMLRemoteModel( .builder() .index(ML_MODEL_GROUP_INDEX) .id(modelGroupId) + .tenantId(mlRegisterModelInput.getTenantId()) .ifSeqNo(getModelGroupResponse.getSeqNo()) .ifPrimaryTerm(getModelGroupResponse.getPrimaryTerm()) .dataObject(modelGroupSourceMap) @@ -528,7 +537,7 @@ private UpdateRequest createUpdateModelGroupRequest( } private int incrementLatestVersion(Map modelGroupSourceMap) { - return (int) modelGroupSourceMap.get(MLModelGroup.LATEST_VERSION_FIELD) + 1; + return Integer.parseInt(modelGroupSourceMap.get(MLModelGroup.LATEST_VERSION_FIELD).toString()) + 1; } private void indexRemoteModel( @@ -575,6 +584,7 @@ private void indexRemoteModel( .builder() .index(ML_MODEL_INDEX) .id(Boolean.TRUE.equals(registerModelInput.getIsHidden()) ? modelName : null) + .tenantId(registerModelInput.getTenantId()) .dataObject(mlModelMeta) .build(); @@ -923,7 +933,7 @@ private void updateModelRegisterStateAsDone( void deployModelAfterRegistering(MLRegisterModelInput registerModelInput, String modelId) { String[] modelNodeIds = registerModelInput.getModelNodeIds(); log.debug("start deploying model after registering, modelId: {} on nodes: {}", modelId, Arrays.toString(modelNodeIds)); - MLDeployModelRequest request = new MLDeployModelRequest(modelId, null, modelNodeIds, false, true, true); + MLDeployModelRequest request = new MLDeployModelRequest(modelId, registerModelInput.getTenantId(), modelNodeIds, false, true, true); ActionListener listener = ActionListener .wrap(r -> log.debug("model deployed, response {}", r), e -> log.error("Failed to deploy model", e)); client.execute(MLDeployModelAction.INSTANCE, request, listener); @@ -990,6 +1000,7 @@ private void handleException(FunctionName functionName, String taskId, Exception * into memory. * * @param modelId model id + * @param tenantId tenant id * @param modelContentHash model content hash value * @param functionName function name * @param mlTask ML task @@ -997,6 +1008,7 @@ private void handleException(FunctionName functionName, String taskId, Exception */ public void deployModel( String modelId, + String tenantId, String modelContentHash, FunctionName functionName, boolean deployToAllNodes, @@ -1039,7 +1051,7 @@ public void deployModel( if (!autoDeployModel) { checkAndAddRunningTask(mlTask, maxDeployTasksPerNode); } - this.getModel(modelId, threadedActionListener(DEPLOY_THREAD_POOL, ActionListener.wrap(mlModel -> { + this.getModel(modelId, tenantId, threadedActionListener(DEPLOY_THREAD_POOL, ActionListener.wrap(mlModel -> { modelCacheHelper.setIsModelEnabled(modelId, mlModel.getIsEnabled()); modelCacheHelper.setModelInfo(modelId, mlModel); if (FunctionName.REMOTE == mlModel.getAlgorithm() @@ -1164,7 +1176,7 @@ private void deployRemoteOrBuiltInModel(MLModel mlModel, Integer eligibleNodeCou return; } log.info("Set connector {} for the model: {}", mlModel.getConnectorId(), modelId); - getConnector(mlModel.getConnectorId(), ActionListener.wrap(connector -> { + getConnector(mlModel.getConnectorId(), mlModel.getTenantId(), ActionListener.wrap(connector -> { mlModel.setConnector(connector); setupParamsAndPredictable(modelId, mlModel); mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); @@ -1241,7 +1253,7 @@ public synchronized void updateModelCache(String modelId, ActionListener wrappedListener.onResponse("Successfully updated model cache for the remote model " + modelId); log.info("Completed the model cache update for the remote model {}", modelId); } else { - getConnector(mlModel.getConnectorId(), ActionListener.wrap(connector -> { + getConnector(mlModel.getConnectorId(), mlModel.getTenantId(), ActionListener.wrap(connector -> { mlModel.setConnector(connector); setupParamsAndPredictable(modelId, mlModel); wrappedListener.onResponse("Successfully updated model cache for the remote model " + modelId); @@ -1286,7 +1298,7 @@ public synchronized void deployControllerWithDeployedModel(String modelId, Actio wrappedListener.onResponse("Successfully deployed model controller for the remote model " + modelId); log.info("Deployed model controller for the remote model {}", modelId); } else { - getConnector(mlModel.getConnectorId(), ActionListener.wrap(connector -> { + getConnector(mlModel.getConnectorId(), mlModel.getTenantId(), ActionListener.wrap(connector -> { mlModel.setConnector(connector); setupParamsAndPredictable(modelId, mlModel); wrappedListener.onResponse("Successfully deployed model controller for the remote model " + modelId); @@ -1324,7 +1336,7 @@ public synchronized void undeployController(String modelId, ActionListener { + getConnector(mlModel.getConnectorId(), mlModel.getTenantId(), ActionListener.wrap(connector -> { mlModel.setConnector(connector); setupParamsAndPredictable(modelId, mlModel); wrappedListener.onResponse("Successfully undeployed model controller for the remote model " + modelId); @@ -1590,55 +1602,36 @@ public MLGuard getMLGuard(String modelId) { * @param listener action listener */ public void getModel(String modelId, ActionListener listener) { - getModel(modelId, null, null, listener); + getModel(modelId, null, listener); } - // TODO remove when all usages are migrated to SDK version /** - * Get model from model index with includes/excludes filter. + * Get model from model index. * * @param modelId model id - * @param includes fields included - * @param excludes fields excluded + * @param tenantId tenant id * @param listener action listener */ - public void getModel(String modelId, String[] includes, String[] excludes, ActionListener listener) { - GetRequest getRequest = new GetRequest(); - FetchSourceContext fetchContext = new FetchSourceContext(true, includes, excludes); - getRequest.index(ML_MODEL_INDEX).id(modelId).fetchSourceContext(fetchContext); - client.get(getRequest, ActionListener.wrap(r -> { - if (r != null && r.isExists()) { - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - String algorithmName = r.getSource().get(ALGORITHM_FIELD).toString(); - - MLModel mlModel = MLModel.parse(parser, algorithmName); - mlModel.setModelId(modelId); - listener.onResponse(mlModel); - } catch (Exception e) { - log.error("Failed to parse ml task{}", r.getId(), e); - listener.onFailure(e); - } - } else { - listener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND)); - } - }, listener::onFailure)); + public void getModel(String modelId, String tenantId, ActionListener listener) { + getModel(modelId, tenantId, null, null, listener); } + // TODO remove when all usages are migrated to SDK version /** * Get model from model index with includes/excludes filter. * - * @param sdkClient the SdkClient instance * @param modelId model id + * @param tenantId tenant id * @param includes fields included * @param excludes fields excluded * @param listener action listener */ - public void getModel(SdkClient sdkClient, String modelId, String[] includes, String[] excludes, ActionListener listener) { + public void getModel(String modelId, String tenantId, String[] includes, String[] excludes, ActionListener listener) { GetDataObjectRequest getRequest = GetDataObjectRequest .builder() .index(ML_MODEL_INDEX) .id(modelId) + .tenantId(tenantId) .fetchSourceContext(new FetchSourceContext(true, includes, excludes)) .build(); sdkClient.getDataObjectAsync(getRequest, client.threadPool().executor(GENERAL_THREAD_POOL)).whenComplete((r, throwable) -> { @@ -1702,30 +1695,54 @@ public void getController(String modelId, ActionListener listener) * Get connector from connector index. * * @param connectorId connector id + * @param tenantId tenant id * @param listener action listener */ - private void getConnector(String connectorId, ActionListener listener) { - GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId); - client.get(getRequest, ActionListener.wrap(r -> { - if (r != null && r.isExists()) { - try ( - XContentParser parser = MLNodeUtils - .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, r.getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Connector connector = Connector.createConnector(parser); - listener.onResponse(connector); - } catch (Exception e) { - log.error("Failed to parse connector:" + connectorId); - listener.onFailure(e); + private void getConnector(String connectorId, String tenantId, ActionListener listener) { + GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest + .builder() + .index(ML_CONNECTOR_INDEX) + .id(connectorId) + .tenantId(tenantId) + .build(); + + sdkClient + .getDataObjectAsync(getDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) + .whenComplete((r, throwable) -> { + log.debug("Completed Get Connector Request, id:{}", connectorId); + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + if (cause instanceof IndexNotFoundException) { + log.error("Failed to get connector index", cause); + listener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND)); + } else { + log.error("Failed to get ML connector " + connectorId, cause); + listener.onFailure(cause); + } + } else { + try { + GetResponse gr = r.parser() == null ? null : GetResponse.fromXContent(r.parser()); + if (gr != null && gr.isExists()) { + try ( + XContentParser parser = MLNodeUtils + .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, gr.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Connector connector = Connector.createConnector(parser); + listener.onResponse(connector); + } catch (Exception e) { + log.error("Failed to parse connector:" + connectorId); + listener.onFailure(e); + } + } else { + listener + .onFailure(new OpenSearchStatusException("Failed to find connector:" + connectorId, RestStatus.NOT_FOUND)); + } + } catch (Exception e) { + listener.onFailure(e); + } } - } else { - listener.onFailure(new OpenSearchStatusException("Failed to find connector:" + connectorId, RestStatus.NOT_FOUND)); - } - }, e -> { - log.error("Failed to get connector", e); - listener.onFailure(new OpenSearchStatusException("Failed to get connector:" + connectorId, RestStatus.NOT_FOUND)); - })); + }); } /** diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index cfaa019b5f..b64820dda6 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -27,7 +27,9 @@ import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Injector; import org.opensearch.common.inject.Module; +import org.opensearch.common.inject.ModulesBuilder; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.IndexScopedSettings; import org.opensearch.common.settings.Setting; @@ -287,6 +289,7 @@ import org.opensearch.rest.RestController; import org.opensearch.rest.RestHandler; import org.opensearch.script.ScriptService; +import org.opensearch.sdk.SdkClient; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchRequestProcessor; import org.opensearch.search.pipeline.SearchResponseProcessor; @@ -434,7 +437,9 @@ public MachineLearningPlugin(Settings settings) { @Override public Collection createGuiceModules() { - return List.of(new SdkClientModule()); + // TODO: SDKClientModule is initialized both in createGuiceModules and createComponents. Unify these + // approaches to prevent multiple instances of SDKClient. + return List.of(new SdkClientModule(null, null)); } @SneakyThrows @@ -461,6 +466,13 @@ public Collection createComponents( Settings settings = environment.settings(); Path dataPath = environment.dataFiles()[0]; Path configFile = environment.configFile(); + // TODO: Rather than recreating SDKClientModule reuse module created as part of createGuiceModules + ModulesBuilder modules = new ModulesBuilder(); + modules.add(new SdkClientModule(client, xContentRegistry)); + Injector injector = modules.createInjector(); + + // Get the injected SdkClient instance from the injector + SdkClient sdkClient = injector.getInstance(SdkClient.class); mlIndicesHandler = new MLIndicesHandler(clusterService, client); encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); @@ -503,6 +515,7 @@ public Collection createComponents( clusterService, scriptService, client, + sdkClient, threadPool, xContentRegistry, modelHelper, diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java index da909f5474..860bcbfef8 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java @@ -109,7 +109,12 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client } }); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - modelManager.getModel(modelId, ActionListener.runBefore(listener, () -> context.restore())); + modelManager + .getModel( + modelId, + getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request), + ActionListener.runBefore(listener, () -> context.restore()) + ); } }; } diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java index a90aa47d45..d4aca6cb6b 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java @@ -16,6 +16,7 @@ import java.security.AccessController; import java.security.PrivilegedAction; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; @@ -54,7 +55,9 @@ import lombok.extern.log4j.Log4j2; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeAction; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.AttributeValueUpdate; import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; @@ -116,12 +119,7 @@ public CompletionStage putDataObjectAsync(PutDataObjectRe // If document exists, overwrite and increment and return SEQ_NO dynamoDbClient.putItem(putItemRequest); // TODO need to pass seqNo to simulated response - String simulatedIndexResponse = simulateOpenSearchResponse( - request.index(), - request.id(), - source, - Map.of("result", "created") - ); + String simulatedIndexResponse = simulateOpenSearchResponse(request.index(), id, source, Map.of("result", "created")); return PutDataObjectResponse.builder().id(id).parser(createParser(simulatedIndexResponse)).build(); } catch (IOException e) { // Rethrow unchecked exception on XContent parsing error @@ -192,12 +190,26 @@ public CompletionStage updateDataObjectAsync(UpdateDat String source = Strings.toString(MediaTypeRegistry.JSON, request.dataObject()); JsonNode jsonNode = OBJECT_MAPPER.readTree(source); Map updateItem = JsonTransformer.convertJsonObjectToDDBAttributeMap(jsonNode); - updateItem.put(HASH_KEY, AttributeValue.builder().s(tenantId).build()); - updateItem.put(RANGE_KEY, AttributeValue.builder().s(request.id()).build()); + updateItem.remove(HASH_KEY); + updateItem.remove(RANGE_KEY); + Map updateAttributeValue = updateItem + .entrySet() + .stream() + .collect( + Collectors + .toMap( + Map.Entry::getKey, + entry -> AttributeValueUpdate.builder().action(AttributeAction.PUT).value(entry.getValue()).build() + ) + ); + Map updateKey = new HashMap<>(); + updateKey.put(HASH_KEY, AttributeValue.builder().s(tenantId).build()); + updateKey.put(RANGE_KEY, AttributeValue.builder().s(request.id()).build()); UpdateItemRequest.Builder updateItemRequestBuilder = UpdateItemRequest .builder() .tableName(getTableName(request.index())) - .key(updateItem); + .key(updateKey) + .attributeUpdates(updateAttributeValue); if (request.ifSeqNo() != null) { // Get current document version and put in attribute map. Ignore primary term on DDB. int currentSeqNo = jsonNode.has(SEQ_NO_KEY) ? jsonNode.get(SEQ_NO_KEY).asInt() : 0; @@ -209,7 +221,6 @@ public CompletionStage updateDataObjectAsync(UpdateDat ); } UpdateItemRequest updateItemRequest = updateItemRequestBuilder.build(); - // TODO need to add an incremented seqNo here dynamoDbClient.updateItem(updateItemRequest); // TODO need to pass seqNo to simulated response String simulatedUpdateResponse = simulateOpenSearchResponse(request.index(), request.id(), source, Map.of("found", true)); diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java index c81dddb65c..84ee7e8b47 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java @@ -103,15 +103,12 @@ public CompletionStage putDataObjectAsync(PutDataObjectRe public CompletionStage getDataObjectAsync(GetDataObjectRequest request, Executor executor) { return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { try { - log.info("Getting {} from {}", request.id(), request.index()); GetResponse getResponse = client .get(new GetRequest(request.index(), request.id()).fetchSourceContext(request.fetchSourceContext())) .actionGet(); if (getResponse == null) { - log.info("Null GetResponse"); return GetDataObjectResponse.builder().id(request.id()).parser(null).build(); } - log.info("Retrieved data object"); return GetDataObjectResponse .builder() .id(getResponse.getId()) diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java index cc9b161146..e5b1e4706f 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java @@ -10,12 +10,15 @@ import org.apache.http.HttpHost; import org.apache.http.conn.ssl.NoopHostnameVerifier; +import org.apache.http.impl.client.BasicCredentialsProvider; import org.opensearch.OpenSearchException; +import org.opensearch.client.Client; import org.opensearch.client.RestClient; import org.opensearch.client.json.jackson.JacksonJsonpMapper; import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.transport.rest_client.RestClientTransport; import org.opensearch.common.inject.AbstractModule; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.sdk.SdkClient; import com.fasterxml.jackson.annotation.JsonInclude; @@ -45,12 +48,20 @@ public class SdkClientModule extends AbstractModule { private final String remoteMetadataType; private final String remoteMetadataEndpoint; private final String region; // not using with RestClient + private Client client; + private NamedXContentRegistry namedXContentRegistry; /** * Instantiate this module using environment variables */ - public SdkClientModule() { - this(System.getenv(REMOTE_METADATA_TYPE), System.getenv(REMOTE_METADATA_ENDPOINT), System.getenv(REGION)); + public SdkClientModule(Client client, NamedXContentRegistry namedXContentRegistry) { + this( + client, + namedXContentRegistry, + System.getenv(REMOTE_METADATA_TYPE), + System.getenv(REMOTE_METADATA_ENDPOINT), + System.getenv(REGION) + ); } /** @@ -59,7 +70,15 @@ public SdkClientModule() { * @param remoteMetadataEndpoint The remote endpoint * @param region The region */ - SdkClientModule(String remoteMetadataType, String remoteMetadataEndpoint, String region) { + SdkClientModule( + Client client, + NamedXContentRegistry namedXContentRegistry, + String remoteMetadataType, + String remoteMetadataEndpoint, + String region + ) { + this.client = client; + this.namedXContentRegistry = namedXContentRegistry; this.remoteMetadataType = remoteMetadataType; this.remoteMetadataEndpoint = remoteMetadataEndpoint; this.region = region; @@ -69,7 +88,7 @@ public SdkClientModule() { protected void configure() { if (this.remoteMetadataType == null) { log.info("Using local opensearch cluster as metadata store"); - bind(SdkClient.class).to(LocalClusterIndicesClient.class); + bindLocalClient(); return; } @@ -85,7 +104,15 @@ protected void configure() { return; default: log.info("Using local opensearch cluster as metadata store"); - bind(SdkClient.class).to(LocalClusterIndicesClient.class); + bindLocalClient(); + } + } + + private void bindLocalClient() { + if (client == null) { + bind(SdkClient.class).to(LocalClusterIndicesClient.class); + } else { + bind(SdkClient.class).toInstance(new LocalClusterIndicesClient(this.client, this.namedXContentRegistry)); } } @@ -106,6 +133,7 @@ private DynamoDbClient createDynamoDbClient() { private OpenSearchClient createOpenSearchClient() { try { + BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider(); // Basic http(not-s) client using RestClient. RestClient restClient = RestClient // This HttpHost syntax works with export REMOTE_METADATA_ENDPOINT=http://127.0.0.1:9200 @@ -113,7 +141,9 @@ private OpenSearchClient createOpenSearchClient() { .setStrictDeprecationMode(true) .setHttpClientConfigCallback(httpClientBuilder -> { try { - return httpClientBuilder.setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE); + return httpClientBuilder + .setDefaultCredentialsProvider(credentialsProvider) + .setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE); } catch (Exception e) { throw new OpenSearchException(e); } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index a2fba5f959..1e75108e7a 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -5,9 +5,6 @@ package org.opensearch.ml.task; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; -import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD; import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage; import static org.opensearch.ml.permission.AccessController.checkUserPermissions; import static org.opensearch.ml.permission.AccessController.getUserContext; @@ -23,23 +20,18 @@ import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionListenerResponseHandler; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ThreadedActionListener; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; @@ -154,7 +146,7 @@ public void dispatchTask( if (workerNodes == null || workerNodes.length == 0) { if (FunctionName.isAutoDeployEnabled(autoDeploymentEnabled, functionName)) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - mlModelManager.getModel(modelId, ActionListener.runBefore(ActionListener.wrap(model -> { + mlModelManager.getModel(modelId, request.getTenantId(), ActionListener.runBefore(ActionListener.wrap(model -> { Boolean isHidden = model.getIsHidden(); if (!checkModelAutoDeployEnabled(model)) { final String errorMsg = getErrorMessage( @@ -245,7 +237,7 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener dataFrameActionListener = ActionListener.wrap(dataSet -> { MLInput newInput = mlInput.toBuilder().inputDataset(dataSet).build(); - predict(modelId, mlTask, newInput, listener); + predict(modelId, request.getTenantId(), mlTask, newInput, listener); }, e -> { log.error("Failed to generate DataFrame from search query", e); handleAsyncMLTaskFailure(mlTask, e); @@ -258,7 +250,7 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener { predict(modelId, mlTask, mlInput, listener); }); + threadPool.executor(threadPoolName).execute(() -> { predict(modelId, request.getTenantId(), mlTask, mlInput, listener); }); break; } } @@ -274,7 +266,7 @@ private String getPredictThreadPool(FunctionName functionName) { return functionName == FunctionName.REMOTE ? REMOTE_PREDICT_THREAD_POOL : PREDICT_THREAD_POOL; } - private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListener listener) { + private void predict(String modelId, String tenantId, MLTask mlTask, MLInput mlInput, ActionListener listener) { ActionListener internalListener = wrappedCleanupListener(listener, mlTask.getTaskId()); // track ML task count and add ML task into cache mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); @@ -305,8 +297,8 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe .state(MLTaskState.RUNNING) .workerNodes(Arrays.asList(clusterService.localNode().getId())) .build(); - mlModelManager.deployModel(modelId, null, functionName, false, true, mlDeployTask, ActionListener.wrap(s -> { - runPredict(modelId, mlTask, mlInput, functionName, internalListener); + mlModelManager.deployModel(modelId, tenantId, null, functionName, false, true, mlDeployTask, ActionListener.wrap(s -> { + runPredict(modelId, tenantId, mlTask, mlInput, functionName, internalListener); }, e -> { log.error("Failed to auto deploy model " + modelId, e); internalListener.onFailure(e); @@ -314,11 +306,12 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe return; } - runPredict(modelId, mlTask, mlInput, functionName, internalListener); + runPredict(modelId, tenantId, mlTask, mlInput, functionName, internalListener); } private void runPredict( String modelId, + String tenantId, MLTask mlTask, MLInput mlInput, FunctionName algorithm, @@ -367,21 +360,12 @@ private void runPredict( // search model by model id. try (ThreadContext.StoredContext context = threadPool.getThreadContext().stashContext()) { - ActionListener getModelListener = ActionListener.wrap(r -> { - if (r == null || !r.isExists()) { + ActionListener getModelListener = ActionListener.wrap(mlModel -> { + if (mlModel == null) { internalListener.onFailure(new ResourceNotFoundException("No model found, please check the modelId.")); return; } - try ( - XContentParser xContentParser = XContentType.JSON - .xContent() - .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, r.getSourceAsString()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, xContentParser.nextToken(), xContentParser); - GetResponse getResponse = r; - String algorithmName = getResponse.getSource().get(ALGORITHM_FIELD).toString(); - MLModel mlModel = MLModel.parse(xContentParser, algorithmName); - mlModel.setModelId(modelId); + try { User resourceUser = mlModel.getUser(); User requestUser = getUserContext(client); if (!checkUserPermissions(requestUser, resourceUser, modelId)) { @@ -416,10 +400,10 @@ private void runPredict( log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + mlTask.getModelId(), e); handlePredictFailure(mlTask, internalListener, e, true, modelId); }); - GetRequest getRequest = new GetRequest(ML_MODEL_INDEX, mlTask.getModelId()); - client - .get( - getRequest, + mlModelManager + .getModel( + mlTask.getModelId(), + tenantId, threadedActionListener( mlTask.getFunctionName(), ActionListener.runBefore(getModelListener, () -> context.restore()) diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java index c9a4a1a6d5..2a9d507de6 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java @@ -5,9 +5,7 @@ package org.opensearch.ml.action.controller; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; @@ -28,6 +26,7 @@ import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.Version; import org.opensearch.action.DocWriteResponse; @@ -169,10 +168,10 @@ public void setup() throws IOException { }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), Mockito.isNull(), any(), any(), isA(ActionListener.class)); when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); doAnswer(invocation -> { @@ -247,10 +246,10 @@ public void testCreateControllerWithModelAccessControlOtherException() { @Test public void testCreateControllerWithModelNotFound() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(null); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), any(), isA(ActionListener.class)); createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java index 1e49ab2fd7..4fb9a968cf 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java @@ -160,10 +160,10 @@ public void setup() throws IOException { }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -216,10 +216,10 @@ public void testDeleteControllerWithModelAccessControlNoPermissionHiddenModel() when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); when(mlModel.getIsHidden()).thenReturn(Boolean.TRUE); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(false); @@ -255,10 +255,10 @@ public void testDeleteControllerWithModelAccessControlOtherExceptionHiddenModel( when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); when(mlModel.getIsHidden()).thenReturn(Boolean.TRUE); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener @@ -280,10 +280,10 @@ public void testDeleteControllerWithModelAccessControlOtherExceptionHiddenModel( @Test public void testDeleteControllerWithGetModelNotFoundSuccess() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(null); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), any(), isA(ActionListener.class)); deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); verify(actionListener).onResponse(deleteResponse); @@ -320,10 +320,10 @@ public void testDeleteControllerWithGetControllerOtherException() { @Test public void testDeleteControllerWithGetModelNotFoundWithGetControllerOtherException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(null); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java index 489e71e080..d4ff67fa14 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java @@ -107,10 +107,10 @@ public void setup() throws IOException { mlControllerGetRequest = MLControllerGetRequest.builder().modelId("testModelId").build(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -170,10 +170,10 @@ public void testGetControllerWithModelAccessControlOtherException() { @Test public void testGetControllerWithGetModelNotFound() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(null); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), any(), isA(ActionListener.class)); getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -187,10 +187,10 @@ public void testGetControllerWithGetModelNotFound() { @Test public void testGetControllerWithGetModelOtherException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), any(), isA(ActionListener.class)); getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java index fd378647e9..f1a87a7dfb 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java @@ -179,10 +179,10 @@ public void setup() throws IOException { }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), any(), isA(ActionListener.class)); when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); when(mlModel.getModelId()).thenReturn("testModelId"); @@ -246,10 +246,10 @@ public void testUpdateControllerWithModelAccessControlNoPermissionHiddenModel() when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); when(mlModel.getIsHidden()).thenReturn(Boolean.TRUE); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(false); @@ -285,10 +285,10 @@ public void testUpdateControllerWithModelAccessControlOtherExceptionHiddenModel( when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); when(mlModel.getIsHidden()).thenReturn(Boolean.TRUE); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onFailure(new RuntimeException("Permission denied: Unable to create the model controller for the model. Details: ")); @@ -328,10 +328,10 @@ public void testUpdateControllerWithControllerEnabledNullHiddenModel() { when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); when(mlModel.getIsHidden()).thenReturn(Boolean.TRUE); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); @@ -396,10 +396,10 @@ public void testUpdateControllerWithModelFunctionUnsupported() { @Test public void tesUpdateControllerWithGetModelNotFound() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(null); return null; - }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), any(), isA(ActionListener.class)); updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -506,10 +506,10 @@ public void testUpdateControllerWithUndeploySuccessPartiallyFailuresHiddenModel( when(mlModel.getIsHidden()).thenReturn(Boolean.TRUE); when(mlModel.getModelId()).thenReturn("testModelId"); doAnswer(invocation -> { - ActionListener mllistener = invocation.getArgument(3); + ActionListener mllistener = invocation.getArgument(4); mllistener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); List failures = List .of(new FailedNodeException("foo1", "Undeploy failed.", new RuntimeException("Exception occurred."))); when(mlModelCacheHelper.getWorkerNodes("testModelId")).thenReturn(new String[] { "foo1", "foo2" }); @@ -556,10 +556,10 @@ public void testUpdateControllerWithUndeployNullResponseHiddenModel() { when(mlModel.getIsHidden()).thenReturn(Boolean.TRUE); when(mlModel.getModelId()).thenReturn("testModelId"); doAnswer(invocation -> { - ActionListener mllistener = invocation.getArgument(3); + ActionListener mllistener = invocation.getArgument(4); mllistener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); when(mlModelCacheHelper.getWorkerNodes("testModelId")).thenReturn(new String[] { "foo1", "foo2" }); @@ -606,10 +606,10 @@ public void testUpdateControllerWithUndeployOtherExceptionHiddenModel() { when(mlModel.getIsHidden()).thenReturn(Boolean.TRUE); when(mlModel.getModelId()).thenReturn("testModelId"); doAnswer(invocation -> { - ActionListener mllistener = invocation.getArgument(3); + ActionListener mllistener = invocation.getArgument(4); mllistener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); when(mlModelCacheHelper.getWorkerNodes("testModelId")).thenReturn(new String[] { "foo1", "foo2" }); diff --git a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java index 8b8ee5234f..879219c3da 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java @@ -212,10 +212,10 @@ public void testDoExecute_success() { MLModel mlModel = mock(MLModel.class); when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); IndexResponse indexResponse = mock(IndexResponse.class); when(indexResponse.getId()).thenReturn("mockIndexId"); @@ -234,10 +234,10 @@ public void testDoExecute_success_not_userInitiatedRequest() { MLModel mlModel = mock(MLModel.class); when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); when(mlDeployModelRequest.isUserInitiatedDeployRequest()).thenReturn(false); @@ -279,10 +279,10 @@ public void testDoExecute_success_hidden_model() { when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); when(mlModel.getIsHidden()).thenReturn(true); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); IndexResponse indexResponse = mock(IndexResponse.class); when(indexResponse.getId()).thenReturn("mockIndexId"); @@ -325,10 +325,10 @@ public void testDoExecute_no_permission_hidden_model() { when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); when(mlModel.getIsHidden()).thenReturn(true); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); IndexResponse indexResponse = mock(IndexResponse.class); when(indexResponse.getId()).thenReturn("mockIndexId"); @@ -351,10 +351,10 @@ public void testDoExecute_userHasNoAccessException() { MLModel mlModel = mock(MLModel.class); when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -373,10 +373,10 @@ public void testDoExecuteRemoteInferenceDisabled() { MLModel mlModel = mock(MLModel.class); when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false); ActionListener deployModelResponseListener = mock(ActionListener.class); @@ -386,14 +386,40 @@ public void testDoExecuteRemoteInferenceDisabled() { assertEquals(REMOTE_INFERENCE_DISABLED_ERR_MSG, argumentCaptor.getValue().getMessage()); } + public void testDoExecuteRemoteInference_MultiNodeEnabled() { + MLModel mlModel = mock(MLModel.class); + when(mlModel.getAlgorithm()).thenReturn(FunctionName.REMOTE); + when(mlModel.getTenantId()).thenReturn("test_tenant"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(4); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + IndexResponse indexResponse = mock(IndexResponse.class); + when(indexResponse.getId()).thenReturn("mockIndexId"); + listener.onResponse(indexResponse); + return null; + }).when(mlTaskManager).createMLTask(any(MLTask.class), Mockito.isA(ActionListener.class)); + + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + ActionListener deployModelResponseListener = mock(ActionListener.class); + when(mlDeployModelRequest.getTenantId()).thenReturn("test_tenant"); + transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLDeployModelResponse.class); + verify(deployModelResponseListener).onResponse(argumentCaptor.capture()); + assertEquals("CREATED", argumentCaptor.getValue().getStatus()); + } + public void testDoExecuteLocalInferenceDisabled() { MLModel mlModel = mock(MLModel.class); when(mlModel.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false); ActionListener deployModelResponseListener = mock(ActionListener.class); @@ -407,10 +433,10 @@ public void test_ValidationFailedException() { MLModel mlModel = mock(MLModel.class); when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -491,7 +517,7 @@ public void testDoExecute_whenDeployModelRequestNodeIdsEmpty_thenMLResourceNotFo public void testDoExecute_whenGetModelHasNPE_exception() { doThrow(NullPointerException.class) .when(mlModelManager) - .getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + .getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); ActionListener deployModelResponseListener = mock(ActionListener.class); transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener); @@ -502,10 +528,10 @@ public void testDoExecute_whenThreadPoolExecutorException_TaskRemoved() { MLModel mlModel = mock(MLModel.class); when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); + }).when(mlModelManager).getModel(anyString(), any(), isNull(), any(String[].class), Mockito.isA(ActionListener.class)); IndexResponse indexResponse = mock(IndexResponse.class); when(indexResponse.getId()).thenReturn("mockIndexId"); diff --git a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeActionTests.java index 83852cc68f..2fa105957e 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeActionTests.java @@ -207,7 +207,7 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(5); listener.onResponse("successful"); return null; - }).when(mlModelManager).deployModel(any(), any(), any(), any(Boolean.class), any(), any(), any()); + }).when(mlModelManager).deployModel(any(), any(), any(), any(), any(Boolean.class), any(), any(), any()); MLForwardResponse forwardResponse = Mockito.mock(MLForwardResponse.class); doAnswer(invocation -> { ActionListenerResponseHandler handler = invocation.getArgument(3); @@ -313,7 +313,7 @@ public void testNodeOperation_FailToSendForwardRequest() { ActionListener listener = invocation.getArgument(4); listener.onResponse("ok"); return null; - }).when(mlModelManager).deployModel(any(), any(), any(), any(Boolean.class), any(), any(), any()); + }).when(mlModelManager).deployModel(any(), any(), any(), any(), any(Boolean.class), any(), any(), any()); doAnswer(invocation -> { TransportResponseHandler handler = invocation.getArgument(3); handler.handleException(new TransportException("error")); @@ -331,7 +331,7 @@ public void testNodeOperation_Exception() { ActionListener listener = invocation.getArgument(4); listener.onFailure(new RuntimeException("Something went wrong")); return null; - }).when(mlModelManager).deployModel(any(), any(), any(), any(Boolean.class), any(), any(), any()); + }).when(mlModelManager).deployModel(any(), any(), any(), any(), any(Boolean.class), any(), any(), any()); final MLDeployModelNodesRequest nodesRequest = prepareRequest(localNode.getId()); final MLDeployModelNodeRequest request = action.newNodeRequest(nodesRequest); final MLDeployModelNodeResponse response = action.nodeOperation(request); @@ -342,7 +342,7 @@ public void testNodeOperation_Exception() { public void testNodeOperation_DeployModelRuntimeException() { doThrow(new RuntimeException("error")) .when(mlModelManager) - .deployModel(any(), any(), any(), any(Boolean.class), any(), any(), any()); + .deployModel(any(), any(), any(), any(), any(Boolean.class), any(), any(), any()); final MLDeployModelNodesRequest nodesRequest = prepareRequest(localNode.getId()); final MLDeployModelNodeRequest request = action.newNodeRequest(nodesRequest); final MLDeployModelNodeResponse response = action.nodeOperation(request); @@ -355,7 +355,7 @@ public void testNodeOperation_MLLimitExceededException() { ActionListener listener = invocation.getArgument(4); listener.onFailure(new MLLimitExceededException("Limit exceeded exception")); return null; - }).when(mlModelManager).deployModel(any(), any(), any(), any(Boolean.class), any(), any(), any()); + }).when(mlModelManager).deployModel(any(), any(), any(), any(), any(Boolean.class), any(), any(), any()); final MLDeployModelNodesRequest nodesRequest = prepareRequest(localNode.getId()); final MLDeployModelNodeRequest request = action.newNodeRequest(nodesRequest); final MLDeployModelNodeResponse response = action.nodeOperation(request); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java index 5129bfa16a..5e5f4e8199 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java @@ -364,13 +364,13 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(3); listener.onResponse(localModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(4); listener.onResponse(localModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); MLModelGroup modelGroup = MLModelGroup .builder() @@ -445,7 +445,7 @@ public void testUpdateRemoteModelWithLocalInformationSuccess() throws Interrupte ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); @@ -465,7 +465,7 @@ public void testUpdateExternalRemoteModelWithExternalRemoteInformationSuccess() ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); @@ -485,7 +485,7 @@ public void testUpdateInternalRemoteModelWithInternalRemoteInformationSuccess() ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); CountDownLatch latch = new CountDownLatch(1); LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); @@ -505,7 +505,7 @@ public void testUpdateHiddenRemoteModelWithRemoteInformationSuccess() throws Int ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doReturn(true).when(transportUpdateModelAction).isSuperAdminUserWrapper(clusterService, client); CountDownLatch latch = new CountDownLatch(1); @@ -526,7 +526,7 @@ public void testUpdateHiddenRemoteModelPermissionError() { ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doReturn(false).when(transportUpdateModelAction).isSuperAdminUserWrapper(clusterService, client); transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_INTERNAL"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -541,7 +541,7 @@ public void testUpdateRemoteModelWithNoExternalConnectorFound() { ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModelWithInternalConnector); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_EXTERNAL"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -559,7 +559,7 @@ public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControl ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(5); @@ -583,7 +583,7 @@ public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControl ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(5); @@ -712,7 +712,7 @@ public void testUpdateModelWithModelNotFound() { ActionListener listener = invocation.getArgument(4); listener.onResponse(null); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -726,7 +726,7 @@ public void testUpdateModelWithFunctionNameFieldNotFound() { ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModelWithNullFunctionName); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -756,7 +756,7 @@ public void testUpdateLocalModelWithUnsupportedFunction() { ActionListener listener = invocation.getArgument(4); listener.onResponse(localModelWithUnsupportedFunction); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_EXTERNAL"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -773,7 +773,7 @@ public void testUpdateRequestDocIOException() throws IOException, InterruptedExc ActionListener listener = invocation.getArgument(4); listener.onResponse(mockModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("mockId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), any(), isA(ActionListener.class)); doReturn("test_model_group_id").when(mockModel).getModelGroupId(); doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); @@ -800,7 +800,7 @@ public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IO ActionListener listener = invocation.getArgument(4); listener.onResponse(mockModel); return null; - }).when(mlModelManager).getModel(eq(sdkClient), eq("mockId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), any(), isA(ActionListener.class)); doReturn("test_model_group_id").when(mockModel).getModelGroupId(); doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); @@ -957,7 +957,7 @@ public void testUpdateModelStateDeployingException() { ActionListener listener = invocation.getArgument(4); listener.onResponse(testDeployingModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); @@ -975,7 +975,7 @@ public void testUpdateModelStateLoadingException() { ActionListener listener = invocation.getArgument(4); listener.onResponse(testDeployingModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); @@ -993,7 +993,7 @@ public void testUpdateModelCacheModelStateDeployedSuccess() throws InterruptedEx ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1022,7 +1022,7 @@ public void testUpdateModelCacheModelWithIsModelEnabledSuccess() throws Interrup ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1053,7 +1053,7 @@ public void testUpdateModelCacheModelWithoutUpdateConnectorWithRateLimiterSucces ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1085,7 +1085,7 @@ public void testUpdateModelCacheModelWithRateLimiterSuccess() throws Interrupted ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1116,7 +1116,7 @@ public void testUpdateModelWithPartialRateLimiterSuccess() throws InterruptedExc ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); MLRateLimiter rateLimiter = MLRateLimiter.builder().limit("1").build(); MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); @@ -1142,7 +1142,7 @@ public void testUpdateModelCacheModelWithPartialRateLimiterSuccess() throws Inte ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1173,7 +1173,7 @@ public void testUpdateModelCacheUpdateResponseListenerWithNullUpdateResponse() t ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); PlainActionFuture future = PlainActionFuture.newFuture(); future.onResponse(null); @@ -1206,7 +1206,7 @@ public void testUpdateModelCacheModelWithUndeploySuccessEmptyFailures() throws I ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1239,7 +1239,7 @@ public void testUpdateControllerWithUndeploySuccessPartiallyFailures() throws In ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1271,7 +1271,7 @@ public void testUpdateControllerWithUndeployNullResponse() throws InterruptedExc ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1302,7 +1302,7 @@ public void testUpdateControllerWithUndeployOtherException() throws InterruptedE ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); @@ -1337,7 +1337,7 @@ public void testUpdateModelCacheModelStateDeployedWrongStatus() throws Interrupt ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1366,7 +1366,7 @@ public void testUpdateModelCacheModelStateDeployedUpdateModelCacheException() th ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1411,7 +1411,7 @@ public void testUpdateModelCacheModelStateDeployedUpdateException() throws Inter ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1442,7 +1442,7 @@ public void testUpdateModelCacheModelRegisterToNewModelGroupSuccess() throws Int ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1475,7 +1475,7 @@ public void testUpdateModelCacheModelRegisterToNewModelGroupWrongStatus() throws ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1503,7 +1503,7 @@ public void testUpdateModelCacheModelRegisterToNewModelGroupUpdateModelCacheExce ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1547,7 +1547,7 @@ public void testUpdateModelCacheModelRegisterToNewModelGroupUpdateException() th ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1577,7 +1577,7 @@ public void testUpdateModelCacheModelStateLoadedSuccess() throws InterruptedExce ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1606,7 +1606,7 @@ public void testUpdateModelCacheModelStatePartiallyDeployedSuccess() throws Inte ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1635,7 +1635,7 @@ public void testUpdateModelCacheModelStatePartiallyLoadedSuccess() throws Interr ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1677,6 +1677,7 @@ private MLModel prepareMLModel(String functionName, MLModelState modelState, boo mlModel = MLModel .builder() .name("test_name") + .tenantId("tenant_id") .modelId("test_model_id") .modelGroupId("test_model_group_id") .description("test_description") @@ -1797,7 +1798,7 @@ public void testUpdateModelStatePartiallyLoadedException() { ActionListener listener = invocation.getArgument(4); listener.onResponse(mockModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("mockId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("mockId"), anyString(), any(), any(), isA(ActionListener.class)); doReturn("test_model_group_id").when(mockModel).getModelGroupId(); doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); @@ -1822,7 +1823,7 @@ public void testUpdateModelStatePartiallyDeployedException() { ActionListener listener = invocation.getArgument(4); listener.onResponse(mockModel); return null; - }).when(mlModelManager).getModel(any(SdkClient.class), eq("mockId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("mockId"), anyString(), any(), any(), isA(ActionListener.class)); doReturn("test_model_group_id").when(mockModel).getModelGroupId(); doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); diff --git a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java index 4cf82f948f..d7b6088aab 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java @@ -188,10 +188,10 @@ public void setup() throws IOException { .isHidden(false) .build(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class)); } @AfterClass @@ -213,10 +213,10 @@ public void testHiddenModelSuccess() { .isHidden(true) .build(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class)); List responseList = new ArrayList<>(); List failuresList = new ArrayList<>(); @@ -247,10 +247,10 @@ public void testHiddenModelPermissionError() { .isHidden(true) .build(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class)); List responseList = new ArrayList<>(); List failuresList = new ArrayList<>(); @@ -292,7 +292,7 @@ public void testDoExecute() { public void testDoExecute_modelAccessControl_notEnabled() { when(modelAccessControlHelper.isModelAccessControlEnabled()).thenReturn(false); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(true); return null; }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); @@ -327,17 +327,19 @@ public void testDoExecute_validate_false() { public void testDoExecute_getModel_exception() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onFailure(new RuntimeException("runtime exception")); return null; - }).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class)); MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); transportUndeployModelsAction.doExecute(task, request, actionListener); verify(actionListener).onFailure(isA(RuntimeException.class)); } public void testDoExecute_validateAccess_exception() { - doThrow(new RuntimeException("runtime exception")).when(mlModelManager).getModel(any(), any(), any(), isA(ActionListener.class)); + doThrow(new RuntimeException("runtime exception")) + .when(mlModelManager) + .getModel(any(), any(), any(), any(), isA(ActionListener.class)); MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); transportUndeployModelsAction.doExecute(task, request, actionListener); verify(actionListener).onFailure(isA(RuntimeException.class)); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 326721803d..250ce2bd98 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -39,9 +39,7 @@ import static org.opensearch.ml.utils.MockHelper.mock_MLIndicesHandler_initModelIndex_failure; import static org.opensearch.ml.utils.MockHelper.mock_client_ThreadContext; import static org.opensearch.ml.utils.MockHelper.mock_client_ThreadContext_Exception; -import static org.opensearch.ml.utils.MockHelper.mock_client_get_NotExist; import static org.opensearch.ml.utils.MockHelper.mock_client_get_NullResponse; -import static org.opensearch.ml.utils.MockHelper.mock_client_get_failure; import static org.opensearch.ml.utils.MockHelper.mock_client_index; import static org.opensearch.ml.utils.MockHelper.mock_client_index_failure; import static org.opensearch.ml.utils.MockHelper.mock_client_update; @@ -90,9 +88,14 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.get.GetResult; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.breaker.MemoryCircuitBreaker; import org.opensearch.ml.breaker.ThresholdCircuitBreaker; @@ -207,7 +210,7 @@ public class MLModelManagerTests extends OpenSearchTestCase { private MLTask pretrainedMLTask; @Before - public void setup() throws URISyntaxException { + public void setup() throws URISyntaxException, IOException { String masterKey = "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="; MockitoAnnotations.openMocks(this); @@ -287,6 +290,7 @@ public void setup() throws URISyntaxException { clusterService, scriptService, client, + sdkClient, threadPool, xContentRegistry, modelHelper, @@ -339,6 +343,16 @@ public void setup() throws URISyntaxException { }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); } + public void setupGetModel(MLModel model) throws IOException { + XContentBuilder content = model.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); + GetResponse getResponse = new GetResponse(getResult); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(getResponse); + when(client.get(any(GetRequest.class))).thenReturn(future); + } + @AfterClass public static void cleanup() { ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); @@ -633,9 +647,12 @@ public void testDeployModel_FailedToGetModel() { when(modelCacheHelper.getDeployedModels()).thenReturn(new String[] {}); when(modelCacheHelper.getLocalDeployedModels()).thenReturn(new String[] {}); mock_threadpool(threadPool, taskExecutorService); - mock_client_get_failure(client); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onFailure(new RuntimeException("get doc failure")); + when(client.get(any(GetRequest.class))).thenReturn(future); + mock_client_ThreadContext(client, threadPool, threadContext); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); + modelManager.deployModel(modelId, null, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); assertFalse(modelManager.isModelRunningOnNode(modelId)); ArgumentCaptor exception = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(exception.capture()); @@ -648,26 +665,13 @@ public void testDeployModel_FailedToGetModel() { ); } - public void testDeployModel_NullGetModelResponse() { + public void testDeployModel_NullGetModelResponse() throws IOException { MLModelConfig modelConfig = TextEmbeddingModelConfig .builder() .modelType("bert") .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) .embeddingDimension(384) .build(); - model = MLModel - .builder() - .modelId(modelId) - .modelState(MLModelState.DEPLOYING) - .algorithm(FunctionName.TEXT_EMBEDDING) - .name(modelName) - .version(version) - .totalChunks(2) - .modelFormat(MLModelFormat.TORCH_SCRIPT) - .modelConfig(modelConfig) - .modelContentHash(modelContentHashValue) - .modelContentSizeInBytes(modelContentSize) - .build(); String[] nodes = new String[] { "node1", "node2" }; mlTask.setWorkerNodes(List.of(nodes)); ActionListener listener = mock(ActionListener.class); @@ -676,7 +680,10 @@ public void testDeployModel_NullGetModelResponse() { when(modelCacheHelper.getLocalDeployedModels()).thenReturn(new String[] {}); mock_threadpool(threadPool, taskExecutorService); mock_client_get_NullResponse(client); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(null); + when(client.get(any(GetRequest.class))).thenReturn(future); + modelManager.deployModel(modelId, null, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); assertFalse(modelManager.isModelRunningOnNode(modelId)); ArgumentCaptor exception = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(exception.capture()); @@ -689,26 +696,13 @@ public void testDeployModel_NullGetModelResponse() { ); } - public void testDeployModel_GetModelResponse_NotExist() { + public void testDeployModel_GetModelResponse_NotExist() throws IOException { MLModelConfig modelConfig = TextEmbeddingModelConfig .builder() .modelType("bert") .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) .embeddingDimension(384) .build(); - model = MLModel - .builder() - .modelId(modelId) - .modelState(MLModelState.DEPLOYING) - .algorithm(FunctionName.TEXT_EMBEDDING) - .name(modelName) - .version(version) - .totalChunks(2) - .modelFormat(MLModelFormat.TORCH_SCRIPT) - .modelConfig(modelConfig) - .modelContentHash(modelContentHashValue) - .modelContentSizeInBytes(modelContentSize) - .build(); String[] nodes = new String[] { "node1", "node2" }; mlTask.setWorkerNodes(List.of(nodes)); ActionListener listener = mock(ActionListener.class); @@ -716,8 +710,15 @@ public void testDeployModel_GetModelResponse_NotExist() { when(modelCacheHelper.getDeployedModels()).thenReturn(new String[] {}); when(modelCacheHelper.getLocalDeployedModels()).thenReturn(new String[] {}); mock_threadpool(threadPool, taskExecutorService); - mock_client_get_NotExist(client); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); + XContentBuilder content = model.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", "111", -2, 0, 111l, false, bytesReference, null, null); + GetResponse getResponse = new GetResponse(getResult); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(getResponse); + when(client.get(any(GetRequest.class))).thenReturn(future); + + modelManager.deployModel(modelId, null, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); assertFalse(modelManager.isModelRunningOnNode(modelId)); ArgumentCaptor exception = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(exception.capture()); @@ -730,7 +731,7 @@ public void testDeployModel_GetModelResponse_NotExist() { ); } - public void testDeployModel_GetModelResponse_wrong_hash_value() { + public void testDeployModel_GetModelResponse_wrong_hash_value() throws IOException { MLModelConfig modelConfig = TextEmbeddingModelConfig .builder() .modelType("bert") @@ -759,10 +760,10 @@ public void testDeployModel_GetModelResponse_wrong_hash_value() { when(modelCacheHelper.getLocalDeployedModels()).thenReturn(new String[] {}); mock_client_ThreadContext(client, threadPool, threadContext); mock_threadpool(threadPool, taskExecutorService); - setUpMock_GetModel(model); - setUpMock_GetModel(modelChunk0); - setUpMock_GetModel(modelChunk0); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); + setupGetModel(model); + setupGetModel(modelChunk0); + setupGetModel(modelChunk0); + modelManager.deployModel(modelId, null, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); assertFalse(modelManager.isModelRunningOnNode(modelId)); ArgumentCaptor exception = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(exception.capture()); @@ -781,7 +782,7 @@ public void testDeployModel_GetModelResponse_wrong_hash_value() { ); } - public void testDeployModel_GetModelResponse_FailedToDeploy() { + public void testDeployModel_GetModelResponse_FailedToDeploy() throws IOException { MLModelConfig modelConfig = TextEmbeddingModelConfig .builder() .modelType("bert") @@ -812,7 +813,7 @@ public void testDeployModel_GetModelResponse_FailedToDeploy() { setUpMock_GetModelChunks(model); // setUpMock_GetModel(modelChunk0); // setUpMock_GetModel(modelChunk1); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); + modelManager.deployModel(modelId, null, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); assertFalse(modelManager.isModelRunningOnNode(modelId)); ArgumentCaptor exception = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(exception.capture()); @@ -828,7 +829,7 @@ public void testDeployModel_GetModelResponse_FailedToDeploy() { public void testDeployModel_ModelAlreadyDeployed() { when(modelCacheHelper.isModelDeployed(modelId)).thenReturn(true); ActionListener listener = mock(ActionListener.class); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); + modelManager.deployModel(modelId, null, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); ArgumentCaptor response = ArgumentCaptor.forClass(String.class); verify(listener).onResponse(response.capture()); assertEquals("successful", response.getValue()); @@ -843,7 +844,7 @@ public void testDeployModel_ExceedMaxDeployedModel() { when(modelCacheHelper.getDeployedModels()).thenReturn(models); when(modelCacheHelper.getLocalDeployedModels()).thenReturn(models); ActionListener listener = mock(ActionListener.class); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); + modelManager.deployModel(modelId, null, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); ArgumentCaptor failure = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(failure.capture()); assertEquals("Exceed max local model per node limit", failure.getValue().getMessage()); @@ -878,7 +879,7 @@ public void testDeployModel_ThreadPoolException() { ActionListener listener = mock(ActionListener.class); FunctionName functionName = FunctionName.TEXT_EMBEDDING; - modelManager.deployModel(modelId, modelContentHashValue, functionName, true, false, mlTask, listener); + modelManager.deployModel(modelId, null, modelContentHashValue, functionName, true, false, mlTask, listener); verify(modelCacheHelper).removeModel(eq(modelId)); verify(mlStats).createCounterStatIfAbsent(eq(functionName), eq(ActionName.DEPLOY), eq(MLActionLevelStat.ML_ACTION_FAILURE_COUNT)); } @@ -1037,7 +1038,7 @@ private void testDeployModel_FailedToRetrieveModelChunks(boolean lastChunk) { ActionListener listener = mock(ActionListener.class); FunctionName functionName = FunctionName.TEXT_EMBEDDING; - modelManager.deployModel(modelId, modelContentHashValue, functionName, true, false, mlTask, listener); + modelManager.deployModel(modelId, null, modelContentHashValue, functionName, true, false, mlTask, listener); verify(modelCacheHelper).removeModel(eq(modelId)); verify(mlStats).createCounterStatIfAbsent(eq(functionName), eq(ActionName.DEPLOY), eq(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); verify(mlStats).getStat(eq(MLNodeLevelStat.ML_REQUEST_COUNT)); @@ -1067,46 +1068,46 @@ private void setUpMock_GetModel(MLModel model) { private void setUpMock_GetModelChunks(MLModel model) { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(model); return null; }).doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(modelChunk0); return null; }).doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(modelChunk1); return null; - }).when(modelManager).getModel(any(), any()); + }).when(modelManager).getModel(any(), any(), any()); } private void setUpMock_GetModelMeta_FailedToGetFirstChunk(MLModel model) { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(model); return null; }).doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new RuntimeException("Failed to get model")); return null; - }).when(modelManager).getModel(any(), any()); + }).when(modelManager).getModel(any(), any(), any()); } private void setUpMock_GetModelMeta_FailedToGetLastChunk(MLModel model) { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(model); return null; }).doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(modelChunk0); return null; }).doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new RuntimeException("Failed to get model")); return null; - }).when(modelManager).getModel(any(), any()); + }).when(modelManager).getModel(any(), any(), any()); } private void setUpMock_DownloadModelFileFailure() { diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java index 3103d1d86c..5a393dbc14 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java @@ -391,7 +391,7 @@ public void updateDataObjectAsync_HappyCase() { assertEquals(TEST_INDEX, updateItemRequest.tableName()); assertEquals(TEST_ID, updateItemRequest.key().get(RANGE_KEY).s()); assertEquals(TENANT_ID, updateItemRequest.key().get(HASH_KEY).s()); - assertEquals("foo", updateItemRequest.key().get("data").s()); + assertEquals("foo", updateItemRequest.attributeUpdates().get("data").value().s()); } @@ -415,8 +415,7 @@ public void updateDataObjectAsync_HappyCaseWithMap() { assertEquals(TEST_INDEX, updateItemRequest.tableName()); assertEquals(TEST_ID, updateItemRequest.key().get(RANGE_KEY).s()); assertEquals(TENANT_ID, updateItemRequest.key().get(HASH_KEY).s()); - assertEquals("bar", updateItemRequest.key().get("foo").s()); - + assertEquals("bar", updateItemRequest.attributeUpdates().get("foo").value().s()); } @Test diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java index 8667450d9c..12cac0afdd 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java @@ -30,21 +30,23 @@ protected void configure() { }; public void testLocalBinding() { - Injector injector = Guice.createInjector(new SdkClientModule(null, null, null), localClientModule); + Injector injector = Guice.createInjector(new SdkClientModule(null, null, null, null, null), localClientModule); SdkClient sdkClient = injector.getInstance(SdkClient.class); assertTrue(sdkClient instanceof LocalClusterIndicesClient); } public void testRemoteOpenSearchBinding() { - Injector injector = Guice.createInjector(new SdkClientModule(SdkClientModule.REMOTE_OPENSEARCH, "http://example.org", "eu-west-3")); + Injector injector = Guice + .createInjector(new SdkClientModule(null, null, SdkClientModule.REMOTE_OPENSEARCH, "http://example.org", "eu-west-3")); SdkClient sdkClient = injector.getInstance(SdkClient.class); assertTrue(sdkClient instanceof RemoteClusterIndicesClient); } public void testDDBBinding() { - Injector injector = Guice.createInjector(new SdkClientModule(SdkClientModule.AWS_DYNAMO_DB, "http://example.org", "eu-west-3")); + Injector injector = Guice + .createInjector(new SdkClientModule(null, null, SdkClientModule.AWS_DYNAMO_DB, "http://example.org", "eu-west-3")); SdkClient sdkClient = injector.getInstance(SdkClient.class); assertTrue(sdkClient instanceof DDBOpenSearchClient); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index cbde703543..20dc04527c 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -233,7 +233,6 @@ public void testExecuteTask_OnLocalNode() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); verify(mlTaskManager).add(any(MLTask.class)); - verify(client).get(any(), any()); verify(mlTaskManager).remove(anyString()); } @@ -243,17 +242,16 @@ public void testExecuteTask_OnLocalNode_QueryInput() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithQuery, transportService, listener); verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); verify(mlTaskManager).add(any(MLTask.class)); - verify(client).get(any(), any()); verify(mlTaskManager).remove(anyString()); } public void testExecuteTask_OnLocalNode_RemoteModelAutoDeploy() { setupMocks(true, false, false, false); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); + ActionListener actionListener = invocation.getArgument(2); actionListener.onResponse(mlModel); return null; - }).when(mlModelManager).getModel(any(), any()); + }).when(mlModelManager).getModel(any(), any(), any()); when(mlModelManager.addModelToAutoDeployCache("111", mlModel)).thenReturn(mlModel); taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); verify(client).execute(any(), any(), any()); @@ -276,7 +274,6 @@ public void testExecuteTask_NoPermission() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlTaskManager).add(any(MLTask.class)); verify(mlTaskManager).remove(anyString()); - verify(client).get(any(), any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); assertEquals("User: test_user does not have permissions to run predict by model: 111", argumentCaptor.getValue().getMessage()); @@ -294,7 +291,6 @@ public void testExecuteTask_OnLocalNode_GetModelFail() { taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); verify(mlTaskManager).add(any(MLTask.class)); - verify(client).get(any(), any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); assertEquals(errorMessage, argumentCaptor.getValue().getMessage()); @@ -380,13 +376,12 @@ public void testExecuteTask_OnLocalNode_prediction_exception() { assertEquals("runtime exception", argumentCaptor.getValue().getMessage()); } - public void testExecuteTask_OnLocalNode_NullGetResponse() { + public void testExecuteTask_OnLocalNode_NullMLModel() { setupMocks(true, false, false, true); taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); verify(mlTaskManager).add(any(MLTask.class)); - verify(client).get(any(), any()); verify(mlTaskManager).remove(anyString()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(argumentCaptor.capture()); @@ -432,7 +427,7 @@ public void testValidateModelTensorOutputFailed() { taskRunner.validateOutputSchema("testId", modelTensorOutput); } - private void setupMocks(boolean runOnLocalNode, boolean failedToParseQueryInput, boolean failedToGetModel, boolean nullGetResponse) { + private void setupMocks(boolean runOnLocalNode, boolean failedToParseQueryInput, boolean failedToGetModel, boolean nullMlModel) { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); if (runOnLocalNode) { @@ -466,23 +461,16 @@ private void setupMocks(boolean runOnLocalNode, boolean failedToParseQueryInput, return null; }).when(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); } - - if (nullGetResponse) { - getResponse = null; - } - - if (failedToGetModel) { - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + if (failedToGetModel) { actionListener.onFailure(new RuntimeException(errorMessage)); - return null; - }).when(client).get(any(), any()); - } else { - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(getResponse); - return null; - }).when(client).get(any(), any()); - } + } else if (nullMlModel) { + actionListener.onResponse(null); + } else { + actionListener.onResponse(mlModel); + } + return null; + }).when(mlModelManager).getModel(any(), any(), any()); } }