Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error messages #173

Merged
merged 2 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions cellarium/ml/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,19 +568,21 @@ def main(args: ArgsType = None) -> None:
or the ``model_name`` key if ``args`` is a dictionary or ``Namespace``.
"""
if isinstance(args, (dict, Namespace)):
assert "model_name" in args, "'model_name' key must be specified in args"
if "model_name" not in args:
raise ValueError("'model_name' key must be specified in args")
model_name = args.pop("model_name")
elif isinstance(args, list):
assert len(args) > 0, "'model_name' must be specified as the first argument in args"
if len(args) == 0:
raise ValueError("'model_name' must be specified as the first argument in args")
model_name = args.pop(0)
elif args is None:
args = sys.argv[1:].copy()
assert len(args) > 0, "'model_name' must be specified after cellarium-ml"
if len(args) == 0:
raise ValueError("'model_name' must be specified after cellarium-ml")
model_name = args.pop(0)

assert (
model_name in REGISTERED_MODELS
), f"'model_name' must be one of {list(REGISTERED_MODELS.keys())}. Got '{model_name}'"
if model_name not in REGISTERED_MODELS:
raise ValueError(f"'model_name' must be one of {list(REGISTERED_MODELS.keys())}. Got '{model_name}'")
model_cli = REGISTERED_MODELS[model_name]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="Transforming to str index.")
Expand Down
6 changes: 5 additions & 1 deletion cellarium/ml/data/dadc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,11 @@ def __iter__(self):
# remove tail of data to make it evenly divisible.
indices = indices[:total_size]
indices = indices[rank * per_replica : (rank + 1) * per_replica]
assert len(indices) == per_replica
if len(indices) != per_replica:
raise ValueError(
f"The number of indices must be equal to the per_replica size. "
f"Got {len(indices)} != {per_replica} at rank {rank}."
)

yield from (self[indices[i : i + self.batch_size]] for i in range(iter_start, iter_end, self.batch_size))
# Sets epoch for persistent workers
Expand Down
38 changes: 26 additions & 12 deletions cellarium/ml/data/distributed_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,20 +175,29 @@ def __init__(
if (shard_size is None) and (last_shard_size is not None):
raise ValueError("If `last_shard_size` is specified then `shard_size` must also be specified.")
if limits is None:
assert shard_size is not None, "If `limits` is `None` then `shard_size` must be specified`"
if shard_size is None:
raise ValueError("If `limits` is `None` then `shard_size` must be specified`")
limits = [shard_size * (i + 1) for i in range(len(self.filenames))]
if last_shard_size is not None:
limits[-1] = limits[-1] - shard_size + last_shard_size
else:
limits = list(limits)
assert len(limits) == len(self.filenames)
if len(limits) != len(self.filenames):
raise ValueError(
f"The number of points in `limits` ({len(limits)}) must match "
f"the number of `filenames` ({len(self.filenames)})."
)
# lru cache
self.cache = LRU(max_cache_size)
self.max_cache_size = max_cache_size
self.cache_size_strictly_enforced = cache_size_strictly_enforced
# schema
adata0 = self.cache[self.filenames[0]] = read_h5ad_file(self.filenames[0])
assert len(adata0) == limits[0]
if len(adata0) != limits[0]:
raise ValueError(
f"The number of cells in the first anndata file ({len(adata0)}) "
f"does not match the first limit ({limits[0]})."
)
self.obs_columns_to_validate = obs_columns_to_validate
self.schema = AnnDataSchema(adata0, obs_columns_to_validate)
# lazy anndatas
Expand All @@ -199,7 +208,10 @@ def __init__(
# use filenames as default keys
if keys is None:
keys = self.filenames
assert len(keys) == len(self.filenames)
if len(keys) != len(self.filenames):
raise ValueError(
f"The number of keys ({len(keys)}) must match the number of `filenames` ({len(filenames)})."
)
with lazy_getattr():
super().__init__(
adatas=lazy_adatas,
Expand Down Expand Up @@ -238,10 +250,11 @@ def materialize(self, adatas_oidx: list[np.ndarray | None], vidx: Index1D) -> li
adata_idx_to_oidx = {i: oidx for i, oidx in enumerate(adatas_oidx) if oidx is not None}
n_adatas = len(adata_idx_to_oidx)
if self.cache_size_strictly_enforced:
assert n_adatas <= self.max_cache_size, (
f"Expected the number of anndata files ({n_adatas}) to be "
f"no more than the max cache size ({self.max_cache_size})."
)
if n_adatas > self.max_cache_size:
raise ValueError(
f"Expected the number of anndata files ({n_adatas}) to be "
f"no more than the max cache size ({self.max_cache_size})."
)
adatas = [None] * n_adatas
# first fetch cached anndata files
# this ensures that they are not popped if they were lru
Expand Down Expand Up @@ -375,10 +388,11 @@ def adata(self) -> AnnData:
# fetch anndata
adata = read_h5ad_file(self.filename)
# validate anndata
assert self.n_obs == adata.n_obs, (
"Expected n_obs for LazyAnnData object and backed anndata to match "
f"but found {self.n_obs} and {adata.n_obs}, respectively."
)
if self.n_obs != adata.n_obs:
raise ValueError(
"Expected `n_obs` for LazyAnnData object and backed anndata to match "
f"but found {self.n_obs} and {adata.n_obs}, respectively."
)
self.schema.validate_anndata(adata)
# cache anndata
if len(self.cache) < self.cache.max_size:
Expand Down
9 changes: 6 additions & 3 deletions cellarium/ml/data/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def read_h5ad_gcs(filename: str, storage_client: Client | None = None) -> AnnDat
Args:
filename: Path to the data file in Cloud Storage.
"""
assert filename.startswith("gs:")
if not filename.startswith("gs:"):
raise ValueError("The filename must start with 'gs://' protocol name.")
# parse bucket and blob names from the filename
filename = re.sub(r"^gs://?", "", filename)
bucket_name, blob_name = filename.split("/", 1)
Expand Down Expand Up @@ -51,7 +52,8 @@ def read_h5ad_url(filename: str) -> AnnData:
Args:
filename: URL of the data file.
"""
assert any(filename.startswith(scheme) for scheme in url_schemes)
if not any(filename.startswith(scheme) for scheme in url_schemes):
raise ValueError("The filename must start with 'http:', 'https:', or 'ftp:' protocol name.")
with urllib.request.urlopen(filename) as response:
with tempfile.TemporaryFile() as tmp_file:
shutil.copyfileobj(response, tmp_file)
Expand All @@ -65,7 +67,8 @@ def read_h5ad_local(filename: str) -> AnnData:
Args:
filename: Path to the local data file.
"""
assert filename.startswith("file:")
if not filename.startswith("file:"):
raise ValueError("The filename must start with 'file:' protocol name.")
filename = re.sub(r"^file://?", "", filename)
return read_h5ad(filename)

