From ede283e0b0893022fba8da53813c18099e94134b Mon Sep 17 00:00:00 2001 From: "Jason.Yu" Date: Tue, 19 Oct 2021 23:49:53 +0800 Subject: [PATCH 1/2] `BaseRestartWorkChain`: clearify the log report when output ports not compatible (#5149) The `BaseRestartWorkChain` output ports required are reported to missing in two conditions. 1) when the nested `CalcJob` do not correctly giving the outputs. 2) when the base work chain does not correctly expose the output ports, which is not considered here. It causes issue https://github.com/aiidateam/aiida-quantumespresso/issues/723 These shold be two different scenario but we report at one log. Would be better to separate this two situation but that will complicate the code logic here (need to check the actual outputs of embed process and then raise for the first half report), I don't think it worth to do for this. --- aiida/engine/processes/workchains/restart.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/aiida/engine/processes/workchains/restart.py b/aiida/engine/processes/workchains/restart.py index 4aa562a4d9..3c8ca8d47d 100644 --- a/aiida/engine/processes/workchains/restart.py +++ b/aiida/engine/processes/workchains/restart.py @@ -311,7 +311,10 @@ def results(self) -> Optional['ExitCode']: output = exposed_outputs[name] except KeyError: if port.required: - self.report(f"required output '{name}' was not an output of {self.ctx.process_name}<{node.pk}>") + self.report( + f'required output \'{name}\' was not an output of {self.ctx.process_name}<{node.pk}> ' + f'(or an incorrect class/output is being exposed).' + ) else: self.out(name, output) From fc0b4587df59416bf7667ba3795f57a83267a5e9 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 20 Oct 2021 04:02:07 +0200 Subject: [PATCH 2/2] =?UTF-8?q?=F0=9F=91=8C=20IMPROVE:=20Entity=20Collecti?= =?UTF-8?q?on=20typing=20(#5183)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR makes static types available for `Entity.objects -> Collection`. The `Collection` subclasses are now defined outside of the `Entity` subclasses, using `Generic`s to dynamically define the entity subclass their methods return. A number of fixes were also made, to ensure the correct backend is used in queries. --- .pre-commit-config.yaml | 1 + aiida/common/datastructures.py | 36 ---- aiida/common/lang.py | 7 +- aiida/orm/authinfos.py | 33 ++-- aiida/orm/comments.py | 71 ++++--- aiida/orm/computers.py | 69 ++++--- aiida/orm/entities.py | 272 ++++++++++++--------------- aiida/orm/groups.py | 73 +++---- aiida/orm/implementation/comments.py | 4 +- aiida/orm/implementation/logs.py | 10 +- aiida/orm/logs.py | 163 ++++++++-------- aiida/orm/nodes/data/remote/base.py | 2 +- aiida/orm/nodes/node.py | 96 +++++----- aiida/orm/users.py | 106 ++++++----- docs/source/howto/data.rst | 2 +- docs/source/nitpick-exceptions | 24 ++- tests/orm/test_groups.py | 4 +- 17 files changed, 490 insertions(+), 483 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 810c81fd80..9ccf4c40fc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -81,6 +81,7 @@ repos: aiida/orm/implementation/sqlalchemy/backend.py| aiida/orm/implementation/querybuilder.py| aiida/orm/implementation/sqlalchemy/querybuilder/.*py| + aiida/orm/entities.py| aiida/orm/nodes/data/jsonable.py| aiida/orm/nodes/node.py| aiida/orm/nodes/process/.*py| diff --git a/aiida/common/datastructures.py b/aiida/common/datastructures.py index 672aea27a0..a9cccc3d27 100644 --- a/aiida/common/datastructures.py +++ b/aiida/common/datastructures.py @@ -165,39 +165,3 @@ class CodeRunMode(IntEnum): SERIAL = 0 PARALLEL = 1 - - -class LazyStore: - """ - A container that provides a mapping to objects based on a key, if the object is not - found in the container when it is retrieved it will created using a provided factory - method - """ - - def __init__(self): - self._store = {} - - def get(self, key, factory): - """ - Get a value in the store based on the key, if it doesn't exist it will be created - using the factory method and returned - - :param key: the key of the object to get - :param factory: the factory used to create the object if necessary - :return: the object - """ - try: - return self._store[key] - except KeyError: - obj = factory() - self._store[key] = obj - return obj - - def pop(self, key): - """ - Pop an object from the store based on the given key - - :param key: the object key - :return: the object that was popped - """ - return self._store.pop(key) diff --git a/aiida/common/lang.py b/aiida/common/lang.py index 37f977b8b9..63b2d05afb 100644 --- a/aiida/common/lang.py +++ b/aiida/common/lang.py @@ -11,7 +11,7 @@ import functools import inspect import keyword -from typing import Any, Callable, Generic, TypeVar +from typing import Any, Callable, Generic, Type, TypeVar def isidentifier(identifier): @@ -77,6 +77,7 @@ def wrapped_fn(self, *args, **kwargs): # pylint: disable=missing-docstring override = override_decorator(check=False) # pylint: disable=invalid-name ReturnType = TypeVar('ReturnType') +SelfType = TypeVar('SelfType') class classproperty(Generic[ReturnType]): # pylint: disable=invalid-name @@ -88,8 +89,8 @@ class classproperty(Generic[ReturnType]): # pylint: disable=invalid-name instance as its first argument). """ - def __init__(self, getter: Callable[[Any], ReturnType]) -> None: + def __init__(self, getter: Callable[[Type[SelfType]], ReturnType]) -> None: self.getter = getter - def __get__(self, instance: Any, owner: Any) -> ReturnType: + def __get__(self, instance: Any, owner: Type[SelfType]) -> ReturnType: return self.getter(owner) diff --git a/aiida/orm/authinfos.py b/aiida/orm/authinfos.py index 5a478c5fd0..49f03416bb 100644 --- a/aiida/orm/authinfos.py +++ b/aiida/orm/authinfos.py @@ -8,8 +8,10 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for the `AuthInfo` ORM class.""" +from typing import Type from aiida.common import exceptions +from aiida.common.lang import classproperty from aiida.manage.manager import get_manager from aiida.plugins import TransportFactory @@ -18,22 +20,33 @@ __all__ = ('AuthInfo',) +class AuthInfoCollection(entities.Collection['AuthInfo']): + """The collection of `AuthInfo` entries.""" + + @staticmethod + def _entity_base_cls() -> Type['AuthInfo']: + return AuthInfo + + def delete(self, pk: int) -> None: + """Delete an entry from the collection. + + :param pk: the pk of the entry to delete + """ + self._backend.authinfos.delete(pk) + + class AuthInfo(entities.Entity): """ORM class that models the authorization information that allows a `User` to connect to a `Computer`.""" - class Collection(entities.Collection): - """The collection of `AuthInfo` entries.""" + Collection = AuthInfoCollection - def delete(self, pk): - """Delete an entry from the collection. - - :param pk: the pk of the entry to delete - """ - self._backend.authinfos.delete(pk) + @classproperty + def objects(cls) -> AuthInfoCollection: # pylint: disable=no-self-argument + return AuthInfoCollection.get_cached(cls, get_manager().get_backend()) PROPERTY_WORKDIR = 'workdir' - def __init__(self, computer, user, backend=None): + def __init__(self, computer, user, backend=None) -> None: """Create an `AuthInfo` instance for the given computer and user. :param computer: a `Computer` instance @@ -41,8 +54,6 @@ def __init__(self, computer, user, backend=None): :param user: a `User` instance :type user: :class:`aiida.orm.User` - - :rtype: :class:`aiida.orm.authinfos.AuthInfo` """ backend = backend or get_manager().get_backend() model = backend.authinfos.create(computer=computer.backend_entity, user=user.backend_entity) diff --git a/aiida/orm/comments.py b/aiida/orm/comments.py index d4ed7eaace..08a04ace7a 100644 --- a/aiida/orm/comments.py +++ b/aiida/orm/comments.py @@ -8,7 +8,9 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Comment objects and functions""" +from typing import List, Type +from aiida.common.lang import classproperty from aiida.manage.manager import get_manager from . import entities, users @@ -16,46 +18,55 @@ __all__ = ('Comment',) -class Comment(entities.Entity): - """Base class to map a DbComment that represents a comment attached to a certain Node.""" +class CommentCollection(entities.Collection['Comment']): + """The collection of Comment entries.""" + + @staticmethod + def _entity_base_cls() -> Type['Comment']: + return Comment + + def delete(self, pk: int) -> None: + """ + Remove a Comment from the collection with the given id + + :param pk: the id of the comment to delete + + :raises TypeError: if ``comment_id`` is not an `int` + :raises `~aiida.common.exceptions.NotExistent`: if Comment with ID ``comment_id`` is not found + """ + self._backend.comments.delete(pk) - class Collection(entities.Collection): - """The collection of Comment entries.""" + def delete_all(self) -> None: + """ + Delete all Comments from the Collection - def delete(self, comment_id): - """ - Remove a Comment from the collection with the given id + :raises `~aiida.common.exceptions.IntegrityError`: if all Comments could not be deleted + """ + self._backend.comments.delete_all() - :param comment_id: the id of the comment to delete - :type comment_id: int + def delete_many(self, filters) -> List[int]: + """ + Delete Comments from the Collection based on ``filters`` - :raises TypeError: if ``comment_id`` is not an `int` - :raises `~aiida.common.exceptions.NotExistent`: if Comment with ID ``comment_id`` is not found - """ - self._backend.comments.delete(comment_id) + :param filters: similar to QueryBuilder filter + :type filters: dict - def delete_all(self): - """ - Delete all Comments from the Collection + :return: (former) ``PK`` s of deleted Comments - :raises `~aiida.common.exceptions.IntegrityError`: if all Comments could not be deleted - """ - self._backend.comments.delete_all() + :raises TypeError: if ``filters`` is not a `dict` + :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty + """ + return self._backend.comments.delete_many(filters) - def delete_many(self, filters): - """ - Delete Comments from the Collection based on ``filters`` - :param filters: similar to QueryBuilder filter - :type filters: dict +class Comment(entities.Entity): + """Base class to map a DbComment that represents a comment attached to a certain Node.""" - :return: (former) ``PK`` s of deleted Comments - :rtype: list + Collection = CommentCollection - :raises TypeError: if ``filters`` is not a `dict` - :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty - """ - self._backend.comments.delete_many(filters) + @classproperty + def objects(cls) -> CommentCollection: # pylint: disable=no-self-argument + return CommentCollection.get_cached(cls, get_manager().get_backend()) def __init__(self, node, user, content=None, backend=None): """ diff --git a/aiida/orm/computers.py b/aiida/orm/computers.py index 0dc32b0012..f95f362602 100644 --- a/aiida/orm/computers.py +++ b/aiida/orm/computers.py @@ -10,8 +10,10 @@ """Module for Computer entities""" import logging import os +from typing import List, Optional, Tuple, Type from aiida.common import exceptions +from aiida.common.lang import classproperty from aiida.manage.manager import get_manager from aiida.orm.implementation import Backend from aiida.plugins import SchedulerFactory, TransportFactory @@ -21,6 +23,40 @@ __all__ = ('Computer',) +class ComputerCollection(entities.Collection['Computer']): + """The collection of Computer entries.""" + + @staticmethod + def _entity_base_cls() -> Type['Computer']: + return Computer + + def get_or_create(self, label: Optional[str] = None, **kwargs) -> Tuple['Computer', bool]: + """ + Try to retrieve a Computer from the DB with the given arguments; + create (and store) a new Computer if such a Computer was not present yet. + + :param label: computer label + + :return: (computer, created) where computer is the computer (new or existing, + in any case already stored) and created is a boolean saying + """ + if not label: + raise ValueError('Computer label must be provided') + + try: + return False, self.get(label=label) + except exceptions.NotExistent: + return True, Computer(backend=self.backend, label=label, **kwargs) + + def list_labels(self) -> List[str]: + """Return a list with all the labels of the computers in the DB.""" + return self._backend.computers.list_names() + + def delete(self, pk: int) -> None: + """Delete the computer with the given id""" + return self._backend.computers.delete(pk) + + class Computer(entities.Entity): """ Computer entity. @@ -34,36 +70,11 @@ class Computer(entities.Entity): PROPERTY_WORKDIR = 'workdir' PROPERTY_SHEBANG = 'shebang' - class Collection(entities.Collection): - """The collection of Computer entries.""" - - def get_or_create(self, label=None, **kwargs): - """ - Try to retrieve a Computer from the DB with the given arguments; - create (and store) a new Computer if such a Computer was not present yet. - - :param label: computer label - :type label: str - - :return: (computer, created) where computer is the computer (new or existing, - in any case already stored) and created is a boolean saying - :rtype: (:class:`aiida.orm.Computer`, bool) - """ - if not label: - raise ValueError('Computer label must be provided') - - try: - return False, self.get(label=label) - except exceptions.NotExistent: - return True, Computer(backend=self.backend, label=label, **kwargs) - - def list_labels(self): - """Return a list with all the labels of the computers in the DB.""" - return self._backend.computers.list_names() + Collection = ComputerCollection - def delete(self, id): # pylint: disable=redefined-builtin,invalid-name - """Delete the computer with the given id""" - return self._backend.computers.delete(id) + @classproperty + def objects(cls) -> ComputerCollection: # pylint: disable=no-self-argument + return ComputerCollection.get_cached(cls, get_manager().get_backend()) def __init__( # pylint: disable=too-many-arguments self, diff --git a/aiida/orm/entities.py b/aiida/orm/entities.py index 19477a8671..77c068e403 100644 --- a/aiida/orm/entities.py +++ b/aiida/orm/entities.py @@ -11,19 +11,30 @@ import abc import copy from enum import Enum -import typing +from functools import lru_cache +from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, TypeVar, cast from plumpy.base.utils import call_with_super_check, super_check -from aiida.common import datastructures, exceptions +from aiida.common import exceptions from aiida.common.lang import classproperty, type_check from aiida.manage.manager import get_manager +try: + from typing import Protocol +except ImportError: # Python <3.8 doesn't have `Protocol` in the stdlib + from typing_extensions import Protocol # type: ignore[misc] + +if TYPE_CHECKING: + from aiida.orm.implementation import Backend, BackendEntity + from aiida.orm.querybuilder import FilterType, OrderByType, QueryBuilder + __all__ = ('Entity', 'Collection', 'EntityAttributesMixin', 'EntityExtrasMixin') -EntityType = typing.TypeVar('EntityType') # pylint: disable=invalid-name +CollectionType = TypeVar('CollectionType', bound='Collection') +EntityType = TypeVar('EntityType', bound='Entity') -_NO_DEFAULT = tuple() +_NO_DEFAULT: Any = tuple() class EntityTypes(Enum): @@ -39,184 +50,138 @@ class EntityTypes(Enum): GROUP_NODE = 'group_node' -class Collection(typing.Generic[EntityType]): - """Container class that represents the collection of objects of a particular type.""" +class Collection(abc.ABC, Generic[EntityType]): + """Container class that represents the collection of objects of a particular entity type.""" - # A store for any backend specific collections that already exist - _COLLECTIONS = datastructures.LazyStore() + @staticmethod + @abc.abstractmethod + def _entity_base_cls() -> Type[EntityType]: + """The allowed entity class or subclasses thereof.""" @classmethod - def get_collection(cls, entity_type, backend): - """ - Get the collection for a given entity type and backend instance - - :param entity_type: the entity type e.g. User, Computer, etc - :type entity_type: :class:`aiida.orm.Entity` + @lru_cache(maxsize=100) + def get_cached(cls, entity_class: Type[EntityType], backend: 'Backend'): + """Get the cached collection instance for the given entity class and backend. :param backend: the backend instance to get the collection for - :type backend: :class:`aiida.orm.implementation.Backend` - - :return: a new collection with the new backend - :rtype: :class:`aiida.orm.Collection` """ - # Lazily get the collection i.e. create only if we haven't done so yet - return cls._COLLECTIONS.get((entity_type, backend), lambda: entity_type.Collection(backend, entity_type)) + from aiida.orm.implementation import Backend + type_check(backend, Backend) + return cls(entity_class, backend=backend) - def __init__(self, backend, entity_class): + def __init__(self, entity_class: Type[EntityType], backend: Optional['Backend'] = None) -> None: """ Construct a new entity collection. - :param backend: the backend instance to get the collection for - :type backend: :class:`aiida.orm.implementation.Backend` - :param entity_class: the entity type e.g. User, Computer, etc - :type entity_class: :class:`aiida.orm.Entity` - + :param backend: the backend instance to get the collection for, or use the default """ - assert issubclass(entity_class, Entity), 'Must provide an entity type' + from aiida.orm.implementation import Backend + type_check(backend, Backend, allow_none=True) + assert issubclass(entity_class, self._entity_base_cls()) self._backend = backend or get_manager().get_backend() self._entity_type = entity_class - def __call__(self, backend): - """ Create a new objects collection using a new backend. - - :param backend: the backend instance to get the collection for - :type backend: :class:`aiida.orm.implementation.Backend` - - :return: a new collection with the new backend - :rtype: :class:`aiida.orm.Collection` - """ + def __call__(self: CollectionType, backend: 'Backend') -> CollectionType: + """Get or create a cached collection using a new backend.""" if backend is self._backend: - # Special case if they actually want the same collection return self - - return self.get_collection(self.entity_type, backend) + return self.get_cached(self.entity_type, backend=backend) # type: ignore @property - def backend(self): - """Return the backend. - - :return: the backend instance of this collection - :rtype: :class:`aiida.orm.implementation.Backend` - """ - return self._backend + def entity_type(self) -> Type[EntityType]: + """The entity type for this instance.""" + return self._entity_type @property - def entity_type(self): - """The entity type. - - :rtype: :class:`aiida.orm.Entity` - """ - return self._entity_type + def backend(self) -> 'Backend': + """Return the backend.""" + return self._backend - def query(self, filters=None, order_by=None, limit=None, offset=None): - """ - Get a query builder for the objects of this collection + def query( + self, + filters: Optional[Dict[str, 'FilterType']] = None, + order_by: Optional['OrderByType'] = None, + limit: Optional[int] = None, + offset: Optional[int] = None + ) -> 'QueryBuilder': + """Get a query builder for the objects of this collection. :param filters: the keyword value pair filters to match - :type filters: dict - :param order_by: a list of (key, direction) pairs specifying the sort order - :type order_by: list - :param limit: the maximum number of results to return - :type limit: int - :param offset: number of initial results to be skipped - :type offset: int - - :return: a new query builder instance - :rtype: :class:`aiida.orm.QueryBuilder` """ from . import querybuilder filters = filters or {} order_by = {self.entity_type: order_by} if order_by else {} - query = querybuilder.QueryBuilder(limit=limit, offset=offset) + query = querybuilder.QueryBuilder(backend=self._backend, limit=limit, offset=offset) query.append(self.entity_type, project='*', filters=filters) query.order_by([order_by]) return query - def get(self, **filters): - """ - Get a single collection entry that matches the filter criteria + def get(self, **filters: Any) -> EntityType: + """Get a single collection entry that matches the filter criteria. :param filters: the filters identifying the object to get - :type filters: dict :return: the entry """ res = self.query(filters=filters) return res.one()[0] - def find(self, filters=None, order_by=None, limit=None): - """ - Find collection entries matching the filter criteria + def find( + self, + filters: Optional[Dict[str, 'FilterType']] = None, + order_by: Optional['OrderByType'] = None, + limit: Optional[int] = None + ) -> List[EntityType]: + """Find collection entries matching the filter criteria. :param filters: the keyword value pair filters to match - :type filters: dict - :param order_by: a list of (key, direction) pairs specifying the sort order - :type order_by: list - :param limit: the maximum number of results to return - :type limit: int :return: a list of resulting matches - :rtype: list """ query = self.query(filters=filters, order_by=order_by, limit=limit) - return query.all(flat=True) + return cast(List[EntityType], query.all(flat=True)) - def all(self): - """ - Get all entities in this collection + def all(self) -> List[EntityType]: + """Get all entities in this collection. :return: A list of all entities - :rtype: list """ - return self.query().all(flat=True) # pylint: disable=no-member + return cast(List[EntityType], self.query().all(flat=True)) # pylint: disable=no-member - def count(self, filters=None): - """Count entities in this collection according to criteria + def count(self, filters: Optional[Dict[str, 'FilterType']] = None) -> int: + """Count entities in this collection according to criteria. :param filters: the keyword value pair filters to match - :type filters: dict :return: The number of entities found using the supplied criteria - :rtype: int """ return self.query(filters=filters).count() -class Entity: +class Entity(abc.ABC): """An AiiDA entity""" - _objects = None - - # Define our collection type - Collection = Collection - @classproperty - def objects(cls, backend=None): # pylint: disable=no-self-argument - """ - Get a collection for objects of this type. - - :param backend: the optional backend to use (otherwise use default) - :type backend: :class:`aiida.orm.implementation.Backend` + @abc.abstractmethod + def objects(cls: EntityType) -> Collection[EntityType]: # pylint: disable=no-self-argument,disable=no-self-use + """Get a collection for objects of this type, with the default backend. :return: an object that can be used to access entities of this type - :rtype: :class:`aiida.orm.Collection` """ - backend = backend or get_manager().get_backend() - return cls.Collection.get_collection(cls, backend) @classmethod def get(cls, **kwargs): return cls.objects.get(**kwargs) # pylint: disable=no-member @classmethod - def from_backend_entity(cls, backend_entity): + def from_backend_entity(cls: Type[EntityType], backend_entity: 'BackendEntity') -> EntityType: """ Construct an entity from a backend entity instance @@ -232,30 +197,28 @@ def from_backend_entity(cls, backend_entity): call_with_super_check(entity.initialize) return entity - def __init__(self, backend_entity): + def __init__(self, backend_entity: 'BackendEntity') -> None: """ :param backend_entity: the backend model supporting this entity - :type backend_entity: :class:`aiida.orm.implementation.entities.BackendEntity` """ self._backend_entity = backend_entity call_with_super_check(self.initialize) - def init_from_backend(self, backend_entity): + def init_from_backend(self, backend_entity: 'BackendEntity') -> None: """ :param backend_entity: the backend model supporting this entity - :type backend_entity: :class:`aiida.orm.implementation.entities.BackendEntity` """ self._backend_entity = backend_entity @super_check - def initialize(self): + def initialize(self) -> None: """Initialize instance attributes. This will be called after the constructor is called or an entity is created from an existing backend entity. """ @property - def id(self): # pylint: disable=invalid-name + def id(self) -> int: # pylint: disable=invalid-name """Return the id for this entity. This identifier is guaranteed to be unique amongst entities of the same type for a single backend instance. @@ -265,7 +228,7 @@ def id(self): # pylint: disable=invalid-name return self._backend_entity.id @property - def pk(self): + def pk(self) -> int: """Return the primary key for this entity. This identifier is guaranteed to be unique amongst entities of the same type for a single backend instance. @@ -285,43 +248,44 @@ def uuid(self): """ return self._backend_entity.uuid - def store(self): + def store(self: EntityType) -> EntityType: """Store the entity.""" self._backend_entity.store() return self @property - def is_stored(self): - """Return whether the entity is stored. - - :return: boolean, True if stored, False otherwise - :rtype: bool - """ + def is_stored(self) -> bool: + """Return whether the entity is stored.""" return self._backend_entity.is_stored @property - def backend(self): - """ - Get the backend for this entity - :return: the backend instance - """ + def backend(self) -> 'Backend': + """Get the backend for this entity""" return self._backend_entity.backend @property - def backend_entity(self): - """ - Get the implementing class for this object - - :return: the class model - """ + def backend_entity(self) -> 'BackendEntity': + """Get the implementing class for this object""" return self._backend_entity -class EntityAttributesMixin(abc.ABC): +class EntityProtocol(Protocol): + """Protocol for attributes required by Entity mixins.""" + + @property + def backend_entity(self) -> 'BackendEntity': + ... + + @property + def is_stored(self) -> bool: + ... + + +class EntityAttributesMixin: """Mixin class that adds all methods for the attributes column to an entity.""" @property - def attributes(self): + def attributes(self: EntityProtocol) -> Dict[str, Any]: """Return the complete attributes dictionary. .. warning:: While the entity is unstored, this will return references of the attributes on the database model, @@ -341,7 +305,7 @@ def attributes(self): return attributes - def get_attribute(self, key, default=_NO_DEFAULT): + def get_attribute(self: EntityProtocol, key: str, default=_NO_DEFAULT) -> Any: """Return the value of an attribute. .. warning:: While the entity is unstored, this will return a reference of the attribute on the database model, @@ -367,7 +331,7 @@ def get_attribute(self, key, default=_NO_DEFAULT): return attribute - def get_attribute_many(self, keys): + def get_attribute_many(self: EntityProtocol, keys: List[str]) -> List[Any]: """Return the values of multiple attributes. .. warning:: While the entity is unstored, this will return references of the attributes on the database model, @@ -389,7 +353,7 @@ def get_attribute_many(self, keys): return attributes - def set_attribute(self, key, value): + def set_attribute(self: EntityProtocol, key: str, value: Any) -> None: """Set an attribute to the given value. :param key: name of the attribute @@ -402,7 +366,7 @@ def set_attribute(self, key, value): self.backend_entity.set_attribute(key, value) - def set_attribute_many(self, attributes): + def set_attribute_many(self: EntityProtocol, attributes: Dict[str, Any]) -> None: """Set multiple attributes. .. note:: This will override any existing attributes that are present in the new dictionary. @@ -416,7 +380,7 @@ def set_attribute_many(self, attributes): self.backend_entity.set_attribute_many(attributes) - def reset_attributes(self, attributes): + def reset_attributes(self: EntityProtocol, attributes: Dict[str, Any]) -> None: """Reset the attributes. .. note:: This will completely clear any existing attributes and replace them with the new dictionary. @@ -430,7 +394,7 @@ def reset_attributes(self, attributes): self.backend_entity.reset_attributes(attributes) - def delete_attribute(self, key): + def delete_attribute(self: EntityProtocol, key: str) -> None: """Delete an attribute. :param key: name of the attribute @@ -442,7 +406,7 @@ def delete_attribute(self, key): self.backend_entity.delete_attribute(key) - def delete_attribute_many(self, keys): + def delete_attribute_many(self: EntityProtocol, keys: List[str]) -> None: """Delete multiple attributes. :param keys: names of the attributes to delete @@ -454,21 +418,21 @@ def delete_attribute_many(self, keys): self.backend_entity.delete_attribute_many(keys) - def clear_attributes(self): + def clear_attributes(self: EntityProtocol) -> None: """Delete all attributes.""" if self.is_stored: raise exceptions.ModificationNotAllowed('the attributes of a stored entity are immutable') self.backend_entity.clear_attributes() - def attributes_items(self): + def attributes_items(self: EntityProtocol): """Return an iterator over the attributes. :return: an iterator with attribute key value pairs """ return self.backend_entity.attributes_items() - def attributes_keys(self): + def attributes_keys(self: EntityProtocol): """Return an iterator over the attribute keys. :return: an iterator with attribute keys @@ -476,11 +440,11 @@ def attributes_keys(self): return self.backend_entity.attributes_keys() -class EntityExtrasMixin(abc.ABC): +class EntityExtrasMixin: """Mixin class that adds all methods for the extras column to an entity.""" @property - def extras(self): + def extras(self: EntityProtocol) -> Dict[str, Any]: """Return the complete extras dictionary. .. warning:: While the entity is unstored, this will return references of the extras on the database model, @@ -500,7 +464,7 @@ def extras(self): return extras - def get_extra(self, key, default=_NO_DEFAULT): + def get_extra(self: EntityProtocol, key: str, default: Any = _NO_DEFAULT) -> Any: """Return the value of an extra. .. warning:: While the entity is unstored, this will return a reference of the extra on the database model, @@ -526,7 +490,7 @@ def get_extra(self, key, default=_NO_DEFAULT): return extra - def get_extra_many(self, keys): + def get_extra_many(self: EntityProtocol, keys: List[str]) -> List[Any]: """Return the values of multiple extras. .. warning:: While the entity is unstored, this will return references of the extras on the database model, @@ -548,7 +512,7 @@ def get_extra_many(self, keys): return extras - def set_extra(self, key, value): + def set_extra(self: EntityProtocol, key: str, value: Any) -> None: """Set an extra to the given value. :param key: name of the extra @@ -557,7 +521,7 @@ def set_extra(self, key, value): """ self.backend_entity.set_extra(key, value) - def set_extra_many(self, extras): + def set_extra_many(self: EntityProtocol, extras: Dict[str, Any]) -> None: """Set multiple extras. .. note:: This will override any existing extras that are present in the new dictionary. @@ -567,7 +531,7 @@ def set_extra_many(self, extras): """ self.backend_entity.set_extra_many(extras) - def reset_extras(self, extras): + def reset_extras(self: EntityProtocol, extras: Dict[str, Any]) -> None: """Reset the extras. .. note:: This will completely clear any existing extras and replace them with the new dictionary. @@ -577,7 +541,7 @@ def reset_extras(self, extras): """ self.backend_entity.reset_extras(extras) - def delete_extra(self, key): + def delete_extra(self: EntityProtocol, key: str) -> None: """Delete an extra. :param key: name of the extra @@ -585,7 +549,7 @@ def delete_extra(self, key): """ self.backend_entity.delete_extra(key) - def delete_extra_many(self, keys): + def delete_extra_many(self: EntityProtocol, keys: List[str]) -> None: """Delete multiple extras. :param keys: names of the extras to delete @@ -593,18 +557,18 @@ def delete_extra_many(self, keys): """ self.backend_entity.delete_extra_many(keys) - def clear_extras(self): + def clear_extras(self: EntityProtocol) -> None: """Delete all extras.""" self.backend_entity.clear_extras() - def extras_items(self): + def extras_items(self: EntityProtocol): """Return an iterator over the extras. :return: an iterator with extra key value pairs """ return self.backend_entity.extras_items() - def extras_keys(self): + def extras_keys(self: EntityProtocol): """Return an iterator over the extra keys. :return: an iterator with extra keys diff --git a/aiida/orm/groups.py b/aiida/orm/groups.py index c61aca0bca..cddf71db51 100644 --- a/aiida/orm/groups.py +++ b/aiida/orm/groups.py @@ -9,11 +9,11 @@ ########################################################################### """AiiDA Group entites""" from abc import ABCMeta -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Tuple, Type import warnings from aiida.common import exceptions -from aiida.common.lang import type_check +from aiida.common.lang import classproperty, type_check from aiida.manage.manager import get_manager from . import convert, entities, users @@ -63,47 +63,56 @@ def __new__(cls, name, bases, namespace, **kwargs): return newcls -class Group(entities.Entity, entities.EntityExtrasMixin, metaclass=GroupMeta): - """An AiiDA ORM implementation of group of nodes.""" +class GroupCollection(entities.Collection['Group']): + """Collection of Groups""" - # added by metaclass - _type_string = ClassVar[Optional[str]] + @staticmethod + def _entity_base_cls() -> Type['Group']: + return Group + + def get_or_create(self, label: Optional[str] = None, **kwargs) -> Tuple['Group', bool]: + """ + Try to retrieve a group from the DB with the given arguments; + create (and store) a new group if such a group was not present yet. + + :param label: group label + + :return: (group, created) where group is the group (new or existing, + in any case already stored) and created is a boolean saying + """ + if not label: + raise ValueError('Group label must be provided') - class Collection(entities.Collection): - """Collection of Groups""" + res = self.find(filters={'label': label}) - def get_or_create(self, label=None, **kwargs): - """ - Try to retrieve a group from the DB with the given arguments; - create (and store) a new group if such a group was not present yet. + if not res: + return self.entity_type(label, backend=self.backend, **kwargs).store(), True - :param label: group label - :type label: str + if len(res) > 1: + raise exceptions.MultipleObjectsError('More than one groups found in the database') - :return: (group, created) where group is the group (new or existing, - in any case already stored) and created is a boolean saying - :rtype: (:class:`aiida.orm.Group`, bool) - """ - if not label: - raise ValueError('Group label must be provided') + return res[0], False - res = self.find(filters={'label': label}) + def delete(self, pk: int) -> None: + """ + Delete a group - if not res: - return self.entity_type(label, backend=self.backend, **kwargs).store(), True + :param pk: the id of the group to delete + """ + self._backend.groups.delete(pk) - if len(res) > 1: - raise exceptions.MultipleObjectsError('More than one groups found in the database') - return res[0], False +class Group(entities.Entity, entities.EntityExtrasMixin, metaclass=GroupMeta): + """An AiiDA ORM implementation of group of nodes.""" + + # added by metaclass + _type_string = ClassVar[Optional[str]] - def delete(self, id): # pylint: disable=invalid-name, redefined-builtin - """ - Delete a group + Collection = GroupCollection - :param id: the id of the group to delete - """ - self._backend.groups.delete(id) + @classproperty + def objects(cls) -> GroupCollection: # pylint: disable=no-self-argument + return GroupCollection.get_cached(cls, get_manager().get_backend()) def __init__(self, label=None, user=None, description='', type_string=None, backend=None): """ diff --git a/aiida/orm/implementation/comments.py b/aiida/orm/implementation/comments.py index 432dd36e02..50ec1273c2 100644 --- a/aiida/orm/implementation/comments.py +++ b/aiida/orm/implementation/comments.py @@ -8,8 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for comment backend classes.""" - import abc +from typing import List from .entities import BackendCollection, BackendEntity @@ -98,7 +98,7 @@ def delete_all(self): """ @abc.abstractmethod - def delete_many(self, filters): + def delete_many(self, filters) -> List[int]: """ Delete Comments based on ``filters`` diff --git a/aiida/orm/implementation/logs.py b/aiida/orm/implementation/logs.py index 172b0dcaa5..ef53c90fad 100644 --- a/aiida/orm/implementation/logs.py +++ b/aiida/orm/implementation/logs.py @@ -9,6 +9,7 @@ ########################################################################### """Backend group module""" import abc +from typing import List from .entities import BackendCollection, BackendEntity @@ -97,19 +98,18 @@ class BackendLogCollection(BackendCollection[BackendLog]): ENTITY_CLASS = BackendLog @abc.abstractmethod - def delete(self, log_id): + def delete(self, log_id: int) -> None: """ Remove a Log entry from the collection with the given id :param log_id: id of the Log to delete - :type log_id: int :raises TypeError: if ``log_id`` is not an `int` :raises `~aiida.common.exceptions.NotExistent`: if Log with ID ``log_id`` is not found """ @abc.abstractmethod - def delete_all(self): + def delete_all(self) -> None: """ Delete all Log entries. @@ -117,15 +117,13 @@ def delete_all(self): """ @abc.abstractmethod - def delete_many(self, filters): + def delete_many(self, filters) -> List[int]: """ Delete Logs based on ``filters`` :param filters: similar to QueryBuilder filter - :type filters: dict :return: (former) ``PK`` s of deleted Logs - :rtype: list :raises TypeError: if ``filters`` is not a `dict` :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty diff --git a/aiida/orm/logs.py b/aiida/orm/logs.py index 9f5d83eb8a..36b6c20ba2 100644 --- a/aiida/orm/logs.py +++ b/aiida/orm/logs.py @@ -8,12 +8,19 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for orm logging abstract classes""" +import logging +from typing import TYPE_CHECKING, List, Optional, Type from aiida.common import timezone +from aiida.common.lang import classproperty from aiida.manage.manager import get_manager from . import entities +if TYPE_CHECKING: + from aiida.orm import Node + from aiida.orm.querybuilder import FilterType, OrderByType + __all__ = ('Log', 'OrderSpecifier', 'ASCENDING', 'DESCENDING') ASCENDING = 'asc' @@ -24,114 +31,108 @@ def OrderSpecifier(field, direction): # pylint: disable=invalid-name return {field: direction} -class Log(entities.Entity): +class LogCollection(entities.Collection['Log']): """ - An AiiDA Log entity. Corresponds to a logged message against a particular AiiDA node. + This class represents the collection of logs and can be used to create + and retrieve logs. """ - class Collection(entities.Collection): - """ - This class represents the collection of logs and can be used to create - and retrieve logs. - """ + @staticmethod + def _entity_base_cls() -> Type['Log']: + return Log - @staticmethod - def create_entry_from_record(record): - """ - Helper function to create a log entry from a record created as by the python logging library + def create_entry_from_record(self, record: logging.LogRecord) -> 'Log': + """Helper function to create a log entry from a record created as by the python logging library - :param record: The record created by the logging module - :type record: :class:`logging.LogRecord` + :param record: The record created by the logging module + :return: A stored log instance + """ + from datetime import datetime + + dbnode_id = record.__dict__.get('dbnode_id', None) - :return: An object implementing the log entry interface - :rtype: :class:`aiida.orm.logs.Log` - """ - from datetime import datetime + # Do not store if dbnode_id is not set + if dbnode_id is None: + return None - dbnode_id = record.__dict__.get('dbnode_id', None) + metadata = dict(record.__dict__) - # Do not store if dbnode_id is not set - if dbnode_id is None: - return None + # If an `exc_info` is present, the log message was an exception, so format the full traceback + try: + import traceback + exc_info = metadata.pop('exc_info') + message = ''.join(traceback.format_exception(*exc_info)) + except (TypeError, KeyError): + message = record.getMessage() - metadata = dict(record.__dict__) + # Stringify the content of `args` if they exist in the metadata to ensure serializability + for key in ['args']: + if key in metadata: + metadata[key] = str(metadata[key]) - # If an `exc_info` is present, the log message was an exception, so format the full traceback - try: - import traceback - exc_info = metadata.pop('exc_info') - message = ''.join(traceback.format_exception(*exc_info)) - except (TypeError, KeyError): - message = record.getMessage() + return Log( + time=timezone.make_aware(datetime.fromtimestamp(record.created)), + loggername=record.name, + levelname=record.levelname, + dbnode_id=dbnode_id, + message=message, + metadata=metadata, + backend=self.backend + ) - # Stringify the content of `args` if they exist in the metadata to ensure serializability - for key in ['args']: - if key in metadata: - metadata[key] = str(metadata[key]) + def get_logs_for(self, entity: 'Node', order_by: Optional['OrderByType'] = None) -> List['Log']: + """Get all the log messages for a given node and optionally sort - return Log( - time=timezone.make_aware(datetime.fromtimestamp(record.created)), - loggername=record.name, - levelname=record.levelname, - dbnode_id=dbnode_id, - message=message, - metadata=metadata - ) + :param entity: the entity to get logs for + :param order_by: a list of (key, direction) pairs specifying the sort order - def get_logs_for(self, entity, order_by=None): - """ - Get all the log messages for a given entity and optionally sort + :return: the list of log entries + """ + from . import nodes - :param entity: the entity to get logs for - :type entity: :class:`aiida.orm.Entity` + if not isinstance(entity, nodes.Node): + raise Exception('Only node logs are stored') - :param order_by: a list of (key, direction) pairs specifying the sort order - :type order_by: list + return self.find({'dbnode_id': entity.pk}, order_by=order_by) - :return: the list of log entries - :rtype: list - """ - from . import nodes + def delete(self, pk: int) -> None: + """Remove a Log entry from the collection with the given id - if not isinstance(entity, nodes.Node): - raise Exception('Only node logs are stored') + :param pk: id of the Log to delete - return self.find({'dbnode_id': entity.pk}, order_by=order_by) + :raises `~aiida.common.exceptions.NotExistent`: if Log with ID ``pk`` is not found + """ + return self._backend.logs.delete(pk) - def delete(self, log_id): - """ - Remove a Log entry from the collection with the given id + def delete_all(self) -> None: + """Delete all Logs in the collection - :param log_id: id of the Log to delete - :type log_id: int + :raises `~aiida.common.exceptions.IntegrityError`: if all Logs could not be deleted + """ + return self._backend.logs.delete_all() - :raises TypeError: if ``log_id`` is not an `int` - :raises `~aiida.common.exceptions.NotExistent`: if Log with ID ``log_id`` is not found - """ - self._backend.logs.delete(log_id) + def delete_many(self, filters: 'FilterType') -> List[int]: + """Delete Logs based on ``filters`` - def delete_all(self): - """ - Delete all Logs in the collection + :param filters: filters to pass to the QueryBuilder + :return: (former) ``PK`` s of deleted Logs - :raises `~aiida.common.exceptions.IntegrityError`: if all Logs could not be deleted - """ - self._backend.logs.delete_all() + :raises TypeError: if ``filters`` is not a `dict` + :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty + """ + return self._backend.logs.delete_many(filters) - def delete_many(self, filters): - """ - Delete Logs based on ``filters`` - :param filters: similar to QueryBuilder filter - :type filters: dict +class Log(entities.Entity): + """ + An AiiDA Log entity. Corresponds to a logged message against a particular AiiDA node. + """ - :return: (former) ``PK`` s of deleted Logs - :rtype: list + Collection = LogCollection - :raises TypeError: if ``filters`` is not a `dict` - :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty - """ - self._backend.logs.delete_many(filters) + @classproperty + def objects(cls) -> LogCollection: # pylint: disable=no-self-argument + return LogCollection.get_cached(cls, get_manager().get_backend()) def __init__(self, time, loggername, levelname, dbnode_id, message='', metadata=None, backend=None): # pylint: disable=too-many-arguments """Construct a new log diff --git a/aiida/orm/nodes/data/remote/base.py b/aiida/orm/nodes/data/remote/base.py index 8de116de26..27d5773563 100644 --- a/aiida/orm/nodes/data/remote/base.py +++ b/aiida/orm/nodes/data/remote/base.py @@ -177,4 +177,4 @@ def _validate(self): raise ValidationError('Remote computer not set.') def get_authinfo(self): - return AuthInfo.objects.get(dbcomputer=self.computer, aiidauser=self.user) + return AuthInfo.objects(self.backend).get(dbcomputer=self.computer, aiidauser=self.user) diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index ac71a1b668..66ab7b14fa 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -14,7 +14,21 @@ import importlib from logging import Logger import typing -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + Generic, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, +) from uuid import UUID from aiida.common import exceptions @@ -41,7 +55,33 @@ __all__ = ('Node',) -_NO_DEFAULT = tuple() # type: ignore[var-annotated] +NodeType = TypeVar('NodeType', bound='Node') + + +class NodeCollection(EntityCollection[NodeType], Generic[NodeType]): + """The collection of nodes.""" + + @staticmethod + def _entity_base_cls() -> Type['Node']: + return Node + + def delete(self, pk: int) -> None: + """Delete a `Node` from the collection with the given id + + :param pk: the node id + """ + node = self.get(id=pk) + + if not node.is_stored: + return + + if node.get_incoming().all(): + raise exceptions.InvalidOperation(f'cannot delete Node<{node.pk}> because it has incoming links') + + if node.get_outgoing().all(): + raise exceptions.InvalidOperation(f'cannot delete Node<{node.pk}> because it has outgoing links') + + self._backend.nodes.delete(pk) class Node(Entity, NodeRepositoryMixin, EntityAttributesMixin, EntityExtrasMixin, metaclass=AbstractNodeMeta): @@ -65,27 +105,6 @@ class Node(Entity, NodeRepositoryMixin, EntityAttributesMixin, EntityExtrasMixin _plugin_type_string: ClassVar[str] _query_type_string: ClassVar[str] - class Collection(EntityCollection): - """The collection of nodes.""" - - def delete(self, node_id: int) -> None: - """Delete a `Node` from the collection with the given id - - :param node_id: the node id - """ - node = self.get(id=node_id) - - if not node.is_stored: - return - - if node.get_incoming().all(): - raise exceptions.InvalidOperation(f'cannot delete Node<{node.pk}> because it has incoming links') - - if node.get_outgoing().all(): - raise exceptions.InvalidOperation(f'cannot delete Node<{node.pk}> because it has outgoing links') - - self._backend.nodes.delete(node_id) - # This will be set by the metaclass call _logger: Optional[Logger] = None @@ -106,10 +125,11 @@ def delete(self, node_id: int) -> None: # These are to be initialized in the `initialization` method _incoming_cache: Optional[List[LinkTriple]] = None - @classmethod - def from_backend_entity(cls, backend_entity: 'BackendNode') -> 'Node': - entity = super().from_backend_entity(backend_entity) - return entity + Collection = NodeCollection + + @classproperty + def objects(cls: Type[NodeType]) -> NodeCollection[NodeType]: # pylint: disable=no-self-argument + return NodeCollection.get_cached(cls, get_manager().get_backend()) # type: ignore[arg-type] def __init__( self, @@ -304,11 +324,7 @@ def repository_metadata(self, value): @property def computer(self) -> Optional[Computer]: - """Return the computer of this node. - - :return: the computer or None - :rtype: `Computer` or None - """ + """Return the computer of this node.""" if self.backend_entity.computer: return Computer.from_backend_entity(self.backend_entity.computer) @@ -332,11 +348,7 @@ def computer(self, computer: Optional[Computer]) -> None: @property def user(self) -> User: - """Return the user of this node. - - :return: the user - :rtype: `User` - """ + """Return the user of this node.""" return User.from_backend_entity(self.backend_entity.user) @user.setter @@ -374,7 +386,7 @@ def add_comment(self, content: str, user: Optional[User] = None) -> Comment: :param user: the user to associate with the comment, will use default if not supplied :return: the newly created comment """ - user = user or User.objects.get_default() + user = user or User.objects(self.backend).get_default() return Comment(node=self, user=user, content=content).store() def get_comment(self, identifier: int) -> Comment: @@ -385,14 +397,14 @@ def get_comment(self, identifier: int) -> Comment: :raise aiida.common.MultipleObjectsError: if the id cannot be uniquely resolved to a comment :return: the comment """ - return Comment.objects.get(dbnode_id=self.pk, id=identifier) + return Comment.objects(self.backend).get(dbnode_id=self.pk, id=identifier) def get_comments(self) -> List[Comment]: """Return a sorted list of comments for this node. :return: the list of comments, sorted by pk """ - return Comment.objects.find(filters={'dbnode_id': self.pk}, order_by=[{'id': 'asc'}]) + return Comment.objects(self.backend).find(filters={'dbnode_id': self.pk}, order_by=[{'id': 'asc'}]) def update_comment(self, identifier: int, content: str) -> None: """Update the content of an existing comment. @@ -402,7 +414,7 @@ def update_comment(self, identifier: int, content: str) -> None: :raise aiida.common.NotExistent: if the comment with the given id does not exist :raise aiida.common.MultipleObjectsError: if the id cannot be uniquely resolved to a comment """ - comment = Comment.objects.get(dbnode_id=self.pk, id=identifier) + comment = Comment.objects(self.backend).get(dbnode_id=self.pk, id=identifier) comment.set_content(content) def remove_comment(self, identifier: int) -> None: # pylint: disable=no-self-use @@ -410,7 +422,7 @@ def remove_comment(self, identifier: int) -> None: # pylint: disable=no-self-us :param identifier: the comment pk """ - Comment.objects.delete(identifier) + Comment.objects(self.backend).delete(identifier) def add_incoming(self, source: 'Node', link_type: LinkType, link_label: str) -> None: """Add a link of the given type from a given node to ourself. diff --git a/aiida/orm/users.py b/aiida/orm/users.py index 4c55bcad6f..d00d889388 100644 --- a/aiida/orm/users.py +++ b/aiida/orm/users.py @@ -8,68 +8,78 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for the ORM user class.""" +from typing import TYPE_CHECKING, Optional, Tuple, Type + from aiida.common import exceptions +from aiida.common.lang import classproperty from aiida.manage.manager import get_manager from . import entities +if TYPE_CHECKING: + from aiida.orm.implementation import Backend + __all__ = ('User',) -class User(entities.Entity): - """AiiDA User""" +class UserCollection(entities.Collection['User']): + """The collection of users stored in a backend.""" - class Collection(entities.Collection): - """The collection of users stored in a backend.""" + UNDEFINED = 'UNDEFINED' + _default_user: Optional['User'] = None - UNDEFINED = 'UNDEFINED' - _default_user = None # type: aiida.orm.User + @staticmethod + def _entity_base_cls() -> Type['User']: + return User - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._default_user = self.UNDEFINED + def __init__(self, entity_class: Type['User'], backend: Optional['Backend'] = None) -> None: + super().__init__(entity_class=entity_class, backend=backend) + self._default_user = self.UNDEFINED - def get_or_create(self, email, **kwargs): - """ - Get the existing user with a given email address or create an unstored one + def get_or_create(self, email: str, **kwargs) -> Tuple[bool, 'User']: + """Get the existing user with a given email address or create an unstored one + + :param kwargs: The properties of the user to get or create + :return: The corresponding user object + :raises: :class:`aiida.common.exceptions.MultipleObjectsError`, + :class:`aiida.common.exceptions.NotExistent` + """ + try: + return False, self.get(email=email) + except exceptions.NotExistent: + return True, User(backend=self.backend, email=email, **kwargs) + + def get_default(self) -> 'User': + """Get the current default user""" + if self._default_user is self.UNDEFINED: + from aiida.manage.configuration import get_profile + profile = get_profile() + email = profile.default_user + if not email: + self._default_user = None - :param kwargs: The properties of the user to get or create - :return: The corresponding user object - :rtype: :class:`aiida.orm.User` - :raises: :class:`aiida.common.exceptions.MultipleObjectsError`, - :class:`aiida.common.exceptions.NotExistent` - """ try: - return False, self.get(email=email) - except exceptions.NotExistent: - return True, User(backend=self.backend, email=email, **kwargs) - - def get_default(self): - """ - Get the current default user - - :return: The default user - :rtype: :class:`aiida.orm.User` - """ - if self._default_user is self.UNDEFINED: - from aiida.manage.configuration import get_profile - profile = get_profile() - email = profile.default_user - if not email: - self._default_user = None - - try: - self._default_user = self.get(email=email) - except (exceptions.MultipleObjectsError, exceptions.NotExistent): - self._default_user = None - - return self._default_user - - def reset(self): - """ - Reset internal caches (default user). - """ - self._default_user = self.UNDEFINED + self._default_user = self.get(email=email) + except (exceptions.MultipleObjectsError, exceptions.NotExistent): + self._default_user = None + + return self._default_user + + def reset(self) -> None: + """ + Reset internal caches (default user). + """ + self._default_user = self.UNDEFINED + + +class User(entities.Entity): + """AiiDA User""" + + Collection = UserCollection + + @classproperty + def objects(cls) -> UserCollection: # pylint: disable=no-self-argument + return UserCollection.get_cached(cls, get_manager().get_backend()) REQUIRED_FIELDS = ['first_name', 'last_name', 'institution'] diff --git a/docs/source/howto/data.rst b/docs/source/howto/data.rst index d11cbd62c7..bb152a1fc6 100644 --- a/docs/source/howto/data.rst +++ b/docs/source/howto/data.rst @@ -576,7 +576,7 @@ This includes, notably: * *Node extras*: These can be deleted using :py:meth:`~aiida.orm.entities.EntityExtrasMixin.delete_extra` and :py:meth:`~aiida.orm.entities.EntityExtrasMixin.delete_extra_many` methods. * *Node comments*: These can be removed using :py:meth:`~aiida.orm.nodes.node.Node.remove_comment`. -* *Groups*: These can be deleted using :py:meth:`Group.objects.delete() `. +* *Groups*: These can be deleted using :py:meth:`Group.objects.delete() `. This command will only delete the group, not the nodes contained in the group. Completely deleting an AiiDA profile diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index b4670ae5f3..e03be84f3c 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -21,11 +21,6 @@ py:class builtins.dict # typing py:class AbstractContextManager py:class asyncio.events.AbstractEventLoop -py:class EntityType -py:class EntityTypes -py:class ModelType -py:class TransactionType -py:class ReturnType py:class function py:class IO py:class traceback @@ -48,6 +43,7 @@ py:class aiida.orm.utils.links.LinkQuadruple py:class aiida.tools.importexport.dbexport.ExportReport py:class aiida.tools.importexport.dbexport.ArchiveData py:class aiida.tools.groups.paths.WalkNodeResult +py:meth aiida.orm.groups.GroupCollection.delete py:class Backend py:class BackendEntity @@ -56,16 +52,22 @@ py:class AuthInfo py:class CalcJob py:class CalcJobImporter py:class CalcJobNode +py:class CollectionType py:class Data py:class DbImporter +py:class EntityType +py:class EntityTypes py:class ExitCode py:class File +py:class FilterType py:class FolderData py:class Group py:class JobInfo py:class JobState +py:class ModelType py:class Node py:class Orbital +py:class OrderByType py:class Parser py:class PersistenceError py:class Process @@ -74,14 +76,26 @@ py:class ProcessNode py:class ProcessSpec py:class Port py:class PortNamespace +py:class QueryBuilder py:class Repository +py:class ReturnType py:class Runner py:class Scheduler +py:class SelfType +py:class TransactionType py:class Transport py:class TransportQueue py:class WorkChain py:class WorkChainSpec +py:class AuthInfoCollection +py:class CommentCollection +py:class ComputerCollection +py:class GroupCollection +py:class LogCollection +py:class NodeCollection +py:class UserCollection + py:class kiwipy.communications.Communicator py:class plumpy.process_states.State py:class plumpy.workchains._If diff --git a/tests/orm/test_groups.py b/tests/orm/test_groups.py index 3e52f4d0f7..3e9c5e1397 100644 --- a/tests/orm/test_groups.py +++ b/tests/orm/test_groups.py @@ -345,7 +345,7 @@ def test_loading_unregistered(): # Removing it as other methods might get a warning instead group_pk = group.pk del group - orm.Group.objects.delete(id=group_pk) + orm.Group.objects.delete(pk=group_pk) @staticmethod def test_querying(): @@ -367,7 +367,7 @@ def test_querying(): # Removing it as other methods might get a warning instead group_pk = group.pk del group - orm.Group.objects.delete(id=group_pk) + orm.Group.objects.delete(pk=group_pk) @staticmethod def test_querying_node_subclasses():