Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(flags): dynamic cohort matching in rust #25776

Merged
merged 41 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
ca431b5
unifying some types
dmarticus Oct 15, 2024
e89f169
in progress but not done yet
dmarticus Oct 15, 2024
4f20e07
Merge branch 'master' into feat/static-cohorts-rust
dmarticus Oct 23, 2024
ed00224
oh lol right let's actually ship
dmarticus Oct 23, 2024
fb8aab8
or default
dmarticus Oct 23, 2024
899a99c
Merge branch 'master' into feat/dynamic-cohorts-rust
dmarticus Oct 24, 2024
896c31a
Merge branch 'master' into feat/dynamic-cohorts-rust
dmarticus Oct 24, 2024
eeea8cc
let's goooo
dmarticus Oct 24, 2024
d02baec
Merge branch 'feat/dynamic-cohorts-rust' of github.com:PostHog/postho…
dmarticus Oct 24, 2024
39dad2d
Merge branch 'master' into feat/dynamic-cohorts-rust
dmarticus Oct 24, 2024
db8cd8d
modeled the data correctly this time :sweat:
dmarticus Oct 24, 2024
43cda76
clippy my frickin GUY
dmarticus Oct 24, 2024
8d2ab85
some light renaming
dmarticus Oct 25, 2024
9ccf479
yeah
dmarticus Oct 25, 2024
797adbe
remove printlns
dmarticus Oct 25, 2024
71def67
add note about not handling groups
dmarticus Oct 28, 2024
27af814
saving a working version that supports caching, since this is the rig…
dmarticus Oct 29, 2024
4c49bc4
new life
dmarticus Oct 30, 2024
d4af2f0
clippy u dawg
dmarticus Oct 30, 2024
870f719
traverse the dependency graph post-cache access
dmarticus Oct 31, 2024
57d9885
cleaning up
dmarticus Oct 31, 2024
9eb0f18
adding more tests
dmarticus Oct 31, 2024
3cfc590
test for the cohort cache
dmarticus Oct 31, 2024
3e8e5d2
a few things
dmarticus Oct 31, 2024
3528b31
Merge branch 'master' into feat/dynamic-cohorts-rust
dmarticus Oct 31, 2024
77059f3
Merge branch 'master' into feat/dynamic-cohorts-rust
dmarticus Nov 1, 2024
09317c4
use global cohort cache
dmarticus Nov 1, 2024
43e8692
less yapping
dmarticus Nov 1, 2024
3a65683
appeasing the linter
dmarticus Nov 2, 2024
a5812e6
that should do it
dmarticus Nov 2, 2024
fd52b24
clean up
dmarticus Nov 4, 2024
59f7c10
rename
dmarticus Nov 4, 2024
8066aff
bit more
dmarticus Nov 4, 2024
4d5ecd9
collapse condition
dmarticus Nov 4, 2024
4012ebe
Merge branch 'master' into feat/dynamic-cohorts-rust
dmarticus Nov 4, 2024
8ededb1
resolve conflicts
dmarticus Nov 6, 2024
fe37b04
working on it
dmarticus Nov 7, 2024
bc38940
Merge branch 'feat/dynamic-cohorts-rust' of github.com:PostHog/postho…
dmarticus Nov 7, 2024
41d3db3
not this either
dmarticus Nov 7, 2024
0a409f4
docs
dmarticus Nov 7, 2024
0dd1c0b
Merge branch 'master' into feat/dynamic-cohorts-rust
dmarticus Nov 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions rust/feature-flags/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ health = { path = "../common/health" }
common-metrics = { path = "../common/metrics" }
tower = { workspace = true }
derive_builder = "0.20.1"
petgraph = "0.6.5"
oliverb123 marked this conversation as resolved.
Show resolved Hide resolved
moka = { version = "0.12.8", features = ["future"] }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

