Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple SampleData from AncestorData #779

Merged
merged 1 commit into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def run_infer(
sample_data = tsinfer.SampleData.from_tree_sequence(ts)

if exact_ancestors:
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
tsinfer.build_simulated_ancestors(sample_data, ancestor_data, ts)
ancestor_data.finalise()
else:
Expand Down Expand Up @@ -535,7 +537,9 @@ def ancestor_properties_worker(args):
}

if compute_exact:
exact_anc = tsinfer.AncestorData(sample_data)
exact_anc = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
tsinfer.build_simulated_ancestors(sample_data, exact_anc, ts)
exact_anc.finalise()
# Show lengths as a fraction of the total.
Expand Down Expand Up @@ -819,7 +823,9 @@ def sim_true_and_inferred_ancestors(args):
sample_data = generate_samples(ts, args.error)

inferred_anc = tsinfer.generate_ancestors(sample_data, engine=args.engine)
true_anc = tsinfer.AncestorData(sample_data)
true_anc = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trivial, but if you want to shorten sample_data to sd (which we do sometimes anyway in the tests), it will make all this fit on a single line. There are a whole set of other examples like this too.

)
tsinfer.build_simulated_ancestors(sample_data, true_anc, ts)
true_anc.finalise()
return sample_data, true_anc, inferred_anc
Expand Down
72 changes: 49 additions & 23 deletions tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -1951,7 +1951,9 @@ def verify_data_round_trip(self, sample_data, ancestor_data, ancestors):

def test_defaults_no_path(self):
sample_data, ancestors = self.get_example_data(10, 10, 40)
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
for _, array in ancestor_data.arrays():
assert array.compressor == formats.DEFAULT_COMPRESSOR
Expand All @@ -1963,7 +1965,9 @@ def test_defaults_with_path(self):
sample_data, ancestors = self.get_example_data(10, 10, 40)
with tempfile.TemporaryDirectory(prefix="tsinf_format_test") as tempdir:
filename = os.path.join(tempdir, "ancestors.tmp")
ancestor_data = tsinfer.AncestorData(sample_data, path=filename)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length, path=filename
)
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
compressor = formats.DEFAULT_COMPRESSOR
for _, array in ancestor_data.arrays():
Expand All @@ -1978,17 +1982,27 @@ def test_bad_max_file_size(self):
for bad_size in ["a", "", -1]:
with pytest.raises(ValueError):
formats.AncestorData(
sample_data, path=filename, max_file_size=bad_size
sample_data.sites_position,
sample_data.sequence_length,
path=filename,
max_file_size=bad_size,
)
for bad_size in [[1, 3], np.array([1, 2])]:
with pytest.raises(TypeError):
formats.AncestorData(
sample_data, path=filename, max_file_size=bad_size
sample_data.sites_position,
sample_data.sequence_length,
path=filename,
max_file_size=bad_size,
)

def test_provenance(self):
sample_data, ancestors = self.get_example_data(10, 10, 40)
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
for timestamp, record in sample_data.provenances():
ancestor_data.add_provenance(timestamp, record)
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
assert ancestor_data.num_provenances == sample_data.num_provenances + 1

Expand All @@ -2008,7 +2022,11 @@ def test_chunk_size(self):
N = 20
for chunk_size in [1, 2, 3, N - 1, N, N + 1]:
sample_data, ancestors = self.get_example_data(6, 1, num_ancestors=N)
ancestor_data = tsinfer.AncestorData(sample_data, chunk_size=chunk_size)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position,
sample_data.sequence_length,
chunk_size=chunk_size,
)
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
assert ancestor_data.ancestors_haplotype.chunks == (chunk_size,)
assert ancestor_data.ancestors_focal_sites.chunks == (chunk_size,)
Expand All @@ -2020,7 +2038,9 @@ def test_filename(self):
sample_data, ancestors = self.get_example_data(10, 2, 40)
with tempfile.TemporaryDirectory(prefix="tsinf_format_test") as tempdir:
filename = os.path.join(tempdir, "ancestors.tmp")
ancestor_data = tsinfer.AncestorData(sample_data, path=filename)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length, path=filename
)
assert os.path.exists(filename)
assert not os.path.isdir(filename)
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
Expand All @@ -2043,7 +2063,10 @@ def test_chunk_size_file_equal(self):
filename = os.path.join(tempdir, f"samples_{chunk_size}.tmp")
files.append(filename)
with tsinfer.AncestorData(
sample_data, path=filename, chunk_size=chunk_size
sample_data.sites_position,
sample_data.sequence_length,
path=filename,
chunk_size=chunk_size,
) as ancestor_data:
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
assert ancestor_data.ancestors_haplotype.chunks == (chunk_size,)
Expand All @@ -2054,7 +2077,9 @@ def test_chunk_size_file_equal(self):

