Skip to content

Commit

Permalink
Merge pull request #89 from BloodAxe/develop
Browse files Browse the repository at this point in the history
0.6.2
  • Loading branch information
BloodAxe authored Dec 25, 2022
2 parents b5abf25 + 4fbc91e commit 439e3d1
Show file tree
Hide file tree
Showing 15 changed files with 471 additions and 43 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
strategy:
matrix:
operating-system: [ubuntu-latest, windows-latest, macos-latest]
python-version: ['3.6', '3.7', '3.8', '3.9', '3.10']
python-version: ['3.7', '3.8', '3.9', '3.10']
pytorch-toolbelt-version: [tests]
fail-fast: false
steps:
Expand Down Expand Up @@ -40,7 +40,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.6', '3.7', '3.8', '3.9', '3.10']
python-version: ['3.7', '3.8', '3.9', '3.10']
steps:
- name: Checkout
uses: actions/checkout@v2
Expand All @@ -53,4 +53,4 @@ jobs:
- name: Install Black
run: pip install black==22.8.0
- name: Run Black
run: black --config=black.toml --check .
run: black --config=pyproject.toml --check .
File renamed without changes.
2 changes: 1 addition & 1 deletion pytorch_toolbelt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import absolute_import

__version__ = "0.6.1"
__version__ = "0.6.2"
39 changes: 38 additions & 1 deletion pytorch_toolbelt/inference/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
__all__ = [
"geometric_mean",
"harmonic_mean",
"harmonic1p_mean",
"logodd_mean",
"log1p_mean",
"pad_image_tensor",
"torch_fliplr",
"torch_flipud",
Expand Down Expand Up @@ -229,7 +231,7 @@ def geometric_mean(x: Tensor, dim: int) -> Tensor:
def harmonic_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor:
"""
Compute harmonic mean along given dimension.
This implementation assume values are in range (0...1) (Probabilities)
Args:
x: Input tensor of arbitrary shape
dim: Dimension to reduce
Expand All @@ -243,6 +245,23 @@ def harmonic_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor:
return x


def harmonic1p_mean(x: Tensor, dim: int) -> Tensor:
"""
Compute harmonic mean along given dimension.
Args:
x: Input tensor of arbitrary shape
dim: Dimension to reduce
Returns:
Tensor
"""
x = torch.reciprocal(x + 1)
x = torch.mean(x, dim=dim)
x = torch.reciprocal(x) - 1
return x


def logodd_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor:
"""
Compute log-odd mean along given dimension.
Expand All @@ -261,3 +280,21 @@ def logodd_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor:
x = torch.mean(x, dim=dim)
x = torch.exp(x) / (1 + torch.exp(x))
return x


def log1p_mean(x: Tensor, dim: int) -> Tensor:
"""
Compute average log(x+1) and them compute exp.
Requires all inputs to be non-negative
Args:
x: Input tensor of arbitrary shape
dim: Dimension to reduce
Returns:
Tensor
"""
x = torch.log1p(x)
x = torch.mean(x, dim=dim)
x = torch.exp(x) - 1
return x
4 changes: 4 additions & 0 deletions pytorch_toolbelt/inference/tta.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,12 @@ def _deaugment_averaging(x: Tensor, reduction: MaybeStrOrCallable) -> Tensor:
x = F.geometric_mean(x, dim=0)
elif reduction in {"hmean", "harmonic_mean"}:
x = F.harmonic_mean(x, dim=0)
elif reduction in {"harmonic1p"}:
x = F.harmonic1p_mean(x, dim=0)
elif reduction == "logodd":
x = F.logodd_mean(x, dim=0)
elif reduction == "log1p":
x = F.log1p_mean(x, dim=0)
elif callable(reduction):
x = reduction(x, dim=0)
elif reduction in {None, "None", "none"}:
Expand Down
1 change: 1 addition & 0 deletions pytorch_toolbelt/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
from .soft_ce import *
from .soft_f1 import *
from .wing_loss import *
from .logcosh import *
20 changes: 20 additions & 0 deletions pytorch_toolbelt/losses/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"soft_jaccard_score",
"soft_dice_score",
"wing_loss",
"log_cosh_loss",
]


