Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/kyron/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,15 @@ rust_binary(
visibility = ["//visibility:public"],
deps = _EXAMPLE_DEPS,
)

rust_binary(
name = "safety_task",
srcs = [
"examples/safety_task.rs",
],
proc_macro_deps = [
"//src/kyron-macros:runtime_macros",
],
visibility = ["//visibility:public"],
deps = _EXAMPLE_DEPS,
)
69 changes: 69 additions & 0 deletions src/kyron/examples/safety_task.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
//
// Copyright (c) 2025 Contributors to the Eclipse Foundation
//
// See the NOTICE file(s) distributed with this work for additional
// information regarding copyright ownership.
//
// This program and the accompanying materials are made available under the
// terms of the Apache License Version 2.0 which is available at
// <https://www.apache.org/licenses/LICENSE-2.0>
//
// SPDX-License-Identifier: Apache-2.0
//

use kyron::prelude::*;
use kyron::safety;
use kyron::spawn_on_dedicated;
use kyron_foundation::prelude::*;

async fn failing_safety_task() -> Result<(), String> {
info!("Worker-N: failing_safety_task");
Err("Intentional failure".to_string())
}

async fn passing_safety_task() -> Result<(), String> {
info!("Worker-N: passing_safety_task");
Ok(())
}

async fn passing_non_safety_task() -> Result<(), String> {
info!("Dedicated worker (dw1): passing_non_safety_task");
Ok(())
}

fn main() {
tracing_subscriber::fmt()
.with_target(false) // Optional: Remove module path
.with_max_level(Level::DEBUG)
.with_thread_ids(true)
.with_thread_names(true)
.init();

// Create runtime
let (builder, _engine_id) = kyron::runtime::RuntimeBuilder::new().with_engine(
ExecutionEngineBuilder::new()
.task_queue_size(256)
.enable_safety_worker(ThreadParameters::default())
.with_dedicated_worker("dw1".into(), ThreadParameters::default())
.workers(2),
);

let mut runtime = builder.build().unwrap();
// Put programs into runtime and run them
runtime.block_on(async move {
let handle1 = safety::spawn(failing_safety_task());
let handle2 = safety::spawn(passing_safety_task());
let handle3 = spawn_on_dedicated(passing_non_safety_task(), "dw1".into());

info!("=============================== Spawned all tasks ===============================");

let _ = handle1.await;
info!("Safety worker: Since safety task fails, safety worker executes parent task from this statement onwards.");
let _ = handle2.await;
let _ = handle3.await;

info!("Safety worker: Program finished running.");
});

info!("Exit.");
}
22 changes: 22 additions & 0 deletions src/kyron/src/scheduler/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ pub(crate) struct WorkerContext {
/// Helper flag to check if safety was enabled in runtime builder
is_safety_enabled: bool,

/// This flag is used to schedule parent task of failing safety task into safety worker
schedule_safety: Cell<bool>,

wakeup_time: Cell<Option<u64>>,
}

Expand Down Expand Up @@ -399,6 +402,7 @@ impl ContextBuilder {
worker_id: Cell::new(self.worker_id.expect("Worker type must be set in context builder!")),
handler: RefCell::new(Some(Rc::new(self.handle.expect("Handler type must be set in context builder!")))),
is_safety_enabled: self.is_with_safety,
schedule_safety: Cell::new(false),
wakeup_time: Cell::new(None),
drivers: Some(self.drivers),
}
Expand Down Expand Up @@ -444,6 +448,24 @@ pub(crate) fn ctx_get_worker_id() -> WorkerId {
})
}

///
/// Set schedule safety flag
///
#[allow(dead_code)] // To avoid error when runtime mocking feature is enabled
pub(crate) fn ctx_set_schedule_safety(val: bool) {
CTX.try_with(|ctx| ctx.borrow().as_ref().expect("Called before CTX init?").schedule_safety.set(val))
.unwrap_or_default();
}

