Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[24.0 backport] ssh: fix error on commandconn close, add ping and default timeout #4395

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions cli/command/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -327,13 +326,8 @@ func (cli *DockerCli) getInitTimeout() time.Duration {

func (cli *DockerCli) initializeFromClient() {
ctx := context.Background()
if !strings.HasPrefix(cli.dockerEndpoint.Host, "ssh://") {
// @FIXME context.WithTimeout doesn't work with connhelper / ssh connections
// time="2020-04-10T10:16:26Z" level=warning msg="commandConn.CloseWrite: commandconn: failed to wait: signal: killed"
var cancel func()
ctx, cancel = context.WithTimeout(ctx, cli.getInitTimeout())
defer cancel()
}
ctx, cancel := context.WithTimeout(ctx, cli.getInitTimeout())
defer cancel()

ping, err := cli.client.Ping(ctx)
if err != nil {
Expand Down
206 changes: 104 additions & 102 deletions cli/connhelper/commandconn/commandconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"runtime"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"

Expand Down Expand Up @@ -64,100 +65,86 @@ func New(_ context.Context, cmd string, args ...string) (net.Conn, error) {

// commandConn implements net.Conn
type commandConn struct {
cmd *exec.Cmd
cmdExited bool
cmdWaitErr error
cmdMutex sync.Mutex
stdin io.WriteCloser
stdout io.ReadCloser
stderrMu sync.Mutex
stderr bytes.Buffer
stdioClosedMu sync.Mutex // for stdinClosed and stdoutClosed
stdinClosed bool
stdoutClosed bool
localAddr net.Addr
remoteAddr net.Addr
cmdMutex sync.Mutex // for cmd, cmdWaitErr
cmd *exec.Cmd
cmdWaitErr error
cmdExited atomic.Bool
stdin io.WriteCloser
stdout io.ReadCloser
stderrMu sync.Mutex // for stderr
stderr bytes.Buffer
stdinClosed atomic.Bool
stdoutClosed atomic.Bool
closing atomic.Bool
localAddr net.Addr
remoteAddr net.Addr
}

// killIfStdioClosed kills the cmd if both stdin and stdout are closed.
func (c *commandConn) killIfStdioClosed() error {
c.stdioClosedMu.Lock()
stdioClosed := c.stdoutClosed && c.stdinClosed
c.stdioClosedMu.Unlock()
if !stdioClosed {
return nil
// kill terminates the process. On Windows it kills the process directly,
// whereas on other platforms, a SIGTERM is sent, before forcefully terminating
// the process after 3 seconds.
func (c *commandConn) kill() {
if c.cmdExited.Load() {
return
}
return c.kill()
}

// killAndWait tries sending SIGTERM to the process before sending SIGKILL.
func killAndWait(cmd *exec.Cmd) error {
c.cmdMutex.Lock()
var werr error
if runtime.GOOS != "windows" {
werrCh := make(chan error)
go func() { werrCh <- cmd.Wait() }()
cmd.Process.Signal(syscall.SIGTERM)
go func() { werrCh <- c.cmd.Wait() }()
_ = c.cmd.Process.Signal(syscall.SIGTERM)
select {
case werr = <-werrCh:
case <-time.After(3 * time.Second):
cmd.Process.Kill()
_ = c.cmd.Process.Kill()
werr = <-werrCh
}
} else {
cmd.Process.Kill()
werr = cmd.Wait()
_ = c.cmd.Process.Kill()
werr = c.cmd.Wait()
}
return werr
c.cmdWaitErr = werr
c.cmdMutex.Unlock()
c.cmdExited.Store(true)
}

// kill returns nil if the command terminated, regardless to the exit status.
func (c *commandConn) kill() error {
var werr error
c.cmdMutex.Lock()
if c.cmdExited {
werr = c.cmdWaitErr
} else {
werr = killAndWait(c.cmd)
c.cmdWaitErr = werr
c.cmdExited = true
}
c.cmdMutex.Unlock()
if werr == nil {
return nil
}
wExitErr, ok := werr.(*exec.ExitError)
if ok {
if wExitErr.ProcessState.Exited() {
return nil
}
// handleEOF handles io.EOF errors while reading or writing from the underlying
// command pipes.
//
// When we've received an EOF we expect that the command will
// be terminated soon. As such, we call Wait() on the command
// and return EOF or the error depending on whether the command
// exited with an error.
//
// If Wait() does not return within 10s, an error is returned
func (c *commandConn) handleEOF(err error) error {
if err != io.EOF {
return err
}
return errors.Wrapf(werr, "commandconn: failed to wait")
}

func (c *commandConn) onEOF(eof error) error {
// when we got EOF, the command is going to be terminated
var werr error
c.cmdMutex.Lock()
if c.cmdExited {
defer c.cmdMutex.Unlock()

var werr error
if c.cmdExited.Load() {
werr = c.cmdWaitErr
} else {
werrCh := make(chan error)
go func() { werrCh <- c.cmd.Wait() }()
select {
case werr = <-werrCh:
c.cmdWaitErr = werr
c.cmdExited = true
c.cmdExited.Store(true)
case <-time.After(10 * time.Second):
c.cmdMutex.Unlock()
c.stderrMu.Lock()
stderr := c.stderr.String()
c.stderrMu.Unlock()
return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, eof, stderr)
return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, err, stderr)
}
}
c.cmdMutex.Unlock()

if werr == nil {
return eof
return err
}
c.stderrMu.Lock()
stderr := c.stderr.String()
Expand All @@ -166,71 +153,86 @@ func (c *commandConn) onEOF(eof error) error {
}

func ignorableCloseError(err error) bool {
errS := err.Error()
ss := []string{
os.ErrClosed.Error(),
return strings.Contains(err.Error(), os.ErrClosed.Error())
}

func (c *commandConn) Read(p []byte) (int, error) {
n, err := c.stdout.Read(p)
// check after the call to Read, since
// it is blocking, and while waiting on it
// Close might get called
if c.closing.Load() {
// If we're currently closing the connection
// we don't want to call onEOF
return n, err
}
for _, s := range ss {
if strings.Contains(errS, s) {
return true
}

return n, c.handleEOF(err)
}

func (c *commandConn) Write(p []byte) (int, error) {
n, err := c.stdin.Write(p)
// check after the call to Write, since
// it is blocking, and while waiting on it
// Close might get called
if c.closing.Load() {
// If we're currently closing the connection
// we don't want to call onEOF
return n, err
}
return false

return n, c.handleEOF(err)
}

// CloseRead allows commandConn to implement halfCloser
func (c *commandConn) CloseRead() error {
// NOTE: maybe already closed here
if err := c.stdout.Close(); err != nil && !ignorableCloseError(err) {
logrus.Warnf("commandConn.CloseRead: %v", err)
return err
}
c.stdioClosedMu.Lock()
c.stdoutClosed = true
c.stdioClosedMu.Unlock()
if err := c.killIfStdioClosed(); err != nil {
logrus.Warnf("commandConn.CloseRead: %v", err)
}
return nil
}
c.stdoutClosed.Store(true)

func (c *commandConn) Read(p []byte) (int, error) {
n, err := c.stdout.Read(p)
if err == io.EOF {
err = c.onEOF(err)
if c.stdinClosed.Load() {
c.kill()
}
return n, err

return nil
}

// CloseWrite allows commandConn to implement halfCloser
func (c *commandConn) CloseWrite() error {
// NOTE: maybe already closed here
if err := c.stdin.Close(); err != nil && !ignorableCloseError(err) {
logrus.Warnf("commandConn.CloseWrite: %v", err)
}
c.stdioClosedMu.Lock()
c.stdinClosed = true
c.stdioClosedMu.Unlock()
if err := c.killIfStdioClosed(); err != nil {
logrus.Warnf("commandConn.CloseWrite: %v", err)
return err
}
return nil
}
c.stdinClosed.Store(true)

func (c *commandConn) Write(p []byte) (int, error) {
n, err := c.stdin.Write(p)
if err == io.EOF {
err = c.onEOF(err)
if c.stdoutClosed.Load() {
c.kill()
}
return n, err
return nil
}

// Close is the net.Conn func that gets called
// by the transport when a dial is cancelled
// due to it's context timing out. Any blocked
// Read or Write calls will be unblocked and
// return errors. It will block until the underlying
// command has terminated.
func (c *commandConn) Close() error {
var err error
if err = c.CloseRead(); err != nil {
c.closing.Store(true)
defer c.closing.Store(false)

if err := c.CloseRead(); err != nil {
logrus.Warnf("commandConn.Close: CloseRead: %v", err)
return err
}
if err = c.CloseWrite(); err != nil {
if err := c.CloseWrite(); err != nil {
logrus.Warnf("commandConn.Close: CloseWrite: %v", err)
return err
}
return err

return nil
}

func (c *commandConn) LocalAddr() net.Addr {
Expand Down
Loading
Loading