Skip to content

Commit

Permalink
Modified figure_data_format_dataset to work with recently introduced …
Browse files Browse the repository at this point in the history
…CacheableDataFrame (#1880)
  • Loading branch information
Animesh Kumar authored Apr 6, 2022
1 parent 6e3bc1a commit 3da5cb6
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
6 changes: 5 additions & 1 deletion ludwig/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pandas.errors import ParserError
from sklearn.model_selection import KFold

from ludwig.data.cache.types import CacheableDataset
from ludwig.utils.fs_utils import download_h5, open_file, upload_h5
from ludwig.utils.misc_utils import get_from_registry

Expand Down Expand Up @@ -85,6 +86,7 @@
SAS_FORMATS,
SPSS_FORMATS,
STATA_FORMATS,
DATAFRAME_FORMATS,
)
)

Expand Down Expand Up @@ -682,7 +684,9 @@ def clear_data_cache():


def figure_data_format_dataset(dataset):
if isinstance(dataset, pd.DataFrame):
if isinstance(dataset, CacheableDataset):
return figure_data_format_dataset(dataset.unwrap())
elif isinstance(dataset, pd.DataFrame):
return pd.DataFrame
elif dd and isinstance(dataset, dd.core.DataFrame):
return dd.core.DataFrame
Expand Down
31 changes: 30 additions & 1 deletion tests/ludwig/utils/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import dask.dataframe as dd
import pandas as pd

from ludwig.utils.data_utils import add_sequence_feature_column, get_abs_path
from ludwig.data.cache.types import CacheableDataframe
from ludwig.utils.data_utils import add_sequence_feature_column, figure_data_format_dataset, get_abs_path


def test_add_sequence_feature_column():
Expand Down Expand Up @@ -57,3 +59,30 @@ def test_add_sequence_feature_column():
def test_get_abs_path():
assert get_abs_path("a", "b.jpg") == "a/b.jpg"
assert get_abs_path(None, "b.jpg") == "b.jpg"


def test_figure_data_format_dataset():
assert figure_data_format_dataset({"a": "b"}) == dict
assert figure_data_format_dataset(pd.DataFrame([1, 2, 3, 4, 5], columns=["x"])) == pd.DataFrame
assert (
figure_data_format_dataset(
dd.from_pandas(pd.DataFrame([1, 2, 3, 4, 5], columns=["x"]), npartitions=1).reset_index()
)
== dd.core.DataFrame
)
assert (
figure_data_format_dataset(
CacheableDataframe(df=pd.DataFrame([1, 2, 3, 4, 5], columns=["x"]), name="test", checksum="test123")
)
== pd.DataFrame
)
assert (
figure_data_format_dataset(
CacheableDataframe(
df=dd.from_pandas(pd.DataFrame([1, 2, 3, 4, 5], columns=["x"]), npartitions=1).reset_index(),
name="test",
checksum="test123",
)
)
== dd.core.DataFrame
)

0 comments on commit 3da5cb6

Please sign in to comment.