Skip to content

Commit

Permalink
SSH command WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
iximiuz committed Mar 4, 2024
1 parent 7f2541d commit a6c73fa
Show file tree
Hide file tree
Showing 4 changed files with 406 additions and 1 deletion.
111 changes: 111 additions & 0 deletions cmd/ssh/ssh.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package ssh

import (
"context"
"fmt"
"log/slog"
"math/rand"
"os"
"time"

"github.com/spf13/cobra"

"github.com/iximiuz/labctl/internal/labcli"
"github.com/iximiuz/labctl/internal/portforward"
"github.com/iximiuz/labctl/internal/ssh"
)

type sshOptions struct {
playID string
machine string
}

func NewCommand(cli labcli.CLI) *cobra.Command {
var opts sshOptions

cmd := &cobra.Command{
Use: "ssh [flags] <playground-id>",
Short: `Start SSH session to the target playground`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
opts.playID = args[0]

return labcli.WrapStatusError(runSSHSession(cmd.Context(), cli, &opts))
},
}

flags := cmd.Flags()

flags.StringVarP(
&opts.machine,
"machine",
"m",
"",
`Target machine (default: the first machine in the playground)`,
)

return cmd
}

func runSSHSession(ctx context.Context, cli labcli.CLI, opts *sshOptions) error {
p, err := cli.Client().GetPlay(ctx, opts.playID)
if err != nil {
return fmt.Errorf("couldn't get playground: %w", err)
}

if opts.machine == "" {
opts.machine = p.Machines[0].Name
}

tunnel, err := portforward.StartTunnel(ctx, cli.Client(), portforward.TunnelOptions{
PlayID: opts.playID,
Machine: opts.machine,
SSHDirPath: cli.Config().SSHDirPath,
})
if err != nil {
return fmt.Errorf("couldn't start tunnel: %w", err)
}

var (
localPort = 40000 + rand.Intn(20000)
errCh = make(chan error, 100)
)

ctx, cancel := context.WithCancel(ctx)
defer cancel()

go func() {
if err := tunnel.Forward(ctx, portforward.ForwardingSpec{
LocalPort: fmt.Sprintf("%d", localPort),
RemotePort: "22",
}, errCh); err != nil {
errCh <- err
}
}()

go func() {
for {
select {
case <-ctx.Done():
return

case err := <-errCh:
slog.Debug("Tunnel error: %v", err)
}
}
}()

time.Sleep(2 * time.Second)

client := ssh.NewClient(fmt.Sprintf("localhost:%d", localPort), "root", cli.Config().SSHDirPath)
if err := client.Shell(ctx, &ssh.SessionIO{
Stdin: cli.InputStream(),
Stdout: os.Stdout,
Stderr: os.Stderr,
AllocPTY: true,
}, "bash"); err != nil {
return fmt.Errorf("couldn't start SSH session: %w", err)
}

return nil
}
284 changes: 284 additions & 0 deletions internal/ssh/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
//go:build !windows

package ssh

import (
"context"
"errors"
"io"
"log/slog"
"net"
"os"
"os/signal"
"sync"
"syscall"

"golang.org/x/crypto/ssh"
"golang.org/x/term"
)

const (
DefaultHeight = 40
DefaultWidth = 80
)

var modes = ssh.TerminalModes{
ssh.ECHO: 0, // disable echoing
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
}

type Client struct {
addr string
user string

sshKeyPath string

client *ssh.Client
conn ssh.Conn
}

func NewClient(addr, user, sshKeyPath string) *Client {
return &Client{
addr: addr,
user: user,
sshKeyPath: sshKeyPath,
}
}

type connResp struct {
err error
conn ssh.Conn
client *ssh.Client
}

