From 2ef436f53da231d93ed508c3fa1522105b369443 Mon Sep 17 00:00:00 2001 From: Jeremy Howard Date: Tue, 22 Oct 2024 17:02:21 +1000 Subject: [PATCH] fixes #34 --- fastlite/core.py | 3 ++- fastlite/kw.py | 12 ++++++------ nbs/00_core.ipynb | 3 ++- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/fastlite/core.py b/fastlite/core.py index 66d9e34..79b6fe9 100644 --- a/fastlite/core.py +++ b/fastlite/core.py @@ -118,10 +118,11 @@ def __call__( select:str = "*", # Comma-separated list of columns to select with_pk:bool=False, # Return tuple of (pk,row)? as_cls:bool=True, # Convert returned dict to stored dataclass? + xtra:dict|None=None, # Extra constraints **kwargs)->list: "Shortcut for `rows_where` or `pks_and_rows_where`, depending on `with_pk`" f = getattr(self, 'pks_and_rows_where' if with_pk else 'rows_where') - xtra = getattr(self, 'xtra_id', {}) + if not xtra: xtra = getattr(self, 'xtra_id', {}) if xtra: xw = ' and '.join(f"[{k}] = {v!r}" for k,v in xtra.items()) where = f'{xw} and {where}' if where else xw diff --git a/fastlite/kw.py b/fastlite/kw.py index e0cada1..34e2266 100644 --- a/fastlite/kw.py +++ b/fastlite/kw.py @@ -46,10 +46,10 @@ def ids_and_rows_where( yield row.pop('__rid'), row @patch -def get(self:Table, pk_values: list|tuple|str|int, as_cls:bool=True)->Any: +def get(self:Table, pk_values: list|tuple|str|int, as_cls:bool=True, xtra:dict|None=None)->Any: if not isinstance(pk_values, (list, tuple)): pk_values = [pk_values] last_pk = pk_values[0] if len(self.pks) == 1 else pk_values - xtra = getattr(self, 'xtra_id', {}) + if not xtra: xtra = getattr(self, 'xtra_id', {}) vals = list(pk_values) + list(xtra.values()) pks = self.pks + list(xtra.keys()) if len(pks)!=len(vals): raise NotFoundError(f"Need {len(pks)} pk") @@ -118,10 +118,10 @@ def _process_row(row): return {k:(v.value if isinstance(v, Enum) else v) for k,v @patch def update(self:Table, updates: dict|None=None, pk_values: list|tuple|str|int|float|None=None, - alter: bool=False, conversions: dict|None=None, **kwargs): + alter: bool=False, conversions: dict|None=None, xtra:dict|None=None, **kwargs): if not updates: updates={} updates = _process_row(updates) - xtra = getattr(self, 'xtra_id', {}) + if not xtra: xtra = getattr(self, 'xtra_id', {}) updates = {**updates, **kwargs, **xtra} if pk_values is None: pk_values = [updates[o] for o in self.pks] self._orig_update(pk_values, updates=updates, alter=alter, conversions=conversions) @@ -143,9 +143,9 @@ def insert_all( conversions: Union[Dict[str, str], Default, None]=DEFAULT, columns: Union[Dict[str, Any], Default, None]=DEFAULT, strict: opt_bool=DEFAULT, - upsert:bool=False, analyze:bool=False, + upsert:bool=False, analyze:bool=False, xtra:dict|None=None, **kwargs) -> Table: - xtra = getattr(self,'xtra_id',{}) + if not xtra: xtra = getattr(self,'xtra_id',{}) records = [_process_row(o) for o in records] records = [{**o, **xtra} for o in records] return self._orig_insert_all( diff --git a/nbs/00_core.ipynb b/nbs/00_core.ipynb index 24b0fa1..a96859e 100644 --- a/nbs/00_core.ipynb +++ b/nbs/00_core.ipynb @@ -530,10 +530,11 @@ " select:str = \"*\", # Comma-separated list of columns to select\n", " with_pk:bool=False, # Return tuple of (pk,row)?\n", " as_cls:bool=True, # Convert returned dict to stored dataclass?\n", + " xtra:dict|None=None, # Extra constraints\n", " **kwargs)->list:\n", " \"Shortcut for `rows_where` or `pks_and_rows_where`, depending on `with_pk`\"\n", " f = getattr(self, 'pks_and_rows_where' if with_pk else 'rows_where')\n", - " xtra = getattr(self, 'xtra_id', {})\n", + " if not xtra: xtra = getattr(self, 'xtra_id', {})\n", " if xtra:\n", " xw = ' and '.join(f\"[{k}] = {v!r}\" for k,v in xtra.items())\n", " where = f'{xw} and {where}' if where else xw\n",