Skip to content

Commit

Permalink
Merge pull request #2 from asikowitz/file-dict-serialization
Browse files Browse the repository at this point in the history
Implement serialization and deserialization for file dict
  • Loading branch information
asikowitz authored Mar 1, 2023
2 parents 03a0a50 + a30ff4e commit a0c6d85
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def _aggregate_access_events(
AggregatedDataset(
bucket_start_time=floored_ts,
resource=resource,
user_email_pattern=self.config.user_email_pattern,
),
)

Expand All @@ -269,6 +268,7 @@ def _aggregate_access_events(
username,
event.query,
metadata.columns,
user_email_pattern=self.config.user_email_pattern,
)
return datasets

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import dataclasses
import logging
from datetime import datetime
from typing import Callable, ClassVar, Counter, Generic, List, Optional, TypeVar
from typing import Callable, Counter, Generic, List, Optional, TypeVar

import pydantic
from pydantic.fields import Field
Expand Down
54 changes: 44 additions & 10 deletions metadata-ingestion/src/datahub/utilities/file_backed_collections.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
import collections
import sqlite3
import tempfile
from typing import Generic, Iterator, MutableMapping, Optional, OrderedDict, TypeVar
from typing import (
Any,
Callable,
Generic,
Iterator,
List,
MutableMapping,
Optional,
OrderedDict,
Sequence,
Tuple,
TypeVar,
Union,
)

# https://docs.python.org/3/library/sqlite3.html#sqlite-and-python-types
SqliteValue = Union[int, float, str, bytes, None]

_VT = TypeVar("_VT")

Expand All @@ -10,21 +26,31 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
"""A dictionary that stores its data in a temporary SQLite database.
This is useful for storing large amounts of data that don't fit in memory.
For performance, implements an in-memory LRU cache using an OrderedDict,
and sets a generous journal size limit.
"""

serializer: Callable[[_VT], SqliteValue]
deserializer: Callable[[Any], _VT]

_cache_max_size: int
_cache_eviction_batch_size: int
_filename: str
_conn: sqlite3.Connection

_conn: sqlite3.Connection
_active_object_cache: OrderedDict[str, _VT]

def __init__(
self,
serializer: Callable[[_VT], SqliteValue],
deserializer: Callable[[Any], _VT],
filename: Optional[str] = None,
cache_max_size: int = 2000,
cache_eviction_batch_size: int = 200,
):
self._serializer = serializer
self._deserializer = deserializer
self._cache_max_size = cache_max_size
self._cache_eviction_batch_size = cache_eviction_batch_size
self._filename = filename or tempfile.mktemp()
Expand Down Expand Up @@ -61,10 +87,10 @@ def _add_to_cache(self, key: str, value: _VT) -> None:
self._prune_cache(num_items_to_prune)

def _prune_cache(self, num_items_to_prune: int) -> None:
items_to_write = []
items_to_write: List[Tuple[str, SqliteValue]] = []
for _ in range(num_items_to_prune):
key, value = self._active_object_cache.popitem(last=False)
items_to_write.append((key, value))
items_to_write.append((key, self._serializer(value)))

