Skip to content

TYP: Type hints & assert statements #42044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
4 changes: 3 additions & 1 deletion pandas/_typing.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
Dict,
Hashable,
List,
Literal,
Mapping,
Optional,
Sequence,
@@ -37,7 +38,6 @@
# https://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles
if TYPE_CHECKING:
from typing import (
Literal,
TypedDict,
final,
)
@@ -123,6 +123,8 @@
Frequency = Union[str, "DateOffset"]
Axes = Collection[Any]
RandomState = Union[int, ArrayLike, np.random.Generator, np.random.RandomState]
MergeTypes = Literal["inner", "outer", "left", "right", "cross"]
ConcatTypes = Literal["inner", "outer"]

# dtypes
NpDtype = Union[str, np.dtype]
11 changes: 9 additions & 2 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
Hashable,
Iterator,
List,
Literal,
cast,
)
import warnings
@@ -518,7 +519,10 @@ def apply_multiple(self) -> FrameOrSeriesUnion:
return self.obj.aggregate(self.f, self.axis, *self.args, **self.kwargs)

def normalize_dictlike_arg(
self, how: str, obj: FrameOrSeriesUnion, func: AggFuncTypeDict
self,
how: Literal["apply", "agg", "transform"],
obj: FrameOrSeriesUnion,
func: AggFuncTypeDict,
) -> AggFuncTypeDict:
"""
Handler for dict-like argument.
@@ -527,7 +531,10 @@ def normalize_dictlike_arg(
that a nested renamer is not passed. Also normalizes to all lists
when values consists of a mix of list and non-lists.
"""
assert how in ("apply", "agg", "transform")
if how not in ("apply", "agg", "transform"):
raise ValueError(
"Value for how argument must be one of : apply, agg, transform"
)

