diff --git a/extension/src/countminsketch.rs b/extension/src/countminsketch.rs index 48f68c05..e8027df3 100644 --- a/extension/src/countminsketch.rs +++ b/extension/src/countminsketch.rs @@ -127,9 +127,8 @@ impl toolkit_experimental::count_min_sketch { } #[pg_extern(immutable, parallel_safe, schema = "toolkit_experimental")] -pub fn approx_count(item: String, aggregate: Option) -> i64 { - let sketch = aggregate.unwrap(); - CountMinSketch::to_internal_countminsketch(&sketch).estimate(item) +pub fn approx_count(item: String, aggregate: Option) -> Option { + aggregate.map(|sketch| CountMinSketch::to_internal_countminsketch(&sketch).estimate(item)) } #[cfg(any(test, feature = "pg_test"))] @@ -266,4 +265,19 @@ mod tests { assert_eq!(output, None) }) } + + #[pg_test] + fn test_approx_count_null_input_yields_null_output() { + Spi::execute(|client| { + let output = client + .select( + "SELECT toolkit_experimental.approx_count('1'::text, NULL::toolkit_experimental.countminsketch)", + None, + None, + ) + .first() + .get_one::(); + assert_eq!(output, None) + }) + } }