diff --git a/.github/workflows/ci-code.yml b/.github/workflows/ci-code.yml index 3a5ab9df30..ada25552d8 100644 --- a/.github/workflows/ci-code.yml +++ b/.github/workflows/ci-code.yml @@ -90,7 +90,7 @@ jobs: - name: Upgrade pip and setuptools run: | - pip install --upgrade pip + pip install --upgrade pip setuptools pip --version - name: Build pymatgen with compatible numpy diff --git a/aiida/cmdline/commands/cmd_code.py b/aiida/cmdline/commands/cmd_code.py index f90741ef32..d78da2d444 100644 --- a/aiida/cmdline/commands/cmd_code.py +++ b/aiida/cmdline/commands/cmd_code.py @@ -89,6 +89,9 @@ def setup_code(ctx, non_interactive, **kwargs): else: kwargs['code_type'] = CodeBuilder.CodeType.STORE_AND_UPLOAD + # Convert entry point to its name + kwargs['input_plugin'] = kwargs['input_plugin'].name + code_builder = CodeBuilder(**kwargs) try: @@ -160,6 +163,9 @@ def code_duplicate(ctx, code, non_interactive, **kwargs): if kwargs.pop('hide_original'): code.hide() + # Convert entry point to its name + kwargs['input_plugin'] = kwargs['input_plugin'].name + code_builder = ctx.code_builder for key, value in kwargs.items(): setattr(code_builder, key, value) diff --git a/aiida/orm/utils/builders/code.py b/aiida/orm/utils/builders/code.py index b2f28fc072..492aed0d77 100644 --- a/aiida/orm/utils/builders/code.py +++ b/aiida/orm/utils/builders/code.py @@ -11,8 +11,6 @@ import enum import os -import importlib_metadata - from aiida.cmdline.utils.decorators import with_dbenv from aiida.common.utils import ErrorAccumulator @@ -155,9 +153,6 @@ def _set_code_attr(self, key, value): Checks compatibility with other code attributes. """ - if key == 'input_plugin' and isinstance(value, importlib_metadata.EntryPoint): - value = value.name - if key == 'description' and value is None: value = '' diff --git a/aiida/plugins/factories.py b/aiida/plugins/factories.py index 6715e85e81..18187da4af 100644 --- a/aiida/plugins/factories.py +++ b/aiida/plugins/factories.py @@ -78,10 +78,11 @@ def CalculationFactory(entry_point_name: str, load: bool = True) -> Optional[Uni entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (CalcJob, calcfunction) - if ( - isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, CalcJob)) or - (is_process_function(entry_point) and entry_point.node_class is CalcFunctionNode) - ): + if not load: + return entry_point + + if ((isclass(entry_point) and issubclass(entry_point, CalcJob)) or + (is_process_function(entry_point) and entry_point.node_class is CalcFunctionNode)): # type: ignore[union-attr] return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) @@ -100,6 +101,9 @@ def CalcJobImporterFactory(entry_point_name: str, load: bool = True) -> Optional entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (CalcJobImporter,) + if not load: + return entry_point + if isclass(entry_point) and issubclass(entry_point, CalcJobImporter): return entry_point # type: ignore[return-value] @@ -120,7 +124,10 @@ def DataFactory(entry_point_name: str, load: bool = True) -> Optional[Union[Entr entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Data,) - if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, Data)): + if not load: + return entry_point + + if isclass(entry_point) and issubclass(entry_point, Data): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) @@ -140,7 +147,10 @@ def DbImporterFactory(entry_point_name: str, load: bool = True) -> Optional[Unio entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (DbImporter,) - if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, DbImporter)): + if not load: + return entry_point + + if isclass(entry_point) and issubclass(entry_point, DbImporter): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) @@ -160,7 +170,10 @@ def GroupFactory(entry_point_name: str, load: bool = True) -> Optional[Union[Ent entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Group,) - if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, Group)): + if not load: + return entry_point + + if isclass(entry_point) and issubclass(entry_point, Group): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) @@ -180,7 +193,10 @@ def OrbitalFactory(entry_point_name: str, load: bool = True) -> Optional[Union[E entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Orbital,) - if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, Orbital)): + if not load: + return entry_point + + if isclass(entry_point) and issubclass(entry_point, Orbital): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) @@ -200,7 +216,10 @@ def ParserFactory(entry_point_name: str, load: bool = True) -> Optional[Union[En entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Parser,) - if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, Parser)): + if not load: + return entry_point + + if isclass(entry_point) and issubclass(entry_point, Parser): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) @@ -220,7 +239,10 @@ def SchedulerFactory(entry_point_name: str, load: bool = True) -> Optional[Union entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Scheduler,) - if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, Scheduler)): + if not load: + return entry_point + + if isclass(entry_point) and issubclass(entry_point, Scheduler): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) @@ -239,7 +261,10 @@ def TransportFactory(entry_point_name: str, load: bool = True) -> Optional[Union entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Transport,) - if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, Transport)): + if not load: + return entry_point + + if isclass(entry_point) and issubclass(entry_point, Transport): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) @@ -260,10 +285,11 @@ def WorkflowFactory(entry_point_name: str, load: bool = True) -> Optional[Union[ entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (WorkChain, workfunction) - if ( - isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, WorkChain)) or - (is_process_function(entry_point) and entry_point.node_class is WorkFunctionNode) - ): + if not load: + return entry_point + + if ((isclass(entry_point) and issubclass(entry_point, WorkChain)) or + (is_process_function(entry_point) and entry_point.node_class is WorkFunctionNode)): # type: ignore[union-attr] return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) diff --git a/tests/plugins/test_entry_point.py b/tests/plugins/test_entry_point.py index cc7d2c463f..2326756073 100644 --- a/tests/plugins/test_entry_point.py +++ b/tests/plugins/test_entry_point.py @@ -11,7 +11,7 @@ import pytest from aiida.common.warnings import AiidaDeprecationWarning -from aiida.plugins.entry_point import EntryPoint, get_entry_point, validate_registered_entry_points +from aiida.plugins.entry_point import get_entry_point, validate_registered_entry_points def test_validate_registered_entry_points(): @@ -42,6 +42,4 @@ def test_get_entry_point_deprecated(group, name): warning = f'The entry point `{name}` is deprecated. Please replace it with `core.{name}`.' with pytest.warns(AiidaDeprecationWarning, match=warning): - entry_point = get_entry_point(group, name) - - assert isinstance(entry_point, EntryPoint) + get_entry_point(group, name)