Skip to content

Commit bd16750

Browse files
Fixes issue #79: do not close user SSH connections
The issue was fixed in the way proposed in #79, an additional field called `CloseHandler` was added to the `Client` struct. This handler can either be `EmptyHandler` which is equivalent to a no-op, or a `CloseSSHClient` which closes the passed SSH client when executed. The `EmptyHandler` is used by default unless `Client.Connect()` is called to establish the SSH connection. The reasoning is that whenever someone passes their own `ssh.Client` to the library, `Connect()` will not be called, thus signaling that we should not manage the `ssh.Client`. To ensure the correctness of this fix, two additional test cases were added (1) `TestUserSuppliedSSHClientDoesNotClose` which creates an `scp` client using a existing `ssh.Client` and ensures that the client is not closed by attempting to create a new session from it. (2) `TestSSHClientNoLeak` which uses `Connect()` to establish the SSH connection and ensures that its underlying `ssh.Client` is no longer functioning using the same mechanism as (1).
1 parent 6d16fff commit bd16750

File tree

3 files changed

+108
-5
lines changed

3 files changed

+108
-5
lines changed

client.go

+34-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright (c) 2021 Bram Vandenbogaerde And Contributors
1+
/* Copyright (c) 2024 Bram Vandenbogaerde And Contributors
22
* You may use, distribute or modify this code under the
33
* terms of the Mozilla Public License 2.0, which is distributed
44
* along with the source code.
@@ -21,6 +21,27 @@ import (
2121
"golang.org/x/crypto/ssh"
2222
)
2323

24+
// Callback for freeing managed resources
25+
type ICloseHandler interface {
26+
Close()
27+
}
28+
29+
// Close handler equivalent to a no-op. Used by default
30+
// when no resources have to be cleaned.
31+
type EmptyHandler struct{}
32+
33+
func (EmptyHandler) Close() {}
34+
35+
// Close handler to close an SSH client
36+
type CloseSSHCLient struct {
37+
// Reference to the used SSH client
38+
sshClient *ssh.Client
39+
}
40+
41+
func (scp CloseSSHCLient) Close() {
42+
scp.sshClient.Close()
43+
}
44+
2445
type PassThru func(r io.Reader, total int64) io.Reader
2546

2647
type Client struct {
@@ -39,6 +60,10 @@ type Client struct {
3960

4061
// RemoteBinary the absolute path to the remote SCP binary.
4162
RemoteBinary string
63+
64+
// Handler called when calling `Close` to clean up any remaining
65+
// resources managed by `Client`.
66+
closeHandler ICloseHandler
4267
}
4368

4469
// Connect connects to the remote SSH server, returns error if it couldn't establish a session to the SSH server.
@@ -49,9 +74,16 @@ func (a *Client) Connect() error {
4974
}
5075

5176
a.sshClient = client
77+
a.closeHandler = CloseSSHCLient{sshClient: client}
5278
return nil
5379
}
5480

81+
// Returns the underlying SSH client, this should be used carefully as
82+
// it will be closed by `client.Close`.
83+
func (a *Client) SSHClient() *ssh.Client {
84+
return a.sshClient
85+
}
86+
5587
// CopyFromFile copies the contents of an os.File to a remote location, it will get the length of the file by looking it up from the filesystem.
5688
func (a *Client) CopyFromFile(ctx context.Context, file os.File, remotePath string, permissions string) error {
5789
return a.CopyFromFilePassThru(ctx, file, remotePath, permissions, nil)
@@ -347,7 +379,5 @@ func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remote
347379
}
348380

349381
func (a *Client) Close() {
350-
if a.sshClient != nil {
351-
a.sshClient.Close()
352-
}
382+
a.closeHandler.Close()
353383
}

configurer.go

+1
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,6 @@ func (c *ClientConfigurer) Create() Client {
7878
Timeout: c.timeout,
7979
RemoteBinary: c.remoteBinary,
8080
sshClient: c.sshClient,
81+
closeHandler: EmptyHandler{},
8182
}
8283
}

tests/basic_test.go

+73-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package tests
1+
package scp
22

33
import (
44
"context"
@@ -317,3 +317,75 @@ func TestFileNotFound(t *testing.T) {
317317
t.Errorf("Expected %v, got %v", expected, err.Error())
318318
}
319319
}
320+
321+
func TestUserSuppliedSSHClientDoesNotClose(t *testing.T) {
322+
// create the SSH connection
323+
clientConfig, err := buildClientConfig()
324+
if err != nil {
325+
t.Error("Could not build client config", clientConfig)
326+
}
327+
328+
sshClient, err := ssh.Dial("tcp", "127.0.0.1:2244", &clientConfig)
329+
if err != nil {
330+
t.Error("Could not establish SSH connection", err)
331+
}
332+
defer sshClient.Close()
333+
334+
// create the SCP client
335+
client, err := scp.NewClientBySSH(sshClient)
336+
if err != nil {
337+
t.Error("Could not create SCP client", err)
338+
}
339+
340+
// copy a file for good measure
341+
342+
f, _ := os.Open("./data/upload_file.txt")
343+
defer f.Close()
344+
345+
err = client.CopyFile(context.Background(), f, "/data/test.txt", "0777")
346+
347+
if err != nil {
348+
t.Error("Could not copy file to remote", err)
349+
}
350+
351+
// then close the SCP client
352+
client.Close()
353+
354+
var session *ssh.Session
355+
356+
// ensure that the SSH client is still opened
357+
// we do so by creating a new session, if this fails
358+
// the SSH connection was already closed
359+
if session, err = sshClient.NewSession(); err != nil {
360+
t.Fatal("SSH session was already closed.")
361+
}
362+
363+
session.Close()
364+
}
365+
366+
// Ensure that the underlying SSH client managed by the library is correctly closed
367+
// after closing the SCP connection
368+
func TestSSHClientNoLeak(t *testing.T) {
369+
client := establishConnection(t)
370+
371+
// copy a file for good measure
372+
f, _ := os.Open("./data/upload_file.txt")
373+
defer f.Close()
374+
375+
err := client.CopyFile(context.Background(), f, "/data/test.txt", "0777")
376+
377+
if err != nil {
378+
t.Error("Could not copy file to remote", err)
379+
}
380+
381+
// then close the SCP client
382+
client.Close()
383+
384+
// ensure that the SSH client is still opened
385+
// we do so by creating a new session, if this fails
386+
// the SSH connection was already closed
387+
if session, err := client.SSHClient().NewSession(); err == nil {
388+
session.Close()
389+
t.Fatal("SSH session was not closed.")
390+
}
391+
}

0 commit comments

Comments
 (0)