Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic support for unix domain sockets #187

Merged
merged 11 commits into from
Apr 2, 2019
31 changes: 22 additions & 9 deletions Sources/KituraNet/ClientRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ public class ClientRequest {
/// A semaphore used to make ClientRequest.end() synchronous
let waitSemaphore = DispatchSemaphore(value: 0)

// Socket path for Unix domain sockets
var unixDomainSocketPath: String?

/**
Client request options enum. This allows the client to specify certain parameteres such as HTTP headers, HTTP methods, host names, and SSL credentials.

Expand Down Expand Up @@ -292,9 +295,11 @@ public class ClientRequest {
/// Initializes a `ClientRequest` instance
///
/// - Parameter options: An array of `Options' describing the request
/// - Parameter unixDomainSocketPath: Specifies a path of a Unix domain socket that the client should connect to.
/// - Parameter callback: The closure of type `Callback` to be used for the callback.
djones6 marked this conversation as resolved.
Show resolved Hide resolved
init(options: [Options], callback: @escaping Callback) {
init(options: [Options], unixDomainSocketPath: String? = nil, callback: @escaping Callback) {

self.unixDomainSocketPath = unixDomainSocketPath
self.callback = callback

var theSchema = "http://"
Expand Down Expand Up @@ -558,9 +563,13 @@ public class ClientRequest {

do {
guard let bootstrap = bootstrap else { return }
channel = try bootstrap.connect(host: hostName, port: Int(self.port!)).wait()
if let unixDomainSocketPath = self.unixDomainSocketPath {
channel = try bootstrap.connect(unixDomainSocketPath: unixDomainSocketPath).wait()
} else {
channel = try bootstrap.connect(host: hostName, port: Int(self.port!)).wait()
}
} catch let error {
Log.error("Connection to \(hostName):\(self.port ?? 80) failed with error: \(error)")
Log.error("Connection to \(hostName): self.unixDomainSocketPath ?? \(self.port ?? 80) failed with error: \(error)")
callback(nil)
return
}
Expand Down Expand Up @@ -759,13 +768,17 @@ class HTTPClientHandler: ChannelInboundHandler {
}
if url.starts(with: "/") {
let scheme = URL(string: clientRequest.url)?.scheme
let port = clientRequest.port.map { UInt16($0) }.map { $0.toInt16() }!
let request = ClientRequest(options: [.schema(scheme!),
.hostname(clientRequest.hostName!),
.port(port),
.path(url)],
callback: clientRequest.callback)
var options: [ClientRequest.Options] = [.schema(scheme!), .hostname(clientRequest.hostName!), .path(url)]
let request: ClientRequest
if let socketPath = self.clientRequest.unixDomainSocketPath {
request = ClientRequest(options: options, unixDomainSocketPath: socketPath, callback: clientRequest.callback)
} else {
let port = clientRequest.port.map { UInt16($0) }.map { $0.toInt16() }!
options.append(.port(port))
request = ClientRequest(options: options, callback: clientRequest.callback)
}
request.maxRedirects = self.clientRequest.maxRedirects - 1

// The next request can be asynchronously moved to a DispatchQueue.
// ClientRequest.end() calls connect().wait(), so we better move this to a dispatch queue.
// Because ClientRequest.end() is blocking, we mark the current task complete after the new task also completes.
Expand Down
12 changes: 7 additions & 5 deletions Sources/KituraNet/HTTP/HTTP.swift
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,20 @@ public class HTTP {
Create a new `ClientRequest` using a list of options.

- Parameter options: a list of `ClientRequest.Options`.
- Parameter callback: closure to run after the request.
- Parameter unixDomainSocketPath: the path of a Unix domain socket that this client should connect to (defaults to `nil`).
- Parameter callback: The closure to run after the request completes. The `ClientResponse?` parameter allows access to the response from the server.
- Returns: a `ClientRequest` instance

### Usage Example: ###
````swift
let request = HTTP.request([ClientRequest.Options]) {response in
...
let myOptions: [ClientRequest.Options] = [.hostname("localhost"), .port("8080")]
let request = HTTP.request(myOptions) { response in
// Process the ClientResponse
}
````
*/
public static func request(_ options: [ClientRequest.Options], callback: @escaping ClientRequest.Callback) -> ClientRequest {
return ClientRequest(options: options, callback: callback)
public static func request(_ options: [ClientRequest.Options], unixDomainSocketPath: String? = nil, callback: @escaping ClientRequest.Callback) -> ClientRequest {
return ClientRequest(options: options, unixDomainSocketPath: unixDomainSocketPath, callback: callback)
}

/**
Expand Down
1 change: 0 additions & 1 deletion Sources/KituraNet/HTTP/HTTPRequestHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ internal class HTTPRequestHandler: ChannelInboundHandler {

public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
let request = self.unwrapInboundIn(data)

// If an error response was already sent, we'd want to spare running through this for now.
// If an upgrade to WebSocket fails, both `errorCaught` and `channelRead` are triggered.
// We'd want to return the error via `errorCaught`.
Expand Down
97 changes: 83 additions & 14 deletions Sources/KituraNet/HTTP/HTTPServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ import LoggerAPI
import NIOWebSocket
import CLinuxHelpers

#if os(Linux)
import Glibc
#else
import Darwin
#endif

// MARK: HTTPServer
/**
An HTTP server that listens for connections on a socket.
Expand All @@ -49,16 +55,12 @@ public class HTTPServer: Server {
*/
public var delegate: ServerDelegate?

/**
Port number for listening for new connections.

### Usage Example: ###
````swift
httpServer.port = 8080
````
*/
/// The TCP port on which this server listens for new connections. If `nil`, this server does not listen on a TCP socket.
public private(set) var port: Int?

/// The Unix domain socket path on which this server listens for new connections. If `nil`, this server does not listen on a Unix socket.
public private(set) var unixDomainSocketPath: String?

private var _state: ServerState = .unknown

private let syncQ = DispatchQueue(label: "HTTPServer.syncQ")
Expand Down Expand Up @@ -225,8 +227,31 @@ public class HTTPServer: Server {
return nil
}

// Sockets could either be TCP/IP sockets or Unix domain sockets
private enum SocketType {
// An TCP/IP socket has an associated port number
case tcp(Int)
// A unix domain socket has an associated filename
case unix(String)
}

/**
Listens for connections on a socket.
Listens for connections on a Unix socket.

### Usage Example: ###
````swift
try server.listen(unixDomainSocketPath: "/my/path")
````

- Parameter unixDomainSocketPath: Unix socket path for new connections, eg. "/my/path"
*/
public func listen(unixDomainSocketPath: String) throws {
self.unixDomainSocketPath = unixDomainSocketPath
try listen(.unix(unixDomainSocketPath))
}

/**
Listens for connections on a TCP socket.

### Usage Example: ###
````swift
Expand All @@ -237,6 +262,10 @@ public class HTTPServer: Server {
*/
public func listen(on port: Int) throws {
self.port = port
try listen(.tcp(port))
}

private func listen(_ socket: SocketType) throws {

if let tlsConfig = tlsConfig {
do {
Expand Down Expand Up @@ -275,20 +304,40 @@ public class HTTPServer: Server {
}
.childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1)

let listenerDescription: String
do {
serverChannel = try bootstrap.bind(host: "0.0.0.0", port: port).wait()
self.port = serverChannel?.localAddress?.port.map { Int($0) }
switch socket {
case SocketType.tcp(let port):
serverChannel = try bootstrap.bind(host: "0.0.0.0", port: port).wait()
self.port = serverChannel?.localAddress?.port.map { Int($0) }
listenerDescription = "port \(self.port ?? port)"
case SocketType.unix(let unixDomainSocketPath):
// Ensure the path doesn't exist...
#if os(Linux)
_ = Glibc.unlink(unixDomainSocketPath)
#else
_ = Darwin.unlink(unixDomainSocketPath)
#endif
serverChannel = try bootstrap.bind(unixDomainSocketPath: unixDomainSocketPath).wait()
self.unixDomainSocketPath = unixDomainSocketPath
listenerDescription = "path \(unixDomainSocketPath)"
}
self.state = .started
self.lifecycleListener.performStartCallbacks()
} catch let error {
self.state = .failed
self.lifecycleListener.performFailCallbacks(with: error)
Log.error("Error trying to bind to \(port): \(error)")
switch socket {
case .tcp(let port):
Log.error("Error trying to bind to \(port): \(error)")
case .unix(let socketPath):
Log.error("Error trying to bind to \(socketPath): \(error)")
}
throw error
}

djones6 marked this conversation as resolved.
Show resolved Hide resolved
Log.info("Listening on port \(self.port!)")
Log.verbose("Options for port \(self.port!): maxPendingConnections: \(maxPendingConnections), allowPortReuse: \(self.allowPortReuse)")
Log.info("Listening on \(listenerDescription)")
Log.verbose("Options for \(listenerDescription): maxPendingConnections: \(maxPendingConnections), allowPortReuse: \(self.allowPortReuse)")

let queuedBlock = DispatchWorkItem(block: {
guard let serverChannel = self.serverChannel else { return }
Expand Down Expand Up @@ -323,6 +372,26 @@ public class HTTPServer: Server {
return server
}

/**
Static method to create a new HTTP server and have it listen for connections on a Unix domain socket.

### Usage Example: ###
````swift
let server = HTTPServer.listen(unixDomainSocketPath: "/my/path", delegate: self)
````

- Parameter unixDomainSocketPath: The path of the Unix domain socket that this server should listen on.
- Parameter delegate: The delegate handler for HTTP connections.

- Returns: A new instance of a `HTTPServer`.
*/
public static func listen(unixDomainSocketPath: String, delegate: ServerDelegate?) throws -> HTTPServer {
let server = HTTP.createServer()
server.delegate = delegate
try server.listen(unixDomainSocketPath: unixDomainSocketPath)
return server
}

/**
Listen for connections on a socket.

Expand Down
30 changes: 22 additions & 8 deletions Tests/KituraNetTests/KituraNIOTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class KituraNetTest: XCTestCase {
var useSSL = useSSLDefault
var port = portDefault

var unixDomainSocketPath: String? = nil

static let sslConfig: SSLService.Configuration = {
let sslConfigDir = URL(fileURLWithPath: #file).appendingPathComponent("../SSLConfig")

Expand Down Expand Up @@ -62,15 +64,19 @@ class KituraNetTest: XCTestCase {
func doTearDown() {
}

func startServer(_ delegate: ServerDelegate?, port: Int = portDefault, useSSL: Bool = useSSLDefault, allowPortReuse: Bool = portReuseDefault) throws -> HTTPServer {
func startServer(_ delegate: ServerDelegate?, unixDomainSocketPath: String? = nil, port: Int = portDefault, useSSL: Bool = useSSLDefault, allowPortReuse: Bool = portReuseDefault) throws -> HTTPServer {

let server = HTTP.createServer()
server.delegate = delegate
server.allowPortReuse = allowPortReuse
if useSSL {
server.sslConfig = KituraNetTest.sslConfig
}
try server.listen(on: port)
if let unixDomainSocketPath = unixDomainSocketPath {
try server.listen(unixDomainSocketPath: unixDomainSocketPath)
} else {
server.allowPortReuse = allowPortReuse
try server.listen(on: port)
}
return server
}

Expand All @@ -87,13 +93,21 @@ class KituraNetTest: XCTestCase {
return (server, serverPort)
}

func performServerTest(_ delegate: ServerDelegate?, port: Int = portDefault, useSSL: Bool = useSSLDefault, allowPortReuse: Bool = portReuseDefault,

func performServerTest(_ delegate: ServerDelegate?, unixDomainSocketPath: String? = nil, port: Int = portDefault, useSSL: Bool = useSSLDefault, allowPortReuse: Bool = portReuseDefault,
line: Int = #line, asyncTasks: (XCTestExpectation) -> Void...) {

do {
var server: HTTPServer
var ephemeralPort: Int = 0
self.useSSL = useSSL
let (server, ephemeralPort) = try startEphemeralServer(delegate, useSSL: useSSL, allowPortReuse: allowPortReuse)
self.port = ephemeralPort
if let unixDomainSocketPath = unixDomainSocketPath {
server = try startServer(delegate, unixDomainSocketPath: unixDomainSocketPath, useSSL: useSSL, allowPortReuse: allowPortReuse)
self.unixDomainSocketPath = unixDomainSocketPath
} else {
(server, ephemeralPort) = try startEphemeralServer(delegate, useSSL: useSSL, allowPortReuse: allowPortReuse)
self.port = ephemeralPort
}
defer {
server.stop()
}
Expand Down Expand Up @@ -145,7 +159,7 @@ class KituraNetTest: XCTestCase {
}
}*/

func performRequest(_ method: String, path: String, hostname: String = "localhost", close: Bool=true, callback: @escaping ClientRequest.Callback,
func performRequest(_ method: String, path: String, unixDomainSocketPath: String? = nil, hostname: String = "localhost", close: Bool=true, callback: @escaping ClientRequest.Callback,
headers: [String: String]? = nil, requestModifier: ((ClientRequest) -> Void)? = nil) {

var allHeaders = [String: String]()
Expand All @@ -163,7 +177,7 @@ class KituraNetTest: XCTestCase {
options.append(.disableSSLVerification)
}

let req = HTTP.request(options, callback: callback)
let req = HTTP.request(options, unixDomainSocketPath: unixDomainSocketPath, callback: callback)
if let requestModifier = requestModifier {
requestModifier(req)
}
Expand Down
Loading