Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can we use pretrained models with iresnet backbone? #2675

Open
ZubairKhan001 opened this issue Oct 25, 2024 · 3 comments
Open

Can we use pretrained models with iresnet backbone? #2675

ZubairKhan001 opened this issue Oct 25, 2024 · 3 comments

Comments

@ZubairKhan001
Copy link

Seems like pertained ms1mv3 arcface models cant be pretrained with a smaller custom dataset. Is that correct.

def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
model = IResNet(block, layers, **kwargs)
if pretrained:
raise ValueError()
return model

@nttstar
Copy link
Collaborator

nttstar commented Oct 25, 2024

Yes, why not?

@ZubairKhan001 ZubairKhan001 changed the title Can we use pretrained models with iresnet backnot? Can we use pretrained models with iresnet backbone? Oct 26, 2024
@ZubairKhan001
Copy link
Author

please, how can we do this in train_v2 file in arcface_torch as it only has --resume configuration option.

@sahasraa
Copy link

have you found a workaround? on how to use pretrained model yet?
I'm kinda using this

import argparse
import logging
import os
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from backbones.iresnet import iresnet50  # Replace with your backbone module
from losses import CombinedMarginLoss
from utils.utils_logging import AverageMeter, init_logging
from utils.utils_callbacks import CallBackVerification
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import GradScaler, autocast

def train(args):
    # Initialize logging
    os.makedirs(args.output_dir, exist_ok=True)
    init_logging(args.rank, args.output_dir)

    # Data preparation
    transform = transforms.Compose([
        transforms.Resize((112, 112)),  # Match model input size
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    train_dataset = datasets.ImageFolder(args.data_dir, transform=transform)
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    # Initialize backbone
    backbone = iresnet50(pretrained=False).cuda()

    # Load pretrained weights
    if args.pretrained:
        checkpoint = torch.load(args.pretrained, map_location="cuda")
        state_dict = checkpoint if "state_dict" not in checkpoint else checkpoint["state_dict"]
        # Strip `module.` prefix if present (for DDP checkpoints)
        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
        backbone.load_state_dict(state_dict, strict=False)
        logging.info(f"Loaded pretrained weights from {args.pretrained}")

    # Wrap model in DDP if distributed training
    if args.world_size > 1:
        backbone = nn.parallel.DistributedDataParallel(
            backbone, device_ids=[args.local_rank], find_unused_parameters=True
        )

    # Define loss and optimizer
    embedding_size = 512
    margin_loss = CombinedMarginLoss(
        embedding_size=embedding_size, s=64, m1=1.0, m2=0.5, m3=0.0
    )
    optimizer = AdamW(backbone.parameters(), lr=args.lr, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs)
    scaler = GradScaler()

    # Verification callback
    callback_verification = CallBackVerification(
        val_targets=args.val_targets, rec_prefix=args.data_dir
    )
    loss_meter = AverageMeter()

    # Training loop
    for epoch in range(args.epochs):
        backbone.train()
        for step, (images, labels) in enumerate(train_loader):
            images, labels = images.cuda(non_blocking=True), labels.cuda(non_blocking=True)

            with autocast():
                embeddings = backbone(images)
                loss = margin_loss(embeddings, labels)

            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            loss_meter.update(loss.item(), images.size(0))

            if step % args.log_interval == 0:
                logging.info(f"Epoch [{epoch}/{args.epochs}], Step [{step}], Loss: {loss_meter.avg:.4f}")

        scheduler.step()

        # Verification after each epoch
        if args.rank == 0:
            callback_verification(epoch, backbone)

        # Save checkpoint
        if args.rank == 0:
            torch.save(
                {
                    "epoch": epoch + 1,
                    "state_dict": backbone.state_dict(),
                    "optimizer": optimizer.state_dict(),
                },
                os.path.join(args.output_dir, f"checkpoint_epoch_{epoch}.pth"),
            )

    logging.info("Training completed.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train iresnet with pretrained weights")
    parser.add_argument("--data_dir", type=str, required=True, help="Path to training dataset")
    parser.add_argument("--output_dir", type=str, default="./output", help="Path to save logs and models")
    parser.add_argument("--pretrained", type=str, default=None, help="Path to pretrained model weights")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size for training")
    parser.add_argument("--epochs", type=int, default=25, help="Number of epochs to train")
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
    parser.add_argument("--num_workers", type=int, default=4, help="Number of data loader workers")
    parser.add_argument("--log_interval", type=int, default=100, help="Logging interval")
    parser.add_argument("--rank", type=int, default=0, help="Rank for distributed training")
    parser.add_argument("--world_size", type=int, default=1, help="World size for distributed training")
    parser.add_argument("--local_rank", type=int, default=0, help="Local rank for distributed training")
    parser.add_argument("--val_targets", type=list, default=["lfw", "cfp_fp"], help="Validation datasets")

    args = parser.parse_args()
    train(args)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants