Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions configs/dann_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
LOSS_NEED_INTERMEDIATE_LAYERS = False
UNK_VALUE = -100 # torch default
IS_UNSUPERVISED = True
LOG_PATH = "_log/exp4"

GRADIENT_REVERSAL_LAYER_ALPHA = 1.0
FREZE_BACKBONE_FEATURES = True
Expand All @@ -18,14 +19,16 @@

CLASSES_CNT = 31
MODEL_BACKBONE = "alexnet" # alexnet resnet50 vanilla_dann
DOMAIN_HEAD = "vanilla_dann"
DOMAIN_HEAD = "vanilla_dann" # "vanilla_dann", "dropout_dann", "mnist_dann"
BACKBONE_PRETRAINED = True
NEED_ADAPTATION_BLOCK = True # ="True" only for alexnet, ="False" for other types
ALEXNET_NEED_ADAPTATION_BLOCK = True # ="True" only for alexnet, ="False" for other types
ALEXNET_USE_DROPOUT_IN_CLASS_HEAD_AFTER_ADAPTATION_BLOCK = True # used only if NEED_ADAPTATION_BLOCK == True
BLOCKS_WITH_SMALLER_LR = 2 # ="2" only for alexnet, ="0" for other types
IMAGE_SIZE = 224
DATASET = "office-31"
SOURCE_DOMAIN = "amazon"
TARGET_DOMAIN = "webcam"
RESNET50_USE_DROPOUT_IN_CLASS_HEAD = True

# CLASSES_CNT = 10
# MODEL_BACKBONE = "mnist_dann"
Expand Down
43 changes: 23 additions & 20 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,44 @@


if __name__ == '__main__':
train_gen_s, val_gen_s, test_gen_s = create_data_generators(dann_config.DATASET,
dann_config.SOURCE_DOMAIN,
batch_size=dann_config.BATCH_SIZE,
infinite_train=True,
image_size=dann_config.IMAGE_SIZE,
num_workers=dann_config.NUM_WORKERS,
device=device)
train_gen_s, _, _ = create_data_generators(dann_config.DATASET,
dann_config.SOURCE_DOMAIN,
batch_size=dann_config.BATCH_SIZE,
infinite_train=True,
image_size=dann_config.IMAGE_SIZE,
num_workers=dann_config.NUM_WORKERS,
device=device,
split_ratios=[1.0, 0., 0.])

train_gen_t, val_gen_t, test_gen_t = create_data_generators(dann_config.DATASET,
dann_config.TARGET_DOMAIN,
batch_size=dann_config.BATCH_SIZE,
infinite_train=True,
image_size=dann_config.IMAGE_SIZE,
num_workers=dann_config.NUM_WORKERS,
device=device)
train_gen_t, _, _ = create_data_generators(dann_config.DATASET,
dann_config.TARGET_DOMAIN,
batch_size=dann_config.BATCH_SIZE,
infinite_train=True,
image_size=dann_config.IMAGE_SIZE,
num_workers=dann_config.NUM_WORKERS,
device=device,
split_ratios=[1.0, 0., 0.])

model = DANNModel().to(device)
print(model)
acc = AccuracyScoreFromLogits()

