Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revision models.detection.yolo #851

Merged
merged 32 commits into from
May 20, 2023
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
cf1646c
remove under_review decorators
redleaf-kim Aug 2, 2022
fecf88c
add yolo cfg with giou & update related test function
redleaf-kim Aug 2, 2022
b5abc8f
add serveral yolo config & layers function test
redleaf-kim Aug 2, 2022
198ebc1
Merge branch 'Lightning-AI:master' into yolo_review
redleaf-kim Aug 2, 2022
08f17f7
remove unused import & variable
redleaf-kim Aug 3, 2022
db2601a
add type hints
redleaf-kim Aug 9, 2022
fe38bb7
remove and merge duplicated test
redleaf-kim Aug 9, 2022
7da9d4a
improve readability
redleaf-kim Aug 9, 2022
8c3ed4e
Merge remote-tracking branch 'origin/yolo_review' into yolo_review
redleaf-kim Aug 9, 2022
8f69419
Merge branch 'master' into yolo_review
otaj Aug 12, 2022
b25a864
Merge branch 'master' into yolo_review
otaj Sep 15, 2022
9ff86ab
Merge branch 'master' into yolo_review
otaj Sep 16, 2022
17fab64
add catch_warning fixture
Sep 19, 2022
353f119
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 19, 2022
a3445ac
fix pytest error; indexing argument will be required to pass in upcom…
redleaf-kim Sep 19, 2022
189346c
fix pytest catch_warnings; MisconfigurationException error
redleaf-kim Sep 19, 2022
a1d97b6
fix pytest error
redleaf-kim Sep 19, 2022
d5b5fb9
Merge remote-tracking branch 'origin/yolo_review' into yolo_review
redleaf-kim Sep 19, 2022
0b4eca4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 19, 2022
b52ab5b
Merge branch 'master' into yolo_review
Borda Sep 19, 2022
eb9930e
Fix most obvious CI failings
Sep 19, 2022
fdf38fb
fix test with a missing warning
Sep 19, 2022
a42cdec
resolve accidentally introduced errors
Sep 21, 2022
d534cfa
add catch_warnings
Oct 11, 2022
57c9baf
Merge branch 'master' into yolo_review
Oct 11, 2022
55059f5
Merge branch 'master' into yolo_review
Borda Oct 27, 2022
bf5b360
Apply suggestions from code review
Borda Mar 28, 2023
3ed65fd
Merge branch 'master' into yolo_review
Borda Mar 28, 2023
bd23c27
update mergify team
Borda May 19, 2023
41d2749
Merge branch 'master' into yolo_review
Borda May 19, 2023
0d1c4e7
Merge branch 'master' into yolo_review
Borda May 19, 2023
d758939
Merge branch 'master' into yolo_review
mergify[bot] May 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 9 additions & 16 deletions pl_bolts/models/detection/yolo/yolo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from pl_bolts.models.detection.yolo import yolo_layers
from pl_bolts.utils.stability import under_review


@under_review()
class YOLOConfiguration:
"""This class can be used to parse the configuration files of the Darknet YOLOv4 implementation.

Expand Down Expand Up @@ -149,7 +147,6 @@ def convert(key, value):
return sections


@under_review()
def _create_layer(config: dict, num_inputs: List[int]) -> Tuple[nn.Module, int]:
"""Calls one of the ``_create_<layertype>(config, num_inputs)`` functions to create a PyTorch module from the
layer config.
Expand All @@ -173,8 +170,7 @@ def _create_layer(config: dict, num_inputs: List[int]) -> Tuple[nn.Module, int]:
return create_func[config["type"]](config, num_inputs)


@under_review()
def _create_convolutional(config, num_inputs):
def _create_convolutional(config: dict, num_inputs: int) -> Tuple[nn.Module, int]:
module = nn.Sequential()

batch_normalize = config.get("batch_normalize", False)
Expand Down Expand Up @@ -210,15 +206,13 @@ def _create_convolutional(config, num_inputs):
return module, config["filters"]


@under_review()
def _create_maxpool(config, num_inputs):
def _create_maxpool(config: dict, num_inputs: int) -> Tuple[nn.Module, int]:
padding = (config["size"] - 1) // 2
module = nn.MaxPool2d(config["size"], config["stride"], padding)
return module, num_inputs[-1]


@under_review()
def _create_route(config, num_inputs):
def _create_route(config: dict, num_inputs: int) -> Tuple[nn.Module, int]:
num_chunks = config.get("groups", 1)
chunk_idx = config.get("group_id", 0)

Expand All @@ -234,20 +228,17 @@ def _create_route(config, num_inputs):
return module, num_outputs


@under_review()
def _create_shortcut(config, num_inputs):
def _create_shortcut(config: dict, num_inputs: int) -> Tuple[nn.Module, int]:
module = yolo_layers.ShortcutLayer(config["from"])
return module, num_inputs[-1]


