| 
 | 1 | +use std::future::Future;  | 
1 | 2 | use std::io;  | 
2 | 3 | use std::path::Path;  | 
3 | 4 | use std::pin::Pin;  | 
4 | 5 | use std::task::{ready, Context, Poll};  | 
5 |  | -use std::{  | 
6 |  | -    future::Future,  | 
7 |  | -    net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},  | 
8 |  | -};  | 
9 | 6 | 
 
  | 
 | 7 | +pub use buffered::{BufferedSocket, WriteBuffer};  | 
10 | 8 | use bytes::BufMut;  | 
11 | 9 | use cfg_if::cfg_if;  | 
12 | 10 | 
 
  | 
13 |  | -pub use buffered::{BufferedSocket, WriteBuffer};  | 
14 |  | - | 
15 |  | -use crate::{io::ReadBuf, rt::spawn_blocking};  | 
 | 11 | +use crate::io::ReadBuf;  | 
16 | 12 | 
 
  | 
17 | 13 | mod buffered;  | 
18 | 14 | 
 
  | 
@@ -146,10 +142,7 @@ where  | 
146 | 142 | pub trait WithSocket {  | 
147 | 143 |     type Output;  | 
148 | 144 | 
 
  | 
149 |  | -    fn with_socket<S: Socket>(  | 
150 |  | -        self,  | 
151 |  | -        socket: S,  | 
152 |  | -    ) -> impl std::future::Future<Output = Self::Output> + Send;  | 
 | 145 | +    fn with_socket<S: Socket>(self, socket: S) -> impl Future<Output = Self::Output> + Send;  | 
153 | 146 | }  | 
154 | 147 | 
 
  | 
155 | 148 | pub struct SocketIntoBox;  | 
@@ -193,98 +186,67 @@ pub async fn connect_tcp<Ws: WithSocket>(  | 
193 | 186 |     port: u16,  | 
194 | 187 |     with_socket: Ws,  | 
195 | 188 | ) -> crate::Result<Ws::Output> {  | 
196 |  | -    // IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those.  | 
197 |  | -    let host = host.trim_matches(&['[', ']'][..]);  | 
198 |  | - | 
199 |  | -    let addresses = if let Ok(addr) = host.parse::<Ipv4Addr>() {  | 
200 |  | -        let addr = SocketAddrV4::new(addr, port);  | 
201 |  | -        vec![SocketAddr::V4(addr)].into_iter()  | 
202 |  | -    } else if let Ok(addr) = host.parse::<Ipv6Addr>() {  | 
203 |  | -        let addr = SocketAddrV6::new(addr, port, 0, 0);  | 
204 |  | -        vec![SocketAddr::V6(addr)].into_iter()  | 
205 |  | -    } else {  | 
206 |  | -        let host = host.to_string();  | 
207 |  | -        spawn_blocking(move || {  | 
208 |  | -            let addr = (host.as_str(), port);  | 
209 |  | -            ToSocketAddrs::to_socket_addrs(&addr)  | 
210 |  | -        })  | 
211 |  | -        .await?  | 
212 |  | -    };  | 
213 |  | - | 
214 |  | -    let mut last_err = None;  | 
215 |  | - | 
216 |  | -    // Loop through all the Socket Addresses that the hostname resolves to  | 
217 |  | -    for socket_addr in addresses {  | 
218 |  | -        match connect_tcp_address(socket_addr).await {  | 
219 |  | -            Ok(stream) => return Ok(with_socket.with_socket(stream).await),  | 
220 |  | -            Err(e) => last_err = Some(e),  | 
221 |  | -        }  | 
 | 189 | +    #[cfg(feature = "_rt-tokio")]  | 
 | 190 | +    if crate::rt::rt_tokio::available() {  | 
 | 191 | +        return Ok(with_socket  | 
 | 192 | +            .with_socket(tokio::net::TcpStream::connect((host, port)).await?)  | 
 | 193 | +            .await);  | 
222 | 194 |     }  | 
223 | 195 | 
 
  | 
224 |  | -    // If we reach this point, it means we failed to connect to any of the addresses.  | 
225 |  | -    // Return the last error we encountered, or a custom error if the hostname didn't resolve to any address.  | 
226 |  | -    Err(match last_err {  | 
227 |  | -        Some(err) => err,  | 
228 |  | -        None => io::Error::new(  | 
229 |  | -            io::ErrorKind::AddrNotAvailable,  | 
230 |  | -            "Hostname did not resolve to any addresses",  | 
231 |  | -        )  | 
232 |  | -        .into(),  | 
233 |  | -    })  | 
234 |  | -}  | 
235 |  | - | 
236 |  | -async fn connect_tcp_address(socket_addr: SocketAddr) -> crate::Result<impl Socket> {  | 
237 | 196 |     cfg_if! {  | 
238 |  | -        if #[cfg(feature = "_rt-tokio")] {  | 
239 |  | -            if crate::rt::rt_tokio::available() {  | 
240 |  | -                use tokio::net::TcpStream;  | 
241 |  | - | 
242 |  | -                let stream = TcpStream::connect(socket_addr).await?;  | 
243 |  | -                stream.set_nodelay(true)?;  | 
244 |  | - | 
245 |  | -                Ok(stream)  | 
246 |  | -            } else {  | 
247 |  | -                crate::rt::missing_rt(socket_addr)  | 
248 |  | -            }  | 
249 |  | -        } else if #[cfg(feature = "_rt-async-io")] {  | 
250 |  | -            use async_io::Async;  | 
251 |  | -            use std::net::TcpStream;  | 
252 |  | - | 
253 |  | -            let stream = Async::<TcpStream>::connect(socket_addr).await?;  | 
254 |  | -            stream.get_ref().set_nodelay(true)?;  | 
255 |  | - | 
256 |  | -            Ok(stream)  | 
 | 197 | +        if #[cfg(feature = "_rt-async-io")] {  | 
 | 198 | +            Ok(with_socket.with_socket(connect_tcp_async_io(host, port).await?).await)  | 
257 | 199 |         } else {  | 
258 |  | -            crate::rt::missing_rt(socket_addr);  | 
259 |  | -            #[allow(unreachable_code)]  | 
260 |  | -            Ok(())  | 
 | 200 | +            crate::rt::missing_rt((host, port, with_socket))  | 
261 | 201 |         }  | 
262 | 202 |     }  | 
263 | 203 | }  | 
264 | 204 | 
 
  | 
