Skip to content

Commit

Permalink
Merge pull request #33 from BloodAxe/feature/0.3.0
Browse files Browse the repository at this point in the history
Feature/0.3.0
  • Loading branch information
BloodAxe authored Jan 17, 2020
2 parents 71a8c2e + 80e1063 commit f7b83ef
Show file tree
Hide file tree
Showing 86 changed files with 4,474 additions and 2,499 deletions.
2 changes: 0 additions & 2 deletions .appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ cache:
environment:

matrix:
- PYTHON: 'C:\Python27-x64'
- PYTHON: 'C:\Python35-x64'
- PYTHON: 'C:\Python36-x64'
- PYTHON: 'C:\Python37-x64'

Expand Down
7 changes: 7 additions & 0 deletions CREDITS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
This file contains links to repositories, source code of which may be partially used in this repository. Mind giving them kudos on GitHub!

1. https://github.com/Cadene/pretrained-models.pytorch
1. https://blog.ceshine.net/post/pytorch-memory-swish/
1. https://github.com/digantamisra98/Mish
1. https://github.com/mapillary/inplace_abn
1. https://github.com/PkuRainBow/OCNet.pytorch
25 changes: 25 additions & 0 deletions black.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Example configuration for Black.

# NOTE: you have to use single-quoted strings in TOML for regular expressions.
# It's the equivalent of r-strings in Python. Multiline strings are treated as
# verbose regular expressions by Black. Use [ ] to denote a significant space
# character.

[tool.black]
line-length = 119
target-version = ['py36', 'py37', 'py38']
include = '\.pyi?$'
exclude = '''
/(
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
'''
6 changes: 1 addition & 5 deletions demo/demo_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,14 @@ def main():
# "dice_log": L.BinaryDiceLogLoss(),
# "sdice": L.BinarySymmetricDiceLoss(),
# "sdice_log": L.BinarySymmetricDiceLoss(log_loss=True),

"bce+lovasz": L.JointLoss(BCEWithLogitsLoss(), L.BinaryLovaszLoss()),
# "lovasz": L.BinaryLovaszLoss(),
# "bce+jaccard": L.JointLoss(BCEWithLogitsLoss(),
# L.BinaryJaccardLoss(), 1, 0.5),

# "bce+log_jaccard": L.JointLoss(BCEWithLogitsLoss(),
# L.BinaryJaccardLogLoss(), 1, 0.5),

# "bce+log_dice": L.JointLoss(BCEWithLogitsLoss(),
# L.BinaryDiceLogLoss(), 1, 0.5)

# "reduced_focal": L.BinaryFocalLoss(reduced=True)
}

Expand Down Expand Up @@ -55,5 +51,5 @@ def main():
f.show()


if __name__ == '__main__':
if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion pytorch_toolbelt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import absolute_import

__version__ = "0.2.2"
__version__ = "0.3.0"
10 changes: 2 additions & 8 deletions pytorch_toolbelt/inference/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,7 @@ def pad_image_tensor(image_tensor: Tensor, pad_size: int = 32):
:return: Tuple of output tensor and pad params. Second argument can be used to reverse pad operation of model output
"""
rows, cols = image_tensor.size(2), image_tensor.size(3)
if (
isinstance(pad_size, Sized)
and isinstance(pad_size, Iterable)
and len(pad_size) == 2
):
if isinstance(pad_size, Sized) and isinstance(pad_size, Iterable) and len(pad_size) == 2:
pad_height, pad_width = [int(val) for val in pad_size]
elif isinstance(pad_size, int):
pad_height = pad_width = pad_size
Expand Down Expand Up @@ -109,9 +105,7 @@ def unpad_image_tensor(image_tensor, pad):

def unpad_xyxy_bboxes(bboxes_tensor: torch.Tensor, pad, dim=-1):
pad_left, pad_right, pad_top, pad_btm = pad
pad = torch.tensor(
[pad_left, pad_top, pad_left, pad_top], dtype=bboxes_tensor.dtype
).to(bboxes_tensor.device)
pad = torch.tensor([pad_left, pad_top, pad_left, pad_top], dtype=bboxes_tensor.dtype).to(bboxes_tensor.device)

if dim == -1:
dim = len(bboxes_tensor.size()) - 1
Expand Down
80 changes: 24 additions & 56 deletions pytorch_toolbelt/inference/tiles.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Implementation of tile-based inference allowing to predict huge images that does not fit into GPU memory entirely
in a sliding-window fashion and merging prediction mask back to full-resolution.
"""
import math
from typing import List

import numpy as np
import cv2
import math
import numpy as np
import torch


