From 34219071b0732d4db087e35031f3463548ceecd4 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Wed, 4 Sep 2024 13:17:28 +0100 Subject: [PATCH] Test alleles are unique Also fix tests to create valid datasets with a given seed --- tests/test_variantdata.py | 50 +++++++++++++++++++++++++++++++++------ tsinfer/formats.py | 6 +++++ 2 files changed, 49 insertions(+), 7 deletions(-) diff --git a/tests/test_variantdata.py b/tests/test_variantdata.py index d639f2a2..4e3b2f81 100644 --- a/tests/test_variantdata.py +++ b/tests/test_variantdata.py @@ -675,9 +675,28 @@ def test_sgkit_ancestor(small_sd_fixture, tmp_path): class TestVariantDataErrors: + def simulate_genotype_call_dataset(self, *args, **kwargs): + # hack around bug in sgkit where duplicate alleles are created + # since this is just for testing, it doens't need to be efficient + # Force a seed: + if "seed" not in kwargs: + kwargs["seed"] = 123 + ds = sgkit.simulate_genotype_call_dataset(*args, **kwargs) + variant_alleles = ds["variant_allele"].values + allowed_alleles = np.array( + ["A", "T", "C", "G", "N"], dtype=variant_alleles.dtype + ) + for row in range(len(variant_alleles)): + alleles = variant_alleles[row] + if len(set(alleles)) != len(alleles): + # Just use a set that we know is unique + variant_alleles[row] = allowed_alleles[0 : len(alleles)] + ds["variant_allele"] = ds["variant_allele"].dims, variant_alleles + return ds + def test_missing_phase(self, tmp_path): path = tmp_path / "data.zarr" - ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3) + ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3) sgkit.save_dataset(ds, path) with pytest.raises( ValueError, match="The call_genotype_phased array is missing" @@ -686,7 +705,7 @@ def test_missing_phase(self, tmp_path): def test_phased(self, tmp_path): path = tmp_path / "data.zarr" - ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3) + ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3) ds["call_genotype_phased"] = ( ds["call_genotype"].dims, np.ones(ds["call_genotype"].shape, dtype=bool), @@ -701,7 +720,7 @@ def test_phased(self, tmp_path): def test_ploidy1_missing_phase(self, tmp_path): path = tmp_path / "data.zarr" # Ploidy==1 is always ok - ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1) + ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1) ds["variant_ancestral_allele"] = ( ds["variant_position"].dims, np.array(["A", "C", "G"], dtype="S1"), @@ -711,7 +730,7 @@ def test_ploidy1_missing_phase(self, tmp_path): def test_ploidy1_unphased(self, tmp_path): path = tmp_path / "data.zarr" - ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1) + ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1) ds["call_genotype_phased"] = ( ds["call_genotype"].dims, np.zeros(ds["call_genotype"].shape, dtype=bool), @@ -725,7 +744,7 @@ def test_ploidy1_unphased(self, tmp_path): def test_duplicate_positions(self, tmp_path): path = tmp_path / "data.zarr" - ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True) + ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True) ds["variant_position"][2] = ds["variant_position"][1] sgkit.save_dataset(ds, path) with pytest.raises(ValueError, match="duplicate or out-of-order values"): @@ -733,7 +752,7 @@ def test_duplicate_positions(self, tmp_path): def test_bad_order_positions(self, tmp_path): path = tmp_path / "data.zarr" - ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True) + ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True) ds["variant_position"][0] = ds["variant_position"][2] - 0.5 sgkit.save_dataset(ds, path) with pytest.raises(ValueError, match="duplicate or out-of-order values"): @@ -741,7 +760,7 @@ def test_bad_order_positions(self, tmp_path): def test_empty_alleles_not_at_end(self, tmp_path): path = tmp_path / "data.zarr" - ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1) + ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1) ds["variant_allele"] = ( ds["variant_allele"].dims, np.array([["", "A", "C"], ["A", "C", ""], ["A", "C", ""]], dtype="S1"), @@ -754,3 +773,20 @@ def test_empty_alleles_not_at_end(self, tmp_path): samples = tsinfer.VariantData(path, "variant_ancestral_allele") with pytest.raises(ValueError, match="Empty alleles must be at the end"): tsinfer.infer(samples) + + def test_unique_alleles(self, tmp_path): + path = tmp_path / "data.zarr" + ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1) + ds["variant_allele"] = ( + ds["variant_allele"].dims, + np.array([["A", "C", "T"], ["A", "C", ""], ["A", "A", ""]], dtype="S1"), + ) + ds["variant_ancestral_allele"] = ( + ["variants"], + np.array(["A", "A", "A"], dtype="S1"), + ) + sgkit.save_dataset(ds, path) + with pytest.raises( + ValueError, match="Duplicate allele values provided at site 2" + ): + tsinfer.VariantData(path, "variant_ancestral_allele") diff --git a/tsinfer/formats.py b/tsinfer/formats.py index 26283a77..7338640c 100644 --- a/tsinfer/formats.py +++ b/tsinfer/formats.py @@ -2398,6 +2398,12 @@ def __init__( f"The ancestral allele {ancestral_allele} was not" f" found in the dataset." ) + for i, (alleles, num_alleles) in enumerate( + zip(self.sites_alleles, self.num_alleles()) + ): + if len(set(alleles) - {b"", "", None}) != num_alleles: + raise ValueError(f"Duplicate allele values provided at site {i}") + self._sites_ancestral_allele = self._sites_ancestral_allele.astype(str) unknown_alleles = collections.Counter() converted = np.zeros(self.num_sites, dtype=np.int8)