Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: backup to stdout and other paths and import from stdin #139

Merged
merged 8 commits into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 88 additions & 49 deletions cmd/backup_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"io"
"io/ioutil"
"os"
"path"
"path/filepath"
"regexp"
"strings"
Expand All @@ -15,6 +14,12 @@ import (
"github.com/spf13/cobra"
)

var backupOutputFile string

func init() {
BackupKeysCmd.Flags().StringVarP(&backupOutputFile, "output", "o", "charm-keys-backup.tar", "keys backup filepath")
}

// BackupKeysCmd is the cobra.Command to back up a user's account SSH keys.
var BackupKeysCmd = &cobra.Command{
Use: "backup-keys",
Expand All @@ -24,20 +29,6 @@ var BackupKeysCmd = &cobra.Command{
Args: cobra.NoArgs,
DisableFlagsInUseLine: true,
RunE: func(cmd *cobra.Command, args []string) error {
const filename = "charm-keys-backup.tar"

cwd, err := os.Getwd()
if err != nil {
return err
}

// Don't overwrite backup file
keyPath := path.Join(cwd, filename)
if fileOrDirectoryExists(keyPath) {
fmt.Printf("Not creating backup file: %s already exists.\n\n", code(filename))
os.Exit(1)
}

cfg, err := client.ConfigFromEnv()
if err != nil {
return err
Expand All @@ -57,11 +48,42 @@ var BackupKeysCmd = &cobra.Command{
return err
}

if err := createTar(dd, filename); err != nil {
backupPath := backupOutputFile
if backupPath == "-" {
exp := regexp.MustCompilePOSIX("charm_(rsa|ed25519)$")
paths, err := getKeyPaths(dd, exp)
if err != nil {
return err
}
if len(paths) != 1 {
return fmt.Errorf("backup to stdout only works with 1 key, you have %d", len(paths))
}
bts, err := os.ReadFile(paths[0])
if err != nil {
return err
}
_, _ = fmt.Fprint(cmd.OutOrStdout(), string(bts))
return nil
}

if !strings.HasSuffix(backupPath, ".tar") {
backupPath = backupPath + ".tar"
}

if fileOrDirectoryExists(backupPath) {
fmt.Printf("Not creating backup file: %s already exists.\n\n", code(backupPath))
os.Exit(1)
}

if err := os.MkdirAll(filepath.Dir(backupPath), 0o754); err != nil {
return err
}

if err := createTar(dd, backupPath); err != nil {
return err
}

fmt.Printf("Done! Saved keys to %s.\n\n", code(filename))
fmt.Printf("Done! Saved keys to %s.\n\n", code(backupPath))
return nil
},
}
Expand Down Expand Up @@ -128,49 +150,66 @@ func createTar(source string, target string) error {

exp := regexp.MustCompilePOSIX("charm_(rsa|ed25519)(.pub)?$")

if err := filepath.Walk(source,
func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
paths, err := getKeyPaths(source, exp)
if err != nil {
return err
}

if !exp.MatchString(path) {
return nil
}
for _, path := range paths {
info, err := os.Stat(path)
if err != nil {
return err
}

header, err := tar.FileInfoHeader(info, info.Name())
if err != nil {
return err
}
header, err := tar.FileInfoHeader(info, info.Name())
if err != nil {
return err
}

if baseDir != "" {
header.Name = filepath.Join(baseDir, strings.TrimPrefix(path, source))
}
if baseDir != "" {
header.Name = filepath.Join(baseDir, strings.TrimPrefix(path, source))
}

if err := tarball.WriteHeader(header); err != nil {
return err
}
if err := tarball.WriteHeader(header); err != nil {
return err
}

if info.IsDir() {
return nil
}
if info.IsDir() {
return nil
}

file, err := os.Open(path)
if err != nil {
return err
}
defer file.Close() // nolint:errcheck
file, err := os.Open(path)
if err != nil {
return err
}
defer file.Close() // nolint:errcheck

if _, err := io.Copy(tarball, file); err != nil {
return err
}
return file.Close()
}); err != nil {
return err
if _, err := io.Copy(tarball, file); err != nil {
return err
}
if err := file.Close(); err != nil {
return err
}
}

if err := tarball.Close(); err != nil {
return err
}
return tarfile.Close()
}

func getKeyPaths(source string, filter *regexp.Regexp) ([]string, error) {
var result []string
err := filepath.Walk(source, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}

if filter.MatchString(path) {
result = append(result, path)
}

return nil
})
return result, err
}
17 changes: 17 additions & 0 deletions cmd/backup_keys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package cmd

import (
"archive/tar"
"bytes"
"io"
"os"
"path/filepath"
"testing"

"github.com/charmbracelet/charm/testserver"
"golang.org/x/crypto/ssh"
)

func TestBackupKeysCmd(t *testing.T) {
Expand Down Expand Up @@ -55,3 +57,18 @@ func TestBackupKeysCmd(t *testing.T) {
t.Errorf("expected at least 2 files (public and private keys), got %d: %v", len(paths), paths)
}
}

func TestBackupToStdout(t *testing.T) {
_ = testserver.SetupTestServer(t)
var b bytes.Buffer

BackupKeysCmd.SetArgs([]string{"-o", "-"})
BackupKeysCmd.SetOut(&b)
if err := BackupKeysCmd.Execute(); err != nil {
t.Fatalf("command failed: %s", err)
}

if _, err := ssh.ParsePrivateKey(b.Bytes()); err != nil {
t.Fatalf("expected no error, got %v", err)
}
}
86 changes: 71 additions & 15 deletions cmd/import_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/charmbracelet/charm/client"
"github.com/charmbracelet/charm/ui/common"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
)

