Skip to content

Commit

Permalink
Add an axis attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
gtca committed Apr 15, 2022
1 parent 29366a4 commit 422d5c9
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
12 changes: 12 additions & 0 deletions mudata/_core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def _write_h5mu(file: h5py.File, mdata: MuData, write_data=True, **kwargs):
write_attribute(file, "obsmap", mdata.obsmap, dataset_kwargs=kwargs)
write_attribute(file, "varmap", mdata.varmap, dataset_kwargs=kwargs)

attrs = file.attrs
attrs["axis"] = mdata.axis

mod = file.require_group("mod")
for k, v in mdata.mod.items():
group = mod.require_group(k)
Expand Down Expand Up @@ -142,6 +145,9 @@ def write_zarr(
write_attribute(file, "obsmap", mdata.obsmap, dataset_kwargs=kwargs)
write_attribute(file, "varmap", mdata.varmap, dataset_kwargs=kwargs)

attrs = file.attrs
attrs["axis"] = mdata.axis

mod = file.require_group("mod")
for k, v in mdata.mod.items():
group = mod.require_group(k)
Expand Down Expand Up @@ -177,6 +183,9 @@ def write_zarr(
attrs["encoder"] = "mudata"
attrs["encoder-version"] = __version__

mod_attrs = mod.attrs
mod_attrs["mod-order"] = list(mdata.mod.keys())

attrs = file.attrs
attrs["encoding-type"] = "MuData"
attrs["encoding-version"] = __mudataversion__
Expand Down Expand Up @@ -418,6 +427,9 @@ def read_h5mu(filename: PathLike, backed: Union[str, bool, None] = None):
else:
d[k] = read_attribute(f[k])

if "axis" in f.attrs:
d["axis"] = f.attrs["axis"]

mu = MuData._init_from_dict_(**d)
mu.file = manager
return mu
Expand Down
21 changes: 19 additions & 2 deletions mudata/_core/mudata.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def __init__(
self._obsmap = MuAxisArrays(self, 0, kwargs.get("obsmap", {}))
self._varmap = MuAxisArrays(self, 1, kwargs.get("varmap", {}))

self._axis = kwargs.get("axis") or 0

# Restore proper .obs and .var
self.update()

Expand All @@ -151,6 +153,8 @@ def __init__(
self._varp = PairwiseArrays(self, 1, dict())
self._varmap = MuAxisArrays(self, 1, dict())

self._axis = 0

self.update()

def _init_common(self):
Expand Down Expand Up @@ -210,6 +214,7 @@ def _init_as_view(self, mudata_ref: "MuData", index):
self.is_view = True
self.file = mudata_ref.file
self._mudata_ref = mudata_ref
self._axis = mudata_ref._axis

def _init_as_actual(self, data: "MuData"):
self._init_common()
Expand All @@ -223,6 +228,7 @@ def _init_as_actual(self, data: "MuData"):
self._varp = PairwiseArrays(self, 1, convert_to_dict(data.varp))
self._varmap = MuAxisArrays(self, 1, convert_to_dict(data.varmap))
self.uns = data.uns
self._axis = data._axis

@classmethod
def _init_from_dict_(
Expand All @@ -237,6 +243,7 @@ def _init_from_dict_(
varp: Optional[Union[np.ndarray, Mapping[str, Sequence[Any]]]] = None,
obsmap: Optional[Mapping[str, Sequence[int]]] = None,
varmap: Optional[Mapping[str, Sequence[int]]] = None,
axis: anndata.compat.Literal[0, 1] = 0,
):

return cls(
Expand All @@ -250,6 +257,7 @@ def _init_from_dict_(
varp=varp,
obsmap=obsmap,
varmap=varmap,
axis=axis,
)

def _check_duplicated_attr_names(self, attr: str):
Expand Down Expand Up @@ -308,6 +316,7 @@ def copy(self, filename: Optional[PathLike] = None) -> "MuData":
self.varp.copy(),
self.obsmap.copy(),
self.varmap.copy(),
self.axis,
)
else:
if filename is None:
Expand Down Expand Up @@ -787,7 +796,7 @@ def update_obs(self):
"""
Update .obs slot of MuData with the newest .obs data from all the modalities
"""
self._update_attr("obs", axis=1)
self._update_attr("obs", axis=1, join_common=bool(True * self.axis == 1))

def obs_names_make_unique(self):
"""
Expand Down Expand Up @@ -879,7 +888,7 @@ def update_var(self):
"""
Update .var slot of MuData with the newest .var data from all the modalities
"""
self._update_attr("var", axis=0, join_common=True)
self._update_attr("var", axis=0, join_common=bool(True * self.axis == 0))

def var_names_make_unique(self):
"""
Expand Down Expand Up @@ -1049,6 +1058,14 @@ def update(self):
self.update_var()
self.update_obs()

@property
def axis(self) -> int:
"""
MuData axis
"""
return self._axis


def write_h5mu(self, filename: Optional[str] = None, **kwargs):
"""
Write MuData object to an HDF5 file
Expand Down

0 comments on commit 422d5c9

Please sign in to comment.