Skip to content

Commit

Permalink
Resolves #32: Auto-detect available device when cuda is not present (#34
Browse files Browse the repository at this point in the history
)

Signed-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>
  • Loading branch information
avsolatorio authored Jul 13, 2023
1 parent 6472ad3 commit dc757bd
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/realtabformer/realtabformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _normalize_gpt2_state_dict(state_dict):


def _validate_get_device(device: str) -> str:
if torch.cuda.device_count() == 0:
if (device == "cuda") and (torch.cuda.device_count() == 0):
if torch.backends.mps.is_available():
_device = "mps"
else:
Expand Down Expand Up @@ -1209,6 +1209,7 @@ def sample(
DataFrame with n_samples rows of generated data
"""
self._check_model()
device = _validate_get_device(device)

# Clear the cache
torch.cuda.empty_cache()
Expand Down Expand Up @@ -1341,6 +1342,7 @@ def predict(
self.model_type == ModelType.tabular
), "The predict method is only implemented for tabular data..."
self._check_model()
device = _validate_get_device(device)
batch = min(batch, data.shape[0])

# Clear the cache
Expand Down

0 comments on commit dc757bd

Please sign in to comment.