Skip to content

Commit

Permalink
fix(workflows): filter messages by tags (#1142)
Browse files Browse the repository at this point in the history
<!-- Please make sure there is an issue that this PR is correlated to. -->

## Changes

<!-- If there are frontend changes, please include screenshots. -->
  • Loading branch information
MasterPtato committed Oct 9, 2024
1 parent 3cfbfb5 commit f07c270
Show file tree
Hide file tree
Showing 27 changed files with 339 additions and 326 deletions.
1 change: 0 additions & 1 deletion lib/chirp-workflow/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ indoc = "2.0.5"
lazy_static = "1.4"
prost = "0.12.4"
prost-types = "0.12.4"
rand = "0.8.5"
rivet-cache = { path = "../../cache/build" }
rivet-connection = { path = "../../connection" }
rivet-metrics = { path = "../../metrics" }
Expand Down
2 changes: 1 addition & 1 deletion lib/chirp-workflow/core/src/compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ where
M: Message,
B: Debug + Clone,
{
let msg_ctx = MessageCtx::new(ctx.conn(), ctx.req_id(), ctx.ray_id())
let msg_ctx = MessageCtx::new(ctx.conn(), ctx.ray_id())
.await
.map_err(GlobalError::raw)?;

Expand Down
6 changes: 3 additions & 3 deletions lib/chirp-workflow/core/src/ctx/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
},
db::DatabaseHandle,
error::WorkflowResult,
message::{Message, ReceivedMessage},
message::{Message, NatsMessage},
operation::{Operation, OperationInput},
signal::Signal,
workflow::{Workflow, WorkflowInput},
Expand Down Expand Up @@ -55,7 +55,7 @@ impl ApiCtx {
(),
);

let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await?;
let msg_ctx = MessageCtx::new(&conn, ray_id).await?;

Ok(ApiCtx {
ray_id,
Expand Down Expand Up @@ -129,7 +129,7 @@ impl ApiCtx {
pub async fn tail_read<M>(
&self,
tags: serde_json::Value,
) -> GlobalResult<Option<ReceivedMessage<M>>>
) -> GlobalResult<Option<NatsMessage<M>>>
where
M: Message,
{
Expand Down
2 changes: 1 addition & 1 deletion lib/chirp-workflow/core/src/ctx/backfill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::{
};
use uuid::Uuid;

use crate::util::Location;
use crate::utils::Location;

// Yes
type Query = Box<
Expand Down
131 changes: 23 additions & 108 deletions lib/chirp-workflow/core/src/ctx/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ use uuid::Uuid;

use crate::{
error::{WorkflowError, WorkflowResult},
message::{self, Message, MessageWrapper, ReceivedMessage, TraceEntry},
message::{redis_keys, Message, NatsMessage, NatsMessageWrapper},
utils,
};

/// Time (in ms) that we subtract from the anchor grace period in order to
Expand All @@ -29,29 +30,15 @@ pub struct MessageCtx {
/// Used for writing to message tails. This cache is ephemeral.
redis_chirp_ephemeral: RedisPool,

req_id: Uuid,
ray_id: Uuid,
trace: Vec<TraceEntry>,
}

impl MessageCtx {
pub async fn new(
conn: &rivet_connection::Connection,
req_id: Uuid,
ray_id: Uuid,
) -> WorkflowResult<Self> {
pub async fn new(conn: &rivet_connection::Connection, ray_id: Uuid) -> WorkflowResult<Self> {
Ok(MessageCtx {
nats: conn.nats().await?,
redis_chirp_ephemeral: conn.redis_chirp_ephemeral().await?,
req_id,
ray_id,
trace: conn
.chirp()
.trace()
.iter()
.cloned()
.map(TryInto::try_into)
.collect::<WorkflowResult<Vec<_>>>()?,
})
}
}
Expand Down Expand Up @@ -109,7 +96,7 @@ impl MessageCtx {
M: Message,
{
let tags_str = cjson::to_string(&tags).map_err(WorkflowError::SerializeMessageTags)?;
let nats_subject = message::serialize_message_nats_subject::<M>(&tags_str);
let nats_subject = M::nats_subject();
let duration_since_epoch = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|err| unreachable!("time is broken: {}", err));
Expand All @@ -124,12 +111,11 @@ impl MessageCtx {

// Serialize message
let req_id = Uuid::new_v4();
let message = MessageWrapper {
let message = NatsMessageWrapper {
req_id: req_id,
ray_id: self.ray_id,
tags: tags.clone(),
tags,
ts,
trace: self.trace.clone(),
allow_recursive: false, // TODO:
body: &body_buf,
};
Expand Down Expand Up @@ -278,8 +264,7 @@ impl MessageCtx {
where
M: Message,
{
let tags_str = cjson::to_string(opts.tags).map_err(WorkflowError::SerializeMessageTags)?;
let nats_subject = message::serialize_message_nats_subject::<M>(&tags_str);
let nats_subject = M::nats_subject();

// Create subscription and flush immediately.
tracing::info!(%nats_subject, tags = ?opts.tags, "creating subscription");
Expand All @@ -296,7 +281,7 @@ impl MessageCtx {
}

// Return handle
let subscription = SubscriptionHandle::new(nats_subject, subscription, self.req_id);
let subscription = SubscriptionHandle::new(nats_subject, subscription, opts.tags.clone());
Ok(subscription)
}

Expand All @@ -305,7 +290,7 @@ impl MessageCtx {
pub async fn tail_read<M>(
&self,
tags: serde_json::Value,
) -> WorkflowResult<Option<ReceivedMessage<M>>>
) -> WorkflowResult<Option<NatsMessage<M>>>
where
M: Message,
{
Expand All @@ -320,7 +305,7 @@ impl MessageCtx {

// Deserialize message
let message = if let Some(message_buf) = message_buf {
let message = ReceivedMessage::<M>::deserialize(message_buf.as_slice())?;
let message = NatsMessage::<M>::deserialize(message_buf.as_slice())?;
tracing::info!(?message, "immediate read tail message");

let recv_lag = (rivet_util::timestamp::now() as f64 - message.ts as f64) / 1000.;
Expand Down Expand Up @@ -410,7 +395,7 @@ where
_guard: DropGuard,
subject: String,
subscription: nats::Subscriber,
req_id: Uuid,
pub tags: serde_json::Value,
}

impl<M> Debug for SubscriptionHandle<M>
Expand All @@ -420,6 +405,7 @@ where
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SubscriptionHandle")
.field("subject", &self.subject)
.field("tags", &self.tags)
.finish()
}
}
Expand All @@ -429,7 +415,7 @@ where
M: Message,
{
#[tracing::instrument(level = "debug", skip_all)]
fn new(subject: String, subscription: nats::Subscriber, req_id: Uuid) -> Self {
fn new(subject: String, subscription: nats::Subscriber, tags: serde_json::Value) -> Self {
let token = CancellationToken::new();

{
Expand Down Expand Up @@ -458,34 +444,15 @@ where
_guard: token.drop_guard(),
subject,
subscription,
req_id,
tags,
}
}

/// Waits for the next message in the subscription.
///
/// This future can be safely dropped.
#[tracing::instrument]
pub async fn next(&mut self) -> WorkflowResult<ReceivedMessage<M>> {
self.next_inner(false).await
}

// TODO: Add a full config struct to pass to `next` that impl's `Default`
/// Waits for the next message in the subscription that originates from the
/// parent request ID via trace.
///
/// This future can be safely dropped.
#[tracing::instrument]
pub async fn next_with_trace(
&mut self,
filter_trace: bool,
) -> WorkflowResult<ReceivedMessage<M>> {
self.next_inner(filter_trace).await
}

/// This future can be safely dropped.
#[tracing::instrument(level = "trace")]
async fn next_inner(&mut self, filter_trace: bool) -> WorkflowResult<ReceivedMessage<M>> {
pub async fn next(&mut self) -> WorkflowResult<NatsMessage<M>> {
tracing::info!("waiting for message");

loop {
Expand All @@ -501,47 +468,22 @@ where
}
};

if filter_trace {
let message_wrapper =
ReceivedMessage::<M>::deserialize_wrapper(&nats_message.payload[..])?;

// Check if the message trace stack originates from this client
//
// We intentionally use the request ID instead of just checking the ray ID because
// there may be multiple calls to `message_with_subscribe` within the same ray.
// Explicitly checking the parent request ensures the response is unique to this
// message.
if message_wrapper
.trace
.iter()
.rev()
.any(|trace_entry| trace_entry.req_id == self.req_id)
{
let message = ReceivedMessage::<M>::deserialize(&nats_message.payload[..])?;
tracing::info!(?message, "received message");

return Ok(message);
}
} else {
let message = ReceivedMessage::<M>::deserialize(&nats_message.payload[..])?;
tracing::info!(?message, "received message");
let message_wrapper = NatsMessage::<M>::deserialize_wrapper(&nats_message.payload[..])?;

let recv_lag = (rivet_util::timestamp::now() as f64 - message.ts as f64) / 1000.;
crate::metrics::MESSAGE_RECV_LAG
.with_label_values(&[M::NAME])
.observe(recv_lag);
// Check if the subscription tags match a subset of the message tags
if utils::is_value_subset(&self.tags, &message_wrapper.tags) {
let message = NatsMessage::<M>::deserialize_from_wrapper(message_wrapper)?;
tracing::info!(?message, "received message");

return Ok(message);
}

// Message not from parent, continue with loop
// Message tags don't match, continue with loop
}
}

/// Converts the subscription in to a stream.
pub fn into_stream(
self,
) -> impl futures_util::Stream<Item = WorkflowResult<ReceivedMessage<M>>> {
pub fn into_stream(self) -> impl futures_util::Stream<Item = WorkflowResult<NatsMessage<M>>> {
futures_util::stream::try_unfold(self, |mut sub| async move {
let message = sub.next().await?;
Ok(Some((message, sub)))
Expand Down Expand Up @@ -569,7 +511,7 @@ pub enum TailAnchorResponse<M>
where
M: Message + Debug,
{
Message(ReceivedMessage<M>),
Message(NatsMessage<M>),

/// Anchor was older than the TTL of the message.
AnchorExpired,
Expand All @@ -589,30 +531,3 @@ where
}
}
}

mod redis_keys {
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
};

use crate::message::Message;

/// HASH
pub fn message_tail<M>(tags_str: &str) -> String
where
M: Message,
{
// Get hash of the tags
let mut hasher = DefaultHasher::new();
tags_str.hash(&mut hasher);

format!("{{topic:{}:{:x}}}:tail", M::NAME, hasher.finish())
}

pub mod message_tail {
pub const REQUEST_ID: &str = "r";
pub const TS: &str = "t";
pub const BODY: &str = "b";
}
}
16 changes: 13 additions & 3 deletions lib/chirp-workflow/core/src/ctx/standalone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use crate::{
},
db::DatabaseHandle,
error::WorkflowResult,
message::{Message, ReceivedMessage},
listen::Listen,
message::{Message, NatsMessage},
operation::{Operation, OperationInput},
signal::Signal,
workflow::{Workflow, WorkflowInput},
Expand Down Expand Up @@ -54,7 +55,7 @@ impl StandaloneCtx {
(),
);

let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await?;
let msg_ctx = MessageCtx::new(&conn, ray_id).await?;

Ok(StandaloneCtx {
ray_id,
Expand Down Expand Up @@ -92,6 +93,15 @@ impl StandaloneCtx {
builder::signal::SignalBuilder::new(self.db.clone(), self.ray_id, body)
}

// /// Listens for a signal indefinitely.
// pub async fn listen<T: Listen>(&mut self) -> GlobalResult<T> {
// tracing::info!(name=%self.name, "listening for signal");

// let ctx = ListenCtx::new(self);

// T::listen(&ctx).await
// }

#[tracing::instrument(err, skip_all, fields(operation = I::Operation::NAME))]
pub async fn op<I>(
&self,
Expand Down Expand Up @@ -128,7 +138,7 @@ impl StandaloneCtx {
pub async fn tail_read<M>(
&self,
tags: serde_json::Value,
) -> GlobalResult<Option<ReceivedMessage<M>>>
) -> GlobalResult<Option<NatsMessage<M>>>
where
M: Message,
{
Expand Down
10 changes: 5 additions & 5 deletions lib/chirp-workflow/core/src/ctx/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ use crate::{
},
db::{DatabaseHandle, DatabasePgNats},
error::WorkflowError,
message::{Message, ReceivedMessage},
message::{Message, NatsMessage},
operation::{Operation, OperationInput},
signal::Signal,
util,
utils,
workflow::{Workflow, WorkflowInput},
};

Expand Down Expand Up @@ -50,7 +50,7 @@ impl TestCtx {
.expect("failed to create chirp client");
let cache =
rivet_cache::CacheInner::from_env(pools.clone()).expect("failed to create cache");
let conn = util::new_conn(
let conn = utils::new_conn(
&shared_client,
&pools,
&cache,
Expand All @@ -73,7 +73,7 @@ impl TestCtx {

let db =
DatabasePgNats::from_pools(pools.crdb().unwrap(), pools.nats_option().clone().unwrap());
let msg_ctx = MessageCtx::new(&conn, req_id, ray_id).await.unwrap();
let msg_ctx = MessageCtx::new(&conn, ray_id).await.unwrap();

TestCtx {
name: service_name,
Expand Down Expand Up @@ -176,7 +176,7 @@ impl TestCtx {
pub async fn tail_read<M>(
&self,
tags: serde_json::Value,
) -> GlobalResult<Option<ReceivedMessage<M>>>
) -> GlobalResult<Option<NatsMessage<M>>>
where
M: Message,
{
Expand Down
Loading

0 comments on commit f07c270

Please sign in to comment.