def test_add_ancestor_errors(self):
sample_data, ancestors = self.get_example_data(22, 16, 30)
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
num_sites = ancestor_data.num_sites
haplotype = np.zeros(num_sites, dtype=np.int8)
ancestor_data.add_ancestor(
Expand Down Expand Up @@ -2091,15 +2116,6 @@ def test_add_ancestor_errors(self):
focal_sites=[],
haplotype=np.zeros(num_sites + 1, dtype=np.int8),
)
# Haplotypes must be < num_alleles
with pytest.raises(ValueError):
ancestor_data.add_ancestor(
start=0,
end=num_sites,
time=1,
focal_sites=[],
haplotype=np.zeros(num_sites, dtype=np.int8) + 2,
)
# focal sites must be within start:end
with pytest.raises(ValueError):
ancestor_data.add_ancestor(
Expand All @@ -2125,7 +2141,9 @@ def test_add_ancestor_errors(self):

def test_iterator(self):
sample_data, ancestors = self.get_example_data(6, 10, 10)
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
assert ancestor_data.num_ancestors > 1
assert ancestor_data.num_ancestors == len(ancestors)
Expand All @@ -2137,7 +2155,9 @@ def test_iterator(self):

def test_equals(self):
sample_data, ancestors = self.get_example_data(6, 1, 2)
with tsinfer.AncestorData(sample_data) as ancestor_data:
with tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
) as ancestor_data:
num_sites = ancestor_data.num_sites
haplotype = np.ones(num_sites, dtype=np.int8)
ancestor_data.add_ancestor(
Expand All @@ -2151,7 +2171,9 @@ def test_equals(self):

def test_accessor(self):
sample_data, ancestors = self.get_example_data(6, 10, 10)
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
self.verify_data_round_trip(sample_data, ancestor_data, ancestors)
for i, new_ancestor in enumerate(ancestor_data.ancestors()):
assert new_ancestor == ancestor_data.ancestor(i)
Expand All @@ -2178,7 +2200,9 @@ def test_zero_sequence_length(self):

def test_bad_insert_proxy_samples(self):
sample_data, ancestor_haps = self.get_example_data(10, 10, 40)
ancestors = tsinfer.AncestorData(sample_data)
ancestors = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
with pytest.raises(ValueError, match="not finalised"):
ancestors.insert_proxy_samples(sample_data)
self.verify_data_round_trip(sample_data, ancestors, ancestor_haps)
Expand Down Expand Up @@ -2458,7 +2482,9 @@ def test_multiple_focal_sites(self):
sample_data.add_site(5, [1, 0, 0, 1])
sample_data.add_site(10, [1, 1, 1, 0])
sample_data.finalise()
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
ancestor_data.add_ancestor(0, 3, 0.6666, [0, 2], [1, 1, 0])
ancestor_data.add_ancestor(1, 2, 0.333, [1], [1])
ancestor_data.finalise()
Expand Down
53 changes: 24 additions & 29 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,12 @@ def get_random_data_example(num_samples, num_sites, seed=42, num_states=2):


class TestUnfinalisedErrors:
def make_ancestor_data_unfinalised(self, path=None):
with tsinfer.SampleData(path=path, sequence_length=2) as sample_data:
sample_data.add_site(1, genotypes=[0, 1, 1, 0], alleles=["G", "C"])
with pytest.raises(ValueError):
tsinfer.AncestorData(sample_data)
if path is not None:
sample_data.close()

def match_ancestors_ancestors_unfinalised(self, path=None):
with tsinfer.SampleData(sequence_length=2) as sample_data:
sample_data.add_site(1, genotypes=[0, 1, 1, 0], alleles=["G", "C"])
with tsinfer.AncestorData(sample_data, path=path) as ancestor_data:
with tsinfer.AncestorData(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, sd rather than sample_data might make it all more readable (fewer lines)

sample_data.sites_position, sample_data.sequence_length, path=path
) as ancestor_data:
ancestor_data.add_ancestor(
start=0,
end=1,
Expand All @@ -75,14 +69,6 @@ def match_ancestors_ancestors_unfinalised(self, path=None):
if path is not None:
ancestor_data.close()

def test_make_ancestor_data(self):
self.make_ancestor_data_unfinalised()

def test_make_ancestor_data_file(self):
with tempfile.TemporaryDirectory(prefix="tsinf_inference_test") as tempdir:
filename = os.path.join(tempdir, "samples.tmp")
self.make_ancestor_data_unfinalised(filename)

def test_match_ancestors_ancestors(self):
self.match_ancestors_ancestors_unfinalised()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we test match_ancestors_ancestors_unfinalised with the path argument anywhere?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, on line 78.


Expand Down Expand Up @@ -615,7 +601,9 @@ def verify_data_round_trip(
)

num_alleles = sample_data.num_alleles()
with tsinfer.AncestorData(sample_data) as ancestor_data:
with tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
) as ancestor_data:
t = np.sum(num_alleles) + 1
for j in range(sample_data.num_sites):
for allele in range(num_alleles[j] - 1):
Expand Down Expand Up @@ -1451,7 +1439,9 @@ def verify_inserted_ancestors(self, ts):
with tsinfer.SampleData(sequence_length=ts.sequence_length) as sample_data:
for v in ts.variants():
sample_data.add_site(v.position, v.genotypes, v.alleles)
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
tsinfer.build_simulated_ancestors(sample_data, ancestor_data, ts)
ancestor_data.finalise()

Expand Down Expand Up @@ -1539,12 +1529,9 @@ def test_bad_focal_sites(self):
(tsinfer.C_ENGINE, _tsinfer.LibraryError),
(tsinfer.PY_ENGINE, ValueError),
]:
with tsinfer.formats.AncestorData(sample_data) as ancestor_data:
g = np.zeros(2, dtype=np.int8)
h = np.zeros(1, dtype=np.int8)
generator = tsinfer.AncestorsGenerator(
sample_data, ancestor_data, engine=engine
)
g = np.zeros(2, dtype=np.int8)
h = np.zeros(1, dtype=np.int8)
generator = tsinfer.AncestorsGenerator(sample_data, None, {}, engine=engine)
generator.ancestor_builder.add_site(1, g)
with pytest.raises(error):
generator.ancestor_builder.make_ancestor([0], h)
Expand Down Expand Up @@ -2118,7 +2105,9 @@ def infer(self, ts, engine, path_compression=False, precision=None):
sample_data.add_site(v.site.position, v.genotypes, v.alleles)
sample_data.finalise()

ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
tsinfer.build_simulated_ancestors(sample_data, ancestor_data, ts)
ancestor_data.finalise()
ancestors_ts = tsinfer.match_ancestors(
Expand Down Expand Up @@ -2326,7 +2315,9 @@ def test_easy_case(self):
for j in range(num_sites):
sample_data.add_site(j, [0, 1, 1])
sample_data.finalise()
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)

