Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Client-Server Architecture support #5

Merged
merged 3 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ CACHE_LMDB
CHACHE_SQLITE3.db
CACHE_LEVELDB
Log

flaxkv/server
FLAXKV_DB

.idea
.ipynb_checkpoints
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ intuitive manner. You can use it just like a Python dictionary without worrying

## TODO

- [ ] Client-Server Architecture
- [x] Client-Server Architecture
- [ ] Benchmark
---

Expand All @@ -85,6 +85,9 @@ from flaxkv import dictdb
import numpy as np

db = dictdb('./test_db')
# or run server `flaxkv run --port 8000`, then:
# db = dictdb('http://localhost:8000', remote=True)

db[1] = 1
db[1.1] = 1 / 3
db['key'] = 'value'
Expand Down
5 changes: 4 additions & 1 deletion README_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@

## TODO

- [ ] 客户端-服务器架构
- [x] 客户端-服务器架构
- [ ] 性能测试

---
Expand All @@ -90,6 +90,9 @@ from flaxkv import dictdb
import numpy as np

db = dictdb('./test_db')
# or run server `flaxkv run --port 8000`, then:
# db = dictdb('http://localhost:8000', remote=True)

db[1] = 1
db[1.1] = 1 / 3
db['key'] = 'value'
Expand Down
22 changes: 17 additions & 5 deletions flaxkv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from .base import LevelDBDict, LMDBDict
from .log import enable_logging
from .serve.client import RemoteDictDB

__version__ = "0.1.6"

Expand All @@ -22,15 +22,27 @@
"dbdict",
"LMDBDict",
"LevelDBDict",
"enable_logging",
"RemoteDictDB",
]


def dictdb(path, backend='lmdb', rebuild=False, raw=False, **kwargs):
def dictdb(
path_or_url: str,
backend='lmdb',
remote=False,
db_name=None,
rebuild=False,
raw=False,
**kwargs,
):
if remote:
return RemoteDictDB(
path_or_url, db_name=db_name, rebuild=rebuild, backend=backend
)
if backend == 'lmdb':
return LMDBDict(path, rebuild=rebuild, raw=raw, **kwargs)
return LMDBDict(path_or_url, rebuild=rebuild, raw=raw, **kwargs)
elif backend == 'leveldb':
return LevelDBDict(path, rebuild=rebuild, raw=raw, **kwargs)
return LevelDBDict(path_or_url, rebuild=rebuild, raw=raw, **kwargs)
else:
raise ValueError(f"Unsupported DB type {backend}.")

Expand Down
41 changes: 41 additions & 0 deletions flaxkv/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import platform

import fire
import uvicorn


class Cli:
@staticmethod
def run(port=8000, workers=1, **kwargs):
"""
Runs the application using the Uvicorn server.

Args:
port (int): The port number on which to run the server. Default is 8000.
workers (int): The number of worker processes to run. Default is 1.

Returns:
None
"""

if platform.system() == "Windows":
os.environ["TZ"] = ""

uvicorn.run(
app="flaxkv.serve.app:app",
host="0.0.0.0",
port=port,
workers=workers,
app_dir="..",
ssl_keyfile=kwargs.get("ssl_keyfile", None),
ssl_certfile=kwargs.get("ssl_certfile", None),
)


def main():
fire.Fire(Cli)


if __name__ == "__main__":
main()
30 changes: 20 additions & 10 deletions flaxkv/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class _enc_prefix:
dict = b'd'
array = b'a'

def __init__(self, db_type, path, rebuild=False, raw=False, **kwargs):
def __init__(self, db_type, path, rebuild=False, raw=False, db_name=None, **kwargs):
"""
Initializes the BaseDBDict class which provides a dictionary-like interface to a database.

Expand All @@ -69,6 +69,7 @@ def __init__(self, db_type, path, rebuild=False, raw=False, **kwargs):
else:
logger.disable('flaxkv')

self._db_name = db_name
self._db_manager = DBManager(
db_type=db_type, db_path=path, rebuild=rebuild, **kwargs
)
Expand Down Expand Up @@ -221,7 +222,10 @@ def get_db_value(self, key: str):
Returns:
value: The encoded value associated with the key.
"""
return self._static_view.get(encode(key))
if self.raw:
return self._static_view.get(key)
else:
return self._static_view.get(encode(key))

