diff --git a/configs/dann_config.py b/configs/dann_config.py index f1791a8..033215f 100644 --- a/configs/dann_config.py +++ b/configs/dann_config.py @@ -2,6 +2,7 @@ LOSS_NEED_INTERMEDIATE_LAYERS = False UNK_VALUE = -100 # torch default IS_UNSUPERVISED = True +LOG_PATH = "_log/exp4" GRADIENT_REVERSAL_LAYER_ALPHA = 1.0 FREZE_BACKBONE_FEATURES = True @@ -18,14 +19,16 @@ CLASSES_CNT = 31 MODEL_BACKBONE = "alexnet" # alexnet resnet50 vanilla_dann -DOMAIN_HEAD = "vanilla_dann" +DOMAIN_HEAD = "vanilla_dann" # "vanilla_dann", "dropout_dann", "mnist_dann" BACKBONE_PRETRAINED = True -NEED_ADAPTATION_BLOCK = True # ="True" only for alexnet, ="False" for other types +ALEXNET_NEED_ADAPTATION_BLOCK = True # ="True" only for alexnet, ="False" for other types +ALEXNET_USE_DROPOUT_IN_CLASS_HEAD_AFTER_ADAPTATION_BLOCK = True # used only if NEED_ADAPTATION_BLOCK == True BLOCKS_WITH_SMALLER_LR = 2 # ="2" only for alexnet, ="0" for other types IMAGE_SIZE = 224 DATASET = "office-31" SOURCE_DOMAIN = "amazon" TARGET_DOMAIN = "webcam" +RESNET50_USE_DROPOUT_IN_CLASS_HEAD = True # CLASSES_CNT = 10 # MODEL_BACKBONE = "mnist_dann" diff --git a/example.py b/example.py index 62dab53..fbdadf1 100644 --- a/example.py +++ b/example.py @@ -16,41 +16,44 @@ if __name__ == '__main__': - train_gen_s, val_gen_s, test_gen_s = create_data_generators(dann_config.DATASET, - dann_config.SOURCE_DOMAIN, - batch_size=dann_config.BATCH_SIZE, - infinite_train=True, - image_size=dann_config.IMAGE_SIZE, - num_workers=dann_config.NUM_WORKERS, - device=device) + train_gen_s, _, _ = create_data_generators(dann_config.DATASET, + dann_config.SOURCE_DOMAIN, + batch_size=dann_config.BATCH_SIZE, + infinite_train=True, + image_size=dann_config.IMAGE_SIZE, + num_workers=dann_config.NUM_WORKERS, + device=device, + split_ratios=[1.0, 0., 0.]) - train_gen_t, val_gen_t, test_gen_t = create_data_generators(dann_config.DATASET, - dann_config.TARGET_DOMAIN, - batch_size=dann_config.BATCH_SIZE, - infinite_train=True, - image_size=dann_config.IMAGE_SIZE, - num_workers=dann_config.NUM_WORKERS, - device=device) + train_gen_t, _, _ = create_data_generators(dann_config.DATASET, + dann_config.TARGET_DOMAIN, + batch_size=dann_config.BATCH_SIZE, + infinite_train=True, + image_size=dann_config.IMAGE_SIZE, + num_workers=dann_config.NUM_WORKERS, + device=device, + split_ratios=[1.0, 0., 0.]) model = DANNModel().to(device) + print(model) acc = AccuracyScoreFromLogits() scheduler = LRSchedulerSGD(blocks_with_smaller_lr=dann_config.BLOCKS_WITH_SMALLER_LR) tr = Trainer(model, loss_DANN) tr.fit(train_gen_s, train_gen_t, n_epochs=dann_config.N_EPOCHS, - validation_data=[val_gen_s, val_gen_t], + validation_data=[train_gen_s, train_gen_t], metrics=[acc], steps_per_epoch=dann_config.STEPS_PER_EPOCH, val_freq=dann_config.VAL_FREQ, opt='sgd', opt_kwargs={'lr': 0.01, 'momentum': 0.9}, lr_scheduler=scheduler, - callbacks=[print_callback(watch=["loss", "domain_loss", "val_loss", - "val_domain_loss", 'trg_metrics', 'src_metrics']), + callbacks=[print_callback(watch=["loss", "domain_loss",# "val_loss", "val_domain_loss", + 'trg_metrics', 'src_metrics']), ModelSaver('DANN', dann_config.SAVE_MODEL_FREQ), - WandbCallback(), - HistorySaver('log_with_sgd', dann_config.VAL_FREQ, path='_log/DANN_Resnet_sgd', - extra_losses={'domain_loss': ['domain_loss', 'val_domain_loss'], + #WandbCallback(), + HistorySaver('log_with_sgd', dann_config.VAL_FREQ, path=dann_config.LOG_PATH, + extra_losses={'domain_loss': ['domain_loss'],#, 'val_domain_loss'], 'train_domain_loss': ['domain_loss_on_src', 'domain_loss_on_trg']})]) wandb.join() diff --git a/models/backbone_models.py b/models/backbone_models.py index a718b6e..ba585bf 100644 --- a/models/backbone_models.py +++ b/models/backbone_models.py @@ -97,7 +97,13 @@ def get_resnet50(): param.requires_grad = False pooling = model.avgpool - classifier = nn.Sequential(nn.Linear(2048, dann_config.CLASSES_CNT)) + if dann_config.RESNET50_USE_DROPOUT_IN_CLASS_HEAD: + classifier = nn.Sequential(nn.Linear(2048, 1024), + nn.ReLU(), + nn.Dropout(0.5), + nn.Linear(1024, dann_config.CLASSES_CNT)) + else: + classifier = nn.Sequential(nn.Linear(2048, dann_config.CLASSES_CNT)) classifier_layer_ids = [0] pooling_ftrs = 2048 pooling_output_side = 1 diff --git a/models/models.py b/models/models.py index 05437c5..d3a55c5 100644 --- a/models/models.py +++ b/models/models.py @@ -19,7 +19,7 @@ def __init__(self): self.features, self.pooling, self.class_classifier, \ domain_input_len, self.classifier_before_domain_cnt = backbone_models.get_backbone_model() - if dann_config.NEED_ADAPTATION_BLOCK: + if dann_config.ALEXNET_NEED_ADAPTATION_BLOCK: self.adaptation_block = nn.Sequential( nn.ReLU(), nn.Linear(domain_input_len, 2048), @@ -27,7 +27,17 @@ def __init__(self): ) domain_input_len = 2048 classifier_start_output_len = self.class_classifier[self.classifier_before_domain_cnt][-1].out_features - self.class_classifier[self.classifier_before_domain_cnt][-1] = nn.Linear(2048, classifier_start_output_len) + if dann_config.ALEXNET_USE_DROPOUT_IN_CLASS_HEAD_AFTER_ADAPTATION_BLOCK: + self.class_classifier[self.classifier_before_domain_cnt][-1] = nn.Sequential( + nn.Linear(2048, 2048), + nn.ReLU(), + nn.Dropout(0.5), + nn.Linear(2048, 1024), + nn.ReLU(), + nn.Dropout(0.5), + nn.Linear(1024, classifier_start_output_len)) + else: + self.class_classifier[self.classifier_before_domain_cnt][-1] = nn.Linear(2048, classifier_start_output_len) self.domain_classifier = domain_heads.get_domain_head(domain_input_len) @@ -48,7 +58,7 @@ def forward(self, input_data, rev_grad_alpha=dann_config.GRADIENT_REVERSAL_LAYER output_classifier = self.class_classifier[i](output_classifier) classifier_layers_outputs.append(output_classifier) - if dann_config.NEED_ADAPTATION_BLOCK: + if dann_config.ALEXNET_NEED_ADAPTATION_BLOCK: output_classifier = self.adaptation_block(output_classifier) reversed_features = blocks.GradientReversalLayer.apply(output_classifier, rev_grad_alpha) diff --git a/trainer/trainer.py b/trainer/trainer.py index be332e0..149bef8 100644 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -49,12 +49,12 @@ def fit(self, src_data, trg_data, n_epochs=1000, steps_per_epoch=100, val_freq=1 elif opt == 'sgd': parameters = self.model.parameters() if hasattr(self.model, "adaptation_block"): - parameters = [{ "params": self.model.features.parameters(), "lr": 0.1 * opt_kwargs["lr"] }, - { "params": self.model.class_classifier[:-1].parameters(), "lr": 0.1 * opt_kwargs["lr"] }, - { "params": self.model.class_classifier[-1].parameters() }, - { "params": self.model.domain_classifier.parameters() }, - { "params": self.model.adaptation_block.parameters() }, - ] + parameters = [{"params": self.model.features.parameters(), "lr": 0.1 * opt_kwargs["lr"]}, + {"params": self.model.class_classifier[:-1].parameters(), "lr": 0.1 * opt_kwargs["lr"]}, + {"params": self.model.class_classifier[-1].parameters()}, + {"params": self.model.domain_classifier.parameters()}, + {"params": self.model.adaptation_block.parameters()}, + ] opt = torch.optim.SGD(parameters, **opt_kwargs) else: raise NotImplementedError @@ -63,12 +63,13 @@ def fit(self, src_data, trg_data, n_epochs=1000, steps_per_epoch=100, val_freq=1 src_val_data, trg_val_data = validation_data for self.epoch in range(self.epoch, n_epochs): + print(f"Starting epoch {self.epoch}/{n_epochs}") self.loss_logger.reset_history() + print(f"Starting training") for step, (src_batch, trg_batch) in enumerate(zip(src_data, trg_data)): if step == steps_per_epoch: break self.train_on_batch(src_batch, trg_batch, opt) - # validation src_metrics = None trg_metrics = None @@ -76,18 +77,21 @@ def fit(self, src_data, trg_data, n_epochs=1000, steps_per_epoch=100, val_freq=1 self.model.eval() # calculating metrics on validation + print(f"Starting metrics calculation") if metrics is not None: if src_val_data is not None: - src_metrics = self.score(src_val_data, metrics) + src_metrics = self.score(src_val_data, metrics, len(src_val_data)) if trg_val_data is not None: - trg_metrics = self.score(trg_val_data, metrics) - - # calculating loss on validation - if src_val_data is not None and trg_val_data is not None: - for val_step, (src_batch, trg_batch) in enumerate(zip(src_val_data, trg_val_data)): - loss, loss_info = self.calc_loss(src_batch, trg_batch) - self.loss_logger.store(prefix="val", loss=loss.data.cpu().item(), **loss_info) + trg_metrics = self.score(trg_val_data, metrics, len(trg_val_data)) + # calculating loss on validation + #commented - not working with training on ALL source and target data + #print(f"Starting loss on validation calculation") + #if src_val_data is not None and trg_val_data is not None: + # for val_step, (src_batch, trg_batch) in enumerate(zip(src_val_data, trg_val_data)): + # loss, loss_info = self.calc_loss(src_batch, trg_batch) + # self.loss_logger.store(prefix="val", loss=loss.data.cpu().item(), **loss_info) + if callbacks is not None: epoch_log = dict(**self.loss_logger.get_info()) if src_metrics is not None: @@ -100,15 +104,21 @@ def fit(self, src_data, trg_data, n_epochs=1000, steps_per_epoch=100, val_freq=1 if lr_scheduler: lr_scheduler.step(opt, self.epoch, n_epochs) - def score(self, data, metrics): + def score(self, data, metrics, steps_num): for metric in metrics: metric.reset() - + + # not to iterate infinitely + cur_step = 0 + data.reload_iterator() for images, true_classes in data: pred_classes = self.model.predict(images) for metric in metrics: metric(true_classes, pred_classes) + cur_step += 1 + if cur_step == steps_num: + break data.reload_iterator() return {metric.name: metric.score for metric in metrics} diff --git a/utils/callbacks.py b/utils/callbacks.py index 4291ccb..56afe6c 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -5,11 +5,15 @@ def simple_callback(model, epoch_log, current_epoch, total_epoch): train_loss = epoch_log['loss'] - val_loss = epoch_log['val_loss'] + if 'val_loss' in epoch_log: + val_loss = epoch_log['val_loss'] trg_metrics = epoch_log['trg_metrics'] src_metrics = epoch_log['src_metrics'] message_head = f'Epoch {current_epoch+1}/{total_epoch}\n' - message_loss = 'loss: {:<10}\t val_loss: {:<10}\t'.format(train_loss, val_loss) + if 'val_loss' in epoch_log: + message_loss = 'loss: {:<10}\t val_loss: {:<10}\t'.format(train_loss, val_loss) + else: + message_loss = 'loss: {:<10}\t'.format(train_loss) message_src_metrics = ' '.join(['val_src_{}: {:<10}\t'.format(k, v) for k, v in src_metrics.items()]) message_trg_metrics = ' '.join(['val_trg_{}: {:<10}\t'.format(k, v) for k, v in trg_metrics.items()]) print(message_head + message_loss + message_src_metrics + message_trg_metrics) @@ -115,7 +119,8 @@ def _save_to_json(self, data, name=None): def __call__(self, model, epoch_log, current_epoch, total_epoch): if current_epoch % self.val_freq == 0: - self.loss_history['val_loss'].append(epoch_log['val_loss']) + if "val_loss" in epoch_log: + self.loss_history['val_loss'].append(epoch_log['val_loss']) for metric in epoch_log['trg_metrics']: self.trg_metrics_history[metric].append(epoch_log['trg_metrics'][metric])