1+ /*
2+ Copyright 2025 The Dapr Authors
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+ http://www.apache.org/licenses/LICENSE-2.0
7+ Unless required by applicable law or agreed to in writing, software
8+ distributed under the License is distributed on an "AS IS" BASIS,
9+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+ See the License for the specific language governing permissions and
11+ limitations under the License.
12+ */
13+
114package sftp
215
316import (
@@ -7,7 +20,6 @@ import (
720 "fmt"
821 "io"
922 "reflect"
10- "sync"
1123
1224 sftpClient "github.com/pkg/sftp"
1325 "golang.org/x/crypto/ssh"
@@ -26,30 +38,9 @@ const (
2638
2739// Sftp is a binding for file operations on sftp server.
2840type Sftp struct {
29- metadata * sftpMetadata
30- logger logger.Logger
31- sftpClient * sftpClient.Client
32- sshClient * ssh.Client
33- clientConfig * ssh.ClientConfig
34- lock sync.RWMutex
35- }
36-
37- func (sftp * Sftp ) Client () (* sftpClient.Client , error ) {
38- sftp .lock .RLock ()
39- current := sftp .sftpClient
40- sftp .lock .RUnlock ()
41-
42- if current != nil {
43- if _ , err := current .Getwd (); err == nil {
44- return current , nil
45- }
46- }
47-
48- err := sftp .handleReconnection ()
49- if err != nil {
50- return nil , err
51- }
52- return sftp .sftpClient , nil
41+ metadata * sftpMetadata
42+ logger logger.Logger
43+ c * Client
5344}
5445
5546// sftpMetadata defines the sftp metadata.
@@ -137,21 +128,12 @@ func (sftp *Sftp) Init(_ context.Context, metadata bindings.Metadata) error {
137128 HostKeyCallback : hostKeyCallback ,
138129 }
139130
140- sshClient , err := ssh .Dial ("tcp" , m .Address , config )
141- if err != nil {
142- return fmt .Errorf ("sftp binding error: error create ssh client: %w" , err )
143- }
144-
145- newSftpClient , err := sftpClient .NewClient (sshClient )
131+ sftp .metadata = m
132+ sftp .c , err = newClient (m .Address , config )
146133 if err != nil {
147- _ = sshClient .Close ()
148- return fmt .Errorf ("sftp binding error: error create sftp client: %w" , err )
134+ return fmt .Errorf ("sftp binding error: create sftp client error: %w" , err )
149135 }
150136
151- sftp .clientConfig = config
152- sftp .metadata = m
153- sftp .sftpClient = newSftpClient
154-
155137 return nil
156138}
157139
@@ -185,19 +167,9 @@ func (sftp *Sftp) create(_ context.Context, req *bindings.InvokeRequest) (*bindi
185167 return nil , fmt .Errorf ("sftp binding error: %w" , err )
186168 }
187169
188- dir , fileName := sftpClient . Split ( path )
170+ c := sftp . c
189171
190- c , err := sftp .Client ()
191- if err != nil {
192- return nil , fmt .Errorf ("sftp binding error: error getting sftp client: %w" , err )
193- }
194-
195- err = c .MkdirAll (dir )
196- if err != nil {
197- return nil , fmt .Errorf ("sftp binding error: error create dir %s: %w" , dir , err )
198- }
199-
200- file , err := c .Create (path )
172+ file , fileName , err := c .create (path )
201173 if err != nil {
202174 return nil , fmt .Errorf ("sftp binding error: error create file %s: %w" , path , err )
203175 }
@@ -240,12 +212,9 @@ func (sftp *Sftp) list(_ context.Context, req *bindings.InvokeRequest) (*binding
240212 return nil , fmt .Errorf ("sftp binding error: %w" , err )
241213 }
242214
243- c , err := sftp .Client ()
244- if err != nil {
245- return nil , fmt .Errorf ("sftp binding error: error getting sftp client: %w" , err )
246- }
215+ c := sftp .c
247216
248- files , err := c .ReadDir (path )
217+ files , err := c .list (path )
249218 if err != nil {
250219 return nil , fmt .Errorf ("sftp binding error: error read dir %s: %w" , path , err )
251220 }
@@ -280,12 +249,9 @@ func (sftp *Sftp) get(_ context.Context, req *bindings.InvokeRequest) (*bindings
280249 return nil , fmt .Errorf ("sftp binding error: %w" , err )
281250 }
282251
283- c , err := sftp .Client ()
284- if err != nil {
285- return nil , fmt .Errorf ("sftp binding error: error getting sftp client: %w" , err )
286- }
252+ c := sftp .c
287253
288- file , err := c .Open (path )
254+ file , err := c .get (path )
289255 if err != nil {
290256 return nil , fmt .Errorf ("sftp binding error: error open file %s: %w" , path , err )
291257 }
@@ -311,11 +277,9 @@ func (sftp *Sftp) delete(_ context.Context, req *bindings.InvokeRequest) (*bindi
311277 return nil , fmt .Errorf ("sftp binding error: %w" , err )
312278 }
313279
314- c , err := sftp .Client ()
315- if err != nil {
316- return nil , fmt .Errorf ("sftp binding error: error getting sftp client: %w" , err )
317- }
318- err = c .Remove (path )
280+ c := sftp .c
281+
282+ err = c .delete (path )
319283 if err != nil {
320284 return nil , fmt .Errorf ("sftp binding error: error remove file %s: %w" , path , err )
321285 }
@@ -339,9 +303,7 @@ func (sftp *Sftp) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bin
339303}
340304
341305func (sftp * Sftp ) Close () error {
342- sftp .lock .Lock ()
343- defer sftp .lock .Unlock ()
344- return sftp .sftpClient .Close ()
306+ return sftp .c .Close ()
345307}
346308
347309func (metadata sftpMetadata ) getPath (requestMetadata map [string ]string ) (path string , err error ) {
@@ -375,41 +337,3 @@ func (sftp *Sftp) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
375337 metadata .GetMetadataInfoFromStructType (reflect .TypeOf (metadataStruct ), & metadataInfo , metadata .BindingType )
376338 return
377339}
378-
379- func (sftp * Sftp ) handleReconnection () error {
380- sftp .lock .Lock ()
381- defer sftp .lock .Unlock ()
382-
383- // Re-check after acquiring the write lock
384- if sftp .sftpClient != nil {
385- if _ , err := sftp .sftpClient .Getwd (); err == nil {
386- return nil
387- }
388- _ = sftp .sftpClient .Close ()
389- sftp .sftpClient = nil
390- }
391- if sftp .sshClient != nil {
392- _ = sftp .sshClient .Close ()
393- sftp .sshClient = nil
394- }
395-
396- if sftp .metadata == nil || sftp .clientConfig == nil {
397- return fmt .Errorf ("sftp binding error: client not initialized" )
398- }
399-
400- sshClient , err := ssh .Dial ("tcp" , sftp .metadata .Address , sftp .clientConfig )
401- if err != nil {
402- return fmt .Errorf ("sftp binding error: error create ssh client: %w" , err )
403- }
404-
405- newSftpClient , err := sftpClient .NewClient (sshClient )
406- if err != nil {
407- _ = sshClient .Close ()
408- return fmt .Errorf ("sftp binding error: error create sftp client: %w" , err )
409- }
410-
411- sftp .sshClient = sshClient
412- sftp .sftpClient = newSftpClient
413-
414- return nil
415- }
0 commit comments