Skip to content

Commit

Permalink
Add headers.contains(...) method and make use of it in HTTPRequestEnc…
Browse files Browse the repository at this point in the history
…oder + minimize access to headers in encoder when possible
  • Loading branch information
normanmaurer committed Apr 9, 2018
1 parent 635b927 commit df862fa
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 11 deletions.
24 changes: 13 additions & 11 deletions Sources/NIOHTTP1/HTTPEncoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,6 @@ private func writeHead(wrapOutboundOut: (IOData) -> NIOAny, writeStartLine: (ino
ctx.write(wrapOutboundOut(.byteBuffer(buffer)), promise: promise)
}

private func isChunkedPart(_ headers: HTTPHeaders) -> Bool {
return headers["transfer-encoding"].contains("chunked")
}

/// Adjusts the response/request headers to ensure that the response/request will be well-framed.
///
/// This method strips Content-Length and Transfer-Encoding headers from responses/requests that must
Expand All @@ -95,16 +91,23 @@ private func isChunkedPart(_ headers: HTTPHeaders) -> Bool {
///
/// Note that for HTTP/1.0 if there is no Content-Length then the response should be followed
/// by connection close. We require that the user send that connection close: we don't do it.
private func sanitizeTransportHeaders(hasBody: HTTPMethod.HasBody, headers: inout HTTPHeaders, version: HTTPVersion) {
///
/// Returns true if its chunked.
private func sanitizeTransportHeaders(hasBody: HTTPMethod.HasBody, headers: inout HTTPHeaders, version: HTTPVersion) -> Bool {
switch hasBody {
case .no:
headers.remove(name: "content-length")
headers.remove(name: "transfer-encoding")
case .yes where headers["content-length"].count == 0 && version.major == 1 && version.minor >= 1:
return false
case .yes where version.major == 1 && version.minor >= 1:
if headers.contains(name: "content-length") {
return false
}
headers.replaceOrAdd(name: "transfer-encoding", value: "chunked")
return true
case .yes, .unlikely:
/* leave alone */
()
return headers.contains(name: "transfer-encoding", value: "chunked")
}
}

Expand All @@ -123,9 +126,8 @@ public final class HTTPRequestEncoder: ChannelOutboundHandler {
public func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
switch self.unwrapOutboundIn(data) {
case .head(var request):
sanitizeTransportHeaders(hasBody: request.method.hasRequestBody, headers: &request.headers, version: request.version)

self.isChunked = isChunkedPart(request.headers)
self.isChunked = sanitizeTransportHeaders(hasBody: request.method.hasRequestBody, headers: &request.headers, version: request.version)

writeHead(wrapOutboundOut: self.wrapOutboundOut, writeStartLine: { buffer in
request.method.write(buffer: &buffer)
Expand Down Expand Up @@ -158,9 +160,9 @@ public final class HTTPResponseEncoder: ChannelOutboundHandler {
public func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
switch self.unwrapOutboundIn(data) {
case .head(var response):
sanitizeTransportHeaders(hasBody: response.status.mayHaveResponseBody ? .yes : .no, headers: &response.headers, version: response.version)

self.isChunked = isChunkedPart(response.headers)
self.isChunked = sanitizeTransportHeaders(hasBody: response.status.mayHaveResponseBody ? .yes : .no, headers: &response.headers, version: response.version)

writeHead(wrapOutboundOut: self.wrapOutboundOut, writeStartLine: { buffer in
response.version.write(buffer: &buffer)
buffer.write(staticString: " ")
Expand Down
24 changes: 24 additions & 0 deletions Sources/NIOHTTP1/HTTPTypes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,30 @@ public struct HTTPHeaders: CustomStringConvertible {
return array
}

/// Checks if a header is present
///
/// - parameters:
/// - name: The name of the header
/// - value: The value of the header or nil if the name should not be checked
// - returns: `true` if a header with the name (and value) exists, `false` otherwise.
public func contains(name: String, value: String? = nil) -> Bool {
guard !self.headers.isEmpty else {
return false
}

let utf8 = name.utf8
let stringLength = utf8.count
for header in headers {
if stringLength == header.name.length && self.buffer.equalCaseInsensitiveASCII(seq: utf8, at: header.name) {
if let value = value {
return value.count == header.value.length && self.buffer.getString(at: header.value.start, length: header.value.length) == value
}
return true
}
}
return false
}

/// Serializes this HTTP header block to bytes suitable for writing to the wire.
///
/// - Parameter buffer: A buffer to write the serialized bytes into. Will increment
Expand Down

0 comments on commit df862fa

Please sign in to comment.