diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c00c315ab5..f54f889bae 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,6 +44,15 @@ repos: hooks: + - id: imports + name: imports + entry: python utils/make_all.py + language: python + types: [python] + require_serial: true + pass_filenames: false + files: aiida/.*py + - id: mypy name: mypy entry: mypy diff --git a/aiida/cmdline/__init__.py b/aiida/cmdline/__init__.py index 34a245187e..b997337142 100644 --- a/aiida/cmdline/__init__.py +++ b/aiida/cmdline/__init__.py @@ -7,16 +7,47 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """The command line interface of AiiDA.""" -from .params.arguments import * -from .params.options import * -from .params.types import * -from .utils.decorators import * -from .utils.echo import * +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .params import * +from .utils import * __all__ = ( - params.arguments.__all__ + params.options.__all__ + params.types.__all__ + utils.decorators.__all__ + - utils.echo.__all__ + 'AbsolutePathParamType', + 'CalculationParamType', + 'CodeParamType', + 'ComputerParamType', + 'ConfigOptionParamType', + 'DataParamType', + 'EmailType', + 'EntryPointType', + 'FileOrUrl', + 'GroupParamType', + 'HostnameType', + 'IdentifierParamType', + 'LabelStringType', + 'LazyChoice', + 'MpirunCommandParamType', + 'MultipleValueParamType', + 'NodeParamType', + 'NonEmptyStringParamType', + 'PathOrUrl', + 'PluginParamType', + 'ProcessParamType', + 'ProfileParamType', + 'ShebangParamType', + 'TestModuleParamType', + 'UserParamType', + 'WorkflowParamType', + 'dbenv', + 'format_call_graph', + 'only_if_daemon_running', + 'with_dbenv', ) + +# yapf: enable diff --git a/aiida/cmdline/params/__init__.py b/aiida/cmdline/params/__init__.py index 2776a55f97..c1329cef2a 100644 --- a/aiida/cmdline/params/__init__.py +++ b/aiida/cmdline/params/__init__.py @@ -7,3 +7,42 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Commandline parameters.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .types import * + +__all__ = ( + 'AbsolutePathParamType', + 'CalculationParamType', + 'CodeParamType', + 'ComputerParamType', + 'ConfigOptionParamType', + 'DataParamType', + 'EmailType', + 'EntryPointType', + 'FileOrUrl', + 'GroupParamType', + 'HostnameType', + 'IdentifierParamType', + 'LabelStringType', + 'LazyChoice', + 'MpirunCommandParamType', + 'MultipleValueParamType', + 'NodeParamType', + 'NonEmptyStringParamType', + 'PathOrUrl', + 'PluginParamType', + 'ProcessParamType', + 'ProfileParamType', + 'ShebangParamType', + 'TestModuleParamType', + 'UserParamType', + 'WorkflowParamType', +) + +# yapf: enable diff --git a/aiida/cmdline/params/arguments/__init__.py b/aiida/cmdline/params/arguments/__init__.py index 71bb8c2544..1c8f5543b8 100644 --- a/aiida/cmdline/params/arguments/__init__.py +++ b/aiida/cmdline/params/arguments/__init__.py @@ -10,60 +10,37 @@ # yapf: disable """Module with pre-defined reusable commandline arguments that can be used as `click` decorators.""" -import click +# AUTO-GENERATED -from .. import types -from .overridable import OverridableArgument +# yapf: disable +# pylint: disable=wildcard-import + +from .main import * __all__ = ( - 'PROFILE', 'PROFILES', 'CALCULATION', 'CALCULATIONS', 'CODE', 'CODES', 'COMPUTER', 'COMPUTERS', 'DATUM', 'DATA', - 'GROUP', 'GROUPS', 'NODE', 'NODES', 'PROCESS', 'PROCESSES', 'WORKFLOW', 'WORKFLOWS', 'INPUT_FILE', 'OUTPUT_FILE', - 'LABEL', 'USER', 'CONFIG_OPTION' + 'CALCULATION', + 'CALCULATIONS', + 'CODE', + 'CODES', + 'COMPUTER', + 'COMPUTERS', + 'CONFIG_OPTION', + 'DATA', + 'DATUM', + 'GROUP', + 'GROUPS', + 'INPUT_FILE', + 'LABEL', + 'NODE', + 'NODES', + 'OUTPUT_FILE', + 'PROCESS', + 'PROCESSES', + 'PROFILE', + 'PROFILES', + 'USER', + 'WORKFLOW', + 'WORKFLOWS', ) - -PROFILE = OverridableArgument('profile', type=types.ProfileParamType()) - -PROFILES = OverridableArgument('profiles', type=types.ProfileParamType(), nargs=-1) - -CALCULATION = OverridableArgument('calculation', type=types.CalculationParamType()) - -CALCULATIONS = OverridableArgument('calculations', nargs=-1, type=types.CalculationParamType()) - -CODE = OverridableArgument('code', type=types.CodeParamType()) - -CODES = OverridableArgument('codes', nargs=-1, type=types.CodeParamType()) - -COMPUTER = OverridableArgument('computer', type=types.ComputerParamType()) - -COMPUTERS = OverridableArgument('computers', nargs=-1, type=types.ComputerParamType()) - -DATUM = OverridableArgument('datum', type=types.DataParamType()) - -DATA = OverridableArgument('data', nargs=-1, type=types.DataParamType()) - -GROUP = OverridableArgument('group', type=types.GroupParamType()) - -GROUPS = OverridableArgument('groups', nargs=-1, type=types.GroupParamType()) - -NODE = OverridableArgument('node', type=types.NodeParamType()) - -NODES = OverridableArgument('nodes', nargs=-1, type=types.NodeParamType()) - -PROCESS = OverridableArgument('process', type=types.ProcessParamType()) - -PROCESSES = OverridableArgument('processes', nargs=-1, type=types.ProcessParamType()) - -WORKFLOW = OverridableArgument('workflow', type=types.WorkflowParamType()) - -WORKFLOWS = OverridableArgument('workflows', nargs=-1, type=types.WorkflowParamType()) - -INPUT_FILE = OverridableArgument('input_file', metavar='INPUT_FILE', type=click.Path(exists=True)) - -OUTPUT_FILE = OverridableArgument('output_file', metavar='OUTPUT_FILE', type=click.Path()) - -LABEL = OverridableArgument('label', type=click.STRING) - -USER = OverridableArgument('user', metavar='USER', type=types.UserParamType()) - -CONFIG_OPTION = OverridableArgument('option', type=types.ConfigOptionParamType()) +# yapf: enable diff --git a/aiida/cmdline/params/arguments/main.py b/aiida/cmdline/params/arguments/main.py new file mode 100644 index 0000000000..71bb8c2544 --- /dev/null +++ b/aiida/cmdline/params/arguments/main.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# yapf: disable +"""Module with pre-defined reusable commandline arguments that can be used as `click` decorators.""" + +import click + +from .. import types +from .overridable import OverridableArgument + +__all__ = ( + 'PROFILE', 'PROFILES', 'CALCULATION', 'CALCULATIONS', 'CODE', 'CODES', 'COMPUTER', 'COMPUTERS', 'DATUM', 'DATA', + 'GROUP', 'GROUPS', 'NODE', 'NODES', 'PROCESS', 'PROCESSES', 'WORKFLOW', 'WORKFLOWS', 'INPUT_FILE', 'OUTPUT_FILE', + 'LABEL', 'USER', 'CONFIG_OPTION' +) + + +PROFILE = OverridableArgument('profile', type=types.ProfileParamType()) + +PROFILES = OverridableArgument('profiles', type=types.ProfileParamType(), nargs=-1) + +CALCULATION = OverridableArgument('calculation', type=types.CalculationParamType()) + +CALCULATIONS = OverridableArgument('calculations', nargs=-1, type=types.CalculationParamType()) + +CODE = OverridableArgument('code', type=types.CodeParamType()) + +CODES = OverridableArgument('codes', nargs=-1, type=types.CodeParamType()) + +COMPUTER = OverridableArgument('computer', type=types.ComputerParamType()) + +COMPUTERS = OverridableArgument('computers', nargs=-1, type=types.ComputerParamType()) + +DATUM = OverridableArgument('datum', type=types.DataParamType()) + +DATA = OverridableArgument('data', nargs=-1, type=types.DataParamType()) + +GROUP = OverridableArgument('group', type=types.GroupParamType()) + +GROUPS = OverridableArgument('groups', nargs=-1, type=types.GroupParamType()) + +NODE = OverridableArgument('node', type=types.NodeParamType()) + +NODES = OverridableArgument('nodes', nargs=-1, type=types.NodeParamType()) + +PROCESS = OverridableArgument('process', type=types.ProcessParamType()) + +PROCESSES = OverridableArgument('processes', nargs=-1, type=types.ProcessParamType()) + +WORKFLOW = OverridableArgument('workflow', type=types.WorkflowParamType()) + +WORKFLOWS = OverridableArgument('workflows', nargs=-1, type=types.WorkflowParamType()) + +INPUT_FILE = OverridableArgument('input_file', metavar='INPUT_FILE', type=click.Path(exists=True)) + +OUTPUT_FILE = OverridableArgument('output_file', metavar='OUTPUT_FILE', type=click.Path()) + +LABEL = OverridableArgument('label', type=click.STRING) + +USER = OverridableArgument('user', metavar='USER', type=types.UserParamType()) + +CONFIG_OPTION = OverridableArgument('option', type=types.ConfigOptionParamType()) diff --git a/aiida/cmdline/params/options/__init__.py b/aiida/cmdline/params/options/__init__.py index accd78c65f..fdc54869f0 100644 --- a/aiida/cmdline/params/options/__init__.py +++ b/aiida/cmdline/params/options/__init__.py @@ -8,590 +8,111 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with pre-defined reusable commandline options that can be used as `click` decorators.""" -import click -from pgsu import DEFAULT_DSN as DEFAULT_DBINFO # pylint: disable=no-name-in-module -from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA -from aiida.manage.external.rmq import BROKER_DEFAULTS -from ...utils import defaults, echo -from .. import types -from .multivalue import MultipleValueOption -from .overridable import OverridableOption -from .contextualdefault import ContextualDefaultOption -from .config import ConfigFileOption +# AUTO-GENERATED -__all__ = ( - 'graph_traversal_rules', 'PROFILE', 'CALCULATION', 'CALCULATIONS', 'CODE', 'CODES', 'COMPUTER', 'COMPUTERS', - 'DATUM', 'DATA', 'GROUP', 'GROUPS', 'NODE', 'NODES', 'FORCE', 'SILENT', 'VISUALIZATION_FORMAT', 'INPUT_FORMAT', - 'EXPORT_FORMAT', 'ARCHIVE_FORMAT', 'NON_INTERACTIVE', 'DRY_RUN', 'USER_EMAIL', 'USER_FIRST_NAME', 'USER_LAST_NAME', - 'USER_INSTITUTION', 'DB_BACKEND', 'DB_ENGINE', 'DB_HOST', 'DB_PORT', 'DB_USERNAME', 'DB_PASSWORD', 'DB_NAME', - 'REPOSITORY_PATH', 'PROFILE_ONLY_CONFIG', 'PROFILE_SET_DEFAULT', 'PREPEND_TEXT', 'APPEND_TEXT', 'LABEL', - 'DESCRIPTION', 'INPUT_PLUGIN', 'CALC_JOB_STATE', 'PROCESS_STATE', 'PROCESS_LABEL', 'TYPE_STRING', 'EXIT_STATUS', - 'FAILED', 'LIMIT', 'PROJECT', 'ORDER_BY', 'PAST_DAYS', 'OLDER_THAN', 'ALL', 'ALL_STATES', 'ALL_USERS', - 'GROUP_CLEAR', 'RAW', 'HOSTNAME', 'TRANSPORT', 'SCHEDULER', 'USER', 'PORT', 'FREQUENCY', 'VERBOSE', 'TIMEOUT', - 'FORMULA_MODE', 'TRAJECTORY_INDEX', 'WITH_ELEMENTS', 'WITH_ELEMENTS_EXCLUSIVE', 'DEBUG', 'PRINT_TRACEBACK' -) - -TRAVERSAL_RULE_HELP_STRING = { - 'call_calc_backward': 'CALL links to calculations backwards', - 'call_calc_forward': 'CALL links to calculations forwards', - 'call_work_backward': 'CALL links to workflows backwards', - 'call_work_forward': 'CALL links to workflows forwards', - 'input_calc_backward': 'INPUT links to calculations backwards', - 'input_calc_forward': 'INPUT links to calculations forwards', - 'input_work_backward': 'INPUT links to workflows backwards', - 'input_work_forward': 'INPUT links to workflows forwards', - 'return_backward': 'RETURN links backwards', - 'return_forward': 'RETURN links forwards', - 'create_backward': 'CREATE links backwards', - 'create_forward': 'CREATE links forwards', -} - - -def valid_process_states(): - """Return a list of valid values for the ProcessState enum.""" - from plumpy import ProcessState - return tuple(state.value for state in ProcessState) - - -def valid_calc_job_states(): - """Return a list of valid values for the CalcState enum.""" - from aiida.common.datastructures import CalcJobState - return tuple(state.value for state in CalcJobState) - - -def active_process_states(): - """Return a list of process states that are considered active.""" - from plumpy import ProcessState - return ([ - ProcessState.CREATED.value, - ProcessState.WAITING.value, - ProcessState.RUNNING.value, - ]) - - -def graph_traversal_rules(rules): - """Apply the graph traversal rule options to the command.""" - - def decorator(command): - """Only apply to traversal rules if they are toggleable.""" - for name, traversal_rule in sorted(rules.items(), reverse=True): - if traversal_rule.toggleable: - option_name = name.replace('_', '-') - option_label = '--{option_name}/--no-{option_name}'.format(option_name=option_name) - help_string = f'Whether to expand the node set by following {TRAVERSAL_RULE_HELP_STRING[name]}.' - click.option(option_label, default=traversal_rule.default, show_default=True, help=help_string)(command) - - return command - - return decorator - - -PROFILE = OverridableOption( - '-p', - '--profile', - 'profile', - type=types.ProfileParamType(), - default=defaults.get_default_profile, - help='Execute the command for this profile instead of the default profile.' -) - -CALCULATION = OverridableOption( - '-C', - '--calculation', - 'calculation', - type=types.CalculationParamType(), - help='A single calculation identified by its ID or UUID.' -) - -CALCULATIONS = OverridableOption( - '-C', - '--calculations', - 'calculations', - type=types.CalculationParamType(), - cls=MultipleValueOption, - help='One or multiple calculations identified by their ID or UUID.' -) - -CODE = OverridableOption( - '-X', '--code', 'code', type=types.CodeParamType(), help='A single code identified by its ID, UUID or label.' -) - -CODES = OverridableOption( - '-X', - '--codes', - 'codes', - type=types.CodeParamType(), - cls=MultipleValueOption, - help='One or multiple codes identified by their ID, UUID or label.' -) - -COMPUTER = OverridableOption( - '-Y', - '--computer', - 'computer', - type=types.ComputerParamType(), - help='A single computer identified by its ID, UUID or label.' -) - -COMPUTERS = OverridableOption( - '-Y', - '--computers', - 'computers', - type=types.ComputerParamType(), - cls=MultipleValueOption, - help='One or multiple computers identified by their ID, UUID or label.' -) - -DATUM = OverridableOption( - '-D', '--datum', 'datum', type=types.DataParamType(), help='A single datum identified by its ID, UUID or label.' -) - -DATA = OverridableOption( - '-D', - '--data', - 'data', - type=types.DataParamType(), - cls=MultipleValueOption, - help='One or multiple data identified by their ID, UUID or label.' -) - -GROUP = OverridableOption( - '-G', '--group', 'group', type=types.GroupParamType(), help='A single group identified by its ID, UUID or label.' -) - -GROUPS = OverridableOption( - '-G', - '--groups', - 'groups', - type=types.GroupParamType(), - cls=MultipleValueOption, - help='One or multiple groups identified by their ID, UUID or label.' -) - -NODE = OverridableOption( - '-N', '--node', 'node', type=types.NodeParamType(), help='A single node identified by its ID or UUID.' -) - -NODES = OverridableOption( - '-N', - '--nodes', - 'nodes', - type=types.NodeParamType(), - cls=MultipleValueOption, - help='One or multiple nodes identified by their ID or UUID.' -) - -FORCE = OverridableOption('-f', '--force', is_flag=True, default=False, help='Do not ask for confirmation.') +# yapf: disable +# pylint: disable=wildcard-import -SILENT = OverridableOption('-s', '--silent', is_flag=True, default=False, help='Suppress any output printed to stdout.') - -VISUALIZATION_FORMAT = OverridableOption( - '-F', '--format', 'fmt', show_default=True, help='Format of the visualized output.' -) - -INPUT_FORMAT = OverridableOption('-F', '--format', 'fmt', show_default=True, help='Format of the input file.') - -EXPORT_FORMAT = OverridableOption('-F', '--format', 'fmt', show_default=True, help='Format of the exported file.') - -ARCHIVE_FORMAT = OverridableOption( - '-F', - '--archive-format', - type=click.Choice(['zip', 'zip-uncompressed', 'tar.gz']), - default='zip', - show_default=True, - help='The format of the archive file.' -) - -NON_INTERACTIVE = OverridableOption( - '-n', - '--non-interactive', - is_flag=True, - is_eager=True, - help='In non-interactive mode, the CLI never prompts but simply uses default values for options that define one.' -) - -DRY_RUN = OverridableOption('-n', '--dry-run', is_flag=True, help='Perform a dry run.') - -USER_EMAIL = OverridableOption( - '--email', - 'email', - type=types.EmailType(), - help='Email address associated with the data you generate. The email address is exported along with the data, ' - 'when sharing it.' -) - -USER_FIRST_NAME = OverridableOption( - '--first-name', type=types.NonEmptyStringParamType(), help='First name of the user.' -) - -USER_LAST_NAME = OverridableOption('--last-name', type=types.NonEmptyStringParamType(), help='Last name of the user.') - -USER_INSTITUTION = OverridableOption( - '--institution', type=types.NonEmptyStringParamType(), help='Institution of the user.' -) - -DB_ENGINE = OverridableOption( - '--db-engine', - help='Engine to use to connect to the database.', - default='postgresql_psycopg2', - type=click.Choice(['postgresql_psycopg2']) -) - -DB_BACKEND = OverridableOption( - '--db-backend', - type=click.Choice([BACKEND_DJANGO, BACKEND_SQLA]), - default=BACKEND_DJANGO, - help='Database backend to use.' -) - -DB_HOST = OverridableOption( - '--db-host', - type=types.HostnameType(), - help='Database server host. Leave empty for "peer" authentication.', - default='localhost' -) - -DB_PORT = OverridableOption( - '--db-port', - type=click.INT, - help='Database server port.', - default=DEFAULT_DBINFO['port'], -) - -DB_USERNAME = OverridableOption( - '--db-username', type=types.NonEmptyStringParamType(), help='Name of the database user.' -) - -DB_PASSWORD = OverridableOption( - '--db-password', - type=click.STRING, - help='Password of the database user.', - hide_input=True, -) - -DB_NAME = OverridableOption('--db-name', type=types.NonEmptyStringParamType(), help='Database name.') - -BROKER_PROTOCOL = OverridableOption( - '--broker-protocol', - type=click.Choice(('amqp', 'amqps')), - default=BROKER_DEFAULTS.protocol, - show_default=True, - help='Protocol to use for the message broker.' -) - -BROKER_USERNAME = OverridableOption( - '--broker-username', - type=types.NonEmptyStringParamType(), - default=BROKER_DEFAULTS.username, - show_default=True, - help='Username to use for authentication with the message broker.' -) - -BROKER_PASSWORD = OverridableOption( - '--broker-password', - type=types.NonEmptyStringParamType(), - default=BROKER_DEFAULTS.password, - show_default=True, - help='Password to use for authentication with the message broker.', - hide_input=True, -) - -BROKER_HOST = OverridableOption( - '--broker-host', - type=types.HostnameType(), - default=BROKER_DEFAULTS.host, - show_default=True, - help='Hostname for the message broker.' -) - -BROKER_PORT = OverridableOption( - '--broker-port', - type=click.INT, - default=BROKER_DEFAULTS.port, - show_default=True, - help='Port for the message broker.', -) - -BROKER_VIRTUAL_HOST = OverridableOption( - '--broker-virtual-host', - type=click.types.StringParamType(), - default=BROKER_DEFAULTS.virtual_host, - show_default=True, - help='Name of the virtual host for the message broker without leading forward slash.' -) +from .config import * +from .contextualdefault import * +from .main import * +from .multivalue import * +from .overridable import * -REPOSITORY_PATH = OverridableOption( - '--repository', type=click.Path(file_okay=False), help='Absolute path to the file repository.' -) - -PROFILE_ONLY_CONFIG = OverridableOption( - '--only-config', is_flag=True, default=False, help='Only configure the user and skip creating the database.' -) - -PROFILE_SET_DEFAULT = OverridableOption( - '--set-default', is_flag=True, default=False, help='Set the profile as the new default.' -) - -PREPEND_TEXT = OverridableOption( - '--prepend-text', type=click.STRING, default='', help='Bash script to be executed before an action.' -) - -APPEND_TEXT = OverridableOption( - '--append-text', type=click.STRING, default='', help='Bash script to be executed after an action has completed.' -) - -LABEL = OverridableOption('-L', '--label', type=click.STRING, metavar='LABEL', help='Short name to be used as a label.') - -DESCRIPTION = OverridableOption( - '-D', - '--description', - type=click.STRING, - metavar='DESCRIPTION', - default='', - required=False, - help='A detailed description.' -) - -INPUT_PLUGIN = OverridableOption( - '-P', '--input-plugin', type=types.PluginParamType(group='calculations'), help='Calculation input plugin string.' -) - -CALC_JOB_STATE = OverridableOption( - '-s', - '--calc-job-state', - 'calc_job_state', - type=types.LazyChoice(valid_calc_job_states), - cls=MultipleValueOption, - help='Only include entries with this calculation job state.' -) - -PROCESS_STATE = OverridableOption( - '-S', - '--process-state', - 'process_state', - type=types.LazyChoice(valid_process_states), - cls=MultipleValueOption, - default=active_process_states, - help='Only include entries with this process state.' -) - -PAUSED = OverridableOption('--paused', 'paused', is_flag=True, help='Only include entries that are paused.') - -PROCESS_LABEL = OverridableOption( - '-L', - '--process-label', - 'process_label', - type=click.STRING, - required=False, - help='Only include entries whose process label matches this filter.' -) - -TYPE_STRING = OverridableOption( - '-T', - '--type-string', - 'type_string', - type=click.STRING, - required=False, - help='Only include entries whose type string matches this filter. Can include `_` to match a single arbitrary ' - 'character or `%` to match any number of characters.' -) - -EXIT_STATUS = OverridableOption( - '-E', '--exit-status', 'exit_status', type=click.INT, help='Only include entries with this exit status.' -) - -FAILED = OverridableOption( - '-X', '--failed', 'failed', is_flag=True, default=False, help='Only include entries that have failed.' -) - -LIMIT = OverridableOption( - '-l', '--limit', 'limit', type=click.INT, default=None, help='Limit the number of entries to display.' -) - -PROJECT = OverridableOption( - '-P', '--project', 'project', cls=MultipleValueOption, help='Select the list of entity attributes to project.' -) - -ORDER_BY = OverridableOption( - '-O', - '--order-by', - 'order_by', - type=click.Choice(['id', 'ctime']), - default='ctime', - show_default=True, - help='Order the entries by this attribute.' -) - -ORDER_DIRECTION = OverridableOption( - '-D', - '--order-direction', - 'order_dir', - type=click.Choice(['asc', 'desc']), - default='asc', - show_default=True, - help='List the entries in ascending or descending order' -) - -PAST_DAYS = OverridableOption( - '-p', - '--past-days', - 'past_days', - type=click.INT, - metavar='PAST_DAYS', - help='Only include entries created in the last PAST_DAYS number of days.' -) - -OLDER_THAN = OverridableOption( - '-o', - '--older-than', - 'older_than', - type=click.INT, - metavar='OLDER_THAN', - help='Only include entries created before OLDER_THAN days ago.' -) - -ALL = OverridableOption( - '-a', - '--all', - 'all_entries', - is_flag=True, - default=False, - help='Include all entries, disregarding all other filter options and flags.' -) - -ALL_STATES = OverridableOption('-A', '--all-states', is_flag=True, help='Do not limit to items in running state.') - -ALL_USERS = OverridableOption( - '-A', '--all-users', 'all_users', is_flag=True, default=False, help='Include all entries regardless of the owner.' -) - -GROUP_CLEAR = OverridableOption( - '-c', '--clear', is_flag=True, default=False, help='Remove all the nodes from the group.' -) - -RAW = OverridableOption( - '-r', - '--raw', - 'raw', - is_flag=True, - default=False, - help='Display only raw query results, without any headers or footers.' -) - -HOSTNAME = OverridableOption('-H', '--hostname', type=types.HostnameType(), help='Hostname.') - -TRANSPORT = OverridableOption( - '-T', - '--transport', - type=types.PluginParamType(group='transports'), - required=True, - help="A transport plugin (as listed in 'verdi plugin list aiida.transports')." -) - -SCHEDULER = OverridableOption( - '-S', - '--scheduler', - type=types.PluginParamType(group='schedulers'), - required=True, - help="A scheduler plugin (as listed in 'verdi plugin list aiida.schedulers')." -) - -USER = OverridableOption('-u', '--user', 'user', type=types.UserParamType(), help='Email address of the user.') - -PORT = OverridableOption('-P', '--port', 'port', type=click.INT, help='Port number.') - -FREQUENCY = OverridableOption('-F', '--frequency', 'frequency', type=click.INT) - -VERBOSE = OverridableOption('-v', '--verbose', is_flag=True, default=False, help='Be more verbose in printing output.') - -TIMEOUT = OverridableOption( - '-t', - '--timeout', - type=click.FLOAT, - default=5.0, - show_default=True, - help='Time in seconds to wait for a response before timing out.' -) - -WAIT = OverridableOption( - '--wait/--no-wait', - default=False, - help='Wait for the action to be completed otherwise return as soon as it is scheduled.' -) - -FORMULA_MODE = OverridableOption( - '-f', - '--formula-mode', - type=click.Choice(['hill', 'hill_compact', 'reduce', 'group', 'count', 'count_compact']), - default='hill', - help='Mode for printing the chemical formula.' -) - -TRAJECTORY_INDEX = OverridableOption( - '-i', - '--trajectory-index', - 'trajectory_index', - type=click.INT, - default=None, - help='Specific step of the Trajectory to select.' -) - -WITH_ELEMENTS = OverridableOption( - '-e', - '--with-elements', - 'elements', - type=click.STRING, - cls=MultipleValueOption, - default=None, - help='Only select objects containing these elements.' -) - -WITH_ELEMENTS_EXCLUSIVE = OverridableOption( - '-E', - '--with-elements-exclusive', - 'elements_exclusive', - type=click.STRING, - cls=MultipleValueOption, - default=None, - help='Only select objects containing only these and no other elements.' -) - -CONFIG_FILE = ConfigFileOption( - '--config', - type=types.FileOrUrl(), - help='Load option values from configuration file in yaml format (local path or URL).' -) - -IDENTIFIER = OverridableOption( - '-i', - '--identifier', - 'identifier', - help='The type of identifier used for specifying each node.', - default='pk', - type=click.Choice(['pk', 'uuid']) -) - -DICT_FORMAT = OverridableOption( - '-f', - '--format', - 'fmt', - type=click.Choice(list(echo.VALID_DICT_FORMATS_MAPPING.keys())), - default=list(echo.VALID_DICT_FORMATS_MAPPING.keys())[0], - help='The format of the output data.' -) - -DICT_KEYS = OverridableOption( - '-k', '--keys', type=click.STRING, cls=MultipleValueOption, help='Filter the output by one or more keys.' -) - -DEBUG = OverridableOption( - '--debug', is_flag=True, default=False, help='Show debug messages. Mostly relevant for developers.', hidden=True -) - -PRINT_TRACEBACK = OverridableOption( - '-t', - '--print-traceback', - is_flag=True, - help='Print the full traceback in case an exception is raised.', -) +__all__ = ( + 'ALL', + 'ALL_STATES', + 'ALL_USERS', + 'APPEND_TEXT', + 'ARCHIVE_FORMAT', + 'BROKER_HOST', + 'BROKER_PASSWORD', + 'BROKER_PORT', + 'BROKER_PROTOCOL', + 'BROKER_USERNAME', + 'BROKER_VIRTUAL_HOST', + 'CALCULATION', + 'CALCULATIONS', + 'CALC_JOB_STATE', + 'CODE', + 'CODES', + 'COMPUTER', + 'COMPUTERS', + 'CONFIG_FILE', + 'ConfigFileOption', + 'ContextualDefaultOption', + 'DATA', + 'DATUM', + 'DB_BACKEND', + 'DB_ENGINE', + 'DB_HOST', + 'DB_NAME', + 'DB_PASSWORD', + 'DB_PORT', + 'DB_USERNAME', + 'DEBUG', + 'DESCRIPTION', + 'DICT_FORMAT', + 'DICT_KEYS', + 'DRY_RUN', + 'EXIT_STATUS', + 'EXPORT_FORMAT', + 'FAILED', + 'FORCE', + 'FORMULA_MODE', + 'FREQUENCY', + 'GROUP', + 'GROUPS', + 'GROUP_CLEAR', + 'HOSTNAME', + 'IDENTIFIER', + 'INPUT_FORMAT', + 'INPUT_PLUGIN', + 'LABEL', + 'LIMIT', + 'MultipleValueOption', + 'NODE', + 'NODES', + 'NON_INTERACTIVE', + 'OLDER_THAN', + 'ORDER_BY', + 'ORDER_DIRECTION', + 'OverridableOption', + 'PAST_DAYS', + 'PAUSED', + 'PORT', + 'PREPEND_TEXT', + 'PRINT_TRACEBACK', + 'PROCESS_LABEL', + 'PROCESS_STATE', + 'PROFILE', + 'PROFILE_ONLY_CONFIG', + 'PROFILE_SET_DEFAULT', + 'PROJECT', + 'RAW', + 'REPOSITORY_PATH', + 'SCHEDULER', + 'SILENT', + 'TIMEOUT', + 'TRAJECTORY_INDEX', + 'TRANSPORT', + 'TRAVERSAL_RULE_HELP_STRING', + 'TYPE_STRING', + 'USER', + 'USER_EMAIL', + 'USER_FIRST_NAME', + 'USER_INSTITUTION', + 'USER_LAST_NAME', + 'VERBOSE', + 'VISUALIZATION_FORMAT', + 'WAIT', + 'WITH_ELEMENTS', + 'WITH_ELEMENTS_EXCLUSIVE', + 'active_process_states', + 'graph_traversal_rules', + 'valid_calc_job_states', + 'valid_process_states', +) + +# yapf: enable diff --git a/aiida/cmdline/params/options/config.py b/aiida/cmdline/params/options/config.py index 9ab5d82278..a4c7b61dbc 100644 --- a/aiida/cmdline/params/options/config.py +++ b/aiida/cmdline/params/options/config.py @@ -17,6 +17,8 @@ from .overridable import OverridableOption +__all__ = ('ConfigFileOption',) + def yaml_config_file_provider(handle, cmd_name): # pylint: disable=unused-argument """Read yaml config file from file handle.""" diff --git a/aiida/cmdline/params/options/contextualdefault.py b/aiida/cmdline/params/options/contextualdefault.py index 1642b45127..a851371363 100644 --- a/aiida/cmdline/params/options/contextualdefault.py +++ b/aiida/cmdline/params/options/contextualdefault.py @@ -15,6 +15,8 @@ import click +__all__ = ('ContextualDefaultOption',) + class ContextualDefaultOption(click.Option): """A class that extends click.Option allowing to define a default callable diff --git a/aiida/cmdline/params/options/main.py b/aiida/cmdline/params/options/main.py new file mode 100644 index 0000000000..6df9e34cdb --- /dev/null +++ b/aiida/cmdline/params/options/main.py @@ -0,0 +1,599 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Module with pre-defined reusable commandline options that can be used as `click` decorators.""" +import click +from pgsu import DEFAULT_DSN as DEFAULT_DBINFO # pylint: disable=no-name-in-module + +from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA +from aiida.manage.external.rmq import BROKER_DEFAULTS +from ...utils import defaults, echo +from .. import types +from .multivalue import MultipleValueOption +from .overridable import OverridableOption +from .config import ConfigFileOption + +__all__ = ( + 'ALL', 'ALL_STATES', 'ALL_USERS', 'APPEND_TEXT', 'ARCHIVE_FORMAT', 'BROKER_HOST', 'BROKER_PASSWORD', 'BROKER_PORT', + 'BROKER_PROTOCOL', 'BROKER_USERNAME', 'BROKER_VIRTUAL_HOST', 'CALCULATION', 'CALCULATIONS', 'CALC_JOB_STATE', + 'CODE', 'CODES', 'COMPUTER', 'COMPUTERS', 'CONFIG_FILE', 'DATA', 'DATUM', 'DB_BACKEND', 'DB_ENGINE', 'DB_HOST', + 'DB_NAME', 'DB_PASSWORD', 'DB_PORT', 'DB_USERNAME', 'DEBUG', 'DESCRIPTION', 'DICT_FORMAT', 'DICT_KEYS', 'DRY_RUN', + 'EXIT_STATUS', 'EXPORT_FORMAT', 'FAILED', 'FORCE', 'FORMULA_MODE', 'FREQUENCY', 'GROUP', 'GROUPS', 'GROUP_CLEAR', + 'HOSTNAME', 'IDENTIFIER', 'INPUT_FORMAT', 'INPUT_PLUGIN', 'LABEL', 'LIMIT', 'NODE', 'NODES', 'NON_INTERACTIVE', + 'OLDER_THAN', 'ORDER_BY', 'ORDER_DIRECTION', 'PAST_DAYS', 'PAUSED', 'PORT', 'PREPEND_TEXT', 'PRINT_TRACEBACK', + 'PROCESS_LABEL', 'PROCESS_STATE', 'PROFILE', 'PROFILE_ONLY_CONFIG', 'PROFILE_SET_DEFAULT', 'PROJECT', 'RAW', + 'REPOSITORY_PATH', 'SCHEDULER', 'SILENT', 'TIMEOUT', 'TRAJECTORY_INDEX', 'TRANSPORT', 'TRAVERSAL_RULE_HELP_STRING', + 'TYPE_STRING', 'USER', 'USER_EMAIL', 'USER_FIRST_NAME', 'USER_INSTITUTION', 'USER_LAST_NAME', 'VERBOSE', + 'VISUALIZATION_FORMAT', 'WAIT', 'WITH_ELEMENTS', 'WITH_ELEMENTS_EXCLUSIVE', 'active_process_states', + 'graph_traversal_rules', 'valid_calc_job_states', 'valid_process_states' +) + +TRAVERSAL_RULE_HELP_STRING = { + 'call_calc_backward': 'CALL links to calculations backwards', + 'call_calc_forward': 'CALL links to calculations forwards', + 'call_work_backward': 'CALL links to workflows backwards', + 'call_work_forward': 'CALL links to workflows forwards', + 'input_calc_backward': 'INPUT links to calculations backwards', + 'input_calc_forward': 'INPUT links to calculations forwards', + 'input_work_backward': 'INPUT links to workflows backwards', + 'input_work_forward': 'INPUT links to workflows forwards', + 'return_backward': 'RETURN links backwards', + 'return_forward': 'RETURN links forwards', + 'create_backward': 'CREATE links backwards', + 'create_forward': 'CREATE links forwards', +} + + +def valid_process_states(): + """Return a list of valid values for the ProcessState enum.""" + from plumpy import ProcessState + return tuple(state.value for state in ProcessState) + + +def valid_calc_job_states(): + """Return a list of valid values for the CalcState enum.""" + from aiida.common.datastructures import CalcJobState + return tuple(state.value for state in CalcJobState) + + +def active_process_states(): + """Return a list of process states that are considered active.""" + from plumpy import ProcessState + return ([ + ProcessState.CREATED.value, + ProcessState.WAITING.value, + ProcessState.RUNNING.value, + ]) + + +def graph_traversal_rules(rules): + """Apply the graph traversal rule options to the command.""" + + def decorator(command): + """Only apply to traversal rules if they are toggleable.""" + for name, traversal_rule in sorted(rules.items(), reverse=True): + if traversal_rule.toggleable: + option_name = name.replace('_', '-') + option_label = '--{option_name}/--no-{option_name}'.format(option_name=option_name) + help_string = f'Whether to expand the node set by following {TRAVERSAL_RULE_HELP_STRING[name]}.' + click.option(option_label, default=traversal_rule.default, show_default=True, help=help_string)(command) + + return command + + return decorator + + +PROFILE = OverridableOption( + '-p', + '--profile', + 'profile', + type=types.ProfileParamType(), + default=defaults.get_default_profile, + help='Execute the command for this profile instead of the default profile.' +) + +CALCULATION = OverridableOption( + '-C', + '--calculation', + 'calculation', + type=types.CalculationParamType(), + help='A single calculation identified by its ID or UUID.' +) + +CALCULATIONS = OverridableOption( + '-C', + '--calculations', + 'calculations', + type=types.CalculationParamType(), + cls=MultipleValueOption, + help='One or multiple calculations identified by their ID or UUID.' +) + +CODE = OverridableOption( + '-X', '--code', 'code', type=types.CodeParamType(), help='A single code identified by its ID, UUID or label.' +) + +CODES = OverridableOption( + '-X', + '--codes', + 'codes', + type=types.CodeParamType(), + cls=MultipleValueOption, + help='One or multiple codes identified by their ID, UUID or label.' +) + +COMPUTER = OverridableOption( + '-Y', + '--computer', + 'computer', + type=types.ComputerParamType(), + help='A single computer identified by its ID, UUID or label.' +) + +COMPUTERS = OverridableOption( + '-Y', + '--computers', + 'computers', + type=types.ComputerParamType(), + cls=MultipleValueOption, + help='One or multiple computers identified by their ID, UUID or label.' +) + +DATUM = OverridableOption( + '-D', '--datum', 'datum', type=types.DataParamType(), help='A single datum identified by its ID, UUID or label.' +) + +DATA = OverridableOption( + '-D', + '--data', + 'data', + type=types.DataParamType(), + cls=MultipleValueOption, + help='One or multiple data identified by their ID, UUID or label.' +) + +GROUP = OverridableOption( + '-G', '--group', 'group', type=types.GroupParamType(), help='A single group identified by its ID, UUID or label.' +) + +GROUPS = OverridableOption( + '-G', + '--groups', + 'groups', + type=types.GroupParamType(), + cls=MultipleValueOption, + help='One or multiple groups identified by their ID, UUID or label.' +) + +NODE = OverridableOption( + '-N', '--node', 'node', type=types.NodeParamType(), help='A single node identified by its ID or UUID.' +) + +NODES = OverridableOption( + '-N', + '--nodes', + 'nodes', + type=types.NodeParamType(), + cls=MultipleValueOption, + help='One or multiple nodes identified by their ID or UUID.' +) + +FORCE = OverridableOption('-f', '--force', is_flag=True, default=False, help='Do not ask for confirmation.') + +SILENT = OverridableOption('-s', '--silent', is_flag=True, default=False, help='Suppress any output printed to stdout.') + +VISUALIZATION_FORMAT = OverridableOption( + '-F', '--format', 'fmt', show_default=True, help='Format of the visualized output.' +) + +INPUT_FORMAT = OverridableOption('-F', '--format', 'fmt', show_default=True, help='Format of the input file.') + +EXPORT_FORMAT = OverridableOption('-F', '--format', 'fmt', show_default=True, help='Format of the exported file.') + +ARCHIVE_FORMAT = OverridableOption( + '-F', + '--archive-format', + type=click.Choice(['zip', 'zip-uncompressed', 'tar.gz']), + default='zip', + show_default=True, + help='The format of the archive file.' +) + +NON_INTERACTIVE = OverridableOption( + '-n', + '--non-interactive', + is_flag=True, + is_eager=True, + help='In non-interactive mode, the CLI never prompts but simply uses default values for options that define one.' +) + +DRY_RUN = OverridableOption('-n', '--dry-run', is_flag=True, help='Perform a dry run.') + +USER_EMAIL = OverridableOption( + '--email', + 'email', + type=types.EmailType(), + help='Email address associated with the data you generate. The email address is exported along with the data, ' + 'when sharing it.' +) + +USER_FIRST_NAME = OverridableOption( + '--first-name', type=types.NonEmptyStringParamType(), help='First name of the user.' +) + +USER_LAST_NAME = OverridableOption('--last-name', type=types.NonEmptyStringParamType(), help='Last name of the user.') + +USER_INSTITUTION = OverridableOption( + '--institution', type=types.NonEmptyStringParamType(), help='Institution of the user.' +) + +DB_ENGINE = OverridableOption( + '--db-engine', + help='Engine to use to connect to the database.', + default='postgresql_psycopg2', + type=click.Choice(['postgresql_psycopg2']) +) + +DB_BACKEND = OverridableOption( + '--db-backend', + type=click.Choice([BACKEND_DJANGO, BACKEND_SQLA]), + default=BACKEND_DJANGO, + help='Database backend to use.' +) + +DB_HOST = OverridableOption( + '--db-host', + type=types.HostnameType(), + help='Database server host. Leave empty for "peer" authentication.', + default='localhost' +) + +DB_PORT = OverridableOption( + '--db-port', + type=click.INT, + help='Database server port.', + default=DEFAULT_DBINFO['port'], +) + +DB_USERNAME = OverridableOption( + '--db-username', type=types.NonEmptyStringParamType(), help='Name of the database user.' +) + +DB_PASSWORD = OverridableOption( + '--db-password', + type=click.STRING, + help='Password of the database user.', + hide_input=True, +) + +DB_NAME = OverridableOption('--db-name', type=types.NonEmptyStringParamType(), help='Database name.') + +BROKER_PROTOCOL = OverridableOption( + '--broker-protocol', + type=click.Choice(('amqp', 'amqps')), + default=BROKER_DEFAULTS.protocol, + show_default=True, + help='Protocol to use for the message broker.' +) + +BROKER_USERNAME = OverridableOption( + '--broker-username', + type=types.NonEmptyStringParamType(), + default=BROKER_DEFAULTS.username, + show_default=True, + help='Username to use for authentication with the message broker.' +) + +BROKER_PASSWORD = OverridableOption( + '--broker-password', + type=types.NonEmptyStringParamType(), + default=BROKER_DEFAULTS.password, + show_default=True, + help='Password to use for authentication with the message broker.', + hide_input=True, +) + +BROKER_HOST = OverridableOption( + '--broker-host', + type=types.HostnameType(), + default=BROKER_DEFAULTS.host, + show_default=True, + help='Hostname for the message broker.' +) + +BROKER_PORT = OverridableOption( + '--broker-port', + type=click.INT, + default=BROKER_DEFAULTS.port, + show_default=True, + help='Port for the message broker.', +) + +BROKER_VIRTUAL_HOST = OverridableOption( + '--broker-virtual-host', + type=click.types.StringParamType(), + default=BROKER_DEFAULTS.virtual_host, + show_default=True, + help='Name of the virtual host for the message broker without leading forward slash.' +) + +REPOSITORY_PATH = OverridableOption( + '--repository', type=click.Path(file_okay=False), help='Absolute path to the file repository.' +) + +PROFILE_ONLY_CONFIG = OverridableOption( + '--only-config', is_flag=True, default=False, help='Only configure the user and skip creating the database.' +) + +PROFILE_SET_DEFAULT = OverridableOption( + '--set-default', is_flag=True, default=False, help='Set the profile as the new default.' +) + +PREPEND_TEXT = OverridableOption( + '--prepend-text', type=click.STRING, default='', help='Bash script to be executed before an action.' +) + +APPEND_TEXT = OverridableOption( + '--append-text', type=click.STRING, default='', help='Bash script to be executed after an action has completed.' +) + +LABEL = OverridableOption('-L', '--label', type=click.STRING, metavar='LABEL', help='Short name to be used as a label.') + +DESCRIPTION = OverridableOption( + '-D', + '--description', + type=click.STRING, + metavar='DESCRIPTION', + default='', + required=False, + help='A detailed description.' +) + +INPUT_PLUGIN = OverridableOption( + '-P', '--input-plugin', type=types.PluginParamType(group='calculations'), help='Calculation input plugin string.' +) + +CALC_JOB_STATE = OverridableOption( + '-s', + '--calc-job-state', + 'calc_job_state', + type=types.LazyChoice(valid_calc_job_states), + cls=MultipleValueOption, + help='Only include entries with this calculation job state.' +) + +PROCESS_STATE = OverridableOption( + '-S', + '--process-state', + 'process_state', + type=types.LazyChoice(valid_process_states), + cls=MultipleValueOption, + default=active_process_states, + help='Only include entries with this process state.' +) + +PAUSED = OverridableOption('--paused', 'paused', is_flag=True, help='Only include entries that are paused.') + +PROCESS_LABEL = OverridableOption( + '-L', + '--process-label', + 'process_label', + type=click.STRING, + required=False, + help='Only include entries whose process label matches this filter.' +) + +TYPE_STRING = OverridableOption( + '-T', + '--type-string', + 'type_string', + type=click.STRING, + required=False, + help='Only include entries whose type string matches this filter. Can include `_` to match a single arbitrary ' + 'character or `%` to match any number of characters.' +) + +EXIT_STATUS = OverridableOption( + '-E', '--exit-status', 'exit_status', type=click.INT, help='Only include entries with this exit status.' +) + +FAILED = OverridableOption( + '-X', '--failed', 'failed', is_flag=True, default=False, help='Only include entries that have failed.' +) + +LIMIT = OverridableOption( + '-l', '--limit', 'limit', type=click.INT, default=None, help='Limit the number of entries to display.' +) + +PROJECT = OverridableOption( + '-P', '--project', 'project', cls=MultipleValueOption, help='Select the list of entity attributes to project.' +) + +ORDER_BY = OverridableOption( + '-O', + '--order-by', + 'order_by', + type=click.Choice(['id', 'ctime']), + default='ctime', + show_default=True, + help='Order the entries by this attribute.' +) + +ORDER_DIRECTION = OverridableOption( + '-D', + '--order-direction', + 'order_dir', + type=click.Choice(['asc', 'desc']), + default='asc', + show_default=True, + help='List the entries in ascending or descending order' +) + +PAST_DAYS = OverridableOption( + '-p', + '--past-days', + 'past_days', + type=click.INT, + metavar='PAST_DAYS', + help='Only include entries created in the last PAST_DAYS number of days.' +) + +OLDER_THAN = OverridableOption( + '-o', + '--older-than', + 'older_than', + type=click.INT, + metavar='OLDER_THAN', + help='Only include entries created before OLDER_THAN days ago.' +) + +ALL = OverridableOption( + '-a', + '--all', + 'all_entries', + is_flag=True, + default=False, + help='Include all entries, disregarding all other filter options and flags.' +) + +ALL_STATES = OverridableOption('-A', '--all-states', is_flag=True, help='Do not limit to items in running state.') + +ALL_USERS = OverridableOption( + '-A', '--all-users', 'all_users', is_flag=True, default=False, help='Include all entries regardless of the owner.' +) + +GROUP_CLEAR = OverridableOption( + '-c', '--clear', is_flag=True, default=False, help='Remove all the nodes from the group.' +) + +RAW = OverridableOption( + '-r', + '--raw', + 'raw', + is_flag=True, + default=False, + help='Display only raw query results, without any headers or footers.' +) + +HOSTNAME = OverridableOption('-H', '--hostname', type=types.HostnameType(), help='Hostname.') + +TRANSPORT = OverridableOption( + '-T', + '--transport', + type=types.PluginParamType(group='transports'), + required=True, + help="A transport plugin (as listed in 'verdi plugin list aiida.transports')." +) + +SCHEDULER = OverridableOption( + '-S', + '--scheduler', + type=types.PluginParamType(group='schedulers'), + required=True, + help="A scheduler plugin (as listed in 'verdi plugin list aiida.schedulers')." +) + +USER = OverridableOption('-u', '--user', 'user', type=types.UserParamType(), help='Email address of the user.') + +PORT = OverridableOption('-P', '--port', 'port', type=click.INT, help='Port number.') + +FREQUENCY = OverridableOption('-F', '--frequency', 'frequency', type=click.INT) + +VERBOSE = OverridableOption('-v', '--verbose', is_flag=True, default=False, help='Be more verbose in printing output.') + +TIMEOUT = OverridableOption( + '-t', + '--timeout', + type=click.FLOAT, + default=5.0, + show_default=True, + help='Time in seconds to wait for a response before timing out.' +) + +WAIT = OverridableOption( + '--wait/--no-wait', + default=False, + help='Wait for the action to be completed otherwise return as soon as it is scheduled.' +) + +FORMULA_MODE = OverridableOption( + '-f', + '--formula-mode', + type=click.Choice(['hill', 'hill_compact', 'reduce', 'group', 'count', 'count_compact']), + default='hill', + help='Mode for printing the chemical formula.' +) + +TRAJECTORY_INDEX = OverridableOption( + '-i', + '--trajectory-index', + 'trajectory_index', + type=click.INT, + default=None, + help='Specific step of the Trajectory to select.' +) + +WITH_ELEMENTS = OverridableOption( + '-e', + '--with-elements', + 'elements', + type=click.STRING, + cls=MultipleValueOption, + default=None, + help='Only select objects containing these elements.' +) + +WITH_ELEMENTS_EXCLUSIVE = OverridableOption( + '-E', + '--with-elements-exclusive', + 'elements_exclusive', + type=click.STRING, + cls=MultipleValueOption, + default=None, + help='Only select objects containing only these and no other elements.' +) + +CONFIG_FILE = ConfigFileOption( + '--config', + type=types.FileOrUrl(), + help='Load option values from configuration file in yaml format (local path or URL).' +) + +IDENTIFIER = OverridableOption( + '-i', + '--identifier', + 'identifier', + help='The type of identifier used for specifying each node.', + default='pk', + type=click.Choice(['pk', 'uuid']) +) + +DICT_FORMAT = OverridableOption( + '-f', + '--format', + 'fmt', + type=click.Choice(list(echo.VALID_DICT_FORMATS_MAPPING.keys())), + default=list(echo.VALID_DICT_FORMATS_MAPPING.keys())[0], + help='The format of the output data.' +) + +DICT_KEYS = OverridableOption( + '-k', '--keys', type=click.STRING, cls=MultipleValueOption, help='Filter the output by one or more keys.' +) + +DEBUG = OverridableOption( + '--debug', is_flag=True, default=False, help='Show debug messages. Mostly relevant for developers.', hidden=True +) + +PRINT_TRACEBACK = OverridableOption( + '-t', + '--print-traceback', + is_flag=True, + help='Print the full traceback in case an exception is raised.', +) diff --git a/aiida/cmdline/params/options/multivalue.py b/aiida/cmdline/params/options/multivalue.py index 9b8fa9a3d2..652b982261 100644 --- a/aiida/cmdline/params/options/multivalue.py +++ b/aiida/cmdline/params/options/multivalue.py @@ -15,6 +15,8 @@ from .. import types +__all__ = ('MultipleValueOption',) + def collect_usage_pieces(self, ctx): """Returns all the pieces that go into the usage line and returns it as a list of strings.""" diff --git a/aiida/cmdline/params/options/overridable.py b/aiida/cmdline/params/options/overridable.py index a8f7a183d9..fae2ca0aff 100644 --- a/aiida/cmdline/params/options/overridable.py +++ b/aiida/cmdline/params/options/overridable.py @@ -16,6 +16,8 @@ import click +__all__ = ('OverridableOption',) + class OverridableOption: """ diff --git a/aiida/cmdline/params/types/__init__.py b/aiida/cmdline/params/types/__init__.py index cedb380572..0c3c24ec61 100644 --- a/aiida/cmdline/params/types/__init__.py +++ b/aiida/cmdline/params/types/__init__.py @@ -9,29 +9,57 @@ ########################################################################### """Provides all parameter types.""" -from .calculation import CalculationParamType -from .choice import LazyChoice -from .code import CodeParamType -from .computer import ComputerParamType, ShebangParamType, MpirunCommandParamType -from .config import ConfigOptionParamType -from .data import DataParamType -from .group import GroupParamType -from .identifier import IdentifierParamType -from .multiple import MultipleValueParamType -from .node import NodeParamType -from .process import ProcessParamType -from .strings import (NonEmptyStringParamType, EmailType, HostnameType, EntryPointType, LabelStringType) -from .path import AbsolutePathParamType, PathOrUrl, FileOrUrl -from .plugin import PluginParamType -from .profile import ProfileParamType -from .user import UserParamType -from .test_module import TestModuleParamType -from .workflow import WorkflowParamType +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .calculation import * +from .choice import * +from .code import * +from .computer import * +from .config import * +from .data import * +from .group import * +from .identifier import * +from .multiple import * +from .node import * +from .path import * +from .plugin import * +from .process import * +from .profile import * +from .strings import * +from .test_module import * +from .user import * +from .workflow import * __all__ = ( - 'LazyChoice', 'IdentifierParamType', 'CalculationParamType', 'CodeParamType', 'ComputerParamType', - 'ConfigOptionParamType', 'DataParamType', 'GroupParamType', 'NodeParamType', 'MpirunCommandParamType', - 'MultipleValueParamType', 'NonEmptyStringParamType', 'PluginParamType', 'AbsolutePathParamType', 'ShebangParamType', - 'UserParamType', 'TestModuleParamType', 'ProfileParamType', 'WorkflowParamType', 'ProcessParamType', 'PathOrUrl', - 'FileOrUrl' + 'AbsolutePathParamType', + 'CalculationParamType', + 'CodeParamType', + 'ComputerParamType', + 'ConfigOptionParamType', + 'DataParamType', + 'EmailType', + 'EntryPointType', + 'FileOrUrl', + 'GroupParamType', + 'HostnameType', + 'IdentifierParamType', + 'LabelStringType', + 'LazyChoice', + 'MpirunCommandParamType', + 'MultipleValueParamType', + 'NodeParamType', + 'NonEmptyStringParamType', + 'PathOrUrl', + 'PluginParamType', + 'ProcessParamType', + 'ProfileParamType', + 'ShebangParamType', + 'TestModuleParamType', + 'UserParamType', + 'WorkflowParamType', ) + +# yapf: enable diff --git a/aiida/cmdline/params/types/calculation.py b/aiida/cmdline/params/types/calculation.py index a9dd484b4f..2e4c0d0750 100644 --- a/aiida/cmdline/params/types/calculation.py +++ b/aiida/cmdline/params/types/calculation.py @@ -13,6 +13,8 @@ from .identifier import IdentifierParamType +__all__ = ('CalculationParamType',) + class CalculationParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/params/types/choice.py b/aiida/cmdline/params/types/choice.py index 92d5894eb3..2a6a2c2190 100644 --- a/aiida/cmdline/params/types/choice.py +++ b/aiida/cmdline/params/types/choice.py @@ -12,6 +12,8 @@ """ import click +__all__ = ('LazyChoice',) + class LazyChoice(click.ParamType): """ diff --git a/aiida/cmdline/params/types/code.py b/aiida/cmdline/params/types/code.py index da1c6753bc..3e7a2803a7 100644 --- a/aiida/cmdline/params/types/code.py +++ b/aiida/cmdline/params/types/code.py @@ -13,6 +13,8 @@ from aiida.cmdline.utils import decorators from .identifier import IdentifierParamType +__all__ = ('CodeParamType',) + class CodeParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/params/types/computer.py b/aiida/cmdline/params/types/computer.py index 8667363ef8..d70fc00d8f 100644 --- a/aiida/cmdline/params/types/computer.py +++ b/aiida/cmdline/params/types/computer.py @@ -16,6 +16,8 @@ from ...utils import decorators from .identifier import IdentifierParamType +__all__ = ('ComputerParamType', 'ShebangParamType', 'MpirunCommandParamType') + class ComputerParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/params/types/data.py b/aiida/cmdline/params/types/data.py index 02c896f4b7..742dec10eb 100644 --- a/aiida/cmdline/params/types/data.py +++ b/aiida/cmdline/params/types/data.py @@ -12,6 +12,8 @@ """ from .identifier import IdentifierParamType +__all__ = ('DataParamType',) + class DataParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/params/types/group.py b/aiida/cmdline/params/types/group.py index 0645ac6e65..01a31588d5 100644 --- a/aiida/cmdline/params/types/group.py +++ b/aiida/cmdline/params/types/group.py @@ -15,6 +15,8 @@ from .identifier import IdentifierParamType +__all__ = ('GroupParamType',) + class GroupParamType(IdentifierParamType): """The ParamType for identifying Group entities or its subclasses.""" diff --git a/aiida/cmdline/params/types/identifier.py b/aiida/cmdline/params/types/identifier.py index 513ee2a82b..058712a090 100644 --- a/aiida/cmdline/params/types/identifier.py +++ b/aiida/cmdline/params/types/identifier.py @@ -17,6 +17,8 @@ from aiida.cmdline.utils.decorators import with_dbenv from aiida.plugins.entry_point import get_entry_point_from_string +__all__ = ('IdentifierParamType',) + class IdentifierParamType(click.ParamType, ABC): """ diff --git a/aiida/cmdline/params/types/multiple.py b/aiida/cmdline/params/types/multiple.py index 733ce7dcd4..a5d4a9f5b5 100644 --- a/aiida/cmdline/params/types/multiple.py +++ b/aiida/cmdline/params/types/multiple.py @@ -12,6 +12,8 @@ """ import click +__all__ = ('MultipleValueParamType',) + class MultipleValueParamType(click.ParamType): """ diff --git a/aiida/cmdline/params/types/node.py b/aiida/cmdline/params/types/node.py index 568dbf50fd..7642eb22d5 100644 --- a/aiida/cmdline/params/types/node.py +++ b/aiida/cmdline/params/types/node.py @@ -12,6 +12,8 @@ """ from .identifier import IdentifierParamType +__all__ = ('NodeParamType',) + class NodeParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/params/types/path.py b/aiida/cmdline/params/types/path.py index daeb5ae115..7dc3a0d9f8 100644 --- a/aiida/cmdline/params/types/path.py +++ b/aiida/cmdline/params/types/path.py @@ -15,6 +15,8 @@ from socket import timeout import click +__all__ = ('AbsolutePathParamType', 'FileOrUrl', 'PathOrUrl') + URL_TIMEOUT_SECONDS = 10 diff --git a/aiida/cmdline/params/types/plugin.py b/aiida/cmdline/params/types/plugin.py index 387e5127a5..588a3b7877 100644 --- a/aiida/cmdline/params/types/plugin.py +++ b/aiida/cmdline/params/types/plugin.py @@ -16,7 +16,9 @@ from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR, ENTRY_POINT_GROUP_PREFIX, EntryPointFormat from aiida.plugins.entry_point import format_entry_point_string, get_entry_point_string_format from aiida.plugins.entry_point import get_entry_point, get_entry_points, get_entry_point_groups -from ..types import EntryPointType +from .strings import EntryPointType + +__all__ = ('PluginParamType',) class PluginParamType(EntryPointType): @@ -143,7 +145,7 @@ def complete(self, ctx, incomplete): # pylint: disable=unused-argument """ return [(p, '') for p in self.get_possibilities(incomplete=incomplete)] - def get_missing_message(self, param): + def get_missing_message(self, param): # pylint: disable=unused-argument return 'Possible arguments are:\n\n' + '\n'.join(self.get_valid_arguments()) def get_entry_point_from_string(self, entry_point_string): diff --git a/aiida/cmdline/params/types/process.py b/aiida/cmdline/params/types/process.py index e18ef66bd7..0cbe5abf65 100644 --- a/aiida/cmdline/params/types/process.py +++ b/aiida/cmdline/params/types/process.py @@ -13,6 +13,8 @@ from .identifier import IdentifierParamType +__all__ = ('ProcessParamType',) + class ProcessParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/params/types/profile.py b/aiida/cmdline/params/types/profile.py index 61d5737f9c..3cf5449709 100644 --- a/aiida/cmdline/params/types/profile.py +++ b/aiida/cmdline/params/types/profile.py @@ -11,6 +11,8 @@ from .strings import LabelStringType +__all__ = ('ProfileParamType',) + class ProfileParamType(LabelStringType): """The profile parameter type for click.""" diff --git a/aiida/cmdline/params/types/strings.py b/aiida/cmdline/params/types/strings.py index 63abbcd599..9f6ebeb8be 100644 --- a/aiida/cmdline/params/types/strings.py +++ b/aiida/cmdline/params/types/strings.py @@ -14,6 +14,8 @@ import re from click.types import StringParamType +__all__ = ('EmailType', 'EntryPointType', 'HostnameType', 'NonEmptyStringParamType', 'LabelStringType') + class NonEmptyStringParamType(StringParamType): """Parameter whose values have to be string and non-empty.""" diff --git a/aiida/cmdline/params/types/test_module.py b/aiida/cmdline/params/types/test_module.py index d13c80633c..d47dbbef94 100644 --- a/aiida/cmdline/params/types/test_module.py +++ b/aiida/cmdline/params/types/test_module.py @@ -10,6 +10,8 @@ """Test module parameter type for click.""" import click +__all__ = ('TestModuleParamType',) + class TestModuleParamType(click.ParamType): """Parameter type to represent a unittest module. diff --git a/aiida/cmdline/params/types/user.py b/aiida/cmdline/params/types/user.py index 71a1c4eaab..216530b72e 100644 --- a/aiida/cmdline/params/types/user.py +++ b/aiida/cmdline/params/types/user.py @@ -12,6 +12,8 @@ from aiida.cmdline.utils.decorators import with_dbenv +__all__ = ('UserParamType',) + class UserParamType(click.ParamType): """ diff --git a/aiida/cmdline/params/types/workflow.py b/aiida/cmdline/params/types/workflow.py index 0a3fc48b6a..7403ff99f7 100644 --- a/aiida/cmdline/params/types/workflow.py +++ b/aiida/cmdline/params/types/workflow.py @@ -13,6 +13,8 @@ from .identifier import IdentifierParamType +__all__ = ('WorkflowParamType',) + class WorkflowParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/utils/__init__.py b/aiida/cmdline/utils/__init__.py index 2776a55f97..9562427ead 100644 --- a/aiida/cmdline/utils/__init__.py +++ b/aiida/cmdline/utils/__init__.py @@ -7,3 +7,22 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Commandline utility functions.""" +# AUTO-GENERATED + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .ascii_vis import * +from .decorators import * + +__all__ = ( + 'dbenv', + 'format_call_graph', + 'only_if_daemon_running', + 'with_dbenv', +) + +# yapf: enable diff --git a/aiida/common/__init__.py b/aiida/common/__init__.py index ea59db2024..1143da1861 100644 --- a/aiida/common/__init__.py +++ b/aiida/common/__init__.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """ Common data structures, utility classes and functions @@ -15,6 +14,11 @@ """ +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .datastructures import * from .exceptions import * from .extendeddicts import * @@ -23,6 +27,65 @@ from .progress_reporter import * __all__ = ( - datastructures.__all__ + exceptions.__all__ + extendeddicts.__all__ + links.__all__ + log.__all__ + - progress_reporter.__all__ + 'AIIDA_LOGGER', + 'AiidaException', + 'AttributeDict', + 'CalcInfo', + 'CalcJobState', + 'CodeInfo', + 'CodeRunMode', + 'ConfigurationError', + 'ConfigurationVersionError', + 'ContentNotExistent', + 'DatabaseMigrationError', + 'DbContentError', + 'DefaultFieldsAttributeDict', + 'EntryPointError', + 'FailedError', + 'FeatureDisabled', + 'FeatureNotAvailable', + 'FixedFieldsAttributeDict', + 'GraphTraversalRule', + 'GraphTraversalRules', + 'HashingError', + 'IncompatibleDatabaseSchema', + 'InputValidationError', + 'IntegrityError', + 'InternalError', + 'InvalidEntryPointTypeError', + 'InvalidOperation', + 'LicensingException', + 'LinkType', + 'LoadingEntryPointError', + 'MissingConfigurationError', + 'MissingEntryPointError', + 'ModificationNotAllowed', + 'MultipleEntryPointError', + 'MultipleObjectsError', + 'NotExistent', + 'NotExistentAttributeError', + 'NotExistentKeyError', + 'OutputParsingError', + 'ParsingError', + 'PluginInternalError', + 'ProfileConfigurationError', + 'ProgressReporterAbstract', + 'RemoteOperationError', + 'StashMode', + 'StoringNotAllowed', + 'TQDM_BAR_FORMAT', + 'TestsNotAllowedError', + 'TransportTaskException', + 'UniquenessError', + 'UnsupportedSpeciesError', + 'ValidationError', + 'create_callback', + 'get_progress_reporter', + 'override_log_formatter', + 'override_log_level', + 'set_progress_bar_tqdm', + 'set_progress_reporter', + 'validate_link_label', ) + +# yapf: enable diff --git a/aiida/common/log.py b/aiida/common/log.py index 10a8686fe6..0cc83df08f 100644 --- a/aiida/common/log.py +++ b/aiida/common/log.py @@ -15,8 +15,6 @@ from contextlib import contextmanager from wrapt import decorator -from aiida.manage.configuration import get_config_option - __all__ = ('AIIDA_LOGGER', 'override_log_level', 'override_log_formatter') # Custom logging level, intended specifically for informative log messages reported during WorkChains. @@ -51,74 +49,77 @@ def filter(self, record): # The default logging dictionary for AiiDA that can be used in conjunction # with the config.dictConfig method of python's logging module -LOGGING = { - 'version': 1, - 'disable_existing_loggers': False, - 'formatters': { - 'verbose': { - 'format': '%(levelname)s %(asctime)s %(module)s %(process)d ' - '%(thread)d %(message)s', - }, - 'halfverbose': { - 'format': '%(asctime)s <%(process)d> %(name)s: [%(levelname)s] %(message)s', - 'datefmt': '%m/%d/%Y %I:%M:%S %p', - }, - }, - 'filters': { - 'testing': { - '()': NotInTestingFilter - } - }, - 'handlers': { - 'console': { - 'level': 'DEBUG', - 'class': 'logging.StreamHandler', - 'formatter': 'halfverbose', - 'filters': ['testing'] - }, - }, - 'loggers': { - 'aiida': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.aiida_loglevel'), - 'propagate': False, +def get_logging_config(): + from aiida.manage.configuration import get_config_option + + return { + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'verbose': { + 'format': '%(levelname)s %(asctime)s %(module)s %(process)d ' + '%(thread)d %(message)s', + }, + 'halfverbose': { + 'format': '%(asctime)s <%(process)d> %(name)s: [%(levelname)s] %(message)s', + 'datefmt': '%m/%d/%Y %I:%M:%S %p', + }, }, - 'plumpy': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.plumpy_loglevel'), - 'propagate': False, + 'filters': { + 'testing': { + '()': NotInTestingFilter + } }, - 'kiwipy': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.kiwipy_loglevel'), - 'propagate': False, + 'handlers': { + 'console': { + 'level': 'DEBUG', + 'class': 'logging.StreamHandler', + 'formatter': 'halfverbose', + 'filters': ['testing'] + }, }, - 'paramiko': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.paramiko_loglevel'), - 'propagate': False, + 'loggers': { + 'aiida': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.aiida_loglevel'), + 'propagate': False, + }, + 'plumpy': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.plumpy_loglevel'), + 'propagate': False, + }, + 'kiwipy': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.kiwipy_loglevel'), + 'propagate': False, + }, + 'paramiko': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.paramiko_loglevel'), + 'propagate': False, + }, + 'alembic': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.alembic_loglevel'), + 'propagate': False, + }, + 'aio_pika': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.aiopika_loglevel'), + 'propagate': False, + }, + 'sqlalchemy': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.sqlalchemy_loglevel'), + 'propagate': False, + 'qualname': 'sqlalchemy.engine', + }, + 'py.warnings': { + 'handlers': ['console'], + }, }, - 'alembic': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.alembic_loglevel'), - 'propagate': False, - }, - 'aio_pika': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.aiopika_loglevel'), - 'propagate': False, - }, - 'sqlalchemy': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.sqlalchemy_loglevel'), - 'propagate': False, - 'qualname': 'sqlalchemy.engine', - }, - 'py.warnings': { - 'handlers': ['console'], - }, - }, -} + } def evaluate_logging_configuration(dictionary): @@ -155,9 +156,11 @@ def configure_logging(with_orm=False, daemon=False, daemon_log_file=None): """ from logging.config import dictConfig + from aiida.manage.configuration import get_config_option + # Evaluate the `LOGGING` configuration to resolve the lambdas that will retrieve the correct values based on the - # currently configured profile. Pass a deep copy of `LOGGING` to ensure that the original remains unaltered. - config = evaluate_logging_configuration(copy.deepcopy(LOGGING)) + # currently configured profile. + config = evaluate_logging_configuration(get_logging_config()) daemon_handler_name = 'daemon_log_file' # Add the daemon file handler to all loggers if daemon=True diff --git a/aiida/engine/__init__.py b/aiida/engine/__init__.py index 984ff61866..523970b934 100644 --- a/aiida/engine/__init__.py +++ b/aiida/engine/__init__.py @@ -7,11 +7,69 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable,redefined-builtin """Module with all the internals that make up the engine of `aiida-core`.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .exceptions import * from .launch import * +from .persistence import * from .processes import * +from .runners import * from .utils import * -__all__ = (launch.__all__ + processes.__all__ + utils.__all__) # type: ignore[name-defined] +__all__ = ( + 'AiiDAPersister', + 'Awaitable', + 'AwaitableAction', + 'AwaitableTarget', + 'BaseRestartWorkChain', + 'CalcJob', + 'CalcJobOutputPort', + 'CalcJobProcessSpec', + 'ExitCode', + 'ExitCodesNamespace', + 'FunctionProcess', + 'InputPort', + 'InterruptableFuture', + 'JobManager', + 'JobsList', + 'ObjectLoader', + 'OutputPort', + 'PORT_NAMESPACE_SEPARATOR', + 'PastException', + 'PortNamespace', + 'Process', + 'ProcessBuilder', + 'ProcessBuilderNamespace', + 'ProcessFuture', + 'ProcessHandlerReport', + 'ProcessSpec', + 'ProcessState', + 'Runner', + 'ToContext', + 'WithNonDb', + 'WithSerialize', + 'WorkChain', + 'append_', + 'assign_', + 'calcfunction', + 'construct_awaitable', + 'get_object_loader', + 'if_', + 'interruptable_task', + 'is_process_function', + 'process_handler', + 'return_', + 'run', + 'run_get_node', + 'run_get_pk', + 'submit', + 'while_', + 'workfunction', +) + +# yapf: enable diff --git a/aiida/engine/processes/__init__.py b/aiida/engine/processes/__init__.py index b3045dcfd4..a2f81d49b3 100644 --- a/aiida/engine/processes/__init__.py +++ b/aiida/engine/processes/__init__.py @@ -7,18 +7,60 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable,redefined-builtin """Module for processes and related utilities.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .builder import * from .calcjobs import * from .exit_code import * from .functions import * +from .futures import * from .ports import * from .process import * from .process_spec import * from .workchains import * __all__ = ( - builder.__all__ + calcjobs.__all__ + exit_code.__all__ + functions.__all__ + # type: ignore[name-defined] - ports.__all__ + process.__all__ + process_spec.__all__ + workchains.__all__ # type: ignore[name-defined] + 'Awaitable', + 'AwaitableAction', + 'AwaitableTarget', + 'BaseRestartWorkChain', + 'CalcJob', + 'CalcJobOutputPort', + 'CalcJobProcessSpec', + 'ExitCode', + 'ExitCodesNamespace', + 'FunctionProcess', + 'InputPort', + 'JobManager', + 'JobsList', + 'OutputPort', + 'PORT_NAMESPACE_SEPARATOR', + 'PortNamespace', + 'Process', + 'ProcessBuilder', + 'ProcessBuilderNamespace', + 'ProcessFuture', + 'ProcessHandlerReport', + 'ProcessSpec', + 'ProcessState', + 'ToContext', + 'WithNonDb', + 'WithSerialize', + 'WorkChain', + 'append_', + 'assign_', + 'calcfunction', + 'construct_awaitable', + 'if_', + 'process_handler', + 'return_', + 'while_', + 'workfunction', ) + +# yapf: enable diff --git a/aiida/engine/processes/calcjobs/__init__.py b/aiida/engine/processes/calcjobs/__init__.py index 57d4777ae7..a91782d092 100644 --- a/aiida/engine/processes/calcjobs/__init__.py +++ b/aiida/engine/processes/calcjobs/__init__.py @@ -7,9 +7,20 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module for the `CalcJob` process and related utilities.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .calcjob import * +from .manager import * + +__all__ = ( + 'CalcJob', + 'JobManager', + 'JobsList', +) -__all__ = (calcjob.__all__) # type: ignore[name-defined] +# yapf: enable diff --git a/aiida/engine/processes/workchains/__init__.py b/aiida/engine/processes/workchains/__init__.py index 9b0cf508c9..56b6a94d2d 100644 --- a/aiida/engine/processes/workchains/__init__.py +++ b/aiida/engine/processes/workchains/__init__.py @@ -7,11 +7,34 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module for the `WorkChain` process and related utilities.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .awaitable import * from .context import * from .restart import * from .utils import * from .workchain import * -__all__ = (context.__all__ + restart.__all__ + utils.__all__ + workchain.__all__) # type: ignore[name-defined] +__all__ = ( + 'Awaitable', + 'AwaitableAction', + 'AwaitableTarget', + 'BaseRestartWorkChain', + 'ProcessHandlerReport', + 'ToContext', + 'WorkChain', + 'append_', + 'assign_', + 'construct_awaitable', + 'if_', + 'process_handler', + 'return_', + 'while_', +) + +# yapf: enable diff --git a/aiida/manage/__init__.py b/aiida/manage/__init__.py index f25c1d5909..bcb5001a50 100644 --- a/aiida/manage/__init__.py +++ b/aiida/manage/__init__.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """ Managing an AiiDA instance: @@ -20,3 +19,50 @@ .. note:: Modules in this sub package may require the database environment to be loaded """ + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .caching import * +from .configuration import * +from .database import * +from .external import * +from .manager import * + +__all__ = ( + 'BROKER_DEFAULTS', + 'CURRENT_CONFIG_VERSION', + 'CommunicationTimeout', + 'Config', + 'ConfigValidationError', + 'DEFAULT_DBINFO', + 'DeliveryFailed', + 'OLDEST_COMPATIBLE_CONFIG_VERSION', + 'Option', + 'Postgres', + 'PostgresConnectionMode', + 'ProcessLauncher', + 'Profile', + 'RemoteException', + 'TABLES_UUID_DEDUPLICATION', + 'check_and_migrate_config', + 'config_needs_migrating', + 'config_schema', + 'deduplicate_uuids', + 'disable_caching', + 'enable_caching', + 'get_current_version', + 'get_duplicate_uuids', + 'get_manager', + 'get_option', + 'get_option_names', + 'get_use_cache', + 'parse_option', + 'reset_manager', + 'verify_uuid_uniqueness', + 'write_database_integrity_violation', +) + +# yapf: enable diff --git a/aiida/manage/configuration/__init__.py b/aiida/manage/configuration/__init__.py index 6860ab2c07..841e0eba7e 100644 --- a/aiida/manage/configuration/__init__.py +++ b/aiida/manage/configuration/__init__.py @@ -7,26 +7,56 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=undefined-variable,wildcard-import,global-statement,redefined-outer-name,cyclic-import """Modules related to the configuration of an AiiDA instance.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .config import * +from .migrations import * +from .options import * +from .profile import * + +__all__ = ( + 'CURRENT_CONFIG_VERSION', + 'Config', + 'ConfigValidationError', + 'OLDEST_COMPATIBLE_CONFIG_VERSION', + 'Option', + 'Profile', + 'check_and_migrate_config', + 'config_needs_migrating', + 'config_schema', + 'get_current_version', + 'get_option', + 'get_option_names', + 'parse_option', +) + +# yapf: enable + +# END AUTO-GENERATED + +# pylint: disable=global-statement,redefined-outer-name,wrong-import-order + +__all__ += ( + 'get_config', 'get_config_option', 'get_config_path', 'get_profile', 'load_documentation_profile', 'load_profile', + 'reset_config', 'reset_profile', 'CONFIG', 'PROFILE', 'BACKEND_UUID' +) + import os import shutil import warnings from aiida.common.warnings import AiidaDeprecationWarning -from .config import * -from .options import * -from .profile import * +from . import options CONFIG = None PROFILE = None BACKEND_UUID = None # This will be set to the UUID of the profile as soon as its corresponding backend is loaded -__all__ = ( - config.__all__ + options.__all__ + profile.__all__ + - ('get_config', 'get_config_option', 'get_config_path', 'load_profile', 'reset_config') -) - def load_profile(profile=None): """Load a profile. diff --git a/aiida/manage/configuration/migrations/__init__.py b/aiida/manage/configuration/migrations/__init__.py index 6b99f32a5d..4aad63827b 100644 --- a/aiida/manage/configuration/migrations/__init__.py +++ b/aiida/manage/configuration/migrations/__init__.py @@ -7,10 +7,22 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=undefined-variable,wildcard-import """Methods and definitions of migrations for the configuration file of an AiiDA instance.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .migrations import * from .utils import * -__all__ = (migrations.__all__ + utils.__all__) +__all__ = ( + 'CURRENT_CONFIG_VERSION', + 'OLDEST_COMPATIBLE_CONFIG_VERSION', + 'check_and_migrate_config', + 'config_needs_migrating', + 'get_current_version', +) + +# yapf: enable diff --git a/aiida/manage/database/__init__.py b/aiida/manage/database/__init__.py index 2776a55f97..f1222d7e8b 100644 --- a/aiida/manage/database/__init__.py +++ b/aiida/manage/database/__init__.py @@ -7,3 +7,21 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Management of the database.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .integrity import * + +__all__ = ( + 'TABLES_UUID_DEDUPLICATION', + 'deduplicate_uuids', + 'get_duplicate_uuids', + 'verify_uuid_uniqueness', + 'write_database_integrity_violation', +) + +# yapf: enable diff --git a/aiida/manage/database/integrity/__init__.py b/aiida/manage/database/integrity/__init__.py index 796a9a7213..6bb9ca7f26 100644 --- a/aiida/manage/database/integrity/__init__.py +++ b/aiida/manage/database/integrity/__init__.py @@ -7,46 +7,22 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name """Methods to validate the database integrity and fix violations.""" -WARNING_BORDER = '*' * 120 +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import -def write_database_integrity_violation(results, headers, reason_message, action_message=None): - """Emit a integrity violation warning and write the violating records to a log file in the current directory +from .duplicate_uuid import * +from .utils import * - :param results: a list of tuples representing the violating records - :param headers: a tuple of strings that will be used as a header for the log file. Should have the same length - as each tuple in the results list. - :param reason_message: a human readable message detailing the reason of the integrity violation - :param action_message: an optional human readable message detailing a performed action, if any - """ - # pylint: disable=duplicate-string-formatting-argument - from datetime import datetime - from tabulate import tabulate - from tempfile import NamedTemporaryFile +__all__ = ( + 'TABLES_UUID_DEDUPLICATION', + 'deduplicate_uuids', + 'get_duplicate_uuids', + 'verify_uuid_uniqueness', + 'write_database_integrity_violation', +) - from aiida.cmdline.utils import echo - from aiida.manage import configuration - - if configuration.PROFILE.is_test_profile: - return - - if action_message is None: - action_message = 'nothing' - - with NamedTemporaryFile(prefix='migration-', suffix='.log', dir='.', delete=False, mode='w+') as handle: - echo.echo('') - echo.echo_warning( - '\n{}\nFound one or multiple records that violate the integrity of the database\nViolation reason: {}\n' - 'Performed action: {}\nViolators written to: {}\n{}\n'.format( - WARNING_BORDER, reason_message, action_message, handle.name, WARNING_BORDER - ) - ) - - handle.write(f'# {datetime.utcnow().isoformat()}\n') - handle.write(f'# Violation reason: {reason_message}\n') - handle.write(f'# Performed action: {action_message}\n') - handle.write('\n') - handle.write(tabulate(results, headers)) +# yapf: enable diff --git a/aiida/manage/database/integrity/utils.py b/aiida/manage/database/integrity/utils.py new file mode 100644 index 0000000000..d516a00a8a --- /dev/null +++ b/aiida/manage/database/integrity/utils.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name +"""Methods to validate the database integrity and fix violations.""" +__all__ = ('write_database_integrity_violation',) + +WARNING_BORDER = '*' * 120 + + +def write_database_integrity_violation(results, headers, reason_message, action_message=None): + """Emit a integrity violation warning and write the violating records to a log file in the current directory + + :param results: a list of tuples representing the violating records + :param headers: a tuple of strings that will be used as a header for the log file. Should have the same length + as each tuple in the results list. + :param reason_message: a human readable message detailing the reason of the integrity violation + :param action_message: an optional human readable message detailing a performed action, if any + """ + # pylint: disable=duplicate-string-formatting-argument + from datetime import datetime + from tabulate import tabulate + from tempfile import NamedTemporaryFile + + from aiida.cmdline.utils import echo + from aiida.manage import configuration + + if configuration.PROFILE.is_test_profile: + return + + if action_message is None: + action_message = 'nothing' + + with NamedTemporaryFile(prefix='migration-', suffix='.log', dir='.', delete=False, mode='w+') as handle: + echo.echo('') + echo.echo_warning( + '\n{}\nFound one or multiple records that violate the integrity of the database\nViolation reason: {}\n' + 'Performed action: {}\nViolators written to: {}\n{}\n'.format( + WARNING_BORDER, reason_message, action_message, handle.name, WARNING_BORDER + ) + ) + + handle.write(f'# {datetime.utcnow().isoformat()}\n') + handle.write(f'# Violation reason: {reason_message}\n') + handle.write(f'# Performed action: {action_message}\n') + handle.write('\n') + handle.write(tabulate(results, headers)) diff --git a/aiida/manage/external/__init__.py b/aiida/manage/external/__init__.py index e82b79252b..d82852a0da 100644 --- a/aiida/manage/external/__init__.py +++ b/aiida/manage/external/__init__.py @@ -8,3 +8,24 @@ # For further information please visit http://www.aiida.net # ########################################################################### """User facing APIs to control AiiDA from the verdi cli, scripts or plugins""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .postgres import * +from .rmq import * + +__all__ = ( + 'BROKER_DEFAULTS', + 'CommunicationTimeout', + 'DEFAULT_DBINFO', + 'DeliveryFailed', + 'Postgres', + 'PostgresConnectionMode', + 'ProcessLauncher', + 'RemoteException', +) + +# yapf: enable diff --git a/aiida/manage/tests/__init__.py b/aiida/manage/tests/__init__.py index 24e0d218a7..813f38628c 100644 --- a/aiida/manage/tests/__init__.py +++ b/aiida/manage/tests/__init__.py @@ -9,501 +9,27 @@ ########################################################################### """ Testing infrastructure for easy testing of AiiDA plugins. - """ -import tempfile -import shutil -import os -from contextlib import contextmanager - -from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA -from aiida.common import exceptions -from aiida.manage import configuration -from aiida.manage.configuration.settings import create_instance_directories -from aiida.manage import manager -from aiida.manage.external.postgres import Postgres - -__all__ = ('TestManager', 'TestManagerError', 'ProfileManager', 'TemporaryProfileManager', '_GLOBAL_TEST_MANAGER') - -_DEFAULT_PROFILE_INFO = { - 'name': 'test_profile', - 'email': 'tests@aiida.mail', - 'first_name': 'AiiDA', - 'last_name': 'Plugintest', - 'institution': 'aiidateam', - 'database_engine': 'postgresql_psycopg2', - 'database_backend': 'django', - 'database_username': 'aiida', - 'database_password': 'aiida_pw', - 'database_name': 'aiida_db', - 'repo_dir': 'test_repo', - 'config_dir': '.aiida', - 'root_path': '', - 'broker_protocol': 'amqp', - 'broker_username': 'guest', - 'broker_password': 'guest', - 'broker_host': '127.0.0.1', - 'broker_port': 5672, - 'broker_virtual_host': '' -} - - -class TestManagerError(Exception): - """Raised by TestManager in situations that may lead to inconsistent behaviour.""" - - def __init__(self, msg): - super().__init__() - self.msg = msg - - def __str__(self): - return repr(self.msg) - - -class TestManager: - """ - Test manager for plugin tests. - - Uses either ProfileManager for wrapping an existing profile or TemporaryProfileManager for setting up a complete - temporary AiiDA environment. - - For usage with pytest, see :py:class:`~aiida.manage.tests.pytest_fixtures`. - """ - - def __init__(self): - self._manager = None - - def use_temporary_profile(self, backend=None, pgtest=None): - """Set up Test manager to use temporary AiiDA profile. - - Uses :py:class:`aiida.manage.tests.TemporaryProfileManager` internally. - - :param backend: Backend to use. - :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, - e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. - - """ - if configuration.PROFILE is not None: - raise TestManagerError('AiiDA dbenv must not be loaded before setting up a test profile.') - if self._manager is not None: - raise TestManagerError('Profile manager already loaded.') - - mngr = TemporaryProfileManager(backend=backend, pgtest=pgtest) - mngr.create_profile() - self._manager = mngr # don't assign before profile has actually been created! - - def use_profile(self, profile_name): - """Set up Test manager to use existing profile. - - Uses :py:class:`aiida.manage.tests.ProfileManager` internally. - - :param profile_name: Name of existing test profile to use. - """ - if configuration.PROFILE is not None: - raise TestManagerError('AiiDA dbenv must not be loaded before setting up a test profile.') - if self._manager is not None: - raise TestManagerError('Profile manager already loaded.') - - self._manager = ProfileManager(profile_name=profile_name) - self._manager.init_db() - - def has_profile_open(self): - return self._manager and self._manager.has_profile_open() - - def reset_db(self): - return self._manager.reset_db() - - def destroy_all(self): - if self._manager: - self._manager.destroy_all() - self._manager = None - - -class ProfileManager: - """ - Wraps existing AiiDA profile. - """ - - def __init__(self, profile_name): - """ - Use an existing profile. - - :param profile_name: Name of the profile to be loaded - """ - from aiida import load_profile - from aiida.backends.testbase import check_if_tests_can_run - - self._profile = None - self._user = None - - try: - self._profile = load_profile(profile_name) - manager.get_manager()._load_backend(schema_check=False) # pylint: disable=protected-access - except Exception: - raise TestManagerError('Unable to load test profile \'{}\'.'.format(profile_name)) - check_if_tests_can_run() - - self._select_db_test_case(backend=self._profile.database_backend) - - def _select_db_test_case(self, backend): - """ - Selects tests case for the correct database backend. - """ - if backend == BACKEND_DJANGO: - from aiida.backends.djsite.db.testbase import DjangoTests - self._test_case = DjangoTests() - elif backend == BACKEND_SQLA: - from aiida.backends.sqlalchemy.testbase import SqlAlchemyTests - from aiida.backends.sqlalchemy import get_scoped_session - - self._test_case = SqlAlchemyTests() - self._test_case.test_session = get_scoped_session() - - def reset_db(self): - self._test_case.clean_db() # will drop all users - manager.reset_manager() - self.init_db() - - def init_db(self): - """Initialise the database state for running of tests. - - Adds default user if necessary. - """ - from aiida.orm import User - from aiida.cmdline.commands.cmd_user import set_default_user - - if not User.objects.get_default(): - user_dict = get_user_dict(_DEFAULT_PROFILE_INFO) - try: - user = User(**user_dict) - user.store() - except exceptions.IntegrityError: - # The user already exists, no problem - user = User.objects.get(**user_dict) - - set_default_user(self._profile, user) - User.objects.reset() # necessary to pick up new default user - - def has_profile_open(self): - return self._profile is not None - - def destroy_all(self): - pass - - -class TemporaryProfileManager(ProfileManager): - """ - Manage the life cycle of a completely separated and temporary AiiDA environment. - - * No profile / database setup required - * Tests run via the TemporaryProfileManager never pollute the user's working environment - - Filesystem: - - * temporary ``.aiida`` configuration folder - * temporary repository folder - - Database: - - * temporary database cluster (via the ``pgtest`` package) - * with ``aiida`` database user - * with ``aiida_db`` database - - AiiDA: - - * configured to use the temporary configuration - * sets up a temporary profile for tests - - All of this happens automatically when using the corresponding tests classes & tests runners (unittest) - or fixtures (pytest). - - Example:: - - tests = TemporaryProfileManager(backend=backend) - tests.create_aiida_db() # set up only the database - tests.create_profile() # set up a profile (creates the db too if necessary) - - # ready for tests - - # run tests 1 - - tests.reset_db() - # database ready for independent tests 2 - - # run tests 2 - - tests.destroy_all() - # everything cleaned up - - """ - - _test_case = None - - def __init__(self, backend=BACKEND_DJANGO, pgtest=None): # pylint: disable=super-init-not-called - """Construct a TemporaryProfileManager - - :param backend: a database backend - :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, - e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. - - """ - from aiida.manage.configuration import settings - - self.dbinfo = {} - self.profile_info = _DEFAULT_PROFILE_INFO - self.profile_info['database_backend'] = backend - self._pgtest = pgtest or {} - - self.pg_cluster = None - self.postgres = None - self._profile = None - self._has_test_db = False - self._backup = { - 'config': configuration.CONFIG, - 'config_dir': settings.AIIDA_CONFIG_FOLDER, - 'profile': configuration.PROFILE, - } - - @property - def profile_dictionary(self): - """Profile parameters. - - Used to set up AiiDA profile from self.profile_info dictionary. - """ - dictionary = { - 'database_engine': self.profile_info.get('database_engine'), - 'database_backend': self.profile_info.get('database_backend'), - 'database_port': self.profile_info.get('database_port'), - 'database_hostname': self.profile_info.get('database_hostname'), - 'database_name': self.profile_info.get('database_name'), - 'database_username': self.profile_info.get('database_username'), - 'database_password': self.profile_info.get('database_password'), - 'broker_protocol': self.profile_info.get('broker_protocol'), - 'broker_username': self.profile_info.get('broker_username'), - 'broker_password': self.profile_info.get('broker_password'), - 'broker_host': self.profile_info.get('broker_host'), - 'broker_port': self.profile_info.get('broker_port'), - 'broker_virtual_host': self.profile_info.get('broker_virtual_host'), - 'repository_uri': f'file://{self.repo}', - } - return dictionary - - def create_db_cluster(self): - """ - Create the database cluster using PGTest. - """ - from pgtest.pgtest import PGTest - - if self.pg_cluster is not None: - raise TestManagerError( - 'Running temporary postgresql cluster detected.Use destroy_all() before creating a new cluster.' - ) - self.pg_cluster = PGTest(**self._pgtest) - self.dbinfo.update(self.pg_cluster.dsn) - - def create_aiida_db(self): - """ - Create the necessary database on the temporary postgres instance. - """ - if configuration.PROFILE is not None: - raise TestManagerError('AiiDA dbenv can not be loaded while creating a tests db environment') - if self.pg_cluster is None: - self.create_db_cluster() - self.postgres = Postgres(interactive=False, quiet=True, dbinfo=self.dbinfo) - # note: not using postgres.create_dbuser_db_safe here since we don't want prompts - self.postgres.create_dbuser(self.profile_info['database_username'], self.profile_info['database_password']) - self.postgres.create_db(self.profile_info['database_username'], self.profile_info['database_name']) - self.dbinfo = self.postgres.dbinfo - self.profile_info['database_hostname'] = self.postgres.host_for_psycopg2 - self.profile_info['database_port'] = self.postgres.port_for_psycopg2 - self._has_test_db = True - - def create_profile(self): - """ - Set AiiDA to use the tests config dir and create a default profile there - - Warning: the AiiDA dbenv must not be loaded when this is called! - """ - from aiida.manage.configuration import settings, load_profile, Profile - - if not self._has_test_db: - self.create_aiida_db() - - if not self.root_dir: - self.root_dir = tempfile.mkdtemp() - configuration.CONFIG = None - settings.AIIDA_CONFIG_FOLDER = self.config_dir - configuration.PROFILE = None - create_instance_directories() - profile_name = self.profile_info['name'] - config = configuration.get_config(create=True) - profile = Profile(profile_name, self.profile_dictionary) - config.add_profile(profile) - config.set_default_profile(profile_name).store() - self._profile = profile - - load_profile(profile_name) - backend = manager.get_manager()._load_backend(schema_check=False) - backend.migrate() - - self._select_db_test_case(backend=self._profile.database_backend) - self.init_db() - - def repo_ok(self): - return bool(self.repo and os.path.isdir(os.path.dirname(self.repo))) - - @property - def repo(self): - return self._return_dir(self.profile_info['repo_dir']) - - def _return_dir(self, dir_path): - """Return a path to a directory from the fs environment""" - if os.path.isabs(dir_path): - return dir_path - return os.path.join(self.root_dir, dir_path) - - @property - def backend(self): - return self.profile_info['backend'] - - @backend.setter - def backend(self, backend): - if self.has_profile_open(): - raise TestManagerError('backend cannot be changed after setting up the environment') - - valid_backends = [BACKEND_DJANGO, BACKEND_SQLA] - if backend not in valid_backends: - raise ValueError(f'invalid backend {backend}, must be one of {valid_backends}') - self.profile_info['backend'] = backend - - @property - def config_dir_ok(self): - return bool(self.config_dir and os.path.isdir(self.config_dir)) - - @property - def config_dir(self): - return self._return_dir(self.profile_info['config_dir']) - - @property - def root_dir(self): - return self.profile_info['root_path'] - - @root_dir.setter - def root_dir(self, root_dir): - self.profile_info['root_path'] = root_dir - - @property - def root_dir_ok(self): - return bool(self.root_dir and os.path.isdir(self.root_dir)) - - def destroy_all(self): - """Remove all traces of the tests run""" - from aiida.manage.configuration import settings - if self.root_dir: - shutil.rmtree(self.root_dir) - self.root_dir = None - if self.pg_cluster: - self.pg_cluster.close() - self.pg_cluster = None - self._has_test_db = False - self._profile = None - self._user = None - - if 'config' in self._backup: - configuration.CONFIG = self._backup['config'] - if 'config_dir' in self._backup: - settings.AIIDA_CONFIG_FOLDER = self._backup['config_dir'] - if 'profile' in self._backup: - configuration.PROFILE = self._backup['profile'] - - def has_profile_open(self): - return self._profile is not None - - -_GLOBAL_TEST_MANAGER = TestManager() - - -@contextmanager -def test_manager(backend=BACKEND_DJANGO, profile_name=None, pgtest=None): - """ Context manager for TestManager objects. - - Sets up temporary AiiDA environment for testing or reuses existing environment, - if `AIIDA_TEST_PROFILE` environment variable is set. - - Example pytest fixture:: - - def aiida_profile(): - with test_manager(backend) as test_mgr: - yield fixture_mgr - - Example unittest test runner:: - - with test_manager(backend) as test_mgr: - # ready for tests - # everything cleaned up - - - :param backend: database backend, either BACKEND_SQLA or BACKEND_DJANGO - :param profile_name: name of test profile to be used or None (to use temporary profile) - :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, - e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. - """ - from aiida.common.utils import Capturing - from aiida.common.log import configure_logging - - try: - if not _GLOBAL_TEST_MANAGER.has_profile_open(): - if profile_name: - _GLOBAL_TEST_MANAGER.use_profile(profile_name=profile_name) - else: - with Capturing(): # capture output of AiiDA DB setup - _GLOBAL_TEST_MANAGER.use_temporary_profile(backend=backend, pgtest=pgtest) - configure_logging(with_orm=True) - yield _GLOBAL_TEST_MANAGER - finally: - _GLOBAL_TEST_MANAGER.destroy_all() - - -def get_test_backend_name(): - """ Read name of database backend from environment variable or the specified test profile. - - Reads database backend ('django' or 'sqlalchemy') from 'AIIDA_TEST_BACKEND' environment variable, - or the backend configured for the 'AIIDA_TEST_PROFILE'. - Defaults to django backend. - - :returns: content of environment variable or `BACKEND_DJANGO` - :raises: ValueError if unknown backend name detected. - :raises: ValueError if both 'AIIDA_TEST_BACKEND' and 'AIIDA_TEST_PROFILE' are set, and the two - backends do not match. - """ - test_profile_name = get_test_profile_name() - backend_env = os.environ.get('AIIDA_TEST_BACKEND', None) - if test_profile_name is not None: - backend_profile = configuration.get_config().get_profile(test_profile_name).database_backend - if backend_env is not None and backend_env != backend_profile: - raise ValueError( - "The backend '{}' read from AIIDA_TEST_BACKEND does not match the backend '{}' " - "of AIIDA_TEST_PROFILE '{}'".format(backend_env, backend_profile, test_profile_name) - ) - backend_res = backend_profile - else: - backend_res = backend_env or BACKEND_DJANGO - - if backend_res in (BACKEND_DJANGO, BACKEND_SQLA): - return backend_res - raise ValueError(f"Unknown backend '{backend_res}' read from AIIDA_TEST_BACKEND environment variable") - -def get_test_profile_name(): - """ Read name of test profile from environment variable. +# AUTO-GENERATED - Reads name of existing test profile 'AIIDA_TEST_PROFILE' environment variable. - If specified, this profile is used for running the tests (instead of setting up a temporary profile). +# yapf: disable +# pylint: disable=wildcard-import - :returns: content of environment variable or `None` - """ - return os.environ.get('AIIDA_TEST_PROFILE', None) +from .main import * +from .unittest_classes import * +__all__ = ( + 'PluginTestCase', + 'ProfileManager', + 'TemporaryProfileManager', + 'TestManager', + 'TestManagerError', + 'TestRunner', + 'get_test_backend_name', + 'get_test_profile_name', + 'get_user_dict', + 'test_manager', +) -def get_user_dict(profile_dict): - """Collect parameters required for creating users.""" - return {k: profile_dict[k] for k in ('email', 'first_name', 'last_name', 'institution')} +# yapf: enable diff --git a/aiida/manage/tests/main.py b/aiida/manage/tests/main.py new file mode 100644 index 0000000000..c3bc9ff19a --- /dev/null +++ b/aiida/manage/tests/main.py @@ -0,0 +1,518 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +""" +Testing infrastructure for easy testing of AiiDA plugins. + +""" +import tempfile +import shutil +import os +from contextlib import contextmanager + +from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA +from aiida.common import exceptions +from aiida.manage import configuration +from aiida.manage.configuration.settings import create_instance_directories +from aiida.manage import manager +from aiida.manage.external.postgres import Postgres + +__all__ = ( + 'get_test_profile_name', + 'get_test_backend_name', + 'get_user_dict', + 'test_manager', + 'TestManager', + 'TestManagerError', + 'ProfileManager', + 'TemporaryProfileManager', +) + +_DEFAULT_PROFILE_INFO = { + 'name': 'test_profile', + 'email': 'tests@aiida.mail', + 'first_name': 'AiiDA', + 'last_name': 'Plugintest', + 'institution': 'aiidateam', + 'database_engine': 'postgresql_psycopg2', + 'database_backend': 'django', + 'database_username': 'aiida', + 'database_password': 'aiida_pw', + 'database_name': 'aiida_db', + 'repo_dir': 'test_repo', + 'config_dir': '.aiida', + 'root_path': '', + 'broker_protocol': 'amqp', + 'broker_username': 'guest', + 'broker_password': 'guest', + 'broker_host': '127.0.0.1', + 'broker_port': 5672, + 'broker_virtual_host': '' +} + + +class TestManagerError(Exception): + """Raised by TestManager in situations that may lead to inconsistent behaviour.""" + + def __init__(self, msg): + super().__init__() + self.msg = msg + + def __str__(self): + return repr(self.msg) + + +class TestManager: + """ + Test manager for plugin tests. + + Uses either ProfileManager for wrapping an existing profile or TemporaryProfileManager for setting up a complete + temporary AiiDA environment. + + For usage with pytest, see :py:class:`~aiida.manage.tests.pytest_fixtures`. + """ + + def __init__(self): + self._manager = None + + def use_temporary_profile(self, backend=None, pgtest=None): + """Set up Test manager to use temporary AiiDA profile. + + Uses :py:class:`aiida.manage.tests.TemporaryProfileManager` internally. + + :param backend: Backend to use. + :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, + e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. + + """ + if configuration.PROFILE is not None: + raise TestManagerError('AiiDA dbenv must not be loaded before setting up a test profile.') + if self._manager is not None: + raise TestManagerError('Profile manager already loaded.') + + mngr = TemporaryProfileManager(backend=backend, pgtest=pgtest) + mngr.create_profile() + self._manager = mngr # don't assign before profile has actually been created! + + def use_profile(self, profile_name): + """Set up Test manager to use existing profile. + + Uses :py:class:`aiida.manage.tests.ProfileManager` internally. + + :param profile_name: Name of existing test profile to use. + """ + if configuration.PROFILE is not None: + raise TestManagerError('AiiDA dbenv must not be loaded before setting up a test profile.') + if self._manager is not None: + raise TestManagerError('Profile manager already loaded.') + + self._manager = ProfileManager(profile_name=profile_name) + self._manager.init_db() + + def has_profile_open(self): + return self._manager and self._manager.has_profile_open() + + def reset_db(self): + return self._manager.reset_db() + + def destroy_all(self): + if self._manager: + self._manager.destroy_all() + self._manager = None + + +class ProfileManager: + """ + Wraps existing AiiDA profile. + """ + + def __init__(self, profile_name): + """ + Use an existing profile. + + :param profile_name: Name of the profile to be loaded + """ + from aiida import load_profile + from aiida.backends.testbase import check_if_tests_can_run + + self._profile = None + self._user = None + + try: + self._profile = load_profile(profile_name) + manager.get_manager()._load_backend(schema_check=False) # pylint: disable=protected-access + except Exception: + raise TestManagerError('Unable to load test profile \'{}\'.'.format(profile_name)) + check_if_tests_can_run() + + self._select_db_test_case(backend=self._profile.database_backend) + + def _select_db_test_case(self, backend): + """ + Selects tests case for the correct database backend. + """ + if backend == BACKEND_DJANGO: + from aiida.backends.djsite.db.testbase import DjangoTests + self._test_case = DjangoTests() + elif backend == BACKEND_SQLA: + from aiida.backends.sqlalchemy.testbase import SqlAlchemyTests + from aiida.backends.sqlalchemy import get_scoped_session + + self._test_case = SqlAlchemyTests() + self._test_case.test_session = get_scoped_session() + + def reset_db(self): + self._test_case.clean_db() # will drop all users + manager.reset_manager() + self.init_db() + + def init_db(self): + """Initialise the database state for running of tests. + + Adds default user if necessary. + """ + from aiida.orm import User + from aiida.cmdline.commands.cmd_user import set_default_user + + if not User.objects.get_default(): + user_dict = get_user_dict(_DEFAULT_PROFILE_INFO) + try: + user = User(**user_dict) + user.store() + except exceptions.IntegrityError: + # The user already exists, no problem + user = User.objects.get(**user_dict) + + set_default_user(self._profile, user) + User.objects.reset() # necessary to pick up new default user + + def has_profile_open(self): + return self._profile is not None + + def destroy_all(self): + pass + + +class TemporaryProfileManager(ProfileManager): + """ + Manage the life cycle of a completely separated and temporary AiiDA environment. + + * No profile / database setup required + * Tests run via the TemporaryProfileManager never pollute the user's working environment + + Filesystem: + + * temporary ``.aiida`` configuration folder + * temporary repository folder + + Database: + + * temporary database cluster (via the ``pgtest`` package) + * with ``aiida`` database user + * with ``aiida_db`` database + + AiiDA: + + * configured to use the temporary configuration + * sets up a temporary profile for tests + + All of this happens automatically when using the corresponding tests classes & tests runners (unittest) + or fixtures (pytest). + + Example:: + + tests = TemporaryProfileManager(backend=backend) + tests.create_aiida_db() # set up only the database + tests.create_profile() # set up a profile (creates the db too if necessary) + + # ready for tests + + # run tests 1 + + tests.reset_db() + # database ready for independent tests 2 + + # run tests 2 + + tests.destroy_all() + # everything cleaned up + + """ + + _test_case = None + + def __init__(self, backend=BACKEND_DJANGO, pgtest=None): # pylint: disable=super-init-not-called + """Construct a TemporaryProfileManager + + :param backend: a database backend + :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, + e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. + + """ + from aiida.manage.configuration import settings + + self.dbinfo = {} + self.profile_info = _DEFAULT_PROFILE_INFO + self.profile_info['database_backend'] = backend + self._pgtest = pgtest or {} + + self.pg_cluster = None + self.postgres = None + self._profile = None + self._has_test_db = False + self._backup = { + 'config': configuration.CONFIG, + 'config_dir': settings.AIIDA_CONFIG_FOLDER, + 'profile': configuration.PROFILE, + } + + @property + def profile_dictionary(self): + """Profile parameters. + + Used to set up AiiDA profile from self.profile_info dictionary. + """ + dictionary = { + 'database_engine': self.profile_info.get('database_engine'), + 'database_backend': self.profile_info.get('database_backend'), + 'database_port': self.profile_info.get('database_port'), + 'database_hostname': self.profile_info.get('database_hostname'), + 'database_name': self.profile_info.get('database_name'), + 'database_username': self.profile_info.get('database_username'), + 'database_password': self.profile_info.get('database_password'), + 'broker_protocol': self.profile_info.get('broker_protocol'), + 'broker_username': self.profile_info.get('broker_username'), + 'broker_password': self.profile_info.get('broker_password'), + 'broker_host': self.profile_info.get('broker_host'), + 'broker_port': self.profile_info.get('broker_port'), + 'broker_virtual_host': self.profile_info.get('broker_virtual_host'), + 'repository_uri': f'file://{self.repo}', + } + return dictionary + + def create_db_cluster(self): + """ + Create the database cluster using PGTest. + """ + from pgtest.pgtest import PGTest + + if self.pg_cluster is not None: + raise TestManagerError( + 'Running temporary postgresql cluster detected.Use destroy_all() before creating a new cluster.' + ) + self.pg_cluster = PGTest(**self._pgtest) + self.dbinfo.update(self.pg_cluster.dsn) + + def create_aiida_db(self): + """ + Create the necessary database on the temporary postgres instance. + """ + if configuration.PROFILE is not None: + raise TestManagerError('AiiDA dbenv can not be loaded while creating a tests db environment') + if self.pg_cluster is None: + self.create_db_cluster() + self.postgres = Postgres(interactive=False, quiet=True, dbinfo=self.dbinfo) + # note: not using postgres.create_dbuser_db_safe here since we don't want prompts + self.postgres.create_dbuser(self.profile_info['database_username'], self.profile_info['database_password']) + self.postgres.create_db(self.profile_info['database_username'], self.profile_info['database_name']) + self.dbinfo = self.postgres.dbinfo + self.profile_info['database_hostname'] = self.postgres.host_for_psycopg2 + self.profile_info['database_port'] = self.postgres.port_for_psycopg2 + self._has_test_db = True + + def create_profile(self): + """ + Set AiiDA to use the tests config dir and create a default profile there + + Warning: the AiiDA dbenv must not be loaded when this is called! + """ + from aiida.manage.configuration import settings, load_profile, Profile + + if not self._has_test_db: + self.create_aiida_db() + + if not self.root_dir: + self.root_dir = tempfile.mkdtemp() + configuration.CONFIG = None + settings.AIIDA_CONFIG_FOLDER = self.config_dir + configuration.PROFILE = None + create_instance_directories() + profile_name = self.profile_info['name'] + config = configuration.get_config(create=True) + profile = Profile(profile_name, self.profile_dictionary) + config.add_profile(profile) + config.set_default_profile(profile_name).store() + self._profile = profile + + load_profile(profile_name) + backend = manager.get_manager()._load_backend(schema_check=False) + backend.migrate() + + self._select_db_test_case(backend=self._profile.database_backend) + self.init_db() + + def repo_ok(self): + return bool(self.repo and os.path.isdir(os.path.dirname(self.repo))) + + @property + def repo(self): + return self._return_dir(self.profile_info['repo_dir']) + + def _return_dir(self, dir_path): + """Return a path to a directory from the fs environment""" + if os.path.isabs(dir_path): + return dir_path + return os.path.join(self.root_dir, dir_path) + + @property + def backend(self): + return self.profile_info['backend'] + + @backend.setter + def backend(self, backend): + if self.has_profile_open(): + raise TestManagerError('backend cannot be changed after setting up the environment') + + valid_backends = [BACKEND_DJANGO, BACKEND_SQLA] + if backend not in valid_backends: + raise ValueError(f'invalid backend {backend}, must be one of {valid_backends}') + self.profile_info['backend'] = backend + + @property + def config_dir_ok(self): + return bool(self.config_dir and os.path.isdir(self.config_dir)) + + @property + def config_dir(self): + return self._return_dir(self.profile_info['config_dir']) + + @property + def root_dir(self): + return self.profile_info['root_path'] + + @root_dir.setter + def root_dir(self, root_dir): + self.profile_info['root_path'] = root_dir + + @property + def root_dir_ok(self): + return bool(self.root_dir and os.path.isdir(self.root_dir)) + + def destroy_all(self): + """Remove all traces of the tests run""" + from aiida.manage.configuration import settings + if self.root_dir: + shutil.rmtree(self.root_dir) + self.root_dir = None + if self.pg_cluster: + self.pg_cluster.close() + self.pg_cluster = None + self._has_test_db = False + self._profile = None + self._user = None + + if 'config' in self._backup: + configuration.CONFIG = self._backup['config'] + if 'config_dir' in self._backup: + settings.AIIDA_CONFIG_FOLDER = self._backup['config_dir'] + if 'profile' in self._backup: + configuration.PROFILE = self._backup['profile'] + + def has_profile_open(self): + return self._profile is not None + + +_GLOBAL_TEST_MANAGER = TestManager() + + +@contextmanager +def test_manager(backend=BACKEND_DJANGO, profile_name=None, pgtest=None): + """ Context manager for TestManager objects. + + Sets up temporary AiiDA environment for testing or reuses existing environment, + if `AIIDA_TEST_PROFILE` environment variable is set. + + Example pytest fixture:: + + def aiida_profile(): + with test_manager(backend) as test_mgr: + yield fixture_mgr + + Example unittest test runner:: + + with test_manager(backend) as test_mgr: + # ready for tests + # everything cleaned up + + + :param backend: database backend, either BACKEND_SQLA or BACKEND_DJANGO + :param profile_name: name of test profile to be used or None (to use temporary profile) + :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, + e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. + """ + from aiida.common.utils import Capturing + from aiida.common.log import configure_logging + + try: + if not _GLOBAL_TEST_MANAGER.has_profile_open(): + if profile_name: + _GLOBAL_TEST_MANAGER.use_profile(profile_name=profile_name) + else: + with Capturing(): # capture output of AiiDA DB setup + _GLOBAL_TEST_MANAGER.use_temporary_profile(backend=backend, pgtest=pgtest) + configure_logging(with_orm=True) + yield _GLOBAL_TEST_MANAGER + finally: + _GLOBAL_TEST_MANAGER.destroy_all() + + +def get_test_backend_name(): + """ Read name of database backend from environment variable or the specified test profile. + + Reads database backend ('django' or 'sqlalchemy') from 'AIIDA_TEST_BACKEND' environment variable, + or the backend configured for the 'AIIDA_TEST_PROFILE'. + Defaults to django backend. + + :returns: content of environment variable or `BACKEND_DJANGO` + :raises: ValueError if unknown backend name detected. + :raises: ValueError if both 'AIIDA_TEST_BACKEND' and 'AIIDA_TEST_PROFILE' are set, and the two + backends do not match. + """ + test_profile_name = get_test_profile_name() + backend_env = os.environ.get('AIIDA_TEST_BACKEND', None) + if test_profile_name is not None: + backend_profile = configuration.get_config().get_profile(test_profile_name).database_backend + if backend_env is not None and backend_env != backend_profile: + raise ValueError( + "The backend '{}' read from AIIDA_TEST_BACKEND does not match the backend '{}' " + "of AIIDA_TEST_PROFILE '{}'".format(backend_env, backend_profile, test_profile_name) + ) + backend_res = backend_profile + else: + backend_res = backend_env or BACKEND_DJANGO + + if backend_res in (BACKEND_DJANGO, BACKEND_SQLA): + return backend_res + raise ValueError(f"Unknown backend '{backend_res}' read from AIIDA_TEST_BACKEND environment variable") + + +def get_test_profile_name(): + """ Read name of test profile from environment variable. + + Reads name of existing test profile 'AIIDA_TEST_PROFILE' environment variable. + If specified, this profile is used for running the tests (instead of setting up a temporary profile). + + :returns: content of environment variable or `None` + """ + return os.environ.get('AIIDA_TEST_PROFILE', None) + + +def get_user_dict(profile_dict): + """Collect parameters required for creating users.""" + return {k: profile_dict[k] for k in ('email', 'first_name', 'last_name', 'institution')} diff --git a/aiida/manage/tests/unittest_classes.py b/aiida/manage/tests/unittest_classes.py index 751583eb3b..674fb43606 100644 --- a/aiida/manage/tests/unittest_classes.py +++ b/aiida/manage/tests/unittest_classes.py @@ -15,7 +15,7 @@ from aiida.common.warnings import AiidaDeprecationWarning from aiida.manage.manager import get_manager -from . import _GLOBAL_TEST_MANAGER, test_manager, get_test_backend_name, get_test_profile_name +from .main import _GLOBAL_TEST_MANAGER, test_manager, get_test_backend_name, get_test_profile_name __all__ = ('PluginTestCase', 'TestRunner') diff --git a/aiida/orm/__init__.py b/aiida/orm/__init__.py index fa6f66afde..a84f87017f 100644 --- a/aiida/orm/__init__.py +++ b/aiida/orm/__init__.py @@ -7,9 +7,13 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable,redefined-builtin,cyclic-import """Main module to expose all orm classes and methods""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .authinfos import * from .comments import * from .computers import * @@ -22,6 +26,87 @@ from .utils import * __all__ = ( - authinfos.__all__ + comments.__all__ + computers.__all__ + entities.__all__ + groups.__all__ + logs.__all__ + - nodes.__all__ + querybuilder.__all__ + users.__all__ + utils.__all__ + 'ASCENDING', + 'AbstractNodeMeta', + 'ArrayData', + 'AttributeManager', + 'AuthInfo', + 'AutoGroup', + 'BandsData', + 'BaseType', + 'Bool', + 'CalcFunctionNode', + 'CalcJobNode', + 'CalcJobResultManager', + 'CalculationEntityLoader', + 'CalculationNode', + 'CifData', + 'Code', + 'CodeEntityLoader', + 'Collection', + 'Comment', + 'Computer', + 'ComputerEntityLoader', + 'DESCENDING', + 'Data', + 'Dict', + 'Entity', + 'EntityAttributesMixin', + 'EntityExtrasMixin', + 'Float', + 'FolderData', + 'Group', + 'GroupEntityLoader', + 'ImportGroup', + 'Int', + 'Kind', + 'KpointsData', + 'LinkManager', + 'LinkPair', + 'LinkTriple', + 'List', + 'Log', + 'Node', + 'NodeEntityLoader', + 'NodeLinksManager', + 'NodeRepositoryMixin', + 'NumericType', + 'OrbitalData', + 'OrderSpecifier', + 'OrmEntityLoader', + 'ProcessNode', + 'ProjectionData', + 'QueryBuilder', + 'RemoteData', + 'RemoteStashData', + 'RemoteStashFolderData', + 'SinglefileData', + 'Site', + 'Str', + 'StructureData', + 'TrajectoryData', + 'UpfData', + 'UpfFamily', + 'User', + 'WorkChainNode', + 'WorkFunctionNode', + 'WorkflowNode', + 'XyData', + 'cif_from_ase', + 'find_bandgap', + 'get_loader', + 'get_query_type_from_type_string', + 'get_type_string_from_class', + 'has_pycifrw', + 'load_code', + 'load_computer', + 'load_entity', + 'load_group', + 'load_node', + 'load_node_class', + 'pycifrw_from_cif', + 'to_aiida_type', + 'validate_link', ) + +# yapf: enable diff --git a/aiida/orm/implementation/__init__.py b/aiida/orm/implementation/__init__.py index 8e2f177b1d..332752718a 100644 --- a/aiida/orm/implementation/__init__.py +++ b/aiida/orm/implementation/__init__.py @@ -8,18 +8,48 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with the implementations of the various backend entities for various database backends.""" -# pylint: disable=wildcard-import,undefined-variable + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .authinfos import * from .backends import * from .comments import * from .computers import * +from .entities import * from .groups import * from .logs import * from .nodes import * from .querybuilder import * from .users import * +from .utils import * __all__ = ( - authinfos.__all__ + backends.__all__ + comments.__all__ + computers.__all__ + groups.__all__ + logs.__all__ + - nodes.__all__ + querybuilder.__all__ + users.__all__ + 'Backend', + 'BackendAuthInfo', + 'BackendAuthInfoCollection', + 'BackendCollection', + 'BackendComment', + 'BackendCommentCollection', + 'BackendComputer', + 'BackendComputerCollection', + 'BackendEntity', + 'BackendEntityAttributesMixin', + 'BackendEntityExtrasMixin', + 'BackendGroup', + 'BackendGroupCollection', + 'BackendLog', + 'BackendLogCollection', + 'BackendNode', + 'BackendNodeCollection', + 'BackendQueryBuilder', + 'BackendUser', + 'BackendUserCollection', + 'EntityType', + 'clean_value', + 'validate_attribute_extra_key', ) + +# yapf: enable diff --git a/aiida/orm/implementation/django/__init__.py b/aiida/orm/implementation/django/__init__.py index 2776a55f97..5089f32237 100644 --- a/aiida/orm/implementation/django/__init__.py +++ b/aiida/orm/implementation/django/__init__.py @@ -7,3 +7,25 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Implementation of Django backend.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .backend import * +from .convert import * +from .groups import * +from .users import * + +__all__ = ( + 'DjangoBackend', + 'DjangoGroup', + 'DjangoGroupCollection', + 'DjangoUser', + 'DjangoUserCollection', + 'get_backend_entity', +) + +# yapf: enable diff --git a/aiida/orm/implementation/django/comments.py b/aiida/orm/implementation/django/comments.py index 1e6f2b0521..7b58c2a8fb 100644 --- a/aiida/orm/implementation/django/comments.py +++ b/aiida/orm/implementation/django/comments.py @@ -114,7 +114,7 @@ def create(self, node, user, content=None, **kwargs): :param content: the comment content :return: a Comment object associated to the given node and user """ - return DjangoComment(self.backend, node, user, content, **kwargs) + return DjangoComment(self.backend, node, user, content, **kwargs) # pylint: disable=abstract-class-instantiated def delete(self, comment_id): """ diff --git a/aiida/orm/implementation/sql/__init__.py b/aiida/orm/implementation/sql/__init__.py index 3cea3705ad..439cd9ba84 100644 --- a/aiida/orm/implementation/sql/__init__.py +++ b/aiida/orm/implementation/sql/__init__.py @@ -12,3 +12,16 @@ All SQL backends with an ORM should subclass from the classes in this module """ + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .backends import * + +__all__ = ( + 'SqlBackend', +) + +# yapf: enable diff --git a/aiida/orm/implementation/sqlalchemy/__init__.py b/aiida/orm/implementation/sqlalchemy/__init__.py index 2776a55f97..82a9691ef1 100644 --- a/aiida/orm/implementation/sqlalchemy/__init__.py +++ b/aiida/orm/implementation/sqlalchemy/__init__.py @@ -7,3 +7,25 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Implementation of SQLAlchemy backend.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .backend import * +from .convert import * +from .groups import * +from .users import * + +__all__ = ( + 'SqlaBackend', + 'SqlaGroup', + 'SqlaGroupCollection', + 'SqlaUser', + 'SqlaUserCollection', + 'get_backend_entity', +) + +# yapf: enable diff --git a/aiida/orm/implementation/sqlalchemy/comments.py b/aiida/orm/implementation/sqlalchemy/comments.py index d97df9aea7..d8173cd703 100644 --- a/aiida/orm/implementation/sqlalchemy/comments.py +++ b/aiida/orm/implementation/sqlalchemy/comments.py @@ -112,7 +112,7 @@ def create(self, node, user, content=None, **kwargs): :param content: the comment content :return: a Comment object associated to the given node and user """ - return SqlaComment(self.backend, node, user, content, **kwargs) + return SqlaComment(self.backend, node, user, content, **kwargs) # pylint: disable=abstract-class-instantiated def delete(self, comment_id): """ diff --git a/aiida/orm/nodes/__init__.py b/aiida/orm/nodes/__init__.py index b11c562245..d7498592b4 100644 --- a/aiida/orm/nodes/__init__.py +++ b/aiida/orm/nodes/__init__.py @@ -7,11 +7,60 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module with `Node` sub classes for data and processes.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .data import * -from .process import * from .node import * +from .process import * +from .repository import * + +__all__ = ( + 'ArrayData', + 'BandsData', + 'BaseType', + 'Bool', + 'CalcFunctionNode', + 'CalcJobNode', + 'CalculationNode', + 'CifData', + 'Code', + 'Data', + 'Dict', + 'Float', + 'FolderData', + 'Int', + 'Kind', + 'KpointsData', + 'List', + 'Node', + 'NodeRepositoryMixin', + 'NumericType', + 'OrbitalData', + 'ProcessNode', + 'ProjectionData', + 'RemoteData', + 'RemoteStashData', + 'RemoteStashFolderData', + 'SinglefileData', + 'Site', + 'Str', + 'StructureData', + 'TrajectoryData', + 'UpfData', + 'WorkChainNode', + 'WorkFunctionNode', + 'WorkflowNode', + 'XyData', + 'cif_from_ase', + 'find_bandgap', + 'has_pycifrw', + 'pycifrw_from_cif', + 'to_aiida_type', +) -__all__ = (data.__all__ + process.__all__ + node.__all__) +# yapf: enable diff --git a/aiida/orm/nodes/data/__init__.py b/aiida/orm/nodes/data/__init__.py index 8ed0d10aa4..31292c91eb 100644 --- a/aiida/orm/nodes/data/__init__.py +++ b/aiida/orm/nodes/data/__init__.py @@ -8,27 +8,64 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with `Node` sub classes for data structures.""" -from .array import ArrayData, BandsData, KpointsData, ProjectionData, TrajectoryData, XyData -from .base import BaseType, to_aiida_type -from .bool import Bool -from .cif import CifData -from .code import Code -from .data import Data -from .dict import Dict -from .float import Float -from .folder import FolderData -from .int import Int -from .list import List -from .numeric import NumericType -from .orbital import OrbitalData -from .remote import RemoteData, RemoteStashData, RemoteStashFolderData -from .singlefile import SinglefileData -from .str import Str -from .structure import StructureData -from .upf import UpfData + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .array import * +from .base import * +from .bool import * +from .cif import * +from .code import * +from .data import * +from .dict import * +from .float import * +from .folder import * +from .int import * +from .list import * +from .numeric import * +from .orbital import * +from .remote import * +from .singlefile import * +from .str import * +from .structure import * +from .upf import * __all__ = ( - 'Data', 'BaseType', 'ArrayData', 'BandsData', 'KpointsData', 'ProjectionData', 'TrajectoryData', 'XyData', 'Bool', - 'CifData', 'Code', 'Float', 'FolderData', 'Int', 'List', 'OrbitalData', 'Dict', 'RemoteData', 'RemoteStashData', - 'RemoteStashFolderData', 'SinglefileData', 'Str', 'StructureData', 'UpfData', 'NumericType', 'to_aiida_type' + 'ArrayData', + 'BandsData', + 'BaseType', + 'Bool', + 'CifData', + 'Code', + 'Data', + 'Dict', + 'Float', + 'FolderData', + 'Int', + 'Kind', + 'KpointsData', + 'List', + 'NumericType', + 'OrbitalData', + 'ProjectionData', + 'RemoteData', + 'RemoteStashData', + 'RemoteStashFolderData', + 'SinglefileData', + 'Site', + 'Str', + 'StructureData', + 'TrajectoryData', + 'UpfData', + 'XyData', + 'cif_from_ase', + 'find_bandgap', + 'has_pycifrw', + 'pycifrw_from_cif', + 'to_aiida_type', ) + +# yapf: enable diff --git a/aiida/orm/nodes/data/array/__init__.py b/aiida/orm/nodes/data/array/__init__.py index d34d6ad52a..f12feedfbe 100644 --- a/aiida/orm/nodes/data/array/__init__.py +++ b/aiida/orm/nodes/data/array/__init__.py @@ -9,11 +9,26 @@ ########################################################################### """Module with `Node` sub classes for array based data structures.""" -from .array import ArrayData -from .bands import BandsData -from .kpoints import KpointsData -from .projection import ProjectionData -from .trajectory import TrajectoryData -from .xy import XyData +# AUTO-GENERATED -__all__ = ('ArrayData', 'BandsData', 'KpointsData', 'ProjectionData', 'TrajectoryData', 'XyData') +# yapf: disable +# pylint: disable=wildcard-import + +from .array import * +from .bands import * +from .kpoints import * +from .projection import * +from .trajectory import * +from .xy import * + +__all__ = ( + 'ArrayData', + 'BandsData', + 'KpointsData', + 'ProjectionData', + 'TrajectoryData', + 'XyData', + 'find_bandgap', +) + +# yapf: enable diff --git a/aiida/orm/nodes/data/array/array.py b/aiida/orm/nodes/data/array/array.py index 3d1553e4c5..98fe87f8d8 100644 --- a/aiida/orm/nodes/data/array/array.py +++ b/aiida/orm/nodes/data/array/array.py @@ -10,9 +10,10 @@ """ AiiDA ORM data class storing (numpy) arrays """ - from ..data import Data +__all__ = ('ArrayData',) + class ArrayData(Data): """ diff --git a/aiida/orm/nodes/data/array/bands.py b/aiida/orm/nodes/data/array/bands.py index 0559017a4f..4a0a40c590 100644 --- a/aiida/orm/nodes/data/array/bands.py +++ b/aiida/orm/nodes/data/array/bands.py @@ -20,6 +20,8 @@ from aiida.common.utils import prettify_labels, join_labels from .kpoints import KpointsData +__all__ = ('BandsData', 'find_bandgap') + def prepare_header_comment(uuid, plot_info, comment_char='#'): """Prepare the header.""" diff --git a/aiida/orm/nodes/data/array/kpoints.py b/aiida/orm/nodes/data/array/kpoints.py index a3aa1630d0..c0957516ab 100644 --- a/aiida/orm/nodes/data/array/kpoints.py +++ b/aiida/orm/nodes/data/array/kpoints.py @@ -16,6 +16,8 @@ from .array import ArrayData +__all__ = ('KpointsData',) + _DEFAULT_EPSILON_LENGTH = 1e-5 _DEFAULT_EPSILON_ANGLE = 1e-5 diff --git a/aiida/orm/nodes/data/array/projection.py b/aiida/orm/nodes/data/array/projection.py index 87b1f8bd08..a27205743c 100644 --- a/aiida/orm/nodes/data/array/projection.py +++ b/aiida/orm/nodes/data/array/projection.py @@ -18,6 +18,8 @@ from .array import ArrayData from .bands import BandsData +__all__ = ('ProjectionData',) + class ProjectionData(OrbitalData, ArrayData): """ diff --git a/aiida/orm/nodes/data/array/trajectory.py b/aiida/orm/nodes/data/array/trajectory.py index c0a0ac2485..575c94133b 100644 --- a/aiida/orm/nodes/data/array/trajectory.py +++ b/aiida/orm/nodes/data/array/trajectory.py @@ -15,6 +15,8 @@ from .array import ArrayData +__all__ = ('TrajectoryData',) + class TrajectoryData(ArrayData): """ diff --git a/aiida/orm/nodes/data/array/xy.py b/aiida/orm/nodes/data/array/xy.py index 48e3a47ae1..3a253074f4 100644 --- a/aiida/orm/nodes/data/array/xy.py +++ b/aiida/orm/nodes/data/array/xy.py @@ -16,6 +16,8 @@ from aiida.common.exceptions import NotExistent from .array import ArrayData +__all__ = ('XyData',) + def check_convert_single_to_tuple(item): """ diff --git a/aiida/orm/nodes/data/cif.py b/aiida/orm/nodes/data/cif.py index 53ad65c098..e5e5339269 100644 --- a/aiida/orm/nodes/data/cif.py +++ b/aiida/orm/nodes/data/cif.py @@ -15,6 +15,8 @@ from .singlefile import SinglefileData +__all__ = ('CifData', 'cif_from_ase', 'has_pycifrw', 'pycifrw_from_cif') + ase_loops = { '_atom_site': [ '_atom_site_label', diff --git a/aiida/orm/nodes/data/remote/__init__.py b/aiida/orm/nodes/data/remote/__init__.py index 2f88d7edbc..ae1b5dbc4f 100644 --- a/aiida/orm/nodes/data/remote/__init__.py +++ b/aiida/orm/nodes/data/remote/__init__.py @@ -1,6 +1,18 @@ # -*- coding: utf-8 -*- """Module with data plugins that represent remote resources and so effectively are symbolic links.""" -from .base import RemoteData -from .stash import RemoteStashData, RemoteStashFolderData -__all__ = ('RemoteData', 'RemoteStashData', 'RemoteStashFolderData') +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .base import * +from .stash import * + +__all__ = ( + 'RemoteData', + 'RemoteStashData', + 'RemoteStashFolderData', +) + +# yapf: enable diff --git a/aiida/orm/nodes/data/remote/stash/__init__.py b/aiida/orm/nodes/data/remote/stash/__init__.py index f744240cfc..e06481e842 100644 --- a/aiida/orm/nodes/data/remote/stash/__init__.py +++ b/aiida/orm/nodes/data/remote/stash/__init__.py @@ -1,6 +1,17 @@ # -*- coding: utf-8 -*- """Module with data plugins that represent files of completed calculations jobs that have been stashed.""" -from .base import RemoteStashData -from .folder import RemoteStashFolderData -__all__ = ('RemoteStashData', 'RemoteStashFolderData') +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .base import * +from .folder import * + +__all__ = ( + 'RemoteStashData', + 'RemoteStashFolderData', +) + +# yapf: enable diff --git a/aiida/orm/nodes/process/__init__.py b/aiida/orm/nodes/process/__init__.py index 4a84f892b0..283b14e9b0 100644 --- a/aiida/orm/nodes/process/__init__.py +++ b/aiida/orm/nodes/process/__init__.py @@ -7,11 +7,25 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module with `Node` sub classes for processes.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .calculation import * from .process import * from .workflow import * -__all__ = (calculation.__all__ + process.__all__ + workflow.__all__) # type: ignore[name-defined] +__all__ = ( + 'CalcFunctionNode', + 'CalcJobNode', + 'CalculationNode', + 'ProcessNode', + 'WorkChainNode', + 'WorkFunctionNode', + 'WorkflowNode', +) + +# yapf: enable diff --git a/aiida/orm/nodes/process/calculation/__init__.py b/aiida/orm/nodes/process/calculation/__init__.py index 4d6232ba92..21af4e576e 100644 --- a/aiida/orm/nodes/process/calculation/__init__.py +++ b/aiida/orm/nodes/process/calculation/__init__.py @@ -9,8 +9,19 @@ ########################################################################### """Module with `Node` sub classes for calculation processes.""" -from .calculation import CalculationNode -from .calcfunction import CalcFunctionNode -from .calcjob import CalcJobNode +# AUTO-GENERATED -__all__ = ('CalculationNode', 'CalcFunctionNode', 'CalcJobNode') +# yapf: disable +# pylint: disable=wildcard-import + +from .calcfunction import * +from .calcjob import * +from .calculation import * + +__all__ = ( + 'CalcFunctionNode', + 'CalcJobNode', + 'CalculationNode', +) + +# yapf: enable diff --git a/aiida/orm/nodes/process/workflow/__init__.py b/aiida/orm/nodes/process/workflow/__init__.py index b4f210da6f..f4125a4f8f 100644 --- a/aiida/orm/nodes/process/workflow/__init__.py +++ b/aiida/orm/nodes/process/workflow/__init__.py @@ -9,8 +9,19 @@ ########################################################################### """Module with `Node` sub classes for workflow processes.""" -from .workflow import WorkflowNode -from .workchain import WorkChainNode -from .workfunction import WorkFunctionNode +# AUTO-GENERATED -__all__ = ('WorkflowNode', 'WorkChainNode', 'WorkFunctionNode') +# yapf: disable +# pylint: disable=wildcard-import + +from .workchain import * +from .workflow import * +from .workfunction import * + +__all__ = ( + 'WorkChainNode', + 'WorkFunctionNode', + 'WorkflowNode', +) + +# yapf: enable diff --git a/aiida/orm/utils/__init__.py b/aiida/orm/utils/__init__.py index f703884d0e..16e7b146c1 100644 --- a/aiida/orm/utils/__init__.py +++ b/aiida/orm/utils/__init__.py @@ -9,197 +9,41 @@ ########################################################################### """Utilities related to the ORM.""" -__all__ = ('load_code', 'load_computer', 'load_group', 'load_node') - - -def load_entity( - entity_loader=None, identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True -): - # pylint: disable=too-many-arguments - """ - Load an entity instance by one of its identifiers: pk, uuid or label - - If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to - automatically infer the type. - - :param identifier: pk (integer), uuid (string) or label (string) of a Code - :param pk: pk of a Code - :param uuid: uuid of a Code, or the beginning of the uuid - :param label: label of a Code - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :returns: the Code instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Code is found - :raise aiida.common.MultipleObjectsError: if more than one Code was found - """ - from aiida.orm.utils.loaders import OrmEntityLoader, IdentifierType - - if entity_loader is None or not issubclass(entity_loader, OrmEntityLoader): - raise TypeError(f'entity_loader should be a sub class of {type(OrmEntityLoader)}') - - inputs_provided = [value is not None for value in (identifier, pk, uuid, label)].count(True) - - if inputs_provided == 0: - raise ValueError("one of the parameters 'identifier', pk', 'uuid' or 'label' has to be specified") - elif inputs_provided > 1: - raise ValueError("only one of parameters 'identifier', pk', 'uuid' or 'label' has to be specified") - - if pk is not None: - - if not isinstance(pk, int): - raise TypeError('a pk has to be an integer') - - identifier = pk - identifier_type = IdentifierType.ID - - elif uuid is not None: - - if not isinstance(uuid, str): - raise TypeError('uuid has to be a string type') - - identifier = uuid - identifier_type = IdentifierType.UUID - - elif label is not None: - - if not isinstance(label, str): - raise TypeError('label has to be a string type') - - identifier = label - identifier_type = IdentifierType.LABEL - else: - identifier = str(identifier) - identifier_type = None - - return entity_loader.load_entity( - identifier, identifier_type, sub_classes=sub_classes, query_with_dashes=query_with_dashes - ) - - -def load_code(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): - """ - Load a Code instance by one of its identifiers: pk, uuid or label - - If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to - automatically infer the type. - - :param identifier: pk (integer), uuid (string) or label (string) of a Code - :param pk: pk of a Code - :param uuid: uuid of a Code, or the beginning of the uuid - :param label: label of a Code - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :return: the Code instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Code is found - :raise aiida.common.MultipleObjectsError: if more than one Code was found - """ - from aiida.orm.utils.loaders import CodeEntityLoader - return load_entity( - CodeEntityLoader, - identifier=identifier, - pk=pk, - uuid=uuid, - label=label, - sub_classes=sub_classes, - query_with_dashes=query_with_dashes - ) - - -def load_computer(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): - """ - Load a Computer instance by one of its identifiers: pk, uuid or label - - If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to - automatically infer the type. - - :param identifier: pk (integer), uuid (string) or label (string) of a Computer - :param pk: pk of a Computer - :param uuid: uuid of a Computer, or the beginning of the uuid - :param label: label of a Computer - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :return: the Computer instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Computer is found - :raise aiida.common.MultipleObjectsError: if more than one Computer was found - """ - from aiida.orm.utils.loaders import ComputerEntityLoader - return load_entity( - ComputerEntityLoader, - identifier=identifier, - pk=pk, - uuid=uuid, - label=label, - sub_classes=sub_classes, - query_with_dashes=query_with_dashes - ) - - -def load_group(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): - """ - Load a Group instance by one of its identifiers: pk, uuid or label - - If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to - automatically infer the type. - - :param identifier: pk (integer), uuid (string) or label (string) of a Group - :param pk: pk of a Group - :param uuid: uuid of a Group, or the beginning of the uuid - :param label: label of a Group - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :return: the Group instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Group is found - :raise aiida.common.MultipleObjectsError: if more than one Group was found - """ - from aiida.orm.utils.loaders import GroupEntityLoader - return load_entity( - GroupEntityLoader, - identifier=identifier, - pk=pk, - uuid=uuid, - label=label, - sub_classes=sub_classes, - query_with_dashes=query_with_dashes - ) - - -def load_node(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): - """ - Load a node by one of its identifiers: pk or uuid. If the type of the identifier is unknown - simply pass it without a keyword and the loader will attempt to infer the type - - :param identifier: pk (integer) or uuid (string) - :param pk: pk of a node - :param uuid: uuid of a node, or the beginning of the uuid - :param label: label of a Node - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :returns: the node instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Node is found - :raise aiida.common.MultipleObjectsError: if more than one Node was found - """ - from aiida.orm.utils.loaders import NodeEntityLoader - return load_entity( - NodeEntityLoader, - identifier=identifier, - pk=pk, - uuid=uuid, - label=label, - sub_classes=sub_classes, - query_with_dashes=query_with_dashes - ) +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .calcjob import * +from .links import * +from .loaders import * +from .managers import * +from .node import * + +__all__ = ( + 'AbstractNodeMeta', + 'AttributeManager', + 'CalcJobResultManager', + 'CalculationEntityLoader', + 'CodeEntityLoader', + 'ComputerEntityLoader', + 'GroupEntityLoader', + 'LinkManager', + 'LinkPair', + 'LinkTriple', + 'NodeEntityLoader', + 'NodeLinksManager', + 'OrmEntityLoader', + 'get_loader', + 'get_query_type_from_type_string', + 'get_type_string_from_class', + 'load_code', + 'load_computer', + 'load_entity', + 'load_group', + 'load_node', + 'load_node_class', + 'validate_link', +) + +# yapf: enable diff --git a/aiida/orm/utils/loaders.py b/aiida/orm/utils/loaders.py index 0773a7a48c..66d59ad187 100644 --- a/aiida/orm/utils/loaders.py +++ b/aiida/orm/utils/loaders.py @@ -16,11 +16,198 @@ from aiida.orm.querybuilder import QueryBuilder __all__ = ( - 'get_loader', 'OrmEntityLoader', 'CalculationEntityLoader', 'CodeEntityLoader', 'ComputerEntityLoader', - 'GroupEntityLoader', 'NodeEntityLoader' + 'load_code', 'load_computer', 'load_group', 'load_node', 'load_entity', 'get_loader', 'OrmEntityLoader', + 'CalculationEntityLoader', 'CodeEntityLoader', 'ComputerEntityLoader', 'GroupEntityLoader', 'NodeEntityLoader' ) +def load_entity( + entity_loader=None, identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True +): + # pylint: disable=too-many-arguments + """ + Load an entity instance by one of its identifiers: pk, uuid or label + + If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to + automatically infer the type. + + :param identifier: pk (integer), uuid (string) or label (string) of a Code + :param pk: pk of a Code + :param uuid: uuid of a Code, or the beginning of the uuid + :param label: label of a Code + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :returns: the Code instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Code is found + :raise aiida.common.MultipleObjectsError: if more than one Code was found + """ + if entity_loader is None or not issubclass(entity_loader, OrmEntityLoader): + raise TypeError(f'entity_loader should be a sub class of {type(OrmEntityLoader)}') + + inputs_provided = [value is not None for value in (identifier, pk, uuid, label)].count(True) + + if inputs_provided == 0: + raise ValueError("one of the parameters 'identifier', pk', 'uuid' or 'label' has to be specified") + elif inputs_provided > 1: + raise ValueError("only one of parameters 'identifier', pk', 'uuid' or 'label' has to be specified") + + if pk is not None: + + if not isinstance(pk, int): + raise TypeError('a pk has to be an integer') + + identifier = pk + identifier_type = IdentifierType.ID + + elif uuid is not None: + + if not isinstance(uuid, str): + raise TypeError('uuid has to be a string type') + + identifier = uuid + identifier_type = IdentifierType.UUID + + elif label is not None: + + if not isinstance(label, str): + raise TypeError('label has to be a string type') + + identifier = label + identifier_type = IdentifierType.LABEL + else: + identifier = str(identifier) + identifier_type = None + + return entity_loader.load_entity( + identifier, identifier_type, sub_classes=sub_classes, query_with_dashes=query_with_dashes + ) + + +def load_code(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): + """ + Load a Code instance by one of its identifiers: pk, uuid or label + + If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to + automatically infer the type. + + :param identifier: pk (integer), uuid (string) or label (string) of a Code + :param pk: pk of a Code + :param uuid: uuid of a Code, or the beginning of the uuid + :param label: label of a Code + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :return: the Code instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Code is found + :raise aiida.common.MultipleObjectsError: if more than one Code was found + """ + return load_entity( + CodeEntityLoader, + identifier=identifier, + pk=pk, + uuid=uuid, + label=label, + sub_classes=sub_classes, + query_with_dashes=query_with_dashes + ) + + +def load_computer(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): + """ + Load a Computer instance by one of its identifiers: pk, uuid or label + + If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to + automatically infer the type. + + :param identifier: pk (integer), uuid (string) or label (string) of a Computer + :param pk: pk of a Computer + :param uuid: uuid of a Computer, or the beginning of the uuid + :param label: label of a Computer + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :return: the Computer instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Computer is found + :raise aiida.common.MultipleObjectsError: if more than one Computer was found + """ + return load_entity( + ComputerEntityLoader, + identifier=identifier, + pk=pk, + uuid=uuid, + label=label, + sub_classes=sub_classes, + query_with_dashes=query_with_dashes + ) + + +def load_group(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): + """ + Load a Group instance by one of its identifiers: pk, uuid or label + + If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to + automatically infer the type. + + :param identifier: pk (integer), uuid (string) or label (string) of a Group + :param pk: pk of a Group + :param uuid: uuid of a Group, or the beginning of the uuid + :param label: label of a Group + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :return: the Group instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Group is found + :raise aiida.common.MultipleObjectsError: if more than one Group was found + """ + return load_entity( + GroupEntityLoader, + identifier=identifier, + pk=pk, + uuid=uuid, + label=label, + sub_classes=sub_classes, + query_with_dashes=query_with_dashes + ) + + +def load_node(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): + """ + Load a node by one of its identifiers: pk or uuid. If the type of the identifier is unknown + simply pass it without a keyword and the loader will attempt to infer the type + + :param identifier: pk (integer) or uuid (string) + :param pk: pk of a node + :param uuid: uuid of a node, or the beginning of the uuid + :param label: label of a Node + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :returns: the node instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Node is found + :raise aiida.common.MultipleObjectsError: if more than one Node was found + """ + return load_entity( + NodeEntityLoader, + identifier=identifier, + pk=pk, + uuid=uuid, + label=label, + sub_classes=sub_classes, + query_with_dashes=query_with_dashes + ) + + def get_loader(orm_class): """Return the correct OrmEntityLoader for the given orm class. diff --git a/aiida/parsers/__init__.py b/aiida/parsers/__init__.py index 5f6ee399c0..b3789ed596 100644 --- a/aiida/parsers/__init__.py +++ b/aiida/parsers/__init__.py @@ -7,9 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module for classes and utilities to write parsers for calculation jobs.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .parser import * -__all__ = (parser.__all__) +__all__ = ( + 'Parser', +) + +# yapf: enable diff --git a/aiida/plugins/__init__.py b/aiida/plugins/__init__.py index a169084f36..14a89108c0 100644 --- a/aiida/plugins/__init__.py +++ b/aiida/plugins/__init__.py @@ -7,10 +7,31 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Classes and functions to load and interact with plugin classes accessible through defined entry points.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .entry_point import * from .factories import * +from .utils import * + +__all__ = ( + 'BaseFactory', + 'CalculationFactory', + 'DataFactory', + 'DbImporterFactory', + 'GroupFactory', + 'OrbitalFactory', + 'ParserFactory', + 'PluginVersionProvider', + 'SchedulerFactory', + 'TransportFactory', + 'WorkflowFactory', + 'load_entry_point', + 'load_entry_point_from_string', +) -__all__ = (entry_point.__all__ + factories.__all__) +# yapf: enable diff --git a/aiida/repository/__init__.py b/aiida/repository/__init__.py index 6c71f4dbaa..c828ca07f1 100644 --- a/aiida/repository/__init__.py +++ b/aiida/repository/__init__.py @@ -8,9 +8,23 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with resources dealing with the file repository.""" -# pylint: disable=undefined-variable + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .backend import * from .common import * from .repository import * -__all__ = (backend.__all__ + common.__all__ + repository.__all__) # type: ignore[name-defined] +__all__ = ( + 'AbstractRepositoryBackend', + 'DiskObjectStoreRepositoryBackend', + 'File', + 'FileType', + 'Repository', + 'SandboxRepositoryBackend', +) + +# yapf: enable diff --git a/aiida/repository/backend/__init__.py b/aiida/repository/backend/__init__.py index 20a704f865..ea4ab3386f 100644 --- a/aiida/repository/backend/__init__.py +++ b/aiida/repository/backend/__init__.py @@ -1,8 +1,19 @@ # -*- coding: utf-8 -*- -# pylint: disable=undefined-variable """Module for file repository backend implementations.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .abstract import * from .disk_object_store import * from .sandbox import * -__all__ = (abstract.__all__ + disk_object_store.__all__ + sandbox.__all__) # type: ignore[name-defined] +__all__ = ( + 'AbstractRepositoryBackend', + 'DiskObjectStoreRepositoryBackend', + 'SandboxRepositoryBackend', +) + +# yapf: enable diff --git a/aiida/restapi/__init__.py b/aiida/restapi/__init__.py index fc199853df..5cdd575a4a 100644 --- a/aiida/restapi/__init__.py +++ b/aiida/restapi/__init__.py @@ -12,3 +12,7 @@ AiiDA nodes stored in database. The REST API is implemented using Flask RESTFul framework. """ + +# AUTO-GENERATED + +__all__ = () diff --git a/aiida/schedulers/__init__.py b/aiida/schedulers/__init__.py index 2dd0db40f8..5fad6ad78f 100644 --- a/aiida/schedulers/__init__.py +++ b/aiida/schedulers/__init__.py @@ -7,10 +7,27 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module for classes and utilities to interact with cluster schedulers.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .datastructures import * from .scheduler import * -__all__ = (datastructures.__all__ + scheduler.__all__) +__all__ = ( + 'JobInfo', + 'JobResource', + 'JobState', + 'JobTemplate', + 'MachineInfo', + 'NodeNumberJobResource', + 'ParEnvJobResource', + 'Scheduler', + 'SchedulerError', + 'SchedulerParsingError', +) + +# yapf: enable diff --git a/aiida/tools/__init__.py b/aiida/tools/__init__.py index ffdf77d6e5..cb4615adb7 100644 --- a/aiida/tools/__init__.py +++ b/aiida/tools/__init__.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable,redefined-builtin """ Tools to operate on AiiDA ORM class instances @@ -21,12 +20,79 @@ """ +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .calculations import * -from .data.array.kpoints import * -from .data.structure import * -from .dbimporters import * +from .data import * from .graph import * +from .groups import * +from .importexport import * +from .visualization import * __all__ = ( - calculations.__all__ + data.array.kpoints.__all__ + data.structure.__all__ + dbimporters.__all__ + graph.__all__ + 'ARCHIVE_READER_LOGGER', + 'ArchiveExportError', + 'ArchiveImportError', + 'ArchiveMetadata', + 'ArchiveMigrationError', + 'ArchiveMigratorAbstract', + 'ArchiveMigratorJsonBase', + 'ArchiveMigratorJsonTar', + 'ArchiveMigratorJsonZip', + 'ArchiveReaderAbstract', + 'ArchiveWriterAbstract', + 'CacheFolder', + 'CalculationTools', + 'CorruptArchive', + 'DELETE_LOGGER', + 'DanglingLinkError', + 'EXPORT_LOGGER', + 'EXPORT_VERSION', + 'ExportFileFormat', + 'ExportImportException', + 'ExportValidationError', + 'Graph', + 'GroupNotFoundError', + 'GroupNotUniqueError', + 'GroupPath', + 'IMPORT_LOGGER', + 'ImportUniquenessError', + 'ImportValidationError', + 'IncompatibleArchiveVersionError', + 'InvalidPath', + 'MIGRATE_LOGGER', + 'MigrationValidationError', + 'NoGroupsInPathError', + 'Orbital', + 'ProgressBarError', + 'ReaderJsonBase', + 'ReaderJsonFolder', + 'ReaderJsonTar', + 'ReaderJsonZip', + 'RealhydrogenOrbital', + 'WriterJsonFolder', + 'WriterJsonTar', + 'WriterJsonZip', + 'default_link_styles', + 'default_node_styles', + 'default_node_sublabels', + 'delete_group_nodes', + 'delete_nodes', + 'detect_archive_type', + 'export', + 'get_explicit_kpoints_path', + 'get_kpoints_path', + 'get_migrator', + 'get_reader', + 'get_writer', + 'import_data', + 'null_callback', + 'pstate_node_styles', + 'spglib_tuple_to_structure', + 'structure_to_spglib_tuple', ) + +# yapf: enable diff --git a/aiida/tools/calculations/__init__.py b/aiida/tools/calculations/__init__.py index 34a0745e5f..7fc43df3e3 100644 --- a/aiida/tools/calculations/__init__.py +++ b/aiida/tools/calculations/__init__.py @@ -7,9 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Calculation tool plugins for Calculation classes.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .base import * -__all__ = (base.__all__) +__all__ = ( + 'CalculationTools', +) + +# yapf: enable diff --git a/aiida/tools/data/__init__.py b/aiida/tools/data/__init__.py index 2776a55f97..fdf843ae12 100644 --- a/aiida/tools/data/__init__.py +++ b/aiida/tools/data/__init__.py @@ -7,3 +7,24 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Tool for handling data.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .array import * +from .orbital import * +from .structure import * + +__all__ = ( + 'Orbital', + 'RealhydrogenOrbital', + 'get_explicit_kpoints_path', + 'get_kpoints_path', + 'spglib_tuple_to_structure', + 'structure_to_spglib_tuple', +) + +# yapf: enable diff --git a/aiida/tools/data/array/__init__.py b/aiida/tools/data/array/__init__.py index 2776a55f97..ebb95e693f 100644 --- a/aiida/tools/data/array/__init__.py +++ b/aiida/tools/data/array/__init__.py @@ -7,3 +7,18 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Tools for manipulating array data classes.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .kpoints import * + +__all__ = ( + 'get_explicit_kpoints_path', + 'get_kpoints_path', +) + +# yapf: enable diff --git a/aiida/tools/data/array/kpoints/__init__.py b/aiida/tools/data/array/kpoints/__init__.py index 59c40e53f6..ac536c11a9 100644 --- a/aiida/tools/data/array/kpoints/__init__.py +++ b/aiida/tools/data/array/kpoints/__init__.py @@ -11,231 +11,17 @@ Various utilities to deal with KpointsData instances or create new ones (e.g. band paths, kpoints from a parsed input text file, ...) """ -from aiida.orm import KpointsData, Dict -__all__ = ('get_kpoints_path', 'get_explicit_kpoints_path') +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import -def get_kpoints_path(structure, method='seekpath', **kwargs): - """ - Returns a dictionary whose contents depend on the method but includes at least the following keys +from .main import * - * parameters: Dict node +__all__ = ( + 'get_explicit_kpoints_path', + 'get_kpoints_path', +) - The contents of the parameters depends on the method but contains at least the keys - - * 'point_coords': a dictionary with 'kpoints-label': [float coordinates] - * 'path': a list of length-2 tuples, with the labels of the starting - and ending point of each label section - - The 'seekpath' method which is the default also returns the following additional nodes - - * primitive_structure: StructureData with the primitive cell - * conv_structure: StructureData with the conventional cell - - Note that the generated kpoints for the seekpath method only apply on the returned primitive_structure - and not on the input structure that was provided - - :param structure: a StructureData node - :param method: the method to use for kpoint generation, options are 'seekpath' and 'legacy'. - It is strongly advised to use the default 'seekpath' as the 'legacy' implementation is known to have - bugs for certain structure cells - :param kwargs: optional keyword arguments that depend on the selected method - :returns: dictionary as described above in the docstring - """ - if method not in _GET_KPOINTS_PATH_METHODS.keys(): - raise ValueError(f"the method '{method}' is not implemented") - - method = _GET_KPOINTS_PATH_METHODS[method] - - return method(structure, **kwargs) - - -def get_explicit_kpoints_path(structure, method='seekpath', **kwargs): - """ - Returns a dictionary whose contents depend on the method but includes at least the following keys - - * parameters: Dict node - * explicit_kpoints: KpointsData node with explicit kpoints path - - The contents of the parameters depends on the method but contains at least the keys - - * 'point_coords': a dictionary with 'kpoints-label': [float coordinates] - * 'path': a list of length-2 tuples, with the labels of the starting - and ending point of each label section - - The 'seekpath' method which is the default also returns the following additional nodes - - * primitive_structure: StructureData with the primitive cell - * conv_structure: StructureData with the conventional cell - - Note that the generated kpoints for the seekpath method only apply on the returned primitive_structure - and not on the input structure that was provided - - :param structure: a StructureData node - :param method: the method to use for kpoint generation, options are 'seekpath' and 'legacy'. - It is strongly advised to use the default 'seekpath' as the 'legacy' implementation is known to have - bugs for certain structure cells - :param kwargs: optional keyword arguments that depend on the selected method - :returns: dictionary as described above in the docstring - """ - if method not in _GET_EXPLICIT_KPOINTS_PATH_METHODS.keys(): - raise ValueError(f"the method '{method}' is not implemented") - - method = _GET_EXPLICIT_KPOINTS_PATH_METHODS[method] - - return method(structure, **kwargs) - - -def _seekpath_get_kpoints_path(structure, **kwargs): - """ - Call the get_kpoints_path wrapper function for Seekpath - - :param structure: a StructureData node - :param with_time_reversal: if False, and the group has no inversion - symmetry, additional lines are returned - :param recipe: choose the reference publication that defines the special points and paths. - Currently, the following value is implemented: - - - ``hpkot``: HPKOT paper: - Y. Hinuma, G. Pizzi, Y. Kumagai, F. Oba, I. Tanaka, Band structure - diagram paths based on crystallography, Comp. Mat. Sci. 128, 140 (2017). - DOI: 10.1016/j.commatsci.2016.10.015 - :param threshold: the threshold to use to verify if we are in - and edge case (e.g., a tetragonal cell, but ``a==c``). For instance, - in the tI lattice, if ``abs(a-c) < threshold``, a - :py:exc:`~seekpath.hpkot.EdgeCaseWarning` is issued. - Note that depending on the bravais lattice, the meaning of the - threshold is different (angle, length, ...) - :param symprec: the symmetry precision used internally by SPGLIB - :param angle_tolerance: the angle_tolerance used internally by SPGLIB - """ - from aiida.tools.data.array.kpoints import seekpath - - assert structure.pbc == (True, True, True), 'Seekpath only implemented for three-dimensional structures' - - recognized_args = ['with_time_reversal', 'recipe', 'threshold', 'symprec', 'angle_tolerance'] - unknown_args = set(kwargs).difference(recognized_args) - - if unknown_args: - raise ValueError(f'unknown arguments {unknown_args}') - - return seekpath.get_kpoints_path(structure, kwargs) - - -def _seekpath_get_explicit_kpoints_path(structure, **kwargs): - """ - Call the get_explicit_kpoints_path wrapper function for Seekpath - - :param structure: a StructureData node - :param with_time_reversal: if False, and the group has no inversion - symmetry, additional lines are returned - :param reference_distance: a reference target distance between neighboring - k-points in the path, in units of 1/ang. The actual value will be as - close as possible to this value, to have an integer number of points in - each path - :param recipe: choose the reference publication that defines the special points and paths. - Currently, the following value is implemented: - - - ``hpkot``: HPKOT paper: - Y. Hinuma, G. Pizzi, Y. Kumagai, F. Oba, I. Tanaka, Band structure - diagram paths based on crystallography, Comp. Mat. Sci. 128, 140 (2017). - DOI: 10.1016/j.commatsci.2016.10.015 - :param threshold: the threshold to use to verify if we are in - and edge case (e.g., a tetragonal cell, but ``a==c``). For instance, - in the tI lattice, if ``abs(a-c) < threshold``, a - :py:exc:`~seekpath.hpkot.EdgeCaseWarning` is issued. - Note that depending on the bravais lattice, the meaning of the - threshold is different (angle, length, ...) - :param symprec: the symmetry precision used internally by SPGLIB - :param angle_tolerance: the angle_tolerance used internally by SPGLIB - """ - from aiida.tools.data.array.kpoints import seekpath - - assert structure.pbc == (True, True, True), 'Seekpath only implemented for three-dimensional structures' - - recognized_args = ['with_time_reversal', 'reference_distance', 'recipe', 'threshold', 'symprec', 'angle_tolerance'] - unknown_args = set(kwargs).difference(recognized_args) - - if unknown_args: - raise ValueError(f'unknown arguments {unknown_args}') - - return seekpath.get_explicit_kpoints_path(structure, kwargs) - - -def _legacy_get_kpoints_path(structure, **kwargs): - """ - Call the get_kpoints_path of the legacy implementation - - :param structure: a StructureData node - :param bool cartesian: if set to true, reads the coordinates eventually passed in value as cartesian coordinates - :param epsilon_length: threshold on lengths comparison, used to get the bravais lattice info - :param epsilon_angle: threshold on angles comparison, used to get the bravais lattice info - """ - from aiida.tools.data.array.kpoints import legacy - - args_recognized = ['cartesian', 'epsilon_length', 'epsilon_angle'] - args_unknown = set(kwargs).difference(args_recognized) - - if args_unknown: - raise ValueError(f'unknown arguments {args_unknown}') - - point_coords, path, bravais_info = legacy.get_kpoints_path(cell=structure.cell, pbc=structure.pbc, **kwargs) - - parameters = { - 'bravais_info': bravais_info, - 'point_coords': point_coords, - 'path': path, - } - - return {'parameters': Dict(dict=parameters)} - - -def _legacy_get_explicit_kpoints_path(structure, **kwargs): - """ - Call the get_explicit_kpoints_path of the legacy implementation - - :param structure: a StructureData node - :param float kpoint_distance: parameter controlling the distance between kpoints. Distance is - given in crystal coordinates, i.e. the distance is computed in the space of b1, b2, b3. - The distance set will be the closest possible to this value, compatible with the requirement - of putting equispaced points between two special points (since extrema are included). - :param bool cartesian: if set to true, reads the coordinates eventually passed in value as cartesian coordinates - :param float epsilon_length: threshold on lengths comparison, used to get the bravais lattice info - :param float epsilon_angle: threshold on angles comparison, used to get the bravais lattice info - """ - from aiida.tools.data.array.kpoints import legacy - - args_recognized = ['value', 'kpoint_distance', 'cartesian', 'epsilon_length', 'epsilon_angle'] - args_unknown = set(kwargs).difference(args_recognized) - - if args_unknown: - raise ValueError(f'unknown arguments {args_unknown}') - - point_coords, path, bravais_info, explicit_kpoints, labels = legacy.get_explicit_kpoints_path( # pylint: disable=unbalanced-tuple-unpacking - cell=structure.cell, pbc=structure.pbc, **kwargs - ) - - kpoints = KpointsData() - kpoints.set_cell(structure.cell) - kpoints.set_kpoints(explicit_kpoints) - kpoints.labels = labels - - parameters = { - 'bravais_info': bravais_info, - 'point_coords': point_coords, - 'path': path, - } - - return {'parameters': Dict(dict=parameters), 'explicit_kpoints': kpoints} - - -_GET_KPOINTS_PATH_METHODS = { - 'legacy': _legacy_get_kpoints_path, - 'seekpath': _seekpath_get_kpoints_path, -} - -_GET_EXPLICIT_KPOINTS_PATH_METHODS = { - 'legacy': _legacy_get_explicit_kpoints_path, - 'seekpath': _seekpath_get_explicit_kpoints_path, -} +# yapf: enable diff --git a/aiida/tools/data/array/kpoints/main.py b/aiida/tools/data/array/kpoints/main.py new file mode 100644 index 0000000000..59c40e53f6 --- /dev/null +++ b/aiida/tools/data/array/kpoints/main.py @@ -0,0 +1,241 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +""" +Various utilities to deal with KpointsData instances or create new ones +(e.g. band paths, kpoints from a parsed input text file, ...) +""" +from aiida.orm import KpointsData, Dict + +__all__ = ('get_kpoints_path', 'get_explicit_kpoints_path') + + +def get_kpoints_path(structure, method='seekpath', **kwargs): + """ + Returns a dictionary whose contents depend on the method but includes at least the following keys + + * parameters: Dict node + + The contents of the parameters depends on the method but contains at least the keys + + * 'point_coords': a dictionary with 'kpoints-label': [float coordinates] + * 'path': a list of length-2 tuples, with the labels of the starting + and ending point of each label section + + The 'seekpath' method which is the default also returns the following additional nodes + + * primitive_structure: StructureData with the primitive cell + * conv_structure: StructureData with the conventional cell + + Note that the generated kpoints for the seekpath method only apply on the returned primitive_structure + and not on the input structure that was provided + + :param structure: a StructureData node + :param method: the method to use for kpoint generation, options are 'seekpath' and 'legacy'. + It is strongly advised to use the default 'seekpath' as the 'legacy' implementation is known to have + bugs for certain structure cells + :param kwargs: optional keyword arguments that depend on the selected method + :returns: dictionary as described above in the docstring + """ + if method not in _GET_KPOINTS_PATH_METHODS.keys(): + raise ValueError(f"the method '{method}' is not implemented") + + method = _GET_KPOINTS_PATH_METHODS[method] + + return method(structure, **kwargs) + + +def get_explicit_kpoints_path(structure, method='seekpath', **kwargs): + """ + Returns a dictionary whose contents depend on the method but includes at least the following keys + + * parameters: Dict node + * explicit_kpoints: KpointsData node with explicit kpoints path + + The contents of the parameters depends on the method but contains at least the keys + + * 'point_coords': a dictionary with 'kpoints-label': [float coordinates] + * 'path': a list of length-2 tuples, with the labels of the starting + and ending point of each label section + + The 'seekpath' method which is the default also returns the following additional nodes + + * primitive_structure: StructureData with the primitive cell + * conv_structure: StructureData with the conventional cell + + Note that the generated kpoints for the seekpath method only apply on the returned primitive_structure + and not on the input structure that was provided + + :param structure: a StructureData node + :param method: the method to use for kpoint generation, options are 'seekpath' and 'legacy'. + It is strongly advised to use the default 'seekpath' as the 'legacy' implementation is known to have + bugs for certain structure cells + :param kwargs: optional keyword arguments that depend on the selected method + :returns: dictionary as described above in the docstring + """ + if method not in _GET_EXPLICIT_KPOINTS_PATH_METHODS.keys(): + raise ValueError(f"the method '{method}' is not implemented") + + method = _GET_EXPLICIT_KPOINTS_PATH_METHODS[method] + + return method(structure, **kwargs) + + +def _seekpath_get_kpoints_path(structure, **kwargs): + """ + Call the get_kpoints_path wrapper function for Seekpath + + :param structure: a StructureData node + :param with_time_reversal: if False, and the group has no inversion + symmetry, additional lines are returned + :param recipe: choose the reference publication that defines the special points and paths. + Currently, the following value is implemented: + + - ``hpkot``: HPKOT paper: + Y. Hinuma, G. Pizzi, Y. Kumagai, F. Oba, I. Tanaka, Band structure + diagram paths based on crystallography, Comp. Mat. Sci. 128, 140 (2017). + DOI: 10.1016/j.commatsci.2016.10.015 + :param threshold: the threshold to use to verify if we are in + and edge case (e.g., a tetragonal cell, but ``a==c``). For instance, + in the tI lattice, if ``abs(a-c) < threshold``, a + :py:exc:`~seekpath.hpkot.EdgeCaseWarning` is issued. + Note that depending on the bravais lattice, the meaning of the + threshold is different (angle, length, ...) + :param symprec: the symmetry precision used internally by SPGLIB + :param angle_tolerance: the angle_tolerance used internally by SPGLIB + """ + from aiida.tools.data.array.kpoints import seekpath + + assert structure.pbc == (True, True, True), 'Seekpath only implemented for three-dimensional structures' + + recognized_args = ['with_time_reversal', 'recipe', 'threshold', 'symprec', 'angle_tolerance'] + unknown_args = set(kwargs).difference(recognized_args) + + if unknown_args: + raise ValueError(f'unknown arguments {unknown_args}') + + return seekpath.get_kpoints_path(structure, kwargs) + + +def _seekpath_get_explicit_kpoints_path(structure, **kwargs): + """ + Call the get_explicit_kpoints_path wrapper function for Seekpath + + :param structure: a StructureData node + :param with_time_reversal: if False, and the group has no inversion + symmetry, additional lines are returned + :param reference_distance: a reference target distance between neighboring + k-points in the path, in units of 1/ang. The actual value will be as + close as possible to this value, to have an integer number of points in + each path + :param recipe: choose the reference publication that defines the special points and paths. + Currently, the following value is implemented: + + - ``hpkot``: HPKOT paper: + Y. Hinuma, G. Pizzi, Y. Kumagai, F. Oba, I. Tanaka, Band structure + diagram paths based on crystallography, Comp. Mat. Sci. 128, 140 (2017). + DOI: 10.1016/j.commatsci.2016.10.015 + :param threshold: the threshold to use to verify if we are in + and edge case (e.g., a tetragonal cell, but ``a==c``). For instance, + in the tI lattice, if ``abs(a-c) < threshold``, a + :py:exc:`~seekpath.hpkot.EdgeCaseWarning` is issued. + Note that depending on the bravais lattice, the meaning of the + threshold is different (angle, length, ...) + :param symprec: the symmetry precision used internally by SPGLIB + :param angle_tolerance: the angle_tolerance used internally by SPGLIB + """ + from aiida.tools.data.array.kpoints import seekpath + + assert structure.pbc == (True, True, True), 'Seekpath only implemented for three-dimensional structures' + + recognized_args = ['with_time_reversal', 'reference_distance', 'recipe', 'threshold', 'symprec', 'angle_tolerance'] + unknown_args = set(kwargs).difference(recognized_args) + + if unknown_args: + raise ValueError(f'unknown arguments {unknown_args}') + + return seekpath.get_explicit_kpoints_path(structure, kwargs) + + +def _legacy_get_kpoints_path(structure, **kwargs): + """ + Call the get_kpoints_path of the legacy implementation + + :param structure: a StructureData node + :param bool cartesian: if set to true, reads the coordinates eventually passed in value as cartesian coordinates + :param epsilon_length: threshold on lengths comparison, used to get the bravais lattice info + :param epsilon_angle: threshold on angles comparison, used to get the bravais lattice info + """ + from aiida.tools.data.array.kpoints import legacy + + args_recognized = ['cartesian', 'epsilon_length', 'epsilon_angle'] + args_unknown = set(kwargs).difference(args_recognized) + + if args_unknown: + raise ValueError(f'unknown arguments {args_unknown}') + + point_coords, path, bravais_info = legacy.get_kpoints_path(cell=structure.cell, pbc=structure.pbc, **kwargs) + + parameters = { + 'bravais_info': bravais_info, + 'point_coords': point_coords, + 'path': path, + } + + return {'parameters': Dict(dict=parameters)} + + +def _legacy_get_explicit_kpoints_path(structure, **kwargs): + """ + Call the get_explicit_kpoints_path of the legacy implementation + + :param structure: a StructureData node + :param float kpoint_distance: parameter controlling the distance between kpoints. Distance is + given in crystal coordinates, i.e. the distance is computed in the space of b1, b2, b3. + The distance set will be the closest possible to this value, compatible with the requirement + of putting equispaced points between two special points (since extrema are included). + :param bool cartesian: if set to true, reads the coordinates eventually passed in value as cartesian coordinates + :param float epsilon_length: threshold on lengths comparison, used to get the bravais lattice info + :param float epsilon_angle: threshold on angles comparison, used to get the bravais lattice info + """ + from aiida.tools.data.array.kpoints import legacy + + args_recognized = ['value', 'kpoint_distance', 'cartesian', 'epsilon_length', 'epsilon_angle'] + args_unknown = set(kwargs).difference(args_recognized) + + if args_unknown: + raise ValueError(f'unknown arguments {args_unknown}') + + point_coords, path, bravais_info, explicit_kpoints, labels = legacy.get_explicit_kpoints_path( # pylint: disable=unbalanced-tuple-unpacking + cell=structure.cell, pbc=structure.pbc, **kwargs + ) + + kpoints = KpointsData() + kpoints.set_cell(structure.cell) + kpoints.set_kpoints(explicit_kpoints) + kpoints.labels = labels + + parameters = { + 'bravais_info': bravais_info, + 'point_coords': point_coords, + 'path': path, + } + + return {'parameters': Dict(dict=parameters), 'explicit_kpoints': kpoints} + + +_GET_KPOINTS_PATH_METHODS = { + 'legacy': _legacy_get_kpoints_path, + 'seekpath': _seekpath_get_kpoints_path, +} + +_GET_EXPLICIT_KPOINTS_PATH_METHODS = { + 'legacy': _legacy_get_explicit_kpoints_path, + 'seekpath': _seekpath_get_explicit_kpoints_path, +} diff --git a/aiida/tools/data/array/kpoints/seekpath.py b/aiida/tools/data/array/kpoints/seekpath.py index 0d4c59b6a0..b12672b02d 100644 --- a/aiida/tools/data/array/kpoints/seekpath.py +++ b/aiida/tools/data/array/kpoints/seekpath.py @@ -12,8 +12,6 @@ from aiida.orm import KpointsData, Dict -__all__ = ('get_explicit_kpoints_path', 'get_kpoints_path') - def get_explicit_kpoints_path(structure, parameters): """ diff --git a/aiida/tools/data/orbital/__init__.py b/aiida/tools/data/orbital/__init__.py index 670b51ba55..2ece04b528 100644 --- a/aiida/tools/data/orbital/__init__.py +++ b/aiida/tools/data/orbital/__init__.py @@ -9,6 +9,17 @@ ########################################################################### """Module for classes and methods that represents molecular orbitals.""" -from .orbital import Orbital +# AUTO-GENERATED -__all__ = ('Orbital',) +# yapf: disable +# pylint: disable=wildcard-import + +from .orbital import * +from .realhydrogen import * + +__all__ = ( + 'Orbital', + 'RealhydrogenOrbital', +) + +# yapf: enable diff --git a/aiida/tools/data/orbital/orbital.py b/aiida/tools/data/orbital/orbital.py index a1ecabb6e1..5c83043e8a 100644 --- a/aiida/tools/data/orbital/orbital.py +++ b/aiida/tools/data/orbital/orbital.py @@ -17,6 +17,8 @@ from aiida.common.exceptions import ValidationError from aiida.plugins.entry_point import get_entry_point_from_class +__all__ = ('Orbital',) + def validate_int(value): """ diff --git a/aiida/tools/data/orbital/realhydrogen.py b/aiida/tools/data/orbital/realhydrogen.py index 84961b48b1..c8116659d8 100644 --- a/aiida/tools/data/orbital/realhydrogen.py +++ b/aiida/tools/data/orbital/realhydrogen.py @@ -15,6 +15,8 @@ from .orbital import Orbital, validate_len3_list_or_none, validate_float_or_none +__all__ = ('RealhydrogenOrbital',) + def validate_l(value): """ diff --git a/aiida/tools/data/structure/__init__.py b/aiida/tools/data/structure.py similarity index 100% rename from aiida/tools/data/structure/__init__.py rename to aiida/tools/data/structure.py diff --git a/aiida/tools/graph/__init__.py b/aiida/tools/graph/__init__.py index c095d1619a..95cffafca3 100644 --- a/aiida/tools/graph/__init__.py +++ b/aiida/tools/graph/__init__.py @@ -7,8 +7,19 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Provides tools for traversing the provenance graph.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .deletions import * -__all__ = deletions.__all__ +__all__ = ( + 'DELETE_LOGGER', + 'delete_group_nodes', + 'delete_nodes', +) + +# yapf: enable diff --git a/aiida/tools/groups/__init__.py b/aiida/tools/groups/__init__.py index 19e936839b..ab74c839aa 100644 --- a/aiida/tools/groups/__init__.py +++ b/aiida/tools/groups/__init__.py @@ -13,8 +13,21 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Provides tools for interacting with AiiDA Groups.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .paths import * -__all__ = paths.__all__ +__all__ = ( + 'GroupNotFoundError', + 'GroupNotUniqueError', + 'GroupPath', + 'InvalidPath', + 'NoGroupsInPathError', +) + +# yapf: enable diff --git a/aiida/tools/importexport/__init__.py b/aiida/tools/importexport/__init__.py index d6d576159f..0d545768a0 100644 --- a/aiida/tools/importexport/__init__.py +++ b/aiida/tools/importexport/__init__.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Provides import/export functionalities. To see history/git blame prior to the move to aiida.tools.importexport, @@ -15,9 +14,58 @@ Functionality: /aiida/orm/importexport.py Tests: /aiida/backends/tests/test_export_and_import.py """ + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .archive import * +from .common import * from .dbexport import * from .dbimport import * -from .common import * -__all__ = (archive.__all__ + dbexport.__all__ + dbimport.__all__ + common.__all__) +__all__ = ( + 'ARCHIVE_READER_LOGGER', + 'ArchiveExportError', + 'ArchiveImportError', + 'ArchiveMetadata', + 'ArchiveMigrationError', + 'ArchiveMigratorAbstract', + 'ArchiveMigratorJsonBase', + 'ArchiveMigratorJsonTar', + 'ArchiveMigratorJsonZip', + 'ArchiveReaderAbstract', + 'ArchiveWriterAbstract', + 'CacheFolder', + 'CorruptArchive', + 'DanglingLinkError', + 'EXPORT_LOGGER', + 'EXPORT_VERSION', + 'ExportFileFormat', + 'ExportImportException', + 'ExportValidationError', + 'IMPORT_LOGGER', + 'ImportUniquenessError', + 'ImportValidationError', + 'IncompatibleArchiveVersionError', + 'MIGRATE_LOGGER', + 'MigrationValidationError', + 'ProgressBarError', + 'ReaderJsonBase', + 'ReaderJsonFolder', + 'ReaderJsonTar', + 'ReaderJsonZip', + 'WriterJsonFolder', + 'WriterJsonTar', + 'WriterJsonZip', + 'detect_archive_type', + 'export', + 'get_migrator', + 'get_reader', + 'get_writer', + 'import_data', + 'null_callback', +) + +# yapf: enable diff --git a/aiida/tools/importexport/archive/__init__.py b/aiida/tools/importexport/archive/__init__.py index ee13f67842..b2dd149d7c 100644 --- a/aiida/tools/importexport/archive/__init__.py +++ b/aiida/tools/importexport/archive/__init__.py @@ -7,13 +7,42 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable # type: ignore """Readers and writers for archive formats, that work independently of a connection to an AiiDA profile.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .common import * from .migrators import * from .readers import * from .writers import * -__all__ = (migrators.__all__ + readers.__all__ + writers.__all__ + common.__all__) +__all__ = ( + 'ARCHIVE_READER_LOGGER', + 'ArchiveMetadata', + 'ArchiveMigratorAbstract', + 'ArchiveMigratorJsonBase', + 'ArchiveMigratorJsonTar', + 'ArchiveMigratorJsonZip', + 'ArchiveReaderAbstract', + 'ArchiveWriterAbstract', + 'CacheFolder', + 'MIGRATE_LOGGER', + 'ReaderJsonBase', + 'ReaderJsonFolder', + 'ReaderJsonTar', + 'ReaderJsonZip', + 'WriterJsonFolder', + 'WriterJsonTar', + 'WriterJsonZip', + 'detect_archive_type', + 'get_migrator', + 'get_reader', + 'get_writer', + 'null_callback', +) + +# yapf: enable diff --git a/aiida/tools/importexport/common/__init__.py b/aiida/tools/importexport/common/__init__.py index 4bdfd23504..f2755eade4 100644 --- a/aiida/tools/importexport/common/__init__.py +++ b/aiida/tools/importexport/common/__init__.py @@ -7,9 +7,30 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Common utility functions, classes, and exceptions""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .config import * from .exceptions import * -__all__ = (config.__all__ + exceptions.__all__) +__all__ = ( + 'ArchiveExportError', + 'ArchiveImportError', + 'ArchiveMigrationError', + 'CorruptArchive', + 'DanglingLinkError', + 'EXPORT_VERSION', + 'ExportImportException', + 'ExportValidationError', + 'ImportUniquenessError', + 'ImportValidationError', + 'IncompatibleArchiveVersionError', + 'MigrationValidationError', + 'ProgressBarError', +) + +# yapf: enable diff --git a/aiida/tools/importexport/dbexport/__init__.py b/aiida/tools/importexport/dbexport/__init__.py index c68a3c242a..22173490f7 100644 --- a/aiida/tools/importexport/dbexport/__init__.py +++ b/aiida/tools/importexport/dbexport/__init__.py @@ -7,607 +7,19 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=fixme,too-many-lines """Provides export functionalities.""" -from collections import defaultdict -import os -import tempfile -from typing import ( - Any, - Callable, - DefaultDict, - Dict, - Iterable, - List, - Optional, - Set, - Tuple, - Type, - Union, - cast, -) - -from aiida import get_version, orm -from aiida.common.links import GraphTraversalRules -from aiida.common.lang import type_check -from aiida.common.progress_reporter import get_progress_reporter, create_callback -from aiida.tools.importexport.common import exceptions -from aiida.tools.importexport.common.config import ( - COMMENT_ENTITY_NAME, - COMPUTER_ENTITY_NAME, - EXPORT_VERSION, - GROUP_ENTITY_NAME, - LOG_ENTITY_NAME, - NODE_ENTITY_NAME, - ExportFileFormat, - entity_names_to_entities, - file_fields_to_model_fields, - get_all_fields_info, - model_fields_to_file_fields, -) -from aiida.tools.graph.graph_traversers import get_nodes_export, validate_traversal_rules -from aiida.tools.importexport.archive.writers import ArchiveMetadata, ArchiveWriterAbstract, get_writer -from aiida.tools.importexport.dbexport.utils import ( - EXPORT_LOGGER, - check_licenses, - fill_in_query, - serialize_dict, - summary, -) - -__all__ = ('export', 'EXPORT_LOGGER', 'ExportFileFormat') - - -def export( - entities: Optional[Iterable[Any]] = None, - filename: Optional[str] = None, - file_format: Union[str, Type[ArchiveWriterAbstract]] = ExportFileFormat.ZIP, - overwrite: bool = False, - include_comments: bool = True, - include_logs: bool = True, - allowed_licenses: Optional[Union[list, Callable]] = None, - forbidden_licenses: Optional[Union[list, Callable]] = None, - writer_init: Optional[Dict[str, Any]] = None, - batch_size: int = 100, - **traversal_rules: bool, -) -> ArchiveWriterAbstract: - """Export AiiDA data to an archive file. - - Note, the logging level and progress reporter should be set externally, for example:: - - from aiida.common.progress_reporter import set_progress_bar_tqdm - - EXPORT_LOGGER.setLevel('DEBUG') - set_progress_bar_tqdm(leave=True) - export(...) - - :param entities: a list of entity instances; - they can belong to different models/entities. - - :param filename: the filename (possibly including the absolute path) - of the file on which to export. - - :param file_format: 'zip', 'tar.gz' or 'folder' or a specific writer class. - - :param overwrite: if True, overwrite the output file without asking, if it exists. - If False, raise an - :py:class:`~aiida.tools.importexport.common.exceptions.ArchiveExportError` - if the output file already exists. - - :param allowed_licenses: List or function. - If a list, then checks whether all licenses of Data nodes are in the list. If a function, - then calls function for licenses of Data nodes expecting True if license is allowed, False - otherwise. - - :param forbidden_licenses: List or function. If a list, - then checks whether all licenses of Data nodes are in the list. If a function, - then calls function for licenses of Data nodes expecting True if license is allowed, False - otherwise. - - :param include_comments: In-/exclude export of comments for given node(s) in ``entities``. - Default: True, *include* comments in export (as well as relevant users). - - :param include_logs: In-/exclude export of logs for given node(s) in ``entities``. - Default: True, *include* logs in export. - - :param writer_init: Additional key-word arguments to pass to the writer class init - - :param batch_size: batch database query results in sub-collections to reduce memory usage - - :param traversal_rules: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` - what rule names are toggleable and what the defaults are. - - :returns: a dictionary of data regarding the export process (timings, etc) - - :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`: - if there are any internal errors when exporting. - :raises `~aiida.common.exceptions.LicensingException`: - if any node is licensed under forbidden license. - """ - # pylint: disable=too-many-locals,too-many-branches,too-many-statements - - # Backwards-compatibility - entities = cast(Iterable[Any], entities) - filename = cast(str, filename) - - type_check( - entities, - (list, tuple, set), - msg='`entities` must be specified and given as a list of AiiDA entities', - ) - entities = list(entities) - if type_check(filename, str, allow_none=True) is None: - filename = 'export_data.aiida' - - if not overwrite and os.path.exists(filename): - raise exceptions.ArchiveExportError(f"The output file '{filename}' already exists") - - # validate the traversal rules and generate a full set for reporting - validate_traversal_rules(GraphTraversalRules.EXPORT, **traversal_rules) - full_traversal_rules = { - name: traversal_rules.get(name, rule.default) for name, rule in GraphTraversalRules.EXPORT.value.items() - } - - # setup the archive writer - writer_init = writer_init or {} - if isinstance(file_format, str): - writer = get_writer(file_format)(filepath=filename, **writer_init) - elif issubclass(file_format, ArchiveWriterAbstract): - writer = file_format(filepath=filename, **writer_init) - else: - raise TypeError('file_format must be a string or ArchiveWriterAbstract class') - - summary( - file_format=writer.file_format_verbose, - export_version=writer.export_version, - outfile=filename, - include_comments=include_comments, - include_logs=include_logs, - traversal_rules=full_traversal_rules - ) - - EXPORT_LOGGER.debug('STARTING EXPORT...') - - all_fields_info, unique_identifiers = get_all_fields_info() - entities_starting_set, given_node_entry_ids = _get_starting_node_ids(entities) - - # Initialize the writer - with writer as writer_context: - - # Iteratively explore the AiiDA graph to find further nodes that should also be exported - with get_progress_reporter()(desc='Traversing provenance via links ...', total=1) as progress: - traverse_output = get_nodes_export(starting_pks=given_node_entry_ids, get_links=True, **traversal_rules) - progress.update() - node_ids_to_be_exported = traverse_output['nodes'] - - EXPORT_LOGGER.debug('WRITING METADATA...') - - writer_context.write_metadata( - ArchiveMetadata( - export_version=EXPORT_VERSION, - aiida_version=get_version(), - unique_identifiers=unique_identifiers, - all_fields_info=all_fields_info, - graph_traversal_rules=traverse_output['rules'], - # Turn sets into lists to be able to export them as JSON metadata. - entities_starting_set={ - entity: list(entity_set) for entity, entity_set in entities_starting_set.items() - }, - include_comments=include_comments, - include_logs=include_logs, - ) - ) - - # Create a mapping of node PK to UUID. - node_pk_2_uuid_mapping: Dict[int, str] = {} - repository_metadata_mapping: Dict[int, dict] = {} - - if node_ids_to_be_exported: - qbuilder = orm.QueryBuilder().append( - orm.Node, - project=('id', 'uuid', 'repository_metadata'), - filters={'id': { - 'in': node_ids_to_be_exported - }}, - ) - for pk, uuid, repository_metadata in qbuilder.iterall(batch_size=batch_size): - node_pk_2_uuid_mapping[pk] = uuid - repository_metadata_mapping[pk] = repository_metadata - - # check that no nodes are being exported with incorrect licensing - _check_node_licenses(node_ids_to_be_exported, allowed_licenses, forbidden_licenses) - - # write the link data - if traverse_output['links'] is not None: - with get_progress_reporter()(total=len(traverse_output['links']), desc='Writing links') as progress: - for link in traverse_output['links']: - progress.update() - writer_context.write_link({ - 'input': node_pk_2_uuid_mapping[link.source_id], - 'output': node_pk_2_uuid_mapping[link.target_id], - 'label': link.link_label, - 'type': link.link_type, - }) - - # generate a list of queries to encapsulate all required entities - entity_queries = _collect_entity_queries( - node_ids_to_be_exported, - entities_starting_set, - node_pk_2_uuid_mapping, - include_comments, - include_logs, - ) - - total_entities = sum(query.count() for query in entity_queries.values()) - - # write all entity data fields - if total_entities: - exported_entity_pks = _write_entity_data( - total_entities=total_entities, - entity_queries=entity_queries, - writer=writer_context, - batch_size=batch_size - ) - else: - exported_entity_pks = defaultdict(set) - EXPORT_LOGGER.info('No entities were found to export') - - # write mappings of groups to the nodes they contain - if exported_entity_pks[GROUP_ENTITY_NAME]: - - EXPORT_LOGGER.debug('Writing group UUID -> [nodes UUIDs]') - - _write_group_mappings( - group_pks=exported_entity_pks[GROUP_ENTITY_NAME], batch_size=batch_size, writer=writer_context - ) - - # copy all required node repositories - if exported_entity_pks[NODE_ENTITY_NAME]: - - _write_node_repositories( - node_pks=exported_entity_pks[NODE_ENTITY_NAME], - repository_metadata_mapping=repository_metadata_mapping, - writer=writer_context - ) - - EXPORT_LOGGER.info('Finalizing Export...') - - # summarize export - export_summary = '\n - '.join(f'{name:<6}: {len(pks)}' for name, pks in exported_entity_pks.items()) - if exported_entity_pks: - EXPORT_LOGGER.info('Exported Entities:\n - ' + export_summary + '\n') - # TODO - # EXPORT_LOGGER.info('Writer Information:\n %s', writer.export_info) - - return writer - - -def _get_starting_node_ids(entities: List[Any]) -> Tuple[DefaultDict[str, Set[str]], Set[int]]: - """Get the starting node UUIDs and PKs - - :param entities: a list of entity instances - - :raises exceptions.ArchiveExportError: - :return: entities_starting_set, given_node_entry_ids - """ - entities_starting_set: DefaultDict[str, Set[str]] = defaultdict(set) - given_node_entry_ids: Set[int] = set() - - # store a list of the actual dbnodes - total = len(entities) + (1 if GROUP_ENTITY_NAME in entities_starting_set else 0) - if not total: - return entities_starting_set, given_node_entry_ids - - for entry in entities: - - if issubclass(entry.__class__, orm.Group): - entities_starting_set[GROUP_ENTITY_NAME].add(entry.uuid) - elif issubclass(entry.__class__, orm.Node): - entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid) - given_node_entry_ids.add(entry.pk) - elif issubclass(entry.__class__, orm.Computer): - entities_starting_set[COMPUTER_ENTITY_NAME].add(entry.uuid) - else: - raise exceptions.ArchiveExportError( - f'I was given {entry} ({type(entry)}),' - ' which is not a Node, Computer, or Group instance' - ) - # Add all the nodes contained within the specified groups - if GROUP_ENTITY_NAME in entities_starting_set: +# AUTO-GENERATED - # Use single query instead of given_group.nodes iterator for performance. - qh_groups = ( - orm.QueryBuilder().append( - orm.Group, - filters={ - 'uuid': { - 'in': entities_starting_set[GROUP_ENTITY_NAME] - } - }, - tag='groups', - ).queryhelp - ) - node_query = orm.QueryBuilder(**qh_groups).append(orm.Node, project=['id', 'uuid'], with_group='groups') - node_count = node_query.count() +# yapf: disable +# pylint: disable=wildcard-import - if node_count: - with get_progress_reporter()(desc='Collecting nodes in groups', total=node_count) as progress: +from .main import * - pks, uuids = [], [] - for pk, uuid in node_query.all(): - progress.update() - pks.append(pk) - uuids.append(uuid) - - entities_starting_set[NODE_ENTITY_NAME].update(uuids) - given_node_entry_ids.update(pks) - - return entities_starting_set, given_node_entry_ids - - -def _check_node_licenses( - node_ids_to_be_exported: Set[int], - allowed_licenses: Optional[Union[list, Callable]], - forbidden_licenses: Optional[Union[list, Callable]], -) -> None: - """Check the nodes to be archived for disallowed licences.""" - # TODO (Spyros) To see better! Especially for functional licenses - # Check the licenses of exported data. - if allowed_licenses is not None or forbidden_licenses is not None: - builder = orm.QueryBuilder() - builder.append( - orm.Node, - project=['id', 'attributes.source.license'], - filters={'id': { - 'in': node_ids_to_be_exported - }}, - ) - # Skip those nodes where the license is not set (this is the standard behavior with Django) - node_licenses = [(a, b) for [a, b] in builder.all() if b is not None] - check_licenses(node_licenses, allowed_licenses, forbidden_licenses) - - -def _get_model_fields(entity_name: str) -> List[str]: - """Return a list of fields to retrieve for a particular entity - - :param entity_name: name of database entity, such as Node - - """ - all_fields_info, _ = get_all_fields_info() - project_cols = ['id'] - entity_prop = all_fields_info[entity_name].keys() - # Here we do the necessary renaming of properties - for prop in entity_prop: - # nprop contains the list of projections - nprop = ( - file_fields_to_model_fields[entity_name][prop] if prop in file_fields_to_model_fields[entity_name] else prop - ) - project_cols.append(nprop) - return project_cols - - -def _collect_entity_queries( - node_ids_to_be_exported: Set[int], - entities_starting_set: DefaultDict[str, Set[str]], - node_pk_2_uuid_mapping: Dict[int, str], - include_comments: bool = True, - include_logs: bool = True, -) -> Dict[str, orm.QueryBuilder]: - """Gather partial queries for all entities to export.""" - # pylint: disable=too-many-locals - given_log_entry_ids = set() - given_comment_entry_ids = set() - - total = 2 + (((1 if include_logs else 0) + (1 if include_comments else 0)) if node_ids_to_be_exported else 0) - with get_progress_reporter()(desc='Building entity database queries', total=total) as progress: - - # Logs - if include_logs and node_ids_to_be_exported: - # Get related log(s) - universal for all nodes - builder = orm.QueryBuilder() - builder.append( - orm.Log, - filters={'dbnode_id': { - 'in': node_ids_to_be_exported - }}, - project='uuid', - ) - res = set(builder.all(flat=True)) - given_log_entry_ids.update(res) - - progress.update() - - # Comments - if include_comments and node_ids_to_be_exported: - # Get related log(s) - universal for all nodes - builder = orm.QueryBuilder() - builder.append( - orm.Comment, - filters={'dbnode_id': { - 'in': node_ids_to_be_exported - }}, - project='uuid', - ) - res = set(builder.all(flat=True)) - given_comment_entry_ids.update(res) - - progress.update() - - # Here we get all the columns that we plan to project per entity that we would like to extract - given_entities = set(entities_starting_set.keys()) - if node_ids_to_be_exported: - given_entities.add(NODE_ENTITY_NAME) - if given_log_entry_ids: - given_entities.add(LOG_ENTITY_NAME) - if given_comment_entry_ids: - given_entities.add(COMMENT_ENTITY_NAME) - - progress.update() - - entities_to_add: Dict[str, orm.QueryBuilder] = {} - if not given_entities: - progress.update() - return entities_to_add - - for given_entity in given_entities: - - project_cols = _get_model_fields(given_entity) - - # Getting the ids that correspond to the right entity - entry_uuids_to_add = entities_starting_set.get(given_entity, set()) - if not entry_uuids_to_add: - if given_entity == LOG_ENTITY_NAME: - entry_uuids_to_add = given_log_entry_ids - elif given_entity == COMMENT_ENTITY_NAME: - entry_uuids_to_add = given_comment_entry_ids - elif given_entity == NODE_ENTITY_NAME: - entry_uuids_to_add.update({node_pk_2_uuid_mapping[_] for _ in node_ids_to_be_exported}) - - builder = orm.QueryBuilder() - builder.append( - entity_names_to_entities[given_entity], - filters={'uuid': { - 'in': entry_uuids_to_add - }}, - project=project_cols, - tag=given_entity, - outerjoin=True, - ) - entities_to_add[given_entity] = builder - - progress.update() - - return entities_to_add - - -def _write_entity_data( - total_entities: int, entity_queries: Dict[str, orm.QueryBuilder], writer: ArchiveWriterAbstract, batch_size: int -) -> Dict[str, Set[int]]: - """Iterate through data returned from entity queries, serialize the DB fields, then write to the export.""" - all_fields_info, unique_identifiers = get_all_fields_info() - entity_separator = '_' - - exported_entity_pks: Dict[str, Set[int]] = defaultdict(set) - unsealed_node_pks: Set[int] = set() - - with get_progress_reporter()(total=total_entities, desc='Writing entity data') as progress: - - for entity_name, entity_query in entity_queries.items(): - - foreign_fields = {k: v for k, v in all_fields_info[entity_name].items() if 'requires' in v} - for value in foreign_fields.values(): - ref_model_name = value['requires'] - fill_in_query( - entity_query, - entity_name, - ref_model_name, - [entity_name], - entity_separator, - ) - - for query_results in entity_query.iterdict(batch_size=batch_size): - - progress.update() - - for key, value in query_results.items(): - - pk = value['id'] - - # This is an empty result of an outer join. - # It should not be taken into account. - if pk is None: - continue - - # Get current entity - current_entity = key.split(entity_separator)[-1] - - # don't allow duplication - if pk in exported_entity_pks[current_entity]: - continue - - exported_entity_pks[current_entity].add(pk) - - fields = serialize_dict( - value, - remove_fields=['id'], - rename_fields=model_fields_to_file_fields[current_entity], - ) - - if current_entity == NODE_ENTITY_NAME and fields['node_type'].startswith('process.'): - if fields['attributes'].get('sealed', False) is not True: - unsealed_node_pks.add(pk) - - writer.write_entity_data(current_entity, pk, unique_identifiers[current_entity], fields) - - if unsealed_node_pks: - raise exceptions.ExportValidationError( - 'All ProcessNodes must be sealed before they can be exported. ' - f"Node(s) with PK(s): {', '.join(str(pk) for pk in unsealed_node_pks)} is/are not sealed." - ) - - return exported_entity_pks - - -def _write_group_mappings(*, group_pks: Set[int], batch_size: int, writer: ArchiveWriterAbstract): - """Query for node UUIDs in exported groups, and write these these mappings to the archive file.""" - group_uuid_query = orm.QueryBuilder().append( - orm.Group, - filters={ - 'id': { - 'in': list(group_pks) - } - }, - project='uuid', - tag='groups', - ).append(orm.Node, project='uuid', with_group='groups') - - groups_uuid_to_node_uuids = defaultdict(set) - for group_uuid, node_uuid in group_uuid_query.iterall(batch_size=batch_size): - groups_uuid_to_node_uuids[group_uuid].add(node_uuid) - - for group_uuid, node_uuids in groups_uuid_to_node_uuids.items(): - writer.write_group_nodes(group_uuid, list(node_uuids)) - - -def _write_node_repositories( - *, node_pks: Set[int], repository_metadata_mapping: Dict[int, dict], writer: ArchiveWriterAbstract -): - """Write all exported node repositories to the archive file.""" - with get_progress_reporter()(total=len(node_pks), desc='Exporting node repositories: ') as progress: - - with tempfile.TemporaryDirectory() as temp: - from disk_objectstore import Container - from aiida.manage.manager import get_manager - - dirpath = os.path.join(temp, 'container') - container_export = Container(dirpath) - container_export.init_container() - - profile = get_manager().get_profile() - assert profile is not None, 'profile not loaded' - container_profile = profile.get_repository().backend.container - - # This should be done more effectively, starting by not having to load the node. Either the repository - # metadata should be collected earlier when the nodes themselves are already exported or a single separate - # query should be done. - hashkeys = [] - - def collect_hashkeys(objects): - for obj in objects.values(): - hashkey = obj.get('k', None) - if hashkey is not None: - hashkeys.append(hashkey) - subobjects = obj.get('o', None) - if subobjects: - collect_hashkeys(subobjects) - - for pk in node_pks: - progress.set_description_str(f'Exporting node repositories: {pk}', refresh=False) - progress.update() - repository_metadata = repository_metadata_mapping[pk] - collect_hashkeys(repository_metadata.get('o', {})) +__all__ = ( + 'EXPORT_LOGGER', + 'ExportFileFormat', + 'export', +) - callback = create_callback(progress) - container_profile.export(set(hashkeys), container_export, compress=False, callback=callback) - writer.write_repository_container(container_export) +# yapf: enable diff --git a/aiida/tools/importexport/dbexport/main.py b/aiida/tools/importexport/dbexport/main.py new file mode 100644 index 0000000000..c68a3c242a --- /dev/null +++ b/aiida/tools/importexport/dbexport/main.py @@ -0,0 +1,613 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=fixme,too-many-lines +"""Provides export functionalities.""" +from collections import defaultdict +import os +import tempfile +from typing import ( + Any, + Callable, + DefaultDict, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Type, + Union, + cast, +) + +from aiida import get_version, orm +from aiida.common.links import GraphTraversalRules +from aiida.common.lang import type_check +from aiida.common.progress_reporter import get_progress_reporter, create_callback +from aiida.tools.importexport.common import exceptions +from aiida.tools.importexport.common.config import ( + COMMENT_ENTITY_NAME, + COMPUTER_ENTITY_NAME, + EXPORT_VERSION, + GROUP_ENTITY_NAME, + LOG_ENTITY_NAME, + NODE_ENTITY_NAME, + ExportFileFormat, + entity_names_to_entities, + file_fields_to_model_fields, + get_all_fields_info, + model_fields_to_file_fields, +) +from aiida.tools.graph.graph_traversers import get_nodes_export, validate_traversal_rules +from aiida.tools.importexport.archive.writers import ArchiveMetadata, ArchiveWriterAbstract, get_writer +from aiida.tools.importexport.dbexport.utils import ( + EXPORT_LOGGER, + check_licenses, + fill_in_query, + serialize_dict, + summary, +) + +__all__ = ('export', 'EXPORT_LOGGER', 'ExportFileFormat') + + +def export( + entities: Optional[Iterable[Any]] = None, + filename: Optional[str] = None, + file_format: Union[str, Type[ArchiveWriterAbstract]] = ExportFileFormat.ZIP, + overwrite: bool = False, + include_comments: bool = True, + include_logs: bool = True, + allowed_licenses: Optional[Union[list, Callable]] = None, + forbidden_licenses: Optional[Union[list, Callable]] = None, + writer_init: Optional[Dict[str, Any]] = None, + batch_size: int = 100, + **traversal_rules: bool, +) -> ArchiveWriterAbstract: + """Export AiiDA data to an archive file. + + Note, the logging level and progress reporter should be set externally, for example:: + + from aiida.common.progress_reporter import set_progress_bar_tqdm + + EXPORT_LOGGER.setLevel('DEBUG') + set_progress_bar_tqdm(leave=True) + export(...) + + :param entities: a list of entity instances; + they can belong to different models/entities. + + :param filename: the filename (possibly including the absolute path) + of the file on which to export. + + :param file_format: 'zip', 'tar.gz' or 'folder' or a specific writer class. + + :param overwrite: if True, overwrite the output file without asking, if it exists. + If False, raise an + :py:class:`~aiida.tools.importexport.common.exceptions.ArchiveExportError` + if the output file already exists. + + :param allowed_licenses: List or function. + If a list, then checks whether all licenses of Data nodes are in the list. If a function, + then calls function for licenses of Data nodes expecting True if license is allowed, False + otherwise. + + :param forbidden_licenses: List or function. If a list, + then checks whether all licenses of Data nodes are in the list. If a function, + then calls function for licenses of Data nodes expecting True if license is allowed, False + otherwise. + + :param include_comments: In-/exclude export of comments for given node(s) in ``entities``. + Default: True, *include* comments in export (as well as relevant users). + + :param include_logs: In-/exclude export of logs for given node(s) in ``entities``. + Default: True, *include* logs in export. + + :param writer_init: Additional key-word arguments to pass to the writer class init + + :param batch_size: batch database query results in sub-collections to reduce memory usage + + :param traversal_rules: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` + what rule names are toggleable and what the defaults are. + + :returns: a dictionary of data regarding the export process (timings, etc) + + :raises `~aiida.tools.importexport.common.exceptions.ArchiveExportError`: + if there are any internal errors when exporting. + :raises `~aiida.common.exceptions.LicensingException`: + if any node is licensed under forbidden license. + """ + # pylint: disable=too-many-locals,too-many-branches,too-many-statements + + # Backwards-compatibility + entities = cast(Iterable[Any], entities) + filename = cast(str, filename) + + type_check( + entities, + (list, tuple, set), + msg='`entities` must be specified and given as a list of AiiDA entities', + ) + entities = list(entities) + if type_check(filename, str, allow_none=True) is None: + filename = 'export_data.aiida' + + if not overwrite and os.path.exists(filename): + raise exceptions.ArchiveExportError(f"The output file '{filename}' already exists") + + # validate the traversal rules and generate a full set for reporting + validate_traversal_rules(GraphTraversalRules.EXPORT, **traversal_rules) + full_traversal_rules = { + name: traversal_rules.get(name, rule.default) for name, rule in GraphTraversalRules.EXPORT.value.items() + } + + # setup the archive writer + writer_init = writer_init or {} + if isinstance(file_format, str): + writer = get_writer(file_format)(filepath=filename, **writer_init) + elif issubclass(file_format, ArchiveWriterAbstract): + writer = file_format(filepath=filename, **writer_init) + else: + raise TypeError('file_format must be a string or ArchiveWriterAbstract class') + + summary( + file_format=writer.file_format_verbose, + export_version=writer.export_version, + outfile=filename, + include_comments=include_comments, + include_logs=include_logs, + traversal_rules=full_traversal_rules + ) + + EXPORT_LOGGER.debug('STARTING EXPORT...') + + all_fields_info, unique_identifiers = get_all_fields_info() + entities_starting_set, given_node_entry_ids = _get_starting_node_ids(entities) + + # Initialize the writer + with writer as writer_context: + + # Iteratively explore the AiiDA graph to find further nodes that should also be exported + with get_progress_reporter()(desc='Traversing provenance via links ...', total=1) as progress: + traverse_output = get_nodes_export(starting_pks=given_node_entry_ids, get_links=True, **traversal_rules) + progress.update() + node_ids_to_be_exported = traverse_output['nodes'] + + EXPORT_LOGGER.debug('WRITING METADATA...') + + writer_context.write_metadata( + ArchiveMetadata( + export_version=EXPORT_VERSION, + aiida_version=get_version(), + unique_identifiers=unique_identifiers, + all_fields_info=all_fields_info, + graph_traversal_rules=traverse_output['rules'], + # Turn sets into lists to be able to export them as JSON metadata. + entities_starting_set={ + entity: list(entity_set) for entity, entity_set in entities_starting_set.items() + }, + include_comments=include_comments, + include_logs=include_logs, + ) + ) + + # Create a mapping of node PK to UUID. + node_pk_2_uuid_mapping: Dict[int, str] = {} + repository_metadata_mapping: Dict[int, dict] = {} + + if node_ids_to_be_exported: + qbuilder = orm.QueryBuilder().append( + orm.Node, + project=('id', 'uuid', 'repository_metadata'), + filters={'id': { + 'in': node_ids_to_be_exported + }}, + ) + for pk, uuid, repository_metadata in qbuilder.iterall(batch_size=batch_size): + node_pk_2_uuid_mapping[pk] = uuid + repository_metadata_mapping[pk] = repository_metadata + + # check that no nodes are being exported with incorrect licensing + _check_node_licenses(node_ids_to_be_exported, allowed_licenses, forbidden_licenses) + + # write the link data + if traverse_output['links'] is not None: + with get_progress_reporter()(total=len(traverse_output['links']), desc='Writing links') as progress: + for link in traverse_output['links']: + progress.update() + writer_context.write_link({ + 'input': node_pk_2_uuid_mapping[link.source_id], + 'output': node_pk_2_uuid_mapping[link.target_id], + 'label': link.link_label, + 'type': link.link_type, + }) + + # generate a list of queries to encapsulate all required entities + entity_queries = _collect_entity_queries( + node_ids_to_be_exported, + entities_starting_set, + node_pk_2_uuid_mapping, + include_comments, + include_logs, + ) + + total_entities = sum(query.count() for query in entity_queries.values()) + + # write all entity data fields + if total_entities: + exported_entity_pks = _write_entity_data( + total_entities=total_entities, + entity_queries=entity_queries, + writer=writer_context, + batch_size=batch_size + ) + else: + exported_entity_pks = defaultdict(set) + EXPORT_LOGGER.info('No entities were found to export') + + # write mappings of groups to the nodes they contain + if exported_entity_pks[GROUP_ENTITY_NAME]: + + EXPORT_LOGGER.debug('Writing group UUID -> [nodes UUIDs]') + + _write_group_mappings( + group_pks=exported_entity_pks[GROUP_ENTITY_NAME], batch_size=batch_size, writer=writer_context + ) + + # copy all required node repositories + if exported_entity_pks[NODE_ENTITY_NAME]: + + _write_node_repositories( + node_pks=exported_entity_pks[NODE_ENTITY_NAME], + repository_metadata_mapping=repository_metadata_mapping, + writer=writer_context + ) + + EXPORT_LOGGER.info('Finalizing Export...') + + # summarize export + export_summary = '\n - '.join(f'{name:<6}: {len(pks)}' for name, pks in exported_entity_pks.items()) + if exported_entity_pks: + EXPORT_LOGGER.info('Exported Entities:\n - ' + export_summary + '\n') + # TODO + # EXPORT_LOGGER.info('Writer Information:\n %s', writer.export_info) + + return writer + + +def _get_starting_node_ids(entities: List[Any]) -> Tuple[DefaultDict[str, Set[str]], Set[int]]: + """Get the starting node UUIDs and PKs + + :param entities: a list of entity instances + + :raises exceptions.ArchiveExportError: + :return: entities_starting_set, given_node_entry_ids + """ + entities_starting_set: DefaultDict[str, Set[str]] = defaultdict(set) + given_node_entry_ids: Set[int] = set() + + # store a list of the actual dbnodes + total = len(entities) + (1 if GROUP_ENTITY_NAME in entities_starting_set else 0) + if not total: + return entities_starting_set, given_node_entry_ids + + for entry in entities: + + if issubclass(entry.__class__, orm.Group): + entities_starting_set[GROUP_ENTITY_NAME].add(entry.uuid) + elif issubclass(entry.__class__, orm.Node): + entities_starting_set[NODE_ENTITY_NAME].add(entry.uuid) + given_node_entry_ids.add(entry.pk) + elif issubclass(entry.__class__, orm.Computer): + entities_starting_set[COMPUTER_ENTITY_NAME].add(entry.uuid) + else: + raise exceptions.ArchiveExportError( + f'I was given {entry} ({type(entry)}),' + ' which is not a Node, Computer, or Group instance' + ) + + # Add all the nodes contained within the specified groups + if GROUP_ENTITY_NAME in entities_starting_set: + + # Use single query instead of given_group.nodes iterator for performance. + qh_groups = ( + orm.QueryBuilder().append( + orm.Group, + filters={ + 'uuid': { + 'in': entities_starting_set[GROUP_ENTITY_NAME] + } + }, + tag='groups', + ).queryhelp + ) + node_query = orm.QueryBuilder(**qh_groups).append(orm.Node, project=['id', 'uuid'], with_group='groups') + node_count = node_query.count() + + if node_count: + with get_progress_reporter()(desc='Collecting nodes in groups', total=node_count) as progress: + + pks, uuids = [], [] + for pk, uuid in node_query.all(): + progress.update() + pks.append(pk) + uuids.append(uuid) + + entities_starting_set[NODE_ENTITY_NAME].update(uuids) + given_node_entry_ids.update(pks) + + return entities_starting_set, given_node_entry_ids + + +def _check_node_licenses( + node_ids_to_be_exported: Set[int], + allowed_licenses: Optional[Union[list, Callable]], + forbidden_licenses: Optional[Union[list, Callable]], +) -> None: + """Check the nodes to be archived for disallowed licences.""" + # TODO (Spyros) To see better! Especially for functional licenses + # Check the licenses of exported data. + if allowed_licenses is not None or forbidden_licenses is not None: + builder = orm.QueryBuilder() + builder.append( + orm.Node, + project=['id', 'attributes.source.license'], + filters={'id': { + 'in': node_ids_to_be_exported + }}, + ) + # Skip those nodes where the license is not set (this is the standard behavior with Django) + node_licenses = [(a, b) for [a, b] in builder.all() if b is not None] + check_licenses(node_licenses, allowed_licenses, forbidden_licenses) + + +def _get_model_fields(entity_name: str) -> List[str]: + """Return a list of fields to retrieve for a particular entity + + :param entity_name: name of database entity, such as Node + + """ + all_fields_info, _ = get_all_fields_info() + project_cols = ['id'] + entity_prop = all_fields_info[entity_name].keys() + # Here we do the necessary renaming of properties + for prop in entity_prop: + # nprop contains the list of projections + nprop = ( + file_fields_to_model_fields[entity_name][prop] if prop in file_fields_to_model_fields[entity_name] else prop + ) + project_cols.append(nprop) + return project_cols + + +def _collect_entity_queries( + node_ids_to_be_exported: Set[int], + entities_starting_set: DefaultDict[str, Set[str]], + node_pk_2_uuid_mapping: Dict[int, str], + include_comments: bool = True, + include_logs: bool = True, +) -> Dict[str, orm.QueryBuilder]: + """Gather partial queries for all entities to export.""" + # pylint: disable=too-many-locals + given_log_entry_ids = set() + given_comment_entry_ids = set() + + total = 2 + (((1 if include_logs else 0) + (1 if include_comments else 0)) if node_ids_to_be_exported else 0) + with get_progress_reporter()(desc='Building entity database queries', total=total) as progress: + + # Logs + if include_logs and node_ids_to_be_exported: + # Get related log(s) - universal for all nodes + builder = orm.QueryBuilder() + builder.append( + orm.Log, + filters={'dbnode_id': { + 'in': node_ids_to_be_exported + }}, + project='uuid', + ) + res = set(builder.all(flat=True)) + given_log_entry_ids.update(res) + + progress.update() + + # Comments + if include_comments and node_ids_to_be_exported: + # Get related log(s) - universal for all nodes + builder = orm.QueryBuilder() + builder.append( + orm.Comment, + filters={'dbnode_id': { + 'in': node_ids_to_be_exported + }}, + project='uuid', + ) + res = set(builder.all(flat=True)) + given_comment_entry_ids.update(res) + + progress.update() + + # Here we get all the columns that we plan to project per entity that we would like to extract + given_entities = set(entities_starting_set.keys()) + if node_ids_to_be_exported: + given_entities.add(NODE_ENTITY_NAME) + if given_log_entry_ids: + given_entities.add(LOG_ENTITY_NAME) + if given_comment_entry_ids: + given_entities.add(COMMENT_ENTITY_NAME) + + progress.update() + + entities_to_add: Dict[str, orm.QueryBuilder] = {} + if not given_entities: + progress.update() + return entities_to_add + + for given_entity in given_entities: + + project_cols = _get_model_fields(given_entity) + + # Getting the ids that correspond to the right entity + entry_uuids_to_add = entities_starting_set.get(given_entity, set()) + if not entry_uuids_to_add: + if given_entity == LOG_ENTITY_NAME: + entry_uuids_to_add = given_log_entry_ids + elif given_entity == COMMENT_ENTITY_NAME: + entry_uuids_to_add = given_comment_entry_ids + elif given_entity == NODE_ENTITY_NAME: + entry_uuids_to_add.update({node_pk_2_uuid_mapping[_] for _ in node_ids_to_be_exported}) + + builder = orm.QueryBuilder() + builder.append( + entity_names_to_entities[given_entity], + filters={'uuid': { + 'in': entry_uuids_to_add + }}, + project=project_cols, + tag=given_entity, + outerjoin=True, + ) + entities_to_add[given_entity] = builder + + progress.update() + + return entities_to_add + + +def _write_entity_data( + total_entities: int, entity_queries: Dict[str, orm.QueryBuilder], writer: ArchiveWriterAbstract, batch_size: int +) -> Dict[str, Set[int]]: + """Iterate through data returned from entity queries, serialize the DB fields, then write to the export.""" + all_fields_info, unique_identifiers = get_all_fields_info() + entity_separator = '_' + + exported_entity_pks: Dict[str, Set[int]] = defaultdict(set) + unsealed_node_pks: Set[int] = set() + + with get_progress_reporter()(total=total_entities, desc='Writing entity data') as progress: + + for entity_name, entity_query in entity_queries.items(): + + foreign_fields = {k: v for k, v in all_fields_info[entity_name].items() if 'requires' in v} + for value in foreign_fields.values(): + ref_model_name = value['requires'] + fill_in_query( + entity_query, + entity_name, + ref_model_name, + [entity_name], + entity_separator, + ) + + for query_results in entity_query.iterdict(batch_size=batch_size): + + progress.update() + + for key, value in query_results.items(): + + pk = value['id'] + + # This is an empty result of an outer join. + # It should not be taken into account. + if pk is None: + continue + + # Get current entity + current_entity = key.split(entity_separator)[-1] + + # don't allow duplication + if pk in exported_entity_pks[current_entity]: + continue + + exported_entity_pks[current_entity].add(pk) + + fields = serialize_dict( + value, + remove_fields=['id'], + rename_fields=model_fields_to_file_fields[current_entity], + ) + + if current_entity == NODE_ENTITY_NAME and fields['node_type'].startswith('process.'): + if fields['attributes'].get('sealed', False) is not True: + unsealed_node_pks.add(pk) + + writer.write_entity_data(current_entity, pk, unique_identifiers[current_entity], fields) + + if unsealed_node_pks: + raise exceptions.ExportValidationError( + 'All ProcessNodes must be sealed before they can be exported. ' + f"Node(s) with PK(s): {', '.join(str(pk) for pk in unsealed_node_pks)} is/are not sealed." + ) + + return exported_entity_pks + + +def _write_group_mappings(*, group_pks: Set[int], batch_size: int, writer: ArchiveWriterAbstract): + """Query for node UUIDs in exported groups, and write these these mappings to the archive file.""" + group_uuid_query = orm.QueryBuilder().append( + orm.Group, + filters={ + 'id': { + 'in': list(group_pks) + } + }, + project='uuid', + tag='groups', + ).append(orm.Node, project='uuid', with_group='groups') + + groups_uuid_to_node_uuids = defaultdict(set) + for group_uuid, node_uuid in group_uuid_query.iterall(batch_size=batch_size): + groups_uuid_to_node_uuids[group_uuid].add(node_uuid) + + for group_uuid, node_uuids in groups_uuid_to_node_uuids.items(): + writer.write_group_nodes(group_uuid, list(node_uuids)) + + +def _write_node_repositories( + *, node_pks: Set[int], repository_metadata_mapping: Dict[int, dict], writer: ArchiveWriterAbstract +): + """Write all exported node repositories to the archive file.""" + with get_progress_reporter()(total=len(node_pks), desc='Exporting node repositories: ') as progress: + + with tempfile.TemporaryDirectory() as temp: + from disk_objectstore import Container + from aiida.manage.manager import get_manager + + dirpath = os.path.join(temp, 'container') + container_export = Container(dirpath) + container_export.init_container() + + profile = get_manager().get_profile() + assert profile is not None, 'profile not loaded' + container_profile = profile.get_repository().backend.container + + # This should be done more effectively, starting by not having to load the node. Either the repository + # metadata should be collected earlier when the nodes themselves are already exported or a single separate + # query should be done. + hashkeys = [] + + def collect_hashkeys(objects): + for obj in objects.values(): + hashkey = obj.get('k', None) + if hashkey is not None: + hashkeys.append(hashkey) + subobjects = obj.get('o', None) + if subobjects: + collect_hashkeys(subobjects) + + for pk in node_pks: + progress.set_description_str(f'Exporting node repositories: {pk}', refresh=False) + progress.update() + repository_metadata = repository_metadata_mapping[pk] + collect_hashkeys(repository_metadata.get('o', {})) + + callback = create_callback(progress) + container_profile.export(set(hashkeys), container_export, compress=False, callback=callback) + writer.write_repository_container(container_export) diff --git a/aiida/tools/importexport/dbimport/__init__.py b/aiida/tools/importexport/dbimport/__init__.py index fa90faa5cc..ad987679f1 100644 --- a/aiida/tools/importexport/dbimport/__init__.py +++ b/aiida/tools/importexport/dbimport/__init__.py @@ -8,79 +8,17 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Provides import functionalities.""" -from aiida.tools.importexport.dbimport.utils import IMPORT_LOGGER -__all__ = ('import_data', 'IMPORT_LOGGER') +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import -def import_data(in_path, group=None, **kwargs): - """Import exported AiiDA archive to the AiiDA database and repository. +from .main import * - Proxy function for the backend-specific import functions. - If ``in_path`` is a folder, calls extract_tree; otherwise, tries to detect the compression format - (zip, tar.gz, tar.bz2, ...) and calls the correct function. +__all__ = ( + 'IMPORT_LOGGER', + 'import_data', +) - Note, the logging level and progress reporter should be set externally, for example:: - - from aiida.common.progress_reporter import set_progress_bar_tqdm - - IMPORT_LOGGER.setLevel('DEBUG') - set_progress_bar_tqdm(leave=True) - import_data(...) - - :param in_path: the path to a file or folder that can be imported in AiiDA. - :type in_path: str - - :param group: Group wherein all imported Nodes will be placed. - :type group: :py:class:`~aiida.orm.groups.Group` - - :param extras_mode_existing: 3 letter code that will identify what to do with the extras import. - The first letter acts on extras that are present in the original node and not present in the imported node. - Can be either: - 'k' (keep it) or - 'n' (do not keep it). - The second letter acts on the imported extras that are not present in the original node. - Can be either: - 'c' (create it) or - 'n' (do not create it). - The third letter defines what to do in case of a name collision. - Can be either: - 'l' (leave the old value), - 'u' (update with a new value), - 'd' (delete the extra), or - 'a' (ask what to do if the content is different). - :type extras_mode_existing: str - - :param extras_mode_new: 'import' to import extras of new nodes or 'none' to ignore them - :type extras_mode_new: str - - :param comment_mode: Comment import modes (when same UUIDs are found). - Can be either: - 'newest' (will keep the Comment with the most recent modification time (mtime)) or - 'overwrite' (will overwrite existing Comments with the ones from the import file). - :type comment_mode: str - - :return: New and existing Nodes and Links. - :rtype: dict - - :raises `~aiida.tools.importexport.common.exceptions.ArchiveImportError`: if there are any internal errors when - importing. - """ - from aiida.manage import configuration - from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA - from aiida.tools.importexport.common.exceptions import ArchiveImportError - - backend = configuration.PROFILE.database_backend - - if backend == BACKEND_SQLA: - from aiida.tools.importexport.dbimport.backends.sqla import import_data_sqla - IMPORT_LOGGER.debug('Calling import function import_data_sqla for the %s backend.', backend) - return import_data_sqla(in_path, group=group, **kwargs) - - if backend == BACKEND_DJANGO: - from aiida.tools.importexport.dbimport.backends.django import import_data_dj - IMPORT_LOGGER.debug('Calling import function import_data_dj for the %s backend.', backend) - return import_data_dj(in_path, group=group, **kwargs) - - # else - raise ArchiveImportError(f'Unknown backend: {backend}') +# yapf: enable diff --git a/aiida/tools/importexport/dbimport/main.py b/aiida/tools/importexport/dbimport/main.py new file mode 100644 index 0000000000..fa90faa5cc --- /dev/null +++ b/aiida/tools/importexport/dbimport/main.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Provides import functionalities.""" +from aiida.tools.importexport.dbimport.utils import IMPORT_LOGGER + +__all__ = ('import_data', 'IMPORT_LOGGER') + + +def import_data(in_path, group=None, **kwargs): + """Import exported AiiDA archive to the AiiDA database and repository. + + Proxy function for the backend-specific import functions. + If ``in_path`` is a folder, calls extract_tree; otherwise, tries to detect the compression format + (zip, tar.gz, tar.bz2, ...) and calls the correct function. + + Note, the logging level and progress reporter should be set externally, for example:: + + from aiida.common.progress_reporter import set_progress_bar_tqdm + + IMPORT_LOGGER.setLevel('DEBUG') + set_progress_bar_tqdm(leave=True) + import_data(...) + + :param in_path: the path to a file or folder that can be imported in AiiDA. + :type in_path: str + + :param group: Group wherein all imported Nodes will be placed. + :type group: :py:class:`~aiida.orm.groups.Group` + + :param extras_mode_existing: 3 letter code that will identify what to do with the extras import. + The first letter acts on extras that are present in the original node and not present in the imported node. + Can be either: + 'k' (keep it) or + 'n' (do not keep it). + The second letter acts on the imported extras that are not present in the original node. + Can be either: + 'c' (create it) or + 'n' (do not create it). + The third letter defines what to do in case of a name collision. + Can be either: + 'l' (leave the old value), + 'u' (update with a new value), + 'd' (delete the extra), or + 'a' (ask what to do if the content is different). + :type extras_mode_existing: str + + :param extras_mode_new: 'import' to import extras of new nodes or 'none' to ignore them + :type extras_mode_new: str + + :param comment_mode: Comment import modes (when same UUIDs are found). + Can be either: + 'newest' (will keep the Comment with the most recent modification time (mtime)) or + 'overwrite' (will overwrite existing Comments with the ones from the import file). + :type comment_mode: str + + :return: New and existing Nodes and Links. + :rtype: dict + + :raises `~aiida.tools.importexport.common.exceptions.ArchiveImportError`: if there are any internal errors when + importing. + """ + from aiida.manage import configuration + from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA + from aiida.tools.importexport.common.exceptions import ArchiveImportError + + backend = configuration.PROFILE.database_backend + + if backend == BACKEND_SQLA: + from aiida.tools.importexport.dbimport.backends.sqla import import_data_sqla + IMPORT_LOGGER.debug('Calling import function import_data_sqla for the %s backend.', backend) + return import_data_sqla(in_path, group=group, **kwargs) + + if backend == BACKEND_DJANGO: + from aiida.tools.importexport.dbimport.backends.django import import_data_dj + IMPORT_LOGGER.debug('Calling import function import_data_dj for the %s backend.', backend) + return import_data_dj(in_path, group=group, **kwargs) + + # else + raise ArchiveImportError(f'Unknown backend: {backend}') diff --git a/aiida/tools/visualization/__init__.py b/aiida/tools/visualization/__init__.py index e059feb04e..59532c1bbb 100644 --- a/aiida/tools/visualization/__init__.py +++ b/aiida/tools/visualization/__init__.py @@ -7,8 +7,21 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Provides tools for visualization of the provenance graph.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .graph import * -__all__ = (graph.__all__) +__all__ = ( + 'Graph', + 'default_link_styles', + 'default_node_styles', + 'default_node_sublabels', + 'pstate_node_styles', +) + +# yapf: enable diff --git a/aiida/transports/__init__.py b/aiida/transports/__init__.py index 813fc6cc69..c1b7e7e3ce 100644 --- a/aiida/transports/__init__.py +++ b/aiida/transports/__init__.py @@ -7,9 +7,21 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module for classes and utilities to define transports to other machines.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .plugins import * from .transport import * -__all__ = (transport.__all__) +__all__ = ( + 'SshTransport', + 'Transport', + 'convert_to_bool', + 'parse_sshconfig', +) + +# yapf: enable diff --git a/aiida/transports/plugins/__init__.py b/aiida/transports/plugins/__init__.py index 2776a55f97..f11aea3b4c 100644 --- a/aiida/transports/plugins/__init__.py +++ b/aiida/transports/plugins/__init__.py @@ -7,3 +7,19 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Plugins for the transport.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .ssh import * + +__all__ = ( + 'SshTransport', + 'convert_to_bool', + 'parse_sshconfig', +) + +# yapf: enable diff --git a/setup.json b/setup.json index 6806c0ca73..f7915cc8ba 100644 --- a/setup.json +++ b/setup.json @@ -95,7 +95,7 @@ ], "pre-commit": [ "astroid<2.5", - "mypy==0.790", + "mypy==0.910", "packaging==20.3", "pre-commit~=2.2", "pylint~=2.5.0", diff --git a/utils/make_all.py b/utils/make_all.py new file mode 100644 index 0000000000..c54a330171 --- /dev/null +++ b/utils/make_all.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# pylint: disable=simplifiable-if-statement,too-many-branches +"""Pre-commit hook to add ``__all__`` imports to ``__init__`` files.""" +import ast +from collections import Counter +from pathlib import Path +from pprint import pprint +import sys +from typing import Dict, List, Optional, Tuple + + +def parse_all(folder_path: str) -> Tuple[dict, dict]: + """Walk through all files in folder, and parse the ``__all__`` variable. + + :return: (all dict, dict of unparsable) + """ + folder_path = Path(folder_path) + all_dict = {} + bad_all = {} + + for path in folder_path.glob('**/*.py'): + + # skip module files + if path.name == '__init__.py': + continue + + # parse the file + parsed = ast.parse(path.read_text(encoding='utf8')) + + # find __all__ assignment + all_token = None + for token in parsed.body: + if not isinstance(token, ast.Assign): + continue + if token.targets and getattr(token.targets[0], 'id', '') == '__all__': + all_token = token + break + + if all_token is None: + bad_all.setdefault('missing', []).append(str(path.relative_to(folder_path))) + continue + + if not isinstance(all_token.value, (ast.List, ast.Tuple)): + bad_all.setdefault('value not list/tuple', []).append(str(path.relative_to(folder_path))) + continue + + if not all(isinstance(el, ast.Str) for el in all_token.value.elts): + bad_all.setdefault('child not strings', []).append(str(path.relative_to(folder_path))) + continue + + names = [n.s for n in all_token.value.elts] + if not names: + continue + + path_dict = all_dict + for part in path.parent.relative_to(folder_path).parts: + path_dict = path_dict.setdefault(part, {}) + path_dict.setdefault(path.name[:-3], {})['__all__'] = names + + return all_dict, bad_all + + +def gather_all(cur_path: List[str], + cur_dict: dict, + skip_children: dict, + all_list: Optional[List[str]] = None) -> List[str]: + """Recursively gather __all__ names.""" + all_list = [] if all_list is None else all_list + for key, val in cur_dict.items(): + if key == '__all__': + all_list.extend(val) + elif key not in skip_children.get('/'.join(cur_path), []): + gather_all(cur_path + [key], val, skip_children, all_list) + return all_list + + +def write_inits(folder_path: str, all_dict: dict, skip_children: Dict[str, List[str]]) -> Dict[str, List[str]]: + """Write __init__.py files for all subfolders. + + :return: folders with non-unique imports + """ + folder_path = Path(folder_path) + non_unique = {} + for path in folder_path.glob('**/__init__.py'): + if path.parent == folder_path: + # skip top level __init__.py + continue + + rel_path = path.parent.relative_to(folder_path).as_posix() + + # get sub_dict for this folder + path_all_dict = all_dict + mod_path = path.parent.relative_to(folder_path).parts + try: + for part in mod_path: + path_all_dict = path_all_dict[part] + except KeyError: + # there is nothing to import + continue + + if '*' in skip_children.get(rel_path, []): + path_all_dict = {} + alls = [] + auto_content = ['', '# AUTO-GENERATED', '', '__all__ = ()', ''] + else: + path_all_dict = { + key: val for key, val in path_all_dict.items() if key not in skip_children.get(rel_path, []) + } + alls = gather_all(list(mod_path), path_all_dict, skip_children) + + # check for non-unique imports + if len(alls + list(path_all_dict)) != len(set(alls + list(path_all_dict))): + non_unique[rel_path] = [k for k, v in Counter(alls + list(path_all_dict)).items() if v > 1] + + auto_content = (['', '# AUTO-GENERATED'] + + ['', '# yapf: disable', '# pylint: disable=wildcard-import', ''] + + [f'from .{mod} import *' for mod in sorted(path_all_dict.keys())] + ['', '__all__ = ('] + + [f' {a!r},' for a in sorted(set(alls))] + [')', '', '# yapf: enable', '']) + + start_content = [] + end_content = [] + in_docstring = in_end_content = False + in_start_content = True + for line in path.read_text(encoding='utf8').splitlines(): + if not in_start_content and line.startswith('# END AUTO-GENERATED'): + in_end_content = True + if in_end_content: + end_content.append(line) + continue + # only use initial comments and docstring + if not (line.startswith('#') or line.startswith('"""') or in_docstring): + in_start_content = False + continue + if line.startswith('"""'): + if (not in_docstring) and not (line.endswith('"""') and not line.strip() == '"""'): + in_docstring = True + else: + in_docstring = False + if in_start_content and not line.startswith('# pylint'): + start_content.append(line) + + new_content = start_content + auto_content + end_content + + path.write_text('\n'.join(new_content).rstrip() + '\n', encoding='utf8') + + return non_unique + + +if __name__ == '__main__': + _folder = Path(__file__).parent.parent.joinpath('aiida') + _skip = { + # skipped since some arguments and options share the same name + 'cmdline/params': ['arguments', 'options'], + # skipped since the module and its method share the same name + 'cmdline/utils': ['echo'], + # skipped since this is for testing only not general use + 'manage': ['tests'], + # skipped since we don't want to expose the implmentation + 'orm': ['implementation'], + # skipped since both implementations share class/function names + 'orm/implementation': ['django', 'sqlalchemy', 'sql'], + # skip all since the module requires extra requirements + 'restapi': ['*'], + } + _all_dict, _bad_all = parse_all(_folder) + _non_unique = write_inits(_folder, _all_dict, _skip) + _bad_all.pop('missing', '') # allow missing __all__ + if _bad_all: + print('unparsable __all__:') + pprint(_bad_all) + if _non_unique: + print('non-unique imports:') + pprint(_non_unique) + if _bad_all or _non_unique: + sys.exit(1)