From 4ce7e9c7814f5b61961cfae8b985846fd4105bf8 Mon Sep 17 00:00:00 2001 From: Dylan Martin Date: Sat, 16 Nov 2024 00:44:55 +0100 Subject: [PATCH] feat(flags): dynamic cohort matching in rust (#25776) --- rust/Cargo.lock | 6 + rust/feature-flags/Cargo.toml | 2 + rust/feature-flags/src/api.rs | 24 +- rust/feature-flags/src/cohort_cache.rs | 221 +++++ rust/feature-flags/src/cohort_models.rs | 50 + rust/feature-flags/src/cohort_operations.rs | 369 ++++++++ rust/feature-flags/src/flag_definitions.rs | 38 +- rust/feature-flags/src/flag_matching.rs | 880 ++++++++++++++++-- rust/feature-flags/src/flag_request.rs | 21 +- rust/feature-flags/src/lib.rs | 5 +- .../src/{utils.rs => metrics_utils.rs} | 0 rust/feature-flags/src/property_matching.rs | 43 +- rust/feature-flags/src/request_handler.rs | 24 +- rust/feature-flags/src/router.rs | 6 +- rust/feature-flags/src/server.rs | 4 + rust/feature-flags/src/team.rs | 6 +- rust/feature-flags/src/test_utils.rs | 74 +- .../tests/test_flag_matching_consistency.rs | 9 +- 18 files changed, 1671 insertions(+), 111 deletions(-) create mode 100644 rust/feature-flags/src/cohort_cache.rs create mode 100644 rust/feature-flags/src/cohort_models.rs create mode 100644 rust/feature-flags/src/cohort_operations.rs rename rust/feature-flags/src/{utils.rs => metrics_utils.rs} (100%) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index b99943cc4e557..9a263a8787831 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1808,7 +1808,9 @@ dependencies = [ "futures", "health", "maxminddb", + "moka", "once_cell", + "petgraph", "rand", "redis", "regex", @@ -3046,9 +3048,13 @@ version = "0.12.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32cf62eb4dd975d2dde76432fb1075c49e3ee2331cf36f1f8fd4b66550d32b6f" dependencies = [ + "async-lock 3.4.0", + "async-trait", "crossbeam-channel", "crossbeam-epoch", "crossbeam-utils", + "event-listener 5.3.1", + "futures-util", "once_cell", "parking_lot", "quanta 0.12.2", diff --git a/rust/feature-flags/Cargo.toml b/rust/feature-flags/Cargo.toml index 4cf4016767be6..4099fd8ab06fd 100644 --- a/rust/feature-flags/Cargo.toml +++ b/rust/feature-flags/Cargo.toml @@ -39,6 +39,8 @@ health = { path = "../common/health" } common-metrics = { path = "../common/metrics" } tower = { workspace = true } derive_builder = "0.20.1" +petgraph = "0.6.5" +moka = { version = "0.12.8", features = ["future"] } [lints] workspace = true diff --git a/rust/feature-flags/src/api.rs b/rust/feature-flags/src/api.rs index 4430476d28a52..9d6b649719bd2 100644 --- a/rust/feature-flags/src/api.rs +++ b/rust/feature-flags/src/api.rs @@ -89,7 +89,7 @@ pub enum FlagError { #[error("Row not found in postgres")] RowNotFound, #[error("failed to parse redis cache data")] - DataParsingError, + RedisDataParsingError, #[error("failed to update redis cache")] CacheUpdateError, #[error("redis unavailable")] @@ -102,6 +102,12 @@ pub enum FlagError { TimeoutError, #[error("No group type mappings")] NoGroupTypeMappings, + #[error("Cohort not found")] + CohortNotFound(String), + #[error("Failed to parse cohort filters")] + CohortFiltersParsingError, + #[error("Cohort dependency cycle")] + CohortDependencyCycle(String), } impl IntoResponse for FlagError { @@ -138,7 +144,7 @@ impl IntoResponse for FlagError { FlagError::TokenValidationError => { (StatusCode::UNAUTHORIZED, "The provided API key is invalid or has expired. Please check your API key and try again.".to_string()) } - FlagError::DataParsingError => { + FlagError::RedisDataParsingError => { tracing::error!("Data parsing error: {:?}", self); ( StatusCode::SERVICE_UNAVAILABLE, @@ -194,6 +200,18 @@ impl IntoResponse for FlagError { "The requested row was not found in the database. Please try again later or contact support if the problem persists.".to_string(), ) } + FlagError::CohortNotFound(msg) => { + tracing::error!("Cohort not found: {}", msg); + (StatusCode::NOT_FOUND, msg) + } + FlagError::CohortFiltersParsingError => { + tracing::error!("Failed to parse cohort filters: {:?}", self); + (StatusCode::BAD_REQUEST, "Failed to parse cohort filters. Please try again later or contact support if the problem persists.".to_string()) + } + FlagError::CohortDependencyCycle(msg) => { + tracing::error!("Cohort dependency cycle: {}", msg); + (StatusCode::BAD_REQUEST, msg) + } } .into_response() } @@ -205,7 +223,7 @@ impl From for FlagError { CustomRedisError::NotFound => FlagError::TokenValidationError, CustomRedisError::PickleError(e) => { tracing::error!("failed to fetch data: {}", e); - FlagError::DataParsingError + FlagError::RedisDataParsingError } CustomRedisError::Timeout(_) => FlagError::TimeoutError, CustomRedisError::Other(e) => { diff --git a/rust/feature-flags/src/cohort_cache.rs b/rust/feature-flags/src/cohort_cache.rs new file mode 100644 index 0000000000000..68894c19f88e2 --- /dev/null +++ b/rust/feature-flags/src/cohort_cache.rs @@ -0,0 +1,221 @@ +use crate::api::FlagError; +use crate::cohort_models::Cohort; +use crate::flag_matching::{PostgresReader, TeamId}; +use moka::future::Cache; +use std::time::Duration; + +/// CohortCacheManager manages the in-memory cache of cohorts using `moka` for caching. +/// +/// Features: +/// - **TTL**: Each cache entry expires after 5 minutes. +/// - **Size-based eviction**: The cache evicts least recently used entries when the maximum capacity is reached. +/// +/// ```text +/// CohortCacheManager { +/// postgres_reader: PostgresReader, +/// per_team_cohorts: Cache> { +/// // Example: +/// 2: [ +/// Cohort { id: 1, name: "Power Users", filters: {...} }, +/// Cohort { id: 2, name: "Churned", filters: {...} } +/// ], +/// 5: [ +/// Cohort { id: 3, name: "Beta Users", filters: {...} } +/// ] +/// } +/// } +/// ``` +/// +#[derive(Clone)] +pub struct CohortCacheManager { + postgres_reader: PostgresReader, + per_team_cohort_cache: Cache>, +} + +impl CohortCacheManager { + pub fn new( + postgres_reader: PostgresReader, + max_capacity: Option, + ttl_seconds: Option, + ) -> Self { + // We use the size of the cohort list (i.e., the number of cohorts for a given team)as the weight of the entry + let weigher = + |_: &TeamId, value: &Vec| -> u32 { value.len().try_into().unwrap_or(u32::MAX) }; + + let cache = Cache::builder() + .time_to_live(Duration::from_secs(ttl_seconds.unwrap_or(300))) // Default to 5 minutes + .weigher(weigher) + .max_capacity(max_capacity.unwrap_or(10_000)) // Default to 10,000 cohorts + .build(); + + Self { + postgres_reader, + per_team_cohort_cache: cache, + } + } + + /// Retrieves cohorts for a given team. + /// + /// If the cohorts are not present in the cache or have expired, it fetches them from the database, + /// caches the result upon successful retrieval, and then returns it. + pub async fn get_cohorts_for_team(&self, team_id: TeamId) -> Result, FlagError> { + if let Some(cached_cohorts) = self.per_team_cohort_cache.get(&team_id).await { + return Ok(cached_cohorts.clone()); + } + let fetched_cohorts = Cohort::list_from_pg(self.postgres_reader.clone(), team_id).await?; + self.per_team_cohort_cache + .insert(team_id, fetched_cohorts.clone()) + .await; + + Ok(fetched_cohorts) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cohort_models::Cohort; + use crate::test_utils::{ + insert_cohort_for_team_in_pg, insert_new_team_in_pg, setup_pg_reader_client, + setup_pg_writer_client, + }; + use std::sync::Arc; + use tokio::time::{sleep, Duration}; + + /// Helper function to setup a new team for testing. + async fn setup_test_team( + writer_client: Arc, + ) -> Result { + let team = crate::test_utils::insert_new_team_in_pg(writer_client, None).await?; + Ok(team.id) + } + + /// Helper function to insert a cohort for a team. + async fn setup_test_cohort( + writer_client: Arc, + team_id: TeamId, + name: Option, + ) -> Result { + let filters = serde_json::json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$active", "type": "person", "value": [true], "negation": false, "operator": "exact"}]}]}}); + insert_cohort_for_team_in_pg(writer_client, team_id, name, filters, false).await + } + + /// Tests that cache entries expire after the specified TTL. + #[tokio::test] + async fn test_cache_expiry() -> Result<(), anyhow::Error> { + let writer_client = setup_pg_writer_client(None).await; + let reader_client = setup_pg_reader_client(None).await; + + let team_id = setup_test_team(writer_client.clone()).await?; + let _cohort = setup_test_cohort(writer_client.clone(), team_id, None).await?; + + // Initialize CohortCacheManager with a short TTL for testing + let cohort_cache = CohortCacheManager::new( + reader_client.clone(), + Some(100), + Some(1), // 1-second TTL + ); + + let cohorts = cohort_cache.get_cohorts_for_team(team_id).await?; + assert_eq!(cohorts.len(), 1); + assert_eq!(cohorts[0].team_id, team_id); + + let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await; + assert!(cached_cohorts.is_some()); + + // Wait for TTL to expire + sleep(Duration::from_secs(2)).await; + + // Attempt to retrieve from cache again + let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await; + assert!(cached_cohorts.is_none(), "Cache entry should have expired"); + + Ok(()) + } + + /// Tests that the cache correctly evicts least recently used entries based on the weigher. + #[tokio::test] + async fn test_cache_weigher() -> Result<(), anyhow::Error> { + let writer_client = setup_pg_writer_client(None).await; + let reader_client = setup_pg_reader_client(None).await; + + // Define a smaller max_capacity for testing + let max_capacity: u64 = 3; + + let cohort_cache = CohortCacheManager::new(reader_client.clone(), Some(max_capacity), None); + + let mut inserted_team_ids = Vec::new(); + + // Insert multiple teams and their cohorts + for _ in 0..max_capacity { + let team = insert_new_team_in_pg(writer_client.clone(), None).await?; + let team_id = team.id; + inserted_team_ids.push(team_id); + setup_test_cohort(writer_client.clone(), team_id, None).await?; + cohort_cache.get_cohorts_for_team(team_id).await?; + } + + cohort_cache.per_team_cohort_cache.run_pending_tasks().await; + let cache_size = cohort_cache.per_team_cohort_cache.entry_count(); + assert_eq!( + cache_size, max_capacity, + "Cache size should be equal to max_capacity" + ); + + let new_team = insert_new_team_in_pg(writer_client.clone(), None).await?; + let new_team_id = new_team.id; + setup_test_cohort(writer_client.clone(), new_team_id, None).await?; + cohort_cache.get_cohorts_for_team(new_team_id).await?; + + cohort_cache.per_team_cohort_cache.run_pending_tasks().await; + let cache_size_after = cohort_cache.per_team_cohort_cache.entry_count(); + assert_eq!( + cache_size_after, max_capacity, + "Cache size should remain equal to max_capacity after eviction" + ); + + let evicted_team_id = &inserted_team_ids[0]; + let cached_cohorts = cohort_cache + .per_team_cohort_cache + .get(evicted_team_id) + .await; + assert!( + cached_cohorts.is_none(), + "Least recently used cache entry should have been evicted" + ); + + let cached_new_team = cohort_cache.per_team_cohort_cache.get(&new_team_id).await; + assert!( + cached_new_team.is_some(), + "Newly added cache entry should be present" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_get_cohorts_for_team() -> Result<(), anyhow::Error> { + let writer_client = setup_pg_writer_client(None).await; + let reader_client = setup_pg_reader_client(None).await; + let team_id = setup_test_team(writer_client.clone()).await?; + let _cohort = setup_test_cohort(writer_client.clone(), team_id, None).await?; + let cohort_cache = CohortCacheManager::new(reader_client.clone(), None, None); + + let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await; + assert!(cached_cohorts.is_none(), "Cache should initially be empty"); + + let cohorts = cohort_cache.get_cohorts_for_team(team_id).await?; + assert_eq!(cohorts.len(), 1); + assert_eq!(cohorts[0].team_id, team_id); + + let cached_cohorts = cohort_cache + .per_team_cohort_cache + .get(&team_id) + .await + .unwrap(); + assert_eq!(cached_cohorts.len(), 1); + assert_eq!(cached_cohorts[0].team_id, team_id); + + Ok(()) + } +} diff --git a/rust/feature-flags/src/cohort_models.rs b/rust/feature-flags/src/cohort_models.rs new file mode 100644 index 0000000000000..d109983901772 --- /dev/null +++ b/rust/feature-flags/src/cohort_models.rs @@ -0,0 +1,50 @@ +use crate::flag_definitions::PropertyFilter; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; + +#[derive(Debug, Clone, Serialize, Deserialize, FromRow)] +pub struct Cohort { + pub id: i32, + pub name: String, + pub description: Option, + pub team_id: i32, + pub deleted: bool, + pub filters: serde_json::Value, + pub query: Option, + pub version: Option, + pub pending_version: Option, + pub count: Option, + pub is_calculating: bool, + pub is_static: bool, + pub errors_calculating: i32, + pub groups: serde_json::Value, + pub created_by_id: Option, +} + +pub type CohortId = i32; + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +#[serde(rename_all = "UPPERCASE")] +pub enum CohortPropertyType { + AND, + OR, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CohortProperty { + pub properties: InnerCohortProperty, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct InnerCohortProperty { + #[serde(rename = "type")] + pub prop_type: CohortPropertyType, + pub values: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CohortValues { + #[serde(rename = "type")] + pub prop_type: String, + pub values: Vec, +} diff --git a/rust/feature-flags/src/cohort_operations.rs b/rust/feature-flags/src/cohort_operations.rs new file mode 100644 index 0000000000000..ea4214ccdc08b --- /dev/null +++ b/rust/feature-flags/src/cohort_operations.rs @@ -0,0 +1,369 @@ +use std::collections::HashSet; +use std::sync::Arc; +use tracing::instrument; + +use crate::cohort_models::{Cohort, CohortId, CohortProperty, InnerCohortProperty}; +use crate::{api::FlagError, database::Client as DatabaseClient, flag_definitions::PropertyFilter}; + +impl Cohort { + /// Returns a cohort from postgres given a cohort_id and team_id + #[instrument(skip_all)] + pub async fn from_pg( + client: Arc, + cohort_id: i32, + team_id: i32, + ) -> Result { + let mut conn = client.get_connection().await.map_err(|e| { + tracing::error!("Failed to get database connection: {}", e); + // TODO should I model my errors more generally? Like, yes, everything behind this API is technically a FlagError, + // but I'm not sure if accessing Cohort definitions should be a FlagError (vs idk, a CohortError? A more general API error?) + FlagError::DatabaseUnavailable + })?; + + let query = "SELECT id, name, description, team_id, deleted, filters, query, version, pending_version, count, is_calculating, is_static, errors_calculating, groups, created_by_id FROM posthog_cohort WHERE id = $1 AND team_id = $2"; + let cohort = sqlx::query_as::<_, Cohort>(query) + .bind(cohort_id) + .bind(team_id) + .fetch_optional(&mut *conn) + .await + .map_err(|e| { + tracing::error!("Failed to fetch cohort from database: {}", e); + FlagError::Internal(format!("Database query error: {}", e)) + })?; + + cohort.ok_or_else(|| { + FlagError::CohortNotFound(format!( + "Cohort with id {} not found for team {}", + cohort_id, team_id + )) + }) + } + + /// Returns all cohorts for a given team + #[instrument(skip_all)] + pub async fn list_from_pg( + client: Arc, + team_id: i32, + ) -> Result, FlagError> { + let mut conn = client.get_connection().await.map_err(|e| { + tracing::error!("Failed to get database connection: {}", e); + FlagError::DatabaseUnavailable + })?; + + let query = "SELECT id, name, description, team_id, deleted, filters, query, version, pending_version, count, is_calculating, is_static, errors_calculating, groups, created_by_id FROM posthog_cohort WHERE team_id = $1"; + let cohorts = sqlx::query_as::<_, Cohort>(query) + .bind(team_id) + .fetch_all(&mut *conn) + .await + .map_err(|e| { + tracing::error!("Failed to fetch cohorts from database: {}", e); + FlagError::Internal(format!("Database query error: {}", e)) + })?; + + Ok(cohorts) + } + + /// Parses the filters JSON into a CohortProperty structure + // TODO: this doesn't handle the deprecated "groups" field, see + // https://github.com/PostHog/posthog/blob/feat/dynamic-cohorts-rust/posthog/models/cohort/cohort.py#L114-L169 + // I'll handle that in a separate PR. + pub fn parse_filters(&self) -> Result, FlagError> { + let cohort_property: CohortProperty = serde_json::from_value(self.filters.clone()) + .map_err(|e| { + tracing::error!("Failed to parse filters for cohort {}: {}", self.id, e); + FlagError::CohortFiltersParsingError + })?; + Ok(cohort_property + .properties + .to_property_filters() + .into_iter() + .filter(|f| !(f.key == "id" && f.prop_type == "cohort")) + .collect()) + } + + /// Extracts dependent CohortIds from the cohort's filters + pub fn extract_dependencies(&self) -> Result, FlagError> { + let cohort_property: CohortProperty = serde_json::from_value(self.filters.clone()) + .map_err(|e| { + tracing::error!("Failed to parse filters for cohort {}: {}", self.id, e); + FlagError::CohortFiltersParsingError + })?; + + let mut dependencies = HashSet::new(); + Self::traverse_filters(&cohort_property.properties, &mut dependencies)?; + Ok(dependencies) + } + + /// Recursively traverses the filter tree to find cohort dependencies + /// + /// Example filter tree structure: + /// ```json + /// { + /// "properties": { + /// "type": "OR", + /// "values": [ + /// { + /// "type": "OR", + /// "values": [ + /// { + /// "key": "id", + /// "value": 123, + /// "type": "cohort", + /// "operator": "exact" + /// }, + /// { + /// "key": "email", + /// "value": "@posthog.com", + /// "type": "person", + /// "operator": "icontains" + /// } + /// ] + /// } + /// ] + /// } + /// } + /// ``` + fn traverse_filters( + inner: &InnerCohortProperty, + dependencies: &mut HashSet, + ) -> Result<(), FlagError> { + for cohort_values in &inner.values { + for filter in &cohort_values.values { + if filter.is_cohort() { + // Assuming the value is a single integer CohortId + if let Some(cohort_id) = filter.value.as_i64() { + dependencies.insert(cohort_id as CohortId); + } else { + return Err(FlagError::CohortFiltersParsingError); + } + } + // NB: we don't support nested cohort properties, so we don't need to traverse further + } + } + Ok(()) + } +} + +impl InnerCohortProperty { + /// Flattens the nested cohort property structure into a list of property filters. + /// + /// The cohort property structure in Postgres looks like: + /// ```json + /// { + /// "type": "OR", + /// "values": [ + /// { + /// "type": "OR", + /// "values": [ + /// { + /// "key": "email", + /// "value": "@posthog.com", + /// "type": "person", + /// "operator": "icontains" + /// }, + /// { + /// "key": "age", + /// "value": 25, + /// "type": "person", + /// "operator": "gt" + /// } + /// ] + /// } + /// ] + /// } + /// ``` + pub fn to_property_filters(&self) -> Vec { + self.values + .iter() + .flat_map(|value| &value.values) + .cloned() + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + cohort_models::{CohortPropertyType, CohortValues}, + test_utils::{ + insert_cohort_for_team_in_pg, insert_new_team_in_pg, setup_pg_reader_client, + setup_pg_writer_client, + }, + }; + use serde_json::json; + + #[tokio::test] + async fn test_cohort_from_pg() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .expect("Failed to insert team"); + + let cohort = insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + None, + json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$initial_browser_version", "type": "person", "value": ["125"], "negation": false, "operator": "exact"}]}]}}), + false, + ) + .await + .expect("Failed to insert cohort"); + + let fetched_cohort = Cohort::from_pg(postgres_reader, cohort.id, team.id) + .await + .expect("Failed to fetch cohort"); + + assert_eq!(fetched_cohort.id, cohort.id); + assert_eq!(fetched_cohort.name, "Test Cohort"); + assert_eq!(fetched_cohort.team_id, team.id); + } + + #[tokio::test] + async fn test_list_from_pg() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .expect("Failed to insert team"); + + // Insert multiple cohorts for the team + insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + Some("Cohort 1".to_string()), + json!({"properties": {"type": "AND", "values": [{"type": "property", "values": [{"key": "age", "type": "person", "value": [30], "negation": false, "operator": "gt"}]}]}}), + false, + ) + .await + .expect("Failed to insert cohort1"); + + insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + Some("Cohort 2".to_string()), + json!({"properties": {"type": "OR", "values": [{"type": "property", "values": [{"key": "country", "type": "person", "value": ["USA"], "negation": false, "operator": "exact"}]}]}}), + false, + ) + .await + .expect("Failed to insert cohort2"); + + let cohorts = Cohort::list_from_pg(postgres_reader, team.id) + .await + .expect("Failed to list cohorts"); + + assert_eq!(cohorts.len(), 2); + let names: HashSet = cohorts.into_iter().map(|c| c.name).collect(); + assert!(names.contains("Cohort 1")); + assert!(names.contains("Cohort 2")); + } + + #[test] + fn test_cohort_parse_filters() { + let cohort = Cohort { + id: 1, + name: "Test Cohort".to_string(), + description: None, + team_id: 1, + deleted: false, + filters: json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$initial_browser_version", "type": "person", "value": ["125"], "negation": false, "operator": "exact"}]}]}}), + query: None, + version: None, + pending_version: None, + count: None, + is_calculating: false, + is_static: false, + errors_calculating: 0, + groups: json!({}), + created_by_id: None, + }; + + let result = cohort.parse_filters().unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].key, "$initial_browser_version"); + assert_eq!(result[0].value, json!(["125"])); + assert_eq!(result[0].prop_type, "person"); + } + + #[test] + fn test_cohort_property_to_property_filters() { + let cohort_property = InnerCohortProperty { + prop_type: CohortPropertyType::AND, + values: vec![CohortValues { + prop_type: "property".to_string(), + values: vec![ + PropertyFilter { + key: "email".to_string(), + value: json!("test@example.com"), + operator: None, + prop_type: "person".to_string(), + group_type_index: None, + negation: None, + }, + PropertyFilter { + key: "age".to_string(), + value: json!(25), + operator: None, + prop_type: "person".to_string(), + group_type_index: None, + negation: None, + }, + ], + }], + }; + + let result = cohort_property.to_property_filters(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].key, "email"); + assert_eq!(result[0].value, json!("test@example.com")); + assert_eq!(result[1].key, "age"); + assert_eq!(result[1].value, json!(25)); + } + + #[tokio::test] + async fn test_extract_dependencies() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .expect("Failed to insert team"); + + // Insert a single cohort that is dependent on another cohort + let dependent_cohort = insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + Some("Dependent Cohort".to_string()), + json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$browser", "type": "person", "value": ["Safari"], "negation": false, "operator": "exact"}]}]}}), + false, + ) + .await + .expect("Failed to insert dependent_cohort"); + + // Insert main cohort with a single dependency + let main_cohort = insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + Some("Main Cohort".to_string()), + json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "id", "type": "cohort", "value": dependent_cohort.id, "negation": false}]}]}}), + false, + ) + .await + .expect("Failed to insert main_cohort"); + + let fetched_main_cohort = Cohort::from_pg(postgres_reader.clone(), main_cohort.id, team.id) + .await + .expect("Failed to fetch main cohort"); + + println!("fetched_main_cohort: {:?}", fetched_main_cohort); + + let dependencies = fetched_main_cohort.extract_dependencies().unwrap(); + let expected_dependencies: HashSet = + [dependent_cohort.id].iter().cloned().collect(); + + assert_eq!(dependencies, expected_dependencies); + } +} diff --git a/rust/feature-flags/src/flag_definitions.rs b/rust/feature-flags/src/flag_definitions.rs index baebaa04da30e..d62ecc9e0e0c1 100644 --- a/rust/feature-flags/src/flag_definitions.rs +++ b/rust/feature-flags/src/flag_definitions.rs @@ -1,4 +1,7 @@ -use crate::{api::FlagError, database::Client as DatabaseClient, redis::Client as RedisClient}; +use crate::{ + api::FlagError, cohort_models::CohortId, database::Client as DatabaseClient, + redis::Client as RedisClient, +}; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tracing::instrument; @@ -7,7 +10,7 @@ use tracing::instrument; // TODO: Add integration tests across repos to ensure this doesn't happen. pub const TEAM_FLAGS_CACHE_PREFIX: &str = "posthog:1:team_feature_flags_"; -#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)] #[serde(rename_all = "snake_case")] pub enum OperatorType { Exact, @@ -25,6 +28,8 @@ pub enum OperatorType { IsDateExact, IsDateAfter, IsDateBefore, + In, + NotIn, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -36,10 +41,28 @@ pub struct PropertyFilter { pub value: serde_json::Value, pub operator: Option, #[serde(rename = "type")] + // TODO: worth making a enum here to differentiate between cohort and person filters? pub prop_type: String, + pub negation: Option, pub group_type_index: Option, } +impl PropertyFilter { + /// Checks if the filter is a cohort filter + pub fn is_cohort(&self) -> bool { + self.key == "id" && self.prop_type == "cohort" + } + + /// Returns the cohort id if the filter is a cohort filter, or None if it's not a cohort filter + /// or if the value cannot be parsed as a cohort id + pub fn get_cohort_id(&self) -> Option { + if !self.is_cohort() { + return None; + } + self.value.as_i64().map(|id| id as CohortId) + } +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct FlagGroupType { pub properties: Option>, @@ -68,6 +91,9 @@ pub struct FlagFilters { pub super_groups: Option>, } +// TODO: see if you can combine these two structs, like we do with cohort models +// this will require not deserializing on read and instead doing it lazily, on-demand +// (which, tbh, is probably a better idea) #[derive(Debug, Clone, Deserialize, Serialize)] pub struct FeatureFlag { pub id: i32, @@ -142,7 +168,7 @@ impl FeatureFlagList { tracing::error!("failed to parse data to flags list: {}", e); println!("failed to parse data: {}", e); - FlagError::DataParsingError + FlagError::RedisDataParsingError })?; Ok(FeatureFlagList { flags: flags_list }) @@ -174,7 +200,7 @@ impl FeatureFlagList { .map(|row| { let filters = serde_json::from_value(row.filters).map_err(|e| { tracing::error!("Failed to deserialize filters for flag {}: {}", row.key, e); - FlagError::DataParsingError + FlagError::RedisDataParsingError })?; Ok(FeatureFlag { @@ -200,7 +226,7 @@ impl FeatureFlagList { ) -> Result<(), FlagError> { let payload = serde_json::to_string(&flags.flags).map_err(|e| { tracing::error!("Failed to serialize flags: {}", e); - FlagError::DataParsingError + FlagError::RedisDataParsingError })?; client @@ -1095,7 +1121,7 @@ mod tests { .expect("Failed to set malformed JSON in Redis"); let result = FeatureFlagList::from_redis(redis_client, team.id).await; - assert!(matches!(result, Err(FlagError::DataParsingError))); + assert!(matches!(result, Err(FlagError::RedisDataParsingError))); // Test database query error (using a non-existent table) let result = sqlx::query("SELECT * FROM non_existent_table") diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index bdcd542f09854..571fe9c84b40a 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -1,30 +1,34 @@ use crate::{ api::{FlagError, FlagValue, FlagsResponse}, + cohort_cache::CohortCacheManager, + cohort_models::{Cohort, CohortId}, database::Client as DatabaseClient, feature_flag_match_reason::FeatureFlagMatchReason, - flag_definitions::{FeatureFlag, FeatureFlagList, FlagGroupType, PropertyFilter}, + flag_definitions::{FeatureFlag, FeatureFlagList, FlagGroupType, OperatorType, PropertyFilter}, metrics_consts::{FLAG_EVALUATION_ERROR_COUNTER, FLAG_HASH_KEY_WRITES_COUNTER}, + metrics_utils::parse_exception_for_prometheus_label, property_matching::match_property, - utils::parse_exception_for_prometheus_label, }; use anyhow::Result; use common_metrics::inc; +use petgraph::algo::{is_cyclic_directed, toposort}; +use petgraph::graph::DiGraph; use serde_json::Value; use sha1::{Digest, Sha1}; use sqlx::{postgres::PgQueryResult, Acquire, FromRow}; use std::fmt::Write; use std::sync::Arc; use std::{ - collections::{HashMap, HashSet}, + collections::{HashMap, HashSet, VecDeque}, time::Duration, }; use tokio::time::{sleep, timeout}; use tracing::{error, info}; -type TeamId = i32; -type GroupTypeIndex = i32; -type PostgresReader = Arc; -type PostgresWriter = Arc; +pub type TeamId = i32; +pub type GroupTypeIndex = i32; +pub type PostgresReader = Arc; +pub type PostgresWriter = Arc; #[derive(Debug)] struct SuperConditionEvaluation { @@ -182,6 +186,7 @@ pub struct FeatureFlagMatcher { pub team_id: TeamId, pub postgres_reader: PostgresReader, pub postgres_writer: PostgresWriter, + pub cohort_cache: Arc, group_type_mapping_cache: GroupTypeMappingCache, properties_cache: PropertiesCache, groups: HashMap, @@ -195,8 +200,8 @@ impl FeatureFlagMatcher { team_id: TeamId, postgres_reader: PostgresReader, postgres_writer: PostgresWriter, + cohort_cache: Arc, group_type_mapping_cache: Option, - properties_cache: Option, groups: Option>, ) -> Self { FeatureFlagMatcher { @@ -204,10 +209,11 @@ impl FeatureFlagMatcher { team_id, postgres_reader: postgres_reader.clone(), postgres_writer: postgres_writer.clone(), + cohort_cache, group_type_mapping_cache: group_type_mapping_cache .unwrap_or_else(|| GroupTypeMappingCache::new(team_id, postgres_reader.clone())), - properties_cache: properties_cache.unwrap_or_default(), groups: groups.unwrap_or_default(), + properties_cache: PropertiesCache::default(), } } @@ -732,12 +738,29 @@ impl FeatureFlagMatcher { .await; } - // NB: we can only evaluate group or person properties, not both - let properties_to_check = self - .get_properties_to_check(feature_flag, property_overrides, flag_property_filters) + // Separate cohort and non-cohort filters + let (cohort_filters, non_cohort_filters): (Vec, Vec) = + flag_property_filters + .iter() + .cloned() + .partition(|prop| prop.is_cohort()); + + // Get the properties we need to check for in this condition match from the flag + any overrides + let target_properties = self + .get_properties_to_check(feature_flag, property_overrides, &non_cohort_filters) .await?; - if !all_properties_match(flag_property_filters, &properties_to_check) { + // Evaluate non-cohort filters first, since they're cheaper to evaluate and we can return early if they don't match + if !all_properties_match(&non_cohort_filters, &target_properties) { + return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); + } + + // Evaluate cohort filters, if any. + if !cohort_filters.is_empty() + && !self + .evaluate_cohort_filters(&cohort_filters, &target_properties) + .await? + { return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); } } @@ -805,6 +828,37 @@ impl FeatureFlagMatcher { } } + /// Evaluates dynamic cohort property filters + /// + /// NB: This method first caches all of the cohorts associated with the team, which allows us to avoid + /// hitting the database for each cohort filter. + pub async fn evaluate_cohort_filters( + &self, + cohort_property_filters: &[PropertyFilter], + target_properties: &HashMap, + ) -> Result { + // At the start of the request, fetch all of the cohorts for the team from the cache + // This method also caches the cohorts in memory for the duration of the application, so we don't need to fetch from + // the database again until we restart the application. + let cohorts = self.cohort_cache.get_cohorts_for_team(self.team_id).await?; + + // Store cohort match results in a HashMap to avoid re-evaluating the same cohort multiple times, + // since the same cohort could appear in multiple property filters. This is especially important + // because evaluating a cohort requires evaluating all of its dependencies, which can be expensive. + let mut cohort_matches = HashMap::new(); + for filter in cohort_property_filters { + let cohort_id = filter + .get_cohort_id() + .ok_or(FlagError::CohortFiltersParsingError)?; + let match_result = + evaluate_cohort_dependencies(cohort_id, target_properties, cohorts.clone())?; + cohort_matches.insert(cohort_id, match_result); + } + + // Apply cohort membership logic (IN|NOT_IN) + apply_cohort_membership_logic(cohort_property_filters, &cohort_matches) + } + /// Check if a super condition matches for a feature flag. /// /// This function evaluates the super conditions of a feature flag to determine if any of them should be enabled. @@ -1048,6 +1102,172 @@ impl FeatureFlagMatcher { } } +/// Evaluates a single cohort and its dependencies. +/// This uses a topological sort to evaluate dependencies first, which is necessary +/// because a cohort can depend on another cohort, and we need to respect the dependency order. +fn evaluate_cohort_dependencies( + initial_cohort_id: CohortId, + target_properties: &HashMap, + cohorts: Vec, +) -> Result { + let cohort_dependency_graph = + build_cohort_dependency_graph(initial_cohort_id, cohorts.clone())?; + + // We need to sort cohorts topologically to ensure we evaluate dependencies before the cohorts that depend on them. + // For example, if cohort A depends on cohort B, we need to evaluate B first to know if A matches. + // This also helps detect cycles - if cohort A depends on B which depends on A, toposort will fail. + let sorted_cohort_ids_as_graph_nodes = + toposort(&cohort_dependency_graph, None).map_err(|e| { + FlagError::CohortDependencyCycle(format!("Cyclic dependency detected: {:?}", e)) + })?; + + // Store evaluation results for each cohort in a map, so we can look up whether a cohort matched + // when evaluating cohorts that depend on it, and also return the final result for the initial cohort + let mut evaluation_results = HashMap::new(); + + // Iterate through the sorted nodes in reverse order (so that we can evaluate dependencies first) + for node in sorted_cohort_ids_as_graph_nodes.into_iter().rev() { + let cohort_id = cohort_dependency_graph[node]; + let cohort = cohorts + .iter() + .find(|c| c.id == cohort_id) + .ok_or(FlagError::CohortNotFound(cohort_id.to_string()))?; + let property_filters = cohort.parse_filters()?; + let dependencies = cohort.extract_dependencies()?; + + // Check if all dependencies have been met (i.e., previous cohorts matched) + let dependencies_met = dependencies + .iter() + .all(|dep_id| evaluation_results.get(dep_id).copied().unwrap_or(false)); + + // If dependencies are not met, mark the current cohort as not matched and continue + // NB: We don't want to _exit_ here, since the non-matching cohort could be wrapped in a `not_in` operator + // and we want to evaluate all cohorts to determine if the initial cohort matches. + if !dependencies_met { + evaluation_results.insert(cohort_id, false); + continue; + } + + // Evaluate all property filters for the current cohort + let all_filters_match = property_filters + .iter() + .all(|filter| match_property(filter, target_properties, false).unwrap_or(false)); + + // Store the evaluation result for the current cohort + evaluation_results.insert(cohort_id, all_filters_match); + } + + // Retrieve and return the evaluation result for the initial cohort + evaluation_results + .get(&initial_cohort_id) + .copied() + .ok_or_else(|| FlagError::CohortNotFound(initial_cohort_id.to_string())) +} + +/// Apply cohort membership logic (i.e., IN|NOT_IN) +fn apply_cohort_membership_logic( + cohort_filters: &[PropertyFilter], + cohort_matches: &HashMap, +) -> Result { + for filter in cohort_filters { + let cohort_id = filter + .get_cohort_id() + .ok_or(FlagError::CohortFiltersParsingError)?; + let matches = cohort_matches.get(&cohort_id).copied().unwrap_or(false); + let operator = filter.operator.unwrap_or(OperatorType::In); + + // Combine the operator logic directly within this method + let membership_match = match operator { + OperatorType::In => matches, + OperatorType::NotIn => !matches, + // Currently supported operators are IN and NOT IN + // Any other operator defaults to false + _ => false, + }; + + // If any filter does not match, return false early + if !membership_match { + return Ok(false); + } + } + // All filters matched + Ok(true) +} + +/// Constructs a dependency graph for cohorts. +/// +/// Example dependency graph: +/// ```text +/// A B +/// | /| +/// | / | +/// | / | +/// C D +/// \ / +/// \ / +/// E +/// ``` +/// In this example: +/// - Cohorts A and B are root nodes (no dependencies) +/// - C depends on A and B +/// - D depends on B +/// - E depends on C and D +/// +/// The graph is acyclic, which is required for valid cohort dependencies. +fn build_cohort_dependency_graph( + initial_cohort_id: CohortId, + cohorts: Vec, +) -> Result, FlagError> { + let mut graph = DiGraph::new(); + let mut node_map = HashMap::new(); + let mut queue = VecDeque::new(); + // This implements a breadth-first search (BFS) traversal to build a directed graph of cohort dependencies. + // Starting from the initial cohort, we: + // 1. Add each cohort as a node in the graph + // 2. Track visited nodes in a map to avoid duplicates + // 3. For each cohort, get its dependencies and add directed edges from the cohort to its dependencies + // 4. Queue up any unvisited dependencies to process their dependencies later + // This builds up the full dependency graph level by level, which we can later check for cycles + queue.push_back(initial_cohort_id); + node_map.insert(initial_cohort_id, graph.add_node(initial_cohort_id)); + + while let Some(cohort_id) = queue.pop_front() { + let cohort = cohorts + .iter() + .find(|c| c.id == cohort_id) + .ok_or(FlagError::CohortNotFound(cohort_id.to_string()))?; + let dependencies = cohort.extract_dependencies()?; + for dep_id in dependencies { + // Retrieve the current node **before** mutable borrowing + // This is safe because we're not mutating the node map, + // and it keeps the borrow checker happy + let current_node = node_map[&cohort_id]; + // Add dependency node if we haven't seen this cohort ID before in our traversal. + // This happens when we discover a new dependency that wasn't previously + // encountered while processing other cohorts in the graph. + let dep_node = node_map + .entry(dep_id) + .or_insert_with(|| graph.add_node(dep_id)); + + graph.add_edge(current_node, *dep_node, ()); + + if !node_map.contains_key(&dep_id) { + queue.push_back(dep_id); + } + } + } + + // Check for cycles, this is an directed acyclic graph so we use is_cyclic_directed + if is_cyclic_directed(&graph) { + return Err(FlagError::CohortDependencyCycle(format!( + "Cyclic dependency detected starting at cohort {}", + initial_cohort_id + ))); + } + + Ok(graph) +} + /// Fetch and locally cache all properties for a given distinct ID and team ID. /// /// This function fetches both person and group properties for a specified distinct ID and team ID. @@ -1443,8 +1663,8 @@ mod tests { OperatorType, }, test_utils::{ - insert_flag_for_team_in_pg, insert_new_team_in_pg, insert_person_for_team_in_pg, - setup_pg_reader_client, setup_pg_writer_client, + insert_cohort_for_team_in_pg, insert_flag_for_team_in_pg, insert_new_team_in_pg, + insert_person_for_team_in_pg, setup_pg_reader_client, setup_pg_writer_client, }, }; @@ -1485,6 +1705,7 @@ mod tests { async fn test_fetch_properties_from_pg_to_match() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await @@ -1534,7 +1755,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -1547,7 +1768,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -1560,7 +1781,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -1573,6 +1794,7 @@ mod tests { async fn test_person_property_overrides() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -1590,6 +1812,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -1611,7 +1834,7 @@ mod tests { team.id, postgres_reader, postgres_writer, - None, + cohort_cache, None, None, ); @@ -1633,6 +1856,7 @@ mod tests { async fn test_group_property_overrides() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -1650,6 +1874,7 @@ mod tests { operator: None, prop_type: "group".to_string(), group_type_index: Some(1), + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -1664,10 +1889,12 @@ mod tests { None, ); - let mut cache = GroupTypeMappingCache::new(team.id, postgres_reader.clone()); + let mut group_type_mapping_cache = + GroupTypeMappingCache::new(team.id, postgres_reader.clone()); let group_types_to_indexes = [("organization".to_string(), 1)].into_iter().collect(); - cache.group_types_to_indexes = group_types_to_indexes; - cache.group_indexes_to_types = [(1, "organization".to_string())].into_iter().collect(); + group_type_mapping_cache.group_types_to_indexes = group_types_to_indexes; + group_type_mapping_cache.group_indexes_to_types = + [(1, "organization".to_string())].into_iter().collect(); let groups = HashMap::from([("organization".to_string(), json!("org_123"))]); @@ -1684,8 +1911,8 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - Some(cache), - None, + cohort_cache.clone(), + Some(group_type_mapping_cache), Some(groups), ); @@ -1708,14 +1935,14 @@ mod tests { let flag = create_test_flag_with_variants(1); let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - - let mut cache = GroupTypeMappingCache::new(1, postgres_reader.clone()); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let mut group_type_mapping_cache = GroupTypeMappingCache::new(1, postgres_reader.clone()); let group_types_to_indexes = [("group_type_1".to_string(), 1)].into_iter().collect(); let group_type_index_to_name = [(1, "group_type_1".to_string())].into_iter().collect(); - cache.group_types_to_indexes = group_types_to_indexes; - cache.group_indexes_to_types = group_type_index_to_name; + group_type_mapping_cache.group_types_to_indexes = group_types_to_indexes; + group_type_mapping_cache.group_indexes_to_types = group_type_index_to_name; let groups = HashMap::from([("group_type_1".to_string(), json!("group_key_1"))]); @@ -1724,8 +1951,8 @@ mod tests { 1, postgres_reader.clone(), postgres_writer.clone(), - Some(cache), - None, + cohort_cache.clone(), + Some(group_type_mapping_cache), Some(groups), ); let variant = matcher.get_matching_variant(&flag, None).await.unwrap(); @@ -1740,6 +1967,7 @@ mod tests { async fn test_get_matching_variant_with_db() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -1751,7 +1979,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -1765,6 +1993,7 @@ mod tests { async fn test_is_condition_match_empty_properties() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let flag = create_test_flag( Some(1), None, @@ -1797,7 +2026,7 @@ mod tests { 1, postgres_reader, postgres_writer, - None, + cohort_cache, None, None, ); @@ -1854,6 +2083,7 @@ mod tests { async fn test_overrides_avoid_db_lookups() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -1871,6 +2101,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -1893,7 +2124,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -1923,6 +2154,7 @@ mod tests { async fn test_fallback_to_db_when_overrides_insufficient() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -1941,6 +2173,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }, PropertyFilter { key: "age".to_string(), @@ -1948,6 +2181,7 @@ mod tests { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }, ]), rollout_percentage: Some(100.0), @@ -1982,7 +2216,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -2006,6 +2240,7 @@ mod tests { async fn test_property_fetching_and_caching() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2025,7 +2260,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -2050,6 +2285,7 @@ mod tests { async fn test_property_caching() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2069,7 +2305,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -2102,7 +2338,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -2150,6 +2386,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }, PropertyFilter { key: "age".to_string(), @@ -2157,6 +2394,7 @@ mod tests { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }, ]; @@ -2170,6 +2408,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }, PropertyFilter { key: "cohort".to_string(), @@ -2177,6 +2416,7 @@ mod tests { operator: None, prop_type: "cohort".to_string(), group_type_index: None, + negation: None, }, ]; @@ -2189,6 +2429,7 @@ mod tests { async fn test_concurrent_flag_evaluation() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2218,13 +2459,14 @@ mod tests { let flag_clone = flag.clone(); let postgres_reader_clone = postgres_reader.clone(); let postgres_writer_clone = postgres_writer.clone(); + let cohort_cache_clone = cohort_cache.clone(); handles.push(tokio::spawn(async move { let mut matcher = FeatureFlagMatcher::new( format!("test_user_{}", i), team.id, postgres_reader_clone, postgres_writer_clone, - None, + cohort_cache_clone, None, None, ); @@ -2246,6 +2488,7 @@ mod tests { async fn test_property_operators() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2264,6 +2507,7 @@ mod tests { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }, PropertyFilter { key: "email".to_string(), @@ -2271,6 +2515,7 @@ mod tests { operator: Some(OperatorType::Icontains), prop_type: "person".to_string(), group_type_index: None, + negation: None, }, ]), rollout_percentage: Some(100.0), @@ -2300,7 +2545,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -2314,7 +2559,7 @@ mod tests { async fn test_empty_hashed_identifier() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let flag = create_test_flag( Some(1), None, @@ -2341,7 +2586,7 @@ mod tests { 1, postgres_reader, postgres_writer, - None, + cohort_cache, None, None, ); @@ -2355,6 +2600,7 @@ mod tests { async fn test_rollout_percentage() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let mut flag = create_test_flag( Some(1), None, @@ -2381,7 +2627,7 @@ mod tests { 1, postgres_reader, postgres_writer, - None, + cohort_cache, None, None, ); @@ -2402,7 +2648,7 @@ mod tests { async fn test_uneven_variant_distribution() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let mut flag = create_test_flag_with_variants(1); // Adjust variant rollout percentages to be uneven @@ -2432,7 +2678,7 @@ mod tests { 1, postgres_reader, postgres_writer, - None, + cohort_cache, None, None, ); @@ -2464,6 +2710,7 @@ mod tests { async fn test_missing_properties_in_db() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2491,6 +2738,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2510,7 +2758,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache, None, None, ); @@ -2524,6 +2772,7 @@ mod tests { async fn test_malformed_property_data() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2551,6 +2800,7 @@ mod tests { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2570,7 +2820,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache, None, None, ); @@ -2585,6 +2835,7 @@ mod tests { async fn test_get_match_with_insufficient_overrides() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2603,6 +2854,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }, PropertyFilter { key: "age".to_string(), @@ -2610,6 +2862,7 @@ mod tests { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }, ]), rollout_percentage: Some(100.0), @@ -2644,7 +2897,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache, None, None, ); @@ -2661,6 +2914,7 @@ mod tests { async fn test_evaluation_reasons() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let flag = create_test_flag( Some(1), None, @@ -2687,7 +2941,7 @@ mod tests { 1, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache, None, None, ); @@ -2705,6 +2959,7 @@ mod tests { async fn test_complex_conditions() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2723,6 +2978,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2734,6 +2990,7 @@ mod tests { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2763,7 +3020,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache, None, None, ); @@ -2777,6 +3034,7 @@ mod tests { async fn test_super_condition_matches_boolean() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2795,6 +3053,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(0.0), variant: None, @@ -2806,6 +3065,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2826,6 +3086,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2850,7 +3111,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -2860,7 +3121,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -2870,7 +3131,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -2897,6 +3158,7 @@ mod tests { async fn test_super_condition_matches_string() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2924,6 +3186,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(0.0), variant: None, @@ -2935,6 +3198,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2955,6 +3219,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2970,7 +3235,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -2986,6 +3251,7 @@ mod tests { async fn test_super_condition_matches_and_false() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3013,6 +3279,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(0.0), variant: None, @@ -3024,6 +3291,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -3044,6 +3312,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -3059,7 +3328,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -3069,7 +3338,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -3079,7 +3348,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -3116,6 +3385,473 @@ mod tests { assert_eq!(result_another_id.condition_index, Some(2)); } + #[tokio::test] + async fn test_basic_cohort_matching() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a cohort with the condition that matches the test user's properties + let cohort_row = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + None, + json!({ + "properties": { + "type": "OR", + "values": [{ + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "125", + "negation": false, + "operator": "gt" + }] + }] + } + }), + false, + ) + .await + .unwrap(); + + // Insert a person with properties that match the cohort condition + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "test_user".to_string(), + Some(json!({"$browser_version": 126})), + ) + .await + .unwrap(); + + // Define a flag with a cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort_row.id), + operator: Some(OperatorType::In), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + cohort_cache.clone(), + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + + assert!(result.matches); + } + + #[tokio::test] + async fn test_not_in_cohort_matching() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a cohort with a condition that does not match the test user's properties + let cohort_row = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + None, + json!({ + "properties": { + "type": "OR", + "values": [{ + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "130", + "negation": false, + "operator": "gt" + }] + }] + } + }), + false, + ) + .await + .unwrap(); + + // Insert a person with properties that do not match the cohort condition + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "test_user".to_string(), + Some(json!({"$browser_version": 126})), + ) + .await + .unwrap(); + + // Define a flag with a NotIn cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort_row.id), + operator: Some(OperatorType::NotIn), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + cohort_cache.clone(), + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + + assert!(result.matches); + } + + #[tokio::test] + async fn test_not_in_cohort_matching_user_in_cohort() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a cohort with a condition that matches the test user's properties + let cohort_row = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + None, + json!({ + "properties": { + "type": "OR", + "values": [{ + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "125", + "negation": false, + "operator": "gt" + }] + }] + } + }), + false, + ) + .await + .unwrap(); + + // Insert a person with properties that match the cohort condition + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "test_user".to_string(), + Some(json!({"$browser_version": 126})), + ) + .await + .unwrap(); + + // Define a flag with a NotIn cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort_row.id), + operator: Some(OperatorType::NotIn), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + cohort_cache.clone(), + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + + // The user matches the cohort, but the flag is set to NotIn, so it should evaluate to false + assert!(!result.matches); + } + + #[tokio::test] + async fn test_cohort_dependent_on_another_cohort() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a base cohort + let base_cohort_row = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + None, + json!({ + "properties": { + "type": "OR", + "values": [{ + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "125", + "negation": false, + "operator": "gt" + }] + }] + } + }), + false, + ) + .await + .unwrap(); + + // Insert a dependent cohort that includes the base cohort + let dependent_cohort_row = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + None, + json!({ + "properties": { + "type": "OR", + "values": [{ + "type": "OR", + "values": [{ + "key": "id", + "type": "cohort", + "value": base_cohort_row.id, + "negation": false, + "operator": "in" + }] + }] + } + }), + false, + ) + .await + .unwrap(); + + // Insert a person with properties that match the base cohort condition + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "test_user".to_string(), + Some(json!({"$browser_version": 126})), + ) + .await + .unwrap(); + + // Define a flag with a cohort filter that depends on another cohort + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(dependent_cohort_row.id), + operator: Some(OperatorType::In), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + cohort_cache.clone(), + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + + assert!(result.matches); + } + + #[tokio::test] + async fn test_in_cohort_matching_user_not_in_cohort() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a cohort with a condition that does not match the test user's properties + let cohort_row = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + None, + json!({ + "properties": { + "type": "OR", + "values": [{ + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "130", + "negation": false, + "operator": "gt" + }] + }] + } + }), + false, + ) + .await + .unwrap(); + + // Insert a person with properties that do not match the cohort condition + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "test_user".to_string(), + Some(json!({"$browser_version": 125})), + ) + .await + .unwrap(); + + // Define a flag with an In cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort_row.id), + operator: Some(OperatorType::In), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + cohort_cache.clone(), + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + + // The user does not match the cohort, and the flag is set to In, so it should evaluate to false + assert!(!result.matches); + } + #[tokio::test] async fn test_set_feature_flag_hash_key_overrides_success() { let postgres_reader = setup_pg_reader_client(None).await; @@ -3123,7 +3859,7 @@ mod tests { let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); - let distinct_id = "user1".to_string(); + let distinct_id = "user2".to_string(); // Insert person insert_person_for_team_in_pg(postgres_reader.clone(), team.id, distinct_id.clone(), None) @@ -3148,7 +3884,7 @@ mod tests { Some(true), // ensure_experience_continuity ); - // need to convert flag to FeatureFlagRow + // Convert flag to FeatureFlagRow let flag_row = FeatureFlagRow { id: flag.id, team_id: flag.team_id, @@ -3165,8 +3901,8 @@ mod tests { .await .unwrap(); - // Attempt to set hash key override - let result = set_feature_flag_hash_key_overrides( + // Set hash key override + set_feature_flag_hash_key_overrides( postgres_writer.clone(), team.id, vec![distinct_id.clone()], @@ -3175,9 +3911,7 @@ mod tests { .await .unwrap(); - assert!(result, "Hash key override should be set successfully"); - - // Retrieve the hash key overrides + // Retrieve hash key overrides let overrides = get_feature_flag_hash_key_overrides( postgres_reader.clone(), team.id, @@ -3186,14 +3920,10 @@ mod tests { .await .unwrap(); - assert!( - !overrides.is_empty(), - "At least one hash key override should be set" - ); assert_eq!( overrides.get("test_flag"), Some(&"hash_key_2".to_string()), - "Hash key override for 'test_flag' should match the set value" + "Hash key override should match the set value" ); } @@ -3271,10 +4001,12 @@ mod tests { "Hash key override should match the set value" ); } + #[tokio::test] async fn test_evaluate_feature_flags_with_experience_continuity() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3304,6 +4036,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -3337,7 +4070,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ) @@ -3356,6 +4089,7 @@ mod tests { async fn test_evaluate_feature_flags_with_continuity_missing_override() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3385,6 +4119,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -3408,7 +4143,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ) @@ -3427,6 +4162,7 @@ mod tests { async fn test_evaluate_all_feature_flags_mixed_continuity() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3456,6 +4192,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -3484,6 +4221,7 @@ mod tests { operator: Some(OperatorType::Gt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -3517,7 +4255,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ) diff --git a/rust/feature-flags/src/flag_request.rs b/rust/feature-flags/src/flag_request.rs index 771c216834c96..1cf64eb879ac4 100644 --- a/rust/feature-flags/src/flag_request.rs +++ b/rust/feature-flags/src/flag_request.rs @@ -158,8 +158,8 @@ impl FlagRequest { pub async fn get_flags_from_cache_or_pg( &self, team_id: i32, - redis_client: Arc, - pg_client: Arc, + redis_client: &Arc, + pg_client: &Arc, ) -> Result { let mut cache_hit = false; let flags = match FeatureFlagList::from_redis(redis_client.clone(), team_id).await { @@ -167,10 +167,14 @@ impl FlagRequest { cache_hit = true; Ok(flags) } - Err(_) => match FeatureFlagList::from_pg(pg_client, team_id).await { + Err(_) => match FeatureFlagList::from_pg(pg_client.clone(), team_id).await { Ok(flags) => { - if let Err(e) = - FeatureFlagList::update_flags_in_redis(redis_client, team_id, &flags).await + if let Err(e) = FeatureFlagList::update_flags_in_redis( + redis_client.clone(), + team_id, + &flags, + ) + .await { tracing::warn!("Failed to update Redis cache: {}", e); // TODO add new metric category for this @@ -206,7 +210,6 @@ mod tests { TEAM_FLAGS_CACHE_PREFIX, }; use crate::flag_request::FlagRequest; - use crate::redis::Client as RedisClient; use crate::team::Team; use crate::test_utils::{insert_new_team_in_redis, setup_pg_reader_client, setup_redis_client}; use bytes::Bytes; @@ -360,6 +363,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(50.0), variant: None, @@ -402,6 +406,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -426,7 +431,7 @@ mod tests { // Test fetching from Redis let result = flag_request - .get_flags_from_cache_or_pg(team.id, redis_client.clone(), pg_client.clone()) + .get_flags_from_cache_or_pg(team.id, &redis_client, &pg_client) .await; assert!(result.is_ok()); let fetched_flags = result.unwrap(); @@ -483,7 +488,7 @@ mod tests { .expect("Failed to remove flags from Redis"); let result = flag_request - .get_flags_from_cache_or_pg(team.id, redis_client.clone(), pg_client.clone()) + .get_flags_from_cache_or_pg(team.id, &redis_client, &pg_client) .await; assert!(result.is_ok()); // Verify that the flags were re-added to Redis diff --git a/rust/feature-flags/src/lib.rs b/rust/feature-flags/src/lib.rs index 051b3e27697f3..67659bfcf9dcd 100644 --- a/rust/feature-flags/src/lib.rs +++ b/rust/feature-flags/src/lib.rs @@ -1,4 +1,7 @@ pub mod api; +pub mod cohort_cache; +pub mod cohort_models; +pub mod cohort_operations; pub mod config; pub mod database; pub mod feature_flag_match_reason; @@ -8,13 +11,13 @@ pub mod flag_matching; pub mod flag_request; pub mod geoip; pub mod metrics_consts; +pub mod metrics_utils; pub mod property_matching; pub mod redis; pub mod request_handler; pub mod router; pub mod server; pub mod team; -pub mod utils; pub mod v0_endpoint; // Test modules don't need to be compiled with main binary diff --git a/rust/feature-flags/src/utils.rs b/rust/feature-flags/src/metrics_utils.rs similarity index 100% rename from rust/feature-flags/src/utils.rs rename to rust/feature-flags/src/metrics_utils.rs diff --git a/rust/feature-flags/src/property_matching.rs b/rust/feature-flags/src/property_matching.rs index 8d12fe6ab5e9d..84479f131611f 100644 --- a/rust/feature-flags/src/property_matching.rs +++ b/rust/feature-flags/src/property_matching.rs @@ -44,7 +44,7 @@ pub fn match_property( } let key = &property.key; - let operator = property.operator.clone().unwrap_or(OperatorType::Exact); + let operator = property.operator.unwrap_or(OperatorType::Exact); let value = &property.value; let match_value = matching_property_values.get(key); @@ -193,6 +193,12 @@ pub fn match_property( // Ok(false) // } } + OperatorType::In | OperatorType::NotIn => { + // TODO: we handle these in cohort matching, so we can just return false here + // because by the time we match properties, we've already decomposed the cohort + // filter into multiple property filters + Ok(false) + } } } @@ -260,6 +266,7 @@ mod test_match_properties { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -313,6 +320,7 @@ mod test_match_properties { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -335,6 +343,7 @@ mod test_match_properties { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -379,6 +388,7 @@ mod test_match_properties { operator: Some(OperatorType::IsNot), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -416,6 +426,7 @@ mod test_match_properties { operator: Some(OperatorType::IsNot), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -490,6 +501,7 @@ mod test_match_properties { operator: Some(OperatorType::IsSet), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -538,6 +550,7 @@ mod test_match_properties { operator: Some(OperatorType::Icontains), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -595,6 +608,7 @@ mod test_match_properties { operator: Some(OperatorType::Icontains), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -634,6 +648,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -674,6 +689,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( &property_b, @@ -708,6 +724,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -730,6 +747,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( &property_d, @@ -760,6 +778,7 @@ mod test_match_properties { operator: Some(OperatorType::Gt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -802,6 +821,7 @@ mod test_match_properties { operator: Some(OperatorType::Lt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -848,6 +868,7 @@ mod test_match_properties { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -889,6 +910,7 @@ mod test_match_properties { operator: Some(OperatorType::Lt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -935,6 +957,7 @@ mod test_match_properties { operator: Some(OperatorType::Lt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -1013,6 +1036,7 @@ mod test_match_properties { operator: Some(OperatorType::IsNot), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1034,6 +1058,7 @@ mod test_match_properties { operator: Some(OperatorType::IsSet), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -1049,6 +1074,7 @@ mod test_match_properties { operator: Some(OperatorType::Icontains), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -1070,6 +1096,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1085,6 +1112,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1118,6 +1146,7 @@ mod test_match_properties { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1137,6 +1166,7 @@ mod test_match_properties { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1152,6 +1182,7 @@ mod test_match_properties { operator: Some(OperatorType::IsSet), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1167,6 +1198,7 @@ mod test_match_properties { operator: Some(OperatorType::IsNotSet), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -1203,6 +1235,7 @@ mod test_match_properties { operator: Some(OperatorType::Icontains), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1218,6 +1251,7 @@ mod test_match_properties { operator: Some(OperatorType::NotIcontains), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1233,6 +1267,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1248,6 +1283,7 @@ mod test_match_properties { operator: Some(OperatorType::NotRegex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1263,6 +1299,7 @@ mod test_match_properties { operator: Some(OperatorType::Gt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1278,6 +1315,7 @@ mod test_match_properties { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1293,6 +1331,7 @@ mod test_match_properties { operator: Some(OperatorType::Lt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1308,6 +1347,7 @@ mod test_match_properties { operator: Some(OperatorType::Lte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1324,6 +1364,7 @@ mod test_match_properties { operator: Some(OperatorType::IsDateBefore), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( diff --git a/rust/feature-flags/src/request_handler.rs b/rust/feature-flags/src/request_handler.rs index 5e0be8faacc59..538c6845d2a02 100644 --- a/rust/feature-flags/src/request_handler.rs +++ b/rust/feature-flags/src/request_handler.rs @@ -1,5 +1,6 @@ use crate::{ api::{FlagError, FlagsResponse}, + cohort_cache::CohortCacheManager, database::Client, flag_definitions::FeatureFlagList, flag_matching::{FeatureFlagMatcher, GroupTypeMappingCache}, @@ -69,6 +70,7 @@ pub struct FeatureFlagEvaluationContext { feature_flags: FeatureFlagList, postgres_reader: Arc, postgres_writer: Arc, + cohort_cache: Arc, #[builder(default)] person_property_overrides: Option>, #[builder(default)] @@ -108,18 +110,16 @@ pub async fn process_request(context: RequestContext) -> Result = state.postgres_reader.clone(); - let postgres_writer_dyn: Arc = state.postgres_writer.clone(); - let evaluation_context = FeatureFlagEvaluationContextBuilder::default() .team_id(team_id) .distinct_id(distinct_id) .feature_flags(feature_flags_from_cache_or_pg) - .postgres_reader(postgres_reader_dyn) - .postgres_writer(postgres_writer_dyn) + .postgres_reader(state.postgres_reader.clone()) + .postgres_writer(state.postgres_writer.clone()) + .cohort_cache(state.cohort_cache.clone()) .person_property_overrides(person_property_overrides) .group_property_overrides(group_property_overrides) .groups(groups) @@ -224,8 +224,8 @@ pub async fn evaluate_feature_flags(context: FeatureFlagEvaluationContext) -> Fl context.team_id, context.postgres_reader, context.postgres_writer, + context.cohort_cache, Some(group_type_mapping_cache), - None, // TODO maybe remove this from the matcher struct, since it's used internally but not passed around context.groups, ); feature_flag_matcher @@ -359,6 +359,7 @@ mod tests { async fn test_evaluate_feature_flags() { let postgres_reader: Arc = setup_pg_reader_client(None).await; let postgres_writer: Arc = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let flag = FeatureFlag { name: Some("Test Flag".to_string()), id: 1, @@ -374,6 +375,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), // Set to 100% to ensure it's always on variant: None, @@ -397,6 +399,7 @@ mod tests { .feature_flags(feature_flag_list) .postgres_reader(postgres_reader) .postgres_writer(postgres_writer) + .cohort_cache(cohort_cache) .person_property_overrides(Some(person_properties)) .build() .expect("Failed to build FeatureFlagEvaluationContext"); @@ -505,6 +508,7 @@ mod tests { async fn test_evaluate_feature_flags_multiple_flags() { let postgres_reader: Arc = setup_pg_reader_client(None).await; let postgres_writer: Arc = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let flags = vec![ FeatureFlag { name: Some("Flag 1".to_string()), @@ -556,6 +560,7 @@ mod tests { .feature_flags(feature_flag_list) .postgres_reader(postgres_reader) .postgres_writer(postgres_writer) + .cohort_cache(cohort_cache) .build() .expect("Failed to build FeatureFlagEvaluationContext"); @@ -608,6 +613,7 @@ mod tests { async fn test_evaluate_feature_flags_with_overrides() { let postgres_reader: Arc = setup_pg_reader_client(None).await; let postgres_writer: Arc = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -627,6 +633,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "group".to_string(), group_type_index: Some(0), + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -655,6 +662,7 @@ mod tests { .feature_flags(feature_flag_list) .postgres_reader(postgres_reader) .postgres_writer(postgres_writer) + .cohort_cache(cohort_cache) .group_property_overrides(Some(group_property_overrides)) .groups(Some(groups)) .build() @@ -688,6 +696,7 @@ mod tests { let long_id = "a".repeat(1000); let postgres_reader: Arc = setup_pg_reader_client(None).await; let postgres_writer: Arc = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let flag = FeatureFlag { name: Some("Test Flag".to_string()), id: 1, @@ -717,6 +726,7 @@ mod tests { .feature_flags(feature_flag_list) .postgres_reader(postgres_reader) .postgres_writer(postgres_writer) + .cohort_cache(cohort_cache) .build() .expect("Failed to build FeatureFlagEvaluationContext"); diff --git a/rust/feature-flags/src/router.rs b/rust/feature-flags/src/router.rs index 505f18adfb008..e34ea31a3c65a 100644 --- a/rust/feature-flags/src/router.rs +++ b/rust/feature-flags/src/router.rs @@ -9,11 +9,12 @@ use health::HealthRegistry; use tower::limit::ConcurrencyLimitLayer; use crate::{ + cohort_cache::CohortCacheManager, config::{Config, TeamIdsToTrack}, database::Client as DatabaseClient, geoip::GeoIpClient, + metrics_utils::team_id_label_filter, redis::Client as RedisClient, - utils::team_id_label_filter, v0_endpoint, }; @@ -22,6 +23,7 @@ pub struct State { pub redis: Arc, pub postgres_reader: Arc, pub postgres_writer: Arc, + pub cohort_cache: Arc, // TODO does this need a better name than just `cohort_cache`? pub geoip: Arc, pub team_ids_to_track: TeamIdsToTrack, } @@ -30,6 +32,7 @@ pub fn router( redis: Arc, postgres_reader: Arc, postgres_writer: Arc, + cohort_cache: Arc, geoip: Arc, liveness: HealthRegistry, config: Config, @@ -42,6 +45,7 @@ where redis, postgres_reader, postgres_writer, + cohort_cache, geoip, team_ids_to_track: config.team_ids_to_track.clone(), }; diff --git a/rust/feature-flags/src/server.rs b/rust/feature-flags/src/server.rs index c9e238fa8fd4e..69ff759ddfcdf 100644 --- a/rust/feature-flags/src/server.rs +++ b/rust/feature-flags/src/server.rs @@ -6,6 +6,7 @@ use std::time::Duration; use health::{HealthHandle, HealthRegistry}; use tokio::net::TcpListener; +use crate::cohort_cache::CohortCacheManager; use crate::config::Config; use crate::database::get_pool; use crate::geoip::GeoIpClient; @@ -54,6 +55,8 @@ where } }; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let health = HealthRegistry::new("liveness"); // TODO - we don't have a more complex health check yet, but we should add e.g. some around DB operations @@ -67,6 +70,7 @@ where redis_client, postgres_reader, postgres_writer, + cohort_cache, geoip_service, health, config, diff --git a/rust/feature-flags/src/team.rs b/rust/feature-flags/src/team.rs index 0fa75f0bd3db7..f13cf29094b85 100644 --- a/rust/feature-flags/src/team.rs +++ b/rust/feature-flags/src/team.rs @@ -42,7 +42,7 @@ impl Team { // TODO: Consider an LRU cache for teams as well, with small TTL to skip redis/pg lookups let team: Team = serde_json::from_str(&serialized_team).map_err(|e| { tracing::error!("failed to parse data to team: {}", e); - FlagError::DataParsingError + FlagError::RedisDataParsingError })?; Ok(team) @@ -55,7 +55,7 @@ impl Team { ) -> Result<(), FlagError> { let serialized_team = serde_json::to_string(&team).map_err(|e| { tracing::error!("Failed to serialize team: {}", e); - FlagError::DataParsingError + FlagError::RedisDataParsingError })?; client @@ -173,7 +173,7 @@ mod tests { let client = setup_redis_client(None); match Team::from_redis(client.clone(), team.api_token.clone()).await { - Err(FlagError::DataParsingError) => (), + Err(FlagError::RedisDataParsingError) => (), Err(other) => panic!("Expected DataParsingError, got {:?}", other), Ok(_) => panic!("Expected DataParsingError"), }; diff --git a/rust/feature-flags/src/test_utils.rs b/rust/feature-flags/src/test_utils.rs index 32a2016bf756b..22f7753d97910 100644 --- a/rust/feature-flags/src/test_utils.rs +++ b/rust/feature-flags/src/test_utils.rs @@ -1,11 +1,12 @@ use anyhow::Error; use axum::async_trait; use serde_json::{json, Value}; -use sqlx::{pool::PoolConnection, postgres::PgRow, Error as SqlxError, PgPool, Postgres}; +use sqlx::{pool::PoolConnection, postgres::PgRow, Error as SqlxError, Postgres}; use std::sync::Arc; use uuid::Uuid; use crate::{ + cohort_models::Cohort, config::{Config, DEFAULT_TEST_CONFIG}, database::{get_pool, Client, CustomDatabaseError}, flag_definitions::{self, FeatureFlag, FeatureFlagRow}, @@ -23,7 +24,9 @@ pub fn random_string(prefix: &str, length: usize) -> String { format!("{}{}", prefix, suffix) } -pub async fn insert_new_team_in_redis(client: Arc) -> Result { +pub async fn insert_new_team_in_redis( + client: Arc, +) -> Result { let id = rand::thread_rng().gen_range(0..10_000_000); let token = random_string("phc_", 12); let team = Team { @@ -48,7 +51,7 @@ pub async fn insert_new_team_in_redis(client: Arc) -> Result, + client: Arc, team_id: i32, json_value: Option, ) -> Result<(), Error> { @@ -88,7 +91,7 @@ pub async fn insert_flags_for_team_in_redis( Ok(()) } -pub fn setup_redis_client(url: Option) -> Arc { +pub fn setup_redis_client(url: Option) -> Arc { let redis_url = match url { Some(value) => value, None => "redis://localhost:6379/".to_string(), @@ -130,7 +133,7 @@ pub fn create_flag_from_json(json_value: Option) -> Vec { flags } -pub async fn setup_pg_reader_client(config: Option<&Config>) -> Arc { +pub async fn setup_pg_reader_client(config: Option<&Config>) -> Arc { let config = config.unwrap_or(&DEFAULT_TEST_CONFIG); Arc::new( get_pool(&config.read_database_url, config.max_pg_connections) @@ -139,7 +142,7 @@ pub async fn setup_pg_reader_client(config: Option<&Config>) -> Arc { ) } -pub async fn setup_pg_writer_client(config: Option<&Config>) -> Arc { +pub async fn setup_pg_writer_client(config: Option<&Config>) -> Arc { let config = config.unwrap_or(&DEFAULT_TEST_CONFIG); Arc::new( get_pool(&config.write_database_url, config.max_pg_connections) @@ -261,7 +264,7 @@ pub async fn insert_new_team_in_pg( } pub async fn insert_flag_for_team_in_pg( - client: Arc, + client: Arc, team_id: i32, flag: Option, ) -> Result { @@ -310,7 +313,7 @@ pub async fn insert_flag_for_team_in_pg( } pub async fn insert_person_for_team_in_pg( - client: Arc, + client: Arc, team_id: i32, distinct_id: String, properties: Option, @@ -352,3 +355,58 @@ pub async fn insert_person_for_team_in_pg( Ok(()) } + +pub async fn insert_cohort_for_team_in_pg( + client: Arc, + team_id: i32, + name: Option, + filters: serde_json::Value, + is_static: bool, +) -> Result { + let cohort = Cohort { + id: 0, // Placeholder, will be updated after insertion + name: name.unwrap_or("Test Cohort".to_string()), + description: Some("Description for cohort".to_string()), + team_id, + deleted: false, + filters, + query: None, + version: Some(1), + pending_version: None, + count: None, + is_calculating: false, + is_static, + errors_calculating: 0, + groups: serde_json::json!([]), + created_by_id: None, + }; + + let mut conn = client.get_connection().await?; + let row: (i32,) = sqlx::query_as( + r#"INSERT INTO posthog_cohort + (name, description, team_id, deleted, filters, query, version, pending_version, count, is_calculating, is_static, errors_calculating, groups, created_by_id) VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + RETURNING id"#, + ) + .bind(&cohort.name) + .bind(&cohort.description) + .bind(cohort.team_id) + .bind(cohort.deleted) + .bind(&cohort.filters) + .bind(&cohort.query) + .bind(cohort.version) + .bind(cohort.pending_version) + .bind(cohort.count) + .bind(cohort.is_calculating) + .bind(cohort.is_static) + .bind(cohort.errors_calculating) + .bind(&cohort.groups) + .bind(cohort.created_by_id) + .fetch_one(&mut *conn) + .await?; + + // Update the cohort_row with the actual id generated by sqlx + let id = row.0; + + Ok(Cohort { id, ..cohort }) +} diff --git a/rust/feature-flags/tests/test_flag_matching_consistency.rs b/rust/feature-flags/tests/test_flag_matching_consistency.rs index 94f4f67dcdc56..c632d28bc151d 100644 --- a/rust/feature-flags/tests/test_flag_matching_consistency.rs +++ b/rust/feature-flags/tests/test_flag_matching_consistency.rs @@ -1,3 +1,6 @@ +use std::sync::Arc; + +use feature_flags::cohort_cache::CohortCacheManager; use feature_flags::feature_flag_match_reason::FeatureFlagMatchReason; /// These tests are common between all libraries doing local evaluation of feature flags. /// This ensures there are no mismatches between implementations. @@ -110,6 +113,7 @@ async fn it_is_consistent_with_rollout_calculation_for_simple_flags() { for (i, result) in results.iter().enumerate().take(1000) { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let distinct_id = format!("distinct_id_{}", i); @@ -118,7 +122,7 @@ async fn it_is_consistent_with_rollout_calculation_for_simple_flags() { 1, postgres_reader, postgres_writer, - None, + cohort_cache, None, None, ) @@ -1209,6 +1213,7 @@ async fn it_is_consistent_with_rollout_calculation_for_multivariate_flags() { for (i, result) in results.iter().enumerate().take(1000) { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let distinct_id = format!("distinct_id_{}", i); let feature_flag_match = FeatureFlagMatcher::new( @@ -1216,7 +1221,7 @@ async fn it_is_consistent_with_rollout_calculation_for_multivariate_flags() { 1, postgres_reader, postgres_writer, - None, + cohort_cache, None, None, )