diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go index 3287b01c4a..4a31a3c3ec 100644 --- a/ssl/test/runner/conn.go +++ b/ssl/test/runner/conn.go @@ -137,8 +137,6 @@ func (c *Conn) init() { c.out.config = c.config c.in.conn = c c.out.conn = c - - c.out.updateOutSeq() } // Access to net.Conn methods. @@ -188,7 +186,6 @@ type halfConn struct { recordNumberEncrypter recordNumberEncrypter mac macFunction seq [8]byte // 64-bit sequence number - outSeq [8]byte // Mapped sequence number nextCipher any // next encryption state nextMac macFunction // next MAC algorithm @@ -289,7 +286,7 @@ func (hc *halfConn) resetCipher() { } // incSeq increments the sequence number. -func (hc *halfConn) incSeq(isOutgoing bool) { +func (hc *halfConn) incSeq() { limit := 0 increment := uint64(1) if hc.isDTLS { @@ -308,8 +305,6 @@ func (hc *halfConn) incSeq(isOutgoing bool) { if increment != 0 { panic("TLS: sequence number wraparound") } - - hc.updateOutSeq() } // incNextSeq increments the starting sequence number for the next epoch. @@ -345,8 +340,6 @@ func (hc *halfConn) incEpoch() { hc.seq[i] = 0 } } - - hc.updateOutSeq() } func (hc *halfConn) setEpoch(epoch uint16) { @@ -359,21 +352,20 @@ func (hc *halfConn) setEpoch(epoch uint16) { for i := range hc.nextSeq { hc.nextSeq[i] = 0 } - hc.updateOutSeq() } -func (hc *halfConn) updateOutSeq() { - if hc.config.Bugs.SequenceNumberMapping != nil { - seqU64 := binary.BigEndian.Uint64(hc.seq[:]) - seqU64 = hc.config.Bugs.SequenceNumberMapping(seqU64) - binary.BigEndian.PutUint64(hc.outSeq[:], seqU64) - - // The DTLS epoch cannot be changed. - copy(hc.outSeq[:2], hc.seq[:2]) - return +func (hc *halfConn) sequenceNumberForOutput() []byte { + if !hc.isDTLS || hc.config.Bugs.SequenceNumberMapping == nil { + return hc.seq[:] } - copy(hc.outSeq[:], hc.seq[:]) + var seq [8]byte + seqU64 := binary.BigEndian.Uint64(hc.seq[:]) + seqU64 = hc.config.Bugs.SequenceNumberMapping(seqU64) + binary.BigEndian.PutUint64(seq[:], seqU64) + // The DTLS epoch cannot be changed. + copy(seq[:2], hc.seq[:2]) + return seq[:] } func (hc *halfConn) explicitIVLen() int { @@ -560,7 +552,7 @@ func (hc *halfConn) decrypt(seq []byte, recordHeaderLen int, record []byte) (ok return false, 0, nil, alertBadRecordMAC } } - hc.incSeq(false) + hc.incSeq() return true, contentType, payload, 0 } @@ -636,7 +628,7 @@ func (c *Conn) useDTLSPlaintextHeader() bool { // (which must be in the last two bytes of the header) should be computed for // the unencrypted, unpadded payload. It will be updated, potentially in-place, // with the final length. -func (hc *halfConn) encrypt(record, payload []byte, typ recordType, headerLen int, headerHasLength bool) ([]byte, error) { +func (hc *halfConn) encrypt(record, payload []byte, typ recordType, headerLen int, headerHasLength bool, seq []byte) ([]byte, error) { prefixLen := len(record) header := record[prefixLen-headerLen:] explicitIVLen := hc.explicitIVLen() @@ -662,7 +654,7 @@ func (hc *halfConn) encrypt(record, payload []byte, typ recordType, headerLen in } if hc.mac != nil { - record = append(record, hc.computeMAC(hc.outSeq[:], header, payload)...) + record = append(record, hc.computeMAC(seq, header, payload)...) } explicitIV := record[prefixLen : prefixLen+explicitIVLen] @@ -674,13 +666,13 @@ func (hc *halfConn) encrypt(record, payload []byte, typ recordType, headerLen in } c.XORKeyStream(record[prefixLen:], record[prefixLen:]) case *tlsAead: - nonce := hc.outSeq[:] + nonce := seq if hc.isDTLS && hc.version >= VersionTLS13 && !hc.conn.useDTLSPlaintextHeader() { // Unlike DTLS 1.2, DTLS 1.3's nonce construction does not use // the epoch number. We store the epoch and nonce numbers // together, so make a copy without the epoch. nonce = make([]byte, 8) - copy(nonce[2:], hc.outSeq[2:]) + copy(nonce[2:], seq[2:]) } // Save the explicit IV, if not empty. @@ -695,7 +687,7 @@ func (hc *halfConn) encrypt(record, payload []byte, typ recordType, headerLen in if hc.version < VersionTLS13 { // (D)TLS 1.2's AD is seq_num || type || version || plaintext length additionalData = make([]byte, 13) - copy(additionalData, hc.outSeq[:]) + copy(additionalData, seq) copy(additionalData[8:], header[:3]) additionalData[11] = byte(len(payload) >> 8) additionalData[12] = byte(len(payload)) @@ -736,7 +728,7 @@ func (hc *halfConn) encrypt(record, payload []byte, typ recordType, headerLen in record[prefixLen-2] = byte(n >> 8) record[prefixLen-1] = byte(n) } - hc.incSeq(true) + hc.incSeq() return record, nil } @@ -1292,7 +1284,7 @@ func (c *Conn) doWriteRecord(typ recordType, data []byte) (n int, err error) { record[3] = byte(m >> 8) // encrypt will update this record[4] = byte(m) - record, err = c.out.encrypt(record, data[:m], typ, tlsRecordHeaderLen, true /* header has length */) + record, err = c.out.encrypt(record, data[:m], typ, tlsRecordHeaderLen, true /* header has length */, c.out.seq[:]) if err != nil { return } @@ -1470,7 +1462,7 @@ func (c *Conn) skipPacket(packet []byte) error { return errors.New("tls: sequence mismatch") } copy(c.in.seq[2:], seq) - c.in.incSeq(false) + c.in.incSeq() } else { if bytes.Compare(seq, c.in.nextSeq[:]) < 0 { return errors.New("tls: sequence mismatch") diff --git a/ssl/test/runner/dtls.go b/ssl/test/runner/dtls.go index 201b6b87db..5eef430d93 100644 --- a/ssl/test/runner/dtls.go +++ b/ssl/test/runner/dtls.go @@ -411,7 +411,7 @@ func (c *Conn) dtlsFlushHandshake() error { // appendDTLS13RecordHeader appends to b the record header for a record of length // recordLen. -func (c *Conn) appendDTLS13RecordHeader(b []byte, recordLen int) []byte { +func (c *Conn) appendDTLS13RecordHeader(b, seq []byte, recordLen int) []byte { // Set the top 3 bits on the type byte to indicate the DTLS 1.3 record // header format. typ := byte(0x20) @@ -428,12 +428,12 @@ func (c *Conn) appendDTLS13RecordHeader(b []byte, recordLen int) []byte { typ |= 0x04 } // Set the epoch bits - typ |= c.out.outSeq[1] & 0x3 + typ |= seq[1] & 0x3 b = append(b, typ) if c.config.DTLSUseShortSeqNums { - b = append(b, c.out.outSeq[7]) + b = append(b, seq[7]) } else { - b = append(b, c.out.outSeq[6], c.out.outSeq[7]) + b = append(b, seq[6], seq[7]) } if !c.config.DTLSRecordHeaderOmitLength { b = append(b, byte(recordLen>>8), byte(recordLen)) @@ -467,21 +467,22 @@ func (c *Conn) dtlsPackRecord(typ recordType, data []byte, mustPack bool) (n int useDTLS13RecordHeader := c.out.version >= VersionTLS13 && c.out.cipher != nil && !c.useDTLSPlaintextHeader() headerHasLength := true record := make([]byte, 0, dtlsMaxRecordHeaderLen+len(data)+c.out.maxEncryptOverhead(len(data))) + seq := c.out.sequenceNumberForOutput() if useDTLS13RecordHeader { - record = c.appendDTLS13RecordHeader(record, len(data)) + record = c.appendDTLS13RecordHeader(record, seq, len(data)) headerHasLength = !c.config.DTLSRecordHeaderOmitLength } else { record = append(record, byte(typ)) record = append(record, byte(vers>>8)) record = append(record, byte(vers)) // DTLS records include an explicit sequence number. - record = append(record, c.out.outSeq[:]...) + record = append(record, seq...) record = append(record, byte(len(data)>>8)) record = append(record, byte(len(data))) } recordHeaderLen := len(record) - record, err = c.out.encrypt(record, data, typ, recordHeaderLen, headerHasLength) + record, err = c.out.encrypt(record, data, typ, recordHeaderLen, headerHasLength, seq) if err != nil { return }