@under_review()
def _create_upsample(config, num_inputs):
def _create_upsample(config: dict, num_inputs: int) -> Tuple[nn.Module, int]:
module = nn.Upsample(scale_factor=config["stride"], mode="nearest")
return module, num_inputs[-1]


@under_review()
def _create_yolo(config, num_inputs):
def _create_yolo(config: dict, num_inputs: int) -> Tuple[nn.Module, int]:
# The "anchors" list alternates width and height.
anchor_dims = config["anchors"]
anchor_dims = [(anchor_dims[i], anchor_dims[i + 1]) for i in range(0, len(anchor_dims), 2)]
Expand All @@ -264,8 +255,10 @@ def _create_yolo(config, num_inputs):
overlap_loss_func = yolo_layers.SELoss()
elif overlap_loss_name == "giou":
overlap_loss_func = yolo_layers.GIoULoss()
else:
elif overlap_loss_name == "iou":
overlap_loss_func = yolo_layers.IoULoss()
else:
raise ValueError("Unknown overlap loss: " + overlap_loss_name)

module = yolo_layers.DetectionLayer(
num_classes=config["classes"],
Expand Down
25 changes: 9 additions & 16 deletions pl_bolts/models/detection/yolo/yolo_layers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor, nn

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils import _TORCH_MESHGRID_REQUIRES_INDEXING, _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -21,7 +20,6 @@
warn_missing_pkg("torchvision")


@under_review()
def _corner_coordinates(xy: Tensor, wh: Tensor) -> Tensor:
"""Converts box center points and sizes to corner coordinates.

Expand All @@ -38,7 +36,6 @@ def _corner_coordinates(xy: Tensor, wh: Tensor) -> Tensor:
return torch.cat((top_left, bottom_right), -1)


@under_review()
def _aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor:
"""Calculates a matrix of intersections over union from box dimensions, assuming that the boxes are located at
the same coordinates.
Expand All @@ -61,7 +58,6 @@ def _aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor:
return inter / union


@under_review()
class SELoss(nn.MSELoss):
def __init__(self):
super().__init__(reduction="none")
Expand All @@ -70,13 +66,11 @@ def forward(self, inputs: Tensor, target: Tensor) -> Tensor:
return super().forward(inputs, target).sum(1)


@under_review()
class IoULoss(nn.Module):
def forward(self, inputs: Tensor, target: Tensor) -> Tensor:
return 1.0 - box_iou(inputs, target).diagonal()


@under_review()
class GIoULoss(nn.Module):
def __init__(self) -> None:
super().__init__()
Expand All @@ -89,7 +83,6 @@ def forward(self, inputs: Tensor, target: Tensor) -> Tensor:
return 1.0 - generalized_box_iou(inputs, target).diagonal()


@under_review()
class DetectionLayer(nn.Module):
"""A YOLO detection layer.

Expand Down Expand Up @@ -263,7 +256,10 @@ def _global_xy(self, xy: Tensor, image_size: Tensor) -> Tensor:

x_range = torch.arange(width, device=xy.device)
y_range = torch.arange(height, device=xy.device)
grid_y, grid_x = torch.meshgrid(y_range, x_range)
if _TORCH_MESHGRID_REQUIRES_INDEXING:
grid_y, grid_x = torch.meshgrid(y_range, x_range, indexing="ij")
else:
grid_y, grid_x = torch.meshgrid(y_range, x_range)
offset = torch.stack((grid_x, grid_y), -1) # [height, width, 2]
offset = offset.unsqueeze(2) # [height, width, 1, 2]

Expand Down Expand Up @@ -468,15 +464,13 @@ def _calculate_losses(
return losses, hits


@under_review()
class Mish(nn.Module):
"""Mish activation."""

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
return x * torch.tanh(nn.functional.softplus(x))


@under_review()
class RouteLayer(nn.Module):
"""Route layer concatenates the output (or part of it) from given layers."""

Expand All @@ -492,12 +486,11 @@ def __init__(self, source_layers: List[int], num_chunks: int, chunk_idx: int) ->
self.num_chunks = num_chunks
self.chunk_idx = chunk_idx

def forward(self, x, outputs):
def forward(self, x, outputs: List[Union[Tensor, None]]) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with this 'x' here!

chunks = [torch.chunk(outputs[layer], self.num_chunks, dim=1)[self.chunk_idx] for layer in self.source_layers]
return torch.cat(chunks, dim=1)


@under_review()
class ShortcutLayer(nn.Module):
"""Shortcut layer adds a residual connection from the source layer."""

Expand All @@ -510,5 +503,5 @@ def __init__(self, source_layer: int) -> None:
super().__init__()
self.source_layer = source_layer

def forward(self, x, outputs):
def forward(self, x, outputs: List[Union[Tensor, None]]) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could review why we are passing 'x' to the forward method and doing nothing with it. Seems to be just to keep with the format...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luca-medeiros you're right. RouteLayer and ShortcutLayer do not use x (the output from the previous layer) and calling them is anyway handled as a special case, so the x can be dropped.

return outputs[-1] + outputs[self.source_layer]
10 changes: 3 additions & 7 deletions pl_bolts/models/detection/yolo/yolo_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from pl_bolts.models.detection.yolo.yolo_layers import DetectionLayer, RouteLayer, ShortcutLayer
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -23,7 +22,6 @@
log = logging.getLogger(__name__)


@under_review()
class YOLO(LightningModule):
"""PyTorch Lightning implementation of YOLOv3 and YOLOv4.

Expand Down Expand Up @@ -179,7 +177,7 @@ def forward(
)
for layer_idx, layer_hits in enumerate(hits):
hit_rate = torch.true_divide(layer_hits, total_hits) if total_hits > 0 else 1.0
self.log(f"layer_{layer_idx}_hit_rate", hit_rate, sync_dist=False)
self.log(f"layer_{layer_idx}_hit_rate", hit_rate, sync_dist=False, batch_size=images.size(0))

def total_loss(loss_name):
"""Returns the sum of the loss over detection layers."""
Expand Down Expand Up @@ -233,8 +231,8 @@ def validation_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], b
total_loss = torch.stack(tuple(losses.values())).sum()