scheduler = LRSchedulerSGD(blocks_with_smaller_lr=dann_config.BLOCKS_WITH_SMALLER_LR)
tr = Trainer(model, loss_DANN)
tr.fit(train_gen_s, train_gen_t,
n_epochs=dann_config.N_EPOCHS,
validation_data=[val_gen_s, val_gen_t],
validation_data=[train_gen_s, train_gen_t],
metrics=[acc],
steps_per_epoch=dann_config.STEPS_PER_EPOCH,
val_freq=dann_config.VAL_FREQ,
opt='sgd',
opt_kwargs={'lr': 0.01, 'momentum': 0.9},
lr_scheduler=scheduler,
callbacks=[print_callback(watch=["loss", "domain_loss", "val_loss",
"val_domain_loss", 'trg_metrics', 'src_metrics']),
callbacks=[print_callback(watch=["loss", "domain_loss",# "val_loss", "val_domain_loss",
'trg_metrics', 'src_metrics']),
ModelSaver('DANN', dann_config.SAVE_MODEL_FREQ),
WandbCallback(),
HistorySaver('log_with_sgd', dann_config.VAL_FREQ, path='_log/DANN_Resnet_sgd',
extra_losses={'domain_loss': ['domain_loss', 'val_domain_loss'],
#WandbCallback(),
HistorySaver('log_with_sgd', dann_config.VAL_FREQ, path=dann_config.LOG_PATH,
extra_losses={'domain_loss': ['domain_loss'],#, 'val_domain_loss'],
'train_domain_loss': ['domain_loss_on_src', 'domain_loss_on_trg']})])
wandb.join()
8 changes: 7 additions & 1 deletion models/backbone_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,13 @@ def get_resnet50():
param.requires_grad = False

pooling = model.avgpool
classifier = nn.Sequential(nn.Linear(2048, dann_config.CLASSES_CNT))
if dann_config.RESNET50_USE_DROPOUT_IN_CLASS_HEAD:
classifier = nn.Sequential(nn.Linear(2048, 1024),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(1024, dann_config.CLASSES_CNT))
else:
classifier = nn.Sequential(nn.Linear(2048, dann_config.CLASSES_CNT))
classifier_layer_ids = [0]
pooling_ftrs = 2048
pooling_output_side = 1
Expand Down
16 changes: 13 additions & 3 deletions models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,25 @@ def __init__(self):
self.features, self.pooling, self.class_classifier, \
domain_input_len, self.classifier_before_domain_cnt = backbone_models.get_backbone_model()

if dann_config.NEED_ADAPTATION_BLOCK:
if dann_config.ALEXNET_NEED_ADAPTATION_BLOCK:
self.adaptation_block = nn.Sequential(
nn.ReLU(),
nn.Linear(domain_input_len, 2048),
nn.ReLU(inplace=True),
)
domain_input_len = 2048
classifier_start_output_len = self.class_classifier[self.classifier_before_domain_cnt][-1].out_features
self.class_classifier[self.classifier_before_domain_cnt][-1] = nn.Linear(2048, classifier_start_output_len)
if dann_config.ALEXNET_USE_DROPOUT_IN_CLASS_HEAD_AFTER_ADAPTATION_BLOCK:
self.class_classifier[self.classifier_before_domain_cnt][-1] = nn.Sequential(
nn.Linear(2048, 2048),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(1024, classifier_start_output_len))
else:
self.class_classifier[self.classifier_before_domain_cnt][-1] = nn.Linear(2048, classifier_start_output_len)

self.domain_classifier = domain_heads.get_domain_head(domain_input_len)

Expand All @@ -48,7 +58,7 @@ def forward(self, input_data, rev_grad_alpha=dann_config.GRADIENT_REVERSAL_LAYER
output_classifier = self.class_classifier[i](output_classifier)
classifier_layers_outputs.append(output_classifier)

if dann_config.NEED_ADAPTATION_BLOCK:
if dann_config.ALEXNET_NEED_ADAPTATION_BLOCK:
output_classifier = self.adaptation_block(output_classifier)

reversed_features = blocks.GradientReversalLayer.apply(output_classifier, rev_grad_alpha)
Expand Down
44 changes: 27 additions & 17 deletions trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ def fit(self, src_data, trg_data, n_epochs=1000, steps_per_epoch=100, val_freq=1
elif opt == 'sgd':
parameters = self.model.parameters()
if hasattr(self.model, "adaptation_block"):
parameters = [{ "params": self.model.features.parameters(), "lr": 0.1 * opt_kwargs["lr"] },
{ "params": self.model.class_classifier[:-1].parameters(), "lr": 0.1 * opt_kwargs["lr"] },
{ "params": self.model.class_classifier[-1].parameters() },
{ "params": self.model.domain_classifier.parameters() },
{ "params": self.model.adaptation_block.parameters() },
]
parameters = [{"params": self.model.features.parameters(), "lr": 0.1 * opt_kwargs["lr"]},
{"params": self.model.class_classifier[:-1].parameters(), "lr": 0.1 * opt_kwargs["lr"]},
{"params": self.model.class_classifier[-1].parameters()},
{"params": self.model.domain_classifier.parameters()},
{"params": self.model.adaptation_block.parameters()},
]
opt = torch.optim.SGD(parameters, **opt_kwargs)
else:
raise NotImplementedError
Expand All @@ -63,31 +63,35 @@ def fit(self, src_data, trg_data, n_epochs=1000, steps_per_epoch=100, val_freq=1
src_val_data, trg_val_data = validation_data

for self.epoch in range(self.epoch, n_epochs):
print(f"Starting epoch {self.epoch}/{n_epochs}")
self.loss_logger.reset_history()
print(f"Starting training")
for step, (src_batch, trg_batch) in enumerate(zip(src_data, trg_data)):
if step == steps_per_epoch:
break
self.train_on_batch(src_batch, trg_batch, opt)

# validation
src_metrics = None
trg_metrics = None
if self.epoch % val_freq == 0 and validation_data is not None:
self.model.eval()

# calculating metrics on validation
print(f"Starting metrics calculation")
if metrics is not None:
if src_val_data is not None:
src_metrics = self.score(src_val_data, metrics)
src_metrics = self.score(src_val_data, metrics, len(src_val_data))
if trg_val_data is not None:
trg_metrics = self.score(trg_val_data, metrics)

# calculating loss on validation
if src_val_data is not None and trg_val_data is not None:
for val_step, (src_batch, trg_batch) in enumerate(zip(src_val_data, trg_val_data)):
loss, loss_info = self.calc_loss(src_batch, trg_batch)
self.loss_logger.store(prefix="val", loss=loss.data.cpu().item(), **loss_info)
trg_metrics = self.score(trg_val_data, metrics, len(trg_val_data))

# calculating loss on validation
#commented - not working with training on ALL source and target data
#print(f"Starting loss on validation calculation")
#if src_val_data is not None and trg_val_data is not None:
# for val_step, (src_batch, trg_batch) in enumerate(zip(src_val_data, trg_val_data)):
# loss, loss_info = self.calc_loss(src_batch, trg_batch)
# self.loss_logger.store(prefix="val", loss=loss.data.cpu().item(), **loss_info)

if callbacks is not None:
epoch_log = dict(**self.loss_logger.get_info())
if src_metrics is not None:
Expand All @@ -100,15 +104,21 @@ def fit(self, src_data, trg_data, n_epochs=1000, steps_per_epoch=100, val_freq=1
if lr_scheduler:
lr_scheduler.step(opt, self.epoch, n_epochs)

def score(self, data, metrics):
def score(self, data, metrics, steps_num):
for metric in metrics:
metric.reset()


# not to iterate infinitely
cur_step = 0

data.reload_iterator()
for images, true_classes in data:
pred_classes = self.model.predict(images)
for metric in metrics:
metric(true_classes, pred_classes)
cur_step += 1
if cur_step == steps_num:
break
data.reload_iterator()
return {metric.name: metric.score for metric in metrics}

Expand Down
11 changes: 8 additions & 3 deletions utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@

def simple_callback(model, epoch_log, current_epoch, total_epoch):
train_loss = epoch_log['loss']
val_loss = epoch_log['val_loss']
if 'val_loss' in epoch_log:
val_loss = epoch_log['val_loss']
trg_metrics = epoch_log['trg_metrics']
src_metrics = epoch_log['src_metrics']
message_head = f'Epoch {current_epoch+1}/{total_epoch}\n'
message_loss = 'loss: {:<10}\t val_loss: {:<10}\t'.format(train_loss, val_loss)
if 'val_loss' in epoch_log:
message_loss = 'loss: {:<10}\t val_loss: {:<10}\t'.format(train_loss, val_loss)
else:
message_loss = 'loss: {:<10}\t'.format(train_loss)
message_src_metrics = ' '.join(['val_src_{}: {:<10}\t'.format(k, v) for k, v in src_metrics.items()])
message_trg_metrics = ' '.join(['val_trg_{}: {:<10}\t'.format(k, v) for k, v in trg_metrics.items()])
print(message_head + message_loss + message_src_metrics + message_trg_metrics)
Expand Down Expand Up @@ -115,7 +119,8 @@ def _save_to_json(self, data, name=None):

def __call__(self, model, epoch_log, current_epoch, total_epoch):
if current_epoch % self.val_freq == 0:
self.loss_history['val_loss'].append(epoch_log['val_loss'])
if "val_loss" in epoch_log:
self.loss_history['val_loss'].append(epoch_log['val_loss'])

for metric in epoch_log['trg_metrics']:
self.trg_metrics_history[metric].append(epoch_log['trg_metrics'][metric])
Expand Down