Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix v0.3.1 bug #350

Merged
merged 12 commits into from
Apr 11, 2024
Merged
7 changes: 6 additions & 1 deletion .github/workflows/miniwdl_check.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
name: 'validate WDL'
on: [pull_request]

on:
pull_request:
branches: [ master, dev ]

env:
MINIWDL_VERSION: 1.8.0

jobs:
miniwdl-check:
runs-on: ubuntu-latest
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/run_packaging_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

name: 'packaging'

on: pull_request
on:
pull_request:
branches: [ master, dev ]

jobs:
build:
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/run_pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

name: 'pytest'

on: pull_request
on:
pull_request:
branches: [ master, dev ]

jobs:
build:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ dist/
*.csv
*.npz
*.tar.gz
data/
26 changes: 26 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,32 @@ The current release contains the following modules. More modules will be added i

Please refer to `the documentation <https://cellbender.readthedocs.io/en/latest/>`_ for a quick start tutorial.

WARNING:

The release tagged v0.3.1 included a bug which caused output count matrices to be incorrect. `The bug
<https://github.com/broadinstitute/CellBender/blame/e2fb5977cb187cb4b12172c9f77ed556bca92cb0/cellbender/remove_background/estimation.py#L241>`_,
introduced in `#303 <https://github.com/broadinstitute/CellBender/pull/303>`_, compromised output denoised count matrices
(due to an integer overflow) and would often show up as negative entries in the output count matrices. The bug also existed on
the master branch until `#347 <https://github.com/broadinstitute/CellBender/pull/347>`_.

