Skip to content

Commit a331e1d

Browse files
committed
chore: Add sftp reconnections
Signed-off-by: Javier Aliaga <javier@diagrid.io>
1 parent d38062a commit a331e1d

File tree

8 files changed

+531
-207
lines changed

8 files changed

+531
-207
lines changed

bindings/sftp/integration/README.md

Lines changed: 0 additions & 42 deletions
This file was deleted.

bindings/sftp/integration/docker-compose.yaml

Lines changed: 0 additions & 34 deletions
This file was deleted.

bindings/sftp/proxy/proxy.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
Copyright 2025 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
*/
13+
14+
package sftp
15+
16+
import (
17+
"errors"
18+
"io"
19+
"log"
20+
"net"
21+
"sync/atomic"
22+
)
23+
24+
type Proxy struct {
25+
ListenAddr string
26+
UpstreamAddr string
27+
Client net.Conn
28+
Server net.Conn
29+
ReconnectionCount atomic.Int32
30+
Listener net.Listener
31+
}
32+
33+
func (p *Proxy) ListenAndServe() error {
34+
ln, err := net.Listen("tcp", p.ListenAddr)
35+
if err != nil {
36+
log.Fatalf("listen: %v", err)
37+
}
38+
log.Printf("Proxy listening on %s -> %s", p.ListenAddr, p.UpstreamAddr)
39+
p.Listener = ln
40+
41+
for {
42+
client, err := ln.Accept()
43+
if err != nil {
44+
if errors.Is(err, net.ErrClosed) {
45+
return nil
46+
}
47+
log.Printf("accept error: %v", err)
48+
continue
49+
}
50+
go p.handle(client)
51+
}
52+
}
53+
54+
func (p *Proxy) handle(client net.Conn) {
55+
defer client.Close()
56+
57+
// Connect to upstream SFTP server
58+
server, err := net.Dial("tcp", p.UpstreamAddr)
59+
if err != nil {
60+
log.Printf("dial upstream: %v", err)
61+
return
62+
}
63+
defer server.Close()
64+
65+
p.Client = client
66+
p.Server = server
67+
p.ReconnectionCount.Add(1)
68+
errCh := make(chan error, 2)
69+
70+
// client -> server
71+
go func() {
72+
_, cErr := io.Copy(server, client)
73+
errCh <- cErr
74+
}()
75+
76+
// server -> client
77+
go func() {
78+
_, cErr := io.Copy(client, server)
79+
errCh <- cErr
80+
}()
81+
82+
// When either direction ends, close both ends
83+
if err := <-errCh; err != nil && !isUsefullyClosed(err) {
84+
log.Printf("proxy stream ended with error: %v", err)
85+
}
86+
}
87+
88+
func (p *Proxy) KillServerConn() error {
89+
return p.Client.Close()
90+
}
91+
92+
func (p *Proxy) Close() {
93+
if p.Client != nil {
94+
_ = p.Client.Close()
95+
}
96+
97+
if p.Server != nil {
98+
_ = p.Server.Close()
99+
}
100+
101+
if p.Listener != nil {
102+
_ = p.Listener.Close()
103+
}
104+
105+
p.ReconnectionCount.Store(0)
106+
}
107+
108+
// isUsefullyClosed filters common close conditions from logging noise
109+
func isUsefullyClosed(err error) bool {
110+
return err == io.EOF || errors.Is(err, net.ErrClosed)
111+
}

bindings/sftp/sftp.go

Lines changed: 29 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
/*
2+
Copyright 2025 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
*/
13+
114
package sftp
215