ancestor_data.add_ancestor( # ID 0
start=0, end=6, focal_sites=[], time=5, haplotype=[0, 0, 0, 0, 0, 0]
Expand Down Expand Up @@ -2372,7 +2363,9 @@ def test_partial_overlap(self):
for j in range(num_sites):
sample_data.add_site(j, [0, 1, 1])
sample_data.finalise()
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)

ancestor_data.add_ancestor( # ID 0
start=0, end=7, focal_sites=[], time=5, haplotype=[0, 0, 0, 0, 0, 0, 0]
Expand Down Expand Up @@ -2417,7 +2410,9 @@ def test_edge_overlap_bug(self):
with tsinfer.SampleData() as sample_data:
for j in range(num_sites):
sample_data.add_site(j, [0, 1, 1])
ancestor_data = tsinfer.AncestorData(sample_data)
ancestor_data = tsinfer.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)

ancestor_data.add_ancestor( # ID 0
start=0,
Expand Down
4 changes: 3 additions & 1 deletion tsinfer/eval_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,9 @@ def run_perfect_inference(
logger.info("Using provided tree sequence")
ancestors_ts = make_ancestors_ts(ts, remove_leaves=True)
else:
ancestor_data = formats.AncestorData(sample_data)
ancestor_data = formats.AncestorData(
sample_data.sites_position, sample_data.sequence_length
)
build_simulated_ancestors(
sample_data, ancestor_data, ts, time_chunking=time_chunking
)
Expand Down
Loading