Skip to content

Commit

Permalink
Adds a check to skip serialization-deserialization if request is for …
Browse files Browse the repository at this point in the history
…same node (opensearch-project#2765)

Signed-off-by: Darshit Chanpura <dchanp@amazon.com>
Signed-off-by: Craig Perkins <cwperx@amazon.com>
Co-authored-by: Craig Perkins <cwperx@amazon.com>
  • Loading branch information
DarshitChanpura and cwperks authored Jul 10, 2023
1 parent 4409701 commit 8d636c4
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
import org.opensearch.action.support.ActionFilter;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.component.Lifecycle.State;
Expand Down Expand Up @@ -211,6 +212,7 @@ public final class OpenSearchSecurityPlugin extends OpenSearchSecuritySSLPlugin
private volatile ConfigurationRepository cr;
private volatile AdminDNs adminDns;
private volatile ClusterService cs;
private static volatile DiscoveryNode localNode;
private volatile AuditLog auditLog;
private volatile BackendRegistry backendRegistry;
private volatile SslExceptionHandler sslExceptionHandler;
Expand Down Expand Up @@ -1799,11 +1801,12 @@ public List<String> getSettingsFilter() {
}

@Override
public void onNodeStarted() {
public void onNodeStarted(DiscoveryNode localNode) {
log.info("Node started");
if (!SSLConfig.isSslOnlyMode() && !client && !disabled) {
cr.initOnNodeStart();
}
this.localNode = localNode;
final Set<ModuleInfo> securityModules = ReflectionHelper.getModulesLoaded();
log.info("{} OpenSearch Security modules loaded so far: {}", securityModules.size(), securityModules);
}
Expand Down Expand Up @@ -1883,6 +1886,14 @@ private static String handleKeyword(final String field) {
return field;
}

public static DiscoveryNode getLocalNode() {
return localNode;
}

public static void setLocalNode(DiscoveryNode node) {
localNode = node;
}

public static class GuiceHolder implements LifecycleComponent {

private static RepositoriesService repositoriesService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.search.SearchAction;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.settings.Settings;
Expand Down Expand Up @@ -131,7 +132,6 @@ public <T extends TransportResponse> void sendRequestDecorate(
TransportRequestOptions options,
TransportResponseHandler<T> handler
) {

final Map<String, String> origHeaders0 = getThreadContext().getHeaders();
final User user0 = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER);
final String injectedUserString = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER);
Expand All @@ -146,6 +146,9 @@ public <T extends TransportResponse> void sendRequestDecorate(
final String origCCSTransientMf = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_CCS);

final boolean isDebugEnabled = log.isDebugEnabled();
final DiscoveryNode localNode = OpenSearchSecurityPlugin.getLocalNode();
boolean isSameNodeRequest = localNode != null && localNode.equals(connection.getNode());

try (ThreadContext.StoredContext stashedContext = getThreadContext().stashContext()) {
final TransportResponseHandler<T> restoringHandler = new RestoringTransportResponseHandler<T>(handler, stashedContext);
getThreadContext().putHeader("_opendistro_security_remotecn", cs.getClusterName().value());
Expand Down Expand Up @@ -223,7 +226,7 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL

getThreadContext().putHeader(headerMap);

ensureCorrectHeaders(remoteAddress0, user0, origin0, injectedUserString, injectedRolesString);
ensureCorrectHeaders(remoteAddress0, user0, origin0, injectedUserString, injectedRolesString, isSameNodeRequest);

if (isActionTraceEnabled()) {
getThreadContext().putHeader(
Expand All @@ -249,7 +252,8 @@ private void ensureCorrectHeaders(
final User origUser,
final String origin,
final String injectedUserString,
final String injectedRolesString
final String injectedRolesString,
boolean isSameNodeRequest
) {
// keep original address

Expand All @@ -263,30 +267,49 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_ORIGIN_HEADE
getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_ORIGIN_HEADER, Origin.LOCAL.toString());
}

TransportAddress transportAddress = null;
if (remoteAdr != null && remoteAdr instanceof TransportAddress) {

String remoteAddressHeader = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER);

if (remoteAddressHeader == null) {
getThreadContext().putHeader(
ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER,
Base64Helper.serializeObject(((TransportAddress) remoteAdr).address())
);
transportAddress = (TransportAddress) remoteAdr;
}
}

String userHeader = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER);
// we put headers as transient for same node requests
if (isSameNodeRequest) {
if (transportAddress != null) {
getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, transportAddress);
}

if (userHeader == null) {
if (origUser != null) {
getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, Base64Helper.serializeObject(origUser));
// if request is going to be handled by same node, we directly put transient value as the thread context is not going to be
// stah.
getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, origUser);
} else if (StringUtils.isNotEmpty(injectedRolesString)) {
getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_HEADER, injectedRolesString);
getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES, injectedRolesString);
} else if (StringUtils.isNotEmpty(injectedUserString)) {
getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER_HEADER, injectedUserString);
getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER, injectedUserString);
}
} else {
if (transportAddress != null) {
getThreadContext().putHeader(
ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER,
Base64Helper.serializeObject(transportAddress.address())
);
}
}

final String userHeader = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER);
if (userHeader == null) {
// put as headers for other requests
if (origUser != null) {
getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, Base64Helper.serializeObject(origUser));
} else if (StringUtils.isNotEmpty(injectedRolesString)) {
getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_HEADER, injectedRolesString);
} else if (StringUtils.isNotEmpty(injectedUserString)) {
getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER_HEADER, injectedUserString);
}
}
}
}

