diff --git a/ludwig/utils/data_utils.py b/ludwig/utils/data_utils.py index 2600bc70739..ad6afb14a3a 100644 --- a/ludwig/utils/data_utils.py +++ b/ludwig/utils/data_utils.py @@ -31,6 +31,7 @@ from pandas.errors import ParserError from sklearn.model_selection import KFold +from ludwig.data.cache.types import CacheableDataset from ludwig.utils.fs_utils import download_h5, open_file, upload_h5 from ludwig.utils.misc_utils import get_from_registry @@ -85,6 +86,7 @@ SAS_FORMATS, SPSS_FORMATS, STATA_FORMATS, + DATAFRAME_FORMATS, ) ) @@ -682,7 +684,9 @@ def clear_data_cache(): def figure_data_format_dataset(dataset): - if isinstance(dataset, pd.DataFrame): + if isinstance(dataset, CacheableDataset): + return figure_data_format_dataset(dataset.unwrap()) + elif isinstance(dataset, pd.DataFrame): return pd.DataFrame elif dd and isinstance(dataset, dd.core.DataFrame): return dd.core.DataFrame diff --git a/tests/ludwig/utils/test_data_utils.py b/tests/ludwig/utils/test_data_utils.py index a0097cf2bd9..68ea8e2a932 100644 --- a/tests/ludwig/utils/test_data_utils.py +++ b/tests/ludwig/utils/test_data_utils.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import dask.dataframe as dd import pandas as pd -from ludwig.utils.data_utils import add_sequence_feature_column, get_abs_path +from ludwig.data.cache.types import CacheableDataframe +from ludwig.utils.data_utils import add_sequence_feature_column, figure_data_format_dataset, get_abs_path def test_add_sequence_feature_column(): @@ -57,3 +59,30 @@ def test_add_sequence_feature_column(): def test_get_abs_path(): assert get_abs_path("a", "b.jpg") == "a/b.jpg" assert get_abs_path(None, "b.jpg") == "b.jpg" + + +def test_figure_data_format_dataset(): + assert figure_data_format_dataset({"a": "b"}) == dict + assert figure_data_format_dataset(pd.DataFrame([1, 2, 3, 4, 5], columns=["x"])) == pd.DataFrame + assert ( + figure_data_format_dataset( + dd.from_pandas(pd.DataFrame([1, 2, 3, 4, 5], columns=["x"]), npartitions=1).reset_index() + ) + == dd.core.DataFrame + ) + assert ( + figure_data_format_dataset( + CacheableDataframe(df=pd.DataFrame([1, 2, 3, 4, 5], columns=["x"]), name="test", checksum="test123") + ) + == pd.DataFrame + ) + assert ( + figure_data_format_dataset( + CacheableDataframe( + df=dd.from_pandas(pd.DataFrame([1, 2, 3, 4, 5], columns=["x"]), npartitions=1).reset_index(), + name="test", + checksum="test123", + ) + ) + == dd.core.DataFrame + )