diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index c160306550..f5c62d7a6d 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -40,6 +40,8 @@ public class MLModel implements ToXContentObject { @Deprecated public static final String ALGORITHM_FIELD = "algorithm"; + + public static final String TENANT_ID_FIELD = "tenant_id"; public static final String FUNCTION_NAME_FIELD = "function_name"; public static final String MODEL_NAME_FIELD = "name"; public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; @@ -133,6 +135,7 @@ public class MLModel implements ToXContentObject { private Connector connector; private String connectorId; private Guardrails guardrails; + private String tenantId; @Builder(toBuilder = true) public MLModel(String name, @@ -166,7 +169,8 @@ public MLModel(String name, Boolean isHidden, Connector connector, String connectorId, - Guardrails guardrails) { + Guardrails guardrails, + String tenantId) { this.name = name; this.modelGroupId = modelGroupId; this.algorithm = algorithm; @@ -200,6 +204,7 @@ public MLModel(String name, this.connector = connector; this.connectorId = connectorId; this.guardrails = guardrails; + this.tenantId = tenantId; } public MLModel(StreamInput input) throws IOException { @@ -442,6 +447,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (guardrails != null) { builder.field(GUARDRAILS_FIELD, guardrails); } + if (tenantId != null) { + builder.field(TENANT_ID_FIELD, tenantId); + } builder.endObject(); return builder; } @@ -486,6 +494,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws Connector connector = null; String connectorId = null; Guardrails guardrails = null; + String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -617,6 +626,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws case GUARDRAILS_FIELD: guardrails = Guardrails.parse(parser); break; + case TENANT_ID_FIELD: + tenantId = parser.text(); + break; default: parser.skipChildren(); break; @@ -656,6 +668,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws .connector(connector) .connectorId(connectorId) .guardrails(guardrails) + .tenantId(tenantId) .build(); } diff --git a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java index 8a8cf5ff04..1ab75e20f6 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java @@ -35,6 +35,8 @@ public class MLModelGroup implements ToXContentObject { public static final String ACCESS = "access"; //assigned to public, private, or null when model group created public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; //unique ID assigned to each model group + + public static final String TENANT_ID_FIELD = "tenant_id"; public static final String CREATED_TIME_FIELD = "created_time"; //model group created time stamp public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; //updated whenever a new model version is created @@ -42,6 +44,7 @@ public class MLModelGroup implements ToXContentObject { @Setter private String name; private String description; + @Setter private int latestVersion; private List backendRoles; private User owner; @@ -50,7 +53,10 @@ public class MLModelGroup implements ToXContentObject { private String modelGroupId; + private String tenantId; + private Instant createdTime; + @Setter private Instant lastUpdatedTime; @@ -58,6 +64,7 @@ public class MLModelGroup implements ToXContentObject { public MLModelGroup(String name, String description, int latestVersion, List backendRoles, User owner, String access, String modelGroupId, + String tenantId, Instant createdTime, Instant lastUpdatedTime) { this.name = Objects.requireNonNull(name, "model group name must not be null"); @@ -67,6 +74,7 @@ public MLModelGroup(String name, String description, int latestVersion, this.owner = owner; this.access = access; this.modelGroupId = modelGroupId; + this.tenantId = tenantId; this.createdTime = createdTime; this.lastUpdatedTime = lastUpdatedTime; } @@ -132,6 +140,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (modelGroupId != null) { builder.field(MODEL_GROUP_ID_FIELD, modelGroupId); } + if (tenantId != null) { + builder.field(TENANT_ID_FIELD, tenantId); + } if (createdTime != null) { builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli()); } @@ -150,6 +161,7 @@ public static MLModelGroup parse(XContentParser parser) throws IOException { User owner = null; String access = null; String modelGroupId = null; + String tenantId = null; Instant createdTime = null; Instant lastUpdateTime = null; @@ -184,6 +196,9 @@ public static MLModelGroup parse(XContentParser parser) throws IOException { case MODEL_GROUP_ID_FIELD: modelGroupId = parser.text(); break; + case TENANT_ID_FIELD: + tenantId = parser.text(); + break; case CREATED_TIME_FIELD: createdTime = Instant.ofEpochMilli(parser.longValue()); break; @@ -203,6 +218,7 @@ public static MLModelGroup parse(XContentParser parser) throws IOException { .owner(owner) .access(access) .modelGroupId(modelGroupId) + .tenantId(tenantId) .createdTime(createdTime) .lastUpdatedTime(lastUpdateTime) .build(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/AbstractGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/AbstractGetRequest.java new file mode 100644 index 0000000000..8f962a70f8 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/AbstractGetRequest.java @@ -0,0 +1,29 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.ml.common.transport; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; +import org.opensearch.action.ActionRequest; +import org.opensearch.core.common.io.stream.StreamInput; + +import java.io.IOException; + +@NoArgsConstructor +@AllArgsConstructor +public abstract class AbstractGetRequest extends ActionRequest { + @Setter + @Getter + private String tenantId; + + public AbstractGetRequest(StreamInput in) throws IOException { + super(in); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequest.java index 51f4616d5f..dae7b78f49 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequest.java @@ -14,6 +14,7 @@ import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.transport.AbstractGetRequest; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -24,11 +25,10 @@ @Getter @Setter -public class MLConnectorGetRequest extends ActionRequest { +public class MLConnectorGetRequest extends AbstractGetRequest { String connectorId; - String tenantId; boolean returnContent; @Builder diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java index 7cad570f1d..7c65d9c6d2 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java @@ -16,6 +16,7 @@ import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.transport.AbstractGetRequest; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -27,7 +28,7 @@ @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString -public class MLModelGetRequest extends ActionRequest { +public class MLModelGetRequest extends AbstractGetRequest { String modelId; boolean returnContent; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index 41cb8d6a04..deaab9837b 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -7,6 +7,7 @@ import lombok.Builder; import lombok.Data; +import lombok.Setter; import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -96,6 +97,9 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { private Boolean isHidden; private Guardrails guardrails; + @Setter + private String tenantId; + @Builder(toBuilder = true) public MLRegisterModelInput(FunctionName functionName, String modelName, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java index b6b5af85b4..7e6ce8b16b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java @@ -58,6 +58,12 @@ public void initModelIndexIfAbsent(ActionListener listener) { initMLIndexIfAbsent(MLIndex.MODEL, listener); } + public Boolean initModelIndexIfAbsent() throws ExecutionException, InterruptedException { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + initModelIndexIfAbsent(actionFuture); + return actionFuture.get(); + } + public void initMLTaskIndex(ActionListener listener) { initMLIndexIfAbsent(MLIndex.TASK, listener); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java index 0d5fc5f659..6d50ce8b01 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java @@ -5,38 +5,26 @@ package org.opensearch.ml.action.models; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; -import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD; -import static org.opensearch.ml.common.MLModel.IS_HIDDEN_FIELD; -import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; -import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; - import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; -import org.opensearch.action.get.GetRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.Connector; -import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.transport.model.MLModelGetAction; import org.opensearch.ml.common.transport.model.MLModelGetRequest; import org.opensearch.ml.common.transport.model.MLModelGetResponse; +import org.opensearch.ml.dao.model.ModelDao; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.utils.RestActionUtils; -import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -46,18 +34,21 @@ import lombok.experimental.FieldDefaults; import lombok.extern.log4j.Log4j2; +import java.util.Optional; + @Log4j2 @FieldDefaults(level = AccessLevel.PRIVATE) public class GetModelTransportAction extends HandledTransportAction { Client client; - NamedXContentRegistry xContentRegistry; ClusterService clusterService; ModelAccessControlHelper modelAccessControlHelper; Settings settings; + ModelDao modelDao; + @Inject public GetModelTransportAction( TransportService transportService, @@ -66,96 +57,71 @@ public GetModelTransportAction( Settings settings, NamedXContentRegistry xContentRegistry, ClusterService clusterService, - ModelAccessControlHelper modelAccessControlHelper + ModelAccessControlHelper modelAccessControlHelper, + ModelDao modelDao ) { super(MLModelGetAction.NAME, transportService, actionFilters, MLModelGetRequest::new); this.client = client; this.settings = settings; - this.xContentRegistry = xContentRegistry; this.clusterService = clusterService; this.modelAccessControlHelper = modelAccessControlHelper; + this.modelDao = modelDao; } @Override protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { MLModelGetRequest mlModelGetRequest = MLModelGetRequest.fromActionRequest(request); String modelId = mlModelGetRequest.getModelId(); - FetchSourceContext fetchSourceContext = getFetchSourceContext(mlModelGetRequest.isReturnContent()); - GetRequest getRequest = new GetRequest(ML_MODEL_INDEX).id(modelId).fetchSourceContext(fetchSourceContext); User user = RestActionUtils.getUserContext(client); boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client); - - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); - client.get(getRequest, ActionListener.wrap(r -> { - if (r != null && r.isExists()) { - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - String algorithmName = r.getSource().get(ALGORITHM_FIELD).toString(); - Boolean isHidden = (Boolean) r.getSource().get(IS_HIDDEN_FIELD); - MLModel mlModel = MLModel.parse(parser, algorithmName); - if (isHidden != null && isHidden) { - if (isSuperAdmin || !mlModelGetRequest.isUserInitiatedGetRequest()) { - wrappedListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build()); - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model", - RestStatus.FORBIDDEN - ) - ); - } - } else { - modelAccessControlHelper - .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { - if (!access) { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model", - RestStatus.FORBIDDEN - ) - ); - } else { - log.debug("Completed Get Model Request, id:{}", modelId); - Connector connector = mlModel.getConnector(); - if (connector != null) { - connector.removeCredential(); - } - wrappedListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build()); - } - }, e -> { - log.error("Failed to validate Access for Model Id " + modelId, e); - wrappedListener.onFailure(e); - })); - } - } catch (Exception e) { - log.error("Failed to parse ml model " + r.getId(), e); - wrappedListener.onFailure(e); - } + Optional modelOptional = modelDao.getModel(modelId, mlModelGetRequest.isReturnContent()); + if (modelOptional.isPresent()) { + MLModel mlModel = modelOptional.get(); + Boolean isHidden = mlModel.getIsHidden(); + if (isHidden != null && isHidden) { + if (isSuperAdmin || !mlModelGetRequest.isUserInitiatedGetRequest()) { + actionListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build()); } else { - wrappedListener + actionListener .onFailure( new OpenSearchStatusException( - "Failed to find model with the provided model id: " + modelId, - RestStatus.NOT_FOUND + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN ) ); } - }, e -> { - if (e instanceof IndexNotFoundException) { - wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model")); - } else { - log.error("Failed to get ML model " + modelId, e); - wrappedListener.onFailure(e); - } - })); - } catch (Exception e) { - log.error("Failed to get ML model " + modelId, e); - actionListener.onFailure(e); + } else { + modelAccessControlHelper + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { + if (!access) { + actionListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN + ) + ); + } else { + log.debug("Completed Get Model Request, id:{}", modelId); + Connector connector = mlModel.getConnector(); + if (connector != null) { + connector.removeCredential(); + } + actionListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build()); + }}, e -> { + log.error("Failed to validate Access for Model Id " + modelId, e); + actionListener.onFailure(e); + })); + } + } else { + actionListener + .onFailure( + new OpenSearchStatusException( + "Failed to find model with the provided model id: " + modelId, + RestStatus.NOT_FOUND + ) + ); } - } // this method is only to stub static method. diff --git a/plugin/src/main/java/org/opensearch/ml/dao/model/ModelDao.java b/plugin/src/main/java/org/opensearch/ml/dao/model/ModelDao.java new file mode 100644 index 0000000000..e41f7b2881 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/dao/model/ModelDao.java @@ -0,0 +1,19 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.ml.dao.model; + +import org.opensearch.ml.common.MLModel; + +import java.util.Optional; + +public interface ModelDao { + + String createModel(MLModel mlModel); + + Optional getModel(String modelId, boolean isReturnContent); +} diff --git a/plugin/src/main/java/org/opensearch/ml/dao/model/OpenSearchRestModelDao.java b/plugin/src/main/java/org/opensearch/ml/dao/model/OpenSearchRestModelDao.java new file mode 100644 index 0000000000..4d178ebe15 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/dao/model/OpenSearchRestModelDao.java @@ -0,0 +1,55 @@ +package org.opensearch.ml.dao.model; + +import lombok.extern.log4j.Log4j2; +import org.opensearch.client.opensearch.OpenSearchClient; +import org.opensearch.client.opensearch.core.GetResponse; +import org.opensearch.client.opensearch.core.IndexRequest; +import org.opensearch.client.opensearch.core.IndexResponse; +import org.opensearch.ml.common.MLModel; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Optional; + +@Log4j2 +public class OpenSearchRestModelDao implements ModelDao { + + private static final String ML_MODEL_INDEX = "oasis_ml_model"; + + private OpenSearchClient openSearchClient; + + public OpenSearchRestModelDao(OpenSearchClient openSearchClient) { + this.openSearchClient = openSearchClient; + } + @Override + public String createModel(MLModel mlModel) { + try { + IndexRequest indexRequest = new IndexRequest.Builder().index(ML_MODEL_INDEX) + .document(mlModel).build(); + final IndexResponse indexResponse = AccessController.doPrivileged((PrivilegedAction) () -> { + try { + return openSearchClient.index(indexRequest); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + return indexResponse.id(); + } catch (Exception e) { + log.error("Exception : " + e); + throw e; + } + } + + @Override + public Optional getModel(String modelId, boolean isReturnContent) { + GetResponse getResponse = AccessController.doPrivileged((PrivilegedAction>) () -> { + try { + return openSearchClient.get(getRequest -> getRequest.index(ML_MODEL_INDEX).id(modelId), MLModel.class); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + return getResponse.source() == null ? Optional.empty() : Optional.of(getResponse.source()); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/dao/model/OpenSearchTransportModelDao.java b/plugin/src/main/java/org/opensearch/ml/dao/model/OpenSearchTransportModelDao.java new file mode 100644 index 0000000000..b11d23f901 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/dao/model/OpenSearchTransportModelDao.java @@ -0,0 +1,97 @@ +package org.opensearch.ml.dao.model; + +import lombok.extern.log4j.Log4j2; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.search.fetch.subphase.FetchSourceContext; + +import java.util.Optional; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD; +import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; +import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; + +@Log4j2 +public class OpenSearchTransportModelDao implements ModelDao { + + private Client client; + private MLIndicesHandler mlIndicesHandler; + + private NamedXContentRegistry xContentRegistry; + + public OpenSearchTransportModelDao(Client client, + MLIndicesHandler mlIndicesHandler, + NamedXContentRegistry xContentRegistry) { + this.client = client; + this.mlIndicesHandler = mlIndicesHandler; + this.xContentRegistry = xContentRegistry; + } + + @Override + public String createModel(MLModel mlModel) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + mlIndicesHandler.initModelIndexIfAbsent(); + IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX); + indexRequest.source(mlModel.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + final IndexResponse indexResponse = client.index(indexRequest).actionGet(); + context.restore(); + return indexResponse.getId(); + } catch (Exception e) { + log.error("Failed to create model!", e); + return null; + } + } + + @Override + public Optional getModel(String modelId, boolean isReturnContent) { + FetchSourceContext fetchSourceContext = getFetchSourceContext(isReturnContent); + GetRequest getRequest = new GetRequest(ML_MODEL_INDEX).id(modelId).fetchSourceContext(fetchSourceContext); + + ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext(); + try { + GetResponse r = client.get(getRequest).actionGet(); + log.debug("Completed Get Model Request, id:{}", modelId); + + if (r != null && r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + String algorithmName = r.getSource().get(ALGORITHM_FIELD).toString(); + MLModel mlModel = MLModel.parse(parser, algorithmName); + return Optional.of(mlModel); + } catch (Exception e) { + log.error("Failed to parse ml model" + r.getId(), e); + throw e; + } + } + return Optional.empty(); + } catch(Exception e) { + if (e instanceof IndexNotFoundException) { + log.error("Failed to get model index", e); + throw new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND); + } else { + log.error("Failed to get ML model " + modelId, e); + throw new IllegalStateException("Failed to get ML model " + modelId, e); + } + } finally { + context.restore(); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index e83493f4e5..7312fe267a 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -62,6 +62,7 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -120,6 +121,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaInput; +import org.opensearch.ml.dao.model.ModelDao; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.MLExecutable; import org.opensearch.ml.engine.ModelHelper; @@ -144,6 +146,7 @@ import lombok.extern.log4j.Log4j2; + /** * Manager class for ML models. It contains ML model related operations like * register, deploy model etc. @@ -173,15 +176,15 @@ public class MLModelManager { private volatile Integer maxRegisterTasksPerNode; private volatile Integer maxDeployTasksPerNode; - public static final ImmutableSet MODEL_DONE_STATES = ImmutableSet - .of( - MLModelState.TRAINED, - MLModelState.REGISTERED, - MLModelState.DEPLOYED, - MLModelState.PARTIALLY_DEPLOYED, - MLModelState.DEPLOY_FAILED, - MLModelState.UNDEPLOYED - ); + private ModelDao modelDao; + + public static final ImmutableSet MODEL_DONE_STATES = ImmutableSet.of(MLModelState.TRAINED, + MLModelState.REGISTERED, + MLModelState.DEPLOYED, + MLModelState.PARTIALLY_DEPLOYED, + MLModelState.DEPLOY_FAILED, + MLModelState.UNDEPLOYED + ); public MLModelManager( ClusterService clusterService, @@ -197,7 +200,8 @@ public MLModelManager( MLTaskManager mlTaskManager, MLModelCacheHelper modelCacheHelper, MLEngine mlEngine, - DiscoveryNodeHelper nodeHelper + DiscoveryNodeHelper nodeHelper, + ModelDao modelDao ) { this.client = client; this.threadPool = threadPool; @@ -212,7 +216,7 @@ public MLModelManager( this.mlTaskManager = mlTaskManager; this.mlEngine = mlEngine; this.nodeHelper = nodeHelper; - + this.modelDao = modelDao; this.maxModelPerNode = ML_COMMONS_MAX_MODELS_PER_NODE.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MAX_MODELS_PER_NODE, it -> maxModelPerNode = it); @@ -243,8 +247,7 @@ public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, if (modelGroup.isExists()) { Map modelGroupSource = modelGroup.getSourceAsMap(); int updatedVersion = incrementLatestVersion(modelGroupSource); - UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest( - modelGroupSource, + UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest(modelGroupSource, modelGroupId, modelGroup.getSeqNo(), modelGroup.getPrimaryTerm(), @@ -318,8 +321,7 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput log.debug("Index model meta doc successfully {}", modelName); wrappedListener.onResponse(response.getId()); }, e -> { - deleteOrUpdateModelGroup( - mlRegisterModelMetaInput.getModelGroupId(), + deleteOrUpdateModelGroup(mlRegisterModelMetaInput.getModelGroupId(), mlRegisterModelMetaInput.getDoesVersionCreateModelGroup(), version ); @@ -337,66 +339,19 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput } /** - * * @param mlRegisterModelInput register model input for remote models * @param mlTask ML task * @param listener action listener */ public void registerMLRemoteModel( - MLRegisterModelInput mlRegisterModelInput, - MLTask mlTask, - ActionListener listener + MLRegisterModelInput mlRegisterModelInput, MLTask mlTask, ActionListener listener ) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { checkAndAddRunningTask(mlTask, maxRegisterTasksPerNode); mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), REGISTER, ML_ACTION_REQUEST_COUNT).increment(); mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); - - String modelGroupId = mlRegisterModelInput.getModelGroupId(); - GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); - client.get(getModelGroupRequest, ActionListener.wrap(getModelGroupResponse -> { - if (getModelGroupResponse.isExists()) { - Map modelGroupSourceMap = getModelGroupResponse.getSourceAsMap(); - int updatedVersion = incrementLatestVersion(modelGroupSourceMap); - UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest( - modelGroupSourceMap, - modelGroupId, - getModelGroupResponse.getSeqNo(), - getModelGroupResponse.getPrimaryTerm(), - updatedVersion - ); - client.update(updateModelGroupRequest, ActionListener.wrap(r -> { - indexRemoteModel(mlRegisterModelInput, mlTask, updatedVersion + "", listener); - }, e -> { - log.error("Failed to update model group " + modelGroupId, e); - handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), e); - listener.onFailure(e); - })); - } else { - log.error("Model group response is empty"); - handleException( - mlRegisterModelInput.getFunctionName(), - mlTask.getTaskId(), - new MLValidationException("Model group not found") - ); - listener.onFailure(new MLResourceNotFoundException("Model Group Response is empty for " + modelGroupId)); - } - }, error -> { - if (error instanceof IndexNotFoundException) { - log.error("Model group Index is missing"); - handleException( - mlRegisterModelInput.getFunctionName(), - mlTask.getTaskId(), - new MLResourceNotFoundException("Failed to get model group due to index missing") - ); - listener.onFailure(error); - } else { - log.error("Failed to get model group", error); - handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), error); - listener.onFailure(error); - } - })); + indexRemoteModel(mlRegisterModelInput, mlTask, null, listener); } catch (Exception e) { log.error("Failed to register remote model", e); handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), e); @@ -507,69 +462,50 @@ private void indexRemoteModel( MLTask mlTask, String modelVersion, ActionListener listener - ) { + ) throws ExecutionException, InterruptedException { String taskId = mlTask.getTaskId(); FunctionName functionName = mlTask.getFunctionName(); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - String modelName = registerModelInput.getModelName(); - String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion; - Instant now = Instant.now(); - if (registerModelInput.getConnector() != null) { - registerModelInput.getConnector().encrypt(mlEngine::encrypt); - } - - mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(boolResponse -> { - MLModel mlModelMeta = MLModel - .builder() - .name(modelName) - .algorithm(functionName) - .modelGroupId(registerModelInput.getModelGroupId()) - .version(version) - .description(registerModelInput.getDescription()) - .rateLimiter(registerModelInput.getRateLimiter()) - .modelFormat(registerModelInput.getModelFormat()) - .modelState(MLModelState.REGISTERED) - .connector(registerModelInput.getConnector()) - .connectorId(registerModelInput.getConnectorId()) - .modelConfig(registerModelInput.getModelConfig()) - .deploySetting(registerModelInput.getDeploySetting()) - .createdTime(now) - .lastUpdateTime(now) - .isHidden(registerModelInput.getIsHidden()) - .guardrails(registerModelInput.getGuardrails()) - .build(); - - IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX); - if (registerModelInput.getIsHidden() != null && registerModelInput.getIsHidden()) { - indexModelMetaRequest.id(modelName); - } - indexModelMetaRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS)); - indexModelMetaRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - // index remote model doc - ActionListener indexListener = ActionListener.wrap(modelMetaRes -> { - String modelId = modelMetaRes.getId(); - mlTask.setModelId(modelId); - log.info("create new model meta doc {} for upload task {}", modelId, taskId); - mlTaskManager.updateMLTask(taskId, Map.of(MODEL_ID_FIELD, modelId, STATE_FIELD, COMPLETED), 5000, true); - if (registerModelInput.isDeployModel()) { - deployModelAfterRegistering(registerModelInput, modelId); - } - listener.onResponse(new MLRegisterModelResponse(taskId, MLTaskState.CREATED.name(), modelId)); - }, e -> { - log.error("Failed to index model meta doc", e); - handleException(functionName, taskId, e); - listener.onFailure(e); - }); + String modelName = registerModelInput.getModelName(); + String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion; + Instant now = Instant.now(); + if (registerModelInput.getConnector() != null) { + registerModelInput.getConnector().encrypt(mlEngine::encrypt); + } - client.index(indexModelMetaRequest, threadedActionListener(REGISTER_THREAD_POOL, indexListener)); - }, error -> { - // failed to initialize the model index - log.error("Failed to init model index", error); - handleException(functionName, taskId, error); - listener.onFailure(error); - })); + Boolean created = mlIndicesHandler.initModelIndexIfAbsent(); + if (!Boolean.TRUE.equals(created)) { + listener.onFailure(new RuntimeException("Failed to init model index")); + } + MLModel mlModelMeta = MLModel + .builder() + .name(modelName) + .algorithm(functionName) + .modelGroupId(registerModelInput.getModelGroupId()) + .version(version) + .description(registerModelInput.getDescription()) + .rateLimiter(registerModelInput.getRateLimiter()) + .modelFormat(registerModelInput.getModelFormat()) + .modelState(MLModelState.REGISTERED) + .connector(registerModelInput.getConnector()) + .connectorId(registerModelInput.getConnectorId()) + .modelConfig(registerModelInput.getModelConfig()) + .deploySetting(registerModelInput.getDeploySetting()) + .createdTime(now) + .lastUpdateTime(now) + .isHidden(registerModelInput.getIsHidden()) + .guardrails(registerModelInput.getGuardrails()) + .tenantId(registerModelInput.getTenantId()) + .build(); + + String modelId = modelDao.createModel(mlModelMeta); + // index remote model doc + mlTask.setModelId(modelId); + log.info("create new model meta doc {} for upload task {}", modelId, taskId); + mlTaskManager.updateMLTask(taskId, Map.of(MODEL_ID_FIELD, modelId, STATE_FIELD, COMPLETED), 5000, true); + if (registerModelInput.isDeployModel()) { + deployModelAfterRegistering(registerModelInput, modelId); } + listener.onResponse(new MLRegisterModelResponse(taskId, MLTaskState.CREATED.name(), modelId)); } @VisibleForTesting @@ -603,6 +539,7 @@ void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, St .lastUpdateTime(now) .isHidden(registerModelInput.getIsHidden()) .guardrails(registerModelInput.getGuardrails()) + .tenantId(registerModelInput.getTenantId()) .build(); IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX); if (registerModelInput.getIsHidden() != null && registerModelInput.getIsHidden()) { @@ -872,7 +809,7 @@ private void updateModelRegisterStateAsDone( void deployModelAfterRegistering(MLRegisterModelInput registerModelInput, String modelId) { String[] modelNodeIds = registerModelInput.getModelNodeIds(); log.debug("start deploying model after registering, modelId: {} on nodes: {}", modelId, Arrays.toString(modelNodeIds)); - MLDeployModelRequest request = new MLDeployModelRequest(modelId, modelNodeIds, false, true, true); + MLDeployModelRequest request = new MLDeployModelRequest(modelId, modelNodeIds, false, false, true); ActionListener listener = ActionListener .wrap(r -> log.debug("model deployed, response {}", r), e -> log.error("Failed to deploy model", e)); client.execute(MLDeployModelAction.INSTANCE, request, listener); diff --git a/plugin/src/main/java/org/opensearch/ml/module/MetaDataAccessModule.java b/plugin/src/main/java/org/opensearch/ml/module/MetaDataAccessModule.java index edfa835a33..18ab6772e0 100644 --- a/plugin/src/main/java/org/opensearch/ml/module/MetaDataAccessModule.java +++ b/plugin/src/main/java/org/opensearch/ml/module/MetaDataAccessModule.java @@ -12,6 +12,8 @@ import software.amazon.awssdk.http.apache.ApacheHttpClient; import software.amazon.awssdk.regions.Region; +import java.util.Optional; + @RequiredArgsConstructor public class MetaDataAccessModule extends AbstractModule { public static final String REMOTE_METADATA_ENDPOINT = "REMOTE_METADATA_ENDPOINT"; @@ -25,14 +27,14 @@ public ConnectorDao createConnectorDao() { return new OpenSearchRestConnectorDao(createOpenSearchClient()); } - private OpenSearchClient createOpenSearchClient() { + public OpenSearchClient createOpenSearchClient() { SdkHttpClient httpClient = ApacheHttpClient.builder().build(); try { return new OpenSearchClient( new AwsSdk2Transport( httpClient, - System.getenv(REMOTE_METADATA_ENDPOINT), - Region.of(System.getenv(REGION)), + Optional.ofNullable(System.getenv(REMOTE_METADATA_ENDPOINT)).orElse("http://localhost:9200"), + Region.of(Optional.ofNullable(System.getenv(REGION)).orElse("us-east-1")), AwsSdk2TransportOptions.builder().build() ) ); diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 9db49f0343..750f7856fa 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -27,6 +27,7 @@ import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.AbstractModule; import org.opensearch.common.inject.Module; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.IndexScopedSettings; @@ -154,6 +155,8 @@ import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheAction; import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaAction; import org.opensearch.ml.common.transport.upload_chunk.MLUploadModelChunkAction; +import org.opensearch.ml.dao.model.ModelDao; +import org.opensearch.ml.dao.model.OpenSearchRestModelDao; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.MLEngineClassLoader; import org.opensearch.ml.engine.ModelHelper; @@ -351,6 +354,10 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc private ScriptService scriptService; private Encryptor encryptor; + private MetaDataAccessModule metaDataAccessModule; + + private ModelDao modelDao; + public MachineLearningPlugin(Settings settings) { // Handle this here as this feature is tied to Search/Query API, not to a ml-common API // and as such, it can't be lazy-loaded when a ml-commons API is invoked. @@ -424,7 +431,16 @@ public MachineLearningPlugin(Settings settings) { @Override public Collection createGuiceModules() { - Collection modules = Arrays.asList(new MetaDataAccessModule()); + List modules = new ArrayList<>(); + metaDataAccessModule = new MetaDataAccessModule(); + modules.add(metaDataAccessModule); + modules.add(new AbstractModule() { + + @Override + protected void configure() { + bind(ModelDao.class).to(OpenSearchRestModelDao.class); + } + }); return modules; } @@ -452,6 +468,7 @@ public Collection createComponents( Settings settings = environment.settings(); Path dataPath = environment.dataFiles()[0]; Path configFile = environment.configFile(); + modelDao = new OpenSearchRestModelDao(metaDataAccessModule.createOpenSearchClient()); encryptor = new EncryptorImpl(clusterService, client); @@ -504,7 +521,8 @@ public Collection createComponents( mlTaskManager, modelCacheHelper, mlEngine, - nodeHelper + nodeHelper, + modelDao ); mlInputDatasetHandler = new MLInputDatasetHandler(client); modelAccessControlHelper = new ModelAccessControlHelper(clusterService, settings); @@ -667,7 +685,8 @@ public Collection createComponents( clusterManagerEventListener, mlCircuitBreakerService, mlModelAutoRedeployer, - cmHandler + cmHandler, + modelDao ); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelAction.java index 097bc6fb77..4022338b2c 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelAction.java @@ -15,6 +15,7 @@ import java.util.Locale; import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.input.Constants; import org.opensearch.ml.common.transport.model.MLModelGetAction; import org.opensearch.ml.common.transport.model.MLModelGetRequest; import org.opensearch.rest.BaseRestHandler; @@ -59,7 +60,10 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client MLModelGetRequest getRequest(RestRequest request) throws IOException { String modelId = getParameterId(request, PARAMETER_MODEL_ID); boolean returnContent = returnContent(request); + String tenantId = request.getHeaders().get(Constants.TENANT_ID).get(0); - return new MLModelGetRequest(modelId, returnContent, true); + MLModelGetRequest mlModelGetRequest = new MLModelGetRequest(modelId, returnContent, true); + mlModelGetRequest.setTenantId(tenantId); + return mlModelGetRequest; } } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java index 68fd73b20a..3e8cf4456e 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java @@ -22,6 +22,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.input.Constants; import org.opensearch.ml.common.transport.register.MLRegisterModelAction; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelRequest; @@ -102,6 +103,8 @@ MLRegisterModelRequest getRequest(RestRequest request) throws IOException { } else if (FunctionName.isDLModel(mlInput.getFunctionName()) && !mlFeatureEnabledSetting.isLocalModelEnabled()) { throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); } + String tenantId = request.getHeaders().get(Constants.TENANT_ID).get(0); + mlInput.setTenantId(tenantId); return new MLRegisterModelRequest(mlInput); } }