type (
Expand All @@ -36,7 +37,7 @@ var (
Hidden: false,
Short: "Import previously backed up Charm account keys.",
Long: paragraph(fmt.Sprintf("%s previously backed up Charm account keys.", keyword("Import"))),
Args: cobra.ExactArgs(1),
Args: cobra.MaximumNArgs(1),
DisableFlagsInUseLine: false,
RunE: func(cmd *cobra.Command, args []string) error {
cfg, err := client.ConfigFromEnv()
Expand All @@ -61,37 +62,60 @@ var (
return err
}

path := "-"
if len(args) > 0 {
path = args[0]
}
if !empty && !forceImportOverwrite {
if common.IsTTY() {
return newImportConfirmationTUI(args[0], dd).Start()
return newImportConfirmationTUI(cmd.InOrStdin(), path, dd).Start()
}
return fmt.Errorf("not overwriting the existing keys in %s; to force, use -f", dd)
}

err = untar(args[0], dd)
if err != nil {
return err
if isStdin(path) {
if err := restoreFromReader(cmd.InOrStdin(), dd); err != nil {
return err
}
} else {
if err := untar(path, dd); err != nil {
return err
}
}

paragraph(fmt.Sprintf("Done! Keys imported to %s", code(dd)))
return nil
},
}
)

func untarCmd(tarPath, dataPath string) tea.Cmd {
func isStdin(path string) bool {
fi, _ := os.Stdin.Stat()
return (fi.Mode()&os.ModeNamedPipe) != 0 || path == "-"
}

func restoreCmd(r io.Reader, path, dataPath string) tea.Cmd {
return func() tea.Msg {
if err := untar(tarPath, dataPath); err != nil {
if isStdin(path) {
if err := restoreFromReader(r, dataPath); err != nil {
return confirmationErrMsg{err}
}
return confirmationSuccessMsg{}
}

if err := untar(path, dataPath); err != nil {
return confirmationErrMsg{err}
}
return confirmationSuccessMsg{}
}
}

type confirmationTUI struct {
state confirmationState
yes bool
err error
tarPath, dataPath string
reader io.Reader
state confirmationState
yes bool
err error
path, dataPath string
}

func (m confirmationTUI) Init() tea.Cmd {
Expand All @@ -112,14 +136,14 @@ func (m confirmationTUI) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case "enter":
if m.yes {
m.state = confirmed
return m, untarCmd(m.tarPath, m.dataPath)
return m, restoreCmd(m.reader, m.path, m.dataPath)
}
m.state = cancelling
return m, tea.Quit
case "y":
m.yes = true
m.state = confirmed
return m, untarCmd(m.tarPath, m.dataPath)
return m, restoreCmd(m.reader, m.path, m.dataPath)
default:
if m.state == ready {
m.yes = false
Expand Down Expand Up @@ -169,6 +193,37 @@ func isEmpty(name string) (bool, error) {
return false, err
}

func restoreFromReader(r io.Reader, dd string) error {
bts, err := io.ReadAll(r)
if err != nil {
return err
}

signer, err := ssh.ParsePrivateKey(bts)
if err != nil {
return fmt.Errorf("invalid private key: %w", err)
}

if signer.PublicKey().Type() != "ssh-ed25519" {
return fmt.Errorf("only ed25519 keys are allowed, yours is %s", signer.PublicKey().Type())
}

keypath := filepath.Join(dd, "charm_ed25519")
if err := os.WriteFile(keypath, bts, 0o600); err != nil {
return err
}

if err := os.WriteFile(
keypath+".pub",
ssh.MarshalAuthorizedKey(signer.PublicKey()),
0o600,
); err != nil {
return err
}

return nil
}

func untar(tarball, targetDir string) error {
reader, err := os.Open(tarball)
if err != nil {
Expand Down Expand Up @@ -225,10 +280,11 @@ func untar(tarball, targetDir string) error {

// Import Confirmation TUI

func newImportConfirmationTUI(tarPath, dataPath string) *tea.Program {
func newImportConfirmationTUI(r io.Reader, tarPath, dataPath string) *tea.Program {
return tea.NewProgram(confirmationTUI{
reader: r,
state: ready,
tarPath: tarPath,
path: tarPath,
dataPath: dataPath,
})
}
Expand Down
Loading