Skip to content

Commit

Permalink
async client support
Browse files Browse the repository at this point in the history
  • Loading branch information
BhaaVast committed Apr 2, 2023
1 parent 87c6ac3 commit 184a766
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 45 deletions.
6 changes: 4 additions & 2 deletions pydastic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@ def get_version() -> str:
InvalidModelError,
NotFoundError,
)
from pydastic.model import ESModel
from pydastic.pydastic import PydasticClient, connect
from pydastic.model import ESAsyncModel, ESModel
from pydastic.pydastic import PydasticClient, connect, connect_async
from pydastic.session import Session

__all__ = [
"ESModel",
"ESAsyncModel",
"Session",
"NotFoundError",
"InvalidModelError",
"InvalidElasticsearchResponse",
"PydasticClient",
"connect",
"connect_async",
]
205 changes: 168 additions & 37 deletions pydastic/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
from copy import copy
from datetime import datetime
from functools import partial
from typing import Any, Callable, Dict, Optional, Tuple, Type, TypeVar, Union

from elasticsearch import NotFoundError as ElasticNotFoundError
Expand Down Expand Up @@ -124,20 +126,11 @@ def save(self: Type[M], index: Optional[str] = None, wait_for: Optional[bool] =
Args:
index (str, optional): Index name
wait_for (bool, optional): Waits for all shards to sync before returning response - useful when writing tests. Defaults to False.
"""
doc = self.dict(exclude={"id"})

# Allow waiting for shards - useful when testing
refresh = "false"
if wait_for:
refresh = "wait_for"

# Use user-provided index if provided (dynamic index support)
if not index:
index = self.Meta.index
res = _client.client.index(index=index, body=doc, id=self.id, refresh=refresh)
self.id = res.get("_id")
Returns:
New document ID
"""
return Save(index=index, wait_for=wait_for, model=self).sync()

@classmethod
def get(cls: Type[M], id: str, extra_fields: Optional[bool] = False, index: Optional[str] = None) -> M:
Expand All @@ -154,26 +147,11 @@ def get(cls: Type[M], id: str, extra_fields: Optional[bool] = False, index: Opti
Raises:
NotFoundError: Returned if document not found
"""
source_includes = None
if not extra_fields:
fields: dict = copy(vars(cls).get("__fields__"))
fields.pop("id", None)
source_includes = list(fields.keys())

# Use user-provided index if provided (dynamic index support)
if not index:
index = cls.Meta.index

try:
res = _client.client.get(index=index, id=id, _source_includes=source_includes)
return Get(model=cls, id_=id, index=index, extra_fields=extra_fields).sync()
except ElasticNotFoundError:
raise NotFoundError(f"document with id {id} not found")

model = cls.from_es(res)
model.id = id

return model

def delete(self: Type[M], index: Optional[str] = None, wait_for: Optional[bool] = False):
"""Deletes document from elasticsearch.
Expand All @@ -187,17 +165,170 @@ def delete(self: Type[M], index: Optional[str] = None, wait_for: Optional[bool]
"""
if not self.id:
raise ValueError("id missing from object")
try:
Delete(index=index, model=self).sync()
except ElasticNotFoundError:
raise NotFoundError(f"document with id {id} not found")


class ESAsyncModel(ESModel):
class Meta:
@property
def index(self) -> str:
"""Elasticsearch index name associated with this model class"""
raise NotImplementedError

async def save(self: Type[M], index: Optional[str] = None, wait_for: Optional[bool] = False):
"""Indexes document into elasticsearch.
If document already exists, existing document will be updated as per native elasticsearch index operation.
If model instance includes an 'id' property, this will be used as the elasticsearch _id.
If no 'id' is provided, then document will be indexed and elasticsearch will generate a suitable id that will be populated on the returned model.
Args:
index (str, optional): Index name
wait_for (bool, optional): Waits for all shards to sync before returning response - useful when writing tests. Defaults to False.
Returns:
New document ID
"""
return await Save(index=index, wait_for=wait_for, model=self).asyncio()

@classmethod
async def get(cls: Type[M], id: str, extra_fields: Optional[bool] = False, index: Optional[str] = None) -> M:
"""Fetches document and returns ESModel instance populated with properties.
Args:
id (str): Document id
extra_fields (bool, Optional): Include fields found in elasticsearch but not part of the model definition
index (str, optional): Index name
Returns:
ESAsyncModel
Raises:
NotFoundError: Returned if document not found
"""
get = Get(model=cls, id_=id, extra_fields=extra_fields, index=index)
try:
return await get.asyncio()
except ElasticNotFoundError:
raise NotFoundError(f"document with id {id} not found")

# Allow waiting for shards - useful when testing
refresh = "false"
if wait_for:
refresh = "wait_for"
async def delete(self: Type[M], index: Optional[str] = None, wait_for: Optional[bool] = False):
"""Deletes document from elasticsearch.
# Use user-provided index if provided (dynamic index support)
if not index:
index = self.Meta.index
Args:
index (str, optional): Index name
wait_for (bool, optional): Waits for all shards to sync before returning response - useful when writing tests. Defaults to False.
Raises:
NotFoundError: Returned if document not found
ValueError: Returned when id attribute missing from instance
"""
if not self.id:
raise ValueError("id missing from object")
try:
_client.client.delete(index=index, id=self.id, refresh=refresh)
await Delete(index=index, model=self, wait_for=wait_for).asyncio()
except ElasticNotFoundError:
raise NotFoundError(f"document with id {id} not found")


class BaseESOperation(ABC):
"""Abstract class of ES operation: save, get, delete
Each subclass should set INDEX_METHOD_NAME which is the ES client function to run
initiator (__init__) is used as pre-operation activity.
"""

INDEX_METHOD_NAME: Optional[str] = None

def __init__(self, model: Type[M], index: str, wait_for: Optional[bool] = None):
if self.INDEX_METHOD_NAME is None:
raise AttributeError(f"Must set INDEX_METHOD_NAME variable for class {self.__class__.__name__}")
self._es_client_func = getattr(_client.client, self.INDEX_METHOD_NAME)

self.model = model
self.index = index
if wait_for is not None:
# Allow waiting for shards - useful when testing
self.refresh = "wait_for" if wait_for else "false"

@abstractmethod
@property
def kwargs(self):
"""ES operation kwargs"""
raise NotImplementedError()

@abstractmethod
def post(self, *args, **kwargs):
"""Post activites after ES operation is done"""
# do nothing by default
raise NotImplementedError()

def sync(self):
"""Run in blocking mode"""
es_result = self._es_callable()
return self.post(es_result=es_result)

async def asyncio(self):
"""Run in async mode"""
es_result = await self._es_callable()
return self.post(es_result=es_result)

@property
def _es_callable(self) -> callable:
return partial(self._es_client_func, **self.kwargs)


class Get(BaseESOperation):
INDEX_METHOD_NAME = "get"

def __init__(self, model: Type[M], id_: str, index: Optional[str] = None, extra_fields: Optional[bool] = False):
super().__init__(index=index, model=model)
self.id = id_

self.source_includes = None
if not extra_fields:
fields: dict = copy(vars(model).get("__fields__"))
fields.pop("id", None)
self.source_includes = list(fields.keys())

@property
def kwargs(self):
return dict(index=self.index, id=self.id, _source_includes=self.source_includes)

def post(self, es_result: dict) -> M:
model = self.model.from_es(es_result)
model.id = self.id
return model


class Save(BaseESOperation):
INDEX_METHOD_NAME = "index"

def __init__(self, model: Type[M], index: Optional[str] = None, wait_for: Optional[bool] = False):
super().__init__(model=model, index=index, wait_for=wait_for)
self.doc = self.model.dict(exclude={"id"})

def post(self, es_result: dict) -> str:
self.model.id = es_result.get("_id")
return self.model.id

@property
def kwargs(self):
return dict(index=self.index, doc=self.doc, id=self.model.id, refresh=self.refresh)


class Delete(BaseESOperation):

INDEX_METHOD_NAME = "delete"

def __init__(self, model: Type[M], index: Optional[str] = None, wait_for: Optional[bool] = False):
super().__init__(model=model, index=index, wait_for=wait_for)

@property
def kwargs(self):
return dict(index=self.index, id=self.model.id, refresh=self.refresh)

def post(self, es_result: dict) -> None:
return None
15 changes: 11 additions & 4 deletions pydastic/pydastic.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import typing as t

from elasticsearch import Elasticsearch
from elasticsearch import AsyncElasticsearch, Elasticsearch

ElasticClasses = t.Union[Elasticsearch, AsyncElasticsearch]


class PydasticClient:
client: Elasticsearch = None
client: t.Optional[ElasticClasses] = None

def __getattribute__(self, __name: str) -> t.Any:
if __name == "client" and object.__getattribute__(self, __name) is None:
raise AttributeError("client not initialized - make sure to call Pydastic.connect()")
raise AttributeError("client not initialized - make sure to call Pydastic.connect() or Pydastic.connect_async()")
return object.__getattribute__(self, __name)

def _set_client(self, client: Elasticsearch):
def _set_client(self, client: ElasticClasses):
object.__setattr__(self, "client", client)
return client

Expand All @@ -35,3 +37,8 @@ def wrapped(*args, **kwargs):
@copy_signature(Elasticsearch)
def connect(*args, **kwargs):
...


@copy_signature(AsyncElasticsearch)
def connect_async(*args, **kwargs):
...
11 changes: 11 additions & 0 deletions tests/car.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Optional

from pydastic import ESAsyncModel


class Car(ESAsyncModel):
model: str
year: Optional[int]

class Meta:
index = "car"
19 changes: 17 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import pytest
from elasticsearch import Elasticsearch
from car import Car
from elasticsearch import AsyncElasticsearch, Elasticsearch
from user import User

from pydastic.pydastic import _client, connect
from pydastic.pydastic import _client, connect, connect_async


@pytest.fixture()
async def async_es() -> AsyncElasticsearch:
connect_async(hosts="http://localhost:9200")
await _client.client.delete_by_query(index="_all", body={"query": {"match_all": {}}}, wait_for_completion=True, refresh=True)
return _client.client


@pytest.fixture()
Expand All @@ -17,3 +25,10 @@ def user(es: Elasticsearch) -> User:
user = User(name="John", phone="123456")
user.save(wait_for=True)
return user


@pytest.fixture()
def car(aes: AsyncElasticsearch) -> Car:
car = Car(name="Seat", year="2023")
car.save(wait_for=True)
return car
43 changes: 43 additions & 0 deletions tests/test_model_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
from car import Car
from elasticsearch import AsyncElasticsearch

from pydastic import ESAsyncModel

pytestmark = pytest.mark.asyncio


def test_model_definition_yields_error_without_meta_class():
with pytest.raises(NotImplementedError):

class AsyncUser(ESAsyncModel):
pass


def test_model_definition_yields_error_without_index():
with pytest.raises(NotImplementedError):

class AsyncUser(ESAsyncModel):
class Meta:
pass


@pytest.mark.asyncio
def test_model_async_save_without_connection_raises_attribute_error():
with pytest.raises(AttributeError):
user = Car(model="Fiat")
user.save(wait_for=True)


@pytest.mark.asyncio
async def test_model_async_save(async_es: AsyncElasticsearch):
user = Car(model="Fiat")
await user.save(wait_for=True)
assert user.id is not None

res = await async_es.get(index=user.Meta.index, id=user.id)
assert res["found"]

# Check that fields match exactly
model = user.to_es()
assert res["_source"] == model

0 comments on commit 184a766

Please sign in to comment.