diff --git a/ssh/server.go b/ssh/server.go index 7f0c236a9a..c2dfe3268c 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -213,6 +213,7 @@ func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewCha } else { for _, algo := range fullConf.PublicKeyAuthAlgorithms { if !contains(supportedPubKeyAuthAlgos, algo) { + c.Close() return nil, nil, nil, fmt.Errorf("ssh: unsupported public key authentication algorithm %s", algo) } } @@ -220,6 +221,7 @@ func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewCha // Check if the config contains any unsupported key exchanges for _, kex := range fullConf.KeyExchanges { if _, ok := serverForbiddenKexAlgos[kex]; ok { + c.Close() return nil, nil, nil, fmt.Errorf("ssh: unsupported key exchange %s for server", kex) } } diff --git a/ssh/server_test.go b/ssh/server_test.go index 2145dce06f..a9b2bce0c1 100644 --- a/ssh/server_test.go +++ b/ssh/server_test.go @@ -5,7 +5,11 @@ package ssh import ( + "io" + "net" + "sync/atomic" "testing" + "time" ) func TestClientAuthRestrictedPublicKeyAlgos(t *testing.T) { @@ -59,27 +63,78 @@ func TestClientAuthRestrictedPublicKeyAlgos(t *testing.T) { } func TestNewServerConnValidationErrors(t *testing.T) { - c1, c2, err := netPipe() - if err != nil { - t.Fatalf("netPipe: %v", err) - } - defer c1.Close() - defer c2.Close() - serverConf := &ServerConfig{ PublicKeyAuthAlgorithms: []string{CertAlgoRSAv01}, } - _, _, _, err = NewServerConn(c1, serverConf) + c := &markerConn{} + _, _, _, err := NewServerConn(c, serverConf) if err == nil { t.Fatal("NewServerConn with invalid public key auth algorithms succeeded") } + if !c.isClosed() { + t.Fatal("NewServerConn with invalid public key auth algorithms left connection open") + } + if c.isUsed() { + t.Fatal("NewServerConn with invalid public key auth algorithms used connection") + } + serverConf = &ServerConfig{ Config: Config{ KeyExchanges: []string{kexAlgoDHGEXSHA256}, }, } - _, _, _, err = NewServerConn(c1, serverConf) + c = &markerConn{} + _, _, _, err = NewServerConn(c, serverConf) if err == nil { t.Fatal("NewServerConn with unsupported key exchange succeeded") } + if !c.isClosed() { + t.Fatal("NewServerConn with unsupported key exchange left connection open") + } + if c.isUsed() { + t.Fatal("NewServerConn with unsupported key exchange used connection") + } +} + +type markerConn struct { + closed uint32 + used uint32 } + +func (c *markerConn) isClosed() bool { + return atomic.LoadUint32(&c.closed) != 0 +} + +func (c *markerConn) isUsed() bool { + return atomic.LoadUint32(&c.used) != 0 +} + +func (c *markerConn) Close() error { + atomic.StoreUint32(&c.closed, 1) + return nil +} + +func (c *markerConn) Read(b []byte) (n int, err error) { + atomic.StoreUint32(&c.used, 1) + if atomic.LoadUint32(&c.closed) != 0 { + return 0, net.ErrClosed + } else { + return 0, io.EOF + } +} + +func (c *markerConn) Write(b []byte) (n int, err error) { + atomic.StoreUint32(&c.used, 1) + if atomic.LoadUint32(&c.closed) != 0 { + return 0, net.ErrClosed + } else { + return 0, io.ErrClosedPipe + } +} + +func (*markerConn) LocalAddr() net.Addr { return nil } +func (*markerConn) RemoteAddr() net.Addr { return nil } + +func (*markerConn) SetDeadline(t time.Time) error { return nil } +func (*markerConn) SetReadDeadline(t time.Time) error { return nil } +func (*markerConn) SetWriteDeadline(t time.Time) error { return nil }