Skip to content

Commit

Permalink
Merge pull request #60 from outbrain/cms
Browse files Browse the repository at this point in the history
Count-min-sketch with an extension
  • Loading branch information
SkBlaz authored Jan 26, 2024
2 parents f29ff14 + 2164df7 commit 04d7b2d
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 2 deletions.
72 changes: 72 additions & 0 deletions outrank/algorithms/sketches/counting_cms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

import sys
from collections import Counter

import numpy as np
from numba import njit
from numba import prange


@njit
def cms_hash(x, seed, width):
x_hash = np.uint32(hash(x))
return (x_hash + seed) % width

class CountMinSketch:
"""
A memory-efficient implementation of the count min sketch algorithm with optimized hashing using Numba JIT.
"""

def __init__(self, depth=6, width=2**22, M=None):
self.depth = depth
self.width = width
self.hash_seeds = np.array(np.random.randint(low=0, high=2**31 - 1, size=depth), dtype=np.uint32)
self.M = np.zeros((depth, width), dtype=np.int32) if M is None else M
self.tmp_vals = set()

@staticmethod
@njit
def _add(M, x, depth, width, hash_seeds, delta=1):
for i in prange(depth):
location = cms_hash(x, hash_seeds[i], width)
M[i, location] += delta

def add(self, x, delta=1):
if len(self.tmp_vals) < 10 ** 5 or sys.getsizeof(self.tmp_vals) / (10 ** 3) < 100.0:
self.tmp_vals.add(x)
CountMinSketch._add(self.M, x, self.depth, self.width, self.hash_seeds, delta)

def batch_add(self, lst, delta=1):
for x in lst:
self.add(x, delta)

def query(self, x):
return min(self.M[i][cms_hash(x, self.hash_seeds[i], self.width)] for i in range(self.depth))

def get_matrix(self):
return self.M

def stream_hist_update(self):
""" A bit hacky way to aggregate cms results """
return Counter(self.query(x) for x in self.tmp_vals)


if __name__ == '__main__':
from collections import Counter

depth = 8
width = 2**22
cms = CountMinSketch(depth, width)

items = [1, 1, 2, 3, 3, 3, 4, 5, 2] * 1000
cms.batch_add(items) # Use the batch_add function

print(cms.query(3)) # Query for frequency estimates
print(cms.query(1))
print(cms.query(2))
print(cms.query(4))
print(cms.query(5))

print(Counter(items)) # Print the exact counts for comparison
print(cms.stream_hist_update())
10 changes: 9 additions & 1 deletion outrank/core_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tqdm

from outrank.algorithms.importance_estimator import get_importances_estimate_pairwise
from outrank.algorithms.sketches.counting_cms import CountMinSketch
from outrank.algorithms.sketches.counting_ultiloglog import (
HyperLogLogWCache as HyperLogLog,
)
Expand All @@ -38,6 +39,7 @@
logger.setLevel(logging.DEBUG)
random.seed(a=123, version=2)
GLOBAL_CARDINALITY_STORAGE: dict[Any, Any] = dict()
GLOBAL_COUNTS_STORAGE: dict[Any, Any] = dict()
GLOBAL_RARE_VALUE_STORAGE: dict[str, Any] = Counter()
GLOBAL_PRIOR_COMB_COUNTS: dict[Any, int] = Counter()
IGNORED_VALUES = set()
Expand Down Expand Up @@ -431,6 +433,11 @@ def compute_cardinalities(input_dataframe: pd.DataFrame, pbar: Any) -> None:
HYPERLL_ERROR_BOUND,
)

if column not in GLOBAL_COUNTS_STORAGE:
GLOBAL_COUNTS_STORAGE[column] = CountMinSketch()

[GLOBAL_COUNTS_STORAGE[column].add(value) for value in input_dataframe[column].values]

