From 67416f320be66bdd1ff5db90217b31a6a59681c2 Mon Sep 17 00:00:00 2001 From: Dave Berenbaum Date: Fri, 13 Sep 2024 10:52:08 -0400 Subject: [PATCH 1/4] read arrow files from cache --- src/datachain/lib/arrow.py | 5 ++++- src/datachain/lib/file.py | 8 ++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/datachain/lib/arrow.py b/src/datachain/lib/arrow.py index 5a845d629..8edca57f4 100644 --- a/src/datachain/lib/arrow.py +++ b/src/datachain/lib/arrow.py @@ -46,7 +46,10 @@ def __init__( self.kwargs = kwargs def process(self, file: File): - if self.nrows: + if file._caching_enabled: + path = file.get_local_path(download=True) + ds = dataset(path, schema=self.input_schema, **self.kwargs) + elif self.nrows: path = _nrows_file(file, self.nrows) ds = dataset(path, schema=self.input_schema, **self.kwargs) else: diff --git a/src/datachain/lib/file.py b/src/datachain/lib/file.py index 4818bf439..bd305a53b 100644 --- a/src/datachain/lib/file.py +++ b/src/datachain/lib/file.py @@ -257,14 +257,18 @@ def get_uid(self) -> UniqueId: dump = self.model_dump() return UniqueId(*(dump[k] for k in self._unique_id_keys)) - def get_local_path(self) -> Optional[str]: + def get_local_path(self, download: bool = False) -> Optional[str]: """Returns path to a file in a local cache. Return None if file is not cached. Throws an exception if cache is not setup.""" if self._catalog is None: raise RuntimeError( "cannot resolve local file path because catalog is not setup" ) - return self._catalog.cache.get_path(self.get_uid()) + uid = self.get_uid() + if download: + client = self._catalog.get_client(self.source) + client.download(uid, callback=self._download_cb) + return self._catalog.cache.get_path(uid) def get_file_suffix(self): """Returns last part of file name with `.`.""" From 1f82d3894937784b88c08f6cdfcd94752a508129 Mon Sep 17 00:00:00 2001 From: Dave Berenbaum Date: Fri, 13 Sep 2024 13:44:05 -0400 Subject: [PATCH 2/4] add test --- tests/unit/lib/test_file.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/unit/lib/test_file.py b/tests/unit/lib/test_file.py index ca6d4599c..5ea07ecef 100644 --- a/tests/unit/lib/test_file.py +++ b/tests/unit/lib/test_file.py @@ -387,3 +387,18 @@ def test_resolve_function(): assert result == "resolved_file" mock_file.resolve.assert_called_once() + + +def test_get_local_path(tmp_path, catalog): + file_name = "myfile" + data = b"some\x00data\x00is\x48\x65\x6c\x57\x6f\x72\x6c\x64\xff\xffheRe" + + file_path = tmp_path / file_name + with open(file_path, "wb") as fd: + fd.write(data) + + file = File(path=file_name, source=f"file://{tmp_path}") + file._set_stream(catalog) + + assert file.get_local_path() is None + assert file.get_local_path(download=True) is not None From 916e86a4bdc85a8d874a2d36700aad100ca07ca2 Mon Sep 17 00:00:00 2001 From: Dave Berenbaum Date: Fri, 13 Sep 2024 14:12:29 -0400 Subject: [PATCH 3/4] test arrow generator works with caching --- tests/unit/lib/test_arrow.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit/lib/test_arrow.py b/tests/unit/lib/test_arrow.py index e9693c2e8..bb12298bd 100644 --- a/tests/unit/lib/test_arrow.py +++ b/tests/unit/lib/test_arrow.py @@ -15,7 +15,8 @@ from datachain.lib.file import File, IndexedFile -def test_arrow_generator(tmp_path, catalog): +@pytest.mark.parametrize("cache", [True, False]) +def test_arrow_generator(tmp_path, catalog, cache): ids = [12345, 67890, 34, 0xF0123] texts = ["28", "22", "we", "hello world"] df = pd.DataFrame({"id": ids, "text": texts}) @@ -24,7 +25,7 @@ def test_arrow_generator(tmp_path, catalog): pq_path = tmp_path / name df.to_parquet(pq_path) stream = File(path=pq_path.as_posix(), source="file:///") - stream._set_stream(catalog, caching_enabled=False) + stream._set_stream(catalog, caching_enabled=cache) func = ArrowGenerator() objs = list(func.process(stream)) From c0f038fa1200f5a8bf510d7a94187092cbcdef06 Mon Sep 17 00:00:00 2001 From: Dave Berenbaum Date: Fri, 13 Sep 2024 14:49:41 -0400 Subject: [PATCH 4/4] try to fix windows tests --- tests/unit/lib/test_arrow.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/lib/test_arrow.py b/tests/unit/lib/test_arrow.py index bb12298bd..260e5af5d 100644 --- a/tests/unit/lib/test_arrow.py +++ b/tests/unit/lib/test_arrow.py @@ -24,7 +24,7 @@ def test_arrow_generator(tmp_path, catalog, cache): name = "111.parquet" pq_path = tmp_path / name df.to_parquet(pq_path) - stream = File(path=pq_path.as_posix(), source="file:///") + stream = File(path=pq_path.as_posix(), source="file://") stream._set_stream(catalog, caching_enabled=cache) func = ArrowGenerator() @@ -47,7 +47,7 @@ def test_arrow_generator_no_source(tmp_path, catalog): name = "111.parquet" pq_path = tmp_path / name df.to_parquet(pq_path) - stream = File(path=pq_path.as_posix(), source="file:///") + stream = File(path=pq_path.as_posix(), source="file://") stream._set_stream(catalog, caching_enabled=False) func = ArrowGenerator(source=False) @@ -68,7 +68,7 @@ def test_arrow_generator_output_schema(tmp_path, catalog): name = "111.parquet" pq_path = tmp_path / name pq.write_table(table, pq_path) - stream = File(path=pq_path.as_posix(), source="file:///") + stream = File(path=pq_path.as_posix(), source="file://") stream._set_stream(catalog, caching_enabled=False) output_schema = dict_to_data_model("", schema_to_output(table.schema))