Skip to content

Commit

Permalink
Use SyncModule and AsyncModule directly
Browse files Browse the repository at this point in the history
This removes the `client.Module` syntax we had before, because that was
a nightmare to type correctly. This is more or less due to this[1] issue
in mypy. Once we find a good way to make type checkers happy, the syntax
may return. But for now, using SyncModule and AsyncModule as a
superclass will suffice.

1: python/mypy#10962
  • Loading branch information
yrd committed Dec 13, 2021
1 parent a0d68e6 commit 3c62999
Show file tree
Hide file tree
Showing 11 changed files with 65 additions and 73 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ from zucker import model, RequestsClient

crm = RequestsClient("https://crm.example.com", "zucker", "password")

class Contact(crm.Module, api_name="Contacts"):
class Contact(model.SyncModule, client=crm, api_name="Contacts"):
lead_source = model.StringField()
phone_work = model.StringField()

Expand Down
4 changes: 2 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ is used like this:
# The data model defined here should match the one on the server - although
# unneeded fields can be omitted.
class Contact(crm.Module, api_name="Contacts"):
class Contact(model.SyncModule, client=crm, api_name="Contacts"):
lead_source = model.StringField()
phone_work = model.StringField()
Expand Down Expand Up @@ -53,7 +53,7 @@ for asynchronous I/O. Here is the same code from above, but instead of using
# as the backend.
crm = AioClient("https://crm.example.com", "zucker", "password")
class Contact(crm.Module, api_name="Contacts"):
class Contact(model.AsyncModule, client=crm, api_name="Contacts"):
lead_source = model.StringField()
phone_work = model.StringField()
Expand Down
27 changes: 14 additions & 13 deletions docs/model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ Defining modules
----------------

To define a module, you need to create a subclass of Zucker's
:class:`~zucker.model.module.BaseModule` base class. This is done by using the
``Module`` attribute from the client as a superclass (which automatically points
to the correct module implementation). This type of module is also referred to
as a *bound* module because it is fixed to the client it is initialized with.
Inside your class, define fields with the same name as they appear in the API:
:class:`~zucker.model.module.BaseModule` base class. This is done by using
either :class:`~zucker.model.SyncModule` or :class:`~zucker.model.AsyncModule`
as the superclass, depending on the client implementation. This type of module
is also referred to as a *bound* module because it is fixed to the client it is
initialized with. Inside your class, define fields with the same name as they
appear in the API:

.. code-block:: python
Expand All @@ -41,16 +42,16 @@ Inside your class, define fields with the same name as they appear in the API:
crm = SomeClient(...)
# Here, the 'Contact' module is bound to the 'crm' client:
class Contact(crm.Module, api_name="Contacts"):
class Contact(model.SyncModule, client=crm, api_name="Contacts"):
first_name = model.StringField()
lead_source = model.StringField()
phone_mobile = model.StringField()
phone_work = model.StringField()
email_opt_out = model.BooleanField()
If you in an asynchronous environment, the ``Module`` attribute of the client
will be an async module implementation. Note that it is not possible (and not
supported) to mix synchronous models with asynchronous clients or vice-versa.
If you in an asynchronous environment, use ``model.AsyncModule`` instead. Note
that it is not possible (and not supported) to mix synchronous models with
asynchronous clients or vice-versa.

Extending the model
~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -117,10 +118,10 @@ client yet.
alpha_crm = RequestsClient("https://alpha.example.com", "zucker", "password")
beta_crm = AioClient("https://alpha.example.com", "zucker", "wordpass")
class AlphaContact(alpha_crm.Module, BaseContact, api_name="Contacts"):
class AlphaContact(model.SyncModule, BaseContact, client=alpha_crm, api_name="Contacts"):
pass
class BetaContact(beta_crm.Module, BaseContact, api_name="Contacts"):
class BetaContact(model.AsyncModule, BaseContact, client=beta_crm, api_name="Contacts"):
# This field will only be present on BetaContact instances:
email_opt_out = model.BooleanField()
Expand All @@ -132,8 +133,8 @@ API Reference
-------------

All bound modules define the following API. You don't need to implement the
abstract methods, as their implementation is already provided when you use a
client's ``.Module`` attribute as a superclass.
abstract methods, as their implementation is already provided when you use the
synchronous or asynchronous modules as a superclass.

.. autoclass:: zucker.model.module.BoundModule
:members:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def handle_request(method, path, **kwargs):

assert client.server_info == (server_flavor, server_version, server_build)

class A(client.Module):
class A(model.SyncModule, client=client):
pass

assert "A" in client
Expand Down
12 changes: 9 additions & 3 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@

from zucker import RequestsClient
from zucker.client import SyncClient
from zucker.model import BooleanField, RelatedField, StringField, UnboundModule
from zucker.model import (
BooleanField,
RelatedField,
StringField,
SyncModule,
UnboundModule,
)
from zucker.model.fields.base import MutableField
from zucker.model.module import BaseModule
from zucker.model.view import SyncView
Expand Down Expand Up @@ -63,7 +69,7 @@ def test_boolean_field_values(raw_value):


def test_related_field_initialization(client: SyncClient):
class Demo(client.Module, BaseDemo):
class Demo(SyncModule, BaseDemo, client=client):
pass

with pytest.raises(TypeError):
Expand All @@ -78,7 +84,7 @@ class Demo(client.Module, BaseDemo):


def test_related_view_building(client: SyncClient):
class Demo(client.Module, BaseDemo):
class Demo(SyncModule, BaseDemo, client=client):
pass

field = RelatedField(Demo, "some_link")
Expand Down
14 changes: 7 additions & 7 deletions tests/test_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class BaseDeno(model.UnboundModule):


def test_bound_init(client: SyncClient):
class Demo(client.Module, BaseDemo):
class Demo(model.SyncModule, BaseDemo, client=client):
pass

# Make sure that records can't be created with the wrong API type.
Expand All @@ -37,7 +37,7 @@ class Demo(client.Module, BaseDemo):


def test_strings(client: SyncClient):
class Demo(client.Module, BaseDemo):
class Demo(model.SyncModule, BaseDemo, client=client):
pass

for func in list[Callable[[object], str]]((repr, str)):
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_new_record_saving(client: SyncClient):
"""New records (without IDs) are saved as expected."""
client.request = MagicMock(return_value={"id": "abc", "foo": "hi", "bar": "hu"})

class Demo(client.Module, BaseDemo):
class Demo(model.SyncModule, BaseDemo, client=client):
pass

record = Demo(foo="hi", bar="hu")
Expand All @@ -98,7 +98,7 @@ def test_existing_record_saving(client: SyncClient):
"""Existing records (that have an ID) are saved as expected."""
client.request = MagicMock(return_value={"id": "abc", "foo": "f00", "bar": "hu"})

class Demo(client.Module, BaseDemo):
class Demo(model.SyncModule, BaseDemo, client=client):
pass

record = Demo(_module="Demo", id="abc", foo="hi", bar="hu")
Expand All @@ -116,7 +116,7 @@ class Demo(client.Module, BaseDemo):


def test_deleting_unsaved(client: SyncClient):
class Demo(client.Module, BaseDemo):
class Demo(model.SyncModule, BaseDemo, client=client):
pass

record = Demo(foo="hi", bar="hu")
Expand All @@ -133,7 +133,7 @@ def test_deleting(client: SyncClient):
# original set)
client.request = MagicMock(return_value={})

class Demo(client.Module, BaseDemo):
class Demo(model.SyncModule, BaseDemo, client=client):
pass

record = Demo(_module="Demo", id="abc", foo="hi", bar="hu")
Expand All @@ -151,7 +151,7 @@ class Demo(client.Module, BaseDemo):
def test_refreshing(client: SyncClient):
client.request = MagicMock(return_value={"id": "abc", "foo": "f00", "bar": "b00"})

class Demo(client.Module, BaseDemo):
class Demo(model.SyncModule, BaseDemo, client=client):
pass

record = Demo(_module="Demo", id="abc", foo="hi", bar="hu")
Expand Down
13 changes: 7 additions & 6 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest

from zucker import model
from zucker.client import SyncClient
from zucker.filtering import Combinator, FilterSet

Expand Down Expand Up @@ -37,7 +38,7 @@ def fake_client(monkeypatch) -> FakeClient:


def test_init(fake_client: FakeClient):
class Demo(fake_client.Module):
class Demo(model.SyncModule, client=fake_client):
pass

for not_applicable in (28, False):
Expand All @@ -47,7 +48,7 @@ class Demo(fake_client.Module):


def test_query_params_filters(fake_client: FakeClient):
class Demo(fake_client.Module):
class Demo(model.SyncModule, client=fake_client):
pass

assert Demo.find(
Expand Down Expand Up @@ -93,15 +94,15 @@ def build_filter() -> dict:


def test_len(fake_client: FakeClient):
class Demo(fake_client.Module):
class Demo(model.SyncModule, client=fake_client):
pass

fake_client.set_data("get", "Demo/count", {"record_count": 43})
assert len(Demo.find()) == 43


def test_getting_id(fake_client: FakeClient):
class Demo(fake_client.Module):
class Demo(model.SyncModule, client=fake_client):
pass

view = Demo.find()
Expand All @@ -127,7 +128,7 @@ def callback(given_method, given_url, params):


def test_getting_index(fake_client: FakeClient):
class Demo(fake_client.Module):
class Demo(model.SyncModule, client=fake_client):
pass

def handle(method, url, params):
Expand All @@ -151,7 +152,7 @@ def handle(method, url, params):


def test_iterating_and_slices(fake_client: FakeClient):
class Demo(fake_client.Module):
class Demo(model.SyncModule, client=fake_client):
pass

record_data = [
Expand Down
2 changes: 1 addition & 1 deletion tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_crud(
skipped all together.
"""

class Lead(live_client.Module, BaseLead, api_name="Leads"):
class Lead(model.SyncModule, BaseLead, client=live_client, api_name="Leads"):
pass

lead_1 = Lead(first_name=first_name, last_name=last_name)
Expand Down
34 changes: 0 additions & 34 deletions zucker/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,6 @@ def __init__(

self._metadata: MutableJsonMapping = {}

self.Module = self._build_module_class()

@abc.abstractmethod
def _build_module_class(self) -> Type[BoundModule]:
...

def __contains__(self, item: Union[str, Type[BoundModule]]) -> bool:
"""Check if the server supports a given module name."""
from zucker.model.module import BoundModule
Expand Down Expand Up @@ -287,20 +281,6 @@ def request(


class SyncClient(BaseClient, abc.ABC):
Module: Type[SyncModule]

def _build_module_class(self) -> Type[SyncModule]:
from zucker.model.module import SyncModule

client = self

class BoundSyncModule(SyncModule):
@classmethod
def get_client(self) -> SyncClient:
return client

return BoundSyncModule

def close(self) -> None:
pass

Expand Down Expand Up @@ -328,20 +308,6 @@ def fetch_metadata(self, *types: str):


class AsyncClient(BaseClient):
Module: Type[AsyncModule]

def _build_module_class(self) -> Type[AsyncModule]:
from zucker.model.module import AsyncModule

client = self

class BoundAsyncModule(AsyncModule):
@classmethod
def get_client(self) -> AsyncClient:
return client

return BoundAsyncModule

async def close(self) -> None:
pass

Expand Down
2 changes: 1 addition & 1 deletion zucker/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .fields.registry import field_for_type
from .fields.relationships import *
from .fields.scalars import *
from .module import UnboundModule
from .module import AsyncModule, SyncModule, UnboundModule
26 changes: 22 additions & 4 deletions zucker/model/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,25 +179,39 @@ class BoundModule(Generic[ClientType], BaseModule, abc.ABC):
saved, refreshed and deleted.
"""

_CLIENT_TYPE: ClassVar[Type[ClientType]]
_client: ClassVar[ClientType]
_api_name: ClassVar[str]

# Cache object that allows direct access to records by their ID. References here are
# held weakly - that means that once the corresponding view is garbage collected,
# entries in this cache may no longer be accessible.
_record_cache: ClassVar[MutableMapping[str, BoundModule[ClientType]]]

def __init_subclass__(cls, api_name: Optional[str] = None, **kwargs):
cls._api_name = api_name or cls.__name__
cls._record_cache = WeakValueDictionary()
def __init_subclass__(
cls,
client: Optional[ClientType] = None,
api_name: Optional[str] = None,
**kwargs,
):
if not abc.ABC in cls.__bases__:
if not isinstance(client, cls._CLIENT_TYPE):
raise TypeError(
f"expecting client of type {cls._CLIENT_TYPE!r}, got "
f"{type(client)!r}"
)
cls._client = client
cls._api_name = api_name or cls.__name__
cls._record_cache = WeakValueDictionary()

@classmethod
@abc.abstractmethod
def get_client(cls) -> ClientType:
"""Client instance bounded to this module.
This client will be used for all server-side communication. Further, any caches
are scoped to this client.
"""
return cls._client

@property
def client(self) -> ClientType:
Expand Down Expand Up @@ -367,6 +381,8 @@ def get_by_id(


class SyncModule(BoundModule[SyncClient], abc.ABC):
_CLIENT_TYPE = SyncClient

@classmethod
def find(
cls: Type[SyncSelf], *filters: Union[JsonMapping, GenericFilter]
Expand Down Expand Up @@ -395,6 +411,8 @@ def get_by_id(cls: Type[SyncSelf], key: str) -> Optional[SyncSelf]:


class AsyncModule(BoundModule[AsyncClient], abc.ABC):
_CLIENT_TYPE = AsyncClient

@classmethod
def find(
cls: Type[AsyncSelf], *filters: Union[JsonMapping, GenericFilter]
Expand Down

0 comments on commit 3c62999

Please sign in to comment.