diff --git a/src/main/java/net/schmizz/sshj/SocketClient.java b/src/main/java/net/schmizz/sshj/SocketClient.java index d7971243a..e4809e0d2 100644 --- a/src/main/java/net/schmizz/sshj/SocketClient.java +++ b/src/main/java/net/schmizz/sshj/SocketClient.java @@ -65,7 +65,9 @@ public void connect(String hostname, int port) throws IOException { this.hostname = hostname; this.port = port; socket = socketFactory.createSocket(); - socket.connect(makeInetSocketAddress(hostname, port), connectTimeout); + if (! socket.isConnected()) { + socket.connect(makeInetSocketAddress(hostname, port), connectTimeout); + } onConnect(); } } @@ -104,7 +106,9 @@ public void connect(InetAddress host) throws IOException { public void connect(InetAddress host, int port) throws IOException { this.port = port; socket = socketFactory.createSocket(); - socket.connect(new InetSocketAddress(host, port), connectTimeout); + if (! socket.isConnected()) { + socket.connect(new InetSocketAddress(host, port), connectTimeout); + } onConnect(); } diff --git a/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java b/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java new file mode 100644 index 000000000..1424d62dd --- /dev/null +++ b/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java @@ -0,0 +1,105 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj; + +import com.hierynomus.sshj.test.SshServerExtension; +import net.schmizz.sshj.SSHClient; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.apache.sshd.server.SshServer; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.util.stream.Stream; + +import javax.net.SocketFactory; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + + +public class ConnectedSocketTest { + @RegisterExtension + public SshServerExtension fixture = new SshServerExtension(); + + @BeforeEach + public void setupClient() throws IOException { + SSHClient defaultClient = fixture.setupDefaultClient(); + } + + private static interface Connector { + void connect(SshServerExtension fx) throws IOException; + } + + private static void connectViaHostname(SshServerExtension fx) throws IOException { + SshServer server = fx.getServer(); + fx.getClient().connect("localhost", server.getPort()); + } + + private static void connectViaAddr(SshServerExtension fx) throws IOException { + SshServer server = fx.getServer(); + InetAddress addr = InetAddress.getByName(server.getHost()); + fx.getClient().connect(addr, server.getPort()); + } + + private static Stream connectMethods() { + return Stream.of(fx -> connectViaHostname(fx), fx -> connectViaAddr(fx)); + } + + @ParameterizedTest + @MethodSource("connectMethods") + public void connectsIfUnconnected(Connector connector) { + assertDoesNotThrow(() -> connector.connect(fixture)); + } + + @ParameterizedTest + @MethodSource("connectMethods") + public void handlesConnected(Connector connector) throws IOException { + Socket socket = SocketFactory.getDefault().createSocket(); + SocketFactory factory = new SocketFactory() { + @Override + public Socket createSocket() { + return socket; + } + @Override + public Socket createSocket(InetAddress host, int port) { + return socket; + } + @Override + public Socket createSocket(InetAddress address, int port, + InetAddress localAddress, int localPort) { + return socket; + } + @Override + public Socket createSocket(String host, int port) { + return socket; + } + @Override + public Socket createSocket(String host, int port, + InetAddress localHost, int localPort) { + return socket; + } + }; + socket.connect(new InetSocketAddress("localhost", fixture.getServer().getPort())); + fixture.getClient().setSocketFactory(factory); + assertDoesNotThrow(() -> connector.connect(fixture)); + } +}