diff --git a/pytorchyolo/models.py b/pytorchyolo/models.py index 1eca59f637..092597ffb3 100644 --- a/pytorchyolo/models.py +++ b/pytorchyolo/models.py @@ -114,6 +114,7 @@ def forward(self, x): x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode) return x + class Mish(nn.Module): """ The MISH activation function (https://github.com/digantamisra98/Mish) """ @@ -123,6 +124,7 @@ def __init__(self): def forward(self, x): return x * torch.tanh(F.softplus(x)) + class YOLOLayer(nn.Module): """Detection layer""" @@ -152,7 +154,7 @@ def forward(self, x, img_size): self.grid = self._make_grid(nx, ny).to(x.device) x[..., 0:2] = (x[..., 0:2].sigmoid() + self.grid) * stride # xy - x[..., 2:4] = torch.exp(x[..., 2:4]) * self.anchor_grid # wh + x[..., 2:4] = torch.exp(x[..., 2:4]) * self.anchor_grid # wh x[..., 4:] = x[..., 4:].sigmoid() x = x.view(bs, -1, self.no) @@ -186,7 +188,7 @@ def forward(self, x): combined_outputs = torch.cat([layer_outputs[int(layer_i)] for layer_i in module_def["layers"].split(",")], 1) group_size = combined_outputs.shape[1] // int(module_def.get("groups", 1)) group_id = int(module_def.get("group_id", 0)) - x = combined_outputs[:, group_size * group_id : group_size * (group_id + 1)] # Slice groupings used by yolo v4 + x = combined_outputs[:, group_size * group_id:group_size * (group_id + 1)] # Slice groupings used by yolo v4 elif module_def["type"] == "shortcut": layer_i = int(module_def["from"]) x = layer_outputs[-1] + layer_outputs[layer_i] diff --git a/pytorchyolo/test.py b/pytorchyolo/test.py index f0bc7439af..9737baf8e4 100755 --- a/pytorchyolo/test.py +++ b/pytorchyolo/test.py @@ -13,7 +13,7 @@ from torch.autograd import Variable from pytorchyolo.models import load_model -from pytorchyolo.utils.utils import load_classes, ap_per_class, get_batch_statistics, non_max_suppression, to_cpu, xywh2xyxy, print_environment_info +from pytorchyolo.utils.utils import load_classes, ap_per_class, get_batch_statistics, non_max_suppression, xywh2xyxy, print_environment_info from pytorchyolo.utils.datasets import ListDataset from pytorchyolo.utils.transforms import DEFAULT_TRANSFORMS from pytorchyolo.utils.parse_config import parse_data_config diff --git a/pytorchyolo/train.py b/pytorchyolo/train.py index 2a1ae02549..a4aab979b0 100755 --- a/pytorchyolo/train.py +++ b/pytorchyolo/train.py @@ -15,7 +15,7 @@ from pytorchyolo.utils.utils import to_cpu, load_classes, print_environment_info, provide_determinism, worker_seed_set from pytorchyolo.utils.datasets import ListDataset from pytorchyolo.utils.augmentations import AUGMENTATION_TRANSFORMS -#from pytorchyolo.utils.transforms import DEFAULT_TRANSFORMS +# from pytorchyolo.utils.transforms import DEFAULT_TRANSFORMS from pytorchyolo.utils.parse_config import parse_data_config from pytorchyolo.utils.loss import compute_loss from pytorchyolo.test import _evaluate, _create_validation_data_loader diff --git a/pytorchyolo/utils/loss.py b/pytorchyolo/utils/loss.py index 1f1d259c33..ba0dfee652 100644 --- a/pytorchyolo/utils/loss.py +++ b/pytorchyolo/utils/loss.py @@ -113,7 +113,7 @@ def compute_loss(predictions, targets, model): # Classification of the objectness the sequel # Calculate the BCE loss between the on the fly generated target and the network prediction - lobj += BCEobj(layer_predictions[..., 4], tobj) # obj loss + lobj += BCEobj(layer_predictions[..., 4], tobj) # obj loss lbox *= 0.05 lobj *= 1.0 diff --git a/pytorchyolo/utils/utils.py b/pytorchyolo/utils/utils.py index bce7a900b8..f39dad00dd 100644 --- a/pytorchyolo/utils/utils.py +++ b/pytorchyolo/utils/utils.py @@ -20,6 +20,7 @@ def provide_determinism(seed=42): torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True + def worker_seed_set(worker_id): # See for details of numpy: # https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 @@ -390,6 +391,7 @@ def print_environment_info(): # Print commit hash if possible try: - print(f"Current Commit Hash: {subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], stderr=subprocess.DEVNULL).decode('ascii').strip()}") + commit_hash = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], stderr=subprocess.DEVNULL).decode('ascii').strip() + print(f"Current Commit Hash: {commit_hash}") except (subprocess.CalledProcessError, FileNotFoundError): print("No git or repo found")