Skip to content

Commit

Permalink
Factor out common create dataset args
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery authored and Ben Jeffery committed Jan 17, 2023
1 parent 38c6442 commit b1f980f
Showing 1 changed file with 52 additions and 64 deletions.
116 changes: 52 additions & 64 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2589,86 +2589,79 @@ def __init__(self, position, sequence_length, chunk_size_sites=None, **kwargs):
else:
self._chunk_size_sites = chunk_size_sites

chunks = self._chunk_size

self.data.attrs["sequence_length"] = sequence_length
if self.sequence_length <= 0:
raise ValueError("Bad samples file: sequence_length cannot be zero or less")

# We specify fill_value here due to https://github.com/pydata/xarray/issues/7292
a = self.data.create_dataset(
self.create_dataset("sample_start", dtype=np.int32)
self.create_dataset("sample_end", dtype=np.int32)
self.create_dataset("sample_time", dtype=np.float64)
self.create_dataset("sample_focal_sites", dtype="array:i4")

self.create_dataset(
"variant_position",
data=position,
shape=position.shape,
chunks=self._chunk_size_sites,
compressor=self._compressor,
dtype=np.float64,
fill_value=None,
)
a.attrs["_ARRAY_DIMENSIONS"] = ["variants"]
a = self.data.create_dataset(
"sample_start",
shape=(0,),
chunks=chunks,
compressor=self._compressor,
dtype=np.int32,
fill_value=None,
)
a.attrs["_ARRAY_DIMENSIONS"] = ["samples"]
a = self.data.create_dataset(
"sample_end",
shape=(0,),
chunks=chunks,
compressor=self._compressor,
dtype=np.int32,
fill_value=None,
)
a.attrs["_ARRAY_DIMENSIONS"] = ["samples"]
a = self.data.create_dataset(
"sample_time",
shape=(0,),
chunks=chunks,
compressor=self._compressor,
dtype=np.float64,
fill_value=None,
)
a.attrs["_ARRAY_DIMENSIONS"] = ["samples"]
a = self.data.create_dataset(
"sample_focal_sites",
shape=(0,),
chunks=chunks,
dtype="array:i4",
compressor=self._compressor,
fill_value=None,
dimensions=["variants"],
)
a.attrs["_ARRAY_DIMENSIONS"] = ["samples"]
a = self.data.create_dataset(

# We have to include a ploidy dimension sgkit compatibility
a = self.create_dataset(
"call_genotype",
# We have to include a ploidy dimension sgkit compatibility
dtype="i1",
shape=(self.num_sites, 0, 1),
chunks=(self._chunk_size_sites, self._chunk_size, 1),
dtype="i1",
compressor=self._compressor,
fill_value=None,
dimensions=["variants", "samples", "ploidy"],
)
a.attrs["_ARRAY_DIMENSIONS"] = ["variants", "samples", "ploidy"]
a.attrs["mixed_ploidy"] = False

# Sgkit requires this array to be present
a = self.data.create_dataset(
a = self.create_dataset(
"call_genotype_mask",
dtype="i1",
shape=(self.num_sites, 0, 1),
chunks=(self._chunk_size_sites, self._chunk_size, 1),
dtype="i1",
compressor=self._compressor,
fill_value=None,
dimensions=["variants", "samples", "ploidy"],
)
a.attrs["_ARRAY_DIMENSIONS"] = ["variants", "samples", "ploidy"]
# We add this to be identical to sgkit generated arrays
a.attrs["dtype"] = "bool"

self._alloc_ancestor_writer()

def create_dataset(
self,
name,
*,
data=None,
shape=None,
chunks=None,
dtype=None,
compressor=None,
dimensions=None,
):
if shape is None:
shape = (0,)
if chunks is None:
chunks = self._chunk_size
if compressor is None:
compressor = self._compressor
if dimensions is None:
dimensions = ["samples"]

ds = self.data.create_dataset(
name,
data=data,
shape=shape,
chunks=chunks,
dtype=dtype,
compressor=compressor,
fill_value=None,
)
ds.attrs["_ARRAY_DIMENSIONS"] = dimensions
return ds

def _alloc_ancestor_writer(self):
self.ancestor_writer = BufferedItemWriter(
{
Expand Down Expand Up @@ -3147,32 +3140,27 @@ def finalise(self):
del self.data["variant_allele"]
except KeyError:
pass
a = self.data.create_dataset(
self.create_dataset(
"variant_allele",
data=np.tile(["0", "1"], (self.num_sites, 1)),
shape=(self.num_sites, 2),
chunks=(self.sites_position.chunks[0], 2),
compressor=self.sites_position.compressor,
dtype="U1",
fill_value=None,
compressor=self.sites_position.compressor,
dimensions=["variants", "alleles"],
)
a.attrs["_ARRAY_DIMENSIONS"] = ["variants", "alleles"]

try:
del self.data["sample_id"]
except KeyError:
pass
a = self.data.create_dataset(
self.create_dataset(
"sample_id",
data=[f"tsinf_anc_{i}" for i in range(len(self.ancestors_start))],
shape=(self.num_sites, 2),
shape=(len(self.ancestors_start),),
chunks=self.ancestors_start.chunks,
compressor=self.ancestors_start.compressor,
# dtype="U",
fill_value=None,
)
a.attrs["_ARRAY_DIMENSIONS"] = ["samples"]

super().finalise()

####################################
Expand Down

0 comments on commit b1f980f

Please sign in to comment.