diff --git a/tsinfer/formats.py b/tsinfer/formats.py index e768f6f3..53f6784a 100644 --- a/tsinfer/formats.py +++ b/tsinfer/formats.py @@ -2589,86 +2589,79 @@ def __init__(self, position, sequence_length, chunk_size_sites=None, **kwargs): else: self._chunk_size_sites = chunk_size_sites - chunks = self._chunk_size - self.data.attrs["sequence_length"] = sequence_length if self.sequence_length <= 0: raise ValueError("Bad samples file: sequence_length cannot be zero or less") # We specify fill_value here due to https://github.com/pydata/xarray/issues/7292 - a = self.data.create_dataset( + self.create_dataset("sample_start", dtype=np.int32) + self.create_dataset("sample_end", dtype=np.int32) + self.create_dataset("sample_time", dtype=np.float64) + self.create_dataset("sample_focal_sites", dtype="array:i4") + + self.create_dataset( "variant_position", data=position, shape=position.shape, chunks=self._chunk_size_sites, - compressor=self._compressor, dtype=np.float64, - fill_value=None, - ) - a.attrs["_ARRAY_DIMENSIONS"] = ["variants"] - a = self.data.create_dataset( - "sample_start", - shape=(0,), - chunks=chunks, - compressor=self._compressor, - dtype=np.int32, - fill_value=None, - ) - a.attrs["_ARRAY_DIMENSIONS"] = ["samples"] - a = self.data.create_dataset( - "sample_end", - shape=(0,), - chunks=chunks, - compressor=self._compressor, - dtype=np.int32, - fill_value=None, - ) - a.attrs["_ARRAY_DIMENSIONS"] = ["samples"] - a = self.data.create_dataset( - "sample_time", - shape=(0,), - chunks=chunks, - compressor=self._compressor, - dtype=np.float64, - fill_value=None, - ) - a.attrs["_ARRAY_DIMENSIONS"] = ["samples"] - a = self.data.create_dataset( - "sample_focal_sites", - shape=(0,), - chunks=chunks, - dtype="array:i4", - compressor=self._compressor, - fill_value=None, + dimensions=["variants"], ) - a.attrs["_ARRAY_DIMENSIONS"] = ["samples"] - a = self.data.create_dataset( + + # We have to include a ploidy dimension sgkit compatibility + a = self.create_dataset( "call_genotype", - # We have to include a ploidy dimension sgkit compatibility + dtype="i1", shape=(self.num_sites, 0, 1), chunks=(self._chunk_size_sites, self._chunk_size, 1), - dtype="i1", - compressor=self._compressor, - fill_value=None, + dimensions=["variants", "samples", "ploidy"], ) - a.attrs["_ARRAY_DIMENSIONS"] = ["variants", "samples", "ploidy"] a.attrs["mixed_ploidy"] = False - # Sgkit requires this array to be present - a = self.data.create_dataset( + a = self.create_dataset( "call_genotype_mask", + dtype="i1", shape=(self.num_sites, 0, 1), chunks=(self._chunk_size_sites, self._chunk_size, 1), - dtype="i1", - compressor=self._compressor, - fill_value=None, + dimensions=["variants", "samples", "ploidy"], ) - a.attrs["_ARRAY_DIMENSIONS"] = ["variants", "samples", "ploidy"] # We add this to be identical to sgkit generated arrays a.attrs["dtype"] = "bool" self._alloc_ancestor_writer() + def create_dataset( + self, + name, + *, + data=None, + shape=None, + chunks=None, + dtype=None, + compressor=None, + dimensions=None, + ): + if shape is None: + shape = (0,) + if chunks is None: + chunks = self._chunk_size + if compressor is None: + compressor = self._compressor + if dimensions is None: + dimensions = ["samples"] + + ds = self.data.create_dataset( + name, + data=data, + shape=shape, + chunks=chunks, + dtype=dtype, + compressor=compressor, + fill_value=None, + ) + ds.attrs["_ARRAY_DIMENSIONS"] = dimensions + return ds + def _alloc_ancestor_writer(self): self.ancestor_writer = BufferedItemWriter( { @@ -3147,32 +3140,27 @@ def finalise(self): del self.data["variant_allele"] except KeyError: pass - a = self.data.create_dataset( + self.create_dataset( "variant_allele", data=np.tile(["0", "1"], (self.num_sites, 1)), shape=(self.num_sites, 2), chunks=(self.sites_position.chunks[0], 2), - compressor=self.sites_position.compressor, dtype="U1", - fill_value=None, + compressor=self.sites_position.compressor, + dimensions=["variants", "alleles"], ) - a.attrs["_ARRAY_DIMENSIONS"] = ["variants", "alleles"] try: del self.data["sample_id"] except KeyError: pass - a = self.data.create_dataset( + self.create_dataset( "sample_id", data=[f"tsinf_anc_{i}" for i in range(len(self.ancestors_start))], - shape=(self.num_sites, 2), + shape=(len(self.ancestors_start),), chunks=self.ancestors_start.chunks, compressor=self.ancestors_start.compressor, - # dtype="U", - fill_value=None, ) - a.attrs["_ARRAY_DIMENSIONS"] = ["samples"] - super().finalise() ####################################