From 398061c9fa83eccf7f967ebfcd48b791a8394a57 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Wed, 23 Nov 2022 16:07:19 +0000 Subject: [PATCH] WIP - decouple SampleData from AncestorData --- evaluation.py | 6 ++-- tests/test_formats.py | 72 ++++++++++++++++++++++++++++------------- tests/test_inference.py | 53 ++++++++++++++---------------- tsinfer/eval_util.py | 2 +- tsinfer/formats.py | 49 ++++++++-------------------- tsinfer/inference.py | 43 +++++++++++++++--------- visualisation.py | 2 +- 7 files changed, 119 insertions(+), 108 deletions(-) diff --git a/evaluation.py b/evaluation.py index 472b3bff..35d3c17c 100644 --- a/evaluation.py +++ b/evaluation.py @@ -147,7 +147,7 @@ def run_infer( sample_data = tsinfer.SampleData.from_tree_sequence(ts) if exact_ancestors: - ancestor_data = tsinfer.AncestorData(sample_data) + ancestor_data = tsinfer.AncestorData(sample_data.sites_position, sample_data.sequence_length) tsinfer.build_simulated_ancestors(sample_data, ancestor_data, ts) ancestor_data.finalise() else: @@ -535,7 +535,7 @@ def ancestor_properties_worker(args): } if compute_exact: - exact_anc = tsinfer.AncestorData(sample_data) + exact_anc = tsinfer.AncestorData(sample_data.sites_position, sample_data.sequence_length) tsinfer.build_simulated_ancestors(sample_data, exact_anc, ts) exact_anc.finalise() # Show lengths as a fraction of the total. @@ -819,7 +819,7 @@ def sim_true_and_inferred_ancestors(args): sample_data = generate_samples(ts, args.error) inferred_anc = tsinfer.generate_ancestors(sample_data, engine=args.engine) - true_anc = tsinfer.AncestorData(sample_data) + true_anc = tsinfer.AncestorData(sample_data.sites_position, sample_data.sequence_length) tsinfer.build_simulated_ancestors(sample_data, true_anc, ts) true_anc.finalise() return sample_data, true_anc, inferred_anc diff --git a/tests/test_formats.py b/tests/test_formats.py index e9b21220..45baaed7 100644 --- a/tests/test_formats.py +++ b/tests/test_formats.py @@ -1951,7 +1951,9 @@ def verify_data_round_trip(self, sample_data, ancestor_data, ancestors): def test_defaults_no_path(self): sample_data, ancestors = self.get_example_data(10, 10, 40) - ancestor_data = tsinfer.AncestorData(sample_data) + ancestor_data = tsinfer.AncestorData( + sample_data.sites_position, sample_data.sequence_length + ) self.verify_data_round_trip(sample_data, ancestor_data, ancestors) for _, array in ancestor_data.arrays(): assert array.compressor == formats.DEFAULT_COMPRESSOR @@ -1963,7 +1965,9 @@ def test_defaults_with_path(self): sample_data, ancestors = self.get_example_data(10, 10, 40) with tempfile.TemporaryDirectory(prefix="tsinf_format_test") as tempdir: filename = os.path.join(tempdir, "ancestors.tmp") - ancestor_data = tsinfer.AncestorData(sample_data, path=filename) + ancestor_data = tsinfer.AncestorData( + sample_data.sites_position, sample_data.sequence_length, path=filename + ) self.verify_data_round_trip(sample_data, ancestor_data, ancestors) compressor = formats.DEFAULT_COMPRESSOR for _, array in ancestor_data.arrays(): @@ -1978,17 +1982,27 @@ def test_bad_max_file_size(self): for bad_size in ["a", "", -1]: with pytest.raises(ValueError): formats.AncestorData( - sample_data, path=filename, max_file_size=bad_size + sample_data.sites_position, + sample_data.sequence_length, + path=filename, + max_file_size=bad_size, ) for bad_size in [[1, 3], np.array([1, 2])]: with pytest.raises(TypeError): formats.AncestorData( - sample_data, path=filename, max_file_size=bad_size + sample_data.sites_position, + sample_data.sequence_length, + path=filename, + max_file_size=bad_size, ) def test_provenance(self): sample_data, ancestors = self.get_example_data(10, 10, 40) - ancestor_data = tsinfer.AncestorData(sample_data) + ancestor_data = tsinfer.AncestorData( + sample_data.sites_position, sample_data.sequence_length + ) + for timestamp, record in sample_data.provenances(): + ancestor_data.add_provenance(timestamp, record) self.verify_data_round_trip(sample_data, ancestor_data, ancestors) assert ancestor_data.num_provenances == sample_data.num_provenances + 1 @@ -2008,7 +2022,11 @@ def test_chunk_size(self): N = 20 for chunk_size in [1, 2, 3, N - 1, N, N + 1]: sample_data, ancestors = self.get_example_data(6, 1, num_ancestors=N) - ancestor_data = tsinfer.AncestorData(sample_data, chunk_size=chunk_size) + ancestor_data = tsinfer.AncestorData( + sample_data.sites_position, + sample_data.sequence_length, + chunk_size=chunk_size, + ) self.verify_data_round_trip(sample_data, ancestor_data, ancestors) assert ancestor_data.ancestors_haplotype.chunks == (chunk_size,) assert ancestor_data.ancestors_focal_sites.chunks == (chunk_size,) @@ -2020,7 +2038,9 @@ def test_filename(self): sample_data, ancestors = self.get_example_data(10, 2, 40) with tempfile.TemporaryDirectory(prefix="tsinf_format_test") as tempdir: filename = os.path.join(tempdir, "ancestors.tmp") - ancestor_data = tsinfer.AncestorData(sample_data, path=filename) + ancestor_data = tsinfer.AncestorData( + sample_data.sites_position, sample_data.sequence_length, path=filename + ) assert os.path.exists(filename) assert not os.path.isdir(filename) self.verify_data_round_trip(sample_data, ancestor_data, ancestors) @@ -2043,7 +2063,10 @@ def test_chunk_size_file_equal(self): filename = os.path.join(tempdir, f"samples_{chunk_size}.tmp") files.append(filename) with tsinfer.AncestorData( - sample_data, path=filename, chunk_size=chunk_size + sample_data.sites_position, + sample_data.sequence_length, + path=filename, + chunk_size=chunk_size, ) as ancestor_data: self.verify_data_round_trip(sample_data, ancestor_data, ancestors) assert ancestor_data.ancestors_haplotype.chunks == (chunk_size,) @@ -2054,7 +2077,9 @@ def test_chunk_size_file_equal(self): def test_add_ancestor_errors(self): sample_data, ancestors = self.get_example_data(22, 16, 30) - ancestor_data = tsinfer.AncestorData(sample_data) + ancestor_data = tsinfer.AncestorData( + sample_data.sites_position, sample_data.sequence_length + ) num_sites = ancestor_data.num_sites haplotype = np.zeros(num_sites, dtype=np.int8) ancestor_data.add_ancestor( @@ -2091,15 +2116,6 @@ def test_add_ancestor_errors(self): focal_sites=[], haplotype=np.zeros(num_sites + 1, dtype=np.int8), ) - # Haplotypes must be < num_alleles - with pytest.raises(ValueError): - ancestor_data.add_ancestor( - start=0, - end=num_sites, - time=1, - focal_sites=[], - haplotype=np.zeros(num_sites, dtype=np.int8) + 2, - ) # focal sites must be within start:end with pytest.raises(ValueError): ancestor_data.add_ancestor( @@ -2125,7 +2141,9 @@ def test_add_ancestor_errors(self): def test_iterator(self): sample_data, ancestors = self.get_example_data(6, 10, 10) - ancestor_data = tsinfer.AncestorData(sample_data) + ancestor_data = tsinfer.AncestorData( + sample_data.sites_position, sample_data.sequence_length + ) self.verify_data_round_trip(sample_data, ancestor_data, ancestors) assert ancestor_data.num_ancestors > 1 assert ancestor_data.num_ancestors == len(ancestors) @@ -2137,7 +2155,9 @@ def test_iterator(self): def test_equals(self): sample_data, ancestors = self.get_example_data(6, 1, 2) - with tsinfer.AncestorData(sample_data) as ancestor_data: + with tsinfer.AncestorData( + sample_data.sites_position, sample_data.sequence_length + ) as ancestor_data: num_sites = ancestor_data.num_sites haplotype = np.ones(num_sites, dtype=np.int8) ancestor_data.add_ancestor( @@ -2151,7 +2171,9 @@ def test_equals(self): def test_accessor(self): sample_data, ancestors = self.get_example_data(6, 10, 10) - ancestor_data = tsinfer.AncestorData(sample_data) + ancestor_data = tsinfer.AncestorData( + sample_data.sites_position, sample_data.sequence_length + ) self.verify_data_round_trip(sample_data, ancestor_data, ancestors) for i, new_ancestor in enumerate(ancestor_data.ancestors()): assert new_ancestor == ancestor_data.ancestor(i) @@ -2178,7 +2200,9 @@ def test_zero_sequence_length(self): def test_bad_insert_proxy_samples(self): sample_data, ancestor_haps = self.get_example_data(10, 10, 40) - ancestors = tsinfer.AncestorData(sample_data) + ancestors = tsinfer.AncestorData( + sample_data.sites_position, sample_data.sequence_length + ) with pytest.raises(ValueError, match="not finalised"): ancestors.insert_proxy_samples(sample_data) self.verify_data_round_trip(sample_data, ancestors, ancestor_haps) @@ -2458,7 +2482,9 @@ def test_multiple_focal_sites(self): sample_data.add_site(5, [1, 0, 0, 1]) sample_data.add_site(10, [1, 1, 1, 0]) sample_data.finalise() - ancestor_data = tsinfer.AncestorData(sample_data) + ancestor_data = tsinfer.AncestorData( + sample_data.sites_position, sample_data.sequence_length + ) ancestor_data.add_ancestor(0, 3, 0.6666, [0, 2], [1, 1, 0]) ancestor_data.add_ancestor(1, 2, 0.333, [1], [1]) ancestor_data.finalise() diff --git a/tests/test_inference.py b/tests/test_inference.py index feb437f5..e2b869db 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -50,18 +50,12 @@ def get_random_data_example(num_samples, num_sites, seed=42, num_states=2): class TestUnfinalisedErrors: - def make_ancestor_data_unfinalised(self, path=None): - with tsinfer.SampleData(path=path, sequence_length=2) as sample_data: - sample_data.add_site(1, genotypes=[0, 1, 1, 0], alleles=["G", "C"]) - with pytest.raises(ValueError): - tsinfer.AncestorData(sample_data) - if path is not None: - sample_data.close() - def match_ancestors_ancestors_unfinalised(self, path=None): with tsinfer.SampleData(sequence_length=2) as sample_data: sample_data.add_site(1, genotypes=[0, 1, 1, 0], alleles=["G", "C"]) - with tsinfer.AncestorData(sample_data, path=path) as ancestor_data: + with tsinfer.AncestorData( + sample_data.sites_position, sample_data.sequence_length, path=path + ) as ancestor_data: ancestor_data.add_ancestor( start=0, end=1, @@ -75,14 +69,6 @@ def match_ancestors_ancestors_unfinalised(self, path=None): if path is not None: ancestor_data.close() - def test_make_ancestor_data(self): - self.make_ancestor_data_unfinalised() - - def test_make_ancestor_data_file(self): - with tempfile.TemporaryDirectory(prefix="tsinf_inference_test") as tempdir: - filename = os.path.join(tempdir, "samples.tmp") - self.make_ancestor_data_unfinalised(filename) - def test_match_ancestors_ancestors(self): self.match_ancestors_ancestors_unfinalised() @@ -615,7 +601,9 @@ def verify_data_round_trip( ) num_alleles = sample_data.num_alleles() - with tsinfer.AncestorData(sample_data) as ancestor_data: + with tsinfer.AncestorData( + sample_data.sites_position, sample_data.sequence_length + ) as ancestor_data: t = np.sum(num_alleles) + 1 for j in range(sample_data.num_sites): for allele in range(num_alleles[j] - 1): @@ -1451,7 +1439,9 @@ def verify_inserted_ancestors(self, ts): with tsinfer.SampleData(sequence_length=ts.sequence_length) as sample_data: for v in ts.variants(): sample_data.add_site(v.position, v.genotypes, v.alleles) - ancestor_data = tsinfer.AncestorData(sample_data) + ancestor_data = tsinfer.AncestorData( + sample_data.sites_position, sample_data.sequence_length + ) tsinfer.build_simulated_ancestors(sample_data, ancestor_data, ts) ancestor_data.finalise() @@ -1539,12 +1529,9 @@ def test_bad_focal_sites(self): (tsinfer.C_ENGINE, _tsinfer.LibraryError), (tsinfer.PY_ENGINE, ValueError), ]: - with tsinfer.formats.AncestorData(sample_data) as ancestor_data: - g = np.zeros(2, dtype=np.int8) - h = np.zeros(1, dtype=np.int8) - generator = tsinfer.AncestorsGenerator( - sample_data, ancestor_data, engine=engine - ) + g = np.zeros(2, dtype=np.int8) + h = np.zeros(1, dtype=np.int8) + generator = tsinfer.AncestorsGenerator(sample_data, None, {}, engine=engine) generator.ancestor_builder.add_site(1, g) with pytest.raises(error): generator.ancestor_builder.make_ancestor([0], h) @@ -2118,7 +2105,9 @@ def infer(self, ts, engine, path_compression=False, precision=None): sample_data.add_site(v.site.position, v.genotypes, v.alleles) sample_data.finalise() - ancestor_data = tsinfer.AncestorData(sample_data) + ancestor_data = tsinfer.AncestorData( + sample_data.sites_position, sample_data.sequence_length + ) tsinfer.build_simulated_ancestors(sample_data, ancestor_data, ts) ancestor_data.finalise() ancestors_ts = tsinfer.match_ancestors( @@ -2326,7 +2315,9 @@ def test_easy_case(self): for j in range(num_sites): sample_data.add_site(j, [0, 1, 1]) sample_data.finalise() - ancestor_data = tsinfer.AncestorData(sample_data) + ancestor_data = tsinfer.AncestorData( + sample_data.sites_position, sample_data.sequence_length + ) ancestor_data.add_ancestor( # ID 0 start=0, end=6, focal_sites=[], time=5, haplotype=[0, 0, 0, 0, 0, 0] @@ -2372,7 +2363,9 @@ def test_partial_overlap(self): for j in range(num_sites): sample_data.add_site(j, [0, 1, 1]) sample_data.finalise() - ancestor_data = tsinfer.AncestorData(sample_data) + ancestor_data = tsinfer.AncestorData( + sample_data.sites_position, sample_data.sequence_length + ) ancestor_data.add_ancestor( # ID 0 start=0, end=7, focal_sites=[], time=5, haplotype=[0, 0, 0, 0, 0, 0, 0] @@ -2417,7 +2410,9 @@ def test_edge_overlap_bug(self): with tsinfer.SampleData() as sample_data: for j in range(num_sites): sample_data.add_site(j, [0, 1, 1]) - ancestor_data = tsinfer.AncestorData(sample_data) + ancestor_data = tsinfer.AncestorData( + sample_data.sites_position, sample_data.sequence_length + ) ancestor_data.add_ancestor( # ID 0 start=0, diff --git a/tsinfer/eval_util.py b/tsinfer/eval_util.py index 85f84d62..ac2af1f7 100644 --- a/tsinfer/eval_util.py +++ b/tsinfer/eval_util.py @@ -695,7 +695,7 @@ def run_perfect_inference( logger.info("Using provided tree sequence") ancestors_ts = make_ancestors_ts(ts, remove_leaves=True) else: - ancestor_data = formats.AncestorData(sample_data) + ancestor_data = formats.AncestorData(sample_data.sites_position, sample_data.sequence_length) build_simulated_ancestors( sample_data, ancestor_data, ts, time_chunking=time_chunking ) diff --git a/tsinfer/formats.py b/tsinfer/formats.py index 39ff201a..a64e6e04 100644 --- a/tsinfer/formats.py +++ b/tsinfer/formats.py @@ -2242,7 +2242,7 @@ def __eq__(self, other): class AncestorData(DataContainer): """ - AncestorData(sample_data, *, path=None, num_flush_threads=0, compressor=None, \ + AncestorData(position, *, path=None, num_flush_threads=0, compressor=None, \ chunk_size=1024, max_file_size=None) Class representing the stored ancestor data produced by @@ -2253,8 +2253,8 @@ class AncestorData(DataContainer): See the documentation for :class:`SampleData` for a discussion of the ``max_file_size`` parameter. - :param SampleData sample_data: The :class:`.SampleData` instance - that this ancestor data file was generated from. + :param array-like position: Integer array of the site positions of the ancestors. + All values should be >0 and the array should be monotonically increasing. :param str path: The path of the file to store the ancestor data. If None, the information is stored in memory and not persistent. :param int num_flush_threads: The number of background threads to use @@ -2277,20 +2277,15 @@ class AncestorData(DataContainer): FORMAT_NAME = "tsinfer-ancestor-data" FORMAT_VERSION = (3, 0) - def __init__(self, sample_data, **kwargs): + def __init__(self, position, sequence_length, **kwargs): super().__init__(**kwargs) - sample_data._check_finalised() - self.sample_data = sample_data - if self.sample_data.sequence_length == 0: - raise ValueError("Bad samples file: sequence_length cannot be zero") - self.data.attrs["sequence_length"] = self.sample_data.sequence_length self._last_time = 0 - chunks = self._chunk_size - # By default all sites in the sample data file are used. - self._num_alleles = self.sample_data.num_alleles() - position = self.sample_data.sites_position[:] + self.data.attrs["sequence_length"] = sequence_length + if self.sequence_length == 0: + raise ValueError("Bad samples file: sequence_length cannot be zero") + self.data.create_dataset( "sites/position", data=position, @@ -2336,10 +2331,6 @@ def __init__(self, sample_data, **kwargs): self._alloc_ancestor_writer() - # Add in the provenance trail from the sample_data file. - for timestamp, record in sample_data.provenances(): - self.add_provenance(timestamp, record) - def _alloc_ancestor_writer(self): self.ancestor_writer = BufferedItemWriter( { @@ -2577,8 +2568,11 @@ def insert_proxy_samples( h[1] for h in sample_data.haplotypes(samples=sample_ids, sites=used_sites) ] - with AncestorData(sample_data, **kwargs) as other: - other.set_inference_sites(used_sites) + with AncestorData( + sample_data.sites_position[:][used_sites], + sample_data.sequence_length, + **kwargs, + ) as other: mutated_sites = set() # To check if mutations have ocurred yet ancestors_iter = self.ancestors() ancestor = next(ancestors_iter, None) @@ -2747,21 +2741,6 @@ def truncate_ancestors( # Write mode (building and editing) #################################### - def set_inference_sites(self, site_ids): - """ - Sets the sites used for inference (i.e., the sites at which ancestor haplotypes - are defined) to the specified list of site IDs. This must be a subset of the - sites in the sample data file, and the IDs must be in increasing order. - - This must be called before the first call to :meth:`.add_ancestor`. - """ - self._check_build_mode() - position = self.sample_data.sites_position[:][site_ids] - array = self.data["sites/position"] - array.resize(position.shape) - array[:] = position - self._num_alleles = self.sample_data.num_alleles(site_ids) - def add_ancestor(self, start, end, time, focal_sites, haplotype): """ Adds an ancestor with the specified haplotype, with ancestral material over the @@ -2782,8 +2761,6 @@ def add_ancestor(self, start, end, time, focal_sites, haplotype): raise ValueError("start must be < end") if haplotype.shape != (end - start,): raise ValueError("haplotypes incorrect shape.") - if np.any(haplotype >= self._num_alleles[start:end]): - raise ValueError("haplotype values must be < num_alleles.") if np.any(focal_sites < start) or np.any(focal_sites >= end): raise ValueError("focal sites must be between start and end") if time <= 0: diff --git a/tsinfer/inference.py b/tsinfer/inference.py index 26301c02..8f589c45 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -393,18 +393,21 @@ def generate_ancestors( "specified times with times-as-frequencies. To explicitly set an undefined" "time for a site, permanently excluding it from inference, set it to np.nan." ) - with formats.AncestorData(sample_data, path=path, **kwargs) as ancestor_data: - generator = AncestorsGenerator( - sample_data, - ancestor_data, - num_threads=num_threads, - engine=engine, - progress_monitor=progress_monitor, - ) - generator.add_sites(exclude_positions) - generator.run() - if record_provenance: - ancestor_data.record_provenance("generate_ancestors") + generator = AncestorsGenerator( + sample_data, + ancestor_data_path=path, + ancestor_data_kwargs=kwargs, + num_threads=num_threads, + engine=engine, + progress_monitor=progress_monitor, + ) + generator.add_sites(exclude_positions) + ancestor_data = generator.run() + for timestamp, record in sample_data.provenances(): + ancestor_data.add_provenance(timestamp, record) + if record_provenance: + ancestor_data.record_provenance("generate_ancestors") + ancestor_data.finalise() return ancestor_data @@ -840,18 +843,21 @@ class AncestorsGenerator: def __init__( self, sample_data, - ancestor_data, + ancestor_data_path, + ancestor_data_kwargs, num_threads=0, engine=constants.C_ENGINE, progress_monitor=None, ): self.sample_data = sample_data - self.ancestor_data = ancestor_data + self.ancestor_data_path = ancestor_data_path + self.ancestor_data_kwargs = ancestor_data_kwargs self.progress_monitor = _get_progress_monitor( progress_monitor, generate_ancestors=True ) self.max_sites = sample_data.num_sites self.num_sites = 0 + self.inference_site_ids = [] self.num_samples = sample_data.num_samples self.num_threads = num_threads if engine == constants.C_ENGINE: @@ -922,7 +928,7 @@ def add_sites(self, exclude_positions=None): self.num_sites += 1 progress.update() progress.close() - self.ancestor_data.set_inference_sites(inference_site_id) + self.inference_site_ids = inference_site_id logger.info("Finished adding sites") def _run_synchronous(self, progress): @@ -1022,6 +1028,12 @@ def run(self): for t, _ in reversed(self.descriptors): if t not in self.timepoint_to_epoch: self.timepoint_to_epoch[t] = len(self.timepoint_to_epoch) + 1 + self.ancestor_data = formats.AncestorData( + self.sample_data.sites_position[:][self.inference_site_ids], + self.sample_data.sequence_length, + path=self.ancestor_data_path, + **self.ancestor_data_kwargs, + ) if self.num_ancestors > 0: logger.info(f"Starting build for {self.num_ancestors} ancestors") progress = self.progress_monitor.get("ga_generate", self.num_ancestors) @@ -1053,6 +1065,7 @@ def run(self): self._run_threaded(progress) progress.close() logger.info("Finished building ancestors") + return self.ancestor_data class Matcher: diff --git a/visualisation.py b/visualisation.py index 7209074d..46044e02 100644 --- a/visualisation.py +++ b/visualisation.py @@ -431,7 +431,7 @@ def visualise( sample_data = tsinfer.SampleData.from_tree_sequence(ts) if perfect_ancestors: - ancestor_data = tsinfer.AncestorData(sample_data) + ancestor_data = tsinfer.AncestorData(sample_data.sites_position, sample_data.sequence_length) tsinfer.build_simulated_ancestors( sample_data, ancestor_data, ts, time_chunking=time_chunking )