diff --git a/src/connector/src/connector_common/common.rs b/src/connector/src/connector_common/common.rs index 1c911c5a3992f..b522ae2eda560 100644 --- a/src/connector/src/connector_common/common.rs +++ b/src/connector/src/connector_common/common.rs @@ -639,6 +639,7 @@ impl NatsCommon { stream: String, split_id: String, start_sequence: NatsOffset, + mut config: jetstream::consumer::pull::Config, ) -> ConnectorResult< async_nats::jetstream::consumer::Consumer, > { @@ -649,10 +650,6 @@ impl NatsCommon { .replace(',', "-") .replace(['.', '>', '*', ' ', '\t'], "_"); let name = format!("risingwave-consumer-{}-{}", subject_name, split_id); - let mut config = jetstream::consumer::pull::Config { - ack_policy: jetstream::consumer::AckPolicy::None, - ..Default::default() - }; let deliver_policy = match start_sequence { NatsOffset::Earliest => DeliverPolicy::All, @@ -671,6 +668,7 @@ impl NatsCommon { }, NatsOffset::None => DeliverPolicy::All, }; + let consumer = stream .get_or_create_consumer(&name, { config.deliver_policy = deliver_policy; diff --git a/src/connector/src/lib.rs b/src/connector/src/lib.rs index 02a3b8c84b50f..7ec84b14088e9 100644 --- a/src/connector/src/lib.rs +++ b/src/connector/src/lib.rs @@ -112,6 +112,25 @@ where } } +pub(crate) fn deserialize_optional_u64_seq_from_string<'de, D>( + deserializer: D, +) -> std::result::Result>, D::Error> +where + D: de::Deserializer<'de>, +{ + let s: Option = de::Deserialize::deserialize(deserializer)?; + if let Some(s) = s { + let numbers = s + .split(',') + .map(|s| s.trim().parse()) + .collect::, _>>() + .map_err(|_| de::Error::invalid_value(de::Unexpected::Str(&s), &"invalid number")); + Ok(Some(numbers?)) + } else { + Ok(None) + } +} + pub(crate) fn deserialize_bool_from_string<'de, D>(deserializer: D) -> Result where D: de::Deserializer<'de>, diff --git a/src/connector/src/source/nats/mod.rs b/src/connector/src/source/nats/mod.rs index 7ef9c74ee7601..0ba35d20269dc 100644 --- a/src/connector/src/source/nats/mod.rs +++ b/src/connector/src/source/nats/mod.rs @@ -17,22 +17,57 @@ pub mod source; pub mod split; use std::collections::HashMap; +use std::time::Duration; +use async_nats::jetstream::consumer::pull::Config; +use async_nats::jetstream::consumer::{AckPolicy, ReplayPolicy}; use serde::Deserialize; +use serde_with::{serde_as, DisplayFromStr}; use with_options::WithOptions; use crate::connector_common::NatsCommon; use crate::source::nats::enumerator::NatsSplitEnumerator; use crate::source::nats::source::{NatsSplit, NatsSplitReader}; use crate::source::SourceProperties; +use crate::{ + deserialize_optional_string_seq_from_string, deserialize_optional_u64_seq_from_string, +}; pub const NATS_CONNECTOR: &str = "nats"; +pub struct AckPolicyWrapper; + +impl AckPolicyWrapper { + pub fn parse_str(s: &str) -> Result { + match s { + "none" => Ok(AckPolicy::None), + "all" => Ok(AckPolicy::All), + "explicit" => Ok(AckPolicy::Explicit), + _ => Err(format!("Invalid AckPolicy '{}'", s)), + } + } +} + +pub struct ReplayPolicyWrapper; + +impl ReplayPolicyWrapper { + pub fn parse_str(s: &str) -> Result { + match s { + "instant" => Ok(ReplayPolicy::Instant), + "original" => Ok(ReplayPolicy::Original), + _ => Err(format!("Invalid ReplayPolicy '{}'", s)), + } + } +} + #[derive(Clone, Debug, Deserialize, WithOptions)] pub struct NatsProperties { #[serde(flatten)] pub common: NatsCommon, + #[serde(flatten)] + pub nats_properties_consumer: NatsPropertiesConsumer, + #[serde(rename = "scan.startup.mode")] pub scan_startup_mode: Option, @@ -49,6 +84,173 @@ pub struct NatsProperties { pub unknown_fields: HashMap, } +impl NatsProperties { + pub fn set_config(&self, c: &mut Config) { + self.nats_properties_consumer.set_config(c); + } +} + +/// Properties for the async-nats library. +/// See +#[serde_as] +#[derive(Clone, Debug, Deserialize, WithOptions)] +pub struct NatsPropertiesConsumer { + #[serde(rename = "consumer.deliver_subject")] + pub deliver_subject: Option, + + #[serde(rename = "consumer.durable_name")] + pub durable_name: Option, + + #[serde(rename = "consumer.name")] + pub name: Option, + + #[serde(rename = "consumer.description")] + pub description: Option, + + #[serde(rename = "consumer.deliver_policy")] + #[serde_as(as = "Option")] + pub deliver_policy: Option, + + #[serde(rename = "consumer.ack_policy")] + #[serde_as(as = "Option")] + pub ack_policy: Option, + + #[serde(rename = "consumer.ack_wait.sec")] + #[serde_as(as = "Option")] + pub ack_wait: Option, + + #[serde(rename = "consumer.max_deliver")] + #[serde_as(as = "Option")] + pub max_deliver: Option, + + #[serde(rename = "consumer.filter_subject")] + pub filter_subject: Option, + + #[serde(rename = "consumer.filter_subjects")] + #[serde(deserialize_with = "deserialize_optional_string_seq_from_string")] + pub filter_subjects: Option>, + + #[serde(rename = "consumer.replay_policy")] + #[serde_as(as = "Option")] + pub replay_policy: Option, + + #[serde(rename = "consumer.rate_limit")] + #[serde_as(as = "Option")] + pub rate_limit: Option, + + #[serde(rename = "consumer.sample_frequency")] + #[serde_as(as = "Option")] + pub sample_frequency: Option, + + #[serde(rename = "consumer.max_waiting")] + #[serde_as(as = "Option")] + pub max_waiting: Option, + + #[serde(rename = "consumer.max_ack_pending")] + #[serde_as(as = "Option")] + pub max_ack_pending: Option, + + #[serde(rename = "consumer.headers_only")] + #[serde_as(as = "Option")] + pub headers_only: Option, + + #[serde(rename = "consumer.max_batch")] + #[serde_as(as = "Option")] + pub max_batch: Option, + + #[serde(rename = "consumer.max_bytes")] + #[serde_as(as = "Option")] + pub max_bytes: Option, + + #[serde(rename = "consumer.max_expires.sec")] + #[serde_as(as = "Option")] + pub max_expires: Option, + + #[serde(rename = "consumer.inactive_threshold.sec")] + #[serde_as(as = "Option")] + pub inactive_threshold: Option, + + #[serde(rename = "consumer.num.replicas", alias = "consumer.num_replicas")] + #[serde_as(as = "Option")] + pub num_replicas: Option, + + #[serde(rename = "consumer.memory_storage")] + #[serde_as(as = "Option")] + pub memory_storage: Option, + + #[serde(rename = "consumer.backoff.sec")] + #[serde(deserialize_with = "deserialize_optional_u64_seq_from_string")] + pub backoff: Option>, +} + +impl NatsPropertiesConsumer { + pub fn set_config(&self, c: &mut Config) { + if let Some(v) = &self.name { + c.name = Some(v.clone()) + } + if let Some(v) = &self.durable_name { + c.durable_name = Some(v.clone()) + } + if let Some(v) = &self.description { + c.description = Some(v.clone()) + } + if let Some(v) = &self.ack_policy { + c.ack_policy = AckPolicyWrapper::parse_str(v).unwrap() + } + if let Some(v) = &self.ack_wait { + c.ack_wait = Duration::from_secs(*v) + } + if let Some(v) = &self.max_deliver { + c.max_deliver = *v + } + if let Some(v) = &self.filter_subject { + c.filter_subject = v.clone() + } + if let Some(v) = &self.filter_subjects { + c.filter_subjects = v.clone() + } + if let Some(v) = &self.replay_policy { + c.replay_policy = ReplayPolicyWrapper::parse_str(v).unwrap() + } + if let Some(v) = &self.rate_limit { + c.rate_limit = *v + } + if let Some(v) = &self.sample_frequency { + c.sample_frequency = *v + } + if let Some(v) = &self.max_waiting { + c.max_waiting = *v + } + if let Some(v) = &self.max_ack_pending { + c.max_ack_pending = *v + } + if let Some(v) = &self.headers_only { + c.headers_only = *v + } + if let Some(v) = &self.max_batch { + c.max_batch = *v + } + if let Some(v) = &self.max_bytes { + c.max_bytes = *v + } + if let Some(v) = &self.max_expires { + c.max_expires = Duration::from_secs(*v) + } + if let Some(v) = &self.inactive_threshold { + c.inactive_threshold = Duration::from_secs(*v) + } + if let Some(v) = &self.num_replicas { + c.num_replicas = *v + } + if let Some(v) = &self.memory_storage { + c.memory_storage = *v + } + if let Some(v) = &self.backoff { + c.backoff = v.iter().map(|&x| Duration::from_secs(x)).collect() + } + } +} + impl SourceProperties for NatsProperties { type Split = NatsSplit; type SplitEnumerator = NatsSplitEnumerator; @@ -62,3 +264,93 @@ impl crate::source::UnknownFields for NatsProperties { self.unknown_fields.clone() } } + +#[cfg(test)] +mod test { + use std::collections::BTreeMap; + + use maplit::btreemap; + + use super::*; + + #[test] + fn test_parse_config_consumer() { + let config: BTreeMap = btreemap! { + "stream".to_string() => "risingwave".to_string(), + + // NATS common + "subject".to_string() => "subject1".to_string(), + "server_url".to_string() => "nats-server:4222".to_string(), + "connect_mode".to_string() => "plain".to_string(), + "type".to_string() => "append-only".to_string(), + + // NATS properties consumer + "consumer.name".to_string() => "foobar".to_string(), + "consumer.durable_name".to_string() => "durable_foobar".to_string(), + "consumer.description".to_string() => "A description".to_string(), + "consumer.ack_policy".to_string() => "all".to_string(), + "consumer.ack_wait.sec".to_string() => "10".to_string(), + "consumer.max_deliver".to_string() => "10".to_string(), + "consumer.filter_subject".to_string() => "subject".to_string(), + "consumer.filter_subjects".to_string() => "subject1,subject2".to_string(), + "consumer.replay_policy".to_string() => "instant".to_string(), + "consumer.rate_limit".to_string() => "100".to_string(), + "consumer.sample_frequency".to_string() => "1".to_string(), + "consumer.max_waiting".to_string() => "5".to_string(), + "consumer.max_ack_pending".to_string() => "100".to_string(), + "consumer.headers_only".to_string() => "true".to_string(), + "consumer.max_batch".to_string() => "10".to_string(), + "consumer.max_bytes".to_string() => "1024".to_string(), + "consumer.max_expires.sec".to_string() => "24".to_string(), + "consumer.inactive_threshold.sec".to_string() => "10".to_string(), + "consumer.num_replicas".to_string() => "3".to_string(), + "consumer.memory_storage".to_string() => "true".to_string(), + "consumer.backoff.sec".to_string() => "2,10,15".to_string(), + + }; + + let props: NatsProperties = + serde_json::from_value(serde_json::to_value(config).unwrap()).unwrap(); + + assert_eq!( + props.nats_properties_consumer.name, + Some("foobar".to_string()) + ); + assert_eq!( + props.nats_properties_consumer.durable_name, + Some("durable_foobar".to_string()) + ); + assert_eq!( + props.nats_properties_consumer.description, + Some("A description".to_string()) + ); + assert_eq!( + props.nats_properties_consumer.ack_policy, + Some("all".to_string()) + ); + assert_eq!(props.nats_properties_consumer.ack_wait, Some(10)); + assert_eq!( + props.nats_properties_consumer.filter_subjects, + Some(vec!["subject1".to_string(), "subject2".to_string()]) + ); + assert_eq!( + props.nats_properties_consumer.replay_policy, + Some("instant".to_string()) + ); + assert_eq!(props.nats_properties_consumer.rate_limit, Some(100)); + assert_eq!(props.nats_properties_consumer.sample_frequency, Some(1)); + assert_eq!(props.nats_properties_consumer.max_waiting, Some(5)); + assert_eq!(props.nats_properties_consumer.max_ack_pending, Some(100)); + assert_eq!(props.nats_properties_consumer.headers_only, Some(true)); + assert_eq!(props.nats_properties_consumer.max_batch, Some(10)); + assert_eq!(props.nats_properties_consumer.max_bytes, Some(1024)); + assert_eq!(props.nats_properties_consumer.max_expires, Some(24)); + assert_eq!(props.nats_properties_consumer.inactive_threshold, Some(10)); + assert_eq!(props.nats_properties_consumer.num_replicas, Some(3)); + assert_eq!(props.nats_properties_consumer.memory_storage, Some(true)); + assert_eq!( + props.nats_properties_consumer.backoff, + Some(vec![2, 10, 15]) + ); + } +} diff --git a/src/connector/src/source/nats/source/reader.rs b/src/connector/src/source/nats/source/reader.rs index 916378a263979..45d13017e0ada 100644 --- a/src/connector/src/source/nats/source/reader.rs +++ b/src/connector/src/source/nats/source/reader.rs @@ -79,14 +79,21 @@ impl SplitReader for NatsSplitReader { start_position => start_position.to_owned(), }; + let mut config = consumer::pull::Config { + ..Default::default() + }; + properties.set_config(&mut config); + let consumer = properties .common .build_consumer( properties.stream.clone(), split_id.to_string(), start_position.clone(), + config, ) .await?; + Ok(Self { consumer, properties, diff --git a/src/connector/src/with_options.rs b/src/connector/src/with_options.rs index eef5ccbd9cbfa..90dcfc5b1d88f 100644 --- a/src/connector/src/with_options.rs +++ b/src/connector/src/with_options.rs @@ -50,12 +50,14 @@ impl WithOptions impl WithOptions for Option {} impl WithOptions for Vec {} +impl WithOptions for Vec {} impl WithOptions for HashMap {} impl WithOptions for BTreeMap {} impl WithOptions for String {} impl WithOptions for bool {} impl WithOptions for usize {} +impl WithOptions for u8 {} impl WithOptions for u16 {} impl WithOptions for u32 {} impl WithOptions for u64 {} diff --git a/src/connector/with_options_source.yaml b/src/connector/with_options_source.yaml index 4a208465265e7..b3f1a3769f19a 100644 --- a/src/connector/with_options_source.yaml +++ b/src/connector/with_options_source.yaml @@ -507,6 +507,77 @@ NatsProperties: - name: max_message_size field_type: i32 required: false + - name: consumer.deliver_subject + field_type: String + required: false + - name: consumer.durable_name + field_type: String + required: false + - name: consumer.name + field_type: String + required: false + - name: consumer.description + field_type: String + required: false + - name: consumer.deliver_policy + field_type: String + required: false + - name: consumer.ack_policy + field_type: String + required: false + - name: consumer.ack_wait.sec + field_type: u64 + required: false + - name: consumer.max_deliver + field_type: i64 + required: false + - name: consumer.filter_subject + field_type: String + required: false + - name: consumer.filter_subjects + field_type: Vec + required: false + - name: consumer.replay_policy + field_type: String + required: false + - name: consumer.rate_limit + field_type: u64 + required: false + - name: consumer.sample_frequency + field_type: u8 + required: false + - name: consumer.max_waiting + field_type: i64 + required: false + - name: consumer.max_ack_pending + field_type: i64 + required: false + - name: consumer.headers_only + field_type: bool + required: false + - name: consumer.max_batch + field_type: i64 + required: false + - name: consumer.max_bytes + field_type: i64 + required: false + - name: consumer.max_expires.sec + field_type: u64 + required: false + - name: consumer.inactive_threshold.sec + field_type: u64 + required: false + - name: consumer.num.replicas + field_type: usize + required: false + alias: + - consumer.num_replicas + - name: consumer.memory_storage + field_type: bool + required: false + - name: consumer.backoff.sec + field_type: Vec + required: false - name: scan.startup.mode field_type: String required: false