Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to retrieve hidden_states #52

Open
vsoesanto opened this issue Apr 6, 2022 · 2 comments
Open

Unable to retrieve hidden_states #52

vsoesanto opened this issue Apr 6, 2022 · 2 comments

Comments

@vsoesanto
Copy link

I converted a locally saved T5 checkpoint to ONNX using FastT5:

>>> from fastT5 import export_and_get_onnx_model
>>> from transformers import AutoTokenizer

>>> model_checkpoint = "path/to/checkpoint"
>>> model = export_and_get_onnx_model(model_name)

I tested it for inference:

>>> tokenizer = AutoTokenizer.from_pretrained(model_name)

>>> token = tokenizer(input_terms, max_length=512 * 2, padding=True, truncation=True, return_tensors='pt')

>>> out = model.generate(input_ids=token['input_ids'].to('cpu'),
                            attention_mask=token['attention_mask'].to('cpu'),
                            return_dict_in_generate=True,
                            max_length=512 * 2,
                            num_beams=1,
                            output_scores=True,
                            output_hidden_states=True)

>>> out.encoder_hidden_states
>>> out.decoder_hidden_states
(None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
...

>>> out
GreedySearchEncoderDecoderOutput(sequences=tensor([[  0, 119, 114, 102, 108, 111, 108, 125, 120, 112, 100, 101,  35,  53, ...
...
), , encoder_attentions=None, encoder_hidden_states=None, decoder_attentions=None, cross_attentions=None, decoder_hidden_states=(None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None))

The hidden states are all None.

Is there any way that I can retrieve the hidden states for both encoder and decoder?

@vsoesanto
Copy link
Author

@Ki6an I also tried by adding "output_hidden_states" = True in the onnx model's config. Also made sure this argument is added in model.generate() call, but still no luck. Any idea how I can retrieve the encoder/decoder hidden states?

@Ki6an
Copy link
Owner

Ki6an commented May 19, 2022

sorry for the late reply,

you can get the hidden states of the encoder easily just by sending in the input_ids and attention mask to the encoder as shown below

...
model = export_and_get_onnx_model(model_name)
encoder = model.encoder

hidden_state = encoder(input_ids, attention_mask)

but for the decoder, you need to make lots of changes. you can start by making changes here

def forward(self, input_ids, attention_mask, encoder_hidden_states):
decoder_output = self.decoder(
input_ids=input_ids,
encoder_attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
)
return (
self.lm_head(decoder_output[0] * (self.config.d_model ** -0.5)),
decoder_output[1],
)

decoder_output[0] is the last_hidden_state

make it return that value as well

     return ( 
         decoder_output[0],
         self.lm_head(decoder_output[0] * (self.config.d_model ** -0.5)), 
         decoder_output[1], 
     ) 

also, do the same changes for the decoder.

then, retrieve those values from ort session here and here

finally, pass those values here

return Seq2SeqLMOutput(logits=logits, past_key_values=past_key_values)

as decoder_hidden_states=

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants