From b550a83e80dcb2eeb081c8dc59fef322f5135446 Mon Sep 17 00:00:00 2001 From: Lachlan O'Dea Date: Tue, 29 Sep 2020 11:13:06 +1000 Subject: [PATCH 1/7] Improve InetSocketAddress and socket binding APIs. Use meaningful names for the InetSocketAddress constructors. Make binding to an automatically assigned socket address explicit. --- .../src/main/scala/StreamsBasedServer.scala | 4 +- .../zio/nio/core/InetSocketAddress.scala | 54 ++++++------- .../core/channels/AsynchronousChannel.scala | 25 +++--- .../nio/core/channels/DatagramChannel.scala | 4 + .../nio/core/channels/SelectableChannel.scala | 16 ++-- .../zio/nio/core/channels/ChannelSpec.scala | 80 +++++++++---------- .../core/channels/DatagramChannelSpec.scala | 42 +++++----- .../zio/nio/core/channels/SelectorSpec.scala | 3 +- .../nio/channels/AsynchronousChannel.scala | 15 ++-- .../zio/nio/channels/DatagramChannel.scala | 15 ++++ .../zio/nio/channels/SelectableChannel.scala | 30 +++---- .../scala/zio/nio/channels/ChannelSpec.scala | 22 ++--- .../nio/channels/DatagramChannelSpec.scala | 6 +- .../scala/zio/nio/channels/SelectorSpec.scala | 7 +- 14 files changed, 173 insertions(+), 150 deletions(-) diff --git a/examples/src/main/scala/StreamsBasedServer.scala b/examples/src/main/scala/StreamsBasedServer.scala index fae9b9ab..a9282048 100644 --- a/examples/src/main/scala/StreamsBasedServer.scala +++ b/examples/src/main/scala/StreamsBasedServer.scala @@ -3,7 +3,7 @@ import zio.clock.Clock import zio.console.Console import zio.duration._ import zio.nio.channels._ -import zio.nio.core.SocketAddress +import zio.nio.core.InetSocketAddress import zio.stream._ object StreamsBasedServer extends App { @@ -16,7 +16,7 @@ object StreamsBasedServer extends App { AsynchronousServerSocketChannel() .use(socket => for { - _ <- SocketAddress.inetSocketAddress("localhost", port) >>= socket.bind + _ <- InetSocketAddress.hostname("localhost", port).flatMap(socket.bindTo(_)) _ <- ZStream .repeatEffect(socket.accept.preallocate) .map(_.withEarlyRelease) diff --git a/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala b/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala index 7fd6ab26..5fc585de 100644 --- a/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala +++ b/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala @@ -22,6 +22,18 @@ sealed class SocketAddress private[nio] (private[nio] val jSocketAddress: JSocke final override def toString: String = jSocketAddress.toString } +object SocketAddress { + + def fromJava(jSocketAddress: JSocketAddress): SocketAddress = + jSocketAddress match { + case inet: JInetSocketAddress => + new InetSocketAddress(inet) + case other => + new SocketAddress(other) + } + +} + /** * Representation of an IP Socket Address (IP address + port number). * @@ -75,29 +87,26 @@ final class InetSocketAddress private[nio] (private val jInetSocketAddress: JIne } -object SocketAddress { - - private[nio] def fromJava(jSocketAddress: JSocketAddress) = - jSocketAddress match { - case inet: JInetSocketAddress => - new InetSocketAddress(inet) - case other => - new SocketAddress(other) - } +object InetSocketAddress { /** * Creates a socket address where the IP address is the wildcard address and the port number a specified value. * * The socket address will be ''resolved''. */ - def inetSocketAddress(port: Int): UIO[InetSocketAddress] = InetSocketAddress(port) + def wildCard(port: Int): UIO[InetSocketAddress] = UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(port))) + + def wildCardEphemeral: UIO[InetSocketAddress] = wildCard(0) /** * Creates a socket address from an IP address and a port number. * * The socket address will be ''resolved''. */ - def inetSocketAddress(hostname: String, port: Int): UIO[InetSocketAddress] = InetSocketAddress(hostname, port) + def hostname(hostname: String, port: Int): UIO[InetSocketAddress] = + UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(hostname, port))) + + def hostnameEphemeral(hostname: String): UIO[InetSocketAddress] = this.hostname(hostname, 0) /** * Creates a socket address from a hostname and a port number. @@ -105,7 +114,10 @@ object SocketAddress { * An attempt will be made to resolve the hostname into an `InetAddress`. * If that attempt fails, the socket address will be flagged as ''unresolved''. */ - def inetSocketAddress(address: InetAddress, port: Int): UIO[InetSocketAddress] = InetSocketAddress(address, port) + def inetAddress(address: InetAddress, port: Int): UIO[InetSocketAddress] = + UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(address.jInetAddress, port))) + + def inetAddressEphemeral(address: InetAddress): UIO[InetSocketAddress] = inetAddress(address, 0) /** * Creates an unresolved socket address from a hostname and a port number. @@ -113,21 +125,9 @@ object SocketAddress { * No attempt will be made to resolve the hostname into an `InetAddress`. * The socket address will be flagged as ''unresolved''. */ - def unresolvedInetSocketAddress(hostname: String, port: Int): UIO[InetSocketAddress] = - InetSocketAddress.createUnresolved(hostname, port) - - private object InetSocketAddress { - - def apply(port: Int): UIO[InetSocketAddress] = UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(port))) - - def apply(host: String, port: Int): UIO[InetSocketAddress] = - UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(host, port))) - - def apply(addr: InetAddress, port: Int): UIO[InetSocketAddress] = - UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(addr.jInetAddress, port))) + def unresolvedHostname(hostname: String, port: Int): UIO[InetSocketAddress] = + UIO.effectTotal(new InetSocketAddress(JInetSocketAddress.createUnresolved(hostname, port))) - def createUnresolved(host: String, port: Int): UIO[InetSocketAddress] = - UIO.effectTotal(new InetSocketAddress(JInetSocketAddress.createUnresolved(host, port))) + def unresolvedHostnameEphemeral(hostname: String): UIO[InetSocketAddress] = unresolvedHostname(hostname, 0) - } } diff --git a/nio-core/src/main/scala/zio/nio/core/channels/AsynchronousChannel.scala b/nio-core/src/main/scala/zio/nio/core/channels/AsynchronousChannel.scala index bc605ec3..53ba58b7 100644 --- a/nio-core/src/main/scala/zio/nio/core/channels/AsynchronousChannel.scala +++ b/nio-core/src/main/scala/zio/nio/core/channels/AsynchronousChannel.scala @@ -65,19 +65,16 @@ abstract class AsynchronousByteChannel private[channels] (protected val channel: final class AsynchronousServerSocketChannel(protected val channel: JAsynchronousServerSocketChannel) extends Channel { - /** - * Binds the channel's socket to a local address and configures the socket - * to listen for connections. - */ - def bind(address: SocketAddress): IO[IOException, Unit] = - IO.effect(channel.bind(address.jSocketAddress)).refineToOrDie[IOException].unit + def bindTo(local: SocketAddress, backlog: Int = 0): IO[IOException, Unit] = bind(Some(local), backlog) + + def bindAuto(backlog: Int = 0): IO[IOException, Unit] = bind(None, backlog) /** * Binds the channel's socket to a local address and configures the socket * to listen for connections, up to backlog pending connection. */ - def bind(address: SocketAddress, backlog: Int): IO[IOException, Unit] = - IO.effect(channel.bind(address.jSocketAddress, backlog)).refineToOrDie[IOException].unit + def bind(address: Option[SocketAddress], backlog: Int = 0): IO[IOException, Unit] = + IO.effect(channel.bind(address.map(_.jSocketAddress).orNull, backlog)).refineToOrDie[IOException].unit def setOption[T](name: SocketOption[T], value: T): IO[IOException, Unit] = IO.effect(channel.setOption(name, value)).refineToOrDie[IOException].unit @@ -119,13 +116,17 @@ object AsynchronousServerSocketChannel { .toNioManaged } -class AsynchronousSocketChannel(override protected val channel: JAsynchronousSocketChannel) +final class AsynchronousSocketChannel(override protected val channel: JAsynchronousSocketChannel) extends AsynchronousByteChannel(channel) { - final def bind(address: SocketAddress): IO[IOException, Unit] = - IO.effect(channel.bind(address.jSocketAddress)).refineToOrDie[IOException].unit + def bindTo(address: SocketAddress): IO[IOException, Unit] = bind(Some(address)) - final def setOption[T](name: SocketOption[T], value: T): IO[IOException, Unit] = + def bindAuto: IO[IOException, Unit] = bind(None) + + def bind(address: Option[SocketAddress]): IO[IOException, Unit] = + IO.effect(channel.bind(address.map(_.jSocketAddress).orNull)).refineToOrDie[IOException].unit + + def setOption[T](name: SocketOption[T], value: T): IO[IOException, Unit] = IO.effect(channel.setOption(name, value)).refineToOrDie[IOException].unit final def shutdownInput: IO[IOException, Unit] = IO.effect(channel.shutdownInput()).refineToOrDie[IOException].unit diff --git a/nio-core/src/main/scala/zio/nio/core/channels/DatagramChannel.scala b/nio-core/src/main/scala/zio/nio/core/channels/DatagramChannel.scala index f8599ee0..5a516699 100644 --- a/nio-core/src/main/scala/zio/nio/core/channels/DatagramChannel.scala +++ b/nio-core/src/main/scala/zio/nio/core/channels/DatagramChannel.scala @@ -15,6 +15,10 @@ final class DatagramChannel private[channels] (override protected[channels] val with SelectableChannel with ScatteringByteChannel { + def bindTo(local: SocketAddress): IO[IOException, Unit] = bind(Some(local)) + + def bindAuto: IO[IOException, Unit] = bind(None) + /** * Binds this channel's underlying socket to the given local address. Passing `None` binds to an * automatically assigned local address. diff --git a/nio-core/src/main/scala/zio/nio/core/channels/SelectableChannel.scala b/nio-core/src/main/scala/zio/nio/core/channels/SelectableChannel.scala index 9cc8b9b9..c694f168 100644 --- a/nio-core/src/main/scala/zio/nio/core/channels/SelectableChannel.scala +++ b/nio-core/src/main/scala/zio/nio/core/channels/SelectableChannel.scala @@ -63,8 +63,12 @@ final class SocketChannel(override protected[channels] val channel: JSocketChann with GatheringByteChannel with ScatteringByteChannel { - def bind(local: SocketAddress): IO[IOException, Unit] = - IO.effect(channel.bind(local.jSocketAddress)).refineToOrDie[IOException].unit + def bindTo(address: SocketAddress): IO[IOException, Unit] = bind(Some(address)) + + def bindAuto: IO[IOException, Unit] = bind(None) + + def bind(local: Option[SocketAddress]): IO[IOException, Unit] = + IO.effect(channel.bind(local.map(_.jSocketAddress).orNull)).refineToOrDie[IOException].unit def setOption[T](name: SocketOption[T], value: T): IO[IOException, Unit] = IO.effect(channel.setOption(name, value)).refineToOrDie[IOException].unit @@ -110,12 +114,12 @@ object SocketChannel { } final class ServerSocketChannel(override protected val channel: JServerSocketChannel) extends SelectableChannel { + def bindTo(local: SocketAddress, backlog: Int = 0): IO[IOException, Unit] = bind(Some(local), backlog) - def bind(local: SocketAddress): IO[IOException, Unit] = - IO.effect(channel.bind(local.jSocketAddress)).refineToOrDie[IOException].unit + def bindAuto(backlog: Int = 0): IO[IOException, Unit] = bind(None, backlog) - def bind(local: SocketAddress, backlog: Int): IO[IOException, Unit] = - IO.effect(channel.bind(local.jSocketAddress, backlog)).refineToOrDie[IOException].unit + def bind(local: Option[SocketAddress], backlog: Int = 0): IO[IOException, Unit] = + IO.effect(channel.bind(local.map(_.jSocketAddress).orNull, backlog)).refineToOrDie[IOException].unit def setOption[T](name: SocketOption[T], value: T): IO[IOException, Unit] = IO.effect(channel.setOption(name, value)).refineToOrDie[IOException].unit diff --git a/nio-core/src/test/scala/zio/nio/core/channels/ChannelSpec.scala b/nio-core/src/test/scala/zio/nio/core/channels/ChannelSpec.scala index 2f37951a..6ae0c962 100644 --- a/nio-core/src/test/scala/zio/nio/core/channels/ChannelSpec.scala +++ b/nio-core/src/test/scala/zio/nio/core/channels/ChannelSpec.scala @@ -1,11 +1,11 @@ package zio.nio.core.channels -import java.io.{ EOFException, FileNotFoundException, IOException } - import zio.nio.core.{ BaseSpec, Buffer, EffectOps, SocketAddress } -import zio.{ IO, _ } -import zio.test._ import zio.test.Assertion._ +import zio.test._ +import zio.{ IO, _ } + +import java.io.{ EOFException, FileNotFoundException, IOException } object ChannelSpec extends BaseSpec { @@ -14,23 +14,22 @@ object ChannelSpec extends BaseSpec { testM("read/write") { def echoServer(started: Promise[Nothing, SocketAddress]): IO[Exception, Unit] = for { - address <- SocketAddress.inetSocketAddress(0) - sink <- Buffer.byte(3) - _ <- AsynchronousServerSocketChannel - .open() - .use { server => - for { - _ <- server.bind(address) - addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) - _ <- started.succeed(addr) - _ <- server.accept.use { worker => - worker.read(sink) *> - sink.flip *> - worker.write(sink) - } - } yield () - } - .fork + sink <- Buffer.byte(3) + _ <- AsynchronousServerSocketChannel + .open() + .use { server => + for { + _ <- server.bindAuto() + addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) + _ <- started.succeed(addr) + _ <- server.accept.use { worker => + worker.read(sink) *> + sink.flip *> + worker.write(sink) + } + } yield () + } + .fork } yield () def echoClient(address: SocketAddress): IO[Exception, Boolean] = @@ -59,22 +58,21 @@ object ChannelSpec extends BaseSpec { testM("read should fail when connection close") { def server(started: Promise[Nothing, SocketAddress]): IO[Exception, Fiber[Exception, Boolean]] = for { - address <- SocketAddress.inetSocketAddress(0) - result <- AsynchronousServerSocketChannel - .open() - .use { server => - for { - _ <- server.bind(address) - addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) - _ <- started.succeed(addr) - result <- server.accept - .use(worker => worker.readChunk(3) *> worker.readChunk(3) *> ZIO.succeed(false)) - .catchSome { case _: java.io.EOFException => - ZIO.succeed(true) - } - } yield result - } - .fork + result <- AsynchronousServerSocketChannel + .open() + .use { server => + for { + _ <- server.bindAuto() + addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) + _ <- started.succeed(addr) + result <- server.accept + .use(worker => worker.readChunk(3) *> worker.readChunk(3) *> ZIO.succeed(false)) + .catchSome { case _: java.io.EOFException => + ZIO.succeed(true) + } + } yield result + } + .fork } yield result def client(address: SocketAddress): IO[Exception, Unit] = @@ -102,25 +100,23 @@ object ChannelSpec extends BaseSpec { } def server( - address: SocketAddress, started: Promise[Nothing, SocketAddress] ): Managed[IOException, Fiber[Exception, Unit]] = for { server <- AsynchronousServerSocketChannel.open() - _ <- server.bind(address).toManaged_ + _ <- server.bindAuto().toManaged_ addr <- server.localAddress.someOrElseM(IO.die(new NoSuchElementException)).toManaged_ _ <- started.succeed(addr).toManaged_ worker <- server.accept.unit.fork } yield worker for { - address <- SocketAddress.inetSocketAddress(0) serverStarted1 <- Promise.make[Nothing, SocketAddress] - _ <- server(address, serverStarted1).use { s1 => + _ <- server(serverStarted1).use { s1 => serverStarted1.await.flatMap(client).zipRight(s1.join) } serverStarted2 <- Promise.make[Nothing, SocketAddress] - _ <- server(address, serverStarted2).use { s2 => + _ <- server(serverStarted2).use { s2 => serverStarted2.await.flatMap(client).zipRight(s2.join) } } yield assertCompletes diff --git a/nio-core/src/test/scala/zio/nio/core/channels/DatagramChannelSpec.scala b/nio-core/src/test/scala/zio/nio/core/channels/DatagramChannelSpec.scala index 973742c3..df39257f 100644 --- a/nio-core/src/test/scala/zio/nio/core/channels/DatagramChannelSpec.scala +++ b/nio-core/src/test/scala/zio/nio/core/channels/DatagramChannelSpec.scala @@ -1,11 +1,11 @@ package zio.nio.core.channels -import java.io.IOException - +import zio.{ IO, _ } import zio.nio.core._ import zio.test.Assertion._ import zio.test._ -import zio.{ IO, _ } + +import java.io.IOException object DatagramChannelSpec extends BaseSpec { @@ -14,19 +14,18 @@ object DatagramChannelSpec extends BaseSpec { testM("read/write") { def echoServer(started: Promise[Nothing, SocketAddress]): IO[IOException, Unit] = for { - address <- SocketAddress.inetSocketAddress(0) - sink <- Buffer.byte(3) - _ <- DatagramChannel.open.use { server => - for { - _ <- server.bind(Some(address)) - addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) - _ <- started.succeed(addr) - retAddress <- server.receive(sink) - addr <- IO.fromOption(retAddress) - _ <- sink.flip - _ <- server.send(sink, addr) - } yield () - }.fork + sink <- Buffer.byte(3) + _ <- DatagramChannel.open.use { server => + for { + _ <- server.bindAuto + addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) + _ <- started.succeed(addr) + retAddress <- server.receive(sink) + addr <- IO.fromOption(retAddress) + _ <- sink.flip + _ <- server.send(sink, addr) + } yield () + }.fork } yield () def echoClient(address: SocketAddress): IO[IOException, Boolean] = @@ -56,28 +55,27 @@ object DatagramChannelSpec extends BaseSpec { def client(address: SocketAddress): IO[IOException, Unit] = DatagramChannel.open.use(_.connect(address).unit) def server( - address: SocketAddress, + address: Option[SocketAddress], started: Promise[Nothing, SocketAddress] ): IO[Nothing, Fiber[IOException, Unit]] = for { worker <- DatagramChannel.open.use { server => for { - _ <- server.bind(Some(address)) - addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) + _ <- server.bind(address) + addr <- server.localAddress.someOrElseM(ZIO.dieMessage("Local address must be bound")) _ <- started.succeed(addr) } yield () }.fork } yield worker for { - address <- SocketAddress.inetSocketAddress(0) serverStarted <- Promise.make[Nothing, SocketAddress] - s1 <- server(address, serverStarted) + s1 <- server(None, serverStarted) addr <- serverStarted.await _ <- client(addr) _ <- s1.join serverStarted2 <- Promise.make[Nothing, SocketAddress] - s2 <- server(addr, serverStarted2) + s2 <- server(Some(addr), serverStarted2) _ <- serverStarted2.await _ <- client(addr) _ <- s2.join diff --git a/nio-core/src/test/scala/zio/nio/core/channels/SelectorSpec.scala b/nio-core/src/test/scala/zio/nio/core/channels/SelectorSpec.scala index d25a28d3..249eb9b0 100644 --- a/nio-core/src/test/scala/zio/nio/core/channels/SelectorSpec.scala +++ b/nio-core/src/test/scala/zio/nio/core/channels/SelectorSpec.scala @@ -65,13 +65,12 @@ object SelectorSpec extends BaseSpec { } yield () for { - address <- SocketAddress.inetSocketAddress(0).toManaged_ scope <- Managed.scope selector <- Selector.open channel <- ServerSocketChannel.open _ <- Managed.fromEffect { for { - _ <- channel.bind(address) + _ <- channel.bindAuto() _ <- channel.configureBlocking(false) _ <- channel.register(selector, Operation.Accept) buffer <- Buffer.byte(256) diff --git a/nio/src/main/scala/zio/nio/channels/AsynchronousChannel.scala b/nio/src/main/scala/zio/nio/channels/AsynchronousChannel.scala index e8abbcf1..8fa283fa 100644 --- a/nio/src/main/scala/zio/nio/channels/AsynchronousChannel.scala +++ b/nio/src/main/scala/zio/nio/channels/AsynchronousChannel.scala @@ -67,17 +67,22 @@ class AsynchronousServerSocketChannel(private val channel: JAsynchronousServerSo /** * Binds the channel's socket to a local address and configures the socket - * to listen for connections. + * to listen for connections, up to backlog pending connection. */ - final def bind(address: SocketAddress): IO[IOException, Unit] = - IO.effect(channel.bind(address.jSocketAddress)).refineToOrDie[IOException].unit + final def bindTo(local: SocketAddress, backlog: Int = 0): IO[IOException, Unit] = bind(Some(local), backlog) + + /** + * Binds the channel's socket to an automatically assigned local address and configures the socket + * to listen for connections, up to backlog pending connection. + */ + final def bindAuto(backlog: Int = 0): IO[IOException, Unit] = bind(None, backlog) /** * Binds the channel's socket to a local address and configures the socket * to listen for connections, up to backlog pending connection. */ - final def bind(address: SocketAddress, backlog: Int): IO[IOException, Unit] = - IO.effect(channel.bind(address.jSocketAddress, backlog)).refineToOrDie[IOException].unit + final def bind(address: Option[SocketAddress], backlog: Int = 0): IO[IOException, Unit] = + IO.effect(channel.bind(address.map(_.jSocketAddress).orNull, backlog)).refineToOrDie[IOException].unit final def setOption[T](name: SocketOption[T], value: T): IO[IOException, Unit] = IO.effect(channel.setOption(name, value)).refineToOrDie[IOException].unit diff --git a/nio/src/main/scala/zio/nio/channels/DatagramChannel.scala b/nio/src/main/scala/zio/nio/channels/DatagramChannel.scala index cb06c3ba..2fc56851 100644 --- a/nio/src/main/scala/zio/nio/channels/DatagramChannel.scala +++ b/nio/src/main/scala/zio/nio/channels/DatagramChannel.scala @@ -109,6 +109,21 @@ object DatagramChannel { */ def bind(local: Option[SocketAddress]): Managed[IOException, DatagramChannel] = open.flatMap(_.bind(local).toManaged_) + /** + * Opens a datagram channel bound to the given local address as a managed resource. + * + * @param local the local address + * @return a datagram channel bound to the local address + */ + def bindTo(local: SocketAddress): Managed[IOException, DatagramChannel] = open.flatMap(_.bind(Some(local)).toManaged_) + + /** + * Opens a datagram channel bound to an automatically assigned local address as a managed resource. + * + * @return a datagram channel bound to the local address + */ + def bindAuto: Managed[IOException, DatagramChannel] = open.flatMap(_.bind(None).toManaged_) + /** * Opens a datagram channel connected to the given remote address as a managed resource. * diff --git a/nio/src/main/scala/zio/nio/channels/SelectableChannel.scala b/nio/src/main/scala/zio/nio/channels/SelectableChannel.scala index 612c7952..1cb81935 100644 --- a/nio/src/main/scala/zio/nio/channels/SelectableChannel.scala +++ b/nio/src/main/scala/zio/nio/channels/SelectableChannel.scala @@ -1,5 +1,11 @@ package zio.nio.channels +import zio.nio.channels.spi.SelectorProvider +import zio.nio.core.SocketAddress +import zio.nio.core.channels.SelectionKey +import zio.nio.core.channels.SelectionKey.Operation +import zio.{ IO, Managed, UIO } + import java.io.IOException import java.net.{ SocketOption, ServerSocket => JServerSocket, Socket => JSocket } import java.nio.channels.{ @@ -8,12 +14,6 @@ import java.nio.channels.{ SocketChannel => JSocketChannel } -import zio.{ IO, Managed, UIO } -import zio.nio.channels.spi.SelectorProvider -import zio.nio.core.{ SocketAddress } -import zio.nio.core.channels.SelectionKey -import zio.nio.core.channels.SelectionKey.Operation - trait SelectableChannel extends Channel { protected val channel: JSelectableChannel @@ -61,8 +61,12 @@ final class SocketChannel private[channels] (override protected[channels] val ch with GatheringByteChannel with ScatteringByteChannel { - final def bind(local: SocketAddress): IO[IOException, Unit] = - IO.effect(channel.bind(local.jSocketAddress)).refineToOrDie[IOException].unit + final def bindTo(address: SocketAddress): IO[IOException, Unit] = bind(Some(address)) + + final def bindAuto: IO[IOException, Unit] = bind(None) + + final def bind(local: Option[SocketAddress]): IO[IOException, Unit] = + IO.effect(channel.bind(local.map(_.jSocketAddress).orNull)).refineToOrDie[IOException].unit final def setOption[T](name: SocketOption[T], value: T): IO[IOException, Unit] = IO.effect(channel.setOption(name, value)).refineToOrDie[IOException].unit @@ -122,14 +126,12 @@ object SocketChannel { final class ServerSocketChannel private (override protected val channel: JServerSocketChannel) extends SelectableChannel { - final def bind(local: SocketAddress): IO[IOException, Unit] = - IO.effect(channel.bind(local.jSocketAddress)).refineToOrDie[IOException].unit + final def bindTo(local: SocketAddress, backlog: Int = 0): IO[IOException, Unit] = bind(Some(local), backlog) - final def bind(local: SocketAddress, backlog: Int): IO[IOException, Unit] = - IO.effect(channel.bind(local.jSocketAddress, backlog)).refineToOrDie[IOException].unit + final def bindAuto(backlog: Int = 0): IO[IOException, Unit] = bind(None, backlog) - final def setOption[T](name: SocketOption[T], value: T): IO[Exception, Unit] = - IO.effect(channel.setOption(name, value)).refineToOrDie[Exception].unit + final def bind(local: Option[SocketAddress], backlog: Int = 0): IO[IOException, Unit] = + IO.effect(channel.bind(local.map(_.jSocketAddress).orNull, backlog)).refineToOrDie[IOException].unit final val socket: UIO[JServerSocket] = IO.effectTotal(channel.socket()) diff --git a/nio/src/test/scala/zio/nio/channels/ChannelSpec.scala b/nio/src/test/scala/zio/nio/channels/ChannelSpec.scala index 0653b543..8ff8d7e1 100644 --- a/nio/src/test/scala/zio/nio/channels/ChannelSpec.scala +++ b/nio/src/test/scala/zio/nio/channels/ChannelSpec.scala @@ -1,7 +1,7 @@ package zio.nio.channels import zio.nio.BaseSpec -import zio.nio.core.{ Buffer, SocketAddress } +import zio.nio.core.{ Buffer, InetSocketAddress, SocketAddress } import zio.{ IO, _ } import zio.test._ import zio.test.Assertion._ @@ -14,11 +14,11 @@ object ChannelSpec extends BaseSpec { testM("read/write") { def echoServer(started: Promise[Nothing, SocketAddress]): IO[Exception, Unit] = for { - address <- SocketAddress.inetSocketAddress(0) + address <- InetSocketAddress.wildCard(0) sink <- Buffer.byte(3) _ <- AsynchronousServerSocketChannel().use { server => for { - _ <- server.bind(address) + _ <- server.bindTo(address) addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) _ <- started.succeed(addr) _ <- server.accept.use { worker => @@ -56,10 +56,10 @@ object ChannelSpec extends BaseSpec { testM("read should fail when connection close") { def server(started: Promise[Nothing, SocketAddress]): IO[Exception, Fiber[Exception, Boolean]] = for { - address <- SocketAddress.inetSocketAddress(0) + address <- InetSocketAddress.wildCard(0) result <- AsynchronousServerSocketChannel().use { server => for { - _ <- server.bind(address) + _ <- server.bindTo(address) addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) _ <- started.succeed(addr) result <- server.accept @@ -100,7 +100,7 @@ object ChannelSpec extends BaseSpec { for { worker <- AsynchronousServerSocketChannel().use { server => for { - _ <- server.bind(address) + _ <- server.bindTo(address) addr <- server.localAddress.flatMap(opt => IO.effect(opt.get).orDie) _ <- started.succeed(addr) worker <- server.accept.use(_ => ZIO.unit) @@ -109,7 +109,7 @@ object ChannelSpec extends BaseSpec { } yield worker for { - address <- SocketAddress.inetSocketAddress(0) + address <- InetSocketAddress.wildCard(0) serverStarted <- Promise.make[Nothing, SocketAddress] s1 <- server(address, serverStarted) addr <- serverStarted.await @@ -125,8 +125,8 @@ object ChannelSpec extends BaseSpec { testM("accept should be interruptible") { AsynchronousServerSocketChannel().use { server => for { - addr <- SocketAddress.inetSocketAddress(0) - _ <- server.bind(addr) + addr <- InetSocketAddress.wildCard(0) + _ <- server.bindTo(addr) fiber <- server.accept.useNow.fork _ <- fiber.interrupt result <- fiber.await @@ -136,9 +136,9 @@ object ChannelSpec extends BaseSpec { // this would best be tagged as an regression test. for now just run manually when suspicious. testM("accept should not leak resources") { val server = for { - addr <- SocketAddress.inetSocketAddress(8081).toManaged_ + addr <- InetSocketAddress.wildCard(8081).toManaged_ channel <- AsynchronousServerSocketChannel() - _ <- channel.bind(addr).toManaged_ + _ <- channel.bindTo(addr).toManaged_ _ <- AsynchronousSocketChannel().use(channel => channel.connect(addr)).forever.toManaged_.fork } yield channel val interruptAccept = server.use( diff --git a/nio/src/test/scala/zio/nio/channels/DatagramChannelSpec.scala b/nio/src/test/scala/zio/nio/channels/DatagramChannelSpec.scala index cfab45b9..1baadd97 100644 --- a/nio/src/test/scala/zio/nio/channels/DatagramChannelSpec.scala +++ b/nio/src/test/scala/zio/nio/channels/DatagramChannelSpec.scala @@ -1,7 +1,7 @@ package zio.nio.channels import zio.nio._ -import zio.nio.core.{ Buffer, SocketAddress } +import zio.nio.core.{ Buffer, InetSocketAddress, SocketAddress } import zio.test.Assertion._ import zio.test._ import zio._ @@ -13,7 +13,7 @@ object DatagramChannelSpec extends BaseSpec { testM("read/write") { def echoServer(started: Promise[Nothing, SocketAddress]): IO[Exception, Unit] = for { - address <- SocketAddress.inetSocketAddress(0) + address <- InetSocketAddress.wildCard(0) sink <- Buffer.byte(3) _ <- DatagramChannel .bind(Some(address)) @@ -72,7 +72,7 @@ object DatagramChannelSpec extends BaseSpec { } yield worker for { - address <- SocketAddress.inetSocketAddress(0) + address <- InetSocketAddress.wildCard(0) serverStarted <- Promise.make[Nothing, SocketAddress] s1 <- server(address, serverStarted) addr <- serverStarted.await diff --git a/nio/src/test/scala/zio/nio/channels/SelectorSpec.scala b/nio/src/test/scala/zio/nio/channels/SelectorSpec.scala index c2dec9fa..f04f9fe5 100644 --- a/nio/src/test/scala/zio/nio/channels/SelectorSpec.scala +++ b/nio/src/test/scala/zio/nio/channels/SelectorSpec.scala @@ -1,11 +1,10 @@ package zio.nio.channels import java.nio.channels.CancelledKeyException - import zio._ import zio.blocking.Blocking import zio.clock.Clock -import zio.nio.core.{ Buffer, ByteBuffer, SocketAddress } +import zio.nio.core.{ Buffer, ByteBuffer, InetSocketAddress, SocketAddress } import zio.nio.core.channels.SelectionKey.Operation import zio.nio.BaseSpec import zio.test._ @@ -68,11 +67,11 @@ object SelectorSpec extends BaseSpec { } yield () for { - address <- SocketAddress.inetSocketAddress(0) + address <- InetSocketAddress.wildCard(0) _ <- Selector.make.use { selector => ServerSocketChannel.open.use { channel => for { - _ <- channel.bind(address) + _ <- channel.bindTo(address) _ <- channel.configureBlocking(false) _ <- channel.register(selector, Operation.Accept) buffer <- Buffer.byte(256) From 9b2b7014f5b7a6e944097c5ad8ac750753f82b66 Mon Sep 17 00:00:00 2001 From: Lachlan O'Dea Date: Fri, 9 Oct 2020 20:43:29 +1100 Subject: [PATCH 2/7] Improve InetSocketAddress API. --- .../zio/nio/core/InetSocketAddress.scala | 54 ++++++++++++++++--- 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala b/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala index 5fc585de..095af550 100644 --- a/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala +++ b/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala @@ -1,8 +1,8 @@ package zio.nio.core -import java.net.{ InetSocketAddress => JInetSocketAddress, SocketAddress => JSocketAddress } +import java.net.{ UnknownHostException, InetSocketAddress => JInetSocketAddress, SocketAddress => JSocketAddress } -import zio.UIO +import zio.{ IO, UIO } /** * Representation of a socket address without a specific protocol. @@ -41,7 +41,10 @@ object SocketAddress { * will be made to resolve the hostname. * If resolution fails then the address is said to be unresolved but can still * be used on some circumstances like connecting through a proxy. - * It provides an immutable object used by sockets for binding, connecting, + * However, note that network channels generally do ''not'' accept unresolved + * socket addresses. + * + * This class provides an immutable object used by sockets for binding, connecting, * or as returned values. * * The wildcard is a special local IP address. It usually means "any" and can @@ -96,27 +99,56 @@ object InetSocketAddress { */ def wildCard(port: Int): UIO[InetSocketAddress] = UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(port))) + /** + * Creates a socket address where the IP address is the wildcard address and + * the port is ephemeral. + * + * The socket address wll be ''resolved''. + */ def wildCardEphemeral: UIO[InetSocketAddress] = wildCard(0) /** - * Creates a socket address from an IP address and a port number. + * Creates a socket address from a hostname and a port number. * - * The socket address will be ''resolved''. + * This method will attempt to resolve the hostname; if this fails, the returned + * socket address will be ''unresolved''. */ def hostname(hostname: String, port: Int): UIO[InetSocketAddress] = UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(hostname, port))) + /** + * Creates a resolved socket address from a hostname and port number. + * + * If the hostname cannot be resolved, fails with `UnknownHostException`. + */ + def hostnameResolved(hostname: String, port: Int): IO[UnknownHostException, InetSocketAddress] = + InetAddress.byName(hostname).flatMap(inetAddress(_, port)) + + /** + * Creates a socket address from a hostname, with an ephemeral port. + * + * This method will attempt to resolve the hostname; if this fails, the returned + * socket address will be ''unresolved''. + */ def hostnameEphemeral(hostname: String): UIO[InetSocketAddress] = this.hostname(hostname, 0) /** - * Creates a socket address from a hostname and a port number. + * Creates a resolved socket address from a hostname, with an ephemeral port. * - * An attempt will be made to resolve the hostname into an `InetAddress`. - * If that attempt fails, the socket address will be flagged as ''unresolved''. + * If the hostname cannot be resolved, fails with `UnknownHostException`. + */ + def hostnameEphemeralResolved(hostname: String): IO[UnknownHostException, InetSocketAddress] = + InetAddress.byName(hostname).flatMap(inetAddressEphemeral) + + /** + * Creates a socket address from an IP address and a port number. */ def inetAddress(address: InetAddress, port: Int): UIO[InetSocketAddress] = UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(address.jInetAddress, port))) + /** + * Creates a socket address from an IP address, with an ephemeral port. + */ def inetAddressEphemeral(address: InetAddress): UIO[InetSocketAddress] = inetAddress(address, 0) /** @@ -128,6 +160,12 @@ object InetSocketAddress { def unresolvedHostname(hostname: String, port: Int): UIO[InetSocketAddress] = UIO.effectTotal(new InetSocketAddress(JInetSocketAddress.createUnresolved(hostname, port))) + /** + * Creates an unresolved socket address from a hostname using an ephemeral port. + * + * No attempt will be made to resolve the hostname into an `InetAddress`. + * The socket address will be flagged as ''unresolved''. + */ def unresolvedHostnameEphemeral(hostname: String): UIO[InetSocketAddress] = unresolvedHostname(hostname, 0) } From 6f48c67b10f4693a2b0aaaaf2f282626f95cd168 Mon Sep 17 00:00:00 2001 From: Lachlan O'Dea Date: Sat, 10 Oct 2020 15:34:56 +1100 Subject: [PATCH 3/7] Improvements to address handling. --- .../main/scala/zio/nio/core/InetAddress.scala | 2 +- .../zio/nio/core/InetSocketAddress.scala | 2 +- .../nio/core/channels/DatagramChannel.scala | 8 +-- .../nio/core/channels/SelectableChannel.scala | 2 +- .../zio/nio/core/channels/SelectionKey.scala | 15 ++++-- .../nio/channels/AsynchronousChannel.scala | 2 +- .../zio/nio/channels/DatagramChannel.scala | 8 +-- .../zio/nio/channels/SelectableChannel.scala | 2 +- .../scala/zio/nio/channels/Selector.scala | 43 +++++++++++++--- .../scala/zio/nio/channels/SelectorSpec.scala | 51 +++++++++---------- 10 files changed, 86 insertions(+), 49 deletions(-) diff --git a/nio-core/src/main/scala/zio/nio/core/InetAddress.scala b/nio-core/src/main/scala/zio/nio/core/InetAddress.scala index 3a901cf1..d62d64eb 100644 --- a/nio-core/src/main/scala/zio/nio/core/InetAddress.scala +++ b/nio-core/src/main/scala/zio/nio/core/InetAddress.scala @@ -46,7 +46,7 @@ final class InetAddress private[nio] (private[nio] val jInetAddress: JInetAddres def canonicalHostName: String = jInetAddress.getCanonicalHostName - def address: Array[Byte] = jInetAddress.getAddress + def address: Chunk[Byte] = Chunk.fromArray(jInetAddress.getAddress) override def hashCode(): Int = jInetAddress.hashCode() diff --git a/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala b/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala index 095af550..e68f1478 100644 --- a/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala +++ b/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala @@ -9,7 +9,7 @@ import zio.{ IO, UIO } * * The concrete subclass [[InetSocketAddress]] is used in practice. */ -sealed class SocketAddress private[nio] (private[nio] val jSocketAddress: JSocketAddress) { +sealed class SocketAddress protected (private[nio] val jSocketAddress: JSocketAddress) { final override def equals(obj: Any): Boolean = obj match { diff --git a/nio-core/src/main/scala/zio/nio/core/channels/DatagramChannel.scala b/nio-core/src/main/scala/zio/nio/core/channels/DatagramChannel.scala index 5a516699..5d5c20f0 100644 --- a/nio-core/src/main/scala/zio/nio/core/channels/DatagramChannel.scala +++ b/nio-core/src/main/scala/zio/nio/core/channels/DatagramChannel.scala @@ -61,7 +61,7 @@ final class DatagramChannel private[channels] (override protected[channels] val * @return the local address if the socket is bound, otherwise `None` */ def localAddress: IO[IOException, Option[SocketAddress]] = - IO.effect(channel.getLocalAddress()).refineToOrDie[IOException].map(a => Option(a).map(new SocketAddress(_))) + IO.effect(channel.getLocalAddress()).refineToOrDie[IOException].map(a => Option(a).map(SocketAddress.fromJava)) /** * Receives a datagram via this channel into the given [[zio.nio.core.ByteBuffer]]. @@ -70,7 +70,9 @@ final class DatagramChannel private[channels] (override protected[channels] val * @return the socket address of the datagram's source, if available. */ def receive(dst: ByteBuffer): IO[IOException, Option[SocketAddress]] = - IO.effect(channel.receive(dst.byteBuffer)).refineToOrDie[IOException].map(a => Option(a).map(new SocketAddress(_))) + IO.effect(channel.receive(dst.byteBuffer)) + .refineToOrDie[IOException] + .map(a => Option(a).map(SocketAddress.fromJava)) /** * Optionally returns the remote socket address that this channel's underlying socket is connected to. @@ -78,7 +80,7 @@ final class DatagramChannel private[channels] (override protected[channels] val * @return the remote address if the socket is connected, otherwise `None` */ def remoteAddress: IO[IOException, Option[SocketAddress]] = - IO.effect(channel.getRemoteAddress()).refineToOrDie[IOException].map(a => Option(a).map(new SocketAddress(_))) + IO.effect(channel.getRemoteAddress()).refineToOrDie[IOException].map(a => Option(a).map(SocketAddress.fromJava)) /** * Sends a datagram via this channel to the given target [[zio.nio.core.SocketAddress]]. diff --git a/nio-core/src/main/scala/zio/nio/core/channels/SelectableChannel.scala b/nio-core/src/main/scala/zio/nio/core/channels/SelectableChannel.scala index c694f168..7b9a51ee 100644 --- a/nio-core/src/main/scala/zio/nio/core/channels/SelectableChannel.scala +++ b/nio-core/src/main/scala/zio/nio/core/channels/SelectableChannel.scala @@ -143,7 +143,7 @@ final class ServerSocketChannel(override protected val channel: JServerSocketCha }) val localAddress: IO[IOException, SocketAddress] = - IO.effect(new SocketAddress(channel.getLocalAddress())).refineToOrDie[IOException] + IO.effect(SocketAddress.fromJava(channel.getLocalAddress())).refineToOrDie[IOException] } object ServerSocketChannel { diff --git a/nio-core/src/main/scala/zio/nio/core/channels/SelectionKey.scala b/nio-core/src/main/scala/zio/nio/core/channels/SelectionKey.scala index 62ddb1d7..67248a3d 100644 --- a/nio-core/src/main/scala/zio/nio/core/channels/SelectionKey.scala +++ b/nio-core/src/main/scala/zio/nio/core/channels/SelectionKey.scala @@ -56,8 +56,7 @@ final class SelectionKey(private[nio] val selectionKey: jc.SelectionKey) { * {{{ * for { * _ <- selector.select - * selectedKeys <- selector.selectedKeys - * _ <- IO.foreach_(selectedKeys) { key => + * _ <- selector.foreachSelectedKey { key => * key.matchChannel { readyOps => { * case channel: ServerSocketChannel if readyOps(Operation.Accept) => * // use `channel` to accept connection @@ -68,7 +67,7 @@ final class SelectionKey(private[nio] val selectionKey: jc.SelectionKey) { * IO.when(readyOps(Operation.Write)) { * // use `channel` to write * } - * } } *> selector.removeKey(key) + * } } * } * } yield () * }}} @@ -126,6 +125,14 @@ final class SelectionKey(private[nio] val selectionKey: jc.SelectionKey) { final def attachment: UIO[Option[AnyRef]] = IO.effectTotal(selectionKey.attachment()).map(Option(_)) - override def toString: String = selectionKey.toString() + override def toString: String = selectionKey.toString + + override def hashCode(): Int = selectionKey.hashCode() + + override def equals(obj: Any): Boolean = + obj match { + case other: SelectionKey => selectionKey.equals(other.selectionKey) + case _ => false + } } diff --git a/nio/src/main/scala/zio/nio/channels/AsynchronousChannel.scala b/nio/src/main/scala/zio/nio/channels/AsynchronousChannel.scala index 8fa283fa..41d2faaf 100644 --- a/nio/src/main/scala/zio/nio/channels/AsynchronousChannel.scala +++ b/nio/src/main/scala/zio/nio/channels/AsynchronousChannel.scala @@ -106,7 +106,7 @@ class AsynchronousServerSocketChannel(private val channel: JAsynchronousServerSo */ final def localAddress: IO[IOException, Option[SocketAddress]] = IO.effect( - Option(channel.getLocalAddress).map(new SocketAddress(_)) + Option(channel.getLocalAddress).map(SocketAddress.fromJava) ).refineToOrDie[IOException] /** diff --git a/nio/src/main/scala/zio/nio/channels/DatagramChannel.scala b/nio/src/main/scala/zio/nio/channels/DatagramChannel.scala index 2fc56851..26e684db 100644 --- a/nio/src/main/scala/zio/nio/channels/DatagramChannel.scala +++ b/nio/src/main/scala/zio/nio/channels/DatagramChannel.scala @@ -43,7 +43,7 @@ final class DatagramChannel private[channels] (override protected[channels] val * @return the local address if the socket is bound, otherwise `None` */ def localAddress: IO[IOException, Option[SocketAddress]] = - IO.effect(channel.getLocalAddress()).refineToOrDie[IOException].map(a => Option(a).map(new SocketAddress(_))) + IO.effect(Option(channel.getLocalAddress()).map(SocketAddress.fromJava)).refineToOrDie[IOException] /** * Receives a datagram via this channel into the given [[zio.nio.core.ByteBuffer]]. @@ -52,7 +52,9 @@ final class DatagramChannel private[channels] (override protected[channels] val * @return the socket address of the datagram's source, if available. */ def receive(dst: ByteBuffer): IO[IOException, Option[SocketAddress]] = - IO.effect(channel.receive(dst.byteBuffer)).refineToOrDie[IOException].map(a => Option(a).map(new SocketAddress(_))) + IO.effect(channel.receive(dst.byteBuffer)) + .refineToOrDie[IOException] + .map(a => Option(a).map(SocketAddress.fromJava)) /** * Optionally returns the remote socket address that this channel's underlying socket is connected to. @@ -60,7 +62,7 @@ final class DatagramChannel private[channels] (override protected[channels] val * @return the remote address if the socket is connected, otherwise `None` */ def remoteAddress: IO[IOException, Option[SocketAddress]] = - IO.effect(channel.getRemoteAddress()).refineToOrDie[IOException].map(a => Option(a).map(new SocketAddress(_))) + IO.effect(Option(channel.getRemoteAddress()).map(SocketAddress.fromJava)).refineToOrDie[IOException] /** * Sends a datagram via this channel to the given target [[zio.nio.core.SocketAddress]]. diff --git a/nio/src/main/scala/zio/nio/channels/SelectableChannel.scala b/nio/src/main/scala/zio/nio/channels/SelectableChannel.scala index 1cb81935..fdc5d44e 100644 --- a/nio/src/main/scala/zio/nio/channels/SelectableChannel.scala +++ b/nio/src/main/scala/zio/nio/channels/SelectableChannel.scala @@ -147,7 +147,7 @@ final class ServerSocketChannel private (override protected val channel: JServer IO.effect(Option(channel.accept()).map(new SocketChannel(_))).refineToOrDie[IOException] final val localAddress: IO[IOException, SocketAddress] = - IO.effect(new SocketAddress(channel.getLocalAddress())).refineToOrDie[IOException] + IO.effect(SocketAddress.fromJava(channel.getLocalAddress())).refineToOrDie[IOException] } object ServerSocketChannel { diff --git a/nio/src/main/scala/zio/nio/channels/Selector.scala b/nio/src/main/scala/zio/nio/channels/Selector.scala index aab6ac9f..eeda518c 100644 --- a/nio/src/main/scala/zio/nio/channels/Selector.scala +++ b/nio/src/main/scala/zio/nio/channels/Selector.scala @@ -1,16 +1,15 @@ package zio.nio.channels -import java.io.IOException -import java.nio.channels.{ ClosedSelectorException, Selector => JSelector, SelectionKey => JSelectionKey } - -import zio.{ IO, Managed, UIO, ZIO } import com.github.ghik.silencer.silent +import zio.blocking.Blocking import zio.duration.Duration import zio.nio.channels.spi.SelectorProvider import zio.nio.core.channels.SelectionKey -import zio.blocking -import zio.blocking.Blocking +import zio.{ IO, Managed, UIO, ZIO, blocking } +import java.io.IOException +import java.nio.channels.{ ClosedSelectorException, SelectionKey => JSelectionKey, Selector => JSelector } +import scala.collection.mutable import scala.jdk.CollectionConverters._ class Selector(private[nio] val selector: JSelector) { @@ -24,12 +23,40 @@ class Selector(private[nio] val selector: JSelector) { .map(_.asScala.toSet[JSelectionKey].map(new SelectionKey(_))) .refineToOrDie[ClosedSelectorException] + /** + * Returns this selector's selected-key set. + * + * Note that the returned set it mutable - keys may be removed from, but not directly added to it. + * Any attempt to add an object to the key set will cause an `UnsupportedOperationException` to be thrown. + * The selected-key set is not thread-safe. + */ @silent - final val selectedKeys: IO[ClosedSelectorException, Set[SelectionKey]] = + final val selectedKeys: IO[ClosedSelectorException, mutable.Set[SelectionKey]] = IO.effect(selector.selectedKeys()) - .map(_.asScala.toSet[JSelectionKey].map(new SelectionKey(_))) + .map(_.asScala.map(new SelectionKey(_))) .refineToOrDie[ClosedSelectorException] + /** + * Performs an effect with each selected key. + * + * If the result of effect is true, the key will be removed from the selected-key set, which is + * usually what you want after successfully handling a selected key. + */ + def foreachSelectedKey[R, E](f: SelectionKey => ZIO[R, E, Boolean]): ZIO[R, E, Unit] = + ZIO.effectTotal(selector.selectedKeys().iterator()).flatMap { iter => + def loop: ZIO[R, E, Unit] = + ZIO.effectSuspendTotal { + if (iter.hasNext) { + val key = iter.next() + f(new SelectionKey(key)).flatMap(ZIO.when(_)(ZIO.effectTotal(iter.remove()))) *> + loop + } else + ZIO.unit + } + + loop + } + final def removeKey(key: SelectionKey): IO[ClosedSelectorException, Unit] = IO.effect(selector.selectedKeys().remove(key.selectionKey)) .unit diff --git a/nio/src/test/scala/zio/nio/channels/SelectorSpec.scala b/nio/src/test/scala/zio/nio/channels/SelectorSpec.scala index f04f9fe5..ef38585d 100644 --- a/nio/src/test/scala/zio/nio/channels/SelectorSpec.scala +++ b/nio/src/test/scala/zio/nio/channels/SelectorSpec.scala @@ -38,32 +38,31 @@ object SelectorSpec extends BaseSpec { buffer: ByteBuffer ): ZIO[Blocking, Exception, Unit] = for { - _ <- selector.select - selectedKeys <- selector.selectedKeys - _ <- IO.foreach(selectedKeys) { key => - IO.whenM(safeStatusCheck(key.isAcceptable)) { - for { - clientOpt <- channel.accept - client = clientOpt.get - _ <- client.configureBlocking(false) - _ <- client.register(selector, Operation.Read) - } yield () - } *> - IO.whenM(safeStatusCheck(key.isReadable)) { - IO.effectSuspendTotal { - val sClient = key.channel - val client = sClient.asInstanceOf[zio.nio.core.channels.SocketChannel] - for { - _ <- client.read(buffer) - _ <- buffer.flip - _ <- client.write(buffer) - _ <- buffer.clear - _ <- client.close - } yield () - } - } *> - selector.removeKey(key) - } + _ <- selector.select + _ <- selector.foreachSelectedKey { key => + IO.whenM(safeStatusCheck(key.isAcceptable)) { + for { + clientOpt <- channel.accept + client = clientOpt.get + _ <- client.configureBlocking(false) + _ <- client.register(selector, Operation.Read) + } yield () + } *> + IO.whenM(safeStatusCheck(key.isReadable)) { + IO.effectSuspendTotal { + val sClient = key.channel + val client = sClient.asInstanceOf[zio.nio.core.channels.SocketChannel] + for { + _ <- client.read(buffer) + _ <- buffer.flip + _ <- client.write(buffer) + _ <- buffer.clear + _ <- client.close + } yield () + } + }.as(true) + } + _ <- selector.selectedKeys.filterOrDieMessage(_.isEmpty)("Selected key set should be empty") } yield () for { From 918f97718562a9686b3a2c49c8b4328a621030f5 Mon Sep 17 00:00:00 2001 From: Lachlan O'Dea Date: Mon, 26 Oct 2020 18:30:13 +1100 Subject: [PATCH 4/7] Add localhost constructor to InetSocketAddress. --- .../src/main/scala/zio/nio/core/InetSocketAddress.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala b/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala index e68f1478..74f4b70c 100644 --- a/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala +++ b/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala @@ -151,6 +151,12 @@ object InetSocketAddress { */ def inetAddressEphemeral(address: InetAddress): UIO[InetSocketAddress] = inetAddress(address, 0) + /** + * Creates a socket address for localhost using the specified port. + */ + def localHost(port: Int): IO[UnknownHostException, InetSocketAddress] = + InetAddress.localHost.flatMap(inetAddress(_, port)) + /** * Creates an unresolved socket address from a hostname and a port number. * From fea81b8a4eb34839c29615c7a625a8febb3398b7 Mon Sep 17 00:00:00 2001 From: Steven Vroonland Date: Sat, 13 Mar 2021 12:42:36 +0100 Subject: [PATCH 5/7] Fix name conflict --- examples/src/main/scala/StreamsBasedServer.scala | 2 +- .../src/main/scala/zio/nio/core/InetSocketAddress.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/src/main/scala/StreamsBasedServer.scala b/examples/src/main/scala/StreamsBasedServer.scala index a9282048..20413c0d 100644 --- a/examples/src/main/scala/StreamsBasedServer.scala +++ b/examples/src/main/scala/StreamsBasedServer.scala @@ -16,7 +16,7 @@ object StreamsBasedServer extends App { AsynchronousServerSocketChannel() .use(socket => for { - _ <- InetSocketAddress.hostname("localhost", port).flatMap(socket.bindTo(_)) + _ <- InetSocketAddress.hostName("localhost", port).flatMap(socket.bindTo(_)) _ <- ZStream .repeatEffect(socket.accept.preallocate) .map(_.withEarlyRelease) diff --git a/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala b/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala index 74f4b70c..bd24e3e8 100644 --- a/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala +++ b/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala @@ -113,8 +113,8 @@ object InetSocketAddress { * This method will attempt to resolve the hostname; if this fails, the returned * socket address will be ''unresolved''. */ - def hostname(hostname: String, port: Int): UIO[InetSocketAddress] = - UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(hostname, port))) + def hostName(hostName: String, port: Int): UIO[InetSocketAddress] = + UIO.effectTotal(new InetSocketAddress(new JInetSocketAddress(hostName, port))) /** * Creates a resolved socket address from a hostname and port number. @@ -130,7 +130,7 @@ object InetSocketAddress { * This method will attempt to resolve the hostname; if this fails, the returned * socket address will be ''unresolved''. */ - def hostnameEphemeral(hostname: String): UIO[InetSocketAddress] = this.hostname(hostname, 0) + def hostnameEphemeral(hostname: String): UIO[InetSocketAddress] = this.hostName(hostname, 0) /** * Creates a resolved socket address from a hostname, with an ephemeral port. From cfc05506d5898e4c61c9ade047dad20a26ec2dba Mon Sep 17 00:00:00 2001 From: Steven Vroonland Date: Sat, 13 Mar 2021 12:43:10 +0100 Subject: [PATCH 6/7] Fix typo --- nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala b/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala index bd24e3e8..e16d7dfd 100644 --- a/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala +++ b/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala @@ -103,7 +103,7 @@ object InetSocketAddress { * Creates a socket address where the IP address is the wildcard address and * the port is ephemeral. * - * The socket address wll be ''resolved''. + * The socket address will be ''resolved''. */ def wildCardEphemeral: UIO[InetSocketAddress] = wildCard(0) From 8acdd43fc07d99dd316b64e233c5acebc4d582d8 Mon Sep 17 00:00:00 2001 From: Steven Vroonland Date: Sat, 13 Mar 2021 12:51:30 +0100 Subject: [PATCH 7/7] Consistent casing of hostName --- .../main/scala/zio/nio/core/InetAddress.scala | 2 +- .../scala/zio/nio/core/InetSocketAddress.scala | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/nio-core/src/main/scala/zio/nio/core/InetAddress.scala b/nio-core/src/main/scala/zio/nio/core/InetAddress.scala index d62d64eb..02cb3483 100644 --- a/nio-core/src/main/scala/zio/nio/core/InetAddress.scala +++ b/nio-core/src/main/scala/zio/nio/core/InetAddress.scala @@ -42,7 +42,7 @@ final class InetAddress private[nio] (private[nio] val jInetAddress: JInetAddres IO.effect(jInetAddress.isReachable(networkInterface.jNetworkInterface, ttl, timeout)) .refineToOrDie[IOException] - def hostname: String = jInetAddress.getHostName + def hostName: String = jInetAddress.getHostName def canonicalHostName: String = jInetAddress.getCanonicalHostName diff --git a/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala b/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala index e16d7dfd..e4a3e58b 100644 --- a/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala +++ b/nio-core/src/main/scala/zio/nio/core/InetSocketAddress.scala @@ -121,8 +121,8 @@ object InetSocketAddress { * * If the hostname cannot be resolved, fails with `UnknownHostException`. */ - def hostnameResolved(hostname: String, port: Int): IO[UnknownHostException, InetSocketAddress] = - InetAddress.byName(hostname).flatMap(inetAddress(_, port)) + def hostNameResolved(hostName: String, port: Int): IO[UnknownHostException, InetSocketAddress] = + InetAddress.byName(hostName).flatMap(inetAddress(_, port)) /** * Creates a socket address from a hostname, with an ephemeral port. @@ -130,15 +130,15 @@ object InetSocketAddress { * This method will attempt to resolve the hostname; if this fails, the returned * socket address will be ''unresolved''. */ - def hostnameEphemeral(hostname: String): UIO[InetSocketAddress] = this.hostName(hostname, 0) + def hostNameEphemeral(hostName: String): UIO[InetSocketAddress] = this.hostName(hostName, 0) /** * Creates a resolved socket address from a hostname, with an ephemeral port. * * If the hostname cannot be resolved, fails with `UnknownHostException`. */ - def hostnameEphemeralResolved(hostname: String): IO[UnknownHostException, InetSocketAddress] = - InetAddress.byName(hostname).flatMap(inetAddressEphemeral) + def hostNameEphemeralResolved(hostName: String): IO[UnknownHostException, InetSocketAddress] = + InetAddress.byName(hostName).flatMap(inetAddressEphemeral) /** * Creates a socket address from an IP address and a port number. @@ -163,8 +163,8 @@ object InetSocketAddress { * No attempt will be made to resolve the hostname into an `InetAddress`. * The socket address will be flagged as ''unresolved''. */ - def unresolvedHostname(hostname: String, port: Int): UIO[InetSocketAddress] = - UIO.effectTotal(new InetSocketAddress(JInetSocketAddress.createUnresolved(hostname, port))) + def unresolvedHostName(hostName: String, port: Int): UIO[InetSocketAddress] = + UIO.effectTotal(new InetSocketAddress(JInetSocketAddress.createUnresolved(hostName, port))) /** * Creates an unresolved socket address from a hostname using an ephemeral port. @@ -172,6 +172,6 @@ object InetSocketAddress { * No attempt will be made to resolve the hostname into an `InetAddress`. * The socket address will be flagged as ''unresolved''. */ - def unresolvedHostnameEphemeral(hostname: String): UIO[InetSocketAddress] = unresolvedHostname(hostname, 0) + def unresolvedHostNameEphemeral(hostName: String): UIO[InetSocketAddress] = unresolvedHostName(hostName, 0) }