for unique_value in set(input_dataframe[column].unique()):
if unique_value:
GLOBAL_CARDINALITY_STORAGE[column].add(
Expand Down Expand Up @@ -615,7 +622,7 @@ def estimate_importances_minibatches(
delimiter: str = '\t',
feature_construction_mode: bool = False,
logger: Any = None,
) -> tuple[list[dict[str, Any]], Any, dict[Any, Any], list[dict[str, Any]], list[dict[str, set[str]]], defaultdict[str, list[set[str]]], dict[str, Any], dict[str, Any]]:
) -> tuple[list[dict[str, Any]], Any, dict[Any, Any], list[dict[str, Any]], list[dict[str, set[str]]], defaultdict[str, list[set[str]]], dict[str, Any], dict[str, Any], dict[str, Any]]:
"""Interaction score estimator - suitable for example for csv-like input data types.
This type of data is normally a single large csv, meaning that minibatch processing needs to
happen during incremental handling of the file (that"s not the case for pre-separated ob data)
Expand Down Expand Up @@ -744,4 +751,5 @@ def estimate_importances_minibatches(
local_coverage_object,
GLOBAL_RARE_VALUE_STORAGE.copy(),
GLOBAL_PRIOR_COMB_COUNTS.copy(),
GLOBAL_COUNTS_STORAGE.copy(),
)
9 changes: 9 additions & 0 deletions outrank/task_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def outrank_task_conduct_ranking(args: Any) -> None:
coverage_object,
RARE_VALUE_STORAGE,
GLOBAL_PRIOR_COMB_COUNTS,
GLOBAL_ITEM_COUNTS,
) = estimate_importances_minibatches(**cmd_arguments)

global_bounds_storage += bounds_object_storage
Expand Down Expand Up @@ -276,6 +277,14 @@ def outrank_task_conduct_ranking(args: Any) -> None:
os.path.join(args.output_folder, 'pairwise_ranks.tsv'), sep='\t', index=False,
)

with open(f'{args.output_folder}/value_repetitions.json', 'w') as out_counts:
out_dict = {}
for k, v in GLOBAL_ITEM_COUNTS.items():
actual_hist = np.array([k + v for k, v in v.stream_hist_update().items()])
more_than = lambda n, ary: len(np.where(ary > n)[0])
out_dict[k] = {x: more_than(x, actual_hist) for x in [0] + [1 * 10 ** x for x in range(6)]}
out_counts.write(json.dumps(out_dict))

with open(f'{args.output_folder}/combination_estimation_counts.json', 'w') as out_counts:
out_dict = {str(k): v for k, v in GLOBAL_PRIOR_COMB_COUNTS.items()}
out_counts.write(json.dumps(out_dict))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _read_description():
packages = [x for x in setuptools.find_packages() if x != 'test']
setuptools.setup(
name='outrank',
version='0.95.6',
version='0.95.7',
description='OutRank: Feature ranking for massive sparse data sets.',
long_description=_read_description(),
long_description_content_type='text/markdown',
Expand Down
75 changes: 75 additions & 0 deletions tests/cms_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

import sys
import unittest

import numpy as np

from outrank.algorithms.sketches.counting_cms import cms_hash
from outrank.algorithms.sketches.counting_cms import CountMinSketch


class TestCountMinSketch(unittest.TestCase):

def setUp(self):
# Set up a CountMinSketch instance with known parameters for testing
self.depth = 6
self.width = 2**10 # smaller width for testing purposes
self.cms = CountMinSketch(self.depth, self.width)

def test_init(self):
self.assertEqual(self.cms.depth, self.depth)
self.assertEqual(self.cms.width, self.width)
self.assertEqual(self.cms.M.shape, (self.depth, self.width))
self.assertEqual(len(self.cms.hash_seeds), self.depth)
self.assertIsInstance(self.cms.tmp_vals, set)

def test_add_and_query_single_element(self):
# Test adding a single element and querying it
element = 'test_element'
self.cms.add(element)
# The queried count should be at least 1 (could be higher due to hash collisions)
self.assertGreaterEqual(self.cms.query(element), 1)

def test_add_and_query_multiple_elements(self):
elements = ['foo', 'bar', 'baz', 'qux', 'quux']
for elem in elements:
self.cms.add(elem)

for elem in elements:
self.assertGreaterEqual(self.cms.query(elem), 1)

def test_batch_add_and_query(self):
elements = ['foo', 'bar', 'baz'] * 10
self.cms.batch_add(elements)

for elem in set(elements):
self.assertGreaterEqual(self.cms.query(elem), 10)

def test_stream_hist_update(self):
self.cms.add('foo')
self.cms.add('foo')
self.cms.add('bar')

hist = self.cms.stream_hist_update()

# Note: we cannot test for exact counts because the CountMinSketch is a probabilistic data structure
# and may overcount. However, we never expect it to undercount an element.
self.assertGreaterEqual(hist[self.cms.query('foo')], 1)
self.assertGreaterEqual(hist[self.cms.query('bar')], 1)

def test_overflow_protection(self):
# This test ensures that the set doesn't grow beyond its allowed size and memory usage
for i in range(100001):
self.cms.add(f'element{i}')

self.assertLessEqual(len(self.cms.tmp_vals), 100000)
self.assertLessEqual(sys.getsizeof(self.cms.tmp_vals) / (10 ** 3), 4200.0)

def test_hash_uniformity(self):
# Basic check for hash function's distribution
seeds = np.array(np.random.randint(low=0, high=2**31 - 1, size=self.depth), dtype=np.uint32)
hashes = [cms_hash(i, seeds[0], self.width) for i in range(1000)]
# Expect fewer collisions over a small sample with a large width
unique_hashes = len(set(hashes))
self.assertGreater(unique_hashes, 900)

0 comments on commit 04d7b2d

Please sign in to comment.