Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: CoW - correctly track references for chained operations #48996

Merged
12 changes: 9 additions & 3 deletions pandas/_libs/internals.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -676,8 +676,9 @@ cdef class BlockManager:
public bint _known_consolidated, _is_consolidated
public ndarray _blknos, _blklocs
public list refs
public object parent

def __cinit__(self, blocks=None, axes=None, refs=None, verify_integrity=True):
def __cinit__(self, blocks=None, axes=None, refs=None, parent=None, verify_integrity=True):
# None as defaults for unpickling GH#42345
if blocks is None:
# This adds 1-2 microseconds to DataFrame(np.array([]))
Expand All @@ -690,6 +691,7 @@ cdef class BlockManager:
self.blocks = blocks
self.axes = axes.copy() # copy to make sure we are not remotely-mutable
self.refs = refs
self.parent = parent

# Populate known_consolidate, blknos, and blklocs lazily
self._known_consolidated = False
Expand Down Expand Up @@ -805,7 +807,9 @@ cdef class BlockManager:
nrefs.append(weakref.ref(blk))

new_axes = [self.axes[0], self.axes[1]._getitem_slice(slobj)]
mgr = type(self)(tuple(nbs), new_axes, nrefs, verify_integrity=False)
mgr = type(self)(
tuple(nbs), new_axes, nrefs, parent=self, verify_integrity=False
)

# We can avoid having to rebuild blklocs/blknos
blklocs = self._blklocs
Expand All @@ -827,4 +831,6 @@ cdef class BlockManager:
new_axes = list(self.axes)
new_axes[axis] = new_axes[axis]._getitem_slice(slobj)

