Skip to content
This repository has been archived by the owner on Dec 7, 2023. It is now read-only.

Unify ssh and exec commands #580

Merged
merged 6 commits into from
Apr 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmd/ignite/cmd/vmcmd/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,6 @@ func NewCmdSSH(out io.Writer) *cobra.Command {

func addSSHFlags(fs *pflag.FlagSet, sf *run.SSHFlags) {
fs.StringVarP(&sf.IdentityFile, "identity", "i", "", "Override the vm's default identity file")
fs.Uint32VarP(&sf.Timeout, "timeout", "t", 10, "Timeout waiting for connection in seconds")
fs.Uint32Var(&sf.Timeout, "timeout", 10, "Timeout waiting for connection in seconds")
fs.BoolVarP(&sf.Tty, "tty", "t", true, "Allocate a pseudo-TTY")
}
116 changes: 1 addition & 115 deletions cmd/ignite/run/exec.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,7 @@
package run

import (
"fmt"
"io/ioutil"
"net"
"os"
"path"
"time"

"github.com/alessio/shellescape"
api "github.com/weaveworks/ignite/pkg/apis/ignite"
"github.com/weaveworks/ignite/pkg/constants"
"github.com/weaveworks/ignite/pkg/util"
"golang.org/x/crypto/ssh"
)

// ExecFlags contains the flags supported by the exec command.
Expand Down Expand Up @@ -41,108 +30,5 @@ func (ef *ExecFlags) NewExecOptions(vmMatch string, command ...string) (eo *exec

// Exec executes command in a VM based on the provided execOptions.
func Exec(eo *execOptions) error {
// Check if the VM is running
if !eo.vm.Running() {
return fmt.Errorf("VM %q is not running", eo.vm.GetUID())
}

// Get the IP address
ipAddrs := eo.vm.Status.IPAddresses
if len(ipAddrs) == 0 {
return fmt.Errorf("VM %q has no usable IP addresses", eo.vm.GetUID())
}

// If an external identity file is specified, use it instead of the internal one
privKeyFile := eo.IdentityFile
if len(privKeyFile) == 0 {
privKeyFile = path.Join(eo.vm.ObjectPath(), fmt.Sprintf(constants.VM_SSH_KEY_TEMPLATE, eo.vm.GetUID()))
if !util.FileExists(privKeyFile) {
return fmt.Errorf("no private key found for VM %q", eo.vm.GetUID())
}
}

signer, err := newSignerForKey(privKeyFile)
if err != nil {
return fmt.Errorf("unable to create signer for private key: %v", err)
}

// Create an SSH client, and connect, we will use this to exec
config := newSSHConfig(signer, eo.Timeout)
client, err := ssh.Dial("tcp", net.JoinHostPort(ipAddrs[0].String(), "22"), config)
if err != nil {
return fmt.Errorf("failed to dial: %v", err)
}

// Run the command, DO NOT wrap this error as the caller can check for the command exit
// code in the ssh.ExitError type
return runSSHCommand(client, eo.Tty, eo.command)
}

func newSignerForKey(keyPath string) (ssh.Signer, error) {
key, err := ioutil.ReadFile(keyPath)
if err != nil {
return nil, fmt.Errorf("unable to read private key: %v", err)
}

// Create the Signer for this private key.
return ssh.ParsePrivateKey(key)
}

func newSSHConfig(publicKey ssh.Signer, timeout uint32) *ssh.ClientConfig {
return &ssh.ClientConfig{
User: "root",
Auth: []ssh.AuthMethod{
ssh.PublicKeys(publicKey),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // TODO: use ssh.FixedPublicKey instead
Timeout: time.Second * time.Duration(timeout),
}
}

func runSSHCommand(client *ssh.Client, tty bool, command []string) error {
// create a session for the command
session, err := client.NewSession()
if err != nil {
return fmt.Errorf("failed to create session: %v", err)
}
defer session.Close()

if tty {
// get a pty
// TODO: should these be based on the host terminal?
// TODO: should we request something other than xterm?
// TODO: we should probably configure the terminal modes
modes := ssh.TerminalModes{}
if err := session.RequestPty("xterm", 80, 40, modes); err != nil {
return fmt.Errorf("request for pseudo terminal failed: %v", err)
}
}

// Connect input / output
// TODO: these should come from the cobra command instead of hardcoding os.Stderr etc.
session.Stderr = os.Stderr
session.Stdout = os.Stdout
session.Stdin = os.Stdin

/*
Do not wrap this error so the caller can check for the exit code
If the remote server does not send an exit status, an error of type *ExitMissingError is returned.
If the command completes unsuccessfully or is interrupted by a signal, the error is of type *ExitError.
Other error types may be returned for I/O problems.
*/
return session.Run(joinShellCommand(command))
}

// joinShellCommand joins command parts into a single string safe for passing to sh -c (or SSH)
func joinShellCommand(command []string) string {
joined := command[0]
if len(command) == 1 {
return joined
}
for _, arg := range command[1:] {
// NOTE: we need to escape / quote to ensure that
// each component of command... is read as a single shell word
joined += " " + shellescape.Quote(arg)
}
return joined
return runSSH(eo.vm, eo.IdentityFile, eo.command, eo.Tty, eo.Timeout)
}
197 changes: 163 additions & 34 deletions cmd/ignite/run/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,82 +2,211 @@ package run

import (
"fmt"
"io/ioutil"
"net"
"os"
"path"
"time"

"github.com/alessio/shellescape"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"

api "github.com/weaveworks/ignite/pkg/apis/ignite"
"github.com/weaveworks/ignite/pkg/constants"
"github.com/weaveworks/ignite/pkg/util"
)

const (
defaultTerm = "xterm"
defaultSSHPort = "22"
defaultSSHNetwork = "tcp"
)

// SSHFlags contains the flags supported by the ssh command.
type SSHFlags struct {
Timeout uint32
IdentityFile string
Tty bool
}

type sshOptions struct {
*SSHFlags
vm *api.VM
}

// NewSSHOptions returns ssh options for a given VM.
func (sf *SSHFlags) NewSSHOptions(vmMatch string) (so *sshOptions, err error) {
so = &sshOptions{SSHFlags: sf}
so.vm, err = getVMForMatch(vmMatch)
return
}

// SSH starts a ssh session as per the provided ssh options.
func SSH(so *sshOptions) error {
// Check if the VM is running
if !so.vm.Running() {
return fmt.Errorf("VM %q is not running", so.vm.GetUID())
return runSSH(so.vm, so.IdentityFile, []string{}, so.Tty, so.Timeout)
}

// runSSH creates and runs ssh session based on the provided arguments.
// If the command list is empty, ssh shell is created, else the ssh command is
// executed.
func runSSH(vm *api.VM, privKeyFile string, command []string, tty bool, timeout uint32) error {
// Check if the VM is running.
if !vm.Running() {
return fmt.Errorf("VM %q is not running", vm.GetUID())
}

ipAddrs := so.vm.Status.IPAddresses
// Get the IP address.
ipAddrs := vm.Status.IPAddresses
if len(ipAddrs) == 0 {
return fmt.Errorf("VM %q has no usable IP addresses", so.vm.GetUID())
return fmt.Errorf("VM %q has no usable IP addresses", vm.GetUID())
}

// We're dealing with local VMs in a trusted (internal) subnet, disable some warnings for convenience
// TODO: For security, track the known_hosts internally, do something about the IP collisions (if needed)
sshOpts := []string{
"LogLevel=ERROR", // Warning: Permanently added '<ip>' (ECDSA) to the list of known hosts.
// We get this if the VM happens to get an address that another container has used:
"UserKnownHostsFile=/dev/null", // WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED!
"StrictHostKeyChecking=no", // The authenticity of host ***** can't be established
fmt.Sprintf("ConnectTimeout=%d", so.Timeout),
// Get private key file path.
if len(privKeyFile) == 0 {
privKeyFile = path.Join(vm.ObjectPath(), fmt.Sprintf(constants.VM_SSH_KEY_TEMPLATE, vm.GetUID()))
if !util.FileExists(privKeyFile) {
return fmt.Errorf("no private key found for VM %q", vm.GetUID())
}
}

sshArgs := append(make([]string, 0, len(sshOpts)*2+3),
fmt.Sprintf("root@%s", ipAddrs[0]))
// Create a new ssh signer for the private key.
signer, err := newSignerForKey(privKeyFile)
if err != nil {
return fmt.Errorf("unable to create signer for private key: %v", err)
}

for _, opt := range sshOpts {
sshArgs = append(sshArgs, "-o", opt)
// Defer exit here and set the exit code based on any ssh error, so that
// this ssh command returns the correct ssh exit code. Since this function
// results in an os.Exit, any error returned by this function won't be
// received by the caller. Print the error to make the errror message
// visible and set the error code when an error is found.
exitCode := 0
defer func() {
os.Exit(exitCode)
}()

// printErrAndSetExitCode is used to print an error message, set exit code
// and return nil. This is needed because once the ssh connection is
// estabilish, to return the error code of the actual ssh session, instead
// of returning an error, the runSSH function defers os.Exit with the ssh
// exit code. For showing any error to the user, it needs to be printed.
printErrAndSetExitCode := func(errMsg error, exitCode *int, code int) error {
log.Errorf("%v\n", errMsg)
*exitCode = code
return nil
}

sshArgs = append(sshArgs, "-i")
// Create an SSH client, and connect.
config := newSSHConfig(signer, timeout)
client, err := ssh.Dial(defaultSSHNetwork, net.JoinHostPort(ipAddrs[0].String(), defaultSSHPort), config)
if err != nil {
return printErrAndSetExitCode(fmt.Errorf("failed to dial: %v", err), &exitCode, 1)
}
defer client.Close()

// If an external identity file is specified, use it instead of the internal one
if len(so.IdentityFile) > 0 {
sshArgs = append(sshArgs, so.IdentityFile)
} else {
privKeyFile := path.Join(so.vm.ObjectPath(), fmt.Sprintf(constants.VM_SSH_KEY_TEMPLATE, so.vm.GetUID()))
if !util.FileExists(privKeyFile) {
return fmt.Errorf("no private key found for VM %q", so.vm.GetUID())
// Create a session.
session, err := client.NewSession()
if err != nil {
return printErrAndSetExitCode(fmt.Errorf("failed to create session: %v", err), &exitCode, 1)
}
defer session.Close()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before we defer the session.Close, we can defer an os.Exit (then they will run in reverse order)

The Exit should take a code variable:
https://stackoverflow.com/a/24601700


// Configure tty if requested.
if tty {
// Get stdin file descriptor reference.
fd := int(os.Stdin.Fd())

// Store the raw state of the terminal.
state, err := terminal.MakeRaw(fd)
if err != nil {
return printErrAndSetExitCode(fmt.Errorf("failed to make terminal raw: %v", err), &exitCode, 1)
}
defer terminal.Restore(fd, state)

sshArgs = append(sshArgs, privKeyFile)
}
// Get the terminal dimensions.
w, h, err := terminal.GetSize(fd)
if err != nil {
return printErrAndSetExitCode(fmt.Errorf("failed to get terminal size: %v", err), &exitCode, 1)
}

// Set terminal modes.
modes := ssh.TerminalModes{
ssh.ECHO: 1,
}

// SSH into the VM
if code, err := util.ExecForeground("ssh", sshArgs...); err != nil {
if code != 255 {
return fmt.Errorf("SSH into VM %q failed: %v", so.vm.GetUID(), err)
// Read the TERM environment variable and use it to request the PTY.
term := os.Getenv("TERM")
if term == "" {
term = defaultTerm
}

// Code 255 is used for signaling a connection error, be it caused by
// a failed connection attempt or disconnection by VM reboot.
log.Warnf("SSH command terminated")
if err := session.RequestPty(term, h, w, modes); err != nil {
return printErrAndSetExitCode(fmt.Errorf("request for pseudo terminal failed: %v", err), &exitCode, 1)
}
}

// Connect input / output.
// TODO: these should come from the cobra command instead of hardcoding
// os.Stderr etc.
session.Stderr = os.Stderr
session.Stdout = os.Stdout
session.Stdin = os.Stdin

if len(command) == 0 {
if err := session.Shell(); err != nil {
return printErrAndSetExitCode(fmt.Errorf("failed to start shell: %v", err), &exitCode, 1)
}

if err := session.Wait(); err != nil {
if e, ok := err.(*ssh.ExitError); ok {
return printErrAndSetExitCode(err, &exitCode, e.ExitStatus())
}
return printErrAndSetExitCode(fmt.Errorf("failed waiting for session to exit: %v", err), &exitCode, 1)
}
} else {
if err := session.Run(joinShellCommand(command)); err != nil {
if e, ok := err.(*ssh.ExitError); ok {
return printErrAndSetExitCode(err, &exitCode, e.ExitStatus())
}
return printErrAndSetExitCode(fmt.Errorf("failed to run shell command: %s", err), &exitCode, 1)
}
}
return nil
}

func newSignerForKey(keyPath string) (ssh.Signer, error) {
key, err := ioutil.ReadFile(keyPath)
if err != nil {
return nil, fmt.Errorf("unable to read private key: %v", err)
}

// Create the Signer for this private key.
return ssh.ParsePrivateKey(key)
}

func newSSHConfig(publicKey ssh.Signer, timeout uint32) *ssh.ClientConfig {
return &ssh.ClientConfig{
User: "root",
Auth: []ssh.AuthMethod{
ssh.PublicKeys(publicKey),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // TODO: use ssh.FixedPublicKey instead
Timeout: time.Second * time.Duration(timeout),
}
}

// joinShellCommand joins command parts into a single string safe for passing to sh -c (or SSH)
func joinShellCommand(command []string) string {
joined := command[0]
if len(command) == 1 {
return joined
}
for _, arg := range command[1:] {
// NOTE: we need to escape / quote to ensure that
// each component of command... is read as a single shell word
joined += " " + shellescape.Quote(arg)
}
return joined
}
3 changes: 2 additions & 1 deletion docs/cli/ignite/ignite_ssh.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ ignite ssh <vm> [flags]
```
-h, --help help for ssh
-i, --identity string Override the vm's default identity file
-t, --timeout uint32 Timeout waiting for connection in seconds (default 10)
--timeout uint32 Timeout waiting for connection in seconds (default 10)
-t, --tty Allocate a pseudo-TTY (default true)
```

### Options inherited from parent commands
Expand Down
Loading