From 0aea14891f7df394488d139fa8f4a0f788fe7cd0 Mon Sep 17 00:00:00 2001 From: Daniel Adam Date: Fri, 2 Aug 2024 16:15:21 +0200 Subject: [PATCH] Fix segfault in State::serialize method The method gets invoked from public API function Conn::ConnectionState but the cipherSuite pointer member might not have been initialized yet. Invoking ConnectionState too early causes a segfault. Issue is fixed by changing the return type of Conn::ConnectionState from State to (State, bool) and returning (State{}, false) if the cipherSuite has not been set. --- conn.go | 8 +++- conn_test.go | 115 ++++++++++++++++++++++++++++++++++++++++------ flight4handler.go | 12 ++++- flight5handler.go | 8 +++- resume_test.go | 10 +++- state.go | 28 ++++++++--- 6 files changed, 152 insertions(+), 29 deletions(-) diff --git a/conn.go b/conn.go index 459cdcf5..09fb5741 100644 --- a/conn.go +++ b/conn.go @@ -411,10 +411,14 @@ func (c *Conn) Close() error { // ConnectionState returns basic DTLS details about the connection. // Note that this replaced the `Export` function of v1. -func (c *Conn) ConnectionState() State { +func (c *Conn) ConnectionState() (State, bool) { c.lock.RLock() defer c.lock.RUnlock() - return *c.state.clone() + stateClone, err := c.state.clone() + if err != nil { + return State{}, false + } + return *stateClone, true } // SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile diff --git a/conn_test.go b/conn_test.go index c66e0381..283021ff 100644 --- a/conn_test.go +++ b/conn_test.go @@ -497,28 +497,40 @@ func TestExportKeyingMaterial(t *testing.T) { c.setLocalEpoch(0) c.setRemoteEpoch(0) - state := c.ConnectionState() + state, ok := c.ConnectionState() + if !ok { + t.Fatal("ConnectionState failed") + } _, err := state.ExportKeyingMaterial(exportLabel, nil, 0) if !errors.Is(err, errHandshakeInProgress) { t.Errorf("ExportKeyingMaterial when epoch == 0: expected '%s' actual '%s'", errHandshakeInProgress, err) } c.setLocalEpoch(1) - state = c.ConnectionState() + state, ok = c.ConnectionState() + if !ok { + t.Fatal("ConnectionState failed") + } _, err = state.ExportKeyingMaterial(exportLabel, []byte{0x00}, 0) if !errors.Is(err, errContextUnsupported) { t.Errorf("ExportKeyingMaterial with context: expected '%s' actual '%s'", errContextUnsupported, err) } for k := range invalidKeyingLabels() { - state = c.ConnectionState() + state, ok = c.ConnectionState() + if !ok { + t.Fatal("ConnectionState failed") + } _, err = state.ExportKeyingMaterial(k, nil, 0) if !errors.Is(err, errReservedExportKeyingMaterial) { t.Errorf("ExportKeyingMaterial reserved label: expected '%s' actual '%s'", errReservedExportKeyingMaterial, err) } } - state = c.ConnectionState() + state, ok = c.ConnectionState() + if !ok { + t.Fatal("ConnectionState failed") + } keyingMaterial, err := state.ExportKeyingMaterial(exportLabel, nil, 10) if err != nil { t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err) @@ -527,7 +539,10 @@ func TestExportKeyingMaterial(t *testing.T) { } c.state.isClient = true - state = c.ConnectionState() + state, ok = c.ConnectionState() + if !ok { + t.Fatal("ConnectionState failed") + } keyingMaterial, err = state.ExportKeyingMaterial(exportLabel, nil, 10) if err != nil { t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err) @@ -669,7 +684,11 @@ func TestPSK(t *testing.T) { t.Fatalf("TestPSK: Server failed(%v)", err) } - actualPSKIdentityHint := server.ConnectionState().IdentityHint + state, ok := server.ConnectionState() + if !ok { + t.Fatalf("TestPSK: Server ConnectionState failed") + } + actualPSKIdentityHint := state.IdentityHint if !bytes.Equal(actualPSKIdentityHint, test.ClientIdentity) { t.Errorf("TestPSK: Server ClientPSKIdentity Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ClientIdentity, actualPSKIdentityHint) } @@ -1194,7 +1213,11 @@ func TestClientCertificate(t *testing.T) { t.Errorf("Client failed(%v)", res.err) } - actualClientCert := server.ConnectionState().PeerCertificates + state, ok := server.ConnectionState() + if !ok { + t.Error("Server connection state not available") + } + actualClientCert := state.PeerCertificates if tt.serverCfg.ClientAuth == RequireAnyClientCert || tt.serverCfg.ClientAuth == RequireAndVerifyClientCert { if actualClientCert == nil { t.Errorf("Client did not provide a certificate") @@ -1221,7 +1244,11 @@ func TestClientCertificate(t *testing.T) { } } - actualServerCert := res.c.ConnectionState().PeerCertificates + clientState, ok := res.c.ConnectionState() + if !ok { + t.Error("Client connection state not available") + } + actualServerCert := clientState.PeerCertificates if actualServerCert == nil { t.Errorf("Server did not provide a certificate") } @@ -2889,8 +2916,12 @@ func TestSessionResume(t *testing.T) { t.Fatalf("TestSessionResume: Server failed(%v)", err) } - actualSessionID := server.ConnectionState().SessionID - actualMasterSecret := server.ConnectionState().masterSecret + state, ok := server.ConnectionState() + if !ok { + t.Fatal("TestSessionResume: ConnectionState failed") + } + actualSessionID := state.SessionID + actualMasterSecret := state.masterSecret if !bytes.Equal(actualSessionID, id) { t.Errorf("TestSessionResumetion: SessionID Mismatch: expected(%v) actual(%v)", id, actualSessionID) } @@ -2940,8 +2971,12 @@ func TestSessionResume(t *testing.T) { t.Fatalf("TestSessionResumetion: Server failed(%v)", err) } - actualSessionID := server.ConnectionState().SessionID - actualMasterSecret := server.ConnectionState().masterSecret + state, ok := server.ConnectionState() + if !ok { + t.Fatal("TestSessionResumetion: ConnectionState failed") + } + actualSessionID := state.SessionID + actualMasterSecret := state.masterSecret ss, _ := s2.Get(actualSessionID) if !bytes.Equal(actualMasterSecret, ss.Secret) { t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", ss.Secret, actualMasterSecret) @@ -3071,8 +3106,8 @@ func TestCipherSuiteMatchesCertificateType(t *testing.T) { t.Fatal(err) } else if err := c.Close(); err != nil { t.Fatal(err) - } else if c.ConnectionState().cipherSuite.ID() != test.expectedCipher { - t.Fatalf("Expected(%s) and Actual(%s) CipherSuite do not match", test.expectedCipher, c.ConnectionState().cipherSuite.ID()) + } else if state, ok := c.ConnectionState(); !ok || state.cipherSuite.ID() != test.expectedCipher { + t.Fatalf("Expected(%s) and Actual(%s) CipherSuite do not match", test.expectedCipher, state.cipherSuite.ID()) } }) } @@ -3527,3 +3562,55 @@ func TestFragmentBuffer_Retransmission(t *testing.T) { t.Fatal("fragment should be retransmission") } } + +func TestConnectionState(t *testing.T) { + ca, cb := dpipe.Pipe() + + // Setup client + clientCfg := &Config{} + clientCert, err := selfsign.GenerateSelfSigned() + if err != nil { + t.Fatal(err) + } + clientCfg.Certificates = []tls.Certificate{clientCert} + clientCfg.InsecureSkipVerify = true + client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), clientCfg) + if err != nil { + t.Fatal(err) + } + defer func() { + _ = client.Close() + }() + + _, ok := client.ConnectionState() + if ok { + t.Fatal("ConnectionState should be nil") + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + c := make(chan error) + go func() { + errC := client.HandshakeContext(ctx) + c <- errC + }() + + // Setup server + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true) + if err != nil { + t.Fatal(err) + } + defer func() { + _ = server.Close() + }() + + err = <-c + if err != nil { + t.Fatal(err) + } + + _, ok = client.ConnectionState() + if !ok { + t.Fatal("ConnectionState should not be nil") + } +} diff --git a/flight4handler.go b/flight4handler.go index 5f867688..7e4ae12f 100644 --- a/flight4handler.go +++ b/flight4handler.go @@ -183,7 +183,11 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous { if cfg.verifyConnection != nil { - if err := cfg.verifyConnection(state.clone()); err != nil { + stateClone, err := state.clone() + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + if err := cfg.verifyConnection(stateClone); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } @@ -210,7 +214,11 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh // go to flight6 } if cfg.verifyConnection != nil { - if err := cfg.verifyConnection(state.clone()); err != nil { + stateClone, err := state.clone() + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + if err := cfg.verifyConnection(stateClone); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } diff --git a/flight5handler.go b/flight5handler.go index 2f6e7b3e..7e940cdc 100644 --- a/flight5handler.go +++ b/flight5handler.go @@ -344,8 +344,12 @@ func initializeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeCo } } if cfg.verifyConnection != nil { - if err = cfg.verifyConnection(state.clone()); err != nil { - return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err + stateClone, errC := state.clone() + if errC != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errC + } + if errC = cfg.verifyConnection(stateClone); errC != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errC } } diff --git a/resume_test.go b/resume_test.go index fa4ca0a4..4f79adb3 100644 --- a/resume_test.go +++ b/resume_test.go @@ -18,7 +18,10 @@ import ( "github.com/pion/transport/v3/test" ) -var errMessageMissmatch = errors.New("messages missmatch") +var ( + errMessageMissmatch = errors.New("messages missmatch") + errInvalidConnectionState = errors.New("failed to get connection state") +) func TestResumeClient(t *testing.T) { DoTestResume(t, Client, Server) @@ -120,7 +123,10 @@ func DoTestResume(t *testing.T, newLocal, newRemote func(net.PacketConn, net.Add } // Serialize and deserialize state - state := local.ConnectionState() + state, ok := local.ConnectionState() + if !ok { + fatal(t, errChan, errInvalidConnectionState) + } var b []byte b, err = state.MarshalBinary() if err != nil { diff --git a/state.go b/state.go index f2d6df6f..f1afb857 100644 --- a/state.go +++ b/state.go @@ -6,6 +6,7 @@ package dtls import ( "bytes" "encoding/gob" + "errors" "sync/atomic" "github.com/pion/dtls/v3/pkg/crypto/elliptic" @@ -87,15 +88,25 @@ type serializedState struct { NegotiatedProtocol string } -func (s *State) clone() *State { - serialized := s.serialize() +var errCipherSuiteNotSet = &InternalError{Err: errors.New("cipher suite not set")} //nolint:goerr113 + +func (s *State) clone() (*State, error) { + serialized, err := s.serialize() + if err != nil { + return nil, err + } state := &State{} state.deserialize(*serialized) - return state + return state, err } -func (s *State) serialize() *serializedState { +func (s *State) serialize() (*serializedState, error) { + if s.cipherSuite == nil { + return nil, errCipherSuiteNotSet + } + cipherSuiteID := uint16(s.cipherSuite.ID()) + // Marshal random values localRnd := s.localRandom.MarshalFixed() remoteRnd := s.remoteRandom.MarshalFixed() @@ -104,7 +115,7 @@ func (s *State) serialize() *serializedState { return &serializedState{ LocalEpoch: s.getLocalEpoch(), RemoteEpoch: s.getRemoteEpoch(), - CipherSuiteID: uint16(s.cipherSuite.ID()), + CipherSuiteID: cipherSuiteID, MasterSecret: s.masterSecret, SequenceNumber: atomic.LoadUint64(&s.localSequenceNumber[epoch]), LocalRandom: localRnd, @@ -117,7 +128,7 @@ func (s *State) serialize() *serializedState { RemoteConnectionID: s.remoteConnectionID, IsClient: s.isClient, NegotiatedProtocol: s.NegotiatedProtocol, - } + }, nil } func (s *State) deserialize(serialized serializedState) { @@ -187,7 +198,10 @@ func (s *State) initCipherSuite() error { // MarshalBinary is a binary.BinaryMarshaler.MarshalBinary implementation func (s *State) MarshalBinary() ([]byte, error) { - serialized := s.serialize() + serialized, err := s.serialize() + if err != nil { + return nil, err + } var buf bytes.Buffer enc := gob.NewEncoder(&buf)