diff --git a/pkg/util/hkdf/hkdf.go b/pkg/util/hkdf/hkdf.go index 7b0da4a63..0617a520c 100644 --- a/pkg/util/hkdf/hkdf.go +++ b/pkg/util/hkdf/hkdf.go @@ -50,7 +50,7 @@ const ( ) // expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1. -func expandLabel(secret []byte, label string, context []byte, length int) []byte { +func expandLabel(secret []byte, label string, context []byte, length int, cipherId uint32) []byte { var hkdfLabel cryptobyte.Builder hkdfLabel.AddUint16(uint16(length)) hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { @@ -63,13 +63,13 @@ func expandLabel(secret []byte, label string, context []byte, length int) []byte out := make([]byte, length) var transcript crypto.Hash - switch length { - case 32: + switch uint16(cipherId & 0x0000FFFF) { + case TLS_AES_128_GCM_SHA256, TLS_CHACHA20_POLY1305_SHA256: transcript = crypto.SHA256 - case 48: + case TLS_AES_256_GCM_SHA384: transcript = crypto.SHA384 default: - panic(fmt.Sprintf("non-tls 1.3 hash found, length: %d", length)) + panic(fmt.Sprintf("Unknown cipher: %d", cipherId)) } n, err := hkdf.Expand(transcript.New, secret, hkdfLabel.BytesOrPanic()).Read(out) if err != nil || n != length { @@ -80,6 +80,6 @@ func expandLabel(secret []byte, label string, context []byte, length int) []byte // from crypto/tls/key_schedule.go line 35 // DeriveSecret implements Derive-Secret from RFC 8446, Section 7.1. -func DeriveSecret(secret []byte, label string, transcript hash.Hash) []byte { - return expandLabel(secret, label, transcript.Sum(nil), transcript.Size()) +func DeriveSecret(secret []byte, label string, transcript hash.Hash, cipherId uint32) []byte { + return expandLabel(secret, label, transcript.Sum(nil), transcript.Size(), cipherId) } diff --git a/pkg/util/hkdf/hkdf_test.go b/pkg/util/hkdf/hkdf_test.go index 4ab9a9ce7..8847f5ff2 100644 --- a/pkg/util/hkdf/hkdf_test.go +++ b/pkg/util/hkdf/hkdf_test.go @@ -10,19 +10,14 @@ import ( func TestHkdf(t *testing.T) { t.Log("TestHkdf") //TODO - var cipher_id = 50336513 + var cipherId uint32 = 50336513 var transcript hash.Hash - // test with different cipher_id - switch cipher_id & 0x0000FFFF { - case 0x1301: - t.Log("TLS_AES_128_GCM_SHA256") + // test with different cipherID + switch uint16(cipherId & 0x0000FFFF) { + case TLS_AES_128_GCM_SHA256, TLS_CHACHA20_POLY1305_SHA256: transcript = crypto.SHA256.New() - case 0x1302: - t.Log("TLS_AES_256_GCM_SHA384") + case TLS_AES_256_GCM_SHA384: transcript = crypto.SHA384.New() - case 0x1303: - t.Log("TLS_CHACHA20_POLY1305_SHA256") - transcript = crypto.SHA256.New() default: t.Log("Unknown cipher") } @@ -42,17 +37,17 @@ func TestHkdf(t *testing.T) { transcript.Write(handshakeTrafficHash) clientSecret := DeriveSecret(handshakeSecret, - ClientHandshakeTrafficLabel, transcript) + ClientHandshakeTrafficLabel, transcript, cipherId) t.Logf("%s: %x", KeyLogLabelClientHandshake, clientSecret) serverHandshakeSecret := DeriveSecret(handshakeSecret, - ServerHandshakeTrafficLabel, transcript) + ServerHandshakeTrafficLabel, transcript, cipherId) t.Logf("%s: %x", KeyLogLabelServerHandshake, serverHandshakeSecret) transcript = crypto.SHA256.New() transcript.Write(serverFinishedHash) trafficSecret := DeriveSecret(masterSecret, - ClientApplicationTrafficLabel, transcript) + ClientApplicationTrafficLabel, transcript, cipherId) t.Logf("%s: %x", KeyLogLabelClientTraffic, trafficSecret) transcript = crypto.SHA256.New() @@ -63,7 +58,7 @@ func TestHkdf(t *testing.T) { //hs.transcript.Write(finished.marshal()) transcript.Write(serverFinishedHash) serverSecret := DeriveSecret(masterSecret, - ServerApplicationTrafficLabel, transcript) + ServerApplicationTrafficLabel, transcript, cipherId) t.Logf("%s: %x", KeyLogLabelServerTraffic, serverSecret) t.Logf("%s: %x", KeyLogLabelServerTraffic, exporterMasterSecret[:transcript.Size()]) diff --git a/user/module/probe_openssl.go b/user/module/probe_openssl.go index 0d71794b8..e5c2bd8a0 100644 --- a/user/module/probe_openssl.go +++ b/user/module/probe_openssl.go @@ -500,18 +500,18 @@ func (this *MOpenSSLProbe) saveMasterSecret(secretEvent *event.MasterSecretEvent return } transcript.Write(secretEvent.HandshakeTrafficHash[:]) - clientSecret := hkdf.DeriveSecret(secretEvent.HandshakeSecret[:], hkdf.ClientHandshakeTrafficLabel, transcript) + clientSecret := hkdf.DeriveSecret(secretEvent.HandshakeSecret[:], hkdf.ClientHandshakeTrafficLabel, transcript, secretEvent.CipherId) b = bytes.NewBufferString(fmt.Sprintf("%s %02x %02x\n", hkdf.KeyLogLabelClientHandshake, secretEvent.ClientRandom, clientSecret)) - serverHandshakeSecret := hkdf.DeriveSecret(secretEvent.HandshakeSecret[:], hkdf.ServerHandshakeTrafficLabel, transcript) + serverHandshakeSecret := hkdf.DeriveSecret(secretEvent.HandshakeSecret[:], hkdf.ServerHandshakeTrafficLabel, transcript, secretEvent.CipherId) b.WriteString(fmt.Sprintf("%s %02x %02x\n", hkdf.KeyLogLabelServerHandshake, secretEvent.ClientRandom, serverHandshakeSecret)) transcript.Reset() transcript.Write(secretEvent.ServerFinishedHash[:]) - trafficSecret := hkdf.DeriveSecret(secretEvent.MasterSecret[:], hkdf.ClientApplicationTrafficLabel, transcript) + trafficSecret := hkdf.DeriveSecret(secretEvent.MasterSecret[:], hkdf.ClientApplicationTrafficLabel, transcript, secretEvent.CipherId) b.WriteString(fmt.Sprintf("%s %02x %02x\n", hkdf.KeyLogLabelClientTraffic, secretEvent.ClientRandom, trafficSecret)) - serverSecret := hkdf.DeriveSecret(secretEvent.MasterSecret[:], hkdf.ServerApplicationTrafficLabel, transcript) + serverSecret := hkdf.DeriveSecret(secretEvent.MasterSecret[:], hkdf.ServerApplicationTrafficLabel, transcript, secretEvent.CipherId) b.WriteString(fmt.Sprintf("%s %02x %02x\n", hkdf.KeyLogLabelServerTraffic, secretEvent.ClientRandom, serverSecret)) // TODO MasterSecret sum