diff --git a/knownhosts.go b/knownhosts.go index 4dad777..7835726 100644 --- a/knownhosts.go +++ b/knownhosts.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "io/ioutil" "net" "sort" "strings" @@ -34,12 +35,17 @@ func (hkcb HostKeyCallback) HostKeyCallback() ssh.HostKeyCallback { return ssh.HostKeyCallback(hkcb) } +type PublicKey struct { + ssh.PublicKey + cert bool +} + // HostKeys returns a slice of known host public keys for the supplied host:port // found in the known_hosts file(s), or an empty slice if the host is not // already known. For hosts that have multiple known_hosts entries (for // different key types), the result will be sorted by known_hosts filename and // line number. -func (hkcb HostKeyCallback) HostKeys(hostWithPort string) (keys []ssh.PublicKey) { +func (hkcb HostKeyCallback) HostKeys(hostWithPort string) (keys []PublicKey) { var keyErr *xknownhosts.KeyError placeholderAddr := &net.TCPAddr{IP: []byte{0, 0, 0, 0}} placeholderPubKey := &fakePublicKey{} @@ -53,14 +59,47 @@ func (hkcb HostKeyCallback) HostKeys(hostWithPort string) (keys []ssh.PublicKey) return (kkeys[i].Filename == kkeys[j].Filename && kkeys[i].Line < kkeys[j].Line) } sort.Slice(kkeys, knownKeyLess) - keys = make([]ssh.PublicKey, len(kkeys)) - for n := range kkeys { - keys[n] = kkeys[n].Key + keys = make([]PublicKey, len(kkeys)) + for n, k := range kkeys { + content, err := ioutil.ReadFile(k.Filename) + if err != nil { + continue + } + lines := strings.Split(string(content), "\n") + line := lines[k.Line-1] + isCert := strings.HasPrefix(line, "@cert-authority") + + keys[n] = PublicKey{ + PublicKey: k.Key, + cert: isCert, + } } } return keys } +func keyTypeToCertType(keyType string) string { + switch keyType { + case ssh.KeyAlgoRSA: + return ssh.CertAlgoRSAv01 + case ssh.KeyAlgoDSA: + return ssh.CertAlgoDSAv01 + case ssh.KeyAlgoECDSA256: + return ssh.CertAlgoECDSA256v01 + case ssh.KeyAlgoSKECDSA256: + return ssh.CertAlgoSKECDSA256v01 + case ssh.KeyAlgoECDSA384: + return ssh.CertAlgoECDSA384v01 + case ssh.KeyAlgoECDSA521: + return ssh.CertAlgoECDSA521v01 + case ssh.KeyAlgoED25519: + return ssh.CertAlgoED25519v01 + case ssh.KeyAlgoSKED25519: + return ssh.CertAlgoSKED25519v01 + } + return "" +} + // HostKeyAlgorithms returns a slice of host key algorithms for the supplied // host:port found in the known_hosts file(s), or an empty slice if the host // is not already known. The result may be used in ssh.ClientConfig's @@ -84,14 +123,27 @@ func (hkcb HostKeyCallback) HostKeyAlgorithms(hostWithPort string) (algos []stri } for _, key := range hostKeys { typ := key.Type() - if typ == ssh.KeyAlgoRSA { - // KeyAlgoRSASHA256 and KeyAlgoRSASHA512 are only public key algorithms, - // not public key formats, so they can't appear as a PublicKey.Type. - // The corresponding PublicKey.Type is KeyAlgoRSA. See RFC 8332, Section 2. - addAlgo(ssh.KeyAlgoRSASHA512) - addAlgo(ssh.KeyAlgoRSASHA256) + if key.cert { + certType := keyTypeToCertType(typ) + if certType == ssh.CertAlgoRSAv01 { + + // CertAlgoRSASHA256v01 and CertAlgoRSASHA512v01 can't appear as a + // Certificate.Type (or PublicKey.Type), but only in + // ClientConfig.HostKeyAlgorithms. + addAlgo(ssh.CertAlgoRSASHA256v01) + addAlgo(ssh.CertAlgoRSASHA512v01) + } + addAlgo(certType) + } else { + if typ == ssh.KeyAlgoRSA { + // KeyAlgoRSASHA256 and KeyAlgoRSASHA512 are only public key algorithms, + // not public key formats, so they can't appear as a PublicKey.Type. + // The corresponding PublicKey.Type is KeyAlgoRSA. See RFC 8332, Section 2. + addAlgo(ssh.KeyAlgoRSASHA512) + addAlgo(ssh.KeyAlgoRSASHA256) + } + addAlgo(typ) } - addAlgo(typ) } return algos }