Skip to content

Commit

Permalink
ssh: add support for extension negotiation (rfc 8308)
Browse files Browse the repository at this point in the history
This adds support for SSH extension negotiation via SSH_MSG_EXT_INFO as defined in RFC 8308 [1]. It also adds support for the `server-sig-algs` extension on both the client and server sides.

[1] https://datatracker.ietf.org/doc/html/rfc8308

Fixes golang/go#49269

Change-Id: Iab374d1254ec83eabdb5f433b95ff39a1a540cc3
GitHub-Last-Rev: 29bf9ec
GitHub-Pull-Request: golang#197
  • Loading branch information
aphistic authored and DominicLavery committed Jan 18, 2022
1 parent 5e0467b commit 3687956
Show file tree
Hide file tree
Showing 9 changed files with 719 additions and 44 deletions.
47 changes: 41 additions & 6 deletions ssh/client_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,31 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error {
if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil {
return err
}
packet, err := c.transport.readPacket()
if err != nil {
return err
}

var serviceAccept serviceAcceptMsg
if err := Unmarshal(packet, &serviceAccept); err != nil {
return err
readAcceptLoop:
for {
packet, err := c.transport.readPacket()
if err != nil {
return err
}

switch packet[0] {
case msgExtInfo:
var extInfo extInfoMsg
if err := Unmarshal(packet, &extInfo); err != nil {
return err
}
c.transport.extensions = extInfo.Extensions
continue
case msgServiceAccept:
if err := Unmarshal(packet, &serviceAccept); err != nil {
return err
}
break readAcceptLoop
default:
return fmt.Errorf("ssh: unexpected message received")
}
}

// during the authentication phase the client first attempts the "none" method
Expand Down Expand Up @@ -337,6 +355,14 @@ func handleAuthResponse(c packetConn) (authResult, []string, error) {
}

switch packet[0] {
case msgExtInfo:
var extInfo extInfoMsg
if err := Unmarshal(packet, &extInfo); err != nil {
return authFailure, nil, err
}
if transport, ok := c.(*handshakeTransport); ok {
transport.extensions = extInfo.Extensions
}
case msgUserAuthBanner:
if err := handleBannerResponse(c, packet); err != nil {
return authFailure, nil, err
Expand Down Expand Up @@ -420,6 +446,15 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe

// like handleAuthResponse, but with less options.
switch packet[0] {
case msgExtInfo:
var extInfo extInfoMsg
if err := Unmarshal(packet, &extInfo); err != nil {
return authFailure, nil, err
}
if transport, ok := c.(*handshakeTransport); ok {
transport.extensions = extInfo.Extensions
}
continue
case msgUserAuthBanner:
if err := handleBannerResponse(c, packet); err != nil {
return authFailure, nil, err
Expand Down
128 changes: 128 additions & 0 deletions ssh/client_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,75 @@ func TestClientAuthPublicKey(t *testing.T) {
}
}

func TestClientAuthPublicKeyExtensions(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()

certChecker := CertChecker{
IsUserAuthority: func(k PublicKey) bool {
return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal())
},
UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
return nil, nil
}

return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User())
},
IsRevoked: func(c *Certificate) bool {
return c.Serial == 666
},
}
serverConfig := &ServerConfig{
PublicKeyCallback: certChecker.Authenticate,
}
serverConfig.AddHostKey(testSigners["rsa"])

go newServer(c1, serverConfig)
clientConn, _, _, err := NewClientConn(c2, "", &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
})
if err != nil {
t.Fatalf("NewClientConn: %v", err)
}

conn, ok := clientConn.(*connection)
if !ok {
t.Fatalf("conn is not a *connection")
}

rawServerSigAlgs, ok := conn.transport.extensions[ExtServerSigAlgs]
if !ok {
t.Fatalf("did not receive server-sig-algs extension")
}

serverSigAlgs := strings.Split(string(rawServerSigAlgs), ",")
if len(serverSigAlgs) == 0 {
t.Fatalf("did not receive any server-sig-algs")
}

for _, expectedAlg := range supportedSigAlgs() {
hasAlg := false
for _, receivedAlg := range serverSigAlgs {
if receivedAlg == expectedAlg {
hasAlg = true
break
}
}
if !hasAlg {
t.Errorf("server-sig-algs did not have expected alg: %s", expectedAlg)
}
}
}

func TestAuthMethodPassword(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Expand All @@ -131,6 +200,65 @@ func TestAuthMethodPassword(t *testing.T) {
}
}

func TestClientAuthPasswordExtensions(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()

serverConfig := &ServerConfig{
PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) {
if conn.User() == "testuser" && string(pass) == clientPassword {
return nil, nil
}
return nil, errors.New("password auth failed")
},
}
serverConfig.AddHostKey(testSigners["rsa"])

