Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature proposal: RetryContextManager (for retrying __enter__ and NOT the contents of the with block) #438

Open
rudolfbyker opened this issue Feb 8, 2024 · 0 comments

Comments

@rudolfbyker
Copy link

Goal

A context manager that retries if the __enter__ method of a context manager raises an exception, but does NOT retry when the contents of the with block raises an exception.

Example use case

The context we are trying to set up is very unreliable (COM server black magic in our case), but once it is set up, it works well. We need to retry setting up the context, without squelching errors raised in the with block.

Proposed new RetryContextManager class

from types import TracebackType
from typing import (
    TypedDict,
    Type,
    ContextManager,
    TypeVar,
    Optional,
    Callable,
    Generic,
    Union,
    Any,
)

from tenacity import retry, RetryCallState, RetryError
from tenacity.stop import StopBaseT
from tenacity.wait import WaitBaseT
from tenacity.retry import RetryBaseT


class RetryKwargs(TypedDict, total=False):
    """
    Copied from the arguments of `BaseRetrying.__init__` in the `tenacity` library.
    """

    sleep: Callable[[Union[int, float]], None]
    stop: StopBaseT
    wait: WaitBaseT
    retry: RetryBaseT
    before: Callable[[RetryCallState], None]
    after: Callable[[RetryCallState], None]
    before_sleep: Optional[Callable[[RetryCallState], None]]
    reraise: bool
    retry_error_cls: Type[RetryError]
    retry_error_callback: Optional[Callable[[RetryCallState], Any]]


T = TypeVar("T")


