From 8caf81bc40032a8e884ec7fe42ea68a8b806f06f Mon Sep 17 00:00:00 2001 From: Justin Date: Tue, 20 Dec 2022 11:15:33 -0600 Subject: [PATCH] Explicitly set max sequence length for the roberta encoder, fix output shape computation, and add unit test. (#2861) --- ludwig/encoders/text_encoders.py | 3 ++- tests/ludwig/encoders/test_text_encoders.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ludwig/encoders/text_encoders.py b/ludwig/encoders/text_encoders.py index dee8588be7f..274c6a1455f 100644 --- a/ludwig/encoders/text_encoders.py +++ b/ludwig/encoders/text_encoders.py @@ -899,6 +899,7 @@ def __init__( ) transformer = RobertaModel(config) self.transformer = FreezeModule(transformer, frozen=not trainable) + self.max_sequence_length = max_sequence_length self.reduce_output = reduce_output if not self.reduce_output == "cls_pooled": self.reduce_sequence = SequenceReducer(reduce_mode=reduce_output) @@ -930,7 +931,7 @@ def input_shape(self) -> torch.Size: @property def output_shape(self) -> torch.Size: if self.reduce_output is None: - return torch.Size([self.max_sequence_length, self.transformer.module.config.hidden_size]) + return torch.Size([self.max_sequence_length - 2, self.transformer.module.config.hidden_size]) return torch.Size([self.transformer.module.config.hidden_size]) @property diff --git a/tests/ludwig/encoders/test_text_encoders.py b/tests/ludwig/encoders/test_text_encoders.py index 650b321b6c3..790ae1a920d 100644 --- a/tests/ludwig/encoders/test_text_encoders.py +++ b/tests/ludwig/encoders/test_text_encoders.py @@ -65,7 +65,7 @@ def test_gpt_encoder(use_pretrained: bool, reduce_output: str, max_sequence_leng @pytest.mark.parametrize("use_pretrained", [False]) -@pytest.mark.parametrize("reduce_output", ["cls_pooled", "sum"]) +@pytest.mark.parametrize("reduce_output", ["cls_pooled", "sum", None]) @pytest.mark.parametrize("max_sequence_length", [20]) def test_roberta_encoder(use_pretrained: bool, reduce_output: str, max_sequence_length: int): roberta_encoder = text_encoders.RoBERTaEncoder(