Skip to content

Commit

Permalink
refactor: move provider dict to a ProviderRegistry class
Browse files Browse the repository at this point in the history
Signed-off-by: Federico Bond <federicobond@gmail.com>
  • Loading branch information
federicobond committed Feb 7, 2024
1 parent e81abd0 commit 531831a
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 40 deletions.
47 changes: 8 additions & 39 deletions openfeature/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
from openfeature.hook import Hook
from openfeature.provider import FeatureProvider
from openfeature.provider.metadata import Metadata
from openfeature.provider.no_op_provider import NoOpProvider

_provider: FeatureProvider = NoOpProvider()
from openfeature.provider.registry import ProviderRegistry

_evaluation_context = EvaluationContext()

_hooks: typing.List[Hook] = []

_providers: typing.Dict[str, FeatureProvider] = {}
_provider_registry: ProviderRegistry = ProviderRegistry()


def get_client(
Expand All @@ -26,46 +24,18 @@ def get_client(
def set_provider(
provider: FeatureProvider, domain: typing.Optional[str] = None
) -> None:
if provider is None:
raise GeneralError(error_message="No provider")

if domain:
_set_domain_provider(domain, provider)
return

global _provider
if _provider:
_provider.shutdown()
_provider = provider
provider.initialize(_evaluation_context)


def _set_domain_provider(domain: str, provider: FeatureProvider) -> None:
if domain in _providers:
old_provider = _providers[domain]
del _providers[domain]
if old_provider not in _providers.values():
old_provider.shutdown()
if provider not in _providers.values():
provider.initialize(_evaluation_context)
_providers[domain] = provider


def _get_provider(domain: typing.Optional[str] = None) -> FeatureProvider:
global _provider
if domain is None:
return _provider
return _providers.get(domain, _provider)
_provider_registry.set_default_provider(provider)
else:
_provider_registry.set_provider(domain, provider)


def clear_providers() -> None:
for provider in _providers.values():
provider.shutdown()
_providers.clear()
return _provider_registry.clear_providers()


def get_provider_metadata(domain: typing.Optional[str] = None) -> Metadata:
return _get_provider(domain).get_metadata()
return _provider_registry.get_provider(domain).get_metadata()


def get_evaluation_context() -> EvaluationContext:
Expand Down Expand Up @@ -96,5 +66,4 @@ def get_hooks() -> typing.List[Hook]:


def shutdown() -> None:
for provider in {_provider, *_providers.values()}:
provider.shutdown()
_provider_registry.shutdown()
2 changes: 1 addition & 1 deletion openfeature/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(

@property
def provider(self) -> FeatureProvider:
return api._get_provider(domain=self.domain)
return api._provider_registry.get_provider(self.domain)

def get_metadata(self) -> ClientMetadata:
return ClientMetadata(domain=self.domain)
Expand Down
59 changes: 59 additions & 0 deletions openfeature/provider/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import typing

from openfeature.evaluation_context import EvaluationContext
from openfeature.exception import GeneralError
from openfeature.provider import FeatureProvider
from openfeature.provider.no_op_provider import NoOpProvider


class ProviderRegistry:
_default_provider: FeatureProvider
_providers: typing.Dict[str, FeatureProvider]

def __init__(self) -> None:
self._default_provider = NoOpProvider()
self._providers = {}

def set_provider(self, domain: str, provider: FeatureProvider) -> None:
if provider is None:
raise GeneralError(error_message="No provider")
providers = self._providers
if domain in providers:
old_provider = providers[domain]
del providers[domain]
if old_provider not in providers.values():
old_provider.shutdown()
if provider not in providers.values():
provider.initialize(self._get_evaluation_context())
providers[domain] = provider

def get_provider(self, domain: typing.Optional[str]) -> FeatureProvider:
if domain is None:
return self._default_provider
return self._providers.get(domain, self._default_provider)

def set_default_provider(self, provider: FeatureProvider) -> None:
if provider is None:
raise GeneralError(error_message="No provider")
if self._default_provider:
self._default_provider.shutdown()
self._default_provider = provider
provider.initialize(self._get_evaluation_context())

def get_default_provider(self) -> FeatureProvider:
return self._default_provider

def clear_providers(self) -> None:
for provider in self._providers.values():
provider.shutdown()
self._providers.clear()

def shutdown(self) -> None:
for provider in {self._default_provider, *self._providers.values()}:
provider.shutdown()

def _get_evaluation_context(self) -> EvaluationContext:
# imported here to avoid circular imports
from openfeature.api import get_evaluation_context # noqa: PLC0415

return get_evaluation_context()

0 comments on commit 531831a

Please sign in to comment.