class RetryContextManager(Generic[T]):
    """
    A context manager that retries if the `__enter__` method of a context manager raises an exception, but does NOT
    retry when the contents of the `with` block raises an exception.
    """

    def __init__(
        self,
        cm: Callable[[], ContextManager[T]],
        retry_kwargs: RetryKwargs,
    ) -> None:
        """
        Create the RetryContextManager.

        Args:
            cm: A callable that returns a context manager.
            retry_kwargs: The arguments to pass to the `retry` decorator.
        """
        self._cm: Callable[[], ContextManager[T]] = cm
        self._cm_instance: Optional[ContextManager[T]] = None
        self._retry_kwargs: RetryKwargs = retry_kwargs

    def __enter__(self) -> T:
        @retry(**self._retry_kwargs)
        def _enter() -> T:
            # Create a new instance of the context manager.
            cm_instance = self._cm()
            try:
                # Enter the context manager.
                managed_resource = cm_instance.__enter__()
            except BaseException as e:
                # Clean up the failed context.
                cm_instance.__exit__(type(e), e, e.__traceback__)

                # Re-raise so that we can retry.
                raise

            # Success.
            self._cm_instance = cm_instance
            return managed_resource

        return _enter()

    def __exit__(
        self,
        exc_type: Type[BaseException] | None,
        exc_val: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        if self._cm_instance is not None:
            self._cm_instance.__exit__(exc_type, exc_val, exc_tb)

Unit tests

import unittest
from contextlib import contextmanager
from logging import getLogger, DEBUG
from queue import Queue
from typing import Generator

from tenacity import stop_after_attempt, before_log, after_log, before_sleep_log, RetryContextManager, RetryKwargs


def create_test_context_manager(
    *,
    n_failures: int,
    thing_to_yield: int,
    history: Queue[str],
    retry_kwargs: RetryKwargs,
) -> RetryContextManager[int]:
    """
    Create an (possibly unreliable) context manager, wrapped in a `RetryContextManager`, for testing.

    Args:
        n_failures: The number of times to fail before succeeding.
        thing_to_yield: The thing to yield when the context manager succeeds.
        history: A queue to put the history of the context manager into. This is used for test assertions.
        retry_kwargs: The arguments to pass to the `retry` decorator.

    Returns:
        A callable that returns a context manager.
    """
    n_tries = 0

    @contextmanager
    def cm() -> Generator[int, None, None]:
        nonlocal n_tries
        history.put(f"{n_tries} entering")

        try:
            if n_tries < n_failures:
                history.put(f"{n_tries} raising")
                raise RuntimeError(f"{n_tries} failed")

            history.put(f"{n_tries} yielding")
            yield thing_to_yield
        finally:
            history.put(f"{n_tries} exiting")
            n_tries += 1

    return RetryContextManager(
        cm=cm,
        retry_kwargs=retry_kwargs,
    )


class TestRetryContextManager(unittest.TestCase):
    def test_no_retries_necessary(self) -> None:
        history: Queue[str] = Queue()
        logger = getLogger("test")
        with self.assertLogs(logger=logger, level=DEBUG) as logs:
            logger.info("Before the context")

            with create_test_context_manager(
                n_failures=0,
                thing_to_yield=1,
                history=history,
                retry_kwargs=RetryKwargs(
                    stop=stop_after_attempt(3),
                    reraise=True,
                    before=before_log(logger=logger, log_level=DEBUG),
                    after=after_log(logger=logger, log_level=DEBUG),
                    before_sleep=before_sleep_log(logger=logger, log_level=DEBUG),
                ),
            ) as value:
                logger.info("Inside the context")
                self.assertEqual(value, 1)

            logger.info("After the context")

        self.assertEqual(["0 entering", "0 yielding", "0 exiting"], list(history.queue))
        self.assertEqual(
            [
                "INFO:test:Before the context",
                "DEBUG:test:Starting call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter', this is the 1st time calling it.",
                "INFO:test:Inside the context",
                "INFO:test:After the context",
            ],
            logs.output,
        )

    def test_retry_then_succeed(self) -> None:
        history: Queue[str] = Queue()
        logger = getLogger("test")
        with self.assertLogs(logger=logger, level=DEBUG) as logs:
            logger.info("Before the context")

            with create_test_context_manager(
                n_failures=2,
                thing_to_yield=1,
                history=history,
                retry_kwargs=RetryKwargs(
                    stop=stop_after_attempt(3),
                    reraise=True,
                    before=before_log(logger=logger, log_level=DEBUG),
                    after=after_log(logger=logger, log_level=DEBUG),
                    before_sleep=before_sleep_log(logger=logger, log_level=DEBUG),
                ),
            ) as value:
                logger.info("Inside the context")
                self.assertEqual(value, 1)

            logger.info("After the context")

        self.assertEqual(
            [
                "0 entering",
                "0 raising",
                "0 exiting",
                "1 entering",
                "1 raising",
                "1 exiting",
                "2 entering",
                "2 yielding",
                "2 exiting",
            ],
            list(history.queue),
        )
        self.assertEqual(
            [
                "INFO:test:Before the context",
                "DEBUG:test:Starting call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter', this is the 1st time calling it.",
                "DEBUG:test:Finished call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter' after 0.000(s), this was the 1st time calling it.",
                "DEBUG:test:Retrying retry_context_manager.RetryContextManager.__enter__.<locals>._enter in 0.0 seconds as it raised RuntimeError: 0 failed.",
                "DEBUG:test:Starting call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter', this is the 2nd time calling it.",
                "DEBUG:test:Finished call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter' after 0.000(s), this was the 2nd time calling it.",
                "DEBUG:test:Retrying retry_context_manager.RetryContextManager.__enter__.<locals>._enter in 0.0 seconds as it raised RuntimeError: 1 failed.",
                "DEBUG:test:Starting call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter', this is the 3rd time calling it.",
                "INFO:test:Inside the context",
                "INFO:test:After the context",
            ],
            logs.output,
        )

    def test_retry_then_give_up(self) -> None:
        history: Queue[str] = Queue()
        logger = getLogger("test")
        with self.assertLogs(logger=logger, level=DEBUG) as logs:
            logger.info("Before the context")

            with self.assertRaisesRegex(RuntimeError, "2 failed"):
                with create_test_context_manager(
                    n_failures=5,
                    thing_to_yield=1,
                    history=history,
                    retry_kwargs=RetryKwargs(
                        stop=stop_after_attempt(3),
                        reraise=True,
                        before=before_log(logger=logger, log_level=DEBUG),
                        after=after_log(logger=logger, log_level=DEBUG),
                        before_sleep=before_sleep_log(logger=logger, log_level=DEBUG),
                    ),
                ):
                    logger.info("Inside the context")

            logger.info("After the context")

        self.assertEqual(
            [
                "0 entering",
                "0 raising",
                "0 exiting",
                "1 entering",
                "1 raising",
                "1 exiting",
                "2 entering",
                "2 raising",
                "2 exiting",
            ],
            list(history.queue),
        )
        self.assertEqual(
            [
                "INFO:test:Before the context",
                "DEBUG:test:Starting call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter', this is the 1st time calling it.",
                "DEBUG:test:Finished call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter' after 0.000(s), this was the 1st time calling it.",
                "DEBUG:test:Retrying retry_context_manager.RetryContextManager.__enter__.<locals>._enter in 0.0 seconds as it raised RuntimeError: 0 failed.",
                "DEBUG:test:Starting call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter', this is the 2nd time calling it.",
                "DEBUG:test:Finished call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter' after 0.000(s), this was the 2nd time calling it.",
                "DEBUG:test:Retrying retry_context_manager.RetryContextManager.__enter__.<locals>._enter in 0.0 seconds as it raised RuntimeError: 1 failed.",
                "DEBUG:test:Starting call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter', this is the 3rd time calling it.",
                "DEBUG:test:Finished call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter' after 0.000(s), this was the 3rd time calling it.",
                "INFO:test:After the context",
            ],
            logs.output,
        )

    def test_with_body_does_not_cause_retry(self) -> None:
        history: Queue[str] = Queue()
        logger = getLogger("test")
        with self.assertLogs(logger=logger, level=DEBUG) as logs:
            logger.info("Before the context")

            with self.assertRaises(RuntimeError):
                with create_test_context_manager(
                    n_failures=0,
                    thing_to_yield=1,
                    history=history,
                    retry_kwargs=RetryKwargs(
                        stop=stop_after_attempt(3),
                        reraise=True,
                        before=before_log(logger=logger, log_level=DEBUG),
                        after=after_log(logger=logger, log_level=DEBUG),
                        before_sleep=before_sleep_log(logger=logger, log_level=DEBUG),
                    ),
                ):
                    logger.info("Inside the context")
                    raise RuntimeError("This should not cause a retry")

                logger.info("After the context")

        self.assertEqual(["0 entering", "0 yielding", "0 exiting"], list(history.queue))
        self.assertEqual(
            [
                "INFO:test:Before the context",
                "DEBUG:test:Starting call to 'retry_context_manager.RetryContextManager.__enter__.<locals>._enter', this is the 1st time calling it.",
                "INFO:test:Inside the context",
            ],
            logs.output,
        )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant