diff --git a/CHANGELOG.md b/CHANGELOG.md index eb78efeb4..53ea3c13d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ ## Release notes +### Upcoming +- Added - `dj.Top` restriction ([#1024](https://github.com/datajoint/datajoint-python/issues/1024)) PR [#1084](https://github.com/datajoint/datajoint-python/pull/1084) + ### 0.14.1 -- Jun 02, 2023 - Fixed - Fix altering a part table that uses the "master" keyword - PR [#991](https://github.com/datajoint/datajoint-python/pull/991) - Fixed - `.ipynb` output in tutorials is not visible in dark mode ([#1078](https://github.com/datajoint/datajoint-python/issues/1078)) PR [#1080](https://github.com/datajoint/datajoint-python/pull/1080) diff --git a/LNX-docker-compose.yml b/LNX-docker-compose.yml index 970552860..9c0a95b78 100644 --- a/LNX-docker-compose.yml +++ b/LNX-docker-compose.yml @@ -44,7 +44,7 @@ services: interval: 15s fakeservices.datajoint.io: <<: *net - image: datajoint/nginx:v0.2.5 + image: datajoint/nginx:v0.2.6 environment: - ADD_db_TYPE=DATABASE - ADD_db_ENDPOINT=db:3306 diff --git a/datajoint/__init__.py b/datajoint/__init__.py index b73ade94a..a1b2befd8 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -37,6 +37,7 @@ "Part", "Not", "AndList", + "Top", "U", "Diagram", "Di", @@ -61,7 +62,7 @@ from .schemas import VirtualModule, list_schemas from .table import Table, FreeTable from .user_tables import Manual, Lookup, Imported, Computed, Part -from .expression import Not, AndList, U +from .expression import Not, AndList, U, Top from .diagram import Diagram from .admin import set_password, kill from .blob import MatCell, MatStruct diff --git a/datajoint/condition.py b/datajoint/condition.py index 80786c84c..de6372c6a 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -10,6 +10,8 @@ import pandas import json from .errors import DataJointError +from typing import Union, List +from dataclasses import dataclass JSON_PATTERN = re.compile( r"^(?P\w+)(\.(?P[\w.*\[\]]+))?(:(?P[\w(,\s)]+))?$" @@ -61,6 +63,35 @@ def append(self, restriction): super().append(restriction) +@dataclass +class Top: + """ + A restriction to the top entities of a query. + In SQL, this corresponds to ORDER BY ... LIMIT ... OFFSET + """ + + limit: Union[int, None] = 1 + order_by: Union[str, List[str]] = "KEY" + offset: int = 0 + + def __post_init__(self): + self.order_by = self.order_by or ["KEY"] + self.offset = self.offset or 0 + + if self.limit is not None and not isinstance(self.limit, int): + raise TypeError("Top limit must be an integer") + if not isinstance(self.order_by, (str, collections.abc.Sequence)) or not all( + isinstance(r, str) for r in self.order_by + ): + raise TypeError("Top order_by attributes must all be strings") + if not isinstance(self.offset, int): + raise TypeError("The offset argument must be an integer") + if self.offset and self.limit is None: + self.limit = 999999999999 # arbitrary large number to allow query + if isinstance(self.order_by, str): + self.order_by = [self.order_by] + + class Not: """invert restriction""" diff --git a/datajoint/declare.py b/datajoint/declare.py index 683e34759..c99c541f0 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -443,9 +443,11 @@ def format_attribute(attr): return f"`{attr}`" return f"({attr})" - match = re.match( - r"(?Punique\s+)?index\s*\(\s*(?P.*)\)", line, re.I - ).groupdict() + match = re.match(r"(?Punique\s+)?index\s*\(\s*(?P.*)\)", line, re.I) + if match is None: + raise DataJointError(f'Table definition syntax error in line "{line}"') + match = match.groupdict() + attr_list = re.findall(r"(?:[^,(]|\([^)]*\))+", match["args"]) index_sql.append( "{unique}index ({attrs})".format( diff --git a/datajoint/expression.py b/datajoint/expression.py index 25dd2fe40..cce40e2e6 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -9,6 +9,7 @@ from .preview import preview, repr_html from .condition import ( AndList, + Top, Not, make_condition, assert_join_compatibility, @@ -52,6 +53,7 @@ class QueryExpression: _connection = None _heading = None _support = None + _top = None # If the query will be using distinct _distinct = False @@ -119,17 +121,33 @@ def where_clause(self): else " WHERE (%s)" % ")AND(".join(str(s) for s in self.restriction) ) + def sorting_clauses(self): + if not self._top: + return "" + clause = ", ".join( + _wrap_attributes( + _flatten_attribute_list(self.primary_key, self._top.order_by) + ) + ) + if clause: + clause = f" ORDER BY {clause}" + if self._top.limit is not None: + clause += f" LIMIT {self._top.limit}{f' OFFSET {self._top.offset}' if self._top.offset else ''}" + + return clause + def make_sql(self, fields=None): """ Make the SQL SELECT statement. :param fields: used to explicitly set the select attributes """ - return "SELECT {distinct}{fields} FROM {from_}{where}".format( + return "SELECT {distinct}{fields} FROM {from_}{where}{sorting}".format( distinct="DISTINCT " if self._distinct else "", fields=self.heading.as_sql(fields or self.heading.names), from_=self.from_clause(), where=self.where_clause(), + sorting=self.sorting_clauses(), ) # --------- query operators ----------- @@ -187,6 +205,14 @@ def restrict(self, restriction): string, or an AndList. """ attributes = set() + if isinstance(restriction, Top): + result = ( + self.make_subquery() + if self._top and not self._top.__eq__(restriction) + else copy.copy(self) + ) # make subquery to avoid overwriting existing Top + result._top = restriction + return result new_condition = make_condition(self, restriction, attributes) if new_condition is True: return self # restriction has no effect, return the same object @@ -200,8 +226,10 @@ def restrict(self, restriction): pass # all ok # If the new condition uses any new attributes, a subquery is required. # However, Aggregation's HAVING statement works fine with aliased attributes. - need_subquery = isinstance(self, Union) or ( - not isinstance(self, Aggregation) and self.heading.new_attributes + need_subquery = ( + isinstance(self, Union) + or (not isinstance(self, Aggregation) and self.heading.new_attributes) + or self._top ) if need_subquery: result = self.make_subquery() @@ -537,19 +565,20 @@ def tail(self, limit=25, **fetch_kwargs): def __len__(self): """:return: number of elements in the result set e.g. ``len(q1)``.""" - return self.connection.query( + result = self.make_subquery() if self._top else copy.copy(self) + return result.connection.query( "SELECT {select_} FROM {from_}{where}".format( select_=( "count(*)" - if any(self._left) + if any(result._left) else "count(DISTINCT {fields})".format( - fields=self.heading.as_sql( - self.primary_key, include_aliases=False + fields=result.heading.as_sql( + result.primary_key, include_aliases=False ) ) ), - from_=self.from_clause(), - where=self.where_clause(), + from_=result.from_clause(), + where=result.where_clause(), ) ).fetchone()[0] @@ -617,18 +646,12 @@ def __next__(self): # -- move on to next entry. return next(self) - def cursor(self, offset=0, limit=None, order_by=None, as_dict=False): + def cursor(self, as_dict=False): """ See expression.fetch() for input description. :return: query cursor """ - if offset and limit is None: - raise DataJointError("limit is required when offset is set") sql = self.make_sql() - if order_by is not None: - sql += " ORDER BY " + ", ".join(order_by) - if limit is not None: - sql += " LIMIT %d" % limit + (" OFFSET %d" % offset if offset else "") logger.debug(sql) return self.connection.query(sql, as_dict=as_dict) @@ -699,21 +722,24 @@ def make_sql(self, fields=None): fields = self.heading.as_sql(fields or self.heading.names) assert self._grouping_attributes or not self.restriction distinct = set(self.heading.names) == set(self.primary_key) - return "SELECT {distinct}{fields} FROM {from_}{where}{group_by}".format( - distinct="DISTINCT " if distinct else "", - fields=fields, - from_=self.from_clause(), - where=self.where_clause(), - group_by="" - if not self.primary_key - else ( - " GROUP BY `%s`" % "`,`".join(self._grouping_attributes) - + ( - "" - if not self.restriction - else " HAVING (%s)" % ")AND(".join(self.restriction) - ) - ), + return ( + "SELECT {distinct}{fields} FROM {from_}{where}{group_by}{sorting}".format( + distinct="DISTINCT " if distinct else "", + fields=fields, + from_=self.from_clause(), + where=self.where_clause(), + group_by="" + if not self.primary_key + else ( + " GROUP BY `%s`" % "`,`".join(self._grouping_attributes) + + ( + "" + if not self.restriction + else f" HAVING ({')AND('.join(self.restriction)})" + ) + ), + sorting=self.sorting_clauses(), + ) ) def __len__(self): @@ -772,7 +798,7 @@ def make_sql(self): ): # no secondary attributes: use UNION DISTINCT fields = arg1.primary_key - return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}`".format( + return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}`{sorting}".format( sql1=arg1.make_sql() if isinstance(arg1, Union) else arg1.make_sql(fields), @@ -780,6 +806,7 @@ def make_sql(self): if isinstance(arg2, Union) else arg2.make_sql(fields), alias=next(self.__count), + sorting=self.sorting_clauses(), ) # with secondary attributes, use union of left join with antijoin fields = self.heading.names @@ -931,3 +958,25 @@ def aggr(self, group, **named_attributes): ) aggregate = aggr # alias for aggr + + +def _flatten_attribute_list(primary_key, attrs): + """ + :param primary_key: list of attributes in primary key + :param attrs: list of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC" + :return: generator of attributes where "KEY" is replaced with its component attributes + """ + for a in attrs: + if re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a): + if primary_key: + yield from primary_key + elif re.match(r"^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$", a): + if primary_key: + yield from (q + " DESC" for q in primary_key) + else: + yield a + + +def _wrap_attributes(attr): + for entry in attr: # wrap attribute names in backquotes + yield re.sub(r"\b((?!asc|desc)\w+)\b", r"`\1`", entry, flags=re.IGNORECASE) diff --git a/datajoint/fetch.py b/datajoint/fetch.py index 750939e5e..49d0b14c0 100644 --- a/datajoint/fetch.py +++ b/datajoint/fetch.py @@ -1,20 +1,18 @@ from functools import partial from pathlib import Path -import logging import pandas import itertools -import re import json import numpy as np import uuid import numbers + +from datajoint.condition import Top from . import blob, hash from .errors import DataJointError from .settings import config from .utils import safe_write -logger = logging.getLogger(__name__.split(".")[0]) - class key: """ @@ -119,21 +117,6 @@ def _get(connection, attr, data, squeeze, download_path): ) -def _flatten_attribute_list(primary_key, attrs): - """ - :param primary_key: list of attributes in primary key - :param attrs: list of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC" - :return: generator of attributes where "KEY" is replaces with its component attributes - """ - for a in attrs: - if re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a): - yield from primary_key - elif re.match(r"^\s*KEY\s+[Dd][Ee][Ss][Cc]\s*$", a): - yield from (q + " DESC" for q in primary_key) - else: - yield a - - class Fetch: """ A fetch object that handles retrieving elements from the table expression. @@ -174,13 +157,13 @@ def __call__( :param download_path: for fetches that download data, e.g. attachments :return: the contents of the table in the form of a structured numpy.array or a dict list """ - if order_by is not None: - # if 'order_by' passed in a string, make into list - if isinstance(order_by, str): - order_by = [order_by] - # expand "KEY" or "KEY DESC" - order_by = list( - _flatten_attribute_list(self._expression.primary_key, order_by) + if offset or order_by or limit: + self._expression = self._expression.restrict( + Top( + limit, + order_by, + offset, + ) ) attrs_as_dict = as_dict and attrs @@ -212,13 +195,6 @@ def __call__( 'use "array" or "frame"'.format(format) ) - if limit is None and offset is not None: - logger.warning( - "Offset set, but no limit. Setting limit to a large number. " - "Consider setting a limit explicitly." - ) - limit = 8000000000 # just a very large number to effect no limit - get = partial( _get, self._expression.connection, @@ -255,9 +231,7 @@ def __call__( ] ret = return_values[0] if len(attrs) == 1 else return_values else: # fetch all attributes as a numpy.record_array or pandas.DataFrame - cur = self._expression.cursor( - as_dict=as_dict, limit=limit, offset=offset, order_by=order_by - ) + cur = self._expression.cursor(as_dict=as_dict) heading = self._expression.heading if as_dict: ret = [ diff --git a/docs/src/query/operators.md b/docs/src/query/operators.md index 9c9258442..550108c75 100644 --- a/docs/src/query/operators.md +++ b/docs/src/query/operators.md @@ -17,8 +17,9 @@ DataJoint implements a complete algebra of operators on tables: | [aggr](#aggr) | A.aggr(B, ...) | Same as projection with computations based on matching information in B | | [union](#union) | A + B | All unique entities from both A and B | | [universal set](#universal-set)\*| dj.U() | All unique entities from both A and B | +| [top](#top)\*| dj.Top() | The top rows of A -\*While not technically a query operator, it is useful to discuss Universal Set in the +\*While not technically query operators, it is useful to discuss Universal Set and Top in the same context. ??? note "Notes on relational algebra" @@ -218,6 +219,29 @@ The examples below will use the table definitions in [table tiers](../reproduce/ +## Top + +Similar to the universal set operator, the top operator uses `dj.Top` notation. It is used to +restrict a query by the given `limit`, `order_by`, and `offset` parameters: + +```python +Session & dj.Top(limit=10, order_by='session_date') +``` + +The result of this expression returns the first 10 rows of `Session` and sorts them +by their `session_date` in ascending order. + +### `order_by` + +| Example | Description | +|-------------------------------------------|---------------------------------------------------------------------------------| +| `order_by="session_date DESC"` | Sort by `session_date` in *descending* order | +| `order_by="KEY"` | Sort by the primary key | +| `order_by="KEY DESC"` | Sort by the primary key in *descending* order | +| `order_by=["subject_id", "session_date"]` | Sort by `subject_id`, then sort matching `subject_id`s by their `session_date` | + +The default values for `dj.Top` parameters are `limit=1`, `order_by="KEY"`, and `offset=0`. + ## Restriction `&` and `-` operators permit restriction. diff --git a/local-docker-compose.yml b/local-docker-compose.yml index 8b43289d3..62b52ad66 100644 --- a/local-docker-compose.yml +++ b/local-docker-compose.yml @@ -46,7 +46,7 @@ services: interval: 15s fakeservices.datajoint.io: <<: *net - image: datajoint/nginx:v0.2.5 + image: datajoint/nginx:v0.2.6 environment: - ADD_db_TYPE=DATABASE - ADD_db_ENDPOINT=db:3306 diff --git a/tests_old/schema_simple.py b/tests_old/schema_simple.py index 78f64d036..3f0c29b8d 100644 --- a/tests_old/schema_simple.py +++ b/tests_old/schema_simple.py @@ -14,6 +14,24 @@ schema = dj.Schema(PREFIX + "_relational", locals(), connection=dj.conn(**CONN_INFO)) +@schema +class SelectPK(dj.Lookup): + definition = """ # tests sql keyword escaping + id: int + select : int + """ + contents = list(dict(id=i, select=i * j) for i in range(3) for j in range(4, 0, -1)) + + +@schema +class KeyPK(dj.Lookup): + definition = """ # tests sql keyword escaping + id : int + key : int + """ + contents = list(dict(id=i, key=i + j) for i in range(3) for j in range(4, 0, -1)) + + @schema class IJ(dj.Lookup): definition = """ # tests restrictions diff --git a/tests_old/test_declare.py b/tests_old/test_declare.py index 67f532449..bb23be276 100644 --- a/tests_old/test_declare.py +++ b/tests_old/test_declare.py @@ -341,3 +341,12 @@ class WithSuchALongPartNameThatItCrashesMySQL(dj.Part): definition = """ -> (master) """ + + @staticmethod + @raises(dj.DataJointError) + def test_regex_mismatch(): + @schema + class IndexAttribute(dj.Manual): + definition = """ + index: int + """ diff --git a/tests_old/test_fetch.py b/tests_old/test_fetch.py index 684cd4846..af0156c6a 100644 --- a/tests_old/test_fetch.py +++ b/tests_old/test_fetch.py @@ -213,26 +213,6 @@ def test_offset(self): np.all([cc == ll for cc, ll in zip(c, l)]), "Sorting order is different" ) - def test_limit_warning(self): - """Tests whether warning is raised if offset is used without limit.""" - log_capture = io.StringIO() - stream_handler = logging.StreamHandler(log_capture) - log_format = logging.Formatter( - "[%(asctime)s][%(funcName)s][%(levelname)s]: %(message)s" - ) - stream_handler.setFormatter(log_format) - stream_handler.set_name("test_limit_warning") - logger.addHandler(stream_handler) - self.lang.fetch(offset=1) - - log_contents = log_capture.getvalue() - log_capture.close() - - for handler in logger.handlers: # Clean up handler - if handler.name == "test_limit_warning": - logger.removeHandler(handler) - assert "[WARNING]: Offset set, but no limit." in log_contents - def test_len(self): """Tests __len__""" assert_equal( diff --git a/tests_old/test_relational_operand.py b/tests_old/test_relational_operand.py index 0611ab267..3ba6291da 100644 --- a/tests_old/test_relational_operand.py +++ b/tests_old/test_relational_operand.py @@ -11,9 +11,11 @@ raises, assert_set_equal, assert_list_equal, + assert_raises, ) import datajoint as dj +from datajoint.errors import DataJointError from .schema_simple import ( A, B, @@ -23,6 +25,8 @@ L, DataA, DataB, + SelectPK, + KeyPK, TTestUpdate, IJ, JI, @@ -487,6 +491,95 @@ def test_restrictions_by_lists(): ) assert_true(len(w - y) == 0, "incorrect restriction without common attributes") + @staticmethod + def test_restrictions_by_top(): + a = L() & dj.Top() + b = L() & dj.Top(order_by=["cond_in_l", "KEY"]) + x = L() & dj.Top(5, "id_l desc", 4) & "cond_in_l=1" + y = L() & "cond_in_l=1" & dj.Top(5, "id_l desc", 4) + z = ( + L() + & dj.Top(None, order_by="id_l desc") + & "cond_in_l=1" + & dj.Top(5, "id_l desc") + & ("id_l=20", "id_l=16", "id_l=17") + & dj.Top(2, "id_l asc", 1) + ) + assert len(a) == 1 + assert len(b) == 1 + assert len(x) == 1 + assert len(y) == 5 + assert len(z) == 2 + assert a.fetch(as_dict=True) == [ + {"id_l": 0, "cond_in_l": 1}, + ] + assert b.fetch(as_dict=True) == [ + {"id_l": 3, "cond_in_l": 0}, + ] + assert x.fetch(as_dict=True) == [{"id_l": 25, "cond_in_l": 1}] + assert y.fetch(as_dict=True) == [ + {"id_l": 16, "cond_in_l": 1}, + {"id_l": 15, "cond_in_l": 1}, + {"id_l": 11, "cond_in_l": 1}, + {"id_l": 10, "cond_in_l": 1}, + {"id_l": 5, "cond_in_l": 1}, + ] + assert z.fetch(as_dict=True) == [ + {"id_l": 17, "cond_in_l": 1}, + {"id_l": 20, "cond_in_l": 1}, + ] + + @staticmethod + def test_top_restriction_with_keywords(): + select = SelectPK() & dj.Top(limit=9, order_by=["select desc"]) + key = KeyPK() & dj.Top(limit=9, order_by="key desc") + assert select.fetch(as_dict=True) == [ + {"id": 2, "select": 8}, + {"id": 2, "select": 6}, + {"id": 1, "select": 4}, + {"id": 2, "select": 4}, + {"id": 1, "select": 3}, + {"id": 1, "select": 2}, + {"id": 2, "select": 2}, + {"id": 1, "select": 1}, + {"id": 0, "select": 0}, + ] + assert key.fetch(as_dict=True) == [ + {"id": 2, "key": 6}, + {"id": 2, "key": 5}, + {"id": 1, "key": 5}, + {"id": 0, "key": 4}, + {"id": 1, "key": 4}, + {"id": 2, "key": 4}, + {"id": 0, "key": 3}, + {"id": 1, "key": 3}, + {"id": 2, "key": 3}, + ] + + @staticmethod + def test_top_errors(): + with assert_raises(DataJointError) as err1: + L() & ("cond_in_l=1", dj.Top()) + with assert_raises(DataJointError) as err2: + L() & dj.AndList(["cond_in_l=1", dj.Top()]) + with assert_raises(TypeError) as err3: + L() & dj.Top(limit="1") + with assert_raises(TypeError) as err4: + L() & dj.Top(order_by=1) + with assert_raises(TypeError) as err5: + L() & dj.Top(offset="1") + assert ( + "Invalid restriction type Top(limit=1, order_by=['KEY'], offset=0)" + == str(err1.exception) + ) + assert ( + "Invalid restriction type Top(limit=1, order_by=['KEY'], offset=0)" + == str(err2.exception) + ) + assert "Top limit must be an integer" == str(err3.exception) + assert "Top order_by attributes must all be strings" == str(err4.exception) + assert "The offset argument must be an integer" == str(err5.exception) + @staticmethod def test_datetime(): """Test date retrieval""" diff --git a/tests_old/test_schema.py b/tests_old/test_schema.py index 8ec24fc49..f7a18198e 100644 --- a/tests_old/test_schema.py +++ b/tests_old/test_schema.py @@ -155,6 +155,8 @@ def test_list_tables(): "#website", "profile", "profile__website", + "#select_p_k", + "#key_p_k", ] ) == set(schema_simple.list_tables())