diff --git a/async-nats/src/lib.rs b/async-nats/src/lib.rs index 2fc07e9cb..cdcccd77f 100644 --- a/async-nats/src/lib.rs +++ b/async-nats/src/lib.rs @@ -119,7 +119,7 @@ use std::task::{Context, Poll}; use std::time::Duration; use tokio::io::ErrorKind; use tokio::time::sleep; -use url::Url; +use url::{Host, Url}; use bytes::{Bytes, BytesMut}; use serde::{Deserialize, Serialize}; @@ -1165,7 +1165,15 @@ impl ServerAddr { /// Returns the host. pub fn host(&self) -> &str { - self.0.host_str().unwrap() + match self.0.host() { + Some(Host::Domain(_)) | Some(Host::Ipv4 { .. }) => self.0.host_str().unwrap(), + // `host_str()` for Ipv6 includes the []s + Some(Host::Ipv6 { .. }) => { + let host = self.0.host_str().unwrap(); + &host[1..host.len() - 1] + } + None => "", + } } /// Returns the port. @@ -1264,3 +1272,26 @@ pub(crate) enum Authorization { CallbackArg1>, ), } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn server_address_ipv6() { + let address = ServerAddr::from_str("nats://[::]").unwrap(); + assert_eq!(address.host(), "::") + } + + #[test] + fn serverr_address_ipv4() { + let address = ServerAddr::from_str("nats://127.0.0.1").unwrap(); + assert_eq!(address.host(), "127.0.0.1") + } + + #[test] + fn serverr_address_domain() { + let address = ServerAddr::from_str("nats://example.com").unwrap(); + assert_eq!(address.host(), "example.com") + } +} diff --git a/nats/src/connector.rs b/nats/src/connector.rs index 56c748ed6..b0e728a6b 100644 --- a/nats/src/connector.rs +++ b/nats/src/connector.rs @@ -20,7 +20,7 @@ use std::str::FromStr; use std::sync::Arc; use std::thread; use std::time::Duration; -use url::Url; +use url::{Host, Url}; use webpki::DNSNameRef; @@ -625,7 +625,15 @@ impl ServerAddress { /// Returns the host. pub fn host(&self) -> &str { - self.0.host_str().unwrap() + match self.0.host() { + Some(Host::Domain(_)) | Some(Host::Ipv4 { .. }) => self.0.host_str().unwrap(), + // `host_str()` for Ipv6 includes the []s + Some(Host::Ipv6 { .. }) => { + let host = self.0.host_str().unwrap(); + &host[1..host.len() - 1] + } + None => "", + } } /// Returns the port. @@ -712,6 +720,24 @@ impl IntoServerList for io::Result> { mod tests { use super::*; + #[test] + fn server_address_ipv6() { + let address = ServerAddress::from_str("nats://[::]").unwrap(); + assert_eq!(address.host(), "::") + } + + #[test] + fn server_address_ipv4() { + let address = ServerAddress::from_str("nats://127.0.0.1").unwrap(); + assert_eq!(address.host(), "127.0.0.1") + } + + #[test] + fn server_address_domain() { + let address = ServerAddress::from_str("nats://example.com").unwrap(); + assert_eq!(address.host(), "example.com") + } + #[test] fn server_address_no_auth() { let address = ServerAddress::from_str("nats://localhost").unwrap();