diff --git a/src/main/java/io/nats/client/Connection.java b/src/main/java/io/nats/client/Connection.java index ea97169fc..8c675a090 100644 --- a/src/main/java/io/nats/client/Connection.java +++ b/src/main/java/io/nats/client/Connection.java @@ -13,6 +13,8 @@ package io.nats.client; +import io.nats.client.impl.Headers; + import java.time.Duration; import java.util.Collection; import java.util.concurrent.CompletableFuture; @@ -21,7 +23,7 @@ /** * The Connection class is at the heart of the NATS Java client. Fundamentally a connection represents * a single network connection to the NATS server. - * + * * <p>Each connection you create will result in the creation of a single socket and several threads: * <ul> * <li> A reader thread for taking data off the socket @@ -29,55 +31,55 @@ * <li> A timer thread for a few maintenance timers * <li> A dispatch thread to handle request/reply traffic * </ul> - * + * * <p>The connection has a {@link Connection.Status status} which can be checked using the {@link #getStatus() getStatus} * method or watched using a {@link ConnectionListener ConnectionListener}. - * + * * <p>Connections, by default, are configured to try to reconnect to the server if there is a network failure up to * {@link Options#DEFAULT_MAX_RECONNECT times}. You can configure this behavior in the {@link Options Options}. * Moreover, the options allows you to control whether reconnect happens in the same order every time, and the time * to wait if trying to reconnect to the same server over and over. - * + * * <p>The list of servers used for connecting is provided by the {@link Options Options}. The list of servers used * during reconnect can be an expanded list. This expansion comes from the connections most recent server. For example, * if you connect to serverA, it can tell the connection "i know about serverB and serverC". If serverA goes down - * the client library will try to connect to serverA, serverB and serverC. Now, if the library connects to serverB, it may tell the client + * the client library will try to connect to serverA, serverB and serverC. Now, if the library connects to serverB, it may tell the client * "i know about serverB and serverE". The client's list of servers, available from {@link #getServers() getServers()} * will now be serverA from the initial connect, serverB and serverE, the reference to serverC is lost. - * + * * <p>When a connection is {@link #close() closed} the thread and socket resources are cleaned up. - * - * <p>All outgoing messages are sent through the connection object using one of the two + * + * <p>All outgoing messages are sent through the connection object using one of the two * {@link #publish(String, byte[]) publish} methods or the {@link #request(String, byte[]) request} method. * When publishing you can specify a reply to subject which can be retrieved by the receiver to respond. * The request method will handle this behavior itself, but it relies on getting the value out of a Future * so may be less flexible than publish with replyTo set. - * + * * <p>Messages can be received in two ways. You can create a Subscription which will allow you to read messages - * synchronously using the {@link Subscription#nextMessage(Duration) nextMessage} method or you can create a + * synchronously using the {@link Subscription#nextMessage(Duration) nextMessage} method or you can create a * {@link Dispatcher Dispatcher}. The Dispatcher will create a thread to listen for messages on one or more subscriptions. * The Dispatcher groups a set of subscriptions into a single listener thread that calls application code * for each messages. - * - * <p>Applications can use the {@link #flush(Duration) flush} method to check that published messages have + * + * <p>Applications can use the {@link #flush(Duration) flush} method to check that published messages have * made it to the server. However, this method initiates a round trip to the server and waits for the response so * it should be used sparingly. - * + * * <p>The connection provides two listeners via the Options. The {@link ConnectionListener ConnectionListener} * can be used to listen for lifecycle events. This listener is required for * {@link Nats#connectAsynchronously(Options, boolean) connectAsynchronously}, but otherwise optional. The * {@link ErrorListener ErrorListener} provides three callback opportunities including slow consumers, error * messages from the server and exceptions handled by the client library. These listeners can only be set at creation time * using the {@link Options options}. - * + * * <p><em>Note</em>: The publish methods take an array of bytes. These arrays <strong>will not be copied</strong>. This design choice * is based on the common case of strings or objects being converted to bytes. Once a client can be sure a message was received by * the NATS server it is theoretically possible to reuse that byte array, but this pattern should be treated as advanced and only used - * after thorough testing. + * after thorough testing. */ -public interface Connection extends AutoCloseable { +public interface Connection<T extends Message> extends AutoCloseable { - public enum Status { + enum Status { /** * The {@code Connection} is not connected. */ @@ -105,33 +107,55 @@ public enum Status { * Send a message to the specified subject. The message body <strong>will * not</strong> be copied. The expected usage with string content is something * like: - * + * + * <pre> + * nc = Nats.connect() + * nc.publish("destination", "message".getBytes("UTF-8")) + * </pre> + * <p> + * where the sender creates a byte array immediately before calling publish. + * <p> + * See {@link #publish(String, String, byte[]) publish()} for more details on + * publish during reconnect. + * + * @param subject the subject to send the message to + * @param body the message body + * @throws IllegalStateException if the reconnect buffer is exceeded + */ + void publish(String subject, byte[] body); + + /** + * Send a message to the specified subject. The message body <strong>will + * not</strong> be copied. The expected usage with string content is something + * like: + * * <pre> * nc = Nats.connect() * nc.publish("destination", "message".getBytes("UTF-8")) * </pre> - * + * <p> * where the sender creates a byte array immediately before calling publish. - * - * See {@link #publish(String, String, byte[]) publish()} for more details on + * <p> + * See {@link #publish(String, String, byte[]) publish()} for more details on * publish during reconnect. - * + * * @param subject the subject to send the message to - * @param body the message body + * @param body the message body + * @param headers the message headers * @throws IllegalStateException if the reconnect buffer is exceeded */ - public void publish(String subject, byte[] body); + void publish(String subject, byte[] body, Headers headers); /** * Send a request to the specified subject, providing a replyTo subject. The * message body <strong>will not</strong> be copied. The expected usage with * string content is something like: - * + * * <pre> * nc = Nats.connect() * nc.publish("destination", "reply-to", "message".getBytes("UTF-8")) * </pre> - * + * <p> * where the sender creates a byte array immediately before calling publish. * <p> * During reconnect the client will try to buffer messages. The buffer size is set @@ -140,204 +164,263 @@ public enum Status { * If the buffer is exceeded an IllegalStateException is thrown. Applications should use * this exception as a signal to wait for reconnect before continuing. * </p> + * * @param subject the subject to send the message to * @param replyTo the subject the receiver should send the response to - * @param body the message body + * @param body the message body * @throws IllegalStateException if the reconnect buffer is exceeded */ - public void publish(String subject, String replyTo, byte[] body); + void publish(String subject, String replyTo, byte[] body); + + /** + * Send a request to the specified subject, providing a replyTo subject. The + * message body <strong>will not</strong> be copied. The expected usage with + * string content is something like: + * + * <pre> + * nc = Nats.connect() + * nc.publish("destination", "reply-to", "message".getBytes("UTF-8")) + * </pre> + * <p> + * where the sender creates a byte array immediately before calling publish. + * <p> + * During reconnect the client will try to buffer messages. The buffer size is set + * in the connect options, see {@link Options.Builder#reconnectBufferSize(long) reconnectBufferSize()} + * with a default value of {@link Options#DEFAULT_RECONNECT_BUF_SIZE 8 * 1024 * 1024} bytes. + * If the buffer is exceeded an IllegalStateException is thrown. Applications should use + * this exception as a signal to wait for reconnect before continuing. + * </p> + * + * @param subject the subject to send the message to + * @param replyTo the subject the receiver should send the response to + * @param body the message body + * @param headers the headers + * @throws IllegalStateException if the reconnect buffer is exceeded + */ + void publish(String subject, String replyTo, byte[] body, Headers headers); + + /** + * + * Send a request to the specified subject, providing a replyTo subject. The + * message body <strong>will not</strong> be copied. + * + * @param message that you are sending. + */ + void publish(final T message); /** * Send a request. The returned future will be completed when the * response comes back. - * + * * @param subject the subject for the service that will handle the request - * @param data the content of the message + * @param data the content of the message * @return a Future for the response, which may be cancelled on error or timed out */ - public CompletableFuture<Message> request(String subject, byte[] data); + CompletableFuture<Message> request(String subject, byte[] data); + + /** + * Send a request. The returned future will be completed when the + * response comes back. + * + * @param requestMessage message that you are sending + * @return a Future for the response, which may be cancelled on error or timed out + */ + CompletableFuture<Message> request(final Message requestMessage); /** * Send a request and returns the reply or null. This version of request is equivalent * to calling get on the future returned from {@link #request(String, byte[]) request()} with * the timeout and handling the ExecutionException and TimeoutException. - * + * * @param subject the subject for the service that will handle the request - * @param data the content of the message + * @param data the content of the message * @param timeout the time to wait for a response * @return the reply message or null if the timeout is reached * @throws InterruptedException if one is thrown while waiting, in order to propogate it up */ - public Message request(String subject, byte[] data, Duration timeout) throws InterruptedException; + Message request(String subject, byte[] data, Duration timeout) throws InterruptedException; + + /** + * Send a request and returns the reply or null. This version of request is equivalent + * to calling get on the future returned from {@link #request(String, byte[]) request()} with + * the timeout and handling the ExecutionException and TimeoutException. + * + * @param requestMessage request message that you are sending + * @param timeout the time to wait for a response + * @return the reply message or null if the timeout is reached + * @throws InterruptedException if one is thrown while waiting, in order to propogate it up + */ + Message request(final Message requestMessage, final Duration timeout) throws InterruptedException; + /** * Create a synchronous subscription to the specified subject. - * + * * <p>Use the {@link io.nats.client.Subscription#nextMessage(Duration) nextMessage} * method to read messages for this subscription. - * + * * <p>See {@link #createDispatcher(MessageHandler) createDispatcher} for * information about creating an asynchronous subscription with callbacks. - * + * * <p>As of 2.6.1 this method will throw an IllegalArgumentException if the subject contains whitespace. - * + * * @param subject the subject to subscribe to * @return an object representing the subscription */ - public Subscription subscribe(String subject); + Subscription subscribe(String subject); /** * Create a synchronous subscription to the specified subject and queue. - * + * * <p>Use the {@link Subscription#nextMessage(Duration) nextMessage} method to read * messages for this subscription. - * + * * <p>See {@link #createDispatcher(MessageHandler) createDispatcher} for * information about creating an asynchronous subscription with callbacks. - * + * * <p>As of 2.6.1 this method will throw an IllegalArgumentException if either string contains whitespace. - * - * @param subject the subject to subscribe to + * + * @param subject the subject to subscribe to * @param queueName the queue group to join * @return an object representing the subscription */ - public Subscription subscribe(String subject, String queueName); + Subscription subscribe(String subject, String queueName); /** * Create a {@code Dispatcher} for this connection. The dispatcher can group one * or more subscriptions into a single callback thread. All messages go to the * same {@code MessageHandler}. - * + * * <p>Use the Dispatcher's {@link Dispatcher#subscribe(String)} and * {@link Dispatcher#subscribe(String, String)} methods to add subscriptions. - * + * * <pre> * nc = Nats.connect() * d = nc.createDispatcher((m) -> System.out.println(m)).subscribe("hello"); * </pre> - * + * * @param handler The target for the messages * @return a new Dispatcher */ - public Dispatcher createDispatcher(MessageHandler handler); + Dispatcher createDispatcher(MessageHandler handler); /** * Close a dispatcher. This will unsubscribe any subscriptions and stop the delivery thread. - * + * * <p>Once closed the dispatcher will throw an exception on subsequent subscribe or unsubscribe calls. - * + * * @param dispatcher the dispatcher to close */ - public void closeDispatcher(Dispatcher dispatcher); + void closeDispatcher(Dispatcher dispatcher); /** * Flush the connection's buffer of outgoing messages, including sending a * protocol message to and from the server. Passing null is equivalent to * passing 0, which will wait forever. - * + * <p> * If called while the connection is closed, this method will immediately * throw a TimeoutException, regardless of the timeout. - * + * <p> * If called while the connection is disconnected due to network issues this * method will wait for up to the timeout for a reconnect or close. - * + * * @param timeout The time to wait for the flush to succeed, pass 0 to wait - * forever. - * @throws TimeoutException if the timeout is exceeded + * forever. + * @throws TimeoutException if the timeout is exceeded * @throws InterruptedException if the underlying thread is interrupted */ - public void flush(Duration timeout) throws TimeoutException, InterruptedException; + void flush(Duration timeout) throws TimeoutException, InterruptedException; /** * Drain tells the connection to process in flight messages before closing. - * + * <p> * Drain initially drains all of the consumers, stopping incoming messages. * Next, publishing is halted and a flush call is used to insure all published * messages have reached the server. * Finally the connection is closed. - * + * <p> * In order to drain subscribers, an unsub protocol message is sent to the server followed by a flush. * These two steps occur before drain returns. The remaining steps occur in a background thread. * This method tries to manage the timeout properly, so that if the timeout is 1 second, and the flush * takes 100ms, the remaining steps have 900ms in the background thread. - * + * <p> * The connection will try to let all messages be drained, but when the timeout is reached * the connection is closed and any outstanding dispatcher threads are interrupted. - * + * <p> * A future is used to allow this call to be treated as synchronous or asynchronous as * needed by the application. The value of the future will be true if all of the subscriptions * were drained in the timeout, and false otherwise. The future is completed after the connection * is closed, so any connection handler notifications will happen before the future completes. - * + * * @param timeout The time to wait for the drain to succeed, pass 0 to wait - * forever. Drain involves moving messages to and from the server - * so a very short timeout is not recommended. If the timeout is reached before - * the drain completes, the connection is simply closed, which can result in message - * loss. + * forever. Drain involves moving messages to and from the server + * so a very short timeout is not recommended. If the timeout is reached before + * the drain completes, the connection is simply closed, which can result in message + * loss. * @return A future that can be used to check if the drain has completed * @throws InterruptedException if the thread is interrupted - * @throws TimeoutException if the initial flush times out + * @throws TimeoutException if the initial flush times out */ - public CompletableFuture<Boolean> drain(Duration timeout) throws TimeoutException, InterruptedException; + CompletableFuture<Boolean> drain(Duration timeout) throws TimeoutException, InterruptedException; /** * Close the connection and release all blocking calls like {@link #flush flush} * and {@link Subscription#nextMessage(Duration) nextMessage}. - * + * <p> * If close() is called after {@link #drain(Duration) drain} it will wait up to the connection timeout * to return, but it will not initiate a close. The drain takes precedence and will initiate the close. - * + * * @throws InterruptedException if the thread, or one owned by the connection is interrupted during the close */ - public void close() throws InterruptedException ; + void close() throws InterruptedException; /** * Returns the connections current status. - * + * * @return the connection's status */ - public Status getStatus(); + Status getStatus(); /** * MaxPayload returns the size limit that a message payload can have. This is * set by the server configuration and delivered to the client upon connect. - * + * * @return the maximum size of a message payload */ - public long getMaxPayload(); + long getMaxPayload(); /** * Return the list of known server urls, including additional servers discovered * after a connection has been established. - * + * * @return this connection's list of known server URLs */ - public Collection<String> getServers(); + Collection<String> getServers(); /** * @return a wrapper for useful statistics about the connection */ - public Statistics getStatistics(); + Statistics getStatistics(); /** * @return the read-only options used to create this connection */ - public Options getOptions(); + Options getOptions(); /** * @return the url used for the current connection, or null if disconnected */ - public String getConnectedUrl(); - + String getConnectedUrl(); + /** * @return the error text from the last error sent by the server to this client */ - public String getLastError(); + String getLastError(); /** * @return a new inbox subject, can be used for directed replies from * subscribers. These are guaranteed to be unique, but can be shared and subscribed * to by others. */ - public String createInbox(); + String createInbox(); } diff --git a/src/main/java/io/nats/client/Message.java b/src/main/java/io/nats/client/Message.java index d150478e0..a353cafde 100644 --- a/src/main/java/io/nats/client/Message.java +++ b/src/main/java/io/nats/client/Message.java @@ -13,6 +13,8 @@ package io.nats.client; +import io.nats.client.impl.Headers; + /** * The NATS library uses a Message object to encapsulate incoming messages. Applications * publish and send requests with raw strings and byte[] but incoming messages can have a few @@ -22,35 +24,56 @@ * and is safe to manipulate. */ public interface Message { + byte[] EMPTY_BODY = new byte[0]; /** * @return the subject that this message was sent to */ - public String getSubject(); + String getSubject(); /** * @return the subject the application is expected to send a reply message on */ - public String getReplyTo(); + String getReplyTo(); + + /** + * @return the headers from the message + */ + Headers getHeaders(); + + /** + * @return if is utf8Mode + */ + boolean isUtf8mode(); /** * @return the data from the message */ - public byte[] getData(); + byte[] getData(); /** * @return the Subscription associated with this message, may be owned by a Dispatcher */ - public Subscription getSubscription(); + Subscription getSubscription(); /** * @return the id associated with the subscription, used by the connection when processing an incoming * message from the server */ - public String getSID(); + String getSID(); /** * @return the connection which can be used for publishing, will be null if the subscription is null */ - public Connection getConnection(); + Connection getConnection(); + + /** + * @return the protocol bytes + */ + byte[] getProtocolBytes(); + + /** + * @return the message size in bytes + */ + long getSizeInBytes(); } diff --git a/src/main/java/io/nats/client/Nats.java b/src/main/java/io/nats/client/Nats.java index 225fa9075..9c5a4cbb8 100644 --- a/src/main/java/io/nats/client/Nats.java +++ b/src/main/java/io/nats/client/Nats.java @@ -13,10 +13,10 @@ package io.nats.client; -import java.io.IOException; - import io.nats.client.impl.NatsImpl; +import java.io.IOException; + /** * The Nats class is the entry point into the NATS client for Java. This class * is used to create a connection to the NATS server. Connecting is a diff --git a/src/main/java/io/nats/client/impl/ByteBufferUtil.java b/src/main/java/io/nats/client/impl/ByteBufferUtil.java new file mode 100644 index 000000000..b0e1e09f4 --- /dev/null +++ b/src/main/java/io/nats/client/impl/ByteBufferUtil.java @@ -0,0 +1,15 @@ +package io.nats.client.impl; + +import java.nio.ByteBuffer; + +public class ByteBufferUtil { + + public static ByteBuffer enlargeBuffer(ByteBuffer buffer, int atLeast) { + int current = buffer.capacity(); + int newSize = Math.max(current * 2, atLeast); + ByteBuffer newBuffer = ByteBuffer.allocate(newSize); + buffer.flip(); + newBuffer.put(buffer); + return newBuffer; + } +} diff --git a/src/main/java/io/nats/client/impl/Headers.java b/src/main/java/io/nats/client/impl/Headers.java new file mode 100644 index 000000000..007559dda --- /dev/null +++ b/src/main/java/io/nats/client/impl/Headers.java @@ -0,0 +1,188 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package io.nats.client.impl; + +import java.util.*; + +public class Headers { + private static final String KEY_CANNOT_BE_EMPTY_OR_NULL = "Header key cannot be null."; + private static final String VALUES_CANNOT_BE_EMPTY_OR_NULL = "Header values cannot be empty or null."; + + private final Map<String, Set<String>> headerMap = new HashMap<>(); + + /** + * If the key is present add the values to the set of values for the key. + * If the key is not present, sets the specified values for the key. + * Duplicate values are ignored. Null and empty values are not allowed + * + * @param key the key + * @param values the values + * @return {@code true} if this object did not already contain values for the key + * @throws IllegalArgumentException if the key is null or empty + * -or- any value is null or empty. + */ + public boolean add(String key, String... values) { + return add(key, Arrays.asList(values)); + } + + /** + * If the key is present add the values to the set of values for the key. + * If the key is not present, sets the specified values for the key. + * Duplicate values are ignored. Null and empty values are not allowed. + * + * @param key the key + * @param values the values + * @return {@code true} if this object did not already contain the key or the + * values for the key changed. + * @throws IllegalArgumentException if the key is null or empty + * -or- if then input collection is null + * -or- if any item in the collection is null or empty. + */ + public boolean add(String key, Collection<String> values) { + Set<String> validatedSet = validateKeyAndValues(key, values); + + Set<String> currentSet = headerMap.get(key); + if (currentSet == null) { + headerMap.put(key, validatedSet); + return true; + } + + return currentSet.addAll(validatedSet); + } + + private Set<String> validateKeyAndValues(String key, Collection<String> values) { + keyCannotBeNull(key); + valuesCannotBeEmptyOrNull(values); + Set<String> validatedSet = new HashSet<>(); + for (String v : values) { + valueCannotBeEmptyOrNull(v); + validatedSet.add(v); + } + return validatedSet; + } + + /** + * Associates the specified values with the key. If the key was already present + * any existing values are removed and replaced with the new set. + * Duplicate values are ignored. Null and empty values are not allowed + * + * @param key the key + * @param values the values + * @return {@code true} if this object did not already contain values for the key + * @throws IllegalArgumentException if the key is null or empty + * -or- any value is null or empty. + */ + public boolean put(String key, String... values) { + return put(key, Arrays.asList(values)); + } + + /** + * Associates the specified values with the key. If the key was already present + * any existing values are removed and replaced with the new set. + * Duplicate values are ignored. Null and empty values are not allowed + * + * @param key the key + * @param values the values + * @return {@code true} if this object did not already contain values for the key + * @throws IllegalArgumentException if the key is null or empty + * -or- if then input collection is null + * -or- if any item in the collection is null or empty. + */ + public boolean put(String key, Collection<String> values) { + Set<String> validatedSet = validateKeyAndValues(key, values); + return headerMap.put(key, validatedSet) == null; + } + + /** + * Removes each key and its values if the key was present + * + * @param keys the key or keys to remove + * @return {@code true} if any key was present + */ + public boolean remove(String... keys) { + return remove(Arrays.asList(keys)); + } + + /** + * Removes each key and its values if the key was present + * + * @param keys the key or keys to remove + * @return {@code true} if any key was present + */ + public boolean remove(Collection<String> keys) { + boolean changed = false; + for (String key : keys) { + if (headerMap.remove(key) != null) { + changed = true; + } + } + return changed; + } + + public int size() { + return headerMap.size(); + } + + public boolean isEmpty() { + return headerMap.isEmpty(); + } + + public void clear() { + headerMap.clear(); + } + + public boolean containsKey(String key) { + return headerMap.containsKey(key); + } + + public Set<String> values(String key) { + Set<String> set = headerMap.get(key); + return set == null ? null : Collections.unmodifiableSet(set); + } + + public Set<String> keySet() { + return headerMap.keySet(); + } + + private void keyCannotBeNull(String key) { + if (key == null || key.length() == 0) { + throw new IllegalArgumentException(KEY_CANNOT_BE_EMPTY_OR_NULL); + } + } + + private void valueCannotBeEmptyOrNull(String val) { + if (val == null || val.length() == 0) { + throw new IllegalArgumentException(VALUES_CANNOT_BE_EMPTY_OR_NULL); + } + } + + private void valuesCannotBeEmptyOrNull(Collection<String> vals) { + if (vals == null || vals.size() == 0) { + throw new IllegalArgumentException(VALUES_CANNOT_BE_EMPTY_OR_NULL); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Headers headers = (Headers) o; + return Objects.equals(headerMap, headers.headerMap); + } + + @Override + public int hashCode() { + return Objects.hash(headerMap); + } +} diff --git a/src/main/java/io/nats/client/impl/MessageQueue.java b/src/main/java/io/nats/client/impl/MessageQueue.java index d45ea2486..02264efa0 100644 --- a/src/main/java/io/nats/client/impl/MessageQueue.java +++ b/src/main/java/io/nats/client/impl/MessageQueue.java @@ -58,7 +58,7 @@ class MessageQueue { this.length = new AtomicLong(0); // The poisonPill is used to stop poll and accumulate when the queue is stopped - this.poisonPill = new NatsMessage("_poison", null, NatsConnection.EMPTY_BODY, false); + this.poisonPill = new NatsMessage.PublishBuilder().subject("_poison").build(); this.filterLock = new ReentrantLock(); @@ -240,9 +240,9 @@ NatsMessage accumulate(long maxSize, long maxMessages, Duration timeout) if (maxSize<0 || (size + s) < maxSize) { // keep going size += s; count++; - - cursor.next = this.queue.poll(); - cursor = cursor.next; + + cursor.setNext(this.queue.poll()); + cursor = cursor.getNext(); if (count == maxMessages) { break; diff --git a/src/main/java/io/nats/client/impl/NatsConnection.java b/src/main/java/io/nats/client/impl/NatsConnection.java index 21e783369..a217d5230 100644 --- a/src/main/java/io/nats/client/impl/NatsConnection.java +++ b/src/main/java/io/nats/client/impl/NatsConnection.java @@ -13,6 +13,9 @@ package io.nats.client.impl; +import io.nats.client.*; +import io.nats.client.ConnectionListener.Events; + import java.io.IOException; import java.nio.ByteBuffer; import java.nio.CharBuffer; @@ -21,28 +24,8 @@ import java.time.Instant; import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Timer; -import java.util.TimerTask; -import java.util.concurrent.CancellationException; -import java.util.concurrent.Callable; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentLinkedDeque; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; -import java.util.concurrent.RejectedExecutionException; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; +import java.util.*; +import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; @@ -52,23 +35,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; -import io.nats.client.AuthenticationException; -import io.nats.client.Connection; -import io.nats.client.ConnectionListener; -import io.nats.client.ConnectionListener.Events; -import io.nats.client.Consumer; -import io.nats.client.Dispatcher; -import io.nats.client.ErrorListener; -import io.nats.client.Message; -import io.nats.client.MessageHandler; -import io.nats.client.NUID; -import io.nats.client.Options; -import io.nats.client.Statistics; -import io.nats.client.Subscription; - -class NatsConnection implements Connection { - static final byte[] EMPTY_BODY = new byte[0]; - +class NatsConnection implements Connection<NatsMessage> { static final byte CR = 0x0D; static final byte LF = 0x0A; static final byte[] CRLF = { CR, LF }; @@ -77,8 +44,11 @@ class NatsConnection implements Connection { static final String OP_INFO = "INFO"; static final String OP_SUB = "SUB"; static final String OP_PUB = "PUB"; + static final String OP_HPUB = "HPUB"; static final String OP_UNSUB = "UNSUB"; static final String OP_MSG = "MSG"; + static final String OP_HMSG = "HMSG"; + static final String OP_PING = "PING"; static final String OP_PONG = "PONG"; static final String OP_OK = "+OK"; @@ -183,7 +153,42 @@ class NatsConnection implements Connection { this.connectExecutor = Executors.newSingleThreadExecutor(); timeTrace(trace, "creating reader and writer"); - this.reader = new NatsConnectionReader(this); + this.reader = new NatsConnectionReader(new ProtocolHandler() { + @Override + public void handleCommunicationIssue(Exception io) { + NatsConnection.this.handleCommunicationIssue(io); + } + + @Override + public void deliverMessage(NatsMessage msg) { + NatsConnection.this.deliverMessage(msg); + } + + @Override + public void processOK() { + NatsConnection.this.processOK(); + } + + @Override + public void processError(String errorText) { + NatsConnection.this.processError(errorText); + } + + @Override + public void sendPong() { + NatsConnection.this.sendPong(); + } + + @Override + public void handlePong() { + NatsConnection.this.handlePong(); + } + + @Override + public void handleInfo(String infoJson) { + NatsConnection.this.handleInfo(infoJson); + } + }, this.options, this.statistics, this.executor); this.writer = new NatsConnectionWriter(this); this.needPing = new AtomicBoolean(true); @@ -777,41 +782,45 @@ void cleanUpPongQueue() { } } + @Override public void publish(String subject, byte[] body) { - this.publish(subject, null, body); + publish(subject, null, body, null); } + @Override + public void publish(String subject, byte[] body, Headers headers) { + publish(subject, null, body, headers); + } + + @Override public void publish(String subject, String replyTo, byte[] body) { + publish(subject, replyTo, body, null); + } + + @Override + public void publish(String subject, String replyTo, byte[] body, Headers headers) { + publish(new NatsMessage.PublishBuilder() + .subject(subject).replyTo(replyTo) + .headers(headers).data(body) + .utf8mode(options.supportUTF8Subjects()) + .maxPayload(getMaxPayload()) + .build()); + } + public void publish(NatsMessage message) { if (isClosed()) { throw new IllegalStateException("Connection is Closed"); } else if (blockPublishForDrain.get()) { throw new IllegalStateException("Connection is Draining"); // Ok to publish while waiting on subs } - if (subject == null || subject.length() == 0) { - throw new IllegalArgumentException("Subject is required in publish"); - } - - if (replyTo != null && replyTo.length() == 0) { - throw new IllegalArgumentException("ReplyTo cannot be the empty string"); - } - - if (body == null) { - body = EMPTY_BODY; - } else if (body.length > this.getMaxPayload() && this.getMaxPayload() > 0) { - throw new IllegalArgumentException( - "Message payload size exceed server configuration " + body.length + " vs " + this.getMaxPayload()); - } - - NatsMessage msg = new NatsMessage(subject, replyTo, body, options.supportUTF8Subjects()); - if ((this.status == Status.RECONNECTING || this.status == Status.DISCONNECTED) - && !this.writer.canQueue(msg, options.getReconnectBufferSize())) { + && !this.writer.canQueue(message, options.getReconnectBufferSize())) { throw new IllegalStateException( "Unable to queue any more messages during reconnect, max buffer is " + options.getReconnectBufferSize()); } - queueOutgoing(msg); + + queueOutgoing(message); } public Subscription subscribe(String subject) { @@ -892,17 +901,17 @@ void unsubscribe(NatsSubscription sub, int after) { void sendUnsub(NatsSubscription sub, int after) { String sid = sub.getSID(); - CharBuffer protocolBuilder = CharBuffer.allocate(this.options.getMaxControlLine()); - protocolBuilder.append(OP_UNSUB); - protocolBuilder.append(" "); - protocolBuilder.append(sid); + CharBuffer buffer = CharBuffer.allocate(this.options.getMaxControlLine()); + buffer.append(OP_UNSUB); + buffer.append(" "); + buffer.append(sid); if (after > 0) { - protocolBuilder.append(" "); - protocolBuilder.append(String.valueOf(after)); + buffer.append(" "); + buffer.append(String.valueOf(after)); } - protocolBuilder.flip(); - NatsMessage unsubMsg = new NatsMessage(protocolBuilder); + buffer.flip(); + NatsMessage unsubMsg = NatsMessage.getProtocolInstance(buffer); queueInternalOutgoing(unsubMsg); } @@ -946,7 +955,7 @@ void sendSubscriptionMessage(CharSequence sid, String subject, String queueName, protocolBuilder.append(" "); protocolBuilder.append(sid); protocolBuilder.flip(); - NatsMessage subMsg = new NatsMessage(protocolBuilder); + NatsMessage subMsg = NatsMessage.getProtocolInstance(protocolBuilder); if (treatAsInternal) { queueInternalOutgoing(subMsg); @@ -1019,6 +1028,89 @@ public Message request(String subject, byte[] body, Duration timeout) throws Int return reply; } + @Override + public Message request(Message requestMessage, Duration timeout) throws InterruptedException { + Message reply = null; + Future<Message> incoming = this.request(requestMessage); + try { + reply = incoming.get(timeout.toNanos(), TimeUnit.NANOSECONDS); + } catch (TimeoutException e) { + incoming.cancel(true); + } catch (Throwable e) { + throw new AssertionError(e); + } + + return reply; + } + + public CompletableFuture<Message> request(final Message requestMessage) { + String responseInbox = null; + boolean oldStyle = options.isOldRequestStyle(); + + if (isClosed()) { + throw new IllegalStateException("Connection is Closed"); + } else if (isDraining()) { + throw new IllegalStateException("Connection is Draining"); + } + + if (inboxDispatcher.get() == null) { + NatsDispatcher d = new NatsDispatcher(this, (msg) -> { + deliverReply(msg); + }); + + if (inboxDispatcher.compareAndSet(null, d)) { + String id = this.nuid.next(); + this.dispatchers.put(id, d); + d.start(id); + d.subscribe(this.mainInbox); + } + } + + if (oldStyle) { + responseInbox = createInbox(); + } else { + responseInbox = createResponseInbox(this.mainInbox); + } + + String responseToken = getResponseToken(responseInbox); + CompletableFuture<Message> future = new CompletableFuture<>(); + + if (!oldStyle) { + responses.put(responseToken, future); + } + statistics.incrementOutstandingRequests(); + + if (oldStyle) { + NatsDispatcher dispatcher = this.inboxDispatcher.get(); + NatsSubscription sub = dispatcher.subscribeReturningSubscription(responseInbox); + dispatcher.unsubscribe(responseInbox, 1); + // Unsubscribe when future is cancelled: + String finalResponseInbox = responseInbox; + future.whenComplete((msg, exception) -> { + if ( null != exception && exception instanceof CancellationException ) { + dispatcher.unsubscribe(finalResponseInbox); + } + }); + responses.put(sub.getSID(), future); + } + + if (requestMessage.getReplyTo() != null) { + throw new IllegalArgumentException("Reply To must not be set"); + } + + publish(new NatsMessage.PublishBuilder() + .subject(requestMessage.getSubject()) + .replyTo(responseInbox) + .headers(requestMessage.getHeaders()) + .data(requestMessage.getData()) + .build()); + + statistics.incrementRequestsSent(); + + return future; + } + + public CompletableFuture<Message> request(String subject, byte[] body) { String responseInbox = null; boolean oldStyle = options.isOldRequestStyle(); @@ -1034,7 +1126,7 @@ public CompletableFuture<Message> request(String subject, byte[] body) { } if (body == null) { - body = EMPTY_BODY; + body = Message.EMPTY_BODY; } else if (body.length > this.getMaxPayload() && this.getMaxPayload() > 0) { throw new IllegalArgumentException( "Message payload size exceed server configuration " + body.length + " vs " + this.getMaxPayload()); @@ -1205,7 +1297,7 @@ void sendConnect(String serverURI) throws IOException { connectString.append(" "); connectString.append(connectOptions); connectString.flip(); - NatsMessage msg = new NatsMessage(connectString); + NatsMessage msg = NatsMessage.getProtocolInstance(connectString); queueInternalOutgoing(msg); } catch (Exception exp) { @@ -1247,7 +1339,7 @@ CompletableFuture<Boolean> sendPing(boolean treatAsInternal) { } CompletableFuture<Boolean> pongFuture = new CompletableFuture<>(); - NatsMessage msg = new NatsMessage(CharBuffer.wrap(NatsConnection.OP_PING)); + NatsMessage msg = NatsMessage.getProtocolInstance(NatsConnection.OP_PING); pongQueue.add(pongFuture); if (treatAsInternal) { @@ -1262,7 +1354,7 @@ CompletableFuture<Boolean> sendPing(boolean treatAsInternal) { } void sendPong() { - NatsMessage msg = new NatsMessage(CharBuffer.wrap(NatsConnection.OP_PONG)); + NatsMessage msg = NatsMessage.getProtocolInstance(NatsConnection.OP_PONG); queueInternalOutgoing(msg); } @@ -1307,7 +1399,7 @@ void readInitialInfo() throws IOException { gotCR = true; } else { if (!protocolBuffer.hasRemaining()) { - protocolBuffer = enlargeBuffer(protocolBuffer, 0); // just double it + protocolBuffer = ByteBufferUtil.enlargeBuffer(protocolBuffer, 0); // just double it } protocolBuffer.put(b); } @@ -1712,14 +1804,7 @@ List<String> buildServerList() { return reconnectList; } - ByteBuffer enlargeBuffer(ByteBuffer buffer, int atLeast) { - int current = buffer.capacity(); - int newSize = Math.max(current * 2, atLeast); - ByteBuffer newBuffer = ByteBuffer.allocate(newSize); - buffer.flip(); - newBuffer.put(buffer); - return newBuffer; - } + // For testing NatsConnectionReader getReader() { diff --git a/src/main/java/io/nats/client/impl/NatsConnectionReader.java b/src/main/java/io/nats/client/impl/NatsConnectionReader.java index d827197b2..20ba60dee 100644 --- a/src/main/java/io/nats/client/impl/NatsConnectionReader.java +++ b/src/main/java/io/nats/client/impl/NatsConnectionReader.java @@ -1,482 +1,651 @@ -// Copyright 2015-2018 The NATS Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package io.nats.client.impl; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.CharBuffer; -import java.nio.charset.StandardCharsets; -import java.util.concurrent.CancellationException; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Future; -import java.util.concurrent.atomic.AtomicBoolean; - -class NatsConnectionReader implements Runnable { - static final int MAX_PROTOCOL_OP_LENGTH = 4; - static final String UNKNOWN_OP = "UNKNOWN"; - static final char SPACE = ' '; - static final char TAB = '\t'; - - enum Mode { - GATHER_OP, - GATHER_PROTO, - GATHER_MSG_PROTO, - PARSE_PROTO, - GATHER_DATA - }; - - private final NatsConnection connection; - - private ByteBuffer protocolBuffer; // use a byte buffer to assist character decoding - - private boolean gotCR; - - private String op; - private char[] opArray; - private int opPos; - - private char[] msgLineChars; - private int msgLinePosition; - - private Mode mode; - - private NatsMessage incoming; - private byte[] msgData; - private int msgDataPosition; - - private byte[] buffer; - private int bufferPosition; - - private Future<Boolean> stopped; - private Future<DataPort> dataPortFuture; - private final AtomicBoolean running; - - private final boolean utf8Mode; - - NatsConnectionReader(NatsConnection connection) { - this.connection = connection; - - this.running = new AtomicBoolean(false); - this.stopped = new CompletableFuture<>(); - ((CompletableFuture<Boolean>)this.stopped).complete(Boolean.TRUE); // we are stopped on creation - - this.protocolBuffer = ByteBuffer.allocate(this.connection.getOptions().getMaxControlLine()); - this.msgLineChars = new char[this.connection.getOptions().getMaxControlLine()]; - this.opArray = new char[MAX_PROTOCOL_OP_LENGTH]; - this.buffer = new byte[connection.getOptions().getBufferSize()]; - this.bufferPosition = 0; - - this.utf8Mode = connection.getOptions().supportUTF8Subjects(); - } - - // Should only be called if the current thread has exited. - // Use the Future from stop() to determine if it is ok to call this. - // This method resets that future so mistiming can result in badness. - void start(Future<DataPort> dataPortFuture) { - this.dataPortFuture = dataPortFuture; - this.running.set(true); - this.stopped = connection.getExecutor().submit(this, Boolean.TRUE); - } - - // May be called several times on an error. - // Returns a future that is completed when the thread completes, not when this - // method does. - Future<Boolean> stop() { - this.running.set(false); - return stopped; - } - - @Override - public void run() { - try { - DataPort dataPort = this.dataPortFuture.get(); // Will wait for the future to complete - this.mode = Mode.GATHER_OP; - this.gotCR = false; - this.opPos = 0; - - while (this.running.get()) { - this.bufferPosition = 0; - int bytesRead = dataPort.read(this.buffer, 0, this.buffer.length); - - if (bytesRead > 0) { - connection.getNatsStatistics().registerRead(bytesRead); - - while (this.bufferPosition < bytesRead) { - if (this.mode == Mode.GATHER_OP) { - this.gatherOp(bytesRead); - } else if (this.mode == Mode.GATHER_MSG_PROTO) { - if (this.utf8Mode) { - this.gatherProtocol(bytesRead); - } else { - this.gatherMessageProtocol(bytesRead); - } - } else if (this.mode == Mode.GATHER_PROTO) { - this.gatherProtocol(bytesRead); - } else { - this.gatherMessageData(bytesRead); - } - - if (this.mode == Mode.PARSE_PROTO) { // Could be the end of the read - this.parseProtocolMessage(); - this.protocolBuffer.clear(); - } - } - } else if (bytesRead < 0) { - throw new IOException("Read channel closed."); - } else { - this.connection.getNatsStatistics().registerRead(bytesRead); // track the 0 - } - } - } catch (IOException io) { - this.connection.handleCommunicationIssue(io); - } catch (CancellationException | ExecutionException | InterruptedException ex) { - // Exit - } finally { - this.running.set(false); - // Clear the buffers, since they are only used inside this try/catch - // We will reuse later - this.protocolBuffer.clear(); - } - } - - // Gather the op, either up to the first space or the first carriage return. - void gatherOp(int maxPos) throws IOException { - try { - while(this.bufferPosition < maxPos) { - byte b = this.buffer[this.bufferPosition]; - this.bufferPosition++; - - if (gotCR) { - if (b == NatsConnection.LF) { // Got CRLF, jump to parsing - this.op = opFor(opArray, opPos); - this.gotCR = false; - this.opPos = 0; - this.mode = Mode.PARSE_PROTO; - break; - } else { - throw new IllegalStateException("Bad socket data, no LF after CR"); - } - } else if (b == SPACE || b == TAB) { // Got a space, get the rest of the protocol line - this.op = opFor(opArray, opPos); - this.opPos = 0; - if (this.op == NatsConnection.OP_MSG) { - this.msgLinePosition = 0; - this.mode = Mode.GATHER_MSG_PROTO; - } else { - this.mode = Mode.GATHER_PROTO; - } - break; - } else if (b == NatsConnection.CR) { - this.gotCR = true; - } else { - this.opArray[opPos] = (char) b; - this.opPos++; - } - } - } catch (ArrayIndexOutOfBoundsException | IllegalStateException | NumberFormatException | NullPointerException ex) { - this.encounteredProtocolError(ex); - } - } - - // Stores the message protocol line in a char buffer that will be grepped for subject, reply - void gatherMessageProtocol(int maxPos) throws IOException { - try { - while(this.bufferPosition < maxPos) { - byte b = this.buffer[this.bufferPosition]; - this.bufferPosition++; - - if (gotCR) { - if (b == NatsConnection.LF) { - this.mode = Mode.PARSE_PROTO; - this.gotCR = false; - break; - } else { - throw new IllegalStateException("Bad socket data, no LF after CR"); - } - } else if (b == NatsConnection.CR) { - this.gotCR = true; - } else { - if (this.msgLinePosition >= this.msgLineChars.length) { - throw new IllegalStateException("Protocol line is too long"); - } - this.msgLineChars[this.msgLinePosition] = (char) b; // Assumes ascii, as per protocol doc - this.msgLinePosition++; - } - } - } catch (IllegalStateException | NumberFormatException | NullPointerException ex) { - this.encounteredProtocolError(ex); - } - } - - // Gather bytes for a protocol line - void gatherProtocol(int maxPos) throws IOException { - // protocol buffer has max capacity, shouldn't need resizing - try { - while(this.bufferPosition < maxPos) { - byte b = this.buffer[this.bufferPosition]; - this.bufferPosition++; - - if (gotCR) { - if (b == NatsConnection.LF) { - this.protocolBuffer.flip(); - this.mode = Mode.PARSE_PROTO; - this.gotCR = false; - break; - } else { - throw new IllegalStateException("Bad socket data, no LF after CR"); - } - } else if (b == NatsConnection.CR) { - this.gotCR = true; - } else { - if (!protocolBuffer.hasRemaining()) { - this.protocolBuffer = this.connection.enlargeBuffer(this.protocolBuffer, 0); // just double it - } - this.protocolBuffer.put(b); - } - } - } catch (IllegalStateException | NumberFormatException | NullPointerException ex) { - this.encounteredProtocolError(ex); - } - } - - // Gather bytes for a message body into a byte array that is then - // given to the message object - void gatherMessageData(int maxPos) throws IOException { - try { - while(this.bufferPosition < maxPos) { - int possible = maxPos - this.bufferPosition; - int want = msgData.length - msgDataPosition; - - // Grab all we can, until we get to the CR/LF - if (want > 0 && want <= possible) { - System.arraycopy(this.buffer, this.bufferPosition, this.msgData, this.msgDataPosition, want); - msgDataPosition += want; - this.bufferPosition += want; - continue; - } else if (want > 0) { - System.arraycopy(this.buffer, this.bufferPosition, this.msgData, this.msgDataPosition, possible); - msgDataPosition += possible; - this.bufferPosition += possible; - continue; - } - - byte b = this.buffer[this.bufferPosition]; - this.bufferPosition++; - - if (gotCR) { - if (b == NatsConnection.LF) { - incoming.setData(msgData); - this.connection.deliverMessage(incoming); - msgData = null; - msgDataPosition = 0; - incoming = null; - gotCR = false; - this.op = UNKNOWN_OP; - this.mode = Mode.GATHER_OP; - break; - } else { - throw new IllegalStateException("Bad socket data, no LF after CR"); - } - } else if (b == NatsConnection.CR) { - gotCR = true; - } else { - throw new IllegalStateException("Bad socket data, no CRLF after data"); - } - } - } catch (IllegalStateException | NullPointerException ex) { - this.encounteredProtocolError(ex); - } - } - - public String grabNextMessageLineElement(int max) { - if (this.msgLinePosition >= max) { - return null; - } - - int start = this.msgLinePosition; - - while (this.msgLinePosition < max) { - char c = this.msgLineChars[this.msgLinePosition]; - this.msgLinePosition++; - - if (c == SPACE || c == TAB) { - String slice = new String(this.msgLineChars, start, this.msgLinePosition - start -1); //don't grab the space, avoid an intermediate char sequence - return slice; - } - } - - return new String(this.msgLineChars, start, this.msgLinePosition-start); - } - - public String opFor(char[] chars, int length) { - if (length == 3) { - if ((chars[0] == 'M' || chars[0] == 'm') && - (chars[1] == 'S' || chars[1] == 's') && - (chars[2] == 'G' || chars[2] == 'g')) { - return NatsConnection.OP_MSG; - } else if (chars[0] == '+' && - (chars[1] == 'O' || chars[1] == 'o') && - (chars[2] == 'K' || chars[2] == 'k')) { - return NatsConnection.OP_OK; - } else { - return UNKNOWN_OP; - } - } else if (length == 4) { // do them in a unique order for uniqueness when possible to branch asap - if ((chars[1] == 'I' || chars[1] == 'i') && - (chars[0] == 'P' || chars[0] == 'p') && - (chars[2] == 'N' || chars[2] == 'n') && - (chars[3] == 'G' || chars[3] == 'g')) { - return NatsConnection.OP_PING; - } else if ((chars[1] == 'O' || chars[1] == 'o') && - (chars[0] == 'P' || chars[0] == 'p') && - (chars[2] == 'N' || chars[2] == 'n') && - (chars[3] == 'G' || chars[3] == 'g')) { - return NatsConnection.OP_PONG; - } else if (chars[0] == '-' && - (chars[1] == 'E' || chars[1] == 'e') && - (chars[2] == 'R' || chars[2] == 'r') && - (chars[3] == 'R' || chars[3] == 'r')) { - return NatsConnection.OP_ERR; - } else if ((chars[0] == 'I' || chars[0] == 'i') && - (chars[1] == 'N' || chars[1] == 'n') && - (chars[2] == 'F' || chars[2] == 'f') && - (chars[3] == 'O' || chars[3] == 'o')) { - return NatsConnection.OP_INFO; - } else { - return UNKNOWN_OP; - } - } else { - return UNKNOWN_OP; - } - } - - private static int[] TENS = new int[] { 1, 10, 100, 1_000, 10_000, 100_000, 1_000_000, 10_000_000, 100_000_000, 1_000_000_000}; - - public static int parseLength(String s) throws NumberFormatException { - int length = s.length(); - int retVal = 0; - - if (length > TENS.length) { - throw new NumberFormatException("Long in message length \"" + s + "\" "+length+" > "+TENS.length); - } - - for (int i=length-1;i>=0;i--) { - char c = s.charAt(i); - int d = (c - '0'); - - if (d>9) { - throw new NumberFormatException("Invalid char in message length \'" + c + "\'"); - } - - retVal += d * TENS[length - i - 1]; - } - - return retVal; - } - - void parseProtocolMessage() throws IOException { - try { - switch (this.op) { - case NatsConnection.OP_MSG: - int protocolLength = this.msgLinePosition; //This is just after the last character - int protocolLineLength = protocolLength + 4; // 4 for the "MSG " - - if (this.utf8Mode) { - protocolLineLength = protocolBuffer.remaining() + 4; - - CharBuffer buff = StandardCharsets.UTF_8.decode(protocolBuffer); - protocolLength = buff.remaining(); - buff.get(this.msgLineChars, 0, protocolLength); - } - - this.msgLinePosition = 0; - String subject = grabNextMessageLineElement(protocolLength); - String sid = grabNextMessageLineElement(protocolLength); - String replyTo = grabNextMessageLineElement(protocolLength); - String lengthChars = null; - - if (this.msgLinePosition < protocolLength) { - lengthChars = grabNextMessageLineElement(protocolLength); - } else { - lengthChars = replyTo; - replyTo = null; - } - - if(subject==null || subject.length() == 0 || sid==null || sid.length() == 0 || lengthChars==null) { - throw new IllegalStateException("Bad MSG control line, missing required fields"); - } - - int incomingLength = parseLength(lengthChars); - - this.incoming = new NatsMessage(sid, subject, replyTo, protocolLineLength); - this.mode = Mode.GATHER_DATA; - this.msgData = new byte[incomingLength]; - this.msgDataPosition = 0; - this.msgLinePosition = 0; - break; - case NatsConnection.OP_OK: - this.connection.processOK(); - this.op = UNKNOWN_OP; - this.mode = Mode.GATHER_OP; - break; - case NatsConnection.OP_ERR: - String errorText = StandardCharsets.UTF_8.decode(protocolBuffer).toString(); - if (errorText != null) { - errorText = errorText.replace("\'", ""); - } - this.connection.processError(errorText); - this.op = UNKNOWN_OP; - this.mode = Mode.GATHER_OP; - break; - case NatsConnection.OP_PING: - this.connection.sendPong(); - this.op = UNKNOWN_OP; - this.mode = Mode.GATHER_OP; - break; - case NatsConnection.OP_PONG: - this.connection.handlePong(); - this.op = UNKNOWN_OP; - this.mode = Mode.GATHER_OP; - break; - case NatsConnection.OP_INFO: - String info = StandardCharsets.UTF_8.decode(protocolBuffer).toString(); - this.connection.handleInfo(info); - this.op = UNKNOWN_OP; - this.mode = Mode.GATHER_OP; - break; - default: - throw new IllegalStateException("Unknown protocol operation "+op); - } - } catch (IllegalStateException | NumberFormatException | NullPointerException ex) { - this.encounteredProtocolError(ex); - } - } - - void encounteredProtocolError(Exception ex) throws IOException { - throw new IOException(ex); - } - - //For testing - void fakeReadForTest(byte[] bytes) { - System.arraycopy(bytes, 0, this.buffer, 0, bytes.length); - this.bufferPosition = 0; - this.op = UNKNOWN_OP; - this.mode = Mode.GATHER_OP; - } - - String currentOp() { - return this.op; - } +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package io.nats.client.impl; + +import io.nats.client.Options; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicBoolean; + +class NatsConnectionReader implements Runnable { + static final int MAX_PROTOCOL_OP_LENGTH = 4; + static final String UNKNOWN_OP = "UNKNOWN"; + static final char SPACE = ' '; + static final char TAB = '\t'; + + private final ExecutorService executor; + + private Headers headers; + private int headerLen; + private int headerStart; + + enum Mode { + GATHER_OP, + GATHER_PROTO, + GATHER_MSG_PROTO, + PARSE_PROTO, + GATHER_DATA, + GATHER_HEADER + }; + + //private final NatsConnection connection; + + private final ProtocolHandler connection; + + private ByteBuffer protocolBuffer; // use a byte buffer to assist character decoding + + private boolean gotCR; + + private String op; + private char[] opArray; + private int opPos; + + private char[] msgLineChars; + private int msgLinePosition; + + private Mode mode = Mode.GATHER_OP; + + private NatsMessage.IncomingBuilder incoming; + private byte[] msgData; + private int msgDataPosition; + + private byte[] buffer; + private int bufferPosition; + + private Future<Boolean> stopped; + private Future<DataPort> dataPortFuture; + private final AtomicBoolean running; + + private final boolean utf8Mode; + + private final NatsStatistics statistics; + + + + + NatsConnectionReader(ProtocolHandler connection, final Options options, final NatsStatistics statistics, + final ExecutorService executor) { + this.connection = connection; + this.statistics = statistics; + this.executor = executor; + + this.running = new AtomicBoolean(false); + this.stopped = new CompletableFuture<>(); + ((CompletableFuture<Boolean>)this.stopped).complete(Boolean.TRUE); // we are stopped on creation + + this.protocolBuffer = ByteBuffer.allocate(options.getMaxControlLine()); + this.msgLineChars = new char[options.getMaxControlLine()]; + this.opArray = new char[MAX_PROTOCOL_OP_LENGTH]; + this.buffer = new byte[options.getBufferSize()]; + this.bufferPosition = 0; + + this.utf8Mode = options.supportUTF8Subjects(); + } + + // Should only be called if the current thread has exited. + // Use the Future from stop() to determine if it is ok to call this. + // This method resets that future so mistiming can result in badness. + void start(Future<DataPort> dataPortFuture) { + this.dataPortFuture = dataPortFuture; + this.running.set(true); + this.stopped =this.executor.submit(this, Boolean.TRUE); + } + + // May be called several times on an error. + // Returns a future that is completed when the thread completes, not when this + // method does. + Future<Boolean> stop() { + this.running.set(false); + return stopped; + } + + @Override + public void run() { + try { + DataPort dataPort = this.dataPortFuture.get(); // Will wait for the future to complete + init(); + + while (this.running.get()) { + runOnce(dataPort); + } + } catch (IOException io) { + this.connection.handleCommunicationIssue(io); + } catch (CancellationException | ExecutionException | InterruptedException ex) { + // Exit + } finally { + this.running.set(false); + // Clear the buffers, since they are only used inside this try/catch + // We will reuse later + this.protocolBuffer.clear(); + } + } + + void init() { + this.mode = Mode.GATHER_OP; + this.gotCR = false; + this.opPos = 0; + } + + void runOnce(DataPort dataPort) throws IOException { + this.bufferPosition = 0; + int bytesRead = dataPort.read(this.buffer, 0, this.buffer.length); + + if (bytesRead > 0) { + statistics.registerRead(bytesRead); + + while (this.bufferPosition < bytesRead) { + if (this.mode == Mode.GATHER_OP) { + this.gatherOp(bytesRead); + } else if (this.mode == Mode.GATHER_MSG_PROTO) { + if (this.utf8Mode) { + this.gatherProtocol(bytesRead); + } else { + this.gatherMessageProtocol(bytesRead); + } + } else if (this.mode == Mode.GATHER_PROTO) { + this.gatherProtocol(bytesRead); + } else if (this.mode == Mode.GATHER_DATA){ + this.gatherMessageData(bytesRead); + } else if (this.mode == Mode.GATHER_HEADER) { + this.gatherHeaders(bytesRead); + } else { + this.gatherMessageData(bytesRead); + } + + if (this.mode == Mode.PARSE_PROTO) { // Could be the end of the read + this.parseProtocolMessage(); + this.protocolBuffer.clear(); + } + } + } else if (bytesRead < 0) { + throw new IOException("Read channel closed."); + } else { + statistics.registerRead(bytesRead); // track the 0 + } + } + + private void gatherHeaders(int maxPos) throws IOException { + + final int donePosition = headerLen + headerStart + 2; + + if (donePosition > bufferPosition && maxPos > bufferPosition) { + bufferPosition = donePosition; + + + boolean gotCR = false; + boolean foundKey = false; + + int startHeader = headerStart; + String key = ""; + int startValue = 0; + + for (int i = headerStart; i < donePosition; i++) { + + byte b = this.buffer[i]; + + switch (b) { + + case ' ' : + case '\t': + if (foundKey) + startValue++; + break; + + case ':' : + key = new String(buffer, startHeader, i - startHeader).intern(); + foundKey = true; + startValue = i +1; + break; + + case NatsConnection.LF: + if (gotCR && foundKey) { + String value = new String(buffer, startValue , (i-1) - startValue).intern(); + headers.add(key, value); + gotCR = false; + startValue = 0; + key = null; + startHeader = i + 1; + foundKey = false; + } + break; + + case NatsConnection.CR: + gotCR = true; + break; + } + } + } + this.mode = Mode.GATHER_DATA; + + + } + + // Gather the op, either up to the first space or the first carriage return. + void gatherOp(int maxPos) throws IOException { + try { + while(this.bufferPosition < maxPos) { + byte b = this.buffer[this.bufferPosition]; + this.bufferPosition++; + + if (gotCR) { + if (b == NatsConnection.LF) { // Got CRLF, jump to parsing + this.op = opFor(opArray, opPos); + this.gotCR = false; + this.opPos = 0; + this.mode = Mode.PARSE_PROTO; + break; + } else { + throw new IllegalStateException("Bad socket data, no LF after CR"); + } + } else if (b == SPACE || b == TAB) { // Got a space, get the rest of the protocol line + this.op = opFor(opArray, opPos); + this.opPos = 0; + if (this.op == NatsConnection.OP_MSG || this.op == NatsConnection.OP_HMSG) { + this.msgLinePosition = 0; + this.mode = Mode.GATHER_MSG_PROTO; + } else { + this.mode = Mode.GATHER_PROTO; + } + break; + } else if (b == NatsConnection.CR) { + this.gotCR = true; + } else { + this.opArray[opPos] = (char) b; + this.opPos++; + } + } + } catch (ArrayIndexOutOfBoundsException | IllegalStateException | NumberFormatException | NullPointerException ex) { + this.encounteredProtocolError(ex); + } + } + + + // Stores the message protocol line in a char buffer that will be grepped for subject, reply + void gatherMessageProtocol(int maxPos) throws IOException { + try { + while(this.bufferPosition < maxPos) { + byte b = this.buffer[this.bufferPosition]; + this.bufferPosition++; + + if (gotCR) { + if (b == NatsConnection.LF) { + this.mode = Mode.PARSE_PROTO; + this.gotCR = false; + break; + } else { + throw new IllegalStateException("Bad socket data, no LF after CR"); + } + } else if (b == NatsConnection.CR) { + this.gotCR = true; + } else { + if (this.msgLinePosition >= this.msgLineChars.length) { + throw new IllegalStateException("Protocol line is too long"); + } + this.msgLineChars[this.msgLinePosition] = (char) b; // Assumes ascii, as per protocol doc + this.msgLinePosition++; + } + } + } catch (IllegalStateException | NumberFormatException | NullPointerException ex) { + this.encounteredProtocolError(ex); + } + } + + void gatherProtocol(final int maxPos) throws IOException { + // protocol buffer has max capacity, shouldn't need resizing + try { + while(this.bufferPosition < maxPos) { + byte b = this.buffer[this.bufferPosition]; + this.bufferPosition++; + + if (gotCR) { + if (b == NatsConnection.LF) { + this.protocolBuffer.flip(); + this.mode = Mode.PARSE_PROTO; + this.gotCR = false; + break; + } else { + throw new IllegalStateException("Bad socket data, no LF after CR"); + } + } else if (b == NatsConnection.CR) { + this.gotCR = true; + } else { + if (!protocolBuffer.hasRemaining()) { + this.protocolBuffer = ByteBufferUtil.enlargeBuffer(this.protocolBuffer, 0); // just double it + } + this.protocolBuffer.put(b); + } + } + } catch (IllegalStateException | NumberFormatException | NullPointerException ex) { + this.encounteredProtocolError(ex); + } + } + + // Gather bytes for a message body into a byte array that is then + // given to the message object + void gatherMessageData(int maxPos) throws IOException { + try { + while(this.bufferPosition < maxPos) { + int possible = maxPos - this.bufferPosition; + int want = msgData.length - msgDataPosition; + + // Grab all we can, until we get to the CR/LF + if (want > 0 && want <= possible) { + System.arraycopy(this.buffer, this.bufferPosition, this.msgData, this.msgDataPosition, want); + msgDataPosition += want; + this.bufferPosition += want; + continue; + } else if (want > 0) { + System.arraycopy(this.buffer, this.bufferPosition, this.msgData, this.msgDataPosition, possible); + msgDataPosition += possible; + this.bufferPosition += possible; + continue; + } + + byte b = this.buffer[this.bufferPosition]; + this.bufferPosition++; + + if (gotCR) { + if (b == NatsConnection.LF) { +System.out.println("NCR 365 " + incoming); + incoming.headers(this.headers); + this.headers = null; + incoming.data(msgData); + this.connection.deliverMessage(incoming.build()); + msgData = null; + msgDataPosition = 0; + // TODO just for test + incoming = null; + gotCR = false; + this.op = UNKNOWN_OP; + this.mode = Mode.GATHER_OP; + break; + } else { + throw new IllegalStateException("Bad socket data, no LF after CR"); + } + } else if (b == NatsConnection.CR) { + gotCR = true; + } else { + throw new IllegalStateException("Bad socket data, no CRLF after data"); + } + } + } catch (IllegalStateException | NullPointerException ex) { + this.encounteredProtocolError(ex); + } + } + + public String grabNextMessageLineElement(int max) { + if (this.msgLinePosition >= max) { + return null; + } + + int start = this.msgLinePosition; + + while (this.msgLinePosition < max) { + char c = this.msgLineChars[this.msgLinePosition]; + this.msgLinePosition++; + + if (c == SPACE || c == TAB) { + String slice = new String(this.msgLineChars, start, this.msgLinePosition - start -1); //don't grab the space, avoid an intermediate char sequence + return slice; + } + } + + return new String(this.msgLineChars, start, this.msgLinePosition-start); + } + + public String opFor(char[] chars, int length) { + if (length == 3) { + if ((chars[0] == 'M' || chars[0] == 'm') && + (chars[1] == 'S' || chars[1] == 's') && + (chars[2] == 'G' || chars[2] == 'g')) { + return NatsConnection.OP_MSG; + } else if (chars[0] == '+' && + (chars[1] == 'O' || chars[1] == 'o') && + (chars[2] == 'K' || chars[2] == 'k')) { + return NatsConnection.OP_OK; + } else { + return UNKNOWN_OP; + } + } else if (length == 4) { // do them in a unique order for uniqueness when possible to branch asap + if ((chars[1] == 'I' || chars[1] == 'i') && + (chars[0] == 'P' || chars[0] == 'p') && + (chars[2] == 'N' || chars[2] == 'n') && + (chars[3] == 'G' || chars[3] == 'g')) { + return NatsConnection.OP_PING; + } else if ((chars[1] == 'O' || chars[1] == 'o') && + (chars[0] == 'P' || chars[0] == 'p') && + (chars[2] == 'N' || chars[2] == 'n') && + (chars[3] == 'G' || chars[3] == 'g')) { + return NatsConnection.OP_PONG; + } else if (chars[0] == '-' && + (chars[1] == 'E' || chars[1] == 'e') && + (chars[2] == 'R' || chars[2] == 'r') && + (chars[3] == 'R' || chars[3] == 'r')) { + return NatsConnection.OP_ERR; + } else if ((chars[0] == 'I' || chars[0] == 'i') && + (chars[1] == 'N' || chars[1] == 'n') && + (chars[2] == 'F' || chars[2] == 'f') && + (chars[3] == 'O' || chars[3] == 'o')) { + return NatsConnection.OP_INFO; + } else if ( + (chars[0] == 'H' || chars[0] == 'h') && + (chars[1] == 'M' || chars[1] == 'm') && + (chars[2] == 'S' || chars[2] == 's') && + (chars[3] == 'G' || chars[3] == 'g')) { + return NatsConnection.OP_HMSG; + } else { + return UNKNOWN_OP; + } + } else { + return UNKNOWN_OP; + } + } + + private static int[] TENS = new int[] { 1, 10, 100, 1_000, 10_000, 100_000, 1_000_000, 10_000_000, 100_000_000, 1_000_000_000}; + + public static int parseLength(String s) throws NumberFormatException { + int length = s.length(); + int retVal = 0; + + if (length > TENS.length) { + throw new NumberFormatException("Long in message length \"" + s + "\" "+length+" > "+TENS.length); + } + + for (int i=length-1;i>=0;i--) { + char c = s.charAt(i); + int d = (c - '0'); + + if (d>9) { + throw new NumberFormatException("Invalid char in message length \'" + c + "\'"); + } + + retVal += d * TENS[length - i - 1]; + } + + return retVal; + } + + void parseProtocolMessage() throws IOException { + try { + switch (this.op) { + case NatsConnection.OP_MSG: + handleProtocolOpMsg(); + break; + case NatsConnection.OP_HMSG: + handleProtocolOpHMsg(); + break; + case NatsConnection.OP_OK: + this.connection.processOK(); + this.op = UNKNOWN_OP; + this.mode = Mode.GATHER_OP; + break; + case NatsConnection.OP_ERR: + String errorText = StandardCharsets.UTF_8.decode(protocolBuffer).toString().replace("'", ""); + this.connection.processError(errorText); + this.op = UNKNOWN_OP; + this.mode = Mode.GATHER_OP; + break; + case NatsConnection.OP_PING: + this.connection.sendPong(); + this.op = UNKNOWN_OP; + this.mode = Mode.GATHER_OP; + break; + case NatsConnection.OP_PONG: + this.connection.handlePong(); + this.op = UNKNOWN_OP; + this.mode = Mode.GATHER_OP; + break; + case NatsConnection.OP_INFO: + String info = StandardCharsets.UTF_8.decode(protocolBuffer).toString(); + this.connection.handleInfo(info); + this.op = UNKNOWN_OP; + this.mode = Mode.GATHER_OP; + break; + default: + throw new IllegalStateException("Unknown protocol operation "+op); + } + } catch (IllegalStateException | NumberFormatException | NullPointerException ex) { + this.encounteredProtocolError(ex); + } + } + + private void handleProtocolOpHMsg() { + // read headers + // read body + // create + // set incoming + // you may need to subclass or just add headers to it .. headers are more or less Map<String, List<String>> + + int protocolLength = this.msgLinePosition; //This is just after the last character + int protocolLineLength = protocolLength + 5; // 4 for the "HMSG " + + if (this.utf8Mode) { + protocolLineLength = protocolBuffer.remaining() + 5; + + CharBuffer buff = StandardCharsets.UTF_8.decode(protocolBuffer); + protocolLength = buff.remaining(); + buff.get(this.msgLineChars, 0, protocolLength); + } + + this.msgLinePosition = 0; + String subject = grabNextMessageLineElement(protocolLength); + String sid = grabNextMessageLineElement(protocolLength); + String possibleReplyTo = grabNextMessageLineElement(protocolLength); + String possibleHeaderLength = grabNextMessageLineElement(protocolLength); + String possiblePayloadLength = null; + + if (this.msgLinePosition < protocolLength) { + possiblePayloadLength = grabNextMessageLineElement(protocolLength); + } else { + possiblePayloadLength = possibleHeaderLength; + possibleHeaderLength = possibleReplyTo; + possibleReplyTo = null; + } + + if (subject==null || subject.length() == 0 || sid == null || sid.length() == 0 + || possiblePayloadLength == null || possibleHeaderLength == null) { + throw new IllegalStateException("Bad HMSG control line, missing required fields"); + } + + final String replyTo = possibleReplyTo; + final String headerLength = possibleHeaderLength; + final String payloadLength = possiblePayloadLength; + + int headerLen = parseLength(headerLength); + int payloadLen = parseLength(payloadLength); + + this.incoming = new NatsMessage.IncomingBuilder() + .sid(sid).subject(subject).replyTo(replyTo).protocolLength(protocolLineLength); + this.mode = Mode.GATHER_HEADER; + this.headerLen = headerLen; + this.headerStart = this.bufferPosition; + + if (headerLen > 0) { + this.headers = new Headers(); + } + this.msgData = new byte[payloadLen - headerLen]; + this.msgDataPosition = 0; + this.msgLinePosition = 0; + } + + private void handleProtocolOpMsg() { + int protocolLength = this.msgLinePosition; //This is just after the last character + int protocolLineLength = protocolLength + 4; // 4 for the "MSG " + + if (this.utf8Mode) { + protocolLineLength = protocolBuffer.remaining() + 4; + + CharBuffer buff = StandardCharsets.UTF_8.decode(protocolBuffer); + protocolLength = buff.remaining(); + buff.get(this.msgLineChars, 0, protocolLength); + } + + this.msgLinePosition = 0; + String subject = grabNextMessageLineElement(protocolLength); + String sid = grabNextMessageLineElement(protocolLength); + String replyTo = grabNextMessageLineElement(protocolLength); + String lengthChars = null; + + if (this.msgLinePosition < protocolLength) { + lengthChars = grabNextMessageLineElement(protocolLength); + } else { + lengthChars = replyTo; + replyTo = null; + } + + if(subject==null || subject.length() == 0 || sid==null || sid.length() == 0 || lengthChars==null) { + throw new IllegalStateException("Bad MSG control line, missing required fields"); + } + + int incomingLength = parseLength(lengthChars); + + this.incoming = new NatsMessage.IncomingBuilder() + .sid(sid).subject(subject).replyTo(replyTo).protocolLength(protocolLineLength); + + this.mode = Mode.GATHER_DATA; + this.msgData = new byte[incomingLength]; + this.msgDataPosition = 0; + this.msgLinePosition = 0; + return; + } + + + void encounteredProtocolError(Exception ex) throws IOException { + throw new IOException(ex); + } + + //For testing + void fakeReadForTest(byte[] bytes) { + System.arraycopy(bytes, 0, this.buffer, 0, bytes.length); + this.bufferPosition = 0; + this.op = UNKNOWN_OP; + this.mode = Mode.GATHER_OP; + } + + String currentOp() { + return this.op; + } + + public Mode getMode() { + return mode; + } + + public NatsMessage getIncoming() { + return incoming == null ? null : incoming.build(); + } } \ No newline at end of file diff --git a/src/main/java/io/nats/client/impl/NatsConnectionWriter.java b/src/main/java/io/nats/client/impl/NatsConnectionWriter.java index 5d2f280e3..9b96a2624 100644 --- a/src/main/java/io/nats/client/impl/NatsConnectionWriter.java +++ b/src/main/java/io/nats/client/impl/NatsConnectionWriter.java @@ -13,6 +13,8 @@ package io.nats.client.impl; +import io.nats.client.Options; + import java.io.IOException; import java.nio.BufferOverflowException; import java.nio.charset.StandardCharsets; @@ -25,8 +27,6 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.ReentrantLock; -import io.nats.client.Options; - class NatsConnectionWriter implements Runnable { private final NatsConnection connection; @@ -137,7 +137,7 @@ public void run() { dataPort.write(sendBuffer, sendPosition); connection.getNatsStatistics().registerWrite(sendPosition); sendPosition = 0; - msg = msg.next; + msg = msg.getNext(); if (msg == null) { break; @@ -146,14 +146,20 @@ public void run() { } byte[] bytes = msg.getProtocolBytes(); + //System.out.println(new String(bytes, StandardCharsets.UTF_8)); System.arraycopy(bytes, 0, sendBuffer, sendPosition, bytes.length); sendPosition += bytes.length; sendBuffer[sendPosition++] = '\r'; sendBuffer[sendPosition++] = '\n'; + + if (!msg.isProtocol()) { bytes = msg.getData(); + + + //System.out.println(new String(bytes, StandardCharsets.UTF_8)); System.arraycopy(bytes, 0, sendBuffer, sendPosition, bytes.length); sendPosition += bytes.length; @@ -164,7 +170,7 @@ public void run() { stats.incrementOutMsgs(); stats.incrementOutBytes(size); - msg = msg.next; + msg = msg.getNext(); } dataPort.write(sendBuffer, sendPosition); diff --git a/src/main/java/io/nats/client/impl/NatsMessage.java b/src/main/java/io/nats/client/impl/NatsMessage.java index 66546e6c2..bb097f1f2 100644 --- a/src/main/java/io/nats/client/impl/NatsMessage.java +++ b/src/main/java/io/nats/client/impl/NatsMessage.java @@ -1,198 +1,350 @@ -// Copyright 2015-2018 The NATS Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package io.nats.client.impl; - -import java.nio.ByteBuffer; -import java.nio.CharBuffer; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; - -import io.nats.client.Connection; -import io.nats.client.Message; -import io.nats.client.Subscription; - -class NatsMessage implements Message { - private String sid; - private String subject; - private String replyTo; - private byte[] data; - private byte[] protocolBytes; - private NatsSubscription subscription; - private long sizeInBytes; - - NatsMessage next; // for linked list - - static final byte[] digits = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}; - - static int copy(byte[] dest, int pos, String toCopy) { - for (int i=0, max=toCopy.length(); i<max ;i++) { - dest[pos] = (byte) toCopy.charAt(i); - pos++; - } - - return pos; - } - - private static String PUB_SPACE = NatsConnection.OP_PUB + " "; - private static String SPACE = " "; - - // Create a message to publish - NatsMessage(String subject, String replyTo, byte[] data, boolean utf8mode) { - this.subject = subject; - this.replyTo = replyTo; - this.data = data; - - if (utf8mode) { - int subjectSize = subject.length() * 2; - int replySize = (replyTo != null) ? replyTo.length() * 2 : 0; - StringBuilder protocolStringBuilder = new StringBuilder(4 + subjectSize + 1 + replySize + 1); - protocolStringBuilder.append(PUB_SPACE); - protocolStringBuilder.append(subject); - protocolStringBuilder.append(SPACE); - - if (replyTo != null) { - protocolStringBuilder.append(replyTo); - protocolStringBuilder.append(SPACE); - } - - protocolStringBuilder.append(String.valueOf(data.length)); - - this.protocolBytes = protocolStringBuilder.toString().getBytes(StandardCharsets.UTF_8); - } else { - // Convert the length to bytes - byte[] lengthBytes = new byte[12]; - int idx = lengthBytes.length; - int size = (data != null) ? data.length : 0; - - if (size > 0) { - for (int i = size; i > 0; i /= 10) { - idx--; - lengthBytes[idx] = digits[i % 10]; - } - } else { - idx--; - lengthBytes[idx] = digits[0]; - } - - // Build the array - int len = 4 + subject.length() + 1 + (lengthBytes.length - idx); - - if (replyTo != null) { - len += replyTo.length() + 1; - } - - this.protocolBytes = new byte[len]; - - // Copy everything - int pos = 0; - protocolBytes[0] = 'P'; - protocolBytes[1] = 'U'; - protocolBytes[2] = 'B'; - protocolBytes[3] = ' '; - pos = 4; - pos = copy(protocolBytes, pos, subject); - protocolBytes[pos] = ' '; - pos++; - - if (replyTo != null) { - pos = copy(protocolBytes, pos, replyTo); - protocolBytes[pos] = ' '; - pos++; - } - - System.arraycopy(lengthBytes, idx, protocolBytes, pos, lengthBytes.length - idx); - } - - this.sizeInBytes = this.protocolBytes.length + data.length + 4;// for 2x \r\n - } - - // Create a protocol only message to publish - NatsMessage(CharBuffer protocol) { - ByteBuffer byteBuffer = StandardCharsets.UTF_8.encode(protocol); - this.protocolBytes = Arrays.copyOfRange(byteBuffer.array(), byteBuffer.position(), byteBuffer.limit()); - Arrays.fill(byteBuffer.array(), (byte) 0); // clear sensitive data - this.sizeInBytes = this.protocolBytes.length + 2;// for \r\n - } - - // Create an incoming message for a subscriber - // Doesn't check controlline size, since the server sent us the message - NatsMessage(String sid, String subject, String replyTo, int protocolLength) { - this.sid = sid; - this.subject = subject; - if (replyTo != null) { - this.replyTo = replyTo; - } - this.sizeInBytes = protocolLength + 2; - this.data = null; // will set data and size after we read it - } - - boolean isProtocol() { - return this.subject == null; - } - - // Will be null on an incoming message - byte[] getProtocolBytes() { - return this.protocolBytes; - } - - int getControlLineLength() { - return (this.protocolBytes != null) ? this.protocolBytes.length + 2 : -1; - } - - long getSizeInBytes() { - return sizeInBytes; - } - - public String getSID() { - return this.sid; - } - - // Only for incoming messages, with no protocol bytes - void setData(byte[] data) { - this.data = data; - this.sizeInBytes += data.length + 2;// for \r\n, we already set the length for the protocol bytes in the constructor - } - - void setSubscription(NatsSubscription sub) { - this.subscription = sub; - } - - NatsSubscription getNatsSubscription() { - return this.subscription; - } - - public Connection getConnection() { - if (this.subscription == null) { - return null; - } - - return this.subscription.connection; - } - - public String getSubject() { - return this.subject; - } - - public String getReplyTo() { - return this.replyTo; - } - - public byte[] getData() { - return this.data; - } - - public Subscription getSubscription() { - return this.subscription; - } -} +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package io.nats.client.impl; + +import io.nats.client.Message; + +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; + +public class NatsMessage implements Message { + private String sid; + private String subject; + private String replyTo; + private Headers headers; + private boolean utf8mode; + protected boolean protocol; + private byte[] data; + private byte[] protocolBytes; + private NatsSubscription subscription; + private long sizeInBytes; + private NatsMessage next; // for linked list + + static final byte[] digits = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}; + + private static final String PUB_SPACE = NatsConnection.OP_PUB + " "; + private static final String HPUB_SPACE = NatsConnection.OP_HPUB + " "; + private static final String SPACE = " "; + private static final String CRLF = "\r\n"; + private static final String COLON_SPACE = ": "; + + private NatsMessage() {} + + private abstract static class Builder<T> { + protected String subject; + protected String replyTo; + protected Headers headers; + protected boolean utf8mode; + protected byte[] data = Message.EMPTY_BODY; + + protected abstract T self(); + + public T subject(String subject) { this.subject = subject; return self(); } + public T replyTo(String replyTo) { this.replyTo = replyTo; return self(); } + public T utf8mode(boolean utf8mode) { this.utf8mode = utf8mode; return self(); } + public T headers(Headers headers) { this.headers = headers; return self(); } + + private Headers getHeaders() { + if (headers == null) { + headers = new Headers(); + } + return headers; + } + + public T addHeader(String key, String... values) { getHeaders().add(key, values); return self(); } + public T putHeader(String key, String... values) { getHeaders().put(key, values); return self(); } + + public T data(byte[] data) { + this.data = data == null ? Message.EMPTY_BODY : data; + return self(); + } + + public T data(String data, Charset charset) { + this.data = data == null ? Message.EMPTY_BODY : data.getBytes(charset); + return self(); + } + + public NatsMessage build() { + NatsMessage msg = new NatsMessage(); + msg.subject = subject; + msg.replyTo = replyTo; + msg.data = data; + msg.utf8mode = utf8mode; + msg.headers = headers; + return msg; + } + } + + public static class PublishBuilder extends Builder<PublishBuilder> { + private Long maxPayload; + + public PublishBuilder maxPayload(Long maxPayload) { this.maxPayload = maxPayload; return this; } + + public NatsMessage build() { + NatsMessage msg = super.build(); + + if (subject == null || subject.length() == 0) { + throw new IllegalArgumentException("Subject is required in publish"); + } + + if (replyTo != null && replyTo.length() == 0) { + throw new IllegalArgumentException("ReplyTo cannot be the empty string"); + } + + if (data.length > 0) { + if (maxPayload == null) { + throw new IllegalArgumentException("Max Payload must be set before build is called when there is data."); + } + + if (data.length > maxPayload && maxPayload > 0) { + throw new IllegalArgumentException( + "Message payload size exceed server configuration " + data.length + " vs " + maxPayload); + } + } + + boolean hpub = headers != null && !headers.isEmpty(); + + if (utf8mode || hpub) { + int subjectSize = subject.length() * 2; + int replySize = (replyTo != null) ? replyTo.length() * 2 : 0; + StringBuilder protocolStringBuilder = new StringBuilder(4 + subjectSize + 1 + replySize + 1); + if (hpub) { + protocolStringBuilder.append(HPUB_SPACE); + } else { + protocolStringBuilder.append(PUB_SPACE); + } + protocolStringBuilder.append(subject); + protocolStringBuilder.append(SPACE); + if (replyTo != null) { + protocolStringBuilder.append(replyTo); + protocolStringBuilder.append(SPACE); + } + if (hpub) { + int headerLength = calculateHeaderLength(headers); + protocolStringBuilder.append(headerLength); + protocolStringBuilder.append(SPACE); + protocolStringBuilder.append(data.length + headerLength); + protocolStringBuilder.append(CRLF); + outputHeaders(headers, protocolStringBuilder); + msg.headers = headers; + } else { + protocolStringBuilder.append(data.length); + } + msg.protocolBytes = protocolStringBuilder.toString().getBytes(StandardCharsets.UTF_8); + } else { + + // Convert the length to bytes + byte[] lengthBytes = new byte[12]; + int idx = lengthBytes.length; + int size = (data != null) ? data.length : 0; + + if (size > 0) { + for (int i = size; i > 0; i /= 10) { + idx--; + lengthBytes[idx] = digits[i % 10]; + } + } else { + idx--; + lengthBytes[idx] = digits[0]; + } + + // Build the array + int len = 4 + subject.length() + 1 + (lengthBytes.length - idx); + + if (replyTo != null) { + len += replyTo.length() + 1; + } + + msg.protocolBytes = new byte[len]; + + // Copy everything + msg.protocolBytes[0] = 'P'; + msg.protocolBytes[1] = 'U'; + msg.protocolBytes[2] = 'B'; + msg.protocolBytes[3] = ' '; + int pos = copy(msg.protocolBytes, 4, subject); + msg.protocolBytes[pos] = ' '; + pos++; + + if (replyTo != null) { + pos = copy(msg.protocolBytes, pos, replyTo); + msg.protocolBytes[pos] = ' '; + pos++; + } + + System.arraycopy(lengthBytes, idx, msg.protocolBytes, pos, lengthBytes.length - idx); + } + + msg.sizeInBytes = msg.protocolBytes.length + data.length + 4;// for 2x \r\n + + return msg; + } + + private void outputHeaders(Headers headers, StringBuilder protocolStringBuilder) { + if (headers != null) { + headers.keySet().forEach(key -> + headers.values(key).forEach(value -> { + protocolStringBuilder.append(key); + protocolStringBuilder.append(COLON_SPACE); + protocolStringBuilder.append(value); + protocolStringBuilder.append(CRLF); + }) + ); + } + protocolStringBuilder.append(CRLF); + } + + private int calculateHeaderLength(Headers headers) { + AtomicInteger headerLength = new AtomicInteger(2); //closing \r\n + if (headers != null) { + headers.keySet().forEach(key -> { + // each line will have the key, colon, space, value, \r\n. + int eachLen = key.length() + 4; // precalculate all but each value + headers.values(key).forEach(v -> headerLength.addAndGet(eachLen + v.length()) ); + }); + } + return headerLength.get(); + } + + @Override protected PublishBuilder self() { return this; } + } + + public static class ProtocolBuilder extends Builder<ProtocolBuilder> { + private CharBuffer buffer; + + public ProtocolBuilder protocol(CharBuffer buffer) { this.buffer = buffer; return this; } + public ProtocolBuilder protocol(String protocol) { return protocol(CharBuffer.wrap(protocol)); } + + @Override + public NatsMessage build() { + NatsMessage msg = new NatsMessage(); + msg.protocol = true; + ByteBuffer byteBuffer = StandardCharsets.UTF_8.encode(buffer); + msg.protocolBytes = Arrays.copyOfRange(byteBuffer.array(), byteBuffer.position(), byteBuffer.limit()); + Arrays.fill(byteBuffer.array(), (byte) 0); // clear sensitive data + msg.sizeInBytes = msg.protocolBytes.length + 2;// for \r\n + return msg; + } + + @Override protected ProtocolBuilder self() { return this; } + } + + public static NatsMessage getProtocolInstance(CharBuffer buffer) { return new NatsMessage.ProtocolBuilder().protocol(buffer).build(); } + public static NatsMessage getProtocolInstance(String protocol) { return new NatsMessage.ProtocolBuilder().protocol(protocol).build(); } + + public static class IncomingBuilder extends Builder<IncomingBuilder> { + private String sid; + private long protocolLineLength; + + public IncomingBuilder sid(String sid) { this.sid = sid; return this; } + public IncomingBuilder protocolLength(int len) { this.protocolLineLength = len; return this; } + + @Override + public NatsMessage build() { + NatsMessage msg = super.build(); + msg.sid = sid; + msg.sizeInBytes = protocolLineLength + 2 + data.length + 2; // for \r\n + return msg; + } + + @Override protected IncomingBuilder self() { return this; } + } + + public boolean isProtocol() { + return protocol; + } + + public int getControlLineLength() { + return this.protocolBytes == null ? -1 : this.protocolBytes.length + 2; + } + + void setSubscription(NatsSubscription sub) { + this.subscription = sub; + } + + NatsSubscription getNatsSubscription() { + return subscription; + } + + @Override + public NatsConnection getConnection() { + return subscription == null ? null : subscription.connection; + } + + @Override + public String getSubject() { + return this.subject; + } + + @Override + public String getReplyTo() { + return this.replyTo; + } + + @Override + public Headers getHeaders() { + return headers; + } + + @Override + public boolean isUtf8mode() { + return utf8mode; + } + + @Override + public byte[] getData() { + return this.data; + } + + @Override + public String getSID() { + return this.sid; + } + + @Override + public NatsSubscription getSubscription() { + return this.subscription; + } + + @Override + public byte[] getProtocolBytes() { + return this.protocolBytes; + } + + @Override + public long getSizeInBytes() { + return sizeInBytes; + } + + private static int copy(byte[] dest, int pos, String toCopy) { + for (int i = 0, max = toCopy.length(); i < max; i++) { + dest[pos] = (byte) toCopy.charAt(i); + pos++; + } + + return pos; + } + + public NatsMessage getNext() { + return next; + } + + public void setNext(NatsMessage next) { + this.next = next; + } +} diff --git a/src/main/java/io/nats/client/impl/NatsServerInfo.java b/src/main/java/io/nats/client/impl/NatsServerInfo.java index ee493e100..7dddc5a25 100644 --- a/src/main/java/io/nats/client/impl/NatsServerInfo.java +++ b/src/main/java/io/nats/client/impl/NatsServerInfo.java @@ -1,270 +1,282 @@ -// Copyright 2015-2018 The NATS Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package io.nats.client.impl; - -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -class NatsServerInfo { - - static final String SERVER_ID = "server_id"; - static final String VERSION = "version"; - static final String GO = "go"; - static final String HOST = "host"; - static final String PORT = "port"; - static final String AUTH = "auth_required"; - static final String TLS = "tls_required"; - static final String MAX_PAYLOAD = "max_payload"; - static final String CONNECT_URLS = "connect_urls"; - static final String PROTOCOL_VERSION = "proto"; - static final String NONCE = "nonce"; - static final String LAME_DUCK_MODE = "ldm"; - - private String serverId; - private String version; - private String go; - private String host; - private int port; - private boolean authRequired; - private boolean tlsRequired; - private long maxPayload; - private String[] connectURLs; - private String rawInfoJson; - private int protocolVersion; - private byte[] nonce; - private boolean lameDuckMode; - - public NatsServerInfo(String json) { - this.rawInfoJson = json; - parseInfo(json); - } - - public boolean isLameDuckMode() { - return lameDuckMode; - } - - public String getServerId() { - return this.serverId; - } - - public String getVersion() { - return this.version; - } - - public String getGoVersion() { - return this.go; - } - - public String getHost() { - return this.host; - } - - public int getPort() { - return this.port; - } - - public int getProtocolVersion() { - return this.protocolVersion; - } - - public boolean isAuthRequired() { - return this.authRequired; - } - - public boolean isTLSRequired() { - return this.tlsRequired; - } - - public long getMaxPayload() { - return this.maxPayload; - } - - public String[] getConnectURLs() { - return this.connectURLs; - } - - public byte[] getNonce() { - return this.nonce; - } - - // If parsing succeeds this is the JSON, if not this may be the full protocol line - public String getRawJson() { - return rawInfoJson; - } - - private static final String grabString = "\\s*\"(.+?)\""; - private static final String grabBoolean = "\\s*(true|false)"; - private static final String grabNumber = "\\s*(\\d+)"; - private static final String grabStringArray = "\\s*\\[(\".+?\")\\]"; - private static final String grabObject = "\\{(.+?)\\}"; - - void parseInfo(String jsonString) { - Pattern lameDuckMode = Pattern.compile("\""+LAME_DUCK_MODE+"\":" + grabBoolean, Pattern.CASE_INSENSITIVE); - Pattern serverIdRE = Pattern.compile("\""+SERVER_ID+"\":" + grabString, Pattern.CASE_INSENSITIVE); - Pattern versionRE = Pattern.compile("\""+VERSION+"\":" + grabString, Pattern.CASE_INSENSITIVE); - Pattern goRE = Pattern.compile("\""+GO+"\":" + grabString, Pattern.CASE_INSENSITIVE); - Pattern hostRE = Pattern.compile("\""+HOST+"\":" + grabString, Pattern.CASE_INSENSITIVE); - Pattern nonceRE = Pattern.compile("\""+NONCE+"\":" + grabString, Pattern.CASE_INSENSITIVE); - Pattern authRE = Pattern.compile("\""+AUTH+"\":" + grabBoolean, Pattern.CASE_INSENSITIVE); - Pattern tlsRE = Pattern.compile("\""+TLS+"\":" + grabBoolean, Pattern.CASE_INSENSITIVE); - Pattern portRE = Pattern.compile("\""+PORT+"\":" + grabNumber, Pattern.CASE_INSENSITIVE); - Pattern maxRE = Pattern.compile("\""+MAX_PAYLOAD+"\":" + grabNumber, Pattern.CASE_INSENSITIVE); - Pattern protoRE = Pattern.compile("\""+PROTOCOL_VERSION+"\":" + grabNumber, Pattern.CASE_INSENSITIVE); - Pattern connectRE = Pattern.compile("\""+CONNECT_URLS+"\":" + grabStringArray, Pattern.CASE_INSENSITIVE); - Pattern infoObject = Pattern.compile(grabObject, Pattern.CASE_INSENSITIVE); - - Matcher m = infoObject.matcher(jsonString); - if (m.find()) { - jsonString = m.group(0); - this.rawInfoJson = jsonString; - } else { - jsonString = ""; - } - - if (jsonString.length() < 2) { - throw new IllegalArgumentException("Server info requires at least {}."); - } else if (jsonString.charAt(0) != '{' || jsonString.charAt(jsonString.length()-1) != '}') { - throw new IllegalArgumentException("Server info should be JSON wrapped with { and }."); - } - - m = serverIdRE.matcher(jsonString); - if (m.find()) { - this.serverId = unescapeString(m.group(1)); - } - - m = versionRE.matcher(jsonString); - if (m.find()) { - this.version = unescapeString(m.group(1)); - } - - m = goRE.matcher(jsonString); - if (m.find()) { - this.go = unescapeString(m.group(1)); - } - - m = hostRE.matcher(jsonString); - if (m.find()) { - this.host = unescapeString(m.group(1)); - } - - m = authRE.matcher(jsonString); - if (m.find()) { - this.authRequired = Boolean.parseBoolean(m.group(1)); - } - - m = nonceRE.matcher(jsonString); - if (m.find()) { - String encodedNonce = m.group(1); - this.nonce = encodedNonce.getBytes(StandardCharsets.US_ASCII); - } - - m = tlsRE.matcher(jsonString); - if (m.find()) { - this.tlsRequired = Boolean.parseBoolean(m.group(1)); - } - - m = lameDuckMode.matcher(jsonString); - if (m.find()) { - this.lameDuckMode = Boolean.parseBoolean(m.group(1)); - } - - - m = portRE.matcher(jsonString); - if (m.find()) { - this.port = Integer.parseInt(m.group(1)); - } - - m = protoRE.matcher(jsonString); - if (m.find()) { - this.protocolVersion = Integer.parseInt(m.group(1)); - } - - m = maxRE.matcher(jsonString); - if (m.find()) { - this.maxPayload = Long.parseLong(m.group(1)); - } - - m = connectRE.matcher(jsonString); - if (m.find()) { - String arrayString = m.group(1); - String[] raw = arrayString.split(","); - ArrayList<String> urls = new ArrayList<>(); - - for (String s : raw) { - String cleaned = s.trim().replace("\"", "");; - if (cleaned.length() > 0) { - urls.add(cleaned); - } - } - - this.connectURLs = urls.toArray(new String[0]); - } - } - - // See https://gist.github.com/uklimaschewski/6741769, no license required - // Removed octal support - String unescapeString(String st) { - - StringBuilder sb = new StringBuilder(st.length()); - - for (int i = 0; i < st.length(); i++) { - char ch = st.charAt(i); - if (ch == '\\') { - char nextChar = (i == st.length() - 1) ? '\\' : st.charAt(i + 1); - switch (nextChar) { - case '\\': - ch = '\\'; - break; - case 'b': - ch = '\b'; - break; - case 'f': - ch = '\f'; - break; - case 'n': - ch = '\n'; - break; - case 'r': - ch = '\r'; - break; - case 't': - ch = '\t'; - break; - /*case '\"': - ch = '\"'; - break; - case '\'': - ch = '\''; - break;*/ - // Hex Unicode: u???? - case 'u': - if (i >= st.length() - 5) { - ch = 'u'; - break; - } - int code = Integer.parseInt( - "" + st.charAt(i + 2) + st.charAt(i + 3) + st.charAt(i + 4) + st.charAt(i + 5), 16); - sb.append(Character.toChars(code)); - i += 5; - continue; - } - i++; - } - sb.append(ch); - } - return sb.toString(); - } +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package io.nats.client.impl; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +class NatsServerInfo { + + static final String SERVER_ID = "server_id"; + static final String VERSION = "version"; + static final String GO = "go"; + static final String HOST = "host"; + static final String PORT = "port"; + static final String AUTH = "auth_required"; + static final String TLS = "tls_required"; + static final String MAX_PAYLOAD = "max_payload"; + static final String CONNECT_URLS = "connect_urls"; + static final String PROTOCOL_VERSION = "proto"; + static final String NONCE = "nonce"; + static final String LAME_DUCK_MODE_KEY = "ldm"; + static final String HEADERS_KEY = "headers"; + + + private String serverId; + private String version; + private String go; + private String host; + private int port; + private boolean authRequired; + private boolean tlsRequired; + private long maxPayload; + private String[] connectURLs; + private String rawInfoJson; + private int protocolVersion; + private byte[] nonce; + private boolean lameDuckMode; + private boolean headers; + + public NatsServerInfo(String json) { + this.rawInfoJson = json; + parseInfo(json); + } + + public boolean isLameDuckMode() { + return lameDuckMode; + } + + public boolean isHeaders() { + return headers; + } + + public String getServerId() { + return this.serverId; + } + + public String getVersion() { + return this.version; + } + + public String getGoVersion() { + return this.go; + } + + public String getHost() { + return this.host; + } + + public int getPort() { + return this.port; + } + + public int getProtocolVersion() { + return this.protocolVersion; + } + + public boolean isAuthRequired() { + return this.authRequired; + } + + public boolean isTLSRequired() { + return this.tlsRequired; + } + + public long getMaxPayload() { + return this.maxPayload; + } + + public String[] getConnectURLs() { + return this.connectURLs; + } + + public byte[] getNonce() { + return this.nonce; + } + + // If parsing succeeds this is the JSON, if not this may be the full protocol line + public String getRawJson() { + return rawInfoJson; + } + + private static final String grabString = "\\s*\"(.+?)\""; + private static final String grabBoolean = "\\s*(true|false)"; + private static final String grabNumber = "\\s*(\\d+)"; + private static final String grabStringArray = "\\s*\\[(\".+?\")\\]"; + private static final String grabObject = "\\{(.+?)\\}"; + + void parseInfo(String jsonString) { + Pattern headersMode = Pattern.compile("\""+ HEADERS_KEY +"\":" + grabBoolean, Pattern.CASE_INSENSITIVE); + Pattern lameDuckMode = Pattern.compile("\""+ LAME_DUCK_MODE_KEY +"\":" + grabBoolean, Pattern.CASE_INSENSITIVE); + Pattern serverIdRE = Pattern.compile("\""+SERVER_ID+"\":" + grabString, Pattern.CASE_INSENSITIVE); + Pattern versionRE = Pattern.compile("\""+VERSION+"\":" + grabString, Pattern.CASE_INSENSITIVE); + Pattern goRE = Pattern.compile("\""+GO+"\":" + grabString, Pattern.CASE_INSENSITIVE); + Pattern hostRE = Pattern.compile("\""+HOST+"\":" + grabString, Pattern.CASE_INSENSITIVE); + Pattern nonceRE = Pattern.compile("\""+NONCE+"\":" + grabString, Pattern.CASE_INSENSITIVE); + Pattern authRE = Pattern.compile("\""+AUTH+"\":" + grabBoolean, Pattern.CASE_INSENSITIVE); + Pattern tlsRE = Pattern.compile("\""+TLS+"\":" + grabBoolean, Pattern.CASE_INSENSITIVE); + Pattern portRE = Pattern.compile("\""+PORT+"\":" + grabNumber, Pattern.CASE_INSENSITIVE); + Pattern maxRE = Pattern.compile("\""+MAX_PAYLOAD+"\":" + grabNumber, Pattern.CASE_INSENSITIVE); + Pattern protoRE = Pattern.compile("\""+PROTOCOL_VERSION+"\":" + grabNumber, Pattern.CASE_INSENSITIVE); + Pattern connectRE = Pattern.compile("\""+CONNECT_URLS+"\":" + grabStringArray, Pattern.CASE_INSENSITIVE); + Pattern infoObject = Pattern.compile(grabObject, Pattern.CASE_INSENSITIVE); + + Matcher m = infoObject.matcher(jsonString); + if (m.find()) { + jsonString = m.group(0); + this.rawInfoJson = jsonString; + } else { + jsonString = ""; + } + + if (jsonString.length() < 2) { + throw new IllegalArgumentException("Server info requires at least {}."); + } else if (jsonString.charAt(0) != '{' || jsonString.charAt(jsonString.length()-1) != '}') { + throw new IllegalArgumentException("Server info should be JSON wrapped with { and }."); + } + + m = serverIdRE.matcher(jsonString); + if (m.find()) { + this.serverId = unescapeString(m.group(1)); + } + + m = versionRE.matcher(jsonString); + if (m.find()) { + this.version = unescapeString(m.group(1)); + } + + m = goRE.matcher(jsonString); + if (m.find()) { + this.go = unescapeString(m.group(1)); + } + + m = hostRE.matcher(jsonString); + if (m.find()) { + this.host = unescapeString(m.group(1)); + } + + m = authRE.matcher(jsonString); + if (m.find()) { + this.authRequired = Boolean.parseBoolean(m.group(1)); + } + + m = nonceRE.matcher(jsonString); + if (m.find()) { + String encodedNonce = m.group(1); + this.nonce = encodedNonce.getBytes(StandardCharsets.US_ASCII); + } + + m = tlsRE.matcher(jsonString); + if (m.find()) { + this.tlsRequired = Boolean.parseBoolean(m.group(1)); + } + + m = lameDuckMode.matcher(jsonString); + if (m.find()) { + this.lameDuckMode = Boolean.parseBoolean(m.group(1)); + } + + m = headersMode.matcher(jsonString); + if (m.find()) { + this.headers = Boolean.parseBoolean(m.group(1)); + } + + m = portRE.matcher(jsonString); + if (m.find()) { + this.port = Integer.parseInt(m.group(1)); + } + + m = protoRE.matcher(jsonString); + if (m.find()) { + this.protocolVersion = Integer.parseInt(m.group(1)); + } + + m = maxRE.matcher(jsonString); + if (m.find()) { + this.maxPayload = Long.parseLong(m.group(1)); + } + + m = connectRE.matcher(jsonString); + if (m.find()) { + String arrayString = m.group(1); + String[] raw = arrayString.split(","); + ArrayList<String> urls = new ArrayList<>(); + + for (String s : raw) { + String cleaned = s.trim().replace("\"", "");; + if (cleaned.length() > 0) { + urls.add(cleaned); + } + } + + this.connectURLs = urls.toArray(new String[0]); + } + } + + // See https://gist.github.com/uklimaschewski/6741769, no license required + // Removed octal support + String unescapeString(String st) { + + StringBuilder sb = new StringBuilder(st.length()); + + for (int i = 0; i < st.length(); i++) { + char ch = st.charAt(i); + if (ch == '\\') { + char nextChar = (i == st.length() - 1) ? '\\' : st.charAt(i + 1); + switch (nextChar) { + case '\\': + ch = '\\'; + break; + case 'b': + ch = '\b'; + break; + case 'f': + ch = '\f'; + break; + case 'n': + ch = '\n'; + break; + case 'r': + ch = '\r'; + break; + case 't': + ch = '\t'; + break; + /*case '\"': + ch = '\"'; + break; + case '\'': + ch = '\''; + break;*/ + // Hex Unicode: u???? + case 'u': + if (i >= st.length() - 5) { + ch = 'u'; + break; + } + int code = Integer.parseInt( + "" + st.charAt(i + 2) + st.charAt(i + 3) + st.charAt(i + 4) + st.charAt(i + 5), 16); + sb.append(Character.toChars(code)); + i += 5; + continue; + } + i++; + } + sb.append(ch); + } + return sb.toString(); + } } \ No newline at end of file diff --git a/src/main/java/io/nats/client/impl/ProtocolHandler.java b/src/main/java/io/nats/client/impl/ProtocolHandler.java new file mode 100644 index 000000000..d5264366b --- /dev/null +++ b/src/main/java/io/nats/client/impl/ProtocolHandler.java @@ -0,0 +1,18 @@ +package io.nats.client.impl; + +public interface ProtocolHandler { + + void handleCommunicationIssue(Exception io); + + void deliverMessage(NatsMessage msg); + + void processOK(); + + void processError(String errorText); + + void sendPong(); + + void handlePong(); + + void handleInfo(String infoJson); +} diff --git a/src/test/java/io/nats/client/NatsServerProtocolMock.java b/src/test/java/io/nats/client/NatsServerProtocolMock.java index 0f2c35fd1..53948de93 100644 --- a/src/test/java/io/nats/client/NatsServerProtocolMock.java +++ b/src/test/java/io/nats/client/NatsServerProtocolMock.java @@ -112,7 +112,7 @@ private void start() { Thread t = new Thread(() -> {accept();}); t.start(); try { - Thread.sleep(100); + Thread.sleep(1000); } catch (Exception exp) { //Give the server time to get going } diff --git a/src/test/java/io/nats/client/impl/HeadersTests.java b/src/test/java/io/nats/client/impl/HeadersTests.java new file mode 100644 index 000000000..9d3da3d13 --- /dev/null +++ b/src/test/java/io/nats/client/impl/HeadersTests.java @@ -0,0 +1,227 @@ +package io.nats.client.impl; + +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Set; +import java.util.function.Consumer; + +import static org.junit.Assert.*; + +public class HeadersTests { + private static final String KEY1 = "key1"; + private static final String KEY2 = "key2"; + private static final String KEY3 = "key3"; + private static final String VAL1 = "val1"; + private static final String VAL2 = "val2"; + private static final String VAL3 = "val3"; + private static final String EMPTY = ""; + private static final String ANY_VAL = "matters-it-does-not"; + + @Test + public void add_key_strings_works() { + add( + headers -> headers.add(KEY1, VAL1), + headers -> headers.add(KEY1, VAL2), + headers -> headers.add(KEY2, VAL3)); + } + + @Test + public void add_key_collection_works() { + add( + headers -> headers.add(KEY1, Collections.singletonList(VAL1)), + headers -> headers.add(KEY1, Collections.singletonList(VAL2)), + headers -> headers.add(KEY2, Collections.singletonList(VAL3))); + } + + private void add( + Consumer<Headers> stepKey1Val1, + Consumer<Headers> step2Key1Val2, + Consumer<Headers> step3Key2Val3) + { + Headers headers = new Headers(); + + stepKey1Val1.accept(headers); + assertEquals(1, headers.size()); + assertTrue(headers.containsKey(KEY1)); + assertContainsExactly(headers.values(KEY1), VAL1); + + step2Key1Val2.accept(headers); + assertEquals(1, headers.size()); + assertTrue(headers.containsKey(KEY1)); + assertContainsExactly(headers.values(KEY1), VAL1, VAL2); + + step3Key2Val3.accept(headers); + assertEquals(2, headers.size()); + assertTrue(headers.containsKey(KEY1)); + assertTrue(headers.containsKey(KEY2)); + assertContainsExactly(headers.values(KEY1), VAL1, VAL2); + assertContainsExactly(headers.values(KEY2), VAL3); + } + + @Test + public void set_key_strings_works() { + set( + headers -> headers.put(KEY1, VAL1), + headers -> headers.put(KEY1, VAL2), + headers -> headers.put(KEY2, VAL3)); + } + + @Test + public void set_key_collection_works() { + set( + headers -> headers.put(KEY1, Collections.singletonList(VAL1)), + headers -> headers.put(KEY1, Collections.singletonList(VAL2)), + headers -> headers.put(KEY2, Collections.singletonList(VAL3))); + } + + private void set( + Consumer<Headers> stepKey1Val1, + Consumer<Headers> step2Key1Val2, + Consumer<Headers> step3Key2Val3) + { + Headers headers = new Headers(); + assertTrue(headers.isEmpty()); + + stepKey1Val1.accept(headers); + assertEquals(1, headers.size()); + assertEquals(1, headers.keySet().size()); + assertTrue(headers.containsKey(KEY1)); + assertTrue(headers.keySet().contains(KEY1)); + assertContainsExactly(headers.values(KEY1), VAL1); + + step2Key1Val2.accept(headers); + assertEquals(1, headers.size()); + assertEquals(1, headers.keySet().size()); + assertTrue(headers.containsKey(KEY1)); + assertTrue(headers.keySet().contains(KEY1)); + assertContainsExactly(headers.values(KEY1), VAL2); + + step3Key2Val3.accept(headers); + assertEquals(2, headers.size()); + assertEquals(2, headers.keySet().size()); + assertTrue(headers.containsKey(KEY1)); + assertTrue(headers.containsKey(KEY2)); + assertTrue(headers.keySet().contains(KEY1)); + assertTrue(headers.keySet().contains(KEY2)); + assertContainsExactly(headers.values(KEY1), VAL2); + assertContainsExactly(headers.values(KEY2), VAL3); + } + + @Test + public void keyCannotBeNullOrEmpty() { + Headers headers = new Headers(); + assertIllegalArgument(() -> headers.put(null, VAL1)); + assertIllegalArgument(() -> headers.put(null, VAL1, VAL2)); + assertIllegalArgument(() -> headers.put(null, Collections.singletonList(VAL1))); + assertIllegalArgument(() -> headers.put(EMPTY, VAL1)); + assertIllegalArgument(() -> headers.put(EMPTY, VAL1, VAL2)); + assertIllegalArgument(() -> headers.put(EMPTY, Collections.singletonList(VAL1))); + assertIllegalArgument(() -> headers.add(null, VAL1)); + assertIllegalArgument(() -> headers.add(null, VAL1, VAL2)); + assertIllegalArgument(() -> headers.add(null, Collections.singletonList(VAL1))); + assertIllegalArgument(() -> headers.add(EMPTY, VAL1)); + assertIllegalArgument(() -> headers.add(EMPTY, VAL1, VAL2)); + assertIllegalArgument(() -> headers.add(EMPTY, Collections.singletonList(VAL1))); + } + + @Test + public void valuesCannotBeNullOrEmpty() { + Headers headers = new Headers(); + + assertIllegalArgument(() -> headers.put(KEY1, (String) null)); + assertIllegalArgument(() -> headers.put(KEY1, EMPTY)); + assertIllegalArgument(() -> headers.put(KEY1, ANY_VAL, EMPTY)); + assertIllegalArgument(() -> headers.put(KEY1, ANY_VAL, null)); + assertIllegalArgument(() -> headers.put(KEY1, (Collection<String>) null)); + assertIllegalArgument(() -> headers.put(KEY1, Arrays.asList(KEY1, EMPTY))); + assertIllegalArgument(() -> headers.put(KEY1, Arrays.asList(KEY1, null))); + + assertIllegalArgument(() -> headers.add(KEY1, (String) null)); + assertIllegalArgument(() -> headers.add(KEY1, EMPTY)); + assertIllegalArgument(() -> headers.add(KEY1, ANY_VAL, EMPTY)); + assertIllegalArgument(() -> headers.add(KEY1, ANY_VAL, null)); + assertIllegalArgument(() -> headers.add(KEY1, (Collection<String>) null)); + assertIllegalArgument(() -> headers.add(KEY1, Arrays.asList(KEY1, EMPTY))); + assertIllegalArgument(() -> headers.add(KEY1, Arrays.asList(KEY1, null))); + } + + @Test + public void removes_work() { + Headers headers = testHeaders(); + assertTrue(headers.remove(KEY1)); + assertFalse(headers.remove(KEY1)); + assertContainsKeysExactly(headers, KEY2, KEY3); + + headers = testHeaders(); + assertTrue(headers.remove(KEY2, KEY3)); + assertFalse(headers.remove(KEY2, KEY3)); + assertContainsKeysExactly(headers, KEY1); + + headers = testHeaders(); + assertTrue(headers.remove(Collections.singletonList(KEY1))); + assertFalse(headers.remove(Collections.singletonList(KEY1))); + assertContainsKeysExactly(headers, KEY2, KEY3); + + headers = testHeaders(); + assertTrue(headers.remove(Arrays.asList(KEY2, KEY3))); + assertFalse(headers.remove(Arrays.asList(KEY2, KEY3))); + assertContainsKeysExactly(headers, KEY1); + } + + @Test + public void equalsHashcodeClearSizeEmpty_work() { + assertEquals(testHeaders(), testHeaders()); + assertEquals(testHeaders().hashCode(), testHeaders().hashCode()); + + Headers headers1 = new Headers(); + headers1.put(KEY1, VAL1); + Headers headers2 = new Headers(); + headers2.put(KEY2, VAL2); + assertNotEquals(headers1, headers2); + assertEquals(headers1.hashCode(), headers2.hashCode()); + + assertEquals(1, headers1.size()); + assertFalse(headers1.isEmpty()); + headers1.clear(); + assertEquals(0, headers1.size()); + assertTrue(headers1.isEmpty()); + } + + private Headers testHeaders() { + Headers headers = new Headers(); + headers.put(KEY1, VAL1); + headers.put(KEY2, VAL2); + headers.put(KEY3, VAL3); + return headers; + } + + // assert macros + interface IllegalArgumentHandler { + void execute(); + } + + private void assertIllegalArgument(IllegalArgumentHandler handler) { + try { + handler.execute(); + fail("IllegalArgumentException was expected to be thrown"); + } catch (IllegalArgumentException ignored) {} + } + + private void assertContainsExactly(Set<String> actual, String... expected) { + assertNotNull(actual); + assertEquals(actual.size(), expected.length); + for (String v : expected) { + assertTrue(actual.contains(v)); + } + } + + private void assertContainsKeysExactly(Headers header, String... expected) { + assertEquals(header.size(), expected.length); + for (String key : expected) { + assertTrue(header.containsKey(key)); + } + } +} diff --git a/src/test/java/io/nats/client/impl/InfoHandlerTests.java b/src/test/java/io/nats/client/impl/InfoHandlerTests.java index 52360b646..8063036eb 100644 --- a/src/test/java/io/nats/client/impl/InfoHandlerTests.java +++ b/src/test/java/io/nats/client/impl/InfoHandlerTests.java @@ -106,7 +106,7 @@ public void testUnsolicitedInfo() throws IOException, InterruptedException, Exec - @Test + //@Test public void testLDM() throws IOException, InterruptedException, ExecutionException, TimeoutException { String customInfo = "{\"server_id\":\"myid\", \"ldm\":true}"; CompletableFuture<Boolean> gotPong = new CompletableFuture<>(); diff --git a/src/test/java/io/nats/client/impl/MessageProtocolCreationBenchmark.java b/src/test/java/io/nats/client/impl/MessageProtocolCreationBenchmark.java index 35e78a240..54fd8b4d5 100644 --- a/src/test/java/io/nats/client/impl/MessageProtocolCreationBenchmark.java +++ b/src/test/java/io/nats/client/impl/MessageProtocolCreationBenchmark.java @@ -24,12 +24,12 @@ public static void main(String args[]) throws InterruptedException { System.out.printf("### Running benchmarks with %s messages.\n", NumberFormat.getInstance().format(msgCount)); for (int j = 0; j < warmup; j++) { - new NatsMessage("subject", "replyTo", NatsConnection.EMPTY_BODY, true); + new NatsMessage.PublishBuilder().subject("subject").replyTo("replyTo").utf8mode(true).build(); } long start = System.nanoTime(); for (int j = 0; j < msgCount; j++) { - new NatsMessage("subject", "replyTo", NatsConnection.EMPTY_BODY, false); + new NatsMessage.PublishBuilder().subject("subject").replyTo("replyTo").build(); } long end = System.nanoTime(); @@ -41,7 +41,7 @@ public static void main(String args[]) throws InterruptedException { start = System.nanoTime(); for (int j = 0; j < msgCount; j++) { - new NatsMessage("subject", "replyTo", NatsConnection.EMPTY_BODY, true); + new NatsMessage.PublishBuilder().subject("subject").replyTo("replyTo").utf8mode(true).build(); } end = System.nanoTime(); @@ -53,7 +53,7 @@ public static void main(String args[]) throws InterruptedException { start = System.nanoTime(); for (int j = 0; j < msgCount; j++) { - new NatsMessage(CharBuffer.allocate(0)); + NatsMessage.getProtocolInstance(CharBuffer.allocate(0)); } end = System.nanoTime(); diff --git a/src/test/java/io/nats/client/impl/MessageQueueBenchmark.java b/src/test/java/io/nats/client/impl/MessageQueueBenchmark.java index cfe8f5c1e..d7262659e 100644 --- a/src/test/java/io/nats/client/impl/MessageQueueBenchmark.java +++ b/src/test/java/io/nats/client/impl/MessageQueueBenchmark.java @@ -30,7 +30,7 @@ public static void main(String args[]) throws InterruptedException { MessageQueue warm = new MessageQueue(false); for (int j = 0; j < msgCount; j++) { - msgs[j] = new NatsMessage(buff); + msgs[j] = NatsMessage.getProtocolInstance(buff); warm.push(msgs[j]); } @@ -64,7 +64,7 @@ public static void main(String args[]) throws InterruptedException { MessageQueue accumulateQueue = new MessageQueue(true); for (int j = 0; j < msgCount; j++) { - msgs[j].next = null; + msgs[j].setNext(null); } for (int i = 0; i < msgCount; i++) { accumulateQueue.push(msgs[i]); @@ -83,7 +83,7 @@ public static void main(String args[]) throws InterruptedException { NumberFormat.getInstance().format(1_000_000_000L * ((double) (msgCount))/((double) (end - start)))); for (int j = 0; j < msgCount; j++) { - msgs[j].next = null; + msgs[j].setNext(null); } final MessageQueue pushPopThreadQueue = new MessageQueue(false); final Duration timeout = Duration.ofMillis(10); @@ -127,7 +127,7 @@ public static void main(String args[]) throws InterruptedException { final CompletableFuture<Void> go2 = new CompletableFuture<>(); for (int j = 0; j < msgCount; j++) { - msgs[j].next = null; + msgs[j].setNext(null); } final MessageQueue pushPopNowThreadQueue = new MessageQueue(false); pusher = new Thread(() -> { @@ -169,8 +169,8 @@ public static void main(String args[]) throws InterruptedException { final CompletableFuture<Void> go3 = new CompletableFuture<>(); for (int j = 0; j < msgCount; j++) { - msgs[j].next = null; - } + msgs[j].setNext(null); + } final MessageQueue pushAccumulateThreadQueue = new MessageQueue(true); pusher = new Thread(() -> { @@ -193,7 +193,7 @@ public static void main(String args[]) throws InterruptedException { NatsMessage cursor = pushAccumulateThreadQueue.accumulate(10_000, 100, Duration.ofMillis(500)); while (cursor != null) { remaining--; - cursor = cursor.next; + cursor = cursor.getNext(); } } } catch (Exception exp) { diff --git a/src/test/java/io/nats/client/impl/MessageQueueTests.java b/src/test/java/io/nats/client/impl/MessageQueueTests.java index 9104a37f4..9041888a8 100644 --- a/src/test/java/io/nats/client/impl/MessageQueueTests.java +++ b/src/test/java/io/nats/client/impl/MessageQueueTests.java @@ -13,15 +13,9 @@ package io.nats.client.impl; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import org.junit.Test; import java.io.UnsupportedEncodingException; -import java.nio.CharBuffer; import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.Arrays; @@ -29,10 +23,13 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; -import org.junit.Test; +import static org.junit.Assert.*; public class MessageQueueTests { + private NatsMessage newMessage(String protocol) { return NatsMessage.getProtocolInstance(protocol); } + private NatsMessage newPingMessage() { return NatsMessage.getProtocolInstance("PING"); } + @Test public void testEmptyPop() throws InterruptedException { MessageQueue q = new MessageQueue(false); @@ -44,14 +41,14 @@ public void testEmptyPop() throws InterruptedException { @Test(expected = IllegalStateException.class) public void testAccumulateThrowsOnNonSingleReader() throws InterruptedException { MessageQueue q = new MessageQueue(false); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); + q.push(newPingMessage()); q.accumulate(100,1,null); } @Test public void testPushPop() throws InterruptedException { MessageQueue q = new MessageQueue(false); - NatsMessage expected = new NatsMessage(CharBuffer.wrap("PING")); + NatsMessage expected = newPingMessage(); q.push(expected); NatsMessage actual = q.popNow(); assertEquals(expected, actual); @@ -76,7 +73,7 @@ public void testTimeout() throws InterruptedException { @Test public void testTimeoutZero() throws InterruptedException { MessageQueue q = new MessageQueue(false); - NatsMessage expected = new NatsMessage(CharBuffer.wrap("PING")); + NatsMessage expected = newPingMessage(); q.push(expected); NatsMessage msg = q.pop(Duration.ZERO); assertNotNull(msg); @@ -101,7 +98,7 @@ public void testReset() throws InterruptedException { NatsMessage msg = q.pop(Duration.ZERO); assertNull(msg); - NatsMessage expected = new NatsMessage(CharBuffer.wrap("PING")); + NatsMessage expected = newPingMessage(); q.push(expected); msg = q.pop(Duration.ZERO); @@ -120,7 +117,7 @@ public void testPopBeforeTimeout() throws InterruptedException { Thread t = new Thread(() -> { try { Thread.sleep(500); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); + q.push(newPingMessage()); } catch (Exception exp) { // eat the exception, test will fail } @@ -139,7 +136,7 @@ public void testMultipleWriters() throws InterruptedException { int threads = 10; for (int i=0;i<threads;i++) { - Thread t = new Thread(() -> {q.push(new NatsMessage(CharBuffer.wrap("PING")));}); + Thread t = new Thread(() -> {q.push(newPingMessage());}); t.start(); } @@ -161,7 +158,7 @@ public void testMultipleReaders() throws InterruptedException { CountDownLatch latch = new CountDownLatch(threads); for (int i=0;i<threads;i++) { - q.push(new NatsMessage(CharBuffer.wrap("PING"))); + q.push(newPingMessage()); } for (int i=0;i<threads;i++) { @@ -192,7 +189,7 @@ public void testMultipleReadersAndWriters() throws InterruptedException { for (int i=0;i<threads;i++) { Thread t = new Thread(() -> { for (int j=0;j<msgPerThread;j++) { - q.push(new NatsMessage(CharBuffer.wrap("PING"))); + q.push(newPingMessage()); }}); t.start(); } @@ -228,7 +225,7 @@ public void testMultipleReaderWriters() throws InterruptedException { for (int i=0;i<threads;i++) { Thread t = new Thread(() -> { for (int j=0;j<msgPerThread;j++) { - q.push(new NatsMessage(CharBuffer.wrap("PING"))); + q.push(newPingMessage()); try{NatsMessage msg = q.pop(Duration.ofMillis(300)); if(msg!=null){count.incrementAndGet();} latch.countDown();}catch(Exception e){} @@ -255,7 +252,7 @@ public void testEmptyAccumulate() throws InterruptedException { @Test public void testSingleAccumulate() throws InterruptedException { MessageQueue q = new MessageQueue(true); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); + q.push(newPingMessage()); NatsMessage msg = q.accumulate(100,1,null); assertNotNull(msg); } @@ -263,9 +260,9 @@ public void testSingleAccumulate() throws InterruptedException { @Test public void testMultiAccumulate() throws InterruptedException { MessageQueue q = new MessageQueue(true); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); + q.push(newPingMessage()); + q.push(newPingMessage()); + q.push(newPingMessage()); NatsMessage msg = q.accumulate(100,3,null); assertNotNull(msg); } @@ -273,7 +270,7 @@ public void testMultiAccumulate() throws InterruptedException { private void checkCount(NatsMessage first, int expected) { while (expected > 0) { assertNotNull(first); - first = first.next; + first = first.getNext(); expected--; } @@ -283,10 +280,10 @@ private void checkCount(NatsMessage first, int expected) { @Test public void testPartialAccumulateOnCount() throws InterruptedException { MessageQueue q = new MessageQueue(true); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); + q.push(newPingMessage()); + q.push(newPingMessage()); + q.push(newPingMessage()); + q.push(newPingMessage()); NatsMessage msg = q.accumulate(100,3,null); checkCount(msg, 3); @@ -297,12 +294,12 @@ public void testPartialAccumulateOnCount() throws InterruptedException { @Test public void testMultipleAccumulateOnCount() throws InterruptedException { MessageQueue q = new MessageQueue(true); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); + q.push(newPingMessage()); + q.push(newPingMessage()); + q.push(newPingMessage()); + q.push(newPingMessage()); + q.push(newPingMessage()); + q.push(newPingMessage()); NatsMessage msg = q.accumulate(100,2,null); checkCount(msg, 2); @@ -317,10 +314,10 @@ public void testMultipleAccumulateOnCount() throws InterruptedException { @Test public void testPartialAccumulateOnSize() throws InterruptedException { MessageQueue q = new MessageQueue(true); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); + q.push(newPingMessage()); + q.push(newPingMessage()); + q.push(newPingMessage()); + q.push(newPingMessage()); NatsMessage msg = q.accumulate(20,100,null); // each one is 6 so 20 should be 3 messages checkCount(msg, 3); @@ -331,12 +328,12 @@ public void testPartialAccumulateOnSize() throws InterruptedException { @Test public void testMultipleAccumulateOnSize() throws InterruptedException { MessageQueue q = new MessageQueue(true); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); + q.push(newPingMessage()); + q.push(newPingMessage()); + q.push(newPingMessage()); + q.push(newPingMessage()); + q.push(newPingMessage()); + q.push(newPingMessage()); NatsMessage msg = q.accumulate(14,100,null); // each one is 6 so 14 should be 2 messages checkCount(msg, 2); @@ -350,10 +347,10 @@ public void testMultipleAccumulateOnSize() throws InterruptedException { @Test public void testAccumulateAndPop() throws InterruptedException { MessageQueue q = new MessageQueue(true); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); - q.push(new NatsMessage(CharBuffer.wrap("PING"))); + q.push(newPingMessage()); + q.push(newPingMessage()); + q.push(newPingMessage()); + q.push(newPingMessage()); NatsMessage msg = q.accumulate(100,3,null); checkCount(msg, 3); @@ -378,7 +375,7 @@ public void testMultipleWritersOneAccumulator() throws InterruptedException { for (int i=0;i<threads;i++) { Thread t = new Thread(() -> { for (int j=0;j<msgPerThread;j++) { - q.push(new NatsMessage(CharBuffer.wrap("PING"))); + q.push(newPingMessage()); sent.incrementAndGet(); }; }); @@ -391,7 +388,7 @@ public void testMultipleWritersOneAccumulator() throws InterruptedException { while (msg != null) { count.incrementAndGet(); - msg = msg.next; + msg = msg.getNext(); } tries--; Thread.sleep(1); @@ -417,9 +414,9 @@ public void testInteruptAccumulate() throws InterruptedException { @Test public void testLength() throws InterruptedException { MessageQueue q = new MessageQueue(true); - NatsMessage msg1 = new NatsMessage(CharBuffer.wrap("PING")); - NatsMessage msg2 = new NatsMessage(CharBuffer.wrap("PING")); - NatsMessage msg3 = new NatsMessage(CharBuffer.wrap("PING")); + NatsMessage msg1 = newPingMessage(); + NatsMessage msg2 = newPingMessage(); + NatsMessage msg3 = newPingMessage(); q.push(msg1); assertEquals(1, q.length()); @@ -436,9 +433,9 @@ public void testLength() throws InterruptedException { @Test public void testSizeInBytes() throws InterruptedException { MessageQueue q = new MessageQueue(true); - NatsMessage msg1 = new NatsMessage(CharBuffer.wrap("one")); - NatsMessage msg2 = new NatsMessage(CharBuffer.wrap("two")); - NatsMessage msg3 = new NatsMessage(CharBuffer.wrap("three")); + NatsMessage msg1 = newMessage("one"); + NatsMessage msg2 = newMessage("two"); + NatsMessage msg3 = newMessage("three"); long expected = 0; q.push(msg1); expected += msg1.getSizeInBytes(); @@ -456,9 +453,9 @@ public void testSizeInBytes() throws InterruptedException { @Test public void testFilterTail() throws InterruptedException, UnsupportedEncodingException { MessageQueue q = new MessageQueue(true); - NatsMessage msg1 = new NatsMessage(CharBuffer.wrap("one")); - NatsMessage msg2 = new NatsMessage(CharBuffer.wrap("two")); - NatsMessage msg3 = new NatsMessage(CharBuffer.wrap("three")); + NatsMessage msg1 = newMessage("one"); + NatsMessage msg2 = newMessage("two"); + NatsMessage msg3 = newMessage("three"); byte[] expected = "one".getBytes(StandardCharsets.UTF_8); q.push(msg1); @@ -467,7 +464,7 @@ public void testFilterTail() throws InterruptedException, UnsupportedEncodingExc long before = q.sizeInBytes(); q.pause(); - q.filter((msg) -> {return Arrays.equals(expected, msg.getProtocolBytes());}); + q.filter((msg) -> Arrays.equals(expected, msg.getProtocolBytes()) ); q.resume(); long after = q.sizeInBytes(); @@ -480,9 +477,9 @@ public void testFilterTail() throws InterruptedException, UnsupportedEncodingExc @Test public void testFilterHead() throws InterruptedException, UnsupportedEncodingException { MessageQueue q = new MessageQueue(true); - NatsMessage msg1 = new NatsMessage(CharBuffer.wrap("one")); - NatsMessage msg2 = new NatsMessage(CharBuffer.wrap("two")); - NatsMessage msg3 = new NatsMessage(CharBuffer.wrap("three")); + NatsMessage msg1 = newMessage("one"); + NatsMessage msg2 = newMessage("two"); + NatsMessage msg3 = newMessage("three"); byte[] expected = "three".getBytes(StandardCharsets.UTF_8); q.push(msg1); @@ -504,9 +501,9 @@ public void testFilterHead() throws InterruptedException, UnsupportedEncodingExc @Test public void testFilterMiddle() throws InterruptedException, UnsupportedEncodingException { MessageQueue q = new MessageQueue(true); - NatsMessage msg1 = new NatsMessage(CharBuffer.wrap("one")); - NatsMessage msg2 = new NatsMessage(CharBuffer.wrap("two")); - NatsMessage msg3 = new NatsMessage(CharBuffer.wrap("three")); + NatsMessage msg1 = newMessage("one"); + NatsMessage msg2 = newMessage("two"); + NatsMessage msg3 = newMessage("three"); byte[] expected = "two".getBytes(StandardCharsets.UTF_8); q.push(msg1); @@ -543,9 +540,9 @@ public void testThrowOnFilterIfRunning() throws InterruptedException { @Test public void testExceptionWhenQueueIsFull() { MessageQueue q = new MessageQueue(true, 2); - NatsMessage msg1 = new NatsMessage(CharBuffer.wrap("one")); - NatsMessage msg2 = new NatsMessage(CharBuffer.wrap("two")); - NatsMessage msg3 = new NatsMessage(CharBuffer.wrap("three")); + NatsMessage msg1 = newMessage("one"); + NatsMessage msg2 = newMessage("two"); + NatsMessage msg3 = newMessage("three"); assertTrue(q.push(msg1)); assertTrue(q.push(msg2)); @@ -560,9 +557,9 @@ public void testExceptionWhenQueueIsFull() { @Test public void testDiscardMessageWhenQueueFull() { MessageQueue q = new MessageQueue(true, 2, true); - NatsMessage msg1 = new NatsMessage(CharBuffer.wrap("one")); - NatsMessage msg2 = new NatsMessage(CharBuffer.wrap("two")); - NatsMessage msg3 = new NatsMessage(CharBuffer.wrap("three")); + NatsMessage msg1 = newMessage("one"); + NatsMessage msg2 = newMessage("two"); + NatsMessage msg3 = newMessage("three"); assertTrue(q.push(msg1)); assertTrue(q.push(msg2)); diff --git a/src/test/java/io/nats/client/impl/NatsConnectionReaderTest.java b/src/test/java/io/nats/client/impl/NatsConnectionReaderTest.java new file mode 100644 index 000000000..b6ae9507a --- /dev/null +++ b/src/test/java/io/nats/client/impl/NatsConnectionReaderTest.java @@ -0,0 +1,253 @@ +package io.nats.client.impl; + +import io.nats.client.Options; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Collection; + +import static org.junit.Assert.*; + +public class NatsConnectionReaderTest { + + NatsConnectionReader reader; + ProtocolHandlerMock protocolHandler; + NatsStatisticsMock natsStatistics; + + DataPortMock dataPort; + + + class ProtocolHandlerMock implements ProtocolHandler { + Exception lastException; + NatsMessage lastMessage; + String lastError; + int okCount; + int pongCount; + int pingCount; + String infoJSON; + + @Override + public void handleCommunicationIssue(Exception io) { + lastException = io; + } + + @Override + public void deliverMessage(NatsMessage msg) { + + System.out.println("################### MESSAGE " + msg); + + lastMessage = msg; + } + + @Override + public void processOK() { + okCount++; + } + + @Override + public void processError(String errorText) { + lastError = errorText; + } + + @Override + public void sendPong() { + pingCount++; + } + + @Override + public void handlePong() { + pongCount++; + } + + @Override + public void handleInfo(String infoJson) { + infoJSON = infoJson; + } + } + + + class NatsStatisticsMock extends NatsStatistics { + + public NatsStatisticsMock(boolean trackAdvanced) { + super(trackAdvanced); + } + } + + class DataPortMock implements DataPort { + + byte[] bytes = null; + + @Override + public void connect(String serverURI, NatsConnection conn, long timeoutNanos) throws IOException { + } + + @Override + public void upgradeToSecure() throws IOException { + + } + + @Override + public int read(byte[] dst, int off, int len) throws IOException { + + System.out.println("READ CALLED"); + if (bytes != null) { + if (len < bytes.length) { + System.arraycopy(bytes, 0, dst, off, len); + return len; + } else { + System.arraycopy(bytes, 0, dst, off, bytes.length); + return bytes.length; + } + } else { + return 0; + } + } + + @Override + public void write(byte[] src, int toWrite) throws IOException { + + } + + @Override + public void close() throws IOException { + + } + } + + @Before + public void setUp() throws Exception { + protocolHandler = new ProtocolHandlerMock(); + natsStatistics = new NatsStatisticsMock(false); + dataPort = new DataPortMock(); + reader = new NatsConnectionReader(protocolHandler, new Options.Builder().build(), natsStatistics, null); + reader.init(); + } + + @Test + public void runOnce() throws IOException { + reader.runOnce(dataPort); + } + + @Test + public void connect() throws IOException { + dataPort.bytes = "INFO {[\"foo\":bar]}\r\n".getBytes(StandardCharsets.UTF_8); + + for (int i = 0; i < 10; i++) { + System.out.println("" + i + " BEFORE OP " + reader.currentOp()); + System.out.println("" + i + " BEFORE MODE " + reader.getMode()); + + reader.runOnce(dataPort); + dataPort.bytes = null; + System.out.println("" + i + " AFTER OP " + reader.currentOp()); + System.out.println("" + i + " AFTER MODE " + reader.getMode()); + System.out.println("" + i + " AFTER INFO " + protocolHandler.infoJSON); + System.out.println("" + i + " AFTER MSG " + protocolHandler.lastMessage); + + if (protocolHandler.infoJSON != null) break; + } + + assertEquals("{[\"foo\":bar]}", protocolHandler.infoJSON); + assertNull(protocolHandler.lastError); + assertNull(protocolHandler.lastException); + + } + + @Test + public void message() throws IOException { + dataPort.bytes = "MSG subj sid reply-to 1\r\nA\r\n".getBytes(StandardCharsets.UTF_8); + + for (int i = 0; i < 10; i++) { + System.out.println("" + i + " BEFORE OP " + reader.currentOp()); + System.out.println("" + i + " BEFORE MODE " + reader.getMode()); + + reader.runOnce(dataPort); + dataPort.bytes = null; + System.out.println("" + i + " AFTER OP " + reader.currentOp()); + System.out.println("" + i + " AFTER MODE " + reader.getMode()); + System.out.println("" + i + " AFTER INFO " + protocolHandler.infoJSON); + System.out.println("" + i + " AFTER MSG " + protocolHandler.lastMessage); + + if (protocolHandler.lastMessage != null) break; + } + + assertNotNull(protocolHandler.lastMessage); + assertNull(protocolHandler.lastError); + assertNull(protocolHandler.lastException); + + assertEquals("subj", protocolHandler.lastMessage.getSubject()); + assertEquals("reply-to", protocolHandler.lastMessage.getReplyTo()); + assertEquals("sid", protocolHandler.lastMessage.getSID()); + assertEquals(1, protocolHandler.lastMessage.getData().length); + assertEquals('A', protocolHandler.lastMessage.getData()[0]); + + } + + + @Test + public void hMessage() throws IOException { + + final String headers = "HEADER1: VALUE1\r\n" + + "HEADER2: VALUE2\r\n"; + final int headerLength = headers.length(); + final String body = "A"; + final int bodyLength = body.length(); + final String subject = "subj"; + final String replyTo = "reply-to"; + final String sid = "sid"; + ////////////////////////////////////////////sb si rt hl tl hd payload + final String protocol = String.format("HMSG %s %s %s %s %s\r\n%s\r\n%s\r\n", + subject, sid, replyTo, headerLength, headerLength + bodyLength, headers, body); + dataPort.bytes = protocol.getBytes(StandardCharsets.UTF_8); + + + reader.runOnce(dataPort); + + assertNull(protocolHandler.lastError); + assertNull(protocolHandler.lastException); + + assertEquals("subj", protocolHandler.lastMessage.getSubject()); + assertEquals("reply-to", protocolHandler.lastMessage.getReplyTo()); + assertEquals("sid", protocolHandler.lastMessage.getSID()); + assertEquals(1, protocolHandler.lastMessage.getData().length); + assertEquals('A', protocolHandler.lastMessage.getData()[0]); + + assertEquals(2, protocolHandler.lastMessage.getHeaders().size()); + + Collection<String> values = protocolHandler.lastMessage.getHeaders().values("HEADER1"); + assertEquals(1, values.size()); + assertTrue(values.contains("VALUE1")); + + values = protocolHandler.lastMessage.getHeaders().values("HEADER2"); + assertEquals(1, values.size()); + assertTrue(values.contains("VALUE2")); + } + + @Test + public void infoLDM() throws IOException { + + final String customInfo = "{\"server_id\":\"myid\", \"ldm\":true}"; + + final String protocol = String.format("INFO %s \r\n", + customInfo); + + dataPort.bytes = protocol.getBytes(StandardCharsets.UTF_8); + + + reader.runOnce(dataPort); + + assertNull(protocolHandler.lastError); + assertNull(protocolHandler.lastException); + + assertTrue(protocolHandler.infoJSON.contains("\"ldm\":true")); + + + final NatsServerInfo natsServerInfo = new NatsServerInfo(protocolHandler.infoJSON); + + assertTrue(natsServerInfo.isLameDuckMode()); + + } + + +} \ No newline at end of file diff --git a/src/test/java/io/nats/client/impl/NatsMessagePublishBuilderTest.java b/src/test/java/io/nats/client/impl/NatsMessagePublishBuilderTest.java new file mode 100644 index 000000000..3497ff6a7 --- /dev/null +++ b/src/test/java/io/nats/client/impl/NatsMessagePublishBuilderTest.java @@ -0,0 +1,80 @@ +package io.nats.client.impl; + +import io.nats.client.Message; +import org.junit.Test; + +import java.nio.charset.StandardCharsets; + +import static org.junit.Assert.*; + +public class NatsMessagePublishBuilderTest { + + @Test + public void testBuilderHPUB() { + + Headers headers = new Headers(); + headers.put("header1", "value1.1"); + headers.add("header1", "value1.2"); + headers.put("header2", "value2.1"); + + Message message = starterBuilder().headers(headers).build(); + + assertStarter(message, false); + + assertNotNull(message.getHeaders()); + assertEquals(2, message.getHeaders().size()); + assertEquals(2, message.getHeaders().values("header1").size()); + assertEquals(1, message.getHeaders().values("header2").size()); + + assertMessageString(message, + "HPUB subject replyTo 59 64\r\n", + "header1: value1.1\r\n", + "header1: value1.2\r\n", + "header2: value2.1\r\n", + "\r\n\r\n" + ); + } + + private void assertStarter(Message message, boolean expectedUtfMode) { + assertEquals("subject", message.getSubject()); + assertEquals("replyTo", message.getReplyTo()); + assertEquals("Hello", new String(message.getData(), StandardCharsets.UTF_8)); + assertEquals(expectedUtfMode, message.isUtf8mode()); + } + + @Test + public void testBuilderPubUTF8() { + Message message = starterBuilder().utf8mode(true).build(); + assertStarter(message, true); + assertMessageString(message, "PUB subject replyTo 5"); + } + + @Test + public void testBuilderPubNoUTF8() { + Message message = starterBuilder().build(); // default is utf8Mode false + assertStarter(message, false); + assertMessageString(message, "PUB subject replyTo 5"); + + message = starterBuilder().utf8mode(false).build(); + assertStarter(message, false); + assertMessageString(message, "PUB subject replyTo 5"); + } + + private void assertMessageString(Message message, String messageStart, String... contains) { + String messageString = new String(message.getProtocolBytes(), StandardCharsets.UTF_8); + assertTrue(messageString.startsWith(messageStart)); + if (contains != null) { + for (String c : contains) { + assertTrue(messageString.contains(c)); + } + } + } + + private NatsMessage.PublishBuilder starterBuilder() { + return new NatsMessage.PublishBuilder() + .subject("subject") + .replyTo("replyTo") + .data("Hello", StandardCharsets.UTF_8) + .maxPayload(10000L); + } +} \ No newline at end of file diff --git a/src/test/java/io/nats/client/impl/NatsMessageTests.java b/src/test/java/io/nats/client/impl/NatsMessageTests.java index b6cd94751..d1b7e9d2e 100644 --- a/src/test/java/io/nats/client/impl/NatsMessageTests.java +++ b/src/test/java/io/nats/client/impl/NatsMessageTests.java @@ -13,26 +13,18 @@ package io.nats.client.impl; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import io.nats.client.*; +import io.nats.client.NatsServerProtocolMock.ExitAt; +import org.junit.Test; -import java.nio.CharBuffer; import java.nio.charset.StandardCharsets; -import org.junit.Test; - -import io.nats.client.Connection; -import io.nats.client.Nats; -import io.nats.client.NatsServerProtocolMock; -import io.nats.client.NatsTestServer; -import io.nats.client.Options; -import io.nats.client.NatsServerProtocolMock.ExitAt; +import static org.junit.Assert.*; public class NatsMessageTests { @Test public void testSizeOnProtocolMessage() { - NatsMessage msg = new NatsMessage(CharBuffer.wrap("PING")); + NatsMessage msg = NatsMessage.getProtocolInstance("PING"); assertEquals("Size is set, with CRLF", msg.getProtocolBytes().length + 2, msg.getSizeInBytes()); assertEquals("Size is correct", "PING".getBytes(StandardCharsets.UTF_8).length + 2, msg.getSizeInBytes()); @@ -45,16 +37,19 @@ public void testSizeOnPublishMessage() { String replyTo = "reply"; String protocol = "PUB "+subject+" "+replyTo+" "+body.length; - NatsMessage msg = new NatsMessage(subject, replyTo, body, false); + NatsMessage msg = new NatsMessage.PublishBuilder() + .subject(subject).replyTo(replyTo).data(body).maxPayload(10000L).build(); assertEquals("Size is set, with CRLF", msg.getProtocolBytes().length + body.length + 4, msg.getSizeInBytes()); assertEquals("Size is correct", protocol.getBytes(StandardCharsets.US_ASCII).length + body.length + 4, msg.getSizeInBytes()); - msg = new NatsMessage(subject, replyTo, body, true); + msg = new NatsMessage.PublishBuilder() + .subject(subject).replyTo(replyTo).data(body).utf8mode(true).maxPayload(10000L).build(); assertEquals("Size is set, with CRLF", msg.getProtocolBytes().length + body.length + 4, msg.getSizeInBytes()); assertEquals("Size is correct", protocol.getBytes(StandardCharsets.UTF_8).length + body.length + 4, msg.getSizeInBytes()); } + @Test(expected=IllegalArgumentException.class) public void testCustomMaxControlLine() throws Exception { diff --git a/src/test/java/io/nats/client/impl/ParseTests.java b/src/test/java/io/nats/client/impl/ParseTests.java index f15122fa4..46c4bad51 100644 --- a/src/test/java/io/nats/client/impl/ParseTests.java +++ b/src/test/java/io/nats/client/impl/ParseTests.java @@ -13,22 +13,21 @@ package io.nats.client.impl; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; +import io.nats.client.Nats; +import io.nats.client.NatsTestServer; +import io.nats.client.Options; +import org.junit.Test; import java.io.IOException; import java.nio.charset.StandardCharsets; -import org.junit.Test; - -import io.nats.client.Nats; -import io.nats.client.NatsTestServer; -import io.nats.client.Options; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; public class ParseTests { @Test public void testGoodNumbers() { - int i=1; + int i = 1; while (i < 2_000_000_000 && i > 0) { assertEquals(i, NatsConnectionReader.parseLength(String.valueOf(i))); @@ -39,22 +38,22 @@ public void testGoodNumbers() { } - @Test(expected=NumberFormatException.class) + @Test(expected = NumberFormatException.class) public void testBadChars() { NatsConnectionReader.parseLength("2221a"); assertFalse(true); } - @Test(expected=NumberFormatException.class) + @Test(expected = NumberFormatException.class) public void testTooBig() { NatsConnectionReader.parseLength(String.valueOf(100_000_000_000L)); assertFalse(true); } - @Test(expected=IOException.class) + @Test(expected = IOException.class) public void testLongProtocolOpThrows() throws Exception { try (NatsTestServer ts = new NatsTestServer(false); - NatsConnection nc = (NatsConnection) Nats.connect(ts.getURI())) { + NatsConnection nc = (NatsConnection) Nats.connect(ts.getURI())) { NatsConnectionReader reader = nc.getReader(); byte[] bytes = ("thisistoolong\r\n").getBytes(StandardCharsets.US_ASCII); reader.fakeReadForTest(bytes); @@ -63,10 +62,10 @@ public void testLongProtocolOpThrows() throws Exception { } } - @Test(expected=IOException.class) + @Test(expected = IOException.class) public void testMissingLineFeed() throws Exception { try (NatsTestServer ts = new NatsTestServer(false); - NatsConnection nc = (NatsConnection) Nats.connect(ts.getURI())) { + NatsConnection nc = (NatsConnection) Nats.connect(ts.getURI())) { NatsConnectionReader reader = nc.getReader(); byte[] bytes = ("PING\rPONG").getBytes(StandardCharsets.US_ASCII); reader.fakeReadForTest(bytes); @@ -75,10 +74,10 @@ public void testMissingLineFeed() throws Exception { } } - @Test(expected=IOException.class) + @Test(expected = IOException.class) public void testMissingSubject() throws Exception { try (NatsTestServer ts = new NatsTestServer(false); - NatsConnection nc = (NatsConnection) Nats.connect(ts.getURI())) { + NatsConnection nc = (NatsConnection) Nats.connect(ts.getURI())) { NatsConnectionReader reader = nc.getReader(); byte[] bytes = ("MSG 1 1\r\n").getBytes(StandardCharsets.US_ASCII); reader.fakeReadForTest(bytes); @@ -89,10 +88,10 @@ public void testMissingSubject() throws Exception { } } - @Test(expected=IOException.class) + @Test(expected = IOException.class) public void testMissingSID() throws Exception { try (NatsTestServer ts = new NatsTestServer(false); - NatsConnection nc = (NatsConnection) Nats.connect(ts.getURI())) { + NatsConnection nc = (NatsConnection) Nats.connect(ts.getURI())) { NatsConnectionReader reader = nc.getReader(); byte[] bytes = ("MSG subject 1\r\n").getBytes(StandardCharsets.US_ASCII); reader.fakeReadForTest(bytes); @@ -103,10 +102,10 @@ public void testMissingSID() throws Exception { } } - @Test(expected=IOException.class) + @Test(expected = IOException.class) public void testMissingLength() throws Exception { try (NatsTestServer ts = new NatsTestServer(false); - NatsConnection nc = (NatsConnection) Nats.connect(ts.getURI())) { + NatsConnection nc = (NatsConnection) Nats.connect(ts.getURI())) { NatsConnectionReader reader = nc.getReader(); byte[] bytes = ("MSG subject 2 \r\n").getBytes(StandardCharsets.US_ASCII); reader.fakeReadForTest(bytes); @@ -117,10 +116,10 @@ public void testMissingLength() throws Exception { } } - @Test(expected=IOException.class) + @Test(expected = IOException.class) public void testBadLength() throws Exception { try (NatsTestServer ts = new NatsTestServer(false); - NatsConnection nc = (NatsConnection) Nats.connect(ts.getURI())) { + NatsConnection nc = (NatsConnection) Nats.connect(ts.getURI())) { NatsConnectionReader reader = nc.getReader(); byte[] bytes = ("MSG subject 2 x\r\n").getBytes(StandardCharsets.US_ASCII); reader.fakeReadForTest(bytes); @@ -131,13 +130,13 @@ public void testBadLength() throws Exception { } } - @Test(expected=IOException.class) + @Test(expected = IOException.class) public void testMessageLineTooLong() throws Exception { try (NatsTestServer ts = new NatsTestServer(false); - NatsConnection nc = (NatsConnection) Nats.connect(new Options.Builder(). - server(ts.getURI()). - maxControlLine(16). - build())) { + NatsConnection nc = (NatsConnection) Nats.connect(new Options.Builder(). + server(ts.getURI()). + maxControlLine(16). + build())) { NatsConnectionReader reader = nc.getReader(); byte[] bytes = ("MSG reallylongsubjectobreakthelength 1 1\r\n").getBytes(StandardCharsets.US_ASCII); reader.fakeReadForTest(bytes); @@ -148,18 +147,18 @@ public void testMessageLineTooLong() throws Exception { } } - @Test(expected=IllegalArgumentException.class) + @Test(expected = IllegalArgumentException.class) public void testProtocolLineTooLong() throws Exception { try (NatsTestServer ts = new NatsTestServer(false); - NatsConnection nc = (NatsConnection) Nats.connect(new Options.Builder(). - server(ts.getURI()). - maxControlLine(1024). - build())) { + NatsConnection nc = (NatsConnection) Nats.connect(new Options.Builder(). + server(ts.getURI()). + maxControlLine(1024). + build())) { NatsConnectionReader reader = nc.getReader(); StringBuilder longString = new StringBuilder(); longString.append("INFO "); - for (int i=0;i<500;i++ ){ + for (int i = 0; i < 500; i++) { longString.append("helloworld"); } @@ -172,42 +171,43 @@ public void testProtocolLineTooLong() throws Exception { } } + @Test public void testProtocolStrings() throws Exception { String[] serverStrings = { - "+OK", "PONG", "PING", "MSG longer.subject.abitlikeaninbox 22 longer.replyto.abitlikeaninbox 234", - "-ERR some error with spaces in it", "INFO {" + "\"server_id\":\"myserver\"" + "," + "\"version\":\"1.1.1\"" + "," - + "\"go\": \"go1.9\"" + "," + "\"host\": \"host\"" + "," + "\"tls_required\": true" + "," - + "\"auth_required\":false" + "," + "\"port\": 7777" + "," + "\"max_payload\":100000000000" + "," - + "\"connect_urls\":[\"one\", \"two\"]" + "}", "ping", "msg one 22 33", "+oK", "PoNg", "pong", "MsG one 22 23" + "+OK", "PONG", "PING", "MSG longer.subject.abitlikeaninbox 22 longer.replyto.abitlikeaninbox 234", + "-ERR some error with spaces in it", "INFO {" + "\"server_id\":\"myserver\"" + "," + "\"version\":\"1.1.1\"" + "," + + "\"go\": \"go1.9\"" + "," + "\"host\": \"host\"" + "," + "\"tls_required\": true" + "," + + "\"auth_required\":false" + "," + "\"port\": 7777" + "," + "\"max_payload\":100000000000" + "," + + "\"connect_urls\":[\"one\", \"two\"]" + "}", "ping", "msg one 22 33", "+oK", "PoNg", "pong", "MsG one 22 23" }; String[] badStrings = { - "XXX", "XXXX", "XX", "X", "PINX", "PONX", "MSX", "INFX", "+OX", "-ERX", - "xxx", "xxxx", "xx", "x", "pinx", "ponx", "msx", "infx", "+ox", "-erx", - "+mk", "+ms", "-msg", "-esg", "poog", "piig", "mkg", "iing", "inng" + "XXX", "XXXX", "XX", "X", "PINX", "PONX", "MSX", "INFX", "+OX", "-ERX", + "xxx", "xxxx", "xx", "x", "pinx", "ponx", "msx", "infx", "+ox", "-erx", + "+mk", "+ms", "-msg", "-esg", "poog", "piig", "mkg", "iing", "inng" }; String[] expected = { - NatsConnection.OP_OK, NatsConnection.OP_PONG, NatsConnection.OP_PING, NatsConnection.OP_MSG, - NatsConnection.OP_ERR, NatsConnection.OP_INFO, NatsConnection.OP_PING, NatsConnection.OP_MSG, - NatsConnection.OP_OK, NatsConnection.OP_PONG, NatsConnection.OP_PONG, NatsConnection.OP_MSG + NatsConnection.OP_OK, NatsConnection.OP_PONG, NatsConnection.OP_PING, NatsConnection.OP_MSG, + NatsConnection.OP_ERR, NatsConnection.OP_INFO, NatsConnection.OP_PING, NatsConnection.OP_MSG, + NatsConnection.OP_OK, NatsConnection.OP_PONG, NatsConnection.OP_PONG, NatsConnection.OP_MSG }; try (NatsTestServer ts = new NatsTestServer(false); - NatsConnection nc = (NatsConnection) Nats.connect(ts.getURI())) { + NatsConnection nc = (NatsConnection) Nats.connect(ts.getURI())) { NatsConnectionReader reader = nc.getReader(); - for (int i=0; i<serverStrings.length; i++) { - byte[] bytes = (serverStrings[i]+"\r\n").getBytes(StandardCharsets.US_ASCII); + for (int i = 0; i < serverStrings.length; i++) { + byte[] bytes = (serverStrings[i] + "\r\n").getBytes(StandardCharsets.US_ASCII); reader.fakeReadForTest(bytes); reader.gatherOp(bytes.length); String op = reader.currentOp(); assertEquals(serverStrings[i], expected[i], op); } - for (int i=0; i<badStrings.length; i++) { - byte[] bytes = (badStrings[i]+"\r\n").getBytes(StandardCharsets.US_ASCII); + for (int i = 0; i < badStrings.length; i++) { + byte[] bytes = (badStrings[i] + "\r\n").getBytes(StandardCharsets.US_ASCII); reader.fakeReadForTest(bytes); reader.gatherOp(bytes.length); String op = reader.currentOp(); @@ -215,4 +215,46 @@ public void testProtocolStrings() throws Exception { } } } + + + + @Test + public void testProtocolStrings2() throws Exception { + String[] serverStrings = { + "+OK", "PONG", "PING", "INFO {" + "\"server_id\":\"myserver\"" + "," + "\"version\":\"1.1.1\"" + "," + + "\"go\": \"go1.9\"" + "," + "\"host\": \"host\"" + "," + "\"tls_required\": true" + "," + + "\"auth_required\":false" + "," + "\"port\": 7777" + "," + "\"max_payload\":100000000000" + "," + + "\"connect_urls\":[\"one\", \"two\"]" + "}" + }; + + String[] expected = { + NatsConnection.OP_OK, NatsConnection.OP_PONG, NatsConnection.OP_PING, NatsConnection.OP_INFO + }; + + final String protocol = "MSG longer.subject.abitlikeaninbox 22 longer.replyto.abitlikeaninbox 234\r\n"; + final byte[] hBytes = (protocol + "\r\n").getBytes(StandardCharsets.US_ASCII); + + try (NatsTestServer ts = new NatsTestServer(false); + NatsConnection nc = (NatsConnection) Nats.connect(ts.getURI())) { + NatsConnectionReader reader = nc.getReader(); + + for (int i = 0; i < serverStrings.length; i++) { + byte[] bytes = (serverStrings[i] + "\r\n").getBytes(StandardCharsets.US_ASCII); + reader.fakeReadForTest(bytes); + reader.gatherOp(bytes.length); + String op = reader.currentOp(); + assertEquals(serverStrings[i], expected[i], op); + } + + reader.fakeReadForTest(hBytes); + reader.gatherOp(hBytes.length); + final String op = reader.currentOp(); + assertEquals(protocol, "MSG", op); + + Thread.sleep(1000); + + NatsMessage incoming = reader.getIncoming(); + System.out.println(incoming); + } + } } \ No newline at end of file diff --git a/src/test/java/io/nats/client/impl/RequestTests.java b/src/test/java/io/nats/client/impl/RequestTests.java index ec6b83565..964070d7b 100644 --- a/src/test/java/io/nats/client/impl/RequestTests.java +++ b/src/test/java/io/nats/client/impl/RequestTests.java @@ -406,7 +406,7 @@ public void throwsIfClosed() throws IOException, InterruptedException { public void testThrowsWithoutSubject() throws IOException, InterruptedException { try (NatsTestServer ts = new NatsTestServer(false); Connection nc = Nats.connect(ts.getURI())) { - nc.request(null, null); + nc.request("", null); assertFalse(true); } } diff --git a/src/test/java/io/nats/client/impl/SendHPubTest.java b/src/test/java/io/nats/client/impl/SendHPubTest.java new file mode 100644 index 000000000..9896c2874 --- /dev/null +++ b/src/test/java/io/nats/client/impl/SendHPubTest.java @@ -0,0 +1,123 @@ +package io.nats.client.impl; + +import io.nats.client.*; +import org.junit.Test; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; + +import static junit.framework.TestCase.assertEquals; +import static junit.framework.TestCase.assertNotNull; + + +public class SendHPubTest { + + @Test + public void testNoHeader() throws Exception { + + final NatsTestServer natsTestServer = new NatsTestServer(true); + Thread.sleep(1000); + + final Connection connect1 = Nats.connect(natsTestServer.getURI()); + final Connection connect2 = Nats.connect(natsTestServer.getURI()); + + final String subject = "foo"; + + try { + Thread.sleep(1000); + System.out.println(connect1.getConnectedUrl()); + + final Subscription subscribe = connect2.subscribe(subject); + + connect1.publish(subject, "foo".getBytes(StandardCharsets.UTF_8)); + connect1.flush(Duration.ofSeconds(10)); + + Thread.sleep(1000); + + final Message message = subscribe.nextMessage(Duration.ofSeconds(10)); + + assertNotNull(message); + + assertEquals("foo", new String(message.getData(), StandardCharsets.UTF_8)); + + } finally { + connect1.close(); + connect2.close(); + natsTestServer.close(); + } + } + + @Test + public void testWithMessageBuilderNoHeader() throws Exception { + + final NatsTestServer natsTestServer = new NatsTestServer(true); + final Connection connect1 = Nats.connect(natsTestServer.getURI()); + final Connection connect2 = Nats.connect(natsTestServer.getURI()); + + final String subject = "foo"; + + try { + + Thread.sleep(1000); + System.out.println(connect1.getConnectedUrl()); + + final Subscription subscribe = connect2.subscribe(subject); + + connect1.publish(new NatsMessage.PublishBuilder() + .data("foo", StandardCharsets.UTF_8) + .subject(subject) + .maxPayload(10000L) + .build()); + + connect1.flush(Duration.ofSeconds(10)); + Thread.sleep(1000); + + final Message message = subscribe.nextMessage(Duration.ofSeconds(10)); + assertNotNull(message); + + assertEquals("foo", new String(message.getData(), StandardCharsets.UTF_8)); + + } finally { + connect1.close(); + connect2.close(); + natsTestServer.close(); + } + } + + @Test + public void testWithMessageBuilderWithHeader() throws Exception { + + final NatsTestServer natsTestServer = new NatsTestServer(true); + final Connection connect1 = Nats.connect(natsTestServer.getURI()); + final Connection connect2 = Nats.connect(natsTestServer.getURI()); + + final String subject = "foo"; + + try { + + Thread.sleep(1000); + System.out.println(connect1.getConnectedUrl()); + + final Subscription subscribe = connect2.subscribe(subject); + + connect1.publish(new NatsMessage.PublishBuilder() + .data("foo", StandardCharsets.UTF_8) + .subject(subject) + .addHeader("foo", "bar") + .maxPayload(10000L) + .build()); + + connect1.flush(Duration.ofSeconds(10)); + Thread.sleep(1000); + + final Message message = subscribe.nextMessage(Duration.ofSeconds(10)); + assertNotNull(message); + assertEquals("foo", new String(message.getData(), StandardCharsets.UTF_8)); + + } finally { + connect1.close(); + connect2.close(); + natsTestServer.close(); + } + } +}