Skip to content

Commit

Permalink
refactor: keeping the same Stderr method
Browse files Browse the repository at this point in the history
  • Loading branch information
caarlos0 committed Oct 24, 2022
1 parent 62448dd commit 993008d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
26 changes: 26 additions & 0 deletions pty.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ func NewPtyWriter(w io.Writer) io.Writer {
}
}

var _ io.Writer = ptyWriter{}

type ptyWriter struct {
w io.Writer
}
Expand All @@ -29,3 +31,27 @@ func (w ptyWriter) Write(p []byte) (int, error) {
}
return n, err
}

// NewPtyReadWriter return an io.ReadWriter that delegates the read to the
// given io.ReadWriter, and the writes to a ptyWriter.
func NewPtyReadWriter(rw io.ReadWriter) io.ReadWriter {
return readWriterDelegate{
w: NewPtyWriter(rw),
r: rw,
}
}

var _ io.ReadWriter = readWriterDelegate{}

type readWriterDelegate struct {
w io.Writer
r io.Reader
}

func (rw readWriterDelegate) Read(p []byte) (n int, err error) {
return rw.r.Read(p)
}

func (rw readWriterDelegate) Write(p []byte) (n int, err error) {
return rw.w.Write(p)
}
14 changes: 8 additions & 6 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ type Session interface {
// During the time that no channel is registered, breaks are ignored.
Break(c chan<- bool)

// SafeStderr returns the Stderr io.Writer that handles replacing \n with
// \r\n when there's an active Pty.
SafeStderr() io.Writer
// Stderr returns an io.ReadWriter that writes to this channel
// with the extended data type set to stderr. Stderr may
// safely be read and written from a different goroutine than
// Read and Write respectively.
Stderr() io.ReadWriter
}

// maxSigBufSize is how many signals will be buffered
Expand Down Expand Up @@ -131,11 +133,11 @@ type session struct {
breakCh chan<- bool
}

func (sess *session) SafeStderr() io.Writer {
func (sess *session) Stderr() io.ReadWriter {
if sess.pty != nil {
return NewPtyWriter(sess.Stderr())
return NewPtyReadWriter(sess.Channel.Stderr())
}
return sess.Stderr()
return sess.Channel.Stderr()
}

func (sess *session) Write(p []byte) (int, error) {
Expand Down
4 changes: 3 additions & 1 deletion session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"net"
"testing"
"time"

gossh "golang.org/x/crypto/ssh"
)
Expand Down Expand Up @@ -236,7 +237,8 @@ func TestPtyWriter(t *testing.T) {
session, _, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
_, _ = fmt.Fprintln(s, "foo\nbar")
_, _ = fmt.Fprintln(s.SafeStderr(), "many\nerrors")
time.Sleep(10 * time.Millisecond)
_, _ = fmt.Fprintln(s.Stderr(), "many\nerrors")
_ = s.Exit(0)
},
}, nil)
Expand Down

0 comments on commit 993008d

Please sign in to comment.