diff --git a/src/tlo/lm.py b/src/tlo/lm.py index e099714850..a7538a1a7a 100644 --- a/src/tlo/lm.py +++ b/src/tlo/lm.py @@ -385,7 +385,7 @@ def predict( rng: Optional[np.random.RandomState] = None, squeeze_single_row_output=True, **kwargs - ) -> pd.Series: + ) -> Union[pd.Series, np.bool_]: """Evaluate linear model output for a given set of input data. :param df: The input ``DataFrame`` containing the input data to evaluate the @@ -396,7 +396,8 @@ def predict( output directly returned. :param squeeze_single_row_output: If ``rng`` argument is not ``None`` and this argument is set to ``True``, the output for a ``df`` input with a single-row - will be a scalar boolean value rather than a boolean ``Series``. + will be a scalar boolean value rather than a boolean ``Series``, if set to + ``False``, the output will always be a ``Series``. :param **kwargs: Values for any external variables included in model predictors. """