Skip to content

Commit d43bfab

Browse files
committed
chore: Handle sftp reconnections
Signed-off-by: Javier Aliaga <javier@diagrid.io>
1 parent 1d759c9 commit d43bfab

File tree

4 files changed

+540
-29
lines changed

4 files changed

+540
-29
lines changed

bindings/sftp/proxy/proxy.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
Copyright 2025 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
*/
13+
14+
package sftp
15+
16+
import (
17+
"errors"
18+
"io"
19+
"log"
20+
"net"
21+
"sync/atomic"
22+
)
23+
24+
type Proxy struct {
25+
ListenAddr string
26+
UpstreamAddr string
27+
Client net.Conn
28+
Server net.Conn
29+
ReconnectionCount atomic.Int32
30+
Listener net.Listener
31+
}
32+
33+
func (p *Proxy) ListenAndServe() error {
34+
ln, err := net.Listen("tcp", p.ListenAddr)
35+
if err != nil {
36+
log.Fatalf("listen: %v", err)
37+
}
38+
log.Printf("Proxy listening on %s -> %s", p.ListenAddr, p.UpstreamAddr)
39+
p.Listener = ln
40+
41+
for {
42+
client, err := ln.Accept()
43+
if err != nil {
44+
if errors.Is(err, net.ErrClosed) {
45+
return nil
46+
}
47+
log.Printf("accept error: %v", err)
48+
continue
49+
}
50+
go p.handle(client)
51+
}
52+
}
53+
54+
func (p *Proxy) handle(client net.Conn) {
55+
defer client.Close()
56+
57+
// Connect to upstream SFTP server
58+
server, err := net.Dial("tcp", p.UpstreamAddr)
59+
if err != nil {
60+
log.Printf("dial upstream: %v", err)
61+
return
62+
}
63+
defer server.Close()
64+
65+
p.Client = client
66+
p.Server = server
67+
p.ReconnectionCount.Add(1)
68+
errCh := make(chan error, 2)
69+
70+
// client -> server
71+
go func() {
72+
_, cErr := io.Copy(server, client)
73+
errCh <- cErr
74+
}()
75+
76+
// server -> client
77+
go func() {
78+
_, cErr := io.Copy(client, server)
79+
errCh <- cErr
80+
}()
81+
82+
// When either direction ends, close both ends
83+
if err := <-errCh; err != nil && !isUsefullyClosed(err) {
84+
log.Printf("proxy stream ended with error: %v", err)
85+
}
86+
}
87+
88+
func (p *Proxy) KillServerConn() error {
89+
return p.Server.Close()
90+
}
91+
92+
func (p *Proxy) Close() {
93+
if p.Client != nil {
94+
_ = p.Client.Close()
95+
}
96+
97+
if p.Server != nil {
98+
_ = p.Server.Close()
99+
}
100+
101+
if p.Listener != nil {
102+
_ = p.Listener.Close()
103+
}
104+
105+
p.ReconnectionCount.Store(0)
106+
}
107+
108+
// isUsefullyClosed filters common close conditions from logging noise
109+
func isUsefullyClosed(err error) bool {
110+
return err == io.EOF || errors.Is(err, net.ErrClosed)
111+
}

bindings/sftp/sftp.go

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
/*
2+
Copyright 2025 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
*/
13+
114
package sftp
215

316
import (
@@ -25,9 +38,9 @@ const (
2538

2639
// Sftp is a binding for file operations on sftp server.
2740
type Sftp struct {
28-
metadata *sftpMetadata
29-
logger logger.Logger
30-
sftpClient *sftpClient.Client
41+
metadata *sftpMetadata
42+
logger logger.Logger
43+
c *Client
3144
}
3245

3346
// sftpMetadata defines the sftp metadata.
@@ -115,19 +128,12 @@ func (sftp *Sftp) Init(_ context.Context, metadata bindings.Metadata) error {
115128
HostKeyCallback: hostKeyCallback,
116129
}
117130

118-
sshClient, err := ssh.Dial("tcp", m.Address, config)
119-
if err != nil {
120-
return fmt.Errorf("sftp binding error: error create ssh client: %w", err)
121-
}
122-
123-
newSftpClient, err := sftpClient.NewClient(sshClient)
131+
sftp.metadata = m
132+
sftp.c, err = newClient(m.Address, config)
124133
if err != nil {
125-
return fmt.Errorf("sftp binding error: error create sftp client: %w", err)
134+
return fmt.Errorf("sftp binding error: create sftp client error: %w", err)
126135
}
127136

128-
sftp.metadata = m
129-
sftp.sftpClient = newSftpClient
130-
131137
return nil
132138
}
133139

@@ -161,14 +167,9 @@ func (sftp *Sftp) create(_ context.Context, req *bindings.InvokeRequest) (*bindi
161167
return nil, fmt.Errorf("sftp binding error: %w", err)
162168
}
163169

164-
dir, fileName := sftpClient.Split(path)
170+
c := sftp.c
165171

166-
err = sftp.sftpClient.MkdirAll(dir)
167-
if err != nil {
168-
return nil, fmt.Errorf("sftp binding error: error create dir %s: %w", dir, err)
169-
}
170-
171-
file, err := sftp.sftpClient.Create(path)
172+
file, fileName, err := c.create(path)
172173
if err != nil {
173174
return nil, fmt.Errorf("sftp binding error: error create file %s: %w", path, err)
174175
}
@@ -211,7 +212,9 @@ func (sftp *Sftp) list(_ context.Context, req *bindings.InvokeRequest) (*binding
211212
return nil, fmt.Errorf("sftp binding error: %w", err)
212213
}
213214

214-
files, err := sftp.sftpClient.ReadDir(path)
215+
c := sftp.c
216+
217+
files, err := c.list(path)
215218
if err != nil {
216219
return nil, fmt.Errorf("sftp binding error: error read dir %s: %w", path, err)
217220
}
@@ -246,7 +249,9 @@ func (sftp *Sftp) get(_ context.Context, req *bindings.InvokeRequest) (*bindings
246249
return nil, fmt.Errorf("sftp binding error: %w", err)
247250
}
248251

249-
file, err := sftp.sftpClient.Open(path)
252+
c := sftp.c
253+
254+
file, err := c.get(path)
250255
if err != nil {
251256
return nil, fmt.Errorf("sftp binding error: error open file %s: %w", path, err)
252257
}
@@ -272,7 +277,9 @@ func (sftp *Sftp) delete(_ context.Context, req *bindings.InvokeRequest) (*bindi
272277
return nil, fmt.Errorf("sftp binding error: %w", err)
273278
}
274279

275-
err = sftp.sftpClient.Remove(path)
280+
c := sftp.c
281+
282+
err = c.delete(path)
276283
if err != nil {
277284
return nil, fmt.Errorf("sftp binding error: error remove file %s: %w", path, err)
278285
}
@@ -296,7 +303,7 @@ func (sftp *Sftp) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bin
296303
}
297304

298305
func (sftp *Sftp) Close() error {
299-
return sftp.sftpClient.Close()
306+
return sftp.c.Close()
300307
}
301308

302309
func (metadata sftpMetadata) getPath(requestMetadata map[string]string) (path string, err error) {

0 commit comments

Comments
 (0)