diff --git a/aries_cloudagent/askar/profile.py b/aries_cloudagent/askar/profile.py index 27cec91b7c..b07dcf1f6b 100644 --- a/aries_cloudagent/askar/profile.py +++ b/aries_cloudagent/askar/profile.py @@ -114,7 +114,7 @@ def bind_providers(self): "aries_cloudagent.indy.credx.issuer.IndyCredxIssuer", ref(self) ), ) - injector.bind_provider( + injector.soft_bind_provider( VCHolder, ClassProvider( "aries_cloudagent.storage.vc_holder.askar.AskarVCHolder", diff --git a/aries_cloudagent/config/injector.py b/aries_cloudagent/config/injector.py index 9c47bd1ec5..26130623db 100644 --- a/aries_cloudagent/config/injector.py +++ b/aries_cloudagent/config/injector.py @@ -1,6 +1,6 @@ """Standard Injector implementation.""" -from typing import Mapping, Optional, Type +from typing import Dict, Mapping, Optional, Type from .base import BaseProvider, BaseInjector, InjectionError, InjectType from .provider import InstanceProvider, CachedProvider @@ -18,7 +18,7 @@ def __init__( ): """Initialize an `Injector`.""" self.enforce_typing = enforce_typing - self._providers = {} + self._providers: Dict[Type, BaseProvider] = {} self._settings = Settings(settings) @property @@ -45,6 +45,24 @@ def bind_provider( provider = CachedProvider(provider) self._providers[base_cls] = provider + def soft_bind_instance(self, base_cls: Type[InjectType], instance: InjectType): + """Add a static instance as a soft class binding. + + The binding occurs only if a provider for the same type does not already exist. + """ + if not self.get_provider(base_cls): + self.bind_instance(base_cls, instance) + + def soft_bind_provider( + self, base_cls: Type[InjectType], provider: BaseProvider, *, cache: bool = False + ): + """Add a dynamic instance resolver as a soft class binding. + + The binding occurs only if a provider for the same type does not already exist. + """ + if not self.get_provider(base_cls): + self.bind_provider(base_cls, provider, cache=cache) + def clear_binding(self, base_cls: Type[InjectType]): """Remove a previously-added binding.""" if base_cls in self._providers: diff --git a/aries_cloudagent/config/tests/test_injector.py b/aries_cloudagent/config/tests/test_injector.py index 76da5f7992..3b5023307f 100644 --- a/aries_cloudagent/config/tests/test_injector.py +++ b/aries_cloudagent/config/tests/test_injector.py @@ -70,6 +70,39 @@ def test_inject_provider(self): assert mock_provider.settings[self.test_key] == override_settings[self.test_key] assert mock_provider.injector is self.test_instance + def test_inject_soft_provider_bindings(self): + """Test injecting providers with soft binding.""" + provider = MockProvider(self.test_value) + override = MockProvider("Override") + + self.test_instance.soft_bind_provider(str, provider) + assert self.test_instance.inject(str) == self.test_value + + self.test_instance.clear_binding(str) + # Bound by a plugin on startup, for example + self.test_instance.bind_provider(str, override) + + # Bound later in Profile.bind_providerse + self.test_instance.soft_bind_provider(str, provider) + + # We want the plugin value, not the Profile bound value + assert self.test_instance.inject(str) == "Override" + + def test_inject_soft_instance_bindings(self): + """Test injecting providers with soft binding.""" + self.test_instance.soft_bind_instance(str, self.test_value) + assert self.test_instance.inject(str) == self.test_value + + self.test_instance.clear_binding(str) + # Bound by a plugin on startup, for example + self.test_instance.bind_instance(str, "Override") + + # Bound later in Profile.bind_providerse + self.test_instance.soft_bind_instance(str, self.test_value) + + # We want the plugin value, not the Profile bound value + assert self.test_instance.inject(str) == "Override" + def test_bad_provider(self): """Test empty and invalid provider results.""" self.test_instance.bind_provider(str, MockProvider(None))