Skip to content

Commit

Permalink
Merge pull request #62 from outbrain/cms
Browse files Browse the repository at this point in the history
More primitive, much faster counter
  • Loading branch information
SkBlaz authored Jan 31, 2024
2 parents ef54a09 + 6f31fc0 commit 2dd2d38
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 36 deletions.
7 changes: 7 additions & 0 deletions outrank/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ def main():
help='Name of the target attribute for ranking. Note that this can be any other feature for most implemented heuristics.',
)

parser.add_argument(
'--max_unique_hist_constraint',
type=int,
default=30_000,
help='Max number of unique values for which counts are recalled.',
)

parser.add_argument(
'--transformers',
type=str,
Expand Down
8 changes: 0 additions & 8 deletions outrank/algorithms/sketches/counting_cms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(self, depth=6, width=2**15, M=None):
self.width = width
self.hash_seeds = np.array(np.random.randint(low=0, high=2**31 - 1, size=depth), dtype=np.uint32)
self.M = np.zeros((depth, width), dtype=np.int32) if M is None else M
self.tmp_vals = set()

@staticmethod
@njit
Expand All @@ -33,8 +32,6 @@ def _add(M, x, depth, width, hash_seeds, delta=1):
M[i, location] += delta

def add(self, x, delta=1):
if len(self.tmp_vals) < 10 ** 4 or sys.getsizeof(self.tmp_vals) / (10 ** 3) < 100.0:
self.tmp_vals.add(x)
CountMinSketch._add(self.M, x, self.depth, self.width, self.hash_seeds, delta)

def batch_add(self, lst, delta=1):
Expand All @@ -47,10 +44,6 @@ def query(self, x):
def get_matrix(self):
return self.M

def stream_hist_update(self):
""" A bit hacky way to aggregate cms results """
return Counter(self.query(x) for x in self.tmp_vals)


if __name__ == '__main__':
from collections import Counter
Expand All @@ -69,4 +62,3 @@ def stream_hist_update(self):
print(cms.query(5))

print(Counter(items)) # Print the exact counts for comparison
print(cms.stream_hist_update())
35 changes: 35 additions & 0 deletions outrank/algorithms/sketches/counting_counters_ordinary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from collections import Counter


class PrimitiveConstrainedCounter:
"""
A memory-efficient implementation of the count min sketch algorithm with optimized hashing using Numba JIT.
"""

def __init__(self, bound: int=(10**4) * 3):
self.max_bound_thr = bound
self.default_counter: Counter = Counter()

def batch_add(self, lst):
if len(self.default_counter) < self.max_bound_thr:
self.default_counter = self.default_counter + Counter(lst)

def add(self, val):
if len(self.default_counter) < self.max_bound_thr:
self.default_counter[val] += 1


if __name__ == '__main__':
from collections import Counter

depth = 8
width = 2**22
import numpy as np
cms = PrimitiveConstrainedCounter()

items = [1, 1, 2, 3, 3, 3, 4, 5, 2] * 10000
cms.batch_add(items) # Use the batch_add function

print(Counter(items)) # Print the exact counts for comparison
8 changes: 4 additions & 4 deletions outrank/core_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import tqdm

from outrank.algorithms.importance_estimator import get_importances_estimate_pairwise
from outrank.algorithms.sketches.counting_cms import CountMinSketch
from outrank.algorithms.sketches.counting_counters_ordinary import PrimitiveConstrainedCounter
from outrank.algorithms.sketches.counting_ultiloglog import (
HyperLogLogWCache as HyperLogLog,
)
Expand Down Expand Up @@ -421,7 +421,7 @@ def compute_value_counts(input_dataframe: pd.DataFrame, args: Any):
del GLOBAL_RARE_VALUE_STORAGE[to_remove_val]


def compute_cardinalities(input_dataframe: pd.DataFrame, pbar: Any) -> None:
def compute_cardinalities(input_dataframe: pd.DataFrame, pbar: Any, max_unique_hist_constraint: int) -> None:
"""Compute cardinalities of features, incrementally"""

global GLOBAL_CARDINALITY_STORAGE
Expand All @@ -434,7 +434,7 @@ def compute_cardinalities(input_dataframe: pd.DataFrame, pbar: Any) -> None:
)

