Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of the Cross Domain Adaptive Clustering(Domain Adaptation) #44

Merged
merged 4 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ We don't provide detailed documentations for Dassl, unlike another [project](htt
Dassl has implemented the following methods:

- Single-source domain adaptation
- [Cross Domain Adaptive Clustering for Semi Supervised Domain Adaptation (CVPR'21)](https://arxiv.org/pdf/2104.09415.pdf) [[dassl/engine/da/cdac.py](dassl/engine/da/cdac.py)]
- [Semi-supervised Domain Adaptation via Minimax Entropy (ICCV'19)](https://arxiv.org/abs/1904.06487) [[dassl/engine/da/mme.py](dassl/engine/da/mme.py)]
- [Maximum Classifier Discrepancy for Unsupervised Domain Adaptation (CVPR'18)](https://arxiv.org/abs/1712.02560https://arxiv.org/abs/1712.02560) [[dassl/engine/da/mcd.py](dassl/engine/da/mcd.py)]
- [Self-ensembling for visual domain adaptation (ICLR'18)](https://arxiv.org/abs/1706.05208) [[dassl/engine/da/self_ensembling.py](dassl/engine/da/self_ensembling.py)]
Expand Down
19 changes: 19 additions & 0 deletions configs/trainers/da/cdac/digit5.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
DATALOADER:
TRAIN_X:
SAMPLER: "RandomSampler"
BATCH_SIZE: 64
TRAIN_U:
SAME_AS_X: False
BATCH_SIZE: 192
TEST:
BATCH_SIZE: 256
K_TRANSFORMS: 2

OPTIM:
NAME: "sgd"
LR: 0.01
MAX_EPOCH: 90

TRAINER:
CDAC:
STRONG_TRANSFORMS: ["randaugment", "normalize"]
19 changes: 19 additions & 0 deletions configs/trainers/da/cdac/domainnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
DATALOADER:
TRAIN_X:
SAMPLER: "RandomDomainSampler"
BATCH_SIZE: 30
TRAIN_U:
SAME_AS_X: False
BATCH_SIZE: 6
TEST:
BATCH_SIZE: 30
K_TRANSFORMS: 2

OPTIM:
NAME: "sgd"
LR: 0.003
MAX_EPOCH: 90

TRAINER:
CDAC:
STRONG_TRANSFORMS: ["randaugment", "normalize"]
20 changes: 20 additions & 0 deletions configs/trainers/da/cdac/mini_domainnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
DATALOADER:
TRAIN_X:
SAMPLER: "RandomDomainSampler"
BATCH_SIZE: 64
TRAIN_U:
SAME_AS_X: False
BATCH_SIZE: 192
TEST:
BATCH_SIZE: 200
K_TRANSFORMS: 2

OPTIM:
NAME: "sgd"
LR: 0.005
MAX_EPOCH: 60
LR_SCHEDULER: "cosine"

TRAINER:
CDAC:
STRONG_TRANSFORMS: ["randaugment", "normalize"]
10 changes: 9 additions & 1 deletion dassl/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,20 @@
# MME
_C.TRAINER.MME = CN()
_C.TRAINER.MME.LMDA = 0.1 # weight for the entropy loss
# CDAC
_C.TRAINER.CDAC = CN()
_C.TRAINER.CDAC.CLASS_LR_MULTI = 10
_C.TRAINER.CDAC.RAMPUP_COEF = 30
_C.TRAINER.CDAC.RAMPUP_ITRS = 1000
_C.TRAINER.CDAC.TOPK_MATCH = 5
_C.TRAINER.CDAC.P_THRESH = 0.95
_C.TRAINER.CDAC.STRONG_TRANSFORMS = ()

# SelfEnsembling
_C.TRAINER.SE = CN()
_C.TRAINER.SE.EMA_ALPHA = 0.999
_C.TRAINER.SE.CONF_THRE = 0.95
_C.TRAINER.SE.RAMPUP = 300

# M3SDA
_C.TRAINER.M3SDA = CN()
_C.TRAINER.M3SDA.LMDA = 0.5 # weight for the moment distance loss
Expand Down
1 change: 1 addition & 0 deletions dassl/engine/da/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .mcd import MCD
from .mme import MME
from .adda import ADDA
from .cdac import CDAC
from .dael import DAEL
from .dann import DANN
from .adabn import AdaBN
Expand Down
277 changes: 277 additions & 0 deletions dassl/engine/da/cdac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
import numpy as np
from functools import partial
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.optim.lr_scheduler import LambdaLR

from dassl.data import DataManager
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling.ops import ReverseGrad
from dassl.engine.trainer import SimpleNet
from dassl.data.transforms.transforms import build_transform


def custom_scheduler(iter, max_iter=None, alpha=10, beta=0.75, init_lr=0.001):
"""Custom LR Annealing

https://arxiv.org/pdf/1409.7495.pdf
"""
if max_iter is None:
return init_lr
return (1 + float(iter / max_iter) * alpha)**(-1.0 * beta)


class AAC(nn.Module):

def forward(self, sim_mat, prob_u, prob_us):

P = prob_u.matmul(prob_us.t())

loss = -(
sim_mat * torch.log(P + 1e-7) +
(1.-sim_mat) * torch.log(1. - P + 1e-7)
)
return loss.mean()


class Prototypes(nn.Module):

def __init__(self, fdim, num_classes, temp=0.05):
super().__init__()
self.prototypes = nn.Linear(fdim, num_classes, bias=False)
self.temp = temp
self.revgrad = ReverseGrad()

def forward(self, x, reverse=False):
if reverse:
x = self.revgrad(x)
x = F.normalize(x, p=2, dim=1)
out = self.prototypes(x)
out = out / self.temp
return out


@TRAINER_REGISTRY.register()
class CDAC(TrainerXU):
"""Cross Domain Adaptive Clustering.

https://arxiv.org/pdf/2104.09415.pdf
"""

def __init__(self, cfg):
self.rampup_coef = cfg.TRAINER.CDAC.RAMPUP_COEF
self.rampup_iters = cfg.TRAINER.CDAC.RAMPUP_ITRS
self.lr_multi = cfg.TRAINER.CDAC.CLASS_LR_MULTI
self.topk = cfg.TRAINER.CDAC.TOPK_MATCH
self.p_thresh = cfg.TRAINER.CDAC.P_THRESH
self.aac_criterion = AAC()
super().__init__(cfg)

def check_cfg(self, cfg):
assert len(
cfg.TRAINER.CDAC.STRONG_TRANSFORMS
) > 0, "Strong augmentations are necessary to run CDAC"
assert cfg.DATALOADER.K_TRANSFORMS == 2, "CDAC needs two strong augmentations of the same image."

def build_data_loader(self):

cfg = self.cfg
tfm_train = build_transform(cfg, is_train=True)
custom_tfm_train = [tfm_train]
choices = cfg.TRAINER.CDAC.STRONG_TRANSFORMS
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
custom_tfm_train += [tfm_train_strong]
self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
self.train_loader_x = self.dm.train_loader_x
self.train_loader_u = self.dm.train_loader_u
self.val_loader = self.dm.val_loader
self.test_loader = self.dm.test_loader
self.num_classes = self.dm.num_classes
self.lab2cname = self.dm.lab2cname

def build_model(self):
cfg = self.cfg

# Custom LR Scheduler for CDAC
if self.cfg.TRAIN.COUNT_ITER == "train_x":
self.num_batches = len(self.train_loader_x)
elif self.cfg.TRAIN.COUNT_ITER == "train_u":
self.num_batches = len(self.len_train_loader_u)
elif self.cfg.TRAIN.COUNT_ITER == "smaller_one":
self.num_batches = min(
len(self.train_loader_x), len(self.train_loader_u)
)
self.max_iter = self.max_epoch * self.num_batches
print("Max Iterations: %d" % self.max_iter)

print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
custom_lr_F = partial(
custom_scheduler, max_iter=self.max_iter, init_lr=cfg.OPTIM.LR
)
self.sched_F = LambdaLR(self.optim_F, custom_lr_F)
self.register_model("F", self.F, self.optim_F, self.sched_F)

print("Building C")
self.C = Prototypes(self.F.fdim, self.num_classes)
self.C.to(self.device)
print("# params: {:,}".format(count_num_param(self.C)))
self.optim_C = build_optimizer(self.C, cfg.OPTIM)

# Multiply the learning rate of C by lr_multi
for group_param in self.optim_C.param_groups:
group_param['lr'] *= self.lr_multi
custom_lr_C = partial(
custom_scheduler,
max_iter=self.max_iter,
init_lr=cfg.OPTIM.LR * self.lr_multi
)
self.sched_F = LambdaLR(self.optim_C, custom_lr_C)

self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
self.register_model("C", self.C, self.optim_C, self.sched_C)

def assess_y_pred_quality(self, y_pred, y_true, mask):
n_masked_correct = (y_pred.eq(y_true).float() * mask).sum()
acc_thre = n_masked_correct / (mask.sum() + 1e-5)
acc_raw = y_pred.eq(y_true).sum() / y_pred.numel() # raw accuracy
keep_rate = mask.sum() / mask.numel()
output = {
"acc_thre": acc_thre,
"acc_raw": acc_raw,
"keep_rate": keep_rate
}
return output

def forward_backward(self, batch_x, batch_u):

current_itr = self.epoch * self.num_batches + self.batch_idx

input_x, label_x, input_u, input_us, input_us2, label_u = self.parse_batch_train(
batch_x, batch_u
)

# Paper Reference Eq. 2 - Supervised Loss

feat_x = self.F(input_x)
logit_x = self.C(feat_x)
loss_x = F.cross_entropy(logit_x, label_x)

self.model_backward_and_update(loss_x)

feat_u = self.F(input_u)
feat_us = self.F(input_us)
feat_us2 = self.F(input_us2)

# Paper Reference Eq.3 - Adversarial Adaptive Loss
logit_u = self.C(feat_u, reverse=True)
logit_us = self.C(feat_us, reverse=True)
prob_u, prob_us = F.softmax(logit_u, dim=1), F.softmax(logit_us, dim=1)

# Get similarity matrix s_ij
sim_mat = self.get_similarity_matrix(feat_u, self.topk, self.device)

aac_loss = (-1. * self.aac_criterion(sim_mat, prob_u, prob_us))

# Paper Reference Eq. 4 - Pseudo label Loss
logit_u = self.C(feat_u)
logit_us = self.C(feat_us)
logit_us2 = self.C(feat_us2)
prob_u, prob_us, prob_us2 = F.softmax(
logit_u, dim=1
), F.softmax(
logit_us, dim=1
), F.softmax(
logit_us2, dim=1
)
prob_u = prob_u.detach()
max_probs, max_idx = torch.max(prob_u, dim=-1)
mask = max_probs.ge(self.p_thresh).float()
p_u_stats = self.assess_y_pred_quality(max_idx, label_u, mask)

pl_loss = (
F.cross_entropy(logit_us2, max_idx, reduction='none') * mask
).mean()

# Paper Reference Eq. 8 - Consistency Loss
cons_multi = self.sigmoid_rampup(
current_itr=current_itr, rampup_itr=self.rampup_iters
) * self.rampup_coef
cons_loss = cons_multi * F.mse_loss(prob_us, prob_us2)

loss_u = aac_loss + pl_loss + cons_loss

self.model_backward_and_update(loss_u)

loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_u": loss_u.item(),
"aac_loss": aac_loss.item(),
"pl_loss": pl_loss.item(),
"cons_loss": cons_loss.item(),
"p_u_pred_acc": p_u_stats["acc_raw"],
"p_u_pred_acc_thre": p_u_stats["acc_thre"],
"p_u_pred_keep": p_u_stats["keep_rate"]
}

# Update LR after every iteration as mentioned in the paper

self.update_lr()

return loss_summary

def parse_batch_train(self, batch_x, batch_u):

input_x = batch_x["img"][0]
label_x = batch_x["label"]

input_u = batch_u["img"][0]
input_us = batch_u["img2"][0]
input_us2 = batch_u["img2"][1]
label_u = batch_u["label"]

input_x = input_x.to(self.device)
label_x = label_x.to(self.device)

input_u = input_u.to(self.device)
input_us = input_us.to(self.device)
input_us2 = input_us2.to(self.device)
label_u = label_u.to(self.device)

return input_x, label_x, input_u, input_us, input_us2, label_u

def model_inference(self, input):
return self.C(self.F(input))

@staticmethod
def get_similarity_matrix(feat, topk, device):

feat_d = feat.detach()

feat_d = torch.sort(
torch.argsort(feat_d, dim=1, descending=True)[:, :topk], dim=1
)[0]
sim_mat = torch.zeros((feat_d.shape[0], feat_d.shape[0])).to(device)
for row in range(feat_d.shape[0]):
sim_mat[row, torch.all(feat_d == feat_d[row, :], dim=1)] = 1
return sim_mat

@staticmethod
def sigmoid_rampup(current_itr, rampup_itr):
"""Exponential Rampup
https://arxiv.org/abs/1610.02242
"""
if rampup_itr == 0:
return 1.0
else:
var = np.clip(current_itr, 0.0, rampup_itr)
phase = 1.0 - var/rampup_itr
return float(np.exp(-5.0 * phase * phase))