Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
nguyen-brat committed Oct 29, 2023
1 parent d1b4a64 commit 502071e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
4 changes: 1 addition & 3 deletions model/claim_verification/joint_cross_encoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,10 @@ def __init__(self,config:JointCrossEncoderConfig,):
out_features=config.nclass,
head_num=self.feature_extractor.extractor_config.num_attention_heads,
)
self.positive_classify_linear = nn.Linear(in_features=self.feature_extractor.extractor_config.hidden_size, out_features=1)

def forward(self, fact, is_positive):
fact_embed = self.feature_extractor(fact)
fact_embed = torch.reshape(fact_embed, shape=[-1, self.config.nins] + list(fact_embed.shape[1:])) # batch_size, num_evident, dim
positive_logits = self.positive_classify_linear(fact_embed).squeeze() # batch_size, n_evidents

multi_evident_output = fact_embed
for evident_aggrerator in self.evident_aggrerators:
Expand All @@ -166,7 +164,7 @@ def forward(self, fact, is_positive):
single_evident_output = fact_embed[torch.arange(fact_embed.shape[0]).tolist(), is_positive.tolist(), :] # is positive is a 1-d tensor of id of positive sample (real sample) in every batch sample
single_evident_logits = self.single_evident_linear(single_evident_output) # batch_size, n_labels

return multi_evident_logits, single_evident_logits, positive_logits
return multi_evident_logits, single_evident_logits

def predict(
self,
Expand Down
4 changes: 2 additions & 2 deletions model/claim_verification/joint_cross_encoder/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
lora_alpha=32,
lora_dropout=0.1,
target_modules='feature_extractor.*.query_key_value|feature_extractor.*.dense|evident_aggrerators.*.out_proj',
modules_to_save=['aggerator', 'single_evident_linear', 'positive_classify_linear']
modules_to_save=['aggerator', 'single_evident_linear']
)
self.model = get_peft_model(self.model, peft_config)
print('*********************')
Expand Down Expand Up @@ -187,7 +187,7 @@ def __call__(
for fact_claims_ids, labels, is_positive, is_positive_ohot in tqdm(train_dataloader, desc="Iteration", smoothing=0.05, disable=not show_progress_bar):
optimizer.zero_grad()
with torch.cuda.amp.autocast(): ###########
multi_evident_logits, single_evident_logits, positive_logits = self.model(fact_claims_ids, is_positive)
multi_evident_logits, single_evident_logits = self.model(fact_claims_ids, is_positive)
multi_evident_loss_value = multi_loss_fct(multi_evident_logits, labels)
single_evident_loss_value = multi_loss_fct(single_evident_logits, labels)
# is_positive_loss_value = binary_loss_fct(positive_logits, is_positive_ohot)
Expand Down

0 comments on commit 502071e

Please sign in to comment.