Skip to content

Commit

Permalink
Merge pull request #79 from BloodAxe/feature/painless_sota
Browse files Browse the repository at this point in the history
Feature/painless sota
  • Loading branch information
BloodAxe authored Oct 20, 2022
2 parents 0d8c697 + 6abace9 commit 896ab13
Show file tree
Hide file tree
Showing 21 changed files with 382 additions and 108 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
- name: Install dependencies
run: pip install .[${{ matrix.pytorch-toolbelt-version }}]
- name: Install linters
run: pip install flake8==3.8.4 flake8-docstrings==1.5.0
run: pip install flake8==5
- name: Run PyTest
run: pytest
- name: Run Flake8
Expand All @@ -48,7 +48,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
python-version: [3.8, 3.9]
steps:
- name: Checkout
uses: actions/checkout@v2
Expand All @@ -59,6 +59,6 @@ jobs:
- name: Update pip
run: python -m pip install --upgrade pip
- name: Install Black
run: pip install black==22.3.0
run: pip install black==22.10.0
- name: Run Black
run: black --config=black.toml --check .
4 changes: 2 additions & 2 deletions .github/workflows/upload_to_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ jobs:
upload:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout
- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python
with:
python-version: '3.8'
- name: Install dependencies
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ var/
.pytest_cache/
/tests/tta_eval.csv
/tests/tmp.onnx
/tests/test_plot_confusion_matrix.png
/tests/test_plot_confusion_matrix.png
51 changes: 26 additions & 25 deletions pytorch_toolbelt/datasets/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,30 @@
"INPUT_INDEX_KEY",
"OUTPUT_EMBEDDINGS_KEY",
"OUTPUT_LOGITS_KEY",
"OUTPUT_MASK_16_KEY",
"OUTPUT_MASK_2_KEY",
"OUTPUT_MASK_32_KEY",
"OUTPUT_MASK_4_KEY",
"OUTPUT_MASK_64_KEY",
"OUTPUT_MASK_8_KEY",
"OUTPUT_MASK_KEY",
"OUTPUT_MASK_KEY_STRIDE_16",
"OUTPUT_MASK_KEY_STRIDE_2",
"OUTPUT_MASK_KEY_STRIDE_32",
"OUTPUT_MASK_KEY_STRIDE_4",
"OUTPUT_MASK_KEY_STRIDE_64",
"OUTPUT_MASK_KEY_STRIDE_8",
"TARGET_CLASS_KEY",
"TARGET_LABELS_KEY",
"TARGET_MASK_16_KEY",
"TARGET_MASK_2_KEY",
"TARGET_MASK_32_KEY",
"TARGET_MASK_4_KEY",
"TARGET_MASK_64_KEY",
"TARGET_MASK_8_KEY",
"TARGET_MASK_KEY",
"TARGET_MASK_KEY_STRIDE_16",
"TARGET_MASK_KEY_STRIDE_2",
"TARGET_MASK_KEY_STRIDE_32",
"TARGET_MASK_KEY_STRIDE_4",
"TARGET_MASK_KEY_STRIDE_64",
"TARGET_MASK_KEY_STRIDE_8",
"TARGET_MASK_WEIGHT_KEY",
"name_for_stride",
"read_image_rgb",
]


def name_for_stride(name, stride: int):
return f"{name}_{stride}"
return f"{name}_STRIDE_{stride}"


INPUT_INDEX_KEY = "INPUT_INDEX_KEY"
Expand All @@ -41,20 +41,21 @@ def name_for_stride(name, stride: int):
TARGET_LABELS_KEY = "TARGET_LABELS_KEY"

