Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 255 additions & 0 deletions bindings/sftp/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package sftp

import (
"errors"
"fmt"
"io"
"os"
"strings"
"sync"
"syscall"

sftpClient "github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)

type Client struct {
sshClient *ssh.Client
sftpClient *sftpClient.Client
address string
config *ssh.ClientConfig
lock sync.RWMutex
rLock sync.Mutex
Comment on lines +34 to +35
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need two locks?

Copy link
Contributor Author

@javier-aliaga javier-aliaga Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because:

  • lock

    • Controls that only one goroutin trying to reconnect and verifies with ping() if the connection is valid or not
  • rlock

    • We do want to prevent multiple goroutines trying to use the client while swapping the connection and ensure the consistency of the client. We cannot swap the client until we have the full lock

}

func newClient(address string, config *ssh.ClientConfig) (*Client, error) {
if address == "" || config == nil {
return nil, errors.New("sftp binding error: client not initialized")
}

sshClient, err := ssh.Dial("tcp", address, config)
if err != nil {
return nil, fmt.Errorf("sftp binding error: error create ssh client: %w", err)
}

newSftpClient, err := sftpClient.NewClient(sshClient)
if err != nil {
_ = sshClient.Close()
return nil, fmt.Errorf("sftp binding error: error create sftp client: %w", err)
}

return &Client{
sshClient: sshClient,
sftpClient: newSftpClient,
address: address,
config: config,
}, nil
}

func (c *Client) Close() error {
_ = c.sshClient.Close()
c.lock.Lock()
defer c.lock.Unlock()
return c.sftpClient.Close()
}

func (c *Client) list(path string) ([]os.FileInfo, error) {
var fi []os.FileInfo

fn := func() error {
var err error
c.lock.RLock()
defer c.lock.RUnlock()
fi, err = c.sftpClient.ReadDir(path)
return err
}

err := withReconnection(c, fn)
if err != nil {
return nil, err
}

return fi, nil
}

func (c *Client) create(path string) (*sftpClient.File, string, error) {
dir, fileName := sftpClient.Split(path)

var file *sftpClient.File

createFn := func() error {
c.lock.RLock()
defer c.lock.RUnlock()
cErr := c.sftpClient.MkdirAll(dir)
if cErr != nil {
return fmt.Errorf("sftp binding error: error create dir %s: %w", dir, cErr)
}

file, cErr = c.sftpClient.Create(path)
if cErr != nil {
return fmt.Errorf("sftp binding error: error create file %s: %w", path, cErr)
}

return nil
}

rErr := withReconnection(c, createFn)
if rErr != nil {
return nil, "", rErr
}

return file, fileName, nil
}

func (c *Client) get(path string) (*sftpClient.File, error) {
var f *sftpClient.File

fn := func() error {
var err error
c.lock.RLock()
defer c.lock.RUnlock()
f, err = c.sftpClient.Open(path)
return err
}

err := withReconnection(c, fn)
if err != nil {
return nil, err
}

return f, nil
}

func (c *Client) delete(path string) error {
fn := func() error {
var err error
c.lock.RLock()
defer c.lock.RUnlock()
err = c.sftpClient.Remove(path)
return err
}

err := withReconnection(c, fn)
if err != nil {
return err
}

return nil
}

func (c *Client) ping() error {
c.lock.RLock()
defer c.lock.RUnlock()
_, err := c.sftpClient.Getwd()
if err != nil {
return err
}
return nil
}

func withReconnection(c *Client, fn func() error) error {
err := fn()
if err == nil {
return nil
}

if !shouldReconnect(err) {
return err
}

rErr := doReconnect(c)
if rErr != nil {
return errors.Join(err, rErr)
}

err = fn()
if err != nil {
return err
}

return nil
}

func doReconnect(c *Client) error {
c.rLock.Lock()
defer c.rLock.Unlock()
Comment on lines +187 to +188
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is hard to follow the logic when we have 2 locks


err := c.ping()
if !shouldReconnect(err) {
return nil
}

sshClient, err := ssh.Dial("tcp", c.address, c.config)
if err != nil {
return fmt.Errorf("sftp binding error: error create ssh client: %w", err)
}

Comment on lines +195 to +199
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please can we move client creation to a single func?

newSftpClient, err := sftpClient.NewClient(sshClient)
if err != nil {
_ = sshClient.Close()
return fmt.Errorf("sftp binding error: error create sftp client: %w", err)
}

// Swap under short lock; close old clients after unlocking.
c.lock.Lock()
oldSftp := c.sftpClient
oldSSH := c.sshClient
c.sftpClient = newSftpClient
c.sshClient = sshClient
c.lock.Unlock()

if oldSftp != nil {
_ = oldSftp.Close()
}
if oldSSH != nil {
_ = oldSSH.Close()
}

return nil
}

// shouldReconnect returns true if the error looks like a transport-level failure
func shouldReconnect(err error) bool {
if err == nil {
return false
}

// Network/timeout conditions
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, os.ErrDeadlineExceeded) || errors.Is(err, syscall.ECONNRESET) {
return true
}

// Common wrapped network error messages
msg := strings.ToLower(err.Error())
switch {
case strings.Contains(msg, "use of closed network connection"),
strings.Contains(msg, "connection reset by peer"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not use typed errors here? syscall.ECONNRESET

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add to the previous if(L230), but I will keep the message here just in case.

strings.Contains(msg, "broken pipe"),
strings.Contains(msg, "connection refused"),
strings.Contains(msg, "network is unreachable"),
strings.Contains(msg, "no such host"):
return true
}

// SFTP status errors that are logical, not connectivity (avoid reconnect)
if errors.Is(err, sftpClient.ErrSSHFxPermissionDenied) ||
errors.Is(err, sftpClient.ErrSSHFxNoSuchFile) ||
errors.Is(err, sftpClient.ErrSSHFxOpUnsupported) {
return false
}

return true
}
55 changes: 31 additions & 24 deletions bindings/sftp/sftp.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
/*
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package sftp

import (
Expand Down Expand Up @@ -25,9 +38,9 @@ const (

// Sftp is a binding for file operations on sftp server.
type Sftp struct {
metadata *sftpMetadata
logger logger.Logger
sftpClient *sftpClient.Client
metadata *sftpMetadata
logger logger.Logger
c *Client
}

// sftpMetadata defines the sftp metadata.
Expand Down Expand Up @@ -115,19 +128,12 @@ func (sftp *Sftp) Init(_ context.Context, metadata bindings.Metadata) error {
HostKeyCallback: hostKeyCallback,
}

sshClient, err := ssh.Dial("tcp", m.Address, config)
if err != nil {
return fmt.Errorf("sftp binding error: error create ssh client: %w", err)
}

newSftpClient, err := sftpClient.NewClient(sshClient)
sftp.metadata = m
sftp.c, err = newClient(m.Address, config)
if err != nil {
return fmt.Errorf("sftp binding error: error create sftp client: %w", err)
return fmt.Errorf("sftp binding error: create sftp client error: %w", err)
}

sftp.metadata = m
sftp.sftpClient = newSftpClient

return nil
}

Expand Down Expand Up @@ -161,14 +167,9 @@ func (sftp *Sftp) create(_ context.Context, req *bindings.InvokeRequest) (*bindi
return nil, fmt.Errorf("sftp binding error: %w", err)
}

dir, fileName := sftpClient.Split(path)
c := sftp.c

err = sftp.sftpClient.MkdirAll(dir)
if err != nil {
return nil, fmt.Errorf("sftp binding error: error create dir %s: %w", dir, err)
}

file, err := sftp.sftpClient.Create(path)
file, fileName, err := c.create(path)
if err != nil {
return nil, fmt.Errorf("sftp binding error: error create file %s: %w", path, err)
}
Expand Down Expand Up @@ -211,7 +212,9 @@ func (sftp *Sftp) list(_ context.Context, req *bindings.InvokeRequest) (*binding
return nil, fmt.Errorf("sftp binding error: %w", err)
}

files, err := sftp.sftpClient.ReadDir(path)
c := sftp.c

files, err := c.list(path)
if err != nil {
return nil, fmt.Errorf("sftp binding error: error read dir %s: %w", path, err)
}
Expand Down Expand Up @@ -246,7 +249,9 @@ func (sftp *Sftp) get(_ context.Context, req *bindings.InvokeRequest) (*bindings
return nil, fmt.Errorf("sftp binding error: %w", err)
}

file, err := sftp.sftpClient.Open(path)
c := sftp.c

file, err := c.get(path)
if err != nil {
return nil, fmt.Errorf("sftp binding error: error open file %s: %w", path, err)
}
Expand All @@ -272,7 +277,9 @@ func (sftp *Sftp) delete(_ context.Context, req *bindings.InvokeRequest) (*bindi
return nil, fmt.Errorf("sftp binding error: %w", err)
}

err = sftp.sftpClient.Remove(path)
c := sftp.c

err = c.delete(path)
if err != nil {
return nil, fmt.Errorf("sftp binding error: error remove file %s: %w", path, err)
}
Expand All @@ -296,7 +303,7 @@ func (sftp *Sftp) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bin
}

func (sftp *Sftp) Close() error {
return sftp.sftpClient.Close()
return sftp.c.Close()
}

func (metadata sftpMetadata) getPath(requestMetadata map[string]string) (path string, err error) {
Expand Down
Loading
Loading