Skip to content

Commit

Permalink
Add support for device_map use (#264)
Browse files Browse the repository at this point in the history
* Add device_map support

* Fix device setter in HF model
  • Loading branch information
gsarti authored Apr 23, 2024
1 parent fa088c8 commit 343fc79
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 6 deletions.
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
author = "The Inseq Team"

# The short X.Y version
version = "0.6"
version = "0.7"
# The full version, including alpha/beta/rc tags
release = "0.6.0"
release = "0.7.0.dev0"


# Prefix link to point to master, comment this during version release and uncomment below line
Expand Down
1 change: 1 addition & 0 deletions inseq/models/attribution_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def __init__(self, **kwargs) -> None:
self.pad_token: Optional[str] = None
self.embed_scale: Optional[float] = None
self._device: Optional[str] = None
self.device_map: Optional[dict[str, Union[str, int, torch.device]]] = None
self.attribution_method: Optional[FeatureAttribution] = None
self.is_hooked: bool = False
self._default_attributed_fn_id: str = "probability"
Expand Down
12 changes: 9 additions & 3 deletions inseq/models/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def __init__(
self.embed_scale = 1.0
self.encoder_int_embeds = None
self.decoder_int_embeds = None
self.device_map = None
if hasattr(self.model, "hf_device_map") and self.model.hf_device_map is not None:
self.device_map = self.model.hf_device_map
self.is_encoder_decoder = self.model.config.is_encoder_decoder
self.configure_embeddings_scale()
self.setup(device, attribution_method, **kwargs)
Expand Down Expand Up @@ -162,16 +165,19 @@ def device(self, new_device: str) -> None:
is_loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
is_loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False)
is_quantized = is_loaded_in_8bit or is_loaded_in_4bit
has_device_map = self.device_map is not None

# Enable compatibility with 8bit models
if self.model:
if not is_quantized:
self.model.to(self._device)
else:
if is_quantized:
mode = "8bit" if is_loaded_in_8bit else "4bit"
logger.warning(
f"The model is loaded in {mode} mode. The device cannot be changed after loading the model."
)
elif has_device_map:
logger.warning("The model is loaded with a device map. The device cannot be changed after loading.")
else:
self.model.to(self._device)

@abstractmethod
def configure_embeddings_scale(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions inseq/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from .hooks import StackFrame, get_post_variable_assignment_hook
from .import_utils import (
is_accelerate_available,
is_captum_available,
is_datasets_available,
is_ipywidgets_available,
Expand Down Expand Up @@ -130,4 +131,5 @@
"validate_indices",
"pad_with_nan",
"recursive_get_submodule",
"is_accelerate_available",
]
5 changes: 5 additions & 0 deletions inseq/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
_captum_available = find_spec("captum") is not None
_joblib_available = find_spec("joblib") is not None
_nltk_available = find_spec("nltk") is not None
_accelerate_available = find_spec("accelerate") is not None


def is_ipywidgets_available():
Expand Down Expand Up @@ -40,3 +41,7 @@ def is_joblib_available():

def is_nltk_available():
return _nltk_available


def is_accelerate_available():
return _accelerate_available
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "inseq"
version = "0.6.0"
version = "0.7.0.dev0"
description = "Interpretability for Sequence Generation Models 🔍"
readme = "README.md"
requires-python = ">=3.9"
Expand Down

0 comments on commit 343fc79

Please sign in to comment.