Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic support for unix domain sockets #187

Merged
merged 11 commits into from
Apr 2, 2019
13 changes: 11 additions & 2 deletions Sources/KituraNet/ClientRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ public class ClientRequest {
/// A semaphore used to make ClientRequest.end() synchronous
let waitSemaphore = DispatchSemaphore(value: 0)

// Socket path for Unix domain sockets
private var unixDomainSocketPath: String?

/**
Client request options enum. This allows the client to specify certain parameteres such as HTTP headers, HTTP methods, host names, and SSL credentials.

Expand Down Expand Up @@ -292,9 +295,11 @@ public class ClientRequest {
/// Initializes a `ClientRequest` instance
///
/// - Parameter options: An array of `Options' describing the request
/// - Parameter unixDomainSocketPath: Specifies a path of a Unix domain socket that the client should connect to.
/// - Parameter callback: The closure of type `Callback` to be used for the callback.
djones6 marked this conversation as resolved.
Show resolved Hide resolved
init(options: [Options], callback: @escaping Callback) {
init(options: [Options], unixDomainSocketPath: String? = nil, callback: @escaping Callback) {

self.unixDomainSocketPath = unixDomainSocketPath
self.callback = callback

var theSchema = "http://"
Expand Down Expand Up @@ -558,7 +563,11 @@ public class ClientRequest {

do {
guard let bootstrap = bootstrap else { return }
channel = try bootstrap.connect(host: hostName, port: Int(self.port!)).wait()
if let unixDomainSocketPath = self.unixDomainSocketPath {
channel = try bootstrap.connect(unixDomainSocketPath: unixDomainSocketPath).wait()
} else {
channel = try bootstrap.connect(host: hostName, port: Int(self.port!)).wait()
}
} catch let error {
Log.error("Connection to \(hostName):\(self.port ?? 80) failed with error: \(error)")
callback(nil)
Expand Down
12 changes: 7 additions & 5 deletions Sources/KituraNet/HTTP/HTTP.swift
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,20 @@ public class HTTP {
Create a new `ClientRequest` using a list of options.

- Parameter options: a list of `ClientRequest.Options`.
- Parameter callback: closure to run after the request.
- Parameter unixDomainSocketPath: the path of a Unix domain socket that this client should connect to (defaults to `nil`).
- Parameter callback: The closure to run after the request completes. The `ClientResponse?` parameter allows access to the response from the server.
- Returns: a `ClientRequest` instance

### Usage Example: ###
````swift
let request = HTTP.request([ClientRequest.Options]) {response in
...
let myOptions: [ClientRequest.Options] = [.hostname("localhost"), .port("8080")]
let request = HTTP.request(myOptions) { response in
// Process the ClientResponse
}
````
*/
public static func request(_ options: [ClientRequest.Options], callback: @escaping ClientRequest.Callback) -> ClientRequest {
return ClientRequest(options: options, callback: callback)
public static func request(_ options: [ClientRequest.Options], unixDomainSocketPath: String? = nil, callback: @escaping ClientRequest.Callback) -> ClientRequest {
return ClientRequest(options: options, unixDomainSocketPath: unixDomainSocketPath, callback: callback)
}

/**
Expand Down
1 change: 0 additions & 1 deletion Sources/KituraNet/HTTP/HTTPRequestHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ internal class HTTPRequestHandler: ChannelInboundHandler {

public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
let request = self.unwrapInboundIn(data)

// If an error response was already sent, we'd want to spare running through this for now.
// If an upgrade to WebSocket fails, both `errorCaught` and `channelRead` are triggered.
// We'd want to return the error via `errorCaught`.
Expand Down
74 changes: 63 additions & 11 deletions Sources/KituraNet/HTTP/HTTPServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,12 @@ public class HTTPServer: Server {
*/
public var delegate: ServerDelegate?

/**
Port number for listening for new connections.

### Usage Example: ###
````swift
httpServer.port = 8080
````
*/
/// The TCP port on which this server listens for new connections. If `nil`, this server does not listen on a TCP socket.
public private(set) var port: Int?

/// The Unix domain socket path on which this server listens for new connections. If `nil`, this server does not listen on a Unix socket.
public private(set) var unixDomainSocketPath: String?

private var _state: ServerState = .unknown

private let syncQ = DispatchQueue(label: "HTTPServer.syncQ")
Expand Down Expand Up @@ -225,8 +221,31 @@ public class HTTPServer: Server {
return nil
}

// Sockets could either be TCP/IP sockets or Unix domain sockets
private enum SocketType {
// An TCP/IP socket has an associated port number
case tcp(Int)
// A unix domain socket has an associated filename
case unix(String)
}

/**
Listens for connections on a socket.
Listens for connections on a Unix socket.

### Usage Example: ###
````swift
try server.listen(unixDomainSocketPath: "/my/path")
````

- Parameter unixDomainSocketPath: Unix socket path for new connections, eg. "/my/path"
*/
public func listen(unixDomainSocketPath: String) throws {
self.unixDomainSocketPath = unixDomainSocketPath
try listen(.unix(unixDomainSocketPath))
}

/**
Listens for connections on a TCP socket.

### Usage Example: ###
````swift
Expand All @@ -237,6 +256,10 @@ public class HTTPServer: Server {
*/
public func listen(on port: Int) throws {
self.port = port
try listen(.tcp(port))
}

private func listen(_ socket: SocketType) throws {

if let tlsConfig = tlsConfig {
do {
Expand Down Expand Up @@ -276,14 +299,23 @@ public class HTTPServer: Server {
.childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1)

do {
serverChannel = try bootstrap.bind(host: "0.0.0.0", port: port).wait()
if case let SocketType.tcp(port) = socket {
serverChannel = try bootstrap.bind(host: "0.0.0.0", port: port).wait()
} else if case let SocketType.unix(unixDomainSocketPath) = socket {
serverChannel = try bootstrap.bind(unixDomainSocketPath: unixDomainSocketPath).wait()
}
self.port = serverChannel?.localAddress?.port.map { Int($0) }
self.state = .started
self.lifecycleListener.performStartCallbacks()
} catch let error {
self.state = .failed
self.lifecycleListener.performFailCallbacks(with: error)
Log.error("Error trying to bind to \(port): \(error)")
switch socket {
case .tcp(let port):
Log.error("Error trying to bind to \(port): \(error)")
case .unix(let socketPath):
Log.error("Error trying to bind to \(socketPath): \(error)")
}
throw error
}

djones6 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -323,6 +355,26 @@ public class HTTPServer: Server {
return server
}

/**
Static method to create a new HTTP server and have it listen for connections on a Unix domain socket.

### Usage Example: ###
````swift
let server = HTTPServer.listen(unixDomainSocketPath: "/my/path", delegate: self)
````

- Parameter unixDomainSocketPath: The path of the Unix domain socket that this server should listen on.
- Parameter delegate: The delegate handler for HTTP connections.

- Returns: A new instance of a `HTTPServer`.
*/
public static func listen(unixDomainSocketPath: String, delegate: ServerDelegate?) throws -> HTTPServer {
let server = HTTP.createServer()
server.delegate = delegate
try server.listen(unixDomainSocketPath: unixDomainSocketPath)
return server
}

/**
Listen for connections on a socket.

Expand Down
30 changes: 22 additions & 8 deletions Tests/KituraNetTests/KituraNIOTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class KituraNetTest: XCTestCase {
var useSSL = useSSLDefault
var port = portDefault

var unixDomainSocketPath: String? = nil

static let sslConfig: SSLService.Configuration = {
let sslConfigDir = URL(fileURLWithPath: #file).appendingPathComponent("../SSLConfig")

Expand Down Expand Up @@ -62,15 +64,19 @@ class KituraNetTest: XCTestCase {
func doTearDown() {
}

func startServer(_ delegate: ServerDelegate?, port: Int = portDefault, useSSL: Bool = useSSLDefault, allowPortReuse: Bool = portReuseDefault) throws -> HTTPServer {
func startServer(_ delegate: ServerDelegate?, unixDomainSocketPath: String? = nil, port: Int = portDefault, useSSL: Bool = useSSLDefault, allowPortReuse: Bool = portReuseDefault) throws -> HTTPServer {

let server = HTTP.createServer()
server.delegate = delegate
server.allowPortReuse = allowPortReuse
if useSSL {
server.sslConfig = KituraNetTest.sslConfig
}
try server.listen(on: port)
if let unixDomainSocketPath = unixDomainSocketPath {
try server.listen(unixDomainSocketPath: unixDomainSocketPath)
} else {
server.allowPortReuse = allowPortReuse
try server.listen(on: port)
}
return server
}

Expand All @@ -87,13 +93,21 @@ class KituraNetTest: XCTestCase {
return (server, serverPort)
}

func performServerTest(_ delegate: ServerDelegate?, port: Int = portDefault, useSSL: Bool = useSSLDefault, allowPortReuse: Bool = portReuseDefault,

func performServerTest(_ delegate: ServerDelegate?, unixDomainSocketPath: String? = nil, port: Int = portDefault, useSSL: Bool = useSSLDefault, allowPortReuse: Bool = portReuseDefault,
line: Int = #line, asyncTasks: (XCTestExpectation) -> Void...) {

do {
var server: HTTPServer
var ephemeralPort: Int = 0
self.useSSL = useSSL
let (server, ephemeralPort) = try startEphemeralServer(delegate, useSSL: useSSL, allowPortReuse: allowPortReuse)
self.port = ephemeralPort
if let unixDomainSocketPath = unixDomainSocketPath {
server = try startServer(delegate, unixDomainSocketPath: unixDomainSocketPath, useSSL: useSSL, allowPortReuse: allowPortReuse)
self.unixDomainSocketPath = unixDomainSocketPath
} else {
(server, ephemeralPort) = try startEphemeralServer(delegate, useSSL: useSSL, allowPortReuse: allowPortReuse)
self.port = ephemeralPort
}
defer {
server.stop()
}
Expand Down Expand Up @@ -145,7 +159,7 @@ class KituraNetTest: XCTestCase {
}
}*/

func performRequest(_ method: String, path: String, hostname: String = "localhost", close: Bool=true, callback: @escaping ClientRequest.Callback,
func performRequest(_ method: String, path: String, unixDomainSocketPath: String? = nil, hostname: String = "localhost", close: Bool=true, callback: @escaping ClientRequest.Callback,
headers: [String: String]? = nil, requestModifier: ((ClientRequest) -> Void)? = nil) {

var allHeaders = [String: String]()
Expand All @@ -163,7 +177,7 @@ class KituraNetTest: XCTestCase {
options.append(.disableSSLVerification)
}

let req = HTTP.request(options, callback: callback)
let req = HTTP.request(options, unixDomainSocketPath: unixDomainSocketPath, callback: callback)
if let requestModifier = requestModifier {
requestModifier(req)
}
Expand Down
103 changes: 103 additions & 0 deletions Tests/KituraNetTests/UnixDomainSocketTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/**
* Copyright IBM Corporation 2016
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
**/

import Foundation

import XCTest

@testable import KituraNet

class UnixDomainSocketTests: KituraNetTest {

static var allTests: [(String, (UnixDomainSocketTests) -> () throws -> Void)] {
return [
("testPostRequestWithUnixDomainSocket", testPostRequestWithUnixDomainSocket),
]
}

override func setUp() {
doSetUp()

// set up the unix socket file
#if os(Linux)
let temporaryDirectory = "/tmp"
#else
var temporaryDirectory: String
if #available(OSX 10.12, *) {
temporaryDirectory = FileManager.default.temporaryDirectory.path
} else {
temporaryDirectory = "/tmp"
}
#endif
self.socketFilePath = temporaryDirectory + "/" + String(ProcessInfo.processInfo.globallyUniqueString.prefix(20))
}

override func tearDown() {
doTearDown()
let fileURL = URL(fileURLWithPath: socketFilePath)
let fm = FileManager.default
do {
try fm.removeItem(at: fileURL)
} catch {
XCTFail(error.localizedDescription)
}
}

private var socketFilePath = ""
private let delegate = TestServerDelegate()

func testPostRequestWithUnixDomainSocket() {
performServerTest(delegate, unixDomainSocketPath: socketFilePath, useSSL: false, asyncTasks: { expectation in
let payload = "[" + contentTypesString + "," + contentTypesString + "]"
self.performRequest("post", path: "/uds", unixDomainSocketPath: self.socketFilePath, callback: {response in
XCTAssertEqual(response?.statusCode, HTTPStatusCode.OK, "Status code wasn't .Ok was \(String(describing: response?.statusCode))")
do {
let expected = "Read \(payload.count) bytes"
var data = Data()
let count = try response?.readAllData(into: &data)
XCTAssertEqual(count, expected.count, "Result should have been \(expected.count) bytes, was \(String(describing: count)) bytes")
let postValue = String(data: data, encoding: .utf8)
if let postValue = postValue {
XCTAssertEqual(postValue, expected)
} else {
XCTFail("postValue's value wasn't an UTF8 string")
}
} catch {
XCTFail("Failed reading the body of the response")
}
expectation.fulfill()
}) {request in
request.write(from: payload)
}
})
}

private class TestServerDelegate: ServerDelegate {
func handle(request: ServerRequest, response: ServerResponse) {
var body = Data()
do {
let length = try request.readAllData(into: &body)
let result = "Read \(length) bytes"
response.headers["Content-Type"] = ["text/plain"]
response.headers["Content-Length"] = ["\(result.count)"]

try response.end(text: result)
} catch {
print("Error reading body or writing response")
}
}
}
}
3 changes: 2 additions & 1 deletion Tests/LinuxMain.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,6 @@ XCTMain([
testCase(ClientE2ETests.allTests.shuffled()),
testCase(PipeliningTests.allTests.shuffled()),
testCase(RegressionTests.allTests.shuffled()),
testCase(MonitoringTests.allTests.shuffled())
testCase(MonitoringTests.allTests.shuffled()),
testCase(UnixDomainSocketTests.allTests.shuffled())
].shuffled())