if column not in GLOBAL_COUNTS_STORAGE:
GLOBAL_COUNTS_STORAGE[column] = CountMinSketch()
GLOBAL_COUNTS_STORAGE[column] = PrimitiveConstrainedCounter(max_unique_hist_constraint)

[GLOBAL_COUNTS_STORAGE[column].add(value) for value in input_dataframe[column].values]

Expand Down Expand Up @@ -553,7 +553,7 @@ def compute_batch_ranking(
feature_memory_consumption = compute_feature_memory_consumption(
input_dataframe, args,
)
compute_cardinalities(input_dataframe, pbar)
compute_cardinalities(input_dataframe, pbar, args.max_unique_hist_constraint)

if args.task == 'identify_rare_values':
compute_value_counts(input_dataframe, args)
Expand Down
2 changes: 1 addition & 1 deletion outrank/core_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def parse_namespace(namespace_path: str) -> tuple[set[str], dict[str, str]]:
if type_name == 'f32':
float_set.add(feature)
except Exception as es:
logging.error(f'\U0001F631 {es} -- {namespace_parts}')
pass

return float_set, id_feature_map

Expand Down
2 changes: 1 addition & 1 deletion outrank/task_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def outrank_task_conduct_ranking(args: Any) -> None:
with open(f'{args.output_folder}/value_repetitions.json', 'w') as out_counts:
out_dict = {}
for k, v in GLOBAL_ITEM_COUNTS.items():
actual_hist = np.array([k + v for k, v in v.stream_hist_update().items()])
actual_hist = np.array(list(v.default_counter.values()))
more_than = lambda n, ary: len(np.where(ary > n)[0])
out_dict[k] = {x: more_than(x, actual_hist) for x in [0] + [1 * 10 ** x for x in range(6)]}
out_counts.write(json.dumps(out_dict))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _read_description():
packages = [x for x in setuptools.find_packages() if x != 'test']
setuptools.setup(
name='outrank',
version='0.95.8',
version='0.96.0',
description='OutRank: Feature ranking for massive sparse data sets.',
long_description=_read_description(),
long_description_content_type='text/markdown',
Expand Down
21 changes: 0 additions & 21 deletions tests/cms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def test_init(self):
self.assertEqual(self.cms.width, self.width)
self.assertEqual(self.cms.M.shape, (self.depth, self.width))
self.assertEqual(len(self.cms.hash_seeds), self.depth)
self.assertIsInstance(self.cms.tmp_vals, set)

def test_add_and_query_single_element(self):
# Test adding a single element and querying it
Expand All @@ -46,26 +45,6 @@ def test_batch_add_and_query(self):
for elem in set(elements):
self.assertGreaterEqual(self.cms.query(elem), 10)

def test_stream_hist_update(self):
self.cms.add('foo')
self.cms.add('foo')
self.cms.add('bar')

hist = self.cms.stream_hist_update()

# Note: we cannot test for exact counts because the CountMinSketch is a probabilistic data structure
# and may overcount. However, we never expect it to undercount an element.
self.assertGreaterEqual(hist[self.cms.query('foo')], 1)
self.assertGreaterEqual(hist[self.cms.query('bar')], 1)

def test_overflow_protection(self):
# This test ensures that the set doesn't grow beyond its allowed size and memory usage
for i in range(100001):
self.cms.add(f'element{i}')

self.assertLessEqual(len(self.cms.tmp_vals), 100000)
self.assertLessEqual(sys.getsizeof(self.cms.tmp_vals) / (10 ** 3), 4200.0)

def test_hash_uniformity(self):
# Basic check for hash function's distribution
seeds = np.array(np.random.randint(low=0, high=2**31 - 1, size=self.depth), dtype=np.uint32)
Expand Down

0 comments on commit 2dd2d38

Please sign in to comment.