Skip to content

Commit

Permalink
Store ancestor data in sgkit format
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery committed Nov 30, 2022
1 parent b4432e1 commit 1e767f8
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 115 deletions.
15 changes: 4 additions & 11 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,24 +1213,17 @@ def run_ancestor_quality(args):
olap_end = estim_positions[olap_end_estim]
olap_end_exact = np.searchsorted(exact_anc.sites_position[:], olap_end)

# ancestors_haplotype[x] contains a vector of inferred sites only
# between ancestors_start[x] and ancestors_end[x]. To match it to the genome-wide
# mask we need to account for the offset before masking out non-shared sites
offset1 = exact_anc.ancestors_start[:][exact_index]
offset2 = estim_anc.ancestors_start[:][estim_index]

exact_full_hap = exact_anc.ancestors_haplotype[:][exact_index]
exact_full_hap = exact_anc.ancestors_full_haplotype[:, exact_index, 0]
# slice the full haplotype to include only the overlapping region
exact_olap = exact_full_hap[
(olap_start_exact - offset1) : (olap_end_exact - offset1)
]
exact_olap = exact_full_hap[olap_start_exact:olap_end_exact]
# make a 1/0 array with only the comparable sites
exact_comp = exact_olap[exact_sites_mask[olap_start_exact:olap_end_exact]]

estim_full_hap = estim_anc.ancestors_haplotype[estim_index]
estim_olap = estim_full_hap[
(olap_start_estim - offset2) : (olap_end_estim - offset2)
]
estim_full_hap = estim_anc.ancestors_full_haplotype[:, estim_index, 0]
estim_olap = estim_full_hap[olap_start_estim:olap_end_estim]
small_estim_mask = estim_sites_mask[olap_start_estim:olap_end_estim]
estim_comp = estim_olap[small_estim_mask]

Expand Down
66 changes: 44 additions & 22 deletions tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -1929,23 +1929,27 @@ def verify_data_round_trip(self, sample_data, ancestor_data, ancestors):
stored_start = ancestor_data.ancestors_start[:]
stored_end = ancestor_data.ancestors_end[:]
stored_time = ancestor_data.ancestors_time[:]
stored_ancestors = ancestor_data.ancestors_haplotype[:]
# Remove the ploidy dimension
stored_ancestors = ancestor_data.ancestors_full_haplotype[:, :, 0]
stored_focal_sites = ancestor_data.ancestors_focal_sites[:]
stored_length = ancestor_data.ancestors_length[:]
for j, (start, end, t, focal_sites, haplotype) in enumerate(ancestors):
for j, (start, end, t, focal_sites, full_haplotype) in enumerate(ancestors):
assert stored_start[j] == start
assert stored_end[j] == end
assert stored_time[j] == t
assert np.array_equal(stored_focal_sites[j], focal_sites)
assert np.array_equal(stored_ancestors[j], haplotype[start:end])
assert np.array_equal(stored_ancestors[:, j], full_haplotype)
assert np.array_equal(ancestors_list[j], haplotype[start:end])
pos = list(ancestor_data.sites_position[:]) + [ancestor_data.sequence_length]
for j, anc in enumerate(ancestor_data.ancestors()):
assert stored_start[j] == anc.start
assert stored_end[j] == anc.end
assert stored_time[j] == anc.time
assert np.array_equal(stored_focal_sites[j], anc.focal_sites)
assert np.array_equal(stored_ancestors[j], anc.haplotype)
assert np.array_equal(stored_ancestors[:, j], anc.full_haplotype)
assert np.array_equal(
stored_ancestors[anc.start : anc.end, j], anc.haplotype
)
length = pos[anc.end] - pos[anc.start]
assert stored_length[j] == length

Expand Down Expand Up @@ -2021,18 +2025,31 @@ def test_provenance(self):
def test_chunk_size(self):
N = 20
for chunk_size in [1, 2, 3, N - 1, N, N + 1]:
sample_data, ancestors = self.get_example_data(6, 1, num_ancestors=N)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position,
sample_data.sequence_length,
chunk_size=chunk_size,
)
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
assert ancestor_data.ancestors_haplotype.chunks == (chunk_size,)
assert ancestor_data.ancestors_focal_sites.chunks == (chunk_size,)
assert ancestor_data.ancestors_start.chunks == (chunk_size,)
assert ancestor_data.ancestors_end.chunks == (chunk_size,)
assert ancestor_data.ancestors_time.chunks == (chunk_size,)
for chunk_size_sites in [None, 1, 2, 3, N - 1, N, N + 1]:
sample_data, ancestors = self.get_example_data(6, 1, num_ancestors=N)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position,
sample_data.sequence_length,
chunk_size=chunk_size,
chunk_size_sites=chunk_size_sites,
)
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
if chunk_size_sites is None:
assert ancestor_data.ancestors_full_haplotype.chunks == (
16384,
chunk_size,
1,
)
else:
assert ancestor_data.ancestors_full_haplotype.chunks == (
chunk_size_sites,
chunk_size,
1,
)
assert ancestor_data.ancestors_focal_sites.chunks == (chunk_size,)
assert ancestor_data.ancestors_start.chunks == (chunk_size,)
assert ancestor_data.ancestors_end.chunks == (chunk_size,)
assert ancestor_data.ancestors_time.chunks == (chunk_size,)

