Skip to content

Commit

Permalink
WIP - decouple SampleData from AncestorData
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Jeffery authored and benjeffery committed Nov 28, 2022
1 parent 6118b09 commit 398061c
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 108 deletions.
6 changes: 3 additions & 3 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def run_infer(
sample_data = tsinfer.SampleData.from_tree_sequence(ts)

if exact_ancestors:
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(sample_data.sites_position, sample_data.sequence_length)
tsinfer.build_simulated_ancestors(sample_data, ancestor_data, ts)
ancestor_data.finalise()
else:
Expand Down Expand Up @@ -535,7 +535,7 @@ def ancestor_properties_worker(args):
}

if compute_exact:
exact_anc = tsinfer.AncestorData(sample_data)
exact_anc = tsinfer.AncestorData(sample_data.sites_position, sample_data.sequence_length)
tsinfer.build_simulated_ancestors(sample_data, exact_anc, ts)
exact_anc.finalise()
# Show lengths as a fraction of the total.
Expand Down Expand Up @@ -819,7 +819,7 @@ def sim_true_and_inferred_ancestors(args):
sample_data = generate_samples(ts, args.error)

inferred_anc = tsinfer.generate_ancestors(sample_data, engine=args.engine)
true_anc = tsinfer.AncestorData(sample_data)
true_anc = tsinfer.AncestorData(sample_data.sites_position, sample_data.sequence_length)
tsinfer.build_simulated_ancestors(sample_data, true_anc, ts)
true_anc.finalise()
return sample_data, true_anc, inferred_anc
Expand Down
72 changes: 49 additions & 23 deletions tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -1951,7 +1951,9 @@ def verify_data_round_trip(self, sample_data, ancestor_data, ancestors):

def test_defaults_no_path(self):
sample_data, ancestors = self.get_example_data(10, 10, 40)
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
for _, array in ancestor_data.arrays():
assert array.compressor == formats.DEFAULT_COMPRESSOR
Expand All @@ -1963,7 +1965,9 @@ def test_defaults_with_path(self):
sample_data, ancestors = self.get_example_data(10, 10, 40)
with tempfile.TemporaryDirectory(prefix="tsinf_format_test") as tempdir:
filename = os.path.join(tempdir, "ancestors.tmp")
ancestor_data = tsinfer.AncestorData(sample_data, path=filename)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length, path=filename
)
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
compressor = formats.DEFAULT_COMPRESSOR
for _, array in ancestor_data.arrays():
Expand All @@ -1978,17 +1982,27 @@ def test_bad_max_file_size(self):
for bad_size in ["a", "", -1]:
with pytest.raises(ValueError):
formats.AncestorData(
sample_data, path=filename, max_file_size=bad_size
sample_data.sites_position,
sample_data.sequence_length,
path=filename,
max_file_size=bad_size,
)
for bad_size in [[1, 3], np.array([1, 2])]:
with pytest.raises(TypeError):
formats.AncestorData(
sample_data, path=filename, max_file_size=bad_size
sample_data.sites_position,
sample_data.sequence_length,
path=filename,
max_file_size=bad_size,
)

def test_provenance(self):
sample_data, ancestors = self.get_example_data(10, 10, 40)
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
for timestamp, record in sample_data.provenances():
ancestor_data.add_provenance(timestamp, record)
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
assert ancestor_data.num_provenances == sample_data.num_provenances + 1

Expand All @@ -2008,7 +2022,11 @@ def test_chunk_size(self):
N = 20
for chunk_size in [1, 2, 3, N - 1, N, N + 1]:
sample_data, ancestors = self.get_example_data(6, 1, num_ancestors=N)
ancestor_data = tsinfer.AncestorData(sample_data, chunk_size=chunk_size)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position,
sample_data.sequence_length,
chunk_size=chunk_size,
)
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
assert ancestor_data.ancestors_haplotype.chunks == (chunk_size,)
assert ancestor_data.ancestors_focal_sites.chunks == (chunk_size,)
Expand All @@ -2020,7 +2038,9 @@ def test_filename(self):
sample_data, ancestors = self.get_example_data(10, 2, 40)
with tempfile.TemporaryDirectory(prefix="tsinf_format_test") as tempdir:
filename = os.path.join(tempdir, "ancestors.tmp")
ancestor_data = tsinfer.AncestorData(sample_data, path=filename)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length, path=filename
)
assert os.path.exists(filename)
assert not os.path.isdir(filename)
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
Expand All @@ -2043,7 +2063,10 @@ def test_chunk_size_file_equal(self):
filename = os.path.join(tempdir, f"samples_{chunk_size}.tmp")
files.append(filename)
with tsinfer.AncestorData(
sample_data, path=filename, chunk_size=chunk_size
sample_data.sites_position,
sample_data.sequence_length,
path=filename,
chunk_size=chunk_size,
) as ancestor_data:
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
assert ancestor_data.ancestors_haplotype.chunks == (chunk_size,)
Expand All @@ -2054,7 +2077,9 @@ def test_chunk_size_file_equal(self):

