diff --git a/server/src/main/java/org/opensearch/action/ActionModule.java b/server/src/main/java/org/opensearch/action/ActionModule.java index 640ccf1217051..89f185e302f69 100644 --- a/server/src/main/java/org/opensearch/action/ActionModule.java +++ b/server/src/main/java/org/opensearch/action/ActionModule.java @@ -282,6 +282,7 @@ import org.opensearch.common.inject.TypeLiteral; import org.opensearch.common.inject.multibindings.MapBinder; import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.extensions.action.ExtensionAction; import org.opensearch.extensions.action.ExtensionProxyAction; import org.opensearch.extensions.action.ExtensionTransportAction; import org.opensearch.common.settings.IndexScopedSettings; @@ -715,6 +716,8 @@ public void reg // Remote Store actions.register(RestoreRemoteStoreAction.INSTANCE, TransportRestoreRemoteStoreAction.class); + // TODO: Remove this and its tests, it is no longer used + // Need to migrate its NAME prefix to the dynamic action name if (FeatureFlags.isEnabled(FeatureFlags.EXTENSIONS)) { // ExtensionProxyAction actions.register(ExtensionProxyAction.INSTANCE, ExtensionTransportAction.class); @@ -983,25 +986,25 @@ public RestController getRestController() { } public static class DynamicActionRegistry { - private final Map> registry = new ConcurrentHashMap<>(); + private final Map registry = new ConcurrentHashMap<>(); - public void registerExtensionAction(ActionHandler handler) { - requireNonNull(handler, "action handler is required"); - String name = handler.getAction().name(); + public void registerExtensionAction(ExtensionAction extensionAction) { + requireNonNull(extensionAction, "extension action is required"); + String name = extensionAction.name(); requireNonNull(name, "name is required"); - if (registry.putIfAbsent(name, handler) != null) { - throw new IllegalArgumentException("action handler for name [" + name + "] already registered"); + if (registry.putIfAbsent(name, extensionAction) != null) { + throw new IllegalArgumentException("extension action for name [" + name + "] already registered"); } } public void unregisterExtensionAction(String name) { requireNonNull(name, "name is required"); if (registry.remove(name) == null) { - throw new IllegalArgumentException("action handler for name [" + name + "] was not registered"); + throw new IllegalArgumentException("extension action for name [" + name + "] was not registered"); } } - public ActionHandler get(String name) { + public ExtensionAction get(String name) { requireNonNull(name, "name is required"); return registry.get(name); } diff --git a/server/src/main/java/org/opensearch/client/node/NodeClient.java b/server/src/main/java/org/opensearch/client/node/NodeClient.java index 56cb7c406744a..a3be70f5ba5ec 100644 --- a/server/src/main/java/org/opensearch/client/node/NodeClient.java +++ b/server/src/main/java/org/opensearch/client/node/NodeClient.java @@ -34,18 +34,25 @@ import org.opensearch.action.ActionType; import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionModule; +import org.opensearch.action.ActionModule.DynamicActionRegistry; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionResponse; +import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.TransportAction; import org.opensearch.client.Client; import org.opensearch.client.support.AbstractClient; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.settings.Settings; +import org.opensearch.extensions.ExtensionsManager; +import org.opensearch.extensions.action.ExtensionAction; +import org.opensearch.extensions.action.ExtensionTransportAction; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskListener; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.RemoteClusterService; +import org.opensearch.transport.TransportService; import java.util.Map; import java.util.function.Supplier; @@ -58,12 +65,16 @@ public class NodeClient extends AbstractClient { private Map actions; + private DynamicActionRegistry extensionActions; + private ActionFilters actionFilters; /** * The id of the local {@link DiscoveryNode}. Useful for generating task ids from tasks returned by * {@link #executeLocally(ActionType, ActionRequest, TaskListener)}. */ private Supplier localNodeId; private RemoteClusterService remoteClusterService; + private TransportService transportService; + private ExtensionsManager extensionsManager; private NamedWriteableRegistry namedWriteableRegistry; public NodeClient(Settings settings, ThreadPool threadPool) { @@ -82,6 +93,21 @@ public void initialize( this.namedWriteableRegistry = namedWriteableRegistry; } + public void initialize( + Map actions, + ActionModule actionModule, + TransportService transportService, + ExtensionsManager extensionsManager, + Supplier localNodeId, + NamedWriteableRegistry namedWriteableRegistry + ) { + initialize(actions, localNodeId, transportService.getRemoteClusterService(), namedWriteableRegistry); + this.extensionActions = actionModule.getExtensionActions(); + this.actionFilters = actionModule.getActionFilters(); + this.transportService = transportService; + this.extensionsManager = extensionsManager; + } + @Override public void close() { // nothing really to do @@ -140,7 +166,20 @@ private Transpo if (actions == null) { throw new IllegalStateException("NodeClient has not been initialized"); } + // Get from action map if it exists TransportAction transportAction = actions.get(action); + // Fallback to dynamic extension action map + if (transportAction == null && extensionActions != null && action instanceof ExtensionAction) { + ExtensionAction extensionAction = extensionActions.get(action.name()); + if (extensionAction != null) { + transportAction = (TransportAction) new ExtensionTransportAction( + action.name(), + transportService, + actionFilters, + extensionsManager + ); + } + } if (transportAction == null) { throw new IllegalStateException("failed to find action [" + action + "] to execute"); } diff --git a/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java b/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java index 8b38bfd6dd47c..dda19718e49d8 100644 --- a/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java +++ b/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java @@ -195,10 +195,11 @@ public void initializeServicesAndRestHandler( ); this.client = client; this.extensionTransportActionsHandler = new ExtensionTransportActionsHandler( + this, extensionIdMap, transportService, client, - actionModule.getExtensionActions() + actionModule ); registerRequestHandler(); } diff --git a/server/src/main/java/org/opensearch/extensions/action/ExtensionAction.java b/server/src/main/java/org/opensearch/extensions/action/ExtensionAction.java new file mode 100644 index 0000000000000..023a8d50bcd55 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/action/ExtensionAction.java @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.action; + +import org.opensearch.action.ActionType; + +/** + * An {@link ActionType} to be used in extension action transport handlers. + * + * @opensearch.internal + */ +public class ExtensionAction extends ActionType { + + private final String uniqueId; + + /** + * Create an instance of this action to register in the dynamic actions map. + * + * @param uniqueId The uniqueId of the extension which will run this action. + * @param name The fully qualified class name of the extension's action to execute. + */ + ExtensionAction(String uniqueId, String name) { + super(name, ExtensionActionResponse::new); + this.uniqueId = uniqueId; + } + + /** + * Gets the uniqueId of the extension which will run this action. + * + * @return the uniqueId + */ + public String getUniqueId() { + return this.uniqueId; + } +} diff --git a/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportAction.java b/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportAction.java index 5976db78002eb..389edf769f39c 100644 --- a/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportAction.java +++ b/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportAction.java @@ -10,37 +10,27 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.inject.Inject; -import org.opensearch.common.settings.Settings; +import org.opensearch.action.support.TransportAction; import org.opensearch.extensions.ExtensionsManager; -import org.opensearch.node.Node; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; /** - * The main proxy transport action used to proxy a transport request from extension to another extension + * A transport action used to proxy a transport request from extension to another extension * * @opensearch.internal */ -public class ExtensionTransportAction extends HandledTransportAction { +public class ExtensionTransportAction extends TransportAction { - private final String nodeName; - private final ClusterService clusterService; private final ExtensionsManager extensionsManager; - @Inject public ExtensionTransportAction( - Settings settings, + String actionName, TransportService transportService, ActionFilters actionFilters, - ClusterService clusterService, ExtensionsManager extensionsManager ) { - super(ExtensionProxyAction.NAME, transportService, actionFilters, ExtensionActionRequest::new); - this.nodeName = Node.NODE_NAME_SETTING.get(settings); - this.clusterService = clusterService; + super(actionName, actionFilters, transportService.getTaskManager()); this.extensionsManager = extensionsManager; } diff --git a/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportActionsHandler.java b/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportActionsHandler.java index ab80ac99538ed..e647e8138dcf7 100644 --- a/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportActionsHandler.java +++ b/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportActionsHandler.java @@ -11,13 +11,13 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionModule; import org.opensearch.action.ActionModule.DynamicActionRegistry; import org.opensearch.client.node.NodeClient; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.extensions.DiscoveryExtensionNode; import org.opensearch.extensions.AcknowledgedResponse; import org.opensearch.extensions.ExtensionsManager; -import org.opensearch.plugins.ActionPlugin.ActionHandler; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.ActionNotFoundTransportException; import org.opensearch.transport.TransportException; @@ -27,10 +27,10 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; -import java.util.HashMap; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -41,23 +41,22 @@ */ public class ExtensionTransportActionsHandler { private static final Logger logger = LogManager.getLogger(ExtensionTransportActionsHandler.class); - private Map actionsMap; + private final Map actionsMap = new ConcurrentHashMap<>(); private final Map extensionIdMap; private final TransportService transportService; private final NodeClient client; - private DynamicActionRegistry dynamicActionRegistry; + private final DynamicActionRegistry dynamicActionRegistry; public ExtensionTransportActionsHandler( Map extensionIdMap, TransportService transportService, NodeClient client, - DynamicActionRegistry dynamicActionRegistry + ActionModule actionModule ) { - this.actionsMap = new HashMap<>(); this.extensionIdMap = extensionIdMap; this.transportService = transportService; this.client = client; - this.dynamicActionRegistry = dynamicActionRegistry; + this.dynamicActionRegistry = actionModule.getExtensionActions(); } /** @@ -73,9 +72,7 @@ void registerAction(String action, DiscoveryExtensionNode extension) throws Ille throw new IllegalArgumentException("The action [" + action + "] you are trying to register is already registered"); } // Register the action in the action module's extension actions map - dynamicActionRegistry.registerExtensionAction( - new ActionHandler<>(new ExtensionAction(action, extension.getId()), ExtensionTransportAction.class) - ); + dynamicActionRegistry.registerExtensionAction(new ExtensionAction(action, extension.getId())); } /** @@ -95,9 +92,6 @@ public DiscoveryExtensionNode getExtension(String action) { * @return A {@link AcknowledgedResponse} indicating success. */ public TransportResponse handleRegisterTransportActionsRequest(RegisterTransportActionsRequest transportActionsRequest) { - /* - * We are proxying the transport Actions through ExtensionProxyAction, so we really dont need to register dynamic actions for now. - */ logger.debug("Register Transport Actions request recieved {}", transportActionsRequest); DiscoveryExtensionNode extension = extensionIdMap.get(transportActionsRequest.getUniqueId()); try { @@ -123,8 +117,8 @@ public TransportResponse handleTransportActionRequestFromExtension(TransportActi String uniqueId = request.getUniqueId(); final TransportActionResponseToExtension response = new TransportActionResponseToExtension(new byte[0]); // Validate that this action has been registered - ActionHandler handler = dynamicActionRegistry.get(actionName); - if (handler == null) { + ExtensionAction extensionAction = dynamicActionRegistry.get(actionName); + if (extensionAction == null) { byte[] responseBytes = ("Request failed: action [" + actionName + "] is not registered for extension [" + uniqueId + "].") .getBytes(StandardCharsets.UTF_8); response.setResponseBytes(responseBytes); @@ -138,8 +132,7 @@ public TransportResponse handleTransportActionRequestFromExtension(TransportActi } final CompletableFuture inProgressFuture = new CompletableFuture<>(); client.execute( - // TODO change this to the registered action type - ExtensionProxyAction.INSTANCE, + extensionAction, new ExtensionActionRequest(request.getAction(), request.getRequestBytes()), new ActionListener() { @Override diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 5b6123196b072..e94516012760f 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -1114,12 +1114,7 @@ protected Node( this.pluginLifecycleComponents = Collections.unmodifiableList(pluginLifecycleComponents); if (FeatureFlags.isEnabled(FeatureFlags.EXTENSIONS)) { client.initialize(injector.getInstance(new Key>() { - }), - actionModule.getExtensionActions(), - () -> clusterService.localNode().getId(), - transportService.getRemoteClusterService(), - namedWriteableRegistry - ); + }), actionModule, transportService, extensionsManager, () -> clusterService.localNode().getId(), namedWriteableRegistry); } else { client.initialize(injector.getInstance(new Key>() { }), () -> clusterService.localNode().getId(), transportService.getRemoteClusterService(), namedWriteableRegistry);