diff --git a/examples/src/main/scala/StreamsBasedServer.scala b/examples/src/main/scala/StreamsBasedServer.scala index fae9b9ab..20413c0d 100644 --- a/examples/src/main/scala/StreamsBasedServer.scala +++ b/examples/src/main/scala/StreamsBasedServer.scala @@ -3,7 +3,7 @@ import zio.clock.Clock import zio.console.Console import zio.duration._ import zio.nio.channels._ -import zio.nio.core.SocketAddress +import zio.nio.core.InetSocketAddress import zio.stream._ object StreamsBasedServer extends App { @@ -16,7 +16,7 @@ object StreamsBasedServer extends App { AsynchronousServerSocketChannel() .use(socket => for { - _ <- SocketAddress.inetSocketAddress("localhost", port) >>= socket.bind + _ <- InetSocketAddress.hostName("localhost", port).flatMap(socket.bindTo(_)) _ <- ZStream .repeatEffect(socket.accept.preallocate) .map(_.withEarlyRelease) diff --git a/nio-core/src/main/scala/zio/nio/core/InetAddress.scala b/nio-core/src/main/scala/zio/nio/core/InetAddress.scala index 3a901cf1..02cb3483 100644 --- a/nio-core/src/main/scala/zio/nio/core/InetAddress.scala +++ b/nio-core/src/main/scala/zio/nio/core/InetAddress.scala @@ -42,11 +42,11 @@ final class InetAddress private[nio] (private[nio] val jInetAddress: JInetAddres IO.effect(jInetAddress.isReachable(networkInterface.jNetworkInterface, ttl, timeout)) .refineToOrDie[IOException] - def hostname: String = jInetAddress.getHostName + def hostName: String = jInetAddress.getHostName def canonicalHostName: String = jInetAddress.getCanonicalHostName - def address: Array[Byte] = jInetAddress.getAddress + def address: Chunk[Byte] = Chunk.fromArray(jInetAddress.getAddress) override def hashCode(): Int = jInetAddress.hashCode() diff --git a/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala b/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala index 7fd6ab26..e4a3e58b 100644 --- a/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala +++ b/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala @@ -1,15 +1,15 @@ package zio.nio.core -import java.net.{ InetSocketAddress => JInetSocketAddress, SocketAddress => JSocketAddress } +import java.net.{ UnknownHostException, InetSocketAddress => JInetSocketAddress, SocketAddress => JSocketAddress } -import zio.UIO +import zio.{ IO, UIO } /** * Representation of a socket address without a specific protocol. * * The concrete subclass [[InetSocketAddress]] is used in practice. */ -sealed class SocketAddress private[nio] (private[nio] val jSocketAddress: JSocketAddress) { +sealed class SocketAddress protected (private[nio] val jSocketAddress: JSocketAddress) { final override def equals(obj: Any): Boolean = obj match { @@ -22,6 +22,18 @@ sealed class SocketAddress private[nio] (private[nio] val jSocketAddress: JSocke final override def toString: String = jSocketAddress.toString } +object SocketAddress { + + def fromJava(jSocketAddress: JSocketAddress): SocketAddress = + jSocketAddress match { + case inet: JInetSocketAddress => + new InetSocketAddress(inet) + case other => + new SocketAddress(other) + } + +} + /** * Representation of an IP Socket Address (IP address + port number). * @@ -29,7 +41,10 @@ sealed class SocketAddress private[nio] (private[nio] val jSocketAddress: JSocke * will be made to resolve the hostname. * If resolution fails then the address is said to be unresolved but can still * be used on some circumstances like connecting through a proxy. - * It provides an immutable object used by sockets for binding, connecting, + * However, note that network channels generally do ''not'' accept unresolved + * socket addresses. + * + * This class provides an immutable object used by sockets for binding, connecting, * or as returned values. * * The wildcard is a special local IP address. It usually means "any" and can @@ -75,59 +90,88 @@ final class InetSocketAddress private[nio] (private val jInetSocketAddress: JIne } -object SocketAddress { - - private[nio] def fromJava(jSocketAddress: JSocketAddress) = - jSocketAddress match { - case inet: JInetSocketAddress => - new InetSocketAddress(inet) - case other => - new SocketAddress(other) - } +object InetSocketAddress { /** * Creates a socket address where the IP address is the wildcard address and the port number a specified value. * * The socket address will be ''resolved''. */ - def inetSocketAddress(port: Int): UIO[InetSocketAddress] = InetSocketAddress(port) + def wildCard(port: Int): UIO[InetSocketAddress] = UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(port))) /** - * Creates a socket address from an IP address and a port number. + * Creates a socket address where the IP address is the wildcard address and + * the port is ephemeral. * * The socket address will be ''resolved''. */ - def inetSocketAddress(hostname: String, port: Int): UIO[InetSocketAddress] = InetSocketAddress(hostname, port) + def wildCardEphemeral: UIO[InetSocketAddress] = wildCard(0) /** * Creates a socket address from a hostname and a port number. * - * An attempt will be made to resolve the hostname into an `InetAddress`. - * If that attempt fails, the socket address will be flagged as ''unresolved''. + * This method will attempt to resolve the hostname; if this fails, the returned + * socket address will be ''unresolved''. */ - def inetSocketAddress(address: InetAddress, port: Int): UIO[InetSocketAddress] = InetSocketAddress(address, port) + def hostName(hostName: String, port: Int): UIO[InetSocketAddress] = + UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(hostName, port))) /** - * Creates an unresolved socket address from a hostname and a port number. + * Creates a resolved socket address from a hostname and port number. * - * No attempt will be made to resolve the hostname into an `InetAddress`. - * The socket address will be flagged as ''unresolved''. + * If the hostname cannot be resolved, fails with `UnknownHostException`. + */ + def hostNameResolved(hostName: String, port: Int): IO[UnknownHostException, InetSocketAddress] = + InetAddress.byName(hostName).flatMap(inetAddress(_, port)) + + /** + * Creates a socket address from a hostname, with an ephemeral port. + * + * This method will attempt to resolve the hostname; if this fails, the returned + * socket address will be ''unresolved''. */ - def unresolvedInetSocketAddress(hostname: String, port: Int): UIO[InetSocketAddress] = - InetSocketAddress.createUnresolved(hostname, port) + def hostNameEphemeral(hostName: String): UIO[InetSocketAddress] = this.hostName(hostName, 0) - private object InetSocketAddress { + /** + * Creates a resolved socket address from a hostname, with an ephemeral port. + * + * If the hostname cannot be resolved, fails with `UnknownHostException`. + */ + def hostNameEphemeralResolved(hostName: String): IO[UnknownHostException, InetSocketAddress] = + InetAddress.byName(hostName).flatMap(inetAddressEphemeral) - def apply(port: Int): UIO[InetSocketAddress] = UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(port))) + /** + * Creates a socket address from an IP address and a port number. + */ + def inetAddress(address: InetAddress, port: Int): UIO[InetSocketAddress] = + UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(address.jInetAddress, port))) - def apply(host: String, port: Int): UIO[InetSocketAddress] = - UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(host, port))) + /** + * Creates a socket address from an IP address, with an ephemeral port. + */ + def inetAddressEphemeral(address: InetAddress): UIO[InetSocketAddress] = inetAddress(address, 0) - def apply(addr: InetAddress, port: Int): UIO[InetSocketAddress] = - UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(addr.jInetAddress, port))) + /** + * Creates a socket address for localhost using the specified port. + */ + def localHost(port: Int): IO[UnknownHostException, InetSocketAddress] = + InetAddress.localHost.flatMap(inetAddress(_, port)) - def createUnresolved(host: String, port: Int): UIO[InetSocketAddress] = - UIO.effectTotal(new InetSocketAddress(JInetSocketAddress.createUnresolved(host, port))) + /** + * Creates an unresolved socket address from a hostname and a port number. + * + * No attempt will be made to resolve the hostname into an `InetAddress`. + * The socket address will be flagged as ''unresolved''. + */ + def unresolvedHostName(hostName: String, port: Int): UIO[InetSocketAddress] = + UIO.effectTotal(new InetSocketAddress(JInetSocketAddress.createUnresolved(hostName, port))) + + /** + * Creates an unresolved socket address from a hostname using an ephemeral port. + * + * No attempt will be made to resolve the hostname into an `InetAddress`. + * The socket address will be flagged as ''unresolved''. + */ + def unresolvedHostNameEphemeral(hostName: String): UIO[InetSocketAddress] = unresolvedHostName(hostName, 0) - } } diff --git a/nio-core/src/main/scala/zio/nio/core/channels/AsynchronousChannel.scala b/nio-core/src/main/scala/zio/nio/core/channels/AsynchronousChannel.scala index bc605ec3..53ba58b7 100644 --- a/nio-core/src/main/scala/zio/nio/core/channels/AsynchronousChannel.scala +++ b/nio-core/src/main/scala/zio/nio/core/channels/AsynchronousChannel.scala @@ -65,19 +65,16 @@ abstract class AsynchronousByteChannel private[channels] (protected val channel: final class AsynchronousServerSocketChannel(protected val channel: JAsynchronousServerSocketChannel) extends Channel { - /** - * Binds the channel's socket to a local address and configures the socket - * to listen for connections. - */ - def bind(address: SocketAddress): IO[IOException, Unit] = - IO.effect(channel.bind(address.jSocketAddress)).refineToOrDie[IOException].unit + def bindTo(local: SocketAddress, backlog: Int = 0): IO[IOException, Unit] = bind(Some(local), backlog) + + def bindAuto(backlog: Int = 0): IO[IOException, Unit] = bind(None, backlog) /** * Binds the channel's socket to a local address and configures the socket * to listen for connections, up to backlog pending connection. */ - def bind(address: SocketAddress, backlog: Int): IO[IOException, Unit] = - IO.effect(channel.bind(address.jSocketAddress, backlog)).refineToOrDie[IOException].unit + def bind(address: Option[SocketAddress], backlog: Int = 0): IO[IOException, Unit] = + IO.effect(channel.bind(address.map(_.jSocketAddress).orNull, backlog)).refineToOrDie[IOException].unit def setOption[T](name: SocketOption[T], value: T): IO[IOException, Unit] = IO.effect(channel.setOption(name, value)).refineToOrDie[IOException].unit @@ -119,13 +116,17 @@ object AsynchronousServerSocketChannel { .toNioManaged } -class AsynchronousSocketChannel(override protected val channel: JAsynchronousSocketChannel) +final class AsynchronousSocketChannel(override protected val channel: JAsynchronousSocketChannel) extends AsynchronousByteChannel(channel) { - final def bind(address: SocketAddress): IO[IOException, Unit] = - IO.effect(channel.bind(address.jSocketAddress)).refineToOrDie[IOException].unit + def bindTo(address: SocketAddress): IO[IOException, Unit] = bind(Some(address)) - final def setOption[T](name: SocketOption[T], value: T): IO[IOException, Unit] = + def bindAuto: IO[IOException, Unit] = bind(None) + + def bind(address: Option[SocketAddress]): IO[IOException, Unit] = + IO.effect(channel.bind(address.map(_.jSocketAddress).orNull)).refineToOrDie[IOException].unit + + def setOption[T](name: SocketOption[T], value: T): IO[IOException, Unit] = IO.effect(channel.setOption(name, value)).refineToOrDie[IOException].unit final def shutdownInput: IO[IOException, Unit] = IO.effect(channel.shutdownInput()).refineToOrDie[IOException].unit diff --git a/nio-core/src/main/scala/zio/nio/core/channels/DatagramChannel.scala b/nio-core/src/main/scala/zio/nio/core/channels/DatagramChannel.scala index f8599ee0..5d5c20f0 100644 --- a/nio-core/src/main/scala/zio/nio/core/channels/DatagramChannel.scala +++ b/nio-core/src/main/scala/zio/nio/core/channels/DatagramChannel.scala @@ -15,6 +15,10 @@ final class DatagramChannel private[channels] (override protected[channels] val with SelectableChannel with ScatteringByteChannel { + def bindTo(local: SocketAddress): IO[IOException, Unit] = bind(Some(local)) + + def bindAuto: IO[IOException, Unit] = bind(None) + /** * Binds this channel's underlying socket to the given local address. Passing `None` binds to an * automatically assigned local address. @@ -57,7 +61,7 @@ final class DatagramChannel private[channels] (override protected[channels] val * @return the local address if the socket is bound, otherwise `None` */ def localAddress: IO[IOException, Option[SocketAddress]] = - IO.effect(channel.getLocalAddress()).refineToOrDie[IOException].map(a => Option(a).map(new SocketAddress(_))) + IO.effect(channel.getLocalAddress()).refineToOrDie[IOException].map(a => Option(a).map(SocketAddress.fromJava)) /** * Receives a datagram via this channel into the given [[zio.nio.core.ByteBuffer]]. @@ -66,7 +70,9 @@ final class DatagramChannel private[channels] (override protected[channels] val * @return the socket address of the datagram's source, if available. */ def receive(dst: ByteBuffer): IO[IOException, Option[SocketAddress]] = - IO.effect(channel.receive(dst.byteBuffer)).refineToOrDie[IOException].map(a => Option(a).map(new SocketAddress(_))) + IO.effect(channel.receive(dst.byteBuffer)) + .refineToOrDie[IOException] + .map(a => Option(a).map(SocketAddress.fromJava)) /** * Optionally returns the remote socket address that this channel's underlying socket is connected to. @@ -74,7 +80,7 @@ final class DatagramChannel private[channels] (override protected[channels] val * @return the remote address if the socket is connected, otherwise `None` */ def remoteAddress: IO[IOException, Option[SocketAddress]] = - IO.effect(channel.getRemoteAddress()).refineToOrDie[IOException].map(a => Option(a).map(new SocketAddress(_))) + IO.effect(channel.getRemoteAddress()).refineToOrDie[IOException].map(a => Option(a).map(SocketAddress.fromJava)) /** * Sends a datagram via this channel to the given target [[zio.nio.core.SocketAddress]]. diff --git a/nio-core/src/main/scala/zio/nio/core/channels/SelectableChannel.scala b/nio-core/src/main/scala/zio/nio/core/channels/SelectableChannel.scala index 9cc8b9b9..7b9a51ee 100644 --- a/nio-core/src/main/scala/zio/nio/core/channels/SelectableChannel.scala +++ b/nio-core/src/main/scala/zio/nio/core/channels/SelectableChannel.scala @@ -63,8 +63,12 @@ final class SocketChannel(override protected[channels] val channel: JSocketChann with GatheringByteChannel with ScatteringByteChannel { - def bind(local: SocketAddress): IO[IOException, Unit] = - IO.effect(channel.bind(local.jSocketAddress)).refineToOrDie[IOException].unit + def bindTo(address: SocketAddress): IO[IOException, Unit] = bind(Some(address)) + + def bindAuto: IO[IOException, Unit] = bind(None) + + def bind(local: Option[SocketAddress]): IO[IOException, Unit] = + IO.effect(channel.bind(local.map(_.jSocketAddress).orNull)).refineToOrDie[IOException].unit def setOption[T](name: SocketOption[T], value: T): IO[IOException, Unit] = IO.effect(channel.setOption(name, value)).refineToOrDie[IOException].unit @@ -110,12 +114,12 @@ object SocketChannel { } final class ServerSocketChannel(override protected val channel: JServerSocketChannel) extends SelectableChannel { + def bindTo(local: SocketAddress, backlog: Int = 0): IO[IOException, Unit] = bind(Some(local), backlog) - def bind(local: SocketAddress): IO[IOException, Unit] = - IO.effect(channel.bind(local.jSocketAddress)).refineToOrDie[IOException].unit + def bindAuto(backlog: Int = 0): IO[IOException, Unit] = bind(None, backlog) - def bind(local: SocketAddress, backlog: Int): IO[IOException, Unit] = - IO.effect(channel.bind(local.jSocketAddress, backlog)).refineToOrDie[IOException].unit + def bind(local: Option[SocketAddress], backlog: Int = 0): IO[IOException, Unit] = + IO.effect(channel.bind(local.map(_.jSocketAddress).orNull, backlog)).refineToOrDie[IOException].unit def setOption[T](name: SocketOption[T], value: T): IO[IOException, Unit] = IO.effect(channel.setOption(name, value)).refineToOrDie[IOException].unit @@ -139,7 +143,7 @@ final class ServerSocketChannel(override protected val channel: JServerSocketCha }) val localAddress: IO[IOException, SocketAddress] = - IO.effect(new SocketAddress(channel.getLocalAddress())).refineToOrDie[IOException] + IO.effect(SocketAddress.fromJava(channel.getLocalAddress())).refineToOrDie[IOException] } object ServerSocketChannel { diff --git a/nio-core/src/main/scala/zio/nio/core/channels/SelectionKey.scala b/nio-core/src/main/scala/zio/nio/core/channels/SelectionKey.scala index 62ddb1d7..67248a3d 100644 --- a/nio-core/src/main/scala/zio/nio/core/channels/SelectionKey.scala +++ b/nio-core/src/main/scala/zio/nio/core/channels/SelectionKey.scala @@ -56,8 +56,7 @@ final class SelectionKey(private[nio] val selectionKey: jc.SelectionKey) { * {{{ * for { * _ <- selector.select - * selectedKeys <- selector.selectedKeys - * _ <- IO.foreach_(selectedKeys) { key => + * _ <- selector.foreachSelectedKey { key => * key.matchChannel { readyOps => { * case channel: ServerSocketChannel if readyOps(Operation.Accept) => * // use `channel` to accept connection @@ -68,7 +67,7 @@ final class SelectionKey(private[nio] val selectionKey: jc.SelectionKey) { * IO.when(readyOps(Operation.Write)) { * // use `channel` to write * } - * } } *> selector.removeKey(key) + * } } * } * } yield () * }}} @@ -126,6 +125,14 @@ final class SelectionKey(private[nio] val selectionKey: jc.SelectionKey) { final def attachment: UIO[Option[AnyRef]] = IO.effectTotal(selectionKey.attachment()).map(Option(_)) - override def toString: String = selectionKey.toString() + override def toString: String = selectionKey.toString + + override def hashCode(): Int = selectionKey.hashCode() + + override def equals(obj: Any): Boolean = + obj match { + case other: SelectionKey => selectionKey.equals(other.selectionKey) + case _ => false + } } diff --git a/nio-core/src/test/scala/zio/nio/core/channels/ChannelSpec.scala b/nio-core/src/test/scala/zio/nio/core/channels/ChannelSpec.scala index 2f37951a..6ae0c962 100644 --- a/nio-core/src/test/scala/zio/nio/core/channels/ChannelSpec.scala +++ b/nio-core/src/test/scala/zio/nio/core/channels/ChannelSpec.scala @@ -1,11 +1,11 @@ package zio.nio.core.channels -import java.io.{ EOFException, FileNotFoundException, IOException } - import zio.nio.core.{ BaseSpec, Buffer, EffectOps, SocketAddress } -import zio.{ IO, _ } -import zio.test._ import zio.test.Assertion._ +import zio.test._ +import zio.{ IO, _ } + +import java.io.{ EOFException, FileNotFoundException, IOException } object ChannelSpec extends BaseSpec { @@ -14,23 +14,22 @@ object ChannelSpec extends BaseSpec { testM("read/write") { def echoServer(started: Promise[Nothing, SocketAddress]): IO[Exception, Unit] = for { - address <- SocketAddress.inetSocketAddress(0) - sink <- Buffer.byte(3) - _ <- AsynchronousServerSocketChannel - .open() - .use { server => - for { - _ <- server.bind(address) - addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) - _ <- started.succeed(addr) - _ <- server.accept.use { worker => - worker.read(sink) *> - sink.flip *> - worker.write(sink) - } - } yield () - } - .fork + sink <- Buffer.byte(3) + _ <- AsynchronousServerSocketChannel + .open() + .use { server => + for { + _ <- server.bindAuto() + addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) + _ <- started.succeed(addr) + _ <- server.accept.use { worker => + worker.read(sink) *> + sink.flip *> + worker.write(sink) + } + } yield () + } + .fork } yield () def echoClient(address: SocketAddress): IO[Exception, Boolean] = @@ -59,22 +58,21 @@ object ChannelSpec extends BaseSpec { testM("read should fail when connection close") { def server(started: Promise[Nothing, SocketAddress]): IO[Exception, Fiber[Exception, Boolean]] = for { - address <- SocketAddress.inetSocketAddress(0) - result <- AsynchronousServerSocketChannel - .open() - .use { server => - for { - _ <- server.bind(address) - addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) - _ <- started.succeed(addr) - result <- server.accept - .use(worker => worker.readChunk(3) *> worker.readChunk(3) *> ZIO.succeed(false)) - .catchSome { case _: java.io.EOFException => - ZIO.succeed(true) - } - } yield result - } - .fork + result <- AsynchronousServerSocketChannel + .open() + .use { server => + for { + _ <- server.bindAuto() + addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) + _ <- started.succeed(addr) + result <- server.accept + .use(worker => worker.readChunk(3) *> worker.readChunk(3) *> ZIO.succeed(false)) + .catchSome { case _: java.io.EOFException => + ZIO.succeed(true) + } + } yield result + } + .fork } yield result def client(address: SocketAddress): IO[Exception, Unit] = @@ -102,25 +100,23 @@ object ChannelSpec extends BaseSpec { } def server( - address: SocketAddress, started: Promise[Nothing, SocketAddress] ): Managed[IOException, Fiber[Exception, Unit]] = for { server <- AsynchronousServerSocketChannel.open() - _ <- server.bind(address).toManaged_ + _ <- server.bindAuto().toManaged_ addr <- server.localAddress.someOrElseM(IO.die(new NoSuchElementException)).toManaged_ _ <- started.succeed(addr).toManaged_ worker <- server.accept.unit.fork } yield worker for { - address <- SocketAddress.inetSocketAddress(0) serverStarted1 <- Promise.make[Nothing, SocketAddress] - _ <- server(address, serverStarted1).use { s1 => + _ <- server(serverStarted1).use { s1 => serverStarted1.await.flatMap(client).zipRight(s1.join) } serverStarted2 <- Promise.make[Nothing, SocketAddress] - _ <- server(address, serverStarted2).use { s2 => + _ <- server(serverStarted2).use { s2 => serverStarted2.await.flatMap(client).zipRight(s2.join) } } yield assertCompletes diff --git a/nio-core/src/test/scala/zio/nio/core/channels/DatagramChannelSpec.scala b/nio-core/src/test/scala/zio/nio/core/channels/DatagramChannelSpec.scala index 973742c3..df39257f 100644 --- a/nio-core/src/test/scala/zio/nio/core/channels/DatagramChannelSpec.scala +++ b/nio-core/src/test/scala/zio/nio/core/channels/DatagramChannelSpec.scala @@ -1,11 +1,11 @@ package zio.nio.core.channels -import java.io.IOException - +import zio.{ IO, _ } import zio.nio.core._ import zio.test.Assertion._ import zio.test._ -import zio.{ IO, _ } + +import java.io.IOException object DatagramChannelSpec extends BaseSpec { @@ -14,19 +14,18 @@ object DatagramChannelSpec extends BaseSpec { testM("read/write") { def echoServer(started: Promise[Nothing, SocketAddress]): IO[IOException, Unit] = for { - address <- SocketAddress.inetSocketAddress(0) - sink <- Buffer.byte(3) - _ <- DatagramChannel.open.use { server => - for { - _ <- server.bind(Some(address)) - addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) - _ <- started.succeed(addr) - retAddress <- server.receive(sink) - addr <- IO.fromOption(retAddress) - _ <- sink.flip - _ <- server.send(sink, addr) - } yield () - }.fork + sink <- Buffer.byte(3) + _ <- DatagramChannel.open.use { server => + for { + _ <- server.bindAuto + addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) + _ <- started.succeed(addr) + retAddress <- server.receive(sink) + addr <- IO.fromOption(retAddress) + _ <- sink.flip + _ <- server.send(sink, addr) + } yield () + }.fork } yield () def echoClient(address: SocketAddress): IO[IOException, Boolean] = @@ -56,28 +55,27 @@ object DatagramChannelSpec extends BaseSpec { def client(address: SocketAddress): IO[IOException, Unit] = DatagramChannel.open.use(_.connect(address).unit) def server( - address: SocketAddress, + address: Option[SocketAddress], started: Promise[Nothing, SocketAddress] ): IO[Nothing, Fiber[IOException, Unit]] = for { worker <- DatagramChannel.open.use { server => for { - _ <- server.bind(Some(address)) - addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) + _ <- server.bind(address) + addr <- server.localAddress.someOrElseM(ZIO.dieMessage("Local address must be bound")) _ <- started.succeed(addr) } yield () }.fork } yield worker for { - address <- SocketAddress.inetSocketAddress(0) serverStarted <- Promise.make[Nothing, SocketAddress] - s1 <- server(address, serverStarted) + s1 <- server(None, serverStarted) addr <- serverStarted.await _ <- client(addr) _ <- s1.join serverStarted2 <- Promise.make[Nothing, SocketAddress] - s2 <- server(addr, serverStarted2) + s2 <- server(Some(addr), serverStarted2) _ <- serverStarted2.await _ <- client(addr) _ <- s2.join diff --git a/nio-core/src/test/scala/zio/nio/core/channels/SelectorSpec.scala b/nio-core/src/test/scala/zio/nio/core/channels/SelectorSpec.scala index d25a28d3..249eb9b0 100644 --- a/nio-core/src/test/scala/zio/nio/core/channels/SelectorSpec.scala +++ b/nio-core/src/test/scala/zio/nio/core/channels/SelectorSpec.scala @@ -65,13 +65,12 @@ object SelectorSpec extends BaseSpec { } yield () for { - address <- SocketAddress.inetSocketAddress(0).toManaged_ scope <- Managed.scope selector <- Selector.open channel <- ServerSocketChannel.open _ <- Managed.fromEffect { for { - _ <- channel.bind(address) + _ <- channel.bindAuto() _ <- channel.configureBlocking(false) _ <- channel.register(selector, Operation.Accept) buffer <- Buffer.byte(256) diff --git a/nio/src/main/scala/zio/nio/channels/AsynchronousChannel.scala b/nio/src/main/scala/zio/nio/channels/AsynchronousChannel.scala index e8abbcf1..41d2faaf 100644 --- a/nio/src/main/scala/zio/nio/channels/AsynchronousChannel.scala +++ b/nio/src/main/scala/zio/nio/channels/AsynchronousChannel.scala @@ -67,17 +67,22 @@ class AsynchronousServerSocketChannel(private val channel: JAsynchronousServerSo /** * Binds the channel's socket to a local address and configures the socket - * to listen for connections. + * to listen for connections, up to backlog pending connection. */ - final def bind(address: SocketAddress): IO[IOException, Unit] = - IO.effect(channel.bind(address.jSocketAddress)).refineToOrDie[IOException].unit + final def bindTo(local: SocketAddress, backlog: Int = 0): IO[IOException, Unit] = bind(Some(local), backlog) + + /** + * Binds the channel's socket to an automatically assigned local address and configures the socket + * to listen for connections, up to backlog pending connection. + */ + final def bindAuto(backlog: Int = 0): IO[IOException, Unit] = bind(None, backlog) /** * Binds the channel's socket to a local address and configures the socket * to listen for connections, up to backlog pending connection. */ - final def bind(address: SocketAddress, backlog: Int): IO[IOException, Unit] = - IO.effect(channel.bind(address.jSocketAddress, backlog)).refineToOrDie[IOException].unit + final def bind(address: Option[SocketAddress], backlog: Int = 0): IO[IOException, Unit] = + IO.effect(channel.bind(address.map(_.jSocketAddress).orNull, backlog)).refineToOrDie[IOException].unit final def setOption[T](name: SocketOption[T], value: T): IO[IOException, Unit] = IO.effect(channel.setOption(name, value)).refineToOrDie[IOException].unit @@ -101,7 +106,7 @@ class AsynchronousServerSocketChannel(private val channel: JAsynchronousServerSo */ final def localAddress: IO[IOException, Option[SocketAddress]] = IO.effect( - Option(channel.getLocalAddress).map(new SocketAddress(_)) + Option(channel.getLocalAddress).map(SocketAddress.fromJava) ).refineToOrDie[IOException] /** diff --git a/nio/src/main/scala/zio/nio/channels/DatagramChannel.scala b/nio/src/main/scala/zio/nio/channels/DatagramChannel.scala index cb06c3ba..26e684db 100644 --- a/nio/src/main/scala/zio/nio/channels/DatagramChannel.scala +++ b/nio/src/main/scala/zio/nio/channels/DatagramChannel.scala @@ -43,7 +43,7 @@ final class DatagramChannel private[channels] (override protected[channels] val * @return the local address if the socket is bound, otherwise `None` */ def localAddress: IO[IOException, Option[SocketAddress]] = - IO.effect(channel.getLocalAddress()).refineToOrDie[IOException].map(a => Option(a).map(new SocketAddress(_))) + IO.effect(Option(channel.getLocalAddress()).map(SocketAddress.fromJava)).refineToOrDie[IOException] /** * Receives a datagram via this channel into the given [[zio.nio.core.ByteBuffer]]. @@ -52,7 +52,9 @@ final class DatagramChannel private[channels] (override protected[channels] val * @return the socket address of the datagram's source, if available. */ def receive(dst: ByteBuffer): IO[IOException, Option[SocketAddress]] = - IO.effect(channel.receive(dst.byteBuffer)).refineToOrDie[IOException].map(a => Option(a).map(new SocketAddress(_))) + IO.effect(channel.receive(dst.byteBuffer)) + .refineToOrDie[IOException] + .map(a => Option(a).map(SocketAddress.fromJava)) /** * Optionally returns the remote socket address that this channel's underlying socket is connected to. @@ -60,7 +62,7 @@ final class DatagramChannel private[channels] (override protected[channels] val * @return the remote address if the socket is connected, otherwise `None` */ def remoteAddress: IO[IOException, Option[SocketAddress]] = - IO.effect(channel.getRemoteAddress()).refineToOrDie[IOException].map(a => Option(a).map(new SocketAddress(_))) + IO.effect(Option(channel.getRemoteAddress()).map(SocketAddress.fromJava)).refineToOrDie[IOException] /** * Sends a datagram via this channel to the given target [[zio.nio.core.SocketAddress]]. @@ -109,6 +111,21 @@ object DatagramChannel { */ def bind(local: Option[SocketAddress]): Managed[IOException, DatagramChannel] = open.flatMap(_.bind(local).toManaged_) + /** + * Opens a datagram channel bound to the given local address as a managed resource. + * + * @param local the local address + * @return a datagram channel bound to the local address + */ + def bindTo(local: SocketAddress): Managed[IOException, DatagramChannel] = open.flatMap(_.bind(Some(local)).toManaged_) + + /** + * Opens a datagram channel bound to an automatically assigned local address as a managed resource. + * + * @return a datagram channel bound to the local address + */ + def bindAuto: Managed[IOException, DatagramChannel] = open.flatMap(_.bind(None).toManaged_) + /** * Opens a datagram channel connected to the given remote address as a managed resource. * diff --git a/nio/src/main/scala/zio/nio/channels/SelectableChannel.scala b/nio/src/main/scala/zio/nio/channels/SelectableChannel.scala index 612c7952..fdc5d44e 100644 --- a/nio/src/main/scala/zio/nio/channels/SelectableChannel.scala +++ b/nio/src/main/scala/zio/nio/channels/SelectableChannel.scala @@ -1,5 +1,11 @@ package zio.nio.channels +import zio.nio.channels.spi.SelectorProvider +import zio.nio.core.SocketAddress +import zio.nio.core.channels.SelectionKey +import zio.nio.core.channels.SelectionKey.Operation +import zio.{ IO, Managed, UIO } + import java.io.IOException import java.net.{ SocketOption, ServerSocket => JServerSocket, Socket => JSocket } import java.nio.channels.{ @@ -8,12 +14,6 @@ import java.nio.channels.{ SocketChannel => JSocketChannel } -import zio.{ IO, Managed, UIO } -import zio.nio.channels.spi.SelectorProvider -import zio.nio.core.{ SocketAddress } -import zio.nio.core.channels.SelectionKey -import zio.nio.core.channels.SelectionKey.Operation - trait SelectableChannel extends Channel { protected val channel: JSelectableChannel @@ -61,8 +61,12 @@ final class SocketChannel private[channels] (override protected[channels] val ch with GatheringByteChannel with ScatteringByteChannel { - final def bind(local: SocketAddress): IO[IOException, Unit] = - IO.effect(channel.bind(local.jSocketAddress)).refineToOrDie[IOException].unit + final def bindTo(address: SocketAddress): IO[IOException, Unit] = bind(Some(address)) + + final def bindAuto: IO[IOException, Unit] = bind(None) + + final def bind(local: Option[SocketAddress]): IO[IOException, Unit] = + IO.effect(channel.bind(local.map(_.jSocketAddress).orNull)).refineToOrDie[IOException].unit final def setOption[T](name: SocketOption[T], value: T): IO[IOException, Unit] = IO.effect(channel.setOption(name, value)).refineToOrDie[IOException].unit @@ -122,14 +126,12 @@ object SocketChannel { final class ServerSocketChannel private (override protected val channel: JServerSocketChannel) extends SelectableChannel { - final def bind(local: SocketAddress): IO[IOException, Unit] = - IO.effect(channel.bind(local.jSocketAddress)).refineToOrDie[IOException].unit + final def bindTo(local: SocketAddress, backlog: Int = 0): IO[IOException, Unit] = bind(Some(local), backlog) - final def bind(local: SocketAddress, backlog: Int): IO[IOException, Unit] = - IO.effect(channel.bind(local.jSocketAddress, backlog)).refineToOrDie[IOException].unit + final def bindAuto(backlog: Int = 0): IO[IOException, Unit] = bind(None, backlog) - final def setOption[T](name: SocketOption[T], value: T): IO[Exception, Unit] = - IO.effect(channel.setOption(name, value)).refineToOrDie[Exception].unit + final def bind(local: Option[SocketAddress], backlog: Int = 0): IO[IOException, Unit] = + IO.effect(channel.bind(local.map(_.jSocketAddress).orNull, backlog)).refineToOrDie[IOException].unit final val socket: UIO[JServerSocket] = IO.effectTotal(channel.socket()) @@ -145,7 +147,7 @@ final class ServerSocketChannel private (override protected val channel: JServer IO.effect(Option(channel.accept()).map(new SocketChannel(_))).refineToOrDie[IOException] final val localAddress: IO[IOException, SocketAddress] = - IO.effect(new SocketAddress(channel.getLocalAddress())).refineToOrDie[IOException] + IO.effect(SocketAddress.fromJava(channel.getLocalAddress())).refineToOrDie[IOException] } object ServerSocketChannel { diff --git a/nio/src/main/scala/zio/nio/channels/Selector.scala b/nio/src/main/scala/zio/nio/channels/Selector.scala index aab6ac9f..eeda518c 100644 --- a/nio/src/main/scala/zio/nio/channels/Selector.scala +++ b/nio/src/main/scala/zio/nio/channels/Selector.scala @@ -1,16 +1,15 @@ package zio.nio.channels -import java.io.IOException -import java.nio.channels.{ ClosedSelectorException, Selector => JSelector, SelectionKey => JSelectionKey } - -import zio.{ IO, Managed, UIO, ZIO } import com.github.ghik.silencer.silent +import zio.blocking.Blocking import zio.duration.Duration import zio.nio.channels.spi.SelectorProvider import zio.nio.core.channels.SelectionKey -import zio.blocking -import zio.blocking.Blocking +import zio.{ IO, Managed, UIO, ZIO, blocking } +import java.io.IOException +import java.nio.channels.{ ClosedSelectorException, SelectionKey => JSelectionKey, Selector => JSelector } +import scala.collection.mutable import scala.jdk.CollectionConverters._ class Selector(private[nio] val selector: JSelector) { @@ -24,12 +23,40 @@ class Selector(private[nio] val selector: JSelector) { .map(_.asScala.toSet[JSelectionKey].map(new SelectionKey(_))) .refineToOrDie[ClosedSelectorException] + /** + * Returns this selector's selected-key set. + * + * Note that the returned set it mutable - keys may be removed from, but not directly added to it. + * Any attempt to add an object to the key set will cause an `UnsupportedOperationException` to be thrown. + * The selected-key set is not thread-safe. + */ @silent - final val selectedKeys: IO[ClosedSelectorException, Set[SelectionKey]] = + final val selectedKeys: IO[ClosedSelectorException, mutable.Set[SelectionKey]] = IO.effect(selector.selectedKeys()) - .map(_.asScala.toSet[JSelectionKey].map(new SelectionKey(_))) + .map(_.asScala.map(new SelectionKey(_))) .refineToOrDie[ClosedSelectorException] + /** + * Performs an effect with each selected key. + * + * If the result of effect is true, the key will be removed from the selected-key set, which is + * usually what you want after successfully handling a selected key. + */ + def foreachSelectedKey[R, E](f: SelectionKey => ZIO[R, E, Boolean]): ZIO[R, E, Unit] = + ZIO.effectTotal(selector.selectedKeys().iterator()).flatMap { iter => + def loop: ZIO[R, E, Unit] = + ZIO.effectSuspendTotal { + if (iter.hasNext) { + val key = iter.next() + f(new SelectionKey(key)).flatMap(ZIO.when(_)(ZIO.effectTotal(iter.remove()))) *> + loop + } else + ZIO.unit + } + + loop + } + final def removeKey(key: SelectionKey): IO[ClosedSelectorException, Unit] = IO.effect(selector.selectedKeys().remove(key.selectionKey)) .unit diff --git a/nio/src/test/scala/zio/nio/channels/ChannelSpec.scala b/nio/src/test/scala/zio/nio/channels/ChannelSpec.scala index 0653b543..8ff8d7e1 100644 --- a/nio/src/test/scala/zio/nio/channels/ChannelSpec.scala +++ b/nio/src/test/scala/zio/nio/channels/ChannelSpec.scala @@ -1,7 +1,7 @@ package zio.nio.channels import zio.nio.BaseSpec -import zio.nio.core.{ Buffer, SocketAddress } +import zio.nio.core.{ Buffer, InetSocketAddress, SocketAddress } import zio.{ IO, _ } import zio.test._ import zio.test.Assertion._ @@ -14,11 +14,11 @@ object ChannelSpec extends BaseSpec { testM("read/write") { def echoServer(started: Promise[Nothing, SocketAddress]): IO[Exception, Unit] = for { - address <- SocketAddress.inetSocketAddress(0) + address <- InetSocketAddress.wildCard(0) sink <- Buffer.byte(3) _ <- AsynchronousServerSocketChannel().use { server => for { - _ <- server.bind(address) + _ <- server.bindTo(address) addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) _ <- started.succeed(addr) _ <- server.accept.use { worker => @@ -56,10 +56,10 @@ object ChannelSpec extends BaseSpec { testM("read should fail when connection close") { def server(started: Promise[Nothing, SocketAddress]): IO[Exception, Fiber[Exception, Boolean]] = for { - address <- SocketAddress.inetSocketAddress(0) + address <- InetSocketAddress.wildCard(0) result <- AsynchronousServerSocketChannel().use { server => for { - _ <- server.bind(address) + _ <- server.bindTo(address) addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) _ <- started.succeed(addr) result <- server.accept @@ -100,7 +100,7 @@ object ChannelSpec extends BaseSpec { for { worker <- AsynchronousServerSocketChannel().use { server => for { - _ <- server.bind(address) + _ <- server.bindTo(address) addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) _ <- started.succeed(addr) worker <- server.accept.use(_ => ZIO.unit) @@ -109,7 +109,7 @@ object ChannelSpec extends BaseSpec { } yield worker for { - address <- SocketAddress.inetSocketAddress(0) + address <- InetSocketAddress.wildCard(0) serverStarted <- Promise.make[Nothing, SocketAddress] s1 <- server(address, serverStarted) addr <- serverStarted.await @@ -125,8 +125,8 @@ object ChannelSpec extends BaseSpec { testM("accept should be interruptible") { AsynchronousServerSocketChannel().use { server => for { - addr <- SocketAddress.inetSocketAddress(0) - _ <- server.bind(addr) + addr <- InetSocketAddress.wildCard(0) + _ <- server.bindTo(addr) fiber <- server.accept.useNow.fork _ <- fiber.interrupt result <- fiber.await @@ -136,9 +136,9 @@ object ChannelSpec extends BaseSpec { // this would best be tagged as an regression test. for now just run manually when suspicious. testM("accept should not leak resources") { val server = for { - addr <- SocketAddress.inetSocketAddress(8081).toManaged_ + addr <- InetSocketAddress.wildCard(8081).toManaged_ channel <- AsynchronousServerSocketChannel() - _ <- channel.bind(addr).toManaged_ + _ <- channel.bindTo(addr).toManaged_ _ <- AsynchronousSocketChannel().use(channel => channel.connect(addr)).forever.toManaged_.fork } yield channel val interruptAccept = server.use( diff --git a/nio/src/test/scala/zio/nio/channels/DatagramChannelSpec.scala b/nio/src/test/scala/zio/nio/channels/DatagramChannelSpec.scala index cfab45b9..1baadd97 100644 --- a/nio/src/test/scala/zio/nio/channels/DatagramChannelSpec.scala +++ b/nio/src/test/scala/zio/nio/channels/DatagramChannelSpec.scala @@ -1,7 +1,7 @@ package zio.nio.channels import zio.nio._ -import zio.nio.core.{ Buffer, SocketAddress } +import zio.nio.core.{ Buffer, InetSocketAddress, SocketAddress } import zio.test.Assertion._ import zio.test._ import zio._ @@ -13,7 +13,7 @@ object DatagramChannelSpec extends BaseSpec { testM("read/write") { def echoServer(started: Promise[Nothing, SocketAddress]): IO[Exception, Unit] = for { - address <- SocketAddress.inetSocketAddress(0) + address <- InetSocketAddress.wildCard(0) sink <- Buffer.byte(3) _ <- DatagramChannel .bind(Some(address)) @@ -72,7 +72,7 @@ object DatagramChannelSpec extends BaseSpec { } yield worker for { - address <- SocketAddress.inetSocketAddress(0) + address <- InetSocketAddress.wildCard(0) serverStarted <- Promise.make[Nothing, SocketAddress] s1 <- server(address, serverStarted) addr <- serverStarted.await diff --git a/nio/src/test/scala/zio/nio/channels/SelectorSpec.scala b/nio/src/test/scala/zio/nio/channels/SelectorSpec.scala index c2dec9fa..ef38585d 100644 --- a/nio/src/test/scala/zio/nio/channels/SelectorSpec.scala +++ b/nio/src/test/scala/zio/nio/channels/SelectorSpec.scala @@ -1,11 +1,10 @@ package zio.nio.channels import java.nio.channels.CancelledKeyException - import zio._ import zio.blocking.Blocking import zio.clock.Clock -import zio.nio.core.{ Buffer, ByteBuffer, SocketAddress } +import zio.nio.core.{ Buffer, ByteBuffer, InetSocketAddress, SocketAddress } import zio.nio.core.channels.SelectionKey.Operation import zio.nio.BaseSpec import zio.test._ @@ -39,40 +38,39 @@ object SelectorSpec extends BaseSpec { buffer: ByteBuffer ): ZIO[Blocking, Exception, Unit] = for { - _ <- selector.select - selectedKeys <- selector.selectedKeys - _ <- IO.foreach(selectedKeys) { key => - IO.whenM(safeStatusCheck(key.isAcceptable)) { - for { - clientOpt <- channel.accept - client = clientOpt.get - _ <- client.configureBlocking(false) - _ <- client.register(selector, Operation.Read) - } yield () - } *> - IO.whenM(safeStatusCheck(key.isReadable)) { - IO.effectSuspendTotal { - val sClient = key.channel - val client = sClient.asInstanceOf[zio.nio.core.channels.SocketChannel] - for { - _ <- client.read(buffer) - _ <- buffer.flip - _ <- client.write(buffer) - _ <- buffer.clear - _ <- client.close - } yield () - } - } *> - selector.removeKey(key) - } + _ <- selector.select + _ <- selector.foreachSelectedKey { key => + IO.whenM(safeStatusCheck(key.isAcceptable)) { + for { + clientOpt <- channel.accept + client = clientOpt.get + _ <- client.configureBlocking(false) + _ <- client.register(selector, Operation.Read) + } yield () + } *> + IO.whenM(safeStatusCheck(key.isReadable)) { + IO.effectSuspendTotal { + val sClient = key.channel + val client = sClient.asInstanceOf[zio.nio.core.channels.SocketChannel] + for { + _ <- client.read(buffer) + _ <- buffer.flip + _ <- client.write(buffer) + _ <- buffer.clear + _ <- client.close + } yield () + } + }.as(true) + } + _ <- selector.selectedKeys.filterOrDieMessage(_.isEmpty)("Selected key set should be empty") } yield () for { - address <- SocketAddress.inetSocketAddress(0) + address <- InetSocketAddress.wildCard(0) _ <- Selector.make.use { selector => ServerSocketChannel.open.use { channel => for { - _ <- channel.bind(address) + _ <- channel.bindTo(address) _ <- channel.configureBlocking(false) _ <- channel.register(selector, Operation.Accept) buffer <- Buffer.byte(256)