Skip to content

Commit

Permalink
throw exception when model group not found during update request
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
  • Loading branch information
rbhavna committed Oct 6, 2023
1 parent 1f43b28 commit 4cf256b
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
Expand Down Expand Up @@ -90,39 +89,41 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUpda
MLUpdateModelGroupInput updateModelGroupInput = updateModelGroupRequest.getUpdateModelGroupInput();
String modelGroupId = updateModelGroupInput.getModelGroupID();
User user = RestActionUtils.getUserContext(client);
if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) {
GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLUpdateModelGroupResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> {
if (modelGroup.isExists()) {
try (
XContentParser parser = MLNodeUtils
.createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, modelGroup.getSourceAsBytesRef())
) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLModelGroup mlModelGroup = MLModelGroup.parse(parser);
GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLUpdateModelGroupResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> {
if (modelGroup.isExists()) {
try (
XContentParser parser = MLNodeUtils
.createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, modelGroup.getSourceAsBytesRef())
) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLModelGroup mlModelGroup = MLModelGroup.parse(parser);
if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) {
validateRequestForAccessControl(updateModelGroupInput, user, mlModelGroup);
updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, wrappedListener, user);
} else {
validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput);
}
} else {
wrappedListener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group"));
} else {
logException("Failed to get model group", e, log);
updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, wrappedListener, user);
} catch (Exception e) {
log.error("Failed to parse ml model group" + modelGroup.getId(), e);
wrappedListener.onFailure(e);
}
}));
} catch (Exception e) {
logException("Failed to Update model group", e, log);
listener.onFailure(e);
}
} else {
validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput);
updateModelGroup(modelGroupId, new HashMap<>(), updateModelGroupInput, listener, user);
} else {
wrappedListener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group"));
} else {
logException("Failed to get model group", e, log);
wrappedListener.onFailure(e);
}
}));
} catch (Exception e) {
logException("Failed to Update model group", e, log);
listener.onFailure(e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.get.GetResult;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.MLModelGroup;
Expand Down Expand Up @@ -367,6 +368,20 @@ public void test_FailedToGetModelGroupException() {
assertEquals("Failed to get model group", argumentCaptor.getValue().getMessage());
}

public void test_ModelGroupIndexNotFoundException() {
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new IndexNotFoundException("Fail to find model group"));
return null;
}).when(client).get(any(), any());

MLUpdateModelGroupRequest actionRequest = prepareRequest(null, AccessMode.RESTRICTED, null);
transportUpdateModelGroupAction.doExecute(task, actionRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("Fail to find model group", argumentCaptor.getValue().getMessage());
}

public void test_FailedToUpdatetModelGroupException() {
doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(1);
Expand Down Expand Up @@ -414,15 +429,16 @@ public void test_ModelGroupNameNotUnique() throws IOException {
}

public void test_ExceptionSecurityDisabledCluster() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule
.expectMessage(
"You cannot specify model access control parameters because the Security plugin or model access control is disabled on your cluster."
);
when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(false);

MLUpdateModelGroupRequest actionRequest = prepareRequest(null, null, true);
transportUpdateModelGroupAction.doExecute(task, actionRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"You cannot specify model access control parameters because the Security plugin or model access control is disabled on your cluster.",
argumentCaptor.getValue().getMessage()
);
}

private MLUpdateModelGroupRequest prepareRequest(List<String> backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) {
Expand Down

0 comments on commit 4cf256b

Please sign in to comment.