For now, we recommend using either v0.3.0 or the master branch (after #347) until v0.3.2 is released, and then using v0.3.2.

Outputs generated with v0.3.1 (or the master branch between #303 and #347) can be salvaged by making use of the
checkpoint file, which is not compromised. The following command will re-run the (inexpensive, CPU-only)
estimation of the output count matrix using the saved posterior in the checkpoint file. Note the use of the new
input argument ``--force-use-checkpoint`` which will allow use of a checkpoint file produced by a different CellBender version:

.. code-block:: console

(cellbender) $ cellbender remove-background \
--input my_raw_count_matrix_file.h5 \
--output my_cellbender_output_file.h5 \
--checkpoint path/to/ckpt.tar.gz \
--force-use-checkpoint

where ``path/to/ckpt.tar.gz`` is the path to the checkpoint file generated by the original run. Ensure that you pair up the right
``--input`` with the right ``--checkpoint``.

Installation and Usage
----------------------

Expand Down
11 changes: 10 additions & 1 deletion cellbender/remove_background/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,19 @@ def add_subparser_args(subparsers: argparse) -> argparse:
subparser.add_argument("--checkpoint", nargs=None, type=str,
dest='input_checkpoint_tarball',
required=False, default=consts.CHECKPOINT_FILE_NAME,
help="Checkpoint tarball produced by the same version "
help="Checkpoint tarball produced by v0.3.0+ "
"of CellBender remove-background. If present, "
"and the workflow hashes match, training will "
"restart from this checkpoint.")
subparser.add_argument("--force-use-checkpoint",
dest='force_use_checkpoint', action="store_true",
help="Normally, checkpoints can only be used if the CellBender "
"code and certain input args match exactly. This flag allows you "
"to bypass this requirement. An example use would be to create a new output "
"using a checkpoint from a run of v0.3.1, a redacted version with "
"faulty output count matrices. If you use this flag, "
"ensure that the input file and the checkpoint match, because "
"CellBender will not check.")
subparser.add_argument("--expected-cells", nargs=None, type=int,
default=None,
dest="expected_cell_count",
Expand Down
36 changes: 23 additions & 13 deletions cellbender/remove_background/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def save_checkpoint(filebase: str,

def load_checkpoint(filebase: Optional[str],
tarball_name: str = consts.CHECKPOINT_FILE_NAME,
force_device: Optional[str] = None)\
force_device: Optional[str] = None,
force_use_checkpoint: bool = False)\
-> Dict[str, Union['RemoveBackgroundPyroModel', pyro.optim.PyroOptim, DataLoader, bool]]:
"""Load checkpoint and prepare a RemoveBackgroundPyroModel and optimizer."""

Expand All @@ -163,6 +164,7 @@ def load_checkpoint(filebase: Optional[str],
tarball_name=tarball_name,
to_load=['model', 'optim', 'param_store', 'dataloader', 'args', 'random_state'],
force_device=force_device,
force_use_checkpoint=force_use_checkpoint,
)
out.update({'loaded': True})
logger.info(f'Loaded partially-trained checkpoint from {tarball_name}')
Expand All @@ -172,7 +174,8 @@ def load_checkpoint(filebase: Optional[str],
def load_from_checkpoint(filebase: Optional[str],
tarball_name: str = consts.CHECKPOINT_FILE_NAME,
to_load: List[str] = ['model'],
force_device: Optional[str] = None) -> Dict:
force_device: Optional[str] = None,
force_use_checkpoint: bool = False) -> Dict:
"""Load specific files from a checkpoint tarball."""

load_kwargs = {}
Expand All @@ -192,19 +195,24 @@ def load_from_checkpoint(filebase: Optional[str],
else:
# no tarball loaded, so do not continue trying to load files
raise FileNotFoundError

# See if files have a hash matching input filebase.
if filebase is not None:

# If posterior is present, do not require run hash to match: will pick up
# after training and run estimation from existing posterior.
# This smoothly allows re-runs (including for problematic v0.3.1)
logger.debug(f'force_use_checkpoint: {force_use_checkpoint}')
if force_use_checkpoint or (filebase is None):
filebase = (glob.glob(os.path.join(tmp_dir, '*_model.torch'))[0]
.replace('_model.torch', ''))
logger.debug(f'Accepting any file hash, so loading {filebase}*')

else:
# See if files have a hash matching input filebase.
basename = os.path.basename(filebase)
filebase = os.path.join(tmp_dir, basename)
logger.debug(f'Looking for files with base name matching {filebase}*')
if not os.path.exists(filebase + '_model.torch'):
logger.info('Workflow hash does not match that of checkpoint.')
raise ValueError('Workflow hash does not match that of checkpoint.')
else:
filebase = (glob.glob(os.path.join(tmp_dir, '*_model.torch'))[0]
.replace('_model.torch', ''))
logger.debug(f'Accepting any file hash, so loading {filebase}*')
raise ValueError('Workflow hash does not match that of checkpoint.')

out = {}

Expand Down Expand Up @@ -265,9 +273,10 @@ def load_from_checkpoint(filebase: Optional[str],
return out


def attempt_load_checkpoint(filebase: str,
def attempt_load_checkpoint(filebase: Optional[str],
tarball_name: str = consts.CHECKPOINT_FILE_NAME,
force_device: Optional[str] = None)\
force_device: Optional[str] = None,
force_use_checkpoint: bool = False)\
-> Dict[str, Union['RemoveBackgroundPyroModel', pyro.optim.PyroOptim, DataLoader, bool]]:
"""Load checkpoint and prepare a RemoveBackgroundPyroModel and optimizer,
or return the inputs if loading fails."""
Expand All @@ -276,7 +285,8 @@ def attempt_load_checkpoint(filebase: str,
logger.debug('Attempting to load checkpoint from ' + tarball_name)
return load_checkpoint(filebase=filebase,
tarball_name=tarball_name,
force_device=force_device)
force_device=force_device,
force_use_checkpoint=force_use_checkpoint)

except FileNotFoundError:
logger.debug('No tarball found')
Expand Down
10 changes: 7 additions & 3 deletions cellbender/remove_background/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@

logger = logging.getLogger('cellbender')

N_CELLS_DATATYPE = np.int32
N_GENES_DATATYPE = np.int32
COUNT_DATATYPE = np.int32


class EstimationMethod(ABC):
"""Base class for estimation of noise counts, given a posterior."""
Expand Down Expand Up @@ -52,7 +56,7 @@ def _estimation_array_to_csr(self,
data: np.ndarray,
m: np.ndarray,
noise_offsets: Optional[Dict[int, int]],
dtype=np.int64) -> sp.csr_matrix:
dtype=COUNT_DATATYPE) -> sp.csr_matrix:
"""Say you have point estimates for each count matrix element (data) and
you have the 'm'-indices for each value (m). This returns a CSR matrix
that has the shape of the count matrix, where duplicate entries have
Expand Down Expand Up @@ -218,7 +222,7 @@ def _estimation_array_to_csr(index_converter,
data: np.ndarray,
m: np.ndarray,
noise_offsets: Optional[Dict[int, int]],
dtype=np.int) -> sp.csr_matrix:
dtype=COUNT_DATATYPE) -> sp.csr_matrix:
"""Say you have point estimates for each count matrix element (data) and
you have the 'm'-indices for each value (m). This returns a CSR matrix
that has the shape of the count matrix, where duplicate entries have
Expand All @@ -238,7 +242,7 @@ def _estimation_array_to_csr(index_converter,
row, col = index_converter.get_ng_indices(m_inds=m)
if noise_offsets is not None:
data = data + np.array([noise_offsets.get(i, 0) for i in m])
coo = sp.coo_matrix((data.astype(dtype), (row.astype(np.uint64), col.astype(np.uint8))),
coo = sp.coo_matrix((data.astype(dtype), (row.astype(N_CELLS_DATATYPE), col.astype(N_GENES_DATATYPE))),
shape=index_converter.matrix_shape, dtype=dtype)
coo.sum_duplicates()
return coo.tocsr()
Expand Down
6 changes: 4 additions & 2 deletions cellbender/remove_background/posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,15 @@ def _do_posterior_regularization(posterior: Posterior):
try:
ckpt_posterior = load_from_checkpoint(tarball_name=args.input_checkpoint_tarball,
filebase=args.checkpoint_filename,
to_load=['posterior'])
to_load=['posterior'],
force_use_checkpoint=args.force_use_checkpoint)
except ValueError:
# input checkpoint tarball was not a match for this workflow
# but we still may have saved a new tarball
ckpt_posterior = load_from_checkpoint(tarball_name=consts.CHECKPOINT_FILE_NAME,
filebase=args.checkpoint_filename,
to_load=['posterior'])
to_load=['posterior'],
force_use_checkpoint=args.force_use_checkpoint)
if os.path.exists(ckpt_posterior.get('posterior_file', 'does_not_exist')):
# Load posterior if it was saved in the checkpoint.
posterior.load(file=ckpt_posterior['posterior_file'])
Expand Down
6 changes: 5 additions & 1 deletion cellbender/remove_background/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,9 @@ def compute_output_denoised_counts_reports_metrics(posterior: Posterior,
posterior.latents_map['p'],
)

# Failsafe to ensure no negative counts.
assert np.all(denoised_counts.data >= 0), 'Negative count matrix entries in output'

# TODO: correct cell probabilities so that any zero-count droplet becomes "empty"

# Save denoised count matrix.
Expand Down Expand Up @@ -627,7 +630,8 @@ def run_inference(dataset_obj: SingleCellRNACountsDataset,
# Attempt to load from a previously-saved checkpoint.
ckpt = attempt_load_checkpoint(filebase=checkpoint_filename,
tarball_name=args.input_checkpoint_tarball,
force_device='cuda:0' if args.use_cuda else 'cpu')
force_device='cuda:0' if args.use_cuda else 'cpu',
force_use_checkpoint=args.force_use_checkpoint)
ckpt_loaded = ckpt['loaded'] # True if a checkpoint was loaded successfully

if ckpt_loaded:
Expand Down
1 change: 1 addition & 0 deletions cellbender/remove_background/tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def test_save_and_load_cellbender_checkpoint(tmpdir_factory, cuda, scheduler):
args.constant_learning_rate = not scheduler
args.debug = False
args.input_checkpoint_tarball = 'none'
args.force_use_checkpoint = False

create_random_state_blank_slate(0)
pyro.clear_param_store()
Expand Down
36 changes: 33 additions & 3 deletions cellbender/remove_background/tests/test_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import torch

from cellbender.remove_background.estimation import Mean, MAP, \
SingleSample, ThresholdCDF, MultipleChoiceKnapsack, pandas_grouped_apply
SingleSample, ThresholdCDF, MultipleChoiceKnapsack, pandas_grouped_apply, _estimation_array_to_csr, COUNT_DATATYPE
from cellbender.remove_background.posterior import IndexConverter, \
dense_to_sparse_op_torch, log_prob_sparse_to_dense
from cellbender.remove_background.tests.conftest import sparse_matrix_equal

from typing import Dict, Union

Expand Down Expand Up @@ -92,11 +93,15 @@ def test_mean_massive_m(log_prob_coo):
new_shape = (coo.shape[0] + greater_than_max_int32, coo.shape[1])
new_coo = sp.coo_matrix((coo.data, (new_row, coo.col)),
shape=new_shape)
print(f'original COO shape: {coo.shape}')
print(f'new COO shape: {new_coo.shape}')
print(f'new row minimum value: {new_coo.row.min()}')
print(f'new row maximum value: {new_coo.row.max()}')
offset_dict = {k + greater_than_max_int32: v for k, v in log_prob_coo['offsets'].items()}

# this is just a shim
converter = IndexConverter(total_n_cells=2,
total_n_genes=new_coo.shape[0])
converter = IndexConverter(total_n_cells=new_coo.shape[0],
total_n_genes=new_coo.shape[1])

# set up and estimate
estimator = Mean(index_converter=converter)
Expand Down Expand Up @@ -379,3 +384,28 @@ def test_parallel_pandas_grouped_apply(fun):

np.testing.assert_array_equal(reg['m'], parallel['m'])
np.testing.assert_array_equal(reg['result'], parallel['result'])


def test_estimation_array_to_csr():

larger_than_uint16 = 2**16 + 1

converter = IndexConverter(total_n_cells=larger_than_uint16,
total_n_genes=larger_than_uint16)
m = larger_than_uint16 + np.arange(-10, 10)
data = np.random.rand(len(m)) * -10
noise_offsets = None

output_csr = _estimation_array_to_csr(index_converter=converter, data=data, m=m, noise_offsets=noise_offsets, dtype=COUNT_DATATYPE)

# reimplementation here with totally permissive datatypes
cell_and_gene_dtype = np.float64
row, col = converter.get_ng_indices(m_inds=m)
if noise_offsets is not None:
data = data + np.array([noise_offsets.get(i, 0) for i in m])
coo = sp.coo_matrix((data.astype(COUNT_DATATYPE), (row.astype(cell_and_gene_dtype), col.astype(cell_and_gene_dtype))),
shape=converter.matrix_shape, dtype=COUNT_DATATYPE)
coo.sum_duplicates()
truth_csr = coo.tocsr()

assert sparse_matrix_equal(output_csr, truth_csr)
3 changes: 3 additions & 0 deletions cellbender/remove_background/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,6 @@ def test_full_run(tmpdir_factory, h5_v3_file, cuda):
adata_cell_barcodes = adata.obs_names[adata.obs['cell_probability'] > consts.CELL_PROB_CUTOFF]
assert set(cell_barcodes) == set(adata_cell_barcodes), \
'Cell barcodes in h5 are different from those in CSV file'

# ensure there are no negative count matrix entries in the output
assert np.all(adata.X.data >= 0), 'Negative count matrix entries in output'
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ jupyter
jupyter_contrib_nbextensions
notebook<7.0.0
nbconvert<7.0.0
lxml_html_clean
psutil
Loading