diff --git a/README.md b/README.md index caba3bf..fb687a4 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,7 @@ We don't provide detailed documentations for Dassl, unlike another [project](htt Dassl has implemented the following methods: - Single-source domain adaptation + - [Cross Domain Adaptive Clustering for Semi Supervised Domain Adaptation (CVPR'21)](https://arxiv.org/pdf/2104.09415.pdf) [[dassl/engine/da/cdac.py](dassl/engine/da/cdac.py)] - [Semi-supervised Domain Adaptation via Minimax Entropy (ICCV'19)](https://arxiv.org/abs/1904.06487) [[dassl/engine/da/mme.py](dassl/engine/da/mme.py)] - [Maximum Classifier Discrepancy for Unsupervised Domain Adaptation (CVPR'18)](https://arxiv.org/abs/1712.02560https://arxiv.org/abs/1712.02560) [[dassl/engine/da/mcd.py](dassl/engine/da/mcd.py)] - [Self-ensembling for visual domain adaptation (ICLR'18)](https://arxiv.org/abs/1706.05208) [[dassl/engine/da/self_ensembling.py](dassl/engine/da/self_ensembling.py)] diff --git a/configs/trainers/da/cdac/digit5.yaml b/configs/trainers/da/cdac/digit5.yaml new file mode 100644 index 0000000..65d82a5 --- /dev/null +++ b/configs/trainers/da/cdac/digit5.yaml @@ -0,0 +1,19 @@ +DATALOADER: + TRAIN_X: + SAMPLER: "RandomSampler" + BATCH_SIZE: 64 + TRAIN_U: + SAME_AS_X: False + BATCH_SIZE: 192 + TEST: + BATCH_SIZE: 256 + K_TRANSFORMS: 2 + +OPTIM: + NAME: "sgd" + LR: 0.01 + MAX_EPOCH: 90 + +TRAINER: + CDAC: + STRONG_TRANSFORMS: ["randaugment", "normalize"] \ No newline at end of file diff --git a/configs/trainers/da/cdac/domainnet.yaml b/configs/trainers/da/cdac/domainnet.yaml new file mode 100644 index 0000000..fa29e87 --- /dev/null +++ b/configs/trainers/da/cdac/domainnet.yaml @@ -0,0 +1,19 @@ +DATALOADER: + TRAIN_X: + SAMPLER: "RandomDomainSampler" + BATCH_SIZE: 30 + TRAIN_U: + SAME_AS_X: False + BATCH_SIZE: 6 + TEST: + BATCH_SIZE: 30 + K_TRANSFORMS: 2 + +OPTIM: + NAME: "sgd" + LR: 0.003 + MAX_EPOCH: 90 + +TRAINER: + CDAC: + STRONG_TRANSFORMS: ["randaugment", "normalize"] \ No newline at end of file diff --git a/configs/trainers/da/cdac/mini_domainnet.yaml b/configs/trainers/da/cdac/mini_domainnet.yaml new file mode 100644 index 0000000..1e9baa7 --- /dev/null +++ b/configs/trainers/da/cdac/mini_domainnet.yaml @@ -0,0 +1,20 @@ +DATALOADER: + TRAIN_X: + SAMPLER: "RandomDomainSampler" + BATCH_SIZE: 64 + TRAIN_U: + SAME_AS_X: False + BATCH_SIZE: 192 + TEST: + BATCH_SIZE: 200 + K_TRANSFORMS: 2 + +OPTIM: + NAME: "sgd" + LR: 0.005 + MAX_EPOCH: 60 + LR_SCHEDULER: "cosine" + +TRAINER: + CDAC: + STRONG_TRANSFORMS: ["randaugment", "normalize"] \ No newline at end of file diff --git a/dassl/config/defaults.py b/dassl/config/defaults.py index ba8494c..ff17afa 100644 --- a/dassl/config/defaults.py +++ b/dassl/config/defaults.py @@ -222,12 +222,20 @@ # MME _C.TRAINER.MME = CN() _C.TRAINER.MME.LMDA = 0.1 # weight for the entropy loss +# CDAC +_C.TRAINER.CDAC = CN() +_C.TRAINER.CDAC.CLASS_LR_MULTI = 10 +_C.TRAINER.CDAC.RAMPUP_COEF = 30 +_C.TRAINER.CDAC.RAMPUP_ITRS = 1000 +_C.TRAINER.CDAC.TOPK_MATCH = 5 +_C.TRAINER.CDAC.P_THRESH = 0.95 +_C.TRAINER.CDAC.STRONG_TRANSFORMS = () + # SelfEnsembling _C.TRAINER.SE = CN() _C.TRAINER.SE.EMA_ALPHA = 0.999 _C.TRAINER.SE.CONF_THRE = 0.95 _C.TRAINER.SE.RAMPUP = 300 - # M3SDA _C.TRAINER.M3SDA = CN() _C.TRAINER.M3SDA.LMDA = 0.5 # weight for the moment distance loss diff --git a/dassl/engine/da/__init__.py b/dassl/engine/da/__init__.py index 2d46469..e940734 100644 --- a/dassl/engine/da/__init__.py +++ b/dassl/engine/da/__init__.py @@ -1,6 +1,7 @@ from .mcd import MCD from .mme import MME from .adda import ADDA +from .cdac import CDAC from .dael import DAEL from .dann import DANN from .adabn import AdaBN diff --git a/dassl/engine/da/cdac.py b/dassl/engine/da/cdac.py new file mode 100644 index 0000000..f470ceb --- /dev/null +++ b/dassl/engine/da/cdac.py @@ -0,0 +1,277 @@ +import numpy as np +from functools import partial +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.optim.lr_scheduler import LambdaLR + +from dassl.data import DataManager +from dassl.optim import build_optimizer, build_lr_scheduler +from dassl.utils import count_num_param +from dassl.engine import TRAINER_REGISTRY, TrainerXU +from dassl.metrics import compute_accuracy +from dassl.modeling.ops import ReverseGrad +from dassl.engine.trainer import SimpleNet +from dassl.data.transforms.transforms import build_transform + + +def custom_scheduler(iter, max_iter=None, alpha=10, beta=0.75, init_lr=0.001): + """Custom LR Annealing + + https://arxiv.org/pdf/1409.7495.pdf + """ + if max_iter is None: + return init_lr + return (1 + float(iter / max_iter) * alpha)**(-1.0 * beta) + + +class AAC(nn.Module): + + def forward(self, sim_mat, prob_u, prob_us): + + P = prob_u.matmul(prob_us.t()) + + loss = -( + sim_mat * torch.log(P + 1e-7) + + (1.-sim_mat) * torch.log(1. - P + 1e-7) + ) + return loss.mean() + + +class Prototypes(nn.Module): + + def __init__(self, fdim, num_classes, temp=0.05): + super().__init__() + self.prototypes = nn.Linear(fdim, num_classes, bias=False) + self.temp = temp + self.revgrad = ReverseGrad() + + def forward(self, x, reverse=False): + if reverse: + x = self.revgrad(x) + x = F.normalize(x, p=2, dim=1) + out = self.prototypes(x) + out = out / self.temp + return out + + +@TRAINER_REGISTRY.register() +class CDAC(TrainerXU): + """Cross Domain Adaptive Clustering. + + https://arxiv.org/pdf/2104.09415.pdf + """ + + def __init__(self, cfg): + self.rampup_coef = cfg.TRAINER.CDAC.RAMPUP_COEF + self.rampup_iters = cfg.TRAINER.CDAC.RAMPUP_ITRS + self.lr_multi = cfg.TRAINER.CDAC.CLASS_LR_MULTI + self.topk = cfg.TRAINER.CDAC.TOPK_MATCH + self.p_thresh = cfg.TRAINER.CDAC.P_THRESH + self.aac_criterion = AAC() + super().__init__(cfg) + + def check_cfg(self, cfg): + assert len( + cfg.TRAINER.CDAC.STRONG_TRANSFORMS + ) > 0, "Strong augmentations are necessary to run CDAC" + assert cfg.DATALOADER.K_TRANSFORMS == 2, "CDAC needs two strong augmentations of the same image." + + def build_data_loader(self): + + cfg = self.cfg + tfm_train = build_transform(cfg, is_train=True) + custom_tfm_train = [tfm_train] + choices = cfg.TRAINER.CDAC.STRONG_TRANSFORMS + tfm_train_strong = build_transform(cfg, is_train=True, choices=choices) + custom_tfm_train += [tfm_train_strong] + self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train) + self.train_loader_x = self.dm.train_loader_x + self.train_loader_u = self.dm.train_loader_u + self.val_loader = self.dm.val_loader + self.test_loader = self.dm.test_loader + self.num_classes = self.dm.num_classes + self.lab2cname = self.dm.lab2cname + + def build_model(self): + cfg = self.cfg + + # Custom LR Scheduler for CDAC + if self.cfg.TRAIN.COUNT_ITER == "train_x": + self.num_batches = len(self.train_loader_x) + elif self.cfg.TRAIN.COUNT_ITER == "train_u": + self.num_batches = len(self.len_train_loader_u) + elif self.cfg.TRAIN.COUNT_ITER == "smaller_one": + self.num_batches = min( + len(self.train_loader_x), len(self.train_loader_u) + ) + self.max_iter = self.max_epoch * self.num_batches + print("Max Iterations: %d" % self.max_iter) + + print("Building F") + self.F = SimpleNet(cfg, cfg.MODEL, 0) + self.F.to(self.device) + print("# params: {:,}".format(count_num_param(self.F))) + self.optim_F = build_optimizer(self.F, cfg.OPTIM) + custom_lr_F = partial( + custom_scheduler, max_iter=self.max_iter, init_lr=cfg.OPTIM.LR + ) + self.sched_F = LambdaLR(self.optim_F, custom_lr_F) + self.register_model("F", self.F, self.optim_F, self.sched_F) + + print("Building C") + self.C = Prototypes(self.F.fdim, self.num_classes) + self.C.to(self.device) + print("# params: {:,}".format(count_num_param(self.C))) + self.optim_C = build_optimizer(self.C, cfg.OPTIM) + + # Multiply the learning rate of C by lr_multi + for group_param in self.optim_C.param_groups: + group_param['lr'] *= self.lr_multi + custom_lr_C = partial( + custom_scheduler, + max_iter=self.max_iter, + init_lr=cfg.OPTIM.LR * self.lr_multi + ) + self.sched_F = LambdaLR(self.optim_C, custom_lr_C) + + self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM) + self.register_model("C", self.C, self.optim_C, self.sched_C) + + def assess_y_pred_quality(self, y_pred, y_true, mask): + n_masked_correct = (y_pred.eq(y_true).float() * mask).sum() + acc_thre = n_masked_correct / (mask.sum() + 1e-5) + acc_raw = y_pred.eq(y_true).sum() / y_pred.numel() # raw accuracy + keep_rate = mask.sum() / mask.numel() + output = { + "acc_thre": acc_thre, + "acc_raw": acc_raw, + "keep_rate": keep_rate + } + return output + + def forward_backward(self, batch_x, batch_u): + + current_itr = self.epoch * self.num_batches + self.batch_idx + + input_x, label_x, input_u, input_us, input_us2, label_u = self.parse_batch_train( + batch_x, batch_u + ) + + # Paper Reference Eq. 2 - Supervised Loss + + feat_x = self.F(input_x) + logit_x = self.C(feat_x) + loss_x = F.cross_entropy(logit_x, label_x) + + self.model_backward_and_update(loss_x) + + feat_u = self.F(input_u) + feat_us = self.F(input_us) + feat_us2 = self.F(input_us2) + + # Paper Reference Eq.3 - Adversarial Adaptive Loss + logit_u = self.C(feat_u, reverse=True) + logit_us = self.C(feat_us, reverse=True) + prob_u, prob_us = F.softmax(logit_u, dim=1), F.softmax(logit_us, dim=1) + + # Get similarity matrix s_ij + sim_mat = self.get_similarity_matrix(feat_u, self.topk, self.device) + + aac_loss = (-1. * self.aac_criterion(sim_mat, prob_u, prob_us)) + + # Paper Reference Eq. 4 - Pseudo label Loss + logit_u = self.C(feat_u) + logit_us = self.C(feat_us) + logit_us2 = self.C(feat_us2) + prob_u, prob_us, prob_us2 = F.softmax( + logit_u, dim=1 + ), F.softmax( + logit_us, dim=1 + ), F.softmax( + logit_us2, dim=1 + ) + prob_u = prob_u.detach() + max_probs, max_idx = torch.max(prob_u, dim=-1) + mask = max_probs.ge(self.p_thresh).float() + p_u_stats = self.assess_y_pred_quality(max_idx, label_u, mask) + + pl_loss = ( + F.cross_entropy(logit_us2, max_idx, reduction='none') * mask + ).mean() + + # Paper Reference Eq. 8 - Consistency Loss + cons_multi = self.sigmoid_rampup( + current_itr=current_itr, rampup_itr=self.rampup_iters + ) * self.rampup_coef + cons_loss = cons_multi * F.mse_loss(prob_us, prob_us2) + + loss_u = aac_loss + pl_loss + cons_loss + + self.model_backward_and_update(loss_u) + + loss_summary = { + "loss_x": loss_x.item(), + "acc_x": compute_accuracy(logit_x, label_x)[0].item(), + "loss_u": loss_u.item(), + "aac_loss": aac_loss.item(), + "pl_loss": pl_loss.item(), + "cons_loss": cons_loss.item(), + "p_u_pred_acc": p_u_stats["acc_raw"], + "p_u_pred_acc_thre": p_u_stats["acc_thre"], + "p_u_pred_keep": p_u_stats["keep_rate"] + } + + # Update LR after every iteration as mentioned in the paper + + self.update_lr() + + return loss_summary + + def parse_batch_train(self, batch_x, batch_u): + + input_x = batch_x["img"][0] + label_x = batch_x["label"] + + input_u = batch_u["img"][0] + input_us = batch_u["img2"][0] + input_us2 = batch_u["img2"][1] + label_u = batch_u["label"] + + input_x = input_x.to(self.device) + label_x = label_x.to(self.device) + + input_u = input_u.to(self.device) + input_us = input_us.to(self.device) + input_us2 = input_us2.to(self.device) + label_u = label_u.to(self.device) + + return input_x, label_x, input_u, input_us, input_us2, label_u + + def model_inference(self, input): + return self.C(self.F(input)) + + @staticmethod + def get_similarity_matrix(feat, topk, device): + + feat_d = feat.detach() + + feat_d = torch.sort( + torch.argsort(feat_d, dim=1, descending=True)[:, :topk], dim=1 + )[0] + sim_mat = torch.zeros((feat_d.shape[0], feat_d.shape[0])).to(device) + for row in range(feat_d.shape[0]): + sim_mat[row, torch.all(feat_d == feat_d[row, :], dim=1)] = 1 + return sim_mat + + @staticmethod + def sigmoid_rampup(current_itr, rampup_itr): + """Exponential Rampup + https://arxiv.org/abs/1610.02242 + """ + if rampup_itr == 0: + return 1.0 + else: + var = np.clip(current_itr, 0.0, rampup_itr) + phase = 1.0 - var/rampup_itr + return float(np.exp(-5.0 * phase * phase))