Skip to content

Commit

Permalink
[Issue 10816][Proxy] Refresh client auth token
Browse files Browse the repository at this point in the history
  • Loading branch information
nosov.kirill committed Apr 4, 2022
1 parent bdc3024 commit 486f891
Show file tree
Hide file tree
Showing 4 changed files with 371 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
public class ClientCnx extends PulsarHandler {

protected final Authentication authentication;
private State state;
protected State state;

private final ConcurrentLongHashMap<TimedCompletableFuture<? extends Object>> pendingRequests =
ConcurrentLongHashMap.<TimedCompletableFuture<? extends Object>>newBuilder()
Expand All @@ -129,11 +129,11 @@ public class ClientCnx extends PulsarHandler {
.concurrencyLevel(1)
.build();

private final CompletableFuture<Void> connectionFuture = new CompletableFuture<Void>();
protected final CompletableFuture<Void> connectionFuture = new CompletableFuture<Void>();
private final ConcurrentLinkedQueue<RequestTime> requestTimeoutQueue = new ConcurrentLinkedQueue<>();
private final Semaphore pendingLookupRequestSemaphore;
private final Semaphore maxLookupRequestSemaphore;
private final EventLoopGroup eventLoopGroup;
protected final EventLoopGroup eventLoopGroup;

private static final AtomicIntegerFieldUpdater<ClientCnx> NUMBER_OF_REJECTED_REQUESTS_UPDATER =
AtomicIntegerFieldUpdater.newUpdater(ClientCnx.class, "numberOfRejectRequests");
Expand All @@ -146,7 +146,7 @@ public class ClientCnx extends PulsarHandler {
private final int maxNumberOfRejectedRequestPerConnection;
private final int rejectedRequestResetTimeSec = 60;
private final int protocolVersion;
private final long operationTimeoutMs;
protected final long operationTimeoutMs;

protected String proxyToTargetBrokerAddress = null;
// Remote hostName with which client is connected
Expand All @@ -164,7 +164,7 @@ public class ClientCnx extends PulsarHandler {
protected AuthenticationDataProvider authenticationDataProvider;
private TransactionBufferHandler transactionBufferHandler;

enum State {
protected enum State {
None, SentConnectFrame, Ready, Failed, Connecting
}

Expand Down Expand Up @@ -242,28 +242,30 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception {
log.info("{} Connected through proxy to target broker at {}", ctx.channel(), proxyToTargetBrokerAddress);
}
// Send CONNECT command
ctx.writeAndFlush(newConnectCommand())
.addListener(future -> {
if (future.isSuccess()) {
if (log.isDebugEnabled()) {
log.debug("Complete: {}", future.isSuccess());
}
state = State.SentConnectFrame;
} else {
log.warn("Error during handshake", future.cause());
ctx.close();
}
});
sendConnectCommand();
}

protected ByteBuf newConnectCommand() throws Exception {
protected void sendConnectCommand() throws Exception {
// mutual authentication is to auth between `remoteHostName` and this client for this channel.
// each channel will have a mutual client/server pair, mutual client evaluateChallenge with init data,
// and return authData to server.
authenticationDataProvider = authentication.getAuthData(remoteHostName);
AuthData authData = authenticationDataProvider.authenticate(AuthData.INIT_AUTH_DATA);
return Commands.newConnect(authentication.getAuthMethodName(), authData, this.protocolVersion,
PulsarVersion.getVersion(), proxyToTargetBrokerAddress, null, null, null);
ByteBuf command = Commands.newConnect(authentication.getAuthMethodName(), authData, this.protocolVersion,
PulsarVersion.getVersion(), proxyToTargetBrokerAddress, null, null, null);

ctx.writeAndFlush(command)
.addListener(future -> {
if (future.isSuccess()) {
if (log.isDebugEnabled()) {
log.debug("Complete: {}", future.isSuccess());
}
state = State.SentConnectFrame;
} else {
log.warn("Error during handshake", future.cause());
ctx.close();
}
});
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,43 +20,129 @@

import io.netty.buffer.ByteBuf;
import io.netty.channel.EventLoopGroup;
import java.util.Arrays;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import org.apache.pulsar.PulsarVersion;
import org.apache.pulsar.client.api.PulsarClientException.TimeoutException;
import org.apache.pulsar.client.impl.ClientCnx;
import org.apache.pulsar.client.impl.conf.ClientConfigurationData;
import org.apache.pulsar.common.api.AuthData;
import org.apache.pulsar.common.api.proto.CommandAuthChallenge;
import org.apache.pulsar.common.protocol.Commands;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ProxyClientCnx extends ClientCnx {

String clientAuthRole;
AuthData clientAuthData;
String clientAuthMethod;
int protocolVersion;
private final boolean forwardAuthorizationCredentials;
private final Supplier<CompletableFuture<AuthData>> clientAuthDataSupplier;

public ProxyClientCnx(ClientConfigurationData conf, EventLoopGroup eventLoopGroup, String clientAuthRole,
AuthData clientAuthData, String clientAuthMethod, int protocolVersion) {
Supplier<CompletableFuture<AuthData>> clientAuthDataSupplier,
String clientAuthMethod, int protocolVersion, boolean forwardAuthorizationCredentials) {
super(conf, eventLoopGroup);
this.clientAuthRole = clientAuthRole;
this.clientAuthData = clientAuthData;
this.clientAuthMethod = clientAuthMethod;
this.protocolVersion = protocolVersion;
this.forwardAuthorizationCredentials = forwardAuthorizationCredentials;
this.clientAuthDataSupplier = clientAuthDataSupplier;
}

@Override
protected ByteBuf newConnectCommand() throws Exception {
if (log.isDebugEnabled()) {
log.debug("New Connection opened via ProxyClientCnx with params clientAuthRole = {},"
+ " clientAuthData = {}, clientAuthMethod = {}",
clientAuthRole, clientAuthData, clientAuthMethod);
protected void sendConnectCommand() throws Exception {
CompletableFuture<ByteBuf> connectCommandFuture = newConnectCommand();
if (!connectCommandFuture.isDone()) {
eventLoopGroup.schedule(() -> {
connectCommandFuture.completeExceptionally(
new TimeoutException("New connect command timeout after ms " + operationTimeoutMs)
);
}, operationTimeoutMs, TimeUnit.MILLISECONDS);
}

connectCommandFuture.whenComplete((data, th) -> {
if (th == null) {
// Send CONNECT command
ctx.writeAndFlush(data).addListener(future -> {
if (future.isSuccess()) {
if (log.isDebugEnabled()) {
log.debug("Complete: {}", future.isSuccess());
}
state = State.SentConnectFrame;
} else {
log.warn("Error during handshake", future.cause());
ctx.close();
}
});
} else {
log.warn("Error during handshake", th);
ctx.close();
}
});
}

private CompletableFuture<ByteBuf> newConnectCommand() throws Exception {
authenticationDataProvider = authentication.getAuthData(remoteHostName);
AuthData authData = authenticationDataProvider.authenticate(AuthData.INIT_AUTH_DATA);
return Commands.newConnect(authentication.getAuthMethodName(), authData, this.protocolVersion,
PulsarVersion.getVersion(), proxyToTargetBrokerAddress, clientAuthRole, clientAuthData,
clientAuthMethod);

return clientAuthDataSupplier.get().thenApply(clientAuthData -> {
if (log.isDebugEnabled()) {
log.debug("New Connection opened via ProxyClientCnx with params clientAuthRole = {},"
+ " clientAuthData = {}, clientAuthMethod = {}",
clientAuthRole, clientAuthData, clientAuthMethod);
}

return Commands.newConnect(authentication.getAuthMethodName(), authData, this.protocolVersion,
PulsarVersion.getVersion(), proxyToTargetBrokerAddress, clientAuthRole, clientAuthData,
clientAuthMethod);
});
}

@Override
protected void handleAuthChallenge(CommandAuthChallenge authChallenge) {
boolean isRefresh = Arrays.equals(
AuthData.REFRESH_AUTH_DATA_BYTES,
authChallenge.getChallenge().getAuthData()
);

if (!forwardAuthorizationCredentials || !isRefresh) {
super.handleAuthChallenge(authChallenge);
return;
}

clientAuthDataSupplier.get()
.thenAccept(authData -> sendAuthResponse(authData, clientAuthMethod))
.exceptionally(ex -> {
log.error("{} Error refresh auth data: {}", ctx.channel(), ex);
connectionFuture.completeExceptionally(ex);
close();
return null;
});
}

private void sendAuthResponse(AuthData authData, String authMethod) {
ByteBuf response = Commands.newAuthResponse(
authMethod,
authData,
protocolVersion,
PulsarVersion.getVersion()
);

if (log.isDebugEnabled()) {
log.debug("{} Mutual auth {}", ctx.channel(), authentication.getAuthMethodName());
}

ctx.writeAndFlush(response).addListener(writeFuture -> {
if (!writeFuture.isSuccess()) {
log.warn("{} Failed to send response for mutual auth to broker: {}", ctx.channel(),
writeFuture.cause().getMessage());
connectionFuture.completeExceptionally(writeFuture.cause());
}
});
}

private static final Logger log = LoggerFactory.getLogger(ProxyClientCnx.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@
import io.netty.handler.ssl.SslHandler;
import java.net.SocketAddress;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import javax.naming.AuthenticationException;
import javax.net.ssl.SSLSession;
Expand Down Expand Up @@ -91,6 +94,8 @@ public class ProxyConnection extends PulsarHandler {
private int protocolVersionToAdvertise;
private String proxyToBrokerUrl;
private HAProxyMessage haProxyMessage;
private final AtomicReference<List<CompletableFuture<AuthData>>> authFutureList =
new AtomicReference<>(Collections.emptyList());

private static final byte[] EMPTY_CREDENTIALS = new byte[0];

Expand Down Expand Up @@ -236,8 +241,9 @@ private synchronized void completeConnect(AuthData clientData) throws PulsarClie
}
if (this.connectionPool == null) {
this.connectionPool = new ProxyConnectionPool(clientConf, service.getWorkerGroup(),
() -> new ProxyClientCnx(clientConf, service.getWorkerGroup(), clientAuthRole, clientAuthData,
clientAuthMethod, protocolVersionToAdvertise));
() -> new ProxyClientCnx(clientConf, service.getWorkerGroup(), clientAuthRole,
this::getOrRefreshClientAuthData, clientAuthMethod, protocolVersionToAdvertise,
service.getConfiguration().isForwardAuthorizationCredentials()));
} else {
LOG.error("BUG! Connection Pool has already been created for proxy connection to {} state {} role {}",
remoteAddress, state, clientAuthRole);
Expand Down Expand Up @@ -315,11 +321,16 @@ private void doAuthentication(AuthData clientData) throws Exception {
// authentication has completed, will send newConnected command.
if (authState.isComplete()) {
clientAuthRole = authState.getAuthRole();
if (LOG.isDebugEnabled()) {
LOG.debug("[{}] Client successfully authenticated with {} role {}",
remoteAddress, authMethod, clientAuthRole);
if (state == State.Init || state == State.Connecting) {
if (LOG.isDebugEnabled()) {
LOG.debug("[{}] Client successfully authenticated with {} role {}",
remoteAddress, authMethod, clientAuthRole);
}
completeConnect(clientData);
} else {
updateClientAuthData(clientData);
LOG.debug("[{}] Refreshed authentication credentials for role {}", remoteAddress, clientAuthRole);
}
completeConnect(clientData);
return;
}

Expand Down Expand Up @@ -410,7 +421,7 @@ remoteAddress, protocolVersionToAdvertise, getRemoteEndpointProtocolVersion(),

@Override
protected void handleAuthResponse(CommandAuthResponse authResponse) {
checkArgument(state == State.Connecting);
checkArgument(state == State.Connecting || state == State.ProxyLookupRequests);
checkArgument(authResponse.hasResponse());
checkArgument(authResponse.getResponse().hasAuthData() && authResponse.getResponse().hasAuthMethodName());

Expand Down Expand Up @@ -479,6 +490,57 @@ ClientConfigurationData createClientConfiguration() {
return clientConf;
}

private CompletableFuture<AuthData> getOrRefreshClientAuthData() {
boolean forwardAuth = service.getConfiguration().isForwardAuthorizationCredentials();

if (!forwardAuth || authState == null || !authState.isExpired()) {
return CompletableFuture.completedFuture(clientAuthData);
}

CompletableFuture<AuthData> result = new CompletableFuture<>();
List<CompletableFuture<AuthData>> prevFutureList = authFutureList.getAndUpdate(lst -> {
List<CompletableFuture<AuthData>> newFutureList = new LinkedList<>(lst);
newFutureList.add(result);
return newFutureList;
});

// only first sends request
if (!prevFutureList.isEmpty()) {
return result;
}

try {
AuthData authData = authState.refreshAuthentication();

ctx.writeAndFlush(Commands.newAuthChallenge(authMethod, authData, protocolVersionToAdvertise))
.addListener(writeFuture -> {
if (writeFuture.isSuccess()) {
LOG.debug("[{}] Sent auth challenge to client to refresh credentials with method: {}.",
remoteAddress, authMethod);
} else {
LOG.warn("{} Failed to send request for mutual auth to client: {}", ctx.channel(),
writeFuture.cause().getMessage());

authFutureList.getAndSet(Collections.emptyList()).forEach(future -> {
future.completeExceptionally(writeFuture.cause());
});
}
});
} catch (Exception e) {
LOG.warn("{} Failed to send request for mutual auth to client: {}", ctx.channel(), e);
authFutureList.getAndSet(Collections.emptyList()).forEach(future -> {
future.completeExceptionally(e);
});
}
return result;
}

private void updateClientAuthData(AuthData clientData) {
this.clientAuthData = clientData;
authFutureList.getAndSet(Collections.emptyList())
.forEach(future -> future.complete(clientData));
}

private static int getProtocolVersionToAdvertise(CommandConnect connect) {
return Math.min(connect.getProtocolVersion(), Commands.getCurrentProtocolVersion());
}
Expand Down
Loading

0 comments on commit 486f891

Please sign in to comment.