From be42ff0fee5f5da5909e2ffba62e4ddf08a13905 Mon Sep 17 00:00:00 2001 From: Jesse Swanson Date: Mon, 18 Jan 2021 21:31:43 -0800 Subject: [PATCH] return hidden states in output of model wrapper --- jiant/proj/main/modeling/taskmodels.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/jiant/proj/main/modeling/taskmodels.py b/jiant/proj/main/modeling/taskmodels.py index 68ad48b42..2354d7db7 100644 --- a/jiant/proj/main/modeling/taskmodels.py +++ b/jiant/proj/main/modeling/taskmodels.py @@ -234,7 +234,8 @@ def __init__(self, encoder, pooler_head: heads.AbstractPoolerHead, layer): def forward(self, batch, task, tokenizer, compute_loss: bool = False): encoder_output = get_output_from_encoder_and_batch( - encoder=self.encoder, batch=batch, output_hidden_states=True) + encoder=self.encoder, batch=batch, output_hidden_states=True + ) # A tuple of layers of hidden states hidden_states = take_one(encoder_output.other) layer_hidden_states = hidden_states[self.layer] @@ -355,7 +356,7 @@ def get_output_from_standard_transformer_models( attention_mask=input_mask, output_hidden_states=output_hidden_states, ) - return output.pooler_output, output.last_hidden_state, output + return output.pooler_output, output.last_hidden_state, output.hidden_states def get_output_from_bart_models(encoder, input_ids, input_mask, output_hidden_states=False): @@ -375,17 +376,17 @@ def get_output_from_bart_models(encoder, input_ids, input_mask, output_hidden_st unpooled = output - other = (enc_all + dec_all,) + hidden_states = (enc_all + dec_all,) bsize, slen = input_ids.shape batch_idx = torch.arange(bsize).to(input_ids.device) # Get last non-pad index pooled = unpooled[batch_idx, slen - input_ids.eq(encoder.config.pad_token_id).sum(1) - 1] - return pooled, unpooled, other + return pooled, unpooled, hidden_states def get_output_from_electra(encoder, input_ids, segment_ids, input_mask): - output = encoder( + hidden_states = encoder( input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, @@ -393,7 +394,7 @@ def get_output_from_electra(encoder, input_ids, segment_ids, input_mask): ) unpooled = output.last_hidden_state pooled = unpooled[:, 0, :] - return pooled, unpooled, output + return pooled, unpooled, hidden_states def compute_mlm_loss(logits, masked_lm_labels):