From d9e8ada4c50f15ad3cba0dfb0c246e3ae33b64c7 Mon Sep 17 00:00:00 2001 From: David Tulga <3924980+dtulga@users.noreply.github.com> Date: Wed, 28 Aug 2024 11:02:46 -0700 Subject: [PATCH] Fixing get_file_signals for custom types --- src/datachain/catalog/catalog.py | 11 +---------- tests/func/test_catalog.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index 849d8b49e..7f583012c 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -1560,17 +1560,8 @@ def get_file_signals( version = self.get_dataset(dataset_name).get_version(dataset_version) file_signals_values = {} - file_schemas = {} - # TODO: To remove after we properly fix deserialization - for signal, type_name in version.feature_schema.items(): - from datachain.lib.model_store import ModelStore - type_name_parsed, v = ModelStore.parse_name_version(type_name) - fr = ModelStore.get(type_name_parsed, v) - if fr and issubclass(fr, File): - file_schemas[signal] = type_name - - schema = SignalSchema.deserialize(file_schemas) + schema = SignalSchema.deserialize(version.feature_schema) for file_signals in schema.get_signals(File): prefix = file_signals.replace(".", DEFAULT_DELIMITER) + DEFAULT_DELIMITER file_signals_values[file_signals] = { diff --git a/tests/func/test_catalog.py b/tests/func/test_catalog.py index 4581add72..bcd088e0d 100644 --- a/tests/func/test_catalog.py +++ b/tests/func/test_catalog.py @@ -1151,6 +1151,36 @@ def test_get_file_signals(cloud_test_catalog, dogs_dataset): } +def test_get_file_signals_with_custom_types(cloud_test_catalog, dogs_dataset): + catalog = cloud_test_catalog.catalog + catalog.metastore.update_dataset_version( + dogs_dataset, + 1, + feature_schema={ + "name": "str", + "age": "str", + "f1": "File@v1", + "f2": "File@v1", + "_custom_types": { + "File@v1": {"source": "str", "name": "str"}, + }, + }, + ) + row = { + "name": "Jon", + "age": 25, + "f1__source": "s3://first_bucket", + "f1__name": "image1.jpg", + "f2__source": "s3://second_bucket", + "f2__name": "image2.jpg", + } + + assert catalog.get_file_signals(dogs_dataset.name, 1, row) == { + "source": "s3://first_bucket", + "name": "image1.jpg", + } + + def test_get_file_signals_no_signals(cloud_test_catalog, dogs_dataset): catalog = cloud_test_catalog.catalog catalog.metastore.update_dataset_version(