Skip to content

Commit

Permalink
fix: check that prediction_proba is not None before checking its type (
Browse files Browse the repository at this point in the history
…#61)

* fix: check that prediction_proba is not None before checking its type

* fix: using matched value to populate ValueError text

* fix: removed ModelType in ValueError text

---------

Co-authored-by: lorenzodagostinoradicalbit <lorenzo.dagostino@radicalbit.ai>
  • Loading branch information
lorenzodagostinoradicalbit and lorenzodagostinoradicalbit authored Jul 1, 2024
1 parent 53155be commit 6942317
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions api/app/models/model_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,19 @@ def validate_target(self) -> Self:
case ModelType.BINARY:
if not is_number(self.target.type):
raise ValueError(
f'target must be a number for a ModelType.BINARY, has been provided [{self.target}]'
f'target must be a number for a {checked_model_type}, has been provided [{self.target}]'
)
return self
case ModelType.MULTI_CLASS:
if not is_number_or_string(self.target.type):
raise ValueError(
f'target must be a number or string for a ModelType.MULTI_CLASS, has been provided [{self.target}]'
f'target must be a number or string for a {checked_model_type}, has been provided [{self.target}]'
)
return self
case ModelType.REGRESSION:
if not is_number(self.target.type):
raise ValueError(
f'target must be a number for a ModelType.REGRESSION, has been provided [{self.target}]'
f'target must be a number for a {checked_model_type}, has been provided [{self.target}]'
)
return self
case _:
Expand All @@ -97,31 +97,37 @@ def validate_outputs(self) -> Self:
case ModelType.BINARY:
if not is_number(self.outputs.prediction.type):
raise ValueError(
f'prediction must be a number for a ModelType.BINARY, has been provided [{self.outputs.prediction}]'
f'prediction must be a number for a {checked_model_type}, has been provided [{self.outputs.prediction}]'
)
if not is_optional_float(self.outputs.prediction_proba.type):
if not is_none(self.outputs.prediction_proba) and not is_optional_float(
self.outputs.prediction_proba.type
):
raise ValueError(
f'prediction_proba must be an optional float for a ModelType.BINARY, has been provided [{self.outputs.prediction_proba}]'
f'prediction_proba must be an optional float for a {checked_model_type}, has been provided [{self.outputs.prediction_proba}]'
)
return self
case ModelType.MULTI_CLASS:
if not is_number_or_string(self.outputs.prediction.type):
raise ValueError(
f'prediction must be a number or string for a ModelType.MULTI_CLASS, has been provided [{self.outputs.prediction}]'
f'prediction must be a number or string for a {checked_model_type}, has been provided [{self.outputs.prediction}]'
)
if not is_optional_float(self.outputs.prediction_proba.type):
if not is_none(self.outputs.prediction_proba) and not is_optional_float(
self.outputs.prediction_proba.type
):
raise ValueError(
f'prediction_proba must be an optional float for a ModelType.MULTI_CLASS, has been provided [{self.outputs.prediction_proba}]'
f'prediction_proba must be an optional float for a {checked_model_type}, has been provided [{self.outputs.prediction_proba}]'
)
return self
case ModelType.REGRESSION:
if not is_number(self.outputs.prediction.type):
raise ValueError(
f'prediction must be a number for a ModelType.REGRESSION, has been provided [{self.outputs.prediction}]'
f'prediction must be a number for a {checked_model_type}, has been provided [{self.outputs.prediction}]'
)
if not is_none(self.outputs.prediction_proba.type):
if not is_none(self.outputs.prediction_proba) and not is_none(
self.outputs.prediction_proba.type
):
raise ValueError(
f'prediction_proba must be None for a ModelType.REGRESSION, has been provided [{self.outputs.prediction_proba}]'
f'prediction_proba must be None for a {checked_model_type}, has been provided [{self.outputs.prediction_proba}]'
)
return self
case _:
Expand Down

0 comments on commit 6942317

Please sign in to comment.