diff --git a/zuul-core/src/main/java/com/netflix/zuul/netty/server/push/PushConnectionRegistry.java b/zuul-core/src/main/java/com/netflix/zuul/netty/server/push/PushConnectionRegistry.java index d7d71bf96b..2409c8f5b0 100644 --- a/zuul-core/src/main/java/com/netflix/zuul/netty/server/push/PushConnectionRegistry.java +++ b/zuul-core/src/main/java/com/netflix/zuul/netty/server/push/PushConnectionRegistry.java @@ -36,7 +36,7 @@ public class PushConnectionRegistry { private final SecureRandom secureTokenGenerator; @Inject - private PushConnectionRegistry() { + PushConnectionRegistry() { clientPushConnectionMap = new ConcurrentHashMap<>(1024 * 32); secureTokenGenerator = new SecureRandom(); } diff --git a/zuul-core/src/main/java/com/netflix/zuul/netty/server/push/PushRegistrationHandler.java b/zuul-core/src/main/java/com/netflix/zuul/netty/server/push/PushRegistrationHandler.java index 020dca8d83..f43abaca7d 100644 --- a/zuul-core/src/main/java/com/netflix/zuul/netty/server/push/PushRegistrationHandler.java +++ b/zuul-core/src/main/java/com/netflix/zuul/netty/server/push/PushRegistrationHandler.java @@ -15,12 +15,16 @@ */ package com.netflix.zuul.netty.server.push; +import com.google.common.annotations.VisibleForTesting; import com.netflix.config.CachedDynamicBooleanProperty; import com.netflix.config.CachedDynamicIntProperty; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.handler.codec.http.websocketx.PingWebSocketFrame; import io.netty.util.concurrent.ScheduledFuture; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -44,8 +48,7 @@ public class PushRegistrationHandler extends ChannelInboundHandlerAdapter { protected final AtomicBoolean destroyed; private ChannelHandlerContext ctx; private volatile PushConnection pushConnection; - private ScheduledFuture keepAliveTask; - + private final List> scheduledFutures; public static final CachedDynamicIntProperty PUSH_REGISTRY_TTL = new CachedDynamicIntProperty("zuul.push.registry.ttl.seconds", 30 * 60); public static final CachedDynamicIntProperty RECONNECT_DITHER = new CachedDynamicIntProperty("zuul.push.reconnect.dither.seconds", 3 * 60); @@ -54,13 +57,14 @@ public class PushRegistrationHandler extends ChannelInboundHandlerAdapter { public static final CachedDynamicBooleanProperty KEEP_ALIVE_ENABLED = new CachedDynamicBooleanProperty("zuul.push.keepalive.enabled", true); public static final CachedDynamicIntProperty KEEP_ALIVE_INTERVAL = new CachedDynamicIntProperty("zuul.push.keepalive.interval.seconds", 3 * 60); - private static Logger logger = LoggerFactory.getLogger(PushRegistrationHandler.class); + private static final Logger logger = LoggerFactory.getLogger(PushRegistrationHandler.class); public PushRegistrationHandler(PushConnectionRegistry pushConnectionRegistry, PushProtocol pushProtocol) { this.pushConnectionRegistry = pushConnectionRegistry; this.pushProtocol = pushProtocol; this.destroyed = new AtomicBoolean(); + this.scheduledFutures = Collections.synchronizedList(new ArrayList<>()); } protected final boolean isAuthenticated() { @@ -68,8 +72,7 @@ protected final boolean isAuthenticated() { } private void tearDown() { - if (! destroyed.get()) { - destroyed.set(true); + if (! destroyed.getAndSet(true)) { if (authEvent != null) { // We should only remove the PushConnection entry from the registry if it's still this pushConnection. String clientID = authEvent.getClientIdentity(); @@ -81,10 +84,8 @@ private void tearDown() { logger.debug("Closing connection for {}", authEvent); } } - if (keepAliveTask != null) { - keepAliveTask.cancel(false); - keepAliveTask = null; - } + scheduledFutures.forEach(f -> f.cancel(false)); + scheduledFutures.clear(); } @Override @@ -119,7 +120,8 @@ private void requestClientToCloseConnection() { // Application level protocol for asking client to close connection ctx.writeAndFlush(pushProtocol.goAwayMessage()); // Force close connection if client doesn't close in reasonable time after we made request - ctx.executor().schedule(() -> forceCloseConnectionFromServerSide(), CLIENT_CLOSE_GRACE_PERIOD.get(), TimeUnit.SECONDS); + scheduledFutures.add(ctx.executor().schedule(this::forceCloseConnectionFromServerSide, + CLIENT_CLOSE_GRACE_PERIOD.get(), TimeUnit.SECONDS)); } else { forceCloseConnectionFromServerSide(); } @@ -180,13 +182,24 @@ else if (evt instanceof PushUserAuth) { * event loop doesn't block */ protected void registerClient(ChannelHandlerContext ctx, PushUserAuth authEvent, - PushConnection conn, PushConnectionRegistry registry) { + PushConnection conn, PushConnectionRegistry registry) { registry.put(authEvent.getClientIdentity(), conn); //Make client reconnect after ttl seconds by closing this connection to limit stickiness of the client - ctx.executor().schedule(this::requestClientToCloseConnection, ditheredReconnectDeadline(), TimeUnit.SECONDS); + scheduledFutures.add(ctx.executor().schedule(this::requestClientToCloseConnection, ditheredReconnectDeadline(), + TimeUnit.SECONDS)); if (KEEP_ALIVE_ENABLED.get()) { - keepAliveTask = ctx.executor().scheduleWithFixedDelay(this::keepAlive, KEEP_ALIVE_INTERVAL.get(), KEEP_ALIVE_INTERVAL.get(), TimeUnit.SECONDS); + scheduledFutures.add(ctx.executor().scheduleWithFixedDelay(this::keepAlive, KEEP_ALIVE_INTERVAL.get(), + KEEP_ALIVE_INTERVAL.get(), TimeUnit.SECONDS)); } } + @VisibleForTesting + PushConnection getPushConnection() { + return pushConnection; + } + + @VisibleForTesting + List> getScheduledFutures() { + return scheduledFutures; + } } diff --git a/zuul-core/src/test/java/com/netflix/zuul/netty/server/push/PushRegistrationHandlerTest.java b/zuul-core/src/test/java/com/netflix/zuul/netty/server/push/PushRegistrationHandlerTest.java new file mode 100644 index 0000000000..0bec052e6f --- /dev/null +++ b/zuul-core/src/test/java/com/netflix/zuul/netty/server/push/PushRegistrationHandlerTest.java @@ -0,0 +1,222 @@ +/* + * Copyright 2022 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.netflix.zuul.netty.server.push; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import com.google.common.util.concurrent.MoreExecutors; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.DefaultEventLoop; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.util.concurrent.ScheduledFuture; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; + +/** + * @author Justin Guerra + * @since 8/31/22 + */ +public class PushRegistrationHandlerTest { + + private static ExecutorService EXECUTOR; + + @Captor + private ArgumentCaptor scheduledCaptor; + + @Captor + private ArgumentCaptor writeCaptor; + + @Mock + private ChannelHandlerContext context; + + @Mock + private ChannelFuture channelFuture; + + @Mock + private ChannelPipeline pipelineMock; + + @Mock + private Channel channel; + + private PushConnectionRegistry registry; + private PushRegistrationHandler handler; + private DefaultEventLoop eventLoopSpy; + private TestAuth successfulAuth; + + @BeforeClass + public static void classSetup() { + EXECUTOR = Executors.newSingleThreadExecutor(); + } + + @AfterClass + public static void classCleanup() { + MoreExecutors.shutdownAndAwaitTermination(EXECUTOR, 5, TimeUnit.SECONDS); + } + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + registry = new PushConnectionRegistry(); + handler = new PushRegistrationHandler(registry, PushProtocol.WEBSOCKET); + successfulAuth = new TestAuth(true); + + eventLoopSpy = spy(new DefaultEventLoop(EXECUTOR)); + doReturn(eventLoopSpy).when(context).executor(); + doReturn(channelFuture).when(context).writeAndFlush(writeCaptor.capture()); + doReturn(pipelineMock).when(context).pipeline(); + doReturn(channel).when(context).channel(); + } + + @Test + public void closeIfNotAuthenticated() throws Exception { + doHandshakeComplete(); + + Runnable scheduledTask = scheduledCaptor.getValue(); + scheduledTask.run(); + + validateConnectionClosed(1000, "Server closed connection"); + } + + @Test + public void authFailed() throws Exception { + doHandshakeComplete(); + handler.userEventTriggered(context, new TestAuth(false)); + validateConnectionClosed(1008, "Auth failed"); + } + + @Test + public void authSuccess() throws Exception { + doHandshakeComplete(); + authenticateChannel(); + } + + @Test + public void requestClientToCloseInactiveConnection() throws Exception { + doHandshakeComplete(); + Mockito.reset(eventLoopSpy); + authenticateChannel(); + verify(eventLoopSpy).schedule(scheduledCaptor.capture(), anyLong(), eq(TimeUnit.SECONDS)); + Runnable requestClientToClose = scheduledCaptor.getValue(); + + requestClientToClose.run(); + validateConnectionClosed(1000, "Server closed connection"); + } + + @Test + public void requestClientToClose() throws Exception { + doHandshakeComplete(); + Mockito.reset(eventLoopSpy); + authenticateChannel(); + verify(eventLoopSpy).schedule(scheduledCaptor.capture(), anyLong(), eq(TimeUnit.SECONDS)); + Runnable requestClientToClose = scheduledCaptor.getValue(); + + int taskListSize = handler.getScheduledFutures().size(); + doReturn(true).when(channel).isActive(); + requestClientToClose.run(); + assertEquals(taskListSize + 1, handler.getScheduledFutures().size()); + Object capture = writeCaptor.getValue(); + assertTrue(capture instanceof TextWebSocketFrame); + TextWebSocketFrame frame = (TextWebSocketFrame) capture; + assertEquals("_CLOSE_", frame.text()); + } + + @Test + public void channelInactiveCancelsTasks() throws Exception { + doHandshakeComplete(); + TestAuth testAuth = new TestAuth(true); + authenticateChannel(); + + List> copyOfFutures = new ArrayList<>(handler.getScheduledFutures()); + + handler.channelInactive(context); + assertNull(registry.get(testAuth.getClientIdentity())); + assertTrue(handler.getScheduledFutures().isEmpty()); + copyOfFutures.forEach(f -> assertTrue(f.isCancelled())); + verify(context).close(); + } + + private void doHandshakeComplete() throws Exception { + handler.userEventTriggered(context, PushProtocol.WEBSOCKET.getHandshakeCompleteEvent()); + assertNotNull(handler.getPushConnection()); + verify(eventLoopSpy).schedule(scheduledCaptor.capture(), anyLong(), eq(TimeUnit.SECONDS)); + } + + private void authenticateChannel() throws Exception { + handler.userEventTriggered(context, successfulAuth); + assertNotNull(registry.get(successfulAuth.getClientIdentity())); + assertEquals(2, handler.getScheduledFutures().size()); + verify(pipelineMock).remove(PushAuthHandler.NAME); + } + + private void validateConnectionClosed(int expected, String messaged) { + Object capture = writeCaptor.getValue(); + assertTrue(capture instanceof CloseWebSocketFrame); + CloseWebSocketFrame closeFrame = (CloseWebSocketFrame) capture; + assertEquals(expected, closeFrame.statusCode()); + assertEquals(messaged, closeFrame.reasonText()); + verify(channelFuture).addListener(ChannelFutureListener.CLOSE); + } + + + private static class TestAuth implements PushUserAuth { + + private final boolean success; + + public TestAuth(boolean success) { + this.success = success; + } + + @Override + public boolean isSuccess() { + return success; + } + + @Override + public int statusCode() { + return 0; + } + + @Override + public String getClientIdentity() { + return "whatever"; + } + } + +} \ No newline at end of file