From 95591d6b2b19563300b3923e47472975ec6d5d83 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Thu, 28 Jul 2022 19:24:13 +0200 Subject: [PATCH 1/5] Add type annotations to public facing APIs --- .editorconfig | 9 +- .pre-commit-config.yaml | 8 + CHANGELOG.md | 2 + bin/make-unasync | 3 + docs/source/conf.py | 3 + neo4j/__init__.py | 5 + neo4j/_async/driver.py | 351 +++-- neo4j/_async/io/_bolt.py | 8 +- neo4j/_async/io/_bolt3.py | 7 +- neo4j/_async/work/__init__.py | 4 +- neo4j/_async/work/result.py | 87 +- neo4j/_async/work/session.py | 79 +- neo4j/_async/work/transaction.py | 51 +- neo4j/_async/work/workspace.py | 4 +- neo4j/_async_compat/network/_bolt_socket.py | 16 +- .../_codec/hydration/v2/hydration_handler.py | 4 +- neo4j/_codec/hydration/v2/temporal.py | 4 +- neo4j/_codec/packstream/v1/__init__.py | 5 - neo4j/_conf.py | 8 - neo4j/_data.py | 53 +- neo4j/_meta.py | 25 +- neo4j/_spatial/__init__.py | 71 +- neo4j/_sync/driver.py | 352 +++-- neo4j/_sync/io/_bolt.py | 8 +- neo4j/_sync/io/_bolt3.py | 7 +- neo4j/_sync/work/__init__.py | 4 +- neo4j/_sync/work/result.py | 87 +- neo4j/_sync/work/session.py | 79 +- neo4j/_sync/work/transaction.py | 51 +- neo4j/_sync/work/workspace.py | 4 +- neo4j/addressing.py | 77 +- neo4j/api.py | 132 +- neo4j/debug.py | 39 +- neo4j/exceptions.py | 78 +- neo4j/graph/__init__.py | 160 ++- neo4j/packstream.py | 4 - neo4j/py.typed | 0 neo4j/spatial/__init__.py | 16 - neo4j/time/__init__.py | 1182 +++++++++-------- neo4j/time/_arithmetic.py | 13 +- neo4j/work/query.py | 27 +- neo4j/work/summary.py | 68 +- requirements-dev.txt | 2 + setup.cfg | 5 + tests/env.py | 14 +- tests/unit/async_/work/test_result.py | 18 +- .../common/spatial/test_cartesian_point.py | 10 +- tests/unit/common/spatial/test_point.py | 35 +- tests/unit/common/spatial/test_wgs84_point.py | 8 +- tests/unit/common/test_addressing.py | 41 +- tests/unit/common/test_api.py | 108 +- tests/unit/common/test_conf.py | 2 - tests/unit/common/test_debug.py | 37 +- tests/unit/common/test_import_neo4j.py | 2 + tests/unit/common/test_record.py | 79 +- tests/unit/common/test_security.py | 14 +- tests/unit/common/test_types.py | 2 +- tests/unit/common/time/test_date.py | 213 +-- tests/unit/common/time/test_datetime.py | 132 +- tests/unit/common/time/test_duration.py | 166 ++- tests/unit/common/time/test_time.py | 83 +- tests/unit/sync/work/test_result.py | 18 +- 62 files changed, 2432 insertions(+), 1752 deletions(-) create mode 100644 neo4j/py.typed diff --git a/.editorconfig b/.editorconfig index 823ad91d..f075225d 100644 --- a/.editorconfig +++ b/.editorconfig @@ -16,7 +16,12 @@ trim_trailing_whitespace = true [*.bat] end_of_line = crlf -[*.py] -max_line_length = 79 +[*{.py,pyi}] indent_style = space indent_size = 4 + +[*.py] +max_line_length = 79 + +[*.pyi] +max_line_length = 130 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c4a7f3d6..009fe61d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,3 +34,11 @@ repos: entry: bin/make-unasync language: system files: "^(neo4j/_async|tests/(unit|integration)/async_|testkitbackend/_async)/.*" + - id: mypy + name: mypy static type check + entry: mypy + args: [ --show-error-codes, neo4j, tests, testkitbackend ] + 'types_or': [ python, pyi ] + language: system + pass_filenames: false + require_serial: true diff --git a/CHANGELOG.md b/CHANGELOG.md index 12caefe0..3c6036f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -118,6 +118,8 @@ be used by client code. `Record` should be imported directly from `neo4j` instead. `neo4j.data.DataHydrator` and `neo4j.data.DataDeydrator` have been removed without replacement. +- Removed undocumented config options that had no effect: + `protocol_version` and `init_size`. ## Version 4.4 diff --git a/bin/make-unasync b/bin/make-unasync index d99a51a9..f1b99465 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -104,6 +104,9 @@ class CustomRule(unasync.Rule): def __init__(self, *args, **kwargs): super(CustomRule, self).__init__(*args, **kwargs) self.out_files = [] + # it's not pretty, but it works + # typing.Awaitable[...] -> typing.Union[...] + self.token_replacements["Awaitable"] = "Union" def _unasync_tokens(self, tokens): # copy from unasync to hook into string handling diff --git a/docs/source/conf.py b/docs/source/conf.py index d6cfe9fe..256d51f7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -115,6 +115,9 @@ # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True +# Don't include type hints in function signatures +autodoc_typehints = "description" + # -- Options for HTML output ---------------------------------------------- diff --git a/neo4j/__init__.py b/neo4j/__init__.py index 8da6854d..df8590be 100644 --- a/neo4j/__init__.py +++ b/neo4j/__init__.py @@ -108,6 +108,7 @@ "BoltDriver", "Bookmark", "Bookmarks", + "Config", "custom_auth", "DEFAULT_DATABASE", "Driver", @@ -117,8 +118,10 @@ "IPv4Address", "IPv6Address", "kerberos_auth", + "log", "ManagedTransaction", "Neo4jDriver", + "PoolConfig", "Query", "READ_ACCESS", "Record", @@ -126,6 +129,7 @@ "ResultSummary", "ServerInfo", "Session", + "SessionConfig", "SummaryCounters", "Transaction", "TRUST_ALL_CERTIFICATES", @@ -135,6 +139,7 @@ "TrustSystemCAs", "unit_of_work", "Version", + "WorkspaceConfig", "WRITE_ACCESS", ] diff --git a/neo4j/_async/driver.py b/neo4j/_async/driver.py index bec20108..377c506d 100644 --- a/neo4j/_async/driver.py +++ b/neo4j/_async/driver.py @@ -16,6 +16,16 @@ # limitations under the License. +from __future__ import annotations + +import typing as t + + +if t.TYPE_CHECKING: + import typing_extensions as te + + import ssl + from .._async_compat.util import AsyncUtil from .._conf import ( Config, @@ -33,125 +43,175 @@ ) from ..addressing import Address from ..api import ( + Auth, + Bookmarks, + DRIVER_BOLT, + DRIVER_NEO4J, + parse_neo4j_uri, + parse_routing_context, READ_ACCESS, + SECURITY_TYPE_SECURE, + SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, + ServerInfo, TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, + URI_SCHEME_BOLT, + URI_SCHEME_BOLT_SECURE, + URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE, + URI_SCHEME_NEO4J, + URI_SCHEME_NEO4J_SECURE, + URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, ) +from .work import AsyncSession class AsyncGraphDatabase: """Accessor for :class:`neo4j.Driver` construction. """ - @classmethod - @AsyncUtil.experimental_async( - "neo4j async is in experimental phase. It might be removed or changed " - "at any time (including patch releases)." - ) - def driver(cls, uri, *, auth=None, **config): - """Create a driver. - - :param uri: the connection URI for the driver, see :ref:`async-uri-ref` for available URIs. - :param auth: the authentication details, see :ref:`auth-ref` for available authentication details. - :param config: driver configuration key-word arguments, see :ref:`async-driver-configuration-ref` for available key-word arguments. - - :rtype: AsyncNeo4jDriver or AsyncBoltDriver - """ - - from ..api import ( - DRIVER_BOLT, - DRIVER_NEO4j, - parse_neo4j_uri, - parse_routing_context, - SECURITY_TYPE_SECURE, - SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, - URI_SCHEME_BOLT, - URI_SCHEME_BOLT_SECURE, - URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE, - URI_SCHEME_NEO4J, - URI_SCHEME_NEO4J_SECURE, - URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, + if t.TYPE_CHECKING: + + @classmethod + def driver( + cls, + uri: str, + *, + auth: t.Union[t.Tuple[t.Any, t.Any], Auth, None] = ..., + max_connection_lifetime: float = ..., + max_connection_pool_size: int = ..., + connection_timeout: float = ..., + update_routing_table_timeout: float = ..., + trust: t.Union[ + te.Literal["TRUST_ALL_CERTIFICATES"], + te.Literal["TRUST_SYSTEM_CA_SIGNED_CERTIFICATES"] + ] = ..., + resolver: t.Union[ + t.Callable[[Address], t.Iterable[Address]], + t.Callable[[Address], t.Awaitable[t.Iterable[Address]]], + ] = ..., + encrypted: bool = ..., + trusted_certificates: TrustStore = ..., + ssl_context: ssl.SSLContext = ..., + user_agent: str = ..., + keep_alive: bool = ..., + + # undocumented/unsupported options + session_connection_timeout: float = ..., + connection_acquisition_timeout: float = ..., + max_transaction_retry_time: float = ..., + initial_retry_delay: float = ..., + retry_delay_multiplier: float = ..., + retry_delay_jitter_factor: float = ..., + database: t.Optional[str] = ..., + fetch_size: int = ..., + impersonated_user: t.Optional[str] = ... + ) -> AsyncDriver: + ... + + else: + + @classmethod + @AsyncUtil.experimental_async( + "neo4j async is in experimental phase. It might be removed or " + "changed at any time (including patch releases)." ) + def driver( + cls, + uri: str, + *, + auth: t.Union[t.Tuple[t.Any, t.Any], Auth] = None, + **config # TODO: type config + ) -> AsyncDriver: + """Create a driver. + + :param uri: the connection URI for the driver, see :ref:`async-uri-ref` for available URIs. + :param auth: the authentication details, see :ref:`auth-ref` for available authentication details. + :param config: driver configuration key-word arguments, see :ref:`async-driver-configuration-ref` for available key-word arguments. + """ + + driver_type, security_type, parsed = parse_neo4j_uri(uri) + + # TODO: 6.0 remove "trust" config option + if "trust" in config.keys(): + if config["trust"] not in ( + TRUST_ALL_CERTIFICATES, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES + ): + from neo4j.exceptions import ConfigurationError + raise ConfigurationError( + "The config setting `trust` values are {!r}" + .format( + [ + TRUST_ALL_CERTIFICATES, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, + ] + ) + ) - driver_type, security_type, parsed = parse_neo4j_uri(uri) + if ("trusted_certificates" in config.keys() + and not isinstance(config["trusted_certificates"], + TrustStore)): + raise ConnectionError( + "The config setting `trusted_certificates` must be of " + "type neo4j.TrustAll, neo4j.TrustCustomCAs, or" + "neo4j.TrustSystemCAs but was {}".format( + type(config["trusted_certificates"]) + ) + ) - # TODO: 6.0 remove "trust" config option - if "trust" in config.keys(): - if config["trust"] not in (TRUST_ALL_CERTIFICATES, - TRUST_SYSTEM_CA_SIGNED_CERTIFICATES): + if (security_type in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE] + and ("encrypted" in config.keys() + or "trust" in config.keys() + or "trusted_certificates" in config.keys() + or "ssl_context" in config.keys())): from neo4j.exceptions import ConfigurationError + + # TODO: 6.0 remove "trust" from error message raise ConfigurationError( - "The config setting `trust` values are {!r}" + 'The config settings "encrypted", "trust", ' + '"trusted_certificates", and "ssl_context" can only be ' + "used with the URI schemes {!r}. Use the other URI " + "schemes {!r} for setting encryption settings." .format( [ - TRUST_ALL_CERTIFICATES, - TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, + URI_SCHEME_BOLT, + URI_SCHEME_NEO4J, + ], + [ + URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE, + URI_SCHEME_BOLT_SECURE, + URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, + URI_SCHEME_NEO4J_SECURE, ] ) ) - if ("trusted_certificates" in config.keys() - and not isinstance(config["trusted_certificates"], - TrustStore)): - raise ConnectionError( - "The config setting `trusted_certificates` must be of type " - "neo4j.TrustAll, neo4j.TrustCustomCAs, or" - "neo4j.TrustSystemCAs but was {}".format( - type(config["trusted_certificates"]) - ) - ) - - if (security_type in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE] - and ("encrypted" in config.keys() - or "trust" in config.keys() - or "trusted_certificates" in config.keys() - or "ssl_context" in config.keys())): - from neo4j.exceptions import ConfigurationError - - # TODO: 6.0 remove "trust" from error message - raise ConfigurationError( - 'The config settings "encrypted", "trust", ' - '"trusted_certificates", and "ssl_context" can only be used ' - "with the URI schemes {!r}. Use the other URI schemes {!r} " - "for setting encryption settings." - .format( - [ - URI_SCHEME_BOLT, - URI_SCHEME_NEO4J, - ], - [ - URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE, - URI_SCHEME_BOLT_SECURE, - URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, - URI_SCHEME_NEO4J_SECURE, - ] - ) - ) - - if security_type == SECURITY_TYPE_SECURE: - config["encrypted"] = True - elif security_type == SECURITY_TYPE_SELF_SIGNED_CERTIFICATE: - config["encrypted"] = True - config["trusted_certificates"] = TrustAll() - - if driver_type == DRIVER_BOLT: - if parse_routing_context(parsed.query): - deprecation_warn( - "Creating a direct driver (`bolt://` scheme) with routing " - "context (URI parameters) is deprecated. They will be " - "ignored. This will raise an error in a future release. " - 'Given URI "{}"'.format(uri), - stack_level=2 - ) - # TODO: 6.0 - raise instead of warning - # raise ValueError( - # 'Routing parameters are not supported with scheme ' - # '"bolt". Given URI "{}".'.format(uri) - # ) - return cls.bolt_driver(parsed.netloc, auth=auth, **config) - elif driver_type == DRIVER_NEO4j: + if security_type == SECURITY_TYPE_SECURE: + config["encrypted"] = True + elif security_type == SECURITY_TYPE_SELF_SIGNED_CERTIFICATE: + config["encrypted"] = True + config["trusted_certificates"] = TrustAll() + + assert driver_type in (DRIVER_BOLT, DRIVER_NEO4J) + if driver_type == DRIVER_BOLT: + if parse_routing_context(parsed.query): + deprecation_warn( + "Creating a direct driver (`bolt://` scheme) with " + "routing context (URI parameters) is deprecated. They " + "will be ignored. This will raise an error in a " + 'future release. Given URI "{}"'.format(uri), + stack_level=2 + ) + # TODO: 6.0 - raise instead of warning + # raise ValueError( + # 'Routing parameters are not supported with scheme ' + # '"bolt". Given URI "{}".'.format(uri) + # ) + return cls.bolt_driver(parsed.netloc, auth=auth, **config) + # else driver_type == DRIVER_NEO4J routing_context = parse_routing_context(parsed.query) - return cls.neo4j_driver(parsed.netloc, auth=auth, routing_context=routing_context, **config) + return cls.neo4j_driver(parsed.netloc, auth=auth, + routing_context=routing_context, **config) @classmethod def bolt_driver(cls, target, *, auth=None, **config): @@ -243,7 +303,7 @@ class AsyncDriver: """ #: Connection pool - _pool = None + _pool: t.Any = None #: Flag if the driver has been closed _closed = False @@ -254,7 +314,7 @@ def __init__(self, pool, default_workspace_config): self._pool = pool self._default_workspace_config = default_workspace_config - async def __aenter__(self): + async def __aenter__(self) -> AsyncDriver: return self async def __aexit__(self, exc_type, exc_value, traceback): @@ -276,31 +336,49 @@ def __del__(self): self.close() @property - def encrypted(self): - """Indicate whether the driver was configured to use encryption. - - :rtype: bool""" + def encrypted(self) -> bool: + """Indicate whether the driver was configured to use encryption.""" return bool(self._pool.pool_config.encrypted) - def session(self, **config): - """Create a session, see :ref:`async-session-construction-ref` - - :param config: session configuration key-word arguments, - see :ref:`async-session-configuration-ref` for available key-word - arguments. - - :returns: new :class:`neo4j.AsyncSession` object - """ - raise NotImplementedError - - async def close(self): + if t.TYPE_CHECKING: + + def session( + self, + session_connection_timeout: float = ..., + connection_acquisition_timeout: float = ..., + max_transaction_retry_time: float = ..., + initial_retry_delay: float = ..., + retry_delay_multiplier: float = ..., + retry_delay_jitter_factor: float = ..., + database: t.Optional[str] = ..., + fetch_size: int = ..., + impersonated_user: t.Optional[str] = ..., + bookmarks: t.Union[t.Iterable[str], Bookmarks, None] = ..., + default_access_mode: str = ..., + ) -> AsyncSession: + ... + + else: + + def session(self, **config) -> AsyncSession: + """Create a session, see :ref:`async-session-construction-ref` + + :param config: session configuration key-word arguments, + see :ref:`async-session-configuration-ref` for available + key-word arguments. + + :returns: new :class:`neo4j.AsyncSession` object + """ + raise NotImplementedError + + async def close(self) -> None: """ Shut down, closing any open connections in the pool. """ await self._pool.close() self._closed = True # TODO: 6.0 - remove config argument - async def verify_connectivity(self, **config): + async def verify_connectivity(self, **config) -> None: """Verify that the driver can establish a connection to the server. This verifies if the driver can establish a reading connection to a @@ -337,7 +415,7 @@ async def verify_connectivity(self, **config): async with self.session(**config) as session: await session._get_server_info() - async def get_server_info(self, **config): + async def get_server_info(self, **config) -> ServerInfo: """Get information about the connected Neo4j server. Try to establish a working read connection to the remote server or a @@ -378,11 +456,10 @@ async def get_server_info(self, **config): return await session._get_server_info() @experimental("Feature support query, based on Bolt protocol version and Neo4j server version will change in the future.") - async def supports_multi_db(self): + async def supports_multi_db(self) -> bool: """ Check if the server or cluster supports multi-databases. :return: Returns true if the server or cluster the driver connects to supports multi-databases, otherwise false. - :rtype: bool .. note:: Feature support query, based on Bolt Protocol Version and Neo4j @@ -390,6 +467,7 @@ async def supports_multi_db(self): """ async with self.session() as session: await session._connect(READ_ACCESS) + assert session._connection return session._connection.supports_multiple_databases @@ -426,17 +504,20 @@ def __init__(self, pool, default_workspace_config): AsyncDriver.__init__(self, pool, default_workspace_config) self._default_workspace_config = default_workspace_config - def session(self, **config): - """ - :param config: The values that can be specified are found in :class: `neo4j.SessionConfig` + if not t.TYPE_CHECKING: - :return: - :rtype: :class: `neo4j.AsyncSession` - """ - from .work import AsyncSession - session_config = SessionConfig(self._default_workspace_config, config) - SessionConfig.consume(config) # Consume the config - return AsyncSession(self._pool, session_config) + def session(self, **config) -> AsyncSession: + """ + :param config: The values that can be specified are found in + :class: `neo4j.SessionConfig` + + :return: + :rtype: :class: `neo4j.AsyncSession` + """ + session_config = SessionConfig(self._default_workspace_config, + config) + SessionConfig.consume(config) # Consume the config + return AsyncSession(self._pool, session_config) class AsyncNeo4jDriver(_Routing, AsyncDriver): @@ -462,8 +543,10 @@ def __init__(self, pool, default_workspace_config): _Routing.__init__(self, pool.get_default_database_initial_router_addresses()) AsyncDriver.__init__(self, pool, default_workspace_config) - def session(self, **config): - from .work import AsyncSession - session_config = SessionConfig(self._default_workspace_config, config) - SessionConfig.consume(config) # Consume the config - return AsyncSession(self._pool, session_config) + if not t.TYPE_CHECKING: + + def session(self, **config) -> AsyncSession: + session_config = SessionConfig(self._default_workspace_config, + config) + SessionConfig.consume(config) # Consume the config + return AsyncSession(self._pool, session_config) diff --git a/neo4j/_async/io/_bolt.py b/neo4j/_async/io/_bolt.py index 7f31f32d..aba4ceb8 100644 --- a/neo4j/_async/io/_bolt.py +++ b/neo4j/_async/io/_bolt.py @@ -16,6 +16,8 @@ # limitations under the License. +from __future__ import annotations + import abc import asyncio from collections import deque @@ -23,14 +25,12 @@ from time import perf_counter from ..._async_compat.network import AsyncBoltSocket -from ..._async_compat.util import AsyncUtil from ..._codec.hydration import v1 as hydration_v1 from ..._codec.packstream import v1 as packstream_v1 from ..._conf import PoolConfig from ..._exceptions import ( BoltError, BoltHandshakeError, - SocketDeadlineExceeded, ) from ..._meta import get_user_agent from ...addressing import Address @@ -75,7 +75,7 @@ class AsyncBolt: MAGIC_PREAMBLE = b"\x60\x60\xB0\x17" - PROTOCOL_VERSION = None + PROTOCOL_VERSION: Version = None # type: ignore[assignment] # flag if connection needs RESET to go back to READY state is_reset = False @@ -776,4 +776,4 @@ def is_idle_for(self, timeout): return perf_counter() - self.idle_since > timeout -AsyncBoltSocket.Bolt = AsyncBolt +AsyncBoltSocket.Bolt = AsyncBolt # type: ignore diff --git a/neo4j/_async/io/_bolt3.py b/neo4j/_async/io/_bolt3.py index 653f2a10..072f7417 100644 --- a/neo4j/_async/io/_bolt3.py +++ b/neo4j/_async/io/_bolt3.py @@ -20,11 +20,7 @@ from logging import getLogger from ssl import SSLSocket -from ..._async_compat.util import AsyncUtil -from ..._exceptions import ( - BoltError, - BoltProtocolError, -) +from ..._exceptions import BoltProtocolError from ...api import ( READ_ACCESS, Version, @@ -32,7 +28,6 @@ from ...exceptions import ( ConfigurationError, DatabaseUnavailable, - DriverError, ForbiddenOnReadOnlyDatabase, Neo4jError, NotALeader, diff --git a/neo4j/_async/work/__init__.py b/neo4j/_async/work/__init__.py index 02099c42..264e6cb9 100644 --- a/neo4j/_async/work/__init__.py +++ b/neo4j/_async/work/__init__.py @@ -24,13 +24,15 @@ from .transaction import ( AsyncManagedTransaction, AsyncTransaction, + AsyncTransactionBase, ) __all__ = [ "AsyncResult", "AsyncSession", - "AsyncTransaction", "AsyncManagedTransaction", + "AsyncTransaction", + "AsyncTransactionBase", "AsyncWorkspace", ] diff --git a/neo4j/_async/work/result.py b/neo4j/_async/work/result.py index 0740cf58..9ab443a3 100644 --- a/neo4j/_async/work/result.py +++ b/neo4j/_async/work/result.py @@ -16,9 +16,14 @@ # limitations under the License. +from __future__ import annotations + +import typing as t from collections import deque from warnings import warn +import typing_extensions as te + from ..._async_compat.util import AsyncUtil from ..._codec.hydration import BrokenHydrationObject from ..._data import ( @@ -38,6 +43,16 @@ from ..io import ConnectionErrorHandler +if t.TYPE_CHECKING: + import pandas # type: ignore[import] + + from ...graph import Graph + + +_T = t.TypeVar("_T") +_T_ResultKey: te.TypeAlias = t.Union[int, str] + + _RESULT_OUT_OF_SCOPE_ERROR = ( "The result is out of scope. The associated transaction " "has been closed. Results can only be used while the " @@ -214,8 +229,9 @@ def on_success(summary_metadata): ) self._streaming = True - async def __aiter__(self): + async def __aiter__(self) -> t.AsyncIterator[Record]: """Iterator returning Records. + :returns: Record, it is an immutable ordered collection of key-value pairs. :rtype: :class:`neo4j.Record` """ @@ -237,7 +253,7 @@ async def __aiter__(self): if self._consumed: raise ResultConsumedError(self, _RESULT_CONSUMED_ERROR) - async def __anext__(self): + async def __anext__(self) -> Record: return await self.__aiter__().__anext__() async def _attach(self): @@ -297,7 +313,7 @@ def _obtain_summary(self): return self._summary - def keys(self): + def keys(self) -> t.Tuple[str, ...]: """The keys for the records in this result. :returns: tuple of key names @@ -321,7 +337,7 @@ async def _tx_end(self): await self._exhaust() self._out_of_scope = True - async def consume(self): + async def consume(self) -> ResultSummary: """Consume the remainder of this result and return a :class:`neo4j.ResultSummary`. Example:: @@ -378,7 +394,17 @@ async def get_two_tx(tx): self._consumed = True return summary - async def single(self, strict=False): + @t.overload + async def single( + self, strict: te.Literal[False] = False + ) -> t.Optional[Record]: + ... + + @t.overload + async def single(self, strict: te.Literal[True]) -> Record: + ... + + async def single(self, strict: bool = False) -> t.Optional[Record]: """Obtain the next and only remaining record or None. Calling this method always exhausts the result. @@ -391,9 +417,7 @@ async def single(self, strict=False): instead of returning None if there is more than one record or warning if there are more than 1 record. :const:`False` by default. - :type strict: bool - :returns: the next :class:`neo4j.Record` or :const:`None` if none remain :warns: if more than one record is available :raises ResultNotSingleError: @@ -402,6 +426,8 @@ async def single(self, strict=False): was obtained has been closed or the Result has been explicitly consumed. + :returns: the next :class:`neo4j.Record` or :const:`None` if none remain + .. versionchanged:: 5.0 Added ``strict`` parameter. .. versionchanged:: 5.0 @@ -433,11 +459,10 @@ async def single(self, strict=False): ) return buffer.popleft() - async def fetch(self, n): + async def fetch(self, n: int) -> t.List[Record]: """Obtain up to n records from this result. :param n: the maximum number of records to fetch. - :type n: int :returns: list of :class:`neo4j.Record` @@ -453,7 +478,7 @@ async def fetch(self, n): for _ in range(min(n, len(self._record_buffer))) ] - async def peek(self): + async def peek(self) -> t.Optional[Record]: """Obtain the next record from this result without consuming it. This leaves the record in the buffer for further processing. @@ -470,20 +495,20 @@ async def peek(self): await self._buffer(1) if self._record_buffer: return self._record_buffer[0] + return None - async def graph(self): + async def graph(self) -> Graph: """Return a :class:`neo4j.graph.Graph` instance containing all the graph objects in the result. After calling this method, the result becomes detached, buffering all remaining records. - :returns: a result graph - :rtype: :class:`neo4j.graph.Graph` + **This is experimental.** (See :ref:`filter-warnings-ref`) :raises ResultConsumedError: if the transaction from which this result was obtained has been closed or the Result has been explicitly consumed. - **This is experimental.** (See :ref:`filter-warnings-ref`) + :returns: a result graph .. versionchanged:: 5.0 Can raise :exc:`ResultConsumedError`. @@ -491,7 +516,9 @@ async def graph(self): await self._buffer_all() return self._hydration_scope.get_graph() - async def value(self, key=0, default=None): + async def value( + self, key: _T_ResultKey = 0, default: object = None + ) -> t.List[t.Any]: """Helper function that return the remainder of the result as a list of values. See :class:`neo4j.Record.value` @@ -499,38 +526,38 @@ async def value(self, key=0, default=None): :param key: field to return for each remaining record. Obtain a single value from the record by index or key. :param default: default value, used if the index of key is unavailable - :returns: list of individual values - :rtype: list - :raises ResultConsumedError: if the transaction from which this result was obtained has been closed or the Result has been explicitly consumed. + :returns: list of individual values + .. versionchanged:: 5.0 Can raise :exc:`ResultConsumedError`. """ return [record.value(key, default) async for record in self] - async def values(self, *keys): + async def values( + self, *keys: _T_ResultKey + ) -> t.List[t.List[t.Any]]: """Helper function that return the remainder of the result as a list of values lists. See :class:`neo4j.Record.values` :param keys: fields to return for each remaining record. Optionally filtering to include only certain values by index or key. - :returns: list of values lists - :rtype: list - :raises ResultConsumedError: if the transaction from which this result was obtained has been closed or the Result has been explicitly consumed. + :returns: list of values lists + .. versionchanged:: 5.0 Can raise :exc:`ResultConsumedError`. """ return [record.values(*keys) async for record in self] - async def data(self, *keys): + async def data(self, *keys: _T_ResultKey) -> t.List[t.Any]: """Helper function that return the remainder of the result as a list of dictionaries. See :class:`neo4j.Record.data` @@ -551,7 +578,11 @@ async def data(self, *keys): @experimental("pandas support is experimental and might be changed or " "removed in future versions") - async def to_df(self, expand=False, parse_dates=False): + async def to_df( + self, + expand: bool = False, + parse_dates: bool = False + ) -> pandas.DataFrame: r"""Convert (the rest of) the result to a pandas DataFrame. This method is only available if the `pandas` library is installed. @@ -627,14 +658,11 @@ async def to_df(self, expand=False, parse_dates=False): :const:`dict` keys and variable names that contain ``.`` or ``\`` will be escaped with a backslash (``\.`` and ``\\`` respectively). - :type expand: bool :param parse_dates: If :const:`True`, columns that excluvively contain :class:`time.DateTime` objects, :class:`time.Date` objects, or :const:`None`, will be converted to :class:`pandas.Timestamp`. - :type parse_dates: bool - :rtype: :py:class:`pandas.DataFrame` :raises ImportError: if `pandas` library is not available. :raises ResultConsumedError: if the transaction from which this result was obtained has been closed or the Result has been explicitly @@ -644,7 +672,7 @@ async def to_df(self, expand=False, parse_dates=False): ``pandas`` support might be changed or removed in future versions without warning. (See :ref:`filter-warnings-ref`) """ - import pandas as pd + import pandas as pd # type: ignore[import] if not expand: df = pd.DataFrame(await self.values(), columns=self._keys) @@ -691,7 +719,7 @@ async def to_df(self, expand=False, parse_dates=False): ) return df - def closed(self): + def closed(self) -> bool: """Return True if the result has been closed. When a result gets consumed :meth:`consume` or the transaction that @@ -702,7 +730,6 @@ def closed(self): will raise a :exc:`ResultConsumedError` when called. :returns: whether the result is closed. - :rtype: bool .. versionadded:: 5.0 """ diff --git a/neo4j/_async/work/session.py b/neo4j/_async/work/session.py index ee4b73ac..bb26aaf9 100644 --- a/neo4j/_async/work/session.py +++ b/neo4j/_async/work/session.py @@ -16,10 +16,22 @@ # limitations under the License. +from __future__ import annotations + +import typing as t from logging import getLogger from random import random from time import perf_counter + +if t.TYPE_CHECKING: + import typing_extensions as te + + from ..io import AsyncBolt + + _R = t.TypeVar("_R") + _P = te.ParamSpec("_P") + from ..._async_compat import async_sleep from ..._conf import SessionConfig from ..._meta import ( @@ -73,10 +85,11 @@ class AsyncSession(AsyncWorkspace): """ # The current connection. - _connection = None + _connection: t.Optional[AsyncBolt] = None - # The current :class:`.Transaction` instance, if any. - _transaction = None + # The current transaction instance, if any. + _transaction: t.Union[AsyncTransaction, AsyncManagedTransaction, None] = \ + None # The current auto-transaction result, if any. _auto_result = None @@ -89,7 +102,7 @@ def __init__(self, pool, session_config): assert isinstance(session_config, SessionConfig) self._bookmarks = self._prepare_bookmarks(session_config.bookmarks) - async def __aenter__(self): + async def __aenter__(self) -> AsyncSession: return self async def __aexit__(self, exception_type, exception_value, traceback): @@ -139,7 +152,7 @@ async def _get_server_info(self): await self._disconnect() return server_info - async def close(self): + async def close(self) -> None: """Close the session. This will release any borrowed resources, such as connections, and will @@ -183,7 +196,12 @@ async def close(self): self._state_failed = False self._closed = True - async def run(self, query, parameters=None, **kwargs): + async def run( + self, + query: t.Union[str, Query], + parameters: t.Dict[str, t.Any] = None, + **kwargs: t.Any + ) -> AsyncResult: """Run a Cypher query within an auto-commit transaction. The query is sent and the result header received @@ -202,12 +220,10 @@ async def run(self, query, parameters=None, **kwargs): For more usage details, see :meth:`.AsyncTransaction.run`. :param query: cypher query - :type query: str, neo4j.Query :param parameters: dictionary of parameters - :type parameters: dict :param kwargs: additional keyword parameters + :returns: a new :class:`neo4j.AsyncResult` object - :rtype: AsyncResult """ if not query: raise ValueError("Cannot run an empty query") @@ -224,8 +240,6 @@ async def run(self, query, parameters=None, **kwargs): if not self._connection: await self._connect(self._config.default_access_mode) cx = self._connection - protocol_version = cx.PROTOCOL_VERSION - server_info = cx.server_info self._auto_result = AsyncResult( cx, self._config.fetch_size, self._result_closed, @@ -243,7 +257,7 @@ async def run(self, query, parameters=None, **kwargs): "`last_bookmark` has been deprecated in favor of `last_bookmarks`. " "This method can lead to unexpected behaviour." ) - async def last_bookmark(self): + async def last_bookmark(self) -> t.Optional[str]: """Return the bookmark received following the last completed transaction. Note: For auto-transactions (:meth:`Session.run`), this will trigger @@ -258,7 +272,6 @@ async def last_bookmark(self): Use :meth:`last_bookmarks` instead. :returns: last bookmark - :rtype: str or None """ # The set of bookmarks to be passed into the next transaction. @@ -273,7 +286,7 @@ async def last_bookmark(self): return self._bookmarks[-1] return None - async def last_bookmarks(self): + async def last_bookmarks(self) -> Bookmarks: """Return most recent bookmarks of the session. Bookmarks can be used to causally chain sessions. For example, @@ -300,7 +313,6 @@ async def last_bookmarks(self): :meth:`Result.consume` for the current result. :returns: the session's last known bookmarks - :rtype: Bookmarks """ # The set of bookmarks to be passed into the next transaction. @@ -338,7 +350,11 @@ async def _open_transaction( self._bookmarks, access_mode, metadata, timeout ) - async def begin_transaction(self, metadata=None, timeout=None): + async def begin_transaction( + self, + metadata: t.Dict[str, t.Any] = None, + timeout: float = None + ) -> AsyncTransaction: """ Begin a new unmanaged transaction. Creates a new :class:`.AsyncTransaction` within this session. At most one transaction may exist in a session at any point in time. To maintain multiple concurrent transactions, use multiple concurrent sessions. @@ -350,7 +366,6 @@ async def begin_transaction(self, metadata=None, timeout=None): Specified metadata will be attached to the executing transaction and visible in the output of ``dbms.listQueries`` and ``dbms.listTransactions`` procedures. It will also get logged to the ``query.log``. This functionality makes it easier to tag transactions and is equivalent to ``dbms.setTXMetaData`` procedure, see https://neo4j.com/docs/operations-manual/current/reference/procedures/ for procedure reference. - :type metadata: dict :param timeout: the transaction timeout in seconds. @@ -358,12 +373,10 @@ async def begin_transaction(self, metadata=None, timeout=None): This functionality allows to limit query/transaction execution time. Specified timeout overrides the default timeout configured in the database using ``dbms.transaction.timeout`` setting. Value should not represent a duration of zero or negative duration. - :type timeout: int - - :returns: A new transaction instance. - :rtype: AsyncTransaction :raises TransactionError: :class:`neo4j.exceptions.TransactionError` if a transaction is already open. + + :returns: A new transaction instance. """ # TODO: Implement TransactionConfig consumption @@ -371,7 +384,9 @@ async def begin_transaction(self, metadata=None, timeout=None): await self._auto_result.consume() if self._transaction: - raise TransactionError("Explicit transaction already open") + raise TransactionError( + self._transaction, "Explicit transaction already open" + ) await self._open_transaction( tx_cls=AsyncTransaction, @@ -379,7 +394,7 @@ async def begin_transaction(self, metadata=None, timeout=None): timeout=timeout ) - return self._transaction + return t.cast(AsyncTransaction, self._transaction) async def _run_transaction( self, access_mode, transaction_function, *args, **kwargs @@ -438,7 +453,13 @@ async def _run_transaction( else: raise ServiceUnavailable("Transaction failed") - async def read_transaction(self, transaction_function, *args, **kwargs): + async def read_transaction( + self, + transaction_function: t.Callable[ + te.Concatenate[AsyncManagedTransaction, _P], t.Awaitable[_R] + ], + *args: _P.args, **kwargs: _P.kwargs + ) -> _R: """Execute a unit of work in a managed read transaction. .. note:: @@ -487,13 +508,20 @@ async def get_two_tx(tx): :class:`.AsyncTransaction`. :param args: arguments for the `transaction_function` :param kwargs: key word arguments for the `transaction_function` + :return: a result as returned by the given unit of work """ return await self._run_transaction( READ_ACCESS, transaction_function, *args, **kwargs ) - async def write_transaction(self, transaction_function, *args, **kwargs): + async def write_transaction( + self, + transaction_function: t.Callable[ + te.Concatenate[AsyncManagedTransaction, _P], t.Awaitable[_R] + ], + *args: _P.args, **kwargs: _P.kwargs + ) -> _R: """Execute a unit of work in a managed write transaction. .. note:: @@ -522,6 +550,7 @@ async def create_node_tx(tx, name): :class:`.AsyncTransaction`. :param args: key word arguments for the `transaction_function` :param kwargs: key word arguments for the `transaction_function` + :return: a result as returned by the given unit of work """ return await self._run_transaction( diff --git a/neo4j/_async/work/transaction.py b/neo4j/_async/work/transaction.py index 57000d12..84e5ca65 100644 --- a/neo4j/_async/work/transaction.py +++ b/neo4j/_async/work/transaction.py @@ -16,6 +16,9 @@ # limitations under the License. +from __future__ import annotations + +import typing as t from functools import wraps from ..._async_compat.util import AsyncUtil @@ -25,10 +28,14 @@ from .result import AsyncResult -__all__ = ("AsyncTransaction", "AsyncManagedTransaction") +__all__ = ( + "AsyncManagedTransaction", + "AsyncTransaction", + "AsyncTransactionBase", +) -class _AsyncTransactionBase: +class AsyncTransactionBase: def __init__(self, connection, fetch_size, on_closed, on_error): self._connection = connection self._error_handling_connection = ConnectionErrorHandler( @@ -75,7 +82,12 @@ async def _consume_results(self): await result._tx_end() self._results = [] - async def run(self, query, parameters=None, **kwparameters): + async def run( + self, + query: str, + parameters: t.Dict[str, t.Any] = None, + **kwparameters: t.Any + ) -> AsyncResult: """ Run a Cypher query within the context of this transaction. Cypher is typically expressed as a query template plus a @@ -95,15 +107,12 @@ async def run(self, query, parameters=None, **kwparameters): :class:`list` properties must be homogenous. :param query: cypher query - :type query: str :param parameters: dictionary of parameters - :type parameters: dict :param kwparameters: additional keyword parameters - :returns: a new :class:`neo4j.AsyncResult` object - :rtype: :class:`neo4j.AsyncResult` - :raise TransactionError: if the transaction is already closed + + :returns: a new :class:`neo4j.AsyncResult` object """ if isinstance(query, Query): raise ValueError("Query object is only supported for session.run") @@ -194,7 +203,7 @@ def _closed(self): return self._closed_flag -class AsyncTransaction(_AsyncTransactionBase): +class AsyncTransaction(AsyncTransactionBase): """ Container for multiple Cypher queries to be executed within a single context. :class:`AsyncTransaction` objects can be used as a context managers (:py:const:`async with` block) where the transaction is committed @@ -205,32 +214,32 @@ class AsyncTransaction(_AsyncTransactionBase): """ - @wraps(_AsyncTransactionBase._enter) - async def __aenter__(self): + @wraps(AsyncTransactionBase._enter) + async def __aenter__(self) -> AsyncTransaction: return await self._enter() - @wraps(_AsyncTransactionBase._exit) + @wraps(AsyncTransactionBase._exit) async def __aexit__(self, exception_type, exception_value, traceback): await self._exit(exception_type, exception_value, traceback) - @wraps(_AsyncTransactionBase._commit) - async def commit(self): + @wraps(AsyncTransactionBase._commit) + async def commit(self) -> None: return await self._commit() - @wraps(_AsyncTransactionBase._rollback) - async def rollback(self): + @wraps(AsyncTransactionBase._rollback) + async def rollback(self) -> None: return await self._rollback() - @wraps(_AsyncTransactionBase._close) - async def close(self): + @wraps(AsyncTransactionBase._close) + async def close(self) -> None: return await self._close() - @wraps(_AsyncTransactionBase._closed) - def closed(self): + @wraps(AsyncTransactionBase._closed) + def closed(self) -> bool: return self._closed() -class AsyncManagedTransaction(_AsyncTransactionBase): +class AsyncManagedTransaction(AsyncTransactionBase): """Transaction object provided to transaction functions. Inside a transaction function, the driver is responsible for managing diff --git a/neo4j/_async/work/workspace.py b/neo4j/_async/work/workspace.py index 9c589db5..2d3a8268 100644 --- a/neo4j/_async/work/workspace.py +++ b/neo4j/_async/work/workspace.py @@ -16,6 +16,8 @@ # limitations under the License. +from __future__ import annotations + import asyncio from ..._conf import WorkspaceConfig @@ -64,7 +66,7 @@ def __del__(self): except (OSError, ServiceUnavailable, SessionExpired): pass - async def __aenter__(self): + async def __aenter__(self) -> AsyncWorkspace: return self async def __aexit__(self, exc_type, exc_value, traceback): diff --git a/neo4j/_async_compat/network/_bolt_socket.py b/neo4j/_async_compat/network/_bolt_socket.py index 475536bc..d7e274e2 100644 --- a/neo4j/_async_compat/network/_bolt_socket.py +++ b/neo4j/_async_compat/network/_bolt_socket.py @@ -16,11 +16,13 @@ # limitations under the License. +from __future__ import annotations + import asyncio import logging import selectors -import socket import struct +import typing as t from socket import ( AF_INET, AF_INET6, @@ -34,8 +36,10 @@ CertificateError, HAS_SNI, SSLError, + SSLSocket, ) -from time import perf_counter + +import typing_extensions as te from ... import addressing from ..._deadline import Deadline @@ -68,7 +72,7 @@ def _sanitize_deadline(deadline): class AsyncBoltSocket: - Bolt = None + Bolt: te.Final[t.Type] = None # type: ignore[assignment] def __init__(self, reader, protocol, writer): self._reader = reader # type: asyncio.StreamReader @@ -372,7 +376,7 @@ async def connect(cls, address, *, timeout, custom_resolver, ssl_context, class BoltSocket: - Bolt = None + Bolt: te.Final[t.Type] = None # type: ignore[assignment] def __init__(self, socket_: socket): self._socket = socket_ @@ -383,12 +387,12 @@ def _socket(self): return self.__socket @_socket.setter - def _socket(self, socket_: socket): + def _socket(self, socket_: t.Union[socket, SSLSocket]): self.__socket = socket_ self.getsockname = socket_.getsockname self.getpeername = socket_.getpeername if hasattr(socket, "getpeercert"): - self.getpeercert = socket_.getpeercert + self.getpeercert = socket_.getpeercert # type: ignore elif hasattr(self, "getpeercert"): del self.getpeercert self.gettimeout = socket_.gettimeout diff --git a/neo4j/_codec/hydration/v2/hydration_handler.py b/neo4j/_codec/hydration/v2/hydration_handler.py index 092201a0..167fab99 100644 --- a/neo4j/_codec/hydration/v2/hydration_handler.py +++ b/neo4j/_codec/hydration/v2/hydration_handler.py @@ -18,10 +18,10 @@ from ..v1.hydration_handler import * from ..v1.hydration_handler import _GraphHydrator -from . import temporal +from . import temporal # type: ignore[no-redef] -class HydrationHandler(HydrationHandlerABC): +class HydrationHandler(HydrationHandlerABC): # type: ignore[no-redef] def __init__(self): super().__init__() self._created_scope = False diff --git a/neo4j/_codec/hydration/v2/temporal.py b/neo4j/_codec/hydration/v2/temporal.py index 4741ce9a..ad602eb5 100644 --- a/neo4j/_codec/hydration/v2/temporal.py +++ b/neo4j/_codec/hydration/v2/temporal.py @@ -19,7 +19,7 @@ from ..v1.temporal import * -def hydrate_datetime(seconds, nanoseconds, tz=None): +def hydrate_datetime(seconds, nanoseconds, tz=None): # type: ignore[no-redef] """ Hydrator for `DateTime` and `LocalDateTime` values. :param seconds: @@ -47,7 +47,7 @@ def hydrate_datetime(seconds, nanoseconds, tz=None): return t.as_timezone(zone) -def dehydrate_datetime(value): +def dehydrate_datetime(value): # type: ignore[no-redef] """ Dehydrator for `datetime` values. :param value: diff --git a/neo4j/_codec/packstream/v1/__init__.py b/neo4j/_codec/packstream/v1/__init__.py index 0c74b687..d2f9caf4 100644 --- a/neo4j/_codec/packstream/v1/__init__.py +++ b/neo4j/_codec/packstream/v1/__init__.py @@ -32,11 +32,6 @@ UNPACKED_UINT_8 = {bytes(bytearray([x])): x for x in range(0x100)} UNPACKED_UINT_16 = {struct_pack(">H", x): x for x in range(0x10000)} -UNPACKED_MARKERS = {b"\xC0": None, b"\xC2": False, b"\xC3": True} -UNPACKED_MARKERS.update({bytes(bytearray([z])): z for z in range(0x00, 0x80)}) -UNPACKED_MARKERS.update({bytes(bytearray([z + 256])): z for z in range(-0x10, 0x00)}) - - INT64_MIN = -(2 ** 63) INT64_MAX = 2 ** 63 diff --git a/neo4j/_conf.py b/neo4j/_conf.py index a3e290cc..0912efe8 100644 --- a/neo4j/_conf.py +++ b/neo4j/_conf.py @@ -335,14 +335,6 @@ class PoolConfig(Config): user_agent = get_user_agent() # Specify the client agent name. - #: Protocol Version (Python Driver Specific) - protocol_version = None # Version(4, 0) - # Specify a specific Bolt Protocol Version - - #: Initial Connection Pool Size (Python Driver Specific) - init_size = 1 # The other drivers do not seed from the start. - # This will seed the pool with the specified number of connections. - #: Socket Keep Alive (Python and .NET Driver Specific) keep_alive = True # Specify whether TCP keep-alive should be enabled. diff --git a/neo4j/_data.py b/neo4j/_data.py index 84690573..6d08ddad 100644 --- a/neo4j/_data.py +++ b/neo4j/_data.py @@ -16,6 +16,9 @@ # limitations under the License. +from __future__ import annotations + +import typing as t from abc import ( ABCMeta, abstractmethod, @@ -39,6 +42,11 @@ ) +if t.TYPE_CHECKING: + _T = t.TypeVar("_T") + _T_K = t.Union[int, str] + + class Record(tuple, Mapping): """ A :class:`.Record` is an immutable ordered collection of key-value pairs. It is generally closer to a :py:class:`namedtuple` than to a @@ -46,7 +54,7 @@ class Record(tuple, Mapping): yield values rather than keys. """ - __keys = None + __keys: t.Tuple[str] def __new__(cls, iterable=()): keys = [] @@ -69,17 +77,16 @@ def _super_getitem_single(self, index): raise self._broken_record_error(index) from value.error return value - def __repr__(self): + def __repr__(self) -> str: return "<%s %s>" % ( self.__class__.__name__, " ".join("%s=%r" % (field, value) for field, value in zip(self.__keys, super().__iter__())) ) - def __str__(self): - return self.__repr__() + __str__ = __repr__ - def __eq__(self, other): + def __eq__(self, other: object) -> bool: """ In order to be flexible regarding comparison, the equality rules for a record permit comparison with any other Sequence or Mapping. @@ -89,27 +96,32 @@ def __eq__(self, other): compare_as_sequence = isinstance(other, Sequence) compare_as_mapping = isinstance(other, Mapping) if compare_as_sequence and compare_as_mapping: + other = t.cast(t.Mapping, other) return list(self) == list(other) and dict(self) == dict(other) elif compare_as_sequence: + other = t.cast(t.Sequence, other) return list(self) == list(other) elif compare_as_mapping: + other = t.cast(t.Mapping, other) return dict(self) == dict(other) else: return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) def __hash__(self): return reduce(xor_operator, map(hash, self.items())) - def __iter__(self): + def __iter__(self) -> t.Iterator[t.Any]: for i, v in enumerate(super().__iter__()): if isinstance(v, BrokenHydrationObject): raise self._broken_record_error(i) from v.error yield v - def __getitem__(self, key): + def __getitem__( # type: ignore[override] + self, key: t.Union[_T_K, slice] + ) -> t.Any: if isinstance(key, slice): keys = self.__keys[key] values = super().__getitem__(key) @@ -129,12 +141,13 @@ def __getslice__(self, start, stop): values = tuple(self)[key] return self.__class__(zip(keys, values)) - def get(self, key, default=None): + def get(self, key: str, default: object = None) -> t.Any: """ Obtain a value from the record by key, returning a default value if the key does not exist. :param key: a key :param default: default value + :return: a value """ try: @@ -146,12 +159,12 @@ def get(self, key, default=None): else: return default - def index(self, key): + def index(self, key: _T_K) -> int: # type: ignore[override] """ Return the index of the given item. :param key: a key + :return: index - :rtype: int """ if isinstance(key, int): if 0 <= key < len(self.__keys): @@ -165,13 +178,14 @@ def index(self, key): else: raise TypeError(key) - def value(self, key=0, default=None): + def value(self, key: _T_K = 0, default: object = None) -> t.Any: """ Obtain a single value from the record by index or key. If no index or key is specified, the first value is returned. If the specified item does not exist, the default value is returned. :param key: an index or key :param default: default value + :return: a single value """ try: @@ -181,24 +195,24 @@ def value(self, key=0, default=None): else: return self[index] - def keys(self): + def keys(self) -> t.List[str]: # type: ignore[override] """ Return the keys of the record. :return: list of key names """ return list(self.__keys) - def values(self, *keys): + def values(self, *keys: _T_K) -> t.List[t.Any]: # type: ignore[override] """ Return the values of the record, optionally filtering to include only certain values by index or key. :param keys: indexes or keys of the items to include; if none are provided, all values will be included + :return: list of values - :rtype: list """ if keys: - d = [] + d: t.List[t.Any] = [] for key in keys: try: i = self.index(key) @@ -213,7 +227,6 @@ def items(self, *keys): """ Return the fields of the record as a list of key and value tuples :return: a list of value tuples - :rtype: list """ if keys: d = [] @@ -228,7 +241,7 @@ def items(self, *keys): return list((self.__keys[i], self._super_getitem_single(i)) for i in range(len(self))) - def data(self, *keys): + def data(self, *keys: _T_K) -> t.Dict[str, t.Any]: """ Return the keys and values of this record as a dictionary, optionally including only certain values by index or key. Keys provided in the items that are not in the record will be @@ -237,8 +250,10 @@ def data(self, *keys): :param keys: indexes or keys of the items to include; if none are provided, all values will be included - :return: dictionary of values, keyed by field name + :raises: :exc:`IndexError` if an out-of-bounds index is specified + + :return: dictionary of values, keyed by field name """ return RecordExporter().transform(dict(self.items(*keys))) diff --git a/neo4j/_meta.py b/neo4j/_meta.py index 307001e3..382f4050 100644 --- a/neo4j/_meta.py +++ b/neo4j/_meta.py @@ -43,7 +43,18 @@ def deprecation_warn(message, stack_level=1): warn(message, category=DeprecationWarning, stacklevel=stack_level + 1) -def deprecated(message): +from typing import ( + Callable, + cast, + TypeVar, +) + + +T = TypeVar("T") +FuncT = TypeVar("FuncT", bound=Callable[..., object]) + + +def deprecated(message: str) -> Callable[[FuncT], FuncT]: """ Decorator for deprecating functions and methods. :: @@ -53,25 +64,31 @@ def foo(x): pass """ - def decorator(f): + def decorator(f: FuncT) -> FuncT: if asyncio.iscoroutinefunction(f): @wraps(f) async def inner(*args, **kwargs): deprecation_warn(message, stack_level=2) return await f(*args, **kwargs) - return inner + return cast(FuncT, inner) else: @wraps(f) def inner(*args, **kwargs): deprecation_warn(message, stack_level=2) return f(*args, **kwargs) - return inner + return cast(FuncT, inner) return decorator +def deprecated_property(message: str): + def decorator(f): + return property(deprecated(message)(f)) + return cast(property, decorator) + + class ExperimentalWarning(Warning): """ Base class for warnings about experimental features. """ diff --git a/neo4j/_spatial/__init__.py b/neo4j/_spatial/__init__.py index 3c84a0b0..baab4539 100644 --- a/neo4j/_spatial/__init__.py +++ b/neo4j/_spatial/__init__.py @@ -20,16 +20,18 @@ This module defines _spatial data types. """ +from __future__ import annotations +import typing as t from threading import Lock # SRID to subclass mappings -srid_table = {} +srid_table: t.Dict[int, t.Tuple[t.Type[Point], int]] = {} srid_table_lock = Lock() -class Point(tuple): +class Point(t.Tuple[float, ...]): """Base-class for _spatial data. A point within a geometric space. This type is generally used via its @@ -39,35 +41,52 @@ class Point(tuple): :param iterable: An iterable of coordinates. All items will be converted to :class:`float`. + :type iterable: Iterable[float] """ #: The SRID (_spatial reference identifier) of the _spatial data. #: A number that identifies the coordinate system the _spatial type is to be #: interpreted in. - #: - #: :type: int - srid = None - def __new__(cls, iterable): + srid: t.Optional[int] + + @property + def x(self) -> float: + ... + + @property + def y(self) -> float: + ... + + @property + def z(self) -> float: + ... + + def __new__(cls, iterable: t.Iterable[float]) -> Point: return tuple.__new__(cls, map(float, iterable)) - def __repr__(self): + def __repr__(self) -> str: return "POINT(%s)" % " ".join(map(str, self)) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: try: - return type(self) is type(other) and tuple(self) == tuple(other) + return (type(self) is type(other) + and tuple(self) == tuple(t.cast(Point, other))) except (AttributeError, TypeError): return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) def __hash__(self): return hash(type(self)) ^ hash(tuple(self)) -def point_type(name, fields, srid_map): +def point_type( + name: str, + fields: t.Tuple[str, str, str], + srid_map: t.Dict[int, int] +) -> t.Type[Point]: """ Dynamically create a Point subclass. """ @@ -90,17 +109,33 @@ def accessor(self, i=index, f=subclass_field): for field_alias in {subclass_field, "xyz"[index]}: attributes[field_alias] = property(accessor) - cls = type(name, (Point,), attributes) + cls = t.cast(t.Type[Point], type(name, (Point,), attributes)) with srid_table_lock: - for dim, srid in srid_map.items(): - srid_table[srid] = (cls, dim) + for dim, srid_ in srid_map.items(): + srid_table[srid_] = (cls, dim) return cls # Point subclass definitions -CartesianPoint = point_type("CartesianPoint", ["x", "y", "z"], - {2: 7203, 3: 9157}) -WGS84Point = point_type("WGS84Point", ["longitude", "latitude", "height"], - {2: 4326, 3: 4979}) +if t.TYPE_CHECKING: + class CartesianPoint(Point): + ... +else: + CartesianPoint = point_type("CartesianPoint", ("x", "y", "z"), + {2: 7203, 3: 9157}) + +if t.TYPE_CHECKING: + class WGS84Point(Point): + @property + def longitude(self) -> float: ... + + @property + def latitude(self) -> float: ... + + @property + def height(self) -> float: ... +else: + WGS84Point = point_type("WGS84Point", ("longitude", "latitude", "height"), + {2: 4326, 3: 4979}) diff --git a/neo4j/_sync/driver.py b/neo4j/_sync/driver.py index e67ff7c0..de685dc5 100644 --- a/neo4j/_sync/driver.py +++ b/neo4j/_sync/driver.py @@ -16,6 +16,16 @@ # limitations under the License. +from __future__ import annotations + +import typing as t + + +if t.TYPE_CHECKING: + import typing_extensions as te + + import ssl + from .._async_compat.util import Util from .._conf import ( Config, @@ -33,125 +43,176 @@ ) from ..addressing import Address from ..api import ( + Auth, + Bookmarks, + DRIVER_BOLT, + DRIVER_NEO4J, + parse_neo4j_uri, + parse_routing_context, READ_ACCESS, + SECURITY_TYPE_SECURE, + SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, + ServerInfo, TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, + URI_SCHEME_BOLT, + URI_SCHEME_BOLT_SECURE, + URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE, + URI_SCHEME_NEO4J, + URI_SCHEME_NEO4J_SECURE, + URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, ) +from .work import Session class GraphDatabase: """Accessor for :class:`neo4j.Driver` construction. """ - @classmethod - @Util.experimental_async( - "neo4j is in experimental phase. It might be removed or changed " - "at any time (including patch releases)." - ) - def driver(cls, uri, *, auth=None, **config): - """Create a driver. - - :param uri: the connection URI for the driver, see :ref:`uri-ref` for available URIs. - :param auth: the authentication details, see :ref:`auth-ref` for available authentication details. - :param config: driver configuration key-word arguments, see :ref:`driver-configuration-ref` for available key-word arguments. - - :rtype: Neo4jDriver or BoltDriver - """ - - from ..api import ( - DRIVER_BOLT, - DRIVER_NEO4j, - parse_neo4j_uri, - parse_routing_context, - SECURITY_TYPE_SECURE, - SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, - URI_SCHEME_BOLT, - URI_SCHEME_BOLT_SECURE, - URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE, - URI_SCHEME_NEO4J, - URI_SCHEME_NEO4J_SECURE, - URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, + if t.TYPE_CHECKING: + + @classmethod + def driver( + cls, + uri: str, + *, + auth: t.Union[t.Tuple[t.Any, t.Any], Auth, None] = ..., + max_connection_lifetime: float = ..., + max_connection_pool_size: int = ..., + connection_timeout: float = ..., + update_routing_table_timeout: float = ..., + trust: t.Union[ + te.Literal["TRUST_ALL_CERTIFICATES"], + te.Literal["TRUST_SYSTEM_CA_SIGNED_CERTIFICATES"] + ] = ..., + resolver: t.Union[ + t.Callable[[Address], t.Iterable[Address]], + t.Callable[[Address], t.Union[t.Iterable[Address]]], + ] = ..., + encrypted: bool = ..., + trusted_certificates: TrustStore = ..., + ssl_context: ssl.SSLContext = ..., + user_agent: str = ..., + keep_alive: bool = ..., + + # undocumented/unsupported options + # might be removed/changed without warning, even in patch versions + session_connection_timeout: float = ..., + connection_acquisition_timeout: float = ..., + max_transaction_retry_time: float = ..., + initial_retry_delay: float = ..., + retry_delay_multiplier: float = ..., + retry_delay_jitter_factor: float = ..., + database: t.Optional[str] = ..., + fetch_size: int = ..., + impersonated_user: t.Optional[str] = ... + ) -> Driver: + ... + + else: + + @classmethod + @Util.experimental_async( + "neo4j is in experimental phase. It might be removed or " + "changed at any time (including patch releases)." ) + def driver( + cls, + uri: str, + *, + auth: t.Union[t.Tuple[t.Any, t.Any], Auth] = None, + **config # TODO: type config + ) -> Driver: + """Create a driver. + + :param uri: the connection URI for the driver, see :ref:`uri-ref` for available URIs. + :param auth: the authentication details, see :ref:`auth-ref` for available authentication details. + :param config: driver configuration key-word arguments, see :ref:`driver-configuration-ref` for available key-word arguments. + """ + + driver_type, security_type, parsed = parse_neo4j_uri(uri) + + # TODO: 6.0 remove "trust" config option + if "trust" in config.keys(): + if config["trust"] not in ( + TRUST_ALL_CERTIFICATES, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES + ): + from neo4j.exceptions import ConfigurationError + raise ConfigurationError( + "The config setting `trust` values are {!r}" + .format( + [ + TRUST_ALL_CERTIFICATES, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, + ] + ) + ) - driver_type, security_type, parsed = parse_neo4j_uri(uri) + if ("trusted_certificates" in config.keys() + and not isinstance(config["trusted_certificates"], + TrustStore)): + raise ConnectionError( + "The config setting `trusted_certificates` must be of " + "type neo4j.TrustAll, neo4j.TrustCustomCAs, or" + "neo4j.TrustSystemCAs but was {}".format( + type(config["trusted_certificates"]) + ) + ) - # TODO: 6.0 remove "trust" config option - if "trust" in config.keys(): - if config["trust"] not in (TRUST_ALL_CERTIFICATES, - TRUST_SYSTEM_CA_SIGNED_CERTIFICATES): + if (security_type in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE] + and ("encrypted" in config.keys() + or "trust" in config.keys() + or "trusted_certificates" in config.keys() + or "ssl_context" in config.keys())): from neo4j.exceptions import ConfigurationError + + # TODO: 6.0 remove "trust" from error message raise ConfigurationError( - "The config setting `trust` values are {!r}" + 'The config settings "encrypted", "trust", ' + '"trusted_certificates", and "ssl_context" can only be ' + "used with the URI schemes {!r}. Use the other URI " + "schemes {!r} for setting encryption settings." .format( [ - TRUST_ALL_CERTIFICATES, - TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, + URI_SCHEME_BOLT, + URI_SCHEME_NEO4J, + ], + [ + URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE, + URI_SCHEME_BOLT_SECURE, + URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, + URI_SCHEME_NEO4J_SECURE, ] ) ) - if ("trusted_certificates" in config.keys() - and not isinstance(config["trusted_certificates"], - TrustStore)): - raise ConnectionError( - "The config setting `trusted_certificates` must be of type " - "neo4j.TrustAll, neo4j.TrustCustomCAs, or" - "neo4j.TrustSystemCAs but was {}".format( - type(config["trusted_certificates"]) - ) - ) - - if (security_type in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE] - and ("encrypted" in config.keys() - or "trust" in config.keys() - or "trusted_certificates" in config.keys() - or "ssl_context" in config.keys())): - from neo4j.exceptions import ConfigurationError - - # TODO: 6.0 remove "trust" from error message - raise ConfigurationError( - 'The config settings "encrypted", "trust", ' - '"trusted_certificates", and "ssl_context" can only be used ' - "with the URI schemes {!r}. Use the other URI schemes {!r} " - "for setting encryption settings." - .format( - [ - URI_SCHEME_BOLT, - URI_SCHEME_NEO4J, - ], - [ - URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE, - URI_SCHEME_BOLT_SECURE, - URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, - URI_SCHEME_NEO4J_SECURE, - ] - ) - ) - - if security_type == SECURITY_TYPE_SECURE: - config["encrypted"] = True - elif security_type == SECURITY_TYPE_SELF_SIGNED_CERTIFICATE: - config["encrypted"] = True - config["trusted_certificates"] = TrustAll() - - if driver_type == DRIVER_BOLT: - if parse_routing_context(parsed.query): - deprecation_warn( - "Creating a direct driver (`bolt://` scheme) with routing " - "context (URI parameters) is deprecated. They will be " - "ignored. This will raise an error in a future release. " - 'Given URI "{}"'.format(uri), - stack_level=2 - ) - # TODO: 6.0 - raise instead of warning - # raise ValueError( - # 'Routing parameters are not supported with scheme ' - # '"bolt". Given URI "{}".'.format(uri) - # ) - return cls.bolt_driver(parsed.netloc, auth=auth, **config) - elif driver_type == DRIVER_NEO4j: + if security_type == SECURITY_TYPE_SECURE: + config["encrypted"] = True + elif security_type == SECURITY_TYPE_SELF_SIGNED_CERTIFICATE: + config["encrypted"] = True + config["trusted_certificates"] = TrustAll() + + assert driver_type in (DRIVER_BOLT, DRIVER_NEO4J) + if driver_type == DRIVER_BOLT: + if parse_routing_context(parsed.query): + deprecation_warn( + "Creating a direct driver (`bolt://` scheme) with " + "routing context (URI parameters) is deprecated. They " + "will be ignored. This will raise an error in a " + 'future release. Given URI "{}"'.format(uri), + stack_level=2 + ) + # TODO: 6.0 - raise instead of warning + # raise ValueError( + # 'Routing parameters are not supported with scheme ' + # '"bolt". Given URI "{}".'.format(uri) + # ) + return cls.bolt_driver(parsed.netloc, auth=auth, **config) + # else driver_type == DRIVER_NEO4J routing_context = parse_routing_context(parsed.query) - return cls.neo4j_driver(parsed.netloc, auth=auth, routing_context=routing_context, **config) + return cls.neo4j_driver(parsed.netloc, auth=auth, + routing_context=routing_context, **config) @classmethod def bolt_driver(cls, target, *, auth=None, **config): @@ -243,7 +304,7 @@ class Driver: """ #: Connection pool - _pool = None + _pool: t.Any = None #: Flag if the driver has been closed _closed = False @@ -254,7 +315,7 @@ def __init__(self, pool, default_workspace_config): self._pool = pool self._default_workspace_config = default_workspace_config - def __enter__(self): + def __enter__(self) -> Driver: return self def __exit__(self, exc_type, exc_value, traceback): @@ -276,31 +337,49 @@ def __del__(self): self.close() @property - def encrypted(self): - """Indicate whether the driver was configured to use encryption. - - :rtype: bool""" + def encrypted(self) -> bool: + """Indicate whether the driver was configured to use encryption.""" return bool(self._pool.pool_config.encrypted) - def session(self, **config): - """Create a session, see :ref:`session-construction-ref` - - :param config: session configuration key-word arguments, - see :ref:`session-configuration-ref` for available key-word - arguments. - - :returns: new :class:`neo4j.Session` object - """ - raise NotImplementedError - - def close(self): + if t.TYPE_CHECKING: + + def session( + self, + session_connection_timeout: float = ..., + connection_acquisition_timeout: float = ..., + max_transaction_retry_time: float = ..., + initial_retry_delay: float = ..., + retry_delay_multiplier: float = ..., + retry_delay_jitter_factor: float = ..., + database: t.Optional[str] = ..., + fetch_size: int = ..., + impersonated_user: t.Optional[str] = ..., + bookmarks: t.Union[t.Iterable[str], Bookmarks, None] = ..., + default_access_mode: str = ..., + ) -> Session: + ... + + else: + + def session(self, **config) -> Session: + """Create a session, see :ref:`session-construction-ref` + + :param config: session configuration key-word arguments, + see :ref:`session-configuration-ref` for available + key-word arguments. + + :returns: new :class:`neo4j.Session` object + """ + raise NotImplementedError + + def close(self) -> None: """ Shut down, closing any open connections in the pool. """ self._pool.close() self._closed = True # TODO: 6.0 - remove config argument - def verify_connectivity(self, **config): + def verify_connectivity(self, **config) -> None: """Verify that the driver can establish a connection to the server. This verifies if the driver can establish a reading connection to a @@ -337,7 +416,7 @@ def verify_connectivity(self, **config): with self.session(**config) as session: session._get_server_info() - def get_server_info(self, **config): + def get_server_info(self, **config) -> ServerInfo: """Get information about the connected Neo4j server. Try to establish a working read connection to the remote server or a @@ -378,11 +457,10 @@ def get_server_info(self, **config): return session._get_server_info() @experimental("Feature support query, based on Bolt protocol version and Neo4j server version will change in the future.") - def supports_multi_db(self): + def supports_multi_db(self) -> bool: """ Check if the server or cluster supports multi-databases. :return: Returns true if the server or cluster the driver connects to supports multi-databases, otherwise false. - :rtype: bool .. note:: Feature support query, based on Bolt Protocol Version and Neo4j @@ -390,6 +468,7 @@ def supports_multi_db(self): """ with self.session() as session: session._connect(READ_ACCESS) + assert session._connection return session._connection.supports_multiple_databases @@ -426,17 +505,20 @@ def __init__(self, pool, default_workspace_config): Driver.__init__(self, pool, default_workspace_config) self._default_workspace_config = default_workspace_config - def session(self, **config): - """ - :param config: The values that can be specified are found in :class: `neo4j.SessionConfig` + if not t.TYPE_CHECKING: - :return: - :rtype: :class: `neo4j.Session` - """ - from .work import Session - session_config = SessionConfig(self._default_workspace_config, config) - SessionConfig.consume(config) # Consume the config - return Session(self._pool, session_config) + def session(self, **config) -> Session: + """ + :param config: The values that can be specified are found in + :class: `neo4j.SessionConfig` + + :return: + :rtype: :class: `neo4j.Session` + """ + session_config = SessionConfig(self._default_workspace_config, + config) + SessionConfig.consume(config) # Consume the config + return Session(self._pool, session_config) class Neo4jDriver(_Routing, Driver): @@ -462,8 +544,10 @@ def __init__(self, pool, default_workspace_config): _Routing.__init__(self, pool.get_default_database_initial_router_addresses()) Driver.__init__(self, pool, default_workspace_config) - def session(self, **config): - from .work import Session - session_config = SessionConfig(self._default_workspace_config, config) - SessionConfig.consume(config) # Consume the config - return Session(self._pool, session_config) + if not t.TYPE_CHECKING: + + def session(self, **config) -> Session: + session_config = SessionConfig(self._default_workspace_config, + config) + SessionConfig.consume(config) # Consume the config + return Session(self._pool, session_config) diff --git a/neo4j/_sync/io/_bolt.py b/neo4j/_sync/io/_bolt.py index b7f9ecd8..697fbf86 100644 --- a/neo4j/_sync/io/_bolt.py +++ b/neo4j/_sync/io/_bolt.py @@ -16,6 +16,8 @@ # limitations under the License. +from __future__ import annotations + import abc import asyncio from collections import deque @@ -23,14 +25,12 @@ from time import perf_counter from ..._async_compat.network import BoltSocket -from ..._async_compat.util import Util from ..._codec.hydration import v1 as hydration_v1 from ..._codec.packstream import v1 as packstream_v1 from ..._conf import PoolConfig from ..._exceptions import ( BoltError, BoltHandshakeError, - SocketDeadlineExceeded, ) from ..._meta import get_user_agent from ...addressing import Address @@ -75,7 +75,7 @@ class Bolt: MAGIC_PREAMBLE = b"\x60\x60\xB0\x17" - PROTOCOL_VERSION = None + PROTOCOL_VERSION: Version = None # type: ignore[assignment] # flag if connection needs RESET to go back to READY state is_reset = False @@ -776,4 +776,4 @@ def is_idle_for(self, timeout): return perf_counter() - self.idle_since > timeout -BoltSocket.Bolt = Bolt +BoltSocket.Bolt = Bolt # type: ignore diff --git a/neo4j/_sync/io/_bolt3.py b/neo4j/_sync/io/_bolt3.py index 1a169f71..2db6d561 100644 --- a/neo4j/_sync/io/_bolt3.py +++ b/neo4j/_sync/io/_bolt3.py @@ -20,11 +20,7 @@ from logging import getLogger from ssl import SSLSocket -from ..._async_compat.util import Util -from ..._exceptions import ( - BoltError, - BoltProtocolError, -) +from ..._exceptions import BoltProtocolError from ...api import ( READ_ACCESS, Version, @@ -32,7 +28,6 @@ from ...exceptions import ( ConfigurationError, DatabaseUnavailable, - DriverError, ForbiddenOnReadOnlyDatabase, Neo4jError, NotALeader, diff --git a/neo4j/_sync/work/__init__.py b/neo4j/_sync/work/__init__.py index 1ab2ffad..92a1af0d 100644 --- a/neo4j/_sync/work/__init__.py +++ b/neo4j/_sync/work/__init__.py @@ -24,13 +24,15 @@ from .transaction import ( ManagedTransaction, Transaction, + TransactionBase, ) __all__ = [ "Result", "Session", - "Transaction", "ManagedTransaction", + "Transaction", + "TransactionBase", "Workspace", ] diff --git a/neo4j/_sync/work/result.py b/neo4j/_sync/work/result.py index bd9e5683..450262ca 100644 --- a/neo4j/_sync/work/result.py +++ b/neo4j/_sync/work/result.py @@ -16,9 +16,14 @@ # limitations under the License. +from __future__ import annotations + +import typing as t from collections import deque from warnings import warn +import typing_extensions as te + from ..._async_compat.util import Util from ..._codec.hydration import BrokenHydrationObject from ..._data import ( @@ -38,6 +43,16 @@ from ..io import ConnectionErrorHandler +if t.TYPE_CHECKING: + import pandas # type: ignore[import] + + from ...graph import Graph + + +_T = t.TypeVar("_T") +_T_ResultKey: te.TypeAlias = t.Union[int, str] + + _RESULT_OUT_OF_SCOPE_ERROR = ( "The result is out of scope. The associated transaction " "has been closed. Results can only be used while the " @@ -214,8 +229,9 @@ def on_success(summary_metadata): ) self._streaming = True - def __iter__(self): + def __iter__(self) -> t.Iterator[Record]: """Iterator returning Records. + :returns: Record, it is an immutable ordered collection of key-value pairs. :rtype: :class:`neo4j.Record` """ @@ -237,7 +253,7 @@ def __iter__(self): if self._consumed: raise ResultConsumedError(self, _RESULT_CONSUMED_ERROR) - def __next__(self): + def __next__(self) -> Record: return self.__iter__().__next__() def _attach(self): @@ -297,7 +313,7 @@ def _obtain_summary(self): return self._summary - def keys(self): + def keys(self) -> t.Tuple[str, ...]: """The keys for the records in this result. :returns: tuple of key names @@ -321,7 +337,7 @@ def _tx_end(self): self._exhaust() self._out_of_scope = True - def consume(self): + def consume(self) -> ResultSummary: """Consume the remainder of this result and return a :class:`neo4j.ResultSummary`. Example:: @@ -378,7 +394,17 @@ def get_two_tx(tx): self._consumed = True return summary - def single(self, strict=False): + @t.overload + def single( + self, strict: te.Literal[False] = False + ) -> t.Optional[Record]: + ... + + @t.overload + def single(self, strict: te.Literal[True]) -> Record: + ... + + def single(self, strict: bool = False) -> t.Optional[Record]: """Obtain the next and only remaining record or None. Calling this method always exhausts the result. @@ -391,9 +417,7 @@ def single(self, strict=False): instead of returning None if there is more than one record or warning if there are more than 1 record. :const:`False` by default. - :type strict: bool - :returns: the next :class:`neo4j.Record` or :const:`None` if none remain :warns: if more than one record is available :raises ResultNotSingleError: @@ -402,6 +426,8 @@ def single(self, strict=False): was obtained has been closed or the Result has been explicitly consumed. + :returns: the next :class:`neo4j.Record` or :const:`None` if none remain + .. versionchanged:: 5.0 Added ``strict`` parameter. .. versionchanged:: 5.0 @@ -433,11 +459,10 @@ def single(self, strict=False): ) return buffer.popleft() - def fetch(self, n): + def fetch(self, n: int) -> t.List[Record]: """Obtain up to n records from this result. :param n: the maximum number of records to fetch. - :type n: int :returns: list of :class:`neo4j.Record` @@ -453,7 +478,7 @@ def fetch(self, n): for _ in range(min(n, len(self._record_buffer))) ] - def peek(self): + def peek(self) -> t.Optional[Record]: """Obtain the next record from this result without consuming it. This leaves the record in the buffer for further processing. @@ -470,20 +495,20 @@ def peek(self): self._buffer(1) if self._record_buffer: return self._record_buffer[0] + return None - def graph(self): + def graph(self) -> Graph: """Return a :class:`neo4j.graph.Graph` instance containing all the graph objects in the result. After calling this method, the result becomes detached, buffering all remaining records. - :returns: a result graph - :rtype: :class:`neo4j.graph.Graph` + **This is experimental.** (See :ref:`filter-warnings-ref`) :raises ResultConsumedError: if the transaction from which this result was obtained has been closed or the Result has been explicitly consumed. - **This is experimental.** (See :ref:`filter-warnings-ref`) + :returns: a result graph .. versionchanged:: 5.0 Can raise :exc:`ResultConsumedError`. @@ -491,7 +516,9 @@ def graph(self): self._buffer_all() return self._hydration_scope.get_graph() - def value(self, key=0, default=None): + def value( + self, key: _T_ResultKey = 0, default: object = None + ) -> t.List[t.Any]: """Helper function that return the remainder of the result as a list of values. See :class:`neo4j.Record.value` @@ -499,38 +526,38 @@ def value(self, key=0, default=None): :param key: field to return for each remaining record. Obtain a single value from the record by index or key. :param default: default value, used if the index of key is unavailable - :returns: list of individual values - :rtype: list - :raises ResultConsumedError: if the transaction from which this result was obtained has been closed or the Result has been explicitly consumed. + :returns: list of individual values + .. versionchanged:: 5.0 Can raise :exc:`ResultConsumedError`. """ return [record.value(key, default) for record in self] - def values(self, *keys): + def values( + self, *keys: _T_ResultKey + ) -> t.List[t.List[t.Any]]: """Helper function that return the remainder of the result as a list of values lists. See :class:`neo4j.Record.values` :param keys: fields to return for each remaining record. Optionally filtering to include only certain values by index or key. - :returns: list of values lists - :rtype: list - :raises ResultConsumedError: if the transaction from which this result was obtained has been closed or the Result has been explicitly consumed. + :returns: list of values lists + .. versionchanged:: 5.0 Can raise :exc:`ResultConsumedError`. """ return [record.values(*keys) for record in self] - def data(self, *keys): + def data(self, *keys: _T_ResultKey) -> t.List[t.Any]: """Helper function that return the remainder of the result as a list of dictionaries. See :class:`neo4j.Record.data` @@ -551,7 +578,11 @@ def data(self, *keys): @experimental("pandas support is experimental and might be changed or " "removed in future versions") - def to_df(self, expand=False, parse_dates=False): + def to_df( + self, + expand: bool = False, + parse_dates: bool = False + ) -> pandas.DataFrame: r"""Convert (the rest of) the result to a pandas DataFrame. This method is only available if the `pandas` library is installed. @@ -627,14 +658,11 @@ def to_df(self, expand=False, parse_dates=False): :const:`dict` keys and variable names that contain ``.`` or ``\`` will be escaped with a backslash (``\.`` and ``\\`` respectively). - :type expand: bool :param parse_dates: If :const:`True`, columns that excluvively contain :class:`time.DateTime` objects, :class:`time.Date` objects, or :const:`None`, will be converted to :class:`pandas.Timestamp`. - :type parse_dates: bool - :rtype: :py:class:`pandas.DataFrame` :raises ImportError: if `pandas` library is not available. :raises ResultConsumedError: if the transaction from which this result was obtained has been closed or the Result has been explicitly @@ -644,7 +672,7 @@ def to_df(self, expand=False, parse_dates=False): ``pandas`` support might be changed or removed in future versions without warning. (See :ref:`filter-warnings-ref`) """ - import pandas as pd + import pandas as pd # type: ignore[import] if not expand: df = pd.DataFrame(self.values(), columns=self._keys) @@ -691,7 +719,7 @@ def to_df(self, expand=False, parse_dates=False): ) return df - def closed(self): + def closed(self) -> bool: """Return True if the result has been closed. When a result gets consumed :meth:`consume` or the transaction that @@ -702,7 +730,6 @@ def closed(self): will raise a :exc:`ResultConsumedError` when called. :returns: whether the result is closed. - :rtype: bool .. versionadded:: 5.0 """ diff --git a/neo4j/_sync/work/session.py b/neo4j/_sync/work/session.py index 87e103d6..6158d2b8 100644 --- a/neo4j/_sync/work/session.py +++ b/neo4j/_sync/work/session.py @@ -16,10 +16,22 @@ # limitations under the License. +from __future__ import annotations + +import typing as t from logging import getLogger from random import random from time import perf_counter + +if t.TYPE_CHECKING: + import typing_extensions as te + + from ..io import Bolt + + _R = t.TypeVar("_R") + _P = te.ParamSpec("_P") + from ..._async_compat import sleep from ..._conf import SessionConfig from ..._meta import ( @@ -73,10 +85,11 @@ class Session(Workspace): """ # The current connection. - _connection = None + _connection: t.Optional[Bolt] = None - # The current :class:`.Transaction` instance, if any. - _transaction = None + # The current transaction instance, if any. + _transaction: t.Union[Transaction, ManagedTransaction, None] = \ + None # The current auto-transaction result, if any. _auto_result = None @@ -89,7 +102,7 @@ def __init__(self, pool, session_config): assert isinstance(session_config, SessionConfig) self._bookmarks = self._prepare_bookmarks(session_config.bookmarks) - def __enter__(self): + def __enter__(self) -> Session: return self def __exit__(self, exception_type, exception_value, traceback): @@ -139,7 +152,7 @@ def _get_server_info(self): self._disconnect() return server_info - def close(self): + def close(self) -> None: """Close the session. This will release any borrowed resources, such as connections, and will @@ -183,7 +196,12 @@ def close(self): self._state_failed = False self._closed = True - def run(self, query, parameters=None, **kwargs): + def run( + self, + query: t.Union[str, Query], + parameters: t.Dict[str, t.Any] = None, + **kwargs: t.Any + ) -> Result: """Run a Cypher query within an auto-commit transaction. The query is sent and the result header received @@ -202,12 +220,10 @@ def run(self, query, parameters=None, **kwargs): For more usage details, see :meth:`.Transaction.run`. :param query: cypher query - :type query: str, neo4j.Query :param parameters: dictionary of parameters - :type parameters: dict :param kwargs: additional keyword parameters + :returns: a new :class:`neo4j.Result` object - :rtype: Result """ if not query: raise ValueError("Cannot run an empty query") @@ -224,8 +240,6 @@ def run(self, query, parameters=None, **kwargs): if not self._connection: self._connect(self._config.default_access_mode) cx = self._connection - protocol_version = cx.PROTOCOL_VERSION - server_info = cx.server_info self._auto_result = Result( cx, self._config.fetch_size, self._result_closed, @@ -243,7 +257,7 @@ def run(self, query, parameters=None, **kwargs): "`last_bookmark` has been deprecated in favor of `last_bookmarks`. " "This method can lead to unexpected behaviour." ) - def last_bookmark(self): + def last_bookmark(self) -> t.Optional[str]: """Return the bookmark received following the last completed transaction. Note: For auto-transactions (:meth:`Session.run`), this will trigger @@ -258,7 +272,6 @@ def last_bookmark(self): Use :meth:`last_bookmarks` instead. :returns: last bookmark - :rtype: str or None """ # The set of bookmarks to be passed into the next transaction. @@ -273,7 +286,7 @@ def last_bookmark(self): return self._bookmarks[-1] return None - def last_bookmarks(self): + def last_bookmarks(self) -> Bookmarks: """Return most recent bookmarks of the session. Bookmarks can be used to causally chain sessions. For example, @@ -300,7 +313,6 @@ def last_bookmarks(self): :meth:`Result.consume` for the current result. :returns: the session's last known bookmarks - :rtype: Bookmarks """ # The set of bookmarks to be passed into the next transaction. @@ -338,7 +350,11 @@ def _open_transaction( self._bookmarks, access_mode, metadata, timeout ) - def begin_transaction(self, metadata=None, timeout=None): + def begin_transaction( + self, + metadata: t.Dict[str, t.Any] = None, + timeout: float = None + ) -> Transaction: """ Begin a new unmanaged transaction. Creates a new :class:`.Transaction` within this session. At most one transaction may exist in a session at any point in time. To maintain multiple concurrent transactions, use multiple concurrent sessions. @@ -350,7 +366,6 @@ def begin_transaction(self, metadata=None, timeout=None): Specified metadata will be attached to the executing transaction and visible in the output of ``dbms.listQueries`` and ``dbms.listTransactions`` procedures. It will also get logged to the ``query.log``. This functionality makes it easier to tag transactions and is equivalent to ``dbms.setTXMetaData`` procedure, see https://neo4j.com/docs/operations-manual/current/reference/procedures/ for procedure reference. - :type metadata: dict :param timeout: the transaction timeout in seconds. @@ -358,12 +373,10 @@ def begin_transaction(self, metadata=None, timeout=None): This functionality allows to limit query/transaction execution time. Specified timeout overrides the default timeout configured in the database using ``dbms.transaction.timeout`` setting. Value should not represent a duration of zero or negative duration. - :type timeout: int - - :returns: A new transaction instance. - :rtype: Transaction :raises TransactionError: :class:`neo4j.exceptions.TransactionError` if a transaction is already open. + + :returns: A new transaction instance. """ # TODO: Implement TransactionConfig consumption @@ -371,7 +384,9 @@ def begin_transaction(self, metadata=None, timeout=None): self._auto_result.consume() if self._transaction: - raise TransactionError("Explicit transaction already open") + raise TransactionError( + self._transaction, "Explicit transaction already open" + ) self._open_transaction( tx_cls=Transaction, @@ -379,7 +394,7 @@ def begin_transaction(self, metadata=None, timeout=None): timeout=timeout ) - return self._transaction + return t.cast(Transaction, self._transaction) def _run_transaction( self, access_mode, transaction_function, *args, **kwargs @@ -438,7 +453,13 @@ def _run_transaction( else: raise ServiceUnavailable("Transaction failed") - def read_transaction(self, transaction_function, *args, **kwargs): + def read_transaction( + self, + transaction_function: t.Callable[ + te.Concatenate[ManagedTransaction, _P], t.Union[_R] + ], + *args: _P.args, **kwargs: _P.kwargs + ) -> _R: """Execute a unit of work in a managed read transaction. .. note:: @@ -487,13 +508,20 @@ def get_two_tx(tx): :class:`.Transaction`. :param args: arguments for the `transaction_function` :param kwargs: key word arguments for the `transaction_function` + :return: a result as returned by the given unit of work """ return self._run_transaction( READ_ACCESS, transaction_function, *args, **kwargs ) - def write_transaction(self, transaction_function, *args, **kwargs): + def write_transaction( + self, + transaction_function: t.Callable[ + te.Concatenate[ManagedTransaction, _P], t.Union[_R] + ], + *args: _P.args, **kwargs: _P.kwargs + ) -> _R: """Execute a unit of work in a managed write transaction. .. note:: @@ -522,6 +550,7 @@ def create_node_tx(tx, name): :class:`.Transaction`. :param args: key word arguments for the `transaction_function` :param kwargs: key word arguments for the `transaction_function` + :return: a result as returned by the given unit of work """ return self._run_transaction( diff --git a/neo4j/_sync/work/transaction.py b/neo4j/_sync/work/transaction.py index a834f00c..c291ee3c 100644 --- a/neo4j/_sync/work/transaction.py +++ b/neo4j/_sync/work/transaction.py @@ -16,6 +16,9 @@ # limitations under the License. +from __future__ import annotations + +import typing as t from functools import wraps from ..._async_compat.util import Util @@ -25,10 +28,14 @@ from .result import Result -__all__ = ("Transaction", "ManagedTransaction") +__all__ = ( + "ManagedTransaction", + "Transaction", + "TransactionBase", +) -class _TransactionBase: +class TransactionBase: def __init__(self, connection, fetch_size, on_closed, on_error): self._connection = connection self._error_handling_connection = ConnectionErrorHandler( @@ -75,7 +82,12 @@ def _consume_results(self): result._tx_end() self._results = [] - def run(self, query, parameters=None, **kwparameters): + def run( + self, + query: str, + parameters: t.Dict[str, t.Any] = None, + **kwparameters: t.Any + ) -> Result: """ Run a Cypher query within the context of this transaction. Cypher is typically expressed as a query template plus a @@ -95,15 +107,12 @@ def run(self, query, parameters=None, **kwparameters): :class:`list` properties must be homogenous. :param query: cypher query - :type query: str :param parameters: dictionary of parameters - :type parameters: dict :param kwparameters: additional keyword parameters - :returns: a new :class:`neo4j.Result` object - :rtype: :class:`neo4j.Result` - :raise TransactionError: if the transaction is already closed + + :returns: a new :class:`neo4j.Result` object """ if isinstance(query, Query): raise ValueError("Query object is only supported for session.run") @@ -194,7 +203,7 @@ def _closed(self): return self._closed_flag -class Transaction(_TransactionBase): +class Transaction(TransactionBase): """ Container for multiple Cypher queries to be executed within a single context. :class:`Transaction` objects can be used as a context managers (:py:const:`with` block) where the transaction is committed @@ -205,32 +214,32 @@ class Transaction(_TransactionBase): """ - @wraps(_TransactionBase._enter) - def __enter__(self): + @wraps(TransactionBase._enter) + def __enter__(self) -> Transaction: return self._enter() - @wraps(_TransactionBase._exit) + @wraps(TransactionBase._exit) def __exit__(self, exception_type, exception_value, traceback): self._exit(exception_type, exception_value, traceback) - @wraps(_TransactionBase._commit) - def commit(self): + @wraps(TransactionBase._commit) + def commit(self) -> None: return self._commit() - @wraps(_TransactionBase._rollback) - def rollback(self): + @wraps(TransactionBase._rollback) + def rollback(self) -> None: return self._rollback() - @wraps(_TransactionBase._close) - def close(self): + @wraps(TransactionBase._close) + def close(self) -> None: return self._close() - @wraps(_TransactionBase._closed) - def closed(self): + @wraps(TransactionBase._closed) + def closed(self) -> bool: return self._closed() -class ManagedTransaction(_TransactionBase): +class ManagedTransaction(TransactionBase): """Transaction object provided to transaction functions. Inside a transaction function, the driver is responsible for managing diff --git a/neo4j/_sync/work/workspace.py b/neo4j/_sync/work/workspace.py index c10fc912..3a04f6cf 100644 --- a/neo4j/_sync/work/workspace.py +++ b/neo4j/_sync/work/workspace.py @@ -16,6 +16,8 @@ # limitations under the License. +from __future__ import annotations + import asyncio from ..._conf import WorkspaceConfig @@ -64,7 +66,7 @@ def __del__(self): except (OSError, ServiceUnavailable, SessionExpired): pass - def __enter__(self): + def __enter__(self) -> Workspace: return self def __exit__(self, exc_type, exc_value, traceback): diff --git a/neo4j/addressing.py b/neo4j/addressing.py index 58ef8dcc..0f0e2d22 100644 --- a/neo4j/addressing.py +++ b/neo4j/addressing.py @@ -16,22 +16,39 @@ # limitations under the License. +from __future__ import annotations + import logging +import typing as t from socket import ( + AddressFamily, AF_INET, AF_INET6, getservbyname, ) +import typing_extensions as te + log = logging.getLogger("neo4j") -class _AddressMeta(type(tuple)): +_T = t.TypeVar("_T") + + +class _WithPeerName(te.Protocol): + def getpeername(self) -> tuple: ... + + +assert type(tuple) is type - def __init__(self, *args, **kwargs): - self._ipv4_cls = None - self._ipv6_cls = None + +class _AddressMeta(type): + + def __init__(cls, *args, **kwargs): + super().__init__(*args, **kwargs) + cls._ipv4_cls = None + cls._ipv6_cls = None def _subclass_by_family(self, family): subclasses = [ @@ -63,16 +80,25 @@ def ipv6_cls(self): class Address(tuple, metaclass=_AddressMeta): @classmethod - def from_socket(cls, socket): + def from_socket( + cls: t.Type[_T_Address], + socket: _WithPeerName + ) -> _T_Address: address = socket.getpeername() return cls(address) @classmethod - def parse(cls, s, default_host=None, default_port=None): + def parse( + cls: t.Type[_T_Address], + s: str, + default_host: str = None, + default_port: int = None + ) -> _T_Address: if not isinstance(s, str): raise TypeError("Address.parse requires a string argument") if s.startswith("["): # IPv6 + port: t.Union[str, int] host, _, port = s[1:].rpartition("]") port = port.lstrip(":") try: @@ -92,16 +118,21 @@ def parse(cls, s, default_host=None, default_port=None): port or default_port or 0)) @classmethod - def parse_list(cls, *s, default_host=None, default_port=None): + def parse_list( + cls: t.Type[_T_Address], + *s: str, + default_host: str = None, + default_port: int = None + ) -> t.List[_T_Address]: """ Parse a string containing one or more socket addresses, each separated by whitespace. """ if not all(isinstance(s0, str) for s0 in s): raise TypeError("Address.parse_list requires a string argument") - return [Address.parse(a, default_host, default_port) + return [cls.parse(a, default_host, default_port) for a in " ".join(s).split()] - def __new__(cls, iterable): + def __new__(cls, iterable: t.Collection) -> Address: if isinstance(iterable, cls): return iterable n_parts = len(iterable) @@ -116,29 +147,29 @@ def __new__(cls, iterable): return inst #: Address family (AF_INET or AF_INET6) - family = None + family: t.Optional[AddressFamily] = None def __repr__(self): return "{}({!r})".format(self.__class__.__name__, tuple(self)) @property - def host_name(self): + def host_name(self) -> str: return self[0] @property - def host(self): + def host(self) -> str: return self[0] @property - def port(self): + def port(self) -> int: return self[1] @property - def unresolved(self): + def unresolved(self) -> Address: return self @property - def port_number(self): + def port_number(self) -> int: try: return getservbyname(self[1]) except (OSError, TypeError): @@ -150,11 +181,14 @@ def port_number(self): raise type(e)("Unknown port value %r" % self[1]) +_T_Address = t.TypeVar("_T_Address", bound=Address) + + class IPv4Address(Address): family = AF_INET - def __str__(self): + def __str__(self) -> str: return "{}:{}".format(*self) @@ -162,22 +196,25 @@ class IPv6Address(Address): family = AF_INET6 - def __str__(self): + def __str__(self) -> str: return "[{}]:{}".format(*self) class ResolvedAddress(Address): + _host_name: str + @property - def host_name(self): + def host_name(self) -> str: return self._host_name @property - def unresolved(self): + def unresolved(self) -> Address: return super().__new__(Address, (self._host_name, *self[1:])) - def __new__(cls, iterable, host_name=None): + def __new__(cls, iterable, *, host_name: str) -> ResolvedAddress: new = super().__new__(cls, iterable) + new = t.cast(ResolvedAddress, new) new._host_name = host_name return new diff --git a/neo4j/api.py b/neo4j/api.py index 7930d1d4..7db50c90 100644 --- a/neo4j/api.py +++ b/neo4j/api.py @@ -18,42 +18,51 @@ """ Base classes and helpers. """ +from __future__ import annotations +import typing as t from urllib.parse import ( parse_qs, urlparse, ) + +if t.TYPE_CHECKING: + import typing_extensions as te + from .addressing import Address + from ._meta import deprecated from .exceptions import ConfigurationError -READ_ACCESS = "READ" -WRITE_ACCESS = "WRITE" +READ_ACCESS: te.Final[str] = "READ" +WRITE_ACCESS: te.Final[str] = "WRITE" -DRIVER_BOLT = "DRIVER_BOLT" -DRIVER_NEO4j = "DRIVER_NEO4J" +DRIVER_BOLT: te.Final[str] = "DRIVER_BOLT" +DRIVER_NEO4J: te.Final[str] = "DRIVER_NEO4J" -SECURITY_TYPE_NOT_SECURE = "SECURITY_TYPE_NOT_SECURE" -SECURITY_TYPE_SELF_SIGNED_CERTIFICATE = "SECURITY_TYPE_SELF_SIGNED_CERTIFICATE" -SECURITY_TYPE_SECURE = "SECURITY_TYPE_SECURE" +SECURITY_TYPE_NOT_SECURE: te.Final[str] = "SECURITY_TYPE_NOT_SECURE" +SECURITY_TYPE_SELF_SIGNED_CERTIFICATE: te.Final[str] = \ + "SECURITY_TYPE_SELF_SIGNED_CERTIFICATE" +SECURITY_TYPE_SECURE: te.Final[str] = "SECURITY_TYPE_SECURE" -URI_SCHEME_BOLT = "bolt" -URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE = "bolt+ssc" -URI_SCHEME_BOLT_SECURE = "bolt+s" +URI_SCHEME_BOLT: te.Final[str] = "bolt" +URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE: te.Final[str] = "bolt+ssc" +URI_SCHEME_BOLT_SECURE: te.Final[str] = "bolt+s" -URI_SCHEME_NEO4J = "neo4j" -URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE = "neo4j+ssc" -URI_SCHEME_NEO4J_SECURE = "neo4j+s" +URI_SCHEME_NEO4J: te.Final[str] = "neo4j" +URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE: te.Final[str] = "neo4j+ssc" +URI_SCHEME_NEO4J_SECURE: te.Final[str] = "neo4j+s" -URI_SCHEME_BOLT_ROUTING = "bolt+routing" +URI_SCHEME_BOLT_ROUTING: te.Final[str] = "bolt+routing" # TODO: 6.0 - remove TRUST constants -TRUST_SYSTEM_CA_SIGNED_CERTIFICATES = "TRUST_SYSTEM_CA_SIGNED_CERTIFICATES" # Default -TRUST_ALL_CERTIFICATES = "TRUST_ALL_CERTIFICATES" +TRUST_SYSTEM_CA_SIGNED_CERTIFICATES: te.Final[str] = \ + "TRUST_SYSTEM_CA_SIGNED_CERTIFICATES" # Default +TRUST_ALL_CERTIFICATES: te.Final[str] = "TRUST_ALL_CERTIFICATES" -SYSTEM_DATABASE = "system" -DEFAULT_DATABASE = None # Must be a non string hashable value +SYSTEM_DATABASE: te.Final[str] = "system" +DEFAULT_DATABASE: te.Final[None] = None # Must be a non string hashable value # TODO: This class is not tested @@ -62,19 +71,21 @@ class Auth: :param scheme: specifies the type of authentication, examples: "basic", "kerberos" - :type scheme: str :param principal: specifies who is being authenticated - :type principal: str or None :param credentials: authenticates the principal - :type credentials: str or None :param realm: specifies the authentication provider - :type realm: str or None :param parameters: extra key word parameters passed along to the authentication provider - :type parameters: Dict[str, Any] """ - def __init__(self, scheme, principal, credentials, realm=None, **parameters): + def __init__( + self, + scheme: t.Optional[str], + principal: t.Optional[str], + credentials: t.Optional[str], + realm: str = None, + **parameters: t.Any + ) -> None: self.scheme = scheme # Neo4j servers pre 4.4 require the principal field to always be # present. Therefore, we transmit it even if it's an empty sting. @@ -92,75 +103,67 @@ def __init__(self, scheme, principal, credentials, realm=None, **parameters): AuthToken = Auth -def basic_auth(user, password, realm=None): +def basic_auth(user: str, password: str, realm: str = None) -> Auth: """Generate a basic auth token for a given user and password. This will set the scheme to "basic" for the auth token. :param user: user name, this will set the - :type user: str :param password: current password, this will set the credentials - :type password: str :param realm: specifies the authentication provider - :type realm: str or None :return: auth token for use with :meth:`GraphDatabase.driver` or :meth:`AsyncGraphDatabase.driver` - :rtype: :class:`neo4j.Auth` """ return Auth("basic", user, password, realm) -def kerberos_auth(base64_encoded_ticket): +def kerberos_auth(base64_encoded_ticket: str) -> Auth: """Generate a kerberos auth token with the base64 encoded ticket. This will set the scheme to "kerberos" for the auth token. :param base64_encoded_ticket: a base64 encoded service ticket, this will set the credentials - :type base64_encoded_ticket: str :return: auth token for use with :meth:`GraphDatabase.driver` or :meth:`AsyncGraphDatabase.driver` - :rtype: :class:`neo4j.Auth` """ return Auth("kerberos", "", base64_encoded_ticket) -def bearer_auth(base64_encoded_token): +def bearer_auth(base64_encoded_token: str) -> Auth: """Generate an auth token for Single-Sign-On providers. This will set the scheme to "bearer" for the auth token. :param base64_encoded_token: a base64 encoded authentication token generated by a Single-Sign-On provider. - :type base64_encoded_token: str :return: auth token for use with :meth:`GraphDatabase.driver` or :meth:`AsyncGraphDatabase.driver` - :rtype: :class:`neo4j.Auth` """ return Auth("bearer", None, base64_encoded_token) -def custom_auth(principal, credentials, realm, scheme, **parameters): +def custom_auth( + principal: t.Optional[str], + credentials: t.Optional[str], + realm: t.Optional[str], + scheme: t.Optional[str], + **parameters: t.Any +) -> Auth: """Generate a custom auth token. :param principal: specifies who is being authenticated - :type principal: str or None :param credentials: authenticates the principal - :type credentials: str or None :param realm: specifies the authentication provider - :type realm: str or None :param scheme: specifies the type of authentication - :type scheme: str or None :param parameters: extra key word parameters passed along to the authentication provider - :type parameters: Dict[str, Any] :return: auth token for use with :meth:`GraphDatabase.driver` or :meth:`AsyncGraphDatabase.driver` - :rtype: :class:`neo4j.Auth` """ return Auth(scheme, principal, credentials, realm, **parameters) @@ -177,7 +180,7 @@ class Bookmark: """ @deprecated("Use the `Bookmarks`` class instead.") - def __init__(self, *values): + def __init__(self, *values: str) -> None: if values: bookmarks = [] for ix in values: @@ -191,20 +194,19 @@ def __init__(self, *values): else: self._values = frozenset() - def __repr__(self): + def __repr__(self) -> str: """ :return: repr string with sorted values """ return "".format(", ".join(["'{}'".format(ix) for ix in sorted(self._values)])) - def __bool__(self): + def __bool__(self) -> bool: return bool(self._values) @property - def values(self): + def values(self) -> frozenset: """ :return: immutable list of bookmark string values - :rtype: frozenset """ return self._values @@ -224,7 +226,7 @@ class Bookmarks: def __init__(self): self._raw_values = frozenset() - def __repr__(self): + def __repr__(self) -> str: """ :return: repr string with sorted values """ @@ -232,10 +234,10 @@ def __repr__(self): ", ".join(map(repr, sorted(self._raw_values))) ) - def __bool__(self): + def __bool__(self) -> bool: return bool(self._raw_values) - def __add__(self, other): + def __add__(self, other: Bookmarks) -> Bookmarks: if isinstance(other, Bookmarks): if not other: return self @@ -245,7 +247,7 @@ def __add__(self, other): return NotImplemented @property - def raw_values(self): + def raw_values(self) -> t.FrozenSet[str]: """The raw bookmark values. You should not need to access them unless you want to serialize @@ -257,7 +259,7 @@ def raw_values(self): return self._raw_values @classmethod - def from_raw_values(cls, values): + def from_raw_values(cls, values: t.Iterable[str]) -> Bookmarks: """Create a Bookmarks object from a list of raw bookmark string values. You should not need to use this method unless you want to deserialize @@ -285,19 +287,19 @@ class ServerInfo: """ Represents a package of information relating to a Neo4j server. """ - def __init__(self, address, protocol_version): + def __init__(self, address: Address, protocol_version: Version): self._address = address self._protocol_version = protocol_version - self._metadata = {} + self._metadata: dict = {} @property - def address(self): + def address(self) -> Address: """ Network address of the remote server. """ return self._address @property - def protocol_version(self): + def protocol_version(self) -> Version: """ Bolt protocol version with which the remote server communicates. This is returned as a :class:`.Version` object, which itself extends a simple 2-tuple of @@ -306,13 +308,13 @@ def protocol_version(self): return self._protocol_version @property - def agent(self): + def agent(self) -> str: """ Server agent string by which the remote server identifies itself. """ - return self._metadata.get("server") + return str(self._metadata.get("server")) - @property + @property # type: ignore @deprecated("The connection id is considered internal information " "and will no longer be exposed in future versions.") def connection_id(self): @@ -320,7 +322,7 @@ def connection_id(self): """ return self._metadata.get("connection_id") - def update(self, metadata): + def update(self, metadata: dict) -> None: """ Update server information with extra metadata. This is typically drawn from the metadata received after successful connection initialisation. @@ -339,7 +341,7 @@ def __repr__(self): def __str__(self): return ".".join(map(str, self)) - def to_bytes(self): + def to_bytes(self) -> bytes: b = bytearray(4) for i, v in enumerate(self): if not 0 <= i < 2: @@ -352,7 +354,7 @@ def to_bytes(self): return bytes(b) @classmethod - def from_bytes(cls, b): + def from_bytes(cls, b: bytes) -> Version: b = bytearray(b) if len(b) != 4: raise ValueError("Byte representation must be exactly four bytes") @@ -382,13 +384,13 @@ def parse_neo4j_uri(uri): driver_type = DRIVER_BOLT security_type = SECURITY_TYPE_SECURE elif parsed.scheme == URI_SCHEME_NEO4J: - driver_type = DRIVER_NEO4j + driver_type = DRIVER_NEO4J security_type = SECURITY_TYPE_NOT_SECURE elif parsed.scheme == URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE: - driver_type = DRIVER_NEO4j + driver_type = DRIVER_NEO4J security_type = SECURITY_TYPE_SELF_SIGNED_CERTIFICATE elif parsed.scheme == URI_SCHEME_NEO4J_SECURE: - driver_type = DRIVER_NEO4j + driver_type = DRIVER_NEO4J security_type = SECURITY_TYPE_SECURE else: raise ConfigurationError("URI scheme {!r} is not supported. Supported URI schemes are {}. Examples: bolt://host[:port] or neo4j://host[:port][?routing_context]".format( diff --git a/neo4j/debug.py b/neo4j/debug.py index ea960056..2a4d37a0 100644 --- a/neo4j/debug.py +++ b/neo4j/debug.py @@ -16,6 +16,9 @@ # limitations under the License. +from __future__ import annotations + +import typing as t from logging import ( CRITICAL, DEBUG, @@ -71,36 +74,37 @@ class Watcher: enable logging for all threads. :param logger_names: Names of loggers to watch. - :type logger_names: str :param default_level: Default minimum log level to show. The level can be overridden by setting the level a level when calling :meth:`.watch`. - :type default_level: int :param default_out: Default output stream for all loggers. The level can be overridden by setting the level a level when calling :meth:`.watch`. - :type default_out: stream or file-like object :param colour: Whether the log levels should be indicated with ANSI colour codes. - :type colour: bool """ - def __init__(self, *logger_names, default_level=DEBUG, default_out=stderr, - colour=False): + def __init__( + self, + *logger_names: str, + default_level: int = DEBUG, + default_out: t.TextIO = stderr, + colour: bool = False + ) -> None: super(Watcher, self).__init__() self.logger_names = logger_names self._loggers = [getLogger(name) for name in self.logger_names] self.default_level = default_level self.default_out = default_out - self._handlers = {} + self._handlers: t.Dict[str, StreamHandler] = {} - format = "%(threadName)s(%(thread)d) %(asctime)s %(message)s" + format_ = "%(threadName)s(%(thread)d) %(asctime)s %(message)s" if not colour: - format = "[%(levelname)-8s] " + format + format_ = "[%(levelname)-8s] " + format_ formatter_cls = ColourFormatter if colour else Formatter - self.formatter = formatter_cls(format) + self.formatter = formatter_cls(format_) - def __enter__(self): + def __enter__(self) -> Watcher: """Enable logging for all loggers.""" self.watch() return self @@ -109,15 +113,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): """Disable logging for all loggers.""" self.stop() - def watch(self, level=None, out=None): + def watch(self, level: int = None, out: t.TextIO = None): """Enable logging for all loggers. :param level: Minimum log level to show. If :const:`None`, the ``default_level`` is used. - :type level: int or :const:`None` :param out: Output stream for all loggers. If :const:`None`, the ``default_out`` is used. - :type out: stream or file-like object or :const:`None` """ if level is None: level = self.default_level @@ -131,7 +133,7 @@ def watch(self, level=None, out=None): logger.addHandler(handler) logger.setLevel(level) - def stop(self): + def stop(self) -> None: """Disable logging for all loggers.""" for logger in self._loggers: try: @@ -140,7 +142,12 @@ def stop(self): pass -def watch(*logger_names, level=DEBUG, out=stderr, colour=False): +def watch( + *logger_names: str, + level: int = DEBUG, + out: t.TextIO = stderr, + colour: bool = False +) -> Watcher: """Quick wrapper for using :class:`.Watcher`. Create a Wathcer with the given configuration, enable watching and return diff --git a/neo4j/exceptions.py b/neo4j/exceptions.py index 15392c5e..7c72af69 100644 --- a/neo4j/exceptions.py +++ b/neo4j/exceptions.py @@ -67,15 +67,38 @@ """ +from __future__ import annotations + +import typing as t + + +if t.TYPE_CHECKING: + import typing_extensions as te + + from ._async.work import ( + AsyncResult, + AsyncSession, + AsyncTransactionBase, + ) + from ._sync.work import ( + Result, + Session, + TransactionBase, + ) + + _T_Transaction = t.Union[AsyncTransactionBase, TransactionBase] + _T_Result = t.Union[AsyncResult, Result] + _T_Session = t.Union[AsyncSession, Session] + from ._meta import deprecated -CLASSIFICATION_CLIENT = "ClientError" -CLASSIFICATION_TRANSIENT = "TransientError" -CLASSIFICATION_DATABASE = "DatabaseError" +CLASSIFICATION_CLIENT: te.Final[str] = "ClientError" +CLASSIFICATION_TRANSIENT: te.Final[str] = "TransientError" +CLASSIFICATION_DATABASE: te.Final[str] = "DatabaseError" -ERROR_REWRITE_MAP = { +ERROR_REWRITE_MAP: t.Dict[str, t.Tuple[str, t.Optional[str]]] = { # This error can be retried ed. The driver just needs to re-authenticate # with the same credentials. "Neo.ClientError.Security.AuthorizationExpired": ( @@ -114,7 +137,9 @@ class Neo4jError(Exception): metadata = None @classmethod - def hydrate(cls, message=None, code=None, **metadata): + def hydrate( + cls, message: str = None, code: str = None, **metadata: t.Any + ) -> Neo4jError: message = message or "An unknown error occurred" code = code or "Neo.DatabaseError.General.UnknownError" try: @@ -167,14 +192,13 @@ def _extract_error_class(cls, classification, code): "Neo4jError.is_retriable is deprecated and will be removed in a " "future version. Please use Neo4jError.is_retryable instead." ) - def is_retriable(self): + def is_retriable(self) -> bool: """Whether the error is retryable. See :meth:`.is_retryable`. :return: :const:`True` if the error is retryable, :const:`False` otherwise. - :rtype: bool .. deprecated:: 5.0 This method will be removed in a future version. @@ -182,7 +206,7 @@ def is_retriable(self): """ return self.is_retryable() - def is_retryable(self): + def is_retryable(self) -> bool: """Whether the error is retryable. Indicates whether a transaction that yielded this error makes sense to @@ -191,14 +215,13 @@ def is_retryable(self): :return: :const:`True` if the error is retryable, :const:`False` otherwise. - :rtype: bool """ return False def invalidates_all_connections(self): return self.code == "Neo.ClientError.Security.AuthorizationExpired" - def is_fatal_during_discovery(self): + def is_fatal_during_discovery(self) -> bool: # checks if the code is an error that is caused by the client. In this # case the driver should fail fast during discovery. if not isinstance(self.code, str): @@ -221,7 +244,7 @@ def __str__(self): class ClientError(Neo4jError): """ The Client sent a bad request - changing the request might yield a successful outcome. """ - def __str__(self): + def __str__(self) -> str: return super().__str__() @@ -274,7 +297,7 @@ class TransientError(Neo4jError): """ The database cannot service the request right now, retrying later might yield a successful outcome. """ - def is_retryable(self): + def is_retryable(self) -> bool: return True @@ -296,7 +319,7 @@ class ForbiddenOnReadOnlyDatabase(TransientError): """ -client_errors = { +client_errors: t.Dict[str, t.Type[Neo4jError]] = { # ConstraintError "Neo.ClientError.Schema.ConstraintValidationFailed": ConstraintError, @@ -332,7 +355,7 @@ class ForbiddenOnReadOnlyDatabase(TransientError): "Neo.ClientError.Cluster.NotALeader": NotALeader, } -transient_errors = { +transient_errors: t.Dict[str, t.Type[Neo4jError]] = { # DatabaseUnavailableError "Neo.TransientError.General.DatabaseUnavailable": DatabaseUnavailable @@ -343,7 +366,7 @@ class ForbiddenOnReadOnlyDatabase(TransientError): class DriverError(Exception): """ Raised when the Driver raises an error. """ - def is_retryable(self): + def is_retryable(self) -> bool: """Whether the error is retryable. Indicates whether a transaction that yielded this error makes sense to @@ -352,7 +375,6 @@ def is_retryable(self): :return: :const:`True` if the error is retryable, :const:`False` otherwise. - :rtype: bool """ return False @@ -362,7 +384,10 @@ class TransactionError(DriverError): """ Raised when an error occurs while using a transaction. """ - def __init__(self, transaction, *args, **kwargs): + def __init__( + self, transaction: _T_Transaction, + *args, **kwargs + ) -> None: super().__init__(*args, **kwargs) self.transaction = transaction @@ -372,7 +397,10 @@ class TransactionNestingError(TransactionError): """ Raised when transactions are nested incorrectly. """ - def __init__(self, transaction, *args, **kwargs): + def __init__( + self, transaction: _T_Transaction, + *args, **kwargs + ) -> None: super().__init__(*args, **kwargs) self.transaction = transaction @@ -381,7 +409,9 @@ def __init__(self, transaction, *args, **kwargs): class ResultError(DriverError): """Raised when an error occurs while using a result object.""" - def __init__(self, result, *args, **kwargs): + def __init__( + self, result: _T_Result, *args, **kwargs + ) -> None: super().__init__(*args, **kwargs) self.result = result @@ -411,10 +441,12 @@ class SessionExpired(DriverError): the purpose described by its original parameters. """ - def __init__(self, session, *args, **kwargs): + def __init__( + self, session: _T_Session, *args, **kwargs + ) -> None: super().__init__(session, *args, **kwargs) - def is_retryable(self): + def is_retryable(self) -> bool: return True @@ -426,7 +458,7 @@ class ServiceUnavailable(DriverError): failure of a database service that the driver is unable to route around. """ - def is_retryable(self): + def is_retryable(self) -> bool: return True @@ -458,7 +490,7 @@ class IncompleteCommit(ServiceUnavailable): successfully or not. """ - def is_retryable(self): + def is_retryable(self) -> bool: return False diff --git a/neo4j/graph/__init__.py b/neo4j/graph/__init__.py index 0939c54c..7409165b 100644 --- a/neo4j/graph/__init__.py +++ b/neo4j/graph/__init__.py @@ -21,6 +21,17 @@ """ +from __future__ import annotations + +import typing as t +from collections.abc import Mapping + +from .._meta import ( + deprecated, + deprecation_warn, +) + + __all__ = [ "Graph", "Node", @@ -29,12 +40,7 @@ ] -from collections.abc import Mapping - -from .._meta import ( - deprecated, - deprecation_warn, -) +_T = t.TypeVar("_T") class Graph: @@ -42,46 +48,56 @@ class Graph: :class:`.Node` and :class:`.Relationship` instances. """ - def __init__(self): - self._nodes = {} - self._legacy_nodes = {} # TODO: 6.0 - remove - self._relationships = {} - self._legacy_relationships = {} # TODO: 6.0 - remove - self._relationship_types = {} + def __init__(self) -> None: + self._nodes: t.Dict[str, Node] = {} + self._legacy_nodes: t.Dict[int, Node] = {} # TODO: 6.0 - remove + self._relationships: t.Dict[str, Relationship] = {} + # TODO: 6.0 - remove + self._legacy_relationships: t.Dict[int, Relationship] = {} + self._relationship_types: t.Dict[str, t.Type[Relationship]] = {} self._node_set_view = EntitySetView(self._nodes, self._legacy_nodes) self._relationship_set_view = EntitySetView(self._relationships, self._legacy_relationships) @property - def nodes(self): + def nodes(self) -> EntitySetView[Node]: """ Access a set view of the nodes in this graph. """ return self._node_set_view @property - def relationships(self): + def relationships(self) -> EntitySetView[Relationship]: """ Access a set view of the relationships in this graph. """ return self._relationship_set_view - def relationship_type(self, name): + def relationship_type(self, name: str) -> t.Type[Relationship]: """ Obtain a :class:`.Relationship` subclass for a given relationship type name. """ try: cls = self._relationship_types[name] except KeyError: - cls = self._relationship_types[name] = type(str(name), (Relationship,), {}) + cls = self._relationship_types[name] = t.cast( + t.Type[Relationship], + type(str(name), (Relationship,), {}) + ) return cls -class Entity(Mapping): +class Entity(t.Mapping[str, t.Any]): """ Base class for :class:`.Node` and :class:`.Relationship` that provides :class:`.Graph` membership and property containment functionality. """ - def __init__(self, graph, element_id, id_, properties): + def __init__( + self, + graph: Graph, + element_id: str, + id_: int, + properties: t.Optional[t.Dict[str, t.Any]] + ) -> None: self._graph = graph self._element_id = element_id self._id = id_ @@ -89,7 +105,7 @@ def __init__(self, graph, element_id, id_, properties): k: v for k, v in (properties or {}).items() if v is not None } - def __eq__(self, other): + def __eq__(self, other: t.Any) -> bool: try: return (type(self) == type(other) and self.graph == other.graph @@ -97,33 +113,33 @@ def __eq__(self, other): except AttributeError: return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) def __hash__(self): return hash(self._element_id) - def __len__(self): + def __len__(self) -> int: return len(self._properties) - def __getitem__(self, name): + def __getitem__(self, name: str) -> t.Any: return self._properties.get(name) - def __contains__(self, name): + def __contains__(self, name: object) -> bool: return name in self._properties - def __iter__(self): + def __iter__(self) -> t.Iterator[str]: return iter(self._properties) @property - def graph(self): + def graph(self) -> Graph: """ The :class:`.Graph` to which this entity belongs. """ return self._graph - @property + @property # type: ignore @deprecated("`id` is deprecated, use `element_id` instead") - def id(self): + def id(self) -> int: """The legacy identity of this entity in its container :class:`.Graph`. Depending on the version of the server this entity was retrieved from, @@ -135,13 +151,11 @@ def id(self): .. deprecated:: 5.0 Use :attr:`.element_id` instead. - - :rtype: int """ return self._id @property - def element_id(self): + def element_id(self) -> str: """The identity of this entity in its container :class:`.Graph`. .. Warning:: @@ -149,41 +163,43 @@ def element_id(self): queries. Don't rely on it for cross-query computations. .. versionadded:: 5.0 - - :rtype: str """ return self._element_id - def get(self, name, default=None): + def get(self, name: str, default: object = None) -> t.Any: """ Get a property value by name, optionally with a default. """ return self._properties.get(name, default) - def keys(self): + def keys(self) -> t.KeysView[str]: """ Return an iterable of all property names. """ return self._properties.keys() - def values(self): + def values(self) -> t.ValuesView[t.Any]: """ Return an iterable of all property values. """ return self._properties.values() - def items(self): + def items(self) -> t.ItemsView[str, t.Any]: """ Return an iterable of all property name-value pairs. """ return self._properties.items() -class EntitySetView(Mapping): +class EntitySetView(Mapping, t.Generic[_T]): """ View of a set of :class:`.Entity` instances within a :class:`.Graph`. """ - def __init__(self, entity_dict, legacy_entity_dict): + def __init__( + self, + entity_dict: t.Dict[str, _T], + legacy_entity_dict: t.Dict[int, _T], + ) -> None: self._entity_dict = entity_dict self._legacy_entity_dict = legacy_entity_dict # TODO: 6.0 - remove - def __getitem__(self, e_id): + def __getitem__(self, e_id: t.Union[int, str]) -> _T: # TODO: 6.0 - remove this compatibility shim if isinstance(e_id, (int, float, complex)): deprecation_warn( @@ -193,10 +209,10 @@ def __getitem__(self, e_id): return self._legacy_entity_dict[e_id] return self._entity_dict[e_id] - def __len__(self): + def __len__(self) -> int: return len(self._entity_dict) - def __iter__(self): + def __iter__(self) -> t.Iterator[_T]: return iter(self._entity_dict.values()) @@ -204,17 +220,23 @@ class Node(Entity): """ Self-contained graph node. """ - def __init__(self, graph, element_id, id_, n_labels=None, - properties=None): + def __init__( + self, + graph: Graph, + element_id: str, + id_: int, + n_labels: t.Iterable[str] = None, + properties: t.Dict[str, t.Any] = None + ) -> None: Entity.__init__(self, graph, element_id, id_, properties) self._labels = frozenset(n_labels or ()) - def __repr__(self): + def __repr__(self) -> str: return (f"") @property - def labels(self): + def labels(self) -> t.FrozenSet[str]: """ The set of labels attached to this node. """ return self._labels @@ -224,36 +246,42 @@ class Relationship(Entity): """ Self-contained graph relationship. """ - def __init__(self, graph, element_id, id_, properties): + def __init__( + self, + graph: Graph, + element_id: str, + id_: int, + properties: t.Dict[str, t.Any], + ) -> None: Entity.__init__(self, graph, element_id, id_, properties) - self._start_node = None - self._end_node = None + self._start_node: t.Optional[Node] = None + self._end_node: t.Optional[Node] = None - def __repr__(self): + def __repr__(self) -> str: return (f"") @property - def nodes(self): + def nodes(self) -> t.Tuple[t.Optional[Node], t.Optional[Node]]: """ The pair of nodes which this relationship connects. """ return self._start_node, self._end_node @property - def start_node(self): + def start_node(self) -> t.Optional[Node]: """ The start node of this relationship. """ return self._start_node @property - def end_node(self): + def end_node(self) -> t.Optional[Node]: """ The end node of this relationship. """ return self._end_node @property - def type(self): + def type(self) -> str: """ The type name of this relationship. This is functionally equivalent to ``type(relationship).__name__``. """ @@ -264,31 +292,31 @@ class Path: """ Self-contained graph path. """ - def __init__(self, start_node, *relationships): + def __init__(self, start_node: Node, *relationships: Relationship) -> None: assert isinstance(start_node, Node) nodes = [start_node] for i, relationship in enumerate(relationships, start=1): assert isinstance(relationship, Relationship) if relationship.start_node == nodes[-1]: - nodes.append(relationship.end_node) + nodes.append(t.cast(Node, relationship.end_node)) elif relationship.end_node == nodes[-1]: - nodes.append(relationship.start_node) + nodes.append(t.cast(Node, relationship.start_node)) else: raise ValueError("Relationship %d does not connect to the last node" % i) self._nodes = tuple(nodes) self._relationships = relationships - def __repr__(self): + def __repr__(self) -> str: return "" % \ (self.start_node, self.end_node, len(self)) - def __eq__(self, other): + def __eq__(self, other: t.Any) -> bool: try: return self.start_node == other.start_node and self.relationships == other.relationships except AttributeError: return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) def __hash__(self): @@ -297,38 +325,38 @@ def __hash__(self): value ^= hash(relationship) return value - def __len__(self): + def __len__(self) -> int: return len(self._relationships) - def __iter__(self): + def __iter__(self) -> t.Iterator[Relationship]: return iter(self._relationships) @property - def graph(self): + def graph(self) -> Graph: """ The :class:`.Graph` to which this path belongs. """ return self._nodes[0].graph @property - def nodes(self): + def nodes(self) -> t.Tuple[Node, ...]: """ The sequence of :class:`.Node` objects in this path. """ return self._nodes @property - def start_node(self): + def start_node(self) -> Node: """ The first :class:`.Node` in this path. """ return self._nodes[0] @property - def end_node(self): + def end_node(self) -> Node: """ The last :class:`.Node` in this path. """ return self._nodes[-1] @property - def relationships(self): + def relationships(self) -> t.Tuple[Relationship, ...]: """ The sequence of :class:`.Relationship` objects in this path. """ return self._relationships diff --git a/neo4j/packstream.py b/neo4j/packstream.py index 041b644f..515c6be1 100644 --- a/neo4j/packstream.py +++ b/neo4j/packstream.py @@ -27,7 +27,6 @@ Packer, Structure, UnpackableBuffer, - UNPACKED_MARKERS, UNPACKED_UINT_8, UNPACKED_UINT_16, Unpacker, @@ -40,9 +39,6 @@ "PACKED_UINT_16", "UNPACKED_UINT_8", "UNPACKED_UINT_16", - "UNPACKED_MARKERS", - "UNPACKED_MARKERS", - "UNPACKED_MARKERS", "INT64_MIN", "INT64_MAX", "Structure", diff --git a/neo4j/py.typed b/neo4j/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/neo4j/spatial/__init__.py b/neo4j/spatial/__init__.py index f530d3da..5a88b147 100644 --- a/neo4j/spatial/__init__.py +++ b/neo4j/spatial/__init__.py @@ -72,22 +72,6 @@ def dehydrate_point(value): return _hydration.dehydrate_point(value) -# TODO: 6.0 remove -@deprecated( - "hydrate_point is considered an internal function and will be removed in " - "a future version" -) -@wraps(_hydration.dehydrate_point) -def dehydrate_point(value): - """ Dehydrator for Point data. - - :param value: - :type value: Point - :return: - """ - return _hydration.dehydrate_point(value) - - # TODO: 6.0 remove @deprecated( "point_type is considered an internal function and will be removed in " diff --git a/neo4j/time/__init__.py b/neo4j/time/__init__.py index d30d705e..4dea3655 100644 --- a/neo4j/time/__init__.py +++ b/neo4j/time/__init__.py @@ -22,12 +22,16 @@ """ +from __future__ import annotations + +import typing as t from datetime import ( date, datetime, time, timedelta, timezone, + tzinfo as _tzinfo, ) from functools import total_ordering from re import compile as re_compile @@ -37,6 +41,8 @@ struct_time, ) +import typing_extensions as te + from ._arithmetic import ( nano_add, nano_div, @@ -73,12 +79,12 @@ #: The smallest year number allowed in a :class:`.Date` or :class:`.DateTime` #: object to be compatible with :class:`datetime.date` and #: :class:`datetime.datetime`. -MIN_YEAR = 1 +MIN_YEAR: te.Final[int] = 1 #: The largest year number allowed in a :class:`.Date` or :class:`.DateTime` #: object to be compatible with :class:`datetime.date` and #: :class:`datetime.datetime`. -MAX_YEAR = 9999 +MAX_YEAR: te.Final[int] = 9999 DATE_ISO_PATTERN = re_compile(r"^(\d{4})-(\d{2})-(\d{2})$") TIME_ISO_PATTERN = re_compile( @@ -195,7 +201,7 @@ class ClockTime(tuple): the ``timespec`` struct in C. """ - def __new__(cls, seconds=0, nanoseconds=0): + def __new__(cls, seconds: float = 0, nanoseconds: int = 0) -> ClockTime: seconds, nanoseconds = divmod( int(NANO_SECONDS * seconds) + int(nanoseconds), NANO_SECONDS ) @@ -314,7 +320,16 @@ def utc_time(self): raise NotImplementedError("No clock implementation selected") -class Duration(tuple): +if t.TYPE_CHECKING: + # make typechecker believe that Duration subclasses datetime.timedelta + # https://github.com/python/typeshed/issues/8409#issuecomment-1197704527 + duration_base_class = timedelta +else: + duration_base_class = object + + +class Duration(t.Tuple[int, int, int, int], # type: ignore[misc] + duration_base_class): """A difference between two points in time. A :class:`.Duration` represents the difference between two points in time. @@ -343,39 +358,40 @@ class Duration(tuple): inclusive. :param years: will be added times 12 to `months` - :type years: float :param months: will be truncated to :class:`int` (`int(months)`) - :type months: float :param weeks: will be added times 7 to `days` - :type weeks: float :param days: will be truncated to :class:`int` (`int(days)`) - :type days: float :param hours: will be added times 3,600,000,000,000 to `nanoseconds` - :type hours: float :param minutes: will be added times 60,000,000,000 to `nanoseconds` - :type minutes: float :param seconds: will be added times 1,000,000,000 to `nanoseconds`` - :type seconds: float :param milliseconds: will be added times 1,000,000 to `nanoseconds` - :type microseconds: float :param microseconds: will be added times 1,000 to `nanoseconds` - :type milliseconds: float :param nanoseconds: will be truncated to :class:`int` (`int(nanoseconds)`) - :type nanoseconds: float :raises ValueError: the components exceed the limits as described above. """ # i64: i64:i64: i32 - min = None + min: te.Final[Duration] = None # type: ignore """The lowest duration value possible.""" - max = None + max: te.Final[Duration] = None # type: ignore """The highest duration value possible.""" - def __new__(cls, years=0, months=0, weeks=0, days=0, hours=0, minutes=0, - seconds=0, milliseconds=0, microseconds=0, nanoseconds=0): + def __new__( + cls, + years: float = 0, + months: float = 0, + weeks: float = 0, + days: float = 0, + hours: float = 0, + minutes: float = 0, + seconds: float = 0, + milliseconds: float = 0, + microseconds: float = 0, + nanoseconds: float = 0 + ) -> Duration: mo = int(12 * years + months) if mo < MIN_INT64 or mo > MAX_INT64: raise ValueError("Months value out of range") @@ -393,20 +409,19 @@ def __new__(cls, years=0, months=0, weeks=0, days=0, hours=0, minutes=0, - (1 if ns < 0 else 0)) if avg_total_seconds < MIN_INT64 or avg_total_seconds > MAX_INT64: raise ValueError("Duration value out of range: %r", - cls.__repr__((mo, d, s, ns))) + tuple.__repr__((mo, d, s, ns))) return tuple.__new__(cls, (mo, d, s, ns)) - def __bool__(self): + def __bool__(self) -> bool: """Falsy if all primary instance attributes are.""" return any(map(bool, self)) __nonzero__ = __bool__ - def __add__(self, other): - """Add a :class:`.Duration` or :class:`datetime.timedelta`. - - :rtype: Duration - """ + def __add__( # type: ignore[override] + self, other: t.Union[Duration, timedelta] + ) -> Duration: + """Add a :class:`.Duration` or :class:`datetime.timedelta`.""" if isinstance(other, Duration): return Duration( months=self[0] + int(other.months), @@ -422,11 +437,8 @@ def __add__(self, other): ) return NotImplemented - def __sub__(self, other): - """Subtract a :class:`.Duration` or :class:`datetime.timedelta`. - - :rtype: Duration - """ + def __sub__(self, other: t.Union[Duration, timedelta]) -> Duration: + """Subtract a :class:`.Duration` or :class:`datetime.timedelta`.""" if isinstance(other, Duration): return Duration( months=self[0] - int(other.months), @@ -443,7 +455,7 @@ def __sub__(self, other): ) return NotImplemented - def __mul__(self, other): + def __mul__(self, other: float) -> Duration: # type: ignore[override] """Multiply by an :class:`int` or :class:`float`. The operation is performed element-wise on @@ -454,8 +466,6 @@ def __mul__(self, other): * seconds and all sub-second units go into nanoseconds. Each element will be rounded to the nearest integer (.5 towards even). - - :rtype: Duration """ if isinstance(other, (int, float)): return Duration( @@ -468,7 +478,7 @@ def __mul__(self, other): ) return NotImplemented - def __floordiv__(self, other): + def __floordiv__(self, other: int) -> Duration: # type: ignore[override] """Integer division by an :class:`int`. The operation is performed element-wise on @@ -479,8 +489,6 @@ def __floordiv__(self, other): * seconds and all sub-second units go into nanoseconds. Each element will be rounded towards -inf. - - :rtype: Duration """ if isinstance(other, int): return Duration( @@ -489,7 +497,7 @@ def __floordiv__(self, other): ) return NotImplemented - def __mod__(self, other): + def __mod__(self, other: int) -> Duration: # type: ignore[override] """Modulo operation by an :class:`int`. The operation is performed element-wise on @@ -498,8 +506,6 @@ def __mod__(self, other): * years go into months, * weeks go into days, * seconds and all sub-second units go into nanoseconds. - - :rtype: Duration """ if isinstance(other, int): return Duration( @@ -508,18 +514,18 @@ def __mod__(self, other): ) return NotImplemented - def __divmod__(self, other): + def __divmod__( # type: ignore[override] + self, other: int + ) -> t.Tuple[Duration, Duration]: """Division and modulo operation by an :class:`int`. See :meth:`__floordiv__` and :meth:`__mod__`. - - :rtype: (Duration, Duration) """ if isinstance(other, int): return self.__floordiv__(other), self.__mod__(other) return NotImplemented - def __truediv__(self, other): + def __truediv__(self, other: float) -> Duration: # type: ignore[override] """Division by an :class:`int` or :class:`float`. The operation is performed element-wise on @@ -530,8 +536,6 @@ def __truediv__(self, other): * seconds and all sub-second units go into nanoseconds. Each element will be rounded to the nearest integer (.5 towards even). - - :rtype: Duration """ if isinstance(other, (int, float)): return Duration( @@ -544,37 +548,37 @@ def __truediv__(self, other): ) return NotImplemented - def __pos__(self): + def __pos__(self) -> Duration: """""" return self - def __neg__(self): + def __neg__(self) -> Duration: """""" return Duration(months=-self[0], days=-self[1], seconds=-self[2], nanoseconds=-self[3]) - def __abs__(self): + def __abs__(self) -> Duration: """""" return Duration(months=abs(self[0]), days=abs(self[1]), seconds=abs(self[2]), nanoseconds=abs(self[3])) - def __repr__(self): + def __repr__(self) -> str: """""" return "Duration(months=%r, days=%r, seconds=%r, nanoseconds=%r)" % self - def __str__(self): + def __str__(self) -> str: """""" return self.iso_format() - def __copy__(self): + def __copy__(self) -> Duration: return self.__new__(self.__class__, months=self[0], days=self[1], seconds=self[2], nanoseconds=self[3]) - def __deepcopy__(self, memodict={}): + def __deepcopy__(self, memo) -> Duration: return self.__copy__() @classmethod - def from_iso_format(cls, s): + def from_iso_format(cls, s: str) -> Duration: """Parse a ISO formatted duration string. Accepted formats (all lowercase letters are placeholders): @@ -596,9 +600,6 @@ def from_iso_format(cls, s): 100 minutes. :param s: String to parse - :type s: str - - :rtype: Duration :raises ValueError: if the string does not match the required format. """ @@ -620,13 +621,10 @@ def from_iso_format(cls, s): fromisoformat = from_iso_format - def iso_format(self, sep="T"): + def iso_format(self, sep: str = "T") -> str: """Return the :class:`Duration` as ISO formatted string. :param sep: the separator before the time components. - :type sep: str - - :rtype: str """ parts = [] hours, minutes, seconds, nanoseconds = \ @@ -665,62 +663,64 @@ def iso_format(self, sep="T"): return "PT0S" @property - def months(self): - """The months of the :class:`Duration`. - - :type: int - """ + def months(self) -> int: + """The months of the :class:`Duration`.""" return self[0] @property - def days(self): - """The days of the :class:`Duration`. - - :type: int - """ + def days(self) -> int: + """The days of the :class:`Duration`.""" return self[1] @property - def seconds(self): - """The seconds of the :class:`Duration`. - - :type: int - """ + def seconds(self) -> int: + """The seconds of the :class:`Duration`.""" return self[2] @property - def nanoseconds(self): - """The nanoseconds of the :class:`Duration`. - - :type: int - """ + def nanoseconds(self) -> int: + """The nanoseconds of the :class:`Duration`.""" return self[3] @property - def years_months_days(self): - """ + def years_months_days(self) -> t.Tuple[int, int, int]: + """Months and days components as a 3-tuple. - :return: + t.Tuple of years, months and days. """ years, months = symmetric_divmod(self[0], 12) return years, months, self[1] @property - def hours_minutes_seconds_nanoseconds(self): - """ A 4-tuple of (hours, minutes, seconds, nanoseconds). + def hours_minutes_seconds_nanoseconds(self) -> t.Tuple[int, int, int, int]: + """Seconds and nanoseconds components as a 4-tuple. - :type: (int, int, int, int) + t.Tuple of hours, minutes, seconds and nanoseconds. """ minutes, seconds = symmetric_divmod(self[2], 60) hours, minutes = symmetric_divmod(minutes, 60) return hours, minutes, seconds, self[3] -Duration.min = Duration(seconds=MIN_INT64, nanoseconds=0) -Duration.max = Duration(seconds=MAX_INT64, nanoseconds=999999999) +Duration.min = Duration( # type: ignore + seconds=MIN_INT64, nanoseconds=0 +) + +Duration.max = Duration( # type: ignore + seconds=MAX_INT64, + nanoseconds=999999999 +) + + +if t.TYPE_CHECKING: + # make typechecker believe that Date subclasses datetime.date + # https://github.com/python/typeshed/issues/8409#issuecomment-1197704527 + date_base_class = date +else: + date_base_class = object -class Date(metaclass=DateType): +class Date(date_base_class, metaclass=DateType): """Idealized date representation. A :class:`.Date` object represents a date (year, month, and day) in the @@ -758,7 +758,7 @@ class Date(metaclass=DateType): # CONSTRUCTOR # - def __new__(cls, year, month, day): + def __new__(cls, year: int, month: int, day: int) -> Date: if year == month == day == 0: return ZeroDate year, month, day = _normalize_day(year, month, day) @@ -766,7 +766,7 @@ def __new__(cls, year, month, day): return cls.__new(ordinal, year, month, day) @classmethod - def __new(cls, ordinal, year, month, day): + def __new(cls, ordinal: int, year: int, month: int, day: int) -> Date: instance = object.__new__(cls) instance.__ordinal = int(ordinal) instance.__year = int(year) @@ -774,32 +774,13 @@ def __new(cls, ordinal, year, month, day): instance.__day = int(day) return instance - def __getattr__(self, name): - """ Map standard library attribute names to local attribute names, - for compatibility. - """ - try: - return { - "isocalendar": self.iso_calendar, - "isoformat": self.iso_format, - "isoweekday": self.iso_weekday, - "strftime": self.__format__, - "toordinal": self.to_ordinal, - "timetuple": self.time_tuple, - }[name] - except KeyError: - raise AttributeError("Date has no attribute %r" % name) - # CLASS METHODS # @classmethod - def today(cls, tz=None): + def today(cls, tz: _tzinfo = None) -> Date: """Get the current date. :param tz: timezone or None to get the local :class:`.Date`. - :type tz: datetime.tzinfo or None - - :rtype: Date :raises OverflowError: if the timestamp is out of the range of values supported by the platform C localtime() function. It’s common for @@ -815,23 +796,16 @@ def today(cls, tz=None): ) @classmethod - def utc_today(cls): - """Get the current date as UTC local date. - - :rtype: Date - """ + def utc_today(cls) -> Date: + """Get the current date as UTC local date.""" return cls.from_clock_time(Clock().utc_time(), UnixEpoch) @classmethod - def from_timestamp(cls, timestamp, tz=None): + def from_timestamp(cls, timestamp: float, tz: _tzinfo = None) -> Date: """:class:`.Date` from a time stamp (seconds since unix epoch). :param timestamp: the unix timestamp (seconds since unix epoch). - :type timestamp: float :param tz: timezone. Set to None to create a local :class:`.Date`. - :type tz: datetime.tzinfo or None - - :rtype: Date :raises OverflowError: if the timestamp is out of the range of values supported by the platform C localtime() function. It’s common for @@ -840,17 +814,15 @@ def from_timestamp(cls, timestamp, tz=None): return cls.from_native(datetime.fromtimestamp(timestamp, tz)) @classmethod - def utc_from_timestamp(cls, timestamp): + def utc_from_timestamp(cls, timestamp: float) -> Date: """:class:`.Date` from a time stamp (seconds since unix epoch). - Returns the `Date` as local date `Date` in UTC. - - :rtype: Date + :returns: the `Date` as local date `Date` in UTC. """ return cls.from_clock_time((timestamp, 0), UnixEpoch) @classmethod - def from_ordinal(cls, ordinal): + def from_ordinal(cls, ordinal: int) -> Date: """ The :class:`.Date` that corresponds to the proleptic Gregorian ordinal. @@ -860,8 +832,6 @@ def from_ordinal(cls, ordinal): transformation is :meth:`.to_ordinal`. The ordinal 0 has a special semantic and will return :attr:`ZeroDate`. - :rtype: Date - :raises ValueError: if the ordinal is outside the range [0, 3652059] (both values included). """ @@ -899,40 +869,33 @@ def from_ordinal(cls, ordinal): return cls.__new(ordinal, year, month, day) @classmethod - def parse(cls, s): + def parse(cls, s: str) -> Date: """Parse a string to produce a :class:`.Date`. Accepted formats: 'Y-M-D' :param s: the string to be parsed. - :type s: str - - :rtype: Date :raises ValueError: if the string could not be parsed. """ try: - numbers = map(int, s.split("-")) + numbers = list(map(int, s.split("-"))) except (ValueError, AttributeError): raise ValueError("Date string must be in format YYYY-MM-DD") else: - numbers = list(numbers) if len(numbers) == 3: return cls(*numbers) raise ValueError("Date string must be in format YYYY-MM-DD") @classmethod - def from_iso_format(cls, s): + def from_iso_format(cls, s: str) -> Date: """Parse a ISO formatted Date string. Accepted formats: 'YYYY-MM-DD' :param s: the string to be parsed. - :type s: str - - :rtype: Date :raises ValueError: if the string could not be parsed. """ @@ -945,27 +908,24 @@ def from_iso_format(cls, s): raise ValueError("Date string must be in format YYYY-MM-DD") @classmethod - def from_native(cls, d): + def from_native(cls, d: date) -> Date: """Convert from a native Python `datetime.date` value. :param d: the date to convert. - :type d: datetime.date - - :rtype: Date """ return Date.from_ordinal(d.toordinal()) @classmethod - def from_clock_time(cls, clock_time, epoch): + def from_clock_time( + cls, + clock_time: t.Union[ClockTime, t.Tuple[float, int]], + epoch: DateTime + ) -> Date: """Convert from a ClockTime relative to a given epoch. :param clock_time: the clock time as :class:`.ClockTime` or as tuple of (seconds, nanoseconds) - :type clock_time: ClockTime or (float, int) :param epoch: the epoch to which `clock_time` is relative - :type epoch: DateTime - - :rtype: Date """ try: clock_time = ClockTime(*clock_time) @@ -976,13 +936,10 @@ def from_clock_time(cls, clock_time, epoch): return Date.from_ordinal(ordinal + epoch.date().to_ordinal()) @classmethod - def is_leap_year(cls, year): + def is_leap_year(cls, year: int) -> bool: """Indicates whether or not `year` is a leap year. :param year: the year to look up - :type year: int - - :rtype: bool :raises ValueError: if `year` is out of range: :attr:`MIN_YEAR` <= year <= :attr:`MAX_YEAR` @@ -992,13 +949,10 @@ def is_leap_year(cls, year): return IS_LEAP_YEAR[year] @classmethod - def days_in_year(cls, year): + def days_in_year(cls, year: int) -> int: """Return the number of days in `year`. :param year: the year to look up - :type year: int - - :rtype: int :raises ValueError: if `year` is out of range: :attr:`MIN_YEAR` <= year <= :attr:`MAX_YEAR` @@ -1008,15 +962,11 @@ def days_in_year(cls, year): return DAYS_IN_YEAR[year] @classmethod - def days_in_month(cls, year, month): + def days_in_month(cls, year: int, month: int) -> int: """Return the number of days in `month` of `year`. :param year: the year to look up - :type year: int :param year: the month to look up - :type year: int - - :rtype: int :raises ValueError: if `year` or `month` is out of range: :attr:`MIN_YEAR` <= year <= :attr:`MAX_YEAR`; @@ -1036,15 +986,34 @@ def __calc_ordinal(cls, year, month, day): # long-hand pure Python algorithm could return date(year, month, day).toordinal() + # CLASS METHOD ALIASES # + + if t.TYPE_CHECKING: + @classmethod + def fromisoformat(cls, s: str) -> Date: + ... + + @classmethod + def fromordinal(cls, ordinal: int) -> Date: + ... + + @classmethod + def fromtimestamp(cls, timestamp: float, tz: _tzinfo = None) -> Date: + ... + + @classmethod + def utcfromtimestamp(cls, timestamp: float) -> Date: + ... + # CLASS ATTRIBUTES # - min = None + min: te.Final[Date] = None # type: ignore """The earliest date value possible.""" - max = None + max: te.Final[Date] = None # type: ignore """The latest date value possible.""" - resolution = None + resolution: te.Final[Duration] = None # type: ignore """The minimum resolution supported.""" # INSTANCE ATTRIBUTES # @@ -1058,7 +1027,7 @@ def __calc_ordinal(cls, year, month, day): __day = 0 @property - def year(self): + def year(self) -> int: """The year of the date. :type: int @@ -1066,7 +1035,7 @@ def year(self): return self.__year @property - def month(self): + def month(self) -> int: """The month of the date. :type: int @@ -1074,7 +1043,7 @@ def month(self): return self.__month @property - def day(self): + def day(self) -> int: """The day of the date. :type: int @@ -1086,20 +1055,16 @@ def day(self): return self.days_in_month(self.__year, self.__month) + self.__day + 1 @property - def year_month_day(self): + def year_month_day(self) -> t.Tuple[int, int, int]: """3-tuple of (year, month, day) describing the date. - - :rtype: (int, int, int) """ return self.year, self.month, self.day @property - def year_week_day(self): + def year_week_day(self) -> t.Tuple[int, int, int]: """3-tuple of (year, week_of_year, day_of_week) describing the date. `day_of_week` will be 1 for Monday and 7 for Sunday. - - :rtype: (int, int, int) """ ordinal = self.__ordinal year = self.__year @@ -1126,13 +1091,11 @@ def iso_week_1(y): day_of_week(ordinal)) @property - def year_day(self): + def year_day(self) -> t.Tuple[int, int]: """2-tuple of (year, day_of_the_year) describing the date. This is the number of the day relative to the start of the year, with `1 Jan` corresponding to `1`. - - :rtype: (int, int) """ return (self.__year, self.toordinal() - Date(self.__year, 1, 1).toordinal() + 1) @@ -1143,45 +1106,43 @@ def __hash__(self): """""" return hash(self.toordinal()) - def __eq__(self, other): - """`==` comparison with :class:`.Date` or :class:`datetime.date`.""" + def __eq__(self, other: object) -> bool: + """``==`` comparison with :class:`.Date` or :class:`datetime.date`.""" if isinstance(other, (Date, date)): return self.toordinal() == other.toordinal() return False - def __ne__(self, other): - """`!=` comparison with :class:`.Date` or :class:`datetime.date`.""" + def __ne__(self, other: object) -> bool: + """``!=`` comparison with :class:`.Date` or :class:`datetime.date`.""" return not self.__eq__(other) - def __lt__(self, other): - """`<` comparison with :class:`.Date` or :class:`datetime.date`.""" + def __lt__(self, other: t.Union[Date, date]) -> bool: + """``<`` comparison with :class:`.Date` or :class:`datetime.date`.""" if isinstance(other, (Date, date)): return self.toordinal() < other.toordinal() raise TypeError("'<' not supported between instances of 'Date' and %r" % type(other).__name__) - def __le__(self, other): - """`<=` comparison with :class:`.Date` or :class:`datetime.date`.""" + def __le__(self, other: t.Union[Date, date]) -> bool: + """``<=`` comparison with :class:`.Date` or :class:`datetime.date`.""" if isinstance(other, (Date, date)): return self.toordinal() <= other.toordinal() raise TypeError("'<=' not supported between instances of 'Date' and %r" % type(other).__name__) - def __ge__(self, other): - """`>=` comparison with :class:`.Date` or :class:`datetime.date`.""" + def __ge__(self, other: t.Union[Date, date]) -> bool: + """``>=`` comparison with :class:`.Date` or :class:`datetime.date`.""" if isinstance(other, (Date, date)): return self.toordinal() >= other.toordinal() raise TypeError("'>=' not supported between instances of 'Date' and %r" % type(other).__name__) - def __gt__(self, other): - """`>` comparison with :class:`.Date` or :class:`datetime.date`.""" + def __gt__(self, other: t.Union[Date, date]) -> bool: + """``>`` comparison with :class:`.Date` or :class:`datetime.date`.""" if isinstance(other, (Date, date)): return self.toordinal() > other.toordinal() raise TypeError("'>' not supported between instances of 'Date' and %r" % type(other).__name__) - def __add__(self, other): + def __add__(self, other: Duration) -> Date: # type: ignore[override] """Add a :class:`.Duration`. - :rtype: Date - :raises ValueError: if the added duration has a time component. """ def add_months(d, months): @@ -1224,6 +1185,14 @@ def add_days(d, days): return new_date return NotImplemented + @t.overload # type: ignore[override] + def __sub__(self, other: t.Union[Date, date]) -> Duration: + ... + + @t.overload + def __sub__(self, other: Duration) -> Date: + ... + def __sub__(self, other): """Subtract a :class:`.Date` or :class:`.Duration`. @@ -1241,102 +1210,97 @@ def __sub__(self, other): except TypeError: return NotImplemented - def __copy__(self): + def __copy__(self) -> Date: return self.__new(self.__ordinal, self.__year, self.__month, self.__day) - def __deepcopy__(self, *args, **kwargs): + def __deepcopy__(self, *args, **kwargs) -> Date: return self.__copy__() # INSTANCE METHODS # - def replace(self, **kwargs): - """Return a :class:`.Date` with one or more components replaced. - :Keyword Arguments: - * **year** (`int`): overwrite the year - - default: `self.year` - * **month** (`int`): overwrite the month - - default: `self.month` - * **day** (`int`): overwrite the day - - default: `self.day` - """ - return Date(kwargs.get("year", self.__year), - kwargs.get("month", self.__month), - kwargs.get("day", self.__day)) + if t.TYPE_CHECKING: - def time_tuple(self): - """Convert the date to :class:`time.struct_time`. + def replace( + self, + year: int = ..., + month: int = ..., + day: int = ..., + **kwargs: object + ) -> Date: + ... - :rtype: time.struct_time - """ + else: + + def replace(self, **kwargs) -> Date: + """Return a :class:`.Date` with one or more components replaced. + + :Keyword Arguments: + * **year** (`int`): overwrite the year - + default: `self.year` + * **month** (`int`): overwrite the month - + default: `self.month` + * **day** (`int`): overwrite the day - + default: `self.day` + """ + return Date(kwargs.get("year", self.__year), + kwargs.get("month", self.__month), + kwargs.get("day", self.__day)) + + def time_tuple(self) -> struct_time: + """Convert the date to :class:`time.struct_time`.""" _, _, day_of_week = self.year_week_day _, day_of_year = self.year_day return struct_time((self.year, self.month, self.day, 0, 0, 0, day_of_week - 1, day_of_year, -1)) - def to_ordinal(self): + def to_ordinal(self) -> int: """The date's proleptic Gregorian ordinal. The corresponding class method for the reverse ordinal-to-date transformation is :meth:`.Date.from_ordinal`. - - :rtype: int """ return self.__ordinal - def to_clock_time(self, epoch): + def to_clock_time(self, epoch: t.Union[Date, DateTime]) -> ClockTime: """Convert the date to :class:`ClockTime` relative to `epoch`. :param epoch: the epoch to which the date is relative - :type epoch: Date - - :rtype: ClockTime """ try: return ClockTime(86400 * (self.to_ordinal() - epoch.to_ordinal())) except AttributeError: raise TypeError("Epoch has no ordinal value") - def to_native(self): + def to_native(self) -> date: """Convert to a native Python :class:`datetime.date` value. - - :rtype: datetime.date """ return date.fromordinal(self.to_ordinal()) - def weekday(self): - """The day of the week where Monday is 0 and Sunday is 6. - - :rtype: int - """ + def weekday(self) -> int: + """The day of the week where Monday is 0 and Sunday is 6.""" return self.year_week_day[2] - 1 - def iso_weekday(self): - """The day of the week where Monday is 1 and Sunday is 7. - - :rtype: int - """ + def iso_weekday(self) -> int: + """The day of the week where Monday is 1 and Sunday is 7.""" return self.year_week_day[2] - def iso_calendar(self): + def iso_calendar(self) -> t.Tuple[int, int, int]: """Alias for :attr:`.year_week_day`""" return self.year_week_day - def iso_format(self): - """Return the :class:`.Date` as ISO formatted string. - - :rtype: str - """ + def iso_format(self) -> str: + """Return the :class:`.Date` as ISO formatted string.""" if self.__ordinal == 0: return "0000-00-00" return "%04d-%02d-%02d" % self.year_month_day - def __repr__(self): + def __repr__(self) -> str: """""" if self.__ordinal == 0: return "neo4j.time.ZeroDate" return "neo4j.time.Date(%r, %r, %r)" % self.year_month_day - def __str__(self): + def __str__(self) -> str: """""" return self.iso_format() @@ -1344,17 +1308,51 @@ def __format__(self, format_spec): """""" raise NotImplementedError() + # INSTANCE METHOD ALIASES # + + def __getattr__(self, name): + """ Map standard library attribute names to local attribute names, + for compatibility. + """ + try: + return { + "isocalendar": self.iso_calendar, + "isoformat": self.iso_format, + "isoweekday": self.iso_weekday, + "strftime": self.__format__, + "toordinal": self.to_ordinal, + "timetuple": self.time_tuple, + }[name] + except KeyError: + raise AttributeError("Date has no attribute %r" % name) + + if t.TYPE_CHECKING: + isocalendar = iso_calendar + isoformat = iso_format + isoweekday = iso_weekday + strftime = __format__ + toordinal = to_ordinal + timetuple = time_tuple -Date.min = Date.from_ordinal(1) -Date.max = Date.from_ordinal(3652059) -Date.resolution = Duration(days=1) + +Date.min = Date.from_ordinal(1) # type: ignore +Date.max = Date.from_ordinal(3652059) # type: ignore +Date.resolution = Duration(days=1) # type: ignore #: A :class:`neo4j.time.Date` instance set to `0000-00-00`. #: This has an ordinal value of `0`. ZeroDate = object.__new__(Date) -class Time(metaclass=TimeType): +if t.TYPE_CHECKING: + # make typechecker believe that Time subclasses datetime.time + # https://github.com/python/typeshed/issues/8409#issuecomment-1197704527 + time_base_class = time +else: + time_base_class = object + + +class Time(time_base_class, metaclass=TimeType): """Time of day. The :class:`.Time` class is a nanosecond-precision drop-in replacement for @@ -1372,23 +1370,25 @@ class Time(metaclass=TimeType): Local times are represented by :class:`.Time` with no ``tzinfo``. :param hour: the hour of the time. Must be in range 0 <= hour < 24. - :type hour: int :param minute: the minute of the time. Must be in range 0 <= minute < 60. - :type minute: int :param second: the second of the time. Must be in range 0 <= second < 60. - :type second: int :param nanosecond: the nanosecond of the time. Must be in range 0 <= nanosecond < 999999999. - :type nanosecond: int :param tzinfo: timezone or None to get a local :class:`.Time`. - :type tzinfo: datetime.tzinfo or None :raises ValueError: if one of the parameters is out of range. """ # CONSTRUCTOR # - def __new__(cls, hour=0, minute=0, second=0, nanosecond=0, tzinfo=None): + def __new__( + cls, + hour: int = 0, + minute: int = 0, + second: int = 0, + nanosecond: int = 0, + tzinfo: _tzinfo = None + ) -> Time: hour, minute, second, nanosecond = cls.__normalize_nanosecond( hour, minute, second, nanosecond ) @@ -1409,27 +1409,13 @@ def __new(cls, ticks, hour, minute, second, nanosecond, tzinfo): instance.__tzinfo = tzinfo return instance - def __getattr__(self, name): - """Map standard library attribute names to local attribute names, - for compatibility. - """ - try: - return { - "isoformat": self.iso_format, - "utcoffset": self.utc_offset, - }[name] - except KeyError: - raise AttributeError("Date has no attribute %r" % name) - # CLASS METHODS # @classmethod - def now(cls, tz=None): + def now(cls, tz: _tzinfo = None) -> Time: """Get the current time. :param tz: optional timezone - :type tz: datetime.tzinfo - :rtype: Time :raises OverflowError: if the timestamp is out of the range of values supported by the platform C localtime() function. It’s common for @@ -1445,15 +1431,12 @@ def now(cls, tz=None): ) @classmethod - def utc_now(cls): - """Get the current time as UTC local time. - - :rtype: Time - """ + def utc_now(cls) -> Time: + """Get the current time as UTC local time.""" return cls.from_clock_time(Clock().utc_time(), UnixEpoch) @classmethod - def from_iso_format(cls, s): + def from_iso_format(cls, s: str) -> Time: """Parse a ISO formatted time string. Accepted formats: @@ -1474,13 +1457,10 @@ def from_iso_format(cls, s): Seconds and sub-seconds are ignored. :param s: String to parse - :type s: str - - :rtype: Time :raises ValueError: if the string does not match the required format. """ - from pytz import FixedOffset + from pytz import FixedOffset # type: ignore m = TIME_ISO_PATTERN.match(s) if m: hour = int(m.group(1)) @@ -1506,15 +1486,11 @@ def from_iso_format(cls, s): raise ValueError("Time string is not in ISO format") @classmethod - def from_ticks(cls, ticks, tz=None): + def from_ticks(cls, ticks: int, tz: _tzinfo = None) -> Time: """Create a time from ticks (nanoseconds since midnight). :param ticks: nanoseconds since midnight - :type ticks: int :param tz: optional timezone - :type tz: datetime.tzinfo - - :rtype: Time :raises ValueError: if ticks is out of bounds (0 <= ticks < 86400000000000) @@ -1529,19 +1505,20 @@ def from_ticks(cls, ticks, tz=None): raise ValueError("Ticks out of range (0..86400000000000)") @classmethod - def from_native(cls, t): + def from_native(cls, t: time) -> Time: """Convert from a native Python :class:`datetime.time` value. :param t: time to convert from - :type t: datetime.time - - :rtype: Time """ nanosecond = t.microsecond * 1000 return Time(t.hour, t.minute, t.second, nanosecond, t.tzinfo) @classmethod - def from_clock_time(cls, clock_time, epoch): + def from_clock_time( + cls, + clock_time: t.Union[ClockTime, t.Tuple[float, int]], + epoch: DateTime + ) -> Time: """Convert from a :class:`.ClockTime` relative to a given epoch. This method, in contrast to most others of this package, assumes days of @@ -1549,11 +1526,7 @@ def from_clock_time(cls, clock_time, epoch): :param clock_time: the clock time as :class:`.ClockTime` or as tuple of (seconds, nanoseconds) - :type clock_time: ClockTime or (float, int) :param epoch: the epoch to which `clock_time` is relative - :type epoch: DateTime - - :rtype: Time """ clock_time = ClockTime(*clock_time) ts = clock_time.seconds % 86400 @@ -1591,15 +1564,26 @@ def __normalize_nanosecond(cls, hour, minute, second, nanosecond): return hour, minute, second, nanosecond raise ValueError("Nanosecond out of range (0..%s)" % (NANO_SECONDS - 1)) + # CLASS METHOD ALIASES # + + if t.TYPE_CHECKING: + @classmethod + def from_iso_format(cls, s: str) -> Time: + ... + + @classmethod + def utc_now(cls) -> Time: + ... + # CLASS ATTRIBUTES # - min = None + min: te.Final[Time] = None # type: ignore """The earliest time value possible.""" - max = None + max: te.Final[Time] = None # type: ignore """The latest time value possible.""" - resolution = None + resolution: te.Final[Duration] = None # type: ignore """The minimum resolution supported.""" # INSTANCE ATTRIBUTES # @@ -1617,62 +1601,43 @@ def __normalize_nanosecond(cls, hour, minute, second, nanosecond): __tzinfo = None @property - def ticks(self): - """The total number of nanoseconds since midnight. - - :type: int - """ + def ticks(self) -> int: + """The total number of nanoseconds since midnight.""" return self.__ticks @property - def hour(self): - """The hours of the time. - - :type: int - """ + def hour(self) -> int: + """The hours of the time.""" return self.__hour @property - def minute(self): - """The minutes of the time. - - :type: int - """ + def minute(self) -> int: + """The minutes of the time.""" return self.__minute @property - def second(self): - """The seconds of the time. - - :type: int - """ + def second(self) -> int: + """The seconds of the time.""" return self.__second @property - def nanosecond(self): - """The nanoseconds of the time. - - :type: int - """ + def nanosecond(self) -> int: + """The nanoseconds of the time.""" return self.__nanosecond @property - def hour_minute_second_nanosecond(self): - """The time as a tuple of (hour, minute, second, nanosecond). - - :type: (int, int, int, int)""" + def hour_minute_second_nanosecond(self) -> t.Tuple[int, int, int, int]: + """The time as a tuple of (hour, minute, second, nanosecond).""" return self.__hour, self.__minute, self.__second, self.__nanosecond @property - def tzinfo(self): - """The timezone of this time. - - :type: datetime.tzinfo or None""" + def tzinfo(self) -> t.Optional[_tzinfo]: + """The timezone of this time.""" return self.__tzinfo # OPERATIONS # - def _get_both_normalized_ticks(self, other, strict=True): + def _get_both_normalized_ticks(self, other: object, strict=True): if (isinstance(other, (time, Time)) and ((self.utc_offset() is None) ^ (other.utcoffset() is None))): @@ -1681,6 +1646,7 @@ def _get_both_normalized_ticks(self, other, strict=True): "times") else: return None, None + other_ticks: int if isinstance(other, Time): other_ticks = other.__ticks elif isinstance(other, time): @@ -1690,13 +1656,14 @@ def _get_both_normalized_ticks(self, other, strict=True): + 1000 * other.microsecond) else: return None, None - utc_offset = other.utcoffset() + assert isinstance(other, (Time, time)) + utc_offset: t.Optional[timedelta] = other.utcoffset() if utc_offset is not None: - other_ticks -= utc_offset.total_seconds() * NANO_SECONDS + other_ticks -= int(utc_offset.total_seconds() * NANO_SECONDS) self_ticks = self.__ticks utc_offset = self.utc_offset() if utc_offset is not None: - self_ticks -= utc_offset.total_seconds() * NANO_SECONDS + self_ticks -= int(utc_offset.total_seconds() * NANO_SECONDS) return self_ticks, other_ticks def __hash__(self): @@ -1708,7 +1675,7 @@ def __hash__(self): self_ticks -= self.utc_offset().total_seconds() * NANO_SECONDS return hash(self_ticks) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: """`==` comparison with :class:`.Time` or :class:`datetime.time`.""" self_ticks, other_ticks = self._get_both_normalized_ticks(other, strict=False) @@ -1716,69 +1683,82 @@ def __eq__(self, other): return False return self_ticks == other_ticks - def __ne__(self, other): + def __ne__(self, other: object) -> bool: """`!=` comparison with :class:`.Time` or :class:`datetime.time`.""" return not self.__eq__(other) - def __lt__(self, other): + def __lt__(self, other: t.Union[Time, time]) -> bool: """`<` comparison with :class:`.Time` or :class:`datetime.time`.""" self_ticks, other_ticks = self._get_both_normalized_ticks(other) if self_ticks is None: return NotImplemented return self_ticks < other_ticks - def __le__(self, other): + def __le__(self, other: t.Union[Time, time]) -> bool: """`<=` comparison with :class:`.Time` or :class:`datetime.time`.""" self_ticks, other_ticks = self._get_both_normalized_ticks(other) if self_ticks is None: return NotImplemented return self_ticks <= other_ticks - def __ge__(self, other): + def __ge__(self, other: t.Union[Time, time]) -> bool: """`>=` comparison with :class:`.Time` or :class:`datetime.time`.""" self_ticks, other_ticks = self._get_both_normalized_ticks(other) if self_ticks is None: return NotImplemented return self_ticks >= other_ticks - def __gt__(self, other): + def __gt__(self, other: t.Union[Time, time]) -> bool: """`>` comparison with :class:`.Time` or :class:`datetime.time`.""" self_ticks, other_ticks = self._get_both_normalized_ticks(other) if self_ticks is None: return NotImplemented return self_ticks > other_ticks - def __copy__(self): + def __copy__(self) -> Time: return self.__new(self.__ticks, self.__hour, self.__minute, self.__second, self.__nanosecond, self.__tzinfo) - def __deepcopy__(self, *args, **kwargs): + def __deepcopy__(self, *args, **kwargs) -> Time: return self.__copy__() # INSTANCE METHODS # - def replace(self, **kwargs): - """Return a :class:`.Time` with one or more components replaced. - - :Keyword Arguments: - * **hour** (`int`): overwrite the hour - - default: `self.hour` - * **minute** (`int`): overwrite the minute - - default: `self.minute` - * **second** (`int`): overwrite the second - - default: `int(self.second)` - * **nanosecond** (`int`): overwrite the nanosecond - - default: `self.nanosecond` - * **tzinfo** (`datetime.tzinfo` or `None`): overwrite the timezone - - default: `self.tzinfo` - - :rtype: Time - """ - return Time(hour=kwargs.get("hour", self.__hour), - minute=kwargs.get("minute", self.__minute), - second=kwargs.get("second", self.__second), - nanosecond=kwargs.get("nanosecond", self.__nanosecond), - tzinfo=kwargs.get("tzinfo", self.__tzinfo)) + if t.TYPE_CHECKING: + + def replace( # type: ignore[override] + self, + hour: int = ..., + minute: int = ..., + second: int = ..., + nanosecond: int = ..., + tzinfo: t.Optional[_tzinfo] = ..., + **kwargs: object + ) -> Time: + ... + + else: + + def replace(self, **kwargs) -> Time: + """Return a :class:`.Time` with one or more components replaced. + + :Keyword Arguments: + * **hour** (`int`): overwrite the hour - + default: `self.hour` + * **minute** (`int`): overwrite the minute - + default: `self.minute` + * **second** (`int`): overwrite the second - + default: `int(self.second)` + * **nanosecond** (`int`): overwrite the nanosecond - + default: `self.nanosecond` + * **tzinfo** (`datetime.tzinfo` or `None`): + overwrite the timezone - default: `self.tzinfo` + """ + return Time(hour=kwargs.get("hour", self.__hour), + minute=kwargs.get("minute", self.__minute), + second=kwargs.get("second", self.__second), + nanosecond=kwargs.get("nanosecond", self.__nanosecond), + tzinfo=kwargs.get("tzinfo", self.__tzinfo)) def _utc_offset(self, dt=None): if self.tzinfo is None: @@ -1800,12 +1780,11 @@ def _utc_offset(self, dt=None): return value raise TypeError("utcoffset must be a timedelta") - def utc_offset(self): + def utc_offset(self) -> t.Optional[timedelta]: """Return the UTC offset of this time. :return: None if this is a local time (:attr:`.tzinfo` is None), else returns `self.tzinfo.utcoffset(self)`. - :rtype: datetime.timedelta :raises ValueError: if `self.tzinfo.utcoffset(self)` is not None and a :class:`timedelta` with a magnitude greater equal 1 day or that is @@ -1815,12 +1794,11 @@ def utc_offset(self): """ return self._utc_offset() - def dst(self): + def dst(self) -> t.Optional[timedelta]: """Get the daylight saving time adjustment (DST). :return: None if this is a local time (:attr:`.tzinfo` is None), else returns `self.tzinfo.dst(self)`. - :rtype: datetime.timedelta :raises ValueError: if `self.tzinfo.dst(self)` is not None and a :class:`timedelta` with a magnitude greater equal 1 day or that is @@ -1831,11 +1809,11 @@ def dst(self): if self.tzinfo is None: return None try: - value = self.tzinfo.dst(self) + value = self.tzinfo.dst(self) # type: ignore except TypeError: # For timezone implementations not compatible with the custom # datetime implementations, we can't do better than this. - value = self.tzinfo.dst(self.to_native()) + value = self.tzinfo.dst(self.to_native()) # type: ignore if value is None: return None if isinstance(value, timedelta): @@ -1846,56 +1824,46 @@ def dst(self): return value raise TypeError("dst must be a timedelta") - def tzname(self): + def tzname(self) -> t.Optional[str]: """Get the name of the :class:`.Time`'s timezone. :returns: None if the time is local (i.e., has no timezone), else return `self.tzinfo.tzname(self)` - - :rtype: str or None """ if self.tzinfo is None: return None try: - return self.tzinfo.tzname(self) + return self.tzinfo.tzname(self) # type: ignore except TypeError: # For timezone implementations not compatible with the custom # datetime implementations, we can't do better than this. - return self.tzinfo.tzname(self.to_native()) - - def to_clock_time(self): - """Convert to :class:`.ClockTime`. + return self.tzinfo.tzname(self.to_native()) # type: ignore - :rtype: ClockTime - """ + def to_clock_time(self) -> ClockTime: + """Convert to :class:`.ClockTime`.""" seconds, nanoseconds = divmod(self.ticks, NANO_SECONDS) return ClockTime(seconds, nanoseconds) - def to_native(self): + def to_native(self) -> time: """Convert to a native Python `datetime.time` value. - This conversion is lossy as the native time implementation only supports - a resolution of microseconds instead of nanoseconds. - - :rtype: datetime.time + This conversion is lossy as the native time implementation only + supports a resolution of microseconds instead of nanoseconds. """ h, m, s, ns = self.hour_minute_second_nanosecond µs = round_half_to_even(ns / 1000) tz = self.tzinfo return time(h, m, s, µs, tz) - def iso_format(self): - """Return the :class:`.Time` as ISO formatted string. - - :rtype: str - """ + def iso_format(self) -> str: + """Return the :class:`.Time` as ISO formatted string.""" s = "%02d:%02d:%02d.%09d" % self.hour_minute_second_nanosecond offset = self.utc_offset() if offset is not None: s += "%+03d:%02d" % divmod(offset.total_seconds() // 60, 60) return s - def __repr__(self): + def __repr__(self) -> str: """""" if self.tzinfo is None: return "neo4j.time.Time(%r, %r, %r, %r)" % \ @@ -1904,7 +1872,7 @@ def __repr__(self): return "neo4j.time.Time(%r, %r, %r, %r, tzinfo=%r)" % \ (self.hour_minute_second_nanosecond + (self.tzinfo,)) - def __str__(self): + def __str__(self) -> str: """""" return self.iso_format() @@ -1912,22 +1880,56 @@ def __format__(self, format_spec): """""" raise NotImplementedError() + # INSTANCE METHOD ALIASES # + + def __getattr__(self, name): + """Map standard library attribute names to local attribute names, + for compatibility. + """ + try: + return { + "isoformat": self.iso_format, + "utcoffset": self.utc_offset, + }[name] + except KeyError: + raise AttributeError("Date has no attribute %r" % name) + + if t.TYPE_CHECKING: + def isoformat(self) -> str: # type: ignore[override] + ... + + utcoffset = utc_offset + -Time.min = Time(hour=0, minute=0, second=0, nanosecond=0) -Time.max = Time(hour=23, minute=59, second=59, nanosecond=999999999) -Time.resolution = Duration(nanoseconds=1) +Time.min = Time( # type: ignore + hour=0, minute=0, second=0, nanosecond=0 +) +Time.max = Time( # type: ignore + hour=23, minute=59, second=59, nanosecond=999999999 +) +Time.resolution = Duration( # type: ignore + nanoseconds=1 +) #: A :class:`.Time` instance set to `00:00:00`. #: This has a :attr:`.ticks` value of `0`. -Midnight = Time.min +Midnight: te.Final[Time] = Time.min #: A :class:`.Time` instance set to `12:00:00`. #: This has a :attr:`.ticks` value of `43200000000000`. -Midday = Time(hour=12) +Midday: te.Final[Time] = Time(hour=12) + + +if t.TYPE_CHECKING: + # make typechecker believe that DateTime subclasses datetime.datetime + # https://github.com/python/typeshed/issues/8409#issuecomment-1197704527 + date_time_base_class = datetime +else: + date_time_base_class = object @total_ordering -class DateTime(metaclass=DateTimeType): +class DateTime(date_time_base_class, metaclass=DateTimeType): """A point in time represented as a date and a time. The :class:`.DateTime` class is a nanosecond-precision drop-in replacement @@ -1956,42 +1958,32 @@ class DateTime(metaclass=DateTimeType): 56.789123456 """ + __date: Date + __time: Time + # CONSTRUCTOR # - def __new__(cls, year, month, day, hour=0, minute=0, second=0, nanosecond=0, - tzinfo=None): + def __new__( + cls, + year: int, + month: int, + day: int, + hour: int = 0, + minute: int = 0, + second: int = 0, + nanosecond: int = 0, + tzinfo: _tzinfo = None + ) -> DateTime: return cls.combine(Date(year, month, day), Time(hour, minute, second, nanosecond, tzinfo)) - def __getattr__(self, name): - """ Map standard library attribute names to local attribute names, - for compatibility. - """ - try: - return { - "astimezone": self.as_timezone, - "isocalendar": self.iso_calendar, - "isoformat": self.iso_format, - "isoweekday": self.iso_weekday, - "strftime": self.__format__, - "toordinal": self.to_ordinal, - "timetuple": self.time_tuple, - "utcoffset": self.utc_offset, - "utctimetuple": self.utc_time_tuple, - }[name] - except KeyError: - raise AttributeError("DateTime has no attribute %r" % name) - # CLASS METHODS # @classmethod - def now(cls, tz=None): + def now(cls, tz: _tzinfo = None) -> DateTime: """Get the current date and time. :param tz: timezone. Set to None to create a local :class:`.DateTime`. - :type tz: datetime.tzinfo` or None - - :rtype: DateTime :raises OverflowError: if the timestamp is out of the range of values supported by the platform C localtime() function. It’s common for @@ -2001,9 +1993,11 @@ def now(cls, tz=None): return cls.from_clock_time(Clock().local_time(), UnixEpoch) else: try: - return tz.fromutc(cls.from_clock_time( - Clock().utc_time(), UnixEpoch - ).replace(tzinfo=tz)) + return tz.fromutc( # type: ignore + cls.from_clock_time( # type: ignore + Clock().utc_time(), UnixEpoch + ).replace(tzinfo=tz) + ) except TypeError: # For timezone implementations not compatible with the custom # datetime implementations, we can't do better than this. @@ -2020,21 +2014,15 @@ def now(cls, tz=None): ) @classmethod - def utc_now(cls): - """Get the current date and time in UTC - - :rtype: DateTime - """ + def utc_now(cls) -> DateTime: + """Get the current date and time in UTC.""" return cls.from_clock_time(Clock().utc_time(), UnixEpoch) @classmethod - def from_iso_format(cls, s): + def from_iso_format(cls, s) -> DateTime: """Parse a ISO formatted date with time string. :param s: String to parse - :type s: str - - :rtype: Time :raises ValueError: if the string does not match the ISO format. """ @@ -2045,15 +2033,11 @@ def from_iso_format(cls, s): raise ValueError("DateTime string is not in ISO format") @classmethod - def from_timestamp(cls, timestamp, tz=None): + def from_timestamp(cls, timestamp: float, tz: _tzinfo = None) -> DateTime: """:class:`.DateTime` from a time stamp (seconds since unix epoch). :param timestamp: the unix timestamp (seconds since unix epoch). - :type timestamp: float :param tz: timezone. Set to None to create a local :class:`.DateTime`. - :type tz: datetime.tzinfo or None - - :rtype: DateTime :raises OverflowError: if the timestamp is out of the range of values supported by the platform C localtime() function. It’s common for @@ -2070,35 +2054,29 @@ def from_timestamp(cls, timestamp, tz=None): ) @classmethod - def utc_from_timestamp(cls, timestamp): + def utc_from_timestamp(cls, timestamp: float) -> DateTime: """:class:`.DateTime` from a time stamp (seconds since unix epoch). Returns the `DateTime` as local date `DateTime` in UTC. - - :rtype: DateTime """ return cls.from_clock_time((timestamp, 0), UnixEpoch) @classmethod - def from_ordinal(cls, ordinal): + def from_ordinal(cls, ordinal: int) -> DateTime: """:class:`.DateTime` from an ordinal. For more info about ordinals see :meth:`.Date.from_ordinal`. - - :rtype: DateTime """ return cls.combine(Date.from_ordinal(ordinal), Midnight) @classmethod - def combine(cls, date, time): + def combine( # type: ignore[override] + cls, date: Date, time: Time + ) -> DateTime: """Combine a :class:`.Date` and a :class:`.Time` to a :class:`DateTime`. :param date: the date - :type date: Date :param time: the time - :type time: Time - - :rtype: DateTime :raises AssertionError: if the parameter types don't match. """ @@ -2114,27 +2092,25 @@ def parse(cls, date_string, format): raise NotImplementedError() @classmethod - def from_native(cls, dt): + def from_native(cls, dt: datetime) -> DateTime: """Convert from a native Python :class:`datetime.datetime` value. :param dt: the datetime to convert - :type dt: datetime.datetime - - :rtype: DateTime """ - return cls.combine(Date.from_native(dt.date()), Time.from_native(dt.timetz())) + return cls.combine(Date.from_native(dt.date()), + Time.from_native(dt.timetz())) @classmethod - def from_clock_time(cls, clock_time, epoch): + def from_clock_time( + cls, + clock_time: t.Union[ClockTime, t.Tuple[float, int]], + epoch: DateTime + ) -> DateTime: """Convert from a :class:`ClockTime` relative to a given epoch. :param clock_time: the clock time as :class:`.ClockTime` or as tuple of (seconds, nanoseconds) - :type clock_time: ClockTime or (float, int) :param epoch: the epoch to which `clock_time` is relative - :type epoch: DateTime - - :rtype: DateTime :raises ValueError: if `clock_time` is invalid. """ @@ -2151,21 +2127,58 @@ def from_clock_time(cls, clock_time, epoch): time_ = Time.from_ticks(ticks) return cls.combine(date_, time_) + # CLASS METHOD ALIASES # + + if t.TYPE_CHECKING: + @classmethod + def fromisoformat(cls, s) -> DateTime: + ... + + @classmethod + def fromordinal(cls, ordinal: int) -> DateTime: + ... + + @classmethod + def fromtimestamp( + cls, timestamp: float, tz: _tzinfo = None + ) -> DateTime: + ... + + # alias of parse + @classmethod + def strptime(cls, date_string, format): + ... + + # alias of now + @classmethod + def today(cls, tz: _tzinfo = None) -> DateTime: + ... + + @classmethod + def utcfromtimestamp(cls, timestamp: float) -> DateTime: + ... + + @classmethod + def utcnow(cls) -> DateTime: + ... + + + # CLASS ATTRIBUTES # - min = None + min: te.Final[DateTime] = None # type: ignore """The earliest date time value possible.""" - max = None + max: te.Final[DateTime] = None # type: ignore """The latest date time value possible.""" - resolution = None + resolution: te.Final[Duration] = None # type: ignore """The minimum resolution supported.""" # INSTANCE ATTRIBUTES # @property - def year(self): + def year(self) -> int: """The year of the :class:`.DateTime`. See :attr:`.Date.year`. @@ -2173,77 +2186,77 @@ def year(self): return self.__date.year @property - def month(self): + def month(self) -> int: """The year of the :class:`.DateTime`. See :attr:`.Date.year`.""" return self.__date.month @property - def day(self): + def day(self) -> int: """The day of the :class:`.DateTime`'s date. See :attr:`.Date.day`.""" return self.__date.day @property - def year_month_day(self): + def year_month_day(self) -> t.Tuple[int, int, int]: """The year_month_day of the :class:`.DateTime`'s date. See :attr:`.Date.year_month_day`.""" return self.__date.year_month_day @property - def year_week_day(self): + def year_week_day(self) -> t.Tuple[int, int, int]: """The year_week_day of the :class:`.DateTime`'s date. See :attr:`.Date.year_week_day`.""" return self.__date.year_week_day @property - def year_day(self): + def year_day(self) -> t.Tuple[int, int]: """The year_day of the :class:`.DateTime`'s date. See :attr:`.Date.year_day`.""" return self.__date.year_day @property - def hour(self): + def hour(self) -> int: """The hour of the :class:`.DateTime`'s time. See :attr:`.Time.hour`.""" return self.__time.hour @property - def minute(self): + def minute(self) -> int: """The minute of the :class:`.DateTime`'s time. See :attr:`.Time.minute`.""" return self.__time.minute @property - def second(self): + def second(self) -> int: """The second of the :class:`.DateTime`'s time. See :attr:`.Time.second`.""" return self.__time.second @property - def nanosecond(self): + def nanosecond(self) -> int: """The nanosecond of the :class:`.DateTime`'s time. See :attr:`.Time.nanosecond`.""" return self.__time.nanosecond @property - def tzinfo(self): + def tzinfo(self) -> t.Optional[_tzinfo]: """The tzinfo of the :class:`.DateTime`'s time. See :attr:`.Time.tzinfo`.""" return self.__time.tzinfo @property - def hour_minute_second_nanosecond(self): + def hour_minute_second_nanosecond(self) -> t.Tuple[int, int, int, int]: """The hour_minute_second_nanosecond of the :class:`.DateTime`'s time. See :attr:`.Time.hour_minute_second_nanosecond`.""" @@ -2285,9 +2298,9 @@ def __hash__(self): self_norm -= utc_offset return hash(self_norm.date()) ^ hash(self_norm.time()) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: """ - `==` comparison with :class:`.DateTime` or :class:`datetime.datetime`. + ``==`` comparison with :class:`.DateTime` or :class:`datetime.datetime`. """ if not isinstance(other, (datetime, DateTime)): return NotImplemented @@ -2298,15 +2311,17 @@ def __eq__(self, other): return False return self_norm == other_norm - def __ne__(self, other): + def __ne__(self, other: object) -> bool: """ - `!=` comparison with :class:`.DateTime` or :class:`datetime.datetime`. + ``!=`` comparison with :class:`.DateTime` or :class:`datetime.datetime`. """ return not self.__eq__(other) - def __lt__(self, other): + def __lt__( # type: ignore[override] + self, other: datetime + ) -> bool: """ - `<` comparison with :class:`.DateTime` or :class:`datetime.datetime`. + ``<`` comparison with :class:`.DateTime` or :class:`datetime.datetime`. """ if not isinstance(other, (datetime, DateTime)): return NotImplemented @@ -2318,9 +2333,11 @@ def __lt__(self, other): return (self_norm.date() < other_norm.date() or self_norm.time() < other_norm.time()) - def __le__(self, other): + def __le__( # type: ignore[override] + self, other: t.Union[datetime, DateTime] + ) -> bool: """ - `<=` comparison with :class:`.DateTime` or :class:`datetime.datetime`. + ``<=`` comparison with :class:`.DateTime` or :class:`datetime.datetime`. """ if not isinstance(other, (datetime, DateTime)): return NotImplemented @@ -2331,9 +2348,11 @@ def __le__(self, other): self_norm, other_norm = self._get_both_normalized(other) return self_norm <= other_norm - def __ge__(self, other): + def __ge__( # type: ignore[override] + self, other: t.Union[datetime, DateTime] + ) -> bool: """ - `>=` comparison with :class:`.DateTime` or :class:`datetime.datetime`. + ``>=`` comparison with :class:`.DateTime` or :class:`datetime.datetime`. """ if not isinstance(other, (datetime, DateTime)): return NotImplemented @@ -2344,9 +2363,11 @@ def __ge__(self, other): self_norm, other_norm = self._get_both_normalized(other) return self_norm >= other_norm - def __gt__(self, other): + def __gt__( # type: ignore[override] + self, other: t.Union[datetime, DateTime] + ) -> bool: """ - `>` comparison with :class:`.DateTime` or :class:`datetime.datetime`. + ``>`` comparison with :class:`.DateTime` or :class:`datetime.datetime`. """ if not isinstance(other, (datetime, DateTime)): return NotImplemented @@ -2358,11 +2379,16 @@ def __gt__(self, other): return (self_norm.date() > other_norm.date() or self_norm.time() > other_norm.time()) - def __add__(self, other): - """Add a :class:`datetime.timedelta`. - - :rtype: DateTime - """ + def __add__(self, other: t.Union[timedelta, Duration]) -> DateTime: + """Add a :class:`datetime.timedelta`.""" + if isinstance(other, Duration): + t = (self.to_clock_time() + + ClockTime(other.seconds, other.nanoseconds)) + days, seconds = symmetric_divmod(t.seconds, 86400) + date_ = self.date() + Duration(months=other.months, + days=days + other.days) + time_ = Time.from_ticks(seconds * NANO_SECONDS + t.nanoseconds) + return self.combine(date_, time_).replace(tzinfo=self.tzinfo) if isinstance(other, timedelta): t = (self.to_clock_time() + ClockTime(86400 * other.days + other.seconds, @@ -2373,24 +2399,31 @@ def __add__(self, other): seconds * NANO_SECONDS + t.nanoseconds )) return self.combine(date_, time_).replace(tzinfo=self.tzinfo) - if isinstance(other, Duration): - t = (self.to_clock_time() - + ClockTime(other.seconds, other.nanoseconds)) - days, seconds = symmetric_divmod(t.seconds, 86400) - date_ = self.date() + Duration(months=other.months, - days=days + other.days) - time_ = Time.from_ticks(seconds * NANO_SECONDS + t.nanoseconds) - return self.combine(date_, time_).replace(tzinfo=self.tzinfo) return NotImplemented + @t.overload # type: ignore[override] + def __sub__(self, other: DateTime) -> Duration: + ... + + @t.overload + def __sub__(self, other: datetime) -> timedelta: + ... + + @t.overload + def __sub__(self, other: t.Union[Duration, timedelta]) -> DateTime: + ... + def __sub__(self, other): - """Subtract a datetime or a timedelta. + """Subtract a datetime/DateTime or a timedelta/Duration. + + Subtracting a :class:`.DateTime` yields the duration between the two + as a :class:`.Duration`. - Supported :class:`.DateTime` (returns :class:`.Duration`), - :class:`datetime.datetime` (returns :class:`datetime.timedelta`), and - :class:`datetime.timedelta` (returns :class:`.DateTime`). + Subtracting a :class:`datetime.datetime` yields the duration between + the two as a :class:`datetime.timedelta`. - :rtype: Duration or datetime.timedelta or DateTime + Subtracting a :class:`datetime.timedelta` or a :class:`.Duration` + yields the :class:`.DateTime` that's the given duration away. """ if isinstance(other, DateTime): self_month_ordinal = 12 * (self.year - 1) + self.month @@ -2415,62 +2448,69 @@ def __sub__(self, other): return self.__add__(-other) return NotImplemented - def __copy__(self): + def __copy__(self) -> DateTime: return self.combine(self.__date, self.__time) - def __deepcopy__(self, *args, **kwargs): + def __deepcopy__(self, memo) -> DateTime: return self.__copy__() # INSTANCE METHODS # - def date(self): - """The date - - :rtype: Date - """ + def date(self) -> Date: + """The date.""" return self.__date - def time(self): - """The time without timezone info - - :rtype: Time - """ + def time(self) -> Time: + """The time without timezone info.""" return self.__time.replace(tzinfo=None) - def timetz(self): - """The time with timezone info - - :rtype: Time - """ + def timetz(self) -> Time: + """The time with timezone info.""" return self.__time - def replace(self, **kwargs): - """Return a :class:`.DateTime` with one or more components replaced. + if t.TYPE_CHECKING: + + def replace( # type: ignore[override] + self, + year: int = ..., + month: int = ..., + day: int = ..., + hour: int = ..., + minute: int = ..., + second: int = ..., + nanosecond: int = ..., + tzinfo: t.Optional[_tzinfo] = ..., + **kwargs: object + ) -> DateTime: + ... - See :meth:`.Date.replace` and :meth:`.Time.replace` for available - arguments. + else: - :rtype: DateTime - """ - date_ = self.__date.replace(**kwargs) - time_ = self.__time.replace(**kwargs) - return self.combine(date_, time_) + def replace(self, **kwargs) -> DateTime: + """Return a ``DateTime`` with one or more components replaced. + + See :meth:`.Date.replace` and :meth:`.Time.replace` for available + arguments. + """ + date_ = self.__date.replace(**kwargs) + time_ = self.__time.replace(**kwargs) + return self.combine(date_, time_) - def as_timezone(self, tz): + def as_timezone(self, tz: _tzinfo) -> DateTime: """Convert this :class:`.DateTime` to another timezone. :param tz: the new timezone - :type tz: datetime.tzinfo or None - :return: the same object if `tz` is None. Else, a new :class:`.DateTime` - that's the same point in time but in a different timezone. - :rtype: DateTime + :return: the same object if ``tz`` is :const:``None``. + Else, a new :class:`.DateTime` that's the same point in time but in + a different timezone. """ if self.tzinfo is None: return self - utc = (self - self.utc_offset()).replace(tzinfo=tz) + offset = t.cast(timedelta, self.utcoffset()) + utc = (self - offset).replace(tzinfo=tz) try: - return tz.fromutc(utc) + return tz.fromutc(utc) # type: ignore except TypeError: # For timezone implementations not compatible with the custom # datetime implementations, we can't do better than this. @@ -2482,7 +2522,7 @@ def as_timezone(self, tz): + self.nanosecond % 1000) ) - def utc_offset(self): + def utc_offset(self) -> t.Optional[timedelta]: """Get the date times utc offset. See :meth:`.Time.utc_offset`. @@ -2490,14 +2530,14 @@ def utc_offset(self): return self.__time._utc_offset(self) - def dst(self): + def dst(self) -> t.Optional[timedelta]: """Get the daylight saving time adjustment (DST). See :meth:`.Time.dst`. """ return self.__time.dst() - def tzname(self): + def tzname(self) -> t.Optional[str]: """Get the timezone name. See :meth:`.Time.tzname`. @@ -2510,18 +2550,15 @@ def time_tuple(self): def utc_time_tuple(self): raise NotImplementedError() - def to_ordinal(self): + def to_ordinal(self) -> int: """Get the ordinal of the :class:`.DateTime`'s date. See :meth:`.Date.to_ordinal` """ return self.__date.to_ordinal() - def to_clock_time(self): - """Convert to :class:`.ClockTime`. - - :rtype: ClockTime - """ + def to_clock_time(self) -> ClockTime: + """Convert to :class:`.ClockTime`.""" total_seconds = 0 for year in range(1, self.year): total_seconds += 86400 * DAYS_IN_YEAR[year] @@ -2531,13 +2568,11 @@ def to_clock_time(self): seconds, nanoseconds = divmod(self.__time.ticks, NANO_SECONDS) return ClockTime(total_seconds + seconds, nanoseconds) - def to_native(self): + def to_native(self) -> datetime: """Convert to a native Python :class:`datetime.datetime` value. This conversion is lossy as the native time implementation only supports a resolution of microseconds instead of nanoseconds. - - :rtype: datetime.datetime """ y, mo, d = self.year_month_day h, m, s, ns = self.hour_minute_second_nanosecond @@ -2545,28 +2580,28 @@ def to_native(self): tz = self.tzinfo return datetime(y, mo, d, h, m, s, ms, tz) - def weekday(self): + def weekday(self) -> int: """Get the weekday. See :meth:`.Date.weekday` """ return self.__date.weekday() - def iso_weekday(self): + def iso_weekday(self) -> int: """Get the ISO weekday. See :meth:`.Date.iso_weekday` """ return self.__date.iso_weekday() - def iso_calendar(self): + def iso_calendar(self) -> t.Tuple[int, int, int]: """Get date as ISO tuple. See :meth:`.Date.iso_calendar` """ return self.__date.iso_calendar() - def iso_format(self, sep="T"): + def iso_format(self, sep: str = "T") -> str: """Return the :class:`.DateTime` as ISO formatted string. This method joins `self.date().iso_format()` (see @@ -2574,9 +2609,6 @@ def iso_format(self, sep="T"): :meth:`.Time.iso_format`) with `sep` in between. :param sep: the separator between the formatted date and time. - :type sep: str - - :rtype: str """ s = "%s%s%s" % (self.date().iso_format(), sep, self.timetz().iso_format()) @@ -2590,8 +2622,9 @@ def iso_format(self, sep="T"): s += "%+03d:%02d" % divmod(offset.total_seconds() // 60, 60) return s - def __repr__(self): + def __repr__(self) -> str: """""" + fields: tuple if self.tzinfo is None: fields = (*self.year_month_day, *self.hour_minute_second_nanosecond) @@ -2602,7 +2635,7 @@ def __repr__(self): return ("neo4j.time.DateTime(%r, %r, %r, %r, %r, %r, %r, tzinfo=%r)" % fields) - def __str__(self): + def __str__(self) -> str: """""" return self.iso_format() @@ -2610,10 +2643,49 @@ def __format__(self, format_spec): """""" raise NotImplementedError() + # INSTANCE METHOD ALIASES # + + def __getattr__(self, name): + """ Map standard library attribute names to local attribute names, + for compatibility. + """ + try: + return { + "astimezone": self.as_timezone, + "isocalendar": self.iso_calendar, + "isoformat": self.iso_format, + "isoweekday": self.iso_weekday, + "strftime": self.__format__, + "toordinal": self.to_ordinal, + "timetuple": self.time_tuple, + "utcoffset": self.utc_offset, + "utctimetuple": self.utc_time_tuple, + }[name] + except KeyError: + raise AttributeError("DateTime has no attribute %r" % name) + + if t.TYPE_CHECKING: + def astimezone( # type: ignore[override] + self, tz: _tzinfo + ) -> DateTime: + ... + + isocalendar = iso_calendar + + def iso_format(self, sep: str = "T") -> str: # type: ignore[override] + ... + + isoweekday = iso_weekday + strftime = __format__ + toordinal = to_ordinal + timetuple = time_tuple + utcoffset = utc_offset + utctimetuple = utc_time_tuple + -DateTime.min = DateTime.combine(Date.min, Time.min) -DateTime.max = DateTime.combine(Date.max, Time.max) -DateTime.resolution = Time.resolution +DateTime.min = DateTime.combine(Date.min, Time.min) # type: ignore +DateTime.max = DateTime.combine(Date.max, Time.max) # type: ignore +DateTime.resolution = Time.resolution # type: ignore #: A :class:`.DateTime` instance set to `0000-00-00T00:00:00`. #: This has a :class:`.Date` component equal to :attr:`ZeroDate` and a diff --git a/neo4j/time/_arithmetic.py b/neo4j/time/_arithmetic.py index 93bfe8ed..7e8cd76e 100644 --- a/neo4j/time/_arithmetic.py +++ b/neo4j/time/_arithmetic.py @@ -16,6 +16,12 @@ # limitations under the License. +from typing import ( + Tuple, + TypeVar, +) + + __all__ = [ "nano_add", "nano_div", @@ -82,7 +88,12 @@ def nano_divmod(x, y): return int(q), number(r / 1000000000) -def symmetric_divmod(dividend, divisor): +_T_dividend = TypeVar("_T_dividend", int, float) + + +def symmetric_divmod( + dividend: _T_dividend, divisor: float +) -> Tuple[int, _T_dividend]: number = type(dividend) if dividend >= 0: quotient, remainder = divmod(dividend, divisor) diff --git a/neo4j/work/query.py b/neo4j/work/query.py index bc861ad1..6dd6f8f3 100644 --- a/neo4j/work/query.py +++ b/neo4j/work/query.py @@ -16,27 +16,40 @@ # limitations under the License. +from __future__ import annotations + +import typing as t + + +if t.TYPE_CHECKING: + _T = t.TypeVar("_T") + + class Query: """ Create a new query. :param text: The query text. - :type text: str :param metadata: metadata attached to the query. - :type metadata: dict :param timeout: seconds. - :type timeout: float or :const:`None` """ - def __init__(self, text, metadata=None, timeout=None): + def __init__( + self, + text: str, + metadata: t.Dict[str, t.Any] = None, + timeout: float = None + ) -> None: self.text = text self.metadata = metadata self.timeout = timeout - def __str__(self): + def __str__(self) -> str: return str(self.text) -def unit_of_work(metadata=None, timeout=None): +def unit_of_work( + metadata: t.Dict[str, t.Any] = None, timeout: float = None +) -> t.Callable[[_T], _T]: """This function is a decorator for transaction functions that allows extra control over how the transaction is carried out. For example, a timeout may be applied:: @@ -54,7 +67,6 @@ def count_people_tx(tx): Specified metadata will be attached to the executing transaction and visible in the output of ``dbms.listQueries`` and ``dbms.listTransactions`` procedures. It will also get logged to the ``query.log``. This functionality makes it easier to tag transactions and is equivalent to ``dbms.setTXMetaData`` procedure, see https://neo4j.com/docs/operations-manual/current/reference/procedures/ for procedure reference. - :type metadata: dict :param timeout: the transaction timeout in seconds. @@ -64,7 +76,6 @@ def count_people_tx(tx): Value should not represent a negative duration. A zero duration will make the transaction execute indefinitely. None will use the default timeout configured in the database. - :type timeout: float or :const:`None` """ def wrapper(f): diff --git a/neo4j/work/summary.py b/neo4j/work/summary.py index a7ce2792..7e613f5b 100644 --- a/neo4j/work/summary.py +++ b/neo4j/work/summary.py @@ -16,7 +16,17 @@ # limitations under the License. +from __future__ import annotations + +import typing as t + + +if t.TYPE_CHECKING: + import typing_extensions as te + from .._exceptions import BoltProtocolError +from ..addressing import Address +from ..api import ServerInfo BOLT_VERSION_1 = 1 @@ -32,47 +42,47 @@ class ResultSummary: """ #: A :class:`neo4j.ServerInfo` instance. Provides some basic information of the server where the result is obtained from. - server = None + server: ServerInfo #: The database name where this summary is obtained from. - database = None + database: t.Optional[str] #: The query that was executed to produce this result. - query = None + query: t.Optional[str] #: Dictionary of parameters passed with the statement. - parameters = None + parameters: t.Optional[t.Dict[str, t.Any]] #: A string that describes the type of query # ``'r'`` = read-only, ``'rw'`` = read/write, ``'w'`` = write-onlye, # ``'s'`` = schema. - query_type = None + query_type: t.Union[te.Literal["r", "rw", "w", "s"], None] #: A :class:`neo4j.SummaryCounters` instance. Counters for operations the query triggered. - counters = None + counters: SummaryCounters #: Dictionary that describes how the database will execute the query. - plan = None + plan: t.Optional[dict] #: Dictionary that describes how the database executed the query. - profile = None + profile: t.Optional[dict] #: The time it took for the server to have the result available. (milliseconds) - result_available_after = None + result_available_after: t.Optional[int] #: The time it took for the server to consume the result. (milliseconds) - result_consumed_after = None + result_consumed_after: t.Optional[int] #: A list of Dictionaries containing notification information. #: Notifications provide extra information for a user executing a statement. #: They can be warnings about problematic queries or other valuable information that can be #: presented in a client. #: Unlike failures or errors, notifications do not affect the execution of a statement. - notifications = None + notifications: t.Optional[t.List[dict]] - def __init__(self, address, **metadata): + def __init__(self, address: Address, **metadata: t.Any) -> None: self.metadata = metadata - self.server = metadata.get("server") + self.server = metadata["server"] self.database = metadata.get("db") self.query = metadata.get("query") self.parameters = metadata.get("parameters") @@ -101,45 +111,45 @@ class SummaryCounters: """ #: - nodes_created = 0 + nodes_created: int = 0 #: - nodes_deleted = 0 + nodes_deleted: int = 0 #: - relationships_created = 0 + relationships_created: int = 0 #: - relationships_deleted = 0 + relationships_deleted: int = 0 #: - properties_set = 0 + properties_set: int = 0 #: - labels_added = 0 + labels_added: int = 0 #: - labels_removed = 0 + labels_removed: int = 0 #: - indexes_added = 0 + indexes_added: int = 0 #: - indexes_removed = 0 + indexes_removed: int = 0 #: - constraints_added = 0 + constraints_added: int = 0 #: - constraints_removed = 0 + constraints_removed: int = 0 #: - system_updates = 0 + system_updates: int = 0 _contains_updates = None _contains_system_updates = None - def __init__(self, statistics): + def __init__(self, statistics) -> None: key_to_attr_name = { "nodes-created": "nodes_created", "nodes-deleted": "nodes_deleted", @@ -161,11 +171,11 @@ def __init__(self, statistics): if attr_name: setattr(self, attr_name, value) - def __repr__(self): + def __repr__(self) -> str: return repr(vars(self)) @property - def contains_updates(self): + def contains_updates(self) -> bool: """True if any of the counters except for system_updates, are greater than 0. Otherwise False.""" if self._contains_updates is not None: @@ -180,7 +190,7 @@ def contains_updates(self): ) @property - def contains_system_updates(self): + def contains_system_updates(self) -> bool: """True if the system database was updated, otherwise False.""" if self._contains_system_updates is not None: return self._contains_system_updates diff --git a/requirements-dev.txt b/requirements-dev.txt index cc40eecd..3d4596f5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,6 +2,8 @@ unasync>=0.5.0 pre-commit>=2.15.0 isort>=5.10.0 +mypy>=0.971 +types-pytz>=2022.1.2 # needed for running tests -r tests/requirements.txt diff --git a/setup.cfg b/setup.cfg index cf2d2220..a40b05ed 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,3 +19,8 @@ use_parentheses=true [tool:pytest] mock_use_standalone_module = true asyncio_mode = auto + +[mypy] + +[mypy-pandas.*] +ignore_missing_imports = True diff --git a/tests/env.py b/tests/env.py index 8e4b077e..11267899 100644 --- a/tests/env.py +++ b/tests/env.py @@ -18,6 +18,8 @@ import abc import sys +import types +import typing as t from os import environ @@ -28,7 +30,7 @@ def eval(self): class _LazyEvalEnv(_LazyEval): - def __init__(self, env_key, type_=str, default=...): + def __init__(self, env_key, type_: t.Type = str, default=...): self.env_key = env_key self.type_ = type_ self.default = default @@ -41,7 +43,7 @@ def eval(self): value = environ[self.env_key] except KeyError as e: raise Exception( - f"Missing environemnt variable {self.env_key}" + f"Missing environment variable {self.env_key}" ) from e if self.type_ is bool: return value.lower() in ("yes", "y", "1", "on", "true") @@ -59,19 +61,19 @@ def eval(self): class _Module: def __init__(self, module): - self._moudle = module + self._module = module def __getattr__(self, item): - val = getattr(self._moudle, item) + val = getattr(self._module, item) if isinstance(val, _LazyEval): val = val.eval() - setattr(self._moudle, item, val) + setattr(self._module, item, val) return val _module = _Module(sys.modules[__name__]) -sys.modules[__name__] = _module +sys.modules[__name__] = _module # type: ignore[assignment] NEO4J_HOST = _LazyEvalEnv("TEST_NEO4J_HOST") diff --git a/tests/unit/async_/work/test_result.py b/tests/unit/async_/work/test_result.py index 32b5cdcb..178b0cc6 100644 --- a/tests/unit/async_/work/test_result.py +++ b/tests/unit/async_/work/test_result.py @@ -829,12 +829,12 @@ async def test_to_df(keys, values, types, instances, test_default_expand): ( ["n"], list(zip(( - Node(None, "00", 0, ["LABEL_A"], - {"a": 1, "b": 2, "d": 1}), - Node(None, "02", 2, ["LABEL_B"], - {"a": 1, "c": 1.2, "d": 2}), - Node(None, "01", 1, ["LABEL_A", "LABEL_B"], - {"a": [1, "a"], "d": 3}), + Node(None, # type: ignore[arg-type] + "00", 0, ["LABEL_A"], {"a": 1, "b": 2, "d": 1}), + Node(None, # type: ignore[arg-type] + "02", 2, ["LABEL_B"], {"a": 1, "c": 1.2, "d": 2}), + Node(None, # type: ignore[arg-type] + "01", 1, ["LABEL_A", "LABEL_B"], {"a": [1, "a"], "d": 3}), ))), [ "n().element_id", "n().labels", "n().prop.a", "n().prop.b", @@ -1009,8 +1009,8 @@ async def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["mixed"], [ - [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6),], - [neo4j_time.Date(2222, 2, 22),], + [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)], + [neo4j_time.Date(2222, 2, 22)], [ pytz.timezone("Europe/Stockholm").localize( neo4j_time.DateTime(1970, 1, 1, 0, 0, 0, 0) @@ -1105,7 +1105,7 @@ async def test_broken_hydration(nested): assert len(record_out) == 2 assert record_out[0] == "foobar" with pytest.raises(BrokenRecordError) as exc: - record_out[1] + _ = record_out[1] cause = exc.value.__cause__ assert isinstance(cause, ValueError) assert repr(b"a") in str(cause) diff --git a/tests/unit/common/spatial/test_cartesian_point.py b/tests/unit/common/spatial/test_cartesian_point.py index 8e22f1a9..5e378d33 100644 --- a/tests/unit/common/spatial/test_cartesian_point.py +++ b/tests/unit/common/spatial/test_cartesian_point.py @@ -16,14 +16,16 @@ # limitations under the License. +from __future__ import annotations + import pytest from neo4j.spatial import CartesianPoint -class CartesianPointTestCase: +class TestCartesianPoint: - def test_alias_3d(self): + def test_alias_3d(self) -> None: x, y, z = 3.2, 4.0, -1.2 p = CartesianPoint((x, y, z)) assert hasattr(p, "x") @@ -33,7 +35,7 @@ def test_alias_3d(self): assert hasattr(p, "z") assert p.z == z - def test_alias_2d(self): + def test_alias_2d(self) -> None: x, y = 3.2, 4.0 p = CartesianPoint((x, y)) assert hasattr(p, "x") @@ -41,4 +43,4 @@ def test_alias_2d(self): assert hasattr(p, "y") assert p.y == y with pytest.raises(AttributeError): - p.z + _ = p.z diff --git a/tests/unit/common/spatial/test_point.py b/tests/unit/common/spatial/test_point.py index 5ede9571..26ef7083 100644 --- a/tests/unit/common/spatial/test_point.py +++ b/tests/unit/common/spatial/test_point.py @@ -16,6 +16,10 @@ # limitations under the License. +from __future__ import annotations + +import typing as t + import pytest from neo4j._spatial import ( @@ -24,27 +28,36 @@ ) -class PointTestCase: +class TestPoint: - @pytest.mark.parametrize("argument", ("a", "b"), ({"x": 1.0, "y": 2.0})) - def test_wrong_type_arguments(self, argument): + @pytest.mark.parametrize("argument", ( + ("a", "b"), {"x": 1.0, "y": 2.0} + )) + def test_wrong_type_arguments(self, argument) -> None: with pytest.raises(ValueError): Point(argument) - @pytest.mark.parametrize("argument", (1, 2), (1.2, 2.1)) - def test_number_arguments(self, argument): + @pytest.mark.parametrize("argument", ( + (1, 2), (1.2, 2.1) + )) + def test_number_arguments(self, argument: t.Iterable[float]) -> None: + print(argument) p = Point(argument) assert tuple(p) == argument - def test_immutable_coordinates(self): - MyPoint = point_type("MyPoint", ["x", "y"], {2: 1234}) + def test_immutable_coordinates(self) -> None: + MyPoint = point_type("MyPoint", ("x", "y", "z"), {2: 1234, 3: 5678}) coordinates = (.1, 0) p = MyPoint(coordinates) with pytest.raises(AttributeError): - p.x = 2.0 + p.x = 2.0 # type: ignore[misc] with pytest.raises(AttributeError): - p.y = 2.0 + p.y = 2.0 # type: ignore[misc] + with pytest.raises(AttributeError): + p.z = 2.0 # type: ignore[misc] + with pytest.raises(TypeError): + p[0] = 2.0 # type: ignore[index] with pytest.raises(TypeError): - p[0] = 2.0 + p[1] = 2.0 # type: ignore[index] with pytest.raises(TypeError): - p[1] = 2.0 + p[2] = 2.0 # type: ignore[index] diff --git a/tests/unit/common/spatial/test_wgs84_point.py b/tests/unit/common/spatial/test_wgs84_point.py index aad72ab6..b15d90ec 100644 --- a/tests/unit/common/spatial/test_wgs84_point.py +++ b/tests/unit/common/spatial/test_wgs84_point.py @@ -16,14 +16,16 @@ # limitations under the License. +from __future__ import annotations + import pytest from neo4j.spatial import WGS84Point -class WGS84PointTestCase: +class TestWGS84Point: - def test_alias_3d(self): + def test_alias_3d(self) -> None: x, y, z = 3.2, 4.0, -1.2 p = WGS84Point((x, y, z)) @@ -42,7 +44,7 @@ def test_alias_3d(self): assert hasattr(p, "z") assert p.z == z - def test_alias_2d(self): + def test_alias_2d(self) -> None: x, y = 3.2, 4.0 p = WGS84Point((x, y)) diff --git a/tests/unit/common/test_addressing.py b/tests/unit/common/test_addressing.py index 99b730f3..f4d5bc00 100644 --- a/tests/unit/common/test_addressing.py +++ b/tests/unit/common/test_addressing.py @@ -16,6 +16,9 @@ # limitations under the License. +from __future__ import annotations + +import typing as t from socket import ( AF_INET, AF_INET6, @@ -51,7 +54,9 @@ (Address(("::1", 7687, 1, 2)), {"family": AF_INET6, "host": "::1", "port": 7687, "str": "[::1]:7687", "repr": "IPv6Address(('::1', 7687, 1, 2))"}), ] ) -def test_address_initialization(test_input, expected): +def test_address_initialization( + test_input: t.Union[tuple, Address], expected: dict +) -> None: address = Address(test_input) assert address.family == expected["family"] assert address.host == expected["host"] @@ -67,7 +72,9 @@ def test_address_initialization(test_input, expected): Address(("127.0.0.1", 7687, 1, 2)), ] ) -def test_address_init_with_address_object_returns_same_instance(test_input): +def test_address_init_with_address_object_returns_same_instance( + test_input: Address +) -> None: address = Address(test_input) assert address is test_input assert id(address) == id(test_input) @@ -82,9 +89,11 @@ def test_address_init_with_address_object_returns_same_instance(test_input): (("[::1]", 7687, 0, 0, 0), ValueError), ] ) -def test_address_initialization_with_incorrect_input(test_input, expected): +def test_address_initialization_with_incorrect_input( + test_input: tuple, expected +) -> None: with pytest.raises(expected): - address = Address(test_input) + _ = Address(test_input) @pytest.mark.parametrize( @@ -94,15 +103,15 @@ def test_address_initialization_with_incorrect_input(test_input, expected): (mock_socket_ipv6, ("[::1]", 7687, 0, 0)) ] ) -def test_address_from_socket(test_input, expected): - +def test_address_from_socket(test_input: mock.Mock, expected: tuple) -> None: + _ = Address.from_socket(mock_socket_ipv4) address = Address.from_socket(test_input) assert address == expected -def test_address_from_socket_with_none(): +def test_address_from_socket_with_none() -> None: with pytest.raises(AttributeError): - address = Address.from_socket(None) + _ = Address.from_socket(None) # type: ignore[arg-type] @pytest.mark.parametrize( @@ -117,7 +126,7 @@ def test_address_from_socket_with_none(): (" ", (" ", 0)), ] ) -def test_address_parse_with_ipv4(test_input, expected): +def test_address_parse_with_ipv4(test_input: str, expected: tuple) -> None: parsed = Address.parse(test_input) assert parsed == expected @@ -131,7 +140,7 @@ def test_address_parse_with_ipv4(test_input, expected): ("[::1]", ("::1", 0, 0, 0)), ] ) -def test_address_should_parse_ipv6(test_input, expected): +def test_address_should_parse_ipv6(test_input: str, expected: tuple) -> None: parsed = Address.parse(test_input) assert parsed == expected @@ -146,9 +155,9 @@ def test_address_should_parse_ipv6(test_input, expected): (Address(("127.0.0.1", 7687)), TypeError), ] ) -def test_address_parse_with_invalid_input(test_input, expected): +def test_address_parse_with_invalid_input(test_input, expected) -> None: with pytest.raises(expected): - parsed = Address.parse(test_input) + _ = Address.parse(test_input) @pytest.mark.parametrize( @@ -160,7 +169,7 @@ def test_address_parse_with_invalid_input(test_input, expected): (("localhost:7687 localhost:7687", "[::1]:7687"), 3), ] ) -def test_address_parse_list(test_input, expected): +def test_address_parse_list(test_input: tuple, expected: int) -> None: addresses = Address.parse_list(*test_input) assert len(addresses) == expected @@ -175,6 +184,8 @@ def test_address_parse_list(test_input, expected): (("localhost:7687", Address(("127.0.0.1", 7687))), TypeError), ] ) -def test_address_parse_list_with_invalid_input(test_input, expected): +def test_address_parse_list_with_invalid_input( + test_input: tuple, expected +) -> None: with pytest.raises(TypeError): - addresses = Address.parse_list(*test_input) + _ = Address.parse_list(*test_input) diff --git a/tests/unit/common/test_api.py b/tests/unit/common/test_api.py index a0f83679..d484d10c 100644 --- a/tests/unit/common/test_api.py +++ b/tests/unit/common/test_api.py @@ -16,8 +16,10 @@ # limitations under the License. +from __future__ import annotations + import itertools -from contextlib import contextmanager +import typing as t import pytest @@ -29,12 +31,12 @@ not_ascii = "♥O◘♦♥O◘♦" -def test_bookmark_is_deprecated(): +def test_bookmark_is_deprecated() -> None: with pytest.deprecated_call(): neo4j.Bookmark() -def test_bookmark_initialization_with_no_values(): +def test_bookmark_initialization_with_no_values() -> None: with pytest.deprecated_call(): bookmark = neo4j.Bookmark() assert bookmark.values == frozenset() @@ -52,7 +54,9 @@ def test_bookmark_initialization_with_no_values(): ((None, "bookmark1", None, "bookmark2", None, None, "bookmark3"), frozenset({"bookmark1", "bookmark2", "bookmark3"}), True, ""), ] ) -def test_bookmark_initialization_with_values_none(test_input, expected_values, expected_bool, expected_repr): +def test_bookmark_initialization_with_values_none( + test_input, expected_values, expected_bool, expected_repr +) -> None: with pytest.deprecated_call(): bookmark = neo4j.Bookmark(*test_input) assert bookmark.values == expected_values @@ -70,7 +74,9 @@ def test_bookmark_initialization_with_values_none(test_input, expected_values, e (("", "bookmark1", "", "bookmark2", "", "", "bookmark3"), frozenset({"bookmark1", "bookmark2", "bookmark3"}), True, ""), ] ) -def test_bookmark_initialization_with_values_empty_string(test_input, expected_values, expected_bool, expected_repr): +def test_bookmark_initialization_with_values_empty_string( + test_input, expected_values, expected_bool, expected_repr +) -> None: with pytest.deprecated_call(): bookmark = neo4j.Bookmark(*test_input) assert bookmark.values == expected_values @@ -86,7 +92,9 @@ def test_bookmark_initialization_with_values_empty_string(test_input, expected_v (standard_ascii, frozenset(standard_ascii), True, "".format(values="', '".join(standard_ascii))) ] ) -def test_bookmark_initialization_with_valid_strings(test_input, expected_values, expected_bool, expected_repr): +def test_bookmark_initialization_with_valid_strings( + test_input, expected_values, expected_bool, expected_repr +) -> None: with pytest.deprecated_call(): bookmark = neo4j.Bookmark(*test_input) assert bookmark.values == expected_values @@ -94,35 +102,31 @@ def test_bookmark_initialization_with_valid_strings(test_input, expected_values, assert repr(bookmark) == expected_repr -@pytest.mark.parametrize( - "test_input, expected", +_bm_input_mark = pytest.mark.parametrize( + ("test_input", "expected"), [ ((not_ascii,), ValueError), (("", not_ascii,), ValueError), (("bookmark1", chr(129),), ValueError), ] ) -@pytest.mark.parametrize(("method", "deprecated", "splat_args"), ( - (neo4j.Bookmark, True, True), - (neo4j.Bookmarks.from_raw_values, False, False), -)) + + +@_bm_input_mark def test_bookmark_initialization_with_invalid_strings( - test_input, expected, method, deprecated, splat_args -): - @contextmanager - def deprecation_assertion(): - if deprecated: - with pytest.warns(DeprecationWarning): - yield - else: - yield + test_input: t.Tuple[str], expected +) -> None: + with pytest.raises(expected): + with pytest.warns(DeprecationWarning): + neo4j.Bookmark(*test_input) + +@_bm_input_mark +def test_bookmarks_initialization_with_invalid_strings( + test_input: t.Tuple[str], expected +) -> None: with pytest.raises(expected): - with deprecation_assertion(): - if splat_args: - method(*test_input) - else: - method(test_input) + neo4j.Bookmarks.from_raw_values(test_input) @pytest.mark.parametrize("test_as_generator", [True, False]) @@ -136,7 +140,7 @@ def deprecation_assertion(): ("bookmark1",), (), )) -def test_bookmarks_raw_values(test_as_generator, values): +def test_bookmarks_raw_values(test_as_generator, values) -> None: expected = frozenset(values) if test_as_generator: values = (v for v in values) @@ -160,7 +164,7 @@ def test_bookmarks_raw_values(test_as_generator, values): ((["bookmark1", "bookmark2"],), TypeError), ((not_ascii,), ValueError), )) -def test_bookmarks_invalid_raw_values(values, exc_type): +def test_bookmarks_invalid_raw_values(values, exc_type) -> None: with pytest.raises(exc_type): neo4j.Bookmarks().from_raw_values(values) @@ -171,7 +175,7 @@ def test_bookmarks_invalid_raw_values(values, exc_type): (("bm42",), ""), ((), ""), )) -def test_bookmarks_repr(values, expected_repr): +def test_bookmarks_repr(values, expected_repr) -> None: bookmarks = neo4j.Bookmarks().from_raw_values(values) assert repr(bookmarks) == expected_repr @@ -188,7 +192,7 @@ def test_bookmarks_repr(values, expected_repr): 2 )) )) -def test_bookmarks_combination(values1, values2): +def test_bookmarks_combination(values1, values2) -> None: bookmarks1 = neo4j.Bookmarks().from_raw_values(values1) bookmarks2 = neo4j.Bookmarks().from_raw_values(values2) bookmarks3 = bookmarks1 + bookmarks2 @@ -209,7 +213,9 @@ def test_bookmarks_combination(values1, values2): ((3, 0, 0, 0), "3.0.0.0", "Version(3, 0, 0, 0)"), ] ) -def test_version_initialization(test_input, expected_str, expected_repr): +def test_version_initialization( + test_input, expected_str, expected_repr +) -> None: version = neo4j.Version(*test_input) assert str(version) == expected_str assert repr(version) == expected_repr @@ -225,7 +231,9 @@ def test_version_initialization(test_input, expected_str, expected_repr): (bytearray([0, 0, 254, 254]), "254.254", "Version(254, 254)"), ] ) -def test_version_from_bytes_with_valid_bolt_version_handshake(test_input, expected_str, expected_repr): +def test_version_from_bytes_with_valid_bolt_version_handshake( + test_input, expected_str, expected_repr +) -> None: version = neo4j.Version.from_bytes(test_input) assert str(version) == expected_str assert repr(version) == expected_repr @@ -241,9 +249,11 @@ def test_version_from_bytes_with_valid_bolt_version_handshake(test_input, expect (bytearray([1, 1, 0, 0]), ValueError), ] ) -def test_version_from_bytes_with_not_valid_bolt_version_handshake(test_input, expected): +def test_version_from_bytes_with_not_valid_bolt_version_handshake( + test_input, expected +) -> None: with pytest.raises(expected): - version = neo4j.Version.from_bytes(test_input) + _ = neo4j.Version.from_bytes(test_input) @pytest.mark.parametrize( @@ -258,12 +268,14 @@ def test_version_from_bytes_with_not_valid_bolt_version_handshake(test_input, ex ((255, 255), bytearray([0, 0, 255, 255])), ] ) -def test_version_to_bytes_with_valid_bolt_version(test_input, expected): +def test_version_to_bytes_with_valid_bolt_version( + test_input, expected +) -> None: version = neo4j.Version(*test_input) assert version.to_bytes() == expected -def test_serverinfo_initialization(): +def test_serverinfo_initialization() -> None: from neo4j.addressing import Address @@ -273,7 +285,6 @@ def test_serverinfo_initialization(): server_info = neo4j.ServerInfo(address, version) assert server_info.address is address assert server_info.protocol_version is version - assert server_info.agent is None with pytest.warns(DeprecationWarning): assert server_info.connection_id is None @@ -287,8 +298,9 @@ def test_serverinfo_initialization(): ] ) @pytest.mark.parametrize("protocol_version", ((3, 0), (4, 3), (42, 1337))) -def test_serverinfo_with_metadata(test_input, expected_agent, - protocol_version): +def test_serverinfo_with_metadata( + test_input, expected_agent, protocol_version +) -> None: from neo4j.addressing import Address address = Address(("bolt://localhost", 7687)) @@ -308,18 +320,20 @@ def test_serverinfo_with_metadata(test_input, expected_agent, ("bolt://localhost:7676", neo4j.api.DRIVER_BOLT, neo4j.api.SECURITY_TYPE_NOT_SECURE, None), ("bolt+ssc://localhost:7676", neo4j.api.DRIVER_BOLT, neo4j.api.SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, None), ("bolt+s://localhost:7676", neo4j.api.DRIVER_BOLT, neo4j.api.SECURITY_TYPE_SECURE, None), - ("neo4j://localhost:7676", neo4j.api.DRIVER_NEO4j, neo4j.api.SECURITY_TYPE_NOT_SECURE, None), - ("neo4j+ssc://localhost:7676", neo4j.api.DRIVER_NEO4j, neo4j.api.SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, None), - ("neo4j+s://localhost:7676", neo4j.api.DRIVER_NEO4j, neo4j.api.SECURITY_TYPE_SECURE, None), + ("neo4j://localhost:7676", neo4j.api.DRIVER_NEO4J, neo4j.api.SECURITY_TYPE_NOT_SECURE, None), + ("neo4j+ssc://localhost:7676", neo4j.api.DRIVER_NEO4J, neo4j.api.SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, None), + ("neo4j+s://localhost:7676", neo4j.api.DRIVER_NEO4J, neo4j.api.SECURITY_TYPE_SECURE, None), ("undefined://localhost:7676", None, None, ConfigurationError), ("localhost:7676", None, None, ConfigurationError), ("://localhost:7676", None, None, ConfigurationError), - ("bolt+routing://localhost:7676", neo4j.api.DRIVER_NEO4j, neo4j.api.SECURITY_TYPE_NOT_SECURE, ConfigurationError), + ("bolt+routing://localhost:7676", neo4j.api.DRIVER_NEO4J, neo4j.api.SECURITY_TYPE_NOT_SECURE, ConfigurationError), ("bolt://username@localhost:7676", None, None, ConfigurationError), ("bolt://username:password@localhost:7676", None, None, ConfigurationError), ] ) -def test_uri_scheme(test_input, expected_driver_type, expected_security_type, expected_error): +def test_uri_scheme( + test_input, expected_driver_type, expected_security_type, expected_error +) -> None: if expected_error: with pytest.raises(expected_error): neo4j.api.parse_neo4j_uri(test_input) @@ -329,16 +343,16 @@ def test_uri_scheme(test_input, expected_driver_type, expected_security_type, ex assert security_type == expected_security_type -def test_parse_routing_context(): +def test_parse_routing_context() -> None: context = neo4j.api.parse_routing_context(query="name=molly&color=white") assert context == {"name": "molly", "color": "white"} -def test_parse_routing_context_should_error_when_value_missing(): +def test_parse_routing_context_should_error_when_value_missing() -> None: with pytest.raises(ConfigurationError): neo4j.api.parse_routing_context("name=&color=white") -def test_parse_routing_context_should_error_when_key_duplicate(): +def test_parse_routing_context_should_error_when_key_duplicate() -> None: with pytest.raises(ConfigurationError): neo4j.api.parse_routing_context("name=molly&name=white") diff --git a/tests/unit/common/test_conf.py b/tests/unit/common/test_conf.py index 390f2cab..684dbaa3 100644 --- a/tests/unit/common/test_conf.py +++ b/tests/unit/common/test_conf.py @@ -46,11 +46,9 @@ test_pool_config = { "connection_timeout": 30.0, "update_routing_table_timeout": 90.0, - "init_size": 1, "keep_alive": True, "max_connection_lifetime": 3600, "max_connection_pool_size": 100, - "protocol_version": None, "resolver": None, "encrypted": False, "user_agent": "test", diff --git a/tests/unit/common/test_debug.py b/tests/unit/common/test_debug.py index f4045546..ece76885 100644 --- a/tests/unit/common/test_debug.py +++ b/tests/unit/common/test_debug.py @@ -16,17 +16,26 @@ # limitations under the License. +from __future__ import annotations + import io import logging import sys +import typing as t import pytest +import typing_extensions as te from neo4j import debug as neo4j_debug +class _TSetupMockProtocol(te.Protocol): + def __call__(self, *args: str) -> t.Sequence[t.Any]: + ... + + @pytest.fixture -def add_handler_mocker(mocker): +def add_handler_mocker(mocker) -> _TSetupMockProtocol: def setup_mock(*logger_names): loggers = [logging.getLogger(name) for name in logger_names] for logger in loggers: @@ -38,7 +47,7 @@ def setup_mock(*logger_names): return setup_mock -def test_watch_returns_watcher(add_handler_mocker): +def test_watch_returns_watcher(add_handler_mocker) -> None: logger_name = "neo4j" add_handler_mocker(logger_name) watcher = neo4j_debug.watch(logger_name) @@ -47,14 +56,14 @@ def test_watch_returns_watcher(add_handler_mocker): @pytest.mark.parametrize("logger_names", (("neo4j",), ("foobar",), ("neo4j", "foobar"))) -def test_watch_enables_logging(logger_names, add_handler_mocker): +def test_watch_enables_logging(logger_names, add_handler_mocker) -> None: loggers = add_handler_mocker(*logger_names) neo4j_debug.watch(*logger_names) for logger in loggers: logger.addHandler.assert_called_once() -def test_watcher_watch_adds_logger(add_handler_mocker): +def test_watcher_watch_adds_logger(add_handler_mocker) -> None: logger_name = "neo4j" logger = add_handler_mocker(logger_name)[0] watcher = neo4j_debug.Watcher(logger_name) @@ -64,7 +73,7 @@ def test_watcher_watch_adds_logger(add_handler_mocker): logger.addHandler.assert_called_once() -def test_watcher_stop_removes_logger(add_handler_mocker): +def test_watcher_stop_removes_logger(add_handler_mocker) -> None: logger_name = "neo4j" logger = add_handler_mocker(logger_name)[0] watcher = neo4j_debug.Watcher(logger_name) @@ -77,9 +86,9 @@ def test_watcher_stop_removes_logger(add_handler_mocker): logger.removeHandler.assert_called_once_with(handler) -def test_watcher_context_manager(mocker): +def test_watcher_context_manager(mocker) -> None: logger_name = "neo4j" - watcher = neo4j_debug.Watcher(logger_name) + watcher: t.Any = neo4j_debug.Watcher(logger_name) watcher.watch = mocker.Mock() watcher.stop = mocker.Mock() @@ -100,8 +109,9 @@ def test_watcher_context_manager(mocker): (None, 1, 1), ) ) -def test_watcher_level(add_handler_mocker, default_level, level, - expected_level): +def test_watcher_level( + add_handler_mocker, default_level, level, expected_level +) -> None: logger_name = "neo4j" logger = add_handler_mocker(logger_name)[0] kwargs = {} @@ -131,8 +141,9 @@ def test_watcher_level(add_handler_mocker, default_level, level, (None, custom_log_out, custom_log_out), ) ) -def test_watcher_out(add_handler_mocker, default_out, out, - expected_out): +def test_watcher_out( + add_handler_mocker, default_out, out, expected_out +) -> None: logger_name = "neo4j" logger = add_handler_mocker(logger_name)[0] kwargs = {} @@ -150,7 +161,7 @@ def test_watcher_out(add_handler_mocker, default_out, out, @pytest.mark.parametrize("colour", (True, False)) -def test_watcher_colour(add_handler_mocker, colour): +def test_watcher_colour(add_handler_mocker, colour) -> None: logger_name = "neo4j" logger = add_handler_mocker(logger_name)[0] watcher = neo4j_debug.Watcher(logger_name, colour=colour) @@ -166,7 +177,7 @@ def test_watcher_colour(add_handler_mocker, colour): @pytest.mark.parametrize("colour", (True, False)) -def test_watcher_format(add_handler_mocker, colour): +def test_watcher_format(add_handler_mocker, colour) -> None: logger_name = "neo4j" logger = add_handler_mocker(logger_name)[0] watcher = neo4j_debug.Watcher(logger_name, colour=colour) diff --git a/tests/unit/common/test_import_neo4j.py b/tests/unit/common/test_import_neo4j.py index aa97bea2..edda3117 100644 --- a/tests/unit/common/test_import_neo4j.py +++ b/tests/unit/common/test_import_neo4j.py @@ -14,6 +14,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + import pytest diff --git a/tests/unit/common/test_record.py b/tests/unit/common/test_record.py index 26a258eb..b37f2261 100644 --- a/tests/unit/common/test_record.py +++ b/tests/unit/common/test_record.py @@ -16,6 +16,8 @@ # limitations under the License. +from __future__ import annotations + import traceback import pytest @@ -30,7 +32,7 @@ # python -m pytest -s -v tests/unit/test_record.py -def test_record_equality(): +def test_record_equality() -> None: record1 = Record(zip(["name", "empire"], ["Nigel", "The British Empire"])) record2 = Record(zip(["name", "empire"], ["Nigel", "The British Empire"])) record3 = Record(zip(["name", "empire"], ["Stefan", "Das Deutschland"])) @@ -39,7 +41,7 @@ def test_record_equality(): assert record2 != record3 -def test_record_hashing(): +def test_record_hashing() -> None: record1 = Record(zip(["name", "empire"], ["Nigel", "The British Empire"])) record2 = Record(zip(["name", "empire"], ["Nigel", "The British Empire"])) record3 = Record(zip(["name", "empire"], ["Stefan", "Das Deutschland"])) @@ -48,32 +50,32 @@ def test_record_hashing(): assert hash(record2) != hash(record3) -def test_record_iter(): +def test_record_iter() -> None: a_record = Record(zip(["name", "empire"], ["Nigel", "The British Empire"])) assert list(a_record.__iter__()) == ["Nigel", "The British Empire"] -def test_record_as_dict(): +def test_record_as_dict() -> None: a_record = Record(zip(["name", "empire"], ["Nigel", "The British Empire"])) assert dict(a_record) == {"name": "Nigel", "empire": "The British Empire"} -def test_record_as_list(): +def test_record_as_list() -> None: a_record = Record(zip(["name", "empire"], ["Nigel", "The British Empire"])) assert list(a_record) == ["Nigel", "The British Empire"] -def test_record_len(): +def test_record_len() -> None: a_record = Record(zip(["name", "empire"], ["Nigel", "The British Empire"])) assert len(a_record) == 2 -def test_record_repr(): +def test_record_repr() -> None: a_record = Record(zip(["name", "empire"], ["Nigel", "The British Empire"])) assert repr(a_record) == "" -def test_record_data(): +def test_record_data() -> None: r = Record(zip(["name", "age", "married"], ["Alice", 33, True])) assert r.data() == {"name": "Alice", "age": 33, "married": True} assert r.data("name") == {"name": "Alice"} @@ -86,12 +88,12 @@ def test_record_data(): _ = r.data(1, 0, 999) -def test_record_keys(): +def test_record_keys() -> None: r = Record(zip(["name", "age", "married"], ["Alice", 33, True])) assert r.keys() == ["name", "age", "married"] -def test_record_values(): +def test_record_values() -> None: r = Record(zip(["name", "age", "married"], ["Alice", 33, True])) assert r.values() == ["Alice", 33, True] assert r.values("name") == ["Alice"] @@ -104,7 +106,7 @@ def test_record_values(): _ = r.values(1, 0, 999) -def test_record_items(): +def test_record_items() -> None: r = Record(zip(["name", "age", "married"], ["Alice", 33, True])) assert r.items() == [("name", "Alice"), ("age", 33), ("married", True)] assert r.items("name") == [("name", "Alice")] @@ -117,7 +119,7 @@ def test_record_items(): _ = r.items(1, 0, 999) -def test_record_index(): +def test_record_index() -> None: r = Record(zip(["name", "age", "married"], ["Alice", 33, True])) assert r.index("name") == 0 assert r.index("age") == 1 @@ -130,10 +132,10 @@ def test_record_index(): with pytest.raises(IndexError): _ = r.index(3) with pytest.raises(TypeError): - _ = r.index(None) + _ = r.index(None) # type: ignore[arg-type] -def test_record_value(): +def test_record_value() -> None: r = Record(zip(["name", "age", "married"], ["Alice", 33, True])) assert r.value() == "Alice" assert r.value("name") == "Alice" @@ -147,10 +149,10 @@ def test_record_value(): assert r.value(3) is None assert r.value(3, 6) == 6 with pytest.raises(TypeError): - _ = r.value(None) + _ = r.value(None) # type: ignore[arg-type] -def test_record_value_kwargs(): +def test_record_value_kwargs() -> None: r = Record(zip(["name", "age", "married"], ["Alice", 33, True])) assert r.value() == "Alice" assert r.value(key="name") == "Alice" @@ -165,60 +167,60 @@ def test_record_value_kwargs(): assert r.value(key=3, default=6) == 6 -def test_record_contains(): +def test_record_contains() -> None: r = Record(zip(["name", "age", "married"], ["Alice", 33, True])) assert "Alice" in r assert 33 in r assert True in r assert 7.5 not in r with pytest.raises(TypeError): - _ = r.index(None) + _ = r.index(None) # type: ignore[arg-type] -def test_record_from_dict(): +def test_record_from_dict() -> None: r = Record({"name": "Alice", "age": 33}) assert r["name"] == "Alice" assert r["age"] == 33 -def test_record_get_slice(): +def test_record_get_slice() -> None: r = Record(zip(["name", "age", "married"], ["Alice", 33, True])) assert Record(zip(["name", "age"], ["Alice", 33])) == r[0:2] -def test_record_get_by_index(): +def test_record_get_by_index() -> None: r = Record(zip(["name", "age", "married"], ["Alice", 33, True])) assert r[0] == "Alice" -def test_record_get_by_name(): +def test_record_get_by_name() -> None: r = Record(zip(["name", "age", "married"], ["Alice", 33, True])) assert r["name"] == "Alice" -def test_record_get_by_out_of_bounds_index(): +def test_record_get_by_out_of_bounds_index() -> None: r = Record(zip(["name", "age", "married"], ["Alice", 33, True])) assert r[9] is None -def test_record_get_item(): +def test_record_get_item() -> None: r = Record(zip(["x", "y"], ["foo", "bar"])) assert r["x"] == "foo" assert r["y"] == "bar" with pytest.raises(KeyError): _ = r["z"] with pytest.raises(TypeError): - _ = r[object()] + _ = r[object()] # type: ignore[index] @pytest.mark.parametrize("len_", (0, 1, 2, 42)) -def test_record_len(len_): +def test_record_len_generic(len_) -> None: r = Record(("key_%i" % i, "val_%i" % i) for i in range(len_)) assert len(r) == len_ @pytest.mark.parametrize("len_", range(3)) -def test_record_repr(len_): +def test_record_repr_generic(len_) -> None: r = Record(("key_%i" % i, "val_%i" % i) for i in range(len_)) assert repr(r) @@ -275,16 +277,21 @@ def test_record_repr(len_): {"x": {"one": 1, "two": 2}} ), ( - zip(["a"], [Node("graph", "42", 42, "Person", {"name": "Alice"})]), + zip( + ["a"], + [Node( + None, # type: ignore[arg-type] + "42", 42, "Person", {"name": "Alice"} + )]), (), {"a": {"name": "Alice"}} ), )) -def test_data(raw, keys, serialized): +def test_data(raw, keys, serialized) -> None: assert Record(raw).data(*keys) == serialized -def test_data_relationship(): +def test_data_relationship() -> None: hydration_scope = HydrationHandler().new_hydration_scope() gh = hydration_scope._graph_hydrator alice = gh.hydrate_node(1, {"Person"}, {"name": "Alice", "age": 33}) @@ -303,7 +310,7 @@ def test_data_relationship(): } -def test_data_unbound_relationship(): +def test_data_unbound_relationship() -> None: hydration_scope = HydrationHandler().new_hydration_scope() gh = hydration_scope._graph_hydrator some_one_knows_some_one = gh.hydrate_relationship( @@ -314,7 +321,7 @@ def test_data_unbound_relationship(): @pytest.mark.parametrize("cyclic", (True, False)) -def test_data_path(cyclic): +def test_data_path(cyclic) -> None: hydration_scope = HydrationHandler().new_hydration_scope() gh = hydration_scope._graph_hydrator alice = gh.hydrate_node(1, {"Person"}, {"name": "Alice", "age": 33}) @@ -385,7 +392,7 @@ def test_data_path(cyclic): (lambda r: r.index(1), False), (lambda r: r.index(2), False), )) -def test_record_with_error(accessor, should_raise): +def test_record_with_error(accessor, should_raise) -> None: class TestException(Exception): pass @@ -401,6 +408,6 @@ class TestException(Exception): return with pytest.raises(BrokenRecordError) as raised: accessor(r) - raised = raised.value - assert raised.__cause__ is exc - assert list(traceback.walk_tb(raised.__cause__.__traceback__)) == frames + exc_value = raised.value + assert exc_value.__cause__ is exc + assert list(traceback.walk_tb(exc_value.__cause__.__traceback__)) == frames diff --git a/tests/unit/common/test_security.py b/tests/unit/common/test_security.py index e4e80491..d1243741 100644 --- a/tests/unit/common/test_security.py +++ b/tests/unit/common/test_security.py @@ -16,6 +16,8 @@ # limitations under the License. +from __future__ import annotations + from neo4j.api import ( basic_auth, bearer_auth, @@ -27,7 +29,7 @@ # python -m pytest -s -v tests/unit/test_security.py -def test_should_generate_kerberos_auth_token_correctly(): +def test_should_generate_kerberos_auth_token_correctly() -> None: auth = kerberos_auth("I am a base64 service ticket") assert auth.scheme == "kerberos" assert auth.principal == "" @@ -37,7 +39,7 @@ def test_should_generate_kerberos_auth_token_correctly(): assert not hasattr(auth, "parameters") -def test_should_generate_bearer_auth_token_correctly(): +def test_should_generate_bearer_auth_token_correctly() -> None: auth = bearer_auth("I am a base64 SSO ticket") assert auth.scheme == "bearer" assert auth.credentials == "I am a base64 SSO ticket" @@ -47,7 +49,7 @@ def test_should_generate_bearer_auth_token_correctly(): assert not hasattr(auth, "parameters") -def test_should_generate_basic_auth_without_realm_correctly(): +def test_should_generate_basic_auth_without_realm_correctly() -> None: auth = basic_auth("molly", "meoooow") assert auth.scheme == "basic" assert auth.principal == "molly" @@ -56,7 +58,7 @@ def test_should_generate_basic_auth_without_realm_correctly(): assert not hasattr(auth, "parameters") -def test_should_generate_base_auth_with_realm_correctly(): +def test_should_generate_base_auth_with_realm_correctly() -> None: auth = basic_auth("molly", "meoooow", "cat_cafe") assert auth.scheme == "basic" assert auth.principal == "molly" @@ -65,7 +67,7 @@ def test_should_generate_base_auth_with_realm_correctly(): assert not hasattr(auth, "parameters") -def test_should_generate_base_auth_with_keyword_realm_correctly(): +def test_should_generate_base_auth_with_keyword_realm_correctly() -> None: auth = basic_auth("molly", "meoooow", realm="cat_cafe") assert auth.scheme == "basic" assert auth.principal == "molly" @@ -74,7 +76,7 @@ def test_should_generate_base_auth_with_keyword_realm_correctly(): assert not hasattr(auth, "parameters") -def test_should_generate_custom_auth_correctly(): +def test_should_generate_custom_auth_correctly() -> None: auth = custom_auth("molly", "meoooow", "cat_cafe", "cat", age="1", color="white") assert auth.scheme == "cat" assert auth.principal == "molly" diff --git a/tests/unit/common/test_types.py b/tests/unit/common/test_types.py index 22a6fcf9..71ca4454 100644 --- a/tests/unit/common/test_types.py +++ b/tests/unit/common/test_types.py @@ -95,7 +95,7 @@ def test_node_with_null_properties(): (*n1, *n2) for n1, n2 in product( ( - (g, id_, element_id, props) + (g, id_, element_id, props) # type: ignore for g in (0, 1) for id_, element_id in ( (1, "1"), diff --git a/tests/unit/common/time/test_date.py b/tests/unit/common/time/test_date.py index f7d9c0c5..d826d016 100644 --- a/tests/unit/common/time/test_date.py +++ b/tests/unit/common/time/test_date.py @@ -16,6 +16,8 @@ # limitations under the License. +from __future__ import annotations + import copy import datetime from datetime import date @@ -38,12 +40,12 @@ class TestDate: - def test_bad_attribute(self): + def test_bad_attribute(self) -> None: d = Date(2000, 1, 1) with pytest.raises(AttributeError): _ = d.x - def test_zero_date(self): + def test_zero_date(self) -> None: d = Date(0, 0, 0) assert d.year_month_day == (0, 0, 0) assert d.year == 0 @@ -51,7 +53,7 @@ def test_zero_date(self): assert d.day == 0 assert d is ZeroDate - def test_zero_ordinal(self): + def test_zero_ordinal(self) -> None: d = Date.from_ordinal(0) assert d.year_month_day == (0, 0, 0) assert d.year == 0 @@ -59,35 +61,35 @@ def test_zero_ordinal(self): assert d.day == 0 assert d is ZeroDate - def test_ordinal_at_start_of_1970(self): + def test_ordinal_at_start_of_1970(self) -> None: d = Date.from_ordinal(719163) assert d.year_month_day == (1970, 1, 1) assert d.year == 1970 assert d.month == 1 assert d.day == 1 - def test_ordinal_at_end_of_1969(self): + def test_ordinal_at_end_of_1969(self) -> None: d = Date.from_ordinal(719162) assert d.year_month_day == (1969, 12, 31) assert d.year == 1969 assert d.month == 12 assert d.day == 31 - def test_ordinal_at_start_of_2018(self): + def test_ordinal_at_start_of_2018(self) -> None: d = Date.from_ordinal(736695) assert d.year_month_day == (2018, 1, 1) assert d.year == 2018 assert d.month == 1 assert d.day == 1 - def test_ordinal_at_end_of_2017(self): + def test_ordinal_at_end_of_2017(self) -> None: d = Date.from_ordinal(736694) assert d.year_month_day == (2017, 12, 31) assert d.year == 2017 assert d.month == 12 assert d.day == 31 - def test_all_positive_days_of_month_for_31_day_month(self): + def test_all_positive_days_of_month_for_31_day_month(self) -> None: for day in range(1, 32): t = Date(1976, 1, day) assert t.year_month_day == (1976, 1, day) @@ -97,7 +99,7 @@ def test_all_positive_days_of_month_for_31_day_month(self): with pytest.raises(ValueError): _ = Date(1976, 1, 32) - def test_all_positive_days_of_month_for_30_day_month(self): + def test_all_positive_days_of_month_for_30_day_month(self) -> None: for day in range(1, 31): t = Date(1976, 6, day) assert t.year_month_day == (1976, 6, day) @@ -107,7 +109,7 @@ def test_all_positive_days_of_month_for_30_day_month(self): with pytest.raises(ValueError): _ = Date(1976, 6, 31) - def test_all_positive_days_of_month_for_29_day_month(self): + def test_all_positive_days_of_month_for_29_day_month(self) -> None: for day in range(1, 30): t = Date(1976, 2, day) assert t.year_month_day == (1976, 2, day) @@ -117,7 +119,7 @@ def test_all_positive_days_of_month_for_29_day_month(self): with pytest.raises(ValueError): _ = Date(1976, 2, 30) - def test_all_positive_days_of_month_for_28_day_month(self): + def test_all_positive_days_of_month_for_28_day_month(self) -> None: for day in range(1, 29): t = Date(1977, 2, day) assert t.year_month_day == (1977, 2, day) @@ -127,129 +129,128 @@ def test_all_positive_days_of_month_for_28_day_month(self): with pytest.raises(ValueError): _ = Date(1977, 2, 29) - def test_last_but_2_day_for_31_day_month(self): + def test_last_but_2_day_for_31_day_month(self) -> None: t = Date(1976, 1, -3) assert t.year_month_day == (1976, 1, 29) assert t.year == 1976 assert t.month == 1 assert t.day == 29 - def test_last_but_1_day_for_31_day_month(self): + def test_last_but_1_day_for_31_day_month(self) -> None: t = Date(1976, 1, -2) assert t.year_month_day == (1976, 1, 30) assert t.year == 1976 assert t.month == 1 assert t.day == 30 - def test_last_day_for_31_day_month(self): + def test_last_day_for_31_day_month(self) -> None: t = Date(1976, 1, -1) assert t.year_month_day == (1976, 1, 31) assert t.year == 1976 assert t.month == 1 assert t.day == 31 - def test_last_but_1_day_for_30_day_month(self): + def test_last_but_1_day_for_30_day_month(self) -> None: t = Date(1976, 6, -2) assert t.year_month_day == (1976, 6, 29) assert t.year == 1976 assert t.month == 6 assert t.day == 29 - def test_last_day_for_30_day_month(self): + def test_last_day_for_30_day_month(self) -> None: t = Date(1976, 6, -1) assert t.year_month_day == (1976, 6, 30) assert t.year == 1976 assert t.month == 6 assert t.day == 30 - def test_day_28_for_29_day_month(self): + def test_day_28_for_29_day_month(self) -> None: t = Date(1976, 2, 28) assert t.year_month_day == (1976, 2, 28) assert t.year == 1976 assert t.month == 2 assert t.day == 28 - def test_last_day_for_29_day_month(self): + def test_last_day_for_29_day_month(self) -> None: t = Date(1976, 2, -1) assert t.year_month_day == (1976, 2, 29) assert t.year == 1976 assert t.month == 2 assert t.day == 29 - def test_last_day_for_28_day_month(self): + def test_last_day_for_28_day_month(self) -> None: t = Date(1977, 2, -1) assert t.year_month_day == (1977, 2, 28) assert t.year == 1977 assert t.month == 2 assert t.day == 28 - def test_cannot_use_year_lower_than_one(self): + def test_cannot_use_year_lower_than_one(self) -> None: with pytest.raises(ValueError): _ = Date(0, 2, 1) - def test_cannot_use_year_higher_than_9999(self): + def test_cannot_use_year_higher_than_9999(self) -> None: with pytest.raises(ValueError): _ = Date(10000, 2, 1) - def test_from_timestamp_without_tz(self): + def test_from_timestamp_without_tz(self) -> None: d = Date.from_timestamp(0) assert d == Date(1970, 1, 1) - def test_from_timestamp_with_tz(self): + def test_from_timestamp_with_tz(self) -> None: d = Date.from_timestamp(0, tz=timezone_eastern) assert d == Date(1969, 12, 31) - def test_utc_from_timestamp(self): + def test_utc_from_timestamp(self) -> None: d = Date.utc_from_timestamp(0) assert d == Date(1970, 1, 1) - def test_from_ordinal(self): + def test_from_ordinal(self) -> None: d = Date.from_ordinal(1) assert d == Date(1, 1, 1) - def test_parse(self): + def test_parse(self) -> None: d = Date.parse("2018-04-30") assert d == Date(2018, 4, 30) - def test_bad_parse_1(self): + def test_bad_parse_1(self) -> None: with pytest.raises(ValueError): _ = Date.parse("30 April 2018") - def test_bad_parse_2(self): + def test_bad_parse_2(self) -> None: with pytest.raises(ValueError): _ = Date.parse("2018-04") - def test_bad_parse_3(self): + def test_bad_parse_3(self) -> None: with pytest.raises(ValueError): - _ = Date.parse(object()) + _ = Date.parse(object()) # type: ignore[arg-type] - def test_replace(self): + def test_replace(self) -> None: d1 = Date(2018, 4, 30) d2 = d1.replace(year=2017) assert d2 == Date(2017, 4, 30) - def test_from_clock_time(self): + def test_from_clock_time(self) -> None: d = Date.from_clock_time((0, 0), epoch=UnixEpoch) assert d == Date(1970, 1, 1) - def test_bad_from_clock_time(self): + def test_bad_from_clock_time(self) -> None: with pytest.raises(ValueError): - _ = Date.from_clock_time(object(), None) - - def test_is_leap_year(self): + _ = Date.from_clock_time(object(), None) # type: ignore[arg-type] + def test_is_leap_year(self) -> None: assert Date.is_leap_year(2000) assert not Date.is_leap_year(2001) - def test_days_in_year(self): + def test_days_in_year(self) -> None: assert Date.days_in_year(2000) == 366 assert Date.days_in_year(2001) == 365 - def test_days_in_month(self): + def test_days_in_month(self) -> None: assert Date.days_in_month(2000, 1) == 31 assert Date.days_in_month(2000, 2) == 29 assert Date.days_in_month(2001, 2) == 28 - def test_instance_attributes(self): + def test_instance_attributes(self) -> None: d = Date(2018, 4, 30) assert d.year == 2018 assert d.month == 4 @@ -258,27 +259,27 @@ def test_instance_attributes(self): assert d.year_week_day == (2018, 18, 1) assert d.year_day == (2018, 120) - def test_can_add_years(self): + def test_can_add_years(self) -> None: d1 = Date(1976, 6, 13) d2 = d1 + Duration(years=2) assert d2 == Date(1978, 6, 13) - def test_can_add_negative_years(self): + def test_can_add_negative_years(self) -> None: d1 = Date(1976, 6, 13) d2 = d1 + Duration(years=-2) assert d2 == Date(1974, 6, 13) - def test_can_add_years_and_months(self): + def test_can_add_years_and_months(self) -> None: d1 = Date(1976, 6, 13) d2 = d1 + Duration(years=2, months=3) assert d2 == Date(1978, 9, 13) - def test_can_add_negative_years_and_months(self): + def test_can_add_negative_years_and_months(self) -> None: d1 = Date(1976, 6, 13) d2 = d1 + Duration(years=-2, months=-3) assert d2 == Date(1974, 3, 13) - def test_can_retain_offset_from_end_of_month(self): + def test_can_retain_offset_from_end_of_month(self) -> None: d = Date(1976, 1, -1) assert d == Date(1976, 1, 31) d += Duration(months=1) @@ -292,178 +293,188 @@ def test_can_retain_offset_from_end_of_month(self): d += Duration(months=1) assert d == Date(1976, 6, 30) - def test_can_roll_over_end_of_year(self): + def test_can_roll_over_end_of_year(self) -> None: d = Date(1976, 12, 1) assert d == Date(1976, 12, 1) d += Duration(months=1) assert d == Date(1977, 1, 1) - def test_can_add_months_and_days(self): + def test_can_add_months_and_days(self) -> None: d1 = Date(1976, 6, 13) d2 = d1 + Duration(months=1, days=1) assert d2 == Date(1976, 7, 14) - def test_can_add_months_then_days(self): + def test_can_add_months_then_days(self) -> None: d1 = Date(1976, 6, 13) d2 = d1 + Duration(months=1) + Duration(days=1) assert d2 == Date(1976, 7, 14) - def test_cannot_add_seconds(self): + def test_cannot_add_seconds(self) -> None: d1 = Date(1976, 6, 13) with pytest.raises(ValueError): _ = d1 + Duration(seconds=1) - def test_adding_empty_duration_returns_self(self): + def test_adding_empty_duration_returns_self(self) -> None: d1 = Date(1976, 6, 13) d2 = d1 + Duration() assert d1 is d2 - def test_adding_object(self): + def test_adding_object(self) -> None: d1 = Date(1976, 6, 13) with pytest.raises(TypeError): - _ = d1 + object() + _ = d1 + object() # type: ignore[operator] - def test_can_add_days_then_months(self): + def test_can_add_days_then_months(self) -> None: d1 = Date(1976, 6, 13) d2 = d1 + Duration(days=1) + Duration(months=1) assert d2 == Date(1976, 7, 14) - def test_can_add_months_and_days_for_last_day_of_short_month(self): + def test_can_add_months_and_days_for_last_day_of_short_month(self) -> None: d1 = Date(1976, 6, 30) d2 = d1 + Duration(months=1, days=1) assert d2 == Date(1976, 8, 1) - def test_can_add_months_then_days_for_last_day_of_short_month(self): + def test_can_add_months_then_days_for_last_day_of_short_month( + self + ) -> None: d1 = Date(1976, 6, 30) d2 = d1 + Duration(months=1) + Duration(days=1) assert d2 == Date(1976, 8, 1) - def test_can_add_days_then_months_for_last_day_of_short_month(self): + def test_can_add_days_then_months_for_last_day_of_short_month( + self + ) -> None: d1 = Date(1976, 6, 30) d2 = d1 + Duration(days=1) + Duration(months=1) assert d2 == Date(1976, 8, 1) - def test_can_add_months_and_days_for_last_day_of_long_month(self): + def test_can_add_months_and_days_for_last_day_of_long_month(self) -> None: d1 = Date(1976, 1, 31) d2 = d1 + Duration(months=1, days=1) assert d2 == Date(1976, 3, 1) - def test_can_add_months_then_days_for_last_day_of_long_month(self): + def test_can_add_months_then_days_for_last_day_of_long_month(self) -> None: d1 = Date(1976, 1, 31) d2 = d1 + Duration(months=1) + Duration(days=1) assert d2 == Date(1976, 3, 1) - def test_can_add_days_then_months_for_last_day_of_long_month(self): + def test_can_add_days_then_months_for_last_day_of_long_month(self) -> None: d1 = Date(1976, 1, 31) d2 = d1 + Duration(days=1) + Duration(months=1) assert d2 == Date(1976, 3, 1) - def test_can_add_negative_months_and_days(self): + def test_can_add_negative_months_and_days(self) -> None: d1 = Date(1976, 6, 13) d2 = d1 + Duration(months=-1, days=-1) assert d2 == Date(1976, 5, 12) - def test_can_add_negative_months_then_days(self): + def test_can_add_negative_months_then_days(self) -> None: d1 = Date(1976, 6, 13) d2 = d1 + Duration(months=-1) + Duration(days=-1) assert d2 == Date(1976, 5, 12) - def test_can_add_negative_days_then_months(self): + def test_can_add_negative_days_then_months(self) -> None: d1 = Date(1976, 6, 13) d2 = d1 + Duration(days=-1) + Duration(months=-1) assert d2 == Date(1976, 5, 12) - def test_can_add_negative_months_and_days_for_first_day_of_month(self): + def test_can_add_negative_months_and_days_for_first_day_of_month( + self + ) -> None: d1 = Date(1976, 6, 1) d2 = d1 + Duration(months=-1, days=-1) assert d2 == Date(1976, 4, 30) - def test_can_add_negative_months_then_days_for_first_day_of_month(self): + def test_can_add_negative_months_then_days_for_first_day_of_month( + self + ) -> None: d1 = Date(1976, 6, 1) d2 = d1 + Duration(months=-1) + Duration(days=-1) assert d2 == Date(1976, 4, 30) - def test_can_add_negative_days_then_months_for_last_day_of_month(self): + def test_can_add_negative_days_then_months_for_last_day_of_month( + self + ) -> None: d1 = Date(1976, 6, 1) d2 = d1 + Duration(days=-1) + Duration(months=-1) assert d2 == Date(1976, 4, 30) - def test_can_add_negative_month_for_last_day_of_long_month(self): + def test_can_add_negative_month_for_last_day_of_long_month(self) -> None: d1 = Date(1976, 5, 31) d2 = d1 + Duration(months=-1) assert d2 == Date(1976, 4, 30) - def test_can_add_negative_month_for_january(self): + def test_can_add_negative_month_for_january(self) -> None: d1 = Date(1976, 1, 31) d2 = d1 + Duration(months=-1) assert d2 == Date(1975, 12, 31) - def test_subtract_date(self): + def test_subtract_date(self) -> None: new_year = Date(2000, 1, 1) christmas = Date(1999, 12, 25) assert new_year - christmas == Duration(days=7) - def test_subtract_duration(self): + def test_subtract_duration(self) -> None: new_year = Date(2000, 1, 1) christmas = Date(1999, 12, 25) assert new_year - Duration(days=7) == christmas - def test_subtract_object(self): + def test_subtract_object(self) -> None: new_year = Date(2000, 1, 1) with pytest.raises(TypeError): - _ = new_year - object() + _ = new_year - object() # type: ignore[operator] - def test_date_less_than(self): + def test_date_less_than(self) -> None: new_year = Date(2000, 1, 1) christmas = Date(1999, 12, 25) assert christmas < new_year - def test_date_less_than_object(self): + def test_date_less_than_object(self) -> None: d = Date(2000, 1, 1) with pytest.raises(TypeError): - _ = d < object() + _ = d < object() # type: ignore[operator] - def test_date_less_than_or_equal_to(self): + def test_date_less_than_or_equal_to(self) -> None: new_year = Date(2000, 1, 1) christmas = Date(1999, 12, 25) assert christmas <= new_year - def test_date_less_than_or_equal_to_object(self): + def test_date_less_than_or_equal_to_object(self) -> None: d = Date(2000, 1, 1) with pytest.raises(TypeError): - _ = d <= object() + _ = d <= object() # type: ignore[operator] - def test_date_greater_than_or_equal_to(self): + def test_date_greater_than_or_equal_to(self) -> None: new_year = Date(2000, 1, 1) christmas = Date(1999, 12, 25) assert new_year >= christmas - def test_date_greater_than_or_equal_to_object(self): + def test_date_greater_than_or_equal_to_object(self) -> None: d = Date(2000, 1, 1) with pytest.raises(TypeError): - _ = d >= object() + _ = d >= object() # type: ignore[operator] - def test_date_greater_than(self): + def test_date_greater_than(self) -> None: new_year = Date(2000, 1, 1) christmas = Date(1999, 12, 25) assert new_year > christmas - def test_date_greater_than_object(self): + def test_date_greater_than_object(self) -> None: d = Date(2000, 1, 1) with pytest.raises(TypeError): - _ = d > object() + _ = d > object() # type: ignore[operator] - def test_date_equal(self): + def test_date_equal(self) -> None: d1 = Date(2000, 1, 1) d2 = Date(2000, 1, 1) assert d1 == d2 - def test_date_not_equal(self): + def test_date_not_equal(self) -> None: d1 = Date(2000, 1, 1) d2 = Date(2000, 1, 2) assert d1 != d2 - def test_date_not_equal_to_object(self): + def test_date_not_equal_to_object(self) -> None: d1 = Date(2000, 1, 1) assert d1 != object() @@ -471,73 +482,73 @@ def test_date_not_equal_to_object(self): Date(2001, 1, 1).to_ordinal(), Date(2008, 1, 1).to_ordinal(), )) - def test_year_week_day(self, ordinal): + def test_year_week_day(self, ordinal) -> None: assert Date.from_ordinal(ordinal).iso_calendar() \ == date.fromordinal(ordinal).isocalendar() - def test_time_tuple(self): + def test_time_tuple(self) -> None: d = Date(2018, 4, 30) expected = struct_time((2018, 4, 30, 0, 0, 0, 0, 120, -1)) assert d.time_tuple() == expected - def test_to_clock_time(self): + def test_to_clock_time(self) -> None: d = Date(2018, 4, 30) assert d.to_clock_time(UnixEpoch) == (1525046400, 0) assert d.to_clock_time(d) == (0, 0) with pytest.raises(TypeError): - _ = d.to_clock_time(object()) + _ = d.to_clock_time(object()) # type: ignore[arg-type] - def test_weekday(self): + def test_weekday(self) -> None: d = Date(2018, 4, 30) assert d.weekday() == 0 - def test_iso_weekday(self): + def test_iso_weekday(self) -> None: d = Date(2018, 4, 30) assert d.iso_weekday() == 1 - def test_str(self): + def test_str(self) -> None: assert str(Date(2018, 4, 30)) == "2018-04-30" assert str(Date(0, 0, 0)) == "0000-00-00" - def test_repr(self): + def test_repr(self) -> None: assert repr(Date(2018, 4, 30)) == "neo4j.time.Date(2018, 4, 30)" assert repr(Date(0, 0, 0)) == "neo4j.time.ZeroDate" - def test_format(self): + def test_format(self) -> None: d = Date(2018, 4, 30) with pytest.raises(NotImplementedError): _ = d.__format__("") - def test_from_native(self): + def test_from_native(self) -> None: native = date(2018, 10, 1) d = Date.from_native(native) assert d.year == native.year assert d.month == native.month assert d.day == native.day - def test_to_native(self): + def test_to_native(self) -> None: d = Date(2018, 10, 1) native = d.to_native() assert d.year == native.year assert d.month == native.month assert d.day == native.day - def test_iso_format(self): + def test_iso_format(self) -> None: d = Date(2018, 10, 1) assert "2018-10-01" == d.iso_format() - def test_from_iso_format(self): + def test_from_iso_format(self) -> None: expected = Date(2018, 10, 1) actual = Date.from_iso_format("2018-10-01") assert expected == actual - def test_date_copy(self): + def test_date_copy(self) -> None: d = Date(2010, 10, 1) d2 = copy.copy(d) assert d is not d2 assert d == d2 - def test_date_deep_copy(self): + def test_date_deep_copy(self) -> None: d = Date(2010, 10, 1) d2 = copy.deepcopy(d) assert d is not d2 @@ -558,7 +569,7 @@ def test_date_deep_copy(self): (datetime.timezone(datetime.timedelta(hours=12)), (1970, 1, 2)), )) -def test_today(tz, expected): +def test_today(tz, expected) -> None: d = Date.today(tz=tz) assert isinstance(d, Date) assert d.year_month_day == expected diff --git a/tests/unit/common/time/test_datetime.py b/tests/unit/common/time/test_datetime.py index aad56bbc..462e6931 100644 --- a/tests/unit/common/time/test_datetime.py +++ b/tests/unit/common/time/test_datetime.py @@ -16,6 +16,8 @@ # limitations under the License. +from __future__ import annotations + import copy import itertools import operator @@ -61,7 +63,7 @@ class TestDateTime: @pytest.mark.parametrize("args", ( (0, 0, 0), (0, 0, 0, 0, 0, 0, 0) )) - def test_zero(self, args): + def test_zero(self, args) -> None: t = DateTime(*args) assert t.year == 0 assert t.month == 0 @@ -71,7 +73,7 @@ def test_zero(self, args): assert t.second == 0 assert t.nanosecond == 0 - def test_non_zero_naive(self): + def test_non_zero_naive(self) -> None: t = DateTime(2018, 4, 26, 23, 0, 17, 914390409) assert t.year == 2018 assert t.month == 4 @@ -81,49 +83,49 @@ def test_non_zero_naive(self): assert t.second == 17 assert t.nanosecond == 914390409 - def test_year_lower_bound(self): + def test_year_lower_bound(self) -> None: with pytest.raises(ValueError): _ = DateTime(MIN_YEAR - 1, 1, 1, 0, 0, 0) - def test_year_upper_bound(self): + def test_year_upper_bound(self) -> None: with pytest.raises(ValueError): _ = DateTime(MAX_YEAR + 1, 1, 1, 0, 0, 0) - def test_month_lower_bound(self): + def test_month_lower_bound(self) -> None: with pytest.raises(ValueError): _ = DateTime(2000, 0, 1, 0, 0, 0) - def test_month_upper_bound(self): + def test_month_upper_bound(self) -> None: with pytest.raises(ValueError): _ = DateTime(2000, 13, 1, 0, 0, 0) - def test_day_zero(self): + def test_day_zero(self) -> None: with pytest.raises(ValueError): _ = DateTime(2000, 1, 0, 0, 0, 0) - def test_day_30_of_29_day_month(self): + def test_day_30_of_29_day_month(self) -> None: with pytest.raises(ValueError): _ = DateTime(2000, 2, 30, 0, 0, 0) - def test_day_32_of_31_day_month(self): + def test_day_32_of_31_day_month(self) -> None: with pytest.raises(ValueError): _ = DateTime(2000, 3, 32, 0, 0, 0) - def test_day_31_of_30_day_month(self): + def test_day_31_of_30_day_month(self) -> None: with pytest.raises(ValueError): _ = DateTime(2000, 4, 31, 0, 0, 0) - def test_day_29_of_28_day_month(self): + def test_day_29_of_28_day_month(self) -> None: with pytest.raises(ValueError): _ = DateTime(1999, 2, 29, 0, 0, 0) - def test_last_day_of_month(self): + def test_last_day_of_month(self) -> None: t = DateTime(2000, 1, -1, 0, 0, 0) assert t.year == 2000 assert t.month == 1 assert t.day == 31 - def test_today(self): + def test_today(self) -> None: t = DateTime.today() assert t.year == 1970 assert t.month == 1 @@ -133,7 +135,7 @@ def test_today(self): assert t.second == 56 assert t.nanosecond == 789000001 - def test_now_without_tz(self): + def test_now_without_tz(self) -> None: t = DateTime.now() assert t.year == 1970 assert t.month == 1 @@ -144,7 +146,7 @@ def test_now_without_tz(self): assert t.nanosecond == 789000001 assert t.tzinfo is None - def test_now_with_tz(self): + def test_now_with_tz(self) -> None: t = DateTime.now(timezone_us_eastern) assert t.year == 1970 assert t.month == 1 @@ -157,7 +159,7 @@ def test_now_with_tz(self): assert t.dst() == timedelta() assert t.tzname() == "EST" - def test_now_with_utc_tz(self): + def test_now_with_utc_tz(self) -> None: t = DateTime.now(timezone_utc) assert t.year == 1970 assert t.month == 1 @@ -170,7 +172,7 @@ def test_now_with_utc_tz(self): assert t.dst() == timedelta() assert t.tzname() == "UTC" - def test_utc_now(self): + def test_utc_now(self) -> None: t = DateTime.utc_now() assert t.year == 1970 assert t.month == 1 @@ -189,17 +191,17 @@ def test_utc_now(self): (datetime_timezone(timedelta(hours=1)), (1970, 1, 1, 1, 0, 0, 0)), (timezone_us_eastern, (1969, 12, 31, 19, 0, 0, 0)), )) - def test_from_timestamp(self, tz, expected): + def test_from_timestamp(self, tz, expected) -> None: t = DateTime.from_timestamp(0, tz=tz) assert t.year_month_day == expected[:3] assert t.hour_minute_second_nanosecond == expected[3:] assert str(t.tzinfo) == str(tz) - def test_from_overflowing_timestamp(self): + def test_from_overflowing_timestamp(self) -> None: with pytest.raises(ValueError): _ = DateTime.from_timestamp(999999999999999999) - def test_from_timestamp_with_tz(self): + def test_from_timestamp_with_tz(self) -> None: t = DateTime.from_timestamp(0, timezone_us_eastern) assert t.year == 1969 assert t.month == 12 @@ -213,20 +215,20 @@ def test_from_timestamp_with_tz(self): assert t.tzname() == "EST" @pytest.mark.parametrize("seconds_args", seconds_options(17, 914390409)) - def test_conversion_to_t(self, seconds_args): + def test_conversion_to_t(self, seconds_args) -> None: dt = DateTime(2018, 4, 26, 23, 0, *seconds_args) t = dt.to_clock_time() assert t, ClockTime(63660380417 == 914390409) @pytest.mark.parametrize("seconds_args1", seconds_options(17, 914390409)) @pytest.mark.parametrize("seconds_args2", seconds_options(17, 914390409)) - def test_add_timedelta(self, seconds_args1, seconds_args2): + def test_add_timedelta(self, seconds_args1, seconds_args2) -> None: dt1 = DateTime(2018, 4, 26, 23, 0, *seconds_args1) delta = timedelta(days=1) dt2 = dt1 + delta assert dt2, DateTime(2018, 4, 27, 23, 0 == seconds_args2) - def test_subtract_datetime_1(self): + def test_subtract_datetime_1(self) -> None: dt1 = DateTime(2018, 4, 26, 23, 0, 17, 914390409) dt2 = DateTime(2018, 1, 1, 0, 0, 0) t = dt1 - dt2 @@ -235,37 +237,41 @@ def test_subtract_datetime_1(self): assert t == Duration(months=3, days=25, hours=23, seconds=17, nanoseconds=914390409) - def test_subtract_datetime_2(self): + def test_subtract_datetime_2(self) -> None: dt1 = DateTime(2018, 4, 1, 23, 0, 17, 914390409) - dt2 = DateTime(2018, 1, 26, 0, 0, 0.0) + dt2 = DateTime(2018, 1, 26, 0, 0, 0) t = dt1 - dt2 assert t == Duration(months=3, days=-25, hours=23, seconds=17.914390409) assert t == Duration(months=3, days=-25, hours=23, seconds=17, nanoseconds=914390409) - def test_subtract_native_datetime_1(self): + def test_subtract_native_datetime_1(self) -> None: dt1 = DateTime(2018, 4, 26, 23, 0, 17, 914390409) dt2 = datetime(2018, 1, 1, 0, 0, 0) t = dt1 - dt2 assert t == timedelta(days=115, hours=23, seconds=17.914390409) - def test_subtract_native_datetime_2(self): + def test_subtract_native_datetime_2(self) -> None: dt1 = DateTime(2018, 4, 1, 23, 0, 17, 914390409) dt2 = datetime(2018, 1, 26, 0, 0, 0) t = dt1 - dt2 assert t == timedelta(days=65, hours=23, seconds=17.914390409) - def test_normalization(self): - ndt1 = timezone_us_eastern.normalize(DateTime(2018, 4, 27, 23, 0, 17, tzinfo=timezone_us_eastern)) - ndt2 = timezone_us_eastern.normalize(datetime(2018, 4, 27, 23, 0, 17, tzinfo=timezone_us_eastern)) + def test_normalization(self) -> None: + ndt1 = timezone_us_eastern.normalize( + DateTime(2018, 4, 27, 23, 0, 17, tzinfo=timezone_us_eastern) + ) + ndt2 = timezone_us_eastern.normalize( + datetime(2018, 4, 27, 23, 0, 17, tzinfo=timezone_us_eastern) + ) assert ndt1 == ndt2 - def test_localization(self): + def test_localization(self) -> None: ldt1 = timezone_us_eastern.localize(datetime(2018, 4, 27, 23, 0, 17)) ldt2 = timezone_us_eastern.localize(DateTime(2018, 4, 27, 23, 0, 17)) assert ldt1 == ldt2 - def test_from_native(self): + def test_from_native(self) -> None: native = datetime(2018, 10, 1, 12, 34, 56, 789123) dt = DateTime.from_native(native) assert dt.year == native.year @@ -276,8 +282,8 @@ def test_from_native(self): assert dt.second == native.second assert dt.nanosecond == native.microsecond * 1000 - def test_to_native(self): - dt = DateTime(2018, 10, 1, 12, 34, 56.789123456) + def test_to_native(self) -> None: + dt = DateTime(2018, 10, 1, 12, 34, 56, 789123456) native = dt.to_native() assert dt.year == native.year assert dt.month == native.month @@ -285,7 +291,7 @@ def test_to_native(self): assert dt.hour == native.hour assert dt.minute == native.minute assert dt.second == native.second - assert dt.nanosecond == native.microsecond * 1000 + assert dt.nanosecond // 1000 == native.microsecond @pytest.mark.parametrize(("dt", "expected"), ( ( @@ -337,77 +343,77 @@ def test_to_native(self): "2018-10-01T12:34:56.789123+00:00" ), )) - def test_iso_format(self, dt, expected): + def test_iso_format(self, dt, expected) -> None: assert dt.isoformat() == expected - def test_from_iso_format_hour_only(self): + def test_from_iso_format_hour_only(self) -> None: expected = DateTime(2018, 10, 1, 12, 0, 0) actual = DateTime.from_iso_format("2018-10-01T12") assert expected == actual - def test_from_iso_format_hour_and_minute(self): + def test_from_iso_format_hour_and_minute(self) -> None: expected = DateTime(2018, 10, 1, 12, 34, 0) actual = DateTime.from_iso_format("2018-10-01T12:34") assert expected == actual - def test_from_iso_format_hour_minute_second(self): + def test_from_iso_format_hour_minute_second(self) -> None: expected = DateTime(2018, 10, 1, 12, 34, 56) actual = DateTime.from_iso_format("2018-10-01T12:34:56") assert expected == actual - def test_from_iso_format_hour_minute_second_milliseconds(self): + def test_from_iso_format_hour_minute_second_milliseconds(self) -> None: expected = DateTime(2018, 10, 1, 12, 34, 56, 123000000) actual = DateTime.from_iso_format("2018-10-01T12:34:56.123") assert expected == actual - def test_from_iso_format_hour_minute_second_microseconds(self): + def test_from_iso_format_hour_minute_second_microseconds(self) -> None: expected = DateTime(2018, 10, 1, 12, 34, 56, 123456000) actual = DateTime.from_iso_format("2018-10-01T12:34:56.123456") assert expected == actual - def test_from_iso_format_hour_minute_second_nanosecond(self): + def test_from_iso_format_hour_minute_second_nanosecond(self) -> None: expected = DateTime(2018, 10, 1, 12, 34, 56, 123456789) actual = DateTime.from_iso_format("2018-10-01T12:34:56.123456789") assert expected == actual - def test_from_iso_format_with_positive_tz(self): + def test_from_iso_format_with_positive_tz(self) -> None: expected = DateTime(2018, 10, 1, 12, 34, 56, 123456789, tzinfo=FixedOffset(754)) actual = DateTime.from_iso_format("2018-10-01T12:34:56.123456789+12:34") assert expected == actual - def test_from_iso_format_with_negative_tz(self): + def test_from_iso_format_with_negative_tz(self) -> None: expected = DateTime(2018, 10, 1, 12, 34, 56, 123456789, tzinfo=FixedOffset(-754)) actual = DateTime.from_iso_format("2018-10-01T12:34:56.123456789-12:34") assert expected == actual - def test_from_iso_format_with_positive_long_tz(self): + def test_from_iso_format_with_positive_long_tz(self) -> None: expected = DateTime(2018, 10, 1, 12, 34, 56, 123456789, tzinfo=FixedOffset(754)) actual = DateTime.from_iso_format("2018-10-01T12:34:56.123456789+12:34:56.123456") assert expected == actual - def test_from_iso_format_with_negative_long_tz(self): + def test_from_iso_format_with_negative_long_tz(self) -> None: expected = DateTime(2018, 10, 1, 12, 34, 56, 123456789, tzinfo=FixedOffset(-754)) actual = DateTime.from_iso_format("2018-10-01T12:34:56.123456789-12:34:56.123456") assert expected == actual - def test_datetime_copy(self): + def test_datetime_copy(self) -> None: d = DateTime(2010, 10, 1, 10, 0, 10) d2 = copy.copy(d) assert d is not d2 assert d == d2 - def test_datetime_deep_copy(self): + def test_datetime_deep_copy(self) -> None: d = DateTime(2010, 10, 1, 10, 0, 12) d2 = copy.deepcopy(d) assert d is not d2 assert d == d2 -def test_iso_format_with_time_zone_case_1(): +def test_iso_format_with_time_zone_case_1() -> None: # python -m pytest tests/unit/time/test_datetime.py -s -v -k test_iso_format_with_time_zone_case_1 expected = DateTime(2019, 10, 30, 7, 54, 2, 129790999, tzinfo=timezone_utc) assert expected.iso_format() == "2019-10-30T07:54:02.129790999+00:00" @@ -416,14 +422,14 @@ def test_iso_format_with_time_zone_case_1(): assert expected == actual -def test_iso_format_with_time_zone_case_2(): +def test_iso_format_with_time_zone_case_2() -> None: # python -m pytest tests/unit/time/test_datetime.py -s -v -k test_iso_format_with_time_zone_case_2 expected = DateTime.from_iso_format("2019-10-30T07:54:02.129790999+01:00") assert expected.tzinfo == FixedOffset(60) assert expected.iso_format() == "2019-10-30T07:54:02.129790999+01:00" -def test_to_native_case_1(): +def test_to_native_case_1() -> None: # python -m pytest tests/unit/time/test_datetime.py -s -v -k test_to_native_case_1 dt = DateTime.from_iso_format("2019-10-30T12:34:56.789123456") native = dt.to_native() @@ -434,7 +440,7 @@ def test_to_native_case_1(): assert native.isoformat() == "2019-10-30T12:34:56.789123" -def test_to_native_case_2(): +def test_to_native_case_2() -> None: # python -m pytest tests/unit/time/test_datetime.py -s -v -k test_to_native_case_2 dt = DateTime.from_iso_format("2019-10-30T12:34:56.789123456+00:00") native = dt.to_native() @@ -445,7 +451,7 @@ def test_to_native_case_2(): assert native.isoformat() == "2019-10-30T12:34:56.789123+00:00" -def test_to_native_case_3(): +def test_to_native_case_3() -> None: # python -m pytest tests/unit/time/test_datetime.py -s -v -k test_to_native_case_3 timestamp = "2021-04-06T00:00:00.500006+00:00" neo4j_datetime = DateTime.from_iso_format(timestamp) @@ -457,7 +463,7 @@ def test_to_native_case_3(): assert native_from_neo4j == native_from_datetime -def test_from_native_case_1(): +def test_from_native_case_1() -> None: # python -m pytest tests/unit/time/test_datetime.py -s -v -k test_from_native_case_1 native = datetime(2018, 10, 1, 12, 34, 56, 789123) dt = DateTime.from_native(native) @@ -471,7 +477,7 @@ def test_from_native_case_1(): assert dt.tzinfo is None -def test_from_native_case_2(): +def test_from_native_case_2() -> None: # python -m pytest tests/unit/time/test_datetime.py -s -v -k test_from_native_case_2 native = datetime(2018, 10, 1, 12, 34, 56, 789123, FixedOffset(0)) dt = DateTime.from_native(native) @@ -486,7 +492,7 @@ def test_from_native_case_2(): @pytest.mark.parametrize("datetime_cls", (DateTime, datetime)) -def test_transition_to_summertime(datetime_cls): +def test_transition_to_summertime(datetime_cls) -> None: dt = datetime_cls(2022, 3, 27, 1, 30) dt = timezone_berlin.localize(dt) assert dt.utcoffset() == timedelta(hours=1) @@ -513,7 +519,7 @@ def test_transition_to_summertime(datetime_cls): @pytest.mark.parametrize("tz", ( timezone_berlin, datetime_timezone(timedelta(hours=-1)) )) -def test_transition_to_summertime_in_utc_space(datetime_cls, utc_impl, tz): +def test_transition_to_summertime_in_utc_space(datetime_cls, utc_impl, tz) -> None: if datetime_cls == DateTime: dt = datetime_cls(2022, 3, 27, 1, 30, 1, 123456789) else: @@ -608,7 +614,7 @@ def test_transition_to_summertime_in_utc_space(datetime_cls, utc_impl, tz): ), )) -def test_equality(dt1, dt2): +def test_equality(dt1, dt2) -> None: assert dt1 == dt2 assert dt2 == dt1 assert dt1 <= dt2 @@ -687,7 +693,7 @@ def test_equality(dt1, dt2): DateTime(2022, 11, 25, 12, 34, 56, 789123456, FixedOffset(0)) ), )) -def test_inequality(dt1, dt2): +def test_inequality(dt1, dt2) -> None: assert dt1 != dt2 assert dt2 != dt1 @@ -718,7 +724,7 @@ def test_inequality(dt1, dt2): repeat=2 ) ) -def test_hashed_equality(dt1, dt2): +def test_hashed_equality(dt1, dt2) -> None: if dt1 == dt2: s = {dt1} assert dt1 in s @@ -751,7 +757,7 @@ def test_hashed_equality(dt1, dt2): @pytest.mark.parametrize("op", ( operator.lt, operator.le, operator.gt, operator.ge, )) -def test_comparison_with_only_one_naive_fails(dt1, dt2, tz, op): +def test_comparison_with_only_one_naive_fails(dt1, dt2, tz, op) -> None: dt1 = dt1.replace(tzinfo=tz) with pytest.raises(TypeError, match="naive"): op(dt1, dt2) @@ -774,7 +780,7 @@ def test_comparison_with_only_one_naive_fails(dt1, dt2, tz, op): @pytest.mark.parametrize("op", ( operator.lt, operator.le, operator.gt, operator.ge, )) -def test_comparison_with_one_naive_and_not_fixed_tz(dt1, dt2, tz, op): +def test_comparison_with_one_naive_and_not_fixed_tz(dt1, dt2, tz, op) -> None: dt1tz = tz.localize(dt1) with pytest.raises(TypeError, match="naive"): op(dt1tz, dt2) @@ -855,7 +861,7 @@ def test_comparison_with_one_naive_and_not_fixed_tz(dt1, dt2, tz, op): DateTime(2022, 11, 25, 12, 34, 56, 789123001, FixedOffset(-1)), ), )) -def test_comparison(dt1, dt2): +def test_comparison(dt1, dt2) -> None: assert dt1 < dt2 assert not dt2 < dt1 assert dt1 <= dt2 diff --git a/tests/unit/common/time/test_duration.py b/tests/unit/common/time/test_duration.py index 8a608f9c..69149a70 100644 --- a/tests/unit/common/time/test_duration.py +++ b/tests/unit/common/time/test_duration.py @@ -16,6 +16,8 @@ # limitations under the License. +from __future__ import annotations + import copy from datetime import timedelta @@ -27,7 +29,7 @@ class TestDuration: - def test_zero(self): + def test_zero(self) -> None: d = Duration() assert d.months == 0 assert d.days == 0 @@ -37,7 +39,7 @@ def test_zero(self): assert d.hours_minutes_seconds_nanoseconds == (0, 0, 0, 0) assert not bool(d) - def test_years_only(self): + def test_years_only(self) -> None: d = Duration(years=2) assert d.months == 24 assert d.days == 0 @@ -47,7 +49,7 @@ def test_years_only(self): assert d.hours_minutes_seconds_nanoseconds == (0, 0, 0, 0) assert bool(d) - def test_months_only(self): + def test_months_only(self) -> None: d = Duration(months=20) assert d.months == 20 assert d.days == 0 @@ -57,11 +59,11 @@ def test_months_only(self): assert d.hours_minutes_seconds_nanoseconds == (0, 0, 0, 0) assert bool(d) - def test_months_out_of_range(self): + def test_months_out_of_range(self) -> None: with pytest.raises(ValueError): _ = Duration(months=(2**64)) - def test_weeks_only(self): + def test_weeks_only(self) -> None: d = Duration(weeks=4) assert d.months == 0 assert d.days == 28 @@ -71,7 +73,7 @@ def test_weeks_only(self): assert d.hours_minutes_seconds_nanoseconds == (0, 0, 0, 0) assert bool(d) - def test_days_only(self): + def test_days_only(self) -> None: d = Duration(days=40) assert d.months == 0 assert d.days == 40 @@ -81,11 +83,11 @@ def test_days_only(self): assert d.hours_minutes_seconds_nanoseconds == (0, 0, 0, 0) assert bool(d) - def test_days_out_of_range(self): + def test_days_out_of_range(self) -> None: with pytest.raises(ValueError): _ = Duration(days=(2**64)) - def test_hours_only(self): + def test_hours_only(self) -> None: d = Duration(hours=10) assert d.months == 0 assert d.days == 0 @@ -95,7 +97,7 @@ def test_hours_only(self): assert d.hours_minutes_seconds_nanoseconds == (10, 0, 0, 0) assert bool(d) - def test_minutes_only(self): + def test_minutes_only(self) -> None: d = Duration(minutes=90.5) assert d.months == 0 assert d.days == 0 @@ -105,7 +107,7 @@ def test_minutes_only(self): assert d.hours_minutes_seconds_nanoseconds == (1, 30, 30, 0) assert bool(d) - def test_seconds_only(self): + def test_seconds_only(self) -> None: d = Duration(seconds=123, nanoseconds=456000000) assert d.months == 0 assert d.days == 0 @@ -115,11 +117,11 @@ def test_seconds_only(self): assert d.hours_minutes_seconds_nanoseconds == (0, 2, 3, 456000000) assert bool(d) - def test_seconds_out_of_range(self): + def test_seconds_out_of_range(self) -> None: with pytest.raises(ValueError): _ = Duration(seconds=(2**64)) - def test_milliseconds_only(self): + def test_milliseconds_only(self) -> None: d = Duration(milliseconds=1234.567) assert d.months == 0 assert d.days == 0 @@ -129,7 +131,7 @@ def test_milliseconds_only(self): assert d.hours_minutes_seconds_nanoseconds == (0, 0, 1, 234567000) assert bool(d) - def test_microseconds_only(self): + def test_microseconds_only(self) -> None: d = Duration(microseconds=1234.567) assert d.months == 0 assert d.days == 0 @@ -139,7 +141,7 @@ def test_microseconds_only(self): assert d.hours_minutes_seconds_nanoseconds == (0, 0, 0, 1234567) assert bool(d) - def test_nanoseconds_only(self): + def test_nanoseconds_only(self) -> None: d = Duration(nanoseconds=1234.567) assert d.months == 0 assert d.days == 0 @@ -149,27 +151,28 @@ def test_nanoseconds_only(self): assert d.hours_minutes_seconds_nanoseconds == (0, 0, 0, 1234) assert bool(d) - def test_can_combine_years_months(self): + def test_can_combine_years_months(self) -> None: t = Duration(years=5, months=3) assert t.months == 63 - def test_can_combine_weeks_and_days(self): + def test_can_combine_weeks_and_days(self) -> None: t = Duration(weeks=5, days=3) assert t.days == 38 - def test_can_combine_hours_minutes_seconds(self): + def test_can_combine_hours_minutes_seconds(self) -> None: t = Duration(hours=5, minutes=4, seconds=3) assert t.seconds == 18243 - def test_can_combine_seconds_and_nanoseconds(self): + def test_can_combine_seconds_and_nanoseconds(self) -> None: t = Duration(seconds=123.456, nanoseconds=321000000) assert t.seconds == 123 assert t.nanoseconds == 777000000 assert t == Duration(seconds=123, nanoseconds=777000000) assert t == Duration(seconds=123.777) - def test_full_positive(self): - d = Duration(years=1, months=2, days=3, hours=4, minutes=5, seconds=6.789) + def test_full_positive(self) -> None: + d = Duration(years=1, months=2, days=3, hours=4, minutes=5, + seconds=6.789) assert d.months == 14 assert d.days == 3 assert d.seconds == 14706 @@ -178,8 +181,9 @@ def test_full_positive(self): assert d.hours_minutes_seconds_nanoseconds == (4, 5, 6, 789000000) assert bool(d) - def test_full_negative(self): - d = Duration(years=-1, months=-2, days=-3, hours=-4, minutes=-5, seconds=-6.789) + def test_full_negative(self) -> None: + d = Duration(years=-1, months=-2, days=-3, hours=-4, minutes=-5, + seconds=-6.789) assert d.months == -14 assert d.days == -3 assert d.seconds == -14706 @@ -188,8 +192,9 @@ def test_full_negative(self): assert d.hours_minutes_seconds_nanoseconds == (-4, -5, -6, -789000000) assert bool(d) - def test_negative_positive(self): - d = Duration(years=-1, months=-2, days=3, hours=-4, minutes=-5, seconds=-6.789) + def test_negative_positive(self) -> None: + d = Duration(years=-1, months=-2, days=3, hours=-4, minutes=-5, + seconds=-6.789) assert d.months == -14 assert d.days == 3 assert d.seconds == -14706 @@ -197,8 +202,9 @@ def test_negative_positive(self): assert d.years_months_days == (-1, -2, 3) assert d.hours_minutes_seconds_nanoseconds == (-4, -5, -6, -789000000) - def test_positive_negative(self): - d = Duration(years=1, months=2, days=-3, hours=4, minutes=5, seconds=6.789) + def test_positive_negative(self) -> None: + d = Duration(years=1, months=2, days=-3, hours=4, minutes=5, + seconds=6.789) assert d.months == 14 assert d.days == -3 assert d.seconds == 14706 @@ -206,84 +212,90 @@ def test_positive_negative(self): assert d.years_months_days == (1, 2, -3) assert d.hours_minutes_seconds_nanoseconds == (4, 5, 6, 789000000) - def test_add_duration(self): + def test_add_duration(self) -> None: d1 = Duration(months=2, days=3, seconds=5, nanoseconds=700000000) d2 = Duration(months=7, days=5, seconds=3, nanoseconds=200000000) d3 = Duration(months=9, days=8, seconds=8, nanoseconds=900000000) assert d1 + d2 == d3 - def test_add_timedelta(self): + def test_add_timedelta(self) -> None: d1 = Duration(months=2, days=3, seconds=5, nanoseconds=700000000) td = timedelta(days=5, seconds=3.2) d3 = Duration(months=2, days=8, seconds=8, nanoseconds=900000000) assert d1 + td == d3 - def test_add_object(self): + def test_add_object(self) -> None: with pytest.raises(TypeError): - _ = Duration(months=2, days=3, seconds=5.7) + object() + _ = (Duration(months=2, days=3, seconds=5.7) + + object()) # type: ignore[operator] - def test_subtract_duration(self): + def test_subtract_duration(self) -> None: d1 = Duration(months=2, days=3, seconds=5, nanoseconds=700000000) d2 = Duration(months=7, days=5, seconds=3, nanoseconds=200000000) d3 = Duration(months=-5, days=-2, seconds=2, nanoseconds=500000000) assert d1 - d2 == d3 - def test_subtract_timedelta(self): + def test_subtract_timedelta(self) -> None: d1 = Duration(months=2, days=3, seconds=5, nanoseconds=700000000) td = timedelta(days=5, seconds=3.2) d3 = Duration(months=2, days=-2, seconds=2, nanoseconds=500000000) assert d1 - td == d3 - def test_subtract_object(self): + def test_subtract_object(self) -> None: with pytest.raises(TypeError): - _ = Duration(months=2, days=3, seconds=5.7) - object() + _ = (Duration(months=2, days=3, seconds=5.7) + - object()) # type: ignore[operator] - def test_multiplication_by_int(self): + def test_multiplication_by_int(self) -> None: d1 = Duration(months=2, days=3, seconds=5.7) i = 11 assert d1 * i == Duration(months=22, days=33, seconds=62.7) - def test_multiplication_by_float(self): + def test_multiplication_by_float(self) -> None: d1 = Duration(months=2, days=3, seconds=5.7) f = 5.5 assert d1 * f == Duration(months=11, days=16, seconds=31.35) - def test_multiplication_by_object(self): + def test_multiplication_by_object(self) -> None: with pytest.raises(TypeError): - _ = Duration(months=2, days=3, seconds=5.7) * object() + _ = (Duration(months=2, days=3, seconds=5.7) + * object()) # type: ignore[operator] @pytest.mark.parametrize("ns", (0, 1)) - def test_floor_division_by_int(self, ns): + def test_floor_division_by_int(self, ns) -> None: d1 = Duration(months=11, days=33, seconds=55.77, nanoseconds=ns) i = 2 assert d1 // i == Duration(months=5, days=16, seconds=27, nanoseconds=885000000) - def test_floor_division_by_object(self): + def test_floor_division_by_object(self) -> None: with pytest.raises(TypeError): - _ = Duration(months=2, days=3, seconds=5.7) // object() + _ = (Duration(months=2, days=3, seconds=5.7) + // object()) # type: ignore[operator] @pytest.mark.parametrize("ns", (0, 1)) - def test_modulus_by_int(self, ns): + def test_modulus_by_int(self, ns) -> None: d1 = Duration(months=11, days=33, seconds=55.77, nanoseconds=ns) i = 2 assert d1 % i == Duration(months=1, days=1, nanoseconds=ns) - def test_modulus_by_object(self): + def test_modulus_by_object(self) -> None: with pytest.raises(TypeError): - _ = Duration(months=2, days=3, seconds=5.7) % object() + _ = (Duration(months=2, days=3, seconds=5.7) + % object()) # type: ignore[operator] @pytest.mark.parametrize("ns", (0, 1)) - def test_floor_division_and_modulus_by_int(self, ns): + def test_floor_division_and_modulus_by_int(self, ns) -> None: d1 = Duration(months=11, days=33, seconds=55.77, nanoseconds=ns) i = 2 assert divmod(d1, i) == (Duration(months=5, days=16, seconds=27, nanoseconds=885000000), Duration(months=1, days=1, nanoseconds=ns)) - def test_floor_division_and_modulus_by_object(self): + def test_floor_division_and_modulus_by_object(self) -> None: with pytest.raises(TypeError): - _ = divmod(Duration(months=2, days=3, seconds=5.7), object()) + _ = divmod(Duration(months=2, days=3, seconds=5.7), + object()) # type: ignore[operator] @pytest.mark.parametrize( ("year", "month", "day"), @@ -297,39 +309,40 @@ def test_floor_division_and_modulus_by_object(self): @pytest.mark.parametrize("ns", range(4)) @pytest.mark.parametrize("divisor", (*range(1, 3), 1_000_000_000)) def test_div_mod_is_well_defined(self, year, month, day, second, ns, - divisor): + divisor) -> None: d1 = Duration(years=year, months=month, days=day, seconds=second, nanoseconds=ns) fraction, rest = divmod(d1, divisor) assert d1 == fraction * divisor + rest - def test_true_division_by_int(self): + def test_true_division_by_int(self) -> None: d1 = Duration(months=11, days=33, seconds=55.77) i = 2 assert d1 / i == Duration(months=6, days=16, seconds=27.885) - def test_true_division_by_float(self): + def test_true_division_by_float(self) -> None: d1 = Duration(months=11, days=33, seconds=55.77) f = 2.5 assert d1 / f == Duration(months=4, days=13, seconds=22.308) - def test_true_division_by_object(self): + def test_true_division_by_object(self) -> None: with pytest.raises(TypeError): - _ = Duration(months=2, days=3, seconds=5.7) / object() + _ = (Duration(months=2, days=3, seconds=5.7) + / object()) # type: ignore[operator] - def test_unary_plus(self): + def test_unary_plus(self) -> None: d = Duration(months=11, days=33, seconds=55.77) assert +d == Duration(months=11, days=33, seconds=55.77) - def test_unary_minus(self): + def test_unary_minus(self) -> None: d = Duration(months=11, days=33, seconds=55.77) assert -d == Duration(months=-11, days=-33, seconds=-55.77) - def test_absolute(self): + def test_absolute(self) -> None: d = Duration(months=-11, days=-33, seconds=-55.77) assert abs(d) == Duration(months=11, days=33, seconds=55.77) - def test_str(self): + def test_str(self) -> None: assert str(Duration()) == "PT0S" assert str(Duration(years=1, months=2)) == "P1Y2M" assert str(Duration(years=-1, months=2)) == "P-10M" @@ -341,36 +354,39 @@ def test_str(self): assert str(Duration(seconds=-0.123456789)) == "PT-0.123456789S" assert str(Duration(seconds=-2, nanoseconds=1)) == "PT-1.999999999S" - def test_repr(self): + def test_repr(self) -> None: d = Duration(months=2, days=3, seconds=5.7) - assert repr(d) == "Duration(months=2, days=3, seconds=5, nanoseconds=700000000)" + assert repr(d) == ( + "Duration(months=2, days=3, seconds=5, nanoseconds=700000000)" + ) - def test_iso_format(self): + def test_iso_format(self) -> None: assert Duration().iso_format() == "PT0S" assert Duration(years=1, months=2).iso_format() == "P1Y2M" assert Duration(years=-1, months=2).iso_format() == "P-10M" assert Duration(months=-13).iso_format() == "P-1Y-1M" - assert Duration(months=2, days=3, seconds=5.7).iso_format() == "P2M3DT5.7S" + assert (Duration(months=2, days=3, seconds=5.7).iso_format() + == "P2M3DT5.7S") assert Duration(hours=12, minutes=34).iso_format() == "PT12H34M" assert Duration(seconds=59).iso_format() == "PT59S" assert Duration(seconds=0.123456789).iso_format() == "PT0.123456789S" assert Duration(seconds=-0.123456789).iso_format() == "PT-0.123456789S" - def test_copy(self): + def test_copy(self) -> None: d = Duration(years=1, months=2, days=3, hours=4, minutes=5, seconds=6, milliseconds=7, microseconds=8, nanoseconds=9) d2 = copy.copy(d) assert d is not d2 assert d == d2 - def test_deep_copy(self): + def test_deep_copy(self) -> None: d = Duration(years=1, months=2, days=3, hours=4, minutes=5, seconds=6, milliseconds=7, microseconds=8, nanoseconds=9) d2 = copy.deepcopy(d) assert d is not d2 assert d == d2 - def test_from_iso_format(self): + def test_from_iso_format(self) -> None: assert Duration() == Duration.from_iso_format("PT0S") assert Duration( hours=12, minutes=34, seconds=56.789 @@ -393,7 +409,7 @@ def test_from_iso_format(self): @pytest.mark.parametrize("with_day", (True, False)) @pytest.mark.parametrize("with_month", (True, False)) @pytest.mark.parametrize("only_ns", (True, False)) - def test_minimal_value(self, with_day, with_month, only_ns): + def test_minimal_value(self, with_day, with_month, only_ns) -> None: seconds = (time.MIN_INT64 + with_month * time.AVERAGE_SECONDS_IN_MONTH + with_day * time.AVERAGE_SECONDS_IN_DAY) @@ -414,7 +430,7 @@ def test_minimal_value(self, with_day, with_month, only_ns): (-1, 0, 0, 0), )) def test_negative_overflow_value(self, with_day, with_month, only_ns, - overflow): + overflow) -> None: seconds = (time.MIN_INT64 + with_month * time.AVERAGE_SECONDS_IN_MONTH + with_day * time.AVERAGE_SECONDS_IN_DAY) @@ -438,7 +454,7 @@ def test_negative_overflow_value(self, with_day, with_month, only_ns, ("days", time.AVERAGE_SECONDS_IN_DAY), ("months", time.AVERAGE_SECONDS_IN_MONTH), )) - def test_minimal_value_only_secondary_field(self, field, module): + def test_minimal_value_only_secondary_field(self, field, module) -> None: kwargs = { field: (time.MIN_INT64 // module - (time.MIN_INT64 % module == 0) @@ -450,7 +466,9 @@ def test_minimal_value_only_secondary_field(self, field, module): ("days", time.AVERAGE_SECONDS_IN_DAY), ("months", time.AVERAGE_SECONDS_IN_MONTH), )) - def test_negative_overflow_value_only_secondary_field(self, field, module): + def test_negative_overflow_value_only_secondary_field( + self, field, module + ) -> None: kwargs = { field: (time.MIN_INT64 // module - (time.MIN_INT64 % module == 0)) @@ -458,7 +476,7 @@ def test_negative_overflow_value_only_secondary_field(self, field, module): with pytest.raises(ValueError): Duration(**kwargs) - def test_negative_overflow_duration_addition(self): + def test_negative_overflow_duration_addition(self) -> None: min_ = Duration.min ns = Duration(nanoseconds=1) with pytest.raises(ValueError): @@ -468,7 +486,7 @@ def test_negative_overflow_duration_addition(self): @pytest.mark.parametrize("with_day", (True, False)) @pytest.mark.parametrize("with_month", (True, False)) @pytest.mark.parametrize("only_ns", (True, False)) - def test_maximal_value(self, with_day, with_month, only_ns): + def test_maximal_value(self, with_day, with_month, only_ns) -> None: seconds = (time.MAX_INT64 - with_month * time.AVERAGE_SECONDS_IN_MONTH - with_day * time.AVERAGE_SECONDS_IN_DAY) @@ -489,7 +507,7 @@ def test_maximal_value(self, with_day, with_month, only_ns): (1, 0, 0, 0), )) def test_positive_overflow_value(self, with_day, with_month, only_ns, - overflow): + overflow) -> None: seconds = (time.MAX_INT64 - with_month * time.AVERAGE_SECONDS_IN_MONTH - with_day * time.AVERAGE_SECONDS_IN_DAY) @@ -513,7 +531,7 @@ def test_positive_overflow_value(self, with_day, with_month, only_ns, ("days", time.AVERAGE_SECONDS_IN_DAY), ("months", time.AVERAGE_SECONDS_IN_MONTH), )) - def test_maximal_value_only_secondary_field(self, field, module): + def test_maximal_value_only_secondary_field(self, field, module) -> None: kwargs = { field: time.MAX_INT64 // module } @@ -523,14 +541,16 @@ def test_maximal_value_only_secondary_field(self, field, module): ("days", time.AVERAGE_SECONDS_IN_DAY), ("months", time.AVERAGE_SECONDS_IN_MONTH), )) - def test_positive_overflow_value_only_secondary_field(self, field, module): + def test_positive_overflow_value_only_secondary_field( + self, field, module + ) -> None: kwargs = { field: time.MAX_INT64 // module + 1 } with pytest.raises(ValueError): Duration(**kwargs) - def test_positive_overflow_duration_addition(self): + def test_positive_overflow_duration_addition(self) -> None: max_ = Duration.max ns = Duration(nanoseconds=1) with pytest.raises(ValueError): diff --git a/tests/unit/common/time/test_time.py b/tests/unit/common/time/test_time.py index 422a8fa0..e94dd3b7 100644 --- a/tests/unit/common/time/test_time.py +++ b/tests/unit/common/time/test_time.py @@ -16,6 +16,8 @@ # limitations under the License. +from __future__ import annotations + import itertools import operator from datetime import ( @@ -45,12 +47,12 @@ class TestTime: - def test_bad_attribute(self): + def test_bad_attribute(self) -> None: t = Time(12, 34, 56, 789000000) with pytest.raises(AttributeError): _ = t.x - def test_simple_time(self): + def test_simple_time(self) -> None: t = Time(12, 34, 56, 789000000) assert t.hour_minute_second_nanosecond == (12, 34, 56, 789000000) assert t.ticks == 45296789000000 @@ -59,7 +61,7 @@ def test_simple_time(self): assert t.second == 56 assert t.nanosecond == 789000000 - def test_midnight(self): + def test_midnight(self) -> None: t = Time(0, 0, 0) assert t.hour_minute_second_nanosecond == (0, 0, 0, 0) assert t.ticks == 0 @@ -68,7 +70,7 @@ def test_midnight(self): assert t.second == 0 assert t.nanosecond == 0 - def test_nanosecond_precision(self): + def test_nanosecond_precision(self) -> None: t = Time(12, 34, 56, 789123456) assert t.hour_minute_second_nanosecond == (12, 34, 56, 789123456) assert t.ticks == 45296789123456 @@ -77,7 +79,7 @@ def test_nanosecond_precision(self): assert t.second == 56 assert t.nanosecond == 789123456 - def test_str(self): + def test_str(self) -> None: t = Time(12, 34, 56, 789123456) assert str(t) == "12:34:56.789123456" @@ -89,19 +91,19 @@ def test_str(self): (datetime_timezone(timedelta(hours=1)), (13, 34, 56, 789000001)), (timezone_us_eastern, (7, 34, 56, 789000001)), )) - def test_now(self, tz, expected): + def test_now(self, tz, expected) -> None: t = Time.now(tz=tz) assert isinstance(t, Time) assert t.hour_minute_second_nanosecond == expected assert str(t.tzinfo) == str(tz) - def test_utc_now(self): + def test_utc_now(self) -> None: t = Time.utc_now() assert isinstance(t, Time) assert t.hour_minute_second_nanosecond == (12, 34, 56, 789000001) assert t.tzinfo is None - def test_from_native(self): + def test_from_native(self) -> None: native = time(12, 34, 56, 789123) t = Time.from_native(native) assert t.hour == native.hour @@ -109,7 +111,7 @@ def test_from_native(self): assert t.second == native.second assert t.nanosecond == native.microsecond * 1000 - def test_to_native(self): + def test_to_native(self) -> None: t = Time(12, 34, 56, 789123456) native = t.to_native() assert t.hour == native.hour @@ -167,75 +169,72 @@ def test_to_native(self): "12:34:56.123456+01:23" ), )) - def test_iso_format(self, t, expected): + def test_iso_format(self, t, expected) -> None: assert t.isoformat() == expected - def test_from_iso_format_hour_only(self): + def test_from_iso_format_hour_only(self) -> None: expected = Time(12, 0, 0) actual = Time.from_iso_format("12") assert expected == actual - def test_from_iso_format_hour_and_minute(self): + def test_from_iso_format_hour_and_minute(self) -> None: expected = Time(12, 34, 0) actual = Time.from_iso_format("12:34") assert expected == actual - def test_from_iso_format_hour_minute_second(self): + def test_from_iso_format_hour_minute_second(self) -> None: expected = Time(12, 34, 56) actual = Time.from_iso_format("12:34:56") assert expected == actual - def test_from_iso_format_hour_minute_second_milliseconds(self): + def test_from_iso_format_hour_minute_second_milliseconds(self) -> None: expected = Time(12, 34, 56, 123000000) actual = Time.from_iso_format("12:34:56.123") assert expected == actual - def test_from_iso_format_hour_minute_second_microseconds(self): + def test_from_iso_format_hour_minute_second_microseconds(self) -> None: expected = Time(12, 34, 56, 123456000) actual = Time.from_iso_format("12:34:56.123456") assert expected == actual - def test_from_iso_format_hour_minute_second_nanosecond(self): + def test_from_iso_format_hour_minute_second_nanosecond(self) -> None: expected = Time(12, 34, 56, 123456789) actual = Time.from_iso_format("12:34:56.123456789") assert expected == actual - def test_from_iso_format_with_positive_tz(self): + def test_from_iso_format_with_positive_tz(self) -> None: expected = Time(12, 34, 56, 123456789, tzinfo=FixedOffset(754)) actual = Time.from_iso_format("12:34:56.123456789+12:34") assert expected == actual - def test_from_iso_format_with_negative_tz(self): + def test_from_iso_format_with_negative_tz(self) -> None: expected = Time(12, 34, 56, 123456789, tzinfo=FixedOffset(-754)) actual = Time.from_iso_format("12:34:56.123456789-12:34") assert expected == actual - def test_from_iso_format_with_positive_long_tz(self): + def test_from_iso_format_with_positive_long_tz(self) -> None: expected = Time(12, 34, 56, 123456789, tzinfo=FixedOffset(754)) actual = Time.from_iso_format("12:34:56.123456789+12:34:56.123456") assert expected == actual - def test_from_iso_format_with_negative_long_tz(self): + def test_from_iso_format_with_negative_long_tz(self) -> None: expected = Time(12, 34, 56, 123456789, tzinfo=FixedOffset(-754)) actual = Time.from_iso_format("12:34:56.123456789-12:34:56.123456") assert expected == actual - def test_from_iso_format_with_hour_only_tz(self): + def test_from_iso_format_with_hour_only_tz(self) -> None: expected = Time(12, 34, 56, 123456789, tzinfo=FixedOffset(120)) actual = Time.from_iso_format("12:34:56.123456789+02:00") assert expected == actual - def test_utc_offset_fixed(self): - expected = Time(12, 34, 56, 123456789, tzinfo=FixedOffset(-754)) - actual = -754 * 60 - assert expected.utc_offset().total_seconds() == actual - - def test_utc_offset_variable(self): + def test_utc_offset_fixed(self) -> None: expected = Time(12, 34, 56, 123456789, tzinfo=FixedOffset(-754)) actual = -754 * 60 - assert expected.utc_offset().total_seconds() == actual + offset = expected.utc_offset() + assert offset is not None + assert offset.total_seconds() == actual - def test_iso_format_with_time_zone_case_1(self): + def test_iso_format_with_time_zone_case_1(self) -> None: # python -m pytest tests/unit/time/test_time.py -s -v -k test_iso_format_with_time_zone_case_1 expected = Time(7, 54, 2, 129790999, tzinfo=timezone_utc) assert expected.iso_format() == "07:54:02.129790999+00:00" @@ -243,13 +242,13 @@ def test_iso_format_with_time_zone_case_1(self): actual = Time.from_iso_format("07:54:02.129790999+00:00") assert expected == actual - def test_iso_format_with_time_zone_case_2(self): + def test_iso_format_with_time_zone_case_2(self) -> None: # python -m pytest tests/unit/time/test_time.py -s -v -k test_iso_format_with_time_zone_case_2 expected = Time.from_iso_format("07:54:02.129790999+01:00") assert expected.tzinfo == FixedOffset(60) assert expected.iso_format() == "07:54:02.129790999+01:00" - def test_to_native_case_1(self): + def test_to_native_case_1(self) -> None: # python -m pytest tests/unit/time/test_time.py -s -v -k test_to_native_case_1 t = Time(12, 34, 56, 789123456) native = t.to_native() @@ -260,7 +259,7 @@ def test_to_native_case_1(self): assert native.tzinfo is None assert native.isoformat() == "12:34:56.789123" - def test_to_native_case_2(self): + def test_to_native_case_2(self) -> None: # python -m pytest tests/unit/time/test_time.py -s -v -k test_to_native_case_2 t = Time(12, 34, 56, 789123456, tzinfo=timezone_utc) native = t.to_native() @@ -271,7 +270,7 @@ def test_to_native_case_2(self): assert native.tzinfo == FixedOffset(0) assert native.isoformat() == "12:34:56.789123+00:00" - def test_from_native_case_1(self): + def test_from_native_case_1(self) -> None: # python -m pytest tests/unit/time/test_time.py -s -v -k test_from_native_case_1 native = time(12, 34, 56, 789123) t = Time.from_native(native) @@ -281,7 +280,7 @@ def test_from_native_case_1(self): assert t.nanosecond == native.microsecond * 1000 assert t.tzinfo is None - def test_from_native_case_2(self): + def test_from_native_case_2(self) -> None: # python -m pytest tests/unit/time/test_time.py -s -v -k test_from_native_case_2 native = time(12, 34, 56, 789123, FixedOffset(0)) t = Time.from_native(native) @@ -329,7 +328,7 @@ def test_from_native_case_2(self): Time(12, 34, 56, 789123123, FixedOffset(0)) ), )) - def test_equality(self, t1, t2): + def test_equality(self, t1, t2) -> None: assert t1 == t2 assert t2 == t1 assert t1 <= t2 @@ -380,7 +379,7 @@ def test_equality(self, t1, t2): Time(12, 34, 56, 789123456, FixedOffset(0)) ), )) - def test_inequality(self, t1, t2): + def test_inequality(self, t1, t2) -> None: assert t1 != t2 assert t2 != t1 @@ -406,7 +405,7 @@ def test_inequality(self, t1, t2): repeat=2 ) ) - def test_hashed_equality(self, t1, t2): + def test_hashed_equality(self, t1, t2) -> None: if t1 == t2: s = {t1} assert t1 in s @@ -438,7 +437,9 @@ def test_hashed_equality(self, t1, t2): @pytest.mark.parametrize("op", ( operator.lt, operator.le, operator.gt, operator.ge, )) - def test_comparison_with_only_one_naive_fails(self, t1, t2, tz, op): + def test_comparison_with_only_one_naive_fails( + self, t1, t2, tz, op + ) -> None: t1 = t1.replace(tzinfo=tz) with pytest.raises(TypeError, match="naive"): op(t1, t2) @@ -460,7 +461,9 @@ def test_comparison_with_only_one_naive_fails(self, t1, t2, tz, op): @pytest.mark.parametrize("op", ( operator.lt, operator.le, operator.gt, operator.ge, )) - def test_comparison_with_one_naive_and_not_fixed_tz(self, t1, t2, tz, op): + def test_comparison_with_one_naive_and_not_fixed_tz( + self, t1, t2, tz, op + ) -> None: t1tz = t1.replace(tzinfo=tz) res = op(t1tz, t2) expected = op(t1, t2) @@ -505,7 +508,7 @@ def test_comparison_with_one_naive_and_not_fixed_tz(self, t1, t2, tz, op): Time(12, 34, 56, 789123001, FixedOffset(-1)), ), )) - def test_comparison(self, t1, t2): + def test_comparison(self, t1, t2) -> None: assert t1 < t2 assert not t2 < t1 assert t1 <= t2 diff --git a/tests/unit/sync/work/test_result.py b/tests/unit/sync/work/test_result.py index c9eb9165..cc8434e7 100644 --- a/tests/unit/sync/work/test_result.py +++ b/tests/unit/sync/work/test_result.py @@ -829,12 +829,12 @@ def test_to_df(keys, values, types, instances, test_default_expand): ( ["n"], list(zip(( - Node(None, "00", 0, ["LABEL_A"], - {"a": 1, "b": 2, "d": 1}), - Node(None, "02", 2, ["LABEL_B"], - {"a": 1, "c": 1.2, "d": 2}), - Node(None, "01", 1, ["LABEL_A", "LABEL_B"], - {"a": [1, "a"], "d": 3}), + Node(None, # type: ignore[arg-type] + "00", 0, ["LABEL_A"], {"a": 1, "b": 2, "d": 1}), + Node(None, # type: ignore[arg-type] + "02", 2, ["LABEL_B"], {"a": 1, "c": 1.2, "d": 2}), + Node(None, # type: ignore[arg-type] + "01", 1, ["LABEL_A", "LABEL_B"], {"a": [1, "a"], "d": 3}), ))), [ "n().element_id", "n().labels", "n().prop.a", "n().prop.b", @@ -1009,8 +1009,8 @@ def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["mixed"], [ - [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6),], - [neo4j_time.Date(2222, 2, 22),], + [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)], + [neo4j_time.Date(2222, 2, 22)], [ pytz.timezone("Europe/Stockholm").localize( neo4j_time.DateTime(1970, 1, 1, 0, 0, 0, 0) @@ -1105,7 +1105,7 @@ def test_broken_hydration(nested): assert len(record_out) == 2 assert record_out[0] == "foobar" with pytest.raises(BrokenRecordError) as exc: - record_out[1] + _ = record_out[1] cause = exc.value.__cause__ assert isinstance(cause, ValueError) assert repr(b"a") in str(cause) From a0f05739e2d4aec898cc709be36d915fd6d3a4a2 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Mon, 1 Aug 2022 10:22:24 +0200 Subject: [PATCH 2/5] Fix typing_extensions imports and docs --- docs/source/api.rst | 62 +++++++++++---------- neo4j/_async/work/result.py | 6 +- neo4j/_async_compat/network/_bolt_socket.py | 4 +- neo4j/_data.py | 5 +- neo4j/_sync/driver.py | 1 - neo4j/_sync/work/result.py | 6 +- neo4j/addressing.py | 10 +++- neo4j/exceptions.py | 44 +++++++-------- neo4j/time/__init__.py | 11 ++-- requirements-dev.txt | 1 + tests/unit/common/test_debug.py | 13 +++-- 11 files changed, 90 insertions(+), 73 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 76b354e8..5c3fde6f 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -1263,43 +1263,44 @@ Server-side errors * :class:`neo4j.exceptions.ForbiddenOnReadOnlyDatabase` -.. autoclass:: neo4j.exceptions.Neo4jError +.. autoexception:: neo4j.exceptions.Neo4jError() + :show-inheritance: :members: message, code, is_retriable, is_retryable -.. autoclass:: neo4j.exceptions.ClientError +.. autoexception:: neo4j.exceptions.ClientError() :show-inheritance: -.. autoclass:: neo4j.exceptions.CypherSyntaxError +.. autoexception:: neo4j.exceptions.CypherSyntaxError() :show-inheritance: -.. autoclass:: neo4j.exceptions.CypherTypeError +.. autoexception:: neo4j.exceptions.CypherTypeError() :show-inheritance: -.. autoclass:: neo4j.exceptions.ConstraintError +.. autoexception:: neo4j.exceptions.ConstraintError() :show-inheritance: -.. autoclass:: neo4j.exceptions.AuthError +.. autoexception:: neo4j.exceptions.AuthError() :show-inheritance: -.. autoclass:: neo4j.exceptions.TokenExpired +.. autoexception:: neo4j.exceptions.TokenExpired() :show-inheritance: -.. autoclass:: neo4j.exceptions.Forbidden +.. autoexception:: neo4j.exceptions.Forbidden() :show-inheritance: -.. autoclass:: neo4j.exceptions.DatabaseError +.. autoexception:: neo4j.exceptions.DatabaseError() :show-inheritance: -.. autoclass:: neo4j.exceptions.TransientError +.. autoexception:: neo4j.exceptions.TransientError() :show-inheritance: -.. autoclass:: neo4j.exceptions.DatabaseUnavailable +.. autoexception:: neo4j.exceptions.DatabaseUnavailable() :show-inheritance: -.. autoclass:: neo4j.exceptions.NotALeader +.. autoexception:: neo4j.exceptions.NotALeader() :show-inheritance: -.. autoclass:: neo4j.exceptions.ForbiddenOnReadOnlyDatabase +.. autoexception:: neo4j.exceptions.ForbiddenOnReadOnlyDatabase() :show-inheritance: @@ -1344,52 +1345,55 @@ Client-side errors * :class:`neo4j.exceptions.CertificateConfigurationError` -.. autoclass:: neo4j.exceptions.DriverError +.. autoexception:: neo4j.exceptions.DriverError() + :show-inheritance: :members: is_retryable -.. autoclass:: neo4j.exceptions.TransactionError +.. autoexception:: neo4j.exceptions.TransactionError() :show-inheritance: + :members: transaction -.. autoclass:: neo4j.exceptions.TransactionNestingError +.. autoexception:: neo4j.exceptions.TransactionNestingError() :show-inheritance: -.. autoclass:: neo4j.exceptions.ResultError +.. autoexception:: neo4j.exceptions.ResultError() :show-inheritance: + :members: result -.. autoclass:: neo4j.exceptions.ResultConsumedError +.. autoexception:: neo4j.exceptions.ResultConsumedError() :show-inheritance: -.. autoclass:: neo4j.exceptions.ResultNotSingleError +.. autoexception:: neo4j.exceptions.ResultNotSingleError() :show-inheritance: -.. autoclass:: neo4j.exceptions.BrokenRecordError +.. autoexception:: neo4j.exceptions.BrokenRecordError() :show-inheritance: -.. autoclass:: neo4j.exceptions.SessionExpired +.. autoexception:: neo4j.exceptions.SessionExpired() :show-inheritance: -.. autoclass:: neo4j.exceptions.ServiceUnavailable +.. autoexception:: neo4j.exceptions.ServiceUnavailable() :show-inheritance: -.. autoclass:: neo4j.exceptions.RoutingServiceUnavailable +.. autoexception:: neo4j.exceptions.RoutingServiceUnavailable() :show-inheritance: -.. autoclass:: neo4j.exceptions.WriteServiceUnavailable +.. autoexception:: neo4j.exceptions.WriteServiceUnavailable() :show-inheritance: -.. autoclass:: neo4j.exceptions.ReadServiceUnavailable +.. autoexception:: neo4j.exceptions.ReadServiceUnavailable() :show-inheritance: -.. autoclass:: neo4j.exceptions.IncompleteCommit +.. autoexception:: neo4j.exceptions.IncompleteCommit() :show-inheritance: -.. autoclass:: neo4j.exceptions.ConfigurationError +.. autoexception:: neo4j.exceptions.ConfigurationError() :show-inheritance: -.. autoclass:: neo4j.exceptions.AuthConfigurationError +.. autoexception:: neo4j.exceptions.AuthConfigurationError() :show-inheritance: -.. autoclass:: neo4j.exceptions.CertificateConfigurationError +.. autoexception:: neo4j.exceptions.CertificateConfigurationError() :show-inheritance: diff --git a/neo4j/_async/work/result.py b/neo4j/_async/work/result.py index 9ab443a3..5e870585 100644 --- a/neo4j/_async/work/result.py +++ b/neo4j/_async/work/result.py @@ -22,7 +22,9 @@ from collections import deque from warnings import warn -import typing_extensions as te + +if t.TYPE_CHECKING: + import typing_extensions as te from ..._async_compat.util import AsyncUtil from ..._codec.hydration import BrokenHydrationObject @@ -50,7 +52,7 @@ _T = t.TypeVar("_T") -_T_ResultKey: te.TypeAlias = t.Union[int, str] +_T_ResultKey = t.Union[int, str] _RESULT_OUT_OF_SCOPE_ERROR = ( diff --git a/neo4j/_async_compat/network/_bolt_socket.py b/neo4j/_async_compat/network/_bolt_socket.py index d7e274e2..81d05e44 100644 --- a/neo4j/_async_compat/network/_bolt_socket.py +++ b/neo4j/_async_compat/network/_bolt_socket.py @@ -39,7 +39,9 @@ SSLSocket, ) -import typing_extensions as te + +if t.TYPE_CHECKING: + import typing_extensions as te from ... import addressing from ..._deadline import Deadline diff --git a/neo4j/_data.py b/neo4j/_data.py index 6d08ddad..fe679ce0 100644 --- a/neo4j/_data.py +++ b/neo4j/_data.py @@ -42,9 +42,8 @@ ) -if t.TYPE_CHECKING: - _T = t.TypeVar("_T") - _T_K = t.Union[int, str] +_T = t.TypeVar("_T") +_T_K = t.Union[int, str] class Record(tuple, Mapping): diff --git a/neo4j/_sync/driver.py b/neo4j/_sync/driver.py index de685dc5..8da43a14 100644 --- a/neo4j/_sync/driver.py +++ b/neo4j/_sync/driver.py @@ -96,7 +96,6 @@ def driver( keep_alive: bool = ..., # undocumented/unsupported options - # might be removed/changed without warning, even in patch versions session_connection_timeout: float = ..., connection_acquisition_timeout: float = ..., max_transaction_retry_time: float = ..., diff --git a/neo4j/_sync/work/result.py b/neo4j/_sync/work/result.py index 450262ca..f1b7cc90 100644 --- a/neo4j/_sync/work/result.py +++ b/neo4j/_sync/work/result.py @@ -22,7 +22,9 @@ from collections import deque from warnings import warn -import typing_extensions as te + +if t.TYPE_CHECKING: + import typing_extensions as te from ..._async_compat.util import Util from ..._codec.hydration import BrokenHydrationObject @@ -50,7 +52,7 @@ _T = t.TypeVar("_T") -_T_ResultKey: te.TypeAlias = t.Union[int, str] +_T_ResultKey = t.Union[int, str] _RESULT_OUT_OF_SCOPE_ERROR = ( diff --git a/neo4j/addressing.py b/neo4j/addressing.py index 0f0e2d22..6cebf798 100644 --- a/neo4j/addressing.py +++ b/neo4j/addressing.py @@ -27,7 +27,9 @@ getservbyname, ) -import typing_extensions as te + +if t.TYPE_CHECKING: + import typing_extensions as te log = logging.getLogger("neo4j") @@ -36,8 +38,10 @@ _T = t.TypeVar("_T") -class _WithPeerName(te.Protocol): - def getpeername(self) -> tuple: ... +if t.TYPE_CHECKING: + + class _WithPeerName(te.Protocol): + def getpeername(self) -> tuple: ... assert type(tuple) is type diff --git a/neo4j/exceptions.py b/neo4j/exceptions.py index 7c72af69..81d64332 100644 --- a/neo4j/exceptions.py +++ b/neo4j/exceptions.py @@ -76,19 +76,28 @@ import typing_extensions as te from ._async.work import ( + AsyncManagedTransaction, AsyncResult, AsyncSession, - AsyncTransactionBase, + AsyncTransaction, ) from ._sync.work import ( + ManagedTransaction, Result, Session, - TransactionBase, + Transaction, ) - _T_Transaction = t.Union[AsyncTransactionBase, TransactionBase] + _T_Transaction = t.Union[AsyncManagedTransaction, AsyncTransaction, + ManagedTransaction, Transaction] _T_Result = t.Union[AsyncResult, Result] _T_Session = t.Union[AsyncSession, Session] +else: + _T_Transaction = t.Union["AsyncManagedTransaction", "AsyncTransaction", + "ManagedTransaction", "Transaction"] + _T_Result = t.Union["AsyncResult", "Result"] + _T_Session = t.Union["AsyncSession", "Session"] + from ._meta import deprecated @@ -384,12 +393,11 @@ class TransactionError(DriverError): """ Raised when an error occurs while using a transaction. """ - def __init__( - self, transaction: _T_Transaction, - *args, **kwargs - ) -> None: + transaction: _T_Transaction + + def __init__(self, transaction_, *args, **kwargs): super().__init__(*args, **kwargs) - self.transaction = transaction + self.transaction = transaction_ # DriverError > TransactionNestingError @@ -397,23 +405,16 @@ class TransactionNestingError(TransactionError): """ Raised when transactions are nested incorrectly. """ - def __init__( - self, transaction: _T_Transaction, - *args, **kwargs - ) -> None: - super().__init__(*args, **kwargs) - self.transaction = transaction - # DriverError > ResultError class ResultError(DriverError): """Raised when an error occurs while using a result object.""" - def __init__( - self, result: _T_Result, *args, **kwargs - ) -> None: + result: _T_Result + + def __init__(self, result_, *args, **kwargs): super().__init__(*args, **kwargs) - self.result = result + self.result = result_ # DriverError > ResultError > ResultConsumedError @@ -441,11 +442,6 @@ class SessionExpired(DriverError): the purpose described by its original parameters. """ - def __init__( - self, session: _T_Session, *args, **kwargs - ) -> None: - super().__init__(session, *args, **kwargs) - def is_retryable(self) -> bool: return True diff --git a/neo4j/time/__init__.py b/neo4j/time/__init__.py index 4dea3655..848bce32 100644 --- a/neo4j/time/__init__.py +++ b/neo4j/time/__init__.py @@ -41,7 +41,9 @@ struct_time, ) -import typing_extensions as te + +if t.TYPE_CHECKING: + import typing_extensions as te from ._arithmetic import ( nano_add, @@ -1218,7 +1220,6 @@ def __deepcopy__(self, *args, **kwargs) -> Date: # INSTANCE METHODS # - if t.TYPE_CHECKING: def replace( @@ -1567,6 +1568,7 @@ def __normalize_nanosecond(cls, hour, minute, second, nanosecond): # CLASS METHOD ALIASES # if t.TYPE_CHECKING: + @classmethod def from_iso_format(cls, s: str) -> Time: ... @@ -1895,6 +1897,7 @@ def __getattr__(self, name): raise AttributeError("Date has no attribute %r" % name) if t.TYPE_CHECKING: + def isoformat(self) -> str: # type: ignore[override] ... @@ -2130,6 +2133,7 @@ def from_clock_time( # CLASS METHOD ALIASES # if t.TYPE_CHECKING: + @classmethod def fromisoformat(cls, s) -> DateTime: ... @@ -2162,8 +2166,6 @@ def utcfromtimestamp(cls, timestamp: float) -> DateTime: def utcnow(cls) -> DateTime: ... - - # CLASS ATTRIBUTES # min: te.Final[DateTime] = None # type: ignore @@ -2665,6 +2667,7 @@ def __getattr__(self, name): raise AttributeError("DateTime has no attribute %r" % name) if t.TYPE_CHECKING: + def astimezone( # type: ignore[override] self, tz: _tzinfo ) -> DateTime: diff --git a/requirements-dev.txt b/requirements-dev.txt index 3d4596f5..b9845fe5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,6 +3,7 @@ unasync>=0.5.0 pre-commit>=2.15.0 isort>=5.10.0 mypy>=0.971 +typing-extensions>=4.3.0 types-pytz>=2022.1.2 # needed for running tests diff --git a/tests/unit/common/test_debug.py b/tests/unit/common/test_debug.py index ece76885..b5350436 100644 --- a/tests/unit/common/test_debug.py +++ b/tests/unit/common/test_debug.py @@ -24,14 +24,19 @@ import typing as t import pytest -import typing_extensions as te + + +if t.TYPE_CHECKING: + import typing_extensions as te from neo4j import debug as neo4j_debug -class _TSetupMockProtocol(te.Protocol): - def __call__(self, *args: str) -> t.Sequence[t.Any]: - ... +if t.TYPE_CHECKING: + + class _TSetupMockProtocol(te.Protocol): + def __call__(self, *args: str) -> t.Sequence[t.Any]: + ... @pytest.fixture From a2b0ef051806e05cb2b2cf5c63a31e318dd51c69 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Mon, 1 Aug 2022 12:32:21 +0200 Subject: [PATCH 3/5] Add type hints for `verify_connectivity` and `get_server_info` --- neo4j/_async/driver.py | 186 ++++++++++++++++++++++++++--------------- neo4j/_sync/driver.py | 186 ++++++++++++++++++++++++++--------------- 2 files changed, 236 insertions(+), 136 deletions(-) diff --git a/neo4j/_async/driver.py b/neo4j/_async/driver.py index 377c506d..0130ccb6 100644 --- a/neo4j/_async/driver.py +++ b/neo4j/_async/driver.py @@ -96,6 +96,7 @@ def driver( keep_alive: bool = ..., # undocumented/unsupported options + # they may be change or removed any time without prior notice session_connection_timeout: float = ..., connection_acquisition_timeout: float = ..., max_transaction_retry_time: float = ..., @@ -347,14 +348,17 @@ def session( session_connection_timeout: float = ..., connection_acquisition_timeout: float = ..., max_transaction_retry_time: float = ..., - initial_retry_delay: float = ..., - retry_delay_multiplier: float = ..., - retry_delay_jitter_factor: float = ..., database: t.Optional[str] = ..., fetch_size: int = ..., impersonated_user: t.Optional[str] = ..., bookmarks: t.Union[t.Iterable[str], Bookmarks, None] = ..., default_access_mode: str = ..., + + # undocumented/unsupported options + # they may be change or removed any time without prior notice + initial_retry_delay: float = ..., + retry_delay_multiplier: float = ..., + retry_delay_jitter_factor: float = ..., ) -> AsyncSession: ... @@ -377,83 +381,129 @@ async def close(self) -> None: await self._pool.close() self._closed = True - # TODO: 6.0 - remove config argument - async def verify_connectivity(self, **config) -> None: - """Verify that the driver can establish a connection to the server. + if t.TYPE_CHECKING: - This verifies if the driver can establish a reading connection to a - remote server or a cluster. Some data will be exchanged. + async def verify_connectivity( + self, + # all arguments are experimental + # they may be change or removed any time without prior notice + session_connection_timeout: float = ..., + connection_acquisition_timeout: float = ..., + max_transaction_retry_time: float = ..., + database: t.Optional[str] = ..., + fetch_size: int = ..., + impersonated_user: t.Optional[str] = ..., + bookmarks: t.Union[t.Iterable[str], Bookmarks, None] = ..., + default_access_mode: str = ..., - .. note:: - Even if this method raises an exception, the driver still needs to - be closed via :meth:`close` to free up all resources. + # undocumented/unsupported options + initial_retry_delay: float = ..., + retry_delay_multiplier: float = ..., + retry_delay_jitter_factor: float = ..., + ) -> None: + ... - :param config: accepts the same configuration key-word arguments as - :meth:`session`. + else: - .. warning:: - All configuration key-word arguments are experimental. - They might be changed or removed in any future version without - prior notice. + # TODO: 6.0 - remove config argument + async def verify_connectivity(self, **config) -> None: + """Verify that the driver can establish a connection to the server. - :raises DriverError: if the driver cannot connect to the remote. - Use the exception to further understand the cause of the - connectivity problem. + This verifies if the driver can establish a reading connection to a + remote server or a cluster. Some data will be exchanged. - .. versionchanged:: 5.0 - The undocumented return value has been removed. - If you need information about the remote server, use - :meth:`get_server_info` instead. - """ - if config: - experimental_warn( - "All configuration key-word arguments to " - "verify_connectivity() are experimental. They might be " - "changed or removed in any future version without prior " - "notice." - ) - async with self.session(**config) as session: - await session._get_server_info() - - async def get_server_info(self, **config) -> ServerInfo: - """Get information about the connected Neo4j server. - - Try to establish a working read connection to the remote server or a - member of a cluster and exchange some data. Then return the contacted - server's information. - - In a cluster, there is no guarantee about which server will be - contacted. + .. note:: + Even if this method raises an exception, the driver still needs + to be closed via :meth:`close` to free up all resources. - .. note:: - Even if this method raises an exception, the driver still needs to - be closed via :meth:`close` to free up all resources. + :param config: accepts the same configuration key-word arguments as + :meth:`session`. + + .. warning:: + All configuration key-word arguments are experimental. + They might be changed or removed in any future version + without prior notice. + + :raises DriverError: if the driver cannot connect to the remote. + Use the exception to further understand the cause of the + connectivity problem. - :param config: accepts the same configuration key-word arguments as - :meth:`session`. + .. versionchanged:: 5.0 + The undocumented return value has been removed. + If you need information about the remote server, use + :meth:`get_server_info` instead. + """ + if config: + experimental_warn( + "All configuration key-word arguments to " + "verify_connectivity() are experimental. They might be " + "changed or removed in any future version without prior " + "notice." + ) + async with self.session(**config) as session: + await session._get_server_info() - .. warning:: - All configuration key-word arguments are experimental. - They might be changed or removed in any future version without - prior notice. + if t.TYPE_CHECKING: - :rtype: ServerInfo + async def get_server_info( + self, + # all arguments are experimental + # they may be change or removed any time without prior notice + session_connection_timeout: float = ..., + connection_acquisition_timeout: float = ..., + max_transaction_retry_time: float = ..., + database: t.Optional[str] = ..., + fetch_size: int = ..., + impersonated_user: t.Optional[str] = ..., + bookmarks: t.Union[t.Iterable[str], Bookmarks, None] = ..., + default_access_mode: str = ..., - :raises DriverError: if the driver cannot connect to the remote. - Use the exception to further understand the cause of the - connectivity problem. + # undocumented/unsupported options + initial_retry_delay: float = ..., + retry_delay_multiplier: float = ..., + retry_delay_jitter_factor: float = ..., + ) -> ServerInfo: + ... - .. versionadded:: 5.0 - """ - if config: - experimental_warn( - "All configuration key-word arguments to " - "verify_connectivity() are experimental. They might be " - "changed or removed in any future version without prior " - "notice." - ) - async with self.session(**config) as session: - return await session._get_server_info() + else: + + async def get_server_info(self, **config) -> ServerInfo: + """Get information about the connected Neo4j server. + + Try to establish a working read connection to the remote server or + a member of a cluster and exchange some data. Then return the + contacted server's information. + + In a cluster, there is no guarantee about which server will be + contacted. + + .. note:: + Even if this method raises an exception, the driver still needs + to be closed via :meth:`close` to free up all resources. + + :param config: accepts the same configuration key-word arguments as + :meth:`session`. + + .. warning:: + All configuration key-word arguments are experimental. + They might be changed or removed in any future version + without prior notice. + + :raises DriverError: if the driver cannot connect to the remote. + Use the exception to further understand the cause of the + connectivity problem. + + .. versionadded:: 5.0 + """ + if config: + experimental_warn( + "All configuration key-word arguments to " + "verify_connectivity() are experimental. They might be " + "changed or removed in any future version without prior " + "notice." + ) + async with self.session(**config) as session: + return await session._get_server_info() @experimental("Feature support query, based on Bolt protocol version and Neo4j server version will change in the future.") async def supports_multi_db(self) -> bool: diff --git a/neo4j/_sync/driver.py b/neo4j/_sync/driver.py index 8da43a14..40414aa3 100644 --- a/neo4j/_sync/driver.py +++ b/neo4j/_sync/driver.py @@ -96,6 +96,7 @@ def driver( keep_alive: bool = ..., # undocumented/unsupported options + # they may be change or removed any time without prior notice session_connection_timeout: float = ..., connection_acquisition_timeout: float = ..., max_transaction_retry_time: float = ..., @@ -347,14 +348,17 @@ def session( session_connection_timeout: float = ..., connection_acquisition_timeout: float = ..., max_transaction_retry_time: float = ..., - initial_retry_delay: float = ..., - retry_delay_multiplier: float = ..., - retry_delay_jitter_factor: float = ..., database: t.Optional[str] = ..., fetch_size: int = ..., impersonated_user: t.Optional[str] = ..., bookmarks: t.Union[t.Iterable[str], Bookmarks, None] = ..., default_access_mode: str = ..., + + # undocumented/unsupported options + # they may be change or removed any time without prior notice + initial_retry_delay: float = ..., + retry_delay_multiplier: float = ..., + retry_delay_jitter_factor: float = ..., ) -> Session: ... @@ -377,83 +381,129 @@ def close(self) -> None: self._pool.close() self._closed = True - # TODO: 6.0 - remove config argument - def verify_connectivity(self, **config) -> None: - """Verify that the driver can establish a connection to the server. + if t.TYPE_CHECKING: - This verifies if the driver can establish a reading connection to a - remote server or a cluster. Some data will be exchanged. + def verify_connectivity( + self, + # all arguments are experimental + # they may be change or removed any time without prior notice + session_connection_timeout: float = ..., + connection_acquisition_timeout: float = ..., + max_transaction_retry_time: float = ..., + database: t.Optional[str] = ..., + fetch_size: int = ..., + impersonated_user: t.Optional[str] = ..., + bookmarks: t.Union[t.Iterable[str], Bookmarks, None] = ..., + default_access_mode: str = ..., - .. note:: - Even if this method raises an exception, the driver still needs to - be closed via :meth:`close` to free up all resources. + # undocumented/unsupported options + initial_retry_delay: float = ..., + retry_delay_multiplier: float = ..., + retry_delay_jitter_factor: float = ..., + ) -> None: + ... - :param config: accepts the same configuration key-word arguments as - :meth:`session`. + else: - .. warning:: - All configuration key-word arguments are experimental. - They might be changed or removed in any future version without - prior notice. + # TODO: 6.0 - remove config argument + def verify_connectivity(self, **config) -> None: + """Verify that the driver can establish a connection to the server. - :raises DriverError: if the driver cannot connect to the remote. - Use the exception to further understand the cause of the - connectivity problem. + This verifies if the driver can establish a reading connection to a + remote server or a cluster. Some data will be exchanged. - .. versionchanged:: 5.0 - The undocumented return value has been removed. - If you need information about the remote server, use - :meth:`get_server_info` instead. - """ - if config: - experimental_warn( - "All configuration key-word arguments to " - "verify_connectivity() are experimental. They might be " - "changed or removed in any future version without prior " - "notice." - ) - with self.session(**config) as session: - session._get_server_info() - - def get_server_info(self, **config) -> ServerInfo: - """Get information about the connected Neo4j server. - - Try to establish a working read connection to the remote server or a - member of a cluster and exchange some data. Then return the contacted - server's information. - - In a cluster, there is no guarantee about which server will be - contacted. + .. note:: + Even if this method raises an exception, the driver still needs + to be closed via :meth:`close` to free up all resources. - .. note:: - Even if this method raises an exception, the driver still needs to - be closed via :meth:`close` to free up all resources. + :param config: accepts the same configuration key-word arguments as + :meth:`session`. + + .. warning:: + All configuration key-word arguments are experimental. + They might be changed or removed in any future version + without prior notice. + + :raises DriverError: if the driver cannot connect to the remote. + Use the exception to further understand the cause of the + connectivity problem. - :param config: accepts the same configuration key-word arguments as - :meth:`session`. + .. versionchanged:: 5.0 + The undocumented return value has been removed. + If you need information about the remote server, use + :meth:`get_server_info` instead. + """ + if config: + experimental_warn( + "All configuration key-word arguments to " + "verify_connectivity() are experimental. They might be " + "changed or removed in any future version without prior " + "notice." + ) + with self.session(**config) as session: + session._get_server_info() - .. warning:: - All configuration key-word arguments are experimental. - They might be changed or removed in any future version without - prior notice. + if t.TYPE_CHECKING: - :rtype: ServerInfo + def get_server_info( + self, + # all arguments are experimental + # they may be change or removed any time without prior notice + session_connection_timeout: float = ..., + connection_acquisition_timeout: float = ..., + max_transaction_retry_time: float = ..., + database: t.Optional[str] = ..., + fetch_size: int = ..., + impersonated_user: t.Optional[str] = ..., + bookmarks: t.Union[t.Iterable[str], Bookmarks, None] = ..., + default_access_mode: str = ..., - :raises DriverError: if the driver cannot connect to the remote. - Use the exception to further understand the cause of the - connectivity problem. + # undocumented/unsupported options + initial_retry_delay: float = ..., + retry_delay_multiplier: float = ..., + retry_delay_jitter_factor: float = ..., + ) -> ServerInfo: + ... - .. versionadded:: 5.0 - """ - if config: - experimental_warn( - "All configuration key-word arguments to " - "verify_connectivity() are experimental. They might be " - "changed or removed in any future version without prior " - "notice." - ) - with self.session(**config) as session: - return session._get_server_info() + else: + + def get_server_info(self, **config) -> ServerInfo: + """Get information about the connected Neo4j server. + + Try to establish a working read connection to the remote server or + a member of a cluster and exchange some data. Then return the + contacted server's information. + + In a cluster, there is no guarantee about which server will be + contacted. + + .. note:: + Even if this method raises an exception, the driver still needs + to be closed via :meth:`close` to free up all resources. + + :param config: accepts the same configuration key-word arguments as + :meth:`session`. + + .. warning:: + All configuration key-word arguments are experimental. + They might be changed or removed in any future version + without prior notice. + + :raises DriverError: if the driver cannot connect to the remote. + Use the exception to further understand the cause of the + connectivity problem. + + .. versionadded:: 5.0 + """ + if config: + experimental_warn( + "All configuration key-word arguments to " + "verify_connectivity() are experimental. They might be " + "changed or removed in any future version without prior " + "notice." + ) + with self.session(**config) as session: + return session._get_server_info() @experimental("Feature support query, based on Bolt protocol version and Neo4j server version will change in the future.") def supports_multi_db(self) -> bool: From 0652a59aaf796ad4fcf504277c3959cf0ec5dc2e Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Tue, 2 Aug 2022 11:58:07 +0200 Subject: [PATCH 4/5] Fix type hints for Python 3.9+ --- neo4j/time/__init__.py | 11 ++++++++--- tests/unit/common/test_types.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/neo4j/time/__init__.py b/neo4j/time/__init__.py index 848bce32..cd4232e4 100644 --- a/neo4j/time/__init__.py +++ b/neo4j/time/__init__.py @@ -1328,7 +1328,9 @@ def __getattr__(self, name): raise AttributeError("Date has no attribute %r" % name) if t.TYPE_CHECKING: - isocalendar = iso_calendar + def iso_calendar(self) -> t.Tuple[int, int, int]: + ... + isoformat = iso_format isoweekday = iso_weekday strftime = __format__ @@ -2668,12 +2670,15 @@ def __getattr__(self, name): if t.TYPE_CHECKING: - def astimezone( # type: ignore[override] + def astimezone( # type: ignore[override] self, tz: _tzinfo ) -> DateTime: ... - isocalendar = iso_calendar + def isocalendar( # type: ignore[override] + self + ) -> t.Tuple[int, int, int]: + ... def iso_format(self, sep: str = "T") -> str: # type: ignore[override] ... diff --git a/tests/unit/common/test_types.py b/tests/unit/common/test_types.py index 71ca4454..f314178e 100644 --- a/tests/unit/common/test_types.py +++ b/tests/unit/common/test_types.py @@ -94,7 +94,7 @@ def test_node_with_null_properties(): "g2", "id2", "eid2", "props2"), ( (*n1, *n2) for n1, n2 in product( - ( + ( # type: ignore (g, id_, element_id, props) # type: ignore for g in (0, 1) for id_, element_id in ( From 9f2155ed5593ab2b587994479f5d2caf855c1155 Mon Sep 17 00:00:00 2001 From: Antonio Barcelos Date: Tue, 2 Aug 2022 12:13:14 +0200 Subject: [PATCH 5/5] Clean-up #TODO comment Signed-off-by: Rouven Bauer --- neo4j/_async/driver.py | 8 +------- neo4j/_sync/driver.py | 8 +------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/neo4j/_async/driver.py b/neo4j/_async/driver.py index 0130ccb6..6107f17d 100644 --- a/neo4j/_async/driver.py +++ b/neo4j/_async/driver.py @@ -116,13 +116,7 @@ def driver( "neo4j async is in experimental phase. It might be removed or " "changed at any time (including patch releases)." ) - def driver( - cls, - uri: str, - *, - auth: t.Union[t.Tuple[t.Any, t.Any], Auth] = None, - **config # TODO: type config - ) -> AsyncDriver: + def driver(cls, uri, *, auth=None, **config) -> AsyncDriver: """Create a driver. :param uri: the connection URI for the driver, see :ref:`async-uri-ref` for available URIs. diff --git a/neo4j/_sync/driver.py b/neo4j/_sync/driver.py index 40414aa3..5834b46b 100644 --- a/neo4j/_sync/driver.py +++ b/neo4j/_sync/driver.py @@ -116,13 +116,7 @@ def driver( "neo4j is in experimental phase. It might be removed or " "changed at any time (including patch releases)." ) - def driver( - cls, - uri: str, - *, - auth: t.Union[t.Tuple[t.Any, t.Any], Auth] = None, - **config # TODO: type config - ) -> Driver: + def driver(cls, uri, *, auth=None, **config) -> Driver: """Create a driver. :param uri: the connection URI for the driver, see :ref:`uri-ref` for available URIs.