diff --git a/stdlib/public/core/DoubleWidth.swift.gyb b/stdlib/public/core/DoubleWidth.swift.gyb index 0e95abb490895..23689779ec52d 100644 --- a/stdlib/public/core/DoubleWidth.swift.gyb +++ b/stdlib/public/core/DoubleWidth.swift.gyb @@ -314,8 +314,8 @@ public struct DoubleWidth : let initialOffset = q.leadingZeroBitCount + (DoubleWidth.bitWidth - rhs.leadingZeroBitCount) - 1 - // TODO(performance): Use &>> instead here? // Start with remainder capturing the high bits of q. + // (These need to be smart shifts, as initialOffset can be > q.bitWidth) var r = q >> Magnitude(DoubleWidth.bitWidth - initialOffset) q <<= Magnitude(initialOffset) @@ -396,8 +396,9 @@ public struct DoubleWidth : let mid2 = sum(b.carry, c.carry, d.partial) let low = DoubleWidth((mid1.partial, a.partial)) - let high = DoubleWidth( - (High(mid2.carry + d.carry), mid1.carry + mid2.partial)) + let high = DoubleWidth(( + High(mid2.carry + d.carry), mid1.carry + mid2.partial + )) if isNegative { let (lowComplement, overflow) = (~low).addingReportingOverflow(1) @@ -433,6 +434,7 @@ public struct DoubleWidth : return } + // Shift is larger than this type's bit width. if rhs._storage.high != (0 as High) || rhs._storage.low >= DoubleWidth.bitWidth { @@ -476,32 +478,30 @@ public struct DoubleWidth : } public static func &<<=(lhs: inout DoubleWidth, rhs: DoubleWidth) { + // Need to use smart shifts here, since rhs can be > Base.bitWidth let rhs = rhs & DoubleWidth(DoubleWidth.bitWidth - 1) lhs._storage.high <<= High(rhs._storage.low) - if Base.bitWidth > rhs._storage.low { - lhs._storage.high |= High(truncatingIfNeeded: lhs._storage.low >> - (numericCast(Base.bitWidth) - rhs._storage.low)) - } else { - lhs._storage.high |= High(truncatingIfNeeded: lhs._storage.low << - (rhs._storage.low - numericCast(Base.bitWidth))) - } + + let lowInHigh = Base.bitWidth > rhs._storage.low + ? lhs._storage.low >> (numericCast(Base.bitWidth) - rhs._storage.low) + : lhs._storage.low << (rhs._storage.low - numericCast(Base.bitWidth)) + lhs._storage.high |= High(truncatingIfNeeded: lowInHigh) + lhs._storage.low <<= rhs._storage.low } public static func &>>=(lhs: inout DoubleWidth, rhs: DoubleWidth) { + // Need to use smart shifts here, since rhs can be > Base.bitWidth let rhs = rhs & DoubleWidth(DoubleWidth.bitWidth - 1) lhs._storage.low >>= rhs._storage.low - if Base.bitWidth > rhs._storage.low { - lhs._storage.low |= Low( - truncatingIfNeeded: - lhs._storage.high << (numericCast(Base.bitWidth) - rhs._storage.low)) - } else { - lhs._storage.low |= Low( - truncatingIfNeeded: lhs._storage.high >> - (rhs._storage.low - numericCast(Base.bitWidth))) - } + + let highInLow = Base.bitWidth > rhs._storage.low + ? lhs._storage.high << (numericCast(Base.bitWidth) - rhs._storage.low) + : lhs._storage.high >> (rhs._storage.low - numericCast(Base.bitWidth)) + lhs._storage.low |= Low(truncatingIfNeeded: highInLow) + lhs._storage.high >>= High(truncatingIfNeeded: rhs._storage.low) } @@ -573,8 +573,10 @@ binaryOperators = [ @_transparent public var byteSwapped: DoubleWidth { - return DoubleWidth((High(truncatingIfNeeded: low.byteSwapped), - Low(truncatingIfNeeded: high.byteSwapped))) + return DoubleWidth(( + High(truncatingIfNeeded: low.byteSwapped), + Low(truncatingIfNeeded: high.byteSwapped) + )) } } diff --git a/test/stdlib/Integers.swift.gyb b/test/stdlib/Integers.swift.gyb index c2e6000ce8f02..175bef86c109b 100644 --- a/test/stdlib/Integers.swift.gyb +++ b/test/stdlib/Integers.swift.gyb @@ -761,22 +761,30 @@ dwTests.test("TwoWords") { } dwTests.test("Bitshifts") { - typealias DWU16 = DoubleWidth - typealias DWU32 = DoubleWidth - typealias DWU64 = DoubleWidth + typealias DWU64 = DoubleWidth>> + typealias DWI64 = DoubleWidth>> - func f(_ x: UInt64) { - let y = DWU64(x) - for i in -65...65 { + func f(_ x: T, type: U.Type) { + let y = U(x) + expectEqual(T.bitWidth, U.bitWidth) + for i in -(T.bitWidth + 1)...(T.bitWidth + 1) { expectTrue(x << i == y << i) expectTrue(x >> i == y >> i) + + expectTrue(x &<< i == y &<< i) + expectTrue(x &>> i == y &>> i) } } - f(1) - f(~(~0 >> 1)) - f(.max) - f(0b11110000_10100101_11110000_10100101_11110000_10100101_11110000_10100101) + f(1 as UInt64, type: DWU64.self) + f(~(~0 as UInt64 >> 1), type: DWU64.self) + f(UInt64.max, type: DWU64.self) + f(0b11110000_10100101_11110000_10100101_11110000_10100101_11110000_10100101 as UInt64, type: DWU64.self) + + f(1 as Int64, type: DWI64.self) + f(Int64.min, type: DWI64.self) + f(Int64.max, type: DWI64.self) + f(0b01010101_10100101_11110000_10100101_11110000_10100101_11110000_10100101 as Int64, type: DWI64.self) } dwTests.test("Remainder/DividingBy0") {