diff --git a/Sources/NIOHTTP1/HTTPTypes.swift b/Sources/NIOHTTP1/HTTPTypes.swift index 594d5bd564..9388ec109e 100644 --- a/Sources/NIOHTTP1/HTTPTypes.swift +++ b/Sources/NIOHTTP1/HTTPTypes.swift @@ -19,27 +19,76 @@ let headerSeparator: StaticString = ": " /// A representation of the request line and header fields of a HTTP request. public struct HTTPRequestHead: Equatable { + private final class _Storage { + var method: HTTPMethod + var rawURI: URI + var version: HTTPVersion + + init(method: HTTPMethod, rawURI: URI, version: HTTPVersion) { + self.method = method + self.rawURI = rawURI + self.version = version + } + + func copy() -> _Storage { + return .init(method: self.method, rawURI: self.rawURI, version: self.version) + } + } + + private var _storage: _Storage + + /// The header fields for this HTTP request. + // warning: do not put this in `_Storage` as it'd trigger a CoW on every mutation + public var headers: HTTPHeaders + /// The HTTP method for this request. - public var method: HTTPMethod + public var method: HTTPMethod { + get { + return self._storage.method + } + set { + if !isKnownUniquelyReferenced(&self._storage) { + self._storage = self._storage.copy() + } + self._storage.method = newValue + } + } // Internal representation of the URI. - private var rawURI: URI + private var rawURI: URI { + get { + return self._storage.rawURI + } + set { + if !isKnownUniquelyReferenced(&self._storage) { + self._storage = self._storage.copy() + } + self._storage.rawURI = newValue + } + } /// The URI used on this request. public var uri: String { - get { - return String(uri: rawURI) + get { + return String(uri: rawURI) } - set { + set { rawURI = .string(newValue) } } /// The version for this HTTP request. - public var version: HTTPVersion - - /// The header fields for this HTTP request. - public var headers: HTTPHeaders + public var version: HTTPVersion { + get { + return self._storage.version + } + set { + if !isKnownUniquelyReferenced(&self._storage) { + self._storage = self._storage.copy() + } + self._storage.version = newValue + } + } /// Create a `HTTPRequestHead` /// @@ -57,10 +106,8 @@ public struct HTTPRequestHead: Equatable { /// - Parameter rawURI: The URI used on this request. /// - Parameter headers: The headers for this HTTP request. init(version: HTTPVersion, method: HTTPMethod, rawURI: URI, headers: HTTPHeaders) { - self.version = version - self.method = method - self.rawURI = rawURI self.headers = headers + self._storage = _Storage(method: method, rawURI: rawURI, version: version) } public static func ==(lhs: HTTPRequestHead, rhs: HTTPRequestHead) -> Bool { @@ -142,24 +189,58 @@ extension HTTPRequestHead { /// A representation of the status line and header fields of a HTTP response. public struct HTTPResponseHead: Equatable { - /// The HTTP response status. - public var status: HTTPResponseStatus + private final class _Storage { + var status: HTTPResponseStatus + var version: HTTPVersion + init(status: HTTPResponseStatus, version: HTTPVersion) { + self.status = status + self.version = version + } + func copy() -> _Storage { + return .init(status: self.status, version: self.version) + } + } - /// The HTTP version that corresponds to this response. - public var version: HTTPVersion + private var _storage: _Storage /// The HTTP headers on this response. + // warning: do not put this in `_Storage` as it'd trigger a CoW on every mutation public var headers: HTTPHeaders + /// The HTTP response status. + public var status: HTTPResponseStatus { + get { + return self._storage.status + } + set { + if !isKnownUniquelyReferenced(&self._storage) { + self._storage = self._storage.copy() + } + self._storage.status = newValue + } + } + + /// The HTTP version that corresponds to this response. + public var version: HTTPVersion { + get { + return self._storage.version + } + set { + if !isKnownUniquelyReferenced(&self._storage) { + self._storage = self._storage.copy() + } + self._storage.version = newValue + } + } + /// Create a `HTTPResponseHead` /// /// - Parameter version: The version for this HTTP response. /// - Parameter status: The status for this HTTP response. /// - Parameter headers: The headers for this HTTP response. public init(version: HTTPVersion, status: HTTPResponseStatus, headers: HTTPHeaders = HTTPHeaders()) { - self.version = version - self.status = status self.headers = headers + self._storage = _Storage(status: status, version: version) } public static func ==(lhs: HTTPResponseHead, rhs: HTTPResponseHead) -> Bool { @@ -220,10 +301,35 @@ private extension UInt8 { /// can be represented appropriately. public struct HTTPHeaders: CustomStringConvertible { + private final class _Storage { + var buffer: ByteBuffer + var headers: [HTTPHeader] + var continuous: Bool = true + + init(buffer: ByteBuffer, headers: [HTTPHeader], continuous: Bool) { + self.buffer = buffer + self.headers = headers + self.continuous = continuous + } + + func copy() -> _Storage { + return .init(buffer: self.buffer, headers: self.headers, continuous: self.continuous) + } + } + private var _storage: _Storage + // Because we use CoW implementations HTTPHeaders is also CoW - fileprivate var buffer: ByteBuffer - fileprivate var headers: [HTTPHeader] - fileprivate var continuous: Bool = true + fileprivate var buffer: ByteBuffer { + return self._storage.buffer + } + + fileprivate var headers: [HTTPHeader] { + return self._storage.headers + } + + fileprivate var continuous: Bool { + return self._storage.continuous + } /// Returns the `String` for the given `HTTPHeaderIndex`. /// @@ -251,8 +357,7 @@ public struct HTTPHeaders: CustomStringConvertible { /// Constructor used by our decoder to construct headers without the need of converting bytes to string. init(buffer: ByteBuffer, headers: [HTTPHeader]) { - self.buffer = buffer - self.headers = headers + self._storage = _Storage(buffer: buffer, headers: headers, continuous: true) } /// Construct a `HTTPHeaders` structure. @@ -295,13 +400,16 @@ public struct HTTPHeaders: CustomStringConvertible { /// - Parameter value: The header field value to add for the given name. public mutating func add(name: String, value: String) { precondition(!name.utf8.contains(where: { !$0.isASCII }), "name must be ASCII") + if !isKnownUniquelyReferenced(&self._storage) { + self._storage = self._storage.copy() + } let nameStart = self.buffer.writerIndex - let nameLength = self.buffer.write(string: name)! - self.buffer.write(staticString: headerSeparator) + let nameLength = self._storage.buffer.write(string: name)! + self._storage.buffer.write(staticString: headerSeparator) let valueStart = self.buffer.writerIndex - let valueLength = self.buffer.write(string: value)! - self.headers.append(HTTPHeader(name: HTTPHeaderIndex(start: nameStart, length: nameLength), value: HTTPHeaderIndex(start: valueStart, length: valueLength))) - self.buffer.write(staticString: crlf) + let valueLength = self._storage.buffer.write(string: value)! + self._storage.headers.append(HTTPHeader(name: HTTPHeaderIndex(start: nameStart, length: nameLength), value: HTTPHeaderIndex(start: valueStart, length: valueLength))) + self._storage.buffer.write(staticString: crlf) } /// Add a header name/value pair to the block, replacing any previous values for the @@ -346,10 +454,14 @@ public struct HTTPHeaders: CustomStringConvertible { return } + if !isKnownUniquelyReferenced(&self._storage) { + self._storage = self._storage.copy() + } + array.forEach { - self.headers.remove(at: $0) + self._storage.headers.remove(at: $0) } - self.continuous = false + self._storage.continuous = false } /// Retrieve all of the values for a give header field name from the block. diff --git a/Tests/NIOHTTP1Tests/HTTPTest+XCTest.swift b/Tests/NIOHTTP1Tests/HTTPTest+XCTest.swift index b4d8ac9761..03471251a4 100644 --- a/Tests/NIOHTTP1Tests/HTTPTest+XCTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPTest+XCTest.swift @@ -34,6 +34,8 @@ extension HTTPTest { ("test1ByteHTTPBody", test1ByteHTTPBody), ("testHTTPPipeliningWithBody", testHTTPPipeliningWithBody), ("testChunkedBody", testChunkedBody), + ("testHTTPRequestHeadCoWWorks", testHTTPRequestHeadCoWWorks), + ("testHTTPResponseHeadCoWWorks", testHTTPResponseHeadCoWWorks), ] } } diff --git a/Tests/NIOHTTP1Tests/HTTPTest.swift b/Tests/NIOHTTP1Tests/HTTPTest.swift index 952223ddde..86aebb45a2 100644 --- a/Tests/NIOHTTP1Tests/HTTPTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPTest.swift @@ -206,4 +206,62 @@ class HTTPTest: XCTestCase { trailers.add(name: "Something", value: "Else") try checkHTTPRequest(HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .POST, uri: "/"), body: "100", trailers: trailers) } + + func testHTTPRequestHeadCoWWorks() throws { + let headers = HTTPHeaders([("foo", "bar")]) + var httpReq = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/uri") + httpReq.headers = headers + + var modVersion = httpReq + modVersion.version = HTTPVersion(major: 2, minor: 0) + XCTAssertEqual(HTTPVersion(major: 1, minor: 1), httpReq.version) + XCTAssertEqual(HTTPVersion(major: 2, minor: 0), modVersion.version) + + var modMethod = httpReq + modMethod.method = .POST + XCTAssertEqual(.GET, httpReq.method) + XCTAssertEqual(.POST, modMethod.method) + + var modURI = httpReq + modURI.uri = "/changed" + XCTAssertEqual("/uri", httpReq.uri) + XCTAssertEqual("/changed", modURI.uri) + + var modHeaders = httpReq + modHeaders.headers.add(name: "qux", value: "quux") + XCTAssertEqual(httpReq.headers, headers) + XCTAssertNotEqual(httpReq, modHeaders) + modHeaders.headers.remove(name: "foo") + XCTAssertEqual(httpReq.headers, headers) + XCTAssertNotEqual(httpReq, modHeaders) + modHeaders.headers.remove(name: "qux") + modHeaders.headers.add(name: "foo", value: "bar") + XCTAssertEqual(httpReq, modHeaders) + } + + func testHTTPResponseHeadCoWWorks() throws { + let headers = HTTPHeaders([("foo", "bar")]) + let httpRes = HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers) + + var modVersion = httpRes + modVersion.version = HTTPVersion(major: 2, minor: 0) + XCTAssertEqual(HTTPVersion(major: 1, minor: 1), httpRes.version) + XCTAssertEqual(HTTPVersion(major: 2, minor: 0), modVersion.version) + + var modStatus = httpRes + modStatus.status = .notFound + XCTAssertEqual(.ok, httpRes.status) + XCTAssertEqual(.notFound, modStatus.status) + + var modHeaders = httpRes + modHeaders.headers.add(name: "qux", value: "quux") + XCTAssertEqual(httpRes.headers, headers) + XCTAssertNotEqual(httpRes, modHeaders) + modHeaders.headers.remove(name: "foo") + XCTAssertEqual(httpRes.headers, headers) + XCTAssertNotEqual(httpRes, modHeaders) + modHeaders.headers.remove(name: "qux") + modHeaders.headers.add(name: "foo", value: "bar") + XCTAssertEqual(httpRes, modHeaders) + } }