Skip to content

Commit

Permalink
Set tenant ID for predict request (#2619)
Browse files Browse the repository at this point in the history
* AWS DDB SDK client support for remote data store

Signed-off-by: Arjun kumar Giri <arjung@amazon.com>

* AWS DDB SDK client support for remote data store

Signed-off-by: Arjun kumar Giri <arjung@amazon.com>

* multi-tenancy for models (create, get, delete, update) + update connector (#2546)

* multi-tenancy for models (create, get, delete)

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* added update connector + update model

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

---------

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
Signed-off-by: Arjun kumar Giri <arjung@amazon.com>

* [Feature/multi_tenancy] Add source map to GetDataObjectResponse (#2489)

* Add source map to GetDataObjectResponse

Signed-off-by: Daniel Widdis <widdis@gmail.com>

* Add test for map getter in clients

Signed-off-by: Daniel Widdis <widdis@gmail.com>

---------

Signed-off-by: Daniel Widdis <widdis@gmail.com>
Signed-off-by: Arjun kumar Giri <arjung@amazon.com>

* [Feature/multi_tenancy] Add UpdateDataObject interface, Client, and Connector Implementations (#2520)

* Restore original exception handling expectations

Signed-off-by: Daniel Widdis <widdis@gmail.com>

* Add UpdateDataObject to interface and implementations

Signed-off-by: Daniel Widdis <widdis@gmail.com>

* Implement UpdateConnector action

Signed-off-by: Daniel Widdis <widdis@gmail.com>

* Move CompletionException handling to a common method

Signed-off-by: Daniel Widdis <widdis@gmail.com>

* Add tests for SDKClient exceptions refactored from Transport Action

Signed-off-by: Daniel Widdis <widdis@gmail.com>

---------

Signed-off-by: Daniel Widdis <widdis@gmail.com>
Signed-off-by: Arjun kumar Giri <arjung@amazon.com>

* Addressed CR comment

Signed-off-by: Arjun kumar Giri <arjung@amazon.com>

* Added javadoc based on feedback

Signed-off-by: Arjun kumar Giri <arjung@amazon.com>

* Set tenant ID for predict request

Signed-off-by: Arjun kumar Giri <arjung@amazon.com>

* Simplify instantiating Data Object Request/Response builders (#2608)

Signed-off-by: Daniel Widdis <widdis@gmail.com>

* Addressed comments

Signed-off-by: Arjun kumar Giri <arjung@amazon.com>

---------

Signed-off-by: Arjun kumar Giri <arjung@amazon.com>
Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
Signed-off-by: Daniel Widdis <widdis@gmail.com>
Signed-off-by: arjunkumargiri <142054468+arjunkumargiri@users.noreply.github.com>
Co-authored-by: Dhrubo Saha <dhrubo@amazon.com>
Co-authored-by: Daniel Widdis <widdis@gmail.com>
  • Loading branch information
3 people authored Jul 11, 2024
1 parent a421f49 commit 8c80f43
Show file tree
Hide file tree
Showing 33 changed files with 411 additions and 324 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/CI-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ jobs:
- name: Generate Password For Admin
id: genpass
run: |
PASSWORD=$(openssl rand -base64 20 | tr -dc 'A-Za-z0-9!@#$%^&*()_+=-')
echo "password={$PASSWORD}" >> $GITHUB_OUTPUT
PASSWORD=$(openssl rand -base64 20 | tr -dc 'A-Za-z0-9!@#$%^&*()_+=-')
echo "password={$PASSWORD}" >> $GITHUB_OUTPUT
- name: Run Docker Image
if: env.imagePresent == 'true'
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ public void initMLIndexIfAbsent(MLIndex index, ActionListener<Boolean> listener)
*/
public void shouldUpdateIndex(String indexName, Integer newVersion, ActionListener<Boolean> listener) {
IndexMetadata indexMetaData = clusterService.state().getMetadata().indices().get(indexName);
if (indexMetaData == null) {
if (indexMetaData == null || indexMetaData.mapping() == null) {
listener.onResponse(Boolean.FALSE);
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLCrea

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLCreateControllerResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> {
// TODO: Add support for multi tenancy
mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> {
FunctionName functionName = mlModel.getAlgorithm();
Boolean isHidden = mlModel.getIsHidden();
if (functionName == TEXT_EMBEDDING || functionName == REMOTE) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<DeleteResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> {
// TODO: Add support for multi tenancy
mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> {
Boolean isHidden = mlModel.getIsHidden();
modelAccessControlHelper
.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLCont
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLController controller = MLController.parse(parser);
mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> {
// TODO: Add support for multi tenancy
mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> {
Boolean isHidden = mlModel.getIsHidden();
modelAccessControlHelper
.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<UpdateResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> {
// TODO: Add support for multi tenancy
mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> {
FunctionName functionName = mlModel.getAlgorithm();
Boolean isHidden = mlModel.getIsHidden();
if (functionName == TEXT_EMBEDDING || functionName == REMOTE) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLDeployModelResponse> wrappedListener = ActionListener.runBefore(listener, context::restore);
mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> {
mlModelManager.getModel(modelId, tenantId, null, excludes, ActionListener.wrap(mlModel -> {
FunctionName functionName = mlModel.getAlgorithm();
Boolean isHidden = mlModel.getIsHidden();
if (!TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModel.getTenantId(), listener)) {
Expand Down Expand Up @@ -285,6 +285,10 @@ private void deployModel(
String taskId = response.getId();
mlTask.setTaskId(taskId);
if (algorithm == FunctionName.REMOTE) {
if (mlFeatureEnabledSetting.isMultiTenancyEnabled()) {
listener.onResponse(new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.CREATED.name()));
return;
}
mlTaskManager.add(mlTask, eligibleNodeIds);
deployRemoteModel(mlModel, mlTask, localNodeId, eligibleNodes, deployToAllNodes, listener);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ private void deployModel(
mlModelManager
.deployModel(
modelId,
null,
modelContentHash,
functionName,
deployToAllNodes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLMode
.builder()
.index(ML_MODEL_INDEX)
.id(modelId)
.tenantId(tenantId)
.fetchSourceContext(fetchSourceContext)
.build();
User user = RestActionUtils.getUserContext(client);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<UpdateResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
mlModelManager.getModel(sdkClient, modelId, null, excludes, ActionListener.wrap(mlModel -> {
mlModelManager.getModel(modelId, tenantId, null, excludes, ActionListener.wrap(mlModel -> {
if (TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModel.getTenantId(), actionListener)) {
if (!isModelDeploying(mlModel.getModelState())) {
FunctionName functionName = mlModel.getAlgorithm();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ public void onFailure(Exception e) {
modelActionListener.onResponse(cachedMlModel);
} else {
// For multi-node cluster, the function name is null in cache, so should always get model first.
mlModelManager.getModel(modelId, modelActionListener);
mlModelManager.getModel(modelId, tenantId, modelActionListener);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ private MLRegisterModelGroupInput createRegisterModelGroupRequest(MLRegisterMode
.backendRoles(registerModelInput.getBackendRoles())
.modelAccessMode(registerModelInput.getAccessMode())
.isAddAllBackendRoles(registerModelInput.getAddAllBackendRoles())
.tenantId(registerModelInput.getTenantId())
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ private void validateAccess(String modelId, String tenantId, ActionListener<Bool
boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client);
String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
mlModelManager.getModel(modelId, null, excludes, ActionListener.runBefore(ActionListener.wrap(mlModel -> {
mlModelManager.getModel(modelId, tenantId, null, excludes, ActionListener.runBefore(ActionListener.wrap(mlModel -> {
if (!TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModel.getTenantId(), listener)) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ public void validateConnectorAccess(
GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest
.builder()
.index(ML_CONNECTOR_INDEX)
.tenantId(tenantId)
.id(connectorId)
.fetchSourceContext(fetchSourceContext)
.build();
Expand Down
Loading

0 comments on commit 8c80f43

Please sign in to comment.