316
import (
@@ -7,7 +20,6 @@ import (
720
"fmt"
821
"io"
922
"reflect"
10-
"sync"
1123

1224
sftpClient "github.com/pkg/sftp"
1325
"golang.org/x/crypto/ssh"
@@ -26,30 +38,9 @@ const (
2638

2739
// Sftp is a binding for file operations on sftp server.
2840
type Sftp struct {
29-
metadata *sftpMetadata
30-
logger logger.Logger
31-
sftpClient *sftpClient.Client
32-
sshClient *ssh.Client
33-
clientConfig *ssh.ClientConfig
34-
lock sync.RWMutex
35-
}
36-
37-
func (sftp *Sftp) Client() (*sftpClient.Client, error) {
38-
sftp.lock.RLock()
39-
current := sftp.sftpClient
40-
sftp.lock.RUnlock()
41-
42-
if current != nil {
43-
if _, err := current.Getwd(); err == nil {
44-
return current, nil
45-
}
46-
}
47-
48-
err := sftp.handleReconnection()
49-
if err != nil {
50-
return nil, err
51-
}
52-
return sftp.sftpClient, nil
41+
metadata *sftpMetadata
42+
logger logger.Logger
43+
c *Client
5344
}
5445

5546
// sftpMetadata defines the sftp metadata.
@@ -137,21 +128,12 @@ func (sftp *Sftp) Init(_ context.Context, metadata bindings.Metadata) error {
137128
HostKeyCallback: hostKeyCallback,
138129
}
139130

140-
sshClient, err := ssh.Dial("tcp", m.Address, config)
141-
if err != nil {
142-
return fmt.Errorf("sftp binding error: error create ssh client: %w", err)
143-
}
144-
145-
newSftpClient, err := sftpClient.NewClient(sshClient)
131+
sftp.metadata = m
132+
sftp.c, err = newClient(m.Address, config)
146133
if err != nil {
147-
_ = sshClient.Close()
148-
return fmt.Errorf("sftp binding error: error create sftp client: %w", err)
134+
return fmt.Errorf("sftp binding error: create sftp client error: %w", err)
149135
}
150136

151-
sftp.clientConfig = config
152-
sftp.metadata = m
153-
sftp.sftpClient = newSftpClient
154-
155137
return nil
156138
}
157139

@@ -185,19 +167,9 @@ func (sftp *Sftp) create(_ context.Context, req *bindings.InvokeRequest) (*bindi
185167
return nil, fmt.Errorf("sftp binding error: %w", err)
186168
}
187169

188-
dir, fileName := sftpClient.Split(path)
170+
c := sftp.c
189171

190-
c, err := sftp.Client()
191-
if err != nil {
192-
return nil, fmt.Errorf("sftp binding error: error getting sftp client: %w", err)
193-
}
194-
195-
err = c.MkdirAll(dir)
196-
if err != nil {
197-
return nil, fmt.Errorf("sftp binding error: error create dir %s: %w", dir, err)
198-
}
199-
200-
file, err := c.Create(path)
172+
file, fileName, err := c.create(path)
201173
if err != nil {
202174
return nil, fmt.Errorf("sftp binding error: error create file %s: %w", path, err)
203175
}
@@ -240,12 +212,9 @@ func (sftp *Sftp) list(_ context.Context, req *bindings.InvokeRequest) (*binding
240212
return nil, fmt.Errorf("sftp binding error: %w", err)
241213
}
242214

243-
c, err := sftp.Client()
244-
if err != nil {
245-
return nil, fmt.Errorf("sftp binding error: error getting sftp client: %w", err)
246-
}
215+
c := sftp.c
247216

248-
files, err := c.ReadDir(path)
217+
files, err := c.list(path)
249218
if err != nil {
250219
return nil, fmt.Errorf("sftp binding error: error read dir %s: %w", path, err)
251220
}
@@ -280,12 +249,9 @@ func (sftp *Sftp) get(_ context.Context, req *bindings.InvokeRequest) (*bindings
280249
return nil, fmt.Errorf("sftp binding error: %w", err)
281250
}
282251

283-
c, err := sftp.Client()
284-
if err != nil {
285-
return nil, fmt.Errorf("sftp binding error: error getting sftp client: %w", err)
286-
}
252+
c := sftp.c
287253

288-
file, err := c.Open(path)
254+
file, err := c.get(path)
289255
if err != nil {
290256
return nil, fmt.Errorf("sftp binding error: error open file %s: %w", path, err)
291257
}
@@ -311,11 +277,9 @@ func (sftp *Sftp) delete(_ context.Context, req *bindings.InvokeRequest) (*bindi
311277
return nil, fmt.Errorf("sftp binding error: %w", err)
312278
}
313279

314-
c, err := sftp.Client()
315-
if err != nil {
316-
return nil, fmt.Errorf("sftp binding error: error getting sftp client: %w", err)
317-
}
318-
err = c.Remove(path)
280+
c := sftp.c
281+
282+
err = c.delete(path)
319283
if err != nil {
320284
return nil, fmt.Errorf("sftp binding error: error remove file %s: %w", path, err)
321285
}
@@ -339,9 +303,7 @@ func (sftp *Sftp) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bin
339303
}
340304

341305
func (sftp *Sftp) Close() error {
342-
sftp.lock.Lock()
343-
defer sftp.lock.Unlock()
344-
return sftp.sftpClient.Close()
306+
return sftp.c.Close()
345307
}
346308

347309
func (metadata sftpMetadata) getPath(requestMetadata map[string]string) (path string, err error) {
@@ -375,41 +337,3 @@ func (sftp *Sftp) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
375337
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.BindingType)
376338
return
377339
}
378-
379-
func (sftp *Sftp) handleReconnection() error {
380-
sftp.lock.Lock()
381-
defer sftp.lock.Unlock()
382-
383-
// Re-check after acquiring the write lock
384-
if sftp.sftpClient != nil {
385-
if _, err := sftp.sftpClient.Getwd(); err == nil {
386-
return nil
387-
}
388-
_ = sftp.sftpClient.Close()
389-
sftp.sftpClient = nil
390-
}
391-
if sftp.sshClient != nil {
392-
_ = sftp.sshClient.Close()
393-
sftp.sshClient = nil
394-
}
395-
396-
if sftp.metadata == nil || sftp.clientConfig == nil {
397-
return fmt.Errorf("sftp binding error: client not initialized")
398-
}
399-
400-
sshClient, err := ssh.Dial("tcp", sftp.metadata.Address, sftp.clientConfig)
401-
if err != nil {
402-
return fmt.Errorf("sftp binding error: error create ssh client: %w", err)
403-
}
404-
405-
newSftpClient, err := sftpClient.NewClient(sshClient)
406-
if err != nil {
407-
_ = sshClient.Close()
408-
return fmt.Errorf("sftp binding error: error create sftp client: %w", err)
409-
}
410-
411-
sftp.sshClient = sshClient
412-
sftp.sftpClient = newSftpClient
413-
414-
return nil
415-
}

0 commit comments

Comments
 (0)