Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PushRegistrationHandler holds onto ChannelHandlerContext references #1294

Merged
merged 3 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class PushConnectionRegistry {
private final SecureRandom secureTokenGenerator;

@Inject
private PushConnectionRegistry() {
PushConnectionRegistry() {
clientPushConnectionMap = new ConcurrentHashMap<>(1024 * 32);
secureTokenGenerator = new SecureRandom();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
*/
package com.netflix.zuul.netty.server.push;

import com.google.common.annotations.VisibleForTesting;
import com.netflix.config.CachedDynamicBooleanProperty;
import com.netflix.config.CachedDynamicIntProperty;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.util.concurrent.ScheduledFuture;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -44,8 +48,7 @@ public class PushRegistrationHandler extends ChannelInboundHandlerAdapter {
protected final AtomicBoolean destroyed;
private ChannelHandlerContext ctx;
private volatile PushConnection pushConnection;
private ScheduledFuture<?> keepAliveTask;

private final List<ScheduledFuture<?>> scheduledFutures;

public static final CachedDynamicIntProperty PUSH_REGISTRY_TTL = new CachedDynamicIntProperty("zuul.push.registry.ttl.seconds", 30 * 60);
public static final CachedDynamicIntProperty RECONNECT_DITHER = new CachedDynamicIntProperty("zuul.push.reconnect.dither.seconds", 3 * 60);
Expand All @@ -54,22 +57,22 @@ public class PushRegistrationHandler extends ChannelInboundHandlerAdapter {
public static final CachedDynamicBooleanProperty KEEP_ALIVE_ENABLED = new CachedDynamicBooleanProperty("zuul.push.keepalive.enabled", true);
public static final CachedDynamicIntProperty KEEP_ALIVE_INTERVAL = new CachedDynamicIntProperty("zuul.push.keepalive.interval.seconds", 3 * 60);

private static Logger logger = LoggerFactory.getLogger(PushRegistrationHandler.class);
private static final Logger logger = LoggerFactory.getLogger(PushRegistrationHandler.class);


public PushRegistrationHandler(PushConnectionRegistry pushConnectionRegistry, PushProtocol pushProtocol) {
this.pushConnectionRegistry = pushConnectionRegistry;
this.pushProtocol = pushProtocol;
this.destroyed = new AtomicBoolean();
this.scheduledFutures = Collections.synchronizedList(new ArrayList<>());
}

protected final boolean isAuthenticated() {
return (authEvent != null && authEvent.isSuccess());
}

private void tearDown() {
if (! destroyed.get()) {
destroyed.set(true);
if (! destroyed.getAndSet(true)) {
if (authEvent != null) {
// We should only remove the PushConnection entry from the registry if it's still this pushConnection.
String clientID = authEvent.getClientIdentity();
Expand All @@ -81,10 +84,8 @@ private void tearDown() {
logger.debug("Closing connection for {}", authEvent);
}
}
if (keepAliveTask != null) {
keepAliveTask.cancel(false);
keepAliveTask = null;
}
scheduledFutures.forEach(f -> f.cancel(false));
scheduledFutures.clear();
}

@Override
Expand Down Expand Up @@ -119,7 +120,8 @@ private void requestClientToCloseConnection() {
// Application level protocol for asking client to close connection
ctx.writeAndFlush(pushProtocol.goAwayMessage());
// Force close connection if client doesn't close in reasonable time after we made request
ctx.executor().schedule(() -> forceCloseConnectionFromServerSide(), CLIENT_CLOSE_GRACE_PERIOD.get(), TimeUnit.SECONDS);
scheduledFutures.add(ctx.executor().schedule(this::forceCloseConnectionFromServerSide,
CLIENT_CLOSE_GRACE_PERIOD.get(), TimeUnit.SECONDS));
} else {
forceCloseConnectionFromServerSide();
}
Expand Down Expand Up @@ -180,13 +182,24 @@ else if (evt instanceof PushUserAuth) {
* event loop doesn't block
*/
protected void registerClient(ChannelHandlerContext ctx, PushUserAuth authEvent,
PushConnection conn, PushConnectionRegistry registry) {
PushConnection conn, PushConnectionRegistry registry) {
registry.put(authEvent.getClientIdentity(), conn);
//Make client reconnect after ttl seconds by closing this connection to limit stickiness of the client
ctx.executor().schedule(this::requestClientToCloseConnection, ditheredReconnectDeadline(), TimeUnit.SECONDS);
scheduledFutures.add(ctx.executor().schedule(this::requestClientToCloseConnection, ditheredReconnectDeadline(),
TimeUnit.SECONDS));
if (KEEP_ALIVE_ENABLED.get()) {
keepAliveTask = ctx.executor().scheduleWithFixedDelay(this::keepAlive, KEEP_ALIVE_INTERVAL.get(), KEEP_ALIVE_INTERVAL.get(), TimeUnit.SECONDS);
scheduledFutures.add(ctx.executor().scheduleWithFixedDelay(this::keepAlive, KEEP_ALIVE_INTERVAL.get(),
KEEP_ALIVE_INTERVAL.get(), TimeUnit.SECONDS));
}
}

@VisibleForTesting
PushConnection getPushConnection() {
return pushConnection;
}

@VisibleForTesting
List<ScheduledFuture<?>> getScheduledFutures() {
return scheduledFutures;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
/*
* Copyright 2022 Netflix, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.netflix.zuul.netty.server.push;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import com.google.common.util.concurrent.MoreExecutors;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.DefaultEventLoop;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.concurrent.ScheduledFuture;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;

/**
* @author Justin Guerra
* @since 8/31/22
*/
public class PushRegistrationHandlerTest {

private static ExecutorService EXECUTOR;

@Captor
private ArgumentCaptor<Runnable> scheduledCaptor;

@Captor
private ArgumentCaptor<Object> writeCaptor;

@Mock
private ChannelHandlerContext context;

@Mock
private ChannelFuture channelFuture;

@Mock
private ChannelPipeline pipelineMock;

@Mock
private Channel channel;

private PushConnectionRegistry registry;
private PushRegistrationHandler handler;
private DefaultEventLoop eventLoopSpy;
private TestAuth successfulAuth;

@BeforeClass
public static void classSetup() {
EXECUTOR = Executors.newSingleThreadExecutor();
}

@AfterClass
public static void classCleanup() {
MoreExecutors.shutdownAndAwaitTermination(EXECUTOR, 5, TimeUnit.SECONDS);
}

@Before
public void setup() {
MockitoAnnotations.openMocks(this);
registry = new PushConnectionRegistry();
handler = new PushRegistrationHandler(registry, PushProtocol.WEBSOCKET);
successfulAuth = new TestAuth(true);

eventLoopSpy = spy(new DefaultEventLoop(EXECUTOR));
doReturn(eventLoopSpy).when(context).executor();
doReturn(channelFuture).when(context).writeAndFlush(writeCaptor.capture());
doReturn(pipelineMock).when(context).pipeline();
doReturn(channel).when(context).channel();
}

@Test
public void closeIfNotAuthenticated() throws Exception {
doHandshakeComplete();

Runnable scheduledTask = scheduledCaptor.getValue();
scheduledTask.run();

validateConnectionClosed(1000, "Server closed connection");
}

@Test
public void authFailed() throws Exception {
doHandshakeComplete();
handler.userEventTriggered(context, new TestAuth(false));
validateConnectionClosed(1008, "Auth failed");
}

@Test
public void authSuccess() throws Exception {
doHandshakeComplete();
authenticateChannel();
}

@Test
public void requestClientToCloseInactiveConnection() throws Exception {
doHandshakeComplete();
Mockito.reset(eventLoopSpy);
authenticateChannel();
verify(eventLoopSpy).schedule(scheduledCaptor.capture(), anyLong(), eq(TimeUnit.SECONDS));
Runnable requestClientToClose = scheduledCaptor.getValue();

requestClientToClose.run();
validateConnectionClosed(1000, "Server closed connection");
}

@Test
public void requestClientToClose() throws Exception {
doHandshakeComplete();
Mockito.reset(eventLoopSpy);
authenticateChannel();
verify(eventLoopSpy).schedule(scheduledCaptor.capture(), anyLong(), eq(TimeUnit.SECONDS));
Runnable requestClientToClose = scheduledCaptor.getValue();

int taskListSize = handler.getScheduledFutures().size();
doReturn(true).when(channel).isActive();
requestClientToClose.run();
assertEquals(taskListSize + 1, handler.getScheduledFutures().size());
Object capture = writeCaptor.getValue();
assertTrue(capture instanceof TextWebSocketFrame);
TextWebSocketFrame frame = (TextWebSocketFrame) capture;
assertEquals("_CLOSE_", frame.text());
}

@Test
public void channelInactiveCancelsTasks() throws Exception {
doHandshakeComplete();
TestAuth testAuth = new TestAuth(true);
authenticateChannel();

List<ScheduledFuture<?>> copyOfFutures = new ArrayList<>(handler.getScheduledFutures());

handler.channelInactive(context);
assertNull(registry.get(testAuth.getClientIdentity()));
assertTrue(handler.getScheduledFutures().isEmpty());
copyOfFutures.forEach(f -> assertTrue(f.isCancelled()));
verify(context).close();
}

private void doHandshakeComplete() throws Exception {
handler.userEventTriggered(context, PushProtocol.WEBSOCKET.getHandshakeCompleteEvent());
assertNotNull(handler.getPushConnection());
verify(eventLoopSpy).schedule(scheduledCaptor.capture(), anyLong(), eq(TimeUnit.SECONDS));
}

private void authenticateChannel() throws Exception {
handler.userEventTriggered(context, successfulAuth);
assertNotNull(registry.get(successfulAuth.getClientIdentity()));
assertEquals(2, handler.getScheduledFutures().size());
verify(pipelineMock).remove(PushAuthHandler.NAME);
}

private void validateConnectionClosed(int expected, String messaged) {
Object capture = writeCaptor.getValue();
assertTrue(capture instanceof CloseWebSocketFrame);
CloseWebSocketFrame closeFrame = (CloseWebSocketFrame) capture;
assertEquals(expected, closeFrame.statusCode());
assertEquals(messaged, closeFrame.reasonText());
verify(channelFuture).addListener(ChannelFutureListener.CLOSE);
}


private static class TestAuth implements PushUserAuth {

private final boolean success;

public TestAuth(boolean success) {
this.success = success;
}

@Override
public boolean isSuccess() {
return success;
}

@Override
public int statusCode() {
return 0;
}

@Override
public String getClientIdentity() {
return "whatever";
}
}

}