@@ -212,8 +212,6 @@ def concatenate_managers(
212
212
for placement , join_units in concat_plan :
213
213
unit = join_units [0 ]
214
214
blk = unit .block
215
- # Assertion disabled for performance
216
- # assert len(join_units) == len(mgrs_indexers)
217
215
218
216
if len (join_units ) == 1 :
219
217
values = blk .values
@@ -331,20 +329,27 @@ def _get_mgr_concatenation_plan(mgr: BlockManager):
331
329
plan : list of (BlockPlacement, JoinUnit) tuples
332
330
333
331
"""
332
+ # Calculate post-reindex shape , save for item axis which will be separate
333
+ # for each block anyway.
334
+ mgr_shape_list = list (mgr .shape )
335
+ mgr_shape = tuple (mgr_shape_list )
334
336
335
337
if mgr .is_single_block :
336
338
blk = mgr .blocks [0 ]
337
- return [(blk .mgr_locs , JoinUnit (blk ))]
339
+ return [(blk .mgr_locs , JoinUnit (blk , mgr_shape ))]
338
340
339
341
blknos = mgr .blknos
340
342
blklocs = mgr .blklocs
341
343
342
344
plan = []
343
345
for blkno , placements in libinternals .get_blkno_placements (blknos , group = False ):
344
346
345
- # Assertions disabled for performance; these should always hold
346
- # assert placements.is_slice_like
347
- # assert blkno != -1
347
+ assert placements .is_slice_like
348
+ assert blkno != - 1
349
+
350
+ shape_list = list (mgr_shape )
351
+ shape_list [0 ] = len (placements )
352
+ shape = tuple (shape_list )
348
353
349
354
blk = mgr .blocks [blkno ]
350
355
ax0_blk_indexer = blklocs [placements .indexer ]
@@ -374,16 +379,19 @@ def _get_mgr_concatenation_plan(mgr: BlockManager):
374
379
375
380
# Assertions disabled for performance
376
381
# assert blk._mgr_locs.as_slice == placements.as_slice
377
- unit = JoinUnit (blk )
382
+ # assert blk.shape[0] == shape[0]
383
+ unit = JoinUnit (blk , shape )
378
384
379
385
plan .append ((placements , unit ))
380
386
381
387
return plan
382
388
383
389
384
390
class JoinUnit :
385
- def __init__ (self , block : Block ) -> None :
391
+ def __init__ (self , block : Block , shape : Shape ) -> None :
392
+ # Passing shape explicitly is required for cases when block is None.
386
393
self .block = block
394
+ self .shape = shape
387
395
388
396
def __repr__ (self ) -> str :
389
397
return f"{ type (self ).__name__ } ({ repr (self .block )} )"
@@ -396,11 +404,22 @@ def is_na(self) -> bool:
396
404
return False
397
405
398
406
def get_reindexed_values (self , empty_dtype : DtypeObj ) -> ArrayLike :
407
+ values : ArrayLike
408
+
399
409
if self .is_na :
400
- return make_na_array (empty_dtype , self .block . shape )
410
+ return make_na_array (empty_dtype , self .shape )
401
411
402
412
else :
403
- return self .block .values
413
+
414
+ if not self .block ._can_consolidate :
415
+ # preserve these for validation in concat_compat
416
+ return self .block .values
417
+
418
+ # No dtype upcasting is done here, it will be performed during
419
+ # concatenation itself.
420
+ values = self .block .values
421
+
422
+ return values
404
423
405
424
406
425
def make_na_array (dtype : DtypeObj , shape : Shape ) -> ArrayLike :
@@ -539,9 +558,6 @@ def _is_uniform_join_units(join_units: list[JoinUnit]) -> bool:
539
558
first = join_units [0 ].block
540
559
if first .dtype .kind == "V" :
541
560
return False
542
- elif len (join_units ) == 1 :
543
- # only use this path when there is something to concatenate
544
- return False
545
561
return (
546
562
# exclude cases where a) ju.block is None or b) we have e.g. Int64+int64
547
563
all (type (ju .block ) is type (first ) for ju in join_units )
@@ -554,8 +570,13 @@ def _is_uniform_join_units(join_units: list[JoinUnit]) -> bool:
554
570
or ju .block .dtype .kind in ["b" , "i" , "u" ]
555
571
for ju in join_units
556
572
)
557
- # this also precludes any blocks with dtype.kind == "V", since
558
- # we excluded that case for `first` above.
573
+ and
574
+ # no blocks that would get missing values (can lead to type upcasts)
575
+ # unless we're an extension dtype.
576
+ all (not ju .is_na or ju .block .is_extension for ju in join_units )
577
+ and
578
+ # only use this path when there is something to concatenate
579
+ len (join_units ) > 1
559
580
)
560
581
561
582
@@ -577,7 +598,10 @@ def _trim_join_unit(join_unit: JoinUnit, length: int) -> JoinUnit:
577
598
extra_block = join_unit .block .getitem_block (slice (length , None ))
578
599
join_unit .block = join_unit .block .getitem_block (slice (length ))
579
600
580
- return JoinUnit (block = extra_block )
601
+ extra_shape = (join_unit .shape [0 ] - length ,) + join_unit .shape [1 :]
602
+ join_unit .shape = (length ,) + join_unit .shape [1 :]
603
+
604
+ return JoinUnit (block = extra_block , shape = extra_shape )
581
605
582
606
583
607
def _combine_concat_plans (plans ):
0 commit comments