private ThreadContext getThreadContext() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ protected void messageReceivedDecorate(
final TransportChannel transportChannel,
Task task
) throws Exception {

String resolvedActionClass = request.getClass().getSimpleName();

if (request instanceof BulkShardRequest) {
Expand Down Expand Up @@ -142,7 +141,31 @@ protected void messageReceivedDecorate(
}

// bypass non-netty requests
if (channelType.equals("direct")) {
if (getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER) != null
|| getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER) != null
|| getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES) != null
|| getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS) != null) {

final String rolesValidation = getThreadContext().getHeader(
ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_VALIDATION_HEADER
);
if (!Strings.isNullOrEmpty(rolesValidation)) {
getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_VALIDATION, rolesValidation);
}

if (isActionTraceEnabled()) {
getThreadContext().putHeader(
"_opendistro_security_trace" + System.currentTimeMillis() + "#" + UUID.randomUUID().toString(),
Thread.currentThread().getName()
+ " DIR -> "
+ transportChannel.getChannelType()
+ " "
+ getThreadContext().getHeaders()
);
}

putInitialActionClassHeader(initialActionClassValue, resolvedActionClass);
} else {
final String userHeader = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER);
final String injectedRolesHeader = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_HEADER);
final String injectedUserHeader = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER_HEADER);
Expand All @@ -162,15 +185,15 @@ protected void messageReceivedDecorate(
);
}

final String originalRemoteAddress = getThreadContext().getHeader(
ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER
);
String originalRemoteAddress = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER);

if (!Strings.isNullOrEmpty(originalRemoteAddress)) {
getThreadContext().putTransient(
ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS,
new TransportAddress((InetSocketAddress) Base64Helper.deserializeObject(originalRemoteAddress))
);
} else {
getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, request.remoteAddress());
}

final String rolesValidation = getThreadContext().getHeader(
Expand All @@ -179,20 +202,9 @@ protected void messageReceivedDecorate(
if (!Strings.isNullOrEmpty(rolesValidation)) {
getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_VALIDATION, rolesValidation);
}
}

if (isActionTraceEnabled()) {
getThreadContext().putHeader(
"_opendistro_security_trace" + System.currentTimeMillis() + "#" + UUID.randomUUID().toString(),
Thread.currentThread().getName()
+ " DIR -> "
+ transportChannel.getChannelType()
+ " "
+ getThreadContext().getHeaders()
);
}

putInitialActionClassHeader(initialActionClassValue, resolvedActionClass);

if (channelType.equals("direct")) {
super.messageReceivedDecorate(request, handler, transportChannel, task);
return;
}
Expand Down Expand Up @@ -272,58 +284,10 @@ protected void messageReceivedDecorate(

// network intercluster request or cross search cluster request
// CS-SUPPRESS-SINGLE: RegexpSingleline Used to allow/disallow TLS connections to extensions
if (HeaderHelper.isInterClusterRequest(getThreadContext())
if (!(HeaderHelper.isInterClusterRequest(getThreadContext())
|| HeaderHelper.isTrustedClusterRequest(getThreadContext())
|| HeaderHelper.isExtensionRequest(getThreadContext())) {
|| HeaderHelper.isExtensionRequest(getThreadContext()))) {
// CS-ENFORCE-SINGLE

final String userHeader = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER);
final String injectedRolesHeader = getThreadContext().getHeader(
ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_HEADER
);
final String injectedUserHeader = getThreadContext().getHeader(
ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER_HEADER
);

if (Strings.isNullOrEmpty(userHeader)) {
// Keeping role injection with higher priority as plugins under OpenSearch will be using this
// on transport layer
if (!Strings.isNullOrEmpty(injectedRolesHeader)) {
getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES, injectedRolesHeader);
} else if (!Strings.isNullOrEmpty(injectedUserHeader)) {
getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER, injectedUserHeader);
}
} else {
getThreadContext().putTransient(
ConfigConstants.OPENDISTRO_SECURITY_USER,
Objects.requireNonNull((User) Base64Helper.deserializeObject(userHeader))
);
}

String originalRemoteAddress = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER);

if (!Strings.isNullOrEmpty(originalRemoteAddress)) {
getThreadContext().putTransient(
ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS,
new TransportAddress((InetSocketAddress) Base64Helper.deserializeObject(originalRemoteAddress))
);
} else {
getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, request.remoteAddress());
}

final String rolesValidation = getThreadContext().getHeader(
ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_VALIDATION_HEADER
);
if (!Strings.isNullOrEmpty(rolesValidation)) {
getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_VALIDATION, rolesValidation);
}

} else {
// this is a netty request from a non-server node (maybe also be internal: or a shard request)
// and therefore issued by a transport client

// since OS 2.0 we do not support this any longer because transport client no longer available

final OpenSearchException exception = ExceptionUtils.createTransportClientNoLongerSupportedException();
log.error(exception.toString());
transportChannel.sendResponse(exception);
Expand All @@ -346,9 +310,8 @@ protected void messageReceivedDecorate(
}

putInitialActionClassHeader(initialActionClassValue, resolvedActionClass);

super.messageReceivedDecorate(request, handler, transportChannel, task);
}
super.messageReceivedDecorate(request, handler, transportChannel, task);
} finally {

if (isActionTraceEnabled()) {
Expand Down
Loading

0 comments on commit 8d636c4

Please sign in to comment.