Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrating MPS backend #9

Merged
merged 6 commits into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 148 additions & 6 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,148 @@
__pycache__
*.pdf
build
dist
*.egg-info
.eggs
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# Sublime workspace
*.sublime-workspace
.DS_Store

#Custom folders
results/
figures/

*.sublime-workspace
*.sublime-project

# Jupyter notebooks
*.ipynb

.idea/

*.h5ad

2 changes: 1 addition & 1 deletion harmony/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from importlib_metadata import version, PackageNotFoundError

try:
__version__ = version('harmony-pytorch')
__version__ = version("harmony-pytorch")
del version
except PackageNotFoundError:
pass
71 changes: 35 additions & 36 deletions harmony/harmony.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .utils import one_hot_tensor, get_batch_codes



def harmonize(
X: np.array,
batch_mat: pd.DataFrame,
Expand Down Expand Up @@ -105,13 +104,16 @@ def harmonize(
>>> X_harmony = harmonize(adata.obsm['X_pca'], adata.obs, ['Channel', 'Lab'])
"""

assert(isinstance(X, np.ndarray))
assert isinstance(X, np.ndarray)

if n_jobs < 0:
import psutil
n_jobs = psutil.cpu_count(logical=False) # get physical cores

n_jobs = psutil.cpu_count(logical=False) # get physical cores
if n_jobs is None:
n_jobs = psutil.cpu_count(logical=True) # if undetermined, use logical cores instead
n_jobs = psutil.cpu_count(
logical=True
) # if undetermined, use logical cores instead
torch.set_num_threads(n_jobs)

device_type = "cpu"
Expand All @@ -120,9 +122,14 @@ def harmonize(
device_type = "cuda"
if verbose:
print("Use GPU mode.")
else:
elif torch.backends.mps.is_available():
device_type = "mps"
if verbose:
print("CUDA is not available on your machine. Use CPU mode instead.")
print("Use Metal (MPS) mode.")
elif verbose:
print(
"Neither CUDA nor MPS is available on your machine. Use CPU mode instead."
)

(stride_0, stride_1) = X.strides
if stride_0 < 0 or stride_1 < 0:
Expand Down Expand Up @@ -156,7 +163,7 @@ def harmonize(
theta = theta.view(1, -1)

assert block_proportion > 0 and block_proportion <= 1
assert correction_method in ["fast", "original"]
assert correction_method in {"fast", "original"}

np.random.seed(random_state)

Expand Down Expand Up @@ -206,13 +213,10 @@ def harmonize(

if is_convergent_harmony(objectives_harmony, tol=tol_harmony):
if verbose:
print("Reach convergence after {} iteration(s).".format(i + 1))
print(f"Reach convergence after {i + 1} iteration(s).")
break

if device_type == "cpu":
return Z_hat.numpy()
else:
return Z_hat.cpu().numpy()
return Z_hat.numpy() if device_type == "cpu" else Z_hat.cpu().numpy()


def initialize_centroids(
Expand All @@ -229,17 +233,19 @@ def initialize_centroids(
):
n_cells = Z_norm.shape[0]

kmeans_params = {'n_clusters': n_clusters,
'init': "k-means++",
'n_init': n_init,
'random_state': random_state,
'max_iter': 25,
}
kmeans_params = {
"n_clusters": n_clusters,
"init": "k-means++",
"n_init": n_init,
"random_state": random_state,
"max_iter": 25,
}

kmeans = KMeans(**kmeans_params)

from threadpoolctl import threadpool_limits
with threadpool_limits(limits = n_jobs):

with threadpool_limits(limits=n_jobs):
if device_type == "cpu":
kmeans.fit(Z_norm)
else:
Expand All @@ -249,9 +255,7 @@ def initialize_centroids(
Y_norm = normalize(Y, p=2, dim=1)

# Initialize R
R = torch.exp(
-2 / sigma * (1 - torch.matmul(Z_norm, Y_norm.t()))
)
R = torch.exp(-2 / sigma * (1 - torch.matmul(Z_norm, Y_norm.t())))
R = normalize(R, p=1, dim=1)

E = torch.matmul(Pr_b, torch.sum(R, dim=0, keepdim=True))
Expand Down Expand Up @@ -282,12 +286,11 @@ def clustering(
device_type,
n_init=10,
):

n_cells = Z_norm.shape[0]

objectives_clustering = []

for i in range(max_iter):
for _ in range(max_iter):
# Compute Cluster Centroids
Y = torch.matmul(R.t(), Z_norm)
Y_norm = normalize(Y, p=2, dim=1)
Expand All @@ -298,12 +301,8 @@ def clustering(
pos = 0
while pos < len(idx_list):
idx_in = idx_list[pos : (pos + block_size)]
R_in = R[
idx_in,
]
Phi_in = Phi[
idx_in,
]
R_in = R[idx_in,]
Phi_in = Phi[idx_in,]

# Compute O and E on left out data.
O -= torch.matmul(Phi_in.t(), R_in)
Expand Down Expand Up @@ -347,14 +346,12 @@ def correction_original(X, R, Phi, ridge_lambda, device_type):
Phi_1 = torch.cat((torch.ones(n_cells, 1, device=device_type), Phi), dim=1)

Z = X.clone()
id_mat = torch.eye(n_batches + 1, n_batches + 1, device = device_type)
id_mat = torch.eye(n_batches + 1, n_batches + 1, device=device_type)
id_mat[0, 0] = 0
Lambda = ridge_lambda * id_mat
for k in range(n_clusters):
Phi_t_diag_R = Phi_1.t() * R[:, k].view(1, -1)
inv_mat = torch.inverse(
torch.matmul(Phi_t_diag_R, Phi_1) + Lambda
)
inv_mat = torch.inverse(torch.matmul(Phi_t_diag_R, Phi_1) + Lambda)
W = torch.matmul(inv_mat, torch.matmul(Phi_t_diag_R, X))
W[0, :] = 0
Z -= torch.matmul(Phi_t_diag_R.t(), W)
Expand All @@ -375,7 +372,7 @@ def correction_fast(X, R, Phi, O, ridge_lambda, device_type):
N_k = torch.sum(O_k)

factor = 1 / (O_k + ridge_lambda)
c = N_k + torch.sum(-factor * O_k ** 2)
c = N_k + torch.sum(-factor * O_k**2)
c_inv = 1 / c

P[0, 1:] = -factor * O_k
Expand All @@ -401,7 +398,9 @@ def compute_objective(
Y_norm, Z_norm, R, theta, sigma, O, E, objective_arr, device_type
):
kmeans_error = torch.sum(R * 2 * (1 - torch.matmul(Z_norm, Y_norm.t())))
entropy_term = sigma * torch.sum(-torch.distributions.Categorical(probs=R).entropy())
entropy_term = sigma * torch.sum(
-torch.distributions.Categorical(probs=R).entropy()
)
diversity_penalty = sigma * torch.sum(
torch.matmul(theta, O * torch.log(torch.div(O + 1, E + 1)))
)
Expand Down
18 changes: 11 additions & 7 deletions harmony/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,28 @@


def get_batch_codes(batch_mat, batch_key):
if type(batch_key) is str or len(batch_key) == 1:
if not type(batch_key) is str:
batch_key = batch_key[0]
if type(batch_key) is str:
yihming marked this conversation as resolved.
Show resolved Hide resolved
batch_vec = batch_mat[batch_key]

elif len(batch_key) == 1:
batch_key = batch_key[0]

batch_vec = batch_mat[batch_key]

else:
df = batch_mat[batch_key].astype('str')
batch_vec = df.apply(lambda row: ','.join(row), axis = 1)
df = batch_mat[batch_key].astype("str")
batch_vec = df.apply(lambda row: ",".join(row), axis=1)

return batch_vec.astype("category")


def one_hot_tensor(X, device_type):
ids = torch.as_tensor(X.cat.codes.values.copy(), dtype = torch.long, device = device_type).view(-1, 1)
ids = torch.as_tensor(
X.cat.codes.values.copy(), dtype=torch.long, device=device_type
).view(-1, 1)
n_row = X.size
n_col = X.cat.categories.size
Phi = torch.zeros(n_row, n_col, dtype=torch.float, device = device_type)
Phi = torch.zeros(n_row, n_col, dtype=torch.float, device=device_type)
Phi.scatter_(dim=1, index=ids, value=1.0)

return Phi
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
long_description = f.read()

requires = [
"torch",
"torch>=1.12",
"numpy",
"pandas",
"psutil",
"threadpoolctl",
"scikit-learn>=0.23",
"importlib_metadata>=0.7; python_version < '3.8'",
yihming marked this conversation as resolved.
Show resolved Hide resolved
"scikit-learn>=0.23"
]

setup(
Expand Down
Loading