From 233748ccdd26dca083f601c8d53b44068f073aee Mon Sep 17 00:00:00 2001 From: Johnny Graettinger Date: Wed, 16 Oct 2024 15:02:14 -0500 Subject: [PATCH] automations: refine interface with separate Outcome & Action A gap of the previous interface is that task polls will often want to make databases changes, and those changes should commit in the same transaction that updates the `internal.tasks` table with messages and state. Introduce an Outcome trait which is applied to Postgres transaction it's given, an to which its returned Action is also applied. This lets the Executor implement an optimistic concurrency workflow and further lets it defer deciding what its polling Action will be until after it verifies its optimistic lock. --- crates/automations/src/executors.rs | 114 +++++++++++--------------- crates/automations/src/lib.rs | 53 ++++++++++-- crates/automations/tests/fibonacci.rs | 23 +++--- 3 files changed, 105 insertions(+), 85 deletions(-) diff --git a/crates/automations/src/executors.rs b/crates/automations/src/executors.rs index 27c67ac718..f6c18002c8 100644 --- a/crates/automations/src/executors.rs +++ b/crates/automations/src/executors.rs @@ -1,4 +1,4 @@ -use super::{server, BoxedRaw, Executor, PollOutcome, TaskType}; +use super::{server, Action, BoxedRaw, Executor, Outcome, TaskType}; use anyhow::Context; use futures::future::{BoxFuture, FutureExt}; use sqlx::types::Json as SqlJson; @@ -9,11 +9,12 @@ pub trait ObjSafe: Send + Sync + 'static { fn poll<'s>( &'s self, + pool: &'s sqlx::PgPool, task_id: models::Id, parent_id: Option, - state: &'s mut Option>, - inbox: &'s mut Option)>>>, - ) -> BoxFuture<'s, anyhow::Result>>; + state: Option>, + inbox: Option)>>>, + ) -> BoxFuture<'s, anyhow::Result<()>>; } impl ObjSafe for E { @@ -23,11 +24,12 @@ impl ObjSafe for E { fn poll<'s>( &'s self, + pool: &'s sqlx::PgPool, task_id: models::Id, parent_id: Option, - state: &'s mut Option>, - inbox: &'s mut Option)>>>, - ) -> BoxFuture<'s, anyhow::Result>> { + mut state: Option>, + mut inbox: Option)>>>, + ) -> BoxFuture<'s, anyhow::Result<()>> { async move { let mut state_parsed: E::State = if let Some(state) = state { serde_json::from_str(state.get()).context("failed to decode task state")? @@ -52,6 +54,7 @@ impl ObjSafe for E { let outcome = E::poll( self, + pool, task_id, parent_id, &mut state_parsed, @@ -60,21 +63,16 @@ impl ObjSafe for E { .await?; // Re-encode state for persistence. - // If we're Done, then the output state is NULL which is implicitly Default. - if matches!(outcome, PollOutcome::Done) { - *state = None - } else { - *state = Some(SqlJson( - serde_json::value::to_raw_value(&state_parsed) - .context("failed to encode inner state")?, - )); - } + state = Some(SqlJson( + serde_json::value::to_raw_value(&state_parsed) + .context("failed to encode inner state")?, + )); // Re-encode the unconsumed portion of the inbox. if inbox_parsed.is_empty() { - *inbox = None + inbox = None } else { - *inbox = Some( + inbox = Some( inbox_parsed .into_iter() .map(|(task_id, msg)| { @@ -91,19 +89,16 @@ impl ObjSafe for E { ); } - Ok(match outcome { - PollOutcome::Done => PollOutcome::Done, - PollOutcome::Send(task_id, msg) => PollOutcome::Send(task_id, msg), - PollOutcome::Sleep(interval) => PollOutcome::Sleep(interval), - PollOutcome::Spawn(task_id, task_type, msg) => { - PollOutcome::Spawn(task_id, task_type, msg) - } - PollOutcome::Suspend => PollOutcome::Suspend, - PollOutcome::Yield(msg) => PollOutcome::Yield( - serde_json::value::to_raw_value(&msg) - .context("failed to encode yielded message")?, - ), - }) + let mut txn = pool.begin().await?; + + let action = outcome + .apply(&mut *txn) + .await + .context("failed to apply task Outcome")?; + + () = persist_action(action, &mut *txn, task_id, parent_id, state, inbox).await?; + + Ok(txn.commit().await?) } .boxed() } @@ -119,8 +114,8 @@ pub async fn poll_task( id: task_id, type_: _, parent_id, - mut inbox, - mut state, + inbox, + state, mut last_heartbeat, }, }: server::ReadyTask, @@ -145,24 +140,9 @@ pub async fn poll_task( // Poll `executor` and `update_heartbeats` in tandem, so that a failure // to update our heartbeat also cancels the executor. - let outcome = tokio::select! { - outcome = executor.poll(task_id, parent_id, &mut state, &mut inbox) => { outcome? }, - err = &mut update_heartbeats => return Err(err), - }; - - // The possibly long-lived polling operation is now complete. - // Build a Future that commits a (hopefully) brief transaction of `outcome`. - let persist_outcome = async { - let mut txn = pool.begin().await?; - () = persist_outcome(outcome, &mut *txn, task_id, parent_id, state, inbox).await?; - Ok(txn.commit().await?) - }; - - // Poll `persist_outcome` while continuing to poll `update_heartbeats`, - // to guarantee we cannot commit an outcome after our lease is lost. tokio::select! { - result = persist_outcome => result, - err = update_heartbeats => Err(err), + result = executor.poll(&pool, task_id, parent_id, state, inbox) => result, + err = &mut update_heartbeats => return Err(err), } } @@ -203,17 +183,17 @@ async fn update_heartbeat( Ok(updated.heartbeat) } -async fn persist_outcome( - outcome: PollOutcome, +async fn persist_action( + action: Action, txn: &mut sqlx::PgConnection, task_id: models::Id, parent_id: Option, - state: Option>, + mut state: Option>, inbox: Option)>>>, ) -> anyhow::Result<()> { use std::time::Duration; - if let PollOutcome::Spawn(spawn_id, spawn_type, _msg) = &outcome { + if let Action::Spawn(spawn_id, spawn_type, _msg) = &action { sqlx::query!( "SELECT internal.create_task($1, $2, $3)", *spawn_id as models::Id, @@ -225,15 +205,15 @@ async fn persist_outcome( .context("failed to spawn new task")?; } - if let Some((send_id, msg)) = match &outcome { + if let Some((send_id, msg)) = match &action { // When a task is spawned, send its first message. - PollOutcome::Spawn(spawn_id, _spawn_type, msg) => Some((*spawn_id, Some(msg))), + Action::Spawn(spawn_id, _spawn_type, msg) => Some((*spawn_id, Some(msg))), // If we're Done but have a parent, send it an EOF. - PollOutcome::Done => parent_id.map(|parent_id| (parent_id, None)), + Action::Done => parent_id.map(|parent_id| (parent_id, None)), // Send an arbitrary message to an identified task. - PollOutcome::Send(task_id, msg) => Some((*task_id, msg.as_ref())), + Action::Send(task_id, msg) => Some((*task_id, msg.as_ref())), // Yield is sugar for sending to our parent. - PollOutcome::Yield(msg) => { + Action::Yield(msg) => { let Some(parent_id) = parent_id else { anyhow::bail!("task yielded illegally, because it does not have a parent"); }; @@ -255,17 +235,19 @@ async fn persist_outcome( let wake_at_interval = if inbox.is_some() { Some(Duration::ZERO) // Always poll immediately if inbox items remain. } else { - match &outcome { - PollOutcome::Sleep(interval) => Some(*interval), + match &action { + Action::Sleep(interval) => Some(*interval), // These outcomes do not suspend the task, and it should wake as soon as possible. - PollOutcome::Spawn(..) | PollOutcome::Send(..) | PollOutcome::Yield(..) => { - Some(Duration::ZERO) - } + Action::Spawn(..) | Action::Send(..) | Action::Yield(..) => Some(Duration::ZERO), // Suspend indefinitely (note that NOW() + NULL::INTERVAL is NULL). - PollOutcome::Done | PollOutcome::Suspend => None, + Action::Done | Action::Suspend => None, } }; + if let Action::Done = &action { + state = None; // Set to NULL, which is implicit Default. + } + let updated = sqlx::query!( r#" UPDATE internal.tasks SET @@ -292,7 +274,7 @@ async fn persist_outcome( // If we're Done and also successfully suspended, then delete ourselves. // (Otherwise, the task has been left in a like-new state). - if matches!(&outcome, PollOutcome::Done if updated.suspended) { + if matches!(&action, Action::Done if updated.suspended) { sqlx::query!( "DELETE FROM internal.tasks WHERE task_id = $1;", task_id as models::Id, diff --git a/crates/automations/src/lib.rs b/crates/automations/src/lib.rs index 9a43187396..7fe6fb9749 100644 --- a/crates/automations/src/lib.rs +++ b/crates/automations/src/lib.rs @@ -23,16 +23,40 @@ type BoxedRaw = Box; #[sqlx(transparent)] pub struct TaskType(pub i16); -/// PollOutcome is the outcome of an `Executor::poll()` for a given task. +/// Outcome of an `Executor::poll()` for a given task, which encloses +/// an Action with which it's applied as a single transaction. +/// +/// As an example of how Executor, Outcome, and Action are used together, +/// suppose an implementation of `Executor::poll()` is called: +/// +/// - It reads DB state associated with the task using sqlx::PgPool. +/// - It performs long-running work, running outside of a DB transaction. +/// - It returns an Outcome implementation which encapsulates the +/// preconditions it observed, as well as its domain-specific outcome. +/// - `Outcome::apply()` is called and re-verifies preconditions using `txn`, +/// returning an error if preconditions have changed. +/// - It applies the effects of its outcome and returns a polling Action. +/// - `txn` is further by this crate as required by the Action, and then commits. +/// +pub trait Outcome: Send { + /// Apply the effects of an Executor poll. While this is an async routine, + /// apply() runs in the context of a held transaction and should be fast. + fn apply<'s>( + self, + txn: &'s mut sqlx::PgConnection, + ) -> impl std::future::Future> + Send + 's; +} + +/// Action undertaken by an Executor task poll. #[derive(Debug)] -pub enum PollOutcome { +pub enum Action { /// Spawn a new TaskId with the given TaskType and send a first message. /// The TaskId must not exist. Spawn(models::Id, TaskType, BoxedRaw), /// Send a message (Some) or EOF (None) to another TaskId, which must exist. Send(models::Id, Option), /// Yield to send a message to this task's parent. - Yield(Yield), + Yield(BoxedRaw), /// Sleep for at-most the indicated Duration, then poll again. /// The task may be woken earlier if it receives a message. Sleep(std::time::Duration), @@ -43,28 +67,35 @@ pub enum PollOutcome { Done, } +// Action implements an Outcome with no side-effects. +impl Outcome for Action { + async fn apply<'s>(self, _txn: &'s mut sqlx::PgConnection) -> anyhow::Result { + Ok(self) + } +} + /// Executor is the core trait implemented by executors of various task types. pub trait Executor: Send + Sync + 'static { const TASK_TYPE: TaskType; type Receive: serde::de::DeserializeOwned + serde::Serialize + Send; type State: Default + serde::de::DeserializeOwned + serde::Serialize + Send; - type Yield: serde::Serialize; + type Outcome: Outcome; fn poll<'s>( &'s self, + pool: &'s sqlx::PgPool, task_id: models::Id, parent_id: Option, state: &'s mut Self::State, inbox: &'s mut std::collections::VecDeque<(models::Id, Option)>, - ) -> impl std::future::Future>> + Send + 's; + ) -> impl std::future::Future> + Send + 's; } -/// Server holds registered implementations of Executor, -/// and serves them. +/// Server holds registered implementations of Executors and serves them. pub struct Server(Vec>); -impl PollOutcome { +impl Action { pub fn spawn( spawn_id: models::Id, task_type: TaskType, @@ -89,6 +120,12 @@ impl PollOutcome { }, )) } + + pub fn yield_(msg: M) -> anyhow::Result { + Ok(Self::Yield( + serde_json::value::to_raw_value(&msg).context("failed to encode yielded message")?, + )) + } } pub fn next_task_id() -> models::Id { diff --git a/crates/automations/tests/fibonacci.rs b/crates/automations/tests/fibonacci.rs index 83d600a42e..db9d00f48b 100644 --- a/crates/automations/tests/fibonacci.rs +++ b/crates/automations/tests/fibonacci.rs @@ -1,4 +1,4 @@ -use automations::PollOutcome; +use automations::Action; use std::collections::VecDeque; /// Fibonacci is one of the least-efficient calculators of the Fibonacci @@ -40,8 +40,8 @@ impl automations::Executor for Fibonacci { const TASK_TYPE: automations::TaskType = automations::TaskType(32767); type Receive = Message; - type Yield = Message; type State = State; + type Outcome = Action; #[tracing::instrument( ret, @@ -51,11 +51,12 @@ impl automations::Executor for Fibonacci { )] async fn poll<'s>( &'s self, + _pool: &'s sqlx::PgPool, task_id: models::Id, parent_id: Option, state: &'s mut Self::State, inbox: &'s mut VecDeque<(models::Id, Option)>, - ) -> anyhow::Result> { + ) -> anyhow::Result { if rand::random::() < self.failure_rate { return Err( anyhow::anyhow!("A no good, very bad error!").context("something bad happened") @@ -63,7 +64,7 @@ impl automations::Executor for Fibonacci { } if let State::SpawnOne(value) = state { - let spawn = PollOutcome::spawn( + let spawn = Action::spawn( automations::next_task_id(), Self::TASK_TYPE, Message { value: *value - 2 }, @@ -80,14 +81,14 @@ impl automations::Executor for Fibonacci { // Base case: (State::Init, Some((_parent_id, Some(Message { value })))) if value <= 2 => { *state = State::Finished; - Ok(PollOutcome::Yield(Message { value: 1 })) + Action::yield_(Message { value: 1 }) } // Recursive case: (State::Init, Some((_parent_id, Some(Message { value })))) => { *state = State::SpawnOne(value); - PollOutcome::spawn( + Action::spawn( automations::next_task_id(), Self::TASK_TYPE, Message { value: value - 1 }, @@ -99,7 +100,7 @@ impl automations::Executor for Fibonacci { // Sleeping at this point in the lifecycle exercises handling of // messages sent to a task that's currently being polled. () = tokio::time::sleep(self.sleep_for).await; - Ok(PollOutcome::Suspend) + Ok(Action::Suspend) } (State::Waiting { partial, pending }, Some((_child_id, Some(Message { value })))) => { @@ -107,7 +108,7 @@ impl automations::Executor for Fibonacci { partial: partial + value, pending, }; - Ok(PollOutcome::Suspend) + Ok(Action::Suspend) } (State::Waiting { partial, pending }, Some((_child_id, None))) => { @@ -116,14 +117,14 @@ impl automations::Executor for Fibonacci { partial, pending: pending - 1, }; - Ok(PollOutcome::Suspend) + Ok(Action::Suspend) } else { *state = State::Finished; - Ok(PollOutcome::Yield(Message { value: partial })) + Action::yield_(Message { value: partial }) } } - (State::Finished, None) => Ok(PollOutcome::Done), + (State::Finished, None) => Ok(Action::Done), state => anyhow::bail!("unexpected poll with state {state:?} and inbox {inbox:?}"), }