Skip to content

Commit

Permalink
use generator
Browse files Browse the repository at this point in the history
  • Loading branch information
YQ-Wang committed Aug 28, 2024
1 parent 17d261b commit 1fce270
Showing 1 changed file with 31 additions and 22 deletions.
53 changes: 31 additions & 22 deletions scbsp/scbsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,21 +79,21 @@ def _binary_distance_matrix_threshold(
indices = ball_tree.query_radius(
input_sparse_mat_array, r=d_val, return_distance=False
)
rows = np.repeat(
np.arange(input_sparse_mat_array.shape[0]), [len(i) for i in indices]
)
cols = np.concatenate(indices)

# Construct binary csr_matrix
data = np.ones_like(rows)

def generate_data():
for i, idx in enumerate(indices):
yield from ((i, j, 1) for j in idx)

rows, cols, data = zip(*generate_data())

sparse_mat = csr_matrix(
(data, (rows, cols)),
shape=(input_sparse_mat_array.shape[0], input_sparse_mat_array.shape[0]),
dtype=np.int8
)

return sparse_mat + identity(
input_sparse_mat_array.shape[0], format="csr", dtype=bool
input_sparse_mat_array.shape[0], format="csr", dtype=np.int8
)


Expand Down Expand Up @@ -157,7 +157,7 @@ def _var_local_means(
input_sp_mat, d_val, leaf_size
)
patches_cells_centroid = diags(
(patches_cells.sum(axis=1) > 1).astype(float).A.ravel(),
(patches_cells.sum(axis=1) > 1).astype(np.float32).A.ravel(),
offsets=0,
format="csr",
)
Expand All @@ -167,26 +167,34 @@ def _var_local_means(

if use_gpu and gpu_enabled:
# Convert the csr_matrix to PyTorch tensors and move to GPU
input_exp_mat_norm_torch = torch.tensor( # type: ignore
input_exp_mat_norm_torch = torch.tensor(
input_exp_mat_norm.toarray(), device="cuda"
)
patches_cells_torch = torch.tensor(patches_cells.toarray(), device="cuda") # type: ignore
diag_matrix_sparse_torch = torch.tensor( # type: ignore
patches_cells_torch = torch.tensor(patches_cells.toarray(), device="cuda")
diag_matrix_sparse_torch = torch.tensor(
diag_matrix_sparse.toarray(), device="cuda"
)

result = torch.matmul( # type: ignore
result = torch.matmul(
input_exp_mat_norm_torch,
torch.matmul(patches_cells_torch, diag_matrix_sparse_torch), # type: ignore
torch.matmul(patches_cells_torch, diag_matrix_sparse_torch),
)
x_kj = scipy.sparse.csr_matrix(result.cpu().numpy())
del result # Free up GPU memory
else:
x_kj = input_exp_mat_norm @ (patches_cells @ diag_matrix_sparse)

# Free up memory
del patches_cells, patches_cells_centroid, diag_matrix_sparse

return _calculate_sparse_variances(x_kj, axis=1)

var_x = np.column_stack([_var_local_means(input_sp_mat, d_val, input_exp_mat_norm, leaf_size, use_gpu).A.ravel() for d_val in (d1, d2)]) # type: ignore
var_x_0_add = _calculate_sparse_variances(input_exp_mat_raw, axis=1).A.ravel() # type: ignore
def var_x_generator():
for d_val in (d1, d2):
yield _var_local_means(input_sp_mat, d_val, input_exp_mat_norm, leaf_size, use_gpu).A.ravel()

var_x = np.column_stack(list(var_x_generator()))
var_x_0_add = _calculate_sparse_variances(input_exp_mat_raw, axis=1).A.ravel()
var_x_0_add /= max(var_x_0_add)
t_matrix = (var_x[:, 1] / var_x[:, 0]) * var_x_0_add
return t_matrix.tolist()
Expand Down Expand Up @@ -243,13 +251,14 @@ def granp(

# Calculate p-values
t_matrix_sum_upper90 = np.quantile(t_matrix_sum, 0.90)
t_matrix_sum_mid = [val for val in t_matrix_sum if val < t_matrix_sum_upper90]
log_t_matrix_sum_mid = np.log(t_matrix_sum_mid)
t_matrix_sum_mid = (val for val in t_matrix_sum if val < t_matrix_sum_upper90)
log_t_matrix_sum_mid = np.fromiter((np.log(val) for val in t_matrix_sum_mid), dtype=float)
log_norm_params = (log_t_matrix_sum_mid.mean(), log_t_matrix_sum_mid.std(ddof=1))

# Calculate p-values using the log-normal distribution.
p_values = 1 - lognorm.cdf(
t_matrix_sum, scale=np.exp(log_norm_params[0]), s=log_norm_params[1]
)
def p_value_generator():
for val in t_matrix_sum:
yield 1 - lognorm.cdf(val, scale=np.exp(log_norm_params[0]), s=log_norm_params[1])

p_values = list(p_value_generator())

return pd.DataFrame({"gene_names": gene_names, "p_values": p_values})

0 comments on commit 1fce270

Please sign in to comment.