Skip to content

Commit

Permalink
Update TCPClient to detect and handle TCP socket closures.
Browse files Browse the repository at this point in the history
  • Loading branch information
Brandon Dahler committed Apr 20, 2022
1 parent 11369ac commit f79486f
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,55 +17,53 @@
package software.amazon.cloudwatchlogs.emf.sinks;

import java.io.IOException;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import lombok.extern.slf4j.Slf4j;

/** A client that would connect to a TCP socket. */
@Slf4j
public class TCPClient implements SocketClient {

private final Endpoint endpoint;
private Socket socket;
private SocketChannel socketChannel;
private boolean shouldConnect = true;

private final ByteBuffer readBuffer = ByteBuffer.allocate(1);

public TCPClient(Endpoint endpoint) {
this.endpoint = endpoint;
}

private void connect() {
try {
socket = createSocket();
socket.connect(new InetSocketAddress(endpoint.getHost(), endpoint.getPort()));
socketChannel = SocketChannel.open();
socketChannel.connect(new InetSocketAddress(endpoint.getHost(), endpoint.getPort()));
shouldConnect = false;
} catch (Exception e) {
shouldConnect = true;
throw new RuntimeException("Failed to connect to the socket.", e);
}
}

protected Socket createSocket() {
return new Socket();
}

@Override
public synchronized void sendMessage(String message) {
if (socket == null || socket.isClosed() || shouldConnect) {
if (socketChannel == null || !socketChannel.isConnected() || shouldConnect) {
connect();
}

OutputStream os;
try {
os = socket.getOutputStream();
} catch (IOException e) {
shouldConnect = true;
throw new RuntimeException(
"Failed to write message to the socket. Failed to open output stream.", e);
}
socketChannel.configureBlocking(true);
socketChannel.write(ByteBuffer.wrap(message.getBytes()));

// Execute a non-blocking, single-byte read to detect if there was a connection closure.
// No actual data is expected to be read.
readBuffer.clear();

socketChannel.configureBlocking(false);
socketChannel.read(readBuffer);

try {
os.write(message.getBytes());
} catch (Exception e) {
shouldConnect = true;
throw new RuntimeException("Failed to write message to the socket.", e);
Expand All @@ -74,8 +72,8 @@ public synchronized void sendMessage(String message) {

@Override
public void close() throws IOException {
if (socket != null) {
socket.close();
if (socketChannel != null) {
socketChannel.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,61 @@

package software.amazon.cloudwatchlogs.emf.sinks;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertThrows;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.net.Socket;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.nio.charset.StandardCharsets;
import org.junit.Test;

public class TCPClientTest {

@Test
public void testSendMessage() throws IOException {
Socket socket = mock(Socket.class);
ByteArrayOutputStream bos = new ByteArrayOutputStream();
when(socket.getOutputStream()).thenReturn(bos);
doNothing().when(socket).connect(any());
Endpoint endpoint = Endpoint.DEFAULT_TCP_ENDPOINT;
InetSocketAddress socketAddress =
new InetSocketAddress(endpoint.getHost(), endpoint.getPort());

TCPClient client =
new TCPClient(endpoint) {
@Override
protected Socket createSocket() {
return socket;
}
};
ServerSocketChannel serverListener = ServerSocketChannel.open();
serverListener.bind(socketAddress);

TCPClient client = new TCPClient(endpoint);

String message = "Test message";
client.sendMessage(message);

assertEquals(bos.toString(), message);
byte[] messageBytes = message.getBytes(StandardCharsets.UTF_8);
ByteBuffer receiveBuffer = ByteBuffer.allocate(messageBytes.length);

SocketChannel serverChannel = serverListener.accept();
serverChannel.read(receiveBuffer);

assertArrayEquals(receiveBuffer.array(), messageBytes);
}


@Test
public void testDetectSocketClosure() throws IOException {
Endpoint endpoint = Endpoint.DEFAULT_TCP_ENDPOINT;
InetSocketAddress socketAddress =
new InetSocketAddress(endpoint.getHost(), endpoint.getPort());

ServerSocketChannel serverListener = ServerSocketChannel.open();
serverListener.bind(socketAddress);

TCPClient client = new TCPClient(endpoint);

String message = "Test message";
client.sendMessage(message);

SocketChannel serverChannel = serverListener.accept();
serverChannel.close();

assertThrows(RuntimeException.class, () -> client.sendMessage(message));
}

}

0 comments on commit f79486f

Please sign in to comment.