def test_add_ancestor_errors(self):
sample_data, ancestors = self.get_example_data(22, 16, 30)
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
num_sites = ancestor_data.num_sites
haplotype = np.zeros(num_sites, dtype=np.int8)
ancestor_data.add_ancestor(
Expand Down Expand Up @@ -2091,15 +2116,6 @@ def test_add_ancestor_errors(self):
focal_sites=[],
haplotype=np.zeros(num_sites + 1, dtype=np.int8),
)
# Haplotypes must be < num_alleles
with pytest.raises(ValueError):
ancestor_data.add_ancestor(
start=0,
end=num_sites,
time=1,
focal_sites=[],
haplotype=np.zeros(num_sites, dtype=np.int8) + 2,
)
# focal sites must be within start:end
with pytest.raises(ValueError):
ancestor_data.add_ancestor(
Expand All @@ -2125,7 +2141,9 @@ def test_add_ancestor_errors(self):

def test_iterator(self):
sample_data, ancestors = self.get_example_data(6, 10, 10)
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
assert ancestor_data.num_ancestors > 1
assert ancestor_data.num_ancestors == len(ancestors)
Expand All @@ -2137,7 +2155,9 @@ def test_iterator(self):

def test_equals(self):
sample_data, ancestors = self.get_example_data(6, 1, 2)
with tsinfer.AncestorData(sample_data) as ancestor_data:
with tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
) as ancestor_data:
num_sites = ancestor_data.num_sites
haplotype = np.ones(num_sites, dtype=np.int8)
ancestor_data.add_ancestor(
Expand All @@ -2151,7 +2171,9 @@ def test_equals(self):

def test_accessor(self):
sample_data, ancestors = self.get_example_data(6, 10, 10)
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
for i, new_ancestor in enumerate(ancestor_data.ancestors()):
assert new_ancestor == ancestor_data.ancestor(i)
Expand All @@ -2178,7 +2200,9 @@ def test_zero_sequence_length(self):

def test_bad_insert_proxy_samples(self):
sample_data, ancestor_haps = self.get_example_data(10, 10, 40)
ancestors = tsinfer.AncestorData(sample_data)
ancestors = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
with pytest.raises(ValueError, match="not finalised"):
ancestors.insert_proxy_samples(sample_data)
self.verify_data_round_trip(sample_data, ancestors, ancestor_haps)
Expand Down Expand Up @@ -2458,7 +2482,9 @@ def test_multiple_focal_sites(self):
sample_data.add_site(5, [1, 0, 0, 1])
sample_data.add_site(10, [1, 1, 1, 0])
sample_data.finalise()
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
ancestor_data.add_ancestor(0, 3, 0.6666, [0, 2], [1, 1, 0])
ancestor_data.add_ancestor(1, 2, 0.333, [1], [1])
ancestor_data.finalise()
Expand Down
53 changes: 24 additions & 29 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,12 @@ def get_random_data_example(num_samples, num_sites, seed=42, num_states=2):


class TestUnfinalisedErrors:
def make_ancestor_data_unfinalised(self, path=None):
with tsinfer.SampleData(path=path, sequence_length=2) as sample_data:
sample_data.add_site(1, genotypes=[0, 1, 1, 0], alleles=["G", "C"])
with pytest.raises(ValueError):
tsinfer.AncestorData(sample_data)
if path is not None:
sample_data.close()

def match_ancestors_ancestors_unfinalised(self, path=None):
with tsinfer.SampleData(sequence_length=2) as sample_data:
sample_data.add_site(1, genotypes=[0, 1, 1, 0], alleles=["G", "C"])
with tsinfer.AncestorData(sample_data, path=path) as ancestor_data:
with tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length, path=path
) as ancestor_data:
ancestor_data.add_ancestor(
start=0,
end=1,
Expand All @@ -75,14 +69,6 @@ def match_ancestors_ancestors_unfinalised(self, path=None):
if path is not None:
ancestor_data.close()

def test_make_ancestor_data(self):
self.make_ancestor_data_unfinalised()

def test_make_ancestor_data_file(self):
with tempfile.TemporaryDirectory(prefix="tsinf_inference_test") as tempdir:
filename = os.path.join(tempdir, "samples.tmp")
self.make_ancestor_data_unfinalised(filename)

def test_match_ancestors_ancestors(self):
self.match_ancestors_ancestors_unfinalised()

Expand Down Expand Up @@ -615,7 +601,9 @@ def verify_data_round_trip(
)

num_alleles = sample_data.num_alleles()
with tsinfer.AncestorData(sample_data) as ancestor_data:
with tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
) as ancestor_data:
t = np.sum(num_alleles) + 1
for j in range(sample_data.num_sites):
for allele in range(num_alleles[j] - 1):
Expand Down Expand Up @@ -1451,7 +1439,9 @@ def verify_inserted_ancestors(self, ts):
with tsinfer.SampleData(sequence_length=ts.sequence_length) as sample_data:
for v in ts.variants():
sample_data.add_site(v.position, v.genotypes, v.alleles)
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
tsinfer.build_simulated_ancestors(sample_data, ancestor_data, ts)
ancestor_data.finalise()

