From 502071ed7829a7a9e0ceaef535e31aede23d6a62 Mon Sep 17 00:00:00 2001 From: nguyen-brat Date: Mon, 30 Oct 2023 01:25:56 +0700 Subject: [PATCH] update --- model/claim_verification/joint_cross_encoder/model.py | 4 +--- model/claim_verification/joint_cross_encoder/trainer.py | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/model/claim_verification/joint_cross_encoder/model.py b/model/claim_verification/joint_cross_encoder/model.py index c2b0a24..1fcb84c 100644 --- a/model/claim_verification/joint_cross_encoder/model.py +++ b/model/claim_verification/joint_cross_encoder/model.py @@ -151,12 +151,10 @@ def __init__(self,config:JointCrossEncoderConfig,): out_features=config.nclass, head_num=self.feature_extractor.extractor_config.num_attention_heads, ) - self.positive_classify_linear = nn.Linear(in_features=self.feature_extractor.extractor_config.hidden_size, out_features=1) def forward(self, fact, is_positive): fact_embed = self.feature_extractor(fact) fact_embed = torch.reshape(fact_embed, shape=[-1, self.config.nins] + list(fact_embed.shape[1:])) # batch_size, num_evident, dim - positive_logits = self.positive_classify_linear(fact_embed).squeeze() # batch_size, n_evidents multi_evident_output = fact_embed for evident_aggrerator in self.evident_aggrerators: @@ -166,7 +164,7 @@ def forward(self, fact, is_positive): single_evident_output = fact_embed[torch.arange(fact_embed.shape[0]).tolist(), is_positive.tolist(), :] # is positive is a 1-d tensor of id of positive sample (real sample) in every batch sample single_evident_logits = self.single_evident_linear(single_evident_output) # batch_size, n_labels - return multi_evident_logits, single_evident_logits, positive_logits + return multi_evident_logits, single_evident_logits def predict( self, diff --git a/model/claim_verification/joint_cross_encoder/trainer.py b/model/claim_verification/joint_cross_encoder/trainer.py index 9a86971..33dc93b 100644 --- a/model/claim_verification/joint_cross_encoder/trainer.py +++ b/model/claim_verification/joint_cross_encoder/trainer.py @@ -56,7 +56,7 @@ def __init__( lora_alpha=32, lora_dropout=0.1, target_modules='feature_extractor.*.query_key_value|feature_extractor.*.dense|evident_aggrerators.*.out_proj', - modules_to_save=['aggerator', 'single_evident_linear', 'positive_classify_linear'] + modules_to_save=['aggerator', 'single_evident_linear'] ) self.model = get_peft_model(self.model, peft_config) print('*********************') @@ -187,7 +187,7 @@ def __call__( for fact_claims_ids, labels, is_positive, is_positive_ohot in tqdm(train_dataloader, desc="Iteration", smoothing=0.05, disable=not show_progress_bar): optimizer.zero_grad() with torch.cuda.amp.autocast(): ########### - multi_evident_logits, single_evident_logits, positive_logits = self.model(fact_claims_ids, is_positive) + multi_evident_logits, single_evident_logits = self.model(fact_claims_ids, is_positive) multi_evident_loss_value = multi_loss_fct(multi_evident_logits, labels) single_evident_loss_value = multi_loss_fct(single_evident_logits, labels) # is_positive_loss_value = binary_loss_fct(positive_logits, is_positive_ohot)