Skip to content

Commit

Permalink
feat: add check_version to wf (#1560)
Browse files Browse the repository at this point in the history
<!-- Please make sure there is an issue that this PR is correlated to. -->
Fixes RVT-4340
## Changes

<!-- If there are frontend changes, please include screenshots. -->
  • Loading branch information
MasterPtato committed Dec 10, 2024
1 parent ed768ab commit a3c99f8
Show file tree
Hide file tree
Showing 11 changed files with 98 additions and 51 deletions.
6 changes: 3 additions & 3 deletions packages/common/chirp-workflow/core/src/compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
common,
message::{MessageCtx, SubscriptionHandle},
},
db::{DatabaseHandle, DatabasePgNats},
db::{DatabaseCrdbNats, DatabaseHandle},
message::{AsTags, Message},
operation::{Operation, OperationInput},
signal::Signal,
Expand Down Expand Up @@ -112,13 +112,13 @@ async fn db_from_ctx<B: Debug + Clone>(
let crdb = ctx.crdb().await?;
let nats = ctx.conn().nats().await?;

Ok(DatabasePgNats::from_pools(crdb, nats))
Ok(DatabaseCrdbNats::from_pools(crdb, nats))
}

// Get crdb pool as a trait object
pub async fn db_from_pools(pools: &rivet_pools::Pools) -> GlobalResult<DatabaseHandle> {
let crdb = pools.crdb()?;
let nats = pools.nats()?;

Ok(DatabasePgNats::from_pools(crdb, nats))
Ok(DatabaseCrdbNats::from_pools(crdb, nats))
}
8 changes: 5 additions & 3 deletions packages/common/chirp-workflow/core/src/ctx/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
message::{SubscriptionHandle, TailAnchor, TailAnchorResponse},
MessageCtx,
},
db::{DatabaseHandle, DatabasePgNats},
db::{DatabaseCrdbNats, DatabaseHandle},
message::{AsTags, Message, NatsMessage},
operation::{Operation, OperationInput},
signal::Signal,
Expand Down Expand Up @@ -68,8 +68,10 @@ impl TestCtx {
(),
);

let db =
DatabasePgNats::from_pools(pools.crdb().unwrap(), pools.nats_option().clone().unwrap());
let db = DatabaseCrdbNats::from_pools(
pools.crdb().unwrap(),
pools.nats_option().clone().unwrap(),
);
let msg_ctx = MessageCtx::new(&conn, ray_id).await.unwrap();

TestCtx {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
activity::{Activity, ActivityInput},
builder::workflow as builder,
ctx::{workflow::Loop, WorkflowCtx},
error::WorkflowError,
executable::{AsyncResult, Executable},
listen::{CustomListener, Listen},
message::Message,
Expand Down Expand Up @@ -47,7 +48,10 @@ impl<'a> VersionedWorkflowCtx<'a> {
self.inner
}

pub fn version(&self) -> usize {
// Handles version of branches via addition. If the inner workflow ctx is version 2 and this version is 2,
// the actual stored version will be 3. Not public because it just denotes the version of the context,
// use `check_version` instead.
fn version(&self) -> usize {
self.inner.version() + self.version - 1
}

Expand Down Expand Up @@ -143,6 +147,39 @@ impl<'a> VersionedWorkflowCtx<'a> {
wrap!(self, "sleep", { self.inner.sleep_until(time).await })
}

/// Returns the version of the current event in history. If no event exists, returns current version and
/// inserts a version check event.
pub async fn check_version(&mut self) -> GlobalResult<usize> {
if self.version == 0 {
return Err(GlobalError::raw(WorkflowError::InvalidVersion(
"version for `check_version` must be greater than 0".into(),
)));
}

if let Some(step_version) = self
.inner
.cursor()
.compare_version_check()
.map_err(GlobalError::raw)?
{
Ok(step_version + 1 - self.inner.version())
} else {
tracing::debug!(name=%self.inner.name(), id=%self.inner.workflow_id(), "inserting version check");

self.inner
.db()
.commit_workflow_version_check_event(
self.inner.workflow_id(),
&self.inner.cursor().current_location(),
self.version + self.inner.version() - 1,
self.inner.loop_location(),
)
.await?;

Ok(self.version + 1 - self.inner.version())
}
}

pub async fn listen_with_timeout<T: Listen>(
&mut self,
duration: impl DurationToMillis,
Expand Down
30 changes: 15 additions & 15 deletions packages/common/chirp-workflow/core/src/ctx/workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1188,36 +1188,35 @@ impl WorkflowCtx {
Ok(())
}

/// Returns true if the workflow has never reached this point before and is consistent for all future
/// executions of this workflow.
pub async fn is_new(&mut self) -> GlobalResult<bool> {
// Existing event
let is_new = if let Some(is_new) = self
/// Returns the version of the current event in history. If no event exists, returns `current_version` and
/// inserts a version check event.
pub async fn check_version(&mut self, current_version: usize) -> GlobalResult<usize> {
if current_version == 0 {
return Err(GlobalError::raw(WorkflowError::InvalidVersion(
"version for `check_version` must be greater than 0".into(),
)));
}

if let Some(step_version) = self
.cursor
.compare_version_check()
.map_err(GlobalError::raw)?
{
is_new
Ok(step_version + 1 - self.version)
} else {
tracing::debug!(name=%self.name, id=%self.workflow_id, "inserting version check");

self.db
.commit_workflow_version_check_event(
self.workflow_id,
&self.cursor.current_location(),
current_version + self.version - 1,
self.loop_location(),
)
.await?;

true
};

if is_new {
// Move to next event
self.cursor.inc();
Ok(current_version + 1 - self.version)
}

Ok(is_new)
}
}

Expand Down Expand Up @@ -1258,7 +1257,8 @@ impl WorkflowCtx {
self.ray_id
}

pub fn version(&self) -> usize {
// Not public because this only denotes the version of the context, use `check_version` instead.
pub(crate) fn version(&self) -> usize {
self.version
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ const TXN_RETRY: Duration = Duration::from_millis(100);
/// Maximum times a query ran by this database adapter is retried.
const MAX_QUERY_RETRIES: usize = 16;

pub struct DatabasePgNats {
pub struct DatabaseCrdbNats {
pool: PgPool,
nats: NatsPool,
sub: Mutex<Option<rivet_pools::prelude::nats::Subscriber>>,
}

impl DatabasePgNats {
pub fn from_pools(pool: PgPool, nats: NatsPool) -> Arc<DatabasePgNats> {
Arc::new(DatabasePgNats {
impl DatabaseCrdbNats {
pub fn from_pools(pool: PgPool, nats: NatsPool) -> Arc<DatabaseCrdbNats> {
Arc::new(DatabaseCrdbNats {
pool,
// Lazy load the nats sub
sub: Mutex::new(None),
Expand Down Expand Up @@ -125,7 +125,7 @@ impl DatabasePgNats {
}

#[async_trait::async_trait]
impl Database for DatabasePgNats {
impl Database for DatabaseCrdbNats {
async fn wake(&self) -> WorkflowResult<()> {
let mut sub = self.sub.try_lock().map_err(WorkflowError::WakeLock)?;

Expand Down Expand Up @@ -454,7 +454,7 @@ impl Database for DatabasePgNats {
sw.location,
sw.location2,
version,
4 AS event_type, -- pg_nats::types::EventType
4 AS event_type, -- crdb_nats::types::EventType
w.workflow_name AS name,
sw.sub_workflow_id AS auxiliary_id,
NULL AS hash,
Expand All @@ -476,7 +476,7 @@ impl Database for DatabasePgNats {
location,
location2,
version,
5 AS event_type, -- pg_nats::types::EventType
5 AS event_type, -- crdb_nats::types::EventType
NULL AS name,
NULL AS auxiliary_id,
NULL AS hash,
Expand All @@ -496,7 +496,7 @@ impl Database for DatabasePgNats {
location,
location2,
version,
6 AS event_type, -- pg_nats::types::EventType
6 AS event_type, -- crdb_nats::types::EventType
NULL AS name,
NULL AS auxiliary_id,
NULL AS hash,
Expand All @@ -516,7 +516,7 @@ impl Database for DatabasePgNats {
ARRAY[] AS location,
location AS location2,
version,
7 AS event_type, -- pg_nats::types::EventType
7 AS event_type, -- crdb_nats::types::EventType
NULL AS name,
NULL AS auxiliary_id,
NULL AS hash,
Expand All @@ -536,7 +536,7 @@ impl Database for DatabasePgNats {
ARRAY[] AS location,
location AS location2,
1 AS version, -- Default
8 AS event_type, -- pg_nats::types::EventType
8 AS event_type, -- crdb_nats::types::EventType
event_name AS name,
NULL AS auxiliary_id,
NULL AS hash,
Expand All @@ -556,7 +556,7 @@ impl Database for DatabasePgNats {
ARRAY[] AS location,
location AS location2,
1 AS version, -- Default
9 AS event_type, -- pg_nats::types::EventType
9 AS event_type, -- crdb_nats::types::EventType
NULL AS name,
NULL AS auxiliary_id,
NULL AS hash,
Expand Down Expand Up @@ -1420,20 +1420,22 @@ impl Database for DatabasePgNats {
&self,
from_workflow_id: Uuid,
location: &Location,
version: usize,
loop_location: Option<&Location>,
) -> WorkflowResult<()> {
self.query(|| async {
sqlx::query(indoc!(
"
INSERT INTO db_workflow.workflow_version_check_events(
workflow_id, location, loop_location
workflow_id, location, version, loop_location
)
VALUES($1, $2, $3)
VALUES($1, $2, $3, $4)
RETURNING 1
",
))
.bind(from_workflow_id)
.bind(location)
.bind(version as i64)
.bind(loop_location)
.execute(&mut *self.conn().await?)
.await
Expand Down
5 changes: 3 additions & 2 deletions packages/common/chirp-workflow/core/src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use crate::{
workflow::Workflow,
};

mod pg_nats;
pub use pg_nats::DatabasePgNats;
mod crdb_nats;
pub use crdb_nats::DatabaseCrdbNats;

pub type DatabaseHandle = Arc<dyn Database + Sync>;

Expand Down Expand Up @@ -222,6 +222,7 @@ pub trait Database: Send {
&self,
from_workflow_id: Uuid,
location: &Location,
version: usize,
loop_location: Option<&Location>,
) -> WorkflowResult<()>;
}
Expand Down
3 changes: 3 additions & 0 deletions packages/common/chirp-workflow/core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ pub enum WorkflowError {

#[error("serialize location: {0}")]
SerializeLocation(serde_json::Error),

#[error("invalid version: {0}")]
InvalidVersion(String),
}

impl WorkflowError {
Expand Down
Loading

0 comments on commit a3c99f8

Please sign in to comment.