Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update type hints to PY39 (and fix various errors and mypy nonsense) #372

Merged
merged 4 commits into from
Mar 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/textacy/augmentation/augmenter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import random
from typing import List, Optional, Sequence, Tuple
from typing import Optional, Sequence

from spacy.tokens import Doc

Expand Down Expand Up @@ -46,8 +46,8 @@ class Augmenter:
The jumps over the lazy odg.

Args:
transforms: Ordered sequence of callables that must take List[:obj:`AugTok`]
as their first positional argument and return another List[:obj:`AugTok`].
transforms: Ordered sequence of callables that must take list[:obj:`AugTok`]
as their first positional argument and return another list[:obj:`AugTok`].

.. note:: Although the particular transforms applied may vary doc-by-doc,
they are applied *in order* as listed here. Since some transforms may
Expand Down Expand Up @@ -112,7 +112,7 @@ def apply_transforms(self, doc: Doc, lang: types.LangLike, **kwargs) -> Doc:

def _validate_transforms(
self, transforms: Sequence[types.AugTransform]
) -> Tuple[types.AugTransform, ...]:
) -> tuple[types.AugTransform, ...]:
transforms = tuple(transforms)
if not transforms:
raise ValueError("at least one transform callable must be specified")
Expand All @@ -123,7 +123,7 @@ def _validate_transforms(

def _validate_num(
self, num: Optional[int | float | Sequence[float]]
) -> int | float | Tuple[float, ...]:
) -> int | float | tuple[float, ...]:
if num is None:
return len(self.tfs)
elif isinstance(num, int) and 0 <= num <= len(self.tfs):
Expand All @@ -142,7 +142,7 @@ def _validate_num(
"or a list of floats of length equal to given transforms"
)

def _get_random_transforms(self) -> List[types.AugTransform]:
def _get_random_transforms(self) -> list[types.AugTransform]:
num = self.num
if isinstance(num, int):
rand_idxs = random.sample(range(len(self.tfs)), min(num, len(self.tfs)))
Expand Down
52 changes: 26 additions & 26 deletions src/textacy/augmentation/transforms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import random
from typing import List, Optional, Set
from typing import Optional

from cytoolz import itertoolz

Expand All @@ -10,11 +10,11 @@