Expand Down Expand Up @@ -1539,12 +1529,9 @@ def test_bad_focal_sites(self):
(tsinfer.C_ENGINE, _tsinfer.LibraryError),
(tsinfer.PY_ENGINE, ValueError),
]:
with tsinfer.formats.AncestorData(sample_data) as ancestor_data:
g = np.zeros(2, dtype=np.int8)
h = np.zeros(1, dtype=np.int8)
generator = tsinfer.AncestorsGenerator(
sample_data, ancestor_data, engine=engine
)
g = np.zeros(2, dtype=np.int8)
h = np.zeros(1, dtype=np.int8)
generator = tsinfer.AncestorsGenerator(sample_data, None, {}, engine=engine)
generator.ancestor_builder.add_site(1, g)
with pytest.raises(error):
generator.ancestor_builder.make_ancestor([0], h)
Expand Down Expand Up @@ -2118,7 +2105,9 @@ def infer(self, ts, engine, path_compression=False, precision=None):
sample_data.add_site(v.site.position, v.genotypes, v.alleles)
sample_data.finalise()

ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
tsinfer.build_simulated_ancestors(sample_data, ancestor_data, ts)
ancestor_data.finalise()
ancestors_ts = tsinfer.match_ancestors(
Expand Down Expand Up @@ -2326,7 +2315,9 @@ def test_easy_case(self):
for j in range(num_sites):
sample_data.add_site(j, [0, 1, 1])
sample_data.finalise()
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)

ancestor_data.add_ancestor( # ID 0
start=0, end=6, focal_sites=[], time=5, haplotype=[0, 0, 0, 0, 0, 0]
Expand Down Expand Up @@ -2372,7 +2363,9 @@ def test_partial_overlap(self):
for j in range(num_sites):
sample_data.add_site(j, [0, 1, 1])
sample_data.finalise()
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)

ancestor_data.add_ancestor( # ID 0
start=0, end=7, focal_sites=[], time=5, haplotype=[0, 0, 0, 0, 0, 0, 0]
Expand Down Expand Up @@ -2417,7 +2410,9 @@ def test_edge_overlap_bug(self):
with tsinfer.SampleData() as sample_data:
for j in range(num_sites):
sample_data.add_site(j, [0, 1, 1])
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)

ancestor_data.add_ancestor( # ID 0
start=0,
Expand Down
2 changes: 1 addition & 1 deletion tsinfer/eval_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def run_perfect_inference(
logger.info("Using provided tree sequence")
ancestors_ts = make_ancestors_ts(ts, remove_leaves=True)
else:
ancestor_data = formats.AncestorData(sample_data)
ancestor_data = formats.AncestorData(sample_data.sites_position, sample_data.sequence_length)
build_simulated_ancestors(
sample_data, ancestor_data, ts, time_chunking=time_chunking
)
Expand Down
Loading

0 comments on commit 398061c

Please sign in to comment.