Skip to content

Commit 4e74b8d

Browse files
author
Nadav Elyahu
committed
add bfloat16 to inference support dtypes
to allow running inference tasks using bfloat16
1 parent 170b46e commit 4e74b8d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

deepspeed/inference/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def _validate_args(self, mpu, replace_with_kernel_inject):
324324
if self._config.checkpoint is not None and not isinstance(self._config.checkpoint, (str, dict)):
325325
raise ValueError(f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}")
326326

327-
supported_dtypes = [None, torch.half, torch.int8, torch.float]
327+
supported_dtypes = [None, torch.half, torch.int8, torch.float, torch.bfloat16]
328328
if self._config.dtype not in supported_dtypes:
329329
raise ValueError(f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}")
330330

0 commit comments

Comments
 (0)