Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve automations interface #1709

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 48 additions & 66 deletions crates/automations/src/executors.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{server, BoxedRaw, Executor, PollOutcome, TaskType};
use super::{server, Action, BoxedRaw, Executor, Outcome, TaskType};
use anyhow::Context;
use futures::future::{BoxFuture, FutureExt};
use sqlx::types::Json as SqlJson;
Expand All @@ -9,11 +9,12 @@ pub trait ObjSafe: Send + Sync + 'static {

fn poll<'s>(
&'s self,
pool: &'s sqlx::PgPool,
task_id: models::Id,
parent_id: Option<models::Id>,
state: &'s mut Option<SqlJson<BoxedRaw>>,
inbox: &'s mut Option<Vec<SqlJson<(models::Id, Option<BoxedRaw>)>>>,
) -> BoxFuture<'s, anyhow::Result<PollOutcome<BoxedRaw>>>;
state: Option<SqlJson<BoxedRaw>>,
inbox: Option<Vec<SqlJson<(models::Id, Option<BoxedRaw>)>>>,
) -> BoxFuture<'s, anyhow::Result<()>>;
}

impl<E: Executor> ObjSafe for E {
Expand All @@ -23,11 +24,12 @@ impl<E: Executor> ObjSafe for E {

fn poll<'s>(
&'s self,
pool: &'s sqlx::PgPool,
task_id: models::Id,
parent_id: Option<models::Id>,
state: &'s mut Option<SqlJson<BoxedRaw>>,
inbox: &'s mut Option<Vec<SqlJson<(models::Id, Option<BoxedRaw>)>>>,
) -> BoxFuture<'s, anyhow::Result<PollOutcome<BoxedRaw>>> {
mut state: Option<SqlJson<BoxedRaw>>,
mut inbox: Option<Vec<SqlJson<(models::Id, Option<BoxedRaw>)>>>,
) -> BoxFuture<'s, anyhow::Result<()>> {
async move {
let mut state_parsed: E::State = if let Some(state) = state {
serde_json::from_str(state.get()).context("failed to decode task state")?
Expand All @@ -52,6 +54,7 @@ impl<E: Executor> ObjSafe for E {

let outcome = E::poll(
self,
pool,
task_id,
parent_id,
&mut state_parsed,
Expand All @@ -60,21 +63,16 @@ impl<E: Executor> ObjSafe for E {
.await?;

// Re-encode state for persistence.
// If we're Done, then the output state is NULL which is implicitly Default.
if matches!(outcome, PollOutcome::Done) {
*state = None
} else {
*state = Some(SqlJson(
serde_json::value::to_raw_value(&state_parsed)
.context("failed to encode inner state")?,
));
}
state = Some(SqlJson(
serde_json::value::to_raw_value(&state_parsed)
.context("failed to encode inner state")?,
));

// Re-encode the unconsumed portion of the inbox.
if inbox_parsed.is_empty() {
*inbox = None
inbox = None
} else {
*inbox = Some(
inbox = Some(
inbox_parsed
.into_iter()
.map(|(task_id, msg)| {
Expand All @@ -91,19 +89,16 @@ impl<E: Executor> ObjSafe for E {
);
}

Ok(match outcome {
PollOutcome::Done => PollOutcome::Done,
PollOutcome::Send(task_id, msg) => PollOutcome::Send(task_id, msg),
PollOutcome::Sleep(interval) => PollOutcome::Sleep(interval),
PollOutcome::Spawn(task_id, task_type, msg) => {
PollOutcome::Spawn(task_id, task_type, msg)
}
PollOutcome::Suspend => PollOutcome::Suspend,
PollOutcome::Yield(msg) => PollOutcome::Yield(
serde_json::value::to_raw_value(&msg)
.context("failed to encode yielded message")?,
),
})
let mut txn = pool.begin().await?;

let action = outcome
.apply(&mut *txn)
.await
.context("failed to apply task Outcome")?;

() = persist_action(action, &mut *txn, task_id, parent_id, state, inbox).await?;

Ok(txn.commit().await?)
}
.boxed()
}
Expand All @@ -119,8 +114,8 @@ pub async fn poll_task(
id: task_id,
type_: _,
parent_id,
mut inbox,
mut state,
inbox,
state,
mut last_heartbeat,
},
}: server::ReadyTask,
Expand All @@ -145,24 +140,9 @@ pub async fn poll_task(

// Poll `executor` and `update_heartbeats` in tandem, so that a failure
// to update our heartbeat also cancels the executor.
let outcome = tokio::select! {
outcome = executor.poll(task_id, parent_id, &mut state, &mut inbox) => { outcome? },
err = &mut update_heartbeats => return Err(err),
};

// The possibly long-lived polling operation is now complete.
// Build a Future that commits a (hopefully) brief transaction of `outcome`.
let persist_outcome = async {
let mut txn = pool.begin().await?;
() = persist_outcome(outcome, &mut *txn, task_id, parent_id, state, inbox).await?;
Ok(txn.commit().await?)
};

// Poll `persist_outcome` while continuing to poll `update_heartbeats`,
// to guarantee we cannot commit an outcome after our lease is lost.
tokio::select! {
result = persist_outcome => result,
err = update_heartbeats => Err(err),
result = executor.poll(&pool, task_id, parent_id, state, inbox) => result,
err = &mut update_heartbeats => return Err(err),
}
}

Expand Down Expand Up @@ -203,17 +183,17 @@ async fn update_heartbeat(
Ok(updated.heartbeat)
}

async fn persist_outcome(
outcome: PollOutcome<BoxedRaw>,
async fn persist_action(
action: Action,
txn: &mut sqlx::PgConnection,
task_id: models::Id,
parent_id: Option<models::Id>,
state: Option<SqlJson<BoxedRaw>>,
mut state: Option<SqlJson<BoxedRaw>>,
inbox: Option<Vec<SqlJson<(models::Id, Option<BoxedRaw>)>>>,
) -> anyhow::Result<()> {
use std::time::Duration;

if let PollOutcome::Spawn(spawn_id, spawn_type, _msg) = &outcome {
if let Action::Spawn(spawn_id, spawn_type, _msg) = &action {
sqlx::query!(
"SELECT internal.create_task($1, $2, $3)",
*spawn_id as models::Id,
Expand All @@ -225,15 +205,15 @@ async fn persist_outcome(
.context("failed to spawn new task")?;
}

if let Some((send_id, msg)) = match &outcome {
if let Some((send_id, msg)) = match &action {
// When a task is spawned, send its first message.
PollOutcome::Spawn(spawn_id, _spawn_type, msg) => Some((*spawn_id, Some(msg))),
Action::Spawn(spawn_id, _spawn_type, msg) => Some((*spawn_id, Some(msg))),
// If we're Done but have a parent, send it an EOF.
PollOutcome::Done => parent_id.map(|parent_id| (parent_id, None)),
Action::Done => parent_id.map(|parent_id| (parent_id, None)),
// Send an arbitrary message to an identified task.
PollOutcome::Send(task_id, msg) => Some((*task_id, msg.as_ref())),
Action::Send(task_id, msg) => Some((*task_id, msg.as_ref())),
// Yield is sugar for sending to our parent.
PollOutcome::Yield(msg) => {
Action::Yield(msg) => {
let Some(parent_id) = parent_id else {
anyhow::bail!("task yielded illegally, because it does not have a parent");
};
Expand All @@ -255,17 +235,19 @@ async fn persist_outcome(
let wake_at_interval = if inbox.is_some() {
Some(Duration::ZERO) // Always poll immediately if inbox items remain.
} else {
match &outcome {
PollOutcome::Sleep(interval) => Some(*interval),
match &action {
Action::Sleep(interval) => Some(*interval),
// These outcomes do not suspend the task, and it should wake as soon as possible.
PollOutcome::Spawn(..) | PollOutcome::Send(..) | PollOutcome::Yield(..) => {
Some(Duration::ZERO)
}
Action::Spawn(..) | Action::Send(..) | Action::Yield(..) => Some(Duration::ZERO),
// Suspend indefinitely (note that NOW() + NULL::INTERVAL is NULL).
PollOutcome::Done | PollOutcome::Suspend => None,
Action::Done | Action::Suspend => None,
}
};

if let Action::Done = &action {
state = None; // Set to NULL, which is implicit Default.
}

let updated = sqlx::query!(
r#"
UPDATE internal.tasks SET
Expand All @@ -292,7 +274,7 @@ async fn persist_outcome(

// If we're Done and also successfully suspended, then delete ourselves.
// (Otherwise, the task has been left in a like-new state).
if matches!(&outcome, PollOutcome::Done if updated.suspended) {
if matches!(&action, Action::Done if updated.suspended) {
sqlx::query!(
"DELETE FROM internal.tasks WHERE task_id = $1;",
task_id as models::Id,
Expand Down
53 changes: 45 additions & 8 deletions crates/automations/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,40 @@ type BoxedRaw = Box<serde_json::value::RawValue>;
#[sqlx(transparent)]
pub struct TaskType(pub i16);

/// PollOutcome is the outcome of an `Executor::poll()` for a given task.
/// Outcome of an `Executor::poll()` for a given task, which encloses
/// an Action with which it's applied as a single transaction.
///
/// As an example of how Executor, Outcome, and Action are used together,
/// suppose an implementation of `Executor::poll()` is called:
///
/// - It reads DB state associated with the task using sqlx::PgPool.
/// - It performs long-running work, running outside of a DB transaction.
/// - It returns an Outcome implementation which encapsulates the
/// preconditions it observed, as well as its domain-specific outcome.
/// - `Outcome::apply()` is called and re-verifies preconditions using `txn`,
/// returning an error if preconditions have changed.
/// - It applies the effects of its outcome and returns a polling Action.
/// - `txn` is further by this crate as required by the Action, and then commits.
///
pub trait Outcome: Send {
/// Apply the effects of an Executor poll. While this is an async routine,
/// apply() runs in the context of a held transaction and should be fast.
fn apply<'s>(
self,
txn: &'s mut sqlx::PgConnection,
) -> impl std::future::Future<Output = anyhow::Result<Action>> + Send + 's;
}

/// Action undertaken by an Executor task poll.
#[derive(Debug)]
pub enum PollOutcome<Yield> {
pub enum Action {
/// Spawn a new TaskId with the given TaskType and send a first message.
/// The TaskId must not exist.
Spawn(models::Id, TaskType, BoxedRaw),
/// Send a message (Some) or EOF (None) to another TaskId, which must exist.
Send(models::Id, Option<BoxedRaw>),
/// Yield to send a message to this task's parent.
Yield(Yield),
Yield(BoxedRaw),
/// Sleep for at-most the indicated Duration, then poll again.
/// The task may be woken earlier if it receives a message.
Sleep(std::time::Duration),
Expand All @@ -43,28 +67,35 @@ pub enum PollOutcome<Yield> {
Done,
}

// Action implements an Outcome with no side-effects.
impl Outcome for Action {
async fn apply<'s>(self, _txn: &'s mut sqlx::PgConnection) -> anyhow::Result<Action> {
Ok(self)
}
}

/// Executor is the core trait implemented by executors of various task types.
pub trait Executor: Send + Sync + 'static {
const TASK_TYPE: TaskType;

type Receive: serde::de::DeserializeOwned + serde::Serialize + Send;
type State: Default + serde::de::DeserializeOwned + serde::Serialize + Send;
type Yield: serde::Serialize;
type Outcome: Outcome;

fn poll<'s>(
&'s self,
pool: &'s sqlx::PgPool,
task_id: models::Id,
parent_id: Option<models::Id>,
state: &'s mut Self::State,
inbox: &'s mut std::collections::VecDeque<(models::Id, Option<Self::Receive>)>,
) -> impl std::future::Future<Output = anyhow::Result<PollOutcome<Self::Yield>>> + Send + 's;
) -> impl std::future::Future<Output = anyhow::Result<Self::Outcome>> + Send + 's;
}

/// Server holds registered implementations of Executor,
/// and serves them.
/// Server holds registered implementations of Executors and serves them.
pub struct Server(Vec<Arc<dyn executors::ObjSafe>>);

impl<Yield> PollOutcome<Yield> {
impl Action {
pub fn spawn<M: serde::Serialize>(
spawn_id: models::Id,
task_type: TaskType,
Expand All @@ -89,6 +120,12 @@ impl<Yield> PollOutcome<Yield> {
},
))
}

pub fn yield_<M: serde::Serialize>(msg: M) -> anyhow::Result<Self> {
Ok(Self::Yield(
serde_json::value::to_raw_value(&msg).context("failed to encode yielded message")?,
))
}
}

pub fn next_task_id() -> models::Id {
Expand Down
Loading
Loading