From fc0b4587df59416bf7667ba3795f57a83267a5e9 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 20 Oct 2021 04:02:07 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=91=8C=20IMPROVE:=20Entity=20Collection?= =?UTF-8?q?=20typing=20(#5183)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR makes static types available for `Entity.objects -> Collection`. The `Collection` subclasses are now defined outside of the `Entity` subclasses, using `Generic`s to dynamically define the entity subclass their methods return. A number of fixes were also made, to ensure the correct backend is used in queries. --- .pre-commit-config.yaml | 1 + aiida/common/datastructures.py | 36 ---- aiida/common/lang.py | 7 +- aiida/orm/authinfos.py | 33 ++-- aiida/orm/comments.py | 71 ++++--- aiida/orm/computers.py | 69 ++++--- aiida/orm/entities.py | 272 ++++++++++++--------------- aiida/orm/groups.py | 73 +++---- aiida/orm/implementation/comments.py | 4 +- aiida/orm/implementation/logs.py | 10 +- aiida/orm/logs.py | 163 ++++++++-------- aiida/orm/nodes/data/remote/base.py | 2 +- aiida/orm/nodes/node.py | 96 +++++----- aiida/orm/users.py | 106 ++++++----- docs/source/howto/data.rst | 2 +- docs/source/nitpick-exceptions | 24 ++- tests/orm/test_groups.py | 4 +- 17 files changed, 490 insertions(+), 483 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 810c81fd80..9ccf4c40fc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -81,6 +81,7 @@ repos: aiida/orm/implementation/sqlalchemy/backend.py| aiida/orm/implementation/querybuilder.py| aiida/orm/implementation/sqlalchemy/querybuilder/.*py| + aiida/orm/entities.py| aiida/orm/nodes/data/jsonable.py| aiida/orm/nodes/node.py| aiida/orm/nodes/process/.*py| diff --git a/aiida/common/datastructures.py b/aiida/common/datastructures.py index 672aea27a0..a9cccc3d27 100644 --- a/aiida/common/datastructures.py +++ b/aiida/common/datastructures.py @@ -165,39 +165,3 @@ class CodeRunMode(IntEnum): SERIAL = 0 PARALLEL = 1 - - -class LazyStore: - """ - A container that provides a mapping to objects based on a key, if the object is not - found in the container when it is retrieved it will created using a provided factory - method - """ - - def __init__(self): - self._store = {} - - def get(self, key, factory): - """ - Get a value in the store based on the key, if it doesn't exist it will be created - using the factory method and returned - - :param key: the key of the object to get - :param factory: the factory used to create the object if necessary - :return: the object - """ - try: - return self._store[key] - except KeyError: - obj = factory() - self._store[key] = obj - return obj - - def pop(self, key): - """ - Pop an object from the store based on the given key - - :param key: the object key - :return: the object that was popped - """ - return self._store.pop(key) diff --git a/aiida/common/lang.py b/aiida/common/lang.py index 37f977b8b9..63b2d05afb 100644 --- a/aiida/common/lang.py +++ b/aiida/common/lang.py @@ -11,7 +11,7 @@ import functools import inspect import keyword -from typing import Any, Callable, Generic, TypeVar +from typing import Any, Callable, Generic, Type, TypeVar def isidentifier(identifier): @@ -77,6 +77,7 @@ def wrapped_fn(self, *args, **kwargs): # pylint: disable=missing-docstring override = override_decorator(check=False) # pylint: disable=invalid-name ReturnType = TypeVar('ReturnType') +SelfType = TypeVar('SelfType') class classproperty(Generic[ReturnType]): # pylint: disable=invalid-name @@ -88,8 +89,8 @@ class classproperty(Generic[ReturnType]): # pylint: disable=invalid-name instance as its first argument). """ - def __init__(self, getter: Callable[[Any], ReturnType]) -> None: + def __init__(self, getter: Callable[[Type[SelfType]], ReturnType]) -> None: self.getter = getter - def __get__(self, instance: Any, owner: Any) -> ReturnType: + def __get__(self, instance: Any, owner: Type[SelfType]) -> ReturnType: return self.getter(owner) diff --git a/aiida/orm/authinfos.py b/aiida/orm/authinfos.py index 5a478c5fd0..49f03416bb 100644 --- a/aiida/orm/authinfos.py +++ b/aiida/orm/authinfos.py @@ -8,8 +8,10 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for the `AuthInfo` ORM class.""" +from typing import Type from aiida.common import exceptions +from aiida.common.lang import classproperty from aiida.manage.manager import get_manager from aiida.plugins import TransportFactory @@ -18,22 +20,33 @@ __all__ = ('AuthInfo',) +class AuthInfoCollection(entities.Collection['AuthInfo']): + """The collection of `AuthInfo` entries.""" + + @staticmethod + def _entity_base_cls() -> Type['AuthInfo']: + return AuthInfo + + def delete(self, pk: int) -> None: + """Delete an entry from the collection. + + :param pk: the pk of the entry to delete + """ + self._backend.authinfos.delete(pk) + + class AuthInfo(entities.Entity): """ORM class that models the authorization information that allows a `User` to connect to a `Computer`.""" - class Collection(entities.Collection): - """The collection of `AuthInfo` entries.""" + Collection = AuthInfoCollection - def delete(self, pk): - """Delete an entry from the collection. - - :param pk: the pk of the entry to delete - """ - self._backend.authinfos.delete(pk) + @classproperty + def objects(cls) -> AuthInfoCollection: # pylint: disable=no-self-argument + return AuthInfoCollection.get_cached(cls, get_manager().get_backend()) PROPERTY_WORKDIR = 'workdir' - def __init__(self, computer, user, backend=None): + def __init__(self, computer, user, backend=None) -> None: """Create an `AuthInfo` instance for the given computer and user. :param computer: a `Computer` instance @@ -41,8 +54,6 @@ def __init__(self, computer, user, backend=None): :param user: a `User` instance :type user: :class:`aiida.orm.User` - - :rtype: :class:`aiida.orm.authinfos.AuthInfo` """ backend = backend or get_manager().get_backend() model = backend.authinfos.create(computer=computer.backend_entity, user=user.backend_entity) diff --git a/aiida/orm/comments.py b/aiida/orm/comments.py index d4ed7eaace..08a04ace7a 100644 --- a/aiida/orm/comments.py +++ b/aiida/orm/comments.py @@ -8,7 +8,9 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Comment objects and functions""" +from typing import List, Type +from aiida.common.lang import classproperty from aiida.manage.manager import get_manager from . import entities, users @@ -16,46 +18,55 @@ __all__ = ('Comment',) -class Comment(entities.Entity): - """Base class to map a DbComment that represents a comment attached to a certain Node.""" +class CommentCollection(entities.Collection['Comment']): + """The collection of Comment entries.""" + + @staticmethod + def _entity_base_cls() -> Type['Comment']: + return Comment + + def delete(self, pk: int) -> None: + """ + Remove a Comment from the collection with the given id + + :param pk: the id of the comment to delete + + :raises TypeError: if ``comment_id`` is not an `int` + :raises `~aiida.common.exceptions.NotExistent`: if Comment with ID ``comment_id`` is not found + """ + self._backend.comments.delete(pk) - class Collection(entities.Collection): - """The collection of Comment entries.""" + def delete_all(self) -> None: + """ + Delete all Comments from the Collection - def delete(self, comment_id): - """ - Remove a Comment from the collection with the given id + :raises `~aiida.common.exceptions.IntegrityError`: if all Comments could not be deleted + """ + self._backend.comments.delete_all() - :param comment_id: the id of the comment to delete - :type comment_id: int + def delete_many(self, filters) -> List[int]: + """ + Delete Comments from the Collection based on ``filters`` - :raises TypeError: if ``comment_id`` is not an `int` - :raises `~aiida.common.exceptions.NotExistent`: if Comment with ID ``comment_id`` is not found - """ - self._backend.comments.delete(comment_id) + :param filters: similar to QueryBuilder filter + :type filters: dict - def delete_all(self): - """ - Delete all Comments from the Collection + :return: (former) ``PK`` s of deleted Comments - :raises `~aiida.common.exceptions.IntegrityError`: if all Comments could not be deleted - """ - self._backend.comments.delete_all() + :raises TypeError: if ``filters`` is not a `dict` + :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty + """ + return self._backend.comments.delete_many(filters) - def delete_many(self, filters): - """ - Delete Comments from the Collection based on ``filters`` - :param filters: similar to QueryBuilder filter - :type filters: dict +class Comment(entities.Entity): + """Base class to map a DbComment that represents a comment attached to a certain Node.""" - :return: (former) ``PK`` s of deleted Comments - :rtype: list + Collection = CommentCollection - :raises TypeError: if ``filters`` is not a `dict` - :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty - """ - self._backend.comments.delete_many(filters) + @classproperty + def objects(cls) -> CommentCollection: # pylint: disable=no-self-argument + return CommentCollection.get_cached(cls, get_manager().get_backend()) def __init__(self, node, user, content=None, backend=None): """ diff --git a/aiida/orm/computers.py b/aiida/orm/computers.py index 0dc32b0012..f95f362602 100644 --- a/aiida/orm/computers.py +++ b/aiida/orm/computers.py @@ -10,8 +10,10 @@ """Module for Computer entities""" import logging import os +from typing import List, Optional, Tuple, Type from aiida.common import exceptions +from aiida.common.lang import classproperty from aiida.manage.manager import get_manager from aiida.orm.implementation import Backend from aiida.plugins import SchedulerFactory, TransportFactory @@ -21,6 +23,40 @@ __all__ = ('Computer',) +class ComputerCollection(entities.Collection['Computer']): + """The collection of Computer entries.""" + + @staticmethod + def _entity_base_cls() -> Type['Computer']: + return Computer + + def get_or_create(self, label: Optional[str] = None, **kwargs) -> Tuple['Computer', bool]: + """ + Try to retrieve a Computer from the DB with the given arguments; + create (and store) a new Computer if such a Computer was not present yet. + + :param label: computer label + + :return: (computer, created) where computer is the computer (new or existing, + in any case already stored) and created is a boolean saying + """ + if not label: + raise ValueError('Computer label must be provided') + + try: + return False, self.get(label=label) + except exceptions.NotExistent: + return True, Computer(backend=self.backend, label=label, **kwargs) + + def list_labels(self) -> List[str]: + """Return a list with all the labels of the computers in the DB.""" + return self._backend.computers.list_names() + + def delete(self, pk: int) -> None: + """Delete the computer with the given id""" + return self._backend.computers.delete(pk) + + class Computer(entities.Entity): """ Computer entity. @@ -34,36 +70,11 @@ class Computer(entities.Entity): PROPERTY_WORKDIR = 'workdir' PROPERTY_SHEBANG = 'shebang' - class Collection(entities.Collection): - """The collection of Computer entries.""" - - def get_or_create(self, label=None, **kwargs): - """ - Try to retrieve a Computer from the DB with the given arguments; - create (and store) a new Computer if such a Computer was not present yet. - - :param label: computer label - :type label: str - - :return: (computer, created) where computer is the computer (new or existing, - in any case already stored) and created is a boolean saying - :rtype: (:class:`aiida.orm.Computer`, bool) - """ - if not label: - raise ValueError('Computer label must be provided') - - try: - return False, self.get(label=label) - except exceptions.NotExistent: - return True, Computer(backend=self.backend, label=label, **kwargs) - - def list_labels(self): - """Return a list with all the labels of the computers in the DB.""" - return self._backend.computers.list_names() + Collection = ComputerCollection - def delete(self, id): # pylint: disable=redefined-builtin,invalid-name - """Delete the computer with the given id""" - return self._backend.computers.delete(id) + @classproperty + def objects(cls) -> ComputerCollection: # pylint: disable=no-self-argument + return ComputerCollection.get_cached(cls, get_manager().get_backend()) def __init__( # pylint: disable=too-many-arguments self, diff --git a/aiida/orm/entities.py b/aiida/orm/entities.py index 19477a8671..77c068e403 100644 --- a/aiida/orm/entities.py +++ b/aiida/orm/entities.py @@ -11,19 +11,30 @@ import abc import copy from enum import Enum -import typing +from functools import lru_cache +from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, TypeVar, cast from plumpy.base.utils import call_with_super_check, super_check -from aiida.common import datastructures, exceptions +from aiida.common import exceptions from aiida.common.lang import classproperty, type_check from aiida.manage.manager import get_manager +try: + from typing import Protocol +except ImportError: # Python <3.8 doesn't have `Protocol` in the stdlib + from typing_extensions import Protocol # type: ignore[misc] + +if TYPE_CHECKING: + from aiida.orm.implementation import Backend, BackendEntity + from aiida.orm.querybuilder import FilterType, OrderByType, QueryBuilder + __all__ = ('Entity', 'Collection', 'EntityAttributesMixin', 'EntityExtrasMixin') -EntityType = typing.TypeVar('EntityType') # pylint: disable=invalid-name +CollectionType = TypeVar('CollectionType', bound='Collection') +EntityType = TypeVar('EntityType', bound='Entity') -_NO_DEFAULT = tuple() +_NO_DEFAULT: Any = tuple() class EntityTypes(Enum): @@ -39,184 +50,138 @@ class EntityTypes(Enum): GROUP_NODE = 'group_node' -class Collection(typing.Generic[EntityType]): - """Container class that represents the collection of objects of a particular type.""" +class Collection(abc.ABC, Generic[EntityType]): + """Container class that represents the collection of objects of a particular entity type.""" - # A store for any backend specific collections that already exist - _COLLECTIONS = datastructures.LazyStore() + @staticmethod + @abc.abstractmethod + def _entity_base_cls() -> Type[EntityType]: + """The allowed entity class or subclasses thereof.""" @classmethod - def get_collection(cls, entity_type, backend): - """ - Get the collection for a given entity type and backend instance - - :param entity_type: the entity type e.g. User, Computer, etc - :type entity_type: :class:`aiida.orm.Entity` + @lru_cache(maxsize=100) + def get_cached(cls, entity_class: Type[EntityType], backend: 'Backend'): + """Get the cached collection instance for the given entity class and backend. :param backend: the backend instance to get the collection for - :type backend: :class:`aiida.orm.implementation.Backend` - - :return: a new collection with the new backend - :rtype: :class:`aiida.orm.Collection` """ - # Lazily get the collection i.e. create only if we haven't done so yet - return cls._COLLECTIONS.get((entity_type, backend), lambda: entity_type.Collection(backend, entity_type)) + from aiida.orm.implementation import Backend + type_check(backend, Backend) + return cls(entity_class, backend=backend) - def __init__(self, backend, entity_class): + def __init__(self, entity_class: Type[EntityType], backend: Optional['Backend'] = None) -> None: """ Construct a new entity collection. - :param backend: the backend instance to get the collection for - :type backend: :class:`aiida.orm.implementation.Backend` - :param entity_class: the entity type e.g. User, Computer, etc - :type entity_class: :class:`aiida.orm.Entity` - + :param backend: the backend instance to get the collection for, or use the default """ - assert issubclass(entity_class, Entity), 'Must provide an entity type' + from aiida.orm.implementation import Backend + type_check(backend, Backend, allow_none=True) + assert issubclass(entity_class, self._entity_base_cls()) self._backend = backend or get_manager().get_backend() self._entity_type = entity_class - def __call__(self, backend): - """ Create a new objects collection using a new backend. - - :param backend: the backend instance to get the collection for - :type backend: :class:`aiida.orm.implementation.Backend` - - :return: a new collection with the new backend - :rtype: :class:`aiida.orm.Collection` - """ + def __call__(self: CollectionType, backend: 'Backend') -> CollectionType: + """Get or create a cached collection using a new backend.""" if backend is self._backend: - # Special case if they actually want the same collection return self - - return self.get_collection(self.entity_type, backend) + return self.get_cached(self.entity_type, backend=backend) # type: ignore @property - def backend(self): - """Return the backend. - - :return: the backend instance of this collection - :rtype: :class:`aiida.orm.implementation.Backend` - """ - return self._backend + def entity_type(self) -> Type[EntityType]: + """The entity type for this instance.""" + return self._entity_type @property - def entity_type(self): - """The entity type. - - :rtype: :class:`aiida.orm.Entity` - """ - return self._entity_type + def backend(self) -> 'Backend': + """Return the backend.""" + return self._backend - def query(self, filters=None, order_by=None, limit=None, offset=None): - """ - Get a query builder for the objects of this collection + def query( + self, + filters: Optional[Dict[str, 'FilterType']] = None, + order_by: Optional['OrderByType'] = None, + limit: Optional[int] = None, + offset: Optional[int] = None + ) -> 'QueryBuilder': + """Get a query builder for the objects of this collection. :param filters: the keyword value pair filters to match - :type filters: dict - :param order_by: a list of (key, direction) pairs specifying the sort order - :type order_by: list - :param limit: the maximum number of results to return - :type limit: int - :param offset: number of initial results to be skipped - :type offset: int - - :return: a new query builder instance - :rtype: :class:`aiida.orm.QueryBuilder` """ from . import querybuilder filters = filters or {} order_by = {self.entity_type: order_by} if order_by else {} - query = querybuilder.QueryBuilder(limit=limit, offset=offset) + query = querybuilder.QueryBuilder(backend=self._backend, limit=limit, offset=offset) query.append(self.entity_type, project='*', filters=filters) query.order_by([order_by]) return query - def get(self, **filters): - """ - Get a single collection entry that matches the filter criteria + def get(self, **filters: Any) -> EntityType: + """Get a single collection entry that matches the filter criteria. :param filters: the filters identifying the object to get - :type filters: dict :return: the entry """ res = self.query(filters=filters) return res.one()[0] - def find(self, filters=None, order_by=None, limit=None): - """ - Find collection entries matching the filter criteria + def find( + self, + filters: Optional[Dict[str, 'FilterType']] = None, + order_by: Optional['OrderByType'] = None, + limit: Optional[int] = None + ) -> List[EntityType]: + """Find collection entries matching the filter criteria. :param filters: the keyword value pair filters to match - :type filters: dict - :param order_by: a list of (key, direction) pairs specifying the sort order - :type order_by: list - :param limit: the maximum number of results to return - :type limit: int :return: a list of resulting matches - :rtype: list """ query = self.query(filters=filters, order_by=order_by, limit=limit) - return query.all(flat=True) + return cast(List[EntityType], query.all(flat=True)) - def all(self): - """ - Get all entities in this collection + def all(self) -> List[EntityType]: + """Get all entities in this collection. :return: A list of all entities - :rtype: list """ - return self.query().all(flat=True) # pylint: disable=no-member + return cast(List[EntityType], self.query().all(flat=True)) # pylint: disable=no-member - def count(self, filters=None): - """Count entities in this collection according to criteria + def count(self, filters: Optional[Dict[str, 'FilterType']] = None) -> int: + """Count entities in this collection according to criteria. :param filters: the keyword value pair filters to match - :type filters: dict :return: The number of entities found using the supplied criteria - :rtype: int """ return self.query(filters=filters).count() -class Entity: +class Entity(abc.ABC): """An AiiDA entity""" - _objects = None - - # Define our collection type - Collection = Collection - @classproperty - def objects(cls, backend=None): # pylint: disable=no-self-argument - """ - Get a collection for objects of this type. - - :param backend: the optional backend to use (otherwise use default) - :type backend: :class:`aiida.orm.implementation.Backend` + @abc.abstractmethod + def objects(cls: EntityType) -> Collection[EntityType]: # pylint: disable=no-self-argument,disable=no-self-use + """Get a collection for objects of this type, with the default backend. :return: an object that can be used to access entities of this type - :rtype: :class:`aiida.orm.Collection` """ - backend = backend or get_manager().get_backend() - return cls.Collection.get_collection(cls, backend) @classmethod def get(cls, **kwargs): return cls.objects.get(**kwargs) # pylint: disable=no-member @classmethod - def from_backend_entity(cls, backend_entity): + def from_backend_entity(cls: Type[EntityType], backend_entity: 'BackendEntity') -> EntityType: """ Construct an entity from a backend entity instance @@ -232,30 +197,28 @@ def from_backend_entity(cls, backend_entity): call_with_super_check(entity.initialize) return entity - def __init__(self, backend_entity): + def __init__(self, backend_entity: 'BackendEntity') -> None: """ :param backend_entity: the backend model supporting this entity - :type backend_entity: :class:`aiida.orm.implementation.entities.BackendEntity` """ self._backend_entity = backend_entity call_with_super_check(self.initialize) - def init_from_backend(self, backend_entity): + def init_from_backend(self, backend_entity: 'BackendEntity') -> None: """ :param backend_entity: the backend model supporting this entity - :type backend_entity: :class:`aiida.orm.implementation.entities.BackendEntity` """ self._backend_entity = backend_entity @super_check - def initialize(self): + def initialize(self) -> None: """Initialize instance attributes. This will be called after the constructor is called or an entity is created from an existing backend entity. """ @property - def id(self): # pylint: disable=invalid-name + def id(self) -> int: # pylint: disable=invalid-name """Return the id for this entity. This identifier is guaranteed to be unique amongst entities of the same type for a single backend instance. @@ -265,7 +228,7 @@ def id(self): # pylint: disable=invalid-name return self._backend_entity.id @property - def pk(self): + def pk(self) -> int: """Return the primary key for this entity. This identifier is guaranteed to be unique amongst entities of the same type for a single backend instance. @@ -285,43 +248,44 @@ def uuid(self): """ return self._backend_entity.uuid - def store(self): + def store(self: EntityType) -> EntityType: """Store the entity.""" self._backend_entity.store() return self @property - def is_stored(self): - """Return whether the entity is stored. - - :return: boolean, True if stored, False otherwise - :rtype: bool - """ + def is_stored(self) -> bool: + """Return whether the entity is stored.""" return self._backend_entity.is_stored @property - def backend(self): - """ - Get the backend for this entity - :return: the backend instance - """ + def backend(self) -> 'Backend': + """Get the backend for this entity""" return self._backend_entity.backend @property - def backend_entity(self): - """ - Get the implementing class for this object - - :return: the class model - """ + def backend_entity(self) -> 'BackendEntity': + """Get the implementing class for this object""" return self._backend_entity -class EntityAttributesMixin(abc.ABC): +class EntityProtocol(Protocol): + """Protocol for attributes required by Entity mixins.""" + + @property + def backend_entity(self) -> 'BackendEntity': + ... + + @property + def is_stored(self) -> bool: + ... + + +class EntityAttributesMixin: """Mixin class that adds all methods for the attributes column to an entity.""" @property - def attributes(self): + def attributes(self: EntityProtocol) -> Dict[str, Any]: """Return the complete attributes dictionary. .. warning:: While the entity is unstored, this will return references of the attributes on the database model, @@ -341,7 +305,7 @@ def attributes(self): return attributes - def get_attribute(self, key, default=_NO_DEFAULT): + def get_attribute(self: EntityProtocol, key: str, default=_NO_DEFAULT) -> Any: """Return the value of an attribute. .. warning:: While the entity is unstored, this will return a reference of the attribute on the database model, @@ -367,7 +331,7 @@ def get_attribute(self, key, default=_NO_DEFAULT): return attribute - def get_attribute_many(self, keys): + def get_attribute_many(self: EntityProtocol, keys: List[str]) -> List[Any]: """Return the values of multiple attributes. .. warning:: While the entity is unstored, this will return references of the attributes on the database model, @@ -389,7 +353,7 @@ def get_attribute_many(self, keys): return attributes - def set_attribute(self, key, value): + def set_attribute(self: EntityProtocol, key: str, value: Any) -> None: """Set an attribute to the given value. :param key: name of the attribute @@ -402,7 +366,7 @@ def set_attribute(self, key, value): self.backend_entity.set_attribute(key, value) - def set_attribute_many(self, attributes): + def set_attribute_many(self: EntityProtocol, attributes: Dict[str, Any]) -> None: """Set multiple attributes. .. note:: This will override any existing attributes that are present in the new dictionary. @@ -416,7 +380,7 @@ def set_attribute_many(self, attributes): self.backend_entity.set_attribute_many(attributes) - def reset_attributes(self, attributes): + def reset_attributes(self: EntityProtocol, attributes: Dict[str, Any]) -> None: """Reset the attributes. .. note:: This will completely clear any existing attributes and replace them with the new dictionary. @@ -430,7 +394,7 @@ def reset_attributes(self, attributes): self.backend_entity.reset_attributes(attributes) - def delete_attribute(self, key): + def delete_attribute(self: EntityProtocol, key: str) -> None: """Delete an attribute. :param key: name of the attribute @@ -442,7 +406,7 @@ def delete_attribute(self, key): self.backend_entity.delete_attribute(key) - def delete_attribute_many(self, keys): + def delete_attribute_many(self: EntityProtocol, keys: List[str]) -> None: """Delete multiple attributes. :param keys: names of the attributes to delete @@ -454,21 +418,21 @@ def delete_attribute_many(self, keys): self.backend_entity.delete_attribute_many(keys) - def clear_attributes(self): + def clear_attributes(self: EntityProtocol) -> None: """Delete all attributes.""" if self.is_stored: raise exceptions.ModificationNotAllowed('the attributes of a stored entity are immutable') self.backend_entity.clear_attributes() - def attributes_items(self): + def attributes_items(self: EntityProtocol): """Return an iterator over the attributes. :return: an iterator with attribute key value pairs """ return self.backend_entity.attributes_items() - def attributes_keys(self): + def attributes_keys(self: EntityProtocol): """Return an iterator over the attribute keys. :return: an iterator with attribute keys @@ -476,11 +440,11 @@ def attributes_keys(self): return self.backend_entity.attributes_keys() -class EntityExtrasMixin(abc.ABC): +class EntityExtrasMixin: """Mixin class that adds all methods for the extras column to an entity.""" @property - def extras(self): + def extras(self: EntityProtocol) -> Dict[str, Any]: """Return the complete extras dictionary. .. warning:: While the entity is unstored, this will return references of the extras on the database model, @@ -500,7 +464,7 @@ def extras(self): return extras - def get_extra(self, key, default=_NO_DEFAULT): + def get_extra(self: EntityProtocol, key: str, default: Any = _NO_DEFAULT) -> Any: """Return the value of an extra. .. warning:: While the entity is unstored, this will return a reference of the extra on the database model, @@ -526,7 +490,7 @@ def get_extra(self, key, default=_NO_DEFAULT): return extra - def get_extra_many(self, keys): + def get_extra_many(self: EntityProtocol, keys: List[str]) -> List[Any]: """Return the values of multiple extras. .. warning:: While the entity is unstored, this will return references of the extras on the database model, @@ -548,7 +512,7 @@ def get_extra_many(self, keys): return extras - def set_extra(self, key, value): + def set_extra(self: EntityProtocol, key: str, value: Any) -> None: """Set an extra to the given value. :param key: name of the extra @@ -557,7 +521,7 @@ def set_extra(self, key, value): """ self.backend_entity.set_extra(key, value) - def set_extra_many(self, extras): + def set_extra_many(self: EntityProtocol, extras: Dict[str, Any]) -> None: """Set multiple extras. .. note:: This will override any existing extras that are present in the new dictionary. @@ -567,7 +531,7 @@ def set_extra_many(self, extras): """ self.backend_entity.set_extra_many(extras) - def reset_extras(self, extras): + def reset_extras(self: EntityProtocol, extras: Dict[str, Any]) -> None: """Reset the extras. .. note:: This will completely clear any existing extras and replace them with the new dictionary. @@ -577,7 +541,7 @@ def reset_extras(self, extras): """ self.backend_entity.reset_extras(extras) - def delete_extra(self, key): + def delete_extra(self: EntityProtocol, key: str) -> None: """Delete an extra. :param key: name of the extra @@ -585,7 +549,7 @@ def delete_extra(self, key): """ self.backend_entity.delete_extra(key) - def delete_extra_many(self, keys): + def delete_extra_many(self: EntityProtocol, keys: List[str]) -> None: """Delete multiple extras. :param keys: names of the extras to delete @@ -593,18 +557,18 @@ def delete_extra_many(self, keys): """ self.backend_entity.delete_extra_many(keys) - def clear_extras(self): + def clear_extras(self: EntityProtocol) -> None: """Delete all extras.""" self.backend_entity.clear_extras() - def extras_items(self): + def extras_items(self: EntityProtocol): """Return an iterator over the extras. :return: an iterator with extra key value pairs """ return self.backend_entity.extras_items() - def extras_keys(self): + def extras_keys(self: EntityProtocol): """Return an iterator over the extra keys. :return: an iterator with extra keys diff --git a/aiida/orm/groups.py b/aiida/orm/groups.py index c61aca0bca..cddf71db51 100644 --- a/aiida/orm/groups.py +++ b/aiida/orm/groups.py @@ -9,11 +9,11 @@ ########################################################################### """AiiDA Group entites""" from abc import ABCMeta -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Tuple, Type import warnings from aiida.common import exceptions -from aiida.common.lang import type_check +from aiida.common.lang import classproperty, type_check from aiida.manage.manager import get_manager from . import convert, entities, users @@ -63,47 +63,56 @@ def __new__(cls, name, bases, namespace, **kwargs): return newcls -class Group(entities.Entity, entities.EntityExtrasMixin, metaclass=GroupMeta): - """An AiiDA ORM implementation of group of nodes.""" +class GroupCollection(entities.Collection['Group']): + """Collection of Groups""" - # added by metaclass - _type_string = ClassVar[Optional[str]] + @staticmethod + def _entity_base_cls() -> Type['Group']: + return Group + + def get_or_create(self, label: Optional[str] = None, **kwargs) -> Tuple['Group', bool]: + """ + Try to retrieve a group from the DB with the given arguments; + create (and store) a new group if such a group was not present yet. + + :param label: group label + + :return: (group, created) where group is the group (new or existing, + in any case already stored) and created is a boolean saying + """ + if not label: + raise ValueError('Group label must be provided') - class Collection(entities.Collection): - """Collection of Groups""" + res = self.find(filters={'label': label}) - def get_or_create(self, label=None, **kwargs): - """ - Try to retrieve a group from the DB with the given arguments; - create (and store) a new group if such a group was not present yet. + if not res: + return self.entity_type(label, backend=self.backend, **kwargs).store(), True - :param label: group label - :type label: str + if len(res) > 1: + raise exceptions.MultipleObjectsError('More than one groups found in the database') - :return: (group, created) where group is the group (new or existing, - in any case already stored) and created is a boolean saying - :rtype: (:class:`aiida.orm.Group`, bool) - """ - if not label: - raise ValueError('Group label must be provided') + return res[0], False - res = self.find(filters={'label': label}) + def delete(self, pk: int) -> None: + """ + Delete a group - if not res: - return self.entity_type(label, backend=self.backend, **kwargs).store(), True + :param pk: the id of the group to delete + """ + self._backend.groups.delete(pk) - if len(res) > 1: - raise exceptions.MultipleObjectsError('More than one groups found in the database') - return res[0], False +class Group(entities.Entity, entities.EntityExtrasMixin, metaclass=GroupMeta): + """An AiiDA ORM implementation of group of nodes.""" + + # added by metaclass + _type_string = ClassVar[Optional[str]] - def delete(self, id): # pylint: disable=invalid-name, redefined-builtin - """ - Delete a group + Collection = GroupCollection - :param id: the id of the group to delete - """ - self._backend.groups.delete(id) + @classproperty + def objects(cls) -> GroupCollection: # pylint: disable=no-self-argument + return GroupCollection.get_cached(cls, get_manager().get_backend()) def __init__(self, label=None, user=None, description='', type_string=None, backend=None): """ diff --git a/aiida/orm/implementation/comments.py b/aiida/orm/implementation/comments.py index 432dd36e02..50ec1273c2 100644 --- a/aiida/orm/implementation/comments.py +++ b/aiida/orm/implementation/comments.py @@ -8,8 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for comment backend classes.""" - import abc +from typing import List from .entities import BackendCollection, BackendEntity @@ -98,7 +98,7 @@ def delete_all(self): """ @abc.abstractmethod - def delete_many(self, filters): + def delete_many(self, filters) -> List[int]: """ Delete Comments based on ``filters`` diff --git a/aiida/orm/implementation/logs.py b/aiida/orm/implementation/logs.py index 172b0dcaa5..ef53c90fad 100644 --- a/aiida/orm/implementation/logs.py +++ b/aiida/orm/implementation/logs.py @@ -9,6 +9,7 @@ ########################################################################### """Backend group module""" import abc +from typing import List from .entities import BackendCollection, BackendEntity @@ -97,19 +98,18 @@ class BackendLogCollection(BackendCollection[BackendLog]): ENTITY_CLASS = BackendLog @abc.abstractmethod - def delete(self, log_id): + def delete(self, log_id: int) -> None: """ Remove a Log entry from the collection with the given id :param log_id: id of the Log to delete - :type log_id: int :raises TypeError: if ``log_id`` is not an `int` :raises `~aiida.common.exceptions.NotExistent`: if Log with ID ``log_id`` is not found """ @abc.abstractmethod - def delete_all(self): + def delete_all(self) -> None: """ Delete all Log entries. @@ -117,15 +117,13 @@ def delete_all(self): """ @abc.abstractmethod - def delete_many(self, filters): + def delete_many(self, filters) -> List[int]: """ Delete Logs based on ``filters`` :param filters: similar to QueryBuilder filter - :type filters: dict :return: (former) ``PK`` s of deleted Logs - :rtype: list :raises TypeError: if ``filters`` is not a `dict` :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty diff --git a/aiida/orm/logs.py b/aiida/orm/logs.py index 9f5d83eb8a..36b6c20ba2 100644 --- a/aiida/orm/logs.py +++ b/aiida/orm/logs.py @@ -8,12 +8,19 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for orm logging abstract classes""" +import logging +from typing import TYPE_CHECKING, List, Optional, Type from aiida.common import timezone +from aiida.common.lang import classproperty from aiida.manage.manager import get_manager from . import entities +if TYPE_CHECKING: + from aiida.orm import Node + from aiida.orm.querybuilder import FilterType, OrderByType + __all__ = ('Log', 'OrderSpecifier', 'ASCENDING', 'DESCENDING') ASCENDING = 'asc' @@ -24,114 +31,108 @@ def OrderSpecifier(field, direction): # pylint: disable=invalid-name return {field: direction} -class Log(entities.Entity): +class LogCollection(entities.Collection['Log']): """ - An AiiDA Log entity. Corresponds to a logged message against a particular AiiDA node. + This class represents the collection of logs and can be used to create + and retrieve logs. """ - class Collection(entities.Collection): - """ - This class represents the collection of logs and can be used to create - and retrieve logs. - """ + @staticmethod + def _entity_base_cls() -> Type['Log']: + return Log - @staticmethod - def create_entry_from_record(record): - """ - Helper function to create a log entry from a record created as by the python logging library + def create_entry_from_record(self, record: logging.LogRecord) -> 'Log': + """Helper function to create a log entry from a record created as by the python logging library - :param record: The record created by the logging module - :type record: :class:`logging.LogRecord` + :param record: The record created by the logging module + :return: A stored log instance + """ + from datetime import datetime + + dbnode_id = record.__dict__.get('dbnode_id', None) - :return: An object implementing the log entry interface - :rtype: :class:`aiida.orm.logs.Log` - """ - from datetime import datetime + # Do not store if dbnode_id is not set + if dbnode_id is None: + return None - dbnode_id = record.__dict__.get('dbnode_id', None) + metadata = dict(record.__dict__) - # Do not store if dbnode_id is not set - if dbnode_id is None: - return None + # If an `exc_info` is present, the log message was an exception, so format the full traceback + try: + import traceback + exc_info = metadata.pop('exc_info') + message = ''.join(traceback.format_exception(*exc_info)) + except (TypeError, KeyError): + message = record.getMessage() - metadata = dict(record.__dict__) + # Stringify the content of `args` if they exist in the metadata to ensure serializability + for key in ['args']: + if key in metadata: + metadata[key] = str(metadata[key]) - # If an `exc_info` is present, the log message was an exception, so format the full traceback - try: - import traceback - exc_info = metadata.pop('exc_info') - message = ''.join(traceback.format_exception(*exc_info)) - except (TypeError, KeyError): - message = record.getMessage() + return Log( + time=timezone.make_aware(datetime.fromtimestamp(record.created)), + loggername=record.name, + levelname=record.levelname, + dbnode_id=dbnode_id, + message=message, + metadata=metadata, + backend=self.backend + ) - # Stringify the content of `args` if they exist in the metadata to ensure serializability - for key in ['args']: - if key in metadata: - metadata[key] = str(metadata[key]) + def get_logs_for(self, entity: 'Node', order_by: Optional['OrderByType'] = None) -> List['Log']: + """Get all the log messages for a given node and optionally sort - return Log( - time=timezone.make_aware(datetime.fromtimestamp(record.created)), - loggername=record.name, - levelname=record.levelname, - dbnode_id=dbnode_id, - message=message, - metadata=metadata - ) + :param entity: the entity to get logs for + :param order_by: a list of (key, direction) pairs specifying the sort order - def get_logs_for(self, entity, order_by=None): - """ - Get all the log messages for a given entity and optionally sort + :return: the list of log entries + """ + from . import nodes - :param entity: the entity to get logs for - :type entity: :class:`aiida.orm.Entity` + if not isinstance(entity, nodes.Node): + raise Exception('Only node logs are stored') - :param order_by: a list of (key, direction) pairs specifying the sort order - :type order_by: list + return self.find({'dbnode_id': entity.pk}, order_by=order_by) - :return: the list of log entries - :rtype: list - """ - from . import nodes + def delete(self, pk: int) -> None: + """Remove a Log entry from the collection with the given id - if not isinstance(entity, nodes.Node): - raise Exception('Only node logs are stored') + :param pk: id of the Log to delete - return self.find({'dbnode_id': entity.pk}, order_by=order_by) + :raises `~aiida.common.exceptions.NotExistent`: if Log with ID ``pk`` is not found + """ + return self._backend.logs.delete(pk) - def delete(self, log_id): - """ - Remove a Log entry from the collection with the given id + def delete_all(self) -> None: + """Delete all Logs in the collection - :param log_id: id of the Log to delete - :type log_id: int + :raises `~aiida.common.exceptions.IntegrityError`: if all Logs could not be deleted + """ + return self._backend.logs.delete_all() - :raises TypeError: if ``log_id`` is not an `int` - :raises `~aiida.common.exceptions.NotExistent`: if Log with ID ``log_id`` is not found - """ - self._backend.logs.delete(log_id) + def delete_many(self, filters: 'FilterType') -> List[int]: + """Delete Logs based on ``filters`` - def delete_all(self): - """ - Delete all Logs in the collection + :param filters: filters to pass to the QueryBuilder + :return: (former) ``PK`` s of deleted Logs - :raises `~aiida.common.exceptions.IntegrityError`: if all Logs could not be deleted - """ - self._backend.logs.delete_all() + :raises TypeError: if ``filters`` is not a `dict` + :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty + """ + return self._backend.logs.delete_many(filters) - def delete_many(self, filters): - """ - Delete Logs based on ``filters`` - :param filters: similar to QueryBuilder filter - :type filters: dict +class Log(entities.Entity): + """ + An AiiDA Log entity. Corresponds to a logged message against a particular AiiDA node. + """ - :return: (former) ``PK`` s of deleted Logs - :rtype: list + Collection = LogCollection - :raises TypeError: if ``filters`` is not a `dict` - :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty - """ - self._backend.logs.delete_many(filters) + @classproperty + def objects(cls) -> LogCollection: # pylint: disable=no-self-argument + return LogCollection.get_cached(cls, get_manager().get_backend()) def __init__(self, time, loggername, levelname, dbnode_id, message='', metadata=None, backend=None): # pylint: disable=too-many-arguments """Construct a new log diff --git a/aiida/orm/nodes/data/remote/base.py b/aiida/orm/nodes/data/remote/base.py index 8de116de26..27d5773563 100644 --- a/aiida/orm/nodes/data/remote/base.py +++ b/aiida/orm/nodes/data/remote/base.py @@ -177,4 +177,4 @@ def _validate(self): raise ValidationError('Remote computer not set.') def get_authinfo(self): - return AuthInfo.objects.get(dbcomputer=self.computer, aiidauser=self.user) + return AuthInfo.objects(self.backend).get(dbcomputer=self.computer, aiidauser=self.user) diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index ac71a1b668..66ab7b14fa 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -14,7 +14,21 @@ import importlib from logging import Logger import typing -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + Generic, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, +) from uuid import UUID from aiida.common import exceptions @@ -41,7 +55,33 @@ __all__ = ('Node',) -_NO_DEFAULT = tuple() # type: ignore[var-annotated] +NodeType = TypeVar('NodeType', bound='Node') + + +class NodeCollection(EntityCollection[NodeType], Generic[NodeType]): + """The collection of nodes.""" + + @staticmethod + def _entity_base_cls() -> Type['Node']: + return Node + + def delete(self, pk: int) -> None: + """Delete a `Node` from the collection with the given id + + :param pk: the node id + """ + node = self.get(id=pk) + + if not node.is_stored: + return + + if node.get_incoming().all(): + raise exceptions.InvalidOperation(f'cannot delete Node<{node.pk}> because it has incoming links') + + if node.get_outgoing().all(): + raise exceptions.InvalidOperation(f'cannot delete Node<{node.pk}> because it has outgoing links') + + self._backend.nodes.delete(pk) class Node(Entity, NodeRepositoryMixin, EntityAttributesMixin, EntityExtrasMixin, metaclass=AbstractNodeMeta): @@ -65,27 +105,6 @@ class Node(Entity, NodeRepositoryMixin, EntityAttributesMixin, EntityExtrasMixin _plugin_type_string: ClassVar[str] _query_type_string: ClassVar[str] - class Collection(EntityCollection): - """The collection of nodes.""" - - def delete(self, node_id: int) -> None: - """Delete a `Node` from the collection with the given id - - :param node_id: the node id - """ - node = self.get(id=node_id) - - if not node.is_stored: - return - - if node.get_incoming().all(): - raise exceptions.InvalidOperation(f'cannot delete Node<{node.pk}> because it has incoming links') - - if node.get_outgoing().all(): - raise exceptions.InvalidOperation(f'cannot delete Node<{node.pk}> because it has outgoing links') - - self._backend.nodes.delete(node_id) - # This will be set by the metaclass call _logger: Optional[Logger] = None @@ -106,10 +125,11 @@ def delete(self, node_id: int) -> None: # These are to be initialized in the `initialization` method _incoming_cache: Optional[List[LinkTriple]] = None - @classmethod - def from_backend_entity(cls, backend_entity: 'BackendNode') -> 'Node': - entity = super().from_backend_entity(backend_entity) - return entity + Collection = NodeCollection + + @classproperty + def objects(cls: Type[NodeType]) -> NodeCollection[NodeType]: # pylint: disable=no-self-argument + return NodeCollection.get_cached(cls, get_manager().get_backend()) # type: ignore[arg-type] def __init__( self, @@ -304,11 +324,7 @@ def repository_metadata(self, value): @property def computer(self) -> Optional[Computer]: - """Return the computer of this node. - - :return: the computer or None - :rtype: `Computer` or None - """ + """Return the computer of this node.""" if self.backend_entity.computer: return Computer.from_backend_entity(self.backend_entity.computer) @@ -332,11 +348,7 @@ def computer(self, computer: Optional[Computer]) -> None: @property def user(self) -> User: - """Return the user of this node. - - :return: the user - :rtype: `User` - """ + """Return the user of this node.""" return User.from_backend_entity(self.backend_entity.user) @user.setter @@ -374,7 +386,7 @@ def add_comment(self, content: str, user: Optional[User] = None) -> Comment: :param user: the user to associate with the comment, will use default if not supplied :return: the newly created comment """ - user = user or User.objects.get_default() + user = user or User.objects(self.backend).get_default() return Comment(node=self, user=user, content=content).store() def get_comment(self, identifier: int) -> Comment: @@ -385,14 +397,14 @@ def get_comment(self, identifier: int) -> Comment: :raise aiida.common.MultipleObjectsError: if the id cannot be uniquely resolved to a comment :return: the comment """ - return Comment.objects.get(dbnode_id=self.pk, id=identifier) + return Comment.objects(self.backend).get(dbnode_id=self.pk, id=identifier) def get_comments(self) -> List[Comment]: """Return a sorted list of comments for this node. :return: the list of comments, sorted by pk """ - return Comment.objects.find(filters={'dbnode_id': self.pk}, order_by=[{'id': 'asc'}]) + return Comment.objects(self.backend).find(filters={'dbnode_id': self.pk}, order_by=[{'id': 'asc'}]) def update_comment(self, identifier: int, content: str) -> None: """Update the content of an existing comment. @@ -402,7 +414,7 @@ def update_comment(self, identifier: int, content: str) -> None: :raise aiida.common.NotExistent: if the comment with the given id does not exist :raise aiida.common.MultipleObjectsError: if the id cannot be uniquely resolved to a comment """ - comment = Comment.objects.get(dbnode_id=self.pk, id=identifier) + comment = Comment.objects(self.backend).get(dbnode_id=self.pk, id=identifier) comment.set_content(content) def remove_comment(self, identifier: int) -> None: # pylint: disable=no-self-use @@ -410,7 +422,7 @@ def remove_comment(self, identifier: int) -> None: # pylint: disable=no-self-us :param identifier: the comment pk """ - Comment.objects.delete(identifier) + Comment.objects(self.backend).delete(identifier) def add_incoming(self, source: 'Node', link_type: LinkType, link_label: str) -> None: """Add a link of the given type from a given node to ourself. diff --git a/aiida/orm/users.py b/aiida/orm/users.py index 4c55bcad6f..d00d889388 100644 --- a/aiida/orm/users.py +++ b/aiida/orm/users.py @@ -8,68 +8,78 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for the ORM user class.""" +from typing import TYPE_CHECKING, Optional, Tuple, Type + from aiida.common import exceptions +from aiida.common.lang import classproperty from aiida.manage.manager import get_manager from . import entities +if TYPE_CHECKING: + from aiida.orm.implementation import Backend + __all__ = ('User',) -class User(entities.Entity): - """AiiDA User""" +class UserCollection(entities.Collection['User']): + """The collection of users stored in a backend.""" - class Collection(entities.Collection): - """The collection of users stored in a backend.""" + UNDEFINED = 'UNDEFINED' + _default_user: Optional['User'] = None - UNDEFINED = 'UNDEFINED' - _default_user = None # type: aiida.orm.User + @staticmethod + def _entity_base_cls() -> Type['User']: + return User - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._default_user = self.UNDEFINED + def __init__(self, entity_class: Type['User'], backend: Optional['Backend'] = None) -> None: + super().__init__(entity_class=entity_class, backend=backend) + self._default_user = self.UNDEFINED - def get_or_create(self, email, **kwargs): - """ - Get the existing user with a given email address or create an unstored one + def get_or_create(self, email: str, **kwargs) -> Tuple[bool, 'User']: + """Get the existing user with a given email address or create an unstored one + + :param kwargs: The properties of the user to get or create + :return: The corresponding user object + :raises: :class:`aiida.common.exceptions.MultipleObjectsError`, + :class:`aiida.common.exceptions.NotExistent` + """ + try: + return False, self.get(email=email) + except exceptions.NotExistent: + return True, User(backend=self.backend, email=email, **kwargs) + + def get_default(self) -> 'User': + """Get the current default user""" + if self._default_user is self.UNDEFINED: + from aiida.manage.configuration import get_profile + profile = get_profile() + email = profile.default_user + if not email: + self._default_user = None - :param kwargs: The properties of the user to get or create - :return: The corresponding user object - :rtype: :class:`aiida.orm.User` - :raises: :class:`aiida.common.exceptions.MultipleObjectsError`, - :class:`aiida.common.exceptions.NotExistent` - """ try: - return False, self.get(email=email) - except exceptions.NotExistent: - return True, User(backend=self.backend, email=email, **kwargs) - - def get_default(self): - """ - Get the current default user - - :return: The default user - :rtype: :class:`aiida.orm.User` - """ - if self._default_user is self.UNDEFINED: - from aiida.manage.configuration import get_profile - profile = get_profile() - email = profile.default_user - if not email: - self._default_user = None - - try: - self._default_user = self.get(email=email) - except (exceptions.MultipleObjectsError, exceptions.NotExistent): - self._default_user = None - - return self._default_user - - def reset(self): - """ - Reset internal caches (default user). - """ - self._default_user = self.UNDEFINED + self._default_user = self.get(email=email) + except (exceptions.MultipleObjectsError, exceptions.NotExistent): + self._default_user = None + + return self._default_user + + def reset(self) -> None: + """ + Reset internal caches (default user). + """ + self._default_user = self.UNDEFINED + + +class User(entities.Entity): + """AiiDA User""" + + Collection = UserCollection + + @classproperty + def objects(cls) -> UserCollection: # pylint: disable=no-self-argument + return UserCollection.get_cached(cls, get_manager().get_backend()) REQUIRED_FIELDS = ['first_name', 'last_name', 'institution'] diff --git a/docs/source/howto/data.rst b/docs/source/howto/data.rst index d11cbd62c7..bb152a1fc6 100644 --- a/docs/source/howto/data.rst +++ b/docs/source/howto/data.rst @@ -576,7 +576,7 @@ This includes, notably: * *Node extras*: These can be deleted using :py:meth:`~aiida.orm.entities.EntityExtrasMixin.delete_extra` and :py:meth:`~aiida.orm.entities.EntityExtrasMixin.delete_extra_many` methods. * *Node comments*: These can be removed using :py:meth:`~aiida.orm.nodes.node.Node.remove_comment`. -* *Groups*: These can be deleted using :py:meth:`Group.objects.delete() `. +* *Groups*: These can be deleted using :py:meth:`Group.objects.delete() `. This command will only delete the group, not the nodes contained in the group. Completely deleting an AiiDA profile diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index b4670ae5f3..e03be84f3c 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -21,11 +21,6 @@ py:class builtins.dict # typing py:class AbstractContextManager py:class asyncio.events.AbstractEventLoop -py:class EntityType -py:class EntityTypes -py:class ModelType -py:class TransactionType -py:class ReturnType py:class function py:class IO py:class traceback @@ -48,6 +43,7 @@ py:class aiida.orm.utils.links.LinkQuadruple py:class aiida.tools.importexport.dbexport.ExportReport py:class aiida.tools.importexport.dbexport.ArchiveData py:class aiida.tools.groups.paths.WalkNodeResult +py:meth aiida.orm.groups.GroupCollection.delete py:class Backend py:class BackendEntity @@ -56,16 +52,22 @@ py:class AuthInfo py:class CalcJob py:class CalcJobImporter py:class CalcJobNode +py:class CollectionType py:class Data py:class DbImporter +py:class EntityType +py:class EntityTypes py:class ExitCode py:class File +py:class FilterType py:class FolderData py:class Group py:class JobInfo py:class JobState +py:class ModelType py:class Node py:class Orbital +py:class OrderByType py:class Parser py:class PersistenceError py:class Process @@ -74,14 +76,26 @@ py:class ProcessNode py:class ProcessSpec py:class Port py:class PortNamespace +py:class QueryBuilder py:class Repository +py:class ReturnType py:class Runner py:class Scheduler +py:class SelfType +py:class TransactionType py:class Transport py:class TransportQueue py:class WorkChain py:class WorkChainSpec +py:class AuthInfoCollection +py:class CommentCollection +py:class ComputerCollection +py:class GroupCollection +py:class LogCollection +py:class NodeCollection +py:class UserCollection + py:class kiwipy.communications.Communicator py:class plumpy.process_states.State py:class plumpy.workchains._If diff --git a/tests/orm/test_groups.py b/tests/orm/test_groups.py index 3e52f4d0f7..3e9c5e1397 100644 --- a/tests/orm/test_groups.py +++ b/tests/orm/test_groups.py @@ -345,7 +345,7 @@ def test_loading_unregistered(): # Removing it as other methods might get a warning instead group_pk = group.pk del group - orm.Group.objects.delete(id=group_pk) + orm.Group.objects.delete(pk=group_pk) @staticmethod def test_querying(): @@ -367,7 +367,7 @@ def test_querying(): # Removing it as other methods might get a warning instead group_pk = group.pk del group - orm.Group.objects.delete(id=group_pk) + orm.Group.objects.delete(pk=group_pk) @staticmethod def test_querying_node_subclasses():