Skip to content

Commit

Permalink
Fix instance stack after exception
Browse files Browse the repository at this point in the history
  • Loading branch information
sealor committed Apr 24, 2024
1 parent 32b1461 commit 35c1b00
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
17 changes: 9 additions & 8 deletions junkie/_junkie.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def __getitem__(self, item):
return self._instances_by_name[item]

def __setitem__(self, key, value):
self._instances_by_name = self._instances_by_name_stack.peek()

if key in self._instances_by_name:
raise JunkieError(f'Instance for "{key}" already exists')
if key in self._context:
Expand All @@ -50,6 +52,8 @@ def __iter__(self) -> Iterator:
return self._instances_by_name.__iter__()

def __contains__(self, item) -> bool:
self._instances_by_name = self._instances_by_name_stack.peek()

return self._instances_by_name.__contains__(item)

@contextmanager
Expand All @@ -58,15 +62,12 @@ def inject(self, *names_and_factories: Union[str, Callable]) -> Union[Any, Tuple

with ExitStack() as self._exit_stack:
self._instances_by_name = self._instances_by_name_stack.peek().copy()
self._instances_by_name_stack.push(self._instances_by_name)

if len(names_and_factories) == 1:
yield self._build_instance(names_and_factories[0])
else:
yield self._build_tuple(*names_and_factories)

self._instances_by_name_stack.pop()
self._instances_by_name = self._instances_by_name_stack.peek()
with self._instances_by_name_stack.push_temporarily(self._instances_by_name):
if len(names_and_factories) == 1:
yield self._build_instance(names_and_factories[0])
else:
yield self._build_tuple(*names_and_factories)

def _build_tuple(self, *names_and_factories: Union[str, Callable]) -> Tuple[Any, ...]:
instances = []
Expand Down
13 changes: 13 additions & 0 deletions test/test_bugfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,16 @@ def faulty_handler():
with self.assertRaises(ValueError):
with my_junkie.inject(faulty_handler):
pass

def test_instances_after_exception(self):
def faulty_handler():
raise ValueError()

obj = object()
my_junkie = Junkie({"object": lambda: obj})

with self.assertRaises(ValueError):
with my_junkie.inject("object", faulty_handler):
pass

self.assertNotIn("object", my_junkie)

0 comments on commit 35c1b00

Please sign in to comment.