diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 5282bc9..f4ffbdd 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -1,10 +1,8 @@ name: Docker Image Deploy to DockerHub on: - push: - branches: [main] - pull_request: - branches: [main] + release: + types: [published] jobs: build: diff --git a/README.md b/README.md index 0f8997a..9db5148 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,5 @@ # YOLO: Official Implementation of YOLOv9, YOLOv7 -> [!CAUTION] -> We wanted to inform you that the training code for this project is still in progress, and there are two known issues: -> -> - Slower convergence speed -> -> We strongly recommend refraining from training the model until version 1.0 is released. -> However, inference and validation with pre-trained weights on COCO are available and can be used safely. - [![Documentation Status](https://readthedocs.org/projects/yolo-docs/badge/?version=latest)](https://yolo-docs.readthedocs.io/en/latest/?badge=latest) ![GitHub License](https://img.shields.io/github/license/WongKinYiu/YOLO) ![WIP](https://img.shields.io/badge/status-WIP-orange) @@ -112,33 +104,6 @@ python yolo/lazy.py task=validation dataset=toy Contributions to the YOLO project are welcome! See [CONTRIBUTING](docs/CONTRIBUTING.md) for guidelines on how to contribute. -### TODO Diagrams - -```mermaid -flowchart TB - subgraph Features - Taskv7-->Segmentation["#35 Segmentation"] - Taskv7-->Classification["#34 Classification"] - Taskv9-->Segmentation - Taskv9-->Classification - Trainv7 - end - subgraph Model - MODELv7-->v7-X - MODELv7-->v7-E6 - MODELv7-->v7-E6E - MODELv9-->v9-T - MODELv9-->v9-S - MODELv9-->v9-E - end - subgraph Bugs - Fix-->Fix1["#12 mAP > 1"] - Fix-->Fix2["v9 Gradient Bump"] - Reply-->Reply1["#39"] - Reply-->Reply2["#36"] - end -``` - ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=WongKinYiu/YOLO&type=Date)](https://star-history.com/#WongKinYiu/YOLO&Date) diff --git a/yolo/config/config.py b/yolo/config/config.py index 382d331..9dd85b0 100644 --- a/yolo/config/config.py +++ b/yolo/config/config.py @@ -66,6 +66,7 @@ class DataConfig: class OptimizerArgs: lr: float weight_decay: float + momentum: float @dataclass diff --git a/yolo/utils/bounding_box_utils.py b/yolo/utils/bounding_box_utils.py index ee02f51..ca94286 100644 --- a/yolo/utils/bounding_box_utils.py +++ b/yolo/utils/bounding_box_utils.py @@ -349,7 +349,8 @@ def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device): def create_auto_anchor(self, model: YOLO, image_size): W, H = image_size - dummy_input = torch.zeros(1, 3, H, W).to(self.device) + # TODO: need accelerate dummy test + dummy_input = torch.zeros(1, 3, H, W) dummy_output = model(dummy_input) strides = [] for predict_head in dummy_output["Main"]: diff --git a/yolo/utils/model_utils.py b/yolo/utils/model_utils.py index 3c79915..511ce00 100644 --- a/yolo/utils/model_utils.py +++ b/yolo/utils/model_utils.py @@ -8,7 +8,6 @@ import torch.distributed as dist from lightning import LightningModule, Trainer from lightning.pytorch.callbacks import Callback -from lightning.pytorch.utilities import rank_zero_only from omegaconf import ListConfig from torch import Tensor, no_grad from torch.optim import Optimizer @@ -37,31 +36,31 @@ def lerp(start: float, end: float, step: Union[int, float], total: int = 1): class EMA(Callback): - def __init__(self, decay: float = 0.9999, tau: float = 500): + def __init__(self, decay: float = 0.9999, tau: float = 2000): super().__init__() logger.info(":chart_with_upwards_trend: Enable Model EMA") self.decay = decay self.tau = tau self.step = 0 + self.ema_state_dict = None def setup(self, trainer, pl_module, stage): pl_module.ema = deepcopy(pl_module.model) - self.ema_parameters = [param.clone().detach().to(pl_module.device) for param in pl_module.parameters()] + self.tau /= trainer.world_size for param in pl_module.ema.parameters(): param.requires_grad = False def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"): - for param, ema_param in zip(pl_module.ema.parameters(), self.ema_parameters): - param.data.copy_(ema_param) - trainer.strategy.broadcast(param) + if self.ema_state_dict is None: + self.ema_state_dict = deepcopy(pl_module.model.state_dict()) + pl_module.ema.load_state_dict(self.ema_state_dict) - @rank_zero_only @no_grad() def on_train_batch_end(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None: self.step += 1 decay_factor = self.decay * (1 - exp(-self.step / self.tau)) - for param, ema_param in zip(pl_module.parameters(), self.ema_parameters): - ema_param.data.copy_(lerp(param.detach(), ema_param, decay_factor)) + for key, param in pl_module.model.state_dict().items(): + self.ema_state_dict[key] = lerp(param.detach(), self.ema_state_dict[key], decay_factor) def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer: @@ -77,9 +76,9 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer: conv_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" not in name] model_parameters = [ - {"params": bias_params, "momentum": 0.8, "weight_decay": 0}, - {"params": conv_params, "momentum": 0.8}, - {"params": norm_params, "momentum": 0.8, "weight_decay": 0}, + {"params": bias_params, "momentum": 0.937, "weight_decay": 0}, + {"params": conv_params, "momentum": 0.937}, + {"params": norm_params, "momentum": 0.937, "weight_decay": 0}, ] def next_epoch(self, batch_num, epoch_idx): @@ -89,8 +88,8 @@ def next_epoch(self, batch_num, epoch_idx): # 0.937: Start Momentum # 0.8 : Normal Momemtum # 3 : The warm up epoch num - self.min_mom = lerp(0.937, 0.8, max(epoch_idx, 3), 3) - self.max_mom = lerp(0.937, 0.8, max(epoch_idx + 1, 3), 3) + self.min_mom = lerp(0.937, 0.8, min(epoch_idx, 3), 3) + self.max_mom = lerp(0.937, 0.8, min(epoch_idx + 1, 3), 3) self.batch_num = batch_num self.batch_idx = 0 @@ -100,7 +99,7 @@ def next_batch(self): for lr_idx, param_group in enumerate(self.param_groups): min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx] param_group["lr"] = lerp(min_lr, max_lr, self.batch_idx, self.batch_num) - param_group["momentum"] = lerp(self.min_mom, self.max_mom, self.batch_idx, self.batch_num) + # param_group["momentum"] = lerp(self.min_mom, self.max_mom, self.batch_idx, self.batch_num) lr_dict[f"LR/{lr_idx}"] = param_group["lr"] return lr_dict @@ -125,7 +124,7 @@ def create_scheduler(optimizer: Optimizer, schedule_cfg: SchedulerConfig) -> _LR lambda1 = lambda epoch: (epoch + 1) / wepoch if epoch < wepoch else 1 lambda2 = lambda epoch: 10 - 9 * ((epoch + 1) / wepoch) if epoch < wepoch else 1 warmup_schedule = LambdaLR(optimizer, lr_lambda=[lambda2, lambda1, lambda1]) - schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[2]) + schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[wepoch - 1]) return schedule