From 7ff377f92820748476e796994fd207e1b5dba1d9 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Tue, 25 Jun 2024 11:19:56 -0700 Subject: [PATCH] LIT: Disable embeddings for TyDi. PiperOrigin-RevId: 646544804 --- lit_nlp/examples/tydi/model.py | 56 +++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/lit_nlp/examples/tydi/model.py b/lit_nlp/examples/tydi/model.py index 7f35f760..55001377 100644 --- a/lit_nlp/examples/tydi/model.py +++ b/lit_nlp/examples/tydi/model.py @@ -76,23 +76,26 @@ def predict(self, inputs: Iterable[_JsonDict], **kw) -> Iterable[_JsonDict]: total_tokens = self.tokenizer.convert_ids_to_tokens(tokens[0]) # split by question & context slicer_question, slicer_context = self._segment_slicers(total_tokens) - # get embeddings - embeddings = results.hidden_states[0][0] - # gradient - gradient = results.hidden_states[-1][0] + + # TODO(b/349177755): Gradients and embeddings are not implemented + # correctly. Use lit_nlp/examples/prompt_debugging/transformers_lms.py + # code as a reference for how to implement these correctly. + # embeddings = results.hidden_states[0][0] + # gradient = results.hidden_states[-1][0] prediction_output.append({ "generated_text": self.tokenizer.decode(predict_answer_tokens), "answers_text": inp["answers_text"], - # Embeddings come from the first token of the last layer. - "cls_emb": results.hidden_states[-1][:, 0][0], "tokens_question": total_tokens[slicer_question], "tokens_context": total_tokens[slicer_context], - "grad_class": None, - "tokens_embs_question": np.asarray(embeddings[slicer_question]), - "token_grad_context": np.asarray(embeddings[slicer_context]), - "tokens_grad_question": np.asarray(gradient[slicer_question]), - "tokens_embs_context": np.asarray(gradient[slicer_context]) + # TODO(b/349177755): Re-enable these once the embeddings and gradients + # are implemented correctly. + # Embeddings come from the first token of the last layer. + # "cls_emb": results.hidden_states[-1][:, 0][0], + # "tokens_embs_question": np.asarray(embeddings[slicer_question]), + # "token_grad_context": np.asarray(embeddings[slicer_context]), + # "tokens_grad_question": np.asarray(gradient[slicer_question]), + # "tokens_embs_context": np.asarray(gradient[slicer_context]), }) return prediction_output @@ -110,20 +113,23 @@ def input_spec(self): def output_spec(self): return { + "answers_text": lit_types.MultiSegmentAnnotations(), "generated_text": lit_types.GeneratedText(parent="answers_text"), - "cls_emb": lit_types.Embeddings(), - "tokens_question": lit_types.Tokens(parent="question"), - "tokens_embs_question": lit_types.TokenEmbeddings( - align="tokens_question" - ), - "tokens_grad_question": lit_types.TokenGradients( - align="tokens_question", grad_for="tokens_embs_question" - ), "tokens_context": lit_types.Tokens(parent="question"), - "tokens_embs_context": lit_types.TokenEmbeddings( - align="tokens_context" - ), - "token_grad_context": lit_types.TokenGradients( - align="tokens_context", grad_for="tokens_embs_context" - ), + "tokens_question": lit_types.Tokens(parent="question"), + # TODO(b/349177755): Re-enable these once the embeddings and gradients + # are implemented correctly. + # "cls_emb": lit_types.Embeddings(), + # "tokens_embs_question": lit_types.TokenEmbeddings( + # align="tokens_question" + # ), + # "tokens_grad_question": lit_types.TokenGradients( + # align="tokens_question", grad_for="tokens_embs_question" + # ), + # "tokens_embs_context": lit_types.TokenEmbeddings( + # align="tokens_context" + # ), + # "token_grad_context": lit_types.TokenGradients( + # align="tokens_context", grad_for="tokens_embs_context" + # ), }