diff --git a/python/tests/api/logger/test_segments.py b/python/tests/api/logger/test_segments.py index b9a76afed3..6be7669f3a 100644 --- a/python/tests/api/logger/test_segments.py +++ b/python/tests/api/logger/test_segments.py @@ -250,3 +250,34 @@ def test_multi_column_segment() -> None: count = segment_distribution.n assert count is not None assert count == 1 + + +def test_multi_column_segment_serialization_roundtrip(tmp_path: Any) -> None: + input_rows = 35 + d = { + "A": [i % 7 for i in range(input_rows)], + "B": [f"x{str(i%5)}" for i in range(input_rows)], + } + + df = pd.DataFrame(data=d) + segmentation_partition = SegmentationPartition(name="A,B", mapper=ColumnMapperFunction(col_names=["A", "B"])) + test_segments = {segmentation_partition.name: segmentation_partition} + results: SegmentedResultSet = why.log(df, schema=DatasetSchema(segments=test_segments)) + results.writer().option(base_dir=tmp_path).write() + + paths = glob(os.path.join(tmp_path) + "/*.bin") + assert len(paths) == input_rows + roundtrip_profiles = [] + for file_path in paths: + roundtrip_profiles.append(read_v0_to_view(file_path)) + assert len(roundtrip_profiles) == input_rows + print(roundtrip_profiles) + print(roundtrip_profiles[15]) + + post_deserialization_view = roundtrip_profiles[15] + assert post_deserialization_view is not None + assert isinstance(post_deserialization_view, DatasetProfileView) + + post_columns = post_deserialization_view.get_columns() + assert "A" in post_columns.keys() + assert "B" in post_columns.keys() diff --git a/python/whylogs/migration/converters.py b/python/whylogs/migration/converters.py index d36e0b7cb1..8d15e70aa1 100644 --- a/python/whylogs/migration/converters.py +++ b/python/whylogs/migration/converters.py @@ -76,8 +76,8 @@ def _generate_segment_tags_metadata( segment_tags = [] col_names = partition.mapper.col_names - index = 0 - for column_name in col_names: + + for index, column_name in enumerate(col_names): segment_tags.append(SegmentTag(key=_TAG_PREFIX + column_name, value=segment.key[index])) else: raise NotImplementedError( @@ -138,6 +138,11 @@ def v0_to_v1_view(msg: DatasetProfileMessageV0) -> DatasetProfileView: } ) + if msg.properties.tags: + logger.info( + f"Found tags in v0 message, ignoring while converting to v1 DatasetProfileView: {msg.properties.tags}" + ) + return DatasetProfileView( columns=columns, dataset_timestamp=dataset_timestamp, creation_timestamp=creation_timestamp ) @@ -225,10 +230,15 @@ def _extract_dist_metric(msg: ColumnMessageV0) -> DistributionMetric: kll_bytes = msg.numbers.histogram floats_sk = None doubles_sk: Optional[ds.kll_doubles_sketch] = None + # If this is a V1 serialized message it will be a double kll sketch. try: floats_sk = ds.kll_floats_sketch.deserialize(kll_bytes) except ValueError as e: logger.info(f"kll encountered old format which threw exception: {e}, attempting kll_doubles deserialization.") + except RuntimeError as e: + logger.warning( + f"kll encountered runtime error in old format which threw exception: {e}, attempting kll_doubles deserialization." + ) if floats_sk is None: doubles_sk = ds.kll_doubles_sketch.deserialize(kll_bytes) else: