Skip to content

Commit

Permalink
Support multiple SqliteDict in a single dbfile
Browse files Browse the repository at this point in the history
  • Loading branch information
binh-vu committed May 20, 2024
1 parent f14e95b commit 963052e
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 51 deletions.
12 changes: 9 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
# Changelog

## [2.12.10] (2024-03-08)
## [2.13.0] - 2024-05-20

### Added

- Support multiple SqliteDict in a single dbfile

## [2.12.10] - 2024-03-08

### Fixed

- Fix RocksDBLoader for NDJson should not keep double quotes around string value (e.g., `"Human"` instead of `Human`).

## [2.4.0](https://github.com/binh-vu/hugedict/tree/2.4.0) (2022-08-02)
## [2.4.0](https://github.com/binh-vu/hugedict/tree/2.4.0) - 2022-08-02

[Full Changelog](https://github.com/binh-vu/hugedict/compare/2.3.3...2.4.0)

Expand All @@ -19,7 +25,7 @@

- \[mrsw\] Invalid address when the URL is too long [\#8](https://github.com/binh-vu/hugedict/issues/8)

## [2.3.3](https://github.com/binh-vu/hugedict/tree/2.3.3) (2022-07-31)
## [2.3.3](https://github.com/binh-vu/hugedict/tree/2.3.3) - 2022-07-31

[Full Changelog](https://github.com/binh-vu/hugedict/compare/2.3.2...2.3.3)

Expand Down
156 changes: 110 additions & 46 deletions hugedict/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,22 @@
import functools
import pickle
import sqlite3
from dataclasses import dataclass
from enum import Enum
from inspect import signature
from pathlib import Path
from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TypeVar, Union
from typing import (
Any,
Callable,
Generic,
Iterable,
Iterator,
Optional,
Tuple,
TypeVar,
Union,
overload,
)

import orjson
from timer import Timer
Expand All @@ -15,7 +27,11 @@
from hugedict.types import F, HugeMutableMapping, V

SqliteKey = TypeVar("SqliteKey", bound=Union[str, int, bytes])
TupleValue = TypeVar("TupleValue", bound=tuple)
V1 = TypeVar("V1")
V2 = TypeVar("V2")
V3 = TypeVar("V3")

DEFAULT_TIMEOUT = 5.0


class SqliteDictFieldType(str, Enum):
Expand All @@ -24,6 +40,16 @@ class SqliteDictFieldType(str, Enum):
bytes = "BLOB"


@dataclass
class SqliteDictArgs(Generic[V]):
tablename: str
keytype: SqliteDictFieldType
ser_value: Callable[[V], bytes]
deser_value: Callable[[bytes], V]
valuetype: SqliteDictFieldType = SqliteDictFieldType.bytes
timeout: float = DEFAULT_TIMEOUT


class SqliteDict(HugeMutableMapping[SqliteKey, V]):
"""A mutable mapping backed by sqlite. This mapping is slower than key-value db but offers
concurrency read-write operators.
Expand All @@ -37,21 +63,29 @@ class SqliteDict(HugeMutableMapping[SqliteKey, V]):

def __init__(
self,
path: Union[str, Path],
path: Union[str, Path, sqlite3.Connection],
keytype: SqliteDictFieldType,
ser_value: Callable[[V], bytes],
deser_value: Callable[[bytes], V],
valuetype: SqliteDictFieldType = SqliteDictFieldType.bytes,
timeout: float = 5.0,
timeout: float = DEFAULT_TIMEOUT,
table_name: str = "data",
):
self.dbfile = Path(path)
need_init = not self.dbfile.exists()
self.db = sqlite3.connect(str(self.dbfile), timeout=timeout)
if need_init:
with self.db:
self.db.execute(
f"CREATE TABLE data(key {keytype.value} PRIMARY KEY, value {valuetype.value})"
)
if isinstance(path, (str, Path)):
self.dbfile = Path(path)
self.table_name = table_name
need_init = not self.dbfile.exists()
self.db = sqlite3.connect(str(self.dbfile), timeout=timeout)
if need_init:
with self.db:
self.db.execute(
f"CREATE TABLE {self.table_name}(key {keytype.value} PRIMARY KEY, value {valuetype.value})"
)
else:
assert isinstance(
path, sqlite3.Connection
), f"path must be a str, Path, or Connection but get {type(path)}"
self.db = path

self.ser_value = ser_value
self.deser_value = deser_value
Expand All @@ -72,17 +106,68 @@ def int(
) -> SqliteDict[str, V]:
return SqliteDict(path, SqliteDictFieldType.int, ser_value, deser_value)

@staticmethod
def mul2(
path: Union[str, Path],
args1: SqliteDictArgs[V1],
args2: SqliteDictArgs[V2],
) -> tuple[SqliteDict[str, V1], SqliteDict[str, V2]]:
d1, d2 = SqliteDict.multiple(path, args1, args2)
return d1, d2

@staticmethod
def mul3(
path: Union[str, Path],
args1: SqliteDictArgs[V1],
args2: SqliteDictArgs[V2],
args3: SqliteDictArgs[V3],
) -> tuple[SqliteDict[str, V1], SqliteDict[str, V2], SqliteDict[str, V3]]:
d1, d2, d3 = SqliteDict.multiple(path, args1, args2, args3)
return d1, d2, d3

@staticmethod
def multiple(
path: Union[str, Path],
*args: SqliteDictArgs[Any],
) -> tuple[SqliteDict[str, Any], ...]:
dbfile = Path(path)
need_init = not dbfile.exists()
timeout = max(a.timeout for a in args)
db = sqlite3.connect(str(dbfile), timeout=timeout)
if need_init:
with db:
for a in args:
db.execute(
f"CREATE TABLE {a.tablename}(key {a.keytype.value} PRIMARY KEY, value {a.valuetype.value})"
)

out = []
for a in args:
out.append(
SqliteDict(
path=db,
keytype=a.keytype,
ser_value=a.ser_value,
deser_value=a.deser_value,
valuetype=a.valuetype,
timeout=a.timeout,
table_name=a.tablename,
)
)
return tuple(out)

def __contains__(self, key: SqliteKey):
return (
self.db.execute(
"SELECT EXISTS ( SELECT 1 FROM data WHERE key = ? LIMIT 1)", (key,)
f"SELECT EXISTS ( SELECT 1 FROM {self.table_name} WHERE key = ? LIMIT 1)",
(key,),
).fetchone()[0]
== 1
)

def __getitem__(self, key: SqliteKey) -> V:
record = self.db.execute(
"SELECT value FROM data WHERE key = ?", (key,)
f"SELECT value FROM {self.table_name} WHERE key = ?", (key,)
).fetchone()
if record is None:
raise KeyError(key)
Expand All @@ -91,38 +176,40 @@ def __getitem__(self, key: SqliteKey) -> V:
def __setitem__(self, key: SqliteKey, value: V):
with self.db:
self.db.execute(
"INSERT INTO data VALUES (:key, :value) ON CONFLICT(key) DO UPDATE SET value = :value",
f"INSERT INTO {self.table_name} VALUES (:key, :value) ON CONFLICT(key) DO UPDATE SET value = :value",
{"key": key, "value": self.ser_value(value)},
)

def __delitem__(self, key: SqliteKey) -> None:
with self.db:
self.db.execute("DELETE FROM data WHERE key = ?", (key,))
self.db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))

def __iter__(self) -> Iterator[SqliteKey]:
return (key[0] for key in self.db.execute("SELECT key FROM data"))
return (key[0] for key in self.db.execute(f"SELECT key FROM {self.table_name}"))

def __len__(self) -> int:
return self.db.execute("SELECT COUNT(*) FROM data").fetchone()[0]
return self.db.execute(f"SELECT COUNT(*) FROM {self.table_name}").fetchone()[0]

def keys(self) -> Iterator[SqliteKey]:
return (key[0] for key in self.db.execute("SELECT key FROM data"))
return (key[0] for key in self.db.execute(f"SELECT key FROM {self.table_name}"))

def values(self) -> Iterator[V]:
return (
self.deser_value(value[0])
for value in self.db.execute("SELECT value FROM data")
for value in self.db.execute(f"SELECT value FROM {self.table_name}")
)

def items(self) -> Iterator[Tuple[SqliteKey, V]]:
return (
(key, self.deser_value(value))
for key, value in self.db.execute("SELECT key, value FROM data")
for key, value in self.db.execute(
f"SELECT key, value FROM {self.table_name}"
)
)

def get(self, key: SqliteKey, default=None):
record = self.db.execute(
"SELECT value FROM data WHERE key = ?", (key,)
f"SELECT value FROM {self.table_name} WHERE key = ?", (key,)
).fetchone()
if record is None:
return default
Expand All @@ -131,7 +218,7 @@ def get(self, key: SqliteKey, default=None):
def batch_insert(self, items: Iterable[Tuple[SqliteKey, V]]):
with self.db:
self.db.executemany(
"INSERT INTO data VALUES (:key, :value) ON CONFLICT(key) DO UPDATE SET value = :value",
f"INSERT INTO {self.table_name} VALUES (:key, :value) ON CONFLICT(key) DO UPDATE SET value = :value",
[{"key": key, "value": self.ser_value(value)} for key, value in items],
)

Expand Down Expand Up @@ -224,26 +311,3 @@ def fn(*args, **kwargs):
return fn

return wrapper_fn # type: ignore


class SqliteDictTuple(HugeMutableMapping[SqliteKey, TupleValue]):
def __init__(
self,
path: Union[str, Path],
keytype: SqliteDictFieldType,
ser_value: Callable[[TupleValue], bytes],
deser_value: Callable[[bytes], TupleValue],
valuetype: SqliteDictFieldType = SqliteDictFieldType.bytes,
timeout: float = 5.0,
):
self.dbfile = Path(path)
need_init = not self.dbfile.exists()
self.db = sqlite3.connect(str(self.dbfile), timeout=timeout)
if need_init:
with self.db:
self.db.execute(
f"CREATE TABLE data(key {keytype.value} PRIMARY KEY, value {valuetype.value})"
)

self.ser_value = ser_value
self.deser_value = deser_value
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "hugedict"
version = "2.12.10"
version = "2.13.0"
authors = [{ name = "Binh Vu", email = "binh@toan2.com" }]
description = "A dictionary-like object that is friendly with multiprocessing and uses key-value databases (e.g., RocksDB) as the underlying storage."
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_sqlitedict.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_concurrent_access(self, mapping: SqliteDict):

# start a new process and write a new key
commands = [
"from hugedict.sqlitedict import SqliteDict",
"from hugedict.sqlite import SqliteDict",
f"map = SqliteDict.str('{str(mapping.dbfile.absolute())}', ser_value=str.encode, deser_value=bytes.decode)",
f"nkey = '{nkey}'",
f"nvalue = '{nvalue}'",
Expand Down

0 comments on commit 963052e

Please sign in to comment.