diff --git a/quantecon/__init__.py b/quantecon/__init__.py index d3f9a12be..21125a5e2 100644 --- a/quantecon/__init__.py +++ b/quantecon/__init__.py @@ -52,4 +52,4 @@ #<- from ._rank_nullspace import rank_est, nullspace from ._robustlq import RBLQ -from .util import searchsorted, fetch_nb_dependencies, tic, tac, toc +from .util import searchsorted, fetch_nb_dependencies, tic, tac, toc, Timer diff --git a/quantecon/util/__init__.py b/quantecon/util/__init__.py index 01db9d367..453e76c9a 100644 --- a/quantecon/util/__init__.py +++ b/quantecon/util/__init__.py @@ -6,4 +6,4 @@ from .array import searchsorted from .notebooks import fetch_nb_dependencies from .random import check_random_state, rng_integers -from .timing import tic, tac, toc, loop_timer +from .timing import tic, tac, toc, loop_timer, Timer diff --git a/quantecon/util/tests/test_timing.py b/quantecon/util/tests/test_timing.py index db0e1bc78..6d454072a 100644 --- a/quantecon/util/tests/test_timing.py +++ b/quantecon/util/tests/test_timing.py @@ -5,7 +5,7 @@ import time from numpy.testing import assert_allclose, assert_ -from quantecon.util import tic, tac, toc, loop_timer +from quantecon.util import tic, tac, toc, loop_timer, Timer class TestTicTacToc: @@ -56,3 +56,94 @@ def test_function_two_arg(n, a): for (average_time, average_of_best) in [test_one_arg, test_two_arg]: assert_(average_time >= average_of_best) + + +class TestTimer: + def setup_method(self): + self.sleep_time = 0.1 + + def test_basic_timer(self): + """Test basic Timer context manager functionality.""" + timer = Timer(silent=True) + + with timer: + time.sleep(self.sleep_time) + + # Check that elapsed time was recorded + assert timer.elapsed is not None + assert_allclose(timer.elapsed, self.sleep_time, atol=0.05, rtol=2) + + def test_timer_return_value(self): + """Test that Timer returns self for variable assignment.""" + with Timer(silent=True) as timer: + time.sleep(self.sleep_time) + + assert timer.elapsed is not None + assert_allclose(timer.elapsed, self.sleep_time, atol=0.05, rtol=2) + + def test_timer_units(self): + """Test different time units.""" + # Test seconds (default) + with Timer(silent=True) as timer_sec: + time.sleep(self.sleep_time) + expected_sec = self.sleep_time + assert_allclose(timer_sec.elapsed, expected_sec, atol=0.05, rtol=2) + + # Timer always stores elapsed time in seconds regardless of display unit + with Timer(unit="milliseconds", silent=True) as timer_ms: + time.sleep(self.sleep_time) + assert_allclose(timer_ms.elapsed, expected_sec, atol=0.05, rtol=2) + + with Timer(unit="microseconds", silent=True) as timer_us: + time.sleep(self.sleep_time) + assert_allclose(timer_us.elapsed, expected_sec, atol=0.05, rtol=2) + + def test_invalid_unit(self): + """Test that invalid units raise ValueError.""" + try: + Timer(unit="invalid") + assert False, "Should have raised ValueError" + except ValueError as e: + assert "unit must be one of" in str(e) + + def test_timer_precision(self): + """Test that precision parameter is accepted (output format tested manually).""" + # Just verify it doesn't crash with different precision values + with Timer(precision=0, silent=True) as timer0: + time.sleep(self.sleep_time) + with Timer(precision=6, silent=True) as timer6: + time.sleep(self.sleep_time) + + assert timer0.elapsed is not None + assert timer6.elapsed is not None + + def test_timer_message(self): + """Test custom message parameter (output format tested manually).""" + with Timer(message="Test operation", silent=True) as timer: + time.sleep(self.sleep_time) + + assert timer.elapsed is not None + + def test_timer_silent_mode(self): + """Test silent mode suppresses output.""" + # This mainly tests that silent=True doesn't crash + # Output suppression is hard to test automatically + with Timer(silent=True) as timer: + time.sleep(self.sleep_time) + + assert timer.elapsed is not None + + def test_timer_exception_handling(self): + """Test that Timer works correctly even when exceptions occur.""" + timer = Timer(silent=True) + + try: + with timer: + time.sleep(self.sleep_time) + raise ValueError("Test exception") + except ValueError: + pass # Expected + + # Timer should still record elapsed time + assert timer.elapsed is not None + assert_allclose(timer.elapsed, self.sleep_time, atol=0.05, rtol=2) diff --git a/quantecon/util/timing.py b/quantecon/util/timing.py index 032273e90..a6b4015e3 100644 --- a/quantecon/util/timing.py +++ b/quantecon/util/timing.py @@ -175,6 +175,98 @@ def loop_timer(self, n, function, args=None, verbose=True, digits=2, __timer__ = __Timer__() +class Timer: + """ + A context manager for timing code execution. + + This provides a modern context manager approach to timing, allowing + patterns like `with Timer():` instead of manual tic/toc calls. + + Parameters + ---------- + message : str, optional(default="") + Custom message to display with timing results. + precision : int, optional(default=2) + Number of decimal places to display for seconds. + unit : str, optional(default="seconds") + Unit to display timing in. Options: "seconds", "milliseconds", "microseconds" + silent : bool, optional(default=False) + If True, suppress printing of timing results. + + Attributes + ---------- + elapsed : float + The elapsed time in seconds. Available after exiting the context. + + Examples + -------- + Basic usage: + >>> with Timer(): + ... # some code + ... pass + 0.00 seconds elapsed + + With custom message and precision: + >>> with Timer("Computing results", precision=4): + ... # some code + ... pass + Computing results: 0.0001 seconds elapsed + + Store elapsed time for comparison: + >>> timer = Timer(silent=True) + >>> with timer: + ... # some code + ... pass + >>> print(f"Method took {timer.elapsed:.6f} seconds") + Method took 0.000123 seconds + """ + + def __init__(self, message="", precision=2, unit="seconds", silent=False): + self.message = message + self.precision = precision + self.unit = unit.lower() + self.silent = silent + self.elapsed = None + self._start_time = None + + # Validate unit + valid_units = ["seconds", "milliseconds", "microseconds"] + if self.unit not in valid_units: + raise ValueError(f"unit must be one of {valid_units}") + + def __enter__(self): + self._start_time = time.time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + end_time = time.time() + self.elapsed = end_time - self._start_time + + if not self.silent: + self._print_elapsed() + + def _print_elapsed(self): + """Print the elapsed time with appropriate formatting.""" + # Convert to requested unit + if self.unit == "milliseconds": + elapsed_display = self.elapsed * 1000 + unit_str = "ms" + elif self.unit == "microseconds": + elapsed_display = self.elapsed * 1000000 + unit_str = "μs" + else: # seconds + elapsed_display = self.elapsed + unit_str = "seconds" + + # Format the message + if self.message: + prefix = f"{self.message}: " + else: + prefix = "" + + print(f"{prefix}{elapsed_display:.{self.precision}f} {unit_str} elapsed") + + def tic(): return __timer__.tic()