diff --git a/lib/chirp-workflow/core/Cargo.toml b/lib/chirp-workflow/core/Cargo.toml index d2458a7a3d..7f87051f5d 100644 --- a/lib/chirp-workflow/core/Cargo.toml +++ b/lib/chirp-workflow/core/Cargo.toml @@ -17,7 +17,6 @@ indoc = "2.0.5" lazy_static = "1.4" prost = "0.12.4" prost-types = "0.12.4" -rand = "0.8.5" rivet-cache = { path = "../../cache/build" } rivet-connection = { path = "../../connection" } rivet-metrics = { path = "../../metrics" } diff --git a/lib/chirp-workflow/core/src/compat.rs b/lib/chirp-workflow/core/src/compat.rs index ae84824ed5..2039593620 100644 --- a/lib/chirp-workflow/core/src/compat.rs +++ b/lib/chirp-workflow/core/src/compat.rs @@ -102,7 +102,7 @@ where M: Message, B: Debug + Clone, { - let msg_ctx = MessageCtx::new(ctx.conn(), ctx.req_id(), ctx.ray_id()) + let msg_ctx = MessageCtx::new(ctx.conn(), ctx.ray_id()) .await .map_err(GlobalError::raw)?; diff --git a/lib/chirp-workflow/core/src/ctx/api.rs b/lib/chirp-workflow/core/src/ctx/api.rs index 699487b4f2..3b1e6051aa 100644 --- a/lib/chirp-workflow/core/src/ctx/api.rs +++ b/lib/chirp-workflow/core/src/ctx/api.rs @@ -13,7 +13,7 @@ use crate::{ }, db::DatabaseHandle, error::WorkflowResult, - message::{Message, ReceivedMessage}, + message::{Message, NatsMessage}, operation::{Operation, OperationInput}, signal::Signal, workflow::{Workflow, WorkflowInput}, @@ -55,7 +55,7 @@ impl ApiCtx { (), ); - let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await?; + let msg_ctx = MessageCtx::new(&conn, ray_id).await?; Ok(ApiCtx { ray_id, @@ -129,7 +129,7 @@ impl ApiCtx { pub async fn tail_read( &self, tags: serde_json::Value, - ) -> GlobalResult>> + ) -> GlobalResult>> where M: Message, { diff --git a/lib/chirp-workflow/core/src/ctx/backfill.rs b/lib/chirp-workflow/core/src/ctx/backfill.rs index bf6fea2108..beb260ff93 100644 --- a/lib/chirp-workflow/core/src/ctx/backfill.rs +++ b/lib/chirp-workflow/core/src/ctx/backfill.rs @@ -11,7 +11,7 @@ use std::{ }; use uuid::Uuid; -use crate::util::Location; +use crate::utils::Location; // Yes type Query = Box< diff --git a/lib/chirp-workflow/core/src/ctx/message.rs b/lib/chirp-workflow/core/src/ctx/message.rs index 37da81121c..f2979f2da7 100644 --- a/lib/chirp-workflow/core/src/ctx/message.rs +++ b/lib/chirp-workflow/core/src/ctx/message.rs @@ -13,7 +13,8 @@ use uuid::Uuid; use crate::{ error::{WorkflowError, WorkflowResult}, - message::{self, Message, MessageWrapper, ReceivedMessage, TraceEntry}, + message::{redis_keys, Message, NatsMessage, NatsMessageWrapper}, + utils, }; /// Time (in ms) that we subtract from the anchor grace period in order to @@ -29,29 +30,15 @@ pub struct MessageCtx { /// Used for writing to message tails. This cache is ephemeral. redis_chirp_ephemeral: RedisPool, - req_id: Uuid, ray_id: Uuid, - trace: Vec, } impl MessageCtx { - pub async fn new( - conn: &rivet_connection::Connection, - req_id: Uuid, - ray_id: Uuid, - ) -> WorkflowResult { + pub async fn new(conn: &rivet_connection::Connection, ray_id: Uuid) -> WorkflowResult { Ok(MessageCtx { nats: conn.nats().await?, redis_chirp_ephemeral: conn.redis_chirp_ephemeral().await?, - req_id, ray_id, - trace: conn - .chirp() - .trace() - .iter() - .cloned() - .map(TryInto::try_into) - .collect::>>()?, }) } } @@ -109,7 +96,7 @@ impl MessageCtx { M: Message, { let tags_str = cjson::to_string(&tags).map_err(WorkflowError::SerializeMessageTags)?; - let nats_subject = message::serialize_message_nats_subject::(&tags_str); + let nats_subject = M::nats_subject(); let duration_since_epoch = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_else(|err| unreachable!("time is broken: {}", err)); @@ -124,12 +111,11 @@ impl MessageCtx { // Serialize message let req_id = Uuid::new_v4(); - let message = MessageWrapper { + let message = NatsMessageWrapper { req_id: req_id, ray_id: self.ray_id, - tags: tags.clone(), + tags, ts, - trace: self.trace.clone(), allow_recursive: false, // TODO: body: &body_buf, }; @@ -278,8 +264,7 @@ impl MessageCtx { where M: Message, { - let tags_str = cjson::to_string(opts.tags).map_err(WorkflowError::SerializeMessageTags)?; - let nats_subject = message::serialize_message_nats_subject::(&tags_str); + let nats_subject = M::nats_subject(); // Create subscription and flush immediately. tracing::info!(%nats_subject, tags = ?opts.tags, "creating subscription"); @@ -296,7 +281,7 @@ impl MessageCtx { } // Return handle - let subscription = SubscriptionHandle::new(nats_subject, subscription, self.req_id); + let subscription = SubscriptionHandle::new(nats_subject, subscription, opts.tags.clone()); Ok(subscription) } @@ -305,7 +290,7 @@ impl MessageCtx { pub async fn tail_read( &self, tags: serde_json::Value, - ) -> WorkflowResult>> + ) -> WorkflowResult>> where M: Message, { @@ -320,7 +305,7 @@ impl MessageCtx { // Deserialize message let message = if let Some(message_buf) = message_buf { - let message = ReceivedMessage::::deserialize(message_buf.as_slice())?; + let message = NatsMessage::::deserialize(message_buf.as_slice())?; tracing::info!(?message, "immediate read tail message"); let recv_lag = (rivet_util::timestamp::now() as f64 - message.ts as f64) / 1000.; @@ -410,7 +395,7 @@ where _guard: DropGuard, subject: String, subscription: nats::Subscriber, - req_id: Uuid, + pub tags: serde_json::Value, } impl Debug for SubscriptionHandle @@ -420,6 +405,7 @@ where fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("SubscriptionHandle") .field("subject", &self.subject) + .field("tags", &self.tags) .finish() } } @@ -429,7 +415,7 @@ where M: Message, { #[tracing::instrument(level = "debug", skip_all)] - fn new(subject: String, subscription: nats::Subscriber, req_id: Uuid) -> Self { + fn new(subject: String, subscription: nats::Subscriber, tags: serde_json::Value) -> Self { let token = CancellationToken::new(); { @@ -458,7 +444,7 @@ where _guard: token.drop_guard(), subject, subscription, - req_id, + tags, } } @@ -466,26 +452,7 @@ where /// /// This future can be safely dropped. #[tracing::instrument] - pub async fn next(&mut self) -> WorkflowResult> { - self.next_inner(false).await - } - - // TODO: Add a full config struct to pass to `next` that impl's `Default` - /// Waits for the next message in the subscription that originates from the - /// parent request ID via trace. - /// - /// This future can be safely dropped. - #[tracing::instrument] - pub async fn next_with_trace( - &mut self, - filter_trace: bool, - ) -> WorkflowResult> { - self.next_inner(filter_trace).await - } - - /// This future can be safely dropped. - #[tracing::instrument(level = "trace")] - async fn next_inner(&mut self, filter_trace: bool) -> WorkflowResult> { + pub async fn next(&mut self) -> WorkflowResult> { tracing::info!("waiting for message"); loop { @@ -501,47 +468,22 @@ where } }; - if filter_trace { - let message_wrapper = - ReceivedMessage::::deserialize_wrapper(&nats_message.payload[..])?; - - // Check if the message trace stack originates from this client - // - // We intentionally use the request ID instead of just checking the ray ID because - // there may be multiple calls to `message_with_subscribe` within the same ray. - // Explicitly checking the parent request ensures the response is unique to this - // message. - if message_wrapper - .trace - .iter() - .rev() - .any(|trace_entry| trace_entry.req_id == self.req_id) - { - let message = ReceivedMessage::::deserialize(&nats_message.payload[..])?; - tracing::info!(?message, "received message"); - - return Ok(message); - } - } else { - let message = ReceivedMessage::::deserialize(&nats_message.payload[..])?; - tracing::info!(?message, "received message"); + let message_wrapper = NatsMessage::::deserialize_wrapper(&nats_message.payload[..])?; - let recv_lag = (rivet_util::timestamp::now() as f64 - message.ts as f64) / 1000.; - crate::metrics::MESSAGE_RECV_LAG - .with_label_values(&[M::NAME]) - .observe(recv_lag); + // Check if the subscription tags match a subset of the message tags + if utils::is_value_subset(&self.tags, &message_wrapper.tags) { + let message = NatsMessage::::deserialize_from_wrapper(message_wrapper)?; + tracing::info!(?message, "received message"); return Ok(message); } - // Message not from parent, continue with loop + // Message tags don't match, continue with loop } } /// Converts the subscription in to a stream. - pub fn into_stream( - self, - ) -> impl futures_util::Stream>> { + pub fn into_stream(self) -> impl futures_util::Stream>> { futures_util::stream::try_unfold(self, |mut sub| async move { let message = sub.next().await?; Ok(Some((message, sub))) @@ -569,7 +511,7 @@ pub enum TailAnchorResponse where M: Message + Debug, { - Message(ReceivedMessage), + Message(NatsMessage), /// Anchor was older than the TTL of the message. AnchorExpired, @@ -589,30 +531,3 @@ where } } } - -mod redis_keys { - use std::{ - collections::hash_map::DefaultHasher, - hash::{Hash, Hasher}, - }; - - use crate::message::Message; - - /// HASH - pub fn message_tail(tags_str: &str) -> String - where - M: Message, - { - // Get hash of the tags - let mut hasher = DefaultHasher::new(); - tags_str.hash(&mut hasher); - - format!("{{topic:{}:{:x}}}:tail", M::NAME, hasher.finish()) - } - - pub mod message_tail { - pub const REQUEST_ID: &str = "r"; - pub const TS: &str = "t"; - pub const BODY: &str = "b"; - } -} diff --git a/lib/chirp-workflow/core/src/ctx/standalone.rs b/lib/chirp-workflow/core/src/ctx/standalone.rs index 705176a56e..0c03a5ffb2 100644 --- a/lib/chirp-workflow/core/src/ctx/standalone.rs +++ b/lib/chirp-workflow/core/src/ctx/standalone.rs @@ -12,7 +12,8 @@ use crate::{ }, db::DatabaseHandle, error::WorkflowResult, - message::{Message, ReceivedMessage}, + listen::Listen, + message::{Message, NatsMessage}, operation::{Operation, OperationInput}, signal::Signal, workflow::{Workflow, WorkflowInput}, @@ -54,7 +55,7 @@ impl StandaloneCtx { (), ); - let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await?; + let msg_ctx = MessageCtx::new(&conn, ray_id).await?; Ok(StandaloneCtx { ray_id, @@ -92,6 +93,15 @@ impl StandaloneCtx { builder::signal::SignalBuilder::new(self.db.clone(), self.ray_id, body) } + // /// Listens for a signal indefinitely. + // pub async fn listen(&mut self) -> GlobalResult { + // tracing::info!(name=%self.name, "listening for signal"); + + // let ctx = ListenCtx::new(self); + + // T::listen(&ctx).await + // } + #[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))] pub async fn op( &self, @@ -128,7 +138,7 @@ impl StandaloneCtx { pub async fn tail_read( &self, tags: serde_json::Value, - ) -> GlobalResult>> + ) -> GlobalResult>> where M: Message, { diff --git a/lib/chirp-workflow/core/src/ctx/test.rs b/lib/chirp-workflow/core/src/ctx/test.rs index 7585824c3a..f8aef11372 100644 --- a/lib/chirp-workflow/core/src/ctx/test.rs +++ b/lib/chirp-workflow/core/src/ctx/test.rs @@ -13,10 +13,10 @@ use crate::{ }, db::{DatabaseHandle, DatabasePgNats}, error::WorkflowError, - message::{Message, ReceivedMessage}, + message::{Message, NatsMessage}, operation::{Operation, OperationInput}, signal::Signal, - util, + utils, workflow::{Workflow, WorkflowInput}, }; @@ -50,7 +50,7 @@ impl TestCtx { .expect("failed to create chirp client"); let cache = rivet_cache::CacheInner::from_env(pools.clone()).expect("failed to create cache"); - let conn = util::new_conn( + let conn = utils::new_conn( &shared_client, &pools, &cache, @@ -73,7 +73,7 @@ impl TestCtx { let db = DatabasePgNats::from_pools(pools.crdb().unwrap(), pools.nats_option().clone().unwrap()); - let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await.unwrap(); + let msg_ctx = MessageCtx::new(&conn, ray_id).await.unwrap(); TestCtx { name: service_name, @@ -176,7 +176,7 @@ impl TestCtx { pub async fn tail_read( &self, tags: serde_json::Value, - ) -> GlobalResult>> + ) -> GlobalResult>> where M: Message, { diff --git a/lib/chirp-workflow/core/src/ctx/workflow.rs b/lib/chirp-workflow/core/src/ctx/workflow.rs index a534f95f2c..53ac516a34 100644 --- a/lib/chirp-workflow/core/src/ctx/workflow.rs +++ b/lib/chirp-workflow/core/src/ctx/workflow.rs @@ -16,13 +16,13 @@ use crate::{ db::{DatabaseHandle, PulledWorkflow}, error::{WorkflowError, WorkflowResult}, event::Event, - executable::{closure, AsyncResult, Executable}, + executable::{AsyncResult, Executable}, listen::{CustomListener, Listen}, message::Message, metrics, registry::RegistryHandle, signal::Signal, - util::{ + utils::{ self, time::{DurationToMillis, TsToMillis}, GlobalErrorExt, Location, @@ -80,7 +80,7 @@ impl WorkflowCtx { conn: rivet_connection::Connection, workflow: PulledWorkflow, ) -> GlobalResult { - let msg_ctx = MessageCtx::new(&conn, workflow.workflow_id, workflow.ray_id).await?; + let msg_ctx = MessageCtx::new(&conn, workflow.ray_id).await?; Ok(WorkflowCtx { workflow_id: workflow.workflow_id, @@ -558,15 +558,6 @@ impl WorkflowCtx { exec.try_execute(self).await } - /// Spawns a new thread to execute workflow steps in. - pub fn spawn(&mut self, f: F) -> tokio::task::JoinHandle> - where - F: for<'a> FnOnce(&'a mut WorkflowCtx) -> AsyncResult<'a, T> + Send + 'static, - { - let mut ctx = self.branch(); - tokio::task::spawn(async move { closure(f).execute(&mut ctx).await }) - } - /// Tests if the given error is unrecoverable. If it is, allows the user to run recovery code safely. /// Should always be used when trying to handle activity errors manually. pub fn catch_unrecoverable( @@ -928,7 +919,7 @@ impl WorkflowCtx { /// For debugging, pretty prints the current location pub(crate) fn loc(&self) -> String { - util::format_location(&self.full_location()) + utils::format_location(&self.full_location()) } pub fn name(&self) -> &str { diff --git a/lib/chirp-workflow/core/src/db/mod.rs b/lib/chirp-workflow/core/src/db/mod.rs index 8f7e5f2c35..9064039f93 100644 --- a/lib/chirp-workflow/core/src/db/mod.rs +++ b/lib/chirp-workflow/core/src/db/mod.rs @@ -6,7 +6,7 @@ use crate::{ activity::ActivityId, error::{WorkflowError, WorkflowResult}, event::Event, - util::Location, + utils::Location, workflow::Workflow, }; @@ -17,6 +17,15 @@ pub type DatabaseHandle = Arc; #[async_trait::async_trait] pub trait Database: Send { + /// When using a wake worker instead of a polling worker, this function will return once the worker + /// should fetch the database again. + async fn wake(&self) -> WorkflowResult<()> { + unimplemented!( + "{} does not implement Database::wake", + std::any::type_name::() + ); + } + /// Writes a new workflow to the database. async fn dispatch_workflow( &self, diff --git a/lib/chirp-workflow/core/src/db/pg_nats.rs b/lib/chirp-workflow/core/src/db/pg_nats.rs index ec8c73351a..5d53ac360a 100644 --- a/lib/chirp-workflow/core/src/db/pg_nats.rs +++ b/lib/chirp-workflow/core/src/db/pg_nats.rs @@ -1,8 +1,10 @@ use std::{sync::Arc, time::Duration}; +use futures_util::{stream::FuturesUnordered, FutureExt, StreamExt, TryStreamExt}; use indoc::indoc; use rivet_pools::prelude::NatsPool; use sqlx::{pool::PoolConnection, Acquire, PgPool, Postgres}; +use tokio::sync::Mutex; use tracing::Instrument; use uuid::Uuid; @@ -30,11 +32,17 @@ const MAX_QUERY_RETRIES: usize = 16; pub struct DatabasePgNats { pool: PgPool, nats: NatsPool, + sub: Mutex>, } impl DatabasePgNats { pub fn from_pools(pool: PgPool, nats: NatsPool) -> Arc { - Arc::new(DatabasePgNats { pool, nats }) + Arc::new(DatabasePgNats { + pool, + // Lazy load the nats sub + sub: Mutex::new(None), + nats, + }) } async fn conn(&self) -> WorkflowResult> { @@ -52,7 +60,7 @@ impl DatabasePgNats { let nats = self.nats.clone(); let spawn_res = tokio::task::Builder::new() - .name("chirp_workflow::DatabasePgNats::wake") + .name("wake") .spawn( async move { // Fail gracefully @@ -117,6 +125,25 @@ impl DatabasePgNats { #[async_trait::async_trait] impl Database for DatabasePgNats { + async fn wake(&self) -> WorkflowResult<()> { + let mut sub = self.sub.try_lock().map_err(WorkflowError::WakeLock)?; + + // Initialize sub + if sub.is_none() { + *sub = Some( + self.nats + .subscribe(message::WORKER_WAKE_SUBJECT) + .await + .map_err(|x| WorkflowError::CreateSubscription(x.into()))?, + ); + } + + match sub.as_mut().expect("unreachable").next().await { + Some(_) => Ok(()), + None => Err(WorkflowError::SubscriptionUnsubscribed), + } + } + async fn dispatch_workflow( &self, ray_id: Uuid, diff --git a/lib/chirp-workflow/core/src/error.rs b/lib/chirp-workflow/core/src/error.rs index 2631e34a36..f8a0bd77dc 100644 --- a/lib/chirp-workflow/core/src/error.rs +++ b/lib/chirp-workflow/core/src/error.rs @@ -78,6 +78,9 @@ pub enum WorkflowError { #[error("serialize message tags: {0:?}")] SerializeMessageTags(cjson::Error), + #[error("decode message tags: {0}")] + DeserializeMessageTags(serde_json::Error), + #[error("serialize loop output: {0}")] SerializeLoopOutput(serde_json::Error), @@ -134,6 +137,9 @@ pub enum WorkflowError { #[error("sleeping until {0}")] Sleep(i64), + + #[error("cannot have multiple concurrent calls to Database::wake")] + WakeLock(tokio::sync::TryLockError), } impl WorkflowError { diff --git a/lib/chirp-workflow/core/src/event.rs b/lib/chirp-workflow/core/src/event.rs index cf4fd90317..3a33f786c6 100644 --- a/lib/chirp-workflow/core/src/event.rs +++ b/lib/chirp-workflow/core/src/event.rs @@ -10,7 +10,7 @@ use crate::{ SignalEventRow, SignalSendEventRow, SleepEventRow, SubWorkflowEventRow, }, error::{WorkflowError, WorkflowResult}, - util::Location, + utils::Location, }; /// An event that happened in the workflow run. diff --git a/lib/chirp-workflow/core/src/lib.rs b/lib/chirp-workflow/core/src/lib.rs index e8e16f5eb3..4532793292 100644 --- a/lib/chirp-workflow/core/src/lib.rs +++ b/lib/chirp-workflow/core/src/lib.rs @@ -13,6 +13,6 @@ pub mod operation; pub mod prelude; pub mod registry; mod signal; -pub mod util; +pub mod utils; mod worker; pub mod workflow; diff --git a/lib/chirp-workflow/core/src/message.rs b/lib/chirp-workflow/core/src/message.rs index 7f08b1a74c..abcd5ba03e 100644 --- a/lib/chirp-workflow/core/src/message.rs +++ b/lib/chirp-workflow/core/src/message.rs @@ -1,6 +1,5 @@ use std::fmt::Debug; -use rivet_operation::prelude::proto::chirp; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use uuid::Uuid; @@ -11,58 +10,59 @@ pub const WORKER_WAKE_SUBJECT: &str = "chirp.workflow.worker.wake"; pub trait Message: Debug + Send + Sync + Serialize + DeserializeOwned + 'static { const NAME: &'static str; const TAIL_TTL: std::time::Duration; -} -pub fn serialize_message_nats_subject(tags_str: &str) -> String -where - M: Message, -{ - format!("chirp.workflow.msg.{}.{}", M::NAME, tags_str,) + fn nats_subject() -> String { + format!("chirp.workflow.msg.{}", Self::NAME) + } } -/// A message received from a Chirp subscription. +/// A message received from a NATS subscription. #[derive(Debug)] -pub struct ReceivedMessage +pub struct NatsMessage where M: Message, { pub(crate) ray_id: Uuid, pub(crate) req_id: Uuid, pub(crate) ts: i64, - pub(crate) trace: Vec, pub(crate) body: M, } -impl ReceivedMessage +impl NatsMessage where M: Message, { #[tracing::instrument(skip(buf))] pub(crate) fn deserialize(buf: &[u8]) -> WorkflowResult { - // Deserialize the wrapper let message_wrapper = Self::deserialize_wrapper(buf)?; + Self::deserialize_from_wrapper(message_wrapper) + } + + #[tracing::instrument(skip(wrapper))] + pub(crate) fn deserialize_from_wrapper( + wrapper: NatsMessageWrapper<'_>, + ) -> WorkflowResult { // Deserialize the body - let body = serde_json::from_str::(message_wrapper.body.get()) + let body = serde_json::from_str(wrapper.body.get()) .map_err(WorkflowError::DeserializeMessageBody)?; - Ok(ReceivedMessage { - ray_id: message_wrapper.ray_id, - req_id: message_wrapper.req_id, - ts: message_wrapper.ts, - trace: message_wrapper.trace, + Ok(NatsMessage { + ray_id: wrapper.ray_id, + req_id: wrapper.req_id, + ts: wrapper.ts, body, }) } // Only returns the message wrapper #[tracing::instrument(skip(buf))] - pub(crate) fn deserialize_wrapper<'a>(buf: &'a [u8]) -> WorkflowResult> { + pub(crate) fn deserialize_wrapper<'a>(buf: &'a [u8]) -> WorkflowResult> { serde_json::from_slice(buf).map_err(WorkflowError::DeserializeMessage) } } -impl std::ops::Deref for ReceivedMessage +impl std::ops::Deref for NatsMessage where M: Message, { @@ -73,7 +73,7 @@ where } } -impl ReceivedMessage +impl NatsMessage where M: Message, { @@ -93,42 +93,42 @@ where pub fn body(&self) -> &M { &self.body } - - pub fn trace(&self) -> &[TraceEntry] { - &self.trace - } } #[derive(Serialize, Deserialize)] -pub(crate) struct MessageWrapper<'a> { +pub(crate) struct NatsMessageWrapper<'a> { pub(crate) ray_id: Uuid, pub(crate) req_id: Uuid, pub(crate) tags: serde_json::Value, pub(crate) ts: i64, - pub(crate) trace: Vec, #[serde(borrow)] pub(crate) body: &'a serde_json::value::RawValue, pub(crate) allow_recursive: bool, } -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct TraceEntry { - context_name: String, - pub(crate) req_id: Uuid, - ts: i64, -} +pub mod redis_keys { + use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, + }; -impl TryFrom for TraceEntry { - type Error = WorkflowError; - - fn try_from(value: chirp::TraceEntry) -> WorkflowResult { - Ok(TraceEntry { - context_name: value.context_name.clone(), - req_id: value - .req_id - .map(|id| id.as_uuid()) - .ok_or(WorkflowError::MissingMessageData)?, - ts: value.ts, - }) + use super::Message; + + /// HASH + pub fn message_tail(tags_str: &str) -> String + where + M: Message, + { + // Get hash of the tags + let mut hasher = DefaultHasher::new(); + tags_str.hash(&mut hasher); + + format!("{{topic:{}:{:x}}}:tail", M::NAME, hasher.finish()) + } + + pub mod message_tail { + pub const REQUEST_ID: &str = "r"; + pub const TS: &str = "t"; + pub const BODY: &str = "b"; } } diff --git a/lib/chirp-workflow/core/src/prelude.rs b/lib/chirp-workflow/core/src/prelude.rs index f3210d078b..b23c467fb7 100644 --- a/lib/chirp-workflow/core/src/prelude.rs +++ b/lib/chirp-workflow/core/src/prelude.rs @@ -25,7 +25,7 @@ pub use crate::{ operation::Operation, registry::Registry, signal::{join_signal, Signal}, - util::GlobalErrorExt, + utils::GlobalErrorExt, worker::Worker, workflow::Workflow, }; diff --git a/lib/chirp-workflow/core/src/signal.rs b/lib/chirp-workflow/core/src/signal.rs index 4b8b662c44..e45749c279 100644 --- a/lib/chirp-workflow/core/src/signal.rs +++ b/lib/chirp-workflow/core/src/signal.rs @@ -32,28 +32,11 @@ pub trait Signal { /// ```` #[macro_export] macro_rules! join_signal { - (pub $join:ident, [$($signals:ident),* $(,)?]) => { - pub enum $join { + ($vis:vis $join:ident, [$($signals:ident),* $(,)?]) => { + $vis enum $join { $($signals($signals)),* } - join_signal!(@ $join, [$($signals),*]); - }; - (pub($($vis:tt)*) $join:ident, [$($signals:ident),* $(,)?]) => { - pub($($vis)*) enum $join { - $($signals($signals)),* - } - - join_signal!(@ $join, [$($signals),*]); - }; - ($join:ident, [$($signals:ident),* $(,)?]) => { - enum $join { - $($signals($signals)),* - } - - join_signal!(@ $join, [$($signals),*]); - }; - (@ $join:ident, [$($signals:ident),*]) => { #[async_trait::async_trait] impl Listen for $join { async fn listen(ctx: &chirp_workflow::prelude::ListenCtx) -> chirp_workflow::prelude::WorkflowResult { diff --git a/lib/chirp-workflow/core/src/util.rs b/lib/chirp-workflow/core/src/utils.rs similarity index 83% rename from lib/chirp-workflow/core/src/util.rs rename to lib/chirp-workflow/core/src/utils.rs index 9f72dafb59..19b0b8514a 100644 --- a/lib/chirp-workflow/core/src/util.rs +++ b/lib/chirp-workflow/core/src/utils.rs @@ -1,14 +1,10 @@ -use global_error::{macros::*, GlobalError, GlobalResult}; -use rand::Rng; +use global_error::{GlobalError, GlobalResult}; use uuid::Uuid; use crate::error::WorkflowError; pub type Location = Box<[usize]>; -// How often the `inject_fault` function fails in percent -const FAULT_RATE: usize = 80; - /// Allows for checking if a global error returned from an activity is recoverable. pub trait GlobalErrorExt { fn is_workflow_recoverable(&self) -> bool; @@ -114,14 +110,6 @@ pub mod time { } } -pub fn inject_fault() -> GlobalResult<()> { - if rand::thread_rng().gen_range(0..100) < FAULT_RATE { - bail!("This is a random panic!"); - } - - Ok(()) -} - pub(crate) fn new_conn( shared_client: &chirp_client::SharedClientHandle, pools: &rivet_pools::Pools, @@ -147,6 +135,24 @@ pub(crate) fn new_conn( rivet_connection::Connection::new(client, pools.clone(), cache.clone()) } +/// Returns true if `subset` is a subset of `superset`. +pub fn is_value_subset(subset: &serde_json::Value, superset: &serde_json::Value) -> bool { + match (subset, superset) { + (serde_json::Value::Object(sub_obj), serde_json::Value::Object(super_obj)) => { + sub_obj.iter().all(|(k, sub_val)| { + super_obj + .get(k) + .map_or(false, |super_val| is_value_subset(sub_val, super_val)) + }) + } + (serde_json::Value::Array(sub_arr), serde_json::Value::Array(super_arr)) => sub_arr + .iter() + .zip(super_arr) + .all(|(sub_val, super_val)| is_value_subset(sub_val, super_val)), + _ => subset == superset, + } +} + pub fn format_location(loc: &Location) -> String { let mut s = "{".to_string(); diff --git a/lib/chirp-workflow/core/src/worker.rs b/lib/chirp-workflow/core/src/worker.rs index 37da8f1a03..dce3437d18 100644 --- a/lib/chirp-workflow/core/src/worker.rs +++ b/lib/chirp-workflow/core/src/worker.rs @@ -1,13 +1,9 @@ -use futures_util::StreamExt; use global_error::GlobalResult; use tokio::time::Duration; use tracing::Instrument; use uuid::Uuid; -use crate::{ - ctx::WorkflowCtx, db::DatabaseHandle, error::WorkflowError, message, registry::RegistryHandle, - util, -}; +use crate::{ctx::WorkflowCtx, db::DatabaseHandle, registry::RegistryHandle, utils}; pub const TICK_INTERVAL: Duration = Duration::from_secs(5); @@ -29,7 +25,8 @@ impl Worker { } } - pub async fn start(mut self, pools: rivet_pools::Pools) -> GlobalResult<()> { + /// Polls the database periodically + pub async fn poll_start(mut self, pools: rivet_pools::Pools) -> GlobalResult<()> { tracing::info!( worker_instance_id=?self.worker_instance_id, "starting worker instance with {} registered workflows", @@ -49,7 +46,8 @@ impl Worker { } } - pub async fn start_with_nats(mut self, pools: rivet_pools::Pools) -> GlobalResult<()> { + /// Polls the database periodically or wakes immediately when `Database::wake` finishes + pub async fn wake_start(mut self, pools: rivet_pools::Pools) -> GlobalResult<()> { tracing::info!( worker_instance_id=?self.worker_instance_id, "starting worker instance with {} registered workflows", @@ -63,25 +61,10 @@ impl Worker { let mut interval = tokio::time::interval(TICK_INTERVAL); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - // Create a subscription to the wake subject which receives messages whenever the worker should be - // awoken - let mut sub = pools - .nats()? - .subscribe(message::WORKER_WAKE_SUBJECT) - .await - .map_err(|x| WorkflowError::CreateSubscription(x.into()))?; - loop { tokio::select! { _ = interval.tick() => {}, - msg = sub.next() => { - match msg { - Some(_) => interval.reset(), - None => { - return Err(WorkflowError::SubscriptionUnsubscribed.into()); - } - } - } + res = self.db.wake() => res?, } self.tick(&shared_client, &pools, &cache).await?; @@ -111,7 +94,7 @@ impl Worker { .pull_workflows(self.worker_instance_id, &filter) .await?; for workflow in workflows { - let conn = util::new_conn( + let conn = utils::new_conn( &shared_client, pools, cache, @@ -127,7 +110,7 @@ impl Worker { async move { // Sleep until deadline if let Some(wake_deadline_ts) = wake_deadline_ts { - util::time::sleep_until_ts(wake_deadline_ts as u64).await; + utils::time::sleep_until_ts(wake_deadline_ts as u64).await; } if let Err(err) = ctx.run().await { diff --git a/svc/Cargo.lock b/svc/Cargo.lock index a44da54e06..21854f3c57 100644 --- a/svc/Cargo.lock +++ b/svc/Cargo.lock @@ -2275,7 +2275,6 @@ dependencies = [ "lazy_static", "prost 0.12.6", "prost-types 0.12.6", - "rand", "rivet-cache", "rivet-connection", "rivet-metrics", @@ -6480,6 +6479,19 @@ version = "0.0.1" dependencies = [ "chirp-workflow", "serde", +] + +[[package]] +name = "pegboard-ws" +version = "0.0.1" +dependencies = [ + "chirp-client", + "chirp-workflow", + "rivet-connection", + "rivet-health-checks", + "rivet-metrics", + "rivet-runtime", + "serde", "tokio-tungstenite 0.23.1", ] diff --git a/svc/Cargo.toml b/svc/Cargo.toml index db74fb8855..b17d8ca2f0 100644 --- a/svc/Cargo.toml +++ b/svc/Cargo.toml @@ -179,6 +179,7 @@ members = [ "pkg/nomad/standalone/monitor", "pkg/nsfw/ops/image-score", "pkg/pegboard", + "pkg/pegboard/standalone/ws", "pkg/perf/ops/log-get", "pkg/profanity/ops/check", "pkg/region/ops/get", diff --git a/svc/pkg/pegboard/Cargo.toml b/svc/pkg/pegboard/Cargo.toml index 0650aa1ace..769e38b4a7 100644 --- a/svc/pkg/pegboard/Cargo.toml +++ b/svc/pkg/pegboard/Cargo.toml @@ -7,5 +7,4 @@ license = "Apache-2.0" [dependencies] chirp-workflow = { path = "../../../lib/chirp-workflow/core" } -tokio-tungstenite = "0.23.1" serde = { version = "1.0.198", features = ["derive"] } diff --git a/svc/pkg/pegboard/src/workflows/ws.rs b/svc/pkg/pegboard/src/workflows/ws.rs deleted file mode 100644 index ed91211405..0000000000 --- a/svc/pkg/pegboard/src/workflows/ws.rs +++ /dev/null @@ -1,81 +0,0 @@ -use std::{ - collections::HashMap, - net::SocketAddr, - sync::{Arc, RwLock}, -}; - -use chirp_workflow::prelude::*; -use futures_util::FutureExt; -// use tokio::net::{TcpListener, TcpStream}; -// use tokio_tungstenite::tungstenite::protocol::Message; - -#[derive(Debug, Serialize, Deserialize)] -pub struct Input {} - -#[workflow] -pub async fn pegboard_ws(ctx: &mut WorkflowCtx, input: &Input) -> GlobalResult<()> { - // let addr = "127.0.0.1:8080"; - // let listener = TcpListener::bind(&addr).await?; - // println!("Listening on: {}", addr); - - let conns = Arc::new(RwLock::new(HashMap::<(), ()>::new())); - - ctx.try_join(( - closure(|ctx| socket_thread(ctx, conns.clone()).boxed()), - closure(|ctx| signal_thread(ctx, conns.clone()).boxed()), - )) - .await?; - - Ok(()) -} - -async fn socket_thread( - ctx: &mut WorkflowCtx, - conns: Arc>>, -) -> GlobalResult<()> { - ctx.repeat(|ctx| { - async move { - if let Ok((stream, addr)) = listener.accept().await { - handle_connection(stream, addr).await; - } else { - tracing::error!("failed to connect websocket"); - } - - Ok(Loop::Continue) - }.boxed() - ) - - Ok(()) -} - -async fn signal_thread( - ctx: &mut WorkflowCtx, - conns: Arc>>, -) -> GlobalResult<()> { - Ok(()) -} - -async fn handle_connection(ctx, raw_stream: TcpStream, addr: SocketAddr) { - ctx.spawn(|ctx| async move { - let ws_stream = tokio_tungstenite::accept_async(raw_stream).await?; - let (mut write, mut read) = ws_stream.split(); - - println!("New WebSocket connection: {}", addr); - - while let Some(Ok(msg)) = read.next().await { - if msg.is_text() || msg.is_binary() { - write.send(msg).await?; - } - } - - Ok(()) - }.boxed()).await -} - -#[derive(Debug, Serialize, Deserialize, Hash)] -struct FooInput {} - -#[activity(Foo)] -async fn foo(ctx: &ActivityCtx, input: &FooInput) -> GlobalResult<()> { - Ok(()) -} diff --git a/svc/pkg/pegboard/standalone/ws/Cargo.toml b/svc/pkg/pegboard/standalone/ws/Cargo.toml new file mode 100644 index 0000000000..1491131593 --- /dev/null +++ b/svc/pkg/pegboard/standalone/ws/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "pegboard-ws" +version = "0.0.1" +edition = "2018" +authors = ["Rivet Gaming, LLC "] +license = "Apache-2.0" + +[dependencies] +chirp-client = { path = "../../../../../lib/chirp/client" } +chirp-workflow = { path = "../../../../../lib/chirp-workflow/core" } +rivet-connection = { path = "../../../../../lib/connection" } +rivet-health-checks = { path = "../../../../../lib/health-checks" } +rivet-metrics = { path = "../../../../../lib/metrics" } +rivet-runtime = { path = "../../../../../lib/runtime" } +serde = { version = "1.0", features = ["derive"] } +tokio-tungstenite = "0.23.1" diff --git a/svc/pkg/pegboard/standalone/ws/Service.toml b/svc/pkg/pegboard/standalone/ws/Service.toml new file mode 100644 index 0000000000..744ec39f0d --- /dev/null +++ b/svc/pkg/pegboard/standalone/ws/Service.toml @@ -0,0 +1,8 @@ +[service] +name = "pegboard-ws" + +[runtime] +kind = "rust" + +[headless] +singleton = true diff --git a/svc/pkg/pegboard/standalone/ws/src/lib.rs b/svc/pkg/pegboard/standalone/ws/src/lib.rs new file mode 100644 index 0000000000..977d7505bb --- /dev/null +++ b/svc/pkg/pegboard/standalone/ws/src/lib.rs @@ -0,0 +1,105 @@ +use std::{collections::HashMap, net::SocketAddr, sync::Arc}; + +use chirp_workflow::prelude::*; +use futures_util::{stream::SplitSink, SinkExt, StreamExt, TryStreamExt}; +use serde_json::json; +use tokio::{ + net::{TcpListener, TcpStream}, + sync::RwLock, +}; +use tokio_tungstenite::{tungstenite::protocol::Message, WebSocketStream}; + +type Connections = HashMap, Message>>; + +#[tracing::instrument(skip_all)] +pub async fn run_from_env(pools: rivet_pools::Pools) -> GlobalResult<()> { + let client = chirp_client::SharedClient::from_env(pools.clone())?.wrap_new("pegboard-ws"); + let cache = rivet_cache::CacheInner::from_env(pools.clone())?; + let ctx = StandaloneCtx::new( + chirp_workflow::compat::db_from_pools(&pools).await?, + rivet_connection::Connection::new(client, pools, cache), + "pegboard-ws", + ) + .await?; + + let conns: Arc> = Arc::new(RwLock::new(HashMap::new())); + + tokio::try_join!( + socket_thread(&ctx, conns.clone()), + signal_thread(&ctx, conns.clone()), + )?; + + Ok(()) +} + +async fn socket_thread(ctx: &StandaloneCtx, conns: Arc>) -> GlobalResult<()> { + let addr = ("127.0.0.1", 8080); + let listener = TcpListener::bind(&addr).await?; + tracing::info!("Listening on: {:?}", addr); + + loop { + match listener.accept().await { + Ok((stream, addr)) => handle_connection(ctx, conns.clone(), stream, addr).await, + Err(err) => tracing::error!(?err, "failed to connect websocket"), + } + } +} + +async fn signal_thread(ctx: &StandaloneCtx, conns: Arc>) -> GlobalResult<()> { + loop { + let sig = ctx.listen::(&json!({})).await?; + + { + let conns = conns.read().await; + + if let Some(write) = conns.get(sig.client_id) { + write.send(sig); + } + } + } +} + +async fn handle_connection( + ctx: &StandaloneCtx, + conns: Arc>, + raw_stream: TcpStream, + addr: SocketAddr, +) { + let ctx = ctx.clone(); + + tokio::spawn(async move { + let ws_stream = tokio_tungstenite::accept_async(raw_stream).await?; + let (tx, mut rx) = ws_stream.split(); + + let client_id = Uuid::new_v4(); + + { + let mut conns = conns.write().await; + + if let Some(mut old_tx) = conns.insert(client_id, tx) { + tracing::error!( + ?client_id, + "client already connected, overwriting old connection" + ); + old_tx.send(Message::Close(None)).await?; + } + } + + // todo!("check if client exists in sql"); + + while let Ok(msg) = rx.try_next().await { + ctx.signal(ClientEvent {}) + .tag("client_id", client_id) + .send() + .await?; + } + + GlobalResult::Ok(()) + }); +} + +#[signal("pegboard_command")] +pub struct Command {} + +#[signal("pegboard_client_event")] +pub struct ClientEvent {} diff --git a/svc/pkg/pegboard/standalone/ws/src/main.rs b/svc/pkg/pegboard/standalone/ws/src/main.rs new file mode 100644 index 0000000000..98b9edbb05 --- /dev/null +++ b/svc/pkg/pegboard/standalone/ws/src/main.rs @@ -0,0 +1,23 @@ +use chirp_workflow::prelude::*; + +fn main() -> GlobalResult<()> { + rivet_runtime::run(start()).unwrap() +} + +async fn start() -> GlobalResult<()> { + let pools = rivet_pools::from_env("pegboard-ws").await?; + + tokio::task::Builder::new() + .name("pegboard_ws::health_checks") + .spawn(rivet_health_checks::run_standalone( + rivet_health_checks::Config { + pools: Some(pools.clone()), + }, + ))?; + + tokio::task::Builder::new() + .name("pegboard_ws::metrics") + .spawn(rivet_metrics::run_standalone())?; + + pegboard_ws::run_from_env(pools.clone()).await +} diff --git a/svc/pkg/pegboard/standalone/ws/tests/integration.rs b/svc/pkg/pegboard/standalone/ws/tests/integration.rs new file mode 100644 index 0000000000..6c8ea4d0f2 --- /dev/null +++ b/svc/pkg/pegboard/standalone/ws/tests/integration.rs @@ -0,0 +1 @@ +// TODO: