Skip to content

Commit

Permalink
Store ancestor data in sgkit format
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery committed Nov 29, 2022
1 parent 2630169 commit 3662ae8
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 88 deletions.
15 changes: 4 additions & 11 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,24 +1213,17 @@ def run_ancestor_quality(args):
olap_end = estim_positions[olap_end_estim]
olap_end_exact = np.searchsorted(exact_anc.sites_position[:], olap_end)

# ancestors_haplotype[x] contains a vector of inferred sites only
# between ancestors_start[x] and ancestors_end[x]. To match it to the genome-wide
# mask we need to account for the offset before masking out non-shared sites
offset1 = exact_anc.ancestors_start[:][exact_index]
offset2 = estim_anc.ancestors_start[:][estim_index]

exact_full_hap = exact_anc.ancestors_haplotype[:][exact_index]
exact_full_hap = exact_anc.ancestors_full_haplotype[:, exact_index, 0]
# slice the full haplotype to include only the overlapping region
exact_olap = exact_full_hap[
(olap_start_exact - offset1) : (olap_end_exact - offset1)
]
exact_olap = exact_full_hap[olap_start_exact:olap_end_exact]
# make a 1/0 array with only the comparable sites
exact_comp = exact_olap[exact_sites_mask[olap_start_exact:olap_end_exact]]

estim_full_hap = estim_anc.ancestors_haplotype[estim_index]
estim_olap = estim_full_hap[
(olap_start_estim - offset2) : (olap_end_estim - offset2)
]
estim_full_hap = estim_anc.ancestors_full_haplotype[:, estim_index, 0]
estim_olap = estim_full_hap[olap_start_estim:olap_end_estim]
small_estim_mask = estim_sites_mask[olap_start_estim:olap_end_estim]
estim_comp = estim_olap[small_estim_mask]

Expand Down
66 changes: 44 additions & 22 deletions tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -1929,23 +1929,27 @@ def verify_data_round_trip(self, sample_data, ancestor_data, ancestors):
stored_start = ancestor_data.ancestors_start[:]
stored_end = ancestor_data.ancestors_end[:]
stored_time = ancestor_data.ancestors_time[:]
stored_ancestors = ancestor_data.ancestors_haplotype[:]
# Remove the ploidy dimension
stored_ancestors = ancestor_data.ancestors_full_haplotype[:, :, 0]
stored_focal_sites = ancestor_data.ancestors_focal_sites[:]
stored_length = ancestor_data.ancestors_length[:]
for j, (start, end, t, focal_sites, haplotype) in enumerate(ancestors):
for j, (start, end, t, focal_sites, full_haplotype) in enumerate(ancestors):
assert stored_start[j] == start
assert stored_end[j] == end
assert stored_time[j] == t
assert np.array_equal(stored_focal_sites[j], focal_sites)
assert np.array_equal(stored_ancestors[j], haplotype[start:end])
assert np.array_equal(stored_ancestors[:, j], full_haplotype)
assert np.array_equal(ancestors_list[j], haplotype[start:end])
pos = list(ancestor_data.sites_position[:]) + [ancestor_data.sequence_length]
for j, anc in enumerate(ancestor_data.ancestors()):
assert stored_start[j] == anc.start
assert stored_end[j] == anc.end
assert stored_time[j] == anc.time
assert np.array_equal(stored_focal_sites[j], anc.focal_sites)
assert np.array_equal(stored_ancestors[j], anc.haplotype)
assert np.array_equal(stored_ancestors[:, j], anc.full_haplotype)
assert np.array_equal(
stored_ancestors[anc.start : anc.end, j], anc.haplotype
)
length = pos[anc.end] - pos[anc.start]
assert stored_length[j] == length

Expand Down Expand Up @@ -2021,18 +2025,31 @@ def test_provenance(self):
def test_chunk_size(self):
N = 20
for chunk_size in [1, 2, 3, N - 1, N, N + 1]:
sample_data, ancestors = self.get_example_data(6, 1, num_ancestors=N)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position,
sample_data.sequence_length,
chunk_size=chunk_size,
)
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
assert ancestor_data.ancestors_haplotype.chunks == (chunk_size,)
assert ancestor_data.ancestors_focal_sites.chunks == (chunk_size,)
assert ancestor_data.ancestors_start.chunks == (chunk_size,)
assert ancestor_data.ancestors_end.chunks == (chunk_size,)
assert ancestor_data.ancestors_time.chunks == (chunk_size,)
for chunk_size_sites in [None, 1, 2, 3, N - 1, N, N + 1]:
sample_data, ancestors = self.get_example_data(6, 1, num_ancestors=N)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position,
sample_data.sequence_length,
chunk_size=chunk_size,
chunk_size_sites=chunk_size_sites,
)
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
if chunk_size_sites is None:
assert ancestor_data.ancestors_full_haplotype.chunks == (
16384,
chunk_size,
1,
)
else:
assert ancestor_data.ancestors_full_haplotype.chunks == (
chunk_size_sites,
chunk_size,
1,
)
assert ancestor_data.ancestors_focal_sites.chunks == (chunk_size,)
assert ancestor_data.ancestors_start.chunks == (chunk_size,)
assert ancestor_data.ancestors_end.chunks == (chunk_size,)
assert ancestor_data.ancestors_time.chunks == (chunk_size,)

def test_filename(self):
sample_data, ancestors = self.get_example_data(10, 2, 40)
Expand Down Expand Up @@ -2069,7 +2086,11 @@ def test_chunk_size_file_equal(self):
chunk_size=chunk_size,
) as ancestor_data:
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
assert ancestor_data.ancestors_haplotype.chunks == (chunk_size,)
assert ancestor_data.ancestors_full_haplotype.chunks == (
16384,
chunk_size,
1,
)
# Now reload the files and check they are equal
with formats.AncestorData.load(files[0]) as file0:
with formats.AncestorData.load(files[1]) as file1:
Expand Down Expand Up @@ -2263,7 +2284,7 @@ def test_insert_proxy_1_sample(self):
inserted = -1
self.assert_ancestor_full_span(ancestors_extra, [inserted])
assert np.array_equal(
ancestors_extra.ancestors_haplotype[inserted],
ancestors_extra.ancestors_full_haplotype[:, inserted, 0],
sample_data.sites_genotypes[:, i][used_sites],
)

Expand Down Expand Up @@ -2304,7 +2325,7 @@ def test_insert_proxy_time_historical_samples(self):
assert ancestors.num_ancestors + 1 == ancestors_extra.num_ancestors
self.assert_ancestor_full_span(ancestors_extra, [-1])
assert np.array_equal(
ancestors_extra.ancestors_haplotype[-1], G[:, 9][used_sites]
ancestors_extra.ancestors_full_haplotype[:, -1, 0], G[:, 9][used_sites]
)
assert np.array_equal(
ancestors_extra.ancestors_time[-1], historical_sample_time + epsilon
Expand All @@ -2320,14 +2341,14 @@ def test_insert_proxy_time_historical_samples(self):
self.assert_ancestor_full_span(ancestors_extra, inserted)
# Older sample
assert np.array_equal(
ancestors_extra.ancestors_haplotype[-2], G[:, 9][used_sites]
ancestors_extra.ancestors_full_haplotype[:, -2, 0], G[:, 9][used_sites]
)
assert np.array_equal(
ancestors_extra.ancestors_time[-2], historical_sample_time + epsilon
)
# Younger sample
assert np.array_equal(
ancestors_extra.ancestors_haplotype[-1], G[:, 0][used_sites]
ancestors_extra.ancestors_full_haplotype[:, -1, 0], G[:, 0][used_sites]
)
assert np.array_equal(ancestors_extra.ancestors_time[-1], epsilon)

Expand Down Expand Up @@ -2490,7 +2511,8 @@ def test_multiple_focal_sites(self):
ancestor_data.finalise()
trunc_anc = ancestor_data.truncate_ancestors(0.3, 0.4, 1)
assert np.array_equal(
trunc_anc.ancestors_haplotype[-1], ancestor_data.ancestors_haplotype[-1]
trunc_anc.ancestors_full_haplotype[-1],
ancestor_data.ancestors_full_haplotype[-1],
)


Expand Down
22 changes: 13 additions & 9 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,7 +1411,7 @@ def test_all_missing_at_adjacent_site(self):
site_0_anc = site_0_anc[0]
# Sites 0 and 2 should share the same ancestor
assert np.all(adp.ancestors_focal_sites[:][site_0_anc] == [0, 2])
focal_site_0_haplotype = adp.ancestors_haplotype[:][site_0_anc]
focal_site_0_haplotype = adp.ancestors_full_haplotype[:, site_0_anc, 0]
# High freq sites with all missing data (e.g. for sites 1 & 3 in the ancestral
# haplotype focussed on sites 0 & 2) should default to tskit.MISSING_DATA
expected_hap_focal_site_0 = [1, u, 1, u, 1]
Expand Down Expand Up @@ -1452,9 +1452,11 @@ def verify_inserted_ancestors(self, ts):
)
start = ancestor_data.ancestors_start[:]
end = ancestor_data.ancestors_end[:]
ancestors = ancestor_data.ancestors_haplotype[:]
ancestors = ancestor_data.ancestors_full_haplotype[:]
for j in range(ancestor_data.num_ancestors):
A[start[j] : end[j], j] = ancestors[j]
A[start[j] : end[j], j] = ancestors[start[j] : end[j], j, 0]
assert np.all(ancestors[0 : start[j], j, 0] == tskit.MISSING_DATA)
assert np.all(ancestors[end[j] :, j, 0] == tskit.MISSING_DATA)
for engine in [tsinfer.PY_ENGINE, tsinfer.C_ENGINE]:
ancestors_ts = tsinfer.match_ancestors(
sample_data, ancestor_data, engine=engine
Expand Down Expand Up @@ -1569,14 +1571,14 @@ def get_simulated_example(self, ts):
return sample_data, ancestor_data

def verify_ancestors(self, sample_data, ancestor_data):
ancestors = ancestor_data.ancestors_haplotype[:]
ancestors = ancestor_data.ancestors_full_haplotype[:]
position = sample_data.sites_position[:]
start = ancestor_data.ancestors_start[:]
end = ancestor_data.ancestors_end[:]
times = ancestor_data.ancestors_time[:]
focal_sites = ancestor_data.ancestors_focal_sites[:]

assert ancestor_data.num_ancestors == ancestors.shape[0]
assert ancestor_data.num_ancestors == ancestors.shape[1]
assert ancestor_data.num_ancestors == times.shape[0]
assert ancestor_data.num_ancestors == start.shape[0]
assert ancestor_data.num_ancestors == end.shape[0]
Expand All @@ -1586,14 +1588,16 @@ def verify_ancestors(self, sample_data, ancestor_data):
assert start[0] == 0
assert end[0] == ancestor_data.num_sites
assert list(focal_sites[0]) == []
assert np.all(ancestors[0] == 0)
assert np.all(ancestors[:, 0] == 0)

used_sites = []
for j in range(ancestor_data.num_ancestors):
a = ancestors[j]
assert a.shape[0] == end[j] - start[j]
a = ancestors[:, j, 0]
assert a.shape[0] == ancestor_data.num_sites
assert np.all(a[0 : start[j]] == tskit.MISSING_DATA)
assert np.all(a[end[j] :] == tskit.MISSING_DATA)
h = np.zeros(ancestor_data.num_sites, dtype=np.uint8)
h[start[j] : end[j]] = a
h[start[j] : end[j]] = a[start[j] : end[j]]
assert np.all(h[start[j] : end[j]] != tskit.MISSING_DATA)
assert np.all(h[focal_sites[j]] == 1)
used_sites.extend(focal_sites[j])
Expand Down
37 changes: 37 additions & 0 deletions tests/test_sgkit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import tempfile

import lmdb
import sgkit
import zarr

import tsinfer
from tsinfer import exceptions


def open_lmbd_readonly(path):
# We set the mapsize here because LMBD will map 1TB of virtual memory if
# we don't, making it hard to figure out how much memory we're actually
# using.
map_size = None
try:
map_size = os.path.getsize(path)
except OSError as e:
raise exceptions.FileFormatError(str(e)) from e
try:
store = zarr.LMDBStore(
path, map_size=map_size, readonly=True, subdir=False, lock=False
)
except lmdb.InvalidError as e:
raise exceptions.FileFormatError(f"Unknown file format:{str(e)}") from e
except lmdb.Error as e:
raise exceptions.FileFormatError(str(e)) from e
return store


def test_ancestor_compat(small_sd_fixture):
with tempfile.TemporaryDirectory(prefix="tsi_eval") as tmpdir:
f = f"{tmpdir}/test.ancestors"
tsinfer.generate_ancestors(small_sd_fixture, path=f)
store = open_lmbd_readonly(f)
sgkit.load_dataset(store)
Loading

0 comments on commit 3662ae8

Please sign in to comment.