Skip to content

Commit

Permalink
Add an optional filter to snapshot for frames to include and fix un…
Browse files Browse the repository at this point in the history
…serializable values (#1085)
  • Loading branch information
leowrites authored Sep 22, 2024
1 parent 1078ec7 commit bbc67ba
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ and adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

### ✨ Enhancements

- Added `include` filter to `snapshot`

### 💫 New checkers

### 🐛 Bug fixes

- Fixed issue where `snapshot` errors on unserializable values
- Fixed issue within `Snapshot.py` where the `memory_viz_version` parameter was not respected

### 🔧 Internal changes
Expand Down
25 changes: 18 additions & 7 deletions python_ta/debug/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
import inspect
import json
import logging
import re
import shutil
import subprocess
import sys
from types import FrameType
from typing import Any, Optional
from typing import Any, Iterable, Optional

from packaging.version import Version, parse

Expand Down Expand Up @@ -40,6 +41,7 @@ def snapshot(
save: bool = False,
memory_viz_args: Optional[list[str]] = None,
memory_viz_version: str = "latest",
include: Optional[Iterable[str | re.Pattern]] = None,
):
"""Capture a snapshot of local variables from the current and outer stack frames
where the 'snapshot' function is called. Returns a list of dictionaries,
Expand All @@ -51,16 +53,20 @@ def snapshot(
For details on the MemoryViz CLI, see https://www.cs.toronto.edu/~david/memory-viz/docs/cli.
memory_viz_version can be used to dictate version, with a default of the latest version.
Note that this function is compatible only with MemoryViz version 0.3.1 and above.
include can be used to specify a collection of function names, either as strings or regular expressions,
whose variables will be captured. By default, all variables in all functions will be captured if no `include`
argument is provided.
"""
variables = []
frame = inspect.currentframe().f_back

while frame:
if frame.f_code.co_name != "<module>":
variables.append({frame.f_code.co_name: frame.f_locals})
else:
global_vars = get_filtered_global_variables(frame)
variables.append(global_vars)
if include is None or any(re.search(regex, frame.f_code.co_name) for regex in include):
if frame.f_code.co_name != "<module>":
variables.append({frame.f_code.co_name: frame.f_locals})
else:
global_vars = get_filtered_global_variables(frame)
variables.append(global_vars)

frame = frame.f_back

Expand Down Expand Up @@ -154,10 +160,15 @@ def process_value(val: Any) -> int:
"value": attr_ids,
}
else: # Handle primitives and other types
try:
json.dumps(val)
jsonable_val = val
except TypeError:
jsonable_val = repr(val)
value_entry = {
"type": type(val).__name__,
"id": value_id_diagram,
"value": val,
"value": jsonable_val,
}

value_entries.append(value_entry)
Expand Down
95 changes: 95 additions & 0 deletions tests/test_debug/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@
Test suite for snapshot functions
"""

from __future__ import annotations

import json
import os
import pathlib
import re
import subprocess
import sys
from typing import Iterable, Optional

from python_ta.debug.snapshot import snapshot, snapshot_to_json

Expand Down Expand Up @@ -642,3 +647,93 @@ def test_snapshot_save_stdout():
expected_svg = expected_svg_file.read()

assert result.stdout == expected_svg


def func_with_include(include: Optional[Iterable[str | re.Pattern]] = None) -> list[dict]:
test_var1a = "David is cool!"
test_var2a = "Students Developing Software"
return snapshot(include=include)


def func_with_include_nested(include: Optional[Iterable[str | re.Pattern]] = None) -> list[dict]:
test_var1b = {"SDS_coolest_project": "PyTA"}
test_var2b = ("Leo", "tester")
return func_with_include(include=include)


def func_with_unserializable_objects() -> list[dict]:
path = pathlib.PosixPath("some path")
vars_in_curr_func = [snapshot()[0]]
processed_result = snapshot_to_json(vars_in_curr_func)
json.dumps(processed_result)
return processed_result


def test_snapshot_only_includes_function_self():
result = func_with_include(include=("func_with_include",))
assert result == [
{
"func_with_include": {
"include": ("func_with_include",),
"test_var1a": "David is cool!",
"test_var2a": "Students Developing Software",
}
}
]


def test_snapshot_includes_multiple_functions():
result = func_with_include_nested(
include=(
"func_with_include",
"func_with_include_nested",
)
)
assert result == [
{
"func_with_include": {
"include": (
"func_with_include",
"func_with_include_nested",
),
"test_var1a": "David is cool!",
"test_var2a": "Students Developing Software",
},
},
{
"func_with_include_nested": {
"include": (
"func_with_include",
"func_with_include_nested",
),
"test_var1b": {"SDS_coolest_project": "PyTA"},
"test_var2b": ("Leo", "tester"),
},
},
]


def test_snapshot_only_includes_specified_function():
result = func_with_include_nested(include=("func_with_include_nested",))
assert result == [
{
"func_with_include_nested": {
"include": ("func_with_include_nested",),
"test_var1b": {"SDS_coolest_project": "PyTA"},
"test_var2b": ("Leo", "tester"),
},
}
]


def test_snapshot_serializes_unserializable_value():
result = func_with_unserializable_objects()
assert result == [
{
"id": None,
"name": "func_with_unserializable_objects",
"type": ".frame",
"value": {"path": 1},
},
{"id": 1, "type": "PosixPath", "value": repr(pathlib.PosixPath("some path"))},
]

0 comments on commit bbc67ba

Please sign in to comment.