///
/// Get schedule safety flag and clear
///
#[allow(dead_code)]
pub(crate) fn ctx_get_schedule_safety() -> bool {
CTX.try_with(|ctx| ctx.borrow().as_ref().expect("Called before CTX init?").schedule_safety.replace(false))
.unwrap_or_default()
}

///
/// Check if safety was enabled
///
Expand Down
101 changes: 75 additions & 26 deletions src/kyron/src/scheduler/join_handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
use kyron_foundation::prelude::*;
use kyron_foundation::{not_recoverable_error, prelude::CommonErrors};

use crate::scheduler::task::task_context::TaskContext;
use crate::{
futures::{FutureInternalReturn, FutureState},
TaskRef,
Expand Down Expand Up @@ -74,34 +75,52 @@ impl<T: Send + 'static> Future for JoinHandle<T> {
if was_set {
FutureInternalReturn::default()
} else {
// Check whether there is safety error for the completed task and this task is running on async worker
// if this task is already running on safety worker/dedicated worker, do not set the flag to schedule on safety worker.
if self.for_task.get_task_safety_error() && TaskContext::is_task_running_on_async_worker() {
// Set the flag to wake this task into safety worker
TaskContext::set_flag_to_wake_parent_task_into_safety();
waker.wake_by_ref();
FutureInternalReturn::polled()
} else {
let mut ret: Result<T, CommonErrors> = Err(CommonErrors::NoData);
let ret_as_ptr = &mut ret as *mut _;
self.for_task.get_return_val(ret_as_ptr as *mut u8);

match ret {
Ok(v) => FutureInternalReturn::ready(Ok(v)),
Err(CommonErrors::OperationAborted) => FutureInternalReturn::ready(Err(CommonErrors::OperationAborted)),
Err(e) => {
not_recoverable_error!(with e, "There has been an error in a task that is not recoverable ({})!");
}
}
}
}
}
FutureState::Polled => {
let waker = cx.waker();

// Set the waker, return values tells what have happen and took care about correct synchronization
let was_set = self.for_task.set_join_handle_waker(waker.clone());

if was_set {
FutureInternalReturn::default()
} else {
// Safety belows forms AqrRel so waker is really written before we do marking
let mut ret: Result<T, CommonErrors> = Err(CommonErrors::NoData);
let ret_as_ptr = &mut ret as *mut _;
self.for_task.get_return_val(ret_as_ptr as *mut u8);

match ret {
Ok(v) => FutureInternalReturn::ready(Ok(v)),
Err(CommonErrors::NoData) => FutureInternalReturn::polled(),
Err(CommonErrors::OperationAborted) => FutureInternalReturn::ready(Err(CommonErrors::OperationAborted)),
Err(e) => {
not_recoverable_error!(with e, "There has been an error in a task that is not recoverable ({})!");
}
}
}
}
FutureState::Polled => {
// Safety belows forms AqrRel so waker is really written before we do marking
let mut ret: Result<T, CommonErrors> = Err(CommonErrors::NoData);
let ret_as_ptr = &mut ret as *mut _;
self.for_task.get_return_val(ret_as_ptr as *mut u8);

match ret {
Ok(v) => FutureInternalReturn::ready(Ok(v)),
Err(CommonErrors::NoData) => FutureInternalReturn::polled(),
Err(CommonErrors::OperationAborted) => FutureInternalReturn::ready(Err(CommonErrors::OperationAborted)),
Err(e) => {
not_recoverable_error!(with e, "There has been an error in a task that is not recoverable ({})!");
}
}
}
FutureState::Finished => {
not_recoverable_error!("Future polled after it finished!");
}
Expand Down Expand Up @@ -256,6 +275,40 @@ mod tests {
assert_eq!(poller.poll(), ::core::task::Poll::Ready(Ok(0)));
}
}

