Skip to content

Commit

Permalink
Generate transport actions dynamically
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <widdis@gmail.com>
  • Loading branch information
dbwiddis committed Mar 17, 2023
1 parent 76671b7 commit eea1a21
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 47 deletions.
19 changes: 11 additions & 8 deletions server/src/main/java/org/opensearch/action/ActionModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@
import org.opensearch.common.inject.TypeLiteral;
import org.opensearch.common.inject.multibindings.MapBinder;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.extensions.action.ExtensionAction;
import org.opensearch.extensions.action.ExtensionProxyAction;
import org.opensearch.extensions.action.ExtensionTransportAction;
import org.opensearch.common.settings.IndexScopedSettings;
Expand Down Expand Up @@ -715,6 +716,8 @@ public <Request extends ActionRequest, Response extends ActionResponse> void reg
// Remote Store
actions.register(RestoreRemoteStoreAction.INSTANCE, TransportRestoreRemoteStoreAction.class);

// TODO: Remove this and its tests, it is no longer used
// Need to migrate its NAME prefix to the dynamic action name
if (FeatureFlags.isEnabled(FeatureFlags.EXTENSIONS)) {
// ExtensionProxyAction
actions.register(ExtensionProxyAction.INSTANCE, ExtensionTransportAction.class);
Expand Down Expand Up @@ -983,25 +986,25 @@ public RestController getRestController() {
}

public static class DynamicActionRegistry {
private final Map<String, ActionHandler<?, ?>> registry = new ConcurrentHashMap<>();
private final Map<String, ExtensionAction> registry = new ConcurrentHashMap<>();

public void registerExtensionAction(ActionHandler<?, ?> handler) {
requireNonNull(handler, "action handler is required");
String name = handler.getAction().name();
public void registerExtensionAction(ExtensionAction extensionAction) {
requireNonNull(extensionAction, "extension action is required");
String name = extensionAction.name();
requireNonNull(name, "name is required");
if (registry.putIfAbsent(name, handler) != null) {
throw new IllegalArgumentException("action handler for name [" + name + "] already registered");
if (registry.putIfAbsent(name, extensionAction) != null) {
throw new IllegalArgumentException("extension action for name [" + name + "] already registered");
}
}

public void unregisterExtensionAction(String name) {
requireNonNull(name, "name is required");
if (registry.remove(name) == null) {
throw new IllegalArgumentException("action handler for name [" + name + "] was not registered");
throw new IllegalArgumentException("extension action for name [" + name + "] was not registered");
}
}

public ActionHandler<?, ?> get(String name) {
public ExtensionAction get(String name) {
requireNonNull(name, "name is required");
return registry.get(name);
}
Expand Down
39 changes: 39 additions & 0 deletions server/src/main/java/org/opensearch/client/node/NodeClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,25 @@

import org.opensearch.action.ActionType;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionModule;
import org.opensearch.action.ActionModule.DynamicActionRegistry;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.TransportAction;
import org.opensearch.client.Client;
import org.opensearch.client.support.AbstractClient;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.settings.Settings;
import org.opensearch.extensions.ExtensionsManager;
import org.opensearch.extensions.action.ExtensionAction;
import org.opensearch.extensions.action.ExtensionTransportAction;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskListener;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.RemoteClusterService;
import org.opensearch.transport.TransportService;

import java.util.Map;
import java.util.function.Supplier;
Expand All @@ -58,12 +65,16 @@
public class NodeClient extends AbstractClient {

private Map<ActionType, TransportAction> actions;
private DynamicActionRegistry extensionActions;
private ActionFilters actionFilters;
/**
* The id of the local {@link DiscoveryNode}. Useful for generating task ids from tasks returned by
* {@link #executeLocally(ActionType, ActionRequest, TaskListener)}.
*/
private Supplier<String> localNodeId;
private RemoteClusterService remoteClusterService;
private TransportService transportService;
private ExtensionsManager extensionsManager;
private NamedWriteableRegistry namedWriteableRegistry;

public NodeClient(Settings settings, ThreadPool threadPool) {
Expand All @@ -82,6 +93,21 @@ public void initialize(
this.namedWriteableRegistry = namedWriteableRegistry;
}

public void initialize(
Map<ActionType, TransportAction> actions,
ActionModule actionModule,
TransportService transportService,
ExtensionsManager extensionsManager,
Supplier<String> localNodeId,
NamedWriteableRegistry namedWriteableRegistry
) {
initialize(actions, localNodeId, transportService.getRemoteClusterService(), namedWriteableRegistry);
this.extensionActions = actionModule.getExtensionActions();
this.actionFilters = actionModule.getActionFilters();
this.transportService = transportService;
this.extensionsManager = extensionsManager;
}

@Override
public void close() {
// nothing really to do
Expand Down Expand Up @@ -140,7 +166,20 @@ private <Request extends ActionRequest, Response extends ActionResponse> Transpo
if (actions == null) {
throw new IllegalStateException("NodeClient has not been initialized");
}
// Get from action map if it exists
TransportAction<Request, Response> transportAction = actions.get(action);
// Fallback to dynamic extension action map
if (transportAction == null && extensionActions != null && action instanceof ExtensionAction) {
ExtensionAction extensionAction = extensionActions.get(action.name());
if (extensionAction != null) {
transportAction = (TransportAction<Request, Response>) new ExtensionTransportAction(
action.name(),
transportService,
actionFilters,
extensionsManager
);
}
}
if (transportAction == null) {
throw new IllegalStateException("failed to find action [" + action + "] to execute");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,11 @@ public void initializeServicesAndRestHandler(
);
this.client = client;
this.extensionTransportActionsHandler = new ExtensionTransportActionsHandler(
this,
extensionIdMap,
transportService,
client,
actionModule.getExtensionActions()
actionModule
);
registerRequestHandler();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.extensions.action;

import org.opensearch.action.ActionType;

/**
* An {@link ActionType} to be used in extension action transport handlers.
*
* @opensearch.internal
*/
public class ExtensionAction extends ActionType<ExtensionActionResponse> {

private final String uniqueId;

/**
* Create an instance of this action to register in the dynamic actions map.
*
* @param uniqueId The uniqueId of the extension which will run this action.
* @param name The fully qualified class name of the extension's action to execute.
*/
ExtensionAction(String uniqueId, String name) {
super(name, ExtensionActionResponse::new);
this.uniqueId = uniqueId;
}

/**
* Gets the uniqueId of the extension which will run this action.
*
* @return the uniqueId
*/
public String getUniqueId() {
return this.uniqueId;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,27 @@

import org.opensearch.action.ActionListener;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.action.support.TransportAction;
import org.opensearch.extensions.ExtensionsManager;
import org.opensearch.node.Node;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

/**
* The main proxy transport action used to proxy a transport request from extension to another extension
* A transport action used to proxy a transport request from extension to another extension
*
* @opensearch.internal
*/
public class ExtensionTransportAction extends HandledTransportAction<ExtensionActionRequest, ExtensionActionResponse> {
public class ExtensionTransportAction extends TransportAction<ExtensionActionRequest, ExtensionActionResponse> {

private final String nodeName;
private final ClusterService clusterService;
private final ExtensionsManager extensionsManager;

@Inject
public ExtensionTransportAction(
Settings settings,
String actionName,
TransportService transportService,
ActionFilters actionFilters,
ClusterService clusterService,
ExtensionsManager extensionsManager
) {
super(ExtensionProxyAction.NAME, transportService, actionFilters, ExtensionActionRequest::new);
this.nodeName = Node.NODE_NAME_SETTING.get(settings);
this.clusterService = clusterService;
super(actionName, actionFilters, transportService.getTaskManager());
this.extensionsManager = extensionsManager;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionModule;
import org.opensearch.action.ActionModule.DynamicActionRegistry;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.extensions.DiscoveryExtensionNode;
import org.opensearch.extensions.AcknowledgedResponse;
import org.opensearch.extensions.ExtensionsManager;
import org.opensearch.plugins.ActionPlugin.ActionHandler;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.ActionNotFoundTransportException;
import org.opensearch.transport.TransportException;
Expand All @@ -27,10 +27,10 @@

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

Expand All @@ -41,23 +41,22 @@
*/
public class ExtensionTransportActionsHandler {
private static final Logger logger = LogManager.getLogger(ExtensionTransportActionsHandler.class);
private Map<String, DiscoveryExtensionNode> actionsMap;
private final Map<String, DiscoveryExtensionNode> actionsMap = new ConcurrentHashMap<>();
private final Map<String, DiscoveryExtensionNode> extensionIdMap;
private final TransportService transportService;
private final NodeClient client;
private DynamicActionRegistry dynamicActionRegistry;
private final DynamicActionRegistry dynamicActionRegistry;

public ExtensionTransportActionsHandler(
Map<String, DiscoveryExtensionNode> extensionIdMap,
TransportService transportService,
NodeClient client,
DynamicActionRegistry dynamicActionRegistry
ActionModule actionModule
) {
this.actionsMap = new HashMap<>();
this.extensionIdMap = extensionIdMap;
this.transportService = transportService;
this.client = client;
this.dynamicActionRegistry = dynamicActionRegistry;
this.dynamicActionRegistry = actionModule.getExtensionActions();
}

/**
Expand All @@ -73,9 +72,7 @@ void registerAction(String action, DiscoveryExtensionNode extension) throws Ille
throw new IllegalArgumentException("The action [" + action + "] you are trying to register is already registered");
}
// Register the action in the action module's extension actions map
dynamicActionRegistry.registerExtensionAction(
new ActionHandler<>(new ExtensionAction(action, extension.getId()), ExtensionTransportAction.class)
);
dynamicActionRegistry.registerExtensionAction(new ExtensionAction(action, extension.getId()));
}

/**
Expand All @@ -95,9 +92,6 @@ public DiscoveryExtensionNode getExtension(String action) {
* @return A {@link AcknowledgedResponse} indicating success.
*/
public TransportResponse handleRegisterTransportActionsRequest(RegisterTransportActionsRequest transportActionsRequest) {
/*
* We are proxying the transport Actions through ExtensionProxyAction, so we really dont need to register dynamic actions for now.
*/
logger.debug("Register Transport Actions request recieved {}", transportActionsRequest);
DiscoveryExtensionNode extension = extensionIdMap.get(transportActionsRequest.getUniqueId());
try {
Expand All @@ -123,8 +117,8 @@ public TransportResponse handleTransportActionRequestFromExtension(TransportActi
String uniqueId = request.getUniqueId();
final TransportActionResponseToExtension response = new TransportActionResponseToExtension(new byte[0]);
// Validate that this action has been registered
ActionHandler<?, ?> handler = dynamicActionRegistry.get(actionName);
if (handler == null) {
ExtensionAction extensionAction = dynamicActionRegistry.get(actionName);
if (extensionAction == null) {
byte[] responseBytes = ("Request failed: action [" + actionName + "] is not registered for extension [" + uniqueId + "].")
.getBytes(StandardCharsets.UTF_8);
response.setResponseBytes(responseBytes);
Expand All @@ -138,8 +132,7 @@ public TransportResponse handleTransportActionRequestFromExtension(TransportActi
}
final CompletableFuture<ExtensionActionResponse> inProgressFuture = new CompletableFuture<>();
client.execute(
// TODO change this to the registered action type
ExtensionProxyAction.INSTANCE,
extensionAction,
new ExtensionActionRequest(request.getAction(), request.getRequestBytes()),
new ActionListener<ExtensionActionResponse>() {
@Override
Expand Down
7 changes: 1 addition & 6 deletions server/src/main/java/org/opensearch/node/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -1114,12 +1114,7 @@ protected Node(
this.pluginLifecycleComponents = Collections.unmodifiableList(pluginLifecycleComponents);
if (FeatureFlags.isEnabled(FeatureFlags.EXTENSIONS)) {
client.initialize(injector.getInstance(new Key<Map<ActionType, TransportAction>>() {
}),
actionModule.getExtensionActions(),
() -> clusterService.localNode().getId(),
transportService.getRemoteClusterService(),
namedWriteableRegistry
);
}), actionModule, transportService, extensionsManager, () -> clusterService.localNode().getId(), namedWriteableRegistry);
} else {
client.initialize(injector.getInstance(new Key<Map<ActionType, TransportAction>>() {
}), () -> clusterService.localNode().getId(), transportService.getRemoteClusterService(), namedWriteableRegistry);
Expand Down

0 comments on commit eea1a21

Please sign in to comment.