diff --git a/rust/kernel/task.rs b/rust/kernel/task.rs index fb17196a6664b6..40daf6eff42bb4 100644 --- a/rust/kernel/task.rs +++ b/rust/kernel/task.rs @@ -4,8 +4,12 @@ //! //! C header: [`include/linux/sched.h`](../../../../include/linux/sched.h). -use crate::{bindings, ARef, AlwaysRefCounted}; -use core::{cell::UnsafeCell, marker::PhantomData, ops::Deref, ptr}; +use crate::{ + bindings, c_str, c_types, error::from_kernel_err_ptr, types::PointerWrapper, ARef, + AlwaysRefCounted, Result, ScopeGuard, +}; +use alloc::boxed::Box; +use core::{cell::UnsafeCell, fmt, marker::PhantomData, ops::Deref, ptr}; /// Wraps the kernel's `struct task_struct`. /// @@ -101,6 +105,88 @@ impl Task { // SAFETY: By the type invariant, we know that `self.0` is valid. unsafe { bindings::signal_pending(self.0.get()) != 0 } } + + /// Starts a new kernel thread and runs it. + /// + /// # Examples + /// + /// Launches 10 threads and waits for them to complete. + /// + /// ``` + /// use kernel::task::Task; + /// use kernel::sync::{CondVar, Mutex}; + /// use core::sync::atomic::{AtomicU32, Ordering}; + /// + /// kernel::init_static_sync! { + /// static COUNT: Mutex = 0; + /// static COUNT_IS_ZERO: CondVar; + /// } + /// + /// fn threadfn() { + /// pr_info!("Running from thread {}\n", Task::current().pid()); + /// let mut guard = COUNT.lock(); + /// *guard -= 1; + /// if *guard == 0 { + /// COUNT_IS_ZERO.notify_all(); + /// } + /// } + /// + /// // Set count to 10 and spawn 10 threads. + /// *COUNT.lock() = 10; + /// for i in 0..10 { + /// Task::spawn(fmt!("test{i}"), threadfn).unwrap(); + /// } + /// + /// // Wait for count to drop to zero. + /// let mut guard = COUNT.lock(); + /// while (*guard != 0) { + /// COUNT_IS_ZERO.wait(&mut guard); + /// } + /// ``` + pub fn spawn( + name: fmt::Arguments<'_>, + func: T, + ) -> Result> { + unsafe extern "C" fn threadfn( + arg: *mut c_types::c_void, + ) -> c_types::c_int { + // SAFETY: The thread argument is always a `Box` because it is only called via the + // thread creation below. + let bfunc = unsafe { Box::::from_pointer(arg) }; + bfunc(); + 0 + } + + let arg = Box::try_new(func)?.into_pointer(); + + // SAFETY: `arg` was just created with a call to `into_pointer` above. + let guard = ScopeGuard::new(|| unsafe { + Box::::from_pointer(arg); + }); + + // SAFETY: The function pointer is always valid (as long as the module remains loaded). + // Ownership of `raw` is transferred to the new thread (if one is actually created), so it + // remains valid. Lastly, the C format string is a constant that require formatting as the + // one and only extra argument. + let ktask = from_kernel_err_ptr(unsafe { + bindings::kthread_create_on_node( + Some(threadfn::), + arg as _, + bindings::NUMA_NO_NODE, + c_str!("%pA").as_char_ptr(), + &name as *const _ as *const c_types::c_void, + ) + })?; + + // SAFETY: Since the kthread creation succeeded and we haven't run it yet, we know the task + // is valid. + let task = unsafe { &*(ktask as *const Task) }.into(); + + // SAFETY: Since the kthread creation succeeded, we know `ktask` is valid. + unsafe { bindings::wake_up_process(ktask) }; + guard.dismiss(); + Ok(task) + } } // SAFETY: The type invariants guarantee that `Task` is always ref-counted.