⚡️ Speed up function handle_error by 103%
#620
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 103% (1.03x) speedup for
handle_errorinmarimo/_server/errors.py⏱️ Runtime :
2.74 milliseconds→1.35 milliseconds(best of186runs)📝 Explanation and details
The optimized code achieves a 102% speedup (2.74ms → 1.35ms runtime) through several targeted optimizations that reduce redundant operations and attribute lookups:
Key Optimizations
1. Session ID Caching in AppState
_cached_session_idattribute to cache the parsed session ID after first lookupget_current_session_id()is called multiple times2. Attribute Access Optimization
self.request.headersin local variablehdrsto avoid repeated attribute traversalgetattr()for session_manager access instead of direct attribute access3. Exception Handling Consolidation
isinstance()checks with a single tuple-based type checkisinstance()calls for each exception typetype(response) in (ValidationError, NotImplementedError, TypeError, Exception)reduces this to one type lookup + one tuple membership teststr(response)to avoid duplicate string conversion (profiler shows 528str()calls)Performance Impact
The error handling optimization is the primary driver of the speedup. The profiler reveals that exception handling dominates execution time (54.3% spent on generic Exception JSONResponse creation). By reducing redundant type checks and string conversions, the optimized version processes exceptions much more efficiently.
Throughput improvement of 3.9% (105,073 → 109,182 ops/sec) demonstrates better sustained performance under load, particularly beneficial for error-heavy workloads where the consolidated exception handling path significantly reduces overhead.
The optimizations are especially effective for test cases involving mixed exception types and high-volume concurrent error handling, where the reduced isinstance() calls and cached string conversions compound into substantial performance gains.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
import asyncio # used to run async functions
import msgspec
import pytest # used for our unit tests
from marimo._server.errors import handle_error
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse
Mocks for dependencies
class DummyLogger:
def init(self):
self.logged = []
self.warned = []
def error(self, msg):
self.logged.append(("error", msg))
def warning(self, msg):
self.warned.append(("warning", msg))
class DummyHeaders(dict):
def get(self, key, default=None):
return super().get(key, default)
class DummyRequest:
def init(self, headers=None, app=None):
self.headers = DummyHeaders(headers or {})
self.app = app or DummyApp()
class DummyApp:
def init(self):
self.state = DummyState()
class DummyState:
pass
Minimal MarimoHTTPException mock
class MarimoHTTPException(Exception):
def init(self, status_code, detail, headers=None):
self.status_code = status_code
self.detail = detail
self.headers = headers or {}
Minimal SessionMode mock
class SessionMode:
EDIT = "EDIT"
RUN = "RUN"
Minimal ConsumerId mock
class ConsumerId(str):
pass
Minimal MissingPackageAlert mock
class MissingPackageAlert:
def init(self, packages, isolated):
self.packages = packages
self.isolated = isolated
from marimo._server.errors import handle_error
----------------- UNIT TESTS -----------------
1. Basic Test Cases
@pytest.mark.asyncio
async def test_handle_error_basic_starlette_http_exception():
"""Test StarletteHTTPException returns correct JSONResponse"""
req = DummyRequest()
exc = StarletteHTTPException(status_code=404, detail="Not Found")
resp = await handle_error(req, exc)
@pytest.mark.asyncio
async def test_handle_error_basic_marimo_http_exception():
"""Test MarimoHTTPException returns correct JSONResponse"""
req = DummyRequest()
exc = MarimoHTTPException(status_code=500, detail="Internal Error")
resp = await handle_error(req, exc)
@pytest.mark.asyncio
async def test_handle_error_basic_module_not_found_error():
"""Test ModuleNotFoundError returns correct JSONResponse"""
req = DummyRequest(headers={"Marimo-Session-Id": "abc"})
exc = ModuleNotFoundError("No module named 'foobar'")
exc.name = "foobar"
exc.msg = "No module named 'foobar'"
resp = await handle_error(req, exc)
@pytest.mark.asyncio
async def test_handle_error_basic_validation_error():
"""Test msgspec.ValidationError returns correct JSONResponse"""
req = DummyRequest()
exc = msgspec.ValidationError("invalid input")
resp = await handle_error(req, exc)
@pytest.mark.asyncio
async def test_handle_error_basic_not_implemented_error():
"""Test NotImplementedError returns correct JSONResponse"""
req = DummyRequest()
exc = NotImplementedError("not supported yet")
resp = await handle_error(req, exc)
@pytest.mark.asyncio
async def test_handle_error_basic_type_error():
"""Test TypeError returns correct JSONResponse"""
req = DummyRequest()
exc = TypeError("bad type")
resp = await handle_error(req, exc)
@pytest.mark.asyncio
async def test_handle_error_basic_generic_exception():
"""Test generic Exception returns correct JSONResponse"""
req = DummyRequest()
exc = Exception("generic error")
resp = await handle_error(req, exc)
@pytest.mark.asyncio
async def test_handle_error_basic_passthrough():
"""Test non-exception response is returned as-is"""
req = DummyRequest()
resp = await handle_error(req, "hello world")
2. Edge Test Cases
@pytest.mark.asyncio
async def test_handle_error_starlette_http_exception_403():
"""Test StarletteHTTPException 403 is mapped to 401 with WWW-Authenticate header"""
req = DummyRequest()
exc = StarletteHTTPException(status_code=403, detail="Forbidden", headers={"X-Test": "foo"})
resp = await handle_error(req, exc)
@pytest.mark.asyncio
async def test_handle_error_module_not_found_error_no_name():
"""Test ModuleNotFoundError with no name attribute returns generic 500"""
req = DummyRequest()
exc = ModuleNotFoundError("No module named 'foobar'")
# no exc.name attribute
resp = await handle_error(req, exc)
@pytest.mark.asyncio
async def test_handle_error_module_not_found_error_session_none():
"""Test ModuleNotFoundError when session_id is None (no session header)"""
req = DummyRequest(headers={})
exc = ModuleNotFoundError("No module named 'baz'")
exc.name = "baz"
exc.msg = "No module named 'baz'"
resp = await handle_error(req, exc)
@pytest.mark.asyncio
async def test_handle_error_module_not_found_error_session_object_none():
"""Test ModuleNotFoundError when session object is None"""
req = DummyRequest(headers={"Marimo-Session-Id": "notfound"})
exc = ModuleNotFoundError("No module named 'qux'")
exc.name = "qux"
exc.msg = "No module named 'qux'"
# SessionManager returns None for unknown session
resp = await handle_error(req, exc)
@pytest.mark.asyncio
async def test_handle_error_concurrent_exceptions():
"""Test concurrent execution of multiple error types"""
req = DummyRequest()
excs = [
StarletteHTTPException(status_code=404, detail="Not Found"),
MarimoHTTPException(status_code=500, detail="Internal Error"),
NotImplementedError("not supported yet"),
TypeError("bad type"),
Exception("generic error"),
"hello world"
]
results = await asyncio.gather(*(handle_error(req, exc) for exc in excs))
@pytest.mark.asyncio
async def test_handle_error_edge_msgspec_validation_error_custom_message():
"""Test msgspec.ValidationError with custom message"""
req = DummyRequest()
exc = msgspec.ValidationError("custom validation error")
resp = await handle_error(req, exc)
3. Large Scale Test Cases
@pytest.mark.asyncio
async def test_handle_error_large_scale_concurrent_starlette_http_exceptions():
"""Test many concurrent StarletteHTTPExceptions"""
req = DummyRequest()
excs = [StarletteHTTPException(status_code=400 + i, detail=f"Error {i}") for i in range(10)]
results = await asyncio.gather(*(handle_error(req, exc) for exc in excs))
for i, resp in enumerate(results):
pass
@pytest.mark.asyncio
async def test_handle_error_large_scale_mixed_exceptions():
"""Test many concurrent mixed exceptions"""
req = DummyRequest()
excs = []
for i in range(10):
if i % 2 == 0:
excs.append(StarletteHTTPException(status_code=404, detail=f"Not Found {i}"))
else:
excs.append(MarimoHTTPException(status_code=500, detail=f"Internal Error {i}"))
results = await asyncio.gather(*(handle_error(req, exc) for exc in excs))
for i, resp in enumerate(results):
if i % 2 == 0:
pass
else:
pass
4. Throughput Test Cases
@pytest.mark.asyncio
async def test_handle_error_throughput_small_load():
"""Throughput test: small load of concurrent error handling"""
req = DummyRequest()
excs = [StarletteHTTPException(status_code=400, detail="Small Load")] * 5
results = await asyncio.gather(*(handle_error(req, exc) for exc in excs))
for resp in results:
pass
@pytest.mark.asyncio
async def test_handle_error_throughput_medium_load():
"""Throughput test: medium load of concurrent error handling"""
req = DummyRequest()
excs = [MarimoHTTPException(status_code=500, detail="Medium Load")] * 50
results = await asyncio.gather(*(handle_error(req, exc) for exc in excs))
for resp in results:
pass
@pytest.mark.asyncio
async def test_handle_error_throughput_large_load():
"""Throughput test: large load of concurrent error handling"""
req = DummyRequest()
excs = [TypeError("Large Load")] * 100
results = await asyncio.gather(*(handle_error(req, exc) for exc in excs))
for resp in results:
pass
@pytest.mark.asyncio
async def test_handle_error_throughput_mixed_load():
"""Throughput test: mixed error types under concurrent load"""
req = DummyRequest()
excs = []
for i in range(25):
excs.append(StarletteHTTPException(status_code=404, detail=f"Not Found {i}"))
excs.append(MarimoHTTPException(status_code=500, detail=f"Internal Error {i}"))
excs.append(TypeError(f"Type Error {i}"))
excs.append(Exception(f"Generic Error {i}"))
results = await asyncio.gather(*(handle_error(req, exc) for exc in excs))
for i in range(0, len(results), 4):
pass
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import asyncio # used to run async functions
import pytest # used for our unit tests
from marimo._server.errors import handle_error
Mocks and helpers for dependencies
class DummyLogger:
def init(self):
self.logged_errors = []
self.logged_warnings = []
def error(self, msg):
self.logged_errors.append(str(msg))
def warning(self, msg):
self.logged_warnings.append(str(msg))
class DummyJSONResponse:
def init(self, content, status_code=200, headers=None):
self.content = content
self.status_code = status_code
self.headers = headers or {}
Dummy HTTPException and MarimoHTTPException
class DummyHTTPException(Exception):
def init(self, status_code, detail, headers=None):
self.status_code = status_code
self.detail = detail
self.headers = headers or {}
class DummyMarimoHTTPException(Exception):
def init(self, status_code, detail):
self.status_code = status_code
self.detail = detail
Dummy msgspec.ValidationError
class DummyValidationError(Exception):
pass
Dummy AppState, Session, SessionMode, ConsumerId, MissingPackageAlert
class DummySessionMode:
EDIT = "EDIT"
RUN = "RUN"
class DummyConsumerId(str):
pass
class DummySession:
pass
class DummyAppState:
def init(self, request):
self.request = request
self.mode = DummySessionMode.EDIT
self.session_manager = self
def get_current_session_id(self):
return "dummy-session-id"
def get_current_session(self):
return DummySession()
class DummyMissingPackageAlert:
def init(self, packages, isolated):
self.packages = packages
self.isolated = isolated
Patchable references
dummy_logger = DummyLogger()
def dummy_is_client_error(status_code):
return 400 <= status_code < 500
def dummy_send_message_to_consumer(session, operation, consumer_id):
# Just record that this was called
dummy_send_message_to_consumer.called.append((session, operation, consumer_id))
dummy_send_message_to_consumer.called = []
def dummy_is_python_isolated():
return True
from marimo._server.errors import handle_error
------------------ UNIT TESTS ------------------
1. Basic Test Cases
@pytest.mark.asyncio
async def test_handle_error_http_exception_basic():
"""Test basic HTTPException handling."""
req = object()
exc = DummyHTTPException(404, "Not Found")
resp = await handle_error(req, exc)
@pytest.mark.asyncio
async def test_handle_error_http_exception_403_to_401():
"""Test HTTPException 403 is mapped to 401 with WWW-Authenticate header."""
req = object()
exc = DummyHTTPException(403, "Forbidden")
resp = await handle_error(req, exc)
@pytest.mark.asyncio
async def test_handle_error_marimo_http_exception_client_error():
"""Test MarimoHTTPException with client error status code."""
req = object()
exc = DummyMarimoHTTPException(400, "Bad Request")
resp = await handle_error(req, exc)
@pytest.mark.asyncio
async def test_handle_error_marimo_http_exception_server_error():
"""Test MarimoHTTPException with server error status code logs error."""
req = object()
exc = DummyMarimoHTTPException(500, "Internal Server Error")
# Clear logger
dummy_logger.logged_errors.clear()
resp = await handle_error(req, exc)
@pytest.mark.asyncio
async def test_handle_error_validation_error():
"""Test ValidationError returns 400."""
req = object()
exc = DummyValidationError("invalid")
resp = await handle_error(req, exc)
@pytest.mark.asyncio
async def test_handle_error_not_implemented_error():
"""Test NotImplementedError returns 501."""
req = object()
exc = NotImplementedError("not implemented")
resp = await handle_error(req, exc)
@pytest.mark.asyncio
async def test_handle_error_type_error():
"""Test TypeError returns 500."""
req = object()
exc = TypeError("type error")
resp = await handle_error(req, exc)
@pytest.mark.asyncio
async def test_handle_error_generic_exception():
"""Test generic Exception returns 500."""
req = object()
exc = Exception("generic error")
resp = await handle_error(req, exc)
@pytest.mark.asyncio
async def test_handle_error_non_exception_response():
"""Test non-exception response is returned as-is."""
req = object()
resp_obj = {"foo": "bar"}
resp = await handle_error(req, resp_obj)
2. Edge Test Cases
@pytest.mark.asyncio
async def test_handle_error_concurrent_execution():
"""Test concurrent execution of handle_error with different exceptions."""
req = object()
excs = [
DummyHTTPException(404, "Not Found"),
DummyMarimoHTTPException(500, "Internal Error"),
DummyValidationError("bad"),
Exception("err"),
{"foo": "bar"},
]
results = await asyncio.gather(*(handle_error(req, e) for e in excs))
3. Large Scale Test Cases
@pytest.mark.asyncio
async def test_handle_error_large_scale_concurrent():
"""Test large scale concurrent execution."""
req = object()
exc_list = [DummyHTTPException(404, f"Not Found {i}") for i in range(50)]
results = await asyncio.gather(*(handle_error(req, e) for e in exc_list))
for i, resp in enumerate(results):
pass
@pytest.mark.asyncio
async def test_handle_error_large_scale_mixed_exceptions():
"""Test large scale mixed exceptions concurrently."""
req = object()
excs = []
for i in range(25):
excs.append(DummyHTTPException(400+i%5, f"HTTP {i}"))
excs.append(DummyMarimoHTTPException(500, f"Marimo {i}"))
excs.append(DummyValidationError(f"Val {i}"))
results = await asyncio.gather(*(handle_error(req, e) for e in excs))
for idx, resp in enumerate(results):
# Check status codes are correct
if idx % 3 == 0:
pass
elif idx % 3 == 1:
pass
else:
pass
4. Throughput Test Cases
@pytest.mark.asyncio
async def test_handle_error_throughput_small_load():
"""Test throughput under small load."""
req = object()
excs = [DummyHTTPException(404, "Not Found"), DummyMarimoHTTPException(500, "Internal Error")]
results = await asyncio.gather(*(handle_error(req, e) for e in excs))
@pytest.mark.asyncio
async def test_handle_error_throughput_medium_load():
"""Test throughput under medium load (50 mixed)."""
req = object()
excs = []
for i in range(25):
excs.append(DummyHTTPException(400+i%5, f"HTTP {i}"))
excs.append(DummyMarimoHTTPException(500, f"Marimo {i}"))
results = await asyncio.gather(*(handle_error(req, e) for e in excs))
@pytest.mark.asyncio
async def test_handle_error_throughput_high_volume():
"""Test throughput under high volume (100 mixed)."""
req = object()
excs = []
for i in range(50):
excs.append(DummyHTTPException(400+i%5, f"HTTP {i}"))
excs.append(DummyMarimoHTTPException(500, f"Marimo {i}"))
results = await asyncio.gather(*(handle_error(req, e) for e in excs))
# Check distribution
http_count = sum(r.status_code < 500 for r in results)
marimo_count = sum(r.status_code == 500 for r in results)
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
To edit these changes
git checkout codeflash/optimize-handle_error-mhvlp40wand push.