def get_batch(self, keys):
"""
Expand Down Expand Up @@ -356,7 +360,7 @@ def _write_buffer_to_db(
self.delete_buffer_set = self.delete_buffer_set - delete_buffer_set_snapshot
self.buffer_dict = self._diff_buffer(self.buffer_dict, buffer_dict_snapshot)

self._db_manager.close_view(self._static_view)
self._db_manager.close_static_view(self._static_view)
self._static_view = self._db_manager.new_static_view()
self._logger.info(
f"write {self._db_manager.db_type.upper()} buffer to db successfully-{current_write_num=}-{self._latest_write_num=}"
Expand Down Expand Up @@ -423,7 +427,10 @@ def pop(self, key, default=None):
value = self.buffer_dict.pop(key)
return value
else:
value = self._static_view.get(encode(key))
if self.raw:
value = self._static_view.get(key)
else:
value = self._static_view.get(encode(key))
return decode(value)
else:
return default
Expand All @@ -444,14 +451,17 @@ def __contains__(self, key):
if key in self.delete_buffer_set:
return False

return self._static_view.get(encode(key)) is not None
if self.raw:
return self._static_view.get(key) is not None
else:
return self._static_view.get(encode(key)) is not None

def clear(self):
"""
Clears the database and resets the buffer.
"""
with self._buffer_lock:
self._db_manager.close_view(self._static_view)
self._db_manager.close_static_view(self._static_view)
self._db_manager.close()
self._close_background_worker(write=False)

Expand Down Expand Up @@ -500,7 +510,7 @@ def close(self, write=True):
"""
self._close_background_worker(write=write)

self._db_manager.close_view(self._static_view)
self._db_manager.close_static_view(self._static_view)
self._db_manager.close()
self._logger.info(f"Closed ({self._db_manager.db_type.upper()}) successfully")

Expand Down Expand Up @@ -549,7 +559,7 @@ class LMDBDict(BaseDBDict):
value: int, float, bool, str, list, dict, and np.ndarray,
"""

def __init__(self, path, map_size=1024**3, rebuild=False, **kwargs):
def __init__(self, path: str, map_size=1024**3, rebuild=False, **kwargs):
super().__init__(
"lmdb", path, max_dbs=1, map_size=map_size, rebuild=rebuild, **kwargs
)
Expand All @@ -564,7 +574,7 @@ def keys(self):
lmdb_keys = set(
decode_key(key) for key in cursor.iternext(keys=True, values=False)
)
self._db_manager.close_view(session)
self._db_manager.close_static_view(session)

return list(lmdb_keys.union(buffer_keys) - delete_buffer_set)

Expand Down Expand Up @@ -636,7 +646,7 @@ class LevelDBDict(BaseDBDict):
value: int, float, bool, str, list, dict and np.ndarray,
"""

def __init__(self, path, rebuild=False, **kwargs):
def __init__(self, path: str, rebuild=False, **kwargs):
super().__init__("leveldb", path=path, rebuild=rebuild)

def keys(self):
Expand Down
2 changes: 1 addition & 1 deletion flaxkv/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def new_static_view(self):
else:
raise ValueError(f"Unsupported DB type {self.db_type}.")

def close_view(self, static_view):
def close_static_view(self, static_view):
"""
Closes the provided static view of the database.

Expand Down
14 changes: 14 additions & 0 deletions flaxkv/pack.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2023 K.Y. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import msgpack
import msgspec.msgpack
import numpy as np
Expand Down
Empty file added flaxkv/serve/__init__.py
Empty file.
87 changes: 87 additions & 0 deletions flaxkv/serve/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import msgspec
from litestar import Litestar, MediaType, Request, get, post

from ..pack import encode
from .interface import AttachRequest, DetachRequest, StructSetData
from .manager import DBManager

db_manager = DBManager(root_path="./FLAXKV_DB")


@post(path="/attach")
async def attach(data: AttachRequest) -> dict:
db = db_manager.get(data.db_name)
if db is None or data.rebuild:
db_manager.set_db(
db_name=data.db_name, backend=data.backend, rebuild=data.rebuild
)
return {"success": True}


@post(path="/detach")
async def detach(data: DetachRequest) -> dict:
db = db_manager.detach(db_name=data.db_name)
if db is None:
return {"success": False, "info": "db not found"}
return {"success": True}


@post(path="/set_raw")
async def set_raw(db_name: str, request: Request) -> dict:
print(f"{db_name=}")
db = db_manager.get(db_name)
if db is None:
return {"success": False, "info": "db not found"}
data = await request.body()
data = msgspec.msgpack.decode(data, type=StructSetData)
db[data.key] = data.value
db.write_immediately(write=True, wait=True)
return {"success": True}


@post("/get_raw", media_type=MediaType.TEXT)
async def get_raw(db_name: str, request: Request) -> bytes:
db = db_manager.get(db_name)
if db is None:
return encode({"success": False, "info": "db not found"})
key = await request.body()
value = db.get(key)
if value is None:
return encode(None)
return db.get(key)


@post("/contains", media_type=MediaType.TEXT)
async def contains(db_name: str, request: Request) -> bytes:
db = db_manager.get(db_name)
if db is None:
return encode({"success": False, "info": "db not found"})
key = await request.body()
is_contain = key in db
return encode(is_contain)


@get("/keys")
async def get_keys(db_name: str) -> dict:
db = db_manager.get(db_name)
if db is None:
return {"success": False, "info": "db not found"}
return {"keys": list(db.keys())}


def on_shutdown():
print("on_shutdown")


app = Litestar(
route_handlers=[
attach,
detach,
set_raw,
get_raw,
contains,
get_keys,
],
on_startup=[lambda: print("on_startup")],
on_shutdown=[on_shutdown],
)
Loading
Loading