|  | 
| 13 | 13 | import numpy as np | 
| 14 | 14 | 
 | 
| 15 | 15 | import pandas._libs.lib as lib | 
| 16 |  | -from pandas._typing import FrameOrSeriesUnion | 
|  | 16 | +from pandas._typing import ( | 
|  | 17 | +    DtypeObj, | 
|  | 18 | +    FrameOrSeriesUnion, | 
|  | 19 | +) | 
| 17 | 20 | from pandas.util._decorators import Appender | 
| 18 | 21 | 
 | 
| 19 | 22 | from pandas.core.dtypes.common import ( | 
| @@ -209,8 +212,12 @@ def _validate(data): | 
| 209 | 212 |         # see _libs/lib.pyx for list of inferred types | 
| 210 | 213 |         allowed_types = ["string", "empty", "bytes", "mixed", "mixed-integer"] | 
| 211 | 214 | 
 | 
| 212 |  | -        values = getattr(data, "values", data)  # Series / Index | 
| 213 |  | -        values = getattr(values, "categories", values)  # categorical / normal | 
|  | 215 | +        # TODO: avoid kludge for tests.extension.test_numpy | 
|  | 216 | +        from pandas.core.internals.managers import _extract_array | 
|  | 217 | + | 
|  | 218 | +        data = _extract_array(data) | 
|  | 219 | + | 
|  | 220 | +        values = getattr(data, "categories", data)  # categorical / normal | 
| 214 | 221 | 
 | 
| 215 | 222 |         inferred_dtype = lib.infer_dtype(values, skipna=True) | 
| 216 | 223 | 
 | 
| @@ -242,6 +249,7 @@ def _wrap_result( | 
| 242 | 249 |         expand: bool | None = None, | 
| 243 | 250 |         fill_value=np.nan, | 
| 244 | 251 |         returns_string=True, | 
|  | 252 | +        returns_bool: bool = False, | 
| 245 | 253 |     ): | 
| 246 | 254 |         from pandas import ( | 
| 247 | 255 |             Index, | 
| @@ -319,19 +327,25 @@ def cons_row(x): | 
| 319 | 327 |         else: | 
| 320 | 328 |             index = self._orig.index | 
| 321 | 329 |             # This is a mess. | 
| 322 |  | -            dtype: str | None | 
| 323 |  | -            if self._is_string and returns_string: | 
| 324 |  | -                dtype = self._orig.dtype | 
|  | 330 | +            dtype: DtypeObj | str | None | 
|  | 331 | +            vdtype = getattr(result, "dtype", None) | 
|  | 332 | +            if self._is_string: | 
|  | 333 | +                if is_bool_dtype(vdtype): | 
|  | 334 | +                    dtype = result.dtype | 
|  | 335 | +                elif returns_string: | 
|  | 336 | +                    dtype = self._orig.dtype | 
|  | 337 | +                else: | 
|  | 338 | +                    dtype = vdtype | 
| 325 | 339 |             else: | 
| 326 |  | -                dtype = None | 
|  | 340 | +                dtype = vdtype | 
| 327 | 341 | 
 | 
| 328 | 342 |             if expand: | 
| 329 | 343 |                 cons = self._orig._constructor_expanddim | 
| 330 | 344 |                 result = cons(result, columns=name, index=index, dtype=dtype) | 
| 331 | 345 |             else: | 
| 332 | 346 |                 # Must be a Series | 
| 333 | 347 |                 cons = self._orig._constructor | 
| 334 |  | -                result = cons(result, name=name, index=index) | 
|  | 348 | +                result = cons(result, name=name, index=index, dtype=dtype) | 
| 335 | 349 |             result = result.__finalize__(self._orig, method="str") | 
| 336 | 350 |             if name is not None and result.ndim == 1: | 
| 337 | 351 |                 # __finalize__ might copy over the original name, but we may | 
| @@ -369,7 +383,7 @@ def _get_series_list(self, others): | 
| 369 | 383 |         if isinstance(others, ABCSeries): | 
| 370 | 384 |             return [others] | 
| 371 | 385 |         elif isinstance(others, ABCIndex): | 
| 372 |  | -            return [Series(others._values, index=idx)] | 
|  | 386 | +            return [Series(others._values, index=idx, dtype=others.dtype)] | 
| 373 | 387 |         elif isinstance(others, ABCDataFrame): | 
| 374 | 388 |             return [others[x] for x in others] | 
| 375 | 389 |         elif isinstance(others, np.ndarray) and others.ndim == 2: | 
| @@ -547,7 +561,7 @@ def cat(self, others=None, sep=None, na_rep=None, join="left"): | 
| 547 | 561 |             sep = "" | 
| 548 | 562 | 
 | 
| 549 | 563 |         if isinstance(self._orig, ABCIndex): | 
| 550 |  | -            data = Series(self._orig, index=self._orig) | 
|  | 564 | +            data = Series(self._orig, index=self._orig, dtype=self._orig.dtype) | 
| 551 | 565 |         else:  # Series | 
| 552 | 566 |             data = self._orig | 
| 553 | 567 | 
 | 
|  | 
0 commit comments