#[test]
fn test_join_handle_waker_is_set_in_polled_state_also() {
let scheduler = create_mock_scheduler();

{
// Data is present before first poll of join handle
let task = ArcInternal::new(AsyncTask::new(box_future(test_function::<u32>()), 1, scheduler.clone()));

let handle = JoinHandle::<u32>::new(TaskRef::new(task.clone()));

let mut poller = TestingFuturePoller::new(handle);

let waker_mock1 = TrackableWaker::new();
let waker1 = waker_mock1.get_waker();

let waker_mock2 = TrackableWaker::new();
let waker2 = waker_mock2.get_waker();

let _ = poller.poll_with_waker(&waker1);
// Now in polled state, poll again with waker2
let _ = poller.poll_with_waker(&waker2);
{
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
task.poll(&mut cx); // task done
}

assert!(!waker_mock1.was_waked());
// this should be TRUE
assert!(waker_mock2.was_waked());
assert_eq!(poller.poll(), ::core::task::Poll::Ready(Ok(0)));
}
}
}

#[cfg(test)]
Expand All @@ -277,8 +330,9 @@ mod tests {

#[test]
fn test_join_handler_mt_get_result() {
let builder = Builder::new();

let mut builder = Builder::new();
// Limit preemption to avoid loom error "Model exceeded maximum number of branches."
builder.preemption_bound = Some(4);
builder.check(|| {
let scheduler = create_mock_scheduler();

Expand All @@ -299,22 +353,17 @@ mod tests {

let waker_mock = TrackableWaker::new();
let waker = waker_mock.get_waker();
let mut was_pending = false;

loop {
match poller.poll_with_waker(&waker) {
Poll::Ready(v) => {
assert_eq!(v, Ok(1234));

if was_pending {
assert!(waker_mock.was_waked());
}
// Note:
// Cannot check whether the waker was woken or not since the waker is set in the join handle poll every time if task is not yet done.
// So depending on the interleaving, the task may finish before the waker is set.

break;
}
Poll::Pending => {
was_pending = true;
}
Poll::Pending => {}
}
loom::hint::spin_loop();
}
Expand Down
14 changes: 9 additions & 5 deletions src/kyron/src/scheduler/safety_waker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//

use super::task::async_task::*;
use crate::scheduler::task::task_context::TaskContext;
use core::task::{RawWaker, RawWakerVTable, Waker};

fn clone_waker(data: *const ()) -> RawWaker {
Expand All @@ -30,13 +31,18 @@ fn wake(data: *const ()) {
let task_header_ptr = data as *const TaskHeader;
let task_ref = unsafe { TaskRef::from_raw(task_header_ptr) };

// Just clear the flag which might have been set by async worker before calling wake/wake_by_ref
// for the scenario where the join handle poll is executed by safety worker and waker is set
TaskContext::clear_schedule_safety_flag();
task_ref.schedule_safety();
}

fn wake_by_ref(data: *const ()) {
let task_header_ptr = data as *const TaskHeader;
let task_ref = unsafe { TaskRef::from_raw(task_header_ptr) };

// Just clear the flag which might have been set by async worker before calling wake/wake_by_ref
TaskContext::clear_schedule_safety_flag();
task_ref.schedule_safety_by_ref();

::core::mem::forget(task_ref); // don't touch refcount from our data since this is done by drop_waker
Expand All @@ -55,11 +61,9 @@ static VTABLE: RawWakerVTable = RawWakerVTable::new(clone_waker, wake, wake_by_r
///
/// Waker will store internally a pointer to the ref counted Task.
///
pub(crate) unsafe fn create_safety_waker(waker: Waker) -> Waker {
let raw_waker = RawWaker::new(waker.data(), &VTABLE);

// Forget original as we took over the ownership, so ref count
::core::mem::forget(waker);
pub(crate) fn create_safety_waker(ptr: TaskRef) -> Waker {
let ptr = TaskRef::into_raw(ptr); // Extracts the pointer from TaskRef not decreasing it's reference count. Since we have a clone here, ref cnt was already increased
let raw_waker = RawWaker::new(ptr as *const (), &VTABLE);

// Convert RawWaker to Waker
unsafe { Waker::from_raw(raw_waker) }
Expand Down
Loading
Loading