diff --git a/pty.go b/pty.go index 3a8b599..6d54a20 100644 --- a/pty.go +++ b/pty.go @@ -13,6 +13,8 @@ func NewPtyWriter(w io.Writer) io.Writer { } } +var _ io.Writer = ptyWriter{} + type ptyWriter struct { w io.Writer } @@ -29,3 +31,27 @@ func (w ptyWriter) Write(p []byte) (int, error) { } return n, err } + +// NewPtyReadWriter return an io.ReadWriter that delegates the read to the +// given io.ReadWriter, and the writes to a ptyWriter. +func NewPtyReadWriter(rw io.ReadWriter) io.ReadWriter { + return readWriterDelegate{ + w: NewPtyWriter(rw), + r: rw, + } +} + +var _ io.ReadWriter = readWriterDelegate{} + +type readWriterDelegate struct { + w io.Writer + r io.Reader +} + +func (rw readWriterDelegate) Read(p []byte) (n int, err error) { + return rw.r.Read(p) +} + +func (rw readWriterDelegate) Write(p []byte) (n int, err error) { + return rw.w.Write(p) +} diff --git a/session.go b/session.go index 331bf3c..20fbf78 100644 --- a/session.go +++ b/session.go @@ -83,9 +83,11 @@ type Session interface { // During the time that no channel is registered, breaks are ignored. Break(c chan<- bool) - // SafeStderr returns the Stderr io.Writer that handles replacing \n with - // \r\n when there's an active Pty. - SafeStderr() io.Writer + // Stderr returns an io.ReadWriter that writes to this channel + // with the extended data type set to stderr. Stderr may + // safely be read and written from a different goroutine than + // Read and Write respectively. + Stderr() io.ReadWriter } // maxSigBufSize is how many signals will be buffered @@ -131,11 +133,11 @@ type session struct { breakCh chan<- bool } -func (sess *session) SafeStderr() io.Writer { +func (sess *session) Stderr() io.ReadWriter { if sess.pty != nil { - return NewPtyWriter(sess.Stderr()) + return NewPtyReadWriter(sess.Channel.Stderr()) } - return sess.Stderr() + return sess.Channel.Stderr() } func (sess *session) Write(p []byte) (int, error) { diff --git a/session_test.go b/session_test.go index 301be61..e17d284 100644 --- a/session_test.go +++ b/session_test.go @@ -6,6 +6,7 @@ import ( "io" "net" "testing" + "time" gossh "golang.org/x/crypto/ssh" ) @@ -236,7 +237,8 @@ func TestPtyWriter(t *testing.T) { session, _, cleanup := newTestSession(t, &Server{ Handler: func(s Session) { _, _ = fmt.Fprintln(s, "foo\nbar") - _, _ = fmt.Fprintln(s.SafeStderr(), "many\nerrors") + time.Sleep(10 * time.Millisecond) + _, _ = fmt.Fprintln(s.Stderr(), "many\nerrors") _ = s.Exit(0) }, }, nil)