@@ -1291,7 +1291,7 @@ def shift(self, periods: int, axis: int = 0, fill_value: Any = None) -> List[Blo
12911291
12921292 return [self .make_block (new_values )]
12931293
1294- def where (self , other , cond , errors = "raise" , axis : int = 0 ) -> List [Block ]:
1294+ def where (self , other , cond , errors = "raise" ) -> List [Block ]:
12951295 """
12961296 evaluate the block; return result block(s) from the result
12971297
@@ -1302,14 +1302,14 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
13021302 errors : str, {'raise', 'ignore'}, default 'raise'
13031303 - ``raise`` : allow exceptions to be raised
13041304 - ``ignore`` : suppress exceptions. On error return original object
1305- axis : int, default 0
13061305
13071306 Returns
13081307 -------
13091308 List[Block]
13101309 """
13111310 import pandas .core .computation .expressions as expressions
13121311
1312+ assert cond .ndim == self .ndim
13131313 assert not isinstance (other , (ABCIndex , ABCSeries , ABCDataFrame ))
13141314
13151315 assert errors in ["raise" , "ignore" ]
@@ -1322,7 +1322,7 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
13221322
13231323 icond , noop = validate_putmask (values , ~ cond )
13241324
1325- if is_valid_na_for_dtype (other , self .dtype ) and not self .is_object :
1325+ if is_valid_na_for_dtype (other , self .dtype ) and self .dtype != _dtype_obj :
13261326 other = self .fill_value
13271327
13281328 if noop :
@@ -1335,7 +1335,7 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
13351335 # we cannot coerce, return a compat dtype
13361336 # we are explicitly ignoring errors
13371337 block = self .coerce_to_target_dtype (other )
1338- blocks = block .where (orig_other , cond , errors = errors , axis = axis )
1338+ blocks = block .where (orig_other , cond , errors = errors )
13391339 return self ._maybe_downcast (blocks , "infer" )
13401340
13411341 # error: Argument 1 to "setitem_datetimelike_compat" has incompatible type
@@ -1364,7 +1364,7 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
13641364 cond = ~ icond
13651365 axis = cond .ndim - 1
13661366 cond = cond .swapaxes (axis , 0 )
1367- mask = np . array ([ cond [ i ] .all () for i in range ( cond . shape [ 0 ])], dtype = bool )
1367+ mask = cond .all (axis = 1 )
13681368
13691369 result_blocks : List [Block ] = []
13701370 for m in [mask , ~ mask ]:
@@ -1670,7 +1670,7 @@ def shift(self, periods: int, axis: int = 0, fill_value: Any = None) -> List[Blo
16701670 new_values = self .values .shift (periods = periods , fill_value = fill_value )
16711671 return [self .make_block_same_class (new_values )]
16721672
1673- def where (self , other , cond , errors = "raise" , axis : int = 0 ) -> List [Block ]:
1673+ def where (self , other , cond , errors = "raise" ) -> List [Block ]:
16741674
16751675 cond = extract_bool_array (cond )
16761676 assert not isinstance (other , (ABCIndex , ABCSeries , ABCDataFrame ))
@@ -1828,7 +1828,7 @@ def putmask(self, mask, new) -> List[Block]:
18281828 arr .T .putmask (mask , new )
18291829 return [self ]
18301830
1831- def where (self , other , cond , errors = "raise" , axis : int = 0 ) -> List [Block ]:
1831+ def where (self , other , cond , errors = "raise" ) -> List [Block ]:
18321832 # TODO(EA2D): reshape unnecessary with 2D EAs
18331833 arr = self .array_values ().reshape (self .shape )
18341834
@@ -1837,7 +1837,7 @@ def where(self, other, cond, errors="raise", axis: int = 0) -> List[Block]:
18371837 try :
18381838 res_values = arr .T .where (cond , other ).T
18391839 except (ValueError , TypeError ):
1840- return super ().where (other , cond , errors = errors , axis = axis )
1840+ return super ().where (other , cond , errors = errors )
18411841
18421842 # TODO(EA2D): reshape not needed with 2D EAs
18431843 res_values = res_values .reshape (self .values .shape )
0 commit comments