diff --git a/src/itest/groovy/com/hierynomus/sshj/ManyChannelsSpec.groovy b/src/itest/groovy/com/hierynomus/sshj/ManyChannelsSpec.groovy new file mode 100644 index 00000000..d2caab6e --- /dev/null +++ b/src/itest/groovy/com/hierynomus/sshj/ManyChannelsSpec.groovy @@ -0,0 +1,74 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.sshj + +import net.schmizz.sshj.SSHClient +import net.schmizz.sshj.common.IOUtils +import net.schmizz.sshj.connection.channel.direct.Session +import spock.lang.Specification + +import java.util.concurrent.* + +import static org.codehaus.groovy.runtime.IOGroovyMethods.withCloseable + +class ManyChannelsSpec extends Specification { + + def "should work with many channels without nonexistent channel error (GH issue #805)"() { + given: + SshdContainer sshd = new SshdContainer.Builder() + .withSshdConfig("""${SshdContainer.Builder.DEFAULT_SSHD_CONFIG} + MaxSessions 200 + """.stripMargin()) + .build() + sshd.start() + SSHClient client = sshd.getConnectedClient() + client.authPublickey("sshj", "src/test/resources/id_rsa") + + when: + List> futures = [] + ExecutorService executorService = Executors.newCachedThreadPool() + + for (int i in 0..20) { + futures.add(executorService.submit((Callable) { + return execute(client) + })) + } + executorService.shutdown() + executorService.awaitTermination(1, TimeUnit.DAYS) + + then: + futures*.get().findAll { it != null }.empty + + cleanup: + client.close() + } + + + private static Exception execute(SSHClient sshClient) { + try { + for (def i in 0..100) { + withCloseable (sshClient.startSession()) {sshSession -> + Session.Command sshCommand = sshSession.exec("ls -la") + IOUtils.readFully(sshCommand.getInputStream()).toString() + sshCommand.close() + } + } + } catch (Exception e) { + return e + } + return null + } +} diff --git a/src/main/java/net/schmizz/sshj/connection/channel/AbstractChannel.java b/src/main/java/net/schmizz/sshj/connection/channel/AbstractChannel.java index 1cd3e5c9..cb237343 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/AbstractChannel.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/AbstractChannel.java @@ -304,6 +304,25 @@ public boolean isOpen() { } } + // Prevent CHANNEL_CLOSE to be sent between isOpen and a Transport.write call in the runnable, otherwise + // a disconnect with a "packet referred to nonexistent channel" message can occur. + // + // This particularly happens when the transport.Reader thread passes an eof from the server to the + // ChannelInputStream, the reading library-user thread returns, and closes the channel at the same time as the + // transport.Reader thread receives the subsequent CHANNEL_CLOSE from the server. + boolean whileOpen(TransportRunnable runnable) throws TransportException, ConnectionException { + openCloseLock.lock(); + try { + if (isOpen()) { + runnable.run(); + return true; + } + } finally { + openCloseLock.unlock(); + } + return false; + } + private void gotChannelRequest(SSHPacket buf) throws ConnectionException, TransportException { final String reqType; @@ -427,5 +446,8 @@ public String toString() { + rwin + " >"; } + public interface TransportRunnable { + void run() throws TransportException, ConnectionException; + } } diff --git a/src/main/java/net/schmizz/sshj/connection/channel/ChannelOutputStream.java b/src/main/java/net/schmizz/sshj/connection/channel/ChannelOutputStream.java index 7aa53153..29701b18 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/ChannelOutputStream.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/ChannelOutputStream.java @@ -30,7 +30,7 @@ */ public final class ChannelOutputStream extends OutputStream implements ErrorNotifiable { - private final Channel chan; + private final AbstractChannel chan; private final Transport trans; private final Window.Remote win; @@ -47,6 +47,12 @@ private final class DataBuffer { private final SSHPacket packet = new SSHPacket(Message.CHANNEL_DATA); private final Buffer.PlainBuffer leftOvers = new Buffer.PlainBuffer(); + private final AbstractChannel.TransportRunnable packetWriteRunnable = new AbstractChannel.TransportRunnable() { + @Override + public void run() throws TransportException { + trans.write(packet); + } + }; DataBuffer() { headerOffset = packet.rpos(); @@ -99,8 +105,9 @@ boolean flush(int bufferSize, boolean canAwaitExpansion) throws TransportExcepti if (leftOverBytes > 0) { leftOvers.putRawBytes(packet.array(), packet.wpos(), leftOverBytes); } - - trans.write(packet); + if (!chan.whileOpen(packetWriteRunnable)) { + throwStreamClosed(); + } win.consume(writeNow); packet.rpos(headerOffset); @@ -119,7 +126,7 @@ boolean flush(int bufferSize, boolean canAwaitExpansion) throws TransportExcepti } - public ChannelOutputStream(Channel chan, Transport trans, Window.Remote win) { + public ChannelOutputStream(AbstractChannel chan, Transport trans, Window.Remote win) { this.chan = chan; this.trans = trans; this.win = win; @@ -157,7 +164,7 @@ private void checkClose() throws SSHException { if (error != null) { throw error; } else { - throw new ConnectionException("Stream closed"); + throwStreamClosed(); } } } @@ -165,9 +172,14 @@ private void checkClose() throws SSHException { @Override public synchronized void close() throws IOException { // Not closed yet, and underlying channel is open to flush the data to. - if (!closed.getAndSet(true) && chan.isOpen()) { - buffer.flush(false); - trans.write(new SSHPacket(Message.CHANNEL_EOF).putUInt32(chan.getRecipient())); + if (!closed.getAndSet(true)) { + chan.whileOpen(new AbstractChannel.TransportRunnable() { + @Override + public void run() throws TransportException, ConnectionException { + buffer.flush(false); + trans.write(new SSHPacket(Message.CHANNEL_EOF).putUInt32(chan.getRecipient())); + } + }); } } @@ -188,4 +200,7 @@ public String toString() { return "< ChannelOutputStream for Channel #" + chan.getID() + " >"; } + private static void throwStreamClosed() throws ConnectionException { + throw new ConnectionException("Stream closed"); + } }