TARGET_MASK_KEY = "TARGET_MASK_KEY"
TARGET_MASK_2_KEY = name_for_stride(TARGET_MASK_KEY, 2)
TARGET_MASK_4_KEY = name_for_stride(TARGET_MASK_KEY, 4)
TARGET_MASK_8_KEY = name_for_stride(TARGET_MASK_KEY, 8)
TARGET_MASK_16_KEY = name_for_stride(TARGET_MASK_KEY, 16)
TARGET_MASK_32_KEY = name_for_stride(TARGET_MASK_KEY, 32)
TARGET_MASK_64_KEY = name_for_stride(TARGET_MASK_KEY, 64)

TARGET_MASK_KEY_STRIDE_2 = name_for_stride(TARGET_MASK_KEY, 2)
TARGET_MASK_KEY_STRIDE_4 = name_for_stride(TARGET_MASK_KEY, 4)
TARGET_MASK_KEY_STRIDE_8 = name_for_stride(TARGET_MASK_KEY, 8)
TARGET_MASK_KEY_STRIDE_16 = name_for_stride(TARGET_MASK_KEY, 16)
TARGET_MASK_KEY_STRIDE_32 = name_for_stride(TARGET_MASK_KEY, 32)
TARGET_MASK_KEY_STRIDE_64 = name_for_stride(TARGET_MASK_KEY, 64)

OUTPUT_MASK_KEY = "OUTPUT_MASK_KEY"
OUTPUT_MASK_2_KEY = name_for_stride(OUTPUT_MASK_KEY, 2)
OUTPUT_MASK_4_KEY = name_for_stride(OUTPUT_MASK_KEY, 4)
OUTPUT_MASK_8_KEY = name_for_stride(OUTPUT_MASK_KEY, 8)
OUTPUT_MASK_16_KEY = name_for_stride(OUTPUT_MASK_KEY, 16)
OUTPUT_MASK_32_KEY = name_for_stride(OUTPUT_MASK_KEY, 32)
OUTPUT_MASK_64_KEY = name_for_stride(OUTPUT_MASK_KEY, 64)
OUTPUT_MASK_KEY_STRIDE_2 = name_for_stride(OUTPUT_MASK_KEY, 2)
OUTPUT_MASK_KEY_STRIDE_4 = name_for_stride(OUTPUT_MASK_KEY, 4)
OUTPUT_MASK_KEY_STRIDE_8 = name_for_stride(OUTPUT_MASK_KEY, 8)
OUTPUT_MASK_KEY_STRIDE_16 = name_for_stride(OUTPUT_MASK_KEY, 16)
OUTPUT_MASK_KEY_STRIDE_32 = name_for_stride(OUTPUT_MASK_KEY, 32)
OUTPUT_MASK_KEY_STRIDE_64 = name_for_stride(OUTPUT_MASK_KEY, 64)

OUTPUT_LOGITS_KEY = "OUTPUT_LOGITS_KEY"
OUTPUT_EMBEDDINGS_KEY = "OUTPUT_EMBEDDINGS_KEY"
Expand Down
11 changes: 6 additions & 5 deletions pytorch_toolbelt/datasets/mean_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,32 @@


class DatasetMeanStdCalculator:
__slots__ = ["global_mean", "global_var", "n_items", "num_channels", "global_max", "global_min"]
__slots__ = ["global_mean", "global_var", "n_items", "num_channels", "global_max", "global_min", "dtype"]

"""
Class to calculate running mean and std of the dataset. It helps when whole dataset does not fit entirely in RAM.
"""

def __init__(self, num_channels: int = 3):
def __init__(self, num_channels: int = 3, dtype=np.float64):
"""
Create a new instance of DatasetMeanStdCalculator
Args:
num_channels: Number of channels in the image. Default value is 3
"""
super(DatasetMeanStdCalculator, self).__init__()
super().__init__()
self.num_channels = num_channels
self.global_mean = None
self.global_var = None
self.global_max = None
self.global_min = None
self.n_items = 0
self.dtype = dtype
self.reset()

