From cea7d96f2db3e9216964c01a520ddeb84e33b410 Mon Sep 17 00:00:00 2001 From: Simonas Kazlauskas Date: Thu, 29 Oct 2015 21:04:05 +0200 Subject: [PATCH] Implement Partial{,Eq} for JoinHandle & os::Thread --- src/liblibc/lib.rs | 1 + src/libstd/sys/unix/thread.rs | 10 ++++++++++ src/libstd/sys/windows/handle.rs | 11 ++++++++++- src/libstd/sys/windows/thread.rs | 1 + src/libstd/thread/mod.rs | 21 +++++++++++++++++++++ 5 files changed, 43 insertions(+), 1 deletion(-) diff --git a/src/liblibc/lib.rs b/src/liblibc/lib.rs index 13902d674379e..5d2700b1f2a8a 100644 --- a/src/liblibc/lib.rs +++ b/src/liblibc/lib.rs @@ -7029,6 +7029,7 @@ pub mod funcs { dwOptions: DWORD) -> BOOL; pub fn CloseHandle(hObject: HANDLE) -> BOOL; + pub fn CompareObjectHandles(h1: HANDLE, h2: HANDLE) -> BOOL; pub fn OpenProcess(dwDesiredAccess: DWORD, bInheritHandle: BOOL, dwProcessId: DWORD) diff --git a/src/libstd/sys/unix/thread.rs b/src/libstd/sys/unix/thread.rs index 3eedb76c21b72..7ba1615d26e74 100644 --- a/src/libstd/sys/unix/thread.rs +++ b/src/libstd/sys/unix/thread.rs @@ -26,6 +26,7 @@ use time::Duration; use sys_common::thread::*; +#[derive(Eq)] pub struct Thread { id: libc::pthread_t, } @@ -169,6 +170,14 @@ impl Thread { } } +impl PartialEq for Thread { + fn eq(&self, other: &Self) -> bool { + unsafe { + pthread_equal(self.id, other.id) != 0 + } + } +} + impl Drop for Thread { fn drop(&mut self) { let ret = unsafe { pthread_detach(self.id) }; @@ -403,6 +412,7 @@ extern { value: *mut libc::c_void) -> libc::c_int; fn pthread_join(native: libc::pthread_t, value: *mut *mut libc::c_void) -> libc::c_int; + fn pthread_equal(t1: libc::pthread_t, t2: libc::pthread_t) -> libc::c_int; fn pthread_attr_init(attr: *mut libc::pthread_attr_t) -> libc::c_int; fn pthread_attr_destroy(attr: *mut libc::pthread_attr_t) -> libc::c_int; fn pthread_attr_setstacksize(attr: *mut libc::pthread_attr_t, diff --git a/src/libstd/sys/windows/handle.rs b/src/libstd/sys/windows/handle.rs index a9e9b0e252077..812dd922e7bc0 100644 --- a/src/libstd/sys/windows/handle.rs +++ b/src/libstd/sys/windows/handle.rs @@ -20,6 +20,7 @@ use sys::cvt; /// An owned container for `HANDLE` object, closing them on Drop. /// /// All methods are inherited through a `Deref` impl to `RawHandle` +#[derive(PartialEq, Eq)] pub struct Handle(RawHandle); /// A wrapper type for `HANDLE` objects to give them proper Send/Sync inference @@ -27,7 +28,7 @@ pub struct Handle(RawHandle); /// /// This does **not** drop the handle when it goes out of scope, use `Handle` /// instead for that. -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Eq)] pub struct RawHandle(HANDLE); unsafe impl Send for RawHandle {} @@ -106,3 +107,11 @@ impl RawHandle { Ok(Handle::new(ret)) } } + +impl PartialEq for RawHandle { + fn eq(&self, other: &Self) -> bool { + unsafe { + libc::CompareObjectHandles(self.0, other.0) != libc::FALSE + } + } +} diff --git a/src/libstd/sys/windows/thread.rs b/src/libstd/sys/windows/thread.rs index cf1b3ebddb97b..4fef56970721a 100644 --- a/src/libstd/sys/windows/thread.rs +++ b/src/libstd/sys/windows/thread.rs @@ -20,6 +20,7 @@ use sys::handle::Handle; use sys_common::thread::*; use time::Duration; +#[derive(PartialEq, Eq)] pub struct Thread { handle: Handle } diff --git a/src/libstd/thread/mod.rs b/src/libstd/thread/mod.rs index 9b8f63997b642..69ed668413052 100644 --- a/src/libstd/thread/mod.rs +++ b/src/libstd/thread/mod.rs @@ -590,6 +590,14 @@ impl JoinInner { } } +impl PartialEq for JoinInner { + fn eq(&self, other: &Self) -> bool { + self.native == other.native + } +} + +impl Eq for JoinInner {} + /// An owned permission to join on a thread (block on its termination). /// /// A `JoinHandle` *detaches* the child thread when it is dropped. @@ -597,6 +605,7 @@ impl JoinInner { /// Due to platform restrictions, it is not possible to `Clone` this /// handle: the ability to join a child thread is a uniquely-owned /// permission. +#[derive(Eq, PartialEq)] #[stable(feature = "rust1", since = "1.0.0")] pub struct JoinHandle(JoinInner); @@ -633,6 +642,7 @@ mod tests { use any::Any; use sync::mpsc::{channel, Sender}; + use sync::{Barrier, Arc}; use result; use super::{Builder}; use thread; @@ -665,6 +675,17 @@ mod tests { rx.recv().unwrap(); } + #[test] + fn test_thread_guard_equality() { + let barrier = Arc::new(Barrier::new(2)); + let b = barrier.clone(); + let h = thread::spawn(move|| { + b.wait(); + }); + assert!(h == h); + barrier.wait(); + } + #[test] fn test_join_panic() { match thread::spawn(move|| {