def test_filename(self):
sample_data, ancestors = self.get_example_data(10, 2, 40)
Expand Down Expand Up @@ -2069,7 +2086,11 @@ def test_chunk_size_file_equal(self):
chunk_size=chunk_size,
) as ancestor_data:
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
assert ancestor_data.ancestors_haplotype.chunks == (chunk_size,)
assert ancestor_data.ancestors_full_haplotype.chunks == (
16384,
chunk_size,
1,
)
# Now reload the files and check they are equal
with formats.AncestorData.load(files[0]) as file0:
with formats.AncestorData.load(files[1]) as file1:
Expand Down Expand Up @@ -2263,7 +2284,7 @@ def test_insert_proxy_1_sample(self):
inserted = -1
self.assert_ancestor_full_span(ancestors_extra, [inserted])
assert np.array_equal(
ancestors_extra.ancestors_haplotype[inserted],
ancestors_extra.ancestors_full_haplotype[:, inserted, 0],
sample_data.sites_genotypes[:, i][used_sites],
)

Expand Down Expand Up @@ -2304,7 +2325,7 @@ def test_insert_proxy_time_historical_samples(self):
assert ancestors.num_ancestors + 1 == ancestors_extra.num_ancestors
self.assert_ancestor_full_span(ancestors_extra, [-1])
assert np.array_equal(
ancestors_extra.ancestors_haplotype[-1], G[:, 9][used_sites]
ancestors_extra.ancestors_full_haplotype[:, -1, 0], G[:, 9][used_sites]
)
assert np.array_equal(
ancestors_extra.ancestors_time[-1], historical_sample_time + epsilon
Expand All @@ -2320,14 +2341,14 @@ def test_insert_proxy_time_historical_samples(self):
self.assert_ancestor_full_span(ancestors_extra, inserted)
# Older sample
assert np.array_equal(
ancestors_extra.ancestors_haplotype[-2], G[:, 9][used_sites]
ancestors_extra.ancestors_full_haplotype[:, -2, 0], G[:, 9][used_sites]
)
assert np.array_equal(
ancestors_extra.ancestors_time[-2], historical_sample_time + epsilon
)
# Younger sample
assert np.array_equal(
ancestors_extra.ancestors_haplotype[-1], G[:, 0][used_sites]
ancestors_extra.ancestors_full_haplotype[:, -1, 0], G[:, 0][used_sites]
)
assert np.array_equal(ancestors_extra.ancestors_time[-1], epsilon)

Expand Down Expand Up @@ -2490,7 +2511,8 @@ def test_multiple_focal_sites(self):
ancestor_data.finalise()
trunc_anc = ancestor_data.truncate_ancestors(0.3, 0.4, 1)
assert np.array_equal(
trunc_anc.ancestors_haplotype[-1], ancestor_data.ancestors_haplotype[-1]
trunc_anc.ancestors_full_haplotype[-1],
ancestor_data.ancestors_full_haplotype[-1],
)


Expand Down
22 changes: 13 additions & 9 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,7 +1411,7 @@ def test_all_missing_at_adjacent_site(self):
site_0_anc = site_0_anc[0]
# Sites 0 and 2 should share the same ancestor
assert np.all(adp.ancestors_focal_sites[:][site_0_anc] == [0, 2])
focal_site_0_haplotype = adp.ancestors_haplotype[:][site_0_anc]
focal_site_0_haplotype = adp.ancestors_full_haplotype[:, site_0_anc, 0]
# High freq sites with all missing data (e.g. for sites 1 & 3 in the ancestral
# haplotype focussed on sites 0 & 2) should default to tskit.MISSING_DATA
expected_hap_focal_site_0 = [1, u, 1, u, 1]
Expand Down Expand Up @@ -1452,9 +1452,11 @@ def verify_inserted_ancestors(self, ts):
)
start = ancestor_data.ancestors_start[:]
end = ancestor_data.ancestors_end[:]
ancestors = ancestor_data.ancestors_haplotype[:]
ancestors = ancestor_data.ancestors_full_haplotype[:]
for j in range(ancestor_data.num_ancestors):
A[start[j] : end[j], j] = ancestors[j]
A[start[j] : end[j], j] = ancestors[start[j] : end[j], j, 0]
assert np.all(ancestors[0 : start[j], j, 0] == tskit.MISSING_DATA)
assert np.all(ancestors[end[j] :, j, 0] == tskit.MISSING_DATA)
for engine in [tsinfer.PY_ENGINE, tsinfer.C_ENGINE]:
ancestors_ts = tsinfer.match_ancestors(
sample_data, ancestor_data, engine=engine
Expand Down Expand Up @@ -1569,14 +1571,14 @@ def get_simulated_example(self, ts):
return sample_data, ancestor_data

def verify_ancestors(self, sample_data, ancestor_data):
ancestors = ancestor_data.ancestors_haplotype[:]
ancestors = ancestor_data.ancestors_full_haplotype[:]
position = sample_data.sites_position[:]
start = ancestor_data.ancestors_start[:]
end = ancestor_data.ancestors_end[:]
times = ancestor_data.ancestors_time[:]
focal_sites = ancestor_data.ancestors_focal_sites[:]

assert ancestor_data.num_ancestors == ancestors.shape[0]
assert ancestor_data.num_ancestors == ancestors.shape[1]
assert ancestor_data.num_ancestors == times.shape[0]
assert ancestor_data.num_ancestors == start.shape[0]
assert ancestor_data.num_ancestors == end.shape[0]
Expand All @@ -1586,14 +1588,16 @@ def verify_ancestors(self, sample_data, ancestor_data):
assert start[0] == 0
assert end[0] == ancestor_data.num_sites
assert list(focal_sites[0]) == []
assert np.all(ancestors[0] == 0)
assert np.all(ancestors[:, 0] == 0)

used_sites = []
for j in range(ancestor_data.num_ancestors):
a = ancestors[j]
assert a.shape[0] == end[j] - start[j]
a = ancestors[:, j, 0]
assert a.shape[0] == ancestor_data.num_sites
assert np.all(a[0 : start[j]] == tskit.MISSING_DATA)
assert np.all(a[end[j] :] == tskit.MISSING_DATA)
h = np.zeros(ancestor_data.num_sites, dtype=np.uint8)
h[start[j] : end[j]] = a
h[start[j] : end[j]] = a[start[j] : end[j]]
assert np.all(h[start[j] : end[j]] != tskit.MISSING_DATA)
assert np.all(h[focal_sites[j]] == 1)
used_sites.extend(focal_sites[j])
Expand Down
36 changes: 35 additions & 1 deletion tests/test_sgkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,43 @@
"""
Tests for the data files.
"""
import os
import sys
import tempfile

import lmdb
import msprime
import numpy as np
import pytest
import sgkit
import zarr

import tsinfer
from tsinfer import exceptions


def open_lmbd_readonly(path):
# We set the mapsize here because LMBD will map 1TB of virtual memory if
# we don't, making it hard to figure out how much memory we're actually
# using.
map_size = None
try:
map_size = os.path.getsize(path)
except OSError as e:
raise exceptions.FileFormatError(str(e)) from e
try:
store = zarr.LMDBStore(
path, map_size=map_size, readonly=True, subdir=False, lock=False
)
except lmdb.InvalidError as e:
raise exceptions.FileFormatError(f"Unknown file format:{str(e)}") from e
except lmdb.Error as e:
raise exceptions.FileFormatError(str(e)) from e
return store


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
def test_sgkit_dataset(tmp_path):
def test_sgkit_sampledata(tmp_path):
import sgkit.io.vcf

ts = msprime.sim_ancestry(
Expand All @@ -48,3 +74,11 @@ def test_sgkit_dataset(tmp_path):
samples = tsinfer.SgkitSampleData(tmp_path / "data.zarr")
inf_ts = tsinfer.infer(samples)
assert np.array_equal(ts.genotype_matrix(), inf_ts.genotype_matrix())


def test_sgkit_ancestor(small_sd_fixture):
with tempfile.TemporaryDirectory(prefix="tsi_eval") as tmpdir:
f = f"{tmpdir}/test.ancestors"
tsinfer.generate_ancestors(small_sd_fixture, path=f)
store = open_lmbd_readonly(f)
sgkit.load_dataset(store)
Loading

0 comments on commit 1e767f8

Please sign in to comment.