diff --git a/CHANGELOG.md b/CHANGELOG.md index b642f4ac..97656b86 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # CHANGELOG +## 4.3.0 / 2024.XX.XX + +* [FEATURE] Add support for `SOCK_STREAM` Unix sockets. See [#228][] + ## 4.2.1 / 2023.03.10 * [FEATURE] Add support for `DD_DOGSTATSD_URL`. See [#217][] @@ -232,6 +236,7 @@ Fork from [indeedeng/java-dogstatsd-client] (https://github.com/indeedeng/java-d [#203]: https://github.com/DataDog/java-dogstatsd-client/issues/203 [#211]: https://github.com/DataDog/java-dogstatsd-client/issues/211 [#217]: https://github.com/DataDog/java-dogstatsd-client/issues/217 +[#228]: https://github.com/DataDog/java-dogstatsd-client/pull/228 [@PatrickAuld]: https://github.com/PatrickAuld [@blevz]: https://github.com/blevz diff --git a/README.md b/README.md index 95931d52..08964b51 100644 --- a/README.md +++ b/README.md @@ -23,13 +23,23 @@ The client jar is distributed via Maven central, and can be downloaded [from Mav ### Unix Domain Socket support -As an alternative to UDP, Agent v6 can receive metrics via a UNIX Socket (on Linux only). This library supports transmission via this protocol. To use it, pass the socket path as a hostname, and `0` as port. +As an alternative to UDP, Agent v6 can receive metrics via a UNIX Socket (on Linux only). This library supports transmission via this protocol. To use it +use the `address()` method of the builder and pass the path to the socket with the `unix://` prefix: + +```java +StatsDClient client = new NonBlockingStatsDClientBuilder() + .address("unix:///var/run/datadog/dsd.socket") + .build(); +``` By default, all exceptions are ignored, mimicking UDP behaviour. When using Unix Sockets, transmission errors trigger exceptions you can choose to handle by passing a `StatsDClientErrorHandler`: - Connection error because of an invalid/missing socket triggers a `java.io.IOException: No such file or directory`. - If DogStatsD's reception buffer were to fill up and the non blocking client is used, the send times out after 100ms and throw either a `java.io.IOException: No buffer space available` or a `java.io.IOException: Resource temporarily unavailable`. +The default UDS transport is using `SOCK_DATAGRAM` sockets. We also have experimental support for `SOCK_STREAM` sockets which can +be enabled by using the `unixstream://` instead of `unix://`. This is not recommended for production use at this time. + ## Configuration Once your DogStatsD client is installed, instantiate it in your code: diff --git a/src/main/java/com/timgroup/statsd/NonBlockingStatsDClient.java b/src/main/java/com/timgroup/statsd/NonBlockingStatsDClient.java index 812f0063..b4953cfc 100644 --- a/src/main/java/com/timgroup/statsd/NonBlockingStatsDClient.java +++ b/src/main/java/com/timgroup/statsd/NonBlockingStatsDClient.java @@ -18,7 +18,6 @@ import java.util.concurrent.Callable; import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadLocalRandom; -import java.util.concurrent.TimeUnit; /** @@ -99,6 +98,7 @@ String tag() { public static final boolean DEFAULT_ENABLE_AGGREGATION = true; public static final boolean DEFAULT_ENABLE_ORIGIN_DETECTION = true; + public static final int SOCKET_CONNECT_TIMEOUT_MS = 1000; public static final String CLIENT_TAG = "client:java"; public static final String CLIENT_VERSION_TAG = "client_version:"; @@ -241,6 +241,9 @@ protected static String format(ThreadLocal formatter, Number value * The client tries to read the container ID by parsing the file /proc/self/cgroup. * This is not supported on Windows. * The client prioritizes the value passed via or entityID or DD_ENTITY_ID (if set) over the container ID. + * @param connectionTimeout + * the timeout in milliseconds for connecting to the StatsD server. Applies to unix sockets only. + * It is also used to detect if a connection is still alive and re-establish a new one if needed. * @throws StatsDClientException * if the client could not be started */ @@ -250,7 +253,7 @@ private NonBlockingStatsDClient(final String prefix, final int queueSize, final final int maxPacketSizeBytes, String entityID, final int poolSize, final int processorWorkers, final int senderWorkers, boolean blocking, final boolean enableTelemetry, final int telemetryFlushInterval, final int aggregationFlushInterval, final int aggregationShards, final ThreadFactory customThreadFactory, - String containerID, final boolean originDetectionEnabled) + String containerID, final boolean originDetectionEnabled, final int connectionTimeout) throws StatsDClientException { if ((prefix != null) && (!prefix.isEmpty())) { @@ -297,7 +300,7 @@ private NonBlockingStatsDClient(final String prefix, final int queueSize, final } try { - clientChannel = createByteChannel(addressLookup, timeout, bufferSize); + clientChannel = createByteChannel(addressLookup, timeout, connectionTimeout, bufferSize); ThreadFactory threadFactory = customThreadFactory != null ? customThreadFactory : new StatsDThreadFactory(); @@ -316,7 +319,7 @@ private NonBlockingStatsDClient(final String prefix, final int queueSize, final telemetryClientChannel = clientChannel; telemetryStatsDProcessor = statsDProcessor; } else { - telemetryClientChannel = createByteChannel(telemetryAddressLookup, timeout, bufferSize); + telemetryClientChannel = createByteChannel(telemetryAddressLookup, timeout, connectionTimeout, bufferSize); // similar settings, but a single worker and non-blocking. telemetryStatsDProcessor = createProcessor(queueSize, handler, maxPacketSizeBytes, @@ -377,7 +380,7 @@ public NonBlockingStatsDClient(final NonBlockingStatsDClientBuilder builder) thr builder.blocking, builder.enableTelemetry, builder.telemetryFlushInterval, (builder.enableAggregation ? builder.aggregationFlushInterval : 0), builder.aggregationShards, builder.threadFactory, builder.containerID, - builder.originDetectionEnabled); + builder.originDetectionEnabled, builder.connectionTimeout); } protected StatsDProcessor createProcessor(final int queueSize, final StatsDClientErrorHandler handler, @@ -478,11 +481,29 @@ StringBuilder tagString(final String[] tags, StringBuilder builder) { return tagString(tags, constantTagsRendered, builder); } - ClientChannel createByteChannel(Callable addressLookup, int timeout, int bufferSize) throws Exception { + ClientChannel createByteChannel( + Callable addressLookup, int timeout, int connectionTimeout, int bufferSize) + throws Exception { final SocketAddress address = addressLookup.call(); if (address instanceof NamedPipeSocketAddress) { return new NamedPipeClientChannel((NamedPipeSocketAddress) address); } + if (address instanceof UnixSocketAddressWithTransport) { + UnixSocketAddressWithTransport unixAddr = ((UnixSocketAddressWithTransport) address); + + // TODO: Maybe introduce a `UnixClientChannel` that can handle both stream and datagram sockets? This would + // Allow us to support `unix://` for both kind of sockets like in go. + switch (unixAddr.getTransportType()) { + case UDS_STREAM: + return new UnixStreamClientChannel(unixAddr.getAddress(), timeout, connectionTimeout, bufferSize); + case UDS_DATAGRAM: + case UDS: + return new UnixDatagramClientChannel(unixAddr.getAddress(), timeout, bufferSize); + default: + throw new IllegalArgumentException("Unsupported transport type: " + unixAddr.getTransportType()); + } + } + // We keep this for backward compatibility try { if (Class.forName("jnr.unixsocket.UnixSocketAddress").isInstance(address)) { return new UnixDatagramClientChannel(address, timeout, bufferSize); diff --git a/src/main/java/com/timgroup/statsd/NonBlockingStatsDClientBuilder.java b/src/main/java/com/timgroup/statsd/NonBlockingStatsDClientBuilder.java index a3cf9aab..a8ad3fc1 100644 --- a/src/main/java/com/timgroup/statsd/NonBlockingStatsDClientBuilder.java +++ b/src/main/java/com/timgroup/statsd/NonBlockingStatsDClientBuilder.java @@ -1,5 +1,6 @@ package com.timgroup.statsd; +import jnr.constants.platform.Sock; import jnr.unixsocket.UnixSocketAddress; import java.net.InetAddress; @@ -34,6 +35,7 @@ public class NonBlockingStatsDClientBuilder implements Cloneable { public int aggregationFlushInterval = StatsDAggregator.DEFAULT_FLUSH_INTERVAL; public int aggregationShards = StatsDAggregator.DEFAULT_SHARDS; public boolean originDetectionEnabled = NonBlockingStatsDClient.DEFAULT_ENABLE_ORIGIN_DETECTION; + public int connectionTimeout = NonBlockingStatsDClient.SOCKET_CONNECT_TIMEOUT_MS; public Callable addressLookup; public Callable telemetryAddressLookup; @@ -71,6 +73,11 @@ public NonBlockingStatsDClientBuilder timeout(int val) { return this; } + public NonBlockingStatsDClientBuilder connectionTimeout(int val) { + connectionTimeout = val; + return this; + } + public NonBlockingStatsDClientBuilder bufferPoolSize(int val) { bufferPoolSize = val; return this; @@ -126,6 +133,16 @@ public NonBlockingStatsDClientBuilder namedPipe(String val) { return this; } + public NonBlockingStatsDClientBuilder address(String address) { + addressLookup = getAddressLookupFromUrl(address); + return this; + } + + public NonBlockingStatsDClientBuilder telemetryAddress(String address) { + telemetryAddressLookup = getAddressLookupFromUrl(address); + return this; + } + public NonBlockingStatsDClientBuilder prefix(String val) { prefix = val; return this; @@ -283,9 +300,12 @@ private Callable getAddressLookupFromUrl(String url) { return staticAddress(uriHost, uriPort); } - if (parsed.getScheme().equals("unix")) { + if (parsed.getScheme().startsWith("unix")) { String uriPath = parsed.getPath(); - return staticAddress(uriPath, 0); + return staticUnixResolution( + uriPath, + UnixSocketAddressWithTransport.TransportType.fromScheme(parsed.getScheme()) + ); } return null; @@ -304,7 +324,10 @@ public static Callable volatileAddressResolution(final String hos if (port == 0) { return new Callable() { @Override public SocketAddress call() throws UnknownHostException { - return new UnixSocketAddress(hostname); + return new UnixSocketAddressWithTransport( + new UnixSocketAddress(hostname), + UnixSocketAddressWithTransport.TransportType.UDS + ); } }; } else { @@ -343,6 +366,17 @@ protected static Callable staticNamedPipeResolution(String namedP }; } + protected static Callable staticUnixResolution( + final String path, + final UnixSocketAddressWithTransport.TransportType transportType) { + return new Callable() { + @Override public SocketAddress call() { + final UnixSocketAddress socketAddress = new UnixSocketAddress(path); + return new UnixSocketAddressWithTransport(socketAddress, transportType); + } + }; + } + private static Callable staticAddress(final String hostname, final int port) { try { return staticAddressResolution(hostname, port); diff --git a/src/main/java/com/timgroup/statsd/UnixSocketAddressWithTransport.java b/src/main/java/com/timgroup/statsd/UnixSocketAddressWithTransport.java new file mode 100644 index 00000000..0fbca0d9 --- /dev/null +++ b/src/main/java/com/timgroup/statsd/UnixSocketAddressWithTransport.java @@ -0,0 +1,70 @@ +package com.timgroup.statsd; + +import java.net.SocketAddress; +import java.util.Objects; + +public class UnixSocketAddressWithTransport extends SocketAddress { + + private final SocketAddress address; + private final TransportType transportType; + + public enum TransportType { + UDS_STREAM("uds-stream"), + UDS_DATAGRAM("uds-datagram"), + UDS("uds"); + + private final String transportType; + + TransportType(String transportType) { + this.transportType = transportType; + } + + String getTransportType() { + return transportType; + } + + static TransportType fromScheme(String scheme) { + switch (scheme) { + case "unixstream": + return UDS_STREAM; + case "unixgram": + return UDS_DATAGRAM; + case "unix": + return UDS; + default: + break; + } + throw new IllegalArgumentException("Unknown scheme: " + scheme); + } + } + + public UnixSocketAddressWithTransport(final SocketAddress address, final TransportType transportType) { + this.address = address; + this.transportType = transportType; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + UnixSocketAddressWithTransport that = (UnixSocketAddressWithTransport) other; + return Objects.equals(address, that.address) && transportType == that.transportType; + } + + @Override + public int hashCode() { + return Objects.hash(address, transportType); + } + + SocketAddress getAddress() { + return address; + } + + TransportType getTransportType() { + return transportType; + } +} diff --git a/src/main/java/com/timgroup/statsd/UnixStreamClientChannel.java b/src/main/java/com/timgroup/statsd/UnixStreamClientChannel.java new file mode 100644 index 00000000..8b4ab17a --- /dev/null +++ b/src/main/java/com/timgroup/statsd/UnixStreamClientChannel.java @@ -0,0 +1,171 @@ +package com.timgroup.statsd; + +import jnr.unixsocket.UnixSocketAddress; +import jnr.unixsocket.UnixSocketChannel; +import jnr.unixsocket.UnixSocketOptions; + +import java.io.IOException; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.SocketChannel; + +/** + * A ClientChannel for Unix domain sockets. + */ +public class UnixStreamClientChannel implements ClientChannel { + private final UnixSocketAddress address; + private final int timeout; + private final int connectionTimeout; + private final int bufferSize; + + + private SocketChannel delegate; + private final ByteBuffer delimiterBuffer = ByteBuffer.allocateDirect(Integer.SIZE / Byte.SIZE).order(ByteOrder.LITTLE_ENDIAN); + + /** + * Creates a new NamedPipeClientChannel with the given address. + * + * @param address Location of named pipe + */ + UnixStreamClientChannel(SocketAddress address, int timeout, int connectionTimeout, int bufferSize) throws IOException { + this.delegate = null; + this.address = (UnixSocketAddress) address; + this.timeout = timeout; + this.connectionTimeout = connectionTimeout; + this.bufferSize = bufferSize; + } + + @Override + public boolean isOpen() { + return delegate.isConnected(); + } + + @Override + public synchronized int write(ByteBuffer src) throws IOException { + connectIfNeeded(); + + int size = src.remaining(); + int written = 0; + if (size == 0) { + return 0; + } + delimiterBuffer.clear(); + delimiterBuffer.putInt(size); + delimiterBuffer.flip(); + + try { + long deadline = System.nanoTime() + timeout * 1_000_000L; + written = writeAll(delimiterBuffer, true, deadline); + if (written > 0) { + written += writeAll(src, false, deadline); + } + } catch (IOException e) { + // If we get an exception, it's unrecoverable, we close the channel and try to reconnect + disconnect(); + throw e; + } + + // If we haven't written anything, we have a timeout + if (written == 0) { + throw new IOException("Write timed out"); + } + + return size; + } + + /** + * Writes all bytes from the given buffer to the channel. + * @param bb buffer to write + * @param canReturnOnTimeout if true, we return if the channel is blocking and we haven't written anything yet + * @param deadline deadline for the write + * @return number of bytes written + * @throws IOException if the channel is closed or an error occurs + */ + public int writeAll(ByteBuffer bb, boolean canReturnOnTimeout, long deadline) throws IOException { + int remaining = bb.remaining(); + int written = 0; + while (remaining > 0) { + int read = delegate.write(bb); + + // If we haven't written anything yet, we can still return + if (read == 0 && canReturnOnTimeout && written == 0) { + return written; + } + + remaining -= read; + written += read; + + if (deadline > 0 && System.nanoTime() > deadline) { + throw new IOException("Write timed out"); + } + } + return written; + } + + private void connectIfNeeded() throws IOException { + if (delegate == null) { + connect(); + } + } + + private void disconnect() throws IOException { + if (delegate != null) { + delegate.close(); + delegate = null; + } + } + + private void connect() throws IOException { + if (this.delegate != null) { + try { + disconnect(); + } catch (IOException e) { + // ignore to be sure we don't stay with a broken delegate forever. + } + } + + UnixSocketChannel delegate = UnixSocketChannel.create(); + + long deadline = System.nanoTime() + connectionTimeout * 1_000_000L; + if (connectionTimeout > 0) { + // Set connect timeout, this should work at least on linux + // https://elixir.bootlin.com/linux/v5.7.4/source/net/unix/af_unix.c#L1696 + // We'd have better timeout support if we used Java 16's native Unix domain socket support (JEP 380) + delegate.setOption(UnixSocketOptions.SO_SNDTIMEO, connectionTimeout); + } + if (!delegate.connect(address)) { + if (connectionTimeout > 0 && System.nanoTime() > deadline) { + throw new IOException("Connection timed out"); + } + if (!delegate.finishConnect()) { + throw new IOException("Connection failed"); + } + } + + if (timeout > 0) { + delegate.setOption(UnixSocketOptions.SO_SNDTIMEO, timeout); + } else { + delegate.setOption(UnixSocketOptions.SO_SNDTIMEO, 0); + } + if (bufferSize > 0) { + delegate.setOption(UnixSocketOptions.SO_SNDBUF, bufferSize); + } + this.delegate = delegate; + } + + @Override + public void close() throws IOException { + disconnect(); + } + + @Override + public String getTransportType() { + return "uds-stream"; + } + + @Override + public String toString() { + return "[" + getTransportType() + "] " + address; + } +} diff --git a/src/test/java/com/timgroup/statsd/BuilderAddressTest.java b/src/test/java/com/timgroup/statsd/BuilderAddressTest.java index 569ac837..430ba926 100644 --- a/src/test/java/com/timgroup/statsd/BuilderAddressTest.java +++ b/src/test/java/com/timgroup/statsd/BuilderAddressTest.java @@ -9,13 +9,10 @@ import jnr.unixsocket.UnixSocketAddress; -import org.junit.Assume; import org.junit.Before; import org.junit.Test; import org.junit.Rule; import org.junit.contrib.java.lang.system.EnvironmentVariables; -import org.junit.runners.MethodSorters; -import org.junit.function.ThrowingRunnable; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -42,15 +39,6 @@ public BuilderAddressTest(String url, String host, String port, String pipe, Soc this.expected = expected; } - static boolean isJnrAvailable() { - try { - Class.forName("jnr.unixsocket.UnixDatagramChannel"); - return true; - } catch (ClassNotFoundException e) { - return false; - } - } - static final private int defaultPort = NonBlockingStatsDClient.DEFAULT_DOGSTATSD_PORT; @Parameters @@ -83,19 +71,38 @@ public static Collection parameters() { { null, "1.1.1.1", "9999", "foo", new NamedPipeSocketAddress("\\\\.\\pipe\\foo") }, })); - if (isJnrAvailable()) { + if (TestHelpers.isJnrAvailable()) { + // Here we use FakeUnixSocketAddress instead of UnixSocketAddress to make sure we can always run the tests without jnr-unixsock. + + UnixSocketAddressWithTransport unixDsd = new UnixSocketAddressWithTransport(new FakeUnixSocketAddress("/dsd.sock"), UnixSocketAddressWithTransport.TransportType.UDS); + UnixSocketAddressWithTransport unixDgramDsd = new UnixSocketAddressWithTransport(new FakeUnixSocketAddress("/dsd.sock"), UnixSocketAddressWithTransport.TransportType.UDS_DATAGRAM); + UnixSocketAddressWithTransport unixStreamDsd = new UnixSocketAddressWithTransport(new FakeUnixSocketAddress("/dsd.sock"), UnixSocketAddressWithTransport.TransportType.UDS_STREAM); + params.addAll(Arrays.asList(new Object[][]{ - { "unix:///dsd.sock", null, null, null, new UnixSocketAddress("/dsd.sock") }, - { "unix://unused/dsd.sock", null, null, null, new UnixSocketAddress("/dsd.sock") }, - { "unix://unused:9999/dsd.sock", null, null, null, new UnixSocketAddress("/dsd.sock") }, - { null, "/dsd.sock", "0", null, new UnixSocketAddress("/dsd.sock") }, - { "unix:///dsd.sock", "1.1.1.1", "9999", null, new UnixSocketAddress("/dsd.sock") }, + { "unix:///dsd.sock", null, null, null, unixDsd }, + { "unix://unused/dsd.sock", null, null, null, unixDsd }, + { "unix://unused:9999/dsd.sock", null, null, null, unixDsd}, + { null, "/dsd.sock", "0", null, unixDsd }, + { "unix:///dsd.sock", "1.1.1.1", "9999", null, unixDsd }, + { "unixgram:///dsd.sock", null, null, null, unixDgramDsd }, + { "unixstream:///dsd.sock", null, null, null, unixStreamDsd }, })); } return params; } + static class FakeUnixSocketAddress extends SocketAddress { + final String path; + public FakeUnixSocketAddress(String path) { + this.path = path; + } + + public String getPath() { + return path; + } + } + @Before public void set() { set(NonBlockingStatsDClient.DD_DOGSTATSD_URL_ENV_VAR, url); @@ -118,7 +125,17 @@ public void address_resolution() throws Exception { // Default configuration matches env vars b = new NonBlockingStatsDClientBuilder().resolve(); - assertEquals(expected, b.addressLookup.call()); + SocketAddress actual = b.addressLookup.call(); + + // Make it possible to run this code even if we don't have jnr-unixsocket. + if (expected instanceof UnixSocketAddressWithTransport) { + UnixSocketAddressWithTransport a = (UnixSocketAddressWithTransport)actual; + UnixSocketAddressWithTransport e = (UnixSocketAddressWithTransport)expected; + assertEquals(((FakeUnixSocketAddress)e.getAddress()).getPath(), ((UnixSocketAddress)a.getAddress()).path()); + assertEquals(e.getTransportType(), a.getTransportType()); + } else { + assertEquals(expected, actual); + } // Explicit configuration is used regardless of environment variables. b = new NonBlockingStatsDClientBuilder().hostname("2.2.2.2").resolve(); diff --git a/src/test/java/com/timgroup/statsd/DummyStatsDServer.java b/src/test/java/com/timgroup/statsd/DummyStatsDServer.java index 0609a7b9..1aa4dc0a 100644 --- a/src/test/java/com/timgroup/statsd/DummyStatsDServer.java +++ b/src/test/java/com/timgroup/statsd/DummyStatsDServer.java @@ -10,6 +10,8 @@ import java.util.List; import java.util.concurrent.atomic.AtomicInteger; +import static com.timgroup.statsd.NonBlockingStatsDClient.DEFAULT_UDS_MAX_PACKET_SIZE_BYTES; + abstract class DummyStatsDServer implements Closeable { private final List messagesReceived = new ArrayList(); private AtomicInteger packetsReceived = new AtomicInteger(0); @@ -20,7 +22,7 @@ protected void listen() { Thread thread = new Thread(new Runnable() { @Override public void run() { - final ByteBuffer packet = ByteBuffer.allocate(1500); + final ByteBuffer packet = ByteBuffer.allocate(DEFAULT_UDS_MAX_PACKET_SIZE_BYTES); while(isOpen()) { if (freeze) { @@ -33,12 +35,7 @@ public void run() { ((Buffer)packet).clear(); // Cast necessary to handle Java9 covariant return types // see: https://jira.mongodb.org/browse/JAVA-2559 for ref. receive(packet); - packetsReceived.addAndGet(1); - - packet.flip(); - for (String msg : StandardCharsets.UTF_8.decode(packet).toString().split("\n")) { - addMessage(msg); - } + handlePacket(packet); } catch (IOException e) { } } @@ -49,6 +46,25 @@ public void run() { thread.start(); } + protected boolean sleepIfFrozen() { + if (freeze) { + try { + Thread.sleep(10); + } catch (InterruptedException e) { + } + } + return freeze; + } + + protected void handlePacket(ByteBuffer packet) { + packetsReceived.addAndGet(1); + + packet.flip(); + for (String msg : StandardCharsets.UTF_8.decode(packet).toString().split("\n")) { + addMessage(msg); + } + } + public void waitForMessage() { waitForMessage(null); } diff --git a/src/test/java/com/timgroup/statsd/NonBlockingStatsDClientTest.java b/src/test/java/com/timgroup/statsd/NonBlockingStatsDClientTest.java index d706fb16..96ebbf57 100644 --- a/src/test/java/com/timgroup/statsd/NonBlockingStatsDClientTest.java +++ b/src/test/java/com/timgroup/statsd/NonBlockingStatsDClientTest.java @@ -1806,7 +1806,7 @@ public NonBlockingStatsDClient build() { this.originDetectionEnabled(false); return new NonBlockingStatsDClient(resolve()) { @Override - ClientChannel createByteChannel(Callable addressLookup, int timeout, int bufferSize) throws Exception { + ClientChannel createByteChannel(Callable addressLookup, int timeout, int connectionTimeout, int bufferSize) throws Exception { return new DatagramClientChannel(addressLookup.call()) { @Override public int write(ByteBuffer data) throws IOException { @@ -1845,7 +1845,7 @@ public NonBlockingStatsDClient build() { this.bufferPoolSize(1); return new NonBlockingStatsDClient(resolve()) { @Override - ClientChannel createByteChannel(Callable addressLookup, int timeout, int bufferSize) throws Exception { + ClientChannel createByteChannel(Callable addressLookup, int timeout, int connectionTimeout, int bufferSize) throws Exception { return new DatagramClientChannel(addressLookup.call()) { @Override public int write(ByteBuffer data) throws IOException { diff --git a/src/test/java/com/timgroup/statsd/TelemetryTest.java b/src/test/java/com/timgroup/statsd/TelemetryTest.java index 9f4d1a8d..68824fa2 100644 --- a/src/test/java/com/timgroup/statsd/TelemetryTest.java +++ b/src/test/java/com/timgroup/statsd/TelemetryTest.java @@ -1,5 +1,6 @@ package com.timgroup.statsd; +import java.util.logging.Logger; import org.junit.After; import org.junit.AfterClass; import org.junit.Assume; @@ -21,6 +22,14 @@ public class TelemetryTest { @Override public void handle(final Exception ex) { /* No-op */ } }; + private static final StatsDClientErrorHandler LOGGING_HANDLER = new StatsDClientErrorHandler() { + + Logger log = Logger.getLogger(StatsDClientErrorHandler.class.getName()); + @Override public void handle(final Exception ex) { + log.warning("Got exception: " + ex); + } + }; + // fakeProcessor store messages from the telemetry only public static class FakeProcessor extends StatsDProcessor { public final List messages = new ArrayList<>(); @@ -110,7 +119,7 @@ private static String computeTelemetryTags() throws IOException, Exception { @BeforeClass public static void start() throws IOException, Exception { server = new UDPDummyStatsDServer(STATSD_SERVER_PORT); - fakeProcessor = new FakeProcessor(NO_OP_HANDLER); + fakeProcessor = new FakeProcessor(LOGGING_HANDLER); client.telemetry.processor = fakeProcessor; telemetryClient.telemetry.processor = fakeProcessor; @@ -376,7 +385,7 @@ public void telemetry_flushInterval() throws Exception { @Test(timeout = 5000L) public void telemetry_droppedData() throws Exception { - Assume.assumeTrue(UnixSocketTest.isUdsAvailable()); + Assume.assumeTrue(TestHelpers.isUdsAvailable()); // fails to send any data on the network, producing packets dropped NonBlockingStatsDClient clientError = new NonBlockingStatsDClientBuilder() diff --git a/src/test/java/com/timgroup/statsd/TestHelpers.java b/src/test/java/com/timgroup/statsd/TestHelpers.java new file mode 100644 index 00000000..1eebd9e8 --- /dev/null +++ b/src/test/java/com/timgroup/statsd/TestHelpers.java @@ -0,0 +1,26 @@ +package com.timgroup.statsd; + +public class TestHelpers +{ + static boolean isLinux() { + return System.getProperty("os.name").toLowerCase().contains("linux"); + } + + static boolean isMac() { + return System.getProperty("os.name").toLowerCase().contains("mac"); + } + + // Check if jnr.unixsocket is on the classpath. + static boolean isJnrAvailable() { + try { + Class.forName("jnr.unixsocket.UnixDatagramChannel"); + return true; + } catch (ClassNotFoundException e) { + return false; + } + } + + static boolean isUdsAvailable() { + return (isLinux() || isMac()) && isJnrAvailable(); + } +} diff --git a/src/test/java/com/timgroup/statsd/UnixSocketDummyStatsDServer.java b/src/test/java/com/timgroup/statsd/UnixDatagramSocketDummyStatsDServer.java similarity index 73% rename from src/test/java/com/timgroup/statsd/UnixSocketDummyStatsDServer.java rename to src/test/java/com/timgroup/statsd/UnixDatagramSocketDummyStatsDServer.java index fde642d6..753411e0 100644 --- a/src/test/java/com/timgroup/statsd/UnixSocketDummyStatsDServer.java +++ b/src/test/java/com/timgroup/statsd/UnixDatagramSocketDummyStatsDServer.java @@ -6,20 +6,22 @@ import jnr.unixsocket.UnixDatagramChannel; import jnr.unixsocket.UnixSocketAddress; -public class UnixSocketDummyStatsDServer extends DummyStatsDServer { +import static com.timgroup.statsd.NonBlockingStatsDClient.DEFAULT_UDS_MAX_PACKET_SIZE_BYTES; + +public class UnixDatagramSocketDummyStatsDServer extends DummyStatsDServer { private final DatagramChannel server; - public UnixSocketDummyStatsDServer(String socketPath) throws IOException { + public UnixDatagramSocketDummyStatsDServer(String socketPath) throws IOException { server = UnixDatagramChannel.open(); server.bind(new UnixSocketAddress(socketPath)); this.listen(); } + @Override protected boolean isOpen() { return server.isOpen(); } - @Override protected void receive(ByteBuffer packet) throws IOException { server.receive(packet); } diff --git a/src/test/java/com/timgroup/statsd/UnixSocketTest.java b/src/test/java/com/timgroup/statsd/UnixSocketTest.java index ec2068e4..374436de 100644 --- a/src/test/java/com/timgroup/statsd/UnixSocketTest.java +++ b/src/test/java/com/timgroup/statsd/UnixSocketTest.java @@ -1,5 +1,6 @@ package com.timgroup.statsd; +import java.util.logging.Logger; import org.junit.After; import org.junit.Assume; import org.junit.Before; @@ -22,37 +23,19 @@ public class UnixSocketTest implements StatsDClientErrorHandler { private static NonBlockingStatsDClient clientAggregate; private static DummyStatsDServer server; private static File socketFile; + private volatile Exception lastException = new Exception(); + private static Logger log = Logger.getLogger(StatsDClientErrorHandler.class.getName()); + public synchronized void handle(Exception exception) { + log.info("Got exception: " + exception); lastException = exception; } - static boolean isLinux() { - return System.getProperty("os.name").toLowerCase().contains("linux"); - } - - static boolean isMac() { - return System.getProperty("os.name").toLowerCase().contains("mac"); - } - - static boolean isJnrAvailable() { - try { - Class.forName("jnr.unixsocket.UnixDatagramChannel"); - return true; - } catch (ClassNotFoundException e) { - return false; - } - } - - // Check if jnr.unixsocket is on the classpath. - static boolean isUdsAvailable() { - return (isLinux() || isMac()) && isJnrAvailable(); - } - @BeforeClass public static void supportedOnly() throws IOException { - Assume.assumeTrue(isUdsAvailable()); + Assume.assumeTrue(TestHelpers.isUdsAvailable()); } @Before @@ -62,7 +45,8 @@ public void start() throws IOException { socketFile = new File(tmpFolder, "socket.sock"); socketFile.deleteOnExit(); - server = new UnixSocketDummyStatsDServer(socketFile.toString()); + server = new UnixDatagramSocketDummyStatsDServer(socketFile.toString()); + client = new NonBlockingStatsDClientBuilder().prefix("my.prefix") .hostname(socketFile.toString()) .port(0) @@ -112,6 +96,7 @@ public void sends_to_statsd() throws Exception { @Test(timeout = 10000L) public void resist_dsd_restart() throws Exception { + // Send one metric, check that it works. client.gauge("mycount", 10); server.waitForMessage(); assertThat(server.messagesReceived(), contains("my.prefix.mycount:10|g")); @@ -127,8 +112,8 @@ public void resist_dsd_restart() throws Exception { assertThat(lastException.getMessage(), containsString("Connection refused")); // Delete the socket file, client should throw an IOException - lastException = new Exception(); socketFile.delete(); + lastException = new Exception(); client.gauge("mycount", 21); while(lastException.getMessage() == null) { @@ -138,12 +123,13 @@ public void resist_dsd_restart() throws Exception { // Re-open the server, next send should work OK lastException = new Exception(); - DummyStatsDServer server2 = new UnixSocketDummyStatsDServer(socketFile.toString()); + DummyStatsDServer server2 = new UnixDatagramSocketDummyStatsDServer(socketFile.toString()); client.gauge("mycount", 30); - server2.waitForMessage(); + server2.waitForMessage(); assertThat(server2.messagesReceived(), hasItem("my.prefix.mycount:30|g")); + server2.clear(); assertThat(lastException.getMessage(), nullValue()); server2.close(); @@ -159,16 +145,18 @@ public void resist_dsd_timeout() throws Exception { // Freeze the server to simulate dsd being overwhelmed server.freeze(); - while(lastException.getMessage() == null) { + while (lastException.getMessage() == null) { client.gauge("mycount", 20); Thread.sleep(10); // We need to fill the buffer, setting a shorter sleep } - String excMessage = isLinux() ? "Resource temporarily unavailable" : "No buffer space available"; + String excMessage = TestHelpers.isLinux() ? "Resource temporarily unavailable" : "No buffer space available"; assertThat(lastException.getMessage(), containsString(excMessage)); // Make sure we recover after we resume listening server.clear(); server.unfreeze(); + + // Now make sure we can receive gauges with 30 while (!server.messagesReceived().contains("my.prefix.mycount:30|g")) { server.clear(); client.gauge("mycount", 30); diff --git a/src/test/java/com/timgroup/statsd/UnixStreamSocketDummyStatsDServer.java b/src/test/java/com/timgroup/statsd/UnixStreamSocketDummyStatsDServer.java new file mode 100644 index 00000000..ea743c8e --- /dev/null +++ b/src/test/java/com/timgroup/statsd/UnixStreamSocketDummyStatsDServer.java @@ -0,0 +1,140 @@ +package com.timgroup.statsd; + +import java.io.IOException; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.logging.Logger; +import jnr.unixsocket.UnixServerSocketChannel; +import jnr.unixsocket.UnixSocketAddress; +import jnr.unixsocket.UnixSocketChannel; + +import static com.timgroup.statsd.NonBlockingStatsDClient.DEFAULT_UDS_MAX_PACKET_SIZE_BYTES; + +public class UnixStreamSocketDummyStatsDServer extends DummyStatsDServer { + private final UnixServerSocketChannel server; + private final ConcurrentLinkedQueue channels = new ConcurrentLinkedQueue<>(); + + private final Logger logger = Logger.getLogger(UnixStreamSocketDummyStatsDServer.class.getName()); + + public UnixStreamSocketDummyStatsDServer(String socketPath) throws IOException { + server = UnixServerSocketChannel.open(); + server.configureBlocking(true); + server.socket().bind(new UnixSocketAddress(socketPath)); + this.listen(); + } + + @Override + protected boolean isOpen() { + return server.isOpen(); + } + + @Override + protected void receive(ByteBuffer packet) throws IOException { + // This is unused because we re-implement listen() to fit our needs + } + + @Override + protected void listen() { + logger.info("Listening on " + server.getLocalSocketAddress()); + Thread thread = new Thread(new Runnable() { + @Override + public void run() { + while(isOpen()) { + if (sleepIfFrozen()) { + continue; + } + try { + logger.info("Waiting for connection"); + UnixSocketChannel clientChannel = server.accept(); + if (clientChannel != null) { + clientChannel.configureBlocking(true); + try { + logger.info("Accepted connection from " + clientChannel.getRemoteSocketAddress()); + } catch (Exception e) { + logger.warning("Failed to get remote socket address"); + } + channels.add(clientChannel); + readChannel(clientChannel); + } + } catch (IOException e) { + } + } + } + }); + thread.setDaemon(true); + thread.start(); + } + + public void readChannel(final UnixSocketChannel clientChannel) { + logger.info("Reading from " + clientChannel); + Thread thread = new Thread(new Runnable() { + @Override + public void run() { + final ByteBuffer packet = ByteBuffer.allocate(DEFAULT_UDS_MAX_PACKET_SIZE_BYTES); + + while(clientChannel.isOpen()) { + if (sleepIfFrozen()) { + continue; + } + ((Buffer)packet).clear(); // Cast necessary to handle Java9 covariant return types + // see: https://jira.mongodb.org/browse/JAVA-2559 for ref. + if (readPacket(clientChannel, packet)) { + handlePacket(packet); + } else { + try { + clientChannel.close(); + } catch (IOException e) { + logger.warning("Failed to close channel: " + e); + } + } + + } + logger.info("Disconnected from " + clientChannel); + } + }); + thread.setDaemon(true); + thread.start(); + } + + private boolean readPacket(SocketChannel channel, ByteBuffer packet) { + try { + ByteBuffer delimiterBuffer = ByteBuffer.allocate(Integer.SIZE / Byte.SIZE).order(ByteOrder.LITTLE_ENDIAN); + + int read = channel.read(delimiterBuffer); + + delimiterBuffer.flip(); + if (read <= 0) { + // There was nothing to read + return false; + } + + int packetSize = delimiterBuffer.getInt(); + if (packetSize > packet.capacity()) { + throw new IOException("Packet size too large"); + } + + packet.limit(packetSize); + while (packet.hasRemaining() && channel.isConnected()) { + channel.read(packet); + } + return true; + } catch (IOException e) { + return false; + } + } + + public void close() throws IOException { + try { + server.close(); + for (UnixSocketChannel channel : channels) { + channel.close(); + } + } catch (Exception e) { + //ignore + } + } + +} diff --git a/src/test/java/com/timgroup/statsd/UnixStreamSocketTest.java b/src/test/java/com/timgroup/statsd/UnixStreamSocketTest.java new file mode 100644 index 00000000..45af036c --- /dev/null +++ b/src/test/java/com/timgroup/statsd/UnixStreamSocketTest.java @@ -0,0 +1,174 @@ +package com.timgroup.statsd; + +import java.util.logging.Logger; +import org.junit.After; +import org.junit.Assume; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import java.io.IOException; +import java.io.File; +import java.nio.file.Files; + +import static org.hamcrest.CoreMatchers.anyOf; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasItem; +import static org.junit.Assert.assertEquals; + +public class UnixStreamSocketTest implements StatsDClientErrorHandler { + private static File tmpFolder; + private static NonBlockingStatsDClient client; + private static NonBlockingStatsDClient clientAggregate; + private static DummyStatsDServer server; + private static File socketFile; + + private volatile Exception lastException = new Exception(); + + private static Logger log = Logger.getLogger(StatsDClientErrorHandler.class.getName()); + + public synchronized void handle(Exception exception) { + log.info("Got exception: " + exception); + lastException = exception; + } + + @BeforeClass + public static void supportedOnly() throws IOException { + Assume.assumeTrue(TestHelpers.isUdsAvailable()); + } + + @Before + public void start() throws IOException { + tmpFolder = Files.createTempDirectory(System.getProperty("java-dsd-test")).toFile(); + tmpFolder.deleteOnExit(); + socketFile = new File(tmpFolder, "socket.sock"); + socketFile.deleteOnExit(); + + server = new UnixStreamSocketDummyStatsDServer(socketFile.toString()); + + client = new NonBlockingStatsDClientBuilder().prefix("my.prefix") + .address("unixstream://" + socketFile.getPath()) + .port(0) + .queueSize(1) + .timeout(500) // non-zero timeout to ensure exception triggered if socket buffer full. + .connectionTimeout(500) + .socketBufferSize(1024 * 1024) + .enableAggregation(false) + .errorHandler(this) + .originDetectionEnabled(false) + .build(); + + clientAggregate = new NonBlockingStatsDClientBuilder().prefix("my.prefix") + .address("unixstream://" + socketFile.getPath()) + .port(0) + .queueSize(1) + .timeout(500) // non-zero timeout to ensure exception triggered if socket buffer full. + .connectionTimeout(500) + .socketBufferSize(1024 * 1024) + .enableAggregation(false) + .errorHandler(this) + .originDetectionEnabled(false) + .build(); + } + + @After + public void stop() throws Exception { + client.stop(); + clientAggregate.stop(); + server.close(); + } + + @Test + public void assert_default_uds_size() throws Exception { + assertEquals(client.statsDProcessor.bufferPool.getBufferSize(), NonBlockingStatsDClient.DEFAULT_UDS_MAX_PACKET_SIZE_BYTES); + } + + @Test(timeout = 5000L) + public void sends_to_statsd() throws Exception { + for(long i = 0; i < 5 ; i++) { + client.gauge("mycount", i); + server.waitForMessage(); + String expected = String.format("my.prefix.mycount:%d|g", i); + assertThat(server.messagesReceived(), contains(expected)); + server.clear(); + } + assertThat(lastException.getMessage(), nullValue()); + } + + @Test(timeout = 10000L) + public void resist_dsd_restart() throws Exception { + // Send one metric, check that it works. + client.gauge("mycount", 10); + server.waitForMessage(); + assertThat(server.messagesReceived(), contains("my.prefix.mycount:10|g")); + server.clear(); + assertThat(lastException.getMessage(), nullValue()); + + // Close the server, client should throw an IOException + server.close(); + while(lastException.getMessage() == null) { + client.gauge("mycount", 20); + Thread.sleep(10); + } + // Depending on the state of the client at that point we might get different messages. + assertThat(lastException.getMessage(), anyOf(containsString("Connection refused"), containsString("Broken pipe"))); + + // Delete the socket file, client should throw an IOException + lastException = new Exception(); + socketFile.delete(); + + client.gauge("mycount", 21); + while(lastException.getMessage() == null) { + Thread.sleep(10); + } + assertThat(lastException.getMessage(), containsString("No such file or directory")); + + // Re-open the server, next send should work OK + DummyStatsDServer server2; + server2 = new UnixStreamSocketDummyStatsDServer(socketFile.toString()); + + lastException = new Exception(); + + client.gauge("mycount", 30); + server2.waitForMessage(); + assertThat(server2.messagesReceived(), hasItem("my.prefix.mycount:30|g")); + + server2.clear(); + assertThat(lastException.getMessage(), nullValue()); + server2.close(); + } + + @Test(timeout = 10000L) + public void resist_dsd_timeout() throws Exception { + client.gauge("mycount", 10); + server.waitForMessage(); + assertThat(server.messagesReceived(), contains("my.prefix.mycount:10|g")); + server.clear(); + assertThat(lastException.getMessage(), nullValue()); + + // Freeze the server to simulate dsd being overwhelmed + server.freeze(); + + while (lastException.getMessage() == null) { + client.gauge("mycount", 20); + + } + String excMessage = "Write timed out"; + assertThat(lastException.getMessage(), containsString(excMessage)); + + // Make sure we recover after we resume listening + server.clear(); + server.unfreeze(); + + // Now make sure we can receive gauges with 30 + while (!server.messagesReceived().contains("my.prefix.mycount:30|g")) { + server.clear(); + client.gauge("mycount", 30); + server.waitForMessage(); + } + assertThat(server.messagesReceived(), hasItem("my.prefix.mycount:30|g")); + server.clear(); + } +}