Skip to content

Commit

Permalink
Support the execution of SSH commands with optional input.
Browse files Browse the repository at this point in the history
Signed-off-by: Gordon Messmer <gordon.messmer@gmail.com>
  • Loading branch information
gordonmessmer committed Feb 1, 2024
1 parent 983ccb9 commit dbf804a
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 4 deletions.
15 changes: 13 additions & 2 deletions pkg/ssh/connection_golang.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ func golangConnectionDial(options ConnectionDialOptions) (*ConnectionDialReport,
}

func golangConnectionExec(options ConnectionExecOptions) (*ConnectionExecReport, error) {
return golangConnectionExecWithInput(options, nil)
}

func golangConnectionExecWithInput(options ConnectionExecOptions, input *os.File) (*ConnectionExecReport, error) {
if !strings.HasPrefix(options.Host, "ssh://") {
options.Host = "ssh://" + options.Host
}
Expand All @@ -117,7 +121,7 @@ func golangConnectionExec(options ConnectionExecOptions) (*ConnectionExecReport,
return nil, fmt.Errorf("failed to connect: %w", err)
}

out, err := ExecRemoteCommand(dialAdd, strings.Join(options.Args, " "))
out, err := ExecRemoteCommandWithInput(dialAdd, strings.Join(options.Args, " "), input)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -189,6 +193,10 @@ func golangConnectionScp(options ConnectionScpOptions) (*ConnectionScpReport, er
// ExecRemoteCommand takes a ssh client connection and a command to run and executes the
// command on the specified client. The function returns the Stdout from the client or the Stderr
func ExecRemoteCommand(dial *ssh.Client, run string) ([]byte, error) {
return ExecRemoteCommandWithInput(dial, run, nil)
}

func ExecRemoteCommandWithInput(dial *ssh.Client, run string, input *os.File) ([]byte, error) {
sess, err := dial.NewSession() // new ssh client session
if err != nil {
return nil, err
Expand All @@ -198,7 +206,10 @@ func ExecRemoteCommand(dial *ssh.Client, run string) ([]byte, error) {
var buffer bytes.Buffer
var bufferErr bytes.Buffer
sess.Stdout = &buffer // output from client funneled into buffer
sess.Stderr = &bufferErr // err form client funneled into buffer
sess.Stderr = &bufferErr // err from client funneled into buffer
if input != nil {
sess.Stdin = input
}
if err := sess.Run(run); err != nil { // run the command on the ssh client
return nil, fmt.Errorf("%v: %w", bufferErr.String(), err)
}
Expand Down
8 changes: 8 additions & 0 deletions pkg/ssh/connection_native.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"os"
"os/exec"
"regexp"
"strings"
Expand Down Expand Up @@ -93,6 +94,10 @@ func nativeConnectionCreate(options ConnectionCreateOptions) error {
}

func nativeConnectionExec(options ConnectionExecOptions) (*ConnectionExecReport, error) {
return nativeConnectionExecWithInput(options, nil)
}

func nativeConnectionExecWithInput(options ConnectionExecOptions, input *os.File) (*ConnectionExecReport, error) {
dst, uri, err := Validate(options.User, options.Host, options.Port, options.Identity)
if err != nil {
return nil, err
Expand Down Expand Up @@ -126,6 +131,9 @@ func nativeConnectionExec(options ConnectionExecOptions) (*ConnectionExecReport,
info := exec.Command(ssh, args...)
info.Stdout = output
info.Stderr = errors
if input != nil {
info.Stdin = input
}
err = info.Run()
if err != nil {
return nil, err
Expand Down
9 changes: 7 additions & 2 deletions pkg/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ssh

import (
"fmt"
"os"

"golang.org/x/crypto/ssh"
)
Expand All @@ -27,15 +28,19 @@ func Dial(options *ConnectionDialOptions, kind EngineMode) (*ssh.Client, error)
}

func Exec(options *ConnectionExecOptions, kind EngineMode) (string, error) {
return ExecWithInput(options, kind, nil)
}

func ExecWithInput(options *ConnectionExecOptions, kind EngineMode, input *os.File) (string, error) {
var rep *ConnectionExecReport
var err error
if kind == NativeMode {
rep, err = nativeConnectionExec(*options)
rep, err = nativeConnectionExecWithInput(*options, input)
if err != nil {
return "", err
}
} else {
rep, err = golangConnectionExec(*options)
rep, err = golangConnectionExecWithInput(*options, input)
if err != nil {
return "", err
}
Expand Down
20 changes: 20 additions & 0 deletions pkg/ssh/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,26 @@ func TestExec(t *testing.T) {
require.Error(t, err, "failed to connect: ssh: handshake failed: ssh: disconnect, reason 2: Too many authentication failures")
}

func TestExecWithInput(t *testing.T) {
options := ConnectionExecOptions{
Port: 22,
Host: "localhost",
Args: []string{"md5sum"},
}

input, err := os.Open("/etc/fstab")
require.NoError(t, err)
defer input.Close()

_, err = ExecWithInput(&options, NativeMode, input)
// exit status 255 is what you get when ssh is not enabled or the connection failed
// this means up to that point, everything worked
require.Error(t, err, "exit status 255")

_, err = ExecWithInput(&options, GolangMode, input)
require.Error(t, err, "failed to connect: ssh: handshake failed: ssh: disconnect, reason 2: Too many authentication failures")
}

func TestDial(t *testing.T) {
options := ConnectionDialOptions{
Port: 22,
Expand Down

0 comments on commit dbf804a

Please sign in to comment.