Skip to content

Commit

Permalink
return hidden states in output of model wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
jeswan committed Jan 19, 2021
1 parent 816496b commit c897a62
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions jiant/proj/main/modeling/taskmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ def __init__(self, encoder, pooler_head: heads.AbstractPoolerHead, layer):

def forward(self, batch, task, tokenizer, compute_loss: bool = False):
encoder_output = get_output_from_encoder_and_batch(
encoder=self.encoder, batch=batch, output_hidden_states=True)
encoder=self.encoder, batch=batch, output_hidden_states=True
)
# A tuple of layers of hidden states
hidden_states = take_one(encoder_output.other)
layer_hidden_states = hidden_states[self.layer]
Expand Down Expand Up @@ -355,7 +356,7 @@ def get_output_from_standard_transformer_models(
attention_mask=input_mask,
output_hidden_states=output_hidden_states,
)
return output.pooler_output, output.last_hidden_state, output
return output.pooler_output, output.last_hidden_state, output.hidden_states


def get_output_from_bart_models(encoder, input_ids, input_mask, output_hidden_states=False):
Expand All @@ -368,20 +369,18 @@ def get_output_from_bart_models(encoder, input_ids, input_mask, output_hidden_st
output = encoder(
input_ids=input_ids, attention_mask=input_mask, output_hidden_states=output_hidden_states,
)
dec_last = output.last_hidden_state
dec_all = output.decoder_hidden_states
enc_last = output.encoder_last_hidden_state
enc_all = output.encoder_hidden_states

unpooled = output

other = (enc_all + dec_all,)
hidden_states = (enc_all + dec_all,)

bsize, slen = input_ids.shape
batch_idx = torch.arange(bsize).to(input_ids.device)
# Get last non-pad index
pooled = unpooled[batch_idx, slen - input_ids.eq(encoder.config.pad_token_id).sum(1) - 1]
return pooled, unpooled, other
return pooled, unpooled, hidden_states


def get_output_from_electra(encoder, input_ids, segment_ids, input_mask):
Expand All @@ -391,9 +390,9 @@ def get_output_from_electra(encoder, input_ids, segment_ids, input_mask):
attention_mask=input_mask,
output_hidden_states=output_hidden_states,
)
unpooled = output.last_hidden_state
unpooled = output.hidden_states
pooled = unpooled[:, 0, :]
return pooled, unpooled, output
return pooled, unpooled, output.hidden_states


def compute_mlm_loss(logits, masked_lm_labels):
Expand Down

0 comments on commit c897a62

Please sign in to comment.