Skip to content

Commit

Permalink
refactor: Simplify SchemaBase.copy (#3543)
Browse files Browse the repository at this point in the history
* refactor: Fix `C901` complexity in `SchemaBase.copy`

The nested functions do not reference `self` and can be split out, much like `.to_dict` -> `_todict`.
Saw this as a chance to add annotations as well. I think this helps illustrate that `.copy` is a noop for anything other than `SchemaBase | dict | list` - which I think might need to be addressed in the future.

* refactor: Further simplify `SchemaBase.copy`

10 lines shorter and is no longer constrained to `list` on the assert.
May be slightly faster on `python<3.11` which do not have zero-cost exceptions https://docs.python.org/3.11/whatsnew/3.11.html#misc

* refactor(perf): Merge identical `_shallow_copy` branches

Remembered a related `ruff` rule [FURB145](https://docs.astral.sh/ruff/rules/slice-copy/).
Wouldn't have made this fix, but likely would have led someone there

* refactor(perf): Reduce `_deep_copy`

- Initialize an empty set **once** at origin, rather than creating a new list per iteration
  - Renamed to `by_ref` in the new private function, to better describe the operation.
  - No change to public API.
- Define a partial `copy` to reduce repetition
- Use a genexpr for `args`, to avoid unpacking twice
  • Loading branch information
dangotbanned authored Aug 18, 2024
1 parent 635a2ec commit c9f9438
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 92 deletions.
90 changes: 44 additions & 46 deletions altair/utils/schemapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Sequence,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import TypeAlias
Expand Down Expand Up @@ -833,6 +834,38 @@ def is_undefined(obj: Any) -> TypeIs[UndefinedType]:
return obj is Undefined


@overload
def _shallow_copy(obj: _CopyImpl) -> _CopyImpl: ...
@overload
def _shallow_copy(obj: Any) -> Any: ...
def _shallow_copy(obj: _CopyImpl | Any) -> _CopyImpl | Any:
if isinstance(obj, SchemaBase):
return obj.copy(deep=False)
elif isinstance(obj, (list, dict)):
return obj.copy()
else:
return obj


@overload
def _deep_copy(obj: _CopyImpl, by_ref: set[str]) -> _CopyImpl: ...
@overload
def _deep_copy(obj: Any, by_ref: set[str]) -> Any: ...
def _deep_copy(obj: _CopyImpl | Any, by_ref: set[str]) -> _CopyImpl | Any:
copy = partial(_deep_copy, by_ref=by_ref)
if isinstance(obj, SchemaBase):
args = (copy(arg) for arg in obj._args)
kwds = {k: (copy(v) if k not in by_ref else v) for k, v in obj._kwds.items()}
with debug_mode(False):
return obj.__class__(*args, **kwds)
elif isinstance(obj, list):
return [copy(v) for v in obj]
elif isinstance(obj, dict):
return {k: (copy(v) if k not in by_ref else v) for k, v in obj.items()}
else:
return obj


class SchemaBase:
"""
Base class for schema wrappers.
Expand Down Expand Up @@ -870,7 +903,7 @@ def __init__(self, *args: Any, **kwds: Any) -> None:
if DEBUG_MODE and self._class_is_valid_at_instantiation:
self.to_dict(validate=True)

def copy( # noqa: C901
def copy(
self, deep: bool | Iterable[Any] = True, ignore: list[str] | None = None
) -> Self:
"""
Expand All @@ -887,53 +920,11 @@ def copy( # noqa: C901
A list of keys for which the contents should not be copied, but
only stored by reference.
"""

def _shallow_copy(obj):
if isinstance(obj, SchemaBase):
return obj.copy(deep=False)
elif isinstance(obj, list):
return obj[:]
elif isinstance(obj, dict):
return obj.copy()
else:
return obj

def _deep_copy(obj, ignore: list[str] | None = None):
if ignore is None:
ignore = []
if isinstance(obj, SchemaBase):
args = tuple(_deep_copy(arg) for arg in obj._args)
kwds = {
k: (_deep_copy(v, ignore=ignore) if k not in ignore else v)
for k, v in obj._kwds.items()
}
with debug_mode(False):
return obj.__class__(*args, **kwds)
elif isinstance(obj, list):
return [_deep_copy(v, ignore=ignore) for v in obj]
elif isinstance(obj, dict):
return {
k: (_deep_copy(v, ignore=ignore) if k not in ignore else v)
for k, v in obj.items()
}
else:
return obj

try:
deep = list(deep) # type: ignore[arg-type]
except TypeError:
deep_is_list = False
else:
deep_is_list = True

if deep and not deep_is_list:
return _deep_copy(self, ignore=ignore)

if deep is True:
return cast("Self", _deep_copy(self, set(ignore) if ignore else set()))
with debug_mode(False):
copy = self.__class__(*self._args, **self._kwds)
if deep_is_list:
# Assert statement is for the benefit of Mypy
assert isinstance(deep, list)
if _is_iterable(deep):
for attr in deep:
copy[attr] = _shallow_copy(copy._get(attr))
return copy
Expand Down Expand Up @@ -1240,6 +1231,13 @@ def __dir__(self) -> list[str]:

TSchemaBase = TypeVar("TSchemaBase", bound=SchemaBase)

_CopyImpl = TypeVar("_CopyImpl", SchemaBase, Dict[Any, Any], List[Any])
"""
Types which have an implementation in ``SchemaBase.copy()``.
All other types are returned **by reference**.
"""


def _is_dict(obj: Any | dict[Any, Any]) -> TypeIs[dict[Any, Any]]:
return isinstance(obj, dict)
Expand Down
90 changes: 44 additions & 46 deletions tools/schemapi/schemapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Sequence,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import TypeAlias
Expand Down Expand Up @@ -831,6 +832,38 @@ def is_undefined(obj: Any) -> TypeIs[UndefinedType]:
return obj is Undefined


@overload
def _shallow_copy(obj: _CopyImpl) -> _CopyImpl: ...
@overload
def _shallow_copy(obj: Any) -> Any: ...
def _shallow_copy(obj: _CopyImpl | Any) -> _CopyImpl | Any:
if isinstance(obj, SchemaBase):
return obj.copy(deep=False)
elif isinstance(obj, (list, dict)):
return obj.copy()
else:
return obj


@overload
def _deep_copy(obj: _CopyImpl, by_ref: set[str]) -> _CopyImpl: ...
@overload
def _deep_copy(obj: Any, by_ref: set[str]) -> Any: ...
def _deep_copy(obj: _CopyImpl | Any, by_ref: set[str]) -> _CopyImpl | Any:
copy = partial(_deep_copy, by_ref=by_ref)
if isinstance(obj, SchemaBase):
args = (copy(arg) for arg in obj._args)
kwds = {k: (copy(v) if k not in by_ref else v) for k, v in obj._kwds.items()}
with debug_mode(False):
return obj.__class__(*args, **kwds)
elif isinstance(obj, list):
return [copy(v) for v in obj]
elif isinstance(obj, dict):
return {k: (copy(v) if k not in by_ref else v) for k, v in obj.items()}
else:
return obj


class SchemaBase:
"""
Base class for schema wrappers.
Expand Down Expand Up @@ -868,7 +901,7 @@ def __init__(self, *args: Any, **kwds: Any) -> None:
if DEBUG_MODE and self._class_is_valid_at_instantiation:
self.to_dict(validate=True)

def copy( # noqa: C901
def copy(
self, deep: bool | Iterable[Any] = True, ignore: list[str] | None = None
) -> Self:
"""
Expand All @@ -885,53 +918,11 @@ def copy( # noqa: C901
A list of keys for which the contents should not be copied, but
only stored by reference.
"""

def _shallow_copy(obj):
if isinstance(obj, SchemaBase):
return obj.copy(deep=False)
elif isinstance(obj, list):
return obj[:]
elif isinstance(obj, dict):
return obj.copy()
else:
return obj

def _deep_copy(obj, ignore: list[str] | None = None):
if ignore is None:
ignore = []
if isinstance(obj, SchemaBase):
args = tuple(_deep_copy(arg) for arg in obj._args)
kwds = {
k: (_deep_copy(v, ignore=ignore) if k not in ignore else v)
for k, v in obj._kwds.items()
}
with debug_mode(False):
return obj.__class__(*args, **kwds)
elif isinstance(obj, list):
return [_deep_copy(v, ignore=ignore) for v in obj]
elif isinstance(obj, dict):
return {
k: (_deep_copy(v, ignore=ignore) if k not in ignore else v)
for k, v in obj.items()
}
else:
return obj

try:
deep = list(deep) # type: ignore[arg-type]
except TypeError:
deep_is_list = False
else:
deep_is_list = True

if deep and not deep_is_list:
return _deep_copy(self, ignore=ignore)

if deep is True:
return cast("Self", _deep_copy(self, set(ignore) if ignore else set()))
with debug_mode(False):
copy = self.__class__(*self._args, **self._kwds)
if deep_is_list:
# Assert statement is for the benefit of Mypy
assert isinstance(deep, list)
if _is_iterable(deep):
for attr in deep:
copy[attr] = _shallow_copy(copy._get(attr))
return copy
Expand Down Expand Up @@ -1238,6 +1229,13 @@ def __dir__(self) -> list[str]:

TSchemaBase = TypeVar("TSchemaBase", bound=SchemaBase)

_CopyImpl = TypeVar("_CopyImpl", SchemaBase, Dict[Any, Any], List[Any])
"""
Types which have an implementation in ``SchemaBase.copy()``.
All other types are returned **by reference**.
"""


def _is_dict(obj: Any | dict[Any, Any]) -> TypeIs[dict[Any, Any]]:
return isinstance(obj, dict)
Expand Down

0 comments on commit c9f9438

Please sign in to comment.