diff --git a/metadata-ingestion/src/datahub/ingestion/source/usage/starburst_trino_usage.py b/metadata-ingestion/src/datahub/ingestion/source/usage/starburst_trino_usage.py index 5534172770b9c..49c45d3101a0c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/usage/starburst_trino_usage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/usage/starburst_trino_usage.py @@ -254,7 +254,6 @@ def _aggregate_access_events( AggregatedDataset( bucket_start_time=floored_ts, resource=resource, - user_email_pattern=self.config.user_email_pattern, ), ) @@ -269,6 +268,7 @@ def _aggregate_access_events( username, event.query, metadata.columns, + user_email_pattern=self.config.user_email_pattern, ) return datasets diff --git a/metadata-ingestion/src/datahub/ingestion/source/usage/usage_common.py b/metadata-ingestion/src/datahub/ingestion/source/usage/usage_common.py index 8d4e0e6b8e49a..83ed3683f9d50 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/usage/usage_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/usage/usage_common.py @@ -2,7 +2,7 @@ import dataclasses import logging from datetime import datetime -from typing import Callable, ClassVar, Counter, Generic, List, Optional, TypeVar +from typing import Callable, Counter, Generic, List, Optional, TypeVar import pydantic from pydantic.fields import Field diff --git a/metadata-ingestion/src/datahub/utilities/file_backed_collections.py b/metadata-ingestion/src/datahub/utilities/file_backed_collections.py index 330650dbae342..49d0b2667b1be 100644 --- a/metadata-ingestion/src/datahub/utilities/file_backed_collections.py +++ b/metadata-ingestion/src/datahub/utilities/file_backed_collections.py @@ -1,7 +1,23 @@ import collections import sqlite3 import tempfile -from typing import Generic, Iterator, MutableMapping, Optional, OrderedDict, TypeVar +from typing import ( + Any, + Callable, + Generic, + Iterator, + List, + MutableMapping, + Optional, + OrderedDict, + Sequence, + Tuple, + TypeVar, + Union, +) + +# https://docs.python.org/3/library/sqlite3.html#sqlite-and-python-types +SqliteValue = Union[int, float, str, bytes, None] _VT = TypeVar("_VT") @@ -10,21 +26,31 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]): """A dictionary that stores its data in a temporary SQLite database. This is useful for storing large amounts of data that don't fit in memory. + + For performance, implements an in-memory LRU cache using an OrderedDict, + and sets a generous journal size limit. """ + serializer: Callable[[_VT], SqliteValue] + deserializer: Callable[[Any], _VT] + _cache_max_size: int _cache_eviction_batch_size: int _filename: str - _conn: sqlite3.Connection + _conn: sqlite3.Connection _active_object_cache: OrderedDict[str, _VT] def __init__( self, + serializer: Callable[[_VT], SqliteValue], + deserializer: Callable[[Any], _VT], filename: Optional[str] = None, cache_max_size: int = 2000, cache_eviction_batch_size: int = 200, ): + self._serializer = serializer + self._deserializer = deserializer self._cache_max_size = cache_max_size self._cache_eviction_batch_size = cache_eviction_batch_size self._filename = filename or tempfile.mktemp() @@ -61,10 +87,10 @@ def _add_to_cache(self, key: str, value: _VT) -> None: self._prune_cache(num_items_to_prune) def _prune_cache(self, num_items_to_prune: int) -> None: - items_to_write = [] + items_to_write: List[Tuple[str, SqliteValue]] = [] for _ in range(num_items_to_prune): key, value = self._active_object_cache.popitem(last=False) - items_to_write.append((key, value)) + items_to_write.append((key, self._serializer(value))) self._conn.executemany( "INSERT OR REPLACE INTO data (key, value) VALUES (?, ?)", items_to_write @@ -79,23 +105,28 @@ def __getitem__(self, key: str) -> _VT: return self._active_object_cache[key] cursor = self._conn.execute("SELECT value FROM data WHERE key = ?", (key,)) - result = cursor.fetchone() + result: Sequence[SqliteValue] = cursor.fetchone() if result is None: raise KeyError(key) - self._add_to_cache(key, result[0]) - return result[0] + deserialized_result = self._deserializer(result[0]) + self._add_to_cache(key, deserialized_result) + return deserialized_result def __setitem__(self, key: str, value: _VT) -> None: self._add_to_cache(key, value) def __delitem__(self, key: str) -> None: - self[key] # raise KeyError if key doesn't exist - + in_cache = False if key in self._active_object_cache: del self._active_object_cache[key] + in_cache = True - self._conn.execute("DELETE FROM data WHERE key = ?", (key,)) + n_deleted = self._conn.execute( + "DELETE FROM data WHERE key = ?", (key,) + ).rowcount + if not in_cache and not n_deleted: + raise KeyError(key) def __iter__(self) -> Iterator[str]: cursor = self._conn.execute("SELECT key FROM data") @@ -124,3 +155,6 @@ def __repr__(self) -> str: def close(self) -> None: self._conn.close() + + def __del__(self) -> None: + self.close() diff --git a/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py b/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py index 0d9421cdff442..9195a22665d3f 100644 --- a/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py +++ b/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py @@ -1,10 +1,20 @@ +import json +from collections import Counter +from dataclasses import asdict, dataclass +from typing import Dict + import pytest from datahub.utilities.file_backed_collections import FileBackedDict def test_file_dict(): - cache = FileBackedDict[int](cache_max_size=10, cache_eviction_batch_size=10) + cache = FileBackedDict[int]( + serializer=lambda x: x, + deserializer=lambda x: x, + cache_max_size=10, + cache_eviction_batch_size=10, + ) for i in range(100): cache[f"key-{i}"] = i @@ -37,11 +47,13 @@ def test_file_dict(): assert cache["key-3"] == 99 cache["key-3"] = 3 - # Test deleting a key. + # Test deleting keys, in and out of cache del cache["key-0"] - assert len(cache) == 99 + del cache["key-99"] + assert len(cache) == 98 with pytest.raises(KeyError): cache["key-0"] + cache["key-99"] # Test deleting a key that doesn't exist. with pytest.raises(KeyError): @@ -50,12 +62,99 @@ def test_file_dict(): # Test adding another key. cache["a"] = 1 assert cache["a"] == 1 - assert len(cache) == 100 - assert sorted(cache) == sorted(["a"] + [f"key-{i}" for i in range(1, 100)]) + assert len(cache) == 99 + assert sorted(cache) == sorted(["a"] + [f"key-{i}" for i in range(1, 99)]) # Test deleting most things. - for i in range(1, 100): + for i in range(1, 99): assert cache[f"key-{i}"] == i del cache[f"key-{i}"] assert len(cache) == 1 assert cache["a"] == 1 + + +def test_file_dict_serialization(): + @dataclass(frozen=True) + class Label: + a: str + b: int + + @dataclass + class Main: + x: int + y: Dict[Label, float] + + def to_dict(self) -> Dict: + d: Dict = {"x": self.x} + str_y = {json.dumps(asdict(k)): v for k, v in self.y.items()} + d["y"] = json.dumps(str_y) + return d + + @classmethod + def from_dict(cls, d: Dict) -> "Main": + str_y = json.loads(d["y"]) + y = {} + for k, v in str_y.items(): + k_str = json.loads(k) + label = Label(k_str["a"], k_str["b"]) + y[label] = v + + return cls(d["x"], y) + + serializer_calls = 0 + deserializer_calls = 0 + + def serialize(m: Main) -> str: + nonlocal serializer_calls + serializer_calls += 1 + print(serializer_calls, m) + return json.dumps(m.to_dict()) + + def deserialize(s: str) -> Main: + nonlocal deserializer_calls + deserializer_calls += 1 + return Main.from_dict(json.loads(s)) + + cache = FileBackedDict[Main]( + serializer=serialize, + deserializer=deserialize, + cache_max_size=0, + cache_eviction_batch_size=0, + ) + first = Main(3, {Label("one", 1): 0.1, Label("two", 2): 0.2}) + second = Main(-100, {Label("z", 26): 0.26}) + + cache["first"] = first + cache["second"] = second + assert serializer_calls == 2 + assert deserializer_calls == 0 + + assert cache["second"] == second + assert cache["first"] == first + assert serializer_calls == 4 # Items written to cache on every access + assert deserializer_calls == 2 + + +def test_file_dict_stores_counter(): + cache = FileBackedDict[Counter[str]]( + serializer=json.dumps, + deserializer=lambda s: Counter(json.loads(s)), + cache_max_size=1, + cache_eviction_batch_size=0, + ) + + n = 5 + in_memory_counters: Dict[int, Counter[str]] = {} + for i in range(n): + cache[str(i)] = Counter() + in_memory_counters[i] = Counter() + for j in range(n): + if i == j: + cache[str(i)][str(j)] += 100 + in_memory_counters[i][str(j)] += 100 + cache[str(i)][str(j)] += j + in_memory_counters[i][str(j)] += j + + for i in range(n): + assert in_memory_counters[i] == cache[str(i)] + assert in_memory_counters[i].most_common(2) == cache[str(i)].most_common(2)