265 |  | -// Work around `impl Socket`` and 'unability to specify test build cargo feature'.  | 
266 |  | -// `connect_tcp_address` compilation would fail without this impl with  | 
267 |  | -// 'cannot infer return type' error.  | 
268 |  | -impl Socket for () {  | 
269 |  | -    fn try_read(&mut self, _: &mut dyn ReadBuf) -> io::Result<usize> {  | 
270 |  | -        unreachable!()  | 
271 |  | -    }  | 
 | 205 | +/// Open a TCP socket to `host` and `port`.  | 
 | 206 | +///  | 
 | 207 | +/// If `host` is a hostname, attempt to connect to each address it resolves to.  | 
 | 208 | +///  | 
 | 209 | +/// This implements the same behavior as [`tokio::net::TcpStream::connect()`].  | 
 | 210 | +#[cfg(feature = "_rt-async-io")]  | 
 | 211 | +async fn connect_tcp_async_io(host: &str, port: u16) -> crate::Result<impl Socket> {  | 
 | 212 | +    use async_io::Async;  | 
 | 213 | +    use std::net::{IpAddr, TcpStream, ToSocketAddrs};  | 
272 | 214 | 
 
  | 
273 |  | -    fn try_write(&mut self, _: &[u8]) -> io::Result<usize> {  | 
274 |  | -        unreachable!()  | 
275 |  | -    }  | 
 | 215 | +    // IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those.  | 
 | 216 | +    let host = host.trim_matches(&['[', ']'][..]);  | 
276 | 217 | 
 
  | 
277 |  | -    fn poll_read_ready(&mut self, _: &mut Context<'_>) -> Poll<io::Result<()>> {  | 
278 |  | -        unreachable!()  | 
 | 218 | +    if let Ok(addr) = host.parse::<IpAddr>() {  | 
 | 219 | +        return Ok(Async::<TcpStream>::connect((addr, port)).await?);  | 
279 | 220 |     }  | 
280 | 221 | 
 
  | 
281 |  | -    fn poll_write_ready(&mut self, _: &mut Context<'_>) -> Poll<io::Result<()>> {  | 
282 |  | -        unreachable!()  | 
283 |  | -    }  | 
 | 222 | +    let host = host.to_string();  | 
 | 223 | + | 
 | 224 | +    let addresses = crate::rt::spawn_blocking(move || {  | 
 | 225 | +        let addr = (host.as_str(), port);  | 
 | 226 | +        ToSocketAddrs::to_socket_addrs(&addr)  | 
 | 227 | +    })  | 
 | 228 | +    .await?;  | 
284 | 229 | 
 
  | 
285 |  | -    fn poll_shutdown(&mut self, _: &mut Context<'_>) -> Poll<io::Result<()>> {  | 
286 |  | -        unreachable!()  | 
 | 230 | +    let mut last_err = None;  | 
 | 231 | + | 
 | 232 | +    // Loop through all the Socket Addresses that the hostname resolves to  | 
 | 233 | +    for socket_addr in addresses {  | 
 | 234 | +        match Async::<TcpStream>::connect(socket_addr).await {  | 
 | 235 | +            Ok(stream) => return Ok(stream),  | 
 | 236 | +            Err(e) => last_err = Some(e),  | 
 | 237 | +        }  | 
287 | 238 |     }  | 
 | 239 | + | 
 | 240 | +    // If we reach this point, it means we failed to connect to any of the addresses.  | 
 | 241 | +    // Return the last error we encountered, or a custom error if the hostname didn't resolve to any address.  | 
 | 242 | +    Err(last_err  | 
 | 243 | +        .unwrap_or_else(|| {  | 
 | 244 | +            io::Error::new(  | 
 | 245 | +                io::ErrorKind::AddrNotAvailable,  | 
 | 246 | +                "Hostname did not resolve to any addresses",  | 
 | 247 | +            )  | 
 | 248 | +        })  | 
 | 249 | +        .into())  | 
288 | 250 | }  | 
289 | 251 | 
 
  | 
290 | 252 | /// Connect a Unix Domain Socket at the given path.  | 
 | 
0 commit comments