Skip to content

Commit 3f72d8c

Browse files
Revert "REF: remove JoinUnit.shape (#43651)" (#47406)
This reverts commit bb9a985.
1 parent 47494a4 commit 3f72d8c

File tree

1 file changed

+40
-16
lines changed

1 file changed

+40
-16
lines changed

pandas/core/internals/concat.py

+40-16
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,6 @@ def concatenate_managers(
212212
for placement, join_units in concat_plan:
213213
unit = join_units[0]
214214
blk = unit.block
215-
# Assertion disabled for performance
216-
# assert len(join_units) == len(mgrs_indexers)
217215

218216
if len(join_units) == 1:
219217
values = blk.values
@@ -331,20 +329,27 @@ def _get_mgr_concatenation_plan(mgr: BlockManager):
331329
plan : list of (BlockPlacement, JoinUnit) tuples
332330
333331
"""
332+
# Calculate post-reindex shape , save for item axis which will be separate
333+
# for each block anyway.
334+
mgr_shape_list = list(mgr.shape)
335+
mgr_shape = tuple(mgr_shape_list)
334336

335337
if mgr.is_single_block:
336338
blk = mgr.blocks[0]
337-
return [(blk.mgr_locs, JoinUnit(blk))]
339+
return [(blk.mgr_locs, JoinUnit(blk, mgr_shape))]
338340

339341
blknos = mgr.blknos
340342
blklocs = mgr.blklocs
341343

342344
plan = []
343345
for blkno, placements in libinternals.get_blkno_placements(blknos, group=False):
344346

345-
# Assertions disabled for performance; these should always hold
346-
# assert placements.is_slice_like
347-
# assert blkno != -1
347+
assert placements.is_slice_like
348+
assert blkno != -1
349+
350+
shape_list = list(mgr_shape)
351+
shape_list[0] = len(placements)
352+
shape = tuple(shape_list)
348353

349354
blk = mgr.blocks[blkno]
350355
ax0_blk_indexer = blklocs[placements.indexer]
@@ -374,16 +379,19 @@ def _get_mgr_concatenation_plan(mgr: BlockManager):
374379

375380
# Assertions disabled for performance
376381
# assert blk._mgr_locs.as_slice == placements.as_slice
377-
unit = JoinUnit(blk)
382+
# assert blk.shape[0] == shape[0]
383+
unit = JoinUnit(blk, shape)
378384

379385
plan.append((placements, unit))
380386

381387
return plan
382388

383389

384390
class JoinUnit:
385-
def __init__(self, block: Block) -> None:
391+
def __init__(self, block: Block, shape: Shape) -> None:
392+
# Passing shape explicitly is required for cases when block is None.
386393
self.block = block
394+
self.shape = shape
387395

388396
def __repr__(self) -> str:
389397
return f"{type(self).__name__}({repr(self.block)})"
@@ -396,11 +404,22 @@ def is_na(self) -> bool:
396404
return False
397405

398406
def get_reindexed_values(self, empty_dtype: DtypeObj) -> ArrayLike:
407+
values: ArrayLike
408+
399409
if self.is_na:
400-
return make_na_array(empty_dtype, self.block.shape)
410+
return make_na_array(empty_dtype, self.shape)
401411

402412
else:
403-
return self.block.values
413+
414+
if not self.block._can_consolidate:
415+
# preserve these for validation in concat_compat
416+
return self.block.values
417+
418+
# No dtype upcasting is done here, it will be performed during
419+
# concatenation itself.
420+
values = self.block.values
421+
422+
return values
404423

405424

406425
def make_na_array(dtype: DtypeObj, shape: Shape) -> ArrayLike:
@@ -539,9 +558,6 @@ def _is_uniform_join_units(join_units: list[JoinUnit]) -> bool:
539558
first = join_units[0].block
540559
if first.dtype.kind == "V":
541560
return False
542-
elif len(join_units) == 1:
543-
# only use this path when there is something to concatenate
544-
return False
545561
return (
546562
# exclude cases where a) ju.block is None or b) we have e.g. Int64+int64
547563
all(type(ju.block) is type(first) for ju in join_units)
@@ -554,8 +570,13 @@ def _is_uniform_join_units(join_units: list[JoinUnit]) -> bool:
554570
or ju.block.dtype.kind in ["b", "i", "u"]
555571
for ju in join_units
556572
)
557-
# this also precludes any blocks with dtype.kind == "V", since
558-
# we excluded that case for `first` above.
573+
and
574+
# no blocks that would get missing values (can lead to type upcasts)
575+
# unless we're an extension dtype.
576+
all(not ju.is_na or ju.block.is_extension for ju in join_units)
577+
and
578+
# only use this path when there is something to concatenate
579+
len(join_units) > 1
559580
)
560581

561582

@@ -577,7 +598,10 @@ def _trim_join_unit(join_unit: JoinUnit, length: int) -> JoinUnit:
577598
extra_block = join_unit.block.getitem_block(slice(length, None))
578599
join_unit.block = join_unit.block.getitem_block(slice(length))
579600

580-
return JoinUnit(block=extra_block)
601+
extra_shape = (join_unit.shape[0] - length,) + join_unit.shape[1:]
602+
join_unit.shape = (length,) + join_unit.shape[1:]
603+
604+
return JoinUnit(block=extra_block, shape=extra_shape)
581605

582606

583607
def _combine_concat_plans(plans):

0 commit comments

Comments
 (0)