diff --git a/cellbender/remove_background/estimation.py b/cellbender/remove_background/estimation.py
index 2443724..371a7fa 100644
--- a/cellbender/remove_background/estimation.py
+++ b/cellbender/remove_background/estimation.py
@@ -463,11 +463,9 @@ def estimate_noise(self,
         if use_multiple_processes:
 
             logger.info('Dividing dataset into chunks of genes')
-            chunk_logic_list = list(
-                self._gene_chunk_iterator(
-                    noise_log_prob_coo=noise_log_prob_coo,
-                    n_chunks=n_chunks,
-                )
+            chunk_logic_list = self._gene_chunk_iterator(
+                noise_log_prob_coo=noise_log_prob_coo,
+                n_chunks=n_chunks,
             )
 
             logger.info('Computing the output in asynchronous chunks in parallel...')
@@ -538,10 +536,9 @@ def estimate_noise(self,
     def _gene_chunk_iterator(self,
                              noise_log_prob_coo: sp.coo_matrix,
                              n_chunks: int) \
-            -> Generator[np.ndarray, None, None]:
-        """Yields chunks of the posterior that can be treated as independent,
-        from the standpoint of MCKP count estimation.  That is, they contain all
-        matrix entries for any genes they include.
+            -> List[np.ndarray]:
+        """Return a list of logical (size m) arrays used to select gene chunks
+        on which to compute the MCKP estimate. These chunks are independent.
 
         Args:
             noise_log_prob_coo: Full noise log prob posterior COO
@@ -551,36 +548,14 @@ def _gene_chunk_iterator(self,
             Logical array which indexes elements of coo posterior for the chunk
         """
 
-        # TODO this generator is way too slow
-
-        # approximate number of entries in a chunk
-        # approx_chunk_entries = (noise_log_prob_coo.data.size - 1) // n_chunks
-
         # get gene annotations
         _, genes = self.index_converter.get_ng_indices(m_inds=noise_log_prob_coo.row)
         genes_series = pd.Series(genes)
 
-        # things we need to keep track of for each chunk
-        # current_chunk_genes = []
-        # entry_logic = np.zeros(noise_log_prob_coo.data.size, dtype=bool)
-
-        # TODO eliminate for loop to speed this up
-        # take the list of genes from the coo, sort it, and divide it evenly
-        # somehow break ties for genes overlapping boundaries of divisions
-        sorted_genes = np.sort(genes)
-        gene_arrays = np.array_split(sorted_genes, n_chunks)
-        last_gene_set = {}
-        for gene_array in gene_arrays:
-            gene_set = set(gene_array)
-            gene_set = gene_set.difference(last_gene_set)  # only the new stuff
-            # if there is a second chunk, make sure there is a gene unique to it
-            if (n_chunks > 1) and (len(gene_set) == len(set(genes))):  # all genes in first set
-                # this mainly exists for tests
-                gene_set = gene_set - {gene_arrays[-1][-1]}
-            last_gene_set = gene_set
-            entry_logic = genes_series.isin(gene_set).values
-            if sum(entry_logic) > 0:
-                yield entry_logic
+        gene_chunk_arrays = np.array_split(np.arange(self.index_converter.total_n_genes), n_chunks)
+
+        gene_logic_arrays = [genes_series.isin(x).values for x in gene_chunk_arrays]
+        return gene_logic_arrays
 
     def _chunk_estimate_noise(self,
                               noise_log_prob_coo: sp.coo_matrix,
diff --git a/cellbender/remove_background/posterior.py b/cellbender/remove_background/posterior.py
index 3758702..6aaeefa 100644
--- a/cellbender/remove_background/posterior.py
+++ b/cellbender/remove_background/posterior.py
@@ -451,7 +451,7 @@ def _get_cell_noise_count_posterior_coo(
                          f'accurate for your dataset.')
             raise RuntimeError('Zero cells found!')
 
-        dataloader_index_to_analyzed_bc_index = np.where(cell_logic)[0]
+        dataloader_index_to_analyzed_bc_index = torch.where(torch.tensor(cell_logic))[0]
         cell_data_loader = DataLoader(
             count_matrix[cell_logic],
             empty_drop_dataset=None,
@@ -468,6 +468,12 @@ def _get_cell_noise_count_posterior_coo(
         log_probs = []
         ind = 0
         n_minibatches = len(cell_data_loader)
+        analyzed_gene_inds = torch.tensor(self.analyzed_gene_inds.copy())
+        if analyzed_bcs_only:
+            barcode_inds = torch.tensor(self.dataset_obj.analyzed_barcode_inds.copy())
+        else:
+            barcode_inds = torch.tensor(self.barcode_inds.copy())
+        nonzero_noise_offset_dict = {}
 
         logger.info('Computing posterior noise count probabilities in mini-batches.')
 
@@ -505,46 +511,43 @@ def _get_cell_noise_count_posterior_coo(
             )
 
             # Get the original gene index from gene index in the trimmed dataset.
-            genes_i = self.analyzed_gene_inds[genes_i_analyzed]
+            genes_i = analyzed_gene_inds[genes_i_analyzed.cpu()]
 
             # Barcode index in the dataloader.
-            bcs_i = bcs_i_chunk + ind
+            bcs_i = (bcs_i_chunk + ind).cpu()
 
             # Obtain the real barcode index since we only use cells.
             bcs_i = dataloader_index_to_analyzed_bc_index[bcs_i]
 
             # Translate chunk barcode inds to overall inds.
-            if analyzed_bcs_only:
-                bcs_i = self.dataset_obj.analyzed_barcode_inds[bcs_i]
-            else:
-                bcs_i = self.barcode_inds[bcs_i]
+            bcs_i = barcode_inds[bcs_i]
 
             # Add sparse matrix values to lists.
-            try:
-                bcs.extend(bcs_i.tolist())
-                genes.extend(genes_i.tolist())
-                c.extend(c_i.tolist())
-                log_probs.extend(log_prob_i.tolist())
-                c_offset.extend(noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed]
-                                .detach().cpu().numpy())
-            except TypeError as e:
-                # edge case of a single value
-                bcs.append(bcs_i)
-                genes.append(genes_i)
-                c.append(c_i)
-                log_probs.append(log_prob_i)
-                c_offset.append(noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed]
-                                .detach().cpu().numpy())
+            bcs.append(bcs_i.detach())
+            genes.append(genes_i.detach())
+            c.append(c_i.detach().cpu())
+            log_probs.append(log_prob_i.detach().cpu())
+
+            # Update offset dict with any nonzeros.
+            nonzero_offset_inds, nonzero_noise_count_offsets = dense_to_sparse_op_torch(
+                noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed].detach().flatten(),
+            )
+            m_i = self.index_converter.get_m_indices(cell_inds=bcs_i, gene_inds=genes_i)
+
+            nonzero_noise_offset_dict.update(
+                dict(zip(m_i[nonzero_offset_inds.detach().cpu()].tolist(),
+                         nonzero_noise_count_offsets.detach().cpu().tolist()))
+            )
+            c_offset.append(noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed].detach().cpu())
 
             # Increment barcode index counter.
             ind += data.shape[0]  # Same as data_loader.batch_size
 
-        # Convert the lists to numpy arrays.
-        log_probs = np.array(log_probs, dtype=float)
-        c = np.array(c, dtype=np.uint32)
-        barcodes = np.array(bcs, dtype=np.uint64)  # uint32 is too small!
-        genes = np.array(genes, dtype=np.uint64)  # use same as above for IndexConverter
-        noise_count_offsets = np.array(c_offset, dtype=np.uint32)
+        # Concatenate lists.
+        log_probs = torch.cat(log_probs)
+        c = torch.cat(c)
+        barcodes = torch.cat(bcs)
+        genes = torch.cat(genes)
 
         # Translate (barcode, gene) inds to 'm' format index.
         m = self.index_converter.get_m_indices(cell_inds=barcodes, gene_inds=genes)
@@ -554,8 +557,6 @@ def _get_cell_noise_count_posterior_coo(
             (log_probs, (m, c)),
             shape=[np.prod(self.count_matrix_shape), n_counts_max],
         )
-        noise_offset_dict = dict(zip(m, noise_count_offsets))
-        nonzero_noise_offset_dict = {k: v for k, v in noise_offset_dict.items() if (v > 0)}
         self._noise_count_posterior_coo_offsets = nonzero_noise_offset_dict
         return self._noise_count_posterior_coo
 
diff --git a/cellbender/remove_background/sparse_utils.py b/cellbender/remove_background/sparse_utils.py
index 4a0f26f..ca31329 100644
--- a/cellbender/remove_background/sparse_utils.py
+++ b/cellbender/remove_background/sparse_utils.py
@@ -10,7 +10,7 @@
 @torch.no_grad()
 def dense_to_sparse_op_torch(t: torch.Tensor,
                              tensor_for_nonzeros: Optional[torch.Tensor] = None) \
-        -> Tuple[np.ndarray, ...]:
+        -> Tuple[torch.Tensor, ...]:
     """Converts dense matrix to sparse COO format tuple of numpy arrays (*indices, data)
 
     Args:
@@ -28,9 +28,9 @@ def dense_to_sparse_op_torch(t: torch.Tensor,
         tensor_for_nonzeros = t
 
     nonzero_inds_tuple = torch.nonzero(tensor_for_nonzeros, as_tuple=True)
-    nonzero_values = t[nonzero_inds_tuple].flatten()
+    nonzero_values = t[nonzero_inds_tuple].flatten().clone()
 
-    return tuple([ten.cpu().numpy() for ten in (nonzero_inds_tuple + (nonzero_values,))])
+    return nonzero_inds_tuple + (nonzero_values,)
 
 
 def log_prob_sparse_to_dense(coo: sp.coo_matrix) -> np.ndarray:
diff --git a/cellbender/remove_background/tests/test_dataprep.py b/cellbender/remove_background/tests/test_dataprep.py
index 8fcb156..5d11380 100644
--- a/cellbender/remove_background/tests/test_dataprep.py
+++ b/cellbender/remove_background/tests/test_dataprep.py
@@ -75,9 +75,9 @@ def test_dataloader_sorting(simulated_dataset, cuda):
             bcs_i = loader.unsort_inds(bcs_i)
 
             # Add sparse matrix values to lists.
-            barcodes.append(bcs_i)
-            genes.append(genes_i)
-            counts.append(counts_i)
+            barcodes.append(bcs_i.detach().cpu())
+            genes.append(genes_i.detach().cpu())
+            counts.append(counts_i.detach().cpu())
 
             # Increment barcode index counter.
             ind += data.shape[0]  # Same as data_loader.batch_size
diff --git a/cellbender/remove_background/tests/test_sparse_utils.py b/cellbender/remove_background/tests/test_sparse_utils.py
index 01230da..2f2e13e 100644
--- a/cellbender/remove_background/tests/test_sparse_utils.py
+++ b/cellbender/remove_background/tests/test_sparse_utils.py
@@ -76,9 +76,9 @@ def test_dense_to_sparse_op_torch(simulated_dataset, cuda):
         bcs_i = data_loader.unsort_inds(bcs_i)
 
         # Add sparse matrix values to lists.
-        barcodes.append(bcs_i)
-        genes.append(genes_i)
-        counts.append(counts_i)
+        barcodes.append(bcs_i.detach().cpu())
+        genes.append(genes_i.detach().cpu())
+        counts.append(counts_i.detach().cpu())
 
         # Increment barcode index counter.
         ind += data.shape[0]  # Same as data_loader.batch_size