77 "fmt"
88 "io"
99 "reflect"
10+ "sync"
1011
1112 sftpClient "github.com/pkg/sftp"
1213 "golang.org/x/crypto/ssh"
@@ -25,9 +26,30 @@ const (
2526
2627// Sftp is a binding for file operations on sftp server.
2728type Sftp struct {
28- metadata * sftpMetadata
29- logger logger.Logger
30- sftpClient * sftpClient.Client
29+ metadata * sftpMetadata
30+ logger logger.Logger
31+ sftpClient * sftpClient.Client
32+ sshClient * ssh.Client
33+ clientConfig * ssh.ClientConfig
34+ lock sync.RWMutex
35+ }
36+
37+ func (sftp * Sftp ) Client () (* sftpClient.Client , error ) {
38+ sftp .lock .RLock ()
39+ current := sftp .sftpClient
40+ sftp .lock .RUnlock ()
41+
42+ if current != nil {
43+ if _ , err := current .Getwd (); err == nil {
44+ return current , nil
45+ }
46+ }
47+
48+ err := sftp .handleReconnection ()
49+ if err != nil {
50+ return nil , err
51+ }
52+ return sftp .sftpClient , nil
3153}
3254
3355// sftpMetadata defines the sftp metadata.
@@ -122,9 +144,11 @@ func (sftp *Sftp) Init(_ context.Context, metadata bindings.Metadata) error {
122144
123145 newSftpClient , err := sftpClient .NewClient (sshClient )
124146 if err != nil {
147+ _ = sshClient .Close ()
125148 return fmt .Errorf ("sftp binding error: error create sftp client: %w" , err )
126149 }
127150
151+ sftp .clientConfig = config
128152 sftp .metadata = m
129153 sftp .sftpClient = newSftpClient
130154
@@ -163,12 +187,17 @@ func (sftp *Sftp) create(_ context.Context, req *bindings.InvokeRequest) (*bindi
163187
164188 dir , fileName := sftpClient .Split (path )
165189
166- err = sftp .sftpClient .MkdirAll (dir )
190+ c , err := sftp .Client ()
191+ if err != nil {
192+ return nil , fmt .Errorf ("sftp binding error: error getting sftp client: %w" , err )
193+ }
194+
195+ err = c .MkdirAll (dir )
167196 if err != nil {
168197 return nil , fmt .Errorf ("sftp binding error: error create dir %s: %w" , dir , err )
169198 }
170199
171- file , err := sftp . sftpClient .Create (path )
200+ file , err := c .Create (path )
172201 if err != nil {
173202 return nil , fmt .Errorf ("sftp binding error: error create file %s: %w" , path , err )
174203 }
@@ -211,7 +240,12 @@ func (sftp *Sftp) list(_ context.Context, req *bindings.InvokeRequest) (*binding
211240 return nil , fmt .Errorf ("sftp binding error: %w" , err )
212241 }
213242
214- files , err := sftp .sftpClient .ReadDir (path )
243+ c , err := sftp .Client ()
244+ if err != nil {
245+ return nil , fmt .Errorf ("sftp binding error: error getting sftp client: %w" , err )
246+ }
247+
248+ files , err := c .ReadDir (path )
215249 if err != nil {
216250 return nil , fmt .Errorf ("sftp binding error: error read dir %s: %w" , path , err )
217251 }
@@ -246,7 +280,12 @@ func (sftp *Sftp) get(_ context.Context, req *bindings.InvokeRequest) (*bindings
246280 return nil , fmt .Errorf ("sftp binding error: %w" , err )
247281 }
248282
249- file , err := sftp .sftpClient .Open (path )
283+ c , err := sftp .Client ()
284+ if err != nil {
285+ return nil , fmt .Errorf ("sftp binding error: error getting sftp client: %w" , err )
286+ }
287+
288+ file , err := c .Open (path )
250289 if err != nil {
251290 return nil , fmt .Errorf ("sftp binding error: error open file %s: %w" , path , err )
252291 }
@@ -272,7 +311,11 @@ func (sftp *Sftp) delete(_ context.Context, req *bindings.InvokeRequest) (*bindi
272311 return nil , fmt .Errorf ("sftp binding error: %w" , err )
273312 }
274313
275- err = sftp .sftpClient .Remove (path )
314+ c , err := sftp .Client ()
315+ if err != nil {
316+ return nil , fmt .Errorf ("sftp binding error: error getting sftp client: %w" , err )
317+ }
318+ err = c .Remove (path )
276319 if err != nil {
277320 return nil , fmt .Errorf ("sftp binding error: error remove file %s: %w" , path , err )
278321 }
@@ -296,6 +339,8 @@ func (sftp *Sftp) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bin
296339}
297340
298341func (sftp * Sftp ) Close () error {
342+ sftp .lock .Lock ()
343+ defer sftp .lock .Unlock ()
299344 return sftp .sftpClient .Close ()
300345}
301346
@@ -330,3 +375,41 @@ func (sftp *Sftp) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
330375 metadata .GetMetadataInfoFromStructType (reflect .TypeOf (metadataStruct ), & metadataInfo , metadata .BindingType )
331376 return
332377}
378+
379+ func (sftp * Sftp ) handleReconnection () error {
380+ sftp .lock .Lock ()
381+ defer sftp .lock .Unlock ()
382+
383+ // Re-check after acquiring the write lock
384+ if sftp .sftpClient != nil {
385+ if _ , err := sftp .sftpClient .Getwd (); err == nil {
386+ return nil
387+ }
388+ _ = sftp .sftpClient .Close ()
389+ sftp .sftpClient = nil
390+ }
391+ if sftp .sshClient != nil {
392+ _ = sftp .sshClient .Close ()
393+ sftp .sshClient = nil
394+ }
395+
396+ if sftp .metadata == nil || sftp .clientConfig == nil {
397+ return fmt .Errorf ("sftp binding error: client not initialized" )
398+ }
399+
400+ sshClient , err := ssh .Dial ("tcp" , sftp .metadata .Address , sftp .clientConfig )
401+ if err != nil {
402+ return fmt .Errorf ("sftp binding error: error create ssh client: %w" , err )
403+ }
404+
405+ newSftpClient , err := sftpClient .NewClient (sshClient )
406+ if err != nil {
407+ _ = sshClient .Close ()
408+ return fmt .Errorf ("sftp binding error: error create sftp client: %w" , err )
409+ }
410+
411+ sftp .sshClient = sshClient
412+ sftp .sftpClient = newSftpClient
413+
414+ return nil
415+ }
0 commit comments