From 0ea79326253cee081867ef461be3c76d82ad30ef Mon Sep 17 00:00:00 2001 From: Martin Volf Date: Fri, 19 Jan 2024 21:47:12 +0100 Subject: [PATCH 1/3] connected sockets can be passed to the library fixes hierynomus/sshj#924 Signed-off-by: Martin Volf --- .../java/net/schmizz/sshj/SocketClient.java | 20 +++-- .../net/schmizz/sshj/ConnectedSocketTest.java | 78 +++++++++++++++++++ 2 files changed, 92 insertions(+), 6 deletions(-) create mode 100644 src/test/java/net/schmizz/sshj/ConnectedSocketTest.java diff --git a/src/main/java/net/schmizz/sshj/SocketClient.java b/src/main/java/net/schmizz/sshj/SocketClient.java index d7971243..5447b557 100644 --- a/src/main/java/net/schmizz/sshj/SocketClient.java +++ b/src/main/java/net/schmizz/sshj/SocketClient.java @@ -65,7 +65,9 @@ public void connect(String hostname, int port) throws IOException { this.hostname = hostname; this.port = port; socket = socketFactory.createSocket(); - socket.connect(makeInetSocketAddress(hostname, port), connectTimeout); + if (! socket.isConnected()) { + socket.connect(makeInetSocketAddress(hostname, port), connectTimeout); + } onConnect(); } } @@ -77,8 +79,10 @@ public void connect(String hostname, int port, InetAddress localAddr, int localP this.hostname = hostname; this.port = port; socket = socketFactory.createSocket(); - socket.bind(new InetSocketAddress(localAddr, localPort)); - socket.connect(makeInetSocketAddress(hostname, port), connectTimeout); + if (! socket.isConnected()) { + socket.bind(new InetSocketAddress(localAddr, localPort)); + socket.connect(makeInetSocketAddress(hostname, port), connectTimeout); + } onConnect(); } } @@ -104,7 +108,9 @@ public void connect(InetAddress host) throws IOException { public void connect(InetAddress host, int port) throws IOException { this.port = port; socket = socketFactory.createSocket(); - socket.connect(new InetSocketAddress(host, port), connectTimeout); + if (! socket.isConnected()) { + socket.connect(new InetSocketAddress(host, port), connectTimeout); + } onConnect(); } @@ -112,8 +118,10 @@ public void connect(InetAddress host, int port, InetAddress localAddr, int local throws IOException { this.port = port; socket = socketFactory.createSocket(); - socket.bind(new InetSocketAddress(localAddr, localPort)); - socket.connect(new InetSocketAddress(host, port), connectTimeout); + if (! socket.isConnected()) { + socket.bind(new InetSocketAddress(localAddr, localPort)); + socket.connect(new InetSocketAddress(host, port), connectTimeout); + } onConnect(); } diff --git a/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java b/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java new file mode 100644 index 00000000..0efe86c9 --- /dev/null +++ b/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java @@ -0,0 +1,78 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj; + +import com.hierynomus.sshj.test.SshServerExtension; +import net.schmizz.sshj.SSHClient; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.Socket; +import javax.net.SocketFactory; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + + +public class ConnectedSocketTest { + @RegisterExtension + public SshServerExtension fixture = new SshServerExtension(); + + @BeforeEach + public void setupClient() throws IOException { + SSHClient defaultClient = fixture.setupDefaultClient(); + } + + @Test + public void connectsIfUnconnected() { + assertDoesNotThrow(() -> fixture.connectClient(fixture.getClient())); + } + + @Test + public void handlesConnected() throws IOException { + Socket socket = SocketFactory.getDefault().createSocket(); + SocketFactory factory = new SocketFactory() { + @Override + public Socket createSocket() { + return socket; + } + @Override + public Socket createSocket(InetAddress host, int port) { + return socket; + } + @Override + public Socket createSocket(InetAddress address, int port, + InetAddress localAddress, int localPort) { + return socket; + } + @Override + public Socket createSocket(String host, int port) { + return socket; + } + @Override + public Socket createSocket(String host, int port, + InetAddress localHost, int localPort) { + return socket; + } + }; + socket.connect(new InetSocketAddress("localhost", fixture.getServer().getPort())); + fixture.getClient().setSocketFactory(factory); + assertDoesNotThrow(() -> fixture.connectClient(fixture.getClient())); + } +} From 5ff91e2c769437c85da1828be3d99ccdd130f362 Mon Sep 17 00:00:00 2001 From: Martin Volf Date: Fri, 26 Jan 2024 08:05:04 +0100 Subject: [PATCH 2/3] removed pointless socket check; test coverage improved Signed-off-by: Martin Volf --- .../java/net/schmizz/sshj/SocketClient.java | 12 ++---- .../net/schmizz/sshj/ConnectedSocketTest.java | 39 ++++++++++++++++--- 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/src/main/java/net/schmizz/sshj/SocketClient.java b/src/main/java/net/schmizz/sshj/SocketClient.java index 5447b557..e4809e0d 100644 --- a/src/main/java/net/schmizz/sshj/SocketClient.java +++ b/src/main/java/net/schmizz/sshj/SocketClient.java @@ -79,10 +79,8 @@ public void connect(String hostname, int port, InetAddress localAddr, int localP this.hostname = hostname; this.port = port; socket = socketFactory.createSocket(); - if (! socket.isConnected()) { - socket.bind(new InetSocketAddress(localAddr, localPort)); - socket.connect(makeInetSocketAddress(hostname, port), connectTimeout); - } + socket.bind(new InetSocketAddress(localAddr, localPort)); + socket.connect(makeInetSocketAddress(hostname, port), connectTimeout); onConnect(); } } @@ -118,10 +116,8 @@ public void connect(InetAddress host, int port, InetAddress localAddr, int local throws IOException { this.port = port; socket = socketFactory.createSocket(); - if (! socket.isConnected()) { - socket.bind(new InetSocketAddress(localAddr, localPort)); - socket.connect(new InetSocketAddress(host, port), connectTimeout); - } + socket.bind(new InetSocketAddress(localAddr, localPort)); + socket.connect(new InetSocketAddress(host, port), connectTimeout); onConnect(); } diff --git a/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java b/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java index 0efe86c9..87bf3f3a 100644 --- a/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java +++ b/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java @@ -20,11 +20,17 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.apache.sshd.server.SshServer; import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.Socket; +import java.util.stream.Stream; + import javax.net.SocketFactory; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; @@ -39,13 +45,34 @@ public void setupClient() throws IOException { SSHClient defaultClient = fixture.setupDefaultClient(); } - @Test - public void connectsIfUnconnected() { - assertDoesNotThrow(() -> fixture.connectClient(fixture.getClient())); + private static interface Connector { + void connect(SshServerExtension fx) throws IOException; + } + + private static void connectViaHostname(SshServerExtension fx) throws IOException { + SshServer server = fx.getServer(); + fx.getClient().connect(server.getHost(), server.getPort()); + } + + private static void connectViaAddr(SshServerExtension fx) throws IOException { + SshServer server = fx.getServer(); + InetAddress addr = InetAddress.getByName(server.getHost()); + fx.getClient().connect(addr, server.getPort()); + } + + private static Stream connectMethods() { + return Stream.of(fx -> connectViaHostname(fx), fx -> connectViaAddr(fx)); + } + + @ParameterizedTest + @MethodSource("connectMethods") + public void connectsIfUnconnected(Connector connector) { + assertDoesNotThrow(() -> connector.connect(fixture)); } - @Test - public void handlesConnected() throws IOException { + @ParameterizedTest + @MethodSource("connectMethods") + public void handlesConnected(Connector connector) throws IOException { Socket socket = SocketFactory.getDefault().createSocket(); SocketFactory factory = new SocketFactory() { @Override @@ -73,6 +100,6 @@ public Socket createSocket(String host, int port, }; socket.connect(new InetSocketAddress("localhost", fixture.getServer().getPort())); fixture.getClient().setSocketFactory(factory); - assertDoesNotThrow(() -> fixture.connectClient(fixture.getClient())); + assertDoesNotThrow(() -> connector.connect(fixture)); } } From fa5b5ceef1b0fc5a468db3412113bfe1b140d56a Mon Sep 17 00:00:00 2001 From: Martin Volf Date: Mon, 29 Jan 2024 09:30:00 +0100 Subject: [PATCH 3/3] better test coverage Signed-off-by: Martin Volf --- src/test/java/net/schmizz/sshj/ConnectedSocketTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java b/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java index 87bf3f3a..1424d62d 100644 --- a/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java +++ b/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java @@ -51,7 +51,7 @@ private static interface Connector { private static void connectViaHostname(SshServerExtension fx) throws IOException { SshServer server = fx.getServer(); - fx.getClient().connect(server.getHost(), server.getPort()); + fx.getClient().connect("localhost", server.getPort()); } private static void connectViaAddr(SshServerExtension fx) throws IOException {