Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve device support and add support for Apple Silicon chipset (mps) #34

Merged
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
e3bd4fb
add to() method to move cebra models (sklearn API) from devices
gonlairo May 16, 2023
d3bbcd9
add EOF
gonlairo May 17, 2023
19ab4c6
better name of test
gonlairo May 17, 2023
922050d
add EOF
gonlairo May 17, 2023
2c34242
improve test with suggestions
gonlairo May 17, 2023
83ef70b
assign self.device_ if it exists only
gonlairo May 19, 2023
2a4ebf1
modify check_device() to allow GPU id specification
gonlairo May 19, 2023
4466064
adapt test given the possibility of specifying GPU ids
gonlairo May 19, 2023
ea0a67d
skip when only cpu is available
gonlairo Jun 7, 2023
e1e6848
fix docs
gonlairo Jun 20, 2023
f384103
delete inline comments
gonlairo Jul 4, 2023
78d7eb2
add support for mps device
gonlairo Jul 4, 2023
9f61bcf
fix indentation
gonlairo Jul 4, 2023
11c6acf
add mps to _set_device() in io
gonlairo Jul 4, 2023
c3e39d3
add to() method to move cebra models (sklearn API) from devices
gonlairo May 16, 2023
ad42c35
add EOF
gonlairo May 17, 2023
aaf7164
better name of test
gonlairo May 17, 2023
9cbce43
add EOF
gonlairo May 17, 2023
fc90a66
improve test with suggestions
gonlairo May 17, 2023
dca1ade
assign self.device_ if it exists only
gonlairo May 19, 2023
7da05ff
modify check_device() to allow GPU id specification
gonlairo May 19, 2023
a9be811
adapt test given the possibility of specifying GPU ids
gonlairo May 19, 2023
b50d75b
skip when only cpu is available
gonlairo Jun 7, 2023
b09b26b
fix docs
gonlairo Jun 20, 2023
ebd481b
delete inline comments
gonlairo Jul 4, 2023
893c0a4
add support for mps device
gonlairo Jul 4, 2023
e92f9bc
fix indentation
gonlairo Jul 4, 2023
42965ea
add mps to _set_device() in io
gonlairo Jul 4, 2023
ae66bad
Merge branch 'main' into rodrigo/move-model-to-device
MMathisLab Jul 6, 2023
044b52f
add mps logic when cuda_if_available + fix test for torch versions < …
gonlairo Jul 12, 2023
7638b0b
Merge branch 'rodrigo/move-model-to-device' of github.com:gonlairo/CE…
gonlairo Jul 13, 2023
9111e9d
fix docs in utils.py
gonlairo Jul 13, 2023
d026391
Merge branch 'main' into rodrigo/move-model-to-device
gonlairo Jul 13, 2023
afd1789
Run pre-commit
stes Jul 13, 2023
f294b58
add more tests
gonlairo Jul 14, 2023
9a9fc3d
Merge branch 'rodrigo/move-model-to-device' of github.com:gonlairo/CE…
gonlairo Jul 14, 2023
dfb0ccc
fix test when cuda is not available
gonlairo Jul 14, 2023
961b4a8
fix test when pytorch < 1.12
gonlairo Jul 14, 2023
d7729cd
Merge branch 'main' into rodrigo/move-model-to-device
gonlairo Jul 17, 2023
5e74c0b
Run pre-commit
stes Jul 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cebra/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import zipfile
from typing import List

import pkg_resources
import requests

import cebra.data
Expand Down Expand Up @@ -88,3 +89,14 @@ def download_file_from_zip_url(url, file="montblanc_tracks.h5"):
except zipfile.error:
pass
return pathlib.Path(foldername) / "data" / file


def _is_mps_availabe(torch):
available = False
if pkg_resources.parse_version(
torch.__version__) >= pkg_resources.parse_version("1.12"):
if torch.backends.mps.is_available():
if torch.backends.mps.is_built():
available = True

return available
49 changes: 49 additions & 0 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,3 +1256,52 @@ def load(cls,
raise RuntimeError("Model loaded from file is not compatible with "
"the current CEBRA version.")
return model

def to(self, device: Union[str, torch.device]):
"""Moves the cebra model to the specified device.

Args:
device: The device to move the cebra model to. This can be a string representing
the device ('cpu','cuda', cuda:device_id, or 'mps') or a torch.device object.

Returns:
The cebra model instance.

Example:

>>> import cebra
>>> import numpy as np
>>> dataset = np.random.uniform(0, 1, (1000, 30))
>>> cebra_model = cebra.CEBRA(max_iterations=10, device = "cuda_if_available")
>>> cebra_model.fit(dataset)
CEBRA(max_iterations=10)
>>> cebra_model = cebra_model.to("cpu")
"""

if not isinstance(device, (str, torch.device)):
raise TypeError(
"The 'device' parameter must be a string or torch.device object."
)

if (not device == 'cpu') and (not device.startswith('cuda')) and (
not device == 'mps'):
raise ValueError(
"The 'device' parameter must be a valid device string or device object."
)

if isinstance(device, str):
device = torch.device(device)

if (not device.type == 'cpu') and (
not device.type.startswith('cuda')) and (not device == 'mps'):
raise ValueError(
"The 'device' parameter must be a valid device string or device object."
)

if hasattr(self, "device_"):
self.device_ = device

self.device = device
self.solver_.model.to(device)

return self
42 changes: 39 additions & 3 deletions cebra/integrations/sklearn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import sklearn.utils.validation as sklearn_utils_validation
import torch

import cebra.helper


def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple:
"""Handle deprecated arguments of a function until they are replaced.
Expand Down Expand Up @@ -114,16 +116,50 @@ def check_device(device: str) -> str:
device: The device to return, if possible.

Returns:
Either cuda or cpu depending on {device} and availability in the environment.
Either cuda, cuda:device_id, mps, or cpu depending on {device} and availability in the environment.
"""

if device == "cuda_if_available":
if torch.cuda.is_available():
return "cuda"
elif cebra.helper._is_mps_availabe(torch):
return "mps"
else:
return "cpu"
elif device in ["cuda", "cpu"]:
elif device.startswith("cuda:") and len(device) > 5:
cuda_device_id = device[5:]
if cuda_device_id.isdigit():
device_count = torch.cuda.device_count()
device_id = int(cuda_device_id)
if device_id < device_count:
return f"cuda:{device_id}"
else:
raise ValueError(
f"CUDA device {device_id} is not available. Available device IDs are 0 to {device_count - 1}."
)
else:
raise ValueError(
f"Invalid CUDA device ID format. Please use 'cuda:device_id' where '{cuda_device_id}' is an integer."
)
elif device == "cuda" and torch.cuda.is_available():
return "cuda:0"
elif device == "cpu":
return device
raise ValueError(f"Device needs to be cuda or cpu, but got {device}.")
elif device == "mps":
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
raise ValueError(
"MPS not available because the current PyTorch install was not "
"built with MPS enabled.")
else:
raise ValueError(
"MPS not available because the current MacOS version is not 12.3+ "
"and/or you do not have an MPS-enabled device on this machine."
)

return device

stes marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Device needs to be cuda, cpu or mps, but got {device}.")


def check_fitted(model: "cebra.models.Model") -> bool:
Expand Down
2 changes: 1 addition & 1 deletion cebra/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _set_device(self, device):
return
if not isinstance(device, str):
device = device.type
if device not in ("cpu", "cuda"):
if device not in ("cpu", "cuda", "mps"):
if device.startswith("cuda"):
_, id_ = device.split(":")
if int(id_) >= torch.cuda.device_count():
Expand Down
Loading