diff --git a/CHANGELOG.md b/CHANGELOG.md index 25dbc19..4fc6e2e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,24 @@ - Rows and columns in the visualization now have indices alongside tokens to facilitate index-based slicing, aggregation and alignment [#282](https://github.com/inseq-team/inseq/pull/282) +- Added a `scores_precision` to `FeatureAttributionOutput.save` to enable efficient saving in `float16` and `float8` formats. This is useful for saving large attribution outputs in a more memory-efficient way. [#273](https://github.com/inseq-team/inseq/pull/273) + +```python +import inseq + +attrib_model = inseq.load_model("gpt2", "attention") +out = attrib_model.attribute("Hello world", generation_kwargs={'max_new_tokens': 100}) + +# Previous usage, memory inefficient +out.save("output.json") + +# Memory-efficient saving +out.save("output_fp16.json", scores_precision="float16") # or "float8" + +# Automatic conversion to float32 +out_loaded = inseq.FeatureAttributionOutput.load("output_fp16.json") +``` + - - A new `SliceAggregator` (`"slices"`) is added to allow for slicing source (in encoder-decoder) or target (in decoder-only) tokens from a `FeatureAttributionSequenceOutput` object, using the same syntax of `ContiguousSpanAggregator`. The `__getitem__` method of the `FeatureAttributionSequenceOutput` is a shortcut for this, allowing slicing with `[start:stop]` syntax. [#282](https://github.com/inseq-team/inseq/pull/282) ```python diff --git a/Makefile b/Makefile index 9df893f..96febe1 100644 --- a/Makefile +++ b/Makefile @@ -82,7 +82,7 @@ fix-style: .PHONY: check-safety check-safety: - $(PYTHON) -m safety check --full-report -i 70612 + $(PYTHON) -m safety check --full-report -i 70612 -i 71670 .PHONY: lint lint: fix-style check-safety diff --git a/inseq/data/attribution.py b/inseq/data/attribution.py index fec015f..e42c030 100644 --- a/inseq/data/attribution.py +++ b/inseq/data/attribution.py @@ -1,3 +1,4 @@ +import base64 import logging from copy import deepcopy from dataclasses import dataclass, field @@ -8,6 +9,8 @@ import torch from ..utils import ( + convert_from_safetensor, + convert_to_safetensor, drop_padding, get_sequences_from_batched_steps, json_advanced_dump, @@ -20,6 +23,7 @@ MultipleScoresPerSequenceTensor, MultipleScoresPerStepTensor, OneOrMoreTokenWithIdSequences, + ScorePrecision, SequenceAttributionTensor, SingleScorePerStepTensor, SingleScoresPerSequenceTensor, @@ -175,6 +179,55 @@ def __sub__(self, other: "FeatureAttributionSequenceOutput") -> "FeatureAttribut raise ValueError(f"Cannot compare {type(other)} with {type(self)}") return self.aggregate("pair", paired_attr=other, do_post_aggregation_checks=False) + def _convert_to_safetensors(self, scores_precision: ScorePrecision = "float32"): + """ + Converts tensor attributes within the class to the specified precision. + The conversion is based on the specified `scores_precision`. + If the input tensor is already of the desired precision, no conversion occurs. + For float8, the function performs scaling and converts to uint8, which can be later converted back to float16 upon reloading. + + Args: + scores_precision (str, optional): Desired output data type precision. Defaults to "float32". + Returns: + self: The function modifies the class attributes in-place. + """ + + if self.source_attributions is not None: + self.source_attributions = convert_to_safetensor( + self.source_attributions.contiguous(), scores_precision=scores_precision + ) + if self.target_attributions is not None: + self.target_attributions = convert_to_safetensor( + self.target_attributions.contiguous(), scores_precision=scores_precision + ) + if self.step_scores is not None: + self.step_scores = { + k: convert_to_safetensor(v.contiguous(), scores_precision=scores_precision) + for k, v in self.step_scores.items() + } + if self.sequence_scores is not None: + self.sequence_scores = { + k: convert_to_safetensor(v.contiguous(), scores_precision=scores_precision) + for k, v in self.sequence_scores.items() + } + return self + + def _recover_from_safetensors(self): + """ + Converts tensor attributes within the class from b64-encoded safetensors to torch tensors.`. + """ + if self.source_attributions is not None: + self.source_attributions = convert_from_safetensor(base64.b64decode(self.source_attributions)) + if self.target_attributions is not None: + self.target_attributions = convert_from_safetensor(base64.b64decode(self.target_attributions)) + if self.step_scores is not None: + self.step_scores = {k: convert_from_safetensor(base64.b64decode(v)) for k, v in self.step_scores.items()} + if self.sequence_scores is not None: + self.sequence_scores = { + k: convert_from_safetensor(base64.b64decode(v)) for k, v in self.sequence_scores.items() + } + return self + @staticmethod def get_remove_pad_fn(attr: "FeatureAttributionStepOutput", name: str) -> Callable: if attr.source_attributions is None or name.startswith("decoder"): @@ -562,6 +615,7 @@ def save( ndarray_compact: bool = True, use_primitives: bool = False, split_sequences: bool = False, + scores_precision: ScorePrecision = "float32", ) -> None: """Save class contents to a JSON file. @@ -583,22 +637,33 @@ def save( If True, the output is split into multiple files, one per sequence. The file names are generated by appending the sequence index to the given path (e.g. ``./out.json`` with two sequences -> ``./out_0.json``, ``./out_1.json``) + scores_precision (:obj:`str`, *optional*, defaults to "float32"): + Rounding precision for saved scores. Can be used to reduce space on disk but introduces rounding + errors. Can be combined with compress=True for further space reduction. + Accepted values: "float32", "float16", or "float8". Default: "float32" (no rounding). """ if not overwrite and Path(path).exists(): raise ValueError(f"{path} already exists. Override with overwrite=True.") save_outs = [] paths = [] if split_sequences: - for i, seq in enumerate(self.sequence_attributions): + for seq_id in range(len(self.sequence_attributions)): attr_out = deepcopy(self) - attr_out.sequence_attributions = [seq] + attr_out.sequence_attributions = [ + attr_out.sequence_attributions[seq_id]._convert_to_safetensors(scores_precision=scores_precision) + ] # this overwrites the original attr_out.step_attributions = None - attr_out.info["input_texts"] = [attr_out.info["input_texts"][i]] - attr_out.info["generated_texts"] = [attr_out.info["generated_texts"][i]] + attr_out.info["input_texts"] = [attr_out.info["input_texts"][seq_id]] + attr_out.info["generated_texts"] = [attr_out.info["generated_texts"][seq_id]] save_outs.append(attr_out) - paths.append(f"{str(path).split('.json')[0]}_{i}.json{'.gz' if compress else ''}") + paths.append(f"{str(path).split('.json')[0]}_{seq_id}.json{'.gz' if compress else ''}") else: - save_outs.append(self) + self_out = deepcopy(self) + self_out.sequence_attributions = [ + seq._convert_to_safetensors(scores_precision=scores_precision) + for seq in self_out.sequence_attributions + ] + save_outs.append(self_out) paths.append(path) for attr_out, path_out in zip(save_outs, paths): with open(path_out, f"w{'b' if compress else ''}") as f: @@ -631,9 +696,9 @@ def load( :class:`~inseq.data.FeatureAttributionOutput`: Loaded attribution output """ out = json_advanced_load(path, decompression=decompress) - out.sequence_attributions = [seq.torch() for seq in out.sequence_attributions] + out.sequence_attributions = [seq._recover_from_safetensors() for seq in out.sequence_attributions] if out.step_attributions is not None: - out.step_attributions = [step.torch() for step in out.step_attributions] + out.step_attributions = [step._recover_from_safetensors() for step in out.step_attributions] return out def aggregate( diff --git a/inseq/utils/__init__.py b/inseq/utils/__init__.py index c7512bd..7471265 100644 --- a/inseq/utils/__init__.py +++ b/inseq/utils/__init__.py @@ -50,6 +50,8 @@ from .torch_utils import ( aggregate_contiguous, check_device, + convert_from_safetensor, + convert_to_safetensor, euclidean_distance, filter_logits, find_block_stack, @@ -71,6 +73,8 @@ "UnknownAttributionMethodError", "MissingAlignmentsError", "cache_results", + "convert_to_safetensor", + "convert_from_safetensor", "optional", "pad", "pretty_list", diff --git a/inseq/utils/serialization.py b/inseq/utils/serialization.py index b45f858..f796619 100644 --- a/inseq/utils/serialization.py +++ b/inseq/utils/serialization.py @@ -29,6 +29,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE # USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import base64 import json from collections import OrderedDict from json import JSONEncoder @@ -59,6 +60,8 @@ def class_instance_encode(obj: EncodableObject, use_primitives: bool = True, **k """ if isinstance(obj, (list, dict)): return obj + if isinstance(obj, bytes): + return base64.b64encode(obj).decode("UTF8") if hasattr(obj, "__class__") and hasattr(obj, "__dict__"): if not hasattr(obj, "__new__"): raise TypeError(f"class '{obj.__class__}' does not have a __new__ method; ") @@ -84,9 +87,7 @@ def class_instance_encode(obj: EncodableObject, use_primitives: bool = True, **k dct["attributes"] = hashodict(obj.__dict__) if use_primitives: attrs = dct.get("attributes", {}) - return attrs - else: - return dct + return attrs if use_primitives else dct return obj diff --git a/inseq/utils/torch_utils.py b/inseq/utils/torch_utils.py index be790db..d3c13b9 100644 --- a/inseq/utils/torch_utils.py +++ b/inseq/utils/torch_utils.py @@ -4,6 +4,7 @@ from inspect import signature from typing import TYPE_CHECKING, Callable, Literal, Optional, Union +import safetensors import torch import torch.nn.functional as F from jaxtyping import Int, Num @@ -40,6 +41,38 @@ def remap_from_filtered( return new_source.scatter(0, index, filtered) +def convert_to_safetensor(tensor: torch.Tensor, scores_precision="float32") -> bytes: + """ + Converts a torch tensor to a safetensor. + + Args: + tensor (torch.Tensor): some torch tensor + scores_precision (str): format to convert weights to: [float32, float16, float8] + Returns: + bytes: A safetensor in bytes format + Raises: + ValueError if `scores_precision` doesn't match the possible options + + """ + if scores_precision == "float32": + return safetensors.torch.save({"attribution": tensor}) + elif scores_precision == "float16": + return safetensors.torch.save({"attribution": tensor.to(torch.float16)}) + elif scores_precision == "float8": + logger.warning("Float8 precision is experimental and may result in loss of precision.") + return safetensors.torch.save({"attribution": tensor.to(torch.float8_e4m3fn)}) + else: + raise ValueError("`scores_precision` has to be one of [float32, float16, float8]") + + +def convert_from_safetensor(safetensor: bytes) -> torch.Tensor: + """ + Convert a safetensor to a torch tensor and convert weights to float32. + Adapted from https://huggingface.co/docs/safetensors/metadata_parsing + """ + return safetensors.torch.load(safetensor)["attribution"].to(torch.float32) + + def postprocess_attribution_scores(func: Callable) -> Callable: @wraps(func) def postprocess_scores_wrapper( diff --git a/inseq/utils/typing.py b/inseq/utils/typing.py index 4eec4a5..4673c07 100644 --- a/inseq/utils/typing.py +++ b/inseq/utils/typing.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Callable, Literal, Optional, Union import torch from captum.attr._utils.attribution import Attribution @@ -71,6 +71,8 @@ class TextSequences: OneOrMoreTokenWithIdSequences = Sequence[Sequence[TokenWithId]] OneOrMoreAttributionSequences = Sequence[Sequence[float]] +ScorePrecision = Literal["float32", "float16", "float8"] + IndexSpan = Union[tuple[int, int], Sequence[tuple[int, int]]] OneOrMoreIndices = Union[int, list[int], tuple[int, int]] OneOrMoreIndicesDict = dict[int, OneOrMoreIndices] diff --git a/pyproject.toml b/pyproject.toml index 83d663c..fcfd8af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ dependencies = [ "numpy>=1.21.6", "jaxtyping>=0.2.25", "typeguard<=2.13.3", - "torch>=2.1.1", + "torch>=2.0", "matplotlib>=3.5.3", "tqdm>=4.64.0", "nvidia-cublas-cu11>=11.10.3.66; sys_platform=='Linux'", @@ -84,7 +84,7 @@ lint = [ "ruff>=0.2.0" ] sklearn = [ - "scikit-learn>=1.4.0", + "scikit-learn>=1.5.1", "joblib>=1.3.2" ] datasets = [ diff --git a/requirements-dev.txt b/requirements-dev.txt index 4020d11..4374e00 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -14,7 +14,7 @@ asttokens==2.4.1 # via stack-data attrs==23.2.0 # via aiohttp -authlib==1.3.0 +authlib==1.3.1 # via safety babel==2.14.0 # via sphinx @@ -46,7 +46,7 @@ contourpy==1.2.0 # via matplotlib coverage==7.4.1 # via pytest-cov -cryptography==42.0.5 +cryptography==43.0.0 # via authlib cycler==0.12.1 # via matplotlib @@ -311,11 +311,11 @@ safety==3.1.0 # via inseq (pyproject.toml) safety-schemas==0.0.2 # via safety -scikit-learn==1.4.0 +scikit-learn==1.5.1 # via inseq (pyproject.toml) scipy==1.12.0 # via scikit-learn -sentencepiece==0.1.99 +sentencepiece==0.2.0 # via transformers setuptools==69.1.0 # via @@ -375,7 +375,7 @@ threadpoolctl==3.2.0 # via scikit-learn tokenizers==0.15.2 # via transformers -torch==2.2.0 +torch==2.3.1 # via # inseq (pyproject.toml) # captum diff --git a/requirements.txt b/requirements.txt index 9c02174..27e3e89 100644 --- a/requirements.txt +++ b/requirements.txt @@ -86,7 +86,7 @@ rich==13.7.0 # via inseq (pyproject.toml) safetensors==0.4.2 # via transformers -sentencepiece==0.1.99 +sentencepiece==0.2.0 # via transformers six==1.16.0 # via python-dateutil @@ -94,7 +94,7 @@ sympy==1.12 # via torch tokenizers==0.15.2 # via transformers -torch==2.2.0 +torch==2.3.1 # via # inseq (pyproject.toml) # captum diff --git a/tests/data/test_attribution.py b/tests/data/test_attribution.py index 6b08150..a8ccb8c 100644 --- a/tests/data/test_attribution.py +++ b/tests/data/test_attribution.py @@ -26,10 +26,10 @@ def test_save_load_attribution_split(tmp_path, saliency_mt_model): out_path = tmp_path / "tmp_attr.json" out = saliency_mt_model.attribute(["This is a test.", "sequence number two"], device="cpu", show_progress=False) out.save(out_path, split_sequences=True) - out_path_1 = tmp_path / "tmp_attr_1.json" + out_path_1 = tmp_path / "tmp_attr_0.json" loaded_out = FeatureAttributionOutput.load(out_path_1) assert torch.allclose( - out.sequence_attributions[1].source_attributions, loaded_out.sequence_attributions[0].source_attributions + out.sequence_attributions[0].source_attributions, loaded_out.sequence_attributions[0].source_attributions ) @@ -41,6 +41,30 @@ def test_save_load_attribution_compressed(tmp_path, saliency_mt_model): assert out == loaded_out +def test_save_load_attribution_float16(tmp_path, saliency_mt_model): + out_path = tmp_path / "tmp_attr_compress.json.gz" + out = saliency_mt_model.attribute("This is a test.", device="cpu", show_progress=False) + out.save(out_path, compress=True, scores_precision="float16") + loaded_out = FeatureAttributionOutput.load(out_path, decompress=True) + assert torch.allclose( + out.sequence_attributions[0].source_attributions, + loaded_out.sequence_attributions[0].source_attributions, + atol=1e-05, + ) + + +def test_save_load_attribution_float8(tmp_path, saliency_mt_model): + out_path = tmp_path / "tmp_attr_compress.json.gz" + out = saliency_mt_model.attribute("This is a test.", device="cpu", show_progress=False) + out.save(out_path, compress=True, scores_precision="float8") + loaded_out = FeatureAttributionOutput.load(out_path, decompress=True) + assert torch.allclose( + out.sequence_attributions[0].source_attributions, + loaded_out.sequence_attributions[0].source_attributions, + atol=1e-02, + ) + + def test_get_scores_dicts_encoder_decoder(saliency_mt_model): out = saliency_mt_model.attribute(["This is a test.", "Hello world!"], device="cpu", show_progress=False) dicts = out.get_scores_dicts()