diff --git a/ultravox/inference/utils.py b/ultravox/inference/utils.py index 0658dcda..5f04ff8c 100644 --- a/ultravox/inference/utils.py +++ b/ultravox/inference/utils.py @@ -10,7 +10,12 @@ def default_device(): def default_dtype(): - return torch.bfloat16 if torch.cuda.is_available() else torch.float32 + # macOS Sonoma 14 enabled bfloat16 on MPS. + return ( + torch.bfloat16 + if torch.cuda.is_available() or torch.backends.mps.is_available() + else torch.float16 + ) def get_dtype(data_type: str):