return type(self)(tuple(new_blocks), new_axes, new_refs, verify_integrity=False)
return type(self)(
tuple(new_blocks), new_axes, new_refs, parent=self, verify_integrity=False
)
54 changes: 42 additions & 12 deletions pandas/core/internals/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import pandas.core.algorithms as algos
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
from pandas.core.arrays.sparse import SparseDtype
import pandas.core.common as com
from pandas.core.construction import (
ensure_wrapped_if_datetimelike,
extract_array,
Expand Down Expand Up @@ -148,6 +149,7 @@ class BaseBlockManager(DataManager):
blocks: tuple[Block, ...]
axes: list[Index]
refs: list[weakref.ref | None] | None
parent: object

@property
def ndim(self) -> int:
Expand All @@ -165,6 +167,7 @@ def from_blocks(
blocks: list[Block],
axes: list[Index],
refs: list[weakref.ref | None] | None = None,
parent: object = None,
) -> T:
raise NotImplementedError

Expand Down Expand Up @@ -605,7 +608,10 @@ def _combine(
axes[-1] = index
axes[0] = self.items.take(indexer)

return type(self).from_blocks(new_blocks, axes, new_refs)
# TODO cover this one
return type(self).from_blocks(
new_blocks, axes, new_refs, parent=None if copy else self
)

@property
def nblocks(self) -> int:
Expand Down Expand Up @@ -648,11 +654,14 @@ def copy_func(ax):
new_refs: list[weakref.ref | None] | None
if deep:
new_refs = None
parent = None
else:
new_refs = [weakref.ref(blk) for blk in self.blocks]
parent = self

res.axes = new_axes
res.refs = new_refs
res.parent = parent

if self.ndim > 1:
# Avoid needing to re-compute these
Expand Down Expand Up @@ -744,6 +753,7 @@ def reindex_indexer(
only_slice=only_slice,
use_na_proxy=use_na_proxy,
)
parent = None if com.all_none(*new_refs) else self
else:
new_blocks = [
blk.take_nd(
Expand All @@ -756,11 +766,12 @@ def reindex_indexer(
for blk in self.blocks
]
new_refs = None
parent = None

new_axes = list(self.axes)
new_axes[axis] = new_axis

new_mgr = type(self).from_blocks(new_blocks, new_axes, new_refs)
new_mgr = type(self).from_blocks(new_blocks, new_axes, new_refs, parent=parent)
if axis == 1:
# We can avoid the need to rebuild these
new_mgr._blknos = self.blknos.copy()
Expand Down Expand Up @@ -995,6 +1006,7 @@ def __init__(
blocks: Sequence[Block],
axes: Sequence[Index],
refs: list[weakref.ref | None] | None = None,
parent: object = None,
verify_integrity: bool = True,
) -> None:

Expand Down Expand Up @@ -1059,11 +1071,13 @@ def from_blocks(
blocks: list[Block],
axes: list[Index],
refs: list[weakref.ref | None] | None = None,
parent: object = None,
) -> BlockManager:
"""
Constructor for BlockManager and SingleBlockManager with same signature.
"""
return cls(blocks, axes, refs, verify_integrity=False)
parent = parent if _using_copy_on_write() else None
return cls(blocks, axes, refs, parent, verify_integrity=False)

# ----------------------------------------------------------------
# Indexing
Expand All @@ -1085,7 +1099,7 @@ def fast_xs(self, loc: int) -> SingleBlockManager:
block = new_block(result, placement=slice(0, len(result)), ndim=1)
# in the case of a single block, the new block is a view
ref = weakref.ref(self.blocks[0])
return SingleBlockManager(block, self.axes[0], [ref])
return SingleBlockManager(block, self.axes[0], [ref], parent=self)

dtype = interleaved_dtype([blk.dtype for blk in self.blocks])

Expand Down Expand Up @@ -1119,7 +1133,7 @@ def fast_xs(self, loc: int) -> SingleBlockManager:
block = new_block(result, placement=slice(0, len(result)), ndim=1)
return SingleBlockManager(block, self.axes[0])

def iget(self, i: int) -> SingleBlockManager:
def iget(self, i: int, track_ref: bool = True) -> SingleBlockManager:
"""
Return the data as a SingleBlockManager.
"""
Expand All @@ -1129,7 +1143,9 @@ def iget(self, i: int) -> SingleBlockManager:
# shortcut for select a single-dim from a 2-dim BM
bp = BlockPlacement(slice(0, len(values)))
nb = type(block)(values, placement=bp, ndim=1)
return SingleBlockManager(nb, self.axes[1], [weakref.ref(block)])
ref = weakref.ref(block) if track_ref else None
parent = self if track_ref else None
return SingleBlockManager(nb, self.axes[1], [ref], parent)

def iget_values(self, i: int) -> ArrayLike:
"""
Expand Down Expand Up @@ -1371,7 +1387,9 @@ def column_setitem(self, loc: int, idx: int | slice | np.ndarray, value) -> None
self.blocks = tuple(blocks)
self._clear_reference_block(blkno)

col_mgr = self.iget(loc)
# this manager is only created temporarily to mutate the values in place
# so don't track references, otherwise the `setitem` would perform CoW again
col_mgr = self.iget(loc, track_ref=False)
new_mgr = col_mgr.setitem((idx,), value)
self.iset(loc, new_mgr._block.values, inplace=True)

Expand Down Expand Up @@ -1469,7 +1487,9 @@ def idelete(self, indexer) -> BlockManager:
nbs, new_refs = self._slice_take_blocks_ax0(taker, only_slice=True)
new_columns = self.items[~is_deleted]
axes = [new_columns, self.axes[1]]
return type(self)(tuple(nbs), axes, new_refs, verify_integrity=False)
# TODO this might not be needed (can a delete ever be done in chained manner?)
parent = None if com.all_none(*new_refs) else self
return type(self)(tuple(nbs), axes, new_refs, parent, verify_integrity=False)

# ----------------------------------------------------------------
# Block-wise Operation
Expand Down Expand Up @@ -1875,6 +1895,7 @@ def __init__(
block: Block,
axis: Index,
refs: list[weakref.ref | None] | None = None,
parent: object = None,
verify_integrity: bool = False,
fastpath=lib.no_default,
) -> None:
Expand All @@ -1893,13 +1914,15 @@ def __init__(
self.axes = [axis]
self.blocks = (block,)
self.refs = refs
self.parent = parent if _using_copy_on_write() else None

@classmethod
def from_blocks(
cls,
blocks: list[Block],
axes: list[Index],
refs: list[weakref.ref | None] | None = None,
parent: object = None,
) -> SingleBlockManager:
"""
Constructor for BlockManager and SingleBlockManager with same signature.
Expand All @@ -1908,7 +1931,7 @@ def from_blocks(
assert len(axes) == 1
if refs is not None:
assert len(refs) == 1
return cls(blocks[0], axes[0], refs, verify_integrity=False)
return cls(blocks[0], axes[0], refs, parent, verify_integrity=False)

@classmethod
def from_array(cls, array: ArrayLike, index: Index) -> SingleBlockManager:
Expand All @@ -1928,7 +1951,10 @@ def to_2d_mgr(self, columns: Index) -> BlockManager:
new_blk = type(blk)(arr, placement=bp, ndim=2)
axes = [columns, self.axes[0]]
refs: list[weakref.ref | None] = [weakref.ref(blk)]
return BlockManager([new_blk], axes=axes, refs=refs, verify_integrity=False)
parent = self if _using_copy_on_write() else None
return BlockManager(
[new_blk], axes=axes, refs=refs, parent=parent, verify_integrity=False
)

def _has_no_reference(self, i: int = 0) -> bool:
"""
Expand Down Expand Up @@ -2010,7 +2036,7 @@ def getitem_mgr(self, indexer: slice | npt.NDArray[np.bool_]) -> SingleBlockMana
new_idx = self.index[indexer]
# TODO(CoW) in theory only need to track reference if new_array is a view
ref = weakref.ref(blk)
return type(self)(block, new_idx, [ref])
return type(self)(block, new_idx, [ref], parent=self)

def get_slice(self, slobj: slice, axis: AxisInt = 0) -> SingleBlockManager:
# Assertion disabled for performance
Expand All @@ -2023,7 +2049,9 @@ def get_slice(self, slobj: slice, axis: AxisInt = 0) -> SingleBlockManager:
bp = BlockPlacement(slice(0, len(array)))
block = type(blk)(array, placement=bp, ndim=1)
new_index = self.index._getitem_slice(slobj)
return type(self)(block, new_index, [weakref.ref(blk)])
# TODO this method is only used in groupby SeriesSplitter at the moment,
# so passing refs / parent is not yet covered by the tests
return type(self)(block, new_index, [weakref.ref(blk)], parent=self)

@property
def index(self) -> Index:
Expand Down Expand Up @@ -2070,6 +2098,7 @@ def setitem_inplace(self, indexer, value) -> None:
if _using_copy_on_write() and not self._has_no_reference(0):
self.blocks = (self._block.copy(),)
self.refs = None
self.parent = None
self._cache.clear()

super().setitem_inplace(indexer, value)
Expand All @@ -2086,6 +2115,7 @@ def idelete(self, indexer) -> SingleBlockManager:
self._cache.clear()
# clear reference since delete always results in a new array
self.refs = None
self.parent = None
return self

def fast_xs(self, loc):
Expand Down
Loading