Expand All @@ -28,14 +28,18 @@ def compute_pyramid_patch_weight_loss(width, height) -> np.ndarray:
Dc = np.zeros((width, height))
De = np.zeros((width, height))

for i in range(width):
for j in range(height):
Dc[i, j] = np.sqrt(np.square(i - xc + 0.5) + np.square(j - yc + 0.5))
De_l = np.sqrt(np.square(i - xl + 0.5) + np.square(j - j + 0.5))
De_r = np.sqrt(np.square(i - xr + 0.5) + np.square(j - j + 0.5))
De_b = np.sqrt(np.square(i - i + 0.5) + np.square(j - yb + 0.5))
De_t = np.sqrt(np.square(i - i + 0.5) + np.square(j - yt + 0.5))
De[i, j] = np.min([De_l, De_r, De_b, De_t])
Dcx = np.square(np.arange(width) - xc + 0.5)
Dcy = np.square(np.arange(height) - yc + 0.5)
Dc = np.sqrt(Dcx[np.newaxis].transpose() + Dcy)

De_l = np.square(np.arange(width) - xl + 0.5) + np.square(0.5)
De_r = np.square(np.arange(width) - xr + 0.5) + np.square(0.5)
De_b = np.square(0.5) + np.square(np.arange(height) - yb + 0.5)
De_t = np.square(0.5) + np.square(np.arange(height) - yt + 0.5)

De_x = np.sqrt(np.minimum(De_l, De_r))
De_y = np.sqrt(np.minimum(De_b, De_t))
De = np.minimum(De_x[np.newaxis].transpose(), De_y)