caching lib with support for TTL and feature weighting

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a heads up, this is already in the workspace, you can probably pull it in (we're using it in error tracking).


[lints]
workspace = true
Expand Down
24 changes: 21 additions & 3 deletions rust/feature-flags/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ pub enum FlagError {
#[error("Row not found in postgres")]
RowNotFound,
#[error("failed to parse redis cache data")]
DataParsingError,
RedisDataParsingError,
#[error("failed to update redis cache")]
CacheUpdateError,
#[error("redis unavailable")]
Expand All @@ -102,6 +102,12 @@ pub enum FlagError {
TimeoutError,
#[error("No group type mappings")]
NoGroupTypeMappings,
#[error("Cohort not found")]
CohortNotFound(String),
#[error("Failed to parse cohort filters")]
CohortFiltersParsingError,
#[error("Cohort dependency cycle")]
CohortDependencyCycle(String),
}

impl IntoResponse for FlagError {
Expand Down Expand Up @@ -138,7 +144,7 @@ impl IntoResponse for FlagError {
FlagError::TokenValidationError => {
(StatusCode::UNAUTHORIZED, "The provided API key is invalid or has expired. Please check your API key and try again.".to_string())
}
FlagError::DataParsingError => {
FlagError::RedisDataParsingError => {
tracing::error!("Data parsing error: {:?}", self);
(
StatusCode::SERVICE_UNAVAILABLE,
Expand Down Expand Up @@ -194,6 +200,18 @@ impl IntoResponse for FlagError {
"The requested row was not found in the database. Please try again later or contact support if the problem persists.".to_string(),
)
}
FlagError::CohortNotFound(msg) => {
tracing::error!("Cohort not found: {}", msg);
(StatusCode::NOT_FOUND, msg)
}
FlagError::CohortFiltersParsingError => {
tracing::error!("Failed to parse cohort filters: {:?}", self);
(StatusCode::BAD_REQUEST, "Failed to parse cohort filters. Please try again later or contact support if the problem persists.".to_string())
}
FlagError::CohortDependencyCycle(msg) => {
tracing::error!("Cohort dependency cycle: {}", msg);
(StatusCode::BAD_REQUEST, msg)
}
}
.into_response()
}
Expand All @@ -205,7 +223,7 @@ impl From<CustomRedisError> for FlagError {
CustomRedisError::NotFound => FlagError::TokenValidationError,
CustomRedisError::PickleError(e) => {
tracing::error!("failed to fetch data: {}", e);
FlagError::DataParsingError
FlagError::RedisDataParsingError
}
CustomRedisError::Timeout(_) => FlagError::TimeoutError,
CustomRedisError::Other(e) => {
Expand Down
221 changes: 221 additions & 0 deletions rust/feature-flags/src/cohort_cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
use crate::api::FlagError;
use crate::cohort_models::Cohort;
use crate::flag_matching::{PostgresReader, TeamId};
use moka::future::Cache;
use std::time::Duration;

/// CohortCacheManager manages the in-memory cache of cohorts using `moka` for caching.
///
/// Features:
/// - **TTL**: Each cache entry expires after 5 minutes.
/// - **Size-based eviction**: The cache evicts least recently used entries when the maximum capacity is reached.
///
/// ```text
/// CohortCacheManager {
/// postgres_reader: PostgresReader,
/// per_team_cohorts: Cache<TeamId, Vec<Cohort>> {
/// // Example:
/// 2: [
/// Cohort { id: 1, name: "Power Users", filters: {...} },
/// Cohort { id: 2, name: "Churned", filters: {...} }
/// ],
/// 5: [
/// Cohort { id: 3, name: "Beta Users", filters: {...} }
/// ]
/// }
/// }
/// ```
///
#[derive(Clone)]
pub struct CohortCacheManager {
postgres_reader: PostgresReader,
per_team_cohort_cache: Cache<TeamId, Vec<Cohort>>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, same as below re: postgres_reader I suppose, but I know from the type that the cache is per-team (it's got TeamId as a key), and I know it's caching cohorts. You can be shorter here, the type shows up everywhere it's used.

This is purely taste though, if you disagree feel free to ignore.

}
oliverb123 marked this conversation as resolved.
Show resolved Hide resolved

impl CohortCacheManager {
pub fn new(
postgres_reader: PostgresReader,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, but if the variable name is the same as the type name, I go for stuff like "pr: PostgresReader" - the ide tells me everything I need to know about it anyway. I'd make it reader: PostgresReader in the struct declaration

max_capacity: Option<u64>,
ttl_seconds: Option<u64>,
) -> Self {
// We use the size of the cohort list (i.e., the number of cohorts for a given team)as the weight of the entry
let weigher =
|_: &TeamId, value: &Vec<Cohort>| -> u32 { value.len().try_into().unwrap_or(u32::MAX) };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought about casts generally, I think this is totally fine and you shouldn't pay the CI time to change it:

I'd almost argue for a raw unwrap here (or an expect with a helpful message), under the consideration you probably do want to fail loudly if a team has more than u32::MAX cohorts, but also, you'll never end up in this situation because fetching them would bring down postgres, you'd OOM, etc, so I'd then almost go for an as cast instead, knowing the truncation will never happen.


let cache = Cache::builder()
.time_to_live(Duration::from_secs(ttl_seconds.unwrap_or(300))) // Default to 5 minutes
.weigher(weigher)
.max_capacity(max_capacity.unwrap_or(10_000)) // Default to 10,000 cohorts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This default strikes me as quite low, I'd bump it an order of magnitude (or set it an order of magnitude larger) - that's a pure gut feeling though.

.build();

Self {
postgres_reader,
per_team_cohort_cache: cache,
}
}

/// Retrieves cohorts for a given team.
///
/// If the cohorts are not present in the cache or have expired, it fetches them from the database,
/// caches the result upon successful retrieval, and then returns it.
pub async fn get_cohorts_for_team(&self, team_id: TeamId) -> Result<Vec<Cohort>, FlagError> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above, I know it's for_team because you ask me for a team id (this is the last I'll leave of these, mostly just highlighting you can often be more concise without losing information the caller needs here, but again, taste)

if let Some(cached_cohorts) = self.per_team_cohort_cache.get(&team_id).await {
return Ok(cached_cohorts.clone());
}
let fetched_cohorts = Cohort::list_from_pg(self.postgres_reader.clone(), team_id).await?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only a note: I'm a fan of taking a manager-wide lock here (called, say, fetch_lock: Mutex<()> in the struct decl), to prevent multiple fetches to the same cohorts, but that's totally an optimisation you can skip until it becomes a problem (basically until postgres gives out).

self.per_team_cohort_cache
.insert(team_id, fetched_cohorts.clone())
.await;

Ok(fetched_cohorts)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::cohort_models::Cohort;
use crate::test_utils::{
insert_cohort_for_team_in_pg, insert_new_team_in_pg, setup_pg_reader_client,
setup_pg_writer_client,
};
use std::sync::Arc;
use tokio::time::{sleep, Duration};

/// Helper function to setup a new team for testing.
async fn setup_test_team(
writer_client: Arc<dyn crate::database::Client + Send + Sync>,
) -> Result<TeamId, anyhow::Error> {
let team = crate::test_utils::insert_new_team_in_pg(writer_client, None).await?;
Ok(team.id)
}

/// Helper function to insert a cohort for a team.
async fn setup_test_cohort(
writer_client: Arc<dyn crate::database::Client + Send + Sync>,
team_id: TeamId,
name: Option<String>,
) -> Result<Cohort, anyhow::Error> {
let filters = serde_json::json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$active", "type": "person", "value": [true], "negation": false, "operator": "exact"}]}]}});
insert_cohort_for_team_in_pg(writer_client, team_id, name, filters, false).await
}

/// Tests that cache entries expire after the specified TTL.
#[tokio::test]
async fn test_cache_expiry() -> Result<(), anyhow::Error> {
let writer_client = setup_pg_writer_client(None).await;
let reader_client = setup_pg_reader_client(None).await;

let team_id = setup_test_team(writer_client.clone()).await?;
let _cohort = setup_test_cohort(writer_client.clone(), team_id, None).await?;

// Initialize CohortCacheManager with a short TTL for testing
let cohort_cache = CohortCacheManager::new(
reader_client.clone(),
Some(100),
Some(1), // 1-second TTL
);

let cohorts = cohort_cache.get_cohorts_for_team(team_id).await?;
assert_eq!(cohorts.len(), 1);
assert_eq!(cohorts[0].team_id, team_id);

let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await;
assert!(cached_cohorts.is_some());

// Wait for TTL to expire
sleep(Duration::from_secs(2)).await;

// Attempt to retrieve from cache again
let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await;
assert!(cached_cohorts.is_none(), "Cache entry should have expired");

Ok(())
}

/// Tests that the cache correctly evicts least recently used entries based on the weigher.
#[tokio::test]
async fn test_cache_weigher() -> Result<(), anyhow::Error> {
let writer_client = setup_pg_writer_client(None).await;
let reader_client = setup_pg_reader_client(None).await;

// Define a smaller max_capacity for testing
let max_capacity: u64 = 3;

let cohort_cache = CohortCacheManager::new(reader_client.clone(), Some(max_capacity), None);

let mut inserted_team_ids = Vec::new();

// Insert multiple teams and their cohorts
for _ in 0..max_capacity {
let team = insert_new_team_in_pg(writer_client.clone(), None).await?;
let team_id = team.id;
inserted_team_ids.push(team_id);
setup_test_cohort(writer_client.clone(), team_id, None).await?;
cohort_cache.get_cohorts_for_team(team_id).await?;
}

cohort_cache.per_team_cohort_cache.run_pending_tasks().await;
let cache_size = cohort_cache.per_team_cohort_cache.entry_count();
assert_eq!(
cache_size, max_capacity,
"Cache size should be equal to max_capacity"
);

let new_team = insert_new_team_in_pg(writer_client.clone(), None).await?;
let new_team_id = new_team.id;
setup_test_cohort(writer_client.clone(), new_team_id, None).await?;
cohort_cache.get_cohorts_for_team(new_team_id).await?;

cohort_cache.per_team_cohort_cache.run_pending_tasks().await;
let cache_size_after = cohort_cache.per_team_cohort_cache.entry_count();
assert_eq!(
cache_size_after, max_capacity,
"Cache size should remain equal to max_capacity after eviction"
);

let evicted_team_id = &inserted_team_ids[0];
let cached_cohorts = cohort_cache
.per_team_cohort_cache
.get(evicted_team_id)
.await;
assert!(
cached_cohorts.is_none(),
"Least recently used cache entry should have been evicted"
);

let cached_new_team = cohort_cache.per_team_cohort_cache.get(&new_team_id).await;
assert!(
cached_new_team.is_some(),
"Newly added cache entry should be present"
);

Ok(())
}

#[tokio::test]
async fn test_get_cohorts_for_team() -> Result<(), anyhow::Error> {
let writer_client = setup_pg_writer_client(None).await;
let reader_client = setup_pg_reader_client(None).await;
let team_id = setup_test_team(writer_client.clone()).await?;
let _cohort = setup_test_cohort(writer_client.clone(), team_id, None).await?;
let cohort_cache = CohortCacheManager::new(reader_client.clone(), None, None);

let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await;
assert!(cached_cohorts.is_none(), "Cache should initially be empty");

let cohorts = cohort_cache.get_cohorts_for_team(team_id).await?;
assert_eq!(cohorts.len(), 1);
assert_eq!(cohorts[0].team_id, team_id);

let cached_cohorts = cohort_cache
.per_team_cohort_cache
.get(&team_id)
.await
.unwrap();
assert_eq!(cached_cohorts.len(), 1);
assert_eq!(cached_cohorts[0].team_id, team_id);

Ok(())
}
}
50 changes: 50 additions & 0 deletions rust/feature-flags/src/cohort_models.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use crate::flag_definitions::PropertyFilter;
use serde::{Deserialize, Serialize};
use sqlx::FromRow;

#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct Cohort {
pub id: i32,
pub name: String,
pub description: Option<String>,
pub team_id: i32,
pub deleted: bool,
pub filters: serde_json::Value,
pub query: Option<serde_json::Value>,
pub version: Option<i32>,
pub pending_version: Option<i32>,
pub count: Option<i32>,
pub is_calculating: bool,
pub is_static: bool,
pub errors_calculating: i32,
pub groups: serde_json::Value,
pub created_by_id: Option<i32>,
}

pub type CohortId = i32;

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[serde(rename_all = "UPPERCASE")]
pub enum CohortPropertyType {
AND,
OR,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CohortProperty {
pub properties: InnerCohortProperty,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct InnerCohortProperty {
#[serde(rename = "type")]
pub prop_type: CohortPropertyType,
pub values: Vec<CohortValues>,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CohortValues {
#[serde(rename = "type")]
pub prop_type: String,
pub values: Vec<PropertyFilter>,
}
Loading
Loading