Skip to content

Commit

Permalink
Implement key loading from .ssh
Browse files Browse the repository at this point in the history
Signed-off-by: Morten Linderud <morten@linderud.pw>
  • Loading branch information
Foxboron committed Jul 29, 2023
1 parent 275b3dd commit 91d9690
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 66 deletions.
142 changes: 92 additions & 50 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"log"
"net"
"os"
"os/signal"
"path"
"path/filepath"
"strings"
"sync"
"syscall"
"time"
Expand All @@ -24,29 +26,6 @@ import (
"golang.org/x/term"
)

// Return XDG_DATA_HOME or $HOME/.local/share
func getDataHome() string {
if s, ok := os.LookupEnv("XDG_DATA_HOME"); ok {
return s
}

dirname, err := os.UserHomeDir()
if err != nil {
panic("$HOME is not defined")
}

return path.Join(dirname, ".local/share")
}

func getAgentStorage() string {
return path.Join(getDataHome(), "ssh-tpm-agent")
}

func SaveKey(k *key.Key) error {
os.MkdirAll(getAgentStorage(), 0700)
return os.WriteFile(path.Join(getAgentStorage(), "ssh.key"), key.MarshalKey(k), 0600)
}

var ErrOperationUnsupported = errors.New("operation unsupported")

type Agent struct {
Expand All @@ -56,6 +35,7 @@ type Agent struct {
listener net.Listener
quit chan interface{}
wg sync.WaitGroup
keys map[string]*key.Key
}

var _ agent.ExtendedAgent = &Agent{}
Expand Down Expand Up @@ -86,19 +66,15 @@ func (a *Agent) Close() error {
}

func (a *Agent) signers() ([]ssh.Signer, error) {
b, err := os.ReadFile(path.Join(getAgentStorage(), "ssh.key"))
if err != nil {
return nil, err
}
k, err := key.UnmarshalKey(b)
if err != nil {
return nil, err
}
s, err := ssh.NewSignerFromSigner(signer.NewTPMSigner(k, a.tpm, a.pin))
if err != nil {
return nil, fmt.Errorf("failed to prepare signer: %w", err)
var signers []ssh.Signer
for _, k := range a.keys {
s, err := ssh.NewSignerFromSigner(signer.NewTPMSigner(k, a.tpm, a.pin))
if err != nil {
return nil, fmt.Errorf("failed to prepare signer: %w", err)
}
signers = append(signers, s)
}
return []ssh.Signer{s}, nil
return signers, nil
}

func (a *Agent) Signers() ([]ssh.Signer, error) {
Expand All @@ -108,26 +84,23 @@ func (a *Agent) Signers() ([]ssh.Signer, error) {
}

func (a *Agent) List() ([]*agent.Key, error) {
var agentKeys []*agent.Key

a.mu.Lock()
defer a.mu.Unlock()
b, err := os.ReadFile(path.Join(getAgentStorage(), "ssh.key"))
if err != nil {
return nil, err
}

k, err := key.UnmarshalKey(b)
if err != nil {
return nil, err
}
for _, k := range a.keys {
pk, err := k.SSHPublicKey()
if err != nil {
return nil, err
}

pk, err := k.SSHPublicKey()
if err != nil {
return nil, err
agentKeys = append(agentKeys, &agent.Key{
Format: pk.Type(),
Blob: pk.Marshal(),
})
}
return []*agent.Key{{
Format: pk.Type(),
Blob: pk.Marshal(),
}}, nil
return agentKeys, nil
}

func (a *Agent) SignWithFlags(key ssh.PublicKey, data []byte, flags agent.SignatureFlags) (*ssh.Signature, error) {
Expand Down Expand Up @@ -196,11 +169,76 @@ func (a *Agent) serve() {
}
}

func (a *Agent) AddKey(k *key.Key) error {
sshpubkey, err := k.SSHPublicKey()
if err != nil {
return err
}
a.keys[ssh.FingerprintSHA256(sshpubkey)] = k
return nil
}

func (a *Agent) LoadKeys() error {
a.mu.Lock()
defer a.mu.Unlock()
keys, err := LoadKeys()
if err != nil {
return err
}

a.keys = keys
return nil
}

func GetSSHDir() string {
dirname, err := os.UserHomeDir()
if err != nil {
panic("$HOME is not defined")
}
return path.Join(dirname, ".ssh")
}

func LoadKeys() (map[string]*key.Key, error) {
keys := map[string]*key.Key{}
err := filepath.WalkDir(GetSSHDir(),
func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}
if !strings.HasSuffix(path, "tpm") {
return nil
}
f, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("failed reading %s", path)
}
k, err := key.DecodeKey(f)
if err != nil {
return fmt.Errorf("%s not a TPM sealed key: %v", path, err)
}
sshpubkey, err := k.SSHPublicKey()
if err != nil {
return fmt.Errorf("%s can't read ssh public key from TPM public: %v", path, err)
}
keys[ssh.FingerprintSHA256(sshpubkey)] = k
return nil
},
)
if err != nil {
log.Fatal(err)
}
return keys, nil
}

func NewAgent(socketPath string, tpmFetch func() transport.TPMCloser, pin func(*key.Key) ([]byte, error)) *Agent {
a := &Agent{
tpm: tpmFetch,
pin: pin,
quit: make(chan interface{}),
keys: make(map[string]*key.Key),
}
l, err := net.Listen("unix", socketPath)
if err != nil {
Expand Down Expand Up @@ -239,5 +277,9 @@ func RunAgent(socketPath string, tpmFetch func() transport.TPMCloser, pin func(*
}

a := execAgent(socketPath, tpmFetch, pin)

//TODO: Maybe we should allow people to not auto-load keys
a.LoadKeys()

a.Wait()
}
11 changes: 4 additions & 7 deletions cmd/ssh-tpm-agent/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"fmt"
"log"
"net"
"os"
"path"
"testing"
"time"
Expand Down Expand Up @@ -123,12 +122,6 @@ func TestSSHAuth(t *testing.T) {
t.Fatalf("failed getting ssh public key")
}

os.Setenv("XDG_DATA_HOME", t.TempDir())

if err := agent.SaveKey(k); err != nil {
t.Fatalf("failed saving key: %v", err)
}

hostkey, msgSent := setupServer(clientKey)

socket := path.Join(t.TempDir(), "socket")
Expand All @@ -145,6 +138,10 @@ func TestSSHAuth(t *testing.T) {
)
defer ag.Stop()

if err := ag.AddKey(k); err != nil {
t.Fatalf("failed saving key: %v", err)
}

sshClient := &ssh.ClientConfig{
User: "username",
Auth: []ssh.AuthMethod{
Expand Down
10 changes: 1 addition & 9 deletions cmd/ssh-tpm-keygen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,6 @@ Options:
Generate new sealed keys for ssh-tpm-agent.`

func GetSSHDir() string {
dirname, err := os.UserHomeDir()
if err != nil {
panic("$HOME is not defined")
}
return path.Join(dirname, ".ssh")
}

func getStdin(s string, args ...any) (string, error) {
fmt.Printf(s, args...)
reader := bufio.NewReader(os.Stdin)
Expand Down Expand Up @@ -94,7 +86,7 @@ func main() {

fmt.Println("Generating a sealed public/private ecdsa key pair.")

filename := path.Join(GetSSHDir(), "id_ecdsa")
filename := path.Join(agent.GetSSHDir(), "id_ecdsa")
filenameInput, err := getStdin("Enter file in which to save the key (%s): ", filename)
if err != nil {
log.Fatal(err)
Expand Down

0 comments on commit 91d9690

Please sign in to comment.