Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve InetSocketAddress and socket binding APIs. #322

Merged
merged 7 commits into from
Mar 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/src/main/scala/StreamsBasedServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import zio.clock.Clock
import zio.console.Console
import zio.duration._
import zio.nio.channels._
import zio.nio.core.SocketAddress
import zio.nio.core.InetSocketAddress
import zio.stream._

object StreamsBasedServer extends App {
Expand All @@ -16,7 +16,7 @@ object StreamsBasedServer extends App {
AsynchronousServerSocketChannel()
.use(socket =>
for {
_ <- SocketAddress.inetSocketAddress("localhost", port) >>= socket.bind
_ <- InetSocketAddress.hostName("localhost", port).flatMap(socket.bindTo(_))
_ <- ZStream
.repeatEffect(socket.accept.preallocate)
.map(_.withEarlyRelease)
Expand Down
4 changes: 2 additions & 2 deletions nio-core/src/main/scala/zio/nio/core/InetAddress.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ final class InetAddress private[nio] (private[nio] val jInetAddress: JInetAddres
IO.effect(jInetAddress.isReachable(networkInterface.jNetworkInterface, ttl, timeout))
.refineToOrDie[IOException]

def hostname: String = jInetAddress.getHostName
def hostName: String = jInetAddress.getHostName

def canonicalHostName: String = jInetAddress.getCanonicalHostName

def address: Array[Byte] = jInetAddress.getAddress
def address: Chunk[Byte] = Chunk.fromArray(jInetAddress.getAddress)

override def hashCode(): Int = jInetAddress.hashCode()

Expand Down
110 changes: 77 additions & 33 deletions nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package zio.nio.core

import java.net.{ InetSocketAddress => JInetSocketAddress, SocketAddress => JSocketAddress }
import java.net.{ UnknownHostException, InetSocketAddress => JInetSocketAddress, SocketAddress => JSocketAddress }

import zio.UIO
import zio.{ IO, UIO }

/**
* Representation of a socket address without a specific protocol.
*
* The concrete subclass [[InetSocketAddress]] is used in practice.
*/
sealed class SocketAddress private[nio] (private[nio] val jSocketAddress: JSocketAddress) {
sealed class SocketAddress protected (private[nio] val jSocketAddress: JSocketAddress) {

final override def equals(obj: Any): Boolean =
obj match {
Expand All @@ -22,14 +22,29 @@ sealed class SocketAddress private[nio] (private[nio] val jSocketAddress: JSocke
final override def toString: String = jSocketAddress.toString
}

object SocketAddress {

def fromJava(jSocketAddress: JSocketAddress): SocketAddress =
jSocketAddress match {
case inet: JInetSocketAddress =>
new InetSocketAddress(inet)
case other =>
new SocketAddress(other)
}

}

/**
* Representation of an IP Socket Address (IP address + port number).
*
* It can also be a pair (hostname + port number), in which case an attempt
* will be made to resolve the hostname.
* If resolution fails then the address is said to be unresolved but can still
* be used on some circumstances like connecting through a proxy.
* It provides an immutable object used by sockets for binding, connecting,
* However, note that network channels generally do ''not'' accept unresolved
* socket addresses.
*
* This class provides an immutable object used by sockets for binding, connecting,
* or as returned values.
*
* The wildcard is a special local IP address. It usually means "any" and can
Expand Down Expand Up @@ -75,59 +90,88 @@ final class InetSocketAddress private[nio] (private val jInetSocketAddress: JIne

}

object SocketAddress {

private[nio] def fromJava(jSocketAddress: JSocketAddress) =
jSocketAddress match {
case inet: JInetSocketAddress =>
new InetSocketAddress(inet)
case other =>
new SocketAddress(other)
}
object InetSocketAddress {

/**
* Creates a socket address where the IP address is the wildcard address and the port number a specified value.
*
* The socket address will be ''resolved''.
*/
def inetSocketAddress(port: Int): UIO[InetSocketAddress] = InetSocketAddress(port)
def wildCard(port: Int): UIO[InetSocketAddress] = UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(port)))

/**
* Creates a socket address from an IP address and a port number.
* Creates a socket address where the IP address is the wildcard address and
* the port is ephemeral.
*
* The socket address will be ''resolved''.
*/
def inetSocketAddress(hostname: String, port: Int): UIO[InetSocketAddress] = InetSocketAddress(hostname, port)
def wildCardEphemeral: UIO[InetSocketAddress] = wildCard(0)

/**
* Creates a socket address from a hostname and a port number.
*
* An attempt will be made to resolve the hostname into an `InetAddress`.
* If that attempt fails, the socket address will be flagged as ''unresolved''.
* This method will attempt to resolve the hostname; if this fails, the returned
* socket address will be ''unresolved''.
*/
def inetSocketAddress(address: InetAddress, port: Int): UIO[InetSocketAddress] = InetSocketAddress(address, port)
def hostName(hostName: String, port: Int): UIO[InetSocketAddress] =
UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(hostName, port)))

/**
* Creates an unresolved socket address from a hostname and a port number.
* Creates a resolved socket address from a hostname and port number.
*
* No attempt will be made to resolve the hostname into an `InetAddress`.
* The socket address will be flagged as ''unresolved''.
* If the hostname cannot be resolved, fails with `UnknownHostException`.
*/
def hostNameResolved(hostName: String, port: Int): IO[UnknownHostException, InetSocketAddress] =
InetAddress.byName(hostName).flatMap(inetAddress(_, port))

/**
* Creates a socket address from a hostname, with an ephemeral port.
*
* This method will attempt to resolve the hostname; if this fails, the returned
* socket address will be ''unresolved''.
*/
def unresolvedInetSocketAddress(hostname: String, port: Int): UIO[InetSocketAddress] =
InetSocketAddress.createUnresolved(hostname, port)
def hostNameEphemeral(hostName: String): UIO[InetSocketAddress] = this.hostName(hostName, 0)

private object InetSocketAddress {
/**
* Creates a resolved socket address from a hostname, with an ephemeral port.
*
* If the hostname cannot be resolved, fails with `UnknownHostException`.
*/
def hostNameEphemeralResolved(hostName: String): IO[UnknownHostException, InetSocketAddress] =
InetAddress.byName(hostName).flatMap(inetAddressEphemeral)

def apply(port: Int): UIO[InetSocketAddress] = UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(port)))
/**
* Creates a socket address from an IP address and a port number.
*/
def inetAddress(address: InetAddress, port: Int): UIO[InetSocketAddress] =
UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(address.jInetAddress, port)))

def apply(host: String, port: Int): UIO[InetSocketAddress] =
UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(host, port)))
/**
* Creates a socket address from an IP address, with an ephemeral port.
*/
def inetAddressEphemeral(address: InetAddress): UIO[InetSocketAddress] = inetAddress(address, 0)

