diff --git a/docs-src/content/functions/crypto.yml b/docs-src/content/functions/crypto.yml index ee7f4c5d3..3a61fa3bd 100644 --- a/docs-src/content/functions/crypto.yml +++ b/docs-src/content/functions/crypto.yml @@ -9,6 +9,47 @@ preamble: | recommended to have your resident security experts inspect gomplate's code before using gomplate for critical security infrastructure!_ funcs: + - name: crypto.SSH + released: v4.2.0 # I hope so + description: | + Namespace for ssh functions + pipeline: false + examples: + - | + $ gomplate -i '{{ crypto.ssh }}' + + + - name: crypto.SSH.PublicKey + released: v4.2.0 + description: | + Loads [Secure Shell](https://en.wikipedia.org/wiki/Secure_Shell) public key + pipeline: true + arguments: + - name: name + required: false + description: the name of the key in `~/.ssh` or the absolute path to it. The default value is defined by `IdentityFile` in `~/.ssh/config`. If not specified, `~/.ssh/id_rsa.pub` is used. + examples: + - | + $ cat ~/.ssh/id_rsa.pub + gxAedO6GSFC7X+feNqKydIqKlq82R9cnjJPuPLbVvWPB+r08PeJobl++6d9m8EQorpokS+ntqnr35QnIBDWLHk139KhWkOjDOvUHJd6pjOOLhSVapmKPOz1dST4QCweET59STvLHHjNVQfJtWI9zVl4X9S4SoiLDkUUyge+9UnqyA9bAr2P4NkVWZYgf3QnrqoWpRGHz1F7JgV+VmGOlh/Kmc6Q== email@example.com + + $ gomplate -i '{{ crypto.SSH.PublicKey }}' + + - | + $ gomplate -i '{{ crypto.SSH.PublicKey.Marshal }}' + ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQCnEosV4dTgI6CL4YgM4Tfzs6CKdvLL/tarxipWrgEcdwn0TqFn3PmvxSOQWXbQci1Rl2I+U6X3Z4qQ3fafEOlF/bDbwfnY/eUpr9dHnVe1FCbX0tVzCR7OMHg7vGnF3Mta5E9MXMBKupiukgH51hH6fosr90Cvuhj0vsmO3jQL+i1yQxgbc14RCMQuIUZqAA/1Y9JWtucYe4X2uRyby/m2qtHA08kjPTREVd1cMSTM6rCdxnjXgJn7I416ybWnNIwwYeU8q2aKNPIhndSnIBMdDQnnxRCQHgWZXGjF8K8dVl1r3lJWbg/XMXKDWwLXbhRXZwR7/6HDamsV9fkY5Sld9VfKesNiCjaWLlnbe3d6NbdveBcBO6DgDFcshvvtOyu4quBly8EJFpyfeo5V8XQTIVMcLxehXMZNlk0C0PGKQx4xHdxTwFw9IFPbuGNRqRIRwC0YEH3TR4+xBp/gxAedO6GSFC7X+feNqKydIqKlq82R9cnjJPuPLbVvWPB+r08PeJobl++6d9m8EQorpokS+ntqnr35QnIBDWLHk139KhWkOjDOvUHJd6pjOOLhSVapmKPOz1dST4QCweET59STvLHHjNVQfJtWI9zVl4X9S4SoiLDkUUyge+9UnqyA9bAr2P4NkVWZYgf3QnrqoWpRGHz1F7JgV+VmGOlh/Kmc6Q== email@example.com + - | + $ gomplate -i '{{ crypto.SSH.PublicKey.Blob | base64.Encode }}' + AAAAB3NzaC1yc2EAAAADAQABAAACAQCnEosV4dTgI6CL4YgM4Tfzs6CKdvLL/tarxipWrgEcdwn0TqFn3PmvxSOQWXbQci1Rl2I+U6X3Z4qQ3fafEOlF/bDbwfnY/eUpr9dHnVe1FCbX0tVzCR7OMHg7vGnF3Mta5E9MXMBKupiukgH51hH6fosr90Cvuhj0vsmO3jQL+i1yQxgbc14RCMQuIUZqAA/1Y9JWtucYe4X2uRyby/m2qtHA08kjPTREVd1cMSTM6rCdxnjXgJn7I416ybWnNIwwYeU8q2aKNPIhndSnIBMdDQnnxRCQHgWZXGjF8K8dVl1r3lJWbg/XMXKDWwLXbhRXZwR7/6HDamsV9fkY5Sld9VfKesNiCjaWLlnbe3d6NbdveBcBO6DgDFcshvvtOyu4quBly8EJFpyfeo5V8XQTIVMcLxehXMZNlk0C0PGKQx4xHdxTwFw9IFPbuGNRqRIRwC0YEH3TR4+xBp/gxAedO6GSFC7X+feNqKydIqKlq82R9cnjJPuPLbVvWPB+r08PeJobl++6d9m8EQorpokS+ntqnr35QnIBDWLHk139KhWkOjDOvUHJd6pjOOLhSVapmKPOz1dST4QCweET59STvLHHjNVQfJtWI9zVl4X9S4SoiLDkUUyge+9UnqyA9bAr2P4NkVWZYgf3QnrqoWpRGHz1F7JgV+VmGOlh/Kmc6Q== + - | + $ gomplate -i '{{ crypto.SSH.PublicKey.Comment }}' + email@example.com + - | + $ gomplate -i '{{ crypto.SSH.PublicKey.Format }}' + ssh-rsa + - | + $ gomplate -i '{{ (crypto.SSH.PublicKey "e2e_id_ed25519").Marshal }}' + ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBCLlDopq1aotlRUMw6oJ7Snr+qa+r5X8qxADTuYJumN e2e_key - name: crypto.Bcrypt released: v2.6.0 description: | diff --git a/docs/content/functions/crypto.md b/docs/content/functions/crypto.md index c049e3334..2229798c7 100644 --- a/docs/content/functions/crypto.md +++ b/docs/content/functions/crypto.md @@ -14,6 +14,75 @@ however, and so can not guarantee correctness of implementation. It is recommended to have your resident security experts inspect gomplate's code before using gomplate for critical security infrastructure!_ +## `crypto.SSH` + +Namespace for ssh functions + +_Added in gomplate [v4.2.0](https://github.com/hairyhenderson/gomplate/releases/tag/v4.2.0)_ +### Usage + +``` +crypto.SSH +``` + + +### Examples + +```console +$ gomplate -i '{{ crypto.ssh }}' + +``` + +## `crypto.SSH.PublicKey` + +Loads [Secure Shell](https://en.wikipedia.org/wiki/Secure_Shell) public key + +_Added in gomplate [v4.2.0](https://github.com/hairyhenderson/gomplate/releases/tag/v4.2.0)_ +### Usage + +``` +crypto.SSH.PublicKey [name] +``` +``` +name | crypto.SSH.PublicKey +``` + +### Arguments + +| name | description | +|------|-------------| +| `name` | _(optional)_ the name of the key in `~/.ssh` or the absolute path to it. The default value is defined by `IdentityFile` in `~/.ssh/config`. If not specified, `~/.ssh/id_rsa.pub` is used. | + +### Examples + +```console +$ cat ~/.ssh/id_rsa.pub +gxAedO6GSFC7X+feNqKydIqKlq82R9cnjJPuPLbVvWPB+r08PeJobl++6d9m8EQorpokS+ntqnr35QnIBDWLHk139KhWkOjDOvUHJd6pjOOLhSVapmKPOz1dST4QCweET59STvLHHjNVQfJtWI9zVl4X9S4SoiLDkUUyge+9UnqyA9bAr2P4NkVWZYgf3QnrqoWpRGHz1F7JgV+VmGOlh/Kmc6Q== email@example.com + +$ gomplate -i '{{ crypto.SSH.PublicKey }}' + +``` +```console +$ gomplate -i '{{ crypto.SSH.PublicKey.Marshal }}' +ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQCnEosV4dTgI6CL4YgM4Tfzs6CKdvLL/tarxipWrgEcdwn0TqFn3PmvxSOQWXbQci1Rl2I+U6X3Z4qQ3fafEOlF/bDbwfnY/eUpr9dHnVe1FCbX0tVzCR7OMHg7vGnF3Mta5E9MXMBKupiukgH51hH6fosr90Cvuhj0vsmO3jQL+i1yQxgbc14RCMQuIUZqAA/1Y9JWtucYe4X2uRyby/m2qtHA08kjPTREVd1cMSTM6rCdxnjXgJn7I416ybWnNIwwYeU8q2aKNPIhndSnIBMdDQnnxRCQHgWZXGjF8K8dVl1r3lJWbg/XMXKDWwLXbhRXZwR7/6HDamsV9fkY5Sld9VfKesNiCjaWLlnbe3d6NbdveBcBO6DgDFcshvvtOyu4quBly8EJFpyfeo5V8XQTIVMcLxehXMZNlk0C0PGKQx4xHdxTwFw9IFPbuGNRqRIRwC0YEH3TR4+xBp/gxAedO6GSFC7X+feNqKydIqKlq82R9cnjJPuPLbVvWPB+r08PeJobl++6d9m8EQorpokS+ntqnr35QnIBDWLHk139KhWkOjDOvUHJd6pjOOLhSVapmKPOz1dST4QCweET59STvLHHjNVQfJtWI9zVl4X9S4SoiLDkUUyge+9UnqyA9bAr2P4NkVWZYgf3QnrqoWpRGHz1F7JgV+VmGOlh/Kmc6Q== email@example.com +``` +```console +$ gomplate -i '{{ crypto.SSH.PublicKey.Blob | base64.Encode }}' +AAAAB3NzaC1yc2EAAAADAQABAAACAQCnEosV4dTgI6CL4YgM4Tfzs6CKdvLL/tarxipWrgEcdwn0TqFn3PmvxSOQWXbQci1Rl2I+U6X3Z4qQ3fafEOlF/bDbwfnY/eUpr9dHnVe1FCbX0tVzCR7OMHg7vGnF3Mta5E9MXMBKupiukgH51hH6fosr90Cvuhj0vsmO3jQL+i1yQxgbc14RCMQuIUZqAA/1Y9JWtucYe4X2uRyby/m2qtHA08kjPTREVd1cMSTM6rCdxnjXgJn7I416ybWnNIwwYeU8q2aKNPIhndSnIBMdDQnnxRCQHgWZXGjF8K8dVl1r3lJWbg/XMXKDWwLXbhRXZwR7/6HDamsV9fkY5Sld9VfKesNiCjaWLlnbe3d6NbdveBcBO6DgDFcshvvtOyu4quBly8EJFpyfeo5V8XQTIVMcLxehXMZNlk0C0PGKQx4xHdxTwFw9IFPbuGNRqRIRwC0YEH3TR4+xBp/gxAedO6GSFC7X+feNqKydIqKlq82R9cnjJPuPLbVvWPB+r08PeJobl++6d9m8EQorpokS+ntqnr35QnIBDWLHk139KhWkOjDOvUHJd6pjOOLhSVapmKPOz1dST4QCweET59STvLHHjNVQfJtWI9zVl4X9S4SoiLDkUUyge+9UnqyA9bAr2P4NkVWZYgf3QnrqoWpRGHz1F7JgV+VmGOlh/Kmc6Q== +``` +```console +$ gomplate -i '{{ crypto.SSH.PublicKey.Comment }}' +email@example.com +``` +```console +$ gomplate -i '{{ crypto.SSH.PublicKey.Format }}' + ssh-rsa +``` +```console +$ gomplate -i '{{ (crypto.SSH.PublicKey "e2e_id_ed25519").Marshal }}' +ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBCLlDopq1aotlRUMw6oJ7Snr+qa+r5X8qxADTuYJumN e2e_key +``` + ## `crypto.Bcrypt` Uses the [bcrypt](https://en.wikipedia.org/wiki/Bcrypt) password hashing algorithm to generate the hash of a given string. Wraps the [`golang.org/x/crypto/brypt`](https://godoc.org/golang.org/x/crypto/bcrypt) package. diff --git a/go.mod b/go.mod index 230855b3f..e1c63cc18 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/itchyny/gojq v0.12.16 github.com/johannesboyne/gofakes3 v0.0.0-20240217095638-c55a48f17be6 github.com/joho/godotenv v1.5.1 + github.com/kevinburke/ssh_config v1.2.0 github.com/lmittmann/tint v1.0.5 github.com/spf13/cobra v1.8.1 github.com/stretchr/testify v1.9.0 @@ -124,7 +125,6 @@ require ( github.com/itchyny/timefmt-go v0.1.6 // indirect github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect - github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect diff --git a/internal/funcs/crypto.go b/internal/funcs/crypto.go index 8807d5181..2de7b545b 100644 --- a/internal/funcs/crypto.go +++ b/internal/funcs/crypto.go @@ -23,7 +23,11 @@ import ( func CreateCryptoFuncs(ctx context.Context) map[string]interface{} { f := map[string]interface{}{} - ns := &CryptoFuncs{ctx} + ns := &CryptoFuncs{ + ctx: ctx, + ssh: newSSHFuncs(ctx), + } + ns.self = ns f["crypto"] = func() interface{} { return ns } return f @@ -31,7 +35,15 @@ func CreateCryptoFuncs(ctx context.Context) map[string]interface{} { // CryptoFuncs - type CryptoFuncs struct { + namespace + ctx context.Context + ssh *SSHFuncs +} + +// SSH - +func (f CryptoFuncs) SSH() *SSHFuncs { + return f.ssh } // PBKDF2 - Run the Password-Based Key Derivation Function #2 as defined in diff --git a/internal/funcs/file.go b/internal/funcs/file.go index 16012cf12..036013a54 100644 --- a/internal/funcs/file.go +++ b/internal/funcs/file.go @@ -4,6 +4,7 @@ import ( "context" "io/fs" "path/filepath" + "sync" osfs "github.com/hack-pad/hackpadfs/os" "github.com/hairyhenderson/gomplate/v4/conv" @@ -11,16 +12,16 @@ import ( "github.com/hairyhenderson/gomplate/v4/internal/iohelpers" ) +var ( + fsys fs.FS + fsysOnce sync.Once +) + // CreateFileFuncs - func CreateFileFuncs(ctx context.Context) map[string]interface{} { - fsys, err := datafs.FSysForPath(ctx, "/") - if err != nil { - fsys = datafs.WrapWdFS(osfs.NewFS()) - } - ns := &FileFuncs{ ctx: ctx, - fs: fsys, + fs: getFS(ctx), } return map[string]interface{}{ @@ -28,6 +29,18 @@ func CreateFileFuncs(ctx context.Context) map[string]interface{} { } } +func getFS(ctx context.Context) fs.FS { + fsysOnce.Do(func() { + var err error + fsys, err = datafs.FSysForPath(ctx, "/") + if err != nil { + fsys = datafs.WrapWdFS(osfs.NewFS()) + } + }) + + return fsys +} + // FileFuncs - type FileFuncs struct { ctx context.Context diff --git a/internal/funcs/log.go b/internal/funcs/log.go new file mode 100644 index 000000000..b6464ebb7 --- /dev/null +++ b/internal/funcs/log.go @@ -0,0 +1,12 @@ +package funcs + +import ( + "context" + "log/slog" +) + +const TraceFuncsLevel = slog.LevelDebug - 1 + +func trace(ctx context.Context, msg string, attrs ...slog.Attr) { + slog.LogAttrs(ctx, TraceFuncsLevel, msg, attrs...) +} diff --git a/internal/funcs/namespace.go b/internal/funcs/namespace.go new file mode 100644 index 000000000..1fc4c313a --- /dev/null +++ b/internal/funcs/namespace.go @@ -0,0 +1,65 @@ +package funcs + +import ( + "fmt" + "reflect" + "strings" +) + +var _ fmt.Stringer = (*namespace)(nil) + +type namespace struct { + self any //must be pointer to outer struct +} + +func (n *namespace) String() string { + ns := n.self + if ns == nil { + return "" + } + + nsType := reflect.TypeOf(ns) + if nsType.Kind() != reflect.Pointer || nsType.Elem().Kind() != reflect.Struct { + panic("invalid namespace type " + nsType.String() + ": must be pointer to struct") + } + + var public []string + public = appendPublicMethods(nsType, public) + nsType = nsType.Elem() + public = appendPublicFields(nsType, public) + + nsName := nsType.String() + nsName = strings.TrimPrefix(nsName, "funcs.") + nsName = strings.TrimSuffix(nsName, "Funcs") + + return fmt.Sprintf("", nsName, public) +} + +func appendPublicFields(nsType reflect.Type, public []string) []string { + for _, field := range reflect.VisibleFields(nsType) { + if !field.IsExported() { + continue + } + + public = append(public, field.Name) + } + + return public +} + +func appendPublicMethods(nsType reflect.Type, public []string) []string { + for i := range nsType.NumMethod() { + method := nsType.Method(i) + if !method.IsExported() { + continue + } + + if method.Name == "String" && nsType.Implements(reflect.TypeFor[fmt.Stringer]()) { + continue + } + + public = append(public, method.Name) + } + + return public +} diff --git a/internal/funcs/ssh.go b/internal/funcs/ssh.go new file mode 100644 index 000000000..d60d1ad38 --- /dev/null +++ b/internal/funcs/ssh.go @@ -0,0 +1,283 @@ +package funcs + +import ( + "context" + "errors" + "fmt" + "io" + "io/fs" + "log/slog" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/hairyhenderson/gomplate/v4/base64" + "github.com/hairyhenderson/gomplate/v4/conv" + "github.com/kevinburke/ssh_config" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +// CreateSSHFuncs - +func CreateSSHFuncs(ctx context.Context) map[string]any { + ns := newSSHFuncs(ctx) + + return map[string]any{ + "ssh": func() any { return ns }, + } +} + +func newSSHFuncs(ctx context.Context) *SSHFuncs { + ns := &SSHFuncs{ + ctx: ctx, + fs: getFS(ctx), + } + + ns.self = ns + return ns +} + +type SSHFuncs struct { + namespace + + ctx context.Context + fs fs.FS + + homeDir string + homeDirOnce sync.Once + + conf *sshClientConf + confOnce sync.Once +} + +// PublicKey - +func (f *SSHFuncs) PublicKey(args ...any) (key *PublicKey, err error) { + slog.Debug("PublicKey start", slog.Any("args", args)) + defer func() { slog.Debug("PublicKey end", slog.Any("key", key), slog.Any("err", err)) }() + + var nameAny any + if len(args) > 0 { + nameAny = args[0] + } + + keyPath, err := f.toPublicKeyPath(nameAny) + if err != nil { + return nil, err + } + + keyFile, err := f.fs.Open(keyPath + ".pub") + if err != nil { + return nil, fmt.Errorf("open key file: %w", err) + } + defer keyFile.Close() + + //TODO: add support for pem-encoded keys + rawKeyData, err := io.ReadAll(keyFile) + if err != nil { + return nil, fmt.Errorf("read key file: %w", err) + } + + out, comment, _, _, err := ssh.ParseAuthorizedKey(rawKeyData) + if err != nil { + return nil, fmt.Errorf("ssh: parse authorized key: %w", err) + } + + agentKey := agent.Key{ + Format: out.Type(), + Blob: out.Marshal(), + Comment: comment, + } + + return newPublicKey(agentKey), nil +} + +type PublicKey struct { + namespace + k agent.Key +} + +// Format - +func (k PublicKey) Format() string { return k.k.Format } + +// Blob - +func (k PublicKey) Blob() []byte { return k.k.Blob } + +// Comment - +func (k PublicKey) Comment() string { return k.k.Comment } + +// Marshal - +func (k PublicKey) Marshal() string { + return k.k.String() +} + +// Verify - +// TODO: documentation +func (k PublicKey) Verify(args ...any) bool { + return k.MustVerify(args...) == nil +} + +// MustVerify - +// TODO: documentation +func (k PublicKey) MustVerify(args ...any) error { + var data, signature, format any + if len(args) > 1 { + data, signature = args[0], args[1] + } + if len(args) > 2 { + format = args[2] + } + + sshSig, err := toSSHSignature(signature, format) + if err != nil { + return err + } + + return k.k.Verify(toBytes(data), sshSig) +} + +func newPublicKey(k agent.Key) *PublicKey { + key := &PublicKey{k: k} + key.self = key + return key +} + +func (f *SSHFuncs) toPublicKeyPath(nameAny any) (string, error) { + if nameAny == nil { + conf, err := f.getConfig() + if err != nil { + return "", fmt.Errorf("get config: %w", err) + } + + return conf.IdentityFile, nil + } + + if path := conv.ToString(nameAny); filepath.IsAbs(path) { + // we will add it in fs.Open argument anyway + return strings.TrimSuffix(path, ".pub"), nil + } + + homeDir, err := f.getHomeDir() + if err != nil { + return "", fmt.Errorf("get home dir: %w", err) + } + + return filepath.Join(homeDir, ".ssh", conv.ToString(nameAny)), nil + +} + +type sshClientConf struct { + IdentityFile string +} + +func (c *sshClientConf) setDefaults(homeDir string) *sshClientConf { + if c.IdentityFile == "" { + c.IdentityFile = filepath.Join(homeDir, ".ssh", "id_rsa") + } + + return c +} + +func (f *SSHFuncs) getConfig() (*sshClientConf, error) { + var err error + f.confOnce.Do(func() { + var homeDir string + homeDir, err = f.getHomeDir() + if err != nil { + err = fmt.Errorf("get home dir: %w", err) + return + } + + f.conf, err = f.getConfigOnce(homeDir) + if f.conf == nil { + f.conf = new(sshClientConf) + } + + f.conf.setDefaults(homeDir) + }) + + return f.conf, err +} + +func (f *SSHFuncs) getConfigOnce(homeDir string) (*sshClientConf, error) { + file, err := f.fs.Open(filepath.Join(homeDir, ".ssh", "config")) + if errors.Is(err, fs.ErrNotExist) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("open config file: %w", err) + } + defer file.Close() + + cfg, err := ssh_config.Decode(file) + if err != nil { + return nil, fmt.Errorf("fail to read ssh config: %w", err) + } + + identityFiles, err := cfg.GetAll("", "IdentityFile") + if err != nil { + return nil, fmt.Errorf("get IdentityFile directives: %w", err) + } + + conf := &sshClientConf{ + IdentityFile: getLast(identityFiles), + } + + return conf, nil +} + +func (f *SSHFuncs) getHomeDir() (string, error) { + var err error + if f.homeDir != "" { + return f.homeDir, nil + } + + f.homeDirOnce.Do(func() { + f.homeDir, err = os.UserHomeDir() + }) + + return f.homeDir, err +} + +func (f *SSHFuncs) reset() { + f.homeDirOnce = sync.Once{} + f.confOnce = sync.Once{} +} + +func (f *SSHFuncs) trace(msg string, attrs ...slog.Attr) { + trace(f.ctx, msg, attrs...) +} + +func toSSHSignature(sig, f any) (*ssh.Signature, error) { + format := conv.ToString(f) + if f == nil { + format = ssh.KeyAlgoRSA + } + + switch sig := sig.(type) { + case *ssh.Signature: + return sig, nil + case string, []byte, byter, fmt.Stringer: + return &ssh.Signature{Format: format, Blob: maybeBase64toBytes(sig)}, nil + default: + return nil, fmt.Errorf("could not convert %T to *ssh.Signature", sig) + } +} + +func getLast[T any](list []T) T { + if len(list) == 0 { + var zero T + return zero + } + + return list[len(list)-1] +} + +func maybeBase64toBytes(in any) []byte { + decoded, err := base64.Decode(conv.ToString(in)) + if err == nil { + return decoded + } + + return toBytes(in) +} diff --git a/internal/funcs/ssh_test.go b/internal/funcs/ssh_test.go new file mode 100644 index 000000000..6d8f15826 --- /dev/null +++ b/internal/funcs/ssh_test.go @@ -0,0 +1,210 @@ +package funcs + +import ( + "context" + "crypto" + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "crypto/sha512" + "encoding/base64" + "fmt" + "io/fs" + "math/big" + "strconv" + "strings" + "testing" + "testing/fstest" + + "github.com/hairyhenderson/gomplate/v4/internal/datafs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +func TestCreateSSHFuncs(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + // Run this a bunch to catch race conditions + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + fmap := CreateSSHFuncs(ctx) + actual := fmap["ssh"].(func() any) + + assert.Equal(t, ctx, actual().(*SSHFuncs).ctx) + }) + } +} + +func TestPublicKey(t *testing.T) { + t.Parallel() + + const ( + idRSAData = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQCnEosV4dTgI6CL4YgM4Tfzs6CKdvLL/tarxipWrgEcdwn0TqFn3PmvxSOQWXbQci1Rl2I+U6X3Z4qQ3fafEOlF/bDbwfnY/eUpr9dHnVe1FCbX0tVzCR7OMHg7vGnF3Mta5E9MXMBKupiukgH51hH6fosr90Cvuhj0vsmO3jQL+i1yQxgbc14RCMQuIUZqAA/1Y9JWtucYe4X2uRyby/m2qtHA08kjPTREVd1cMSTM6rCdxnjXgJn7I416ybWnNIwwYeU8q2aKNPIhndSnIBMdDQnnxRCQHgWZXGjF8K8dVl1r3lJWbg/XMXKDWwLXbhRXZwR7/6HDamsV9fkY5Sld9VfKesNiCjaWLlnbe3d6NbdveBcBO6DgDFcshvvtOyu4quBly8EJFpyfeo5V8XQTIVMcLxehXMZNlk0C0PGKQx4xHdxTwFw9IFPbuGNRqRIRwC0YEH3TR4+xBp/gxAedO6GSFC7X+feNqKydIqKlq82R9cnjJPuPLbVvWPB+r08PeJobl++6d9m8EQorpokS+ntqnr35QnIBDWLHk139KhWkOjDOvUHJd6pjOOLhSVapmKPOz1dST4QCweET59STvLHHjNVQfJtWI9zVl4X9S4SoiLDkUUyge+9UnqyA9bAr2P4NkVWZYgf3QnrqoWpRGHz1F7JgV+VmGOlh/Kmc6Q== email@example.com" + idED25519Data = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBCLlDopq1aotlRUMw6oJ7Snr+qa+r5X8qxADTuYJumN other_email@example.com\n" + ) + + fsys := fstest.MapFS{ + "home/user/.ssh": &fstest.MapFile{Mode: fs.ModeDir | 0o777}, + "home/user/.ssh/id_rsa.pub": &fstest.MapFile{Data: []byte(idRSAData)}, + "home/user/.ssh/id_ed25519.pub": &fstest.MapFile{Data: []byte(idED25519Data)}, + } + ff := &SSHFuncs{fs: datafs.WrapWdFS(fsys), homeDir: "/home/user"} + + defaultKey, err := ff.PublicKey(nil) + require.NoError(t, err) + require.NotNil(t, defaultKey) + + ed25519Key, err := ff.PublicKey("id_ed25519") + require.NoError(t, err) + require.NotNil(t, ed25519Key) + + assert.Equal(t, "ssh-rsa", defaultKey.Format()) + assert.Equal(t, idRSAData, defaultKey.Marshal()) + assert.Equal(t, "email@example.com", defaultKey.Comment()) + + expectedED25519 := strings.Join(strings.Fields(idED25519Data), " ") + assert.Equal(t, "ssh-ed25519", ed25519Key.Format()) + assert.Equal(t, expectedED25519, ed25519Key.Marshal()) + assert.Equal(t, "other_email@example.com", ed25519Key.Comment()) + + fsys["home/user/.ssh/config"] = &fstest.MapFile{Data: []byte("IdentityFile /home/user/.ssh/id_ed25519")} + ff.reset() + defaultKey, err = ff.PublicKey(nil) + require.NoError(t, err) + require.NotNil(t, defaultKey) + + assert.Equal(t, "ssh-ed25519", defaultKey.Format()) + assert.Equal(t, expectedED25519, defaultKey.Marshal()) + assert.Equal(t, "other_email@example.com", defaultKey.Comment()) + + fsys["home/user/.ssh/config"] = &fstest.MapFile{} + ff.reset() + defaultKey, err = ff.PublicKey(nil) + require.NoError(t, err) + require.NotNil(t, defaultKey) + + assert.Equal(t, "ssh-rsa", defaultKey.Format()) + assert.Equal(t, idRSAData, defaultKey.Marshal()) + assert.Equal(t, "email@example.com", defaultKey.Comment()) +} + +func TestRSAPublicKeyVerify(t *testing.T) { + t.Parallel() + + keyPair, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + publicKey := &keyPair.PublicKey + + fsys := fstest.MapFS{ + "home/user/.ssh": &fstest.MapFile{Mode: fs.ModeDir | 0o777}, + "home/user/.ssh/config": &fstest.MapFile{Data: []byte("IdentityFile /home/user/.ssh/id_ed25519")}, + "home/user/.ssh/id_ed25519.pub": &fstest.MapFile{Data: marshalAuthorizedKey(publicKey)}, + } + ff := &SSHFuncs{fs: datafs.WrapWdFS(fsys), homeDir: "/home/user"} + + defaultKey, err := ff.PublicKey(nil) + require.NoError(t, err) + require.NotNil(t, defaultKey) + + msg := []byte("My message") + + msgHash := sha512.New() + _, err = msgHash.Write(msg) + require.NoError(t, err) + + digest := msgHash.Sum(nil) + + signature, err := rsa.SignPKCS1v15(rand.Reader, keyPair, crypto.SHA512, digest) + require.NoError(t, err) + + sshSignature := &ssh.Signature{ + Format: ssh.KeyAlgoRSASHA512, + Blob: signature, + } + assert.NoError(t, defaultKey.MustVerify(msg, sshSignature, nil)) + assert.NoError(t, defaultKey.MustVerify(msg, sshSignature.Blob, sshSignature.Format)) + assert.NoError(t, defaultKey.MustVerify(msg, base64Encode(sshSignature.Blob), sshSignature.Format)) +} + +func TestED25519Verify(t *testing.T) { + t.Parallel() + + publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + fsys := fstest.MapFS{ + "home/user/.ssh": &fstest.MapFile{Mode: fs.ModeDir | 0o777}, + "home/user/.ssh/id_rsa.pub": &fstest.MapFile{Data: marshalAuthorizedKey(publicKey)}, + } + ff := &SSHFuncs{fs: datafs.WrapWdFS(fsys), homeDir: "/home/user"} + + defaultKey, err := ff.PublicKey(nil) + require.NoError(t, err) + require.NotNil(t, defaultKey) + + msg := []byte("My message") + signature := ed25519.Sign(privateKey, msg) + + sshSignature := &ssh.Signature{ + Format: ssh.KeyAlgoED25519, + Blob: signature, + } + assert.NoError(t, defaultKey.MustVerify(msg, sshSignature, nil)) + assert.NoError(t, defaultKey.MustVerify(msg, sshSignature.Blob, sshSignature.Format)) + assert.NoError(t, defaultKey.MustVerify(msg, base64Encode(sshSignature.Blob), sshSignature.Format)) +} + +func marshalAuthorizedKey(key any) []byte { + var sshKey agent.Key + + switch key := key.(type) { + case *rsa.PublicKey: + sshKey.Format = ssh.KeyAlgoRSA + sshKey.Blob = ssh.Marshal(struct { + Format string + E *big.Int + N *big.Int + }{ + sshKey.Format, + big.NewInt(int64(key.E)), + key.N, + }) + case ed25519.PublicKey: + sshKey.Format = ssh.KeyAlgoED25519 + sshKey.Blob = ssh.Marshal(struct { + Format string + Blob []byte + }{ + sshKey.Format, + key, + }) + //TODO: + //case *ecdsa.PublicKey: + // switch bitSize := key.Params().BitSize; bitSize { + // case 256: + // sshKey.Format = ssh.KeyAlgoECDSA256 + // case 384: + // sshKey.Format = ssh.KeyAlgoECDSA384 + // case 521: + // sshKey.Format = ssh.KeyAlgoECDSA521 + // default: + // panic(fmt.Errorf("ecdsa: unsupported bit size %d", bitSize)) + // } + default: + panic(fmt.Errorf("unsupported key type %T", key)) + } + + return ssh.MarshalAuthorizedKey(&sshKey) +} + +func base64Encode(blob []byte) []byte { + buf := make([]byte, base64.StdEncoding.EncodedLen(len(blob))) + base64.StdEncoding.Encode(buf, blob) + return buf +}