def substitute_word_synonyms(
aug_toks: List[types.AugTok],
aug_toks: list[types.AugTok],
*,
num: int | float = 1,
pos: Optional[str | Set[str]] = None,
) -> List[types.AugTok]:
pos: Optional[str | set[str]] = None,
) -> list[types.AugTok]:
"""
Randomly substitute words for which synonyms are available
with a randomly selected synonym,
Expand Down Expand Up @@ -64,11 +64,11 @@ def substitute_word_synonyms(


def insert_word_synonyms(
aug_toks: List[types.AugTok],
aug_toks: list[types.AugTok],
*,
num: int | float = 1,
pos: Optional[str | Set[str]] = None,
) -> List[types.AugTok]:
pos: Optional[str | set[str]] = None,
) -> list[types.AugTok]:
"""
Randomly insert random synonyms of tokens for which synonyms are available,
up to ``num`` times or with a probability of ``num``.
Expand Down Expand Up @@ -106,7 +106,7 @@ def insert_word_synonyms(
return aug_toks[:]

rand_aug_toks = iter(rand_aug_toks)
new_aug_toks: List[types.AugTok] = []
new_aug_toks: list[types.AugTok] = []
# NOTE: https://github.com/python/mypy/issues/5492
padded_pairs = itertoolz.sliding_window(2, [None] + aug_toks) # type: ignore
for idx, (prev_tok, curr_tok) in enumerate(padded_pairs):
Expand Down Expand Up @@ -140,11 +140,11 @@ def insert_word_synonyms(


def swap_words(
aug_toks: List[types.AugTok],
aug_toks: list[types.AugTok],
*,
num: int | float = 1,
pos: Optional[str | Set[str]] = None,
) -> List[types.AugTok]:
pos: Optional[str | set[str]] = None,
) -> list[types.AugTok]:
"""
Randomly swap the positions of two *adjacent* words,
up to ``num`` times or with a probability of ``num``.
Expand Down Expand Up @@ -209,11 +209,11 @@ def swap_words(


def delete_words(
aug_toks: List[types.AugTok],
aug_toks: list[types.AugTok],
*,
num: int | float = 1,
pos: Optional[str | Set[str]] = None,
) -> List[types.AugTok]:
pos: Optional[str | set[str]] = None,
) -> list[types.AugTok]:
"""
Randomly delete words,
up to ``num`` times or with a probability of ``num``.
Expand Down Expand Up @@ -243,7 +243,7 @@ def delete_words(
if not rand_idxs:
return aug_toks[:]

new_aug_toks: List[types.AugTok] = []
new_aug_toks: list[types.AugTok] = []
# NOTE: https://github.com/python/mypy/issues/5492
padded_triplets = itertoolz.sliding_window(
3, [None] + aug_toks + [None] # type: ignore
Expand All @@ -266,11 +266,11 @@ def delete_words(


def substitute_chars(
aug_toks: List[types.AugTok],
aug_toks: list[types.AugTok],
*,
num: int | float = 1,
lang: Optional[str] = None,
) -> List[types.AugTok]:
) -> list[types.AugTok]:
"""
Randomly substitute a single character in randomly-selected words with another,
up to ``num`` times or with a probability of ``num``.
Expand Down Expand Up @@ -332,11 +332,11 @@ def substitute_chars(


def insert_chars(
aug_toks: List[types.AugTok],
aug_toks: list[types.AugTok],
*,
num: int | float = 1,
lang: Optional[str] = None,
) -> List[types.AugTok]:
) -> list[types.AugTok]:
"""
Randomly insert a character into randomly-selected words,
up to ``num`` times or with a probability of ``num``.
Expand Down Expand Up @@ -398,8 +398,8 @@ def insert_chars(


def swap_chars(
aug_toks: List[types.AugTok], *, num: int | float = 1
) -> List[types.AugTok]:
aug_toks: list[types.AugTok], *, num: int | float = 1
) -> list[types.AugTok]:
"""
Randomly swap two *adjacent* characters in randomly-selected words,
up to ``num`` times or with a probability of ``num``.
Expand Down Expand Up @@ -443,8 +443,8 @@ def swap_chars(


def delete_chars(
aug_toks: List[types.AugTok], *, num: int | float = 1
) -> List[types.AugTok]:
aug_toks: list[types.AugTok], *, num: int | float = 1
) -> list[types.AugTok]:
"""
Randomly delete a character in randomly-selected words,
up to ``num`` times or with a probability of ``num``.
Expand Down Expand Up @@ -493,18 +493,18 @@ def delete_chars(
def _validate_aug_toks(aug_toks):
if not (isinstance(aug_toks, list) and isinstance(aug_toks[0], types.AugTok)):
raise TypeError(
errors.type_invalid_msg("aug_toks", type(aug_toks), List[types.AugTok])
errors.type_invalid_msg("aug_toks", type(aug_toks), list[types.AugTok])
)


def _select_random_candidates(cands, num):
"""
Args:
cands (List[obj])
cands (list[obj])
num (int or float)

Returns:
List[obj]
list[obj]
"""
if isinstance(num, int) and num >= 0:
rand_cands = random.sample(cands, min(num, len(cands)))
Expand Down
13 changes: 8 additions & 5 deletions src/textacy/augmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools
import itertools
import string
from typing import Iterable, List, Tuple
from typing import Iterable

from cachetools import cached
from cachetools.keys import hashkey
Expand All @@ -17,7 +17,7 @@
udhr = datasets.UDHR()


def to_aug_toks(doclike: types.DocLike) -> List[types.AugTok]:
def to_aug_toks(doclike: types.DocLike) -> list[types.AugTok]:
"""
Transform a spaCy ``Doc`` or ``Span`` into a list of ``AugTok`` objects,
suitable for use in data augmentation transform functions.
Expand All @@ -27,7 +27,7 @@ def to_aug_toks(doclike: types.DocLike) -> List[types.AugTok]:
errors.type_invalid_msg("spacy_obj", type(doclike), types.DocLike)
)
lang = doclike.vocab.lang
toks_syns: Iterable[List[str]]
toks_syns: Iterable[list[str]]
if concept_net.filepath is None or lang not in concept_net.synonyms:
toks_syns = ([] for _ in doclike)
else:
Expand All @@ -50,7 +50,7 @@ def to_aug_toks(doclike: types.DocLike) -> List[types.AugTok]:


@cached(cache.LRU_CACHE, key=functools.partial(hashkey, "char_weights"))
def get_char_weights(lang: str) -> List[Tuple[str, int]]:
def get_char_weights(lang: str) -> list[tuple[str, int]]:
"""
Get lang-specific character weights for use in certain data augmentation transforms,
based on texts in :class:`textacy.datasets.UDHR`.
Expand All @@ -65,7 +65,10 @@ def get_char_weights(lang: str) -> List[Tuple[str, int]]:
try:
char_weights = list(
collections.Counter(
char for text in udhr.texts(lang=lang) for char in text if char.isalnum()
char
for text in udhr.texts(lang=lang)
for char in text
if char.isalnum()
).items()
)
except ValueError:
Expand Down
17 changes: 12 additions & 5 deletions src/textacy/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

LOGGER = logging.getLogger(__name__)


def _get_size(obj, seen=None):
"""
Recursively find the actual size of an object, in bytes.
Expand Down Expand Up @@ -41,17 +42,23 @@ def _get_size(obj, seen=None):
try:
size += sum((_get_size(i, seen) for i in obj))
except TypeError:
LOGGER.warning("Unable to get size of %r. This may lead to incorrect sizes. Please report this error.", obj)
if hasattr(obj, "__slots__"): # can have __slots__ with __dict__
size += sum(_get_size(getattr(obj, s), seen) for s in obj.__slots__ if hasattr(obj, s))
LOGGER.warning(
"Unable to get size of %r. This may lead to incorrect sizes. Please report this error.",
obj,
)
if hasattr(obj, "__slots__"): # can have __slots__ with __dict__
size += sum(
_get_size(getattr(obj, s), seen) for s in obj.__slots__ if hasattr(obj, s)
)

return size


LRU_CACHE = LRUCache(
LRU_CACHE: LRUCache = LRUCache(
int(os.environ.get("TEXTACY_MAX_CACHE_SIZE", 2147483648)), getsizeof=_get_size
)
""":class:`cachetools.LRUCache`: Least Recently Used (LRU) cache for loaded data.
"""
Least Recently Used (LRU) cache for loaded data.

The max cache size may be set by the `TEXTACY_MAX_CACHE_SIZE` environment variable,
where the value must be an integer (in bytes). Otherwise, the max size is 2GB.
Expand Down
24 changes: 14 additions & 10 deletions src/textacy/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
"""
import pathlib
import re
from typing import Dict, Pattern, Set
from typing import Pattern


DEFAULT_DATA_DIR: pathlib.Path = pathlib.Path(__file__).parent.resolve() / "data"

NUMERIC_ENT_TYPES: Set[str] = {
NUMERIC_ENT_TYPES: set[str] = {
"ORDINAL",
"CARDINAL",
"MONEY",
Expand All @@ -17,11 +17,11 @@
"TIME",
"DATE",
}
SUBJ_DEPS: Set[str] = {"agent", "csubj", "csubjpass", "expl", "nsubj", "nsubjpass"}
OBJ_DEPS: Set[str] = {"attr", "dobj", "dative", "oprd"}
AUX_DEPS: Set[str] = {"aux", "auxpass", "neg"}
SUBJ_DEPS: set[str] = {"agent", "csubj", "csubjpass", "expl", "nsubj", "nsubjpass"}
OBJ_DEPS: set[str] = {"attr", "dobj", "dative", "oprd"}
AUX_DEPS: set[str] = {"aux", "auxpass", "neg"}

REPORTING_VERBS: Dict[str, Set[str]] = {
REPORTING_VERBS: dict[str, set[str]] = {
"en": {
"according",
"accuse",
Expand Down Expand Up @@ -125,7 +125,7 @@
},
}

UD_V2_MORPH_LABELS: Set[str] = {
UD_V2_MORPH_LABELS: set[str] = {
"Abbr",
"Animacy",
"Aspect",
Expand Down Expand Up @@ -158,10 +158,12 @@
Source: https://universaldependencies.org/u/feat/index.html
"""

MATCHER_VALID_OPS: Set[str] = {"!", "+", "?", "*"}
MATCHER_VALID_OPS: set[str] = {"!", "+", "?", "*"}

RE_MATCHER_TOKPAT_DELIM: Pattern = re.compile(r"\s+")
RE_MATCHER_SPECIAL_VAL: Pattern = re.compile(r"^(int|bool)\([^: ]+\)$", flags=re.UNICODE)
RE_MATCHER_SPECIAL_VAL: Pattern = re.compile(
r"^(int|bool)\([^: ]+\)$", flags=re.UNICODE
)

RE_ACRONYM: Pattern = re.compile(
r"(?:^|(?<=\W))"
Expand All @@ -181,7 +183,9 @@
RE_DANGLING_PARENS_TERM: Pattern = re.compile(
r"(?:\s|^)(\()\s{1,2}(.*?)\s{1,2}(\))(?:\s|$)", flags=re.UNICODE
)
RE_LEAD_TAIL_CRUFT_TERM: Pattern = re.compile(r"^[^\w(-]+|[^\w).!?]+$", flags=re.UNICODE)
RE_LEAD_TAIL_CRUFT_TERM: Pattern = re.compile(
r"^[^\w(-]+|[^\w).!?]+$", flags=re.UNICODE
)
RE_LEAD_HYPHEN_TERM: Pattern = re.compile(r"^-([^\W\d_])", flags=re.UNICODE)
RE_NEG_DIGIT_TERM: Pattern = re.compile(r"(-) (\d)", flags=re.UNICODE)
RE_WEIRD_HYPHEN_SPACE_TERM: Pattern = re.compile(
Expand Down
Loading