diff --git a/requirements.txt b/requirements.txt index 72fbdeae..0583b15d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ # Core -transformers[torch]>=4.39.3 +transformers[torch]>=4.40.2 bitsandbytes>=0.42.0 peft simple_parsing diff --git a/ultravox/training/train.py b/ultravox/training/train.py index a5e68e62..21991219 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -174,9 +174,7 @@ def main() -> None: logging.info( f"Using dtype and device (world_size): {dtype}, {device} ({world_size})" ) - model.to(device) - model.language_model.to(dtype) - model.multi_modal_projector.to(dtype) + model.to(device=device, dtype=dtype) # TODO: check if the whole model can now be moved to dtype instead # Prepare dataset, subsetting if needed