From 98d8d77f2cb51d766514db50d00b6a36ca5fa0b9 Mon Sep 17 00:00:00 2001 From: Parag Date: Fri, 10 Jan 2025 15:41:31 +0530 Subject: [PATCH] managed the import of torch.amp to be compatible with all pytorch versions --- train.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index fec0a239c0a..e24194be12a 100644 --- a/train.py +++ b/train.py @@ -36,6 +36,12 @@ import torch.nn as nn import yaml from torch.optim import lr_scheduler + +try: + import torch.amp as amp +except ImportError: + import torch.cuda.amp as amp + from tqdm import tqdm FILE = Path(__file__).resolve() @@ -94,6 +100,7 @@ torch_distributed_zero_first, ) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html RANK = int(os.getenv("RANK", -1)) WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1)) @@ -221,7 +228,7 @@ def train(hyp, opt, device, callbacks): LOGGER.info(f"Transferred {len(csd)}/{len(model.state_dict())} items from {weights}") # report else: model = Model(cfg, ch=3, nc=nc, anchors=hyp.get("anchors")).to(device) # create - amp = check_amp(model) # check AMP + use_amp = check_amp(model) # check AMP # Freeze freeze = [f"model.{x}." for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze @@ -238,7 +245,7 @@ def train(hyp, opt, device, callbacks): # Batch size if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size - batch_size = check_train_batch_size(model, imgsz, amp) + batch_size = check_train_batch_size(model, imgsz, use_amp) loggers.on_params_update({"batch_size": batch_size}) # Optimizer @@ -352,7 +359,8 @@ def lf(x): maps = np.zeros(nc) # mAP per class results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) scheduler.last_epoch = start_epoch - 1 # do not move - scaler = torch.cuda.amp.GradScaler(enabled=amp) + # scaler = torch.cuda.amp.GradScaler(enabled=amp) + scaler = amp.GradScaler(enabled=use_amp) stopper, stop = EarlyStopping(patience=opt.patience), False compute_loss = ComputeLoss(model) # init loss class callbacks.run("on_train_start") @@ -409,7 +417,8 @@ def lf(x): imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) # Forward - with torch.cuda.amp.autocast(amp): + # with torch.cuda.amp.autocast(amp): + with amp.autocast(enabled=use_amp, device_type=device.type): pred = model(imgs) # forward loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size if RANK != -1: @@ -458,7 +467,7 @@ def lf(x): data_dict, batch_size=batch_size // WORLD_SIZE * 2, imgsz=imgsz, - half=amp, + half=use_amp, model=ema.ema, single_cls=single_cls, dataloader=val_loader,