diff --git a/airflow-core/src/airflow/providers_manager.py b/airflow-core/src/airflow/providers_manager.py index 074a47c58111a..5dc70d0fc9e56 100644 --- a/airflow-core/src/airflow/providers_manager.py +++ b/airflow-core/src/airflow/providers_manager.py @@ -38,7 +38,6 @@ from airflow._shared.module_loading import entry_points_with_dist, import_string from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.singleton import Singleton if TYPE_CHECKING: from airflow.cli.cli_config import CLICommand @@ -366,7 +365,7 @@ def wrapped_function(*args: PS.args, **kwargs: PS.kwargs) -> None: return provider_info_cache_decorator -class ProvidersManager(LoggingMixin, metaclass=Singleton): +class ProvidersManager(LoggingMixin): """ Manages all provider distributions. @@ -377,6 +376,12 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton): resource_version = "0" _initialized: bool = False _initialization_stack_trace = None + _instance: ProvidersManager | None = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance @staticmethod def initialized() -> bool: @@ -388,6 +393,10 @@ def initialization_stack_trace() -> str | None: def __init__(self): """Initialize the manager.""" + # skip initialization if already initialized + if self.initialized(): + return + super().__init__() ProvidersManager._initialized = True ProvidersManager._initialization_stack_trace = "".join(traceback.format_stack(inspect.currentframe())) diff --git a/airflow-core/src/airflow/utils/singleton.py b/airflow-core/src/airflow/utils/singleton.py deleted file mode 100644 index cfc97eddbfcfc..0000000000000 --- a/airflow-core/src/airflow/utils/singleton.py +++ /dev/null @@ -1,33 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from typing import Generic, TypeVar - -T = TypeVar("T") - - -class Singleton(type, Generic[T]): - """Metaclass that allows to implement singleton pattern.""" - - _instances: dict[Singleton[T], T] = {} - - def __call__(cls: Singleton[T], *args, **kwargs) -> T: - if cls not in cls._instances: - cls._instances[cls] = super().__call__(*args, **kwargs) - return cls._instances[cls] diff --git a/airflow-core/tests/unit/always/test_providers_manager.py b/airflow-core/tests/unit/always/test_providers_manager.py index e0d686a565bb7..1e579bfdd955d 100644 --- a/airflow-core/tests/unit/always/test_providers_manager.py +++ b/airflow-core/tests/unit/always/test_providers_manager.py @@ -65,6 +65,21 @@ class TestProviderManager: def inject_fixtures(self, caplog, cleanup_providers_manager): self._caplog = caplog + def test_providers_manager_singleton(self): + """Test that ProvidersManager returns the same instance and shares state.""" + pm1 = ProvidersManager() + pm2 = ProvidersManager() + + assert pm1 is pm2 + + # assert their states are same + assert pm1._provider_dict is pm2._provider_dict + assert pm1._hook_provider_dict is pm2._hook_provider_dict + + # update property on one instance and check on another + pm1.resource_version = "updated_version" + assert pm2.resource_version == "updated_version" + def test_providers_are_loaded(self): with self._caplog.at_level(logging.WARNING): self._caplog.clear() diff --git a/airflow-core/tests/unit/utils/test_singleton.py b/airflow-core/tests/unit/utils/test_singleton.py deleted file mode 100644 index 57145fe7b97ba..0000000000000 --- a/airflow-core/tests/unit/utils/test_singleton.py +++ /dev/null @@ -1,65 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from airflow.utils.singleton import Singleton - - -class A(metaclass=Singleton): - pass - - -class Counter(metaclass=Singleton): - """Singleton class that counts how much __init__ and count was called.""" - - counter = 0 - - def __init__(self): - self.counter += 1 - - def count(self): - self.counter += 1 - - -def test_singleton_refers_to_same_instance(): - a, b = A(), A() - assert a is b - - -def test_singleton_after_out_of_context_does_refer_to_same_instance(): - # check if setting something on singleton is preserved after instance goes out of context - def x(): - a = A() - a.a = "a" - - x() - b = A() - assert b.a == "a" - - -def test_singleton_does_not_call_init_second_time(): - # first creation of Counter, check if __init__ is called - c = Counter() - assert c.counter == 1 - - # check if "new instance" calls __init__ - it shouldn't - d = Counter() - assert c.counter == 1 - - # check if incrementing "new instance" increments counter on previous one - d.count() - assert c.counter == 2