Skip to content

Commit

Permalink
Adding a failing test for dataset instantiation from passage level / …
Browse files Browse the repository at this point in the history
…flat data.
  • Loading branch information
Mark committed May 1, 2024
1 parent 17bbff4 commit 9239cdd
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 7 deletions.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
37 changes: 30 additions & 7 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path
from typing import Iterable
import os

import pandas as pd
import pytest
Expand Down Expand Up @@ -85,6 +86,18 @@ def test_huggingface_dataset_gst() -> HuggingFaceDataset:
)


@pytest.fixture
def test_huggingface_dataset_cpr_passage_level_flat() -> HuggingFaceDataset:
"""Test HuggingFace dataset with flattened passage level schema."""
dataset_dir = "tests/test_data/huggingface/cpr_passage_level_flat"
dataset_files = os.listdir(dataset_dir)
dataset = HuggingFaceDataset.from_parquet(
path_or_paths=[os.path.join(dataset_dir, f) for f in dataset_files]
)
assert isinstance(dataset, HuggingFaceDataset)
return dataset


def test_dataset_metadata_df(test_dataset):
metadata_df = test_dataset.metadata_df

Expand Down Expand Up @@ -427,25 +440,35 @@ def test_dataset_from_huggingface_cpr(test_huggingface_dataset_cpr, limit):
assert len(dataset) == limit


def test_dataset_from_huggingface_gst(test_huggingface_dataset_gst):
def test_dataset_from_huggingface_gst(
test_huggingface_dataset_gst, test_huggingface_dataset_cpr_passage_level_flat
):
"""Test that a dataset can be created from a HuggingFace dataset."""
# GST Dataset
dataset = Dataset(document_model=GSTDocument)._from_huggingface_parquet(
test_huggingface_dataset_gst
)

assert isinstance(dataset, Dataset)
assert all(isinstance(doc, GSTDocument) for doc in dataset.documents)

assert any(doc.languages is not None for doc in dataset.documents)

# Check hugingface dataset has the same number of documents as the dataset
assert len(dataset) == len({d["document_id"] for d in test_huggingface_dataset_gst})
unique_document_ids = set(d["document_id"] for d in test_huggingface_dataset_gst)
assert len(dataset) == len(unique_document_ids)

# Check huggingface dataset has the same number of text blocks as the dataset
assert sum(len(doc.text_blocks or []) for doc in dataset.documents) == len(
test_huggingface_dataset_gst
dataset_text_blocks_number = sum(
len(doc.text_blocks or []) for doc in dataset.documents
)
assert dataset_text_blocks_number == len(test_huggingface_dataset_gst)

# CPR Dataset from passage level flat dataset schema
dataset = Dataset(document_model=CPRDocument)._from_huggingface_parquet(
test_huggingface_dataset_cpr_passage_level_flat
)

assert isinstance(dataset, Dataset)
assert all(isinstance(doc, CPRDocument) for doc in dataset.documents)


def test_dataset_indexable(test_dataset):
"""Tests that the dataset can be indexed to get documents"""
Expand Down

0 comments on commit 9239cdd

Please sign in to comment.