Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implementing the Copy trait for Extension #126

Merged
merged 1 commit into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,13 @@ impl TcpConnector<'_> {
(Some(cidr), _) => match cidr {
IpCidr::V4(cidr) => {
let ip = IpAddr::V4(
assign_ipv4_from_extension(cidr, self.inner.cidr_range, &extension).await,
assign_ipv4_from_extension(cidr, self.inner.cidr_range, extension).await,
);
Ok(SocketAddr::new(ip, 0))
}
IpCidr::V6(cidr) => {
let ip = IpAddr::V6(
assign_ipv6_from_extension(cidr, self.inner.cidr_range, &extension).await,
assign_ipv6_from_extension(cidr, self.inner.cidr_range, extension).await,
);
Ok(SocketAddr::new(ip, 0))
}
Expand Down Expand Up @@ -228,7 +228,7 @@ impl TcpConnector<'_> {
let mut last_err = None;

for target_addr in addrs {
match self.connect(target_addr, &extension).await {
match self.connect(target_addr, extension).await {
Ok(stream) => return Ok(stream),
Err(e) => last_err = Some(e),
};
Expand Down Expand Up @@ -353,7 +353,7 @@ impl TcpConnector<'_> {
pub async fn connect(
&self,
target_addr: SocketAddr,
extension: &Extension,
extension: Extension,
) -> std::io::Result<TcpStream> {
match (self.inner.cidr, self.inner.fallback) {
(None, Some(fallback)) => {
Expand Down Expand Up @@ -417,7 +417,7 @@ impl TcpConnector<'_> {
&self,
target_addr: SocketAddr,
cidr: IpCidr,
extension: &Extension,
extension: Extension,
) -> std::io::Result<TcpStream> {
let socket = self.create_socket_with_cidr(cidr, extension).await?;
socket.connect(target_addr).await
Expand Down Expand Up @@ -490,7 +490,7 @@ impl TcpConnector<'_> {
target_addr: SocketAddr,
cidr: IpCidr,
fallback: IpAddr,
extension: &Extension,
extension: Extension,
) -> std::io::Result<TcpStream> {
match self.connect_with_cidr(target_addr, cidr, extension).await {
Ok(first) => Ok(first),
Expand Down Expand Up @@ -562,7 +562,7 @@ impl TcpConnector<'_> {
async fn create_socket_with_cidr(
&self,
cidr: IpCidr,
extension: &Extension,
extension: Extension,
) -> std::io::Result<TcpSocket> {
match cidr {
IpCidr::V4(cidr) => {
Expand Down Expand Up @@ -625,9 +625,9 @@ impl UdpConnector<'_> {
pub async fn bind_socket(&self, extension: Extension) -> std::io::Result<UdpSocket> {
match (self.inner.cidr, self.inner.fallback) {
(None, Some(fallback)) => self.create_socket_with_addr(fallback).await,
(Some(cidr), None) => self.create_socket_with_cidr(cidr, &extension).await,
(Some(cidr), None) => self.create_socket_with_cidr(cidr, extension).await,
(Some(cidr), Some(fallback)) => {
self.create_socket_with_cidr_and_fallback(cidr, fallback, &extension)
self.create_socket_with_cidr_and_fallback(cidr, fallback, extension)
.await
}
(None, None) => UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], 0))).await,
Expand Down Expand Up @@ -765,7 +765,7 @@ impl UdpConnector<'_> {
async fn create_socket_with_cidr(
&self,
cidr: IpCidr,
extension: &Extension,
extension: Extension,
) -> std::io::Result<UdpSocket> {
match cidr {
IpCidr::V4(cidr) => {
Expand Down Expand Up @@ -805,7 +805,7 @@ impl UdpConnector<'_> {
&self,
cidr: IpCidr,
fallback: IpAddr,
extension: &Extension,
extension: Extension,
) -> std::io::Result<UdpSocket> {
match self.create_socket_with_cidr(cidr, extension).await {
Ok(first) => Ok(first),
Expand Down Expand Up @@ -871,19 +871,19 @@ impl HttpConnector<'_> {
let mut connector = self.inner.http.clone();
match (self.inner.cidr, self.inner.fallback) {
(Some(IpCidr::V4(cidr)), Some(IpAddr::V6(v6))) => {
let v4 = assign_ipv4_from_extension(cidr, self.inner.cidr_range, &extension).await;
let v4 = assign_ipv4_from_extension(cidr, self.inner.cidr_range, extension).await;
connector.set_local_addresses(v4, v6);
}
(Some(IpCidr::V4(cidr)), None) => {
let v4 = assign_ipv4_from_extension(cidr, self.inner.cidr_range, &extension).await;
let v4 = assign_ipv4_from_extension(cidr, self.inner.cidr_range, extension).await;
connector.set_local_address(Some(v4.into()));
}
(Some(IpCidr::V6(cidr)), Some(IpAddr::V4(v4))) => {
let v6 = assign_ipv6_from_extension(cidr, self.inner.cidr_range, &extension).await;
let v6 = assign_ipv6_from_extension(cidr, self.inner.cidr_range, extension).await;
connector.set_local_addresses(v4, v6);
}
(Some(IpCidr::V6(cidr)), None) => {
let v6 = assign_ipv6_from_extension(cidr, self.inner.cidr_range, &extension).await;
let v6 = assign_ipv6_from_extension(cidr, self.inner.cidr_range, extension).await;
connector.set_local_address(Some(v6.into()));
}
(None, addr) => connector.set_local_address(addr),
Expand Down Expand Up @@ -940,7 +940,7 @@ fn error(last_err: Option<std::io::Error>) -> std::io::Error {
async fn assign_ipv4_from_extension(
cidr: Ipv4Cidr,
cidr_range: Option<u8>,
extension: &Extension,
extension: Extension,
) -> Ipv4Addr {
if let Some(combined) = combined(extension).await {
match extension {
Expand Down Expand Up @@ -975,7 +975,7 @@ async fn assign_ipv4_from_extension(
async fn assign_ipv6_from_extension(
cidr: Ipv6Cidr,
cidr_range: Option<u8>,
extension: &Extension,
extension: Extension,
) -> Ipv6Addr {
if let Some(combined) = combined(extension).await {
match extension {
Expand Down Expand Up @@ -1137,8 +1137,8 @@ fn assign_ipv6_with_range(cidr: Ipv6Cidr, range: u8, combined: u128) -> Ipv6Addr
///
/// Returns an `Option<u64>` which is `Some(combined_value)` if the operation
/// is applicable and successful, or `None` if the `extension` variant does not
async fn combined(extension: &Extension) -> Option<u64> {
match *extension {
async fn combined(extension: Extension) -> Option<u64> {
match extension {
Extension::Range(value) => Some(value),
Extension::Session(value) => Some(value),
Extension::TTL(ttl) => tokio::task::spawn_blocking(move || {
Expand Down Expand Up @@ -1201,7 +1201,7 @@ mod tests {
async fn test_assign_ipv4_from_extension() {
let cidr = "2001:470:e953::/48".parse().unwrap();
let extension = Extension::Session(0x12345);
let ipv6_address = assign_ipv6_from_extension(cidr, None, &extension).await;
let ipv6_address = assign_ipv6_from_extension(cidr, None, extension).await;
assert_eq!(
ipv6_address,
std::net::Ipv6Addr::from([0x2001, 0x470, 0xe953, 0, 0, 0, 1, 0x2345])
Expand Down
2 changes: 1 addition & 1 deletion src/extension.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/// Enum representing different types of extensions.
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Debug, Default)]
#[derive(Clone, Copy, Debug, Default)]
pub enum Extension {
#[default]
None,
Expand Down
2 changes: 1 addition & 1 deletion src/socks/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async fn hanlde_connect_proxy(
.connect_with_domain((domain, port), extension)
.await
}
Address::SocketAddress(socket_addr) => connector.connect(socket_addr, &extension).await,
Address::SocketAddress(socket_addr) => connector.connect(socket_addr, extension).await,
};

match target_stream {
Expand Down
Loading