From 81f57625a27f4fe43e6569cc60469bf9cdbbcb8a Mon Sep 17 00:00:00 2001
From: kolapapa <346512016@qq.com>
Date: Wed, 31 Mar 2021 01:49:10 +0800
Subject: [PATCH] feat(client): http_connector supports `set_interface`

---
 Cargo.toml                 |  2 +-
 src/client/connect/http.rs | 90 ++++++++++++++++++++++++++++++++++++++
 2 files changed, 91 insertions(+), 1 deletion(-)

diff --git a/Cargo.toml b/Cargo.toml
index 02d6d333c4..f96745d674 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -42,7 +42,7 @@ want = "0.3"
 # Optional
 
 libc = { version = "0.2", optional = true }
-socket2 = { version = "0.4", optional = true }
+socket2 = { version = "0.4", optional = true, features = ["all"] }
 
 [dev-dependencies]
 futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
diff --git a/src/client/connect/http.rs b/src/client/connect/http.rs
index 4437c86380..38eeb91326 100644
--- a/src/client/connect/http.rs
+++ b/src/client/connect/http.rs
@@ -79,6 +79,7 @@ struct Config {
     reuse_address: bool,
     send_buffer_size: Option<usize>,
     recv_buffer_size: Option<usize>,
+    interface: Option<String>,
 }
 
 // ===== impl HttpConnector =====
@@ -119,6 +120,7 @@ impl<R> HttpConnector<R> {
                 reuse_address: false,
                 send_buffer_size: None,
                 recv_buffer_size: None,
+                interface: None,
             }),
             resolver,
         }
@@ -228,6 +230,25 @@ impl<R> HttpConnector<R> {
         self
     }
 
+    /// Sets the value for the `SO_BINDTODEVICE` option on this socket.
+    ///
+    /// If a socket is bound to an interface, only packets received from that particular
+    /// interface are processed by the socket. Note that this only works for some socket
+    /// types, particularly AF_INET sockets.
+    ///
+    /// On Linux it can be used to specify a [VRF], but the binary needs
+    /// to either have `CAP_NET_RAW` or to be run as root.
+    ///
+    /// This function is only available on Android、Fuchsia and Linux.
+    ///
+    /// [VRF]: https://www.kernel.org/doc/Documentation/networking/vrf.txt
+    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
+    #[inline]
+    pub fn set_interface<S: Into<String>>(&mut self, interface: S) -> &mut Self {
+        self.config_mut().interface = Some(interface.into());
+        self
+    }
+
     // private
 
     fn config_mut(&mut self) -> &mut Config {
@@ -612,6 +633,14 @@ fn connect(
     )
     .map_err(ConnectError::m("tcp bind local error"))?;
 
+    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
+    // That this only works for some socket types, particularly AF_INET sockets.
+    if config.interface.is_some() {
+        socket
+            .bind_device(config.interface.as_ref().map(|iface| iface.as_bytes()))
+            .map_err(ConnectError::m("tcp bind interface error"))?;
+    }
+
     #[cfg(unix)]
     let socket = unsafe {
         // Safety: `from_raw_fd` is only safe to call if ownership of the raw
@@ -756,6 +785,14 @@ mod tests {
         (ip_v4, ip_v6)
     }
 
+    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
+    fn default_interface() -> Option<String> {
+        pnet_datalink::interfaces()
+            .iter()
+            .find(|e| e.is_up() && !e.is_loopback() && !e.ips.is_empty())
+            .map(|e| e.name.clone())
+    }
+
     #[tokio::test]
     async fn test_errors_missing_scheme() {
         let dst = "example.domain".parse().unwrap();
@@ -804,6 +841,58 @@ mod tests {
         }
     }
 
+    // NOTE: pnet crate that we use in this test doesn't compile on Windows
+    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
+    #[tokio::test]
+    #[ignore = "setting `SO_BINDTODEVICE` requires the `CAP_NET_RAW` capability (works when running as root)"]
+    async fn interface() {
+        use socket2::{Domain, Protocol, Socket, Type};
+        use std::net::TcpListener;
+        let _ = pretty_env_logger::try_init();
+
+        let interface: Option<String> = default_interface();
+
+        let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
+        let port = server4.local_addr().unwrap().port();
+
+        let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
+
+        let assert_interface_name =
+            |dst: String,
+             server: TcpListener,
+             bind_iface: Option<String>,
+             expected_interface: Option<String>| async move {
+                let mut connector = HttpConnector::new();
+                if let Some(iface) = bind_iface {
+                    connector.set_interface(iface);
+                }
+
+                connect(connector, dst.parse().unwrap()).await.unwrap();
+                let domain = Domain::for_address(server.local_addr().unwrap());
+                let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)).unwrap();
+
+                assert_eq!(
+                    socket.device().unwrap().as_deref(),
+                    expected_interface.as_deref().map(|val| val.as_bytes())
+                );
+            };
+
+        assert_interface_name(
+            format!("http://127.0.0.1:{}", port),
+            server4,
+            interface.clone(),
+            interface.clone(),
+        )
+        .await;
+        assert_interface_name(
+            format!("http://[::1]:{}", port),
+            server6,
+            interface.clone(),
+            interface.clone(),
+        )
+        .await;
+    }
+
     #[test]
     #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
     fn client_happy_eyeballs() {
@@ -933,6 +1022,7 @@ mod tests {
                         enforce_http: false,
                         send_buffer_size: None,
                         recv_buffer_size: None,
+                        interface: None,
                     };
                     let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg);
                     let start = Instant::now();