Skip to content

Commit

Permalink
fix annotations for Python >= 3.8 (#5301)
Browse files Browse the repository at this point in the history
* run mypy on Python 3.9

* appease mypy

* Revert "run mypy on Python 3.9"

This reverts commit b935c83.
  • Loading branch information
pmeier authored Jan 28, 2022
1 parent 460e1bd commit 8e874ff
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def write_version_file():
pytorch_dep += "==" + os.getenv("PYTORCH_VERSION")

requirements = [
"typing_extensions",
"numpy",
"requests",
pytorch_dep,
Expand Down
4 changes: 3 additions & 1 deletion torchvision/datasets/stl10.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os.path
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Optional, Tuple, cast

import numpy as np
from PIL import Image
Expand Down Expand Up @@ -65,10 +65,12 @@ def __init__(
self.labels: Optional[np.ndarray]
if self.split == "train":
self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
self.labels = cast(np.ndarray, self.labels)
self.__load_folds(folds)

elif self.split == "train+unlabeled":
self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
self.labels = cast(np.ndarray, self.labels)
self.__load_folds(folds)
unlabeled_data, _ = self.__loadfile(self.train_list[2][0])
self.data = np.concatenate((self.data, unlabeled_data))
Expand Down
5 changes: 3 additions & 2 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import torch
from PIL import Image, ImageOps, ImageEnhance
from typing_extensions import Literal

try:
import accimage
Expand Down Expand Up @@ -130,7 +131,7 @@ def pad(
img: Image.Image,
padding: Union[int, List[int], Tuple[int, ...]],
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
padding_mode: str = "constant",
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> Image.Image:

if not _is_pil_image(img):
Expand Down Expand Up @@ -189,7 +190,7 @@ def pad(
if img.mode == "P":
palette = img.getpalette()
img = np.asarray(img)
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode)
img = Image.fromarray(img)
img.putpalette(palette)
return img
Expand Down

0 comments on commit 8e874ff

Please sign in to comment.