Skip to content

Commit

Permalink
LIT: Disable embeddings for TyDi.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 646544804
  • Loading branch information
RyanMullins authored and LIT team committed Jun 25, 2024
1 parent b14e3b1 commit 7ff377f
Showing 1 changed file with 31 additions and 25 deletions.
56 changes: 31 additions & 25 deletions lit_nlp/examples/tydi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,26 @@ def predict(self, inputs: Iterable[_JsonDict], **kw) -> Iterable[_JsonDict]:
total_tokens = self.tokenizer.convert_ids_to_tokens(tokens[0])
# split by question & context
slicer_question, slicer_context = self._segment_slicers(total_tokens)
# get embeddings
embeddings = results.hidden_states[0][0]
# gradient
gradient = results.hidden_states[-1][0]

# TODO(b/349177755): Gradients and embeddings are not implemented
# correctly. Use lit_nlp/examples/prompt_debugging/transformers_lms.py
# code as a reference for how to implement these correctly.
# embeddings = results.hidden_states[0][0]
# gradient = results.hidden_states[-1][0]

prediction_output.append({
"generated_text": self.tokenizer.decode(predict_answer_tokens),
"answers_text": inp["answers_text"],
# Embeddings come from the first token of the last layer.
"cls_emb": results.hidden_states[-1][:, 0][0],
"tokens_question": total_tokens[slicer_question],
"tokens_context": total_tokens[slicer_context],
"grad_class": None,
"tokens_embs_question": np.asarray(embeddings[slicer_question]),
"token_grad_context": np.asarray(embeddings[slicer_context]),
"tokens_grad_question": np.asarray(gradient[slicer_question]),
"tokens_embs_context": np.asarray(gradient[slicer_context])
# TODO(b/349177755): Re-enable these once the embeddings and gradients
# are implemented correctly.
# Embeddings come from the first token of the last layer.
# "cls_emb": results.hidden_states[-1][:, 0][0],
# "tokens_embs_question": np.asarray(embeddings[slicer_question]),
# "token_grad_context": np.asarray(embeddings[slicer_context]),
# "tokens_grad_question": np.asarray(gradient[slicer_question]),
# "tokens_embs_context": np.asarray(gradient[slicer_context]),
})

return prediction_output
Expand All @@ -110,20 +113,23 @@ def input_spec(self):

def output_spec(self):
return {
"answers_text": lit_types.MultiSegmentAnnotations(),
"generated_text": lit_types.GeneratedText(parent="answers_text"),
"cls_emb": lit_types.Embeddings(),
"tokens_question": lit_types.Tokens(parent="question"),
"tokens_embs_question": lit_types.TokenEmbeddings(
align="tokens_question"
),
"tokens_grad_question": lit_types.TokenGradients(
align="tokens_question", grad_for="tokens_embs_question"
),
"tokens_context": lit_types.Tokens(parent="question"),
"tokens_embs_context": lit_types.TokenEmbeddings(
align="tokens_context"
),
"token_grad_context": lit_types.TokenGradients(
align="tokens_context", grad_for="tokens_embs_context"
),
"tokens_question": lit_types.Tokens(parent="question"),
# TODO(b/349177755): Re-enable these once the embeddings and gradients
# are implemented correctly.
# "cls_emb": lit_types.Embeddings(),
# "tokens_embs_question": lit_types.TokenEmbeddings(
# align="tokens_question"
# ),
# "tokens_grad_question": lit_types.TokenGradients(
# align="tokens_question", grad_for="tokens_embs_question"
# ),
# "tokens_embs_context": lit_types.TokenEmbeddings(
# align="tokens_context"
# ),
# "token_grad_context": lit_types.TokenGradients(
# align="tokens_context", grad_for="tokens_embs_context"
# ),
}

0 comments on commit 7ff377f

Please sign in to comment.