go newServer(c1, serverConfig)
clientConn, _, _, err := NewClientConn(c2, "", &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
Password(clientPassword),
},
HostKeyCallback: InsecureIgnoreHostKey(),
})
if err != nil {
t.Fatalf("NewClientConn: %v", err)
}

conn, ok := clientConn.(*connection)
if !ok {
t.Fatalf("conn is not a *connection")
}

rawServerSigAlgs, ok := conn.transport.extensions[ExtServerSigAlgs]
if !ok {
t.Fatalf("did not receive server-sig-algs extension")
}

serverSigAlgs := strings.Split(string(rawServerSigAlgs), ",")
if len(serverSigAlgs) == 0 {
t.Fatalf("did not receive any server-sig-algs")
}

for _, expectedAlg := range supportedSigAlgs() {
hasAlg := false
for _, receivedAlg := range serverSigAlgs {
if receivedAlg == expectedAlg {
hasAlg = true
break
}
}
if !hasAlg {
t.Errorf("server-sig-algs did not have expected alg: %s", expectedAlg)
}
}
}

func TestAuthMethodFallback(t *testing.T) {
var passwordCalled bool
config := &ClientConfig{
Expand Down
51 changes: 51 additions & 0 deletions ssh/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@ const (
serviceSSH = "ssh-connection"
)

// These are string constants related to extensions and extension negotiation
const (
extInfoServer = "ext-info-s"
extInfoClient = "ext-info-c"

ExtServerSigAlgs = "server-sig-algs"
// extDelayCompression = "delay-compression"
// extNoFlowControl = "no-flow-control"
// extElevation = "elevation"
)

// defaultExtensions lists extensions enabled by default.
var defaultExtensions = []string{
ExtServerSigAlgs,
}

// supportedCiphers lists ciphers we support but might not recommend.
var supportedCiphers = []string{
"aes128-ctr", "aes192-ctr", "aes256-ctr",
Expand Down Expand Up @@ -108,6 +124,18 @@ var hashFuncs = map[string]crypto.Hash{
CertAlgoECDSA521v01: crypto.SHA512,
}

// supportedSigAlgs returns a slice of algorithms supported for pubkey authentication
// in no particular order.
func supportedSigAlgs() []string {
// TODO(kxd) I'm not sure if hashFuncs is the best place to get this set but it seemed
// like a sensible first step. Should this be a curated list?
var serverSigAlgs []string
for k := range hashFuncs {
serverSigAlgs = append(serverSigAlgs, k)
}
return serverSigAlgs
}

// unexpectedMessageError results when the SSH message that we received didn't
// match what we wanted.
func unexpectedMessageError(expected, got uint8) error {
Expand All @@ -130,6 +158,16 @@ func findCommon(what string, client []string, server []string) (common string, e
return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server)
}

// hasString returns true if string "a" is in slice of strings "x", false otherwise.
func hasString(a string, x []string) bool {
for _, s := range x {
if a == s {
return true
}
}
return false
}

// directionAlgorithms records algorithm choices in one direction (either read or write)
type directionAlgorithms struct {
Cipher string
Expand Down Expand Up @@ -165,6 +203,11 @@ func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMs
result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos)
if err != nil {
return
} else if result.kex == extInfoClient || result.kex == extInfoServer {
// According to RFC8308 section 2.2 if either the client or server extension signal
// is chosen as the kex algorithm the parties must disconnect.
// chosen
return result, fmt.Errorf("ssh: invalid kex algorithm chosen")
}

result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
Expand Down Expand Up @@ -238,6 +281,10 @@ type Config struct {
// The allowed MAC algorithms. If unspecified then a sensible default
// is used.
MACs []string

// A list of enabled extensions. If unspecified then a sensible
// default is used
Extensions []string
}

// SetDefaults sets sensible values for unset fields in config. This is
Expand Down Expand Up @@ -267,6 +314,10 @@ func (c *Config) SetDefaults() {
c.MACs = supportedMACs
}

if c.Extensions == nil {
c.Extensions = defaultExtensions
}

if c.RekeyThreshold == 0 {
// cipher specific default
} else if c.RekeyThreshold < minRekeyThreshold {
Expand Down
22 changes: 22 additions & 0 deletions ssh/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,28 @@ func TestFindAgreedAlgorithms(t *testing.T) {
wantErr: true,
},

testcase{
name: "server ext info kex chosen",
serverIn: kexInitMsg{
KexAlgos: []string{extInfoServer},
},
clientIn: kexInitMsg{
KexAlgos: []string{extInfoServer},
},
wantErr: true,
},

testcase{
name: "client ext info kex chosen",
serverIn: kexInitMsg{
KexAlgos: []string{extInfoClient},
},
clientIn: kexInitMsg{
KexAlgos: []string{extInfoClient},
},
wantErr: true,
},

testcase{
name: "client decides cipher",
serverIn: kexInitMsg{
Expand Down
Loading

0 comments on commit 3687956

Please sign in to comment.