Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update SSHClient to interrupt KeepAlive Thread when disconnecting #752

Merged
merged 2 commits into from
Dec 22, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 12 additions & 20 deletions src/main/java/net/schmizz/keepalive/KeepAlive.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import net.schmizz.sshj.transport.TransportException;
import org.slf4j.Logger;

import java.util.concurrent.TimeUnit;

public abstract class KeepAlive extends Thread {
protected final Logger log;
protected final ConnectionImpl conn;
Expand All @@ -33,50 +35,40 @@ protected KeepAlive(ConnectionImpl conn, String name) {
setDaemon(true);
}

public boolean isEnabled() {
return keepAliveInterval > 0;
}

public synchronized int getKeepAliveInterval() {
return keepAliveInterval;
}

public synchronized void setKeepAliveInterval(int keepAliveInterval) {
this.keepAliveInterval = keepAliveInterval;
if (keepAliveInterval > 0 && getState() == State.NEW) {
start();
}
notify();
}

synchronized protected int getPositiveInterval()
throws InterruptedException {
while (keepAliveInterval <= 0) {
wait();
}
return keepAliveInterval;
}

@Override
public void run() {
log.debug("Starting {}, sending keep-alive every {} seconds", getClass().getSimpleName(), keepAliveInterval);
log.debug("{} Started with interval [{} seconds]", getClass().getSimpleName(), keepAliveInterval);
try {
while (!isInterrupted()) {
final int hi = getPositiveInterval();
final int interval = getKeepAliveInterval();
if (conn.getTransport().isRunning()) {
log.debug("Sending keep-alive since {} seconds elapsed", hi);
log.debug("{} Sending after interval [{} seconds]", getClass().getSimpleName(), interval);
doKeepAlive();
}
Thread.sleep(hi * 1000);
TimeUnit.SECONDS.sleep(interval);
}
} catch (InterruptedException e) {
// Interrupt signal may be catched when sleeping.
log.trace("{} Interrupted while sleeping", getClass().getSimpleName());
} catch (Exception e) {
// If we weren't interrupted, kill the transport, then this exception was unexpected.
// Else we're in shutdown-mode already, so don't forcibly kill the transport.
if (!isInterrupted()) {
conn.getTransport().die(e);
}
}

log.debug("Stopping {}", getClass().getSimpleName());

log.debug("{} Stopped", getClass().getSimpleName());
}

protected abstract void doKeepAlive() throws TransportException, ConnectionException;
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/net/schmizz/sshj/SSHClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package net.schmizz.sshj;

import net.schmizz.keepalive.KeepAlive;
import net.schmizz.sshj.common.*;
import net.schmizz.sshj.connection.Connection;
import net.schmizz.sshj.connection.ConnectionException;
Expand Down Expand Up @@ -424,6 +425,7 @@ public void authGssApiWithMic(String username, LoginContext context, Oid support
@Override
public void disconnect()
throws IOException {
conn.getKeepAlive().interrupt();
for (LocalPortForwarder forwarder : forwarders) {
try {
forwarder.close();
Expand Down Expand Up @@ -791,6 +793,10 @@ protected void onConnect()
throws IOException {
super.onConnect();
trans.init(getRemoteHostname(), getRemotePort(), getInputStream(), getOutputStream());
final KeepAlive keepAliveThread = conn.getKeepAlive();
if (keepAliveThread.isEnabled()) {
keepAliveThread.start();
}
doKex();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,54 +15,66 @@
*/
package com.hierynomus.sshj.keepalive;

import com.hierynomus.sshj.test.KnownFailingTests;
import com.hierynomus.sshj.test.SlowTests;
import com.hierynomus.sshj.test.SshFixture;
import net.schmizz.keepalive.KeepAlive;
import net.schmizz.keepalive.KeepAliveProvider;
import net.schmizz.sshj.DefaultConfig;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.userauth.UserAuthException;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;

import java.io.IOException;
import java.lang.management.ManagementFactory;
import java.lang.management.ThreadInfo;
import java.lang.management.ThreadMXBean;

import static org.junit.Assert.fail;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;

public class KeepAliveThreadTerminationTest {

private static final int KEEP_ALIVE_SECONDS = 1;

private static final long STOP_SLEEP = 1500;

@Rule
public SshFixture fixture = new SshFixture();

@Test
@Category({SlowTests.class, KnownFailingTests.class})
public void shouldCorrectlyTerminateThreadOnDisconnect() throws IOException, InterruptedException {
DefaultConfig defaultConfig = new DefaultConfig();
public void shouldNotStartThreadOnSetKeepAliveInterval() {
final SSHClient sshClient = setupClient();

final KeepAlive keepAlive = sshClient.getConnection().getKeepAlive();
assertTrue(keepAlive.isDaemon());
assertFalse(keepAlive.isAlive());
assertEquals(Thread.State.NEW, keepAlive.getState());
}

@Test
public void shouldStartThreadOnConnectAndInterruptOnDisconnect() throws IOException, InterruptedException {
final SSHClient sshClient = setupClient();

final KeepAlive keepAlive = sshClient.getConnection().getKeepAlive();
assertTrue(keepAlive.isDaemon());
assertEquals(Thread.State.NEW, keepAlive.getState());

fixture.connectClient(sshClient);
assertEquals(Thread.State.TIMED_WAITING, keepAlive.getState());

assertThrows(UserAuthException.class, () -> sshClient.authPassword("bad", "credentials"));

fixture.stopClient();
Thread.sleep(STOP_SLEEP);

assertFalse(keepAlive.isAlive());
assertEquals(Thread.State.TERMINATED, keepAlive.getState());
}

private SSHClient setupClient() {
final DefaultConfig defaultConfig = new DefaultConfig();
defaultConfig.setKeepAliveProvider(KeepAliveProvider.KEEP_ALIVE);
for (int i = 0; i < 10; i++) {
SSHClient sshClient = fixture.setupClient(defaultConfig);
fixture.connectClient(sshClient);
sshClient.getConnection().getKeepAlive().setKeepAliveInterval(1);
try {
sshClient.authPassword("bad", "credentials");
fail("Should not auth.");
} catch (UserAuthException e) {
// OK
}
fixture.stopClient();
Thread.sleep(2000);
}

ThreadMXBean threadMXBean = ManagementFactory.getThreadMXBean();
for (long l : threadMXBean.getAllThreadIds()) {
ThreadInfo threadInfo = threadMXBean.getThreadInfo(l);
if (threadInfo.getThreadName().equals("keep-alive") && threadInfo.getThreadState() != Thread.State.TERMINATED) {
fail("Found alive keep-alive thread in state " + threadInfo.getThreadState());
}
}
final SSHClient sshClient = fixture.setupClient(defaultConfig);
sshClient.getConnection().getKeepAlive().setKeepAliveInterval(KEEP_ALIVE_SECONDS);
return sshClient;
}
}