self._conn.executemany(
"INSERT OR REPLACE INTO data (key, value) VALUES (?, ?)", items_to_write
Expand All @@ -79,23 +105,28 @@ def __getitem__(self, key: str) -> _VT:
return self._active_object_cache[key]

cursor = self._conn.execute("SELECT value FROM data WHERE key = ?", (key,))
result = cursor.fetchone()
result: Sequence[SqliteValue] = cursor.fetchone()
if result is None:
raise KeyError(key)

self._add_to_cache(key, result[0])
return result[0]
deserialized_result = self._deserializer(result[0])
self._add_to_cache(key, deserialized_result)
return deserialized_result

def __setitem__(self, key: str, value: _VT) -> None:
self._add_to_cache(key, value)

def __delitem__(self, key: str) -> None:
self[key] # raise KeyError if key doesn't exist

in_cache = False
if key in self._active_object_cache:
del self._active_object_cache[key]
in_cache = True

self._conn.execute("DELETE FROM data WHERE key = ?", (key,))
n_deleted = self._conn.execute(
"DELETE FROM data WHERE key = ?", (key,)
).rowcount
if not in_cache and not n_deleted:
raise KeyError(key)

def __iter__(self) -> Iterator[str]:
cursor = self._conn.execute("SELECT key FROM data")
Expand Down Expand Up @@ -124,3 +155,6 @@ def __repr__(self) -> str:

def close(self) -> None:
self._conn.close()

def __del__(self) -> None:
self.close()
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import json
from collections import Counter
from dataclasses import asdict, dataclass
from typing import Dict

import pytest

from datahub.utilities.file_backed_collections import FileBackedDict


def test_file_dict():
cache = FileBackedDict[int](cache_max_size=10, cache_eviction_batch_size=10)
cache = FileBackedDict[int](
serializer=lambda x: x,
deserializer=lambda x: x,
cache_max_size=10,
cache_eviction_batch_size=10,
)

for i in range(100):
cache[f"key-{i}"] = i
Expand Down Expand Up @@ -37,11 +47,13 @@ def test_file_dict():
assert cache["key-3"] == 99
cache["key-3"] = 3

# Test deleting a key.
# Test deleting keys, in and out of cache
del cache["key-0"]
assert len(cache) == 99
del cache["key-99"]
assert len(cache) == 98
with pytest.raises(KeyError):
cache["key-0"]
cache["key-99"]

# Test deleting a key that doesn't exist.
with pytest.raises(KeyError):
Expand All @@ -50,12 +62,99 @@ def test_file_dict():
# Test adding another key.
cache["a"] = 1
assert cache["a"] == 1
assert len(cache) == 100
assert sorted(cache) == sorted(["a"] + [f"key-{i}" for i in range(1, 100)])
assert len(cache) == 99
assert sorted(cache) == sorted(["a"] + [f"key-{i}" for i in range(1, 99)])

# Test deleting most things.
for i in range(1, 100):
for i in range(1, 99):
assert cache[f"key-{i}"] == i
del cache[f"key-{i}"]
assert len(cache) == 1
assert cache["a"] == 1


def test_file_dict_serialization():
@dataclass(frozen=True)
class Label:
a: str
b: int

@dataclass
class Main:
x: int
y: Dict[Label, float]

def to_dict(self) -> Dict:
d: Dict = {"x": self.x}
str_y = {json.dumps(asdict(k)): v for k, v in self.y.items()}
d["y"] = json.dumps(str_y)
return d

@classmethod
def from_dict(cls, d: Dict) -> "Main":
str_y = json.loads(d["y"])
y = {}
for k, v in str_y.items():
k_str = json.loads(k)
label = Label(k_str["a"], k_str["b"])
y[label] = v

return cls(d["x"], y)

serializer_calls = 0
deserializer_calls = 0

def serialize(m: Main) -> str:
nonlocal serializer_calls
serializer_calls += 1
print(serializer_calls, m)
return json.dumps(m.to_dict())

def deserialize(s: str) -> Main:
nonlocal deserializer_calls
deserializer_calls += 1
return Main.from_dict(json.loads(s))

cache = FileBackedDict[Main](
serializer=serialize,
deserializer=deserialize,
cache_max_size=0,
cache_eviction_batch_size=0,
)
first = Main(3, {Label("one", 1): 0.1, Label("two", 2): 0.2})
second = Main(-100, {Label("z", 26): 0.26})

cache["first"] = first
cache["second"] = second
assert serializer_calls == 2
assert deserializer_calls == 0

assert cache["second"] == second
assert cache["first"] == first
assert serializer_calls == 4 # Items written to cache on every access
assert deserializer_calls == 2


def test_file_dict_stores_counter():
cache = FileBackedDict[Counter[str]](
serializer=json.dumps,
deserializer=lambda s: Counter(json.loads(s)),
cache_max_size=1,
cache_eviction_batch_size=0,
)

n = 5
in_memory_counters: Dict[int, Counter[str]] = {}
for i in range(n):
cache[str(i)] = Counter()
in_memory_counters[i] = Counter()
for j in range(n):
if i == j:
cache[str(i)][str(j)] += 100
in_memory_counters[i][str(j)] += 100
cache[str(i)][str(j)] += j
in_memory_counters[i][str(j)] += j

for i in range(n):
assert in_memory_counters[i] == cache[str(i)]
assert in_memory_counters[i].most_common(2) == cache[str(i)].most_common(2)

0 comments on commit a0c6d85

Please sign in to comment.