Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(source): add NATS source consumer parameters #17615

Merged
merged 5 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/connector/src/connector_common/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ impl NatsCommon {
stream: String,
split_id: String,
start_sequence: NatsOffset,
mut config: jetstream::consumer::pull::Config,
) -> ConnectorResult<
async_nats::jetstream::consumer::Consumer<async_nats::jetstream::consumer::pull::Config>,
> {
Expand All @@ -649,10 +650,6 @@ impl NatsCommon {
.replace(',', "-")
.replace(['.', '>', '*', ' ', '\t'], "_");
let name = format!("risingwave-consumer-{}-{}", subject_name, split_id);
let mut config = jetstream::consumer::pull::Config {
ack_policy: jetstream::consumer::AckPolicy::None,
..Default::default()
};

let deliver_policy = match start_sequence {
NatsOffset::Earliest => DeliverPolicy::All,
Expand All @@ -671,6 +668,7 @@ impl NatsCommon {
},
NatsOffset::None => DeliverPolicy::All,
};

let consumer = stream
.get_or_create_consumer(&name, {
config.deliver_policy = deliver_policy;
Expand Down
19 changes: 19 additions & 0 deletions src/connector/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,25 @@ where
}
}

pub(crate) fn deserialize_duration_seq_from_string<'de, D>(
deserializer: D,
) -> std::result::Result<Option<Vec<Duration>>, D::Error>
tabVersion marked this conversation as resolved.
Show resolved Hide resolved
where
D: de::Deserializer<'de>,
{
let s: Option<String> = de::Deserialize::deserialize(deserializer)?;
if let Some(s) = s {
let durations = s
.split(',')
.map(|s| s.trim().parse().map(Duration::from_secs))
.collect::<Result<Vec<Duration>, _>>()
.map_err(|_| de::Error::invalid_value(de::Unexpected::Str(&s), &"invalid duration"));
Ok(Some(durations?))
} else {
Ok(None)
}
}

pub(crate) fn deserialize_bool_from_string<'de, D>(deserializer: D) -> Result<bool, D::Error>
where
D: de::Deserializer<'de>,
Expand Down
303 changes: 303 additions & 0 deletions src/connector/src/source/nats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,55 @@ pub mod source;
pub mod split;

use std::collections::HashMap;
use std::time::Duration;

use async_nats::jetstream::consumer::pull::Config;
use async_nats::jetstream::consumer::{AckPolicy, ReplayPolicy};
use serde::Deserialize;
use serde_with::{serde_as, DisplayFromStr, DurationSeconds};
use with_options::WithOptions;

use crate::connector_common::NatsCommon;
use crate::source::nats::enumerator::NatsSplitEnumerator;
use crate::source::nats::source::{NatsSplit, NatsSplitReader};
use crate::source::SourceProperties;
use crate::{deserialize_duration_seq_from_string, deserialize_optional_string_seq_from_string};

pub const NATS_CONNECTOR: &str = "nats";

pub struct AckPolicyWrapper;

impl AckPolicyWrapper {
pub fn parse_str(s: &str) -> AckPolicy {
match s {
"none" => AckPolicy::None,
"all" => AckPolicy::All,
"explicit" => AckPolicy::Explicit,
_ => AckPolicy::None,
tabVersion marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

pub struct ReplayPolicyWrapper;

impl ReplayPolicyWrapper {
pub fn parse_str(s: &str) -> ReplayPolicy {
match s {
"instant" => ReplayPolicy::Instant,
"original" => ReplayPolicy::Original,
_ => ReplayPolicy::Instant,
}
}
}

#[derive(Clone, Debug, Deserialize, WithOptions)]
pub struct NatsProperties {
#[serde(flatten)]
pub common: NatsCommon,

#[serde(flatten)]
pub nats_properties_consumer: NatsPropertiesConsumer,

#[serde(rename = "scan.startup.mode")]
pub scan_startup_mode: Option<String>,

Expand All @@ -49,6 +82,173 @@ pub struct NatsProperties {
pub unknown_fields: HashMap<String, String>,
}

impl NatsProperties {
pub fn set_config(&self, c: &mut Config) {
self.nats_properties_consumer.set_config(c);
}
}

/// Properties for the async-nats library.
/// See <https://docs.rs/async-nats/latest/async_nats/jetstream/consumer/struct.Config.html>
#[serde_as]
#[derive(Clone, Debug, Deserialize, WithOptions)]
pub struct NatsPropertiesConsumer {
#[serde(rename = "consumer.deliver_subject")]
pub deliver_subject: Option<String>,

#[serde(rename = "consumer.durable_name")]
pub durable_name: Option<String>,

#[serde(rename = "consumer.name")]
pub name: Option<String>,

#[serde(rename = "consumer.description")]
pub description: Option<String>,

#[serde(rename = "consumer.deliver_policy")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub deliver_policy: Option<String>,

#[serde(rename = "consumer.ack_policy")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub ack_policy: Option<String>,
Comment on lines +114 to +116
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can try to deser it into Enum directly, serde can help us reject unrecognized str.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to do this, but unfortunately couldn't find a clean way to do it -- the enums from the NATS crate do not implement FromStr, so it doesn’t seem possible to deserialize directly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nevermind, I will try refactor it later. Thanks for your contribution.


#[serde(rename = "consumer.ack_wait")]
#[serde_as(as = "Option<DurationSeconds<String>>")]
pub ack_wait: Option<Duration>,

#[serde(rename = "consumer.max_deliver")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub max_deliver: Option<i64>,

#[serde(rename = "consumer.filter_subject")]
pub filter_subject: Option<String>,

#[serde(rename = "consumer.filter_subjects")]
#[serde(deserialize_with = "deserialize_optional_string_seq_from_string")]
pub filter_subjects: Option<Vec<String>>,

#[serde(rename = "consumer.replay_policy")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub replay_policy: Option<String>,

#[serde(rename = "consumer.rate_limit")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub rate_limit: Option<u64>,

#[serde(rename = "consumer.sample_frequency")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub sample_frequency: Option<u8>,

#[serde(rename = "consumer.max_waiting")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub max_waiting: Option<i64>,

#[serde(rename = "consumer.max_ack_pending")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub max_ack_pending: Option<i64>,

#[serde(rename = "consumer.headers_only")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub headers_only: Option<bool>,

#[serde(rename = "consumer.max_batch")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub max_batch: Option<i64>,

#[serde(rename = "consumer.max_bytes")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub max_bytes: Option<i64>,

#[serde(rename = "consumer.max_expires")]
#[serde_as(as = "Option<DurationSeconds<String>>")]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

pub max_expires: Option<Duration>,

#[serde(rename = "consumer.inactive_threshold")]
#[serde_as(as = "Option<DurationSeconds<String>>")]
pub inactive_threshold: Option<Duration>,

#[serde(rename = "consumer.num.replicas", alias = "consumer.num_replicas")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub num_replicas: Option<usize>,

#[serde(rename = "consumer.memory_storage")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub memory_storage: Option<bool>,

#[serde(rename = "consumer.backoff")]
#[serde(deserialize_with = "deserialize_duration_seq_from_string")]
pub backoff: Option<Vec<Duration>>,
}

impl NatsPropertiesConsumer {
pub fn set_config(&self, c: &mut Config) {
if let Some(v) = &self.name {
c.name = Some(v.clone())
}
if let Some(v) = &self.durable_name {
c.durable_name = Some(v.clone())
}
if let Some(v) = &self.description {
c.description = Some(v.clone())
}
if let Some(v) = &self.ack_policy {
c.ack_policy = AckPolicyWrapper::parse_str(v)
}
if let Some(v) = &self.ack_wait {
c.ack_wait = *v
}
if let Some(v) = &self.max_deliver {
c.max_deliver = *v
}
if let Some(v) = &self.filter_subject {
c.filter_subject = v.clone()
}
if let Some(v) = &self.filter_subjects {
c.filter_subjects = v.clone()
}
if let Some(v) = &self.replay_policy {
c.replay_policy = ReplayPolicyWrapper::parse_str(v)
}
if let Some(v) = &self.rate_limit {
c.rate_limit = *v
}
if let Some(v) = &self.sample_frequency {
c.sample_frequency = *v
}
if let Some(v) = &self.max_waiting {
c.max_waiting = *v
}
if let Some(v) = &self.max_ack_pending {
c.max_ack_pending = *v
}
if let Some(v) = &self.headers_only {
c.headers_only = *v
}
if let Some(v) = &self.max_batch {
c.max_batch = *v
}
if let Some(v) = &self.max_bytes {
c.max_bytes = *v
}
if let Some(v) = &self.max_expires {
c.max_expires = *v
}
if let Some(v) = &self.inactive_threshold {
c.inactive_threshold = *v
}
if let Some(v) = &self.num_replicas {
c.num_replicas = *v
}
if let Some(v) = &self.memory_storage {
c.memory_storage = *v
}
if let Some(v) = &self.backoff {
c.backoff = v.clone()
}
}
}

impl SourceProperties for NatsProperties {
type Split = NatsSplit;
type SplitEnumerator = NatsSplitEnumerator;
Expand All @@ -62,3 +262,106 @@ impl crate::source::UnknownFields for NatsProperties {
self.unknown_fields.clone()
}
}

#[cfg(test)]
mod test {
use std::collections::BTreeMap;

use maplit::btreemap;

use super::*;

#[test]
fn test_parse_config_consumer() {
let config: BTreeMap<String, String> = btreemap! {
"stream".to_string() => "risingwave".to_string(),

// NATS common
"subject".to_string() => "subject1".to_string(),
"server_url".to_string() => "nats-server:4222".to_string(),
"connect_mode".to_string() => "plain".to_string(),
"type".to_string() => "append-only".to_string(),

// NATS properties consumer
"consumer.name".to_string() => "foobar".to_string(),
"consumer.durable_name".to_string() => "durable_foobar".to_string(),
"consumer.description".to_string() => "A description".to_string(),
"consumer.ack_policy".to_string() => "all".to_string(),
"consumer.ack_wait".to_string() => "10".to_string(),
"consumer.max_deliver".to_string() => "10".to_string(),
"consumer.filter_subject".to_string() => "subject".to_string(),
"consumer.filter_subjects".to_string() => "subject1,subject2".to_string(),
"consumer.replay_policy".to_string() => "instant".to_string(),
"consumer.rate_limit".to_string() => "100".to_string(),
"consumer.sample_frequency".to_string() => "1".to_string(),
"consumer.max_waiting".to_string() => "5".to_string(),
"consumer.max_ack_pending".to_string() => "100".to_string(),
"consumer.headers_only".to_string() => "true".to_string(),
"consumer.max_batch".to_string() => "10".to_string(),
"consumer.max_bytes".to_string() => "1024".to_string(),
"consumer.max_expires".to_string() => "24".to_string(),
"consumer.inactive_threshold".to_string() => "10".to_string(),
"consumer.num_replicas".to_string() => "3".to_string(),
"consumer.memory_storage".to_string() => "true".to_string(),
"consumer.backoff".to_string() => "2,10,15".to_string(),

};

let props: NatsProperties =
serde_json::from_value(serde_json::to_value(config).unwrap()).unwrap();

assert_eq!(
props.nats_properties_consumer.name,
Some("foobar".to_string())
);
assert_eq!(
props.nats_properties_consumer.durable_name,
Some("durable_foobar".to_string())
);
assert_eq!(
props.nats_properties_consumer.description,
Some("A description".to_string())
);
assert_eq!(
props.nats_properties_consumer.ack_policy,
Some("all".to_string())
);
assert_eq!(
props.nats_properties_consumer.ack_wait,
Some(Duration::from_secs(10))
);
assert_eq!(
props.nats_properties_consumer.filter_subjects,
Some(vec!["subject1".to_string(), "subject2".to_string()])
);
assert_eq!(
props.nats_properties_consumer.replay_policy,
Some("instant".to_string())
);
assert_eq!(props.nats_properties_consumer.rate_limit, Some(100));
assert_eq!(props.nats_properties_consumer.sample_frequency, Some(1));
assert_eq!(props.nats_properties_consumer.max_waiting, Some(5));
assert_eq!(props.nats_properties_consumer.max_ack_pending, Some(100));
assert_eq!(props.nats_properties_consumer.headers_only, Some(true));
assert_eq!(props.nats_properties_consumer.max_batch, Some(10));
assert_eq!(props.nats_properties_consumer.max_bytes, Some(1024));
assert_eq!(
props.nats_properties_consumer.max_expires,
Some(Duration::from_secs(24))
);
assert_eq!(
props.nats_properties_consumer.inactive_threshold,
Some(Duration::from_secs(10))
);
assert_eq!(props.nats_properties_consumer.num_replicas, Some(3));
assert_eq!(props.nats_properties_consumer.memory_storage, Some(true));
assert_eq!(
props.nats_properties_consumer.backoff,
Some(vec![
Duration::from_secs(2),
Duration::from_secs(10),
Duration::from_secs(15)
])
);
}
}
Loading
Loading