Skip to content

Commit

Permalink
added tests for export_files
Browse files Browse the repository at this point in the history
  • Loading branch information
ilongin committed Jul 16, 2024
1 parent 2285554 commit bb6fb50
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import datetime
import math
import os
from collections.abc import Generator, Iterator
from pathlib import Path

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -45,6 +47,16 @@ class MyNested(BaseModel):
]


def _create_local_bucket(tmp_dir, bucket_name, files, data):
bucket_dir = tmp_dir / bucket_name
bucket_dir.mkdir(parents=True)
for file_path in files:
file_path = bucket_dir / file_path
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as fd:
fd.write(data)


def test_pandas_conversion(catalog):
df = pd.DataFrame(DF_DATA)
df1 = DataChain.from_pandas(df)
Expand Down Expand Up @@ -837,3 +849,35 @@ def test_parse_tabular_object_name(tmp_dir, catalog):
df.to_parquet(path)
dc = DataChain.from_storage(path.as_uri()).parse_tabular(object_name="name")
assert "name.first_name" in dc.to_pandas().columns


@pytest.mark.parametrize("strategy", ["fullpath", "filename"])
def test_export_files(tmp_dir, catalog, strategy):
data = b"some\x00data\x00is\x48\x65\x6c\x57\x6f\x72\x6c\x64\xff\xffheRe"
bucket_name = "mybucket"
files = ["dir1/a.json", "dir1/dir2/b.json"]

_create_local_bucket(tmp_dir, bucket_name, files, data)

df = DataChain.from_storage((tmp_dir / bucket_name).as_uri())
df.export_files(tmp_dir / "output", strategy=strategy)

for file_path in files:
if strategy == "filename":
path = os.path.basename(file_path)
else:
path = tmp_dir / bucket_name / Path(file_path)
with open(tmp_dir / "output" / path, "rb") as f:
assert f.read() == data


def test_export_files_filename_strategy_not_unique_files(tmp_dir, catalog):
data = b"some\x00data\x00is\x48\x65\x6c\x57\x6f\x72\x6c\x64\xff\xffheRe"
bucket_name = "mybucket"
files = ["dir1/a.json", "dir1/dir2/a.json"]

_create_local_bucket(tmp_dir, bucket_name, files, data)

df = DataChain.from_storage((tmp_dir / bucket_name).as_uri())
with pytest.raises(ValueError):
df.export_files(tmp_dir / "output", strategy="filename")

0 comments on commit bb6fb50

Please sign in to comment.