Skip to content

Commit

Permalink
feat: allocate real pty (#8)
Browse files Browse the repository at this point in the history
* feat: allocate real pty

This adds a new PtyHandler to handle allocating PTYs and storing them in
a platform specific field in `Pty`. This PR is backward-compatible,
it defaults to EmulatePty handler that sets the `emulatePty` field in
context and uses `PtyWriter` to preserve the current behavor.

* fix: update pty godoc

* fix: convert pty handlers to server options

* fix: consume resize events on pty

* feat: support windows conpty

* feat: add pty start process example

* fix: return tty name

* fix: ptystart example for unix

* fix: imports

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

* fix: update

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

* chore: deps

---------

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>
Co-authored-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>
  • Loading branch information
aymanbagabas and caarlos0 authored Jan 17, 2024
1 parent 7e1d867 commit 7ed763a
Show file tree
Hide file tree
Showing 12 changed files with 876 additions and 13 deletions.
59 changes: 59 additions & 0 deletions _examples/ssh-ptystart/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package main

import (
"fmt"
"io"
"log"
"os"
"os/exec"
"runtime"
"time"

"github.com/charmbracelet/ssh"
)

func main() {
ssh.Handle(func(s ssh.Session) {
log.Printf("connected %s %s %q", s.User(), s.RemoteAddr(), s.RawCommand())
defer log.Printf("disconnected %s %s", s.User(), s.RemoteAddr())

pty, _, ok := s.Pty()
if !ok {
io.WriteString(s, "No PTY requested.\n")
s.Exit(1)
return
}

name := "bash"
if runtime.GOOS == "windows" {
name = "powershell.exe"
}
cmd := exec.Command(name)
cmd.Env = append(os.Environ(), "SSH_TTY="+pty.Name(), fmt.Sprintf("TERM=%s", pty.Term))
if err := pty.Start(cmd); err != nil {
fmt.Fprintln(s, err.Error())
s.Exit(1)
return
}

if runtime.GOOS == "windows" {
// ProcessState gets populated by pty.Start waiting on the process
// to exit.
for cmd.ProcessState == nil {
time.Sleep(100 * time.Millisecond)
}

s.Exit(cmd.ProcessState.ExitCode())
} else {
if err := cmd.Wait(); err != nil {
fmt.Fprintln(s, err)
s.Exit(cmd.ProcessState.ExitCode())
}
}
})

log.Println("starting ssh server on port 2222...")
if err := ssh.ListenAndServe(":2222", nil, ssh.AllocatePty()); err != nil && err != ssh.ErrServerClosed {
log.Fatal(err)
}
}
4 changes: 4 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ var (
// ContextKeyPublicKey is a context key for use with Contexts in this package.
// The associated value will be of type PublicKey.
ContextKeyPublicKey = &contextKey{"public-key"}

// ContextKeySession is a context key for use with Contexts in this package.
// The associated value will be of type Session.
ContextKeySession = &contextKey{"session"}
)

// Context is a package specific context interface. It exposes connection
Expand Down
6 changes: 5 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ go 1.17

require (
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
github.com/charmbracelet/x/exp/term v0.0.0-20240117030132-5a84c80527c7
github.com/creack/pty v1.1.21
github.com/u-root/u-root v0.11.0
golang.org/x/crypto v0.17.0
golang.org/x/sys v0.16.0
)

require golang.org/x/sys v0.15.0 // indirect
require github.com/charmbracelet/x/errors v0.0.0-20240117030013-d31dba354651 // indirect
342 changes: 341 additions & 1 deletion go.sum

Large diffs are not rendered by default.

30 changes: 29 additions & 1 deletion options.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func HostKeyPEM(bytes []byte) Option {
// denying PTY requests.
func NoPty() Option {
return func(srv *Server) error {
srv.PtyCallback = func(ctx Context, pty Pty) bool {
srv.PtyCallback = func(Context, Pty) bool {
return false
}
return nil
Expand All @@ -82,3 +82,31 @@ func WrapConn(fn ConnCallback) Option {
return nil
}
}

var contextKeyEmulatePty = &contextKey{"emulate-pty"}

func emulatePtyHandler(ctx Context, _ Session, _ Pty) (func() error, error) {
ctx.SetValue(contextKeyEmulatePty, true)
return func() error { return nil }, nil
}

// EmulatePty returns a functional option that fakes a PTY. It uses PtyWriter
// underneath.
func EmulatePty() Option {
return func(s *Server) error {
s.PtyHandler = emulatePtyHandler
return nil
}
}

// AllocatePty returns a functional option that allocates a PTY. Implementers
// who wish to use an actual PTY should use this along with the platform
// specific PTY implementation defined in pty_*.go.
func AllocatePty() Option {
return func(s *Server) error {
s.PtyHandler = func(_ Context, s Session, pty Pty) (func() error, error) {
return s.(*session).ptyAllocate(pty.Term, pty.Window, pty.Modes)
}
return nil
}
}
14 changes: 14 additions & 0 deletions pty.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@ package ssh

import (
"bytes"
"errors"
"io"
"os/exec"
)

// ErrUnsupported is returned when the platform does not support PTY.
var ErrUnsupported = errors.New("pty unsupported")

// NewPtyWriter creates a writer that handles when the session has a active
// PTY, replacing the \n with \r\n.
func NewPtyWriter(w io.Writer) io.Writer {
Expand Down Expand Up @@ -55,3 +60,12 @@ func (rw readWriterDelegate) Read(p []byte) (n int, err error) {
func (rw readWriterDelegate) Write(p []byte) (n int, err error) {
return rw.w.Write(p)
}

// Start starts a *exec.Cmd attached to the Session. If a PTY is allocated,
// it will use that for I/O.
// On Windows, the process execution lifecycle is not managed by Go and has to
// be managed manually. This means that c.Wait() won't work.
// See https://github.com/charmbracelet/x/blob/main/exp/term/windows/conpty/conpty_windows.go
func (p *Pty) Start(c *exec.Cmd) error {
return p.start(c)
}
44 changes: 44 additions & 0 deletions pty_other.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//go:build !linux && !darwin && !freebsd && !dragonfly && !netbsd && !openbsd && !solaris && !windows
// +build !linux,!darwin,!freebsd,!dragonfly,!netbsd,!openbsd,!solaris,!windows

package ssh

import (
"os/exec"

"golang.org/x/crypto/ssh"
)

type impl struct{}

func (i *impl) IsZero() bool {
return true
}

func (i *impl) Name() string {
return ""
}

func (i *impl) Read(p []byte) (n int, err error) {
return 0, ErrUnsupported
}

func (i *impl) Write(p []byte) (n int, err error) {
return 0, ErrUnsupported
}

func (i *impl) Resize(w int, h int) error {
return ErrUnsupported
}

func (i *impl) Close() error {
return nil
}

func (*impl) start(*exec.Cmd) error {
return ErrUnsupported
}

func newPty(Context, string, Window, ssh.TerminalModes) (impl, error) {
return impl{}, ErrUnsupported
}
199 changes: 199 additions & 0 deletions pty_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build darwin dragonfly freebsd linux netbsd openbsd solaris

package ssh

import (
"fmt"
"os"
"os/exec"
"syscall"

"github.com/creack/pty"
"github.com/u-root/u-root/pkg/termios"
"golang.org/x/crypto/ssh"
"golang.org/x/sys/unix"
)

type impl struct {
// Master is the master PTY file descriptor.
Master *os.File

// Slave is the slave PTY file descriptor.
Slave *os.File
}

func (i *impl) IsZero() bool {
return i.Master == nil && i.Slave == nil
}

// Name returns the name of the slave PTY.
func (i *impl) Name() string {
return i.Slave.Name()
}

// Read implements ptyInterface.
func (i *impl) Read(p []byte) (n int, err error) {
return i.Master.Read(p)
}

// Write implements ptyInterface.
func (i *impl) Write(p []byte) (n int, err error) {
return i.Master.Write(p)
}

func (i *impl) Close() error {
if err := i.Master.Close(); err != nil {
return err
}
return i.Slave.Close()
}

func (i *impl) Resize(w int, h int) (rErr error) {
conn, err := i.Master.SyscallConn()
if err != nil {
return err
}

return conn.Control(func(fd uintptr) {
rErr = termios.SetWinSize(fd, &termios.Winsize{
Winsize: unix.Winsize{
Row: uint16(h),
Col: uint16(w),
},
})
})
}

func (i *impl) start(c *exec.Cmd) error {
c.Stdin, c.Stdout, c.Stderr = i.Slave, i.Slave, i.Slave
if c.SysProcAttr == nil {
c.SysProcAttr = &syscall.SysProcAttr{}
}
c.SysProcAttr.Setctty = true
c.SysProcAttr.Setsid = true
return c.Start()
}

func newPty(_ Context, _ string, win Window, modes ssh.TerminalModes) (_ impl, rErr error) {
ptm, pts, err := pty.Open()
if err != nil {
return impl{}, err
}

conn, err := ptm.SyscallConn()
if err != nil {
return impl{}, err
}

if err := conn.Control(func(fd uintptr) {
rErr = applyTerminalModesToFd(fd, win.Width, win.Height, modes)
}); err != nil {
return impl{}, err
}

return impl{Master: ptm, Slave: pts}, rErr
}

func applyTerminalModesToFd(fd uintptr, width int, height int, modes ssh.TerminalModes) error {
// Get the current TTY configuration.
tios, err := termios.GTTY(int(fd))
if err != nil {
return fmt.Errorf("GTTY: %w", err)
}

// Apply the modes from the SSH request.
tios.Row = height
tios.Col = width

for c, v := range modes {
if c == ssh.TTY_OP_ISPEED {
tios.Ispeed = int(v)
continue
}
if c == ssh.TTY_OP_OSPEED {
tios.Ospeed = int(v)
continue
}
k, ok := terminalModeFlagNames[c]
if !ok {
continue
}
if _, ok := tios.CC[k]; ok {
tios.CC[k] = uint8(v)
continue
}
if _, ok := tios.Opts[k]; ok {
tios.Opts[k] = v > 0
continue
}
}

// Save the new TTY configuration.
if _, err := tios.STTY(int(fd)); err != nil {
return fmt.Errorf("STTY: %w", err)
}

return nil
}

// terminalModeFlagNames maps the SSH terminal mode flags to mnemonic
// names used by the termios package.
var terminalModeFlagNames = map[uint8]string{
ssh.VINTR: "intr",
ssh.VQUIT: "quit",
ssh.VERASE: "erase",
ssh.VKILL: "kill",
ssh.VEOF: "eof",
ssh.VEOL: "eol",
ssh.VEOL2: "eol2",
ssh.VSTART: "start",
ssh.VSTOP: "stop",
ssh.VSUSP: "susp",
ssh.VDSUSP: "dsusp",
ssh.VREPRINT: "rprnt",
ssh.VWERASE: "werase",
ssh.VLNEXT: "lnext",
ssh.VFLUSH: "flush",
ssh.VSWTCH: "swtch",
ssh.VSTATUS: "status",
ssh.VDISCARD: "discard",
ssh.IGNPAR: "ignpar",
ssh.PARMRK: "parmrk",
ssh.INPCK: "inpck",
ssh.ISTRIP: "istrip",
ssh.INLCR: "inlcr",
ssh.IGNCR: "igncr",
ssh.ICRNL: "icrnl",
ssh.IUCLC: "iuclc",
ssh.IXON: "ixon",
ssh.IXANY: "ixany",
ssh.IXOFF: "ixoff",
ssh.IMAXBEL: "imaxbel",
ssh.IUTF8: "iutf8",
ssh.ISIG: "isig",
ssh.ICANON: "icanon",
ssh.XCASE: "xcase",
ssh.ECHO: "echo",
ssh.ECHOE: "echoe",
ssh.ECHOK: "echok",
ssh.ECHONL: "echonl",
ssh.NOFLSH: "noflsh",
ssh.TOSTOP: "tostop",
ssh.IEXTEN: "iexten",
ssh.ECHOCTL: "echoctl",
ssh.ECHOKE: "echoke",
ssh.PENDIN: "pendin",
ssh.OPOST: "opost",
ssh.OLCUC: "olcuc",
ssh.ONLCR: "onlcr",
ssh.OCRNL: "ocrnl",
ssh.ONOCR: "onocr",
ssh.ONLRET: "onlret",
ssh.CS7: "cs7",
ssh.CS8: "cs8",
ssh.PARENB: "parenb",
ssh.PARODD: "parodd",
ssh.TTY_OP_ISPEED: "tty_op_ispeed",
ssh.TTY_OP_OSPEED: "tty_op_ospeed",
}
Loading

0 comments on commit 7ed763a

Please sign in to comment.