Skip to content

Commit

Permalink
Cast LLMEncoder output to torch.float32, freeze final layer at in…
Browse files Browse the repository at this point in the history
…it. (#3900)
  • Loading branch information
jeffkinnison authored Jan 19, 2024
1 parent 495a6bf commit 27c6079
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
17 changes: 9 additions & 8 deletions ludwig/encoders/text_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2425,6 +2425,14 @@ def __init__(self, encoder_config: LLMEncoderConfig = None, **kwargs):

clear_data_cache()

# Because we use the last hidden state as encoder output rather than the logits, the final module of the model
# has input pass through but no gradient update in the backward pass. This can lead to a DDP error. Freezing
# the module prevents this from happening. This is done at initialization to prevent "unused parameters" errors
# from happening when the encoder is used before `prepare_for_training` is called, for example during batch
# size tuning.
out_module = list(self.model.modules())[-1]
out_module.requires_grad_(requires_grad=False)

@staticmethod
def get_schema_cls() -> Type[BaseEncoderConfig]:
return LLMEncoderConfig
Expand Down Expand Up @@ -2459,13 +2467,6 @@ def prepare_for_training(self):
self.prepare_for_quantized_training()
self.initialize_adapter()

# Because we use the last hidden state as encoder output rather than the logits, the final module of the model
# has input pass through but no gradient update in the backward pass. This can lead to a DDP error. Freezing
# the module prevents this from happening.
if not self.config.adapter:
out_module = list(self.model.modules())[-1]
out_module.requires_grad_(requires_grad=False)

def prepare_for_quantized_training(self):
from peft import prepare_model_for_kbit_training

Expand All @@ -2479,7 +2480,7 @@ def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None):
# Get the hidden state of the last layer and return it as the text encoding
model_outputs = self.model(input_ids=inputs, output_hidden_states=True).hidden_states[-1]

return {ENCODER_OUTPUT: model_outputs}
return {ENCODER_OUTPUT: model_outputs.type(torch.float32)}

def _save_to_state_dict(self, destination: Dict, prefix: str, keep_vars: bool):
# This is called by `torch.nn.Module.state_dict()` under the hood. `state_dict()` does additional work to
Expand Down
12 changes: 12 additions & 0 deletions tests/ludwig/encoders/test_llm_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ def test_init(self, encoder_config: LLMEncoderConfig, model_config):
assert encoder.input_shape == torch.Size([encoder_config.max_sequence_length])
assert encoder.output_shape == torch.Size([encoder_config.max_sequence_length, model_config.hidden_size])

# The final layer must not be trainable because it is not used
last_module = list(encoder.model.modules())[-1]
assert all(not p.requires_grad for p in last_module.parameters())

# Test that max sequence length falls back to the context length when too large
context_len = get_context_len(model_config)
cl_config = copy.deepcopy(encoder_config)
Expand All @@ -104,6 +108,10 @@ def test_init(self, encoder_config: LLMEncoderConfig, model_config):
assert encoder.input_shape == torch.Size([context_len])
assert encoder.output_shape == torch.Size([context_len, model_config.hidden_size])

# The final layer must not be trainable because it is not used
last_module = list(encoder.model.modules())[-1]
assert all(not p.requires_grad for p in last_module.parameters())

@pytest.mark.parametrize("adapter", list(ADAPTER_CONFIG_MAP.keys()))
def test_init_with_adapter(self, encoder_config: LLMEncoderConfig, adapter: str, model_config):
from peft import PeftModel
Expand All @@ -120,6 +128,10 @@ def test_init_with_adapter(self, encoder_config: LLMEncoderConfig, adapter: str,
assert encoder.input_shape == torch.Size([encoder_config.max_sequence_length])
assert encoder.output_shape == torch.Size([encoder_config.max_sequence_length, model_config.hidden_size])

# The final layer must not be trainable because it is not used
last_module = list(encoder.model.modules())[-1]
assert all(not p.requires_grad for p in last_module.parameters())

@pytest.mark.parametrize("adapter", list(ADAPTER_CONFIG_MAP.keys()))
def test_prepare_for_training(self, encoder_config: LLMEncoderConfig, adapter: str):
from peft import PeftModel
Expand Down

0 comments on commit 27c6079

Please sign in to comment.