func (c *Client) Connect(ctx context.Context) error {
privateKey, err := ReadPrivateKey(c.sshKeyPath)
if err != nil {
return err
}

keySigner, err := ssh.ParsePrivateKey([]byte(privateKey))
if err != nil {
return err
}

var d net.Dialer
tcpConn, err := d.DialContext(ctx, "tcp", c.addr)
if err != nil {
return err
}

conf := &ssh.ClientConfig{
User: c.user,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(keySigner),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
HostKeyAlgorithms: []string{ssh.KeyAlgoED25519},
}

respCh := make(chan connResp)

// ssh.NewClientConn doesn't take a context, so we need to handle cancelation on our end
go func() {
conn, chans, reqs, err := ssh.NewClientConn(tcpConn, tcpConn.RemoteAddr().String(), conf)
if err != nil {
respCh <- connResp{err: err}
return
}

client := ssh.NewClient(conn, chans, reqs)

respCh <- connResp{nil, conn, client}
}()

for {
select {
case <-ctx.Done():
return ctx.Err()

case resp := <-respCh:
if resp.err != nil {
return resp.err
}
c.conn = resp.conn
c.client = resp.client
return nil
}
}
}

func (c *Client) Shell(ctx context.Context, sessIO *SessionIO, cmd string) error {
if c.client == nil {
if err := c.Connect(ctx); err != nil {
return err
}
}

sess, err := c.client.NewSession()
if err != nil {
return err
}
defer sess.Close()

return sessIO.attach(ctx, sess, cmd)
}

func (c *Client) Close() error {
if c.conn != nil {
if err := c.conn.Close(); err != nil {
return err
}
}

c.conn = nil
return nil
}

type SessionIO struct {
Stdin io.Reader
Stdout io.WriteCloser
Stderr io.WriteCloser

AllocPTY bool
TermEnv string
}

func (s *SessionIO) attach(ctx context.Context, sess *ssh.Session, cmd string) error {
if s.AllocPTY {
width, height := DefaultWidth, DefaultHeight

if fd, ok := getFd(s.Stdin); ok {
state, err := term.MakeRaw(fd)
if err != nil {
return err
}
defer term.Restore(fd, state)
}

if w, h, err := s.getAndWatchSize(ctx, sess); err == nil {
width, height = w, h
}

if err := sess.RequestPty(s.TermEnv, height, width, modes); err != nil {
return err
}
}

var closeStdin sync.Once
stdin, err := sess.StdinPipe()
if err != nil {
return err
}
defer closeStdin.Do(func() {
stdin.Close()
})

stdout, err := sess.StdoutPipe()
if err != nil {
return err
}

stderr, err := sess.StderrPipe()
if err != nil {
return err
}

go func() {
defer closeStdin.Do(func() {
stdin.Close()
})
if s.Stdin != nil {
io.Copy(stdin, s.Stdin)
}
}()
if s.Stdout != nil {
go io.Copy(s.Stdout, stdout)
}

if s.Stderr != nil {
go io.Copy(s.Stderr, stderr)
}

cmdC := make(chan error, 1)
go func() {
defer close(cmdC)

if cmd == "" {
err = sess.Shell()
} else {
err = sess.Run(cmd)
}

if err != nil && err != io.EOF {
cmdC <- err
}
}()

select {
case err := <-cmdC:
return err

case <-ctx.Done():
return errors.New("session forcibly closed; the remote process may still be running")
}
}

func (s *SessionIO) getAndWatchSize(ctx context.Context, sess *ssh.Session) (int, int, error) {
fd, ok := getFd(s.Stdin)
if !ok {
return 0, 0, errors.New("could not get console handle")
}

width, height, err := term.GetSize(fd)
if err != nil {
return 0, 0, err
}

go func() {
if err := watchWindowSize(ctx, fd, sess); err != nil {
slog.Debug("Error watching window size", err)
}
}()

return width, height, nil
}

func watchWindowSize(ctx context.Context, fd int, sess *ssh.Session) error {
sigc := make(chan os.Signal, 1)
signal.Notify(sigc, syscall.SIGWINCH)

for {
select {
case <-sigc:
case <-ctx.Done():
return nil
}

width, height, err := term.GetSize(fd)
if err != nil {
return err
}

if err := sess.WindowChange(height, width); err != nil {
return err
}
}
}

// FdReader is an io.Reader with an Fd function
type FdReader interface {
io.Reader
Fd() uintptr
}

func getFd(reader io.Reader) (fd int, ok bool) {
fdthing, ok := reader.(FdReader)
if !ok {
return 0, false
}

fd = int(fdthing.Fd())
return fd, term.IsTerminal(fd)
}
Loading

0 comments on commit a6c73fa

Please sign in to comment.