Expand Down Expand Up @@ -298,3 +299,22 @@ def label_smoothed_nll_loss(
eps_i = epsilon / lprobs.size(dim)
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
return loss


def log_cosh_loss(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Numerically stable log-cosh implementation.
Reference: https://datascience.stackexchange.com/questions/96271/logcoshloss-on-pytorch
Args:
y_pred:
y_true:
Returns:
"""

def _log_cosh(x: torch.Tensor) -> torch.Tensor:
return x + torch.nn.functional.softplus(-2.0 * x) - math.log(2.0)

return torch.mean(_log_cosh(y_pred - y_true))
13 changes: 13 additions & 0 deletions pytorch_toolbelt/losses/logcosh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch
from pytorch_toolbelt.losses.functional import log_cosh_loss
from torch import nn

__all__ = ["LogCoshLoss"]


class LogCoshLoss(nn.Module):
def __init__(self):
super().__init__()

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
return log_cosh_loss(y_pred, y_true)
7 changes: 4 additions & 3 deletions pytorch_toolbelt/modules/encoders/timm/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@


class GenericTimmEncoder(EncoderModule):
def __init__(self, timm_encoder: Union[nn.Module, str], layers: List[int] = None, pretrained=True):
def __init__(self, timm_encoder: Union[nn.Module, str], layers: List[int] = None, pretrained=True, **kwargs):
strides = []
channels = []
default_layers = []
if isinstance(timm_encoder, str):
import timm.models.factory

timm_encoder = timm.models.factory.create_model(timm_encoder, features_only=True, pretrained=pretrained)
timm_encoder = timm.models.factory.create_model(
timm_encoder, features_only=True, pretrained=pretrained, **kwargs
)

for i, fi in enumerate(timm_encoder.feature_info):
strides.append(fi["reduction"])
Expand Down Expand Up @@ -61,7 +63,6 @@ def make_n_channel_input_std_conv(conv: nn.Module, in_channels: int, mode="auto"
dilation=kwargs.get("dilation", conv.dilation),
groups=kwargs.get("groups", conv.groups),
bias=kwargs.get("bias", conv.bias is not None),
eps=kwargs.get("eps", conv.eps),
)

w = conv.weight
Expand Down
21 changes: 14 additions & 7 deletions pytorch_toolbelt/modules/upsample.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Optional, List

import torch
Expand Down Expand Up @@ -97,20 +98,26 @@ class BilinearAdditiveUpsample2d(nn.Module):
https://arxiv.org/abs/1707.05847
"""

def __init__(self, in_channels: int, scale_factor: int = 2, n: int = 4):
def __init__(self, in_channels: int, scale_factor: int = 2, n=None):
super().__init__()
if in_channels % n != 0:
raise ValueError(f"Number of input channels ({in_channels})must be divisable by n ({n})")
if n is not None:
warnings.warn(
"Argument n has been deprecated and will be removed in new release. It is computed automatically and not required to be specified explicitly"
)

self.n = 2**scale_factor

if in_channels % self.n != 0:
raise ValueError(f"Number of input channels ({in_channels})must be divisable by n ({self.n})")

self.in_channels = in_channels
self.out_channels = in_channels // n
self.out_channels = in_channels // self.n
self.upsample = nn.UpsamplingBilinear2d(scale_factor=scale_factor)
self.n = n

def forward(self, x: Tensor) -> Tensor: # skipcq: PYL-W0221
x = self.upsample(x)
n, c, h, w = x.size()
x = x.reshape(n, c // self.n, self.n, h, w).mean(2)
x = x.reshape(n, self.out_channels, self.n, h, w).mean(2)
return x


Expand All @@ -135,7 +142,7 @@ def __init__(self, in_channels, scale_factor=2, n=4):
self.conv = nn.ConvTranspose2d(
in_channels, in_channels // n, kernel_size=3, padding=1, stride=scale_factor, output_padding=1
)
self.residual = BilinearAdditiveUpsample2d(in_channels, scale_factor=scale_factor, n=n)
self.residual = BilinearAdditiveUpsample2d(in_channels, scale_factor=scale_factor)
self.init_weights()

def forward(self, x: Tensor) -> Tensor: # skipcq: PYL-W0221
Expand Down
Loading

0 comments on commit 439e3d1

Please sign in to comment.