Skip to content

Commit

Permalink
Test alleles are unique
Browse files Browse the repository at this point in the history
Also fix tests to create valid datasets with a given seed
  • Loading branch information
hyanwong committed Sep 4, 2024
1 parent a080e8f commit 3421907
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 7 deletions.
50 changes: 43 additions & 7 deletions tests/test_variantdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,9 +675,28 @@ def test_sgkit_ancestor(small_sd_fixture, tmp_path):


class TestVariantDataErrors:
def simulate_genotype_call_dataset(self, *args, **kwargs):
# hack around bug in sgkit where duplicate alleles are created
# since this is just for testing, it doens't need to be efficient
# Force a seed:
if "seed" not in kwargs:
kwargs["seed"] = 123
ds = sgkit.simulate_genotype_call_dataset(*args, **kwargs)
variant_alleles = ds["variant_allele"].values
allowed_alleles = np.array(
["A", "T", "C", "G", "N"], dtype=variant_alleles.dtype
)
for row in range(len(variant_alleles)):
alleles = variant_alleles[row]
if len(set(alleles)) != len(alleles):
# Just use a set that we know is unique
variant_alleles[row] = allowed_alleles[0 : len(alleles)]
ds["variant_allele"] = ds["variant_allele"].dims, variant_alleles
return ds

def test_missing_phase(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
sgkit.save_dataset(ds, path)
with pytest.raises(
ValueError, match="The call_genotype_phased array is missing"
Expand All @@ -686,7 +705,7 @@ def test_missing_phase(self, tmp_path):

def test_phased(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
ds["call_genotype_phased"] = (
ds["call_genotype"].dims,
np.ones(ds["call_genotype"].shape, dtype=bool),
Expand All @@ -701,7 +720,7 @@ def test_phased(self, tmp_path):
def test_ploidy1_missing_phase(self, tmp_path):
path = tmp_path / "data.zarr"
# Ploidy==1 is always ok
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
ds["variant_ancestral_allele"] = (
ds["variant_position"].dims,
np.array(["A", "C", "G"], dtype="S1"),
Expand All @@ -711,7 +730,7 @@ def test_ploidy1_missing_phase(self, tmp_path):

def test_ploidy1_unphased(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
ds["call_genotype_phased"] = (
ds["call_genotype"].dims,
np.zeros(ds["call_genotype"].shape, dtype=bool),
Expand All @@ -725,23 +744,23 @@ def test_ploidy1_unphased(self, tmp_path):

def test_duplicate_positions(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
ds["variant_position"][2] = ds["variant_position"][1]
sgkit.save_dataset(ds, path)
with pytest.raises(ValueError, match="duplicate or out-of-order values"):
tsinfer.VariantData(path, "variant_ancestral_allele")

def test_bad_order_positions(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
ds["variant_position"][0] = ds["variant_position"][2] - 0.5
sgkit.save_dataset(ds, path)
with pytest.raises(ValueError, match="duplicate or out-of-order values"):
tsinfer.VariantData(path, "variant_ancestral_allele")

def test_empty_alleles_not_at_end(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
ds["variant_allele"] = (
ds["variant_allele"].dims,
np.array([["", "A", "C"], ["A", "C", ""], ["A", "C", ""]], dtype="S1"),
Expand All @@ -754,3 +773,20 @@ def test_empty_alleles_not_at_end(self, tmp_path):
samples = tsinfer.VariantData(path, "variant_ancestral_allele")
with pytest.raises(ValueError, match="Empty alleles must be at the end"):
tsinfer.infer(samples)

def test_unique_alleles(self, tmp_path):
path = tmp_path / "data.zarr"
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
ds["variant_allele"] = (
ds["variant_allele"].dims,
np.array([["A", "C", "T"], ["A", "C", ""], ["A", "A", ""]], dtype="S1"),
)
ds["variant_ancestral_allele"] = (
["variants"],
np.array(["A", "A", "A"], dtype="S1"),
)
sgkit.save_dataset(ds, path)
with pytest.raises(
ValueError, match="Duplicate allele values provided at site 2"
):
tsinfer.VariantData(path, "variant_ancestral_allele")
6 changes: 6 additions & 0 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2398,6 +2398,12 @@ def __init__(
f"The ancestral allele {ancestral_allele} was not"
f" found in the dataset."
)
for i, (alleles, num_alleles) in enumerate(
zip(self.sites_alleles, self.num_alleles())
):
if len(set(alleles) - {b"", "", None}) != num_alleles:
raise ValueError(f"Duplicate allele values provided at site {i}")

self._sites_ancestral_allele = self._sites_ancestral_allele.astype(str)
unknown_alleles = collections.Counter()
converted = np.zeros(self.num_sites, dtype=np.int8)
Expand Down

0 comments on commit 3421907

Please sign in to comment.