33import functools
44from typing import (
55 TYPE_CHECKING ,
6- Optional ,
76 overload ,
87)
98
@@ -33,7 +32,6 @@ def take_nd(
3332 arr : np .ndarray ,
3433 indexer ,
3534 axis : int = ...,
36- out : Optional [np .ndarray ] = ...,
3735 fill_value = ...,
3836 allow_fill : bool = ...,
3937) -> np .ndarray :
@@ -45,7 +43,6 @@ def take_nd(
4543 arr : ExtensionArray ,
4644 indexer ,
4745 axis : int = ...,
48- out : Optional [np .ndarray ] = ...,
4946 fill_value = ...,
5047 allow_fill : bool = ...,
5148) -> ArrayLike :
@@ -56,7 +53,6 @@ def take_nd(
5653 arr : ArrayLike ,
5754 indexer ,
5855 axis : int = 0 ,
59- out : Optional [np .ndarray ] = None ,
6056 fill_value = lib .no_default ,
6157 allow_fill : bool = True ,
6258) -> ArrayLike :
@@ -79,10 +75,6 @@ def take_nd(
7975 indices are filed with fill_value
8076 axis : int, default 0
8177 Axis to take from
82- out : ndarray or None, default None
83- Optional output array, must be appropriate type to hold input and
84- fill_value together, if indexer has any -1 value entries; call
85- maybe_promote to determine this type for any fill_value
8678 fill_value : any, default np.nan
8779 Fill value to replace -1 values with
8880 allow_fill : boolean, default True
@@ -104,14 +96,13 @@ def take_nd(
10496 return arr .take (indexer , fill_value = fill_value , allow_fill = allow_fill )
10597
10698 arr = np .asarray (arr )
107- return _take_nd_ndarray (arr , indexer , axis , out , fill_value , allow_fill )
99+ return _take_nd_ndarray (arr , indexer , axis , fill_value , allow_fill )
108100
109101
110102def _take_nd_ndarray (
111103 arr : np .ndarray ,
112104 indexer ,
113105 axis : int ,
114- out : Optional [np .ndarray ],
115106 fill_value ,
116107 allow_fill : bool ,
117108) -> np .ndarray :
@@ -121,8 +112,12 @@ def _take_nd_ndarray(
121112 dtype , fill_value = arr .dtype , arr .dtype .type ()
122113 else :
123114 indexer = ensure_platform_int (indexer )
124- indexer , dtype , fill_value , mask_info = _take_preprocess_indexer_and_fill_value (
125- arr , indexer , out , fill_value , allow_fill
115+
116+ if not allow_fill :
117+ return arr .take (indexer , axis = axis )
118+
119+ dtype , fill_value , mask_info = _take_preprocess_indexer_and_fill_value (
120+ arr , indexer , fill_value
126121 )
127122
128123 flip_order = False
@@ -132,23 +127,20 @@ def _take_nd_ndarray(
132127 if flip_order :
133128 arr = arr .T
134129 axis = arr .ndim - axis - 1
135- if out is not None :
136- out = out .T
137130
138131 # at this point, it's guaranteed that dtype can hold both the arr values
139132 # and the fill_value
140- if out is None :
141- out_shape_ = list (arr .shape )
142- out_shape_ [axis ] = len (indexer )
143- out_shape = tuple (out_shape_ )
144- if arr .flags .f_contiguous and axis == arr .ndim - 1 :
145- # minor tweak that can make an order-of-magnitude difference
146- # for dataframes initialized directly from 2-d ndarrays
147- # (s.t. df.values is c-contiguous and df._mgr.blocks[0] is its
148- # f-contiguous transpose)
149- out = np .empty (out_shape , dtype = dtype , order = "F" )
150- else :
151- out = np .empty (out_shape , dtype = dtype )
133+ out_shape_ = list (arr .shape )
134+ out_shape_ [axis ] = len (indexer )
135+ out_shape = tuple (out_shape_ )
136+ if arr .flags .f_contiguous and axis == arr .ndim - 1 :
137+ # minor tweak that can make an order-of-magnitude difference
138+ # for dataframes initialized directly from 2-d ndarrays
139+ # (s.t. df.values is c-contiguous and df._mgr.blocks[0] is its
140+ # f-contiguous transpose)
141+ out = np .empty (out_shape , dtype = dtype , order = "F" )
142+ else :
143+ out = np .empty (out_shape , dtype = dtype )
152144
153145 func = _get_take_nd_function (
154146 arr .ndim , arr .dtype , out .dtype , axis = axis , mask_info = mask_info
@@ -192,8 +184,8 @@ def take_1d(
192184 if not allow_fill :
193185 return arr .take (indexer )
194186
195- indexer , dtype , fill_value , mask_info = _take_preprocess_indexer_and_fill_value (
196- arr , indexer , None , fill_value , allow_fill
187+ dtype , fill_value , mask_info = _take_preprocess_indexer_and_fill_value (
188+ arr , indexer , fill_value
197189 )
198190
199191 # at this point, it's guaranteed that dtype can hold both the arr values
@@ -517,32 +509,22 @@ def _take_2d_multi_object(
517509def _take_preprocess_indexer_and_fill_value (
518510 arr : np .ndarray ,
519511 indexer : np .ndarray ,
520- out : Optional [np .ndarray ],
521512 fill_value ,
522- allow_fill : bool ,
523513):
524514 mask_info = None
525515
526- if not allow_fill :
527- dtype , fill_value = arr .dtype , arr .dtype .type ()
528- mask_info = None , False
529- else :
530- # check for promotion based on types only (do this first because
531- # it's faster than computing a mask)
532- dtype , fill_value = maybe_promote (arr .dtype , fill_value )
533- if dtype != arr .dtype and (out is None or out .dtype != dtype ):
534- # check if promotion is actually required based on indexer
535- mask = indexer == - 1
536- needs_masking = mask .any ()
537- mask_info = mask , needs_masking
538- if needs_masking :
539- if out is not None and out .dtype != dtype :
540- raise TypeError ("Incompatible type for fill_value" )
541- else :
542- # if not, then depromote, set fill_value to dummy
543- # (it won't be used but we don't want the cython code
544- # to crash when trying to cast it to dtype)
545- dtype , fill_value = arr .dtype , arr .dtype .type ()
546-
547- indexer = ensure_platform_int (indexer )
548- return indexer , dtype , fill_value , mask_info
516+ # check for promotion based on types only (do this first because
517+ # it's faster than computing a mask)
518+ dtype , fill_value = maybe_promote (arr .dtype , fill_value )
519+ if dtype != arr .dtype :
520+ # check if promotion is actually required based on indexer
521+ mask = indexer == - 1
522+ needs_masking = mask .any ()
523+ mask_info = mask , needs_masking
524+ if not needs_masking :
525+ # if not, then depromote, set fill_value to dummy
526+ # (it won't be used but we don't want the cython code
527+ # to crash when trying to cast it to dtype)
528+ dtype , fill_value = arr .dtype , arr .dtype .type ()
529+
530+ return dtype , fill_value , mask_info
0 commit comments