def reset(self):
self.global_mean = np.zeros(self.num_channels, dtype=np.float64)
self.global_var = np.zeros(self.num_channels, dtype=np.float64)
self.global_mean = np.zeros(self.num_channels, dtype=self.dtype)
self.global_var = np.zeros(self.num_channels, dtype=self.dtype)
self.global_max = np.ones_like(self.global_mean) * float("-inf")
self.global_min = np.ones_like(self.global_mean) * float("+inf")
self.n_items = 0
Expand Down
1 change: 1 addition & 0 deletions pytorch_toolbelt/datasets/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .inria_aerial import *
194 changes: 194 additions & 0 deletions pytorch_toolbelt/datasets/providers/inria_aerial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import os
import subprocess
import warnings
from pathlib import Path
from typing import Union, Optional, Tuple
import hashlib

import numpy as np
import pandas as pd
import torch
import zipfile

from sklearn.model_selection import GroupKFold

from pytorch_toolbelt.utils import fs


__all__ = ["InriaAerialImageDataset"]


class InriaAerialImageDataset:
"""
python -m pytorch_toolbelt.datasets.providers.inria_aerial inria_dataset
"""

TASK = "binary_segmentation"
METRIC = ""
ORIGIN = "https://project.inria.fr/aerialimagelabeling"
TRAIN_LOCATIONS = ["austin", "chicago", "kitsap", "tyrol-w", "vienna"]
TEST_LOCATIONS = ["bellingham", "bloomington", "innsbruck", "sfo", "tyrol-e"]

urls = {
"https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.001": "17a7d95c78e484328fd8fe5d5afa2b505e04b8db8fceb617819f3c935d1f39ec",
"https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.002": "b505cb223964b157823e88fbd5b0bd041afcbf39427af3ca1ce981ff9f61aff4",
"https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.003": "752916faa67be6fc6693f8559531598fa2798dc01b7d197263e911718038252e",
"https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.004": "b3893e78f92572455fc2c811af560a558d2a57f9b92eff62fa41399b607a6f44",
"https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.005": "a92eb20fdc9911c5ffe3afc514490b8f1e1e5b22301a6fc55d3b4e1624d8033f",
}

@classmethod
def download_and_extract(cls, data_dir: Union[str, Path]) -> bool:
try:
from py7zr import py7zr
except ImportError:
print("You need to install py7zr to extract 7z-archive: `pip install py7zr`.")
return False

filenames = []
for file_url, file_hash in cls.urls.items():
file_path = os.path.join(data_dir, os.path.basename(file_url))
if not os.path.isfile(file_path) or cls.sha256digest(file_path) != file_hash:
os.makedirs(data_dir, exist_ok=True)
torch.hub.download_url_to_file(file_url, file_path)

filenames.append(file_path)

main_archive = os.path.join(data_dir, "aerialimagelabeling.7z")
with open(main_archive, "ab") as outfile: # append in binary mode
for fname in filenames:
with open(fname, "rb") as infile: # open in binary mode also
outfile.write(infile.read())

with py7zr.SevenZipFile(main_archive, "r") as archive:
archive.extractall(data_dir)
os.unlink(main_archive)

zip_archive = os.path.join(data_dir, "NEW2-AerialImageDataset.zip")
with zipfile.ZipFile(zip_archive, "r") as zip_ref:
zip_ref.extractall(data_dir)
os.unlink(zip_archive)
return True

@classmethod
def init_from_folder(cls, data_dir: Union[str, Path], download: bool = False):
data_dir = os.path.expanduser(data_dir)

if download:
if not cls.download_and_extract(data_dir):
raise RuntimeError("Download and extract failed")

return cls(os.path.join(data_dir, "AerialImageDataset"))

@classmethod
def sha256digest(cls, filename: str) -> str:
blocksize = 4096
sha = hashlib.sha256()
with open(filename, "rb") as f:
file_buffer = f.read(blocksize)
while len(file_buffer) > 0:
sha.update(file_buffer)
file_buffer = f.read(blocksize)
readable_hash = sha.hexdigest()
return readable_hash

