Skip to content

Commit

Permalink
Fix multi-column segment key translation (whylabs#848)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamie256 authored and Han Wang committed Sep 20, 2022
1 parent eb8f95c commit 82c1461
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
31 changes: 31 additions & 0 deletions python/tests/api/logger/test_segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,34 @@ def test_multi_column_segment() -> None:
count = segment_distribution.n
assert count is not None
assert count == 1


def test_multi_column_segment_serialization_roundtrip(tmp_path: Any) -> None:
input_rows = 35
d = {
"A": [i % 7 for i in range(input_rows)],
"B": [f"x{str(i%5)}" for i in range(input_rows)],
}

df = pd.DataFrame(data=d)
segmentation_partition = SegmentationPartition(name="A,B", mapper=ColumnMapperFunction(col_names=["A", "B"]))
test_segments = {segmentation_partition.name: segmentation_partition}
results: SegmentedResultSet = why.log(df, schema=DatasetSchema(segments=test_segments))
results.writer().option(base_dir=tmp_path).write()

paths = glob(os.path.join(tmp_path) + "/*.bin")
assert len(paths) == input_rows
roundtrip_profiles = []
for file_path in paths:
roundtrip_profiles.append(read_v0_to_view(file_path))
assert len(roundtrip_profiles) == input_rows
print(roundtrip_profiles)
print(roundtrip_profiles[15])

post_deserialization_view = roundtrip_profiles[15]
assert post_deserialization_view is not None
assert isinstance(post_deserialization_view, DatasetProfileView)

post_columns = post_deserialization_view.get_columns()
assert "A" in post_columns.keys()
assert "B" in post_columns.keys()
14 changes: 12 additions & 2 deletions python/whylogs/migration/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def _generate_segment_tags_metadata(

segment_tags = []
col_names = partition.mapper.col_names
index = 0
for column_name in col_names:

for index, column_name in enumerate(col_names):
segment_tags.append(SegmentTag(key=_TAG_PREFIX + column_name, value=segment.key[index]))
else:
raise NotImplementedError(
Expand Down Expand Up @@ -138,6 +138,11 @@ def v0_to_v1_view(msg: DatasetProfileMessageV0) -> DatasetProfileView:
}
)

if msg.properties.tags:
logger.info(
f"Found tags in v0 message, ignoring while converting to v1 DatasetProfileView: {msg.properties.tags}"
)

return DatasetProfileView(
columns=columns, dataset_timestamp=dataset_timestamp, creation_timestamp=creation_timestamp
)
Expand Down Expand Up @@ -225,10 +230,15 @@ def _extract_dist_metric(msg: ColumnMessageV0) -> DistributionMetric:
kll_bytes = msg.numbers.histogram
floats_sk = None
doubles_sk: Optional[ds.kll_doubles_sketch] = None
# If this is a V1 serialized message it will be a double kll sketch.
try:
floats_sk = ds.kll_floats_sketch.deserialize(kll_bytes)
except ValueError as e:
logger.info(f"kll encountered old format which threw exception: {e}, attempting kll_doubles deserialization.")
except RuntimeError as e:
logger.warning(
f"kll encountered runtime error in old format which threw exception: {e}, attempting kll_doubles deserialization."
)
if floats_sk is None:
doubles_sk = ds.kll_doubles_sketch.deserialize(kll_bytes)
else:
Expand Down

0 comments on commit 82c1461

Please sign in to comment.