From 351684c340d1ef36cf58a7a0457cf0ea032cb0bb Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 10 Aug 2021 21:23:52 +0200 Subject: [PATCH 01/11] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20REFACTOR:=20Make=20`?= =?UTF-8?q?=5F=5Fall=5F=5F`=20explicit?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .pre-commit-config.yaml | 9 + aiida/cmdline/__init__.py | 25 +- aiida/cmdline/params/__init__.py | 35 + aiida/cmdline/params/arguments/__init__.py | 80 +-- aiida/cmdline/params/arguments/main.py | 69 ++ aiida/cmdline/params/options/__init__.py | 659 ++---------------- aiida/cmdline/params/options/main.py | 597 ++++++++++++++++ aiida/cmdline/params/types/__init__.py | 71 +- aiida/cmdline/params/types/calculation.py | 2 + aiida/cmdline/params/types/choice.py | 2 + aiida/cmdline/params/types/code.py | 1 + aiida/cmdline/params/types/computer.py | 1 + aiida/cmdline/params/types/data.py | 1 + aiida/cmdline/params/types/group.py | 1 + aiida/cmdline/params/types/identifier.py | 1 + aiida/cmdline/params/types/multiple.py | 1 + aiida/cmdline/params/types/node.py | 1 + aiida/cmdline/params/types/path.py | 2 + aiida/cmdline/params/types/plugin.py | 1 + aiida/cmdline/params/types/process.py | 1 + aiida/cmdline/params/types/profile.py | 1 + aiida/cmdline/params/types/strings.py | 1 + aiida/cmdline/params/types/test_module.py | 1 + aiida/cmdline/params/types/user.py | 1 + aiida/cmdline/params/types/workflow.py | 1 + aiida/cmdline/utils/__init__.py | 23 + aiida/common/__init__.py | 66 +- aiida/engine/__init__.py | 59 +- aiida/engine/processes/__init__.py | 45 +- aiida/engine/processes/calcjobs/__init__.py | 12 +- aiida/engine/processes/workchains/__init__.py | 24 +- aiida/manage/__init__.py | 56 +- aiida/manage/configuration/__init__.py | 283 +------- aiida/manage/configuration/main.py | 267 +++++++ .../configuration/migrations/__init__.py | 13 +- aiida/manage/database/__init__.py | 14 + aiida/manage/database/integrity/__init__.py | 51 +- aiida/manage/database/integrity/utils.py | 53 ++ aiida/manage/external/__init__.py | 18 + aiida/manage/tests/__init__.py | 501 +------------ aiida/orm/__init__.py | 83 ++- aiida/orm/implementation/__init__.py | 33 +- aiida/orm/implementation/django/__init__.py | 18 + aiida/orm/implementation/sql/__init__.py | 10 + .../orm/implementation/sqlalchemy/__init__.py | 18 + aiida/orm/nodes/__init__.py | 48 +- aiida/orm/nodes/data/__init__.py | 71 +- aiida/orm/nodes/data/array/__init__.py | 26 +- aiida/orm/nodes/data/array/array.py | 3 +- aiida/orm/nodes/data/array/bands.py | 2 + aiida/orm/nodes/data/array/kpoints.py | 2 + aiida/orm/nodes/data/array/projection.py | 2 + aiida/orm/nodes/data/array/trajectory.py | 1 + aiida/orm/nodes/data/array/xy.py | 2 + aiida/orm/nodes/data/remote/__init__.py | 15 +- aiida/orm/nodes/data/remote/stash/__init__.py | 14 +- aiida/orm/nodes/process/__init__.py | 15 +- .../orm/nodes/process/calculation/__init__.py | 16 +- aiida/orm/nodes/process/workflow/__init__.py | 16 +- aiida/orm/utils/__init__.py | 229 +----- aiida/orm/utils/load_funcs.py | 205 ++++++ aiida/parsers/__init__.py | 9 +- aiida/plugins/__init__.py | 22 +- aiida/repository/__init__.py | 15 +- aiida/repository/backend/__init__.py | 12 +- aiida/restapi/__init__.py | 11 + aiida/schedulers/__init__.py | 18 +- aiida/tools/__init__.py | 73 +- aiida/tools/calculations/__init__.py | 9 +- aiida/tools/data/__init__.py | 17 + aiida/tools/data/array/__init__.py | 11 + aiida/tools/data/array/kpoints/__init__.py | 233 +------ aiida/tools/data/array/kpoints/main.py | 241 +++++++ aiida/tools/data/array/kpoints/seekpath.py | 2 - aiida/tools/data/orbital/__init__.py | 12 +- aiida/tools/data/orbital/orbital.py | 2 + aiida/tools/data/orbital/realhydrogen.py | 2 + .../{structure/__init__.py => structure.py} | 0 aiida/tools/graph/__init__.py | 12 +- aiida/tools/groups/__init__.py | 14 +- aiida/tools/importexport/__init__.py | 51 +- aiida/tools/importexport/archive/__init__.py | 30 +- aiida/tools/importexport/common/__init__.py | 22 +- aiida/tools/importexport/dbexport/__init__.py | 609 +--------------- aiida/tools/importexport/dbexport/main.py | 613 ++++++++++++++++ aiida/tools/importexport/dbimport/__init__.py | 81 +-- aiida/tools/importexport/dbimport/main.py | 86 +++ aiida/tools/visualization/__init__.py | 14 +- aiida/transports/__init__.py | 13 +- aiida/transports/plugins/__init__.py | 12 + utils/make_all.py | 142 ++++ 91 files changed, 3615 insertions(+), 2652 deletions(-) create mode 100644 aiida/cmdline/params/arguments/main.py create mode 100644 aiida/cmdline/params/options/main.py create mode 100644 aiida/manage/configuration/main.py create mode 100644 aiida/manage/database/integrity/utils.py create mode 100644 aiida/orm/utils/load_funcs.py create mode 100644 aiida/tools/data/array/kpoints/main.py rename aiida/tools/data/{structure/__init__.py => structure.py} (100%) create mode 100644 aiida/tools/importexport/dbexport/main.py create mode 100644 aiida/tools/importexport/dbimport/main.py create mode 100644 utils/make_all.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d400a83f3a..6c228fe232 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,6 +39,15 @@ repos: hooks: + - id: imports + name: imports + entry: utils/make_all.py + language: python + types: [python] + require_serial: true + pass_filenames: false + files: aiida/.*py + - id: mypy name: mypy entry: mypy diff --git a/aiida/cmdline/__init__.py b/aiida/cmdline/__init__.py index 34a245187e..f4af371243 100644 --- a/aiida/cmdline/__init__.py +++ b/aiida/cmdline/__init__.py @@ -7,16 +7,25 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """The command line interface of AiiDA.""" -from .params.arguments import * -from .params.options import * -from .params.types import * -from .utils.decorators import * -from .utils.echo import * +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .utils import * __all__ = ( - params.arguments.__all__ + params.options.__all__ + params.types.__all__ + utils.decorators.__all__ + - utils.echo.__all__ + 'dbenv', + 'echo', + 'echo_critical', + 'echo_dictionary', + 'echo_error', + 'echo_highlight', + 'echo_info', + 'echo_success', + 'echo_warning', + 'format_call_graph', + 'only_if_daemon_running', + 'with_dbenv', ) diff --git a/aiida/cmdline/params/__init__.py b/aiida/cmdline/params/__init__.py index 2776a55f97..509af66b4a 100644 --- a/aiida/cmdline/params/__init__.py +++ b/aiida/cmdline/params/__init__.py @@ -7,3 +7,38 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .types import * + +__all__ = ( + 'AbsolutePathParamType', + 'CalculationParamType', + 'CodeParamType', + 'ComputerParamType', + 'ConfigOptionParamType', + 'DataParamType', + 'EmailType', + 'EntryPointType', + 'FileOrUrl', + 'GroupParamType', + 'HostnameType', + 'IdentifierParamType', + 'LabelStringType', + 'LazyChoice', + 'MpirunCommandParamType', + 'MultipleValueParamType', + 'NodeParamType', + 'NonEmptyStringParamType', + 'PathOrUrl', + 'PluginParamType', + 'ProcessParamType', + 'ProfileParamType', + 'ShebangParamType', + 'TestModuleParamType', + 'UserParamType', + 'WorkflowParamType', +) diff --git a/aiida/cmdline/params/arguments/__init__.py b/aiida/cmdline/params/arguments/__init__.py index 71bb8c2544..83f5413ed9 100644 --- a/aiida/cmdline/params/arguments/__init__.py +++ b/aiida/cmdline/params/arguments/__init__.py @@ -10,60 +10,34 @@ # yapf: disable """Module with pre-defined reusable commandline arguments that can be used as `click` decorators.""" -import click +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import -from .. import types -from .overridable import OverridableArgument +from .main import * __all__ = ( - 'PROFILE', 'PROFILES', 'CALCULATION', 'CALCULATIONS', 'CODE', 'CODES', 'COMPUTER', 'COMPUTERS', 'DATUM', 'DATA', - 'GROUP', 'GROUPS', 'NODE', 'NODES', 'PROCESS', 'PROCESSES', 'WORKFLOW', 'WORKFLOWS', 'INPUT_FILE', 'OUTPUT_FILE', - 'LABEL', 'USER', 'CONFIG_OPTION' + 'CALCULATION', + 'CALCULATIONS', + 'CODE', + 'CODES', + 'COMPUTER', + 'COMPUTERS', + 'CONFIG_OPTION', + 'DATA', + 'DATUM', + 'GROUP', + 'GROUPS', + 'INPUT_FILE', + 'LABEL', + 'NODE', + 'NODES', + 'OUTPUT_FILE', + 'PROCESS', + 'PROCESSES', + 'PROFILE', + 'PROFILES', + 'USER', + 'WORKFLOW', + 'WORKFLOWS', ) - - -PROFILE = OverridableArgument('profile', type=types.ProfileParamType()) - -PROFILES = OverridableArgument('profiles', type=types.ProfileParamType(), nargs=-1) - -CALCULATION = OverridableArgument('calculation', type=types.CalculationParamType()) - -CALCULATIONS = OverridableArgument('calculations', nargs=-1, type=types.CalculationParamType()) - -CODE = OverridableArgument('code', type=types.CodeParamType()) - -CODES = OverridableArgument('codes', nargs=-1, type=types.CodeParamType()) - -COMPUTER = OverridableArgument('computer', type=types.ComputerParamType()) - -COMPUTERS = OverridableArgument('computers', nargs=-1, type=types.ComputerParamType()) - -DATUM = OverridableArgument('datum', type=types.DataParamType()) - -DATA = OverridableArgument('data', nargs=-1, type=types.DataParamType()) - -GROUP = OverridableArgument('group', type=types.GroupParamType()) - -GROUPS = OverridableArgument('groups', nargs=-1, type=types.GroupParamType()) - -NODE = OverridableArgument('node', type=types.NodeParamType()) - -NODES = OverridableArgument('nodes', nargs=-1, type=types.NodeParamType()) - -PROCESS = OverridableArgument('process', type=types.ProcessParamType()) - -PROCESSES = OverridableArgument('processes', nargs=-1, type=types.ProcessParamType()) - -WORKFLOW = OverridableArgument('workflow', type=types.WorkflowParamType()) - -WORKFLOWS = OverridableArgument('workflows', nargs=-1, type=types.WorkflowParamType()) - -INPUT_FILE = OverridableArgument('input_file', metavar='INPUT_FILE', type=click.Path(exists=True)) - -OUTPUT_FILE = OverridableArgument('output_file', metavar='OUTPUT_FILE', type=click.Path()) - -LABEL = OverridableArgument('label', type=click.STRING) - -USER = OverridableArgument('user', metavar='USER', type=types.UserParamType()) - -CONFIG_OPTION = OverridableArgument('option', type=types.ConfigOptionParamType()) diff --git a/aiida/cmdline/params/arguments/main.py b/aiida/cmdline/params/arguments/main.py new file mode 100644 index 0000000000..71bb8c2544 --- /dev/null +++ b/aiida/cmdline/params/arguments/main.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# yapf: disable +"""Module with pre-defined reusable commandline arguments that can be used as `click` decorators.""" + +import click + +from .. import types +from .overridable import OverridableArgument + +__all__ = ( + 'PROFILE', 'PROFILES', 'CALCULATION', 'CALCULATIONS', 'CODE', 'CODES', 'COMPUTER', 'COMPUTERS', 'DATUM', 'DATA', + 'GROUP', 'GROUPS', 'NODE', 'NODES', 'PROCESS', 'PROCESSES', 'WORKFLOW', 'WORKFLOWS', 'INPUT_FILE', 'OUTPUT_FILE', + 'LABEL', 'USER', 'CONFIG_OPTION' +) + + +PROFILE = OverridableArgument('profile', type=types.ProfileParamType()) + +PROFILES = OverridableArgument('profiles', type=types.ProfileParamType(), nargs=-1) + +CALCULATION = OverridableArgument('calculation', type=types.CalculationParamType()) + +CALCULATIONS = OverridableArgument('calculations', nargs=-1, type=types.CalculationParamType()) + +CODE = OverridableArgument('code', type=types.CodeParamType()) + +CODES = OverridableArgument('codes', nargs=-1, type=types.CodeParamType()) + +COMPUTER = OverridableArgument('computer', type=types.ComputerParamType()) + +COMPUTERS = OverridableArgument('computers', nargs=-1, type=types.ComputerParamType()) + +DATUM = OverridableArgument('datum', type=types.DataParamType()) + +DATA = OverridableArgument('data', nargs=-1, type=types.DataParamType()) + +GROUP = OverridableArgument('group', type=types.GroupParamType()) + +GROUPS = OverridableArgument('groups', nargs=-1, type=types.GroupParamType()) + +NODE = OverridableArgument('node', type=types.NodeParamType()) + +NODES = OverridableArgument('nodes', nargs=-1, type=types.NodeParamType()) + +PROCESS = OverridableArgument('process', type=types.ProcessParamType()) + +PROCESSES = OverridableArgument('processes', nargs=-1, type=types.ProcessParamType()) + +WORKFLOW = OverridableArgument('workflow', type=types.WorkflowParamType()) + +WORKFLOWS = OverridableArgument('workflows', nargs=-1, type=types.WorkflowParamType()) + +INPUT_FILE = OverridableArgument('input_file', metavar='INPUT_FILE', type=click.Path(exists=True)) + +OUTPUT_FILE = OverridableArgument('output_file', metavar='OUTPUT_FILE', type=click.Path()) + +LABEL = OverridableArgument('label', type=click.STRING) + +USER = OverridableArgument('user', metavar='USER', type=types.UserParamType()) + +CONFIG_OPTION = OverridableArgument('option', type=types.ConfigOptionParamType()) diff --git a/aiida/cmdline/params/options/__init__.py b/aiida/cmdline/params/options/__init__.py index accd78c65f..74d15b39c0 100644 --- a/aiida/cmdline/params/options/__init__.py +++ b/aiida/cmdline/params/options/__init__.py @@ -8,590 +8,83 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with pre-defined reusable commandline options that can be used as `click` decorators.""" -import click -from pgsu import DEFAULT_DSN as DEFAULT_DBINFO # pylint: disable=no-name-in-module -from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA -from aiida.manage.external.rmq import BROKER_DEFAULTS -from ...utils import defaults, echo -from .. import types -from .multivalue import MultipleValueOption -from .overridable import OverridableOption -from .contextualdefault import ContextualDefaultOption -from .config import ConfigFileOption +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import -__all__ = ( - 'graph_traversal_rules', 'PROFILE', 'CALCULATION', 'CALCULATIONS', 'CODE', 'CODES', 'COMPUTER', 'COMPUTERS', - 'DATUM', 'DATA', 'GROUP', 'GROUPS', 'NODE', 'NODES', 'FORCE', 'SILENT', 'VISUALIZATION_FORMAT', 'INPUT_FORMAT', - 'EXPORT_FORMAT', 'ARCHIVE_FORMAT', 'NON_INTERACTIVE', 'DRY_RUN', 'USER_EMAIL', 'USER_FIRST_NAME', 'USER_LAST_NAME', - 'USER_INSTITUTION', 'DB_BACKEND', 'DB_ENGINE', 'DB_HOST', 'DB_PORT', 'DB_USERNAME', 'DB_PASSWORD', 'DB_NAME', - 'REPOSITORY_PATH', 'PROFILE_ONLY_CONFIG', 'PROFILE_SET_DEFAULT', 'PREPEND_TEXT', 'APPEND_TEXT', 'LABEL', - 'DESCRIPTION', 'INPUT_PLUGIN', 'CALC_JOB_STATE', 'PROCESS_STATE', 'PROCESS_LABEL', 'TYPE_STRING', 'EXIT_STATUS', - 'FAILED', 'LIMIT', 'PROJECT', 'ORDER_BY', 'PAST_DAYS', 'OLDER_THAN', 'ALL', 'ALL_STATES', 'ALL_USERS', - 'GROUP_CLEAR', 'RAW', 'HOSTNAME', 'TRANSPORT', 'SCHEDULER', 'USER', 'PORT', 'FREQUENCY', 'VERBOSE', 'TIMEOUT', - 'FORMULA_MODE', 'TRAJECTORY_INDEX', 'WITH_ELEMENTS', 'WITH_ELEMENTS_EXCLUSIVE', 'DEBUG', 'PRINT_TRACEBACK' -) - -TRAVERSAL_RULE_HELP_STRING = { - 'call_calc_backward': 'CALL links to calculations backwards', - 'call_calc_forward': 'CALL links to calculations forwards', - 'call_work_backward': 'CALL links to workflows backwards', - 'call_work_forward': 'CALL links to workflows forwards', - 'input_calc_backward': 'INPUT links to calculations backwards', - 'input_calc_forward': 'INPUT links to calculations forwards', - 'input_work_backward': 'INPUT links to workflows backwards', - 'input_work_forward': 'INPUT links to workflows forwards', - 'return_backward': 'RETURN links backwards', - 'return_forward': 'RETURN links forwards', - 'create_backward': 'CREATE links backwards', - 'create_forward': 'CREATE links forwards', -} - - -def valid_process_states(): - """Return a list of valid values for the ProcessState enum.""" - from plumpy import ProcessState - return tuple(state.value for state in ProcessState) - - -def valid_calc_job_states(): - """Return a list of valid values for the CalcState enum.""" - from aiida.common.datastructures import CalcJobState - return tuple(state.value for state in CalcJobState) - - -def active_process_states(): - """Return a list of process states that are considered active.""" - from plumpy import ProcessState - return ([ - ProcessState.CREATED.value, - ProcessState.WAITING.value, - ProcessState.RUNNING.value, - ]) - - -def graph_traversal_rules(rules): - """Apply the graph traversal rule options to the command.""" - - def decorator(command): - """Only apply to traversal rules if they are toggleable.""" - for name, traversal_rule in sorted(rules.items(), reverse=True): - if traversal_rule.toggleable: - option_name = name.replace('_', '-') - option_label = '--{option_name}/--no-{option_name}'.format(option_name=option_name) - help_string = f'Whether to expand the node set by following {TRAVERSAL_RULE_HELP_STRING[name]}.' - click.option(option_label, default=traversal_rule.default, show_default=True, help=help_string)(command) - - return command - - return decorator - - -PROFILE = OverridableOption( - '-p', - '--profile', - 'profile', - type=types.ProfileParamType(), - default=defaults.get_default_profile, - help='Execute the command for this profile instead of the default profile.' -) - -CALCULATION = OverridableOption( - '-C', - '--calculation', - 'calculation', - type=types.CalculationParamType(), - help='A single calculation identified by its ID or UUID.' -) - -CALCULATIONS = OverridableOption( - '-C', - '--calculations', - 'calculations', - type=types.CalculationParamType(), - cls=MultipleValueOption, - help='One or multiple calculations identified by their ID or UUID.' -) - -CODE = OverridableOption( - '-X', '--code', 'code', type=types.CodeParamType(), help='A single code identified by its ID, UUID or label.' -) - -CODES = OverridableOption( - '-X', - '--codes', - 'codes', - type=types.CodeParamType(), - cls=MultipleValueOption, - help='One or multiple codes identified by their ID, UUID or label.' -) - -COMPUTER = OverridableOption( - '-Y', - '--computer', - 'computer', - type=types.ComputerParamType(), - help='A single computer identified by its ID, UUID or label.' -) - -COMPUTERS = OverridableOption( - '-Y', - '--computers', - 'computers', - type=types.ComputerParamType(), - cls=MultipleValueOption, - help='One or multiple computers identified by their ID, UUID or label.' -) - -DATUM = OverridableOption( - '-D', '--datum', 'datum', type=types.DataParamType(), help='A single datum identified by its ID, UUID or label.' -) - -DATA = OverridableOption( - '-D', - '--data', - 'data', - type=types.DataParamType(), - cls=MultipleValueOption, - help='One or multiple data identified by their ID, UUID or label.' -) - -GROUP = OverridableOption( - '-G', '--group', 'group', type=types.GroupParamType(), help='A single group identified by its ID, UUID or label.' -) - -GROUPS = OverridableOption( - '-G', - '--groups', - 'groups', - type=types.GroupParamType(), - cls=MultipleValueOption, - help='One or multiple groups identified by their ID, UUID or label.' -) - -NODE = OverridableOption( - '-N', '--node', 'node', type=types.NodeParamType(), help='A single node identified by its ID or UUID.' -) - -NODES = OverridableOption( - '-N', - '--nodes', - 'nodes', - type=types.NodeParamType(), - cls=MultipleValueOption, - help='One or multiple nodes identified by their ID or UUID.' -) - -FORCE = OverridableOption('-f', '--force', is_flag=True, default=False, help='Do not ask for confirmation.') +from .main import * -SILENT = OverridableOption('-s', '--silent', is_flag=True, default=False, help='Suppress any output printed to stdout.') - -VISUALIZATION_FORMAT = OverridableOption( - '-F', '--format', 'fmt', show_default=True, help='Format of the visualized output.' -) - -INPUT_FORMAT = OverridableOption('-F', '--format', 'fmt', show_default=True, help='Format of the input file.') - -EXPORT_FORMAT = OverridableOption('-F', '--format', 'fmt', show_default=True, help='Format of the exported file.') - -ARCHIVE_FORMAT = OverridableOption( - '-F', - '--archive-format', - type=click.Choice(['zip', 'zip-uncompressed', 'tar.gz']), - default='zip', - show_default=True, - help='The format of the archive file.' -) - -NON_INTERACTIVE = OverridableOption( - '-n', - '--non-interactive', - is_flag=True, - is_eager=True, - help='In non-interactive mode, the CLI never prompts but simply uses default values for options that define one.' -) - -DRY_RUN = OverridableOption('-n', '--dry-run', is_flag=True, help='Perform a dry run.') - -USER_EMAIL = OverridableOption( - '--email', - 'email', - type=types.EmailType(), - help='Email address associated with the data you generate. The email address is exported along with the data, ' - 'when sharing it.' -) - -USER_FIRST_NAME = OverridableOption( - '--first-name', type=types.NonEmptyStringParamType(), help='First name of the user.' -) - -USER_LAST_NAME = OverridableOption('--last-name', type=types.NonEmptyStringParamType(), help='Last name of the user.') - -USER_INSTITUTION = OverridableOption( - '--institution', type=types.NonEmptyStringParamType(), help='Institution of the user.' -) - -DB_ENGINE = OverridableOption( - '--db-engine', - help='Engine to use to connect to the database.', - default='postgresql_psycopg2', - type=click.Choice(['postgresql_psycopg2']) -) - -DB_BACKEND = OverridableOption( - '--db-backend', - type=click.Choice([BACKEND_DJANGO, BACKEND_SQLA]), - default=BACKEND_DJANGO, - help='Database backend to use.' -) - -DB_HOST = OverridableOption( - '--db-host', - type=types.HostnameType(), - help='Database server host. Leave empty for "peer" authentication.', - default='localhost' -) - -DB_PORT = OverridableOption( - '--db-port', - type=click.INT, - help='Database server port.', - default=DEFAULT_DBINFO['port'], -) - -DB_USERNAME = OverridableOption( - '--db-username', type=types.NonEmptyStringParamType(), help='Name of the database user.' -) - -DB_PASSWORD = OverridableOption( - '--db-password', - type=click.STRING, - help='Password of the database user.', - hide_input=True, -) - -DB_NAME = OverridableOption('--db-name', type=types.NonEmptyStringParamType(), help='Database name.') - -BROKER_PROTOCOL = OverridableOption( - '--broker-protocol', - type=click.Choice(('amqp', 'amqps')), - default=BROKER_DEFAULTS.protocol, - show_default=True, - help='Protocol to use for the message broker.' -) - -BROKER_USERNAME = OverridableOption( - '--broker-username', - type=types.NonEmptyStringParamType(), - default=BROKER_DEFAULTS.username, - show_default=True, - help='Username to use for authentication with the message broker.' -) - -BROKER_PASSWORD = OverridableOption( - '--broker-password', - type=types.NonEmptyStringParamType(), - default=BROKER_DEFAULTS.password, - show_default=True, - help='Password to use for authentication with the message broker.', - hide_input=True, -) - -BROKER_HOST = OverridableOption( - '--broker-host', - type=types.HostnameType(), - default=BROKER_DEFAULTS.host, - show_default=True, - help='Hostname for the message broker.' -) - -BROKER_PORT = OverridableOption( - '--broker-port', - type=click.INT, - default=BROKER_DEFAULTS.port, - show_default=True, - help='Port for the message broker.', -) - -BROKER_VIRTUAL_HOST = OverridableOption( - '--broker-virtual-host', - type=click.types.StringParamType(), - default=BROKER_DEFAULTS.virtual_host, - show_default=True, - help='Name of the virtual host for the message broker without leading forward slash.' -) - -REPOSITORY_PATH = OverridableOption( - '--repository', type=click.Path(file_okay=False), help='Absolute path to the file repository.' -) - -PROFILE_ONLY_CONFIG = OverridableOption( - '--only-config', is_flag=True, default=False, help='Only configure the user and skip creating the database.' -) - -PROFILE_SET_DEFAULT = OverridableOption( - '--set-default', is_flag=True, default=False, help='Set the profile as the new default.' -) - -PREPEND_TEXT = OverridableOption( - '--prepend-text', type=click.STRING, default='', help='Bash script to be executed before an action.' -) - -APPEND_TEXT = OverridableOption( - '--append-text', type=click.STRING, default='', help='Bash script to be executed after an action has completed.' -) - -LABEL = OverridableOption('-L', '--label', type=click.STRING, metavar='LABEL', help='Short name to be used as a label.') - -DESCRIPTION = OverridableOption( - '-D', - '--description', - type=click.STRING, - metavar='DESCRIPTION', - default='', - required=False, - help='A detailed description.' -) - -INPUT_PLUGIN = OverridableOption( - '-P', '--input-plugin', type=types.PluginParamType(group='calculations'), help='Calculation input plugin string.' -) - -CALC_JOB_STATE = OverridableOption( - '-s', - '--calc-job-state', - 'calc_job_state', - type=types.LazyChoice(valid_calc_job_states), - cls=MultipleValueOption, - help='Only include entries with this calculation job state.' -) - -PROCESS_STATE = OverridableOption( - '-S', - '--process-state', - 'process_state', - type=types.LazyChoice(valid_process_states), - cls=MultipleValueOption, - default=active_process_states, - help='Only include entries with this process state.' -) - -PAUSED = OverridableOption('--paused', 'paused', is_flag=True, help='Only include entries that are paused.') - -PROCESS_LABEL = OverridableOption( - '-L', - '--process-label', - 'process_label', - type=click.STRING, - required=False, - help='Only include entries whose process label matches this filter.' -) - -TYPE_STRING = OverridableOption( - '-T', - '--type-string', - 'type_string', - type=click.STRING, - required=False, - help='Only include entries whose type string matches this filter. Can include `_` to match a single arbitrary ' - 'character or `%` to match any number of characters.' -) - -EXIT_STATUS = OverridableOption( - '-E', '--exit-status', 'exit_status', type=click.INT, help='Only include entries with this exit status.' -) - -FAILED = OverridableOption( - '-X', '--failed', 'failed', is_flag=True, default=False, help='Only include entries that have failed.' -) - -LIMIT = OverridableOption( - '-l', '--limit', 'limit', type=click.INT, default=None, help='Limit the number of entries to display.' -) - -PROJECT = OverridableOption( - '-P', '--project', 'project', cls=MultipleValueOption, help='Select the list of entity attributes to project.' -) - -ORDER_BY = OverridableOption( - '-O', - '--order-by', - 'order_by', - type=click.Choice(['id', 'ctime']), - default='ctime', - show_default=True, - help='Order the entries by this attribute.' -) - -ORDER_DIRECTION = OverridableOption( - '-D', - '--order-direction', - 'order_dir', - type=click.Choice(['asc', 'desc']), - default='asc', - show_default=True, - help='List the entries in ascending or descending order' -) - -PAST_DAYS = OverridableOption( - '-p', - '--past-days', - 'past_days', - type=click.INT, - metavar='PAST_DAYS', - help='Only include entries created in the last PAST_DAYS number of days.' -) - -OLDER_THAN = OverridableOption( - '-o', - '--older-than', - 'older_than', - type=click.INT, - metavar='OLDER_THAN', - help='Only include entries created before OLDER_THAN days ago.' -) - -ALL = OverridableOption( - '-a', - '--all', - 'all_entries', - is_flag=True, - default=False, - help='Include all entries, disregarding all other filter options and flags.' -) - -ALL_STATES = OverridableOption('-A', '--all-states', is_flag=True, help='Do not limit to items in running state.') - -ALL_USERS = OverridableOption( - '-A', '--all-users', 'all_users', is_flag=True, default=False, help='Include all entries regardless of the owner.' -) - -GROUP_CLEAR = OverridableOption( - '-c', '--clear', is_flag=True, default=False, help='Remove all the nodes from the group.' -) - -RAW = OverridableOption( - '-r', - '--raw', - 'raw', - is_flag=True, - default=False, - help='Display only raw query results, without any headers or footers.' -) - -HOSTNAME = OverridableOption('-H', '--hostname', type=types.HostnameType(), help='Hostname.') - -TRANSPORT = OverridableOption( - '-T', - '--transport', - type=types.PluginParamType(group='transports'), - required=True, - help="A transport plugin (as listed in 'verdi plugin list aiida.transports')." -) - -SCHEDULER = OverridableOption( - '-S', - '--scheduler', - type=types.PluginParamType(group='schedulers'), - required=True, - help="A scheduler plugin (as listed in 'verdi plugin list aiida.schedulers')." -) - -USER = OverridableOption('-u', '--user', 'user', type=types.UserParamType(), help='Email address of the user.') - -PORT = OverridableOption('-P', '--port', 'port', type=click.INT, help='Port number.') - -FREQUENCY = OverridableOption('-F', '--frequency', 'frequency', type=click.INT) - -VERBOSE = OverridableOption('-v', '--verbose', is_flag=True, default=False, help='Be more verbose in printing output.') - -TIMEOUT = OverridableOption( - '-t', - '--timeout', - type=click.FLOAT, - default=5.0, - show_default=True, - help='Time in seconds to wait for a response before timing out.' -) - -WAIT = OverridableOption( - '--wait/--no-wait', - default=False, - help='Wait for the action to be completed otherwise return as soon as it is scheduled.' -) - -FORMULA_MODE = OverridableOption( - '-f', - '--formula-mode', - type=click.Choice(['hill', 'hill_compact', 'reduce', 'group', 'count', 'count_compact']), - default='hill', - help='Mode for printing the chemical formula.' -) - -TRAJECTORY_INDEX = OverridableOption( - '-i', - '--trajectory-index', - 'trajectory_index', - type=click.INT, - default=None, - help='Specific step of the Trajectory to select.' -) - -WITH_ELEMENTS = OverridableOption( - '-e', - '--with-elements', - 'elements', - type=click.STRING, - cls=MultipleValueOption, - default=None, - help='Only select objects containing these elements.' -) - -WITH_ELEMENTS_EXCLUSIVE = OverridableOption( - '-E', - '--with-elements-exclusive', - 'elements_exclusive', - type=click.STRING, - cls=MultipleValueOption, - default=None, - help='Only select objects containing only these and no other elements.' -) - -CONFIG_FILE = ConfigFileOption( - '--config', - type=types.FileOrUrl(), - help='Load option values from configuration file in yaml format (local path or URL).' -) - -IDENTIFIER = OverridableOption( - '-i', - '--identifier', - 'identifier', - help='The type of identifier used for specifying each node.', - default='pk', - type=click.Choice(['pk', 'uuid']) -) - -DICT_FORMAT = OverridableOption( - '-f', - '--format', - 'fmt', - type=click.Choice(list(echo.VALID_DICT_FORMATS_MAPPING.keys())), - default=list(echo.VALID_DICT_FORMATS_MAPPING.keys())[0], - help='The format of the output data.' -) - -DICT_KEYS = OverridableOption( - '-k', '--keys', type=click.STRING, cls=MultipleValueOption, help='Filter the output by one or more keys.' -) - -DEBUG = OverridableOption( - '--debug', is_flag=True, default=False, help='Show debug messages. Mostly relevant for developers.', hidden=True -) - -PRINT_TRACEBACK = OverridableOption( - '-t', - '--print-traceback', - is_flag=True, - help='Print the full traceback in case an exception is raised.', +__all__ = ( + 'ALL', + 'ALL_STATES', + 'ALL_USERS', + 'APPEND_TEXT', + 'ARCHIVE_FORMAT', + 'CALCULATION', + 'CALCULATIONS', + 'CALC_JOB_STATE', + 'CODE', + 'CODES', + 'COMPUTER', + 'COMPUTERS', + 'DATA', + 'DATUM', + 'DB_BACKEND', + 'DB_ENGINE', + 'DB_HOST', + 'DB_NAME', + 'DB_PASSWORD', + 'DB_PORT', + 'DB_USERNAME', + 'DEBUG', + 'DESCRIPTION', + 'DRY_RUN', + 'EXIT_STATUS', + 'EXPORT_FORMAT', + 'FAILED', + 'FORCE', + 'FORMULA_MODE', + 'FREQUENCY', + 'GROUP', + 'GROUPS', + 'GROUP_CLEAR', + 'HOSTNAME', + 'INPUT_FORMAT', + 'INPUT_PLUGIN', + 'LABEL', + 'LIMIT', + 'NODE', + 'NODES', + 'NON_INTERACTIVE', + 'OLDER_THAN', + 'ORDER_BY', + 'PAST_DAYS', + 'PORT', + 'PREPEND_TEXT', + 'PRINT_TRACEBACK', + 'PROCESS_LABEL', + 'PROCESS_STATE', + 'PROFILE', + 'PROFILE_ONLY_CONFIG', + 'PROFILE_SET_DEFAULT', + 'PROJECT', + 'RAW', + 'REPOSITORY_PATH', + 'SCHEDULER', + 'SILENT', + 'TIMEOUT', + 'TRAJECTORY_INDEX', + 'TRANSPORT', + 'TYPE_STRING', + 'USER', + 'USER_EMAIL', + 'USER_FIRST_NAME', + 'USER_INSTITUTION', + 'USER_LAST_NAME', + 'VERBOSE', + 'VISUALIZATION_FORMAT', + 'WITH_ELEMENTS', + 'WITH_ELEMENTS_EXCLUSIVE', + 'graph_traversal_rules', ) diff --git a/aiida/cmdline/params/options/main.py b/aiida/cmdline/params/options/main.py new file mode 100644 index 0000000000..accd78c65f --- /dev/null +++ b/aiida/cmdline/params/options/main.py @@ -0,0 +1,597 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Module with pre-defined reusable commandline options that can be used as `click` decorators.""" +import click +from pgsu import DEFAULT_DSN as DEFAULT_DBINFO # pylint: disable=no-name-in-module + +from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA +from aiida.manage.external.rmq import BROKER_DEFAULTS +from ...utils import defaults, echo +from .. import types +from .multivalue import MultipleValueOption +from .overridable import OverridableOption +from .contextualdefault import ContextualDefaultOption +from .config import ConfigFileOption + +__all__ = ( + 'graph_traversal_rules', 'PROFILE', 'CALCULATION', 'CALCULATIONS', 'CODE', 'CODES', 'COMPUTER', 'COMPUTERS', + 'DATUM', 'DATA', 'GROUP', 'GROUPS', 'NODE', 'NODES', 'FORCE', 'SILENT', 'VISUALIZATION_FORMAT', 'INPUT_FORMAT', + 'EXPORT_FORMAT', 'ARCHIVE_FORMAT', 'NON_INTERACTIVE', 'DRY_RUN', 'USER_EMAIL', 'USER_FIRST_NAME', 'USER_LAST_NAME', + 'USER_INSTITUTION', 'DB_BACKEND', 'DB_ENGINE', 'DB_HOST', 'DB_PORT', 'DB_USERNAME', 'DB_PASSWORD', 'DB_NAME', + 'REPOSITORY_PATH', 'PROFILE_ONLY_CONFIG', 'PROFILE_SET_DEFAULT', 'PREPEND_TEXT', 'APPEND_TEXT', 'LABEL', + 'DESCRIPTION', 'INPUT_PLUGIN', 'CALC_JOB_STATE', 'PROCESS_STATE', 'PROCESS_LABEL', 'TYPE_STRING', 'EXIT_STATUS', + 'FAILED', 'LIMIT', 'PROJECT', 'ORDER_BY', 'PAST_DAYS', 'OLDER_THAN', 'ALL', 'ALL_STATES', 'ALL_USERS', + 'GROUP_CLEAR', 'RAW', 'HOSTNAME', 'TRANSPORT', 'SCHEDULER', 'USER', 'PORT', 'FREQUENCY', 'VERBOSE', 'TIMEOUT', + 'FORMULA_MODE', 'TRAJECTORY_INDEX', 'WITH_ELEMENTS', 'WITH_ELEMENTS_EXCLUSIVE', 'DEBUG', 'PRINT_TRACEBACK' +) + +TRAVERSAL_RULE_HELP_STRING = { + 'call_calc_backward': 'CALL links to calculations backwards', + 'call_calc_forward': 'CALL links to calculations forwards', + 'call_work_backward': 'CALL links to workflows backwards', + 'call_work_forward': 'CALL links to workflows forwards', + 'input_calc_backward': 'INPUT links to calculations backwards', + 'input_calc_forward': 'INPUT links to calculations forwards', + 'input_work_backward': 'INPUT links to workflows backwards', + 'input_work_forward': 'INPUT links to workflows forwards', + 'return_backward': 'RETURN links backwards', + 'return_forward': 'RETURN links forwards', + 'create_backward': 'CREATE links backwards', + 'create_forward': 'CREATE links forwards', +} + + +def valid_process_states(): + """Return a list of valid values for the ProcessState enum.""" + from plumpy import ProcessState + return tuple(state.value for state in ProcessState) + + +def valid_calc_job_states(): + """Return a list of valid values for the CalcState enum.""" + from aiida.common.datastructures import CalcJobState + return tuple(state.value for state in CalcJobState) + + +def active_process_states(): + """Return a list of process states that are considered active.""" + from plumpy import ProcessState + return ([ + ProcessState.CREATED.value, + ProcessState.WAITING.value, + ProcessState.RUNNING.value, + ]) + + +def graph_traversal_rules(rules): + """Apply the graph traversal rule options to the command.""" + + def decorator(command): + """Only apply to traversal rules if they are toggleable.""" + for name, traversal_rule in sorted(rules.items(), reverse=True): + if traversal_rule.toggleable: + option_name = name.replace('_', '-') + option_label = '--{option_name}/--no-{option_name}'.format(option_name=option_name) + help_string = f'Whether to expand the node set by following {TRAVERSAL_RULE_HELP_STRING[name]}.' + click.option(option_label, default=traversal_rule.default, show_default=True, help=help_string)(command) + + return command + + return decorator + + +PROFILE = OverridableOption( + '-p', + '--profile', + 'profile', + type=types.ProfileParamType(), + default=defaults.get_default_profile, + help='Execute the command for this profile instead of the default profile.' +) + +CALCULATION = OverridableOption( + '-C', + '--calculation', + 'calculation', + type=types.CalculationParamType(), + help='A single calculation identified by its ID or UUID.' +) + +CALCULATIONS = OverridableOption( + '-C', + '--calculations', + 'calculations', + type=types.CalculationParamType(), + cls=MultipleValueOption, + help='One or multiple calculations identified by their ID or UUID.' +) + +CODE = OverridableOption( + '-X', '--code', 'code', type=types.CodeParamType(), help='A single code identified by its ID, UUID or label.' +) + +CODES = OverridableOption( + '-X', + '--codes', + 'codes', + type=types.CodeParamType(), + cls=MultipleValueOption, + help='One or multiple codes identified by their ID, UUID or label.' +) + +COMPUTER = OverridableOption( + '-Y', + '--computer', + 'computer', + type=types.ComputerParamType(), + help='A single computer identified by its ID, UUID or label.' +) + +COMPUTERS = OverridableOption( + '-Y', + '--computers', + 'computers', + type=types.ComputerParamType(), + cls=MultipleValueOption, + help='One or multiple computers identified by their ID, UUID or label.' +) + +DATUM = OverridableOption( + '-D', '--datum', 'datum', type=types.DataParamType(), help='A single datum identified by its ID, UUID or label.' +) + +DATA = OverridableOption( + '-D', + '--data', + 'data', + type=types.DataParamType(), + cls=MultipleValueOption, + help='One or multiple data identified by their ID, UUID or label.' +) + +GROUP = OverridableOption( + '-G', '--group', 'group', type=types.GroupParamType(), help='A single group identified by its ID, UUID or label.' +) + +GROUPS = OverridableOption( + '-G', + '--groups', + 'groups', + type=types.GroupParamType(), + cls=MultipleValueOption, + help='One or multiple groups identified by their ID, UUID or label.' +) + +NODE = OverridableOption( + '-N', '--node', 'node', type=types.NodeParamType(), help='A single node identified by its ID or UUID.' +) + +NODES = OverridableOption( + '-N', + '--nodes', + 'nodes', + type=types.NodeParamType(), + cls=MultipleValueOption, + help='One or multiple nodes identified by their ID or UUID.' +) + +FORCE = OverridableOption('-f', '--force', is_flag=True, default=False, help='Do not ask for confirmation.') + +SILENT = OverridableOption('-s', '--silent', is_flag=True, default=False, help='Suppress any output printed to stdout.') + +VISUALIZATION_FORMAT = OverridableOption( + '-F', '--format', 'fmt', show_default=True, help='Format of the visualized output.' +) + +INPUT_FORMAT = OverridableOption('-F', '--format', 'fmt', show_default=True, help='Format of the input file.') + +EXPORT_FORMAT = OverridableOption('-F', '--format', 'fmt', show_default=True, help='Format of the exported file.') + +ARCHIVE_FORMAT = OverridableOption( + '-F', + '--archive-format', + type=click.Choice(['zip', 'zip-uncompressed', 'tar.gz']), + default='zip', + show_default=True, + help='The format of the archive file.' +) + +NON_INTERACTIVE = OverridableOption( + '-n', + '--non-interactive', + is_flag=True, + is_eager=True, + help='In non-interactive mode, the CLI never prompts but simply uses default values for options that define one.' +) + +DRY_RUN = OverridableOption('-n', '--dry-run', is_flag=True, help='Perform a dry run.') + +USER_EMAIL = OverridableOption( + '--email', + 'email', + type=types.EmailType(), + help='Email address associated with the data you generate. The email address is exported along with the data, ' + 'when sharing it.' +) + +USER_FIRST_NAME = OverridableOption( + '--first-name', type=types.NonEmptyStringParamType(), help='First name of the user.' +) + +USER_LAST_NAME = OverridableOption('--last-name', type=types.NonEmptyStringParamType(), help='Last name of the user.') + +USER_INSTITUTION = OverridableOption( + '--institution', type=types.NonEmptyStringParamType(), help='Institution of the user.' +) + +DB_ENGINE = OverridableOption( + '--db-engine', + help='Engine to use to connect to the database.', + default='postgresql_psycopg2', + type=click.Choice(['postgresql_psycopg2']) +) + +DB_BACKEND = OverridableOption( + '--db-backend', + type=click.Choice([BACKEND_DJANGO, BACKEND_SQLA]), + default=BACKEND_DJANGO, + help='Database backend to use.' +) + +DB_HOST = OverridableOption( + '--db-host', + type=types.HostnameType(), + help='Database server host. Leave empty for "peer" authentication.', + default='localhost' +) + +DB_PORT = OverridableOption( + '--db-port', + type=click.INT, + help='Database server port.', + default=DEFAULT_DBINFO['port'], +) + +DB_USERNAME = OverridableOption( + '--db-username', type=types.NonEmptyStringParamType(), help='Name of the database user.' +) + +DB_PASSWORD = OverridableOption( + '--db-password', + type=click.STRING, + help='Password of the database user.', + hide_input=True, +) + +DB_NAME = OverridableOption('--db-name', type=types.NonEmptyStringParamType(), help='Database name.') + +BROKER_PROTOCOL = OverridableOption( + '--broker-protocol', + type=click.Choice(('amqp', 'amqps')), + default=BROKER_DEFAULTS.protocol, + show_default=True, + help='Protocol to use for the message broker.' +) + +BROKER_USERNAME = OverridableOption( + '--broker-username', + type=types.NonEmptyStringParamType(), + default=BROKER_DEFAULTS.username, + show_default=True, + help='Username to use for authentication with the message broker.' +) + +BROKER_PASSWORD = OverridableOption( + '--broker-password', + type=types.NonEmptyStringParamType(), + default=BROKER_DEFAULTS.password, + show_default=True, + help='Password to use for authentication with the message broker.', + hide_input=True, +) + +BROKER_HOST = OverridableOption( + '--broker-host', + type=types.HostnameType(), + default=BROKER_DEFAULTS.host, + show_default=True, + help='Hostname for the message broker.' +) + +BROKER_PORT = OverridableOption( + '--broker-port', + type=click.INT, + default=BROKER_DEFAULTS.port, + show_default=True, + help='Port for the message broker.', +) + +BROKER_VIRTUAL_HOST = OverridableOption( + '--broker-virtual-host', + type=click.types.StringParamType(), + default=BROKER_DEFAULTS.virtual_host, + show_default=True, + help='Name of the virtual host for the message broker without leading forward slash.' +) + +REPOSITORY_PATH = OverridableOption( + '--repository', type=click.Path(file_okay=False), help='Absolute path to the file repository.' +) + +PROFILE_ONLY_CONFIG = OverridableOption( + '--only-config', is_flag=True, default=False, help='Only configure the user and skip creating the database.' +) + +PROFILE_SET_DEFAULT = OverridableOption( + '--set-default', is_flag=True, default=False, help='Set the profile as the new default.' +) + +PREPEND_TEXT = OverridableOption( + '--prepend-text', type=click.STRING, default='', help='Bash script to be executed before an action.' +) + +APPEND_TEXT = OverridableOption( + '--append-text', type=click.STRING, default='', help='Bash script to be executed after an action has completed.' +) + +LABEL = OverridableOption('-L', '--label', type=click.STRING, metavar='LABEL', help='Short name to be used as a label.') + +DESCRIPTION = OverridableOption( + '-D', + '--description', + type=click.STRING, + metavar='DESCRIPTION', + default='', + required=False, + help='A detailed description.' +) + +INPUT_PLUGIN = OverridableOption( + '-P', '--input-plugin', type=types.PluginParamType(group='calculations'), help='Calculation input plugin string.' +) + +CALC_JOB_STATE = OverridableOption( + '-s', + '--calc-job-state', + 'calc_job_state', + type=types.LazyChoice(valid_calc_job_states), + cls=MultipleValueOption, + help='Only include entries with this calculation job state.' +) + +PROCESS_STATE = OverridableOption( + '-S', + '--process-state', + 'process_state', + type=types.LazyChoice(valid_process_states), + cls=MultipleValueOption, + default=active_process_states, + help='Only include entries with this process state.' +) + +PAUSED = OverridableOption('--paused', 'paused', is_flag=True, help='Only include entries that are paused.') + +PROCESS_LABEL = OverridableOption( + '-L', + '--process-label', + 'process_label', + type=click.STRING, + required=False, + help='Only include entries whose process label matches this filter.' +) + +TYPE_STRING = OverridableOption( + '-T', + '--type-string', + 'type_string', + type=click.STRING, + required=False, + help='Only include entries whose type string matches this filter. Can include `_` to match a single arbitrary ' + 'character or `%` to match any number of characters.' +) + +EXIT_STATUS = OverridableOption( + '-E', '--exit-status', 'exit_status', type=click.INT, help='Only include entries with this exit status.' +) + +FAILED = OverridableOption( + '-X', '--failed', 'failed', is_flag=True, default=False, help='Only include entries that have failed.' +) + +LIMIT = OverridableOption( + '-l', '--limit', 'limit', type=click.INT, default=None, help='Limit the number of entries to display.' +) + +PROJECT = OverridableOption( + '-P', '--project', 'project', cls=MultipleValueOption, help='Select the list of entity attributes to project.' +) + +ORDER_BY = OverridableOption( + '-O', + '--order-by', + 'order_by', + type=click.Choice(['id', 'ctime']), + default='ctime', + show_default=True, + help='Order the entries by this attribute.' +) + +ORDER_DIRECTION = OverridableOption( + '-D', + '--order-direction', + 'order_dir', + type=click.Choice(['asc', 'desc']), + default='asc', + show_default=True, + help='List the entries in ascending or descending order' +) + +PAST_DAYS = OverridableOption( + '-p', + '--past-days', + 'past_days', + type=click.INT, + metavar='PAST_DAYS', + help='Only include entries created in the last PAST_DAYS number of days.' +) + +OLDER_THAN = OverridableOption( + '-o', + '--older-than', + 'older_than', + type=click.INT, + metavar='OLDER_THAN', + help='Only include entries created before OLDER_THAN days ago.' +) + +ALL = OverridableOption( + '-a', + '--all', + 'all_entries', + is_flag=True, + default=False, + help='Include all entries, disregarding all other filter options and flags.' +) + +ALL_STATES = OverridableOption('-A', '--all-states', is_flag=True, help='Do not limit to items in running state.') + +ALL_USERS = OverridableOption( + '-A', '--all-users', 'all_users', is_flag=True, default=False, help='Include all entries regardless of the owner.' +) + +GROUP_CLEAR = OverridableOption( + '-c', '--clear', is_flag=True, default=False, help='Remove all the nodes from the group.' +) + +RAW = OverridableOption( + '-r', + '--raw', + 'raw', + is_flag=True, + default=False, + help='Display only raw query results, without any headers or footers.' +) + +HOSTNAME = OverridableOption('-H', '--hostname', type=types.HostnameType(), help='Hostname.') + +TRANSPORT = OverridableOption( + '-T', + '--transport', + type=types.PluginParamType(group='transports'), + required=True, + help="A transport plugin (as listed in 'verdi plugin list aiida.transports')." +) + +SCHEDULER = OverridableOption( + '-S', + '--scheduler', + type=types.PluginParamType(group='schedulers'), + required=True, + help="A scheduler plugin (as listed in 'verdi plugin list aiida.schedulers')." +) + +USER = OverridableOption('-u', '--user', 'user', type=types.UserParamType(), help='Email address of the user.') + +PORT = OverridableOption('-P', '--port', 'port', type=click.INT, help='Port number.') + +FREQUENCY = OverridableOption('-F', '--frequency', 'frequency', type=click.INT) + +VERBOSE = OverridableOption('-v', '--verbose', is_flag=True, default=False, help='Be more verbose in printing output.') + +TIMEOUT = OverridableOption( + '-t', + '--timeout', + type=click.FLOAT, + default=5.0, + show_default=True, + help='Time in seconds to wait for a response before timing out.' +) + +WAIT = OverridableOption( + '--wait/--no-wait', + default=False, + help='Wait for the action to be completed otherwise return as soon as it is scheduled.' +) + +FORMULA_MODE = OverridableOption( + '-f', + '--formula-mode', + type=click.Choice(['hill', 'hill_compact', 'reduce', 'group', 'count', 'count_compact']), + default='hill', + help='Mode for printing the chemical formula.' +) + +TRAJECTORY_INDEX = OverridableOption( + '-i', + '--trajectory-index', + 'trajectory_index', + type=click.INT, + default=None, + help='Specific step of the Trajectory to select.' +) + +WITH_ELEMENTS = OverridableOption( + '-e', + '--with-elements', + 'elements', + type=click.STRING, + cls=MultipleValueOption, + default=None, + help='Only select objects containing these elements.' +) + +WITH_ELEMENTS_EXCLUSIVE = OverridableOption( + '-E', + '--with-elements-exclusive', + 'elements_exclusive', + type=click.STRING, + cls=MultipleValueOption, + default=None, + help='Only select objects containing only these and no other elements.' +) + +CONFIG_FILE = ConfigFileOption( + '--config', + type=types.FileOrUrl(), + help='Load option values from configuration file in yaml format (local path or URL).' +) + +IDENTIFIER = OverridableOption( + '-i', + '--identifier', + 'identifier', + help='The type of identifier used for specifying each node.', + default='pk', + type=click.Choice(['pk', 'uuid']) +) + +DICT_FORMAT = OverridableOption( + '-f', + '--format', + 'fmt', + type=click.Choice(list(echo.VALID_DICT_FORMATS_MAPPING.keys())), + default=list(echo.VALID_DICT_FORMATS_MAPPING.keys())[0], + help='The format of the output data.' +) + +DICT_KEYS = OverridableOption( + '-k', '--keys', type=click.STRING, cls=MultipleValueOption, help='Filter the output by one or more keys.' +) + +DEBUG = OverridableOption( + '--debug', is_flag=True, default=False, help='Show debug messages. Mostly relevant for developers.', hidden=True +) + +PRINT_TRACEBACK = OverridableOption( + '-t', + '--print-traceback', + is_flag=True, + help='Print the full traceback in case an exception is raised.', +) diff --git a/aiida/cmdline/params/types/__init__.py b/aiida/cmdline/params/types/__init__.py index cedb380572..9ee58ad940 100644 --- a/aiida/cmdline/params/types/__init__.py +++ b/aiida/cmdline/params/types/__init__.py @@ -9,29 +9,54 @@ ########################################################################### """Provides all parameter types.""" -from .calculation import CalculationParamType -from .choice import LazyChoice -from .code import CodeParamType -from .computer import ComputerParamType, ShebangParamType, MpirunCommandParamType -from .config import ConfigOptionParamType -from .data import DataParamType -from .group import GroupParamType -from .identifier import IdentifierParamType -from .multiple import MultipleValueParamType -from .node import NodeParamType -from .process import ProcessParamType -from .strings import (NonEmptyStringParamType, EmailType, HostnameType, EntryPointType, LabelStringType) -from .path import AbsolutePathParamType, PathOrUrl, FileOrUrl -from .plugin import PluginParamType -from .profile import ProfileParamType -from .user import UserParamType -from .test_module import TestModuleParamType -from .workflow import WorkflowParamType +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .calculation import * +from .choice import * +from .code import * +from .computer import * +from .config import * +from .data import * +from .group import * +from .identifier import * +from .multiple import * +from .node import * +from .path import * +from .plugin import * +from .process import * +from .profile import * +from .strings import * +from .test_module import * +from .user import * +from .workflow import * __all__ = ( - 'LazyChoice', 'IdentifierParamType', 'CalculationParamType', 'CodeParamType', 'ComputerParamType', - 'ConfigOptionParamType', 'DataParamType', 'GroupParamType', 'NodeParamType', 'MpirunCommandParamType', - 'MultipleValueParamType', 'NonEmptyStringParamType', 'PluginParamType', 'AbsolutePathParamType', 'ShebangParamType', - 'UserParamType', 'TestModuleParamType', 'ProfileParamType', 'WorkflowParamType', 'ProcessParamType', 'PathOrUrl', - 'FileOrUrl' + 'AbsolutePathParamType', + 'CalculationParamType', + 'CodeParamType', + 'ComputerParamType', + 'ConfigOptionParamType', + 'DataParamType', + 'EmailType', + 'EntryPointType', + 'FileOrUrl', + 'GroupParamType', + 'HostnameType', + 'IdentifierParamType', + 'LabelStringType', + 'LazyChoice', + 'MpirunCommandParamType', + 'MultipleValueParamType', + 'NodeParamType', + 'NonEmptyStringParamType', + 'PathOrUrl', + 'PluginParamType', + 'ProcessParamType', + 'ProfileParamType', + 'ShebangParamType', + 'TestModuleParamType', + 'UserParamType', + 'WorkflowParamType', ) diff --git a/aiida/cmdline/params/types/calculation.py b/aiida/cmdline/params/types/calculation.py index a9dd484b4f..2e4c0d0750 100644 --- a/aiida/cmdline/params/types/calculation.py +++ b/aiida/cmdline/params/types/calculation.py @@ -13,6 +13,8 @@ from .identifier import IdentifierParamType +__all__ = ('CalculationParamType',) + class CalculationParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/params/types/choice.py b/aiida/cmdline/params/types/choice.py index 92d5894eb3..827b19020a 100644 --- a/aiida/cmdline/params/types/choice.py +++ b/aiida/cmdline/params/types/choice.py @@ -12,6 +12,8 @@ """ import click +__all__ = ('LazyChoice', ) + class LazyChoice(click.ParamType): """ diff --git a/aiida/cmdline/params/types/code.py b/aiida/cmdline/params/types/code.py index da1c6753bc..ecbdabbd68 100644 --- a/aiida/cmdline/params/types/code.py +++ b/aiida/cmdline/params/types/code.py @@ -13,6 +13,7 @@ from aiida.cmdline.utils import decorators from .identifier import IdentifierParamType +__all__ = ('CodeParamType',) class CodeParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/params/types/computer.py b/aiida/cmdline/params/types/computer.py index 8667363ef8..5e80e33100 100644 --- a/aiida/cmdline/params/types/computer.py +++ b/aiida/cmdline/params/types/computer.py @@ -16,6 +16,7 @@ from ...utils import decorators from .identifier import IdentifierParamType +__all__ = ('ComputerParamType', 'ShebangParamType', 'MpirunCommandParamType') class ComputerParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/params/types/data.py b/aiida/cmdline/params/types/data.py index 02c896f4b7..1ce0faf557 100644 --- a/aiida/cmdline/params/types/data.py +++ b/aiida/cmdline/params/types/data.py @@ -12,6 +12,7 @@ """ from .identifier import IdentifierParamType +__all__ = ('DataParamType',) class DataParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/params/types/group.py b/aiida/cmdline/params/types/group.py index 0645ac6e65..d09582cd31 100644 --- a/aiida/cmdline/params/types/group.py +++ b/aiida/cmdline/params/types/group.py @@ -15,6 +15,7 @@ from .identifier import IdentifierParamType +__all__ = ('GroupParamType',) class GroupParamType(IdentifierParamType): """The ParamType for identifying Group entities or its subclasses.""" diff --git a/aiida/cmdline/params/types/identifier.py b/aiida/cmdline/params/types/identifier.py index 513ee2a82b..3d8a735f4d 100644 --- a/aiida/cmdline/params/types/identifier.py +++ b/aiida/cmdline/params/types/identifier.py @@ -17,6 +17,7 @@ from aiida.cmdline.utils.decorators import with_dbenv from aiida.plugins.entry_point import get_entry_point_from_string +__all__ = ('IdentifierParamType',) class IdentifierParamType(click.ParamType, ABC): """ diff --git a/aiida/cmdline/params/types/multiple.py b/aiida/cmdline/params/types/multiple.py index 733ce7dcd4..29a142be7a 100644 --- a/aiida/cmdline/params/types/multiple.py +++ b/aiida/cmdline/params/types/multiple.py @@ -12,6 +12,7 @@ """ import click +__all__ = ('MultipleValueParamType', ) class MultipleValueParamType(click.ParamType): """ diff --git a/aiida/cmdline/params/types/node.py b/aiida/cmdline/params/types/node.py index 568dbf50fd..6e6b8f585b 100644 --- a/aiida/cmdline/params/types/node.py +++ b/aiida/cmdline/params/types/node.py @@ -12,6 +12,7 @@ """ from .identifier import IdentifierParamType +__all__ = ('NodeParamType', ) class NodeParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/params/types/path.py b/aiida/cmdline/params/types/path.py index daeb5ae115..7dc3a0d9f8 100644 --- a/aiida/cmdline/params/types/path.py +++ b/aiida/cmdline/params/types/path.py @@ -15,6 +15,8 @@ from socket import timeout import click +__all__ = ('AbsolutePathParamType', 'FileOrUrl', 'PathOrUrl') + URL_TIMEOUT_SECONDS = 10 diff --git a/aiida/cmdline/params/types/plugin.py b/aiida/cmdline/params/types/plugin.py index 387e5127a5..e56f4c2c8d 100644 --- a/aiida/cmdline/params/types/plugin.py +++ b/aiida/cmdline/params/types/plugin.py @@ -18,6 +18,7 @@ from aiida.plugins.entry_point import get_entry_point, get_entry_points, get_entry_point_groups from ..types import EntryPointType +__all__ = ('PluginParamType',) class PluginParamType(EntryPointType): """ diff --git a/aiida/cmdline/params/types/process.py b/aiida/cmdline/params/types/process.py index e18ef66bd7..6933cb42e8 100644 --- a/aiida/cmdline/params/types/process.py +++ b/aiida/cmdline/params/types/process.py @@ -13,6 +13,7 @@ from .identifier import IdentifierParamType +__all__ = ('ProcessParamType',) class ProcessParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/params/types/profile.py b/aiida/cmdline/params/types/profile.py index 61d5737f9c..831948f382 100644 --- a/aiida/cmdline/params/types/profile.py +++ b/aiida/cmdline/params/types/profile.py @@ -11,6 +11,7 @@ from .strings import LabelStringType +__all__ = ('ProfileParamType',) class ProfileParamType(LabelStringType): """The profile parameter type for click.""" diff --git a/aiida/cmdline/params/types/strings.py b/aiida/cmdline/params/types/strings.py index 63abbcd599..f35e763971 100644 --- a/aiida/cmdline/params/types/strings.py +++ b/aiida/cmdline/params/types/strings.py @@ -14,6 +14,7 @@ import re from click.types import StringParamType +__all__ = ('EmailType', 'EntryPointType', 'HostnameType', 'NonEmptyStringParamType', 'LabelStringType') class NonEmptyStringParamType(StringParamType): """Parameter whose values have to be string and non-empty.""" diff --git a/aiida/cmdline/params/types/test_module.py b/aiida/cmdline/params/types/test_module.py index d13c80633c..de985afd8b 100644 --- a/aiida/cmdline/params/types/test_module.py +++ b/aiida/cmdline/params/types/test_module.py @@ -10,6 +10,7 @@ """Test module parameter type for click.""" import click +__all__ = ('TestModuleParamType',) class TestModuleParamType(click.ParamType): """Parameter type to represent a unittest module. diff --git a/aiida/cmdline/params/types/user.py b/aiida/cmdline/params/types/user.py index 71a1c4eaab..df770e37a5 100644 --- a/aiida/cmdline/params/types/user.py +++ b/aiida/cmdline/params/types/user.py @@ -12,6 +12,7 @@ from aiida.cmdline.utils.decorators import with_dbenv +__all__ = ('UserParamType',) class UserParamType(click.ParamType): """ diff --git a/aiida/cmdline/params/types/workflow.py b/aiida/cmdline/params/types/workflow.py index 0a3fc48b6a..d23a7442fd 100644 --- a/aiida/cmdline/params/types/workflow.py +++ b/aiida/cmdline/params/types/workflow.py @@ -13,6 +13,7 @@ from .identifier import IdentifierParamType +__all__ = ('WorkflowParamType',) class WorkflowParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/utils/__init__.py b/aiida/cmdline/utils/__init__.py index 2776a55f97..f88b062b68 100644 --- a/aiida/cmdline/utils/__init__.py +++ b/aiida/cmdline/utils/__init__.py @@ -7,3 +7,26 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .ascii_vis import * +from .decorators import * +from .echo import * + +__all__ = ( + 'dbenv', + 'echo', + 'echo_critical', + 'echo_dictionary', + 'echo_error', + 'echo_highlight', + 'echo_info', + 'echo_success', + 'echo_warning', + 'format_call_graph', + 'only_if_daemon_running', + 'with_dbenv', +) diff --git a/aiida/common/__init__.py b/aiida/common/__init__.py index ea59db2024..880a702298 100644 --- a/aiida/common/__init__.py +++ b/aiida/common/__init__.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """ Common data structures, utility classes and functions @@ -15,6 +14,10 @@ """ +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .datastructures import * from .exceptions import * from .extendeddicts import * @@ -23,6 +26,63 @@ from .progress_reporter import * __all__ = ( - datastructures.__all__ + exceptions.__all__ + extendeddicts.__all__ + links.__all__ + log.__all__ + - progress_reporter.__all__ + 'AIIDA_LOGGER', + 'AiidaException', + 'AttributeDict', + 'CalcInfo', + 'CalcJobState', + 'CodeInfo', + 'CodeRunMode', + 'ConfigurationError', + 'ConfigurationVersionError', + 'ContentNotExistent', + 'DatabaseMigrationError', + 'DbContentError', + 'DefaultFieldsAttributeDict', + 'EntryPointError', + 'FailedError', + 'FeatureDisabled', + 'FeatureNotAvailable', + 'FixedFieldsAttributeDict', + 'GraphTraversalRule', + 'GraphTraversalRules', + 'HashingError', + 'IncompatibleDatabaseSchema', + 'InputValidationError', + 'IntegrityError', + 'InternalError', + 'InvalidEntryPointTypeError', + 'InvalidOperation', + 'LicensingException', + 'LinkType', + 'LoadingEntryPointError', + 'MissingConfigurationError', + 'MissingEntryPointError', + 'ModificationNotAllowed', + 'MultipleEntryPointError', + 'MultipleObjectsError', + 'NotExistent', + 'NotExistentAttributeError', + 'NotExistentKeyError', + 'OutputParsingError', + 'ParsingError', + 'PluginInternalError', + 'ProfileConfigurationError', + 'ProgressReporterAbstract', + 'RemoteOperationError', + 'StashMode', + 'StoringNotAllowed', + 'TQDM_BAR_FORMAT', + 'TestsNotAllowedError', + 'TransportTaskException', + 'UniquenessError', + 'UnsupportedSpeciesError', + 'ValidationError', + 'create_callback', + 'get_progress_reporter', + 'override_log_formatter', + 'override_log_level', + 'set_progress_bar_tqdm', + 'set_progress_reporter', + 'validate_link_label', ) diff --git a/aiida/engine/__init__.py b/aiida/engine/__init__.py index 984ff61866..356645d65b 100644 --- a/aiida/engine/__init__.py +++ b/aiida/engine/__init__.py @@ -7,11 +7,66 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable,redefined-builtin """Module with all the internals that make up the engine of `aiida-core`.""" +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .exceptions import * from .launch import * +from .persistence import * from .processes import * +from .runners import * from .utils import * -__all__ = (launch.__all__ + processes.__all__ + utils.__all__) # type: ignore[name-defined] +__all__ = ( + 'AiiDAPersister', + 'Awaitable', + 'AwaitableAction', + 'AwaitableTarget', + 'BaseRestartWorkChain', + 'CalcJob', + 'CalcJobOutputPort', + 'CalcJobProcessSpec', + 'ExitCode', + 'ExitCodesNamespace', + 'FunctionProcess', + 'InputPort', + 'InterruptableFuture', + 'JobManager', + 'JobsList', + 'ObjectLoader', + 'OutputPort', + 'PORT_NAMESPACE_SEPARATOR', + 'PastException', + 'PortNamespace', + 'Process', + 'ProcessBuilder', + 'ProcessBuilderNamespace', + 'ProcessFuture', + 'ProcessHandlerReport', + 'ProcessSpec', + 'ProcessState', + 'Runner', + 'ToContext', + 'WithNonDb', + 'WithSerialize', + 'WorkChain', + 'append_', + 'assign_', + 'calcfunction', + 'construct_awaitable', + 'get_object_loader', + 'if_', + 'interruptable_task', + 'is_process_function', + 'process_handler', + 'return_', + 'run', + 'run_get_node', + 'run_get_pk', + 'submit', + 'while_', + 'workfunction', +) diff --git a/aiida/engine/processes/__init__.py b/aiida/engine/processes/__init__.py index b3045dcfd4..9a1930eb76 100644 --- a/aiida/engine/processes/__init__.py +++ b/aiida/engine/processes/__init__.py @@ -7,18 +7,57 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable,redefined-builtin """Module for processes and related utilities.""" + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .builder import * from .calcjobs import * from .exit_code import * from .functions import * +from .futures import * from .ports import * from .process import * from .process_spec import * from .workchains import * __all__ = ( - builder.__all__ + calcjobs.__all__ + exit_code.__all__ + functions.__all__ + # type: ignore[name-defined] - ports.__all__ + process.__all__ + process_spec.__all__ + workchains.__all__ # type: ignore[name-defined] + 'Awaitable', + 'AwaitableAction', + 'AwaitableTarget', + 'BaseRestartWorkChain', + 'CalcJob', + 'CalcJobOutputPort', + 'CalcJobProcessSpec', + 'ExitCode', + 'ExitCodesNamespace', + 'FunctionProcess', + 'InputPort', + 'JobManager', + 'JobsList', + 'OutputPort', + 'PORT_NAMESPACE_SEPARATOR', + 'PortNamespace', + 'Process', + 'ProcessBuilder', + 'ProcessBuilderNamespace', + 'ProcessFuture', + 'ProcessHandlerReport', + 'ProcessSpec', + 'ProcessState', + 'ToContext', + 'WithNonDb', + 'WithSerialize', + 'WorkChain', + 'append_', + 'assign_', + 'calcfunction', + 'construct_awaitable', + 'if_', + 'process_handler', + 'return_', + 'while_', + 'workfunction', ) diff --git a/aiida/engine/processes/calcjobs/__init__.py b/aiida/engine/processes/calcjobs/__init__.py index 57d4777ae7..bce38909b3 100644 --- a/aiida/engine/processes/calcjobs/__init__.py +++ b/aiida/engine/processes/calcjobs/__init__.py @@ -7,9 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module for the `CalcJob` process and related utilities.""" +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .calcjob import * +from .manager import * -__all__ = (calcjob.__all__) # type: ignore[name-defined] +__all__ = ( + 'CalcJob', + 'JobManager', + 'JobsList', +) diff --git a/aiida/engine/processes/workchains/__init__.py b/aiida/engine/processes/workchains/__init__.py index 9b0cf508c9..b11d3fc147 100644 --- a/aiida/engine/processes/workchains/__init__.py +++ b/aiida/engine/processes/workchains/__init__.py @@ -7,11 +7,31 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module for the `WorkChain` process and related utilities.""" + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .awaitable import * from .context import * from .restart import * from .utils import * from .workchain import * -__all__ = (context.__all__ + restart.__all__ + utils.__all__ + workchain.__all__) # type: ignore[name-defined] +__all__ = ( + 'Awaitable', + 'AwaitableAction', + 'AwaitableTarget', + 'BaseRestartWorkChain', + 'ProcessHandlerReport', + 'ToContext', + 'WorkChain', + 'append_', + 'assign_', + 'construct_awaitable', + 'if_', + 'process_handler', + 'return_', + 'while_', +) diff --git a/aiida/manage/__init__.py b/aiida/manage/__init__.py index f25c1d5909..e8a920fb86 100644 --- a/aiida/manage/__init__.py +++ b/aiida/manage/__init__.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """ Managing an AiiDA instance: @@ -20,3 +19,58 @@ .. note:: Modules in this sub package may require the database environment to be loaded """ + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .caching import * +from .configuration import * +from .database import * +from .external import * +from .manager import * +from .tests import * + +__all__ = ( + 'BACKEND_UUID', + 'BROKER_DEFAULTS', + 'CONFIG', + 'CURRENT_CONFIG_VERSION', + 'CommunicationTimeout', + 'Config', + 'ConfigValidationError', + 'DEFAULT_DBINFO', + 'DeliveryFailed', + 'OLDEST_COMPATIBLE_CONFIG_VERSION', + 'Option', + 'PROFILE', + 'PluginTestCase', + 'Postgres', + 'PostgresConnectionMode', + 'ProcessLauncher', + 'Profile', + 'RemoteException', + 'TABLES_UUID_DEDUPLICATION', + 'TestRunner', + 'check_and_migrate_config', + 'config_needs_migrating', + 'config_schema', + 'deduplicate_uuids', + 'disable_caching', + 'enable_caching', + 'get_config', + 'get_config_option', + 'get_config_path', + 'get_current_version', + 'get_duplicate_uuids', + 'get_manager', + 'get_option', + 'get_option_names', + 'get_use_cache', + 'load_profile', + 'parse_option', + 'reset_config', + 'reset_manager', + 'verify_uuid_uniqueness', + 'write_database_integrity_violation', +) diff --git a/aiida/manage/configuration/__init__.py b/aiida/manage/configuration/__init__.py index 6860ab2c07..fb148b11b0 100644 --- a/aiida/manage/configuration/__init__.py +++ b/aiida/manage/configuration/__init__.py @@ -7,267 +7,38 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=undefined-variable,wildcard-import,global-statement,redefined-outer-name,cyclic-import """Modules related to the configuration of an AiiDA instance.""" -import os -import shutil -import warnings -from aiida.common.warnings import AiidaDeprecationWarning +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .config import * +from .main import * +from .migrations import * from .options import * from .profile import * -CONFIG = None -PROFILE = None -BACKEND_UUID = None # This will be set to the UUID of the profile as soon as its corresponding backend is loaded - __all__ = ( - config.__all__ + options.__all__ + profile.__all__ + - ('get_config', 'get_config_option', 'get_config_path', 'load_profile', 'reset_config') + 'BACKEND_UUID', + 'CONFIG', + 'CURRENT_CONFIG_VERSION', + 'Config', + 'ConfigValidationError', + 'OLDEST_COMPATIBLE_CONFIG_VERSION', + 'Option', + 'PROFILE', + 'Profile', + 'check_and_migrate_config', + 'config_needs_migrating', + 'config_schema', + 'get_config', + 'get_config_option', + 'get_config_path', + 'get_current_version', + 'get_option', + 'get_option_names', + 'load_profile', + 'parse_option', + 'reset_config', ) - - -def load_profile(profile=None): - """Load a profile. - - .. note:: if a profile is already loaded and no explicit profile is specified, nothing will be done - - :param profile: the name of the profile to load, by default will use the one marked as default in the config - :type profile: str - - :return: the loaded `Profile` instance - :rtype: :class:`~aiida.manage.configuration.Profile` - :raises `aiida.common.exceptions.InvalidOperation`: if the backend of another profile has already been loaded - """ - from aiida.common import InvalidOperation - from aiida.common.log import configure_logging - - global PROFILE - global BACKEND_UUID - - # If a profile is loaded and the specified profile name is None or that of the currently loaded, do nothing - if PROFILE and (profile is None or PROFILE.name is profile): - return PROFILE - - PROFILE = get_config().get_profile(profile) - - if BACKEND_UUID is not None and BACKEND_UUID != PROFILE.uuid: - # Once the switching of profiles with different backends becomes possible, the backend has to be reset properly - raise InvalidOperation('cannot switch profile because backend of another profile is already loaded') - - # Reconfigure the logging to make sure that profile specific logging configuration options are taken into account. - # Note that we do not configure with `with_orm=True` because that will force the backend to be loaded. This should - # instead be done lazily in `Manager._load_backend`. - configure_logging() - - return PROFILE - - -def get_config_path(): - """Returns path to .aiida configuration directory.""" - from .settings import AIIDA_CONFIG_FOLDER, DEFAULT_CONFIG_FILE_NAME - - return os.path.join(AIIDA_CONFIG_FOLDER, DEFAULT_CONFIG_FILE_NAME) - - -def load_config(create=False): - """Instantiate Config object representing an AiiDA configuration file. - - Warning: Contrary to :func:`~aiida.manage.configuration.get_config`, this function is uncached and will always - create a new Config object. You may want to call :func:`~aiida.manage.configuration.get_config` instead. - - :param create: if True, will create the configuration file if it does not already exist - :type create: bool - - :return: the config - :rtype: :class:`~aiida.manage.configuration.config.Config` - :raises aiida.common.MissingConfigurationError: if the configuration file could not be found and create=False - """ - from aiida.common import exceptions - from .config import Config - - filepath = get_config_path() - - if not os.path.isfile(filepath) and not create: - raise exceptions.MissingConfigurationError(f'configuration file {filepath} does not exist') - - try: - config = Config.from_file(filepath) - except ValueError as exc: - raise exceptions.ConfigurationError(f'configuration file {filepath} contains invalid JSON') from exc - - _merge_deprecated_cache_yaml(config, filepath) - - return config - - -def _merge_deprecated_cache_yaml(config, filepath): - """Merge the deprecated cache_config.yml into the config.""" - from aiida.common import timezone - cache_path = os.path.join(os.path.dirname(filepath), 'cache_config.yml') - if not os.path.exists(cache_path): - return - - cache_path_backup = None - # Keep generating a new backup filename based on the current time until it does not exist - while not cache_path_backup or os.path.isfile(cache_path_backup): - cache_path_backup = f"{cache_path}.{timezone.now().strftime('%Y%m%d-%H%M%S.%f')}" - - warnings.warn( - 'cache_config.yml use is deprecated and support will be removed in `v3.0`. Merging into config.json and ' - f'moving to: {cache_path_backup}', AiidaDeprecationWarning - ) - import yaml - with open(cache_path, 'r', encoding='utf8') as handle: - cache_config = yaml.safe_load(handle) - for profile_name, data in cache_config.items(): - if profile_name not in config.profile_names: - warnings.warn(f"Profile '{profile_name}' from cache_config.yml not in config.json, skipping", UserWarning) - continue - for key, option_name in [('default', 'caching.default_enabled'), ('enabled', 'caching.enabled_for'), - ('disabled', 'caching.disabled_for')]: - if key in data: - value = data[key] - # in case of empty key - value = [] if value is None and key != 'default' else value - config.set_option(option_name, value, scope=profile_name) - config.store() - shutil.move(cache_path, cache_path_backup) - - -def get_profile(): - """Return the currently loaded profile. - - :return: the globally loaded `Profile` instance or `None` - :rtype: :class:`~aiida.manage.configuration.Profile` - """ - global PROFILE - return PROFILE - - -def reset_profile(): - """Reset the globally loaded profile. - - .. warning:: This is experimental functionality and should for now be used only internally. If the reset is unclean - weird unknown side-effects may occur that end up corrupting or destroying data. - """ - global PROFILE - global BACKEND_UUID - PROFILE = None - BACKEND_UUID = None - - -def reset_config(): - """Reset the globally loaded config. - - .. warning:: This is experimental functionality and should for now be used only internally. If the reset is unclean - weird unknown side-effects may occur that end up corrupting or destroying data. - """ - global CONFIG - CONFIG = None - - -def get_config(create=False): - """Return the current configuration. - - If the configuration has not been loaded yet - * the configuration is loaded using ``load_config`` - * the global `CONFIG` variable is set - * the configuration object is returned - - Note: This function will except if no configuration file can be found. Only call this function, if you need - information from the configuration file. - - :param create: if True, will create the configuration file if it does not already exist - :type create: bool - - :return: the config - :rtype: :class:`~aiida.manage.configuration.config.Config` - :raises aiida.common.ConfigurationError: if the configuration file could not be found, read or deserialized - """ - global CONFIG - - if not CONFIG: - CONFIG = load_config(create=create) - - if CONFIG.get_option('warnings.showdeprecations'): - # If the user does not want to get AiiDA deprecation warnings, we disable them - this can be achieved with:: - # verdi config warnings.showdeprecations False - # Note that the AiidaDeprecationWarning does NOT inherit from DeprecationWarning - warnings.simplefilter('default', AiidaDeprecationWarning) # pylint: disable=no-member - # This should default to 'once', i.e. once per different message - else: - warnings.simplefilter('ignore', AiidaDeprecationWarning) # pylint: disable=no-member - - return CONFIG - - -def get_config_option(option_name): - """Return the value for the given configuration option. - - This function will attempt to load the value of the option as defined for the current profile or otherwise as - defined configuration wide. If no configuration is yet loaded, this function will fall back on the default that may - be defined for the option itself. This is useful for options that need to be defined at loading time of AiiDA when - no configuration is yet loaded or may not even yet exist. In cases where one expects a profile to be loaded, - preference should be given to retrieving the option through the Config instance and its `get_option` method. - - :param option_name: the name of the configuration option - :type option_name: str - - :return: option value as specified for the profile/configuration if loaded, otherwise option default - """ - from aiida.common import exceptions - - option = options.get_option(option_name) - - try: - config = get_config(create=True) - except exceptions.ConfigurationError: - value = option.default if option.default is not options.NO_DEFAULT else None - else: - if config.current_profile: - # Try to get the option for the profile, but do not return the option default - value_profile = config.get_option(option_name, scope=config.current_profile.name, default=False) - else: - value_profile = None - - # Value is the profile value if defined or otherwise the global value, which will be None if not set - value = value_profile if value_profile else config.get_option(option_name) - - return value - - -def load_documentation_profile(): - """Load a dummy profile just for the purposes of being able to build the documentation. - - The building of the documentation will require importing the `aiida` package and some code will try to access the - loaded configuration and profile, which if not done will except. On top of that, Django will raise an exception if - the database models are loaded before its settings are loaded. This also is taken care of by loading a Django - profile and loading the corresponding backend. Calling this function will perform all these requirements allowing - the documentation to be built without having to install and configure AiiDA nor having an actual database present. - """ - import tempfile - from aiida.manage.manager import get_manager - from .config import Config - from .profile import Profile - - global PROFILE - global CONFIG - - with tempfile.NamedTemporaryFile() as handle: - profile_name = 'readthedocs' - profile = { - 'AIIDADB_ENGINE': 'postgresql_psycopg2', - 'AIIDADB_BACKEND': 'django', - 'AIIDADB_PORT': 5432, - 'AIIDADB_HOST': 'localhost', - 'AIIDADB_NAME': 'aiidadb', - 'AIIDADB_PASS': 'aiidadb', - 'AIIDADB_USER': 'aiida', - 'AIIDADB_REPOSITORY_URI': 'file:///dev/null', - } - config = {'default_profile': profile_name, 'profiles': {profile_name: profile}} - PROFILE = Profile(profile_name, profile, from_config=True) - CONFIG = Config(handle.name, config) - get_manager()._load_backend(schema_check=False, repository_check=False) # pylint: disable=protected-access diff --git a/aiida/manage/configuration/main.py b/aiida/manage/configuration/main.py new file mode 100644 index 0000000000..f27f589c3a --- /dev/null +++ b/aiida/manage/configuration/main.py @@ -0,0 +1,267 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=undefined-variable,wildcard-import,global-statement,redefined-outer-name,cyclic-import +"""Modules related to the configuration of an AiiDA instance.""" +import os +import shutil +import warnings + +from aiida.common.warnings import AiidaDeprecationWarning + +CONFIG = None +PROFILE = None +BACKEND_UUID = None # This will be set to the UUID of the profile as soon as its corresponding backend is loaded + +__all__ = ('get_config', 'get_config_option', 'get_config_path', 'load_profile', 'reset_config', 'CONFIG', 'PROFILE', 'BACKEND_UUID') + + +def load_profile(profile=None): + """Load a profile. + + .. note:: if a profile is already loaded and no explicit profile is specified, nothing will be done + + :param profile: the name of the profile to load, by default will use the one marked as default in the config + :type profile: str + + :return: the loaded `Profile` instance + :rtype: :class:`~aiida.manage.configuration.Profile` + :raises `aiida.common.exceptions.InvalidOperation`: if the backend of another profile has already been loaded + """ + from aiida.common import InvalidOperation + from aiida.common.log import configure_logging + + global PROFILE + global BACKEND_UUID + + # If a profile is loaded and the specified profile name is None or that of the currently loaded, do nothing + if PROFILE and (profile is None or PROFILE.name is profile): + return PROFILE + + PROFILE = get_config().get_profile(profile) + + if BACKEND_UUID is not None and BACKEND_UUID != PROFILE.uuid: + # Once the switching of profiles with different backends becomes possible, the backend has to be reset properly + raise InvalidOperation('cannot switch profile because backend of another profile is already loaded') + + # Reconfigure the logging to make sure that profile specific logging configuration options are taken into account. + # Note that we do not configure with `with_orm=True` because that will force the backend to be loaded. This should + # instead be done lazily in `Manager._load_backend`. + configure_logging() + + return PROFILE + + +def get_config_path(): + """Returns path to .aiida configuration directory.""" + from .settings import AIIDA_CONFIG_FOLDER, DEFAULT_CONFIG_FILE_NAME + + return os.path.join(AIIDA_CONFIG_FOLDER, DEFAULT_CONFIG_FILE_NAME) + + +def load_config(create=False): + """Instantiate Config object representing an AiiDA configuration file. + + Warning: Contrary to :func:`~aiida.manage.configuration.get_config`, this function is uncached and will always + create a new Config object. You may want to call :func:`~aiida.manage.configuration.get_config` instead. + + :param create: if True, will create the configuration file if it does not already exist + :type create: bool + + :return: the config + :rtype: :class:`~aiida.manage.configuration.config.Config` + :raises aiida.common.MissingConfigurationError: if the configuration file could not be found and create=False + """ + from aiida.common import exceptions + from .config import Config + + filepath = get_config_path() + + if not os.path.isfile(filepath) and not create: + raise exceptions.MissingConfigurationError(f'configuration file {filepath} does not exist') + + try: + config = Config.from_file(filepath) + except ValueError as exc: + raise exceptions.ConfigurationError(f'configuration file {filepath} contains invalid JSON') from exc + + _merge_deprecated_cache_yaml(config, filepath) + + return config + + +def _merge_deprecated_cache_yaml(config, filepath): + """Merge the deprecated cache_config.yml into the config.""" + from aiida.common import timezone + cache_path = os.path.join(os.path.dirname(filepath), 'cache_config.yml') + if not os.path.exists(cache_path): + return + + cache_path_backup = None + # Keep generating a new backup filename based on the current time until it does not exist + while not cache_path_backup or os.path.isfile(cache_path_backup): + cache_path_backup = f"{cache_path}.{timezone.now().strftime('%Y%m%d-%H%M%S.%f')}" + + warnings.warn( + 'cache_config.yml use is deprecated and support will be removed in `v3.0`. Merging into config.json and ' + f'moving to: {cache_path_backup}', AiidaDeprecationWarning + ) + import yaml + with open(cache_path, 'r', encoding='utf8') as handle: + cache_config = yaml.safe_load(handle) + for profile_name, data in cache_config.items(): + if profile_name not in config.profile_names: + warnings.warn(f"Profile '{profile_name}' from cache_config.yml not in config.json, skipping", UserWarning) + continue + for key, option_name in [('default', 'caching.default_enabled'), ('enabled', 'caching.enabled_for'), + ('disabled', 'caching.disabled_for')]: + if key in data: + value = data[key] + # in case of empty key + value = [] if value is None and key != 'default' else value + config.set_option(option_name, value, scope=profile_name) + config.store() + shutil.move(cache_path, cache_path_backup) + + +def get_profile(): + """Return the currently loaded profile. + + :return: the globally loaded `Profile` instance or `None` + :rtype: :class:`~aiida.manage.configuration.Profile` + """ + global PROFILE + return PROFILE + + +def reset_profile(): + """Reset the globally loaded profile. + + .. warning:: This is experimental functionality and should for now be used only internally. If the reset is unclean + weird unknown side-effects may occur that end up corrupting or destroying data. + """ + global PROFILE + global BACKEND_UUID + PROFILE = None + BACKEND_UUID = None + + +def reset_config(): + """Reset the globally loaded config. + + .. warning:: This is experimental functionality and should for now be used only internally. If the reset is unclean + weird unknown side-effects may occur that end up corrupting or destroying data. + """ + global CONFIG + CONFIG = None + + +def get_config(create=False): + """Return the current configuration. + + If the configuration has not been loaded yet + * the configuration is loaded using ``load_config`` + * the global `CONFIG` variable is set + * the configuration object is returned + + Note: This function will except if no configuration file can be found. Only call this function, if you need + information from the configuration file. + + :param create: if True, will create the configuration file if it does not already exist + :type create: bool + + :return: the config + :rtype: :class:`~aiida.manage.configuration.config.Config` + :raises aiida.common.ConfigurationError: if the configuration file could not be found, read or deserialized + """ + global CONFIG + + if not CONFIG: + CONFIG = load_config(create=create) + + if CONFIG.get_option('warnings.showdeprecations'): + # If the user does not want to get AiiDA deprecation warnings, we disable them - this can be achieved with:: + # verdi config warnings.showdeprecations False + # Note that the AiidaDeprecationWarning does NOT inherit from DeprecationWarning + warnings.simplefilter('default', AiidaDeprecationWarning) # pylint: disable=no-member + # This should default to 'once', i.e. once per different message + else: + warnings.simplefilter('ignore', AiidaDeprecationWarning) # pylint: disable=no-member + + return CONFIG + + +def get_config_option(option_name): + """Return the value for the given configuration option. + + This function will attempt to load the value of the option as defined for the current profile or otherwise as + defined configuration wide. If no configuration is yet loaded, this function will fall back on the default that may + be defined for the option itself. This is useful for options that need to be defined at loading time of AiiDA when + no configuration is yet loaded or may not even yet exist. In cases where one expects a profile to be loaded, + preference should be given to retrieving the option through the Config instance and its `get_option` method. + + :param option_name: the name of the configuration option + :type option_name: str + + :return: option value as specified for the profile/configuration if loaded, otherwise option default + """ + from aiida.common import exceptions + + option = options.get_option(option_name) + + try: + config = get_config(create=True) + except exceptions.ConfigurationError: + value = option.default if option.default is not options.NO_DEFAULT else None + else: + if config.current_profile: + # Try to get the option for the profile, but do not return the option default + value_profile = config.get_option(option_name, scope=config.current_profile.name, default=False) + else: + value_profile = None + + # Value is the profile value if defined or otherwise the global value, which will be None if not set + value = value_profile if value_profile else config.get_option(option_name) + + return value + + +def load_documentation_profile(): + """Load a dummy profile just for the purposes of being able to build the documentation. + + The building of the documentation will require importing the `aiida` package and some code will try to access the + loaded configuration and profile, which if not done will except. On top of that, Django will raise an exception if + the database models are loaded before its settings are loaded. This also is taken care of by loading a Django + profile and loading the corresponding backend. Calling this function will perform all these requirements allowing + the documentation to be built without having to install and configure AiiDA nor having an actual database present. + """ + import tempfile + from aiida.manage.manager import get_manager + from .config import Config + from .profile import Profile + + global PROFILE + global CONFIG + + with tempfile.NamedTemporaryFile() as handle: + profile_name = 'readthedocs' + profile = { + 'AIIDADB_ENGINE': 'postgresql_psycopg2', + 'AIIDADB_BACKEND': 'django', + 'AIIDADB_PORT': 5432, + 'AIIDADB_HOST': 'localhost', + 'AIIDADB_NAME': 'aiidadb', + 'AIIDADB_PASS': 'aiidadb', + 'AIIDADB_USER': 'aiida', + 'AIIDADB_REPOSITORY_URI': 'file:///dev/null', + } + config = {'default_profile': profile_name, 'profiles': {profile_name: profile}} + PROFILE = Profile(profile_name, profile, from_config=True) + CONFIG = Config(handle.name, config) + get_manager()._load_backend(schema_check=False, repository_check=False) # pylint: disable=protected-access diff --git a/aiida/manage/configuration/migrations/__init__.py b/aiida/manage/configuration/migrations/__init__.py index 6b99f32a5d..2170fcdcce 100644 --- a/aiida/manage/configuration/migrations/__init__.py +++ b/aiida/manage/configuration/migrations/__init__.py @@ -7,10 +7,19 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=undefined-variable,wildcard-import """Methods and definitions of migrations for the configuration file of an AiiDA instance.""" +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .migrations import * from .utils import * -__all__ = (migrations.__all__ + utils.__all__) +__all__ = ( + 'CURRENT_CONFIG_VERSION', + 'OLDEST_COMPATIBLE_CONFIG_VERSION', + 'check_and_migrate_config', + 'config_needs_migrating', + 'get_current_version', +) diff --git a/aiida/manage/database/__init__.py b/aiida/manage/database/__init__.py index 2776a55f97..7cb1bdb138 100644 --- a/aiida/manage/database/__init__.py +++ b/aiida/manage/database/__init__.py @@ -7,3 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .integrity import * + +__all__ = ( + 'TABLES_UUID_DEDUPLICATION', + 'deduplicate_uuids', + 'get_duplicate_uuids', + 'verify_uuid_uniqueness', + 'write_database_integrity_violation', +) diff --git a/aiida/manage/database/integrity/__init__.py b/aiida/manage/database/integrity/__init__.py index 796a9a7213..4bc2bb0f47 100644 --- a/aiida/manage/database/integrity/__init__.py +++ b/aiida/manage/database/integrity/__init__.py @@ -7,46 +7,19 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name """Methods to validate the database integrity and fix violations.""" -WARNING_BORDER = '*' * 120 +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import +from .duplicate_uuid import * +from .utils import * -def write_database_integrity_violation(results, headers, reason_message, action_message=None): - """Emit a integrity violation warning and write the violating records to a log file in the current directory - - :param results: a list of tuples representing the violating records - :param headers: a tuple of strings that will be used as a header for the log file. Should have the same length - as each tuple in the results list. - :param reason_message: a human readable message detailing the reason of the integrity violation - :param action_message: an optional human readable message detailing a performed action, if any - """ - # pylint: disable=duplicate-string-formatting-argument - from datetime import datetime - from tabulate import tabulate - from tempfile import NamedTemporaryFile - - from aiida.cmdline.utils import echo - from aiida.manage import configuration - - if configuration.PROFILE.is_test_profile: - return - - if action_message is None: - action_message = 'nothing' - - with NamedTemporaryFile(prefix='migration-', suffix='.log', dir='.', delete=False, mode='w+') as handle: - echo.echo('') - echo.echo_warning( - '\n{}\nFound one or multiple records that violate the integrity of the database\nViolation reason: {}\n' - 'Performed action: {}\nViolators written to: {}\n{}\n'.format( - WARNING_BORDER, reason_message, action_message, handle.name, WARNING_BORDER - ) - ) - - handle.write(f'# {datetime.utcnow().isoformat()}\n') - handle.write(f'# Violation reason: {reason_message}\n') - handle.write(f'# Performed action: {action_message}\n') - handle.write('\n') - handle.write(tabulate(results, headers)) +__all__ = ( + 'TABLES_UUID_DEDUPLICATION', + 'deduplicate_uuids', + 'get_duplicate_uuids', + 'verify_uuid_uniqueness', + 'write_database_integrity_violation', +) diff --git a/aiida/manage/database/integrity/utils.py b/aiida/manage/database/integrity/utils.py new file mode 100644 index 0000000000..d516a00a8a --- /dev/null +++ b/aiida/manage/database/integrity/utils.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name +"""Methods to validate the database integrity and fix violations.""" +__all__ = ('write_database_integrity_violation',) + +WARNING_BORDER = '*' * 120 + + +def write_database_integrity_violation(results, headers, reason_message, action_message=None): + """Emit a integrity violation warning and write the violating records to a log file in the current directory + + :param results: a list of tuples representing the violating records + :param headers: a tuple of strings that will be used as a header for the log file. Should have the same length + as each tuple in the results list. + :param reason_message: a human readable message detailing the reason of the integrity violation + :param action_message: an optional human readable message detailing a performed action, if any + """ + # pylint: disable=duplicate-string-formatting-argument + from datetime import datetime + from tabulate import tabulate + from tempfile import NamedTemporaryFile + + from aiida.cmdline.utils import echo + from aiida.manage import configuration + + if configuration.PROFILE.is_test_profile: + return + + if action_message is None: + action_message = 'nothing' + + with NamedTemporaryFile(prefix='migration-', suffix='.log', dir='.', delete=False, mode='w+') as handle: + echo.echo('') + echo.echo_warning( + '\n{}\nFound one or multiple records that violate the integrity of the database\nViolation reason: {}\n' + 'Performed action: {}\nViolators written to: {}\n{}\n'.format( + WARNING_BORDER, reason_message, action_message, handle.name, WARNING_BORDER + ) + ) + + handle.write(f'# {datetime.utcnow().isoformat()}\n') + handle.write(f'# Violation reason: {reason_message}\n') + handle.write(f'# Performed action: {action_message}\n') + handle.write('\n') + handle.write(tabulate(results, headers)) diff --git a/aiida/manage/external/__init__.py b/aiida/manage/external/__init__.py index e82b79252b..10bf08219a 100644 --- a/aiida/manage/external/__init__.py +++ b/aiida/manage/external/__init__.py @@ -8,3 +8,21 @@ # For further information please visit http://www.aiida.net # ########################################################################### """User facing APIs to control AiiDA from the verdi cli, scripts or plugins""" + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .postgres import * +from .rmq import * + +__all__ = ( + 'BROKER_DEFAULTS', + 'CommunicationTimeout', + 'DEFAULT_DBINFO', + 'DeliveryFailed', + 'Postgres', + 'PostgresConnectionMode', + 'ProcessLauncher', + 'RemoteException', +) diff --git a/aiida/manage/tests/__init__.py b/aiida/manage/tests/__init__.py index 24e0d218a7..7d5fb9195a 100644 --- a/aiida/manage/tests/__init__.py +++ b/aiida/manage/tests/__init__.py @@ -11,499 +11,14 @@ Testing infrastructure for easy testing of AiiDA plugins. """ -import tempfile -import shutil -import os -from contextlib import contextmanager -from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA -from aiida.common import exceptions -from aiida.manage import configuration -from aiida.manage.configuration.settings import create_instance_directories -from aiida.manage import manager -from aiida.manage.external.postgres import Postgres +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import -__all__ = ('TestManager', 'TestManagerError', 'ProfileManager', 'TemporaryProfileManager', '_GLOBAL_TEST_MANAGER') +from .unittest_classes import * -_DEFAULT_PROFILE_INFO = { - 'name': 'test_profile', - 'email': 'tests@aiida.mail', - 'first_name': 'AiiDA', - 'last_name': 'Plugintest', - 'institution': 'aiidateam', - 'database_engine': 'postgresql_psycopg2', - 'database_backend': 'django', - 'database_username': 'aiida', - 'database_password': 'aiida_pw', - 'database_name': 'aiida_db', - 'repo_dir': 'test_repo', - 'config_dir': '.aiida', - 'root_path': '', - 'broker_protocol': 'amqp', - 'broker_username': 'guest', - 'broker_password': 'guest', - 'broker_host': '127.0.0.1', - 'broker_port': 5672, - 'broker_virtual_host': '' -} - - -class TestManagerError(Exception): - """Raised by TestManager in situations that may lead to inconsistent behaviour.""" - - def __init__(self, msg): - super().__init__() - self.msg = msg - - def __str__(self): - return repr(self.msg) - - -class TestManager: - """ - Test manager for plugin tests. - - Uses either ProfileManager for wrapping an existing profile or TemporaryProfileManager for setting up a complete - temporary AiiDA environment. - - For usage with pytest, see :py:class:`~aiida.manage.tests.pytest_fixtures`. - """ - - def __init__(self): - self._manager = None - - def use_temporary_profile(self, backend=None, pgtest=None): - """Set up Test manager to use temporary AiiDA profile. - - Uses :py:class:`aiida.manage.tests.TemporaryProfileManager` internally. - - :param backend: Backend to use. - :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, - e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. - - """ - if configuration.PROFILE is not None: - raise TestManagerError('AiiDA dbenv must not be loaded before setting up a test profile.') - if self._manager is not None: - raise TestManagerError('Profile manager already loaded.') - - mngr = TemporaryProfileManager(backend=backend, pgtest=pgtest) - mngr.create_profile() - self._manager = mngr # don't assign before profile has actually been created! - - def use_profile(self, profile_name): - """Set up Test manager to use existing profile. - - Uses :py:class:`aiida.manage.tests.ProfileManager` internally. - - :param profile_name: Name of existing test profile to use. - """ - if configuration.PROFILE is not None: - raise TestManagerError('AiiDA dbenv must not be loaded before setting up a test profile.') - if self._manager is not None: - raise TestManagerError('Profile manager already loaded.') - - self._manager = ProfileManager(profile_name=profile_name) - self._manager.init_db() - - def has_profile_open(self): - return self._manager and self._manager.has_profile_open() - - def reset_db(self): - return self._manager.reset_db() - - def destroy_all(self): - if self._manager: - self._manager.destroy_all() - self._manager = None - - -class ProfileManager: - """ - Wraps existing AiiDA profile. - """ - - def __init__(self, profile_name): - """ - Use an existing profile. - - :param profile_name: Name of the profile to be loaded - """ - from aiida import load_profile - from aiida.backends.testbase import check_if_tests_can_run - - self._profile = None - self._user = None - - try: - self._profile = load_profile(profile_name) - manager.get_manager()._load_backend(schema_check=False) # pylint: disable=protected-access - except Exception: - raise TestManagerError('Unable to load test profile \'{}\'.'.format(profile_name)) - check_if_tests_can_run() - - self._select_db_test_case(backend=self._profile.database_backend) - - def _select_db_test_case(self, backend): - """ - Selects tests case for the correct database backend. - """ - if backend == BACKEND_DJANGO: - from aiida.backends.djsite.db.testbase import DjangoTests - self._test_case = DjangoTests() - elif backend == BACKEND_SQLA: - from aiida.backends.sqlalchemy.testbase import SqlAlchemyTests - from aiida.backends.sqlalchemy import get_scoped_session - - self._test_case = SqlAlchemyTests() - self._test_case.test_session = get_scoped_session() - - def reset_db(self): - self._test_case.clean_db() # will drop all users - manager.reset_manager() - self.init_db() - - def init_db(self): - """Initialise the database state for running of tests. - - Adds default user if necessary. - """ - from aiida.orm import User - from aiida.cmdline.commands.cmd_user import set_default_user - - if not User.objects.get_default(): - user_dict = get_user_dict(_DEFAULT_PROFILE_INFO) - try: - user = User(**user_dict) - user.store() - except exceptions.IntegrityError: - # The user already exists, no problem - user = User.objects.get(**user_dict) - - set_default_user(self._profile, user) - User.objects.reset() # necessary to pick up new default user - - def has_profile_open(self): - return self._profile is not None - - def destroy_all(self): - pass - - -class TemporaryProfileManager(ProfileManager): - """ - Manage the life cycle of a completely separated and temporary AiiDA environment. - - * No profile / database setup required - * Tests run via the TemporaryProfileManager never pollute the user's working environment - - Filesystem: - - * temporary ``.aiida`` configuration folder - * temporary repository folder - - Database: - - * temporary database cluster (via the ``pgtest`` package) - * with ``aiida`` database user - * with ``aiida_db`` database - - AiiDA: - - * configured to use the temporary configuration - * sets up a temporary profile for tests - - All of this happens automatically when using the corresponding tests classes & tests runners (unittest) - or fixtures (pytest). - - Example:: - - tests = TemporaryProfileManager(backend=backend) - tests.create_aiida_db() # set up only the database - tests.create_profile() # set up a profile (creates the db too if necessary) - - # ready for tests - - # run tests 1 - - tests.reset_db() - # database ready for independent tests 2 - - # run tests 2 - - tests.destroy_all() - # everything cleaned up - - """ - - _test_case = None - - def __init__(self, backend=BACKEND_DJANGO, pgtest=None): # pylint: disable=super-init-not-called - """Construct a TemporaryProfileManager - - :param backend: a database backend - :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, - e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. - - """ - from aiida.manage.configuration import settings - - self.dbinfo = {} - self.profile_info = _DEFAULT_PROFILE_INFO - self.profile_info['database_backend'] = backend - self._pgtest = pgtest or {} - - self.pg_cluster = None - self.postgres = None - self._profile = None - self._has_test_db = False - self._backup = { - 'config': configuration.CONFIG, - 'config_dir': settings.AIIDA_CONFIG_FOLDER, - 'profile': configuration.PROFILE, - } - - @property - def profile_dictionary(self): - """Profile parameters. - - Used to set up AiiDA profile from self.profile_info dictionary. - """ - dictionary = { - 'database_engine': self.profile_info.get('database_engine'), - 'database_backend': self.profile_info.get('database_backend'), - 'database_port': self.profile_info.get('database_port'), - 'database_hostname': self.profile_info.get('database_hostname'), - 'database_name': self.profile_info.get('database_name'), - 'database_username': self.profile_info.get('database_username'), - 'database_password': self.profile_info.get('database_password'), - 'broker_protocol': self.profile_info.get('broker_protocol'), - 'broker_username': self.profile_info.get('broker_username'), - 'broker_password': self.profile_info.get('broker_password'), - 'broker_host': self.profile_info.get('broker_host'), - 'broker_port': self.profile_info.get('broker_port'), - 'broker_virtual_host': self.profile_info.get('broker_virtual_host'), - 'repository_uri': f'file://{self.repo}', - } - return dictionary - - def create_db_cluster(self): - """ - Create the database cluster using PGTest. - """ - from pgtest.pgtest import PGTest - - if self.pg_cluster is not None: - raise TestManagerError( - 'Running temporary postgresql cluster detected.Use destroy_all() before creating a new cluster.' - ) - self.pg_cluster = PGTest(**self._pgtest) - self.dbinfo.update(self.pg_cluster.dsn) - - def create_aiida_db(self): - """ - Create the necessary database on the temporary postgres instance. - """ - if configuration.PROFILE is not None: - raise TestManagerError('AiiDA dbenv can not be loaded while creating a tests db environment') - if self.pg_cluster is None: - self.create_db_cluster() - self.postgres = Postgres(interactive=False, quiet=True, dbinfo=self.dbinfo) - # note: not using postgres.create_dbuser_db_safe here since we don't want prompts - self.postgres.create_dbuser(self.profile_info['database_username'], self.profile_info['database_password']) - self.postgres.create_db(self.profile_info['database_username'], self.profile_info['database_name']) - self.dbinfo = self.postgres.dbinfo - self.profile_info['database_hostname'] = self.postgres.host_for_psycopg2 - self.profile_info['database_port'] = self.postgres.port_for_psycopg2 - self._has_test_db = True - - def create_profile(self): - """ - Set AiiDA to use the tests config dir and create a default profile there - - Warning: the AiiDA dbenv must not be loaded when this is called! - """ - from aiida.manage.configuration import settings, load_profile, Profile - - if not self._has_test_db: - self.create_aiida_db() - - if not self.root_dir: - self.root_dir = tempfile.mkdtemp() - configuration.CONFIG = None - settings.AIIDA_CONFIG_FOLDER = self.config_dir - configuration.PROFILE = None - create_instance_directories() - profile_name = self.profile_info['name'] - config = configuration.get_config(create=True) - profile = Profile(profile_name, self.profile_dictionary) - config.add_profile(profile) - config.set_default_profile(profile_name).store() - self._profile = profile - - load_profile(profile_name) - backend = manager.get_manager()._load_backend(schema_check=False) - backend.migrate() - - self._select_db_test_case(backend=self._profile.database_backend) - self.init_db() - - def repo_ok(self): - return bool(self.repo and os.path.isdir(os.path.dirname(self.repo))) - - @property - def repo(self): - return self._return_dir(self.profile_info['repo_dir']) - - def _return_dir(self, dir_path): - """Return a path to a directory from the fs environment""" - if os.path.isabs(dir_path): - return dir_path - return os.path.join(self.root_dir, dir_path) - - @property - def backend(self): - return self.profile_info['backend'] - - @backend.setter - def backend(self, backend): - if self.has_profile_open(): - raise TestManagerError('backend cannot be changed after setting up the environment') - - valid_backends = [BACKEND_DJANGO, BACKEND_SQLA] - if backend not in valid_backends: - raise ValueError(f'invalid backend {backend}, must be one of {valid_backends}') - self.profile_info['backend'] = backend - - @property - def config_dir_ok(self): - return bool(self.config_dir and os.path.isdir(self.config_dir)) - - @property - def config_dir(self): - return self._return_dir(self.profile_info['config_dir']) - - @property - def root_dir(self): - return self.profile_info['root_path'] - - @root_dir.setter - def root_dir(self, root_dir): - self.profile_info['root_path'] = root_dir - - @property - def root_dir_ok(self): - return bool(self.root_dir and os.path.isdir(self.root_dir)) - - def destroy_all(self): - """Remove all traces of the tests run""" - from aiida.manage.configuration import settings - if self.root_dir: - shutil.rmtree(self.root_dir) - self.root_dir = None - if self.pg_cluster: - self.pg_cluster.close() - self.pg_cluster = None - self._has_test_db = False - self._profile = None - self._user = None - - if 'config' in self._backup: - configuration.CONFIG = self._backup['config'] - if 'config_dir' in self._backup: - settings.AIIDA_CONFIG_FOLDER = self._backup['config_dir'] - if 'profile' in self._backup: - configuration.PROFILE = self._backup['profile'] - - def has_profile_open(self): - return self._profile is not None - - -_GLOBAL_TEST_MANAGER = TestManager() - - -@contextmanager -def test_manager(backend=BACKEND_DJANGO, profile_name=None, pgtest=None): - """ Context manager for TestManager objects. - - Sets up temporary AiiDA environment for testing or reuses existing environment, - if `AIIDA_TEST_PROFILE` environment variable is set. - - Example pytest fixture:: - - def aiida_profile(): - with test_manager(backend) as test_mgr: - yield fixture_mgr - - Example unittest test runner:: - - with test_manager(backend) as test_mgr: - # ready for tests - # everything cleaned up - - - :param backend: database backend, either BACKEND_SQLA or BACKEND_DJANGO - :param profile_name: name of test profile to be used or None (to use temporary profile) - :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, - e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. - """ - from aiida.common.utils import Capturing - from aiida.common.log import configure_logging - - try: - if not _GLOBAL_TEST_MANAGER.has_profile_open(): - if profile_name: - _GLOBAL_TEST_MANAGER.use_profile(profile_name=profile_name) - else: - with Capturing(): # capture output of AiiDA DB setup - _GLOBAL_TEST_MANAGER.use_temporary_profile(backend=backend, pgtest=pgtest) - configure_logging(with_orm=True) - yield _GLOBAL_TEST_MANAGER - finally: - _GLOBAL_TEST_MANAGER.destroy_all() - - -def get_test_backend_name(): - """ Read name of database backend from environment variable or the specified test profile. - - Reads database backend ('django' or 'sqlalchemy') from 'AIIDA_TEST_BACKEND' environment variable, - or the backend configured for the 'AIIDA_TEST_PROFILE'. - Defaults to django backend. - - :returns: content of environment variable or `BACKEND_DJANGO` - :raises: ValueError if unknown backend name detected. - :raises: ValueError if both 'AIIDA_TEST_BACKEND' and 'AIIDA_TEST_PROFILE' are set, and the two - backends do not match. - """ - test_profile_name = get_test_profile_name() - backend_env = os.environ.get('AIIDA_TEST_BACKEND', None) - if test_profile_name is not None: - backend_profile = configuration.get_config().get_profile(test_profile_name).database_backend - if backend_env is not None and backend_env != backend_profile: - raise ValueError( - "The backend '{}' read from AIIDA_TEST_BACKEND does not match the backend '{}' " - "of AIIDA_TEST_PROFILE '{}'".format(backend_env, backend_profile, test_profile_name) - ) - backend_res = backend_profile - else: - backend_res = backend_env or BACKEND_DJANGO - - if backend_res in (BACKEND_DJANGO, BACKEND_SQLA): - return backend_res - raise ValueError(f"Unknown backend '{backend_res}' read from AIIDA_TEST_BACKEND environment variable") - - -def get_test_profile_name(): - """ Read name of test profile from environment variable. - - Reads name of existing test profile 'AIIDA_TEST_PROFILE' environment variable. - If specified, this profile is used for running the tests (instead of setting up a temporary profile). - - :returns: content of environment variable or `None` - """ - return os.environ.get('AIIDA_TEST_PROFILE', None) - - -def get_user_dict(profile_dict): - """Collect parameters required for creating users.""" - return {k: profile_dict[k] for k in ('email', 'first_name', 'last_name', 'institution')} +__all__ = ( + 'PluginTestCase', + 'TestRunner', +) diff --git a/aiida/orm/__init__.py b/aiida/orm/__init__.py index fa6f66afde..b3aab44ca0 100644 --- a/aiida/orm/__init__.py +++ b/aiida/orm/__init__.py @@ -7,9 +7,12 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable,redefined-builtin,cyclic-import """Main module to expose all orm classes and methods""" +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .authinfos import * from .comments import * from .computers import * @@ -22,6 +25,80 @@ from .utils import * __all__ = ( - authinfos.__all__ + comments.__all__ + computers.__all__ + entities.__all__ + groups.__all__ + logs.__all__ + - nodes.__all__ + querybuilder.__all__ + users.__all__ + utils.__all__ + 'ASCENDING', + 'AbstractNodeMeta', + 'ArrayData', + 'AttributeManager', + 'AuthInfo', + 'AutoGroup', + 'BandsData', + 'BaseType', + 'Bool', + 'CalcFunctionNode', + 'CalcJobNode', + 'CalcJobResultManager', + 'CalculationEntityLoader', + 'CalculationNode', + 'Code', + 'CodeEntityLoader', + 'Collection', + 'Comment', + 'Computer', + 'ComputerEntityLoader', + 'DESCENDING', + 'Data', + 'Dict', + 'Entity', + 'EntityAttributesMixin', + 'EntityExtrasMixin', + 'Float', + 'FolderData', + 'Group', + 'GroupEntityLoader', + 'ImportGroup', + 'Int', + 'Kind', + 'KpointsData', + 'LinkManager', + 'LinkPair', + 'LinkTriple', + 'List', + 'Log', + 'Node', + 'NodeEntityLoader', + 'NodeLinksManager', + 'NodeRepositoryMixin', + 'NumericType', + 'OrbitalData', + 'OrderSpecifier', + 'OrmEntityLoader', + 'ProcessNode', + 'ProjectionData', + 'QueryBuilder', + 'RemoteData', + 'RemoteStashData', + 'RemoteStashFolderData', + 'SinglefileData', + 'Site', + 'Str', + 'StructureData', + 'TrajectoryData', + 'UpfData', + 'UpfFamily', + 'User', + 'WorkChainNode', + 'WorkFunctionNode', + 'WorkflowNode', + 'XyData', + 'find_bandgap', + 'get_loader', + 'get_query_type_from_type_string', + 'get_type_string_from_class', + 'load_code', + 'load_computer', + 'load_group', + 'load_node', + 'load_node_class', + 'to_aiida_type', + 'validate_link', ) diff --git a/aiida/orm/implementation/__init__.py b/aiida/orm/implementation/__init__.py index 8e2f177b1d..6ee6d1bbe0 100644 --- a/aiida/orm/implementation/__init__.py +++ b/aiida/orm/implementation/__init__.py @@ -8,18 +8,45 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with the implementations of the various backend entities for various database backends.""" -# pylint: disable=wildcard-import,undefined-variable + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .authinfos import * from .backends import * from .comments import * from .computers import * +from .entities import * from .groups import * from .logs import * from .nodes import * from .querybuilder import * from .users import * +from .utils import * __all__ = ( - authinfos.__all__ + backends.__all__ + comments.__all__ + computers.__all__ + groups.__all__ + logs.__all__ + - nodes.__all__ + querybuilder.__all__ + users.__all__ + 'Backend', + 'BackendAuthInfo', + 'BackendAuthInfoCollection', + 'BackendCollection', + 'BackendComment', + 'BackendCommentCollection', + 'BackendComputer', + 'BackendComputerCollection', + 'BackendEntity', + 'BackendEntityAttributesMixin', + 'BackendEntityExtrasMixin', + 'BackendGroup', + 'BackendGroupCollection', + 'BackendLog', + 'BackendLogCollection', + 'BackendNode', + 'BackendNodeCollection', + 'BackendQueryBuilder', + 'BackendUser', + 'BackendUserCollection', + 'EntityType', + 'clean_value', + 'validate_attribute_extra_key', ) diff --git a/aiida/orm/implementation/django/__init__.py b/aiida/orm/implementation/django/__init__.py index 2776a55f97..6957af28c6 100644 --- a/aiida/orm/implementation/django/__init__.py +++ b/aiida/orm/implementation/django/__init__.py @@ -7,3 +7,21 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .backend import * +from .convert import * +from .groups import * +from .users import * + +__all__ = ( + 'DjangoBackend', + 'DjangoGroup', + 'DjangoGroupCollection', + 'DjangoUser', + 'DjangoUserCollection', + 'get_backend_entity', +) diff --git a/aiida/orm/implementation/sql/__init__.py b/aiida/orm/implementation/sql/__init__.py index 3cea3705ad..d5b115fd4e 100644 --- a/aiida/orm/implementation/sql/__init__.py +++ b/aiida/orm/implementation/sql/__init__.py @@ -12,3 +12,13 @@ All SQL backends with an ORM should subclass from the classes in this module """ + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .backends import * + +__all__ = ( + 'SqlBackend', +) diff --git a/aiida/orm/implementation/sqlalchemy/__init__.py b/aiida/orm/implementation/sqlalchemy/__init__.py index 2776a55f97..b836c7c844 100644 --- a/aiida/orm/implementation/sqlalchemy/__init__.py +++ b/aiida/orm/implementation/sqlalchemy/__init__.py @@ -7,3 +7,21 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .backend import * +from .convert import * +from .groups import * +from .users import * + +__all__ = ( + 'SqlaBackend', + 'SqlaGroup', + 'SqlaGroupCollection', + 'SqlaUser', + 'SqlaUserCollection', + 'get_backend_entity', +) diff --git a/aiida/orm/nodes/__init__.py b/aiida/orm/nodes/__init__.py index b11c562245..f15d59e7da 100644 --- a/aiida/orm/nodes/__init__.py +++ b/aiida/orm/nodes/__init__.py @@ -7,11 +7,53 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module with `Node` sub classes for data and processes.""" +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .data import * -from .process import * from .node import * +from .process import * +from .repository import * -__all__ = (data.__all__ + process.__all__ + node.__all__) +__all__ = ( + 'ArrayData', + 'BandsData', + 'BaseType', + 'Bool', + 'CalcFunctionNode', + 'CalcJobNode', + 'CalculationNode', + 'Code', + 'Data', + 'Dict', + 'Float', + 'FolderData', + 'Int', + 'Kind', + 'KpointsData', + 'List', + 'Node', + 'NodeRepositoryMixin', + 'NumericType', + 'OrbitalData', + 'ProcessNode', + 'ProjectionData', + 'RemoteData', + 'RemoteStashData', + 'RemoteStashFolderData', + 'SinglefileData', + 'Site', + 'Str', + 'StructureData', + 'TrajectoryData', + 'UpfData', + 'WorkChainNode', + 'WorkFunctionNode', + 'WorkflowNode', + 'XyData', + 'find_bandgap', + 'to_aiida_type', +) diff --git a/aiida/orm/nodes/data/__init__.py b/aiida/orm/nodes/data/__init__.py index 8ed0d10aa4..5cb9339df4 100644 --- a/aiida/orm/nodes/data/__init__.py +++ b/aiida/orm/nodes/data/__init__.py @@ -8,27 +8,56 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with `Node` sub classes for data structures.""" -from .array import ArrayData, BandsData, KpointsData, ProjectionData, TrajectoryData, XyData -from .base import BaseType, to_aiida_type -from .bool import Bool -from .cif import CifData -from .code import Code -from .data import Data -from .dict import Dict -from .float import Float -from .folder import FolderData -from .int import Int -from .list import List -from .numeric import NumericType -from .orbital import OrbitalData -from .remote import RemoteData, RemoteStashData, RemoteStashFolderData -from .singlefile import SinglefileData -from .str import Str -from .structure import StructureData -from .upf import UpfData + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .array import * +from .base import * +from .bool import * +from .code import * +from .data import * +from .dict import * +from .float import * +from .folder import * +from .int import * +from .list import * +from .numeric import * +from .orbital import * +from .remote import * +from .singlefile import * +from .str import * +from .structure import * +from .upf import * __all__ = ( - 'Data', 'BaseType', 'ArrayData', 'BandsData', 'KpointsData', 'ProjectionData', 'TrajectoryData', 'XyData', 'Bool', - 'CifData', 'Code', 'Float', 'FolderData', 'Int', 'List', 'OrbitalData', 'Dict', 'RemoteData', 'RemoteStashData', - 'RemoteStashFolderData', 'SinglefileData', 'Str', 'StructureData', 'UpfData', 'NumericType', 'to_aiida_type' + 'ArrayData', + 'BandsData', + 'BaseType', + 'Bool', + 'Code', + 'Data', + 'Dict', + 'Float', + 'FolderData', + 'Int', + 'Kind', + 'KpointsData', + 'List', + 'NumericType', + 'OrbitalData', + 'ProjectionData', + 'RemoteData', + 'RemoteStashData', + 'RemoteStashFolderData', + 'SinglefileData', + 'Site', + 'Str', + 'StructureData', + 'TrajectoryData', + 'UpfData', + 'XyData', + 'find_bandgap', + 'to_aiida_type', ) diff --git a/aiida/orm/nodes/data/array/__init__.py b/aiida/orm/nodes/data/array/__init__.py index d34d6ad52a..8cdfec92b9 100644 --- a/aiida/orm/nodes/data/array/__init__.py +++ b/aiida/orm/nodes/data/array/__init__.py @@ -9,11 +9,23 @@ ########################################################################### """Module with `Node` sub classes for array based data structures.""" -from .array import ArrayData -from .bands import BandsData -from .kpoints import KpointsData -from .projection import ProjectionData -from .trajectory import TrajectoryData -from .xy import XyData +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import -__all__ = ('ArrayData', 'BandsData', 'KpointsData', 'ProjectionData', 'TrajectoryData', 'XyData') +from .array import * +from .bands import * +from .kpoints import * +from .projection import * +from .trajectory import * +from .xy import * + +__all__ = ( + 'ArrayData', + 'BandsData', + 'KpointsData', + 'ProjectionData', + 'TrajectoryData', + 'XyData', + 'find_bandgap', +) diff --git a/aiida/orm/nodes/data/array/array.py b/aiida/orm/nodes/data/array/array.py index 3d1553e4c5..98fe87f8d8 100644 --- a/aiida/orm/nodes/data/array/array.py +++ b/aiida/orm/nodes/data/array/array.py @@ -10,9 +10,10 @@ """ AiiDA ORM data class storing (numpy) arrays """ - from ..data import Data +__all__ = ('ArrayData',) + class ArrayData(Data): """ diff --git a/aiida/orm/nodes/data/array/bands.py b/aiida/orm/nodes/data/array/bands.py index 0559017a4f..4a0a40c590 100644 --- a/aiida/orm/nodes/data/array/bands.py +++ b/aiida/orm/nodes/data/array/bands.py @@ -20,6 +20,8 @@ from aiida.common.utils import prettify_labels, join_labels from .kpoints import KpointsData +__all__ = ('BandsData', 'find_bandgap') + def prepare_header_comment(uuid, plot_info, comment_char='#'): """Prepare the header.""" diff --git a/aiida/orm/nodes/data/array/kpoints.py b/aiida/orm/nodes/data/array/kpoints.py index a3aa1630d0..c0957516ab 100644 --- a/aiida/orm/nodes/data/array/kpoints.py +++ b/aiida/orm/nodes/data/array/kpoints.py @@ -16,6 +16,8 @@ from .array import ArrayData +__all__ = ('KpointsData',) + _DEFAULT_EPSILON_LENGTH = 1e-5 _DEFAULT_EPSILON_ANGLE = 1e-5 diff --git a/aiida/orm/nodes/data/array/projection.py b/aiida/orm/nodes/data/array/projection.py index 87b1f8bd08..a27205743c 100644 --- a/aiida/orm/nodes/data/array/projection.py +++ b/aiida/orm/nodes/data/array/projection.py @@ -18,6 +18,8 @@ from .array import ArrayData from .bands import BandsData +__all__ = ('ProjectionData',) + class ProjectionData(OrbitalData, ArrayData): """ diff --git a/aiida/orm/nodes/data/array/trajectory.py b/aiida/orm/nodes/data/array/trajectory.py index c0a0ac2485..9edd934d59 100644 --- a/aiida/orm/nodes/data/array/trajectory.py +++ b/aiida/orm/nodes/data/array/trajectory.py @@ -15,6 +15,7 @@ from .array import ArrayData +__all__ = ('TrajectoryData',) class TrajectoryData(ArrayData): """ diff --git a/aiida/orm/nodes/data/array/xy.py b/aiida/orm/nodes/data/array/xy.py index 48e3a47ae1..3a253074f4 100644 --- a/aiida/orm/nodes/data/array/xy.py +++ b/aiida/orm/nodes/data/array/xy.py @@ -16,6 +16,8 @@ from aiida.common.exceptions import NotExistent from .array import ArrayData +__all__ = ('XyData',) + def check_convert_single_to_tuple(item): """ diff --git a/aiida/orm/nodes/data/remote/__init__.py b/aiida/orm/nodes/data/remote/__init__.py index 2f88d7edbc..a57cdcf2f5 100644 --- a/aiida/orm/nodes/data/remote/__init__.py +++ b/aiida/orm/nodes/data/remote/__init__.py @@ -1,6 +1,15 @@ # -*- coding: utf-8 -*- """Module with data plugins that represent remote resources and so effectively are symbolic links.""" -from .base import RemoteData -from .stash import RemoteStashData, RemoteStashFolderData -__all__ = ('RemoteData', 'RemoteStashData', 'RemoteStashFolderData') +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .base import * +from .stash import * + +__all__ = ( + 'RemoteData', + 'RemoteStashData', + 'RemoteStashFolderData', +) diff --git a/aiida/orm/nodes/data/remote/stash/__init__.py b/aiida/orm/nodes/data/remote/stash/__init__.py index f744240cfc..74b714dca9 100644 --- a/aiida/orm/nodes/data/remote/stash/__init__.py +++ b/aiida/orm/nodes/data/remote/stash/__init__.py @@ -1,6 +1,14 @@ # -*- coding: utf-8 -*- """Module with data plugins that represent files of completed calculations jobs that have been stashed.""" -from .base import RemoteStashData -from .folder import RemoteStashFolderData -__all__ = ('RemoteStashData', 'RemoteStashFolderData') +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .base import * +from .folder import * + +__all__ = ( + 'RemoteStashData', + 'RemoteStashFolderData', +) diff --git a/aiida/orm/nodes/process/__init__.py b/aiida/orm/nodes/process/__init__.py index 4a84f892b0..2e698cd3bc 100644 --- a/aiida/orm/nodes/process/__init__.py +++ b/aiida/orm/nodes/process/__init__.py @@ -7,11 +7,22 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module with `Node` sub classes for processes.""" +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .calculation import * from .process import * from .workflow import * -__all__ = (calculation.__all__ + process.__all__ + workflow.__all__) # type: ignore[name-defined] +__all__ = ( + 'CalcFunctionNode', + 'CalcJobNode', + 'CalculationNode', + 'ProcessNode', + 'WorkChainNode', + 'WorkFunctionNode', + 'WorkflowNode', +) diff --git a/aiida/orm/nodes/process/calculation/__init__.py b/aiida/orm/nodes/process/calculation/__init__.py index 4d6232ba92..f7ce163fc7 100644 --- a/aiida/orm/nodes/process/calculation/__init__.py +++ b/aiida/orm/nodes/process/calculation/__init__.py @@ -9,8 +9,16 @@ ########################################################################### """Module with `Node` sub classes for calculation processes.""" -from .calculation import CalculationNode -from .calcfunction import CalcFunctionNode -from .calcjob import CalcJobNode +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import -__all__ = ('CalculationNode', 'CalcFunctionNode', 'CalcJobNode') +from .calcfunction import * +from .calcjob import * +from .calculation import * + +__all__ = ( + 'CalcFunctionNode', + 'CalcJobNode', + 'CalculationNode', +) diff --git a/aiida/orm/nodes/process/workflow/__init__.py b/aiida/orm/nodes/process/workflow/__init__.py index b4f210da6f..2f13f3af94 100644 --- a/aiida/orm/nodes/process/workflow/__init__.py +++ b/aiida/orm/nodes/process/workflow/__init__.py @@ -9,8 +9,16 @@ ########################################################################### """Module with `Node` sub classes for workflow processes.""" -from .workflow import WorkflowNode -from .workchain import WorkChainNode -from .workfunction import WorkFunctionNode +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import -__all__ = ('WorkflowNode', 'WorkChainNode', 'WorkFunctionNode') +from .workchain import * +from .workflow import * +from .workfunction import * + +__all__ = ( + 'WorkChainNode', + 'WorkFunctionNode', + 'WorkflowNode', +) diff --git a/aiida/orm/utils/__init__.py b/aiida/orm/utils/__init__.py index f703884d0e..bbe2bc6282 100644 --- a/aiida/orm/utils/__init__.py +++ b/aiida/orm/utils/__init__.py @@ -9,197 +9,38 @@ ########################################################################### """Utilities related to the ORM.""" -__all__ = ('load_code', 'load_computer', 'load_group', 'load_node') - - -def load_entity( - entity_loader=None, identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True -): - # pylint: disable=too-many-arguments - """ - Load an entity instance by one of its identifiers: pk, uuid or label - - If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to - automatically infer the type. - - :param identifier: pk (integer), uuid (string) or label (string) of a Code - :param pk: pk of a Code - :param uuid: uuid of a Code, or the beginning of the uuid - :param label: label of a Code - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :returns: the Code instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Code is found - :raise aiida.common.MultipleObjectsError: if more than one Code was found - """ - from aiida.orm.utils.loaders import OrmEntityLoader, IdentifierType - - if entity_loader is None or not issubclass(entity_loader, OrmEntityLoader): - raise TypeError(f'entity_loader should be a sub class of {type(OrmEntityLoader)}') - - inputs_provided = [value is not None for value in (identifier, pk, uuid, label)].count(True) - - if inputs_provided == 0: - raise ValueError("one of the parameters 'identifier', pk', 'uuid' or 'label' has to be specified") - elif inputs_provided > 1: - raise ValueError("only one of parameters 'identifier', pk', 'uuid' or 'label' has to be specified") - - if pk is not None: - - if not isinstance(pk, int): - raise TypeError('a pk has to be an integer') - - identifier = pk - identifier_type = IdentifierType.ID - - elif uuid is not None: - - if not isinstance(uuid, str): - raise TypeError('uuid has to be a string type') - - identifier = uuid - identifier_type = IdentifierType.UUID - - elif label is not None: - - if not isinstance(label, str): - raise TypeError('label has to be a string type') - - identifier = label - identifier_type = IdentifierType.LABEL - else: - identifier = str(identifier) - identifier_type = None - - return entity_loader.load_entity( - identifier, identifier_type, sub_classes=sub_classes, query_with_dashes=query_with_dashes - ) - - -def load_code(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): - """ - Load a Code instance by one of its identifiers: pk, uuid or label - - If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to - automatically infer the type. - - :param identifier: pk (integer), uuid (string) or label (string) of a Code - :param pk: pk of a Code - :param uuid: uuid of a Code, or the beginning of the uuid - :param label: label of a Code - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :return: the Code instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Code is found - :raise aiida.common.MultipleObjectsError: if more than one Code was found - """ - from aiida.orm.utils.loaders import CodeEntityLoader - return load_entity( - CodeEntityLoader, - identifier=identifier, - pk=pk, - uuid=uuid, - label=label, - sub_classes=sub_classes, - query_with_dashes=query_with_dashes - ) - - -def load_computer(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): - """ - Load a Computer instance by one of its identifiers: pk, uuid or label - - If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to - automatically infer the type. - - :param identifier: pk (integer), uuid (string) or label (string) of a Computer - :param pk: pk of a Computer - :param uuid: uuid of a Computer, or the beginning of the uuid - :param label: label of a Computer - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :return: the Computer instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Computer is found - :raise aiida.common.MultipleObjectsError: if more than one Computer was found - """ - from aiida.orm.utils.loaders import ComputerEntityLoader - return load_entity( - ComputerEntityLoader, - identifier=identifier, - pk=pk, - uuid=uuid, - label=label, - sub_classes=sub_classes, - query_with_dashes=query_with_dashes - ) - - -def load_group(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): - """ - Load a Group instance by one of its identifiers: pk, uuid or label - - If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to - automatically infer the type. - - :param identifier: pk (integer), uuid (string) or label (string) of a Group - :param pk: pk of a Group - :param uuid: uuid of a Group, or the beginning of the uuid - :param label: label of a Group - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :return: the Group instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Group is found - :raise aiida.common.MultipleObjectsError: if more than one Group was found - """ - from aiida.orm.utils.loaders import GroupEntityLoader - return load_entity( - GroupEntityLoader, - identifier=identifier, - pk=pk, - uuid=uuid, - label=label, - sub_classes=sub_classes, - query_with_dashes=query_with_dashes - ) - - -def load_node(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): - """ - Load a node by one of its identifiers: pk or uuid. If the type of the identifier is unknown - simply pass it without a keyword and the loader will attempt to infer the type - - :param identifier: pk (integer) or uuid (string) - :param pk: pk of a node - :param uuid: uuid of a node, or the beginning of the uuid - :param label: label of a Node - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :returns: the node instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Node is found - :raise aiida.common.MultipleObjectsError: if more than one Node was found - """ - from aiida.orm.utils.loaders import NodeEntityLoader - return load_entity( - NodeEntityLoader, - identifier=identifier, - pk=pk, - uuid=uuid, - label=label, - sub_classes=sub_classes, - query_with_dashes=query_with_dashes - ) +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .calcjob import * +from .links import * +from .load_funcs import * +from .loaders import * +from .managers import * +from .node import * + +__all__ = ( + 'AbstractNodeMeta', + 'AttributeManager', + 'CalcJobResultManager', + 'CalculationEntityLoader', + 'CodeEntityLoader', + 'ComputerEntityLoader', + 'GroupEntityLoader', + 'LinkManager', + 'LinkPair', + 'LinkTriple', + 'NodeEntityLoader', + 'NodeLinksManager', + 'OrmEntityLoader', + 'get_loader', + 'get_query_type_from_type_string', + 'get_type_string_from_class', + 'load_code', + 'load_computer', + 'load_group', + 'load_node', + 'load_node_class', + 'validate_link', +) diff --git a/aiida/orm/utils/load_funcs.py b/aiida/orm/utils/load_funcs.py new file mode 100644 index 0000000000..f703884d0e --- /dev/null +++ b/aiida/orm/utils/load_funcs.py @@ -0,0 +1,205 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Utilities related to the ORM.""" + +__all__ = ('load_code', 'load_computer', 'load_group', 'load_node') + + +def load_entity( + entity_loader=None, identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True +): + # pylint: disable=too-many-arguments + """ + Load an entity instance by one of its identifiers: pk, uuid or label + + If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to + automatically infer the type. + + :param identifier: pk (integer), uuid (string) or label (string) of a Code + :param pk: pk of a Code + :param uuid: uuid of a Code, or the beginning of the uuid + :param label: label of a Code + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :returns: the Code instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Code is found + :raise aiida.common.MultipleObjectsError: if more than one Code was found + """ + from aiida.orm.utils.loaders import OrmEntityLoader, IdentifierType + + if entity_loader is None or not issubclass(entity_loader, OrmEntityLoader): + raise TypeError(f'entity_loader should be a sub class of {type(OrmEntityLoader)}') + + inputs_provided = [value is not None for value in (identifier, pk, uuid, label)].count(True) + + if inputs_provided == 0: + raise ValueError("one of the parameters 'identifier', pk', 'uuid' or 'label' has to be specified") + elif inputs_provided > 1: + raise ValueError("only one of parameters 'identifier', pk', 'uuid' or 'label' has to be specified") + + if pk is not None: + + if not isinstance(pk, int): + raise TypeError('a pk has to be an integer') + + identifier = pk + identifier_type = IdentifierType.ID + + elif uuid is not None: + + if not isinstance(uuid, str): + raise TypeError('uuid has to be a string type') + + identifier = uuid + identifier_type = IdentifierType.UUID + + elif label is not None: + + if not isinstance(label, str): + raise TypeError('label has to be a string type') + + identifier = label + identifier_type = IdentifierType.LABEL + else: + identifier = str(identifier) + identifier_type = None + + return entity_loader.load_entity( + identifier, identifier_type, sub_classes=sub_classes, query_with_dashes=query_with_dashes + ) + + +def load_code(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): + """ + Load a Code instance by one of its identifiers: pk, uuid or label + + If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to + automatically infer the type. + + :param identifier: pk (integer), uuid (string) or label (string) of a Code + :param pk: pk of a Code + :param uuid: uuid of a Code, or the beginning of the uuid + :param label: label of a Code + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :return: the Code instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Code is found + :raise aiida.common.MultipleObjectsError: if more than one Code was found + """ + from aiida.orm.utils.loaders import CodeEntityLoader + return load_entity( + CodeEntityLoader, + identifier=identifier, + pk=pk, + uuid=uuid, + label=label, + sub_classes=sub_classes, + query_with_dashes=query_with_dashes + ) + + +def load_computer(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): + """ + Load a Computer instance by one of its identifiers: pk, uuid or label + + If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to + automatically infer the type. + + :param identifier: pk (integer), uuid (string) or label (string) of a Computer + :param pk: pk of a Computer + :param uuid: uuid of a Computer, or the beginning of the uuid + :param label: label of a Computer + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :return: the Computer instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Computer is found + :raise aiida.common.MultipleObjectsError: if more than one Computer was found + """ + from aiida.orm.utils.loaders import ComputerEntityLoader + return load_entity( + ComputerEntityLoader, + identifier=identifier, + pk=pk, + uuid=uuid, + label=label, + sub_classes=sub_classes, + query_with_dashes=query_with_dashes + ) + + +def load_group(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): + """ + Load a Group instance by one of its identifiers: pk, uuid or label + + If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to + automatically infer the type. + + :param identifier: pk (integer), uuid (string) or label (string) of a Group + :param pk: pk of a Group + :param uuid: uuid of a Group, or the beginning of the uuid + :param label: label of a Group + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :return: the Group instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Group is found + :raise aiida.common.MultipleObjectsError: if more than one Group was found + """ + from aiida.orm.utils.loaders import GroupEntityLoader + return load_entity( + GroupEntityLoader, + identifier=identifier, + pk=pk, + uuid=uuid, + label=label, + sub_classes=sub_classes, + query_with_dashes=query_with_dashes + ) + + +def load_node(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): + """ + Load a node by one of its identifiers: pk or uuid. If the type of the identifier is unknown + simply pass it without a keyword and the loader will attempt to infer the type + + :param identifier: pk (integer) or uuid (string) + :param pk: pk of a node + :param uuid: uuid of a node, or the beginning of the uuid + :param label: label of a Node + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :returns: the node instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Node is found + :raise aiida.common.MultipleObjectsError: if more than one Node was found + """ + from aiida.orm.utils.loaders import NodeEntityLoader + return load_entity( + NodeEntityLoader, + identifier=identifier, + pk=pk, + uuid=uuid, + label=label, + sub_classes=sub_classes, + query_with_dashes=query_with_dashes + ) diff --git a/aiida/parsers/__init__.py b/aiida/parsers/__init__.py index 5f6ee399c0..9198ba3a42 100644 --- a/aiida/parsers/__init__.py +++ b/aiida/parsers/__init__.py @@ -7,9 +7,14 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module for classes and utilities to write parsers for calculation jobs.""" +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .parser import * -__all__ = (parser.__all__) +__all__ = ( + 'Parser', +) diff --git a/aiida/plugins/__init__.py b/aiida/plugins/__init__.py index a169084f36..349d07e35a 100644 --- a/aiida/plugins/__init__.py +++ b/aiida/plugins/__init__.py @@ -7,10 +7,28 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Classes and functions to load and interact with plugin classes accessible through defined entry points.""" +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .entry_point import * from .factories import * +from .utils import * -__all__ = (entry_point.__all__ + factories.__all__) +__all__ = ( + 'BaseFactory', + 'CalculationFactory', + 'DataFactory', + 'DbImporterFactory', + 'GroupFactory', + 'OrbitalFactory', + 'ParserFactory', + 'PluginVersionProvider', + 'SchedulerFactory', + 'TransportFactory', + 'WorkflowFactory', + 'load_entry_point', + 'load_entry_point_from_string', +) diff --git a/aiida/repository/__init__.py b/aiida/repository/__init__.py index 6c71f4dbaa..91876a3342 100644 --- a/aiida/repository/__init__.py +++ b/aiida/repository/__init__.py @@ -8,9 +8,20 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with resources dealing with the file repository.""" -# pylint: disable=undefined-variable + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .backend import * from .common import * from .repository import * -__all__ = (backend.__all__ + common.__all__ + repository.__all__) # type: ignore[name-defined] +__all__ = ( + 'AbstractRepositoryBackend', + 'DiskObjectStoreRepositoryBackend', + 'File', + 'FileType', + 'Repository', + 'SandboxRepositoryBackend', +) diff --git a/aiida/repository/backend/__init__.py b/aiida/repository/backend/__init__.py index 20a704f865..4d890c6dae 100644 --- a/aiida/repository/backend/__init__.py +++ b/aiida/repository/backend/__init__.py @@ -1,8 +1,16 @@ # -*- coding: utf-8 -*- -# pylint: disable=undefined-variable """Module for file repository backend implementations.""" + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .abstract import * from .disk_object_store import * from .sandbox import * -__all__ = (abstract.__all__ + disk_object_store.__all__ + sandbox.__all__) # type: ignore[name-defined] +__all__ = ( + 'AbstractRepositoryBackend', + 'DiskObjectStoreRepositoryBackend', + 'SandboxRepositoryBackend', +) diff --git a/aiida/restapi/__init__.py b/aiida/restapi/__init__.py index fc199853df..4b340e2689 100644 --- a/aiida/restapi/__init__.py +++ b/aiida/restapi/__init__.py @@ -12,3 +12,14 @@ AiiDA nodes stored in database. The REST API is implemented using Flask RESTFul framework. """ + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .run_api import * + +__all__ = ( + 'configure_api', + 'run_api', +) diff --git a/aiida/schedulers/__init__.py b/aiida/schedulers/__init__.py index 2dd0db40f8..fd52173e18 100644 --- a/aiida/schedulers/__init__.py +++ b/aiida/schedulers/__init__.py @@ -7,10 +7,24 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module for classes and utilities to interact with cluster schedulers.""" +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .datastructures import * from .scheduler import * -__all__ = (datastructures.__all__ + scheduler.__all__) +__all__ = ( + 'JobInfo', + 'JobResource', + 'JobState', + 'JobTemplate', + 'MachineInfo', + 'NodeNumberJobResource', + 'ParEnvJobResource', + 'Scheduler', + 'SchedulerError', + 'SchedulerParsingError', +) diff --git a/aiida/tools/__init__.py b/aiida/tools/__init__.py index ffdf77d6e5..27dad0e172 100644 --- a/aiida/tools/__init__.py +++ b/aiida/tools/__init__.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable,redefined-builtin """ Tools to operate on AiiDA ORM class instances @@ -21,12 +20,76 @@ """ +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .calculations import * -from .data.array.kpoints import * -from .data.structure import * -from .dbimporters import * +from .data import * from .graph import * +from .groups import * +from .importexport import * +from .visualization import * __all__ = ( - calculations.__all__ + data.array.kpoints.__all__ + data.structure.__all__ + dbimporters.__all__ + graph.__all__ + 'ARCHIVE_READER_LOGGER', + 'ArchiveExportError', + 'ArchiveImportError', + 'ArchiveMetadata', + 'ArchiveMigrationError', + 'ArchiveMigratorAbstract', + 'ArchiveMigratorJsonBase', + 'ArchiveMigratorJsonTar', + 'ArchiveMigratorJsonZip', + 'ArchiveReaderAbstract', + 'ArchiveWriterAbstract', + 'CacheFolder', + 'CalculationTools', + 'CorruptArchive', + 'DELETE_LOGGER', + 'DanglingLinkError', + 'EXPORT_LOGGER', + 'EXPORT_VERSION', + 'ExportFileFormat', + 'ExportImportException', + 'ExportValidationError', + 'Graph', + 'GroupNotFoundError', + 'GroupNotUniqueError', + 'GroupPath', + 'IMPORT_LOGGER', + 'ImportUniquenessError', + 'ImportValidationError', + 'IncompatibleArchiveVersionError', + 'InvalidPath', + 'MIGRATE_LOGGER', + 'MigrationValidationError', + 'NoGroupsInPathError', + 'Orbital', + 'ProgressBarError', + 'ReaderJsonBase', + 'ReaderJsonFolder', + 'ReaderJsonTar', + 'ReaderJsonZip', + 'RealhydrogenOrbital', + 'WriterJsonFolder', + 'WriterJsonTar', + 'WriterJsonZip', + 'default_link_styles', + 'default_node_styles', + 'default_node_sublabels', + 'delete_group_nodes', + 'delete_nodes', + 'detect_archive_type', + 'export', + 'get_explicit_kpoints_path', + 'get_kpoints_path', + 'get_migrator', + 'get_reader', + 'get_writer', + 'import_data', + 'null_callback', + 'pstate_node_styles', + 'spglib_tuple_to_structure', + 'structure_to_spglib_tuple', ) diff --git a/aiida/tools/calculations/__init__.py b/aiida/tools/calculations/__init__.py index 34a0745e5f..10aa462662 100644 --- a/aiida/tools/calculations/__init__.py +++ b/aiida/tools/calculations/__init__.py @@ -7,9 +7,14 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Calculation tool plugins for Calculation classes.""" +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .base import * -__all__ = (base.__all__) +__all__ = ( + 'CalculationTools', +) diff --git a/aiida/tools/data/__init__.py b/aiida/tools/data/__init__.py index 2776a55f97..16ea3bdd91 100644 --- a/aiida/tools/data/__init__.py +++ b/aiida/tools/data/__init__.py @@ -7,3 +7,20 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .array import * +from .orbital import * +from .structure import * + +__all__ = ( + 'Orbital', + 'RealhydrogenOrbital', + 'get_explicit_kpoints_path', + 'get_kpoints_path', + 'spglib_tuple_to_structure', + 'structure_to_spglib_tuple', +) diff --git a/aiida/tools/data/array/__init__.py b/aiida/tools/data/array/__init__.py index 2776a55f97..3b67a12982 100644 --- a/aiida/tools/data/array/__init__.py +++ b/aiida/tools/data/array/__init__.py @@ -7,3 +7,14 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .kpoints import * + +__all__ = ( + 'get_explicit_kpoints_path', + 'get_kpoints_path', +) diff --git a/aiida/tools/data/array/kpoints/__init__.py b/aiida/tools/data/array/kpoints/__init__.py index 59c40e53f6..7916aad1d1 100644 --- a/aiida/tools/data/array/kpoints/__init__.py +++ b/aiida/tools/data/array/kpoints/__init__.py @@ -11,231 +11,14 @@ Various utilities to deal with KpointsData instances or create new ones (e.g. band paths, kpoints from a parsed input text file, ...) """ -from aiida.orm import KpointsData, Dict -__all__ = ('get_kpoints_path', 'get_explicit_kpoints_path') +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import +from .main import * -def get_kpoints_path(structure, method='seekpath', **kwargs): - """ - Returns a dictionary whose contents depend on the method but includes at least the following keys - - * parameters: Dict node - - The contents of the parameters depends on the method but contains at least the keys - - * 'point_coords': a dictionary with 'kpoints-label': [float coordinates] - * 'path': a list of length-2 tuples, with the labels of the starting - and ending point of each label section - - The 'seekpath' method which is the default also returns the following additional nodes - - * primitive_structure: StructureData with the primitive cell - * conv_structure: StructureData with the conventional cell - - Note that the generated kpoints for the seekpath method only apply on the returned primitive_structure - and not on the input structure that was provided - - :param structure: a StructureData node - :param method: the method to use for kpoint generation, options are 'seekpath' and 'legacy'. - It is strongly advised to use the default 'seekpath' as the 'legacy' implementation is known to have - bugs for certain structure cells - :param kwargs: optional keyword arguments that depend on the selected method - :returns: dictionary as described above in the docstring - """ - if method not in _GET_KPOINTS_PATH_METHODS.keys(): - raise ValueError(f"the method '{method}' is not implemented") - - method = _GET_KPOINTS_PATH_METHODS[method] - - return method(structure, **kwargs) - - -def get_explicit_kpoints_path(structure, method='seekpath', **kwargs): - """ - Returns a dictionary whose contents depend on the method but includes at least the following keys - - * parameters: Dict node - * explicit_kpoints: KpointsData node with explicit kpoints path - - The contents of the parameters depends on the method but contains at least the keys - - * 'point_coords': a dictionary with 'kpoints-label': [float coordinates] - * 'path': a list of length-2 tuples, with the labels of the starting - and ending point of each label section - - The 'seekpath' method which is the default also returns the following additional nodes - - * primitive_structure: StructureData with the primitive cell - * conv_structure: StructureData with the conventional cell - - Note that the generated kpoints for the seekpath method only apply on the returned primitive_structure - and not on the input structure that was provided - - :param structure: a StructureData node - :param method: the method to use for kpoint generation, options are 'seekpath' and 'legacy'. - It is strongly advised to use the default 'seekpath' as the 'legacy' implementation is known to have - bugs for certain structure cells - :param kwargs: optional keyword arguments that depend on the selected method - :returns: dictionary as described above in the docstring - """ - if method not in _GET_EXPLICIT_KPOINTS_PATH_METHODS.keys(): - raise ValueError(f"the method '{method}' is not implemented") - - method = _GET_EXPLICIT_KPOINTS_PATH_METHODS[method] - - return method(structure, **kwargs) - - -def _seekpath_get_kpoints_path(structure, **kwargs): - """ - Call the get_kpoints_path wrapper function for Seekpath - - :param structure: a StructureData node - :param with_time_reversal: if False, and the group has no inversion - symmetry, additional lines are returned - :param recipe: choose the reference publication that defines the special points and paths. - Currently, the following value is implemented: - - - ``hpkot``: HPKOT paper: - Y. Hinuma, G. Pizzi, Y. Kumagai, F. Oba, I. Tanaka, Band structure - diagram paths based on crystallography, Comp. Mat. Sci. 128, 140 (2017). - DOI: 10.1016/j.commatsci.2016.10.015 - :param threshold: the threshold to use to verify if we are in - and edge case (e.g., a tetragonal cell, but ``a==c``). For instance, - in the tI lattice, if ``abs(a-c) < threshold``, a - :py:exc:`~seekpath.hpkot.EdgeCaseWarning` is issued. - Note that depending on the bravais lattice, the meaning of the - threshold is different (angle, length, ...) - :param symprec: the symmetry precision used internally by SPGLIB - :param angle_tolerance: the angle_tolerance used internally by SPGLIB - """ - from aiida.tools.data.array.kpoints import seekpath - - assert structure.pbc == (True, True, True), 'Seekpath only implemented for three-dimensional structures' - - recognized_args = ['with_time_reversal', 'recipe', 'threshold', 'symprec', 'angle_tolerance'] - unknown_args = set(kwargs).difference(recognized_args) - - if unknown_args: - raise ValueError(f'unknown arguments {unknown_args}') - - return seekpath.get_kpoints_path(structure, kwargs) - - -def _seekpath_get_explicit_kpoints_path(structure, **kwargs): - """ - Call the get_explicit_kpoints_path wrapper function for Seekpath - - :param structure: a StructureData node - :param with_time_reversal: if False, and the group has no inversion - symmetry, additional lines are returned - :param reference_distance: a reference target distance between neighboring - k-points in the path, in units of 1/ang. The actual value will be as - close as possible to this value, to have an integer number of points in - each path - :param recipe: choose the reference publication that defines the special points and paths. - Currently, the following value is implemented: - - - ``hpkot``: HPKOT paper: - Y. Hinuma, G. Pizzi, Y. Kumagai, F. Oba, I. Tanaka, Band structure - diagram paths based on crystallography, Comp. Mat. Sci. 128, 140 (2017). - DOI: 10.1016/j.commatsci.2016.10.015 - :param threshold: the threshold to use to verify if we are in - and edge case (e.g., a tetragonal cell, but ``a==c``). For instance, - in the tI lattice, if ``abs(a-c) < threshold``, a - :py:exc:`~seekpath.hpkot.EdgeCaseWarning` is issued. - Note that depending on the bravais lattice, the meaning of the - threshold is different (angle, length, ...) - :param symprec: the symmetry precision used internally by SPGLIB - :param angle_tolerance: the angle_tolerance used internally by SPGLIB - """ - from aiida.tools.data.array.kpoints import seekpath - - assert structure.pbc == (True, True, True), 'Seekpath only implemented for three-dimensional structures' - - recognized_args = ['with_time_reversal', 'reference_distance', 'recipe', 'threshold', 'symprec', 'angle_tolerance'] - unknown_args = set(kwargs).difference(recognized_args) - - if unknown_args: - raise ValueError(f'unknown arguments {unknown_args}') - - return seekpath.get_explicit_kpoints_path(structure, kwargs) - - -def _legacy_get_kpoints_path(structure, **kwargs): - """ - Call the get_kpoints_path of the legacy implementation - - :param structure: a StructureData node - :param bool cartesian: if set to true, reads the coordinates eventually passed in value as cartesian coordinates - :param epsilon_length: threshold on lengths comparison, used to get the bravais lattice info - :param epsilon_angle: threshold on angles comparison, used to get the bravais lattice info - """ - from aiida.tools.data.array.kpoints import legacy - - args_recognized = ['cartesian', 'epsilon_length', 'epsilon_angle'] - args_unknown = set(kwargs).difference(args_recognized) - - if args_unknown: - raise ValueError(f'unknown arguments {args_unknown}') - - point_coords, path, bravais_info = legacy.get_kpoints_path(cell=structure.cell, pbc=structure.pbc, **kwargs) - - parameters = { - 'bravais_info': bravais_info, - 'point_coords': point_coords, - 'path': path, - } - - return {'parameters': Dict(dict=parameters)} - - -def _legacy_get_explicit_kpoints_path(structure, **kwargs): - """ - Call the get_explicit_kpoints_path of the legacy implementation - - :param structure: a StructureData node - :param float kpoint_distance: parameter controlling the distance between kpoints. Distance is - given in crystal coordinates, i.e. the distance is computed in the space of b1, b2, b3. - The distance set will be the closest possible to this value, compatible with the requirement - of putting equispaced points between two special points (since extrema are included). - :param bool cartesian: if set to true, reads the coordinates eventually passed in value as cartesian coordinates - :param float epsilon_length: threshold on lengths comparison, used to get the bravais lattice info - :param float epsilon_angle: threshold on angles comparison, used to get the bravais lattice info - """ - from aiida.tools.data.array.kpoints import legacy - - args_recognized = ['value', 'kpoint_distance', 'cartesian', 'epsilon_length', 'epsilon_angle'] - args_unknown = set(kwargs).difference(args_recognized) - - if args_unknown: - raise ValueError(f'unknown arguments {args_unknown}') - - point_coords, path, bravais_info, explicit_kpoints, labels = legacy.get_explicit_kpoints_path( # pylint: disable=unbalanced-tuple-unpacking - cell=structure.cell, pbc=structure.pbc, **kwargs - ) - - kpoints = KpointsData() - kpoints.set_cell(structure.cell) - kpoints.set_kpoints(explicit_kpoints) - kpoints.labels = labels - - parameters = { - 'bravais_info': bravais_info, - 'point_coords': point_coords, - 'path': path, - } - - return {'parameters': Dict(dict=parameters), 'explicit_kpoints': kpoints} - - -_GET_KPOINTS_PATH_METHODS = { - 'legacy': _legacy_get_kpoints_path, - 'seekpath': _seekpath_get_kpoints_path, -} - -_GET_EXPLICIT_KPOINTS_PATH_METHODS = { - 'legacy': _legacy_get_explicit_kpoints_path, - 'seekpath': _seekpath_get_explicit_kpoints_path, -} +__all__ = ( + 'get_explicit_kpoints_path', + 'get_kpoints_path', +) diff --git a/aiida/tools/data/array/kpoints/main.py b/aiida/tools/data/array/kpoints/main.py new file mode 100644 index 0000000000..59c40e53f6 --- /dev/null +++ b/aiida/tools/data/array/kpoints/main.py @@ -0,0 +1,241 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +""" +Various utilities to deal with KpointsData instances or create new ones +(e.g. band paths, kpoints from a parsed input text file, ...) +""" +from aiida.orm import KpointsData, Dict + +__all__ = ('get_kpoints_path', 'get_explicit_kpoints_path') + + +def get_kpoints_path(structure, method='seekpath', **kwargs): + """ + Returns a dictionary whose contents depend on the method but includes at least the following keys + + * parameters: Dict node + + The contents of the parameters depends on the method but contains at least the keys + + * 'point_coords': a dictionary with 'kpoints-label': [float coordinates] + * 'path': a list of length-2 tuples, with the labels of the starting + and ending point of each label section + + The 'seekpath' method which is the default also returns the following additional nodes + + * primitive_structure: StructureData with the primitive cell + * conv_structure: StructureData with the conventional cell + + Note that the generated kpoints for the seekpath method only apply on the returned primitive_structure + and not on the input structure that was provided + + :param structure: a StructureData node + :param method: the method to use for kpoint generation, options are 'seekpath' and 'legacy'. + It is strongly advised to use the default 'seekpath' as the 'legacy' implementation is known to have + bugs for certain structure cells + :param kwargs: optional keyword arguments that depend on the selected method + :returns: dictionary as described above in the docstring + """ + if method not in _GET_KPOINTS_PATH_METHODS.keys(): + raise ValueError(f"the method '{method}' is not implemented") + + method = _GET_KPOINTS_PATH_METHODS[method] + + return method(structure, **kwargs) + + +def get_explicit_kpoints_path(structure, method='seekpath', **kwargs): + """ + Returns a dictionary whose contents depend on the method but includes at least the following keys + + * parameters: Dict node + * explicit_kpoints: KpointsData node with explicit kpoints path + + The contents of the parameters depends on the method but contains at least the keys + + * 'point_coords': a dictionary with 'kpoints-label': [float coordinates] + * 'path': a list of length-2 tuples, with the labels of the starting + and ending point of each label section + + The 'seekpath' method which is the default also returns the following additional nodes + + * primitive_structure: StructureData with the primitive cell + * conv_structure: StructureData with the conventional cell + + Note that the generated kpoints for the seekpath method only apply on the returned primitive_structure + and not on the input structure that was provided + + :param structure: a StructureData node + :param method: the method to use for kpoint generation, options are 'seekpath' and 'legacy'. + It is strongly advised to use the default 'seekpath' as the 'legacy' implementation is known to have + bugs for certain structure cells + :param kwargs: optional keyword arguments that depend on the selected method + :returns: dictionary as described above in the docstring + """ + if method not in _GET_EXPLICIT_KPOINTS_PATH_METHODS.keys(): + raise ValueError(f"the method '{method}' is not implemented") + + method = _GET_EXPLICIT_KPOINTS_PATH_METHODS[method] + + return method(structure, **kwargs) + + +def _seekpath_get_kpoints_path(structure, **kwargs): + """ + Call the get_kpoints_path wrapper function for Seekpath + + :param structure: a StructureData node + :param with_time_reversal: if False, and the group has no inversion + symmetry, additional lines are returned + :param recipe: choose the reference publication that defines the special points and paths. + Currently, the following value is implemented: + + - ``hpkot``: HPKOT paper: + Y. Hinuma, G. Pizzi, Y. Kumagai, F. Oba, I. Tanaka, Band structure + diagram paths based on crystallography, Comp. Mat. Sci. 128, 140 (2017). + DOI: 10.1016/j.commatsci.2016.10.015 + :param threshold: the threshold to use to verify if we are in + and edge case (e.g., a tetragonal cell, but ``a==c``). For instance, + in the tI lattice, if ``abs(a-c) < threshold``, a + :py:exc:`~seekpath.hpkot.EdgeCaseWarning` is issued. + Note that depending on the bravais lattice, the meaning of the + threshold is different (angle, length, ...) + :param symprec: the symmetry precision used internally by SPGLIB + :param angle_tolerance: the angle_tolerance used internally by SPGLIB + """ + from aiida.tools.data.array.kpoints import seekpath + + assert structure.pbc == (True, True, True), 'Seekpath only implemented for three-dimensional structures' + + recognized_args = ['with_time_reversal', 'recipe', 'threshold', 'symprec', 'angle_tolerance'] + unknown_args = set(kwargs).difference(recognized_args) + + if unknown_args: + raise ValueError(f'unknown arguments {unknown_args}') + + return seekpath.get_kpoints_path(structure, kwargs) + + +def _seekpath_get_explicit_kpoints_path(structure, **kwargs): + """ + Call the get_explicit_kpoints_path wrapper function for Seekpath + + :param structure: a StructureData node + :param with_time_reversal: if False, and the group has no inversion + symmetry, additional lines are returned + :param reference_distance: a reference target distance between neighboring + k-points in the path, in units of 1/ang. The actual value will be as + close as possible to this value, to have an integer number of points in + each path + :param recipe: choose the reference publication that defines the special points and paths. + Currently, the following value is implemented: + + - ``hpkot``: HPKOT paper: + Y. Hinuma, G. Pizzi, Y. Kumagai, F. Oba, I. Tanaka, Band structure + diagram paths based on crystallography, Comp. Mat. Sci. 128, 140 (2017). + DOI: 10.1016/j.commatsci.2016.10.015 + :param threshold: the threshold to use to verify if we are in + and edge case (e.g., a tetragonal cell, but ``a==c``). For instance, + in the tI lattice, if ``abs(a-c) < threshold``, a + :py:exc:`~seekpath.hpkot.EdgeCaseWarning` is issued. + Note that depending on the bravais lattice, the meaning of the + threshold is different (angle, length, ...) + :param symprec: the symmetry precision used internally by SPGLIB + :param angle_tolerance: the angle_tolerance used internally by SPGLIB + """ + from aiida.tools.data.array.kpoints import seekpath + + assert structure.pbc == (True, True, True), 'Seekpath only implemented for three-dimensional structures' + + recognized_args = ['with_time_reversal', 'reference_distance', 'recipe', 'threshold', 'symprec', 'angle_tolerance'] + unknown_args = set(kwargs).difference(recognized_args) + + if unknown_args: + raise ValueError(f'unknown arguments {unknown_args}') + + return seekpath.get_explicit_kpoints_path(structure, kwargs) + + +def _legacy_get_kpoints_path(structure, **kwargs): + """ + Call the get_kpoints_path of the legacy implementation + + :param structure: a StructureData node + :param bool cartesian: if set to true, reads the coordinates eventually passed in value as cartesian coordinates + :param epsilon_length: threshold on lengths comparison, used to get the bravais lattice info + :param epsilon_angle: threshold on angles comparison, used to get the bravais lattice info + """ + from aiida.tools.data.array.kpoints import legacy + + args_recognized = ['cartesian', 'epsilon_length', 'epsilon_angle'] + args_unknown = set(kwargs).difference(args_recognized) + + if args_unknown: + raise ValueError(f'unknown arguments {args_unknown}') + + point_coords, path, bravais_info = legacy.get_kpoints_path(cell=structure.cell, pbc=structure.pbc, **kwargs) + + parameters = { + 'bravais_info': bravais_info, + 'point_coords': point_coords, + 'path': path, + } + + return {'parameters': Dict(dict=parameters)} + + +def _legacy_get_explicit_kpoints_path(structure, **kwargs): + """ + Call the get_explicit_kpoints_path of the legacy implementation + + :param structure: a StructureData node + :param float kpoint_distance: parameter controlling the distance between kpoints. Distance is + given in crystal coordinates, i.e. the distance is computed in the space of b1, b2, b3. + The distance set will be the closest possible to this value, compatible with the requirement + of putting equispaced points between two special points (since extrema are included). + :param bool cartesian: if set to true, reads the coordinates eventually passed in value as cartesian coordinates + :param float epsilon_length: threshold on lengths comparison, used to get the bravais lattice info + :param float epsilon_angle: threshold on angles comparison, used to get the bravais lattice info + """ + from aiida.tools.data.array.kpoints import legacy + + args_recognized = ['value', 'kpoint_distance', 'cartesian', 'epsilon_length', 'epsilon_angle'] + args_unknown = set(kwargs).difference(args_recognized) + + if args_unknown: + raise ValueError(f'unknown arguments {args_unknown}') + + point_coords, path, bravais_info, explicit_kpoints, labels = legacy.get_explicit_kpoints_path( # pylint: disable=unbalanced-tuple-unpacking + cell=structure.cell, pbc=structure.pbc, **kwargs + ) + + kpoints = KpointsData() + kpoints.set_cell(structure.cell) + kpoints.set_kpoints(explicit_kpoints) + kpoints.labels = labels + + parameters = { + 'bravais_info': bravais_info, + 'point_coords': point_coords, + 'path': path, + } + + return {'parameters': Dict(dict=parameters), 'explicit_kpoints': kpoints} + + +_GET_KPOINTS_PATH_METHODS = { + 'legacy': _legacy_get_kpoints_path, + 'seekpath': _seekpath_get_kpoints_path, +} + +_GET_EXPLICIT_KPOINTS_PATH_METHODS = { + 'legacy': _legacy_get_explicit_kpoints_path, + 'seekpath': _seekpath_get_explicit_kpoints_path, +} diff --git a/aiida/tools/data/array/kpoints/seekpath.py b/aiida/tools/data/array/kpoints/seekpath.py index 0d4c59b6a0..b12672b02d 100644 --- a/aiida/tools/data/array/kpoints/seekpath.py +++ b/aiida/tools/data/array/kpoints/seekpath.py @@ -12,8 +12,6 @@ from aiida.orm import KpointsData, Dict -__all__ = ('get_explicit_kpoints_path', 'get_kpoints_path') - def get_explicit_kpoints_path(structure, parameters): """ diff --git a/aiida/tools/data/orbital/__init__.py b/aiida/tools/data/orbital/__init__.py index 670b51ba55..ba172980e2 100644 --- a/aiida/tools/data/orbital/__init__.py +++ b/aiida/tools/data/orbital/__init__.py @@ -9,6 +9,14 @@ ########################################################################### """Module for classes and methods that represents molecular orbitals.""" -from .orbital import Orbital +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import -__all__ = ('Orbital',) +from .orbital import * +from .realhydrogen import * + +__all__ = ( + 'Orbital', + 'RealhydrogenOrbital', +) diff --git a/aiida/tools/data/orbital/orbital.py b/aiida/tools/data/orbital/orbital.py index a1ecabb6e1..5c83043e8a 100644 --- a/aiida/tools/data/orbital/orbital.py +++ b/aiida/tools/data/orbital/orbital.py @@ -17,6 +17,8 @@ from aiida.common.exceptions import ValidationError from aiida.plugins.entry_point import get_entry_point_from_class +__all__ = ('Orbital',) + def validate_int(value): """ diff --git a/aiida/tools/data/orbital/realhydrogen.py b/aiida/tools/data/orbital/realhydrogen.py index 84961b48b1..14b601f0ba 100644 --- a/aiida/tools/data/orbital/realhydrogen.py +++ b/aiida/tools/data/orbital/realhydrogen.py @@ -16,6 +16,8 @@ from .orbital import Orbital, validate_len3_list_or_none, validate_float_or_none +__all__ = ('RealhydrogenOrbital',) + def validate_l(value): """ Validate the value of the angular momentum diff --git a/aiida/tools/data/structure/__init__.py b/aiida/tools/data/structure.py similarity index 100% rename from aiida/tools/data/structure/__init__.py rename to aiida/tools/data/structure.py diff --git a/aiida/tools/graph/__init__.py b/aiida/tools/graph/__init__.py index c095d1619a..36de7bcc94 100644 --- a/aiida/tools/graph/__init__.py +++ b/aiida/tools/graph/__init__.py @@ -7,8 +7,16 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Provides tools for traversing the provenance graph.""" + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .deletions import * -__all__ = deletions.__all__ +__all__ = ( + 'DELETE_LOGGER', + 'delete_group_nodes', + 'delete_nodes', +) diff --git a/aiida/tools/groups/__init__.py b/aiida/tools/groups/__init__.py index 19e936839b..b7ad08bbcb 100644 --- a/aiida/tools/groups/__init__.py +++ b/aiida/tools/groups/__init__.py @@ -13,8 +13,18 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Provides tools for interacting with AiiDA Groups.""" + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .paths import * -__all__ = paths.__all__ +__all__ = ( + 'GroupNotFoundError', + 'GroupNotUniqueError', + 'GroupPath', + 'InvalidPath', + 'NoGroupsInPathError', +) diff --git a/aiida/tools/importexport/__init__.py b/aiida/tools/importexport/__init__.py index d6d576159f..90a2d9e431 100644 --- a/aiida/tools/importexport/__init__.py +++ b/aiida/tools/importexport/__init__.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Provides import/export functionalities. To see history/git blame prior to the move to aiida.tools.importexport, @@ -15,9 +14,55 @@ Functionality: /aiida/orm/importexport.py Tests: /aiida/backends/tests/test_export_and_import.py """ + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .archive import * +from .common import * from .dbexport import * from .dbimport import * -from .common import * -__all__ = (archive.__all__ + dbexport.__all__ + dbimport.__all__ + common.__all__) +__all__ = ( + 'ARCHIVE_READER_LOGGER', + 'ArchiveExportError', + 'ArchiveImportError', + 'ArchiveMetadata', + 'ArchiveMigrationError', + 'ArchiveMigratorAbstract', + 'ArchiveMigratorJsonBase', + 'ArchiveMigratorJsonTar', + 'ArchiveMigratorJsonZip', + 'ArchiveReaderAbstract', + 'ArchiveWriterAbstract', + 'CacheFolder', + 'CorruptArchive', + 'DanglingLinkError', + 'EXPORT_LOGGER', + 'EXPORT_VERSION', + 'ExportFileFormat', + 'ExportImportException', + 'ExportValidationError', + 'IMPORT_LOGGER', + 'ImportUniquenessError', + 'ImportValidationError', + 'IncompatibleArchiveVersionError', + 'MIGRATE_LOGGER', + 'MigrationValidationError', + 'ProgressBarError', + 'ReaderJsonBase', + 'ReaderJsonFolder', + 'ReaderJsonTar', + 'ReaderJsonZip', + 'WriterJsonFolder', + 'WriterJsonTar', + 'WriterJsonZip', + 'detect_archive_type', + 'export', + 'get_migrator', + 'get_reader', + 'get_writer', + 'import_data', + 'null_callback', +) diff --git a/aiida/tools/importexport/archive/__init__.py b/aiida/tools/importexport/archive/__init__.py index ee13f67842..51dd0c35d6 100644 --- a/aiida/tools/importexport/archive/__init__.py +++ b/aiida/tools/importexport/archive/__init__.py @@ -7,13 +7,39 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable # type: ignore """Readers and writers for archive formats, that work independently of a connection to an AiiDA profile.""" +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .common import * from .migrators import * from .readers import * from .writers import * -__all__ = (migrators.__all__ + readers.__all__ + writers.__all__ + common.__all__) +__all__ = ( + 'ARCHIVE_READER_LOGGER', + 'ArchiveMetadata', + 'ArchiveMigratorAbstract', + 'ArchiveMigratorJsonBase', + 'ArchiveMigratorJsonTar', + 'ArchiveMigratorJsonZip', + 'ArchiveReaderAbstract', + 'ArchiveWriterAbstract', + 'CacheFolder', + 'MIGRATE_LOGGER', + 'ReaderJsonBase', + 'ReaderJsonFolder', + 'ReaderJsonTar', + 'ReaderJsonZip', + 'WriterJsonFolder', + 'WriterJsonTar', + 'WriterJsonZip', + 'detect_archive_type', + 'get_migrator', + 'get_reader', + 'get_writer', + 'null_callback', +) diff --git a/aiida/tools/importexport/common/__init__.py b/aiida/tools/importexport/common/__init__.py index 4bdfd23504..df4813add8 100644 --- a/aiida/tools/importexport/common/__init__.py +++ b/aiida/tools/importexport/common/__init__.py @@ -7,9 +7,27 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Common utility functions, classes, and exceptions""" + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .config import * from .exceptions import * -__all__ = (config.__all__ + exceptions.__all__) +__all__ = ( + 'ArchiveExportError', + 'ArchiveImportError', + 'ArchiveMigrationError', + 'CorruptArchive', + 'DanglingLinkError', + 'EXPORT_VERSION', + 'ExportImportException', + 'ExportValidationError', + 'ImportUniquenessError', + 'ImportValidationError', + 'IncompatibleArchiveVersionError', + 'MigrationValidationError', + 'ProgressBarError', +) diff --git a/aiida/tools/importexport/dbexport/__init__.py b/aiida/tools/importexport/dbexport/__init__.py index c68a3c242a..b2b13f57bf 100644 --- a/aiida/tools/importexport/dbexport/__init__.py +++ b/aiida/tools/importexport/dbexport/__init__.py @@ -7,607 +7,16 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=fixme,too-many-lines """Provides export functionalities.""" -from collections import defaultdict -import os -import tempfile -from typing import ( - Any, - Callable, - DefaultDict, - Dict, - Iterable, - List, - Optional, - Set, - Tuple, - Type, - Union, - cast, -) - -from aiida import get_version, orm -from aiida.common.links import GraphTraversalRules -from aiida.common.lang import type_check -from aiida.common.progress_reporter import get_progress_reporter, create_callback -from aiida.tools.importexport.common import exceptions -from aiida.tools.importexport.common.config import ( - COMMENT_ENTITY_NAME, - COMPUTER_ENTITY_NAME, - EXPORT_VERSION, - GROUP_ENTITY_NAME, - LOG_ENTITY_NAME, - NODE_ENTITY_NAME, - ExportFileFormat, - entity_names_to_entities, - file_fields_to_model_fields, - get_all_fields_info, - model_fields_to_file_fields, -) -from aiida.tools.graph.graph_traversers import get_nodes_export, validate_traversal_rules -from aiida.tools.importexport.archive.writers import ArchiveMetadata, ArchiveWriterAbstract, get_writer -from aiida.tools.importexport.dbexport.utils import ( - EXPORT_LOGGER, - check_licenses, - fill_in_query, - serialize_dict, - summary, -) - -__all__ = ('export', 'EXPORT_LOGGER', 'ExportFileFormat') - - -def export( - entities: Optional[Iterable[Any]] = None, - filename: Optional[str] = None, - file_format: Union[str, Type[ArchiveWriterAbstract]] = ExportFileFormat.ZIP, - overwrite: bool = False, - include_comments: bool = True, - include_logs: bool = True, - allowed_licenses: Optional[Union[list, Callable]] = None, - forbidden_licenses: Optional[Union[list, Callable]] = None, - writer_init: Optional[Dict[str, Any]] = None, - batch_size: int = 100, - **traversal_rules: bool, -) -> ArchiveWriterAbstract: - """Export AiiDA data to an archive file. - - Note, the logging level and progress reporter should be set externally, for example:: - - from aiida.common.progress_reporter import set_progress_bar_tqdm - - EXPORT_LOGGER.setLevel('DEBUG') - set_progress_bar_tqdm(leave=True) - export(...) - - :param entities: a list of entity instances; - they can belong to different models/entities. - - :param filename: the filename (possibly including the absolute path) - of the file on which to export. - - :param file_format: 'zip', 'tar.gz' or 'folder' or a specific writer class. - - :param overwrite: if True, overwrite the output file without asking, if it exists. - If False, raise an - :py:class:`~aiida.tools.importexport.common.exceptions.ArchiveExportError` - if the output file already exists. - - :param allowed_licenses: List or function. - If a list, then checks whether all licenses of Data nodes are in the list. If a function, - then calls function for licenses of Data nodes expecting True if license is allowed, False - otherwise. - - :param forbidden_licenses: List or function. If a list, - then checks whether all licenses of Data nodes are in the list. If a function, - then calls function for licenses of Data nodes expecting True if license is allowed, False - otherwise. - - :param include_comments: In-/exclude export of comments for given node(s) in ``entities``. - Default: True, *include* comments in export (as well as relevant users). - - :param include_logs: In-/exclude export of logs for given node(s) in ``entities``. - Default: True, *include* logs in export. - - :param writer_init: Additional key-word arguments to pass to the writer class init - - :param batch_size: batch database query results in sub-collections to reduce memory usage - - :param traversal_rules: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` - what rule names are toggleable and what the defaults are. - - :returns: a dictionary of data regarding the export process (timings, etc) - - :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`: - if there are any internal errors when exporting. - :raises `~aiida.common.exceptions.LicensingException`: - if any node is licensed under forbidden license. - """ - # pylint: disable=too-many-locals,too-many-branches,too-many-statements - - # Backwards-compatibility - entities = cast(Iterable[Any], entities) - filename = cast(str, filename) - - type_check( - entities, - (list, tuple, set), - msg='`entities` must be specified and given as a list of AiiDA entities', - ) - entities = list(entities) - if type_check(filename, str, allow_none=True) is None: - filename = 'export_data.aiida' - - if not overwrite and os.path.exists(filename): - raise exceptions.ArchiveExportError(f"The output file '{filename}' already exists") - - # validate the traversal rules and generate a full set for reporting - validate_traversal_rules(GraphTraversalRules.EXPORT, **traversal_rules) - full_traversal_rules = { - name: traversal_rules.get(name, rule.default) for name, rule in GraphTraversalRules.EXPORT.value.items() - } - - # setup the archive writer - writer_init = writer_init or {} - if isinstance(file_format, str): - writer = get_writer(file_format)(filepath=filename, **writer_init) - elif issubclass(file_format, ArchiveWriterAbstract): - writer = file_format(filepath=filename, **writer_init) - else: - raise TypeError('file_format must be a string or ArchiveWriterAbstract class') - - summary( - file_format=writer.file_format_verbose, - export_version=writer.export_version, - outfile=filename, - include_comments=include_comments, - include_logs=include_logs, - traversal_rules=full_traversal_rules - ) - - EXPORT_LOGGER.debug('STARTING EXPORT...') - - all_fields_info, unique_identifiers = get_all_fields_info() - entities_starting_set, given_node_entry_ids = _get_starting_node_ids(entities) - - # Initialize the writer - with writer as writer_context: - - # Iteratively explore the AiiDA graph to find further nodes that should also be exported - with get_progress_reporter()(desc='Traversing provenance via links ...', total=1) as progress: - traverse_output = get_nodes_export(starting_pks=given_node_entry_ids, get_links=True, **traversal_rules) - progress.update() - node_ids_to_be_exported = traverse_output['nodes'] - - EXPORT_LOGGER.debug('WRITING METADATA...') - - writer_context.write_metadata( - ArchiveMetadata( - export_version=EXPORT_VERSION, - aiida_version=get_version(), - unique_identifiers=unique_identifiers, - all_fields_info=all_fields_info, - graph_traversal_rules=traverse_output['rules'], - # Turn sets into lists to be able to export them as JSON metadata. - entities_starting_set={ - entity: list(entity_set) for entity, entity_set in entities_starting_set.items() - }, - include_comments=include_comments, - include_logs=include_logs, - ) - ) - - # Create a mapping of node PK to UUID. - node_pk_2_uuid_mapping: Dict[int, str] = {} - repository_metadata_mapping: Dict[int, dict] = {} - - if node_ids_to_be_exported: - qbuilder = orm.QueryBuilder().append( - orm.Node, - project=('id', 'uuid', 'repository_metadata'), - filters={'id': { - 'in': node_ids_to_be_exported - }}, - ) - for pk, uuid, repository_metadata in qbuilder.iterall(batch_size=batch_size): - node_pk_2_uuid_mapping[pk] = uuid - repository_metadata_mapping[pk] = repository_metadata - - # check that no nodes are being exported with incorrect licensing - _check_node_licenses(node_ids_to_be_exported, allowed_licenses, forbidden_licenses) - - # write the link data - if traverse_output['links'] is not None: - with get_progress_reporter()(total=len(traverse_output['links']), desc='Writing links') as progress: - for link in traverse_output['links']: - progress.update() - writer_context.write_link({ - 'input': node_pk_2_uuid_mapping[link.source_id], - 'output': node_pk_2_uuid_mapping[link.target_id], - 'label': link.link_label, - 'type': link.link_type, - }) - - # generate a list of queries to encapsulate all required entities - entity_queries = _collect_entity_queries( - node_ids_to_be_exported, - entities_starting_set, - node_pk_2_uuid_mapping, - include_comments, - include_logs, - ) - - total_entities = sum(query.count() for query in entity_queries.values()) - - # write all entity data fields - if total_entities: - exported_entity_pks = _write_entity_data( - total_entities=total_entities, - entity_queries=entity_queries, - writer=writer_context, - batch_size=batch_size - ) - else: - exported_entity_pks = defaultdict(set) - EXPORT_LOGGER.info('No entities were found to export') - - # write mappings of groups to the nodes they contain - if exported_entity_pks[GROUP_ENTITY_NAME]: - - EXPORT_LOGGER.debug('Writing group UUID -> [nodes UUIDs]') - - _write_group_mappings( - group_pks=exported_entity_pks[GROUP_ENTITY_NAME], batch_size=batch_size, writer=writer_context - ) - - # copy all required node repositories - if exported_entity_pks[NODE_ENTITY_NAME]: - - _write_node_repositories( - node_pks=exported_entity_pks[NODE_ENTITY_NAME], - repository_metadata_mapping=repository_metadata_mapping, - writer=writer_context - ) - - EXPORT_LOGGER.info('Finalizing Export...') - - # summarize export - export_summary = '\n - '.join(f'{name:<6}: {len(pks)}' for name, pks in exported_entity_pks.items()) - if exported_entity_pks: - EXPORT_LOGGER.info('Exported Entities:\n - ' + export_summary + '\n') - # TODO - # EXPORT_LOGGER.info('Writer Information:\n %s', writer.export_info) - - return writer - - -def _get_starting_node_ids(entities: List[Any]) -> Tuple[DefaultDict[str, Set[str]], Set[int]]: - """Get the starting node UUIDs and PKs - - :param entities: a list of entity instances - - :raises exceptions.ArchiveExportError: - :return: entities_starting_set, given_node_entry_ids - """ - entities_starting_set: DefaultDict[str, Set[str]] = defaultdict(set) - given_node_entry_ids: Set[int] = set() - - # store a list of the actual dbnodes - total = len(entities) + (1 if GROUP_ENTITY_NAME in entities_starting_set else 0) - if not total: - return entities_starting_set, given_node_entry_ids - - for entry in entities: - - if issubclass(entry.__class__, orm.Group): - entities_starting_set[GROUP_ENTITY_NAME].add(entry.uuid) - elif issubclass(entry.__class__, orm.Node): - entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid) - given_node_entry_ids.add(entry.pk) - elif issubclass(entry.__class__, orm.Computer): - entities_starting_set[COMPUTER_ENTITY_NAME].add(entry.uuid) - else: - raise exceptions.ArchiveExportError( - f'I was given {entry} ({type(entry)}),' - ' which is not a Node, Computer, or Group instance' - ) - - # Add all the nodes contained within the specified groups - if GROUP_ENTITY_NAME in entities_starting_set: - # Use single query instead of given_group.nodes iterator for performance. - qh_groups = ( - orm.QueryBuilder().append( - orm.Group, - filters={ - 'uuid': { - 'in': entities_starting_set[GROUP_ENTITY_NAME] - } - }, - tag='groups', - ).queryhelp - ) - node_query = orm.QueryBuilder(**qh_groups).append(orm.Node, project=['id', 'uuid'], with_group='groups') - node_count = node_query.count() +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import - if node_count: - with get_progress_reporter()(desc='Collecting nodes in groups', total=node_count) as progress: +from .main import * - pks, uuids = [], [] - for pk, uuid in node_query.all(): - progress.update() - pks.append(pk) - uuids.append(uuid) - - entities_starting_set[NODE_ENTITY_NAME].update(uuids) - given_node_entry_ids.update(pks) - - return entities_starting_set, given_node_entry_ids - - -def _check_node_licenses( - node_ids_to_be_exported: Set[int], - allowed_licenses: Optional[Union[list, Callable]], - forbidden_licenses: Optional[Union[list, Callable]], -) -> None: - """Check the nodes to be archived for disallowed licences.""" - # TODO (Spyros) To see better! Especially for functional licenses - # Check the licenses of exported data. - if allowed_licenses is not None or forbidden_licenses is not None: - builder = orm.QueryBuilder() - builder.append( - orm.Node, - project=['id', 'attributes.source.license'], - filters={'id': { - 'in': node_ids_to_be_exported - }}, - ) - # Skip those nodes where the license is not set (this is the standard behavior with Django) - node_licenses = [(a, b) for [a, b] in builder.all() if b is not None] - check_licenses(node_licenses, allowed_licenses, forbidden_licenses) - - -def _get_model_fields(entity_name: str) -> List[str]: - """Return a list of fields to retrieve for a particular entity - - :param entity_name: name of database entity, such as Node - - """ - all_fields_info, _ = get_all_fields_info() - project_cols = ['id'] - entity_prop = all_fields_info[entity_name].keys() - # Here we do the necessary renaming of properties - for prop in entity_prop: - # nprop contains the list of projections - nprop = ( - file_fields_to_model_fields[entity_name][prop] if prop in file_fields_to_model_fields[entity_name] else prop - ) - project_cols.append(nprop) - return project_cols - - -def _collect_entity_queries( - node_ids_to_be_exported: Set[int], - entities_starting_set: DefaultDict[str, Set[str]], - node_pk_2_uuid_mapping: Dict[int, str], - include_comments: bool = True, - include_logs: bool = True, -) -> Dict[str, orm.QueryBuilder]: - """Gather partial queries for all entities to export.""" - # pylint: disable=too-many-locals - given_log_entry_ids = set() - given_comment_entry_ids = set() - - total = 2 + (((1 if include_logs else 0) + (1 if include_comments else 0)) if node_ids_to_be_exported else 0) - with get_progress_reporter()(desc='Building entity database queries', total=total) as progress: - - # Logs - if include_logs and node_ids_to_be_exported: - # Get related log(s) - universal for all nodes - builder = orm.QueryBuilder() - builder.append( - orm.Log, - filters={'dbnode_id': { - 'in': node_ids_to_be_exported - }}, - project='uuid', - ) - res = set(builder.all(flat=True)) - given_log_entry_ids.update(res) - - progress.update() - - # Comments - if include_comments and node_ids_to_be_exported: - # Get related log(s) - universal for all nodes - builder = orm.QueryBuilder() - builder.append( - orm.Comment, - filters={'dbnode_id': { - 'in': node_ids_to_be_exported - }}, - project='uuid', - ) - res = set(builder.all(flat=True)) - given_comment_entry_ids.update(res) - - progress.update() - - # Here we get all the columns that we plan to project per entity that we would like to extract - given_entities = set(entities_starting_set.keys()) - if node_ids_to_be_exported: - given_entities.add(NODE_ENTITY_NAME) - if given_log_entry_ids: - given_entities.add(LOG_ENTITY_NAME) - if given_comment_entry_ids: - given_entities.add(COMMENT_ENTITY_NAME) - - progress.update() - - entities_to_add: Dict[str, orm.QueryBuilder] = {} - if not given_entities: - progress.update() - return entities_to_add - - for given_entity in given_entities: - - project_cols = _get_model_fields(given_entity) - - # Getting the ids that correspond to the right entity - entry_uuids_to_add = entities_starting_set.get(given_entity, set()) - if not entry_uuids_to_add: - if given_entity == LOG_ENTITY_NAME: - entry_uuids_to_add = given_log_entry_ids - elif given_entity == COMMENT_ENTITY_NAME: - entry_uuids_to_add = given_comment_entry_ids - elif given_entity == NODE_ENTITY_NAME: - entry_uuids_to_add.update({node_pk_2_uuid_mapping[_] for _ in node_ids_to_be_exported}) - - builder = orm.QueryBuilder() - builder.append( - entity_names_to_entities[given_entity], - filters={'uuid': { - 'in': entry_uuids_to_add - }}, - project=project_cols, - tag=given_entity, - outerjoin=True, - ) - entities_to_add[given_entity] = builder - - progress.update() - - return entities_to_add - - -def _write_entity_data( - total_entities: int, entity_queries: Dict[str, orm.QueryBuilder], writer: ArchiveWriterAbstract, batch_size: int -) -> Dict[str, Set[int]]: - """Iterate through data returned from entity queries, serialize the DB fields, then write to the export.""" - all_fields_info, unique_identifiers = get_all_fields_info() - entity_separator = '_' - - exported_entity_pks: Dict[str, Set[int]] = defaultdict(set) - unsealed_node_pks: Set[int] = set() - - with get_progress_reporter()(total=total_entities, desc='Writing entity data') as progress: - - for entity_name, entity_query in entity_queries.items(): - - foreign_fields = {k: v for k, v in all_fields_info[entity_name].items() if 'requires' in v} - for value in foreign_fields.values(): - ref_model_name = value['requires'] - fill_in_query( - entity_query, - entity_name, - ref_model_name, - [entity_name], - entity_separator, - ) - - for query_results in entity_query.iterdict(batch_size=batch_size): - - progress.update() - - for key, value in query_results.items(): - - pk = value['id'] - - # This is an empty result of an outer join. - # It should not be taken into account. - if pk is None: - continue - - # Get current entity - current_entity = key.split(entity_separator)[-1] - - # don't allow duplication - if pk in exported_entity_pks[current_entity]: - continue - - exported_entity_pks[current_entity].add(pk) - - fields = serialize_dict( - value, - remove_fields=['id'], - rename_fields=model_fields_to_file_fields[current_entity], - ) - - if current_entity == NODE_ENTITY_NAME and fields['node_type'].startswith('process.'): - if fields['attributes'].get('sealed', False) is not True: - unsealed_node_pks.add(pk) - - writer.write_entity_data(current_entity, pk, unique_identifiers[current_entity], fields) - - if unsealed_node_pks: - raise exceptions.ExportValidationError( - 'All ProcessNodes must be sealed before they can be exported. ' - f"Node(s) with PK(s): {', '.join(str(pk) for pk in unsealed_node_pks)} is/are not sealed." - ) - - return exported_entity_pks - - -def _write_group_mappings(*, group_pks: Set[int], batch_size: int, writer: ArchiveWriterAbstract): - """Query for node UUIDs in exported groups, and write these these mappings to the archive file.""" - group_uuid_query = orm.QueryBuilder().append( - orm.Group, - filters={ - 'id': { - 'in': list(group_pks) - } - }, - project='uuid', - tag='groups', - ).append(orm.Node, project='uuid', with_group='groups') - - groups_uuid_to_node_uuids = defaultdict(set) - for group_uuid, node_uuid in group_uuid_query.iterall(batch_size=batch_size): - groups_uuid_to_node_uuids[group_uuid].add(node_uuid) - - for group_uuid, node_uuids in groups_uuid_to_node_uuids.items(): - writer.write_group_nodes(group_uuid, list(node_uuids)) - - -def _write_node_repositories( - *, node_pks: Set[int], repository_metadata_mapping: Dict[int, dict], writer: ArchiveWriterAbstract -): - """Write all exported node repositories to the archive file.""" - with get_progress_reporter()(total=len(node_pks), desc='Exporting node repositories: ') as progress: - - with tempfile.TemporaryDirectory() as temp: - from disk_objectstore import Container - from aiida.manage.manager import get_manager - - dirpath = os.path.join(temp, 'container') - container_export = Container(dirpath) - container_export.init_container() - - profile = get_manager().get_profile() - assert profile is not None, 'profile not loaded' - container_profile = profile.get_repository().backend.container - - # This should be done more effectively, starting by not having to load the node. Either the repository - # metadata should be collected earlier when the nodes themselves are already exported or a single separate - # query should be done. - hashkeys = [] - - def collect_hashkeys(objects): - for obj in objects.values(): - hashkey = obj.get('k', None) - if hashkey is not None: - hashkeys.append(hashkey) - subobjects = obj.get('o', None) - if subobjects: - collect_hashkeys(subobjects) - - for pk in node_pks: - progress.set_description_str(f'Exporting node repositories: {pk}', refresh=False) - progress.update() - repository_metadata = repository_metadata_mapping[pk] - collect_hashkeys(repository_metadata.get('o', {})) - - callback = create_callback(progress) - container_profile.export(set(hashkeys), container_export, compress=False, callback=callback) - writer.write_repository_container(container_export) +__all__ = ( + 'EXPORT_LOGGER', + 'ExportFileFormat', + 'export', +) diff --git a/aiida/tools/importexport/dbexport/main.py b/aiida/tools/importexport/dbexport/main.py new file mode 100644 index 0000000000..c68a3c242a --- /dev/null +++ b/aiida/tools/importexport/dbexport/main.py @@ -0,0 +1,613 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=fixme,too-many-lines +"""Provides export functionalities.""" +from collections import defaultdict +import os +import tempfile +from typing import ( + Any, + Callable, + DefaultDict, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Type, + Union, + cast, +) + +from aiida import get_version, orm +from aiida.common.links import GraphTraversalRules +from aiida.common.lang import type_check +from aiida.common.progress_reporter import get_progress_reporter, create_callback +from aiida.tools.importexport.common import exceptions +from aiida.tools.importexport.common.config import ( + COMMENT_ENTITY_NAME, + COMPUTER_ENTITY_NAME, + EXPORT_VERSION, + GROUP_ENTITY_NAME, + LOG_ENTITY_NAME, + NODE_ENTITY_NAME, + ExportFileFormat, + entity_names_to_entities, + file_fields_to_model_fields, + get_all_fields_info, + model_fields_to_file_fields, +) +from aiida.tools.graph.graph_traversers import get_nodes_export, validate_traversal_rules +from aiida.tools.importexport.archive.writers import ArchiveMetadata, ArchiveWriterAbstract, get_writer +from aiida.tools.importexport.dbexport.utils import ( + EXPORT_LOGGER, + check_licenses, + fill_in_query, + serialize_dict, + summary, +) + +__all__ = ('export', 'EXPORT_LOGGER', 'ExportFileFormat') + + +def export( + entities: Optional[Iterable[Any]] = None, + filename: Optional[str] = None, + file_format: Union[str, Type[ArchiveWriterAbstract]] = ExportFileFormat.ZIP, + overwrite: bool = False, + include_comments: bool = True, + include_logs: bool = True, + allowed_licenses: Optional[Union[list, Callable]] = None, + forbidden_licenses: Optional[Union[list, Callable]] = None, + writer_init: Optional[Dict[str, Any]] = None, + batch_size: int = 100, + **traversal_rules: bool, +) -> ArchiveWriterAbstract: + """Export AiiDA data to an archive file. + + Note, the logging level and progress reporter should be set externally, for example:: + + from aiida.common.progress_reporter import set_progress_bar_tqdm + + EXPORT_LOGGER.setLevel('DEBUG') + set_progress_bar_tqdm(leave=True) + export(...) + + :param entities: a list of entity instances; + they can belong to different models/entities. + + :param filename: the filename (possibly including the absolute path) + of the file on which to export. + + :param file_format: 'zip', 'tar.gz' or 'folder' or a specific writer class. + + :param overwrite: if True, overwrite the output file without asking, if it exists. + If False, raise an + :py:class:`~aiida.tools.importexport.common.exceptions.ArchiveExportError` + if the output file already exists. + + :param allowed_licenses: List or function. + If a list, then checks whether all licenses of Data nodes are in the list. If a function, + then calls function for licenses of Data nodes expecting True if license is allowed, False + otherwise. + + :param forbidden_licenses: List or function. If a list, + then checks whether all licenses of Data nodes are in the list. If a function, + then calls function for licenses of Data nodes expecting True if license is allowed, False + otherwise. + + :param include_comments: In-/exclude export of comments for given node(s) in ``entities``. + Default: True, *include* comments in export (as well as relevant users). + + :param include_logs: In-/exclude export of logs for given node(s) in ``entities``. + Default: True, *include* logs in export. + + :param writer_init: Additional key-word arguments to pass to the writer class init + + :param batch_size: batch database query results in sub-collections to reduce memory usage + + :param traversal_rules: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` + what rule names are toggleable and what the defaults are. + + :returns: a dictionary of data regarding the export process (timings, etc) + + :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`: + if there are any internal errors when exporting. + :raises `~aiida.common.exceptions.LicensingException`: + if any node is licensed under forbidden license. + """ + # pylint: disable=too-many-locals,too-many-branches,too-many-statements + + # Backwards-compatibility + entities = cast(Iterable[Any], entities) + filename = cast(str, filename) + + type_check( + entities, + (list, tuple, set), + msg='`entities` must be specified and given as a list of AiiDA entities', + ) + entities = list(entities) + if type_check(filename, str, allow_none=True) is None: + filename = 'export_data.aiida' + + if not overwrite and os.path.exists(filename): + raise exceptions.ArchiveExportError(f"The output file '{filename}' already exists") + + # validate the traversal rules and generate a full set for reporting + validate_traversal_rules(GraphTraversalRules.EXPORT, **traversal_rules) + full_traversal_rules = { + name: traversal_rules.get(name, rule.default) for name, rule in GraphTraversalRules.EXPORT.value.items() + } + + # setup the archive writer + writer_init = writer_init or {} + if isinstance(file_format, str): + writer = get_writer(file_format)(filepath=filename, **writer_init) + elif issubclass(file_format, ArchiveWriterAbstract): + writer = file_format(filepath=filename, **writer_init) + else: + raise TypeError('file_format must be a string or ArchiveWriterAbstract class') + + summary( + file_format=writer.file_format_verbose, + export_version=writer.export_version, + outfile=filename, + include_comments=include_comments, + include_logs=include_logs, + traversal_rules=full_traversal_rules + ) + + EXPORT_LOGGER.debug('STARTING EXPORT...') + + all_fields_info, unique_identifiers = get_all_fields_info() + entities_starting_set, given_node_entry_ids = _get_starting_node_ids(entities) + + # Initialize the writer + with writer as writer_context: + + # Iteratively explore the AiiDA graph to find further nodes that should also be exported + with get_progress_reporter()(desc='Traversing provenance via links ...', total=1) as progress: + traverse_output = get_nodes_export(starting_pks=given_node_entry_ids, get_links=True, **traversal_rules) + progress.update() + node_ids_to_be_exported = traverse_output['nodes'] + + EXPORT_LOGGER.debug('WRITING METADATA...') + + writer_context.write_metadata( + ArchiveMetadata( + export_version=EXPORT_VERSION, + aiida_version=get_version(), + unique_identifiers=unique_identifiers, + all_fields_info=all_fields_info, + graph_traversal_rules=traverse_output['rules'], + # Turn sets into lists to be able to export them as JSON metadata. + entities_starting_set={ + entity: list(entity_set) for entity, entity_set in entities_starting_set.items() + }, + include_comments=include_comments, + include_logs=include_logs, + ) + ) + + # Create a mapping of node PK to UUID. + node_pk_2_uuid_mapping: Dict[int, str] = {} + repository_metadata_mapping: Dict[int, dict] = {} + + if node_ids_to_be_exported: + qbuilder = orm.QueryBuilder().append( + orm.Node, + project=('id', 'uuid', 'repository_metadata'), + filters={'id': { + 'in': node_ids_to_be_exported + }}, + ) + for pk, uuid, repository_metadata in qbuilder.iterall(batch_size=batch_size): + node_pk_2_uuid_mapping[pk] = uuid + repository_metadata_mapping[pk] = repository_metadata + + # check that no nodes are being exported with incorrect licensing + _check_node_licenses(node_ids_to_be_exported, allowed_licenses, forbidden_licenses) + + # write the link data + if traverse_output['links'] is not None: + with get_progress_reporter()(total=len(traverse_output['links']), desc='Writing links') as progress: + for link in traverse_output['links']: + progress.update() + writer_context.write_link({ + 'input': node_pk_2_uuid_mapping[link.source_id], + 'output': node_pk_2_uuid_mapping[link.target_id], + 'label': link.link_label, + 'type': link.link_type, + }) + + # generate a list of queries to encapsulate all required entities + entity_queries = _collect_entity_queries( + node_ids_to_be_exported, + entities_starting_set, + node_pk_2_uuid_mapping, + include_comments, + include_logs, + ) + + total_entities = sum(query.count() for query in entity_queries.values()) + + # write all entity data fields + if total_entities: + exported_entity_pks = _write_entity_data( + total_entities=total_entities, + entity_queries=entity_queries, + writer=writer_context, + batch_size=batch_size + ) + else: + exported_entity_pks = defaultdict(set) + EXPORT_LOGGER.info('No entities were found to export') + + # write mappings of groups to the nodes they contain + if exported_entity_pks[GROUP_ENTITY_NAME]: + + EXPORT_LOGGER.debug('Writing group UUID -> [nodes UUIDs]') + + _write_group_mappings( + group_pks=exported_entity_pks[GROUP_ENTITY_NAME], batch_size=batch_size, writer=writer_context + ) + + # copy all required node repositories + if exported_entity_pks[NODE_ENTITY_NAME]: + + _write_node_repositories( + node_pks=exported_entity_pks[NODE_ENTITY_NAME], + repository_metadata_mapping=repository_metadata_mapping, + writer=writer_context + ) + + EXPORT_LOGGER.info('Finalizing Export...') + + # summarize export + export_summary = '\n - '.join(f'{name:<6}: {len(pks)}' for name, pks in exported_entity_pks.items()) + if exported_entity_pks: + EXPORT_LOGGER.info('Exported Entities:\n - ' + export_summary + '\n') + # TODO + # EXPORT_LOGGER.info('Writer Information:\n %s', writer.export_info) + + return writer + + +def _get_starting_node_ids(entities: List[Any]) -> Tuple[DefaultDict[str, Set[str]], Set[int]]: + """Get the starting node UUIDs and PKs + + :param entities: a list of entity instances + + :raises exceptions.ArchiveExportError: + :return: entities_starting_set, given_node_entry_ids + """ + entities_starting_set: DefaultDict[str, Set[str]] = defaultdict(set) + given_node_entry_ids: Set[int] = set() + + # store a list of the actual dbnodes + total = len(entities) + (1 if GROUP_ENTITY_NAME in entities_starting_set else 0) + if not total: + return entities_starting_set, given_node_entry_ids + + for entry in entities: + + if issubclass(entry.__class__, orm.Group): + entities_starting_set[GROUP_ENTITY_NAME].add(entry.uuid) + elif issubclass(entry.__class__, orm.Node): + entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid) + given_node_entry_ids.add(entry.pk) + elif issubclass(entry.__class__, orm.Computer): + entities_starting_set[COMPUTER_ENTITY_NAME].add(entry.uuid) + else: + raise exceptions.ArchiveExportError( + f'I was given {entry} ({type(entry)}),' + ' which is not a Node, Computer, or Group instance' + ) + + # Add all the nodes contained within the specified groups + if GROUP_ENTITY_NAME in entities_starting_set: + + # Use single query instead of given_group.nodes iterator for performance. + qh_groups = ( + orm.QueryBuilder().append( + orm.Group, + filters={ + 'uuid': { + 'in': entities_starting_set[GROUP_ENTITY_NAME] + } + }, + tag='groups', + ).queryhelp + ) + node_query = orm.QueryBuilder(**qh_groups).append(orm.Node, project=['id', 'uuid'], with_group='groups') + node_count = node_query.count() + + if node_count: + with get_progress_reporter()(desc='Collecting nodes in groups', total=node_count) as progress: + + pks, uuids = [], [] + for pk, uuid in node_query.all(): + progress.update() + pks.append(pk) + uuids.append(uuid) + + entities_starting_set[NODE_ENTITY_NAME].update(uuids) + given_node_entry_ids.update(pks) + + return entities_starting_set, given_node_entry_ids + + +def _check_node_licenses( + node_ids_to_be_exported: Set[int], + allowed_licenses: Optional[Union[list, Callable]], + forbidden_licenses: Optional[Union[list, Callable]], +) -> None: + """Check the nodes to be archived for disallowed licences.""" + # TODO (Spyros) To see better! Especially for functional licenses + # Check the licenses of exported data. + if allowed_licenses is not None or forbidden_licenses is not None: + builder = orm.QueryBuilder() + builder.append( + orm.Node, + project=['id', 'attributes.source.license'], + filters={'id': { + 'in': node_ids_to_be_exported + }}, + ) + # Skip those nodes where the license is not set (this is the standard behavior with Django) + node_licenses = [(a, b) for [a, b] in builder.all() if b is not None] + check_licenses(node_licenses, allowed_licenses, forbidden_licenses) + + +def _get_model_fields(entity_name: str) -> List[str]: + """Return a list of fields to retrieve for a particular entity + + :param entity_name: name of database entity, such as Node + + """ + all_fields_info, _ = get_all_fields_info() + project_cols = ['id'] + entity_prop = all_fields_info[entity_name].keys() + # Here we do the necessary renaming of properties + for prop in entity_prop: + # nprop contains the list of projections + nprop = ( + file_fields_to_model_fields[entity_name][prop] if prop in file_fields_to_model_fields[entity_name] else prop + ) + project_cols.append(nprop) + return project_cols + + +def _collect_entity_queries( + node_ids_to_be_exported: Set[int], + entities_starting_set: DefaultDict[str, Set[str]], + node_pk_2_uuid_mapping: Dict[int, str], + include_comments: bool = True, + include_logs: bool = True, +) -> Dict[str, orm.QueryBuilder]: + """Gather partial queries for all entities to export.""" + # pylint: disable=too-many-locals + given_log_entry_ids = set() + given_comment_entry_ids = set() + + total = 2 + (((1 if include_logs else 0) + (1 if include_comments else 0)) if node_ids_to_be_exported else 0) + with get_progress_reporter()(desc='Building entity database queries', total=total) as progress: + + # Logs + if include_logs and node_ids_to_be_exported: + # Get related log(s) - universal for all nodes + builder = orm.QueryBuilder() + builder.append( + orm.Log, + filters={'dbnode_id': { + 'in': node_ids_to_be_exported + }}, + project='uuid', + ) + res = set(builder.all(flat=True)) + given_log_entry_ids.update(res) + + progress.update() + + # Comments + if include_comments and node_ids_to_be_exported: + # Get related log(s) - universal for all nodes + builder = orm.QueryBuilder() + builder.append( + orm.Comment, + filters={'dbnode_id': { + 'in': node_ids_to_be_exported + }}, + project='uuid', + ) + res = set(builder.all(flat=True)) + given_comment_entry_ids.update(res) + + progress.update() + + # Here we get all the columns that we plan to project per entity that we would like to extract + given_entities = set(entities_starting_set.keys()) + if node_ids_to_be_exported: + given_entities.add(NODE_ENTITY_NAME) + if given_log_entry_ids: + given_entities.add(LOG_ENTITY_NAME) + if given_comment_entry_ids: + given_entities.add(COMMENT_ENTITY_NAME) + + progress.update() + + entities_to_add: Dict[str, orm.QueryBuilder] = {} + if not given_entities: + progress.update() + return entities_to_add + + for given_entity in given_entities: + + project_cols = _get_model_fields(given_entity) + + # Getting the ids that correspond to the right entity + entry_uuids_to_add = entities_starting_set.get(given_entity, set()) + if not entry_uuids_to_add: + if given_entity == LOG_ENTITY_NAME: + entry_uuids_to_add = given_log_entry_ids + elif given_entity == COMMENT_ENTITY_NAME: + entry_uuids_to_add = given_comment_entry_ids + elif given_entity == NODE_ENTITY_NAME: + entry_uuids_to_add.update({node_pk_2_uuid_mapping[_] for _ in node_ids_to_be_exported}) + + builder = orm.QueryBuilder() + builder.append( + entity_names_to_entities[given_entity], + filters={'uuid': { + 'in': entry_uuids_to_add + }}, + project=project_cols, + tag=given_entity, + outerjoin=True, + ) + entities_to_add[given_entity] = builder + + progress.update() + + return entities_to_add + + +def _write_entity_data( + total_entities: int, entity_queries: Dict[str, orm.QueryBuilder], writer: ArchiveWriterAbstract, batch_size: int +) -> Dict[str, Set[int]]: + """Iterate through data returned from entity queries, serialize the DB fields, then write to the export.""" + all_fields_info, unique_identifiers = get_all_fields_info() + entity_separator = '_' + + exported_entity_pks: Dict[str, Set[int]] = defaultdict(set) + unsealed_node_pks: Set[int] = set() + + with get_progress_reporter()(total=total_entities, desc='Writing entity data') as progress: + + for entity_name, entity_query in entity_queries.items(): + + foreign_fields = {k: v for k, v in all_fields_info[entity_name].items() if 'requires' in v} + for value in foreign_fields.values(): + ref_model_name = value['requires'] + fill_in_query( + entity_query, + entity_name, + ref_model_name, + [entity_name], + entity_separator, + ) + + for query_results in entity_query.iterdict(batch_size=batch_size): + + progress.update() + + for key, value in query_results.items(): + + pk = value['id'] + + # This is an empty result of an outer join. + # It should not be taken into account. + if pk is None: + continue + + # Get current entity + current_entity = key.split(entity_separator)[-1] + + # don't allow duplication + if pk in exported_entity_pks[current_entity]: + continue + + exported_entity_pks[current_entity].add(pk) + + fields = serialize_dict( + value, + remove_fields=['id'], + rename_fields=model_fields_to_file_fields[current_entity], + ) + + if current_entity == NODE_ENTITY_NAME and fields['node_type'].startswith('process.'): + if fields['attributes'].get('sealed', False) is not True: + unsealed_node_pks.add(pk) + + writer.write_entity_data(current_entity, pk, unique_identifiers[current_entity], fields) + + if unsealed_node_pks: + raise exceptions.ExportValidationError( + 'All ProcessNodes must be sealed before they can be exported. ' + f"Node(s) with PK(s): {', '.join(str(pk) for pk in unsealed_node_pks)} is/are not sealed." + ) + + return exported_entity_pks + + +def _write_group_mappings(*, group_pks: Set[int], batch_size: int, writer: ArchiveWriterAbstract): + """Query for node UUIDs in exported groups, and write these these mappings to the archive file.""" + group_uuid_query = orm.QueryBuilder().append( + orm.Group, + filters={ + 'id': { + 'in': list(group_pks) + } + }, + project='uuid', + tag='groups', + ).append(orm.Node, project='uuid', with_group='groups') + + groups_uuid_to_node_uuids = defaultdict(set) + for group_uuid, node_uuid in group_uuid_query.iterall(batch_size=batch_size): + groups_uuid_to_node_uuids[group_uuid].add(node_uuid) + + for group_uuid, node_uuids in groups_uuid_to_node_uuids.items(): + writer.write_group_nodes(group_uuid, list(node_uuids)) + + +def _write_node_repositories( + *, node_pks: Set[int], repository_metadata_mapping: Dict[int, dict], writer: ArchiveWriterAbstract +): + """Write all exported node repositories to the archive file.""" + with get_progress_reporter()(total=len(node_pks), desc='Exporting node repositories: ') as progress: + + with tempfile.TemporaryDirectory() as temp: + from disk_objectstore import Container + from aiida.manage.manager import get_manager + + dirpath = os.path.join(temp, 'container') + container_export = Container(dirpath) + container_export.init_container() + + profile = get_manager().get_profile() + assert profile is not None, 'profile not loaded' + container_profile = profile.get_repository().backend.container + + # This should be done more effectively, starting by not having to load the node. Either the repository + # metadata should be collected earlier when the nodes themselves are already exported or a single separate + # query should be done. + hashkeys = [] + + def collect_hashkeys(objects): + for obj in objects.values(): + hashkey = obj.get('k', None) + if hashkey is not None: + hashkeys.append(hashkey) + subobjects = obj.get('o', None) + if subobjects: + collect_hashkeys(subobjects) + + for pk in node_pks: + progress.set_description_str(f'Exporting node repositories: {pk}', refresh=False) + progress.update() + repository_metadata = repository_metadata_mapping[pk] + collect_hashkeys(repository_metadata.get('o', {})) + + callback = create_callback(progress) + container_profile.export(set(hashkeys), container_export, compress=False, callback=callback) + writer.write_repository_container(container_export) diff --git a/aiida/tools/importexport/dbimport/__init__.py b/aiida/tools/importexport/dbimport/__init__.py index fa90faa5cc..1d3f712e24 100644 --- a/aiida/tools/importexport/dbimport/__init__.py +++ b/aiida/tools/importexport/dbimport/__init__.py @@ -8,79 +8,14 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Provides import functionalities.""" -from aiida.tools.importexport.dbimport.utils import IMPORT_LOGGER -__all__ = ('import_data', 'IMPORT_LOGGER') +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import +from .main import * -def import_data(in_path, group=None, **kwargs): - """Import exported AiiDA archive to the AiiDA database and repository. - - Proxy function for the backend-specific import functions. - If ``in_path`` is a folder, calls extract_tree; otherwise, tries to detect the compression format - (zip, tar.gz, tar.bz2, ...) and calls the correct function. - - Note, the logging level and progress reporter should be set externally, for example:: - - from aiida.common.progress_reporter import set_progress_bar_tqdm - - IMPORT_LOGGER.setLevel('DEBUG') - set_progress_bar_tqdm(leave=True) - import_data(...) - - :param in_path: the path to a file or folder that can be imported in AiiDA. - :type in_path: str - - :param group: Group wherein all imported Nodes will be placed. - :type group: :py:class:`~aiida.orm.groups.Group` - - :param extras_mode_existing: 3 letter code that will identify what to do with the extras import. - The first letter acts on extras that are present in the original node and not present in the imported node. - Can be either: - 'k' (keep it) or - 'n' (do not keep it). - The second letter acts on the imported extras that are not present in the original node. - Can be either: - 'c' (create it) or - 'n' (do not create it). - The third letter defines what to do in case of a name collision. - Can be either: - 'l' (leave the old value), - 'u' (update with a new value), - 'd' (delete the extra), or - 'a' (ask what to do if the content is different). - :type extras_mode_existing: str - - :param extras_mode_new: 'import' to import extras of new nodes or 'none' to ignore them - :type extras_mode_new: str - - :param comment_mode: Comment import modes (when same UUIDs are found). - Can be either: - 'newest' (will keep the Comment with the most recent modification time (mtime)) or - 'overwrite' (will overwrite existing Comments with the ones from the import file). - :type comment_mode: str - - :return: New and existing Nodes and Links. - :rtype: dict - - :raises `~aiida.tools.importexport.common.exceptions.ArchiveImportError`: if there are any internal errors when - importing. - """ - from aiida.manage import configuration - from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA - from aiida.tools.importexport.common.exceptions import ArchiveImportError - - backend = configuration.PROFILE.database_backend - - if backend == BACKEND_SQLA: - from aiida.tools.importexport.dbimport.backends.sqla import import_data_sqla - IMPORT_LOGGER.debug('Calling import function import_data_sqla for the %s backend.', backend) - return import_data_sqla(in_path, group=group, **kwargs) - - if backend == BACKEND_DJANGO: - from aiida.tools.importexport.dbimport.backends.django import import_data_dj - IMPORT_LOGGER.debug('Calling import function import_data_dj for the %s backend.', backend) - return import_data_dj(in_path, group=group, **kwargs) - - # else - raise ArchiveImportError(f'Unknown backend: {backend}') +__all__ = ( + 'IMPORT_LOGGER', + 'import_data', +) diff --git a/aiida/tools/importexport/dbimport/main.py b/aiida/tools/importexport/dbimport/main.py new file mode 100644 index 0000000000..fa90faa5cc --- /dev/null +++ b/aiida/tools/importexport/dbimport/main.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Provides import functionalities.""" +from aiida.tools.importexport.dbimport.utils import IMPORT_LOGGER + +__all__ = ('import_data', 'IMPORT_LOGGER') + + +def import_data(in_path, group=None, **kwargs): + """Import exported AiiDA archive to the AiiDA database and repository. + + Proxy function for the backend-specific import functions. + If ``in_path`` is a folder, calls extract_tree; otherwise, tries to detect the compression format + (zip, tar.gz, tar.bz2, ...) and calls the correct function. + + Note, the logging level and progress reporter should be set externally, for example:: + + from aiida.common.progress_reporter import set_progress_bar_tqdm + + IMPORT_LOGGER.setLevel('DEBUG') + set_progress_bar_tqdm(leave=True) + import_data(...) + + :param in_path: the path to a file or folder that can be imported in AiiDA. + :type in_path: str + + :param group: Group wherein all imported Nodes will be placed. + :type group: :py:class:`~aiida.orm.groups.Group` + + :param extras_mode_existing: 3 letter code that will identify what to do with the extras import. + The first letter acts on extras that are present in the original node and not present in the imported node. + Can be either: + 'k' (keep it) or + 'n' (do not keep it). + The second letter acts on the imported extras that are not present in the original node. + Can be either: + 'c' (create it) or + 'n' (do not create it). + The third letter defines what to do in case of a name collision. + Can be either: + 'l' (leave the old value), + 'u' (update with a new value), + 'd' (delete the extra), or + 'a' (ask what to do if the content is different). + :type extras_mode_existing: str + + :param extras_mode_new: 'import' to import extras of new nodes or 'none' to ignore them + :type extras_mode_new: str + + :param comment_mode: Comment import modes (when same UUIDs are found). + Can be either: + 'newest' (will keep the Comment with the most recent modification time (mtime)) or + 'overwrite' (will overwrite existing Comments with the ones from the import file). + :type comment_mode: str + + :return: New and existing Nodes and Links. + :rtype: dict + + :raises `~aiida.tools.importexport.common.exceptions.ArchiveImportError`: if there are any internal errors when + importing. + """ + from aiida.manage import configuration + from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA + from aiida.tools.importexport.common.exceptions import ArchiveImportError + + backend = configuration.PROFILE.database_backend + + if backend == BACKEND_SQLA: + from aiida.tools.importexport.dbimport.backends.sqla import import_data_sqla + IMPORT_LOGGER.debug('Calling import function import_data_sqla for the %s backend.', backend) + return import_data_sqla(in_path, group=group, **kwargs) + + if backend == BACKEND_DJANGO: + from aiida.tools.importexport.dbimport.backends.django import import_data_dj + IMPORT_LOGGER.debug('Calling import function import_data_dj for the %s backend.', backend) + return import_data_dj(in_path, group=group, **kwargs) + + # else + raise ArchiveImportError(f'Unknown backend: {backend}') diff --git a/aiida/tools/visualization/__init__.py b/aiida/tools/visualization/__init__.py index e059feb04e..d198ea67f9 100644 --- a/aiida/tools/visualization/__init__.py +++ b/aiida/tools/visualization/__init__.py @@ -7,8 +7,18 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Provides tools for visualization of the provenance graph.""" + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + from .graph import * -__all__ = (graph.__all__) +__all__ = ( + 'Graph', + 'default_link_styles', + 'default_node_styles', + 'default_node_sublabels', + 'pstate_node_styles', +) diff --git a/aiida/transports/__init__.py b/aiida/transports/__init__.py index 813fc6cc69..0b9b968f18 100644 --- a/aiida/transports/__init__.py +++ b/aiida/transports/__init__.py @@ -7,9 +7,18 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module for classes and utilities to define transports to other machines.""" +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .plugins import * from .transport import * -__all__ = (transport.__all__) +__all__ = ( + 'SshTransport', + 'Transport', + 'convert_to_bool', + 'parse_sshconfig', +) diff --git a/aiida/transports/plugins/__init__.py b/aiida/transports/plugins/__init__.py index 2776a55f97..efa2d9ce7f 100644 --- a/aiida/transports/plugins/__init__.py +++ b/aiida/transports/plugins/__init__.py @@ -7,3 +7,15 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### + +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import + +from .ssh import * + +__all__ = ( + 'SshTransport', + 'convert_to_bool', + 'parse_sshconfig', +) diff --git a/utils/make_all.py b/utils/make_all.py new file mode 100644 index 0000000000..c9f3b44514 --- /dev/null +++ b/utils/make_all.py @@ -0,0 +1,142 @@ +import ast +from collections import Counter +from pathlib import Path +from pprint import pprint +import sys +from typing import List, Optional, Tuple + + +def parse_all(folder_path: str) -> Tuple[dict, dict]: + """Walk through all files in folder, and parse the ``__all__`` variable. + + :return: (all dict, dict of unparsable) + """ + folder_path = Path(folder_path) + all_dict = {} + bad_all = {} + + for path in folder_path.glob("**/*.py"): + + # skip module files + if path.name == "__init__.py": + continue + + # parse the file + parsed = ast.parse(path.read_text(encoding="utf8")) + + # find __all__ assignment + all_token = None + for token in parsed.body: + if not isinstance(token, ast.Assign): + continue + if token.targets and getattr(token.targets[0], "id", "") == "__all__": + all_token = token + break + + if all_token is None: + bad_all.setdefault("missing", []).append(str(path.relative_to(folder_path))) + continue + + if not isinstance(all_token.value, (ast.List, ast.Tuple)): + bad_all.setdefault("vaue not list/tuple", []).append(str(path.relative_to(folder_path))) + continue + + if not all(isinstance(el, ast.Str) for el in all_token.value.elts): + bad_all.setdefault("child not strings", []).append(str(path.relative_to(folder_path))) + continue + + names = [n.s for n in all_token.value.elts] + if not names: + continue + + path_dict = all_dict + for part in path.parent.relative_to(folder_path).parts: + path_dict = path_dict.setdefault(part, {}) + path_dict.setdefault(path.name[:-3], {})["__all__"] = names + + return all_dict, bad_all + + +def gather_all(all_dict: dict, all_list: Optional[List[str]] = None) -> List[str]: + """Recursively gather __all__ names.""" + all_list = [] if all_list is None else all_list + for key, val in all_dict.items(): + if key == "__all__": + all_list.extend(val) + else: + gather_all(val, all_list) + return all_list + + +def write_inits(folder_path: str, all_dict: dict, skip_children: dict) -> dict: + """Write __init__.py files for all subfolders. + + :return: folders with non-unique imports + """ + folder_path = Path(folder_path) + non_unique = {} + for path in folder_path.glob("**/__init__.py"): + if path.parent == folder_path: + # skip top level __init__.py + continue + rel_path = path.parent.relative_to(folder_path).as_posix() + # get sub_dict for this folder + path_all_dict = all_dict + try: + for part in path.parent.relative_to(folder_path).parts: + path_all_dict = path_all_dict[part] + except KeyError: + continue + path_all_dict = {key: val for key, val in path_all_dict.items() if key not in skip_children.get(rel_path, [])} + alls = gather_all(path_all_dict) + if not alls: + continue + # check for non-unique imports + if len(alls) != len(set(alls)): + non_unique[rel_path] = [k for k, v in Counter(alls).items() if v > 1] + content = (["", "# AUTO-GENERATED"] + ["# yapf: disable", "# pylint: disable=wildcard-import", ""] + + [f"from .{mod} import *" for mod in sorted(path_all_dict.keys())] + ["", "__all__ = ("] + + [f" {a!r}," for a in sorted(set(alls))] + [")", ""]) + + new_content = [] + in_docstring = False + for line in path.read_text(encoding="utf8").splitlines(): + # only use initial comments and docstring + if not (line.startswith("#") or line.startswith('"""') or in_docstring): + break + if line.startswith('"""'): + if (not in_docstring) and not (line.endswith('"""') and not line.strip() == '"""'): + in_docstring = True + else: + in_docstring = False + if not line.startswith("# pylint"): + new_content.append(line) + + # could warn if overwriting any non-autogenerated content + + new_content.extend(content) + + path.write_text("\n".join(new_content), encoding="utf8") + + return non_unique + + +if __name__ == "__main__": + _folder = Path(__file__).parent.parent.joinpath("aiida") + _skip = { + "orm": "implementation", + "orm/implementation": ["django", "sqlalchemy", "sql"], + "cmdline": ["params"], + "cmdline/params": ["arguments", "options"] + } + _all_dict, _bad_all = parse_all(_folder) + _non_unique = write_inits(_folder, _all_dict, _skip) + _bad_all.pop("missing", "") # allow missing __all__ + if _bad_all: + print("unparsable __all__:") + pprint(_bad_all) + if _non_unique: + print("non-unique imports:") + pprint(_non_unique) + if _bad_all or _non_unique: + sys.exit(1) From 5a1ca687e7835d9896ed36299971c468fae80c6d Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 11 Aug 2021 00:02:31 +0200 Subject: [PATCH 02/11] improve --- .pre-commit-config.yaml | 2 +- aiida/cmdline/__init__.py | 38 +- aiida/cmdline/params/__init__.py | 3 + aiida/cmdline/params/arguments/__init__.py | 3 + aiida/cmdline/params/options/__init__.py | 28 + aiida/cmdline/params/options/config.py | 2 + .../params/options/contextualdefault.py | 2 + aiida/cmdline/params/options/main.py | 22 +- aiida/cmdline/params/options/multivalue.py | 2 + aiida/cmdline/params/options/overridable.py | 2 + aiida/cmdline/params/types/__init__.py | 3 + aiida/cmdline/params/types/choice.py | 2 +- aiida/cmdline/params/types/code.py | 1 + aiida/cmdline/params/types/computer.py | 1 + aiida/cmdline/params/types/data.py | 1 + aiida/cmdline/params/types/group.py | 1 + aiida/cmdline/params/types/identifier.py | 1 + aiida/cmdline/params/types/multiple.py | 3 +- aiida/cmdline/params/types/node.py | 3 +- aiida/cmdline/params/types/plugin.py | 5 +- aiida/cmdline/params/types/process.py | 1 + aiida/cmdline/params/types/profile.py | 1 + aiida/cmdline/params/types/strings.py | 1 + aiida/cmdline/params/types/test_module.py | 1 + aiida/cmdline/params/types/user.py | 1 + aiida/cmdline/params/types/workflow.py | 1 + aiida/cmdline/utils/__init__.py | 14 +- aiida/cmdline/utils/echo.py | 2 +- aiida/common/__init__.py | 3 + aiida/common/log.py | 137 ++--- aiida/engine/__init__.py | 3 + aiida/engine/processes/__init__.py | 3 + aiida/engine/processes/calcjobs/__init__.py | 3 + aiida/engine/processes/workchains/__init__.py | 3 + aiida/manage/__init__.py | 14 +- aiida/manage/configuration/__init__.py | 275 +++++++++- aiida/manage/configuration/main.py | 267 --------- .../configuration/migrations/__init__.py | 3 + aiida/manage/database/__init__.py | 3 + aiida/manage/database/integrity/__init__.py | 3 + aiida/manage/external/__init__.py | 3 + aiida/manage/tests/__init__.py | 13 + aiida/manage/tests/main.py | 519 ++++++++++++++++++ aiida/manage/tests/unittest_classes.py | 2 +- aiida/orm/__init__.py | 7 + aiida/orm/implementation/__init__.py | 3 + aiida/orm/implementation/django/__init__.py | 3 + aiida/orm/implementation/sql/__init__.py | 3 + .../orm/implementation/sqlalchemy/__init__.py | 3 + aiida/orm/nodes/__init__.py | 7 + aiida/orm/nodes/data/__init__.py | 8 + aiida/orm/nodes/data/array/__init__.py | 3 + aiida/orm/nodes/data/array/trajectory.py | 1 + aiida/orm/nodes/data/cif.py | 2 + aiida/orm/nodes/data/remote/__init__.py | 3 + aiida/orm/nodes/data/remote/stash/__init__.py | 3 + aiida/orm/nodes/process/__init__.py | 3 + .../orm/nodes/process/calculation/__init__.py | 3 + aiida/orm/nodes/process/workflow/__init__.py | 3 + aiida/orm/utils/__init__.py | 3 + aiida/parsers/__init__.py | 3 + aiida/plugins/__init__.py | 3 + aiida/repository/__init__.py | 3 + aiida/repository/backend/__init__.py | 3 + aiida/restapi/__init__.py | 3 + aiida/schedulers/__init__.py | 3 + aiida/tools/__init__.py | 3 + aiida/tools/calculations/__init__.py | 3 + aiida/tools/data/__init__.py | 3 + aiida/tools/data/array/__init__.py | 4 + aiida/tools/data/array/kpoints/__init__.py | 3 + aiida/tools/data/orbital/__init__.py | 3 + aiida/tools/data/orbital/realhydrogen.py | 2 +- aiida/tools/graph/__init__.py | 3 + aiida/tools/groups/__init__.py | 3 + aiida/tools/importexport/__init__.py | 3 + aiida/tools/importexport/archive/__init__.py | 3 + aiida/tools/importexport/common/__init__.py | 3 + aiida/tools/importexport/dbexport/__init__.py | 3 + aiida/tools/importexport/dbimport/__init__.py | 3 + aiida/tools/visualization/__init__.py | 3 + aiida/transports/__init__.py | 3 + aiida/transports/plugins/__init__.py | 3 + utils/make_all.py | 113 ++-- 84 files changed, 1206 insertions(+), 435 deletions(-) delete mode 100644 aiida/manage/configuration/main.py create mode 100644 aiida/manage/tests/main.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6c228fe232..e51e1f8d65 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,7 +41,7 @@ repos: - id: imports name: imports - entry: utils/make_all.py + entry: python utils/make_all.py language: python types: [python] require_serial: true diff --git a/aiida/cmdline/__init__.py b/aiida/cmdline/__init__.py index f4af371243..b997337142 100644 --- a/aiida/cmdline/__init__.py +++ b/aiida/cmdline/__init__.py @@ -10,22 +10,44 @@ """The command line interface of AiiDA.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import +from .params import * from .utils import * __all__ = ( + 'AbsolutePathParamType', + 'CalculationParamType', + 'CodeParamType', + 'ComputerParamType', + 'ConfigOptionParamType', + 'DataParamType', + 'EmailType', + 'EntryPointType', + 'FileOrUrl', + 'GroupParamType', + 'HostnameType', + 'IdentifierParamType', + 'LabelStringType', + 'LazyChoice', + 'MpirunCommandParamType', + 'MultipleValueParamType', + 'NodeParamType', + 'NonEmptyStringParamType', + 'PathOrUrl', + 'PluginParamType', + 'ProcessParamType', + 'ProfileParamType', + 'ShebangParamType', + 'TestModuleParamType', + 'UserParamType', + 'WorkflowParamType', 'dbenv', - 'echo', - 'echo_critical', - 'echo_dictionary', - 'echo_error', - 'echo_highlight', - 'echo_info', - 'echo_success', - 'echo_warning', 'format_call_graph', 'only_if_daemon_running', 'with_dbenv', ) + +# yapf: enable diff --git a/aiida/cmdline/params/__init__.py b/aiida/cmdline/params/__init__.py index 509af66b4a..5dc9b918a3 100644 --- a/aiida/cmdline/params/__init__.py +++ b/aiida/cmdline/params/__init__.py @@ -9,6 +9,7 @@ ########################################################################### # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -42,3 +43,5 @@ 'UserParamType', 'WorkflowParamType', ) + +# yapf: enable diff --git a/aiida/cmdline/params/arguments/__init__.py b/aiida/cmdline/params/arguments/__init__.py index 83f5413ed9..1c8f5543b8 100644 --- a/aiida/cmdline/params/arguments/__init__.py +++ b/aiida/cmdline/params/arguments/__init__.py @@ -11,6 +11,7 @@ """Module with pre-defined reusable commandline arguments that can be used as `click` decorators.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -41,3 +42,5 @@ 'WORKFLOW', 'WORKFLOWS', ) + +# yapf: enable diff --git a/aiida/cmdline/params/options/__init__.py b/aiida/cmdline/params/options/__init__.py index 74d15b39c0..fdc54869f0 100644 --- a/aiida/cmdline/params/options/__init__.py +++ b/aiida/cmdline/params/options/__init__.py @@ -10,10 +10,15 @@ """Module with pre-defined reusable commandline options that can be used as `click` decorators.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import +from .config import * +from .contextualdefault import * from .main import * +from .multivalue import * +from .overridable import * __all__ = ( 'ALL', @@ -21,6 +26,12 @@ 'ALL_USERS', 'APPEND_TEXT', 'ARCHIVE_FORMAT', + 'BROKER_HOST', + 'BROKER_PASSWORD', + 'BROKER_PORT', + 'BROKER_PROTOCOL', + 'BROKER_USERNAME', + 'BROKER_VIRTUAL_HOST', 'CALCULATION', 'CALCULATIONS', 'CALC_JOB_STATE', @@ -28,6 +39,9 @@ 'CODES', 'COMPUTER', 'COMPUTERS', + 'CONFIG_FILE', + 'ConfigFileOption', + 'ContextualDefaultOption', 'DATA', 'DATUM', 'DB_BACKEND', @@ -39,6 +53,8 @@ 'DB_USERNAME', 'DEBUG', 'DESCRIPTION', + 'DICT_FORMAT', + 'DICT_KEYS', 'DRY_RUN', 'EXIT_STATUS', 'EXPORT_FORMAT', @@ -50,16 +66,21 @@ 'GROUPS', 'GROUP_CLEAR', 'HOSTNAME', + 'IDENTIFIER', 'INPUT_FORMAT', 'INPUT_PLUGIN', 'LABEL', 'LIMIT', + 'MultipleValueOption', 'NODE', 'NODES', 'NON_INTERACTIVE', 'OLDER_THAN', 'ORDER_BY', + 'ORDER_DIRECTION', + 'OverridableOption', 'PAST_DAYS', + 'PAUSED', 'PORT', 'PREPEND_TEXT', 'PRINT_TRACEBACK', @@ -76,6 +97,7 @@ 'TIMEOUT', 'TRAJECTORY_INDEX', 'TRANSPORT', + 'TRAVERSAL_RULE_HELP_STRING', 'TYPE_STRING', 'USER', 'USER_EMAIL', @@ -84,7 +106,13 @@ 'USER_LAST_NAME', 'VERBOSE', 'VISUALIZATION_FORMAT', + 'WAIT', 'WITH_ELEMENTS', 'WITH_ELEMENTS_EXCLUSIVE', + 'active_process_states', 'graph_traversal_rules', + 'valid_calc_job_states', + 'valid_process_states', ) + +# yapf: enable diff --git a/aiida/cmdline/params/options/config.py b/aiida/cmdline/params/options/config.py index 9ab5d82278..a4c7b61dbc 100644 --- a/aiida/cmdline/params/options/config.py +++ b/aiida/cmdline/params/options/config.py @@ -17,6 +17,8 @@ from .overridable import OverridableOption +__all__ = ('ConfigFileOption',) + def yaml_config_file_provider(handle, cmd_name): # pylint: disable=unused-argument """Read yaml config file from file handle.""" diff --git a/aiida/cmdline/params/options/contextualdefault.py b/aiida/cmdline/params/options/contextualdefault.py index 1642b45127..a851371363 100644 --- a/aiida/cmdline/params/options/contextualdefault.py +++ b/aiida/cmdline/params/options/contextualdefault.py @@ -15,6 +15,8 @@ import click +__all__ = ('ContextualDefaultOption',) + class ContextualDefaultOption(click.Option): """A class that extends click.Option allowing to define a default callable diff --git a/aiida/cmdline/params/options/main.py b/aiida/cmdline/params/options/main.py index accd78c65f..6df9e34cdb 100644 --- a/aiida/cmdline/params/options/main.py +++ b/aiida/cmdline/params/options/main.py @@ -17,19 +17,21 @@ from .. import types from .multivalue import MultipleValueOption from .overridable import OverridableOption -from .contextualdefault import ContextualDefaultOption from .config import ConfigFileOption __all__ = ( - 'graph_traversal_rules', 'PROFILE', 'CALCULATION', 'CALCULATIONS', 'CODE', 'CODES', 'COMPUTER', 'COMPUTERS', - 'DATUM', 'DATA', 'GROUP', 'GROUPS', 'NODE', 'NODES', 'FORCE', 'SILENT', 'VISUALIZATION_FORMAT', 'INPUT_FORMAT', - 'EXPORT_FORMAT', 'ARCHIVE_FORMAT', 'NON_INTERACTIVE', 'DRY_RUN', 'USER_EMAIL', 'USER_FIRST_NAME', 'USER_LAST_NAME', - 'USER_INSTITUTION', 'DB_BACKEND', 'DB_ENGINE', 'DB_HOST', 'DB_PORT', 'DB_USERNAME', 'DB_PASSWORD', 'DB_NAME', - 'REPOSITORY_PATH', 'PROFILE_ONLY_CONFIG', 'PROFILE_SET_DEFAULT', 'PREPEND_TEXT', 'APPEND_TEXT', 'LABEL', - 'DESCRIPTION', 'INPUT_PLUGIN', 'CALC_JOB_STATE', 'PROCESS_STATE', 'PROCESS_LABEL', 'TYPE_STRING', 'EXIT_STATUS', - 'FAILED', 'LIMIT', 'PROJECT', 'ORDER_BY', 'PAST_DAYS', 'OLDER_THAN', 'ALL', 'ALL_STATES', 'ALL_USERS', - 'GROUP_CLEAR', 'RAW', 'HOSTNAME', 'TRANSPORT', 'SCHEDULER', 'USER', 'PORT', 'FREQUENCY', 'VERBOSE', 'TIMEOUT', - 'FORMULA_MODE', 'TRAJECTORY_INDEX', 'WITH_ELEMENTS', 'WITH_ELEMENTS_EXCLUSIVE', 'DEBUG', 'PRINT_TRACEBACK' + 'ALL', 'ALL_STATES', 'ALL_USERS', 'APPEND_TEXT', 'ARCHIVE_FORMAT', 'BROKER_HOST', 'BROKER_PASSWORD', 'BROKER_PORT', + 'BROKER_PROTOCOL', 'BROKER_USERNAME', 'BROKER_VIRTUAL_HOST', 'CALCULATION', 'CALCULATIONS', 'CALC_JOB_STATE', + 'CODE', 'CODES', 'COMPUTER', 'COMPUTERS', 'CONFIG_FILE', 'DATA', 'DATUM', 'DB_BACKEND', 'DB_ENGINE', 'DB_HOST', + 'DB_NAME', 'DB_PASSWORD', 'DB_PORT', 'DB_USERNAME', 'DEBUG', 'DESCRIPTION', 'DICT_FORMAT', 'DICT_KEYS', 'DRY_RUN', + 'EXIT_STATUS', 'EXPORT_FORMAT', 'FAILED', 'FORCE', 'FORMULA_MODE', 'FREQUENCY', 'GROUP', 'GROUPS', 'GROUP_CLEAR', + 'HOSTNAME', 'IDENTIFIER', 'INPUT_FORMAT', 'INPUT_PLUGIN', 'LABEL', 'LIMIT', 'NODE', 'NODES', 'NON_INTERACTIVE', + 'OLDER_THAN', 'ORDER_BY', 'ORDER_DIRECTION', 'PAST_DAYS', 'PAUSED', 'PORT', 'PREPEND_TEXT', 'PRINT_TRACEBACK', + 'PROCESS_LABEL', 'PROCESS_STATE', 'PROFILE', 'PROFILE_ONLY_CONFIG', 'PROFILE_SET_DEFAULT', 'PROJECT', 'RAW', + 'REPOSITORY_PATH', 'SCHEDULER', 'SILENT', 'TIMEOUT', 'TRAJECTORY_INDEX', 'TRANSPORT', 'TRAVERSAL_RULE_HELP_STRING', + 'TYPE_STRING', 'USER', 'USER_EMAIL', 'USER_FIRST_NAME', 'USER_INSTITUTION', 'USER_LAST_NAME', 'VERBOSE', + 'VISUALIZATION_FORMAT', 'WAIT', 'WITH_ELEMENTS', 'WITH_ELEMENTS_EXCLUSIVE', 'active_process_states', + 'graph_traversal_rules', 'valid_calc_job_states', 'valid_process_states' ) TRAVERSAL_RULE_HELP_STRING = { diff --git a/aiida/cmdline/params/options/multivalue.py b/aiida/cmdline/params/options/multivalue.py index 9b8fa9a3d2..652b982261 100644 --- a/aiida/cmdline/params/options/multivalue.py +++ b/aiida/cmdline/params/options/multivalue.py @@ -15,6 +15,8 @@ from .. import types +__all__ = ('MultipleValueOption',) + def collect_usage_pieces(self, ctx): """Returns all the pieces that go into the usage line and returns it as a list of strings.""" diff --git a/aiida/cmdline/params/options/overridable.py b/aiida/cmdline/params/options/overridable.py index a8f7a183d9..fae2ca0aff 100644 --- a/aiida/cmdline/params/options/overridable.py +++ b/aiida/cmdline/params/options/overridable.py @@ -16,6 +16,8 @@ import click +__all__ = ('OverridableOption',) + class OverridableOption: """ diff --git a/aiida/cmdline/params/types/__init__.py b/aiida/cmdline/params/types/__init__.py index 9ee58ad940..0c3c24ec61 100644 --- a/aiida/cmdline/params/types/__init__.py +++ b/aiida/cmdline/params/types/__init__.py @@ -10,6 +10,7 @@ """Provides all parameter types.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -60,3 +61,5 @@ 'UserParamType', 'WorkflowParamType', ) + +# yapf: enable diff --git a/aiida/cmdline/params/types/choice.py b/aiida/cmdline/params/types/choice.py index 827b19020a..2a6a2c2190 100644 --- a/aiida/cmdline/params/types/choice.py +++ b/aiida/cmdline/params/types/choice.py @@ -12,7 +12,7 @@ """ import click -__all__ = ('LazyChoice', ) +__all__ = ('LazyChoice',) class LazyChoice(click.ParamType): diff --git a/aiida/cmdline/params/types/code.py b/aiida/cmdline/params/types/code.py index ecbdabbd68..3e7a2803a7 100644 --- a/aiida/cmdline/params/types/code.py +++ b/aiida/cmdline/params/types/code.py @@ -15,6 +15,7 @@ __all__ = ('CodeParamType',) + class CodeParamType(IdentifierParamType): """ The ParamType for identifying Code entities or its subclasses diff --git a/aiida/cmdline/params/types/computer.py b/aiida/cmdline/params/types/computer.py index 5e80e33100..d70fc00d8f 100644 --- a/aiida/cmdline/params/types/computer.py +++ b/aiida/cmdline/params/types/computer.py @@ -18,6 +18,7 @@ __all__ = ('ComputerParamType', 'ShebangParamType', 'MpirunCommandParamType') + class ComputerParamType(IdentifierParamType): """ The ParamType for identifying Computer entities or its subclasses diff --git a/aiida/cmdline/params/types/data.py b/aiida/cmdline/params/types/data.py index 1ce0faf557..742dec10eb 100644 --- a/aiida/cmdline/params/types/data.py +++ b/aiida/cmdline/params/types/data.py @@ -14,6 +14,7 @@ __all__ = ('DataParamType',) + class DataParamType(IdentifierParamType): """ The ParamType for identifying Data entities or its subclasses diff --git a/aiida/cmdline/params/types/group.py b/aiida/cmdline/params/types/group.py index d09582cd31..01a31588d5 100644 --- a/aiida/cmdline/params/types/group.py +++ b/aiida/cmdline/params/types/group.py @@ -17,6 +17,7 @@ __all__ = ('GroupParamType',) + class GroupParamType(IdentifierParamType): """The ParamType for identifying Group entities or its subclasses.""" diff --git a/aiida/cmdline/params/types/identifier.py b/aiida/cmdline/params/types/identifier.py index 3d8a735f4d..058712a090 100644 --- a/aiida/cmdline/params/types/identifier.py +++ b/aiida/cmdline/params/types/identifier.py @@ -19,6 +19,7 @@ __all__ = ('IdentifierParamType',) + class IdentifierParamType(click.ParamType, ABC): """ An extension of click.ParamType for a generic identifier parameter. In AiiDA, orm entities can often be diff --git a/aiida/cmdline/params/types/multiple.py b/aiida/cmdline/params/types/multiple.py index 29a142be7a..a5d4a9f5b5 100644 --- a/aiida/cmdline/params/types/multiple.py +++ b/aiida/cmdline/params/types/multiple.py @@ -12,7 +12,8 @@ """ import click -__all__ = ('MultipleValueParamType', ) +__all__ = ('MultipleValueParamType',) + class MultipleValueParamType(click.ParamType): """ diff --git a/aiida/cmdline/params/types/node.py b/aiida/cmdline/params/types/node.py index 6e6b8f585b..7642eb22d5 100644 --- a/aiida/cmdline/params/types/node.py +++ b/aiida/cmdline/params/types/node.py @@ -12,7 +12,8 @@ """ from .identifier import IdentifierParamType -__all__ = ('NodeParamType', ) +__all__ = ('NodeParamType',) + class NodeParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/params/types/plugin.py b/aiida/cmdline/params/types/plugin.py index e56f4c2c8d..588a3b7877 100644 --- a/aiida/cmdline/params/types/plugin.py +++ b/aiida/cmdline/params/types/plugin.py @@ -16,10 +16,11 @@ from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR, ENTRY_POINT_GROUP_PREFIX, EntryPointFormat from aiida.plugins.entry_point import format_entry_point_string, get_entry_point_string_format from aiida.plugins.entry_point import get_entry_point, get_entry_points, get_entry_point_groups -from ..types import EntryPointType +from .strings import EntryPointType __all__ = ('PluginParamType',) + class PluginParamType(EntryPointType): """ AiiDA Plugin name parameter type. @@ -144,7 +145,7 @@ def complete(self, ctx, incomplete): # pylint: disable=unused-argument """ return [(p, '') for p in self.get_possibilities(incomplete=incomplete)] - def get_missing_message(self, param): + def get_missing_message(self, param): # pylint: disable=unused-argument return 'Possible arguments are:\n\n' + '\n'.join(self.get_valid_arguments()) def get_entry_point_from_string(self, entry_point_string): diff --git a/aiida/cmdline/params/types/process.py b/aiida/cmdline/params/types/process.py index 6933cb42e8..0cbe5abf65 100644 --- a/aiida/cmdline/params/types/process.py +++ b/aiida/cmdline/params/types/process.py @@ -15,6 +15,7 @@ __all__ = ('ProcessParamType',) + class ProcessParamType(IdentifierParamType): """ The ParamType for identifying ProcessNode entities or its subclasses diff --git a/aiida/cmdline/params/types/profile.py b/aiida/cmdline/params/types/profile.py index 831948f382..3cf5449709 100644 --- a/aiida/cmdline/params/types/profile.py +++ b/aiida/cmdline/params/types/profile.py @@ -13,6 +13,7 @@ __all__ = ('ProfileParamType',) + class ProfileParamType(LabelStringType): """The profile parameter type for click.""" diff --git a/aiida/cmdline/params/types/strings.py b/aiida/cmdline/params/types/strings.py index f35e763971..9f6ebeb8be 100644 --- a/aiida/cmdline/params/types/strings.py +++ b/aiida/cmdline/params/types/strings.py @@ -16,6 +16,7 @@ __all__ = ('EmailType', 'EntryPointType', 'HostnameType', 'NonEmptyStringParamType', 'LabelStringType') + class NonEmptyStringParamType(StringParamType): """Parameter whose values have to be string and non-empty.""" name = 'nonemptystring' diff --git a/aiida/cmdline/params/types/test_module.py b/aiida/cmdline/params/types/test_module.py index de985afd8b..d47dbbef94 100644 --- a/aiida/cmdline/params/types/test_module.py +++ b/aiida/cmdline/params/types/test_module.py @@ -12,6 +12,7 @@ __all__ = ('TestModuleParamType',) + class TestModuleParamType(click.ParamType): """Parameter type to represent a unittest module. diff --git a/aiida/cmdline/params/types/user.py b/aiida/cmdline/params/types/user.py index df770e37a5..216530b72e 100644 --- a/aiida/cmdline/params/types/user.py +++ b/aiida/cmdline/params/types/user.py @@ -14,6 +14,7 @@ __all__ = ('UserParamType',) + class UserParamType(click.ParamType): """ The user parameter type for click. Can get or create a user. diff --git a/aiida/cmdline/params/types/workflow.py b/aiida/cmdline/params/types/workflow.py index d23a7442fd..7403ff99f7 100644 --- a/aiida/cmdline/params/types/workflow.py +++ b/aiida/cmdline/params/types/workflow.py @@ -15,6 +15,7 @@ __all__ = ('WorkflowParamType',) + class WorkflowParamType(IdentifierParamType): """ The ParamType for identifying WorkflowNode entities or its subclasses diff --git a/aiida/cmdline/utils/__init__.py b/aiida/cmdline/utils/__init__.py index f88b062b68..9562427ead 100644 --- a/aiida/cmdline/utils/__init__.py +++ b/aiida/cmdline/utils/__init__.py @@ -7,26 +7,22 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Commandline utility functions.""" +# AUTO-GENERATED # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import from .ascii_vis import * from .decorators import * -from .echo import * __all__ = ( 'dbenv', - 'echo', - 'echo_critical', - 'echo_dictionary', - 'echo_error', - 'echo_highlight', - 'echo_info', - 'echo_success', - 'echo_warning', 'format_call_graph', 'only_if_daemon_running', 'with_dbenv', ) + +# yapf: enable diff --git a/aiida/cmdline/utils/echo.py b/aiida/cmdline/utils/echo.py index 248a01a2db..bb511baf74 100644 --- a/aiida/cmdline/utils/echo.py +++ b/aiida/cmdline/utils/echo.py @@ -18,7 +18,7 @@ __all__ = ( 'echo', 'echo_info', 'echo_success', 'echo_warning', 'echo_error', 'echo_critical', 'echo_highlight', - 'echo_dictionary' + 'echo_dictionary', 'VALID_DICT_FORMATS_MAPPING' ) diff --git a/aiida/common/__init__.py b/aiida/common/__init__.py index 880a702298..1143da1861 100644 --- a/aiida/common/__init__.py +++ b/aiida/common/__init__.py @@ -15,6 +15,7 @@ """ # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -86,3 +87,5 @@ 'set_progress_reporter', 'validate_link_label', ) + +# yapf: enable diff --git a/aiida/common/log.py b/aiida/common/log.py index 10a8686fe6..98baa6a9ab 100644 --- a/aiida/common/log.py +++ b/aiida/common/log.py @@ -15,8 +15,6 @@ from contextlib import contextmanager from wrapt import decorator -from aiida.manage.configuration import get_config_option - __all__ = ('AIIDA_LOGGER', 'override_log_level', 'override_log_formatter') # Custom logging level, intended specifically for informative log messages reported during WorkChains. @@ -51,74 +49,77 @@ def filter(self, record): # The default logging dictionary for AiiDA that can be used in conjunction # with the config.dictConfig method of python's logging module -LOGGING = { - 'version': 1, - 'disable_existing_loggers': False, - 'formatters': { - 'verbose': { - 'format': '%(levelname)s %(asctime)s %(module)s %(process)d ' - '%(thread)d %(message)s', - }, - 'halfverbose': { - 'format': '%(asctime)s <%(process)d> %(name)s: [%(levelname)s] %(message)s', - 'datefmt': '%m/%d/%Y %I:%M:%S %p', - }, - }, - 'filters': { - 'testing': { - '()': NotInTestingFilter - } - }, - 'handlers': { - 'console': { - 'level': 'DEBUG', - 'class': 'logging.StreamHandler', - 'formatter': 'halfverbose', - 'filters': ['testing'] - }, - }, - 'loggers': { - 'aiida': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.aiida_loglevel'), - 'propagate': False, +def get_logging_config(): + from aiida.manage.configuration import get_config_option + + return { + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'verbose': { + 'format': '%(levelname)s %(asctime)s %(module)s %(process)d ' + '%(thread)d %(message)s', + }, + 'halfverbose': { + 'format': '%(asctime)s <%(process)d> %(name)s: [%(levelname)s] %(message)s', + 'datefmt': '%m/%d/%Y %I:%M:%S %p', + }, }, - 'plumpy': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.plumpy_loglevel'), - 'propagate': False, + 'filters': { + 'testing': { + '()': NotInTestingFilter + } }, - 'kiwipy': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.kiwipy_loglevel'), - 'propagate': False, + 'handlers': { + 'console': { + 'level': 'DEBUG', + 'class': 'logging.StreamHandler', + 'formatter': 'halfverbose', + 'filters': ['testing'] + }, }, - 'paramiko': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.paramiko_loglevel'), - 'propagate': False, + 'loggers': { + 'aiida': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.aiida_loglevel'), + 'propagate': False, + }, + 'plumpy': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.plumpy_loglevel'), + 'propagate': False, + }, + 'kiwipy': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.kiwipy_loglevel'), + 'propagate': False, + }, + 'paramiko': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.paramiko_loglevel'), + 'propagate': False, + }, + 'alembic': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.alembic_loglevel'), + 'propagate': False, + }, + 'aio_pika': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.aiopika_loglevel'), + 'propagate': False, + }, + 'sqlalchemy': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.sqlalchemy_loglevel'), + 'propagate': False, + 'qualname': 'sqlalchemy.engine', + }, + 'py.warnings': { + 'handlers': ['console'], + }, }, - 'alembic': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.alembic_loglevel'), - 'propagate': False, - }, - 'aio_pika': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.aiopika_loglevel'), - 'propagate': False, - }, - 'sqlalchemy': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.sqlalchemy_loglevel'), - 'propagate': False, - 'qualname': 'sqlalchemy.engine', - }, - 'py.warnings': { - 'handlers': ['console'], - }, - }, -} + } def evaluate_logging_configuration(dictionary): @@ -155,9 +156,11 @@ def configure_logging(with_orm=False, daemon=False, daemon_log_file=None): """ from logging.config import dictConfig + from aiida.manage.configuration import get_config_option + # Evaluate the `LOGGING` configuration to resolve the lambdas that will retrieve the correct values based on the # currently configured profile. Pass a deep copy of `LOGGING` to ensure that the original remains unaltered. - config = evaluate_logging_configuration(copy.deepcopy(LOGGING)) + config = evaluate_logging_configuration(get_logging_config()) daemon_handler_name = 'daemon_log_file' # Add the daemon file handler to all loggers if daemon=True diff --git a/aiida/engine/__init__.py b/aiida/engine/__init__.py index 356645d65b..523970b934 100644 --- a/aiida/engine/__init__.py +++ b/aiida/engine/__init__.py @@ -10,6 +10,7 @@ """Module with all the internals that make up the engine of `aiida-core`.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -70,3 +71,5 @@ 'while_', 'workfunction', ) + +# yapf: enable diff --git a/aiida/engine/processes/__init__.py b/aiida/engine/processes/__init__.py index 9a1930eb76..a2f81d49b3 100644 --- a/aiida/engine/processes/__init__.py +++ b/aiida/engine/processes/__init__.py @@ -10,6 +10,7 @@ """Module for processes and related utilities.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -61,3 +62,5 @@ 'while_', 'workfunction', ) + +# yapf: enable diff --git a/aiida/engine/processes/calcjobs/__init__.py b/aiida/engine/processes/calcjobs/__init__.py index bce38909b3..a91782d092 100644 --- a/aiida/engine/processes/calcjobs/__init__.py +++ b/aiida/engine/processes/calcjobs/__init__.py @@ -10,6 +10,7 @@ """Module for the `CalcJob` process and related utilities.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -21,3 +22,5 @@ 'JobManager', 'JobsList', ) + +# yapf: enable diff --git a/aiida/engine/processes/workchains/__init__.py b/aiida/engine/processes/workchains/__init__.py index b11d3fc147..56b6a94d2d 100644 --- a/aiida/engine/processes/workchains/__init__.py +++ b/aiida/engine/processes/workchains/__init__.py @@ -10,6 +10,7 @@ """Module for the `WorkChain` process and related utilities.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -35,3 +36,5 @@ 'return_', 'while_', ) + +# yapf: enable diff --git a/aiida/manage/__init__.py b/aiida/manage/__init__.py index e8a920fb86..bcb5001a50 100644 --- a/aiida/manage/__init__.py +++ b/aiida/manage/__init__.py @@ -21,6 +21,7 @@ """ # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -29,12 +30,9 @@ from .database import * from .external import * from .manager import * -from .tests import * __all__ = ( - 'BACKEND_UUID', 'BROKER_DEFAULTS', - 'CONFIG', 'CURRENT_CONFIG_VERSION', 'CommunicationTimeout', 'Config', @@ -43,34 +41,28 @@ 'DeliveryFailed', 'OLDEST_COMPATIBLE_CONFIG_VERSION', 'Option', - 'PROFILE', - 'PluginTestCase', 'Postgres', 'PostgresConnectionMode', 'ProcessLauncher', 'Profile', 'RemoteException', 'TABLES_UUID_DEDUPLICATION', - 'TestRunner', 'check_and_migrate_config', 'config_needs_migrating', 'config_schema', 'deduplicate_uuids', 'disable_caching', 'enable_caching', - 'get_config', - 'get_config_option', - 'get_config_path', 'get_current_version', 'get_duplicate_uuids', 'get_manager', 'get_option', 'get_option_names', 'get_use_cache', - 'load_profile', 'parse_option', - 'reset_config', 'reset_manager', 'verify_uuid_uniqueness', 'write_database_integrity_violation', ) + +# yapf: enable diff --git a/aiida/manage/configuration/__init__.py b/aiida/manage/configuration/__init__.py index fb148b11b0..3f81aff446 100644 --- a/aiida/manage/configuration/__init__.py +++ b/aiida/manage/configuration/__init__.py @@ -10,35 +10,292 @@ """Modules related to the configuration of an AiiDA instance.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import from .config import * -from .main import * from .migrations import * from .options import * from .profile import * __all__ = ( - 'BACKEND_UUID', - 'CONFIG', 'CURRENT_CONFIG_VERSION', 'Config', 'ConfigValidationError', 'OLDEST_COMPATIBLE_CONFIG_VERSION', 'Option', - 'PROFILE', 'Profile', 'check_and_migrate_config', 'config_needs_migrating', 'config_schema', - 'get_config', - 'get_config_option', - 'get_config_path', 'get_current_version', 'get_option', 'get_option_names', - 'load_profile', 'parse_option', - 'reset_config', ) + +# yapf: enable + +# END AUTO-GENERATED + +__all__ += ( + 'get_config', 'get_config_option', 'get_config_path', 'get_profile', 'load_documentation_profile', 'load_profile', + 'reset_config', 'reset_profile', 'CONFIG', 'PROFILE', 'BACKEND_UUID' +) + +import os +import shutil +import warnings + +from aiida.common.warnings import AiidaDeprecationWarning +from . import options + +CONFIG = None +PROFILE = None +BACKEND_UUID = None # This will be set to the UUID of the profile as soon as its corresponding backend is loaded + + +def load_profile(profile=None): + """Load a profile. + + .. note:: if a profile is already loaded and no explicit profile is specified, nothing will be done + + :param profile: the name of the profile to load, by default will use the one marked as default in the config + :type profile: str + + :return: the loaded `Profile` instance + :rtype: :class:`~aiida.manage.configuration.Profile` + :raises `aiida.common.exceptions.InvalidOperation`: if the backend of another profile has already been loaded + """ + from aiida.common import InvalidOperation + from aiida.common.log import configure_logging + + global PROFILE + global BACKEND_UUID + + # If a profile is loaded and the specified profile name is None or that of the currently loaded, do nothing + if PROFILE and (profile is None or PROFILE.name is profile): + return PROFILE + + PROFILE = get_config().get_profile(profile) + + if BACKEND_UUID is not None and BACKEND_UUID != PROFILE.uuid: + # Once the switching of profiles with different backends becomes possible, the backend has to be reset properly + raise InvalidOperation('cannot switch profile because backend of another profile is already loaded') + + # Reconfigure the logging to make sure that profile specific logging configuration options are taken into account. + # Note that we do not configure with `with_orm=True` because that will force the backend to be loaded. This should + # instead be done lazily in `Manager._load_backend`. + configure_logging() + + return PROFILE + + +def get_config_path(): + """Returns path to .aiida configuration directory.""" + from .settings import AIIDA_CONFIG_FOLDER, DEFAULT_CONFIG_FILE_NAME + + return os.path.join(AIIDA_CONFIG_FOLDER, DEFAULT_CONFIG_FILE_NAME) + + +def load_config(create=False): + """Instantiate Config object representing an AiiDA configuration file. + + Warning: Contrary to :func:`~aiida.manage.configuration.get_config`, this function is uncached and will always + create a new Config object. You may want to call :func:`~aiida.manage.configuration.get_config` instead. + + :param create: if True, will create the configuration file if it does not already exist + :type create: bool + + :return: the config + :rtype: :class:`~aiida.manage.configuration.config.Config` + :raises aiida.common.MissingConfigurationError: if the configuration file could not be found and create=False + """ + from aiida.common import exceptions + from .config import Config + + filepath = get_config_path() + + if not os.path.isfile(filepath) and not create: + raise exceptions.MissingConfigurationError(f'configuration file {filepath} does not exist') + + try: + config = Config.from_file(filepath) + except ValueError as exc: + raise exceptions.ConfigurationError(f'configuration file {filepath} contains invalid JSON') from exc + + _merge_deprecated_cache_yaml(config, filepath) + + return config + + +def _merge_deprecated_cache_yaml(config, filepath): + """Merge the deprecated cache_config.yml into the config.""" + from aiida.common import timezone + cache_path = os.path.join(os.path.dirname(filepath), 'cache_config.yml') + if not os.path.exists(cache_path): + return + + cache_path_backup = None + # Keep generating a new backup filename based on the current time until it does not exist + while not cache_path_backup or os.path.isfile(cache_path_backup): + cache_path_backup = f"{cache_path}.{timezone.now().strftime('%Y%m%d-%H%M%S.%f')}" + + warnings.warn( + 'cache_config.yml use is deprecated and support will be removed in `v3.0`. Merging into config.json and ' + f'moving to: {cache_path_backup}', AiidaDeprecationWarning + ) + import yaml + with open(cache_path, 'r', encoding='utf8') as handle: + cache_config = yaml.safe_load(handle) + for profile_name, data in cache_config.items(): + if profile_name not in config.profile_names: + warnings.warn(f"Profile '{profile_name}' from cache_config.yml not in config.json, skipping", UserWarning) + continue + for key, option_name in [('default', 'caching.default_enabled'), ('enabled', 'caching.enabled_for'), + ('disabled', 'caching.disabled_for')]: + if key in data: + value = data[key] + # in case of empty key + value = [] if value is None and key != 'default' else value + config.set_option(option_name, value, scope=profile_name) + config.store() + shutil.move(cache_path, cache_path_backup) + + +def get_profile(): + """Return the currently loaded profile. + + :return: the globally loaded `Profile` instance or `None` + :rtype: :class:`~aiida.manage.configuration.Profile` + """ + global PROFILE + return PROFILE + + +def reset_profile(): + """Reset the globally loaded profile. + + .. warning:: This is experimental functionality and should for now be used only internally. If the reset is unclean + weird unknown side-effects may occur that end up corrupting or destroying data. + """ + global PROFILE + global BACKEND_UUID + PROFILE = None + BACKEND_UUID = None + + +def reset_config(): + """Reset the globally loaded config. + + .. warning:: This is experimental functionality and should for now be used only internally. If the reset is unclean + weird unknown side-effects may occur that end up corrupting or destroying data. + """ + global CONFIG + CONFIG = None + + +def get_config(create=False): + """Return the current configuration. + + If the configuration has not been loaded yet + * the configuration is loaded using ``load_config`` + * the global `CONFIG` variable is set + * the configuration object is returned + + Note: This function will except if no configuration file can be found. Only call this function, if you need + information from the configuration file. + + :param create: if True, will create the configuration file if it does not already exist + :type create: bool + + :return: the config + :rtype: :class:`~aiida.manage.configuration.config.Config` + :raises aiida.common.ConfigurationError: if the configuration file could not be found, read or deserialized + """ + global CONFIG + + if not CONFIG: + CONFIG = load_config(create=create) + + if CONFIG.get_option('warnings.showdeprecations'): + # If the user does not want to get AiiDA deprecation warnings, we disable them - this can be achieved with:: + # verdi config warnings.showdeprecations False + # Note that the AiidaDeprecationWarning does NOT inherit from DeprecationWarning + warnings.simplefilter('default', AiidaDeprecationWarning) # pylint: disable=no-member + # This should default to 'once', i.e. once per different message + else: + warnings.simplefilter('ignore', AiidaDeprecationWarning) # pylint: disable=no-member + + return CONFIG + + +def get_config_option(option_name): + """Return the value for the given configuration option. + + This function will attempt to load the value of the option as defined for the current profile or otherwise as + defined configuration wide. If no configuration is yet loaded, this function will fall back on the default that may + be defined for the option itself. This is useful for options that need to be defined at loading time of AiiDA when + no configuration is yet loaded or may not even yet exist. In cases where one expects a profile to be loaded, + preference should be given to retrieving the option through the Config instance and its `get_option` method. + + :param option_name: the name of the configuration option + :type option_name: str + + :return: option value as specified for the profile/configuration if loaded, otherwise option default + """ + from aiida.common import exceptions + + option = options.get_option(option_name) + + try: + config = get_config(create=True) + except exceptions.ConfigurationError: + value = option.default if option.default is not options.NO_DEFAULT else None + else: + if config.current_profile: + # Try to get the option for the profile, but do not return the option default + value_profile = config.get_option(option_name, scope=config.current_profile.name, default=False) + else: + value_profile = None + + # Value is the profile value if defined or otherwise the global value, which will be None if not set + value = value_profile if value_profile else config.get_option(option_name) + + return value + + +def load_documentation_profile(): + """Load a dummy profile just for the purposes of being able to build the documentation. + + The building of the documentation will require importing the `aiida` package and some code will try to access the + loaded configuration and profile, which if not done will except. On top of that, Django will raise an exception if + the database models are loaded before its settings are loaded. This also is taken care of by loading a Django + profile and loading the corresponding backend. Calling this function will perform all these requirements allowing + the documentation to be built without having to install and configure AiiDA nor having an actual database present. + """ + import tempfile + from aiida.manage.manager import get_manager + from .config import Config + from .profile import Profile + + global PROFILE + global CONFIG + + with tempfile.NamedTemporaryFile() as handle: + profile_name = 'readthedocs' + profile = { + 'AIIDADB_ENGINE': 'postgresql_psycopg2', + 'AIIDADB_BACKEND': 'django', + 'AIIDADB_PORT': 5432, + 'AIIDADB_HOST': 'localhost', + 'AIIDADB_NAME': 'aiidadb', + 'AIIDADB_PASS': 'aiidadb', + 'AIIDADB_USER': 'aiida', + 'AIIDADB_REPOSITORY_URI': 'file:///dev/null', + } + config = {'default_profile': profile_name, 'profiles': {profile_name: profile}} + PROFILE = Profile(profile_name, profile, from_config=True) + CONFIG = Config(handle.name, config) + get_manager()._load_backend(schema_check=False, repository_check=False) # pylint: disable=protected-access diff --git a/aiida/manage/configuration/main.py b/aiida/manage/configuration/main.py deleted file mode 100644 index f27f589c3a..0000000000 --- a/aiida/manage/configuration/main.py +++ /dev/null @@ -1,267 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=undefined-variable,wildcard-import,global-statement,redefined-outer-name,cyclic-import -"""Modules related to the configuration of an AiiDA instance.""" -import os -import shutil -import warnings - -from aiida.common.warnings import AiidaDeprecationWarning - -CONFIG = None -PROFILE = None -BACKEND_UUID = None # This will be set to the UUID of the profile as soon as its corresponding backend is loaded - -__all__ = ('get_config', 'get_config_option', 'get_config_path', 'load_profile', 'reset_config', 'CONFIG', 'PROFILE', 'BACKEND_UUID') - - -def load_profile(profile=None): - """Load a profile. - - .. note:: if a profile is already loaded and no explicit profile is specified, nothing will be done - - :param profile: the name of the profile to load, by default will use the one marked as default in the config - :type profile: str - - :return: the loaded `Profile` instance - :rtype: :class:`~aiida.manage.configuration.Profile` - :raises `aiida.common.exceptions.InvalidOperation`: if the backend of another profile has already been loaded - """ - from aiida.common import InvalidOperation - from aiida.common.log import configure_logging - - global PROFILE - global BACKEND_UUID - - # If a profile is loaded and the specified profile name is None or that of the currently loaded, do nothing - if PROFILE and (profile is None or PROFILE.name is profile): - return PROFILE - - PROFILE = get_config().get_profile(profile) - - if BACKEND_UUID is not None and BACKEND_UUID != PROFILE.uuid: - # Once the switching of profiles with different backends becomes possible, the backend has to be reset properly - raise InvalidOperation('cannot switch profile because backend of another profile is already loaded') - - # Reconfigure the logging to make sure that profile specific logging configuration options are taken into account. - # Note that we do not configure with `with_orm=True` because that will force the backend to be loaded. This should - # instead be done lazily in `Manager._load_backend`. - configure_logging() - - return PROFILE - - -def get_config_path(): - """Returns path to .aiida configuration directory.""" - from .settings import AIIDA_CONFIG_FOLDER, DEFAULT_CONFIG_FILE_NAME - - return os.path.join(AIIDA_CONFIG_FOLDER, DEFAULT_CONFIG_FILE_NAME) - - -def load_config(create=False): - """Instantiate Config object representing an AiiDA configuration file. - - Warning: Contrary to :func:`~aiida.manage.configuration.get_config`, this function is uncached and will always - create a new Config object. You may want to call :func:`~aiida.manage.configuration.get_config` instead. - - :param create: if True, will create the configuration file if it does not already exist - :type create: bool - - :return: the config - :rtype: :class:`~aiida.manage.configuration.config.Config` - :raises aiida.common.MissingConfigurationError: if the configuration file could not be found and create=False - """ - from aiida.common import exceptions - from .config import Config - - filepath = get_config_path() - - if not os.path.isfile(filepath) and not create: - raise exceptions.MissingConfigurationError(f'configuration file {filepath} does not exist') - - try: - config = Config.from_file(filepath) - except ValueError as exc: - raise exceptions.ConfigurationError(f'configuration file {filepath} contains invalid JSON') from exc - - _merge_deprecated_cache_yaml(config, filepath) - - return config - - -def _merge_deprecated_cache_yaml(config, filepath): - """Merge the deprecated cache_config.yml into the config.""" - from aiida.common import timezone - cache_path = os.path.join(os.path.dirname(filepath), 'cache_config.yml') - if not os.path.exists(cache_path): - return - - cache_path_backup = None - # Keep generating a new backup filename based on the current time until it does not exist - while not cache_path_backup or os.path.isfile(cache_path_backup): - cache_path_backup = f"{cache_path}.{timezone.now().strftime('%Y%m%d-%H%M%S.%f')}" - - warnings.warn( - 'cache_config.yml use is deprecated and support will be removed in `v3.0`. Merging into config.json and ' - f'moving to: {cache_path_backup}', AiidaDeprecationWarning - ) - import yaml - with open(cache_path, 'r', encoding='utf8') as handle: - cache_config = yaml.safe_load(handle) - for profile_name, data in cache_config.items(): - if profile_name not in config.profile_names: - warnings.warn(f"Profile '{profile_name}' from cache_config.yml not in config.json, skipping", UserWarning) - continue - for key, option_name in [('default', 'caching.default_enabled'), ('enabled', 'caching.enabled_for'), - ('disabled', 'caching.disabled_for')]: - if key in data: - value = data[key] - # in case of empty key - value = [] if value is None and key != 'default' else value - config.set_option(option_name, value, scope=profile_name) - config.store() - shutil.move(cache_path, cache_path_backup) - - -def get_profile(): - """Return the currently loaded profile. - - :return: the globally loaded `Profile` instance or `None` - :rtype: :class:`~aiida.manage.configuration.Profile` - """ - global PROFILE - return PROFILE - - -def reset_profile(): - """Reset the globally loaded profile. - - .. warning:: This is experimental functionality and should for now be used only internally. If the reset is unclean - weird unknown side-effects may occur that end up corrupting or destroying data. - """ - global PROFILE - global BACKEND_UUID - PROFILE = None - BACKEND_UUID = None - - -def reset_config(): - """Reset the globally loaded config. - - .. warning:: This is experimental functionality and should for now be used only internally. If the reset is unclean - weird unknown side-effects may occur that end up corrupting or destroying data. - """ - global CONFIG - CONFIG = None - - -def get_config(create=False): - """Return the current configuration. - - If the configuration has not been loaded yet - * the configuration is loaded using ``load_config`` - * the global `CONFIG` variable is set - * the configuration object is returned - - Note: This function will except if no configuration file can be found. Only call this function, if you need - information from the configuration file. - - :param create: if True, will create the configuration file if it does not already exist - :type create: bool - - :return: the config - :rtype: :class:`~aiida.manage.configuration.config.Config` - :raises aiida.common.ConfigurationError: if the configuration file could not be found, read or deserialized - """ - global CONFIG - - if not CONFIG: - CONFIG = load_config(create=create) - - if CONFIG.get_option('warnings.showdeprecations'): - # If the user does not want to get AiiDA deprecation warnings, we disable them - this can be achieved with:: - # verdi config warnings.showdeprecations False - # Note that the AiidaDeprecationWarning does NOT inherit from DeprecationWarning - warnings.simplefilter('default', AiidaDeprecationWarning) # pylint: disable=no-member - # This should default to 'once', i.e. once per different message - else: - warnings.simplefilter('ignore', AiidaDeprecationWarning) # pylint: disable=no-member - - return CONFIG - - -def get_config_option(option_name): - """Return the value for the given configuration option. - - This function will attempt to load the value of the option as defined for the current profile or otherwise as - defined configuration wide. If no configuration is yet loaded, this function will fall back on the default that may - be defined for the option itself. This is useful for options that need to be defined at loading time of AiiDA when - no configuration is yet loaded or may not even yet exist. In cases where one expects a profile to be loaded, - preference should be given to retrieving the option through the Config instance and its `get_option` method. - - :param option_name: the name of the configuration option - :type option_name: str - - :return: option value as specified for the profile/configuration if loaded, otherwise option default - """ - from aiida.common import exceptions - - option = options.get_option(option_name) - - try: - config = get_config(create=True) - except exceptions.ConfigurationError: - value = option.default if option.default is not options.NO_DEFAULT else None - else: - if config.current_profile: - # Try to get the option for the profile, but do not return the option default - value_profile = config.get_option(option_name, scope=config.current_profile.name, default=False) - else: - value_profile = None - - # Value is the profile value if defined or otherwise the global value, which will be None if not set - value = value_profile if value_profile else config.get_option(option_name) - - return value - - -def load_documentation_profile(): - """Load a dummy profile just for the purposes of being able to build the documentation. - - The building of the documentation will require importing the `aiida` package and some code will try to access the - loaded configuration and profile, which if not done will except. On top of that, Django will raise an exception if - the database models are loaded before its settings are loaded. This also is taken care of by loading a Django - profile and loading the corresponding backend. Calling this function will perform all these requirements allowing - the documentation to be built without having to install and configure AiiDA nor having an actual database present. - """ - import tempfile - from aiida.manage.manager import get_manager - from .config import Config - from .profile import Profile - - global PROFILE - global CONFIG - - with tempfile.NamedTemporaryFile() as handle: - profile_name = 'readthedocs' - profile = { - 'AIIDADB_ENGINE': 'postgresql_psycopg2', - 'AIIDADB_BACKEND': 'django', - 'AIIDADB_PORT': 5432, - 'AIIDADB_HOST': 'localhost', - 'AIIDADB_NAME': 'aiidadb', - 'AIIDADB_PASS': 'aiidadb', - 'AIIDADB_USER': 'aiida', - 'AIIDADB_REPOSITORY_URI': 'file:///dev/null', - } - config = {'default_profile': profile_name, 'profiles': {profile_name: profile}} - PROFILE = Profile(profile_name, profile, from_config=True) - CONFIG = Config(handle.name, config) - get_manager()._load_backend(schema_check=False, repository_check=False) # pylint: disable=protected-access diff --git a/aiida/manage/configuration/migrations/__init__.py b/aiida/manage/configuration/migrations/__init__.py index 2170fcdcce..4aad63827b 100644 --- a/aiida/manage/configuration/migrations/__init__.py +++ b/aiida/manage/configuration/migrations/__init__.py @@ -10,6 +10,7 @@ """Methods and definitions of migrations for the configuration file of an AiiDA instance.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -23,3 +24,5 @@ 'config_needs_migrating', 'get_current_version', ) + +# yapf: enable diff --git a/aiida/manage/database/__init__.py b/aiida/manage/database/__init__.py index 7cb1bdb138..e6131da0b1 100644 --- a/aiida/manage/database/__init__.py +++ b/aiida/manage/database/__init__.py @@ -9,6 +9,7 @@ ########################################################################### # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -21,3 +22,5 @@ 'verify_uuid_uniqueness', 'write_database_integrity_violation', ) + +# yapf: enable diff --git a/aiida/manage/database/integrity/__init__.py b/aiida/manage/database/integrity/__init__.py index 4bc2bb0f47..6bb9ca7f26 100644 --- a/aiida/manage/database/integrity/__init__.py +++ b/aiida/manage/database/integrity/__init__.py @@ -10,6 +10,7 @@ """Methods to validate the database integrity and fix violations.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -23,3 +24,5 @@ 'verify_uuid_uniqueness', 'write_database_integrity_violation', ) + +# yapf: enable diff --git a/aiida/manage/external/__init__.py b/aiida/manage/external/__init__.py index 10bf08219a..d82852a0da 100644 --- a/aiida/manage/external/__init__.py +++ b/aiida/manage/external/__init__.py @@ -10,6 +10,7 @@ """User facing APIs to control AiiDA from the verdi cli, scripts or plugins""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -26,3 +27,5 @@ 'ProcessLauncher', 'RemoteException', ) + +# yapf: enable diff --git a/aiida/manage/tests/__init__.py b/aiida/manage/tests/__init__.py index 7d5fb9195a..f60a6c3c7f 100644 --- a/aiida/manage/tests/__init__.py +++ b/aiida/manage/tests/__init__.py @@ -13,12 +13,25 @@ """ # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import +from .main import * from .unittest_classes import * __all__ = ( 'PluginTestCase', + 'ProfileManager', + 'TemporaryProfileManager', + 'TestManager', + 'TestManagerError', 'TestRunner', + '_GLOBAL_TEST_MANAGER', + 'get_test_backend_name', + 'get_test_profile_name', + 'get_user_dict', + 'test_manager', ) + +# yapf: enable diff --git a/aiida/manage/tests/main.py b/aiida/manage/tests/main.py new file mode 100644 index 0000000000..c6a333670f --- /dev/null +++ b/aiida/manage/tests/main.py @@ -0,0 +1,519 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +""" +Testing infrastructure for easy testing of AiiDA plugins. + +""" +import tempfile +import shutil +import os +from contextlib import contextmanager + +from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA +from aiida.common import exceptions +from aiida.manage import configuration +from aiida.manage.configuration.settings import create_instance_directories +from aiida.manage import manager +from aiida.manage.external.postgres import Postgres + +__all__ = ( + '_GLOBAL_TEST_MANAGER', + 'get_test_profile_name', + 'get_test_backend_name', + 'get_user_dict', + 'test_manager', + 'TestManager', + 'TestManagerError', + 'ProfileManager', + 'TemporaryProfileManager', +) + +_DEFAULT_PROFILE_INFO = { + 'name': 'test_profile', + 'email': 'tests@aiida.mail', + 'first_name': 'AiiDA', + 'last_name': 'Plugintest', + 'institution': 'aiidateam', + 'database_engine': 'postgresql_psycopg2', + 'database_backend': 'django', + 'database_username': 'aiida', + 'database_password': 'aiida_pw', + 'database_name': 'aiida_db', + 'repo_dir': 'test_repo', + 'config_dir': '.aiida', + 'root_path': '', + 'broker_protocol': 'amqp', + 'broker_username': 'guest', + 'broker_password': 'guest', + 'broker_host': '127.0.0.1', + 'broker_port': 5672, + 'broker_virtual_host': '' +} + + +class TestManagerError(Exception): + """Raised by TestManager in situations that may lead to inconsistent behaviour.""" + + def __init__(self, msg): + super().__init__() + self.msg = msg + + def __str__(self): + return repr(self.msg) + + +class TestManager: + """ + Test manager for plugin tests. + + Uses either ProfileManager for wrapping an existing profile or TemporaryProfileManager for setting up a complete + temporary AiiDA environment. + + For usage with pytest, see :py:class:`~aiida.manage.tests.pytest_fixtures`. + """ + + def __init__(self): + self._manager = None + + def use_temporary_profile(self, backend=None, pgtest=None): + """Set up Test manager to use temporary AiiDA profile. + + Uses :py:class:`aiida.manage.tests.TemporaryProfileManager` internally. + + :param backend: Backend to use. + :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, + e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. + + """ + if configuration.PROFILE is not None: + raise TestManagerError('AiiDA dbenv must not be loaded before setting up a test profile.') + if self._manager is not None: + raise TestManagerError('Profile manager already loaded.') + + mngr = TemporaryProfileManager(backend=backend, pgtest=pgtest) + mngr.create_profile() + self._manager = mngr # don't assign before profile has actually been created! + + def use_profile(self, profile_name): + """Set up Test manager to use existing profile. + + Uses :py:class:`aiida.manage.tests.ProfileManager` internally. + + :param profile_name: Name of existing test profile to use. + """ + if configuration.PROFILE is not None: + raise TestManagerError('AiiDA dbenv must not be loaded before setting up a test profile.') + if self._manager is not None: + raise TestManagerError('Profile manager already loaded.') + + self._manager = ProfileManager(profile_name=profile_name) + self._manager.init_db() + + def has_profile_open(self): + return self._manager and self._manager.has_profile_open() + + def reset_db(self): + return self._manager.reset_db() + + def destroy_all(self): + if self._manager: + self._manager.destroy_all() + self._manager = None + + +class ProfileManager: + """ + Wraps existing AiiDA profile. + """ + + def __init__(self, profile_name): + """ + Use an existing profile. + + :param profile_name: Name of the profile to be loaded + """ + from aiida import load_profile + from aiida.backends.testbase import check_if_tests_can_run + + self._profile = None + self._user = None + + try: + self._profile = load_profile(profile_name) + manager.get_manager()._load_backend(schema_check=False) # pylint: disable=protected-access + except Exception: + raise TestManagerError('Unable to load test profile \'{}\'.'.format(profile_name)) + check_if_tests_can_run() + + self._select_db_test_case(backend=self._profile.database_backend) + + def _select_db_test_case(self, backend): + """ + Selects tests case for the correct database backend. + """ + if backend == BACKEND_DJANGO: + from aiida.backends.djsite.db.testbase import DjangoTests + self._test_case = DjangoTests() + elif backend == BACKEND_SQLA: + from aiida.backends.sqlalchemy.testbase import SqlAlchemyTests + from aiida.backends.sqlalchemy import get_scoped_session + + self._test_case = SqlAlchemyTests() + self._test_case.test_session = get_scoped_session() + + def reset_db(self): + self._test_case.clean_db() # will drop all users + manager.reset_manager() + self.init_db() + + def init_db(self): + """Initialise the database state for running of tests. + + Adds default user if necessary. + """ + from aiida.orm import User + from aiida.cmdline.commands.cmd_user import set_default_user + + if not User.objects.get_default(): + user_dict = get_user_dict(_DEFAULT_PROFILE_INFO) + try: + user = User(**user_dict) + user.store() + except exceptions.IntegrityError: + # The user already exists, no problem + user = User.objects.get(**user_dict) + + set_default_user(self._profile, user) + User.objects.reset() # necessary to pick up new default user + + def has_profile_open(self): + return self._profile is not None + + def destroy_all(self): + pass + + +class TemporaryProfileManager(ProfileManager): + """ + Manage the life cycle of a completely separated and temporary AiiDA environment. + + * No profile / database setup required + * Tests run via the TemporaryProfileManager never pollute the user's working environment + + Filesystem: + + * temporary ``.aiida`` configuration folder + * temporary repository folder + + Database: + + * temporary database cluster (via the ``pgtest`` package) + * with ``aiida`` database user + * with ``aiida_db`` database + + AiiDA: + + * configured to use the temporary configuration + * sets up a temporary profile for tests + + All of this happens automatically when using the corresponding tests classes & tests runners (unittest) + or fixtures (pytest). + + Example:: + + tests = TemporaryProfileManager(backend=backend) + tests.create_aiida_db() # set up only the database + tests.create_profile() # set up a profile (creates the db too if necessary) + + # ready for tests + + # run tests 1 + + tests.reset_db() + # database ready for independent tests 2 + + # run tests 2 + + tests.destroy_all() + # everything cleaned up + + """ + + _test_case = None + + def __init__(self, backend=BACKEND_DJANGO, pgtest=None): # pylint: disable=super-init-not-called + """Construct a TemporaryProfileManager + + :param backend: a database backend + :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, + e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. + + """ + from aiida.manage.configuration import settings + + self.dbinfo = {} + self.profile_info = _DEFAULT_PROFILE_INFO + self.profile_info['database_backend'] = backend + self._pgtest = pgtest or {} + + self.pg_cluster = None + self.postgres = None + self._profile = None + self._has_test_db = False + self._backup = { + 'config': configuration.CONFIG, + 'config_dir': settings.AIIDA_CONFIG_FOLDER, + 'profile': configuration.PROFILE, + } + + @property + def profile_dictionary(self): + """Profile parameters. + + Used to set up AiiDA profile from self.profile_info dictionary. + """ + dictionary = { + 'database_engine': self.profile_info.get('database_engine'), + 'database_backend': self.profile_info.get('database_backend'), + 'database_port': self.profile_info.get('database_port'), + 'database_hostname': self.profile_info.get('database_hostname'), + 'database_name': self.profile_info.get('database_name'), + 'database_username': self.profile_info.get('database_username'), + 'database_password': self.profile_info.get('database_password'), + 'broker_protocol': self.profile_info.get('broker_protocol'), + 'broker_username': self.profile_info.get('broker_username'), + 'broker_password': self.profile_info.get('broker_password'), + 'broker_host': self.profile_info.get('broker_host'), + 'broker_port': self.profile_info.get('broker_port'), + 'broker_virtual_host': self.profile_info.get('broker_virtual_host'), + 'repository_uri': f'file://{self.repo}', + } + return dictionary + + def create_db_cluster(self): + """ + Create the database cluster using PGTest. + """ + from pgtest.pgtest import PGTest + + if self.pg_cluster is not None: + raise TestManagerError( + 'Running temporary postgresql cluster detected.Use destroy_all() before creating a new cluster.' + ) + self.pg_cluster = PGTest(**self._pgtest) + self.dbinfo.update(self.pg_cluster.dsn) + + def create_aiida_db(self): + """ + Create the necessary database on the temporary postgres instance. + """ + if configuration.PROFILE is not None: + raise TestManagerError('AiiDA dbenv can not be loaded while creating a tests db environment') + if self.pg_cluster is None: + self.create_db_cluster() + self.postgres = Postgres(interactive=False, quiet=True, dbinfo=self.dbinfo) + # note: not using postgres.create_dbuser_db_safe here since we don't want prompts + self.postgres.create_dbuser(self.profile_info['database_username'], self.profile_info['database_password']) + self.postgres.create_db(self.profile_info['database_username'], self.profile_info['database_name']) + self.dbinfo = self.postgres.dbinfo + self.profile_info['database_hostname'] = self.postgres.host_for_psycopg2 + self.profile_info['database_port'] = self.postgres.port_for_psycopg2 + self._has_test_db = True + + def create_profile(self): + """ + Set AiiDA to use the tests config dir and create a default profile there + + Warning: the AiiDA dbenv must not be loaded when this is called! + """ + from aiida.manage.configuration import settings, load_profile, Profile + + if not self._has_test_db: + self.create_aiida_db() + + if not self.root_dir: + self.root_dir = tempfile.mkdtemp() + configuration.CONFIG = None + settings.AIIDA_CONFIG_FOLDER = self.config_dir + configuration.PROFILE = None + create_instance_directories() + profile_name = self.profile_info['name'] + config = configuration.get_config(create=True) + profile = Profile(profile_name, self.profile_dictionary) + config.add_profile(profile) + config.set_default_profile(profile_name).store() + self._profile = profile + + load_profile(profile_name) + backend = manager.get_manager()._load_backend(schema_check=False) + backend.migrate() + + self._select_db_test_case(backend=self._profile.database_backend) + self.init_db() + + def repo_ok(self): + return bool(self.repo and os.path.isdir(os.path.dirname(self.repo))) + + @property + def repo(self): + return self._return_dir(self.profile_info['repo_dir']) + + def _return_dir(self, dir_path): + """Return a path to a directory from the fs environment""" + if os.path.isabs(dir_path): + return dir_path + return os.path.join(self.root_dir, dir_path) + + @property + def backend(self): + return self.profile_info['backend'] + + @backend.setter + def backend(self, backend): + if self.has_profile_open(): + raise TestManagerError('backend cannot be changed after setting up the environment') + + valid_backends = [BACKEND_DJANGO, BACKEND_SQLA] + if backend not in valid_backends: + raise ValueError(f'invalid backend {backend}, must be one of {valid_backends}') + self.profile_info['backend'] = backend + + @property + def config_dir_ok(self): + return bool(self.config_dir and os.path.isdir(self.config_dir)) + + @property + def config_dir(self): + return self._return_dir(self.profile_info['config_dir']) + + @property + def root_dir(self): + return self.profile_info['root_path'] + + @root_dir.setter + def root_dir(self, root_dir): + self.profile_info['root_path'] = root_dir + + @property + def root_dir_ok(self): + return bool(self.root_dir and os.path.isdir(self.root_dir)) + + def destroy_all(self): + """Remove all traces of the tests run""" + from aiida.manage.configuration import settings + if self.root_dir: + shutil.rmtree(self.root_dir) + self.root_dir = None + if self.pg_cluster: + self.pg_cluster.close() + self.pg_cluster = None + self._has_test_db = False + self._profile = None + self._user = None + + if 'config' in self._backup: + configuration.CONFIG = self._backup['config'] + if 'config_dir' in self._backup: + settings.AIIDA_CONFIG_FOLDER = self._backup['config_dir'] + if 'profile' in self._backup: + configuration.PROFILE = self._backup['profile'] + + def has_profile_open(self): + return self._profile is not None + + +_GLOBAL_TEST_MANAGER = TestManager() + + +@contextmanager +def test_manager(backend=BACKEND_DJANGO, profile_name=None, pgtest=None): + """ Context manager for TestManager objects. + + Sets up temporary AiiDA environment for testing or reuses existing environment, + if `AIIDA_TEST_PROFILE` environment variable is set. + + Example pytest fixture:: + + def aiida_profile(): + with test_manager(backend) as test_mgr: + yield fixture_mgr + + Example unittest test runner:: + + with test_manager(backend) as test_mgr: + # ready for tests + # everything cleaned up + + + :param backend: database backend, either BACKEND_SQLA or BACKEND_DJANGO + :param profile_name: name of test profile to be used or None (to use temporary profile) + :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, + e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. + """ + from aiida.common.utils import Capturing + from aiida.common.log import configure_logging + + try: + if not _GLOBAL_TEST_MANAGER.has_profile_open(): + if profile_name: + _GLOBAL_TEST_MANAGER.use_profile(profile_name=profile_name) + else: + with Capturing(): # capture output of AiiDA DB setup + _GLOBAL_TEST_MANAGER.use_temporary_profile(backend=backend, pgtest=pgtest) + configure_logging(with_orm=True) + yield _GLOBAL_TEST_MANAGER + finally: + _GLOBAL_TEST_MANAGER.destroy_all() + + +def get_test_backend_name(): + """ Read name of database backend from environment variable or the specified test profile. + + Reads database backend ('django' or 'sqlalchemy') from 'AIIDA_TEST_BACKEND' environment variable, + or the backend configured for the 'AIIDA_TEST_PROFILE'. + Defaults to django backend. + + :returns: content of environment variable or `BACKEND_DJANGO` + :raises: ValueError if unknown backend name detected. + :raises: ValueError if both 'AIIDA_TEST_BACKEND' and 'AIIDA_TEST_PROFILE' are set, and the two + backends do not match. + """ + test_profile_name = get_test_profile_name() + backend_env = os.environ.get('AIIDA_TEST_BACKEND', None) + if test_profile_name is not None: + backend_profile = configuration.get_config().get_profile(test_profile_name).database_backend + if backend_env is not None and backend_env != backend_profile: + raise ValueError( + "The backend '{}' read from AIIDA_TEST_BACKEND does not match the backend '{}' " + "of AIIDA_TEST_PROFILE '{}'".format(backend_env, backend_profile, test_profile_name) + ) + backend_res = backend_profile + else: + backend_res = backend_env or BACKEND_DJANGO + + if backend_res in (BACKEND_DJANGO, BACKEND_SQLA): + return backend_res + raise ValueError(f"Unknown backend '{backend_res}' read from AIIDA_TEST_BACKEND environment variable") + + +def get_test_profile_name(): + """ Read name of test profile from environment variable. + + Reads name of existing test profile 'AIIDA_TEST_PROFILE' environment variable. + If specified, this profile is used for running the tests (instead of setting up a temporary profile). + + :returns: content of environment variable or `None` + """ + return os.environ.get('AIIDA_TEST_PROFILE', None) + + +def get_user_dict(profile_dict): + """Collect parameters required for creating users.""" + return {k: profile_dict[k] for k in ('email', 'first_name', 'last_name', 'institution')} diff --git a/aiida/manage/tests/unittest_classes.py b/aiida/manage/tests/unittest_classes.py index 751583eb3b..674fb43606 100644 --- a/aiida/manage/tests/unittest_classes.py +++ b/aiida/manage/tests/unittest_classes.py @@ -15,7 +15,7 @@ from aiida.common.warnings import AiidaDeprecationWarning from aiida.manage.manager import get_manager -from . import _GLOBAL_TEST_MANAGER, test_manager, get_test_backend_name, get_test_profile_name +from .main import _GLOBAL_TEST_MANAGER, test_manager, get_test_backend_name, get_test_profile_name __all__ = ('PluginTestCase', 'TestRunner') diff --git a/aiida/orm/__init__.py b/aiida/orm/__init__.py index b3aab44ca0..e09549c7c8 100644 --- a/aiida/orm/__init__.py +++ b/aiida/orm/__init__.py @@ -10,6 +10,7 @@ """Main module to expose all orm classes and methods""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -39,6 +40,7 @@ 'CalcJobResultManager', 'CalculationEntityLoader', 'CalculationNode', + 'CifData', 'Code', 'CodeEntityLoader', 'Collection', @@ -90,15 +92,20 @@ 'WorkFunctionNode', 'WorkflowNode', 'XyData', + 'cif_from_ase', 'find_bandgap', 'get_loader', 'get_query_type_from_type_string', 'get_type_string_from_class', + 'has_pycifrw', 'load_code', 'load_computer', 'load_group', 'load_node', 'load_node_class', + 'pycifrw_from_cif', 'to_aiida_type', 'validate_link', ) + +# yapf: enable diff --git a/aiida/orm/implementation/__init__.py b/aiida/orm/implementation/__init__.py index 6ee6d1bbe0..332752718a 100644 --- a/aiida/orm/implementation/__init__.py +++ b/aiida/orm/implementation/__init__.py @@ -10,6 +10,7 @@ """Module with the implementations of the various backend entities for various database backends.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -50,3 +51,5 @@ 'clean_value', 'validate_attribute_extra_key', ) + +# yapf: enable diff --git a/aiida/orm/implementation/django/__init__.py b/aiida/orm/implementation/django/__init__.py index 6957af28c6..919c6a84ba 100644 --- a/aiida/orm/implementation/django/__init__.py +++ b/aiida/orm/implementation/django/__init__.py @@ -9,6 +9,7 @@ ########################################################################### # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -25,3 +26,5 @@ 'DjangoUserCollection', 'get_backend_entity', ) + +# yapf: enable diff --git a/aiida/orm/implementation/sql/__init__.py b/aiida/orm/implementation/sql/__init__.py index d5b115fd4e..439cd9ba84 100644 --- a/aiida/orm/implementation/sql/__init__.py +++ b/aiida/orm/implementation/sql/__init__.py @@ -14,6 +14,7 @@ """ # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -22,3 +23,5 @@ __all__ = ( 'SqlBackend', ) + +# yapf: enable diff --git a/aiida/orm/implementation/sqlalchemy/__init__.py b/aiida/orm/implementation/sqlalchemy/__init__.py index b836c7c844..f1ba8458df 100644 --- a/aiida/orm/implementation/sqlalchemy/__init__.py +++ b/aiida/orm/implementation/sqlalchemy/__init__.py @@ -9,6 +9,7 @@ ########################################################################### # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -25,3 +26,5 @@ 'SqlaUserCollection', 'get_backend_entity', ) + +# yapf: enable diff --git a/aiida/orm/nodes/__init__.py b/aiida/orm/nodes/__init__.py index f15d59e7da..d7498592b4 100644 --- a/aiida/orm/nodes/__init__.py +++ b/aiida/orm/nodes/__init__.py @@ -10,6 +10,7 @@ """Module with `Node` sub classes for data and processes.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -26,6 +27,7 @@ 'CalcFunctionNode', 'CalcJobNode', 'CalculationNode', + 'CifData', 'Code', 'Data', 'Dict', @@ -54,6 +56,11 @@ 'WorkFunctionNode', 'WorkflowNode', 'XyData', + 'cif_from_ase', 'find_bandgap', + 'has_pycifrw', + 'pycifrw_from_cif', 'to_aiida_type', ) + +# yapf: enable diff --git a/aiida/orm/nodes/data/__init__.py b/aiida/orm/nodes/data/__init__.py index 5cb9339df4..31292c91eb 100644 --- a/aiida/orm/nodes/data/__init__.py +++ b/aiida/orm/nodes/data/__init__.py @@ -10,12 +10,14 @@ """Module with `Node` sub classes for data structures.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import from .array import * from .base import * from .bool import * +from .cif import * from .code import * from .data import * from .dict import * @@ -36,6 +38,7 @@ 'BandsData', 'BaseType', 'Bool', + 'CifData', 'Code', 'Data', 'Dict', @@ -58,6 +61,11 @@ 'TrajectoryData', 'UpfData', 'XyData', + 'cif_from_ase', 'find_bandgap', + 'has_pycifrw', + 'pycifrw_from_cif', 'to_aiida_type', ) + +# yapf: enable diff --git a/aiida/orm/nodes/data/array/__init__.py b/aiida/orm/nodes/data/array/__init__.py index 8cdfec92b9..f12feedfbe 100644 --- a/aiida/orm/nodes/data/array/__init__.py +++ b/aiida/orm/nodes/data/array/__init__.py @@ -10,6 +10,7 @@ """Module with `Node` sub classes for array based data structures.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -29,3 +30,5 @@ 'XyData', 'find_bandgap', ) + +# yapf: enable diff --git a/aiida/orm/nodes/data/array/trajectory.py b/aiida/orm/nodes/data/array/trajectory.py index 9edd934d59..575c94133b 100644 --- a/aiida/orm/nodes/data/array/trajectory.py +++ b/aiida/orm/nodes/data/array/trajectory.py @@ -17,6 +17,7 @@ __all__ = ('TrajectoryData',) + class TrajectoryData(ArrayData): """ Stores a trajectory (a sequence of crystal structures with timestamps, and diff --git a/aiida/orm/nodes/data/cif.py b/aiida/orm/nodes/data/cif.py index 53ad65c098..e5e5339269 100644 --- a/aiida/orm/nodes/data/cif.py +++ b/aiida/orm/nodes/data/cif.py @@ -15,6 +15,8 @@ from .singlefile import SinglefileData +__all__ = ('CifData', 'cif_from_ase', 'has_pycifrw', 'pycifrw_from_cif') + ase_loops = { '_atom_site': [ '_atom_site_label', diff --git a/aiida/orm/nodes/data/remote/__init__.py b/aiida/orm/nodes/data/remote/__init__.py index a57cdcf2f5..ae1b5dbc4f 100644 --- a/aiida/orm/nodes/data/remote/__init__.py +++ b/aiida/orm/nodes/data/remote/__init__.py @@ -2,6 +2,7 @@ """Module with data plugins that represent remote resources and so effectively are symbolic links.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -13,3 +14,5 @@ 'RemoteStashData', 'RemoteStashFolderData', ) + +# yapf: enable diff --git a/aiida/orm/nodes/data/remote/stash/__init__.py b/aiida/orm/nodes/data/remote/stash/__init__.py index 74b714dca9..e06481e842 100644 --- a/aiida/orm/nodes/data/remote/stash/__init__.py +++ b/aiida/orm/nodes/data/remote/stash/__init__.py @@ -2,6 +2,7 @@ """Module with data plugins that represent files of completed calculations jobs that have been stashed.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -12,3 +13,5 @@ 'RemoteStashData', 'RemoteStashFolderData', ) + +# yapf: enable diff --git a/aiida/orm/nodes/process/__init__.py b/aiida/orm/nodes/process/__init__.py index 2e698cd3bc..283b14e9b0 100644 --- a/aiida/orm/nodes/process/__init__.py +++ b/aiida/orm/nodes/process/__init__.py @@ -10,6 +10,7 @@ """Module with `Node` sub classes for processes.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -26,3 +27,5 @@ 'WorkFunctionNode', 'WorkflowNode', ) + +# yapf: enable diff --git a/aiida/orm/nodes/process/calculation/__init__.py b/aiida/orm/nodes/process/calculation/__init__.py index f7ce163fc7..21af4e576e 100644 --- a/aiida/orm/nodes/process/calculation/__init__.py +++ b/aiida/orm/nodes/process/calculation/__init__.py @@ -10,6 +10,7 @@ """Module with `Node` sub classes for calculation processes.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -22,3 +23,5 @@ 'CalcJobNode', 'CalculationNode', ) + +# yapf: enable diff --git a/aiida/orm/nodes/process/workflow/__init__.py b/aiida/orm/nodes/process/workflow/__init__.py index 2f13f3af94..f4125a4f8f 100644 --- a/aiida/orm/nodes/process/workflow/__init__.py +++ b/aiida/orm/nodes/process/workflow/__init__.py @@ -10,6 +10,7 @@ """Module with `Node` sub classes for workflow processes.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -22,3 +23,5 @@ 'WorkFunctionNode', 'WorkflowNode', ) + +# yapf: enable diff --git a/aiida/orm/utils/__init__.py b/aiida/orm/utils/__init__.py index bbe2bc6282..7b049553cb 100644 --- a/aiida/orm/utils/__init__.py +++ b/aiida/orm/utils/__init__.py @@ -10,6 +10,7 @@ """Utilities related to the ORM.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -44,3 +45,5 @@ 'load_node_class', 'validate_link', ) + +# yapf: enable diff --git a/aiida/parsers/__init__.py b/aiida/parsers/__init__.py index 9198ba3a42..b3789ed596 100644 --- a/aiida/parsers/__init__.py +++ b/aiida/parsers/__init__.py @@ -10,6 +10,7 @@ """Module for classes and utilities to write parsers for calculation jobs.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -18,3 +19,5 @@ __all__ = ( 'Parser', ) + +# yapf: enable diff --git a/aiida/plugins/__init__.py b/aiida/plugins/__init__.py index 349d07e35a..14a89108c0 100644 --- a/aiida/plugins/__init__.py +++ b/aiida/plugins/__init__.py @@ -10,6 +10,7 @@ """Classes and functions to load and interact with plugin classes accessible through defined entry points.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -32,3 +33,5 @@ 'load_entry_point', 'load_entry_point_from_string', ) + +# yapf: enable diff --git a/aiida/repository/__init__.py b/aiida/repository/__init__.py index 91876a3342..c828ca07f1 100644 --- a/aiida/repository/__init__.py +++ b/aiida/repository/__init__.py @@ -10,6 +10,7 @@ """Module with resources dealing with the file repository.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -25,3 +26,5 @@ 'Repository', 'SandboxRepositoryBackend', ) + +# yapf: enable diff --git a/aiida/repository/backend/__init__.py b/aiida/repository/backend/__init__.py index 4d890c6dae..ea4ab3386f 100644 --- a/aiida/repository/backend/__init__.py +++ b/aiida/repository/backend/__init__.py @@ -2,6 +2,7 @@ """Module for file repository backend implementations.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -14,3 +15,5 @@ 'DiskObjectStoreRepositoryBackend', 'SandboxRepositoryBackend', ) + +# yapf: enable diff --git a/aiida/restapi/__init__.py b/aiida/restapi/__init__.py index 4b340e2689..90d8726344 100644 --- a/aiida/restapi/__init__.py +++ b/aiida/restapi/__init__.py @@ -14,6 +14,7 @@ """ # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -23,3 +24,5 @@ 'configure_api', 'run_api', ) + +# yapf: enable diff --git a/aiida/schedulers/__init__.py b/aiida/schedulers/__init__.py index fd52173e18..5fad6ad78f 100644 --- a/aiida/schedulers/__init__.py +++ b/aiida/schedulers/__init__.py @@ -10,6 +10,7 @@ """Module for classes and utilities to interact with cluster schedulers.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -28,3 +29,5 @@ 'SchedulerError', 'SchedulerParsingError', ) + +# yapf: enable diff --git a/aiida/tools/__init__.py b/aiida/tools/__init__.py index 27dad0e172..cb4615adb7 100644 --- a/aiida/tools/__init__.py +++ b/aiida/tools/__init__.py @@ -21,6 +21,7 @@ """ # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -93,3 +94,5 @@ 'spglib_tuple_to_structure', 'structure_to_spglib_tuple', ) + +# yapf: enable diff --git a/aiida/tools/calculations/__init__.py b/aiida/tools/calculations/__init__.py index 10aa462662..7fc43df3e3 100644 --- a/aiida/tools/calculations/__init__.py +++ b/aiida/tools/calculations/__init__.py @@ -10,6 +10,7 @@ """Calculation tool plugins for Calculation classes.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -18,3 +19,5 @@ __all__ = ( 'CalculationTools', ) + +# yapf: enable diff --git a/aiida/tools/data/__init__.py b/aiida/tools/data/__init__.py index 16ea3bdd91..fb200bb489 100644 --- a/aiida/tools/data/__init__.py +++ b/aiida/tools/data/__init__.py @@ -9,6 +9,7 @@ ########################################################################### # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -24,3 +25,5 @@ 'spglib_tuple_to_structure', 'structure_to_spglib_tuple', ) + +# yapf: enable diff --git a/aiida/tools/data/array/__init__.py b/aiida/tools/data/array/__init__.py index 3b67a12982..ebb95e693f 100644 --- a/aiida/tools/data/array/__init__.py +++ b/aiida/tools/data/array/__init__.py @@ -7,8 +7,10 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Tools for manipulating array data classes.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -18,3 +20,5 @@ 'get_explicit_kpoints_path', 'get_kpoints_path', ) + +# yapf: enable diff --git a/aiida/tools/data/array/kpoints/__init__.py b/aiida/tools/data/array/kpoints/__init__.py index 7916aad1d1..ac536c11a9 100644 --- a/aiida/tools/data/array/kpoints/__init__.py +++ b/aiida/tools/data/array/kpoints/__init__.py @@ -13,6 +13,7 @@ """ # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -22,3 +23,5 @@ 'get_explicit_kpoints_path', 'get_kpoints_path', ) + +# yapf: enable diff --git a/aiida/tools/data/orbital/__init__.py b/aiida/tools/data/orbital/__init__.py index ba172980e2..2ece04b528 100644 --- a/aiida/tools/data/orbital/__init__.py +++ b/aiida/tools/data/orbital/__init__.py @@ -10,6 +10,7 @@ """Module for classes and methods that represents molecular orbitals.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -20,3 +21,5 @@ 'Orbital', 'RealhydrogenOrbital', ) + +# yapf: enable diff --git a/aiida/tools/data/orbital/realhydrogen.py b/aiida/tools/data/orbital/realhydrogen.py index 14b601f0ba..c8116659d8 100644 --- a/aiida/tools/data/orbital/realhydrogen.py +++ b/aiida/tools/data/orbital/realhydrogen.py @@ -15,9 +15,9 @@ from .orbital import Orbital, validate_len3_list_or_none, validate_float_or_none - __all__ = ('RealhydrogenOrbital',) + def validate_l(value): """ Validate the value of the angular momentum diff --git a/aiida/tools/graph/__init__.py b/aiida/tools/graph/__init__.py index 36de7bcc94..95cffafca3 100644 --- a/aiida/tools/graph/__init__.py +++ b/aiida/tools/graph/__init__.py @@ -10,6 +10,7 @@ """Provides tools for traversing the provenance graph.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -20,3 +21,5 @@ 'delete_group_nodes', 'delete_nodes', ) + +# yapf: enable diff --git a/aiida/tools/groups/__init__.py b/aiida/tools/groups/__init__.py index b7ad08bbcb..ab74c839aa 100644 --- a/aiida/tools/groups/__init__.py +++ b/aiida/tools/groups/__init__.py @@ -16,6 +16,7 @@ """Provides tools for interacting with AiiDA Groups.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -28,3 +29,5 @@ 'InvalidPath', 'NoGroupsInPathError', ) + +# yapf: enable diff --git a/aiida/tools/importexport/__init__.py b/aiida/tools/importexport/__init__.py index 90a2d9e431..0d545768a0 100644 --- a/aiida/tools/importexport/__init__.py +++ b/aiida/tools/importexport/__init__.py @@ -16,6 +16,7 @@ """ # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -66,3 +67,5 @@ 'import_data', 'null_callback', ) + +# yapf: enable diff --git a/aiida/tools/importexport/archive/__init__.py b/aiida/tools/importexport/archive/__init__.py index 51dd0c35d6..b2dd149d7c 100644 --- a/aiida/tools/importexport/archive/__init__.py +++ b/aiida/tools/importexport/archive/__init__.py @@ -11,6 +11,7 @@ """Readers and writers for archive formats, that work independently of a connection to an AiiDA profile.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -43,3 +44,5 @@ 'get_writer', 'null_callback', ) + +# yapf: enable diff --git a/aiida/tools/importexport/common/__init__.py b/aiida/tools/importexport/common/__init__.py index df4813add8..f2755eade4 100644 --- a/aiida/tools/importexport/common/__init__.py +++ b/aiida/tools/importexport/common/__init__.py @@ -10,6 +10,7 @@ """Common utility functions, classes, and exceptions""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -31,3 +32,5 @@ 'MigrationValidationError', 'ProgressBarError', ) + +# yapf: enable diff --git a/aiida/tools/importexport/dbexport/__init__.py b/aiida/tools/importexport/dbexport/__init__.py index b2b13f57bf..22173490f7 100644 --- a/aiida/tools/importexport/dbexport/__init__.py +++ b/aiida/tools/importexport/dbexport/__init__.py @@ -10,6 +10,7 @@ """Provides export functionalities.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -20,3 +21,5 @@ 'ExportFileFormat', 'export', ) + +# yapf: enable diff --git a/aiida/tools/importexport/dbimport/__init__.py b/aiida/tools/importexport/dbimport/__init__.py index 1d3f712e24..ad987679f1 100644 --- a/aiida/tools/importexport/dbimport/__init__.py +++ b/aiida/tools/importexport/dbimport/__init__.py @@ -10,6 +10,7 @@ """Provides import functionalities.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -19,3 +20,5 @@ 'IMPORT_LOGGER', 'import_data', ) + +# yapf: enable diff --git a/aiida/tools/visualization/__init__.py b/aiida/tools/visualization/__init__.py index d198ea67f9..59532c1bbb 100644 --- a/aiida/tools/visualization/__init__.py +++ b/aiida/tools/visualization/__init__.py @@ -10,6 +10,7 @@ """Provides tools for visualization of the provenance graph.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -22,3 +23,5 @@ 'default_node_sublabels', 'pstate_node_styles', ) + +# yapf: enable diff --git a/aiida/transports/__init__.py b/aiida/transports/__init__.py index 0b9b968f18..c1b7e7e3ce 100644 --- a/aiida/transports/__init__.py +++ b/aiida/transports/__init__.py @@ -10,6 +10,7 @@ """Module for classes and utilities to define transports to other machines.""" # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -22,3 +23,5 @@ 'convert_to_bool', 'parse_sshconfig', ) + +# yapf: enable diff --git a/aiida/transports/plugins/__init__.py b/aiida/transports/plugins/__init__.py index efa2d9ce7f..e578131b6c 100644 --- a/aiida/transports/plugins/__init__.py +++ b/aiida/transports/plugins/__init__.py @@ -9,6 +9,7 @@ ########################################################################### # AUTO-GENERATED + # yapf: disable # pylint: disable=wildcard-import @@ -19,3 +20,5 @@ 'convert_to_bool', 'parse_sshconfig', ) + +# yapf: enable diff --git a/utils/make_all.py b/utils/make_all.py index c9f3b44514..c4d0657cc1 100644 --- a/utils/make_all.py +++ b/utils/make_all.py @@ -1,3 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +"""Pre-commit hook to add ``__all__`` installs to ``__init__`` files.""" import ast from collections import Counter from pathlib import Path @@ -8,41 +11,41 @@ def parse_all(folder_path: str) -> Tuple[dict, dict]: """Walk through all files in folder, and parse the ``__all__`` variable. - + :return: (all dict, dict of unparsable) """ folder_path = Path(folder_path) all_dict = {} bad_all = {} - for path in folder_path.glob("**/*.py"): + for path in folder_path.glob('**/*.py'): # skip module files - if path.name == "__init__.py": + if path.name == '__init__.py': continue # parse the file - parsed = ast.parse(path.read_text(encoding="utf8")) + parsed = ast.parse(path.read_text(encoding='utf8')) # find __all__ assignment all_token = None for token in parsed.body: if not isinstance(token, ast.Assign): continue - if token.targets and getattr(token.targets[0], "id", "") == "__all__": + if token.targets and getattr(token.targets[0], 'id', '') == '__all__': all_token = token break if all_token is None: - bad_all.setdefault("missing", []).append(str(path.relative_to(folder_path))) + bad_all.setdefault('missing', []).append(str(path.relative_to(folder_path))) continue if not isinstance(all_token.value, (ast.List, ast.Tuple)): - bad_all.setdefault("vaue not list/tuple", []).append(str(path.relative_to(folder_path))) + bad_all.setdefault('vaue not list/tuple', []).append(str(path.relative_to(folder_path))) continue if not all(isinstance(el, ast.Str) for el in all_token.value.elts): - bad_all.setdefault("child not strings", []).append(str(path.relative_to(folder_path))) + bad_all.setdefault('child not strings', []).append(str(path.relative_to(folder_path))) continue names = [n.s for n in all_token.value.elts] @@ -52,91 +55,111 @@ def parse_all(folder_path: str) -> Tuple[dict, dict]: path_dict = all_dict for part in path.parent.relative_to(folder_path).parts: path_dict = path_dict.setdefault(part, {}) - path_dict.setdefault(path.name[:-3], {})["__all__"] = names + path_dict.setdefault(path.name[:-3], {})['__all__'] = names return all_dict, bad_all -def gather_all(all_dict: dict, all_list: Optional[List[str]] = None) -> List[str]: +def gather_all(cur_path: List[str], + cur_dict: dict, + skip_children: dict, + all_list: Optional[List[str]] = None) -> List[str]: """Recursively gather __all__ names.""" all_list = [] if all_list is None else all_list - for key, val in all_dict.items(): - if key == "__all__": + for key, val in cur_dict.items(): + if key == '__all__': all_list.extend(val) - else: - gather_all(val, all_list) + elif key not in skip_children.get("/".join(cur_path), []): + gather_all(cur_path + [key], val, skip_children, all_list) return all_list def write_inits(folder_path: str, all_dict: dict, skip_children: dict) -> dict: """Write __init__.py files for all subfolders. - + :return: folders with non-unique imports """ folder_path = Path(folder_path) non_unique = {} - for path in folder_path.glob("**/__init__.py"): + for path in folder_path.glob('**/__init__.py'): if path.parent == folder_path: # skip top level __init__.py continue + rel_path = path.parent.relative_to(folder_path).as_posix() + # get sub_dict for this folder path_all_dict = all_dict + mod_path = path.parent.relative_to(folder_path).parts try: - for part in path.parent.relative_to(folder_path).parts: + for part in mod_path: path_all_dict = path_all_dict[part] except KeyError: + # there is nothing to import continue + path_all_dict = {key: val for key, val in path_all_dict.items() if key not in skip_children.get(rel_path, [])} - alls = gather_all(path_all_dict) - if not alls: - continue + alls = gather_all(list(mod_path), path_all_dict, skip_children) + # check for non-unique imports - if len(alls) != len(set(alls)): - non_unique[rel_path] = [k for k, v in Counter(alls).items() if v > 1] - content = (["", "# AUTO-GENERATED"] + ["# yapf: disable", "# pylint: disable=wildcard-import", ""] + - [f"from .{mod} import *" for mod in sorted(path_all_dict.keys())] + ["", "__all__ = ("] + - [f" {a!r}," for a in sorted(set(alls))] + [")", ""]) - - new_content = [] - in_docstring = False - for line in path.read_text(encoding="utf8").splitlines(): + if len(alls + list(path_all_dict)) != len(set(alls + list(path_all_dict))): + non_unique[rel_path] = [k for k, v in Counter(alls + list(path_all_dict)).items() if v > 1] + + auto_content = (['', '# AUTO-GENERATED'] + ['', '# yapf: disable', '# pylint: disable=wildcard-import', ''] + + [f'from .{mod} import *' for mod in sorted(path_all_dict.keys())] + ['', '__all__ = ('] + + [f' {a!r},' for a in sorted(set(alls))] + [')', '', '# yapf: enable', '']) + + start_content = [] + end_content = [] + in_docstring = in_end_content = False + in_start_content = True + for line in path.read_text(encoding='utf8').splitlines(): + if not in_start_content and line.startswith('# END AUTO-GENERATED'): + in_end_content = True + if in_end_content: + end_content.append(line) + continue # only use initial comments and docstring - if not (line.startswith("#") or line.startswith('"""') or in_docstring): - break + if not (line.startswith('#') or line.startswith('"""') or in_docstring): + in_start_content = False + continue if line.startswith('"""'): if (not in_docstring) and not (line.endswith('"""') and not line.strip() == '"""'): in_docstring = True else: in_docstring = False - if not line.startswith("# pylint"): - new_content.append(line) + if in_start_content and not line.startswith('# pylint'): + start_content.append(line) - # could warn if overwriting any non-autogenerated content + new_content = start_content + auto_content + end_content - new_content.extend(content) + if not alls: + # there is nothing to import + continue - path.write_text("\n".join(new_content), encoding="utf8") + path.write_text('\n'.join(new_content).rstrip() + "\n", encoding='utf8') return non_unique -if __name__ == "__main__": - _folder = Path(__file__).parent.parent.joinpath("aiida") +if __name__ == '__main__': + _folder = Path(__file__).parent.parent.joinpath('aiida') _skip = { - "orm": "implementation", - "orm/implementation": ["django", "sqlalchemy", "sql"], - "cmdline": ["params"], - "cmdline/params": ["arguments", "options"] + 'cmdline/params': ['arguments', 'options'], + 'cmdline/utils': ['echo'], + 'manage': ['tests'], + 'orm': 'implementation', + 'orm/implementation': ['django', 'sqlalchemy', 'sql'], + 'restapi': ['run_api'], } _all_dict, _bad_all = parse_all(_folder) _non_unique = write_inits(_folder, _all_dict, _skip) - _bad_all.pop("missing", "") # allow missing __all__ + _bad_all.pop('missing', '') # allow missing __all__ if _bad_all: - print("unparsable __all__:") + print('unparsable __all__:') pprint(_bad_all) if _non_unique: - print("non-unique imports:") + print('non-unique imports:') pprint(_non_unique) if _bad_all or _non_unique: sys.exit(1) From 9d2e4d72d5fb0208fc2f80184396f97a854daad2 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 11 Aug 2021 00:26:06 +0200 Subject: [PATCH 03/11] pre-commit fixes --- aiida/cmdline/params/__init__.py | 1 + aiida/manage/configuration/__init__.py | 2 ++ aiida/manage/database/__init__.py | 1 + aiida/manage/tests/__init__.py | 2 -- aiida/manage/tests/main.py | 1 - aiida/orm/implementation/django/__init__.py | 1 + aiida/orm/implementation/django/comments.py | 2 +- aiida/orm/implementation/sqlalchemy/__init__.py | 1 + aiida/orm/implementation/sqlalchemy/comments.py | 2 +- aiida/tools/data/__init__.py | 1 + aiida/transports/plugins/__init__.py | 1 + utils/make_all.py | 9 +++++---- 12 files changed, 15 insertions(+), 9 deletions(-) diff --git a/aiida/cmdline/params/__init__.py b/aiida/cmdline/params/__init__.py index 5dc9b918a3..c1329cef2a 100644 --- a/aiida/cmdline/params/__init__.py +++ b/aiida/cmdline/params/__init__.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Commandline parameters.""" # AUTO-GENERATED diff --git a/aiida/manage/configuration/__init__.py b/aiida/manage/configuration/__init__.py index 3f81aff446..841e0eba7e 100644 --- a/aiida/manage/configuration/__init__.py +++ b/aiida/manage/configuration/__init__.py @@ -39,6 +39,8 @@ # END AUTO-GENERATED +# pylint: disable=global-statement,redefined-outer-name,wrong-import-order + __all__ += ( 'get_config', 'get_config_option', 'get_config_path', 'get_profile', 'load_documentation_profile', 'load_profile', 'reset_config', 'reset_profile', 'CONFIG', 'PROFILE', 'BACKEND_UUID' diff --git a/aiida/manage/database/__init__.py b/aiida/manage/database/__init__.py index e6131da0b1..f1222d7e8b 100644 --- a/aiida/manage/database/__init__.py +++ b/aiida/manage/database/__init__.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Management of the database.""" # AUTO-GENERATED diff --git a/aiida/manage/tests/__init__.py b/aiida/manage/tests/__init__.py index f60a6c3c7f..813f38628c 100644 --- a/aiida/manage/tests/__init__.py +++ b/aiida/manage/tests/__init__.py @@ -9,7 +9,6 @@ ########################################################################### """ Testing infrastructure for easy testing of AiiDA plugins. - """ # AUTO-GENERATED @@ -27,7 +26,6 @@ 'TestManager', 'TestManagerError', 'TestRunner', - '_GLOBAL_TEST_MANAGER', 'get_test_backend_name', 'get_test_profile_name', 'get_user_dict', diff --git a/aiida/manage/tests/main.py b/aiida/manage/tests/main.py index c6a333670f..c3bc9ff19a 100644 --- a/aiida/manage/tests/main.py +++ b/aiida/manage/tests/main.py @@ -24,7 +24,6 @@ from aiida.manage.external.postgres import Postgres __all__ = ( - '_GLOBAL_TEST_MANAGER', 'get_test_profile_name', 'get_test_backend_name', 'get_user_dict', diff --git a/aiida/orm/implementation/django/__init__.py b/aiida/orm/implementation/django/__init__.py index 919c6a84ba..5089f32237 100644 --- a/aiida/orm/implementation/django/__init__.py +++ b/aiida/orm/implementation/django/__init__.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Implementation of Django backend.""" # AUTO-GENERATED diff --git a/aiida/orm/implementation/django/comments.py b/aiida/orm/implementation/django/comments.py index 1e6f2b0521..7b58c2a8fb 100644 --- a/aiida/orm/implementation/django/comments.py +++ b/aiida/orm/implementation/django/comments.py @@ -114,7 +114,7 @@ def create(self, node, user, content=None, **kwargs): :param content: the comment content :return: a Comment object associated to the given node and user """ - return DjangoComment(self.backend, node, user, content, **kwargs) + return DjangoComment(self.backend, node, user, content, **kwargs) # pylint: disable=abstract-class-instantiated def delete(self, comment_id): """ diff --git a/aiida/orm/implementation/sqlalchemy/__init__.py b/aiida/orm/implementation/sqlalchemy/__init__.py index f1ba8458df..82a9691ef1 100644 --- a/aiida/orm/implementation/sqlalchemy/__init__.py +++ b/aiida/orm/implementation/sqlalchemy/__init__.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Implementation of SQLAlchemy backend.""" # AUTO-GENERATED diff --git a/aiida/orm/implementation/sqlalchemy/comments.py b/aiida/orm/implementation/sqlalchemy/comments.py index d97df9aea7..d8173cd703 100644 --- a/aiida/orm/implementation/sqlalchemy/comments.py +++ b/aiida/orm/implementation/sqlalchemy/comments.py @@ -112,7 +112,7 @@ def create(self, node, user, content=None, **kwargs): :param content: the comment content :return: a Comment object associated to the given node and user """ - return SqlaComment(self.backend, node, user, content, **kwargs) + return SqlaComment(self.backend, node, user, content, **kwargs) # pylint: disable=abstract-class-instantiated def delete(self, comment_id): """ diff --git a/aiida/tools/data/__init__.py b/aiida/tools/data/__init__.py index fb200bb489..fdf843ae12 100644 --- a/aiida/tools/data/__init__.py +++ b/aiida/tools/data/__init__.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Tool for handling data.""" # AUTO-GENERATED diff --git a/aiida/transports/plugins/__init__.py b/aiida/transports/plugins/__init__.py index e578131b6c..f11aea3b4c 100644 --- a/aiida/transports/plugins/__init__.py +++ b/aiida/transports/plugins/__init__.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Plugins for the transport.""" # AUTO-GENERATED diff --git a/utils/make_all.py b/utils/make_all.py index c4d0657cc1..072484a065 100644 --- a/utils/make_all.py +++ b/utils/make_all.py @@ -1,5 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +# pylint: disable=simplifiable-if-statement,too-many-branches """Pre-commit hook to add ``__all__`` installs to ``__init__`` files.""" import ast from collections import Counter @@ -69,7 +70,7 @@ def gather_all(cur_path: List[str], for key, val in cur_dict.items(): if key == '__all__': all_list.extend(val) - elif key not in skip_children.get("/".join(cur_path), []): + elif key not in skip_children.get('/'.join(cur_path), []): gather_all(cur_path + [key], val, skip_children, all_list) return all_list @@ -106,8 +107,8 @@ def write_inits(folder_path: str, all_dict: dict, skip_children: dict) -> dict: non_unique[rel_path] = [k for k, v in Counter(alls + list(path_all_dict)).items() if v > 1] auto_content = (['', '# AUTO-GENERATED'] + ['', '# yapf: disable', '# pylint: disable=wildcard-import', ''] + - [f'from .{mod} import *' for mod in sorted(path_all_dict.keys())] + ['', '__all__ = ('] + - [f' {a!r},' for a in sorted(set(alls))] + [')', '', '# yapf: enable', '']) + [f'from .{mod} import *' for mod in sorted(path_all_dict.keys())] + ['', '__all__ = ('] + + [f' {a!r},' for a in sorted(set(alls))] + [')', '', '# yapf: enable', '']) start_content = [] end_content = [] @@ -137,7 +138,7 @@ def write_inits(folder_path: str, all_dict: dict, skip_children: dict) -> dict: # there is nothing to import continue - path.write_text('\n'.join(new_content).rstrip() + "\n", encoding='utf8') + path.write_text('\n'.join(new_content).rstrip() + '\n', encoding='utf8') return non_unique From d5f41419901fb57fed5c0d81a92dcfa2c1825b71 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 11 Aug 2021 01:09:21 +0200 Subject: [PATCH 04/11] add `load_entity` to __all__ --- aiida/orm/__init__.py | 1 + aiida/orm/utils/__init__.py | 1 + aiida/orm/utils/load_funcs.py | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/aiida/orm/__init__.py b/aiida/orm/__init__.py index e09549c7c8..a84f87017f 100644 --- a/aiida/orm/__init__.py +++ b/aiida/orm/__init__.py @@ -100,6 +100,7 @@ 'has_pycifrw', 'load_code', 'load_computer', + 'load_entity', 'load_group', 'load_node', 'load_node_class', diff --git a/aiida/orm/utils/__init__.py b/aiida/orm/utils/__init__.py index 7b049553cb..1d91cee652 100644 --- a/aiida/orm/utils/__init__.py +++ b/aiida/orm/utils/__init__.py @@ -40,6 +40,7 @@ 'get_type_string_from_class', 'load_code', 'load_computer', + 'load_entity', 'load_group', 'load_node', 'load_node_class', diff --git a/aiida/orm/utils/load_funcs.py b/aiida/orm/utils/load_funcs.py index f703884d0e..a1ce7d94cf 100644 --- a/aiida/orm/utils/load_funcs.py +++ b/aiida/orm/utils/load_funcs.py @@ -9,7 +9,7 @@ ########################################################################### """Utilities related to the ORM.""" -__all__ = ('load_code', 'load_computer', 'load_group', 'load_node') +__all__ = ('load_code', 'load_computer', 'load_group', 'load_node', 'load_entity') def load_entity( From 1462dff1e12df67b813cc36ff86d3b531d795271 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 11 Aug 2021 01:14:54 +0200 Subject: [PATCH 05/11] Update __init__.py --- aiida/restapi/__init__.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/aiida/restapi/__init__.py b/aiida/restapi/__init__.py index 90d8726344..fc199853df 100644 --- a/aiida/restapi/__init__.py +++ b/aiida/restapi/__init__.py @@ -12,17 +12,3 @@ AiiDA nodes stored in database. The REST API is implemented using Flask RESTFul framework. """ - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .run_api import * - -__all__ = ( - 'configure_api', - 'run_api', -) - -# yapf: enable From 0e4d6368308ecba057cb4421e8eed2591b28f441 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 11 Aug 2021 01:39:58 +0200 Subject: [PATCH 06/11] allow `*` skip --- aiida/restapi/__init__.py | 4 ++++ utils/make_all.py | 42 +++++++++++++++++++++------------------ 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/aiida/restapi/__init__.py b/aiida/restapi/__init__.py index fc199853df..5cdd575a4a 100644 --- a/aiida/restapi/__init__.py +++ b/aiida/restapi/__init__.py @@ -12,3 +12,7 @@ AiiDA nodes stored in database. The REST API is implemented using Flask RESTFul framework. """ + +# AUTO-GENERATED + +__all__ = () diff --git a/utils/make_all.py b/utils/make_all.py index 072484a065..8eaf351dc5 100644 --- a/utils/make_all.py +++ b/utils/make_all.py @@ -7,7 +7,7 @@ from pathlib import Path from pprint import pprint import sys -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple def parse_all(folder_path: str) -> Tuple[dict, dict]: @@ -42,7 +42,7 @@ def parse_all(folder_path: str) -> Tuple[dict, dict]: continue if not isinstance(all_token.value, (ast.List, ast.Tuple)): - bad_all.setdefault('vaue not list/tuple', []).append(str(path.relative_to(folder_path))) + bad_all.setdefault('value not list/tuple', []).append(str(path.relative_to(folder_path))) continue if not all(isinstance(el, ast.Str) for el in all_token.value.elts): @@ -75,7 +75,7 @@ def gather_all(cur_path: List[str], return all_list -def write_inits(folder_path: str, all_dict: dict, skip_children: dict) -> dict: +def write_inits(folder_path: str, all_dict: dict, skip_children: Dict[str, List[str]]) -> Dict[str, List[str]]: """Write __init__.py files for all subfolders. :return: folders with non-unique imports @@ -99,16 +99,24 @@ def write_inits(folder_path: str, all_dict: dict, skip_children: dict) -> dict: # there is nothing to import continue - path_all_dict = {key: val for key, val in path_all_dict.items() if key not in skip_children.get(rel_path, [])} - alls = gather_all(list(mod_path), path_all_dict, skip_children) - - # check for non-unique imports - if len(alls + list(path_all_dict)) != len(set(alls + list(path_all_dict))): - non_unique[rel_path] = [k for k, v in Counter(alls + list(path_all_dict)).items() if v > 1] - - auto_content = (['', '# AUTO-GENERATED'] + ['', '# yapf: disable', '# pylint: disable=wildcard-import', ''] + - [f'from .{mod} import *' for mod in sorted(path_all_dict.keys())] + ['', '__all__ = ('] + - [f' {a!r},' for a in sorted(set(alls))] + [')', '', '# yapf: enable', '']) + if '*' in skip_children.get(rel_path, []): + path_all_dict = {} + alls = [] + auto_content = ['', '# AUTO-GENERATED', '', '__all__ = ()', ''] + else: + path_all_dict = { + key: val for key, val in path_all_dict.items() if key not in skip_children.get(rel_path, []) + } + alls = gather_all(list(mod_path), path_all_dict, skip_children) + + # check for non-unique imports + if len(alls + list(path_all_dict)) != len(set(alls + list(path_all_dict))): + non_unique[rel_path] = [k for k, v in Counter(alls + list(path_all_dict)).items() if v > 1] + + auto_content = (['', '# AUTO-GENERATED'] + + ['', '# yapf: disable', '# pylint: disable=wildcard-import', ''] + + [f'from .{mod} import *' for mod in sorted(path_all_dict.keys())] + ['', '__all__ = ('] + + [f' {a!r},' for a in sorted(set(alls))] + [')', '', '# yapf: enable', '']) start_content = [] end_content = [] @@ -134,10 +142,6 @@ def write_inits(folder_path: str, all_dict: dict, skip_children: dict) -> dict: new_content = start_content + auto_content + end_content - if not alls: - # there is nothing to import - continue - path.write_text('\n'.join(new_content).rstrip() + '\n', encoding='utf8') return non_unique @@ -149,9 +153,9 @@ def write_inits(folder_path: str, all_dict: dict, skip_children: dict) -> dict: 'cmdline/params': ['arguments', 'options'], 'cmdline/utils': ['echo'], 'manage': ['tests'], - 'orm': 'implementation', + 'orm': ['implementation'], 'orm/implementation': ['django', 'sqlalchemy', 'sql'], - 'restapi': ['run_api'], + 'restapi': ['*'], } _all_dict, _bad_all = parse_all(_folder) _non_unique = write_inits(_folder, _all_dict, _skip) From 1a55259f5e46445e245a5b692b2cb37093980d02 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 11 Aug 2021 01:59:59 +0200 Subject: [PATCH 07/11] =?UTF-8?q?=E2=AC=86=EF=B8=8F=20UPDATE:=20mypy=20v0.?= =?UTF-8?q?910?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.json b/setup.json index 319f15d843..b4ce76146b 100644 --- a/setup.json +++ b/setup.json @@ -95,7 +95,7 @@ ], "pre-commit": [ "astroid<2.5", - "mypy==0.790", + "mypy==0.910", "packaging==20.3", "pre-commit~=2.2", "pylint~=2.5.0", From 797b7f64a6d144d57ccd842cb8217736ec9227dd Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 11 Aug 2021 11:46:10 +0200 Subject: [PATCH 08/11] Apply suggestions from code review Co-authored-by: Leopold Talirz --- utils/make_all.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/utils/make_all.py b/utils/make_all.py index 8eaf351dc5..c54a330171 100644 --- a/utils/make_all.py +++ b/utils/make_all.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # pylint: disable=simplifiable-if-statement,too-many-branches -"""Pre-commit hook to add ``__all__`` installs to ``__init__`` files.""" +"""Pre-commit hook to add ``__all__`` imports to ``__init__`` files.""" import ast from collections import Counter from pathlib import Path @@ -150,11 +150,17 @@ def write_inits(folder_path: str, all_dict: dict, skip_children: Dict[str, List[ if __name__ == '__main__': _folder = Path(__file__).parent.parent.joinpath('aiida') _skip = { + # skipped since some arguments and options share the same name 'cmdline/params': ['arguments', 'options'], + # skipped since the module and its method share the same name 'cmdline/utils': ['echo'], + # skipped since this is for testing only not general use 'manage': ['tests'], + # skipped since we don't want to expose the implmentation 'orm': ['implementation'], + # skipped since both implementations share class/function names 'orm/implementation': ['django', 'sqlalchemy', 'sql'], + # skip all since the module requires extra requirements 'restapi': ['*'], } _all_dict, _bad_all = parse_all(_folder) From fce7f62c2b4426c1db06c0005a14519f5de70925 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 11 Aug 2021 13:44:22 +0200 Subject: [PATCH 09/11] Update aiida/common/log.py --- aiida/common/log.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiida/common/log.py b/aiida/common/log.py index 98baa6a9ab..0cc83df08f 100644 --- a/aiida/common/log.py +++ b/aiida/common/log.py @@ -159,7 +159,7 @@ def configure_logging(with_orm=False, daemon=False, daemon_log_file=None): from aiida.manage.configuration import get_config_option # Evaluate the `LOGGING` configuration to resolve the lambdas that will retrieve the correct values based on the - # currently configured profile. Pass a deep copy of `LOGGING` to ensure that the original remains unaltered. + # currently configured profile. config = evaluate_logging_configuration(get_logging_config()) daemon_handler_name = 'daemon_log_file' From 70868bf00d56dc16edb96461997bc8e9cddd4a52 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 11 Aug 2021 16:39:16 +0200 Subject: [PATCH 10/11] Move loaders --- aiida/orm/utils/__init__.py | 1 - aiida/orm/utils/load_funcs.py | 205 ---------------------------------- aiida/orm/utils/loaders.py | 191 ++++++++++++++++++++++++++++++- 3 files changed, 189 insertions(+), 208 deletions(-) delete mode 100644 aiida/orm/utils/load_funcs.py diff --git a/aiida/orm/utils/__init__.py b/aiida/orm/utils/__init__.py index 1d91cee652..16e7b146c1 100644 --- a/aiida/orm/utils/__init__.py +++ b/aiida/orm/utils/__init__.py @@ -16,7 +16,6 @@ from .calcjob import * from .links import * -from .load_funcs import * from .loaders import * from .managers import * from .node import * diff --git a/aiida/orm/utils/load_funcs.py b/aiida/orm/utils/load_funcs.py deleted file mode 100644 index a1ce7d94cf..0000000000 --- a/aiida/orm/utils/load_funcs.py +++ /dev/null @@ -1,205 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Utilities related to the ORM.""" - -__all__ = ('load_code', 'load_computer', 'load_group', 'load_node', 'load_entity') - - -def load_entity( - entity_loader=None, identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True -): - # pylint: disable=too-many-arguments - """ - Load an entity instance by one of its identifiers: pk, uuid or label - - If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to - automatically infer the type. - - :param identifier: pk (integer), uuid (string) or label (string) of a Code - :param pk: pk of a Code - :param uuid: uuid of a Code, or the beginning of the uuid - :param label: label of a Code - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :returns: the Code instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Code is found - :raise aiida.common.MultipleObjectsError: if more than one Code was found - """ - from aiida.orm.utils.loaders import OrmEntityLoader, IdentifierType - - if entity_loader is None or not issubclass(entity_loader, OrmEntityLoader): - raise TypeError(f'entity_loader should be a sub class of {type(OrmEntityLoader)}') - - inputs_provided = [value is not None for value in (identifier, pk, uuid, label)].count(True) - - if inputs_provided == 0: - raise ValueError("one of the parameters 'identifier', pk', 'uuid' or 'label' has to be specified") - elif inputs_provided > 1: - raise ValueError("only one of parameters 'identifier', pk', 'uuid' or 'label' has to be specified") - - if pk is not None: - - if not isinstance(pk, int): - raise TypeError('a pk has to be an integer') - - identifier = pk - identifier_type = IdentifierType.ID - - elif uuid is not None: - - if not isinstance(uuid, str): - raise TypeError('uuid has to be a string type') - - identifier = uuid - identifier_type = IdentifierType.UUID - - elif label is not None: - - if not isinstance(label, str): - raise TypeError('label has to be a string type') - - identifier = label - identifier_type = IdentifierType.LABEL - else: - identifier = str(identifier) - identifier_type = None - - return entity_loader.load_entity( - identifier, identifier_type, sub_classes=sub_classes, query_with_dashes=query_with_dashes - ) - - -def load_code(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): - """ - Load a Code instance by one of its identifiers: pk, uuid or label - - If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to - automatically infer the type. - - :param identifier: pk (integer), uuid (string) or label (string) of a Code - :param pk: pk of a Code - :param uuid: uuid of a Code, or the beginning of the uuid - :param label: label of a Code - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :return: the Code instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Code is found - :raise aiida.common.MultipleObjectsError: if more than one Code was found - """ - from aiida.orm.utils.loaders import CodeEntityLoader - return load_entity( - CodeEntityLoader, - identifier=identifier, - pk=pk, - uuid=uuid, - label=label, - sub_classes=sub_classes, - query_with_dashes=query_with_dashes - ) - - -def load_computer(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): - """ - Load a Computer instance by one of its identifiers: pk, uuid or label - - If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to - automatically infer the type. - - :param identifier: pk (integer), uuid (string) or label (string) of a Computer - :param pk: pk of a Computer - :param uuid: uuid of a Computer, or the beginning of the uuid - :param label: label of a Computer - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :return: the Computer instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Computer is found - :raise aiida.common.MultipleObjectsError: if more than one Computer was found - """ - from aiida.orm.utils.loaders import ComputerEntityLoader - return load_entity( - ComputerEntityLoader, - identifier=identifier, - pk=pk, - uuid=uuid, - label=label, - sub_classes=sub_classes, - query_with_dashes=query_with_dashes - ) - - -def load_group(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): - """ - Load a Group instance by one of its identifiers: pk, uuid or label - - If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to - automatically infer the type. - - :param identifier: pk (integer), uuid (string) or label (string) of a Group - :param pk: pk of a Group - :param uuid: uuid of a Group, or the beginning of the uuid - :param label: label of a Group - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :return: the Group instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Group is found - :raise aiida.common.MultipleObjectsError: if more than one Group was found - """ - from aiida.orm.utils.loaders import GroupEntityLoader - return load_entity( - GroupEntityLoader, - identifier=identifier, - pk=pk, - uuid=uuid, - label=label, - sub_classes=sub_classes, - query_with_dashes=query_with_dashes - ) - - -def load_node(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): - """ - Load a node by one of its identifiers: pk or uuid. If the type of the identifier is unknown - simply pass it without a keyword and the loader will attempt to infer the type - - :param identifier: pk (integer) or uuid (string) - :param pk: pk of a node - :param uuid: uuid of a node, or the beginning of the uuid - :param label: label of a Node - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :returns: the node instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Node is found - :raise aiida.common.MultipleObjectsError: if more than one Node was found - """ - from aiida.orm.utils.loaders import NodeEntityLoader - return load_entity( - NodeEntityLoader, - identifier=identifier, - pk=pk, - uuid=uuid, - label=label, - sub_classes=sub_classes, - query_with_dashes=query_with_dashes - ) diff --git a/aiida/orm/utils/loaders.py b/aiida/orm/utils/loaders.py index 0773a7a48c..66d59ad187 100644 --- a/aiida/orm/utils/loaders.py +++ b/aiida/orm/utils/loaders.py @@ -16,11 +16,198 @@ from aiida.orm.querybuilder import QueryBuilder __all__ = ( - 'get_loader', 'OrmEntityLoader', 'CalculationEntityLoader', 'CodeEntityLoader', 'ComputerEntityLoader', - 'GroupEntityLoader', 'NodeEntityLoader' + 'load_code', 'load_computer', 'load_group', 'load_node', 'load_entity', 'get_loader', 'OrmEntityLoader', + 'CalculationEntityLoader', 'CodeEntityLoader', 'ComputerEntityLoader', 'GroupEntityLoader', 'NodeEntityLoader' ) +def load_entity( + entity_loader=None, identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True +): + # pylint: disable=too-many-arguments + """ + Load an entity instance by one of its identifiers: pk, uuid or label + + If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to + automatically infer the type. + + :param identifier: pk (integer), uuid (string) or label (string) of a Code + :param pk: pk of a Code + :param uuid: uuid of a Code, or the beginning of the uuid + :param label: label of a Code + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :returns: the Code instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Code is found + :raise aiida.common.MultipleObjectsError: if more than one Code was found + """ + if entity_loader is None or not issubclass(entity_loader, OrmEntityLoader): + raise TypeError(f'entity_loader should be a sub class of {type(OrmEntityLoader)}') + + inputs_provided = [value is not None for value in (identifier, pk, uuid, label)].count(True) + + if inputs_provided == 0: + raise ValueError("one of the parameters 'identifier', pk', 'uuid' or 'label' has to be specified") + elif inputs_provided > 1: + raise ValueError("only one of parameters 'identifier', pk', 'uuid' or 'label' has to be specified") + + if pk is not None: + + if not isinstance(pk, int): + raise TypeError('a pk has to be an integer') + + identifier = pk + identifier_type = IdentifierType.ID + + elif uuid is not None: + + if not isinstance(uuid, str): + raise TypeError('uuid has to be a string type') + + identifier = uuid + identifier_type = IdentifierType.UUID + + elif label is not None: + + if not isinstance(label, str): + raise TypeError('label has to be a string type') + + identifier = label + identifier_type = IdentifierType.LABEL + else: + identifier = str(identifier) + identifier_type = None + + return entity_loader.load_entity( + identifier, identifier_type, sub_classes=sub_classes, query_with_dashes=query_with_dashes + ) + + +def load_code(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): + """ + Load a Code instance by one of its identifiers: pk, uuid or label + + If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to + automatically infer the type. + + :param identifier: pk (integer), uuid (string) or label (string) of a Code + :param pk: pk of a Code + :param uuid: uuid of a Code, or the beginning of the uuid + :param label: label of a Code + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :return: the Code instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Code is found + :raise aiida.common.MultipleObjectsError: if more than one Code was found + """ + return load_entity( + CodeEntityLoader, + identifier=identifier, + pk=pk, + uuid=uuid, + label=label, + sub_classes=sub_classes, + query_with_dashes=query_with_dashes + ) + + +def load_computer(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): + """ + Load a Computer instance by one of its identifiers: pk, uuid or label + + If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to + automatically infer the type. + + :param identifier: pk (integer), uuid (string) or label (string) of a Computer + :param pk: pk of a Computer + :param uuid: uuid of a Computer, or the beginning of the uuid + :param label: label of a Computer + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :return: the Computer instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Computer is found + :raise aiida.common.MultipleObjectsError: if more than one Computer was found + """ + return load_entity( + ComputerEntityLoader, + identifier=identifier, + pk=pk, + uuid=uuid, + label=label, + sub_classes=sub_classes, + query_with_dashes=query_with_dashes + ) + + +def load_group(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): + """ + Load a Group instance by one of its identifiers: pk, uuid or label + + If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to + automatically infer the type. + + :param identifier: pk (integer), uuid (string) or label (string) of a Group + :param pk: pk of a Group + :param uuid: uuid of a Group, or the beginning of the uuid + :param label: label of a Group + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :return: the Group instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Group is found + :raise aiida.common.MultipleObjectsError: if more than one Group was found + """ + return load_entity( + GroupEntityLoader, + identifier=identifier, + pk=pk, + uuid=uuid, + label=label, + sub_classes=sub_classes, + query_with_dashes=query_with_dashes + ) + + +def load_node(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): + """ + Load a node by one of its identifiers: pk or uuid. If the type of the identifier is unknown + simply pass it without a keyword and the loader will attempt to infer the type + + :param identifier: pk (integer) or uuid (string) + :param pk: pk of a node + :param uuid: uuid of a node, or the beginning of the uuid + :param label: label of a Node + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :returns: the node instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Node is found + :raise aiida.common.MultipleObjectsError: if more than one Node was found + """ + return load_entity( + NodeEntityLoader, + identifier=identifier, + pk=pk, + uuid=uuid, + label=label, + sub_classes=sub_classes, + query_with_dashes=query_with_dashes + ) + + def get_loader(orm_class): """Return the correct OrmEntityLoader for the given orm class. From 7e2dc6a68f0a49ae201af7a18af8d72482320ca5 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 11 Aug 2021 16:40:41 +0200 Subject: [PATCH 11/11] remove VALID_DICT_FORMATS_MAPPING --- aiida/cmdline/utils/echo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiida/cmdline/utils/echo.py b/aiida/cmdline/utils/echo.py index bb511baf74..248a01a2db 100644 --- a/aiida/cmdline/utils/echo.py +++ b/aiida/cmdline/utils/echo.py @@ -18,7 +18,7 @@ __all__ = ( 'echo', 'echo_info', 'echo_success', 'echo_warning', 'echo_error', 'echo_critical', 'echo_highlight', - 'echo_dictionary', 'VALID_DICT_FORMATS_MAPPING' + 'echo_dictionary' )