From 85dfd330e6c8d2c338577528fc572a90fc35d1fe Mon Sep 17 00:00:00 2001 From: root Date: Wed, 30 Oct 2024 23:36:12 +0800 Subject: [PATCH] fix bug mentioned in issue #10 --- src/algorithms/MMFL.py | 53 +++++++++++++++++++++++---------------- src/datasets/dataset_L.py | 1 + 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/src/algorithms/MMFL.py b/src/algorithms/MMFL.py index 4c26c8c..3fd4f77 100644 --- a/src/algorithms/MMFL.py +++ b/src/algorithms/MMFL.py @@ -297,27 +297,38 @@ def distill(self, round_n, img_vec, txt_vec, img_num, txt_num, distill_index): def aggregation(i_vec=img_vec, t_vec=txt_vec, i_num=img_num, t_num=txt_num): if self.args.agg_method == "con_w": - contrastive_w = [] - for vec in i_vec: # vec: [50000, n_feature], global_txt_feature: [50000, n_feature] - logits = torch.matmul(vec, self.global_txt_feature.T) # [50000, 50000] - exp_logits = torch.exp(logits) - log_prob = logits - torch.log(torch.sum(exp_logits, dim=1, keepdim=True)) - contrastive_w.append(torch.diagonal(log_prob).reshape(1, -1)) - contrastive_w = torch.softmax(torch.cat(contrastive_w, dim=0), dim=0) - for i in range(len(i_vec)): - i_vec[i] = (i_vec[i] * contrastive_w[i].reshape(-1, 1)).unsqueeze(0) - i_vec = torch.sum(torch.cat(i_vec, dim=0), dim=0) # aggregated image vectors - - contrastive_w = [] - for vec in t_vec: # vec: [50000, n_feature], global_txt_feature: [50000, n_feature] - logits = torch.matmul(vec, self.global_img_feature.T) # [50000, 50000] - exp_logits = torch.exp(logits) - log_prob = logits - torch.log(torch.sum(exp_logits, dim=1, keepdim=True)) - contrastive_w.append(torch.diagonal(log_prob).reshape(1, -1)) - contrastive_w = torch.softmax(torch.cat(contrastive_w, dim=0), dim=0) - for i in range(len(t_vec)): - t_vec[i] = (t_vec[i] * contrastive_w[i].reshape(-1, 1)).unsqueeze(0) - t_vec = torch.sum(torch.cat(t_vec, dim=0), dim=0) # aggregated text vectors + if i_vec: + num_i_vec=len(i_vec) + contrastive_w = torch.zeros(num_i_vec, 50000) + for i_idx,vec in enumerate(i_vec): # vec: [50000, n_feature], global_txt_feature: [50000, n_feature] + logits = torch.matmul(vec, self.global_txt_feature.T) # [50000, 50000] + exp_logits = torch.exp(logits) + log_prob = logits - torch.log(torch.sum(exp_logits, dim=1, keepdim=True)) + contrastive_w[i_idx]=torch.diagonal(log_prob).reshape(-1) + del logits, exp_logits, log_prob + torch.cuda.empty_cache() + gc.collect() + contrastive_w[:num_i_vec] = torch.softmax(contrastive_w[:num_i_vec], dim=0) + for i in range(len(i_vec)): + i_vec[i] = (i_vec[i] * contrastive_w[i].reshape(-1, 1)).unsqueeze(0) + i_vec = torch.sum(torch.cat(i_vec, dim=0), dim=0) + + + if t_vec: + num_t_vec=len(t_vec) + contrastive_w = torch.zeros(num_t_vec, 50000) + for t_idx,vec in enumerate(t_vec): # vec: [50000, n_feature], global_txt_feature: [50000, n_feature] + logits = torch.matmul(vec, self.global_img_feature.T) # [50000, 50000] + exp_logits = torch.exp(logits) + log_prob = logits - torch.log(torch.sum(exp_logits, dim=1, keepdim=True)) + contrastive_w[t_idx]=torch.diagonal(log_prob).reshape(-1) + del logits, exp_logits, log_prob + torch.cuda.empty_cache() + gc.collect() + contrastive_w[:num_t_vec] = torch.softmax(contrastive_w[:num_t_vec], dim=0) + for i in range(len(t_vec)): + t_vec[i] = (t_vec[i] * contrastive_w[i].reshape(-1, 1)).unsqueeze(0) + t_vec = torch.sum(torch.cat(t_vec, dim=0), dim=0) # aggregated text vectors else: raise NotImplementedError diff --git a/src/datasets/dataset_L.py b/src/datasets/dataset_L.py index 6d3ad2e..28474d3 100644 --- a/src/datasets/dataset_L.py +++ b/src/datasets/dataset_L.py @@ -14,6 +14,7 @@ from torch.utils.data import DataLoader from torchtext.data.utils import get_tokenizer from torchtext.vocab import build_vocab_from_iterator +import torchtext.datasets from torchvision import transforms from tqdm import tqdm