Skip to content

Commit

Permalink
TYP: type all arguments with str default values (pandas-dev#48508)
Browse files Browse the repository at this point in the history
* TYP: type all arguments with str default values

* na_rep: back to str

* na(t)_rep is always a string

* add float for some functions

* and the same for the few float default arguments

* define a few more literal constants

* avoid itertools.cycle mypy error

* revert mistake
  • Loading branch information
twoertwein authored Sep 22, 2022
1 parent a375061 commit 1c51e60
Show file tree
Hide file tree
Showing 53 changed files with 419 additions and 207 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ repos:
|/_testing/
- id: autotyping
name: autotyping
entry: python -m libcst.tool codemod autotyping.AutotypeCommand --none-return --scalar-return --annotate-magics --annotate-imprecise-magics --bool-param
entry: python -m libcst.tool codemod autotyping.AutotypeCommand --none-return --scalar-return --annotate-magics --annotate-imprecise-magics --bool-param --bytes-param --str-param --float-param
types_or: [python, pyi]
files: ^pandas
exclude: ^(pandas/tests|pandas/_version.py|pandas/io/clipboard)
Expand Down
25 changes: 16 additions & 9 deletions pandas/_testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
set_locale,
)

from pandas._typing import Dtype
from pandas._typing import (
Dtype,
Frequency,
)
from pandas.compat import pa_version_under1p01