@classmethod
def read_tiff(
cls, image_fname: str, crop_coords: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None
) -> np.ndarray:
import rasterio
from rasterio.windows import Window

window = None
if crop_coords is not None:
(row_start, row_stop), (col_start, col_stop) = crop_coords
window = Window.from_slices((row_start, row_stop), (col_start, col_stop))

if not os.path.isfile(image_fname):
raise FileNotFoundError(image_fname)

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

with rasterio.open(image_fname) as f:
image = f.read(window=window)
image = np.moveaxis(image, 0, -1) # CHW->HWC
if image.shape[2] == 1:
image = np.squeeze(image, axis=2)
return image

@classmethod
def compress_prediction_mask(cls, predicted_mask_fname, compressed_mask_fname):
command = (
"gdal_translate --config GDAL_PAM_ENABLED NO -co COMPRESS=CCITTFAX4 -co NBITS=1 "
+ predicted_mask_fname
+ " "
+ compressed_mask_fname
)
subprocess.call(command, shell=True)

def __init__(self, root_dir: str):
self.root_dir = root_dir
self.train_dir = os.path.join(root_dir, "train")
self.test_dir = os.path.join(root_dir, "test")

if not os.path.isdir(self.train_dir):
raise FileNotFoundError(f"Train directory {self.train_dir} does not exist")
if not os.path.isdir(self.test_dir):
raise FileNotFoundError(f"Test directory {self.train_dir} does not exist")

self.train_images = fs.find_images_in_dir(os.path.join(self.train_dir, "images"))
self.train_masks = fs.find_images_in_dir(os.path.join(self.train_dir, "gt"))

if len(self.train_images) != 180 or len(self.train_masks) != 180:
raise RuntimeError("Number of train images and ground-truth masks must be 180")

def get_test_df(self) -> pd.DataFrame:
test_images = fs.find_images_in_dir(os.path.join(self.test_dir, "images"))
df = pd.DataFrame.from_dict({"images": test_images})
df["rows"] = 5000
df["cols"] = 5000
df["location"] = df["images"].apply(lambda x: fs.id_from_fname(x).rstrip("0123456789"))
return df

def get_train_val_split_train_df(self) -> pd.DataFrame:
# For validation, we remove the first five images of every location
# (e.g., austin{1-5}.tif, chicago{1-5}.tif) from the training set.
# That is suggested validation strategy by competition host
valid_locations = []
for loc in self.TRAIN_LOCATIONS:
for i in range(1, 6):
valid_locations.append(f"{loc}{i}")

df = pd.DataFrame.from_dict({"images": self.train_images, "masks": self.train_masks})
df["location_with_index"] = df["images"].apply(lambda x: fs.id_from_fname(x))
df["location"] = df["location_with_index"].apply(lambda x: x.rstrip("0123456789"))
df["split"] = df["location_with_index"].apply(lambda l: "valid" if l in valid_locations else "train")
df["rows"] = 5000
df["cols"] = 5000
return df

def get_kfold_split_train_df(self, num_folds: int = 5) -> pd.DataFrame:
df = pd.DataFrame.from_dict({"images": self.train_images, "masks": self.train_masks})
df["location_with_index"] = df["images"].apply(lambda x: fs.id_from_fname(x))
df["location"] = df["location_with_index"].apply(lambda x: x.rstrip("0123456789"))
df["rows"] = 5000
df["cols"] = 5000
df["fold"] = -1
kfold = GroupKFold(n_splits=num_folds)
for fold, (train_index, test_index) in enumerate(kfold.split(df, df, groups=df["location"])):
df.loc[test_index, "fold"] = fold
return df


def download_and_extract(data_dir):
ds = InriaAerialImageDataset.init_from_folder(data_dir, download=True)
print(ds.get_test_df())
print(ds.get_train_val_split_train_df())
print(ds.get_kfold_split_train_df())


if __name__ == "__main__":
from fire import Fire

Fire(download_and_extract)
Loading

0 comments on commit 896ab13

Please sign in to comment.