alpha = (width * height) / np.sum(np.divide(De, np.add(Dc, De)))
W = alpha * np.divide(De, np.add(Dc, De))
Expand All @@ -47,9 +51,7 @@ class ImageSlicer:
Helper class to slice image into tiles and merge them back
"""

def __init__(
self, image_shape, tile_size, tile_step=0, image_margin=0, weight="mean"
):
def __init__(self, image_shape, tile_size, tile_step=0, image_margin=0, weight="mean"):
"""
:param image_shape: Shape of the source image (H, W)
Expand All @@ -75,21 +77,14 @@ def __init__(

weights = {"mean": self._mean, "pyramid": self._pyramid}

self.weight = (
weight
if isinstance(weight, np.ndarray)
else weights[weight](self.tile_size)
)
self.weight = weight if isinstance(weight, np.ndarray) else weights[weight](self.tile_size)

if self.tile_step[0] < 1 or self.tile_step[0] > self.tile_size[0]:
raise ValueError()
if self.tile_step[1] < 1 or self.tile_step[1] > self.tile_size[1]:
raise ValueError()

overlap = [
self.tile_size[0] - self.tile_step[0],
self.tile_size[1] - self.tile_step[1],
]
overlap = [self.tile_size[0] - self.tile_step[0], self.tile_size[1] - self.tile_step[1]]

self.margin_left = 0
self.margin_right = 0
Expand All @@ -111,14 +106,10 @@ def __init__(
self.margin_bottom = extra_h - self.margin_top

else:
if (self.image_width - overlap[1] + 2 * image_margin) % self.tile_step[
1
] != 0:
if (self.image_width - overlap[1] + 2 * image_margin) % self.tile_step[1] != 0:
raise ValueError()

if (self.image_height - overlap[0] + 2 * image_margin) % self.tile_step[
0
] != 0:
if (self.image_height - overlap[0] + 2 * image_margin) % self.tile_step[0] != 0:
raise ValueError()

self.margin_left = image_margin
Expand All @@ -130,32 +121,13 @@ def __init__(
bbox_crops = []

for y in range(
0,
self.image_height
+ self.margin_top
+ self.margin_bottom
- self.tile_size[0]
+ 1,
self.tile_step[0],
0, self.image_height + self.margin_top + self.margin_bottom - self.tile_size[0] + 1, self.tile_step[0]
):
for x in range(
0,
self.image_width
+ self.margin_left
+ self.margin_right
- self.tile_size[1]
+ 1,
self.tile_step[1],
0, self.image_width + self.margin_left + self.margin_right - self.tile_size[1] + 1, self.tile_step[1]
):
crops.append((x, y, self.tile_size[1], self.tile_size[0]))
bbox_crops.append(
(
x - self.margin_left,
y - self.margin_top,
self.tile_size[1],
self.tile_size[0],
)
)
bbox_crops.append((x - self.margin_left, y - self.margin_top, self.tile_size[1], self.tile_size[0]))

self.crops = np.array(crops)
self.bbox_crops = np.array(bbox_crops)
Expand Down Expand Up @@ -189,9 +161,7 @@ def split(self, image, border_type=cv2.BORDER_CONSTANT, value=0):

return tiles

def cut_patch(
self, image: np.ndarray, slice_index, border_type=cv2.BORDER_CONSTANT, value=0
):
def cut_patch(self, image: np.ndarray, slice_index, border_type=cv2.BORDER_CONSTANT, value=0):
assert image.shape[0] == self.image_height
assert image.shape[1] == self.image_width

Expand Down Expand Up @@ -298,9 +268,7 @@ def integrate_batch(self, batch: torch.Tensor, crop_coords):
:param crop_coords: Corresponding tile crops w.r.t to original image
"""
if len(batch) != len(crop_coords):
raise ValueError(
"Number of images in batch does not correspond to number of coordinates"
)
raise ValueError("Number of images in batch does not correspond to number of coordinates")

for tile, (x, y, tile_width, tile_height) in zip(batch, crop_coords):
self.image[:, y : y + tile_height, x : x + tile_width] += tile * self.weight
Expand Down
38 changes: 10 additions & 28 deletions pytorch_toolbelt/inference/tta.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,11 @@ def fivecrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple) -> T
center_crop_y = (image_height - crop_height) // 2
center_crop_x = (image_width - crop_width) // 2

crop_cc = image[
...,
center_crop_y : center_crop_y + crop_height,
center_crop_x : center_crop_x + crop_width,
]
crop_cc = image[..., center_crop_y : center_crop_y + crop_height, center_crop_x : center_crop_x + crop_width]
assert crop_cc.size(2) == crop_height
assert crop_cc.size(3) == crop_width

output = (
model(crop_tl)
+ model(crop_tr)
+ model(crop_bl)
+ model(crop_br)
+ model(crop_cc)
)
output = model(crop_tl) + model(crop_tr) + model(crop_bl) + model(crop_br) + model(crop_cc)
one_over_5 = float(1.0 / 5.0)
return output * one_over_5

Expand Down Expand Up @@ -125,11 +115,7 @@ def tencrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple) -> Te
center_crop_y = (image_height - crop_height) // 2
center_crop_x = (image_width - crop_width) // 2

crop_cc = image[
...,
center_crop_y : center_crop_y + crop_height,
center_crop_x : center_crop_x + crop_width,
]
crop_cc = image[..., center_crop_y : center_crop_y + crop_height, center_crop_x : center_crop_x + crop_width]
assert crop_cc.size(2) == crop_height
assert crop_cc.size(3) == crop_width

Expand Down Expand Up @@ -202,11 +188,10 @@ def d4_image2mask(model: nn.Module, image: Tensor) -> Tensor:
output = model(image)

for aug, deaug in zip(
[F.torch_rot90, F.torch_rot180, F.torch_rot270],
[F.torch_rot270, F.torch_rot180, F.torch_rot90],
[F.torch_rot90, F.torch_rot180, F.torch_rot270], [F.torch_rot270, F.torch_rot180, F.torch_rot90]
):
x = deaug(model(aug(image)))
output = output + x
output += x

image = F.torch_transpose(image)

Expand All @@ -215,10 +200,11 @@ def d4_image2mask(model: nn.Module, image: Tensor) -> Tensor:
[F.torch_none, F.torch_rot270, F.torch_rot180, F.torch_rot90],
):
x = deaug(model(aug(image)))
output = output + F.torch_transpose(x)
output += F.torch_transpose(x)

one_over_8 = float(1.0 / 8.0)
return output * one_over_8
output *= one_over_8
return output


class TTAWrapper(nn.Module):
Expand Down Expand Up @@ -258,13 +244,9 @@ def forward(self, input: Tensor) -> Tensor:

for scale in self.scale_levels:
dst_size = int(h * scale), int(w * scale)
input_scaled = interpolate(
input, dst_size, mode="bilinear", align_corners=False
)
input_scaled = interpolate(input, dst_size, mode="bilinear", align_corners=False)
output_scaled = self.model(input_scaled)
output_scaled = interpolate(
output_scaled, out_size, mode="bilinear", align_corners=False
)
output_scaled = interpolate(output_scaled, out_size, mode="bilinear", align_corners=False)
output += output_scaled

return output / (1 + len(self.scale_levels))
6 changes: 4 additions & 2 deletions pytorch_toolbelt/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import absolute_import

from .dice import *
from .focal import *
from .jaccard import *
from .dice import *
from .lovasz import *
from .joint_loss import *
from .lovasz import *
from .soft_bce import *
from .soft_ce import *
from .wing_loss import *
Loading

0 comments on commit f7b83ef

Please sign in to comment.