diff --git a/marimo/_ast/parse.py b/marimo/_ast/parse.py index e23622e86b7..d61a4137adf 100644 --- a/marimo/_ast/parse.py +++ b/marimo/_ast/parse.py @@ -8,28 +8,13 @@ from pathlib import Path from textwrap import dedent from tokenize import TokenInfo, tokenize -from typing import ( - TYPE_CHECKING, - Any, - Generic, - Optional, - TypeVar, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast from marimo._ast.names import DEFAULT_CELL_NAME, SETUP_CELL_NAME -from marimo._schemas.serialization import ( - AppInstantiation, - CellDef, - ClassCell, - FunctionCell, - Header, - NotebookSerialization, - SetupCell, - UnparsableCell, - Violation, -) +from marimo._schemas.serialization import (AppInstantiation, CellDef, + ClassCell, FunctionCell, Header, + NotebookSerialization, SetupCell, + UnparsableCell, Violation) if TYPE_CHECKING: from collections.abc import Iterator @@ -886,15 +871,20 @@ def is_body_cell(node: Node) -> bool: def _is_setup_call(node: Node) -> bool: - if isinstance(node, ast.Attribute): - return ( - isinstance(node.value, ast.Name) - and node.value.id == "app" - and node.attr == "setup" - ) - elif isinstance(node, ast.Call): - return _is_setup_call(node.func) - return False + while True: + if isinstance(node, ast.Attribute): + value = node.value + if ( + type(value) is ast.Name + and value.id == "app" + and node.attr == "setup" + ): + return True + return False + elif isinstance(node, ast.Call): + node = node.func + continue + return False def is_setup_cell(node: Node) -> bool: diff --git a/marimo/_ast/pytest.py b/marimo/_ast/pytest.py index ee2515f533e..30c3dc570b5 100644 --- a/marimo/_ast/pytest.py +++ b/marimo/_ast/pytest.py @@ -44,7 +44,7 @@ def build_stub_fn( ) -> Callable[..., Any]: # Avoid declaring the function in the global scope, since it may cause # issues with meta-analysis tools like cxfreeze (see #3828). - PYTEST_BASE = ast_parse(inspect.getsource(_pytest_scaffold)) + PYTEST_BASE = _cached_pytest_base() # We modify the signature of the cell function such that pytest # does not attempt to use the arguments as fixtures. @@ -304,3 +304,8 @@ def process_for_pytest(func: Fn, cell: Cell) -> None: # Insert the class into the frame. frame.frame.f_locals[cls.__name__] = cls break + + +@functools.lru_cache(maxsize=1) +def _cached_pytest_base() -> ast.Module: + return ast_parse(inspect.getsource(_pytest_scaffold))