diff --git a/crates/automations/src/executors.rs b/crates/automations/src/executors.rs index 27c67ac718..f6c18002c8 100644 --- a/crates/automations/src/executors.rs +++ b/crates/automations/src/executors.rs @@ -1,4 +1,4 @@ -use super::{server, BoxedRaw, Executor, PollOutcome, TaskType}; +use super::{server, Action, BoxedRaw, Executor, Outcome, TaskType}; use anyhow::Context; use futures::future::{BoxFuture, FutureExt}; use sqlx::types::Json as SqlJson; @@ -9,11 +9,12 @@ pub trait ObjSafe: Send + Sync + 'static { fn poll<'s>( &'s self, + pool: &'s sqlx::PgPool, task_id: models::Id, parent_id: Option, - state: &'s mut Option>, - inbox: &'s mut Option)>>>, - ) -> BoxFuture<'s, anyhow::Result>>; + state: Option>, + inbox: Option)>>>, + ) -> BoxFuture<'s, anyhow::Result<()>>; } impl ObjSafe for E { @@ -23,11 +24,12 @@ impl ObjSafe for E { fn poll<'s>( &'s self, + pool: &'s sqlx::PgPool, task_id: models::Id, parent_id: Option, - state: &'s mut Option>, - inbox: &'s mut Option)>>>, - ) -> BoxFuture<'s, anyhow::Result>> { + mut state: Option>, + mut inbox: Option)>>>, + ) -> BoxFuture<'s, anyhow::Result<()>> { async move { let mut state_parsed: E::State = if let Some(state) = state { serde_json::from_str(state.get()).context("failed to decode task state")? @@ -52,6 +54,7 @@ impl ObjSafe for E { let outcome = E::poll( self, + pool, task_id, parent_id, &mut state_parsed, @@ -60,21 +63,16 @@ impl ObjSafe for E { .await?; // Re-encode state for persistence. - // If we're Done, then the output state is NULL which is implicitly Default. - if matches!(outcome, PollOutcome::Done) { - *state = None - } else { - *state = Some(SqlJson( - serde_json::value::to_raw_value(&state_parsed) - .context("failed to encode inner state")?, - )); - } + state = Some(SqlJson( + serde_json::value::to_raw_value(&state_parsed) + .context("failed to encode inner state")?, + )); // Re-encode the unconsumed portion of the inbox. if inbox_parsed.is_empty() { - *inbox = None + inbox = None } else { - *inbox = Some( + inbox = Some( inbox_parsed .into_iter() .map(|(task_id, msg)| { @@ -91,19 +89,16 @@ impl ObjSafe for E { ); } - Ok(match outcome { - PollOutcome::Done => PollOutcome::Done, - PollOutcome::Send(task_id, msg) => PollOutcome::Send(task_id, msg), - PollOutcome::Sleep(interval) => PollOutcome::Sleep(interval), - PollOutcome::Spawn(task_id, task_type, msg) => { - PollOutcome::Spawn(task_id, task_type, msg) - } - PollOutcome::Suspend => PollOutcome::Suspend, - PollOutcome::Yield(msg) => PollOutcome::Yield( - serde_json::value::to_raw_value(&msg) - .context("failed to encode yielded message")?, - ), - }) + let mut txn = pool.begin().await?; + + let action = outcome + .apply(&mut *txn) + .await + .context("failed to apply task Outcome")?; + + () = persist_action(action, &mut *txn, task_id, parent_id, state, inbox).await?; + + Ok(txn.commit().await?) } .boxed() } @@ -119,8 +114,8 @@ pub async fn poll_task( id: task_id, type_: _, parent_id, - mut inbox, - mut state, + inbox, + state, mut last_heartbeat, }, }: server::ReadyTask, @@ -145,24 +140,9 @@ pub async fn poll_task( // Poll `executor` and `update_heartbeats` in tandem, so that a failure // to update our heartbeat also cancels the executor. - let outcome = tokio::select! { - outcome = executor.poll(task_id, parent_id, &mut state, &mut inbox) => { outcome? }, - err = &mut update_heartbeats => return Err(err), - }; - - // The possibly long-lived polling operation is now complete. - // Build a Future that commits a (hopefully) brief transaction of `outcome`. - let persist_outcome = async { - let mut txn = pool.begin().await?; - () = persist_outcome(outcome, &mut *txn, task_id, parent_id, state, inbox).await?; - Ok(txn.commit().await?) - }; - - // Poll `persist_outcome` while continuing to poll `update_heartbeats`, - // to guarantee we cannot commit an outcome after our lease is lost. tokio::select! { - result = persist_outcome => result, - err = update_heartbeats => Err(err), + result = executor.poll(&pool, task_id, parent_id, state, inbox) => result, + err = &mut update_heartbeats => return Err(err), } } @@ -203,17 +183,17 @@ async fn update_heartbeat( Ok(updated.heartbeat) } -async fn persist_outcome( - outcome: PollOutcome, +async fn persist_action( + action: Action, txn: &mut sqlx::PgConnection, task_id: models::Id, parent_id: Option, - state: Option>, + mut state: Option>, inbox: Option)>>>, ) -> anyhow::Result<()> { use std::time::Duration; - if let PollOutcome::Spawn(spawn_id, spawn_type, _msg) = &outcome { + if let Action::Spawn(spawn_id, spawn_type, _msg) = &action { sqlx::query!( "SELECT internal.create_task($1, $2, $3)", *spawn_id as models::Id, @@ -225,15 +205,15 @@ async fn persist_outcome( .context("failed to spawn new task")?; } - if let Some((send_id, msg)) = match &outcome { + if let Some((send_id, msg)) = match &action { // When a task is spawned, send its first message. - PollOutcome::Spawn(spawn_id, _spawn_type, msg) => Some((*spawn_id, Some(msg))), + Action::Spawn(spawn_id, _spawn_type, msg) => Some((*spawn_id, Some(msg))), // If we're Done but have a parent, send it an EOF. - PollOutcome::Done => parent_id.map(|parent_id| (parent_id, None)), + Action::Done => parent_id.map(|parent_id| (parent_id, None)), // Send an arbitrary message to an identified task. - PollOutcome::Send(task_id, msg) => Some((*task_id, msg.as_ref())), + Action::Send(task_id, msg) => Some((*task_id, msg.as_ref())), // Yield is sugar for sending to our parent. - PollOutcome::Yield(msg) => { + Action::Yield(msg) => { let Some(parent_id) = parent_id else { anyhow::bail!("task yielded illegally, because it does not have a parent"); }; @@ -255,17 +235,19 @@ async fn persist_outcome( let wake_at_interval = if inbox.is_some() { Some(Duration::ZERO) // Always poll immediately if inbox items remain. } else { - match &outcome { - PollOutcome::Sleep(interval) => Some(*interval), + match &action { + Action::Sleep(interval) => Some(*interval), // These outcomes do not suspend the task, and it should wake as soon as possible. - PollOutcome::Spawn(..) | PollOutcome::Send(..) | PollOutcome::Yield(..) => { - Some(Duration::ZERO) - } + Action::Spawn(..) | Action::Send(..) | Action::Yield(..) => Some(Duration::ZERO), // Suspend indefinitely (note that NOW() + NULL::INTERVAL is NULL). - PollOutcome::Done | PollOutcome::Suspend => None, + Action::Done | Action::Suspend => None, } }; + if let Action::Done = &action { + state = None; // Set to NULL, which is implicit Default. + } + let updated = sqlx::query!( r#" UPDATE internal.tasks SET @@ -292,7 +274,7 @@ async fn persist_outcome( // If we're Done and also successfully suspended, then delete ourselves. // (Otherwise, the task has been left in a like-new state). - if matches!(&outcome, PollOutcome::Done if updated.suspended) { + if matches!(&action, Action::Done if updated.suspended) { sqlx::query!( "DELETE FROM internal.tasks WHERE task_id = $1;", task_id as models::Id, diff --git a/crates/automations/src/lib.rs b/crates/automations/src/lib.rs index 9a43187396..7fe6fb9749 100644 --- a/crates/automations/src/lib.rs +++ b/crates/automations/src/lib.rs @@ -23,16 +23,40 @@ type BoxedRaw = Box; #[sqlx(transparent)] pub struct TaskType(pub i16); -/// PollOutcome is the outcome of an `Executor::poll()` for a given task. +/// Outcome of an `Executor::poll()` for a given task, which encloses +/// an Action with which it's applied as a single transaction. +/// +/// As an example of how Executor, Outcome, and Action are used together, +/// suppose an implementation of `Executor::poll()` is called: +/// +/// - It reads DB state associated with the task using sqlx::PgPool. +/// - It performs long-running work, running outside of a DB transaction. +/// - It returns an Outcome implementation which encapsulates the +/// preconditions it observed, as well as its domain-specific outcome. +/// - `Outcome::apply()` is called and re-verifies preconditions using `txn`, +/// returning an error if preconditions have changed. +/// - It applies the effects of its outcome and returns a polling Action. +/// - `txn` is further by this crate as required by the Action, and then commits. +/// +pub trait Outcome: Send { + /// Apply the effects of an Executor poll. While this is an async routine, + /// apply() runs in the context of a held transaction and should be fast. + fn apply<'s>( + self, + txn: &'s mut sqlx::PgConnection, + ) -> impl std::future::Future> + Send + 's; +} + +/// Action undertaken by an Executor task poll. #[derive(Debug)] -pub enum PollOutcome { +pub enum Action { /// Spawn a new TaskId with the given TaskType and send a first message. /// The TaskId must not exist. Spawn(models::Id, TaskType, BoxedRaw), /// Send a message (Some) or EOF (None) to another TaskId, which must exist. Send(models::Id, Option), /// Yield to send a message to this task's parent. - Yield(Yield), + Yield(BoxedRaw), /// Sleep for at-most the indicated Duration, then poll again. /// The task may be woken earlier if it receives a message. Sleep(std::time::Duration), @@ -43,28 +67,35 @@ pub enum PollOutcome { Done, } +// Action implements an Outcome with no side-effects. +impl Outcome for Action { + async fn apply<'s>(self, _txn: &'s mut sqlx::PgConnection) -> anyhow::Result { + Ok(self) + } +} + /// Executor is the core trait implemented by executors of various task types. pub trait Executor: Send + Sync + 'static { const TASK_TYPE: TaskType; type Receive: serde::de::DeserializeOwned + serde::Serialize + Send; type State: Default + serde::de::DeserializeOwned + serde::Serialize + Send; - type Yield: serde::Serialize; + type Outcome: Outcome; fn poll<'s>( &'s self, + pool: &'s sqlx::PgPool, task_id: models::Id, parent_id: Option, state: &'s mut Self::State, inbox: &'s mut std::collections::VecDeque<(models::Id, Option)>, - ) -> impl std::future::Future>> + Send + 's; + ) -> impl std::future::Future> + Send + 's; } -/// Server holds registered implementations of Executor, -/// and serves them. +/// Server holds registered implementations of Executors and serves them. pub struct Server(Vec>); -impl PollOutcome { +impl Action { pub fn spawn( spawn_id: models::Id, task_type: TaskType, @@ -89,6 +120,12 @@ impl PollOutcome { }, )) } + + pub fn yield_(msg: M) -> anyhow::Result { + Ok(Self::Yield( + serde_json::value::to_raw_value(&msg).context("failed to encode yielded message")?, + )) + } } pub fn next_task_id() -> models::Id { diff --git a/crates/automations/tests/fibonacci.rs b/crates/automations/tests/fibonacci.rs index 83d600a42e..db9d00f48b 100644 --- a/crates/automations/tests/fibonacci.rs +++ b/crates/automations/tests/fibonacci.rs @@ -1,4 +1,4 @@ -use automations::PollOutcome; +use automations::Action; use std::collections::VecDeque; /// Fibonacci is one of the least-efficient calculators of the Fibonacci @@ -40,8 +40,8 @@ impl automations::Executor for Fibonacci { const TASK_TYPE: automations::TaskType = automations::TaskType(32767); type Receive = Message; - type Yield = Message; type State = State; + type Outcome = Action; #[tracing::instrument( ret, @@ -51,11 +51,12 @@ impl automations::Executor for Fibonacci { )] async fn poll<'s>( &'s self, + _pool: &'s sqlx::PgPool, task_id: models::Id, parent_id: Option, state: &'s mut Self::State, inbox: &'s mut VecDeque<(models::Id, Option)>, - ) -> anyhow::Result> { + ) -> anyhow::Result { if rand::random::() < self.failure_rate { return Err( anyhow::anyhow!("A no good, very bad error!").context("something bad happened") @@ -63,7 +64,7 @@ impl automations::Executor for Fibonacci { } if let State::SpawnOne(value) = state { - let spawn = PollOutcome::spawn( + let spawn = Action::spawn( automations::next_task_id(), Self::TASK_TYPE, Message { value: *value - 2 }, @@ -80,14 +81,14 @@ impl automations::Executor for Fibonacci { // Base case: (State::Init, Some((_parent_id, Some(Message { value })))) if value <= 2 => { *state = State::Finished; - Ok(PollOutcome::Yield(Message { value: 1 })) + Action::yield_(Message { value: 1 }) } // Recursive case: (State::Init, Some((_parent_id, Some(Message { value })))) => { *state = State::SpawnOne(value); - PollOutcome::spawn( + Action::spawn( automations::next_task_id(), Self::TASK_TYPE, Message { value: value - 1 }, @@ -99,7 +100,7 @@ impl automations::Executor for Fibonacci { // Sleeping at this point in the lifecycle exercises handling of // messages sent to a task that's currently being polled. () = tokio::time::sleep(self.sleep_for).await; - Ok(PollOutcome::Suspend) + Ok(Action::Suspend) } (State::Waiting { partial, pending }, Some((_child_id, Some(Message { value })))) => { @@ -107,7 +108,7 @@ impl automations::Executor for Fibonacci { partial: partial + value, pending, }; - Ok(PollOutcome::Suspend) + Ok(Action::Suspend) } (State::Waiting { partial, pending }, Some((_child_id, None))) => { @@ -116,14 +117,14 @@ impl automations::Executor for Fibonacci { partial, pending: pending - 1, }; - Ok(PollOutcome::Suspend) + Ok(Action::Suspend) } else { *state = State::Finished; - Ok(PollOutcome::Yield(Message { value: partial })) + Action::yield_(Message { value: partial }) } } - (State::Finished, None) => Ok(PollOutcome::Done), + (State::Finished, None) => Ok(Action::Done), state => anyhow::bail!("unexpected poll with state {state:?} and inbox {inbox:?}"), }