from pandas.core.dtypes.common import (
Expand Down Expand Up @@ -401,13 +404,17 @@ def makeFloatIndex(k=10, name=None) -> Float64Index:
return Float64Index(base_idx)


def makeDateIndex(k: int = 10, freq="B", name=None, **kwargs) -> DatetimeIndex:
def makeDateIndex(
k: int = 10, freq: Frequency = "B", name=None, **kwargs
) -> DatetimeIndex:
dt = datetime(2000, 1, 1)
dr = bdate_range(dt, periods=k, freq=freq, name=name)
return DatetimeIndex(dr, name=name, **kwargs)


def makeTimedeltaIndex(k: int = 10, freq="D", name=None, **kwargs) -> TimedeltaIndex:
def makeTimedeltaIndex(
k: int = 10, freq: Frequency = "D", name=None, **kwargs
) -> TimedeltaIndex:
return pd.timedelta_range(start="1 day", periods=k, freq=freq, name=name, **kwargs)


Expand Down Expand Up @@ -484,7 +491,7 @@ def getSeriesData() -> dict[str, Series]:
return {c: Series(np.random.randn(_N), index=index) for c in getCols(_K)}


def makeTimeSeries(nper=None, freq="B", name=None) -> Series:
def makeTimeSeries(nper=None, freq: Frequency = "B", name=None) -> Series:
if nper is None:
nper = _N
return Series(
Expand All @@ -498,7 +505,7 @@ def makePeriodSeries(nper=None, name=None) -> Series:
return Series(np.random.randn(nper), index=makePeriodIndex(nper), name=name)


def getTimeSeriesData(nper=None, freq="B") -> dict[str, Series]:
def getTimeSeriesData(nper=None, freq: Frequency = "B") -> dict[str, Series]:
return {c: makeTimeSeries(nper, freq) for c in getCols(_K)}


Expand All @@ -507,7 +514,7 @@ def getPeriodData(nper=None) -> dict[str, Series]:


# make frame
def makeTimeDataFrame(nper=None, freq="B") -> DataFrame:
def makeTimeDataFrame(nper=None, freq: Frequency = "B") -> DataFrame:
data = getTimeSeriesData(nper, freq)
return DataFrame(data)

Expand Down Expand Up @@ -542,7 +549,7 @@ def makePeriodFrame(nper=None) -> DataFrame:
def makeCustomIndex(
nentries,
nlevels,
prefix="#",
prefix: str = "#",
names: bool | str | list[str] | None = False,
ndupe_l=None,
idx_type=None,
Expand Down Expand Up @@ -760,7 +767,7 @@ def makeCustomDataframe(
return DataFrame(data, index, columns, dtype=dtype)


def _create_missing_idx(nrows, ncols, density, random_state=None):
def _create_missing_idx(nrows, ncols, density: float, random_state=None):
if random_state is None:
random_state = np.random
else:
Expand All @@ -787,7 +794,7 @@ def _gen_unique_rand(rng, _extra_size):
return i.tolist(), j.tolist()


def makeMissingDataframe(density=0.9, random_state=None) -> DataFrame:
def makeMissingDataframe(density: float = 0.9, random_state=None) -> DataFrame:
df = makeDataFrame()
i, j = _create_missing_idx(*df.shape, density=density, random_state=random_state)
df.values[i, j] = np.nan
Expand Down
4 changes: 2 additions & 2 deletions pandas/_testing/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def dec(f):
@optional_args # type: ignore[misc]
def network(
t,
url="https://www.google.com",
url: str = "https://www.google.com",
raise_on_error: bool = False,
check_before_test: bool = False,
error_classes=None,
Expand Down Expand Up @@ -369,7 +369,7 @@ def round_trip_localpath(writer, reader, path: str | None = None):
return obj


def write_to_compressed(compression, path, data, dest="test"):
def write_to_compressed(compression, path, data, dest: str = "test"):
"""
Write data to a compressed file.
Expand Down
4 changes: 3 additions & 1 deletion pandas/_testing/_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import numpy as np

from pandas._typing import NpDtype


def randbool(size=(), p: float = 0.5):
return np.random.rand(*size) <= p
Expand All @@ -14,7 +16,7 @@ def randbool(size=(), p: float = 0.5):
)


def rands_array(nchars, size, dtype="O", replace: bool = True) -> np.ndarray:
def rands_array(nchars, size, dtype: NpDtype = "O", replace: bool = True) -> np.ndarray:
"""
Generate an array of byte strings.
"""
Expand Down
22 changes: 12 additions & 10 deletions pandas/_testing/asserters.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def assert_index_equal(
"""
__tracebackhide__ = True

def _check_types(left, right, obj="Index") -> None:
def _check_types(left, right, obj: str = "Index") -> None:
if not exact:
return

Expand Down Expand Up @@ -429,7 +429,9 @@ def _get_ilevel_values(index, level):
assert_categorical_equal(left._values, right._values, obj=f"{obj} category")


def assert_class_equal(left, right, exact: bool | str = True, obj="Input") -> None:
def assert_class_equal(
left, right, exact: bool | str = True, obj: str = "Input"
) -> None:
"""
Checks classes are equal.
"""
Expand Down Expand Up @@ -527,7 +529,7 @@ def assert_categorical_equal(
right,
check_dtype: bool = True,
check_category_order: bool = True,
obj="Categorical",
obj: str = "Categorical",
) -> None:
"""
Test that Categoricals are equivalent.
Expand Down Expand Up @@ -584,7 +586,7 @@ def assert_categorical_equal(


def assert_interval_array_equal(
left, right, exact="equiv", obj="IntervalArray"
left, right, exact: bool | Literal["equiv"] = "equiv", obj: str = "IntervalArray"
) -> None:
"""
Test that two IntervalArrays are equivalent.
Expand Down Expand Up @@ -614,15 +616,15 @@ def assert_interval_array_equal(
assert_attr_equal("closed", left, right, obj=obj)


def assert_period_array_equal(left, right, obj="PeriodArray") -> None:
def assert_period_array_equal(left, right, obj: str = "PeriodArray") -> None:
_check_isinstance(left, right, PeriodArray)

assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data")
assert_attr_equal("freq", left, right, obj=obj)


def assert_datetime_array_equal(
left, right, obj="DatetimeArray", check_freq: bool = True
left, right, obj: str = "DatetimeArray", check_freq: bool = True
) -> None:
__tracebackhide__ = True
_check_isinstance(left, right, DatetimeArray)
Expand All @@ -634,7 +636,7 @@ def assert_datetime_array_equal(


def assert_timedelta_array_equal(
left, right, obj="TimedeltaArray", check_freq: bool = True
left, right, obj: str = "TimedeltaArray", check_freq: bool = True
) -> None:
__tracebackhide__ = True
_check_isinstance(left, right, TimedeltaArray)
Expand Down Expand Up @@ -693,7 +695,7 @@ def assert_numpy_array_equal(
check_dtype: bool | Literal["equiv"] = True,
err_msg=None,
check_same=None,
obj="numpy array",
obj: str = "numpy array",
index_values=None,
) -> None:
"""
Expand Down Expand Up @@ -887,7 +889,7 @@ def assert_series_equal(
check_flags: bool = True,
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
obj="Series",
obj: str = "Series",
*,
check_index: bool = True,
check_like: bool = False,
Expand Down Expand Up @@ -1157,7 +1159,7 @@ def assert_frame_equal(
check_flags: bool = True,
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
obj="DataFrame",
obj: str = "DataFrame",
) -> None:
"""
Check that left and right DataFrame are equal.
Expand Down
14 changes: 14 additions & 0 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,17 @@ def closed(self) -> bool:

# dropna
AnyAll = Literal["any", "all"]

MatplotlibColor = Union[str, Sequence[float]]
TimeGrouperOrigin = Union[
"Timestamp", Literal["epoch", "start", "start_day", "end", "end_day"]
]
TimeAmbiguous = Union[Literal["infer", "NaT", "raise"], "npt.NDArray[np.bool_]"]
TimeNonexistent = Union[
Literal["shift_forward", "shift_backward", "NaT", "raise"], timedelta
]
DropKeep = Literal["first", "last", False]
CorrelationMethod = Union[
Literal["pearson", "kendall", "spearman"], Callable[[np.ndarray, np.ndarray], float]
]
AlignJoin = Literal["outer", "inner", "left", "right"]
3 changes: 2 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pandas._typing import (
Dtype,
PositionalIndexer,
SortKind,
TakeIndexer,
npt,
)
Expand Down Expand Up @@ -472,7 +473,7 @@ def isna(self) -> npt.NDArray[np.bool_]:
def argsort(
self,
ascending: bool = True,
kind: str = "quicksort",
kind: SortKind = "quicksort",
na_position: str = "last",
*args,
**kwargs,
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
ScalarIndexer,
SequenceIndexer,
Shape,
SortKind,
TakeIndexer,
npt,
)
Expand Down Expand Up @@ -671,7 +672,7 @@ def _values_for_argsort(self) -> np.ndarray:
def argsort(
self,
ascending: bool = True,
kind: str = "quicksort",
kind: SortKind = "quicksort",
na_position: str = "last",
*args,
**kwargs,
Expand Down
9 changes: 6 additions & 3 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
NpDtype,
Ordered,
Shape,
SortKind,
npt,
type_t,
)
Expand Down Expand Up @@ -1827,7 +1828,7 @@ def check_for_ordered(self, op) -> None:
# error: Signature of "argsort" incompatible with supertype "ExtensionArray"
@deprecate_nonkeyword_arguments(version=None, allowed_args=["self"])
def argsort( # type: ignore[override]
self, ascending: bool = True, kind="quicksort", **kwargs
self, ascending: bool = True, kind: SortKind = "quicksort", **kwargs
):
"""
Return the indices that would sort the Categorical.
Expand Down Expand Up @@ -2200,7 +2201,9 @@ def _repr_footer(self) -> str:
info = self._repr_categories_info()
return f"Length: {len(self)}\n{info}"

def _get_repr(self, length: bool = True, na_rep="NaN", footer: bool = True) -> str:
def _get_repr(
self, length: bool = True, na_rep: str = "NaN", footer: bool = True
) -> str:
from pandas.io.formats import format as fmt

formatter = fmt.CategoricalFormatter(
Expand Down Expand Up @@ -2716,7 +2719,7 @@ def _str_map(
result = PandasArray(categories.to_numpy())._str_map(f, na_value, dtype)
return take_nd(result, codes, fill_value=na_value)

def _str_get_dummies(self, sep="|"):
def _str_get_dummies(self, sep: str = "|"):
# sep may not be in categories. Just bail on this.
from pandas.core.arrays import PandasArray

Expand Down
27 changes: 22 additions & 5 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
PositionalIndexerTuple,
ScalarIndexer,
SequenceIndexer,
TimeAmbiguous,
TimeNonexistent,
npt,
)
from pandas.compat.numpy import function as nv
Expand Down Expand Up @@ -308,7 +310,7 @@ def asi8(self) -> npt.NDArray[np.int64]:
# Rendering Methods

def _format_native_types(
self, *, na_rep="NaT", date_format=None
self, *, na_rep: str | float = "NaT", date_format=None
) -> npt.NDArray[np.object_]:
"""
Helper method for astype when converting to strings.
Expand Down Expand Up @@ -556,7 +558,7 @@ def _concat_same_type(
new_obj._freq = new_freq
return new_obj

def copy(self: DatetimeLikeArrayT, order="C") -> DatetimeLikeArrayT:
def copy(self: DatetimeLikeArrayT, order: str = "C") -> DatetimeLikeArrayT:
# error: Unexpected keyword argument "order" for "copy"
new_obj = super().copy(order=order) # type: ignore[call-arg]
new_obj._freq = self.freq
Expand Down Expand Up @@ -2085,15 +2087,30 @@ def _round(self, freq, mode, ambiguous, nonexistent):
return self._simple_new(result, dtype=self.dtype)

@Appender((_round_doc + _round_example).format(op="round"))
def round(self, freq, ambiguous="raise", nonexistent="raise"):
def round(
self,
freq,
ambiguous: TimeAmbiguous = "raise",
nonexistent: TimeNonexistent = "raise",
):
return self._round(freq, RoundTo.NEAREST_HALF_EVEN, ambiguous, nonexistent)

@Appender((_round_doc + _floor_example).format(op="floor"))
def floor(self, freq, ambiguous="raise", nonexistent="raise"):
def floor(
self,
freq,
ambiguous: TimeAmbiguous = "raise",
nonexistent: TimeNonexistent = "raise",
):
return self._round(freq, RoundTo.MINUS_INFTY, ambiguous, nonexistent)

@Appender((_round_doc + _ceil_example).format(op="ceil"))
def ceil(self, freq, ambiguous="raise", nonexistent="raise"):
def ceil(
self,
freq,
ambiguous: TimeAmbiguous = "raise",
nonexistent: TimeNonexistent = "raise",
):
return self._round(freq, RoundTo.PLUS_INFTY, ambiguous, nonexistent)

# --------------------------------------------------------------
Expand Down
Loading

0 comments on commit 1c51e60

Please sign in to comment.