for name, value in losses.items():
self.log(f"val/{name}_loss", value, sync_dist=True)
self.log("val/total_loss", total_loss, sync_dist=True)
self.log(f"val/{name}_loss", value, sync_dist=True, batch_size=images.size(0))
self.log("val/total_loss", total_loss, sync_dist=True, batch_size=images.size(0))

def test_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_idx: int):
"""Evaluates a batch of data from the test set.
Expand Down Expand Up @@ -455,7 +453,6 @@ def _filter_detections(self, detections: Dict[str, Tensor]) -> Dict[str, List[Te
return {"boxes": out_boxes, "scores": out_scores, "classprobs": out_classprobs, "labels": out_labels}


@under_review()
class Resize:
"""Rescales the image and target to given dimensions.

Expand Down Expand Up @@ -486,7 +483,6 @@ def __call__(self, image: Tensor, target: Dict[str, Any]):
return image, target


@under_review()
def run_cli():
from argparse import ArgumentParser

Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerification # type: ignore

_NATIVE_AMP_AVAILABLE: bool = module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")

_TORCHVISION_AVAILABLE: bool = module_available("torchvision")
_GYM_AVAILABLE: bool = module_available("gym")
_SKLEARN_AVAILABLE: bool = module_available("sklearn")
Expand All @@ -20,6 +19,7 @@
_PL_GREATER_EQUAL_1_4_5 = compare_version("pytorch_lightning", operator.ge, "1.4.5")
_TORCH_ORT_AVAILABLE = module_available("torch_ort")
_TORCH_MAX_VERSION_SPARSEML = compare_version("torch", operator.lt, "1.11.0")
_TORCH_MESHGRID_REQUIRES_INDEXING = compare_version("torch", operator.ge, "1.10.0")
_SPARSEML_AVAILABLE = module_available("sparseml") and _PL_GREATER_EQUAL_1_4_5 and _TORCH_MAX_VERSION_SPARSEML

__all__ = ["BatchGradientVerification"]
81 changes: 81 additions & 0 deletions tests/data/yolo_giou.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
[net]
width=256
height=256
channels=3

[convolutional]
batch_normalize=1
filters=8
size=3
stride=1
pad=1
activation=leaky

[route]
layers=-1
groups=2
group_id=1

[maxpool]
size=2
stride=2

[convolutional]
batch_normalize=1
filters=2
size=1
stride=1
pad=1
activation=mish

[convolutional]
batch_normalize=1
filters=4
size=3
stride=1
pad=1
activation=mish

[shortcut]
from=-3
activation=linear

[convolutional]
size=1
stride=1
pad=1
filters=14
activation=linear

[yolo]
mask=2,3
anchors=1,2, 3,4, 5,6, 9,10
classes=2
iou_loss=giou
scale_x_y=1.05
cls_normalizer=1.0
iou_normalizer=0.07
ignore_thresh=0.7

[route]
layers = -4

[upsample]
stride=2

[convolutional]
size=1
stride=1
pad=1
filters=14
activation=linear

[yolo]
mask=0,1
anchors=1,2, 3,4, 5,6, 9,10
classes=2
iou_loss=giou
scale_x_y=1.05
cls_normalizer=1.0
iou_normalizer=0.07
ignore_thresh=0.7
Loading