Skip to content

Commit

Permalink
Correct fix for dtype validation in DeepSpeedInferenceConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
adk9 committed May 28, 2024
1 parent 62ca5f2 commit 4cb7ac3
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 26 deletions.
45 changes: 20 additions & 25 deletions deepspeed/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,17 @@


class DtypeEnum(Enum):
# The torch dtype must always be the first value (so we return torch.dtype)
fp16 = torch.float16, "torch.float16", "fp16", "float16", "half"
fp32 = torch.float32, "torch.float32", "fp32", "float32", "float"
bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat"
int8 = torch.int8, "torch.int8", "int8"

# Copied from https://stackoverflow.com/a/43210118
# Allows us to use multiple values for each Enum index and returns first
# listed value when Enum is called
def __new__(cls, *values):
obj = object.__new__(cls)
# first value is canonical value
obj._value_ = values[0]
for other_value in values[1:]:
cls._value2member_map_[other_value] = obj
obj._all_values = values
return obj

def __repr__(self):
return "<%s.%s: %s>" % (
self.__class__.__name__,
self._name_,
", ".join([repr(v) for v in self._all_values]),
)
fp16 = (torch.float16, "torch.float16", "fp16", "float16", "half")
fp32 = (torch.float32, "torch.float32", "fp32", "float32", "float")
bf16 = (torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat")
int8 = (torch.int8, "torch.int8", "int8")

@classmethod
def from_str(cls, value: str):
for dtype in cls:
if value in dtype.value:
return dtype
raise ValueError(f"'{value}' is not a valid DtypeEnum")


class MoETypeEnum(str, Enum):
Expand Down Expand Up @@ -136,7 +123,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
`(attention_output projection, transformer output projection)`
"""

dtype: DtypeEnum = torch.float16
dtype: torch.dtype = torch.float16
"""
Desired model data type, will convert model to this type.
Supported target types: `torch.half`, `torch.int8`, `torch.float`
Expand Down Expand Up @@ -303,6 +290,14 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
"new_param": "moe.type"
})

@field_validator("dtype", mode="before")
def validate_dtype(cls, field_value, values):
if isinstance(field_value, str):
return DtypeEnum.from_str(field_value).value[0]
if isinstance(field_value, torch.dtype):
return field_value
raise TypeError(f"Invalid type for dtype: {type(field_value)}")

@field_validator("moe")
def moe_backward_compat(cls, field_value, values):
if isinstance(field_value, bool):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/inference/test_inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_kwargs_and_config(self):
assert engine._config.dtype == kwargs["dtype"]

def test_json_config(self, tmpdir):
config = {"replace_with_kernel_inject": True, "dtype": torch.float32}
config = {"replace_with_kernel_inject": True, "dtype": "torch.float32"}
config_json = create_config_from_dict(tmpdir, config)

engine = deepspeed.init_inference(torch.nn.Module(), config=config_json)
Expand Down

0 comments on commit 4cb7ac3

Please sign in to comment.