# Can't use func.values(); wouldn't work for a Series
if (
7 changes: 5 additions & 2 deletions pandas/core/arrays/_ranges.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,8 @@
"""
from __future__ import annotations

from typing import Literal

import numpy as np

from pandas._libs.lib import i8max
@@ -75,7 +77,7 @@ def generate_regular_range(


def _generate_range_overflow_safe(
endpoint: int, periods: int, stride: int, side: str = "start"
endpoint: int, periods: int, stride: int, side: Literal["start", "end"] = "start"
) -> int:
"""
Calculate the second endpoint for passing to np.arange, checking
@@ -142,13 +144,14 @@ def _generate_range_overflow_safe(


def _generate_range_overflow_safe_signed(
endpoint: int, periods: int, stride: int, side: str
endpoint: int, periods: int, stride: int, side: Literal["start", "end"]
) -> int:
"""
A special case for _generate_range_overflow_safe where `periods * stride`
can be calculated without overflowing int64 bounds.
"""
assert side in ["start", "end"]

if side == "end":
stride *= -1

26 changes: 15 additions & 11 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
@@ -37,6 +37,7 @@
to_offset,
tzconversion,
)
from pandas._typing import Dtype
from pandas.errors import PerformanceWarning

from pandas.core.dtypes.cast import astype_dt64_to_dt64tz
@@ -1967,12 +1968,12 @@ def sequence_to_datetimes(

def sequence_to_dt64ns(
data,
dtype=None,
copy=False,
tz=None,
dayfirst=False,
yearfirst=False,
ambiguous="raise",
dtype: Dtype | None = None,
copy: bool = False,
tz: tzinfo | str = None,
dayfirst: bool = False,
yearfirst: bool = False,
ambiguous: str | bool = "raise",
*,
allow_object: bool = False,
allow_mixed: bool = False,
@@ -2126,10 +2127,10 @@ def sequence_to_dt64ns(

def objects_to_datetime64ns(
data: np.ndarray,
dayfirst,
yearfirst,
utc=False,
errors="raise",
dayfirst: bool,
yearfirst: bool,
utc: bool = False,
errors: Literal["raise", "coerce", "ignore"] = "raise",
require_iso8601: bool = False,
allow_object: bool = False,
allow_mixed: bool = False,
@@ -2164,7 +2165,10 @@ def objects_to_datetime64ns(
------
ValueError : if data cannot be converted to datetimes
"""
assert errors in ["raise", "ignore", "coerce"]
if errors not in ["raise", "ignore", "coerce"]:
raise ValueError(
"Value for errors argument must be one of: raise, coerce, ignore"
)

# if str-dtype, convert
data = np.array(data, copy=False, dtype=np.object_)
2 changes: 1 addition & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
@@ -9191,7 +9191,7 @@ def merge(
sort: bool = False,
suffixes: Suffixes = ("_x", "_y"),
copy: bool = True,
indicator: bool = False,
indicator: bool | str = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure about this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per _MergeOperation you can optionally supply a string as the column name, otherwise the indicator is given a default name:

if isinstance(self.indicator, str):
self.indicator_name = self.indicator
elif isinstance(self.indicator, bool):
self.indicator_name = "_merge" if self.indicator else None

Maybe I missed this, but is there a "valid column name" definition that would be more specific than str?

validate: str | None = None,
) -> DataFrame:
from pandas.core.reshape.merge import merge
9 changes: 7 additions & 2 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
@@ -5691,7 +5691,9 @@ def _validate_indexer(self, form: str_t, key, kind: str_t):
if key is not None and not is_integer(key):
raise self._invalid_indexer(form, key)

def _maybe_cast_slice_bound(self, label, side: str_t, kind=no_default):
def _maybe_cast_slice_bound(
self, label, side: str_t, kind: Literal["loc", "getitem"] = no_default
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think needs to be amended to include no_default?

):
"""
This function should be overloaded in subclasses that allow non-trivial
casting on label-slice bounds, e.g. datetime-like indices allowing
@@ -5755,7 +5757,10 @@ def get_slice_bound(self, label, side: str_t, kind=None) -> int:
int
Index of label.
"""
assert kind in ["loc", "getitem", None]
if kind not in ["loc", "getitem", None]:
raise ValueError(
"Value for kind argument must be one of: loc, getitem or None"
)

if side not in ("left", "right"):
raise ValueError(
3 changes: 2 additions & 1 deletion pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
@@ -1200,7 +1200,8 @@ def where(self, other, cond, errors="raise") -> list[Block]:
assert cond.ndim == self.ndim
assert not isinstance(other, (ABCIndex, ABCSeries, ABCDataFrame))

assert errors in ["raise", "ignore"]
if errors not in ["raise", "ignore"]:
raise ValueError("Value for errors argument must be one of: raise, ignore")
transpose = self.ndim == 2

values = self.values
6 changes: 4 additions & 2 deletions pandas/core/missing.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
from typing import (
TYPE_CHECKING,
Any,
Literal,
cast,
)

@@ -164,7 +165,7 @@ def clean_interp_method(method: str, index: Index, **kwargs) -> str:
return method


def find_valid_index(values, *, how: str) -> int | None:
def find_valid_index(values, *, how: Literal["first", "last"]) -> int | None:
"""
Retrieves the index of the first valid value.

@@ -178,7 +179,8 @@ def find_valid_index(values, *, how: str) -> int | None:
-------
int or None
"""
assert how in ["first", "last"]
if how not in ["first", "last"]:
raise ValueError("Value for how argument must be one of : first, last")

if len(values) == 0: # early stop
return None
16 changes: 9 additions & 7 deletions pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,9 @@
DtypeObj,
FrameOrSeries,
IndexLabel,
MergeTypes,
Suffixes,
TimedeltaConvertibleTypes,
)
from pandas.errors import MergeError
from pandas.util._decorators import (
@@ -92,7 +94,7 @@
def merge(
left: DataFrame | Series,
right: DataFrame | Series,
how: str = "inner",
how: MergeTypes = "inner",
on: IndexLabel | None = None,
left_on: IndexLabel | None = None,
right_on: IndexLabel | None = None,
@@ -101,7 +103,7 @@ def merge(
sort: bool = False,
suffixes: Suffixes = ("_x", "_y"),
copy: bool = True,
indicator: bool = False,
indicator: bool | str = False,
validate: str | None = None,
) -> DataFrame:
op = _MergeOperation(
@@ -331,11 +333,11 @@ def merge_asof(
right_on: IndexLabel | None = None,
left_index: bool = False,
right_index: bool = False,
by=None,
left_by=None,
right_by=None,
by: IndexLabel | None = None,
left_by: Hashable | None = None,
right_by: Hashable | None = None,
suffixes: Suffixes = ("_x", "_y"),
tolerance=None,
tolerance: None | TimedeltaConvertibleTypes = None,
allow_exact_matches: bool = True,
direction: str = "backward",
) -> DataFrame:
@@ -622,7 +624,7 @@ def __init__(
sort: bool = True,
suffixes: Suffixes = ("_x", "_y"),
copy: bool = True,
indicator: bool = False,
indicator: bool | str = False,
validate: str | None = None,
):
_left = _validate_operand(left)
3 changes: 2 additions & 1 deletion pandas/io/excel/_util.py
Original file line number Diff line number Diff line change
@@ -59,7 +59,8 @@ def get_default_engine(ext, mode="reader"):
"xls": "xlwt",
"ods": "odf",
}
assert mode in ["reader", "writer"]
if mode not in ["reader", "writer"]:
raise ValueError('File mode must be either "reader" or "writer".')
if mode == "writer":
# Prefer xlsxwriter over openpyxl if installed
xlsxwriter = import_optional_dependency("xlsxwriter", errors="warn")
6 changes: 3 additions & 3 deletions pandas/tseries/frequencies.py
Original file line number Diff line number Diff line change
@@ -445,7 +445,7 @@ def _maybe_add_count(base: str, count: float) -> str:
# Frequency comparison


def is_subperiod(source, target) -> bool:
def is_subperiod(source: str | DateOffset, target: str | DateOffset) -> bool:
"""
Returns True if downsampling is possible between source and target
frequencies
@@ -501,7 +501,7 @@ def is_subperiod(source, target) -> bool:
return False


def is_superperiod(source, target) -> bool:
def is_superperiod(source: str | DateOffset, target: str | DateOffset) -> bool:
"""
Returns True if upsampling is possible between source and target
frequencies
@@ -559,7 +559,7 @@ def is_superperiod(source, target) -> bool:
return False


def _maybe_coerce_freq(code) -> str:
def _maybe_coerce_freq(code: str | DateOffset) -> str:
"""we might need to coerce a code to a rule_code
and uppercase it

Loading