diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java index a0d87b3d64..68768eb54b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java @@ -91,6 +91,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> { if (modelGroup.isExists()) { try ( @@ -104,20 +105,20 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (e instanceof IndexNotFoundException) { - listener.onFailure(new MLResourceNotFoundException("Fail to find model group")); + wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); } else { logException("Failed to get model group", e, log); - listener.onFailure(e); + wrappedListener.onFailure(e); } })); } catch (Exception e) { @@ -188,15 +189,16 @@ private void updateModelGroup(String modelGroupId, Map source, A UpdateRequest updateModelGroupRequest = new UpdateRequest(); updateModelGroupRequest.index(ML_MODEL_GROUP_INDEX).id(modelGroupId).doc(source); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); client .update( updateModelGroupRequest, - ActionListener.wrap(r -> { listener.onResponse(new MLUpdateModelGroupResponse("Updated")); }, e -> { + ActionListener.wrap(r -> { wrappedListener.onResponse(new MLUpdateModelGroupResponse("Updated")); }, e -> { if (e instanceof IndexNotFoundException) { - listener.onFailure(new MLResourceNotFoundException("Fail to find model group")); + wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); } else { log.error("Failed to update model group", e, log); - listener.onFailure(new MLValidationException("Failed to update Model Group")); + wrappedListener.onFailure(new MLValidationException("Failed to update Model Group")); } }) );