Expand Down
8 changes: 4 additions & 4 deletions cellarium/ml/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,16 @@ def collate_fn(batch: list[dict[str, np.ndarray]]) -> dict[str, np.ndarray | tor
keys = batch[0].keys()
collated_batch = {}
if len(batch) > 1:
assert all(keys == data.keys() for data in batch[1:]), "All dictionaries in the batch must have the same keys."
if not all(keys == data.keys() for data in batch[1:]):
raise ValueError("All dictionaries in the batch must have the same keys.")
for key in keys:
if key == "obs_names":
collated_batch[key] = np.concatenate([data[key] for data in batch], axis=0)
elif key in ["var_names", "var_names_g"]:
# Check that all var_names are the same
if len(batch) > 1:
assert all(
np.array_equal(batch[0][key], data[key]) for data in batch[1:]
), "All dictionaries in the batch must have the same var_names."
if not all(np.array_equal(batch[0][key], data[key]) for data in batch[1:]):
raise ValueError("All dictionaries in the batch must have the same var_names.")
# If so, just take the first one
collated_batch[key] = batch[0][key]
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_distributed_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_indexing(
oidx, n_adatas = row_select

if cache_size_strictly_enforced and (n_adatas > max_cache_size):
with pytest.raises(AssertionError, match="Expected the number of anndata files"):
with pytest.raises(ValueError, match="Expected the number of anndata files"):
dat_view = dat[oidx, vidx]
else:
adt_view = adt[oidx, vidx]
Expand Down Expand Up @@ -201,7 +201,7 @@ def test_indexing_dataset(
)

if cache_size_strictly_enforced and (n_adatas > max_cache_size):
with pytest.raises(AssertionError, match="Expected the number of anndata files"):
with pytest.raises(ValueError, match="Expected the number of anndata files"):
dataset_X = dataset[oidx]["x_ng"]
else:
adt_X = adt[oidx].X
Expand Down
Loading