Skip to content

Commit

Permalink
Passing backend and modifying check to include dask series
Browse files Browse the repository at this point in the history
  • Loading branch information
arnavgarg1 committed Jul 22, 2022
1 parent 06594a5 commit cd9e82f
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
4 changes: 1 addition & 3 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,9 +959,7 @@ def evaluate(

if collect_predictions:
postproc_predictions = convert_predictions(
postproc_predictions,
self.model.output_features,
return_type=return_type,
postproc_predictions, self.model.output_features, return_type=return_type, backend=self.backend
)

for callback in self.callbacks:
Expand Down
4 changes: 2 additions & 2 deletions ludwig/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from ludwig.utils.dataframe_utils import is_dask_df
from ludwig.utils.dataframe_utils import is_dask_series_or_df
from ludwig.utils.types import DataFrame


Expand All @@ -20,7 +20,7 @@ def convert_to_dict(
subgroup = key[len(of_name) + 1 :]

values = predictions[key]
if is_dask_df(values, backend):
if is_dask_series_or_df(values, backend):
values = values.compute()
try:
values = np.stack(values.to_numpy())
Expand Down
4 changes: 2 additions & 2 deletions ludwig/utils/dataframe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ def is_dask_backend(backend: Optional["Backend"]) -> bool: # noqa: F821
return backend is not None and is_dask_lib(backend.df_engine.df_lib)


def is_dask_df(df: DataFrame, backend: Optional["Backend"]) -> bool: # noqa: F821
def is_dask_series_or_df(df: DataFrame, backend: Optional["Backend"]) -> bool: # noqa: F821
if is_dask_backend(backend):
import dask.dataframe as dd

return isinstance(df, dd.DataFrame)
return isinstance(df, dd.Series) or isinstance(df, dd.DataFrame)
return False


Expand Down

0 comments on commit cd9e82f

Please sign in to comment.