Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 19 additions & 29 deletions marimo/_ast/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,13 @@
from pathlib import Path
from textwrap import dedent
from tokenize import TokenInfo, tokenize
from typing import (
TYPE_CHECKING,
Any,
Generic,
Optional,
TypeVar,
Union,
cast,
)
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast

from marimo._ast.names import DEFAULT_CELL_NAME, SETUP_CELL_NAME
from marimo._schemas.serialization import (
AppInstantiation,
CellDef,
ClassCell,
FunctionCell,
Header,
NotebookSerialization,
SetupCell,
UnparsableCell,
Violation,
)
from marimo._schemas.serialization import (AppInstantiation, CellDef,
ClassCell, FunctionCell, Header,
NotebookSerialization, SetupCell,
UnparsableCell, Violation)

if TYPE_CHECKING:
from collections.abc import Iterator
Expand Down Expand Up @@ -886,15 +871,20 @@ def is_body_cell(node: Node) -> bool:


def _is_setup_call(node: Node) -> bool:
if isinstance(node, ast.Attribute):
return (
isinstance(node.value, ast.Name)
and node.value.id == "app"
and node.attr == "setup"
)
elif isinstance(node, ast.Call):
return _is_setup_call(node.func)
return False
while True:
if isinstance(node, ast.Attribute):
value = node.value
if (
type(value) is ast.Name
and value.id == "app"
and node.attr == "setup"
):
return True
return False
elif isinstance(node, ast.Call):
node = node.func
continue
return False


def is_setup_cell(node: Node) -> bool:
Expand Down
7 changes: 6 additions & 1 deletion marimo/_ast/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def build_stub_fn(
) -> Callable[..., Any]:
# Avoid declaring the function in the global scope, since it may cause
# issues with meta-analysis tools like cxfreeze (see #3828).
PYTEST_BASE = ast_parse(inspect.getsource(_pytest_scaffold))
PYTEST_BASE = _cached_pytest_base()

# We modify the signature of the cell function such that pytest
# does not attempt to use the arguments as fixtures.
Expand Down Expand Up @@ -304,3 +304,8 @@ def process_for_pytest(func: Fn, cell: Cell) -> None:
# Insert the class into the frame.
frame.frame.f_locals[cls.__name__] = cls
break


@functools.lru_cache(maxsize=1)
def _cached_pytest_base() -> ast.Module:
return ast_parse(inspect.getsource(_pytest_scaffold))