def apply(addr: InetAddress, port: Int): UIO[InetSocketAddress] =
UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(addr.jInetAddress, port)))
/**
* Creates a socket address for localhost using the specified port.
*/
def localHost(port: Int): IO[UnknownHostException, InetSocketAddress] =
InetAddress.localHost.flatMap(inetAddress(_, port))

def createUnresolved(host: String, port: Int): UIO[InetSocketAddress] =
UIO.effectTotal(new InetSocketAddress(JInetSocketAddress.createUnresolved(host, port)))
/**
* Creates an unresolved socket address from a hostname and a port number.
*
* No attempt will be made to resolve the hostname into an `InetAddress`.
* The socket address will be flagged as ''unresolved''.
*/
def unresolvedHostName(hostName: String, port: Int): UIO[InetSocketAddress] =
UIO.effectTotal(new InetSocketAddress(JInetSocketAddress.createUnresolved(hostName, port)))

/**
* Creates an unresolved socket address from a hostname using an ephemeral port.
*
* No attempt will be made to resolve the hostname into an `InetAddress`.
* The socket address will be flagged as ''unresolved''.
*/
def unresolvedHostNameEphemeral(hostName: String): UIO[InetSocketAddress] = unresolvedHostName(hostName, 0)

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,16 @@ abstract class AsynchronousByteChannel private[channels] (protected val channel:

final class AsynchronousServerSocketChannel(protected val channel: JAsynchronousServerSocketChannel) extends Channel {

/**
* Binds the channel's socket to a local address and configures the socket
* to listen for connections.
*/
def bind(address: SocketAddress): IO[IOException, Unit] =
IO.effect(channel.bind(address.jSocketAddress)).refineToOrDie[IOException].unit
def bindTo(local: SocketAddress, backlog: Int = 0): IO[IOException, Unit] = bind(Some(local), backlog)

def bindAuto(backlog: Int = 0): IO[IOException, Unit] = bind(None, backlog)

/**
* Binds the channel's socket to a local address and configures the socket
* to listen for connections, up to backlog pending connection.
*/
def bind(address: SocketAddress, backlog: Int): IO[IOException, Unit] =
IO.effect(channel.bind(address.jSocketAddress, backlog)).refineToOrDie[IOException].unit
def bind(address: Option[SocketAddress], backlog: Int = 0): IO[IOException, Unit] =
IO.effect(channel.bind(address.map(_.jSocketAddress).orNull, backlog)).refineToOrDie[IOException].unit

def setOption[T](name: SocketOption[T], value: T): IO[IOException, Unit] =
IO.effect(channel.setOption(name, value)).refineToOrDie[IOException].unit
Expand Down Expand Up @@ -119,13 +116,17 @@ object AsynchronousServerSocketChannel {
.toNioManaged
}

class AsynchronousSocketChannel(override protected val channel: JAsynchronousSocketChannel)
final class AsynchronousSocketChannel(override protected val channel: JAsynchronousSocketChannel)
extends AsynchronousByteChannel(channel) {

final def bind(address: SocketAddress): IO[IOException, Unit] =
IO.effect(channel.bind(address.jSocketAddress)).refineToOrDie[IOException].unit
def bindTo(address: SocketAddress): IO[IOException, Unit] = bind(Some(address))

final def setOption[T](name: SocketOption[T], value: T): IO[IOException, Unit] =
def bindAuto: IO[IOException, Unit] = bind(None)

def bind(address: Option[SocketAddress]): IO[IOException, Unit] =
IO.effect(channel.bind(address.map(_.jSocketAddress).orNull)).refineToOrDie[IOException].unit

def setOption[T](name: SocketOption[T], value: T): IO[IOException, Unit] =
IO.effect(channel.setOption(name, value)).refineToOrDie[IOException].unit

final def shutdownInput: IO[IOException, Unit] = IO.effect(channel.shutdownInput()).refineToOrDie[IOException].unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ final class DatagramChannel private[channels] (override protected[channels] val
with SelectableChannel
with ScatteringByteChannel {

def bindTo(local: SocketAddress): IO[IOException, Unit] = bind(Some(local))

def bindAuto: IO[IOException, Unit] = bind(None)

/**
* Binds this channel's underlying socket to the given local address. Passing `None` binds to an
* automatically assigned local address.
Expand Down Expand Up @@ -57,7 +61,7 @@ final class DatagramChannel private[channels] (override protected[channels] val
* @return the local address if the socket is bound, otherwise `None`
*/
def localAddress: IO[IOException, Option[SocketAddress]] =
IO.effect(channel.getLocalAddress()).refineToOrDie[IOException].map(a => Option(a).map(new SocketAddress(_)))
IO.effect(channel.getLocalAddress()).refineToOrDie[IOException].map(a => Option(a).map(SocketAddress.fromJava))

/**
* Receives a datagram via this channel into the given [[zio.nio.core.ByteBuffer]].
Expand All @@ -66,15 +70,17 @@ final class DatagramChannel private[channels] (override protected[channels] val
* @return the socket address of the datagram's source, if available.
*/
def receive(dst: ByteBuffer): IO[IOException, Option[SocketAddress]] =
IO.effect(channel.receive(dst.byteBuffer)).refineToOrDie[IOException].map(a => Option(a).map(new SocketAddress(_)))
IO.effect(channel.receive(dst.byteBuffer))
.refineToOrDie[IOException]
.map(a => Option(a).map(SocketAddress.fromJava))

/**
* Optionally returns the remote socket address that this channel's underlying socket is connected to.
*
* @return the remote address if the socket is connected, otherwise `None`
*/
def remoteAddress: IO[IOException, Option[SocketAddress]] =
IO.effect(channel.getRemoteAddress()).refineToOrDie[IOException].map(a => Option(a).map(new SocketAddress(_)))
IO.effect(channel.getRemoteAddress()).refineToOrDie[IOException].map(a => Option(a).map(SocketAddress.fromJava))

/**
* Sends a datagram via this channel to the given target [[zio.nio.core.SocketAddress]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,12 @@ final class SocketChannel(override protected[channels] val channel: JSocketChann
with GatheringByteChannel
with ScatteringByteChannel {

def bind(local: SocketAddress): IO[IOException, Unit] =
IO.effect(channel.bind(local.jSocketAddress)).refineToOrDie[IOException].unit
def bindTo(address: SocketAddress): IO[IOException, Unit] = bind(Some(address))

def bindAuto: IO[IOException, Unit] = bind(None)

def bind(local: Option[SocketAddress]): IO[IOException, Unit] =
IO.effect(channel.bind(local.map(_.jSocketAddress).orNull)).refineToOrDie[IOException].unit

def setOption[T](name: SocketOption[T], value: T): IO[IOException, Unit] =
IO.effect(channel.setOption(name, value)).refineToOrDie[IOException].unit
Expand Down Expand Up @@ -110,12 +114,12 @@ object SocketChannel {
}

final class ServerSocketChannel(override protected val channel: JServerSocketChannel) extends SelectableChannel {
def bindTo(local: SocketAddress, backlog: Int = 0): IO[IOException, Unit] = bind(Some(local), backlog)

def bind(local: SocketAddress): IO[IOException, Unit] =
IO.effect(channel.bind(local.jSocketAddress)).refineToOrDie[IOException].unit
def bindAuto(backlog: Int = 0): IO[IOException, Unit] = bind(None, backlog)

def bind(local: SocketAddress, backlog: Int): IO[IOException, Unit] =
IO.effect(channel.bind(local.jSocketAddress, backlog)).refineToOrDie[IOException].unit
def bind(local: Option[SocketAddress], backlog: Int = 0): IO[IOException, Unit] =
IO.effect(channel.bind(local.map(_.jSocketAddress).orNull, backlog)).refineToOrDie[IOException].unit

def setOption[T](name: SocketOption[T], value: T): IO[IOException, Unit] =
IO.effect(channel.setOption(name, value)).refineToOrDie[IOException].unit
Expand All @@ -139,7 +143,7 @@ final class ServerSocketChannel(override protected val channel: JServerSocketCha
})

val localAddress: IO[IOException, SocketAddress] =
IO.effect(new SocketAddress(channel.getLocalAddress())).refineToOrDie[IOException]
IO.effect(SocketAddress.fromJava(channel.getLocalAddress())).refineToOrDie[IOException]
}

object ServerSocketChannel {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ final class SelectionKey(private[nio] val selectionKey: jc.SelectionKey) {
* {{{
* for {
* _ <- selector.select
* selectedKeys <- selector.selectedKeys
* _ <- IO.foreach_(selectedKeys) { key =>
* _ <- selector.foreachSelectedKey { key =>
* key.matchChannel { readyOps => {
* case channel: ServerSocketChannel if readyOps(Operation.Accept) =>
* // use `channel` to accept connection
Expand All @@ -68,7 +67,7 @@ final class SelectionKey(private[nio] val selectionKey: jc.SelectionKey) {
* IO.when(readyOps(Operation.Write)) {
* // use `channel` to write
* }
* } } *> selector.removeKey(key)
* } }
* }
* } yield ()
* }}}
Expand Down Expand Up @@ -126,6 +125,14 @@ final class SelectionKey(private[nio] val selectionKey: jc.SelectionKey) {

final def attachment: UIO[Option[AnyRef]] = IO.effectTotal(selectionKey.attachment()).map(Option(_))

override def toString: String = selectionKey.toString()
override def toString: String = selectionKey.toString

override def hashCode(): Int = selectionKey.hashCode()

override def equals(obj: Any): Boolean =
obj match {
case other: SelectionKey => selectionKey.equals(other.selectionKey)
case _ => false
}

}
Loading